refactor: token platformID update

This commit is contained in:
Gordon 2023-06-14 10:47:18 +08:00
parent 87d64c6afe
commit 0124a5c05d
6 changed files with 36 additions and 40 deletions

View File

@ -188,9 +188,10 @@ func (ws *WsServer) multiTerminalLoginChecker(info *kickHandler) {
for _, c := range info.oldClients { for _, c := range info.oldClients {
err := c.KickOnlineMessage() err := c.KickOnlineMessage()
if err != nil { if err != nil {
log.ZWarn() log.ZWarn(c.ctx, "kick online message error", err)
} }
} }
ws.cache.GetTokensWithoutError(info.newClient.ctx, info.newClient.UserID, info.newClient.PlatformID)
} }
} }

View File

@ -42,7 +42,7 @@ func (s *authServer) UserToken(ctx context.Context, req *pbAuth.UserTokenReq) (*
if _, err := s.userRpcClient.GetUserInfo(ctx, req.UserID); err != nil { if _, err := s.userRpcClient.GetUserInfo(ctx, req.UserID); err != nil {
return nil, err return nil, err
} }
token, err := s.authDatabase.CreateToken(ctx, req.UserID, constant.PlatformIDToName(int(req.PlatformID))) token, err := s.authDatabase.CreateToken(ctx, req.UserID, int(req.PlatformID))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -56,7 +56,7 @@ func (s *authServer) parseToken(ctx context.Context, tokensString string) (claim
if err != nil { if err != nil {
return nil, utils.Wrap(err, "") return nil, utils.Wrap(err, "")
} }
m, err := s.authDatabase.GetTokensWithoutError(ctx, claims.UID, claims.Platform) m, err := s.authDatabase.GetTokensWithoutError(ctx, claims.UserID, claims.PlatformID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -82,8 +82,8 @@ func (s *authServer) ParseToken(ctx context.Context, req *pbAuth.ParseTokenReq)
if err != nil { if err != nil {
return nil, err return nil, err
} }
resp.UserID = claims.UID resp.UserID = claims.UserID
resp.Platform = claims.Platform resp.Platform = constant.PlatformIDToName(claims.PlatformID)
resp.ExpireTimeSeconds = claims.ExpiresAt.Unix() resp.ExpireTimeSeconds = claims.ExpiresAt.Unix()
return resp, nil return resp, nil
} }

View File

@ -88,9 +88,9 @@ type MsgModel interface {
SeqCache SeqCache
thirdCache thirdCache
AddTokenFlag(ctx context.Context, userID string, platformID int, token string, flag int) error AddTokenFlag(ctx context.Context, userID string, platformID int, token string, flag int) error
GetTokensWithoutError(ctx context.Context, userID, platformID string) (map[string]int, error) GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error)
SetTokenMapByUidPid(ctx context.Context, userID string, platform string, m map[string]int) error SetTokenMapByUidPid(ctx context.Context, userID string, platformID int, m map[string]int) error
DeleteTokenByUidPid(ctx context.Context, userID string, platform string, fields []string) error DeleteTokenByUidPid(ctx context.Context, userID string, platformID int, fields []string) error
GetMessagesBySeq(ctx context.Context, conversationID string, seqs []int64) (seqMsg []*sdkws.MsgData, failedSeqList []int64, err error) GetMessagesBySeq(ctx context.Context, conversationID string, seqs []int64) (seqMsg []*sdkws.MsgData, failedSeqList []int64, err error)
SetMessageToCache(ctx context.Context, conversationID string, msgs []*sdkws.MsgData) (int, error) SetMessageToCache(ctx context.Context, conversationID string, msgs []*sdkws.MsgData) (int, error)
UserDeleteMsgs(ctx context.Context, conversationID string, seqs []int64, userID string) error UserDeleteMsgs(ctx context.Context, conversationID string, seqs []int64, userID string) error
@ -260,8 +260,8 @@ func (c *msgCache) AddTokenFlag(ctx context.Context, userID string, platformID i
return errs.Wrap(c.rdb.HSet(ctx, key, token, flag).Err()) return errs.Wrap(c.rdb.HSet(ctx, key, token, flag).Err())
} }
func (c *msgCache) GetTokensWithoutError(ctx context.Context, userID, platformID string) (map[string]int, error) { func (c *msgCache) GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error) {
key := uidPidToken + userID + ":" + platformID key := uidPidToken + userID + ":" + constant.PlatformIDToName(platformID)
m, err := c.rdb.HGetAll(ctx, key).Result() m, err := c.rdb.HGetAll(ctx, key).Result()
if err != nil { if err != nil {
return nil, errs.Wrap(err) return nil, errs.Wrap(err)
@ -273,8 +273,8 @@ func (c *msgCache) GetTokensWithoutError(ctx context.Context, userID, platformID
return mm, nil return mm, nil
} }
func (c *msgCache) SetTokenMapByUidPid(ctx context.Context, userID string, platform string, m map[string]int) error { func (c *msgCache) SetTokenMapByUidPid(ctx context.Context, userID string, platform int, m map[string]int) error {
key := uidPidToken + userID + ":" + platform key := uidPidToken + userID + ":" + constant.PlatformIDToName(platform)
mm := make(map[string]interface{}) mm := make(map[string]interface{})
for k, v := range m { for k, v := range m {
mm[k] = v mm[k] = v
@ -282,8 +282,8 @@ func (c *msgCache) SetTokenMapByUidPid(ctx context.Context, userID string, platf
return errs.Wrap(c.rdb.HSet(ctx, key, mm).Err()) return errs.Wrap(c.rdb.HSet(ctx, key, mm).Err())
} }
func (c *msgCache) DeleteTokenByUidPid(ctx context.Context, userID string, platform string, fields []string) error { func (c *msgCache) DeleteTokenByUidPid(ctx context.Context, userID string, platform int, fields []string) error {
key := uidPidToken + userID + ":" + platform key := uidPidToken + userID + ":" + constant.PlatformIDToName(platform)
return errs.Wrap(c.rdb.HDel(ctx, key, fields...).Err()) return errs.Wrap(c.rdb.HDel(ctx, key, fields...).Err())
} }

View File

@ -12,9 +12,9 @@ import (
type AuthDatabase interface { type AuthDatabase interface {
//结果为空 不返回错误 //结果为空 不返回错误
GetTokensWithoutError(ctx context.Context, userID, platform string) (map[string]int, error) GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error)
//创建token //创建token
CreateToken(ctx context.Context, userID string, platform string) (string, error) CreateToken(ctx context.Context, userID string, platformID int) (string, error)
} }
type authDatabase struct { type authDatabase struct {
@ -29,13 +29,13 @@ func NewAuthDatabase(cache cache.MsgModel, accessSecret string, accessExpire int
} }
// 结果为空 不返回错误 // 结果为空 不返回错误
func (a *authDatabase) GetTokensWithoutError(ctx context.Context, userID, platform string) (map[string]int, error) { func (a *authDatabase) GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error) {
return a.cache.GetTokensWithoutError(ctx, userID, platform) return a.cache.GetTokensWithoutError(ctx, userID, platformID)
} }
// 创建token // 创建token
func (a *authDatabase) CreateToken(ctx context.Context, userID string, platform string) (string, error) { func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformID int) (string, error) {
tokens, err := a.cache.GetTokensWithoutError(ctx, userID, platform) tokens, err := a.cache.GetTokensWithoutError(ctx, userID, platformID)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -47,16 +47,16 @@ func (a *authDatabase) CreateToken(ctx context.Context, userID string, platform
} }
} }
if len(deleteTokenKey) != 0 { if len(deleteTokenKey) != 0 {
err := a.cache.DeleteTokenByUidPid(ctx, userID, platform, deleteTokenKey) err := a.cache.DeleteTokenByUidPid(ctx, userID, platformID, deleteTokenKey)
if err != nil { if err != nil {
return "", err return "", err
} }
} }
claims := tokenverify.BuildClaims(userID, platform, a.accessExpire) claims := tokenverify.BuildClaims(userID, platformID, a.accessExpire)
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString([]byte(a.accessSecret)) tokenString, err := token.SignedString([]byte(a.accessSecret))
if err != nil { if err != nil {
return "", utils.Wrap(err, "") return "", utils.Wrap(err, "")
} }
return tokenString, a.cache.AddTokenFlag(ctx, userID, constant.PlatformNameToID(platform), tokenString, constant.NormalToken) return tokenString, a.cache.AddTokenFlag(ctx, userID, platformID, tokenString, constant.NormalToken)
} }

View File

@ -128,7 +128,7 @@ func GinParseToken(rdb redis.UniversalClient) gin.HandlerFunc {
c.Abort() c.Abort()
return return
} }
m, err := dataBase.GetTokensWithoutError(c, claims.UID, claims.Platform) m, err := dataBase.GetTokensWithoutError(c, claims.UserID, claims.PlatformID)
if err != nil { if err != nil {
log.ZWarn(c, "cache get token error", errs.ErrTokenNotExist.Wrap()) log.ZWarn(c, "cache get token error", errs.ErrTokenNotExist.Wrap())
apiresp.GinError(c, errs.ErrTokenNotExist.Wrap()) apiresp.GinError(c, errs.ErrTokenNotExist.Wrap())
@ -156,8 +156,8 @@ func GinParseToken(rdb redis.UniversalClient) gin.HandlerFunc {
return return
} }
} }
c.Set(constant.OpUserPlatform, claims.Platform) c.Set(constant.OpUserPlatform, constant.PlatformIDToName(claims.PlatformID))
c.Set(constant.OpUserID, claims.UID) c.Set(constant.OpUserID, claims.UserID)
c.Next() c.Next()
} }
} }

