diff --git a/internal/rpc/group/group.go b/internal/rpc/group/group.go index 870646631..8bad860f3 100644 --- a/internal/rpc/group/group.go +++ b/internal/rpc/group/group.go @@ -122,11 +122,47 @@ func (s *groupServer) CreateGroup(ctx context.Context, req *pbGroup.CreateGroupR trace_log.SetContextInfo(ctx, utils.GetFuncName(1), nil, "req", req.String(), "resp", resp.String()) trace_log.ShowLog(ctx) }() - if err := token_verify.CheckAccessV2(ctx, req.OpUserID, req.OwnerUserID); err != nil { SetErrorForResp(err, resp.CommonResp) return } + var groupOwnerNum int + var userIDs []string + for _, info := range req.InitMemberList { + if info.RoleLevel == constant.GroupOwner { + groupOwnerNum++ + } + userIDs = append(userIDs, info.UserID) + } + if req.OwnerUserID != "" { + groupOwnerNum++ + userIDs = append(userIDs, req.OwnerUserID) + } + if groupOwnerNum != 1 { + SetErrorForResp(constant.ErrArgs, resp.CommonResp) + return + } + if utils.IsRepeatStringSlice(userIDs) { + SetErrorForResp(constant.ErrArgs, resp.CommonResp) + return + } + users, err := rocksCache.GetUserInfoFromCacheBatch(userIDs) + if err != nil { + SetErrorForResp(err, resp.CommonResp) + return + } + if len(users) != len(userIDs) { + SetErrorForResp(constant.ErrArgs, resp.CommonResp) + return + } + userMap := make(map[string]*imdb.User) + for i, user := range users { + userMap[user.UserID] = users[i] + } + if err := s.DelGroupAndUserCache(req.OperationID, "", userIDs); err != nil { + SetErrorForResp(err, resp.CommonResp) + return + } if err := callbackBeforeCreateGroup(ctx, req); err != nil { SetErrorForResp(err, resp.CommonResp) return @@ -138,7 +174,6 @@ func (s *groupServer) CreateGroup(ctx context.Context, req *pbGroup.CreateGroupR bi.SetString(groupId[0:8], 16) groupId = bi.String() } - //to group groupInfo := imdb.Group{} utils.CopyStructFields(&groupInfo, req.GroupInfo) groupInfo.CreatorUserID = req.OpUserID @@ -147,125 +182,69 @@ func (s *groupServer) CreateGroup(ctx context.Context, req *pbGroup.CreateGroupR if groupInfo.NotificationUpdateTime.Unix() < 0 { groupInfo.NotificationUpdateTime = utils.UnixSecondToTime(0) } - if err := (*imdb.Group)(nil).Create(ctx, []*imdb.Group{&groupInfo}); err != nil { - SetErrorForResp(err, resp.CommonResp) - return - } - groupMember := imdb.GroupMember{} - us := &imdb.User{} - if req.OwnerUserID != "" { - var userIDList []string - for _, v := range req.InitMemberList { - userIDList = append(userIDList, v.UserID) - } - userIDList = append(userIDList, req.OwnerUserID) - if err := s.DelGroupAndUserCache(req.OperationID, "", userIDList); err != nil { - SetErr(ctx, "DelGroupAndUserCache", err, &resp.CommonResp.ErrCode, &resp.CommonResp.ErrMsg) - return - } - var err error - us, err = imdb.GetUserByUserID(req.OwnerUserID) - if err != nil { - SetErr(ctx, "GetUserByUserID", err, &resp.CommonResp.ErrCode, &resp.CommonResp.ErrMsg, "userID", req.OwnerUserID) - return - } - //to group member - groupMember = imdb.GroupMember{GroupID: groupId, RoleLevel: constant.GroupOwner, OperatorUserID: req.OpUserID, JoinSource: constant.JoinByInvitation, InviterUserID: req.OpUserID} - utils.CopyStructFields(&groupMember, us) - if err := CallbackBeforeMemberJoinGroup(ctx, req.OperationID, &groupMember, groupInfo.Ex); err != nil { - SetErrorForResp(err, resp.CommonResp) - return - } - if err := (*imdb.GroupMember)(nil).Create(ctx, []*imdb.GroupMember{&groupMember}); err != nil { - SetErrorForResp(err, resp.CommonResp) - return - } - } - var okUserIDList []string if req.GroupInfo.GroupType != constant.SuperGroup { - //to group member var groupMembers []*imdb.GroupMember - for _, user := range req.InitMemberList { - us, err := rocksCache.GetUserInfoFromCache(user.UserID) - if err != nil { - trace_log.SetContextInfo(ctx, "GetUserInfoFromCache", err, "userID", user.UserID) - continue + joinGroup := func(userID string, roleLevel int32) error { + groupMember := &imdb.GroupMember{GroupID: groupId, RoleLevel: roleLevel, OperatorUserID: req.OpUserID, JoinSource: constant.JoinByInvitation, InviterUserID: req.OpUserID} + user := userMap[userID] + utils.CopyStructFields(&groupMember, user) + if err := CallbackBeforeMemberJoinGroup(ctx, req.OperationID, groupMember, groupInfo.Ex); err != nil { + return err } - if user.RoleLevel == constant.GroupOwner { - trace_log.SetContextInfo(ctx, "GetUserInfoFromCache", nil, "userID", user.UserID, "msg", "only one owner, failed ") - continue - } - groupMember.RoleLevel = user.RoleLevel - groupMember.JoinSource = constant.JoinByInvitation - groupMember.InviterUserID = req.OpUserID - utils.CopyStructFields(&groupMember, us) - if err := CallbackBeforeMemberJoinGroup(ctx, req.OperationID, &groupMember, groupInfo.Ex); err != nil { + groupMembers = append(groupMembers, groupMember) + return nil + } + if req.OwnerUserID == "" { + if err := joinGroup(req.OwnerUserID, constant.GroupOwner); err != nil { + SetErrorForResp(err, resp.CommonResp) + return + } + } + for _, info := range req.InitMemberList { + if err := joinGroup(info.UserID, info.RoleLevel); err != nil { SetErrorForResp(err, resp.CommonResp) return } - groupMembers = append(groupMembers, &groupMember) - okUserIDList = append(okUserIDList, user.UserID) } if err := (*imdb.GroupMember)(nil).Create(ctx, groupMembers); err != nil { SetErrorForResp(err, resp.CommonResp) return } - group, err := rocksCache.GetGroupInfoFromCache(ctx, groupId) - if err != nil { - SetErrorForResp(err, resp.CommonResp) - return - } - utils.CopyStructFields(resp.GroupInfo, group) - memberCount, err := rocksCache.GetGroupMemberNumFromCache(groupId) - if err != nil { - SetErrorForResp(err, resp.CommonResp) - return - } - resp.GroupInfo.MemberCount = uint32(memberCount) - if req.OwnerUserID != "" { - resp.GroupInfo.OwnerUserID = req.OwnerUserID - okUserIDList = append(okUserIDList, req.OwnerUserID) - } - // superGroup stored in mongodb } else { - for _, v := range req.InitMemberList { - okUserIDList = append(okUserIDList, v.UserID) - } - if err := db.DB.CreateSuperGroup(groupId, okUserIDList, len(okUserIDList)); err != nil { + if err := db.DB.CreateSuperGroup(groupId, userIDs, len(userIDs)); err != nil { SetErrorForResp(err, resp.CommonResp) return } } - - if len(okUserIDList) != 0 { - if req.GroupInfo.GroupType != constant.SuperGroup { - chat.GroupCreatedNotification(req.OperationID, req.OpUserID, groupId, okUserIDList) - } else { - for _, userID := range okUserIDList { - if err := rocksCache.DelJoinedSuperGroupIDListFromCache(userID); err != nil { - trace_log.SetContextInfo(ctx, "DelJoinedSuperGroupIDListFromCache", err, "userID", userID) - //log.NewWarn(req.OperationID, utils.GetSelfFuncName(), userID, err.Error()) - } + if err := (*imdb.Group)(nil).Create(ctx, []*imdb.Group{&groupInfo}); err != nil { + SetErrorForResp(err, resp.CommonResp) + return + } + utils.CopyStructFields(resp.GroupInfo, groupInfo) + resp.GroupInfo.MemberCount = uint32(len(userIDs)) + if req.GroupInfo.GroupType != constant.SuperGroup { + chat.GroupCreatedNotification(req.OperationID, req.OpUserID, groupId, userIDs) + } else { + for _, userID := range userIDs { + if err := rocksCache.DelJoinedSuperGroupIDListFromCache(userID); err != nil { + trace_log.SetContextInfo(ctx, "DelJoinedSuperGroupIDListFromCache", err, "userID", userID) } - go func() { - for _, v := range okUserIDList { - chat.SuperGroupNotification(req.OperationID, v, v) - } - }() } - return - } else { - //log.NewInfo(req.OperationID, "rpc CreateGroup return ", resp.String()) - return + go func() { + for _, v := range userIDs { + chat.SuperGroupNotification(req.OperationID, v, v) + } + }() } + return } -func (s *groupServer) GetJoinedGroupList(ctx context.Context, req *pbGroup.GetJoinedGroupListReq) (resp *pbGroup.GetJoinedGroupListResp,_ error) { +func (s *groupServer) GetJoinedGroupList(ctx context.Context, req *pbGroup.GetJoinedGroupListReq) (resp *pbGroup.GetJoinedGroupListResp, _ error) { resp = &pbGroup.GetJoinedGroupListResp{CommonResp: &open_im_sdk.CommonResp{}} ctx = trace_log.NewRpcCtx(ctx, utils.GetSelfFuncName(), req.OperationID) trace_log.SetContextInfo(ctx, utils.GetSelfFuncName(), nil, "req", req, "resp", resp) defer trace_log.ShowLog(ctx) - if err := token_verify.CheckAccessV2(ctx, req.OpUserID, req.FromUserID);err != nil { + if err := token_verify.CheckAccessV2(ctx, req.OpUserID, req.FromUserID); err != nil { SetErrorForResp(err, &resp.CommonResp.ErrCode, &resp.CommonResp.ErrMsg) return } diff --git a/pkg/common/db/rocks_cache/rocks_cache.go b/pkg/common/db/rocks_cache/rocks_cache.go index 95298e458..d11459110 100644 --- a/pkg/common/db/rocks_cache/rocks_cache.go +++ b/pkg/common/db/rocks_cache/rocks_cache.go @@ -205,6 +205,18 @@ func GetUserInfoFromCache(userID string) (*imdb.User, error) { return userInfo, utils.Wrap(err, "") } +func GetUserInfoFromCacheBatch(userIDs []string) ([]*imdb.User, error) { + var users []*imdb.User + for _, userID := range userIDs { + user, err := GetUserInfoFromCache(userID) + if err != nil { + return nil, err + } + users = append(users, user) + } + return users, nil +} + func DelUserInfoFromCache(userID string) error { return db.DB.Rc.TagAsDeleted(userInfoCache + userID) } diff --git a/pkg/utils/strings.go b/pkg/utils/strings.go index f611f1a35..ef744730f 100644 --- a/pkg/utils/strings.go +++ b/pkg/utils/strings.go @@ -37,7 +37,7 @@ func Uint32ToString(i uint32) string { return strconv.FormatInt(int64(i), 10) } -//judge a string whether in the string list +// judge a string whether in the string list func IsContain(target string, List []string) bool { for _, element := range List { @@ -80,7 +80,7 @@ func StructToJsonBytes(param interface{}) []byte { return dataType } -//The incoming parameter must be a pointer +// The incoming parameter must be a pointer func JsonStringToStruct(s string, args interface{}) error { err := json.Unmarshal([]byte(s), args) return err @@ -121,3 +121,14 @@ func RemoveDuplicateElement(idList []string) []string { } return result } + +func IsRepeatStringSlice(arr []string) bool { + t := make(map[string]struct{}) + for _, s := range arr { + if _, ok := t[s]; ok { + return true + } + t[s] = struct{}{} + } + return false +}