WsVerifyToken

This commit is contained in:
withchao 2023-06-14 10:15:58 +08:00
parent 87d64c6afe
commit 82e85c7083
2 changed files with 38 additions and 15 deletions

View File

@ -1,6 +1,7 @@
package msggateway package msggateway
import ( import (
"context"
"errors" "errors"
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/config" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/config"
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant"
@ -188,13 +189,13 @@ func (ws *WsServer) multiTerminalLoginChecker(info *kickHandler) {
for _, c := range info.oldClients { for _, c := range info.oldClients {
err := c.KickOnlineMessage() err := c.KickOnlineMessage()
if err != nil { if err != nil {
log.ZWarn() log.ZError(c.ctx, "KickOnlineMessage", err)
} }
} }
} }
} }
} }
func (ws *WsServer) unregisterClient(client *Client) { func (ws *WsServer) unregisterClient(client *Client) {
defer ws.clientPool.Put(client) defer ws.clientPool.Put(client)
isDeleteUser := ws.clients.delete(client.UserID, client.ctx.GetRemoteAddr()) isDeleteUser := ws.clients.delete(client.UserID, client.ctx.GetRemoteAddr())
@ -206,9 +207,9 @@ func (ws *WsServer) unregisterClient(client *Client) {
} }
func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) { func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) {
context := newContext(w, r) connContext := newContext(w, r)
if ws.onlineUserConnNum >= ws.wsMaxConnNum { if ws.onlineUserConnNum >= ws.wsMaxConnNum {
httpError(context, errs.ErrConnOverMaxNumLimit) httpError(connContext, errs.ErrConnOverMaxNumLimit)
return return
} }
var ( var (
@ -219,46 +220,65 @@ func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) {
compression bool compression bool
) )
token, exists = context.Query(Token) token, exists = connContext.Query(Token)
if !exists { if !exists {
httpError(context, errs.ErrConnArgsErr) httpError(connContext, errs.ErrConnArgsErr)
return return
} }
userID, exists = context.Query(WsUserID) userID, exists = connContext.Query(WsUserID)
if !exists { if !exists {
httpError(context, errs.ErrConnArgsErr) httpError(connContext, errs.ErrConnArgsErr)
return return
} }
platformID, exists = context.Query(PlatformID) platformID, exists = connContext.Query(PlatformID)
if !exists || utils.StringToInt(platformID) == 0 { if !exists || utils.StringToInt(platformID) == 0 {
httpError(context, errs.ErrConnArgsErr) httpError(connContext, errs.ErrConnArgsErr)
return return
} }
err := tokenverify.WsVerifyToken(token, userID, platformID) err := tokenverify.WsVerifyToken(token, userID, platformID)
if err != nil { if err != nil {
httpError(context, err) httpError(connContext, err)
return
}
m, err := ws.cache.GetTokensWithoutError(context.Background(), userID, platformID)
if err != nil {
httpError(connContext, err)
return
}
if v, ok := m[token]; ok {
switch v {
case constant.NormalToken:
case constant.KickedToken:
httpError(connContext, errs.ErrTokenKicked.Wrap())
return
default:
httpError(connContext, errs.ErrTokenUnknown.Wrap())
return
}
} else {
httpError(connContext, errs.ErrTokenNotExist.Wrap())
return return
} }
wsLongConn := newGWebSocket(WebSocket, ws.handshakeTimeout) wsLongConn := newGWebSocket(WebSocket, ws.handshakeTimeout)
err = wsLongConn.GenerateLongConn(w, r) err = wsLongConn.GenerateLongConn(w, r)
if err != nil { if err != nil {
httpError(context, err) httpError(connContext, err)
return return
} }
compressProtoc, exists := context.Query(Compression) compressProtoc, exists := connContext.Query(Compression)
if exists { if exists {
if compressProtoc == GzipCompressionProtocol { if compressProtoc == GzipCompressionProtocol {
compression = true compression = true
} }
} }
compressProtoc, exists = context.GetHeader(Compression) compressProtoc, exists = connContext.GetHeader(Compression)
if exists { if exists {
if compressProtoc == GzipCompressionProtocol { if compressProtoc == GzipCompressionProtocol {
compression = true compression = true
} }
} }
client := ws.clientPool.Get().(*Client) client := ws.clientPool.Get().(*Client)
client.ResetClient(context, wsLongConn, context.GetBackground(), compression, ws) client.ResetClient(connContext, wsLongConn, connContext.GetBackground(), compression, ws)
ws.registerChan <- client ws.registerChan <- client
go client.readMessage() go client.readMessage()
} }

View File

@ -155,6 +155,9 @@ func GinParseToken(rdb redis.UniversalClient) gin.HandlerFunc {
c.Abort() c.Abort()
return return
} }
} else {
apiresp.GinError(c, errs.ErrTokenNotExist.Wrap())
return
} }
c.Set(constant.OpUserPlatform, claims.Platform) c.Set(constant.OpUserPlatform, claims.Platform)
c.Set(constant.OpUserID, claims.UID) c.Set(constant.OpUserID, claims.UID)