View File

@ -4,7 +4,6 @@ import (
"context" "context"
"fmt" "fmt"
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/config" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/config"
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant"
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/mcontext" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/mcontext"
"github.com/OpenIMSDK/Open-IM-Server/pkg/errs" "github.com/OpenIMSDK/Open-IM-Server/pkg/errs"
"github.com/OpenIMSDK/Open-IM-Server/pkg/utils" "github.com/OpenIMSDK/Open-IM-Server/pkg/utils"
@ -14,17 +13,17 @@ import (
) )
type Claims struct { type Claims struct {
UID string UserID string
Platform string //login platform PlatformID int //login platform
jwt.RegisteredClaims jwt.RegisteredClaims
} }
func BuildClaims(uid, platform string, ttl int64) Claims { func BuildClaims(uid string, platformID int, ttl int64) Claims {
now := time.Now() now := time.Now()
before := now.Add(-time.Minute * 5) before := now.Add(-time.Minute * 5)
return Claims{ return Claims{
UID: uid, UserID: uid,
Platform: platform, PlatformID: platformID,
RegisteredClaims: jwt.RegisteredClaims{ RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(now.Add(time.Duration(ttl*24) * time.Hour)), //Expiration time ExpiresAt: jwt.NewNumericDate(now.Add(time.Duration(ttl*24) * time.Hour)), //Expiration time
IssuedAt: jwt.NewNumericDate(now), //Issuing time IssuedAt: jwt.NewNumericDate(now), //Issuing time
@ -95,19 +94,15 @@ func WsVerifyToken(token, userID, platformID string) error {
if err != nil { if err != nil {
return errs.ErrArgs.Wrap(fmt.Sprintf("platformID %s is not int", platformID)) return errs.ErrArgs.Wrap(fmt.Sprintf("platformID %s is not int", platformID))
} }
platform := constant.PlatformIDToName(platformIDInt)
if platform == "" {
return errs.ErrArgs.Wrap(fmt.Sprintf("platformID %s is not exist", platformID))
}
claim, err := GetClaimFromToken(token) claim, err := GetClaimFromToken(token)
if err != nil { if err != nil {
return err return err
} }
if claim.UID != userID { if claim.UserID != userID {
return errs.ErrTokenInvalid.Wrap(fmt.Sprintf("token uid %s != userID %s", claim.UID, userID)) return errs.ErrTokenInvalid.Wrap(fmt.Sprintf("token uid %s != userID %s", claim.UserID, userID))
} }
if claim.Platform != platform { if claim.PlatformID != platformIDInt {
return errs.ErrTokenInvalid.Wrap(fmt.Sprintf("token platform %s != %s", claim.Platform, platform)) return errs.ErrTokenInvalid.Wrap(fmt.Sprintf("token platform %d != %d", claim.PlatformID, platformIDInt))
} }
return nil return nil
} }