diff --git a/cmd/openim-msggateway/main.go b/cmd/openim-msggateway/main.go index ded134916..6e3eda6bf 100644 --- a/cmd/openim-msggateway/main.go +++ b/cmd/openim-msggateway/main.go @@ -20,11 +20,7 @@ import ( ) func main() { - msgGatewayCmd := cmd.NewMsgGatewayCmd(cmd.MsgGatewayServer) - msgGatewayCmd.AddWsPortFlag() - msgGatewayCmd.AddPortFlag() - msgGatewayCmd.AddPrometheusPortFlag() - if err := msgGatewayCmd.Exec(); err != nil { + if err := cmd.NewMsgGatewayCmd().Exec(); err != nil { program.ExitWithError(err) } } diff --git a/internal/msggateway/callback.go b/internal/msggateway/callback.go index ceb8af1db..46a09c3de 100644 --- a/internal/msggateway/callback.go +++ b/internal/msggateway/callback.go @@ -25,8 +25,8 @@ import ( "github.com/openimsdk/tools/mcontext" ) -func CallbackUserOnline(ctx context.Context, callback *config.Callback, userID string, platformID int, isAppBackground bool, connID string) error { - if !callback.CallbackUserOnline.Enable { +func CallbackUserOnline(ctx context.Context, callback *config.Webhooks, userID string, platformID int, isAppBackground bool, connID string) error { + if !callback.AfterUserOnline.Enable { return nil } req := cbapi.CallbackUserOnlineReq{ @@ -44,14 +44,14 @@ func CallbackUserOnline(ctx context.Context, callback *config.Callback, userID s ConnID: connID, } resp := cbapi.CommonCallbackResp{} - if err := http.CallBackPostReturn(ctx, callback.CallbackUrl, &req, &resp, callback.CallbackUserOnline); err != nil { + if err := http.CallBackPostReturn(ctx, callback.URL, &req, &resp, callback.AfterUserOnline); err != nil { return err } return nil } -func CallbackUserOffline(ctx context.Context, callback *config.Callback, userID string, platformID int, connID string) error { - if !callback.CallbackUserOffline.Enable { +func CallbackUserOffline(ctx context.Context, callback *config.Webhooks, userID string, platformID int, connID string) error { + if !callback.AfterUserOffline.Enable { return nil } req := &cbapi.CallbackUserOfflineReq{ @@ -68,14 +68,14 @@ func CallbackUserOffline(ctx context.Context, callback *config.Callback, userID ConnID: connID, } resp := &cbapi.CallbackUserOfflineResp{} - if err := http.CallBackPostReturn(ctx, callback.CallbackUrl, req, resp, callback.CallbackUserOffline); err != nil { + if err := http.CallBackPostReturn(ctx, callback.URL, req, resp, callback.AfterUserOffline); err != nil { return err } return nil } -func CallbackUserKickOff(ctx context.Context, callback *config.Callback, userID string, platformID int) error { - if !callback.CallbackUserKickOff.Enable { +func CallbackUserKickOff(ctx context.Context, callback *config.Webhooks, userID string, platformID int) error { + if !callback.AfterUserKickOff.Enable { return nil } req := &cbapi.CallbackUserKickOffReq{ @@ -91,7 +91,7 @@ func CallbackUserKickOff(ctx context.Context, callback *config.Callback, userID Seq: time.Now().UnixMilli(), } resp := &cbapi.CommonCallbackResp{} - if err := http.CallBackPostReturn(ctx, callback.CallbackUrl, req, resp, callback.CallbackUserOffline); err != nil { + if err := http.CallBackPostReturn(ctx, callback.URL, req, resp, callback.AfterUserOffline); err != nil { return err } return nil diff --git a/internal/msggateway/hub_server.go b/internal/msggateway/hub_server.go index a502e455a..1e00931e7 100644 --- a/internal/msggateway/hub_server.go +++ b/internal/msggateway/hub_server.go @@ -16,10 +16,10 @@ package msggateway import ( "context" + "github.com/openimsdk/open-im-server/v3/pkg/common/cmd" "github.com/openimsdk/tools/db/redisutil" "github.com/openimsdk/open-im-server/v3/pkg/authverify" - "github.com/openimsdk/open-im-server/v3/pkg/common/config" "github.com/openimsdk/open-im-server/v3/pkg/common/db/cache" "github.com/openimsdk/open-im-server/v3/pkg/common/servererrs" "github.com/openimsdk/open-im-server/v3/pkg/common/startrpc" @@ -32,8 +32,8 @@ import ( "google.golang.org/grpc" ) -func (s *Server) InitServer(ctx context.Context, config *config.GlobalConfig, disCov discovery.SvcDiscoveryRegistry, server *grpc.Server) error { - rdb, err := redisutil.NewRedisClient(ctx, config.Redis.Build()) +func (s *Server) InitServer(ctx context.Context, config *cmd.MsgGatewayConfig, disCov discovery.SvcDiscoveryRegistry, server *grpc.Server) error { + rdb, err := redisutil.NewRedisClient(ctx, config.RedisConfig.Build()) if err != nil { return err } @@ -45,11 +45,11 @@ func (s *Server) InitServer(ctx context.Context, config *config.GlobalConfig, di return nil } -func (s *Server) Start(ctx context.Context, conf *config.GlobalConfig) error { - return startrpc.Start(ctx, - s.rpcPort, - conf.RpcRegisterName.OpenImMessageGatewayName, - s.prometheusPort, +func (s *Server) Start(ctx context.Context, index int, conf *cmd.MsgGatewayConfig) error { + return startrpc.Start(ctx, &conf.ZookeeperConfig, &conf.MsgGateway.Prometheus, conf.MsgGateway.ListenIP, + conf.MsgGateway.RPC.RegisterIP, + conf.MsgGateway.RPC.Ports, index, + conf.Share.RpcRegisterName.MessageGateway, conf, s.InitServer, ) @@ -59,7 +59,7 @@ type Server struct { rpcPort int prometheusPort int LongConnServer LongConnServer - config *config.GlobalConfig + config *cmd.MsgGatewayConfig pushTerminal map[int]struct{} } @@ -67,7 +67,7 @@ func (s *Server) SetLongConnServer(LongConnServer LongConnServer) { s.LongConnServer = LongConnServer } -func NewServer(rpcPort int, proPort int, longConnServer LongConnServer, conf *config.GlobalConfig) *Server { +func NewServer(rpcPort int, proPort int, longConnServer LongConnServer, conf *cmd.MsgGatewayConfig) *Server { s := &Server{ rpcPort: rpcPort, prometheusPort: proPort, @@ -91,7 +91,7 @@ func (s *Server) GetUsersOnlineStatus( ctx context.Context, req *msggateway.GetUsersOnlineStatusReq, ) (*msggateway.GetUsersOnlineStatusResp, error) { - if !authverify.IsAppManagerUid(ctx, &s.config.Manager, &s.config.IMAdmin) { + if !authverify.IsAppManagerUid(ctx, &s.config.Share.IMAdmin) { return nil, errs.ErrNoPermission.WrapMsg("only app manager") } var resp msggateway.GetUsersOnlineStatusResp diff --git a/internal/msggateway/init.go b/internal/msggateway/init.go index 4fcb4b201..6377d258c 100644 --- a/internal/msggateway/init.go +++ b/internal/msggateway/init.go @@ -16,23 +16,35 @@ package msggateway import ( "context" + "github.com/openimsdk/open-im-server/v3/pkg/common/cmd" + "github.com/openimsdk/tools/utils/datautil" "time" - "github.com/openimsdk/open-im-server/v3/pkg/common/config" "github.com/openimsdk/tools/log" ) // Start run ws server. -func Start(ctx context.Context, conf *config.GlobalConfig, rpcPort, wsPort, prometheusPort int) error { - log.CInfo(ctx, "MSG-GATEWAY server is initializing", "rpcPort", rpcPort, "wsPort", wsPort, - "prometheusPort", prometheusPort) +func Start(ctx context.Context, index int, conf *cmd.MsgGatewayConfig) error { + log.CInfo(ctx, "MSG-GATEWAY server is initializing", "rpcPorts", conf.MsgGateway.RPC.Ports, + "wsPort", conf.MsgGateway.LongConnSvr.Ports, "prometheusPorts", conf.MsgGateway.Prometheus.Ports) + wsPort, err := datautil.GetElemByIndex(conf.MsgGateway.LongConnSvr.Ports, index) + if err != nil { + return err + } + prometheusPort, err := datautil.GetElemByIndex(conf.MsgGateway.Prometheus.Ports, index) + if err != nil { + return err + } + rpcPort, err := datautil.GetElemByIndex(conf.MsgGateway.RPC.Ports, index) + if err != nil { + return err + } longServer, err := NewWsServer( conf, WithPort(wsPort), - WithMaxConnNum(int64(conf.LongConnSvr.WebsocketMaxConnNum)), - WithHandshakeTimeout(time.Duration(conf.LongConnSvr.WebsocketTimeout)*time.Second), - WithMessageMaxMsgLength(conf.LongConnSvr.WebsocketMaxMsgLen), - WithWriteBufferSize(conf.LongConnSvr.WebsocketWriteBufferSize), + WithMaxConnNum(int64(conf.MsgGateway.LongConnSvr.WebsocketMaxConnNum)), + WithHandshakeTimeout(time.Duration(conf.MsgGateway.LongConnSvr.WebsocketTimeout)*time.Second), + WithMessageMaxMsgLength(conf.MsgGateway.LongConnSvr.WebsocketMaxMsgLen), ) if err != nil { return err @@ -41,7 +53,7 @@ func Start(ctx context.Context, conf *config.GlobalConfig, rpcPort, wsPort, prom hubServer := NewServer(rpcPort, prometheusPort, longServer, conf) netDone := make(chan error) go func() { - err = hubServer.Start(ctx, conf) + err = hubServer.Start(ctx, index, conf) netDone <- err }() return hubServer.LongConnServer.Run(netDone) diff --git a/internal/msggateway/message_handler.go b/internal/msggateway/message_handler.go index d568aeb3e..a1037a259 100644 --- a/internal/msggateway/message_handler.go +++ b/internal/msggateway/message_handler.go @@ -107,8 +107,8 @@ type GrpcHandler struct { } func NewGrpcHandler(validate *validator.Validate, client discovery.SvcDiscoveryRegistry, rpcRegisterName *config.RpcRegisterName) *GrpcHandler { - msgRpcClient := rpcclient.NewMessageRpcClient(client, rpcRegisterName.OpenImMsgName) - pushRpcClient := rpcclient.NewPushRpcClient(client, rpcRegisterName.OpenImPushName) + msgRpcClient := rpcclient.NewMessageRpcClient(client, rpcRegisterName.Msg) + pushRpcClient := rpcclient.NewPushRpcClient(client, rpcRegisterName.Push) return &GrpcHandler{ msgRpcClient: &msgRpcClient, pushClient: &pushRpcClient, validate: validate, diff --git a/internal/msggateway/n_ws_server.go b/internal/msggateway/n_ws_server.go index 526ea628c..12b0c7764 100644 --- a/internal/msggateway/n_ws_server.go +++ b/internal/msggateway/n_ws_server.go @@ -18,6 +18,7 @@ import ( "context" "encoding/json" "fmt" + "github.com/openimsdk/open-im-server/v3/pkg/common/cmd" "net/http" "strconv" "sync" @@ -26,7 +27,6 @@ import ( "github.com/go-playground/validator/v10" "github.com/openimsdk/open-im-server/v3/pkg/authverify" - "github.com/openimsdk/open-im-server/v3/pkg/common/config" "github.com/openimsdk/open-im-server/v3/pkg/common/db/cache" "github.com/openimsdk/open-im-server/v3/pkg/common/prommetrics" "github.com/openimsdk/open-im-server/v3/pkg/common/servererrs" @@ -49,7 +49,7 @@ type LongConnServer interface { GetUserPlatformCons(userID string, platform int) ([]*Client, bool, bool) Validate(s any) error SetCacheHandler(cache cache.TokenModel) - SetDiscoveryRegistry(client discovery.SvcDiscoveryRegistry, config *config.GlobalConfig) + SetDiscoveryRegistry(client discovery.SvcDiscoveryRegistry, config *cmd.MsgGatewayConfig) KickUserConn(client *Client) error UnRegister(c *Client) SetKickHandlerInfo(i *kickHandler) @@ -59,7 +59,7 @@ type LongConnServer interface { } type WsServer struct { - globalConfig *config.GlobalConfig + msgGatewayConfig *cmd.MsgGatewayConfig port int wsMaxConnNum int64 registerChan chan *Client @@ -86,9 +86,9 @@ type kickHandler struct { newClient *Client } -func (ws *WsServer) SetDiscoveryRegistry(disCov discovery.SvcDiscoveryRegistry, config *config.GlobalConfig) { - ws.MessageHandler = NewGrpcHandler(ws.validate, disCov, &config.RpcRegisterName) - u := rpcclient.NewUserRpcClient(disCov, config.RpcRegisterName.OpenImUserName, &config.Manager, &config.IMAdmin) +func (ws *WsServer) SetDiscoveryRegistry(disCov discovery.SvcDiscoveryRegistry, config *cmd.MsgGatewayConfig) { + ws.MessageHandler = NewGrpcHandler(ws.validate, disCov, &config.Share.RpcRegisterName) + u := rpcclient.NewUserRpcClient(disCov, config.Share.RpcRegisterName.User, &config.Share.IMAdmin) ws.userClient = &u ws.disCov = disCov } @@ -100,12 +100,12 @@ func (ws *WsServer) SetUserOnlineStatus(ctx context.Context, client *Client, sta } switch status { case constant.Online: - err := CallbackUserOnline(ctx, &ws.globalConfig.Callback, client.UserID, client.PlatformID, client.IsBackground, client.ctx.GetConnID()) + err := CallbackUserOnline(ctx, &ws.msgGatewayConfig.WebhooksConfig, client.UserID, client.PlatformID, client.IsBackground, client.ctx.GetConnID()) if err != nil { log.ZWarn(ctx, "CallbackUserOnline err", err) } case constant.Offline: - err := CallbackUserOffline(ctx, &ws.globalConfig.Callback, client.UserID, client.PlatformID, client.ctx.GetConnID()) + err := CallbackUserOffline(ctx, &ws.msgGatewayConfig.WebhooksConfig, client.UserID, client.PlatformID, client.ctx.GetConnID()) if err != nil { log.ZWarn(ctx, "CallbackUserOffline err", err) } @@ -132,14 +132,14 @@ func (ws *WsServer) GetUserPlatformCons(userID string, platform int) ([]*Client, return ws.clients.Get(userID, platform) } -func NewWsServer(globalConfig *config.GlobalConfig, opts ...Option) (*WsServer, error) { +func NewWsServer(msgGatewayConfig *cmd.MsgGatewayConfig, opts ...Option) (*WsServer, error) { var config configs for _, o := range opts { o(&config) } v := validator.New() return &WsServer{ - globalConfig: globalConfig, + msgGatewayConfig: msgGatewayConfig, port: config.port, wsMaxConnNum: config.maxConnNum, writeBufferSize: config.writeBufferSize, @@ -213,7 +213,7 @@ func (ws *WsServer) Run(done chan error) error { var concurrentRequest = 3 func (ws *WsServer) sendUserOnlineInfoToOtherNode(ctx context.Context, client *Client) error { - conns, err := ws.disCov.GetConns(ctx, ws.globalConfig.RpcRegisterName.OpenImMessageGatewayName) + conns, err := ws.disCov.GetConns(ctx, ws.msgGatewayConfig.Share.RpcRegisterName.MessageGateway) if err != nil { return err } @@ -278,7 +278,7 @@ func (ws *WsServer) registerClient(client *Client) { } wg := sync.WaitGroup{} - if ws.globalConfig.Envs.Discovery == "zookeeper" { + if ws.msgGatewayConfig.Share.Env == "zookeeper" { wg.Add(1) go func() { defer wg.Done() @@ -322,7 +322,7 @@ func (ws *WsServer) KickUserConn(client *Client) error { } func (ws *WsServer) multiTerminalLoginChecker(clientOK bool, oldClients []*Client, newClient *Client) { - switch ws.globalConfig.MultiLoginPolicy { + switch ws.msgGatewayConfig.MsgGateway.MultiLoginPolicy { case constant.DefalutNotKick: case constant.PCAndOther: if constant.PlatformIDToClass(newClient.PlatformID) == constant.TerminalPC { @@ -434,7 +434,7 @@ func (ws *WsServer) ParseWSArgs(r *http.Request) (args *WSArgs, err error) { return nil, servererrs.ErrConnArgsErr.WrapMsg("platformID is not int") } v.PlatformID = platformID - if err = authverify.WsVerifyToken(v.Token, v.UserID, ws.globalConfig.Secret, platformID); err != nil { + if err = authverify.WsVerifyToken(v.Token, v.UserID, ws.msgGatewayConfig.Share.Secret, platformID); err != nil { return nil, err } if query.Get(Compression) == GzipCompressionProtocol { diff --git a/pkg/common/cmd/msg_gateway.go b/pkg/common/cmd/msg_gateway.go index 58410c1b3..cf2b57b2e 100644 --- a/pkg/common/cmd/msg_gateway.go +++ b/pkg/common/cmd/msg_gateway.go @@ -16,65 +16,50 @@ package cmd import ( "context" - "log" + "github.com/openimsdk/open-im-server/v3/pkg/common/config" "github.com/openimsdk/open-im-server/v3/internal/msggateway" - config2 "github.com/openimsdk/open-im-server/v3/pkg/common/config" - "github.com/openimsdk/protocol/constant" + "github.com/openimsdk/tools/system/program" "github.com/spf13/cobra" ) type MsgGatewayCmd struct { *RootCmd - ctx context.Context + ctx context.Context + configMap map[string]StructEnvPrefix + msgGatewayConfig MsgGatewayConfig +} +type MsgGatewayConfig struct { + MsgGateway config.MsgGateway + RedisConfig config.Redis + ZookeeperConfig config.ZooKeeper + Share config.Share + WebhooksConfig config.Webhooks } -func NewMsgGatewayCmd(name string) *MsgGatewayCmd { - ret := &MsgGatewayCmd{RootCmd: NewRootCmd(program.GetProcessName(), name)} - ret.ctx = context.WithValue(context.Background(), "version", config2.Version) - ret.addRunE() - ret.SetRootCmdPt(ret) +func NewMsgGatewayCmd() *MsgGatewayCmd { + var msgGatewayConfig MsgGatewayConfig + ret := &MsgGatewayCmd{msgGatewayConfig: msgGatewayConfig} + ret.configMap = map[string]StructEnvPrefix{ + OpenIMAPICfgFileName: {EnvPrefix: apiEnvPrefix, ConfigStruct: &msgGatewayConfig.MsgGateway}, + RedisConfigFileName: {EnvPrefix: redisEnvPrefix, ConfigStruct: &msgGatewayConfig.RedisConfig}, + ZookeeperConfigFileName: {EnvPrefix: zoopkeeperEnvPrefix, ConfigStruct: &msgGatewayConfig.ZookeeperConfig}, + ShareFileName: {EnvPrefix: shareEnvPrefix, ConfigStruct: &msgGatewayConfig.Share}, + WebhooksConfigFileName: {EnvPrefix: webhooksEnvPrefix, ConfigStruct: &msgGatewayConfig.WebhooksConfig}, + } + ret.RootCmd = NewRootCmd(program.GetProcessName(), WithConfigMap(ret.configMap)) + ret.ctx = context.WithValue(context.Background(), "version", config.Version) + ret.Command.PreRunE = func(cmd *cobra.Command, args []string) error { + return ret.preRunE() + } return ret } -func (m *MsgGatewayCmd) AddWsPortFlag() { - m.Command.Flags().IntP(constant.FlagWsPort, "w", 0, "ws server listen port") -} - -func (m *MsgGatewayCmd) getWsPortFlag(cmd *cobra.Command) int { - port, err := cmd.Flags().GetInt(constant.FlagWsPort) - if err != nil { - log.Println("Error getting ws port flag:", err) - } - if port == 0 { - port = m.PortFromConfig(constant.FlagWsPort) - } - return port -} - -func (m *MsgGatewayCmd) addRunE() { - m.Command.RunE = func(cmd *cobra.Command, args []string) error { - return msggateway.Start(m.ctx, m.config, m.getPortFlag(cmd), m.getWsPortFlag(cmd), m.getPrometheusPortFlag(cmd)) - } -} - func (m *MsgGatewayCmd) Exec() error { return m.Execute() } -func (m *MsgGatewayCmd) GetPortFromConfig(portType string) int { - switch portType { - case constant.FlagWsPort: - return m.config.LongConnSvr.OpenImWsPort[0] - - case constant.FlagPort: - return m.config.LongConnSvr.OpenImMessageGatewayPort[0] - - case constant.FlagPrometheusPort: - return m.config.Prometheus.MessageGatewayPrometheusPort[0] - - default: - return 0 - } +func (m *MsgGatewayCmd) preRunE() error { + return msggateway.Start(m.ctx, m.Index(), &m.msgGatewayConfig) }