refactor: msggateway update

This commit is contained in:
Gordon 2024-04-08 21:16:44 +08:00
parent 93b68f070c
commit 220eeca866
7 changed files with 87 additions and 94 deletions

View File

@ -20,11 +20,7 @@ import (
) )
func main() { func main() {
msgGatewayCmd := cmd.NewMsgGatewayCmd(cmd.MsgGatewayServer) if err := cmd.NewMsgGatewayCmd().Exec(); err != nil {
msgGatewayCmd.AddWsPortFlag()
msgGatewayCmd.AddPortFlag()
msgGatewayCmd.AddPrometheusPortFlag()
if err := msgGatewayCmd.Exec(); err != nil {
program.ExitWithError(err) program.ExitWithError(err)
} }
} }

View File

@ -25,8 +25,8 @@ import (
"github.com/openimsdk/tools/mcontext" "github.com/openimsdk/tools/mcontext"
) )
func CallbackUserOnline(ctx context.Context, callback *config.Callback, userID string, platformID int, isAppBackground bool, connID string) error { func CallbackUserOnline(ctx context.Context, callback *config.Webhooks, userID string, platformID int, isAppBackground bool, connID string) error {
if !callback.CallbackUserOnline.Enable { if !callback.AfterUserOnline.Enable {
return nil return nil
} }
req := cbapi.CallbackUserOnlineReq{ req := cbapi.CallbackUserOnlineReq{
@ -44,14 +44,14 @@ func CallbackUserOnline(ctx context.Context, callback *config.Callback, userID s
ConnID: connID, ConnID: connID,
} }
resp := cbapi.CommonCallbackResp{} resp := cbapi.CommonCallbackResp{}
if err := http.CallBackPostReturn(ctx, callback.CallbackUrl, &req, &resp, callback.CallbackUserOnline); err != nil { if err := http.CallBackPostReturn(ctx, callback.URL, &req, &resp, callback.AfterUserOnline); err != nil {
return err return err
} }
return nil return nil
} }
func CallbackUserOffline(ctx context.Context, callback *config.Callback, userID string, platformID int, connID string) error { func CallbackUserOffline(ctx context.Context, callback *config.Webhooks, userID string, platformID int, connID string) error {
if !callback.CallbackUserOffline.Enable { if !callback.AfterUserOffline.Enable {
return nil return nil
} }
req := &cbapi.CallbackUserOfflineReq{ req := &cbapi.CallbackUserOfflineReq{
@ -68,14 +68,14 @@ func CallbackUserOffline(ctx context.Context, callback *config.Callback, userID
ConnID: connID, ConnID: connID,
} }
resp := &cbapi.CallbackUserOfflineResp{} resp := &cbapi.CallbackUserOfflineResp{}
if err := http.CallBackPostReturn(ctx, callback.CallbackUrl, req, resp, callback.CallbackUserOffline); err != nil { if err := http.CallBackPostReturn(ctx, callback.URL, req, resp, callback.AfterUserOffline); err != nil {
return err return err
} }
return nil return nil
} }
func CallbackUserKickOff(ctx context.Context, callback *config.Callback, userID string, platformID int) error { func CallbackUserKickOff(ctx context.Context, callback *config.Webhooks, userID string, platformID int) error {
if !callback.CallbackUserKickOff.Enable { if !callback.AfterUserKickOff.Enable {
return nil return nil
} }
req := &cbapi.CallbackUserKickOffReq{ req := &cbapi.CallbackUserKickOffReq{
@ -91,7 +91,7 @@ func CallbackUserKickOff(ctx context.Context, callback *config.Callback, userID
Seq: time.Now().UnixMilli(), Seq: time.Now().UnixMilli(),
} }
resp := &cbapi.CommonCallbackResp{} resp := &cbapi.CommonCallbackResp{}
if err := http.CallBackPostReturn(ctx, callback.CallbackUrl, req, resp, callback.CallbackUserOffline); err != nil { if err := http.CallBackPostReturn(ctx, callback.URL, req, resp, callback.AfterUserOffline); err != nil {
return err return err
} }
return nil return nil

View File

@ -16,10 +16,10 @@ package msggateway
import ( import (
"context" "context"
"github.com/openimsdk/open-im-server/v3/pkg/common/cmd"
"github.com/openimsdk/tools/db/redisutil" "github.com/openimsdk/tools/db/redisutil"
"github.com/openimsdk/open-im-server/v3/pkg/authverify" "github.com/openimsdk/open-im-server/v3/pkg/authverify"
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/cache" "github.com/openimsdk/open-im-server/v3/pkg/common/db/cache"
"github.com/openimsdk/open-im-server/v3/pkg/common/servererrs" "github.com/openimsdk/open-im-server/v3/pkg/common/servererrs"
"github.com/openimsdk/open-im-server/v3/pkg/common/startrpc" "github.com/openimsdk/open-im-server/v3/pkg/common/startrpc"
@ -32,8 +32,8 @@ import (
"google.golang.org/grpc" "google.golang.org/grpc"
) )
func (s *Server) InitServer(ctx context.Context, config *config.GlobalConfig, disCov discovery.SvcDiscoveryRegistry, server *grpc.Server) error { func (s *Server) InitServer(ctx context.Context, config *cmd.MsgGatewayConfig, disCov discovery.SvcDiscoveryRegistry, server *grpc.Server) error {
rdb, err := redisutil.NewRedisClient(ctx, config.Redis.Build()) rdb, err := redisutil.NewRedisClient(ctx, config.RedisConfig.Build())
if err != nil { if err != nil {
return err return err
} }
@ -45,11 +45,11 @@ func (s *Server) InitServer(ctx context.Context, config *config.GlobalConfig, di
return nil return nil
} }
func (s *Server) Start(ctx context.Context, conf *config.GlobalConfig) error { func (s *Server) Start(ctx context.Context, index int, conf *cmd.MsgGatewayConfig) error {
return startrpc.Start(ctx, return startrpc.Start(ctx, &conf.ZookeeperConfig, &conf.MsgGateway.Prometheus, conf.MsgGateway.ListenIP,
s.rpcPort, conf.MsgGateway.RPC.RegisterIP,
conf.RpcRegisterName.OpenImMessageGatewayName, conf.MsgGateway.RPC.Ports, index,
s.prometheusPort, conf.Share.RpcRegisterName.MessageGateway,
conf, conf,
s.InitServer, s.InitServer,
) )
@ -59,7 +59,7 @@ type Server struct {
rpcPort int rpcPort int
prometheusPort int prometheusPort int
LongConnServer LongConnServer LongConnServer LongConnServer
config *config.GlobalConfig config *cmd.MsgGatewayConfig
pushTerminal map[int]struct{} pushTerminal map[int]struct{}
} }
@ -67,7 +67,7 @@ func (s *Server) SetLongConnServer(LongConnServer LongConnServer) {
s.LongConnServer = LongConnServer s.LongConnServer = LongConnServer
} }
func NewServer(rpcPort int, proPort int, longConnServer LongConnServer, conf *config.GlobalConfig) *Server { func NewServer(rpcPort int, proPort int, longConnServer LongConnServer, conf *cmd.MsgGatewayConfig) *Server {
s := &Server{ s := &Server{
rpcPort: rpcPort, rpcPort: rpcPort,
prometheusPort: proPort, prometheusPort: proPort,
@ -91,7 +91,7 @@ func (s *Server) GetUsersOnlineStatus(
ctx context.Context, ctx context.Context,
req *msggateway.GetUsersOnlineStatusReq, req *msggateway.GetUsersOnlineStatusReq,
) (*msggateway.GetUsersOnlineStatusResp, error) { ) (*msggateway.GetUsersOnlineStatusResp, error) {
if !authverify.IsAppManagerUid(ctx, &s.config.Manager, &s.config.IMAdmin) { if !authverify.IsAppManagerUid(ctx, &s.config.Share.IMAdmin) {
return nil, errs.ErrNoPermission.WrapMsg("only app manager") return nil, errs.ErrNoPermission.WrapMsg("only app manager")
} }
var resp msggateway.GetUsersOnlineStatusResp var resp msggateway.GetUsersOnlineStatusResp

View File

@ -16,23 +16,35 @@ package msggateway
import ( import (
"context" "context"
"github.com/openimsdk/open-im-server/v3/pkg/common/cmd"
"github.com/openimsdk/tools/utils/datautil"
"time" "time"
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
"github.com/openimsdk/tools/log" "github.com/openimsdk/tools/log"
) )
// Start run ws server. // Start run ws server.
func Start(ctx context.Context, conf *config.GlobalConfig, rpcPort, wsPort, prometheusPort int) error { func Start(ctx context.Context, index int, conf *cmd.MsgGatewayConfig) error {
log.CInfo(ctx, "MSG-GATEWAY server is initializing", "rpcPort", rpcPort, "wsPort", wsPort, log.CInfo(ctx, "MSG-GATEWAY server is initializing", "rpcPorts", conf.MsgGateway.RPC.Ports,
"prometheusPort", prometheusPort) "wsPort", conf.MsgGateway.LongConnSvr.Ports, "prometheusPorts", conf.MsgGateway.Prometheus.Ports)
wsPort, err := datautil.GetElemByIndex(conf.MsgGateway.LongConnSvr.Ports, index)
if err != nil {
return err
}
prometheusPort, err := datautil.GetElemByIndex(conf.MsgGateway.Prometheus.Ports, index)
if err != nil {
return err
}
rpcPort, err := datautil.GetElemByIndex(conf.MsgGateway.RPC.Ports, index)
if err != nil {
return err
}
longServer, err := NewWsServer( longServer, err := NewWsServer(
conf, conf,
WithPort(wsPort), WithPort(wsPort),
WithMaxConnNum(int64(conf.LongConnSvr.WebsocketMaxConnNum)), WithMaxConnNum(int64(conf.MsgGateway.LongConnSvr.WebsocketMaxConnNum)),
WithHandshakeTimeout(time.Duration(conf.LongConnSvr.WebsocketTimeout)*time.Second), WithHandshakeTimeout(time.Duration(conf.MsgGateway.LongConnSvr.WebsocketTimeout)*time.Second),
WithMessageMaxMsgLength(conf.LongConnSvr.WebsocketMaxMsgLen), WithMessageMaxMsgLength(conf.MsgGateway.LongConnSvr.WebsocketMaxMsgLen),
WithWriteBufferSize(conf.LongConnSvr.WebsocketWriteBufferSize),
) )
if err != nil { if err != nil {
return err return err
@ -41,7 +53,7 @@ func Start(ctx context.Context, conf *config.GlobalConfig, rpcPort, wsPort, prom
hubServer := NewServer(rpcPort, prometheusPort, longServer, conf) hubServer := NewServer(rpcPort, prometheusPort, longServer, conf)
netDone := make(chan error) netDone := make(chan error)
go func() { go func() {
err = hubServer.Start(ctx, conf) err = hubServer.Start(ctx, index, conf)
netDone <- err netDone <- err
}() }()
return hubServer.LongConnServer.Run(netDone) return hubServer.LongConnServer.Run(netDone)

View File

@ -107,8 +107,8 @@ type GrpcHandler struct {
} }
func NewGrpcHandler(validate *validator.Validate, client discovery.SvcDiscoveryRegistry, rpcRegisterName *config.RpcRegisterName) *GrpcHandler { func NewGrpcHandler(validate *validator.Validate, client discovery.SvcDiscoveryRegistry, rpcRegisterName *config.RpcRegisterName) *GrpcHandler {
msgRpcClient := rpcclient.NewMessageRpcClient(client, rpcRegisterName.OpenImMsgName) msgRpcClient := rpcclient.NewMessageRpcClient(client, rpcRegisterName.Msg)
pushRpcClient := rpcclient.NewPushRpcClient(client, rpcRegisterName.OpenImPushName) pushRpcClient := rpcclient.NewPushRpcClient(client, rpcRegisterName.Push)
return &GrpcHandler{ return &GrpcHandler{
msgRpcClient: &msgRpcClient, msgRpcClient: &msgRpcClient,
pushClient: &pushRpcClient, validate: validate, pushClient: &pushRpcClient, validate: validate,

View File

@ -18,6 +18,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/openimsdk/open-im-server/v3/pkg/common/cmd"
"net/http" "net/http"
"strconv" "strconv"
"sync" "sync"
@ -26,7 +27,6 @@ import (
"github.com/go-playground/validator/v10" "github.com/go-playground/validator/v10"
"github.com/openimsdk/open-im-server/v3/pkg/authverify" "github.com/openimsdk/open-im-server/v3/pkg/authverify"
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/cache" "github.com/openimsdk/open-im-server/v3/pkg/common/db/cache"
"github.com/openimsdk/open-im-server/v3/pkg/common/prommetrics" "github.com/openimsdk/open-im-server/v3/pkg/common/prommetrics"
"github.com/openimsdk/open-im-server/v3/pkg/common/servererrs" "github.com/openimsdk/open-im-server/v3/pkg/common/servererrs"
@ -49,7 +49,7 @@ type LongConnServer interface {
GetUserPlatformCons(userID string, platform int) ([]*Client, bool, bool) GetUserPlatformCons(userID string, platform int) ([]*Client, bool, bool)
Validate(s any) error Validate(s any) error
SetCacheHandler(cache cache.TokenModel) SetCacheHandler(cache cache.TokenModel)
SetDiscoveryRegistry(client discovery.SvcDiscoveryRegistry, config *config.GlobalConfig) SetDiscoveryRegistry(client discovery.SvcDiscoveryRegistry, config *cmd.MsgGatewayConfig)
KickUserConn(client *Client) error KickUserConn(client *Client) error
UnRegister(c *Client) UnRegister(c *Client)
SetKickHandlerInfo(i *kickHandler) SetKickHandlerInfo(i *kickHandler)
@ -59,7 +59,7 @@ type LongConnServer interface {
} }
type WsServer struct { type WsServer struct {
globalConfig *config.GlobalConfig msgGatewayConfig *cmd.MsgGatewayConfig
port int port int
wsMaxConnNum int64 wsMaxConnNum int64
registerChan chan *Client registerChan chan *Client
@ -86,9 +86,9 @@ type kickHandler struct {
newClient *Client newClient *Client
} }
func (ws *WsServer) SetDiscoveryRegistry(disCov discovery.SvcDiscoveryRegistry, config *config.GlobalConfig) { func (ws *WsServer) SetDiscoveryRegistry(disCov discovery.SvcDiscoveryRegistry, config *cmd.MsgGatewayConfig) {
ws.MessageHandler = NewGrpcHandler(ws.validate, disCov, &config.RpcRegisterName) ws.MessageHandler = NewGrpcHandler(ws.validate, disCov, &config.Share.RpcRegisterName)
u := rpcclient.NewUserRpcClient(disCov, config.RpcRegisterName.OpenImUserName, &config.Manager, &config.IMAdmin) u := rpcclient.NewUserRpcClient(disCov, config.Share.RpcRegisterName.User, &config.Share.IMAdmin)
ws.userClient = &u ws.userClient = &u
ws.disCov = disCov ws.disCov = disCov
} }
@ -100,12 +100,12 @@ func (ws *WsServer) SetUserOnlineStatus(ctx context.Context, client *Client, sta
} }
switch status { switch status {
case constant.Online: case constant.Online:
err := CallbackUserOnline(ctx, &ws.globalConfig.Callback, client.UserID, client.PlatformID, client.IsBackground, client.ctx.GetConnID()) err := CallbackUserOnline(ctx, &ws.msgGatewayConfig.WebhooksConfig, client.UserID, client.PlatformID, client.IsBackground, client.ctx.GetConnID())
if err != nil { if err != nil {
log.ZWarn(ctx, "CallbackUserOnline err", err) log.ZWarn(ctx, "CallbackUserOnline err", err)
} }
case constant.Offline: case constant.Offline:
err := CallbackUserOffline(ctx, &ws.globalConfig.Callback, client.UserID, client.PlatformID, client.ctx.GetConnID()) err := CallbackUserOffline(ctx, &ws.msgGatewayConfig.WebhooksConfig, client.UserID, client.PlatformID, client.ctx.GetConnID())
if err != nil { if err != nil {
log.ZWarn(ctx, "CallbackUserOffline err", err) log.ZWarn(ctx, "CallbackUserOffline err", err)
} }
@ -132,14 +132,14 @@ func (ws *WsServer) GetUserPlatformCons(userID string, platform int) ([]*Client,
return ws.clients.Get(userID, platform) return ws.clients.Get(userID, platform)
} }
func NewWsServer(globalConfig *config.GlobalConfig, opts ...Option) (*WsServer, error) { func NewWsServer(msgGatewayConfig *cmd.MsgGatewayConfig, opts ...Option) (*WsServer, error) {
var config configs var config configs
for _, o := range opts { for _, o := range opts {
o(&config) o(&config)
} }
v := validator.New() v := validator.New()
return &WsServer{ return &WsServer{
globalConfig: globalConfig, msgGatewayConfig: msgGatewayConfig,
port: config.port, port: config.port,
wsMaxConnNum: config.maxConnNum, wsMaxConnNum: config.maxConnNum,
writeBufferSize: config.writeBufferSize, writeBufferSize: config.writeBufferSize,
@ -213,7 +213,7 @@ func (ws *WsServer) Run(done chan error) error {
var concurrentRequest = 3 var concurrentRequest = 3
func (ws *WsServer) sendUserOnlineInfoToOtherNode(ctx context.Context, client *Client) error { func (ws *WsServer) sendUserOnlineInfoToOtherNode(ctx context.Context, client *Client) error {
conns, err := ws.disCov.GetConns(ctx, ws.globalConfig.RpcRegisterName.OpenImMessageGatewayName) conns, err := ws.disCov.GetConns(ctx, ws.msgGatewayConfig.Share.RpcRegisterName.MessageGateway)
if err != nil { if err != nil {
return err return err
} }
@ -278,7 +278,7 @@ func (ws *WsServer) registerClient(client *Client) {
} }
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
if ws.globalConfig.Envs.Discovery == "zookeeper" { if ws.msgGatewayConfig.Share.Env == "zookeeper" {
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
@ -322,7 +322,7 @@ func (ws *WsServer) KickUserConn(client *Client) error {
} }
func (ws *WsServer) multiTerminalLoginChecker(clientOK bool, oldClients []*Client, newClient *Client) { func (ws *WsServer) multiTerminalLoginChecker(clientOK bool, oldClients []*Client, newClient *Client) {
switch ws.globalConfig.MultiLoginPolicy { switch ws.msgGatewayConfig.MsgGateway.MultiLoginPolicy {
case constant.DefalutNotKick: case constant.DefalutNotKick:
case constant.PCAndOther: case constant.PCAndOther:
if constant.PlatformIDToClass(newClient.PlatformID) == constant.TerminalPC { if constant.PlatformIDToClass(newClient.PlatformID) == constant.TerminalPC {
@ -434,7 +434,7 @@ func (ws *WsServer) ParseWSArgs(r *http.Request) (args *WSArgs, err error) {
return nil, servererrs.ErrConnArgsErr.WrapMsg("platformID is not int") return nil, servererrs.ErrConnArgsErr.WrapMsg("platformID is not int")
} }
v.PlatformID = platformID v.PlatformID = platformID
if err = authverify.WsVerifyToken(v.Token, v.UserID, ws.globalConfig.Secret, platformID); err != nil { if err = authverify.WsVerifyToken(v.Token, v.UserID, ws.msgGatewayConfig.Share.Secret, platformID); err != nil {
return nil, err return nil, err
} }
if query.Get(Compression) == GzipCompressionProtocol { if query.Get(Compression) == GzipCompressionProtocol {

View File

@ -16,11 +16,10 @@ package cmd
import ( import (
"context" "context"
"log" "github.com/openimsdk/open-im-server/v3/pkg/common/config"
"github.com/openimsdk/open-im-server/v3/internal/msggateway" "github.com/openimsdk/open-im-server/v3/internal/msggateway"
config2 "github.com/openimsdk/open-im-server/v3/pkg/common/config"
"github.com/openimsdk/protocol/constant"
"github.com/openimsdk/tools/system/program" "github.com/openimsdk/tools/system/program"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@ -28,53 +27,39 @@ import (
type MsgGatewayCmd struct { type MsgGatewayCmd struct {
*RootCmd *RootCmd
ctx context.Context ctx context.Context
configMap map[string]StructEnvPrefix
msgGatewayConfig MsgGatewayConfig
}
type MsgGatewayConfig struct {
MsgGateway config.MsgGateway
RedisConfig config.Redis
ZookeeperConfig config.ZooKeeper
Share config.Share
WebhooksConfig config.Webhooks
} }
func NewMsgGatewayCmd(name string) *MsgGatewayCmd { func NewMsgGatewayCmd() *MsgGatewayCmd {
ret := &MsgGatewayCmd{RootCmd: NewRootCmd(program.GetProcessName(), name)} var msgGatewayConfig MsgGatewayConfig
ret.ctx = context.WithValue(context.Background(), "version", config2.Version) ret := &MsgGatewayCmd{msgGatewayConfig: msgGatewayConfig}
ret.addRunE() ret.configMap = map[string]StructEnvPrefix{
ret.SetRootCmdPt(ret) OpenIMAPICfgFileName: {EnvPrefix: apiEnvPrefix, ConfigStruct: &msgGatewayConfig.MsgGateway},
RedisConfigFileName: {EnvPrefix: redisEnvPrefix, ConfigStruct: &msgGatewayConfig.RedisConfig},
ZookeeperConfigFileName: {EnvPrefix: zoopkeeperEnvPrefix, ConfigStruct: &msgGatewayConfig.ZookeeperConfig},
ShareFileName: {EnvPrefix: shareEnvPrefix, ConfigStruct: &msgGatewayConfig.Share},
WebhooksConfigFileName: {EnvPrefix: webhooksEnvPrefix, ConfigStruct: &msgGatewayConfig.WebhooksConfig},
}
ret.RootCmd = NewRootCmd(program.GetProcessName(), WithConfigMap(ret.configMap))
ret.ctx = context.WithValue(context.Background(), "version", config.Version)
ret.Command.PreRunE = func(cmd *cobra.Command, args []string) error {
return ret.preRunE()
}
return ret return ret
} }
func (m *MsgGatewayCmd) AddWsPortFlag() {
m.Command.Flags().IntP(constant.FlagWsPort, "w", 0, "ws server listen port")
}
func (m *MsgGatewayCmd) getWsPortFlag(cmd *cobra.Command) int {
port, err := cmd.Flags().GetInt(constant.FlagWsPort)
if err != nil {
log.Println("Error getting ws port flag:", err)
}
if port == 0 {
port = m.PortFromConfig(constant.FlagWsPort)
}
return port
}
func (m *MsgGatewayCmd) addRunE() {
m.Command.RunE = func(cmd *cobra.Command, args []string) error {
return msggateway.Start(m.ctx, m.config, m.getPortFlag(cmd), m.getWsPortFlag(cmd), m.getPrometheusPortFlag(cmd))
}
}
func (m *MsgGatewayCmd) Exec() error { func (m *MsgGatewayCmd) Exec() error {
return m.Execute() return m.Execute()
} }
func (m *MsgGatewayCmd) GetPortFromConfig(portType string) int { func (m *MsgGatewayCmd) preRunE() error {
switch portType { return msggateway.Start(m.ctx, m.Index(), &m.msgGatewayConfig)
case constant.FlagWsPort:
return m.config.LongConnSvr.OpenImWsPort[0]
case constant.FlagPort:
return m.config.LongConnSvr.OpenImMessageGatewayPort[0]
case constant.FlagPrometheusPort:
return m.config.Prometheus.MessageGatewayPrometheusPort[0]
default:
return 0
}
} }