Merge remote-tracking branch 'origin/errcode' into errcode

This commit is contained in:
withchao 2023-02-16 17:11:00 +08:00
commit ff542e83fe
11 changed files with 341 additions and 97 deletions

View File

@ -28,3 +28,39 @@ func (m *MsgCheck) SendMsg(ctx context.Context, req *msg.SendMsgReq) (*msg.SendM
resp, err := msg.NewMsgClient(cc).SendMsg(ctx, req) resp, err := msg.NewMsgClient(cc).SendMsg(ctx, req)
return resp, err return resp, err
} }
func (m *MsgCheck) SendMsg(ctx context.Context, req *msg.SendMsgReq) (*msg.SendMsgResp, error) {
cc, err := m.getConn()
if err != nil {
return nil, err
}
resp, err := msg.NewMsgClient(cc).SendMsg(ctx, req)
return resp, err
}
func (m *MsgCheck) SendMsg(ctx context.Context, req *msg.SendMsgReq) (*msg.SendMsgResp, error) {
cc, err := m.getConn()
if err != nil {
return nil, err
}
resp, err := msg.NewMsgClient(cc).SendMsg(ctx, req)
return resp, err
}
func (m *MsgCheck) SendMsg(ctx context.Context, req *msg.SendMsgReq) (*msg.SendMsgResp, error) {
cc, err := m.getConn()
if err != nil {
return nil, err
}
resp, err := msg.NewMsgClient(cc).SendMsg(ctx, req)
return resp, err
}
func (m *MsgCheck) SendMsg(ctx context.Context, req *msg.SendMsgReq) (*msg.SendMsgResp, error) {
cc, err := m.getConn()
if err != nil {
return nil, err
}
resp, err := msg.NewMsgClient(cc).SendMsg(ctx, req)
return resp, err
}

View File

