diff --git a/pkg/common/db/cache/group.go b/pkg/common/db/cache/group.go index ab5c3f9b2..578fe8154 100644 --- a/pkg/common/db/cache/group.go +++ b/pkg/common/db/cache/group.go @@ -45,12 +45,14 @@ type GroupCache interface { GetGroupsMemberIDs(ctx context.Context, groupIDs []string) (groupMemberIDs map[string][]string, err error) DelGroupMemberIDs(groupID string) GroupCache + + GetJoinedGroupIDs(ctx context.Context, userID string) (joinedGroupIDs []string, err error) DelJoinedGroupID(userID ...string) GroupCache GetGroupMemberInfo(ctx context.Context, groupID, userID string) (groupMember *relationTb.GroupMemberModel, err error) - GetGroupMembersInfo(ctx context.Context, groupID string, userID []string, roleLevel []int32) (groupMembers []*relationTb.GroupMemberModel, err error) + GetGroupMembersInfo(ctx context.Context, groupID string, userID []string) (groupMembers []*relationTb.GroupMemberModel, err error) GetAllGroupMembersInfo(ctx context.Context, groupID string) (groupMembers []*relationTb.GroupMemberModel, err error) - GetGroupMembersPage(ctx context.Context, groupID string, showNumber, pageNumber int32) (groupMembers []*relationTb.GroupMemberModel, err error) + GetGroupMembersPage(ctx context.Context, groupID string, userID []string, showNumber, pageNumber int32) (groupMembers []*relationTb.GroupMemberModel, err error) DelGroupMembersInfo(groupID string, userID ...string) GroupCache @@ -294,22 +296,23 @@ func (g *GroupCacheRedis) GetGroupMemberInfo(ctx context.Context, groupID, userI }) } -func (g *GroupCacheRedis) GetGroupMembersInfo(ctx context.Context, groupID string, userIDs []string, roleLevel []int32) ([]*relationTb.GroupMemberModel, error) { +func (g *GroupCacheRedis) GetGroupMembersInfo(ctx context.Context, groupID string, userIDs []string) ([]*relationTb.GroupMemberModel, error) { var keys []string for _, userID := range userIDs { keys = append(keys, g.getGroupMemberInfoKey(groupID, userID)) } return batchGetCache(ctx, g.rcClient, keys, g.expireTime, g.GetGroupMemberIndex, func(ctx context.Context) ([]*relationTb.GroupMemberModel, error) { - return g.groupMemberDB.Find(ctx, []string{groupID}, userIDs, roleLevel) + return g.groupMemberDB.Find(ctx, []string{groupID}, userIDs, nil) }) } -func (g *GroupCacheRedis) GetGroupMembersPage(ctx context.Context, groupID string, showNumber, pageNumber int32) (groupMembers []*relationTb.GroupMemberModel, err error) { +func (g *GroupCacheRedis) GetGroupMembersPage(ctx context.Context, groupID string, userIDs []string, showNumber, pageNumber int32) (groupMembers []*relationTb.GroupMemberModel, err error) { groupMemberIDs, err := g.GetGroupMemberIDs(ctx, groupID) if err != nil { return nil, err } - return g.GetGroupMembersInfo(ctx, groupID, utils.Paginate(groupMemberIDs, int(showNumber), int(showNumber)), nil) + userIDs = utils.BothExist(userIDs, groupMemberIDs) + return g.GetGroupMembersInfo(ctx, groupID, utils.Paginate(userIDs, int(showNumber), int(showNumber))) } func (g *GroupCacheRedis) GetAllGroupMembersInfo(ctx context.Context, groupID string) (groupMembers []*relationTb.GroupMemberModel, err error) { @@ -317,7 +320,7 @@ func (g *GroupCacheRedis) GetAllGroupMembersInfo(ctx context.Context, groupID st if err != nil { return nil, err } - return g.GetGroupMembersInfo(ctx, groupID, groupMemberIDs, nil) + return g.GetGroupMembersInfo(ctx, groupID, groupMemberIDs) } func (g *GroupCacheRedis) GetAllGroupMemberInfo(ctx context.Context, groupID string) ([]*relationTb.GroupMemberModel, error) { diff --git a/pkg/common/db/controller/group.go b/pkg/common/db/controller/group.go index 87b2320f6..6a741ec41 100644 --- a/pkg/common/db/controller/group.go +++ b/pkg/common/db/controller/group.go @@ -177,9 +177,16 @@ func (g *groupDatabase) TakeGroupOwner(ctx context.Context, groupID string) (*re return g.groupMemberDB.TakeOwner(ctx, groupID) // todo cache group owner } -func (g *groupDatabase) FindGroupMember(ctx context.Context, groupIDs []string, userIDs []string, roleLevels []int32) ([]*relationTb.GroupMemberModel, error) { +func (g *groupDatabase) FindGroupMember(ctx context.Context, groupIDs []string, userIDs []string, roleLevels []int32) (totalGroupMembers []*relationTb.GroupMemberModel, err error) { if roleLevels == nil { - return g.cache.GetGroupMembersInfo(ctx, groupIDs[0], userIDs, nil) + for _, groupID := range groupIDs { + groupMembers, err := g.cache.GetGroupMembersInfo(ctx, groupID, userIDs) + if err != nil { + return nil, err + } + totalGroupMembers = append(totalGroupMembers, groupMembers...) + } + return totalGroupMembers, nil } return g.groupMemberDB.Find(ctx, groupIDs, userIDs, roleLevels) } @@ -187,8 +194,25 @@ func (g *groupDatabase) FindGroupMember(ctx context.Context, groupIDs []string, func (g *groupDatabase) PageGroupMember(ctx context.Context, groupIDs []string, userIDs []string, roleLevels []int32, pageNumber, showNumber int32) (total uint32, totalGroupMembers []*relationTb.GroupMemberModel, err error) { if roleLevels == nil { if pageNumber == 0 || showNumber == 0 { + if groupIDs == nil { + for _, userID := range userIDs { + groupIDs, err := g.cache.GetJoinedGroupIDs(ctx, userID) + if err != nil { + return 0, nil, err + } + for _, groupID := range groupIDs { + groupMembers, err := g.cache.GetGroupMembersInfo(ctx, groupID, []string{userID}) + if err != nil { + return 0, nil, err + } + totalGroupMembers = append(totalGroupMembers, groupMembers...) + } + } + return + } + for _, groupID := range groupIDs { - groupMembers, err := g.cache.GetAllGroupMembersInfo(ctx, groupID) + groupMembers, err := g.cache.GetGroupMembersInfo(ctx, groupID, userIDs) if err != nil { return 0, nil, err } @@ -197,7 +221,7 @@ func (g *groupDatabase) PageGroupMember(ctx context.Context, groupIDs []string, return uint32(len(totalGroupMembers)), totalGroupMembers, nil } else { for _, groupID := range groupIDs { - groupMembers, err := g.cache.GetGroupMembersPage(ctx, groupID, pageNumber, showNumber) + groupMembers, err := g.cache.GetGroupMembersPage(ctx, groupID, userIDs, pageNumber, showNumber) if err != nil { return 0, nil, err } @@ -205,7 +229,6 @@ func (g *groupDatabase) PageGroupMember(ctx context.Context, groupIDs []string, } return uint32(len(totalGroupMembers)), totalGroupMembers, nil } - } return g.groupMemberDB.SearchMember(ctx, "", groupIDs, userIDs, roleLevels, pageNumber, showNumber) } diff --git a/pkg/common/db/relation/group_member_model.go b/pkg/common/db/relation/group_member_model.go index 68062e2e3..7adaa62e6 100644 --- a/pkg/common/db/relation/group_member_model.go +++ b/pkg/common/db/relation/group_member_model.go @@ -84,28 +84,48 @@ func (g *GroupMemberGorm) MapGroupMemberNum(ctx context.Context, groupIDs []stri } func (g *GroupMemberGorm) FindJoinUserID(ctx context.Context, groupIDs []string) (groupUsers map[string][]string, err error) { - var items []struct { - GroupID string `gorm:"group_id"` - UserID string `gorm:"user_id"` - } - if err := g.db(ctx).Model(&relation.GroupMemberModel{}).Where("group_id in (?)", groupIDs).Find(&items).Error; err != nil { + var groupMembers []*relation.GroupMemberModel + if err := g.db(ctx).Select("group_id, user_id").Where("group_id in (?)", groupIDs).Find(&groupMembers).Error; err != nil { return nil, utils.Wrap(err, "") } groupUsers = make(map[string][]string) - for _, item := range items { - groupUsers[item.GroupID] = append(groupUsers[item.GroupID], item.UserID) + for _, item := range groupMembers { + v, ok := groupUsers[item.GroupID] + if !ok { + groupUsers[item.GroupID] = []string{item.UserID} + } else { + groupUsers[item.GroupID] = append(v, item.UserID) + } } return groupUsers, nil } func (g *GroupMemberGorm) FindMemberUserID(ctx context.Context, groupID string) (userIDs []string, err error) { - return userIDs, utils.Wrap(g.db(ctx).Model(&relation.GroupMemberModel{}).Where("group_id = ?", groupID).Pluck("user_id", &userIDs).Error, "") + return userIDs, utils.Wrap(g.db(ctx).Where("group_id = ?", groupID).Pluck("user_id", &userIDs).Error, "") } func (g *GroupMemberGorm) FindUserJoinedGroupID(ctx context.Context, userID string) (groupIDs []string, err error) { - return groupIDs, utils.Wrap(g.db(ctx).Model(&relation.GroupMemberModel{}).Where("user_id = ?", userID).Pluck("group_id", &groupIDs).Error, "") + return groupIDs, utils.Wrap(g.db(ctx).Where("user_id = ?", userID).Pluck("group_id", &groupIDs).Error, "") } func (g *GroupMemberGorm) TakeGroupMemberNum(ctx context.Context, groupID string) (count int64, err error) { - return count, utils.Wrap(g.db(ctx).Model(&relation.GroupMemberModel{}).Where("group_id = ?", groupID).Count(&count).Error, "") + return count, utils.Wrap(g.db(ctx).Where("group_id = ?", groupID).Count(&count).Error, "") +} + +func (g *GroupMemberGorm) FindUsersJoinedGroupID(ctx context.Context, userIDs []string) (map[string][]string, error) { + var groupMembers []*relation.GroupMemberModel + err := g.db(ctx).Select("group_id, user_id").Where("user_id IN (?)", userIDs).Find(&groupMembers).Error + if err != nil { + return nil, err + } + result := make(map[string][]string) + for _, groupMember := range groupMembers { + v, ok := result[groupMember.UserID] + if !ok { + result[groupMember.UserID] = []string{groupMember.GroupID} + } else { + result[groupMember.UserID] = append(v, groupMember.GroupID) + } + } + return result, nil } diff --git a/pkg/common/db/table/relation/group_member.go b/pkg/common/db/table/relation/group_member.go index fb1acf3ab..c49fbcc48 100644 --- a/pkg/common/db/table/relation/group_member.go +++ b/pkg/common/db/table/relation/group_member.go @@ -43,4 +43,5 @@ type GroupMemberModelInterface interface { FindJoinUserID(ctx context.Context, groupIDs []string) (groupUsers map[string][]string, err error) FindUserJoinedGroupID(ctx context.Context, userID string) (groupIDs []string, err error) TakeGroupMemberNum(ctx context.Context, groupID string) (count int64, err error) + FindUsersJoinedGroupID(ctx context.Context, userIDs []string) (map[string][]string, error) }