Merge pull request #3341 from openimsdk/cherry-pick-56c5c1f

fix: delete token by correct platformID && feat: adminToken can be re… [Created by @icey-yu from #3313]
This commit is contained in:
chao 2025-05-14 16:37:55 +08:00 committed by GitHub
commit c7a934e8db
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 492 additions and 45 deletions

View File

@ -140,15 +140,17 @@ func (s *authServer) parseToken(ctx context.Context, tokensString string) (claim
if err != nil { if err != nil {
return nil, err return nil, err
} }
isAdmin := authverify.IsManagerUserID(claims.UserID, s.config.Share.IMAdminUserID)
if isAdmin {
return claims, nil
}
m, err := s.authDatabase.GetTokensWithoutError(ctx, claims.UserID, claims.PlatformID) m, err := s.authDatabase.GetTokensWithoutError(ctx, claims.UserID, claims.PlatformID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if len(m) == 0 { if len(m) == 0 {
isAdmin := authverify.IsManagerUserID(claims.UserID, s.config.Share.IMAdminUserID)
if isAdmin {
if err = s.authDatabase.GetTemporaryTokensWithoutError(ctx, claims.UserID, claims.PlatformID, tokensString); err == nil {
return claims, nil
}
}
return nil, servererrs.ErrTokenNotExist.Wrap() return nil, servererrs.ErrTokenNotExist.Wrap()
} }
if v, ok := m[tokensString]; ok { if v, ok := m[tokensString]; ok {
@ -160,6 +162,13 @@ func (s *authServer) parseToken(ctx context.Context, tokensString string) (claim
default: default:
return nil, errs.Wrap(errs.ErrTokenUnknown) return nil, errs.Wrap(errs.ErrTokenUnknown)
} }
} else {
isAdmin := authverify.IsManagerUserID(claims.UserID, s.config.Share.IMAdminUserID)
if isAdmin {
if err = s.authDatabase.GetTemporaryTokensWithoutError(ctx, claims.UserID, claims.PlatformID, tokensString); err == nil {
return claims, nil
}
}
} }
return nil, servererrs.ErrTokenNotExist.Wrap() return nil, servererrs.ErrTokenNotExist.Wrap()
} }

View File

