mirror of
				https://github.com/openimsdk/open-im-server.git
				synced 2025-10-25 20:52:11 +08:00 
			
		
		
		
	fix: delete token by correct platformID && feat: adminToken can be retained for five minutes after deleting (#3313)
This commit is contained in:
		
							parent
							
								
									c29e2a9a28
								
							
						
					
					
						commit
						78aaf6abce
					
				| @ -140,15 +140,17 @@ func (s *authServer) parseToken(ctx context.Context, tokensString string) (claim | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	isAdmin := authverify.IsManagerUserID(claims.UserID, s.config.Share.IMAdminUserID) | ||||
| 	if isAdmin { | ||||
| 		return claims, nil | ||||
| 	} | ||||
| 	m, err := s.authDatabase.GetTokensWithoutError(ctx, claims.UserID, claims.PlatformID) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if len(m) == 0 { | ||||
| 		isAdmin := authverify.IsManagerUserID(claims.UserID, s.config.Share.IMAdminUserID) | ||||
| 		if isAdmin { | ||||
| 			if err = s.authDatabase.GetTemporaryTokensWithoutError(ctx, claims.UserID, claims.PlatformID, tokensString); err == nil { | ||||
| 				return claims, nil | ||||
| 			} | ||||
| 		} | ||||
| 		return nil, servererrs.ErrTokenNotExist.Wrap() | ||||
| 	} | ||||
| 	if v, ok := m[tokensString]; ok { | ||||
| @ -160,6 +162,13 @@ func (s *authServer) parseToken(ctx context.Context, tokensString string) (claim | ||||
| 		default: | ||||
| 			return nil, errs.Wrap(errs.ErrTokenUnknown) | ||||
| 		} | ||||
| 	} else { | ||||
| 		isAdmin := authverify.IsManagerUserID(claims.UserID, s.config.Share.IMAdminUserID) | ||||
| 		if isAdmin { | ||||
| 			if err = s.authDatabase.GetTemporaryTokensWithoutError(ctx, claims.UserID, claims.PlatformID, tokensString); err == nil { | ||||
| 				return claims, nil | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	return nil, servererrs.ErrTokenNotExist.Wrap() | ||||
| } | ||||
|  | ||||
							
								
								
									
										7
									
								
								pkg/common/storage/cache/cachekey/token.go
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										7
									
								
								pkg/common/storage/cache/cachekey/token.go
									
									
									
									
										vendored
									
									
								
							| @ -1,8 +1,9 @@ | ||||
| package cachekey | ||||
| 
 | ||||
| import ( | ||||
| 	"github.com/openimsdk/protocol/constant" | ||||
| 	"strings" | ||||
| 
 | ||||
| 	"github.com/openimsdk/protocol/constant" | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
| @ -13,6 +14,10 @@ func GetTokenKey(userID string, platformID int) string { | ||||
| 	return UidPidToken + userID + ":" + constant.PlatformIDToName(platformID) | ||||
| } | ||||
| 
 | ||||
| func GetTemporaryTokenKey(userID string, platformID int, token string) string { | ||||
| 	return UidPidToken + ":TEMPORARY:" + userID + ":" + constant.PlatformIDToName(platformID) + ":" + token | ||||
| } | ||||
| 
 | ||||
| func GetAllPlatformTokenKey(userID string) []string { | ||||
| 	res := make([]string, len(constant.PlatformID2Name)) | ||||
| 	for k := range constant.PlatformID2Name { | ||||
|  | ||||
							
								
								
									
										166
									
								
								pkg/common/storage/cache/mcache/token.go
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										166
									
								
								pkg/common/storage/cache/mcache/token.go
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @ -0,0 +1,166 @@ | ||||
| package mcache | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache" | ||||
| 	"github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/cachekey" | ||||
| 	"github.com/openimsdk/open-im-server/v3/pkg/common/storage/database" | ||||
| 	"github.com/openimsdk/tools/errs" | ||||
| 	"github.com/openimsdk/tools/log" | ||||
| ) | ||||
| 
 | ||||
| func NewTokenCacheModel(cache database.Cache, accessExpire int64) cache.TokenModel { | ||||
| 	c := &tokenCache{cache: cache} | ||||
| 	c.accessExpire = c.getExpireTime(accessExpire) | ||||
| 	return c | ||||
| } | ||||
| 
 | ||||
| type tokenCache struct { | ||||
| 	cache        database.Cache | ||||
| 	accessExpire time.Duration | ||||
| } | ||||
| 
 | ||||
| func (x *tokenCache) getTokenKey(userID string, platformID int, token string) string { | ||||
| 	return cachekey.GetTokenKey(userID, platformID) + ":" + token | ||||
| } | ||||
| 
 | ||||
| func (x *tokenCache) SetTokenFlag(ctx context.Context, userID string, platformID int, token string, flag int) error { | ||||
| 	return x.cache.Set(ctx, x.getTokenKey(userID, platformID, token), strconv.Itoa(flag), x.accessExpire) | ||||
| } | ||||
| 
 | ||||
| // SetTokenFlagEx set token and flag with expire time | ||||
| func (x *tokenCache) SetTokenFlagEx(ctx context.Context, userID string, platformID int, token string, flag int) error { | ||||
| 	return x.SetTokenFlag(ctx, userID, platformID, token, flag) | ||||
| } | ||||
| 
 | ||||
| func (x *tokenCache) GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error) { | ||||
| 	prefix := x.getTokenKey(userID, platformID, "") | ||||
| 	m, err := x.cache.Prefix(ctx, prefix) | ||||
| 	if err != nil { | ||||
| 		return nil, errs.Wrap(err) | ||||
| 	} | ||||
| 	mm := make(map[string]int) | ||||
| 	for k, v := range m { | ||||
| 		state, err := strconv.Atoi(v) | ||||
| 		if err != nil { | ||||
| 			log.ZError(ctx, "token value is not int", err, "value", v, "userID", userID, "platformID", platformID) | ||||
| 			continue | ||||
| 		} | ||||
| 		mm[strings.TrimPrefix(k, prefix)] = state | ||||
| 	} | ||||
| 	return mm, nil | ||||
| } | ||||
| 
 | ||||
| func (x *tokenCache) HasTemporaryToken(ctx context.Context, userID string, platformID int, token string) error { | ||||
| 	key := cachekey.GetTemporaryTokenKey(userID, platformID, token) | ||||
| 	if _, err := x.cache.Get(ctx, []string{key}); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (x *tokenCache) GetAllTokensWithoutError(ctx context.Context, userID string) (map[int]map[string]int, error) { | ||||
| 	prefix := cachekey.UidPidToken + userID + ":" | ||||
| 	tokens, err := x.cache.Prefix(ctx, prefix) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	res := make(map[int]map[string]int) | ||||
| 	for key, flagStr := range tokens { | ||||
| 		flag, err := strconv.Atoi(flagStr) | ||||
| 		if err != nil { | ||||
| 			log.ZError(ctx, "token value is not int", err, "key", key, "value", flagStr, "userID", userID) | ||||
| 			continue | ||||
| 		} | ||||
| 		arr := strings.SplitN(strings.TrimPrefix(key, prefix), ":", 2) | ||||
| 		if len(arr) != 2 { | ||||
| 			log.ZError(ctx, "token value is not int", err, "key", key, "value", flagStr, "userID", userID) | ||||
| 			continue | ||||
| 		} | ||||
| 		platformID, err := strconv.Atoi(arr[0]) | ||||
| 		if err != nil { | ||||
| 			log.ZError(ctx, "token value is not int", err, "key", key, "value", flagStr, "userID", userID) | ||||
| 			continue | ||||
| 		} | ||||
| 		token := arr[1] | ||||
| 		if token == "" { | ||||
| 			log.ZError(ctx, "token value is not int", err, "key", key, "value", flagStr, "userID", userID) | ||||
| 			continue | ||||
| 		} | ||||
| 		tk, ok := res[platformID] | ||||
| 		if !ok { | ||||
| 			tk = make(map[string]int) | ||||
| 			res[platformID] = tk | ||||
| 		} | ||||
| 		tk[token] = flag | ||||
| 	} | ||||
| 	return res, nil | ||||
| } | ||||
| 
 | ||||
| func (x *tokenCache) SetTokenMapByUidPid(ctx context.Context, userID string, platformID int, m map[string]int) error { | ||||
| 	for token, flag := range m { | ||||
| 		err := x.SetTokenFlag(ctx, userID, platformID, token, flag) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (x *tokenCache) BatchSetTokenMapByUidPid(ctx context.Context, tokens map[string]map[string]any) error { | ||||
| 	for prefix, tokenFlag := range tokens { | ||||
| 		for token, flag := range tokenFlag { | ||||
| 			flagStr := fmt.Sprintf("%v", flag) | ||||
| 			if err := x.cache.Set(ctx, prefix+":"+token, flagStr, x.accessExpire); err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (x *tokenCache) DeleteTokenByUidPid(ctx context.Context, userID string, platformID int, fields []string) error { | ||||
| 	keys := make([]string, 0, len(fields)) | ||||
| 	for _, token := range fields { | ||||
| 		keys = append(keys, x.getTokenKey(userID, platformID, token)) | ||||
| 	} | ||||
| 	return x.cache.Del(ctx, keys) | ||||
| } | ||||
| 
 | ||||
| func (x *tokenCache) getExpireTime(t int64) time.Duration { | ||||
| 	return time.Hour * 24 * time.Duration(t) | ||||
| } | ||||
| 
 | ||||
| func (x *tokenCache) DeleteTokenByTokenMap(ctx context.Context, userID string, tokens map[int][]string) error { | ||||
| 	keys := make([]string, 0, len(tokens)) | ||||
| 	for platformID, ts := range tokens { | ||||
| 		for _, t := range ts { | ||||
| 			keys = append(keys, x.getTokenKey(userID, platformID, t)) | ||||
| 		} | ||||
| 	} | ||||
| 	return x.cache.Del(ctx, keys) | ||||
| } | ||||
| 
 | ||||
| func (x *tokenCache) DeleteAndSetTemporary(ctx context.Context, userID string, platformID int, fields []string) error { | ||||
| 	keys := make([]string, 0, len(fields)) | ||||
| 	for _, f := range fields { | ||||
| 		keys = append(keys, x.getTokenKey(userID, platformID, f)) | ||||
| 	} | ||||
| 	if err := x.cache.Del(ctx, keys); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	for _, f := range fields { | ||||
| 		k := cachekey.GetTemporaryTokenKey(userID, platformID, f) | ||||
| 		if err := x.cache.Set(ctx, k, "", time.Minute*5); err != nil { | ||||
| 			return errs.Wrap(err) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return nil | ||||
| } | ||||
							
								
								
									
										59
									
								
								pkg/common/storage/cache/redis/token.go
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										59
									
								
								pkg/common/storage/cache/redis/token.go
									
									
									
									
										vendored
									
									
								
							| @ -9,6 +9,7 @@ import ( | ||||
| 	"github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache" | ||||
| 	"github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/cachekey" | ||||
| 	"github.com/openimsdk/tools/errs" | ||||
| 	"github.com/openimsdk/tools/utils/datautil" | ||||
| 	"github.com/redis/go-redis/v9" | ||||
| ) | ||||
| 
 | ||||
| @ -55,6 +56,14 @@ func (c *tokenCache) GetTokensWithoutError(ctx context.Context, userID string, p | ||||
| 	return mm, nil | ||||
| } | ||||
| 
 | ||||
| func (c *tokenCache) HasTemporaryToken(ctx context.Context, userID string, platformID int, token string) error { | ||||
| 	err := c.rdb.Get(ctx, cachekey.GetTemporaryTokenKey(userID, platformID, token)).Err() | ||||
| 	if err != nil { | ||||
| 		return errs.Wrap(err) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (c *tokenCache) GetAllTokensWithoutError(ctx context.Context, userID string) (map[int]map[string]int, error) { | ||||
| 	var ( | ||||
| 		res     = make(map[int]map[string]int) | ||||
| @ -101,6 +110,8 @@ func (c *tokenCache) SetTokenMapByUidPid(ctx context.Context, userID string, pla | ||||
| } | ||||
| 
 | ||||
| func (c *tokenCache) BatchSetTokenMapByUidPid(ctx context.Context, tokens map[string]map[string]any) error { | ||||
| 	keys := datautil.Keys(tokens) | ||||
| 	if err := ProcessKeysBySlot(ctx, c.rdb, keys, func(ctx context.Context, slot int64, keys []string) error { | ||||
| 		pipe := c.rdb.Pipeline() | ||||
| 		for k, v := range tokens { | ||||
| 			pipe.HSet(ctx, k, v) | ||||
| @ -110,6 +121,10 @@ func (c *tokenCache) BatchSetTokenMapByUidPid(ctx context.Context, tokens map[st | ||||
| 			return errs.Wrap(err) | ||||
| 		} | ||||
| 		return nil | ||||
| 	}); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (c *tokenCache) DeleteTokenByUidPid(ctx context.Context, userID string, platformID int, fields []string) error { | ||||
| @ -119,3 +134,47 @@ func (c *tokenCache) DeleteTokenByUidPid(ctx context.Context, userID string, pla | ||||
| func (c *tokenCache) getExpireTime(t int64) time.Duration { | ||||
| 	return time.Hour * 24 * time.Duration(t) | ||||
| } | ||||
| 
 | ||||
| // DeleteTokenByTokenMap tokens key is platformID, value is token slice | ||||
| func (c *tokenCache) DeleteTokenByTokenMap(ctx context.Context, userID string, tokens map[int][]string) error { | ||||
| 	var ( | ||||
| 		keys   = make([]string, 0, len(tokens)) | ||||
| 		keyMap = make(map[string][]string) | ||||
| 	) | ||||
| 	for k, v := range tokens { | ||||
| 		k1 := cachekey.GetTokenKey(userID, k) | ||||
| 		keys = append(keys, k1) | ||||
| 		keyMap[k1] = v | ||||
| 	} | ||||
| 
 | ||||
| 	if err := ProcessKeysBySlot(ctx, c.rdb, keys, func(ctx context.Context, slot int64, keys []string) error { | ||||
| 		pipe := c.rdb.Pipeline() | ||||
| 		for k, v := range tokens { | ||||
| 			pipe.HDel(ctx, cachekey.GetTokenKey(userID, k), v...) | ||||
| 		} | ||||
| 		_, err := pipe.Exec(ctx) | ||||
| 		if err != nil { | ||||
| 			return errs.Wrap(err) | ||||
| 		} | ||||
| 		return nil | ||||
| 	}); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (c *tokenCache) DeleteAndSetTemporary(ctx context.Context, userID string, platformID int, fields []string) error { | ||||
| 	key := cachekey.GetTokenKey(userID, platformID) | ||||
| 	if err := c.rdb.HDel(ctx, key, fields...).Err(); err != nil { | ||||
| 		return errs.Wrap(err) | ||||
| 	} | ||||
| 	for _, f := range fields { | ||||
| 		k := cachekey.GetTemporaryTokenKey(userID, platformID, f) | ||||
| 		if err := c.rdb.Set(ctx, k, "", time.Minute*5).Err(); err != nil { | ||||
| 			return errs.Wrap(err) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
							
								
								
									
										3
									
								
								pkg/common/storage/cache/token.go
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								pkg/common/storage/cache/token.go
									
									
									
									
										vendored
									
									
								
							| @ -9,8 +9,11 @@ type TokenModel interface { | ||||
| 	// SetTokenFlagEx set token and flag with expire time | ||||
| 	SetTokenFlagEx(ctx context.Context, userID string, platformID int, token string, flag int) error | ||||
| 	GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error) | ||||
| 	HasTemporaryToken(ctx context.Context, userID string, platformID int, token string) error | ||||
| 	GetAllTokensWithoutError(ctx context.Context, userID string) (map[int]map[string]int, error) | ||||
| 	SetTokenMapByUidPid(ctx context.Context, userID string, platformID int, m map[string]int) error | ||||
| 	BatchSetTokenMapByUidPid(ctx context.Context, tokens map[string]map[string]any) error | ||||
| 	DeleteTokenByUidPid(ctx context.Context, userID string, platformID int, fields []string) error | ||||
| 	DeleteTokenByTokenMap(ctx context.Context, userID string, tokens map[int][]string) error | ||||
| 	DeleteAndSetTemporary(ctx context.Context, userID string, platformID int, fields []string) error | ||||
| } | ||||
|  | ||||
| @ -17,6 +17,8 @@ import ( | ||||
| type AuthDatabase interface { | ||||
| 	// If the result is empty, no error is returned. | ||||
| 	GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error) | ||||
| 
 | ||||
| 	GetTemporaryTokensWithoutError(ctx context.Context, userID string, platformID int, token string) error | ||||
| 	// Create token | ||||
| 	CreateToken(ctx context.Context, userID string, platformID int) (string, error) | ||||
| 
 | ||||
| @ -51,6 +53,10 @@ func (a *authDatabase) GetTokensWithoutError(ctx context.Context, userID string, | ||||
| 	return a.cache.GetTokensWithoutError(ctx, userID, platformID) | ||||
| } | ||||
| 
 | ||||
| func (a *authDatabase) GetTemporaryTokensWithoutError(ctx context.Context, userID string, platformID int, token string) error { | ||||
| 	return a.cache.HasTemporaryToken(ctx, userID, platformID, token) | ||||
| } | ||||
| 
 | ||||
| func (a *authDatabase) SetTokenMapByUidPid(ctx context.Context, userID string, platformID int, m map[string]int) error { | ||||
| 	return a.cache.SetTokenMapByUidPid(ctx, userID, platformID, m) | ||||
| } | ||||
| @ -86,19 +92,20 @@ func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformI | ||||
| 			return "", err | ||||
| 		} | ||||
| 
 | ||||
| 		deleteTokenKey, kickedTokenKey, err := a.checkToken(ctx, tokens, platformID) | ||||
| 	deleteTokenKey, kickedTokenKey, adminTokens, err := a.checkToken(ctx, tokens, platformID) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
| 	if len(deleteTokenKey) != 0 { | ||||
| 			err = a.cache.DeleteTokenByUidPid(ctx, userID, platformID, deleteTokenKey) | ||||
| 		err = a.cache.DeleteTokenByTokenMap(ctx, userID, deleteTokenKey) | ||||
| 		if err != nil { | ||||
| 			return "", err | ||||
| 		} | ||||
| 	} | ||||
| 	if len(kickedTokenKey) != 0 { | ||||
| 			for _, k := range kickedTokenKey { | ||||
| 				err := a.cache.SetTokenFlagEx(ctx, userID, platformID, k, constant.KickedToken) | ||||
| 		for plt, ks := range kickedTokenKey { | ||||
| 			for _, k := range ks { | ||||
| 				err := a.cache.SetTokenFlagEx(ctx, userID, plt, k, constant.KickedToken) | ||||
| 				if err != nil { | ||||
| 					return "", err | ||||
| 				} | ||||
| @ -106,6 +113,11 @@ func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformI | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	if len(adminTokens) != 0 { | ||||
| 		if err = a.cache.DeleteAndSetTemporary(ctx, userID, constant.AdminPlatformID, adminTokens); err != nil { | ||||
| 			return "", err | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	claims := tokenverify.BuildClaims(userID, platformID, a.accessExpire) | ||||
| 	token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) | ||||
| @ -123,12 +135,13 @@ func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformI | ||||
| 	return tokenString, nil | ||||
| } | ||||
| 
 | ||||
| func (a *authDatabase) checkToken(ctx context.Context, tokens map[int]map[string]int, platformID int) ([]string, []string, error) { | ||||
| 	// todo: Move the logic for handling old data to another location. | ||||
| // checkToken will check token by tokenPolicy and return deleteToken,kickToken,deleteAdminToken | ||||
| func (a *authDatabase) checkToken(ctx context.Context, tokens map[int]map[string]int, platformID int) (map[int][]string, map[int][]string, []string, error) { | ||||
| 	// todo: Asynchronous deletion of old data. | ||||
| 	var ( | ||||
| 		loginTokenMap  = make(map[int][]string) // The length of the value of the map must be greater than 0 | ||||
| 		deleteToken    = make([]string, 0) | ||||
| 		kickToken      = make([]string, 0) | ||||
| 		deleteToken    = make(map[int][]string) | ||||
| 		kickToken      = make(map[int][]string) | ||||
| 		adminToken     = make([]string, 0) | ||||
| 		unkickTerminal = "" | ||||
| 	) | ||||
| @ -137,7 +150,7 @@ func (a *authDatabase) checkToken(ctx context.Context, tokens map[int]map[string | ||||
| 		for k, v := range tks { | ||||
| 			_, err := tokenverify.GetClaimFromToken(k, authverify.Secret(a.accessSecret)) | ||||
| 			if err != nil || v != constant.NormalToken { | ||||
| 				deleteToken = append(deleteToken, k) | ||||
| 				deleteToken[plfID] = append(deleteToken[plfID], k) | ||||
| 			} else { | ||||
| 				if plfID != constant.AdminPlatformID { | ||||
| 					loginTokenMap[plfID] = append(loginTokenMap[plfID], k) | ||||
| @ -157,14 +170,15 @@ func (a *authDatabase) checkToken(ctx context.Context, tokens map[int]map[string | ||||
| 			} | ||||
| 			limit := a.multiLogin.MaxNumOneEnd | ||||
| 			if l > limit { | ||||
| 				kickToken = append(kickToken, ts[:l-limit]...) | ||||
| 				kickToken[plt] = ts[:l-limit] | ||||
| 			} | ||||
| 		} | ||||
| 	case constant.AllLoginButSameTermKick: | ||||
| 		for plt, ts := range loginTokenMap { | ||||
| 			kickToken = append(kickToken, ts[:len(ts)-1]...) | ||||
| 			kickToken[plt] = ts[:len(ts)-1] | ||||
| 
 | ||||
| 			if plt == platformID { | ||||
| 				kickToken = append(kickToken, ts[len(ts)-1]) | ||||
| 				kickToken[plt] = append(kickToken[plt], ts[len(ts)-1]) | ||||
| 			} | ||||
| 		} | ||||
| 	case constant.PCAndOther: | ||||
| @ -172,29 +186,33 @@ func (a *authDatabase) checkToken(ctx context.Context, tokens map[int]map[string | ||||
| 		if constant.PlatformIDToClass(platformID) != unkickTerminal { | ||||
| 			for plt, ts := range loginTokenMap { | ||||
| 				if constant.PlatformIDToClass(plt) != unkickTerminal { | ||||
| 					kickToken = append(kickToken, ts...) | ||||
| 					kickToken[plt] = ts | ||||
| 				} | ||||
| 			} | ||||
| 		} else { | ||||
| 			var ( | ||||
| 				preKick   []string | ||||
| 				isReserve = true | ||||
| 				preKickToken string | ||||
| 				preKickPlt   int | ||||
| 				reserveToken = false | ||||
| 			) | ||||
| 			for plt, ts := range loginTokenMap { | ||||
| 				if constant.PlatformIDToClass(plt) != unkickTerminal { | ||||
| 					// Keep a token from another end | ||||
| 					if isReserve { | ||||
| 						isReserve = false | ||||
| 						kickToken = append(kickToken, ts[:len(ts)-1]...) | ||||
| 						preKick = append(preKick, ts[len(ts)-1]) | ||||
| 					if !reserveToken { | ||||
| 						reserveToken = true | ||||
| 						kickToken[plt] = ts[:len(ts)-1] | ||||
| 						preKickToken = ts[len(ts)-1] | ||||
| 						preKickPlt = plt | ||||
| 						continue | ||||
| 					} else { | ||||
| 						// Prioritize keeping Android | ||||
| 						if plt == constant.AndroidPlatformID { | ||||
| 							kickToken = append(kickToken, preKick...) | ||||
| 							kickToken = append(kickToken, ts[:len(ts)-1]...) | ||||
| 							if preKickToken != "" { | ||||
| 								kickToken[preKickPlt] = append(kickToken[preKickPlt], preKickToken) | ||||
| 							} | ||||
| 							kickToken[plt] = ts[:len(ts)-1] | ||||
| 						} else { | ||||
| 							kickToken = append(kickToken, ts...) | ||||
| 							kickToken[plt] = ts | ||||
| 						} | ||||
| 					} | ||||
| 				} | ||||
| @ -207,19 +225,19 @@ func (a *authDatabase) checkToken(ctx context.Context, tokens map[int]map[string | ||||
| 
 | ||||
| 		for plt, ts := range loginTokenMap { | ||||
| 			if constant.PlatformIDToClass(plt) == constant.PlatformIDToClass(platformID) { | ||||
| 				kickToken = append(kickToken, ts...) | ||||
| 				kickToken[plt] = ts | ||||
| 			} else { | ||||
| 				if _, ok := reserved[constant.PlatformIDToClass(plt)]; !ok { | ||||
| 					reserved[constant.PlatformIDToClass(plt)] = struct{}{} | ||||
| 					kickToken = append(kickToken, ts[:len(ts)-1]...) | ||||
| 					kickToken[plt] = ts[:len(ts)-1] | ||||
| 					continue | ||||
| 				} else { | ||||
| 					kickToken = append(kickToken, ts...) | ||||
| 					kickToken[plt] = ts | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	default: | ||||
| 		return nil, nil, errs.New("unknown multiLogin policy").Wrap() | ||||
| 		return nil, nil, nil, errs.New("unknown multiLogin policy").Wrap() | ||||
| 	} | ||||
| 
 | ||||
| 	//var adminTokenMaxNum = a.multiLogin.MaxNumOneEnd | ||||
| @ -233,5 +251,9 @@ func (a *authDatabase) checkToken(ctx context.Context, tokens map[int]map[string | ||||
| 	//if l > adminTokenMaxNum { | ||||
| 	//	kickToken = append(kickToken, adminToken[:l-adminTokenMaxNum]...) | ||||
| 	//} | ||||
| 	return deleteToken, kickToken, nil | ||||
| 	var deleteAdminToken []string | ||||
| 	if platformID == constant.AdminPlatformID { | ||||
| 		deleteAdminToken = adminToken | ||||
| 	} | ||||
| 	return deleteToken, kickToken, deleteAdminToken, nil | ||||
| } | ||||
|  | ||||
							
								
								
									
										183
									
								
								pkg/common/storage/database/mgo/cache.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										183
									
								
								pkg/common/storage/database/mgo/cache.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,183 @@ | ||||
| package mgo | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"strconv" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/google/uuid" | ||||
| 	"github.com/openimsdk/open-im-server/v3/pkg/common/storage/database" | ||||
| 	"github.com/openimsdk/open-im-server/v3/pkg/common/storage/model" | ||||
| 	"github.com/openimsdk/tools/db/mongoutil" | ||||
| 	"github.com/openimsdk/tools/errs" | ||||
| 	"go.mongodb.org/mongo-driver/bson" | ||||
| 	"go.mongodb.org/mongo-driver/mongo" | ||||
| 	"go.mongodb.org/mongo-driver/mongo/options" | ||||
| ) | ||||
| 
 | ||||
| func NewCacheMgo(db *mongo.Database) (*CacheMgo, error) { | ||||
| 	coll := db.Collection(database.CacheName) | ||||
| 	_, err := coll.Indexes().CreateMany(context.Background(), []mongo.IndexModel{ | ||||
| 		{ | ||||
| 			Keys: bson.D{ | ||||
| 				{Key: "key", Value: 1}, | ||||
| 			}, | ||||
| 			Options: options.Index().SetUnique(true), | ||||
| 		}, | ||||
| 		{ | ||||
| 			Keys: bson.D{ | ||||
| 				{Key: "expire_at", Value: 1}, | ||||
| 			}, | ||||
| 			Options: options.Index().SetExpireAfterSeconds(0), | ||||
| 		}, | ||||
| 	}) | ||||
| 	if err != nil { | ||||
| 		return nil, errs.Wrap(err) | ||||
| 	} | ||||
| 	return &CacheMgo{coll: coll}, nil | ||||
| } | ||||
| 
 | ||||
| type CacheMgo struct { | ||||
| 	coll *mongo.Collection | ||||
| } | ||||
| 
 | ||||
| func (x *CacheMgo) findToMap(res []model.Cache, now time.Time) map[string]string { | ||||
| 	kv := make(map[string]string) | ||||
| 	for _, re := range res { | ||||
| 		if re.ExpireAt != nil && re.ExpireAt.Before(now) { | ||||
| 			continue | ||||
| 		} | ||||
| 		kv[re.Key] = re.Value | ||||
| 	} | ||||
| 	return kv | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| func (x *CacheMgo) Get(ctx context.Context, key []string) (map[string]string, error) { | ||||
| 	if len(key) == 0 { | ||||
| 		return nil, nil | ||||
| 	} | ||||
| 	now := time.Now() | ||||
| 	res, err := mongoutil.Find[model.Cache](ctx, x.coll, bson.M{ | ||||
| 		"key": bson.M{"$in": key}, | ||||
| 		"$or": []bson.M{ | ||||
| 			{"expire_at": bson.M{"$gt": now}}, | ||||
| 			{"expire_at": nil}, | ||||
| 		}, | ||||
| 	}) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return x.findToMap(res, now), nil | ||||
| } | ||||
| 
 | ||||
| func (x *CacheMgo) Prefix(ctx context.Context, prefix string) (map[string]string, error) { | ||||
| 	now := time.Now() | ||||
| 	res, err := mongoutil.Find[model.Cache](ctx, x.coll, bson.M{ | ||||
| 		"key": bson.M{"$regex": "^" + prefix}, | ||||
| 		"$or": []bson.M{ | ||||
| 			{"expire_at": bson.M{"$gt": now}}, | ||||
| 			{"expire_at": nil}, | ||||
| 		}, | ||||
| 	}) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return x.findToMap(res, now), nil | ||||
| } | ||||
| 
 | ||||
| func (x *CacheMgo) Set(ctx context.Context, key string, value string, expireAt time.Duration) error { | ||||
| 	cv := &model.Cache{ | ||||
| 		Key:   key, | ||||
| 		Value: value, | ||||
| 	} | ||||
| 	if expireAt > 0 { | ||||
| 		now := time.Now().Add(expireAt) | ||||
| 		cv.ExpireAt = &now | ||||
| 	} | ||||
| 	opt := options.Update().SetUpsert(true) | ||||
| 	return mongoutil.UpdateOne(ctx, x.coll, bson.M{"key": key}, bson.M{"$set": cv}, false, opt) | ||||
| } | ||||
| 
 | ||||
| func (x *CacheMgo) Incr(ctx context.Context, key string, value int) (int, error) { | ||||
| 	pipeline := mongo.Pipeline{ | ||||
| 		{ | ||||
| 			{"$set", bson.M{ | ||||
| 				"value": bson.M{ | ||||
| 					"$toString": bson.M{ | ||||
| 						"$add": bson.A{ | ||||
| 							bson.M{"$toInt": "$value"}, | ||||
| 							value, | ||||
| 						}, | ||||
| 					}, | ||||
| 				}, | ||||
| 			}}, | ||||
| 		}, | ||||
| 	} | ||||
| 	opt := options.FindOneAndUpdate().SetReturnDocument(options.After) | ||||
| 	res, err := mongoutil.FindOneAndUpdate[model.Cache](ctx, x.coll, bson.M{"key": key}, pipeline, opt) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	return strconv.Atoi(res.Value) | ||||
| } | ||||
| 
 | ||||
| func (x *CacheMgo) Del(ctx context.Context, key []string) error { | ||||
| 	if len(key) == 0 { | ||||
| 		return nil | ||||
| 	} | ||||
| 	_, err := x.coll.DeleteMany(ctx, bson.M{"key": bson.M{"$in": key}}) | ||||
| 	return errs.Wrap(err) | ||||
| } | ||||
| 
 | ||||
| func (x *CacheMgo) lockKey(key string) string { | ||||
| 	return "LOCK_" + key | ||||
| } | ||||
| 
 | ||||
| func (x *CacheMgo) Lock(ctx context.Context, key string, duration time.Duration) (string, error) { | ||||
| 	tmp, err := uuid.NewUUID() | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
| 	if duration <= 0 || duration > time.Minute*10 { | ||||
| 		duration = time.Minute * 10 | ||||
| 	} | ||||
| 	cv := &model.Cache{ | ||||
| 		Key:      x.lockKey(key), | ||||
| 		Value:    tmp.String(), | ||||
| 		ExpireAt: nil, | ||||
| 	} | ||||
| 	ctx, cancel := context.WithTimeout(ctx, time.Second*30) | ||||
| 	defer cancel() | ||||
| 	wait := func() error { | ||||
| 		timeout := time.NewTimer(time.Millisecond * 100) | ||||
| 		defer timeout.Stop() | ||||
| 		select { | ||||
| 		case <-ctx.Done(): | ||||
| 			return ctx.Err() | ||||
| 		case <-timeout.C: | ||||
| 			return nil | ||||
| 		} | ||||
| 	} | ||||
| 	for { | ||||
| 		if err := mongoutil.DeleteOne(ctx, x.coll, bson.M{"key": key, "expire_at": bson.M{"$lt": time.Now()}}); err != nil { | ||||
| 			return "", err | ||||
| 		} | ||||
| 		expireAt := time.Now().Add(duration) | ||||
| 		cv.ExpireAt = &expireAt | ||||
| 		if err := mongoutil.InsertMany[*model.Cache](ctx, x.coll, []*model.Cache{cv}); err != nil { | ||||
| 			if mongo.IsDuplicateKeyError(err) { | ||||
| 				if err := wait(); err != nil { | ||||
| 					return "", err | ||||
| 				} | ||||
| 				continue | ||||
| 			} | ||||
| 			return "", err | ||||
| 		} | ||||
| 		return cv.Value, nil | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (x *CacheMgo) Unlock(ctx context.Context, key string, value string) error { | ||||
| 	return mongoutil.DeleteOne(ctx, x.coll, bson.M{"key": x.lockKey(key), "value": value}) | ||||
| } | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user