diff --git a/internal/msggateway/n_ws_server.go b/internal/msggateway/n_ws_server.go index 7318b2cf1..2563ab3ce 100644 --- a/internal/msggateway/n_ws_server.go +++ b/internal/msggateway/n_ws_server.go @@ -188,9 +188,10 @@ func (ws *WsServer) multiTerminalLoginChecker(info *kickHandler) { for _, c := range info.oldClients { err := c.KickOnlineMessage() if err != nil { - log.ZWarn() + log.ZWarn(c.ctx, "kick online message error", err) } } + ws.cache.GetTokensWithoutError(info.newClient.ctx, info.newClient.UserID, info.newClient.PlatformID) } } diff --git a/internal/rpc/auth/auth.go b/internal/rpc/auth/auth.go index 8298c9a87..fe5dfd21e 100644 --- a/internal/rpc/auth/auth.go +++ b/internal/rpc/auth/auth.go @@ -42,7 +42,7 @@ func (s *authServer) UserToken(ctx context.Context, req *pbAuth.UserTokenReq) (* if _, err := s.userRpcClient.GetUserInfo(ctx, req.UserID); err != nil { return nil, err } - token, err := s.authDatabase.CreateToken(ctx, req.UserID, constant.PlatformIDToName(int(req.PlatformID))) + token, err := s.authDatabase.CreateToken(ctx, req.UserID, int(req.PlatformID)) if err != nil { return nil, err } @@ -56,7 +56,7 @@ func (s *authServer) parseToken(ctx context.Context, tokensString string) (claim if err != nil { return nil, utils.Wrap(err, "") } - m, err := s.authDatabase.GetTokensWithoutError(ctx, claims.UID, claims.Platform) + m, err := s.authDatabase.GetTokensWithoutError(ctx, claims.UserID, claims.PlatformID) if err != nil { return nil, err } @@ -82,8 +82,8 @@ func (s *authServer) ParseToken(ctx context.Context, req *pbAuth.ParseTokenReq) if err != nil { return nil, err } - resp.UserID = claims.UID - resp.Platform = claims.Platform + resp.UserID = claims.UserID + resp.Platform = constant.PlatformIDToName(claims.PlatformID) resp.ExpireTimeSeconds = claims.ExpiresAt.Unix() return resp, nil } diff --git a/pkg/common/db/cache/msg.go b/pkg/common/db/cache/msg.go index d3bb47b91..30d5f1ffc 100644 --- a/pkg/common/db/cache/msg.go +++ b/pkg/common/db/cache/msg.go @@ -88,9 +88,9 @@ type MsgModel interface { SeqCache thirdCache AddTokenFlag(ctx context.Context, userID string, platformID int, token string, flag int) error - GetTokensWithoutError(ctx context.Context, userID, platformID string) (map[string]int, error) - SetTokenMapByUidPid(ctx context.Context, userID string, platform string, m map[string]int) error - DeleteTokenByUidPid(ctx context.Context, userID string, platform string, fields []string) error + GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error) + SetTokenMapByUidPid(ctx context.Context, userID string, platformID int, m map[string]int) error + DeleteTokenByUidPid(ctx context.Context, userID string, platformID int, fields []string) error GetMessagesBySeq(ctx context.Context, conversationID string, seqs []int64) (seqMsg []*sdkws.MsgData, failedSeqList []int64, err error) SetMessageToCache(ctx context.Context, conversationID string, msgs []*sdkws.MsgData) (int, error) UserDeleteMsgs(ctx context.Context, conversationID string, seqs []int64, userID string) error @@ -260,8 +260,8 @@ func (c *msgCache) AddTokenFlag(ctx context.Context, userID string, platformID i return errs.Wrap(c.rdb.HSet(ctx, key, token, flag).Err()) } -func (c *msgCache) GetTokensWithoutError(ctx context.Context, userID, platformID string) (map[string]int, error) { - key := uidPidToken + userID + ":" + platformID +func (c *msgCache) GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error) { + key := uidPidToken + userID + ":" + constant.PlatformIDToName(platformID) m, err := c.rdb.HGetAll(ctx, key).Result() if err != nil { return nil, errs.Wrap(err) @@ -273,8 +273,8 @@ func (c *msgCache) GetTokensWithoutError(ctx context.Context, userID, platformID return mm, nil } -func (c *msgCache) SetTokenMapByUidPid(ctx context.Context, userID string, platform string, m map[string]int) error { - key := uidPidToken + userID + ":" + platform +func (c *msgCache) SetTokenMapByUidPid(ctx context.Context, userID string, platform int, m map[string]int) error { + key := uidPidToken + userID + ":" + constant.PlatformIDToName(platform) mm := make(map[string]interface{}) for k, v := range m { mm[k] = v @@ -282,8 +282,8 @@ func (c *msgCache) SetTokenMapByUidPid(ctx context.Context, userID string, platf return errs.Wrap(c.rdb.HSet(ctx, key, mm).Err()) } -func (c *msgCache) DeleteTokenByUidPid(ctx context.Context, userID string, platform string, fields []string) error { - key := uidPidToken + userID + ":" + platform +func (c *msgCache) DeleteTokenByUidPid(ctx context.Context, userID string, platform int, fields []string) error { + key := uidPidToken + userID + ":" + constant.PlatformIDToName(platform) return errs.Wrap(c.rdb.HDel(ctx, key, fields...).Err()) } diff --git a/pkg/common/db/controller/auth.go b/pkg/common/db/controller/auth.go index 148ef6c96..6d6add902 100644 --- a/pkg/common/db/controller/auth.go +++ b/pkg/common/db/controller/auth.go @@ -12,9 +12,9 @@ import ( type AuthDatabase interface { //结果为空 不返回错误 - GetTokensWithoutError(ctx context.Context, userID, platform string) (map[string]int, error) + GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error) //创建token - CreateToken(ctx context.Context, userID string, platform string) (string, error) + CreateToken(ctx context.Context, userID string, platformID int) (string, error) } type authDatabase struct { @@ -29,13 +29,13 @@ func NewAuthDatabase(cache cache.MsgModel, accessSecret string, accessExpire int } // 结果为空 不返回错误 -func (a *authDatabase) GetTokensWithoutError(ctx context.Context, userID, platform string) (map[string]int, error) { - return a.cache.GetTokensWithoutError(ctx, userID, platform) +func (a *authDatabase) GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error) { + return a.cache.GetTokensWithoutError(ctx, userID, platformID) } // 创建token -func (a *authDatabase) CreateToken(ctx context.Context, userID string, platform string) (string, error) { - tokens, err := a.cache.GetTokensWithoutError(ctx, userID, platform) +func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformID int) (string, error) { + tokens, err := a.cache.GetTokensWithoutError(ctx, userID, platformID) if err != nil { return "", err } @@ -47,16 +47,16 @@ func (a *authDatabase) CreateToken(ctx context.Context, userID string, platform } } if len(deleteTokenKey) != 0 { - err := a.cache.DeleteTokenByUidPid(ctx, userID, platform, deleteTokenKey) + err := a.cache.DeleteTokenByUidPid(ctx, userID, platformID, deleteTokenKey) if err != nil { return "", err } } - claims := tokenverify.BuildClaims(userID, platform, a.accessExpire) + claims := tokenverify.BuildClaims(userID, platformID, a.accessExpire) token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) tokenString, err := token.SignedString([]byte(a.accessSecret)) if err != nil { return "", utils.Wrap(err, "") } - return tokenString, a.cache.AddTokenFlag(ctx, userID, constant.PlatformNameToID(platform), tokenString, constant.NormalToken) + return tokenString, a.cache.AddTokenFlag(ctx, userID, platformID, tokenString, constant.NormalToken) } diff --git a/pkg/common/mw/gin.go b/pkg/common/mw/gin.go index 4b12a8244..6c87fbb5b 100644 --- a/pkg/common/mw/gin.go +++ b/pkg/common/mw/gin.go @@ -128,7 +128,7 @@ func GinParseToken(rdb redis.UniversalClient) gin.HandlerFunc { c.Abort() return } - m, err := dataBase.GetTokensWithoutError(c, claims.UID, claims.Platform) + m, err := dataBase.GetTokensWithoutError(c, claims.UserID, claims.PlatformID) if err != nil { log.ZWarn(c, "cache get token error", errs.ErrTokenNotExist.Wrap()) apiresp.GinError(c, errs.ErrTokenNotExist.Wrap()) @@ -156,8 +156,8 @@ func GinParseToken(rdb redis.UniversalClient) gin.HandlerFunc { return } } - c.Set(constant.OpUserPlatform, claims.Platform) - c.Set(constant.OpUserID, claims.UID) + c.Set(constant.OpUserPlatform, constant.PlatformIDToName(claims.PlatformID)) + c.Set(constant.OpUserID, claims.UserID) c.Next() } } diff --git a/pkg/common/tokenverify/jwt_token.go b/pkg/common/tokenverify/jwt_token.go index 1f2f0797c..65a31545e 100644 --- a/pkg/common/tokenverify/jwt_token.go +++ b/pkg/common/tokenverify/jwt_token.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/config" - "github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/mcontext" "github.com/OpenIMSDK/Open-IM-Server/pkg/errs" "github.com/OpenIMSDK/Open-IM-Server/pkg/utils" @@ -14,17 +13,17 @@ import ( ) type Claims struct { - UID string - Platform string //login platform + UserID string + PlatformID int //login platform jwt.RegisteredClaims } -func BuildClaims(uid, platform string, ttl int64) Claims { +func BuildClaims(uid string, platformID int, ttl int64) Claims { now := time.Now() before := now.Add(-time.Minute * 5) return Claims{ - UID: uid, - Platform: platform, + UserID: uid, + PlatformID: platformID, RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(now.Add(time.Duration(ttl*24) * time.Hour)), //Expiration time IssuedAt: jwt.NewNumericDate(now), //Issuing time @@ -95,19 +94,15 @@ func WsVerifyToken(token, userID, platformID string) error { if err != nil { return errs.ErrArgs.Wrap(fmt.Sprintf("platformID %s is not int", platformID)) } - platform := constant.PlatformIDToName(platformIDInt) - if platform == "" { - return errs.ErrArgs.Wrap(fmt.Sprintf("platformID %s is not exist", platformID)) - } claim, err := GetClaimFromToken(token) if err != nil { return err } - if claim.UID != userID { - return errs.ErrTokenInvalid.Wrap(fmt.Sprintf("token uid %s != userID %s", claim.UID, userID)) + if claim.UserID != userID { + return errs.ErrTokenInvalid.Wrap(fmt.Sprintf("token uid %s != userID %s", claim.UserID, userID)) } - if claim.Platform != platform { - return errs.ErrTokenInvalid.Wrap(fmt.Sprintf("token platform %s != %s", claim.Platform, platform)) + if claim.PlatformID != platformIDInt { + return errs.ErrTokenInvalid.Wrap(fmt.Sprintf("token platform %d != %d", claim.PlatformID, platformIDInt)) } return nil }