Merge remote-tracking branch 'origin/errcode' into errcode

This commit is contained in:
withchao 2023-02-16 17:11:00 +08:00
commit ff542e83fe
11 changed files with 341 additions and 97 deletions

View File

@ -28,3 +28,39 @@ func (m *MsgCheck) SendMsg(ctx context.Context, req *msg.SendMsgReq) (*msg.SendM
resp, err := msg.NewMsgClient(cc).SendMsg(ctx, req)
return resp, err
}
func (m *MsgCheck) SendMsg(ctx context.Context, req *msg.SendMsgReq) (*msg.SendMsgResp, error) {
cc, err := m.getConn()
if err != nil {
return nil, err
}
resp, err := msg.NewMsgClient(cc).SendMsg(ctx, req)
return resp, err
}
func (m *MsgCheck) SendMsg(ctx context.Context, req *msg.SendMsgReq) (*msg.SendMsgResp, error) {
cc, err := m.getConn()
if err != nil {
return nil, err
}
resp, err := msg.NewMsgClient(cc).SendMsg(ctx, req)
return resp, err
}
func (m *MsgCheck) SendMsg(ctx context.Context, req *msg.SendMsgReq) (*msg.SendMsgResp, error) {
cc, err := m.getConn()
if err != nil {
return nil, err
}
resp, err := msg.NewMsgClient(cc).SendMsg(ctx, req)
return resp, err
}
func (m *MsgCheck) SendMsg(ctx context.Context, req *msg.SendMsgReq) (*msg.SendMsgResp, error) {
cc, err := m.getConn()
if err != nil {
return nil, err
}
resp, err := msg.NewMsgClient(cc).SendMsg(ctx, req)
return resp, err
}

View File

