From 78aaf6abce679ab1091f48eb7e2590748fafd4ad Mon Sep 17 00:00:00 2001 From: icey-yu <119291641+icey-yu@users.noreply.github.com> Date: Tue, 6 May 2025 15:10:10 +0800 Subject: [PATCH] fix: delete token by correct platformID && feat: adminToken can be retained for five minutes after deleting (#3313) --- internal/rpc/auth/auth.go | 17 +- pkg/common/storage/cache/cachekey/token.go | 7 +- pkg/common/storage/cache/mcache/token.go | 166 +++++++++++++++++++ pkg/common/storage/cache/redis/token.go | 73 +++++++- pkg/common/storage/cache/token.go | 3 + pkg/common/storage/controller/auth.go | 88 ++++++---- pkg/common/storage/database/mgo/cache.go | 183 +++++++++++++++++++++ 7 files changed, 492 insertions(+), 45 deletions(-) create mode 100644 pkg/common/storage/cache/mcache/token.go create mode 100644 pkg/common/storage/database/mgo/cache.go diff --git a/internal/rpc/auth/auth.go b/internal/rpc/auth/auth.go index 2e64c365c..d34630b2f 100644 --- a/internal/rpc/auth/auth.go +++ b/internal/rpc/auth/auth.go @@ -140,15 +140,17 @@ func (s *authServer) parseToken(ctx context.Context, tokensString string) (claim if err != nil { return nil, err } - isAdmin := authverify.IsManagerUserID(claims.UserID, s.config.Share.IMAdminUserID) - if isAdmin { - return claims, nil - } m, err := s.authDatabase.GetTokensWithoutError(ctx, claims.UserID, claims.PlatformID) if err != nil { return nil, err } if len(m) == 0 { + isAdmin := authverify.IsManagerUserID(claims.UserID, s.config.Share.IMAdminUserID) + if isAdmin { + if err = s.authDatabase.GetTemporaryTokensWithoutError(ctx, claims.UserID, claims.PlatformID, tokensString); err == nil { + return claims, nil + } + } return nil, servererrs.ErrTokenNotExist.Wrap() } if v, ok := m[tokensString]; ok { @@ -160,6 +162,13 @@ func (s *authServer) parseToken(ctx context.Context, tokensString string) (claim default: return nil, errs.Wrap(errs.ErrTokenUnknown) } + } else { + isAdmin := authverify.IsManagerUserID(claims.UserID, s.config.Share.IMAdminUserID) + if isAdmin { + if err = s.authDatabase.GetTemporaryTokensWithoutError(ctx, claims.UserID, claims.PlatformID, tokensString); err == nil { + return claims, nil + } + } } return nil, servererrs.ErrTokenNotExist.Wrap() } diff --git a/pkg/common/storage/cache/cachekey/token.go b/pkg/common/storage/cache/cachekey/token.go index 83ba2f211..6fe1bdfef 100644 --- a/pkg/common/storage/cache/cachekey/token.go +++ b/pkg/common/storage/cache/cachekey/token.go @@ -1,8 +1,9 @@ package cachekey import ( - "github.com/openimsdk/protocol/constant" "strings" + + "github.com/openimsdk/protocol/constant" ) const ( @@ -13,6 +14,10 @@ func GetTokenKey(userID string, platformID int) string { return UidPidToken + userID + ":" + constant.PlatformIDToName(platformID) } +func GetTemporaryTokenKey(userID string, platformID int, token string) string { + return UidPidToken + ":TEMPORARY:" + userID + ":" + constant.PlatformIDToName(platformID) + ":" + token +} + func GetAllPlatformTokenKey(userID string) []string { res := make([]string, len(constant.PlatformID2Name)) for k := range constant.PlatformID2Name { diff --git a/pkg/common/storage/cache/mcache/token.go b/pkg/common/storage/cache/mcache/token.go new file mode 100644 index 000000000..98b9cc066 --- /dev/null +++ b/pkg/common/storage/cache/mcache/token.go @@ -0,0 +1,166 @@ +package mcache + +import ( + "context" + "fmt" + "strconv" + "strings" + "time" + + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache" + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/cachekey" + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/database" + "github.com/openimsdk/tools/errs" + "github.com/openimsdk/tools/log" +) + +func NewTokenCacheModel(cache database.Cache, accessExpire int64) cache.TokenModel { + c := &tokenCache{cache: cache} + c.accessExpire = c.getExpireTime(accessExpire) + return c +} + +type tokenCache struct { + cache database.Cache + accessExpire time.Duration +} + +func (x *tokenCache) getTokenKey(userID string, platformID int, token string) string { + return cachekey.GetTokenKey(userID, platformID) + ":" + token +} + +func (x *tokenCache) SetTokenFlag(ctx context.Context, userID string, platformID int, token string, flag int) error { + return x.cache.Set(ctx, x.getTokenKey(userID, platformID, token), strconv.Itoa(flag), x.accessExpire) +} + +// SetTokenFlagEx set token and flag with expire time +func (x *tokenCache) SetTokenFlagEx(ctx context.Context, userID string, platformID int, token string, flag int) error { + return x.SetTokenFlag(ctx, userID, platformID, token, flag) +} + +func (x *tokenCache) GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error) { + prefix := x.getTokenKey(userID, platformID, "") + m, err := x.cache.Prefix(ctx, prefix) + if err != nil { + return nil, errs.Wrap(err) + } + mm := make(map[string]int) + for k, v := range m { + state, err := strconv.Atoi(v) + if err != nil { + log.ZError(ctx, "token value is not int", err, "value", v, "userID", userID, "platformID", platformID) + continue + } + mm[strings.TrimPrefix(k, prefix)] = state + } + return mm, nil +} + +func (x *tokenCache) HasTemporaryToken(ctx context.Context, userID string, platformID int, token string) error { + key := cachekey.GetTemporaryTokenKey(userID, platformID, token) + if _, err := x.cache.Get(ctx, []string{key}); err != nil { + return err + } + return nil +} + +func (x *tokenCache) GetAllTokensWithoutError(ctx context.Context, userID string) (map[int]map[string]int, error) { + prefix := cachekey.UidPidToken + userID + ":" + tokens, err := x.cache.Prefix(ctx, prefix) + if err != nil { + return nil, err + } + res := make(map[int]map[string]int) + for key, flagStr := range tokens { + flag, err := strconv.Atoi(flagStr) + if err != nil { + log.ZError(ctx, "token value is not int", err, "key", key, "value", flagStr, "userID", userID) + continue + } + arr := strings.SplitN(strings.TrimPrefix(key, prefix), ":", 2) + if len(arr) != 2 { + log.ZError(ctx, "token value is not int", err, "key", key, "value", flagStr, "userID", userID) + continue + } + platformID, err := strconv.Atoi(arr[0]) + if err != nil { + log.ZError(ctx, "token value is not int", err, "key", key, "value", flagStr, "userID", userID) + continue + } + token := arr[1] + if token == "" { + log.ZError(ctx, "token value is not int", err, "key", key, "value", flagStr, "userID", userID) + continue + } + tk, ok := res[platformID] + if !ok { + tk = make(map[string]int) + res[platformID] = tk + } + tk[token] = flag + } + return res, nil +} + +func (x *tokenCache) SetTokenMapByUidPid(ctx context.Context, userID string, platformID int, m map[string]int) error { + for token, flag := range m { + err := x.SetTokenFlag(ctx, userID, platformID, token, flag) + if err != nil { + return err + } + } + return nil +} + +func (x *tokenCache) BatchSetTokenMapByUidPid(ctx context.Context, tokens map[string]map[string]any) error { + for prefix, tokenFlag := range tokens { + for token, flag := range tokenFlag { + flagStr := fmt.Sprintf("%v", flag) + if err := x.cache.Set(ctx, prefix+":"+token, flagStr, x.accessExpire); err != nil { + return err + } + } + } + return nil +} + +func (x *tokenCache) DeleteTokenByUidPid(ctx context.Context, userID string, platformID int, fields []string) error { + keys := make([]string, 0, len(fields)) + for _, token := range fields { + keys = append(keys, x.getTokenKey(userID, platformID, token)) + } + return x.cache.Del(ctx, keys) +} + +func (x *tokenCache) getExpireTime(t int64) time.Duration { + return time.Hour * 24 * time.Duration(t) +} + +func (x *tokenCache) DeleteTokenByTokenMap(ctx context.Context, userID string, tokens map[int][]string) error { + keys := make([]string, 0, len(tokens)) + for platformID, ts := range tokens { + for _, t := range ts { + keys = append(keys, x.getTokenKey(userID, platformID, t)) + } + } + return x.cache.Del(ctx, keys) +} + +func (x *tokenCache) DeleteAndSetTemporary(ctx context.Context, userID string, platformID int, fields []string) error { + keys := make([]string, 0, len(fields)) + for _, f := range fields { + keys = append(keys, x.getTokenKey(userID, platformID, f)) + } + if err := x.cache.Del(ctx, keys); err != nil { + return err + } + + for _, f := range fields { + k := cachekey.GetTemporaryTokenKey(userID, platformID, f) + if err := x.cache.Set(ctx, k, "", time.Minute*5); err != nil { + return errs.Wrap(err) + } + } + + return nil +} diff --git a/pkg/common/storage/cache/redis/token.go b/pkg/common/storage/cache/redis/token.go index 510da43e3..b3870daee 100644 --- a/pkg/common/storage/cache/redis/token.go +++ b/pkg/common/storage/cache/redis/token.go @@ -9,6 +9,7 @@ import ( "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache" "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/cachekey" "github.com/openimsdk/tools/errs" + "github.com/openimsdk/tools/utils/datautil" "github.com/redis/go-redis/v9" ) @@ -55,6 +56,14 @@ func (c *tokenCache) GetTokensWithoutError(ctx context.Context, userID string, p return mm, nil } +func (c *tokenCache) HasTemporaryToken(ctx context.Context, userID string, platformID int, token string) error { + err := c.rdb.Get(ctx, cachekey.GetTemporaryTokenKey(userID, platformID, token)).Err() + if err != nil { + return errs.Wrap(err) + } + return nil +} + func (c *tokenCache) GetAllTokensWithoutError(ctx context.Context, userID string) (map[int]map[string]int, error) { var ( res = make(map[int]map[string]int) @@ -101,13 +110,19 @@ func (c *tokenCache) SetTokenMapByUidPid(ctx context.Context, userID string, pla } func (c *tokenCache) BatchSetTokenMapByUidPid(ctx context.Context, tokens map[string]map[string]any) error { - pipe := c.rdb.Pipeline() - for k, v := range tokens { - pipe.HSet(ctx, k, v) - } - _, err := pipe.Exec(ctx) - if err != nil { - return errs.Wrap(err) + keys := datautil.Keys(tokens) + if err := ProcessKeysBySlot(ctx, c.rdb, keys, func(ctx context.Context, slot int64, keys []string) error { + pipe := c.rdb.Pipeline() + for k, v := range tokens { + pipe.HSet(ctx, k, v) + } + _, err := pipe.Exec(ctx) + if err != nil { + return errs.Wrap(err) + } + return nil + }); err != nil { + return err } return nil } @@ -119,3 +134,47 @@ func (c *tokenCache) DeleteTokenByUidPid(ctx context.Context, userID string, pla func (c *tokenCache) getExpireTime(t int64) time.Duration { return time.Hour * 24 * time.Duration(t) } + +// DeleteTokenByTokenMap tokens key is platformID, value is token slice +func (c *tokenCache) DeleteTokenByTokenMap(ctx context.Context, userID string, tokens map[int][]string) error { + var ( + keys = make([]string, 0, len(tokens)) + keyMap = make(map[string][]string) + ) + for k, v := range tokens { + k1 := cachekey.GetTokenKey(userID, k) + keys = append(keys, k1) + keyMap[k1] = v + } + + if err := ProcessKeysBySlot(ctx, c.rdb, keys, func(ctx context.Context, slot int64, keys []string) error { + pipe := c.rdb.Pipeline() + for k, v := range tokens { + pipe.HDel(ctx, cachekey.GetTokenKey(userID, k), v...) + } + _, err := pipe.Exec(ctx) + if err != nil { + return errs.Wrap(err) + } + return nil + }); err != nil { + return err + } + + return nil +} + +func (c *tokenCache) DeleteAndSetTemporary(ctx context.Context, userID string, platformID int, fields []string) error { + key := cachekey.GetTokenKey(userID, platformID) + if err := c.rdb.HDel(ctx, key, fields...).Err(); err != nil { + return errs.Wrap(err) + } + for _, f := range fields { + k := cachekey.GetTemporaryTokenKey(userID, platformID, f) + if err := c.rdb.Set(ctx, k, "", time.Minute*5).Err(); err != nil { + return errs.Wrap(err) + } + } + + return nil +} diff --git a/pkg/common/storage/cache/token.go b/pkg/common/storage/cache/token.go index e5e0a9383..441c08939 100644 --- a/pkg/common/storage/cache/token.go +++ b/pkg/common/storage/cache/token.go @@ -9,8 +9,11 @@ type TokenModel interface { // SetTokenFlagEx set token and flag with expire time SetTokenFlagEx(ctx context.Context, userID string, platformID int, token string, flag int) error GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error) + HasTemporaryToken(ctx context.Context, userID string, platformID int, token string) error GetAllTokensWithoutError(ctx context.Context, userID string) (map[int]map[string]int, error) SetTokenMapByUidPid(ctx context.Context, userID string, platformID int, m map[string]int) error BatchSetTokenMapByUidPid(ctx context.Context, tokens map[string]map[string]any) error DeleteTokenByUidPid(ctx context.Context, userID string, platformID int, fields []string) error + DeleteTokenByTokenMap(ctx context.Context, userID string, tokens map[int][]string) error + DeleteAndSetTemporary(ctx context.Context, userID string, platformID int, fields []string) error } diff --git a/pkg/common/storage/controller/auth.go b/pkg/common/storage/controller/auth.go index f9061a73b..496a434bf 100644 --- a/pkg/common/storage/controller/auth.go +++ b/pkg/common/storage/controller/auth.go @@ -17,6 +17,8 @@ import ( type AuthDatabase interface { // If the result is empty, no error is returned. GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error) + + GetTemporaryTokensWithoutError(ctx context.Context, userID string, platformID int, token string) error // Create token CreateToken(ctx context.Context, userID string, platformID int) (string, error) @@ -51,6 +53,10 @@ func (a *authDatabase) GetTokensWithoutError(ctx context.Context, userID string, return a.cache.GetTokensWithoutError(ctx, userID, platformID) } +func (a *authDatabase) GetTemporaryTokensWithoutError(ctx context.Context, userID string, platformID int, token string) error { + return a.cache.HasTemporaryToken(ctx, userID, platformID, token) +} + func (a *authDatabase) SetTokenMapByUidPid(ctx context.Context, userID string, platformID int, m map[string]int) error { return a.cache.SetTokenMapByUidPid(ctx, userID, platformID, m) } @@ -86,19 +92,20 @@ func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformI return "", err } - deleteTokenKey, kickedTokenKey, err := a.checkToken(ctx, tokens, platformID) + deleteTokenKey, kickedTokenKey, adminTokens, err := a.checkToken(ctx, tokens, platformID) + if err != nil { + return "", err + } + if len(deleteTokenKey) != 0 { + err = a.cache.DeleteTokenByTokenMap(ctx, userID, deleteTokenKey) if err != nil { return "", err } - if len(deleteTokenKey) != 0 { - err = a.cache.DeleteTokenByUidPid(ctx, userID, platformID, deleteTokenKey) - if err != nil { - return "", err - } - } - if len(kickedTokenKey) != 0 { - for _, k := range kickedTokenKey { - err := a.cache.SetTokenFlagEx(ctx, userID, platformID, k, constant.KickedToken) + } + if len(kickedTokenKey) != 0 { + for plt, ks := range kickedTokenKey { + for _, k := range ks { + err := a.cache.SetTokenFlagEx(ctx, userID, plt, k, constant.KickedToken) if err != nil { return "", err } @@ -106,6 +113,11 @@ func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformI } } } + if len(adminTokens) != 0 { + if err = a.cache.DeleteAndSetTemporary(ctx, userID, constant.AdminPlatformID, adminTokens); err != nil { + return "", err + } + } claims := tokenverify.BuildClaims(userID, platformID, a.accessExpire) token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) @@ -123,12 +135,13 @@ func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformI return tokenString, nil } -func (a *authDatabase) checkToken(ctx context.Context, tokens map[int]map[string]int, platformID int) ([]string, []string, error) { - // todo: Move the logic for handling old data to another location. +// checkToken will check token by tokenPolicy and return deleteToken,kickToken,deleteAdminToken +func (a *authDatabase) checkToken(ctx context.Context, tokens map[int]map[string]int, platformID int) (map[int][]string, map[int][]string, []string, error) { + // todo: Asynchronous deletion of old data. var ( loginTokenMap = make(map[int][]string) // The length of the value of the map must be greater than 0 - deleteToken = make([]string, 0) - kickToken = make([]string, 0) + deleteToken = make(map[int][]string) + kickToken = make(map[int][]string) adminToken = make([]string, 0) unkickTerminal = "" ) @@ -137,7 +150,7 @@ func (a *authDatabase) checkToken(ctx context.Context, tokens map[int]map[string for k, v := range tks { _, err := tokenverify.GetClaimFromToken(k, authverify.Secret(a.accessSecret)) if err != nil || v != constant.NormalToken { - deleteToken = append(deleteToken, k) + deleteToken[plfID] = append(deleteToken[plfID], k) } else { if plfID != constant.AdminPlatformID { loginTokenMap[plfID] = append(loginTokenMap[plfID], k) @@ -157,14 +170,15 @@ func (a *authDatabase) checkToken(ctx context.Context, tokens map[int]map[string } limit := a.multiLogin.MaxNumOneEnd if l > limit { - kickToken = append(kickToken, ts[:l-limit]...) + kickToken[plt] = ts[:l-limit] } } case constant.AllLoginButSameTermKick: for plt, ts := range loginTokenMap { - kickToken = append(kickToken, ts[:len(ts)-1]...) + kickToken[plt] = ts[:len(ts)-1] + if plt == platformID { - kickToken = append(kickToken, ts[len(ts)-1]) + kickToken[plt] = append(kickToken[plt], ts[len(ts)-1]) } } case constant.PCAndOther: @@ -172,29 +186,33 @@ func (a *authDatabase) checkToken(ctx context.Context, tokens map[int]map[string if constant.PlatformIDToClass(platformID) != unkickTerminal { for plt, ts := range loginTokenMap { if constant.PlatformIDToClass(plt) != unkickTerminal { - kickToken = append(kickToken, ts...) + kickToken[plt] = ts } } } else { var ( - preKick []string - isReserve = true + preKickToken string + preKickPlt int + reserveToken = false ) for plt, ts := range loginTokenMap { if constant.PlatformIDToClass(plt) != unkickTerminal { // Keep a token from another end - if isReserve { - isReserve = false - kickToken = append(kickToken, ts[:len(ts)-1]...) - preKick = append(preKick, ts[len(ts)-1]) + if !reserveToken { + reserveToken = true + kickToken[plt] = ts[:len(ts)-1] + preKickToken = ts[len(ts)-1] + preKickPlt = plt continue } else { // Prioritize keeping Android if plt == constant.AndroidPlatformID { - kickToken = append(kickToken, preKick...) - kickToken = append(kickToken, ts[:len(ts)-1]...) + if preKickToken != "" { + kickToken[preKickPlt] = append(kickToken[preKickPlt], preKickToken) + } + kickToken[plt] = ts[:len(ts)-1] } else { - kickToken = append(kickToken, ts...) + kickToken[plt] = ts } } } @@ -207,19 +225,19 @@ func (a *authDatabase) checkToken(ctx context.Context, tokens map[int]map[string for plt, ts := range loginTokenMap { if constant.PlatformIDToClass(plt) == constant.PlatformIDToClass(platformID) { - kickToken = append(kickToken, ts...) + kickToken[plt] = ts } else { if _, ok := reserved[constant.PlatformIDToClass(plt)]; !ok { reserved[constant.PlatformIDToClass(plt)] = struct{}{} - kickToken = append(kickToken, ts[:len(ts)-1]...) + kickToken[plt] = ts[:len(ts)-1] continue } else { - kickToken = append(kickToken, ts...) + kickToken[plt] = ts } } } default: - return nil, nil, errs.New("unknown multiLogin policy").Wrap() + return nil, nil, nil, errs.New("unknown multiLogin policy").Wrap() } //var adminTokenMaxNum = a.multiLogin.MaxNumOneEnd @@ -233,5 +251,9 @@ func (a *authDatabase) checkToken(ctx context.Context, tokens map[int]map[string //if l > adminTokenMaxNum { // kickToken = append(kickToken, adminToken[:l-adminTokenMaxNum]...) //} - return deleteToken, kickToken, nil + var deleteAdminToken []string + if platformID == constant.AdminPlatformID { + deleteAdminToken = adminToken + } + return deleteToken, kickToken, deleteAdminToken, nil } diff --git a/pkg/common/storage/database/mgo/cache.go b/pkg/common/storage/database/mgo/cache.go new file mode 100644 index 000000000..991dfa874 --- /dev/null +++ b/pkg/common/storage/database/mgo/cache.go @@ -0,0 +1,183 @@ +package mgo + +import ( + "context" + "strconv" + "time" + + "github.com/google/uuid" + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/database" + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/model" + "github.com/openimsdk/tools/db/mongoutil" + "github.com/openimsdk/tools/errs" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" +) + +func NewCacheMgo(db *mongo.Database) (*CacheMgo, error) { + coll := db.Collection(database.CacheName) + _, err := coll.Indexes().CreateMany(context.Background(), []mongo.IndexModel{ + { + Keys: bson.D{ + {Key: "key", Value: 1}, + }, + Options: options.Index().SetUnique(true), + }, + { + Keys: bson.D{ + {Key: "expire_at", Value: 1}, + }, + Options: options.Index().SetExpireAfterSeconds(0), + }, + }) + if err != nil { + return nil, errs.Wrap(err) + } + return &CacheMgo{coll: coll}, nil +} + +type CacheMgo struct { + coll *mongo.Collection +} + +func (x *CacheMgo) findToMap(res []model.Cache, now time.Time) map[string]string { + kv := make(map[string]string) + for _, re := range res { + if re.ExpireAt != nil && re.ExpireAt.Before(now) { + continue + } + kv[re.Key] = re.Value + } + return kv + +} + +func (x *CacheMgo) Get(ctx context.Context, key []string) (map[string]string, error) { + if len(key) == 0 { + return nil, nil + } + now := time.Now() + res, err := mongoutil.Find[model.Cache](ctx, x.coll, bson.M{ + "key": bson.M{"$in": key}, + "$or": []bson.M{ + {"expire_at": bson.M{"$gt": now}}, + {"expire_at": nil}, + }, + }) + if err != nil { + return nil, err + } + return x.findToMap(res, now), nil +} + +func (x *CacheMgo) Prefix(ctx context.Context, prefix string) (map[string]string, error) { + now := time.Now() + res, err := mongoutil.Find[model.Cache](ctx, x.coll, bson.M{ + "key": bson.M{"$regex": "^" + prefix}, + "$or": []bson.M{ + {"expire_at": bson.M{"$gt": now}}, + {"expire_at": nil}, + }, + }) + if err != nil { + return nil, err + } + return x.findToMap(res, now), nil +} + +func (x *CacheMgo) Set(ctx context.Context, key string, value string, expireAt time.Duration) error { + cv := &model.Cache{ + Key: key, + Value: value, + } + if expireAt > 0 { + now := time.Now().Add(expireAt) + cv.ExpireAt = &now + } + opt := options.Update().SetUpsert(true) + return mongoutil.UpdateOne(ctx, x.coll, bson.M{"key": key}, bson.M{"$set": cv}, false, opt) +} + +func (x *CacheMgo) Incr(ctx context.Context, key string, value int) (int, error) { + pipeline := mongo.Pipeline{ + { + {"$set", bson.M{ + "value": bson.M{ + "$toString": bson.M{ + "$add": bson.A{ + bson.M{"$toInt": "$value"}, + value, + }, + }, + }, + }}, + }, + } + opt := options.FindOneAndUpdate().SetReturnDocument(options.After) + res, err := mongoutil.FindOneAndUpdate[model.Cache](ctx, x.coll, bson.M{"key": key}, pipeline, opt) + if err != nil { + return 0, err + } + return strconv.Atoi(res.Value) +} + +func (x *CacheMgo) Del(ctx context.Context, key []string) error { + if len(key) == 0 { + return nil + } + _, err := x.coll.DeleteMany(ctx, bson.M{"key": bson.M{"$in": key}}) + return errs.Wrap(err) +} + +func (x *CacheMgo) lockKey(key string) string { + return "LOCK_" + key +} + +func (x *CacheMgo) Lock(ctx context.Context, key string, duration time.Duration) (string, error) { + tmp, err := uuid.NewUUID() + if err != nil { + return "", err + } + if duration <= 0 || duration > time.Minute*10 { + duration = time.Minute * 10 + } + cv := &model.Cache{ + Key: x.lockKey(key), + Value: tmp.String(), + ExpireAt: nil, + } + ctx, cancel := context.WithTimeout(ctx, time.Second*30) + defer cancel() + wait := func() error { + timeout := time.NewTimer(time.Millisecond * 100) + defer timeout.Stop() + select { + case <-ctx.Done(): + return ctx.Err() + case <-timeout.C: + return nil + } + } + for { + if err := mongoutil.DeleteOne(ctx, x.coll, bson.M{"key": key, "expire_at": bson.M{"$lt": time.Now()}}); err != nil { + return "", err + } + expireAt := time.Now().Add(duration) + cv.ExpireAt = &expireAt + if err := mongoutil.InsertMany[*model.Cache](ctx, x.coll, []*model.Cache{cv}); err != nil { + if mongo.IsDuplicateKeyError(err) { + if err := wait(); err != nil { + return "", err + } + continue + } + return "", err + } + return cv.Value, nil + } +} + +func (x *CacheMgo) Unlock(ctx context.Context, key string, value string) error { + return mongoutil.DeleteOne(ctx, x.coll, bson.M{"key": x.lockKey(key), "value": value}) +}