diff --git a/internal/msggateway/client.go b/internal/msggateway/client.go index 9c165e4dd..b32130c9a 100644 --- a/internal/msggateway/client.go +++ b/internal/msggateway/client.go @@ -20,6 +20,7 @@ import ( "fmt" "runtime/debug" "sync" + "sync/atomic" "github.com/openimsdk/open-im-server/v3/pkg/msgprocessor" @@ -70,7 +71,7 @@ type Client struct { IsBackground bool `json:"isBackground"` ctx *UserConnContext longConnServer LongConnServer - closed bool + closed atomic.Bool closedErr error token string } @@ -102,18 +103,14 @@ func (c *Client) ResetClient( c.ctx = ctx c.longConnServer = longConnServer c.IsBackground = false - c.closed = false + c.closed.Store(false) c.closedErr = nil c.token = token } func (c *Client) pingHandler(_ string) error { - c.conn.SetReadDeadline(pongWait) - err := c.writePongMsg() - if err != nil { - return err - } - return nil + _ = c.conn.SetReadDeadline(pongWait) + return c.writePongMsg() } func (c *Client) readMessage() { @@ -124,9 +121,11 @@ func (c *Client) readMessage() { } c.close() }() + c.conn.SetReadLimit(maxMessageSize) _ = c.conn.SetReadDeadline(pongWait) c.conn.SetPingHandler(c.pingHandler) + for { messageType, message, returnErr := c.conn.ReadMessage() if returnErr != nil { @@ -134,11 +133,13 @@ func (c *Client) readMessage() { c.closedErr = returnErr return } + log.ZDebug(c.ctx, "readMessage", "messageType", messageType) - if c.closed { // 连接刚置位已经关闭,但是协程还没退出的场景 + if c.closed.Load() { // 连接刚置位已经关闭,但是协程还没退出的场景 c.closedErr = ErrConnClosed return } + switch messageType { case MessageBinary: _ = c.conn.SetReadDeadline(pongWait) @@ -150,9 +151,11 @@ func (c *Client) readMessage() { case MessageText: c.closedErr = ErrNotSupportMessageProtocol return + case PingMessage: err := c.writePongMsg() log.ZError(c.ctx, "writePongMsg", err) + case CloseMessage: c.closedErr = ErrClientClosed return @@ -163,29 +166,40 @@ func (c *Client) readMessage() { func (c *Client) handleMessage(message []byte) error { if c.IsCompress { - var decompressErr error - message, decompressErr = c.longConnServer.DeCompress(message) - if decompressErr != nil { - return utils.Wrap(decompressErr, "") + var err error + message, err = c.longConnServer.DeCompress(message) + if err != nil { + return utils.Wrap(err, "") } } - var binaryReq Req - err := c.longConnServer.Decode(message, &binaryReq) + + var binaryReq = getReq() + defer freeReq(binaryReq) + + err := c.longConnServer.Decode(message, binaryReq) if err != nil { return utils.Wrap(err, "") } + if err := c.longConnServer.Validate(binaryReq); err != nil { return utils.Wrap(err, "") } + if binaryReq.SendID != c.UserID { return utils.Wrap(errors.New("exception conn userID not same to req userID"), binaryReq.String()) } + ctx := mcontext.WithMustInfoCtx( []string{binaryReq.OperationID, binaryReq.SendID, constant.PlatformIDToName(c.PlatformID), c.ctx.GetConnID()}, ) + log.ZDebug(ctx, "gateway req message", "req", binaryReq.String()) - var messageErr error - var resp []byte + + var ( + resp []byte + messageErr error + ) + switch binaryReq.ReqIdentifier { case WSGetNewestSeq: resp, messageErr = c.longConnServer.GetSeq(ctx, binaryReq) @@ -208,23 +222,29 @@ func (c *Client) handleMessage(message []byte) error { ) } - return c.replyMessage(ctx, &binaryReq, messageErr, resp) + return c.replyMessage(ctx, binaryReq, messageErr, resp) } -func (c *Client) setAppBackgroundStatus(ctx context.Context, req Req) ([]byte, error) { +func (c *Client) setAppBackgroundStatus(ctx context.Context, req *Req) ([]byte, error) { resp, isBackground, messageErr := c.longConnServer.SetUserDeviceBackground(ctx, req) if messageErr != nil { return nil, messageErr } + c.IsBackground = isBackground // todo callback return resp, nil } func (c *Client) close() { + if c.closed.Load() { + return + } + c.w.Lock() defer c.w.Unlock() - c.closed = true + + c.closed.Store(true) c.conn.Close() c.longConnServer.UnRegister(c) } @@ -244,6 +264,7 @@ func (c *Client) replyMessage(ctx context.Context, binaryReq *Req, err error, re if err != nil { log.ZWarn(ctx, "wireBinaryMsg replyMessage", err, "resp", mReply.String()) } + if binaryReq.ReqIdentifier == WsLogoutMsg { return errors.New("user logout") } @@ -280,39 +301,42 @@ func (c *Client) KickOnlineMessage() error { } func (c *Client) writeBinaryMsg(resp Resp) error { - c.w.Lock() - defer c.w.Unlock() - if c.closed { + if c.closed.Load() { return nil } - resultBuf := bufferPool.Get().([]byte) encodedBuf, err := c.longConnServer.Encode(resp) if err != nil { return utils.Wrap(err, "") } + + c.w.Lock() + defer c.w.Unlock() + _ = c.conn.SetWriteDeadline(writeWait) if c.IsCompress { - var compressErr error - resultBuf, compressErr = c.longConnServer.Compress(encodedBuf) + resultBuf, compressErr := c.longConnServer.Compress(encodedBuf) if compressErr != nil { return utils.Wrap(compressErr, "") } return c.conn.WriteMessage(MessageBinary, resultBuf) - } else { - return c.conn.WriteMessage(MessageBinary, encodedBuf) } + + return c.conn.WriteMessage(MessageBinary, encodedBuf) } func (c *Client) writePongMsg() error { - c.w.Lock() - defer c.w.Unlock() - if c.closed { + if c.closed.Load() { return nil } + + c.w.Lock() + defer c.w.Unlock() + err := c.conn.SetWriteDeadline(writeWait) if err != nil { return utils.Wrap(err, "") } + return c.conn.WriteMessage(PongMessage, nil) } diff --git a/internal/msggateway/message_handler.go b/internal/msggateway/message_handler.go index a8de5e0b5..dd5e00f18 100644 --- a/internal/msggateway/message_handler.go +++ b/internal/msggateway/message_handler.go @@ -16,6 +16,7 @@ package msggateway import ( "context" + "sync" "github.com/OpenIMSDK/protocol/push" "github.com/OpenIMSDK/tools/discoveryregistry" @@ -49,6 +50,27 @@ func (r *Req) String() string { return utils.StructToJsonString(tReq) } +var reqPool = sync.Pool{ + New: func() any { + return new(Req) + }, +} + +func getReq() *Req { + req := reqPool.Get().(*Req) + req.Data = nil + req.MsgIncr = "" + req.OperationID = "" + req.ReqIdentifier = 0 + req.SendID = "" + req.Token = "" + return req +} + +func freeReq(req *Req) { + reqPool.Put(req) +} + type Resp struct { ReqIdentifier int32 `json:"reqIdentifier"` MsgIncr string `json:"msgIncr"` @@ -69,12 +91,12 @@ func (r *Resp) String() string { } type MessageHandler interface { - GetSeq(context context.Context, data Req) ([]byte, error) - SendMessage(context context.Context, data Req) ([]byte, error) - SendSignalMessage(context context.Context, data Req) ([]byte, error) - PullMessageBySeqList(context context.Context, data Req) ([]byte, error) - UserLogout(context context.Context, data Req) ([]byte, error) - SetUserDeviceBackground(context context.Context, data Req) ([]byte, bool, error) + GetSeq(context context.Context, data *Req) ([]byte, error) + SendMessage(context context.Context, data *Req) ([]byte, error) + SendSignalMessage(context context.Context, data *Req) ([]byte, error) + PullMessageBySeqList(context context.Context, data *Req) ([]byte, error) + UserLogout(context context.Context, data *Req) ([]byte, error) + SetUserDeviceBackground(context context.Context, data *Req) ([]byte, bool, error) } var _ MessageHandler = (*GrpcHandler)(nil) @@ -94,7 +116,7 @@ func NewGrpcHandler(validate *validator.Validate, client discoveryregistry.SvcDi } } -func (g GrpcHandler) GetSeq(context context.Context, data Req) ([]byte, error) { +func (g GrpcHandler) GetSeq(context context.Context, data *Req) ([]byte, error) { req := sdkws.GetMaxSeqReq{} if err := proto.Unmarshal(data.Data, &req); err != nil { return nil, err @@ -113,7 +135,7 @@ func (g GrpcHandler) GetSeq(context context.Context, data Req) ([]byte, error) { return c, nil } -func (g GrpcHandler) SendMessage(context context.Context, data Req) ([]byte, error) { +func (g GrpcHandler) SendMessage(context context.Context, data *Req) ([]byte, error) { msgData := sdkws.MsgData{} if err := proto.Unmarshal(data.Data, &msgData); err != nil { return nil, err @@ -133,7 +155,7 @@ func (g GrpcHandler) SendMessage(context context.Context, data Req) ([]byte, err return c, nil } -func (g GrpcHandler) SendSignalMessage(context context.Context, data Req) ([]byte, error) { +func (g GrpcHandler) SendSignalMessage(context context.Context, data *Req) ([]byte, error) { resp, err := g.msgRpcClient.SendMsg(context, nil) if err != nil { return nil, err @@ -145,7 +167,7 @@ func (g GrpcHandler) SendSignalMessage(context context.Context, data Req) ([]byt return c, nil } -func (g GrpcHandler) PullMessageBySeqList(context context.Context, data Req) ([]byte, error) { +func (g GrpcHandler) PullMessageBySeqList(context context.Context, data *Req) ([]byte, error) { req := sdkws.PullMessageBySeqsReq{} if err := proto.Unmarshal(data.Data, &req); err != nil { return nil, err @@ -164,7 +186,7 @@ func (g GrpcHandler) PullMessageBySeqList(context context.Context, data Req) ([] return c, nil } -func (g GrpcHandler) UserLogout(context context.Context, data Req) ([]byte, error) { +func (g GrpcHandler) UserLogout(context context.Context, data *Req) ([]byte, error) { req := push.DelUserPushTokenReq{} if err := proto.Unmarshal(data.Data, &req); err != nil { return nil, err @@ -180,7 +202,7 @@ func (g GrpcHandler) UserLogout(context context.Context, data Req) ([]byte, erro return c, nil } -func (g GrpcHandler) SetUserDeviceBackground(_ context.Context, data Req) ([]byte, bool, error) { +func (g GrpcHandler) SetUserDeviceBackground(_ context.Context, data *Req) ([]byte, bool, error) { req := sdkws.SetAppBackgroundStatusReq{} if err := proto.Unmarshal(data.Data, &req); err != nil { return nil, false, err