From a0bfc21d117ccec8232e84833e2e304c772d3599 Mon Sep 17 00:00:00 2001 From: Jianhui Zhao Date: Fri, 4 Jul 2025 11:54:19 +0800 Subject: [PATCH] Refactor the message distribution process To solve a major problem: when there are many devices connected, the broker will blocked in processing messages. By the way, I've rewritten and reorganized the entire code architecture, improved performance. Some new features has been added: * support device grouping. * support show device's IP address. Change-Id: I250e18091be7fd42028c82767b6edef50b3f6d8f Signed-off-by: Jianhui Zhao --- .github/workflows/release.yml | 25 ++ Dockerfile | 3 +- api.go | 221 ++++++---- broker.go | 366 ---------------- build.sh | 3 +- client/client.go | 18 - command.go | 177 ++++---- config/config.go => config.go | 81 ++-- device.go | 788 ++++++++++++++++++++-------------- go.mod | 1 + go.sum | 2 + http.go | 479 +++++++++++---------- log/log.go | 2 + main.go | 108 +++-- rttys.conf | 2 +- rttys_stress_test.go | 200 +++++++++ server.go | 113 +++++ tlv.go | 38 -- ui/src/components/RttyCmd.vue | 2 +- ui/src/components/RttyWeb.vue | 8 +- ui/src/i18n/en.json | 3 +- ui/src/i18n/zh-CN.json | 5 +- ui/src/views/Home.vue | 98 +++-- ui/src/views/Rtty.vue | 55 ++- ui/vite.config.js | 6 +- user.go | 297 ++++++++----- utils/utils.go | 59 ++- version/version.go | 23 - 28 files changed, 1762 insertions(+), 1421 deletions(-) create mode 100644 .github/workflows/release.yml delete mode 100644 broker.go delete mode 100644 client/client.go rename config/config.go => config.go (64%) create mode 100644 rttys_stress_test.go create mode 100644 server.go delete mode 100644 tlv.go delete mode 100644 version/version.go diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000..2baa2e9 --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,25 @@ +name: release + +on: + push: + tags: + - 'v*' + +jobs: + docker: + runs-on: ubuntu-24.04 + steps: + - id: get-version + uses: battila7/get-version-action@v2 + + - uses: docker/login-action@v3 + with: + username: zhaojh329 + password: ${{ secrets.DOCKERHUB_TOKEN }} + + - uses: docker/build-push-action@v6 + with: + push: true + tags: | + zhaojh329/rttys:${{ steps.get-version.outputs.version-without-v }} + zhaojh329/rttys:latest diff --git a/Dockerfile b/Dockerfile index 443a693..10a9e7b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -8,10 +8,9 @@ WORKDIR /rttys-build COPY . . COPY --from=ui /rttys-ui/dist ui/dist RUN CGO_ENABLED=0 \ - VersionPath="rttys/version" \ GitCommit=$(git log --pretty=format:"%h" -1) \ BuildTime=$(date +%FT%T%z) \ - go build -ldflags="-s -w -X $VersionPath.gitCommit=$GitCommit -X $VersionPath.buildTime=$BuildTime" + go build -ldflags="-s -w -X main.gitCommit=$GitCommit -X main.buildTime=$BuildTime" FROM alpine:latest COPY --from=rttys /rttys-build/rttys /usr/bin/rttys diff --git a/api.go b/api.go index f868e4e..b703d33 100644 --- a/api.go +++ b/api.go @@ -1,3 +1,27 @@ +/* + * MIT License + * + * Copyright (c) 2019 Jianhui Zhao + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + package main import ( @@ -5,12 +29,10 @@ import ( "io/fs" "net" "net/http" - "os" "path" "strings" "time" - "rttys/config" "rttys/utils" "github.com/fanjindong/go-cache" @@ -26,38 +48,18 @@ const httpSessionExpire = 30 * time.Minute //go:embed ui/dist var staticFs embed.FS -func httpLogin(cfg *config.Config, password string) bool { - return cfg.Password == "" || cfg.Password == password -} - -func isLocalRequest(c *gin.Context) bool { - addr, _ := net.ResolveTCPAddr("tcp", c.Request.RemoteAddr) - return addr.IP.IsLoopback() -} - -func httpAuth(cfg *config.Config, c *gin.Context) bool { - if !cfg.LocalAuth && isLocalRequest(c) { - return true - } - - sid, err := c.Cookie("sid") - if err != nil || !httpSessions.Exists(sid) { - return false - } - - httpSessions.Expire(sid, httpSessionExpire) - - return true -} - -func apiStart(br *broker) { - cfg := br.cfg +func (srv *RttyServer) ListenAPI() error { + cfg := &srv.cfg gin.SetMode(gin.ReleaseMode) r := gin.New() - r.Use(gin.Recovery()) + r.Use(func(c *gin.Context) { + c.Next() + log.Debug().Msgf("%s - \"%s %s %s %d\"", c.ClientIP(), + c.Request.Method, c.Request.URL.Path, c.Request.Proto, c.Writer.Status()) + }) if cfg.AllowOrigins { log.Debug().Msg("Allow all origins") @@ -77,38 +79,52 @@ func apiStart(br *broker) { authorized.GET("/connect/:devid", func(c *gin.Context) { if c.GetHeader("Upgrade") != "websocket" { + group := c.Query("group") devid := c.Param("devid") - if _, ok := br.getDevice(devid); !ok { + if dev := srv.GetDevice(group, devid); dev == nil { c.Redirect(http.StatusFound, "/error/offline") return } - c.Redirect(http.StatusFound, "/rtty/"+devid) + c.Redirect(http.StatusFound, "/rtty/"+devid+"?group="+group) return } - serveUser(br, c) + handleUserConnection(srv, c) + }) + + authorized.GET("/groups", func(c *gin.Context) { + groups := []string{""} + + srv.groups.Range(func(key, value any) bool { + if key != "" { + groups = append(groups, key.(string)) + } + return true + }) + + c.JSON(http.StatusOK, groups) }) authorized.GET("/devs", func(c *gin.Context) { - type DeviceInfo struct { - ID string `json:"id"` - Connected uint32 `json:"connected"` - Uptime uint32 `json:"uptime"` - Description string `json:"description"` - Proto uint8 `json:"proto"` + devs := make([]*DeviceInfo, 0) + g := srv.GetGroup(c.Query("group"), false) + + if g == nil { + c.JSON(http.StatusOK, devs) + return } - devs := make([]DeviceInfo, 0) + g.devices.Range(func(key, value any) bool { + dev := value.(*Device) - br.devices.Range(func(key, value any) bool { - dev := value.(*device) - - devs = append(devs, DeviceInfo{ - ID: dev.id, - Description: dev.desc, - Connected: uint32(time.Now().Unix() - dev.timestamp), - Uptime: dev.uptime, - Proto: dev.proto, + devs = append(devs, &DeviceInfo{ + Group: dev.group, + ID: dev.id, + Desc: dev.desc, + Connected: uint32(time.Now().Unix() - dev.timestamp), + Uptime: dev.uptime, + Proto: dev.proto, + IPaddr: dev.conn.RemoteAddr().(*net.TCPAddr).IP.String(), }) return true @@ -118,24 +134,46 @@ func apiStart(br *broker) { }) authorized.GET("/dev/:devid", func(c *gin.Context) { - if dev, ok := br.getDevice(c.Param("devid")); ok { - c.JSON(http.StatusOK, gin.H{ - "description": dev.desc, - "connected": uint32(time.Now().Unix() - dev.timestamp), - "uptime": dev.uptime, - "proto": dev.proto, - }) + if dev := srv.GetDevice(c.Query("group"), c.Param("devid")); dev != nil { + info := &DeviceInfo{ + ID: dev.id, + Desc: dev.desc, + Connected: uint32(time.Now().Unix() - dev.timestamp), + Uptime: dev.uptime, + Proto: dev.proto, + IPaddr: dev.conn.RemoteAddr().(*net.TCPAddr).IP.String(), + } + c.JSON(http.StatusOK, info) } else { c.Status(http.StatusNotFound) } }) authorized.POST("/cmd/:devid", func(c *gin.Context) { - handleCmdReq(br, c) + cmdInfo := &CommandReqInfo{} + + err := c.BindJSON(&cmdInfo) + if err != nil || cmdInfo.Cmd == "" || cmdInfo.Username == "" { + cmdErrResp(c, rttyCmdErrInvalid) + return + } + + dev := srv.GetDevice(c.Query("group"), c.Param("devid")) + if dev == nil { + cmdErrResp(c, rttyCmdErrOffline) + return + } + + dev.handleCmdReq(c, cmdInfo) }) authorized.Any("/web/:devid/:proto/:addr/*path", func(c *gin.Context) { - httpProxyRedirect(br, c) + httpProxyRedirect(srv, c, "") + }) + + authorized.Any("/web2/:group/:devid/:proto/:addr/*path", func(c *gin.Context) { + group := c.Param("group") + httpProxyRedirect(srv, c, group) }) authorized.GET("/signout", func(c *gin.Context) { @@ -149,32 +187,6 @@ func apiStart(br *broker) { c.Status(http.StatusOK) }) - authorized.GET("/file/:sid", func(c *gin.Context) { - sid := c.Param("sid") - if fp, ok := br.fileProxy.Load(sid); ok { - fp := fp.(*fileProxy) - - if s, ok := br.getSession(sid); ok { - fp.Ack(s.dev, sid) - } - - defer func() { - if err := recover(); err != nil { - if ne, ok := err.(*net.OpError); ok { - if se, ok := ne.Err.(*os.SyscallError); ok { - if strings.Contains(strings.ToLower(se.Error()), "broken pipe") || strings.Contains(strings.ToLower(se.Error()), "connection reset by peer") { - fp.reader.Close() - } - } - } - } - }() - - c.DataFromReader(http.StatusOK, -1, "application/octet-stream", fp.reader, nil) - br.fileProxy.Delete(sid) - } - }) - r.POST("/signin", func(c *gin.Context) { type credentials struct { Password string `json:"password"` @@ -209,7 +221,11 @@ func apiStart(br *broker) { } }) - fs, _ := fs.Sub(staticFs, "ui/dist") + fs, err := fs.Sub(staticFs, "ui/dist") + if err != nil { + return err + } + root := http.FS(fs) fh := http.FileServer(root) @@ -246,8 +262,37 @@ func apiStart(br *broker) { fh.ServeHTTP(c.Writer, c.Request) }) - go func() { - log.Info().Msgf("Listen user on: %s", cfg.AddrUser) - log.Fatal().Err(r.Run(cfg.AddrUser)) - }() + ln, err := net.Listen("tcp", cfg.AddrUser) + if err != nil { + return err + } + defer ln.Close() + + log.Info().Msgf("Listen users on: %s", ln.Addr().(*net.TCPAddr)) + + return r.RunListener(ln) +} + +func httpLogin(cfg *Config, password string) bool { + return cfg.Password == "" || cfg.Password == password +} + +func isLocalRequest(c *gin.Context) bool { + addr, _ := net.ResolveTCPAddr("tcp", c.Request.RemoteAddr) + return addr.IP.IsLoopback() +} + +func httpAuth(cfg *Config, c *gin.Context) bool { + if !cfg.LocalAuth && isLocalRequest(c) { + return true + } + + sid, err := c.Cookie("sid") + if err != nil || !httpSessions.Exists(sid) { + return false + } + + httpSessions.Expire(sid, httpSessionExpire) + + return true } diff --git a/broker.go b/broker.go deleted file mode 100644 index 2d4e5aa..0000000 --- a/broker.go +++ /dev/null @@ -1,366 +0,0 @@ -package main - -import ( - "encoding/binary" - "fmt" - "io" - "net/http" - "strings" - "sync" - "time" - - "rttys/client" - "rttys/config" - "rttys/utils" - - "github.com/gorilla/websocket" - jsoniter "github.com/json-iterator/go" - "github.com/rs/zerolog/log" -) - -type session struct { - dev client.Client - user client.Client - confirmed bool -} - -type broker struct { - cfg *config.Config - devices sync.Map - loginAck chan *loginAckMsg - logout chan string - register chan client.Client - unregister chan client.Client - sessions sync.Map - termMessage chan *termMessage - fileMessage chan *fileMessage - userMessage chan *usrMessage - cmdResp chan []byte - cmdReq chan *commandReq - httpResp chan *httpResp - httpReq chan *httpReq - fileProxy sync.Map -} - -func newBroker(cfg *config.Config) *broker { - return &broker{ - cfg: cfg, - loginAck: make(chan *loginAckMsg, 1000), - logout: make(chan string, 1000), - register: make(chan client.Client, 1000), - unregister: make(chan client.Client, 1000), - termMessage: make(chan *termMessage, 1000), - fileMessage: make(chan *fileMessage, 1000), - userMessage: make(chan *usrMessage, 1000), - cmdResp: make(chan []byte, 1000), - cmdReq: make(chan *commandReq, 1000), - httpResp: make(chan *httpResp, 1000), - httpReq: make(chan *httpReq, 1000), - } -} - -func (br *broker) getDevice(devid string) (*device, bool) { - if dev, ok := br.devices.Load(devid); ok { - return dev.(*device), true - } - return nil, false -} - -func (br *broker) getSession(sid string) (*session, bool) { - if s, ok := br.sessions.Load(sid); ok { - return s.(*session), true - } - return nil, false -} - -func (br *broker) devRegister(dev *device) { - defer func() { - br.register <- dev - }() - - devid := dev.id - cfg := br.cfg - - if dev.proto < rttyProtoRequired { - log.Error().Msgf("%s: unsupported protocol version: %d, need %d", devid, dev.proto, rttyProtoRequired) - dev.err = devRegErrHookFailed - return - } - - if cfg.Token != "" && dev.token != cfg.Token { - log.Error().Msgf("%s: invalid token", devid) - dev.err = devRegErrInvalidToken - return - } - - devHookUrl := br.cfg.DevHookUrl - - if devHookUrl == "" { - return - } - - cli := &http.Client{ - Timeout: 3 * time.Second, - } - - data := fmt.Sprintf(`{"devid":"%s", "token":"%s"}`, dev.id, dev.token) - - resp, err := cli.Post(devHookUrl, "application/json", strings.NewReader(data)) - if err != nil { - log.Error().Msgf("%s: call device hook url fail:"+err.Error(), devid) - dev.err = devRegErrHookFailed - return - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - log.Error().Msgf("%s: call device hook url, StatusCode: %d", devid, resp.StatusCode) - dev.err = devRegErrHookFailed - } -} - -func (br *broker) run() { - defer logPanic() - - for { - select { - case c := <-br.register: - if c.Closed() { - break - } - - devid := c.DeviceID() - - if c.IsDevice() { - dev := c.(*device) - - if dev.err == 0 { - if _, ok := br.getDevice(devid); ok { - log.Error().Msg("Device ID conflicting: " + devid) - dev.err = devRegErrIdConflicting - } - } - - if dev.err == 0 { - dev.registered = true - br.devices.Store(dev.id, dev) - log.Info().Msgf("Device '%s' registered, proto %d, heartbeat %v", dev.id, dev.proto, dev.heartbeat) - } else { - // ensure the last packet was sent - time.AfterFunc(time.Millisecond*100, func() { - dev.Close() - }) - } - - dev.WriteMsg(msgTypeRegister, append([]byte{dev.err}, DevRegErrMsg[dev.err]...)) - } else { - if dev, ok := br.getDevice(devid); ok { - sid := utils.GenUniqueID() - - c.(*user).sid = sid - - s := &session{ - dev: dev, - user: c, - } - - time.AfterFunc(time.Second*3, func() { - if !s.confirmed { - log.Error().Msgf("Session '%s' confirm timeout", sid) - c.CloseConn() - } - }) - - br.sessions.Store(sid, s) - - dev.WriteMsg(msgTypeLogin, []byte(sid)) - log.Info().Msg("New session: " + sid) - } else { - userLoginAck(loginErrorOffline, c) - log.Error().Msgf("Not found the device '%s'", devid) - } - } - - case c := <-br.unregister: - devid := c.DeviceID() - - c.Close() - - if c.IsDevice() { - dev := c.(*device) - - if !dev.registered { - break - } - - br.devices.Delete(devid) - - dev.registered = false - - br.sessions.Range(func(key, value any) bool { - sid := key.(string) - s := value.(*session) - - if s.dev == c { - br.sessions.Delete(sid) - s.user.Close() - log.Info().Msg("Delete session: " + sid) - } - return true - }) - - log.Info().Msgf("Device '%s' unregistered", devid) - } else { - sid := c.(*user).sid - - if _, ok := br.getSession(sid); ok { - br.sessions.Delete(sid) - - if dev, ok := br.getDevice(devid); ok { - dev.WriteMsg(msgTypeLogout, []byte(sid)) - } - - log.Info().Msg("Delete session: " + sid) - } - } - - case msg := <-br.loginAck: - if s, ok := br.getSession(msg.sid); ok { - if msg.isBusy { - userLoginAck(loginErrorBusy, s.user) - log.Error().Msg("login fail, device busy") - } else { - s.confirmed = true - - userLoginAck(loginErrorNone, s.user) - } - } - - // device active logout - // typically, executing the exit command at the terminal will case this - case sid := <-br.logout: - if s, ok := br.getSession(sid); ok { - br.sessions.Delete(sid) - s.user.Close() - - log.Info().Msg("Delete session: " + sid) - } - - case msg := <-br.termMessage: - if s, ok := br.getSession(msg.sid); ok { - s.user.WriteMsg(websocket.BinaryMessage, msg.data) - } - - case msg := <-br.fileMessage: - sid := msg.sid - if s, ok := br.getSession(sid); ok { - typ := msg.data[0] - data := msg.data[1:] - - switch typ { - case msgTypeFileSend: - pipereader, pipewriter := io.Pipe() - br.fileProxy.Store(sid, &fileProxy{pipereader, pipewriter}) - s.user.WriteMsg(websocket.TextMessage, fmt.Appendf(nil, `{"type":"sendfile", "name": "%s"}`, string(data))) - - case msgTypeFileRecv: - s.user.WriteMsg(websocket.TextMessage, []byte(`{"type":"recvfile"}`)) - - case msgTypeFileData: - if fp, ok := br.fileProxy.Load(sid); ok { - fp := fp.(*fileProxy) - if len(data) == 0 { - fp.Close() - } else { - fp.Write(s.dev, sid, data) - } - } - - case msgTypeFileAck: - s.user.WriteMsg(websocket.TextMessage, []byte(`{"type":"fileAck"}`)) - - case msgTypeFileAbort: - if fp, ok := br.fileProxy.Load(sid); ok { - fp := fp.(*fileProxy) - fp.Close() - } - } - } - - case msg := <-br.userMessage: - if s, ok := br.getSession(msg.sid); ok { - if dev, ok := br.getDevice(s.dev.DeviceID()); ok { - data := msg.data - - if msg.typ == websocket.BinaryMessage { - typ := msgTypeTermData - if data[0] == 1 { - typ = msgTypeFile - } - dev.WriteMsg(typ, append([]byte(msg.sid), data[1:]...)) - } else { - typ := jsoniter.Get(data, "type").ToString() - - switch typ { - case "winsize": - b := [32 + 4]byte{} - - copy(b[:], msg.sid) - - cols := jsoniter.Get(data, "cols").ToUint() - rows := jsoniter.Get(data, "rows").ToUint() - - binary.BigEndian.PutUint16(b[32:], uint16(cols)) - binary.BigEndian.PutUint16(b[34:], uint16(rows)) - - dev.WriteMsg(msgTypeWinsize, b[:]) - - case "ack": - b := [32 + 2]byte{} - copy(b[:], msg.sid) - - ack := jsoniter.Get(data, "ack").ToUint() - binary.BigEndian.PutUint16(b[32:], uint16(ack)) - dev.WriteMsg(msgTypeAck, b[:]) - - case "fileInfo": - size := jsoniter.Get(data, "size").ToUint32() - name := jsoniter.Get(data, "name").ToString() - - b := make([]byte, 32+1+4+len(name)) - copy(b[:], msg.sid) - b[32] = msgTypeFileInfo - binary.BigEndian.PutUint32(b[33:], size) - copy(b[37:], name) - dev.WriteMsg(msgTypeFile, b[:]) - - case "fileCanceled": - b := [33]byte{} - copy(b[:], msg.sid) - b[32] = msgTypeFileAbort - dev.WriteMsg(msgTypeFile, b[:]) - } - } - } - } else { - log.Error().Msg("Not found sid: " + msg.sid) - } - - case req := <-br.cmdReq: - if dev, ok := br.getDevice(req.devid); ok { - dev.WriteMsg(msgTypeCmd, req.data) - } - - case data := <-br.cmdResp: - handleCmdResp(data) - - case req := <-br.httpReq: - if dev, ok := br.getDevice(req.devid); ok { - dev.WriteMsg(msgTypeHttp, req.data) - } - - case resp := <-br.httpResp: - handleHttpProxyResp(resp) - } - } -} diff --git a/build.sh b/build.sh index 0a42d4d..46b5541 100755 --- a/build.sh +++ b/build.sh @@ -1,6 +1,5 @@ #!/bin/sh -VersionPath="rttys/version" GitCommit=$(git log --pretty=format:"%h" -1) BuildTime=$(date +%FT%T%z) @@ -23,7 +22,7 @@ generate() { bin="rttys.exe" } - GOOS=$os GOARCH=$arch CGO_ENABLED=0 go build -ldflags="-s -w -X $VersionPath.gitCommit=$GitCommit -X $VersionPath.buildTime=$BuildTime" -o $dir/$bin && cp rttys.service $dir + GOOS=$os GOARCH=$arch CGO_ENABLED=0 go build -ldflags="-s -w -X main.GitCommit=$GitCommit -X main.BuildTime=$BuildTime" -o $dir/$bin && cp rttys.service $dir } generate $1 $2 diff --git a/client/client.go b/client/client.go deleted file mode 100644 index abde620..0000000 --- a/client/client.go +++ /dev/null @@ -1,18 +0,0 @@ -package client - -// Client abstract device and user -type Client interface { - WriteMsg(typ int, data []byte) - - // For users, return the device ID that the user wants to access - // For devices, return the ID of the device - DeviceID() string - - IsDevice() bool - - Close() - - CloseConn() - - Closed() bool -} diff --git a/command.go b/command.go index ed3bb44..5c67fdb 100644 --- a/command.go +++ b/command.go @@ -1,31 +1,51 @@ +/* + * MIT License + * + * Copyright (c) 2019 Jianhui Zhao + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + package main import ( "context" "net/http" + "rttys/utils" "strconv" - "strings" - "sync" "time" - "rttys/utils" - "github.com/gin-gonic/gin" - jsoniter "github.com/json-iterator/go" + "github.com/rs/zerolog/log" + "github.com/valyala/bytebufferpool" ) -const commandTimeout = 30 // second +type CommandReq struct { + cancel context.CancelFunc + acked bool + c *gin.Context +} -const ( - rttyCmdErrInvalid = 1001 - rttyCmdErrOffline = 1002 - rttyCmdErrTimeout = 1003 -) - -var cmdErrMsg = map[int]string{ - rttyCmdErrInvalid: "invalid format", - rttyCmdErrOffline: "device offline", - rttyCmdErrTimeout: "timeout", +type CommandReqInfo struct { + Cmd string `json:"cmd"` + Username string `json:"username"` + Params []string `json:"params"` } type CommandRespAttrs struct { @@ -39,85 +59,51 @@ type CommandRespInfo struct { Attrs CommandRespAttrs `json:"attrs"` } -type commandReqInfo struct { - Cmd string `json:"cmd"` - Username string `json:"username"` - Params []string `json:"params"` +const ( + rttyCmdErrInvalid = 1001 + rttyCmdErrOffline = 1002 + rttyCmdErrTimeout = 1003 +) + +var cmdErrMsg = map[int]string{ + rttyCmdErrInvalid: "invalid format", + rttyCmdErrOffline: "device offline", + rttyCmdErrTimeout: "timeout", } -type commandReq struct { - cancel context.CancelFunc - c *gin.Context - devid string - data []byte -} +func (dev *Device) handleCmdReq(c *gin.Context, info *CommandReqInfo) { + ctx, cancel := context.WithCancel(dev.ctx) + defer cancel() -var commands sync.Map - -func handleCmdResp(data []byte) { - info := &CommandRespInfo{} - - jsoniter.Unmarshal(data, info) - - if req, ok := commands.Load(info.Token); ok { - req := req.(*commandReq) - req.c.JSON(http.StatusOK, info.Attrs) - req.cancel() - } -} - -func cmdErrReply(err int, req *commandReq) { - req.c.JSON(http.StatusOK, gin.H{ - "err": err, - "msg": cmdErrMsg[err], - }) - req.cancel() -} - -func handleCmdReq(br *broker, c *gin.Context) { - devid := c.Param("devid") - - ctx, cancel := context.WithCancel(context.Background()) - - req := &commandReq{ + req := &CommandReq{ cancel: cancel, c: c, - devid: devid, - } - - if _, ok := br.getDevice(devid); !ok { - cmdErrReply(rttyCmdErrOffline, req) - return - } - - cmdInfo := commandReqInfo{} - - err := c.BindJSON(&cmdInfo) - if err != nil || cmdInfo.Username == "" || cmdInfo.Cmd == "" { - cmdErrReply(rttyCmdErrInvalid, req) - return } token := utils.GenUniqueID() - data := make([]string, 4) + msg := bytebufferpool.Get() + defer bytebufferpool.Put(msg) - data[0] = cmdInfo.Username - data[1] = cmdInfo.Cmd - data[2] = token - data[3] = string(byte(len(cmdInfo.Params))) + BpWriteCString(msg, info.Username) + BpWriteCString(msg, info.Cmd) + BpWriteCString(msg, token) - msg := []byte(strings.Join(data, string(byte(0)))) + msg.WriteByte(byte(len(info.Params))) - for _, param := range cmdInfo.Params { - msg = append(msg, param...) - msg = append(msg, 0) + for _, param := range info.Params { + BpWriteCString(msg, param) } - req.data = msg - br.cmdReq <- req + log.Debug().Msgf("send cmd request for device '%s', token '%s'", dev.id, token) - waitTime := commandTimeout + err := dev.WriteMsg(msgTypeCmd, "", msg.Bytes()) + if err != nil { + cmdErrResp(c, rttyCmdErrOffline) + return + } + + waitTime := CommandTimeout wait := c.Query("wait") if wait != "" { @@ -129,18 +115,39 @@ func handleCmdReq(br *broker, c *gin.Context) { return } - commands.Store(token, req) + dev.commands.Store(token, req) - if waitTime < 0 || waitTime > commandTimeout { - waitTime = commandTimeout + if waitTime < 0 || waitTime > CommandTimeout { + waitTime = CommandTimeout } tmr := time.NewTimer(time.Second * time.Duration(waitTime)) + log.Debug().Msgf("wait for cmd response for device '%s', token '%s', waitTime %ds", dev.id, token, waitTime) + select { case <-tmr.C: - cmdErrReply(rttyCmdErrTimeout, req) - commands.Delete(token) + cmdErrResp(c, rttyCmdErrTimeout) case <-ctx.Done(): } + + dev.commands.Delete(token) + + if !req.acked { + cmdErrResp(c, rttyCmdErrOffline) + } + + log.Debug().Msgf("handle cmd request for device '%s', token '%s' done", dev.id, token) +} + +func cmdErrResp(c *gin.Context, err int) { + c.JSON(http.StatusOK, gin.H{ + "err": err, + "msg": cmdErrMsg[err], + }) +} + +func BpWriteCString(bb *bytebufferpool.ByteBuffer, s string) { + bb.WriteString(s) + bb.WriteByte(0) } diff --git a/config/config.go b/config.go similarity index 64% rename from config/config.go rename to config.go index 986cb28..3fda9c4 100644 --- a/config/config.go +++ b/config.go @@ -1,4 +1,28 @@ -package config +/* + * MIT License + * + * Copyright (c) 2019 Jianhui Zhao + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +package main import ( "fmt" @@ -8,14 +32,12 @@ import ( "github.com/urfave/cli/v3" ) -// Config struct type Config struct { AddrDev string AddrUser string AddrHttpProxy string HttpProxyRedirURL string HttpProxyRedirDomain string - HttpProxyPort int Token string DevHookUrl string LocalAuth bool @@ -23,6 +45,29 @@ type Config struct { AllowOrigins bool } +func (cfg *Config) Parse(c *cli.Command) error { + conf := c.String("conf") + if conf != "" { + err := parseYamlCfg(cfg, conf) + if err != nil { + return err + } + } + + getFlagOpt(c, "addr-dev", &cfg.AddrDev) + getFlagOpt(c, "addr-user", &cfg.AddrUser) + getFlagOpt(c, "addr-http-proxy", &cfg.AddrHttpProxy) + getFlagOpt(c, "http-proxy-redir-url", &cfg.HttpProxyRedirURL) + getFlagOpt(c, "http-proxy-redir-domain", &cfg.HttpProxyRedirDomain) + getFlagOpt(c, "dev-hook-url", &cfg.DevHookUrl) + getFlagOpt(c, "local-auth", &cfg.LocalAuth) + getFlagOpt(c, "token", &cfg.Token) + getFlagOpt(c, "password", &cfg.Password) + getFlagOpt(c, "allow-origins", &cfg.AllowOrigins) + + return nil +} + func getConfigOpt(yamlCfg *yaml.File, name string, opt any) { val, err := yamlCfg.Get(name) if err != nil { @@ -74,33 +119,3 @@ func getFlagOpt(c *cli.Command, name string, opt any) { *opt = c.Bool(name) } } - -// Parse config -func Parse(c *cli.Command) (*Config, error) { - cfg := &Config{ - AddrDev: ":5912", - AddrUser: ":5913", - LocalAuth: true, - } - - conf := c.String("conf") - if conf != "" { - err := parseYamlCfg(cfg, conf) - if err != nil { - return nil, err - } - } - - getFlagOpt(c, "addr-dev", &cfg.AddrDev) - getFlagOpt(c, "addr-user", &cfg.AddrUser) - getFlagOpt(c, "addr-http-proxy", &cfg.AddrHttpProxy) - getFlagOpt(c, "http-proxy-redir-url", &cfg.HttpProxyRedirURL) - getFlagOpt(c, "http-proxy-redir-domain", &cfg.HttpProxyRedirDomain) - getFlagOpt(c, "dev-hook-url", &cfg.DevHookUrl) - getFlagOpt(c, "local-auth", &cfg.LocalAuth) - getFlagOpt(c, "token", &cfg.Token) - getFlagOpt(c, "password", &cfg.Password) - getFlagOpt(c, "allow-origins", &cfg.AllowOrigins) - - return cfg, nil -} diff --git a/device.go b/device.go index bd0245e..6fa668f 100644 --- a/device.go +++ b/device.go @@ -1,21 +1,85 @@ +/* + * MIT License + * + * Copyright (c) 2019 Jianhui Zhao + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + package main import ( "bufio" "bytes" + "context" "encoding/binary" + "fmt" "io" "net" - "rttys/client" + "net/http" "strings" "sync" "time" + "rttys/utils" + + "github.com/gorilla/websocket" + jsoniter "github.com/json-iterator/go" "github.com/rs/zerolog/log" + "github.com/valyala/bytebufferpool" ) +type DeviceInfo struct { + Group string `json:"group"` + ID string `json:"id"` + Connected uint32 `json:"connected"` + Uptime uint32 `json:"uptime"` + Desc string `json:"description"` + Proto uint8 `json:"proto"` + IPaddr string `json:"ipaddr"` +} + +type Device struct { + group string + id string + proto uint8 + desc string + timestamp int64 + uptime uint32 + token string + heartbeat time.Duration + + users sync.Map + pending sync.Map + commands sync.Map + https sync.Map + + conn net.Conn + br *bufio.Reader + readBuf []byte + close sync.Once + ctx context.Context + cancel context.CancelFunc +} + const ( - msgTypeRegister = iota + msgTypeRegister = byte(iota) msgTypeLogin msgTypeLogout msgTypeTermData @@ -25,11 +89,10 @@ const ( msgTypeFile msgTypeHttp msgTypeAck - msgTypeMax = msgTypeAck ) const ( - msgTypeFileSend = iota + msgTypeFileSend = byte(iota) msgTypeFileRecv msgTypeFileInfo msgTypeFileData @@ -42,6 +105,7 @@ const ( msgRegAttrDevid msgRegAttrDescription msgRegAttrToken + msgRegAttrGroup ) const ( @@ -55,6 +119,20 @@ const ( devRegErrIdConflicting ) +const ( + RttyProtoRequired uint8 = 3 + WaitRegistTimeout = 5 * time.Second + DefaultHeartbeat = 5 * time.Second + TermLoginTimeout = 5 * time.Second + CommandTimeout = 30 +) + +const ( + loginErrorNone = 0x00 + loginErrorOffline = 0x01 + loginErrorBusy = 0x02 +) + var DevRegErrMsg = map[byte]string{ 0: "Success", devRegErrUnsupportedProto: "Unsupported protocol", @@ -63,189 +141,111 @@ var DevRegErrMsg = map[byte]string{ devRegErrIdConflicting: "ID conflict", } -// Minimum protocol version requirements of rtty -const rttyProtoRequired uint8 = 3 - -type device struct { - br *broker - proto uint8 - heartbeat time.Duration - id string - desc string /* description of the device */ - timestamp int64 /* Connection time */ - uptime uint32 - token string - conn net.Conn - registered bool - closed bool - close sync.Once - err byte - send chan []byte // Buffered channel of outbound messages. - httpProxyCons sync.Map +var DeviceMsgHandlers = map[byte]func(*Device, []byte) error{ + msgTypeHeartbeat: handleHeartbeatMsg, + msgTypeLogin: handleLoginMsg, + msgTypeLogout: handleLogoutMsg, + msgTypeTermData: handleTermDataMsg, + msgTypeFile: handleFileMsg, + msgTypeCmd: handleCmdMsg, + msgTypeHttp: handleHttpMsg, } -type termMessage struct { - sid string - data []byte -} +func (srv *RttyServer) ListenDevices() { + cfg := &srv.cfg -type fileMessage struct { - sid string - data []byte -} + ln, err := net.Listen("tcp", cfg.AddrDev) + if err != nil { + log.Fatal().Msg(err.Error()) + } + defer ln.Close() -type fileProxy struct { - reader *io.PipeReader - writer *io.PipeWriter -} + log.Info().Msgf("Listen devices on: %s", ln.Addr().(*net.TCPAddr)) -func (fp *fileProxy) Read(b []byte) (int, error) { - return fp.reader.Read(b) -} - -func (fp *fileProxy) Write(dev client.Client, sid string, b []byte) { - go func() { - _, err := fp.writer.Write(b) + for { + conn, err := ln.Accept() if err != nil { - fp.Cancel(dev, sid) - dev.(*device).br.fileProxy.Delete(sid) + log.Error().Msg(err.Error()) + continue + } + + go handleDeviceConnection(srv, conn) + } +} + +func handleDeviceConnection(srv *RttyServer, conn net.Conn) { + defer logPanic() + + dev := &Device{ + conn: conn, + heartbeat: DefaultHeartbeat, + timestamp: time.Now().Unix(), + br: bufio.NewReader(conn), + } + defer dev.Close(srv) + + dev.ctx, dev.cancel = context.WithCancel(context.Background()) + + log.Debug().Msgf("new device '%s' connected", conn.RemoteAddr()) + + conn.SetReadDeadline(time.Now().Add(WaitRegistTimeout)) + + typ, data, err := dev.ReadMsg() + if err != nil { + log.Error().Msgf("read register msg fail: %v", err) + return + } + + if typ != msgTypeRegister { + log.Error().Msg("register msg expected first") + return + } + + if !dev.ParseRegister(data) { + log.Error().Msg("invalid device info") + return + } + + code := dev.Register(srv) + + err = dev.WriteMsg(msgTypeRegister, "", append([]byte{code}, DevRegErrMsg[code]...)) + if err != nil { + log.Printf("send register to device '%s' fail: %v", dev.id, err) + return + } + + if code != 0 { + return + } + + log.Info().Msgf("device '%s' registered, group '%s' proto %d, heartbeat %v", + dev.id, dev.group, dev.proto, dev.heartbeat) + + for { + conn.SetReadDeadline(time.Now().Add(dev.heartbeat * 3 / 2)) + + typ, data, err = dev.ReadMsg() + if err != nil { + if err != io.EOF { + log.Error().Msgf("read msg from device '%s' fail: %v", dev.id, err) + } return } - fp.Ack(dev, sid) - }() -} -func (fp *fileProxy) Close() { - fp.writer.Close() -} + log.Debug().Msgf("device msg %s from device %s", msgTypeName(typ), dev.id) -func (fp *fileProxy) Cancel(dev client.Client, sid string) { - b := make([]byte, 33) - copy(b, sid) - b[32] = msgTypeFileAbort - dev.WriteMsg(msgTypeFile, b) -} + handler, ok := DeviceMsgHandlers[typ] + if !ok { + log.Error().Msgf("unexpected message '%s' from device '%s'", msgTypeName(typ), dev.id) + return + } -func (fp *fileProxy) Ack(dev client.Client, sid string) { - b := make([]byte, 33) - copy(b, sid) - b[32] = msgTypeFileAck - dev.WriteMsg(msgTypeFile, b) -} - -type loginAckMsg struct { - devid string - sid string - isBusy bool -} - -func (dev *device) IsDevice() bool { - return true -} - -func (dev *device) DeviceID() string { - return dev.id -} - -func buildMsg(typ int, data []byte) []byte { - b := []byte{byte(typ), 0, 0} - - binary.BigEndian.PutUint16(b[1:], uint16(len(data))) - - return append(b, data...) -} - -func (dev *device) WriteMsg(typ int, data []byte) { - dev.send <- buildMsg(typ, data) -} - -func (dev *device) Closed() bool { - return dev.closed -} - -func (dev *device) CloseConn() { - dev.conn.Close() -} - -func (dev *device) Close() { - dev.close.Do(func() { - dev.closed = true - - log.Debug().Msgf("Device '%s' disconnected", dev.conn.RemoteAddr()) - - dev.CloseConn() - - close(dev.send) - - dev.httpProxyCons.Clear() - }) -} - -func parseDeviceInfo(dev *device, b []byte) bool { - if len(b) < 1 { - return false + err = handler(dev, data) + if err != nil { + log.Error().Msg(err.Error()) + return + } } - - dev.proto = b[0] - - if dev.proto > 4 { - attrs := parseTLV(b[1:]) - if attrs == nil { - return false - } - - for typ, val := range attrs { - switch typ { - case msgRegAttrHeartbeat: - dev.heartbeat = time.Duration(val[0]) * time.Second - case msgRegAttrDevid: - dev.id = string(val) - case msgRegAttrDescription: - dev.desc = string(val) - case msgRegAttrToken: - dev.token = string(val) - } - } - - return true - } - - b = b[1:] - - fields := bytes.Split(b, []byte{0}) - - if len(fields) < 3 { - return false - } - - dev.id = string(fields[0]) - dev.desc = string(fields[1]) - dev.token = string(fields[2]) - - return true -} - -func parseHeartbeat(dev *device, b []byte) bool { - if dev.proto > 4 { - attrs := parseTLV(b) - if attrs == nil { - return false - } - - for typ, val := range attrs { - switch typ { - case msgHeartbeatAttrUptime: - dev.uptime = binary.BigEndian.Uint32(val) - } - } - } else { - if len(b) < 4 { - return false - } - dev.uptime = binary.BigEndian.Uint32(b[:4]) - } - - return true } func msgTypeName(typ byte) string { @@ -271,194 +271,320 @@ func msgTypeName(typ byte) string { case msgTypeAck: return "ack" default: - return "unknown" + return fmt.Sprintf("unknown(%d)", typ) } } -func (dev *device) readLoop() { - defer logPanic() +func (dev *Device) ReadMsg() (byte, []byte, error) { + head := make([]byte, 3) + br := dev.br - logPrefix := dev.conn.RemoteAddr().String() + _, err := io.ReadFull(br, head) + if err != nil { + return 0, nil, err + } - tmr := time.AfterFunc(time.Second*5, func() { - log.Error().Msgf("%s: timeout", logPrefix) - dev.Close() + typ := head[0] + + msgLen := binary.BigEndian.Uint16(head[1:]) + + if cap(dev.readBuf) < int(msgLen) { + dev.readBuf = make([]byte, msgLen) + } else { + dev.readBuf = dev.readBuf[:msgLen] + } + + _, err = io.ReadFull(br, dev.readBuf) + if err != nil { + return 0, nil, err + } + + return typ, dev.readBuf, nil +} + +func (dev *Device) WriteMsg(typ byte, sid string, data []byte) error { + bb := bytebufferpool.Get() + defer bytebufferpool.Put(bb) + + b := []byte{typ, 0, 0} + + binary.BigEndian.PutUint16(b[1:], uint16(len(sid)+len(data))) + + bb.Write(b) + bb.WriteString(sid) + bb.Write(data) + + _, err := bb.WriteTo(dev.conn) + + return err +} + +func (dev *Device) WriteFileMsg(typ byte, sid string, fileType byte, data []byte) error { + bb := bytebufferpool.Get() + defer bytebufferpool.Put(bb) + + bb.WriteByte(fileType) + bb.Write(data) + + return dev.WriteMsg(typ, sid, bb.Bytes()) +} + +func (dev *Device) Close(srv *RttyServer) { + dev.close.Do(func() { + log.Error().Msgf("device '%s' disconnected", dev.id) + srv.DelDevice(dev) + dev.cancel() + dev.conn.Close() }) +} - defer func() { - dev.br.unregister <- dev - tmr.Stop() - }() +func (dev *Device) ParseRegister(b []byte) bool { + if len(b) < 1 { + return false + } - br := bufio.NewReader(dev.conn) + dev.proto = b[0] - for { - b, err := br.Peek(3) - if err != nil { - if err != io.EOF && !strings.Contains(err.Error(), "use of closed network connection") { - log.Error().Msgf("%s: %s", logPrefix, err.Error()) + if dev.proto > 4 { + attrs := utils.ParseTLV(b[1:]) + if attrs == nil { + return false + } + + for typ, val := range attrs { + switch typ { + case msgRegAttrHeartbeat: + dev.heartbeat = time.Duration(val[0]) * time.Second + case msgRegAttrDevid: + dev.id = string(val) + case msgRegAttrDescription: + dev.desc = string(val) + case msgRegAttrToken: + dev.token = string(val) + case msgRegAttrGroup: + dev.group = string(val) } - return } - br.Discard(3) + return true + } - typ := b[0] + b = b[1:] - if typ > msgTypeMax { - log.Error().Msgf("%s: invalid msg type: %d", logPrefix, typ) - return + fields := bytes.Split(b, []byte{0}) + + if len(fields) < 3 { + return false + } + + dev.id = string(fields[0]) + dev.desc = string(fields[1]) + dev.token = string(fields[2]) + + return true +} + +func (dev *Device) Register(srv *RttyServer) byte { + cfg := &srv.cfg + + if dev.proto < RttyProtoRequired { + log.Error().Msgf("minimum proto required %d, found %d for device '%s'", RttyProtoRequired, dev.proto, dev.id) + return devRegErrHookFailed + } + + if cfg.Token != "" && dev.token != cfg.Token { + log.Error().Msgf("invalid token for device '%s'", dev.id) + return devRegErrInvalidToken + } + + devHookUrl := cfg.DevHookUrl + if devHookUrl != "" { + cli := &http.Client{ + Timeout: 3 * time.Second, } - log.Debug().Msgf("%s: recv msg: %s", logPrefix, msgTypeName(typ)) + data := fmt.Sprintf(`{"group":"%s", "devid":"%s", "token":"%s"}`, dev.group, dev.id, dev.token) - msgLen := binary.BigEndian.Uint16(b[1:]) - - b = make([]byte, msgLen) - _, err = io.ReadFull(br, b) + resp, err := cli.Post(devHookUrl, "application/json", strings.NewReader(data)) if err != nil { - log.Error().Msg(err.Error()) - return + log.Error().Msgf("call device hook url fail for device %s: %v", dev.id, err) + return devRegErrHookFailed } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + log.Error().Msgf("call device hook url for device '%s', StatusCode: %d", dev.id, resp.StatusCode) + return devRegErrHookFailed + } + } + + if !srv.AddDevice(dev) { + return devRegErrIdConflicting + } + + return 0 +} + +func handleHeartbeatMsg(dev *Device, data []byte) error { + if !parseHeartbeat(dev, data) { + return fmt.Errorf("invalid heartbeat msg from device '%s'", dev.id) + } + return dev.WriteMsg(msgTypeHeartbeat, "", nil) +} + +func parseHeartbeat(dev *Device, data []byte) bool { + if dev.proto > 4 { + attrs := utils.ParseTLV(data) + if attrs == nil { + return false + } + + for typ, val := range attrs { + switch typ { + case msgHeartbeatAttrUptime: + dev.uptime = binary.BigEndian.Uint32(val) + } + } + } else { + if len(data) < 4 { + return false + } + dev.uptime = binary.BigEndian.Uint32(data[:4]) + } + + return true +} + +func handleLogoutMsg(dev *Device, data []byte) error { + if len(data) < 32 { + return fmt.Errorf("invalid logout msg from device '%s'", dev.id) + } + + sid := string(data[:32]) + + if val, loaded := dev.users.LoadAndDelete(sid); loaded { + user := val.(*User) + user.Close() + } + + return nil +} + +func handleLoginMsg(dev *Device, data []byte) error { + if len(data) < 33 { + return fmt.Errorf("invalid login msg from device '%s'", dev.id) + } + + sid := string(data[:32]) + code := data[32] + + if val, loaded := dev.pending.LoadAndDelete(sid); loaded { + user := val.(*User) + + ok := code == 0 + errCode := loginErrorNone + + if ok { + log.Debug().Msgf("login session '%s' for device '%s' success", sid, dev.id) + dev.users.Store(sid, user) + } else { + errCode = loginErrorBusy + log.Error().Msgf("login session '%s' for device '%s' fail, due to device busy", sid, dev.id) + } + + user.WriteMsg(websocket.TextMessage, + []byte(fmt.Appendf(nil, `{"type":"login","err":%d}`, errCode))) + + user.pending <- ok + } + + return nil +} + +func handleTermDataMsg(dev *Device, data []byte) error { + if len(data) < 32 { + return fmt.Errorf("invalid term data msg from device '%s'", dev.id) + } + + sid := string(data[:32]) + + if val, ok := dev.users.Load(sid); ok { + user := val.(*User) + data[31] = 0 + user.WriteMsg(websocket.BinaryMessage, data[31:]) + } + + return nil +} + +func handleFileMsg(dev *Device, data []byte) error { + if len(data) < 33 { + return fmt.Errorf("invalid file msg from device '%s'", dev.id) + } + + sid := string(data[:32]) + typ := data[32] + + if val, ok := dev.users.Load(sid); ok { + user := val.(*User) switch typ { - case msgTypeRegister: - if !parseDeviceInfo(dev, b) { - log.Error().Msgf("%s: msgTypeRegister: invalid", logPrefix) - return - } + case msgTypeFileSend: + user.WriteMsg(websocket.TextMessage, + fmt.Appendf(nil, `{"type":"sendfile", "name": "%s"}`, string(data[33:]))) - if dev.id == "" { - log.Error().Msgf("%s: msgTypeRegister: devid is empty", logPrefix) - return - } + case msgTypeFileRecv: + user.WriteMsg(websocket.TextMessage, []byte(`{"type":"recvfile"}`)) - logPrefix = dev.id + case msgTypeFileData: + data[32] = 1 + user.WriteMsg(websocket.BinaryMessage, data[32:]) - tmr.Stop() + case msgTypeFileAck: + user.WriteMsg(websocket.TextMessage, []byte(`{"type":"fileAck"}`)) - dev.br.devRegister(dev) - - case msgTypeLogin: - if msgLen < 33 { - log.Error().Msgf("%s: msgTypeLogin: invalid", logPrefix) - return - } - - sid := string(b[:32]) - code := b[32] - - dev.br.loginAck <- &loginAckMsg{dev.id, sid, code == 1} - - case msgTypeLogout: - if msgLen < 32 { - log.Error().Msgf("%s: msgTypeLogout: invalid", logPrefix) - return - } - - dev.br.logout <- string(b[:32]) - - case msgTypeTermData: - fallthrough - case msgTypeFile: - if msgLen < 32 { - log.Error().Msgf("%s: msgTypeTermData|msgTypeFile: invalid", logPrefix) - return - } - - sid := string(b[:32]) - - if typ == msgTypeFile { - dev.br.fileMessage <- &fileMessage{sid, b[32:]} - } else { - dev.br.termMessage <- &termMessage{sid, b[32:]} - } - - case msgTypeCmd: - if msgLen < 1 { - log.Error().Msgf("%s: msgTypeCmd: invalid", logPrefix) - return - } - - dev.br.cmdResp <- b - - case msgTypeHttp: - if msgLen < 18 { - log.Error().Msgf("%s: msgTypeHttp: invalid", logPrefix) - return - } - - dev.br.httpResp <- &httpResp{b, dev} - - case msgTypeHeartbeat: - if !parseHeartbeat(dev, b) { - log.Error().Msgf("%s: msgTypeHeartbeat: invalid", logPrefix) - return - } - - _, err := dev.conn.Write(buildMsg(msgTypeHeartbeat, []byte{})) - if err != nil { - log.Error().Msg(err.Error()) - return - } - - default: - log.Error().Msgf("%s: invalid msg type: %d", logPrefix, typ) - return - } - - tmr.Reset(dev.heartbeat * 3 / 2) - } -} - -func (dev *device) writeLoop() { - defer logPanic() - - defer func() { - dev.br.unregister <- dev - }() - - for msg := range dev.send { - _, err := dev.conn.Write(msg) - if err != nil { - log.Error().Msg(err.Error()) - return + case msgTypeFileAbort: + user.WriteMsg(websocket.BinaryMessage, []byte{1}) } } + + return nil } -func listenDevice(br *broker) { - cfg := br.cfg +func handleHttpMsg(dev *Device, data []byte) error { + if len(data) < 18 { + return fmt.Errorf("invalid http msg from device '%s'", dev.id) + } - ln, err := net.Listen("tcp", cfg.AddrDev) + addr := data[:18] + data = data[18:] + + if c, ok := dev.https.Load(string(addr)); ok { + c := c.(net.Conn) + if len(data) == 0 { + c.Close() + } else { + c.Write(data) + } + } + + return nil +} + +func handleCmdMsg(dev *Device, data []byte) error { + info := &CommandRespInfo{} + + err := jsoniter.Unmarshal(data, info) if err != nil { - log.Fatal().Msg(err.Error()) + return fmt.Errorf("parse command resp info error: %v", err) } - log.Info().Msgf("Listen device on: %s", cfg.AddrDev) + if val, ok := dev.commands.Load(info.Token); ok { + req := val.(*CommandReq) + req.acked = true + req.c.JSON(http.StatusOK, info.Attrs) + req.cancel() + } - go func() { - defer ln.Close() - - for { - conn, err := ln.Accept() - if err != nil { - log.Error().Msg(err.Error()) - continue - } - - log.Debug().Msgf("Device '%s' connected", conn.RemoteAddr()) - - dev := &device{ - br: br, - conn: conn, - heartbeat: time.Second * 5, - timestamp: time.Now().Unix(), - send: make(chan []byte, 256), - } - - go dev.readLoop() - go dev.writeLoop() - } - }() + return nil } diff --git a/go.mod b/go.mod index 2bc54d2..0012d1f 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,7 @@ require ( github.com/mattn/go-colorable v0.1.14 github.com/rs/zerolog v1.34.0 github.com/urfave/cli/v3 v3.3.8 + github.com/valyala/bytebufferpool v1.0.0 golang.org/x/term v0.32.0 ) diff --git a/go.sum b/go.sum index b59c827..7241feb 100644 --- a/go.sum +++ b/go.sum @@ -93,6 +93,8 @@ github.com/ugorji/go/codec v1.3.0 h1:Qd2W2sQawAfG8XSvzwhBeoGq71zXOC/Q1E9y/wUcsUA github.com/ugorji/go/codec v1.3.0/go.mod h1:pRBVtBSKl77K30Bv8R2P+cLSGaTtex6fsA2Wjqmfxj4= github.com/urfave/cli/v3 v3.3.8 h1:BzolUExliMdet9NlJ/u4m5vHSotJ3PzEqSAZ1oPMa/E= github.com/urfave/cli/v3 v3.3.8/go.mod h1:FJSKtM/9AiiTOJL4fJ6TbMUkxBXn7GO9guZqoZtpYpo= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= golang.org/x/arch v0.18.0 h1:WN9poc33zL4AzGxqf8VtpKUnGvMi8O9lhNyBMF/85qc= golang.org/x/arch v0.18.0/go.mod h1:bdwinDaKcfZUGpH09BB7ZmOfhalA8lQdzl62l8gGWsk= golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM= diff --git a/http.go b/http.go index f018a4e..1d49b73 100644 --- a/http.go +++ b/http.go @@ -1,7 +1,32 @@ +/* + * MIT License + * + * Copyright (c) 2019 Jianhui Zhao + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + package main import ( "bufio" + "context" "encoding/binary" "errors" "fmt" @@ -9,109 +34,83 @@ import ( "net/http" "net/url" "strconv" + "sync" + "sync/atomic" "time" - "rttys/client" "rttys/utils" - "github.com/fanjindong/go-cache" "github.com/gin-gonic/gin" "github.com/rs/zerolog/log" + "github.com/valyala/bytebufferpool" ) -type httpResp struct { - data []byte - dev client.Client +type HttpProxySession struct { + expire atomic.Int64 + ctx context.Context + cancel context.CancelFunc } -type httpReq struct { - devid string - data []byte -} - -var httpProxySessions = cache.NewMemCache(cache.WithClearInterval(time.Minute)) +var httpProxySessions = sync.Map{} const httpProxySessionsExpire = 15 * time.Minute -func handleHttpProxyResp(resp *httpResp) { - dev := resp.dev.(*device) - data := resp.data - addr := data[:18] - data = data[18:] - - if c, ok := dev.httpProxyCons.Load(string(addr)); ok { - c := c.(net.Conn) - if len(data) == 0 { - c.Close() - } else { - c.Write(data) - } - } +func (ses *HttpProxySession) Expire() { + ses.expire.Store(time.Now().Add(httpProxySessionsExpire).Unix()) } -func genDestAddr(addr string) []byte { - destIP, destPort, err := httpProxyVaildAddr(addr) +func (srv *RttyServer) ListenHttpProxy() { + cfg := &srv.cfg + + if cfg.AddrHttpProxy != "" { + addr, err := net.ResolveTCPAddr("tcp", cfg.AddrHttpProxy) + if err != nil { + log.Warn().Msg("invalid http proxy addr: " + err.Error()) + } else { + srv.httpProxyPort = addr.Port + } + } + + ln, err := net.Listen("tcp", cfg.AddrHttpProxy) if err != nil { - return nil + log.Fatal().Msg(err.Error()) } + defer ln.Close() - b := make([]byte, 6) - copy(b, destIP) + srv.httpProxyPort = ln.Addr().(*net.TCPAddr).Port - binary.BigEndian.PutUint16(b[4:], destPort) + log.Info().Msgf("Listen http proxy on: %s", ln.Addr().(*net.TCPAddr)) - return b -} + go httpProxySessionsClean() -func tcpAddr2Bytes(addr *net.TCPAddr) []byte { - b := make([]byte, 18) - - binary.BigEndian.PutUint16(b[:2], uint16(addr.Port)) - - copy(b[2:], addr.IP) - - return b -} - -type HttpProxyWriter struct { - destAddr []byte - srcAddr []byte - hostHeaderRewrite string - br *broker - dev client.Client - https bool -} - -func sendHttpReq(br *broker, c client.Client, https bool, srcAddr []byte, destAddr []byte, data []byte) { - msg := []byte{} - dev := c.(*device) - - if dev.proto > 3 { - if https { - msg = append(msg, 1) - } else { - msg = append(msg, 0) + for { + c, err := ln.Accept() + if err != nil { + log.Error().Msg(err.Error()) + continue } + + go doHttpProxy(srv, c) } - - msg = append(msg, srcAddr...) - msg = append(msg, destAddr...) - msg = append(msg, data...) - - br.httpReq <- &httpReq{dev.id, msg} } -func (rw *HttpProxyWriter) Write(p []byte) (n int, err error) { - sendHttpReq(rw.br, rw.dev, rw.https, rw.srcAddr, rw.destAddr, p) - return len(p), nil +func httpProxySessionsClean() { + for { + time.Sleep(time.Second * 30) + + httpProxySessions.Range(func(key, value any) bool { + ses := value.(*HttpProxySession) + if time.Now().Unix() > ses.expire.Load() { + log.Debug().Msgf("Http proxy session '%s' expired", key) + ses.cancel() + httpProxySessions.Delete(key) + } + return true + }) + } } -func (rw *HttpProxyWriter) WriteRequest(req *http.Request) { - req.Host = rw.hostHeaderRewrite - req.Write(rw) -} - -func doHttpProxy(brk *broker, c net.Conn) { +func doHttpProxy(srv *RttyServer, c net.Conn) { defer logPanic() defer c.Close() @@ -129,8 +128,15 @@ func doHttpProxy(brk *broker, c net.Conn) { } devid := cookie.Value - dev, ok := brk.getDevice(devid) - if !ok { + group := "" + + cookie, err = req.Cookie("rtty-http-group") + if err == nil { + group = cookie.Value + } + + dev := srv.GetDevice(group, devid) + if dev == nil { log.Debug().Msgf(`device "%s" offline`, devid) return } @@ -157,28 +163,29 @@ func doHttpProxy(brk *broker, c net.Conn) { destAddr := genDestAddr(hostHeaderRewrite) srcAddr := tcpAddr2Bytes(c.RemoteAddr().(*net.TCPAddr)) - dev.httpProxyCons.Store(string(srcAddr), c) - - exit := make(chan struct{}) - - if v, ok := httpProxySessions.Get(sid); ok { - go func() { - select { - case <-v.(chan struct{}): - c.Close() - case <-exit: - } - - dev.httpProxyCons.Delete(string(srcAddr)) - }() - } else { + sesVal, ok := httpProxySessions.Load(sid) + if !ok { log.Debug().Msgf(`not found httpProxySession "%s", devid "%s"`, sid, devid) return } - log.Debug().Msgf("doHttpProxy devid: %s, https: %v, destaddr: %s", devid, https, hostHeaderRewrite) + ses := sesVal.(*HttpProxySession) - hpw := &HttpProxyWriter{destAddr, srcAddr, hostHeaderRewrite, brk, dev, https} + ctx, cancel := context.WithCancel(ses.ctx) + defer cancel() + + go func() { + <-ctx.Done() + c.Close() + log.Debug().Msgf("http proxy conn closed, devid: %s, https: %v, destaddr: %s", devid, https, hostHeaderRewrite) + dev.https.Delete(string(srcAddr)) + }() + + log.Debug().Msgf("new http proxy conn, devid: %s, https: %v, destaddr: %s", devid, https, hostHeaderRewrite) + + dev.https.Store(string(srcAddr), c) + + hpw := &HttpProxyWriter{destAddr, srcAddr, hostHeaderRewrite, dev, https} req.Host = hostHeaderRewrite hpw.WriteRequest(req) @@ -189,67 +196,165 @@ func doHttpProxy(brk *broker, c net.Conn) { for { n, err := c.Read(b) if err != nil { - close(exit) return } - - sendHttpReq(brk, dev, https, srcAddr, destAddr, b[:n]) - - httpProxySessions.Expire(sid, httpProxySessionsExpire) + sendHttpReq(dev, https, srcAddr, destAddr, b[:n]) + ses.Expire() } } else { for { req, err := http.ReadRequest(br) if err != nil { - close(exit) return } - hpw.WriteRequest(req) - - httpProxySessions.Expire(sid, httpProxySessionsExpire) + ses.Expire() } } } -func listenHttpProxy(brk *broker) { - cfg := brk.cfg +func httpProxyRedirect(srv *RttyServer, c *gin.Context, group string) { + cfg := &srv.cfg + devid := c.Param("devid") + proto := c.Param("proto") + addr := c.Param("addr") + rawPath := c.Param("path") - if cfg.AddrHttpProxy != "" { - addr, err := net.ResolveTCPAddr("tcp", cfg.AddrHttpProxy) - if err != nil { - log.Warn().Msg("invalid http proxy addr: " + err.Error()) - } else { - cfg.HttpProxyPort = addr.Port - } - } + log.Debug().Msgf("httpProxyRedirect devid: %s, proto: %s, addr: %s, path: %s", devid, proto, addr, rawPath) - if cfg.HttpProxyPort == 0 { - log.Info().Msg("Automatically select an available port for http proxy") - } - - ln, err := net.Listen("tcp4", cfg.AddrHttpProxy) + _, _, err := httpProxyVaildAddr(addr) if err != nil { - log.Fatal().Msg(err.Error()) + log.Debug().Msgf("invalid addr: %s", addr) + c.Status(http.StatusBadRequest) + return } - cfg.HttpProxyPort = ln.Addr().(*net.TCPAddr).Port + path, err := url.Parse(rawPath) + if err != nil { + log.Debug().Msgf("invalid path: %s", rawPath) + c.Status(http.StatusBadRequest) + return + } - log.Info().Msgf("Listen http proxy on: %s", ln.Addr().(*net.TCPAddr)) + dev := srv.GetDevice(group, devid) + if dev == nil { + c.Redirect(http.StatusFound, "/error/offline") + return + } - go func() { - defer ln.Close() - - for { - c, err := ln.Accept() - if err != nil { - log.Error().Msg(err.Error()) - continue - } - - go doHttpProxy(brk, c) + location := c.Request.Header.Get("HttpProxyRedir") + if location == "" { + location = cfg.HttpProxyRedirURL + if location != "" { + log.Debug().Msgf("use HttpProxyRedirURL from config: %s, devid: %s", location, devid) } - }() + } else { + log.Debug().Msgf("use HttpProxyRedir from HTTP header: %s, devid: %s", location, devid) + } + + if location == "" { + host, _, err := net.SplitHostPort(c.Request.Host) + if err != nil { + host = c.Request.Host + } + + location = "http://" + host + + if srv.httpProxyPort != 80 { + location += fmt.Sprintf(":%d", srv.httpProxyPort) + } + } + + location += path.Path + + location += fmt.Sprintf("?_=%d", time.Now().Unix()) + + if path.RawQuery != "" { + location += "&" + path.RawQuery + } + + sid, err := c.Cookie("rtty-http-sid") + if err == nil { + if v, loaded := httpProxySessions.LoadAndDelete(sid); loaded { + s := v.(*HttpProxySession) + s.cancel() + log.Debug().Msgf(`del old httpProxySession "%s" for device "%s"`, sid, devid) + } + } + + sid = utils.GenUniqueID() + + ctx, cancel := context.WithCancel(dev.ctx) + + ses := &HttpProxySession{ + ctx: ctx, + cancel: cancel, + } + ses.Expire() + httpProxySessions.Store(sid, ses) + + log.Debug().Msgf(`new httpProxySession "%s" for device "%s"`, sid, devid) + + domain := c.Request.Header.Get("HttpProxyRedirDomain") + if domain == "" { + domain = cfg.HttpProxyRedirDomain + if domain != "" { + log.Debug().Msgf("set cookie domain from config: %s, devid: %s", domain, devid) + } + } else { + log.Debug().Msgf("set cookie domain from HTTP header: %s, devid: %s", domain, devid) + } + + c.SetCookie("rtty-http-sid", sid, 0, "", domain, false, true) + c.SetCookie("rtty-http-group", group, 0, "", domain, false, true) + c.SetCookie("rtty-http-devid", devid, 0, "", domain, false, true) + c.SetCookie("rtty-http-proto", proto, 0, "", domain, false, true) + c.SetCookie("rtty-http-destaddr", addr, 0, "", domain, false, true) + + c.Redirect(http.StatusFound, location) +} + +func sendHttpReq(dev *Device, https bool, srcAddr []byte, destAddr []byte, data []byte) { + bb := bytebufferpool.Get() + defer bytebufferpool.Put(bb) + + if dev.proto > 3 { + if https { + bb.WriteByte(1) + } else { + bb.WriteByte(0) + } + } + + bb.Write(srcAddr) + bb.Write(destAddr) + bb.Write(data) + + dev.WriteMsg(msgTypeHttp, "", bb.Bytes()) +} + +func genDestAddr(addr string) []byte { + destIP, destPort, err := httpProxyVaildAddr(addr) + if err != nil { + return nil + } + + b := make([]byte, 6) + copy(b, destIP) + + binary.BigEndian.PutUint16(b[4:], destPort) + + return b +} + +func tcpAddr2Bytes(addr *net.TCPAddr) []byte { + b := make([]byte, 18) + + binary.BigEndian.PutUint16(b[:2], uint16(addr.Port)) + + copy(b[2:], addr.IP) + + return b } func httpProxyVaildAddr(addr string) (net.IP, uint16, error) { @@ -274,94 +379,20 @@ func httpProxyVaildAddr(addr string) (net.IP, uint16, error) { return ip, uint16(port), nil } -func httpProxyRedirect(br *broker, c *gin.Context) { - cfg := br.cfg - devid := c.Param("devid") - proto := c.Param("proto") - addr := c.Param("addr") - rawPath := c.Param("path") - - log.Debug().Msgf("httpProxyRedirect devid: %s, proto: %s, addr: %s, path: %s", devid, proto, addr, rawPath) - - _, _, err := httpProxyVaildAddr(addr) - if err != nil { - log.Debug().Msgf("invalid addr: %s", addr) - c.Status(http.StatusBadRequest) - return - } - - path, err := url.Parse(rawPath) - if err != nil { - log.Debug().Msgf("invalid path: %s", rawPath) - c.Status(http.StatusBadRequest) - return - } - - if _, ok := br.getDevice(devid); !ok { - c.Redirect(http.StatusFound, "/error/offline") - return - } - - location := c.Request.Header.Get("HttpProxyRedir") - if location == "" { - location = cfg.HttpProxyRedirURL - if location != "" { - log.Debug().Msgf("use HttpProxyRedirURL from config: %s, devid: %s", location, devid) - } - } else { - log.Debug().Msgf("use HttpProxyRedir from HTTP header: %s, devid: %s", location, devid) - } - - if location == "" { - host, _, err := net.SplitHostPort(c.Request.Host) - if err != nil { - host = c.Request.Host - } - - location = "http://" + host - - if cfg.HttpProxyPort != 80 { - location += fmt.Sprintf(":%d", cfg.HttpProxyPort) - } - } - - location += path.Path - - location += fmt.Sprintf("?_=%d", time.Now().Unix()) - - if path.RawQuery != "" { - location += "&" + path.RawQuery - } - - sid, err := c.Cookie("rtty-http-sid") - if err == nil { - if v, ok := httpProxySessions.Get(sid); ok { - close(v.(chan struct{})) - httpProxySessions.Del(sid) - log.Debug().Msgf(`del old httpProxySession "%s" for device "%s"`, sid, devid) - } - } - - sid = utils.GenUniqueID() - - httpProxySessions.Set(sid, make(chan struct{}), cache.WithEx(httpProxySessionsExpire)) - - log.Debug().Msgf(`new httpProxySession "%s" for device "%s"`, sid, devid) - - domain := c.Request.Header.Get("HttpProxyRedirDomain") - if domain == "" { - domain = cfg.HttpProxyRedirDomain - if domain != "" { - log.Debug().Msgf("set cookie domain from config: %s, devid: %s", domain, devid) - } - } else { - log.Debug().Msgf("set cookie domain from HTTP header: %s, devid: %s", domain, devid) - } - - c.SetCookie("rtty-http-sid", sid, 0, "", domain, false, true) - c.SetCookie("rtty-http-devid", devid, 0, "", domain, false, true) - c.SetCookie("rtty-http-proto", proto, 0, "", domain, false, true) - c.SetCookie("rtty-http-destaddr", addr, 0, "", domain, false, true) - - c.Redirect(http.StatusFound, location) +type HttpProxyWriter struct { + destAddr []byte + srcAddr []byte + hostHeaderRewrite string + dev *Device + https bool +} + +func (rw *HttpProxyWriter) Write(p []byte) (n int, err error) { + sendHttpReq(rw.dev, rw.https, rw.srcAddr, rw.destAddr, p) + return len(p), nil +} + +func (rw *HttpProxyWriter) WriteRequest(req *http.Request) { + req.Host = rw.hostHeaderRewrite + req.Write(rw) } diff --git a/log/log.go b/log/log.go index 3236588..5d24e4a 100644 --- a/log/log.go +++ b/log/log.go @@ -59,6 +59,8 @@ func init() { } log.Logger = logger + + zerolog.SetGlobalLevel(zerolog.InfoLevel) } // SetPath set the log file path diff --git a/main.go b/main.go index d286335..da7b05a 100644 --- a/main.go +++ b/main.go @@ -1,13 +1,36 @@ +/* + * MIT License + * + * Copyright (c) 2019 Jianhui Zhao + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + package main import ( "context" "os" + "os/signal" "runtime" "runtime/debug" - - "rttys/config" - "rttys/version" + "syscall" xlog "rttys/log" @@ -16,6 +39,13 @@ import ( "github.com/urfave/cli/v3" ) +const RttysVersion = "4.4.5" + +var ( + GitCommit = "" + BuildTime = "" +) + func main() { defaultLogPath := "/var/log/rttys.log" if runtime.GOOS == "windows" { @@ -25,7 +55,7 @@ func main() { cmd := &cli.Command{ Name: "rttys", Usage: "The server side for rtty", - Version: version.Version(), + Version: RttysVersion, Flags: []cli.Flag{ &cli.StringFlag{ Name: "log", @@ -92,9 +122,7 @@ func main() { Usage: "more detailed output", }, }, - Action: func(c context.Context, cmd *cli.Command) error { - return runRttys(cmd) - }, + Action: cmdAction, } err := cmd.Run(context.Background(), os.Args) @@ -103,10 +131,12 @@ func main() { } } -func runRttys(c *cli.Command) error { - xlog.SetPath(c.String("log")) +func cmdAction(c context.Context, cmd *cli.Command) error { + defer logPanic() - switch c.String("log-level") { + xlog.SetPath(cmd.String("log")) + + switch cmd.String("log-level") { case "debug": zerolog.SetGlobalLevel(zerolog.DebugLevel) case "warn": @@ -117,41 +147,41 @@ func runRttys(c *cli.Command) error { zerolog.SetGlobalLevel(zerolog.InfoLevel) } - if c.Bool("verbose") { + if cmd.Bool("verbose") { xlog.Verbose() } - cfg, err := config.Parse(c) - if err != nil { - return err - } - log.Info().Msg("Go Version: " + runtime.Version()) log.Info().Msgf("Go OS/Arch: %s/%s", runtime.GOOS, runtime.GOARCH) - log.Info().Msg("Rttys Version: " + version.Version()) + log.Info().Msg("Rttys Version: " + RttysVersion) - gitCommit := version.GitCommit() - buildTime := version.BuildTime() - - if gitCommit != "" { - log.Info().Msg("Git Commit: " + version.GitCommit()) + if GitCommit != "" { + log.Info().Msg("Git Commit: " + GitCommit) } - if buildTime != "" { - log.Info().Msg("Build Time: " + version.BuildTime()) + if BuildTime != "" { + log.Info().Msg("Build Time: " + BuildTime) } - defer logPanic() + if runtime.GOOS != "windows" { + go signalHandle() + } - br := newBroker(cfg) - go br.run() + cfg := Config{ + AddrDev: ":5912", + AddrUser: ":5913", + LocalAuth: true, + } - listenDevice(br) - listenHttpProxy(br) - apiStart(br) + err := cfg.Parse(cmd) + if err != nil { + return err + } - select {} + srv := &RttyServer{cfg: cfg} + + return srv.Run() } func logPanic() { @@ -165,3 +195,19 @@ func saveCrashLog(p any, stack []byte) { log.Error().Msgf("%v", p) log.Error().Msg(string(stack)) } + +func signalHandle() { + + c := make(chan os.Signal, 1) + + signal.Notify(c, syscall.SIGUSR1) + + for s := range c { + switch s { + case syscall.SIGUSR1: + xlog.Verbose() + zerolog.SetGlobalLevel(zerolog.DebugLevel) + log.Debug().Msg("Debug mode enabled") + } + } +} diff --git a/rttys.conf b/rttys.conf index e0b4e74..beb9dc9 100644 --- a/rttys.conf +++ b/rttys.conf @@ -13,7 +13,7 @@ # This url will be called when the device is connected if this url is configured # The request method is POST -# The parameters are in JSON format: {"id": "device ID", "token": "device TOKEN"} +# The parameters are in JSON format: {"group": "group ID","id": "device ID", "token": "device TOKEN"} # Return HTTP 200 to indicate that the device is allowed to connect. #dev-hook-url: http://127.0.0.1:8080/rttys-dev-hook diff --git a/rttys_stress_test.go b/rttys_stress_test.go new file mode 100644 index 0000000..e030395 --- /dev/null +++ b/rttys_stress_test.go @@ -0,0 +1,200 @@ +/* + * MIT License + * + * Copyright (c) 2019 Jianhui Zhao + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +package main + +import ( + "context" + "flag" + "io" + "net/http" + "net/http/cookiejar" + "net/url" + "strings" + "sync" + "testing" + "time" + + "github.com/gorilla/websocket" + "github.com/rs/zerolog/log" +) + +func TestRttysStress(t *testing.T) { + duration := 10 * time.Minute + + timeoutFlag := flag.Lookup("test.timeout") + if timeoutFlag != nil { + duration = timeoutFlag.Value.(flag.Getter).Get().(time.Duration) + } + + cfg := Config{ + AddrDev: ":5912", + AddrUser: ":5913", + } + + srv := &RttyServer{cfg: cfg} + + go func() { + err := srv.Run() + if err != nil { + log.Fatal().Msg(err.Error()) + } + }() + + ctx, cancel := context.WithTimeout(context.Background(), duration-time.Second*2) + defer cancel() + + time.Sleep(time.Millisecond * 100) + + log.Info().Msg("Waiting for devices to connect for testing...") + + devices := &sync.Map{} + + for { + select { + case <-ctx.Done(): + log.Info().Msg("Test timeout, exiting...") + return + default: + time.Sleep(time.Second * 1) + + srv.groups.Range(func(key, value any) bool { + group := key.(string) + g := value.(*DeviceGroup) + g.devices.Range(func(key, value any) bool { + dev := value.(*Device) + if _, loaded := devices.LoadOrStore(dev.id, group+dev.id); !loaded { + go runDeviceTest(ctx, devices, group, dev.id) + } + return true + }) + return true + }) + } + } +} + +func runDeviceTest(ctx context.Context, devices *sync.Map, group, devID string) { + ctx, cancel := context.WithCancel(ctx) + + defer func() { + time.Sleep(time.Second) + cancel() + devices.Delete(group + devID) + }() + + go runHttpTest(ctx, group, devID) + + wg := &sync.WaitGroup{} + + for range 7 { + wg.Add(1) + go runWebSocketTest(ctx, group, devID, wg) + } + + wg.Wait() +} + +func runWebSocketTest(ctx context.Context, group, devID string, wg *sync.WaitGroup) { + conn, _, err := websocket.DefaultDialer.Dial("ws://127.0.0.1:5913/connect/"+devID+"?group="+group, nil) + if err != nil { + log.Fatal().Msg(err.Error()) + } + defer conn.Close() + defer wg.Done() + + go func() { + <-ctx.Done() + conn.Close() + }() + + go func() { + msg := []byte{0} + msg = append(msg, []byte("ttttttttttttttttttttttttttttt\n")...) + msg = append(msg, []byte("ttttttttttttttttttttttttttttt\n")...) + for { + err = conn.WriteMessage(websocket.BinaryMessage, msg) + if err != nil { + return + } + time.Sleep(time.Millisecond * 20) + } + }() + + for { + _, _, err := conn.ReadMessage() + if err != nil { + return + } + } +} + +func runHttpTest(ctx context.Context, group, devID string) { + for { + select { + case <-ctx.Done(): + return + default: + runHttpTestOnce(ctx, group, devID) + } + } +} + +func runHttpTestOnce(ctx context.Context, group, devID string) { + addr := "" + + if group == "" { + addr = "http://127.0.0.1:5913/web/" + } else { + addr = "http://127.0.0.1:5913/web2/" + group + "/" + } + + addr += devID + "/http/" + encodeURIComponent("127.0.0.1:80/") + + jar, _ := cookiejar.New(nil) + client := &http.Client{ + Jar: jar, + } + + request, _ := http.NewRequestWithContext(ctx, "GET", addr, nil) + + for range 10 { + res, err := client.Do(request) + if err != nil { + log.Info().Msg(err.Error()) + return + } + defer res.Body.Close() + + io.ReadAll(res.Body) + + time.Sleep(10 * time.Millisecond) + } +} + +func encodeURIComponent(str string) string { + r := url.QueryEscape(str) + r = strings.ReplaceAll(r, "+", "%20") + return r +} diff --git a/server.go b/server.go new file mode 100644 index 0000000..c8a0c30 --- /dev/null +++ b/server.go @@ -0,0 +1,113 @@ +/* + * MIT License + * + * Copyright (c) 2019 Jianhui Zhao + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +package main + +import ( + "sync" + "sync/atomic" + + "github.com/rs/zerolog/log" +) + +type RttyServer struct { + mu sync.RWMutex + groups sync.Map + cfg Config + httpProxyPort int +} + +type DeviceGroup struct { + devices sync.Map + count atomic.Int32 +} + +func (srv *RttyServer) Run() error { + log.Debug().Msgf("%+v", srv.cfg) + + go srv.ListenDevices() + go srv.ListenHttpProxy() + + return srv.ListenAPI() +} + +func (srv *RttyServer) GetDevice(group, id string) *Device { + srv.mu.RLock() + defer srv.mu.RUnlock() + + g := srv.GetGroup(group, false) + if g == nil { + return nil + } + + if v, ok := g.devices.Load(id); ok { + return v.(*Device) + } + + return nil +} + +func (srv *RttyServer) AddDevice(dev *Device) bool { + srv.mu.Lock() + defer srv.mu.Unlock() + + g := srv.GetGroup(dev.group, true) + + if _, loaded := g.devices.LoadOrStore(dev.id, dev); loaded { + return false + } + + g.count.Add(1) + + return true +} + +func (srv *RttyServer) DelDevice(dev *Device) { + srv.mu.Lock() + defer srv.mu.Unlock() + + g := srv.GetGroup(dev.group, false) + if g == nil { + return + } + + if _, loaded := g.devices.LoadAndDelete(dev.id); loaded { + if g.count.Add(-1) == 0 { + srv.groups.Delete(dev.group) + } + } +} + +func (srv *RttyServer) GetGroup(group string, create bool) *DeviceGroup { + if create { + val, _ := srv.groups.LoadOrStore(group, &DeviceGroup{}) + return val.(*DeviceGroup) + } else { + val, ok := srv.groups.Load(group) + if !ok { + return nil + } + return val.(*DeviceGroup) + } +} diff --git a/tlv.go b/tlv.go deleted file mode 100644 index 4b61150..0000000 --- a/tlv.go +++ /dev/null @@ -1,38 +0,0 @@ -package main - -import ( - "bytes" - "encoding/binary" - "io" -) - -func parseTLV(data []byte) map[uint8][]byte { - if len(data) < 3 { - return nil - } - - tlvs := map[uint8][]byte{} - - reader := bytes.NewReader(data) - - for reader.Len() > 0 { - typ, _ := reader.ReadByte() - - var length uint16 - err := binary.Read(reader, binary.BigEndian, &length) - if err != nil { - return nil - } - - value := make([]byte, length) - - _, err = io.ReadFull(reader, value) - if err != nil { - return nil - } - - tlvs[typ] = value - } - - return tlvs -} diff --git a/ui/src/components/RttyCmd.vue b/ui/src/components/RttyCmd.vue index 84de02b..6289f4b 100644 --- a/ui/src/components/RttyCmd.vue +++ b/ui/src/components/RttyCmd.vue @@ -148,7 +148,7 @@ export default { params: this.cmdData.params } - this.axios.post(`/cmd/${item.id}?wait=${this.cmdData.wait}`, data).then((response) => { + this.axios.post(`/cmd/${item.id}?group=${item.group}&wait=${this.cmdData.wait}`, data).then((response) => { if (this.cmdData.wait === 0) { this.cmdStatus.responses.push({ err: 0, diff --git a/ui/src/components/RttyWeb.vue b/ui/src/components/RttyWeb.vue index ee5b33b..110d1f6 100644 --- a/ui/src/components/RttyWeb.vue +++ b/ui/src/components/RttyWeb.vue @@ -76,12 +76,14 @@ export default { callback() }}] }, + group: '', devid: '', devProto: null } }, methods: { show(dev) { + this.group = dev.group this.devid = dev.id this.devProto = dev.proto this.formData.proto = 'http' @@ -122,7 +124,11 @@ export default { path = '/' const addr = encodeURIComponent(`${ipaddr}:${port}${path}`) - window.open(`/web/${this.devid}/${proto}/${addr}`) + + if (this.group) + window.open(`/web2/${this.group}/${this.devid}/${proto}/${addr}`) + else + window.open(`/web/${this.devid}/${proto}/${addr}`) }, 100) }) } diff --git a/ui/src/i18n/en.json b/ui/src/i18n/en.json index 51ef652..3b2cbda 100644 --- a/ui/src/i18n/en.json +++ b/ui/src/i18n/en.json @@ -10,6 +10,7 @@ "Reset": "Reset", "Signin Fail! password wrong.": "Signin Fail! password wrong.", "Refresh List": "Refresh List", + "ungrouped": "Ungrouped", "Please enter the filter key...": "Please enter the filter key...", "Execute command": "Execute command", "device-count": "Device Count: {count}", @@ -68,4 +69,4 @@ "Already copied to clipboard": "Already copied to clipboard", "Please use shortcut \"Shift+Insert\"": "Please use shortcut \"Shift+Insert\"", "Your device's rtty does not support https proxy, please upgrade it.": "Your device's rtty does not support https proxy, please upgrade it." -} +} \ No newline at end of file diff --git a/ui/src/i18n/zh-CN.json b/ui/src/i18n/zh-CN.json index 52fe8ef..46c3a71 100644 --- a/ui/src/i18n/zh-CN.json +++ b/ui/src/i18n/zh-CN.json @@ -10,6 +10,7 @@ "Reset": "复位", "Signin Fail! password wrong.": "登录失败,密码错误", "Refresh List": "刷新列表", + "ungrouped": "未分组", "Please enter the filter key...": "请输入关键字进行过滤...", "Execute command": "执行命令", "device-count": "设备数: {count}", @@ -24,6 +25,7 @@ "Uptime": "运行时长", "Description": "描述", "Please select the devices you want to operate": "请选择您要操作的设备", + "Access your devices's Web": "访问您的设备的 Web", "Command": "命令", "Cancel": "取消", "OK": "确定", @@ -54,7 +56,6 @@ "Upload or download file": "上传或者下载文件", "About": "关于", "Please execute command \"rtty -R\" or \"rtty -S\" in current terminal!": "请在当前终端中执行命令 \"rtty -R\" 或者 \"rtty -S\"", - "Access your devices's Web": "访问您的设备的 Web", "Proto": "协议", "ipaddr": "IP 地址", "port": "端口", @@ -69,4 +70,4 @@ "Already copied to clipboard": "已复制到剪切板", "Please use shortcut \"Shift+Insert\"": "请使用快捷键 \"Shift+Insert\"", "Your device's rtty does not support https proxy, please upgrade it.": "你的设备的 rtty 不支持 https 代理,请升级" -} +} \ No newline at end of file diff --git a/ui/src/views/Home.vue b/ui/src/views/Home.vue index 7d5de7e..e04f621 100644 --- a/ui/src/views/Home.vue +++ b/ui/src/views/Home.vue @@ -1,44 +1,54 @@