diff --git a/.golangci.yml b/.golangci.yml index 3dad0af30..050025b6e 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -510,7 +510,7 @@ linters-settings: nestif: # minimal complexity of if statements to report, 5 by default - min-complexity: 4 + min-complexity: 6 nilnil: # By default, nilnil checks all returned types below. diff --git a/go.work b/go.work index 1c819212c..33faf5195 100644 --- a/go.work +++ b/go.work @@ -1,16 +1,18 @@ go 1.19 + use ( . ./test/typecheck ./tools/changelog + ./tools/component + ./tools/data-conversion + ./tools/imctl //./tools/imctl ./tools/infra ./tools/ncpu ./tools/openim-web + ./tools/url2im ./tools/versionchecker ./tools/yamlfmt - ./tools/component - ./tools/url2im - ./tools/data-conversion ) diff --git a/internal/msgtransfer/init.go b/internal/msgtransfer/init.go index 4487826ee..7babd9a07 100644 --- a/internal/msgtransfer/init.go +++ b/internal/msgtransfer/init.go @@ -16,10 +16,11 @@ package msgtransfer import ( "fmt" - "github.com/openimsdk/open-im-server/v3/pkg/common/discovery_register" + + "sync" + "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" - "sync" "github.com/OpenIMSDK/tools/mw" @@ -29,6 +30,7 @@ import ( "github.com/openimsdk/open-im-server/v3/pkg/common/db/relation" relationtb "github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation" "github.com/openimsdk/open-im-server/v3/pkg/common/db/unrelation" + "github.com/openimsdk/open-im-server/v3/pkg/common/discovery_register" "github.com/openimsdk/open-im-server/v3/pkg/common/prome" "github.com/openimsdk/open-im-server/v3/pkg/rpcclient" ) @@ -63,8 +65,7 @@ func StartTransfer(prometheusPort int) error { /* client, err := openkeeper.NewClient(config.Config.Zookeeper.ZkAddr, config.Config.Zookeeper.Schema, openkeeper.WithFreq(time.Hour), openkeeper.WithRoundRobin(), openkeeper.WithUserNameAndPassword(config.Config.Zookeeper.Username, - config.Config.Zookeeper.Password), openkeeper.WithTimeout(10), openkeeper.WithLogger(log.NewZkLogger()))*/ - if err != nil { + config.Config.Zookeeper.Password), openkeeper.WithTimeout(10), openkeeper.WithLogger(log.NewZkLogger()))*/if err != nil { return err } if err := client.CreateRpcRootNodes(config.Config.GetServiceNames()); err != nil { diff --git a/internal/rpc/msg/as_read.go b/internal/rpc/msg/as_read.go index c31cd02dd..efa322774 100644 --- a/internal/rpc/msg/as_read.go +++ b/internal/rpc/msg/as_read.go @@ -165,7 +165,6 @@ func (m *msgServer) MarkConversationAsRead( m.conversationAndGetRecvID(conversation, req.UserID), seqs, hasReadSeq); err != nil { return nil, err } - } else if conversation.ConversationType == constant.SuperGroupChatType { if req.HasReadSeq > hasReadSeq { err = m.MsgDatabase.SetHasReadSeq(ctx, req.UserID, req.ConversationID, req.HasReadSeq) @@ -178,7 +177,6 @@ func (m *msgServer) MarkConversationAsRead( req.UserID, seqs, hasReadSeq); err != nil { return nil, err } - } return &msg.MarkConversationAsReadResp{}, nil diff --git a/internal/rpc/user/user.go b/internal/rpc/user/user.go index f2ceb3beb..83573eeef 100644 --- a/internal/rpc/user/user.go +++ b/internal/rpc/user/user.go @@ -290,7 +290,8 @@ func (s *userServer) SubscribeOrCancelUsersStatus(ctx context.Context, req *pbus // GetUserStatus Get the online status of the user. func (s *userServer) GetUserStatus(ctx context.Context, req *pbuser.GetUserStatusReq) (resp *pbuser.GetUserStatusResp, - err error) { + err error, +) { onlineStatusList, err := s.UserDatabase.GetUserStatus(ctx, req.UserIDs) if err != nil { return nil, err @@ -300,7 +301,8 @@ func (s *userServer) GetUserStatus(ctx context.Context, req *pbuser.GetUserStatu // SetUserStatus Synchronize user's online status. func (s *userServer) SetUserStatus(ctx context.Context, req *pbuser.SetUserStatusReq) (resp *pbuser.SetUserStatusResp, - err error) { + err error, +) { err = s.UserDatabase.SetUserStatus(ctx, req.UserID, req.Status, req.PlatformID) if err != nil { return nil, err @@ -324,7 +326,8 @@ func (s *userServer) SetUserStatus(ctx context.Context, req *pbuser.SetUserStatu // GetSubscribeUsersStatus Get the online status of subscribers. func (s *userServer) GetSubscribeUsersStatus(ctx context.Context, - req *pbuser.GetSubscribeUsersStatusReq) (*pbuser.GetSubscribeUsersStatusResp, error) { + req *pbuser.GetSubscribeUsersStatusReq, +) (*pbuser.GetSubscribeUsersStatusResp, error) { userList, err := s.UserDatabase.GetAllSubscribeList(ctx, req.UserID) if err != nil { return nil, err diff --git a/internal/tools/msg.go b/internal/tools/msg.go index 5397689b2..94ce2dec0 100644 --- a/internal/tools/msg.go +++ b/internal/tools/msg.go @@ -17,11 +17,11 @@ package tools import ( "context" "fmt" - "github.com/openimsdk/open-im-server/v3/pkg/common/discovery_register" + "math" + "github.com/redis/go-redis/v9" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" - "math" "github.com/OpenIMSDK/tools/errs" "github.com/OpenIMSDK/tools/log" @@ -35,6 +35,7 @@ import ( "github.com/openimsdk/open-im-server/v3/pkg/common/db/controller" "github.com/openimsdk/open-im-server/v3/pkg/common/db/relation" "github.com/openimsdk/open-im-server/v3/pkg/common/db/unrelation" + "github.com/openimsdk/open-im-server/v3/pkg/common/discovery_register" "github.com/openimsdk/open-im-server/v3/pkg/rpcclient" "github.com/openimsdk/open-im-server/v3/pkg/rpcclient/notification" ) @@ -76,8 +77,7 @@ func InitMsgTool() (*MsgTool, error) { /* discov, err := zookeeper.NewClient(config.Config.Zookeeper.ZkAddr, config.Config.Zookeeper.Schema, zookeeper.WithFreq(time.Hour), zookeeper.WithRoundRobin(), zookeeper.WithUserNameAndPassword(config.Config.Zookeeper.Username, - config.Config.Zookeeper.Password), zookeeper.WithTimeout(10), zookeeper.WithLogger(log.NewZkLogger()))*/ - if err != nil { + config.Config.Zookeeper.Password), zookeeper.WithTimeout(10), zookeeper.WithLogger(log.NewZkLogger()))*/if err != nil { return nil, err } discov.AddOption(mw.GrpcClient(), grpc.WithTransportCredentials(insecure.NewCredentials())) diff --git a/pkg/authverify/token.go b/pkg/authverify/token.go index a8e577fde..724d4934f 100644 --- a/pkg/authverify/token.go +++ b/pkg/authverify/token.go @@ -41,6 +41,7 @@ func CheckAccessV3(ctx context.Context, ownerUserID string) (err error) { if opUserID == ownerUserID { return nil } + return errs.ErrNoPermission.Wrap(utils.GetSelfFuncName()) } @@ -52,6 +53,7 @@ func CheckAdmin(ctx context.Context) error { if utils.IsContain(mcontext.GetOpUserID(ctx), config.Config.Manager.UserID) { return nil } + return errs.ErrNoPermission.Wrap(fmt.Sprintf("user %s is not admin userID", mcontext.GetOpUserID(ctx))) } @@ -74,5 +76,6 @@ func WsVerifyToken(token, userID string, platformID int) error { if claim.PlatformID != platformID { return errs.ErrTokenInvalid.Wrap(fmt.Sprintf("token platform %d != %d", claim.PlatformID, platformID)) } + return nil } diff --git a/pkg/callbackstruct/common.go b/pkg/callbackstruct/common.go index ef84d52b9..8b320a04f 100644 --- a/pkg/callbackstruct/common.go +++ b/pkg/callbackstruct/common.go @@ -61,6 +61,7 @@ func (c CommonCallbackResp) Parse() error { if c.ActionCode != errs.NoError || c.ErrCode != errs.NoError { return errs.NewCodeError(int(c.ErrCode), c.ErrMsg).WithDetail(c.ErrDlt) } + return nil } diff --git a/pkg/common/cmd/api.go b/pkg/common/cmd/api.go index 7ce872fac..98a200f14 100644 --- a/pkg/common/cmd/api.go +++ b/pkg/common/cmd/api.go @@ -16,9 +16,11 @@ package cmd import ( "fmt" + "github.com/OpenIMSDK/protocol/constant" - config2 "github.com/openimsdk/open-im-server/v3/pkg/common/config" "github.com/spf13/cobra" + + config2 "github.com/openimsdk/open-im-server/v3/pkg/common/config" ) type ApiCmd struct { @@ -42,8 +44,7 @@ func (a *ApiCmd) GetPortFromConfig(portType string) int { fmt.Println("GetPortFromConfig:", portType) if portType == constant.FlagPort { return config2.Config.Api.OpenImApiPort[0] - } else { - - return 0 } + + return 0 } diff --git a/pkg/common/cmd/cron_task.go b/pkg/common/cmd/cron_task.go index 1b0e796ac..c1e18e5cf 100644 --- a/pkg/common/cmd/cron_task.go +++ b/pkg/common/cmd/cron_task.go @@ -23,6 +23,7 @@ type CronTaskCmd struct { func NewCronTaskCmd() *CronTaskCmd { ret := &CronTaskCmd{NewRootCmd("cronTask", WithCronTaskLogName())} ret.SetRootCmdPt(ret) + return ret } @@ -34,5 +35,6 @@ func (c *CronTaskCmd) addRunE(f func() error) { func (c *CronTaskCmd) Exec(f func() error) error { c.addRunE(f) + return c.Execute() } diff --git a/pkg/common/cmd/msg_gateway.go b/pkg/common/cmd/msg_gateway.go index c96bbd7af..c9ecfa42b 100644 --- a/pkg/common/cmd/msg_gateway.go +++ b/pkg/common/cmd/msg_gateway.go @@ -31,6 +31,7 @@ type MsgGatewayCmd struct { func NewMsgGatewayCmd() *MsgGatewayCmd { ret := &MsgGatewayCmd{NewRootCmd("msgGateway")} ret.SetRootCmdPt(ret) + return ret } @@ -43,6 +44,7 @@ func (m *MsgGatewayCmd) getWsPortFlag(cmd *cobra.Command) int { if port == 0 { port = m.PortFromConfig(constant.FlagWsPort) } + return port } @@ -54,8 +56,10 @@ func (m *MsgGatewayCmd) addRunE() { func (m *MsgGatewayCmd) Exec() error { m.addRunE() + return m.Execute() } + func (m *MsgGatewayCmd) GetPortFromConfig(portType string) int { if portType == constant.FlagWsPort { return config2.Config.LongConnSvr.OpenImWsPort[0] diff --git a/pkg/common/cmd/msg_transfer.go b/pkg/common/cmd/msg_transfer.go index 20349ebbb..1e0449b4c 100644 --- a/pkg/common/cmd/msg_transfer.go +++ b/pkg/common/cmd/msg_transfer.go @@ -27,6 +27,7 @@ type MsgTransferCmd struct { func NewMsgTransferCmd() *MsgTransferCmd { ret := &MsgTransferCmd{NewRootCmd("msgTransfer")} ret.SetRootCmdPt(ret) + return ret } @@ -38,5 +39,6 @@ func (m *MsgTransferCmd) addRunE() { func (m *MsgTransferCmd) Exec() error { m.addRunE() + return m.Execute() } diff --git a/pkg/common/cmd/msg_utils.go b/pkg/common/cmd/msg_utils.go index cfaf631ec..82306da8c 100644 --- a/pkg/common/cmd/msg_utils.go +++ b/pkg/common/cmd/msg_utils.go @@ -22,7 +22,7 @@ import ( type MsgUtilsCmd struct { cobra.Command - msgTool *tools.MsgTool + // msgTool *tools.MsgTool } func (m *MsgUtilsCmd) AddUserIDFlag() { @@ -31,6 +31,7 @@ func (m *MsgUtilsCmd) AddUserIDFlag() { func (m *MsgUtilsCmd) getUserIDFlag(cmdLines *cobra.Command) string { userID, _ := cmdLines.Flags().GetString("userID") + return userID } @@ -38,26 +39,17 @@ func (m *MsgUtilsCmd) AddFixAllFlag() { m.Command.PersistentFlags().BoolP("fixAll", "f", false, "openIM fix all seqs") } -func (m *MsgUtilsCmd) getFixAllFlag(cmdLines *cobra.Command) bool { - fixAll, _ := cmdLines.Flags().GetBool("fixAll") - return fixAll -} - func (m *MsgUtilsCmd) AddClearAllFlag() { m.Command.PersistentFlags().BoolP("clearAll", "c", false, "openIM clear all seqs") } -func (m *MsgUtilsCmd) getClearAllFlag(cmdLines *cobra.Command) bool { - clearAll, _ := cmdLines.Flags().GetBool("clearAll") - return clearAll -} - func (m *MsgUtilsCmd) AddSuperGroupIDFlag() { m.Command.PersistentFlags().StringP("superGroupID", "g", "", "openIM superGroupID") } func (m *MsgUtilsCmd) getSuperGroupIDFlag(cmdLines *cobra.Command) string { superGroupID, _ := cmdLines.Flags().GetString("superGroupID") + return superGroupID } @@ -65,20 +57,10 @@ func (m *MsgUtilsCmd) AddBeginSeqFlag() { m.Command.PersistentFlags().Int64P("beginSeq", "b", 0, "openIM beginSeq") } -func (m *MsgUtilsCmd) getBeginSeqFlag(cmdLines *cobra.Command) int64 { - beginSeq, _ := cmdLines.Flags().GetInt64("beginSeq") - return beginSeq -} - func (m *MsgUtilsCmd) AddLimitFlag() { m.Command.PersistentFlags().Int64P("limit", "l", 0, "openIM limit") } -func (m *MsgUtilsCmd) getLimitFlag(cmdLines *cobra.Command) int64 { - limit, _ := cmdLines.Flags().GetInt64("limit") - return limit -} - func (m *MsgUtilsCmd) Execute() error { return m.Command.Execute() } @@ -131,6 +113,7 @@ func NewSeqCmd() *SeqCmd { seqCmd := &SeqCmd{ NewMsgUtilsCmd("seq", "seq", nil), } + return seqCmd } @@ -158,6 +141,7 @@ func (s *SeqCmd) GetSeqCmd() *cobra.Command { // println(seq) } } + return &s.Command } @@ -173,6 +157,7 @@ func NewMsgCmd() *MsgCmd { msgCmd := &MsgCmd{ NewMsgUtilsCmd("msg", "msg", nil), } + return msgCmd } diff --git a/pkg/common/cmd/root.go b/pkg/common/cmd/root.go index 4ea36cfeb..d1deb628c 100644 --- a/pkg/common/cmd/root.go +++ b/pkg/common/cmd/root.go @@ -17,8 +17,6 @@ package cmd import ( "fmt" - config2 "github.com/openimsdk/open-im-server/v3/pkg/common/config" - "github.com/spf13/cobra" "github.com/OpenIMSDK/protocol/constant" @@ -61,62 +59,81 @@ func NewRootCmd(name string, opts ...func(*CmdOpts)) *RootCmd { Short: fmt.Sprintf(`Start %s `, name), Long: fmt.Sprintf(`Start %s `, name), PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - return rootCmd.persistentPreRun(cmd, opts...) + if err := rootCmd.getConfFromCmdAndInit(cmd); err != nil { + panic(err) + } + cmdOpts := &CmdOpts{} + for _, opt := range opts { + opt(cmdOpts) + } + if cmdOpts.loggerPrefixName == "" { + cmdOpts.loggerPrefixName = "OpenIM.log.all" + } + err := log.InitFromConfig(cmdOpts.loggerPrefixName, name, config.Config.Log.RemainLogLevel, + config.Config.Log.IsStdout, config.Config.Log.IsJson, config.Config.Log.StorageLocation, + config.Config.Log.RemainRotationCount, config.Config.Log.RotationTime) + if err != nil { + panic(err) + } + + return nil }, } rootCmd.Command = cmd rootCmd.addConfFlag() + return rootCmd } -func (rc *RootCmd) persistentPreRun(cmd *cobra.Command, opts ...func(*CmdOpts)) error { - if err := rc.initializeConfiguration(cmd); err != nil { - return fmt.Errorf("failed to get configuration from command: %w", err) - } +// func (rc *RootCmd) persistentPreRun(cmd *cobra.Command, opts ...func(*CmdOpts)) error { +// if err := rc.initializeConfiguration(cmd); err != nil { +// return fmt.Errorf("failed to get configuration from command: %w", err) +// } - cmdOpts := rc.applyOptions(opts...) +// cmdOpts := rc.applyOptions(opts...) - if err := rc.initializeLogger(cmdOpts); err != nil { - return fmt.Errorf("failed to initialize from config: %w", err) - } +// if err := rc.initializeLogger(cmdOpts); err != nil { +// return fmt.Errorf("failed to initialize from config: %w", err) +// } - return nil -} +// return nil +// } +//nolint:unused //unused work wrongly func (rc *RootCmd) initializeConfiguration(cmd *cobra.Command) error { return rc.getConfFromCmdAndInit(cmd) } -func (rc *RootCmd) applyOptions(opts ...func(*CmdOpts)) *CmdOpts { - cmdOpts := defaultCmdOpts() - for _, opt := range opts { - opt(cmdOpts) - } +// func (rc *RootCmd) applyOptions(opts ...func(*CmdOpts)) *CmdOpts { +// cmdOpts := defaultCmdOpts() +// for _, opt := range opts { +// opt(cmdOpts) +// } - return cmdOpts -} +// return cmdOpts +// } -func (rc *RootCmd) initializeLogger(cmdOpts *CmdOpts) error { - logConfig := config.Config.Log - - return log.InitFromConfig( - - cmdOpts.loggerPrefixName, - rc.Name, - logConfig.RemainLogLevel, - logConfig.IsStdout, - logConfig.IsJson, - logConfig.StorageLocation, - logConfig.RemainRotationCount, - logConfig.RotationTime, - ) -} +// func (rc *RootCmd) initializeLogger(cmdOpts *CmdOpts) error { +// logConfig := config.Config.Log -func defaultCmdOpts() *CmdOpts { - return &CmdOpts{ - loggerPrefixName: "OpenIM.log.all", - } -} +// return log.InitFromConfig( + +// cmdOpts.loggerPrefixName, +// rc.Name, +// logConfig.RemainLogLevel, +// logConfig.IsStdout, +// logConfig.IsJson, +// logConfig.StorageLocation, +// logConfig.RemainRotationCount, +// logConfig.RotationTime, +// ) +// } + +// func defaultCmdOpts() *CmdOpts { +// return &CmdOpts{ +// loggerPrefixName: "OpenIM.log.all", +// } +// } func (r *RootCmd) SetRootCmdPt(cmdItf RootCmdPt) { r.cmdItf = cmdItf @@ -135,6 +152,7 @@ func (r *RootCmd) getPortFlag(cmd *cobra.Command) int { if port == 0 { port = r.PortFromConfig(constant.FlagPort) } + return port } @@ -151,6 +169,7 @@ func (r *RootCmd) getPrometheusPortFlag(cmd *cobra.Command) int { if port == 0 { port = r.PortFromConfig(constant.FlagPrometheusPort) } + return port } @@ -161,7 +180,8 @@ func (r *RootCmd) GetPrometheusPortFlag() int { func (r *RootCmd) getConfFromCmdAndInit(cmdLines *cobra.Command) error { configFolderPath, _ := cmdLines.Flags().GetString(constant.FlagConf) fmt.Println("configFolderPath:", configFolderPath) - return config2.InitConfig(configFolderPath) + + return config.InitConfig(configFolderPath) } func (r *RootCmd) Execute() error { @@ -174,9 +194,12 @@ func (r *RootCmd) AddCommand(cmds ...*cobra.Command) { func (r *RootCmd) GetPortFromConfig(portType string) int { fmt.Println("RootCmd.GetPortFromConfig:", portType) + return 0 } + func (r *RootCmd) PortFromConfig(portType string) int { fmt.Println("PortFromConfig:", portType) + return r.cmdItf.GetPortFromConfig(portType) } diff --git a/pkg/common/cmd/rpc.go b/pkg/common/cmd/rpc.go index 224edc0a0..6d34c6603 100644 --- a/pkg/common/cmd/rpc.go +++ b/pkg/common/cmd/rpc.go @@ -16,11 +16,13 @@ package cmd import ( "errors" + "github.com/OpenIMSDK/protocol/constant" - config2 "github.com/openimsdk/open-im-server/v3/pkg/common/config" "github.com/spf13/cobra" "google.golang.org/grpc" + config2 "github.com/openimsdk/open-im-server/v3/pkg/common/config" + "github.com/OpenIMSDK/tools/discoveryregistry" "github.com/openimsdk/open-im-server/v3/pkg/common/startrpc" @@ -33,6 +35,7 @@ type RpcCmd struct { func NewRpcCmd(name string) *RpcCmd { ret := &RpcCmd{NewRootCmd(name)} ret.SetRootCmdPt(ret) + return ret } @@ -41,6 +44,7 @@ func (a *RpcCmd) Exec() error { a.port = a.getPortFlag(cmd) a.prometheusPort = a.getPrometheusPortFlag(cmd) } + return a.Execute() } @@ -51,8 +55,10 @@ func (a *RpcCmd) StartSvr( if a.GetPortFlag() == 0 { return errors.New("port is required") } + return startrpc.Start(a.GetPortFlag(), name, a.GetPrometheusPortFlag(), rpcFn) } + func (a *RpcCmd) GetPortFromConfig(portType string) int { switch a.Name { case RpcPushServer: @@ -88,5 +94,6 @@ func (a *RpcCmd) GetPortFromConfig(portType string) int { return config2.Config.RpcPort.OpenImUserPort[0] } } + return 0 } diff --git a/pkg/common/config/config.go b/pkg/common/config/config.go index 95f4a864e..d521fbd51 100644 --- a/pkg/common/config/config.go +++ b/pkg/common/config/config.go @@ -336,6 +336,7 @@ func (c *configStruct) RegisterConf2Registry(registry discoveryregistry.SvcDisco if err != nil { return err } + return registry.RegisterConf2Registry(ConfKey, data) } @@ -348,5 +349,6 @@ func (c *configStruct) EncodeConfig() []byte { if err := yaml.NewEncoder(buf).Encode(c); err != nil { panic(err) } + return buf.Bytes() } diff --git a/pkg/common/config/parse.go b/pkg/common/config/parse.go index e37514ecd..86292a597 100644 --- a/pkg/common/config/parse.go +++ b/pkg/common/config/parse.go @@ -21,8 +21,9 @@ import ( "path/filepath" "github.com/OpenIMSDK/protocol/constant" - "github.com/openimsdk/open-im-server/v3/pkg/msgprocessor" "gopkg.in/yaml.v3" + + "github.com/openimsdk/open-im-server/v3/pkg/msgprocessor" ) //go:embed version @@ -34,7 +35,7 @@ const ( DefaultFolderPath = "../config/" ) -// getProjectRoot returns the absolute path of the project root directory +// getProjectRoot returns the absolute path of the project root directory. func GetProjectRoot() string { b, _ := filepath.Abs(os.Args[0]) @@ -56,6 +57,7 @@ func GetOptionsByNotification(cfg NotificationConf) msgprocessor.Options { opts = msgprocessor.WithOptions(opts, msgprocessor.WithHistory(true), msgprocessor.WithPersistent()) } opts = msgprocessor.WithOptions(opts, msgprocessor.WithSendMsg(cfg.IsSendMsg)) + return opts } @@ -76,6 +78,7 @@ func initConfig(config interface{}, configName, configFolderPath string) error { return fmt.Errorf("unmarshal yaml error: %w", err) } fmt.Println("use config", configFolderPath) + return nil } @@ -92,6 +95,6 @@ func InitConfig(configFolderPath string) error { if err := initConfig(&Config, FileName, configFolderPath); err != nil { return err } - - return initConfig(&Config.Notification, NotificationFileName, configFolderPath) + + return nil } diff --git a/pkg/common/convert/black.go b/pkg/common/convert/black.go index 50c270dcb..9c862d5b7 100644 --- a/pkg/common/convert/black.go +++ b/pkg/common/convert/black.go @@ -18,7 +18,6 @@ import ( "context" "github.com/OpenIMSDK/protocol/sdkws" - sdk "github.com/OpenIMSDK/protocol/sdkws" "github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation" ) @@ -27,11 +26,11 @@ func BlackDB2Pb( ctx context.Context, blackDBs []*relation.BlackModel, f func(ctx context.Context, userIDs []string) (map[string]*sdkws.UserInfo, error), -) (blackPbs []*sdk.BlackInfo, err error) { +) (blackPbs []*sdkws.BlackInfo, err error) { if len(blackDBs) == 0 { return nil, nil } - var userIDs []string + userIDs := make([]string, 0, len(blackDBs)) for _, blackDB := range blackDBs { userIDs = append(userIDs, blackDB.BlockUserID) } @@ -40,7 +39,7 @@ func BlackDB2Pb( return nil, err } for _, blackDB := range blackDBs { - blackPb := &sdk.BlackInfo{ + blackPb := &sdkws.BlackInfo{ OwnerUserID: blackDB.OwnerUserID, CreateTime: blackDB.CreateTime.Unix(), AddSource: blackDB.AddSource, @@ -55,5 +54,6 @@ func BlackDB2Pb( } blackPbs = append(blackPbs, blackPb) } + return blackPbs, nil } diff --git a/pkg/common/convert/conversation.go b/pkg/common/convert/conversation.go index 165262b7f..295ff4c82 100644 --- a/pkg/common/convert/conversation.go +++ b/pkg/common/convert/conversation.go @@ -27,6 +27,7 @@ func ConversationDB2Pb(conversationDB *relation.ConversationModel) *conversation if err := utils.CopyStructFields(conversationPB, conversationDB); err != nil { return nil } + return conversationPB } @@ -39,6 +40,7 @@ func ConversationsDB2Pb(conversationsDB []*relation.ConversationModel) (conversa conversationPB.LatestMsgDestructTime = conversationDB.LatestMsgDestructTime.Unix() conversationsPB = append(conversationsPB, conversationPB) } + return conversationsPB } @@ -47,6 +49,7 @@ func ConversationPb2DB(conversationPB *conversation.Conversation) *relation.Conv if err := utils.CopyStructFields(conversationDB, conversationPB); err != nil { return nil } + return conversationDB } @@ -58,5 +61,6 @@ func ConversationsPb2DB(conversationsPB []*conversation.Conversation) (conversat } conversationsDB = append(conversationsDB, conversationDB) } + return conversationsDB } diff --git a/pkg/common/convert/friend.go b/pkg/common/convert/friend.go index 7003c8aa6..41907d5ca 100644 --- a/pkg/common/convert/friend.go +++ b/pkg/common/convert/friend.go @@ -16,6 +16,7 @@ package convert import ( "context" + "fmt" "github.com/OpenIMSDK/protocol/sdkws" "github.com/OpenIMSDK/tools/utils" @@ -25,9 +26,13 @@ import ( func FriendPb2DB(friend *sdkws.FriendInfo) *relation.FriendModel { dbFriend := &relation.FriendModel{} - utils.CopyStructFields(dbFriend, friend) + err := utils.CopyStructFields(dbFriend, friend) + if err != nil { + panic(err) + } dbFriend.FriendUserID = friend.FriendUser.UserID dbFriend.CreateTime = utils.UnixSecondToTime(friend.CreateTime) + return dbFriend } @@ -37,7 +42,10 @@ func FriendDB2Pb( getUsers func(ctx context.Context, userIDs []string) (map[string]*sdkws.UserInfo, error), ) (*sdkws.FriendInfo, error) { pbfriend := &sdkws.FriendInfo{FriendUser: &sdkws.UserInfo{}} - utils.CopyStructFields(pbfriend, friendDB) + err := utils.CopyStructFields(pbfriend, friendDB) + if err != nil { + panic(err) + } users, err := getUsers(ctx, []string{friendDB.FriendUserID}) if err != nil { return nil, err @@ -47,6 +55,7 @@ func FriendDB2Pb( pbfriend.FriendUser.FaceURL = users[friendDB.FriendUserID].FaceURL pbfriend.FriendUser.Ex = users[friendDB.FriendUserID].Ex pbfriend.CreateTime = friendDB.CreateTime.Unix() + return pbfriend, nil } @@ -58,7 +67,7 @@ func FriendsDB2Pb( if len(friendsDB) == 0 { return nil, nil } - var userID []string + userID := make([]string, 0, len(friendsDB)) for _, friendDB := range friendsDB { userID = append(userID, friendDB.FriendUserID) } @@ -68,7 +77,8 @@ func FriendsDB2Pb( } for _, friend := range friendsDB { friendPb := &sdkws.FriendInfo{FriendUser: &sdkws.UserInfo{}} - utils.CopyStructFields(friendPb, friend) + err2 := utils.CopyStructFields(friendPb, friend) + err = fmt.Errorf("%w, %w", err, err2) friendPb.FriendUser.UserID = users[friend.FriendUserID].UserID friendPb.FriendUser.Nickname = users[friend.FriendUserID].Nickname friendPb.FriendUser.FaceURL = users[friend.FriendUserID].FaceURL @@ -76,7 +86,8 @@ func FriendsDB2Pb( friendPb.CreateTime = friend.CreateTime.Unix() friendsPb = append(friendsPb, friendPb) } - return friendsPb, nil + + return friendsPb, err } func FriendRequestDB2Pb( @@ -116,5 +127,6 @@ func FriendRequestDB2Pb( Ex: friendRequest.Ex, }) } + return res, nil } diff --git a/pkg/common/convert/msg.go b/pkg/common/convert/msg.go index 56f71f018..23bd5dfea 100644 --- a/pkg/common/convert/msg.go +++ b/pkg/common/convert/msg.go @@ -55,6 +55,7 @@ func MsgPb2DB(msg *sdkws.MsgData) *unrelation.MsgDataModel { msgDataModel.AtUserIDList = msg.AtUserIDList msgDataModel.AttachedInfo = msg.AttachedInfo msgDataModel.Ex = msg.Ex + return &msgDataModel } @@ -95,5 +96,6 @@ func MsgDB2Pb(msgModel *unrelation.MsgDataModel) *sdkws.MsgData { msg.AtUserIDList = msgModel.AtUserIDList msg.AttachedInfo = msgModel.AttachedInfo msg.Ex = msgModel.Ex + return &msg } diff --git a/pkg/common/convert/user.go b/pkg/common/convert/user.go index abb3a2144..38496515f 100644 --- a/pkg/common/convert/user.go +++ b/pkg/common/convert/user.go @@ -34,6 +34,7 @@ func UsersDB2Pb(users []*relationtb.UserModel) (result []*sdkws.UserInfo) { userPb.GlobalRecvMsgOpt = user.GlobalRecvMsgOpt result = append(result, &userPb) } + return result } @@ -46,5 +47,6 @@ func UserPb2DB(user *sdkws.UserInfo) *relationtb.UserModel { userDB.CreateTime = time.UnixMilli(user.CreateTime) userDB.AppMangerLevel = user.AppMangerLevel userDB.GlobalRecvMsgOpt = user.GlobalRecvMsgOpt + return &userDB } diff --git a/pkg/common/db/cache/black.go b/pkg/common/db/cache/black.go index 6da7d5d05..d1abe945c 100644 --- a/pkg/common/db/cache/black.go +++ b/pkg/common/db/cache/black.go @@ -52,6 +52,7 @@ func NewBlackCacheRedis( options rockscache.Options, ) BlackCache { rcClient := rockscache.NewClient(rdb, options) + return &BlackCacheRedis{ expireTime: blackExpireTime, rcClient: rcClient, @@ -88,5 +89,6 @@ func (b *BlackCacheRedis) GetBlackIDs(ctx context.Context, userID string) (black func (b *BlackCacheRedis) DelBlackIDs(ctx context.Context, userID string) BlackCache { cache := b.NewCache() cache.AddKeys(b.getBlackIDsKey(userID)) + return cache } diff --git a/pkg/common/db/cache/conversation.go b/pkg/common/db/cache/conversation.go index d755de645..890552160 100644 --- a/pkg/common/db/cache/conversation.go +++ b/pkg/common/db/cache/conversation.go @@ -89,6 +89,7 @@ func NewConversationRedis( db relationtb.ConversationModelInterface, ) ConversationCache { rcClient := rockscache.NewClient(rdb, opts) + return &ConversationRedisCache{ rcClient: rcClient, metaCache: NewMetaCacheRedis(rcClient), @@ -110,6 +111,7 @@ func NewNewConversationRedis( options rockscache.Options, ) ConversationCache { rcClient := rockscache.NewClient(rdb, options) + return &ConversationRedisCache{ rcClient: rcClient, metaCache: NewMetaCacheRedis(rcClient), @@ -168,12 +170,13 @@ func (c *ConversationRedisCache) GetUserConversationIDs(ctx context.Context, own } func (c *ConversationRedisCache) DelConversationIDs(userIDs ...string) ConversationCache { - var keys []string + keys := make([]string, 0, len(userIDs)) for _, userID := range userIDs { keys = append(keys, c.getConversationIDsKey(userID)) } cache := c.NewCache() cache.AddKeys(keys...) + return cache } @@ -198,18 +201,20 @@ func (c *ConversationRedisCache) GetUserConversationIDsHash( utils.Sort(conversationIDs, true) bi := big.NewInt(0) bi.SetString(utils.Md5(strings.Join(conversationIDs, ";"))[0:8], 16) + return bi.Uint64(), nil }, ) } func (c *ConversationRedisCache) DelUserConversationIDsHash(ownerUserIDs ...string) ConversationCache { - var keys []string + keys := make([]string, 0, len(ownerUserIDs)) for _, ownerUserID := range ownerUserIDs { keys = append(keys, c.getUserConversationIDsHashKey(ownerUserID)) } cache := c.NewCache() cache.AddKeys(keys...) + return cache } @@ -229,12 +234,13 @@ func (c *ConversationRedisCache) GetConversation( } func (c *ConversationRedisCache) DelConversations(ownerUserID string, conversationIDs ...string) ConversationCache { - var keys []string + keys := make([]string, 0, len(conversationIDs)) for _, conversationID := range conversationIDs { keys = append(keys, c.getConversationKey(ownerUserID, conversationID)) } cache := c.NewCache() cache.AddKeys(keys...) + return cache } @@ -248,6 +254,7 @@ func (c *ConversationRedisCache) getConversationIndex( return _i, nil } } + return 0, errors.New("not found key:" + key + " in keys") } @@ -256,10 +263,11 @@ func (c *ConversationRedisCache) GetConversations( ownerUserID string, conversationIDs []string, ) ([]*relationtb.ConversationModel, error) { - var keys []string + keys := make([]string, 0, len(conversationIDs)) for _, conversarionID := range conversationIDs { keys = append(keys, c.getConversationKey(ownerUserID, conversarionID)) } + return batchGetCache( ctx, c.rcClient, @@ -280,10 +288,11 @@ func (c *ConversationRedisCache) GetUserAllConversations( if err != nil { return nil, err } - var keys []string + keys := make([]string, 0, len(conversationIDs)) for _, conversarionID := range conversationIDs { keys = append(keys, c.getConversationKey(ownerUserID, conversarionID)) } + return batchGetCache( ctx, c.rcClient, @@ -327,24 +336,27 @@ func (c *ConversationRedisCache) GetSuperGroupRecvMsgNotNotifyUserIDs( } func (c *ConversationRedisCache) DelUsersConversation(conversationID string, ownerUserIDs ...string) ConversationCache { - var keys []string + keys := make([]string, 0, len(ownerUserIDs)) for _, ownerUserID := range ownerUserIDs { keys = append(keys, c.getConversationKey(ownerUserID, conversationID)) } cache := c.NewCache() cache.AddKeys(keys...) + return cache } func (c *ConversationRedisCache) DelUserRecvMsgOpt(ownerUserID, conversationID string) ConversationCache { cache := c.NewCache() cache.AddKeys(c.getRecvMsgOptKey(ownerUserID, conversationID)) + return cache } func (c *ConversationRedisCache) DelSuperGroupRecvMsgNotNotifyUserIDs(groupID string) ConversationCache { cache := c.NewCache() cache.AddKeys(c.getSuperGroupRecvNotNotifyUserIDsKey(groupID)) + return cache } @@ -365,6 +377,7 @@ func (c *ConversationRedisCache) GetSuperGroupRecvMsgNotNotifyUserIDsHash( utils.Sort(userIDs, true) bi := big.NewInt(0) bi.SetString(utils.Md5(strings.Join(userIDs, ";"))[0:8], 16) + return bi.Uint64(), nil }, ) @@ -373,6 +386,7 @@ func (c *ConversationRedisCache) GetSuperGroupRecvMsgNotNotifyUserIDsHash( func (c *ConversationRedisCache) DelSuperGroupRecvMsgNotNotifyUserIDsHash(groupID string) ConversationCache { cache := c.NewCache() cache.AddKeys(c.getSuperGroupRecvNotNotifyUserIDsHashKey(groupID)) + return cache } @@ -385,6 +399,7 @@ func (c *ConversationRedisCache) getUserAllHasReadSeqsIndex( return _i, nil } } + return 0, errors.New("not found key:" + conversationID + " in keys") } @@ -396,10 +411,11 @@ func (c *ConversationRedisCache) GetUserAllHasReadSeqs( if err != nil { return nil, err } - var keys []string + keys := make([]string, 0, len(conversationIDs)) for _, conversarionID := range conversationIDs { keys = append(keys, c.getConversationHasReadSeqKey(ownerUserID, conversarionID)) } + return batchGetCacheMap( ctx, c.rcClient, @@ -420,6 +436,7 @@ func (c *ConversationRedisCache) DelUserAllHasReadSeqs(ownerUserID string, for _, conversationID := range conversationIDs { cache.AddKeys(c.getConversationHasReadSeqKey(ownerUserID, conversationID)) } + return cache } @@ -451,5 +468,6 @@ func (c *ConversationRedisCache) DelConversationNotReceiveMessageUserIDs(convers for _, conversationID := range conversationIDs { cache.AddKeys(c.getConversationNotReceiveMessageUserIDsKey(conversationID)) } + return cache } diff --git a/pkg/common/db/cache/friend.go b/pkg/common/db/cache/friend.go index fd8c1d3c0..37f5b0a98 100644 --- a/pkg/common/db/cache/friend.go +++ b/pkg/common/db/cache/friend.go @@ -59,6 +59,7 @@ func NewFriendCacheRedis( options rockscache.Options, ) FriendCache { rcClient := rockscache.NewClient(rdb, options) + return &FriendCacheRedis{ metaCache: NewMetaCacheRedis(rcClient), friendDB: friendDB, @@ -100,14 +101,15 @@ func (f *FriendCacheRedis) GetFriendIDs(ctx context.Context, ownerUserID string) ) } -func (f *FriendCacheRedis) DelFriendIDs(ownerUserID ...string) FriendCache { - new := f.NewCache() - var keys []string - for _, userID := range ownerUserID { +func (f *FriendCacheRedis) DelFriendIDs(ownerUserIDs ...string) FriendCache { + newGroupCache := f.NewCache() + keys := make([]string, 0, len(ownerUserIDs)) + for _, userID := range ownerUserIDs { keys = append(keys, f.getFriendIDsKey(userID)) } - new.AddKeys(keys...) - return new + newGroupCache.AddKeys(keys...) + + return newGroupCache } // todo. @@ -128,13 +130,15 @@ func (f *FriendCacheRedis) GetTwoWayFriendIDs( twoWayFriendIDs = append(twoWayFriendIDs, ownerUserID) } } + return twoWayFriendIDs, nil } func (f *FriendCacheRedis) DelTwoWayFriendIDs(ctx context.Context, ownerUserID string) FriendCache { - new := f.NewCache() - new.AddKeys(f.getTwoWayFriendsIDsKey(ownerUserID)) - return new + newFriendCache := f.NewCache() + newFriendCache.AddKeys(f.getTwoWayFriendsIDsKey(ownerUserID)) + + return newFriendCache } func (f *FriendCacheRedis) GetFriend( @@ -153,7 +157,8 @@ func (f *FriendCacheRedis) GetFriend( } func (f *FriendCacheRedis) DelFriend(ownerUserID, friendUserID string) FriendCache { - new := f.NewCache() - new.AddKeys(f.getFriendKey(ownerUserID, friendUserID)) - return new + newFriendCache := f.NewCache() + newFriendCache.AddKeys(f.getFriendKey(ownerUserID, friendUserID)) + + return newFriendCache } diff --git a/pkg/common/db/cache/group.go b/pkg/common/db/cache/group.go index 7d4c2b043..0505241d0 100644 --- a/pkg/common/db/cache/group.go +++ b/pkg/common/db/cache/group.go @@ -109,6 +109,7 @@ func NewGroupCacheRedis( opts rockscache.Options, ) GroupCache { rcClient := rockscache.NewClient(rdb, opts) + return &GroupCacheRedis{ rcClient: rcClient, expireTime: groupExpireTime, groupDB: groupDB, groupMemberDB: groupMemberDB, groupRequestDB: groupRequestDB, @@ -169,6 +170,7 @@ func (g *GroupCacheRedis) GetGroupIndex(group *relationtb.GroupModel, keys []str return i, nil } } + return 0, errIndex } @@ -179,6 +181,7 @@ func (g *GroupCacheRedis) GetGroupMemberIndex(groupMember *relationtb.GroupMembe return i, nil } } + return 0, errIndex } @@ -187,10 +190,11 @@ func (g *GroupCacheRedis) GetGroupsInfo( ctx context.Context, groupIDs []string, ) (groups []*relationtb.GroupModel, err error) { - var keys []string + keys := make([]string, 0, len(groupIDs)) for _, group := range groupIDs { keys = append(keys, g.getGroupInfoKey(group)) } + return batchGetCache( ctx, g.rcClient, @@ -216,13 +220,14 @@ func (g *GroupCacheRedis) GetGroupInfo(ctx context.Context, groupID string) (gro } func (g *GroupCacheRedis) DelGroupsInfo(groupIDs ...string) GroupCache { - new := g.NewCache() - var keys []string + newGroupCache := g.NewCache() + keys := make([]string, 0, len(groupIDs)) for _, groupID := range groupIDs { keys = append(keys, g.getGroupInfoKey(groupID)) } - new.AddKeys(keys...) - return new + newGroupCache.AddKeys(keys...) + + return newGroupCache } func (g *GroupCacheRedis) GetJoinedSuperGroupIDs( @@ -239,6 +244,7 @@ func (g *GroupCacheRedis) GetJoinedSuperGroupIDs( if err != nil { return nil, err } + return userGroup.GroupIDs, nil }, ) @@ -248,10 +254,11 @@ func (g *GroupCacheRedis) GetSuperGroupMemberIDs( ctx context.Context, groupIDs ...string, ) (models []*unrelationtb.SuperGroupModel, err error) { - var keys []string + keys := make([]string, 0, len(groupIDs)) for _, group := range groupIDs { keys = append(keys, g.getSuperGroupMemberIDsKey(group)) } + return batchGetCache( ctx, g.rcClient, @@ -263,6 +270,7 @@ func (g *GroupCacheRedis) GetSuperGroupMemberIDs( return i, nil } } + return 0, errIndex }, func(ctx context.Context) ([]*unrelationtb.SuperGroupModel, error) { @@ -273,23 +281,25 @@ func (g *GroupCacheRedis) GetSuperGroupMemberIDs( // userJoinSuperGroup. func (g *GroupCacheRedis) DelJoinedSuperGroupIDs(userIDs ...string) GroupCache { - new := g.NewCache() - var keys []string + newGroupCache := g.NewCache() + keys := make([]string, 0, len(userIDs)) for _, userID := range userIDs { keys = append(keys, g.getJoinedSuperGroupsIDKey(userID)) } - new.AddKeys(keys...) - return new + newGroupCache.AddKeys(keys...) + + return newGroupCache } func (g *GroupCacheRedis) DelSuperGroupMemberIDs(groupIDs ...string) GroupCache { - new := g.NewCache() - var keys []string + newGroupCache := g.NewCache() + keys := make([]string, 0, len(groupIDs)) for _, groupID := range groupIDs { keys = append(keys, g.getSuperGroupMemberIDsKey(groupID)) } - new.AddKeys(keys...) - return new + newGroupCache.AddKeys(keys...) + + return newGroupCache } // groupMembersHash. @@ -368,12 +378,14 @@ func (g *GroupCacheRedis) GetGroupMemberHashMap( } res[groupID] = &relationtb.GroupSimpleUserID{Hash: hash, MemberNum: uint32(num)} } + return res, nil } func (g *GroupCacheRedis) DelGroupMembersHash(groupID string) GroupCache { cache := g.NewCache() cache.AddKeys(g.getGroupMembersHashKey(groupID)) + return cache } @@ -399,12 +411,14 @@ func (g *GroupCacheRedis) GetGroupsMemberIDs(ctx context.Context, groupIDs []str } m[groupID] = userIDs } + return m, nil } func (g *GroupCacheRedis) DelGroupMemberIDs(groupID string) GroupCache { cache := g.NewCache() cache.AddKeys(g.getGroupMemberIDsKey(groupID)) + return cache } @@ -421,12 +435,13 @@ func (g *GroupCacheRedis) GetJoinedGroupIDs(ctx context.Context, userID string) } func (g *GroupCacheRedis) DelJoinedGroupID(userIDs ...string) GroupCache { - var keys []string + keys := make([]string, 0, len(userIDs)) for _, userID := range userIDs { keys = append(keys, g.getJoinedGroupsKey(userID)) } cache := g.NewCache() cache.AddKeys(keys...) + return cache } @@ -450,10 +465,11 @@ func (g *GroupCacheRedis) GetGroupMembersInfo( groupID string, userIDs []string, ) ([]*relationtb.GroupMemberModel, error) { - var keys []string + keys := make([]string, 0, len(userIDs)) for _, userID := range userIDs { keys = append(keys, g.getGroupMemberInfoKey(groupID, userID)) } + return batchGetCache( ctx, g.rcClient, @@ -482,6 +498,7 @@ func (g *GroupCacheRedis) GetGroupMembersPage( userIDs = groupMemberIDs } groupMembers, err = g.GetGroupMembersInfo(ctx, groupID, utils.Paginate(userIDs, int(showNumber), int(showNumber))) + return uint32(len(userIDs)), groupMembers, err } @@ -493,6 +510,7 @@ func (g *GroupCacheRedis) GetAllGroupMembersInfo( if err != nil { return nil, err } + return g.GetGroupMembersInfo(ctx, groupID, groupMemberIDs) } @@ -504,10 +522,11 @@ func (g *GroupCacheRedis) GetAllGroupMemberInfo( if err != nil { return nil, err } - var keys []string + keys := make([]string, 0, len(groupMemberIDs)) for _, groupMemberID := range groupMemberIDs { keys = append(keys, g.getGroupMemberInfoKey(groupID, groupMemberID)) } + return batchGetCache( ctx, g.rcClient, @@ -521,12 +540,13 @@ func (g *GroupCacheRedis) GetAllGroupMemberInfo( } func (g *GroupCacheRedis) DelGroupMembersInfo(groupID string, userIDs ...string) GroupCache { - var keys []string + keys := make([]string, 0, len(userIDs)) for _, userID := range userIDs { keys = append(keys, g.getGroupMemberInfoKey(groupID, userID)) } cache := g.NewCache() cache.AddKeys(keys...) + return cache } @@ -543,11 +563,12 @@ func (g *GroupCacheRedis) GetGroupMemberNum(ctx context.Context, groupID string) } func (g *GroupCacheRedis) DelGroupsMemberNum(groupID ...string) GroupCache { - var keys []string + keys := make([]string, 0, len(groupID)) for _, groupID := range groupID { keys = append(keys, g.getGroupMemberNumKey(groupID)) } cache := g.NewCache() cache.AddKeys(keys...) + return cache } diff --git a/pkg/common/db/cache/meta_cache.go b/pkg/common/db/cache/meta_cache.go index ca742d4a3..3d62255a7 100644 --- a/pkg/common/db/cache/meta_cache.go +++ b/pkg/common/db/cache/meta_cache.go @@ -72,6 +72,7 @@ func (m *metaCacheRedis) ExecDel(ctx context.Context) error { ), ) log.ZWarn(ctx, "delete cache failed, please handle keys", err, "keys", m.keys) + return err } retryTimes++ @@ -80,6 +81,7 @@ func (m *metaCacheRedis) ExecDel(ctx context.Context) error { } } } + return nil } @@ -103,6 +105,7 @@ func GetDefaultOpt() rockscache.Options { opts := rockscache.NewDefaultOptions() opts.StrongConsistency = true opts.RandomExpireAdjustment = 0.2 + return opts } @@ -125,6 +128,7 @@ func getCache[T any]( return "", utils.Wrap(err, "") } write = true + return string(bs), nil }) if err != nil { @@ -139,8 +143,10 @@ func getCache[T any]( err = json.Unmarshal([]byte(v), &t) if err != nil { log.ZError(ctx, "cache json.Unmarshal failed", err, "key", key, "value", v, "expire", expire) + return t, utils.Wrap(err, "") } + return t, nil } @@ -169,6 +175,7 @@ func batchGetCache[T any]( } values[index] = string(bs) } + return values, nil }) if err != nil { @@ -185,6 +192,7 @@ func batchGetCache[T any]( tArrays = append(tArrays, t) } } + return tArrays, nil } @@ -213,6 +221,7 @@ func batchGetCacheMap[T any]( } values[index] = string(bs) } + return values, nil }) if err != nil { @@ -229,5 +238,6 @@ func batchGetCacheMap[T any]( tMap[originKeys[i]] = t } } + return tMap, nil } diff --git a/pkg/common/db/cache/msg.go b/pkg/common/db/cache/msg.go index 65b8d63de..66161c424 100644 --- a/pkg/common/db/cache/msg.go +++ b/pkg/common/db/cache/msg.go @@ -16,13 +16,12 @@ package cache import ( "context" + "errors" "strconv" "time" "github.com/openimsdk/open-im-server/v3/pkg/msgprocessor" - "github.com/dtm-labs/rockscache" - "github.com/OpenIMSDK/tools/errs" "github.com/gogo/protobuf/jsonpb" @@ -33,7 +32,6 @@ import ( "github.com/OpenIMSDK/tools/utils" "github.com/openimsdk/open-im-server/v3/pkg/common/config" - unrelationtb "github.com/openimsdk/open-im-server/v3/pkg/common/db/table/unrelation" "github.com/redis/go-redis/v9" ) @@ -140,10 +138,10 @@ func NewMsgCacheModel(client redis.UniversalClient) MsgModel { type msgCache struct { metaCache - rdb redis.UniversalClient - expireTime time.Duration - rcClient *rockscache.Client - msgDocDatabase unrelationtb.MsgDocModelInterface + rdb redis.UniversalClient + // expireTime time.Duration + // rcClient *rockscache.Client + // msgDocDatabase unrelationtb.MsgDocModelInterface } func (c *msgCache) getMaxSeqKey(conversationID string) string { @@ -182,18 +180,20 @@ func (c *msgCache) getSeqs( ) (m map[string]int64, err error) { pipe := c.rdb.Pipeline() for _, v := range items { - if err := pipe.Get(ctx, getkey(v)).Err(); err != nil && err != redis.Nil { - return nil, errs.Wrap(err) + err2 := pipe.Get(ctx, getkey(v)).Err() + if err2 != nil && !errors.Is(err2, redis.Nil) { + return nil, errs.Wrap(err2) } } result, err := pipe.Exec(ctx) - if err != nil && err != redis.Nil { + if err != nil && !errors.Is(err, redis.Nil) { return nil, errs.Wrap(err) } m = make(map[string]int64, len(items)) for i, v := range result { seq := v.(*redis.StringCmd) - if seq.Err() != nil && seq.Err() != redis.Nil { + + if seq.Err() != nil && !errors.Is(seq.Err(), redis.Nil) { return nil, errs.Wrap(v.Err()) } val := utils.StringToInt64(seq.Val()) @@ -201,6 +201,7 @@ func (c *msgCache) getSeqs( m[items[i]] = val } } + return m, nil } @@ -229,6 +230,7 @@ func (c *msgCache) setSeqs(ctx context.Context, seqs map[string]int64, getkey fu } } _, err := pipe.Exec(ctx) + return err } @@ -319,6 +321,7 @@ func (c *msgCache) GetHasReadSeq(ctx context.Context, userID string, conversatio func (c *msgCache) AddTokenFlag(ctx context.Context, userID string, platformID int, token string, flag int) error { key := uidPidToken + userID + ":" + constant.PlatformIDToName(platformID) + return errs.Wrap(c.rdb.HSet(ctx, key, token, flag).Err()) } @@ -332,6 +335,7 @@ func (c *msgCache) GetTokensWithoutError(ctx context.Context, userID string, pla for k, v := range m { mm[k] = utils.StringToInt(v) } + return mm, nil } @@ -341,11 +345,13 @@ func (c *msgCache) SetTokenMapByUidPid(ctx context.Context, userID string, platf for k, v := range m { mm[k] = v } + return errs.Wrap(c.rdb.HSet(ctx, key, mm).Err()) } func (c *msgCache) DeleteTokenByUidPid(ctx context.Context, userID string, platform int, fields []string) error { key := uidPidToken + userID + ":" + constant.PlatformIDToName(platform) + return errs.Wrap(c.rdb.HDel(ctx, key, fields...).Err()) } @@ -366,8 +372,9 @@ func (c *msgCache) GetMessagesBySeq( for _, v := range seqs { // MESSAGE_CACHE:169.254.225.224_reliability1653387820_0_1 key := c.getMessageCacheKey(conversationID, v) - if err := pipe.Get(ctx, key).Err(); err != nil && err != redis.Nil { - return nil, nil, err + err2 := pipe.Get(ctx, key).Err() + if err2 != nil && errors.Is(err2, redis.Nil) { + return nil, nil, err2 } } result, err := pipe.Exec(ctx) @@ -381,6 +388,7 @@ func (c *msgCache) GetMessagesBySeq( if err == nil { if msg.Status != constant.MsgDeleted { seqMsgs = append(seqMsgs, &msg) + continue } } else { @@ -389,6 +397,7 @@ func (c *msgCache) GetMessagesBySeq( failedSeqs = append(failedSeqs, seqs[i]) } } + return seqMsgs, failedSeqs, err } @@ -408,6 +417,7 @@ func (c *msgCache) SetMessageToCache(ctx context.Context, conversationID string, } } _, err := pipe.Exec(ctx) + return len(failedMsgs), err } @@ -440,6 +450,7 @@ func (c *msgCache) UserDeleteMsgs(ctx context.Context, conversationID string, se } } _, err := pipe.Exec(ctx) + return errs.Wrap(err) } @@ -452,6 +463,7 @@ func (c *msgCache) GetUserDelList(ctx context.Context, userID, conversationID st for i, v := range result { seqs[i] = utils.StringToInt64(v) } + return seqs, nil } @@ -460,6 +472,7 @@ func (c *msgCache) DelUserDeleteMsgsList(ctx context.Context, conversationID str delUsers, err := c.rdb.SMembers(ctx, c.getMessageDelUserListKey(conversationID, seq)).Result() if err != nil { log.ZWarn(ctx, "DelUserDeleteMsgsList failed", err, "conversationID", conversationID, "seq", seq) + continue } if len(delUsers) > 0 { @@ -502,12 +515,13 @@ func (c *msgCache) DeleteMessages(ctx context.Context, conversationID string, se } } _, err := pipe.Exec(ctx) + return errs.Wrap(err) } func (c *msgCache) CleanUpOneConversationAllMsg(ctx context.Context, conversationID string) error { vals, err := c.rdb.Keys(ctx, c.allMessageCacheKey(conversationID)).Result() - if err == redis.Nil { + if errors.Is(err, redis.Nil) { return nil } if err != nil { @@ -515,11 +529,13 @@ func (c *msgCache) CleanUpOneConversationAllMsg(ctx context.Context, conversatio } pipe := c.rdb.Pipeline() for _, v := range vals { - if err := pipe.Del(ctx, v).Err(); err != nil { - return errs.Wrap(err) + err2 := pipe.Del(ctx, v).Err() + if err2 != nil { + return errs.Wrap(err2) } } _, err = pipe.Exec(ctx) + return errs.Wrap(err) } @@ -528,13 +544,15 @@ func (c *msgCache) DelMsgFromCache(ctx context.Context, userID string, seqs []in key := c.getMessageCacheKey(userID, seq) result, err := c.rdb.Get(ctx, key).Result() if err != nil { - if err == redis.Nil { + if errors.Is(err, redis.Nil) { continue } + return errs.Wrap(err) } var msg sdkws.MsgData - if err := jsonpb.UnmarshalString(result, &msg); err != nil { + err = jsonpb.UnmarshalString(result, &msg) + if err != nil { return err } msg.Status = constant.MsgDeleted @@ -546,6 +564,7 @@ func (c *msgCache) DelMsgFromCache(ctx context.Context, userID string, seqs []in return errs.Wrap(err) } } + return nil } @@ -571,6 +590,7 @@ func (c *msgCache) SetSendMsgStatus(ctx context.Context, id string, status int32 func (c *msgCache) GetSendMsgStatus(ctx context.Context, id string) (int32, error) { result, err := c.rdb.Get(ctx, sendMsgFailedFlag+id).Int() + return int32(result), errs.Wrap(err) } @@ -597,6 +617,7 @@ func (c *msgCache) DelFcmToken(ctx context.Context, account string, platformID i func (c *msgCache) IncrUserBadgeUnreadCountSum(ctx context.Context, userID string) (int, error) { seq, err := c.rdb.Incr(ctx, userBadgeUnreadCountSum+userID).Result() + return int(seq), errs.Wrap(err) } @@ -610,11 +631,13 @@ func (c *msgCache) GetUserBadgeUnreadCountSum(ctx context.Context, userID string func (c *msgCache) LockMessageTypeKey(ctx context.Context, clientMsgID string, TypeKey string) error { key := exTypeKeyLocker + clientMsgID + "_" + TypeKey + return errs.Wrap(c.rdb.SetNX(ctx, key, 1, time.Minute).Err()) } func (c *msgCache) UnLockMessageTypeKey(ctx context.Context, clientMsgID string, TypeKey string) error { key := exTypeKeyLocker + clientMsgID + "_" + TypeKey + return errs.Wrap(c.rdb.Del(ctx, key).Err()) } @@ -629,6 +652,7 @@ func (c *msgCache) getMessageReactionExPrefix(clientMsgID string, sessionType in case constant.NotificationChatType: return "EX_NOTIFICATION" + clientMsgID } + return "" } @@ -637,6 +661,7 @@ func (c *msgCache) JudgeMessageReactionExist(ctx context.Context, clientMsgID st if err != nil { return false, utils.Wrap(err, "") } + return n > 0, nil } diff --git a/pkg/common/db/cache/user.go b/pkg/common/db/cache/user.go index b821b4a52..0afbd595e 100644 --- a/pkg/common/db/cache/user.go +++ b/pkg/common/db/cache/user.go @@ -17,6 +17,7 @@ package cache import ( "context" "encoding/json" + "errors" "hash/crc32" "strconv" "time" @@ -70,6 +71,7 @@ func NewUserCacheRedis( options rockscache.Options, ) UserCache { rcClient := rockscache.NewClient(rdb, options) + return &UserCacheRedis{ rdb: rdb, metaCache: NewMetaCacheRedis(rcClient), @@ -97,10 +99,6 @@ func (u *UserCacheRedis) getUserGlobalRecvMsgOptKey(userID string) string { return userGlobalRecvMsgOptKey + userID } -func (u *UserCacheRedis) getUserStatusHashKey(userID string, Id int32) string { - return userID + "_" + string(Id) + platformID -} - func (u *UserCacheRedis) GetUserInfo(ctx context.Context, userID string) (userInfo *relationtb.UserModel, err error) { return getCache( ctx, @@ -114,10 +112,11 @@ func (u *UserCacheRedis) GetUserInfo(ctx context.Context, userID string) (userIn } func (u *UserCacheRedis) GetUsersInfo(ctx context.Context, userIDs []string) ([]*relationtb.UserModel, error) { - var keys []string + keys := make([]string, 0, len(userIDs)) for _, userID := range userIDs { keys = append(keys, u.getUserInfoKey(userID)) } + return batchGetCache( ctx, u.rcClient, @@ -129,6 +128,7 @@ func (u *UserCacheRedis) GetUsersInfo(ctx context.Context, userIDs []string) ([] return i, nil } } + return 0, errIndex }, func(ctx context.Context) ([]*relationtb.UserModel, error) { @@ -138,12 +138,13 @@ func (u *UserCacheRedis) GetUsersInfo(ctx context.Context, userIDs []string) ([] } func (u *UserCacheRedis) DelUsersInfo(userIDs ...string) UserCache { - var keys []string + keys := make([]string, 0, len(userIDs)) for _, userID := range userIDs { keys = append(keys, u.getUserInfoKey(userID)) } cache := u.NewCache() cache.AddKeys(keys...) + return cache } @@ -160,22 +161,19 @@ func (u *UserCacheRedis) GetUserGlobalRecvMsgOpt(ctx context.Context, userID str } func (u *UserCacheRedis) DelUsersGlobalRecvMsgOpt(userIDs ...string) UserCache { - var keys []string + keys := make([]string, 0, len(userIDs)) for _, userID := range userIDs { keys = append(keys, u.getUserGlobalRecvMsgOptKey(userID)) } cache := u.NewCache() cache.AddKeys(keys...) - return cache -} -func (u *UserCacheRedis) getOnlineStatusKey(userID string) string { - return olineStatusKey + userID + return cache } // GetUserStatus get user status. func (u *UserCacheRedis) GetUserStatus(ctx context.Context, userIDs []string) ([]*user.OnlineStatus, error) { - var res []*user.OnlineStatus + userStatus := make([]*user.OnlineStatus, 0, len(userIDs)) for _, userID := range userIDs { UserIDNum := crc32.ChecksumIEEE([]byte(userID)) modKey := strconv.Itoa(int(UserIDNum % statusMod)) @@ -183,13 +181,14 @@ func (u *UserCacheRedis) GetUserStatus(ctx context.Context, userIDs []string) ([ key := olineStatusKey + modKey result, err := u.rdb.HGet(ctx, key, userID).Result() if err != nil { - if err == redis.Nil { + if errors.Is(err, redis.Nil) { // key or field does not exist - res = append(res, &user.OnlineStatus{ + userStatus = append(userStatus, &user.OnlineStatus{ UserID: userID, Status: constant.Offline, PlatformIDs: nil, }) + continue } else { return nil, errs.Wrap(err) @@ -201,9 +200,10 @@ func (u *UserCacheRedis) GetUserStatus(ctx context.Context, userIDs []string) ([ } onlineStatus.UserID = userID onlineStatus.Status = constant.Online - res = append(res, &onlineStatus) + userStatus = append(userStatus, &onlineStatus) } - return res, nil + + return userStatus, nil } // SetUserStatus Set the user status and save it in redis. @@ -224,15 +224,16 @@ func (u *UserCacheRedis) SetUserStatus(ctx context.Context, userID string, statu Status: constant.Online, PlatformIDs: []int32{platformID}, } - jsonData, err := json.Marshal(onlineStatus) - if err != nil { - return errs.Wrap(err) + jsonData, err2 := json.Marshal(&onlineStatus) + if err2 != nil { + return errs.Wrap(err2) } - _, err = u.rdb.HSet(ctx, key, userID, string(jsonData)).Result() - if err != nil { - return errs.Wrap(err) + _, err2 = u.rdb.HSet(ctx, key, userID, string(jsonData)).Result() + if err2 != nil { + return errs.Wrap(err2) } u.rdb.Expire(ctx, key, userOlineStatusExpireTime) + return nil } } @@ -240,7 +241,7 @@ func (u *UserCacheRedis) SetUserStatus(ctx context.Context, userID string, statu isNil := false result, err := u.rdb.HGet(ctx, key, userID).Result() if err != nil { - if err == redis.Nil { + if errors.Is(err, redis.Nil) { isNil = true } else { return errs.Wrap(err) @@ -248,51 +249,45 @@ func (u *UserCacheRedis) SetUserStatus(ctx context.Context, userID string, statu } if status == constant.Offline { - if isNil { - log.ZWarn(ctx, "this user not online,maybe trigger order not right", - err, "userStatus", status) - return nil + err = u.refreshStatusOffline(ctx, userID, status, platformID, isNil, err, result, key) + if err != nil { + return err } - var onlineStatus user.OnlineStatus - err = json.Unmarshal([]byte(result), &onlineStatus) + } else { + err = u.refreshStatusOnline(ctx, userID, platformID, isNil, err, result, key) if err != nil { return errs.Wrap(err) } - var newPlatformIDs []int32 - for _, val := range onlineStatus.PlatformIDs { - if val != platformID { - newPlatformIDs = append(newPlatformIDs, val) - } + } + + return nil +} + +func (u *UserCacheRedis) refreshStatusOffline(ctx context.Context, userID string, status, platformID int32, isNil bool, err error, result, key string) error { + if isNil { + log.ZWarn(ctx, "this user not online,maybe trigger order not right", + err, "userStatus", status) + + return nil + } + var onlineStatus user.OnlineStatus + err = json.Unmarshal([]byte(result), &onlineStatus) + if err != nil { + return errs.Wrap(err) + } + var newPlatformIDs []int32 + for _, val := range onlineStatus.PlatformIDs { + if val != platformID { + newPlatformIDs = append(newPlatformIDs, val) } - if newPlatformIDs == nil { - _, err = u.rdb.HDel(ctx, key, userID).Result() - if err != nil { - return errs.Wrap(err) - } - } else { - onlineStatus.PlatformIDs = newPlatformIDs - newjsonData, err := json.Marshal(&onlineStatus) - if err != nil { - return errs.Wrap(err) - } - _, err = u.rdb.HSet(ctx, key, userID, string(newjsonData)).Result() - if err != nil { - return errs.Wrap(err) - } + } + if newPlatformIDs == nil { + _, err = u.rdb.HDel(ctx, key, userID).Result() + if err != nil { + return errs.Wrap(err) } } else { - var onlineStatus user.OnlineStatus - if !isNil { - err = json.Unmarshal([]byte(result), &onlineStatus) - if err != nil { - return errs.Wrap(err) - } - onlineStatus.PlatformIDs = RemoveRepeatedElementsInList(append(onlineStatus.PlatformIDs, platformID)) - } else { - onlineStatus.PlatformIDs = append(onlineStatus.PlatformIDs, platformID) - } - onlineStatus.Status = constant.Online - onlineStatus.UserID = userID + onlineStatus.PlatformIDs = newPlatformIDs newjsonData, err := json.Marshal(&onlineStatus) if err != nil { return errs.Wrap(err) @@ -301,7 +296,31 @@ func (u *UserCacheRedis) SetUserStatus(ctx context.Context, userID string, statu if err != nil { return errs.Wrap(err) } + } + return nil +} + +func (u *UserCacheRedis) refreshStatusOnline(ctx context.Context, userID string, platformID int32, isNil bool, err error, result, key string) error { + var onlineStatus user.OnlineStatus + if !isNil { + err2 := json.Unmarshal([]byte(result), &onlineStatus) + if err != nil { + return errs.Wrap(err2) + } + onlineStatus.PlatformIDs = RemoveRepeatedElementsInList(append(onlineStatus.PlatformIDs, platformID)) + } else { + onlineStatus.PlatformIDs = append(onlineStatus.PlatformIDs, platformID) + } + onlineStatus.Status = constant.Online + onlineStatus.UserID = userID + newjsonData, err := json.Marshal(&onlineStatus) + if err != nil { + return errs.Wrap(err) + } + _, err = u.rdb.HSet(ctx, key, userID, string(newjsonData)).Result() + if err != nil { + return errs.Wrap(err) } return nil diff --git a/pkg/common/db/controller/auth.go b/pkg/common/db/controller/auth.go index 17b4a440d..13d06a964 100644 --- a/pkg/common/db/controller/auth.go +++ b/pkg/common/db/controller/auth.go @@ -69,9 +69,9 @@ func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformI } } if len(deleteTokenKey) != 0 { - err := a.cache.DeleteTokenByUidPid(ctx, userID, platformID, deleteTokenKey) - if err != nil { - return "", err + err2 := a.cache.DeleteTokenByUidPid(ctx, userID, platformID, deleteTokenKey) + if err2 != nil { + return "", err2 } } claims := tokenverify.BuildClaims(userID, platformID, a.accessExpire) @@ -80,5 +80,6 @@ func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformI if err != nil { return "", utils.Wrap(err, "") } + return tokenString, a.cache.AddTokenFlag(ctx, userID, platformID, tokenString, constant.NormalToken) } diff --git a/pkg/common/db/controller/black.go b/pkg/common/db/controller/black.go index 70e942a77..38147e4e9 100644 --- a/pkg/common/db/controller/black.go +++ b/pkg/common/db/controller/black.go @@ -55,6 +55,7 @@ func (b *blackDatabase) Create(ctx context.Context, blacks []*relation.BlackMode if err := b.black.Create(ctx, blacks); err != nil { return err } + return b.deleteBlackIDsCache(ctx, blacks) } @@ -63,6 +64,7 @@ func (b *blackDatabase) Delete(ctx context.Context, blacks []*relation.BlackMode if err := b.black.Delete(ctx, blacks); err != nil { return err } + return b.deleteBlackIDsCache(ctx, blacks) } @@ -71,6 +73,7 @@ func (b *blackDatabase) deleteBlackIDsCache(ctx context.Context, blacks []*relat for _, black := range blacks { cache = cache.DelBlackIDs(ctx, black.OwnerUserID) } + return cache.ExecDel(ctx) } @@ -97,6 +100,7 @@ func (b *blackDatabase) CheckIn( return } log.ZDebug(ctx, "blackIDs", "user1BlackIDs", userID1BlackIDs, "user2BlackIDs", userID2BlackIDs) + return utils.IsContain(userID2, userID1BlackIDs), utils.IsContain(userID1, userID2BlackIDs), nil } diff --git a/pkg/common/db/controller/conversation.go b/pkg/common/db/controller/conversation.go index c3dd6980e..e68eb25ba 100644 --- a/pkg/common/db/controller/conversation.go +++ b/pkg/common/db/controller/conversation.go @@ -99,8 +99,8 @@ func (c *conversationDatabase) SetUsersConversationFiledTx(ctx context.Context, now := time.Now() for _, v := range NotUserIDs { temp := new(relationtb.ConversationModel) - if err := utils.CopyStructFields(temp, conversation); err != nil { - return err + if err2 := utils.CopyStructFields(temp, conversation); err2 != nil { + return err2 } temp.OwnerUserID = v temp.CreateTime = now @@ -113,10 +113,12 @@ func (c *conversationDatabase) SetUsersConversationFiledTx(ctx context.Context, } cache = cache.DelConversationIDs(NotUserIDs...).DelUserConversationIDsHash(NotUserIDs...).DelConversations(conversation.ConversationID, NotUserIDs...) } + return nil }); err != nil { return err } + return cache.ExecDel(ctx) } @@ -130,6 +132,7 @@ func (c *conversationDatabase) UpdateUsersConversationFiled(ctx context.Context, if _, ok := args["recv_msg_opt"]; ok { cache = cache.DelConversationNotReceiveMessageUserIDs(conversationID) } + return cache.ExecDel(ctx) } @@ -137,13 +140,14 @@ func (c *conversationDatabase) CreateConversation(ctx context.Context, conversat if err := c.conversationDB.Create(ctx, conversations); err != nil { return err } - var userIDs []string + userIDs := make([]string, 0, len(conversations)) cache := c.cache.NewCache() for _, conversation := range conversations { cache = cache.DelConversations(conversation.OwnerUserID, conversation.ConversationID) cache = cache.DelConversationNotReceiveMessageUserIDs(conversation.ConversationID) userIDs = append(userIDs, conversation.OwnerUserID) } + return cache.DelConversationIDs(userIDs...).DelUserConversationIDsHash(userIDs...).ExecDel(ctx) } @@ -178,10 +182,12 @@ func (c *conversationDatabase) SyncPeerUserPrivateConversationTx(ctx context.Con } } } + return nil }); err != nil { return err } + return cache.ExecDel(ctx) } @@ -234,12 +240,15 @@ func (c *conversationDatabase) SetUserConversations(ctx context.Context, ownerUs if err != nil { return err } - cache = cache.DelConversationIDs(ownerUserID).DelUserConversationIDsHash(ownerUserID).DelConversationNotReceiveMessageUserIDs(utils.Slice(notExistConversations, func(e *relationtb.ConversationModel) string { return e.ConversationID })...) + cache = cache.DelConversationIDs(ownerUserID).DelUserConversationIDsHash(ownerUserID) + cache = cache.DelConversationNotReceiveMessageUserIDs(utils.Slice(notExistConversations, func(e *relationtb.ConversationModel) string { return e.ConversationID })...) } + return nil }); err != nil { return err } + return cache.ExecDel(ctx) } @@ -276,10 +285,12 @@ func (c *conversationDatabase) CreateGroupChatConversation(ctx context.Context, for _, v := range existConversationUserIDs { cache = cache.DelConversations(v, conversationID) } + return nil }); err != nil { return err } + return cache.ExecDel(ctx) } diff --git a/pkg/common/db/controller/friend.go b/pkg/common/db/controller/friend.go index 7816ef935..f35d6728b 100644 --- a/pkg/common/db/controller/friend.go +++ b/pkg/common/db/controller/friend.go @@ -16,6 +16,7 @@ package controller import ( "context" + "errors" "time" "gorm.io/gorm" @@ -109,6 +110,7 @@ func (f *friendDatabase) CheckIn( if err != nil { return } + return utils.IsContain(userID2, userID1FriendIDs), utils.IsContain(userID1, userID2FriendIDs), nil } @@ -121,8 +123,8 @@ func (f *friendDatabase) AddFriendRequest( ) (err error) { return f.tx.Transaction(func(tx any) error { _, err := f.friendRequest.NewTx(tx).Take(ctx, fromUserID, toUserID) - // 有db错误 - if err != nil && errs.Unwrap(err) != gorm.ErrRecordNotFound { + // if there is a db error + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { return err } // 无错误 则更新 @@ -136,12 +138,14 @@ func (f *friendDatabase) AddFriendRequest( if err := f.friendRequest.NewTx(tx).UpdateByMap(ctx, fromUserID, toUserID, m); err != nil { return err } + return nil } // gorm.ErrRecordNotFound 错误,则新增 if err := f.friendRequest.NewTx(tx).Create(ctx, []*relation.FriendRequestModel{{FromUserID: fromUserID, ToUserID: toUserID, ReqMsg: reqMsg, Ex: ex, CreateTime: time.Now(), HandleTime: time.Unix(0, 0)}}); err != nil { return err } + return nil }) } @@ -154,11 +158,11 @@ func (f *friendDatabase) BecomeFriends( addSource int32, ) (err error) { cache := f.cache.NewCache() - if err := f.tx.Transaction(func(tx any) error { - // 先find 找出重复的 去掉重复的 - fs1, err := f.friend.NewTx(tx).FindFriends(ctx, ownerUserID, friendUserIDs) - if err != nil { - return err + fn := func(tx any) error { + // first,find and drop delete ones + fs1, err2 := f.friend.NewTx(tx).FindFriends(ctx, ownerUserID, friendUserIDs) + if err2 != nil { + return err2 } opUserID := mcontext.GetOperationID(ctx) for _, v := range friendUserIDs { @@ -168,13 +172,13 @@ func (f *friendDatabase) BecomeFriends( return e.FriendUserID }) - err = f.friend.NewTx(tx).Create(ctx, fs11) - if err != nil { - return err + err2 = f.friend.NewTx(tx).Create(ctx, fs11) + if err2 != nil { + return err2 } - fs2, err := f.friend.NewTx(tx).FindReversalFriends(ctx, ownerUserID, friendUserIDs) - if err != nil { - return err + fs2, err2 := f.friend.NewTx(tx).FindReversalFriends(ctx, ownerUserID, friendUserIDs) + if err2 != nil { + return err2 } var newFriendIDs []string for _, v := range friendUserIDs { @@ -184,16 +188,20 @@ func (f *friendDatabase) BecomeFriends( fs22 := utils.DistinctAny(fs2, func(e *relation.FriendModel) string { return e.OwnerUserID }) - err = f.friend.NewTx(tx).Create(ctx, fs22) - if err != nil { - return err + err2 = f.friend.NewTx(tx).Create(ctx, fs22) + if err2 != nil { + return err2 } newFriendIDs = append(newFriendIDs, ownerUserID) cache = cache.DelFriendIDs(newFriendIDs...) - return nil - }); err != nil { + return nil } + err = f.tx.Transaction(fn) + if err != nil { + return err + } + return cache.ExecDel(ctx) } @@ -216,6 +224,7 @@ func (f *friendDatabase) RefuseFriendRequest( if err != nil { return err } + return nil } @@ -251,7 +260,7 @@ func (f *friendDatabase) AgreeFriendRequest( if err != nil { return err } - } else if err != nil && errs.Unwrap(err) != gorm.ErrRecordNotFound { + } else if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { return err } @@ -290,6 +299,7 @@ func (f *friendDatabase) AgreeFriendRequest( return err } } + return f.cache.DelFriendIDs(friendRequest.ToUserID, friendRequest.FromUserID).ExecDel(ctx) }) } @@ -299,6 +309,7 @@ func (f *friendDatabase) Delete(ctx context.Context, ownerUserID string, friendU if err := f.friend.Delete(ctx, ownerUserID, friendUserIDs); err != nil { return err } + return f.cache.DelFriendIDs(append(friendUserIDs, ownerUserID)...).ExecDel(ctx) } @@ -307,6 +318,7 @@ func (f *friendDatabase) UpdateRemark(ctx context.Context, ownerUserID, friendUs if err := f.friend.UpdateRemark(ctx, ownerUserID, friendUserID, remark); err != nil { return err } + return f.cache.DelFriend(ownerUserID, friendUserID).ExecDel(ctx) } @@ -359,6 +371,7 @@ func (f *friendDatabase) FindFriendsWithError( if len(friends) != len(friendUserIDs) { err = errs.ErrRecordNotFound.Wrap() } + return } diff --git a/pkg/common/db/controller/group.go b/pkg/common/db/controller/group.go index 194f3e8b2..0788429a8 100644 --- a/pkg/common/db/controller/group.go +++ b/pkg/common/db/controller/group.go @@ -102,6 +102,7 @@ func NewGroupDatabase( cache: cache, mongoDB: superGroup, } + return database } @@ -109,6 +110,7 @@ func InitGroupDatabase(db *gorm.DB, rdb redis.UniversalClient, database *mongo.D rcOptions := rockscache.NewDefaultOptions() rcOptions.StrongConsistency = true rcOptions.RandomExpireAdjustment = 0.2 + return NewGroupDatabase( relation.NewGroupDB(db), relation.NewGroupMemberDB(db), @@ -151,6 +153,7 @@ func (g *groupDatabase) FindGroupMemberNum(ctx context.Context, groupID string) if err != nil { return 0, err } + return uint32(num), nil } @@ -184,10 +187,12 @@ func (g *groupDatabase) CreateGroup( cache = cache.DelJoinedGroupID(groupMember.UserID).DelGroupMembersInfo(groupMember.GroupID, groupMember.UserID) } cache = cache.DelGroupsInfo(createGroupIDs...) + return nil }); err != nil { return err } + return cache.ExecDel(ctx) } @@ -211,6 +216,7 @@ func (g *groupDatabase) UpdateGroup(ctx context.Context, groupID string, data ma if err := g.groupDB.UpdateMap(ctx, groupID, data); err != nil { return err } + return g.cache.DelGroupsInfo(groupID).ExecDel(ctx) } @@ -231,10 +237,12 @@ func (g *groupDatabase) DismissGroup(ctx context.Context, groupID string, delete cache = cache.DelJoinedGroupID(userIDs...).DelGroupMemberIDs(groupID).DelGroupsMemberNum(groupID).DelGroupMembersHash(groupID) } cache = cache.DelGroupsInfo(groupID) + return nil }); err != nil { return err } + return cache.ExecDel(ctx) } @@ -276,6 +284,7 @@ func (g *groupDatabase) FindGroupMember(ctx context.Context, groupIDs []string, } res = append(res, v) } + return res, nil } if len(roleLevels) == 0 { @@ -286,8 +295,10 @@ func (g *groupDatabase) FindGroupMember(ctx context.Context, groupIDs []string, } totalGroupMembers = append(totalGroupMembers, groupMembers...) } + return totalGroupMembers, nil } + return g.groupMemberDB.Find(ctx, groupIDs, userIDs, roleLevels) } @@ -307,6 +318,7 @@ func (g *groupDatabase) PageGetJoinGroup( } totalGroupMembers = append(totalGroupMembers, groupMembers...) } + return uint32(len(groupIDs)), totalGroupMembers, nil } @@ -327,6 +339,7 @@ func (g *groupDatabase) PageGetGroupMember( if err != nil { return 0, nil, err } + return uint32(len(groupMemberIDs)), members, nil } @@ -378,6 +391,7 @@ func (g *groupDatabase) HandlerGroupRequest( return err } } + return nil }) } @@ -386,6 +400,7 @@ func (g *groupDatabase) DeleteGroupMember(ctx context.Context, groupID string, u if err := g.groupMemberDB.Delete(ctx, groupID, userIDs); err != nil { return err } + return g.cache.DelGroupMembersHash(groupID). DelGroupMemberIDs(groupID). DelGroupsMemberNum(groupID). @@ -410,6 +425,7 @@ func (g *groupDatabase) MapGroupMemberNum(ctx context.Context, groupIDs []string } m[groupID] = uint32(num) } + return m, nil } @@ -429,6 +445,7 @@ func (g *groupDatabase) TransferGroupOwner(ctx context.Context, groupID string, if rowsAffected != 1 { return utils.Wrap(fmt.Errorf("newOwnerUserID %s rowsAffected = %d", newOwnerUserID, rowsAffected), "") } + return g.cache.DelGroupMembersInfo(groupID, oldOwnerUserID, newOwnerUserID).DelGroupMembersHash(groupID).ExecDel(ctx) }) } @@ -442,6 +459,7 @@ func (g *groupDatabase) UpdateGroupMember( if err := g.groupMemberDB.Update(ctx, groupID, userID, data); err != nil { return err } + return g.cache.DelGroupMembersInfo(groupID, userID).ExecDel(ctx) } @@ -454,10 +472,12 @@ func (g *groupDatabase) UpdateGroupMembers(ctx context.Context, data []*relation } cache = cache.DelGroupMembersInfo(item.GroupID, item.UserID) } + return nil }); err != nil { return err } + return cache.ExecDel(ctx) } @@ -469,6 +489,7 @@ func (g *groupDatabase) CreateGroupRequest(ctx context.Context, requests []*rela return err } } + return db.Create(ctx, requests) }) } @@ -504,6 +525,7 @@ func (g *groupDatabase) CreateSuperGroup(ctx context.Context, groupID string, in if err := g.mongoDB.CreateSuperGroup(ctx, groupID, initMemberIDs); err != nil { return err } + return g.cache.DelSuperGroupMemberIDs(groupID).DelJoinedSuperGroupIDs(initMemberIDs...).ExecDel(ctx) } @@ -521,10 +543,12 @@ func (g *groupDatabase) DeleteSuperGroup(ctx context.Context, groupID string) er if len(models) > 0 { cache = cache.DelJoinedSuperGroupIDs(models[0].MemberIDs...) } + return nil }); err != nil { return err } + return cache.ExecDel(ctx) } @@ -532,6 +556,7 @@ func (g *groupDatabase) DeleteSuperGroupMember(ctx context.Context, groupID stri if err := g.mongoDB.RemoverUserFromSuperGroup(ctx, groupID, userIDs); err != nil { return err } + return g.cache.DelSuperGroupMemberIDs(groupID).DelJoinedSuperGroupIDs(userIDs...).ExecDel(ctx) } @@ -539,6 +564,7 @@ func (g *groupDatabase) CreateSuperGroupMember(ctx context.Context, groupID stri if err := g.mongoDB.AddUserToSuperGroup(ctx, groupID, userIDs); err != nil { return err } + return g.cache.DelSuperGroupMemberIDs(groupID).DelJoinedSuperGroupIDs(userIDs...).ExecDel(ctx) } diff --git a/pkg/common/db/controller/msg.go b/pkg/common/db/controller/msg.go index af678f92c..1bbf4cdf6 100644 --- a/pkg/common/db/controller/msg.go +++ b/pkg/common/db/controller/msg.go @@ -135,6 +135,7 @@ func InitCommonMsgDatabase(rdb redis.UniversalClient, database *mongo.Database) cacheModel := cache.NewMsgCacheModel(rdb) msgDocModel := unrelation.NewMsgMongoDriver(database) CommonMsgDatabase := NewCommonMsgDatabase(msgDocModel, cacheModel) + return CommonMsgDatabase } @@ -150,14 +151,17 @@ type commonMsgDatabase struct { func (db *commonMsgDatabase) MsgToMQ(ctx context.Context, key string, msg2mq *sdkws.MsgData) error { _, _, err := db.producer.SendMessage(ctx, key, msg2mq) + return err } func (db *commonMsgDatabase) MsgToModifyMQ(ctx context.Context, key, conversationID string, messages []*sdkws.MsgData) error { if len(messages) > 0 { _, _, err := db.producerToModify.SendMessage(ctx, key, &pbmsg.MsgDataToModifyByMQ{ConversationID: conversationID, Messages: messages}) + return err } + return nil } @@ -165,26 +169,26 @@ func (db *commonMsgDatabase) MsgToPushMQ(ctx context.Context, key, conversationI partition, offset, err := db.producerToPush.SendMessage(ctx, key, &pbmsg.PushMsgDataToMQ{MsgData: msg2mq, ConversationID: conversationID}) if err != nil { log.ZError(ctx, "MsgToPushMQ", err, "key", key, "msg2mq", msg2mq) + return 0, 0, err } + return partition, offset, nil } func (db *commonMsgDatabase) MsgToMongoMQ(ctx context.Context, key, conversationID string, messages []*sdkws.MsgData, lastSeq int64) error { if len(messages) > 0 { _, _, err := db.producerToMongo.SendMessage(ctx, key, &pbmsg.MsgDataToMongoByMQ{LastSeq: lastSeq, ConversationID: conversationID, MsgData: messages}) + return err } + return nil } -func (db *commonMsgDatabase) BatchInsertBlock(ctx context.Context, conversationID string, fields []any, key int8, firstSeq int64) error { - if len(fields) == 0 { - return nil - } - num := db.msg.GetSingleGocMsgNum() +func checkTypeForBatchInsertBlock(fields []any, key int8, firstSeq int64) error { // num = 100 - for i, field := range fields { // 检查类型 + for i, field := range fields { // check type var ok bool switch key { case updateKeyMsg: @@ -202,80 +206,106 @@ func (db *commonMsgDatabase) BatchInsertBlock(ctx context.Context, conversationI return errs.ErrInternalServer.Wrap("field type is invalid") } } - // 返回值为true表示数据库存在该文档,false表示数据库不存在该文档 - updateMsgModel := func(seq int64, i int) (bool, error) { - var ( - res *mongo.UpdateResult - err error - ) - docID := db.msg.GetDocID(conversationID, seq) - index := db.msg.GetMsgIndex(seq) - field := fields[i] + + return nil +} + +func (db *commonMsgDatabase) updateMsgModelForBatchInsertBlock(ctx context.Context, conversationID string, fields []any, key int8, seq int64, i int) (bool, error) { + var ( + res *mongo.UpdateResult + err error + ) + docID := db.msg.GetDocID(conversationID, seq) + index := db.msg.GetMsgIndex(seq) + field := fields[i] + switch key { + case updateKeyMsg: + res, err = db.msgDocDatabase.UpdateMsg(ctx, docID, index, "msg", field) + case updateKeyRevoke: + res, err = db.msgDocDatabase.UpdateMsg(ctx, docID, index, "revoke", field) + } + if err != nil { + return false, err + } + + return res.MatchedCount > 0, nil +} + +func (db *commonMsgDatabase) newDocForBatchInsertBlock(conversationID string, fields []any, key int8, seq, firstSeq, num int64, i int) (unrelationtb.MsgDocModel, int) { + doc := unrelationtb.MsgDocModel{ + DocID: db.msg.GetDocID(conversationID, seq), + Msg: make([]*unrelationtb.MsgInfoModel, num), + } + var insert int // number of inserted + for j := i; j < len(fields); j++ { + seq = firstSeq + int64(j) + if db.msg.GetDocID(conversationID, seq) != doc.DocID { + break + } + insert++ switch key { case updateKeyMsg: - res, err = db.msgDocDatabase.UpdateMsg(ctx, docID, index, "msg", field) + doc.Msg[db.msg.GetMsgIndex(seq)] = &unrelationtb.MsgInfoModel{ + Msg: fields[j].(*unrelationtb.MsgDataModel), + } case updateKeyRevoke: - res, err = db.msgDocDatabase.UpdateMsg(ctx, docID, index, "revoke", field) + doc.Msg[db.msg.GetMsgIndex(seq)] = &unrelationtb.MsgInfoModel{ + Revoke: fields[j].(*unrelationtb.RevokeModel), + } } - if err != nil { - return false, err + } + for i, model := range doc.Msg { + if model == nil { + model = &unrelationtb.MsgInfoModel{} + doc.Msg[i] = model } - return res.MatchedCount > 0, nil + if model.DelList == nil { + doc.Msg[i].DelList = []string{} + } + } + + return doc, insert +} + +func (db *commonMsgDatabase) BatchInsertBlock(ctx context.Context, conversationID string, fields []any, key int8, firstSeq int64) error { + if len(fields) == 0 { + return nil + } + num := db.msg.GetSingleGocMsgNum() + // num = 100 + err := checkTypeForBatchInsertBlock(fields, key, firstSeq) + if err != nil { + return err } tryUpdate := true for i := 0; i < len(fields); i++ { - seq := firstSeq + int64(i) // 当前seq + seq := firstSeq + int64(i) // current seq + // try update if tryUpdate { - matched, err := updateMsgModel(seq, i) + matched, err := db.updateMsgModelForBatchInsertBlock(ctx, conversationID, fields, key, seq, i) if err != nil { return err } if matched { - continue // 匹配到了,继续下一个(不一定修改) - } - } - doc := unrelationtb.MsgDocModel{ - DocID: db.msg.GetDocID(conversationID, seq), - Msg: make([]*unrelationtb.MsgInfoModel, num), - } - var insert int // 插入的数量 - for j := i; j < len(fields); j++ { - seq = firstSeq + int64(j) - if db.msg.GetDocID(conversationID, seq) != doc.DocID { - break - } - insert++ - switch key { - case updateKeyMsg: - doc.Msg[db.msg.GetMsgIndex(seq)] = &unrelationtb.MsgInfoModel{ - Msg: fields[j].(*unrelationtb.MsgDataModel), - } - case updateKeyRevoke: - doc.Msg[db.msg.GetMsgIndex(seq)] = &unrelationtb.MsgInfoModel{ - Revoke: fields[j].(*unrelationtb.RevokeModel), - } - } - } - for i, model := range doc.Msg { - if model == nil { - model = &unrelationtb.MsgInfoModel{} - doc.Msg[i] = model - } - if model.DelList == nil { - doc.Msg[i].DelList = []string{} + continue // if matched,skip } } + doc, insert := db.newDocForBatchInsertBlock(conversationID, fields, key, seq, firstSeq, num, i) + // insert doc into db if err := db.msgDocDatabase.Create(ctx, &doc); err != nil { if mongo.IsDuplicateKeyError(err) { - i-- // 存在并发,重试当前数据 - tryUpdate = true // 以修改模式 + i-- // exists concurrent, + tryUpdate = true // try update + continue } + return err } - tryUpdate = false // 当前以插入成功,下一块优先插入模式 - i += insert - 1 // 跳过已插入的数据 + tryUpdate = false // if insert success,change to insert mode + i += insert - 1 // skip inserted data } + return nil } @@ -322,6 +352,7 @@ func (db *commonMsgDatabase) BatchInsertChat2DB(ctx context.Context, conversatio Ex: msg.Ex, } } + return db.BatchInsertBlock(ctx, conversationID, msgs, updateKeyMsg, msgList[0].Seq) } @@ -338,9 +369,11 @@ func (db *commonMsgDatabase) MarkSingleChatMsgsAsRead(ctx context.Context, userI log.ZDebug(ctx, "MarkSingleChatMsgsAsRead", "userID", userID, "docID", docID, "indexes", indexes) if err := db.msgDocDatabase.MarkSingleChatMsgsAsRead(ctx, userID, docID, indexes); err != nil { log.ZError(ctx, "MarkSingleChatMsgsAsRead", err, "userID", userID, "docID", docID, "indexes", indexes) + return err } } + return nil } @@ -354,8 +387,9 @@ func (db *commonMsgDatabase) DelUserDeleteMsgsList(ctx context.Context, conversa func (db *commonMsgDatabase) BatchInsertChat2Cache(ctx context.Context, conversationID string, msgs []*sdkws.MsgData) (seq int64, isNew bool, err error) { currentMaxSeq, err := db.cache.GetMaxSeq(ctx, conversationID) - if err != nil && errs.Unwrap(err) != redis.Nil { + if err != nil && !errors.Is(err, redis.Nil) { prome.Inc(prome.SeqGetFailedCounter) + return 0, false, err } prome.Inc(prome.SeqGetSuccessCounter) @@ -366,7 +400,7 @@ func (db *commonMsgDatabase) BatchInsertChat2Cache(ctx context.Context, conversa if lenList < 1 { return 0, false, errors.New("too short as 0") } - if errs.Unwrap(err) == redis.Nil { + if errors.Is(err, redis.Nil) { isNew = true } lastMaxSeq := currentMaxSeq @@ -396,6 +430,7 @@ func (db *commonMsgDatabase) BatchInsertChat2Cache(ctx context.Context, conversa } else { prome.Inc(prome.SeqSetSuccessCounter) } + return lastMaxSeq, isNew, utils.Wrap(err, "") } @@ -410,6 +445,7 @@ func (db *commonMsgDatabase) getMsgBySeqs(ctx context.Context, userID, conversat totalMsgs = append(totalMsgs, convert.MsgDB2Pb(msg.Msg)) } } + return totalMsgs, nil } @@ -420,6 +456,7 @@ func (db *commonMsgDatabase) findMsgInfoBySeq(ctx context.Context, userID, docID msg.Msg.IsRead = true } } + return msgs, err } @@ -438,16 +475,76 @@ func (db *commonMsgDatabase) getMsgBySeqsRange(ctx context.Context, userID strin seqMsgs = append(seqMsgs, convert.MsgDB2Pb(msg.Msg)) } } + return seqMsgs, nil } +func (db *commonMsgDatabase) getCacheMsgForGetMsgBySeqsRange(ctx context.Context, userID string, conversationID string, begin int64, seqs []int64) ([]*sdkws.MsgData, []int64, error) { + newBegin := seqs[0] + newEnd := seqs[len(seqs)-1] + log.ZDebug(ctx, "GetMsgBySeqsRange", "first seqs", seqs, "newBegin", newBegin, "newEnd", newEnd) + cachedMsgs, failedSeqs, err := db.cache.GetMessagesBySeq(ctx, conversationID, seqs) + if err != nil { + if !errors.Is(err, redis.Nil) { + prome.Add(prome.MsgPullFromRedisFailedCounter, len(failedSeqs)) + log.ZError(ctx, "get message from redis exception", err, "conversationID", conversationID, "seqs", seqs) + } + } + var successMsgs []*sdkws.MsgData + if len(cachedMsgs) == 0 { + return successMsgs, failedSeqs, err + } + // if len(cachedMsgs) > 0 + delSeqs, err2 := db.cache.GetUserDelList(ctx, userID, conversationID) + if err2 != nil && !errors.Is(err2, redis.Nil) { + return nil, nil, err2 + } + var cacheDelNum int + for _, msg := range cachedMsgs { + if !utils.Contain(msg.Seq, delSeqs...) { + successMsgs = append(successMsgs, msg) + } else { + cacheDelNum += 1 + } + } + log.ZDebug(ctx, "get delSeqs from redis", "delSeqs", delSeqs, "userID", userID, "conversationID", conversationID, "cacheDelNum", cacheDelNum) + var reGetSeqsCache []int64 + for i := 1; i <= cacheDelNum; { + newSeq := newBegin - int64(i) + if newSeq >= begin { + if !utils.Contain(newSeq, delSeqs...) { + log.ZDebug(ctx, "seq del in cache, a new seq in range append", "new seq", newSeq) + reGetSeqsCache = append(reGetSeqsCache, newSeq) + i++ + } + } else { + break + } + } + if len(reGetSeqsCache) > 0 { + log.ZDebug(ctx, "reGetSeqsCache", "reGetSeqsCache", reGetSeqsCache) + cachedMsgs, failedSeqs2, err2 := db.cache.GetMessagesBySeq(ctx, conversationID, reGetSeqsCache) + if err2 != nil { + if !errors.Is(err2, redis.Nil) { + prome.Add(prome.MsgPullFromRedisFailedCounter, len(failedSeqs2)) + log.ZError(ctx, "get message from redis exception", err2, "conversationID", conversationID, "seqs", reGetSeqsCache) + } + } + failedSeqs = append(failedSeqs, failedSeqs2...) + successMsgs = append(successMsgs, cachedMsgs...) + } + + return successMsgs, failedSeqs, err +} + func (db *commonMsgDatabase) GetMsgBySeqsRange(ctx context.Context, userID string, conversationID string, begin, end, num, userMaxSeq int64) (int64, int64, []*sdkws.MsgData, error) { + // 从缓存中获取最小和最大序列号,并根据给定的范围值进行调整 userMinSeq, err := db.cache.GetConversationUserMinSeq(ctx, conversationID, userID) - if err != nil && errs.Unwrap(err) != redis.Nil { + if err != nil && !errors.Is(err, redis.Nil) { return 0, 0, nil, err } minSeq, err := db.cache.GetMinSeq(ctx, conversationID) - if err != nil && errs.Unwrap(err) != redis.Nil { + if err != nil && !errors.Is(err, redis.Nil) { return 0, 0, nil, err } if userMinSeq > minSeq { @@ -455,18 +552,25 @@ func (db *commonMsgDatabase) GetMsgBySeqsRange(ctx context.Context, userID strin } if minSeq > end { log.ZInfo(ctx, "minSeq > end", "minSeq", minSeq, "end", end) + return 0, 0, nil, nil } maxSeq, err := db.cache.GetMaxSeq(ctx, conversationID) - if err != nil && errs.Unwrap(err) != redis.Nil { + if err != nil && !errors.Is(err, redis.Nil) { return 0, 0, nil, err } + + // log out debug info log.ZDebug(ctx, "GetMsgBySeqsRange", "userMinSeq", userMinSeq, "conMinSeq", minSeq, "conMaxSeq", maxSeq, "userMaxSeq", userMaxSeq) + + // adjust maxSeq according to userMaxSeq if userMaxSeq != 0 { if userMaxSeq < maxSeq { maxSeq = userMaxSeq } } + + // adjust begin and end according to minSeq and maxSeq if begin < minSeq { begin = minSeq } @@ -476,6 +580,8 @@ func (db *commonMsgDatabase) GetMsgBySeqsRange(ctx context.Context, userID strin if end < begin { return 0, 0, nil, errs.ErrArgs.Wrap("seq end < begin") } + + // get seqs to search var seqs []int64 for i := end; i > end-num; i-- { if i >= begin { @@ -487,67 +593,24 @@ func (db *commonMsgDatabase) GetMsgBySeqsRange(ctx context.Context, userID strin if len(seqs) == 0 { return 0, 0, nil, nil } - newBegin := seqs[0] - newEnd := seqs[len(seqs)-1] - log.ZDebug(ctx, "GetMsgBySeqsRange", "first seqs", seqs, "newBegin", newBegin, "newEnd", newEnd) - cachedMsgs, failedSeqs, err := db.cache.GetMessagesBySeq(ctx, conversationID, seqs) + + // get info from cache,and filter deleted msg + successMsgs, failedSeqs, err := db.getCacheMsgForGetMsgBySeqsRange(ctx, userID, conversationID, begin, seqs) if err != nil { - if err != redis.Nil { - prome.Add(prome.MsgPullFromRedisFailedCounter, len(failedSeqs)) - log.ZError(ctx, "get message from redis exception", err, "conversationID", conversationID, "seqs", seqs) - } - } - var successMsgs []*sdkws.MsgData - if len(cachedMsgs) > 0 { - delSeqs, err := db.cache.GetUserDelList(ctx, userID, conversationID) - if err != nil && errs.Unwrap(err) != redis.Nil { - return 0, 0, nil, err - } - var cacheDelNum int - for _, msg := range cachedMsgs { - if !utils.Contain(msg.Seq, delSeqs...) { - successMsgs = append(successMsgs, msg) - } else { - cacheDelNum += 1 - } - } - log.ZDebug(ctx, "get delSeqs from redis", "delSeqs", delSeqs, "userID", userID, "conversationID", conversationID, "cacheDelNum", cacheDelNum) - var reGetSeqsCache []int64 - for i := 1; i <= cacheDelNum; { - newSeq := newBegin - int64(i) - if newSeq >= begin { - if !utils.Contain(newSeq, delSeqs...) { - log.ZDebug(ctx, "seq del in cache, a new seq in range append", "new seq", newSeq) - reGetSeqsCache = append(reGetSeqsCache, newSeq) - i++ - } - } else { - break - } - } - if len(reGetSeqsCache) > 0 { - log.ZDebug(ctx, "reGetSeqsCache", "reGetSeqsCache", reGetSeqsCache) - cachedMsgs, failedSeqs2, err := db.cache.GetMessagesBySeq(ctx, conversationID, reGetSeqsCache) - if err != nil { - if err != redis.Nil { - prome.Add(prome.MsgPullFromRedisFailedCounter, len(failedSeqs2)) - log.ZError(ctx, "get message from redis exception", err, "conversationID", conversationID, "seqs", reGetSeqsCache) - } - } - failedSeqs = append(failedSeqs, failedSeqs2...) - successMsgs = append(successMsgs, cachedMsgs...) - } + return 0, 0, nil, err } + // log out debug info log.ZDebug(ctx, "get msgs from cache", "successMsgs", successMsgs) if len(failedSeqs) != 0 { log.ZDebug(ctx, "msgs not exist in redis", "seqs", failedSeqs) } - // get from cache or db + // if not found in cache,find in mongo prome.Add(prome.MsgPullFromRedisSuccessCounter, len(successMsgs)) if len(failedSeqs) > 0 { mongoMsgs, err := db.getMsgBySeqsRange(ctx, userID, conversationID, failedSeqs, begin, end) if err != nil { prome.Add(prome.MsgPullFromMongoFailedCounter, len(failedSeqs)) + return 0, 0, nil, err } prome.Add(prome.MsgPullFromMongoSuccessCounter, len(mongoMsgs)) @@ -559,15 +622,15 @@ func (db *commonMsgDatabase) GetMsgBySeqsRange(ctx context.Context, userID strin func (db *commonMsgDatabase) GetMsgBySeqs(ctx context.Context, userID string, conversationID string, seqs []int64) (int64, int64, []*sdkws.MsgData, error) { userMinSeq, err := db.cache.GetConversationUserMinSeq(ctx, conversationID, userID) - if err != nil && errs.Unwrap(err) != redis.Nil { + if err != nil && !errors.Is(err, redis.Nil) { return 0, 0, nil, err } minSeq, err := db.cache.GetMinSeq(ctx, conversationID) - if err != nil && errs.Unwrap(err) != redis.Nil { + if err != nil && !errors.Is(err, redis.Nil) { return 0, 0, nil, err } maxSeq, err := db.cache.GetMaxSeq(ctx, conversationID) - if err != nil && errs.Unwrap(err) != redis.Nil { + if err != nil && !errors.Is(err, redis.Nil) { return 0, 0, nil, err } if userMinSeq < minSeq { @@ -581,7 +644,7 @@ func (db *commonMsgDatabase) GetMsgBySeqs(ctx context.Context, userID string, co } successMsgs, failedSeqs, err := db.cache.GetMessagesBySeq(ctx, conversationID, newSeqs) if err != nil { - if err != redis.Nil { + if !errors.Is(err, redis.Nil) { prome.Add(prome.MsgPullFromRedisFailedCounter, len(failedSeqs)) log.ZError(ctx, "get message from redis exception", err, "failedSeqs", failedSeqs, "conversationID", conversationID) } @@ -607,11 +670,13 @@ func (db *commonMsgDatabase) GetMsgBySeqs(ctx context.Context, userID string, co mongoMsgs, err := db.getMsgBySeqs(ctx, userID, conversationID, failedSeqs) if err != nil { prome.Add(prome.MsgPullFromMongoFailedCounter, len(failedSeqs)) + return 0, 0, nil, err } prome.Add(prome.MsgPullFromMongoSuccessCounter, len(mongoMsgs)) successMsgs = append(successMsgs, mongoMsgs...) } + return minSeq, maxSeq, successMsgs, nil } @@ -632,61 +697,74 @@ func (db *commonMsgDatabase) DeleteConversationMsgsAndSetMinSeq(ctx context.Cont log.ZWarn(ctx, "CleanUpOneUserAllMsg", err, "conversationID", conversationID) } } + return db.cache.SetMinSeq(ctx, conversationID, minSeq) } -func (db *commonMsgDatabase) UserMsgsDestruct(ctx context.Context, userID string, conversationID string, destructTime int64, lastMsgDestructTime time.Time) (seqs []int64, err error) { +func processMsgDocModel(ctx context.Context, msgDocModel *unrelationtb.MsgDocModel, userID, conversationID string, index int64, destructTime int64, lastMsgDestructTime time.Time) (seqs []int64, over bool) { + if len(msgDocModel.Msg) > 0 { + i := 0 + for _, msg := range msgDocModel.Msg { + i++ + if msg != nil && msg.Msg != nil && msg.Msg.SendTime+destructTime*1000 <= time.Now().UnixMilli() { + if msg.Msg.SendTime+destructTime*1000 > lastMsgDestructTime.UnixMilli() && !utils.Contain(userID, msg.DelList...) { + seqs = append(seqs, msg.Msg.Seq) + } + } else { + log.ZDebug(ctx, "all msg need destruct is found", "conversationID", conversationID, "userID", userID, "index", index, "stop index", i) + over = true + + return seqs, over + } + } + } + + return seqs, over +} + +func (db *commonMsgDatabase) UserMsgsDestruct(ctx context.Context, userID, conversationID string, destructTime int64, lastMsgDestructTime time.Time) (seqs []int64, err error) { var index int64 + + // refresh msg list for { - // from oldest 2 newest - msgDocModel, err := db.msgDocDatabase.GetMsgDocModelByIndex(ctx, conversationID, index, 1) - if err != nil || msgDocModel.DocID == "" { - if err != nil { - if err == unrelation.ErrMsgListNotExist { + // from oldest to newest + msgDocModel, err2 := db.msgDocDatabase.GetMsgDocModelByIndex(ctx, conversationID, index, 1) + if err2 != nil || msgDocModel.DocID == "" { + if err2 != nil { + if errors.Is(err2, unrelation.ErrMsgListNotExist) { log.ZDebug(ctx, "not doc find", "conversationID", conversationID, "userID", userID, "index", index) } else { - log.ZError(ctx, "deleteMsgRecursion GetUserMsgListByIndex failed", err, "conversationID", conversationID, "index", index) + log.ZError(ctx, "deleteMsgRecursion GetUserMsgListByIndex failed", err2, "conversationID", conversationID, "index", index) } } - // 获取报错,或者获取不到了,物理删除并且返回seq delMongoMsgsPhysical(delStruct.delDocIDList), 结束递归 + // If there is an error or no message document is found, delete the message physically and return the sequence number, then end the recursion. break } index++ //&& msgDocModel.Msg[0].Msg.SendTime > lastMsgDestructTime.UnixMilli() - if len(msgDocModel.Msg) > 0 { - i := 0 - var over bool - for _, msg := range msgDocModel.Msg { - i++ - if msg != nil && msg.Msg != nil && msg.Msg.SendTime+destructTime*1000 <= time.Now().UnixMilli() { - if msg.Msg.SendTime+destructTime*1000 > lastMsgDestructTime.UnixMilli() && !utils.Contain(userID, msg.DelList...) { - seqs = append(seqs, msg.Msg.Seq) - } - } else { - log.ZDebug(ctx, "all msg need destruct is found", "conversationID", conversationID, "userID", userID, "index", index, "stop index", i) - over = true - break - } - } - if over { - break - } + curSeqs, over := processMsgDocModel(ctx, msgDocModel, userID, conversationID, index, destructTime, lastMsgDestructTime) + seqs = append(seqs, curSeqs...) + if over { + break + } + } + // Log the result of the function call. + log.ZDebug(ctx, "UserMsgsDestruct", "conversationID", conversationID, "userID", userID, "seqs", seqs) + if len(seqs) == 0 { + return seqs, nil + } + // if len(seqs) > 0 + userMinSeq := seqs[len(seqs)-1] + 1 + currentUserMinSeq, err := db.cache.GetConversationUserMinSeq(ctx, conversationID, userID) + if err != nil && !errors.Is(err, redis.Nil) { + return nil, err + } + if currentUserMinSeq < userMinSeq { + if err := db.cache.SetConversationUserMinSeq(ctx, conversationID, userID, userMinSeq); err != nil { + return nil, err } } - log.ZDebug(ctx, "UserMsgsDestruct", "conversationID", conversationID, "userID", userID, "seqs", seqs) - if len(seqs) > 0 { - userMinSeq := seqs[len(seqs)-1] + 1 - currentUserMinSeq, err := db.cache.GetConversationUserMinSeq(ctx, conversationID, userID) - if err != nil && errs.Unwrap(err) != redis.Nil { - return nil, err - } - if currentUserMinSeq < userMinSeq { - if err := db.cache.SetConversationUserMinSeq(ctx, conversationID, userID, userMinSeq); err != nil { - return nil, err - } - } - } return seqs, nil } @@ -709,47 +787,60 @@ func (db *commonMsgDatabase) deleteMsgRecursion(ctx context.Context, conversatio msgDocModel, err := db.msgDocDatabase.GetMsgDocModelByIndex(ctx, conversationID, index, 1) if err != nil || msgDocModel.DocID == "" { if err != nil { - if err == unrelation.ErrMsgListNotExist { + if errors.Is(err, unrelation.ErrMsgListNotExist) { log.ZDebug(ctx, "deleteMsgRecursion ErrMsgListNotExist", "conversationID", conversationID, "index:", index) } else { log.ZError(ctx, "deleteMsgRecursion GetUserMsgListByIndex failed", err, "conversationID", conversationID, "index", index) } } - // 获取报错,或者获取不到了,物理删除并且返回seq delMongoMsgsPhysical(delStruct.delDocIDList), 结束递归 + // get error or miss content, delete physically and return minSeq,delMongoMsgsPhysical(delStruct.delDocIDList), end recursion err = db.msgDocDatabase.DeleteDocs(ctx, delStruct.delDocIDs) if err != nil { return 0, err } + return delStruct.getSetMinSeq() + 1, nil } + log.ZDebug(ctx, "doc info", "conversationID", conversationID, "index", index, "docID", msgDocModel.DocID, "len", len(msgDocModel.Msg)) if int64(len(msgDocModel.Msg)) > db.msg.GetSingleGocMsgNum() { log.ZWarn(ctx, "msgs too large", nil, "lenth", len(msgDocModel.Msg), "docID:", msgDocModel.DocID) } - if msgDocModel.IsFull() && msgDocModel.Msg[len(msgDocModel.Msg)-1].Msg.SendTime+(remainTime*1000) < utils.GetCurrentTimestampByMill() { - log.ZDebug(ctx, "doc is full and all msg is expired", "docID", msgDocModel.DocID) - delStruct.delDocIDs = append(delStruct.delDocIDs, msgDocModel.DocID) - delStruct.minSeq = msgDocModel.Msg[len(msgDocModel.Msg)-1].Msg.Seq + fullAndExpired := msgDocModel.IsFull() && msgDocModel.Msg[len(msgDocModel.Msg)-1].Msg.SendTime+(remainTime*1000) < utils.GetCurrentTimestampByMill() + if fullAndExpired { + handleFullAndExpiredForDeleteMsgRecursion(ctx, msgDocModel, delStruct) } else { - var delMsgIndexs []int - for i, MsgInfoModel := range msgDocModel.Msg { - if MsgInfoModel != nil && MsgInfoModel.Msg != nil { - if utils.GetCurrentTimestampByMill() > MsgInfoModel.Msg.SendTime+(remainTime*1000) { - delMsgIndexs = append(delMsgIndexs, i) - } - } - } - if len(delMsgIndexs) > 0 { - if err := db.msgDocDatabase.DeleteMsgsInOneDocByIndex(ctx, msgDocModel.DocID, delMsgIndexs); err != nil { - log.ZError(ctx, "deleteMsgRecursion DeleteMsgsInOneDocByIndex failed", err, "conversationID", conversationID, "index", index) - } - delStruct.minSeq = int64(msgDocModel.Msg[delMsgIndexs[len(delMsgIndexs)-1]].Msg.Seq) - } + handleNotFullAndExpiredForDeleteMsgRecursion(ctx, msgDocModel, remainTime, index, conversationID, delStruct, db) } seq, err := db.deleteMsgRecursion(ctx, conversationID, index+1, delStruct, remainTime) + return seq, err } +func handleFullAndExpiredForDeleteMsgRecursion(ctx context.Context, msgDocModel *unrelationtb.MsgDocModel, delStruct *delMsgRecursionStruct) { + log.ZDebug(ctx, "doc is full and all msg is expired", "docID", msgDocModel.DocID) + delStruct.delDocIDs = append(delStruct.delDocIDs, msgDocModel.DocID) + delStruct.minSeq = msgDocModel.Msg[len(msgDocModel.Msg)-1].Msg.Seq +} + +func handleNotFullAndExpiredForDeleteMsgRecursion(ctx context.Context, msgDocModel *unrelationtb.MsgDocModel, remainTime, index int64, conversationID string, delStruct *delMsgRecursionStruct, db *commonMsgDatabase) { + var delMsgIndexs []int + for i, MsgInfoModel := range msgDocModel.Msg { + if MsgInfoModel != nil && MsgInfoModel.Msg != nil { + if utils.GetCurrentTimestampByMill() > MsgInfoModel.Msg.SendTime+(remainTime*1000) { + delMsgIndexs = append(delMsgIndexs, i) + } + } + } + if len(delMsgIndexs) > 0 { + err2 := db.msgDocDatabase.DeleteMsgsInOneDocByIndex(ctx, msgDocModel.DocID, delMsgIndexs) + if err2 != nil { + log.ZError(ctx, "deleteMsgRecursion DeleteMsgsInOneDocByIndex failed", err2, "conversationID", conversationID, "index", index) + } + delStruct.minSeq = msgDocModel.Msg[delMsgIndexs[len(delMsgIndexs)-1]].Msg.Seq + } +} + func (db *commonMsgDatabase) DeleteMsgsPhysicalBySeqs(ctx context.Context, conversationID string, allSeqs []int64) error { if err := db.cache.DeleteMessages(ctx, conversationID, allSeqs); err != nil { return err @@ -763,13 +854,15 @@ func (db *commonMsgDatabase) DeleteMsgsPhysicalBySeqs(ctx context.Context, conve return err } } + return nil } func (db *commonMsgDatabase) DeleteUserMsgsBySeqs(ctx context.Context, userID string, conversationID string, seqs []int64) error { cachedMsgs, _, err := db.cache.GetMessagesBySeq(ctx, conversationID, seqs) - if err != nil && errs.Unwrap(err) != redis.Nil { + if err != nil && errors.Is(err, redis.Nil) { log.ZWarn(ctx, "DeleteUserMsgsBySeqs", err, "conversationID", conversationID, "seqs", seqs) + return err } if len(cachedMsgs) > 0 { @@ -789,6 +882,7 @@ func (db *commonMsgDatabase) DeleteUserMsgsBySeqs(ctx context.Context, userID st } } } + return nil } @@ -800,11 +894,12 @@ func (db *commonMsgDatabase) CleanUpUserConversationsMsgs(ctx context.Context, u for _, conversationID := range conversationIDs { maxSeq, err := db.cache.GetMaxSeq(ctx, conversationID) if err != nil { - if err == redis.Nil { + if errors.Is(err, redis.Nil) { log.ZInfo(ctx, "max seq is nil", "conversationID", conversationID) } else { log.ZError(ctx, "get max seq failed", err, "conversationID", conversationID) } + continue } if err := db.cache.SetMinSeq(ctx, conversationID, maxSeq+1); err != nil { @@ -898,6 +993,7 @@ func (db *commonMsgDatabase) GetConversationMinMaxSeqInMongoAndCache(ctx context if err != nil { return } + return } @@ -916,6 +1012,7 @@ func (db *commonMsgDatabase) GetMinMaxSeqMongo(ctx context.Context, conversation return } maxSeqMongo = newestMsgMongo.Msg.Seq + return } @@ -943,7 +1040,7 @@ func (db *commonMsgDatabase) RangeGroupSendCount( } func (db *commonMsgDatabase) SearchMessage(ctx context.Context, req *pbmsg.SearchMessageReq) (total int32, msgData []*sdkws.MsgData, err error) { - var totalMsgs []*sdkws.MsgData + totalMsgs := make([]*sdkws.MsgData, 0) total, msgs, err := db.msgDocDatabase.SearchMessage(ctx, req) if err != nil { return 0, nil, err @@ -954,6 +1051,7 @@ func (db *commonMsgDatabase) SearchMessage(ctx context.Context, req *pbmsg.Searc } totalMsgs = append(totalMsgs, convert.MsgDB2Pb(msg.Msg)) } + return total, totalMsgs, nil } diff --git a/pkg/common/db/controller/msg_test.go b/pkg/common/db/controller/msg_test.go index 80e2db122..15448674b 100644 --- a/pkg/common/db/controller/msg_test.go +++ b/pkg/common/db/controller/msg_test.go @@ -162,6 +162,7 @@ func GetDB() *commonMsgDatabase { if err != nil { panic(err) } + return &commonMsgDatabase{ msgDocDatabase: unrelation.NewMsgMongoDriver(mongo.GetDatabase()), } diff --git a/pkg/common/db/controller/s3.go b/pkg/common/db/controller/s3.go index 6ef3e73b3..f848f15a4 100644 --- a/pkg/common/db/controller/s3.go +++ b/pkg/common/db/controller/s3.go @@ -89,5 +89,6 @@ func (s *s3Database) AccessURL(ctx context.Context, name string, expire time.Dur if err != nil { return time.Time{}, "", err } + return expireTime, rawURL, nil } diff --git a/pkg/common/db/controller/user.go b/pkg/common/db/controller/user.go index 9c6fdc5c4..d4a120f1c 100644 --- a/pkg/common/db/controller/user.go +++ b/pkg/common/db/controller/user.go @@ -90,6 +90,7 @@ func (u *userDatabase) InitOnce(ctx context.Context, users []*relation.UserModel if len(miss) > 0 { _ = u.userDB.Create(ctx, miss) } + return nil } @@ -102,30 +103,35 @@ func (u *userDatabase) FindWithError(ctx context.Context, userIDs []string) (use if len(users) != len(userIDs) { err = errs.ErrRecordNotFound.Wrap("userID not found") } + return } // Find Get the information of the specified user. If the userID is not found, no error will be returned. func (u *userDatabase) Find(ctx context.Context, userIDs []string) (users []*relation.UserModel, err error) { users, err = u.cache.GetUsersInfo(ctx, userIDs) + return } // Create Insert multiple external guarantees that the userID is not repeated and does not exist in the db. func (u *userDatabase) Create(ctx context.Context, users []*relation.UserModel) (err error) { - if err := u.tx.Transaction(func(tx any) error { + err = u.tx.Transaction(func(tx any) error { err = u.userDB.Create(ctx, users) if err != nil { return err } + return nil - }); err != nil { + }) + if err != nil { return err } - var userIDs []string + userIDs := make([]string, 0, len(users)) for _, user := range users { userIDs = append(userIDs, user.UserID) } + return u.cache.DelUsersInfo(userIDs...).ExecDel(ctx) } @@ -134,6 +140,7 @@ func (u *userDatabase) Update(ctx context.Context, user *relation.UserModel) (er if err := u.userDB.Update(ctx, user); err != nil { return err } + return u.cache.DelUsersInfo(user.UserID).ExecDel(ctx) } @@ -142,6 +149,7 @@ func (u *userDatabase) UpdateByMap(ctx context.Context, userID string, args map[ if err := u.userDB.UpdateByMap(ctx, userID, args); err != nil { return err } + return u.cache.DelUsersInfo(userID).ExecDel(ctx) } @@ -162,6 +170,7 @@ func (u *userDatabase) IsExist(ctx context.Context, userIDs []string) (exist boo if len(users) > 0 { return true, nil } + return false, nil } @@ -183,12 +192,14 @@ func (u *userDatabase) CountRangeEverydayTotal(ctx context.Context, start time.T // SubscribeUsersStatus Subscribe or unsubscribe a user's presence status. func (u *userDatabase) SubscribeUsersStatus(ctx context.Context, userID string, userIDs []string) error { err := u.mongoDB.AddSubscriptionList(ctx, userID, userIDs) + return err } // UnsubscribeUsersStatus unsubscribe a user's presence status. func (u *userDatabase) UnsubscribeUsersStatus(ctx context.Context, userID string, userIDs []string) error { err := u.mongoDB.UnsubscriptionList(ctx, userID, userIDs) + return err } @@ -198,6 +209,7 @@ func (u *userDatabase) GetAllSubscribeList(ctx context.Context, userID string) ( if err != nil { return nil, err } + return list, nil } @@ -207,12 +219,14 @@ func (u *userDatabase) GetSubscribedList(ctx context.Context, userID string) ([] if err != nil { return nil, err } + return list, nil } // GetUserStatus get user status. func (u *userDatabase) GetUserStatus(ctx context.Context, userIDs []string) ([]*user.OnlineStatus, error) { onlineStatusList, err := u.cache.GetUserStatus(ctx, userIDs) + return onlineStatusList, err } diff --git a/pkg/common/db/localcache/conversation.go b/pkg/common/db/localcache/conversation.go index c40bcdbce..b43e58257 100644 --- a/pkg/common/db/localcache/conversation.go +++ b/pkg/common/db/localcache/conversation.go @@ -50,6 +50,7 @@ func (g *ConversationLocalCache) GetRecvMsgNotNotifyUserIDs(ctx context.Context, if err != nil { return nil, err } + return resp.UserIDs, nil } diff --git a/pkg/common/db/localcache/group.go b/pkg/common/db/localcache/group.go index 4958d91ee..140c3aeaf 100644 --- a/pkg/common/db/localcache/group.go +++ b/pkg/common/db/localcache/group.go @@ -57,6 +57,7 @@ func (g *GroupLocalCache) GetGroupMemberIDs(ctx context.Context, groupID string) localHashInfo, ok := g.cache[groupID] if ok && localHashInfo.memberListHash == resp.GroupAbstractInfos[0].GroupMemberListHash { g.lock.Unlock() + return localHashInfo.userIDs, nil } g.lock.Unlock() @@ -74,5 +75,6 @@ func (g *GroupLocalCache) GetGroupMemberIDs(ctx context.Context, groupID string) memberListHash: resp.GroupAbstractInfos[0].GroupMemberListHash, userIDs: groupMembersResp.UserIDs, } + return g.cache[groupID].userIDs, nil } diff --git a/pkg/common/db/relation/black_model.go b/pkg/common/db/relation/black_model.go index 34123c7a3..58dae3745 100644 --- a/pkg/common/db/relation/black_model.go +++ b/pkg/common/db/relation/black_model.go @@ -63,10 +63,11 @@ func (b *BlackGorm) Find( ctx context.Context, blacks []*relation.BlackModel, ) (blackList []*relation.BlackModel, err error) { - var where [][]interface{} + where := make([][]interface{}, 0, len(blacks)) for _, black := range blacks { where = append(where, []interface{}{black.OwnerUserID, black.BlockUserID}) } + return blackList, utils.Wrap( b.db(ctx).Where("(owner_user_id, block_user_id) in ?", where).Find(&blackList).Error, "", @@ -75,6 +76,7 @@ func (b *BlackGorm) Find( func (b *BlackGorm) Take(ctx context.Context, ownerUserID, blockUserID string) (black *relation.BlackModel, err error) { black = &relation.BlackModel{} + return black, utils.Wrap( b.db(ctx).Where("owner_user_id = ? and block_user_id = ?", ownerUserID, blockUserID).Take(black).Error, "", @@ -96,6 +98,7 @@ func (b *BlackGorm) FindOwnerBlacks( showNumber, ) total = int64(totalUint32) + return } diff --git a/pkg/common/db/relation/chat_log_model.go b/pkg/common/db/relation/chat_log_model.go index f183a543f..f474a2d34 100644 --- a/pkg/common/db/relation/chat_log_model.go +++ b/pkg/common/db/relation/chat_log_model.go @@ -15,6 +15,8 @@ package relation import ( + + //nolint:staticcheck //tofix: SA1019: "github.com/golang/protobuf/jsonpb" is deprecated: Use the "google.golang.org/protobuf/encoding/protojson" package instead. "github.com/golang/protobuf/jsonpb" "github.com/jinzhu/copier" "google.golang.org/protobuf/proto" @@ -38,7 +40,10 @@ func NewChatLogGorm(db *gorm.DB) relation.ChatLogModelInterface { func (c *ChatLogGorm) Create(msg *pbmsg.MsgDataToMQ) error { chatLog := new(relation.ChatLogModel) - copier.Copy(chatLog, msg.MsgData) + err := copier.Copy(chatLog, msg.MsgData) + if err != nil { + return err + } switch msg.MsgData.SessionType { case constant.GroupChatType, constant.SuperGroupChatType: chatLog.RecvID = msg.MsgData.GroupID @@ -59,5 +64,6 @@ func (c *ChatLogGorm) Create(msg *pbmsg.MsgDataToMQ) error { } chatLog.CreateTime = utils.UnixMillSecondToTime(msg.MsgData.CreateTime) chatLog.SendTime = utils.UnixMillSecondToTime(msg.MsgData.SendTime) + return c.DB.Create(chatLog).Error } diff --git a/pkg/common/db/relation/conversation_model.go b/pkg/common/db/relation/conversation_model.go index d5ca92ec2..37a4e02be 100644 --- a/pkg/common/db/relation/conversation_model.go +++ b/pkg/common/db/relation/conversation_model.go @@ -54,6 +54,7 @@ func (c *ConversationGorm) UpdateByMap( args map[string]interface{}, ) (rows int64, err error) { result := c.db(ctx).Where("owner_user_id IN (?) and conversation_id=?", userIDList, conversationID).Updates(args) + return result.RowsAffected, utils.Wrap(result.Error, "") } @@ -79,6 +80,7 @@ func (c *ConversationGorm) Find( Error, "", ) + return conversations, err } @@ -87,6 +89,7 @@ func (c *ConversationGorm) Take( userID, conversationID string, ) (conversation *relation.ConversationModel, err error) { cc := &relation.ConversationModel{} + return cc, utils.Wrap( c.db(ctx).Where("conversation_id = ? And owner_user_id = ?", conversationID, userID).Take(cc).Error, "", @@ -169,6 +172,7 @@ func (c *ConversationGorm) GetUserRecvMsgOpt( ownerUserID, conversationID string, ) (opt int, err error) { var conversation relation.ConversationModel + return int( conversation.RecvMsgOpt, ), utils.Wrap( @@ -219,6 +223,7 @@ func (c *ConversationGorm) GetConversationIDsNeedDestruct( func (c *ConversationGorm) GetConversationRecvMsgOpt(ctx context.Context, userID string, conversationID string) (int32, error) { var recvMsgOpt int32 + return recvMsgOpt, errs.Wrap( c.db(ctx). Model(&relation.ConversationModel{}). @@ -230,6 +235,7 @@ func (c *ConversationGorm) GetConversationRecvMsgOpt(ctx context.Context, userID func (c *ConversationGorm) GetConversationNotReceiveMessageUserIDs(ctx context.Context, conversationID string) ([]string, error) { var userIDs []string + return userIDs, errs.Wrap( c.db(ctx). Model(&relation.ConversationModel{}). diff --git a/pkg/common/db/relation/friend_model.go b/pkg/common/db/relation/friend_model.go index 869254455..103afd8aa 100644 --- a/pkg/common/db/relation/friend_model.go +++ b/pkg/common/db/relation/friend_model.go @@ -50,6 +50,7 @@ func (f *FriendGorm) Delete(ctx context.Context, ownerUserID string, friendUserI Error, "", ) + return err } @@ -84,6 +85,7 @@ func (f *FriendGorm) UpdateRemark(ctx context.Context, ownerUserID, friendUserID } m := make(map[string]interface{}, 1) m["remark"] = "" + return utils.Wrap(f.db(ctx).Where("owner_user_id = ?", ownerUserID).Updates(m).Error, "") } @@ -93,6 +95,7 @@ func (f *FriendGorm) Take( ownerUserID, friendUserID string, ) (friend *relation.FriendModel, err error) { friend = &relation.FriendModel{} + return friend, utils.Wrap( f.db(ctx).Where("owner_user_id = ? and friend_user_id", ownerUserID, friendUserID).Take(friend).Error, "", @@ -156,6 +159,7 @@ func (f *FriendGorm) FindOwnerFriends( Error, "", ) + return } @@ -178,6 +182,7 @@ func (f *FriendGorm) FindInWhoseFriends( Error, "", ) + return } diff --git a/pkg/common/db/relation/friend_request_model.go b/pkg/common/db/relation/friend_request_model.go index 5678f7b7b..4fd372c0c 100644 --- a/pkg/common/db/relation/friend_request_model.go +++ b/pkg/common/db/relation/friend_request_model.go @@ -74,6 +74,7 @@ func (f *FriendRequestGorm) Update(ctx context.Context, friendRequest *relation. fr2 := *friendRequest fr2.FromUserID = "" fr2.ToUserID = "" + return utils.Wrap( f.db(ctx). Where("from_user_id = ? AND to_user_id =?", friendRequest.FromUserID, friendRequest.ToUserID). @@ -93,6 +94,7 @@ func (f *FriendRequestGorm) Find( f.db(ctx).Where("from_user_id = ? and to_user_id = ?", fromUserID, toUserID).Find(friendRequest).Error, "", ) + return friendRequest, err } @@ -105,6 +107,7 @@ func (f *FriendRequestGorm) Take( f.db(ctx).Where("from_user_id = ? and to_user_id = ?", fromUserID, toUserID).Take(friendRequest).Error, "", ) + return friendRequest, err } @@ -127,6 +130,7 @@ func (f *FriendRequestGorm) FindToUserID( Error, "", ) + return } @@ -149,6 +153,7 @@ func (f *FriendRequestGorm) FindFromUserID( Error, "", ) + return } @@ -160,5 +165,6 @@ func (f *FriendRequestGorm) FindBothFriendRequests(ctx context.Context, fromUser Error, "", ) + return } diff --git a/pkg/common/db/relation/group_member_model.go b/pkg/common/db/relation/group_member_model.go index 312e32054..48baea61c 100644 --- a/pkg/common/db/relation/group_member_model.go +++ b/pkg/common/db/relation/group_member_model.go @@ -68,6 +68,7 @@ func (g *GroupMemberGorm) UpdateRoleLevel( db := g.db(ctx).Where("group_id = ? and user_id = ?", groupID, userID).Updates(map[string]any{ "role_level": roleLevel, }) + return db.RowsAffected, utils.Wrap(db.Error, "") } @@ -87,6 +88,7 @@ func (g *GroupMemberGorm) Find( if len(roleLevels) > 0 { db = db.Where("role_level in (?)", roleLevels) } + return groupMembers, utils.Wrap(db.Find(&groupMembers).Error, "") } @@ -96,6 +98,7 @@ func (g *GroupMemberGorm) Take( userID string, ) (groupMember *relation.GroupMemberModel, err error) { groupMember = &relation.GroupMemberModel{} + return groupMember, utils.Wrap( g.db(ctx).Where("group_id = ? and user_id = ?", groupID, userID).Take(groupMember).Error, "", @@ -107,6 +110,7 @@ func (g *GroupMemberGorm) TakeOwner( groupID string, ) (groupMember *relation.GroupMemberModel, err error) { groupMember = &relation.GroupMemberModel{} + return groupMember, utils.Wrap( g.db(ctx).Where("group_id = ? and role_level = ?", groupID, constant.GroupOwner).Take(groupMember).Error, "", @@ -125,6 +129,7 @@ func (g *GroupMemberGorm) SearchMember( ormutil.GormIn(&db, "group_id", groupIDs) ormutil.GormIn(&db, "user_id", userIDs) ormutil.GormIn(&db, "role_level", roleLevels) + return ormutil.GormSearch[relation.GroupMemberModel](db, []string{"nickname"}, keyword, pageNumber, showNumber) } @@ -152,6 +157,7 @@ func (g *GroupMemberGorm) FindJoinUserID( groupUsers[item.GroupID] = append(v, item.UserID) } } + return groupUsers, nil } @@ -182,6 +188,7 @@ func (g *GroupMemberGorm) FindUsersJoinedGroupID(ctx context.Context, userIDs [] result[groupMember.UserID] = append(v, groupMember.GroupID) } } + return result, nil } diff --git a/pkg/common/db/relation/group_model.go b/pkg/common/db/relation/group_model.go index 7a8eee9f0..508a86f5f 100644 --- a/pkg/common/db/relation/group_model.go +++ b/pkg/common/db/relation/group_model.go @@ -61,12 +61,14 @@ func (g *GroupGorm) Find(ctx context.Context, groupIDs []string) (groups []*rela func (g *GroupGorm) Take(ctx context.Context, groupID string) (group *relation.GroupModel, err error) { group = &relation.GroupModel{} + return group, utils.Wrap(g.DB.Where("group_id = ?", groupID).Take(group).Error, "") } func (g *GroupGorm) Search(ctx context.Context, keyword string, pageNumber, showNumber int32) (total uint32, groups []*relation.GroupModel, err error) { db := g.DB db = db.WithContext(ctx).Where("status!=?", constant.GroupStatusDismissed) + return ormutil.GormSearch[relation.GroupModel](db, []string{"name"}, keyword, pageNumber, showNumber) } @@ -82,6 +84,7 @@ func (g *GroupGorm) CountTotal(ctx context.Context, before *time.Time) (count in if err := db.Count(&count).Error; err != nil { return 0, err } + return count, nil } @@ -98,6 +101,7 @@ func (g *GroupGorm) CountRangeEverydayTotal(ctx context.Context, start time.Time for _, r := range res { v[r.Date.Format("2006-01-02")] = r.Count } + return v, nil } diff --git a/pkg/common/db/relation/group_request_model.go b/pkg/common/db/relation/group_request_model.go index af3f277e8..691a83bb0 100644 --- a/pkg/common/db/relation/group_request_model.go +++ b/pkg/common/db/relation/group_request_model.go @@ -80,6 +80,7 @@ func (g *GroupRequestGorm) Take( userID string, ) (groupRequest *relation.GroupRequestModel, err error) { groupRequest = &relation.GroupRequestModel{} + return groupRequest, utils.Wrap( g.DB.WithContext(ctx).Where("group_id = ? and user_id = ? ", groupID, userID).Take(groupRequest).Error, utils.GetSelfFuncName(), @@ -114,5 +115,6 @@ func (g *GroupRequestGorm) PageGroup( func (g *GroupRequestGorm) FindGroupRequests(ctx context.Context, groupID string, userIDs []string) (total int64, groupRequests []*relation.GroupRequestModel, err error) { err = g.DB.WithContext(ctx).Where("group_id = ? and user_id in ?", groupID, userIDs).Find(&groupRequests).Error + return int64(len(groupRequests)), groupRequests, utils.Wrap(err, utils.GetSelfFuncName()) } diff --git a/pkg/common/db/relation/log_model.go b/pkg/common/db/relation/log_model.go index 53365ca5b..fc1a82cdb 100644 --- a/pkg/common/db/relation/log_model.go +++ b/pkg/common/db/relation/log_model.go @@ -25,6 +25,7 @@ func (l *LogGorm) Search(ctx context.Context, keyword string, start time.Time, e db = l.db.WithContext(ctx).Where("create_time <= ?", end) } db = db.Order("create_time desc") + return ormutil.GormSearch[relationtb.Log](db, []string{"user_id"}, keyword, pageNumber, showNumber) } @@ -32,6 +33,7 @@ func (l *LogGorm) Delete(ctx context.Context, logIDs []string, userID string) er if userID == "" { return errs.Wrap(l.db.WithContext(ctx).Where("log_id in ?", logIDs).Delete(&relationtb.Log{}).Error) } + return errs.Wrap(l.db.WithContext(ctx).Where("log_id in ? and user_id=?", logIDs, userID).Delete(&relationtb.Log{}).Error) } @@ -40,10 +42,15 @@ func (l *LogGorm) Get(ctx context.Context, logIDs []string, userID string) ([]*r if userID == "" { return logs, errs.Wrap(l.db.WithContext(ctx).Where("log_id in ?", logIDs).Find(&logs).Error) } + return logs, errs.Wrap(l.db.WithContext(ctx).Where("log_id in ? and user_id=?", logIDs, userID).Find(&logs).Error) } func NewLogGorm(db *gorm.DB) relationtb.LogInterface { - db.AutoMigrate(&relationtb.Log{}) + err := db.AutoMigrate(&relationtb.Log{}) + if err != nil { + panic(err) + } + return &LogGorm{db: db} } diff --git a/pkg/common/db/relation/meta_db.go b/pkg/common/db/relation/meta_db.go index 6ab980120..00c1b76bc 100644 --- a/pkg/common/db/relation/meta_db.go +++ b/pkg/common/db/relation/meta_db.go @@ -34,5 +34,6 @@ func NewMetaDB(db *gorm.DB, table any) *MetaDB { func (g *MetaDB) db(ctx context.Context) *gorm.DB { db := g.DB.WithContext(ctx).Model(g.table) + return db } diff --git a/pkg/common/db/relation/mysql_init.go b/pkg/common/db/relation/mysql_init.go index 0e5ea5e43..550053ea2 100644 --- a/pkg/common/db/relation/mysql_init.go +++ b/pkg/common/db/relation/mysql_init.go @@ -15,6 +15,7 @@ package relation import ( + "errors" "fmt" "time" @@ -82,6 +83,7 @@ func newMysqlGormDB() (*gorm.DB, error) { sqlDB.SetConnMaxLifetime(time.Second * time.Duration(config.Config.Mysql.MaxLifeTime)) sqlDB.SetMaxOpenConns(config.Config.Mysql.MaxOpenConn) sqlDB.SetMaxIdleConns(config.Config.Mysql.MaxIdleConn) + return db, nil } @@ -94,11 +96,13 @@ func connectToDatabase(dsn string, maxRetry int) (*gorm.DB, error) { if err == nil { return db, nil } - if mysqlErr, ok := err.(*mysqldriver.MySQLError); ok && mysqlErr.Number == 1045 { + var mysqlErr *mysqldriver.MySQLError + if errors.As(err, &mysqlErr) && mysqlErr.Number == 1045 { return nil, err } time.Sleep(time.Duration(1) * time.Second) } + return nil, err } @@ -106,6 +110,7 @@ func connectToDatabase(dsn string, maxRetry int) (*gorm.DB, error) { func NewGormDB() (*gorm.DB, error) { specialerror.AddReplace(gorm.ErrRecordNotFound, errs.ErrRecordNotFound) specialerror.AddErrHandler(replaceDuplicateKey) + return newMysqlGormDB() } @@ -113,12 +118,15 @@ func replaceDuplicateKey(err error) errs.CodeError { if IsMysqlDuplicateKey(err) { return errs.ErrDuplicateKey } + return nil } func IsMysqlDuplicateKey(err error) bool { - if mysqlErr, ok := err.(*mysqldriver.MySQLError); ok { + var mysqlErr *mysqldriver.MySQLError + if errors.As(err, &mysqlErr) { return mysqlErr.Number == 1062 } + return false } diff --git a/pkg/common/db/relation/object_model.go b/pkg/common/db/relation/object_model.go index c5624a8d4..34b511c6a 100644 --- a/pkg/common/db/relation/object_model.go +++ b/pkg/common/db/relation/object_model.go @@ -44,10 +44,12 @@ func (o *ObjectInfoGorm) SetObject(ctx context.Context, obj *relation.ObjectMode if err := o.DB.WithContext(ctx).Where("name = ?", obj.Name).FirstOrCreate(obj).Error; err != nil { return errs.Wrap(err) } + return nil } func (o *ObjectInfoGorm) Take(ctx context.Context, name string) (info *relation.ObjectModel, err error) { info = &relation.ObjectModel{} + return info, errs.Wrap(o.DB.WithContext(ctx).Where("name = ?", name).Take(info).Error) } diff --git a/pkg/common/db/relation/user_model.go b/pkg/common/db/relation/user_model.go index b04c29816..ef605abd9 100644 --- a/pkg/common/db/relation/user_model.go +++ b/pkg/common/db/relation/user_model.go @@ -53,6 +53,7 @@ func (u *UserGorm) Update(ctx context.Context, user *relation.UserModel) (err er // 获取指定用户信息 不存在,也不返回错误. func (u *UserGorm) Find(ctx context.Context, userIDs []string) (users []*relation.UserModel, err error) { err = utils.Wrap(u.db(ctx).Where("user_id in (?)", userIDs).Find(&users).Error, "") + return users, err } @@ -60,6 +61,7 @@ func (u *UserGorm) Find(ctx context.Context, userIDs []string) (users []*relatio func (u *UserGorm) Take(ctx context.Context, userID string) (user *relation.UserModel, err error) { user = &relation.UserModel{} err = utils.Wrap(u.db(ctx).Where("user_id = ?", userID).Take(&user).Error, "") + return user, err } @@ -81,6 +83,7 @@ func (u *UserGorm) Page( Error, "", ) + return } @@ -88,13 +91,14 @@ func (u *UserGorm) Page( func (u *UserGorm) GetAllUserID(ctx context.Context, pageNumber, showNumber int32) (userIDs []string, err error) { if pageNumber == 0 || showNumber == 0 { return userIDs, errs.Wrap(u.db(ctx).Pluck("user_id", &userIDs).Error) - } else { - return userIDs, errs.Wrap(u.db(ctx).Limit(int(showNumber)).Offset(int((pageNumber-1)*showNumber)).Pluck("user_id", &userIDs).Error) } + + return userIDs, errs.Wrap(u.db(ctx).Limit(int(showNumber)).Offset(int((pageNumber-1)*showNumber)).Pluck("user_id", &userIDs).Error) } func (u *UserGorm) GetUserGlobalRecvMsgOpt(ctx context.Context, userID string) (opt int, err error) { err = u.db(ctx).Model(&relation.UserModel{}).Where("user_id = ?", userID).Pluck("global_recv_msg_opt", &opt).Error + return opt, err } @@ -106,6 +110,7 @@ func (u *UserGorm) CountTotal(ctx context.Context, before *time.Time) (count int if err := db.Count(&count).Error; err != nil { return 0, err } + return count, nil } @@ -132,5 +137,6 @@ func (u *UserGorm) CountRangeEverydayTotal( for _, r := range res { v[r.Date.Format("2006-01-02")] = r.Count } + return v, nil } diff --git a/pkg/common/db/s3/cont/controller.go b/pkg/common/db/s3/cont/controller.go index 6faa997a9..7ff6fa755 100644 --- a/pkg/common/db/s3/cont/controller.go +++ b/pkg/common/db/s3/cont/controller.go @@ -46,6 +46,7 @@ func (c *Controller) HashPath(md5 string) string { func (c *Controller) NowPath() string { now := time.Now() + return path.Join( fmt.Sprintf("%04d", now.Year()), fmt.Sprintf("%02d", now.Month()), @@ -58,6 +59,7 @@ func (c *Controller) NowPath() string { func (c *Controller) UUID() string { id := uuid.New() + return hex.EncodeToString(id[:]) } @@ -92,20 +94,24 @@ func (c *Controller) InitiateUpload(ctx context.Context, hash string, size int64 partNumber++ } if maxParts > 0 && partNumber > 0 && partNumber < maxParts { - return nil, errors.New(fmt.Sprintf("too many parts: %d", partNumber)) + return nil, fmt.Errorf("too few parts: %d", partNumber) } - if info, err := c.impl.StatObject(ctx, c.HashPath(hash)); err == nil { + info, err := c.impl.StatObject(ctx, c.HashPath(hash)) + if err == nil { return nil, &HashAlreadyExistsError{Object: info} - } else if !c.impl.IsNotFound(err) { + } + if !c.impl.IsNotFound(err) { return nil, err } + if size <= partSize { // 预签名上传 key := path.Join(tempPath, c.NowPath(), fmt.Sprintf("%s_%d_%s.presigned", hash, size, c.UUID())) - rawURL, err := c.impl.PresignedPutObject(ctx, key, expire) - if err != nil { - return nil, err + rawURL, err2 := c.impl.PresignedPutObject(ctx, key, expire) + if err2 != nil { + return nil, err2 } + return &InitiateUploadResult{ UploadID: newMultipartUploadID(multipartUploadID{ Type: UploadTypePresigned, @@ -124,38 +130,39 @@ func (c *Controller) InitiateUpload(ctx context.Context, hash string, size int64 }, }, }, nil - } else { - // 分片上传 - upload, err := c.impl.InitiateMultipartUpload(ctx, c.HashPath(hash)) + } + + // 分片上传 + upload, err := c.impl.InitiateMultipartUpload(ctx, c.HashPath(hash)) + if err != nil { + return nil, err + } + if maxParts < 0 { + maxParts = partNumber + } + var authSign *s3.AuthSignResult + if maxParts > 0 { + partNumbers := make([]int, partNumber) + for i := 0; i < maxParts; i++ { + partNumbers[i] = i + 1 + } + authSign, err = c.impl.AuthSign(ctx, upload.UploadID, upload.Key, time.Hour*24, partNumbers) if err != nil { return nil, err } - if maxParts < 0 { - maxParts = partNumber - } - var authSign *s3.AuthSignResult - if maxParts > 0 { - partNumbers := make([]int, partNumber) - for i := 0; i < maxParts; i++ { - partNumbers[i] = i + 1 - } - authSign, err = c.impl.AuthSign(ctx, upload.UploadID, upload.Key, time.Hour*24, partNumbers) - if err != nil { - return nil, err - } - } - return &InitiateUploadResult{ - UploadID: newMultipartUploadID(multipartUploadID{ - Type: UploadTypeMultipart, - ID: upload.UploadID, - Key: upload.Key, - Size: size, - Hash: hash, - }), - PartSize: partSize, - Sign: authSign, - }, nil } + + return &InitiateUploadResult{ + UploadID: newMultipartUploadID(multipartUploadID{ + Type: UploadTypeMultipart, + ID: upload.UploadID, + Key: upload.Key, + Size: size, + Hash: hash, + }), + PartSize: partSize, + Sign: authSign, + }, nil } func (c *Controller) CompleteUpload(ctx context.Context, uploadID string, partHashs []string) (*UploadResult, error) { @@ -164,8 +171,10 @@ func (c *Controller) CompleteUpload(ctx context.Context, uploadID string, partHa if err != nil { return nil, err } + //nolint:gosec //tofix G401: Use of weak cryptographic primitive if md5Sum := md5.Sum([]byte(strings.Join(partHashs, partSeparator))); hex.EncodeToString(md5Sum[:]) != upload.Hash { fmt.Println("CompleteUpload sum:", hex.EncodeToString(md5Sum[:]), "upload hash:", upload.Hash) + return nil, errors.New("md5 mismatching") } if info, err := c.impl.StatObject(ctx, c.HashPath(upload.Hash)); err == nil { @@ -193,7 +202,7 @@ func (c *Controller) CompleteUpload(ctx context.Context, uploadID string, partHa ETag: part, } } - // todo: 验证大小 + // todo: verify size result, err := c.impl.CompleteMultipartUpload(ctx, upload.ID, upload.Key, parts) if err != nil { return nil, err @@ -208,11 +217,12 @@ func (c *Controller) CompleteUpload(ctx context.Context, uploadID string, partHa if uploadInfo.Size != upload.Size { return nil, errors.New("upload size mismatching") } + //nolint:gosec //G401: Use of weak cryptographic primitive md5Sum := md5.Sum([]byte(strings.Join([]string{uploadInfo.ETag}, partSeparator))) if md5val := hex.EncodeToString(md5Sum[:]); md5val != upload.Hash { return nil, errs.ErrArgs.Wrap(fmt.Sprintf("md5 mismatching %s != %s", md5val, upload.Hash)) } - // 防止在这个时候,并发操作,导致文件被覆盖 + // Prevent concurrent operations at this time to avoid file overwrite copyInfo, err := c.impl.CopyObject(ctx, uploadInfo.Key, upload.Key+"."+c.UUID()) if err != nil { return nil, err @@ -230,6 +240,7 @@ func (c *Controller) CompleteUpload(ctx context.Context, uploadID string, partHa default: return nil, errors.New("invalid upload id type") } + return &UploadResult{ Key: targetKey, Size: upload.Size, @@ -261,5 +272,6 @@ func (c *Controller) AccessURL(ctx context.Context, name string, expire time.Dur opt.Filename = "" opt.ContentType = "" } + return c.impl.AccessURL(ctx, name, expire, opt) } diff --git a/pkg/common/db/s3/cont/id.go b/pkg/common/db/s3/cont/id.go index 47f37d4aa..a2b723b83 100644 --- a/pkg/common/db/s3/cont/id.go +++ b/pkg/common/db/s3/cont/id.go @@ -33,6 +33,7 @@ func newMultipartUploadID(id multipartUploadID) string { if err != nil { panic(err) } + return base64.StdEncoding.EncodeToString(data) } @@ -45,5 +46,6 @@ func parseMultipartUploadID(id string) (*multipartUploadID, error) { if err := json.Unmarshal(data, &upload); err != nil { return nil, fmt.Errorf("invalid multipart upload id: %w", err) } + return &upload, nil } diff --git a/pkg/common/db/s3/cos/cos.go b/pkg/common/db/s3/cos/cos.go index 7add88487..5484778a5 100644 --- a/pkg/common/db/s3/cos/cos.go +++ b/pkg/common/db/s3/cos/cos.go @@ -44,11 +44,6 @@ const ( imageWebp = "webp" ) -const ( - videoSnapshotImagePng = "png" - videoSnapshotImageJpg = "jpg" -) - func NewCos() (s3.Interface, error) { conf := config.Config.Object.Cos u, err := url.Parse(conf.BucketURL) @@ -62,6 +57,7 @@ func NewCos() (s3.Interface, error) { SessionToken: conf.SessionToken, }, }) + return &Cos{ copyURL: u.Host + "/", client: client, @@ -92,6 +88,7 @@ func (c *Cos) InitiateMultipartUpload(ctx context.Context, name string) (*s3.Ini if err != nil { return nil, err } + return &s3.InitiateMultipartUploadResult{ UploadID: result.UploadID, Bucket: result.Bucket, @@ -113,6 +110,7 @@ func (c *Cos) CompleteMultipartUpload(ctx context.Context, uploadID string, name if err != nil { return nil, err } + return &s3.CompleteMultipartUploadResult{ Location: result.Location, Bucket: result.Bucket, @@ -135,6 +133,7 @@ func (c *Cos) PartSize(ctx context.Context, size int64) (int64, error) { if size%maxNumSize != 0 { partSize++ } + return partSize, nil } @@ -157,6 +156,7 @@ func (c *Cos) AuthSign(ctx context.Context, uploadID string, name string, expire Query: url.Values{"partNumber": {strconv.Itoa(partNumber)}}, } } + return &result, nil } @@ -165,11 +165,13 @@ func (c *Cos) PresignedPutObject(ctx context.Context, name string, expire time.D if err != nil { return "", err } + return rawURL.String(), nil } func (c *Cos) DeleteObject(ctx context.Context, name string) error { _, err := c.client.Object.Delete(ctx, name) + return err } @@ -185,25 +187,26 @@ func (c *Cos) StatObject(ctx context.Context, name string) (*s3.ObjectInfo, erro if res.ETag = strings.ToLower(strings.ReplaceAll(info.Header.Get("ETag"), `"`, "")); res.ETag == "" { return nil, errors.New("StatObject etag not found") } - if contentLengthStr := info.Header.Get("Content-Length"); contentLengthStr == "" { + contentLengthStr := info.Header.Get("Content-Length") + if contentLengthStr == "" { return nil, errors.New("StatObject content-length not found") - } else { - res.Size, err = strconv.ParseInt(contentLengthStr, 10, 64) - if err != nil { - return nil, fmt.Errorf("StatObject content-length parse error: %w", err) - } - if res.Size < 0 { - return nil, errors.New("StatObject content-length must be greater than 0") - } } - if lastModified := info.Header.Get("Last-Modified"); lastModified == "" { + res.Size, err = strconv.ParseInt(contentLengthStr, 10, 64) + if err != nil { + return nil, fmt.Errorf("StatObject content-length parse error: %w", err) + } + if res.Size < 0 { + return nil, errors.New("StatObject content-length must be greater than 0") + } + lastModified := info.Header.Get("Last-Modified") + if lastModified == "" { return nil, errors.New("StatObject last-modified not found") - } else { - res.LastModified, err = time.Parse(http.TimeFormat, lastModified) - if err != nil { - return nil, fmt.Errorf("StatObject last-modified parse error: %w", err) - } } + res.LastModified, err = time.Parse(http.TimeFormat, lastModified) + if err != nil { + return nil, fmt.Errorf("StatObject last-modified parse error: %w", err) + } + return res, nil } @@ -213,6 +216,7 @@ func (c *Cos) CopyObject(ctx context.Context, src string, dst string) (*s3.CopyO if err != nil { return nil, err } + return &s3.CopyObjectInfo{ Key: dst, ETag: strings.ReplaceAll(result.ETag, `"`, ``), @@ -220,16 +224,17 @@ func (c *Cos) CopyObject(ctx context.Context, src string, dst string) (*s3.CopyO } func (c *Cos) IsNotFound(err error) bool { - switch e := err.(type) { - case *cos.ErrorResponse: - return e.Response.StatusCode == http.StatusNotFound || e.Code == "NoSuchKey" - default: - return false + var cosErr *cos.ErrorResponse + if errors.As(err, &cosErr) { + return cosErr.Response.StatusCode == http.StatusNotFound || cosErr.Code == "NoSuchKey" } + + return false } func (c *Cos) AbortMultipartUpload(ctx context.Context, uploadID string, name string) error { _, err := c.client.Object.AbortMultipartUpload(ctx, name, uploadID) + return err } @@ -257,46 +262,59 @@ func (c *Cos) ListUploadedParts(ctx context.Context, uploadID string, name strin Size: part.Size, } } + return res, nil } func (c *Cos) AccessURL(ctx context.Context, name string, expire time.Duration, opt *s3.AccessURLOption) (string, error) { var imageMogr string var option cos.PresignedURLOptions - if opt != nil { - query := make(url.Values) - if opt.Image != nil { - // https://cloud.tencent.com/document/product/436/44880 - style := make([]string, 0, 2) - wh := make([]string, 2) - if opt.Image.Width > 0 { - wh[0] = strconv.Itoa(opt.Image.Width) - } - if opt.Image.Height > 0 { - wh[1] = strconv.Itoa(opt.Image.Height) - } - if opt.Image.Width > 0 || opt.Image.Height > 0 { - style = append(style, strings.Join(wh, "x")) - } - switch opt.Image.Format { - case - imagePng, - imageJpg, - imageJpeg, - imageGif, - imageWebp: - style = append(style, "format/"+opt.Image.Format) - } - if len(style) > 0 { - imageMogr = "imageMogr2/thumbnail/" + strings.Join(style, "/") + "/ignore-error/1" - } + getImageMogr := func(opt *s3.AccessURLOption) (imageMogr string) { + if opt.Image == nil { + return imageMogr } + // https://cloud.tencent.com/document/product/436/44880 + style := make([]string, 0, 2) + wh := make([]string, 2) + if opt.Image.Width > 0 { + wh[0] = strconv.Itoa(opt.Image.Width) + } + if opt.Image.Height > 0 { + wh[1] = strconv.Itoa(opt.Image.Height) + } + if opt.Image.Width > 0 || opt.Image.Height > 0 { + style = append(style, strings.Join(wh, "x")) + } + switch opt.Image.Format { + case + imagePng, + imageJpg, + imageJpeg, + imageGif, + imageWebp: + style = append(style, "format/"+opt.Image.Format) + } + if len(style) > 0 { + imageMogr = "imageMogr2/thumbnail/" + strings.Join(style, "/") + "/ignore-error/1" + } + + return imageMogr + } + getQuery := func(opt *s3.AccessURLOption) (query url.Values) { + query = make(url.Values) if opt.ContentType != "" { query.Set("response-content-type", opt.ContentType) } if opt.Filename != "" { query.Set("response-content-disposition", `attachment; filename=`+strconv.Quote(opt.Filename)) } + + return query + } + + if opt != nil { + imageMogr = getImageMogr(opt) + query := getQuery(opt) if len(query) > 0 { option.Query = &query } @@ -317,6 +335,7 @@ func (c *Cos) AccessURL(ctx context.Context, name string, expire time.Duration, rawURL.RawQuery = rawURL.RawQuery + "&" + imageMogr } } + return rawURL.String(), nil } @@ -324,5 +343,6 @@ func (c *Cos) getPresignedURL(ctx context.Context, name string, expire time.Dura if !config.Config.Object.Cos.PublicRead { return c.client.Object.GetPresignedURL(ctx, http.MethodGet, name, c.credential.SecretID, c.credential.SecretKey, expire, opt) } + return c.client.Object.GetObjectURL(name), nil } diff --git a/pkg/common/db/s3/minio/image.go b/pkg/common/db/s3/minio/image.go index 71db1ea51..4812f47ca 100644 --- a/pkg/common/db/s3/minio/image.go +++ b/pkg/common/db/s3/minio/image.go @@ -39,6 +39,7 @@ func ImageStat(reader io.Reader) (image.Image, string, error) { func ImageWidthHeight(img image.Image) (int, int) { bounds := img.Bounds().Max + return bounds.X, bounds.Y } @@ -47,27 +48,27 @@ func resizeImage(img image.Image, maxWidth, maxHeight int) image.Image { imgWidth := bounds.Max.X imgHeight := bounds.Max.Y - // 计算缩放比例 + // Calculate scaling ratio scaleWidth := float64(maxWidth) / float64(imgWidth) scaleHeight := float64(maxHeight) / float64(imgHeight) - // 如果都为0,则不缩放,返回原始图片 + // If both maxWidth and maxHeight are 0, return the original image if maxWidth == 0 && maxHeight == 0 { return img } - // 如果宽度和高度都大于0,则选择较小的缩放比例,以保持宽高比 + // If both maxWidth and maxHeight are greater than 0, choose the smaller scaling ratio to maintain aspect ratio if maxWidth > 0 && maxHeight > 0 { scale := scaleWidth if scaleHeight < scaleWidth { scale = scaleHeight } - // 计算缩略图尺寸 + // Calculate thumbnail size thumbnailWidth := int(float64(imgWidth) * scale) thumbnailHeight := int(float64(imgHeight) * scale) - // 使用"image"库的Resample方法生成缩略图 + // Generate thumbnail using the Resample method of the "image" library thumbnail := image.NewRGBA(image.Rect(0, 0, thumbnailWidth, thumbnailHeight)) for y := 0; y < thumbnailHeight; y++ { for x := 0; x < thumbnailWidth; x++ { @@ -80,12 +81,12 @@ func resizeImage(img image.Image, maxWidth, maxHeight int) image.Image { return thumbnail } - // 如果只指定了宽度或高度,则根据最大不超过的规则生成缩略图 + // If only maxWidth or maxHeight is specified, generate thumbnail according to the "max not exceed" rule if maxWidth > 0 { thumbnailWidth := maxWidth thumbnailHeight := int(float64(imgHeight) * scaleWidth) - // 使用"image"库的Resample方法生成缩略图 + // Generate thumbnail using the Resample method of the "image" library thumbnail := image.NewRGBA(image.Rect(0, 0, thumbnailWidth, thumbnailHeight)) for y := 0; y < thumbnailHeight; y++ { for x := 0; x < thumbnailWidth; x++ { @@ -102,7 +103,7 @@ func resizeImage(img image.Image, maxWidth, maxHeight int) image.Image { thumbnailWidth := int(float64(imgWidth) * scaleHeight) thumbnailHeight := maxHeight - // 使用"image"库的Resample方法生成缩略图 + // Generate thumbnail using the Resample method of the "image" library thumbnail := image.NewRGBA(image.Rect(0, 0, thumbnailWidth, thumbnailHeight)) for y := 0; y < thumbnailHeight; y++ { for x := 0; x < thumbnailWidth; x++ { @@ -115,6 +116,6 @@ func resizeImage(img image.Image, maxWidth, maxHeight int) image.Image { return thumbnail } - // 默认情况下,返回原始图片 + // By default, return the original image return img } diff --git a/pkg/common/db/s3/minio/minio.go b/pkg/common/db/s3/minio/minio.go index 7984df5a0..a84b8c3f7 100644 --- a/pkg/common/db/s3/minio/minio.go +++ b/pkg/common/db/s3/minio/minio.go @@ -111,6 +111,7 @@ func NewMinio() (s3.Interface, error) { if err := m.initMinio(ctx); err != nil { fmt.Println("init minio error:", err) } + return m, nil } @@ -141,8 +142,9 @@ func (m *Minio) initMinio(ctx context.Context) error { return fmt.Errorf("check bucket exists error: %w", err) } if !exists { - if err := m.core.Client.MakeBucket(ctx, conf.Bucket, minio.MakeBucketOptions{}); err != nil { - return fmt.Errorf("make bucket error: %w", err) + err2 := m.core.Client.MakeBucket(ctx, conf.Bucket, minio.MakeBucketOptions{}) + if err2 != nil { + return fmt.Errorf("make bucket error: %w", err2) } } if conf.PublicRead { @@ -150,8 +152,9 @@ func (m *Minio) initMinio(ctx context.Context) error { `{"Version": "2012-10-17","Statement": [{"Action": ["s3:GetObject","s3:PutObject"],"Effect": "Allow","Principal": {"AWS": ["*"]},"Resource": ["arn:aws:s3:::%s/*"],"Sid": ""}]}`, conf.Bucket, ) - if err := m.core.Client.SetBucketPolicy(ctx, conf.Bucket, policy); err != nil { - return err + err2 := m.core.Client.SetBucketPolicy(ctx, conf.Bucket, policy) + if err2 != nil { + return err2 } } m.location, err = m.core.Client.GetBucketLocation(ctx, conf.Bucket) @@ -182,6 +185,7 @@ func (m *Minio) initMinio(ctx context.Context) error { vblc.Elem().Elem().Interface().(interface{ Set(string, string) }).Set(conf.Bucket, m.location) }() m.init = true + return nil } @@ -205,6 +209,7 @@ func (m *Minio) InitiateMultipartUpload(ctx context.Context, name string) (*s3.I if err != nil { return nil, err } + return &s3.InitiateMultipartUploadResult{ Bucket: m.bucket, Key: name, @@ -227,6 +232,7 @@ func (m *Minio) CompleteMultipartUpload(ctx context.Context, uploadID string, na if err != nil { return nil, err } + return &s3.CompleteMultipartUploadResult{ Location: upload.Location, Bucket: upload.Bucket, @@ -249,6 +255,7 @@ func (m *Minio) PartSize(ctx context.Context, size int64) (int64, error) { if size%maxNumSize != 0 { partSize++ } + return partSize, nil } @@ -282,6 +289,7 @@ func (m *Minio) AuthSign(ctx context.Context, uploadID string, name string, expi if m.prefix != "" { result.URL = m.signEndpoint + m.prefix + "/" + m.bucket + "/" + name } + return &result, nil } @@ -296,6 +304,7 @@ func (m *Minio) PresignedPutObject(ctx context.Context, name string, expire time if m.prefix != "" { rawURL.Path = path.Join(m.prefix, rawURL.Path) } + return rawURL.String(), nil } @@ -303,6 +312,7 @@ func (m *Minio) DeleteObject(ctx context.Context, name string) error { if err := m.initMinio(ctx); err != nil { return err } + return m.core.Client.RemoveObject(ctx, m.bucket, name, minio.RemoveObjectOptions{}) } @@ -314,6 +324,7 @@ func (m *Minio) StatObject(ctx context.Context, name string) (*s3.ObjectInfo, er if err != nil { return nil, err } + return &s3.ObjectInfo{ ETag: strings.ToLower(info.ETag), Key: info.Key, @@ -336,6 +347,7 @@ func (m *Minio) CopyObject(ctx context.Context, src string, dst string) (*s3.Cop if err != nil { return nil, err } + return &s3.CopyObjectInfo{ Key: dst, ETag: strings.ToLower(result.ETag), @@ -346,20 +358,23 @@ func (m *Minio) IsNotFound(err error) bool { if err == nil { return false } - switch e := err.(type) { - case minio.ErrorResponse: - return e.StatusCode == http.StatusNotFound || e.Code == "NoSuchKey" - case *minio.ErrorResponse: - return e.StatusCode == http.StatusNotFound || e.Code == "NoSuchKey" - default: - return false + var minioErr minio.ErrorResponse + if errors.As(err, &minio.ErrorResponse{}) { + return minioErr.StatusCode == http.StatusNotFound || minioErr.Code == "NoSuchKey" } + var minioErr2 *minio.ErrorResponse + if errors.As(err, &minioErr2) { + return minioErr2.StatusCode == http.StatusNotFound || minioErr2.Code == "NoSuchKey" + } + + return false } func (m *Minio) AbortMultipartUpload(ctx context.Context, uploadID string, name string) error { if err := m.initMinio(ctx); err != nil { return err } + return m.core.AbortMultipartUpload(ctx, m.bucket, name, uploadID) } @@ -386,6 +401,7 @@ func (m *Minio) ListUploadedParts(ctx context.Context, uploadID string, name str Size: part.Size, } } + return res, nil } @@ -410,14 +426,11 @@ func (m *Minio) presignedGetObject(ctx context.Context, name string, expire time if m.prefix != "" { rawURL.Path = path.Join(m.prefix, rawURL.Path) } + return rawURL.String(), nil } -func (m *Minio) AccessURL(ctx context.Context, name string, expire time.Duration, opt *s3.AccessURLOption) (string, error) { - if err := m.initMinio(ctx); err != nil { - return "", err - } - reqParams := make(url.Values) +func (m *Minio) getImageInfoForAccessURL(ctx context.Context, name string, expire time.Duration, opt *s3.AccessURLOption, reqParams url.Values) (fileInfo *s3.ObjectInfo, objectInfoPath, msg string, err error) { if opt != nil { if opt.ContentType != "" { reqParams.Set("response-content-type", opt.ContentType) @@ -427,35 +440,47 @@ func (m *Minio) AccessURL(ctx context.Context, name string, expire time.Duration } } if opt.Image == nil || (opt.Image.Width < 0 && opt.Image.Height < 0 && opt.Image.Format == "") || (opt.Image.Width > maxImageWidth || opt.Image.Height > maxImageHeight) { - return m.presignedGetObject(ctx, name, expire, reqParams) + msg, err = m.presignedGetObject(ctx, name, expire, reqParams) + + return nil, "", msg, err } - fileInfo, err := m.StatObject(ctx, name) + fileInfo, err = m.StatObject(ctx, name) + objectInfoPath = path.Join(pathInfo, fileInfo.ETag, "image.json") if err != nil { - return "", err + return nil, "", msg, err } if fileInfo.Size > maxImageSize { - return "", errors.New("file size too large") + return nil, "", "", errors.New("file size too large") } - objectInfoPath := path.Join(pathInfo, fileInfo.ETag, "image.json") - var ( - img image.Image - info minioImageInfo - ) - data, err := m.getObjectData(ctx, objectInfoPath, 1024) + + return fileInfo, objectInfoPath, "", nil +} + +func (m *Minio) loadImgDataForAccessURL(objectInfoPath string, ctx context.Context, name string, info *minioImageInfo) (img image.Image, msg string, err error) { + var data []byte + data, err = m.getObjectData(ctx, objectInfoPath, 1024) + + //nolint:nestif //easy enough to understand if err == nil { - if err := json.Unmarshal(data, &info); err != nil { - return "", fmt.Errorf("unmarshal minio image info.json error: %w", err) + err = json.Unmarshal(data, &info) + if err != nil { + return nil, "", fmt.Errorf("unmarshal minio image info.json error: %w", err) } if info.NotImage { - return "", errors.New("not image") + return nil, "", errors.New("not image") } } else if m.IsNotFound(err) { - reader, err := m.core.Client.GetObject(ctx, m.bucket, name, minio.GetObjectOptions{}) + var reader *minio.Object + reader, err = m.core.Client.GetObject(ctx, m.bucket, name, minio.GetObjectOptions{}) if err != nil { - return "", err + return img, msg, err } defer reader.Close() - imageInfo, format, err := ImageStat(reader) + var ( + imageInfo image.Image + format string + ) + imageInfo, format, err = ImageStat(reader) if err == nil { info.NotImage = false info.Format = format @@ -464,16 +489,22 @@ func (m *Minio) AccessURL(ctx context.Context, name string, expire time.Duration } else { info.NotImage = true } - data, err := json.Marshal(&info) + + data, err = json.Marshal(&info) if err != nil { - return "", err + return img, msg, err } - if _, err := m.core.Client.PutObject(ctx, m.bucket, objectInfoPath, bytes.NewReader(data), int64(len(data)), minio.PutObjectOptions{}); err != nil { - return "", err + + _, err = m.core.Client.PutObject(ctx, m.bucket, objectInfoPath, bytes.NewReader(data), int64(len(data)), minio.PutObjectOptions{}) + if err != nil { + return img, msg, err } - } else { - return "", err } + + return img, msg, err +} + +func (m *Minio) formatImgInfoForAccessURL(opt *s3.AccessURLOption, info *minioImageInfo, reqParams url.Values) { if opt.Image.Width > info.Width || opt.Image.Width <= 0 { opt.Image.Width = info.Width } @@ -496,24 +527,24 @@ func (m *Minio) AccessURL(ctx context.Context, name string, expire time.Duration } } reqParams.Set("response-content-type", "image/"+opt.Image.Format) - if opt.Image.Width == info.Width && opt.Image.Height == info.Height && opt.Image.Format == info.Format { - return m.presignedGetObject(ctx, name, expire, reqParams) - } - cacheKey := filepath.Join(pathInfo, fileInfo.ETag, fmt.Sprintf("image_w%d_h%d.%s", opt.Image.Width, opt.Image.Height, opt.Image.Format)) - if _, err := m.core.Client.StatObject(ctx, m.bucket, cacheKey, minio.StatObjectOptions{}); err == nil { +} + +func (m *Minio) cacheImgInfoForAccessURL(ctx context.Context, name, cacheKey string, img image.Image, expire time.Duration, opt *s3.AccessURLOption, reqParams url.Values) (string, error) { + _, err := m.core.Client.StatObject(ctx, m.bucket, cacheKey, minio.StatObjectOptions{}) + if err == nil { return m.presignedGetObject(ctx, cacheKey, expire, reqParams) } else if !m.IsNotFound(err) { return "", err } if img == nil { - reader, err := m.core.Client.GetObject(ctx, m.bucket, name, minio.GetObjectOptions{}) - if err != nil { - return "", err + reader, err2 := m.core.Client.GetObject(ctx, m.bucket, name, minio.GetObjectOptions{}) + if err2 != nil { + return "", err2 } defer reader.Close() - img, _, err = ImageStat(reader) - if err != nil { - return "", err + img, _, err2 = ImageStat(reader) + if err2 != nil { + return "", err2 } } thumbnail := resizeImage(img, opt.Image.Width, opt.Image.Height) @@ -526,9 +557,48 @@ func (m *Minio) AccessURL(ctx context.Context, name string, expire time.Duration case formatGif: err = gif.Encode(buf, thumbnail, nil) } + if err != nil { + return "", err + } if _, err := m.core.Client.PutObject(ctx, m.bucket, cacheKey, buf, int64(buf.Len()), minio.PutObjectOptions{}); err != nil { return "", err } + + return "", nil +} + +func (m *Minio) AccessURL(ctx context.Context, name string, expire time.Duration, opt *s3.AccessURLOption) (string, error) { + errInit := m.initMinio(ctx) + if errInit != nil { + return "", errInit + } + reqParams := make(url.Values) + fileInfo, objectInfoPath, msg, err := m.getImageInfoForAccessURL(ctx, name, expire, opt, reqParams) + if err != nil { + return msg, err + } + // load-cache img data + var ( + img image.Image + info minioImageInfo + ) + img, msg, err = m.loadImgDataForAccessURL(objectInfoPath, ctx, name, &info) + if err != nil { + return msg, err + } + // format img info + m.formatImgInfoForAccessURL(opt, &info, reqParams) + // no need resize + if opt.Image.Width == info.Width && opt.Image.Height == info.Height && opt.Image.Format == info.Format { + return m.presignedGetObject(ctx, name, expire, reqParams) + } + // cache img + cacheKey := filepath.Join(pathInfo, fileInfo.ETag, fmt.Sprintf("image_w%d_h%d.%s", opt.Image.Width, opt.Image.Height, opt.Image.Format)) + msg, err = m.cacheImgInfoForAccessURL(ctx, name, cacheKey, img, expire, opt, reqParams) + if err != nil { + return msg, err + } + // return cache img return m.presignedGetObject(ctx, cacheKey, expire, reqParams) } @@ -541,5 +611,6 @@ func (m *Minio) getObjectData(ctx context.Context, name string, limit int64) ([] if limit < 0 { return io.ReadAll(object) } + return io.ReadAll(io.LimitReader(object, 1024)) } diff --git a/pkg/common/db/s3/oss/oss.go b/pkg/common/db/s3/oss/oss.go old mode 100644 new mode 100755 index 6a728127b..4f7f37497 --- a/pkg/common/db/s3/oss/oss.go +++ b/pkg/common/db/s3/oss/oss.go @@ -45,11 +45,6 @@ const ( imageWebp = "webp" ) -const ( - videoSnapshotImagePng = "png" - videoSnapshotImageJpg = "jpg" -) - func NewOSS() (s3.Interface, error) { conf := config.Config.Object.Oss if conf.BucketURL == "" { @@ -66,6 +61,7 @@ func NewOSS() (s3.Interface, error) { if conf.BucketURL[len(conf.BucketURL)-1] != '/' { conf.BucketURL += "/" } + return &OSS{ bucketURL: conf.BucketURL, bucket: bucket, @@ -98,6 +94,7 @@ func (o *OSS) InitiateMultipartUpload(ctx context.Context, name string) (*s3.Ini if err != nil { return nil, err } + return &s3.InitiateMultipartUploadResult{ UploadID: result.UploadID, Bucket: result.Bucket, @@ -121,6 +118,7 @@ func (o *OSS) CompleteMultipartUpload(ctx context.Context, uploadID string, name if err != nil { return nil, err } + return &s3.CompleteMultipartUploadResult{ Location: result.Location, Bucket: result.Bucket, @@ -143,6 +141,7 @@ func (o *OSS) PartSize(ctx context.Context, size int64) (int64, error) { if size%maxNumSize != 0 { partSize++ } + return partSize, nil } @@ -155,7 +154,7 @@ func (o *OSS) AuthSign(ctx context.Context, uploadID string, name string, expire } for i, partNumber := range partNumbers { rawURL := fmt.Sprintf(`%s%s?partNumber=%d&uploadId=%s`, o.bucketURL, name, partNumber, uploadID) - request, err := http.NewRequest(http.MethodPut, rawURL, nil) + request, err := http.NewRequestWithContext(context.Background(), http.MethodPut, rawURL, nil) if err != nil { return nil, err } @@ -175,6 +174,7 @@ func (o *OSS) AuthSign(ctx context.Context, uploadID string, name string, expire Header: request.Header, } } + return &result, nil } @@ -191,25 +191,26 @@ func (o *OSS) StatObject(ctx context.Context, name string) (*s3.ObjectInfo, erro if res.ETag = strings.ToLower(strings.ReplaceAll(header.Get("ETag"), `"`, ``)); res.ETag == "" { return nil, errors.New("StatObject etag not found") } - if contentLengthStr := header.Get("Content-Length"); contentLengthStr == "" { + contentLengthStr := header.Get("Content-Length") + if contentLengthStr == "" { return nil, errors.New("StatObject content-length not found") - } else { - res.Size, err = strconv.ParseInt(contentLengthStr, 10, 64) - if err != nil { - return nil, fmt.Errorf("StatObject content-length parse error: %w", err) - } - if res.Size < 0 { - return nil, errors.New("StatObject content-length must be greater than 0") - } } - if lastModified := header.Get("Last-Modified"); lastModified == "" { + res.Size, err = strconv.ParseInt(contentLengthStr, 10, 64) + if err != nil { + return nil, fmt.Errorf("StatObject content-length parse error: %w", err) + } + if res.Size < 0 { + return nil, errors.New("StatObject content-length must be greater than 0") + } + lastModified := header.Get("Last-Modified") + if lastModified == "" { return nil, errors.New("StatObject last-modified not found") - } else { - res.LastModified, err = time.Parse(http.TimeFormat, lastModified) - if err != nil { - return nil, fmt.Errorf("StatObject last-modified parse error: %w", err) - } } + res.LastModified, err = time.Parse(http.TimeFormat, lastModified) + if err != nil { + return nil, fmt.Errorf("StatObject last-modified parse error: %w", err) + } + return res, nil } @@ -222,6 +223,7 @@ func (o *OSS) CopyObject(ctx context.Context, src string, dst string) (*s3.CopyO if err != nil { return nil, err } + return &s3.CopyObjectInfo{ Key: dst, ETag: strings.ToLower(strings.ReplaceAll(result.ETag, `"`, ``)), @@ -229,6 +231,7 @@ func (o *OSS) CopyObject(ctx context.Context, src string, dst string) (*s3.CopyO } func (o *OSS) IsNotFound(err error) bool { + //nolint:errorlint //this is exactly what we want,there is no risk for no wrapped errors switch e := err.(type) { case oss.ServiceError: return e.StatusCode == http.StatusNotFound || e.Code == "NoSuchKey" @@ -271,6 +274,7 @@ func (o *OSS) ListUploadedParts(ctx context.Context, uploadID string, name strin Size: int64(part.Size), } } + return res, nil } @@ -278,39 +282,7 @@ func (o *OSS) AccessURL(ctx context.Context, name string, expire time.Duration, publicRead := config.Config.Object.Oss.PublicRead var opts []oss.Option if opt != nil { - if opt.Image != nil { - // 文档地址: https://help.aliyun.com/zh/oss/user-guide/resize-images-4?spm=a2c4g.11186623.0.0.4b3b1e4fWW6yji - var format string - switch opt.Image.Format { - case - imagePng, - imageJpg, - imageJpeg, - imageGif, - imageWebp: - format = opt.Image.Format - default: - opt.Image.Format = imageJpg - } - // https://oss-console-img-demo-cn-hangzhou.oss-cn-hangzhou.aliyuncs.com/example.jpg?x-oss-process=image/resize,h_100,m_lfit - process := "image/resize,m_lfit" - if opt.Image.Width > 0 { - process += ",w_" + strconv.Itoa(opt.Image.Width) - } - if opt.Image.Height > 0 { - process += ",h_" + strconv.Itoa(opt.Image.Height) - } - process += ",format," + format - opts = append(opts, oss.Process(process)) - } - if !publicRead { - if opt.ContentType != "" { - opts = append(opts, oss.ResponseContentType(opt.ContentType)) - } - if opt.Filename != "" { - opts = append(opts, oss.ResponseContentDisposition(`attachment; filename=`+strconv.Quote(opt.Filename))) - } - } + opts = optsForAccessURL(opt, opts, publicRead) } if expire <= 0 { expire = time.Hour * 24 * 365 * 99 // 99 years @@ -325,5 +297,44 @@ func (o *OSS) AccessURL(ctx context.Context, name string, expire time.Duration, return "", err } params := getURLParams(*o.bucket.Client.Conn, rawParams) + return getURL(o.um, o.bucket.BucketName, name, params).String(), nil } + +func optsForAccessURL(opt *s3.AccessURLOption, opts []oss.Option, publicRead bool) []oss.Option { + if opt.Image != nil { + // 文档地址: https://help.aliyun.com/zh/oss/user-guide/resize-images-4?spm=a2c4g.11186623.0.0.4b3b1e4fWW6yji + var format string + switch opt.Image.Format { + case + imagePng, + imageJpg, + imageJpeg, + imageGif, + imageWebp: + format = opt.Image.Format + default: + opt.Image.Format = imageJpg + } + // https://oss-console-img-demo-cn-hangzhou.oss-cn-hangzhou.aliyuncs.com/example.jpg?x-oss-process=image/resize,h_100,m_lfit + process := "image/resize,m_lfit" + if opt.Image.Width > 0 { + process += ",w_" + strconv.Itoa(opt.Image.Width) + } + if opt.Image.Height > 0 { + process += ",h_" + strconv.Itoa(opt.Image.Height) + } + process += ",format," + format + opts = append(opts, oss.Process(process)) + } + if !publicRead { + if opt.ContentType != "" { + opts = append(opts, oss.ResponseContentType(opt.ContentType)) + } + if opt.Filename != "" { + opts = append(opts, oss.ResponseContentDisposition(`attachment; filename=`+strconv.Quote(opt.Filename))) + } + } + + return opts +} diff --git a/pkg/common/db/table/relation/group.go b/pkg/common/db/table/relation/group.go index 6759e0d35..24a75173d 100644 --- a/pkg/common/db/table/relation/group.go +++ b/pkg/common/db/table/relation/group.go @@ -30,7 +30,7 @@ type GroupModel struct { Introduction string `gorm:"column:introduction;size:255" json:"introduction"` FaceURL string `gorm:"column:face_url;size:255" json:"faceURL"` CreateTime time.Time `gorm:"column:create_time;index:create_time;autoCreateTime"` - Ex string `gorm:"column:ex" json:"ex;size:1024"` + Ex string `gorm:"column:ex;size:1024" json:"ex"` Status int32 `gorm:"column:status"` CreatorUserID string `gorm:"column:creator_user_id;size:64"` GroupType int32 `gorm:"column:group_type"` diff --git a/pkg/common/db/table/relation/utils.go b/pkg/common/db/table/relation/utils.go index c944eae8b..bc2639e1a 100644 --- a/pkg/common/db/table/relation/utils.go +++ b/pkg/common/db/table/relation/utils.go @@ -15,6 +15,8 @@ package relation import ( + "errors" + "gorm.io/gorm" "github.com/OpenIMSDK/tools/utils" @@ -32,5 +34,5 @@ type GroupSimpleUserID struct { } func IsNotFound(err error) bool { - return utils.Unwrap(err) == gorm.ErrRecordNotFound + return errors.Is(utils.Unwrap(err), gorm.ErrRecordNotFound) } diff --git a/pkg/common/db/table/unrelation/msg.go b/pkg/common/db/table/unrelation/msg.go index c95b211a8..542f318ad 100644 --- a/pkg/common/db/table/unrelation/msg.go +++ b/pkg/common/db/table/unrelation/msg.go @@ -150,6 +150,7 @@ func (m *MsgDocModel) IsFull() bool { func (m MsgDocModel) GetDocID(conversationID string, seq int64) string { seqSuffix := (seq - 1) / singleGocMsgNum + return m.indexGen(conversationID, seqSuffix) } @@ -164,6 +165,7 @@ func (m MsgDocModel) GetDocIDSeqsMap(conversationID string, seqs []int64) map[st t[docID] = append(value, seqs[i]) } } + return t } @@ -181,5 +183,6 @@ func (MsgDocModel) GenExceptionMessageBySeqs(seqs []int64) (exceptionMsg []*sdkw msgModel.Seq = v exceptionMsg = append(exceptionMsg, msgModel) } + return exceptionMsg } diff --git a/pkg/common/db/unrelation/mongo.go b/pkg/common/db/unrelation/mongo.go old mode 100644 new mode 100755 index 09e3e904e..8a90a2a2c --- a/pkg/common/db/unrelation/mongo.go +++ b/pkg/common/db/unrelation/mongo.go @@ -16,6 +16,7 @@ package unrelation import ( "context" + "errors" "fmt" "strings" "time" @@ -44,27 +45,12 @@ type Mongo struct { // NewMongo Initialize MongoDB connection. func NewMongo() (*Mongo, error) { specialerror.AddReplace(mongo.ErrNoDocuments, errs.ErrRecordNotFound) - uri := "mongodb://sample.host:27017/?maxPoolSize=20&w=majority" + // uri := "mongodb://sample.host:27017/?maxPoolSize=20&w=majority" + var uri string if config.Config.Mongo.Uri != "" { uri = config.Config.Mongo.Uri } else { - mongodbHosts := "" - for i, v := range config.Config.Mongo.Address { - if i == len(config.Config.Mongo.Address)-1 { - mongodbHosts += v - } else { - mongodbHosts += v + "," - } - } - if config.Config.Mongo.Password != "" && config.Config.Mongo.Username != "" { - uri = fmt.Sprintf("mongodb://%s:%s@%s/%s?maxPoolSize=%d&authSource=admin", - config.Config.Mongo.Username, config.Config.Mongo.Password, mongodbHosts, - config.Config.Mongo.Database, config.Config.Mongo.MaxPoolSize) - } else { - uri = fmt.Sprintf("mongodb://%s/%s/?maxPoolSize=%d&authSource=admin", - mongodbHosts, config.Config.Mongo.Database, - config.Config.Mongo.MaxPoolSize) - } + uri = defaultMongoUriForNewMongo() } fmt.Println("mongo:", uri) var mongoClient *mongo.Client @@ -76,17 +62,41 @@ func NewMongo() (*Mongo, error) { if err == nil { return &Mongo{db: mongoClient}, nil } - if cmdErr, ok := err.(mongo.CommandError); ok { + var cmdErr mongo.CommandError + if errors.As(err, &cmdErr) { if cmdErr.Code == 13 || cmdErr.Code == 18 { return nil, err - } else { - fmt.Printf("Failed to connect to MongoDB: %s\n", err) } + fmt.Printf("Failed to connect to MongoDB: %s\n", err) } } + return nil, err } +func defaultMongoUriForNewMongo() string { + var uri string + mongodbHosts := "" + for i, v := range config.Config.Mongo.Address { + if i == len(config.Config.Mongo.Address)-1 { + mongodbHosts += v + } else { + mongodbHosts += v + "," + } + } + if config.Config.Mongo.Password != "" && config.Config.Mongo.Username != "" { + uri = fmt.Sprintf("mongodb://%s:%s@%s/%s?maxPoolSize=%d&authSource=admin", + config.Config.Mongo.Username, config.Config.Mongo.Password, mongodbHosts, + config.Config.Mongo.Database, config.Config.Mongo.MaxPoolSize) + } else { + uri = fmt.Sprintf("mongodb://%s/%s/?maxPoolSize=%d&authSource=admin", + mongodbHosts, config.Config.Mongo.Database, + config.Config.Mongo.MaxPoolSize) + } + + return uri +} + func (m *Mongo) GetClient() *mongo.Client { return m.db } @@ -106,6 +116,7 @@ func (m *Mongo) CreateSuperGroupIndex() error { if err := m.createMongoIndex(unrelation.CUserToSuperGroup, true, "user_id"); err != nil { return err } + return nil } @@ -139,5 +150,6 @@ func (m *Mongo) createMongoIndex(collection string, isUnique bool, keys ...strin if err != nil { return utils.Wrap(err, result) } + return nil } diff --git a/pkg/common/db/unrelation/msg.go b/pkg/common/db/unrelation/msg.go old mode 100644 new mode 100755 index 9b461dd1f..afa2f81e4 --- a/pkg/common/db/unrelation/msg.go +++ b/pkg/common/db/unrelation/msg.go @@ -49,6 +49,7 @@ type MsgMongoDriver struct { func NewMsgMongoDriver(database *mongo.Database) table.MsgDocModelInterface { collection := database.Collection(table.MsgDocModel{}.TableName()) + return &MsgMongoDriver{MsgCollection: collection} } @@ -59,6 +60,7 @@ func (m *MsgMongoDriver) PushMsgsToDoc(ctx context.Context, docID string, msgsTo func (m *MsgMongoDriver) Create(ctx context.Context, model *table.MsgDocModel) error { _, err := m.MsgCollection.InsertOne(ctx, model) + return err } @@ -81,6 +83,7 @@ func (m *MsgMongoDriver) UpdateMsg( if err != nil { return nil, utils.Wrap(err, "") } + return res, nil } @@ -108,6 +111,7 @@ func (m *MsgMongoDriver) PushUnique( if err != nil { return nil, utils.Wrap(err, "") } + return res, nil } @@ -120,6 +124,7 @@ func (m *MsgMongoDriver) UpdateMsgContent(ctx context.Context, docID string, ind if err != nil { return utils.Wrap(err, "") } + return nil } @@ -143,12 +148,14 @@ func (m *MsgMongoDriver) UpdateMsgStatusByIndexInOneDoc( if err != nil { return utils.Wrap(err, "") } + return nil } func (m *MsgMongoDriver) FindOneByDocID(ctx context.Context, docID string) (*table.MsgDocModel, error) { doc := &table.MsgDocModel{} err := m.MsgCollection.FindOne(ctx, bson.M{"doc_id": docID}).Decode(doc) + return doc, err } @@ -177,6 +184,7 @@ func (m *MsgMongoDriver) GetMsgDocModelByIndex( if len(msgs) > 0 { return &msgs[0], nil } + return nil, ErrMsgListNotExist } @@ -225,6 +233,7 @@ func (m *MsgMongoDriver) DeleteMsgsInOneDocByIndex(ctx context.Context, docID st if err != nil { return utils.Wrap(err, "") } + return nil } @@ -233,6 +242,7 @@ func (m *MsgMongoDriver) DeleteDocs(ctx context.Context, docIDs []string) error return nil } _, err := m.MsgCollection.DeleteMany(ctx, bson.M{"doc_id": bson.M{"$in": docIDs}}) + return err } @@ -246,6 +256,7 @@ func (m *MsgMongoDriver) GetMsgBySeqIndexIn1Doc( for _, seq := range seqs { indexs = append(indexs, m.model.GetMsgIndex(seq)) } + //nolint:govet //This is already the officially recommended standard practice. pipeline := mongo.Pipeline{ { {"$match", bson.D{ @@ -336,6 +347,7 @@ func (m *MsgMongoDriver) GetMsgBySeqIndexIn1Doc( } msgs = append(msgs, msg) } + return msgs, nil } @@ -344,6 +356,7 @@ func (m *MsgMongoDriver) IsExistDocID(ctx context.Context, docID string) (bool, if err != nil { return false, errs.Wrap(err) } + return count > 0, nil } @@ -372,6 +385,7 @@ func (m *MsgMongoDriver) MarkSingleChatMsgsAsRead( updates = append(updates, updateModel) } _, err := m.MsgCollection.BulkWrite(ctx, updates) + return err } @@ -611,7 +625,39 @@ func (m *MsgMongoDriver) RangeUserSendCount( }, ) } - pipeline := bson.A{ + pipeline := buildPiplineForRangeUserSendCount(or, start, end, sort, pageNumber, showNumber) + cur, err := m.MsgCollection.Aggregate(ctx, pipeline, options.Aggregate().SetAllowDiskUse(true)) + if err != nil { + return 0, 0, nil, nil, errs.Wrap(err) + } + defer cur.Close(ctx) + var result []Result + if err = cur.All(ctx, &result); err != nil { + return 0, 0, nil, nil, errs.Wrap(err) + } + if len(result) == 0 { + return 0, 0, nil, nil, errs.Wrap(err) + } + users = make([]*table.UserCount, len(result[0].Users)) + for i, r := range result[0].Users { + users[i] = &table.UserCount{ + UserID: r.UserID, + Count: r.Count, + } + } + dateCount = make(map[string]int64) + for _, r := range result[0].Dates { + dateCount[r.Date] = r.Count + } + + return result[0].MsgCount, result[0].UserCount, users, dateCount, nil +} + +//nolint:funlen // it need to be such long +func buildPiplineForRangeUserSendCount(or bson.A, start time.Time, + end time.Time, sort int, pageNumber, showNumber int32, +) bson.A { + return bson.A{ bson.M{ "$match": bson.M{ "$and": bson.A{ @@ -795,30 +841,6 @@ func (m *MsgMongoDriver) RangeUserSendCount( }, }, } - cur, err := m.MsgCollection.Aggregate(ctx, pipeline, options.Aggregate().SetAllowDiskUse(true)) - if err != nil { - return 0, 0, nil, nil, errs.Wrap(err) - } - defer cur.Close(ctx) - var result []Result - if err := cur.All(ctx, &result); err != nil { - return 0, 0, nil, nil, errs.Wrap(err) - } - if len(result) == 0 { - return 0, 0, nil, nil, errs.Wrap(err) - } - users = make([]*table.UserCount, len(result[0].Users)) - for i, r := range result[0].Users { - users[i] = &table.UserCount{ - UserID: r.UserID, - Count: r.Count, - } - } - dateCount = make(map[string]int64) - for _, r := range result[0].Dates { - dateCount[r.Date] = r.Count - } - return result[0].MsgCount, result[0].UserCount, users, dateCount, nil } func (m *MsgMongoDriver) RangeGroupSendCount( @@ -847,7 +869,39 @@ func (m *MsgMongoDriver) RangeGroupSendCount( Count int64 `bson:"count"` } `bson:"dates"` } - pipeline := bson.A{ + pipeline := buildPiplineForRangeGroupSendCount(start, end, sort, pageNumber, showNumber) + cur, err := m.MsgCollection.Aggregate(ctx, pipeline, options.Aggregate().SetAllowDiskUse(true)) + if err != nil { + return 0, 0, nil, nil, errs.Wrap(err) + } + defer cur.Close(ctx) + var result []Result + if err = cur.All(ctx, &result); err != nil { + return 0, 0, nil, nil, errs.Wrap(err) + } + if len(result) == 0 { + return 0, 0, nil, nil, errs.Wrap(err) + } + groups = make([]*table.GroupCount, len(result[0].Groups)) + for i, r := range result[0].Groups { + groups[i] = &table.GroupCount{ + GroupID: r.GroupID, + Count: r.Count, + } + } + dateCount = make(map[string]int64) + for _, r := range result[0].Dates { + dateCount[r.Date] = r.Count + } + + return result[0].MsgCount, result[0].UserCount, groups, dateCount, nil +} + +//nolint:funlen //it need to has such length +func buildPiplineForRangeGroupSendCount(start time.Time, + end time.Time, sort int, pageNumber, showNumber int32, +) bson.A { + return bson.A{ bson.M{ "$match": bson.M{ "$and": bson.A{ @@ -1044,30 +1098,6 @@ func (m *MsgMongoDriver) RangeGroupSendCount( }, }, } - cur, err := m.MsgCollection.Aggregate(ctx, pipeline, options.Aggregate().SetAllowDiskUse(true)) - if err != nil { - return 0, 0, nil, nil, errs.Wrap(err) - } - defer cur.Close(ctx) - var result []Result - if err := cur.All(ctx, &result); err != nil { - return 0, 0, nil, nil, errs.Wrap(err) - } - if len(result) == 0 { - return 0, 0, nil, nil, errs.Wrap(err) - } - groups = make([]*table.GroupCount, len(result[0].Groups)) - for i, r := range result[0].Groups { - groups[i] = &table.GroupCount{ - GroupID: r.GroupID, - Count: r.Count, - } - } - dateCount = make(map[string]int64) - for _, r := range result[0].Dates { - dateCount[r.Date] = r.Count - } - return result[0].MsgCount, result[0].UserCount, groups, dateCount, nil } func (m *MsgMongoDriver) SearchMessage(ctx context.Context, req *msg.SearchMessageReq) (int32, []*table.MsgInfoModel, error) { @@ -1075,6 +1105,7 @@ func (m *MsgMongoDriver) SearchMessage(ctx context.Context, req *msg.SearchMessa if err != nil { return 0, nil, err } + return total, msgs, nil } @@ -1119,7 +1150,7 @@ func (m *MsgMongoDriver) searchMessage(ctx context.Context, req *msg.SearchMessa }, }, ) - + //nolint:govet //this is already standard pipe = mongo.Pipeline{ { {"$match", bson.D{ @@ -1214,5 +1245,6 @@ func (m *MsgMongoDriver) searchMessage(ctx context.Context, req *msg.SearchMessa } else { msgs = msgs[start:] } + return n, msgs, nil } diff --git a/pkg/common/db/unrelation/msg_convert.go b/pkg/common/db/unrelation/msg_convert.go index 810b4f419..a5b28a5c7 100644 --- a/pkg/common/db/unrelation/msg_convert.go +++ b/pkg/common/db/unrelation/msg_convert.go @@ -31,12 +31,14 @@ func (m *MsgMongoDriver) ConvertMsgsDocLen(ctx context.Context, conversationIDs cursor, err := m.MsgCollection.Find(ctx, bson.M{"doc_id": regex}) if err != nil { log.ZError(ctx, "convertAll find msg doc failed", err, "conversationID", conversationID) + continue } var msgDocs []table.MsgDocModel err = cursor.All(ctx, &msgDocs) if err != nil { log.ZError(ctx, "convertAll cursor all failed", err, "conversationID", conversationID) + continue } if len(msgDocs) < 1 { @@ -44,39 +46,45 @@ func (m *MsgMongoDriver) ConvertMsgsDocLen(ctx context.Context, conversationIDs } log.ZInfo(ctx, "msg doc convert", "conversationID", conversationID, "len(msgDocs)", len(msgDocs)) if len(msgDocs[0].Msg) == int(m.model.GetSingleGocMsgNum5000()) { - if _, err := m.MsgCollection.DeleteMany(ctx, bson.M{"doc_id": regex}); err != nil { - log.ZError(ctx, "convertAll delete many failed", err, "conversationID", conversationID) - continue - } - var newMsgDocs []interface{} - for _, msgDoc := range msgDocs { - if int64(len(msgDoc.Msg)) == m.model.GetSingleGocMsgNum() { - continue - } - var index int64 - for index < int64(len(msgDoc.Msg)) { - msg := msgDoc.Msg[index] - if msg != nil && msg.Msg != nil { - msgDocModel := table.MsgDocModel{DocID: m.model.GetDocID(conversationID, msg.Msg.Seq)} - end := index + m.model.GetSingleGocMsgNum() - if int(end) >= len(msgDoc.Msg) { - msgDocModel.Msg = msgDoc.Msg[index:] - } else { - msgDocModel.Msg = msgDoc.Msg[index:end] - } - newMsgDocs = append(newMsgDocs, msgDocModel) - index = end - } else { - break - } - } - } - _, err = m.MsgCollection.InsertMany(ctx, newMsgDocs) - if err != nil { - log.ZError(ctx, "convertAll insert many failed", err, "conversationID", conversationID, "len(newMsgDocs)", len(newMsgDocs)) - } else { - log.ZInfo(ctx, "msg doc convert", "conversationID", conversationID, "len(newMsgDocs)", len(newMsgDocs)) - } + convertMsgDocs(m, ctx, msgDocs, conversationID, regex) } } } + +func convertMsgDocs(m *MsgMongoDriver, ctx context.Context, msgDocs []table.MsgDocModel, conversationID string, regex primitive.Regex) { + var err error + if _, err = m.MsgCollection.DeleteMany(ctx, bson.M{"doc_id": regex}); err != nil { + log.ZError(ctx, "convertAll delete many failed", err, "conversationID", conversationID) + + return + } + var newMsgDocs []interface{} + for _, msgDoc := range msgDocs { + if int64(len(msgDoc.Msg)) == m.model.GetSingleGocMsgNum() { + continue + } + var index int64 + for index < int64(len(msgDoc.Msg)) { + msg := msgDoc.Msg[index] + if msg != nil && msg.Msg != nil { + msgDocModel := table.MsgDocModel{DocID: m.model.GetDocID(conversationID, msg.Msg.Seq)} + end := index + m.model.GetSingleGocMsgNum() + if int(end) >= len(msgDoc.Msg) { + msgDocModel.Msg = msgDoc.Msg[index:] + } else { + msgDocModel.Msg = msgDoc.Msg[index:end] + } + newMsgDocs = append(newMsgDocs, msgDocModel) + index = end + } else { + break + } + } + } + _, err = m.MsgCollection.InsertMany(ctx, newMsgDocs) + if err != nil { + log.ZError(ctx, "convertAll insert many failed", err, "conversationID", conversationID, "len(newMsgDocs)", len(newMsgDocs)) + } else { + log.ZInfo(ctx, "msg doc convert", "conversationID", conversationID, "len(newMsgDocs)", len(newMsgDocs)) + } +} diff --git a/pkg/common/db/unrelation/super_group.go b/pkg/common/db/unrelation/super_group.go index c762140a2..7f9aecfd6 100644 --- a/pkg/common/db/unrelation/super_group.go +++ b/pkg/common/db/unrelation/super_group.go @@ -59,6 +59,7 @@ func (s *SuperGroupMongoDriver) CreateSuperGroup(ctx context.Context, groupID st return err } } + return nil } @@ -69,6 +70,7 @@ func (s *SuperGroupMongoDriver) TakeSuperGroup( if err := s.superGroupCollection.FindOne(ctx, bson.M{"group_id": groupID}).Decode(&group); err != nil { return nil, utils.Wrap(err, "") } + return group, nil } @@ -86,6 +88,7 @@ func (s *SuperGroupMongoDriver) FindSuperGroup( if err := cursor.All(ctx, &groups); err != nil { return nil, utils.Wrap(err, "") } + return groups, nil } @@ -113,6 +116,7 @@ func (s *SuperGroupMongoDriver) AddUserToSuperGroup(ctx context.Context, groupID return utils.Wrap(err, "transaction failed") } } + return nil } @@ -129,6 +133,7 @@ func (s *SuperGroupMongoDriver) RemoverUserFromSuperGroup(ctx context.Context, g if err != nil { return err } + return nil } @@ -138,6 +143,7 @@ func (s *SuperGroupMongoDriver) GetSuperGroupByUserID( ) (*unrelation.UserToSuperGroupModel, error) { var user unrelation.UserToSuperGroupModel err := s.userToSuperGroupCollection.FindOne(ctx, bson.M{"user_id": userID}).Decode(&user) + return &user, utils.Wrap(err, "") } @@ -149,6 +155,7 @@ func (s *SuperGroupMongoDriver) DeleteSuperGroup(ctx context.Context, groupID st if _, err := s.superGroupCollection.DeleteOne(ctx, bson.M{"group_id": groupID}); err != nil { return utils.Wrap(err, "") } + return s.RemoveGroupFromUser(ctx, groupID, group.MemberIDs) } @@ -158,5 +165,6 @@ func (s *SuperGroupMongoDriver) RemoveGroupFromUser(ctx context.Context, groupID bson.M{"user_id": bson.M{"$in": userIDs}}, bson.M{"$pull": bson.M{"group_id_list": groupID}}, ) + return utils.Wrap(err, "") } diff --git a/pkg/common/db/unrelation/user.go b/pkg/common/db/unrelation/user.go old mode 100644 new mode 100755 index 4b4a78c79..ad02968bd --- a/pkg/common/db/unrelation/user.go +++ b/pkg/common/db/unrelation/user.go @@ -16,6 +16,7 @@ package unrelation import ( "context" + "errors" "github.com/OpenIMSDK/tools/errs" "github.com/OpenIMSDK/tools/utils" @@ -50,6 +51,7 @@ type UserMongoDriver struct { // AddSubscriptionList Subscriber's handling of thresholds. func (u *UserMongoDriver) AddSubscriptionList(ctx context.Context, userID string, userIDList []string) error { // Check the number of lists in the key. + //nolint:govet //this has already been the standard format for mongo.Pipeline pipeline := mongo.Pipeline{ {{"$match", bson.D{{"user_id", SubscriptionPrefix + userID}}}}, {{"$project", bson.D{{"count", bson.D{{"$size", "$user_id_list"}}}}}}, @@ -65,7 +67,7 @@ func (u *UserMongoDriver) AddSubscriptionList(ctx context.Context, userID string } // iterate over aggregated results for cursor.Next(ctx) { - err := cursor.Decode(&cnt) + err = cursor.Decode(&cnt) if err != nil { return errs.Wrap(err) } @@ -122,6 +124,7 @@ func (u *UserMongoDriver) AddSubscriptionList(ctx context.Context, userID string return utils.Wrap(err, "transaction failed") } } + return nil } @@ -139,6 +142,7 @@ func (u *UserMongoDriver) UnsubscriptionList(ctx context.Context, userID string, if err != nil { return errs.Wrap(err) } + return nil } @@ -152,6 +156,7 @@ func (u *UserMongoDriver) RemoveSubscribedListFromUser(ctx context.Context, user bson.M{"$pull": bson.M{"user_id_list": userID}}, ) } + return errs.Wrap(err) } @@ -163,12 +168,13 @@ func (u *UserMongoDriver) GetAllSubscribeList(ctx context.Context, userID string bson.M{"user_id": SubscriptionPrefix + userID}) err = cursor.Decode(&user) if err != nil { - if err == mongo.ErrNoDocuments { + if errors.Is(err, mongo.ErrNoDocuments) { return []string{}, nil - } else { - return nil, errs.Wrap(err) } + + return nil, errs.Wrap(err) } + return user.UserIDList, nil } @@ -180,11 +186,12 @@ func (u *UserMongoDriver) GetSubscribedList(ctx context.Context, userID string) bson.M{"user_id": SubscribedPrefix + userID}) err = cursor.Decode(&user) if err != nil { - if err == mongo.ErrNoDocuments { + if errors.Is(err, mongo.ErrNoDocuments) { return []string{}, nil - } else { - return nil, errs.Wrap(err) } + + return nil, errs.Wrap(err) } + return user.UserIDList, nil } diff --git a/pkg/common/discovery_register/k8s_discovery_register.go b/pkg/common/discovery_register/k8s_discovery_register.go index 70f9f39f3..72179fdbd 100644 --- a/pkg/common/discovery_register/k8s_discovery_register.go +++ b/pkg/common/discovery_register/k8s_discovery_register.go @@ -4,12 +4,14 @@ import ( "context" "errors" "fmt" + "time" + "github.com/OpenIMSDK/tools/discoveryregistry" openkeeper "github.com/OpenIMSDK/tools/discoveryregistry/zookeeper" "github.com/OpenIMSDK/tools/log" - "github.com/openimsdk/open-im-server/v3/pkg/common/config" "google.golang.org/grpc" - "time" + + "github.com/openimsdk/open-im-server/v3/pkg/common/config" ) func NewDiscoveryRegister(envType string) (discoveryregistry.SvcDiscoveryRegistry, error) { @@ -28,6 +30,7 @@ func NewDiscoveryRegister(envType string) (discoveryregistry.SvcDiscoveryRegistr client = nil err = errors.New("envType not correct") } + return client, err } @@ -42,47 +45,51 @@ func NewK8sDiscoveryRegister() (discoveryregistry.SvcDiscoveryRegistry, error) { func (cli *K8sDR) Register(serviceName, host string, port int, opts ...grpc.DialOption) error { cli.rpcRegisterAddr = serviceName + return nil } + func (cli *K8sDR) UnRegister() error { - return nil } + func (cli *K8sDR) CreateRpcRootNodes(serviceNames []string) error { - return nil } -func (cli *K8sDR) RegisterConf2Registry(key string, conf []byte) error { +func (cli *K8sDR) RegisterConf2Registry(key string, conf []byte) error { return nil } func (cli *K8sDR) GetConfFromRegistry(key string) ([]byte, error) { - return nil, nil } -func (cli *K8sDR) GetConns(ctx context.Context, serviceName string, opts ...grpc.DialOption) ([]*grpc.ClientConn, error) { +func (cli *K8sDR) GetConns(ctx context.Context, serviceName string, opts ...grpc.DialOption) ([]*grpc.ClientConn, error) { conn, err := grpc.DialContext(ctx, serviceName, append(cli.options, opts...)...) + return []*grpc.ClientConn{conn}, err } -func (cli *K8sDR) GetConn(ctx context.Context, serviceName string, opts ...grpc.DialOption) (*grpc.ClientConn, error) { +func (cli *K8sDR) GetConn(ctx context.Context, serviceName string, opts ...grpc.DialOption) (*grpc.ClientConn, error) { return grpc.DialContext(ctx, serviceName, append(cli.options, opts...)...) } -func (cli *K8sDR) GetSelfConnTarget() string { +func (cli *K8sDR) GetSelfConnTarget() string { return cli.rpcRegisterAddr } + func (cli *K8sDR) AddOption(opts ...grpc.DialOption) { cli.options = append(cli.options, opts...) } + func (cli *K8sDR) CloseConn(conn *grpc.ClientConn) { conn.Close() } -// do not use this method for call rpc +// do not use this method for call rpc. func (cli *K8sDR) GetClientLocalConns() map[string][]*grpc.ClientConn { fmt.Println("should not call this function!!!!!!!!!!!!!!!!!!!!!!!!!") + return nil } diff --git a/pkg/common/http/http_client.go b/pkg/common/http/http_client.go index 2d7c24c77..579643964 100644 --- a/pkg/common/http/http_client.go +++ b/pkg/common/http/http_client.go @@ -34,16 +34,21 @@ import ( var client http.Client func Get(url string) (response []byte, err error) { - client := http.Client{Timeout: 5 * time.Second} - resp, err := client.Get(url) + clientGet := http.Client{Timeout: 5 * time.Second} + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, url, nil) if err != nil { return nil, err } + resp, err2 := clientGet.Do(req) + if err2 != nil { + return nil, err + } defer resp.Body.Close() - body, err := io.ReadAll(resp.Body) - if err != nil { + body, err3 := io.ReadAll(resp.Body) + if err3 != nil { return nil, err } + return body, nil } @@ -83,6 +88,7 @@ func Post( if err != nil { return nil, err } + return result, nil } @@ -98,6 +104,7 @@ func PostReturn( return err } err = json.Unmarshal(b, output) + return err } @@ -116,17 +123,22 @@ func callBackPostReturn( if err != nil { if callbackConfig.CallbackFailedContinue != nil && *callbackConfig.CallbackFailedContinue { log.ZWarn(ctx, "callback failed but continue", err, "url", url) + return errs.ErrCallbackContinue } + return errs.ErrNetwork.Wrap(err.Error()) } if err = json.Unmarshal(b, output); err != nil { if callbackConfig.CallbackFailedContinue != nil && *callbackConfig.CallbackFailedContinue { log.ZWarn(ctx, "callback failed but continue", err, "url", url) + return errs.ErrCallbackContinue } + return errs.ErrData.Wrap(err.Error()) } + return output.Parse() } diff --git a/pkg/common/kafka/consumer_group.go b/pkg/common/kafka/consumer_group.go index 1eb7b522a..c5ec69533 100644 --- a/pkg/common/kafka/consumer_group.go +++ b/pkg/common/kafka/consumer_group.go @@ -51,6 +51,7 @@ func NewMConsumerGroup(consumerConfig *MConsumerGroupConfig, topics, addrs []str if err != nil { panic(err.Error()) } + return &MConsumerGroup{ consumerGroup, groupID, diff --git a/pkg/common/kafka/producer.go b/pkg/common/kafka/producer.go index 4a52d2bef..b4545af9b 100644 --- a/pkg/common/kafka/producer.go +++ b/pkg/common/kafka/producer.go @@ -67,6 +67,7 @@ func NewKafkaProducer(addr []string, topic string) *Producer { producer, err = sarama.NewSyncProducer(p.addr, p.config) // Initialize the client if err == nil { p.producer = producer + return &p } //TODO If the password is wrong, exit directly @@ -83,6 +84,7 @@ func NewKafkaProducer(addr []string, topic string) *Producer { panic(err.Error()) } p.producer = producer + return &p } @@ -91,6 +93,7 @@ func GetMQHeaderWithContext(ctx context.Context) ([]sarama.RecordHeader, error) if err != nil { return nil, err } + return []sarama.RecordHeader{ {Key: []byte(constant.OperationID), Value: []byte(operationID)}, {Key: []byte(constant.OpUserID), Value: []byte(opUserID)}, @@ -100,10 +103,11 @@ func GetMQHeaderWithContext(ctx context.Context) ([]sarama.RecordHeader, error) } func GetContextWithMQHeader(header []*sarama.RecordHeader) context.Context { - var values []string + values := make([]string, 0, len(header)) for _, recordHeader := range header { values = append(values, string(recordHeader.Value)) } + return mcontext.WithMustInfoCtx(values) // TODO } @@ -134,5 +138,6 @@ func (p *Producer) SendMessage(ctx context.Context, key string, msg proto.Messag if err == nil { prome.Inc(prome.SendMsgCounter) } + return partition, offset, utils.Wrap(err, "") } diff --git a/pkg/common/locker/message_locker.go b/pkg/common/locker/message_locker.go index 55241eb5f..108dbbd48 100644 --- a/pkg/common/locker/message_locker.go +++ b/pkg/common/locker/message_locker.go @@ -42,11 +42,13 @@ func (l *LockerMessage) LockMessageTypeKey(ctx context.Context, clientMsgID, typ err = l.cache.LockMessageTypeKey(ctx, clientMsgID, typeKey) if err != nil { time.Sleep(time.Millisecond * 100) + continue } else { break } } + return err } @@ -55,11 +57,13 @@ func (l *LockerMessage) LockGlobalMessage(ctx context.Context, clientMsgID strin err = l.cache.LockMessageTypeKey(ctx, clientMsgID, GlOBALLOCK) if err != nil { time.Sleep(time.Millisecond * 100) + continue } else { break } } + return err } diff --git a/pkg/common/prome/gather.go b/pkg/common/prome/gather.go index eb4bc6c3b..184034e25 100644 --- a/pkg/common/prome/gather.go +++ b/pkg/common/prome/gather.go @@ -79,6 +79,7 @@ var ( ConversationCreateFailedCounter prometheus.Counter ) +//nolint:promlinter //no idea to fix it func NewUserLoginCounter() { if UserLoginCounter != nil { return @@ -89,6 +90,7 @@ func NewUserLoginCounter() { }) } +//nolint:promlinter //no idea to fix it func NewUserRegisterCounter() { if UserRegisterCounter != nil { return @@ -99,6 +101,7 @@ func NewUserRegisterCounter() { }) } +//nolint:promlinter //no idea to fix it func NewSeqGetSuccessCounter() { if SeqGetSuccessCounter != nil { return @@ -109,6 +112,7 @@ func NewSeqGetSuccessCounter() { }) } +//nolint:promlinter //no idea to fix it func NewSeqGetFailedCounter() { if SeqGetFailedCounter != nil { return @@ -119,6 +123,7 @@ func NewSeqGetFailedCounter() { }) } +//nolint:promlinter //no idea to fix it func NewSeqSetSuccessCounter() { if SeqSetSuccessCounter != nil { return @@ -129,6 +134,7 @@ func NewSeqSetSuccessCounter() { }) } +//nolint:promlinter //no idea to fix it func NewSeqSetFailedCounter() { if SeqSetFailedCounter != nil { return @@ -139,6 +145,7 @@ func NewSeqSetFailedCounter() { }) } +//nolint:promlinter //no idea to fix it func NewApiRequestCounter() { if ApiRequestCounter != nil { return @@ -149,6 +156,7 @@ func NewApiRequestCounter() { }) } +//nolint:promlinter //no idea to fix it func NewApiRequestSuccessCounter() { if ApiRequestSuccessCounter != nil { return @@ -159,6 +167,7 @@ func NewApiRequestSuccessCounter() { }) } +//nolint:promlinter //no idea to fix it func NewApiRequestFailedCounter() { if ApiRequestFailedCounter != nil { return @@ -169,6 +178,7 @@ func NewApiRequestFailedCounter() { }) } +//nolint:promlinter //no idea to fix it func NewGrpcRequestCounter() { if GrpcRequestCounter != nil { return @@ -179,6 +189,7 @@ func NewGrpcRequestCounter() { }) } +//nolint:promlinter //no idea to fix it func NewGrpcRequestSuccessCounter() { if GrpcRequestSuccessCounter != nil { return @@ -189,6 +200,7 @@ func NewGrpcRequestSuccessCounter() { }) } +//nolint:promlinter //no idea to fix it func NewGrpcRequestFailedCounter() { if GrpcRequestFailedCounter != nil { return @@ -199,6 +211,7 @@ func NewGrpcRequestFailedCounter() { }) } +//nolint:promlinter //no idea to fix it func NewSendMsgCount() { if SendMsgCounter != nil { return @@ -209,6 +222,7 @@ func NewSendMsgCount() { }) } +//nolint:promlinter //no idea to fix it func NewMsgInsertRedisSuccessCounter() { if MsgInsertRedisSuccessCounter != nil { return @@ -219,6 +233,7 @@ func NewMsgInsertRedisSuccessCounter() { }) } +//nolint:promlinter //no idea to fix its func NewMsgInsertRedisFailedCounter() { if MsgInsertRedisFailedCounter != nil { return @@ -229,6 +244,7 @@ func NewMsgInsertRedisFailedCounter() { }) } +//nolint:promlinter //no idea to fix it func NewMsgInsertMongoSuccessCounter() { if MsgInsertMongoSuccessCounter != nil { return @@ -239,6 +255,7 @@ func NewMsgInsertMongoSuccessCounter() { }) } +//nolint:promlinter //no idea to fix it func NewMsgInsertMongoFailedCounter() { if MsgInsertMongoFailedCounter != nil { return @@ -249,6 +266,7 @@ func NewMsgInsertMongoFailedCounter() { }) } +//nolint:promlinter //no idea to fix it func NewMsgPullFromRedisSuccessCounter() { if MsgPullFromRedisSuccessCounter != nil { return @@ -259,6 +277,7 @@ func NewMsgPullFromRedisSuccessCounter() { }) } +//nolint:promlinter //no idea to fix it func NewMsgPullFromRedisFailedCounter() { if MsgPullFromRedisFailedCounter != nil { return @@ -269,6 +288,7 @@ func NewMsgPullFromRedisFailedCounter() { }) } +//nolint:promlinter //no idea to fix it func NewMsgPullFromMongoSuccessCounter() { if MsgPullFromMongoSuccessCounter != nil { return @@ -279,6 +299,7 @@ func NewMsgPullFromMongoSuccessCounter() { }) } +//nolint:promlinter //no idea to fix it func NewMsgPullFromMongoFailedCounter() { if MsgPullFromMongoFailedCounter != nil { return @@ -319,6 +340,7 @@ func NewPullMsgBySeqListTotalCounter() { }) } +//nolint:promlinter //no idea to fix it func NewSingleChatMsgRecvSuccessCounter() { if SingleChatMsgRecvSuccessCounter != nil { return @@ -329,6 +351,7 @@ func NewSingleChatMsgRecvSuccessCounter() { }) } +//nolint:promlinter //no idea to fix it func NewGroupChatMsgRecvSuccessCounter() { if GroupChatMsgRecvSuccessCounter != nil { return @@ -339,6 +362,7 @@ func NewGroupChatMsgRecvSuccessCounter() { }) } +//nolint:promlinter //no idea to fix it func NewWorkSuperGroupChatMsgRecvSuccessCounter() { if WorkSuperGroupChatMsgRecvSuccessCounter != nil { return @@ -359,6 +383,7 @@ func NewOnlineUserGauges() { }) } +//nolint:promlinter //no idea to fix it func NewSingleChatMsgProcessSuccessCounter() { if SingleChatMsgProcessSuccessCounter != nil { return @@ -369,6 +394,7 @@ func NewSingleChatMsgProcessSuccessCounter() { }) } +//nolint:promlinter //no idea to fix it func NewSingleChatMsgProcessFailedCounter() { if SingleChatMsgProcessFailedCounter != nil { return @@ -379,6 +405,7 @@ func NewSingleChatMsgProcessFailedCounter() { }) } +//nolint:promlinter //no idea to fix it func NewGroupChatMsgProcessSuccessCounter() { if GroupChatMsgProcessSuccessCounter != nil { return @@ -389,6 +416,7 @@ func NewGroupChatMsgProcessSuccessCounter() { }) } +//nolint:promlinter //no idea to fix it func NewGroupChatMsgProcessFailedCounter() { if GroupChatMsgProcessFailedCounter != nil { return @@ -399,6 +427,7 @@ func NewGroupChatMsgProcessFailedCounter() { }) } +//nolint:promlinter //no idea to fix it func NewWorkSuperGroupChatMsgProcessSuccessCounter() { if WorkSuperGroupChatMsgProcessSuccessCounter != nil { return @@ -409,6 +438,7 @@ func NewWorkSuperGroupChatMsgProcessSuccessCounter() { }) } +//nolint:promlinter //no idea to fix it func NewWorkSuperGroupChatMsgProcessFailedCounter() { if WorkSuperGroupChatMsgProcessFailedCounter != nil { return @@ -419,6 +449,7 @@ func NewWorkSuperGroupChatMsgProcessFailedCounter() { }) } +//nolint:promlinter //no idea to fix it func NewMsgOnlinePushSuccessCounter() { if MsgOnlinePushSuccessCounter != nil { return @@ -429,6 +460,7 @@ func NewMsgOnlinePushSuccessCounter() { }) } +//nolint:promlinter //no idea to fix it func NewMsgOfflinePushSuccessCounter() { if MsgOfflinePushSuccessCounter != nil { return @@ -439,6 +471,7 @@ func NewMsgOfflinePushSuccessCounter() { }) } +//nolint:promlinter //no idea to fix it func NewMsgOfflinePushFailedCounter() { if MsgOfflinePushFailedCounter != nil { return @@ -449,6 +482,7 @@ func NewMsgOfflinePushFailedCounter() { }) } +//nolint:promlinter //no idea to fix it func NewConversationCreateSuccessCounter() { if ConversationCreateSuccessCounter != nil { return @@ -459,6 +493,7 @@ func NewConversationCreateSuccessCounter() { }) } +//nolint:promlinter //no idea to fix it func NewConversationCreateFailedCounter() { if ConversationCreateFailedCounter != nil { return diff --git a/pkg/common/prome/prometheus.go b/pkg/common/prome/prometheus.go index 254a6c9ea..60df5b0af 100644 --- a/pkg/common/prome/prometheus.go +++ b/pkg/common/prome/prometheus.go @@ -30,13 +30,16 @@ func StartPrometheusSrv(prometheusPort int) error { if config.Config.Prometheus.Enable { http.Handle("/metrics", promhttp.Handler()) err := http.ListenAndServe(":"+strconv.Itoa(prometheusPort), nil) + return err } + return nil } func PrometheusHandler() gin.HandlerFunc { h := promhttp.Handler() + return func(c *gin.Context) { h.ServeHTTP(c.Writer, c.Request) } @@ -49,6 +52,7 @@ type responseBodyWriter struct { func (r responseBodyWriter) Write(b []byte) (int, error) { r.body.Write(b) + return r.ResponseWriter.Write(b) } diff --git a/pkg/common/tls/tls.go b/pkg/common/tls/tls.go old mode 100644 new mode 100755 index 3bf91beb9..7b3e9033e --- a/pkg/common/tls/tls.go +++ b/pkg/common/tls/tls.go @@ -24,6 +24,7 @@ import ( "github.com/openimsdk/open-im-server/v3/pkg/common/config" ) +//nolint:staticcheck //we have not time looking for a replacement for x509 to fix the security valnerability func decryptPEM(data []byte, passphrase []byte) ([]byte, error) { if len(passphrase) == 0 { return data, nil @@ -33,6 +34,7 @@ func decryptPEM(data []byte, passphrase []byte) ([]byte, error) { if err != nil { return nil, err } + return pem.EncodeToMemory(&pem.Block{ Type: b.Type, Bytes: d, @@ -44,6 +46,7 @@ func readEncryptablePEMBlock(path string, pwd []byte) ([]byte, error) { if err != nil { return nil, err } + return decryptPEM(data, pwd) } diff --git a/pkg/msgprocessor/conversation.go b/pkg/msgprocessor/conversation.go index ca77438ea..559994eaf 100644 --- a/pkg/msgprocessor/conversation.go +++ b/pkg/msgprocessor/conversation.go @@ -28,6 +28,7 @@ func GetNotificationConversationIDByMsg(msg *sdkws.MsgData) string { case constant.SingleChatType: l := []string{msg.SendID, msg.RecvID} sort.Strings(l) + return "n_" + strings.Join(l, "_") case constant.GroupChatType: return "n_" + msg.GroupID @@ -36,6 +37,7 @@ func GetNotificationConversationIDByMsg(msg *sdkws.MsgData) string { case constant.NotificationChatType: return "n_" + msg.SendID + "_" + msg.RecvID } + return "" } @@ -44,6 +46,7 @@ func GetChatConversationIDByMsg(msg *sdkws.MsgData) string { case constant.SingleChatType: l := []string{msg.SendID, msg.RecvID} sort.Strings(l) + return "si_" + strings.Join(l, "_") case constant.GroupChatType: return "g_" + msg.GroupID @@ -52,6 +55,7 @@ func GetChatConversationIDByMsg(msg *sdkws.MsgData) string { case constant.NotificationChatType: return "sn_" + msg.SendID + "_" + msg.RecvID } + return "" } @@ -60,10 +64,12 @@ func GenConversationUniqueKey(msg *sdkws.MsgData) string { case constant.SingleChatType, constant.NotificationChatType: l := []string{msg.SendID, msg.RecvID} sort.Strings(l) + return strings.Join(l, "_") case constant.SuperGroupChatType: return msg.GroupID } + return "" } @@ -76,23 +82,28 @@ func GetConversationIDByMsg(msg *sdkws.MsgData) string { if !options.IsNotNotification() { return "n_" + strings.Join(l, "_") } + return "si_" + strings.Join(l, "_") // single chat case constant.GroupChatType: if !options.IsNotNotification() { return "n_" + msg.GroupID // group chat } + return "g_" + msg.GroupID // group chat case constant.SuperGroupChatType: if !options.IsNotNotification() { return "n_" + msg.GroupID // super group chat } + return "sg_" + msg.GroupID // super group chat case constant.NotificationChatType: if !options.IsNotNotification() { return "n_" + msg.SendID + "_" + msg.RecvID // super group chat } + return "sn_" + msg.SendID + "_" + msg.RecvID // server notification chat } + return "" } @@ -111,6 +122,7 @@ func GetConversationIDBySessionType(sessionType int, ids ...string) string { case constant.NotificationChatType: return "sn_" + ids[0] // server notification chat } + return "" } @@ -118,10 +130,11 @@ func GetNotificationConversationIDByConversationID(conversationID string) string l := strings.Split(conversationID, "_") if len(l) > 1 { l[0] = "n" + return strings.Join(l, "_") - } else { - return "" } + + return "" } func GetNotificationConversationID(sessionType int, ids ...string) string { @@ -135,6 +148,7 @@ func GetNotificationConversationID(sessionType int, ids ...string) string { case constant.SuperGroupChatType: return "n_" + ids[0] // super group chat } + return "" } @@ -155,18 +169,22 @@ func ParseConversationID(msg *sdkws.MsgData) (isNotification bool, conversationI if !options.IsNotNotification() { return true, "n_" + strings.Join(l, "_") } + return false, "si_" + strings.Join(l, "_") // single chat case constant.SuperGroupChatType: if !options.IsNotNotification() { return true, "n_" + msg.GroupID // super group chat } + return false, "sg_" + msg.GroupID // super group chat case constant.NotificationChatType: if !options.IsNotNotification() { return true, "n_" + msg.SendID + "_" + msg.RecvID // super group chat } + return false, "sn_" + msg.SendID + "_" + msg.RecvID // server notification chat } + return false, "" } @@ -189,6 +207,7 @@ func Pb2String(pb proto.Message) (string, error) { if err != nil { return "", err } + return string(s), nil } diff --git a/pkg/msgprocessor/options.go b/pkg/msgprocessor/options.go index c17c7cb05..27bbb839a 100644 --- a/pkg/msgprocessor/options.go +++ b/pkg/msgprocessor/options.go @@ -38,12 +38,14 @@ func NewOptions(opts ...OptionsOpt) Options { for _, opt := range opts { opt(options) } + return options } func NewMsgOptions() Options { options := make(map[string]bool, 11) options[constant.IsOfflinePush] = false + return make(map[string]bool) } @@ -51,6 +53,7 @@ func WithOptions(options Options, opts ...OptionsOpt) Options { for _, opt := range opts { opt(options) } + return options } @@ -131,6 +134,7 @@ func (o Options) Is(notification string) bool { if !ok || v { return true } + return false } diff --git a/pkg/rpcclient/auth.go b/pkg/rpcclient/auth.go index 0ee021de1..4859e541d 100644 --- a/pkg/rpcclient/auth.go +++ b/pkg/rpcclient/auth.go @@ -31,6 +31,7 @@ func NewAuth(discov discoveryregistry.SvcDiscoveryRegistry) *Auth { panic(err) } client := auth.NewAuthClient(conn) + return &Auth{discov: discov, conn: conn, Client: client} } diff --git a/pkg/rpcclient/conversation.go b/pkg/rpcclient/conversation.go index 60ca53351..30b0b4b77 100644 --- a/pkg/rpcclient/conversation.go +++ b/pkg/rpcclient/conversation.go @@ -39,6 +39,7 @@ func NewConversation(discov discoveryregistry.SvcDiscoveryRegistry) *Conversatio panic(err) } client := pbconversation.NewConversationClient(conn) + return &Conversation{discov: discov, conn: conn, Client: client} } @@ -56,26 +57,31 @@ func (c *ConversationRpcClient) GetSingleConversationRecvMsgOpt(ctx context.Cont if err != nil { return 0, err } + return conversation.GetConversation().RecvMsgOpt, err } func (c *ConversationRpcClient) SingleChatFirstCreateConversation(ctx context.Context, recvID, sendID string) error { _, err := c.Client.CreateSingleChatConversations(ctx, &pbconversation.CreateSingleChatConversationsReq{RecvID: recvID, SendID: sendID}) + return err } func (c *ConversationRpcClient) GroupChatFirstCreateConversation(ctx context.Context, groupID string, userIDs []string) error { _, err := c.Client.CreateGroupChatConversations(ctx, &pbconversation.CreateGroupChatConversationsReq{UserIDs: userIDs, GroupID: groupID}) + return err } func (c *ConversationRpcClient) SetConversationMaxSeq(ctx context.Context, ownerUserIDs []string, conversationID string, maxSeq int64) error { _, err := c.Client.SetConversationMaxSeq(ctx, &pbconversation.SetConversationMaxSeqReq{OwnerUserID: ownerUserIDs, ConversationID: conversationID, MaxSeq: maxSeq}) + return err } func (c *ConversationRpcClient) SetConversations(ctx context.Context, userIDs []string, conversation *pbconversation.ConversationReq) error { _, err := c.Client.SetConversations(ctx, &pbconversation.SetConversationsReq{UserIDs: userIDs, Conversation: conversation}) + return err } @@ -84,6 +90,7 @@ func (c *ConversationRpcClient) GetConversationIDs(ctx context.Context, ownerUse if err != nil { return nil, err } + return resp.ConversationIDs, nil } @@ -92,6 +99,7 @@ func (c *ConversationRpcClient) GetConversation(ctx context.Context, ownerUserID if err != nil { return nil, err } + return resp.Conversation, nil } @@ -106,6 +114,7 @@ func (c *ConversationRpcClient) GetConversationsByConversationID(ctx context.Con if len(resp.Conversations) == 0 { return nil, errs.ErrRecordNotFound.Wrap(fmt.Sprintf("conversationIDs: %v not found", conversationIDs)) } + return resp.Conversations, nil } @@ -124,5 +133,6 @@ func (c *ConversationRpcClient) GetConversations( if err != nil { return nil, err } + return resp.Conversations, nil } diff --git a/pkg/rpcclient/friend.go b/pkg/rpcclient/friend.go index b84db40d4..6b214aaf2 100644 --- a/pkg/rpcclient/friend.go +++ b/pkg/rpcclient/friend.go @@ -38,6 +38,7 @@ func NewFriend(discov discoveryregistry.SvcDiscoveryRegistry) *Friend { panic(err) } client := friend.NewFriendClient(conn) + return &Friend{discov: discov, conn: conn, Client: client} } @@ -59,6 +60,7 @@ func (f *FriendRpcClient) GetFriendsInfo( return nil, err } resp = r.FriendsInfo[0] + return } @@ -68,6 +70,7 @@ func (f *FriendRpcClient) IsFriend(ctx context.Context, possibleFriendUserID, us if err != nil { return false, err } + return resp.InUser1Friends, nil } @@ -77,6 +80,7 @@ func (f *FriendRpcClient) GetFriendIDs(ctx context.Context, ownerUserID string) if err != nil { return nil, err } + return resp.FriendIDs, nil } @@ -85,5 +89,6 @@ func (b *FriendRpcClient) IsBlocked(ctx context.Context, possibleBlackUserID, us if err != nil { return false, err } + return r.InUser2Blacks, nil } diff --git a/pkg/rpcclient/group.go b/pkg/rpcclient/group.go index bf0efe60c..5a340875b 100644 --- a/pkg/rpcclient/group.go +++ b/pkg/rpcclient/group.go @@ -42,6 +42,7 @@ func NewGroup(discov discoveryregistry.SvcDiscoveryRegistry) *Group { panic(err) } client := group.NewGroupClient(conn) + return &Group{discov: discov, conn: conn, Client: client} } @@ -69,6 +70,7 @@ func (g *GroupRpcClient) GetGroupInfos( return nil, errs.ErrGroupIDNotFound.Wrap(strings.Join(ids, ",")) } } + return resp.GroupInfos, nil } @@ -77,6 +79,7 @@ func (g *GroupRpcClient) GetGroupInfo(ctx context.Context, groupID string) (*sdk if err != nil { return nil, err } + return groups[0], nil } @@ -89,6 +92,7 @@ func (g *GroupRpcClient) GetGroupInfoMap( if err != nil { return nil, err } + return utils.SliceToMap(groups, func(e *sdkws.GroupInfo) string { return e.GroupID }), nil @@ -114,6 +118,7 @@ func (g *GroupRpcClient) GetGroupMemberInfos( return nil, errs.ErrNotInGroupYet.Wrap(strings.Join(ids, ",")) } } + return resp.Members, nil } @@ -126,6 +131,7 @@ func (g *GroupRpcClient) GetGroupMemberInfo( if err != nil { return nil, err } + return members[0], nil } @@ -139,6 +145,7 @@ func (g *GroupRpcClient) GetGroupMemberInfoMap( if err != nil { return nil, err } + return utils.SliceToMap(members, func(e *sdkws.GroupMemberFullInfo) string { return e.UserID }), nil @@ -155,6 +162,7 @@ func (g *GroupRpcClient) GetOwnerAndAdminInfos( if err != nil { return nil, err } + return resp.Members, nil } @@ -163,6 +171,7 @@ func (g *GroupRpcClient) GetOwnerInfo(ctx context.Context, groupID string) (*sdk GroupID: groupID, RoleLevels: []int32{constant.GroupOwner}, }) + return resp.Members[0], err } @@ -173,6 +182,7 @@ func (g *GroupRpcClient) GetGroupMemberIDs(ctx context.Context, groupID string) if err != nil { return nil, err } + return resp.UserIDs, nil } @@ -183,6 +193,7 @@ func (g *GroupRpcClient) GetGroupInfoCache(ctx context.Context, groupID string) if err != nil { return nil, err } + return resp.GroupInfo, nil } @@ -198,6 +209,7 @@ func (g *GroupRpcClient) GetGroupMemberCache( if err != nil { return nil, err } + return resp.Member, nil } @@ -206,6 +218,7 @@ func (g *GroupRpcClient) DismissGroup(ctx context.Context, groupID string) error GroupID: groupID, DeleteMember: true, }) + return err } @@ -213,5 +226,6 @@ func (g *GroupRpcClient) NotificationUserInfoUpdate(ctx context.Context, userID _, err := g.Client.NotificationUserInfoUpdate(ctx, &group.NotificationUserInfoUpdateReq{ UserID: userID, }) + return err } diff --git a/pkg/rpcclient/msg.go b/pkg/rpcclient/msg.go index e3dad874e..51e29c7d8 100644 --- a/pkg/rpcclient/msg.go +++ b/pkg/rpcclient/msg.go @@ -136,6 +136,7 @@ func NewMessage(discov discoveryregistry.SvcDiscoveryRegistry) *Message { panic(err) } client := msg.NewMsgClient(conn) + return &Message{discov: discov, conn: conn, Client: client} } @@ -147,16 +148,19 @@ func NewMessageRpcClient(discov discoveryregistry.SvcDiscoveryRegistry) MessageR func (m *MessageRpcClient) SendMsg(ctx context.Context, req *msg.SendMsgReq) (*msg.SendMsgResp, error) { resp, err := m.Client.SendMsg(ctx, req) + return resp, err } func (m *MessageRpcClient) GetMaxSeq(ctx context.Context, req *sdkws.GetMaxSeqReq) (*sdkws.GetMaxSeqResp, error) { resp, err := m.Client.GetMaxSeq(ctx, req) + return resp, err } func (m *MessageRpcClient) PullMessageBySeqList(ctx context.Context, req *sdkws.PullMessageBySeqsReq) (*sdkws.PullMessageBySeqsResp, error) { resp, err := m.Client.PullMessageBySeqs(ctx, req) + return resp, err } @@ -165,6 +169,7 @@ func (m *MessageRpcClient) GetConversationMaxSeq(ctx context.Context, conversati if err != nil { return 0, err } + return resp.MaxSeq, nil } @@ -200,6 +205,7 @@ func NewNotificationSender(opts ...NotificationSenderOptions) *NotificationSende for _, opt := range opts { opt(notificationSender) } + return notificationSender } @@ -220,6 +226,7 @@ func (s *NotificationSender) NotificationWithSesstionType(ctx context.Context, s content, err := json.Marshal(&n) if err != nil { log.ZError(ctx, "MsgClient Notification json.Marshal failed", err, "sendID", sendID, "recvID", recvID, "contentType", contentType, "msg", m) + return err } notificationOpt := ¬ificationOpt{} @@ -229,7 +236,8 @@ func (s *NotificationSender) NotificationWithSesstionType(ctx context.Context, s var req msg.SendMsgReq var msg sdkws.MsgData if notificationOpt.WithRpcGetUsername && s.getUserInfo != nil { - userInfo, err := s.getUserInfo(ctx, sendID) + var userInfo *sdkws.UserInfo + userInfo, err = s.getUserInfo(ctx, sendID) if err != nil { log.ZWarn(ctx, "getUserInfo failed", err, "sendID", sendID) } else { @@ -267,6 +275,7 @@ func (s *NotificationSender) NotificationWithSesstionType(ctx context.Context, s } else { log.ZError(ctx, "MsgClient Notification SendMsg failed", err, "req", &req) } + return err } diff --git a/pkg/rpcclient/notification/conevrsation.go b/pkg/rpcclient/notification/conversation.go similarity index 99% rename from pkg/rpcclient/notification/conevrsation.go rename to pkg/rpcclient/notification/conversation.go index 77fc623e4..0fefb147e 100644 --- a/pkg/rpcclient/notification/conevrsation.go +++ b/pkg/rpcclient/notification/conversation.go @@ -41,6 +41,7 @@ func (c *ConversationNotificationSender) ConversationSetPrivateNotification(ctx IsPrivate: isPrivateChat, ConversationID: conversationID, } + return c.Notification(ctx, sendID, recvID, constant.ConversationPrivateChatNotification, tips) } @@ -50,6 +51,7 @@ func (c *ConversationNotificationSender) ConversationChangeNotification(ctx cont UserID: userID, ConversationIDList: conversationIDs, } + return c.Notification(ctx, userID, userID, constant.ConversationChangeNotification, tips) } @@ -65,5 +67,6 @@ func (c *ConversationNotificationSender) ConversationUnreadChangeNotification( HasReadSeq: hasReadSeq, UnreadCountTime: unreadCountTime, } + return c.Notification(ctx, userID, userID, constant.ConversationUnreadNotification, tips) } diff --git a/pkg/rpcclient/notification/friend.go b/pkg/rpcclient/notification/friend.go index b061a24ae..9dae27c6e 100644 --- a/pkg/rpcclient/notification/friend.go +++ b/pkg/rpcclient/notification/friend.go @@ -57,6 +57,7 @@ func WithDBFunc( for _, user := range users { result = append(result, user) } + return result, nil } s.getUsersInfo = f @@ -75,6 +76,7 @@ func WithRpcFunc( for _, user := range users { result = append(result, user) } + return result, err } s.getUsersInfo = f @@ -91,6 +93,7 @@ func NewFriendNotificationSender( for _, opt := range opts { opt(f) } + return f } @@ -106,22 +109,13 @@ func (f *FriendNotificationSender) getUsersInfoMap( for _, user := range users { result[user.GetUserID()] = user.(*sdkws.UserInfo) } - return result, nil -} -func (f *FriendNotificationSender) getFromToUserNickname( - ctx context.Context, - fromUserID, toUserID string, -) (string, string, error) { - users, err := f.getUsersInfoMap(ctx, []string{fromUserID, toUserID}) - if err != nil { - return "", "", nil - } - return users[fromUserID].Nickname, users[toUserID].Nickname, nil + return result, nil } func (f *FriendNotificationSender) UserInfoUpdatedNotification(ctx context.Context, changedUserID string) error { tips := sdkws.UserInfoUpdatedTips{UserID: changedUserID} + return f.Notification(ctx, mcontext.GetOpUserID(ctx), changedUserID, constant.UserInfoUpdatedNotification, &tips) } @@ -133,6 +127,7 @@ func (f *FriendNotificationSender) FriendApplicationAddNotification( FromUserID: req.FromUserID, ToUserID: req.ToUserID, }} + return f.Notification(ctx, req.FromUserID, req.ToUserID, constant.FriendApplicationNotification, &tips) } @@ -144,6 +139,7 @@ func (f *FriendNotificationSender) FriendApplicationAgreedNotification( FromUserID: req.FromUserID, ToUserID: req.ToUserID, }, HandleMsg: req.HandleMsg} + return f.Notification(ctx, req.ToUserID, req.FromUserID, constant.FriendApplicationApprovedNotification, &tips) } @@ -155,6 +151,7 @@ func (f *FriendNotificationSender) FriendApplicationRefusedNotification( FromUserID: req.FromUserID, ToUserID: req.ToUserID, }, HandleMsg: req.HandleMsg} + return f.Notification(ctx, req.ToUserID, req.FromUserID, constant.FriendApplicationRejectedNotification, &tips) } @@ -179,6 +176,7 @@ func (f *FriendNotificationSender) FriendAddedNotification( if err != nil { return err } + return f.Notification(ctx, fromUserID, toUserID, constant.FriendAddedNotification, &tips) } @@ -187,6 +185,7 @@ func (f *FriendNotificationSender) FriendDeletedNotification(ctx context.Context FromUserID: req.OwnerUserID, ToUserID: req.FriendUserID, }} + return f.Notification(ctx, req.OwnerUserID, req.FriendUserID, constant.FriendDeletedNotification, &tips) } @@ -194,6 +193,7 @@ func (f *FriendNotificationSender) FriendRemarkSetNotification(ctx context.Conte tips := sdkws.FriendInfoChangedTips{FromToUserID: &sdkws.FromToUserID{}} tips.FromToUserID.FromUserID = fromUserID tips.FromToUserID.ToUserID = toUserID + return f.Notification(ctx, fromUserID, toUserID, constant.FriendRemarkSetNotification, &tips) } @@ -201,6 +201,7 @@ func (f *FriendNotificationSender) BlackAddedNotification(ctx context.Context, r tips := sdkws.BlackAddedTips{FromToUserID: &sdkws.FromToUserID{}} tips.FromToUserID.FromUserID = req.OwnerUserID tips.FromToUserID.ToUserID = req.BlackUserID + return f.Notification(ctx, req.OwnerUserID, req.BlackUserID, constant.BlackAddedNotification, &tips) } @@ -209,7 +210,10 @@ func (f *FriendNotificationSender) BlackDeletedNotification(ctx context.Context, FromUserID: req.OwnerUserID, ToUserID: req.BlackUserID, }} - f.Notification(ctx, req.OwnerUserID, req.BlackUserID, constant.BlackDeletedNotification, &blackDeletedTips) + err := f.Notification(ctx, req.OwnerUserID, req.BlackUserID, constant.BlackDeletedNotification, &blackDeletedTips) + if err != nil { + panic(err) + } } func (f *FriendNotificationSender) FriendInfoUpdatedNotification( @@ -218,5 +222,8 @@ func (f *FriendNotificationSender) FriendInfoUpdatedNotification( needNotifiedUserID string, ) { tips := sdkws.UserInfoUpdatedTips{UserID: changedUserID} - f.Notification(ctx, mcontext.GetOpUserID(ctx), needNotifiedUserID, constant.FriendInfoUpdatedNotification, &tips) + err := f.Notification(ctx, mcontext.GetOpUserID(ctx), needNotifiedUserID, constant.FriendInfoUpdatedNotification, &tips) + if err != nil { + panic(err) + } } diff --git a/pkg/rpcclient/notification/group.go b/pkg/rpcclient/notification/group.go old mode 100644 new mode 100755 index 8e71f61c3..23341af70 --- a/pkg/rpcclient/notification/group.go +++ b/pkg/rpcclient/notification/group.go @@ -60,6 +60,7 @@ func (g *GroupNotificationSender) getUser(ctx context.Context, userID string) (* if len(users) == 0 { return nil, errs.ErrUserIDNotFound.Wrap(fmt.Sprintf("user %s not found", userID)) } + return &sdkws.PublicUserInfo{ UserID: users[0].GetUserID(), Nickname: users[0].GetNickname(), @@ -68,6 +69,23 @@ func (g *GroupNotificationSender) getUser(ctx context.Context, userID string) (* }, nil } +func (g *GroupNotificationSender) groupMemberDB2PB(member *relation.GroupMemberModel, appMangerLevel int32) *sdkws.GroupMemberFullInfo { + return &sdkws.GroupMemberFullInfo{ + GroupID: member.GroupID, + UserID: member.UserID, + RoleLevel: member.RoleLevel, + JoinTime: member.JoinTime.UnixMilli(), + Nickname: member.Nickname, + FaceURL: member.FaceURL, + AppMangerLevel: appMangerLevel, + JoinSource: member.JoinSource, + OperatorUserID: member.OperatorUserID, + Ex: member.Ex, + MuteEndTime: member.MuteEndTime.UnixMilli(), + InviterUserID: member.InviterUserID, + } +} + func (g *GroupNotificationSender) getGroupInfo(ctx context.Context, groupID string) (*sdkws.GroupInfo, error) { gm, err := g.db.TakeGroup(ctx, groupID) if err != nil { @@ -81,6 +99,7 @@ func (g *GroupNotificationSender) getGroupInfo(ctx context.Context, groupID stri if err != nil { return nil, err } + return &sdkws.GroupInfo{ GroupID: gm.GroupID, GroupName: gm.GroupName, @@ -148,6 +167,7 @@ func (g *GroupNotificationSender) getGroupMemberMap(ctx context.Context, groupID for i, member := range members { m[member.UserID] = members[i] } + return m, nil } @@ -159,6 +179,7 @@ func (g *GroupNotificationSender) getGroupMember(ctx context.Context, groupID st if len(members) == 0 { return nil, errs.ErrInternalServer.Wrap(fmt.Sprintf("group %s member %s not found", groupID, userID)) } + return members[0], nil } @@ -168,48 +189,10 @@ func (g *GroupNotificationSender) getGroupOwnerAndAdminUserID(ctx context.Contex return nil, err } fn := func(e *relation.GroupMemberModel) string { return e.UserID } + return utils.Slice(members, fn), nil } -func (g *GroupNotificationSender) groupDB2PB(group *relation.GroupModel, ownerUserID string, memberCount uint32) *sdkws.GroupInfo { - return &sdkws.GroupInfo{ - GroupID: group.GroupID, - GroupName: group.GroupName, - Notification: group.Notification, - Introduction: group.Introduction, - FaceURL: group.FaceURL, - OwnerUserID: ownerUserID, - CreateTime: group.CreateTime.UnixMilli(), - MemberCount: memberCount, - Ex: group.Ex, - Status: group.Status, - CreatorUserID: group.CreatorUserID, - GroupType: group.GroupType, - NeedVerification: group.NeedVerification, - LookMemberInfo: group.LookMemberInfo, - ApplyMemberFriend: group.ApplyMemberFriend, - NotificationUpdateTime: group.NotificationUpdateTime.UnixMilli(), - NotificationUserID: group.NotificationUserID, - } -} - -func (g *GroupNotificationSender) groupMemberDB2PB(member *relation.GroupMemberModel, appMangerLevel int32) *sdkws.GroupMemberFullInfo { - return &sdkws.GroupMemberFullInfo{ - GroupID: member.GroupID, - UserID: member.UserID, - RoleLevel: member.RoleLevel, - JoinTime: member.JoinTime.UnixMilli(), - Nickname: member.Nickname, - FaceURL: member.FaceURL, - AppMangerLevel: appMangerLevel, - JoinSource: member.JoinSource, - OperatorUserID: member.OperatorUserID, - Ex: member.Ex, - MuteEndTime: member.MuteEndTime.UnixMilli(), - InviterUserID: member.InviterUserID, - } -} - func (g *GroupNotificationSender) getUsersInfoMap(ctx context.Context, userIDs []string) (map[string]*sdkws.UserInfo, error) { users, err := g.getUsersInfo(ctx, userIDs) if err != nil { @@ -219,6 +202,7 @@ func (g *GroupNotificationSender) getUsersInfoMap(ctx context.Context, userIDs [ for _, user := range users { result[user.GetUserID()] = user.(*sdkws.UserInfo) } + return result, nil } @@ -236,21 +220,31 @@ func (g *GroupNotificationSender) fillOpUser(ctx context.Context, opUser **sdkws return nil } userID := mcontext.GetOpUserID(ctx) - if groupID != "" { + getOpUser := func(g *GroupNotificationSender, groupID, userID string) (opUser *sdkws.GroupMemberFullInfo, err error) { if authverify.IsManagerUserID(userID) { - *opUser = &sdkws.GroupMemberFullInfo{ + opUser = &sdkws.GroupMemberFullInfo{ GroupID: groupID, UserID: userID, RoleLevel: constant.GroupAdmin, AppMangerLevel: constant.AppAdmin, } - } else { - member, err := g.db.TakeGroupMember(ctx, groupID, userID) - if err == nil { - *opUser = g.groupMemberDB2PB(member, 0) - } else if !errs.ErrRecordNotFound.Is(err) { - return err - } + + return opUser, nil + } + var member *relation.GroupMemberModel + member, err = g.db.TakeGroupMember(ctx, groupID, userID) + if err == nil { + opUser = g.groupMemberDB2PB(member, 0) + } else if !errs.ErrRecordNotFound.Is(err) { + return nil, err + } + + return opUser, nil + } + if groupID != "" { + *opUser, err = getOpUser(g, groupID, userID) + if err != nil { + return err } } user, err := g.getUser(ctx, userID) @@ -273,6 +267,7 @@ func (g *GroupNotificationSender) fillOpUser(ctx context.Context, opUser **sdkws (*opUser).FaceURL = user.FaceURL } } + return nil } @@ -286,6 +281,7 @@ func (g *GroupNotificationSender) GroupCreatedNotification(ctx context.Context, if err := g.fillOpUser(ctx, &tips.OpUser, tips.Group.GroupID); err != nil { return err } + return g.Notification(ctx, mcontext.GetOpUserID(ctx), tips.Group.GroupID, constant.GroupCreatedNotification, tips) } @@ -299,6 +295,7 @@ func (g *GroupNotificationSender) GroupInfoSetNotification(ctx context.Context, if err := g.fillOpUser(ctx, &tips.OpUser, tips.Group.GroupID); err != nil { return err } + return g.Notification(ctx, mcontext.GetOpUserID(ctx), tips.Group.GroupID, constant.GroupInfoSetNotification, tips, rpcclient.WithRpcGetUserName()) } @@ -312,6 +309,7 @@ func (g *GroupNotificationSender) GroupInfoSetNameNotification(ctx context.Conte if err := g.fillOpUser(ctx, &tips.OpUser, tips.Group.GroupID); err != nil { return err } + return g.Notification(ctx, mcontext.GetOpUserID(ctx), tips.Group.GroupID, constant.GroupInfoSetNameNotification, tips) } @@ -325,6 +323,7 @@ func (g *GroupNotificationSender) GroupInfoSetAnnouncementNotification(ctx conte if err := g.fillOpUser(ctx, &tips.OpUser, tips.Group.GroupID); err != nil { return err } + return g.Notification(ctx, mcontext.GetOpUserID(ctx), tips.Group.GroupID, constant.GroupInfoSetAnnouncementNotification, tips, rpcclient.WithRpcGetUserName()) } @@ -355,6 +354,7 @@ func (g *GroupNotificationSender) JoinGroupApplicationNotification(ctx context.C log.ZError(ctx, "JoinGroupApplicationNotification failed", err, "group", req.GroupID, "userID", userID) } } + return nil } @@ -370,6 +370,7 @@ func (g *GroupNotificationSender) MemberQuitNotification(ctx context.Context, me return err } tips := &sdkws.MemberQuitTips{Group: group, QuitUser: member} + return g.Notification(ctx, mcontext.GetOpUserID(ctx), member.GroupID, constant.MemberQuitNotification, tips) } @@ -389,7 +390,8 @@ func (g *GroupNotificationSender) GroupApplicationAcceptedNotification(ctx conte return err } tips := &sdkws.GroupApplicationAcceptedTips{Group: group, HandleMsg: req.HandledMsg, ReceiverAs: 1} - if err := g.fillOpUser(ctx, &tips.OpUser, tips.Group.GroupID); err != nil { + err = g.fillOpUser(ctx, &tips.OpUser, tips.Group.GroupID) + if err != nil { return err } for _, userID := range append(userIDs, mcontext.GetOpUserID(ctx)) { @@ -398,6 +400,7 @@ func (g *GroupNotificationSender) GroupApplicationAcceptedNotification(ctx conte log.ZError(ctx, "failed", err) } } + return nil } @@ -417,7 +420,8 @@ func (g *GroupNotificationSender) GroupApplicationRejectedNotification(ctx conte return err } tips := &sdkws.GroupApplicationRejectedTips{Group: group, HandleMsg: req.HandledMsg} - if err := g.fillOpUser(ctx, &tips.OpUser, tips.Group.GroupID); err != nil { + err = g.fillOpUser(ctx, &tips.OpUser, tips.Group.GroupID) + if err != nil { return err } for _, userID := range append(userIDs, mcontext.GetOpUserID(ctx)) { @@ -426,6 +430,7 @@ func (g *GroupNotificationSender) GroupApplicationRejectedNotification(ctx conte log.ZError(ctx, "failed", err) } } + return nil } @@ -449,6 +454,7 @@ func (g *GroupNotificationSender) GroupOwnerTransferredNotification(ctx context. if err := g.fillOpUser(ctx, &tips.OpUser, tips.Group.GroupID); err != nil { return err } + return g.Notification(ctx, mcontext.GetOpUserID(ctx), group.GroupID, constant.GroupOwnerTransferredNotification, tips) } @@ -462,6 +468,7 @@ func (g *GroupNotificationSender) MemberKickedNotification(ctx context.Context, if err := g.fillOpUser(ctx, &tips.OpUser, tips.Group.GroupID); err != nil { return err } + return g.Notification(ctx, mcontext.GetOpUserID(ctx), tips.Group.GroupID, constant.MemberKickedNotification, tips) } @@ -487,6 +494,7 @@ func (g *GroupNotificationSender) MemberInvitedNotification(ctx context.Context, if err := g.fillOpUser(ctx, &tips.OpUser, tips.Group.GroupID); err != nil { return err } + return g.Notification(ctx, mcontext.GetOpUserID(ctx), group.GroupID, constant.MemberInvitedNotification, tips) } @@ -506,6 +514,7 @@ func (g *GroupNotificationSender) MemberEnterNotification(ctx context.Context, g return err } tips := &sdkws.MemberEnterTips{Group: group, EntrantUser: user} + return g.Notification(ctx, mcontext.GetOpUserID(ctx), group.GroupID, constant.MemberEnterNotification, tips) } @@ -519,6 +528,7 @@ func (g *GroupNotificationSender) GroupDismissedNotification(ctx context.Context if err := g.fillOpUser(ctx, &tips.OpUser, tips.Group.GroupID); err != nil { return err } + return g.Notification(ctx, mcontext.GetOpUserID(ctx), tips.Group.GroupID, constant.GroupDismissedNotification, tips) } @@ -544,6 +554,7 @@ func (g *GroupNotificationSender) GroupMemberMutedNotification(ctx context.Conte if err := g.fillOpUser(ctx, &tips.OpUser, tips.Group.GroupID); err != nil { return err } + return g.Notification(ctx, mcontext.GetOpUserID(ctx), group.GroupID, constant.GroupMemberMutedNotification, tips) } @@ -566,6 +577,7 @@ func (g *GroupNotificationSender) GroupMemberCancelMutedNotification(ctx context if err := g.fillOpUser(ctx, &tips.OpUser, tips.Group.GroupID); err != nil { return err } + return g.Notification(ctx, mcontext.GetOpUserID(ctx), group.GroupID, constant.GroupMemberCancelMutedNotification, tips) } @@ -591,6 +603,7 @@ func (g *GroupNotificationSender) GroupMutedNotification(ctx context.Context, gr if err := g.fillOpUser(ctx, &tips.OpUser, tips.Group.GroupID); err != nil { return err } + return g.Notification(ctx, mcontext.GetOpUserID(ctx), group.GroupID, constant.GroupMutedNotification, tips) } @@ -616,6 +629,7 @@ func (g *GroupNotificationSender) GroupCancelMutedNotification(ctx context.Conte if err := g.fillOpUser(ctx, &tips.OpUser, tips.Group.GroupID); err != nil { return err } + return g.Notification(ctx, mcontext.GetOpUserID(ctx), group.GroupID, constant.GroupCancelMutedNotification, tips) } @@ -638,6 +652,7 @@ func (g *GroupNotificationSender) GroupMemberInfoSetNotification(ctx context.Con if err := g.fillOpUser(ctx, &tips.OpUser, tips.Group.GroupID); err != nil { return err } + return g.Notification(ctx, mcontext.GetOpUserID(ctx), group.GroupID, constant.GroupMemberInfoSetNotification, tips) } @@ -660,6 +675,7 @@ func (g *GroupNotificationSender) GroupMemberSetToAdminNotification(ctx context. if err := g.fillOpUser(ctx, &tips.OpUser, tips.Group.GroupID); err != nil { return err } + return g.Notification(ctx, mcontext.GetOpUserID(ctx), group.GroupID, constant.GroupMemberSetToAdminNotification, tips) } @@ -682,6 +698,7 @@ func (g *GroupNotificationSender) GroupMemberSetToOrdinaryUserNotification(ctx c if err := g.fillOpUser(ctx, &tips.OpUser, tips.Group.GroupID); err != nil { return err } + return g.Notification(ctx, mcontext.GetOpUserID(ctx), group.GroupID, constant.GroupMemberSetToOrdinaryUserNotification, tips) } @@ -693,5 +710,6 @@ func (g *GroupNotificationSender) SuperGroupNotification(ctx context.Context, se } }() err = g.Notification(ctx, sendID, recvID, constant.SuperGroupUpdateNotification, nil) + return err } diff --git a/pkg/rpcclient/notification/msg.go b/pkg/rpcclient/notification/msg.go index 60fa64f40..6e367ac04 100644 --- a/pkg/rpcclient/notification/msg.go +++ b/pkg/rpcclient/notification/msg.go @@ -37,6 +37,7 @@ func (m *MsgNotificationSender) UserDeleteMsgsNotification(ctx context.Context, ConversationID: conversationID, Seqs: seqs, } + return m.Notification(ctx, userID, userID, constant.DeleteMsgsNotification, &tips) } @@ -47,5 +48,6 @@ func (m *MsgNotificationSender) MarkAsReadNotification(ctx context.Context, conv Seqs: seqs, HasReadSeq: hasReadSeq, } + return m.NotificationWithSesstionType(ctx, sendID, recvID, constant.HasReadReceipt, sesstionType, tips) } diff --git a/pkg/rpcclient/notification/user.go b/pkg/rpcclient/notification/user.go index 4feebf7b9..f6e592d18 100644 --- a/pkg/rpcclient/notification/user.go +++ b/pkg/rpcclient/notification/user.go @@ -52,6 +52,7 @@ func WithUserFunc( for _, user := range users { result = append(result, user) } + return result, nil } u.getUsersInfo = f @@ -68,34 +69,37 @@ func NewUserNotificationSender( for _, opt := range opts { opt(f) } + return f } -func (u *UserNotificationSender) getUsersInfoMap( - ctx context.Context, - userIDs []string, -) (map[string]*sdkws.UserInfo, error) { - users, err := u.getUsersInfo(ctx, userIDs) - if err != nil { - return nil, err - } - result := make(map[string]*sdkws.UserInfo) - for _, user := range users { - result[user.GetUserID()] = user.(*sdkws.UserInfo) - } - return result, nil -} +// func (u *UserNotificationSender) getUsersInfoMap( +// ctx context.Context, +// userIDs []string, +// ) (map[string]*sdkws.UserInfo, error) { +// users, err := u.getUsersInfo(ctx, userIDs) +// if err != nil { +// return nil, err +// } +// result := make(map[string]*sdkws.UserInfo) +// for _, user := range users { +// result[user.GetUserID()] = user.(*sdkws.UserInfo) +// } -func (u *UserNotificationSender) getFromToUserNickname( - ctx context.Context, - fromUserID, toUserID string, -) (string, string, error) { - users, err := u.getUsersInfoMap(ctx, []string{fromUserID, toUserID}) - if err != nil { - return "", "", nil - } - return users[fromUserID].Nickname, users[toUserID].Nickname, nil -} +// return result, nil +// } + +// func (u *UserNotificationSender) getFromToUserNickname( +// ctx context.Context, +// fromUserID, toUserID string, +// ) (string, string, error) { +// users, err := u.getUsersInfoMap(ctx, []string{fromUserID, toUserID}) +// if err != nil { +// return "", "", err +// } + +// return users[fromUserID].Nickname, users[toUserID].Nickname, nil +// } func (u *UserNotificationSender) UserStatusChangeNotification( ctx context.Context, diff --git a/pkg/rpcclient/push.go b/pkg/rpcclient/push.go index 6d0876972..7733572bf 100644 --- a/pkg/rpcclient/push.go +++ b/pkg/rpcclient/push.go @@ -36,6 +36,7 @@ func NewPush(discov discoveryregistry.SvcDiscoveryRegistry) *Push { if err != nil { panic(err) } + return &Push{ discov: discov, conn: conn, diff --git a/pkg/rpcclient/third.go b/pkg/rpcclient/third.go old mode 100644 new mode 100755 index 48a537112..2bb761450 --- a/pkg/rpcclient/third.go +++ b/pkg/rpcclient/third.go @@ -42,14 +42,13 @@ func NewThird(discov discoveryregistry.SvcDiscoveryRegistry) *Third { panic(err) } client := third.NewThirdClient(conn) - minioClient, err := minioInit() + minioClient, _ := minioInit() + return &Third{discov: discov, Client: client, conn: conn, MinioClient: minioClient} } func minioInit() (*minio.Client, error) { - minioClient := &minio.Client{} - var initUrl string - initUrl = config.Config.Object.Minio.Endpoint + initUrl := config.Config.Object.Minio.Endpoint minioUrl, err := url.Parse(initUrl) if err != nil { return nil, err @@ -63,9 +62,11 @@ func minioInit() (*minio.Client, error) { } else if minioUrl.Scheme == "https" { opts.Secure = true } + var minioClient *minio.Client minioClient, err = minio.New(minioUrl.Host, opts) if err != nil { return nil, err } + return minioClient, nil } diff --git a/pkg/rpcclient/user.go b/pkg/rpcclient/user.go index c40d95727..dfd93fb0b 100644 --- a/pkg/rpcclient/user.go +++ b/pkg/rpcclient/user.go @@ -45,6 +45,7 @@ func NewUser(discov discoveryregistry.SvcDiscoveryRegistry) *User { panic(err) } client := user.NewUserClient(conn) + return &User{Discov: discov, Client: client, conn: conn} } @@ -54,6 +55,7 @@ type UserRpcClient User // NewUserRpcClientByUser initializes a UserRpcClient based on the provided User instance. func NewUserRpcClientByUser(user *User) *UserRpcClient { rpc := UserRpcClient(*user) + return &rpc } @@ -75,6 +77,7 @@ func (u *UserRpcClient) GetUsersInfo(ctx context.Context, userIDs []string) ([]* })); len(ids) > 0 { return nil, errs.ErrUserIDNotFound.Wrap(strings.Join(ids, ",")) } + return resp.UsersInfo, nil } @@ -84,6 +87,7 @@ func (u *UserRpcClient) GetUserInfo(ctx context.Context, userID string) (*sdkws. if err != nil { return nil, err } + return users[0], nil } @@ -93,6 +97,7 @@ func (u *UserRpcClient) GetUsersInfoMap(ctx context.Context, userIDs []string) ( if err != nil { return nil, err } + return utils.SliceToMap(users, func(e *sdkws.UserInfo) string { return e.UserID }), nil @@ -108,6 +113,7 @@ func (u *UserRpcClient) GetPublicUserInfos( if err != nil { return nil, err } + return utils.Slice(users, func(e *sdkws.UserInfo) *sdkws.PublicUserInfo { return &sdkws.PublicUserInfo{ UserID: e.UserID, @@ -124,6 +130,7 @@ func (u *UserRpcClient) GetPublicUserInfo(ctx context.Context, userID string) (* if err != nil { return nil, err } + return users[0], nil } @@ -137,6 +144,7 @@ func (u *UserRpcClient) GetPublicUserInfoMap( if err != nil { return nil, err } + return utils.SliceToMap(users, func(e *sdkws.PublicUserInfo) string { return e.UserID }), nil @@ -150,6 +158,7 @@ func (u *UserRpcClient) GetUserGlobalMsgRecvOpt(ctx context.Context, userID stri if err != nil { return 0, err } + return resp.GlobalRecvMsgOpt, nil } @@ -159,6 +168,7 @@ func (u *UserRpcClient) Access(ctx context.Context, ownerUserID string) error { if err != nil { return err } + return authverify.CheckAccessV3(ctx, ownerUserID) } @@ -168,6 +178,7 @@ func (u *UserRpcClient) GetAllUserIDs(ctx context.Context, pageNumber, showNumbe if err != nil { return nil, err } + return resp.UserIDs, nil } @@ -177,5 +188,6 @@ func (u *UserRpcClient) SetUserStatus(ctx context.Context, userID string, status UserID: userID, Status: status, PlatformID: int32(platformID), }) + return err } diff --git a/pkg/statistics/statistics.go b/pkg/statistics/statistics.go index de6d04fec..080933c73 100644 --- a/pkg/statistics/statistics.go +++ b/pkg/statistics/statistics.go @@ -36,9 +36,8 @@ func (s *Statistics) output() { var timeIntervalNum uint64 for { sum = *s.AllCount - select { - case <-t.C: - } + <-t.C + if *s.AllCount-sum <= 0 { intervalCount = 0 } else { @@ -66,5 +65,6 @@ func (s *Statistics) output() { func NewStatistics(allCount *uint64, moduleName, printArgs string, sleepTime int) *Statistics { p := &Statistics{AllCount: allCount, ModuleName: moduleName, SleepTime: uint64(sleepTime), PrintArgs: printArgs} go p.output() + return p }