diff --git a/internal/msggateway/new/client.go b/internal/msggateway/new/client.go index 3feca66af..68178a1da 100644 --- a/internal/msggateway/new/client.go +++ b/internal/msggateway/new/client.go @@ -2,21 +2,20 @@ package new import ( "Open_IM/pkg/common/constant" - promePkg "Open_IM/pkg/common/prometheus" + "Open_IM/pkg/utils" + "bytes" "context" "errors" "fmt" - "github.com/envoyproxy/protoc-gen-validate/validate" "github.com/go-playground/validator/v10" - "open_im_sdk/pkg/log" - "open_im_sdk/pkg/utils" "runtime/debug" "sync" + "time" ) const ( // MessageText is for UTF-8 encoded text messages like JSON. - MessageText = iota + 1 + MessageText = iota + 1 // MessageBinary is for binary messages like protobufs. MessageBinary // CloseMessage denotes a close control message. The optional message @@ -34,48 +33,51 @@ const ( ) type Client struct { - w *sync.Mutex - conn LongConn - PlatformID int32 - PushedMaxSeq uint32 - IsCompress bool - userID string - IsBackground bool - token string - connID string - onlineAt int64 // 上线时间戳(毫秒) - handler MessageHandler - unregisterChan chan *Client + w *sync.Mutex + conn LongConn + PlatformID int32 + PushedMaxSeq uint32 + IsCompress bool + userID string + IsBackground bool + token string + connID string + onlineAt int64 // 上线时间戳(毫秒) + handler MessageHandler + unregisterChan chan *Client compressor Compressor encoder Encoder - userContext UserConnContext - validate *validator.Validate + userContext UserConnContext + validate *validator.Validate + closed bool } -func newClient( conn LongConn,isCompress bool, userID string, isBackground bool, token string, - connID string, onlineAt int64, handler MessageHandler,unregisterChan chan *Client) *Client { +func newClient(conn LongConn, isCompress bool, userID string, isBackground bool, token string, + connID string, onlineAt int64, handler MessageHandler, unregisterChan chan *Client) *Client { return &Client{ - conn: conn, + conn: conn, IsCompress: isCompress, - userID: userID, IsBackground: - isBackground, token: token, - connID: connID, - onlineAt: onlineAt, - handler: handler, - unregisterChan: unregisterChan, + userID: userID, IsBackground: isBackground, token: token, + connID: connID, + onlineAt: onlineAt, + handler: handler, + unregisterChan: unregisterChan, } } -func(c *Client) readMessage(){ +func (c *Client) readMessage() { defer func() { - if r:=recover(); r != nil { + if r := recover(); r != nil { fmt.Println("socket have panic err:", r, string(debug.Stack())) } //c.close() }() - var returnErr error - for { - messageType, message, returnErr := c.conn.ReadMessage() - if returnErr!=nil{ + var returnErr error + for { + messageType, message, returnErr := c.conn.ReadMessage() + if returnErr != nil { + break + } + if c.closed == true { break } switch messageType { @@ -89,7 +91,7 @@ func(c *Client) readMessage(){ continue } returnErr = c.handleMessage(message) - if returnErr!=nil{ + if returnErr != nil { break } @@ -97,52 +99,88 @@ func(c *Client) readMessage(){ } } -func (c *Client) handleMessage(message []byte)error { - if c.IsCompress { +func (c *Client) handleMessage(message []byte) error { + if c.IsCompress { var decompressErr error - message,decompressErr = c.compressor.DeCompress(message) + message, decompressErr = c.compressor.DeCompress(message) if decompressErr != nil { - return utils.Wrap(decompressErr,"") + return utils.Wrap(decompressErr, "") } } - var binaryReq Req + var binaryReq Req err := c.encoder.Decode(message, &binaryReq) if err != nil { - return utils.Wrap(err,"") + return utils.Wrap(err, "") } if err := c.validate.Struct(binaryReq); err != nil { - return utils.Wrap(err,"") + return utils.Wrap(err, "") } if binaryReq.SendID != c.userID { return errors.New("exception conn userID not same to req userID") } - ctx:=context.Background() - ctx =context.WithValue(ctx,"operationID",binaryReq.OperationID) - ctx = context.WithValue(ctx,"userID",binaryReq.SendID) + ctx := context.Background() + ctx = context.WithValue(ctx, "operationID", binaryReq.OperationID) + ctx = context.WithValue(ctx, "userID", binaryReq.SendID) var messageErr error - var resp []byte + var resp []byte switch binaryReq.ReqIdentifier { case constant.WSGetNewestSeq: - resp,messageErr=c.handler.GetSeq(ctx,binaryReq) + resp, messageErr = c.handler.GetSeq(ctx, binaryReq) case constant.WSSendMsg: - resp,messageErr=c.handler.SendMessage(ctx,binaryReq) + resp, messageErr = c.handler.SendMessage(ctx, binaryReq) case constant.WSSendSignalMsg: - resp,messageErr=c.handler.SendSignalMessage(ctx,binaryReq) + resp, messageErr = c.handler.SendSignalMessage(ctx, binaryReq) case constant.WSPullMsgBySeqList: - resp,messageErr=c.handler.PullMessageBySeqList(ctx,binaryReq) + resp, messageErr = c.handler.PullMessageBySeqList(ctx, binaryReq) case constant.WsLogoutMsg: - resp,messageErr=c.handler.UserLogout(ctx,binaryReq) + resp, messageErr = c.handler.UserLogout(ctx, binaryReq) case constant.WsSetBackgroundStatus: - resp,messageErr=c.handler.SetUserDeviceBackground(ctx,binaryReq) + resp, messageErr = c.handler.SetUserDeviceBackground(ctx, binaryReq) default: - return errors.New(fmt.Sprintf("ReqIdentifier failed,sendID:%d,msgIncr:%s,reqIdentifier:%s",binaryReq.SendID,binaryReq.MsgIncr,binaryReq.ReqIdentifier)) + return errors.New(fmt.Sprintf("ReqIdentifier failed,sendID:%d,msgIncr:%s,reqIdentifier:%s", binaryReq.SendID, binaryReq.MsgIncr, binaryReq.ReqIdentifier)) } - + c.replyMessage(binaryReq, messageErr, resp) + return nil } -func (c *Client) close() { +func (c *Client) close() { + c.w.Lock() + defer c.w.Unlock() + c.conn.Close() + c.unregisterChan <- c } -func () { - +func (c *Client) replyMessage(binaryReq Req, err error, resp []byte) { + mReply := Resp{ + ReqIdentifier: binaryReq.ReqIdentifier, + MsgIncr: binaryReq.MsgIncr, + OperationID: binaryReq.OperationID, + Data: resp, + } + _ = c.writeMsg(mReply) +} + +func (c *Client) writeMsg(resp Resp) error { + c.w.Lock() + defer c.w.Unlock() + if c.closed == true { + return nil + } + encodedBuf := bufferPool.Get().([]byte) + resultBuf := bufferPool.Get().([]byte) + encodeBuf, err := c.encoder.Encode(resp) + if err != nil { + return utils.Wrap(err, "") + } + _ = c.conn.SetWriteTimeout(60) + if c.IsCompress { + var compressErr error + resultBuf, compressErr = c.compressor.Compress(encodeBuf) + if compressErr != nil { + return utils.Wrap(compressErr, "") + } + return c.conn.WriteMessage(MessageBinary, resultBuf) + } else { + return c.conn.WriteMessage(MessageBinary, encodedBuf) + } } diff --git a/internal/msggateway/new/message_handler.go b/internal/msggateway/new/message_handler.go index f280992a9..e9285c5d3 100644 --- a/internal/msggateway/new/message_handler.go +++ b/internal/msggateway/new/message_handler.go @@ -10,6 +10,14 @@ type Req struct { MsgIncr string `json:"msgIncr" validate:"required"` Data []byte `json:"data"` } +type Resp struct { + ReqIdentifier int32 `json:"reqIdentifier"` + MsgIncr string `json:"msgIncr"` + OperationID string `json:"operationID"` + ErrCode int32 `json:"errCode"` + ErrMsg string `json:"errMsg"` + Data []byte `json:"data"` +} type MessageHandler interface { GetSeq(context context.Context, data Req) ([]byte, error) SendMessage(context context.Context, data Req) ([]byte, error) diff --git a/internal/msggateway/new/n_ws_server.go b/internal/msggateway/new/n_ws_server.go index 3c08385cd..7f8f1b131 100644 --- a/internal/msggateway/new/n_ws_server.go +++ b/internal/msggateway/new/n_ws_server.go @@ -1,14 +1,22 @@ package new import ( + "bytes" "errors" "github.com/gorilla/websocket" "net/http" "open_im_sdk/pkg/utils" "sync" + "sync/atomic" "time" ) + +var bufferPool = sync.Pool{ + New: func() interface{} { + return make([]byte, 1000) + }, +} type LongConnServer interface { Run() error } @@ -58,6 +66,41 @@ func newWsServer(opts ...Option) (*WsServer, error) { }, nil } func (ws *WsServer) Run() error { + var client *Client + go func() { + for { + select { + case client = <-ws.registerChan: + ws.registerClient(client) + case client = <-h.unregisterChan: + h.unregisterClient(client) + case msg = <-h.readChan: + h.messageHandler(msg) + } + } + }() +} + +func (ws *WsServer) registerClient(client *Client) { + var ( + ok bool + cli *Client + ) + + if cli, ok = h.clients.Get(client.key); ok == false { + h.clients.Set(client.key, client) + atomic.AddInt64(&h.onlineConnections, 1) + fmt.Println("R在线用户数量:", h.onlineConnections) + return + } + + if client.onlineAt > cli.onlineAt { + h.clients.Set(client.key, client) + h.close(cli) + return + } + h.close(client) +} http.HandleFunc("/", ws.wsHandler) //Get request from client to handle by wsHandler return http.ListenAndServe(":"+utils.IntToString(ws.port), nil) //Start listening