From 141cfef1b67303e8371cca9697a3378cbcb205ef Mon Sep 17 00:00:00 2001 From: Gordon <1432970085@qq.com> Date: Mon, 20 Feb 2023 16:42:52 +0800 Subject: [PATCH] conversation update --- pkg/common/db/controller/conversation.go | 14 ++++++++------ pkg/common/db/relation/conversation_model.go | 11 +++++++++++ 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/pkg/common/db/controller/conversation.go b/pkg/common/db/controller/conversation.go index 863026909..3bf6c4d78 100644 --- a/pkg/common/db/controller/conversation.go +++ b/pkg/common/db/controller/conversation.go @@ -94,17 +94,17 @@ var _ ConversationDataBaseInterface = (*ConversationDataBase)(nil) type ConversationDataBase struct { conversationDB relation.Conversation cache cache.ConversationCache - db *gorm.DB } func (c ConversationDataBase) SetUsersConversationFiledTx(ctx context.Context, userIDList []string, conversation *relationTb.ConversationModel, filedMap map[string]interface{}) error { - return c.db.Transaction(func(tx *gorm.DB) error { - haveUserID, err := c.conversationDB.FindUserID(ctx, userIDList, conversation.ConversationID, tx) + fn := func(tx any) error { + temp := c.conversationDB.NewTx(tx) + haveUserID, err := temp.FindUserID(ctx, userIDList, conversation.ConversationID, tx) if err != nil { return err } if len(haveUserID) > 0 { - err = c.conversationDB.UpdateByMap(ctx, haveUserID, conversation.ConversationID, filedMap, tx) + err = temp.UpdateByMap(ctx, haveUserID, conversation.ConversationID, filedMap, tx) if err != nil { return err } @@ -119,7 +119,7 @@ func (c ConversationDataBase) SetUsersConversationFiledTx(ctx context.Context, u temp.OwnerUserID = v cList = append(cList, temp) } - err = c.conversationDB.Create(ctx, cList) + err = temp.Create(ctx, cList) if err != nil { return err } @@ -134,7 +134,9 @@ func (c ConversationDataBase) SetUsersConversationFiledTx(ctx context.Context, u return err } return nil - }) + } + + return c.conversationDB.Transaction(fn) } func NewConversationDataBase(db relation.Conversation, cache cache.ConversationCache) *ConversationDataBase { diff --git a/pkg/common/db/relation/conversation_model.go b/pkg/common/db/relation/conversation_model.go index aef72365b..a975928e5 100644 --- a/pkg/common/db/relation/conversation_model.go +++ b/pkg/common/db/relation/conversation_model.go @@ -18,6 +18,8 @@ type Conversation interface { FindUserIDAllConversationID(ctx context.Context, userID string, tx ...any) ([]string, error) Take(ctx context.Context, userID, conversationID string, tx ...any) (conversation *relation.ConversationModel, err error) FindConversationID(ctx context.Context, userID string, conversationIDList []string, tx ...any) (existConversationID []string, err error) + Transaction(func(tx any) error) error + NewTx(tx any) Conversation } type ConversationGorm struct { DB *gorm.DB @@ -26,6 +28,15 @@ type ConversationGorm struct { func NewConversationGorm(DB *gorm.DB) Conversation { return &ConversationGorm{DB: DB} } +func (c *ConversationGorm) Transaction(fn func(tx any) error) error { + return c.DB.Transaction(func(tx *gorm.DB) error { + return fn(tx) + }) +} + +func (c *ConversationGorm) NewTx(tx any) Conversation { + return &ConversationGorm{DB: tx.(*gorm.DB)} +} func (c *ConversationGorm) Create(ctx context.Context, conversations []*relation.ConversationModel, tx ...any) (err error) { defer func() {