diff --git a/internal/msggateway/client.go b/internal/msggateway/client.go index a9c5b7672..1ba36f458 100644 --- a/internal/msggateway/client.go +++ b/internal/msggateway/client.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "github.com/OpenIMSDK/Open-IM-Server/pkg/apiresp" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/log" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/mcontext" @@ -14,6 +15,10 @@ import ( "sync" ) +var ErrConnClosed = errors.New("conn has closed") +var ErrNotSupportMessageProtocol = errors.New("not support message protocol") +var ErrClientClosed = errors.New("client actively close the connection") + const ( // MessageText is for UTF-8 encoded text messages like JSON. MessageText = iota + 1 @@ -33,6 +38,8 @@ const ( PongMessage = 10 ) +type PongHandler func(string) error + type Client struct { w *sync.Mutex conn LongConn @@ -41,9 +48,9 @@ type Client struct { userID string isBackground bool ctx *UserConnContext - onlineAt int64 // 上线时间戳(毫秒) longConnServer LongConnServer closed bool + closedErr error } func newClient(ctx *UserConnContext, conn LongConn, isCompress bool) *Client { @@ -54,7 +61,6 @@ func newClient(ctx *UserConnContext, conn LongConn, isCompress bool) *Client { isCompress: isCompress, userID: ctx.GetUserID(), ctx: ctx, - onlineAt: utils.GetCurrentTimestampByMill(), } } func (c *Client) ResetClient(ctx *UserConnContext, conn LongConn, isCompress bool, longConnServer LongConnServer) { @@ -64,9 +70,12 @@ func (c *Client) ResetClient(ctx *UserConnContext, conn LongConn, isCompress boo c.isCompress = isCompress c.userID = ctx.GetUserID() c.ctx = ctx - c.onlineAt = utils.GetCurrentTimestampByMill() c.longConnServer = longConnServer } +func (c *Client) pongHandler(_ string) error { + c.conn.SetReadDeadline(pongWait) + return nil +} func (c *Client) readMessage() { defer func() { if r := recover(); r != nil { @@ -74,31 +83,36 @@ func (c *Client) readMessage() { } c.close() }() - //var returnErr error + c.conn.SetReadLimit(maxMessageSize) + _ = c.conn.SetReadDeadline(pongWait) + c.conn.SetPongHandler(c.pongHandler) for { messageType, message, returnErr := c.conn.ReadMessage() if returnErr != nil { - break + c.closedErr = returnErr + return } if c.closed == true { //连接刚置位已经关闭,但是协程还没退出的场景 - break + c.closedErr = ErrConnClosed + return } switch messageType { - case PingMessage: - case PongMessage: - case CloseMessage: - return - case MessageText: case MessageBinary: - if len(message) == 0 { - continue + parseDataErr := c.handleMessage(message) + if parseDataErr != nil { + c.closedErr = parseDataErr + return } - returnErr = c.handleMessage(message) - if returnErr != nil { - log.ZError(context.Background(), "WSGetNewestSeq", returnErr) - break - } - + case MessageText: + c.closedErr = ErrNotSupportMessageProtocol + return + case PingMessage: + err := c.writePongMsg() + log.ZError(c.ctx, "writePongMsg", err) + case CloseMessage: + c.closedErr = ErrClientClosed + return + default: } } @@ -120,7 +134,7 @@ func (c *Client) handleMessage(message []byte) error { return utils.Wrap(err, "") } if binaryReq.SendID != c.userID { - return errors.New("exception conn userID not same to req userID") + return utils.Wrap(errors.New("exception conn userID not same to req userID"), binaryReq.String()) } ctx := mcontext.WithMustInfoCtx([]string{binaryReq.OperationID, binaryReq.SendID, constant.PlatformIDToName(c.platformID), c.ctx.GetConnID()}) var messageErr error @@ -128,8 +142,6 @@ func (c *Client) handleMessage(message []byte) error { switch binaryReq.ReqIdentifier { case WSGetNewestSeq: resp, messageErr = c.longConnServer.GetSeq(ctx, binaryReq) - log.ZError(ctx, "WSGetNewestSeq", messageErr, "resp", resp) - case WSSendMsg: resp, messageErr = c.longConnServer.SendMessage(ctx, binaryReq) case WSSendSignalMsg: @@ -166,13 +178,16 @@ func (c *Client) close() { } func (c *Client) replyMessage(binaryReq *Req, err error, resp []byte) { + errResp := apiresp.ParseError(err) mReply := Resp{ ReqIdentifier: binaryReq.ReqIdentifier, MsgIncr: binaryReq.MsgIncr, OperationID: binaryReq.OperationID, + ErrCode: errResp.ErrCode, + ErrMsg: errResp.ErrMsg, Data: resp, } - _ = c.writeMsg(mReply) + _ = c.writeBinaryMsg(mReply) } func (c *Client) PushMessage(ctx context.Context, msgData *sdkws.MsgData) error { data, err := proto.Marshal(msgData) @@ -184,15 +199,14 @@ func (c *Client) PushMessage(ctx context.Context, msgData *sdkws.MsgData) error OperationID: mcontext.GetOperationID(ctx), Data: data, } - return c.writeMsg(resp) - + return c.writeBinaryMsg(resp) } func (c *Client) KickOnlineMessage(ctx context.Context) error { return nil } -func (c *Client) writeMsg(resp Resp) error { +func (c *Client) writeBinaryMsg(resp Resp) error { c.w.Lock() defer c.w.Unlock() if c.closed == true { @@ -204,7 +218,7 @@ func (c *Client) writeMsg(resp Resp) error { if err != nil { return utils.Wrap(err, "") } - _ = c.conn.SetWriteTimeout(60) + _ = c.conn.SetWriteDeadline(writeWait) if c.isCompress { var compressErr error resultBuf, compressErr = c.longConnServer.Compress(encodeBuf) @@ -216,3 +230,14 @@ func (c *Client) writeMsg(resp Resp) error { return c.conn.WriteMessage(MessageBinary, encodedBuf) } } + +func (c *Client) writePongMsg() error { + c.w.Lock() + defer c.w.Unlock() + if c.closed == true { + return nil + } + _ = c.conn.SetWriteDeadline(writeWait) + return c.conn.WriteMessage(PongMessage, nil) + +} diff --git a/internal/msggateway/constant.go b/internal/msggateway/constant.go index 21a0435b0..5f2fa2b17 100644 --- a/internal/msggateway/constant.go +++ b/internal/msggateway/constant.go @@ -1,5 +1,7 @@ package msggateway +import "time" + const ( WsUserID = "sendID" CommonUserID = "userID" @@ -25,3 +27,14 @@ const ( WsSetBackgroundStatus = 2004 WSDataError = 3001 ) + +const ( + // Time allowed to write a message to the peer. + writeWait = 10 * time.Second + + // Time allowed to read the next pong message from the peer. + pongWait = 30 * time.Second + + // Maximum message size allowed from peer. + maxMessageSize = 51200 +) diff --git a/internal/msggateway/long_conn.go b/internal/msggateway/long_conn.go index 0fc6f1f35..819094ee0 100644 --- a/internal/msggateway/long_conn.go +++ b/internal/msggateway/long_conn.go @@ -13,31 +13,33 @@ type LongConn interface { WriteMessage(messageType int, message []byte) error //Read message from connection. ReadMessage() (int, []byte, error) - //SetReadTimeout sets the read deadline on the underlying network connection, + // SetReadDeadline sets the read deadline on the underlying network connection, //after a read has timed out, will return an error. - SetReadTimeout(timeout int) error - //SetWriteTimeout sets the write deadline when send message,when read has timed out,will return error. - SetWriteTimeout(timeout int) error + SetReadDeadline(timeout time.Duration) error + // SetWriteDeadline sets the write deadline when send message,when read has timed out,will return error. + SetWriteDeadline(timeout time.Duration) error //Try to dial a connection,url must set auth args,header can control compress data Dial(urlStr string, requestHeader http.Header) (*http.Response, error) //Whether the connection of the current long connection is nil IsNil() bool //Set the connection of the current long connection to nil SetConnNil() + // SetReadLimit sets the maximum size for a message read from the peer.bytes + SetReadLimit(limit int64) + SetPongHandler(handler PongHandler) //Check the connection of the current and when it was sent are the same //CheckSendConnDiffNow() bool // GenerateLongConn(w http.ResponseWriter, r *http.Request) error } type GWebSocket struct { - protocolType int - conn *websocket.Conn - handshakeTimeout time.Duration - readBufferSize, WriteBufferSize int + protocolType int + conn *websocket.Conn + handshakeTimeout time.Duration } -func newGWebSocket(protocolType int, handshakeTimeout time.Duration, readBufferSize int) *GWebSocket { - return &GWebSocket{protocolType: protocolType, handshakeTimeout: handshakeTimeout, readBufferSize: readBufferSize} +func newGWebSocket(protocolType int, handshakeTimeout time.Duration) *GWebSocket { + return &GWebSocket{protocolType: protocolType, handshakeTimeout: handshakeTimeout} } func (d *GWebSocket) Close() error { @@ -46,7 +48,6 @@ func (d *GWebSocket) Close() error { func (d *GWebSocket) GenerateLongConn(w http.ResponseWriter, r *http.Request) error { upgrader := &websocket.Upgrader{ HandshakeTimeout: d.handshakeTimeout, - ReadBufferSize: d.readBufferSize, CheckOrigin: func(r *http.Request) bool { return true }, } conn, err := upgrader.Upgrade(w, r, nil) @@ -69,12 +70,12 @@ func (d *GWebSocket) WriteMessage(messageType int, message []byte) error { func (d *GWebSocket) ReadMessage() (int, []byte, error) { return d.conn.ReadMessage() } -func (d *GWebSocket) SetReadTimeout(timeout int) error { - return d.conn.SetReadDeadline(time.Now().Add(time.Duration(timeout) * time.Second)) +func (d *GWebSocket) SetReadDeadline(timeout time.Duration) error { + return d.conn.SetReadDeadline(time.Now().Add(timeout)) } -func (d *GWebSocket) SetWriteTimeout(timeout int) error { - return d.conn.SetWriteDeadline(time.Now().Add(time.Duration(timeout) * time.Second)) +func (d *GWebSocket) SetWriteDeadline(timeout time.Duration) error { + return d.conn.SetWriteDeadline(time.Now().Add(timeout)) } func (d *GWebSocket) Dial(urlStr string, requestHeader http.Header) (*http.Response, error) { @@ -96,6 +97,12 @@ func (d *GWebSocket) IsNil() bool { func (d *GWebSocket) SetConnNil() { d.conn = nil } +func (d *GWebSocket) SetReadLimit(limit int64) { + d.conn.SetReadLimit(limit) +} +func (d *GWebSocket) SetPongHandler(handler PongHandler) { + d.conn.SetPongHandler(handler) +} //func (d *GWebSocket) CheckSendConnDiffNow() bool { // return d.conn == d.sendConn diff --git a/internal/msggateway/message_handler.go b/internal/msggateway/message_handler.go index 0e8b5fbe2..b38108c1a 100644 --- a/internal/msggateway/message_handler.go +++ b/internal/msggateway/message_handler.go @@ -2,10 +2,10 @@ package msggateway import ( "context" - "github.com/OpenIMSDK/Open-IM-Server/pkg/common/log" "github.com/OpenIMSDK/Open-IM-Server/pkg/proto/msg" "github.com/OpenIMSDK/Open-IM-Server/pkg/proto/sdkws" "github.com/OpenIMSDK/Open-IM-Server/pkg/rpcclient/notification" + "github.com/OpenIMSDK/Open-IM-Server/pkg/utils" "github.com/go-playground/validator/v10" "github.com/golang/protobuf/proto" ) @@ -18,11 +18,16 @@ type Req struct { MsgIncr string `json:"msgIncr" validate:"required"` Data []byte `json:"data"` } + +func (r *Req) String() string { + return utils.StructToJsonString(r) +} + type Resp struct { ReqIdentifier int32 `json:"reqIdentifier"` MsgIncr string `json:"msgIncr"` OperationID string `json:"operationID"` - ErrCode int32 `json:"errCode"` + ErrCode int `json:"errCode"` ErrMsg string `json:"errMsg"` Data []byte `json:"data"` } @@ -54,7 +59,6 @@ func (g GrpcHandler) GetSeq(context context.Context, data Req) ([]byte, error) { if err := g.validate.Struct(req); err != nil { return nil, err } - log.ZDebug(context, "msggateway GetSeq", "notification", g.notification, "msg", g.notification.Msg) resp, err := g.notification.Msg.GetMaxAndMinSeq(context, &req) if err != nil { return nil, err @@ -148,3 +152,23 @@ func (g GrpcHandler) SetUserDeviceBackground(_ context.Context, data Req) ([]byt } return nil, req.IsBackground, nil } + +//func (g GrpcHandler) call[T any](ctx context.Context, data Req, m proto.Message, rpc func(ctx context.Context, req proto.Message)) ([]byte, error) { +// if err := proto.Unmarshal(data.Data, m); err != nil { +// return nil, err +// } +// if err := g.validate.Struct(m); err != nil { +// return nil, err +// } +// rpc(ctx, m) +// req := msg.SendMsgReq{MsgData: &msgData} +// resp, err := g.notification.Msg.SendMsg(context, &req) +// if err != nil { +// return nil, err +// } +// c, err := proto.Marshal(resp) +// if err != nil { +// return nil, err +// } +// return c, nil +//} diff --git a/internal/msggateway/n_ws_server.go b/internal/msggateway/n_ws_server.go index 26d3ea4fa..7450a08ce 100644 --- a/internal/msggateway/n_ws_server.go +++ b/internal/msggateway/n_ws_server.go @@ -34,18 +34,17 @@ var bufferPool = sync.Pool{ } type WsServer struct { - port int - wsMaxConnNum int64 - registerChan chan *Client - unregisterChan chan *Client - clients *UserMap - clientPool sync.Pool - onlineUserNum int64 - onlineUserConnNum int64 - handshakeTimeout time.Duration - readBufferSize, WriteBufferSize int - hubServer *Server - validate *validator.Validate + port int + wsMaxConnNum int64 + registerChan chan *Client + unregisterChan chan *Client + clients *UserMap + clientPool sync.Pool + onlineUserNum int64 + onlineUserConnNum int64 + handshakeTimeout time.Duration + hubServer *Server + validate *validator.Validate Compressor Encoder MessageHandler @@ -85,7 +84,6 @@ func NewWsServer(opts ...Option) (*WsServer, error) { port: config.port, wsMaxConnNum: config.maxConnNum, handshakeTimeout: config.handshakeTimeout, - readBufferSize: config.messageMaxMsgLength, clientPool: sync.Pool{ New: func() interface{} { return new(Client) @@ -149,7 +147,7 @@ func (ws *WsServer) unregisterClient(client *Client) { atomic.AddInt64(&ws.onlineUserNum, -1) } atomic.AddInt64(&ws.onlineUserConnNum, -1) - log.ZInfo(client.ctx, "user offline", "online user Num", ws.onlineUserNum, "online user conn Num", ws.onlineUserConnNum) + log.ZInfo(client.ctx, "user offline", "close reason", client.closedErr, "online user Num", ws.onlineUserNum, "online user conn Num", ws.onlineUserConnNum) } func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) { @@ -186,7 +184,7 @@ func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) { httpError(context, err) return } - wsLongConn := newGWebSocket(WebSocket, ws.handshakeTimeout, ws.readBufferSize) + wsLongConn := newGWebSocket(WebSocket, ws.handshakeTimeout) err = wsLongConn.GenerateLongConn(w, r) if err != nil { httpError(context, err) diff --git a/pkg/apiresp/gin.go b/pkg/apiresp/gin.go index 772a15c52..26e7be317 100644 --- a/pkg/apiresp/gin.go +++ b/pkg/apiresp/gin.go @@ -10,7 +10,7 @@ func GinError(c *gin.Context, err error) { GinSuccess(c, nil) return } - c.JSON(http.StatusOK, apiError(err)) + c.JSON(http.StatusOK, ParseError(err)) } func GinSuccess(c *gin.Context, data any) { diff --git a/pkg/apiresp/resp.go b/pkg/apiresp/resp.go index 06001a541..adc4e6f61 100644 --- a/pkg/apiresp/resp.go +++ b/pkg/apiresp/resp.go @@ -5,7 +5,7 @@ import ( "reflect" ) -type apiResponse struct { +type ApiResponse struct { ErrCode int `json:"errCode"` ErrMsg string `json:"errMsg"` ErrDlt string `json:"errDlt"` @@ -30,23 +30,23 @@ func isAllFieldsPrivate(v any) bool { return true } -func apiSuccess(data any) *apiResponse { +func apiSuccess(data any) *ApiResponse { if isAllFieldsPrivate(data) { - return &apiResponse{} + return &ApiResponse{} } - return &apiResponse{ + return &ApiResponse{ Data: data, } } -func apiError(err error) *apiResponse { +func ParseError(err error) *ApiResponse { unwrap := errs.Unwrap(err) if codeErr, ok := unwrap.(errs.CodeError); ok { - resp := apiResponse{ErrCode: codeErr.Code(), ErrMsg: codeErr.Msg(), ErrDlt: codeErr.Detail()} + resp := ApiResponse{ErrCode: codeErr.Code(), ErrMsg: codeErr.Msg(), ErrDlt: codeErr.Detail()} if resp.ErrDlt == "" { resp.ErrDlt = err.Error() } return &resp } - return &apiResponse{ErrCode: errs.ServerInternalError, ErrMsg: err.Error()} + return &ApiResponse{ErrCode: errs.ServerInternalError, ErrMsg: err.Error()} }