mirror of
https://github.com/openimsdk/open-im-server.git
synced 2025-12-24 08:07:00 +08:00
refactor: replace LongConn with ClientConn interface and simplify message handling
(cherry picked from commit a1dd79a4592f1be55ad0c96fece250df7e185002)
This commit is contained in:
parent
dd5ff6f0e4
commit
efa709c267
@ -16,7 +16,6 @@ package msggateway
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
@ -31,7 +30,6 @@ import (
|
|||||||
"github.com/openimsdk/tools/errs"
|
"github.com/openimsdk/tools/errs"
|
||||||
"github.com/openimsdk/tools/log"
|
"github.com/openimsdk/tools/log"
|
||||||
"github.com/openimsdk/tools/mcontext"
|
"github.com/openimsdk/tools/mcontext"
|
||||||
"github.com/openimsdk/tools/utils/stringutil"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -64,13 +62,12 @@ type PingPongHandler func(string) error
|
|||||||
|
|
||||||
type Client struct {
|
type Client struct {
|
||||||
w *sync.Mutex
|
w *sync.Mutex
|
||||||
conn LongConn
|
conn ClientConn
|
||||||
PlatformID int `json:"platformID"`
|
PlatformID int `json:"platformID"`
|
||||||
IsCompress bool `json:"isCompress"`
|
IsCompress bool `json:"isCompress"`
|
||||||
UserID string `json:"userID"`
|
UserID string `json:"userID"`
|
||||||
IsBackground bool `json:"isBackground"`
|
IsBackground bool `json:"isBackground"`
|
||||||
SDKType string `json:"sdkType"`
|
SDKType string `json:"sdkType"`
|
||||||
SDKVersion string `json:"sdkVersion"`
|
|
||||||
Encoder Encoder
|
Encoder Encoder
|
||||||
ctx *UserConnContext
|
ctx *UserConnContext
|
||||||
longConnServer LongConnServer
|
longConnServer LongConnServer
|
||||||
@ -84,10 +81,10 @@ type Client struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ResetClient updates the client's state with new connection and context information.
|
// ResetClient updates the client's state with new connection and context information.
|
||||||
func (c *Client) ResetClient(ctx *UserConnContext, conn LongConn, longConnServer LongConnServer) {
|
func (c *Client) ResetClient(ctx *UserConnContext, conn ClientConn, longConnServer LongConnServer) {
|
||||||
c.w = new(sync.Mutex)
|
c.w = new(sync.Mutex)
|
||||||
c.conn = conn
|
c.conn = conn
|
||||||
c.PlatformID = stringutil.StringToInt(ctx.GetPlatformID())
|
c.PlatformID = ctx.GetPlatformID()
|
||||||
c.IsCompress = ctx.GetCompression()
|
c.IsCompress = ctx.GetCompression()
|
||||||
c.IsBackground = ctx.GetBackground()
|
c.IsBackground = ctx.GetBackground()
|
||||||
c.UserID = ctx.GetUserID()
|
c.UserID = ctx.GetUserID()
|
||||||
@ -98,7 +95,6 @@ func (c *Client) ResetClient(ctx *UserConnContext, conn LongConn, longConnServer
|
|||||||
c.closedErr = nil
|
c.closedErr = nil
|
||||||
c.token = ctx.GetToken()
|
c.token = ctx.GetToken()
|
||||||
c.SDKType = ctx.GetSDKType()
|
c.SDKType = ctx.GetSDKType()
|
||||||
c.SDKVersion = ctx.GetSDKVersion()
|
|
||||||
c.hbCtx, c.hbCancel = context.WithCancel(c.ctx)
|
c.hbCtx, c.hbCancel = context.WithCancel(c.ctx)
|
||||||
c.subLock = new(sync.Mutex)
|
c.subLock = new(sync.Mutex)
|
||||||
if c.subUserIDs != nil {
|
if c.subUserIDs != nil {
|
||||||
@ -112,22 +108,6 @@ func (c *Client) ResetClient(ctx *UserConnContext, conn LongConn, longConnServer
|
|||||||
c.subUserIDs = make(map[string]struct{})
|
c.subUserIDs = make(map[string]struct{})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) pingHandler(appData string) error {
|
|
||||||
if err := c.conn.SetReadDeadline(pongWait); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
log.ZDebug(c.ctx, "ping Handler Success.", "appData", appData)
|
|
||||||
return c.writePongMsg(appData)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) pongHandler(_ string) error {
|
|
||||||
if err := c.conn.SetReadDeadline(pongWait); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// readMessage continuously reads messages from the connection.
|
// readMessage continuously reads messages from the connection.
|
||||||
func (c *Client) readMessage() {
|
func (c *Client) readMessage() {
|
||||||
defer func() {
|
defer func() {
|
||||||
@ -138,52 +118,25 @@ func (c *Client) readMessage() {
|
|||||||
c.close()
|
c.close()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
c.conn.SetReadLimit(maxMessageSize)
|
|
||||||
_ = c.conn.SetReadDeadline(pongWait)
|
|
||||||
c.conn.SetPongHandler(c.pongHandler)
|
|
||||||
c.conn.SetPingHandler(c.pingHandler)
|
|
||||||
c.activeHeartbeat(c.hbCtx)
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
log.ZDebug(c.ctx, "readMessage")
|
log.ZDebug(c.ctx, "readMessage")
|
||||||
messageType, message, returnErr := c.conn.ReadMessage()
|
message, returnErr := c.conn.ReadMessage()
|
||||||
if returnErr != nil {
|
if returnErr != nil {
|
||||||
log.ZWarn(c.ctx, "readMessage", returnErr, "messageType", messageType)
|
log.ZWarn(c.ctx, "readMessage", returnErr)
|
||||||
c.closedErr = returnErr
|
c.closedErr = returnErr
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.ZDebug(c.ctx, "readMessage", "messageType", messageType)
|
|
||||||
if c.closed.Load() {
|
if c.closed.Load() {
|
||||||
// The scenario where the connection has just been closed, but the coroutine has not exited
|
// The scenario where the connection has just been closed, but the coroutine has not exited
|
||||||
c.closedErr = ErrConnClosed
|
c.closedErr = ErrConnClosed
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
switch messageType {
|
parseDataErr := c.handleMessage(message)
|
||||||
case MessageBinary:
|
if parseDataErr != nil {
|
||||||
_ = c.conn.SetReadDeadline(pongWait)
|
c.closedErr = parseDataErr
|
||||||
parseDataErr := c.handleMessage(message)
|
|
||||||
if parseDataErr != nil {
|
|
||||||
c.closedErr = parseDataErr
|
|
||||||
return
|
|
||||||
}
|
|
||||||
case MessageText:
|
|
||||||
_ = c.conn.SetReadDeadline(pongWait)
|
|
||||||
parseDataErr := c.handlerTextMessage(message)
|
|
||||||
if parseDataErr != nil {
|
|
||||||
c.closedErr = parseDataErr
|
|
||||||
return
|
|
||||||
}
|
|
||||||
case PingMessage:
|
|
||||||
err := c.writePongMsg("")
|
|
||||||
log.ZError(c.ctx, "writePongMsg", err)
|
|
||||||
|
|
||||||
case CloseMessage:
|
|
||||||
c.closedErr = ErrClientClosed
|
|
||||||
return
|
return
|
||||||
|
|
||||||
default:
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -358,109 +311,13 @@ func (c *Client) writeBinaryMsg(resp Resp) error {
|
|||||||
c.w.Lock()
|
c.w.Lock()
|
||||||
defer c.w.Unlock()
|
defer c.w.Unlock()
|
||||||
|
|
||||||
err = c.conn.SetWriteDeadline(writeWait)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.IsCompress {
|
if c.IsCompress {
|
||||||
resultBuf, compressErr := c.longConnServer.CompressWithPool(encodedBuf)
|
resultBuf, compressErr := c.longConnServer.CompressWithPool(encodedBuf)
|
||||||
if compressErr != nil {
|
if compressErr != nil {
|
||||||
return compressErr
|
return compressErr
|
||||||
}
|
}
|
||||||
return c.conn.WriteMessage(MessageBinary, resultBuf)
|
return c.conn.WriteMessage(resultBuf)
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.conn.WriteMessage(MessageBinary, encodedBuf)
|
return c.conn.WriteMessage(encodedBuf)
|
||||||
}
|
|
||||||
|
|
||||||
// Actively initiate Heartbeat when platform in Web.
|
|
||||||
func (c *Client) activeHeartbeat(ctx context.Context) {
|
|
||||||
if c.PlatformID == constant.WebPlatformID {
|
|
||||||
go func() {
|
|
||||||
defer func() {
|
|
||||||
if r := recover(); r != nil {
|
|
||||||
log.ZPanic(ctx, "activeHeartbeat Panic", errs.ErrPanic(r))
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
log.ZDebug(ctx, "server initiative send heartbeat start.")
|
|
||||||
ticker := time.NewTicker(pingPeriod)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ticker.C:
|
|
||||||
if err := c.writePingMsg(); err != nil {
|
|
||||||
log.ZWarn(c.ctx, "send Ping Message error.", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
case <-c.hbCtx.Done():
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
func (c *Client) writePingMsg() error {
|
|
||||||
if c.closed.Load() {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
c.w.Lock()
|
|
||||||
defer c.w.Unlock()
|
|
||||||
|
|
||||||
err := c.conn.SetWriteDeadline(writeWait)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return c.conn.WriteMessage(PingMessage, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) writePongMsg(appData string) error {
|
|
||||||
log.ZDebug(c.ctx, "write Pong Msg in Server", "appData", appData)
|
|
||||||
if c.closed.Load() {
|
|
||||||
log.ZWarn(c.ctx, "is closed in server", nil, "appdata", appData, "closed err", c.closedErr)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
c.w.Lock()
|
|
||||||
defer c.w.Unlock()
|
|
||||||
|
|
||||||
err := c.conn.SetWriteDeadline(writeWait)
|
|
||||||
if err != nil {
|
|
||||||
log.ZWarn(c.ctx, "SetWriteDeadline in Server have error", errs.Wrap(err), "writeWait", writeWait, "appData", appData)
|
|
||||||
return errs.Wrap(err)
|
|
||||||
}
|
|
||||||
err = c.conn.WriteMessage(PongMessage, []byte(appData))
|
|
||||||
if err != nil {
|
|
||||||
log.ZWarn(c.ctx, "Write Message have error", errs.Wrap(err), "Pong msg", PongMessage)
|
|
||||||
}
|
|
||||||
|
|
||||||
return errs.Wrap(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) handlerTextMessage(b []byte) error {
|
|
||||||
var msg TextMessage
|
|
||||||
if err := json.Unmarshal(b, &msg); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
switch msg.Type {
|
|
||||||
case TextPong:
|
|
||||||
return nil
|
|
||||||
case TextPing:
|
|
||||||
msg.Type = TextPong
|
|
||||||
msgData, err := json.Marshal(msg)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
c.w.Lock()
|
|
||||||
defer c.w.Unlock()
|
|
||||||
if err := c.conn.SetWriteDeadline(writeWait); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return c.conn.WriteMessage(MessageText, msgData)
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("not support message type %s", msg.Type)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
229
internal/msggateway/client_conn.go
Normal file
229
internal/msggateway/client_conn.go
Normal file
@ -0,0 +1,229 @@
|
|||||||
|
package msggateway
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
|
||||||
|
"github.com/openimsdk/tools/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrWriteFull = fmt.Errorf("websocket write buffer full,close connection")
|
||||||
|
|
||||||
|
type ClientConn interface {
|
||||||
|
ReadMessage() ([]byte, error)
|
||||||
|
WriteMessage(message []byte) error
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
|
|
||||||
|
type websocketMessage struct {
|
||||||
|
MessageType int
|
||||||
|
Data []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewWebSocketClientConn(conn *websocket.Conn, readLimit int64, readTimeout time.Duration, pingInterval time.Duration) ClientConn {
|
||||||
|
c := &websocketClientConn{
|
||||||
|
readTimeout: readTimeout,
|
||||||
|
conn: conn,
|
||||||
|
writer: make(chan *websocketMessage, 256),
|
||||||
|
done: make(chan struct{}),
|
||||||
|
}
|
||||||
|
if readLimit > 0 {
|
||||||
|
c.conn.SetReadLimit(readLimit)
|
||||||
|
}
|
||||||
|
c.conn.SetPingHandler(c.pingHandler)
|
||||||
|
c.conn.SetPongHandler(c.pongHandler)
|
||||||
|
|
||||||
|
go c.loopSend()
|
||||||
|
if pingInterval > 0 {
|
||||||
|
go c.doPing(pingInterval)
|
||||||
|
}
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
type websocketClientConn struct {
|
||||||
|
readTimeout time.Duration
|
||||||
|
conn *websocket.Conn
|
||||||
|
writer chan *websocketMessage
|
||||||
|
done chan struct{}
|
||||||
|
err atomic.Pointer[error]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *websocketClientConn) ReadMessage() ([]byte, error) {
|
||||||
|
buf, err := c.readMessage()
|
||||||
|
if err != nil {
|
||||||
|
return nil, c.closeBy(fmt.Errorf("read message %w", err))
|
||||||
|
}
|
||||||
|
return buf, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *websocketClientConn) WriteMessage(message []byte) error {
|
||||||
|
return c.writeMessage(websocket.BinaryMessage, message)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *websocketClientConn) Close() error {
|
||||||
|
return c.closeBy(fmt.Errorf("websocket connection closed"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *websocketClientConn) closeBy(err error) error {
|
||||||
|
if !c.err.CompareAndSwap(nil, &err) {
|
||||||
|
return *c.err.Load()
|
||||||
|
}
|
||||||
|
close(c.done)
|
||||||
|
log.ZWarn(context.Background(), "websocket connection closed", err, "remoteAddr", c.conn.RemoteAddr(),
|
||||||
|
"chan length", len(c.writer))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *websocketClientConn) writeMessage(messageType int, data []byte) error {
|
||||||
|
if errPtr := c.err.Load(); errPtr != nil {
|
||||||
|
return *errPtr
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case c.writer <- &websocketMessage{MessageType: messageType, Data: data}:
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
return c.closeBy(ErrWriteFull)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *websocketClientConn) loopSend() {
|
||||||
|
defer func() {
|
||||||
|
_ = c.conn.Close()
|
||||||
|
}()
|
||||||
|
var err error
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-c.done:
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case msg := <-c.writer:
|
||||||
|
switch msg.MessageType {
|
||||||
|
case websocket.TextMessage, websocket.BinaryMessage:
|
||||||
|
err = c.conn.WriteMessage(msg.MessageType, msg.Data)
|
||||||
|
default:
|
||||||
|
err = c.conn.WriteControl(msg.MessageType, msg.Data, time.Time{})
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
_ = c.closeBy(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case msg := <-c.writer:
|
||||||
|
switch msg.MessageType {
|
||||||
|
case websocket.TextMessage, websocket.BinaryMessage:
|
||||||
|
err = c.conn.WriteMessage(msg.MessageType, msg.Data)
|
||||||
|
default:
|
||||||
|
err = c.conn.WriteControl(msg.MessageType, msg.Data, time.Time{})
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
_ = c.closeBy(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *websocketClientConn) setReadDeadline() error {
|
||||||
|
deadline := time.Now().Add(c.readTimeout)
|
||||||
|
return c.conn.SetReadDeadline(deadline)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *websocketClientConn) readMessage() ([]byte, error) {
|
||||||
|
for {
|
||||||
|
if err := c.setReadDeadline(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
messageType, buf, err := c.conn.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
switch messageType {
|
||||||
|
case websocket.BinaryMessage:
|
||||||
|
return buf, nil
|
||||||
|
case websocket.TextMessage:
|
||||||
|
if err := c.onReadTextMessage(buf); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
case websocket.PingMessage:
|
||||||
|
if err := c.pingHandler(string(buf)); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
case websocket.PongMessage:
|
||||||
|
if err := c.pongHandler(string(buf)); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
case websocket.CloseMessage:
|
||||||
|
if len(buf) == 0 {
|
||||||
|
return nil, errors.New("websocket connection closed by peer")
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("websocket connection closed by peer, data %s", string(buf))
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unknown websocket message type %d", messageType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *websocketClientConn) onReadTextMessage(buf []byte) error {
|
||||||
|
var msg struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Body json.RawMessage `json:"body"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(buf, &msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
switch msg.Type {
|
||||||
|
case TextPong:
|
||||||
|
return nil
|
||||||
|
case TextPing:
|
||||||
|
msg.Type = TextPong
|
||||||
|
msgData, err := json.Marshal(msg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.writeMessage(websocket.TextMessage, msgData)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("not support text message type %s", msg.Type)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *websocketClientConn) pingHandler(appData string) error {
|
||||||
|
log.ZDebug(context.Background(), "ping handler recv ping", "remoteAddr", c.conn.RemoteAddr(), "appData", appData)
|
||||||
|
if err := c.setReadDeadline(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err := c.conn.WriteControl(websocket.PongMessage, []byte(appData), time.Now().Add(time.Second*1))
|
||||||
|
if err != nil {
|
||||||
|
log.ZWarn(context.Background(), "ping handler write pong error", err, "remoteAddr", c.conn.RemoteAddr(), "appData", appData)
|
||||||
|
}
|
||||||
|
log.ZDebug(context.Background(), "ping handler write pong success", "remoteAddr", c.conn.RemoteAddr(), "appData", appData)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *websocketClientConn) pongHandler(string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *websocketClientConn) doPing(d time.Duration) {
|
||||||
|
ticker := time.NewTicker(d)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-c.done:
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
if err := c.writeMessage(websocket.PingMessage, nil); err != nil {
|
||||||
|
_ = c.closeBy(fmt.Errorf("send ping %w", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -15,6 +15,8 @@
|
|||||||
package msggateway
|
package msggateway
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strconv"
|
"strconv"
|
||||||
@ -24,10 +26,20 @@ import (
|
|||||||
|
|
||||||
"github.com/openimsdk/protocol/constant"
|
"github.com/openimsdk/protocol/constant"
|
||||||
"github.com/openimsdk/tools/utils/encrypt"
|
"github.com/openimsdk/tools/utils/encrypt"
|
||||||
"github.com/openimsdk/tools/utils/stringutil"
|
|
||||||
"github.com/openimsdk/tools/utils/timeutil"
|
"github.com/openimsdk/tools/utils/timeutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type UserConnContextInfo struct {
|
||||||
|
Token string `json:"token"`
|
||||||
|
UserID string `json:"userID"`
|
||||||
|
PlatformID int `json:"platformID"`
|
||||||
|
OperationID string `json:"operationID"`
|
||||||
|
Compression string `json:"compression"`
|
||||||
|
SDKType string `json:"sdkType"`
|
||||||
|
SendResponse bool `json:"sendResponse"`
|
||||||
|
Background bool `json:"background"`
|
||||||
|
}
|
||||||
|
|
||||||
type UserConnContext struct {
|
type UserConnContext struct {
|
||||||
RespWriter http.ResponseWriter
|
RespWriter http.ResponseWriter
|
||||||
Req *http.Request
|
Req *http.Request
|
||||||
@ -35,6 +47,7 @@ type UserConnContext struct {
|
|||||||
Method string
|
Method string
|
||||||
RemoteAddr string
|
RemoteAddr string
|
||||||
ConnID string
|
ConnID string
|
||||||
|
info *UserConnContextInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *UserConnContext) Deadline() (deadline time.Time, ok bool) {
|
func (c *UserConnContext) Deadline() (deadline time.Time, ok bool) {
|
||||||
@ -58,7 +71,7 @@ func (c *UserConnContext) Value(key any) any {
|
|||||||
case constant.ConnID:
|
case constant.ConnID:
|
||||||
return c.GetConnID()
|
return c.GetConnID()
|
||||||
case constant.OpUserPlatform:
|
case constant.OpUserPlatform:
|
||||||
return constant.PlatformIDToName(stringutil.StringToInt(c.GetPlatformID()))
|
return c.GetPlatformID()
|
||||||
case constant.RemoteAddr:
|
case constant.RemoteAddr:
|
||||||
return c.RemoteAddr
|
return c.RemoteAddr
|
||||||
default:
|
default:
|
||||||
@ -83,30 +96,91 @@ func newContext(respWriter http.ResponseWriter, req *http.Request) *UserConnCont
|
|||||||
|
|
||||||
func newTempContext() *UserConnContext {
|
func newTempContext() *UserConnContext {
|
||||||
return &UserConnContext{
|
return &UserConnContext{
|
||||||
Req: &http.Request{URL: &url.URL{}},
|
Req: &http.Request{URL: &url.URL{}},
|
||||||
|
info: &UserConnContextInfo{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *UserConnContext) ParseEssentialArgs() error {
|
||||||
|
query := c.Req.URL.Query()
|
||||||
|
if data := query.Get("v"); data != "" {
|
||||||
|
return c.parseByJson(data)
|
||||||
|
} else {
|
||||||
|
return c.parseByQuery(query, c.Req.Header)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *UserConnContext) parseByQuery(query url.Values, header http.Header) error {
|
||||||
|
info := UserConnContextInfo{
|
||||||
|
Token: query.Get(Token),
|
||||||
|
UserID: query.Get(WsUserID),
|
||||||
|
OperationID: query.Get(OperationID),
|
||||||
|
Compression: query.Get(Compression),
|
||||||
|
SDKType: query.Get(SDKType),
|
||||||
|
}
|
||||||
|
platformID, err := strconv.Atoi(query.Get(PlatformID))
|
||||||
|
if err != nil {
|
||||||
|
return servererrs.ErrConnArgsErr.WrapMsg("platformID is not int")
|
||||||
|
}
|
||||||
|
info.PlatformID = platformID
|
||||||
|
if val := query.Get(SendResponse); val != "" {
|
||||||
|
ok, err := strconv.ParseBool(val)
|
||||||
|
if err != nil {
|
||||||
|
return servererrs.ErrConnArgsErr.WrapMsg("isMsgResp is not bool")
|
||||||
|
}
|
||||||
|
info.SendResponse = ok
|
||||||
|
}
|
||||||
|
if info.Compression == "" {
|
||||||
|
info.Compression = header.Get(Compression)
|
||||||
|
}
|
||||||
|
background, err := strconv.ParseBool(query.Get(BackgroundStatus))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
info.Background = background
|
||||||
|
return c.checkInfo(&info)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *UserConnContext) parseByJson(data string) error {
|
||||||
|
reqInfo, err := base64.RawURLEncoding.DecodeString(data)
|
||||||
|
if err != nil {
|
||||||
|
return servererrs.ErrConnArgsErr.WrapMsg("data is not base64")
|
||||||
|
}
|
||||||
|
var info UserConnContextInfo
|
||||||
|
if err := json.Unmarshal(reqInfo, &info); err != nil {
|
||||||
|
return servererrs.ErrConnArgsErr.WrapMsg("data is not json", "info", err.Error())
|
||||||
|
}
|
||||||
|
return c.checkInfo(&info)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *UserConnContext) checkInfo(info *UserConnContextInfo) error {
|
||||||
|
if info.OperationID == "" {
|
||||||
|
return servererrs.ErrConnArgsErr.WrapMsg("operationID is empty")
|
||||||
|
}
|
||||||
|
if info.Token == "" {
|
||||||
|
return servererrs.ErrConnArgsErr.WrapMsg("token is empty")
|
||||||
|
}
|
||||||
|
if info.UserID == "" {
|
||||||
|
return servererrs.ErrConnArgsErr.WrapMsg("sendID is empty")
|
||||||
|
}
|
||||||
|
if _, ok := constant.PlatformID2Name[info.PlatformID]; !ok {
|
||||||
|
return servererrs.ErrConnArgsErr.WrapMsg("platformID is invalid")
|
||||||
|
}
|
||||||
|
switch info.SDKType {
|
||||||
|
case "":
|
||||||
|
info.SDKType = GoSDK
|
||||||
|
case GoSDK, JsSDK:
|
||||||
|
default:
|
||||||
|
return servererrs.ErrConnArgsErr.WrapMsg("sdkType is invalid")
|
||||||
|
}
|
||||||
|
c.info = info
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *UserConnContext) GetRemoteAddr() string {
|
func (c *UserConnContext) GetRemoteAddr() string {
|
||||||
return c.RemoteAddr
|
return c.RemoteAddr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *UserConnContext) Query(key string) (string, bool) {
|
|
||||||
var value string
|
|
||||||
if value = c.Req.URL.Query().Get(key); value == "" {
|
|
||||||
return value, false
|
|
||||||
}
|
|
||||||
return value, true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *UserConnContext) GetHeader(key string) (string, bool) {
|
|
||||||
var value string
|
|
||||||
if value = c.Req.Header.Get(key); value == "" {
|
|
||||||
return value, false
|
|
||||||
}
|
|
||||||
return value, true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *UserConnContext) SetHeader(key, value string) {
|
func (c *UserConnContext) SetHeader(key, value string) {
|
||||||
c.RespWriter.Header().Set(key, value)
|
c.RespWriter.Header().Set(key, value)
|
||||||
}
|
}
|
||||||
@ -120,97 +194,69 @@ func (c *UserConnContext) GetConnID() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *UserConnContext) GetUserID() string {
|
func (c *UserConnContext) GetUserID() string {
|
||||||
return c.Req.URL.Query().Get(WsUserID)
|
if c == nil || c.info == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return c.info.UserID
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *UserConnContext) GetPlatformID() string {
|
func (c *UserConnContext) GetPlatformID() int {
|
||||||
return c.Req.URL.Query().Get(PlatformID)
|
if c == nil || c.info == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return c.info.PlatformID
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *UserConnContext) GetOperationID() string {
|
func (c *UserConnContext) GetOperationID() string {
|
||||||
return c.Req.URL.Query().Get(OperationID)
|
if c == nil || c.info == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return c.info.OperationID
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *UserConnContext) SetOperationID(operationID string) {
|
func (c *UserConnContext) SetOperationID(operationID string) {
|
||||||
values := c.Req.URL.Query()
|
if c.info == nil {
|
||||||
values.Set(OperationID, operationID)
|
c.info = &UserConnContextInfo{}
|
||||||
c.Req.URL.RawQuery = values.Encode()
|
}
|
||||||
|
c.info.OperationID = operationID
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *UserConnContext) GetToken() string {
|
func (c *UserConnContext) GetToken() string {
|
||||||
return c.Req.URL.Query().Get(Token)
|
if c == nil || c.info == nil {
|
||||||
}
|
return ""
|
||||||
|
}
|
||||||
func (c *UserConnContext) GetSDKVersion() string {
|
return c.info.Token
|
||||||
return c.Req.URL.Query().Get(SDKVersion)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *UserConnContext) GetCompression() bool {
|
func (c *UserConnContext) GetCompression() bool {
|
||||||
compression, exists := c.Query(Compression)
|
return c != nil && c.info != nil && c.info.Compression == GzipCompressionProtocol
|
||||||
if exists && compression == GzipCompressionProtocol {
|
|
||||||
return true
|
|
||||||
} else {
|
|
||||||
compression, exists := c.GetHeader(Compression)
|
|
||||||
if exists && compression == GzipCompressionProtocol {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *UserConnContext) GetSDKType() string {
|
func (c *UserConnContext) GetSDKType() string {
|
||||||
sdkType := c.Req.URL.Query().Get(SDKType)
|
if c == nil || c.info == nil {
|
||||||
if sdkType == "" {
|
return GoSDK
|
||||||
sdkType = GoSDK
|
}
|
||||||
|
switch c.info.SDKType {
|
||||||
|
case "", GoSDK:
|
||||||
|
return GoSDK
|
||||||
|
case JsSDK:
|
||||||
|
return JsSDK
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
}
|
}
|
||||||
return sdkType
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *UserConnContext) ShouldSendResp() bool {
|
func (c *UserConnContext) ShouldSendResp() bool {
|
||||||
errResp, exists := c.Query(SendResponse)
|
return c != nil && c.info != nil && c.info.SendResponse
|
||||||
if exists {
|
|
||||||
b, err := strconv.ParseBool(errResp)
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
} else {
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *UserConnContext) SetToken(token string) {
|
func (c *UserConnContext) SetToken(token string) {
|
||||||
c.Req.URL.RawQuery = Token + "=" + token
|
if c.info == nil {
|
||||||
|
c.info = &UserConnContextInfo{}
|
||||||
|
}
|
||||||
|
c.info.Token = token
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *UserConnContext) GetBackground() bool {
|
func (c *UserConnContext) GetBackground() bool {
|
||||||
b, err := strconv.ParseBool(c.Req.URL.Query().Get(BackgroundStatus))
|
return c != nil && c.info != nil && c.info.Background
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
func (c *UserConnContext) ParseEssentialArgs() error {
|
|
||||||
_, exists := c.Query(Token)
|
|
||||||
if !exists {
|
|
||||||
return servererrs.ErrConnArgsErr.WrapMsg("token is empty")
|
|
||||||
}
|
|
||||||
_, exists = c.Query(WsUserID)
|
|
||||||
if !exists {
|
|
||||||
return servererrs.ErrConnArgsErr.WrapMsg("sendID is empty")
|
|
||||||
}
|
|
||||||
platformIDStr, exists := c.Query(PlatformID)
|
|
||||||
if !exists {
|
|
||||||
return servererrs.ErrConnArgsErr.WrapMsg("platformID is empty")
|
|
||||||
}
|
|
||||||
_, err := strconv.Atoi(platformIDStr)
|
|
||||||
if err != nil {
|
|
||||||
return servererrs.ErrConnArgsErr.WrapMsg("platformID is not int")
|
|
||||||
}
|
|
||||||
switch sdkType, _ := c.Query(SDKType); sdkType {
|
|
||||||
case "", GoSDK, JsSDK:
|
|
||||||
default:
|
|
||||||
return servererrs.ErrConnArgsErr.WrapMsg("sdkType is not go or js")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,179 +0,0 @@
|
|||||||
// Copyright © 2023 OpenIM. All rights reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package msggateway
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"net/http"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/openimsdk/tools/apiresp"
|
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
|
||||||
"github.com/openimsdk/tools/errs"
|
|
||||||
)
|
|
||||||
|
|
||||||
type LongConn interface {
|
|
||||||
// Close this connection
|
|
||||||
Close() error
|
|
||||||
// WriteMessage Write message to connection,messageType means data type,can be set binary(2) and text(1).
|
|
||||||
WriteMessage(messageType int, message []byte) error
|
|
||||||
// ReadMessage Read message from connection.
|
|
||||||
ReadMessage() (int, []byte, error)
|
|
||||||
// SetReadDeadline sets the read deadline on the underlying network connection,
|
|
||||||
// after a read has timed out, will return an error.
|
|
||||||
SetReadDeadline(timeout time.Duration) error
|
|
||||||
// SetWriteDeadline sets to write deadline when send message,when read has timed out,will return error.
|
|
||||||
SetWriteDeadline(timeout time.Duration) error
|
|
||||||
// Dial Try to dial a connection,url must set auth args,header can control compress data
|
|
||||||
Dial(urlStr string, requestHeader http.Header) (*http.Response, error)
|
|
||||||
// IsNil Whether the connection of the current long connection is nil
|
|
||||||
IsNil() bool
|
|
||||||
// SetConnNil Set the connection of the current long connection to nil
|
|
||||||
SetConnNil()
|
|
||||||
// SetReadLimit sets the maximum size for a message read from the peer.bytes
|
|
||||||
SetReadLimit(limit int64)
|
|
||||||
SetPongHandler(handler PingPongHandler)
|
|
||||||
SetPingHandler(handler PingPongHandler)
|
|
||||||
// GenerateLongConn Check the connection of the current and when it was sent are the same
|
|
||||||
GenerateLongConn(w http.ResponseWriter, r *http.Request) error
|
|
||||||
}
|
|
||||||
type GWebSocket struct {
|
|
||||||
protocolType int
|
|
||||||
conn *websocket.Conn
|
|
||||||
handshakeTimeout time.Duration
|
|
||||||
writeBufferSize int
|
|
||||||
}
|
|
||||||
|
|
||||||
func newGWebSocket(protocolType int, handshakeTimeout time.Duration, wbs int) *GWebSocket {
|
|
||||||
return &GWebSocket{protocolType: protocolType, handshakeTimeout: handshakeTimeout, writeBufferSize: wbs}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *GWebSocket) Close() error {
|
|
||||||
return d.conn.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *GWebSocket) GenerateLongConn(w http.ResponseWriter, r *http.Request) error {
|
|
||||||
upgrader := &websocket.Upgrader{
|
|
||||||
HandshakeTimeout: d.handshakeTimeout,
|
|
||||||
CheckOrigin: func(r *http.Request) bool { return true },
|
|
||||||
}
|
|
||||||
if d.writeBufferSize > 0 { // default is 4kb.
|
|
||||||
upgrader.WriteBufferSize = d.writeBufferSize
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := upgrader.Upgrade(w, r, nil)
|
|
||||||
if err != nil {
|
|
||||||
// The upgrader.Upgrade method usually returns enough error messages to diagnose problems that may occur during the upgrade
|
|
||||||
return errs.WrapMsg(err, "GenerateLongConn: WebSocket upgrade failed")
|
|
||||||
}
|
|
||||||
d.conn = conn
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *GWebSocket) WriteMessage(messageType int, message []byte) error {
|
|
||||||
// d.setSendConn(d.conn)
|
|
||||||
return d.conn.WriteMessage(messageType, message)
|
|
||||||
}
|
|
||||||
|
|
||||||
// func (d *GWebSocket) setSendConn(sendConn *websocket.Conn) {
|
|
||||||
// d.sendConn = sendConn
|
|
||||||
//}
|
|
||||||
|
|
||||||
func (d *GWebSocket) ReadMessage() (int, []byte, error) {
|
|
||||||
return d.conn.ReadMessage()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *GWebSocket) SetReadDeadline(timeout time.Duration) error {
|
|
||||||
return d.conn.SetReadDeadline(time.Now().Add(timeout))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *GWebSocket) SetWriteDeadline(timeout time.Duration) error {
|
|
||||||
if timeout <= 0 {
|
|
||||||
return errs.New("timeout must be greater than 0")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO SetWriteDeadline Future add error handling
|
|
||||||
if err := d.conn.SetWriteDeadline(time.Now().Add(timeout)); err != nil {
|
|
||||||
return errs.WrapMsg(err, "GWebSocket.SetWriteDeadline failed")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *GWebSocket) Dial(urlStr string, requestHeader http.Header) (*http.Response, error) {
|
|
||||||
conn, httpResp, err := websocket.DefaultDialer.Dial(urlStr, requestHeader)
|
|
||||||
if err != nil {
|
|
||||||
return httpResp, errs.WrapMsg(err, "GWebSocket.Dial failed", "url", urlStr)
|
|
||||||
}
|
|
||||||
d.conn = conn
|
|
||||||
return httpResp, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *GWebSocket) IsNil() bool {
|
|
||||||
return d.conn == nil
|
|
||||||
//
|
|
||||||
// if d.conn != nil {
|
|
||||||
// return false
|
|
||||||
// }
|
|
||||||
// return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *GWebSocket) SetConnNil() {
|
|
||||||
d.conn = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *GWebSocket) SetReadLimit(limit int64) {
|
|
||||||
d.conn.SetReadLimit(limit)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *GWebSocket) SetPongHandler(handler PingPongHandler) {
|
|
||||||
d.conn.SetPongHandler(handler)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *GWebSocket) SetPingHandler(handler PingPongHandler) {
|
|
||||||
d.conn.SetPingHandler(handler)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *GWebSocket) RespondWithError(err error, w http.ResponseWriter, r *http.Request) error {
|
|
||||||
if err := d.GenerateLongConn(w, r); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
data, err := json.Marshal(apiresp.ParseError(err))
|
|
||||||
if err != nil {
|
|
||||||
_ = d.Close()
|
|
||||||
return errs.WrapMsg(err, "json marshal failed")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := d.WriteMessage(MessageText, data); err != nil {
|
|
||||||
_ = d.Close()
|
|
||||||
return errs.WrapMsg(err, "WriteMessage failed")
|
|
||||||
}
|
|
||||||
_ = d.Close()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *GWebSocket) RespondWithSuccess() error {
|
|
||||||
data, err := json.Marshal(apiresp.ParseError(nil))
|
|
||||||
if err != nil {
|
|
||||||
_ = d.Close()
|
|
||||||
return errs.WrapMsg(err, "json marshal failed")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := d.WriteMessage(MessageText, data); err != nil {
|
|
||||||
_ = d.Close()
|
|
||||||
return errs.WrapMsg(err, "WriteMessage failed")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@ -2,18 +2,20 @@ package msggateway
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
"github.com/openimsdk/open-im-server/v3/pkg/rpcli"
|
"github.com/openimsdk/open-im-server/v3/pkg/rpcli"
|
||||||
|
"github.com/openimsdk/tools/apiresp"
|
||||||
|
|
||||||
"github.com/openimsdk/open-im-server/v3/pkg/common/webhook"
|
"github.com/openimsdk/open-im-server/v3/pkg/common/webhook"
|
||||||
"github.com/openimsdk/open-im-server/v3/pkg/rpccache"
|
"github.com/openimsdk/open-im-server/v3/pkg/rpccache"
|
||||||
pbAuth "github.com/openimsdk/protocol/auth"
|
pbAuth "github.com/openimsdk/protocol/auth"
|
||||||
"github.com/openimsdk/tools/errs"
|
|
||||||
"github.com/openimsdk/tools/mcontext"
|
"github.com/openimsdk/tools/mcontext"
|
||||||
|
|
||||||
"github.com/go-playground/validator/v10"
|
"github.com/go-playground/validator/v10"
|
||||||
@ -23,10 +25,11 @@ import (
|
|||||||
"github.com/openimsdk/protocol/msggateway"
|
"github.com/openimsdk/protocol/msggateway"
|
||||||
"github.com/openimsdk/tools/discovery"
|
"github.com/openimsdk/tools/discovery"
|
||||||
"github.com/openimsdk/tools/log"
|
"github.com/openimsdk/tools/log"
|
||||||
"github.com/openimsdk/tools/utils/stringutil"
|
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var wsSuccessResponse, _ = json.Marshal(&apiresp.ApiResponse{})
|
||||||
|
|
||||||
type LongConnServer interface {
|
type LongConnServer interface {
|
||||||
Run(ctx context.Context) error
|
Run(ctx context.Context) error
|
||||||
wsHandler(w http.ResponseWriter, r *http.Request)
|
wsHandler(w http.ResponseWriter, r *http.Request)
|
||||||
@ -43,6 +46,7 @@ type LongConnServer interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type WsServer struct {
|
type WsServer struct {
|
||||||
|
websocket *websocket.Upgrader
|
||||||
msgGatewayConfig *Config
|
msgGatewayConfig *Config
|
||||||
port int
|
port int
|
||||||
wsMaxConnNum int64
|
wsMaxConnNum int64
|
||||||
@ -136,9 +140,13 @@ func NewWsServer(msgGatewayConfig *Config, opts ...Option) *WsServer {
|
|||||||
o(&config)
|
o(&config)
|
||||||
}
|
}
|
||||||
//userRpcClient := rpcclient.NewUserRpcClient(client, config.Discovery.RpcService.User, config.Share.IMAdminUser)
|
//userRpcClient := rpcclient.NewUserRpcClient(client, config.Discovery.RpcService.User, config.Share.IMAdminUser)
|
||||||
|
upgrader := &websocket.Upgrader{
|
||||||
|
HandshakeTimeout: config.handshakeTimeout,
|
||||||
|
CheckOrigin: func(r *http.Request) bool { return true },
|
||||||
|
}
|
||||||
v := validator.New()
|
v := validator.New()
|
||||||
return &WsServer{
|
return &WsServer{
|
||||||
|
websocket: upgrader,
|
||||||
msgGatewayConfig: msgGatewayConfig,
|
msgGatewayConfig: msgGatewayConfig,
|
||||||
port: config.port,
|
port: config.port,
|
||||||
wsMaxConnNum: config.maxConnNum,
|
wsMaxConnNum: config.maxConnNum,
|
||||||
@ -260,8 +268,7 @@ func (ws *WsServer) registerClient(client *Client) {
|
|||||||
)
|
)
|
||||||
oldClients, userOK, clientOK = ws.clients.Get(client.UserID, client.PlatformID)
|
oldClients, userOK, clientOK = ws.clients.Get(client.UserID, client.PlatformID)
|
||||||
|
|
||||||
log.ZInfo(client.ctx, "registerClient", "userID", client.UserID, "platformID", client.PlatformID,
|
log.ZInfo(client.ctx, "registerClient", "userID", client.UserID, "platformID", client.PlatformID)
|
||||||
"sdkVersion", client.SDKVersion)
|
|
||||||
|
|
||||||
if !userOK {
|
if !userOK {
|
||||||
ws.clients.Set(client.UserID, client)
|
ws.clients.Set(client.UserID, client)
|
||||||
@ -448,7 +455,7 @@ func (ws *WsServer) unregisterClient(client *Client) {
|
|||||||
// validateRespWithRequest checks if the response matches the expected userID and platformID.
|
// validateRespWithRequest checks if the response matches the expected userID and platformID.
|
||||||
func (ws *WsServer) validateRespWithRequest(ctx *UserConnContext, resp *pbAuth.ParseTokenResp) error {
|
func (ws *WsServer) validateRespWithRequest(ctx *UserConnContext, resp *pbAuth.ParseTokenResp) error {
|
||||||
userID := ctx.GetUserID()
|
userID := ctx.GetUserID()
|
||||||
platformID := stringutil.StringToInt32(ctx.GetPlatformID())
|
platformID := int32(ctx.GetPlatformID())
|
||||||
if resp.UserID != userID {
|
if resp.UserID != userID {
|
||||||
return servererrs.ErrTokenInvalid.WrapMsg(fmt.Sprintf("token uid %s != userID %s", resp.UserID, userID))
|
return servererrs.ErrTokenInvalid.WrapMsg(fmt.Sprintf("token uid %s != userID %s", resp.UserID, userID))
|
||||||
}
|
}
|
||||||
@ -458,19 +465,37 @@ func (ws *WsServer) validateRespWithRequest(ctx *UserConnContext, resp *pbAuth.P
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ws *WsServer) handlerError(ctx *UserConnContext, w http.ResponseWriter, r *http.Request, err error) {
|
||||||
|
if !ctx.ShouldSendResp() {
|
||||||
|
httpError(ctx, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// the browser cannot get the response of upgrade failure
|
||||||
|
data, err := json.Marshal(apiresp.ParseError(err))
|
||||||
|
if err != nil {
|
||||||
|
log.ZError(ctx, "json marshal failed", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
conn, upgradeErr := ws.websocket.Upgrade(w, r, nil)
|
||||||
|
if upgradeErr != nil {
|
||||||
|
log.ZWarn(ctx, "websocket upgrade failed", upgradeErr, "respErr", err, "resp", string(data))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
if err := conn.WriteMessage(websocket.TextMessage, data); err != nil {
|
||||||
|
log.ZWarn(ctx, "WriteMessage failed", err, "respErr", err, "resp", string(data))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) {
|
func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
// Create a new connection context
|
// Create a new connection context
|
||||||
connContext := newContext(w, r)
|
connContext := newContext(w, r)
|
||||||
|
|
||||||
if !ws.ready.Load() {
|
|
||||||
httpError(connContext, errs.New("ws server not ready"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if the current number of online user connections exceeds the maximum limit
|
// Check if the current number of online user connections exceeds the maximum limit
|
||||||
if ws.onlineUserConnNum.Load() >= ws.wsMaxConnNum {
|
if ws.onlineUserConnNum.Load() >= ws.wsMaxConnNum {
|
||||||
// If it exceeds the maximum connection number, return an error via HTTP and stop processing
|
// If it exceeds the maximum connection number, return an error via HTTP and stop processing
|
||||||
httpError(connContext, servererrs.ErrConnOverMaxNumLimit.WrapMsg("over max conn num limit"))
|
ws.handlerError(connContext, w, r, servererrs.ErrConnOverMaxNumLimit.WrapMsg("over max conn num limit"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -478,31 +503,14 @@ func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
err := connContext.ParseEssentialArgs()
|
err := connContext.ParseEssentialArgs()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// If there's an error during parsing, return an error via HTTP and stop processing
|
// If there's an error during parsing, return an error via HTTP and stop processing
|
||||||
|
ws.handlerError(connContext, w, r, err)
|
||||||
httpError(connContext, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if ws.authClient == nil {
|
|
||||||
httpError(connContext, errs.New("auth client is not initialized"))
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Call the authentication client to parse the Token obtained from the context
|
// Call the authentication client to parse the Token obtained from the context
|
||||||
resp, err := ws.authClient.ParseToken(connContext, connContext.GetToken())
|
resp, err := ws.authClient.ParseToken(connContext, connContext.GetToken())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// If there's an error parsing the Token, decide whether to send the error message via WebSocket based on the context flag
|
ws.handlerError(connContext, w, r, err)
|
||||||
shouldSendError := connContext.ShouldSendResp()
|
|
||||||
if shouldSendError {
|
|
||||||
// Create a WebSocket connection object and attempt to send the error message via WebSocket
|
|
||||||
wsLongConn := newGWebSocket(WebSocket, ws.handshakeTimeout, ws.writeBufferSize)
|
|
||||||
if err := wsLongConn.RespondWithError(err, w, r); err == nil {
|
|
||||||
// If the error message is successfully sent via WebSocket, stop processing
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// If sending via WebSocket is not required or fails, return the error via HTTP and stop processing
|
|
||||||
httpError(connContext, err)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -510,32 +518,30 @@ func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
err = ws.validateRespWithRequest(connContext, resp)
|
err = ws.validateRespWithRequest(connContext, resp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// If validation fails, return an error via HTTP and stop processing
|
// If validation fails, return an error via HTTP and stop processing
|
||||||
httpError(connContext, err)
|
ws.handlerError(connContext, w, r, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
conn, err := ws.websocket.Upgrade(w, r, nil)
|
||||||
log.ZDebug(connContext, "new conn", "token", connContext.GetToken())
|
if err != nil {
|
||||||
// Create a WebSocket long connection object
|
log.ZWarn(connContext, "websocket upgrade failed", err)
|
||||||
wsLongConn := newGWebSocket(WebSocket, ws.handshakeTimeout, ws.writeBufferSize)
|
|
||||||
if err := wsLongConn.GenerateLongConn(w, r); err != nil {
|
|
||||||
//If the creation of the long connection fails, the error is handled internally during the handshake process.
|
|
||||||
log.ZWarn(connContext, "long connection fails", err)
|
|
||||||
return
|
return
|
||||||
} else {
|
}
|
||||||
// Check if a normal response should be sent via WebSocket
|
if connContext.ShouldSendResp() {
|
||||||
shouldSendSuccessResp := connContext.ShouldSendResp()
|
if err := conn.WriteMessage(websocket.TextMessage, wsSuccessResponse); err != nil {
|
||||||
if shouldSendSuccessResp {
|
log.ZWarn(connContext, "WriteMessage first response", err)
|
||||||
// Attempt to send a success message through WebSocket
|
return
|
||||||
if err := wsLongConn.RespondWithSuccess(); err != nil {
|
|
||||||
// If the success message is successfully sent, end further processing
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Retrieve a client object from the client pool, reset its state, and associate it with the current WebSocket long connection
|
log.ZDebug(connContext, "new conn", "token", connContext.GetToken())
|
||||||
client := ws.clientPool.Get().(*Client)
|
|
||||||
client.ResetClient(connContext, wsLongConn, ws)
|
var pingInterval time.Duration
|
||||||
|
if connContext.GetPlatformID() == constant.WebPlatformID {
|
||||||
|
pingInterval = pingPeriod
|
||||||
|
}
|
||||||
|
|
||||||
|
client := new(Client)
|
||||||
|
client.ResetClient(connContext, NewWebSocketClientConn(conn, maxMessageSize, pongWait, pingInterval), ws)
|
||||||
|
|
||||||
// Register the client with the server and start message processing
|
// Register the client with the server and start message processing
|
||||||
ws.registerChan <- client
|
ws.registerChan <- client
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user