feat: kick user when same terminal login

This commit is contained in:
Gordon 2023-06-14 09:58:10 +08:00
parent f6dcc2ba44
commit ad42eaed11
5 changed files with 81 additions and 23 deletions

View File

@ -224,8 +224,11 @@ func (c *Client) PushMessage(ctx context.Context, msgData *sdkws.MsgData) error
return c.writeBinaryMsg(resp) return c.writeBinaryMsg(resp)
} }
func (c *Client) KickOnlineMessage(ctx context.Context) error { func (c *Client) KickOnlineMessage() error {
return nil resp := Resp{
ReqIdentifier: WSKickOnlineMsg,
}
return c.writeBinaryMsg(resp)
} }
func (c *Client) writeBinaryMsg(resp Resp) error { func (c *Client) writeBinaryMsg(resp Resp) error {

View File

@ -2,6 +2,7 @@ package msggateway
import ( import (
"context" "context"
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/db/cache"
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/config" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/config"
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant"
@ -17,7 +18,13 @@ import (
) )
func (s *Server) InitServer(client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error { func (s *Server) InitServer(client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error {
rdb, err := cache.NewRedis()
if err != nil {
return err
}
msgModel := cache.NewMsgCacheModel(rdb)
s.LongConnServer.SetDiscoveryRegistry(client) s.LongConnServer.SetDiscoveryRegistry(client)
s.LongConnServer.SetCacheHandler(msgModel)
msggateway.RegisterMsgGatewayServer(server, s) msggateway.RegisterMsgGatewayServer(server, s)
return nil return nil
} }
@ -131,7 +138,7 @@ func (s *Server) KickUserOffline(ctx context.Context, req *msggateway.KickUserOf
for _, v := range req.KickUserIDList { for _, v := range req.KickUserIDList {
if clients, _, ok := s.LongConnServer.GetUserPlatformCons(v, int(req.PlatformID)); ok { if clients, _, ok := s.LongConnServer.GetUserPlatformCons(v, int(req.PlatformID)); ok {
for _, client := range clients { for _, client := range clients {
err := client.KickOnlineMessage(ctx) err := client.KickOnlineMessage()
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -2,6 +2,9 @@ package msggateway
import ( import (
"errors" "errors"
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/config"
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant"
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/db/cache"
"net/http" "net/http"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -22,7 +25,7 @@ type LongConnServer interface {
GetUserAllCons(userID string) ([]*Client, bool) GetUserAllCons(userID string) ([]*Client, bool)
GetUserPlatformCons(userID string, platform int) ([]*Client, bool, bool) GetUserPlatformCons(userID string, platform int) ([]*Client, bool, bool)
Validate(s interface{}) error Validate(s interface{}) error
//SetMessageHandler(msgRpcClient *rpcclient.MsgClient) SetCacheHandler(cache cache.MsgModel)
SetDiscoveryRegistry(client discoveryregistry.SvcDiscoveryRegistry) SetDiscoveryRegistry(client discoveryregistry.SvcDiscoveryRegistry)
UnRegister(c *Client) UnRegister(c *Client)
Compressor Compressor
@ -41,6 +44,7 @@ type WsServer struct {
wsMaxConnNum int64 wsMaxConnNum int64
registerChan chan *Client registerChan chan *Client
unregisterChan chan *Client unregisterChan chan *Client
kickHandlerChan chan *kickHandler
clients *UserMap clients *UserMap
clientPool sync.Pool clientPool sync.Pool
onlineUserNum int64 onlineUserNum int64
@ -48,14 +52,23 @@ type WsServer struct {
handshakeTimeout time.Duration handshakeTimeout time.Duration
hubServer *Server hubServer *Server
validate *validator.Validate validate *validator.Validate
cache cache.MsgModel
Compressor Compressor
Encoder Encoder
MessageHandler MessageHandler
} }
type kickHandler struct {
clientOK bool
oldClients []*Client
newClient *Client
}
func (ws *WsServer) SetDiscoveryRegistry(client discoveryregistry.SvcDiscoveryRegistry) { func (ws *WsServer) SetDiscoveryRegistry(client discoveryregistry.SvcDiscoveryRegistry) {
ws.MessageHandler = NewGrpcHandler(ws.validate, client) ws.MessageHandler = NewGrpcHandler(ws.validate, client)
} }
func (ws *WsServer) SetCacheHandler(cache cache.MsgModel) {
ws.cache = cache
}
func (ws *WsServer) UnRegister(c *Client) { func (ws *WsServer) UnRegister(c *Client) {
ws.unregisterChan <- c ws.unregisterChan <- c
@ -94,6 +107,7 @@ func NewWsServer(opts ...Option) (*WsServer, error) {
}, },
registerChan: make(chan *Client, 1000), registerChan: make(chan *Client, 1000),
unregisterChan: make(chan *Client, 1000), unregisterChan: make(chan *Client, 1000),
kickHandlerChan: make(chan *kickHandler, 1000),
validate: v, validate: v,
clients: newUserMap(), clients: newUserMap(),
Compressor: NewGzipCompressor(), Compressor: NewGzipCompressor(),
@ -109,6 +123,8 @@ func (ws *WsServer) Run() error {
ws.registerClient(client) ws.registerClient(client)
case client = <-ws.unregisterChan: case client = <-ws.unregisterChan:
ws.unregisterClient(client) ws.unregisterClient(client)
case onlineInfo := <-ws.kickHandlerChan:
ws.multiTerminalLoginChecker(onlineInfo)
} }
} }
}() }()
@ -121,24 +137,27 @@ func (ws *WsServer) registerClient(client *Client) {
var ( var (
userOK bool userOK bool
clientOK bool clientOK bool
cli []*Client oldClients []*Client
) )
cli, userOK, clientOK = ws.clients.Get(client.UserID, client.PlatformID) ws.clients.Set(client.UserID, client)
oldClients, userOK, clientOK = ws.clients.Get(client.UserID, client.PlatformID)
if !userOK { if !userOK {
log.ZDebug(client.ctx, "user not exist", "userID", client.UserID, "platformID", client.PlatformID) log.ZDebug(client.ctx, "user not exist", "userID", client.UserID, "platformID", client.PlatformID)
ws.clients.Set(client.UserID, client)
atomic.AddInt64(&ws.onlineUserNum, 1) atomic.AddInt64(&ws.onlineUserNum, 1)
atomic.AddInt64(&ws.onlineUserConnNum, 1) atomic.AddInt64(&ws.onlineUserConnNum, 1)
} else { } else {
i := &kickHandler{
clientOK: clientOK,
oldClients: oldClients,
newClient: client,
}
ws.kickHandlerChan <- i
log.ZDebug(client.ctx, "user exist", "userID", client.UserID, "platformID", client.PlatformID) log.ZDebug(client.ctx, "user exist", "userID", client.UserID, "platformID", client.PlatformID)
if clientOK { //已经有同平台的连接存在 if clientOK { //已经有同平台的连接存在
ws.clients.Set(client.UserID, client) log.ZInfo(client.ctx, "repeat login", "userID", client.UserID, "platformID", client.PlatformID, "old remote addr", getRemoteAdders(oldClients))
ws.multiTerminalLoginChecker(cli)
log.ZInfo(client.ctx, "repeat login", "userID", client.UserID, "platformID", client.PlatformID, "old remote addr", getRemoteAdders(cli))
atomic.AddInt64(&ws.onlineUserConnNum, 1) atomic.AddInt64(&ws.onlineUserConnNum, 1)
} else { } else {
ws.clients.Set(client.UserID, client)
atomic.AddInt64(&ws.onlineUserConnNum, 1) atomic.AddInt64(&ws.onlineUserConnNum, 1)
} }
} }
@ -156,7 +175,24 @@ func getRemoteAdders(client []*Client) string {
return ret return ret
} }
func (ws *WsServer) multiTerminalLoginChecker(client []*Client) { func (ws *WsServer) multiTerminalLoginChecker(info *kickHandler) {
switch config.Config.MultiLoginPolicy {
case constant.DefalutNotKick:
case constant.PCAndOther:
if constant.PlatformIDToClass(info.newClient.PlatformID) == constant.TerminalPC {
return
}
fallthrough
case constant.AllLoginButSameTermKick:
if info.clientOK {
for _, c := range info.oldClients {
err := c.KickOnlineMessage()
if err != nil {
log.ZWarn()
}
}
}
}
} }
func (ws *WsServer) unregisterClient(client *Client) { func (ws *WsServer) unregisterClient(client *Client) {
@ -198,7 +234,6 @@ func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) {
httpError(context, errs.ErrConnArgsErr) httpError(context, errs.ErrConnArgsErr)
return return
} }
// log.ZDebug(context2.Background(), "conn", "platformID", platformID)
err := tokenverify.WsVerifyToken(token, userID, platformID) err := tokenverify.WsVerifyToken(token, userID, platformID)
if err != nil { if err != nil {
httpError(context, err) httpError(context, err)

View File

@ -118,6 +118,7 @@ const (
ExpiredToken = 3 ExpiredToken = 3
//MultiTerminalLogin //MultiTerminalLogin
DefalutNotKick = 0
//Full-end login, but the same end is mutually exclusive //Full-end login, but the same end is mutually exclusive
AllLoginButSameTermKick = 1 AllLoginButSameTermKick = 1
//Only one of the endpoints can log in //Only one of the endpoints can log in

View File

@ -57,7 +57,7 @@ var PlatformName2ID = map[string]int{
IPadPlatformStr: IPadPlatformID, IPadPlatformStr: IPadPlatformID,
AdminPlatformStr: AdminPlatformID, AdminPlatformStr: AdminPlatformID,
} }
var Platform2class = map[string]string{ var PlatformName2class = map[string]string{
IOSPlatformStr: TerminalMobile, IOSPlatformStr: TerminalMobile,
AndroidPlatformStr: TerminalMobile, AndroidPlatformStr: TerminalMobile,
MiniWebPlatformStr: WebPlatformStr, MiniWebPlatformStr: WebPlatformStr,
@ -66,6 +66,15 @@ var Platform2class = map[string]string{
OSXPlatformStr: TerminalPC, OSXPlatformStr: TerminalPC,
LinuxPlatformStr: TerminalPC, LinuxPlatformStr: TerminalPC,
} }
var PlatformID2class = map[int]string{
IOSPlatformID: TerminalMobile,
AndroidPlatformID: TerminalMobile,
MiniWebPlatformID: WebPlatformStr,
WebPlatformID: WebPlatformStr,
WindowsPlatformID: TerminalPC,
OSXPlatformID: TerminalPC,
LinuxPlatformID: TerminalPC,
}
func PlatformIDToName(num int) string { func PlatformIDToName(num int) string {
return PlatformID2Name[num] return PlatformID2Name[num]
@ -74,5 +83,8 @@ func PlatformNameToID(name string) int {
return PlatformName2ID[name] return PlatformName2ID[name]
} }
func PlatformNameToClass(name string) string { func PlatformNameToClass(name string) string {
return Platform2class[name] return PlatformName2class[name]
}
func PlatformIDToClass(num int) string {
return PlatformID2class[num]
} }