fix: admin token limit (#2871)

This commit is contained in:
icey-yu 2024-11-22 12:25:28 +08:00 committed by OpenIM-Robot
parent c9e2f7d375
commit 0e07ad70c3
3 changed files with 59 additions and 34 deletions

View File

@ -16,6 +16,7 @@ package auth
import ( import (
"context" "context"
"errors"
"github.com/openimsdk/open-im-server/v3/pkg/common/config" "github.com/openimsdk/open-im-server/v3/pkg/common/config"
redis2 "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/redis" redis2 "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/redis"
@ -66,6 +67,7 @@ func Start(ctx context.Context, config *Config, client discovery.SvcDiscoveryReg
config.Share.Secret, config.Share.Secret,
config.RpcConfig.TokenPolicy.Expire, config.RpcConfig.TokenPolicy.Expire,
config.Share.MultiLogin, config.Share.MultiLogin,
config.Share.IMAdminUserID,
), ),
config: config, config: config,
}) })
@ -129,6 +131,10 @@ func (s *authServer) parseToken(ctx context.Context, tokensString string) (claim
if err != nil { if err != nil {
return nil, errs.Wrap(err) return nil, errs.Wrap(err)
} }
isAdmin := authverify.IsManagerUserID(claims.UserID, s.config.Share.IMAdminUserID)
if isAdmin {
return claims, nil
}
m, err := s.authDatabase.GetTokensWithoutError(ctx, claims.UserID, claims.PlatformID) m, err := s.authDatabase.GetTokensWithoutError(ctx, claims.UserID, claims.PlatformID)
if err != nil { if err != nil {
return nil, err return nil, err
@ -190,7 +196,7 @@ func (s *authServer) forceKickOff(ctx context.Context, userID string, platformID
} }
m, err := s.authDatabase.GetTokensWithoutError(ctx, userID, int(platformID)) m, err := s.authDatabase.GetTokensWithoutError(ctx, userID, int(platformID))
if err != nil && err != redis.Nil { if err != nil && errors.Is(err, redis.Nil) {
return err return err
} }
for k := range m { for k := range m {
@ -208,7 +214,7 @@ func (s *authServer) forceKickOff(ctx context.Context, userID string, platformID
func (s *authServer) InvalidateToken(ctx context.Context, req *pbauth.InvalidateTokenReq) (*pbauth.InvalidateTokenResp, error) { func (s *authServer) InvalidateToken(ctx context.Context, req *pbauth.InvalidateTokenReq) (*pbauth.InvalidateTokenResp, error) {
m, err := s.authDatabase.GetTokensWithoutError(ctx, req.UserID, int(req.PlatformID)) m, err := s.authDatabase.GetTokensWithoutError(ctx, req.UserID, int(req.PlatformID))
if err != nil && err != redis.Nil { if err != nil && errors.Is(err, redis.Nil) {
return nil, err return nil, err
} }
if m == nil { if m == nil {

View File

@ -34,14 +34,26 @@ type authDatabase struct {
accessSecret string accessSecret string
accessExpire int64 accessExpire int64
multiLogin multiLoginConfig multiLogin multiLoginConfig
adminUserIDs []string
} }
func NewAuthDatabase(cache cache.TokenModel, accessSecret string, accessExpire int64, multiLogin config.MultiLogin) AuthDatabase { func NewAuthDatabase(cache cache.TokenModel, accessSecret string, accessExpire int64, multiLogin config.MultiLogin, adminUserIDs []string) AuthDatabase {
return &authDatabase{cache: cache, accessSecret: accessSecret, accessExpire: accessExpire, multiLogin: multiLoginConfig{ return &authDatabase{cache: cache, accessSecret: accessSecret, accessExpire: accessExpire, multiLogin: multiLoginConfig{
Policy: multiLogin.Policy, Policy: multiLogin.Policy,
MaxNumOneEnd: multiLogin.MaxNumOneEnd, MaxNumOneEnd: multiLogin.MaxNumOneEnd,
CustomizeLoginNum: map[int]int{
constant.IOSPlatformID: multiLogin.CustomizeLoginNum.IOS,
constant.AndroidPlatformID: multiLogin.CustomizeLoginNum.Android,
constant.WindowsPlatformID: multiLogin.CustomizeLoginNum.Windows,
constant.OSXPlatformID: multiLogin.CustomizeLoginNum.OSX,
constant.WebPlatformID: multiLogin.CustomizeLoginNum.Web,
constant.MiniWebPlatformID: multiLogin.CustomizeLoginNum.MiniWeb,
constant.LinuxPlatformID: multiLogin.CustomizeLoginNum.Linux,
constant.AndroidPadPlatformID: multiLogin.CustomizeLoginNum.APad,
constant.IPadPlatformID: multiLogin.CustomizeLoginNum.IPad,
constant.AdminPlatformID: multiLogin.CustomizeLoginNum.Admin,
}, },
adminUserIDs: adminUserIDs, }, adminUserIDs: adminUserIDs,
} }
} }
@ -79,10 +91,13 @@ func (a *authDatabase) BatchSetTokenMapByUidPid(ctx context.Context, tokens []st
// Create Token. // Create Token.
func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformID int) (string, error) { func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformID int) (string, error) {
isAdmin := authverify.IsManagerUserID(userID, a.adminUserIDs)
if !isAdmin {
tokens, err := a.cache.GetAllTokensWithoutError(ctx, userID) tokens, err := a.cache.GetAllTokensWithoutError(ctx, userID)
if err != nil { if err != nil {
return "", err return "", err
} }
deleteTokenKey, kickedTokenKey, err := a.checkToken(ctx, tokens, platformID) deleteTokenKey, kickedTokenKey, err := a.checkToken(ctx, tokens, platformID)
if err != nil { if err != nil {
return "", err return "", err
@ -102,6 +117,7 @@ func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformI
log.ZDebug(ctx, "kicked token in create token", "token", k) log.ZDebug(ctx, "kicked token in create token", "token", k)
} }
} }
}
claims := tokenverify.BuildClaims(userID, platformID, a.accessExpire) claims := tokenverify.BuildClaims(userID, platformID, a.accessExpire)
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
@ -110,9 +126,12 @@ func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformI
return "", errs.WrapMsg(err, "token.SignedString") return "", errs.WrapMsg(err, "token.SignedString")
} }
if !isAdmin {
if err = a.cache.SetTokenFlagEx(ctx, userID, platformID, tokenString, constant.NormalToken); err != nil { if err = a.cache.SetTokenFlagEx(ctx, userID, platformID, tokenString, constant.NormalToken); err != nil {
return "", err return "", err
} }
}
return tokenString, nil return tokenString, nil
} }
@ -215,16 +234,16 @@ func (a *authDatabase) checkToken(ctx context.Context, tokens map[int]map[string
return nil, nil, errs.New("unknown multiLogin policy").Wrap() return nil, nil, errs.New("unknown multiLogin policy").Wrap()
} }
var adminTokenMaxNum = a.multiLogin.MaxNumOneEnd //var adminTokenMaxNum = a.multiLogin.MaxNumOneEnd
if a.multiLogin.Policy == constant.Customize { //if a.multiLogin.Policy == constant.Customize {
adminTokenMaxNum = a.multiLogin.CustomizeLoginNum[constant.AdminPlatformID] // adminTokenMaxNum = a.multiLogin.CustomizeLoginNum[constant.AdminPlatformID]
} //}
l := len(adminToken) //l := len(adminToken)
if platformID == constant.AdminPlatformID { //if platformID == constant.AdminPlatformID {
l++ // l++
} //}
if l > adminTokenMaxNum { //if l > adminTokenMaxNum {
kickToken = append(kickToken, adminToken[:l-adminTokenMaxNum]...) // kickToken = append(kickToken, adminToken[:l-adminTokenMaxNum]...)
} //}
return deleteToken, kickToken, nil return deleteToken, kickToken, nil
} }

View File

@ -490,7 +490,7 @@ func (db *commonMsgDatabase) GetMsgBySeqs(ctx context.Context, userID string, co
} }
successMsgs, failedSeqs, err := db.msg.GetMessagesBySeq(ctx, conversationID, newSeqs) successMsgs, failedSeqs, err := db.msg.GetMessagesBySeq(ctx, conversationID, newSeqs)
if err != nil { if err != nil {
if err != redis.Nil { if errors.Is(err, redis.Nil) {
log.ZError(ctx, "get message from redis exception", err, "failedSeqs", failedSeqs, "conversationID", conversationID) log.ZError(ctx, "get message from redis exception", err, "failedSeqs", failedSeqs, "conversationID", conversationID)
} }
} }