refactor: simplify platformID handling and enhance UserConnContext structure

This commit is contained in:
withchao 2025-12-25 15:52:19 +08:00
parent 95ab761d8f
commit 9fefa916c8
5 changed files with 441 additions and 375 deletions

View File

@ -30,7 +30,6 @@ import (
"github.com/openimsdk/tools/errs" "github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log" "github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/mcontext" "github.com/openimsdk/tools/mcontext"
"github.com/openimsdk/tools/utils/stringutil"
) )
var ( var (
@ -85,7 +84,7 @@ type Client struct {
func (c *Client) ResetClient(ctx *UserConnContext, conn ClientConn, longConnServer LongConnServer) { func (c *Client) ResetClient(ctx *UserConnContext, conn ClientConn, longConnServer LongConnServer) {
c.w = new(sync.Mutex) c.w = new(sync.Mutex)
c.conn = conn c.conn = conn
c.PlatformID = stringutil.StringToInt(ctx.GetPlatformID()) c.PlatformID = ctx.GetPlatformID()
c.IsCompress = ctx.GetCompression() c.IsCompress = ctx.GetCompression()
c.IsBackground = ctx.GetBackground() c.IsBackground = ctx.GetBackground()
c.UserID = ctx.GetUserID() c.UserID = ctx.GetUserID()

View File

@ -15,18 +15,31 @@
package msggateway package msggateway
import ( import (
"github.com/openimsdk/open-im-server/v3/pkg/common/servererrs" "encoding/base64"
"encoding/json"
"net/http" "net/http"
"net/url" "net/url"
"strconv" "strconv"
"time" "time"
"github.com/openimsdk/open-im-server/v3/pkg/common/servererrs"
"github.com/openimsdk/protocol/constant" "github.com/openimsdk/protocol/constant"
"github.com/openimsdk/tools/utils/encrypt" "github.com/openimsdk/tools/utils/encrypt"
"github.com/openimsdk/tools/utils/stringutil"
"github.com/openimsdk/tools/utils/timeutil" "github.com/openimsdk/tools/utils/timeutil"
) )
type UserConnContextInfo struct {
Token string `json:"token"`
UserID string `json:"userID"`
PlatformID int `json:"platformID"`
OperationID string `json:"operationID"`
Compression string `json:"compression"`
SDKType string `json:"sdkType"`
SendResponse bool `json:"sendResponse"`
Background bool `json:"background"`
}
type UserConnContext struct { type UserConnContext struct {
RespWriter http.ResponseWriter RespWriter http.ResponseWriter
Req *http.Request Req *http.Request
@ -34,6 +47,7 @@ type UserConnContext struct {
Method string Method string
RemoteAddr string RemoteAddr string
ConnID string ConnID string
info *UserConnContextInfo
} }
func (c *UserConnContext) Deadline() (deadline time.Time, ok bool) { func (c *UserConnContext) Deadline() (deadline time.Time, ok bool) {
@ -57,7 +71,7 @@ func (c *UserConnContext) Value(key any) any {
case constant.ConnID: case constant.ConnID:
return c.GetConnID() return c.GetConnID()
case constant.OpUserPlatform: case constant.OpUserPlatform:
return constant.PlatformIDToName(stringutil.StringToInt(c.GetPlatformID())) return c.GetPlatformID()
case constant.RemoteAddr: case constant.RemoteAddr:
return c.RemoteAddr return c.RemoteAddr
default: default:
@ -82,30 +96,91 @@ func newContext(respWriter http.ResponseWriter, req *http.Request) *UserConnCont
func newTempContext() *UserConnContext { func newTempContext() *UserConnContext {
return &UserConnContext{ return &UserConnContext{
Req: &http.Request{URL: &url.URL{}}, Req: &http.Request{URL: &url.URL{}},
info: &UserConnContextInfo{},
} }
} }
func (c *UserConnContext) ParseEssentialArgs() error {
query := c.Req.URL.Query()
if data := query.Get("v"); data != "" {
return c.parseByJson(data)
} else {
return c.parseByQuery(query, c.Req.Header)
}
}
func (c *UserConnContext) parseByQuery(query url.Values, header http.Header) error {
info := UserConnContextInfo{
Token: query.Get(Token),
UserID: query.Get(WsUserID),
OperationID: query.Get(OperationID),
Compression: query.Get(Compression),
SDKType: query.Get(SDKType),
}
platformID, err := strconv.Atoi(query.Get(PlatformID))
if err != nil {
return servererrs.ErrConnArgsErr.WrapMsg("platformID is not int")
}
info.PlatformID = platformID
if val := query.Get(SendResponse); val != "" {
ok, err := strconv.ParseBool(val)
if err != nil {
return servererrs.ErrConnArgsErr.WrapMsg("isMsgResp is not bool")
}
info.SendResponse = ok
}
if info.Compression == "" {
info.Compression = header.Get(Compression)
}
background, err := strconv.ParseBool(query.Get(BackgroundStatus))
if err != nil {
return err
}
info.Background = background
return c.checkInfo(&info)
}
func (c *UserConnContext) parseByJson(data string) error {
reqInfo, err := base64.RawURLEncoding.DecodeString(data)
if err != nil {
return servererrs.ErrConnArgsErr.WrapMsg("data is not base64")
}
var info UserConnContextInfo
if err := json.Unmarshal(reqInfo, &info); err != nil {
return servererrs.ErrConnArgsErr.WrapMsg("data is not json", "info", err.Error())
}
return c.checkInfo(&info)
}
func (c *UserConnContext) checkInfo(info *UserConnContextInfo) error {
if info.OperationID == "" {
return servererrs.ErrConnArgsErr.WrapMsg("operationID is empty")
}
if info.Token == "" {
return servererrs.ErrConnArgsErr.WrapMsg("token is empty")
}
if info.UserID == "" {
return servererrs.ErrConnArgsErr.WrapMsg("sendID is empty")
}
if _, ok := constant.PlatformID2Name[info.PlatformID]; !ok {
return servererrs.ErrConnArgsErr.WrapMsg("platformID is invalid")
}
switch info.SDKType {
case "":
info.SDKType = GoSDK
case GoSDK, JsSDK:
default:
return servererrs.ErrConnArgsErr.WrapMsg("sdkType is invalid")
}
c.info = info
return nil
}
func (c *UserConnContext) GetRemoteAddr() string { func (c *UserConnContext) GetRemoteAddr() string {
return c.RemoteAddr return c.RemoteAddr
} }
func (c *UserConnContext) Query(key string) (string, bool) {
var value string
if value = c.Req.URL.Query().Get(key); value == "" {
return value, false
}
return value, true
}
func (c *UserConnContext) GetHeader(key string) (string, bool) {
var value string
if value = c.Req.Header.Get(key); value == "" {
return value, false
}
return value, true
}
func (c *UserConnContext) SetHeader(key, value string) { func (c *UserConnContext) SetHeader(key, value string) {
c.RespWriter.Header().Set(key, value) c.RespWriter.Header().Set(key, value)
} }
@ -119,93 +194,69 @@ func (c *UserConnContext) GetConnID() string {
} }
func (c *UserConnContext) GetUserID() string { func (c *UserConnContext) GetUserID() string {
return c.Req.URL.Query().Get(WsUserID) if c == nil || c.info == nil {
return ""
}
return c.info.UserID
} }
func (c *UserConnContext) GetPlatformID() string { func (c *UserConnContext) GetPlatformID() int {
return c.Req.URL.Query().Get(PlatformID) if c == nil || c.info == nil {
return 0
}
return c.info.PlatformID
} }
func (c *UserConnContext) GetOperationID() string { func (c *UserConnContext) GetOperationID() string {
return c.Req.URL.Query().Get(OperationID) if c == nil || c.info == nil {
return ""
}
return c.info.OperationID
} }
func (c *UserConnContext) SetOperationID(operationID string) { func (c *UserConnContext) SetOperationID(operationID string) {
values := c.Req.URL.Query() if c.info == nil {
values.Set(OperationID, operationID) c.info = &UserConnContextInfo{}
c.Req.URL.RawQuery = values.Encode() }
c.info.OperationID = operationID
} }
func (c *UserConnContext) GetToken() string { func (c *UserConnContext) GetToken() string {
return c.Req.URL.Query().Get(Token) if c == nil || c.info == nil {
return ""
}
return c.info.Token
} }
func (c *UserConnContext) GetCompression() bool { func (c *UserConnContext) GetCompression() bool {
compression, exists := c.Query(Compression) return c != nil && c.info != nil && c.info.Compression == GzipCompressionProtocol
if exists && compression == GzipCompressionProtocol {
return true
} else {
compression, exists := c.GetHeader(Compression)
if exists && compression == GzipCompressionProtocol {
return true
}
}
return false
} }
func (c *UserConnContext) GetSDKType() string { func (c *UserConnContext) GetSDKType() string {
sdkType := c.Req.URL.Query().Get(SDKType) if c == nil || c.info == nil {
if sdkType == "" { return GoSDK
sdkType = GoSDK }
switch c.info.SDKType {
case "", GoSDK:
return GoSDK
case JsSDK:
return JsSDK
default:
return ""
} }
return sdkType
} }
func (c *UserConnContext) ShouldSendResp() bool { func (c *UserConnContext) ShouldSendResp() bool {
errResp, exists := c.Query(SendResponse) return c != nil && c.info != nil && c.info.SendResponse
if exists {
b, err := strconv.ParseBool(errResp)
if err != nil {
return false
} else {
return b
}
}
return false
} }
func (c *UserConnContext) SetToken(token string) { func (c *UserConnContext) SetToken(token string) {
c.Req.URL.RawQuery = Token + "=" + token if c.info == nil {
c.info = &UserConnContextInfo{}
}
c.info.Token = token
} }
func (c *UserConnContext) GetBackground() bool { func (c *UserConnContext) GetBackground() bool {
b, err := strconv.ParseBool(c.Req.URL.Query().Get(BackgroundStatus)) return c != nil && c.info != nil && c.info.Background
if err != nil {
return false
}
return b
}
func (c *UserConnContext) ParseEssentialArgs() error {
_, exists := c.Query(Token)
if !exists {
return servererrs.ErrConnArgsErr.WrapMsg("token is empty")
}
_, exists = c.Query(WsUserID)
if !exists {
return servererrs.ErrConnArgsErr.WrapMsg("sendID is empty")
}
platformIDStr, exists := c.Query(PlatformID)
if !exists {
return servererrs.ErrConnArgsErr.WrapMsg("platformID is empty")
}
_, err := strconv.Atoi(platformIDStr)
if err != nil {
return servererrs.ErrConnArgsErr.WrapMsg("platformID is not int")
}
switch sdkType, _ := c.Query(SDKType); sdkType {
case "", GoSDK, JsSDK:
default:
return servererrs.ErrConnArgsErr.WrapMsg("sdkType is not go or js")
}
return nil
} }

View File

@ -5,7 +5,6 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"strconv"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@ -30,6 +29,8 @@ import (
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
) )
var wsSuccessResponse, _ = json.Marshal(&apiresp.ApiResponse{})
type LongConnServer interface { type LongConnServer interface {
Run(done chan error) error Run(done chan error) error
wsHandler(w http.ResponseWriter, r *http.Request) wsHandler(w http.ResponseWriter, r *http.Request)
@ -448,11 +449,11 @@ func (ws *WsServer) unregisterClient(client *Client) {
// validateRespWithRequest checks if the response matches the expected userID and platformID. // validateRespWithRequest checks if the response matches the expected userID and platformID.
func (ws *WsServer) validateRespWithRequest(ctx *UserConnContext, resp *pbAuth.ParseTokenResp) error { func (ws *WsServer) validateRespWithRequest(ctx *UserConnContext, resp *pbAuth.ParseTokenResp) error {
userID := ctx.GetUserID() userID := ctx.GetUserID()
platformID := stringutil.StringToInt32(ctx.GetPlatformID()) platformID := ctx.GetPlatformID()
if resp.UserID != userID { if resp.UserID != userID {
return servererrs.ErrTokenInvalid.WrapMsg(fmt.Sprintf("token uid %s != userID %s", resp.UserID, userID)) return servererrs.ErrTokenInvalid.WrapMsg(fmt.Sprintf("token uid %s != userID %s", resp.UserID, userID))
} }
if resp.PlatformID != platformID { if int(resp.PlatformID) != platformID {
return servererrs.ErrTokenInvalid.WrapMsg(fmt.Sprintf("token platform %d != platformID %d", resp.PlatformID, platformID)) return servererrs.ErrTokenInvalid.WrapMsg(fmt.Sprintf("token platform %d != platformID %d", resp.PlatformID, platformID))
} }
return nil return nil
@ -519,10 +520,16 @@ func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) {
log.ZWarn(connContext, "websocket upgrade failed", err) log.ZWarn(connContext, "websocket upgrade failed", err)
return return
} }
if connContext.ShouldSendResp() {
if err := conn.WriteMessage(websocket.TextMessage, wsSuccessResponse); err != nil {
log.ZWarn(connContext, "WriteMessage first response", err)
return
}
}
log.ZDebug(connContext, "new conn", "token", connContext.GetToken()) log.ZDebug(connContext, "new conn", "token", connContext.GetToken())
var pingInterval time.Duration var pingInterval time.Duration
if connContext.GetPlatformID() == strconv.Itoa(constant.WebPlatformID) { if connContext.GetPlatformID() == constant.WebPlatformID {
pingInterval = pingPeriod pingInterval = pingPeriod
} }

File diff suppressed because it is too large Load Diff

View File

@ -2,6 +2,7 @@ package group
import ( import (
"context" "context"
"errors"
"github.com/openimsdk/open-im-server/v3/internal/rpc/incrversion" "github.com/openimsdk/open-im-server/v3/internal/rpc/incrversion"
"github.com/openimsdk/open-im-server/v3/pkg/authverify" "github.com/openimsdk/open-im-server/v3/pkg/authverify"
@ -12,23 +13,24 @@ import (
pbgroup "github.com/openimsdk/protocol/group" pbgroup "github.com/openimsdk/protocol/group"
"github.com/openimsdk/protocol/sdkws" "github.com/openimsdk/protocol/sdkws"
"github.com/openimsdk/tools/errs" "github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/mcontext" "github.com/openimsdk/tools/mcontext"
"github.com/openimsdk/tools/utils/datautil" "github.com/openimsdk/tools/utils/datautil"
) )
const versionSyncLimit = 500 const versionSyncLimit = 500
func (g *groupServer) GetFullGroupMemberUserIDs(ctx context.Context, req *pbgroup.GetFullGroupMemberUserIDsReq) (*pbgroup.GetFullGroupMemberUserIDsResp, error) { func (s *groupServer) GetFullGroupMemberUserIDs(ctx context.Context, req *pbgroup.GetFullGroupMemberUserIDsReq) (*pbgroup.GetFullGroupMemberUserIDsResp, error) {
userIDs, err := g.db.FindGroupMemberUserID(ctx, req.GroupID) userIDs, err := s.db.FindGroupMemberUserID(ctx, req.GroupID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if !authverify.IsAppManagerUid(ctx, g.config.Share.IMAdminUserID) { if !authverify.IsAppManagerUid(ctx, s.config.Share.IMAdminUserID) {
if !datautil.Contain(mcontext.GetOpUserID(ctx), userIDs...) { if !datautil.Contain(mcontext.GetOpUserID(ctx), userIDs...) {
return nil, errs.ErrNoPermission.WrapMsg("op user not in group") return nil, errs.ErrNoPermission.WrapMsg("op user not in group")
} }
} }
vl, err := g.db.FindMaxGroupMemberVersionCache(ctx, req.GroupID) vl, err := s.db.FindMaxGroupMemberVersionCache(ctx, req.GroupID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -146,8 +148,8 @@ func (s *groupServer) GetIncrementalGroupMember(ctx context.Context, req *pbgrou
return resp, nil return resp, nil
} }
func (g *groupServer) GetIncrementalJoinGroup(ctx context.Context, req *pbgroup.GetIncrementalJoinGroupReq) (*pbgroup.GetIncrementalJoinGroupResp, error) { func (s *groupServer) GetIncrementalJoinGroup(ctx context.Context, req *pbgroup.GetIncrementalJoinGroupReq) (*pbgroup.GetIncrementalJoinGroupResp, error) {
if err := authverify.CheckAccessV3(ctx, req.UserID, g.config.Share.IMAdminUserID); err != nil { if err := authverify.CheckAccessV3(ctx, req.UserID, s.config.Share.IMAdminUserID); err != nil {
return nil, err return nil, err
} }
opt := incrversion.Option[*sdkws.GroupInfo, pbgroup.GetIncrementalJoinGroupResp]{ opt := incrversion.Option[*sdkws.GroupInfo, pbgroup.GetIncrementalJoinGroupResp]{
@ -155,9 +157,9 @@ func (g *groupServer) GetIncrementalJoinGroup(ctx context.Context, req *pbgroup.
VersionKey: req.UserID, VersionKey: req.UserID,
VersionID: req.VersionID, VersionID: req.VersionID,
VersionNumber: req.Version, VersionNumber: req.Version,
Version: g.db.FindJoinIncrVersion, Version: s.db.FindJoinIncrVersion,
CacheMaxVersion: g.db.FindMaxJoinGroupVersionCache, CacheMaxVersion: s.db.FindMaxJoinGroupVersionCache,
Find: g.getGroupsInfo, Find: s.getGroupsInfo,
Resp: func(version *model.VersionLog, delIDs []string, insertList, updateList []*sdkws.GroupInfo, full bool) *pbgroup.GetIncrementalJoinGroupResp { Resp: func(version *model.VersionLog, delIDs []string, insertList, updateList []*sdkws.GroupInfo, full bool) *pbgroup.GetIncrementalJoinGroupResp {
return &pbgroup.GetIncrementalJoinGroupResp{ return &pbgroup.GetIncrementalJoinGroupResp{
VersionID: version.ID.Hex(), VersionID: version.ID.Hex(),
@ -172,22 +174,29 @@ func (g *groupServer) GetIncrementalJoinGroup(ctx context.Context, req *pbgroup.
return opt.Build() return opt.Build()
} }
func (g *groupServer) BatchGetIncrementalGroupMember(ctx context.Context, req *pbgroup.BatchGetIncrementalGroupMemberReq) (*pbgroup.BatchGetIncrementalGroupMemberResp, error) { func (s *groupServer) BatchGetIncrementalGroupMember(ctx context.Context, req *pbgroup.BatchGetIncrementalGroupMemberReq) (*pbgroup.BatchGetIncrementalGroupMemberResp, error) {
var num int var num int
resp := make(map[string]*pbgroup.GetIncrementalGroupMemberResp) resp := make(map[string]*pbgroup.GetIncrementalGroupMemberResp)
for _, memberReq := range req.ReqList { for _, memberReq := range req.ReqList {
if _, ok := resp[memberReq.GroupID]; ok { if _, ok := resp[memberReq.GroupID]; ok {
continue continue
} }
memberResp, err := g.GetIncrementalGroupMember(ctx, memberReq) memberResp, err := s.GetIncrementalGroupMember(ctx, memberReq)
if err != nil { if err != nil {
if errors.Is(err, servererrs.ErrDismissedAlready) {
log.ZWarn(ctx, "Failed to get incremental group member", err, "groupID", memberReq.GroupID, "request", memberReq)
continue
}
return nil, err return nil, err
} }
resp[memberReq.GroupID] = memberResp resp[memberReq.GroupID] = memberResp
num += len(memberResp.Insert) + len(memberResp.Update) + len(memberResp.Delete) num += len(memberResp.Insert) + len(memberResp.Update) + len(memberResp.Delete)
if num >= versionSyncLimit { if num >= versionSyncLimit {
break break
} }
} }
return &pbgroup.BatchGetIncrementalGroupMemberResp{RespList: resp}, nil return &pbgroup.BatchGetIncrementalGroupMemberResp{RespList: resp}, nil
} }