ws update

This commit is contained in:
Gordon 2023-03-24 16:39:33 +08:00
parent 98e22cc699
commit 1291420db0
7 changed files with 135 additions and 68 deletions

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"github.com/OpenIMSDK/Open-IM-Server/pkg/apiresp"
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant"
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/log" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/log"
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/mcontext" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/mcontext"
@ -14,6 +15,10 @@ import (
"sync" "sync"
) )
var ErrConnClosed = errors.New("conn has closed")
var ErrNotSupportMessageProtocol = errors.New("not support message protocol")
var ErrClientClosed = errors.New("client actively close the connection")
const ( const (
// MessageText is for UTF-8 encoded text messages like JSON. // MessageText is for UTF-8 encoded text messages like JSON.
MessageText = iota + 1 MessageText = iota + 1
@ -33,6 +38,8 @@ const (
PongMessage = 10 PongMessage = 10
) )
type PongHandler func(string) error
type Client struct { type Client struct {
w *sync.Mutex w *sync.Mutex
conn LongConn conn LongConn
@ -41,9 +48,9 @@ type Client struct {
userID string userID string
isBackground bool isBackground bool
ctx *UserConnContext ctx *UserConnContext
onlineAt int64 // 上线时间戳(毫秒)
longConnServer LongConnServer longConnServer LongConnServer
closed bool closed bool
closedErr error
} }
func newClient(ctx *UserConnContext, conn LongConn, isCompress bool) *Client { func newClient(ctx *UserConnContext, conn LongConn, isCompress bool) *Client {
@ -54,7 +61,6 @@ func newClient(ctx *UserConnContext, conn LongConn, isCompress bool) *Client {
isCompress: isCompress, isCompress: isCompress,
userID: ctx.GetUserID(), userID: ctx.GetUserID(),
ctx: ctx, ctx: ctx,
onlineAt: utils.GetCurrentTimestampByMill(),
} }
} }
func (c *Client) ResetClient(ctx *UserConnContext, conn LongConn, isCompress bool, longConnServer LongConnServer) { func (c *Client) ResetClient(ctx *UserConnContext, conn LongConn, isCompress bool, longConnServer LongConnServer) {
@ -64,9 +70,12 @@ func (c *Client) ResetClient(ctx *UserConnContext, conn LongConn, isCompress boo
c.isCompress = isCompress c.isCompress = isCompress
c.userID = ctx.GetUserID() c.userID = ctx.GetUserID()
c.ctx = ctx c.ctx = ctx
c.onlineAt = utils.GetCurrentTimestampByMill()
c.longConnServer = longConnServer c.longConnServer = longConnServer
} }
func (c *Client) pongHandler(_ string) error {
c.conn.SetReadDeadline(pongWait)
return nil
}
func (c *Client) readMessage() { func (c *Client) readMessage() {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
@ -74,31 +83,36 @@ func (c *Client) readMessage() {
} }
c.close() c.close()
}() }()
//var returnErr error c.conn.SetReadLimit(maxMessageSize)
_ = c.conn.SetReadDeadline(pongWait)
c.conn.SetPongHandler(c.pongHandler)
for { for {
messageType, message, returnErr := c.conn.ReadMessage() messageType, message, returnErr := c.conn.ReadMessage()
if returnErr != nil { if returnErr != nil {
break c.closedErr = returnErr
return
} }
if c.closed == true { //连接刚置位已经关闭,但是协程还没退出的场景 if c.closed == true { //连接刚置位已经关闭,但是协程还没退出的场景
break c.closedErr = ErrConnClosed
return
} }
switch messageType { switch messageType {
case PingMessage:
case PongMessage:
case CloseMessage:
return
case MessageText:
case MessageBinary: case MessageBinary:
if len(message) == 0 { parseDataErr := c.handleMessage(message)
continue if parseDataErr != nil {
c.closedErr = parseDataErr
return
} }
returnErr = c.handleMessage(message) case MessageText:
if returnErr != nil { c.closedErr = ErrNotSupportMessageProtocol
log.ZError(context.Background(), "WSGetNewestSeq", returnErr) return
break case PingMessage:
} err := c.writePongMsg()
log.ZError(c.ctx, "writePongMsg", err)
case CloseMessage:
c.closedErr = ErrClientClosed
return
default:
} }
} }
@ -120,7 +134,7 @@ func (c *Client) handleMessage(message []byte) error {
return utils.Wrap(err, "") return utils.Wrap(err, "")
} }
if binaryReq.SendID != c.userID { if binaryReq.SendID != c.userID {
return errors.New("exception conn userID not same to req userID") return utils.Wrap(errors.New("exception conn userID not same to req userID"), binaryReq.String())
} }
ctx := mcontext.WithMustInfoCtx([]string{binaryReq.OperationID, binaryReq.SendID, constant.PlatformIDToName(c.platformID), c.ctx.GetConnID()}) ctx := mcontext.WithMustInfoCtx([]string{binaryReq.OperationID, binaryReq.SendID, constant.PlatformIDToName(c.platformID), c.ctx.GetConnID()})
var messageErr error var messageErr error
@ -128,8 +142,6 @@ func (c *Client) handleMessage(message []byte) error {
switch binaryReq.ReqIdentifier { switch binaryReq.ReqIdentifier {
case WSGetNewestSeq: case WSGetNewestSeq:
resp, messageErr = c.longConnServer.GetSeq(ctx, binaryReq) resp, messageErr = c.longConnServer.GetSeq(ctx, binaryReq)
log.ZError(ctx, "WSGetNewestSeq", messageErr, "resp", resp)
case WSSendMsg: case WSSendMsg:
resp, messageErr = c.longConnServer.SendMessage(ctx, binaryReq) resp, messageErr = c.longConnServer.SendMessage(ctx, binaryReq)
case WSSendSignalMsg: case WSSendSignalMsg:
@ -166,13 +178,16 @@ func (c *Client) close() {
} }
func (c *Client) replyMessage(binaryReq *Req, err error, resp []byte) { func (c *Client) replyMessage(binaryReq *Req, err error, resp []byte) {
errResp := apiresp.ParseError(err)
mReply := Resp{ mReply := Resp{
ReqIdentifier: binaryReq.ReqIdentifier, ReqIdentifier: binaryReq.ReqIdentifier,
MsgIncr: binaryReq.MsgIncr, MsgIncr: binaryReq.MsgIncr,
OperationID: binaryReq.OperationID, OperationID: binaryReq.OperationID,
ErrCode: errResp.ErrCode,
ErrMsg: errResp.ErrMsg,
Data: resp, Data: resp,
} }
_ = c.writeMsg(mReply) _ = c.writeBinaryMsg(mReply)
} }
func (c *Client) PushMessage(ctx context.Context, msgData *sdkws.MsgData) error { func (c *Client) PushMessage(ctx context.Context, msgData *sdkws.MsgData) error {
data, err := proto.Marshal(msgData) data, err := proto.Marshal(msgData)
@ -184,15 +199,14 @@ func (c *Client) PushMessage(ctx context.Context, msgData *sdkws.MsgData) error
OperationID: mcontext.GetOperationID(ctx), OperationID: mcontext.GetOperationID(ctx),
Data: data, Data: data,
} }
return c.writeMsg(resp) return c.writeBinaryMsg(resp)
} }
func (c *Client) KickOnlineMessage(ctx context.Context) error { func (c *Client) KickOnlineMessage(ctx context.Context) error {
return nil return nil
} }
func (c *Client) writeMsg(resp Resp) error { func (c *Client) writeBinaryMsg(resp Resp) error {
c.w.Lock() c.w.Lock()
defer c.w.Unlock() defer c.w.Unlock()
if c.closed == true { if c.closed == true {
@ -204,7 +218,7 @@ func (c *Client) writeMsg(resp Resp) error {
if err != nil { if err != nil {
return utils.Wrap(err, "") return utils.Wrap(err, "")
} }
_ = c.conn.SetWriteTimeout(60) _ = c.conn.SetWriteDeadline(writeWait)
if c.isCompress { if c.isCompress {
var compressErr error var compressErr error
resultBuf, compressErr = c.longConnServer.Compress(encodeBuf) resultBuf, compressErr = c.longConnServer.Compress(encodeBuf)
@ -216,3 +230,14 @@ func (c *Client) writeMsg(resp Resp) error {
return c.conn.WriteMessage(MessageBinary, encodedBuf) return c.conn.WriteMessage(MessageBinary, encodedBuf)
} }
} }
func (c *Client) writePongMsg() error {
c.w.Lock()
defer c.w.Unlock()
if c.closed == true {
return nil
}
_ = c.conn.SetWriteDeadline(writeWait)
return c.conn.WriteMessage(PongMessage, nil)
}

View File

@ -1,5 +1,7 @@
package msggateway package msggateway
import "time"
const ( const (
WsUserID = "sendID" WsUserID = "sendID"
CommonUserID = "userID" CommonUserID = "userID"
@ -25,3 +27,14 @@ const (
WsSetBackgroundStatus = 2004 WsSetBackgroundStatus = 2004
WSDataError = 3001 WSDataError = 3001
) )
const (
// Time allowed to write a message to the peer.
writeWait = 10 * time.Second
// Time allowed to read the next pong message from the peer.
pongWait = 30 * time.Second
// Maximum message size allowed from peer.
maxMessageSize = 51200
)

View File

@ -13,17 +13,20 @@ type LongConn interface {
WriteMessage(messageType int, message []byte) error WriteMessage(messageType int, message []byte) error
//Read message from connection. //Read message from connection.
ReadMessage() (int, []byte, error) ReadMessage() (int, []byte, error)
//SetReadTimeout sets the read deadline on the underlying network connection, // SetReadDeadline sets the read deadline on the underlying network connection,
//after a read has timed out, will return an error. //after a read has timed out, will return an error.
SetReadTimeout(timeout int) error SetReadDeadline(timeout time.Duration) error
//SetWriteTimeout sets the write deadline when send message,when read has timed out,will return error. // SetWriteDeadline sets the write deadline when send message,when read has timed out,will return error.
SetWriteTimeout(timeout int) error SetWriteDeadline(timeout time.Duration) error
//Try to dial a connection,url must set auth args,header can control compress data //Try to dial a connection,url must set auth args,header can control compress data
Dial(urlStr string, requestHeader http.Header) (*http.Response, error) Dial(urlStr string, requestHeader http.Header) (*http.Response, error)
//Whether the connection of the current long connection is nil //Whether the connection of the current long connection is nil
IsNil() bool IsNil() bool
//Set the connection of the current long connection to nil //Set the connection of the current long connection to nil
SetConnNil() SetConnNil()
// SetReadLimit sets the maximum size for a message read from the peer.bytes
SetReadLimit(limit int64)
SetPongHandler(handler PongHandler)
//Check the connection of the current and when it was sent are the same //Check the connection of the current and when it was sent are the same
//CheckSendConnDiffNow() bool //CheckSendConnDiffNow() bool
// //
@ -33,11 +36,10 @@ type GWebSocket struct {
protocolType int protocolType int
conn *websocket.Conn conn *websocket.Conn
handshakeTimeout time.Duration handshakeTimeout time.Duration
readBufferSize, WriteBufferSize int
} }
func newGWebSocket(protocolType int, handshakeTimeout time.Duration, readBufferSize int) *GWebSocket { func newGWebSocket(protocolType int, handshakeTimeout time.Duration) *GWebSocket {
return &GWebSocket{protocolType: protocolType, handshakeTimeout: handshakeTimeout, readBufferSize: readBufferSize} return &GWebSocket{protocolType: protocolType, handshakeTimeout: handshakeTimeout}
} }
func (d *GWebSocket) Close() error { func (d *GWebSocket) Close() error {
@ -46,7 +48,6 @@ func (d *GWebSocket) Close() error {
func (d *GWebSocket) GenerateLongConn(w http.ResponseWriter, r *http.Request) error { func (d *GWebSocket) GenerateLongConn(w http.ResponseWriter, r *http.Request) error {
upgrader := &websocket.Upgrader{ upgrader := &websocket.Upgrader{
HandshakeTimeout: d.handshakeTimeout, HandshakeTimeout: d.handshakeTimeout,
ReadBufferSize: d.readBufferSize,
CheckOrigin: func(r *http.Request) bool { return true }, CheckOrigin: func(r *http.Request) bool { return true },
} }
conn, err := upgrader.Upgrade(w, r, nil) conn, err := upgrader.Upgrade(w, r, nil)
@ -69,12 +70,12 @@ func (d *GWebSocket) WriteMessage(messageType int, message []byte) error {
func (d *GWebSocket) ReadMessage() (int, []byte, error) { func (d *GWebSocket) ReadMessage() (int, []byte, error) {
return d.conn.ReadMessage() return d.conn.ReadMessage()
} }
func (d *GWebSocket) SetReadTimeout(timeout int) error { func (d *GWebSocket) SetReadDeadline(timeout time.Duration) error {
return d.conn.SetReadDeadline(time.Now().Add(time.Duration(timeout) * time.Second)) return d.conn.SetReadDeadline(time.Now().Add(timeout))
} }
func (d *GWebSocket) SetWriteTimeout(timeout int) error { func (d *GWebSocket) SetWriteDeadline(timeout time.Duration) error {
return d.conn.SetWriteDeadline(time.Now().Add(time.Duration(timeout) * time.Second)) return d.conn.SetWriteDeadline(time.Now().Add(timeout))
} }
func (d *GWebSocket) Dial(urlStr string, requestHeader http.Header) (*http.Response, error) { func (d *GWebSocket) Dial(urlStr string, requestHeader http.Header) (*http.Response, error) {
@ -96,6 +97,12 @@ func (d *GWebSocket) IsNil() bool {
func (d *GWebSocket) SetConnNil() { func (d *GWebSocket) SetConnNil() {
d.conn = nil d.conn = nil
} }
func (d *GWebSocket) SetReadLimit(limit int64) {
d.conn.SetReadLimit(limit)
}
func (d *GWebSocket) SetPongHandler(handler PongHandler) {
d.conn.SetPongHandler(handler)
}
//func (d *GWebSocket) CheckSendConnDiffNow() bool { //func (d *GWebSocket) CheckSendConnDiffNow() bool {
// return d.conn == d.sendConn // return d.conn == d.sendConn

View File

@ -2,10 +2,10 @@ package msggateway
import ( import (
"context" "context"
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/log"
"github.com/OpenIMSDK/Open-IM-Server/pkg/proto/msg" "github.com/OpenIMSDK/Open-IM-Server/pkg/proto/msg"
"github.com/OpenIMSDK/Open-IM-Server/pkg/proto/sdkws" "github.com/OpenIMSDK/Open-IM-Server/pkg/proto/sdkws"
"github.com/OpenIMSDK/Open-IM-Server/pkg/rpcclient/notification" "github.com/OpenIMSDK/Open-IM-Server/pkg/rpcclient/notification"
"github.com/OpenIMSDK/Open-IM-Server/pkg/utils"
"github.com/go-playground/validator/v10" "github.com/go-playground/validator/v10"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
) )
@ -18,11 +18,16 @@ type Req struct {
MsgIncr string `json:"msgIncr" validate:"required"` MsgIncr string `json:"msgIncr" validate:"required"`
Data []byte `json:"data"` Data []byte `json:"data"`
} }
func (r *Req) String() string {
return utils.StructToJsonString(r)
}
type Resp struct { type Resp struct {
ReqIdentifier int32 `json:"reqIdentifier"` ReqIdentifier int32 `json:"reqIdentifier"`
MsgIncr string `json:"msgIncr"` MsgIncr string `json:"msgIncr"`
OperationID string `json:"operationID"` OperationID string `json:"operationID"`
ErrCode int32 `json:"errCode"` ErrCode int `json:"errCode"`
ErrMsg string `json:"errMsg"` ErrMsg string `json:"errMsg"`
Data []byte `json:"data"` Data []byte `json:"data"`
} }
@ -54,7 +59,6 @@ func (g GrpcHandler) GetSeq(context context.Context, data Req) ([]byte, error) {
if err := g.validate.Struct(req); err != nil { if err := g.validate.Struct(req); err != nil {
return nil, err return nil, err
} }
log.ZDebug(context, "msggateway GetSeq", "notification", g.notification, "msg", g.notification.Msg)
resp, err := g.notification.Msg.GetMaxAndMinSeq(context, &req) resp, err := g.notification.Msg.GetMaxAndMinSeq(context, &req)
if err != nil { if err != nil {
return nil, err return nil, err
@ -148,3 +152,23 @@ func (g GrpcHandler) SetUserDeviceBackground(_ context.Context, data Req) ([]byt
} }
return nil, req.IsBackground, nil return nil, req.IsBackground, nil
} }
//func (g GrpcHandler) call[T any](ctx context.Context, data Req, m proto.Message, rpc func(ctx context.Context, req proto.Message)) ([]byte, error) {
// if err := proto.Unmarshal(data.Data, m); err != nil {
// return nil, err
// }
// if err := g.validate.Struct(m); err != nil {
// return nil, err
// }
// rpc(ctx, m)
// req := msg.SendMsgReq{MsgData: &msgData}
// resp, err := g.notification.Msg.SendMsg(context, &req)
// if err != nil {
// return nil, err
// }
// c, err := proto.Marshal(resp)
// if err != nil {
// return nil, err
// }
// return c, nil
//}

View File

@ -43,7 +43,6 @@ type WsServer struct {
onlineUserNum int64 onlineUserNum int64
onlineUserConnNum int64 onlineUserConnNum int64
handshakeTimeout time.Duration handshakeTimeout time.Duration
readBufferSize, WriteBufferSize int
hubServer *Server hubServer *Server
validate *validator.Validate validate *validator.Validate
Compressor Compressor
@ -85,7 +84,6 @@ func NewWsServer(opts ...Option) (*WsServer, error) {
port: config.port, port: config.port,
wsMaxConnNum: config.maxConnNum, wsMaxConnNum: config.maxConnNum,
handshakeTimeout: config.handshakeTimeout, handshakeTimeout: config.handshakeTimeout,
readBufferSize: config.messageMaxMsgLength,
clientPool: sync.Pool{ clientPool: sync.Pool{
New: func() interface{} { New: func() interface{} {
return new(Client) return new(Client)
@ -149,7 +147,7 @@ func (ws *WsServer) unregisterClient(client *Client) {
atomic.AddInt64(&ws.onlineUserNum, -1) atomic.AddInt64(&ws.onlineUserNum, -1)
} }
atomic.AddInt64(&ws.onlineUserConnNum, -1) atomic.AddInt64(&ws.onlineUserConnNum, -1)
log.ZInfo(client.ctx, "user offline", "online user Num", ws.onlineUserNum, "online user conn Num", ws.onlineUserConnNum) log.ZInfo(client.ctx, "user offline", "close reason", client.closedErr, "online user Num", ws.onlineUserNum, "online user conn Num", ws.onlineUserConnNum)
} }
func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) { func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) {
@ -186,7 +184,7 @@ func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) {
httpError(context, err) httpError(context, err)
return return
} }
wsLongConn := newGWebSocket(WebSocket, ws.handshakeTimeout, ws.readBufferSize) wsLongConn := newGWebSocket(WebSocket, ws.handshakeTimeout)
err = wsLongConn.GenerateLongConn(w, r) err = wsLongConn.GenerateLongConn(w, r)
if err != nil { if err != nil {
httpError(context, err) httpError(context, err)

View File

@ -10,7 +10,7 @@ func GinError(c *gin.Context, err error) {
GinSuccess(c, nil) GinSuccess(c, nil)
return return
} }
c.JSON(http.StatusOK, apiError(err)) c.JSON(http.StatusOK, ParseError(err))
} }
func GinSuccess(c *gin.Context, data any) { func GinSuccess(c *gin.Context, data any) {

View File

@ -5,7 +5,7 @@ import (
"reflect" "reflect"
) )
type apiResponse struct { type ApiResponse struct {
ErrCode int `json:"errCode"` ErrCode int `json:"errCode"`
ErrMsg string `json:"errMsg"` ErrMsg string `json:"errMsg"`
ErrDlt string `json:"errDlt"` ErrDlt string `json:"errDlt"`
@ -30,23 +30,23 @@ func isAllFieldsPrivate(v any) bool {
return true return true
} }
func apiSuccess(data any) *apiResponse { func apiSuccess(data any) *ApiResponse {
if isAllFieldsPrivate(data) { if isAllFieldsPrivate(data) {
return &apiResponse{} return &ApiResponse{}
} }
return &apiResponse{ return &ApiResponse{
Data: data, Data: data,
} }
} }
func apiError(err error) *apiResponse { func ParseError(err error) *ApiResponse {
unwrap := errs.Unwrap(err) unwrap := errs.Unwrap(err)
if codeErr, ok := unwrap.(errs.CodeError); ok { if codeErr, ok := unwrap.(errs.CodeError); ok {
resp := apiResponse{ErrCode: codeErr.Code(), ErrMsg: codeErr.Msg(), ErrDlt: codeErr.Detail()} resp := ApiResponse{ErrCode: codeErr.Code(), ErrMsg: codeErr.Msg(), ErrDlt: codeErr.Detail()}
if resp.ErrDlt == "" { if resp.ErrDlt == "" {
resp.ErrDlt = err.Error() resp.ErrDlt = err.Error()
} }
return &resp return &resp
} }
return &apiResponse{ErrCode: errs.ServerInternalError, ErrMsg: err.Error()} return &ApiResponse{ErrCode: errs.ServerInternalError, ErrMsg: err.Error()}
} }