fix: delete token by correct platformID && feat: adminToken can be retained for five minutes after deleting (#3313)

This commit is contained in:
icey-yu 2025-05-06 15:10:10 +08:00 committed by GitHub
parent dd981b976d
commit 56c5c1f015
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 177 additions and 45 deletions

View File

@ -159,15 +159,17 @@ func (s *authServer) parseToken(ctx context.Context, tokensString string) (claim
if err != nil { if err != nil {
return nil, err return nil, err
} }
isAdmin := authverify.IsManagerUserID(claims.UserID, s.config.Share.IMAdminUserID)
if isAdmin {
return claims, nil
}
m, err := s.authDatabase.GetTokensWithoutError(ctx, claims.UserID, claims.PlatformID) m, err := s.authDatabase.GetTokensWithoutError(ctx, claims.UserID, claims.PlatformID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if len(m) == 0 { if len(m) == 0 {
isAdmin := authverify.IsManagerUserID(claims.UserID, s.config.Share.IMAdminUserID)
if isAdmin {
if err = s.authDatabase.GetTemporaryTokensWithoutError(ctx, claims.UserID, claims.PlatformID, tokensString); err == nil {
return claims, nil
}
}
return nil, servererrs.ErrTokenNotExist.Wrap() return nil, servererrs.ErrTokenNotExist.Wrap()
} }
if v, ok := m[tokensString]; ok { if v, ok := m[tokensString]; ok {
@ -179,6 +181,13 @@ func (s *authServer) parseToken(ctx context.Context, tokensString string) (claim
default: default:
return nil, errs.Wrap(errs.ErrTokenUnknown) return nil, errs.Wrap(errs.ErrTokenUnknown)
} }
} else {
isAdmin := authverify.IsManagerUserID(claims.UserID, s.config.Share.IMAdminUserID)
if isAdmin {
if err = s.authDatabase.GetTemporaryTokensWithoutError(ctx, claims.UserID, claims.PlatformID, tokensString); err == nil {
return claims, nil
}
}
} }
return nil, servererrs.ErrTokenNotExist.Wrap() return nil, servererrs.ErrTokenNotExist.Wrap()
} }

View File