@ -3,14 +3,12 @@ package new
import (
"Open_IM/pkg/common/constant"
"Open_IM/pkg/utils"
"bytes"
"context"
"errors"
"fmt"
"github.com/go-playground/validator/v10"
"runtime/debug"
"sync"
"time"
)
const (
@ -35,35 +33,52 @@ const (
type Client struct {
w *sync.Mutex
conn LongConn
PlatformID int32
PushedMaxSeq uint32
IsCompress bool
platformID int
isCompress bool
userID string
IsBackground bool
token string
isBackground bool
connID string
onlineAt int64 // 上线时间戳(毫秒)
handler MessageHandler
unregisterChan chan *Client
compressor Compressor
encoder Encoder
userContext UserConnContext
validate *validator.Validate
closed bool
}
func newClient(conn LongConn, isCompress bool, userID string, isBackground bool, token string,
connID string, onlineAt int64, handler MessageHandler, unregisterChan chan *Client) *Client {
func newClient(ctx *UserConnContext, conn LongConn, isCompress bool, compressor Compressor, encoder Encoder,
handler MessageHandler, unregisterChan chan *Client, validate *validator.Validate) *Client {
return &Client{
w: new(sync.Mutex),
conn: conn,
IsCompress: isCompress,
userID: userID, IsBackground: isBackground, token: token,
connID: connID,
onlineAt: onlineAt,
platformID: utils.StringToInt(ctx.GetPlatformID()),
isCompress: isCompress,
userID: ctx.GetUserID(),
compressor: compressor,
encoder: encoder,
connID: ctx.GetConnID(),
onlineAt: utils.GetCurrentTimestampByMill(),
handler: handler,
unregisterChan: unregisterChan,
validate: validate,
}
}
func (c *Client) ResetClient(ctx *UserConnContext, conn LongConn, isCompress bool, compressor Compressor, encoder Encoder,
handler MessageHandler, unregisterChan chan *Client, validate *validator.Validate) {
c.w = new(sync.Mutex)
c.conn = conn
c.platformID = utils.StringToInt(ctx.GetPlatformID())
c.isCompress = isCompress
c.userID = ctx.GetUserID()
c.compressor = compressor
c.encoder = encoder
c.connID = ctx.GetConnID()
c.onlineAt = utils.GetCurrentTimestampByMill()
c.handler = handler
c.unregisterChan = unregisterChan
c.validate = validate
}
func (c *Client) readMessage() {
defer func() {
if r := recover(); r != nil {
@ -77,7 +92,7 @@ func (c *Client) readMessage() {
if returnErr != nil {
break
}
if c.closed == true {
if c.closed == true { //连接刚置位已经关闭,但是协程还没退出的场景
break
}
switch messageType {
@ -119,7 +134,8 @@ func (c *Client) handleMessage(message []byte) error {
return errors.New("exception conn userID not same to req userID")
}
ctx := context.Background()
ctx = context.WithValue(ctx, "operationID", binaryReq.OperationID)
ctx = context.WithValue(ctx, c.connID, binaryReq.OperationID)
ctx = context.WithValue(ctx, OPERATION_ID, binaryReq.OperationID)
ctx = context.WithValue(ctx, "userID", binaryReq.SendID)
var messageErr error
var resp []byte
@ -173,7 +189,7 @@ func (c *Client) writeMsg(resp Resp) error {
return utils.Wrap(err, "")
}
_ = c.conn.SetWriteTimeout(60)
if c.IsCompress {
if c.isCompress {
var compressErr error
resultBuf, compressErr = c.compressor.Compress(encodeBuf)
if compressErr != nil {

View File

@ -0,0 +1,10 @@
package new
const (
USERID = "sendID"
PLATFORM_ID = "platformID"
TOKEN = "token"
OPERATION_ID = "operationID"
COMPRESSION = "compression"
GZIP_COMPRESSION_PROTOCAL = "gzip"
)

View File

@ -1,6 +1,10 @@
package new
import "net/http"
import (
"Open_IM/pkg/utils"
"net/http"
"strconv"
)
type UserConnContext struct {
RespWriter http.ResponseWriter
@ -19,9 +23,32 @@ func newContext(respWriter http.ResponseWriter, req *http.Request) *UserConnCont
RemoteAddr: req.RemoteAddr,
}
}
func (c *UserConnContext) Query(key string) string {
return c.Req.URL.Query().Get(key)
func (c *UserConnContext) Query(key string) (string, bool) {
var value string
if value = c.Req.URL.Query().Get(key); value == "" {
return value, false
}
return value, true
}
func (c *UserConnContext) GetHeader(key string) string {
return c.Req.Header.Get(key)
func (c *UserConnContext) GetHeader(key string) (string, bool) {
var value string
if value = c.Req.Header.Get(key); value == "" {
return value, false
}
return value, true
}
func (c *UserConnContext) SetHeader(key, value string) {
c.RespWriter.Header().Set(key, value)
}
func (c *UserConnContext) ErrReturn(error string, code int) {
http.Error(c.RespWriter, error, code)
}
func (c *UserConnContext) GetConnID() string {
return c.RemoteAddr + "_" + strconv.Itoa(int(utils.GetCurrentTimestampByMill()))
}
func (c *UserConnContext) GetUserID() string {
return c.Req.URL.Query().Get(USERID)
}
func (c *UserConnContext) GetPlatformID() string {
return c.Req.URL.Query().Get(PLATFORM_ID)
}

View File

@ -0,0 +1,44 @@
package new
import (
"Open_IM/pkg/common/constant"
"errors"
"net/http"
)
func httpError(ctx *UserConnContext, err error) {
code := http.StatusUnauthorized
ctx.SetHeader("Sec-Websocket-Version", "13")
ctx.SetHeader("ws_err_msg", err.Error())
if errors.Is(err, constant.ErrTokenExpired) {
code = int(constant.ErrTokenExpired.ErrCode)
}
if errors.Is(err, constant.ErrTokenInvalid) {
code = int(constant.ErrTokenInvalid.ErrCode)
}
if errors.Is(err, constant.ErrTokenMalformed) {
code = int(constant.ErrTokenMalformed.ErrCode)
}
if errors.Is(err, constant.ErrTokenNotValidYet) {
code = int(constant.ErrTokenNotValidYet.ErrCode)
}
if errors.Is(err, constant.ErrTokenUnknown) {
code = int(constant.ErrTokenUnknown.ErrCode)
}
if errors.Is(err, constant.ErrTokenKicked) {
code = int(constant.ErrTokenKicked.ErrCode)
}
if errors.Is(err, constant.ErrTokenDifferentPlatformID) {
code = int(constant.ErrTokenDifferentPlatformID.ErrCode)
}
if errors.Is(err, constant.ErrTokenDifferentUserID) {
code = int(constant.ErrTokenDifferentUserID.ErrCode)
}
if errors.Is(err, constant.ErrConnOverMaxNumLimit) {
code = int(constant.ErrConnOverMaxNumLimit.ErrCode)
}
if errors.Is(err, constant.ErrConnArgsErr) {
code = int(constant.ErrConnArgsErr.ErrCode)
}
ctx.ErrReturn(err.Error(), code)
}

View File

@ -26,19 +26,37 @@ type LongConn interface {
SetConnNil()
//Check the connection of the current and when it was sent are the same
CheckSendConnDiffNow() bool
//
GenerateLongConn(w http.ResponseWriter, r *http.Request) error
}
type GWebSocket struct {
protocolType int
conn *websocket.Conn
handshakeTimeout time.Duration
readBufferSize, WriteBufferSize int
}
func NewDefault(protocolType int) *GWebSocket {
return &GWebSocket{protocolType: protocolType}
func newGWebSocket(protocolType int, handshakeTimeout time.Duration, readBufferSize int) *GWebSocket {
return &GWebSocket{protocolType: protocolType, handshakeTimeout: handshakeTimeout, readBufferSize: readBufferSize}
}
func (d *GWebSocket) Close() error {
return d.conn.Close()
}
func (d *GWebSocket) GenerateLongConn(w http.ResponseWriter, r *http.Request) error {
upgrader := &websocket.Upgrader{
HandshakeTimeout: d.handshakeTimeout,
ReadBufferSize: d.readBufferSize,
CheckOrigin: func(r *http.Request) bool { return true },
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return err
}
d.conn = conn
return nil
}
func (d *GWebSocket) WriteMessage(messageType int, message []byte) error {
d.setSendConn(d.conn)
return d.conn.WriteMessage(messageType, message)

View File

@ -1,6 +1,9 @@
package new
import "context"
import (
"Open_IM/internal/common/check"
"context"
)
type Req struct {
ReqIdentifier int32 `json:"reqIdentifier" validate:"required"`
@ -30,6 +33,11 @@ type MessageHandler interface {
var _ MessageHandler = (*GrpcHandler)(nil)
type GrpcHandler struct {
msg *check.MsgCheck
}
func NewGrpcHandler(msg *check.MsgCheck) *GrpcHandler {
return &GrpcHandler{msg: msg}
}
func (g GrpcHandler) GetSeq(context context.Context, data Req) ([]byte, error) {

View File

@ -1,8 +1,11 @@
package new
import (
"bytes"
"Open_IM/pkg/common/constant"
"Open_IM/pkg/utils"
"errors"
"fmt"
"github.com/go-playground/validator/v10"
"github.com/gorilla/websocket"
"net/http"
"sync"
@ -10,12 +13,12 @@ import (
"time"
)
var bufferPool = sync.Pool{
New: func() interface{} {
return make([]byte, 1000)
},
}
type LongConnServer interface {
Run() error
}
@ -28,7 +31,7 @@ type Server struct {
}
type WsServer struct {
port int
wsMaxConnNum int
wsMaxConnNum int64
wsUpGrader *websocket.Upgrader
registerChan chan *Client
unregisterChan chan *Client
@ -36,8 +39,12 @@ type WsServer struct {
clientPool sync.Pool
onlineUserNum int64
onlineUserConnNum int64
compressor Compressor
gzipCompressor Compressor
encoder Encoder
handler MessageHandler
handshakeTimeout time.Duration
readBufferSize, WriteBufferSize int
validate *validator.Validate
}
func newWsServer(opts ...Option) (*WsServer, error) {
@ -52,16 +59,16 @@ func newWsServer(opts ...Option) (*WsServer, error) {
return &WsServer{
port: config.port,
wsMaxConnNum: config.maxConnNum,
wsUpGrader: &websocket.Upgrader{
HandshakeTimeout: config.handshakeTimeout,
ReadBufferSize: config.messageMaxMsgLength,
CheckOrigin: func(r *http.Request) bool { return true },
},
handshakeTimeout: config.handshakeTimeout,
readBufferSize: config.messageMaxMsgLength,
clientPool: sync.Pool{
New: func() interface{} {
return new(Client)
},
},
validate: validator.New(),
clients: newUserMap(),
handler: NewGrpcHandler(),
}, nil
}
func (ws *WsServer) Run() error {
@ -71,53 +78,115 @@ func (ws *WsServer) Run() error {
select {
case client = <-ws.registerChan:
ws.registerClient(client)
case client = <-h.unregisterChan:
h.unregisterClient(client)
case msg = <-h.readChan:
h.messageHandler(msg)
case client = <-ws.unregisterChan:
ws.unregisterClient(client)
}
}
}()
http.HandleFunc("/", ws.wsHandler) //Get request from client to handle by wsHandler
return http.ListenAndServe(":"+utils.IntToString(ws.port), nil) //Start listening
}
func (ws *WsServer) registerClient(client *Client) {
var (
ok bool
userOK bool
clientOK bool
cli *Client
)
if cli, ok = h.clients.Get(client.key); ok == false {
h.clients.Set(client.key, client)
atomic.AddInt64(&h.onlineConnections, 1)
fmt.Println("R在线用户数量:", h.onlineConnections)
return
cli, userOK,clientOK = ws.clients.Get(client.userID,client.platformID)
if !userOK {
ws.clients.Set(client.userID,client)
atomic.AddInt64(&ws.onlineUserNum, 1)
atomic.AddInt64(&ws.onlineUserConnNum, 1)
fmt.Println("R在线用户数量:", ws.onlineUserNum)
fmt.Println("R在线用户连接数量:", ws.onlineUserConnNum)
}else{
if clientOK {//已经有同平台的连接存在
ws.clients.Set(client.userID,client)
ws.multiTerminalLoginChecker(cli)
}else{
ws.clients.Set(client.userID,client)
atomic.AddInt64(&ws.onlineUserConnNum, 1)
fmt.Println("R在线用户数量:", ws.onlineUserNum)
fmt.Println("R在线用户连接数量:", ws.onlineUserConnNum)
}
}
if client.onlineAt > cli.onlineAt {
h.clients.Set(client.key, client)
h.close(cli)
return
}
h.close(client)
}
http.HandleFunc("/", ws.wsHandler) //Get request from client to handle by wsHandler
return http.ListenAndServe(":"+utils.IntToString(ws.port), nil) //Start listening
func (ws *WsServer) multiTerminalLoginChecker(client *Client) {
}
func (ws *WsServer) unregisterClient(client *Client) {
isDeleteUser:=ws.clients.delete(client.userID,client.platformID)
if isDeleteUser {
atomic.AddInt64(&ws.onlineUserNum, -1)
}
atomic.AddInt64(&ws.onlineUserConnNum, -1)
fmt.Println("R在线用户数量:", ws.onlineUserNum)
fmt.Println("R在线用户连接数量:", ws.onlineUserConnNum)
}
}
func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) {
context := newContext(w, r)
if isPass, compression := ws.headerCheck(w, r, operationID); isPass {
conn, err := ws.wsUpGrader.Upgrade(w, r, nil) //Conn is obtained through the upgraded escalator
if err != nil {
log.Error(operationID, "upgrade http conn err", err.Error(), query)
if ws.onlineUserConnNum >= ws.wsMaxConnNum {
httpError(context, constant.ErrConnOverMaxNumLimit)
return
} else {
newConn := &UserConn{conn, new(sync.Mutex), utils.StringToInt32(query["platformID"][0]), 0, compression, query["sendID"][0], false, query["token"][0], conn.RemoteAddr().String() + "_" + strconv.Itoa(int(utils.GetCurrentTimestampByMill()))}
userCount++
ws.addUserConn(query["sendID"][0], utils.StringToInt(query["platformID"][0]), newConn, query["token"][0], newConn.connID, operationID)
go ws.readMsg(newConn)
}
} else {
log.Error(operationID, "headerCheck failed ")
var (
token string
userID string
platformID string
exists bool
compression bool
compressor Compressor
)
token, exists = context.Query(TOKEN)
if !exists {
httpError(context, constant.ErrConnArgsErr)
return
}
userID, exists = context.Query(USERID)
if !exists {
httpError(context, constant.ErrConnArgsErr)
return
}
platformID, exists = context.Query(PLATFORM_ID)
if !exists {
httpError(context, constant.ErrConnArgsErr)
return
}
err := tokenverify.WsVerifyToken(token, userID, platformID)
if err != nil {
httpError(context, err)
return
}
wsLongConn:=newGWebSocket(constant.WebSocket,ws.handshakeTimeout,ws.readBufferSize)
err = wsLongConn.GenerateLongConn(w, r)
if err != nil {
httpError(context, err)
return
}
compressProtoc, exists := context.Query(COMPRESSION)
if exists {
if compressProtoc==GZIP_COMPRESSION_PROTOCAL{
compression = true
compressor = ws.gzipCompressor
}
}
compressProtoc, exists = context.GetHeader(COMPRESSION)
if exists {
if compressProtoc==GZIP_COMPRESSION_PROTOCAL {
compression = true
compressor = ws.gzipCompressor
}
}
client:=ws.clientPool.Get().(*Client)
client.ResetClient(context,wsLongConn,compression,compressor,ws.encoder,ws.handler,ws.unregisterChan,ws.validate)
ws.registerChan <- client
go client.readMessage()
}

View File

@ -7,7 +7,7 @@ type configs struct {
//长连接监听端口
port int
//长连接允许最大链接数
maxConnNum int
maxConnNum int64
//连接握手超时时间
handshakeTimeout time.Duration
//允许消息最大长度
@ -19,7 +19,7 @@ func WithPort(port int) Option {
opt.port = port
}
}
func WithMaxConnNum(num int) Option {
func WithMaxConnNum(num int64) Option {
return func(opt *configs) {
opt.maxConnNum = num
}

View File

@ -9,24 +9,24 @@ type UserMap struct {
func newUserMap() *UserMap {
return &UserMap{}
}
func (u *UserMap) GetAll(key string) []*Client {
func (u *UserMap) GetAll(key string) ([]*Client, bool) {
allClients, ok := u.m.Load(key)
if ok {
return allClients.([]*Client)
return allClients.([]*Client), ok
}
return nil
return nil, ok
}
func (u *UserMap) Get(key string, platformID int32) (*Client, bool) {
allClients, existed := u.m.Load(key)
if existed {
func (u *UserMap) Get(key string, platformID int) (*Client, bool, bool) {
allClients, userExisted := u.m.Load(key)
if userExisted {
for _, client := range allClients.([]*Client) {
if client.PlatformID == platformID {
return client, existed
if client.platformID == platformID {
return client, userExisted, true
}
}
return nil, false
return nil, userExisted, false
}
return nil, existed
return nil, userExisted, false
}
func (u *UserMap) Set(key string, v *Client) {
allClients, existed := u.m.Load(key)
@ -40,24 +40,25 @@ func (u *UserMap) Set(key string, v *Client) {
u.m.Store(key, clients)
}
}
func (u *UserMap) delete(key string, platformID int32) {
func (u *UserMap) delete(key string, platformID int) (isDeleteUser bool) {
allClients, existed := u.m.Load(key)
if existed {
oldClients := allClients.([]*Client)
a := make([]*Client, len(oldClients))
a := make([]*Client, 3)
for _, client := range oldClients {
if client.PlatformID != platformID {
if client.platformID != platformID {
a = append(a, client)
}
}
if len(a) == 0 {
u.m.Delete(key)
return true
} else {
u.m.Store(key, a)
return false
}
}
return existed
}
func (u *UserMap) DeleteAll(key string) {
u.m.Delete(key)

View File

@ -54,6 +54,14 @@ var (
//
ErrMutedInGroup = &ErrInfo{MutedInGroup, "MutedInGroup", ""}
ErrMutedGroup = &ErrInfo{MutedGroup, "MutedGroup", ""}
ErrConnOverMaxNumLimit = &ErrInfo{ConnOverMaxNumLimit, "ConnOverMaxNumLimit", ""}
ErrConnOverMaxNumLimit = &ErrInfo{ConnOverMaxNumLimit, "ConnOverMaxNumLimit", ""}
ErrConnOverMaxNumLimit = &ErrInfo{ConnOverMaxNumLimit, "ConnOverMaxNumLimit", ""}
ErrConnArgsErr = &ErrInfo{ConnArgsErr, "args err, need token, sendID, platformID", ""}
ErrConnUpdateErr = &ErrInfo{ConnArgsErr, "upgrade http conn err", ""}
)
const (
@ -142,6 +150,13 @@ const (
MessageHasReadDisable = 96001
)
// 长连接网关错误码
const (
ConnOverMaxNumLimit = 970001
ConnArgsErr = 970002
ConnUpdateErr = 970003
)
// temp
var (