mirror of
https://github.com/openimsdk/open-im-server.git
synced 2025-04-06 04:15:46 +08:00
Feature/optimise jwt token (#24)
This commit is contained in:
parent
0dfaa49c97
commit
0e6432f95a
1
go.mod
1
go.mod
@ -32,6 +32,7 @@ require (
|
||||
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 // indirect
|
||||
github.com/rifflock/lfshook v0.0.0-20180920164130-b9218ef580f5
|
||||
github.com/sirupsen/logrus v1.6.0
|
||||
github.com/stretchr/testify v1.7.0
|
||||
github.com/tencentyun/qcloud-cos-sts-sdk v0.0.0-20210325043845-84a0811633ca
|
||||
github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5 // indirect
|
||||
go.etcd.io/etcd v0.0.0-20200402134248-51bdeb39e698
|
||||
|
2
go.sum
2
go.sum
@ -233,6 +233,8 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV
|
||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||
github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
|
||||
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/tencentyun/qcloud-cos-sts-sdk v0.0.0-20210325043845-84a0811633ca h1:G/aIr3WiUesWHL2YGYgEqjM5tCAJ43Ml+0C18wDkWWs=
|
||||
github.com/tencentyun/qcloud-cos-sts-sdk v0.0.0-20210325043845-84a0811633ca/go.mod h1:b18KQa4IxHbxeseW1GcZox53d7J0z39VNONTxvvlkXw=
|
||||
github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U=
|
||||
|
@ -4,8 +4,9 @@ import (
|
||||
"Open_IM/src/common/config"
|
||||
"Open_IM/src/common/db"
|
||||
"errors"
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
"time"
|
||||
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -25,21 +26,16 @@ type Claims struct {
|
||||
func BuildClaims(uid, accountAddr, platform string, ttl int64) Claims {
|
||||
now := time.Now().Unix()
|
||||
//if ttl=-1 Permanent token
|
||||
if ttl == -1 {
|
||||
return Claims{
|
||||
UID: uid,
|
||||
Platform: platform,
|
||||
StandardClaims: jwt.StandardClaims{
|
||||
ExpiresAt: -1,
|
||||
IssuedAt: now,
|
||||
NotBefore: now,
|
||||
}}
|
||||
expiresAt := int64(-1)
|
||||
if ttl != -1 {
|
||||
expiresAt = now + ttl
|
||||
}
|
||||
|
||||
return Claims{
|
||||
UID: uid,
|
||||
Platform: platform,
|
||||
StandardClaims: jwt.StandardClaims{
|
||||
ExpiresAt: now + ttl, //Expiration time
|
||||
ExpiresAt: expiresAt, //Expiration time
|
||||
IssuedAt: now, //Issuing time
|
||||
NotBefore: now, //Begin Effective time
|
||||
}}
|
||||
@ -59,7 +55,7 @@ func secret() jwt.Keyfunc {
|
||||
}
|
||||
}
|
||||
|
||||
func ParseToken(tokensString string) (claims *Claims, err error) {
|
||||
func getClaimFromToken(tokensString string) (*Claims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokensString, &Claims{}, secret())
|
||||
if err != nil {
|
||||
if ve, ok := err.(*jwt.ValidationError); ok {
|
||||
@ -75,73 +71,63 @@ func ParseToken(tokensString string) (claims *Claims, err error) {
|
||||
}
|
||||
}
|
||||
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
|
||||
// 1.check userid and platform class 0 not exists and 1 exists
|
||||
existsInterface, err := db.DB.ExistsUserIDAndPlatform(claims.UID, Platform2class[claims.Platform])
|
||||
return claims, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func ParseToken(tokensString string) (claims *Claims, err error) {
|
||||
claims, err = getClaimFromToken(tokensString)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 1.check userid and platform class 0 not exists and 1 exists
|
||||
existsInterface, err := db.DB.ExistsUserIDAndPlatform(claims.UID, Platform2class[claims.Platform])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
exists := existsInterface.(int64)
|
||||
//get config multi login policy
|
||||
if config.Config.MultiLoginPolicy.OnlyOneTerminalAccess {
|
||||
//OnlyOneTerminalAccess policy need to check all terminal
|
||||
//When only one end is allowed to log in, there is a situation that needs to be paid attention to. After PC login,
|
||||
//mobile login should check two platform times. One of them is less than the redis storage time, which is the invalid token.
|
||||
platform := "PC"
|
||||
if Platform2class[claims.Platform] == "PC" {
|
||||
platform = "Mobile"
|
||||
}
|
||||
|
||||
existsInterface, err = db.DB.ExistsUserIDAndPlatform(claims.UID, platform)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
exists := existsInterface.(int64)
|
||||
//get config multi login policy
|
||||
if config.Config.MultiLoginPolicy.OnlyOneTerminalAccess {
|
||||
//OnlyOneTerminalAccess policy need to check all terminal
|
||||
//When only one end is allowed to log in, there is a situation that needs to be paid attention to. After PC login,
|
||||
//mobile login should check two platform times. One of them is less than the redis storage time, which is the invalid token.
|
||||
if Platform2class[claims.Platform] == "PC" {
|
||||
existsInterface, err = db.DB.ExistsUserIDAndPlatform(claims.UID, "Mobile")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
exists = existsInterface.(int64)
|
||||
if exists == 1 {
|
||||
res, err := MakeTheTokenInvalid(*claims, "Mobile")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if res {
|
||||
return nil, TokenInvalid
|
||||
}
|
||||
}
|
||||
} else {
|
||||
existsInterface, err = db.DB.ExistsUserIDAndPlatform(claims.UID, "PC")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
exists = existsInterface.(int64)
|
||||
if exists == 1 {
|
||||
res, err := MakeTheTokenInvalid(*claims, "PC")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if res {
|
||||
return nil, TokenInvalid
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if exists == 1 {
|
||||
res, err := MakeTheTokenInvalid(*claims, Platform2class[claims.Platform])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if res {
|
||||
return nil, TokenInvalid
|
||||
}
|
||||
exists = existsInterface.(int64)
|
||||
if exists == 1 {
|
||||
res, err := MakeTheTokenInvalid(*claims, platform)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
} else if config.Config.MultiLoginPolicy.MobileAndPCTerminalAccessButOtherTerminalKickEachOther {
|
||||
if exists == 1 {
|
||||
res, err := MakeTheTokenInvalid(*claims, Platform2class[claims.Platform])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if res {
|
||||
return nil, TokenInvalid
|
||||
}
|
||||
if res {
|
||||
return nil, TokenInvalid
|
||||
}
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
return nil, TokenUnknown
|
||||
// config.Config.MultiLoginPolicy.MobileAndPCTerminalAccessButOtherTerminalKickEachOther == true
|
||||
// or PC/Mobile validate success
|
||||
// final check
|
||||
if exists == 1 {
|
||||
res, err := MakeTheTokenInvalid(*claims, Platform2class[claims.Platform])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if res {
|
||||
return nil, TokenInvalid
|
||||
}
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func MakeTheTokenInvalid(currentClaims Claims, platformClass string) (bool, error) {
|
||||
@ -159,35 +145,16 @@ func MakeTheTokenInvalid(currentClaims Claims, platformClass string) (bool, erro
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func ParseRedisInterfaceToken(redisToken interface{}) (*Claims, error) {
|
||||
token, err := jwt.ParseWithClaims(string(redisToken.([]uint8)), &Claims{}, secret())
|
||||
if err != nil {
|
||||
if ve, ok := err.(*jwt.ValidationError); ok {
|
||||
if ve.Errors&jwt.ValidationErrorMalformed != 0 {
|
||||
return nil, TokenMalformed
|
||||
} else if ve.Errors&jwt.ValidationErrorExpired != 0 {
|
||||
return nil, TokenExpired
|
||||
} else if ve.Errors&jwt.ValidationErrorNotValidYet != 0 {
|
||||
return nil, TokenNotValidYet
|
||||
} else {
|
||||
return nil, TokenInvalid
|
||||
}
|
||||
}
|
||||
}
|
||||
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
|
||||
return claims, nil
|
||||
}
|
||||
return nil, err
|
||||
return getClaimFromToken(string(redisToken.([]uint8)))
|
||||
}
|
||||
|
||||
//Validation token, false means failure, true means successful verification
|
||||
func VerifyToken(token, uid string) bool {
|
||||
claims, err := ParseToken(token)
|
||||
if err != nil {
|
||||
if err != nil || claims.UID != uid {
|
||||
return false
|
||||
} else if claims.UID != uid {
|
||||
return false
|
||||
} else {
|
||||
return true
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
86
src/utils/jwt_token_test.go
Normal file
86
src/utils/jwt_token_test.go
Normal file
@ -0,0 +1,86 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"Open_IM/src/common/config"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_BuildClaims(t *testing.T) {
|
||||
uid := "1"
|
||||
accountAddr := "accountAddr"
|
||||
platform := "PC"
|
||||
ttl := int64(-1)
|
||||
claim := BuildClaims(uid, accountAddr, platform, ttl)
|
||||
now := time.Now().Unix()
|
||||
|
||||
assert.Equal(t, claim.UID, uid, "uid should equal")
|
||||
assert.Equal(t, claim.Platform, platform, "platform should equal")
|
||||
assert.Equal(t, claim.StandardClaims.ExpiresAt, int64(-1), "StandardClaims.ExpiresAt should be equal")
|
||||
// time difference within 1s
|
||||
assert.Equal(t, claim.StandardClaims.IssuedAt, now, "StandardClaims.IssuedAt should be equal")
|
||||
assert.Equal(t, claim.StandardClaims.NotBefore, now, "StandardClaims.NotBefore should be equal")
|
||||
|
||||
ttl = int64(60)
|
||||
now = time.Now().Unix()
|
||||
claim = BuildClaims(uid, accountAddr, platform, ttl)
|
||||
// time difference within 1s
|
||||
assert.Equal(t, claim.StandardClaims.ExpiresAt, int64(60)+now, "StandardClaims.ExpiresAt should be equal")
|
||||
assert.Equal(t, claim.StandardClaims.IssuedAt, now, "StandardClaims.IssuedAt should be equal")
|
||||
assert.Equal(t, claim.StandardClaims.NotBefore, now, "StandardClaims.NotBefore should be equal")
|
||||
}
|
||||
|
||||
func Test_CreateToken(t *testing.T) {
|
||||
uid := "1"
|
||||
accountAddr := "accountAddr"
|
||||
platform := int32(1)
|
||||
now := time.Now().Unix()
|
||||
|
||||
tokenString, expiresAt, err := CreateToken(uid, accountAddr, platform)
|
||||
|
||||
assert.NotEmpty(t, tokenString)
|
||||
assert.Equal(t, expiresAt, 604800+now)
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func Test_VerifyToken(t *testing.T) {
|
||||
uid := "1"
|
||||
accountAddr := "accountAddr"
|
||||
platform := int32(1)
|
||||
tokenString, _, _ := CreateToken(uid, accountAddr, platform)
|
||||
result := VerifyToken(tokenString, uid)
|
||||
assert.True(t, result)
|
||||
result = VerifyToken(tokenString, "2")
|
||||
assert.False(t, result)
|
||||
}
|
||||
|
||||
func Test_ParseRedisInterfaceToken(t *testing.T) {
|
||||
uid := "1"
|
||||
accountAddr := "accountAddr"
|
||||
platform := int32(1)
|
||||
tokenString, _, _ := CreateToken(uid, accountAddr, platform)
|
||||
|
||||
claims, err := ParseRedisInterfaceToken([]uint8(tokenString))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, claims.UID, uid)
|
||||
|
||||
// timeout
|
||||
config.Config.TokenPolicy.AccessExpire = -80
|
||||
tokenString, _, _ = CreateToken(uid, accountAddr, platform)
|
||||
claims, err = ParseRedisInterfaceToken([]uint8(tokenString))
|
||||
assert.Equal(t, err, TokenExpired)
|
||||
assert.Nil(t, claims)
|
||||
}
|
||||
|
||||
func Test_ParseToken(t *testing.T) {
|
||||
uid := "1"
|
||||
accountAddr := "accountAddr"
|
||||
platform := int32(1)
|
||||
tokenString, _, _ := CreateToken(uid, accountAddr, platform)
|
||||
claims, err := ParseToken(tokenString)
|
||||
if err == nil {
|
||||
assert.Equal(t, claims.UID, uid)
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user