diff --git a/internal/msggateway/client.go b/internal/msggateway/client.go index bdb62aece..2c2192e53 100644 --- a/internal/msggateway/client.go +++ b/internal/msggateway/client.go @@ -16,7 +16,6 @@ package msggateway import ( "context" - "encoding/json" "fmt" "sync" "sync/atomic" @@ -64,7 +63,7 @@ type PingPongHandler func(string) error type Client struct { w *sync.Mutex - conn LongConn + conn ClientConn PlatformID int `json:"platformID"` IsCompress bool `json:"isCompress"` UserID string `json:"userID"` @@ -83,7 +82,7 @@ type Client struct { } // ResetClient updates the client's state with new connection and context information. -func (c *Client) ResetClient(ctx *UserConnContext, conn LongConn, longConnServer LongConnServer) { +func (c *Client) ResetClient(ctx *UserConnContext, conn ClientConn, longConnServer LongConnServer) { c.w = new(sync.Mutex) c.conn = conn c.PlatformID = stringutil.StringToInt(ctx.GetPlatformID()) @@ -110,22 +109,6 @@ func (c *Client) ResetClient(ctx *UserConnContext, conn LongConn, longConnServer c.subUserIDs = make(map[string]struct{}) } -func (c *Client) pingHandler(appData string) error { - if err := c.conn.SetReadDeadline(pongWait); err != nil { - return err - } - - log.ZDebug(c.ctx, "ping Handler Success.", "appData", appData) - return c.writePongMsg(appData) -} - -func (c *Client) pongHandler(_ string) error { - if err := c.conn.SetReadDeadline(pongWait); err != nil { - return err - } - return nil -} - // readMessage continuously reads messages from the connection. func (c *Client) readMessage() { defer func() { @@ -136,52 +119,25 @@ func (c *Client) readMessage() { c.close() }() - c.conn.SetReadLimit(maxMessageSize) - _ = c.conn.SetReadDeadline(pongWait) - c.conn.SetPongHandler(c.pongHandler) - c.conn.SetPingHandler(c.pingHandler) - c.activeHeartbeat(c.hbCtx) - for { log.ZDebug(c.ctx, "readMessage") - messageType, message, returnErr := c.conn.ReadMessage() + message, returnErr := c.conn.ReadMessage() if returnErr != nil { - log.ZWarn(c.ctx, "readMessage", returnErr, "messageType", messageType) + log.ZWarn(c.ctx, "readMessage", returnErr) c.closedErr = returnErr return } - log.ZDebug(c.ctx, "readMessage", "messageType", messageType) if c.closed.Load() { // The scenario where the connection has just been closed, but the coroutine has not exited c.closedErr = ErrConnClosed return } - switch messageType { - case MessageBinary: - _ = c.conn.SetReadDeadline(pongWait) - parseDataErr := c.handleMessage(message) - if parseDataErr != nil { - c.closedErr = parseDataErr - return - } - case MessageText: - _ = c.conn.SetReadDeadline(pongWait) - parseDataErr := c.handlerTextMessage(message) - if parseDataErr != nil { - c.closedErr = parseDataErr - return - } - case PingMessage: - err := c.writePongMsg("") - log.ZError(c.ctx, "writePongMsg", err) - - case CloseMessage: - c.closedErr = ErrClientClosed + parseDataErr := c.handleMessage(message) + if parseDataErr != nil { + c.closedErr = parseDataErr return - - default: } } } @@ -356,109 +312,13 @@ func (c *Client) writeBinaryMsg(resp Resp) error { c.w.Lock() defer c.w.Unlock() - err = c.conn.SetWriteDeadline(writeWait) - if err != nil { - return err - } - if c.IsCompress { resultBuf, compressErr := c.longConnServer.CompressWithPool(encodedBuf) if compressErr != nil { return compressErr } - return c.conn.WriteMessage(MessageBinary, resultBuf) + return c.conn.WriteMessage(resultBuf) } - return c.conn.WriteMessage(MessageBinary, encodedBuf) -} - -// Actively initiate Heartbeat when platform in Web. -func (c *Client) activeHeartbeat(ctx context.Context) { - if c.PlatformID == constant.WebPlatformID { - go func() { - defer func() { - if r := recover(); r != nil { - log.ZPanic(ctx, "activeHeartbeat Panic", errs.ErrPanic(r)) - } - }() - log.ZDebug(ctx, "server initiative send heartbeat start.") - ticker := time.NewTicker(pingPeriod) - defer ticker.Stop() - - for { - select { - case <-ticker.C: - if err := c.writePingMsg(); err != nil { - log.ZWarn(c.ctx, "send Ping Message error.", err) - return - } - case <-c.hbCtx.Done(): - return - } - } - }() - } -} -func (c *Client) writePingMsg() error { - if c.closed.Load() { - return nil - } - - c.w.Lock() - defer c.w.Unlock() - - err := c.conn.SetWriteDeadline(writeWait) - if err != nil { - return err - } - - return c.conn.WriteMessage(PingMessage, nil) -} - -func (c *Client) writePongMsg(appData string) error { - log.ZDebug(c.ctx, "write Pong Msg in Server", "appData", appData) - if c.closed.Load() { - log.ZWarn(c.ctx, "is closed in server", nil, "appdata", appData, "closed err", c.closedErr) - return nil - } - - c.w.Lock() - defer c.w.Unlock() - - err := c.conn.SetWriteDeadline(writeWait) - if err != nil { - log.ZWarn(c.ctx, "SetWriteDeadline in Server have error", errs.Wrap(err), "writeWait", writeWait, "appData", appData) - return errs.Wrap(err) - } - err = c.conn.WriteMessage(PongMessage, []byte(appData)) - if err != nil { - log.ZWarn(c.ctx, "Write Message have error", errs.Wrap(err), "Pong msg", PongMessage) - } - - return errs.Wrap(err) -} - -func (c *Client) handlerTextMessage(b []byte) error { - var msg TextMessage - if err := json.Unmarshal(b, &msg); err != nil { - return err - } - switch msg.Type { - case TextPong: - return nil - case TextPing: - msg.Type = TextPong - msgData, err := json.Marshal(msg) - if err != nil { - return err - } - c.w.Lock() - defer c.w.Unlock() - if err := c.conn.SetWriteDeadline(writeWait); err != nil { - return err - } - return c.conn.WriteMessage(MessageText, msgData) - default: - return fmt.Errorf("not support message type %s", msg.Type) - } + return c.conn.WriteMessage(encodedBuf) } diff --git a/internal/msggateway/client_conn.go b/internal/msggateway/client_conn.go new file mode 100644 index 000000000..7837380e1 --- /dev/null +++ b/internal/msggateway/client_conn.go @@ -0,0 +1,212 @@ +package msggateway + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "sync/atomic" + "time" + + "github.com/gorilla/websocket" + + "github.com/openimsdk/tools/log" +) + +var ErrWriteFull = fmt.Errorf("websocket write buffer full,close connection") + +type ClientConn interface { + ReadMessage() ([]byte, error) + WriteMessage(message []byte) error + Close() error +} + +type websocketMessage struct { + MessageType int + Data []byte +} + +func NewWebSocketClientConn(conn *websocket.Conn, readLimit int64, readTimeout time.Duration, pingInterval time.Duration) ClientConn { + c := &websocketClientConn{ + readTimeout: readTimeout, + conn: conn, + writer: make(chan *websocketMessage, 256), + done: make(chan struct{}), + } + if readLimit > 0 { + c.conn.SetReadLimit(readLimit) + } + c.conn.SetPingHandler(c.pingHandler) + c.conn.SetPongHandler(c.pongHandler) + + go c.loopSend() + if pingInterval > 0 { + go c.doPing(pingInterval) + } + return c +} + +type websocketClientConn struct { + readTimeout time.Duration + conn *websocket.Conn + writer chan *websocketMessage + done chan struct{} + err atomic.Pointer[error] +} + +func (c *websocketClientConn) ReadMessage() ([]byte, error) { + buf, err := c.readMessage() + if err != nil { + return nil, c.closeBy(fmt.Errorf("read message %w", err)) + } + return buf, nil +} + +func (c *websocketClientConn) WriteMessage(message []byte) error { + return c.writeMessage(websocket.BinaryMessage, message) +} + +func (c *websocketClientConn) Close() error { + _ = c.closeBy(fmt.Errorf("websocket connection closed")) + return nil +} + +func (c *websocketClientConn) closeBy(err error) error { + if !c.err.CompareAndSwap(nil, &err) { + return *c.err.Load() + } + close(c.done) + log.ZWarn(context.Background(), "websocket connection closed", err, "remoteAddr", c.conn.RemoteAddr(), + "chan length", len(c.writer)) + _ = c.conn.Close() + return err +} + +func (c *websocketClientConn) writeMessage(messageType int, data []byte) error { + if errPtr := c.err.Load(); errPtr != nil { + return *errPtr + } + select { + case c.writer <- &websocketMessage{MessageType: messageType, Data: data}: + return nil + default: + return c.closeBy(ErrWriteFull) + } +} + +func (c *websocketClientConn) loopSend() { + var err error + for { + select { + case <-c.done: + return + case msg := <-c.writer: + switch msg.MessageType { + case websocket.TextMessage, websocket.BinaryMessage: + err = c.conn.WriteMessage(msg.MessageType, msg.Data) + default: + err = c.conn.WriteControl(msg.MessageType, msg.Data, time.Time{}) + } + if err != nil { + _ = c.closeBy(err) + return + } + } + } +} + +func (c *websocketClientConn) setReadDeadline() error { + deadline := time.Now().Add(c.readTimeout) + return c.conn.SetReadDeadline(deadline) +} + +func (c *websocketClientConn) readMessage() ([]byte, error) { + for { + if err := c.setReadDeadline(); err != nil { + return nil, err + } + messageType, buf, err := c.conn.ReadMessage() + if err != nil { + return nil, err + } + switch messageType { + case websocket.BinaryMessage: + return buf, nil + case websocket.TextMessage: + if err := c.onReadTextMessage(buf); err != nil { + return nil, err + } + case websocket.PingMessage: + if err := c.pingHandler(string(buf)); err != nil { + return nil, err + } + case websocket.PongMessage: + if err := c.pongHandler(string(buf)); err != nil { + return nil, err + } + case websocket.CloseMessage: + if len(buf) == 0 { + return nil, errors.New("websocket connection closed by peer") + } + return nil, fmt.Errorf("websocket connection closed by peer, data %s", string(buf)) + default: + return nil, fmt.Errorf("unknown websocket message type %d", messageType) + } + } +} + +func (c *websocketClientConn) onReadTextMessage(buf []byte) error { + var msg struct { + Type string `json:"type"` + Body json.RawMessage `json:"body"` + } + if err := json.Unmarshal(buf, &msg); err != nil { + return err + } + switch msg.Type { + case TextPong: + return nil + case TextPing: + msg.Type = TextPong + msgData, err := json.Marshal(msg) + if err != nil { + return err + } + return c.writeMessage(websocket.TextMessage, msgData) + default: + return fmt.Errorf("not support text message type %s", msg.Type) + } +} + +func (c *websocketClientConn) pingHandler(appData string) error { + log.ZWarn(context.Background(), "ping handler recv ping", nil, "remoteAddr", c.conn.RemoteAddr(), "appData", appData) + if err := c.setReadDeadline(); err != nil { + return err + } + err := c.conn.WriteControl(websocket.PongMessage, []byte(appData), time.Now().Add(time.Second*1)) + if err != nil { + log.ZWarn(context.Background(), "ping handler write pong error", err, "remoteAddr", c.conn.RemoteAddr(), "appData", appData) + } + log.ZWarn(context.Background(), "ping handler write pong success", nil, "remoteAddr", c.conn.RemoteAddr(), "appData", appData) + return nil +} + +func (c *websocketClientConn) pongHandler(string) error { + return nil +} + +func (c *websocketClientConn) doPing(d time.Duration) { + ticker := time.NewTicker(d) + defer ticker.Stop() + for { + select { + case <-c.done: + return + case <-ticker.C: + if err := c.writeMessage(websocket.PingMessage, nil); err != nil { + _ = c.closeBy(fmt.Errorf("send ping %w", err)) + return + } + } + } +} diff --git a/internal/msggateway/long_conn.go b/internal/msggateway/long_conn.go deleted file mode 100644 index c1b3e27c9..000000000 --- a/internal/msggateway/long_conn.go +++ /dev/null @@ -1,179 +0,0 @@ -// Copyright © 2023 OpenIM. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package msggateway - -import ( - "encoding/json" - "net/http" - "time" - - "github.com/openimsdk/tools/apiresp" - - "github.com/gorilla/websocket" - "github.com/openimsdk/tools/errs" -) - -type LongConn interface { - // Close this connection - Close() error - // WriteMessage Write message to connection,messageType means data type,can be set binary(2) and text(1). - WriteMessage(messageType int, message []byte) error - // ReadMessage Read message from connection. - ReadMessage() (int, []byte, error) - // SetReadDeadline sets the read deadline on the underlying network connection, - // after a read has timed out, will return an error. - SetReadDeadline(timeout time.Duration) error - // SetWriteDeadline sets to write deadline when send message,when read has timed out,will return error. - SetWriteDeadline(timeout time.Duration) error - // Dial Try to dial a connection,url must set auth args,header can control compress data - Dial(urlStr string, requestHeader http.Header) (*http.Response, error) - // IsNil Whether the connection of the current long connection is nil - IsNil() bool - // SetConnNil Set the connection of the current long connection to nil - SetConnNil() - // SetReadLimit sets the maximum size for a message read from the peer.bytes - SetReadLimit(limit int64) - SetPongHandler(handler PingPongHandler) - SetPingHandler(handler PingPongHandler) - // GenerateLongConn Check the connection of the current and when it was sent are the same - GenerateLongConn(w http.ResponseWriter, r *http.Request) error -} -type GWebSocket struct { - protocolType int - conn *websocket.Conn - handshakeTimeout time.Duration - writeBufferSize int -} - -func newGWebSocket(protocolType int, handshakeTimeout time.Duration, wbs int) *GWebSocket { - return &GWebSocket{protocolType: protocolType, handshakeTimeout: handshakeTimeout, writeBufferSize: wbs} -} - -func (d *GWebSocket) Close() error { - return d.conn.Close() -} - -func (d *GWebSocket) GenerateLongConn(w http.ResponseWriter, r *http.Request) error { - upgrader := &websocket.Upgrader{ - HandshakeTimeout: d.handshakeTimeout, - CheckOrigin: func(r *http.Request) bool { return true }, - } - if d.writeBufferSize > 0 { // default is 4kb. - upgrader.WriteBufferSize = d.writeBufferSize - } - - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - // The upgrader.Upgrade method usually returns enough error messages to diagnose problems that may occur during the upgrade - return errs.WrapMsg(err, "GenerateLongConn: WebSocket upgrade failed") - } - d.conn = conn - return nil -} - -func (d *GWebSocket) WriteMessage(messageType int, message []byte) error { - // d.setSendConn(d.conn) - return d.conn.WriteMessage(messageType, message) -} - -// func (d *GWebSocket) setSendConn(sendConn *websocket.Conn) { -// d.sendConn = sendConn -//} - -func (d *GWebSocket) ReadMessage() (int, []byte, error) { - return d.conn.ReadMessage() -} - -func (d *GWebSocket) SetReadDeadline(timeout time.Duration) error { - return d.conn.SetReadDeadline(time.Now().Add(timeout)) -} - -func (d *GWebSocket) SetWriteDeadline(timeout time.Duration) error { - if timeout <= 0 { - return errs.New("timeout must be greater than 0") - } - - // TODO SetWriteDeadline Future add error handling - if err := d.conn.SetWriteDeadline(time.Now().Add(timeout)); err != nil { - return errs.WrapMsg(err, "GWebSocket.SetWriteDeadline failed") - } - return nil -} - -func (d *GWebSocket) Dial(urlStr string, requestHeader http.Header) (*http.Response, error) { - conn, httpResp, err := websocket.DefaultDialer.Dial(urlStr, requestHeader) - if err != nil { - return httpResp, errs.WrapMsg(err, "GWebSocket.Dial failed", "url", urlStr) - } - d.conn = conn - return httpResp, nil -} - -func (d *GWebSocket) IsNil() bool { - return d.conn == nil - // - // if d.conn != nil { - // return false - // } - // return true -} - -func (d *GWebSocket) SetConnNil() { - d.conn = nil -} - -func (d *GWebSocket) SetReadLimit(limit int64) { - d.conn.SetReadLimit(limit) -} - -func (d *GWebSocket) SetPongHandler(handler PingPongHandler) { - d.conn.SetPongHandler(handler) -} - -func (d *GWebSocket) SetPingHandler(handler PingPongHandler) { - d.conn.SetPingHandler(handler) -} - -func (d *GWebSocket) RespondWithError(err error, w http.ResponseWriter, r *http.Request) error { - if err := d.GenerateLongConn(w, r); err != nil { - return err - } - data, err := json.Marshal(apiresp.ParseError(err)) - if err != nil { - _ = d.Close() - return errs.WrapMsg(err, "json marshal failed") - } - - if err := d.WriteMessage(MessageText, data); err != nil { - _ = d.Close() - return errs.WrapMsg(err, "WriteMessage failed") - } - _ = d.Close() - return nil -} - -func (d *GWebSocket) RespondWithSuccess() error { - data, err := json.Marshal(apiresp.ParseError(nil)) - if err != nil { - _ = d.Close() - return errs.WrapMsg(err, "json marshal failed") - } - - if err := d.WriteMessage(MessageText, data); err != nil { - _ = d.Close() - return errs.WrapMsg(err, "WriteMessage failed") - } - return nil -} diff --git a/internal/msggateway/ws_server.go b/internal/msggateway/ws_server.go index fbe159476..fe729d5fb 100644 --- a/internal/msggateway/ws_server.go +++ b/internal/msggateway/ws_server.go @@ -2,13 +2,17 @@ package msggateway import ( "context" + "encoding/json" "fmt" "net/http" + "strconv" "sync" "sync/atomic" "time" + "github.com/gorilla/websocket" "github.com/openimsdk/open-im-server/v3/pkg/rpcli" + "github.com/openimsdk/tools/apiresp" "github.com/go-playground/validator/v10" "github.com/openimsdk/open-im-server/v3/pkg/common/prommetrics" @@ -42,6 +46,7 @@ type LongConnServer interface { } type WsServer struct { + websocket *websocket.Upgrader msgGatewayConfig *Config port int wsMaxConnNum int64 @@ -131,9 +136,13 @@ func NewWsServer(msgGatewayConfig *Config, opts ...Option) *WsServer { o(&config) } //userRpcClient := rpcclient.NewUserRpcClient(client, config.Share.RpcRegisterName.User, config.Share.IMAdminUserID) - + upgrader := &websocket.Upgrader{ + HandshakeTimeout: config.handshakeTimeout, + CheckOrigin: func(r *http.Request) bool { return true }, + } v := validator.New() return &WsServer{ + websocket: upgrader, msgGatewayConfig: msgGatewayConfig, port: config.port, wsMaxConnNum: config.maxConnNum, @@ -449,6 +458,29 @@ func (ws *WsServer) validateRespWithRequest(ctx *UserConnContext, resp *pbAuth.P return nil } +func (ws *WsServer) handlerError(ctx *UserConnContext, w http.ResponseWriter, r *http.Request, err error) { + if !ctx.ShouldSendResp() { + httpError(ctx, err) + return + } + // the browser cannot get the response of upgrade failure + data, err := json.Marshal(apiresp.ParseError(err)) + if err != nil { + log.ZError(ctx, "json marshal failed", err) + return + } + conn, upgradeErr := ws.websocket.Upgrade(w, r, nil) + if upgradeErr != nil { + log.ZWarn(ctx, "websocket upgrade failed", upgradeErr, "respErr", err, "resp", string(data)) + return + } + defer conn.Close() + if err := conn.WriteMessage(websocket.TextMessage, data); err != nil { + log.ZWarn(ctx, "WriteMessage failed", err, "respErr", err, "resp", string(data)) + return + } +} + func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) { // Create a new connection context connContext := newContext(w, r) @@ -456,7 +488,7 @@ func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) { // Check if the current number of online user connections exceeds the maximum limit if ws.onlineUserConnNum.Load() >= ws.wsMaxConnNum { // If it exceeds the maximum connection number, return an error via HTTP and stop processing - httpError(connContext, servererrs.ErrConnOverMaxNumLimit.WrapMsg("over max conn num limit")) + ws.handlerError(connContext, w, r, servererrs.ErrConnOverMaxNumLimit.WrapMsg("over max conn num limit")) return } @@ -464,26 +496,14 @@ func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) { err := connContext.ParseEssentialArgs() if err != nil { // If there's an error during parsing, return an error via HTTP and stop processing - - httpError(connContext, err) + ws.handlerError(connContext, w, r, err) return } // Call the authentication client to parse the Token obtained from the context resp, err := ws.authClient.ParseToken(connContext, connContext.GetToken()) if err != nil { - // If there's an error parsing the Token, decide whether to send the error message via WebSocket based on the context flag - shouldSendError := connContext.ShouldSendResp() - if shouldSendError { - // Create a WebSocket connection object and attempt to send the error message via WebSocket - wsLongConn := newGWebSocket(WebSocket, ws.handshakeTimeout, ws.writeBufferSize) - if err := wsLongConn.RespondWithError(err, w, r); err == nil { - // If the error message is successfully sent via WebSocket, stop processing - return - } - } - // If sending via WebSocket is not required or fails, return the error via HTTP and stop processing - httpError(connContext, err) + ws.handlerError(connContext, w, r, err) return } @@ -491,32 +511,24 @@ func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) { err = ws.validateRespWithRequest(connContext, resp) if err != nil { // If validation fails, return an error via HTTP and stop processing - httpError(connContext, err) + ws.handlerError(connContext, w, r, err) return } - - log.ZDebug(connContext, "new conn", "token", connContext.GetToken()) - // Create a WebSocket long connection object - wsLongConn := newGWebSocket(WebSocket, ws.handshakeTimeout, ws.writeBufferSize) - if err := wsLongConn.GenerateLongConn(w, r); err != nil { - //If the creation of the long connection fails, the error is handled internally during the handshake process. - log.ZWarn(connContext, "long connection fails", err) + conn, err := ws.websocket.Upgrade(w, r, nil) + if err != nil { + log.ZWarn(connContext, "websocket upgrade failed", err) return - } else { - // Check if a normal response should be sent via WebSocket - shouldSendSuccessResp := connContext.ShouldSendResp() - if shouldSendSuccessResp { - // Attempt to send a success message through WebSocket - if err := wsLongConn.RespondWithSuccess(); err != nil { - // If the success message is successfully sent, end further processing - return - } - } + } + log.ZDebug(connContext, "new conn", "token", connContext.GetToken()) + + var pingInterval time.Duration + if connContext.GetPlatformID() == strconv.Itoa(constant.WebPlatformID) { + pingInterval = pingPeriod } // Retrieve a client object from the client pool, reset its state, and associate it with the current WebSocket long connection client := ws.clientPool.Get().(*Client) - client.ResetClient(connContext, wsLongConn, ws) + client.ResetClient(connContext, NewWebSocketClientConn(conn, maxMessageSize, pongWait, pingInterval), ws) // Register the client with the server and start message processing ws.registerChan <- client