diff --git a/internal/common/check/friend.go b/internal/common/check/friend.go index 4b120096d..622b4fd9d 100644 --- a/internal/common/check/friend.go +++ b/internal/common/check/friend.go @@ -25,7 +25,7 @@ func (f *FriendChecker) GetFriendsInfo(ctx context.Context, ownerUserID, friendU if err != nil { return nil, err } - r, err := friend.NewFriendClient(cc).GetPaginationFriends(ctx, &friend.GetPaginationFriendsReq{OwnerUserID: ownerUserID, FriendUserIDs: []string{friendUserID}}) + r, err := friend.NewFriendClient(cc).GetDesignatedFriends(ctx, &friend.GetDesignatedFriendsReq{OwnerUserID: ownerUserID, FriendUserIDs: []string{friendUserID}}) if err != nil { return nil, err } @@ -67,7 +67,7 @@ func (f *FriendChecker) GetAllPageFriends(ctx context.Context, ownerUserID strin if tmp.Total == int32(len(resp)) { return resp, nil } - return nil, constant.ErrData.Wrap("total != resp, but result is nil") + return nil, constant.ErrData.Wrap("The total number of results and expectations are different, but result is nil") } resp = append(resp, tmp.FriendsInfo...) page++ diff --git a/internal/common/check/user.go b/internal/common/check/user.go index f5328f0aa..66b2a8ce9 100644 --- a/internal/common/check/user.go +++ b/internal/common/check/user.go @@ -100,5 +100,12 @@ func (u *UserCheck) GetPublicUserInfoMap(ctx context.Context, userIDs []string, } func (u *UserCheck) GetUserGlobalMsgRecvOpt(ctx context.Context, userID string) (int32, error) { - return 0, nil + cc, err := u.getConn() + if err != nil { + return 0, err + } + resp, err := user.NewUserClient(cc).GetGlobalRecvMessageOpt(ctx, &user.GetGlobalRecvMessageOptReq{ + UserID: userID, + }) + return resp.GlobalRecvMsgOpt, err } diff --git a/internal/rpc/friend/friend.go b/internal/rpc/friend/friend.go index 11b776475..dd819967f 100644 --- a/internal/rpc/friend/friend.go +++ b/internal/rpc/friend/friend.go @@ -4,6 +4,7 @@ import ( "Open_IM/internal/common/check" "Open_IM/internal/common/convert" "Open_IM/internal/common/notification" + "Open_IM/internal/tx" "Open_IM/pkg/common/constant" "Open_IM/pkg/common/db/controller" "Open_IM/pkg/common/db/relation" @@ -27,16 +28,16 @@ type friendServer struct { } func Start(client *openKeeper.ZkClient, server *grpc.Server) error { - mysql, err := relation.NewGormDB() + db, err := relation.NewGormDB() if err != nil { return err } - if err := mysql.AutoMigrate(&tablerelation.FriendModel{}, &tablerelation.FriendRequestModel{}, &tablerelation.BlackModel{}); err != nil { + if err := db.AutoMigrate(&tablerelation.FriendModel{}, &tablerelation.FriendRequestModel{}, &tablerelation.BlackModel{}); err != nil { return err } pbfriend.RegisterFriendServer(server, &friendServer{ - FriendInterface: controller.NewFriendController(mysql), - BlackInterface: controller.NewBlackController(mysql), + FriendInterface: controller.NewFriendController(controller.NewFriendDatabase(relation.NewFriendGorm(db), relation.NewFriendRequestGorm(db), tx.NewGorm(db))), + BlackInterface: controller.NewBlackController(controller.NewBlackDatabase(relation.NewBlackGorm(db))), notification: notification.NewCheck(client), userCheck: check.NewUserCheck(client), RegisterCenter: client, diff --git a/internal/rpc/msg/send_msg.go b/internal/rpc/msg/send_msg.go index 150da4ae1..8cc24550d 100644 --- a/internal/rpc/msg/send_msg.go +++ b/internal/rpc/msg/send_msg.go @@ -165,14 +165,14 @@ func (m *msgServer) messageVerification(ctx context.Context, data *msg.SendMsgRe } if revokeMessage.RevokerID != revokeMessage.SourceMessageSendID { - resp, err := m.MsgInterface.GetSuperGroupMsg(ctx, data.MsgData.GroupID, revokeMessage.Seq) + resp, err := m.MsgInterface.GetSuperGroupMsgBySeqs(ctx, data.MsgData.GroupID, []int64{int64(revokeMessage.Seq)}) if err != nil { return nil, err } - if resp.ClientMsgID == revokeMessage.ClientMsgID && resp.Seq == revokeMessage.Seq { - revokeMessage.SourceMessageSendTime = resp.SendTime - revokeMessage.SourceMessageSenderNickname = resp.SenderNickname - revokeMessage.SourceMessageSendID = resp.SendID + if resp[0].ClientMsgID == revokeMessage.ClientMsgID && resp[0].Seq == int64(revokeMessage.Seq) { + revokeMessage.SourceMessageSendTime = resp[0].SendTime + revokeMessage.SourceMessageSenderNickname = resp[0].SenderNickname + revokeMessage.SourceMessageSendID = resp[0].SendID data.MsgData.Content = []byte(utils.StructToJsonString(revokeMessage)) } else { return nil, constant.ErrData.Wrap("MsgData") diff --git a/internal/rpc/msg/send_pull.go b/internal/rpc/msg/send_pull.go index 09502d4f3..cba7e619c 100644 --- a/internal/rpc/msg/send_pull.go +++ b/internal/rpc/msg/send_pull.go @@ -67,7 +67,7 @@ func (m *msgServer) sendMsgSingleChat(ctx context.Context, req *msg.SendMsgReq) if err != nil { return nil, err } - isSend, err := modifyMessageByUserMessageReceiveOpt(req.MsgData.RecvID, req.MsgData.SendID, constant.SingleChatType, req) + isSend, err := m.modifyMessageByUserMessageReceiveOpt(ctx, req.MsgData.RecvID, req.MsgData.SendID, constant.SingleChatType, req) if err != nil { return nil, err } diff --git a/internal/rpc/user/user.go b/internal/rpc/user/user.go index e81653780..d29dd01c8 100644 --- a/internal/rpc/user/user.go +++ b/internal/rpc/user/user.go @@ -4,6 +4,7 @@ import ( "Open_IM/internal/common/check" "Open_IM/internal/common/convert" "Open_IM/internal/common/notification" + "Open_IM/pkg/common/config" "Open_IM/pkg/common/constant" "Open_IM/pkg/common/db/controller" "Open_IM/pkg/common/db/relation" @@ -29,19 +30,28 @@ type userServer struct { } func Start(client *openKeeper.ZkClient, server *grpc.Server) error { - mysql, err := relation.NewGormDB() + gormDB, err := relation.NewGormDB() if err != nil { return err } - if err := mysql.AutoMigrate(&tablerelation.UserModel{}); err != nil { + if err := gormDB.AutoMigrate(&tablerelation.UserModel{}); err != nil { return err } - pbuser.RegisterUserServer(server, &userServer{ - UserInterface: controller.NewUserController(mysql), + u := &userServer{ + UserInterface: controller.NewUserController(controller.NewUserDatabase(relation.NewUserGorm(gormDB))), notification: notification.NewCheck(client), userCheck: check.NewUserCheck(client), RegisterCenter: client, - }) + } + pbuser.RegisterUserServer(server, u) + users := make([]*tablerelation.UserModel, 0) + if len(config.Config.Manager.AppManagerUid) != len(config.Config.Manager.Nickname) { + return constant.ErrConfig.Wrap("len(config.Config.Manager.AppManagerUid) != len(config.Config.Manager.Nickname)") + } + for k, v := range config.Config.Manager.AppManagerUid { + users = append(users, &tablerelation.UserModel{UserID: v, Nickname: config.Config.Manager.Nickname[k]}) + } + u.UserInterface.InitOnce(context.Background(), users) return nil } diff --git a/pkg/common/config/config.go b/pkg/common/config/config.go index ecf6b7774..b137e96ba 100644 --- a/pkg/common/config/config.go +++ b/pkg/common/config/config.go @@ -213,9 +213,9 @@ type config struct { } } Manager struct { - AppManagerUid []string `yaml:"appManagerUid"` - Secrets []string `yaml:"secrets"` - AppSysNotificationName string `yaml:"appSysNotificationName"` + AppManagerUid []string `yaml:"appManagerUid"` + // AppSysNotificationName string `yaml:"appSysNotificationName"` + Nickname []string `yaml:"nickname"` } Kafka struct { diff --git a/pkg/common/constant/constant.go b/pkg/common/constant/constant.go index 521fbd09e..ced99672a 100644 --- a/pkg/common/constant/constant.go +++ b/pkg/common/constant/constant.go @@ -319,8 +319,4 @@ const BigVersion = "v2" const LogFileName = "OpenIM.log" -const StatisticsTimeInterval = 60 - -const MaxNotificationNum = 500 - const CurrentVersion = "v2.3.4-rc0" diff --git a/pkg/common/constant/errors.go b/pkg/common/constant/errors.go index 6bf324906..03b45fa2f 100644 --- a/pkg/common/constant/errors.go +++ b/pkg/common/constant/errors.go @@ -58,6 +58,8 @@ var ( ErrConnArgsErr = &ErrInfo{ConnArgsErr, "args err, need token, sendID, platformID", ""} ErrConnUpdateErr = &ErrInfo{ConnArgsErr, "upgrade http conn err", ""} + + ErrConfig = &ErrInfo{ConfigError, "ConfigError", ""} ) const ( @@ -91,6 +93,8 @@ const ( DataError = 90007 //数据错误 IdentityError = 90008 // 身份错误 非管理员token,且token中userID与请求userID不一致 + + ConfigError = 90009 ) // 账号错误码 diff --git a/pkg/common/constant/limit.go b/pkg/common/constant/limit.go index cdd19e240..57e764eef 100644 --- a/pkg/common/constant/limit.go +++ b/pkg/common/constant/limit.go @@ -1,5 +1,7 @@ package constant const ( - ShowNumber = 1000 + ShowNumber = 1000 + StatisticsTimeInterval = 60 + MaxNotificationNum = 500 ) diff --git a/pkg/common/db/controller/black.go b/pkg/common/db/controller/black.go index 98493fd3a..6a761c5c5 100644 --- a/pkg/common/db/controller/black.go +++ b/pkg/common/db/controller/black.go @@ -1,7 +1,6 @@ package controller import ( - relation2 "Open_IM/pkg/common/db/relation" "Open_IM/pkg/common/db/table/relation" "Open_IM/pkg/utils" "context" @@ -24,8 +23,8 @@ type BlackController struct { database BlackDatabaseInterface } -func NewBlackController(db *gorm.DB) *BlackController { - return &BlackController{database: NewBlackDatabase(db)} +func NewBlackController(database BlackDatabaseInterface) BlackInterface { + return &BlackController{database: database} } // Create 增加黑名单 @@ -60,35 +59,31 @@ type BlackDatabaseInterface interface { } type BlackDatabase struct { - sqlDB *relation2.BlackGorm + black relation.BlackModelInterface } -func NewBlackDatabase(db *gorm.DB) *BlackDatabase { - sqlDB := relation2.NewBlackGorm(db) - database := &BlackDatabase{ - sqlDB: sqlDB, - } - return database +func NewBlackDatabase(black relation.BlackModelInterface) *BlackDatabase { + return &BlackDatabase{black} } // Create 增加黑名单 func (b *BlackDatabase) Create(ctx context.Context, blacks []*relation.BlackModel) (err error) { - return b.sqlDB.Create(ctx, blacks) + return b.black.Create(ctx, blacks) } // Delete 删除黑名单 func (b *BlackDatabase) Delete(ctx context.Context, blacks []*relation.BlackModel) (err error) { - return b.sqlDB.Delete(ctx, blacks) + return b.black.Delete(ctx, blacks) } // FindOwnerBlacks 获取黑名单列表 func (b *BlackDatabase) FindOwnerBlacks(ctx context.Context, ownerUserID string, pageNumber, showNumber int32) (blacks []*relation.BlackModel, total int64, err error) { - return b.sqlDB.FindOwnerBlacks(ctx, ownerUserID, pageNumber, showNumber) + return b.black.FindOwnerBlacks(ctx, ownerUserID, pageNumber, showNumber) } // CheckIn 检查user2是否在user1的黑名单列表中(inUser1Blacks==true) 检查user1是否在user2的黑名单列表中(inUser2Blacks==true) func (b *BlackDatabase) CheckIn(ctx context.Context, userID1, userID2 string) (inUser1Blacks bool, inUser2Blacks bool, err error) { - _, err = b.sqlDB.Take(ctx, userID1, userID2) + _, err = b.black.Take(ctx, userID1, userID2) if err != nil { if errors.Unwrap(err) != gorm.ErrRecordNotFound { return @@ -99,7 +94,7 @@ func (b *BlackDatabase) CheckIn(ctx context.Context, userID1, userID2 string) (i } inUser2Blacks = true - _, err = b.sqlDB.Take(ctx, userID2, userID1) + _, err = b.black.Take(ctx, userID2, userID1) if err != nil { if utils.Unwrap(err) != gorm.ErrRecordNotFound { return diff --git a/pkg/common/db/controller/friend.go b/pkg/common/db/controller/friend.go index ea322a9d3..1b38cf634 100644 --- a/pkg/common/db/controller/friend.go +++ b/pkg/common/db/controller/friend.go @@ -1,8 +1,8 @@ package controller import ( + "Open_IM/internal/tx" "Open_IM/pkg/common/constant" - relation1 "Open_IM/pkg/common/db/relation" "Open_IM/pkg/common/db/table/relation" "Open_IM/pkg/utils" "context" @@ -41,8 +41,8 @@ type FriendController struct { database FriendDatabaseInterface } -func NewFriendController(db *gorm.DB) *FriendController { - return &FriendController{database: NewFriendDatabase(db)} +func NewFriendController(database FriendDatabaseInterface) FriendInterface { + return &FriendController{database: database} } // 检查user2是否在user1的好友列表中(inUser1Friends==true) 检查user1是否在user2的好友列表中(inUser2Friends==true) @@ -133,12 +133,13 @@ type FriendDatabaseInterface interface { } type FriendDatabase struct { - friend *relation1.FriendGorm - friendRequest *relation1.FriendRequestGorm + friend relation.FriendModelInterface + friendRequest relation.FriendRequestModelInterface + tx tx.Tx } -func NewFriendDatabase(db *gorm.DB) *FriendDatabase { - return &FriendDatabase{friend: relation1.NewFriendGorm(db), friendRequest: relation1.NewFriendRequestGorm(db)} +func NewFriendDatabase(friend relation.FriendModelInterface, friendRequest relation.FriendRequestModelInterface, tx tx.Tx) *FriendDatabase { + return &FriendDatabase{friend: friend, friendRequest: friendRequest, tx: tx} } // ok 检查user2是否在user1的好友列表中(inUser1Friends==true) 检查user1是否在user2的好友列表中(inUser2Friends==true) @@ -160,8 +161,8 @@ func (f *FriendDatabase) CheckIn(ctx context.Context, userID1, userID2 string) ( // 增加或者更新好友申请 如果之前有记录则更新,没有记录则新增 func (f *FriendDatabase) AddFriendRequest(ctx context.Context, fromUserID, toUserID string, reqMsg string, ex string) (err error) { - return f.friendRequest.DB.Transaction(func(tx *gorm.DB) error { - _, err := f.friendRequest.Take(ctx, fromUserID, toUserID, tx) + return f.tx.Transaction(func(tx any) error { + _, err := f.friendRequest.NewTx(tx).Take(ctx, fromUserID, toUserID) //有db错误 if err != nil && errors.Unwrap(err) != gorm.ErrRecordNotFound { return err @@ -173,13 +174,13 @@ func (f *FriendDatabase) AddFriendRequest(ctx context.Context, fromUserID, toUse m["handle_msg"] = "" m["req_msg"] = reqMsg m["ex"] = ex - if err := f.friendRequest.UpdateByMap(ctx, fromUserID, toUserID, m, tx); err != nil { + if err := f.friendRequest.NewTx(tx).UpdateByMap(ctx, fromUserID, toUserID, m); err != nil { return err } return nil } //gorm.ErrRecordNotFound 错误,则新增 - if err := f.friendRequest.Create(ctx, []*relation.FriendRequestModel{&relation.FriendRequestModel{FromUserID: fromUserID, ToUserID: toUserID, ReqMsg: reqMsg, Ex: ex}}, tx); err != nil { + if err := f.friendRequest.NewTx(tx).Create(ctx, []*relation.FriendRequestModel{&relation.FriendRequestModel{FromUserID: fromUserID, ToUserID: toUserID, ReqMsg: reqMsg, Ex: ex}}); err != nil { return err } return nil @@ -188,9 +189,9 @@ func (f *FriendDatabase) AddFriendRequest(ctx context.Context, fromUserID, toUse // (1)先判断是否在好友表 (在不在都不返回错误) (2)对于不在好友列表的 插入即可 func (f *FriendDatabase) BecomeFriends(ctx context.Context, ownerUserID string, friendUserIDs []string, addSource int32, OperatorUserID string) (err error) { - return f.friend.DB.Transaction(func(tx *gorm.DB) error { + return f.tx.Transaction(func(tx any) error { //先find 找出重复的 去掉重复的 - fs1, err := f.friend.FindFriends(ctx, ownerUserID, friendUserIDs, tx) + fs1, err := f.friend.NewTx(tx).FindFriends(ctx, ownerUserID, friendUserIDs) if err != nil { return err } @@ -201,12 +202,12 @@ func (f *FriendDatabase) BecomeFriends(ctx context.Context, ownerUserID string, return e.FriendUserID }) - err = f.friend.Create(ctx, fs11, tx) + err = f.friend.NewTx(tx).Create(ctx, fs11) if err != nil { return err } - fs2, err := f.friend.FindReversalFriends(ctx, ownerUserID, friendUserIDs, tx) + fs2, err := f.friend.NewTx(tx).FindReversalFriends(ctx, ownerUserID, friendUserIDs) if err != nil { return err } @@ -216,7 +217,7 @@ func (f *FriendDatabase) BecomeFriends(ctx context.Context, ownerUserID string, fs22 := utils.DistinctAny(fs2, func(e *relation.FriendModel) string { return e.OwnerUserID }) - err = f.friend.Create(ctx, fs22, tx) + err = f.friend.NewTx(tx).Create(ctx, fs22) if err != nil { return err } @@ -240,14 +241,14 @@ func (f *FriendDatabase) RefuseFriendRequest(ctx context.Context, friendRequest // 同意好友申请 (1)检查是否有申请记录且为未处理状态 (没有记录返回错误) (2)检查是否好友(不返回错误) (3) 不是好友则建立双向好友关系 (4)修改申请记录 已同意 func (f *FriendDatabase) AgreeFriendRequest(ctx context.Context, friendRequest *relation.FriendRequestModel) (err error) { - return f.friend.DB.Transaction(func(tx *gorm.DB) error { - _, err = f.friendRequest.Take(ctx, friendRequest.FromUserID, friendRequest.ToUserID) + return f.tx.Transaction(func(tx any) error { + _, err = f.friendRequest.NewTx(tx).Take(ctx, friendRequest.FromUserID, friendRequest.ToUserID) if err != nil { return err } friendRequest.HandlerUserID = friendRequest.FromUserID friendRequest.HandleResult = constant.FriendResponseAgree - err = f.friendRequest.Update(ctx, []*relation.FriendRequestModel{friendRequest}, tx) + err = f.friendRequest.NewTx(tx).Update(ctx, []*relation.FriendRequestModel{friendRequest}) if err != nil { return err } @@ -257,7 +258,7 @@ func (f *FriendDatabase) AgreeFriendRequest(ctx context.Context, friendRequest * addSource := int32(constant.BecomeFriendByApply) OperatorUserID := friendRequest.FromUserID //先find 找出重复的 去掉重复的 - fs1, err := f.friend.FindFriends(ctx, ownerUserID, friendUserIDs, tx) + fs1, err := f.friend.NewTx(tx).FindFriends(ctx, ownerUserID, friendUserIDs) if err != nil { return err } @@ -268,12 +269,12 @@ func (f *FriendDatabase) AgreeFriendRequest(ctx context.Context, friendRequest * return e.FriendUserID }) - err = f.friend.Create(ctx, fs11, tx) + err = f.friend.NewTx(tx).Create(ctx, fs11) if err != nil { return err } - fs2, err := f.friend.FindReversalFriends(ctx, ownerUserID, friendUserIDs, tx) + fs2, err := f.friend.NewTx(tx).FindReversalFriends(ctx, ownerUserID, friendUserIDs) if err != nil { return err } @@ -283,7 +284,7 @@ func (f *FriendDatabase) AgreeFriendRequest(ctx context.Context, friendRequest * fs22 := utils.DistinctAny(fs2, func(e *relation.FriendModel) string { return e.OwnerUserID }) - err = f.friend.Create(ctx, fs22, tx) + err = f.friend.NewTx(tx).Create(ctx, fs22) if err != nil { return err } diff --git a/pkg/common/db/controller/msg.go b/pkg/common/db/controller/msg.go index 307244349..e1c45dbfb 100644 --- a/pkg/common/db/controller/msg.go +++ b/pkg/common/db/controller/msg.go @@ -33,7 +33,7 @@ type MsgInterface interface { DelMsgBySeqs(ctx context.Context, userID string, seqs []int64) (totalUnExistSeqs []int64, err error) // 通过seqList获取db中写扩散消息 GetMsgBySeqs(ctx context.Context, userID string, seqs []int64) (seqMsg []*sdkws.MsgData, err error) - // 通过seqList获取大群在db里面的消息 + // 通过seqList获取大群在db里面的消息 没找到返回错误 GetSuperGroupMsgBySeqs(ctx context.Context, groupID string, seqs []int64) (seqMsg []*sdkws.MsgData, err error) // 删除用户所有消息/cache/db然后重置seq CleanUpUserMsg(ctx context.Context, userID string) error @@ -49,6 +49,8 @@ type MsgInterface interface { SetGroupUserMinSeq(ctx context.Context, groupID, userID string, minSeq int64) (err error) // 设置用户最小seq 直接调用cache SetUserMinSeq(ctx context.Context, userID string, minSeq int64) (err error) + + MsgToMQ(ctx context.Context, key string, data *pbMsg.MsgDataToMQ) (err error) } func NewMsgController(mgo *mongo.Client, rdb redis.UniversalClient) MsgInterface { diff --git a/pkg/common/db/controller/user.go b/pkg/common/db/controller/user.go index 0cdab97e1..d36100898 100644 --- a/pkg/common/db/controller/user.go +++ b/pkg/common/db/controller/user.go @@ -2,29 +2,30 @@ package controller import ( "Open_IM/pkg/common/constant" - "Open_IM/pkg/common/db/relation" - relationTb "Open_IM/pkg/common/db/table/relation" + "Open_IM/pkg/common/db/table/relation" + "Open_IM/pkg/utils" "context" - "gorm.io/gorm" ) type UserInterface interface { //获取指定用户的信息 如有userID未找到 也返回错误 - FindWithError(ctx context.Context, userIDs []string) (users []*relationTb.UserModel, err error) + FindWithError(ctx context.Context, userIDs []string) (users []*relation.UserModel, err error) //获取指定用户的信息 如有userID未找到 不返回错误 - Find(ctx context.Context, userIDs []string) (users []*relationTb.UserModel, err error) + Find(ctx context.Context, userIDs []string) (users []*relation.UserModel, err error) //插入多条 外部保证userID 不重复 且在db中不存在 - Create(ctx context.Context, users []*relationTb.UserModel) (err error) + Create(ctx context.Context, users []*relation.UserModel) (err error) //更新(非零值) 外部保证userID存在 - Update(ctx context.Context, users []*relationTb.UserModel) (err error) + Update(ctx context.Context, users []*relation.UserModel) (err error) //更新(零值) 外部保证userID存在 UpdateByMap(ctx context.Context, userID string, args map[string]interface{}) (err error) //如果没找到,不返回错误 - Page(ctx context.Context, pageNumber, showNumber int32) (users []*relationTb.UserModel, count int64, err error) + Page(ctx context.Context, pageNumber, showNumber int32) (users []*relation.UserModel, count int64, err error) //只要有一个存在就为true IsExist(ctx context.Context, userIDs []string) (exist bool, err error) //获取所有用户ID GetAllUserID(ctx context.Context) ([]string, error) + //函数内部先查询db中是否存在,存在则什么都不做;不存在则插入 + InitOnce(ctx context.Context, users []*relation.UserModel) (err error) } type UserController struct { @@ -32,25 +33,25 @@ type UserController struct { } // 获取指定用户的信息 如有userID未找到 也返回错误 -func (u *UserController) FindWithError(ctx context.Context, userIDs []string) (users []*relationTb.UserModel, err error) { +func (u *UserController) FindWithError(ctx context.Context, userIDs []string) (users []*relation.UserModel, err error) { return u.database.FindWithError(ctx, userIDs) } -func (u *UserController) Find(ctx context.Context, userIDs []string) (users []*relationTb.UserModel, err error) { +func (u *UserController) Find(ctx context.Context, userIDs []string) (users []*relation.UserModel, err error) { return u.database.Find(ctx, userIDs) } -func (u *UserController) Create(ctx context.Context, users []*relationTb.UserModel) error { +func (u *UserController) Create(ctx context.Context, users []*relation.UserModel) error { return u.database.Create(ctx, users) } -func (u *UserController) Update(ctx context.Context, users []*relationTb.UserModel) (err error) { +func (u *UserController) Update(ctx context.Context, users []*relation.UserModel) (err error) { return u.database.Update(ctx, users) } func (u *UserController) UpdateByMap(ctx context.Context, userID string, args map[string]interface{}) (err error) { return u.database.UpdateByMap(ctx, userID, args) } -func (u *UserController) Page(ctx context.Context, pageNumber, showNumber int32) (users []*relationTb.UserModel, count int64, err error) { +func (u *UserController) Page(ctx context.Context, pageNumber, showNumber int32) (users []*relation.UserModel, count int64, err error) { return u.database.Page(ctx, pageNumber, showNumber) } @@ -62,45 +63,58 @@ func (u *UserController) GetAllUserID(ctx context.Context) ([]string, error) { return u.database.GetAllUserID(ctx) } -func NewUserController(db *gorm.DB) *UserController { - controller := &UserController{database: newUserDatabase(db)} - return controller +func (u *UserController) InitOnce(ctx context.Context, users []*relation.UserModel) (err error) { + return u.database.InitOnce(ctx, users) +} + +func NewUserController(database UserDatabaseInterface) UserInterface { + return &UserController{database} } type UserDatabaseInterface interface { //获取指定用户的信息 如有userID未找到 也返回错误 - FindWithError(ctx context.Context, userIDs []string) (users []*relationTb.UserModel, err error) + FindWithError(ctx context.Context, userIDs []string) (users []*relation.UserModel, err error) //获取指定用户的信息 如有userID未找到 不返回错误 - Find(ctx context.Context, userIDs []string) (users []*relationTb.UserModel, err error) + Find(ctx context.Context, userIDs []string) (users []*relation.UserModel, err error) //插入多条 外部保证userID 不重复 且在db中不存在 - Create(ctx context.Context, users []*relationTb.UserModel) (err error) + Create(ctx context.Context, users []*relation.UserModel) (err error) //更新(非零值) 外部保证userID存在 - Update(ctx context.Context, users []*relationTb.UserModel) (err error) + Update(ctx context.Context, users []*relation.UserModel) (err error) //更新(零值) 外部保证userID存在 UpdateByMap(ctx context.Context, userID string, args map[string]interface{}) (err error) //如果没找到,不返回错误 - Page(ctx context.Context, pageNumber, showNumber int32) (users []*relationTb.UserModel, count int64, err error) + Page(ctx context.Context, pageNumber, showNumber int32) (users []*relation.UserModel, count int64, err error) //只要有一个存在就为true IsExist(ctx context.Context, userIDs []string) (exist bool, err error) //获取所有用户ID GetAllUserID(ctx context.Context) ([]string, error) + //函数内部先查询db中是否存在,存在则什么都不做;不存在则插入 + InitOnce(ctx context.Context, users []*relation.UserModel) (err error) } type UserDatabase struct { - user *relation.UserGorm + userDB relation.UserModelInterface } -func newUserDatabase(db *gorm.DB) *UserDatabase { - sqlDB := relation.NewUserGorm(db) - database := &UserDatabase{ - user: sqlDB, +func NewUserDatabase(userDB relation.UserModelInterface) *UserDatabase { + return &UserDatabase{userDB: userDB} +} + +func (u *UserDatabase) InitOnce(ctx context.Context, users []*relation.UserModel) (err error) { + userIDs := utils.Slice(users, func(e *relation.UserModel) string { + return e.UserID + }) + result, err := u.userDB.Find(ctx, userIDs) + if err != nil { + return err } - return database + } // 获取指定用户的信息 如有userID未找到 也返回错误 -func (u *UserDatabase) FindWithError(ctx context.Context, userIDs []string) (users []*relationTb.UserModel, err error) { - users, err = u.user.Find(ctx, userIDs) +func (u *UserDatabase) FindWithError(ctx context.Context, userIDs []string) (users []*relation.UserModel, err error) { + + users, err = u.userDB.Find(ctx, userIDs) if err != nil { return } @@ -111,34 +125,34 @@ func (u *UserDatabase) FindWithError(ctx context.Context, userIDs []string) (use } // 获取指定用户的信息 如有userID未找到 不返回错误 -func (u *UserDatabase) Find(ctx context.Context, userIDs []string) (users []*relationTb.UserModel, err error) { - users, err = u.user.Find(ctx, userIDs) +func (u *UserDatabase) Find(ctx context.Context, userIDs []string) (users []*relation.UserModel, err error) { + users, err = u.userDB.Find(ctx, userIDs) return } // 插入多条 外部保证userID 不重复 且在db中不存在 -func (u *UserDatabase) Create(ctx context.Context, users []*relationTb.UserModel) (err error) { - return u.user.Create(ctx, users) +func (u *UserDatabase) Create(ctx context.Context, users []*relation.UserModel) (err error) { + return u.userDB.Create(ctx, users) } // 更新(非零值) 外部保证userID存在 -func (u *UserDatabase) Update(ctx context.Context, users []*relationTb.UserModel) (err error) { - return u.user.Update(ctx, users) +func (u *UserDatabase) Update(ctx context.Context, users []*relation.UserModel) (err error) { + return u.userDB.Update(ctx, users) } // 更新(零值) 外部保证userID存在 func (u *UserDatabase) UpdateByMap(ctx context.Context, userID string, args map[string]interface{}) (err error) { - return u.user.UpdateByMap(ctx, userID, args) + return u.userDB.UpdateByMap(ctx, userID, args) } // 获取,如果没找到,不返回错误 -func (u *UserDatabase) Page(ctx context.Context, showNumber, pageNumber int32) (users []*relationTb.UserModel, count int64, err error) { - return u.user.Page(ctx, showNumber, pageNumber) +func (u *UserDatabase) Page(ctx context.Context, pageNumber, showNumber int32) (users []*relation.UserModel, count int64, err error) { + return u.userDB.Page(ctx, pageNumber, showNumber) } // userIDs是否存在 只要有一个存在就为true func (u *UserDatabase) IsExist(ctx context.Context, userIDs []string) (exist bool, err error) { - users, err := u.user.Find(ctx, userIDs) + users, err := u.userDB.Find(ctx, userIDs) if err != nil { return false, err } @@ -148,6 +162,21 @@ func (u *UserDatabase) IsExist(ctx context.Context, userIDs []string) (exist boo return false, nil } -func (u *UserDatabase) GetAllUserID(ctx context.Context) ([]string, error) { - return u.user.GetAllUserID(ctx) +func (u *UserDatabase) GetAllUserID(ctx context.Context) (userIDs []string, err error) { + pageNumber := int32(0) + for { + tmp, total, err := u.userDB.PageUserID(ctx, pageNumber, constant.ShowNumber) + if err != nil { + return nil, err + } + if len(tmp) == 0 { + if total == int64(len(userIDs)) { + return userIDs, nil + } + return nil, constant.ErrData.Wrap("The total number of results and expectations are different, but result is nil") + } + userIDs = append(userIDs, tmp...) + pageNumber++ + } + return userIDs, nil } diff --git a/pkg/common/db/relation/black_model.go b/pkg/common/db/relation/black_model.go index 07498bb48..dbef9e116 100644 --- a/pkg/common/db/relation/black_model.go +++ b/pkg/common/db/relation/black_model.go @@ -12,10 +12,8 @@ type BlackGorm struct { DB *gorm.DB } -func NewBlackGorm(db *gorm.DB) *BlackGorm { - var black BlackGorm - black.DB = db - return &black +func NewBlackGorm(db *gorm.DB) relation.BlackModelInterface { + return &BlackGorm{db} } func (b *BlackGorm) Create(ctx context.Context, blacks []*relation.BlackModel) (err error) { diff --git a/pkg/common/db/relation/friend_model.go b/pkg/common/db/relation/friend_model.go index a6f423d16..15d541855 100644 --- a/pkg/common/db/relation/friend_model.go +++ b/pkg/common/db/relation/friend_model.go @@ -8,135 +8,125 @@ import ( "gorm.io/gorm" ) -type FriendDB interface { - Create(ctx context.Context, friends []*relation.FriendModel) (err error) - Delete(ctx context.Context, ownerUserID string, friendUserIDs string) (err error) - UpdateByMap(ctx context.Context, ownerUserID string, args map[string]interface{}) (err error) - Update(ctx context.Context, friends []*relation.FriendModel) (err error) - UpdateRemark(ctx context.Context, ownerUserID, friendUserID, remark string) (err error) - FindOwnerUserID(ctx context.Context, ownerUserID string) (friends []*relation.FriendModel, err error) -} - type FriendGorm struct { - DB *gorm.DB `gorm:"-"` + DB *gorm.DB } -func NewFriendGorm(DB *gorm.DB) *FriendGorm { - return &FriendGorm{DB: DB} +func NewFriendGorm(db *gorm.DB) relation.FriendModelInterface { + return &FriendGorm{DB: db} } -type FriendUser struct { - FriendGorm - Nickname string `gorm:"column:name;size:255"` +func (f *FriendGorm) NewTx(tx any) relation.FriendModelInterface { + return &FriendGorm{DB: tx.(*gorm.DB)} } // 插入多条记录 -func (f *FriendGorm) Create(ctx context.Context, friends []*relation.FriendModel, tx ...any) (err error) { +func (f *FriendGorm) Create(ctx context.Context, friends []*relation.FriendModel) (err error) { defer func() { tracelog.SetCtxDebug(ctx, utils.GetFuncName(1), err, "friends", friends) }() - return utils.Wrap(getDBConn(f.DB, tx).Create(&friends).Error, "") + return utils.Wrap(f.DB.Create(&friends).Error, "") } // 删除ownerUserID指定的好友 -func (f *FriendGorm) Delete(ctx context.Context, ownerUserID string, friendUserIDs []string, tx ...any) (err error) { +func (f *FriendGorm) Delete(ctx context.Context, ownerUserID string, friendUserIDs []string) (err error) { defer func() { tracelog.SetCtxDebug(ctx, utils.GetFuncName(1), err, "ownerUserID", ownerUserID, "friendUserIDs", friendUserIDs) }() - err = utils.Wrap(getDBConn(f.DB, tx).Where("owner_user_id = ? AND friend_user_id in ( ?)", ownerUserID, friendUserIDs).Delete(&relation.FriendModel{}).Error, "") + err = utils.Wrap(f.DB.Where("owner_user_id = ? AND friend_user_id in ( ?)", ownerUserID, friendUserIDs).Delete(&relation.FriendModel{}).Error, "") return err } // 更新ownerUserID单个好友信息 更新零值 -func (f *FriendGorm) UpdateByMap(ctx context.Context, ownerUserID string, friendUserID string, args map[string]interface{}, tx ...any) (err error) { +func (f *FriendGorm) UpdateByMap(ctx context.Context, ownerUserID string, friendUserID string, args map[string]interface{}) (err error) { defer func() { tracelog.SetCtxDebug(ctx, utils.GetFuncName(1), err, "ownerUserID", ownerUserID, "friendUserID", friendUserID, "args", args) }() - return utils.Wrap(getDBConn(f.DB, tx).Model(&relation.FriendModel{}).Where("owner_user_id = ? AND friend_user_id = ? ", ownerUserID, friendUserID).Updates(args).Error, "") + return utils.Wrap(f.DB.Model(&relation.FriendModel{}).Where("owner_user_id = ? AND friend_user_id = ? ", ownerUserID, friendUserID).Updates(args).Error, "") } // 更新好友信息的非零值 -func (f *FriendGorm) Update(ctx context.Context, friends []*relation.FriendModel, tx ...any) (err error) { +func (f *FriendGorm) Update(ctx context.Context, friends []*relation.FriendModel) (err error) { defer func() { tracelog.SetCtxDebug(ctx, utils.GetFuncName(1), err, "friends", friends) }() - return utils.Wrap(getDBConn(f.DB, tx).Updates(&friends).Error, "") + return utils.Wrap(f.DB.Updates(&friends).Error, "") } // 更新好友备注(也支持零值 ) -func (f *FriendGorm) UpdateRemark(ctx context.Context, ownerUserID, friendUserID, remark string, tx ...any) (err error) { +func (f *FriendGorm) UpdateRemark(ctx context.Context, ownerUserID, friendUserID, remark string) (err error) { defer func() { tracelog.SetCtxDebug(ctx, utils.GetFuncName(1), err, "ownerUserID", ownerUserID, "friendUserID", friendUserID, "remark", remark) }() if remark != "" { - return utils.Wrap(getDBConn(f.DB, tx).Model(&relation.FriendModel{}).Where("owner_user_id = ? and friend_user_id = ?", ownerUserID, friendUserID).Update("remark", remark).Error, "") + return utils.Wrap(f.DB.Model(&relation.FriendModel{}).Where("owner_user_id = ? and friend_user_id = ?", ownerUserID, friendUserID).Update("remark", remark).Error, "") } m := make(map[string]interface{}, 1) m["remark"] = "" - return utils.Wrap(getDBConn(f.DB, tx).Model(&relation.FriendModel{}).Where("owner_user_id = ?", ownerUserID).Updates(m).Error, "") + return utils.Wrap(f.DB.Model(&relation.FriendModel{}).Where("owner_user_id = ?", ownerUserID).Updates(m).Error, "") } // 获取单个好友信息,如没找到 返回错误 -func (f *FriendGorm) Take(ctx context.Context, ownerUserID, friendUserID string, tx ...any) (friend *relation.FriendModel, err error) { +func (f *FriendGorm) Take(ctx context.Context, ownerUserID, friendUserID string) (friend *relation.FriendModel, err error) { friend = &relation.FriendModel{} defer tracelog.SetCtxDebug(ctx, utils.GetFuncName(1), err, "ownerUserID", ownerUserID, "friendUserID", friendUserID, "friend", *friend) - return friend, utils.Wrap(getDBConn(f.DB, tx).Where("owner_user_id = ? and friend_user_id", ownerUserID, friendUserID).Take(friend).Error, "") + return friend, utils.Wrap(f.DB.Where("owner_user_id = ? and friend_user_id", ownerUserID, friendUserID).Take(friend).Error, "") } // 查找好友关系,如果是双向关系,则都返回 -func (f *FriendGorm) FindUserState(ctx context.Context, userID1, userID2 string, tx ...any) (friends []*relation.FriendModel, err error) { +func (f *FriendGorm) FindUserState(ctx context.Context, userID1, userID2 string) (friends []*relation.FriendModel, err error) { defer func() { tracelog.SetCtxDebug(ctx, utils.GetFuncName(1), err, "userID1", userID1, "userID2", userID2, "friends", friends) }() - return friends, utils.Wrap(getDBConn(f.DB, tx).Where("(owner_user_id = ? and friend_user_id = ?) or (owner_user_id = ? and friend_user_id = ?)", userID1, userID2, userID2, userID1).Find(&friends).Error, "") + return friends, utils.Wrap(f.DB.Where("(owner_user_id = ? and friend_user_id = ?) or (owner_user_id = ? and friend_user_id = ?)", userID1, userID2, userID2, userID1).Find(&friends).Error, "") } // 获取 owner指定的好友列表 如果有friendUserIDs不存在,也不返回错误 -func (f *FriendGorm) FindFriends(ctx context.Context, ownerUserID string, friendUserIDs []string, tx ...any) (friends []*relation.FriendModel, err error) { +func (f *FriendGorm) FindFriends(ctx context.Context, ownerUserID string, friendUserIDs []string) (friends []*relation.FriendModel, err error) { defer func() { tracelog.SetCtxDebug(ctx, utils.GetFuncName(1), err, "ownerUserID", ownerUserID, "friendUserIDs", friendUserIDs, "friends", friends) }() - return friends, utils.Wrap(getDBConn(f.DB, tx).Where("owner_user_id = ? AND friend_user_id in (?)", ownerUserID, friendUserIDs).Find(&friends).Error, "") + return friends, utils.Wrap(f.DB.Where("owner_user_id = ? AND friend_user_id in (?)", ownerUserID, friendUserIDs).Find(&friends).Error, "") } // 获取哪些人添加了friendUserID 如果有ownerUserIDs不存在,也不返回错误 -func (f *FriendGorm) FindReversalFriends(ctx context.Context, friendUserID string, ownerUserIDs []string, tx ...any) (friends []*relation.FriendModel, err error) { +func (f *FriendGorm) FindReversalFriends(ctx context.Context, friendUserID string, ownerUserIDs []string) (friends []*relation.FriendModel, err error) { defer func() { tracelog.SetCtxDebug(ctx, utils.GetFuncName(1), err, "friendUserID", friendUserID, "ownerUserIDs", ownerUserIDs, "friends", friends) }() - return friends, utils.Wrap(getDBConn(f.DB, tx).Where("friend_user_id = ? AND owner_user_id in (?)", friendUserID, ownerUserIDs).Find(&friends).Error, "") + return friends, utils.Wrap(f.DB.Where("friend_user_id = ? AND owner_user_id in (?)", friendUserID, ownerUserIDs).Find(&friends).Error, "") } // 获取ownerUserID好友列表 支持翻页 -func (f *FriendGorm) FindOwnerFriends(ctx context.Context, ownerUserID string, pageNumber, showNumber int32, tx ...any) (friends []*relation.FriendModel, total int64, err error) { +func (f *FriendGorm) FindOwnerFriends(ctx context.Context, ownerUserID string, pageNumber, showNumber int32) (friends []*relation.FriendModel, total int64, err error) { defer func() { tracelog.SetCtxDebug(ctx, utils.GetFuncName(1), err, "ownerUserID", ownerUserID, "pageNumber", pageNumber, "showNumber", showNumber, "friends", friends, "total", total) }() - err = getDBConn(f.DB, tx).Model(&relation.FriendModel{}).Where("owner_user_id = ? ", ownerUserID).Count(&total).Error + err = f.DB.Model(&relation.FriendModel{}).Where("owner_user_id = ? ", ownerUserID).Count(&total).Error if err != nil { return nil, 0, utils.Wrap(err, "") } - err = utils.Wrap(getDBConn(f.DB, tx).Where("owner_user_id = ? ", ownerUserID).Limit(int(showNumber)).Offset(int(pageNumber*showNumber)).Find(&friends).Error, "") + err = utils.Wrap(f.DB.Where("owner_user_id = ? ", ownerUserID).Limit(int(showNumber)).Offset(int(pageNumber*showNumber)).Find(&friends).Error, "") return } // 获取哪些人添加了friendUserID 支持翻页 -func (f *FriendGorm) FindInWhoseFriends(ctx context.Context, friendUserID string, pageNumber, showNumber int32, tx ...any) (friends []*relation.FriendModel, total int64, err error) { +func (f *FriendGorm) FindInWhoseFriends(ctx context.Context, friendUserID string, pageNumber, showNumber int32) (friends []*relation.FriendModel, total int64, err error) { defer func() { tracelog.SetCtxDebug(ctx, utils.GetFuncName(1), err, "friendUserID", friendUserID, "pageNumber", pageNumber, "showNumber", showNumber, "friends", friends, "total", total) }() - err = getDBConn(f.DB, tx).Model(&relation.FriendModel{}).Where("friend_user_id = ? ", friendUserID).Count(&total).Error + err = f.DB.Model(&relation.FriendModel{}).Where("friend_user_id = ? ", friendUserID).Count(&total).Error if err != nil { return nil, 0, utils.Wrap(err, "") } - err = utils.Wrap(getDBConn(f.DB, tx).Where("friend_user_id = ? ", friendUserID).Limit(int(showNumber)).Offset(int(pageNumber*showNumber)).Find(&friends).Error, "") + err = utils.Wrap(f.DB.Where("friend_user_id = ? ", friendUserID).Limit(int(showNumber)).Offset(int(pageNumber*showNumber)).Find(&friends).Error, "") return } -func (f *FriendGorm) FindFriendUserIDs(ctx context.Context, ownerUserID string, tx ...any) (friendUserIDs []string, err error) { +func (f *FriendGorm) FindFriendUserIDs(ctx context.Context, ownerUserID string) (friendUserIDs []string, err error) { defer func() { tracelog.SetCtxDebug(ctx, utils.GetFuncName(1), err, "ownerUserID", ownerUserID, "friendUserIDs", friendUserIDs) }() - return friendUserIDs, utils.Wrap(getDBConn(f.DB, tx).Model(&relation.FriendModel{}).Where("owner_user_id = ? ", ownerUserID).Pluck("friend_user_id", &friendUserIDs).Error, "") + return friendUserIDs, utils.Wrap(f.DB.Model(&relation.FriendModel{}).Where("owner_user_id = ? ", ownerUserID).Pluck("friend_user_id", &friendUserIDs).Error, "") } diff --git a/pkg/common/db/relation/friend_request_model.go b/pkg/common/db/relation/friend_request_model.go index 060de47d2..1326048ec 100644 --- a/pkg/common/db/relation/friend_request_model.go +++ b/pkg/common/db/relation/friend_request_model.go @@ -8,92 +8,92 @@ import ( "gorm.io/gorm" ) -//var FriendRequestDB *gorm.DB - -func NewFriendRequestGorm(db *gorm.DB) *FriendRequestGorm { - var fr FriendRequestGorm - fr.DB = db - return &fr +func NewFriendRequestGorm(db *gorm.DB) relation.FriendRequestModelInterface { + return &FriendRequestGorm{db} } type FriendRequestGorm struct { - DB *gorm.DB `gorm:"-"` + DB *gorm.DB +} + +func (f *FriendRequestGorm) NewTx(tx any) relation.FriendRequestModelInterface { + return &FriendRequestGorm{DB: tx.(*gorm.DB)} } // 插入多条记录 -func (f *FriendRequestGorm) Create(ctx context.Context, friendRequests []*relation.FriendRequestModel, tx ...any) (err error) { +func (f *FriendRequestGorm) Create(ctx context.Context, friendRequests []*relation.FriendRequestModel) (err error) { defer func() { tracelog.SetCtxDebug(ctx, utils.GetFuncName(1), err, "friendRequests", friendRequests) }() - return utils.Wrap(getDBConn(f.DB, tx).Create(&friendRequests).Error, "") + return utils.Wrap(f.DB.Create(&friendRequests).Error, "") } // 删除记录 -func (f *FriendRequestGorm) Delete(ctx context.Context, fromUserID, toUserID string, tx ...any) (err error) { +func (f *FriendRequestGorm) Delete(ctx context.Context, fromUserID, toUserID string) (err error) { defer func() { tracelog.SetCtxDebug(ctx, utils.GetFuncName(1), err, "fromUserID", fromUserID, "toUserID", toUserID) }() - return utils.Wrap(getDBConn(f.DB, tx).Where("from_user_id = ? AND to_user_id = ?", fromUserID, toUserID).Delete(&relation.FriendRequestModel{}).Error, "") + return utils.Wrap(f.DB.Where("from_user_id = ? AND to_user_id = ?", fromUserID, toUserID).Delete(&relation.FriendRequestModel{}).Error, "") } // 更新零值 -func (f *FriendRequestGorm) UpdateByMap(ctx context.Context, formUserID string, toUserID string, args map[string]interface{}, tx ...any) (err error) { +func (f *FriendRequestGorm) UpdateByMap(ctx context.Context, formUserID string, toUserID string, args map[string]interface{}) (err error) { defer func() { tracelog.SetCtxDebug(ctx, utils.GetFuncName(1), err, "formUserID", formUserID, "toUserID", toUserID, "args", args) }() - return utils.Wrap(getDBConn(f.DB, tx).Model(&relation.FriendRequestModel{}).Where("from_user_id = ? AND to_user_id ", formUserID, toUserID).Updates(args).Error, "") + return utils.Wrap(f.DB.Model(&relation.FriendRequestModel{}).Where("from_user_id = ? AND to_user_id ", formUserID, toUserID).Updates(args).Error, "") } // 更新多条记录 (非零值) -func (f *FriendRequestGorm) Update(ctx context.Context, friendRequests []*relation.FriendRequestModel, tx ...any) (err error) { +func (f *FriendRequestGorm) Update(ctx context.Context, friendRequests []*relation.FriendRequestModel) (err error) { defer func() { tracelog.SetCtxDebug(ctx, utils.GetFuncName(1), err, "friendRequests", friendRequests) }() - return utils.Wrap(getDBConn(f.DB, tx).Updates(&friendRequests).Error, "") + return utils.Wrap(f.DB.Updates(&friendRequests).Error, "") } // 获取来指定用户的好友申请 未找到 不返回错误 -func (f *FriendRequestGorm) Find(ctx context.Context, fromUserID, toUserID string, tx ...any) (friendRequest *relation.FriendRequestModel, err error) { +func (f *FriendRequestGorm) Find(ctx context.Context, fromUserID, toUserID string) (friendRequest *relation.FriendRequestModel, err error) { friendRequest = &relation.FriendRequestModel{} defer func() { tracelog.SetCtxDebug(ctx, utils.GetFuncName(1), err, "fromUserID", fromUserID, "toUserID", toUserID, "friendRequest", *friendRequest) }() - utils.Wrap(getDBConn(f.DB, tx).Where("from_user_id = ? and to_user_id", fromUserID, toUserID).Find(friendRequest).Error, "") + utils.Wrap(f.DB.Where("from_user_id = ? and to_user_id", fromUserID, toUserID).Find(friendRequest).Error, "") return } -func (f *FriendRequestGorm) Take(ctx context.Context, fromUserID, toUserID string, tx ...any) (friendRequest *relation.FriendRequestModel, err error) { +func (f *FriendRequestGorm) Take(ctx context.Context, fromUserID, toUserID string) (friendRequest *relation.FriendRequestModel, err error) { friendRequest = &relation.FriendRequestModel{} defer func() { tracelog.SetCtxDebug(ctx, utils.GetFuncName(1), err, "fromUserID", fromUserID, "toUserID", toUserID, "friendRequest", *friendRequest) }() - utils.Wrap(getDBConn(f.DB, tx).Where("from_user_id = ? and to_user_id", fromUserID, toUserID).Take(friendRequest).Error, "") + utils.Wrap(f.DB.Where("from_user_id = ? and to_user_id", fromUserID, toUserID).Take(friendRequest).Error, "") return } // 获取toUserID收到的好友申请列表 -func (f *FriendRequestGorm) FindToUserID(ctx context.Context, toUserID string, pageNumber, showNumber int32, tx ...any) (friendRequests []*relation.FriendRequestModel, total int64, err error) { +func (f *FriendRequestGorm) FindToUserID(ctx context.Context, toUserID string, pageNumber, showNumber int32) (friendRequests []*relation.FriendRequestModel, total int64, err error) { defer func() { tracelog.SetCtxDebug(ctx, utils.GetFuncName(1), err, "toUserID", toUserID, "friendRequests", friendRequests) }() - err = getDBConn(f.DB, tx).Model(&relation.FriendRequestModel{}).Where("to_user_id = ? ", toUserID).Count(&total).Error + err = f.DB.Model(&relation.FriendRequestModel{}).Where("to_user_id = ? ", toUserID).Count(&total).Error if err != nil { return nil, 0, utils.Wrap(err, "") } - err = utils.Wrap(getDBConn(f.DB, tx).Where("to_user_id = ? ", toUserID).Limit(int(showNumber)).Offset(int(pageNumber*showNumber)).Find(&friendRequests).Error, "") + err = utils.Wrap(f.DB.Where("to_user_id = ? ", toUserID).Limit(int(showNumber)).Offset(int(pageNumber*showNumber)).Find(&friendRequests).Error, "") return } // 获取fromUserID发出去的好友申请列表 -func (f *FriendRequestGorm) FindFromUserID(ctx context.Context, fromUserID string, pageNumber, showNumber int32, tx ...any) (friendRequests []*relation.FriendRequestModel, total int64, err error) { +func (f *FriendRequestGorm) FindFromUserID(ctx context.Context, fromUserID string, pageNumber, showNumber int32) (friendRequests []*relation.FriendRequestModel, total int64, err error) { defer func() { tracelog.SetCtxDebug(ctx, utils.GetFuncName(1), err, "fromUserID", fromUserID, "friendRequests", friendRequests) }() - err = getDBConn(f.DB, tx).Model(&relation.FriendRequestModel{}).Where("from_user_id = ? ", fromUserID).Count(&total).Error + err = f.DB.Model(&relation.FriendRequestModel{}).Where("from_user_id = ? ", fromUserID).Count(&total).Error if err != nil { return nil, 0, utils.Wrap(err, "") } - err = utils.Wrap(getDBConn(f.DB, tx).Where("from_user_id = ? ", fromUserID).Limit(int(showNumber)).Offset(int(pageNumber*showNumber)).Find(&friendRequests).Error, "") + err = utils.Wrap(f.DB.Where("from_user_id = ? ", fromUserID).Limit(int(showNumber)).Offset(int(pageNumber*showNumber)).Find(&friendRequests).Error, "") return } diff --git a/pkg/common/db/relation/mysql_init.go b/pkg/common/db/relation/mysql_init.go index 92f80a176..1047f51cc 100644 --- a/pkg/common/db/relation/mysql_init.go +++ b/pkg/common/db/relation/mysql_init.go @@ -58,6 +58,7 @@ func newMysqlGormDB() (*gorm.DB, error) { return db, nil } +// gorm mysql func NewGormDB() (*gorm.DB, error) { return newMysqlGormDB() } @@ -67,12 +68,3 @@ type Writer struct{} func (w Writer) Printf(format string, args ...interface{}) { fmt.Printf(format, args...) } - -func getDBConn(db *gorm.DB, tx []any) *gorm.DB { - if len(tx) > 0 { - if txDB, ok := tx[0].(*gorm.DB); ok { - return txDB - } - } - return db -} diff --git a/pkg/common/db/relation/user_model.go b/pkg/common/db/relation/user_model.go index a3b573d94..c76aca95e 100644 --- a/pkg/common/db/relation/user_model.go +++ b/pkg/common/db/relation/user_model.go @@ -5,7 +5,6 @@ import ( "Open_IM/pkg/common/tracelog" "Open_IM/pkg/utils" "context" - "fmt" "gorm.io/gorm" ) @@ -13,97 +12,75 @@ type UserGorm struct { DB *gorm.DB } -func NewUserGorm(db *gorm.DB) *UserGorm { - var user UserGorm - user.DB = db - return &user +func NewUserGorm(DB *gorm.DB) relation.UserModelInterface { + return &UserGorm{DB: DB} } // 插入多条 -func (u *UserGorm) Create(ctx context.Context, users []*relation.UserModel, tx ...any) (err error) { +func (u *UserGorm) Create(ctx context.Context, users []*relation.UserModel) (err error) { defer func() { tracelog.SetCtxDebug(ctx, utils.GetFuncName(1), err, "users", users) }() - return utils.Wrap(getDBConn(u.DB, tx).Create(&users).Error, "") + return utils.Wrap(u.DB.Create(&users).Error, "") } // 更新用户信息 零值 -func (u *UserGorm) UpdateByMap(ctx context.Context, userID string, args map[string]interface{}, tx ...any) (err error) { +func (u *UserGorm) UpdateByMap(ctx context.Context, userID string, args map[string]interface{}) (err error) { defer func() { tracelog.SetCtxDebug(ctx, utils.GetFuncName(1), err, "userID", userID, "args", args) }() - return utils.Wrap(getDBConn(u.DB, tx).Model(&relation.UserModel{}).Where("user_id = ?", userID).Updates(args).Error, "") + return utils.Wrap(u.DB.Model(&relation.UserModel{}).Where("user_id = ?", userID).Updates(args).Error, "") } // 更新多个用户信息 非零值 -func (u *UserGorm) Update(ctx context.Context, users []*relation.UserModel, tx ...any) (err error) { +func (u *UserGorm) Update(ctx context.Context, users []*relation.UserModel) (err error) { defer func() { tracelog.SetCtxDebug(ctx, utils.GetFuncName(1), err, "users", users) }() - return utils.Wrap(getDBConn(u.DB, tx).Updates(&users).Error, "") + return utils.Wrap(u.DB.Updates(&users).Error, "") } // 获取指定用户信息 不存在,也不返回错误 -func (u *UserGorm) Find(ctx context.Context, userIDs []string, tx ...any) (users []*relation.UserModel, err error) { +func (u *UserGorm) Find(ctx context.Context, userIDs []string) (users []*relation.UserModel, err error) { defer func() { tracelog.SetCtxDebug(ctx, utils.GetFuncName(1), err, "userIDs", userIDs, "users", users) }() - err = utils.Wrap(getDBConn(u.DB, tx).Where("user_id in (?)", userIDs).Find(&users).Error, "") + err = utils.Wrap(u.DB.Where("user_id in (?)", userIDs).Find(&users).Error, "") return users, err } // 获取某个用户信息 不存在,则返回错误 -func (u *UserGorm) Take(ctx context.Context, userID string, tx ...any) (user *relation.UserModel, err error) { +func (u *UserGorm) Take(ctx context.Context, userID string) (user *relation.UserModel, err error) { user = &relation.UserModel{} defer func() { tracelog.SetCtxDebug(ctx, utils.GetFuncName(1), err, "userID", userID, "user", *user) }() - err = utils.Wrap(getDBConn(u.DB, tx).Where("user_id = ?", userID).Take(&user).Error, "") + err = utils.Wrap(u.DB.Where("user_id = ?", userID).Take(&user).Error, "") return user, err } -// 通过名字查找用户 不存在,不返回错误 -func (u *UserGorm) GetByName(ctx context.Context, userName string, pageNumber, showNumber int32, tx ...any) (users []*relation.UserModel, count int64, err error) { - defer func() { - tracelog.SetCtxDebug(ctx, utils.GetFuncName(1), err, "userName", userName, "pageNumber", pageNumber, "showNumber", showNumber, "users", users, "count", count) - }() - err = utils.Wrap(getDBConn(u.DB, tx).Model(&relation.UserModel{}).Where(" name like ?", fmt.Sprintf("%%%s%%", userName)).Count(&count).Error, "") - if err != nil { - return - } - err = utils.Wrap(getDBConn(u.DB, tx).Model(&relation.UserModel{}).Where(" name like ?", fmt.Sprintf("%%%s%%", userName)).Limit(int(showNumber)).Offset(int(showNumber*pageNumber)).Find(&users).Error, "") - return -} - -// 通过名字或userID查找用户 不存在,不返回错误 -func (u *UserGorm) GetByNameAndID(ctx context.Context, content string, pageNumber, showNumber int32, tx ...any) (users []*relation.UserModel, count int64, err error) { - defer func() { - tracelog.SetCtxDebug(ctx, utils.GetFuncName(1), err, "content", content, "pageNumber", pageNumber, "showNumber", showNumber, "users", users, "count", count) - }() - db := getDBConn(u.DB, tx).Model(&relation.UserModel{}).Where(" name like ? or user_id = ? ", fmt.Sprintf("%%%s%%", content), content) - if err = db.Count(&count).Error; err != nil { - return - } - err = utils.Wrap(db.Limit(int(showNumber)).Offset(int(showNumber*pageNumber)).Find(&users).Error, "") - return -} - // 获取用户信息 不存在,不返回错误 -func (u *UserGorm) Page(ctx context.Context, pageNumber, showNumber int32, tx ...any) (users []*relation.UserModel, count int64, err error) { +func (u *UserGorm) Page(ctx context.Context, pageNumber, showNumber int32) (users []*relation.UserModel, count int64, err error) { defer func() { tracelog.SetCtxDebug(ctx, utils.GetFuncName(1), err, "pageNumber", pageNumber, "showNumber", showNumber, "users", users, "count", count) }() - err = utils.Wrap(getDBConn(u.DB, tx).Model(&relation.UserModel{}).Count(&count).Error, "") + err = utils.Wrap(u.DB.Model(&relation.UserModel{}).Count(&count).Error, "") if err != nil { return } - err = utils.Wrap(getDBConn(u.DB, tx).Limit(int(showNumber)).Offset(int(pageNumber*showNumber)).Find(&users).Error, "") + err = utils.Wrap(u.DB.Limit(int(showNumber)).Offset(int(pageNumber*showNumber)).Find(&users).Error, "") return } // 获取所有用户ID -func (u *UserGorm) GetAllUserID(ctx context.Context) ([]string, error) { - var userIDs []string - err := u.DB.Pluck("user_id", &userIDs).Error - return userIDs, err +func (u *UserGorm) PageUserID(ctx context.Context, pageNumber, showNumber int32) (userIDs []string, count int64, err error) { + defer func() { + tracelog.SetCtxDebug(ctx, utils.GetFuncName(1), err, "pageNumber", pageNumber, "showNumber", showNumber, "userIDs", userIDs, "count", count) + }() + err = utils.Wrap(u.DB.Model(&relation.UserModel{}).Count(&count).Error, "") + if err != nil { + return + } + err = u.DB.Limit(int(showNumber)).Offset(int(pageNumber*showNumber)).Pluck("user_id", &userIDs).Error + return userIDs, count, err } diff --git a/pkg/common/db/table/relation/black.go b/pkg/common/db/table/relation/black.go index 6dcf8fda0..fd25cff7f 100644 --- a/pkg/common/db/table/relation/black.go +++ b/pkg/common/db/table/relation/black.go @@ -1,6 +1,7 @@ package relation import ( + "context" "time" ) @@ -22,4 +23,12 @@ func (BlackModel) TableName() string { } type BlackModelInterface interface { + Create(ctx context.Context, blacks []*BlackModel) (err error) + Delete(ctx context.Context, blacks []*BlackModel) (err error) + UpdateByMap(ctx context.Context, ownerUserID, blockUserID string, args map[string]interface{}) (err error) + Update(ctx context.Context, blacks []*BlackModel) (err error) + Find(ctx context.Context, blacks []*BlackModel) (blackList []*BlackModel, err error) + Take(ctx context.Context, ownerUserID, blockUserID string) (black *BlackModel, err error) + FindOwnerBlacks(ctx context.Context, ownerUserID string, pageNumber, showNumber int32) (blacks []*BlackModel, total int64, err error) + FindBlackUserIDs(ctx context.Context, ownerUserID string) (blackUserIDs []string, err error) } diff --git a/pkg/common/db/table/relation/friend.go b/pkg/common/db/table/relation/friend.go index de00f7747..e9eb35c65 100644 --- a/pkg/common/db/table/relation/friend.go +++ b/pkg/common/db/table/relation/friend.go @@ -1,6 +1,9 @@ package relation -import "time" +import ( + "context" + "time" +) const ( FriendModelTableName = "friends" @@ -21,4 +24,30 @@ func (FriendModel) TableName() string { } type FriendModelInterface interface { + // 插入多条记录 + Create(ctx context.Context, friends []*FriendModel) (err error) + // 删除ownerUserID指定的好友 + Delete(ctx context.Context, ownerUserID string, friendUserIDs []string) (err error) + // 更新ownerUserID单个好友信息 更新零值 + UpdateByMap(ctx context.Context, ownerUserID string, friendUserID string, args map[string]interface{}) (err error) + // 更新好友信息的非零值 + Update(ctx context.Context, friends []*FriendModel) (err error) + // 更新好友备注(也支持零值 ) + UpdateRemark(ctx context.Context, ownerUserID, friendUserID, remark string) (err error) + // 获取单个好友信息,如没找到 返回错误 + Take(ctx context.Context, ownerUserID, friendUserID string) (friend *FriendModel, err error) + // 查找好友关系,如果是双向关系,则都返回 + FindUserState(ctx context.Context, userID1, userID2 string) (friends []*FriendModel, err error) + // 获取 owner指定的好友列表 如果有friendUserIDs不存在,也不返回错误 + FindFriends(ctx context.Context, ownerUserID string, friendUserIDs []string) (friends []*FriendModel, err error) + // 获取哪些人添加了friendUserID 如果有ownerUserIDs不存在,也不返回错误 + FindReversalFriends(ctx context.Context, friendUserID string, ownerUserIDs []string) (friends []*FriendModel, err error) + // 获取ownerUserID好友列表 支持翻页 + FindOwnerFriends(ctx context.Context, ownerUserID string, pageNumber, showNumber int32) (friends []*FriendModel, total int64, err error) + // 获取哪些人添加了friendUserID 支持翻页 + FindInWhoseFriends(ctx context.Context, friendUserID string, pageNumber, showNumber int32) (friends []*FriendModel, total int64, err error) + // 获取好友UserID列表 + FindFriendUserIDs(ctx context.Context, ownerUserID string) (friendUserIDs []string, err error) + + NewTx(tx any) FriendModelInterface } diff --git a/pkg/common/db/table/relation/friend_request.go b/pkg/common/db/table/relation/friend_request.go index 8a4ea2805..559b67bb2 100644 --- a/pkg/common/db/table/relation/friend_request.go +++ b/pkg/common/db/table/relation/friend_request.go @@ -1,6 +1,9 @@ package relation -import "time" +import ( + "context" + "time" +) const FriendRequestModelTableName = "friend_requests" @@ -21,4 +24,21 @@ func (FriendRequestModel) TableName() string { } type FriendRequestModelInterface interface { + // 插入多条记录 + Create(ctx context.Context, friendRequests []*FriendRequestModel) (err error) + // 删除记录 + Delete(ctx context.Context, fromUserID, toUserID string) (err error) + // 更新零值 + UpdateByMap(ctx context.Context, formUserID string, toUserID string, args map[string]interface{}) (err error) + // 更新多条记录 (非零值) + Update(ctx context.Context, friendRequests []*FriendRequestModel) (err error) + // 获取来指定用户的好友申请 未找到 不返回错误 + Find(ctx context.Context, fromUserID, toUserID string) (friendRequest *FriendRequestModel, err error) + Take(ctx context.Context, fromUserID, toUserID string) (friendRequest *FriendRequestModel, err error) + // 获取toUserID收到的好友申请列表 + FindToUserID(ctx context.Context, toUserID string, pageNumber, showNumber int32) (friendRequests []*FriendRequestModel, total int64, err error) + // 获取fromUserID发出去的好友申请列表 + FindFromUserID(ctx context.Context, fromUserID string, pageNumber, showNumber int32) (friendRequests []*FriendRequestModel, total int64, err error) + + NewTx(tx any) FriendRequestModelInterface } diff --git a/pkg/common/db/table/relation/user.go b/pkg/common/db/table/relation/user.go index 2dd119e13..2b0d2d441 100644 --- a/pkg/common/db/table/relation/user.go +++ b/pkg/common/db/table/relation/user.go @@ -1,6 +1,9 @@ package relation -import "time" +import ( + "context" + "time" +) const ( UserModelTableName = "users" @@ -17,13 +20,23 @@ type UserModel struct { Email string `gorm:"column:email;size:64"` Ex string `gorm:"column:ex;size:1024"` CreateTime time.Time `gorm:"column:create_time;index:create_time; autoCreateTime"` - AppMangerLevel int32 `gorm:"column:app_manger_level"` + AppMangerLevel int32 `gorm:"column:app_manger_level;default:18"` GlobalRecvMsgOpt int32 `gorm:"column:global_recv_msg_opt"` } func (UserModel) TableName() string { - return GroupRequestModelTableName + return UserModelTableName } type UserModelInterface interface { + Create(ctx context.Context, users []*UserModel) (err error) + UpdateByMap(ctx context.Context, userID string, args map[string]interface{}) (err error) + Update(ctx context.Context, users []*UserModel) (err error) + // 获取指定用户信息 不存在,也不返回错误 + Find(ctx context.Context, userIDs []string) (users []*UserModel, err error) + // 获取某个用户信息 不存在,则返回错误 + Take(ctx context.Context, userID string) (user *UserModel, err error) + // 获取用户信息 不存在,不返回错误 + Page(ctx context.Context, pageNumber, showNumber int32) (users []*UserModel, count int64, err error) + PageUserID(ctx context.Context, pageNumber, showNumber int32) (userIDs []string, count int64, err error) } diff --git a/pkg/utils/utils_v2.go b/pkg/utils/utils_v2.go index 2300556a9..a28efcf14 100644 --- a/pkg/utils/utils_v2.go +++ b/pkg/utils/utils_v2.go @@ -5,6 +5,28 @@ import ( "sort" ) +// SliceSub a中存在,b中不存在 (a-b) +func SliceSub[E comparable](a, b []E) []E { + k := make(map[E]struct{}) + for i := 0; i < len(b); i++ { + k[b[i]] = struct{}{} + } + t := make(map[E]struct{}) + rs := make([]E, 0, len(a)) + for i := 0; i < len(a); i++ { + e := a[i] + if _, ok := t[e]; ok { + continue + } + if _, ok := k[e]; ok { + continue + } + rs = append(rs, e) + t[e] = struct{}{} + } + return rs +} + // DistinctAny 去重 func DistinctAny[E any, K comparable](es []E, fn func(e E) K) []E { v := make([]E, 0, len(es))