@ -1,8 +1,9 @@
package cachekey package cachekey
import ( import (
"github.com/openimsdk/protocol/constant"
"strings" "strings"
"github.com/openimsdk/protocol/constant"
) )
const ( const (
@ -13,6 +14,10 @@ func GetTokenKey(userID string, platformID int) string {
return UidPidToken + userID + ":" + constant.PlatformIDToName(platformID) return UidPidToken + userID + ":" + constant.PlatformIDToName(platformID)
} }
func GetTemporaryTokenKey(userID string, platformID int, token string) string {
return UidPidToken + ":TEMPORARY:" + userID + ":" + constant.PlatformIDToName(platformID) + ":" + token
}
func GetAllPlatformTokenKey(userID string) []string { func GetAllPlatformTokenKey(userID string) []string {
res := make([]string, len(constant.PlatformID2Name)) res := make([]string, len(constant.PlatformID2Name))
for k := range constant.PlatformID2Name { for k := range constant.PlatformID2Name {

View File

@ -27,7 +27,6 @@ type tokenCache struct {
func (x *tokenCache) getTokenKey(userID string, platformID int, token string) string { func (x *tokenCache) getTokenKey(userID string, platformID int, token string) string {
return cachekey.GetTokenKey(userID, platformID) + ":" + token return cachekey.GetTokenKey(userID, platformID) + ":" + token
} }
func (x *tokenCache) SetTokenFlag(ctx context.Context, userID string, platformID int, token string, flag int) error { func (x *tokenCache) SetTokenFlag(ctx context.Context, userID string, platformID int, token string, flag int) error {
@ -57,6 +56,14 @@ func (x *tokenCache) GetTokensWithoutError(ctx context.Context, userID string, p
return mm, nil return mm, nil
} }
func (x *tokenCache) HasTemporaryToken(ctx context.Context, userID string, platformID int, token string) error {
key := cachekey.GetTemporaryTokenKey(userID, platformID, token)
if _, err := x.cache.Get(ctx, []string{key}); err != nil {
return err
}
return nil
}
func (x *tokenCache) GetAllTokensWithoutError(ctx context.Context, userID string) (map[int]map[string]int, error) { func (x *tokenCache) GetAllTokensWithoutError(ctx context.Context, userID string) (map[int]map[string]int, error) {
prefix := cachekey.UidPidToken + userID + ":" prefix := cachekey.UidPidToken + userID + ":"
tokens, err := x.cache.Prefix(ctx, prefix) tokens, err := x.cache.Prefix(ctx, prefix)
@ -128,3 +135,32 @@ func (x *tokenCache) DeleteTokenByUidPid(ctx context.Context, userID string, pla
func (x *tokenCache) getExpireTime(t int64) time.Duration { func (x *tokenCache) getExpireTime(t int64) time.Duration {
return time.Hour * 24 * time.Duration(t) return time.Hour * 24 * time.Duration(t)
} }
func (x *tokenCache) DeleteTokenByTokenMap(ctx context.Context, userID string, tokens map[int][]string) error {
keys := make([]string, 0, len(tokens))
for platformID, ts := range tokens {
for _, t := range ts {
keys = append(keys, x.getTokenKey(userID, platformID, t))
}
}
return x.cache.Del(ctx, keys)
}
func (x *tokenCache) DeleteAndSetTemporary(ctx context.Context, userID string, platformID int, fields []string) error {
keys := make([]string, 0, len(fields))
for _, f := range fields {
keys = append(keys, x.getTokenKey(userID, platformID, f))
}
if err := x.cache.Del(ctx, keys); err != nil {
return err
}
for _, f := range fields {
k := cachekey.GetTemporaryTokenKey(userID, platformID, f)
if err := x.cache.Set(ctx, k, "", time.Minute*5); err != nil {
return errs.Wrap(err)
}
}
return nil
}

View File

@ -9,6 +9,7 @@ import (
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache" "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache"
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/cachekey" "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/cachekey"
"github.com/openimsdk/tools/errs" "github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/utils/datautil"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
) )
@ -55,6 +56,14 @@ func (c *tokenCache) GetTokensWithoutError(ctx context.Context, userID string, p
return mm, nil return mm, nil
} }
func (c *tokenCache) HasTemporaryToken(ctx context.Context, userID string, platformID int, token string) error {
err := c.rdb.Get(ctx, cachekey.GetTemporaryTokenKey(userID, platformID, token)).Err()
if err != nil {
return errs.Wrap(err)
}
return nil
}
func (c *tokenCache) GetAllTokensWithoutError(ctx context.Context, userID string) (map[int]map[string]int, error) { func (c *tokenCache) GetAllTokensWithoutError(ctx context.Context, userID string) (map[int]map[string]int, error) {
var ( var (
res = make(map[int]map[string]int) res = make(map[int]map[string]int)
@ -101,6 +110,8 @@ func (c *tokenCache) SetTokenMapByUidPid(ctx context.Context, userID string, pla
} }
func (c *tokenCache) BatchSetTokenMapByUidPid(ctx context.Context, tokens map[string]map[string]any) error { func (c *tokenCache) BatchSetTokenMapByUidPid(ctx context.Context, tokens map[string]map[string]any) error {
keys := datautil.Keys(tokens)
if err := ProcessKeysBySlot(ctx, c.rdb, keys, func(ctx context.Context, slot int64, keys []string) error {
pipe := c.rdb.Pipeline() pipe := c.rdb.Pipeline()
for k, v := range tokens { for k, v := range tokens {
pipe.HSet(ctx, k, v) pipe.HSet(ctx, k, v)
@ -110,6 +121,10 @@ func (c *tokenCache) BatchSetTokenMapByUidPid(ctx context.Context, tokens map[st
return errs.Wrap(err) return errs.Wrap(err)
} }
return nil return nil
}); err != nil {
return err
}
return nil
} }
func (c *tokenCache) DeleteTokenByUidPid(ctx context.Context, userID string, platformID int, fields []string) error { func (c *tokenCache) DeleteTokenByUidPid(ctx context.Context, userID string, platformID int, fields []string) error {
@ -119,3 +134,47 @@ func (c *tokenCache) DeleteTokenByUidPid(ctx context.Context, userID string, pla
func (c *tokenCache) getExpireTime(t int64) time.Duration { func (c *tokenCache) getExpireTime(t int64) time.Duration {
return time.Hour * 24 * time.Duration(t) return time.Hour * 24 * time.Duration(t)
} }
// DeleteTokenByTokenMap tokens key is platformID, value is token slice
func (c *tokenCache) DeleteTokenByTokenMap(ctx context.Context, userID string, tokens map[int][]string) error {
var (
keys = make([]string, 0, len(tokens))
keyMap = make(map[string][]string)
)
for k, v := range tokens {
k1 := cachekey.GetTokenKey(userID, k)
keys = append(keys, k1)
keyMap[k1] = v
}
if err := ProcessKeysBySlot(ctx, c.rdb, keys, func(ctx context.Context, slot int64, keys []string) error {
pipe := c.rdb.Pipeline()
for k, v := range tokens {
pipe.HDel(ctx, cachekey.GetTokenKey(userID, k), v...)
}
_, err := pipe.Exec(ctx)
if err != nil {
return errs.Wrap(err)
}
return nil
}); err != nil {
return err
}
return nil
}
func (c *tokenCache) DeleteAndSetTemporary(ctx context.Context, userID string, platformID int, fields []string) error {
key := cachekey.GetTokenKey(userID, platformID)
if err := c.rdb.HDel(ctx, key, fields...).Err(); err != nil {
return errs.Wrap(err)
}
for _, f := range fields {
k := cachekey.GetTemporaryTokenKey(userID, platformID, f)
if err := c.rdb.Set(ctx, k, "", time.Minute*5).Err(); err != nil {
return errs.Wrap(err)
}
}
return nil
}

View File

@ -9,8 +9,11 @@ type TokenModel interface {
// SetTokenFlagEx set token and flag with expire time // SetTokenFlagEx set token and flag with expire time
SetTokenFlagEx(ctx context.Context, userID string, platformID int, token string, flag int) error SetTokenFlagEx(ctx context.Context, userID string, platformID int, token string, flag int) error
GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error) GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error)
HasTemporaryToken(ctx context.Context, userID string, platformID int, token string) error
GetAllTokensWithoutError(ctx context.Context, userID string) (map[int]map[string]int, error) GetAllTokensWithoutError(ctx context.Context, userID string) (map[int]map[string]int, error)
SetTokenMapByUidPid(ctx context.Context, userID string, platformID int, m map[string]int) error SetTokenMapByUidPid(ctx context.Context, userID string, platformID int, m map[string]int) error
BatchSetTokenMapByUidPid(ctx context.Context, tokens map[string]map[string]any) error BatchSetTokenMapByUidPid(ctx context.Context, tokens map[string]map[string]any) error
DeleteTokenByUidPid(ctx context.Context, userID string, platformID int, fields []string) error DeleteTokenByUidPid(ctx context.Context, userID string, platformID int, fields []string) error
DeleteTokenByTokenMap(ctx context.Context, userID string, tokens map[int][]string) error
DeleteAndSetTemporary(ctx context.Context, userID string, platformID int, fields []string) error
} }

View File

@ -17,6 +17,8 @@ import (
type AuthDatabase interface { type AuthDatabase interface {
// If the result is empty, no error is returned. // If the result is empty, no error is returned.
GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error) GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error)
GetTemporaryTokensWithoutError(ctx context.Context, userID string, platformID int, token string) error
// Create token // Create token
CreateToken(ctx context.Context, userID string, platformID int) (string, error) CreateToken(ctx context.Context, userID string, platformID int) (string, error)
@ -52,6 +54,10 @@ func (a *authDatabase) GetTokensWithoutError(ctx context.Context, userID string,
return a.cache.GetTokensWithoutError(ctx, userID, platformID) return a.cache.GetTokensWithoutError(ctx, userID, platformID)
} }
func (a *authDatabase) GetTemporaryTokensWithoutError(ctx context.Context, userID string, platformID int, token string) error {
return a.cache.HasTemporaryToken(ctx, userID, platformID, token)
}
func (a *authDatabase) SetTokenMapByUidPid(ctx context.Context, userID string, platformID int, m map[string]int) error { func (a *authDatabase) SetTokenMapByUidPid(ctx context.Context, userID string, platformID int, m map[string]int) error {
return a.cache.SetTokenMapByUidPid(ctx, userID, platformID, m) return a.cache.SetTokenMapByUidPid(ctx, userID, platformID, m)
} }
@ -85,25 +91,32 @@ func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformI
return "", err return "", err
} }
deleteTokenKey, kickedTokenKey, err := a.checkToken(ctx, tokens, platformID) deleteTokenKey, kickedTokenKey, adminTokens, err := a.checkToken(ctx, tokens, platformID)
if err != nil { if err != nil {
return "", err return "", err
} }
if len(deleteTokenKey) != 0 { if len(deleteTokenKey) != 0 {
err = a.cache.DeleteTokenByUidPid(ctx, userID, platformID, deleteTokenKey) err = a.cache.DeleteTokenByTokenMap(ctx, userID, deleteTokenKey)
if err != nil { if err != nil {
return "", err return "", err
} }
} }
if len(kickedTokenKey) != 0 { if len(kickedTokenKey) != 0 {
for _, k := range kickedTokenKey { for plt, ks := range kickedTokenKey {
err := a.cache.SetTokenFlagEx(ctx, userID, platformID, k, constant.KickedToken) for _, k := range ks {
err := a.cache.SetTokenFlagEx(ctx, userID, plt, k, constant.KickedToken)
if err != nil { if err != nil {
return "", err return "", err
} }
log.ZDebug(ctx, "kicked token in create token", "token", k) log.ZDebug(ctx, "kicked token in create token", "token", k)
} }
} }
}
if len(adminTokens) != 0 {
if err = a.cache.DeleteAndSetTemporary(ctx, userID, constant.AdminPlatformID, adminTokens); err != nil {
return "", err
}
}
claims := tokenverify.BuildClaims(userID, platformID, a.accessExpire) claims := tokenverify.BuildClaims(userID, platformID, a.accessExpire)
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
@ -119,12 +132,13 @@ func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformI
return tokenString, nil return tokenString, nil
} }
func (a *authDatabase) checkToken(ctx context.Context, tokens map[int]map[string]int, platformID int) ([]string, []string, error) { // checkToken will check token by tokenPolicy and return deleteToken,kickToken,deleteAdminToken
// todo: Move the logic for handling old data to another location. func (a *authDatabase) checkToken(ctx context.Context, tokens map[int]map[string]int, platformID int) (map[int][]string, map[int][]string, []string, error) {
// todo: Asynchronous deletion of old data.
var ( var (
loginTokenMap = make(map[int][]string) // The length of the value of the map must be greater than 0 loginTokenMap = make(map[int][]string) // The length of the value of the map must be greater than 0
deleteToken = make([]string, 0) deleteToken = make(map[int][]string)
kickToken = make([]string, 0) kickToken = make(map[int][]string)
adminToken = make([]string, 0) adminToken = make([]string, 0)
unkickTerminal = "" unkickTerminal = ""
) )
@ -133,7 +147,7 @@ func (a *authDatabase) checkToken(ctx context.Context, tokens map[int]map[string
for k, v := range tks { for k, v := range tks {
_, err := tokenverify.GetClaimFromToken(k, authverify.Secret(a.accessSecret)) _, err := tokenverify.GetClaimFromToken(k, authverify.Secret(a.accessSecret))
if err != nil || v != constant.NormalToken { if err != nil || v != constant.NormalToken {
deleteToken = append(deleteToken, k) deleteToken[plfID] = append(deleteToken[plfID], k)
} else { } else {
if plfID != constant.AdminPlatformID { if plfID != constant.AdminPlatformID {
loginTokenMap[plfID] = append(loginTokenMap[plfID], k) loginTokenMap[plfID] = append(loginTokenMap[plfID], k)
@ -153,14 +167,15 @@ func (a *authDatabase) checkToken(ctx context.Context, tokens map[int]map[string
} }
limit := a.multiLogin.MaxNumOneEnd limit := a.multiLogin.MaxNumOneEnd
if l > limit { if l > limit {
kickToken = append(kickToken, ts[:l-limit]...) kickToken[plt] = ts[:l-limit]
} }
} }
case constant.AllLoginButSameTermKick: case constant.AllLoginButSameTermKick:
for plt, ts := range loginTokenMap { for plt, ts := range loginTokenMap {
kickToken = append(kickToken, ts[:len(ts)-1]...) kickToken[plt] = ts[:len(ts)-1]
if plt == platformID { if plt == platformID {
kickToken = append(kickToken, ts[len(ts)-1]) kickToken[plt] = append(kickToken[plt], ts[len(ts)-1])
} }
} }
case constant.PCAndOther: case constant.PCAndOther:
@ -168,29 +183,33 @@ func (a *authDatabase) checkToken(ctx context.Context, tokens map[int]map[string
if constant.PlatformIDToClass(platformID) != unkickTerminal { if constant.PlatformIDToClass(platformID) != unkickTerminal {
for plt, ts := range loginTokenMap { for plt, ts := range loginTokenMap {
if constant.PlatformIDToClass(plt) != unkickTerminal { if constant.PlatformIDToClass(plt) != unkickTerminal {
kickToken = append(kickToken, ts...) kickToken[plt] = ts
} }
} }
} else { } else {
var ( var (
preKick []string preKickToken string
isReserve = true preKickPlt int
reserveToken = false
) )
for plt, ts := range loginTokenMap { for plt, ts := range loginTokenMap {
if constant.PlatformIDToClass(plt) != unkickTerminal { if constant.PlatformIDToClass(plt) != unkickTerminal {
// Keep a token from another end // Keep a token from another end
if isReserve { if !reserveToken {
isReserve = false reserveToken = true
kickToken = append(kickToken, ts[:len(ts)-1]...) kickToken[plt] = ts[:len(ts)-1]
preKick = append(preKick, ts[len(ts)-1]) preKickToken = ts[len(ts)-1]
preKickPlt = plt
continue continue
} else { } else {
// Prioritize keeping Android // Prioritize keeping Android
if plt == constant.AndroidPlatformID { if plt == constant.AndroidPlatformID {
kickToken = append(kickToken, preKick...) if preKickToken != "" {
kickToken = append(kickToken, ts[:len(ts)-1]...) kickToken[preKickPlt] = append(kickToken[preKickPlt], preKickToken)
}
kickToken[plt] = ts[:len(ts)-1]
} else { } else {
kickToken = append(kickToken, ts...) kickToken[plt] = ts
} }
} }
} }
@ -203,19 +222,19 @@ func (a *authDatabase) checkToken(ctx context.Context, tokens map[int]map[string
for plt, ts := range loginTokenMap { for plt, ts := range loginTokenMap {
if constant.PlatformIDToClass(plt) == constant.PlatformIDToClass(platformID) { if constant.PlatformIDToClass(plt) == constant.PlatformIDToClass(platformID) {
kickToken = append(kickToken, ts...) kickToken[plt] = ts
} else { } else {
if _, ok := reserved[constant.PlatformIDToClass(plt)]; !ok { if _, ok := reserved[constant.PlatformIDToClass(plt)]; !ok {
reserved[constant.PlatformIDToClass(plt)] = struct{}{} reserved[constant.PlatformIDToClass(plt)] = struct{}{}
kickToken = append(kickToken, ts[:len(ts)-1]...) kickToken[plt] = ts[:len(ts)-1]
continue continue
} else { } else {
kickToken = append(kickToken, ts...) kickToken[plt] = ts
} }
} }
} }
default: default:
return nil, nil, errs.New("unknown multiLogin policy").Wrap() return nil, nil, nil, errs.New("unknown multiLogin policy").Wrap()
} }
//var adminTokenMaxNum = a.multiLogin.MaxNumOneEnd //var adminTokenMaxNum = a.multiLogin.MaxNumOneEnd
@ -226,8 +245,9 @@ func (a *authDatabase) checkToken(ctx context.Context, tokens map[int]map[string
//if l > adminTokenMaxNum { //if l > adminTokenMaxNum {
// kickToken = append(kickToken, adminToken[:l-adminTokenMaxNum]...) // kickToken = append(kickToken, adminToken[:l-adminTokenMaxNum]...)
//} //}
var deleteAdminToken []string
if platformID == constant.AdminPlatformID { if platformID == constant.AdminPlatformID {
kickToken = append(kickToken, adminToken...) deleteAdminToken = adminToken
} }
return deleteToken, kickToken, nil return deleteToken, kickToken, deleteAdminToken, nil
} }

View File

@ -127,7 +127,7 @@ func (x *CacheMgo) Del(ctx context.Context, key []string) error {
return nil return nil
} }
_, err := x.coll.DeleteMany(ctx, bson.M{"key": bson.M{"$in": key}}) _, err := x.coll.DeleteMany(ctx, bson.M{"key": bson.M{"$in": key}})
return err return errs.Wrap(err)
} }
func (x *CacheMgo) lockKey(key string) string { func (x *CacheMgo) lockKey(key string) string {