diff --git a/internal/rpc/group/group.go b/internal/rpc/group/group.go index b758db1f3..a088e7788 100644 --- a/internal/rpc/group/group.go +++ b/internal/rpc/group/group.go @@ -147,15 +147,15 @@ func (s *groupServer) CreateGroup(ctx context.Context, req *pbGroup.CreateGroupR return nil, constant.ErrArgs.Wrap("no group owner") } var userIDs []string - for _, ordinaryUserID := range req.InitMemberList { - userIDs = append(userIDs, ordinaryUserID) + for _, userID := range req.InitMemberList { + userIDs = append(userIDs, userID) + } + for _, userID := range req.AdminUserIDs { + userIDs = append(userIDs, userID) } userIDs = append(userIDs, req.OwnerUserID) - for _, adminUserID := range req.AdminUserIDs { - userIDs = append(userIDs, adminUserID) - } if utils.IsDuplicateID(userIDs) { - return nil, constant.ErrArgs.Wrap("group member is repeated") + return nil, constant.ErrArgs.Wrap("group member repeated") } users, err := getUsersInfo(ctx, userIDs) if err != nil { @@ -165,68 +165,61 @@ func (s *groupServer) CreateGroup(ctx context.Context, req *pbGroup.CreateGroupR for i, user := range users { userMap[user.UserID] = users[i] } + for _, userID := range userIDs { + if userMap[userID] == nil { + return nil, constant.ErrUserIDNotFound.Wrap(userID) + } + } if err := callbackBeforeCreateGroup(ctx, req); err != nil { return nil, err } - var groupInfo relation.Group - utils.CopyStructFields(&groupInfo, req.GroupInfo) - - groupInfo, err := (&cp.PBGroup{req.GroupInfo}).Convert() - groupInfo.GroupID = genGroupID(ctx, req.GroupInfo.GroupID) - if req.GroupInfo.GroupType != constant.SuperGroup { - var groupMembers []*relation.GroupMember + var group relation.Group + var groupMembers []*relation.GroupMember + utils.CopyStructFields(&group, req.GroupInfo) + group.GroupID = genGroupID(ctx, req.GroupInfo.GroupID) + if req.GroupInfo.GroupType == constant.SuperGroup { + if err := s.GroupInterface.CreateSuperGroup(ctx, group.GroupID, userIDs); err != nil { + return nil, err + } + } else { + opUserID := tools.OpUserID(ctx) joinGroup := func(userID string, roleLevel int32) error { - groupMember := &relation.GroupMember{GroupID: groupInfo.GroupID, RoleLevel: roleLevel, OperatorUserID: tools.OpUserID(ctx), JoinSource: constant.JoinByInvitation, InviterUserID: tools.OpUserID(ctx)} user := userMap[userID] + groupMember := &relation.GroupMember{GroupID: group.GroupID, RoleLevel: roleLevel, OperatorUserID: opUserID, JoinSource: constant.JoinByInvitation, InviterUserID: opUserID} utils.CopyStructFields(&groupMember, user) - if err := CallbackBeforeMemberJoinGroup(ctx, tools.OperationID(ctx), groupMember, groupInfo.Ex); err != nil { + if err := CallbackBeforeMemberJoinGroup(ctx, tools.OperationID(ctx), groupMember, group.Ex); err != nil { return err } groupMembers = append(groupMembers, groupMember) return nil } - if err := joinGroup(req.OwnerUserID, constant.GroupOwner); err != nil { return nil, err } - - for _, info := range req.InitMemberList { - if err := joinGroup(info, constant.GroupOrdinaryUsers); err != nil { + for _, userID := range req.AdminUserIDs { + if err := joinGroup(userID, constant.GroupAdmin); err != nil { return nil, err } } - for _, info := range req.AdminUserIDs { - if err := joinGroup(info, constant.GroupAdmin); err != nil { + for _, userID := range req.InitMemberList { + if err := joinGroup(userID, constant.GroupOrdinaryUsers); err != nil { return nil, err } } - if err := (*relation.GroupMember)(nil).Create(ctx, groupMembers); err != nil { - return nil, err - } - - } else { - if err := db.DB.CreateSuperGroup(groupId, userIDs, len(userIDs)); err != nil { - return nil, err - } } - if err := (*relation.Group)(nil).Create(ctx, []*relation.Group{&groupInfo}); err != nil { + if err := s.GroupInterface.CreateGroup(ctx, []*relation.Group{&group}, groupMembers); err != nil { return nil, err } - utils.CopyStructFields(resp.GroupInfo, groupInfo) + utils.CopyStructFields(resp.GroupInfo, group) resp.GroupInfo.MemberCount = uint32(len(userIDs)) - if req.GroupInfo.GroupType != constant.SuperGroup { - chat.GroupCreatedNotification(tools.OperationID(ctx), tools.OpUserID(ctx), groupId, userIDs) - } else { - for _, userID := range userIDs { - if err := rocksCache.DelJoinedSuperGroupIDListFromCache(ctx, userID); err != nil { - trace_log.SetCtxInfo(ctx, "DelJoinedSuperGroupIDListFromCache", err, "userID", userID) - } - } + if req.GroupInfo.GroupType == constant.SuperGroup { go func() { - for _, v := range userIDs { - chat.SuperGroupNotification(tools.OperationID(ctx), v, v) + for _, userID := range userIDs { + chat.SuperGroupNotification(tools.OperationID(ctx), userID, userID) } }() + } else { + chat.GroupCreatedNotification(tools.OperationID(ctx), tools.OpUserID(ctx), group.GroupID, userIDs) } return resp, nil } diff --git a/pkg/common/db/controller/black.go b/pkg/common/db/controller/black.go index 3ec2f154f..50b5ca4ee 100644 --- a/pkg/common/db/controller/black.go +++ b/pkg/common/db/controller/black.go @@ -2,7 +2,7 @@ package controller import ( "Open_IM/pkg/common/db/cache" - "Open_IM/pkg/common/db/mysql" + "Open_IM/pkg/common/db/relation" "context" "errors" "gorm.io/gorm" diff --git a/pkg/common/db/controller/friend.go b/pkg/common/db/controller/friend.go index 957158d5c..e59c71b5a 100644 --- a/pkg/common/db/controller/friend.go +++ b/pkg/common/db/controller/friend.go @@ -2,7 +2,6 @@ package controller import ( "Open_IM/pkg/common/db/cache" - "Open_IM/pkg/common/db/mysql" "Open_IM/pkg/common/db/relation" "context" "errors" diff --git a/pkg/common/db/controller/friend_request.go b/pkg/common/db/controller/friend_request.go index 647016152..e301bc92a 100644 --- a/pkg/common/db/controller/friend_request.go +++ b/pkg/common/db/controller/friend_request.go @@ -2,7 +2,7 @@ package controller import ( "Open_IM/pkg/common/db/cache" - "Open_IM/pkg/common/db/mysql" + "Open_IM/pkg/common/db/relation" "context" ) diff --git a/pkg/common/db/controller/group.go b/pkg/common/db/controller/group.go index 329c9cbc5..293cbf85a 100644 --- a/pkg/common/db/controller/group.go +++ b/pkg/common/db/controller/group.go @@ -14,12 +14,12 @@ import ( type GroupInterface interface { FindGroupsByID(ctx context.Context, groupIDs []string) (groups []*relation.Group, err error) - CreateGroup(ctx context.Context, groups []*relation.Group) error + CreateGroup(ctx context.Context, groups []*relation.Group, groupMember []*relation.GroupMember) error DeleteGroupByIDs(ctx context.Context, groupIDs []string) error TakeGroupByID(ctx context.Context, groupID string) (group *relation.Group, err error) //mongo - CreateSuperGroup(ctx context.Context, groupID string, initMemberIDList []string, memberNumCount int) error + CreateSuperGroup(ctx context.Context, groupID string, initMemberIDList []string) error GetSuperGroupByID(ctx context.Context, groupID string) (superGroup *unrelation.SuperGroup, err error) } @@ -36,8 +36,8 @@ func (g *GroupController) FindGroupsByID(ctx context.Context, groupIDs []string) return g.database.FindGroupsByID(ctx, groupIDs) } -func (g *GroupController) CreateGroup(ctx context.Context, groups []*relation.Group) error { - return g.database.CreateGroup(ctx, groups) +func (g *GroupController) CreateGroup(ctx context.Context, groups []*relation.Group, groupMember []*relation.GroupMember) error { + return g.database.CreateGroup(ctx, groups, groupMember) } func (g *GroupController) DeleteGroupByIDs(ctx context.Context, groupIDs []string) error { @@ -52,17 +52,17 @@ func (g *GroupController) GetSuperGroupByID(ctx context.Context, groupID string) return g.database.GetSuperGroupByID(ctx, groupID) } -func (g *GroupController) CreateSuperGroup(ctx context.Context, groupID string, initMemberIDList []string, memberNumCount int) error { - return g.database.CreateSuperGroup(ctx, groupID, initMemberIDList, memberNumCount) +func (g *GroupController) CreateSuperGroup(ctx context.Context, groupID string, initMemberIDList []string) error { + return g.database.CreateSuperGroup(ctx, groupID, initMemberIDList) } type DataBase interface { FindGroupsByID(ctx context.Context, groupIDs []string) (groups []*relation.Group, err error) - CreateGroup(ctx context.Context, groups []*relation.Group) error + CreateGroup(ctx context.Context, groups []*relation.Group, groupMember []*relation.GroupMember) error DeleteGroupByIDs(ctx context.Context, groupIDs []string) error TakeGroupByID(ctx context.Context, groupID string) (group *relation.Group, err error) GetSuperGroupByID(ctx context.Context, groupID string) (superGroup *unrelation.SuperGroup, err error) - CreateSuperGroup(ctx context.Context, groupID string, initMemberIDList []string, memberNumCount int) error + CreateSuperGroup(ctx context.Context, groupID string, initMemberIDList []string) error } type GroupDataBase struct { @@ -100,8 +100,18 @@ func (g *GroupDataBase) FindGroupsByID(ctx context.Context, groupIDs []string) ( return g.cache.GetGroupsInfo(ctx, groupIDs) } -func (g *GroupDataBase) CreateGroup(ctx context.Context, groups []*relation.Group) error { - return g.groupDB.Create(ctx, groups) +func (g *GroupDataBase) CreateGroup(ctx context.Context, groups []*relation.Group, groupMember []*relation.GroupMember) error { + return g.db.Transaction(func(tx *gorm.DB) error { + if err := g.groupDB.Create(ctx, groups, tx); err != nil { + return err + } + if len(groupMember) > 0 { + if err := g.groupMemberDB.Create(ctx, groupMember, tx); err != nil { + return err + } + } + return nil + }) } func (g *GroupDataBase) DeleteGroupByIDs(ctx context.Context, groupIDs []string) error { @@ -136,14 +146,14 @@ func (g *GroupDataBase) Update(ctx context.Context, groups []*relation.Group) er }) } -func (g *GroupDataBase) CreateSuperGroup(ctx context.Context, groupID string, initMemberIDList []string, memberNumCount int) error { +func (g *GroupDataBase) CreateSuperGroup(ctx context.Context, groupID string, initMemberIDList []string) error { sess, err := g.mongoDB.MgoClient.StartSession() if err != nil { return err } defer sess.EndSession(ctx) sCtx := mongo.NewSessionContext(ctx, sess) - if err = g.mongoDB.CreateSuperGroup(sCtx, groupID, initMemberIDList, memberNumCount); err != nil { + if err = g.mongoDB.CreateSuperGroup(sCtx, groupID, initMemberIDList); err != nil { _ = sess.AbortTransaction(ctx) return err } @@ -158,15 +168,3 @@ func (g *GroupDataBase) CreateSuperGroup(ctx context.Context, groupID string, in func (g *GroupDataBase) GetSuperGroupByID(ctx context.Context, groupID string) (superGroup *unrelation.SuperGroup, err error) { return g.mongoDB.GetSuperGroup(ctx, groupID) } - -func (g *GroupDataBase) CreateGroupAndMember(ctx context.Context, groups []*relation.Group, groupMember []*relation.GroupMember) error { - return g.db.Transaction(func(tx *gorm.DB) error { - if err := g.groupDB.Create(ctx, groups, tx); err != nil { - return err - } - if err := g.groupMemberDB.Create(ctx, groupMember, tx); err != nil { - return err - } - return nil - }) -} diff --git a/pkg/common/db/unrelation/super_group.go b/pkg/common/db/unrelation/super_group.go index 5358f7b4c..c51bfcec9 100644 --- a/pkg/common/db/unrelation/super_group.go +++ b/pkg/common/db/unrelation/super_group.go @@ -37,7 +37,7 @@ func NewSuperGroupMgoDB(mgoClient *mongo.Client) *SuperGroupMgoDB { return &SuperGroupMgoDB{MgoDB: mgoDB, MgoClient: mgoClient, superGroupCollection: mgoDB.Collection(cSuperGroup), userToSuperGroupCollection: mgoDB.Collection(cUserToSuperGroup)} } -func (db *SuperGroupMgoDB) CreateSuperGroup(sCtx mongo.SessionContext, groupID string, initMemberIDList []string, memberNumCount int) error { +func (db *SuperGroupMgoDB) CreateSuperGroup(sCtx mongo.SessionContext, groupID string, initMemberIDList []string) error { superGroup := SuperGroup{ GroupID: groupID, MemberIDList: initMemberIDList,