From 71f1fcee57a057416753da5eb467190c8ce5071d Mon Sep 17 00:00:00 2001 From: Gordon <1432970085@qq.com> Date: Tue, 14 Feb 2023 21:08:36 +0800 Subject: [PATCH] msggateway refactor --- cmd/msggateway/main.go | 6 +- internal/msggateway/new/client.go | 148 +++++++++++++++++++++ internal/msggateway/new/compressor.go | 44 ++++++ internal/msggateway/new/context.go | 27 ++++ internal/msggateway/new/encoder.go | 37 ++++++ internal/msggateway/new/long_conn.go | 83 ++++++++++++ internal/msggateway/new/message_handler.go | 49 +++++++ internal/msggateway/new/n_ws_server.go | 81 +++++++++++ internal/msggateway/new/options.go | 36 +++++ internal/msggateway/new/user_map.go | 64 +++++++++ 10 files changed, 572 insertions(+), 3 deletions(-) create mode 100644 internal/msggateway/new/client.go create mode 100644 internal/msggateway/new/compressor.go create mode 100644 internal/msggateway/new/context.go create mode 100644 internal/msggateway/new/encoder.go create mode 100644 internal/msggateway/new/long_conn.go create mode 100644 internal/msggateway/new/message_handler.go create mode 100644 internal/msggateway/new/n_ws_server.go create mode 100644 internal/msggateway/new/options.go create mode 100644 internal/msggateway/new/user_map.go diff --git a/cmd/msggateway/main.go b/cmd/msggateway/main.go index 0eb8adb1a..86de2dcc9 100644 --- a/cmd/msggateway/main.go +++ b/cmd/msggateway/main.go @@ -1,7 +1,7 @@ package main import ( - "Open_IM/internal/msg_gateway/gate" + "Open_IM/internal/msggateway" "Open_IM/pkg/common/config" "Open_IM/pkg/common/constant" "Open_IM/pkg/common/log" @@ -22,7 +22,7 @@ func main() { var wg sync.WaitGroup wg.Add(1) fmt.Println("start rpc/msg_gateway server, port: ", *rpcPort, *wsPort, *prometheusPort, ", OpenIM version: ", constant.CurrentVersion, "\n") - gate.Init(*rpcPort, *wsPort) - gate.Run(*prometheusPort) + msggateway.Init(*rpcPort, *wsPort) + msggateway.Run(*prometheusPort) wg.Wait() } diff --git a/internal/msggateway/new/client.go b/internal/msggateway/new/client.go new file mode 100644 index 000000000..3feca66af --- /dev/null +++ b/internal/msggateway/new/client.go @@ -0,0 +1,148 @@ +package new + +import ( + "Open_IM/pkg/common/constant" + promePkg "Open_IM/pkg/common/prometheus" + "context" + "errors" + "fmt" + "github.com/envoyproxy/protoc-gen-validate/validate" + "github.com/go-playground/validator/v10" + "open_im_sdk/pkg/log" + "open_im_sdk/pkg/utils" + "runtime/debug" + "sync" +) + +const ( + // MessageText is for UTF-8 encoded text messages like JSON. + MessageText = iota + 1 + // MessageBinary is for binary messages like protobufs. + MessageBinary + // CloseMessage denotes a close control message. The optional message + // payload contains a numeric code and text. Use the FormatCloseMessage + // function to format a close message payload. + CloseMessage = 8 + + // PingMessage denotes a ping control message. The optional message payload + // is UTF-8 encoded text. + PingMessage = 9 + + // PongMessage denotes a pong control message. The optional message payload + // is UTF-8 encoded text. + PongMessage = 10 +) + +type Client struct { + w *sync.Mutex + conn LongConn + PlatformID int32 + PushedMaxSeq uint32 + IsCompress bool + userID string + IsBackground bool + token string + connID string + onlineAt int64 // 上线时间戳(毫秒) + handler MessageHandler + unregisterChan chan *Client + compressor Compressor + encoder Encoder + userContext UserConnContext + validate *validator.Validate +} + +func newClient( conn LongConn,isCompress bool, userID string, isBackground bool, token string, + connID string, onlineAt int64, handler MessageHandler,unregisterChan chan *Client) *Client { + return &Client{ + conn: conn, + IsCompress: isCompress, + userID: userID, IsBackground: + isBackground, token: token, + connID: connID, + onlineAt: onlineAt, + handler: handler, + unregisterChan: unregisterChan, + } +} +func(c *Client) readMessage(){ + defer func() { + if r:=recover(); r != nil { + fmt.Println("socket have panic err:", r, string(debug.Stack())) + } + //c.close() + }() + var returnErr error + for { + messageType, message, returnErr := c.conn.ReadMessage() + if returnErr!=nil{ + break + } + switch messageType { + case PingMessage: + case PongMessage: + case CloseMessage: + return + case MessageText: + case MessageBinary: + if len(message) == 0 { + continue + } + returnErr = c.handleMessage(message) + if returnErr!=nil{ + break + } + + } + } + +} +func (c *Client) handleMessage(message []byte)error { + if c.IsCompress { + var decompressErr error + message,decompressErr = c.compressor.DeCompress(message) + if decompressErr != nil { + return utils.Wrap(decompressErr,"") + } + } + var binaryReq Req + err := c.encoder.Decode(message, &binaryReq) + if err != nil { + return utils.Wrap(err,"") + } + if err := c.validate.Struct(binaryReq); err != nil { + return utils.Wrap(err,"") + } + if binaryReq.SendID != c.userID { + return errors.New("exception conn userID not same to req userID") + } + ctx:=context.Background() + ctx =context.WithValue(ctx,"operationID",binaryReq.OperationID) + ctx = context.WithValue(ctx,"userID",binaryReq.SendID) + var messageErr error + var resp []byte + switch binaryReq.ReqIdentifier { + case constant.WSGetNewestSeq: + resp,messageErr=c.handler.GetSeq(ctx,binaryReq) + case constant.WSSendMsg: + resp,messageErr=c.handler.SendMessage(ctx,binaryReq) + case constant.WSSendSignalMsg: + resp,messageErr=c.handler.SendSignalMessage(ctx,binaryReq) + case constant.WSPullMsgBySeqList: + resp,messageErr=c.handler.PullMessageBySeqList(ctx,binaryReq) + case constant.WsLogoutMsg: + resp,messageErr=c.handler.UserLogout(ctx,binaryReq) + case constant.WsSetBackgroundStatus: + resp,messageErr=c.handler.SetUserDeviceBackground(ctx,binaryReq) + default: + return errors.New(fmt.Sprintf("ReqIdentifier failed,sendID:%d,msgIncr:%s,reqIdentifier:%s",binaryReq.SendID,binaryReq.MsgIncr,binaryReq.ReqIdentifier)) + } + + +} +func (c *Client) close() { + +} +func () { + +} diff --git a/internal/msggateway/new/compressor.go b/internal/msggateway/new/compressor.go new file mode 100644 index 000000000..42abcccd1 --- /dev/null +++ b/internal/msggateway/new/compressor.go @@ -0,0 +1,44 @@ +package new + +import ( + "bytes" + "compress/gzip" + "io/ioutil" + "open_im_sdk/pkg/utils" +) + +type Compressor interface { + Compress(rawData []byte) ([]byte, error) + DeCompress(compressedData []byte) ([]byte, error) +} +type GzipCompressor struct { + compressProtocol string +} + +func NewGzipCompressor() *GzipCompressor { + return &GzipCompressor{compressProtocol: "gzip"} +} +func (g *GzipCompressor) Compress(rawData []byte) ([]byte, error) { + gzipBuffer := bytes.Buffer{} + gz := gzip.NewWriter(&gzipBuffer) + if _, err := gz.Write(rawData); err != nil { + return nil, utils.Wrap(err, "") + } + if err := gz.Close(); err != nil { + return nil, utils.Wrap(err, "") + } + return gzipBuffer.Bytes(), nil +} +func (g *GzipCompressor) DeCompress(compressedData []byte) ([]byte, error) { + buff := bytes.NewBuffer(compressedData) + reader, err := gzip.NewReader(buff) + if err != nil { + return nil, utils.Wrap(err, "NewReader failed") + } + compressedData, err = ioutil.ReadAll(reader) + if err != nil { + return nil, utils.Wrap(err, "ReadAll failed") + } + _ = reader.Close() + return compressedData, nil +} diff --git a/internal/msggateway/new/context.go b/internal/msggateway/new/context.go new file mode 100644 index 000000000..9ab353351 --- /dev/null +++ b/internal/msggateway/new/context.go @@ -0,0 +1,27 @@ +package new + +import "net/http" + +type UserConnContext struct { + RespWriter http.ResponseWriter + Req *http.Request + Path string + Method string + RemoteAddr string +} + +func newContext(respWriter http.ResponseWriter, req *http.Request) *UserConnContext { + return &UserConnContext{ + RespWriter: respWriter, + Req: req, + Path: req.URL.Path, + Method: req.Method, + RemoteAddr: req.RemoteAddr, + } +} +func (c *UserConnContext) Query(key string) string { + return c.Req.URL.Query().Get(key) +} +func (c *UserConnContext) GetHeader(key string) string { + return c.Req.Header.Get(key) +} diff --git a/internal/msggateway/new/encoder.go b/internal/msggateway/new/encoder.go new file mode 100644 index 000000000..10f369d2b --- /dev/null +++ b/internal/msggateway/new/encoder.go @@ -0,0 +1,37 @@ +package new + +import ( + "bytes" + "encoding/gob" + "open_im_sdk/pkg/utils" +) + +type Encoder interface { + Encode(data interface{}) ([]byte, error) + Decode(encodeData []byte, decodeData interface{}) error +} + +type GobEncoder struct { +} + +func NewGobEncoder() *GobEncoder { + return &GobEncoder{} +} +func (g *GobEncoder) Encode(data interface{}) ([]byte, error) { + buff := bytes.Buffer{} + enc := gob.NewEncoder(&buff) + err := enc.Encode(data) + if err != nil { + return nil, err + } + return buff.Bytes(), nil +} +func (g *GobEncoder) Decode(encodeData []byte, decodeData interface{}) error { + buff := bytes.NewBuffer(encodeData) + dec := gob.NewDecoder(buff) + err := dec.Decode(decodeData) + if err != nil { + return utils.Wrap(err, "") + } + return nil +} diff --git a/internal/msggateway/new/long_conn.go b/internal/msggateway/new/long_conn.go new file mode 100644 index 000000000..823964e40 --- /dev/null +++ b/internal/msggateway/new/long_conn.go @@ -0,0 +1,83 @@ +package new + +import ( + "github.com/gorilla/websocket" + "net/http" + "time" +) + +type LongConn interface { + //Close this connection + Close() error + //Write message to connection,messageType means data type,can be set binary(2) and text(1). + WriteMessage(messageType int, message []byte) error + //Read message from connection. + ReadMessage() (int, []byte, error) + //SetReadTimeout sets the read deadline on the underlying network connection, + //after a read has timed out, will return an error. + SetReadTimeout(timeout int) error + //SetWriteTimeout sets the write deadline when send message,when read has timed out,will return error. + SetWriteTimeout(timeout int) error + //Try to dial a connection,url must set auth args,header can control compress data + Dial(urlStr string, requestHeader http.Header) (*http.Response, error) + //Whether the connection of the current long connection is nil + IsNil() bool + //Set the connection of the current long connection to nil + SetConnNil() + //Check the connection of the current and when it was sent are the same + CheckSendConnDiffNow() bool +} +type GWebSocket struct { + protocolType int + conn *websocket.Conn +} + +func NewDefault(protocolType int) *GWebSocket { + return &GWebSocket{protocolType: protocolType} +} +func (d *GWebSocket) Close() error { + return d.conn.Close() +} + +func (d *GWebSocket) WriteMessage(messageType int, message []byte) error { + d.setSendConn(d.conn) + return d.conn.WriteMessage(messageType, message) +} + +func (d *GWebSocket) setSendConn(sendConn *websocket.Conn) { + d.sendConn = sendConn +} + +func (d *GWebSocket) ReadMessage() (int, []byte, error) { + return d.conn.ReadMessage() +} +func (d *GWebSocket) SetReadTimeout(timeout int) error { + return d.conn.SetReadDeadline(time.Now().Add(time.Duration(timeout) * time.Second)) +} + +func (d *GWebSocket) SetWriteTimeout(timeout int) error { + return d.conn.SetWriteDeadline(time.Now().Add(time.Duration(timeout) * time.Second)) +} + +func (d *GWebSocket) Dial(urlStr string, requestHeader http.Header) (*http.Response, error) { + conn, httpResp, err := websocket.DefaultDialer.Dial(urlStr, requestHeader) + if err == nil { + d.conn = conn + } + return httpResp, err + +} + +func (d *GWebSocket) IsNil() bool { + if d.conn != nil { + return false + } + return true +} + +func (d *GWebSocket) SetConnNil() { + d.conn = nil +} +func (d *GWebSocket) CheckSendConnDiffNow() bool { + return d.conn == d.sendConn +} diff --git a/internal/msggateway/new/message_handler.go b/internal/msggateway/new/message_handler.go new file mode 100644 index 000000000..f280992a9 --- /dev/null +++ b/internal/msggateway/new/message_handler.go @@ -0,0 +1,49 @@ +package new + +import "context" + +type Req struct { + ReqIdentifier int32 `json:"reqIdentifier" validate:"required"` + Token string `json:"token" ` + SendID string `json:"sendID" validate:"required"` + OperationID string `json:"operationID" validate:"required"` + MsgIncr string `json:"msgIncr" validate:"required"` + Data []byte `json:"data"` +} +type MessageHandler interface { + GetSeq(context context.Context, data Req) ([]byte, error) + SendMessage(context context.Context, data Req) ([]byte, error) + SendSignalMessage(context context.Context, data Req) ([]byte, error) + PullMessageBySeqList(context context.Context, data Req) ([]byte, error) + UserLogout(context context.Context, data Req) ([]byte, error) + SetUserDeviceBackground(context context.Context, data Req) ([]byte, error) +} + +var _ MessageHandler = (*GrpcHandler)(nil) + +type GrpcHandler struct { +} + +func (g GrpcHandler) GetSeq(context context.Context, data Req) ([]byte, error) { + panic("implement me") +} + +func (g GrpcHandler) SendMessage(context context.Context, data Req) ([]byte, error) { + panic("implement me") +} + +func (g GrpcHandler) SendSignalMessage(context context.Context, data Req) ([]byte, error) { + panic("implement me") +} + +func (g GrpcHandler) PullMessageBySeqList(context context.Context, data Req) ([]byte, error) { + panic("implement me") +} + +func (g GrpcHandler) UserLogout(context context.Context, data Req) ([]byte, error) { + panic("implement me") +} + +func (g GrpcHandler) SetUserDeviceBackground(context context.Context, data Req) ([]byte, error) { + panic("implement me") +} diff --git a/internal/msggateway/new/n_ws_server.go b/internal/msggateway/new/n_ws_server.go new file mode 100644 index 000000000..3c08385cd --- /dev/null +++ b/internal/msggateway/new/n_ws_server.go @@ -0,0 +1,81 @@ +package new + +import ( + "errors" + "github.com/gorilla/websocket" + "net/http" + "open_im_sdk/pkg/utils" + "sync" + "time" +) + +type LongConnServer interface { + Run() error +} + +type Server struct { + rpcPort int + wsMaxConnNum int + longConnServer *LongConnServer + rpcServer *RpcServer +} +type WsServer struct { + port int + wsMaxConnNum int + wsUpGrader *websocket.Upgrader + registerChan chan *Client + unregisterChan chan *Client + clients *UserMap + clientPool sync.Pool + onlineUserNum int64 + onlineUserConnNum int64 + compressor Compressor + handler MessageHandler +} + +func newWsServer(opts ...Option) (*WsServer, error) { + var config configs + for _, o := range opts { + o(&config) + } + if config.port < 1024 { + return nil, errors.New("port not allow to listen") + + } + return &WsServer{ + port: config.port, + wsMaxConnNum: config.maxConnNum, + wsUpGrader: &websocket.Upgrader{ + HandshakeTimeout: config.handshakeTimeout, + ReadBufferSize: config.messageMaxMsgLength, + CheckOrigin: func(r *http.Request) bool { return true }, + }, + clientPool: sync.Pool{ + New: func() interface{} { + return new(Client) + }, + }, + }, nil +} +func (ws *WsServer) Run() error { + http.HandleFunc("/", ws.wsHandler) //Get request from client to handle by wsHandler + return http.ListenAndServe(":"+utils.IntToString(ws.port), nil) //Start listening + +} +func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) { + context := newContext(w, r) + if isPass, compression := ws.headerCheck(w, r, operationID); isPass { + conn, err := ws.wsUpGrader.Upgrade(w, r, nil) //Conn is obtained through the upgraded escalator + if err != nil { + log.Error(operationID, "upgrade http conn err", err.Error(), query) + return + } else { + newConn := &UserConn{conn, new(sync.Mutex), utils.StringToInt32(query["platformID"][0]), 0, compression, query["sendID"][0], false, query["token"][0], conn.RemoteAddr().String() + "_" + strconv.Itoa(int(utils.GetCurrentTimestampByMill()))} + userCount++ + ws.addUserConn(query["sendID"][0], utils.StringToInt(query["platformID"][0]), newConn, query["token"][0], newConn.connID, operationID) + go ws.readMsg(newConn) + } + } else { + log.Error(operationID, "headerCheck failed ") + } +} diff --git a/internal/msggateway/new/options.go b/internal/msggateway/new/options.go new file mode 100644 index 000000000..71a732751 --- /dev/null +++ b/internal/msggateway/new/options.go @@ -0,0 +1,36 @@ +package new + +import "time" + +type Option func(opt *configs) +type configs struct { + //长连接监听端口 + port int + //长连接允许最大链接数 + maxConnNum int + //连接握手超时时间 + handshakeTimeout time.Duration + //允许消息最大长度 + messageMaxMsgLength int +} + +func WithPort(port int) Option { + return func(opt *configs) { + opt.port = port + } +} +func WithMaxConnNum(num int) Option { + return func(opt *configs) { + opt.maxConnNum = num + } +} +func WithHandshakeTimeout(t time.Duration) Option { + return func(opt *configs) { + opt.handshakeTimeout = t + } +} +func WithMessageMaxMsgLength(length int) Option { + return func(opt *configs) { + opt.messageMaxMsgLength = length + } +} diff --git a/internal/msggateway/new/user_map.go b/internal/msggateway/new/user_map.go new file mode 100644 index 000000000..82615e827 --- /dev/null +++ b/internal/msggateway/new/user_map.go @@ -0,0 +1,64 @@ +package new + +import "sync" + +type UserMap struct { + m sync.Map +} + +func newUserMap() *UserMap { + return &UserMap{} +} +func (u *UserMap) GetAll(key string) []*Client { + allClients, ok := u.m.Load(key) + if ok { + return allClients.([]*Client) + } + return nil +} +func (u *UserMap) Get(key string, platformID int32) (*Client, bool) { + allClients, existed := u.m.Load(key) + if existed { + for _, client := range allClients.([]*Client) { + if client.PlatformID == platformID { + return client, existed + } + } + return nil, false + } + return nil, existed +} +func (u *UserMap) Set(key string, v *Client) { + allClients, existed := u.m.Load(key) + if existed { + oldClients := allClients.([]*Client) + oldClients = append(oldClients, v) + u.m.Store(key, oldClients) + } else { + clients := make([]*Client, 3) + clients = append(clients, v) + u.m.Store(key, clients) + } +} +func (u *UserMap) delete(key string, platformID int32) { + allClients, existed := u.m.Load(key) + if existed { + oldClients := allClients.([]*Client) + + a := make([]*Client, len(oldClients)) + for _, client := range oldClients { + if client.PlatformID != platformID { + a = append(a, client) + } + } + if len(a) == 0 { + u.m.Delete(key) + } else { + u.m.Store(key, a) + + } + } +} +func (u *UserMap) DeleteAll(key string) { + u.m.Delete(key) +}