From e379067601a7c6e50cc272b8a2da7723f4734b58 Mon Sep 17 00:00:00 2001 From: Gordon <46924906+FGadvancer@users.noreply.github.com> Date: Thu, 11 Apr 2024 21:12:45 +0800 Subject: [PATCH] refactor: websocket auth change to call rpc of auth. --- internal/msggateway/client.go | 8 +- internal/msggateway/constant.go | 2 +- internal/msggateway/context.go | 47 ++++++++++ internal/msggateway/long_conn.go | 20 +++++ internal/msggateway/n_ws_server.go | 133 +++++++++-------------------- 5 files changed, 114 insertions(+), 96 deletions(-) diff --git a/internal/msggateway/client.go b/internal/msggateway/client.go index 8be20ce9c..806b8cac2 100644 --- a/internal/msggateway/client.go +++ b/internal/msggateway/client.go @@ -87,19 +87,19 @@ type Client struct { // } // ResetClient updates the client's state with new connection and context information. -func (c *Client) ResetClient(ctx *UserConnContext, conn LongConn, isBackground, isCompress bool, longConnServer LongConnServer, token string) { +func (c *Client) ResetClient(ctx *UserConnContext, conn LongConn, longConnServer LongConnServer) { c.w = new(sync.Mutex) c.conn = conn c.PlatformID = stringutil.StringToInt(ctx.GetPlatformID()) - c.IsCompress = isCompress - c.IsBackground = isBackground + c.IsCompress = ctx.GetCompression() + c.IsBackground = ctx.GetBackground() c.UserID = ctx.GetUserID() c.ctx = ctx c.longConnServer = longConnServer c.IsBackground = false c.closed.Store(false) c.closedErr = nil - c.token = token + c.token = ctx.GetToken() } func (c *Client) pingHandler(_ string) error { diff --git a/internal/msggateway/constant.go b/internal/msggateway/constant.go index 045629b4e..477fcfde6 100644 --- a/internal/msggateway/constant.go +++ b/internal/msggateway/constant.go @@ -26,7 +26,7 @@ const ( Compression = "compression" GzipCompressionProtocol = "gzip" BackgroundStatus = "isBackground" - MsgResp = "isMsgResp" + ErrResp = "errResp" ) const ( diff --git a/internal/msggateway/context.go b/internal/msggateway/context.go index c3e6f5014..87ed25d75 100644 --- a/internal/msggateway/context.go +++ b/internal/msggateway/context.go @@ -15,6 +15,7 @@ package msggateway import ( + "github.com/openimsdk/open-im-server/v3/pkg/common/servererrs" "net/http" "net/url" "strconv" @@ -135,6 +136,32 @@ func (c *UserConnContext) GetToken() string { return c.Req.URL.Query().Get(Token) } +func (c *UserConnContext) GetCompression() bool { + compression, exists := c.Query(Compression) + if exists && compression == GzipCompressionProtocol { + return true + } else { + compression, exists := c.GetHeader(Compression) + if exists && compression == GzipCompressionProtocol { + return true + } + } + return false +} + +func (c *UserConnContext) ShouldSendError() bool { + errResp, exists := c.Query(ErrResp) + if exists { + b, err := strconv.ParseBool(errResp) + if err != nil { + return false + } else { + return b + } + } + return false +} + func (c *UserConnContext) SetToken(token string) { c.Req.URL.RawQuery = Token + "=" + token } @@ -146,3 +173,23 @@ func (c *UserConnContext) GetBackground() bool { } return b } +func (c *UserConnContext) ParseEssentialArgs() error { + _, exists := c.Query(Token) + if !exists { + return servererrs.ErrConnArgsErr.WrapMsg("token is empty") + } + _, exists = c.Query(WsUserID) + if !exists { + return servererrs.ErrConnArgsErr.WrapMsg("sendID is empty") + } + platformIDStr, exists := c.Query(PlatformID) + if !exists { + return servererrs.ErrConnArgsErr.WrapMsg("platformID is empty") + } + _, err := strconv.Atoi(platformIDStr) + if err != nil { + return servererrs.ErrConnArgsErr.WrapMsg("platformID is not int") + + } + return nil +} diff --git a/internal/msggateway/long_conn.go b/internal/msggateway/long_conn.go index dfd5e2e87..92bc189b1 100644 --- a/internal/msggateway/long_conn.go +++ b/internal/msggateway/long_conn.go @@ -15,6 +15,8 @@ package msggateway import ( + "encoding/json" + "github.com/openimsdk/tools/apiresp" "net/http" "time" @@ -143,6 +145,24 @@ func (d *GWebSocket) SetPingHandler(handler PingPongHandler) { d.conn.SetPingHandler(handler) } +func (d *GWebSocket) RespErrInfo(err error, w http.ResponseWriter, r *http.Request) error { + if err := d.GenerateLongConn(w, r); err != nil { + return err + } + data, err := json.Marshal(apiresp.ParseError(err)) + if err != nil { + _ = d.Close() + return errs.WrapMsg(err, "json marshal failed") + } + + if err := d.WriteMessage(MessageText, data); err != nil { + _ = d.Close() + return errs.WrapMsg(err, "WriteMessage failed") + } + _ = d.Close() + return nil +} + // func (d *GWebSocket) CheckSendConnDiffNow() bool { // return d.conn == d.sendConn //} diff --git a/internal/msggateway/n_ws_server.go b/internal/msggateway/n_ws_server.go index 44c17a6a9..0a0e2bd8b 100644 --- a/internal/msggateway/n_ws_server.go +++ b/internal/msggateway/n_ws_server.go @@ -16,23 +16,20 @@ package msggateway import ( "context" - "encoding/json" "fmt" + pbAuth "github.com/openimsdk/protocol/auth" "net/http" - "strconv" "sync" "sync/atomic" "time" "github.com/go-playground/validator/v10" - "github.com/openimsdk/open-im-server/v3/pkg/authverify" "github.com/openimsdk/open-im-server/v3/pkg/common/db/cache" "github.com/openimsdk/open-im-server/v3/pkg/common/prommetrics" "github.com/openimsdk/open-im-server/v3/pkg/common/servererrs" "github.com/openimsdk/open-im-server/v3/pkg/rpcclient" "github.com/openimsdk/protocol/constant" "github.com/openimsdk/protocol/msggateway" - "github.com/openimsdk/tools/apiresp" "github.com/openimsdk/tools/discovery" "github.com/openimsdk/tools/errs" "github.com/openimsdk/tools/log" @@ -73,6 +70,7 @@ type WsServer struct { validate *validator.Validate cache cache.TokenModel userClient *rpcclient.UserRpcClient + authClient *rpcclient.Auth disCov discovery.SvcDiscoveryRegistry Compressor Encoder @@ -88,6 +86,7 @@ type kickHandler struct { func (ws *WsServer) SetDiscoveryRegistry(disCov discovery.SvcDiscoveryRegistry, config *Config) { ws.MessageHandler = NewGrpcHandler(ws.validate, disCov, &config.Share.RpcRegisterName) u := rpcclient.NewUserRpcClient(disCov, config.Share.RpcRegisterName.User, config.Share.IMAdminUserID) + ws.authClient = rpcclient.NewAuth(disCov, config.Share.RpcRegisterName.Auth) ws.userClient = &u ws.disCov = disCov } @@ -408,102 +407,54 @@ func (ws *WsServer) unregisterClient(client *Client) { ) } -func (ws *WsServer) ParseWSArgs(r *http.Request) (args *WSArgs, err error) { - var v WSArgs - defer func() { - args = &v - }() - query := r.URL.Query() - v.MsgResp, _ = strconv.ParseBool(query.Get(MsgResp)) - if ws.onlineUserConnNum.Load() >= ws.wsMaxConnNum { - return nil, servererrs.ErrConnOverMaxNumLimit.WrapMsg("over max conn num limit") +// validateRespWithRequest checks if the response matches the expected userID and platformID. +func (ws *WsServer) validateRespWithRequest(ctx *UserConnContext, resp *pbAuth.ParseTokenResp) error { + userID := ctx.GetUserID() + platformID := stringutil.StringToInt32(ctx.GetPlatformID()) + if resp.UserID != userID { + return servererrs.ErrTokenInvalid.WrapMsg(fmt.Sprintf("token uid %s != userID %s", resp.UserID, userID)) } - if v.Token = query.Get(Token); v.Token == "" { - return nil, servererrs.ErrConnArgsErr.WrapMsg("token is empty") + if resp.PlatformID != platformID { + return servererrs.ErrTokenInvalid.WrapMsg(fmt.Sprintf("token platform %d != platformID %d", resp.PlatformID, platformID)) } - if v.UserID = query.Get(WsUserID); v.UserID == "" { - return nil, servererrs.ErrConnArgsErr.WrapMsg("sendID is empty") - } - platformIDStr := query.Get(PlatformID) - if platformIDStr == "" { - return nil, servererrs.ErrConnArgsErr.WrapMsg("platformID is empty") - } - platformID, err := strconv.Atoi(platformIDStr) - if err != nil { - return nil, servererrs.ErrConnArgsErr.WrapMsg("platformID is not int") - } - v.PlatformID = platformID - if err = authverify.WsVerifyToken(v.Token, v.UserID, ws.msgGatewayConfig.Share.Secret, platformID); err != nil { - return nil, err - } - if query.Get(Compression) == GzipCompressionProtocol { - v.Compression = true - } - if r.Header.Get(Compression) == GzipCompressionProtocol { - v.Compression = true - } - m, err := ws.cache.GetTokensWithoutError(context.Background(), v.UserID, platformID) - if err != nil { - return nil, err - } - if v, ok := m[v.Token]; ok { - switch v { - case constant.NormalToken: - case constant.KickedToken: - return nil, servererrs.ErrTokenKicked.Wrap() - default: - return nil, servererrs.ErrTokenUnknown.WrapMsg(fmt.Sprintf("token status is %d", v)) - } - } else { - return nil, servererrs.ErrTokenNotExist.Wrap() - } - return &v, nil -} - -type WSArgs struct { - Token string - UserID string - PlatformID int - Compression bool - MsgResp bool + return nil } func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) { connContext := newContext(w, r) - args, pErr := ws.ParseWSArgs(r) - var wsLongConn *GWebSocket - if args.MsgResp { - wsLongConn = newGWebSocket(WebSocket, ws.handshakeTimeout, ws.writeBufferSize) - if err := wsLongConn.GenerateLongConn(w, r); err != nil { - httpError(connContext, err) - return - } - data, err := json.Marshal(apiresp.ParseError(pErr)) - if err != nil { - _ = wsLongConn.Close() - return - } - if err := wsLongConn.WriteMessage(MessageText, data); err != nil { - _ = wsLongConn.Close() - return - } - if pErr != nil { - _ = wsLongConn.Close() - return - } - } else { - if pErr != nil { - httpError(connContext, pErr) - return - } - wsLongConn = newGWebSocket(WebSocket, ws.handshakeTimeout, ws.writeBufferSize) - if err := wsLongConn.GenerateLongConn(w, r); err != nil { - httpError(connContext, err) - return + if ws.onlineUserConnNum.Load() >= ws.wsMaxConnNum { + httpError(connContext, servererrs.ErrConnOverMaxNumLimit.WrapMsg("over max conn num limit")) + return + } + err := connContext.ParseEssentialArgs() + if err != nil { + httpError(connContext, err) + return + } + resp, err := ws.authClient.ParseToken(connContext, connContext.GetToken()) + if err != nil { + shouldSendError := connContext.ShouldSendError() + if shouldSendError { + wsLongConn := newGWebSocket(WebSocket, ws.handshakeTimeout, ws.writeBufferSize) + if err := wsLongConn.RespErrInfo(err, w, r); err == nil { + return + } } + httpError(connContext, err) + return + } + err = ws.validateRespWithRequest(connContext, resp) + if err != nil { + httpError(connContext, err) + return + } + wsLongConn := newGWebSocket(WebSocket, ws.handshakeTimeout, ws.writeBufferSize) + if err := wsLongConn.GenerateLongConn(w, r); err != nil { + httpError(connContext, err) + return } client := ws.clientPool.Get().(*Client) - client.ResetClient(connContext, wsLongConn, connContext.GetBackground(), args.Compression, ws, args.Token) + client.ResetClient(connContext, wsLongConn, ws) ws.registerChan <- client go client.readMessage() }