mirror of
https://github.com/openimsdk/open-im-server.git
synced 2025-11-05 11:52:10 +08:00
refactor: websocket auth change to call rpc of auth.
This commit is contained in:
parent
95b180e7dc
commit
e379067601
@ -87,19 +87,19 @@ type Client struct {
|
|||||||
// }
|
// }
|
||||||
|
|
||||||
// ResetClient updates the client's state with new connection and context information.
|
// ResetClient updates the client's state with new connection and context information.
|
||||||
func (c *Client) ResetClient(ctx *UserConnContext, conn LongConn, isBackground, isCompress bool, longConnServer LongConnServer, token string) {
|
func (c *Client) ResetClient(ctx *UserConnContext, conn LongConn, longConnServer LongConnServer) {
|
||||||
c.w = new(sync.Mutex)
|
c.w = new(sync.Mutex)
|
||||||
c.conn = conn
|
c.conn = conn
|
||||||
c.PlatformID = stringutil.StringToInt(ctx.GetPlatformID())
|
c.PlatformID = stringutil.StringToInt(ctx.GetPlatformID())
|
||||||
c.IsCompress = isCompress
|
c.IsCompress = ctx.GetCompression()
|
||||||
c.IsBackground = isBackground
|
c.IsBackground = ctx.GetBackground()
|
||||||
c.UserID = ctx.GetUserID()
|
c.UserID = ctx.GetUserID()
|
||||||
c.ctx = ctx
|
c.ctx = ctx
|
||||||
c.longConnServer = longConnServer
|
c.longConnServer = longConnServer
|
||||||
c.IsBackground = false
|
c.IsBackground = false
|
||||||
c.closed.Store(false)
|
c.closed.Store(false)
|
||||||
c.closedErr = nil
|
c.closedErr = nil
|
||||||
c.token = token
|
c.token = ctx.GetToken()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) pingHandler(_ string) error {
|
func (c *Client) pingHandler(_ string) error {
|
||||||
|
|||||||
@ -26,7 +26,7 @@ const (
|
|||||||
Compression = "compression"
|
Compression = "compression"
|
||||||
GzipCompressionProtocol = "gzip"
|
GzipCompressionProtocol = "gzip"
|
||||||
BackgroundStatus = "isBackground"
|
BackgroundStatus = "isBackground"
|
||||||
MsgResp = "isMsgResp"
|
ErrResp = "errResp"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|||||||
@ -15,6 +15,7 @@
|
|||||||
package msggateway
|
package msggateway
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/openimsdk/open-im-server/v3/pkg/common/servererrs"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strconv"
|
"strconv"
|
||||||
@ -135,6 +136,32 @@ func (c *UserConnContext) GetToken() string {
|
|||||||
return c.Req.URL.Query().Get(Token)
|
return c.Req.URL.Query().Get(Token)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *UserConnContext) GetCompression() bool {
|
||||||
|
compression, exists := c.Query(Compression)
|
||||||
|
if exists && compression == GzipCompressionProtocol {
|
||||||
|
return true
|
||||||
|
} else {
|
||||||
|
compression, exists := c.GetHeader(Compression)
|
||||||
|
if exists && compression == GzipCompressionProtocol {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *UserConnContext) ShouldSendError() bool {
|
||||||
|
errResp, exists := c.Query(ErrResp)
|
||||||
|
if exists {
|
||||||
|
b, err := strconv.ParseBool(errResp)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
} else {
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func (c *UserConnContext) SetToken(token string) {
|
func (c *UserConnContext) SetToken(token string) {
|
||||||
c.Req.URL.RawQuery = Token + "=" + token
|
c.Req.URL.RawQuery = Token + "=" + token
|
||||||
}
|
}
|
||||||
@ -146,3 +173,23 @@ func (c *UserConnContext) GetBackground() bool {
|
|||||||
}
|
}
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
func (c *UserConnContext) ParseEssentialArgs() error {
|
||||||
|
_, exists := c.Query(Token)
|
||||||
|
if !exists {
|
||||||
|
return servererrs.ErrConnArgsErr.WrapMsg("token is empty")
|
||||||
|
}
|
||||||
|
_, exists = c.Query(WsUserID)
|
||||||
|
if !exists {
|
||||||
|
return servererrs.ErrConnArgsErr.WrapMsg("sendID is empty")
|
||||||
|
}
|
||||||
|
platformIDStr, exists := c.Query(PlatformID)
|
||||||
|
if !exists {
|
||||||
|
return servererrs.ErrConnArgsErr.WrapMsg("platformID is empty")
|
||||||
|
}
|
||||||
|
_, err := strconv.Atoi(platformIDStr)
|
||||||
|
if err != nil {
|
||||||
|
return servererrs.ErrConnArgsErr.WrapMsg("platformID is not int")
|
||||||
|
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@ -15,6 +15,8 @@
|
|||||||
package msggateway
|
package msggateway
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"github.com/openimsdk/tools/apiresp"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -143,6 +145,24 @@ func (d *GWebSocket) SetPingHandler(handler PingPongHandler) {
|
|||||||
d.conn.SetPingHandler(handler)
|
d.conn.SetPingHandler(handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *GWebSocket) RespErrInfo(err error, w http.ResponseWriter, r *http.Request) error {
|
||||||
|
if err := d.GenerateLongConn(w, r); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(apiresp.ParseError(err))
|
||||||
|
if err != nil {
|
||||||
|
_ = d.Close()
|
||||||
|
return errs.WrapMsg(err, "json marshal failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := d.WriteMessage(MessageText, data); err != nil {
|
||||||
|
_ = d.Close()
|
||||||
|
return errs.WrapMsg(err, "WriteMessage failed")
|
||||||
|
}
|
||||||
|
_ = d.Close()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// func (d *GWebSocket) CheckSendConnDiffNow() bool {
|
// func (d *GWebSocket) CheckSendConnDiffNow() bool {
|
||||||
// return d.conn == d.sendConn
|
// return d.conn == d.sendConn
|
||||||
//}
|
//}
|
||||||
|
|||||||
@ -16,23 +16,20 @@ package msggateway
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
|
pbAuth "github.com/openimsdk/protocol/auth"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-playground/validator/v10"
|
"github.com/go-playground/validator/v10"
|
||||||
"github.com/openimsdk/open-im-server/v3/pkg/authverify"
|
|
||||||
"github.com/openimsdk/open-im-server/v3/pkg/common/db/cache"
|
"github.com/openimsdk/open-im-server/v3/pkg/common/db/cache"
|
||||||
"github.com/openimsdk/open-im-server/v3/pkg/common/prommetrics"
|
"github.com/openimsdk/open-im-server/v3/pkg/common/prommetrics"
|
||||||
"github.com/openimsdk/open-im-server/v3/pkg/common/servererrs"
|
"github.com/openimsdk/open-im-server/v3/pkg/common/servererrs"
|
||||||
"github.com/openimsdk/open-im-server/v3/pkg/rpcclient"
|
"github.com/openimsdk/open-im-server/v3/pkg/rpcclient"
|
||||||
"github.com/openimsdk/protocol/constant"
|
"github.com/openimsdk/protocol/constant"
|
||||||
"github.com/openimsdk/protocol/msggateway"
|
"github.com/openimsdk/protocol/msggateway"
|
||||||
"github.com/openimsdk/tools/apiresp"
|
|
||||||
"github.com/openimsdk/tools/discovery"
|
"github.com/openimsdk/tools/discovery"
|
||||||
"github.com/openimsdk/tools/errs"
|
"github.com/openimsdk/tools/errs"
|
||||||
"github.com/openimsdk/tools/log"
|
"github.com/openimsdk/tools/log"
|
||||||
@ -73,6 +70,7 @@ type WsServer struct {
|
|||||||
validate *validator.Validate
|
validate *validator.Validate
|
||||||
cache cache.TokenModel
|
cache cache.TokenModel
|
||||||
userClient *rpcclient.UserRpcClient
|
userClient *rpcclient.UserRpcClient
|
||||||
|
authClient *rpcclient.Auth
|
||||||
disCov discovery.SvcDiscoveryRegistry
|
disCov discovery.SvcDiscoveryRegistry
|
||||||
Compressor
|
Compressor
|
||||||
Encoder
|
Encoder
|
||||||
@ -88,6 +86,7 @@ type kickHandler struct {
|
|||||||
func (ws *WsServer) SetDiscoveryRegistry(disCov discovery.SvcDiscoveryRegistry, config *Config) {
|
func (ws *WsServer) SetDiscoveryRegistry(disCov discovery.SvcDiscoveryRegistry, config *Config) {
|
||||||
ws.MessageHandler = NewGrpcHandler(ws.validate, disCov, &config.Share.RpcRegisterName)
|
ws.MessageHandler = NewGrpcHandler(ws.validate, disCov, &config.Share.RpcRegisterName)
|
||||||
u := rpcclient.NewUserRpcClient(disCov, config.Share.RpcRegisterName.User, config.Share.IMAdminUserID)
|
u := rpcclient.NewUserRpcClient(disCov, config.Share.RpcRegisterName.User, config.Share.IMAdminUserID)
|
||||||
|
ws.authClient = rpcclient.NewAuth(disCov, config.Share.RpcRegisterName.Auth)
|
||||||
ws.userClient = &u
|
ws.userClient = &u
|
||||||
ws.disCov = disCov
|
ws.disCov = disCov
|
||||||
}
|
}
|
||||||
@ -408,102 +407,54 @@ func (ws *WsServer) unregisterClient(client *Client) {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ws *WsServer) ParseWSArgs(r *http.Request) (args *WSArgs, err error) {
|
// validateRespWithRequest checks if the response matches the expected userID and platformID.
|
||||||
var v WSArgs
|
func (ws *WsServer) validateRespWithRequest(ctx *UserConnContext, resp *pbAuth.ParseTokenResp) error {
|
||||||
defer func() {
|
userID := ctx.GetUserID()
|
||||||
args = &v
|
platformID := stringutil.StringToInt32(ctx.GetPlatformID())
|
||||||
}()
|
if resp.UserID != userID {
|
||||||
query := r.URL.Query()
|
return servererrs.ErrTokenInvalid.WrapMsg(fmt.Sprintf("token uid %s != userID %s", resp.UserID, userID))
|
||||||
v.MsgResp, _ = strconv.ParseBool(query.Get(MsgResp))
|
|
||||||
if ws.onlineUserConnNum.Load() >= ws.wsMaxConnNum {
|
|
||||||
return nil, servererrs.ErrConnOverMaxNumLimit.WrapMsg("over max conn num limit")
|
|
||||||
}
|
}
|
||||||
if v.Token = query.Get(Token); v.Token == "" {
|
if resp.PlatformID != platformID {
|
||||||
return nil, servererrs.ErrConnArgsErr.WrapMsg("token is empty")
|
return servererrs.ErrTokenInvalid.WrapMsg(fmt.Sprintf("token platform %d != platformID %d", resp.PlatformID, platformID))
|
||||||
}
|
}
|
||||||
if v.UserID = query.Get(WsUserID); v.UserID == "" {
|
return nil
|
||||||
return nil, servererrs.ErrConnArgsErr.WrapMsg("sendID is empty")
|
|
||||||
}
|
|
||||||
platformIDStr := query.Get(PlatformID)
|
|
||||||
if platformIDStr == "" {
|
|
||||||
return nil, servererrs.ErrConnArgsErr.WrapMsg("platformID is empty")
|
|
||||||
}
|
|
||||||
platformID, err := strconv.Atoi(platformIDStr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, servererrs.ErrConnArgsErr.WrapMsg("platformID is not int")
|
|
||||||
}
|
|
||||||
v.PlatformID = platformID
|
|
||||||
if err = authverify.WsVerifyToken(v.Token, v.UserID, ws.msgGatewayConfig.Share.Secret, platformID); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if query.Get(Compression) == GzipCompressionProtocol {
|
|
||||||
v.Compression = true
|
|
||||||
}
|
|
||||||
if r.Header.Get(Compression) == GzipCompressionProtocol {
|
|
||||||
v.Compression = true
|
|
||||||
}
|
|
||||||
m, err := ws.cache.GetTokensWithoutError(context.Background(), v.UserID, platformID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if v, ok := m[v.Token]; ok {
|
|
||||||
switch v {
|
|
||||||
case constant.NormalToken:
|
|
||||||
case constant.KickedToken:
|
|
||||||
return nil, servererrs.ErrTokenKicked.Wrap()
|
|
||||||
default:
|
|
||||||
return nil, servererrs.ErrTokenUnknown.WrapMsg(fmt.Sprintf("token status is %d", v))
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return nil, servererrs.ErrTokenNotExist.Wrap()
|
|
||||||
}
|
|
||||||
return &v, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type WSArgs struct {
|
|
||||||
Token string
|
|
||||||
UserID string
|
|
||||||
PlatformID int
|
|
||||||
Compression bool
|
|
||||||
MsgResp bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) {
|
func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
connContext := newContext(w, r)
|
connContext := newContext(w, r)
|
||||||
args, pErr := ws.ParseWSArgs(r)
|
if ws.onlineUserConnNum.Load() >= ws.wsMaxConnNum {
|
||||||
var wsLongConn *GWebSocket
|
httpError(connContext, servererrs.ErrConnOverMaxNumLimit.WrapMsg("over max conn num limit"))
|
||||||
if args.MsgResp {
|
|
||||||
wsLongConn = newGWebSocket(WebSocket, ws.handshakeTimeout, ws.writeBufferSize)
|
|
||||||
if err := wsLongConn.GenerateLongConn(w, r); err != nil {
|
|
||||||
httpError(connContext, err)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
data, err := json.Marshal(apiresp.ParseError(pErr))
|
err := connContext.ParseEssentialArgs()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = wsLongConn.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err := wsLongConn.WriteMessage(MessageText, data); err != nil {
|
|
||||||
_ = wsLongConn.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if pErr != nil {
|
|
||||||
_ = wsLongConn.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if pErr != nil {
|
|
||||||
httpError(connContext, pErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
wsLongConn = newGWebSocket(WebSocket, ws.handshakeTimeout, ws.writeBufferSize)
|
|
||||||
if err := wsLongConn.GenerateLongConn(w, r); err != nil {
|
|
||||||
httpError(connContext, err)
|
httpError(connContext, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
resp, err := ws.authClient.ParseToken(connContext, connContext.GetToken())
|
||||||
|
if err != nil {
|
||||||
|
shouldSendError := connContext.ShouldSendError()
|
||||||
|
if shouldSendError {
|
||||||
|
wsLongConn := newGWebSocket(WebSocket, ws.handshakeTimeout, ws.writeBufferSize)
|
||||||
|
if err := wsLongConn.RespErrInfo(err, w, r); err == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
httpError(connContext, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = ws.validateRespWithRequest(connContext, resp)
|
||||||
|
if err != nil {
|
||||||
|
httpError(connContext, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
wsLongConn := newGWebSocket(WebSocket, ws.handshakeTimeout, ws.writeBufferSize)
|
||||||
|
if err := wsLongConn.GenerateLongConn(w, r); err != nil {
|
||||||
|
httpError(connContext, err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
client := ws.clientPool.Get().(*Client)
|
client := ws.clientPool.Get().(*Client)
|
||||||
client.ResetClient(connContext, wsLongConn, connContext.GetBackground(), args.Compression, ws, args.Token)
|
client.ResetClient(connContext, wsLongConn, ws)
|
||||||
ws.registerChan <- client
|
ws.registerChan <- client
|
||||||
go client.readMessage()
|
go client.readMessage()
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user