diff --git a/internal/msggateway/client.go b/internal/msggateway/client.go index e7d794324..6bce68c85 100644 --- a/internal/msggateway/client.go +++ b/internal/msggateway/client.go @@ -224,8 +224,11 @@ func (c *Client) PushMessage(ctx context.Context, msgData *sdkws.MsgData) error return c.writeBinaryMsg(resp) } -func (c *Client) KickOnlineMessage(ctx context.Context) error { - return nil +func (c *Client) KickOnlineMessage() error { + resp := Resp{ + ReqIdentifier: WSKickOnlineMsg, + } + return c.writeBinaryMsg(resp) } func (c *Client) writeBinaryMsg(resp Resp) error { diff --git a/internal/msggateway/hub_server.go b/internal/msggateway/hub_server.go index e93497de4..786d3eeff 100644 --- a/internal/msggateway/hub_server.go +++ b/internal/msggateway/hub_server.go @@ -2,6 +2,7 @@ package msggateway import ( "context" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/db/cache" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/config" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant" @@ -17,7 +18,13 @@ import ( ) func (s *Server) InitServer(client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error { + rdb, err := cache.NewRedis() + if err != nil { + return err + } + msgModel := cache.NewMsgCacheModel(rdb) s.LongConnServer.SetDiscoveryRegistry(client) + s.LongConnServer.SetCacheHandler(msgModel) msggateway.RegisterMsgGatewayServer(server, s) return nil } @@ -131,7 +138,7 @@ func (s *Server) KickUserOffline(ctx context.Context, req *msggateway.KickUserOf for _, v := range req.KickUserIDList { if clients, _, ok := s.LongConnServer.GetUserPlatformCons(v, int(req.PlatformID)); ok { for _, client := range clients { - err := client.KickOnlineMessage(ctx) + err := client.KickOnlineMessage() if err != nil { return nil, err } diff --git a/internal/msggateway/n_ws_server.go b/internal/msggateway/n_ws_server.go index 749287e7f..7318b2cf1 100644 --- a/internal/msggateway/n_ws_server.go +++ b/internal/msggateway/n_ws_server.go @@ -2,6 +2,9 @@ package msggateway import ( "errors" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/config" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/db/cache" "net/http" "sync" "sync/atomic" @@ -22,7 +25,7 @@ type LongConnServer interface { GetUserAllCons(userID string) ([]*Client, bool) GetUserPlatformCons(userID string, platform int) ([]*Client, bool, bool) Validate(s interface{}) error - //SetMessageHandler(msgRpcClient *rpcclient.MsgClient) + SetCacheHandler(cache cache.MsgModel) SetDiscoveryRegistry(client discoveryregistry.SvcDiscoveryRegistry) UnRegister(c *Client) Compressor @@ -41,6 +44,7 @@ type WsServer struct { wsMaxConnNum int64 registerChan chan *Client unregisterChan chan *Client + kickHandlerChan chan *kickHandler clients *UserMap clientPool sync.Pool onlineUserNum int64 @@ -48,14 +52,23 @@ type WsServer struct { handshakeTimeout time.Duration hubServer *Server validate *validator.Validate + cache cache.MsgModel Compressor Encoder MessageHandler } +type kickHandler struct { + clientOK bool + oldClients []*Client + newClient *Client +} func (ws *WsServer) SetDiscoveryRegistry(client discoveryregistry.SvcDiscoveryRegistry) { ws.MessageHandler = NewGrpcHandler(ws.validate, client) } +func (ws *WsServer) SetCacheHandler(cache cache.MsgModel) { + ws.cache = cache +} func (ws *WsServer) UnRegister(c *Client) { ws.unregisterChan <- c @@ -92,12 +105,13 @@ func NewWsServer(opts ...Option) (*WsServer, error) { return new(Client) }, }, - registerChan: make(chan *Client, 1000), - unregisterChan: make(chan *Client, 1000), - validate: v, - clients: newUserMap(), - Compressor: NewGzipCompressor(), - Encoder: NewGobEncoder(), + registerChan: make(chan *Client, 1000), + unregisterChan: make(chan *Client, 1000), + kickHandlerChan: make(chan *kickHandler, 1000), + validate: v, + clients: newUserMap(), + Compressor: NewGzipCompressor(), + Encoder: NewGobEncoder(), }, nil } func (ws *WsServer) Run() error { @@ -109,6 +123,8 @@ func (ws *WsServer) Run() error { ws.registerClient(client) case client = <-ws.unregisterChan: ws.unregisterClient(client) + case onlineInfo := <-ws.kickHandlerChan: + ws.multiTerminalLoginChecker(onlineInfo) } } }() @@ -119,26 +135,29 @@ func (ws *WsServer) Run() error { func (ws *WsServer) registerClient(client *Client) { var ( - userOK bool - clientOK bool - cli []*Client + userOK bool + clientOK bool + oldClients []*Client ) - cli, userOK, clientOK = ws.clients.Get(client.UserID, client.PlatformID) + ws.clients.Set(client.UserID, client) + oldClients, userOK, clientOK = ws.clients.Get(client.UserID, client.PlatformID) if !userOK { log.ZDebug(client.ctx, "user not exist", "userID", client.UserID, "platformID", client.PlatformID) - ws.clients.Set(client.UserID, client) atomic.AddInt64(&ws.onlineUserNum, 1) atomic.AddInt64(&ws.onlineUserConnNum, 1) } else { + i := &kickHandler{ + clientOK: clientOK, + oldClients: oldClients, + newClient: client, + } + ws.kickHandlerChan <- i log.ZDebug(client.ctx, "user exist", "userID", client.UserID, "platformID", client.PlatformID) if clientOK { //已经有同平台的连接存在 - ws.clients.Set(client.UserID, client) - ws.multiTerminalLoginChecker(cli) - log.ZInfo(client.ctx, "repeat login", "userID", client.UserID, "platformID", client.PlatformID, "old remote addr", getRemoteAdders(cli)) + log.ZInfo(client.ctx, "repeat login", "userID", client.UserID, "platformID", client.PlatformID, "old remote addr", getRemoteAdders(oldClients)) atomic.AddInt64(&ws.onlineUserConnNum, 1) } else { - ws.clients.Set(client.UserID, client) atomic.AddInt64(&ws.onlineUserConnNum, 1) } } @@ -156,7 +175,24 @@ func getRemoteAdders(client []*Client) string { return ret } -func (ws *WsServer) multiTerminalLoginChecker(client []*Client) { +func (ws *WsServer) multiTerminalLoginChecker(info *kickHandler) { + switch config.Config.MultiLoginPolicy { + case constant.DefalutNotKick: + case constant.PCAndOther: + if constant.PlatformIDToClass(info.newClient.PlatformID) == constant.TerminalPC { + return + } + fallthrough + case constant.AllLoginButSameTermKick: + if info.clientOK { + for _, c := range info.oldClients { + err := c.KickOnlineMessage() + if err != nil { + log.ZWarn() + } + } + } + } } func (ws *WsServer) unregisterClient(client *Client) { @@ -198,7 +234,6 @@ func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) { httpError(context, errs.ErrConnArgsErr) return } - // log.ZDebug(context2.Background(), "conn", "platformID", platformID) err := tokenverify.WsVerifyToken(token, userID, platformID) if err != nil { httpError(context, err) diff --git a/pkg/common/constant/constant.go b/pkg/common/constant/constant.go index 697694e87..8be23c10f 100644 --- a/pkg/common/constant/constant.go +++ b/pkg/common/constant/constant.go @@ -118,6 +118,7 @@ const ( ExpiredToken = 3 //MultiTerminalLogin + DefalutNotKick = 0 //Full-end login, but the same end is mutually exclusive AllLoginButSameTermKick = 1 //Only one of the endpoints can log in diff --git a/pkg/common/constant/platform_id_to_name.go b/pkg/common/constant/platform_id_to_name.go index 3d5ab059b..e8bb129eb 100644 --- a/pkg/common/constant/platform_id_to_name.go +++ b/pkg/common/constant/platform_id_to_name.go @@ -57,7 +57,7 @@ var PlatformName2ID = map[string]int{ IPadPlatformStr: IPadPlatformID, AdminPlatformStr: AdminPlatformID, } -var Platform2class = map[string]string{ +var PlatformName2class = map[string]string{ IOSPlatformStr: TerminalMobile, AndroidPlatformStr: TerminalMobile, MiniWebPlatformStr: WebPlatformStr, @@ -66,6 +66,15 @@ var Platform2class = map[string]string{ OSXPlatformStr: TerminalPC, LinuxPlatformStr: TerminalPC, } +var PlatformID2class = map[int]string{ + IOSPlatformID: TerminalMobile, + AndroidPlatformID: TerminalMobile, + MiniWebPlatformID: WebPlatformStr, + WebPlatformID: WebPlatformStr, + WindowsPlatformID: TerminalPC, + OSXPlatformID: TerminalPC, + LinuxPlatformID: TerminalPC, +} func PlatformIDToName(num int) string { return PlatformID2Name[num] @@ -74,5 +83,8 @@ func PlatformNameToID(name string) int { return PlatformName2ID[name] } func PlatformNameToClass(name string) string { - return Platform2class[name] + return PlatformName2class[name] +} +func PlatformIDToClass(num int) string { + return PlatformID2class[num] }