diff --git a/command.go b/command.go index fb5e75a..33d5ff7 100644 --- a/command.go +++ b/command.go @@ -12,6 +12,7 @@ import ( "strconv" "time" + "github.com/zhaojh329/rtty-go/proto" "github.com/zhaojh329/rttys/v5/utils" "github.com/gin-gonic/gin" @@ -74,7 +75,7 @@ func (dev *Device) handleCmdReq(c *gin.Context, info *CommandReqInfo) { log.Debug().Msgf("send cmd request for device '%s', token '%s'", dev.id, token) - err := dev.WriteMsg(msgTypeCmd, "", msg.Bytes()) + err := dev.WriteMsg(proto.MsgTypeCmd, msg) if err != nil { cmdErrResp(c, rttyCmdErrOffline) return diff --git a/device.go b/device.go index 015a826..93c17b9 100644 --- a/device.go +++ b/device.go @@ -6,7 +6,6 @@ package main import ( - "bufio" "bytes" "context" "crypto/tls" @@ -21,12 +20,11 @@ import ( "sync" "time" - "github.com/zhaojh329/rttys/v5/utils" - "github.com/gorilla/websocket" jsoniter "github.com/json-iterator/go" "github.com/rs/zerolog/log" - "github.com/valyala/bytebufferpool" + "github.com/zhaojh329/rtty-go/proto" + "github.com/zhaojh329/rttys/v5/utils" ) type DeviceInfo struct { @@ -54,48 +52,14 @@ type Device struct { commands sync.Map https sync.Map - conn net.Conn - br *bufio.Reader - readBuf []byte - close sync.Once - ctx context.Context - cancel context.CancelFunc + conn net.Conn + close sync.Once + ctx context.Context + cancel context.CancelFunc + + msg *proto.MsgReaderWriter } -const ( - msgTypeRegister = byte(iota) - msgTypeLogin - msgTypeLogout - msgTypeTermData - msgTypeWinsize - msgTypeCmd - msgTypeHeartbeat - msgTypeFile - msgTypeHttp - msgTypeAck -) - -const ( - msgTypeFileSend = byte(iota) - msgTypeFileRecv - msgTypeFileInfo - msgTypeFileData - msgTypeFileAck - msgTypeFileAbort -) - -const ( - msgRegAttrHeartbeat = iota - msgRegAttrDevid - msgRegAttrDescription - msgRegAttrToken - msgRegAttrGroup -) - -const ( - msgHeartbeatAttrUptime = iota -) - const ( devRegErrUnsupportedProto = iota + 1 devRegErrInvalidToken @@ -120,13 +84,13 @@ var DevRegErrMsg = map[byte]string{ } var DeviceMsgHandlers = map[byte]func(*Device, []byte) error{ - msgTypeHeartbeat: handleHeartbeatMsg, - msgTypeLogin: handleLoginMsg, - msgTypeLogout: handleLogoutMsg, - msgTypeTermData: handleTermDataMsg, - msgTypeFile: handleFileMsg, - msgTypeCmd: handleCmdMsg, - msgTypeHttp: handleHttpMsg, + proto.MsgTypeHeartbeat: handleHeartbeatMsg, + proto.MsgTypeLogin: handleLoginMsg, + proto.MsgTypeLogout: handleLogoutMsg, + proto.MsgTypeTermData: handleTermDataMsg, + proto.MsgTypeFile: handleFileMsg, + proto.MsgTypeCmd: handleCmdMsg, + proto.MsgTypeHttp: handleHttpMsg, } func (srv *RttyServer) ListenDevices() { @@ -185,7 +149,8 @@ func handleDeviceConnection(srv *RttyServer, conn net.Conn) { conn: conn, heartbeat: DefaultHeartbeat, timestamp: time.Now().Unix(), - br: bufio.NewReader(conn), + + msg: proto.NewMsgReaderWriter(proto.RoleRttys, conn), } defer dev.Close(srv) @@ -201,7 +166,7 @@ func handleDeviceConnection(srv *RttyServer, conn net.Conn) { return } - if typ != msgTypeRegister { + if typ != proto.MsgTypeRegister { log.Error().Msg("register msg expected first") return } @@ -214,7 +179,7 @@ func handleDeviceConnection(srv *RttyServer, conn net.Conn) { code := dev.Register(srv) - err = dev.WriteMsg(msgTypeRegister, "", append([]byte{code}, DevRegErrMsg[code]...)) + err = dev.WriteMsg(proto.MsgTypeRegister, code, DevRegErrMsg[code]) if err != nil { log.Error().Err(err).Msgf("send register to device '%s' fail", dev.id) return @@ -238,11 +203,11 @@ func handleDeviceConnection(srv *RttyServer, conn net.Conn) { return } - log.Debug().Msgf("device msg %s from device %s", msgTypeName(typ), dev.id) + log.Debug().Msgf("device msg %s from device %s", proto.MsgTypeName(typ), dev.id) handler, ok := DeviceMsgHandlers[typ] if !ok { - log.Error().Msgf("unexpected message '%s' from device '%s'", msgTypeName(typ), dev.id) + log.Error().Msgf("unexpected message '%s' from device '%s'", proto.MsgTypeName(typ), dev.id) return } @@ -254,85 +219,12 @@ func handleDeviceConnection(srv *RttyServer, conn net.Conn) { } } -func msgTypeName(typ byte) string { - switch typ { - case msgTypeRegister: - return "register" - case msgTypeLogin: - return "login" - case msgTypeLogout: - return "logout" - case msgTypeTermData: - return "termdata" - case msgTypeWinsize: - return "winsize" - case msgTypeCmd: - return "cmd" - case msgTypeHeartbeat: - return "heartbeat" - case msgTypeFile: - return "file" - case msgTypeHttp: - return "http" - case msgTypeAck: - return "ack" - default: - return fmt.Sprintf("unknown(%d)", typ) - } -} - func (dev *Device) ReadMsg() (byte, []byte, error) { - head := make([]byte, 3) - br := dev.br - - _, err := io.ReadFull(br, head) - if err != nil { - return 0, nil, err - } - - typ := head[0] - - msgLen := binary.BigEndian.Uint16(head[1:]) - - if cap(dev.readBuf) < int(msgLen) { - dev.readBuf = make([]byte, msgLen) - } else { - dev.readBuf = dev.readBuf[:msgLen] - } - - _, err = io.ReadFull(br, dev.readBuf) - if err != nil { - return 0, nil, err - } - - return typ, dev.readBuf, nil + return dev.msg.Read() } -func (dev *Device) WriteMsg(typ byte, sid string, data []byte) error { - bb := bytebufferpool.Get() - defer bytebufferpool.Put(bb) - - b := []byte{typ, 0, 0} - - binary.BigEndian.PutUint16(b[1:], uint16(len(sid)+len(data))) - - bb.Write(b) - bb.WriteString(sid) - bb.Write(data) - - _, err := bb.WriteTo(dev.conn) - - return err -} - -func (dev *Device) WriteFileMsg(typ byte, sid string, fileType byte, data []byte) error { - bb := bytebufferpool.Get() - defer bytebufferpool.Put(bb) - - bb.WriteByte(fileType) - bb.Write(data) - - return dev.WriteMsg(typ, sid, bb.Bytes()) +func (dev *Device) WriteMsg(typ byte, data ...any) error { + return dev.msg.Write(typ, data...) } func (dev *Device) Close(srv *RttyServer) { @@ -345,10 +237,6 @@ func (dev *Device) Close(srv *RttyServer) { } func (dev *Device) ParseRegister(b []byte) error { - if len(b) < 1 { - return fmt.Errorf("too short") - } - dev.proto = b[0] if dev.proto > 4 { @@ -359,15 +247,15 @@ func (dev *Device) ParseRegister(b []byte) error { for typ, val := range attrs { switch typ { - case msgRegAttrHeartbeat: + case proto.MsgRegAttrHeartbeat: dev.heartbeat = time.Duration(val[0]) * time.Second - case msgRegAttrDevid: + case proto.MsgRegAttrDevid: dev.id = string(val) - case msgRegAttrDescription: + case proto.MsgRegAttrDescription: dev.desc = string(val) - case msgRegAttrToken: + case proto.MsgRegAttrToken: dev.token = string(val) - case msgRegAttrGroup: + case proto.MsgRegAttrGroup: dev.group = string(val) } } @@ -389,15 +277,15 @@ func (dev *Device) ParseRegister(b []byte) error { return fmt.Errorf("not found device id") } - if len(dev.id) > 32 { + if len(dev.id) > proto.MaximumDevIDLen { return fmt.Errorf("device id too long") } - if len(dev.desc) > 126 { + if len(dev.desc) > proto.MaximumDescLen { return fmt.Errorf("device desc too long") } - if len(dev.group) > 16 { + if len(dev.group) > proto.MaximumGroupLen { return fmt.Errorf("device group too long") } @@ -449,7 +337,7 @@ func handleHeartbeatMsg(dev *Device, data []byte) error { if !parseHeartbeat(dev, data) { return fmt.Errorf("invalid heartbeat msg from device '%s'", dev.id) } - return dev.WriteMsg(msgTypeHeartbeat, "", nil) + return dev.WriteMsg(proto.MsgTypeHeartbeat) } func parseHeartbeat(dev *Device, data []byte) bool { @@ -461,7 +349,7 @@ func parseHeartbeat(dev *Device, data []byte) bool { for typ, val := range attrs { switch typ { - case msgHeartbeatAttrUptime: + case proto.MsgHeartbeatAttrUptime: dev.uptime = binary.BigEndian.Uint32(val) } } @@ -476,10 +364,6 @@ func parseHeartbeat(dev *Device, data []byte) bool { } func handleLogoutMsg(dev *Device, data []byte) error { - if len(data) < 32 { - return fmt.Errorf("invalid logout msg from device '%s'", dev.id) - } - sid := string(data[:32]) if val, loaded := dev.users.LoadAndDelete(sid); loaded { @@ -491,10 +375,6 @@ func handleLogoutMsg(dev *Device, data []byte) error { } func handleLoginMsg(dev *Device, data []byte) error { - if len(data) < 33 { - return fmt.Errorf("invalid login msg from device '%s'", dev.id) - } - sid := string(data[:32]) code := data[32] @@ -525,10 +405,6 @@ func handleLoginMsg(dev *Device, data []byte) error { } func handleTermDataMsg(dev *Device, data []byte) error { - if len(data) < 32 { - return fmt.Errorf("invalid term data msg from device '%s'", dev.id) - } - sid := string(data[:32]) if val, ok := dev.users.Load(sid); ok { @@ -541,10 +417,6 @@ func handleTermDataMsg(dev *Device, data []byte) error { } func handleFileMsg(dev *Device, data []byte) error { - if len(data) < 33 { - return fmt.Errorf("invalid file msg from device '%s'", dev.id) - } - sid := string(data[:32]) typ := data[32] @@ -552,21 +424,21 @@ func handleFileMsg(dev *Device, data []byte) error { user := val.(*User) switch typ { - case msgTypeFileSend: + case proto.MsgTypeFileSend: user.WriteMsg(websocket.TextMessage, fmt.Appendf(nil, `{"type":"sendfile", "name": "%s"}`, string(data[33:]))) - case msgTypeFileRecv: + case proto.MsgTypeFileRecv: user.WriteMsg(websocket.TextMessage, []byte(`{"type":"recvfile"}`)) - case msgTypeFileData: + case proto.MsgTypeFileData: data[32] = 1 user.WriteMsg(websocket.BinaryMessage, data[32:]) - case msgTypeFileAck: + case proto.MsgTypeFileAck: user.WriteMsg(websocket.TextMessage, []byte(`{"type":"fileAck"}`)) - case msgTypeFileAbort: + case proto.MsgTypeFileAbort: user.WriteMsg(websocket.BinaryMessage, []byte{1}) } } @@ -575,10 +447,6 @@ func handleFileMsg(dev *Device, data []byte) error { } func handleHttpMsg(dev *Device, data []byte) error { - if len(data) < 18 { - return fmt.Errorf("invalid http msg from device '%s'", dev.id) - } - addr := data[:18] data = data[18:] diff --git a/go.mod b/go.mod index e1154fc..2c2d533 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,8 @@ module github.com/zhaojh329/rttys/v5 -go 1.24 +go 1.24.4 + +toolchain go1.24.5 require ( github.com/dwdcth/consoleEx v0.0.0-20180521133551-f56f6eb78b76 @@ -15,6 +17,7 @@ require ( github.com/rs/zerolog v1.34.0 github.com/urfave/cli/v3 v3.3.8 github.com/valyala/bytebufferpool v1.0.0 + github.com/zhaojh329/rtty-go v1.1.0 ) require ( @@ -39,7 +42,7 @@ require ( golang.org/x/arch v0.18.0 // indirect golang.org/x/crypto v0.39.0 // indirect golang.org/x/net v0.41.0 // indirect - golang.org/x/sys v0.33.0 // indirect + golang.org/x/sys v0.34.0 // indirect golang.org/x/text v0.26.0 // indirect google.golang.org/protobuf v1.36.6 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index 2b7194e..775ee51 100644 --- a/go.sum +++ b/go.sum @@ -95,6 +95,8 @@ github.com/urfave/cli/v3 v3.3.8 h1:BzolUExliMdet9NlJ/u4m5vHSotJ3PzEqSAZ1oPMa/E= github.com/urfave/cli/v3 v3.3.8/go.mod h1:FJSKtM/9AiiTOJL4fJ6TbMUkxBXn7GO9guZqoZtpYpo= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/zhaojh329/rtty-go v1.1.0 h1:aZHI2prDFKSkn+8i5YxmS+ixpYGMgRlzqUZcxc7mlLQ= +github.com/zhaojh329/rtty-go v1.1.0/go.mod h1:U5f1woopWzirdSKpc6kI7NTFtAP8k8XU4mUQ0olm+i8= golang.org/x/arch v0.18.0 h1:WN9poc33zL4AzGxqf8VtpKUnGvMi8O9lhNyBMF/85qc= golang.org/x/arch v0.18.0/go.mod h1:bdwinDaKcfZUGpH09BB7ZmOfhalA8lQdzl62l8gGWsk= golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM= @@ -104,8 +106,8 @@ golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= -golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA= +golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M= golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA= google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= diff --git a/http.go b/http.go index 47ae259..0ed3634 100644 --- a/http.go +++ b/http.go @@ -19,6 +19,7 @@ import ( "sync/atomic" "time" + "github.com/zhaojh329/rtty-go/proto" "github.com/zhaojh329/rttys/v5/utils" "github.com/gin-gonic/gin" @@ -305,7 +306,7 @@ func sendHttpReq(dev *Device, https bool, srcAddr []byte, destAddr []byte, data bb.Write(destAddr) bb.Write(data) - dev.WriteMsg(msgTypeHttp, "", bb.Bytes()) + dev.WriteMsg(proto.MsgTypeHttp, bb) } func genDestAddr(addr string) []byte { diff --git a/user.go b/user.go index b28ee88..0603603 100644 --- a/user.go +++ b/user.go @@ -7,12 +7,12 @@ package main import ( "context" - "encoding/binary" "net/http" "sync" "sync/atomic" "time" + "github.com/zhaojh329/rtty-go/proto" "github.com/zhaojh329/rttys/v5/utils" "github.com/gin-gonic/gin" @@ -84,7 +84,7 @@ func handleUserConnection(srv *RttyServer, c *gin.Context) { defer user.Close() - if err := dev.WriteMsg(msgTypeLogin, sid, nil); err != nil { + if err := dev.WriteMsg(proto.MsgTypeLogin, sid); err != nil { log.Error().Msgf("send login msg to device %s fail: %v", dev.id, err) return } @@ -117,7 +117,7 @@ func (user *User) Close() { user.closed.Store(true) if _, loaded := dev.users.LoadAndDelete(sid); loaded { - dev.WriteMsg(msgTypeLogout, sid, nil) + dev.WriteMsg(proto.MsgTypeLogout, sid) } dev.pending.Delete(sid) @@ -172,9 +172,9 @@ func (user *User) handleMsg() { return } - typ := msgTypeTermData + typ := proto.MsgTypeTermData if data[0] == 1 { - typ = msgTypeFile + typ = proto.MsgTypeFile } err = dev.WriteMsg(typ, sid, data[1:]) @@ -189,30 +189,19 @@ func (user *User) handleMsg() { switch msg.Type { case "winsize": - b := make([]byte, 4) - - binary.BigEndian.PutUint16(b, msg.Cols) - binary.BigEndian.PutUint16(b[2:], msg.Rows) - - err = dev.WriteMsg(msgTypeWinsize, sid, b) + err = dev.WriteMsg(proto.MsgTypeWinsize, sid, msg.Cols, msg.Rows) case "ack": - b := make([]byte, 2) - binary.BigEndian.PutUint16(b, msg.Ack) - err = dev.WriteMsg(msgTypeAck, sid, b) + err = dev.WriteMsg(proto.MsgTypeAck, sid, msg.Ack) case "fileInfo": - b := make([]byte, 4+len(msg.Name)) - binary.BigEndian.PutUint32(b, msg.Size) - copy(b[4:], []byte(msg.Name)) - - err = dev.WriteFileMsg(msgTypeFile, sid, msgTypeFileInfo, b) + err = dev.WriteMsg(proto.MsgTypeFile, sid, proto.MsgTypeFileInfo, msg.Size, msg.Name) case "fileCanceled": - err = dev.WriteFileMsg(msgTypeFile, sid, msgTypeFileAbort, nil) + err = dev.WriteMsg(proto.MsgTypeFile, sid, proto.MsgTypeFileAbort) case "fileAck": - err = dev.WriteFileMsg(msgTypeFile, sid, msgTypeFileAck, nil) + err = dev.WriteMsg(proto.MsgTypeFile, sid, proto.MsgTypeFileAck) } }