diff --git a/internal/rpc/auth/auth.go b/internal/rpc/auth/auth.go index 804375e4f..ab2e3a7a9 100644 --- a/internal/rpc/auth/auth.go +++ b/internal/rpc/auth/auth.go @@ -45,10 +45,11 @@ type authServer struct { } type Config struct { - RpcConfig config.Auth - RedisConfig config.Redis - Share config.Share - Discovery config.Discovery + RpcConfig config.Auth + MsgGatewayConfig config.MsgGateway + RedisConfig config.Redis + Share config.Share + Discovery config.Discovery } func Start(ctx context.Context, config *Config, client discovery.SvcDiscoveryRegistry, server *grpc.Server) error { @@ -64,6 +65,7 @@ func Start(ctx context.Context, config *Config, client discovery.SvcDiscoveryReg redis2.NewTokenCacheModel(rdb, config.RpcConfig.TokenPolicy.Expire), config.Share.Secret, config.RpcConfig.TokenPolicy.Expire, + config.MsgGatewayConfig.MultiLoginPolicy, ), config: config, }) diff --git a/pkg/common/cmd/auth.go b/pkg/common/cmd/auth.go index b35a95f39..a4f6a50c4 100644 --- a/pkg/common/cmd/auth.go +++ b/pkg/common/cmd/auth.go @@ -35,10 +35,11 @@ func NewAuthRpcCmd() *AuthRpcCmd { var authConfig auth.Config ret := &AuthRpcCmd{authConfig: &authConfig} ret.configMap = map[string]any{ - OpenIMRPCAuthCfgFileName: &authConfig.RpcConfig, - RedisConfigFileName: &authConfig.RedisConfig, - ShareFileName: &authConfig.Share, - DiscoveryConfigFilename: &authConfig.Discovery, + OpenIMRPCAuthCfgFileName: &authConfig.RpcConfig, + OpenIMMsgGatewayCfgFileName: &authConfig.MsgGatewayConfig, + RedisConfigFileName: &authConfig.RedisConfig, + ShareFileName: &authConfig.Share, + DiscoveryConfigFilename: &authConfig.Discovery, } ret.RootCmd = NewRootCmd(program.GetProcessName(), WithConfigMap(ret.configMap)) ret.ctx = context.WithValue(context.Background(), "version", version.Version) diff --git a/pkg/common/storage/controller/auth.go b/pkg/common/storage/controller/auth.go index 410283927..cb06a197d 100644 --- a/pkg/common/storage/controller/auth.go +++ b/pkg/common/storage/controller/auth.go @@ -35,13 +35,14 @@ type AuthDatabase interface { } type authDatabase struct { - cache cache.TokenModel - accessSecret string - accessExpire int64 + cache cache.TokenModel + accessSecret string + accessExpire int64 + multiLoginPolicy int } -func NewAuthDatabase(cache cache.TokenModel, accessSecret string, accessExpire int64) AuthDatabase { - return &authDatabase{cache: cache, accessSecret: accessSecret, accessExpire: accessExpire} +func NewAuthDatabase(cache cache.TokenModel, accessSecret string, accessExpire int64, policy int) AuthDatabase { + return &authDatabase{cache: cache, accessSecret: accessSecret, accessExpire: accessExpire, multiLoginPolicy: policy} } // If the result is empty. @@ -55,6 +56,7 @@ func (a *authDatabase) SetTokenMapByUidPid(ctx context.Context, userID string, p // Create Token. func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformID int) (string, error) { + // todo: get all platform token tokens, err := a.cache.GetTokensWithoutError(ctx, userID, platformID) if err != nil { return "", err @@ -65,7 +67,7 @@ func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformI t, err := tokenverify.GetClaimFromToken(k, authverify.Secret(a.accessSecret)) if err != nil || v != constant.NormalToken { deleteTokenKey = append(deleteTokenKey, k) - } else if t.UserID == userID && t.PlatformID == platformID { + } else if a.checkKickToken(ctx, platformID, t) { kickedTokenKey = append(kickedTokenKey, k) } } @@ -96,3 +98,23 @@ func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformI } return tokenString, nil } + +func (a *authDatabase) checkKickToken(ctx context.Context, platformID int, token *tokenverify.Claims) bool { + switch a.multiLoginPolicy { + case constant.DefalutNotKick: + return false + case constant.PCAndOther: + if constant.PlatformIDToClass(platformID) == constant.TerminalPC || + constant.PlatformIDToClass(token.PlatformID) == constant.TerminalPC { + return false + } + return true + case constant.AllLoginButSameTermKick: + if platformID == token.PlatformID { + return true + } + return false + default: + return false + } +}