mirror of
https://github.com/openimsdk/open-im-server.git
synced 2025-04-26 03:26:57 +08:00
refactor: token platformID update
This commit is contained in:
parent
87d64c6afe
commit
0124a5c05d
@ -188,9 +188,10 @@ func (ws *WsServer) multiTerminalLoginChecker(info *kickHandler) {
|
|||||||
for _, c := range info.oldClients {
|
for _, c := range info.oldClients {
|
||||||
err := c.KickOnlineMessage()
|
err := c.KickOnlineMessage()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.ZWarn()
|
log.ZWarn(c.ctx, "kick online message error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
ws.cache.GetTokensWithoutError(info.newClient.ctx, info.newClient.UserID, info.newClient.PlatformID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -42,7 +42,7 @@ func (s *authServer) UserToken(ctx context.Context, req *pbAuth.UserTokenReq) (*
|
|||||||
if _, err := s.userRpcClient.GetUserInfo(ctx, req.UserID); err != nil {
|
if _, err := s.userRpcClient.GetUserInfo(ctx, req.UserID); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
token, err := s.authDatabase.CreateToken(ctx, req.UserID, constant.PlatformIDToName(int(req.PlatformID)))
|
token, err := s.authDatabase.CreateToken(ctx, req.UserID, int(req.PlatformID))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -56,7 +56,7 @@ func (s *authServer) parseToken(ctx context.Context, tokensString string) (claim
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, utils.Wrap(err, "")
|
return nil, utils.Wrap(err, "")
|
||||||
}
|
}
|
||||||
m, err := s.authDatabase.GetTokensWithoutError(ctx, claims.UID, claims.Platform)
|
m, err := s.authDatabase.GetTokensWithoutError(ctx, claims.UserID, claims.PlatformID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -82,8 +82,8 @@ func (s *authServer) ParseToken(ctx context.Context, req *pbAuth.ParseTokenReq)
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
resp.UserID = claims.UID
|
resp.UserID = claims.UserID
|
||||||
resp.Platform = claims.Platform
|
resp.Platform = constant.PlatformIDToName(claims.PlatformID)
|
||||||
resp.ExpireTimeSeconds = claims.ExpiresAt.Unix()
|
resp.ExpireTimeSeconds = claims.ExpiresAt.Unix()
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
18
pkg/common/db/cache/msg.go
vendored
18
pkg/common/db/cache/msg.go
vendored
@ -88,9 +88,9 @@ type MsgModel interface {
|
|||||||
SeqCache
|
SeqCache
|
||||||
thirdCache
|
thirdCache
|
||||||
AddTokenFlag(ctx context.Context, userID string, platformID int, token string, flag int) error
|
AddTokenFlag(ctx context.Context, userID string, platformID int, token string, flag int) error
|
||||||
GetTokensWithoutError(ctx context.Context, userID, platformID string) (map[string]int, error)
|
GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error)
|
||||||
SetTokenMapByUidPid(ctx context.Context, userID string, platform string, m map[string]int) error
|
SetTokenMapByUidPid(ctx context.Context, userID string, platformID int, m map[string]int) error
|
||||||
DeleteTokenByUidPid(ctx context.Context, userID string, platform string, fields []string) error
|
DeleteTokenByUidPid(ctx context.Context, userID string, platformID int, fields []string) error
|
||||||
GetMessagesBySeq(ctx context.Context, conversationID string, seqs []int64) (seqMsg []*sdkws.MsgData, failedSeqList []int64, err error)
|
GetMessagesBySeq(ctx context.Context, conversationID string, seqs []int64) (seqMsg []*sdkws.MsgData, failedSeqList []int64, err error)
|
||||||
SetMessageToCache(ctx context.Context, conversationID string, msgs []*sdkws.MsgData) (int, error)
|
SetMessageToCache(ctx context.Context, conversationID string, msgs []*sdkws.MsgData) (int, error)
|
||||||
UserDeleteMsgs(ctx context.Context, conversationID string, seqs []int64, userID string) error
|
UserDeleteMsgs(ctx context.Context, conversationID string, seqs []int64, userID string) error
|
||||||
@ -260,8 +260,8 @@ func (c *msgCache) AddTokenFlag(ctx context.Context, userID string, platformID i
|
|||||||
return errs.Wrap(c.rdb.HSet(ctx, key, token, flag).Err())
|
return errs.Wrap(c.rdb.HSet(ctx, key, token, flag).Err())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *msgCache) GetTokensWithoutError(ctx context.Context, userID, platformID string) (map[string]int, error) {
|
func (c *msgCache) GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error) {
|
||||||
key := uidPidToken + userID + ":" + platformID
|
key := uidPidToken + userID + ":" + constant.PlatformIDToName(platformID)
|
||||||
m, err := c.rdb.HGetAll(ctx, key).Result()
|
m, err := c.rdb.HGetAll(ctx, key).Result()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errs.Wrap(err)
|
return nil, errs.Wrap(err)
|
||||||
@ -273,8 +273,8 @@ func (c *msgCache) GetTokensWithoutError(ctx context.Context, userID, platformID
|
|||||||
return mm, nil
|
return mm, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *msgCache) SetTokenMapByUidPid(ctx context.Context, userID string, platform string, m map[string]int) error {
|
func (c *msgCache) SetTokenMapByUidPid(ctx context.Context, userID string, platform int, m map[string]int) error {
|
||||||
key := uidPidToken + userID + ":" + platform
|
key := uidPidToken + userID + ":" + constant.PlatformIDToName(platform)
|
||||||
mm := make(map[string]interface{})
|
mm := make(map[string]interface{})
|
||||||
for k, v := range m {
|
for k, v := range m {
|
||||||
mm[k] = v
|
mm[k] = v
|
||||||
@ -282,8 +282,8 @@ func (c *msgCache) SetTokenMapByUidPid(ctx context.Context, userID string, platf
|
|||||||
return errs.Wrap(c.rdb.HSet(ctx, key, mm).Err())
|
return errs.Wrap(c.rdb.HSet(ctx, key, mm).Err())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *msgCache) DeleteTokenByUidPid(ctx context.Context, userID string, platform string, fields []string) error {
|
func (c *msgCache) DeleteTokenByUidPid(ctx context.Context, userID string, platform int, fields []string) error {
|
||||||
key := uidPidToken + userID + ":" + platform
|
key := uidPidToken + userID + ":" + constant.PlatformIDToName(platform)
|
||||||
return errs.Wrap(c.rdb.HDel(ctx, key, fields...).Err())
|
return errs.Wrap(c.rdb.HDel(ctx, key, fields...).Err())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -12,9 +12,9 @@ import (
|
|||||||
|
|
||||||
type AuthDatabase interface {
|
type AuthDatabase interface {
|
||||||
//结果为空 不返回错误
|
//结果为空 不返回错误
|
||||||
GetTokensWithoutError(ctx context.Context, userID, platform string) (map[string]int, error)
|
GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error)
|
||||||
//创建token
|
//创建token
|
||||||
CreateToken(ctx context.Context, userID string, platform string) (string, error)
|
CreateToken(ctx context.Context, userID string, platformID int) (string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type authDatabase struct {
|
type authDatabase struct {
|
||||||
@ -29,13 +29,13 @@ func NewAuthDatabase(cache cache.MsgModel, accessSecret string, accessExpire int
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 结果为空 不返回错误
|
// 结果为空 不返回错误
|
||||||
func (a *authDatabase) GetTokensWithoutError(ctx context.Context, userID, platform string) (map[string]int, error) {
|
func (a *authDatabase) GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error) {
|
||||||
return a.cache.GetTokensWithoutError(ctx, userID, platform)
|
return a.cache.GetTokensWithoutError(ctx, userID, platformID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 创建token
|
// 创建token
|
||||||
func (a *authDatabase) CreateToken(ctx context.Context, userID string, platform string) (string, error) {
|
func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformID int) (string, error) {
|
||||||
tokens, err := a.cache.GetTokensWithoutError(ctx, userID, platform)
|
tokens, err := a.cache.GetTokensWithoutError(ctx, userID, platformID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
@ -47,16 +47,16 @@ func (a *authDatabase) CreateToken(ctx context.Context, userID string, platform
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(deleteTokenKey) != 0 {
|
if len(deleteTokenKey) != 0 {
|
||||||
err := a.cache.DeleteTokenByUidPid(ctx, userID, platform, deleteTokenKey)
|
err := a.cache.DeleteTokenByUidPid(ctx, userID, platformID, deleteTokenKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
claims := tokenverify.BuildClaims(userID, platform, a.accessExpire)
|
claims := tokenverify.BuildClaims(userID, platformID, a.accessExpire)
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||||
tokenString, err := token.SignedString([]byte(a.accessSecret))
|
tokenString, err := token.SignedString([]byte(a.accessSecret))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", utils.Wrap(err, "")
|
return "", utils.Wrap(err, "")
|
||||||
}
|
}
|
||||||
return tokenString, a.cache.AddTokenFlag(ctx, userID, constant.PlatformNameToID(platform), tokenString, constant.NormalToken)
|
return tokenString, a.cache.AddTokenFlag(ctx, userID, platformID, tokenString, constant.NormalToken)
|
||||||
}
|
}
|
||||||
|
@ -128,7 +128,7 @@ func GinParseToken(rdb redis.UniversalClient) gin.HandlerFunc {
|
|||||||
c.Abort()
|
c.Abort()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
m, err := dataBase.GetTokensWithoutError(c, claims.UID, claims.Platform)
|
m, err := dataBase.GetTokensWithoutError(c, claims.UserID, claims.PlatformID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.ZWarn(c, "cache get token error", errs.ErrTokenNotExist.Wrap())
|
log.ZWarn(c, "cache get token error", errs.ErrTokenNotExist.Wrap())
|
||||||
apiresp.GinError(c, errs.ErrTokenNotExist.Wrap())
|
apiresp.GinError(c, errs.ErrTokenNotExist.Wrap())
|
||||||
@ -156,8 +156,8 @@ func GinParseToken(rdb redis.UniversalClient) gin.HandlerFunc {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
c.Set(constant.OpUserPlatform, claims.Platform)
|
c.Set(constant.OpUserPlatform, constant.PlatformIDToName(claims.PlatformID))
|
||||||
c.Set(constant.OpUserID, claims.UID)
|
c.Set(constant.OpUserID, claims.UserID)
|
||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -4,7 +4,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/config"
|
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/config"
|
||||||
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant"
|
|
||||||
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/mcontext"
|
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/mcontext"
|
||||||
"github.com/OpenIMSDK/Open-IM-Server/pkg/errs"
|
"github.com/OpenIMSDK/Open-IM-Server/pkg/errs"
|
||||||
"github.com/OpenIMSDK/Open-IM-Server/pkg/utils"
|
"github.com/OpenIMSDK/Open-IM-Server/pkg/utils"
|
||||||
@ -14,17 +13,17 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Claims struct {
|
type Claims struct {
|
||||||
UID string
|
UserID string
|
||||||
Platform string //login platform
|
PlatformID int //login platform
|
||||||
jwt.RegisteredClaims
|
jwt.RegisteredClaims
|
||||||
}
|
}
|
||||||
|
|
||||||
func BuildClaims(uid, platform string, ttl int64) Claims {
|
func BuildClaims(uid string, platformID int, ttl int64) Claims {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
before := now.Add(-time.Minute * 5)
|
before := now.Add(-time.Minute * 5)
|
||||||
return Claims{
|
return Claims{
|
||||||
UID: uid,
|
UserID: uid,
|
||||||
Platform: platform,
|
PlatformID: platformID,
|
||||||
RegisteredClaims: jwt.RegisteredClaims{
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
ExpiresAt: jwt.NewNumericDate(now.Add(time.Duration(ttl*24) * time.Hour)), //Expiration time
|
ExpiresAt: jwt.NewNumericDate(now.Add(time.Duration(ttl*24) * time.Hour)), //Expiration time
|
||||||
IssuedAt: jwt.NewNumericDate(now), //Issuing time
|
IssuedAt: jwt.NewNumericDate(now), //Issuing time
|
||||||
@ -95,19 +94,15 @@ func WsVerifyToken(token, userID, platformID string) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return errs.ErrArgs.Wrap(fmt.Sprintf("platformID %s is not int", platformID))
|
return errs.ErrArgs.Wrap(fmt.Sprintf("platformID %s is not int", platformID))
|
||||||
}
|
}
|
||||||
platform := constant.PlatformIDToName(platformIDInt)
|
|
||||||
if platform == "" {
|
|
||||||
return errs.ErrArgs.Wrap(fmt.Sprintf("platformID %s is not exist", platformID))
|
|
||||||
}
|
|
||||||
claim, err := GetClaimFromToken(token)
|
claim, err := GetClaimFromToken(token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if claim.UID != userID {
|
if claim.UserID != userID {
|
||||||
return errs.ErrTokenInvalid.Wrap(fmt.Sprintf("token uid %s != userID %s", claim.UID, userID))
|
return errs.ErrTokenInvalid.Wrap(fmt.Sprintf("token uid %s != userID %s", claim.UserID, userID))
|
||||||
}
|
}
|
||||||
if claim.Platform != platform {
|
if claim.PlatformID != platformIDInt {
|
||||||
return errs.ErrTokenInvalid.Wrap(fmt.Sprintf("token platform %s != %s", claim.Platform, platform))
|
return errs.ErrTokenInvalid.Wrap(fmt.Sprintf("token platform %d != %d", claim.PlatformID, platformIDInt))
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user