mirror of
https://github.com/openimsdk/open-im-server.git
synced 2025-12-30 14:27:02 +08:00
refactor: simplify platformID handling and enhance UserConnContext structure
This commit is contained in:
parent
95ab761d8f
commit
9fefa916c8
@ -30,7 +30,6 @@ import (
|
||||
"github.com/openimsdk/tools/errs"
|
||||
"github.com/openimsdk/tools/log"
|
||||
"github.com/openimsdk/tools/mcontext"
|
||||
"github.com/openimsdk/tools/utils/stringutil"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -85,7 +84,7 @@ type Client struct {
|
||||
func (c *Client) ResetClient(ctx *UserConnContext, conn ClientConn, longConnServer LongConnServer) {
|
||||
c.w = new(sync.Mutex)
|
||||
c.conn = conn
|
||||
c.PlatformID = stringutil.StringToInt(ctx.GetPlatformID())
|
||||
c.PlatformID = ctx.GetPlatformID()
|
||||
c.IsCompress = ctx.GetCompression()
|
||||
c.IsBackground = ctx.GetBackground()
|
||||
c.UserID = ctx.GetUserID()
|
||||
|
||||
@ -15,18 +15,31 @@
|
||||
package msggateway
|
||||
|
||||
import (
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/common/servererrs"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/common/servererrs"
|
||||
|
||||
"github.com/openimsdk/protocol/constant"
|
||||
"github.com/openimsdk/tools/utils/encrypt"
|
||||
"github.com/openimsdk/tools/utils/stringutil"
|
||||
"github.com/openimsdk/tools/utils/timeutil"
|
||||
)
|
||||
|
||||
type UserConnContextInfo struct {
|
||||
Token string `json:"token"`
|
||||
UserID string `json:"userID"`
|
||||
PlatformID int `json:"platformID"`
|
||||
OperationID string `json:"operationID"`
|
||||
Compression string `json:"compression"`
|
||||
SDKType string `json:"sdkType"`
|
||||
SendResponse bool `json:"sendResponse"`
|
||||
Background bool `json:"background"`
|
||||
}
|
||||
|
||||
type UserConnContext struct {
|
||||
RespWriter http.ResponseWriter
|
||||
Req *http.Request
|
||||
@ -34,6 +47,7 @@ type UserConnContext struct {
|
||||
Method string
|
||||
RemoteAddr string
|
||||
ConnID string
|
||||
info *UserConnContextInfo
|
||||
}
|
||||
|
||||
func (c *UserConnContext) Deadline() (deadline time.Time, ok bool) {
|
||||
@ -57,7 +71,7 @@ func (c *UserConnContext) Value(key any) any {
|
||||
case constant.ConnID:
|
||||
return c.GetConnID()
|
||||
case constant.OpUserPlatform:
|
||||
return constant.PlatformIDToName(stringutil.StringToInt(c.GetPlatformID()))
|
||||
return c.GetPlatformID()
|
||||
case constant.RemoteAddr:
|
||||
return c.RemoteAddr
|
||||
default:
|
||||
@ -82,30 +96,91 @@ func newContext(respWriter http.ResponseWriter, req *http.Request) *UserConnCont
|
||||
|
||||
func newTempContext() *UserConnContext {
|
||||
return &UserConnContext{
|
||||
Req: &http.Request{URL: &url.URL{}},
|
||||
Req: &http.Request{URL: &url.URL{}},
|
||||
info: &UserConnContextInfo{},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *UserConnContext) ParseEssentialArgs() error {
|
||||
query := c.Req.URL.Query()
|
||||
if data := query.Get("v"); data != "" {
|
||||
return c.parseByJson(data)
|
||||
} else {
|
||||
return c.parseByQuery(query, c.Req.Header)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *UserConnContext) parseByQuery(query url.Values, header http.Header) error {
|
||||
info := UserConnContextInfo{
|
||||
Token: query.Get(Token),
|
||||
UserID: query.Get(WsUserID),
|
||||
OperationID: query.Get(OperationID),
|
||||
Compression: query.Get(Compression),
|
||||
SDKType: query.Get(SDKType),
|
||||
}
|
||||
platformID, err := strconv.Atoi(query.Get(PlatformID))
|
||||
if err != nil {
|
||||
return servererrs.ErrConnArgsErr.WrapMsg("platformID is not int")
|
||||
}
|
||||
info.PlatformID = platformID
|
||||
if val := query.Get(SendResponse); val != "" {
|
||||
ok, err := strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
return servererrs.ErrConnArgsErr.WrapMsg("isMsgResp is not bool")
|
||||
}
|
||||
info.SendResponse = ok
|
||||
}
|
||||
if info.Compression == "" {
|
||||
info.Compression = header.Get(Compression)
|
||||
}
|
||||
background, err := strconv.ParseBool(query.Get(BackgroundStatus))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
info.Background = background
|
||||
return c.checkInfo(&info)
|
||||
}
|
||||
|
||||
func (c *UserConnContext) parseByJson(data string) error {
|
||||
reqInfo, err := base64.RawURLEncoding.DecodeString(data)
|
||||
if err != nil {
|
||||
return servererrs.ErrConnArgsErr.WrapMsg("data is not base64")
|
||||
}
|
||||
var info UserConnContextInfo
|
||||
if err := json.Unmarshal(reqInfo, &info); err != nil {
|
||||
return servererrs.ErrConnArgsErr.WrapMsg("data is not json", "info", err.Error())
|
||||
}
|
||||
return c.checkInfo(&info)
|
||||
}
|
||||
|
||||
func (c *UserConnContext) checkInfo(info *UserConnContextInfo) error {
|
||||
if info.OperationID == "" {
|
||||
return servererrs.ErrConnArgsErr.WrapMsg("operationID is empty")
|
||||
}
|
||||
if info.Token == "" {
|
||||
return servererrs.ErrConnArgsErr.WrapMsg("token is empty")
|
||||
}
|
||||
if info.UserID == "" {
|
||||
return servererrs.ErrConnArgsErr.WrapMsg("sendID is empty")
|
||||
}
|
||||
if _, ok := constant.PlatformID2Name[info.PlatformID]; !ok {
|
||||
return servererrs.ErrConnArgsErr.WrapMsg("platformID is invalid")
|
||||
}
|
||||
switch info.SDKType {
|
||||
case "":
|
||||
info.SDKType = GoSDK
|
||||
case GoSDK, JsSDK:
|
||||
default:
|
||||
return servererrs.ErrConnArgsErr.WrapMsg("sdkType is invalid")
|
||||
}
|
||||
c.info = info
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *UserConnContext) GetRemoteAddr() string {
|
||||
return c.RemoteAddr
|
||||
}
|
||||
|
||||
func (c *UserConnContext) Query(key string) (string, bool) {
|
||||
var value string
|
||||
if value = c.Req.URL.Query().Get(key); value == "" {
|
||||
return value, false
|
||||
}
|
||||
return value, true
|
||||
}
|
||||
|
||||
func (c *UserConnContext) GetHeader(key string) (string, bool) {
|
||||
var value string
|
||||
if value = c.Req.Header.Get(key); value == "" {
|
||||
return value, false
|
||||
}
|
||||
return value, true
|
||||
}
|
||||
|
||||
func (c *UserConnContext) SetHeader(key, value string) {
|
||||
c.RespWriter.Header().Set(key, value)
|
||||
}
|
||||
@ -119,93 +194,69 @@ func (c *UserConnContext) GetConnID() string {
|
||||
}
|
||||
|
||||
func (c *UserConnContext) GetUserID() string {
|
||||
return c.Req.URL.Query().Get(WsUserID)
|
||||
if c == nil || c.info == nil {
|
||||
return ""
|
||||
}
|
||||
return c.info.UserID
|
||||
}
|
||||
|
||||
func (c *UserConnContext) GetPlatformID() string {
|
||||
return c.Req.URL.Query().Get(PlatformID)
|
||||
func (c *UserConnContext) GetPlatformID() int {
|
||||
if c == nil || c.info == nil {
|
||||
return 0
|
||||
}
|
||||
return c.info.PlatformID
|
||||
}
|
||||
|
||||
func (c *UserConnContext) GetOperationID() string {
|
||||
return c.Req.URL.Query().Get(OperationID)
|
||||
if c == nil || c.info == nil {
|
||||
return ""
|
||||
}
|
||||
return c.info.OperationID
|
||||
}
|
||||
|
||||
func (c *UserConnContext) SetOperationID(operationID string) {
|
||||
values := c.Req.URL.Query()
|
||||
values.Set(OperationID, operationID)
|
||||
c.Req.URL.RawQuery = values.Encode()
|
||||
if c.info == nil {
|
||||
c.info = &UserConnContextInfo{}
|
||||
}
|
||||
c.info.OperationID = operationID
|
||||
}
|
||||
|
||||
func (c *UserConnContext) GetToken() string {
|
||||
return c.Req.URL.Query().Get(Token)
|
||||
if c == nil || c.info == nil {
|
||||
return ""
|
||||
}
|
||||
return c.info.Token
|
||||
}
|
||||
|
||||
func (c *UserConnContext) GetCompression() bool {
|
||||
compression, exists := c.Query(Compression)
|
||||
if exists && compression == GzipCompressionProtocol {
|
||||
return true
|
||||
} else {
|
||||
compression, exists := c.GetHeader(Compression)
|
||||
if exists && compression == GzipCompressionProtocol {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
return c != nil && c.info != nil && c.info.Compression == GzipCompressionProtocol
|
||||
}
|
||||
|
||||
func (c *UserConnContext) GetSDKType() string {
|
||||
sdkType := c.Req.URL.Query().Get(SDKType)
|
||||
if sdkType == "" {
|
||||
sdkType = GoSDK
|
||||
if c == nil || c.info == nil {
|
||||
return GoSDK
|
||||
}
|
||||
switch c.info.SDKType {
|
||||
case "", GoSDK:
|
||||
return GoSDK
|
||||
case JsSDK:
|
||||
return JsSDK
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
return sdkType
|
||||
}
|
||||
|
||||
func (c *UserConnContext) ShouldSendResp() bool {
|
||||
errResp, exists := c.Query(SendResponse)
|
||||
if exists {
|
||||
b, err := strconv.ParseBool(errResp)
|
||||
if err != nil {
|
||||
return false
|
||||
} else {
|
||||
return b
|
||||
}
|
||||
}
|
||||
return false
|
||||
return c != nil && c.info != nil && c.info.SendResponse
|
||||
}
|
||||
|
||||
func (c *UserConnContext) SetToken(token string) {
|
||||
c.Req.URL.RawQuery = Token + "=" + token
|
||||
if c.info == nil {
|
||||
c.info = &UserConnContextInfo{}
|
||||
}
|
||||
c.info.Token = token
|
||||
}
|
||||
|
||||
func (c *UserConnContext) GetBackground() bool {
|
||||
b, err := strconv.ParseBool(c.Req.URL.Query().Get(BackgroundStatus))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return b
|
||||
}
|
||||
func (c *UserConnContext) ParseEssentialArgs() error {
|
||||
_, exists := c.Query(Token)
|
||||
if !exists {
|
||||
return servererrs.ErrConnArgsErr.WrapMsg("token is empty")
|
||||
}
|
||||
_, exists = c.Query(WsUserID)
|
||||
if !exists {
|
||||
return servererrs.ErrConnArgsErr.WrapMsg("sendID is empty")
|
||||
}
|
||||
platformIDStr, exists := c.Query(PlatformID)
|
||||
if !exists {
|
||||
return servererrs.ErrConnArgsErr.WrapMsg("platformID is empty")
|
||||
}
|
||||
_, err := strconv.Atoi(platformIDStr)
|
||||
if err != nil {
|
||||
return servererrs.ErrConnArgsErr.WrapMsg("platformID is not int")
|
||||
}
|
||||
switch sdkType, _ := c.Query(SDKType); sdkType {
|
||||
case "", GoSDK, JsSDK:
|
||||
default:
|
||||
return servererrs.ErrConnArgsErr.WrapMsg("sdkType is not go or js")
|
||||
}
|
||||
return nil
|
||||
return c != nil && c.info != nil && c.info.Background
|
||||
}
|
||||
|
||||
@ -5,7 +5,6 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@ -30,6 +29,8 @@ import (
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
var wsSuccessResponse, _ = json.Marshal(&apiresp.ApiResponse{})
|
||||
|
||||
type LongConnServer interface {
|
||||
Run(done chan error) error
|
||||
wsHandler(w http.ResponseWriter, r *http.Request)
|
||||
@ -448,11 +449,11 @@ func (ws *WsServer) unregisterClient(client *Client) {
|
||||
// validateRespWithRequest checks if the response matches the expected userID and platformID.
|
||||
func (ws *WsServer) validateRespWithRequest(ctx *UserConnContext, resp *pbAuth.ParseTokenResp) error {
|
||||
userID := ctx.GetUserID()
|
||||
platformID := stringutil.StringToInt32(ctx.GetPlatformID())
|
||||
platformID := ctx.GetPlatformID()
|
||||
if resp.UserID != userID {
|
||||
return servererrs.ErrTokenInvalid.WrapMsg(fmt.Sprintf("token uid %s != userID %s", resp.UserID, userID))
|
||||
}
|
||||
if resp.PlatformID != platformID {
|
||||
if int(resp.PlatformID) != platformID {
|
||||
return servererrs.ErrTokenInvalid.WrapMsg(fmt.Sprintf("token platform %d != platformID %d", resp.PlatformID, platformID))
|
||||
}
|
||||
return nil
|
||||
@ -519,10 +520,16 @@ func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) {
|
||||
log.ZWarn(connContext, "websocket upgrade failed", err)
|
||||
return
|
||||
}
|
||||
if connContext.ShouldSendResp() {
|
||||
if err := conn.WriteMessage(websocket.TextMessage, wsSuccessResponse); err != nil {
|
||||
log.ZWarn(connContext, "WriteMessage first response", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
log.ZDebug(connContext, "new conn", "token", connContext.GetToken())
|
||||
|
||||
var pingInterval time.Duration
|
||||
if connContext.GetPlatformID() == strconv.Itoa(constant.WebPlatformID) {
|
||||
if connContext.GetPlatformID() == constant.WebPlatformID {
|
||||
pingInterval = pingPeriod
|
||||
}
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -2,6 +2,7 @@ package group
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/openimsdk/open-im-server/v3/internal/rpc/incrversion"
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/authverify"
|
||||
@ -12,23 +13,24 @@ import (
|
||||
pbgroup "github.com/openimsdk/protocol/group"
|
||||
"github.com/openimsdk/protocol/sdkws"
|
||||
"github.com/openimsdk/tools/errs"
|
||||
"github.com/openimsdk/tools/log"
|
||||
"github.com/openimsdk/tools/mcontext"
|
||||
"github.com/openimsdk/tools/utils/datautil"
|
||||
)
|
||||
|
||||
const versionSyncLimit = 500
|
||||
|
||||
func (g *groupServer) GetFullGroupMemberUserIDs(ctx context.Context, req *pbgroup.GetFullGroupMemberUserIDsReq) (*pbgroup.GetFullGroupMemberUserIDsResp, error) {
|
||||
userIDs, err := g.db.FindGroupMemberUserID(ctx, req.GroupID)
|
||||
func (s *groupServer) GetFullGroupMemberUserIDs(ctx context.Context, req *pbgroup.GetFullGroupMemberUserIDsReq) (*pbgroup.GetFullGroupMemberUserIDsResp, error) {
|
||||
userIDs, err := s.db.FindGroupMemberUserID(ctx, req.GroupID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !authverify.IsAppManagerUid(ctx, g.config.Share.IMAdminUserID) {
|
||||
if !authverify.IsAppManagerUid(ctx, s.config.Share.IMAdminUserID) {
|
||||
if !datautil.Contain(mcontext.GetOpUserID(ctx), userIDs...) {
|
||||
return nil, errs.ErrNoPermission.WrapMsg("op user not in group")
|
||||
}
|
||||
}
|
||||
vl, err := g.db.FindMaxGroupMemberVersionCache(ctx, req.GroupID)
|
||||
vl, err := s.db.FindMaxGroupMemberVersionCache(ctx, req.GroupID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -146,8 +148,8 @@ func (s *groupServer) GetIncrementalGroupMember(ctx context.Context, req *pbgrou
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (g *groupServer) GetIncrementalJoinGroup(ctx context.Context, req *pbgroup.GetIncrementalJoinGroupReq) (*pbgroup.GetIncrementalJoinGroupResp, error) {
|
||||
if err := authverify.CheckAccessV3(ctx, req.UserID, g.config.Share.IMAdminUserID); err != nil {
|
||||
func (s *groupServer) GetIncrementalJoinGroup(ctx context.Context, req *pbgroup.GetIncrementalJoinGroupReq) (*pbgroup.GetIncrementalJoinGroupResp, error) {
|
||||
if err := authverify.CheckAccessV3(ctx, req.UserID, s.config.Share.IMAdminUserID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
opt := incrversion.Option[*sdkws.GroupInfo, pbgroup.GetIncrementalJoinGroupResp]{
|
||||
@ -155,9 +157,9 @@ func (g *groupServer) GetIncrementalJoinGroup(ctx context.Context, req *pbgroup.
|
||||
VersionKey: req.UserID,
|
||||
VersionID: req.VersionID,
|
||||
VersionNumber: req.Version,
|
||||
Version: g.db.FindJoinIncrVersion,
|
||||
CacheMaxVersion: g.db.FindMaxJoinGroupVersionCache,
|
||||
Find: g.getGroupsInfo,
|
||||
Version: s.db.FindJoinIncrVersion,
|
||||
CacheMaxVersion: s.db.FindMaxJoinGroupVersionCache,
|
||||
Find: s.getGroupsInfo,
|
||||
Resp: func(version *model.VersionLog, delIDs []string, insertList, updateList []*sdkws.GroupInfo, full bool) *pbgroup.GetIncrementalJoinGroupResp {
|
||||
return &pbgroup.GetIncrementalJoinGroupResp{
|
||||
VersionID: version.ID.Hex(),
|
||||
@ -172,22 +174,29 @@ func (g *groupServer) GetIncrementalJoinGroup(ctx context.Context, req *pbgroup.
|
||||
return opt.Build()
|
||||
}
|
||||
|
||||
func (g *groupServer) BatchGetIncrementalGroupMember(ctx context.Context, req *pbgroup.BatchGetIncrementalGroupMemberReq) (*pbgroup.BatchGetIncrementalGroupMemberResp, error) {
|
||||
func (s *groupServer) BatchGetIncrementalGroupMember(ctx context.Context, req *pbgroup.BatchGetIncrementalGroupMemberReq) (*pbgroup.BatchGetIncrementalGroupMemberResp, error) {
|
||||
var num int
|
||||
resp := make(map[string]*pbgroup.GetIncrementalGroupMemberResp)
|
||||
|
||||
for _, memberReq := range req.ReqList {
|
||||
if _, ok := resp[memberReq.GroupID]; ok {
|
||||
continue
|
||||
}
|
||||
memberResp, err := g.GetIncrementalGroupMember(ctx, memberReq)
|
||||
memberResp, err := s.GetIncrementalGroupMember(ctx, memberReq)
|
||||
if err != nil {
|
||||
if errors.Is(err, servererrs.ErrDismissedAlready) {
|
||||
log.ZWarn(ctx, "Failed to get incremental group member", err, "groupID", memberReq.GroupID, "request", memberReq)
|
||||
continue
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp[memberReq.GroupID] = memberResp
|
||||
num += len(memberResp.Insert) + len(memberResp.Update) + len(memberResp.Delete)
|
||||
if num >= versionSyncLimit {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return &pbgroup.BatchGetIncrementalGroupMemberResp{RespList: resp}, nil
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user