diff --git a/internal/msggateway/n_ws_server.go b/internal/msggateway/n_ws_server.go index d7175ed36..c665c5556 100644 --- a/internal/msggateway/n_ws_server.go +++ b/internal/msggateway/n_ws_server.go @@ -7,6 +7,7 @@ import ( "github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/db/cache" "net/http" + "strconv" "sync" "sync/atomic" "time" @@ -214,11 +215,11 @@ func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) { return } var ( - token string - userID string - platformID string - exists bool - compression bool + token string + userID string + platformIDStr string + exists bool + compression bool ) token, exists = connContext.Query(Token) @@ -231,13 +232,17 @@ func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) { httpError(connContext, errs.ErrConnArgsErr) return } - platformID, exists = connContext.Query(PlatformID) - if !exists || utils.StringToInt(platformID) == 0 { + platformIDStr, exists = connContext.Query(PlatformID) + if !exists { httpError(connContext, errs.ErrConnArgsErr) return } - err := tokenverify.WsVerifyToken(token, userID, platformID) + platformID, err := strconv.Atoi(platformIDStr) if err != nil { + httpError(connContext, errs.ErrConnArgsErr) + return + } + if err := tokenverify.WsVerifyToken(token, userID, platformID); err != nil { httpError(connContext, err) return } diff --git a/pkg/common/tokenverify/jwt_token.go b/pkg/common/tokenverify/jwt_token.go index 65a31545e..bc7ca62e6 100644 --- a/pkg/common/tokenverify/jwt_token.go +++ b/pkg/common/tokenverify/jwt_token.go @@ -8,7 +8,6 @@ import ( "github.com/OpenIMSDK/Open-IM-Server/pkg/errs" "github.com/OpenIMSDK/Open-IM-Server/pkg/utils" "github.com/golang-jwt/jwt/v4" - "strconv" "time" ) @@ -89,11 +88,7 @@ func ParseRedisInterfaceToken(redisToken interface{}) (*Claims, error) { func IsManagerUserID(opUserID string) bool { return utils.IsContain(opUserID, config.Config.Manager.AppManagerUid) } -func WsVerifyToken(token, userID, platformID string) error { - platformIDInt, err := strconv.Atoi(platformID) - if err != nil { - return errs.ErrArgs.Wrap(fmt.Sprintf("platformID %s is not int", platformID)) - } +func WsVerifyToken(token, userID string, platformID int) error { claim, err := GetClaimFromToken(token) if err != nil { return err @@ -101,8 +96,8 @@ func WsVerifyToken(token, userID, platformID string) error { if claim.UserID != userID { return errs.ErrTokenInvalid.Wrap(fmt.Sprintf("token uid %s != userID %s", claim.UserID, userID)) } - if claim.PlatformID != platformIDInt { - return errs.ErrTokenInvalid.Wrap(fmt.Sprintf("token platform %d != %d", claim.PlatformID, platformIDInt)) + if claim.PlatformID != platformID { + return errs.ErrTokenInvalid.Wrap(fmt.Sprintf("token platform %d != %d", claim.PlatformID, platformID)) } return nil }