Files
archived-rttys/device.go
Jianhui Zhao 21db9f8528 use SPDX-License-Identifier
Signed-off-by: Jianhui Zhao <zhaojh329@gmail.com>
2025-08-08 19:56:51 +08:00

622 lines
12 KiB
Go

/* SPDX-License-Identifier: MIT */
/*
* Author: Jianhui Zhao <zhaojh329@gmail.com>
*/
package main
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"encoding/binary"
"fmt"
"io"
"net"
"net/http"
"os"
"strings"
"sync"
"time"
"github.com/zhaojh329/rttys/v5/utils"
"github.com/gorilla/websocket"
jsoniter "github.com/json-iterator/go"
"github.com/rs/zerolog/log"
"github.com/valyala/bytebufferpool"
)
type DeviceInfo struct {
Group string `json:"group"`
ID string `json:"id"`
Connected uint32 `json:"connected"`
Uptime uint32 `json:"uptime"`
Desc string `json:"description"`
Proto uint8 `json:"proto"`
IPaddr string `json:"ipaddr"`
}
type Device struct {
group string
id string
proto uint8
desc string
timestamp int64
uptime uint32
token string
heartbeat time.Duration
users sync.Map
pending sync.Map
commands sync.Map
https sync.Map
conn net.Conn
br *bufio.Reader
readBuf []byte
close sync.Once
ctx context.Context
cancel context.CancelFunc
}
const (
msgTypeRegister = byte(iota)
msgTypeLogin
msgTypeLogout
msgTypeTermData
msgTypeWinsize
msgTypeCmd
msgTypeHeartbeat
msgTypeFile
msgTypeHttp
msgTypeAck
)
const (
msgTypeFileSend = byte(iota)
msgTypeFileRecv
msgTypeFileInfo
msgTypeFileData
msgTypeFileAck
msgTypeFileAbort
)
const (
msgRegAttrHeartbeat = iota
msgRegAttrDevid
msgRegAttrDescription
msgRegAttrToken
msgRegAttrGroup
)
const (
msgHeartbeatAttrUptime = iota
)
const (
devRegErrUnsupportedProto = iota + 1
devRegErrInvalidToken
devRegErrHookFailed
devRegErrIdConflicting
)
const (
RttyProtoRequired uint8 = 3
WaitRegistTimeout = 5 * time.Second
DefaultHeartbeat = 5 * time.Second
TermLoginTimeout = 5 * time.Second
CommandTimeout = 30
)
var DevRegErrMsg = map[byte]string{
0: "Success",
devRegErrUnsupportedProto: "Unsupported protocol",
devRegErrInvalidToken: "Invalid token",
devRegErrHookFailed: "Hook failed",
devRegErrIdConflicting: "ID conflict",
}
var DeviceMsgHandlers = map[byte]func(*Device, []byte) error{
msgTypeHeartbeat: handleHeartbeatMsg,
msgTypeLogin: handleLoginMsg,
msgTypeLogout: handleLogoutMsg,
msgTypeTermData: handleTermDataMsg,
msgTypeFile: handleFileMsg,
msgTypeCmd: handleCmdMsg,
msgTypeHttp: handleHttpMsg,
}
func (srv *RttyServer) ListenDevices() {
cfg := &srv.cfg
ln, err := net.Listen("tcp", cfg.AddrDev)
if err != nil {
log.Fatal().Msg(err.Error())
}
defer ln.Close()
if cfg.SslCert != "" && cfg.SslKey != "" {
cert, err := tls.LoadX509KeyPair(cfg.SslCert, cfg.SslKey)
if err != nil {
log.Fatal().Msg(err.Error())
}
config := &tls.Config{
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS12,
}
if cfg.CaCert != "" {
caCert, err := os.ReadFile(cfg.CaCert)
if err != nil {
log.Fatal().Msg(err.Error())
}
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(caCert)
config.ClientCAs = caCertPool
config.ClientAuth = tls.RequireAndVerifyClientCert
}
ln = tls.NewListener(ln, config)
log.Info().Msgf("Listen devices on: %s SSL on", ln.Addr().(*net.TCPAddr))
} else {
log.Info().Msgf("Listen devices on: %s SSL off", ln.Addr().(*net.TCPAddr))
}
for {
conn, err := ln.Accept()
if err != nil {
log.Error().Msg(err.Error())
continue
}
go handleDeviceConnection(srv, conn)
}
}
func handleDeviceConnection(srv *RttyServer, conn net.Conn) {
defer logPanic()
dev := &Device{
conn: conn,
heartbeat: DefaultHeartbeat,
timestamp: time.Now().Unix(),
br: bufio.NewReader(conn),
}
defer dev.Close(srv)
dev.ctx, dev.cancel = context.WithCancel(context.Background())
log.Debug().Msgf("new device '%s' connected", conn.RemoteAddr())
conn.SetReadDeadline(time.Now().Add(WaitRegistTimeout))
typ, data, err := dev.ReadMsg()
if err != nil {
log.Error().Msgf("read register msg fail: %v", err)
return
}
if typ != msgTypeRegister {
log.Error().Msg("register msg expected first")
return
}
err = dev.ParseRegister(data)
if err != nil {
log.Error().Err(err).Msg("invalid device info")
return
}
code := dev.Register(srv)
err = dev.WriteMsg(msgTypeRegister, "", append([]byte{code}, DevRegErrMsg[code]...))
if err != nil {
log.Error().Err(err).Msgf("send register to device '%s' fail", dev.id)
return
}
if code != 0 {
return
}
log.Info().Msgf("device '%s' registered, group '%s' proto %d, heartbeat %v",
dev.id, dev.group, dev.proto, dev.heartbeat)
for {
conn.SetReadDeadline(time.Now().Add(dev.heartbeat * 3 / 2))
typ, data, err = dev.ReadMsg()
if err != nil {
if err != io.EOF {
log.Error().Msgf("read msg from device '%s' fail: %v", dev.id, err)
}
return
}
log.Debug().Msgf("device msg %s from device %s", msgTypeName(typ), dev.id)
handler, ok := DeviceMsgHandlers[typ]
if !ok {
log.Error().Msgf("unexpected message '%s' from device '%s'", msgTypeName(typ), dev.id)
return
}
err = handler(dev, data)
if err != nil {
log.Error().Msg(err.Error())
return
}
}
}
func msgTypeName(typ byte) string {
switch typ {
case msgTypeRegister:
return "register"
case msgTypeLogin:
return "login"
case msgTypeLogout:
return "logout"
case msgTypeTermData:
return "termdata"
case msgTypeWinsize:
return "winsize"
case msgTypeCmd:
return "cmd"
case msgTypeHeartbeat:
return "heartbeat"
case msgTypeFile:
return "file"
case msgTypeHttp:
return "http"
case msgTypeAck:
return "ack"
default:
return fmt.Sprintf("unknown(%d)", typ)
}
}
func (dev *Device) ReadMsg() (byte, []byte, error) {
head := make([]byte, 3)
br := dev.br
_, err := io.ReadFull(br, head)
if err != nil {
return 0, nil, err
}
typ := head[0]
msgLen := binary.BigEndian.Uint16(head[1:])
if cap(dev.readBuf) < int(msgLen) {
dev.readBuf = make([]byte, msgLen)
} else {
dev.readBuf = dev.readBuf[:msgLen]
}
_, err = io.ReadFull(br, dev.readBuf)
if err != nil {
return 0, nil, err
}
return typ, dev.readBuf, nil
}
func (dev *Device) WriteMsg(typ byte, sid string, data []byte) error {
bb := bytebufferpool.Get()
defer bytebufferpool.Put(bb)
b := []byte{typ, 0, 0}
binary.BigEndian.PutUint16(b[1:], uint16(len(sid)+len(data)))
bb.Write(b)
bb.WriteString(sid)
bb.Write(data)
_, err := bb.WriteTo(dev.conn)
return err
}
func (dev *Device) WriteFileMsg(typ byte, sid string, fileType byte, data []byte) error {
bb := bytebufferpool.Get()
defer bytebufferpool.Put(bb)
bb.WriteByte(fileType)
bb.Write(data)
return dev.WriteMsg(typ, sid, bb.Bytes())
}
func (dev *Device) Close(srv *RttyServer) {
dev.close.Do(func() {
log.Error().Msgf("device '%s' disconnected", dev.id)
srv.DelDevice(dev)
dev.cancel()
dev.conn.Close()
})
}
func (dev *Device) ParseRegister(b []byte) error {
if len(b) < 1 {
return fmt.Errorf("too short")
}
dev.proto = b[0]
if dev.proto > 4 {
attrs := utils.ParseTLV(b[1:])
if attrs == nil {
return fmt.Errorf("invalid tlv")
}
for typ, val := range attrs {
switch typ {
case msgRegAttrHeartbeat:
dev.heartbeat = time.Duration(val[0]) * time.Second
case msgRegAttrDevid:
dev.id = string(val)
case msgRegAttrDescription:
dev.desc = string(val)
case msgRegAttrToken:
dev.token = string(val)
case msgRegAttrGroup:
dev.group = string(val)
}
}
} else {
b = b[1:]
fields := bytes.Split(b, []byte{0})
if len(fields) < 3 {
return fmt.Errorf("invalid format")
}
dev.id = string(fields[0])
dev.desc = string(fields[1])
dev.token = string(fields[2])
}
if dev.id == "" {
return fmt.Errorf("not found device id")
}
if len(dev.id) > 32 {
return fmt.Errorf("device id too long")
}
if len(dev.desc) > 126 {
return fmt.Errorf("device desc too long")
}
if len(dev.group) > 16 {
return fmt.Errorf("device group too long")
}
return nil
}
func (dev *Device) Register(srv *RttyServer) byte {
cfg := &srv.cfg
if dev.proto < RttyProtoRequired {
log.Error().Msgf("minimum proto required %d, found %d for device '%s'", RttyProtoRequired, dev.proto, dev.id)
return devRegErrHookFailed
}
if cfg.Token != "" && dev.token != cfg.Token {
log.Error().Msgf("invalid token for device '%s'", dev.id)
return devRegErrInvalidToken
}
devHookUrl := cfg.DevHookUrl
if devHookUrl != "" {
cli := &http.Client{
Timeout: 3 * time.Second,
}
data := fmt.Sprintf(`{"group":"%s", "devid":"%s", "token":"%s"}`, dev.group, dev.id, dev.token)
resp, err := cli.Post(devHookUrl, "application/json", strings.NewReader(data))
if err != nil {
log.Error().Msgf("call device hook url fail for device %s: %v", dev.id, err)
return devRegErrHookFailed
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
log.Error().Msgf("call device hook url for device '%s', StatusCode: %d", dev.id, resp.StatusCode)
return devRegErrHookFailed
}
}
if !srv.AddDevice(dev) {
return devRegErrIdConflicting
}
return 0
}
func handleHeartbeatMsg(dev *Device, data []byte) error {
if !parseHeartbeat(dev, data) {
return fmt.Errorf("invalid heartbeat msg from device '%s'", dev.id)
}
return dev.WriteMsg(msgTypeHeartbeat, "", nil)
}
func parseHeartbeat(dev *Device, data []byte) bool {
if dev.proto > 4 {
attrs := utils.ParseTLV(data)
if attrs == nil {
return false
}
for typ, val := range attrs {
switch typ {
case msgHeartbeatAttrUptime:
dev.uptime = binary.BigEndian.Uint32(val)
}
}
} else {
if len(data) < 4 {
return false
}
dev.uptime = binary.BigEndian.Uint32(data[:4])
}
return true
}
func handleLogoutMsg(dev *Device, data []byte) error {
if len(data) < 32 {
return fmt.Errorf("invalid logout msg from device '%s'", dev.id)
}
sid := string(data[:32])
if val, loaded := dev.users.LoadAndDelete(sid); loaded {
user := val.(*User)
user.Close()
}
return nil
}
func handleLoginMsg(dev *Device, data []byte) error {
if len(data) < 33 {
return fmt.Errorf("invalid login msg from device '%s'", dev.id)
}
sid := string(data[:32])
code := data[32]
if val, loaded := dev.pending.LoadAndDelete(sid); loaded {
user := val.(*User)
ok := code == 0
errCode := 0
if ok {
log.Debug().Msgf("login session '%s' for device '%s' success", sid, dev.id)
dev.users.Store(sid, user)
} else {
errCode = LoginErrorBusy
log.Error().Msgf("login session '%s' for device '%s' fail, due to device busy", sid, dev.id)
}
if errCode == 0 {
user.WriteMsg(websocket.TextMessage, []byte(fmt.Appendf(nil, `{"type":"login"}`)))
} else {
user.SendCloseMsg(LoginErrorBusy, "device busy")
}
user.pending <- ok
}
return nil
}
func handleTermDataMsg(dev *Device, data []byte) error {
if len(data) < 32 {
return fmt.Errorf("invalid term data msg from device '%s'", dev.id)
}
sid := string(data[:32])
if val, ok := dev.users.Load(sid); ok {
user := val.(*User)
data[31] = 0
user.WriteMsg(websocket.BinaryMessage, data[31:])
}
return nil
}
func handleFileMsg(dev *Device, data []byte) error {
if len(data) < 33 {
return fmt.Errorf("invalid file msg from device '%s'", dev.id)
}
sid := string(data[:32])
typ := data[32]
if val, ok := dev.users.Load(sid); ok {
user := val.(*User)
switch typ {
case msgTypeFileSend:
user.WriteMsg(websocket.TextMessage,
fmt.Appendf(nil, `{"type":"sendfile", "name": "%s"}`, string(data[33:])))
case msgTypeFileRecv:
user.WriteMsg(websocket.TextMessage, []byte(`{"type":"recvfile"}`))
case msgTypeFileData:
data[32] = 1
user.WriteMsg(websocket.BinaryMessage, data[32:])
case msgTypeFileAck:
user.WriteMsg(websocket.TextMessage, []byte(`{"type":"fileAck"}`))
case msgTypeFileAbort:
user.WriteMsg(websocket.BinaryMessage, []byte{1})
}
}
return nil
}
func handleHttpMsg(dev *Device, data []byte) error {
if len(data) < 18 {
return fmt.Errorf("invalid http msg from device '%s'", dev.id)
}
addr := data[:18]
data = data[18:]
if c, ok := dev.https.Load(string(addr)); ok {
c := c.(net.Conn)
if len(data) == 0 {
c.Close()
} else {
c.Write(data)
}
}
return nil
}
func handleCmdMsg(dev *Device, data []byte) error {
info := &CommandRespInfo{}
err := jsoniter.Unmarshal(data, info)
if err != nil {
return fmt.Errorf("parse command resp info error: %v", err)
}
var attrs map[string]any
err = jsoniter.Unmarshal(info.Attrs, &attrs)
if err != nil {
return fmt.Errorf("parse command resp attrs error: %v", err)
}
attrs["devid"] = dev.id
if val, ok := dev.commands.Load(info.Token); ok {
req := val.(*CommandReq)
req.acked = true
req.c.JSON(http.StatusOK, attrs)
req.cancel()
}
return nil
}