Merge branch 'errcode' of github.com:OpenIMSDK/Open-IM-Server into errcode

This commit is contained in:
wangchuxiao 2023-06-11 16:20:05 +08:00
commit 89f6bbc6a4
7 changed files with 53 additions and 19 deletions

View File

@ -65,11 +65,12 @@ func newClient(ctx *UserConnContext, conn LongConn, isCompress bool) *Client {
ctx: ctx, ctx: ctx,
} }
} }
func (c *Client) ResetClient(ctx *UserConnContext, conn LongConn, isCompress bool, longConnServer LongConnServer) { func (c *Client) ResetClient(ctx *UserConnContext, conn LongConn, isBackground, isCompress bool, longConnServer LongConnServer) {
c.w = new(sync.Mutex) c.w = new(sync.Mutex)
c.conn = conn c.conn = conn
c.PlatformID = utils.StringToInt(ctx.GetPlatformID()) c.PlatformID = utils.StringToInt(ctx.GetPlatformID())
c.IsCompress = isCompress c.IsCompress = isCompress
c.IsBackground = isBackground
c.UserID = ctx.GetUserID() c.UserID = ctx.GetUserID()
c.ctx = ctx c.ctx = ctx
c.longConnServer = longConnServer c.longConnServer = longConnServer

View File

@ -11,6 +11,7 @@ const (
OperationID = "operationID" OperationID = "operationID"
Compression = "compression" Compression = "compression"
GzipCompressionProtocol = "gzip" GzipCompressionProtocol = "gzip"
BackgroundStatus = "isBackground"
) )
const ( const (
WebSocket = iota + 1 WebSocket = iota + 1

View File

@ -91,3 +91,11 @@ func (c *UserConnContext) GetPlatformID() string {
func (c *UserConnContext) GetOperationID() string { func (c *UserConnContext) GetOperationID() string {
return c.Req.URL.Query().Get(OperationID) return c.Req.URL.Query().Get(OperationID)
} }
func (c *UserConnContext) GetBackground() bool {
b, err := strconv.ParseBool(c.Req.URL.Query().Get(BackgroundStatus))
if err != nil {
return false
} else {
return b
}
}

View File

@ -1,13 +1,11 @@
package msggateway package msggateway
import ( import "github.com/OpenIMSDK/Open-IM-Server/pkg/apiresp"
"net/http"
)
func httpError(ctx *UserConnContext, err error) { func httpError(ctx *UserConnContext, err error) {
code := http.StatusUnauthorized //code := http.StatusUnauthorized
ctx.SetHeader("Sec-Websocket-Version", "13") //ctx.SetHeader("Sec-Websocket-Version", "13")
ctx.SetHeader("ws_err_msg", err.Error()) //ctx.SetHeader("ws_err_msg", err.Error())
//if errors.Is(err, errs.ErrTokenExpired) { //if errors.Is(err, errs.ErrTokenExpired) {
// code = errs.ErrTokenExpired.Code() // code = errs.ErrTokenExpired.Code()
//} //}
@ -38,5 +36,6 @@ func httpError(ctx *UserConnContext, err error) {
//if errors.Is(err, errs.ErrConnArgsErr) { //if errors.Is(err, errs.ErrConnArgsErr) {
// code = errs.ErrConnArgsErr.Code() // code = errs.ErrConnArgsErr.Code()
//} //}
ctx.ErrReturn(err.Error(), code) //ctx.ErrReturn(err.Error(), code)
apiresp.HttpError(ctx.RespWriter, err)
} }

View File

@ -223,7 +223,7 @@ func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) {
} }
} }
client := ws.clientPool.Get().(*Client) client := ws.clientPool.Get().(*Client)
client.ResetClient(context, wsLongConn, compression, ws) client.ResetClient(context, wsLongConn, context.GetBackground(), compression, ws)
ws.registerChan <- client ws.registerChan <- client
go client.readMessage() go client.readMessage()
} }

25
pkg/apiresp/http.go Normal file
View File

@ -0,0 +1,25 @@
package apiresp
import (
"encoding/json"
"net/http"
)
func httpJson(w http.ResponseWriter, data any) {
body, err := json.Marshal(data)
if err != nil {
http.Error(w, "json marshal error: "+err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.WriteHeader(http.StatusOK)
_, _ = w.Write(body)
}
func HttpError(w http.ResponseWriter, err error) {
httpJson(w, ParseError(err))
}
func HttpSuccess(w http.ResponseWriter, data any) {
httpJson(w, ApiSuccess(data))
}

View File

@ -89,15 +89,15 @@ func IsManagerUserID(opUserID string) bool {
return utils.IsContain(opUserID, config.Config.Manager.AppManagerUid) return utils.IsContain(opUserID, config.Config.Manager.AppManagerUid)
} }
func WsVerifyToken(token, userID, platformID string) error { func WsVerifyToken(token, userID, platformID string) error {
claim, err := GetClaimFromToken(token) //claim, err := GetClaimFromToken(token)
if err != nil { //if err != nil {
return err // return err
} //}
if claim.UID != userID { //if claim.UID != userID {
return errs.ErrTokenInvalid.Wrap(fmt.Sprintf("token uid %s != userID %s", claim.UID, userID)) // return errs.ErrTokenInvalid.Wrap(fmt.Sprintf("token uid %s != userID %s", claim.UID, userID))
} //}
if claim.Platform != platformID { //if claim.Platform != platformID {
return errs.ErrInternalServer.Wrap(fmt.Sprintf("token platform %s != platformID %s", claim.Platform, platformID)) // return errs.ErrTokenInvalid.Wrap(fmt.Sprintf("token platform %s != platformID %s", claim.Platform, platformID))
} //}
return nil return nil
} }