refactor: websocket auth change to call rpc of auth.

This commit is contained in:
Gordon 2024-04-11 21:12:45 +08:00
parent 95b180e7dc
commit e379067601
5 changed files with 114 additions and 96 deletions

View File

@ -87,19 +87,19 @@ type Client struct {
// } // }
// ResetClient updates the client's state with new connection and context information. // ResetClient updates the client's state with new connection and context information.
func (c *Client) ResetClient(ctx *UserConnContext, conn LongConn, isBackground, isCompress bool, longConnServer LongConnServer, token string) { func (c *Client) ResetClient(ctx *UserConnContext, conn LongConn, longConnServer LongConnServer) {
c.w = new(sync.Mutex) c.w = new(sync.Mutex)
c.conn = conn c.conn = conn
c.PlatformID = stringutil.StringToInt(ctx.GetPlatformID()) c.PlatformID = stringutil.StringToInt(ctx.GetPlatformID())
c.IsCompress = isCompress c.IsCompress = ctx.GetCompression()
c.IsBackground = isBackground c.IsBackground = ctx.GetBackground()
c.UserID = ctx.GetUserID() c.UserID = ctx.GetUserID()
c.ctx = ctx c.ctx = ctx
c.longConnServer = longConnServer c.longConnServer = longConnServer
c.IsBackground = false c.IsBackground = false
c.closed.Store(false) c.closed.Store(false)
c.closedErr = nil c.closedErr = nil
c.token = token c.token = ctx.GetToken()
} }
func (c *Client) pingHandler(_ string) error { func (c *Client) pingHandler(_ string) error {

View File

@ -26,7 +26,7 @@ const (
Compression = "compression" Compression = "compression"
GzipCompressionProtocol = "gzip" GzipCompressionProtocol = "gzip"
BackgroundStatus = "isBackground" BackgroundStatus = "isBackground"
MsgResp = "isMsgResp" ErrResp = "errResp"
) )
const ( const (

View File

@ -15,6 +15,7 @@
package msggateway package msggateway
import ( import (
"github.com/openimsdk/open-im-server/v3/pkg/common/servererrs"
"net/http" "net/http"
"net/url" "net/url"
"strconv" "strconv"
@ -135,6 +136,32 @@ func (c *UserConnContext) GetToken() string {
return c.Req.URL.Query().Get(Token) return c.Req.URL.Query().Get(Token)
} }
func (c *UserConnContext) GetCompression() bool {
compression, exists := c.Query(Compression)
if exists && compression == GzipCompressionProtocol {
return true
} else {
compression, exists := c.GetHeader(Compression)
if exists && compression == GzipCompressionProtocol {
return true
}
}
return false
}
func (c *UserConnContext) ShouldSendError() bool {
errResp, exists := c.Query(ErrResp)
if exists {
b, err := strconv.ParseBool(errResp)
if err != nil {
return false
} else {
return b
}
}
return false
}
func (c *UserConnContext) SetToken(token string) { func (c *UserConnContext) SetToken(token string) {
c.Req.URL.RawQuery = Token + "=" + token c.Req.URL.RawQuery = Token + "=" + token
} }
@ -146,3 +173,23 @@ func (c *UserConnContext) GetBackground() bool {
} }
return b return b
} }
func (c *UserConnContext) ParseEssentialArgs() error {
_, exists := c.Query(Token)
if !exists {
return servererrs.ErrConnArgsErr.WrapMsg("token is empty")
}
_, exists = c.Query(WsUserID)
if !exists {
return servererrs.ErrConnArgsErr.WrapMsg("sendID is empty")
}
platformIDStr, exists := c.Query(PlatformID)
if !exists {
return servererrs.ErrConnArgsErr.WrapMsg("platformID is empty")
}
_, err := strconv.Atoi(platformIDStr)
if err != nil {
return servererrs.ErrConnArgsErr.WrapMsg("platformID is not int")
}
return nil
}

View File

@ -15,6 +15,8 @@
package msggateway package msggateway
import ( import (
"encoding/json"
"github.com/openimsdk/tools/apiresp"
"net/http" "net/http"
"time" "time"
@ -143,6 +145,24 @@ func (d *GWebSocket) SetPingHandler(handler PingPongHandler) {
d.conn.SetPingHandler(handler) d.conn.SetPingHandler(handler)
} }
func (d *GWebSocket) RespErrInfo(err error, w http.ResponseWriter, r *http.Request) error {
if err := d.GenerateLongConn(w, r); err != nil {
return err
}
data, err := json.Marshal(apiresp.ParseError(err))
if err != nil {
_ = d.Close()
return errs.WrapMsg(err, "json marshal failed")
}
if err := d.WriteMessage(MessageText, data); err != nil {
_ = d.Close()
return errs.WrapMsg(err, "WriteMessage failed")
}
_ = d.Close()
return nil
}
// func (d *GWebSocket) CheckSendConnDiffNow() bool { // func (d *GWebSocket) CheckSendConnDiffNow() bool {
// return d.conn == d.sendConn // return d.conn == d.sendConn
//} //}

View File

@ -16,23 +16,20 @@ package msggateway
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
pbAuth "github.com/openimsdk/protocol/auth"
"net/http" "net/http"
"strconv"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/go-playground/validator/v10" "github.com/go-playground/validator/v10"
"github.com/openimsdk/open-im-server/v3/pkg/authverify"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/cache" "github.com/openimsdk/open-im-server/v3/pkg/common/db/cache"
"github.com/openimsdk/open-im-server/v3/pkg/common/prommetrics" "github.com/openimsdk/open-im-server/v3/pkg/common/prommetrics"
"github.com/openimsdk/open-im-server/v3/pkg/common/servererrs" "github.com/openimsdk/open-im-server/v3/pkg/common/servererrs"
"github.com/openimsdk/open-im-server/v3/pkg/rpcclient" "github.com/openimsdk/open-im-server/v3/pkg/rpcclient"
"github.com/openimsdk/protocol/constant" "github.com/openimsdk/protocol/constant"
"github.com/openimsdk/protocol/msggateway" "github.com/openimsdk/protocol/msggateway"
"github.com/openimsdk/tools/apiresp"
"github.com/openimsdk/tools/discovery" "github.com/openimsdk/tools/discovery"
"github.com/openimsdk/tools/errs" "github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log" "github.com/openimsdk/tools/log"
@ -73,6 +70,7 @@ type WsServer struct {
validate *validator.Validate validate *validator.Validate
cache cache.TokenModel cache cache.TokenModel
userClient *rpcclient.UserRpcClient userClient *rpcclient.UserRpcClient
authClient *rpcclient.Auth
disCov discovery.SvcDiscoveryRegistry disCov discovery.SvcDiscoveryRegistry
Compressor Compressor
Encoder Encoder
@ -88,6 +86,7 @@ type kickHandler struct {
func (ws *WsServer) SetDiscoveryRegistry(disCov discovery.SvcDiscoveryRegistry, config *Config) { func (ws *WsServer) SetDiscoveryRegistry(disCov discovery.SvcDiscoveryRegistry, config *Config) {
ws.MessageHandler = NewGrpcHandler(ws.validate, disCov, &config.Share.RpcRegisterName) ws.MessageHandler = NewGrpcHandler(ws.validate, disCov, &config.Share.RpcRegisterName)
u := rpcclient.NewUserRpcClient(disCov, config.Share.RpcRegisterName.User, config.Share.IMAdminUserID) u := rpcclient.NewUserRpcClient(disCov, config.Share.RpcRegisterName.User, config.Share.IMAdminUserID)
ws.authClient = rpcclient.NewAuth(disCov, config.Share.RpcRegisterName.Auth)
ws.userClient = &u ws.userClient = &u
ws.disCov = disCov ws.disCov = disCov
} }
@ -408,102 +407,54 @@ func (ws *WsServer) unregisterClient(client *Client) {
) )
} }
func (ws *WsServer) ParseWSArgs(r *http.Request) (args *WSArgs, err error) { // validateRespWithRequest checks if the response matches the expected userID and platformID.
var v WSArgs func (ws *WsServer) validateRespWithRequest(ctx *UserConnContext, resp *pbAuth.ParseTokenResp) error {
defer func() { userID := ctx.GetUserID()
args = &v platformID := stringutil.StringToInt32(ctx.GetPlatformID())
}() if resp.UserID != userID {
query := r.URL.Query() return servererrs.ErrTokenInvalid.WrapMsg(fmt.Sprintf("token uid %s != userID %s", resp.UserID, userID))
v.MsgResp, _ = strconv.ParseBool(query.Get(MsgResp))
if ws.onlineUserConnNum.Load() >= ws.wsMaxConnNum {
return nil, servererrs.ErrConnOverMaxNumLimit.WrapMsg("over max conn num limit")
} }
if v.Token = query.Get(Token); v.Token == "" { if resp.PlatformID != platformID {
return nil, servererrs.ErrConnArgsErr.WrapMsg("token is empty") return servererrs.ErrTokenInvalid.WrapMsg(fmt.Sprintf("token platform %d != platformID %d", resp.PlatformID, platformID))
} }
if v.UserID = query.Get(WsUserID); v.UserID == "" { return nil
return nil, servererrs.ErrConnArgsErr.WrapMsg("sendID is empty")
}
platformIDStr := query.Get(PlatformID)
if platformIDStr == "" {
return nil, servererrs.ErrConnArgsErr.WrapMsg("platformID is empty")
}
platformID, err := strconv.Atoi(platformIDStr)
if err != nil {
return nil, servererrs.ErrConnArgsErr.WrapMsg("platformID is not int")
}
v.PlatformID = platformID
if err = authverify.WsVerifyToken(v.Token, v.UserID, ws.msgGatewayConfig.Share.Secret, platformID); err != nil {
return nil, err
}
if query.Get(Compression) == GzipCompressionProtocol {
v.Compression = true
}
if r.Header.Get(Compression) == GzipCompressionProtocol {
v.Compression = true
}
m, err := ws.cache.GetTokensWithoutError(context.Background(), v.UserID, platformID)
if err != nil {
return nil, err
}
if v, ok := m[v.Token]; ok {
switch v {
case constant.NormalToken:
case constant.KickedToken:
return nil, servererrs.ErrTokenKicked.Wrap()
default:
return nil, servererrs.ErrTokenUnknown.WrapMsg(fmt.Sprintf("token status is %d", v))
}
} else {
return nil, servererrs.ErrTokenNotExist.Wrap()
}
return &v, nil
}
type WSArgs struct {
Token string
UserID string
PlatformID int
Compression bool
MsgResp bool
} }
func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) { func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) {
connContext := newContext(w, r) connContext := newContext(w, r)
args, pErr := ws.ParseWSArgs(r) if ws.onlineUserConnNum.Load() >= ws.wsMaxConnNum {
var wsLongConn *GWebSocket httpError(connContext, servererrs.ErrConnOverMaxNumLimit.WrapMsg("over max conn num limit"))
if args.MsgResp {
wsLongConn = newGWebSocket(WebSocket, ws.handshakeTimeout, ws.writeBufferSize)
if err := wsLongConn.GenerateLongConn(w, r); err != nil {
httpError(connContext, err)
return return
} }
data, err := json.Marshal(apiresp.ParseError(pErr)) err := connContext.ParseEssentialArgs()
if err != nil { if err != nil {
_ = wsLongConn.Close()
return
}
if err := wsLongConn.WriteMessage(MessageText, data); err != nil {
_ = wsLongConn.Close()
return
}
if pErr != nil {
_ = wsLongConn.Close()
return
}
} else {
if pErr != nil {
httpError(connContext, pErr)
return
}
wsLongConn = newGWebSocket(WebSocket, ws.handshakeTimeout, ws.writeBufferSize)
if err := wsLongConn.GenerateLongConn(w, r); err != nil {
httpError(connContext, err) httpError(connContext, err)
return return
} }
resp, err := ws.authClient.ParseToken(connContext, connContext.GetToken())
if err != nil {
shouldSendError := connContext.ShouldSendError()
if shouldSendError {
wsLongConn := newGWebSocket(WebSocket, ws.handshakeTimeout, ws.writeBufferSize)
if err := wsLongConn.RespErrInfo(err, w, r); err == nil {
return
}
}
httpError(connContext, err)
return
}
err = ws.validateRespWithRequest(connContext, resp)
if err != nil {
httpError(connContext, err)
return
}
wsLongConn := newGWebSocket(WebSocket, ws.handshakeTimeout, ws.writeBufferSize)
if err := wsLongConn.GenerateLongConn(w, r); err != nil {
httpError(connContext, err)
return
} }
client := ws.clientPool.Get().(*Client) client := ws.clientPool.Get().(*Client)
client.ResetClient(connContext, wsLongConn, connContext.GetBackground(), args.Compression, ws, args.Token) client.ResetClient(connContext, wsLongConn, ws)
ws.registerChan <- client ws.registerChan <- client
go client.readMessage() go client.readMessage()
} }