diff --git a/internal/common/check/msg.go b/internal/common/check/msg.go index 010ee1606..bc3adf16f 100644 --- a/internal/common/check/msg.go +++ b/internal/common/check/msg.go @@ -28,3 +28,39 @@ func (m *MsgCheck) SendMsg(ctx context.Context, req *msg.SendMsgReq) (*msg.SendM resp, err := msg.NewMsgClient(cc).SendMsg(ctx, req) return resp, err } + +func (m *MsgCheck) SendMsg(ctx context.Context, req *msg.SendMsgReq) (*msg.SendMsgResp, error) { + cc, err := m.getConn() + if err != nil { + return nil, err + } + resp, err := msg.NewMsgClient(cc).SendMsg(ctx, req) + return resp, err +} + +func (m *MsgCheck) SendMsg(ctx context.Context, req *msg.SendMsgReq) (*msg.SendMsgResp, error) { + cc, err := m.getConn() + if err != nil { + return nil, err + } + resp, err := msg.NewMsgClient(cc).SendMsg(ctx, req) + return resp, err +} + +func (m *MsgCheck) SendMsg(ctx context.Context, req *msg.SendMsgReq) (*msg.SendMsgResp, error) { + cc, err := m.getConn() + if err != nil { + return nil, err + } + resp, err := msg.NewMsgClient(cc).SendMsg(ctx, req) + return resp, err +} + +func (m *MsgCheck) SendMsg(ctx context.Context, req *msg.SendMsgReq) (*msg.SendMsgResp, error) { + cc, err := m.getConn() + if err != nil { + return nil, err + } + resp, err := msg.NewMsgClient(cc).SendMsg(ctx, req) + return resp, err +} diff --git a/internal/msggateway/new/client.go b/internal/msggateway/new/client.go index 68178a1da..61ade23fe 100644 --- a/internal/msggateway/new/client.go +++ b/internal/msggateway/new/client.go @@ -3,14 +3,12 @@ package new import ( "Open_IM/pkg/common/constant" "Open_IM/pkg/utils" - "bytes" "context" "errors" "fmt" "github.com/go-playground/validator/v10" "runtime/debug" "sync" - "time" ) const ( @@ -35,35 +33,52 @@ const ( type Client struct { w *sync.Mutex conn LongConn - PlatformID int32 - PushedMaxSeq uint32 - IsCompress bool + platformID int + isCompress bool userID string - IsBackground bool - token string + isBackground bool connID string onlineAt int64 // 上线时间戳(毫秒) handler MessageHandler unregisterChan chan *Client compressor Compressor encoder Encoder - userContext UserConnContext validate *validator.Validate closed bool } -func newClient(conn LongConn, isCompress bool, userID string, isBackground bool, token string, - connID string, onlineAt int64, handler MessageHandler, unregisterChan chan *Client) *Client { +func newClient(ctx *UserConnContext, conn LongConn, isCompress bool, compressor Compressor, encoder Encoder, + handler MessageHandler, unregisterChan chan *Client, validate *validator.Validate) *Client { return &Client{ - conn: conn, - IsCompress: isCompress, - userID: userID, IsBackground: isBackground, token: token, - connID: connID, - onlineAt: onlineAt, + w: new(sync.Mutex), + conn: conn, + platformID: utils.StringToInt(ctx.GetPlatformID()), + isCompress: isCompress, + userID: ctx.GetUserID(), + compressor: compressor, + encoder: encoder, + connID: ctx.GetConnID(), + onlineAt: utils.GetCurrentTimestampByMill(), handler: handler, unregisterChan: unregisterChan, + validate: validate, } } +func (c *Client) ResetClient(ctx *UserConnContext, conn LongConn, isCompress bool, compressor Compressor, encoder Encoder, + handler MessageHandler, unregisterChan chan *Client, validate *validator.Validate) { + c.w = new(sync.Mutex) + c.conn = conn + c.platformID = utils.StringToInt(ctx.GetPlatformID()) + c.isCompress = isCompress + c.userID = ctx.GetUserID() + c.compressor = compressor + c.encoder = encoder + c.connID = ctx.GetConnID() + c.onlineAt = utils.GetCurrentTimestampByMill() + c.handler = handler + c.unregisterChan = unregisterChan + c.validate = validate +} func (c *Client) readMessage() { defer func() { if r := recover(); r != nil { @@ -77,7 +92,7 @@ func (c *Client) readMessage() { if returnErr != nil { break } - if c.closed == true { + if c.closed == true { //连接刚置位已经关闭,但是协程还没退出的场景 break } switch messageType { @@ -119,7 +134,8 @@ func (c *Client) handleMessage(message []byte) error { return errors.New("exception conn userID not same to req userID") } ctx := context.Background() - ctx = context.WithValue(ctx, "operationID", binaryReq.OperationID) + ctx = context.WithValue(ctx, c.connID, binaryReq.OperationID) + ctx = context.WithValue(ctx, OPERATION_ID, binaryReq.OperationID) ctx = context.WithValue(ctx, "userID", binaryReq.SendID) var messageErr error var resp []byte @@ -173,7 +189,7 @@ func (c *Client) writeMsg(resp Resp) error { return utils.Wrap(err, "") } _ = c.conn.SetWriteTimeout(60) - if c.IsCompress { + if c.isCompress { var compressErr error resultBuf, compressErr = c.compressor.Compress(encodeBuf) if compressErr != nil { diff --git a/internal/msggateway/new/constant.go b/internal/msggateway/new/constant.go new file mode 100644 index 000000000..b9e5c6aea --- /dev/null +++ b/internal/msggateway/new/constant.go @@ -0,0 +1,10 @@ +package new + +const ( + USERID = "sendID" + PLATFORM_ID = "platformID" + TOKEN = "token" + OPERATION_ID = "operationID" + COMPRESSION = "compression" + GZIP_COMPRESSION_PROTOCAL = "gzip" +) diff --git a/internal/msggateway/new/context.go b/internal/msggateway/new/context.go index 9ab353351..45e923a34 100644 --- a/internal/msggateway/new/context.go +++ b/internal/msggateway/new/context.go @@ -1,6 +1,10 @@ package new -import "net/http" +import ( + "Open_IM/pkg/utils" + "net/http" + "strconv" +) type UserConnContext struct { RespWriter http.ResponseWriter @@ -19,9 +23,32 @@ func newContext(respWriter http.ResponseWriter, req *http.Request) *UserConnCont RemoteAddr: req.RemoteAddr, } } -func (c *UserConnContext) Query(key string) string { - return c.Req.URL.Query().Get(key) +func (c *UserConnContext) Query(key string) (string, bool) { + var value string + if value = c.Req.URL.Query().Get(key); value == "" { + return value, false + } + return value, true } -func (c *UserConnContext) GetHeader(key string) string { - return c.Req.Header.Get(key) +func (c *UserConnContext) GetHeader(key string) (string, bool) { + var value string + if value = c.Req.Header.Get(key); value == "" { + return value, false + } + return value, true +} +func (c *UserConnContext) SetHeader(key, value string) { + c.RespWriter.Header().Set(key, value) +} +func (c *UserConnContext) ErrReturn(error string, code int) { + http.Error(c.RespWriter, error, code) +} +func (c *UserConnContext) GetConnID() string { + return c.RemoteAddr + "_" + strconv.Itoa(int(utils.GetCurrentTimestampByMill())) +} +func (c *UserConnContext) GetUserID() string { + return c.Req.URL.Query().Get(USERID) +} +func (c *UserConnContext) GetPlatformID() string { + return c.Req.URL.Query().Get(PLATFORM_ID) } diff --git a/internal/msggateway/new/http_error.go b/internal/msggateway/new/http_error.go new file mode 100644 index 000000000..8686d78d7 --- /dev/null +++ b/internal/msggateway/new/http_error.go @@ -0,0 +1,44 @@ +package new + +import ( + "Open_IM/pkg/common/constant" + "errors" + "net/http" +) + +func httpError(ctx *UserConnContext, err error) { + code := http.StatusUnauthorized + ctx.SetHeader("Sec-Websocket-Version", "13") + ctx.SetHeader("ws_err_msg", err.Error()) + if errors.Is(err, constant.ErrTokenExpired) { + code = int(constant.ErrTokenExpired.ErrCode) + } + if errors.Is(err, constant.ErrTokenInvalid) { + code = int(constant.ErrTokenInvalid.ErrCode) + } + if errors.Is(err, constant.ErrTokenMalformed) { + code = int(constant.ErrTokenMalformed.ErrCode) + } + if errors.Is(err, constant.ErrTokenNotValidYet) { + code = int(constant.ErrTokenNotValidYet.ErrCode) + } + if errors.Is(err, constant.ErrTokenUnknown) { + code = int(constant.ErrTokenUnknown.ErrCode) + } + if errors.Is(err, constant.ErrTokenKicked) { + code = int(constant.ErrTokenKicked.ErrCode) + } + if errors.Is(err, constant.ErrTokenDifferentPlatformID) { + code = int(constant.ErrTokenDifferentPlatformID.ErrCode) + } + if errors.Is(err, constant.ErrTokenDifferentUserID) { + code = int(constant.ErrTokenDifferentUserID.ErrCode) + } + if errors.Is(err, constant.ErrConnOverMaxNumLimit) { + code = int(constant.ErrConnOverMaxNumLimit.ErrCode) + } + if errors.Is(err, constant.ErrConnArgsErr) { + code = int(constant.ErrConnArgsErr.ErrCode) + } + ctx.ErrReturn(err.Error(), code) +} diff --git a/internal/msggateway/new/long_conn.go b/internal/msggateway/new/long_conn.go index 823964e40..fd33ea615 100644 --- a/internal/msggateway/new/long_conn.go +++ b/internal/msggateway/new/long_conn.go @@ -26,19 +26,37 @@ type LongConn interface { SetConnNil() //Check the connection of the current and when it was sent are the same CheckSendConnDiffNow() bool + // + GenerateLongConn(w http.ResponseWriter, r *http.Request) error } type GWebSocket struct { - protocolType int - conn *websocket.Conn + protocolType int + conn *websocket.Conn + handshakeTimeout time.Duration + readBufferSize, WriteBufferSize int } -func NewDefault(protocolType int) *GWebSocket { - return &GWebSocket{protocolType: protocolType} +func newGWebSocket(protocolType int, handshakeTimeout time.Duration, readBufferSize int) *GWebSocket { + return &GWebSocket{protocolType: protocolType, handshakeTimeout: handshakeTimeout, readBufferSize: readBufferSize} } + func (d *GWebSocket) Close() error { return d.conn.Close() } +func (d *GWebSocket) GenerateLongConn(w http.ResponseWriter, r *http.Request) error { + upgrader := &websocket.Upgrader{ + HandshakeTimeout: d.handshakeTimeout, + ReadBufferSize: d.readBufferSize, + CheckOrigin: func(r *http.Request) bool { return true }, + } + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return err + } + d.conn = conn + return nil +} func (d *GWebSocket) WriteMessage(messageType int, message []byte) error { d.setSendConn(d.conn) return d.conn.WriteMessage(messageType, message) diff --git a/internal/msggateway/new/message_handler.go b/internal/msggateway/new/message_handler.go index e9285c5d3..52e36414d 100644 --- a/internal/msggateway/new/message_handler.go +++ b/internal/msggateway/new/message_handler.go @@ -1,6 +1,9 @@ package new -import "context" +import ( + "Open_IM/internal/common/check" + "context" +) type Req struct { ReqIdentifier int32 `json:"reqIdentifier" validate:"required"` @@ -30,6 +33,11 @@ type MessageHandler interface { var _ MessageHandler = (*GrpcHandler)(nil) type GrpcHandler struct { + msg *check.MsgCheck +} + +func NewGrpcHandler(msg *check.MsgCheck) *GrpcHandler { + return &GrpcHandler{msg: msg} } func (g GrpcHandler) GetSeq(context context.Context, data Req) ([]byte, error) { diff --git a/internal/msggateway/new/n_ws_server.go b/internal/msggateway/new/n_ws_server.go index d74d715b7..65ef48460 100644 --- a/internal/msggateway/new/n_ws_server.go +++ b/internal/msggateway/new/n_ws_server.go @@ -1,8 +1,11 @@ package new import ( - "bytes" + "Open_IM/pkg/common/constant" + "Open_IM/pkg/utils" "errors" + "fmt" + "github.com/go-playground/validator/v10" "github.com/gorilla/websocket" "net/http" "sync" @@ -10,12 +13,12 @@ import ( "time" ) - var bufferPool = sync.Pool{ New: func() interface{} { return make([]byte, 1000) }, } + type LongConnServer interface { Run() error } @@ -27,17 +30,21 @@ type Server struct { rpcServer *RpcServer } type WsServer struct { - port int - wsMaxConnNum int - wsUpGrader *websocket.Upgrader - registerChan chan *Client - unregisterChan chan *Client - clients *UserMap - clientPool sync.Pool - onlineUserNum int64 - onlineUserConnNum int64 - compressor Compressor - handler MessageHandler + port int + wsMaxConnNum int64 + wsUpGrader *websocket.Upgrader + registerChan chan *Client + unregisterChan chan *Client + clients *UserMap + clientPool sync.Pool + onlineUserNum int64 + onlineUserConnNum int64 + gzipCompressor Compressor + encoder Encoder + handler MessageHandler + handshakeTimeout time.Duration + readBufferSize, WriteBufferSize int + validate *validator.Validate } func newWsServer(opts ...Option) (*WsServer, error) { @@ -50,18 +57,18 @@ func newWsServer(opts ...Option) (*WsServer, error) { } return &WsServer{ - port: config.port, - wsMaxConnNum: config.maxConnNum, - wsUpGrader: &websocket.Upgrader{ - HandshakeTimeout: config.handshakeTimeout, - ReadBufferSize: config.messageMaxMsgLength, - CheckOrigin: func(r *http.Request) bool { return true }, - }, + port: config.port, + wsMaxConnNum: config.maxConnNum, + handshakeTimeout: config.handshakeTimeout, + readBufferSize: config.messageMaxMsgLength, clientPool: sync.Pool{ New: func() interface{} { return new(Client) }, }, + validate: validator.New(), + clients: newUserMap(), + handler: NewGrpcHandler(), }, nil } func (ws *WsServer) Run() error { @@ -71,53 +78,115 @@ func (ws *WsServer) Run() error { select { case client = <-ws.registerChan: ws.registerClient(client) - case client = <-h.unregisterChan: - h.unregisterClient(client) - case msg = <-h.readChan: - h.messageHandler(msg) + case client = <-ws.unregisterChan: + ws.unregisterClient(client) } } }() + http.HandleFunc("/", ws.wsHandler) //Get request from client to handle by wsHandler + return http.ListenAndServe(":"+utils.IntToString(ws.port), nil) //Start listening } func (ws *WsServer) registerClient(client *Client) { var ( - ok bool + userOK bool + clientOK bool cli *Client ) - - if cli, ok = h.clients.Get(client.key); ok == false { - h.clients.Set(client.key, client) - atomic.AddInt64(&h.onlineConnections, 1) - fmt.Println("R在线用户数量:", h.onlineConnections) - return + cli, userOK,clientOK = ws.clients.Get(client.userID,client.platformID) + if !userOK { + ws.clients.Set(client.userID,client) + atomic.AddInt64(&ws.onlineUserNum, 1) + atomic.AddInt64(&ws.onlineUserConnNum, 1) + fmt.Println("R在线用户数量:", ws.onlineUserNum) + fmt.Println("R在线用户连接数量:", ws.onlineUserConnNum) + }else{ + if clientOK {//已经有同平台的连接存在 + ws.clients.Set(client.userID,client) + ws.multiTerminalLoginChecker(cli) + }else{ + ws.clients.Set(client.userID,client) + atomic.AddInt64(&ws.onlineUserConnNum, 1) + fmt.Println("R在线用户数量:", ws.onlineUserNum) + fmt.Println("R在线用户连接数量:", ws.onlineUserConnNum) + } } - if client.onlineAt > cli.onlineAt { - h.clients.Set(client.key, client) - h.close(cli) - return - } - h.close(client) } - http.HandleFunc("/", ws.wsHandler) //Get request from client to handle by wsHandler - return http.ListenAndServe(":"+utils.IntToString(ws.port), nil) //Start listening + +func (ws *WsServer) multiTerminalLoginChecker(client *Client) { + +} +func (ws *WsServer) unregisterClient(client *Client) { + isDeleteUser:=ws.clients.delete(client.userID,client.platformID) + if isDeleteUser { + atomic.AddInt64(&ws.onlineUserNum, -1) + } + atomic.AddInt64(&ws.onlineUserConnNum, -1) + fmt.Println("R在线用户数量:", ws.onlineUserNum) + fmt.Println("R在线用户连接数量:", ws.onlineUserConnNum) +} + } func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) { context := newContext(w, r) - if isPass, compression := ws.headerCheck(w, r, operationID); isPass { - conn, err := ws.wsUpGrader.Upgrade(w, r, nil) //Conn is obtained through the upgraded escalator - if err != nil { - log.Error(operationID, "upgrade http conn err", err.Error(), query) - return - } else { - newConn := &UserConn{conn, new(sync.Mutex), utils.StringToInt32(query["platformID"][0]), 0, compression, query["sendID"][0], false, query["token"][0], conn.RemoteAddr().String() + "_" + strconv.Itoa(int(utils.GetCurrentTimestampByMill()))} - userCount++ - ws.addUserConn(query["sendID"][0], utils.StringToInt(query["platformID"][0]), newConn, query["token"][0], newConn.connID, operationID) - go ws.readMsg(newConn) - } - } else { - log.Error(operationID, "headerCheck failed ") + if ws.onlineUserConnNum >= ws.wsMaxConnNum { + httpError(context, constant.ErrConnOverMaxNumLimit) + return } + var ( + token string + userID string + platformID string + exists bool + compression bool + compressor Compressor + + ) + + token, exists = context.Query(TOKEN) + if !exists { + httpError(context, constant.ErrConnArgsErr) + return + } + userID, exists = context.Query(USERID) + if !exists { + httpError(context, constant.ErrConnArgsErr) + return + } + platformID, exists = context.Query(PLATFORM_ID) + if !exists { + httpError(context, constant.ErrConnArgsErr) + return + } + err := tokenverify.WsVerifyToken(token, userID, platformID) + if err != nil { + httpError(context, err) + return + } + wsLongConn:=newGWebSocket(constant.WebSocket,ws.handshakeTimeout,ws.readBufferSize) + err = wsLongConn.GenerateLongConn(w, r) + if err != nil { + httpError(context, err) + return + } + compressProtoc, exists := context.Query(COMPRESSION) + if exists { + if compressProtoc==GZIP_COMPRESSION_PROTOCAL{ + compression = true + compressor = ws.gzipCompressor + } + } + compressProtoc, exists = context.GetHeader(COMPRESSION) + if exists { + if compressProtoc==GZIP_COMPRESSION_PROTOCAL { + compression = true + compressor = ws.gzipCompressor + } + } + client:=ws.clientPool.Get().(*Client) + client.ResetClient(context,wsLongConn,compression,compressor,ws.encoder,ws.handler,ws.unregisterChan,ws.validate) + ws.registerChan <- client + go client.readMessage() } diff --git a/internal/msggateway/new/options.go b/internal/msggateway/new/options.go index 71a732751..0fb575fc1 100644 --- a/internal/msggateway/new/options.go +++ b/internal/msggateway/new/options.go @@ -7,7 +7,7 @@ type configs struct { //长连接监听端口 port int //长连接允许最大链接数 - maxConnNum int + maxConnNum int64 //连接握手超时时间 handshakeTimeout time.Duration //允许消息最大长度 @@ -19,7 +19,7 @@ func WithPort(port int) Option { opt.port = port } } -func WithMaxConnNum(num int) Option { +func WithMaxConnNum(num int64) Option { return func(opt *configs) { opt.maxConnNum = num } diff --git a/internal/msggateway/new/user_map.go b/internal/msggateway/new/user_map.go index 82615e827..0f45aa018 100644 --- a/internal/msggateway/new/user_map.go +++ b/internal/msggateway/new/user_map.go @@ -9,24 +9,24 @@ type UserMap struct { func newUserMap() *UserMap { return &UserMap{} } -func (u *UserMap) GetAll(key string) []*Client { +func (u *UserMap) GetAll(key string) ([]*Client, bool) { allClients, ok := u.m.Load(key) if ok { - return allClients.([]*Client) + return allClients.([]*Client), ok } - return nil + return nil, ok } -func (u *UserMap) Get(key string, platformID int32) (*Client, bool) { - allClients, existed := u.m.Load(key) - if existed { +func (u *UserMap) Get(key string, platformID int) (*Client, bool, bool) { + allClients, userExisted := u.m.Load(key) + if userExisted { for _, client := range allClients.([]*Client) { - if client.PlatformID == platformID { - return client, existed + if client.platformID == platformID { + return client, userExisted, true } } - return nil, false + return nil, userExisted, false } - return nil, existed + return nil, userExisted, false } func (u *UserMap) Set(key string, v *Client) { allClients, existed := u.m.Load(key) @@ -40,24 +40,25 @@ func (u *UserMap) Set(key string, v *Client) { u.m.Store(key, clients) } } -func (u *UserMap) delete(key string, platformID int32) { +func (u *UserMap) delete(key string, platformID int) (isDeleteUser bool) { allClients, existed := u.m.Load(key) if existed { oldClients := allClients.([]*Client) - - a := make([]*Client, len(oldClients)) + a := make([]*Client, 3) for _, client := range oldClients { - if client.PlatformID != platformID { + if client.platformID != platformID { a = append(a, client) } } if len(a) == 0 { u.m.Delete(key) + return true } else { u.m.Store(key, a) - + return false } } + return existed } func (u *UserMap) DeleteAll(key string) { u.m.Delete(key) diff --git a/pkg/common/constant/errors.go b/pkg/common/constant/errors.go index b998fc71d..f59216967 100644 --- a/pkg/common/constant/errors.go +++ b/pkg/common/constant/errors.go @@ -54,6 +54,14 @@ var ( // ErrMutedInGroup = &ErrInfo{MutedInGroup, "MutedInGroup", ""} ErrMutedGroup = &ErrInfo{MutedGroup, "MutedGroup", ""} + + ErrConnOverMaxNumLimit = &ErrInfo{ConnOverMaxNumLimit, "ConnOverMaxNumLimit", ""} + + ErrConnOverMaxNumLimit = &ErrInfo{ConnOverMaxNumLimit, "ConnOverMaxNumLimit", ""} + + ErrConnOverMaxNumLimit = &ErrInfo{ConnOverMaxNumLimit, "ConnOverMaxNumLimit", ""} + ErrConnArgsErr = &ErrInfo{ConnArgsErr, "args err, need token, sendID, platformID", ""} + ErrConnUpdateErr = &ErrInfo{ConnArgsErr, "upgrade http conn err", ""} ) const ( @@ -142,6 +150,13 @@ const ( MessageHasReadDisable = 96001 ) +// 长连接网关错误码 +const ( + ConnOverMaxNumLimit = 970001 + ConnArgsErr = 970002 + ConnUpdateErr = 970003 +) + // temp var (