From d54cab5d037398c2486314de1aa08bc9f356dd97 Mon Sep 17 00:00:00 2001 From: Gordon <1432970085@qq.com> Date: Wed, 22 Feb 2023 21:06:55 +0800 Subject: [PATCH] gateway update --- internal/common/check/msg.go | 37 +++++++ internal/common/notification/c.go | 6 +- internal/common/notification/extend_msg.go | 2 +- internal/msggateway/new/client.go | 25 +++-- internal/msggateway/new/compressor.go | 1 + internal/msggateway/new/constant.go | 4 +- internal/msggateway/new/context.go | 2 +- internal/msggateway/new/encoder.go | 1 + internal/msggateway/new/long_conn.go | 17 ++-- internal/msggateway/new/message_handler.go | 110 ++++++++++++++++++--- internal/msggateway/new/n_ws_server.go | 66 ++++++------- 11 files changed, 203 insertions(+), 68 deletions(-) diff --git a/internal/common/check/msg.go b/internal/common/check/msg.go index 010ee1606..06dd82c3e 100644 --- a/internal/common/check/msg.go +++ b/internal/common/check/msg.go @@ -4,6 +4,7 @@ import ( "Open_IM/pkg/common/config" discoveryRegistry "Open_IM/pkg/discoveryregistry" "Open_IM/pkg/proto/msg" + "Open_IM/pkg/proto/sdkws" "context" "google.golang.org/grpc" ) @@ -28,3 +29,39 @@ func (m *MsgCheck) SendMsg(ctx context.Context, req *msg.SendMsgReq) (*msg.SendM resp, err := msg.NewMsgClient(cc).SendMsg(ctx, req) return resp, err } + +func (m *MsgCheck) GetMaxAndMinSeq(ctx context.Context, req *sdkws.GetMaxAndMinSeqReq) (*sdkws.GetMaxAndMinSeqResp, error) { + cc, err := m.getConn() + if err != nil { + return nil, err + } + resp, err := msg.NewMsgClient(cc).GetMaxAndMinSeq(ctx, req) + return resp, err +} + +func (m *MsgCheck) PullMessageBySeqList(ctx context.Context, req *sdkws.PullMessageBySeqListReq) (*sdkws.PullMessageBySeqListResp, error) { + cc, err := m.getConn() + if err != nil { + return nil, err + } + resp, err := msg.NewMsgClient(cc).PullMessageBySeqList(ctx, req) + return resp, err +} + +//func (m *MsgCheck) SendMsg(ctx context.Context, req *msg.SendMsgReq) (*msg.SendMsgResp, error) { +// cc, err := m.getConn() +// if err != nil { +// return nil, err +// } +// resp, err := msg.NewMsgClient(cc).SendMsg(ctx, req) +// return resp, err +//} +// +//func (m *MsgCheck) SendMsg(ctx context.Context, req *msg.SendMsgReq) (*msg.SendMsgResp, error) { +// cc, err := m.getConn() +// if err != nil { +// return nil, err +// } +// resp, err := msg.NewMsgClient(cc).SendMsg(ctx, req) +// return resp, err +//} diff --git a/internal/common/notification/c.go b/internal/common/notification/c.go index 52f6853b2..fe765c41f 100644 --- a/internal/common/notification/c.go +++ b/internal/common/notification/c.go @@ -16,7 +16,7 @@ import ( type Check struct { user *check.UserCheck group *check.GroupChecker - msg *check.MsgCheck + Msg *check.MsgCheck friend *check.FriendChecker conversation *check.ConversationChecker } @@ -25,7 +25,7 @@ func NewCheck(zk discoveryRegistry.SvcDiscoveryRegistry) *Check { return &Check{ user: check.NewUserCheck(zk), group: check.NewGroupChecker(zk), - msg: check.NewMsgCheck(zk), + Msg: check.NewMsgCheck(zk), friend: check.NewFriendChecker(zk), conversation: check.NewConversationChecker(zk), } @@ -301,5 +301,5 @@ func (c *Check) Notification(ctx context.Context, notificationMsg *NotificationM msg.OfflinePushInfo = &offlineInfo req.MsgData = &msg - _, err = c.msg.SendMsg(ctx, &req) + _, err = c.Msg.SendMsg(ctx, &req) } diff --git a/internal/common/notification/extend_msg.go b/internal/common/notification/extend_msg.go index 18f0d7b91..40416f30d 100644 --- a/internal/common/notification/extend_msg.go +++ b/internal/common/notification/extend_msg.go @@ -87,5 +87,5 @@ func (c *Check) messageReactionSender(ctx context.Context, sendID string, source case constant.GroupChatType, constant.SuperGroupChatType: pbData.MsgData.GroupID = sourceID } - _, err = c.msg.SendMsg(ctx, &pbData) + _, err = c.Msg.SendMsg(ctx, &pbData) } diff --git a/internal/msggateway/new/client.go b/internal/msggateway/new/client.go index 61ade23fe..85dd68a00 100644 --- a/internal/msggateway/new/client.go +++ b/internal/msggateway/new/client.go @@ -86,7 +86,7 @@ func (c *Client) readMessage() { } //c.close() }() - var returnErr error + //var returnErr error for { messageType, message, returnErr := c.conn.ReadMessage() if returnErr != nil { @@ -115,7 +115,7 @@ func (c *Client) readMessage() { } func (c *Client) handleMessage(message []byte) error { - if c.IsCompress { + if c.isCompress { var decompressErr error message, decompressErr = c.compressor.DeCompress(message) if decompressErr != nil { @@ -134,9 +134,10 @@ func (c *Client) handleMessage(message []byte) error { return errors.New("exception conn userID not same to req userID") } ctx := context.Background() - ctx = context.WithValue(ctx, c.connID, binaryReq.OperationID) + ctx = context.WithValue(ctx, CONN_ID, c.connID) ctx = context.WithValue(ctx, OPERATION_ID, binaryReq.OperationID) - ctx = context.WithValue(ctx, "userID", binaryReq.SendID) + ctx = context.WithValue(ctx, COMMON_USERID, binaryReq.SendID) + ctx = context.WithValue(ctx, PLATFORM_ID, c.platformID) var messageErr error var resp []byte switch binaryReq.ReqIdentifier { @@ -151,13 +152,23 @@ func (c *Client) handleMessage(message []byte) error { case constant.WsLogoutMsg: resp, messageErr = c.handler.UserLogout(ctx, binaryReq) case constant.WsSetBackgroundStatus: - resp, messageErr = c.handler.SetUserDeviceBackground(ctx, binaryReq) + resp, messageErr = c.setAppBackgroundStatus(ctx, binaryReq) default: return errors.New(fmt.Sprintf("ReqIdentifier failed,sendID:%d,msgIncr:%s,reqIdentifier:%s", binaryReq.SendID, binaryReq.MsgIncr, binaryReq.ReqIdentifier)) } - c.replyMessage(binaryReq, messageErr, resp) + c.replyMessage(&binaryReq, messageErr, resp) return nil +} +func (c *Client) setAppBackgroundStatus(ctx context.Context, req Req) ([]byte, error) { + resp, isBackground, messageErr := c.handler.SetUserDeviceBackground(ctx, req) + if messageErr != nil { + return nil, messageErr + } + c.isBackground = isBackground + //todo callback + return resp, nil + } func (c *Client) close() { c.w.Lock() @@ -166,7 +177,7 @@ func (c *Client) close() { c.unregisterChan <- c } -func (c *Client) replyMessage(binaryReq Req, err error, resp []byte) { +func (c *Client) replyMessage(binaryReq *Req, err error, resp []byte) { mReply := Resp{ ReqIdentifier: binaryReq.ReqIdentifier, MsgIncr: binaryReq.MsgIncr, diff --git a/internal/msggateway/new/compressor.go b/internal/msggateway/new/compressor.go index fbc9ea7bf..c36e381d3 100644 --- a/internal/msggateway/new/compressor.go +++ b/internal/msggateway/new/compressor.go @@ -1,6 +1,7 @@ package new import ( + "Open_IM/pkg/utils" "bytes" "compress/gzip" "io/ioutil" diff --git a/internal/msggateway/new/constant.go b/internal/msggateway/new/constant.go index b9e5c6aea..6b646e24e 100644 --- a/internal/msggateway/new/constant.go +++ b/internal/msggateway/new/constant.go @@ -1,8 +1,10 @@ package new const ( - USERID = "sendID" + WS_USERID = "sendID" + COMMON_USERID = "userID" PLATFORM_ID = "platformID" + CONN_ID = "connID" TOKEN = "token" OPERATION_ID = "operationID" COMPRESSION = "compression" diff --git a/internal/msggateway/new/context.go b/internal/msggateway/new/context.go index 45e923a34..39b44e0e8 100644 --- a/internal/msggateway/new/context.go +++ b/internal/msggateway/new/context.go @@ -47,7 +47,7 @@ func (c *UserConnContext) GetConnID() string { return c.RemoteAddr + "_" + strconv.Itoa(int(utils.GetCurrentTimestampByMill())) } func (c *UserConnContext) GetUserID() string { - return c.Req.URL.Query().Get(USERID) + return c.Req.URL.Query().Get(WS_USERID) } func (c *UserConnContext) GetPlatformID() string { return c.Req.URL.Query().Get(PLATFORM_ID) diff --git a/internal/msggateway/new/encoder.go b/internal/msggateway/new/encoder.go index 522d40c6a..bbc3d5c86 100644 --- a/internal/msggateway/new/encoder.go +++ b/internal/msggateway/new/encoder.go @@ -1,6 +1,7 @@ package new import ( + "Open_IM/pkg/utils" "bytes" "encoding/gob" ) diff --git a/internal/msggateway/new/long_conn.go b/internal/msggateway/new/long_conn.go index fd33ea615..8845504a7 100644 --- a/internal/msggateway/new/long_conn.go +++ b/internal/msggateway/new/long_conn.go @@ -25,7 +25,7 @@ type LongConn interface { //Set the connection of the current long connection to nil SetConnNil() //Check the connection of the current and when it was sent are the same - CheckSendConnDiffNow() bool + //CheckSendConnDiffNow() bool // GenerateLongConn(w http.ResponseWriter, r *http.Request) error } @@ -58,13 +58,13 @@ func (d *GWebSocket) GenerateLongConn(w http.ResponseWriter, r *http.Request) er } func (d *GWebSocket) WriteMessage(messageType int, message []byte) error { - d.setSendConn(d.conn) + //d.setSendConn(d.conn) return d.conn.WriteMessage(messageType, message) } -func (d *GWebSocket) setSendConn(sendConn *websocket.Conn) { - d.sendConn = sendConn -} +//func (d *GWebSocket) setSendConn(sendConn *websocket.Conn) { +// d.sendConn = sendConn +//} func (d *GWebSocket) ReadMessage() (int, []byte, error) { return d.conn.ReadMessage() @@ -96,6 +96,7 @@ func (d *GWebSocket) IsNil() bool { func (d *GWebSocket) SetConnNil() { d.conn = nil } -func (d *GWebSocket) CheckSendConnDiffNow() bool { - return d.conn == d.sendConn -} + +//func (d *GWebSocket) CheckSendConnDiffNow() bool { +// return d.conn == d.sendConn +//} diff --git a/internal/msggateway/new/message_handler.go b/internal/msggateway/new/message_handler.go index 52e36414d..f75e16c05 100644 --- a/internal/msggateway/new/message_handler.go +++ b/internal/msggateway/new/message_handler.go @@ -1,8 +1,13 @@ package new import ( - "Open_IM/internal/common/check" + "Open_IM/internal/common/notification" + "Open_IM/pkg/proto/msg" + pbRtc "Open_IM/pkg/proto/rtc" + "Open_IM/pkg/proto/sdkws" "context" + "github.com/go-playground/validator/v10" + "github.com/golang/protobuf/proto" ) type Req struct { @@ -27,39 +32,118 @@ type MessageHandler interface { SendSignalMessage(context context.Context, data Req) ([]byte, error) PullMessageBySeqList(context context.Context, data Req) ([]byte, error) UserLogout(context context.Context, data Req) ([]byte, error) - SetUserDeviceBackground(context context.Context, data Req) ([]byte, error) + SetUserDeviceBackground(context context.Context, data Req) ([]byte, bool, error) } var _ MessageHandler = (*GrpcHandler)(nil) type GrpcHandler struct { - msg *check.MsgCheck + notification *notification.Check + validate *validator.Validate } -func NewGrpcHandler(msg *check.MsgCheck) *GrpcHandler { - return &GrpcHandler{msg: msg} +func NewGrpcHandler(validate *validator.Validate, notification *notification.Check) *GrpcHandler { + return &GrpcHandler{notification: notification, validate: validate} } func (g GrpcHandler) GetSeq(context context.Context, data Req) ([]byte, error) { - panic("implement me") + req := sdkws.GetMaxAndMinSeqReq{} + if err := proto.Unmarshal(data.Data, &req); err != nil { + return nil, err + } + if err := g.validate.Struct(req); err != nil { + return nil, err + } + resp, err := g.notification.Msg.GetMaxAndMinSeq(context, &req) + if err != nil { + return nil, err + } + c, err := proto.Marshal(resp) + if err != nil { + return nil, err + } + return c, nil } func (g GrpcHandler) SendMessage(context context.Context, data Req) ([]byte, error) { - panic("implement me") + msgData := sdkws.MsgData{} + if err := proto.Unmarshal(data.Data, &msgData); err != nil { + return nil, err + } + if err := g.validate.Struct(msgData); err != nil { + return nil, err + } + req := msg.SendMsgReq{MsgData: &msgData} + resp, err := g.notification.Msg.SendMsg(context, &req) + if err != nil { + return nil, err + } + c, err := proto.Marshal(resp) + if err != nil { + return nil, err + } + return c, nil } func (g GrpcHandler) SendSignalMessage(context context.Context, data Req) ([]byte, error) { - panic("implement me") + signalReq := pbRtc.SignalReq{} + if err := proto.Unmarshal(data.Data, &signalReq); err != nil { + return nil, err + } + if err := g.validate.Struct(signalReq); err != nil { + return nil, err + } + //req := pbRtc.SignalMessageAssembleReq{SignalReq: &signalReq, OperationID: "111"} + //todo rtc rpc call + resp, err := g.notification.Msg.SendMsg(context, nil) + if err != nil { + return nil, err + } + c, err := proto.Marshal(resp) + if err != nil { + return nil, err + } + return c, nil } func (g GrpcHandler) PullMessageBySeqList(context context.Context, data Req) ([]byte, error) { - panic("implement me") + req := sdkws.PullMessageBySeqListReq{} + if err := proto.Unmarshal(data.Data, &req); err != nil { + return nil, err + } + if err := g.validate.Struct(data); err != nil { + return nil, err + } + resp, err := g.notification.Msg.PullMessageBySeqList(context, &req) + if err != nil { + return nil, err + } + c, err := proto.Marshal(resp) + if err != nil { + return nil, err + } + return c, nil } func (g GrpcHandler) UserLogout(context context.Context, data Req) ([]byte, error) { - panic("implement me") + //todo + resp, err := g.notification.Msg.PullMessageBySeqList(context, nil) + if err != nil { + return nil, err + } + c, err := proto.Marshal(resp) + if err != nil { + return nil, err + } + return c, nil } - -func (g GrpcHandler) SetUserDeviceBackground(context context.Context, data Req) ([]byte, error) { - panic("implement me") +func (g GrpcHandler) SetUserDeviceBackground(_ context.Context, data Req) ([]byte, bool, error) { + req := sdkws.SetAppBackgroundStatusReq{} + if err := proto.Unmarshal(data.Data, &req); err != nil { + return nil, false, err + } + if err := g.validate.Struct(data); err != nil { + return nil, false, err + } + return nil, req.IsBackground, nil } diff --git a/internal/msggateway/new/n_ws_server.go b/internal/msggateway/new/n_ws_server.go index 65ef48460..83d7a31e5 100644 --- a/internal/msggateway/new/n_ws_server.go +++ b/internal/msggateway/new/n_ws_server.go @@ -2,6 +2,7 @@ package new import ( "Open_IM/pkg/common/constant" + "Open_IM/pkg/common/tokenverify" "Open_IM/pkg/utils" "errors" "fmt" @@ -15,7 +16,7 @@ import ( var bufferPool = sync.Pool{ New: func() interface{} { - return make([]byte, 1000) + return make([]byte, 1024) }, } @@ -27,7 +28,7 @@ type Server struct { rpcPort int wsMaxConnNum int longConnServer *LongConnServer - rpcServer *RpcServer + //rpcServer *RpcServer } type WsServer struct { port int @@ -40,11 +41,11 @@ type WsServer struct { onlineUserNum int64 onlineUserConnNum int64 gzipCompressor Compressor - encoder Encoder + encoder Encoder handler MessageHandler handshakeTimeout time.Duration readBufferSize, WriteBufferSize int - validate *validator.Validate + validate *validator.Validate } func newWsServer(opts ...Option) (*WsServer, error) { @@ -67,8 +68,8 @@ func newWsServer(opts ...Option) (*WsServer, error) { }, }, validate: validator.New(), - clients: newUserMap(), - handler: NewGrpcHandler(), + clients: newUserMap(), + //handler: NewGrpcHandler(validate), }, nil } func (ws *WsServer) Run() error { @@ -83,29 +84,29 @@ func (ws *WsServer) Run() error { } } }() - http.HandleFunc("/", ws.wsHandler) //Get request from client to handle by wsHandler + http.HandleFunc("/", ws.wsHandler) //Get request from client to handle by wsHandler return http.ListenAndServe(":"+utils.IntToString(ws.port), nil) //Start listening } func (ws *WsServer) registerClient(client *Client) { var ( - userOK bool + userOK bool clientOK bool - cli *Client + cli *Client ) - cli, userOK,clientOK = ws.clients.Get(client.userID,client.platformID) - if !userOK { - ws.clients.Set(client.userID,client) + cli, userOK, clientOK = ws.clients.Get(client.userID, client.platformID) + if !userOK { + ws.clients.Set(client.userID, client) atomic.AddInt64(&ws.onlineUserNum, 1) atomic.AddInt64(&ws.onlineUserConnNum, 1) fmt.Println("R在线用户数量:", ws.onlineUserNum) fmt.Println("R在线用户连接数量:", ws.onlineUserConnNum) - }else{ - if clientOK {//已经有同平台的连接存在 - ws.clients.Set(client.userID,client) + } else { + if clientOK { //已经有同平台的连接存在 + ws.clients.Set(client.userID, client) ws.multiTerminalLoginChecker(cli) - }else{ - ws.clients.Set(client.userID,client) + } else { + ws.clients.Set(client.userID, client) atomic.AddInt64(&ws.onlineUserConnNum, 1) fmt.Println("R在线用户数量:", ws.onlineUserNum) fmt.Println("R在线用户连接数量:", ws.onlineUserConnNum) @@ -118,8 +119,8 @@ func (ws *WsServer) multiTerminalLoginChecker(client *Client) { } func (ws *WsServer) unregisterClient(client *Client) { - isDeleteUser:=ws.clients.delete(client.userID,client.platformID) - if isDeleteUser { + isDeleteUser := ws.clients.delete(client.userID, client.platformID) + if isDeleteUser { atomic.AddInt64(&ws.onlineUserNum, -1) } atomic.AddInt64(&ws.onlineUserConnNum, -1) @@ -127,8 +128,6 @@ func (ws *WsServer) unregisterClient(client *Client) { fmt.Println("R在线用户连接数量:", ws.onlineUserConnNum) } - -} func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) { context := newContext(w, r) if ws.onlineUserConnNum >= ws.wsMaxConnNum { @@ -136,13 +135,12 @@ func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) { return } var ( - token string - userID string - platformID string - exists bool - compression bool - compressor Compressor - + token string + userID string + platformID string + exists bool + compression bool + compressor Compressor ) token, exists = context.Query(TOKEN) @@ -150,7 +148,7 @@ func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) { httpError(context, constant.ErrConnArgsErr) return } - userID, exists = context.Query(USERID) + userID, exists = context.Query(WS_USERID) if !exists { httpError(context, constant.ErrConnArgsErr) return @@ -165,7 +163,7 @@ func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) { httpError(context, err) return } - wsLongConn:=newGWebSocket(constant.WebSocket,ws.handshakeTimeout,ws.readBufferSize) + wsLongConn := newGWebSocket(constant.WebSocket, ws.handshakeTimeout, ws.readBufferSize) err = wsLongConn.GenerateLongConn(w, r) if err != nil { httpError(context, err) @@ -173,20 +171,20 @@ func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) { } compressProtoc, exists := context.Query(COMPRESSION) if exists { - if compressProtoc==GZIP_COMPRESSION_PROTOCAL{ + if compressProtoc == GZIP_COMPRESSION_PROTOCAL { compression = true compressor = ws.gzipCompressor } } compressProtoc, exists = context.GetHeader(COMPRESSION) if exists { - if compressProtoc==GZIP_COMPRESSION_PROTOCAL { + if compressProtoc == GZIP_COMPRESSION_PROTOCAL { compression = true compressor = ws.gzipCompressor } } - client:=ws.clientPool.Get().(*Client) - client.ResetClient(context,wsLongConn,compression,compressor,ws.encoder,ws.handler,ws.unregisterChan,ws.validate) + client := ws.clientPool.Get().(*Client) + client.ResetClient(context, wsLongConn, compression, compressor, ws.encoder, ws.handler, ws.unregisterChan, ws.validate) ws.registerChan <- client go client.readMessage() }