@ -1,8 +1,9 @@
package cachekey package cachekey
import ( import (
"github.com/openimsdk/protocol/constant"
"strings" "strings"
"github.com/openimsdk/protocol/constant"
) )
const ( const (
@ -13,6 +14,10 @@ func GetTokenKey(userID string, platformID int) string {
return UidPidToken + userID + ":" + constant.PlatformIDToName(platformID) return UidPidToken + userID + ":" + constant.PlatformIDToName(platformID)
} }
func GetTemporaryTokenKey(userID string, platformID int, token string) string {
return UidPidToken + ":TEMPORARY:" + userID + ":" + constant.PlatformIDToName(platformID) + ":" + token
}
func GetAllPlatformTokenKey(userID string) []string { func GetAllPlatformTokenKey(userID string) []string {
res := make([]string, len(constant.PlatformID2Name)) res := make([]string, len(constant.PlatformID2Name))
for k := range constant.PlatformID2Name { for k := range constant.PlatformID2Name {

166
pkg/common/storage/cache/mcache/token.go vendored Normal file
View File

@ -0,0 +1,166 @@
package mcache
import (
"context"
"fmt"
"strconv"
"strings"
"time"
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache"
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/cachekey"
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/database"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
)
func NewTokenCacheModel(cache database.Cache, accessExpire int64) cache.TokenModel {
c := &tokenCache{cache: cache}
c.accessExpire = c.getExpireTime(accessExpire)
return c
}
type tokenCache struct {
cache database.Cache
accessExpire time.Duration
}
func (x *tokenCache) getTokenKey(userID string, platformID int, token string) string {
return cachekey.GetTokenKey(userID, platformID) + ":" + token
}
func (x *tokenCache) SetTokenFlag(ctx context.Context, userID string, platformID int, token string, flag int) error {
return x.cache.Set(ctx, x.getTokenKey(userID, platformID, token), strconv.Itoa(flag), x.accessExpire)
}
// SetTokenFlagEx set token and flag with expire time
func (x *tokenCache) SetTokenFlagEx(ctx context.Context, userID string, platformID int, token string, flag int) error {
return x.SetTokenFlag(ctx, userID, platformID, token, flag)
}
func (x *tokenCache) GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error) {
prefix := x.getTokenKey(userID, platformID, "")
m, err := x.cache.Prefix(ctx, prefix)
if err != nil {
return nil, errs.Wrap(err)
}
mm := make(map[string]int)
for k, v := range m {
state, err := strconv.Atoi(v)
if err != nil {
log.ZError(ctx, "token value is not int", err, "value", v, "userID", userID, "platformID", platformID)
continue
}
mm[strings.TrimPrefix(k, prefix)] = state
}
return mm, nil
}
func (x *tokenCache) HasTemporaryToken(ctx context.Context, userID string, platformID int, token string) error {
key := cachekey.GetTemporaryTokenKey(userID, platformID, token)
if _, err := x.cache.Get(ctx, []string{key}); err != nil {
return err
}
return nil
}
func (x *tokenCache) GetAllTokensWithoutError(ctx context.Context, userID string) (map[int]map[string]int, error) {
prefix := cachekey.UidPidToken + userID + ":"
tokens, err := x.cache.Prefix(ctx, prefix)
if err != nil {
return nil, err
}
res := make(map[int]map[string]int)
for key, flagStr := range tokens {
flag, err := strconv.Atoi(flagStr)
if err != nil {
log.ZError(ctx, "token value is not int", err, "key", key, "value", flagStr, "userID", userID)
continue
}
arr := strings.SplitN(strings.TrimPrefix(key, prefix), ":", 2)
if len(arr) != 2 {
log.ZError(ctx, "token value is not int", err, "key", key, "value", flagStr, "userID", userID)
continue
}
platformID, err := strconv.Atoi(arr[0])
if err != nil {
log.ZError(ctx, "token value is not int", err, "key", key, "value", flagStr, "userID", userID)
continue
}
token := arr[1]
if token == "" {
log.ZError(ctx, "token value is not int", err, "key", key, "value", flagStr, "userID", userID)
continue
}
tk, ok := res[platformID]
if !ok {
tk = make(map[string]int)
res[platformID] = tk
}
tk[token] = flag
}
return res, nil
}
func (x *tokenCache) SetTokenMapByUidPid(ctx context.Context, userID string, platformID int, m map[string]int) error {
for token, flag := range m {
err := x.SetTokenFlag(ctx, userID, platformID, token, flag)
if err != nil {
return err
}
}
return nil
}
func (x *tokenCache) BatchSetTokenMapByUidPid(ctx context.Context, tokens map[string]map[string]any) error {
for prefix, tokenFlag := range tokens {
for token, flag := range tokenFlag {
flagStr := fmt.Sprintf("%v", flag)
if err := x.cache.Set(ctx, prefix+":"+token, flagStr, x.accessExpire); err != nil {
return err
}
}
}
return nil
}
func (x *tokenCache) DeleteTokenByUidPid(ctx context.Context, userID string, platformID int, fields []string) error {
keys := make([]string, 0, len(fields))
for _, token := range fields {
keys = append(keys, x.getTokenKey(userID, platformID, token))
}
return x.cache.Del(ctx, keys)
}
func (x *tokenCache) getExpireTime(t int64) time.Duration {
return time.Hour * 24 * time.Duration(t)
}
func (x *tokenCache) DeleteTokenByTokenMap(ctx context.Context, userID string, tokens map[int][]string) error {
keys := make([]string, 0, len(tokens))
for platformID, ts := range tokens {
for _, t := range ts {
keys = append(keys, x.getTokenKey(userID, platformID, t))
}
}
return x.cache.Del(ctx, keys)
}
func (x *tokenCache) DeleteAndSetTemporary(ctx context.Context, userID string, platformID int, fields []string) error {
keys := make([]string, 0, len(fields))
for _, f := range fields {
keys = append(keys, x.getTokenKey(userID, platformID, f))
}
if err := x.cache.Del(ctx, keys); err != nil {
return err
}
for _, f := range fields {
k := cachekey.GetTemporaryTokenKey(userID, platformID, f)
if err := x.cache.Set(ctx, k, "", time.Minute*5); err != nil {
return errs.Wrap(err)
}
}
return nil
}

View File

@ -9,6 +9,7 @@ import (
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache" "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache"
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/cachekey" "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/cachekey"
"github.com/openimsdk/tools/errs" "github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/utils/datautil"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
) )
@ -55,6 +56,14 @@ func (c *tokenCache) GetTokensWithoutError(ctx context.Context, userID string, p
return mm, nil return mm, nil
} }
func (c *tokenCache) HasTemporaryToken(ctx context.Context, userID string, platformID int, token string) error {
err := c.rdb.Get(ctx, cachekey.GetTemporaryTokenKey(userID, platformID, token)).Err()
if err != nil {
return errs.Wrap(err)
}
return nil
}
func (c *tokenCache) GetAllTokensWithoutError(ctx context.Context, userID string) (map[int]map[string]int, error) { func (c *tokenCache) GetAllTokensWithoutError(ctx context.Context, userID string) (map[int]map[string]int, error) {
var ( var (
res = make(map[int]map[string]int) res = make(map[int]map[string]int)
@ -101,6 +110,8 @@ func (c *tokenCache) SetTokenMapByUidPid(ctx context.Context, userID string, pla
} }
func (c *tokenCache) BatchSetTokenMapByUidPid(ctx context.Context, tokens map[string]map[string]any) error { func (c *tokenCache) BatchSetTokenMapByUidPid(ctx context.Context, tokens map[string]map[string]any) error {
keys := datautil.Keys(tokens)
if err := ProcessKeysBySlot(ctx, c.rdb, keys, func(ctx context.Context, slot int64, keys []string) error {
pipe := c.rdb.Pipeline() pipe := c.rdb.Pipeline()
for k, v := range tokens { for k, v := range tokens {
pipe.HSet(ctx, k, v) pipe.HSet(ctx, k, v)
@ -110,6 +121,10 @@ func (c *tokenCache) BatchSetTokenMapByUidPid(ctx context.Context, tokens map[st
return errs.Wrap(err) return errs.Wrap(err)
} }
return nil return nil
}); err != nil {
return err
}
return nil
} }
func (c *tokenCache) DeleteTokenByUidPid(ctx context.Context, userID string, platformID int, fields []string) error { func (c *tokenCache) DeleteTokenByUidPid(ctx context.Context, userID string, platformID int, fields []string) error {
@ -119,3 +134,47 @@ func (c *tokenCache) DeleteTokenByUidPid(ctx context.Context, userID string, pla
func (c *tokenCache) getExpireTime(t int64) time.Duration { func (c *tokenCache) getExpireTime(t int64) time.Duration {
return time.Hour * 24 * time.Duration(t) return time.Hour * 24 * time.Duration(t)
} }
// DeleteTokenByTokenMap tokens key is platformID, value is token slice
func (c *tokenCache) DeleteTokenByTokenMap(ctx context.Context, userID string, tokens map[int][]string) error {
var (
keys = make([]string, 0, len(tokens))
keyMap = make(map[string][]string)
)
for k, v := range tokens {
k1 := cachekey.GetTokenKey(userID, k)
keys = append(keys, k1)
keyMap[k1] = v
}
if err := ProcessKeysBySlot(ctx, c.rdb, keys, func(ctx context.Context, slot int64, keys []string) error {
pipe := c.rdb.Pipeline()
for k, v := range tokens {
pipe.HDel(ctx, cachekey.GetTokenKey(userID, k), v...)
}
_, err := pipe.Exec(ctx)
if err != nil {
return errs.Wrap(err)
}
return nil
}); err != nil {
return err
}
return nil
}
func (c *tokenCache) DeleteAndSetTemporary(ctx context.Context, userID string, platformID int, fields []string) error {
key := cachekey.GetTokenKey(userID, platformID)
if err := c.rdb.HDel(ctx, key, fields...).Err(); err != nil {
return errs.Wrap(err)
}
for _, f := range fields {
k := cachekey.GetTemporaryTokenKey(userID, platformID, f)
if err := c.rdb.Set(ctx, k, "", time.Minute*5).Err(); err != nil {
return errs.Wrap(err)
}
}
return nil
}

View File

@ -9,8 +9,11 @@ type TokenModel interface {
// SetTokenFlagEx set token and flag with expire time // SetTokenFlagEx set token and flag with expire time
SetTokenFlagEx(ctx context.Context, userID string, platformID int, token string, flag int) error SetTokenFlagEx(ctx context.Context, userID string, platformID int, token string, flag int) error
GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error) GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error)
HasTemporaryToken(ctx context.Context, userID string, platformID int, token string) error
GetAllTokensWithoutError(ctx context.Context, userID string) (map[int]map[string]int, error) GetAllTokensWithoutError(ctx context.Context, userID string) (map[int]map[string]int, error)
SetTokenMapByUidPid(ctx context.Context, userID string, platformID int, m map[string]int) error SetTokenMapByUidPid(ctx context.Context, userID string, platformID int, m map[string]int) error
BatchSetTokenMapByUidPid(ctx context.Context, tokens map[string]map[string]any) error BatchSetTokenMapByUidPid(ctx context.Context, tokens map[string]map[string]any) error
DeleteTokenByUidPid(ctx context.Context, userID string, platformID int, fields []string) error DeleteTokenByUidPid(ctx context.Context, userID string, platformID int, fields []string) error
DeleteTokenByTokenMap(ctx context.Context, userID string, tokens map[int][]string) error
DeleteAndSetTemporary(ctx context.Context, userID string, platformID int, fields []string) error
} }

View File

@ -17,6 +17,8 @@ import (
type AuthDatabase interface { type AuthDatabase interface {
// If the result is empty, no error is returned. // If the result is empty, no error is returned.
GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error) GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error)
GetTemporaryTokensWithoutError(ctx context.Context, userID string, platformID int, token string) error
// Create token // Create token
CreateToken(ctx context.Context, userID string, platformID int) (string, error) CreateToken(ctx context.Context, userID string, platformID int) (string, error)
@ -51,6 +53,10 @@ func (a *authDatabase) GetTokensWithoutError(ctx context.Context, userID string,
return a.cache.GetTokensWithoutError(ctx, userID, platformID) return a.cache.GetTokensWithoutError(ctx, userID, platformID)
} }
func (a *authDatabase) GetTemporaryTokensWithoutError(ctx context.Context, userID string, platformID int, token string) error {
return a.cache.HasTemporaryToken(ctx, userID, platformID, token)
}
func (a *authDatabase) SetTokenMapByUidPid(ctx context.Context, userID string, platformID int, m map[string]int) error { func (a *authDatabase) SetTokenMapByUidPid(ctx context.Context, userID string, platformID int, m map[string]int) error {
return a.cache.SetTokenMapByUidPid(ctx, userID, platformID, m) return a.cache.SetTokenMapByUidPid(ctx, userID, platformID, m)
} }
@ -86,19 +92,20 @@ func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformI
return "", err return "", err
} }
deleteTokenKey, kickedTokenKey, err := a.checkToken(ctx, tokens, platformID) deleteTokenKey, kickedTokenKey, adminTokens, err := a.checkToken(ctx, tokens, platformID)
if err != nil { if err != nil {
return "", err return "", err
} }
if len(deleteTokenKey) != 0 { if len(deleteTokenKey) != 0 {
err = a.cache.DeleteTokenByUidPid(ctx, userID, platformID, deleteTokenKey) err = a.cache.DeleteTokenByTokenMap(ctx, userID, deleteTokenKey)
if err != nil { if err != nil {
return "", err return "", err
} }
} }
if len(kickedTokenKey) != 0 { if len(kickedTokenKey) != 0 {
for _, k := range kickedTokenKey { for plt, ks := range kickedTokenKey {
err := a.cache.SetTokenFlagEx(ctx, userID, platformID, k, constant.KickedToken) for _, k := range ks {
err := a.cache.SetTokenFlagEx(ctx, userID, plt, k, constant.KickedToken)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -106,6 +113,11 @@ func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformI
} }
} }
} }
if len(adminTokens) != 0 {
if err = a.cache.DeleteAndSetTemporary(ctx, userID, constant.AdminPlatformID, adminTokens); err != nil {
return "", err
}
}
claims := tokenverify.BuildClaims(userID, platformID, a.accessExpire) claims := tokenverify.BuildClaims(userID, platformID, a.accessExpire)
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
@ -123,12 +135,13 @@ func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformI
return tokenString, nil return tokenString, nil
} }
func (a *authDatabase) checkToken(ctx context.Context, tokens map[int]map[string]int, platformID int) ([]string, []string, error) { // checkToken will check token by tokenPolicy and return deleteToken,kickToken,deleteAdminToken
// todo: Move the logic for handling old data to another location. func (a *authDatabase) checkToken(ctx context.Context, tokens map[int]map[string]int, platformID int) (map[int][]string, map[int][]string, []string, error) {
// todo: Asynchronous deletion of old data.
var ( var (
loginTokenMap = make(map[int][]string) // The length of the value of the map must be greater than 0 loginTokenMap = make(map[int][]string) // The length of the value of the map must be greater than 0
deleteToken = make([]string, 0) deleteToken = make(map[int][]string)
kickToken = make([]string, 0) kickToken = make(map[int][]string)
adminToken = make([]string, 0) adminToken = make([]string, 0)
unkickTerminal = "" unkickTerminal = ""
) )
@ -137,7 +150,7 @@ func (a *authDatabase) checkToken(ctx context.Context, tokens map[int]map[string
for k, v := range tks { for k, v := range tks {
_, err := tokenverify.GetClaimFromToken(k, authverify.Secret(a.accessSecret)) _, err := tokenverify.GetClaimFromToken(k, authverify.Secret(a.accessSecret))
if err != nil || v != constant.NormalToken { if err != nil || v != constant.NormalToken {
deleteToken = append(deleteToken, k) deleteToken[plfID] = append(deleteToken[plfID], k)
} else { } else {
if plfID != constant.AdminPlatformID { if plfID != constant.AdminPlatformID {
loginTokenMap[plfID] = append(loginTokenMap[plfID], k) loginTokenMap[plfID] = append(loginTokenMap[plfID], k)
@ -157,14 +170,15 @@ func (a *authDatabase) checkToken(ctx context.Context, tokens map[int]map[string
} }
limit := a.multiLogin.MaxNumOneEnd limit := a.multiLogin.MaxNumOneEnd
if l > limit { if l > limit {
kickToken = append(kickToken, ts[:l-limit]...) kickToken[plt] = ts[:l-limit]
} }
} }
case constant.AllLoginButSameTermKick: case constant.AllLoginButSameTermKick:
for plt, ts := range loginTokenMap { for plt, ts := range loginTokenMap {
kickToken = append(kickToken, ts[:len(ts)-1]...) kickToken[plt] = ts[:len(ts)-1]
if plt == platformID { if plt == platformID {
kickToken = append(kickToken, ts[len(ts)-1]) kickToken[plt] = append(kickToken[plt], ts[len(ts)-1])
} }
} }
case constant.PCAndOther: case constant.PCAndOther:
@ -172,29 +186,33 @@ func (a *authDatabase) checkToken(ctx context.Context, tokens map[int]map[string
if constant.PlatformIDToClass(platformID) != unkickTerminal { if constant.PlatformIDToClass(platformID) != unkickTerminal {
for plt, ts := range loginTokenMap { for plt, ts := range loginTokenMap {
if constant.PlatformIDToClass(plt) != unkickTerminal { if constant.PlatformIDToClass(plt) != unkickTerminal {
kickToken = append(kickToken, ts...) kickToken[plt] = ts
} }
} }
} else { } else {
var ( var (
preKick []string preKickToken string
isReserve = true preKickPlt int
reserveToken = false
) )
for plt, ts := range loginTokenMap { for plt, ts := range loginTokenMap {
if constant.PlatformIDToClass(plt) != unkickTerminal { if constant.PlatformIDToClass(plt) != unkickTerminal {
// Keep a token from another end // Keep a token from another end
if isReserve { if !reserveToken {
isReserve = false reserveToken = true
kickToken = append(kickToken, ts[:len(ts)-1]...) kickToken[plt] = ts[:len(ts)-1]
preKick = append(preKick, ts[len(ts)-1]) preKickToken = ts[len(ts)-1]
preKickPlt = plt
continue continue
} else { } else {
// Prioritize keeping Android // Prioritize keeping Android
if plt == constant.AndroidPlatformID { if plt == constant.AndroidPlatformID {
kickToken = append(kickToken, preKick...) if preKickToken != "" {
kickToken = append(kickToken, ts[:len(ts)-1]...) kickToken[preKickPlt] = append(kickToken[preKickPlt], preKickToken)
}
kickToken[plt] = ts[:len(ts)-1]
} else { } else {
kickToken = append(kickToken, ts...) kickToken[plt] = ts
} }
} }
} }
@ -207,19 +225,19 @@ func (a *authDatabase) checkToken(ctx context.Context, tokens map[int]map[string
for plt, ts := range loginTokenMap { for plt, ts := range loginTokenMap {
if constant.PlatformIDToClass(plt) == constant.PlatformIDToClass(platformID) { if constant.PlatformIDToClass(plt) == constant.PlatformIDToClass(platformID) {
kickToken = append(kickToken, ts...) kickToken[plt] = ts
} else { } else {
if _, ok := reserved[constant.PlatformIDToClass(plt)]; !ok { if _, ok := reserved[constant.PlatformIDToClass(plt)]; !ok {
reserved[constant.PlatformIDToClass(plt)] = struct{}{} reserved[constant.PlatformIDToClass(plt)] = struct{}{}
kickToken = append(kickToken, ts[:len(ts)-1]...) kickToken[plt] = ts[:len(ts)-1]
continue continue
} else { } else {
kickToken = append(kickToken, ts...) kickToken[plt] = ts
} }
} }
} }
default: default:
return nil, nil, errs.New("unknown multiLogin policy").Wrap() return nil, nil, nil, errs.New("unknown multiLogin policy").Wrap()
} }
//var adminTokenMaxNum = a.multiLogin.MaxNumOneEnd //var adminTokenMaxNum = a.multiLogin.MaxNumOneEnd
@ -233,5 +251,9 @@ func (a *authDatabase) checkToken(ctx context.Context, tokens map[int]map[string
//if l > adminTokenMaxNum { //if l > adminTokenMaxNum {
// kickToken = append(kickToken, adminToken[:l-adminTokenMaxNum]...) // kickToken = append(kickToken, adminToken[:l-adminTokenMaxNum]...)
//} //}
return deleteToken, kickToken, nil var deleteAdminToken []string
if platformID == constant.AdminPlatformID {
deleteAdminToken = adminToken
}
return deleteToken, kickToken, deleteAdminToken, nil
} }

View File

@ -0,0 +1,183 @@
package mgo
import (
"context"
"strconv"
"time"
"github.com/google/uuid"
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/database"
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/model"
"github.com/openimsdk/tools/db/mongoutil"
"github.com/openimsdk/tools/errs"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
func NewCacheMgo(db *mongo.Database) (*CacheMgo, error) {
coll := db.Collection(database.CacheName)
_, err := coll.Indexes().CreateMany(context.Background(), []mongo.IndexModel{
{
Keys: bson.D{
{Key: "key", Value: 1},
},
Options: options.Index().SetUnique(true),
},
{
Keys: bson.D{
{Key: "expire_at", Value: 1},
},
Options: options.Index().SetExpireAfterSeconds(0),
},
})
if err != nil {
return nil, errs.Wrap(err)
}
return &CacheMgo{coll: coll}, nil
}
type CacheMgo struct {
coll *mongo.Collection
}
func (x *CacheMgo) findToMap(res []model.Cache, now time.Time) map[string]string {
kv := make(map[string]string)
for _, re := range res {
if re.ExpireAt != nil && re.ExpireAt.Before(now) {
continue
}
kv[re.Key] = re.Value
}
return kv
}
func (x *CacheMgo) Get(ctx context.Context, key []string) (map[string]string, error) {
if len(key) == 0 {
return nil, nil
}
now := time.Now()
res, err := mongoutil.Find[model.Cache](ctx, x.coll, bson.M{
"key": bson.M{"$in": key},
"$or": []bson.M{
{"expire_at": bson.M{"$gt": now}},
{"expire_at": nil},
},
})
if err != nil {
return nil, err
}
return x.findToMap(res, now), nil
}
func (x *CacheMgo) Prefix(ctx context.Context, prefix string) (map[string]string, error) {
now := time.Now()
res, err := mongoutil.Find[model.Cache](ctx, x.coll, bson.M{
"key": bson.M{"$regex": "^" + prefix},
"$or": []bson.M{
{"expire_at": bson.M{"$gt": now}},
{"expire_at": nil},
},
})
if err != nil {
return nil, err
}
return x.findToMap(res, now), nil
}
func (x *CacheMgo) Set(ctx context.Context, key string, value string, expireAt time.Duration) error {
cv := &model.Cache{
Key: key,
Value: value,
}
if expireAt > 0 {
now := time.Now().Add(expireAt)
cv.ExpireAt = &now
}
opt := options.Update().SetUpsert(true)
return mongoutil.UpdateOne(ctx, x.coll, bson.M{"key": key}, bson.M{"$set": cv}, false, opt)
}
func (x *CacheMgo) Incr(ctx context.Context, key string, value int) (int, error) {
pipeline := mongo.Pipeline{
{
{"$set", bson.M{
"value": bson.M{
"$toString": bson.M{
"$add": bson.A{
bson.M{"$toInt": "$value"},
value,
},
},
},
}},
},
}
opt := options.FindOneAndUpdate().SetReturnDocument(options.After)
res, err := mongoutil.FindOneAndUpdate[model.Cache](ctx, x.coll, bson.M{"key": key}, pipeline, opt)
if err != nil {
return 0, err
}
return strconv.Atoi(res.Value)
}
func (x *CacheMgo) Del(ctx context.Context, key []string) error {
if len(key) == 0 {
return nil
}
_, err := x.coll.DeleteMany(ctx, bson.M{"key": bson.M{"$in": key}})
return errs.Wrap(err)
}
func (x *CacheMgo) lockKey(key string) string {
return "LOCK_" + key
}
func (x *CacheMgo) Lock(ctx context.Context, key string, duration time.Duration) (string, error) {
tmp, err := uuid.NewUUID()
if err != nil {
return "", err
}
if duration <= 0 || duration > time.Minute*10 {
duration = time.Minute * 10
}
cv := &model.Cache{
Key: x.lockKey(key),
Value: tmp.String(),
ExpireAt: nil,
}
ctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel()
wait := func() error {
timeout := time.NewTimer(time.Millisecond * 100)
defer timeout.Stop()
select {
case <-ctx.Done():
return ctx.Err()
case <-timeout.C:
return nil
}
}
for {
if err := mongoutil.DeleteOne(ctx, x.coll, bson.M{"key": key, "expire_at": bson.M{"$lt": time.Now()}}); err != nil {
return "", err
}
expireAt := time.Now().Add(duration)
cv.ExpireAt = &expireAt
if err := mongoutil.InsertMany[*model.Cache](ctx, x.coll, []*model.Cache{cv}); err != nil {
if mongo.IsDuplicateKeyError(err) {
if err := wait(); err != nil {
return "", err
}
continue
}
return "", err
}
return cv.Value, nil
}
}
func (x *CacheMgo) Unlock(ctx context.Context, key string, value string) error {
return mongoutil.DeleteOne(ctx, x.coll, bson.M{"key": x.lockKey(key), "value": value})
}