Merge remote-tracking branch 'origin/errcode' into errcode

This commit is contained in:
skiffer-git 2023-02-20 17:13:37 +08:00
commit 9b0682705e
2 changed files with 19 additions and 6 deletions

View File

@ -94,17 +94,17 @@ var _ ConversationDataBaseInterface = (*ConversationDataBase)(nil)
type ConversationDataBase struct { type ConversationDataBase struct {
conversationDB relation.Conversation conversationDB relation.Conversation
cache cache.ConversationCache cache cache.ConversationCache
db *gorm.DB
} }
func (c ConversationDataBase) SetUsersConversationFiledTx(ctx context.Context, userIDList []string, conversation *relationTb.ConversationModel, filedMap map[string]interface{}) error { func (c ConversationDataBase) SetUsersConversationFiledTx(ctx context.Context, userIDList []string, conversation *relationTb.ConversationModel, filedMap map[string]interface{}) error {
return c.db.Transaction(func(tx *gorm.DB) error { fn := func(tx any) error {
haveUserID, err := c.conversationDB.FindUserID(ctx, userIDList, conversation.ConversationID, tx) temp := c.conversationDB.NewTx(tx)
haveUserID, err := temp.FindUserID(ctx, userIDList, conversation.ConversationID, tx)
if err != nil { if err != nil {
return err return err
} }
if len(haveUserID) > 0 { if len(haveUserID) > 0 {
err = c.conversationDB.UpdateByMap(ctx, haveUserID, conversation.ConversationID, filedMap, tx) err = temp.UpdateByMap(ctx, haveUserID, conversation.ConversationID, filedMap, tx)
if err != nil { if err != nil {
return err return err
} }
@ -119,7 +119,7 @@ func (c ConversationDataBase) SetUsersConversationFiledTx(ctx context.Context, u
temp.OwnerUserID = v temp.OwnerUserID = v
cList = append(cList, temp) cList = append(cList, temp)
} }
err = c.conversationDB.Create(ctx, cList) err = temp.Create(ctx, cList)
if err != nil { if err != nil {
return err return err
} }
@ -134,7 +134,9 @@ func (c ConversationDataBase) SetUsersConversationFiledTx(ctx context.Context, u
return err return err
} }
return nil return nil
}) }
return c.conversationDB.Transaction(fn)
} }
func NewConversationDataBase(db relation.Conversation, cache cache.ConversationCache) *ConversationDataBase { func NewConversationDataBase(db relation.Conversation, cache cache.ConversationCache) *ConversationDataBase {

View File

@ -18,6 +18,8 @@ type Conversation interface {
FindUserIDAllConversationID(ctx context.Context, userID string, tx ...any) ([]string, error) FindUserIDAllConversationID(ctx context.Context, userID string, tx ...any) ([]string, error)
Take(ctx context.Context, userID, conversationID string, tx ...any) (conversation *relation.ConversationModel, err error) Take(ctx context.Context, userID, conversationID string, tx ...any) (conversation *relation.ConversationModel, err error)
FindConversationID(ctx context.Context, userID string, conversationIDList []string, tx ...any) (existConversationID []string, err error) FindConversationID(ctx context.Context, userID string, conversationIDList []string, tx ...any) (existConversationID []string, err error)
Transaction(func(tx any) error) error
NewTx(tx any) Conversation
} }
type ConversationGorm struct { type ConversationGorm struct {
DB *gorm.DB DB *gorm.DB
@ -26,6 +28,15 @@ type ConversationGorm struct {
func NewConversationGorm(DB *gorm.DB) Conversation { func NewConversationGorm(DB *gorm.DB) Conversation {
return &ConversationGorm{DB: DB} return &ConversationGorm{DB: DB}
} }
func (c *ConversationGorm) Transaction(fn func(tx any) error) error {
return c.DB.Transaction(func(tx *gorm.DB) error {
return fn(tx)
})
}
func (c *ConversationGorm) NewTx(tx any) Conversation {
return &ConversationGorm{DB: tx.(*gorm.DB)}
}
func (c *ConversationGorm) Create(ctx context.Context, conversations []*relation.ConversationModel, tx ...any) (err error) { func (c *ConversationGorm) Create(ctx context.Context, conversations []*relation.ConversationModel, tx ...any) (err error) {
defer func() { defer func() {