diff --git a/internal/common/rpc_server/a.go b/internal/common/rpc_server/a.go index 75223eaf6..107c0cbf8 100644 --- a/internal/common/rpc_server/a.go +++ b/internal/common/rpc_server/a.go @@ -36,6 +36,7 @@ func NewRpcServer(registerIPInConfig string, port int, registerName string, zkSe return nil, err } s.RegisterCenter = zkClient + return s, nil } diff --git a/internal/rpc/auth/auth.go b/internal/rpc/auth/auth.go index eb0e35d5f..f723731e4 100644 --- a/internal/rpc/auth/auth.go +++ b/internal/rpc/auth/auth.go @@ -5,6 +5,7 @@ import ( "Open_IM/internal/common/rpc_server" "Open_IM/pkg/common/config" "Open_IM/pkg/common/constant" + "Open_IM/pkg/common/db/cache" "Open_IM/pkg/common/db/controller" "Open_IM/pkg/common/log" promePkg "Open_IM/pkg/common/prometheus" @@ -23,8 +24,11 @@ func NewRpcAuthServer(port int) *rpcAuth { if err != nil { panic(err) } + var redis cache.RedisClient + redis.InitRedis() return &rpcAuth{ - RpcServer: r, + RpcServer: r, + AuthInterface: controller.NewAuthController(redis.GetClient(), config.Config.TokenPolicy.AccessSecret, config.Config.TokenPolicy.AccessExpire), } } @@ -64,7 +68,7 @@ func (s *rpcAuth) UserToken(ctx context.Context, req *pbAuth.UserTokenReq) (*pbA if _, err := check.GetUsersInfo(ctx, req.UserID); err != nil { return nil, err } - token, err := s.CreateToken(ctx, req.UserID, int(req.PlatformID), config.Config.TokenPolicy.AccessExpire) + token, err := s.CreateToken(ctx, req.UserID, constant.PlatformIDToName(int(req.PlatformID))) if err != nil { return nil, err } @@ -73,39 +77,41 @@ func (s *rpcAuth) UserToken(ctx context.Context, req *pbAuth.UserTokenReq) (*pbA return &resp, nil } -func (s *rpcAuth) parseToken(ctx context.Context, tokensString, operationID string) (claims *tokenverify.Claims, err error) { +func (s *rpcAuth) parseToken(ctx context.Context, tokensString string) (claims *tokenverify.Claims, err error) { claims, err = tokenverify.GetClaimFromToken(tokensString) if err != nil { return nil, utils.Wrap(err, "") } - m, err := s.GetTokens(ctx, claims.UID, claims.Platform) + m, err := s.GetTokensWithoutError(ctx, claims.UID, claims.Platform) if err != nil { return nil, err } - + if len(m) == 0 { + return nil, constant.ErrTokenNotExist.Wrap() + } if v, ok := m[tokensString]; ok { switch v { case constant.NormalToken: return claims, nil case constant.KickedToken: - return nil, utils.Wrap(constant.ErrTokenKicked, "this token has been kicked by other same terminal ") + return nil, constant.ErrTokenKicked.Wrap() default: return nil, utils.Wrap(constant.ErrTokenUnknown, "") } } - return nil, utils.Wrap(constant.ErrTokenNotExist, "redis token map not find") + return nil, constant.ErrTokenNotExist.Wrap() } -func (s *rpcAuth) ParseToken(ctx context.Context, req *pbAuth.ParseTokenReq) (*pbAuth.ParseTokenResp, error) { - resp := pbAuth.ParseTokenResp{} - claims, err := s.parseToken(ctx, req.Token, req.OperationID) +func (s *rpcAuth) ParseToken(ctx context.Context, req *pbAuth.ParseTokenReq) (resp *pbAuth.ParseTokenResp, err error) { + resp = &pbAuth.ParseTokenResp{} + claims, err := s.parseToken(ctx, req.Token) if err != nil { return nil, err } resp.UserID = claims.UID resp.Platform = claims.Platform resp.ExpireTimeSeconds = claims.ExpiresAt.Unix() - return &resp, nil + return resp, nil } func (s *rpcAuth) ForceLogout(ctx context.Context, req *pbAuth.ForceLogoutReq) (*pbAuth.ForceLogoutResp, error) { diff --git a/pkg/common/config/config.go b/pkg/common/config/config.go index 647887063..35a3f93f5 100644 --- a/pkg/common/config/config.go +++ b/pkg/common/config/config.go @@ -665,15 +665,6 @@ func initConfig(config interface{}, configName, configPath string) { func InitConfig(configPath string) { initConfig(&Config, "config.yaml", configPath) initConfig(&UsualConfig, "usualConfig.yaml", configPath) - if Config.Etcd.UserName == "" { - Config.Etcd.UserName = UsualConfig.Etcd.UserName - } - if Config.Etcd.Password == "" { - Config.Etcd.Password = UsualConfig.Etcd.Password - } - if Config.Etcd.Secret == "" { - Config.Etcd.Secret = UsualConfig.Etcd.Secret - } if Config.Mysql.DBUserName == "" { Config.Mysql.DBUserName = UsualConfig.Mysql.DBUserName diff --git a/pkg/common/db/cache/redis.go b/pkg/common/db/cache/redis.go index f9b090f2e..cf06782ce 100644 --- a/pkg/common/db/cache/redis.go +++ b/pkg/common/db/cache/redis.go @@ -20,10 +20,10 @@ import ( ) const ( - userIncrSeq = "REDIS_USER_INCR_SEQ:" // user incr seq - appleDeviceToken = "DEVICE_TOKEN" - userMinSeq = "REDIS_USER_MIN_SEQ:" - uidPidToken = "UID_PID_TOKEN_STATUS:" + userIncrSeq = "REDIS_USER_INCR_SEQ:" // user incr seq + appleDeviceToken = "DEVICE_TOKEN" + userMinSeq = "REDIS_USER_MIN_SEQ:" + getuiToken = "GETUI_TOKEN" getuiTaskID = "GETUI_TASK_ID" messageCache = "MESSAGE_CACHE:" @@ -94,33 +94,33 @@ func NewRedisClient(rdb redis.UniversalClient) *RedisClient { return &RedisClient{rdb: rdb} } -//Perform seq auto-increment operation of user messages +// Perform seq auto-increment operation of user messages func (r *RedisClient) IncrUserSeq(uid string) (uint64, error) { key := userIncrSeq + uid seq, err := r.rdb.Incr(context.Background(), key).Result() return uint64(seq), err } -//Get the largest Seq +// Get the largest Seq func (r *RedisClient) GetUserMaxSeq(uid string) (uint64, error) { key := userIncrSeq + uid seq, err := r.rdb.Get(context.Background(), key).Result() return uint64(utils.StringToInt(seq)), err } -//set the largest Seq +// set the largest Seq func (r *RedisClient) SetUserMaxSeq(uid string, maxSeq uint64) error { key := userIncrSeq + uid return r.rdb.Set(context.Background(), key, maxSeq, 0).Err() } -//Set the user's minimum seq +// Set the user's minimum seq func (r *RedisClient) SetUserMinSeq(uid string, minSeq uint32) (err error) { key := userMinSeq + uid return r.rdb.Set(context.Background(), key, minSeq, 0).Err() } -//Get the smallest Seq +// Get the smallest Seq func (r *RedisClient) GetUserMinSeq(uid string) (uint64, error) { key := userMinSeq + uid seq, err := r.rdb.Get(context.Background(), key).Result() @@ -159,7 +159,7 @@ func (r *RedisClient) SetGroupMinSeq(groupID string, minSeq uint32) error { return r.rdb.Set(context.Background(), key, minSeq, 0).Err() } -//Store userid and platform class to redis +// Store userid and platform class to redis func (r *RedisClient) AddTokenFlag(userID string, platformID int, token string, flag int) error { key := uidPidToken + userID + ":" + constant.PlatformIDToName(platformID) log2.NewDebug("", "add token key is ", key) diff --git a/pkg/common/db/cache/token.go b/pkg/common/db/cache/token.go new file mode 100644 index 000000000..2aba43a0e --- /dev/null +++ b/pkg/common/db/cache/token.go @@ -0,0 +1,75 @@ +package cache + +import ( + "Open_IM/pkg/common/constant" + "Open_IM/pkg/common/tokenverify" + "Open_IM/pkg/utils" + "context" + go_redis "github.com/go-redis/redis/v8" + "github.com/golang-jwt/jwt/v4" +) + +const ( + uidPidToken = "UID_PID_TOKEN_STATUS:" +) + +type Token interface { + //结果为空 不返回错误 + GetTokensWithoutError(ctx context.Context, userID, platform string) (map[string]int, error) + //创建token + CreateToken(ctx context.Context, userID string, platformID int) (string, error) +} + +type TokenRedis struct { + RedisClient *RedisClient + AccessSecret string + AccessExpire int64 +} + +func NewTokenRedis(redisClient *RedisClient, accessSecret string, accessExpire int64) *TokenRedis { + return &TokenRedis{redisClient, accessSecret, accessExpire} +} + +// 结果为空 不返回错误 +func (t *TokenRedis) GetTokensWithoutError(ctx context.Context, userID, platform string) (map[string]int, error) { + key := uidPidToken + userID + ":" + platform + m, err := t.RedisClient.GetClient().HGetAll(context.Background(), key).Result() + if err != nil && err == go_redis.Nil { + return nil, nil + } + mm := make(map[string]int) + for k, v := range m { + mm[k] = utils.StringToInt(v) + } + return mm, utils.Wrap(err, "") +} + +// 创建token +func (t *TokenRedis) CreateToken(ctx context.Context, userID string, platform string) (string, error) { + tokens, err := t.GetTokensWithoutError(ctx, userID, platform) + if err != nil { + return "", err + } + var deleteTokenKey []string + for k, v := range tokens { + _, err = tokenverify.GetClaimFromToken(k) + if err != nil || v != constant.NormalToken { + deleteTokenKey = append(deleteTokenKey, k) + } + } + if len(deleteTokenKey) != 0 { + key := uidPidToken + userID + ":" + platform + err := t.RedisClient.GetClient().HDel(context.Background(), key, deleteTokenKey...).Err() + if err != nil { + return "", err + } + } + claims := tokenverify.BuildClaims(userID, platform, t.AccessExpire) + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + tokenString, err := token.SignedString([]byte(t.AccessSecret)) + if err != nil { + return "", utils.Wrap(err, "") + } + key := uidPidToken + userID + ":" + platform + return "", utils.Wrap(t.RedisClient.GetClient().HSet(context.Background(), key, tokenString, constant.NormalToken).Err(), "") +} diff --git a/pkg/common/db/controller/auth.go b/pkg/common/db/controller/auth.go index 9c9bcf6df..eba88fd7d 100644 --- a/pkg/common/db/controller/auth.go +++ b/pkg/common/db/controller/auth.go @@ -1,9 +1,34 @@ package controller -import "context" +import ( + "Open_IM/pkg/common/db/cache" + "context" + "github.com/go-redis/redis/v8" +) type AuthInterface interface { - GetTokens(ctx context.Context, userID, platform string) (map[string]int, error) - DeleteToken(ctx context.Context, userID, platform string) error - CreateToken(ctx context.Context, userID string, platformID int, ttl int64) (string, error) + //结果为空 不返回错误 + GetTokensWithoutError(ctx context.Context, userID, platform string) (map[string]int, error) + + //创建token + CreateToken(ctx context.Context, userID string, platform string) (string, error) +} + +type AuthController struct { + database *cache.TokenRedis +} + +func NewAuthController(rdb redis.UniversalClient, accessSecret string, accessExpire int64) *AuthController { + cache.NewRedisClient(rdb) + return &AuthController{database: cache.NewTokenRedis(cache.NewRedisClient(rdb), accessSecret, accessExpire)} +} + +// 结果为空 不返回错误 +func (a *AuthController) GetTokensWithoutError(ctx context.Context, userID, platform string) (map[string]int, error) { + return a.database.GetTokensWithoutError(ctx, userID, platform) +} + +// 创建token +func (a *AuthController) CreateToken(ctx context.Context, userID string, platform string) (string, error) { + return a.database.CreateToken(ctx, userID, platform) }