@ -3,14 +3,12 @@ package new
import ( import (
"Open_IM/pkg/common/constant" "Open_IM/pkg/common/constant"
"Open_IM/pkg/utils" "Open_IM/pkg/utils"
"bytes"
"context" "context"
"errors" "errors"
"fmt" "fmt"
"github.com/go-playground/validator/v10" "github.com/go-playground/validator/v10"
"runtime/debug" "runtime/debug"
"sync" "sync"
"time"
) )
const ( const (
@ -35,35 +33,52 @@ const (
type Client struct { type Client struct {
w *sync.Mutex w *sync.Mutex
conn LongConn conn LongConn
PlatformID int32 platformID int
PushedMaxSeq uint32 isCompress bool
IsCompress bool
userID string userID string
IsBackground bool isBackground bool
token string
connID string connID string
onlineAt int64 // 上线时间戳(毫秒) onlineAt int64 // 上线时间戳(毫秒)
handler MessageHandler handler MessageHandler
unregisterChan chan *Client unregisterChan chan *Client
compressor Compressor compressor Compressor
encoder Encoder encoder Encoder
userContext UserConnContext
validate *validator.Validate validate *validator.Validate
closed bool closed bool
} }
func newClient(conn LongConn, isCompress bool, userID string, isBackground bool, token string, func newClient(ctx *UserConnContext, conn LongConn, isCompress bool, compressor Compressor, encoder Encoder,
connID string, onlineAt int64, handler MessageHandler, unregisterChan chan *Client) *Client { handler MessageHandler, unregisterChan chan *Client, validate *validator.Validate) *Client {
return &Client{ return &Client{
conn: conn, w: new(sync.Mutex),
IsCompress: isCompress, conn: conn,
userID: userID, IsBackground: isBackground, token: token, platformID: utils.StringToInt(ctx.GetPlatformID()),
connID: connID, isCompress: isCompress,
onlineAt: onlineAt, userID: ctx.GetUserID(),
compressor: compressor,
encoder: encoder,
connID: ctx.GetConnID(),
onlineAt: utils.GetCurrentTimestampByMill(),
handler: handler, handler: handler,
unregisterChan: unregisterChan, unregisterChan: unregisterChan,
validate: validate,
} }
} }
func (c *Client) ResetClient(ctx *UserConnContext, conn LongConn, isCompress bool, compressor Compressor, encoder Encoder,
handler MessageHandler, unregisterChan chan *Client, validate *validator.Validate) {
c.w = new(sync.Mutex)
c.conn = conn
c.platformID = utils.StringToInt(ctx.GetPlatformID())
c.isCompress = isCompress
c.userID = ctx.GetUserID()
c.compressor = compressor
c.encoder = encoder
c.connID = ctx.GetConnID()
c.onlineAt = utils.GetCurrentTimestampByMill()
c.handler = handler
c.unregisterChan = unregisterChan
c.validate = validate
}
func (c *Client) readMessage() { func (c *Client) readMessage() {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
@ -77,7 +92,7 @@ func (c *Client) readMessage() {
if returnErr != nil { if returnErr != nil {
break break
} }
if c.closed == true { if c.closed == true { //连接刚置位已经关闭,但是协程还没退出的场景
break break
} }
switch messageType { switch messageType {
@ -119,7 +134,8 @@ func (c *Client) handleMessage(message []byte) error {
return errors.New("exception conn userID not same to req userID") return errors.New("exception conn userID not same to req userID")
} }
ctx := context.Background() ctx := context.Background()
ctx = context.WithValue(ctx, "operationID", binaryReq.OperationID) ctx = context.WithValue(ctx, c.connID, binaryReq.OperationID)
ctx = context.WithValue(ctx, OPERATION_ID, binaryReq.OperationID)
ctx = context.WithValue(ctx, "userID", binaryReq.SendID) ctx = context.WithValue(ctx, "userID", binaryReq.SendID)
var messageErr error var messageErr error
var resp []byte var resp []byte
@ -173,7 +189,7 @@ func (c *Client) writeMsg(resp Resp) error {
return utils.Wrap(err, "") return utils.Wrap(err, "")
} }
_ = c.conn.SetWriteTimeout(60) _ = c.conn.SetWriteTimeout(60)
if c.IsCompress { if c.isCompress {
var compressErr error var compressErr error
resultBuf, compressErr = c.compressor.Compress(encodeBuf) resultBuf, compressErr = c.compressor.Compress(encodeBuf)
if compressErr != nil { if compressErr != nil {

View File

@ -0,0 +1,10 @@
package new
const (
USERID = "sendID"
PLATFORM_ID = "platformID"
TOKEN = "token"
OPERATION_ID = "operationID"
COMPRESSION = "compression"
GZIP_COMPRESSION_PROTOCAL = "gzip"
)

View File

@ -1,6 +1,10 @@
package new package new
import "net/http" import (
"Open_IM/pkg/utils"
"net/http"
"strconv"
)
type UserConnContext struct { type UserConnContext struct {
RespWriter http.ResponseWriter RespWriter http.ResponseWriter
@ -19,9 +23,32 @@ func newContext(respWriter http.ResponseWriter, req *http.Request) *UserConnCont
RemoteAddr: req.RemoteAddr, RemoteAddr: req.RemoteAddr,
} }
} }
func (c *UserConnContext) Query(key string) string { func (c *UserConnContext) Query(key string) (string, bool) {
return c.Req.URL.Query().Get(key) var value string
if value = c.Req.URL.Query().Get(key); value == "" {
return value, false
}
return value, true
} }
func (c *UserConnContext) GetHeader(key string) string { func (c *UserConnContext) GetHeader(key string) (string, bool) {
return c.Req.Header.Get(key) var value string
if value = c.Req.Header.Get(key); value == "" {
return value, false
}
return value, true
}
func (c *UserConnContext) SetHeader(key, value string) {
c.RespWriter.Header().Set(key, value)
}
func (c *UserConnContext) ErrReturn(error string, code int) {
http.Error(c.RespWriter, error, code)
}
func (c *UserConnContext) GetConnID() string {
return c.RemoteAddr + "_" + strconv.Itoa(int(utils.GetCurrentTimestampByMill()))
}
func (c *UserConnContext) GetUserID() string {
return c.Req.URL.Query().Get(USERID)
}
func (c *UserConnContext) GetPlatformID() string {
return c.Req.URL.Query().Get(PLATFORM_ID)
} }

View File

@ -0,0 +1,44 @@
package new
import (
"Open_IM/pkg/common/constant"
"errors"
"net/http"
)
func httpError(ctx *UserConnContext, err error) {
code := http.StatusUnauthorized
ctx.SetHeader("Sec-Websocket-Version", "13")
ctx.SetHeader("ws_err_msg", err.Error())
if errors.Is(err, constant.ErrTokenExpired) {
code = int(constant.ErrTokenExpired.ErrCode)
}
if errors.Is(err, constant.ErrTokenInvalid) {
code = int(constant.ErrTokenInvalid.ErrCode)
}
if errors.Is(err, constant.ErrTokenMalformed) {
code = int(constant.ErrTokenMalformed.ErrCode)
}
if errors.Is(err, constant.ErrTokenNotValidYet) {
code = int(constant.ErrTokenNotValidYet.ErrCode)
}
if errors.Is(err, constant.ErrTokenUnknown) {
code = int(constant.ErrTokenUnknown.ErrCode)
}
if errors.Is(err, constant.ErrTokenKicked) {
code = int(constant.ErrTokenKicked.ErrCode)
}
if errors.Is(err, constant.ErrTokenDifferentPlatformID) {
code = int(constant.ErrTokenDifferentPlatformID.ErrCode)
}
if errors.Is(err, constant.ErrTokenDifferentUserID) {
code = int(constant.ErrTokenDifferentUserID.ErrCode)
}
if errors.Is(err, constant.ErrConnOverMaxNumLimit) {
code = int(constant.ErrConnOverMaxNumLimit.ErrCode)
}
if errors.Is(err, constant.ErrConnArgsErr) {
code = int(constant.ErrConnArgsErr.ErrCode)
}
ctx.ErrReturn(err.Error(), code)
}

View File

@ -26,19 +26,37 @@ type LongConn interface {
SetConnNil() SetConnNil()
//Check the connection of the current and when it was sent are the same //Check the connection of the current and when it was sent are the same
CheckSendConnDiffNow() bool CheckSendConnDiffNow() bool
//
GenerateLongConn(w http.ResponseWriter, r *http.Request) error
} }
type GWebSocket struct { type GWebSocket struct {
protocolType int protocolType int
conn *websocket.Conn conn *websocket.Conn
handshakeTimeout time.Duration
readBufferSize, WriteBufferSize int
} }
func NewDefault(protocolType int) *GWebSocket { func newGWebSocket(protocolType int, handshakeTimeout time.Duration, readBufferSize int) *GWebSocket {
return &GWebSocket{protocolType: protocolType} return &GWebSocket{protocolType: protocolType, handshakeTimeout: handshakeTimeout, readBufferSize: readBufferSize}
} }
func (d *GWebSocket) Close() error { func (d *GWebSocket) Close() error {
return d.conn.Close() return d.conn.Close()
} }
func (d *GWebSocket) GenerateLongConn(w http.ResponseWriter, r *http.Request) error {
upgrader := &websocket.Upgrader{
HandshakeTimeout: d.handshakeTimeout,
ReadBufferSize: d.readBufferSize,
CheckOrigin: func(r *http.Request) bool { return true },
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return err
}
d.conn = conn
return nil
}
func (d *GWebSocket) WriteMessage(messageType int, message []byte) error { func (d *GWebSocket) WriteMessage(messageType int, message []byte) error {
d.setSendConn(d.conn) d.setSendConn(d.conn)
return d.conn.WriteMessage(messageType, message) return d.conn.WriteMessage(messageType, message)

View File

@ -1,6 +1,9 @@
package new package new
import "context" import (
"Open_IM/internal/common/check"
"context"
)
type Req struct { type Req struct {
ReqIdentifier int32 `json:"reqIdentifier" validate:"required"` ReqIdentifier int32 `json:"reqIdentifier" validate:"required"`
@ -30,6 +33,11 @@ type MessageHandler interface {
var _ MessageHandler = (*GrpcHandler)(nil) var _ MessageHandler = (*GrpcHandler)(nil)
type GrpcHandler struct { type GrpcHandler struct {
msg *check.MsgCheck
}
func NewGrpcHandler(msg *check.MsgCheck) *GrpcHandler {
return &GrpcHandler{msg: msg}
} }
func (g GrpcHandler) GetSeq(context context.Context, data Req) ([]byte, error) { func (g GrpcHandler) GetSeq(context context.Context, data Req) ([]byte, error) {

View File

@ -1,8 +1,11 @@
package new package new
import ( import (
"bytes" "Open_IM/pkg/common/constant"
"Open_IM/pkg/utils"
"errors" "errors"
"fmt"
"github.com/go-playground/validator/v10"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"net/http" "net/http"
"sync" "sync"
@ -10,12 +13,12 @@ import (
"time" "time"
) )
var bufferPool = sync.Pool{ var bufferPool = sync.Pool{
New: func() interface{} { New: func() interface{} {
return make([]byte, 1000) return make([]byte, 1000)
}, },
} }
type LongConnServer interface { type LongConnServer interface {
Run() error Run() error
} }
@ -27,17 +30,21 @@ type Server struct {
rpcServer *RpcServer rpcServer *RpcServer
} }
type WsServer struct { type WsServer struct {
port int port int
wsMaxConnNum int wsMaxConnNum int64
wsUpGrader *websocket.Upgrader wsUpGrader *websocket.Upgrader
registerChan chan *Client registerChan chan *Client
unregisterChan chan *Client unregisterChan chan *Client
clients *UserMap clients *UserMap
clientPool sync.Pool clientPool sync.Pool
onlineUserNum int64 onlineUserNum int64
onlineUserConnNum int64 onlineUserConnNum int64
compressor Compressor gzipCompressor Compressor
handler MessageHandler encoder Encoder
handler MessageHandler
handshakeTimeout time.Duration
readBufferSize, WriteBufferSize int
validate *validator.Validate
} }
func newWsServer(opts ...Option) (*WsServer, error) { func newWsServer(opts ...Option) (*WsServer, error) {
@ -50,18 +57,18 @@ func newWsServer(opts ...Option) (*WsServer, error) {
} }
return &WsServer{ return &WsServer{
port: config.port, port: config.port,
wsMaxConnNum: config.maxConnNum, wsMaxConnNum: config.maxConnNum,
wsUpGrader: &websocket.Upgrader{ handshakeTimeout: config.handshakeTimeout,
HandshakeTimeout: config.handshakeTimeout, readBufferSize: config.messageMaxMsgLength,
ReadBufferSize: config.messageMaxMsgLength,
CheckOrigin: func(r *http.Request) bool { return true },
},
clientPool: sync.Pool{ clientPool: sync.Pool{
New: func() interface{} { New: func() interface{} {
return new(Client) return new(Client)
}, },
}, },
validate: validator.New(),
clients: newUserMap(),
handler: NewGrpcHandler(),
}, nil }, nil
} }
func (ws *WsServer) Run() error { func (ws *WsServer) Run() error {
@ -71,53 +78,115 @@ func (ws *WsServer) Run() error {
select { select {
case client = <-ws.registerChan: case client = <-ws.registerChan:
ws.registerClient(client) ws.registerClient(client)
case client = <-h.unregisterChan: case client = <-ws.unregisterChan:
h.unregisterClient(client) ws.unregisterClient(client)
case msg = <-h.readChan:
h.messageHandler(msg)
} }
} }
}() }()
http.HandleFunc("/", ws.wsHandler) //Get request from client to handle by wsHandler
return http.ListenAndServe(":"+utils.IntToString(ws.port), nil) //Start listening
} }
func (ws *WsServer) registerClient(client *Client) { func (ws *WsServer) registerClient(client *Client) {
var ( var (
ok bool userOK bool
clientOK bool
cli *Client cli *Client
) )
cli, userOK,clientOK = ws.clients.Get(client.userID,client.platformID)
if cli, ok = h.clients.Get(client.key); ok == false { if !userOK {
h.clients.Set(client.key, client) ws.clients.Set(client.userID,client)
atomic.AddInt64(&h.onlineConnections, 1) atomic.AddInt64(&ws.onlineUserNum, 1)
fmt.Println("R在线用户数量:", h.onlineConnections) atomic.AddInt64(&ws.onlineUserConnNum, 1)
return fmt.Println("R在线用户数量:", ws.onlineUserNum)
fmt.Println("R在线用户连接数量:", ws.onlineUserConnNum)
}else{
if clientOK {//已经有同平台的连接存在
ws.clients.Set(client.userID,client)
ws.multiTerminalLoginChecker(cli)
}else{
ws.clients.Set(client.userID,client)
atomic.AddInt64(&ws.onlineUserConnNum, 1)
fmt.Println("R在线用户数量:", ws.onlineUserNum)
fmt.Println("R在线用户连接数量:", ws.onlineUserConnNum)
}
} }
if client.onlineAt > cli.onlineAt {
h.clients.Set(client.key, client)
h.close(cli)
return
}
h.close(client)
} }
http.HandleFunc("/", ws.wsHandler) //Get request from client to handle by wsHandler
return http.ListenAndServe(":"+utils.IntToString(ws.port), nil) //Start listening func (ws *WsServer) multiTerminalLoginChecker(client *Client) {
}
func (ws *WsServer) unregisterClient(client *Client) {
isDeleteUser:=ws.clients.delete(client.userID,client.platformID)
if isDeleteUser {
atomic.AddInt64(&ws.onlineUserNum, -1)
}
atomic.AddInt64(&ws.onlineUserConnNum, -1)
fmt.Println("R在线用户数量:", ws.onlineUserNum)
fmt.Println("R在线用户连接数量:", ws.onlineUserConnNum)
}
} }
func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) { func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) {
context := newContext(w, r) context := newContext(w, r)
if isPass, compression := ws.headerCheck(w, r, operationID); isPass { if ws.onlineUserConnNum >= ws.wsMaxConnNum {
conn, err := ws.wsUpGrader.Upgrade(w, r, nil) //Conn is obtained through the upgraded escalator httpError(context, constant.ErrConnOverMaxNumLimit)
if err != nil { return
log.Error(operationID, "upgrade http conn err", err.Error(), query)
return
} else {
newConn := &UserConn{conn, new(sync.Mutex), utils.StringToInt32(query["platformID"][0]), 0, compression, query["sendID"][0], false, query["token"][0], conn.RemoteAddr().String() + "_" + strconv.Itoa(int(utils.GetCurrentTimestampByMill()))}
userCount++
ws.addUserConn(query["sendID"][0], utils.StringToInt(query["platformID"][0]), newConn, query["token"][0], newConn.connID, operationID)
go ws.readMsg(newConn)
}
} else {
log.Error(operationID, "headerCheck failed ")
} }
var (
token string
userID string
platformID string
exists bool
compression bool
compressor Compressor
)
token, exists = context.Query(TOKEN)
if !exists {
httpError(context, constant.ErrConnArgsErr)
return
}
userID, exists = context.Query(USERID)
if !exists {
httpError(context, constant.ErrConnArgsErr)
return
}
platformID, exists = context.Query(PLATFORM_ID)
if !exists {
httpError(context, constant.ErrConnArgsErr)
return
}
err := tokenverify.WsVerifyToken(token, userID, platformID)
if err != nil {
httpError(context, err)
return
}
wsLongConn:=newGWebSocket(constant.WebSocket,ws.handshakeTimeout,ws.readBufferSize)
err = wsLongConn.GenerateLongConn(w, r)
if err != nil {
httpError(context, err)
return
}
compressProtoc, exists := context.Query(COMPRESSION)
if exists {
if compressProtoc==GZIP_COMPRESSION_PROTOCAL{
compression = true
compressor = ws.gzipCompressor
}
}
compressProtoc, exists = context.GetHeader(COMPRESSION)
if exists {
if compressProtoc==GZIP_COMPRESSION_PROTOCAL {
compression = true
compressor = ws.gzipCompressor
}
}
client:=ws.clientPool.Get().(*Client)
client.ResetClient(context,wsLongConn,compression,compressor,ws.encoder,ws.handler,ws.unregisterChan,ws.validate)
ws.registerChan <- client
go client.readMessage()
} }

View File

@ -7,7 +7,7 @@ type configs struct {
//长连接监听端口 //长连接监听端口
port int port int
//长连接允许最大链接数 //长连接允许最大链接数
maxConnNum int maxConnNum int64
//连接握手超时时间 //连接握手超时时间
handshakeTimeout time.Duration handshakeTimeout time.Duration
//允许消息最大长度 //允许消息最大长度
@ -19,7 +19,7 @@ func WithPort(port int) Option {
opt.port = port opt.port = port
} }
} }
func WithMaxConnNum(num int) Option { func WithMaxConnNum(num int64) Option {
return func(opt *configs) { return func(opt *configs) {
opt.maxConnNum = num opt.maxConnNum = num
} }

View File

@ -9,24 +9,24 @@ type UserMap struct {
func newUserMap() *UserMap { func newUserMap() *UserMap {
return &UserMap{} return &UserMap{}
} }
func (u *UserMap) GetAll(key string) []*Client { func (u *UserMap) GetAll(key string) ([]*Client, bool) {
allClients, ok := u.m.Load(key) allClients, ok := u.m.Load(key)
if ok { if ok {
return allClients.([]*Client) return allClients.([]*Client), ok
} }
return nil return nil, ok
} }
func (u *UserMap) Get(key string, platformID int32) (*Client, bool) { func (u *UserMap) Get(key string, platformID int) (*Client, bool, bool) {
allClients, existed := u.m.Load(key) allClients, userExisted := u.m.Load(key)
if existed { if userExisted {
for _, client := range allClients.([]*Client) { for _, client := range allClients.([]*Client) {
if client.PlatformID == platformID { if client.platformID == platformID {
return client, existed return client, userExisted, true
} }
} }
return nil, false return nil, userExisted, false
} }
return nil, existed return nil, userExisted, false
} }
func (u *UserMap) Set(key string, v *Client) { func (u *UserMap) Set(key string, v *Client) {
allClients, existed := u.m.Load(key) allClients, existed := u.m.Load(key)
@ -40,24 +40,25 @@ func (u *UserMap) Set(key string, v *Client) {
u.m.Store(key, clients) u.m.Store(key, clients)
} }
} }
func (u *UserMap) delete(key string, platformID int32) { func (u *UserMap) delete(key string, platformID int) (isDeleteUser bool) {
allClients, existed := u.m.Load(key) allClients, existed := u.m.Load(key)
if existed { if existed {
oldClients := allClients.([]*Client) oldClients := allClients.([]*Client)
a := make([]*Client, 3)
a := make([]*Client, len(oldClients))
for _, client := range oldClients { for _, client := range oldClients {
if client.PlatformID != platformID { if client.platformID != platformID {
a = append(a, client) a = append(a, client)
} }
} }
if len(a) == 0 { if len(a) == 0 {
u.m.Delete(key) u.m.Delete(key)
return true
} else { } else {
u.m.Store(key, a) u.m.Store(key, a)
return false
} }
} }
return existed
} }
func (u *UserMap) DeleteAll(key string) { func (u *UserMap) DeleteAll(key string) {
u.m.Delete(key) u.m.Delete(key)

View File

@ -54,6 +54,14 @@ var (
// //
ErrMutedInGroup = &ErrInfo{MutedInGroup, "MutedInGroup", ""} ErrMutedInGroup = &ErrInfo{MutedInGroup, "MutedInGroup", ""}
ErrMutedGroup = &ErrInfo{MutedGroup, "MutedGroup", ""} ErrMutedGroup = &ErrInfo{MutedGroup, "MutedGroup", ""}
ErrConnOverMaxNumLimit = &ErrInfo{ConnOverMaxNumLimit, "ConnOverMaxNumLimit", ""}
ErrConnOverMaxNumLimit = &ErrInfo{ConnOverMaxNumLimit, "ConnOverMaxNumLimit", ""}
ErrConnOverMaxNumLimit = &ErrInfo{ConnOverMaxNumLimit, "ConnOverMaxNumLimit", ""}
ErrConnArgsErr = &ErrInfo{ConnArgsErr, "args err, need token, sendID, platformID", ""}
ErrConnUpdateErr = &ErrInfo{ConnArgsErr, "upgrade http conn err", ""}
) )
const ( const (
@ -142,6 +150,13 @@ const (
MessageHasReadDisable = 96001 MessageHasReadDisable = 96001
) )
// 长连接网关错误码
const (
ConnOverMaxNumLimit = 970001
ConnArgsErr = 970002
ConnUpdateErr = 970003
)
// temp // temp
var ( var (