fix: prevent concurrent map access with sync.Map

```
fatal error: concurrent map read and map write
```

Signed-off-by: Jianhui Zhao <zhaojh329@gmail.com>
This commit is contained in:
Jianhui Zhao
2025-05-28 14:08:37 +08:00
parent 181815046c
commit bb112787d0
4 changed files with 51 additions and 36 deletions

13
api.go
View File

@@ -252,7 +252,7 @@ func apiStart(br *broker) {
authorized.GET("/connect/:devid", func(c *gin.Context) {
if c.GetHeader("Upgrade") != "websocket" {
devid := c.Param("devid")
if _, ok := br.devices[devid]; !ok {
if _, ok := br.getDevice(devid); !ok {
c.Redirect(http.StatusFound, "/error/offline")
return
}
@@ -322,8 +322,7 @@ func apiStart(br *broker) {
Bound: username != "",
}
if dev, ok := br.devices[id]; ok {
dev := dev.(*device)
if dev, ok := br.getDevice(id); ok {
di.Connected = uint32(time.Now().Unix() - dev.timestamp)
di.Uptime = dev.uptime
di.Online = true
@@ -623,7 +622,7 @@ func apiStart(br *broker) {
}
for _, devid := range data.Devices {
if _, ok := br.devices[devid]; !ok {
if _, ok := br.getDevice(devid); !ok {
sql := fmt.Sprintf("DELETE FROM device WHERE id = '%s'", devid)
if username != "" {
@@ -641,8 +640,10 @@ func apiStart(br *broker) {
sid := c.Param("sid")
if fp, ok := br.fileProxy.Load(sid); ok {
fp := fp.(*fileProxy)
s := br.sessions[sid]
fp.Ack(s.dev, sid)
if s, ok := br.getSession(sid); ok {
fp.Ack(s.dev, sid)
}
defer func() {
if err := recover(); err != nil {

View File

@@ -28,12 +28,12 @@ type session struct {
type broker struct {
cfg *config.Config
devices map[string]client.Client
devices sync.Map
loginAck chan *loginAckMsg
logout chan string
register chan client.Client
unregister chan client.Client
sessions map[string]*session
sessions sync.Map
termMessage chan *termMessage
fileMessage chan *fileMessage
userMessage chan *usrMessage
@@ -52,8 +52,6 @@ func newBroker(cfg *config.Config) *broker {
logout: make(chan string, 1000),
register: make(chan client.Client, 1000),
unregister: make(chan client.Client, 1000),
devices: make(map[string]client.Client),
sessions: make(map[string]*session),
termMessage: make(chan *termMessage, 1000),
fileMessage: make(chan *fileMessage, 1000),
userMessage: make(chan *usrMessage, 1000),
@@ -90,6 +88,20 @@ func devAuth(cfg *config.Config, dev *device) bool {
return jsoniter.Get(body, "auth").ToBool()
}
func (br *broker) getDevice(devid string) (*device, bool) {
if dev, ok := br.devices.Load(devid); ok {
return dev.(*device), true
}
return nil, false
}
func (br *broker) getSession(sid string) (*session, bool) {
if dev, ok := br.sessions.Load(sid); ok {
return dev.(*session), true
}
return nil, false
}
func (br *broker) run() {
for {
select {
@@ -105,7 +117,7 @@ func (br *broker) run() {
err := byte(0)
msg := "OK"
if _, ok := br.devices[devid]; ok {
if _, ok := br.getDevice(devid); ok {
log.Error().Msg("Device ID conflicting: " + devid)
msg = "ID conflicting"
err = 1
@@ -121,7 +133,7 @@ func (br *broker) run() {
}
} else {
dev.registered = true
br.devices[devid] = c
br.devices.Store(devid, c)
dev.UpdateDb()
log.Info().Msgf("Device '%s' registered, proto %d", devid, dev.proto)
}
@@ -135,7 +147,7 @@ func (br *broker) run() {
})
}
} else {
if dev, ok := br.devices[devid]; ok {
if dev, ok := br.getDevice(devid); ok {
sid := utils.GenUniqueID("sid")
c.(*user).sid = sid
@@ -152,7 +164,7 @@ func (br *broker) run() {
}
})
br.sessions[sid] = s
br.sessions.Store(sid, s)
dev.WriteMsg(msgTypeLogin, []byte(sid))
log.Info().Msg("New session: " + sid)
@@ -174,26 +186,30 @@ func (br *broker) run() {
break
}
delete(br.devices, devid)
br.devices.Delete(devid)
dev.registered = false
for sid, s := range br.sessions {
br.sessions.Range(func(key, value any) bool {
sid := key.(string)
s := value.(*session)
if s.dev == c {
delete(br.sessions, sid)
br.sessions.Delete(sid)
s.user.Close()
log.Info().Msg("Delete session: " + sid)
}
}
return true
})
log.Info().Msgf("Device '%s' unregistered", devid)
} else {
sid := c.(*user).sid
if _, ok := br.sessions[sid]; ok {
delete(br.sessions, sid)
if _, ok := br.getSession(sid); ok {
br.sessions.Delete(sid)
if dev, ok := br.devices[devid]; ok {
if dev, ok := br.getDevice(devid); ok {
dev.WriteMsg(msgTypeLogout, []byte(sid))
}
@@ -202,7 +218,7 @@ func (br *broker) run() {
}
case msg := <-br.loginAck:
if s, ok := br.sessions[msg.sid]; ok {
if s, ok := br.getSession(msg.sid); ok {
if msg.isBusy {
userLoginAck(loginErrorBusy, s.user)
log.Error().Msg("login fail, device busy")
@@ -216,21 +232,21 @@ func (br *broker) run() {
// device active logout
// typically, executing the exit command at the terminal will case this
case sid := <-br.logout:
if s, ok := br.sessions[sid]; ok {
delete(br.sessions, sid)
if s, ok := br.getSession(sid); ok {
br.sessions.Delete(sid)
s.user.Close()
log.Info().Msg("Delete session: " + sid)
}
case msg := <-br.termMessage:
if s, ok := br.sessions[msg.sid]; ok {
if s, ok := br.getSession(msg.sid); ok {
s.user.WriteMsg(websocket.BinaryMessage, msg.data)
}
case msg := <-br.fileMessage:
sid := msg.sid
if s, ok := br.sessions[sid]; ok {
if s, ok := br.getSession(sid); ok {
typ := msg.data[0]
data := msg.data[1:]
@@ -265,8 +281,8 @@ func (br *broker) run() {
}
case msg := <-br.userMessage:
if s, ok := br.sessions[msg.sid]; ok {
if dev, ok := br.devices[s.dev.DeviceID()]; ok {
if s, ok := br.getSession(msg.sid); ok {
if dev, ok := br.getDevice(s.dev.DeviceID()); ok {
data := msg.data
if msg.typ == websocket.BinaryMessage {
@@ -324,7 +340,7 @@ func (br *broker) run() {
}
case req := <-br.cmdReq:
if dev, ok := br.devices[req.devid]; ok {
if dev, ok := br.getDevice(req.devid); ok {
dev.WriteMsg(msgTypeCmd, req.data)
}
@@ -332,7 +348,7 @@ func (br *broker) run() {
handleCmdResp(data)
case req := <-br.httpReq:
if dev, ok := br.devices[req.devid]; ok {
if dev, ok := br.getDevice(req.devid); ok {
dev.WriteMsg(msgTypeHttp, req.data)
}

View File

@@ -73,8 +73,7 @@ func handleCmdReq(br *broker, c *gin.Context) {
devid: devid,
}
_, ok := br.devices[devid]
if !ok {
if _, ok := br.getDevice(devid); !ok {
cmdErrReply(rttyCmdErrOffline, req)
return
}

View File

@@ -129,7 +129,7 @@ func doHttpProxy(brk *broker, c net.Conn) {
}
devid := cookie.Value
dev, ok := brk.devices[devid]
dev, ok := brk.getDevice(devid)
if !ok {
log.Debug().Msgf(`device "%s" offline`, devid)
return
@@ -185,7 +185,7 @@ func doHttpProxy(brk *broker, c net.Conn) {
log.Debug().Msgf("doHttpProxy devid: %s, https: %v, destaddr: %s", devid, https, hostHeaderRewrite)
hpw := &HttpProxyWriter{destAddr, srcAddr, hostHeaderRewrite, brk, dev.(*device), https}
hpw := &HttpProxyWriter{destAddr, srcAddr, hostHeaderRewrite, brk, dev, https}
req.Host = hostHeaderRewrite
hpw.WriteRequest(req)
@@ -315,8 +315,7 @@ func httpProxyRedirect(br *broker, c *gin.Context) {
return
}
_, ok := br.devices[devid]
if !ok {
if _, ok := br.getDevice(devid); !ok {
c.Redirect(http.StatusFound, "/error/offline")
return
}