mirror of
https://github.com/openimsdk/open-im-server.git
synced 2025-04-06 04:15:46 +08:00
Merge branch 'errcode' of github.com:OpenIMSDK/Open-IM-Server into errcode
This commit is contained in:
commit
bef052ed8c
@ -224,8 +224,11 @@ func (c *Client) PushMessage(ctx context.Context, msgData *sdkws.MsgData) error
|
|||||||
return c.writeBinaryMsg(resp)
|
return c.writeBinaryMsg(resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) KickOnlineMessage(ctx context.Context) error {
|
func (c *Client) KickOnlineMessage() error {
|
||||||
return nil
|
resp := Resp{
|
||||||
|
ReqIdentifier: WSKickOnlineMsg,
|
||||||
|
}
|
||||||
|
return c.writeBinaryMsg(resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) writeBinaryMsg(resp Resp) error {
|
func (c *Client) writeBinaryMsg(resp Resp) error {
|
||||||
|
@ -91,6 +91,9 @@ func (c *UserConnContext) GetPlatformID() string {
|
|||||||
func (c *UserConnContext) GetOperationID() string {
|
func (c *UserConnContext) GetOperationID() string {
|
||||||
return c.Req.URL.Query().Get(OperationID)
|
return c.Req.URL.Query().Get(OperationID)
|
||||||
}
|
}
|
||||||
|
func (c *UserConnContext) GetToken() string {
|
||||||
|
return c.Req.URL.Query().Get(Token)
|
||||||
|
}
|
||||||
func (c *UserConnContext) GetBackground() bool {
|
func (c *UserConnContext) GetBackground() bool {
|
||||||
b, err := strconv.ParseBool(c.Req.URL.Query().Get(BackgroundStatus))
|
b, err := strconv.ParseBool(c.Req.URL.Query().Get(BackgroundStatus))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -2,6 +2,7 @@ package msggateway
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/db/cache"
|
||||||
|
|
||||||
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/config"
|
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/config"
|
||||||
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant"
|
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant"
|
||||||
@ -17,7 +18,13 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func (s *Server) InitServer(client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error {
|
func (s *Server) InitServer(client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error {
|
||||||
|
rdb, err := cache.NewRedis()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
msgModel := cache.NewMsgCacheModel(rdb)
|
||||||
s.LongConnServer.SetDiscoveryRegistry(client)
|
s.LongConnServer.SetDiscoveryRegistry(client)
|
||||||
|
s.LongConnServer.SetCacheHandler(msgModel)
|
||||||
msggateway.RegisterMsgGatewayServer(server, s)
|
msggateway.RegisterMsgGatewayServer(server, s)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -131,7 +138,7 @@ func (s *Server) KickUserOffline(ctx context.Context, req *msggateway.KickUserOf
|
|||||||
for _, v := range req.KickUserIDList {
|
for _, v := range req.KickUserIDList {
|
||||||
if clients, _, ok := s.LongConnServer.GetUserPlatformCons(v, int(req.PlatformID)); ok {
|
if clients, _, ok := s.LongConnServer.GetUserPlatformCons(v, int(req.PlatformID)); ok {
|
||||||
for _, client := range clients {
|
for _, client := range clients {
|
||||||
err := client.KickOnlineMessage(ctx)
|
err := client.KickOnlineMessage()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -1,13 +1,19 @@
|
|||||||
package msggateway
|
package msggateway
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/config"
|
||||||
|
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant"
|
||||||
|
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/db/cache"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/OpenIMSDK/Open-IM-Server/pkg/discoveryregistry"
|
"github.com/OpenIMSDK/Open-IM-Server/pkg/discoveryregistry"
|
||||||
|
redis "github.com/go-redis/redis/v8"
|
||||||
|
|
||||||
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/log"
|
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/log"
|
||||||
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/tokenverify"
|
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/tokenverify"
|
||||||
@ -22,7 +28,7 @@ type LongConnServer interface {
|
|||||||
GetUserAllCons(userID string) ([]*Client, bool)
|
GetUserAllCons(userID string) ([]*Client, bool)
|
||||||
GetUserPlatformCons(userID string, platform int) ([]*Client, bool, bool)
|
GetUserPlatformCons(userID string, platform int) ([]*Client, bool, bool)
|
||||||
Validate(s interface{}) error
|
Validate(s interface{}) error
|
||||||
//SetMessageHandler(msgRpcClient *rpcclient.MsgClient)
|
SetCacheHandler(cache cache.MsgModel)
|
||||||
SetDiscoveryRegistry(client discoveryregistry.SvcDiscoveryRegistry)
|
SetDiscoveryRegistry(client discoveryregistry.SvcDiscoveryRegistry)
|
||||||
UnRegister(c *Client)
|
UnRegister(c *Client)
|
||||||
Compressor
|
Compressor
|
||||||
@ -41,6 +47,7 @@ type WsServer struct {
|
|||||||
wsMaxConnNum int64
|
wsMaxConnNum int64
|
||||||
registerChan chan *Client
|
registerChan chan *Client
|
||||||
unregisterChan chan *Client
|
unregisterChan chan *Client
|
||||||
|
kickHandlerChan chan *kickHandler
|
||||||
clients *UserMap
|
clients *UserMap
|
||||||
clientPool sync.Pool
|
clientPool sync.Pool
|
||||||
onlineUserNum int64
|
onlineUserNum int64
|
||||||
@ -48,14 +55,23 @@ type WsServer struct {
|
|||||||
handshakeTimeout time.Duration
|
handshakeTimeout time.Duration
|
||||||
hubServer *Server
|
hubServer *Server
|
||||||
validate *validator.Validate
|
validate *validator.Validate
|
||||||
|
cache cache.MsgModel
|
||||||
Compressor
|
Compressor
|
||||||
Encoder
|
Encoder
|
||||||
MessageHandler
|
MessageHandler
|
||||||
}
|
}
|
||||||
|
type kickHandler struct {
|
||||||
|
clientOK bool
|
||||||
|
oldClients []*Client
|
||||||
|
newClient *Client
|
||||||
|
}
|
||||||
|
|
||||||
func (ws *WsServer) SetDiscoveryRegistry(client discoveryregistry.SvcDiscoveryRegistry) {
|
func (ws *WsServer) SetDiscoveryRegistry(client discoveryregistry.SvcDiscoveryRegistry) {
|
||||||
ws.MessageHandler = NewGrpcHandler(ws.validate, client)
|
ws.MessageHandler = NewGrpcHandler(ws.validate, client)
|
||||||
}
|
}
|
||||||
|
func (ws *WsServer) SetCacheHandler(cache cache.MsgModel) {
|
||||||
|
ws.cache = cache
|
||||||
|
}
|
||||||
|
|
||||||
func (ws *WsServer) UnRegister(c *Client) {
|
func (ws *WsServer) UnRegister(c *Client) {
|
||||||
ws.unregisterChan <- c
|
ws.unregisterChan <- c
|
||||||
@ -92,12 +108,13 @@ func NewWsServer(opts ...Option) (*WsServer, error) {
|
|||||||
return new(Client)
|
return new(Client)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
registerChan: make(chan *Client, 1000),
|
registerChan: make(chan *Client, 1000),
|
||||||
unregisterChan: make(chan *Client, 1000),
|
unregisterChan: make(chan *Client, 1000),
|
||||||
validate: v,
|
kickHandlerChan: make(chan *kickHandler, 1000),
|
||||||
clients: newUserMap(),
|
validate: v,
|
||||||
Compressor: NewGzipCompressor(),
|
clients: newUserMap(),
|
||||||
Encoder: NewGobEncoder(),
|
Compressor: NewGzipCompressor(),
|
||||||
|
Encoder: NewGobEncoder(),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
func (ws *WsServer) Run() error {
|
func (ws *WsServer) Run() error {
|
||||||
@ -109,6 +126,8 @@ func (ws *WsServer) Run() error {
|
|||||||
ws.registerClient(client)
|
ws.registerClient(client)
|
||||||
case client = <-ws.unregisterChan:
|
case client = <-ws.unregisterChan:
|
||||||
ws.unregisterClient(client)
|
ws.unregisterClient(client)
|
||||||
|
case onlineInfo := <-ws.kickHandlerChan:
|
||||||
|
ws.multiTerminalLoginChecker(onlineInfo)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@ -119,26 +138,29 @@ func (ws *WsServer) Run() error {
|
|||||||
|
|
||||||
func (ws *WsServer) registerClient(client *Client) {
|
func (ws *WsServer) registerClient(client *Client) {
|
||||||
var (
|
var (
|
||||||
userOK bool
|
userOK bool
|
||||||
clientOK bool
|
clientOK bool
|
||||||
cli []*Client
|
oldClients []*Client
|
||||||
)
|
)
|
||||||
cli, userOK, clientOK = ws.clients.Get(client.UserID, client.PlatformID)
|
ws.clients.Set(client.UserID, client)
|
||||||
|
oldClients, userOK, clientOK = ws.clients.Get(client.UserID, client.PlatformID)
|
||||||
if !userOK {
|
if !userOK {
|
||||||
log.ZDebug(client.ctx, "user not exist", "userID", client.UserID, "platformID", client.PlatformID)
|
log.ZDebug(client.ctx, "user not exist", "userID", client.UserID, "platformID", client.PlatformID)
|
||||||
ws.clients.Set(client.UserID, client)
|
|
||||||
atomic.AddInt64(&ws.onlineUserNum, 1)
|
atomic.AddInt64(&ws.onlineUserNum, 1)
|
||||||
atomic.AddInt64(&ws.onlineUserConnNum, 1)
|
atomic.AddInt64(&ws.onlineUserConnNum, 1)
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
|
i := &kickHandler{
|
||||||
|
clientOK: clientOK,
|
||||||
|
oldClients: oldClients,
|
||||||
|
newClient: client,
|
||||||
|
}
|
||||||
|
ws.kickHandlerChan <- i
|
||||||
log.ZDebug(client.ctx, "user exist", "userID", client.UserID, "platformID", client.PlatformID)
|
log.ZDebug(client.ctx, "user exist", "userID", client.UserID, "platformID", client.PlatformID)
|
||||||
if clientOK { //已经有同平台的连接存在
|
if clientOK { //已经有同平台的连接存在
|
||||||
ws.clients.Set(client.UserID, client)
|
log.ZInfo(client.ctx, "repeat login", "userID", client.UserID, "platformID", client.PlatformID, "old remote addr", getRemoteAdders(oldClients))
|
||||||
ws.multiTerminalLoginChecker(cli)
|
|
||||||
log.ZInfo(client.ctx, "repeat login", "userID", client.UserID, "platformID", client.PlatformID, "old remote addr", getRemoteAdders(cli))
|
|
||||||
atomic.AddInt64(&ws.onlineUserConnNum, 1)
|
atomic.AddInt64(&ws.onlineUserConnNum, 1)
|
||||||
} else {
|
} else {
|
||||||
ws.clients.Set(client.UserID, client)
|
|
||||||
atomic.AddInt64(&ws.onlineUserConnNum, 1)
|
atomic.AddInt64(&ws.onlineUserConnNum, 1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -156,7 +178,47 @@ func getRemoteAdders(client []*Client) string {
|
|||||||
return ret
|
return ret
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ws *WsServer) multiTerminalLoginChecker(client []*Client) {
|
func (ws *WsServer) multiTerminalLoginChecker(info *kickHandler) {
|
||||||
|
switch config.Config.MultiLoginPolicy {
|
||||||
|
case constant.DefalutNotKick:
|
||||||
|
case constant.PCAndOther:
|
||||||
|
if constant.PlatformIDToClass(info.newClient.PlatformID) == constant.TerminalPC {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fallthrough
|
||||||
|
case constant.AllLoginButSameTermKick:
|
||||||
|
if info.clientOK {
|
||||||
|
ws.clients.deleteClients(info.newClient.UserID, info.oldClients)
|
||||||
|
for _, c := range info.oldClients {
|
||||||
|
err := c.KickOnlineMessage()
|
||||||
|
if err != nil {
|
||||||
|
log.ZWarn(c.ctx, "KickOnlineMessage", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
m, err := ws.cache.GetTokensWithoutError(info.newClient.ctx, info.newClient.UserID, info.newClient.PlatformID)
|
||||||
|
if err != nil && err != redis.Nil {
|
||||||
|
log.ZWarn(info.newClient.ctx, "get token from redis err", err, "userID", info.newClient.UserID, "platformID", info.newClient.PlatformID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if m == nil {
|
||||||
|
log.ZWarn(info.newClient.ctx, "m is nil", errors.New("m is nil"), "userID", info.newClient.UserID, "platformID", info.newClient.PlatformID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.ZDebug(info.newClient.ctx, "get token from redis", "userID", info.newClient.UserID, "platformID", info.newClient.PlatformID, "tokenMap", m)
|
||||||
|
|
||||||
|
for k, _ := range m {
|
||||||
|
if k != info.newClient.ctx.GetToken() {
|
||||||
|
m[k] = constant.KickedToken
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log.ZDebug(info.newClient.ctx, "set token map is ", "token map", m, "userID", info.newClient.UserID)
|
||||||
|
err = ws.cache.SetTokenMapByUidPid(info.newClient.ctx, info.newClient.UserID, info.newClient.PlatformID, m)
|
||||||
|
if err != nil {
|
||||||
|
log.ZWarn(info.newClient.ctx, "SetTokenMapByUidPid err", err, "userID", info.newClient.UserID, "platformID", info.newClient.PlatformID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
func (ws *WsServer) unregisterClient(client *Client) {
|
func (ws *WsServer) unregisterClient(client *Client) {
|
||||||
@ -170,60 +232,83 @@ func (ws *WsServer) unregisterClient(client *Client) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) {
|
func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
context := newContext(w, r)
|
defer log.ZInfo(context.Background(), "wsHandler", "remote addr", "url", r.URL.String())
|
||||||
|
connContext := newContext(w, r)
|
||||||
if ws.onlineUserConnNum >= ws.wsMaxConnNum {
|
if ws.onlineUserConnNum >= ws.wsMaxConnNum {
|
||||||
httpError(context, errs.ErrConnOverMaxNumLimit)
|
httpError(connContext, errs.ErrConnOverMaxNumLimit)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
var (
|
var (
|
||||||
token string
|
token string
|
||||||
userID string
|
userID string
|
||||||
platformID string
|
platformIDStr string
|
||||||
exists bool
|
exists bool
|
||||||
compression bool
|
compression bool
|
||||||
)
|
)
|
||||||
|
|
||||||
token, exists = context.Query(Token)
|
token, exists = connContext.Query(Token)
|
||||||
if !exists {
|
if !exists {
|
||||||
httpError(context, errs.ErrConnArgsErr)
|
httpError(connContext, errs.ErrConnArgsErr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
userID, exists = context.Query(WsUserID)
|
userID, exists = connContext.Query(WsUserID)
|
||||||
if !exists {
|
if !exists {
|
||||||
httpError(context, errs.ErrConnArgsErr)
|
httpError(connContext, errs.ErrConnArgsErr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
platformID, exists = context.Query(PlatformID)
|
platformIDStr, exists = connContext.Query(PlatformID)
|
||||||
if !exists || utils.StringToInt(platformID) == 0 {
|
if !exists {
|
||||||
httpError(context, errs.ErrConnArgsErr)
|
httpError(connContext, errs.ErrConnArgsErr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// log.ZDebug(context2.Background(), "conn", "platformID", platformID)
|
platformID, err := strconv.Atoi(platformIDStr)
|
||||||
err := tokenverify.WsVerifyToken(token, userID, platformID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(context, err)
|
httpError(connContext, errs.ErrConnArgsErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := tokenverify.WsVerifyToken(token, userID, platformID); err != nil {
|
||||||
|
httpError(connContext, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
m, err := ws.cache.GetTokensWithoutError(context.Background(), userID, platformID)
|
||||||
|
if err != nil {
|
||||||
|
httpError(connContext, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if v, ok := m[token]; ok {
|
||||||
|
switch v {
|
||||||
|
case constant.NormalToken:
|
||||||
|
case constant.KickedToken:
|
||||||
|
httpError(connContext, errs.ErrTokenKicked.Wrap())
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
httpError(connContext, errs.ErrTokenUnknown.Wrap())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
httpError(connContext, errs.ErrTokenNotExist.Wrap())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
wsLongConn := newGWebSocket(WebSocket, ws.handshakeTimeout)
|
wsLongConn := newGWebSocket(WebSocket, ws.handshakeTimeout)
|
||||||
err = wsLongConn.GenerateLongConn(w, r)
|
err = wsLongConn.GenerateLongConn(w, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(context, err)
|
httpError(connContext, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
compressProtoc, exists := context.Query(Compression)
|
compressProtoc, exists := connContext.Query(Compression)
|
||||||
if exists {
|
if exists {
|
||||||
if compressProtoc == GzipCompressionProtocol {
|
if compressProtoc == GzipCompressionProtocol {
|
||||||
compression = true
|
compression = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
compressProtoc, exists = context.GetHeader(Compression)
|
compressProtoc, exists = connContext.GetHeader(Compression)
|
||||||
if exists {
|
if exists {
|
||||||
if compressProtoc == GzipCompressionProtocol {
|
if compressProtoc == GzipCompressionProtocol {
|
||||||
compression = true
|
compression = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
client := ws.clientPool.Get().(*Client)
|
client := ws.clientPool.Get().(*Client)
|
||||||
client.ResetClient(context, wsLongConn, context.GetBackground(), compression, ws)
|
client.ResetClient(connContext, wsLongConn, connContext.GetBackground(), compression, ws)
|
||||||
ws.registerChan <- client
|
ws.registerChan <- client
|
||||||
go client.readMessage()
|
go client.readMessage()
|
||||||
}
|
}
|
||||||
|
@ -3,6 +3,7 @@ package msggateway
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/log"
|
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/log"
|
||||||
|
"github.com/OpenIMSDK/Open-IM-Server/pkg/utils"
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -71,6 +72,29 @@ func (u *UserMap) delete(key string, connRemoteAddr string) (isDeleteUser bool)
|
|||||||
}
|
}
|
||||||
return existed
|
return existed
|
||||||
}
|
}
|
||||||
|
func (u *UserMap) deleteClients(key string, clients []*Client) (isDeleteUser bool) {
|
||||||
|
m := utils.SliceToMapAny(clients, func(c *Client) (string, struct{}) {
|
||||||
|
return c.ctx.GetRemoteAddr(), struct{}{}
|
||||||
|
})
|
||||||
|
allClients, existed := u.m.Load(key)
|
||||||
|
if existed {
|
||||||
|
oldClients := allClients.([]*Client)
|
||||||
|
var a []*Client
|
||||||
|
for _, client := range oldClients {
|
||||||
|
if _, ok := m[client.ctx.GetRemoteAddr()]; !ok {
|
||||||
|
a = append(a, client)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(a) == 0 {
|
||||||
|
u.m.Delete(key)
|
||||||
|
return true
|
||||||
|
} else {
|
||||||
|
u.m.Store(key, a)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return existed
|
||||||
|
}
|
||||||
func (u *UserMap) DeleteAll(key string) {
|
func (u *UserMap) DeleteAll(key string) {
|
||||||
u.m.Delete(key)
|
u.m.Delete(key)
|
||||||
}
|
}
|
||||||
|
@ -42,7 +42,7 @@ func (s *authServer) UserToken(ctx context.Context, req *pbAuth.UserTokenReq) (*
|
|||||||
if _, err := s.userRpcClient.GetUserInfo(ctx, req.UserID); err != nil {
|
if _, err := s.userRpcClient.GetUserInfo(ctx, req.UserID); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
token, err := s.authDatabase.CreateToken(ctx, req.UserID, constant.PlatformIDToName(int(req.PlatformID)))
|
token, err := s.authDatabase.CreateToken(ctx, req.UserID, int(req.PlatformID))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -56,7 +56,7 @@ func (s *authServer) parseToken(ctx context.Context, tokensString string) (claim
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, utils.Wrap(err, "")
|
return nil, utils.Wrap(err, "")
|
||||||
}
|
}
|
||||||
m, err := s.authDatabase.GetTokensWithoutError(ctx, claims.UID, claims.Platform)
|
m, err := s.authDatabase.GetTokensWithoutError(ctx, claims.UserID, claims.PlatformID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -82,8 +82,8 @@ func (s *authServer) ParseToken(ctx context.Context, req *pbAuth.ParseTokenReq)
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
resp.UserID = claims.UID
|
resp.UserID = claims.UserID
|
||||||
resp.Platform = claims.Platform
|
resp.Platform = constant.PlatformIDToName(claims.PlatformID)
|
||||||
resp.ExpireTimeSeconds = claims.ExpiresAt.Unix()
|
resp.ExpireTimeSeconds = claims.ExpiresAt.Unix()
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
@ -118,6 +118,7 @@ const (
|
|||||||
ExpiredToken = 3
|
ExpiredToken = 3
|
||||||
|
|
||||||
//MultiTerminalLogin
|
//MultiTerminalLogin
|
||||||
|
DefalutNotKick = 0
|
||||||
//Full-end login, but the same end is mutually exclusive
|
//Full-end login, but the same end is mutually exclusive
|
||||||
AllLoginButSameTermKick = 1
|
AllLoginButSameTermKick = 1
|
||||||
//Only one of the endpoints can log in
|
//Only one of the endpoints can log in
|
||||||
|
@ -57,7 +57,7 @@ var PlatformName2ID = map[string]int{
|
|||||||
IPadPlatformStr: IPadPlatformID,
|
IPadPlatformStr: IPadPlatformID,
|
||||||
AdminPlatformStr: AdminPlatformID,
|
AdminPlatformStr: AdminPlatformID,
|
||||||
}
|
}
|
||||||
var Platform2class = map[string]string{
|
var PlatformName2class = map[string]string{
|
||||||
IOSPlatformStr: TerminalMobile,
|
IOSPlatformStr: TerminalMobile,
|
||||||
AndroidPlatformStr: TerminalMobile,
|
AndroidPlatformStr: TerminalMobile,
|
||||||
MiniWebPlatformStr: WebPlatformStr,
|
MiniWebPlatformStr: WebPlatformStr,
|
||||||
@ -66,6 +66,15 @@ var Platform2class = map[string]string{
|
|||||||
OSXPlatformStr: TerminalPC,
|
OSXPlatformStr: TerminalPC,
|
||||||
LinuxPlatformStr: TerminalPC,
|
LinuxPlatformStr: TerminalPC,
|
||||||
}
|
}
|
||||||
|
var PlatformID2class = map[int]string{
|
||||||
|
IOSPlatformID: TerminalMobile,
|
||||||
|
AndroidPlatformID: TerminalMobile,
|
||||||
|
MiniWebPlatformID: WebPlatformStr,
|
||||||
|
WebPlatformID: WebPlatformStr,
|
||||||
|
WindowsPlatformID: TerminalPC,
|
||||||
|
OSXPlatformID: TerminalPC,
|
||||||
|
LinuxPlatformID: TerminalPC,
|
||||||
|
}
|
||||||
|
|
||||||
func PlatformIDToName(num int) string {
|
func PlatformIDToName(num int) string {
|
||||||
return PlatformID2Name[num]
|
return PlatformID2Name[num]
|
||||||
@ -74,5 +83,8 @@ func PlatformNameToID(name string) int {
|
|||||||
return PlatformName2ID[name]
|
return PlatformName2ID[name]
|
||||||
}
|
}
|
||||||
func PlatformNameToClass(name string) string {
|
func PlatformNameToClass(name string) string {
|
||||||
return Platform2class[name]
|
return PlatformName2class[name]
|
||||||
|
}
|
||||||
|
func PlatformIDToClass(num int) string {
|
||||||
|
return PlatformID2class[num]
|
||||||
}
|
}
|
||||||
|
18
pkg/common/db/cache/msg.go
vendored
18
pkg/common/db/cache/msg.go
vendored
@ -88,9 +88,9 @@ type MsgModel interface {
|
|||||||
SeqCache
|
SeqCache
|
||||||
thirdCache
|
thirdCache
|
||||||
AddTokenFlag(ctx context.Context, userID string, platformID int, token string, flag int) error
|
AddTokenFlag(ctx context.Context, userID string, platformID int, token string, flag int) error
|
||||||
GetTokensWithoutError(ctx context.Context, userID, platformID string) (map[string]int, error)
|
GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error)
|
||||||
SetTokenMapByUidPid(ctx context.Context, userID string, platform string, m map[string]int) error
|
SetTokenMapByUidPid(ctx context.Context, userID string, platformID int, m map[string]int) error
|
||||||
DeleteTokenByUidPid(ctx context.Context, userID string, platform string, fields []string) error
|
DeleteTokenByUidPid(ctx context.Context, userID string, platformID int, fields []string) error
|
||||||
GetMessagesBySeq(ctx context.Context, conversationID string, seqs []int64) (seqMsg []*sdkws.MsgData, failedSeqList []int64, err error)
|
GetMessagesBySeq(ctx context.Context, conversationID string, seqs []int64) (seqMsg []*sdkws.MsgData, failedSeqList []int64, err error)
|
||||||
SetMessageToCache(ctx context.Context, conversationID string, msgs []*sdkws.MsgData) (int, error)
|
SetMessageToCache(ctx context.Context, conversationID string, msgs []*sdkws.MsgData) (int, error)
|
||||||
UserDeleteMsgs(ctx context.Context, conversationID string, seqs []int64, userID string) error
|
UserDeleteMsgs(ctx context.Context, conversationID string, seqs []int64, userID string) error
|
||||||
@ -260,8 +260,8 @@ func (c *msgCache) AddTokenFlag(ctx context.Context, userID string, platformID i
|
|||||||
return errs.Wrap(c.rdb.HSet(ctx, key, token, flag).Err())
|
return errs.Wrap(c.rdb.HSet(ctx, key, token, flag).Err())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *msgCache) GetTokensWithoutError(ctx context.Context, userID, platformID string) (map[string]int, error) {
|
func (c *msgCache) GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error) {
|
||||||
key := uidPidToken + userID + ":" + platformID
|
key := uidPidToken + userID + ":" + constant.PlatformIDToName(platformID)
|
||||||
m, err := c.rdb.HGetAll(ctx, key).Result()
|
m, err := c.rdb.HGetAll(ctx, key).Result()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errs.Wrap(err)
|
return nil, errs.Wrap(err)
|
||||||
@ -273,8 +273,8 @@ func (c *msgCache) GetTokensWithoutError(ctx context.Context, userID, platformID
|
|||||||
return mm, nil
|
return mm, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *msgCache) SetTokenMapByUidPid(ctx context.Context, userID string, platform string, m map[string]int) error {
|
func (c *msgCache) SetTokenMapByUidPid(ctx context.Context, userID string, platform int, m map[string]int) error {
|
||||||
key := uidPidToken + userID + ":" + platform
|
key := uidPidToken + userID + ":" + constant.PlatformIDToName(platform)
|
||||||
mm := make(map[string]interface{})
|
mm := make(map[string]interface{})
|
||||||
for k, v := range m {
|
for k, v := range m {
|
||||||
mm[k] = v
|
mm[k] = v
|
||||||
@ -282,8 +282,8 @@ func (c *msgCache) SetTokenMapByUidPid(ctx context.Context, userID string, platf
|
|||||||
return errs.Wrap(c.rdb.HSet(ctx, key, mm).Err())
|
return errs.Wrap(c.rdb.HSet(ctx, key, mm).Err())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *msgCache) DeleteTokenByUidPid(ctx context.Context, userID string, platform string, fields []string) error {
|
func (c *msgCache) DeleteTokenByUidPid(ctx context.Context, userID string, platform int, fields []string) error {
|
||||||
key := uidPidToken + userID + ":" + platform
|
key := uidPidToken + userID + ":" + constant.PlatformIDToName(platform)
|
||||||
return errs.Wrap(c.rdb.HDel(ctx, key, fields...).Err())
|
return errs.Wrap(c.rdb.HDel(ctx, key, fields...).Err())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -12,9 +12,9 @@ import (
|
|||||||
|
|
||||||
type AuthDatabase interface {
|
type AuthDatabase interface {
|
||||||
//结果为空 不返回错误
|
//结果为空 不返回错误
|
||||||
GetTokensWithoutError(ctx context.Context, userID, platform string) (map[string]int, error)
|
GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error)
|
||||||
//创建token
|
//创建token
|
||||||
CreateToken(ctx context.Context, userID string, platform string) (string, error)
|
CreateToken(ctx context.Context, userID string, platformID int) (string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type authDatabase struct {
|
type authDatabase struct {
|
||||||
@ -29,13 +29,13 @@ func NewAuthDatabase(cache cache.MsgModel, accessSecret string, accessExpire int
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 结果为空 不返回错误
|
// 结果为空 不返回错误
|
||||||
func (a *authDatabase) GetTokensWithoutError(ctx context.Context, userID, platform string) (map[string]int, error) {
|
func (a *authDatabase) GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error) {
|
||||||
return a.cache.GetTokensWithoutError(ctx, userID, platform)
|
return a.cache.GetTokensWithoutError(ctx, userID, platformID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 创建token
|
// 创建token
|
||||||
func (a *authDatabase) CreateToken(ctx context.Context, userID string, platform string) (string, error) {
|
func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformID int) (string, error) {
|
||||||
tokens, err := a.cache.GetTokensWithoutError(ctx, userID, platform)
|
tokens, err := a.cache.GetTokensWithoutError(ctx, userID, platformID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
@ -47,16 +47,16 @@ func (a *authDatabase) CreateToken(ctx context.Context, userID string, platform
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(deleteTokenKey) != 0 {
|
if len(deleteTokenKey) != 0 {
|
||||||
err := a.cache.DeleteTokenByUidPid(ctx, userID, platform, deleteTokenKey)
|
err := a.cache.DeleteTokenByUidPid(ctx, userID, platformID, deleteTokenKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
claims := tokenverify.BuildClaims(userID, platform, a.accessExpire)
|
claims := tokenverify.BuildClaims(userID, platformID, a.accessExpire)
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||||
tokenString, err := token.SignedString([]byte(a.accessSecret))
|
tokenString, err := token.SignedString([]byte(a.accessSecret))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", utils.Wrap(err, "")
|
return "", utils.Wrap(err, "")
|
||||||
}
|
}
|
||||||
return tokenString, a.cache.AddTokenFlag(ctx, userID, constant.PlatformNameToID(platform), tokenString, constant.NormalToken)
|
return tokenString, a.cache.AddTokenFlag(ctx, userID, platformID, tokenString, constant.NormalToken)
|
||||||
}
|
}
|
||||||
|
@ -128,7 +128,7 @@ func GinParseToken(rdb redis.UniversalClient) gin.HandlerFunc {
|
|||||||
c.Abort()
|
c.Abort()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
m, err := dataBase.GetTokensWithoutError(c, claims.UID, claims.Platform)
|
m, err := dataBase.GetTokensWithoutError(c, claims.UserID, claims.PlatformID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.ZWarn(c, "cache get token error", errs.ErrTokenNotExist.Wrap())
|
log.ZWarn(c, "cache get token error", errs.ErrTokenNotExist.Wrap())
|
||||||
apiresp.GinError(c, errs.ErrTokenNotExist.Wrap())
|
apiresp.GinError(c, errs.ErrTokenNotExist.Wrap())
|
||||||
@ -155,9 +155,12 @@ func GinParseToken(rdb redis.UniversalClient) gin.HandlerFunc {
|
|||||||
c.Abort()
|
c.Abort()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
apiresp.GinError(c, errs.ErrTokenNotExist.Wrap())
|
||||||
|
return
|
||||||
}
|
}
|
||||||
c.Set(constant.OpUserPlatform, claims.Platform)
|
c.Set(constant.OpUserPlatform, constant.PlatformIDToName(claims.PlatformID))
|
||||||
c.Set(constant.OpUserID, claims.UID)
|
c.Set(constant.OpUserID, claims.UserID)
|
||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -4,27 +4,25 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/config"
|
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/config"
|
||||||
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant"
|
|
||||||
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/mcontext"
|
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/mcontext"
|
||||||
"github.com/OpenIMSDK/Open-IM-Server/pkg/errs"
|
"github.com/OpenIMSDK/Open-IM-Server/pkg/errs"
|
||||||
"github.com/OpenIMSDK/Open-IM-Server/pkg/utils"
|
"github.com/OpenIMSDK/Open-IM-Server/pkg/utils"
|
||||||
"github.com/golang-jwt/jwt/v4"
|
"github.com/golang-jwt/jwt/v4"
|
||||||
"strconv"
|
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Claims struct {
|
type Claims struct {
|
||||||
UID string
|
UserID string
|
||||||
Platform string //login platform
|
PlatformID int //login platform
|
||||||
jwt.RegisteredClaims
|
jwt.RegisteredClaims
|
||||||
}
|
}
|
||||||
|
|
||||||
func BuildClaims(uid, platform string, ttl int64) Claims {
|
func BuildClaims(uid string, platformID int, ttl int64) Claims {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
before := now.Add(-time.Minute * 5)
|
before := now.Add(-time.Minute * 5)
|
||||||
return Claims{
|
return Claims{
|
||||||
UID: uid,
|
UserID: uid,
|
||||||
Platform: platform,
|
PlatformID: platformID,
|
||||||
RegisteredClaims: jwt.RegisteredClaims{
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
ExpiresAt: jwt.NewNumericDate(now.Add(time.Duration(ttl*24) * time.Hour)), //Expiration time
|
ExpiresAt: jwt.NewNumericDate(now.Add(time.Duration(ttl*24) * time.Hour)), //Expiration time
|
||||||
IssuedAt: jwt.NewNumericDate(now), //Issuing time
|
IssuedAt: jwt.NewNumericDate(now), //Issuing time
|
||||||
@ -90,24 +88,16 @@ func ParseRedisInterfaceToken(redisToken interface{}) (*Claims, error) {
|
|||||||
func IsManagerUserID(opUserID string) bool {
|
func IsManagerUserID(opUserID string) bool {
|
||||||
return utils.IsContain(opUserID, config.Config.Manager.AppManagerUid)
|
return utils.IsContain(opUserID, config.Config.Manager.AppManagerUid)
|
||||||
}
|
}
|
||||||
func WsVerifyToken(token, userID, platformID string) error {
|
func WsVerifyToken(token, userID string, platformID int) error {
|
||||||
platformIDInt, err := strconv.Atoi(platformID)
|
|
||||||
if err != nil {
|
|
||||||
return errs.ErrArgs.Wrap(fmt.Sprintf("platformID %s is not int", platformID))
|
|
||||||
}
|
|
||||||
platform := constant.PlatformIDToName(platformIDInt)
|
|
||||||
if platform == "" {
|
|
||||||
return errs.ErrArgs.Wrap(fmt.Sprintf("platformID %s is not exist", platformID))
|
|
||||||
}
|
|
||||||
claim, err := GetClaimFromToken(token)
|
claim, err := GetClaimFromToken(token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if claim.UID != userID {
|
if claim.UserID != userID {
|
||||||
return errs.ErrTokenInvalid.Wrap(fmt.Sprintf("token uid %s != userID %s", claim.UID, userID))
|
return errs.ErrTokenInvalid.Wrap(fmt.Sprintf("token uid %s != userID %s", claim.UserID, userID))
|
||||||
}
|
}
|
||||||
if claim.Platform != platform {
|
if claim.PlatformID != platformID {
|
||||||
return errs.ErrTokenInvalid.Wrap(fmt.Sprintf("token platform %s != %s", claim.Platform, platform))
|
return errs.ErrTokenInvalid.Wrap(fmt.Sprintf("token platform %d != %d", claim.PlatformID, platformID))
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user