diff --git a/cmd/api/main.go b/cmd/api/main.go index f154b3bd7..b39a294a8 100644 --- a/cmd/api/main.go +++ b/cmd/api/main.go @@ -4,15 +4,17 @@ import ( "bytes" "context" "fmt" - "gopkg.in/yaml.v3" "net" "strconv" + "gopkg.in/yaml.v3" + "github.com/OpenIMSDK/Open-IM-Server/internal/api" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/cmd" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/config" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/db/cache" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/log" + "github.com/OpenIMSDK/Open-IM-Server/pkg/discoveryregistry" "github.com/OpenIMSDK/openKeeper" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant" @@ -31,11 +33,13 @@ func run(port int) error { if port == 0 { port = config.Config.Api.GinPort[0] } + var err error rdb, err := cache.NewRedis() if err != nil { return err } - zk, err := openKeeper.NewClient(config.Config.Zookeeper.ZkAddr, config.Config.Zookeeper.Schema, 10, config.Config.Zookeeper.UserName, config.Config.Zookeeper.Password) + var client discoveryregistry.SvcDiscoveryRegistry + client, err = openKeeper.NewClient(config.Config.Zookeeper.ZkAddr, config.Config.Zookeeper.Schema, 10, config.Config.Zookeeper.UserName, config.Config.Zookeeper.Password) if err != nil { return err } @@ -43,11 +47,11 @@ func run(port int) error { if err := yaml.NewEncoder(buf).Encode(config.Config); err != nil { return err } - if err := zk.RegisterConf2Registry(constant.OpenIMCommonConfigKey, buf.Bytes()); err != nil { + if err := client.RegisterConf2Registry(constant.OpenIMCommonConfigKey, buf.Bytes()); err != nil { return err } log.NewPrivateLog(constant.LogFileName) - router := api.NewGinRouter(zk, rdb) + router := api.NewGinRouter(client, rdb) var address string if config.Config.Api.ListenIP != "" { address = net.JoinHostPort(config.Config.Api.ListenIP, strconv.Itoa(port)) diff --git a/internal/rpc/group/group.go b/internal/rpc/group/group.go index ee414e726..70a98afba 100644 --- a/internal/rpc/group/group.go +++ b/internal/rpc/group/group.go @@ -256,17 +256,6 @@ func (s *groupServer) InviteUserToGroup(ctx context.Context, req *pbGroup.Invite if group.Status == constant.GroupStatusDismissed { return nil, errs.ErrDismissedAlready.Wrap() } - members, err := s.GroupDatabase.FindGroupMember(ctx, []string{group.GroupID}, nil, nil) - if err != nil { - return nil, err - } - memberMap := utils.SliceToMap(members, func(e *relationTb.GroupMemberModel) string { - return e.UserID - }) - if ids := utils.Single(req.InvitedUserIDs, utils.Keys(memberMap)); len(ids) > 0 { - log.ZDebug(ctx, "user in group", "ids", ids) - return nil, errs.ErrArgs.Wrap("user in group " + strings.Join(ids, ",")) - } userMap, err := s.UserCheck.GetUsersInfoMap(ctx, req.InvitedUserIDs, true) if err != nil { return nil, err @@ -274,11 +263,14 @@ func (s *groupServer) InviteUserToGroup(ctx context.Context, req *pbGroup.Invite if group.NeedVerification == constant.AllNeedVerification { if !tokenverify.IsAppManagerUid(ctx) { opUserID := mcontext.GetOpUserID(ctx) - member, ok := memberMap[opUserID] - if !ok { + groupMembers, err := s.GroupDatabase.FindGroupMember(ctx, []string{req.GroupID}, []string{opUserID}, nil) + if err != nil { + return nil, err + } + if len(groupMembers) <= 0 { return nil, errs.ErrNoPermission.Wrap("not in group") } - if !(member.RoleLevel == constant.GroupOwner || member.RoleLevel == constant.GroupAdmin) { + if !(groupMembers[0].RoleLevel == constant.GroupOwner || groupMembers[0].RoleLevel == constant.GroupAdmin) { var requests []*relationTb.GroupRequestModel for _, userID := range req.InvitedUserIDs { requests = append(requests, &relationTb.GroupRequestModel{ diff --git a/pkg/common/db/controller/group.go b/pkg/common/db/controller/group.go index da2b3062d..f130cc333 100644 --- a/pkg/common/db/controller/group.go +++ b/pkg/common/db/controller/group.go @@ -208,7 +208,8 @@ func (g *groupDatabase) PageGroupMember(ctx context.Context, groupIDs []string, totalGroupMembers = append(totalGroupMembers, groupMembers...) } } - return + + return uint32(len(totalGroupMembers)), totalGroupMembers, nil } for _, groupID := range groupIDs { groupMembers, err := g.cache.GetGroupMembersInfo(ctx, groupID, userIDs) @@ -234,7 +235,7 @@ func (g *groupDatabase) PageGroupMember(ctx context.Context, groupIDs []string, totalGroupMembers = append(totalGroupMembers, groupMembers...) } } - return + return uint32(len(totalGroupMembers)), totalGroupMembers, nil } for _, groupID := range groupIDs { groupMembers, err := g.cache.GetGroupMembersPage(ctx, groupID, userIDs, pageNumber, showNumber)