diff --git a/internal/rpc/group/super_group.go b/internal/rpc/group/super_group.go index 5e0fc1892..ebdab3723 100644 --- a/internal/rpc/group/super_group.go +++ b/internal/rpc/group/super_group.go @@ -3,42 +3,53 @@ package group import ( "Open_IM/pkg/common/constant" "Open_IM/pkg/common/db/table/relation" + "Open_IM/pkg/common/db/table/unrelation" pbGroup "Open_IM/pkg/proto/group" sdk_ws "Open_IM/pkg/proto/sdk_ws" "Open_IM/pkg/utils" "context" + "fmt" + "strings" ) func (s *groupServer) GetJoinedSuperGroupList(ctx context.Context, req *pbGroup.GetJoinedSuperGroupListReq) (*pbGroup.GetJoinedSuperGroupListResp, error) { resp := &pbGroup.GetJoinedSuperGroupListResp{} - total, groupIDs, err := s.GroupInterface.FindJoinSuperGroup(ctx, req.UserID, req.Pagination.PageNumber, req.Pagination.ShowNumber) + joinSuperGroup, err := s.GroupInterface.FindJoinSuperGroup(ctx, req.UserID) if err != nil { return nil, err } - resp.Total = total - if len(groupIDs) == 0 { + if len(joinSuperGroup.GroupIDs) == 0 { return resp, nil } - numMap, err := s.GroupInterface.MapSuperGroupMemberNum(ctx, groupIDs) - if err != nil { - return nil, err - } - owners, err := s.GroupInterface.FindGroupMember(ctx, groupIDs, nil, []int32{constant.GroupOwner}) + owners, err := s.GroupInterface.FindGroupMember(ctx, joinSuperGroup.GroupIDs, nil, []int32{constant.GroupOwner}) if err != nil { return nil, err } ownerMap := utils.SliceToMap(owners, func(e *relation.GroupMemberModel) string { return e.GroupID }) - groups, err := s.GroupInterface.FindGroup(ctx, groupIDs) + if ids := utils.Single(joinSuperGroup.GroupIDs, utils.Keys(ownerMap)); len(ids) > 0 { + return nil, constant.ErrData.Wrap(fmt.Sprintf("super group %s not owner", strings.Join(ids, ","))) + } + groups, err := s.GroupInterface.FindGroup(ctx, joinSuperGroup.GroupIDs) if err != nil { return nil, err } groupMap := utils.SliceToMap(groups, func(e *relation.GroupModel) string { return e.GroupID }) - resp.Groups = utils.Slice(groupIDs, func(groupID string) *sdk_ws.GroupInfo { - return DbToPbGroupInfo(groupMap[groupID], ownerMap[groupID].UserID, numMap[groupID]) + if ids := utils.Single(joinSuperGroup.GroupIDs, utils.Keys(groupMap)); len(ids) > 0 { + return nil, constant.ErrData.Wrap(fmt.Sprintf("super group info %s not found", strings.Join(ids, ","))) + } + superGroupMembers, err := s.GroupInterface.FindSuperGroup(ctx, joinSuperGroup.GroupIDs) + if err != nil { + return nil, err + } + superGroupMemberMap := utils.SliceToMapAny(superGroupMembers, func(e *unrelation.SuperGroupModel) (string, []string) { + return e.GroupID, e.MemberIDs + }) + resp.Groups = utils.Slice(joinSuperGroup.GroupIDs, func(groupID string) *sdk_ws.GroupInfo { + return DbToPbGroupInfo(groupMap[groupID], ownerMap[groupID].UserID, uint32(len(superGroupMemberMap))) }) return resp, nil } @@ -52,10 +63,13 @@ func (s *groupServer) GetSuperGroupsInfo(ctx context.Context, req *pbGroup.GetSu if err != nil { return nil, err } - numMap, err := s.GroupInterface.MapSuperGroupMemberNum(ctx, req.GroupIDs) + superGroupMembers, err := s.GroupInterface.FindSuperGroup(ctx, req.GroupIDs) if err != nil { return nil, err } + superGroupMemberMap := utils.SliceToMapAny(superGroupMembers, func(e *unrelation.SuperGroupModel) (string, []string) { + return e.GroupID, e.MemberIDs + }) owners, err := s.GroupInterface.FindGroupMember(ctx, req.GroupIDs, nil, []int32{constant.GroupOwner}) if err != nil { return nil, err @@ -64,7 +78,7 @@ func (s *groupServer) GetSuperGroupsInfo(ctx context.Context, req *pbGroup.GetSu return e.GroupID }) resp.GroupInfos = utils.Slice(groups, func(e *relation.GroupModel) *sdk_ws.GroupInfo { - return DbToPbGroupInfo(e, ownerMap[e.GroupID].UserID, numMap[e.GroupID]) + return DbToPbGroupInfo(e, ownerMap[e.GroupID].UserID, uint32(len(superGroupMemberMap[e.GroupID]))) }) return resp, nil } diff --git a/pkg/common/db/controller/group.go b/pkg/common/db/controller/group.go index d715e400b..a10abdb61 100644 --- a/pkg/common/db/controller/group.go +++ b/pkg/common/db/controller/group.go @@ -43,13 +43,12 @@ type GroupInterface interface { TakeGroupRequest(ctx context.Context, groupID string, userID string) (*relation2.GroupRequestModel, error) PageGroupRequestUser(ctx context.Context, userID string, pageNumber, showNumber int32) (int32, []*relation2.GroupRequestModel, error) // SuperGroup - TakeSuperGroup(ctx context.Context, groupID string) (superGroup *unrelation2.SuperGroupModel, err error) - FindJoinSuperGroup(ctx context.Context, userID string, pageNumber, showNumber int32) (total int32, groupIDs []string, err error) + FindSuperGroup(ctx context.Context, groupIDs []string) ([]*unrelation2.SuperGroupModel, error) + FindJoinSuperGroup(ctx context.Context, userID string) (superGroup *unrelation2.UserToSuperGroupModel, err error) CreateSuperGroup(ctx context.Context, groupID string, initMemberIDList []string) error DeleteSuperGroup(ctx context.Context, groupID string) error DeleteSuperGroupMember(ctx context.Context, groupID string, userIDs []string) error CreateSuperGroupMember(ctx context.Context, groupID string, userIDs []string) error - MapSuperGroupMemberNum(ctx context.Context, groupIDs []string) (map[string]uint32, error) } var _ GroupInterface = (*GroupController)(nil) @@ -138,12 +137,15 @@ func (g *GroupController) PageGroupRequestUser(ctx context.Context, userID strin return g.database.PageGroupRequestUser(ctx, userID, pageNumber, showNumber) } -func (g *GroupController) TakeSuperGroup(ctx context.Context, groupID string) (superGroup *unrelation2.SuperGroupModel, err error) { - return g.database.TakeSuperGroup(ctx, groupID) +// func (g *GroupController) TakeSuperGroup(ctx context.Context, groupID string) (superGroup *unrelation2.SuperGroupModel, err error) { +// return g.database.TakeSuperGroup(ctx, groupID) +// } +func (g *GroupController) FindSuperGroup(ctx context.Context, groupIDs []string) ([]*unrelation2.SuperGroupModel, error) { + return g.database.FindSuperGroup(ctx, groupIDs) } -func (g *GroupController) FindJoinSuperGroup(ctx context.Context, userID string, pageNumber, showNumber int32) (total int32, groupIDs []string, err error) { - return g.database.FindJoinSuperGroup(ctx, userID, pageNumber, showNumber) +func (g *GroupController) FindJoinSuperGroup(ctx context.Context, userID string) (*unrelation2.UserToSuperGroupModel, error) { + return g.database.FindJoinSuperGroup(ctx, userID) } func (g *GroupController) CreateSuperGroup(ctx context.Context, groupID string, initMemberIDList []string) error { @@ -162,10 +164,6 @@ func (g *GroupController) CreateSuperGroupMember(ctx context.Context, groupID st return g.database.CreateSuperGroupMember(ctx, groupID, userIDs) } -func (g *GroupController) MapSuperGroupMemberNum(ctx context.Context, groupIDs []string) (map[string]uint32, error) { - return g.database.MapSuperGroupMemberNum(ctx, groupIDs) -} - type GroupDataBaseInterface interface { CreateGroup(ctx context.Context, groups []*relation2.GroupModel, groupMembers []*relation2.GroupMemberModel) error TakeGroup(ctx context.Context, groupID string) (group *relation2.GroupModel, err error) @@ -190,13 +188,13 @@ type GroupDataBaseInterface interface { TakeGroupRequest(ctx context.Context, groupID string, userID string) (*relation2.GroupRequestModel, error) PageGroupRequestUser(ctx context.Context, userID string, pageNumber, showNumber int32) (int32, []*relation2.GroupRequestModel, error) // SuperGroup - TakeSuperGroup(ctx context.Context, groupID string) (superGroup *unrelation2.SuperGroupModel, err error) - FindJoinSuperGroup(ctx context.Context, userID string, pageNumber, showNumber int32) (total int32, groupIDs []string, err error) + //TakeSuperGroup(ctx context.Context, groupID string) (superGroup *unrelation2.SuperGroupModel, err error) + FindSuperGroup(ctx context.Context, groupIDs []string) ([]*unrelation2.SuperGroupModel, error) + FindJoinSuperGroup(ctx context.Context, userID string) (*unrelation2.UserToSuperGroupModel, error) CreateSuperGroup(ctx context.Context, groupID string, initMemberIDList []string) error DeleteSuperGroup(ctx context.Context, groupID string) error DeleteSuperGroupMember(ctx context.Context, groupID string, userIDs []string) error CreateSuperGroupMember(ctx context.Context, groupID string, userIDs []string) error - MapSuperGroupMemberNum(ctx context.Context, groupIDs []string) (map[string]uint32, error) } func newGroupDatabase(db *gorm.DB, rdb redis.UniversalClient, mgoClient *mongo.Client) GroupDataBaseInterface { @@ -356,12 +354,16 @@ func (g *GroupDataBase) PageGroupRequestUser(ctx context.Context, userID string, return g.groupRequestDB.Page(ctx, userID, pageNumber, showNumber) } -func (g *GroupDataBase) TakeSuperGroup(ctx context.Context, groupID string) (superGroup *unrelation2.SuperGroupModel, err error) { - return g.mongoDB.GetSuperGroup(ctx, groupID) +//func (g *GroupDataBase) TakeSuperGroup(ctx context.Context, groupID string) (superGroup *unrelation2.SuperGroupModel, err error) { +// return g.mongoDB.GetSuperGroup(ctx, groupID) +//} + +func (g *GroupDataBase) FindSuperGroup(ctx context.Context, groupIDs []string) ([]*unrelation2.SuperGroupModel, error) { + return g.mongoDB.FindSuperGroup(ctx, groupIDs) } -func (g *GroupDataBase) FindJoinSuperGroup(ctx context.Context, userID string, pageNumber, showNumber int32) (total int32, groupIDs []string, err error) { - return g.mongoDB.GetJoinGroup(ctx, userID, pageNumber, showNumber) +func (g *GroupDataBase) FindJoinSuperGroup(ctx context.Context, userID string) (*unrelation2.UserToSuperGroupModel, error) { + return g.mongoDB.GetSuperGroupByUserID(ctx, userID) } func (g *GroupDataBase) CreateSuperGroup(ctx context.Context, groupID string, initMemberIDList []string) error { @@ -385,10 +387,6 @@ func (g *GroupDataBase) CreateSuperGroupMember(ctx context.Context, groupID stri return g.mongoDB.AddUserToSuperGroup(ctx, groupID, userIDs) } -func (g *GroupDataBase) MapSuperGroupMemberNum(ctx context.Context, groupIDs []string) (map[string]uint32, error) { - return g.mongoDB.MapGroupMemberCount(ctx, groupIDs) -} - func MongoTransaction(ctx context.Context, mgo *mongo.Client, fn func(ctx mongo.SessionContext) error) error { sess, err := mgo.StartSession() if err != nil { diff --git a/pkg/common/db/unrelation/super_group.go b/pkg/common/db/unrelation/super_group.go index e3a57d624..6f7ee5fd3 100644 --- a/pkg/common/db/unrelation/super_group.go +++ b/pkg/common/db/unrelation/super_group.go @@ -46,10 +46,24 @@ func (db *SuperGroupMongoDriver) CreateSuperGroup(sCtx mongo.SessionContext, gro } -func (db *SuperGroupMongoDriver) GetSuperGroup(ctx context.Context, groupID string) (*unrelation.SuperGroupModel, error) { - superGroup := unrelation.SuperGroupModel{} - err := db.superGroupCollection.FindOne(ctx, bson.M{"group_id": groupID}).Decode(&superGroup) - return &superGroup, err +//func (db *SuperGroupMongoDriver) GetSuperGroup(ctx context.Context, groupID string) (*unrelation.SuperGroupModel, error) { +// superGroup := unrelation.SuperGroupModel{} +// err := db.superGroupCollection.FindOne(ctx, bson.M{"group_id": groupID}).Decode(&superGroup) +// return &superGroup, err +//} + +func (db *SuperGroupMongoDriver) FindSuperGroup(ctx context.Context, groupIDs []string) (groups []*unrelation.SuperGroupModel, err error) { + cursor, err := db.superGroupCollection.Find(ctx, bson.M{"group_id": bson.M{ + "$in": groupIDs, + }}) + if err != nil { + return nil, utils.Wrap(err, "") + } + defer cursor.Close(ctx) + if err := cursor.All(ctx, &groups); err != nil { + return nil, utils.Wrap(err, "") + } + return groups, nil } func (db *SuperGroupMongoDriver) AddUserToSuperGroup(ctx context.Context, groupID string, userIDs []string) error { @@ -98,16 +112,6 @@ func (db *SuperGroupMongoDriver) GetSuperGroupByUserID(ctx context.Context, user return &user, utils.Wrap(err, "") } -func (db *SuperGroupMongoDriver) GetJoinGroup(ctx context.Context, userID string, pageNumber, showNumber int32) (int32, []string, error) { - //TODO implement me - panic("implement me") -} - -func (db *SuperGroupMongoDriver) MapGroupMemberCount(ctx context.Context, groupIDs []string) (map[string]uint32, error) { - //TODO implement me - panic("implement me") -} - func (db *SuperGroupMongoDriver) DeleteSuperGroup(ctx context.Context, groupID string) error { opts := options.Session().SetDefaultReadConcern(readconcern.Majority()) return db.MgoDB.Client().UseSessionWithOptions(ctx, opts, func(sCtx mongo.SessionContext) error { diff --git a/pkg/utils/utils_v2.go b/pkg/utils/utils_v2.go index 76d3946f8..c01103bfa 100644 --- a/pkg/utils/utils_v2.go +++ b/pkg/utils/utils_v2.go @@ -322,6 +322,10 @@ func If[T any](isa bool, a, b T) T { return b } +func ToPtr[T any](t T) *T { + return &t +} + // Equal 比较切片是否相对(包括元素顺序) func Equal[E comparable](a []E, b []E) bool { if len(a) != len(b) {