diff --git a/pkg/common/db/controller/conversation.go b/pkg/common/db/controller/conversation.go index d8425240a..73af59d30 100644 --- a/pkg/common/db/controller/conversation.go +++ b/pkg/common/db/controller/conversation.go @@ -32,7 +32,7 @@ type ConversationDatabase interface { SetUsersConversationFiledTx(ctx context.Context, userIDList []string, conversation *relationTb.ConversationModel, filedMap map[string]interface{}) error } -func NewConversationDatabase(conversation relation.Conversation, cache cache.ConversationCache, tx tx.Tx) ConversationDatabase { +func NewConversationDatabase(conversation relationTb.ConversationModelInterface, cache cache.ConversationCache, tx tx.Tx) ConversationDatabase { return &ConversationDataBase{ conversationDB: conversation, cache: cache, @@ -41,7 +41,7 @@ func NewConversationDatabase(conversation relation.Conversation, cache cache.Con } type ConversationDataBase struct { - conversationDB relation.Conversation + conversationDB relationTb.ConversationModelInterface cache cache.ConversationCache tx tx.Tx } diff --git a/pkg/common/db/relation/black_model.go b/pkg/common/db/relation/black_model.go index faae6bf2c..58b8edd21 100644 --- a/pkg/common/db/relation/black_model.go +++ b/pkg/common/db/relation/black_model.go @@ -8,11 +8,11 @@ import ( ) type BlackGorm struct { - DB *gorm.DB + *MetaDB } func NewBlackGorm(db *gorm.DB) relation.BlackModelInterface { - return &BlackGorm{db} + return &BlackGorm{NewMetaDB(db, &relation.BlackModel{})} } func (b *BlackGorm) Create(ctx context.Context, blacks []*relation.BlackModel) (err error) { diff --git a/pkg/common/db/relation/chat_log_model.go b/pkg/common/db/relation/chat_log_model.go index 80ee30758..85e1ba175 100644 --- a/pkg/common/db/relation/chat_log_model.go +++ b/pkg/common/db/relation/chat_log_model.go @@ -14,11 +14,11 @@ import ( ) type ChatLogGorm struct { - DB *gorm.DB + *MetaDB } func NewChatLogGorm(db *gorm.DB) relation.ChatLogModelInterface { - return &ChatLogGorm{DB: db} + return &ChatLogGorm{NewMetaDB(db, &relation.ChatLogModel{})} } func (c *ChatLogGorm) Create(msg pbMsg.MsgDataToMQ) error { diff --git a/pkg/common/db/relation/conversation_model.go b/pkg/common/db/relation/conversation_model.go index 2b03966c8..4d572b314 100644 --- a/pkg/common/db/relation/conversation_model.go +++ b/pkg/common/db/relation/conversation_model.go @@ -8,29 +8,16 @@ import ( "gorm.io/gorm" ) -type Conversation interface { - Create(ctx context.Context, conversations []*relation.ConversationModel) (err error) - Delete(ctx context.Context, groupIDs []string) (err error) - UpdateByMap(ctx context.Context, userIDList []string, conversationID string, args map[string]interface{}) (err error) - Update(ctx context.Context, conversations []*relation.ConversationModel) (err error) - Find(ctx context.Context, ownerUserID string, conversationIDs []string) (conversations []*relation.ConversationModel, err error) - FindUserID(ctx context.Context, userIDList []string, conversationID string) ([]string, error) - FindUserIDAllConversationID(ctx context.Context, userID string) ([]string, error) - Take(ctx context.Context, userID, conversationID string) (conversation *relation.ConversationModel, err error) - FindConversationID(ctx context.Context, userID string, conversationIDList []string) (existConversationID []string, err error) - FindRecvMsgNotNotifyUserIDs(ctx context.Context, groupID string) ([]string, error) - NewTx(tx any) Conversation -} type ConversationGorm struct { - DB *gorm.DB + *MetaDB } -func NewConversationGorm(DB *gorm.DB) Conversation { - return &ConversationGorm{DB: DB} +func NewConversationGorm(db *gorm.DB) relation.ConversationModelInterface { + return &ConversationGorm{NewMetaDB(db, &relation.ConversationModel{})} } -func (c *ConversationGorm) NewTx(tx any) Conversation { - return &ConversationGorm{DB: tx.(*gorm.DB)} +func (c *ConversationGorm) NewTx(tx any) relation.ConversationModelInterface { + return &ConversationGorm{NewMetaDB(tx.(*gorm.DB), &relation.ConversationModel{})} } func (c *ConversationGorm) Create(ctx context.Context, conversations []*relation.ConversationModel) (err error) { diff --git a/pkg/common/db/relation/friend_model.go b/pkg/common/db/relation/friend_model.go index d48e6476e..e149922c6 100644 --- a/pkg/common/db/relation/friend_model.go +++ b/pkg/common/db/relation/friend_model.go @@ -8,15 +8,15 @@ import ( ) type FriendGorm struct { - DB *gorm.DB + *MetaDB } func NewFriendGorm(db *gorm.DB) relation.FriendModelInterface { - return &FriendGorm{DB: db} + return &FriendGorm{NewMetaDB(db, &relation.FriendModel{})} } func (f *FriendGorm) NewTx(tx any) relation.FriendModelInterface { - return &FriendGorm{DB: tx.(*gorm.DB)} + return &FriendGorm{NewMetaDB(tx.(*gorm.DB), &relation.FriendModel{})} } // 插入多条记录 diff --git a/pkg/common/db/relation/friend_request_model.go b/pkg/common/db/relation/friend_request_model.go index 0c79eeefb..70afaa3cb 100644 --- a/pkg/common/db/relation/friend_request_model.go +++ b/pkg/common/db/relation/friend_request_model.go @@ -7,16 +7,16 @@ import ( "gorm.io/gorm" ) -func NewFriendRequestGorm(db *gorm.DB) relation.FriendRequestModelInterface { - return &FriendRequestGorm{db} +type FriendRequestGorm struct { + *MetaDB } -type FriendRequestGorm struct { - DB *gorm.DB +func NewFriendRequestGorm(db *gorm.DB) relation.FriendRequestModelInterface { + return &FriendRequestGorm{NewMetaDB(db, &relation.FriendModel{})} } func (f *FriendRequestGorm) NewTx(tx any) relation.FriendRequestModelInterface { - return &FriendRequestGorm{DB: tx.(*gorm.DB)} + return &FriendRequestGorm{NewMetaDB(tx.(*gorm.DB), &relation.FriendModel{})} } // 插入多条记录 diff --git a/pkg/common/db/relation/group_member_model.go b/pkg/common/db/relation/group_member_model.go index 884f0f4d7..c32d1ace9 100644 --- a/pkg/common/db/relation/group_member_model.go +++ b/pkg/common/db/relation/group_member_model.go @@ -11,15 +11,15 @@ import ( var _ relation.GroupMemberModelInterface = (*GroupMemberGorm)(nil) type GroupMemberGorm struct { - DB *gorm.DB + *MetaDB } func NewGroupMemberDB(db *gorm.DB) relation.GroupMemberModelInterface { - return &GroupMemberGorm{DB: db} + return &GroupMemberGorm{NewMetaDB(db, &relation.GroupMemberModel{})} } func (g *GroupMemberGorm) NewTx(tx any) relation.GroupMemberModelInterface { - return &GroupMemberGorm{DB: tx.(*gorm.DB)} + return &GroupMemberGorm{NewMetaDB(tx.(*gorm.DB), &relation.GroupMemberModel{})} } func (g *GroupMemberGorm) Create(ctx context.Context, groupMemberList []*relation.GroupMemberModel) (err error) { diff --git a/pkg/common/db/relation/group_model.go b/pkg/common/db/relation/group_model.go index 8ff5cd36f..99c230080 100644 --- a/pkg/common/db/relation/group_model.go +++ b/pkg/common/db/relation/group_model.go @@ -10,15 +10,15 @@ import ( var _ relation.GroupModelInterface = (*GroupGorm)(nil) type GroupGorm struct { - DB *gorm.DB + *MetaDB } func NewGroupDB(db *gorm.DB) relation.GroupModelInterface { - return &GroupGorm{DB: db} + return &GroupGorm{NewMetaDB(db, &relation.GroupModel{})} } func (g *GroupGorm) NewTx(tx any) relation.GroupModelInterface { - return &GroupGorm{DB: tx.(*gorm.DB)} + return &GroupGorm{NewMetaDB(tx.(*gorm.DB), &relation.GroupModel{})} } func (g *GroupGorm) Create(ctx context.Context, groups []*relation.GroupModel) (err error) { diff --git a/pkg/common/db/relation/group_request_model.go b/pkg/common/db/relation/group_request_model.go index cd03c3113..67c8f6bc7 100644 --- a/pkg/common/db/relation/group_request_model.go +++ b/pkg/common/db/relation/group_request_model.go @@ -8,21 +8,19 @@ import ( ) type GroupRequestGorm struct { - DB *gorm.DB -} - -func (g *GroupRequestGorm) NewTx(tx any) relation.GroupRequestModelInterface { - return &GroupRequestGorm{ - DB: tx.(*gorm.DB), - } + *MetaDB } func NewGroupRequest(db *gorm.DB) relation.GroupRequestModelInterface { return &GroupRequestGorm{ - DB: db, + NewMetaDB(db, &relation.GroupRequestModel{}), } } +func (g *GroupRequestGorm) NewTx(tx any) relation.GroupRequestModelInterface { + return &GroupRequestGorm{NewMetaDB(tx.(*gorm.DB), &relation.GroupRequestModel{})} +} + func (g *GroupRequestGorm) Create(ctx context.Context, groupRequests []*relation.GroupRequestModel) (err error) { return utils.Wrap(g.DB.Create(&groupRequests).Error, utils.GetSelfFuncName()) } diff --git a/pkg/common/db/relation/meta_db.go b/pkg/common/db/relation/meta_db.go new file mode 100644 index 000000000..b758bb863 --- /dev/null +++ b/pkg/common/db/relation/meta_db.go @@ -0,0 +1,22 @@ +package relation + +import ( + "context" + "gorm.io/gorm" +) + +type MetaDB struct { + DB *gorm.DB + table interface{} +} + +func NewMetaDB(db *gorm.DB, table any) *MetaDB { + return &MetaDB{ + DB: db, + table: table, + } +} + +func (g *MetaDB) db(ctx context.Context) *gorm.DB { + return g.DB.WithContext(ctx).Model(g.table) +} diff --git a/pkg/common/db/relation/object_hash_model.go b/pkg/common/db/relation/object_hash_model.go index 25626407f..807853e98 100644 --- a/pkg/common/db/relation/object_hash_model.go +++ b/pkg/common/db/relation/object_hash_model.go @@ -7,19 +7,19 @@ import ( "gorm.io/gorm" ) -func NewObjectHash(db *gorm.DB) relation.ObjectHashModelInterface { - return &ObjectHashGorm{ - DB: db, - } +type ObjectHashGorm struct { + *MetaDB } -type ObjectHashGorm struct { - DB *gorm.DB +func NewObjectHash(db *gorm.DB) relation.ObjectHashModelInterface { + return &ObjectHashGorm{ + NewMetaDB(db, &relation.ObjectHashModel{}), + } } func (o *ObjectHashGorm) NewTx(tx any) relation.ObjectHashModelInterface { return &ObjectHashGorm{ - DB: tx.(*gorm.DB), + NewMetaDB(tx.(*gorm.DB), &relation.ObjectHashModel{}), } } diff --git a/pkg/common/db/relation/object_info_model.go b/pkg/common/db/relation/object_info_model.go index c7ca1eaa0..8f73ac6d7 100644 --- a/pkg/common/db/relation/object_info_model.go +++ b/pkg/common/db/relation/object_info_model.go @@ -8,19 +8,19 @@ import ( "time" ) -func NewObjectInfo(db *gorm.DB) relation.ObjectInfoModelInterface { - return &ObjectInfoGorm{ - DB: db, - } +type ObjectInfoGorm struct { + *MetaDB } -type ObjectInfoGorm struct { - DB *gorm.DB +func NewObjectInfo(db *gorm.DB) relation.ObjectInfoModelInterface { + return &ObjectInfoGorm{ + NewMetaDB(db, &relation.ObjectInfoModel{}), + } } func (o *ObjectInfoGorm) NewTx(tx any) relation.ObjectInfoModelInterface { return &ObjectInfoGorm{ - DB: tx.(*gorm.DB), + NewMetaDB(tx.(*gorm.DB), &relation.ObjectInfoModel{}), } } diff --git a/pkg/common/db/relation/object_put_model.go b/pkg/common/db/relation/object_put_model.go index 5249923c2..135cc7e4c 100644 --- a/pkg/common/db/relation/object_put_model.go +++ b/pkg/common/db/relation/object_put_model.go @@ -8,19 +8,19 @@ import ( "time" ) -func NewObjectPut(db *gorm.DB) relation.ObjectPutModelInterface { - return &ObjectPutGorm{ - DB: db, - } +type ObjectPutGorm struct { + *MetaDB } -type ObjectPutGorm struct { - DB *gorm.DB +func NewObjectPut(db *gorm.DB) relation.ObjectPutModelInterface { + return &ObjectPutGorm{ + NewMetaDB(db, &relation.ObjectPutModel{}), + } } func (o *ObjectPutGorm) NewTx(tx any) relation.ObjectPutModelInterface { return &ObjectPutGorm{ - DB: tx.(*gorm.DB), + NewMetaDB(tx.(*gorm.DB), &relation.ObjectPutModel{}), } } diff --git a/pkg/common/db/relation/user_model.go b/pkg/common/db/relation/user_model.go index b99cc2539..668b35a21 100644 --- a/pkg/common/db/relation/user_model.go +++ b/pkg/common/db/relation/user_model.go @@ -2,65 +2,59 @@ package relation import ( "OpenIM/pkg/common/db/table/relation" - "OpenIM/pkg/common/log" "OpenIM/pkg/utils" "context" "gorm.io/gorm" ) type UserGorm struct { - DB *gorm.DB + *MetaDB } func NewUserGorm(db *gorm.DB) relation.UserModelInterface { - return &UserGorm{DB: db} -} - -func (u *UserGorm) db() *gorm.DB { - return u.DB.Model(&relation.UserModel{}) + return &UserGorm{NewMetaDB(db, &relation.UserModel{})} } // 插入多条 func (u *UserGorm) Create(ctx context.Context, users []*relation.UserModel) (err error) { - return utils.Wrap(u.db().Create(&users).Error, "") + return utils.Wrap(u.db(ctx).Create(&users).Error, "") } // 更新用户信息 零值 func (u *UserGorm) UpdateByMap(ctx context.Context, userID string, args map[string]interface{}) (err error) { - return utils.Wrap(u.db().Where("user_id = ?", userID).Updates(args).Error, "") + return utils.Wrap(u.db(ctx).Where("user_id = ?", userID).Updates(args).Error, "") } // 更新多个用户信息 非零值 func (u *UserGorm) Update(ctx context.Context, users []*relation.UserModel) (err error) { - return utils.Wrap(u.db().Updates(&users).Error, "") + return utils.Wrap(u.db(ctx).Updates(&users).Error, "") } // 获取指定用户信息 不存在,也不返回错误 func (u *UserGorm) Find(ctx context.Context, userIDs []string) (users []*relation.UserModel, err error) { - log.ZDebug(ctx, "Find args", "userIDs", userIDs, "db", u.db()) - err = utils.Wrap(u.db().Where("user_id in (?)", userIDs).Find(&users).Error, "") + err = utils.Wrap(u.db(ctx).Where("user_id in (?)", userIDs).Find(&users).Error, "") return users, err } // 获取某个用户信息 不存在,则返回错误 func (u *UserGorm) Take(ctx context.Context, userID string) (user *relation.UserModel, err error) { user = &relation.UserModel{} - err = utils.Wrap(u.db().Where("user_id = ?", userID).Take(&user).Error, "") + err = utils.Wrap(u.db(ctx).Where("user_id = ?", userID).Take(&user).Error, "") return user, err } // 获取用户信息 不存在,不返回错误 func (u *UserGorm) Page(ctx context.Context, pageNumber, showNumber int32) (users []*relation.UserModel, count int64, err error) { - err = utils.Wrap(u.db().Count(&count).Error, "") + err = utils.Wrap(u.db(ctx).Count(&count).Error, "") if err != nil { return } - err = utils.Wrap(u.db().Limit(int(showNumber)).Offset(int(pageNumber*showNumber)).Find(&users).Error, "") + err = utils.Wrap(u.db(ctx).Limit(int(showNumber)).Offset(int(pageNumber*showNumber)).Find(&users).Error, "") return } // 获取所有用户ID func (u *UserGorm) GetAllUserID(ctx context.Context) (userIDs []string, err error) { - err = u.db().Pluck("user_id", &userIDs).Error + err = u.db(ctx).Pluck("user_id", &userIDs).Error return userIDs, err } diff --git a/pkg/common/db/table/relation/conversation.go b/pkg/common/db/table/relation/conversation.go index 454d82b45..00e3a26a6 100644 --- a/pkg/common/db/table/relation/conversation.go +++ b/pkg/common/db/table/relation/conversation.go @@ -1,5 +1,7 @@ package relation +import "context" + const ( conversationModelTableName = "conversations" ) @@ -28,4 +30,15 @@ func (ConversationModel) TableName() string { } type ConversationModelInterface interface { + Create(ctx context.Context, conversations []*ConversationModel) (err error) + Delete(ctx context.Context, groupIDs []string) (err error) + UpdateByMap(ctx context.Context, userIDList []string, conversationID string, args map[string]interface{}) (err error) + Update(ctx context.Context, conversations []*ConversationModel) (err error) + Find(ctx context.Context, ownerUserID string, conversationIDs []string) (conversations []*ConversationModel, err error) + FindUserID(ctx context.Context, userIDList []string, conversationID string) ([]string, error) + FindUserIDAllConversationID(ctx context.Context, userID string) ([]string, error) + Take(ctx context.Context, userID, conversationID string) (conversation *ConversationModel, err error) + FindConversationID(ctx context.Context, userID string, conversationIDList []string) (existConversationID []string, err error) + FindRecvMsgNotNotifyUserIDs(ctx context.Context, groupID string) ([]string, error) + NewTx(tx any) ConversationModelInterface }