diff --git a/pkg/common/db/controller/group.go b/pkg/common/db/controller/group.go index 1a90b5b0a..a920fb54e 100644 --- a/pkg/common/db/controller/group.go +++ b/pkg/common/db/controller/group.go @@ -339,7 +339,15 @@ func (g *groupDatabase) UpdateGroupMembers(ctx context.Context, data []*relation } func (g *groupDatabase) CreateGroupRequest(ctx context.Context, requests []*relationTb.GroupRequestModel) error { - return g.groupRequestDB.Create(ctx, requests) + return g.tx.Transaction(func(tx any) error { + db := g.groupRequestDB.NewTx(tx) + for _, request := range requests { + if err := db.Delete(ctx, request.GroupID, request.UserID); err != nil { + return err + } + } + return db.Create(ctx, requests) + }) } func (g *groupDatabase) TakeGroupRequest(ctx context.Context, groupID string, userID string) (*relationTb.GroupRequestModel, error) { diff --git a/pkg/common/db/relation/group_request_model.go b/pkg/common/db/relation/group_request_model.go index e7e15539b..ea419fabf 100644 --- a/pkg/common/db/relation/group_request_model.go +++ b/pkg/common/db/relation/group_request_model.go @@ -24,11 +24,15 @@ func (g *GroupRequestGorm) NewTx(tx any) relation.GroupRequestModelInterface { } func (g *GroupRequestGorm) Create(ctx context.Context, groupRequests []*relation.GroupRequestModel) (err error) { - return utils.Wrap(g.DB.Create(&groupRequests).Error, utils.GetSelfFuncName()) + return utils.Wrap(g.DB.WithContext(ctx).Create(&groupRequests).Error, utils.GetSelfFuncName()) +} + +func (g *GroupRequestGorm) Delete(ctx context.Context, groupID string, userID string) (err error) { + return utils.Wrap(g.DB.WithContext(ctx).Where("group_id = ? and user_id = ? ", groupID, userID).Delete(&relation.GroupRequestModel{}).Error, utils.GetSelfFuncName()) } func (g *GroupRequestGorm) UpdateHandler(ctx context.Context, groupID string, userID string, handledMsg string, handleResult int32) (err error) { - return utils.Wrap(g.DB.Model(&relation.GroupRequestModel{}).Where("group_id = ? and user_id = ? ", groupID, userID).Updates(map[string]any{ + return utils.Wrap(g.DB.WithContext(ctx).Model(&relation.GroupRequestModel{}).Where("group_id = ? and user_id = ? ", groupID, userID).Updates(map[string]any{ "handle_msg": handledMsg, "handle_result": handleResult, }).Error, utils.GetSelfFuncName()) @@ -36,9 +40,9 @@ func (g *GroupRequestGorm) UpdateHandler(ctx context.Context, groupID string, us func (g *GroupRequestGorm) Take(ctx context.Context, groupID string, userID string) (groupRequest *relation.GroupRequestModel, err error) { groupRequest = &relation.GroupRequestModel{} - return groupRequest, utils.Wrap(g.DB.Where("group_id = ? and user_id = ? ", groupID, userID).Take(groupRequest).Error, utils.GetSelfFuncName()) + return groupRequest, utils.Wrap(g.DB.WithContext(ctx).Where("group_id = ? and user_id = ? ", groupID, userID).Take(groupRequest).Error, utils.GetSelfFuncName()) } func (g *GroupRequestGorm) Page(ctx context.Context, userID string, pageNumber, showNumber int32) (total uint32, groups []*relation.GroupRequestModel, err error) { - return ormutil.GormSearch[relation.GroupRequestModel](g.DB.Where("user_id = ?", userID), nil, "", pageNumber, showNumber) + return ormutil.GormSearch[relation.GroupRequestModel](g.DB.WithContext(ctx).Where("user_id = ?", userID), nil, "", pageNumber, showNumber) } diff --git a/pkg/common/db/table/relation/group_request.go b/pkg/common/db/table/relation/group_request.go index 03ca5fda0..a40da1ed0 100644 --- a/pkg/common/db/table/relation/group_request.go +++ b/pkg/common/db/table/relation/group_request.go @@ -30,6 +30,7 @@ func (GroupRequestModel) TableName() string { type GroupRequestModelInterface interface { NewTx(tx any) GroupRequestModelInterface Create(ctx context.Context, groupRequests []*GroupRequestModel) (err error) + Delete(ctx context.Context, groupID string, userID string) (err error) UpdateHandler(ctx context.Context, groupID string, userID string, handledMsg string, handleResult int32) (err error) Take(ctx context.Context, groupID string, userID string) (groupRequest *GroupRequestModel, err error) Page(ctx context.Context, userID string, pageNumber, showNumber int32) (total uint32, groups []*GroupRequestModel, err error)