Merge branch 'errcode' of github.com:OpenIMSDK/Open-IM-Server into errcode

This commit is contained in:
wangchuxiao 2023-06-14 12:20:52 +08:00
commit bef052ed8c
12 changed files with 216 additions and 88 deletions

View File

@ -224,8 +224,11 @@ func (c *Client) PushMessage(ctx context.Context, msgData *sdkws.MsgData) error
return c.writeBinaryMsg(resp) return c.writeBinaryMsg(resp)
} }
func (c *Client) KickOnlineMessage(ctx context.Context) error { func (c *Client) KickOnlineMessage() error {
return nil resp := Resp{
ReqIdentifier: WSKickOnlineMsg,
}
return c.writeBinaryMsg(resp)
} }
func (c *Client) writeBinaryMsg(resp Resp) error { func (c *Client) writeBinaryMsg(resp Resp) error {

View File

@ -91,6 +91,9 @@ func (c *UserConnContext) GetPlatformID() string {
func (c *UserConnContext) GetOperationID() string { func (c *UserConnContext) GetOperationID() string {
return c.Req.URL.Query().Get(OperationID) return c.Req.URL.Query().Get(OperationID)
} }
func (c *UserConnContext) GetToken() string {
return c.Req.URL.Query().Get(Token)
}
func (c *UserConnContext) GetBackground() bool { func (c *UserConnContext) GetBackground() bool {
b, err := strconv.ParseBool(c.Req.URL.Query().Get(BackgroundStatus)) b, err := strconv.ParseBool(c.Req.URL.Query().Get(BackgroundStatus))
if err != nil { if err != nil {

View File

@ -2,6 +2,7 @@ package msggateway
import ( import (
"context" "context"
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/db/cache"
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/config" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/config"
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant"
@ -17,7 +18,13 @@ import (
) )
func (s *Server) InitServer(client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error { func (s *Server) InitServer(client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error {
rdb, err := cache.NewRedis()
if err != nil {
return err
}
msgModel := cache.NewMsgCacheModel(rdb)
s.LongConnServer.SetDiscoveryRegistry(client) s.LongConnServer.SetDiscoveryRegistry(client)
s.LongConnServer.SetCacheHandler(msgModel)
msggateway.RegisterMsgGatewayServer(server, s) msggateway.RegisterMsgGatewayServer(server, s)
return nil return nil
} }
@ -131,7 +138,7 @@ func (s *Server) KickUserOffline(ctx context.Context, req *msggateway.KickUserOf
for _, v := range req.KickUserIDList { for _, v := range req.KickUserIDList {
if clients, _, ok := s.LongConnServer.GetUserPlatformCons(v, int(req.PlatformID)); ok { if clients, _, ok := s.LongConnServer.GetUserPlatformCons(v, int(req.PlatformID)); ok {
for _, client := range clients { for _, client := range clients {
err := client.KickOnlineMessage(ctx) err := client.KickOnlineMessage()
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -1,13 +1,19 @@
package msggateway package msggateway
import ( import (
"context"
"errors" "errors"
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/config"
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant"
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/db/cache"
"net/http" "net/http"
"strconv"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/OpenIMSDK/Open-IM-Server/pkg/discoveryregistry" "github.com/OpenIMSDK/Open-IM-Server/pkg/discoveryregistry"
redis "github.com/go-redis/redis/v8"
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/log" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/log"
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/tokenverify" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/tokenverify"
@ -22,7 +28,7 @@ type LongConnServer interface {
GetUserAllCons(userID string) ([]*Client, bool) GetUserAllCons(userID string) ([]*Client, bool)
GetUserPlatformCons(userID string, platform int) ([]*Client, bool, bool) GetUserPlatformCons(userID string, platform int) ([]*Client, bool, bool)
Validate(s interface{}) error Validate(s interface{}) error
//SetMessageHandler(msgRpcClient *rpcclient.MsgClient) SetCacheHandler(cache cache.MsgModel)
SetDiscoveryRegistry(client discoveryregistry.SvcDiscoveryRegistry) SetDiscoveryRegistry(client discoveryregistry.SvcDiscoveryRegistry)
UnRegister(c *Client) UnRegister(c *Client)
Compressor Compressor
@ -41,6 +47,7 @@ type WsServer struct {
wsMaxConnNum int64 wsMaxConnNum int64
registerChan chan *Client registerChan chan *Client
unregisterChan chan *Client unregisterChan chan *Client
kickHandlerChan chan *kickHandler
clients *UserMap clients *UserMap
clientPool sync.Pool clientPool sync.Pool
onlineUserNum int64 onlineUserNum int64
@ -48,14 +55,23 @@ type WsServer struct {
handshakeTimeout time.Duration handshakeTimeout time.Duration
hubServer *Server hubServer *Server
validate *validator.Validate validate *validator.Validate
cache cache.MsgModel
Compressor Compressor
Encoder Encoder
MessageHandler MessageHandler
} }
type kickHandler struct {
clientOK bool
oldClients []*Client
newClient *Client
}
func (ws *WsServer) SetDiscoveryRegistry(client discoveryregistry.SvcDiscoveryRegistry) { func (ws *WsServer) SetDiscoveryRegistry(client discoveryregistry.SvcDiscoveryRegistry) {
ws.MessageHandler = NewGrpcHandler(ws.validate, client) ws.MessageHandler = NewGrpcHandler(ws.validate, client)
} }
func (ws *WsServer) SetCacheHandler(cache cache.MsgModel) {
ws.cache = cache
}
func (ws *WsServer) UnRegister(c *Client) { func (ws *WsServer) UnRegister(c *Client) {
ws.unregisterChan <- c ws.unregisterChan <- c
@ -92,12 +108,13 @@ func NewWsServer(opts ...Option) (*WsServer, error) {
return new(Client) return new(Client)
}, },
}, },
registerChan: make(chan *Client, 1000), registerChan: make(chan *Client, 1000),
unregisterChan: make(chan *Client, 1000), unregisterChan: make(chan *Client, 1000),
validate: v, kickHandlerChan: make(chan *kickHandler, 1000),
clients: newUserMap(), validate: v,
Compressor: NewGzipCompressor(), clients: newUserMap(),
Encoder: NewGobEncoder(), Compressor: NewGzipCompressor(),
Encoder: NewGobEncoder(),
}, nil }, nil
} }
func (ws *WsServer) Run() error { func (ws *WsServer) Run() error {
@ -109,6 +126,8 @@ func (ws *WsServer) Run() error {
ws.registerClient(client) ws.registerClient(client)
case client = <-ws.unregisterChan: case client = <-ws.unregisterChan:
ws.unregisterClient(client) ws.unregisterClient(client)
case onlineInfo := <-ws.kickHandlerChan:
ws.multiTerminalLoginChecker(onlineInfo)
} }
} }
}() }()
@ -119,26 +138,29 @@ func (ws *WsServer) Run() error {
func (ws *WsServer) registerClient(client *Client) { func (ws *WsServer) registerClient(client *Client) {
var ( var (
userOK bool userOK bool
clientOK bool clientOK bool
cli []*Client oldClients []*Client
) )
cli, userOK, clientOK = ws.clients.Get(client.UserID, client.PlatformID) ws.clients.Set(client.UserID, client)
oldClients, userOK, clientOK = ws.clients.Get(client.UserID, client.PlatformID)
if !userOK { if !userOK {
log.ZDebug(client.ctx, "user not exist", "userID", client.UserID, "platformID", client.PlatformID) log.ZDebug(client.ctx, "user not exist", "userID", client.UserID, "platformID", client.PlatformID)
ws.clients.Set(client.UserID, client)
atomic.AddInt64(&ws.onlineUserNum, 1) atomic.AddInt64(&ws.onlineUserNum, 1)
atomic.AddInt64(&ws.onlineUserConnNum, 1) atomic.AddInt64(&ws.onlineUserConnNum, 1)
} else { } else {
i := &kickHandler{
clientOK: clientOK,
oldClients: oldClients,
newClient: client,
}
ws.kickHandlerChan <- i
log.ZDebug(client.ctx, "user exist", "userID", client.UserID, "platformID", client.PlatformID) log.ZDebug(client.ctx, "user exist", "userID", client.UserID, "platformID", client.PlatformID)
if clientOK { //已经有同平台的连接存在 if clientOK { //已经有同平台的连接存在
ws.clients.Set(client.UserID, client) log.ZInfo(client.ctx, "repeat login", "userID", client.UserID, "platformID", client.PlatformID, "old remote addr", getRemoteAdders(oldClients))
ws.multiTerminalLoginChecker(cli)
log.ZInfo(client.ctx, "repeat login", "userID", client.UserID, "platformID", client.PlatformID, "old remote addr", getRemoteAdders(cli))
atomic.AddInt64(&ws.onlineUserConnNum, 1) atomic.AddInt64(&ws.onlineUserConnNum, 1)
} else { } else {
ws.clients.Set(client.UserID, client)
atomic.AddInt64(&ws.onlineUserConnNum, 1) atomic.AddInt64(&ws.onlineUserConnNum, 1)
} }
} }
@ -156,7 +178,47 @@ func getRemoteAdders(client []*Client) string {
return ret return ret
} }
func (ws *WsServer) multiTerminalLoginChecker(client []*Client) { func (ws *WsServer) multiTerminalLoginChecker(info *kickHandler) {
switch config.Config.MultiLoginPolicy {
case constant.DefalutNotKick:
case constant.PCAndOther:
if constant.PlatformIDToClass(info.newClient.PlatformID) == constant.TerminalPC {
return
}
fallthrough
case constant.AllLoginButSameTermKick:
if info.clientOK {
ws.clients.deleteClients(info.newClient.UserID, info.oldClients)
for _, c := range info.oldClients {
err := c.KickOnlineMessage()
if err != nil {
log.ZWarn(c.ctx, "KickOnlineMessage", err)
}
}
m, err := ws.cache.GetTokensWithoutError(info.newClient.ctx, info.newClient.UserID, info.newClient.PlatformID)
if err != nil && err != redis.Nil {
log.ZWarn(info.newClient.ctx, "get token from redis err", err, "userID", info.newClient.UserID, "platformID", info.newClient.PlatformID)
return
}
if m == nil {
log.ZWarn(info.newClient.ctx, "m is nil", errors.New("m is nil"), "userID", info.newClient.UserID, "platformID", info.newClient.PlatformID)
return
}
log.ZDebug(info.newClient.ctx, "get token from redis", "userID", info.newClient.UserID, "platformID", info.newClient.PlatformID, "tokenMap", m)
for k, _ := range m {
if k != info.newClient.ctx.GetToken() {
m[k] = constant.KickedToken
}
}
log.ZDebug(info.newClient.ctx, "set token map is ", "token map", m, "userID", info.newClient.UserID)
err = ws.cache.SetTokenMapByUidPid(info.newClient.ctx, info.newClient.UserID, info.newClient.PlatformID, m)
if err != nil {
log.ZWarn(info.newClient.ctx, "SetTokenMapByUidPid err", err, "userID", info.newClient.UserID, "platformID", info.newClient.PlatformID)
return
}
}
}
} }
func (ws *WsServer) unregisterClient(client *Client) { func (ws *WsServer) unregisterClient(client *Client) {
@ -170,60 +232,83 @@ func (ws *WsServer) unregisterClient(client *Client) {
} }
func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) { func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) {
context := newContext(w, r) defer log.ZInfo(context.Background(), "wsHandler", "remote addr", "url", r.URL.String())
connContext := newContext(w, r)
if ws.onlineUserConnNum >= ws.wsMaxConnNum { if ws.onlineUserConnNum >= ws.wsMaxConnNum {
httpError(context, errs.ErrConnOverMaxNumLimit) httpError(connContext, errs.ErrConnOverMaxNumLimit)
return return
} }
var ( var (
token string token string
userID string userID string
platformID string platformIDStr string
exists bool exists bool
compression bool compression bool
) )
token, exists = context.Query(Token) token, exists = connContext.Query(Token)
if !exists { if !exists {
httpError(context, errs.ErrConnArgsErr) httpError(connContext, errs.ErrConnArgsErr)
return return
} }
userID, exists = context.Query(WsUserID) userID, exists = connContext.Query(WsUserID)
if !exists { if !exists {
httpError(context, errs.ErrConnArgsErr) httpError(connContext, errs.ErrConnArgsErr)
return return
} }
platformID, exists = context.Query(PlatformID) platformIDStr, exists = connContext.Query(PlatformID)
if !exists || utils.StringToInt(platformID) == 0 { if !exists {
httpError(context, errs.ErrConnArgsErr) httpError(connContext, errs.ErrConnArgsErr)
return return
} }
// log.ZDebug(context2.Background(), "conn", "platformID", platformID) platformID, err := strconv.Atoi(platformIDStr)
err := tokenverify.WsVerifyToken(token, userID, platformID)
if err != nil { if err != nil {
httpError(context, err) httpError(connContext, errs.ErrConnArgsErr)
return
}
if err := tokenverify.WsVerifyToken(token, userID, platformID); err != nil {
httpError(connContext, err)
return
}
m, err := ws.cache.GetTokensWithoutError(context.Background(), userID, platformID)
if err != nil {
httpError(connContext, err)
return
}
if v, ok := m[token]; ok {
switch v {
case constant.NormalToken:
case constant.KickedToken:
httpError(connContext, errs.ErrTokenKicked.Wrap())
return
default:
httpError(connContext, errs.ErrTokenUnknown.Wrap())
return
}
} else {
httpError(connContext, errs.ErrTokenNotExist.Wrap())
return return
} }
wsLongConn := newGWebSocket(WebSocket, ws.handshakeTimeout) wsLongConn := newGWebSocket(WebSocket, ws.handshakeTimeout)
err = wsLongConn.GenerateLongConn(w, r) err = wsLongConn.GenerateLongConn(w, r)
if err != nil { if err != nil {
httpError(context, err) httpError(connContext, err)
return return
} }
compressProtoc, exists := context.Query(Compression) compressProtoc, exists := connContext.Query(Compression)
if exists { if exists {
if compressProtoc == GzipCompressionProtocol { if compressProtoc == GzipCompressionProtocol {
compression = true compression = true
} }
} }
compressProtoc, exists = context.GetHeader(Compression) compressProtoc, exists = connContext.GetHeader(Compression)
if exists { if exists {
if compressProtoc == GzipCompressionProtocol { if compressProtoc == GzipCompressionProtocol {
compression = true compression = true
} }
} }
client := ws.clientPool.Get().(*Client) client := ws.clientPool.Get().(*Client)
client.ResetClient(context, wsLongConn, context.GetBackground(), compression, ws) client.ResetClient(connContext, wsLongConn, connContext.GetBackground(), compression, ws)
ws.registerChan <- client ws.registerChan <- client
go client.readMessage() go client.readMessage()
} }

View File

@ -3,6 +3,7 @@ package msggateway
import ( import (
"context" "context"
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/log" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/log"
"github.com/OpenIMSDK/Open-IM-Server/pkg/utils"
"sync" "sync"
) )
@ -71,6 +72,29 @@ func (u *UserMap) delete(key string, connRemoteAddr string) (isDeleteUser bool)
} }
return existed return existed
} }
func (u *UserMap) deleteClients(key string, clients []*Client) (isDeleteUser bool) {
m := utils.SliceToMapAny(clients, func(c *Client) (string, struct{}) {
return c.ctx.GetRemoteAddr(), struct{}{}
})
allClients, existed := u.m.Load(key)
if existed {
oldClients := allClients.([]*Client)
var a []*Client
for _, client := range oldClients {
if _, ok := m[client.ctx.GetRemoteAddr()]; !ok {
a = append(a, client)
}
}
if len(a) == 0 {
u.m.Delete(key)
return true
} else {
u.m.Store(key, a)
return false
}
}
return existed
}
func (u *UserMap) DeleteAll(key string) { func (u *UserMap) DeleteAll(key string) {
u.m.Delete(key) u.m.Delete(key)
} }

View File

@ -42,7 +42,7 @@ func (s *authServer) UserToken(ctx context.Context, req *pbAuth.UserTokenReq) (*
if _, err := s.userRpcClient.GetUserInfo(ctx, req.UserID); err != nil { if _, err := s.userRpcClient.GetUserInfo(ctx, req.UserID); err != nil {
return nil, err return nil, err
} }
token, err := s.authDatabase.CreateToken(ctx, req.UserID, constant.PlatformIDToName(int(req.PlatformID))) token, err := s.authDatabase.CreateToken(ctx, req.UserID, int(req.PlatformID))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -56,7 +56,7 @@ func (s *authServer) parseToken(ctx context.Context, tokensString string) (claim
if err != nil { if err != nil {
return nil, utils.Wrap(err, "") return nil, utils.Wrap(err, "")
} }
m, err := s.authDatabase.GetTokensWithoutError(ctx, claims.UID, claims.Platform) m, err := s.authDatabase.GetTokensWithoutError(ctx, claims.UserID, claims.PlatformID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -82,8 +82,8 @@ func (s *authServer) ParseToken(ctx context.Context, req *pbAuth.ParseTokenReq)
if err != nil { if err != nil {
return nil, err return nil, err
} }
resp.UserID = claims.UID resp.UserID = claims.UserID
resp.Platform = claims.Platform resp.Platform = constant.PlatformIDToName(claims.PlatformID)
resp.ExpireTimeSeconds = claims.ExpiresAt.Unix() resp.ExpireTimeSeconds = claims.ExpiresAt.Unix()
return resp, nil return resp, nil
} }

View File

@ -118,6 +118,7 @@ const (
ExpiredToken = 3 ExpiredToken = 3
//MultiTerminalLogin //MultiTerminalLogin
DefalutNotKick = 0
//Full-end login, but the same end is mutually exclusive //Full-end login, but the same end is mutually exclusive
AllLoginButSameTermKick = 1 AllLoginButSameTermKick = 1
//Only one of the endpoints can log in //Only one of the endpoints can log in

View File

@ -57,7 +57,7 @@ var PlatformName2ID = map[string]int{
IPadPlatformStr: IPadPlatformID, IPadPlatformStr: IPadPlatformID,
AdminPlatformStr: AdminPlatformID, AdminPlatformStr: AdminPlatformID,
} }
var Platform2class = map[string]string{ var PlatformName2class = map[string]string{
IOSPlatformStr: TerminalMobile, IOSPlatformStr: TerminalMobile,
AndroidPlatformStr: TerminalMobile, AndroidPlatformStr: TerminalMobile,
MiniWebPlatformStr: WebPlatformStr, MiniWebPlatformStr: WebPlatformStr,
@ -66,6 +66,15 @@ var Platform2class = map[string]string{
OSXPlatformStr: TerminalPC, OSXPlatformStr: TerminalPC,
LinuxPlatformStr: TerminalPC, LinuxPlatformStr: TerminalPC,
} }
var PlatformID2class = map[int]string{
IOSPlatformID: TerminalMobile,
AndroidPlatformID: TerminalMobile,
MiniWebPlatformID: WebPlatformStr,
WebPlatformID: WebPlatformStr,
WindowsPlatformID: TerminalPC,
OSXPlatformID: TerminalPC,
LinuxPlatformID: TerminalPC,
}
func PlatformIDToName(num int) string { func PlatformIDToName(num int) string {
return PlatformID2Name[num] return PlatformID2Name[num]
@ -74,5 +83,8 @@ func PlatformNameToID(name string) int {
return PlatformName2ID[name] return PlatformName2ID[name]
} }
func PlatformNameToClass(name string) string { func PlatformNameToClass(name string) string {
return Platform2class[name] return PlatformName2class[name]
}
func PlatformIDToClass(num int) string {
return PlatformID2class[num]
} }

View File

@ -88,9 +88,9 @@ type MsgModel interface {
SeqCache SeqCache
thirdCache thirdCache
AddTokenFlag(ctx context.Context, userID string, platformID int, token string, flag int) error AddTokenFlag(ctx context.Context, userID string, platformID int, token string, flag int) error
GetTokensWithoutError(ctx context.Context, userID, platformID string) (map[string]int, error) GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error)
SetTokenMapByUidPid(ctx context.Context, userID string, platform string, m map[string]int) error SetTokenMapByUidPid(ctx context.Context, userID string, platformID int, m map[string]int) error
DeleteTokenByUidPid(ctx context.Context, userID string, platform string, fields []string) error DeleteTokenByUidPid(ctx context.Context, userID string, platformID int, fields []string) error
GetMessagesBySeq(ctx context.Context, conversationID string, seqs []int64) (seqMsg []*sdkws.MsgData, failedSeqList []int64, err error) GetMessagesBySeq(ctx context.Context, conversationID string, seqs []int64) (seqMsg []*sdkws.MsgData, failedSeqList []int64, err error)
SetMessageToCache(ctx context.Context, conversationID string, msgs []*sdkws.MsgData) (int, error) SetMessageToCache(ctx context.Context, conversationID string, msgs []*sdkws.MsgData) (int, error)
UserDeleteMsgs(ctx context.Context, conversationID string, seqs []int64, userID string) error UserDeleteMsgs(ctx context.Context, conversationID string, seqs []int64, userID string) error
@ -260,8 +260,8 @@ func (c *msgCache) AddTokenFlag(ctx context.Context, userID string, platformID i
return errs.Wrap(c.rdb.HSet(ctx, key, token, flag).Err()) return errs.Wrap(c.rdb.HSet(ctx, key, token, flag).Err())
} }
func (c *msgCache) GetTokensWithoutError(ctx context.Context, userID, platformID string) (map[string]int, error) { func (c *msgCache) GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error) {
key := uidPidToken + userID + ":" + platformID key := uidPidToken + userID + ":" + constant.PlatformIDToName(platformID)
m, err := c.rdb.HGetAll(ctx, key).Result() m, err := c.rdb.HGetAll(ctx, key).Result()
if err != nil { if err != nil {
return nil, errs.Wrap(err) return nil, errs.Wrap(err)
@ -273,8 +273,8 @@ func (c *msgCache) GetTokensWithoutError(ctx context.Context, userID, platformID
return mm, nil return mm, nil
} }
func (c *msgCache) SetTokenMapByUidPid(ctx context.Context, userID string, platform string, m map[string]int) error { func (c *msgCache) SetTokenMapByUidPid(ctx context.Context, userID string, platform int, m map[string]int) error {
key := uidPidToken + userID + ":" + platform key := uidPidToken + userID + ":" + constant.PlatformIDToName(platform)
mm := make(map[string]interface{}) mm := make(map[string]interface{})
for k, v := range m { for k, v := range m {
mm[k] = v mm[k] = v
@ -282,8 +282,8 @@ func (c *msgCache) SetTokenMapByUidPid(ctx context.Context, userID string, platf
return errs.Wrap(c.rdb.HSet(ctx, key, mm).Err()) return errs.Wrap(c.rdb.HSet(ctx, key, mm).Err())
} }
func (c *msgCache) DeleteTokenByUidPid(ctx context.Context, userID string, platform string, fields []string) error { func (c *msgCache) DeleteTokenByUidPid(ctx context.Context, userID string, platform int, fields []string) error {
key := uidPidToken + userID + ":" + platform key := uidPidToken + userID + ":" + constant.PlatformIDToName(platform)
return errs.Wrap(c.rdb.HDel(ctx, key, fields...).Err()) return errs.Wrap(c.rdb.HDel(ctx, key, fields...).Err())
} }

View File

@ -12,9 +12,9 @@ import (
type AuthDatabase interface { type AuthDatabase interface {
//结果为空 不返回错误 //结果为空 不返回错误
GetTokensWithoutError(ctx context.Context, userID, platform string) (map[string]int, error) GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error)
//创建token //创建token
CreateToken(ctx context.Context, userID string, platform string) (string, error) CreateToken(ctx context.Context, userID string, platformID int) (string, error)
} }
type authDatabase struct { type authDatabase struct {
@ -29,13 +29,13 @@ func NewAuthDatabase(cache cache.MsgModel, accessSecret string, accessExpire int
} }
// 结果为空 不返回错误 // 结果为空 不返回错误
func (a *authDatabase) GetTokensWithoutError(ctx context.Context, userID, platform string) (map[string]int, error) { func (a *authDatabase) GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error) {
return a.cache.GetTokensWithoutError(ctx, userID, platform) return a.cache.GetTokensWithoutError(ctx, userID, platformID)
} }
// 创建token // 创建token
func (a *authDatabase) CreateToken(ctx context.Context, userID string, platform string) (string, error) { func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformID int) (string, error) {
tokens, err := a.cache.GetTokensWithoutError(ctx, userID, platform) tokens, err := a.cache.GetTokensWithoutError(ctx, userID, platformID)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -47,16 +47,16 @@ func (a *authDatabase) CreateToken(ctx context.Context, userID string, platform
} }
} }
if len(deleteTokenKey) != 0 { if len(deleteTokenKey) != 0 {
err := a.cache.DeleteTokenByUidPid(ctx, userID, platform, deleteTokenKey) err := a.cache.DeleteTokenByUidPid(ctx, userID, platformID, deleteTokenKey)
if err != nil { if err != nil {
return "", err return "", err
} }
} }
claims := tokenverify.BuildClaims(userID, platform, a.accessExpire) claims := tokenverify.BuildClaims(userID, platformID, a.accessExpire)
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString([]byte(a.accessSecret)) tokenString, err := token.SignedString([]byte(a.accessSecret))
if err != nil { if err != nil {
return "", utils.Wrap(err, "") return "", utils.Wrap(err, "")
} }
return tokenString, a.cache.AddTokenFlag(ctx, userID, constant.PlatformNameToID(platform), tokenString, constant.NormalToken) return tokenString, a.cache.AddTokenFlag(ctx, userID, platformID, tokenString, constant.NormalToken)
} }

View File

@ -128,7 +128,7 @@ func GinParseToken(rdb redis.UniversalClient) gin.HandlerFunc {
c.Abort() c.Abort()
return return
} }
m, err := dataBase.GetTokensWithoutError(c, claims.UID, claims.Platform) m, err := dataBase.GetTokensWithoutError(c, claims.UserID, claims.PlatformID)
if err != nil { if err != nil {
log.ZWarn(c, "cache get token error", errs.ErrTokenNotExist.Wrap()) log.ZWarn(c, "cache get token error", errs.ErrTokenNotExist.Wrap())
apiresp.GinError(c, errs.ErrTokenNotExist.Wrap()) apiresp.GinError(c, errs.ErrTokenNotExist.Wrap())
@ -155,9 +155,12 @@ func GinParseToken(rdb redis.UniversalClient) gin.HandlerFunc {
c.Abort() c.Abort()
return return
} }
} else {
apiresp.GinError(c, errs.ErrTokenNotExist.Wrap())
return
} }
c.Set(constant.OpUserPlatform, claims.Platform) c.Set(constant.OpUserPlatform, constant.PlatformIDToName(claims.PlatformID))
c.Set(constant.OpUserID, claims.UID) c.Set(constant.OpUserID, claims.UserID)
c.Next() c.Next()
} }
} }

View File

@ -4,27 +4,25 @@ import (
"context" "context"
"fmt" "fmt"
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/config" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/config"
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant"
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/mcontext" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/mcontext"
"github.com/OpenIMSDK/Open-IM-Server/pkg/errs" "github.com/OpenIMSDK/Open-IM-Server/pkg/errs"
"github.com/OpenIMSDK/Open-IM-Server/pkg/utils" "github.com/OpenIMSDK/Open-IM-Server/pkg/utils"
"github.com/golang-jwt/jwt/v4" "github.com/golang-jwt/jwt/v4"
"strconv"
"time" "time"
) )
type Claims struct { type Claims struct {
UID string UserID string
Platform string //login platform PlatformID int //login platform
jwt.RegisteredClaims jwt.RegisteredClaims
} }
func BuildClaims(uid, platform string, ttl int64) Claims { func BuildClaims(uid string, platformID int, ttl int64) Claims {
now := time.Now() now := time.Now()
before := now.Add(-time.Minute * 5) before := now.Add(-time.Minute * 5)
return Claims{ return Claims{
UID: uid, UserID: uid,
Platform: platform, PlatformID: platformID,
RegisteredClaims: jwt.RegisteredClaims{ RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(now.Add(time.Duration(ttl*24) * time.Hour)), //Expiration time ExpiresAt: jwt.NewNumericDate(now.Add(time.Duration(ttl*24) * time.Hour)), //Expiration time
IssuedAt: jwt.NewNumericDate(now), //Issuing time IssuedAt: jwt.NewNumericDate(now), //Issuing time
@ -90,24 +88,16 @@ func ParseRedisInterfaceToken(redisToken interface{}) (*Claims, error) {
func IsManagerUserID(opUserID string) bool { func IsManagerUserID(opUserID string) bool {
return utils.IsContain(opUserID, config.Config.Manager.AppManagerUid) return utils.IsContain(opUserID, config.Config.Manager.AppManagerUid)
} }
func WsVerifyToken(token, userID, platformID string) error { func WsVerifyToken(token, userID string, platformID int) error {
platformIDInt, err := strconv.Atoi(platformID)
if err != nil {
return errs.ErrArgs.Wrap(fmt.Sprintf("platformID %s is not int", platformID))
}
platform := constant.PlatformIDToName(platformIDInt)
if platform == "" {
return errs.ErrArgs.Wrap(fmt.Sprintf("platformID %s is not exist", platformID))
}
claim, err := GetClaimFromToken(token) claim, err := GetClaimFromToken(token)
if err != nil { if err != nil {
return err return err
} }
if claim.UID != userID { if claim.UserID != userID {
return errs.ErrTokenInvalid.Wrap(fmt.Sprintf("token uid %s != userID %s", claim.UID, userID)) return errs.ErrTokenInvalid.Wrap(fmt.Sprintf("token uid %s != userID %s", claim.UserID, userID))
} }
if claim.Platform != platform { if claim.PlatformID != platformID {
return errs.ErrTokenInvalid.Wrap(fmt.Sprintf("token platform %s != %s", claim.Platform, platform)) return errs.ErrTokenInvalid.Wrap(fmt.Sprintf("token platform %d != %d", claim.PlatformID, platformID))
} }
return nil return nil
} }