diff --git a/go.mod b/go.mod index d65977757..0e3a13904 100644 --- a/go.mod +++ b/go.mod @@ -13,7 +13,7 @@ require ( github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 github.com/mitchellh/mapstructure v1.5.0 github.com/openimsdk/protocol v0.0.73-alpha.6 - github.com/openimsdk/tools v0.0.50-alpha.79 + github.com/openimsdk/tools v0.0.50-alpha.81 github.com/pkg/errors v0.9.1 // indirect github.com/prometheus/client_golang v1.18.0 github.com/stretchr/testify v1.9.0 diff --git a/go.sum b/go.sum index 390a51c4a..6bc410a2d 100644 --- a/go.sum +++ b/go.sum @@ -345,12 +345,12 @@ github.com/onsi/ginkgo/v2 v2.19.0 h1:9Cnnf7UHo57Hy3k6/m5k3dRfGTMXGvxhHFvkDTCTpvA github.com/onsi/ginkgo/v2 v2.19.0/go.mod h1:rlwLi9PilAFJ8jCg9UE1QP6VBpd6/xj3SRC0d6TU0To= github.com/onsi/gomega v1.25.0 h1:Vw7br2PCDYijJHSfBOWhov+8cAnUf8MfMaIOV323l6Y= github.com/onsi/gomega v1.25.0/go.mod h1:r+zV744Re+DiYCIPRlYOTxn0YkOLcAnW8k1xXdMPGhM= -github.com/openimsdk/gomake v0.0.15-alpha.2 h1:5Q8yl8ezy2yx+q8/ucU/t4kJnDfCzNOrkXcDACCqtyM= -github.com/openimsdk/gomake v0.0.15-alpha.2/go.mod h1:PndCozNc2IsQIciyn9mvEblYWZwJmAI+06z94EY+csI= +github.com/openimsdk/gomake v0.0.15-alpha.5 h1:eEZCEHm+NsmcO3onXZPIUbGFCYPYbsX5beV3ZyOsGhY= +github.com/openimsdk/gomake v0.0.15-alpha.5/go.mod h1:PndCozNc2IsQIciyn9mvEblYWZwJmAI+06z94EY+csI= github.com/openimsdk/protocol v0.0.73-alpha.6 h1:sna9coWG7HN1zObBPtvG0Ki/vzqHXiB4qKbA5P3w7kc= github.com/openimsdk/protocol v0.0.73-alpha.6/go.mod h1:WF7EuE55vQvpyUAzDXcqg+B+446xQyEba0X35lTINmw= -github.com/openimsdk/tools v0.0.50-alpha.79 h1:jxYEbrzaze4Z2r4NrKad816buZ690ix0L9MTOOOH3ik= -github.com/openimsdk/tools v0.0.50-alpha.79/go.mod h1:n2poR3asX1e1XZce4O+MOWAp+X02QJRFvhcLCXZdzRo= +github.com/openimsdk/tools v0.0.50-alpha.81 h1:VbuJKtigNXLkCKB/Q6f2UHsqoSaTOAwS8F51c1nhOCA= +github.com/openimsdk/tools v0.0.50-alpha.81/go.mod h1:n2poR3asX1e1XZce4O+MOWAp+X02QJRFvhcLCXZdzRo= github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ= diff --git a/internal/api/config_manager.go b/internal/api/config_manager.go index 4d846dd9b..35ab060a7 100644 --- a/internal/api/config_manager.go +++ b/internal/api/config_manager.go @@ -45,7 +45,7 @@ func NewConfigManager(IMAdminUserID []string, cfg *config.AllConfig, client *cli } func (cm *ConfigManager) CheckAdmin(c *gin.Context) { - if err := authverify.CheckAdmin(c, cm.imAdminUserID); err != nil { + if err := authverify.CheckAdmin(c); err != nil { apiresp.GinError(c, err) c.Abort() } diff --git a/internal/api/msg.go b/internal/api/msg.go index 1d53cbc48..8be4832e6 100644 --- a/internal/api/msg.go +++ b/internal/api/msg.go @@ -281,7 +281,7 @@ func (m *MessageApi) SendMessage(c *gin.Context) { } // Check if the user has the app manager role. - if !authverify.IsAppManagerUid(c, m.imAdminUserID) { + if !authverify.IsAdmin(c) { // Respond with a permission error if the user is not an app manager. apiresp.GinError(c, errs.ErrNoPermission.WrapMsg("only app manager can send message")) return @@ -355,7 +355,7 @@ func (m *MessageApi) SendBusinessNotification(c *gin.Context) { if req.ReliabilityLevel == nil { req.ReliabilityLevel = datautil.ToPtr(1) } - if !authverify.IsAppManagerUid(c, m.imAdminUserID) { + if !authverify.IsAdmin(c) { apiresp.GinError(c, errs.ErrNoPermission.WrapMsg("only app manager can send message")) return } @@ -399,7 +399,7 @@ func (m *MessageApi) BatchSendMsg(c *gin.Context) { apiresp.GinError(c, errs.ErrArgs.WithDetail(err.Error()).Wrap()) return } - if err := authverify.CheckAdmin(c, m.imAdminUserID); err != nil { + if err := authverify.CheckAdmin(c); err != nil { apiresp.GinError(c, errs.ErrNoPermission.WrapMsg("only app manager can send message")) return } diff --git a/internal/api/router.go b/internal/api/router.go index c7bc3c724..e9e5f6d5f 100644 --- a/internal/api/router.go +++ b/internal/api/router.go @@ -9,6 +9,11 @@ import ( "github.com/gin-gonic/gin" "github.com/gin-gonic/gin/binding" "github.com/go-playground/validator/v10" + "github.com/openimsdk/open-im-server/v3/pkg/authverify" + "github.com/openimsdk/tools/mcontext" + "github.com/openimsdk/tools/utils/datautil" + clientv3 "go.etcd.io/etcd/client/v3" + "github.com/openimsdk/open-im-server/v3/internal/api/jssdk" "github.com/openimsdk/open-im-server/v3/pkg/common/config" "github.com/openimsdk/open-im-server/v3/pkg/common/prommetrics" @@ -96,7 +101,7 @@ func newGinRouter(ctx context.Context, client discovery.SvcDiscoveryRegistry, cf r.Use(gzip.Gzip(gzip.BestSpeed)) } r.Use(prommetricsGin(), gin.RecoveryWithWriter(gin.DefaultErrorWriter, mw.GinPanicErr), mw.CorsHandler(), - mw.GinParseOperationID(), GinParseToken(rpcli.NewAuthClient(authConn))) + mw.GinParseOperationID(), GinParseToken(rpcli.NewAuthClient(authConn)), setGinIsAdmin(cfg.Share.IMAdminUserID)) u := NewUserApi(user.NewUserClient(userConn), client, cfg.Discovery.RpcService) { @@ -352,6 +357,14 @@ func GinParseToken(authClient *rpcli.AuthClient) gin.HandlerFunc { } } +func setGinIsAdmin(imAdminUserID []string) gin.HandlerFunc { + return func(c *gin.Context) { + opUserID := mcontext.GetOpUserID(c) + admin := datautil.Contain(opUserID, imAdminUserID...) + c.Set(authverify.CtxIsAdminKey, admin) + } +} + // Whitelist api not parse token var Whitelist = []string{ "/auth/get_admin_token", diff --git a/internal/msggateway/hub_server.go b/internal/msggateway/hub_server.go index 887a90d7a..8c744b7d1 100644 --- a/internal/msggateway/hub_server.go +++ b/internal/msggateway/hub_server.go @@ -101,7 +101,7 @@ func NewServer(longConnServer LongConnServer, conf *Config, ready func(srv *Serv } func (s *Server) GetUsersOnlineStatus(ctx context.Context, req *msggateway.GetUsersOnlineStatusReq) (*msggateway.GetUsersOnlineStatusResp, error) { - if !authverify.IsAppManagerUid(ctx, s.config.Share.IMAdminUserID) { + if !authverify.IsAdmin(ctx) { return nil, errs.ErrNoPermission.WrapMsg("only app manager") } var resp msggateway.GetUsersOnlineStatusResp diff --git a/internal/rpc/auth/auth.go b/internal/rpc/auth/auth.go index d34630b2f..2c2691d1d 100644 --- a/internal/rpc/auth/auth.go +++ b/internal/rpc/auth/auth.go @@ -18,11 +18,14 @@ import ( "context" "errors" + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache" + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/mcache" + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/database/mgo" + "github.com/openimsdk/open-im-server/v3/pkg/dbbuild" "github.com/openimsdk/open-im-server/v3/pkg/rpcli" "github.com/openimsdk/open-im-server/v3/pkg/common/config" redis2 "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/redis" - "github.com/openimsdk/tools/db/redisutil" "github.com/openimsdk/tools/utils/datautil" "github.com/redis/go-redis/v9" @@ -43,7 +46,7 @@ import ( type authServer struct { pbauth.UnimplementedAuthServer authDatabase controller.AuthDatabase - RegisterCenter discovery.SvcDiscoveryRegistry + RegisterCenter discovery.Conn config *Config userClient *rpcli.UserClient } @@ -51,15 +54,31 @@ type authServer struct { type Config struct { RpcConfig config.Auth RedisConfig config.Redis + MongoConfig config.Mongo Share config.Share Discovery config.Discovery } -func Start(ctx context.Context, config *Config, client discovery.SvcDiscoveryRegistry, server *grpc.Server) error { - rdb, err := redisutil.NewRedisClient(ctx, config.RedisConfig.Build()) +func Start(ctx context.Context, config *Config, client discovery.Conn, server grpc.ServiceRegistrar) error { + dbb := dbbuild.NewBuilder(&config.MongoConfig, &config.RedisConfig) + rdb, err := dbb.Redis(ctx) if err != nil { return err } + var token cache.TokenModel + if rdb == nil { + mdb, err := dbb.Mongo(ctx) + if err != nil { + return err + } + mc, err := mgo.NewCacheMgo(mdb.GetDB()) + if err != nil { + return err + } + token = mcache.NewTokenCacheModel(mc, config.RpcConfig.TokenPolicy.Expire) + } else { + token = redis2.NewTokenCacheModel(rdb, config.RpcConfig.TokenPolicy.Expire) + } userConn, err := client.GetConn(ctx, config.Discovery.RpcService.User) if err != nil { return err @@ -67,7 +86,7 @@ func Start(ctx context.Context, config *Config, client discovery.SvcDiscoveryReg pbauth.RegisterAuthServer(server, &authServer{ RegisterCenter: client, authDatabase: controller.NewAuthDatabase( - redis2.NewTokenCacheModel(rdb, config.RpcConfig.TokenPolicy.Expire), + token, config.Share.Secret, config.RpcConfig.TokenPolicy.Expire, config.Share.MultiLogin, @@ -106,7 +125,7 @@ func (s *authServer) GetAdminToken(ctx context.Context, req *pbauth.GetAdminToke } func (s *authServer) GetUserToken(ctx context.Context, req *pbauth.GetUserTokenReq) (*pbauth.GetUserTokenResp, error) { - if err := authverify.CheckAdmin(ctx, s.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAdmin(ctx); err != nil { return nil, err } @@ -116,7 +135,7 @@ func (s *authServer) GetUserToken(ctx context.Context, req *pbauth.GetUserTokenR resp := pbauth.GetUserTokenResp{} - if authverify.IsManagerUserID(req.UserID, s.config.Share.IMAdminUserID) { + if authverify.CheckUserIsAdmin(ctx, req.UserID) { return nil, errs.ErrNoPermission.WrapMsg("don't get Admin token") } user, err := s.userClient.GetUserInfo(ctx, req.UserID) @@ -145,7 +164,7 @@ func (s *authServer) parseToken(ctx context.Context, tokensString string) (claim return nil, err } if len(m) == 0 { - isAdmin := authverify.IsManagerUserID(claims.UserID, s.config.Share.IMAdminUserID) + isAdmin := authverify.CheckUserIsAdmin(ctx, claims.UserID) if isAdmin { if err = s.authDatabase.GetTemporaryTokensWithoutError(ctx, claims.UserID, claims.PlatformID, tokensString); err == nil { return claims, nil @@ -163,7 +182,7 @@ func (s *authServer) parseToken(ctx context.Context, tokensString string) (claim return nil, errs.Wrap(errs.ErrTokenUnknown) } } else { - isAdmin := authverify.IsManagerUserID(claims.UserID, s.config.Share.IMAdminUserID) + isAdmin := authverify.CheckUserIsAdmin(ctx, claims.UserID) if isAdmin { if err = s.authDatabase.GetTemporaryTokensWithoutError(ctx, claims.UserID, claims.PlatformID, tokensString); err == nil { return claims, nil @@ -186,7 +205,7 @@ func (s *authServer) ParseToken(ctx context.Context, req *pbauth.ParseTokenReq) } func (s *authServer) ForceLogout(ctx context.Context, req *pbauth.ForceLogoutReq) (*pbauth.ForceLogoutResp, error) { - if err := authverify.CheckAdmin(ctx, s.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAdmin(ctx); err != nil { return nil, err } if err := s.forceKickOff(ctx, req.UserID, req.PlatformID); err != nil { diff --git a/internal/rpc/conversation/sync.go b/internal/rpc/conversation/sync.go index ad88b2bbd..cee74b319 100644 --- a/internal/rpc/conversation/sync.go +++ b/internal/rpc/conversation/sync.go @@ -4,12 +4,16 @@ import ( "context" "github.com/openimsdk/open-im-server/v3/internal/rpc/incrversion" + "github.com/openimsdk/open-im-server/v3/pkg/authverify" "github.com/openimsdk/open-im-server/v3/pkg/common/storage/model" "github.com/openimsdk/open-im-server/v3/pkg/util/hashutil" "github.com/openimsdk/protocol/conversation" ) func (c *conversationServer) GetFullOwnerConversationIDs(ctx context.Context, req *conversation.GetFullOwnerConversationIDsReq) (*conversation.GetFullOwnerConversationIDsResp, error) { + if err := authverify.CheckAccess(ctx, req.UserID); err != nil { + return nil, err + } vl, err := c.conversationDatabase.FindMaxConversationUserVersionCache(ctx, req.UserID) if err != nil { return nil, err diff --git a/internal/rpc/group/group.go b/internal/rpc/group/group.go index 3d8d35960..1ed3ce799 100644 --- a/internal/rpc/group/group.go +++ b/internal/rpc/group/group.go @@ -156,7 +156,7 @@ func (g *groupServer) NotificationUserInfoUpdate(ctx context.Context, req *pbgro } func (g *groupServer) CheckGroupAdmin(ctx context.Context, groupID string) error { - if !authverify.IsAppManagerUid(ctx, g.config.Share.IMAdminUserID) { + if !authverify.IsAdmin(ctx) { groupMember, err := g.db.TakeGroupMember(ctx, groupID, mcontext.GetOpUserID(ctx)) if err != nil { return err @@ -208,7 +208,7 @@ func (g *groupServer) CreateGroup(ctx context.Context, req *pbgroup.CreateGroupR if req.OwnerUserID == "" { return nil, errs.ErrArgs.WrapMsg("no group owner") } - if err := authverify.CheckAccessV3(ctx, req.OwnerUserID, g.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAccess(ctx, req.OwnerUserID); err != nil { return nil, err } userIDs := append(append(req.MemberUserIDs, req.AdminUserIDs...), req.OwnerUserID) @@ -311,7 +311,7 @@ func (g *groupServer) CreateGroup(ctx context.Context, req *pbgroup.CreateGroupR } func (g *groupServer) GetJoinedGroupList(ctx context.Context, req *pbgroup.GetJoinedGroupListReq) (*pbgroup.GetJoinedGroupListResp, error) { - if err := authverify.CheckAccessV3(ctx, req.FromUserID, g.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAccess(ctx, req.FromUserID); err != nil { return nil, err } total, members, err := g.db.PageGetJoinGroup(ctx, req.FromUserID, req.Pagination) @@ -383,7 +383,7 @@ func (g *groupServer) InviteUserToGroup(ctx context.Context, req *pbgroup.Invite var groupMember *model.GroupMember var opUserID string - if !authverify.IsAppManagerUid(ctx, g.config.Share.IMAdminUserID) { + if !authverify.IsAdmin(ctx) { opUserID = mcontext.GetOpUserID(ctx) var err error groupMember, err = g.db.TakeGroupMember(ctx, req.GroupID, opUserID) @@ -402,7 +402,7 @@ func (g *groupServer) InviteUserToGroup(ctx context.Context, req *pbgroup.Invite } if group.NeedVerification == constant.AllNeedVerification { - if !authverify.IsAppManagerUid(ctx, g.config.Share.IMAdminUserID) { + if !authverify.IsAdmin(ctx) { if !(groupMember.RoleLevel == constant.GroupOwner || groupMember.RoleLevel == constant.GroupAdmin) { var requests []*model.GroupRequest for _, userID := range req.InvitedUserIDs { @@ -490,6 +490,11 @@ func (g *groupServer) GetGroupAllMember(ctx context.Context, req *pbgroup.GetGro } func (g *groupServer) GetGroupMemberList(ctx context.Context, req *pbgroup.GetGroupMemberListReq) (*pbgroup.GetGroupMemberListResp, error) { + if opUserID := mcontext.GetOpUserID(ctx); !datautil.Contain(opUserID, g.config.Share.IMAdminUserID...) { + if _, err := g.db.TakeGroupMember(ctx, req.GroupID, opUserID); err != nil { + return nil, err + } + } var ( total int64 members []*model.GroupMember @@ -498,7 +503,7 @@ func (g *groupServer) GetGroupMemberList(ctx context.Context, req *pbgroup.GetGr if req.Keyword == "" { total, members, err = g.db.PageGetGroupMember(ctx, req.GroupID, req.Pagination) } else { - members, err = g.db.FindGroupMemberAll(ctx, req.GroupID) + total, members, err = g.db.SearchGroupMember(ctx, req.GroupID, req.Keyword, req.Pagination) } if err != nil { return nil, err @@ -506,27 +511,6 @@ func (g *groupServer) GetGroupMemberList(ctx context.Context, req *pbgroup.GetGr if err := g.PopulateGroupMember(ctx, members...); err != nil { return nil, err } - if req.Keyword != "" { - groupMembers := make([]*model.GroupMember, 0) - for _, member := range members { - if member.UserID == req.Keyword { - groupMembers = append(groupMembers, member) - total++ - continue - } - if member.Nickname == req.Keyword { - groupMembers = append(groupMembers, member) - total++ - continue - } - } - - members := datautil.Paginate(groupMembers, int(req.Pagination.GetPageNumber()), int(req.Pagination.GetShowNumber())) - return &pbgroup.GetGroupMemberListResp{ - Total: uint32(total), - Members: datautil.Batch(convert.Db2PbGroupMember, members), - }, nil - } return &pbgroup.GetGroupMemberListResp{ Total: uint32(total), Members: datautil.Batch(convert.Db2PbGroupMember, members), @@ -567,7 +551,7 @@ func (g *groupServer) KickGroupMember(ctx context.Context, req *pbgroup.KickGrou for i, member := range members { memberMap[member.UserID] = members[i] } - isAppManagerUid := authverify.IsAppManagerUid(ctx, g.config.Share.IMAdminUserID) + isAppManagerUid := authverify.IsAdmin(ctx) opMember := memberMap[opUserID] for _, userID := range req.KickedUserIDs { member, ok := memberMap[userID] @@ -785,7 +769,7 @@ func (g *groupServer) GroupApplicationResponse(ctx context.Context, req *pbgroup if !datautil.Contain(req.HandleResult, constant.GroupResponseAgree, constant.GroupResponseRefuse) { return nil, errs.ErrArgs.WrapMsg("HandleResult unknown") } - if !authverify.IsAppManagerUid(ctx, g.config.Share.IMAdminUserID) { + if !authverify.IsAdmin(ctx) { groupMember, err := g.db.TakeGroupMember(ctx, req.GroupID, mcontext.GetOpUserID(ctx)) if err != nil { return nil, err @@ -933,7 +917,7 @@ func (g *groupServer) QuitGroup(ctx context.Context, req *pbgroup.QuitGroupReq) if req.UserID == "" { req.UserID = mcontext.GetOpUserID(ctx) } else { - if err := authverify.CheckAccessV3(ctx, req.UserID, g.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAccess(ctx, req.UserID); err != nil { return nil, err } } @@ -971,7 +955,7 @@ func (g *groupServer) deleteMemberAndSetConversationSeq(ctx context.Context, gro func (g *groupServer) SetGroupInfo(ctx context.Context, req *pbgroup.SetGroupInfoReq) (*pbgroup.SetGroupInfoResp, error) { var opMember *model.GroupMember - if !authverify.IsAppManagerUid(ctx, g.config.Share.IMAdminUserID) { + if !authverify.IsAdmin(ctx) { var err error opMember, err = g.db.TakeGroupMember(ctx, req.GroupInfoForSet.GroupID, mcontext.GetOpUserID(ctx)) if err != nil { @@ -1064,7 +1048,7 @@ func (g *groupServer) SetGroupInfo(ctx context.Context, req *pbgroup.SetGroupInf func (g *groupServer) SetGroupInfoEx(ctx context.Context, req *pbgroup.SetGroupInfoExReq) (*pbgroup.SetGroupInfoExResp, error) { var opMember *model.GroupMember - if !authverify.IsAppManagerUid(ctx, g.config.Share.IMAdminUserID) { + if !authverify.IsAdmin(ctx) { var err error opMember, err = g.db.TakeGroupMember(ctx, req.GroupID, mcontext.GetOpUserID(ctx)) @@ -1216,7 +1200,7 @@ func (g *groupServer) TransferGroupOwner(ctx context.Context, req *pbgroup.Trans return nil, errs.ErrArgs.WrapMsg("NewOwnerUser not in group " + req.NewOwnerUserID) } - if !authverify.IsAppManagerUid(ctx, g.config.Share.IMAdminUserID) { + if !authverify.IsAdmin(ctx) { if !(mcontext.GetOpUserID(ctx) == oldOwner.UserID && oldOwner.RoleLevel == constant.GroupOwner) { return nil, errs.ErrNoPermission.WrapMsg("no permission transfer group owner") } @@ -1359,7 +1343,7 @@ func (g *groupServer) DismissGroup(ctx context.Context, req *pbgroup.DismissGrou if err != nil { return nil, err } - if !authverify.IsAppManagerUid(ctx, g.config.Share.IMAdminUserID) { + if !authverify.IsAdmin(ctx) { if owner.UserID != mcontext.GetOpUserID(ctx) { return nil, errs.ErrNoPermission.WrapMsg("not group owner") } @@ -1416,7 +1400,7 @@ func (g *groupServer) MuteGroupMember(ctx context.Context, req *pbgroup.MuteGrou if err := g.PopulateGroupMember(ctx, member); err != nil { return nil, err } - if !authverify.IsAppManagerUid(ctx, g.config.Share.IMAdminUserID) { + if !authverify.IsAdmin(ctx) { opMember, err := g.db.TakeGroupMember(ctx, req.GroupID, mcontext.GetOpUserID(ctx)) if err != nil { return nil, err @@ -1452,7 +1436,7 @@ func (g *groupServer) CancelMuteGroupMember(ctx context.Context, req *pbgroup.Ca return nil, err } - if !authverify.IsAppManagerUid(ctx, g.config.Share.IMAdminUserID) { + if !authverify.IsAdmin(ctx) { opMember, err := g.db.TakeGroupMember(ctx, req.GroupID, mcontext.GetOpUserID(ctx)) if err != nil { return nil, err @@ -1512,7 +1496,7 @@ func (g *groupServer) SetGroupMemberInfo(ctx context.Context, req *pbgroup.SetGr if opUserID == "" { return nil, errs.ErrNoPermission.WrapMsg("no op user id") } - isAppManagerUid := authverify.IsAppManagerUid(ctx, g.config.Share.IMAdminUserID) + isAppManagerUid := authverify.IsAdmin(ctx) groupMembers := make(map[string][]*pbgroup.SetGroupMemberInfo) for i, member := range req.Members { if member.RoleLevel != nil { diff --git a/internal/rpc/group/notification.go b/internal/rpc/group/notification.go index 77ec7dbdc..a6c715735 100644 --- a/internal/rpc/group/notification.go +++ b/internal/rpc/group/notification.go @@ -242,8 +242,8 @@ func (g *NotificationSender) fillOpUserByUserID(ctx context.Context, userID stri return errs.ErrInternalServer.WrapMsg("**sdkws.GroupMemberFullInfo is nil") } if groupID != "" { - if authverify.IsManagerUserID(userID, g.config.Share.IMAdminUserID) { - *opUser = &sdkws.GroupMemberFullInfo{ + if authverify.CheckUserIsAdmin(ctx, userID) { + *targetUser = &sdkws.GroupMemberFullInfo{ GroupID: groupID, UserID: userID, RoleLevel: constant.GroupAdmin, diff --git a/internal/rpc/group/sync.go b/internal/rpc/group/sync.go index 0592aa811..ed608dea3 100644 --- a/internal/rpc/group/sync.go +++ b/internal/rpc/group/sync.go @@ -12,15 +12,23 @@ import ( pbgroup "github.com/openimsdk/protocol/group" "github.com/openimsdk/protocol/sdkws" "github.com/openimsdk/tools/errs" - "github.com/openimsdk/tools/log" + "github.com/openimsdk/tools/mcontext" + "github.com/openimsdk/tools/utils/datautil" ) -func (s *groupServer) GetFullGroupMemberUserIDs(ctx context.Context, req *pbgroup.GetFullGroupMemberUserIDsReq) (*pbgroup.GetFullGroupMemberUserIDsResp, error) { - vl, err := s.db.FindMaxGroupMemberVersionCache(ctx, req.GroupID) +const versionSyncLimit = 500 + +func (g *groupServer) GetFullGroupMemberUserIDs(ctx context.Context, req *pbgroup.GetFullGroupMemberUserIDsReq) (*pbgroup.GetFullGroupMemberUserIDsResp, error) { + userIDs, err := g.db.FindGroupMemberUserID(ctx, req.GroupID) if err != nil { return nil, err } - userIDs, err := s.db.FindGroupMemberUserID(ctx, req.GroupID) + if opUserID := mcontext.GetOpUserID(ctx); !datautil.Contain(opUserID, g.config.Share.IMAdminUserID...) { + if !datautil.Contain(opUserID, userIDs...) { + return nil, errs.ErrNoPermission.WrapMsg("user not in group") + } + } + vl, err := g.db.FindMaxGroupMemberVersionCache(ctx, req.GroupID) if err != nil { return nil, err } @@ -36,8 +44,11 @@ func (s *groupServer) GetFullGroupMemberUserIDs(ctx context.Context, req *pbgrou }, nil } -func (s *groupServer) GetFullJoinGroupIDs(ctx context.Context, req *pbgroup.GetFullJoinGroupIDsReq) (*pbgroup.GetFullJoinGroupIDsResp, error) { - vl, err := s.db.FindMaxJoinGroupVersionCache(ctx, req.UserID) +func (g *groupServer) GetFullJoinGroupIDs(ctx context.Context, req *pbgroup.GetFullJoinGroupIDsReq) (*pbgroup.GetFullJoinGroupIDsResp, error) { + if err := authverify.CheckAccess(ctx, req.UserID); err != nil { + return nil, err + } + vl, err := g.db.FindMaxJoinGroupVersionCache(ctx, req.UserID) if err != nil { return nil, err } @@ -65,6 +76,9 @@ func (s *groupServer) GetIncrementalGroupMember(ctx context.Context, req *pbgrou if group.Status == constant.GroupStatusDismissed { return nil, servererrs.ErrDismissedAlready.Wrap() } + if _, err := g.db.TakeGroupMember(ctx, req.GroupID, mcontext.GetOpUserID(ctx)); err != nil { + return nil, err + } var ( hasGroupUpdate bool sortVersion uint64 @@ -132,152 +146,8 @@ func (s *groupServer) GetIncrementalGroupMember(ctx context.Context, req *pbgrou return resp, nil } -func (s *groupServer) BatchGetIncrementalGroupMember(ctx context.Context, req *pbgroup.BatchGetIncrementalGroupMemberReq) (resp *pbgroup.BatchGetIncrementalGroupMemberResp, err error) { - type VersionInfo struct { - GroupID string - VersionID string - VersionNumber uint64 - } - - var groupIDs []string - - groupsVersionMap := make(map[string]*VersionInfo) - groupsMap := make(map[string]*model.Group) - hasGroupUpdateMap := make(map[string]bool) - sortVersionMap := make(map[string]uint64) - - var targetKeys, versionIDs []string - var versionNumbers []uint64 - - var requestBodyLen int - - for _, group := range req.ReqList { - groupsVersionMap[group.GroupID] = &VersionInfo{ - GroupID: group.GroupID, - VersionID: group.VersionID, - VersionNumber: group.Version, - } - - groupIDs = append(groupIDs, group.GroupID) - } - - groups, err := s.db.FindGroup(ctx, groupIDs) - if err != nil { - return nil, errs.Wrap(err) - } - - for _, group := range groups { - if group.Status == constant.GroupStatusDismissed { - err = servererrs.ErrDismissedAlready.Wrap() - log.ZError(ctx, "This group is Dismissed Already", err, "group is", group.GroupID) - - delete(groupsVersionMap, group.GroupID) - } else { - groupsMap[group.GroupID] = group - } - } - - for groupID, vInfo := range groupsVersionMap { - targetKeys = append(targetKeys, groupID) - versionIDs = append(versionIDs, vInfo.VersionID) - versionNumbers = append(versionNumbers, vInfo.VersionNumber) - } - - opt := incrversion.BatchOption[[]*sdkws.GroupMemberFullInfo, pbgroup.BatchGetIncrementalGroupMemberResp]{ - Ctx: ctx, - TargetKeys: targetKeys, - VersionIDs: versionIDs, - VersionNumbers: versionNumbers, - Versions: func(ctx context.Context, groupIDs []string, versions []uint64, limits []int) (map[string]*model.VersionLog, error) { - vLogs, err := s.db.BatchFindMemberIncrVersion(ctx, groupIDs, versions, limits) - if err != nil { - return nil, errs.Wrap(err) - } - - for groupID, vlog := range vLogs { - vlogElems := make([]model.VersionLogElem, 0, len(vlog.Logs)) - for i, log := range vlog.Logs { - switch log.EID { - case model.VersionGroupChangeID: - vlog.LogLen-- - hasGroupUpdateMap[groupID] = true - case model.VersionSortChangeID: - vlog.LogLen-- - sortVersionMap[groupID] = uint64(log.Version) - default: - vlogElems = append(vlogElems, vlog.Logs[i]) - } - } - vlog.Logs = vlogElems - if vlog.LogLen > 0 { - hasGroupUpdateMap[groupID] = true - } - } - - return vLogs, nil - }, - CacheMaxVersions: s.db.BatchFindMaxGroupMemberVersionCache, - Find: func(ctx context.Context, groupID string, ids []string) ([]*sdkws.GroupMemberFullInfo, error) { - memberInfo, err := s.getGroupMembersInfo(ctx, groupID, ids) - if err != nil { - return nil, err - } - - return memberInfo, err - }, - Resp: func(versions map[string]*model.VersionLog, deleteIdsMap map[string][]string, insertListMap, updateListMap map[string][]*sdkws.GroupMemberFullInfo, fullMap map[string]bool) *pbgroup.BatchGetIncrementalGroupMemberResp { - resList := make(map[string]*pbgroup.GetIncrementalGroupMemberResp) - - for groupID, versionLog := range versions { - resList[groupID] = &pbgroup.GetIncrementalGroupMemberResp{ - VersionID: versionLog.ID.Hex(), - Version: uint64(versionLog.Version), - Full: fullMap[groupID], - Delete: deleteIdsMap[groupID], - Insert: insertListMap[groupID], - Update: updateListMap[groupID], - SortVersion: sortVersionMap[groupID], - } - - requestBodyLen += len(insertListMap[groupID]) + len(updateListMap[groupID]) + len(deleteIdsMap[groupID]) - if requestBodyLen > 200 { - break - } - } - - return &pbgroup.BatchGetIncrementalGroupMemberResp{ - RespList: resList, - } - }, - } - - resp, err = opt.Build() - if err != nil { - return nil, errs.Wrap(err) - } - - for groupID, val := range resp.RespList { - if val.Full || hasGroupUpdateMap[groupID] { - count, err := s.db.FindGroupMemberNum(ctx, groupID) - if err != nil { - return nil, err - } - - owner, err := s.db.TakeGroupOwner(ctx, groupID) - if err != nil { - return nil, err - } - - resp.RespList[groupID].Group = s.groupDB2PB(groupsMap[groupID], owner.UserID, count) - } - } - - return resp, nil - -} - -func (s *groupServer) GetIncrementalJoinGroup(ctx context.Context, req *pbgroup.GetIncrementalJoinGroupReq) (*pbgroup.GetIncrementalJoinGroupResp, error) { - if err := authverify.CheckAccessV3(ctx, req.UserID, s.config.Share.IMAdminUserID); err != nil { +func (g *groupServer) GetIncrementalJoinGroup(ctx context.Context, req *pbgroup.GetIncrementalJoinGroupReq) (*pbgroup.GetIncrementalJoinGroupResp, error) { + if err := authverify.CheckAccess(ctx, req.UserID); err != nil { return nil, err } opt := incrversion.Option[*sdkws.GroupInfo, pbgroup.GetIncrementalJoinGroupResp]{ diff --git a/internal/rpc/msg/clear.go b/internal/rpc/msg/clear.go index 8e14b281e..96eb99aed 100644 --- a/internal/rpc/msg/clear.go +++ b/internal/rpc/msg/clear.go @@ -2,15 +2,16 @@ package msg import ( "context" + "strings" + "github.com/openimsdk/open-im-server/v3/pkg/authverify" "github.com/openimsdk/protocol/msg" "github.com/openimsdk/tools/log" - "strings" ) // DestructMsgs hard delete in Database. func (m *msgServer) DestructMsgs(ctx context.Context, req *msg.DestructMsgsReq) (*msg.DestructMsgsResp, error) { - if err := authverify.CheckAdmin(ctx, m.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAdmin(ctx); err != nil { return nil, err } docs, err := m.MsgDatabase.GetRandBeforeMsg(ctx, req.Timestamp, int(req.Limit)) diff --git a/internal/rpc/msg/delete.go b/internal/rpc/msg/delete.go index d3485faaa..4590523d5 100644 --- a/internal/rpc/msg/delete.go +++ b/internal/rpc/msg/delete.go @@ -42,7 +42,7 @@ func (m *msgServer) validateDeleteSyncOpt(opt *msg.DeleteSyncOpt) (isSyncSelf, i } func (m *msgServer) ClearConversationsMsg(ctx context.Context, req *msg.ClearConversationsMsgReq) (*msg.ClearConversationsMsgResp, error) { - if err := authverify.CheckAccessV3(ctx, req.UserID, m.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAccess(ctx, req.UserID); err != nil { return nil, err } if err := m.clearConversation(ctx, req.ConversationIDs, req.UserID, req.DeleteSyncOpt); err != nil { @@ -52,7 +52,7 @@ func (m *msgServer) ClearConversationsMsg(ctx context.Context, req *msg.ClearCon } func (m *msgServer) UserClearAllMsg(ctx context.Context, req *msg.UserClearAllMsgReq) (*msg.UserClearAllMsgResp, error) { - if err := authverify.CheckAccessV3(ctx, req.UserID, m.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAccess(ctx, req.UserID); err != nil { return nil, err } conversationIDs, err := m.ConversationLocalCache.GetConversationIDs(ctx, req.UserID) @@ -66,7 +66,7 @@ func (m *msgServer) UserClearAllMsg(ctx context.Context, req *msg.UserClearAllMs } func (m *msgServer) DeleteMsgs(ctx context.Context, req *msg.DeleteMsgsReq) (*msg.DeleteMsgsResp, error) { - if err := authverify.CheckAccessV3(ctx, req.UserID, m.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAccess(ctx, req.UserID); err != nil { return nil, err } isSyncSelf, isSyncOther := m.validateDeleteSyncOpt(req.DeleteSyncOpt) @@ -102,7 +102,7 @@ func (m *msgServer) DeleteMsgPhysicalBySeq(ctx context.Context, req *msg.DeleteM } func (m *msgServer) DeleteMsgPhysical(ctx context.Context, req *msg.DeleteMsgPhysicalReq) (*msg.DeleteMsgPhysicalResp, error) { - if err := authverify.CheckAdmin(ctx, m.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAdmin(ctx); err != nil { return nil, err } remainTime := timeutil.GetCurrentTimestampBySecond() - req.Timestamp diff --git a/internal/rpc/msg/revoke.go b/internal/rpc/msg/revoke.go index c2fb5833f..bd1d66ba1 100644 --- a/internal/rpc/msg/revoke.go +++ b/internal/rpc/msg/revoke.go @@ -42,7 +42,7 @@ func (m *msgServer) RevokeMsg(ctx context.Context, req *msg.RevokeMsgReq) (*msg. if req.Seq < 0 { return nil, errs.ErrArgs.WrapMsg("seq is invalid") } - if err := authverify.CheckAccessV3(ctx, req.UserID, m.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAccess(ctx, req.UserID); err != nil { return nil, err } user, err := m.UserLocalCache.GetUserInfo(ctx, req.UserID) @@ -63,11 +63,11 @@ func (m *msgServer) RevokeMsg(ctx context.Context, req *msg.RevokeMsgReq) (*msg. data, _ := json.Marshal(msgs[0]) log.ZDebug(ctx, "GetMsgBySeqs", "conversationID", req.ConversationID, "seq", req.Seq, "msg", string(data)) var role int32 - if !authverify.IsAppManagerUid(ctx, m.config.Share.IMAdminUserID) { + if !authverify.IsAdmin(ctx) { sessionType := msgs[0].SessionType switch sessionType { case constant.SingleChatType: - if err := authverify.CheckAccessV3(ctx, msgs[0].SendID, m.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAccess(ctx, msgs[0].SendID); err != nil { return nil, err } role = user.AppMangerLevel diff --git a/internal/rpc/msg/sync_msg.go b/internal/rpc/msg/sync_msg.go index 6cf1c21d3..38eed93bc 100644 --- a/internal/rpc/msg/sync_msg.go +++ b/internal/rpc/msg/sync_msg.go @@ -118,7 +118,7 @@ func (m *msgServer) GetSeqMessage(ctx context.Context, req *msg.GetSeqMessageReq } func (m *msgServer) GetMaxSeq(ctx context.Context, req *sdkws.GetMaxSeqReq) (*sdkws.GetMaxSeqResp, error) { - if err := authverify.CheckAccessV3(ctx, req.UserID, m.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAccess(ctx, req.UserID); err != nil { return nil, err } conversationIDs, err := m.ConversationLocalCache.GetConversationIDs(ctx, req.UserID) diff --git a/internal/rpc/relation/black.go b/internal/rpc/relation/black.go index b795d6248..381a56273 100644 --- a/internal/rpc/relation/black.go +++ b/internal/rpc/relation/black.go @@ -30,10 +30,9 @@ import ( ) func (s *friendServer) GetPaginationBlacks(ctx context.Context, req *relation.GetPaginationBlacksReq) (resp *relation.GetPaginationBlacksResp, err error) { - if err := authverify.CheckAccessV3(ctx, req.UserID, s.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAccess(ctx, req.UserID); err != nil { return nil, err } - total, blacks, err := s.blackDatabase.FindOwnerBlacks(ctx, req.UserID, req.Pagination) if err != nil { return nil, err @@ -59,7 +58,7 @@ func (s *friendServer) IsBlack(ctx context.Context, req *relation.IsBlackReq) (* } func (s *friendServer) RemoveBlack(ctx context.Context, req *relation.RemoveBlackReq) (*relation.RemoveBlackResp, error) { - if err := authverify.CheckAccessV3(ctx, req.OwnerUserID, s.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAccess(ctx, req.OwnerUserID); err != nil { return nil, err } @@ -74,7 +73,7 @@ func (s *friendServer) RemoveBlack(ctx context.Context, req *relation.RemoveBlac } func (s *friendServer) AddBlack(ctx context.Context, req *relation.AddBlackReq) (*relation.AddBlackResp, error) { - if err := authverify.CheckAccessV3(ctx, req.OwnerUserID, s.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAccess(ctx, req.OwnerUserID); err != nil { return nil, err } @@ -100,7 +99,7 @@ func (s *friendServer) AddBlack(ctx context.Context, req *relation.AddBlackReq) } func (s *friendServer) GetSpecifiedBlacks(ctx context.Context, req *relation.GetSpecifiedBlacksReq) (*relation.GetSpecifiedBlacksResp, error) { - if err := authverify.CheckAccessV3(ctx, req.OwnerUserID, s.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAccess(ctx, req.OwnerUserID); err != nil { return nil, err } diff --git a/internal/rpc/relation/friend.go b/internal/rpc/relation/friend.go index 8172b8681..50a8667ea 100644 --- a/internal/rpc/relation/friend.go +++ b/internal/rpc/relation/friend.go @@ -135,7 +135,7 @@ func Start(ctx context.Context, config *Config, client discovery.SvcDiscoveryReg // ok. func (s *friendServer) ApplyToAddFriend(ctx context.Context, req *relation.ApplyToAddFriendReq) (resp *relation.ApplyToAddFriendResp, err error) { resp = &relation.ApplyToAddFriendResp{} - if err := authverify.CheckAccessV3(ctx, req.FromUserID, s.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAccess(ctx, req.FromUserID); err != nil { return nil, err } if req.ToUserID == req.FromUserID { @@ -165,7 +165,7 @@ func (s *friendServer) ApplyToAddFriend(ctx context.Context, req *relation.Apply // ok. func (s *friendServer) ImportFriends(ctx context.Context, req *relation.ImportFriendReq) (resp *relation.ImportFriendResp, err error) { - if err := authverify.CheckAdmin(ctx, s.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAdmin(ctx); err != nil { return nil, err } @@ -201,7 +201,7 @@ func (s *friendServer) ImportFriends(ctx context.Context, req *relation.ImportFr // ok. func (s *friendServer) RespondFriendApply(ctx context.Context, req *relation.RespondFriendApplyReq) (resp *relation.RespondFriendApplyResp, err error) { resp = &relation.RespondFriendApplyResp{} - if err := authverify.CheckAccessV3(ctx, req.ToUserID, s.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAccess(ctx, req.ToUserID); err != nil { return nil, err } @@ -236,7 +236,7 @@ func (s *friendServer) RespondFriendApply(ctx context.Context, req *relation.Res // ok. func (s *friendServer) DeleteFriend(ctx context.Context, req *relation.DeleteFriendReq) (resp *relation.DeleteFriendResp, err error) { - if err := authverify.CheckAccessV3(ctx, req.OwnerUserID, s.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAccess(ctx, req.OwnerUserID); err != nil { return nil, err } @@ -261,7 +261,7 @@ func (s *friendServer) SetFriendRemark(ctx context.Context, req *relation.SetFri return nil, err } - if err := authverify.CheckAccessV3(ctx, req.OwnerUserID, s.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAccess(ctx, req.OwnerUserID); err != nil { return nil, err } @@ -331,7 +331,7 @@ func (s *friendServer) GetDesignatedFriendsApply(ctx context.Context, // Get received friend requests (i.e., those initiated by others). func (s *friendServer) GetPaginationFriendsApplyTo(ctx context.Context, req *relation.GetPaginationFriendsApplyToReq) (resp *relation.GetPaginationFriendsApplyToResp, err error) { - if err := authverify.CheckAccessV3(ctx, req.UserID, s.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAccess(ctx, req.UserID); err != nil { return nil, err } @@ -354,7 +354,7 @@ func (s *friendServer) GetPaginationFriendsApplyTo(ctx context.Context, req *rel func (s *friendServer) GetPaginationFriendsApplyFrom(ctx context.Context, req *relation.GetPaginationFriendsApplyFromReq) (resp *relation.GetPaginationFriendsApplyFromResp, err error) { resp = &relation.GetPaginationFriendsApplyFromResp{} - if err := authverify.CheckAccessV3(ctx, req.UserID, s.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAccess(ctx, req.UserID); err != nil { return nil, err } @@ -384,7 +384,7 @@ func (s *friendServer) IsFriend(ctx context.Context, req *relation.IsFriendReq) } func (s *friendServer) GetPaginationFriends(ctx context.Context, req *relation.GetPaginationFriendsReq) (resp *relation.GetPaginationFriendsResp, err error) { - if err := authverify.CheckAccessV3(ctx, req.UserID, s.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAccess(ctx, req.UserID); err != nil { return nil, err } @@ -405,7 +405,7 @@ func (s *friendServer) GetPaginationFriends(ctx context.Context, req *relation.G } func (s *friendServer) GetFriendIDs(ctx context.Context, req *relation.GetFriendIDsReq) (resp *relation.GetFriendIDsResp, err error) { - if err := authverify.CheckAccessV3(ctx, req.UserID, s.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAccess(ctx, req.UserID); err != nil { return nil, err } diff --git a/internal/rpc/relation/sync.go b/internal/rpc/relation/sync.go index 0ad94fe82..79fa0858c 100644 --- a/internal/rpc/relation/sync.go +++ b/internal/rpc/relation/sync.go @@ -2,10 +2,11 @@ package relation import ( "context" + "slices" + "github.com/openimsdk/open-im-server/v3/pkg/util/hashutil" "github.com/openimsdk/protocol/sdkws" "github.com/openimsdk/tools/log" - "slices" "github.com/openimsdk/open-im-server/v3/internal/rpc/incrversion" "github.com/openimsdk/open-im-server/v3/pkg/authverify" @@ -39,6 +40,9 @@ func (s *friendServer) NotificationUserInfoUpdate(ctx context.Context, req *rela } func (s *friendServer) GetFullFriendUserIDs(ctx context.Context, req *relation.GetFullFriendUserIDsReq) (*relation.GetFullFriendUserIDsResp, error) { + if err := authverify.CheckAccess(ctx, req.UserID); err != nil { + return nil, err + } vl, err := s.db.FindMaxFriendVersionCache(ctx, req.UserID) if err != nil { return nil, err @@ -60,7 +64,7 @@ func (s *friendServer) GetFullFriendUserIDs(ctx context.Context, req *relation.G } func (s *friendServer) GetIncrementalFriends(ctx context.Context, req *relation.GetIncrementalFriendsReq) (*relation.GetIncrementalFriendsResp, error) { - if err := authverify.CheckAccessV3(ctx, req.UserID, s.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAccess(ctx, req.UserID); err != nil { return nil, err } var sortVersion uint64 diff --git a/internal/rpc/third/log.go b/internal/rpc/third/log.go index 4d8cbc0bb..fba3ecb88 100644 --- a/internal/rpc/third/log.go +++ b/internal/rpc/third/log.go @@ -82,7 +82,7 @@ func (t *thirdServer) UploadLogs(ctx context.Context, req *third.UploadLogsReq) } func (t *thirdServer) DeleteLogs(ctx context.Context, req *third.DeleteLogsReq) (*third.DeleteLogsResp, error) { - if err := authverify.CheckAdmin(ctx, t.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAdmin(ctx); err != nil { return nil, err } userID := "" @@ -123,7 +123,7 @@ func dbToPbLogInfos(logs []*relationtb.Log) []*third.LogInfo { } func (t *thirdServer) SearchLogs(ctx context.Context, req *third.SearchLogsReq) (*third.SearchLogsResp, error) { - if err := authverify.CheckAdmin(ctx, t.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAdmin(ctx); err != nil { return nil, err } var ( diff --git a/internal/rpc/third/s3.go b/internal/rpc/third/s3.go index 757320dac..bfdd9f1b8 100644 --- a/internal/rpc/third/s3.go +++ b/internal/rpc/third/s3.go @@ -198,7 +198,7 @@ func (t *thirdServer) InitiateFormData(ctx context.Context, req *third.InitiateF var duration time.Duration opUserID := mcontext.GetOpUserID(ctx) var key string - if t.IsManagerUserID(opUserID) { + if authverify.CheckUserIsAdmin(ctx, opUserID) { if req.Millisecond <= 0 { duration = time.Minute * 10 } else { @@ -289,7 +289,7 @@ func (t *thirdServer) apiAddress(prefix, name string) string { } func (t *thirdServer) DeleteOutdatedData(ctx context.Context, req *third.DeleteOutdatedDataReq) (*third.DeleteOutdatedDataResp, error) { - if err := authverify.CheckAdmin(ctx, t.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAdmin(ctx); err != nil { return nil, err } engine := t.config.RpcConfig.Object.Enable diff --git a/internal/rpc/third/tool.go b/internal/rpc/third/tool.go index 4e22ffbf9..a063fa654 100644 --- a/internal/rpc/third/tool.go +++ b/internal/rpc/third/tool.go @@ -54,7 +54,7 @@ func (t *thirdServer) checkUploadName(ctx context.Context, name string) error { if opUserID == "" { return errs.ErrNoPermission.WrapMsg("opUserID is empty") } - if !authverify.IsManagerUserID(opUserID, t.config.Share.IMAdminUserID) { + if !authverify.CheckUserIsAdmin(ctx, opUserID) { if !strings.HasPrefix(name, opUserID+"/") { return errs.ErrNoPermission.WrapMsg(fmt.Sprintf("name must start with `%s/`", opUserID)) } @@ -79,10 +79,6 @@ func checkValidObjectName(objectName string) error { return checkValidObjectNamePrefix(objectName) } -func (t *thirdServer) IsManagerUserID(opUserID string) bool { - return authverify.IsManagerUserID(opUserID, t.config.Share.IMAdminUserID) -} - func putUpdate[T any](update map[string]any, name string, val interface{ GetValuePtr() *T }) { ptrVal := val.GetValuePtr() if ptrVal == nil { diff --git a/internal/rpc/user/config.go b/internal/rpc/user/config.go index 5a9a46359..f3f5a7a96 100644 --- a/internal/rpc/user/config.go +++ b/internal/rpc/user/config.go @@ -11,7 +11,7 @@ import ( func (s *userServer) GetUserClientConfig(ctx context.Context, req *pbuser.GetUserClientConfigReq) (*pbuser.GetUserClientConfigResp, error) { if req.UserID != "" { - if err := authverify.CheckAccessV3(ctx, req.UserID, s.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAccess(ctx, req.UserID); err != nil { return nil, err } if _, err := s.db.GetUserByID(ctx, req.UserID); err != nil { @@ -26,7 +26,7 @@ func (s *userServer) GetUserClientConfig(ctx context.Context, req *pbuser.GetUse } func (s *userServer) SetUserClientConfig(ctx context.Context, req *pbuser.SetUserClientConfigReq) (*pbuser.SetUserClientConfigResp, error) { - if err := authverify.CheckAdmin(ctx, s.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAdmin(ctx); err != nil { return nil, err } if req.UserID != "" { @@ -41,7 +41,7 @@ func (s *userServer) SetUserClientConfig(ctx context.Context, req *pbuser.SetUse } func (s *userServer) DelUserClientConfig(ctx context.Context, req *pbuser.DelUserClientConfigReq) (*pbuser.DelUserClientConfigResp, error) { - if err := authverify.CheckAdmin(ctx, s.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAdmin(ctx); err != nil { return nil, err } if err := s.clientConfig.DelUserConfig(ctx, req.UserID, req.Keys); err != nil { @@ -51,7 +51,7 @@ func (s *userServer) DelUserClientConfig(ctx context.Context, req *pbuser.DelUse } func (s *userServer) PageUserClientConfig(ctx context.Context, req *pbuser.PageUserClientConfigReq) (*pbuser.PageUserClientConfigResp, error) { - if err := authverify.CheckAdmin(ctx, s.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAdmin(ctx); err != nil { return nil, err } total, res, err := s.clientConfig.GetUserConfigPage(ctx, req.UserID, req.Key, req.Pagination) diff --git a/internal/rpc/user/user.go b/internal/rpc/user/user.go index 0e35aba6e..7f082f784 100644 --- a/internal/rpc/user/user.go +++ b/internal/rpc/user/user.go @@ -23,29 +23,27 @@ import ( "time" "github.com/openimsdk/open-im-server/v3/internal/rpc/relation" + "github.com/openimsdk/open-im-server/v3/pkg/authverify" "github.com/openimsdk/open-im-server/v3/pkg/common/config" + "github.com/openimsdk/open-im-server/v3/pkg/common/convert" "github.com/openimsdk/open-im-server/v3/pkg/common/prommetrics" + "github.com/openimsdk/open-im-server/v3/pkg/common/servererrs" "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache" "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/redis" + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/controller" "github.com/openimsdk/open-im-server/v3/pkg/common/storage/database/mgo" tablerelation "github.com/openimsdk/open-im-server/v3/pkg/common/storage/model" "github.com/openimsdk/open-im-server/v3/pkg/common/webhook" + "github.com/openimsdk/open-im-server/v3/pkg/dbbuild" "github.com/openimsdk/open-im-server/v3/pkg/localcache" "github.com/openimsdk/open-im-server/v3/pkg/rpcli" + "github.com/openimsdk/protocol/constant" "github.com/openimsdk/protocol/group" friendpb "github.com/openimsdk/protocol/relation" - "github.com/openimsdk/tools/db/redisutil" - - "github.com/openimsdk/open-im-server/v3/pkg/authverify" - "github.com/openimsdk/open-im-server/v3/pkg/common/convert" - "github.com/openimsdk/open-im-server/v3/pkg/common/servererrs" - "github.com/openimsdk/open-im-server/v3/pkg/common/storage/controller" - "github.com/openimsdk/protocol/constant" "github.com/openimsdk/protocol/sdkws" pbuser "github.com/openimsdk/protocol/user" - "github.com/openimsdk/tools/db/mongoutil" "github.com/openimsdk/tools/db/pagination" - registry "github.com/openimsdk/tools/discovery" + "github.com/openimsdk/tools/discovery" "github.com/openimsdk/tools/errs" "github.com/openimsdk/tools/utils/datautil" "google.golang.org/grpc" @@ -61,7 +59,7 @@ type userServer struct { db controller.UserDatabase friendNotificationSender *relation.FriendNotificationSender userNotificationSender *UserNotificationSender - RegisterCenter registry.SvcDiscoveryRegistry + RegisterCenter discovery.Conn config *Config webhookClient *webhook.Client groupClient *rpcli.GroupClient @@ -81,19 +79,21 @@ type Config struct { Discovery config.Discovery } -func Start(ctx context.Context, config *Config, client registry.SvcDiscoveryRegistry, server *grpc.Server) error { - mgocli, err := mongoutil.NewMongoDB(ctx, config.MongodbConfig.Build()) +func Start(ctx context.Context, config *Config, client discovery.Conn, server grpc.ServiceRegistrar) error { + dbb := dbbuild.NewBuilder(&config.MongodbConfig, &config.RedisConfig) + mgocli, err := dbb.Mongo(ctx) if err != nil { return err } - rdb, err := redisutil.NewRedisClient(ctx, config.RedisConfig.Build()) + rdb, err := dbb.Redis(ctx) if err != nil { return err } + users := make([]*tablerelation.User, 0) for _, v := range config.Share.IMAdminUserID { - users = append(users, &tablerelation.User{UserID: v, Nickname: v, AppMangerLevel: constant.AppNotificationAdmin}) + users = append(users, &tablerelation.User{UserID: v, Nickname: v, AppMangerLevel: constant.AppAdmin}) } userDB, err := mgo.NewUserMongo(mgocli.GetDB()) if err != nil { @@ -150,7 +150,7 @@ func (s *userServer) GetDesignateUsers(ctx context.Context, req *pbuser.GetDesig // UpdateUserInfo func (s *userServer) UpdateUserInfo(ctx context.Context, req *pbuser.UpdateUserInfoReq) (resp *pbuser.UpdateUserInfoResp, err error) { resp = &pbuser.UpdateUserInfoResp{} - err = authverify.CheckAccessV3(ctx, req.UserInfo.UserID, s.config.Share.IMAdminUserID) + err = authverify.CheckAccess(ctx, req.UserInfo.UserID) if err != nil { return nil, err } @@ -177,7 +177,7 @@ func (s *userServer) UpdateUserInfo(ctx context.Context, req *pbuser.UpdateUserI func (s *userServer) UpdateUserInfoEx(ctx context.Context, req *pbuser.UpdateUserInfoExReq) (resp *pbuser.UpdateUserInfoExResp, err error) { resp = &pbuser.UpdateUserInfoExResp{} - err = authverify.CheckAccessV3(ctx, req.UserInfo.UserID, s.config.Share.IMAdminUserID) + err = authverify.CheckAccess(ctx, req.UserInfo.UserID) if err != nil { return nil, err } @@ -235,8 +235,7 @@ func (s *userServer) AccountCheck(ctx context.Context, req *pbuser.AccountCheckR if datautil.Duplicate(req.CheckUserIDs) { return nil, errs.ErrArgs.WrapMsg("userID repeated") } - err = authverify.CheckAdmin(ctx, s.config.Share.IMAdminUserID) - if err != nil { + if err = authverify.CheckAdmin(ctx); err != nil { return nil, err } users, err := s.db.Find(ctx, req.CheckUserIDs) @@ -283,14 +282,12 @@ func (s *userServer) UserRegister(ctx context.Context, req *pbuser.UserRegisterR return nil, errs.ErrArgs.WrapMsg("users is empty") } // check if secret is changed - if s.config.Share.Secret == defaultSecret { - return nil, servererrs.ErrSecretNotChanged.Wrap() - } - - if err = authverify.CheckAdmin(ctx, s.config.Share.IMAdminUserID); err != nil { + //if s.config.Share.Secret == defaultSecret { + // return nil, servererrs.ErrSecretNotChanged.Wrap() + //} + if err = authverify.CheckAdmin(ctx); err != nil { return nil, err } - if datautil.DuplicateAny(req.Users, func(e *sdkws.UserInfo) string { return e.UserID }) { return nil, errs.ErrArgs.WrapMsg("userID repeated") } @@ -356,7 +353,7 @@ func (s *userServer) GetAllUserID(ctx context.Context, req *pbuser.GetAllUserIDR // ProcessUserCommandAdd user general function add. func (s *userServer) ProcessUserCommandAdd(ctx context.Context, req *pbuser.ProcessUserCommandAddReq) (*pbuser.ProcessUserCommandAddResp, error) { - err := authverify.CheckAccessV3(ctx, req.UserID, s.config.Share.IMAdminUserID) + err := authverify.CheckAccess(ctx, req.UserID) if err != nil { return nil, err } @@ -384,7 +381,7 @@ func (s *userServer) ProcessUserCommandAdd(ctx context.Context, req *pbuser.Proc // ProcessUserCommandDelete user general function delete. func (s *userServer) ProcessUserCommandDelete(ctx context.Context, req *pbuser.ProcessUserCommandDeleteReq) (*pbuser.ProcessUserCommandDeleteResp, error) { - err := authverify.CheckAccessV3(ctx, req.UserID, s.config.Share.IMAdminUserID) + err := authverify.CheckAccess(ctx, req.UserID) if err != nil { return nil, err } @@ -403,7 +400,7 @@ func (s *userServer) ProcessUserCommandDelete(ctx context.Context, req *pbuser.P // ProcessUserCommandUpdate user general function update. func (s *userServer) ProcessUserCommandUpdate(ctx context.Context, req *pbuser.ProcessUserCommandUpdateReq) (*pbuser.ProcessUserCommandUpdateResp, error) { - err := authverify.CheckAccessV3(ctx, req.UserID, s.config.Share.IMAdminUserID) + err := authverify.CheckAccess(ctx, req.UserID) if err != nil { return nil, err } @@ -432,7 +429,7 @@ func (s *userServer) ProcessUserCommandUpdate(ctx context.Context, req *pbuser.P func (s *userServer) ProcessUserCommandGet(ctx context.Context, req *pbuser.ProcessUserCommandGetReq) (*pbuser.ProcessUserCommandGetResp, error) { - err := authverify.CheckAccessV3(ctx, req.UserID, s.config.Share.IMAdminUserID) + err := authverify.CheckAccess(ctx, req.UserID) if err != nil { return nil, err } @@ -461,7 +458,7 @@ func (s *userServer) ProcessUserCommandGet(ctx context.Context, req *pbuser.Proc } func (s *userServer) ProcessUserCommandGetAll(ctx context.Context, req *pbuser.ProcessUserCommandGetAllReq) (*pbuser.ProcessUserCommandGetAllResp, error) { - err := authverify.CheckAccessV3(ctx, req.UserID, s.config.Share.IMAdminUserID) + err := authverify.CheckAccess(ctx, req.UserID) if err != nil { return nil, err } @@ -490,7 +487,7 @@ func (s *userServer) ProcessUserCommandGetAll(ctx context.Context, req *pbuser.P } func (s *userServer) AddNotificationAccount(ctx context.Context, req *pbuser.AddNotificationAccountReq) (*pbuser.AddNotificationAccountResp, error) { - if err := authverify.CheckAdmin(ctx, s.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAdmin(ctx); err != nil { return nil, err } if req.AppMangerLevel < constant.AppNotificationAdmin { @@ -536,7 +533,7 @@ func (s *userServer) AddNotificationAccount(ctx context.Context, req *pbuser.Add } func (s *userServer) UpdateNotificationAccountInfo(ctx context.Context, req *pbuser.UpdateNotificationAccountInfoReq) (*pbuser.UpdateNotificationAccountInfoResp, error) { - if err := authverify.CheckAdmin(ctx, s.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAdmin(ctx); err != nil { return nil, err } @@ -563,7 +560,7 @@ func (s *userServer) UpdateNotificationAccountInfo(ctx context.Context, req *pbu func (s *userServer) SearchNotificationAccount(ctx context.Context, req *pbuser.SearchNotificationAccountReq) (*pbuser.SearchNotificationAccountResp, error) { // Check if user is an admin - if err := authverify.CheckAdmin(ctx, s.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAdmin(ctx); err != nil { return nil, err } @@ -618,7 +615,7 @@ func (s *userServer) GetNotificationAccount(ctx context.Context, req *pbuser.Get if err != nil { return nil, servererrs.ErrUserIDNotFound.Wrap() } - if user.AppMangerLevel == constant.AppAdmin || user.AppMangerLevel >= constant.AppNotificationAdmin { + if user.AppMangerLevel >= constant.AppAdmin { return &pbuser.GetNotificationAccountResp{Account: &pbuser.NotificationAccountInfo{ UserID: user.UserID, FaceURL: user.FaceURL, diff --git a/pkg/authverify/token.go b/pkg/authverify/token.go index 872feb1cf..75fb1448b 100644 --- a/pkg/authverify/token.go +++ b/pkg/authverify/token.go @@ -31,32 +31,49 @@ func Secret(secret string) jwt.Keyfunc { } } -func CheckAccessV3(ctx context.Context, ownerUserID string, imAdminUserID []string) (err error) { - opUserID := mcontext.GetOpUserID(ctx) - if datautil.Contain(opUserID, imAdminUserID...) { - return nil - } - if opUserID == ownerUserID { - return nil - } - return servererrs.ErrNoPermission.WrapMsg("ownerUserID", ownerUserID) -} - -func IsAppManagerUid(ctx context.Context, imAdminUserID []string) bool { - return datautil.Contain(mcontext.GetOpUserID(ctx), imAdminUserID...) -} - -func CheckAdmin(ctx context.Context, imAdminUserID []string) error { - if datautil.Contain(mcontext.GetOpUserID(ctx), imAdminUserID...) { +func CheckAdmin(ctx context.Context) error { + if IsAdmin(ctx) { return nil } return servererrs.ErrNoPermission.WrapMsg(fmt.Sprintf("user %s is not admin userID", mcontext.GetOpUserID(ctx))) } -func IsManagerUserID(opUserID string, imAdminUserID []string) bool { - return datautil.Contain(opUserID, imAdminUserID...) +//func IsManagerUserID(opUserID string, imAdminUserID []string) bool { +// return datautil.Contain(opUserID, imAdminUserID...) +//} + +func CheckUserIsAdmin(ctx context.Context, userID string) bool { + return datautil.Contain(userID, GetIMAdminUserIDs(ctx)...) } func CheckSystemAccount(ctx context.Context, level int32) bool { return level >= constant.AppAdmin } + +const ( + CtxIsAdminKey = "CtxIsAdminKey" +) + +func WithIMAdminUserIDs(ctx context.Context, imAdminUserID []string) context.Context { + return context.WithValue(ctx, CtxIsAdminKey, imAdminUserID) +} + +func GetIMAdminUserIDs(ctx context.Context) []string { + imAdminUserID, _ := ctx.Value(CtxIsAdminKey).([]string) + return imAdminUserID +} + +func IsAdmin(ctx context.Context) bool { + return datautil.Contain(mcontext.GetOpUserID(ctx), GetIMAdminUserIDs(ctx)...) +} + +func CheckAccess(ctx context.Context, ownerUserID string) error { + opUserID := mcontext.GetOpUserID(ctx) + if opUserID == ownerUserID { + return nil + } + if datautil.Contain(mcontext.GetOpUserID(ctx), GetIMAdminUserIDs(ctx)...) { + return nil + } + return servererrs.ErrNoPermission.WrapMsg("ownerUserID", ownerUserID) +} diff --git a/pkg/common/startrpc/mw.go b/pkg/common/startrpc/mw.go new file mode 100644 index 000000000..c6cd55380 --- /dev/null +++ b/pkg/common/startrpc/mw.go @@ -0,0 +1,15 @@ +package startrpc + +import ( + "context" + + "github.com/openimsdk/open-im-server/v3/pkg/authverify" + "google.golang.org/grpc" +) + +func grpcServerIMAdminUserID(imAdminUserID []string) grpc.ServerOption { + return grpc.ChainUnaryInterceptor(func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) { + ctx = authverify.WithIMAdminUserIDs(ctx, imAdminUserID) + return handler(ctx, req) + }) +} diff --git a/pkg/common/startrpc/start.go b/pkg/common/startrpc/start.go index 70ea885f7..03621343b 100644 --- a/pkg/common/startrpc/start.go +++ b/pkg/common/startrpc/start.go @@ -19,7 +19,6 @@ import ( "errors" "fmt" "net" - "net/http" "os" "os/signal" "reflect" @@ -28,21 +27,18 @@ import ( "time" conf "github.com/openimsdk/open-im-server/v3/pkg/common/config" - disetcd "github.com/openimsdk/open-im-server/v3/pkg/common/discovery/etcd" - "github.com/openimsdk/tools/discovery/etcd" "github.com/openimsdk/tools/utils/datautil" "github.com/openimsdk/tools/utils/jsonutil" + "github.com/openimsdk/tools/utils/network" "google.golang.org/grpc/status" - "github.com/openimsdk/tools/utils/runtimeenv" - kdisc "github.com/openimsdk/open-im-server/v3/pkg/common/discovery" "github.com/openimsdk/open-im-server/v3/pkg/common/prommetrics" "github.com/openimsdk/tools/discovery" "github.com/openimsdk/tools/errs" "github.com/openimsdk/tools/log" - "github.com/openimsdk/tools/mw" - "github.com/openimsdk/tools/utils/network" + grpccli "github.com/openimsdk/tools/mw/grpc/client" + grpcsrv "github.com/openimsdk/tools/mw/grpc/server" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" ) @@ -81,31 +77,59 @@ func getConfigRpcMaxRequestBody(value reflect.Value) *conf.MaxRequestBody { return nil } +func getConfigShare(value reflect.Value) *conf.Share { + for value.Kind() == reflect.Pointer { + value = value.Elem() + } + if value.Kind() == reflect.Struct { + num := value.NumField() + for i := 0; i < num; i++ { + field := value.Field(i) + if !field.CanInterface() { + continue + } + for field.Kind() == reflect.Pointer { + field = field.Elem() + } + switch elem := field.Interface().(type) { + case conf.Share: + return &elem + } + if field.Kind() == reflect.Struct { + if elem := getConfigShare(field); elem != nil { + return elem + } + } + } + } + return nil +} + func Start[T any](ctx context.Context, disc *conf.Discovery, prometheusConfig *conf.Prometheus, listenIP, registerIP string, autoSetPorts bool, rpcPorts []int, index int, rpcRegisterName string, notification *conf.Notification, config T, watchConfigNames []string, watchServiceNames []string, - rpcFn func(ctx context.Context, config T, client discovery.SvcDiscoveryRegistry, server *grpc.Server) error, + rpcFn func(ctx context.Context, config T, client discovery.Conn, server grpc.ServiceRegistrar) error, options ...grpc.ServerOption) error { - watchConfigNames = append(watchConfigNames, conf.LogConfigFileName) - var ( - rpcTcpAddr string - netDone = make(chan struct{}, 2) - netErr error - prometheusPort int - ) - if notification != nil { conf.InitNotification(notification) } maxRequestBody := getConfigRpcMaxRequestBody(reflect.ValueOf(config)) + shareConfig := getConfigShare(reflect.ValueOf(config)) log.ZDebug(ctx, "rpc start", "rpcMaxRequestBody", maxRequestBody, "rpcRegisterName", rpcRegisterName, "registerIP", registerIP, "listenIP", listenIP) options = append(options, - mw.GrpcServer(), + grpcsrv.GrpcServerMetadataContext(), + grpcsrv.GrpcServerLogger(), + grpcsrv.GrpcServerErrorConvert(), + grpcsrv.GrpcServerRequestValidate(), + grpcsrv.GrpcServerPanicCapture(), ) + if shareConfig != nil && len(shareConfig.IMAdminUserID) > 0 { + options = append(options, grpcServerIMAdminUserID(shareConfig.IMAdminUserID)) + } var clientOptions []grpc.DialOption if maxRequestBody != nil { if maxRequestBody.RequestMaxBodySize > 0 { @@ -122,41 +146,32 @@ func Start[T any](ctx context.Context, disc *conf.Discovery, prometheusConfig *c if err != nil { return err } - - runTimeEnv := runtimeenv.RuntimeEnvironment() - - if !autoSetPorts { - rpcPort, err := datautil.GetElemByIndex(rpcPorts, index) + var prometheusListenAddr string + if autoSetPorts { + prometheusListenAddr = net.JoinHostPort(listenIP, "0") + } else { + prometheusPort, err := datautil.GetElemByIndex(prometheusConfig.Ports, index) if err != nil { return err } - rpcTcpAddr = net.JoinHostPort(network.GetListenIP(listenIP), strconv.Itoa(rpcPort)) - } else { - rpcTcpAddr = net.JoinHostPort(network.GetListenIP(listenIP), "0") + prometheusListenAddr = net.JoinHostPort(listenIP, strconv.Itoa(prometheusPort)) } - getAutoPort := func() (net.Listener, int, error) { - listener, err := net.Listen("tcp", rpcTcpAddr) - if err != nil { - return nil, 0, errs.WrapMsg(err, "listen err", "rpcTcpAddr", rpcTcpAddr) - } - _, portStr, _ := net.SplitHostPort(listener.Addr().String()) - port, _ := strconv.Atoi(portStr) - return listener, port, nil - } + watchConfigNames = append(watchConfigNames, conf.LogConfigFileName) - if autoSetPorts && discovery.Enable != conf.ETCD { - return errs.New("only etcd support autoSetPorts", "rpcRegisterName", rpcRegisterName).Wrap() - } - client, err := kdisc.NewDiscoveryRegister(discovery, runTimeEnv, watchServiceNames) + client, err := kdisc.NewDiscoveryRegister(disc, watchServiceNames) if err != nil { return err } defer client.Close() client.AddOption( - mw.GrpcClient(), grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"LoadBalancingPolicy": "%s"}`, "round_robin")), + + grpccli.GrpcClientLogger(), + grpccli.GrpcClientContext(), + grpccli.GrpcClientErrorConvert(), ) if len(clientOptions) > 0 { client.AddOption(clientOptions...) @@ -178,122 +193,111 @@ func Start[T any](ctx context.Context, disc *conf.Discovery, prometheusConfig *c if prometheusListenAddr != "" { options = append( - options, mw.GrpcServer(), + options, prommetricsUnaryInterceptor(rpcRegisterName), prommetricsStreamInterceptor(rpcRegisterName), ) - - var ( - listener net.Listener - ) - - if autoSetPorts { - listener, prometheusPort, err = getAutoPort() - if err != nil { - return err - } - - etcdClient := client.(*etcd.SvcDiscoveryRegistryImpl).GetClient() - - _, err = etcdClient.Put(ctx, prommetrics.BuildDiscoveryKey(rpcRegisterName), jsonutil.StructToJsonString(prommetrics.BuildDefaultTarget(registerIP, prometheusPort))) - if err != nil { - return errs.WrapMsg(err, "etcd put err") - } - } else { - prometheusPort, err = datautil.GetElemByIndex(prometheusConfig.Ports, index) - if err != nil { - return err - } - listener, err = net.Listen("tcp", fmt.Sprintf(":%d", prometheusPort)) - if err != nil { - return errs.WrapMsg(err, "listen err", "rpcTcpAddr", rpcTcpAddr) - } - } - cs := prommetrics.GetGrpcCusMetrics(rpcRegisterName, discovery) - go func() { - if err := prommetrics.RpcInit(cs, listener); err != nil && !errors.Is(err, http.ErrServerClosed) { - netErr = errs.WrapMsg(err, fmt.Sprintf("rpc %s prometheus start err: %d", rpcRegisterName, prometheusPort)) - netDone <- struct{}{} - } - //metric.InitializeMetrics(srv) - // Create a HTTP server for prometheus. - // httpServer = &http.Server{Handler: promhttp.HandlerFor(reg, promhttp.HandlerOpts{}), Addr: fmt.Sprintf("0.0.0.0:%d", prometheusPort)} - // if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { - // netErr = errs.WrapMsg(err, "prometheus start err", httpServer.Addr) - // netDone <- struct{}{} - // } - }() - } else { - options = append(options, mw.GrpcServer()) - } - - listener, port, err := getAutoPort() - if err != nil { - return err - } - - log.CInfo(ctx, "RPC server is initializing", "rpcRegisterName", rpcRegisterName, "rpcPort", port, - "prometheusPort", prometheusPort) - - defer listener.Close() - srv := grpc.NewServer(options...) - - err = rpcFn(ctx, config, client, srv) - if err != nil { - return err - } - - err = client.Register( - ctx, - rpcRegisterName, - registerIP, - port, - grpc.WithTransportCredentials(insecure.NewCredentials()), - ) - if err != nil { - return err - } - - go func() { - err := srv.Serve(listener) - if err != nil && !errors.Is(err, http.ErrServerClosed) { - netErr = errs.WrapMsg(err, "rpc start err: ", rpcTcpAddr) - netDone <- struct{}{} - } - }() - - if discovery.Enable == conf.ETCD { - cm := disetcd.NewConfigManager(client.(*etcd.SvcDiscoveryRegistryImpl).GetClient(), watchConfigNames) - cm.Watch(ctx) - } - - sigs := make(chan os.Signal, 1) - signal.Notify(sigs, syscall.SIGTERM) - select { - case <-sigs: - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - if err := gracefulStopWithCtx(ctx, srv.GracefulStop); err != nil { + prometheusListener, prometheusPort, err := listenTCP(prometheusListenAddr) + if err != nil { return err } - return nil - case <-netDone: - return netErr + log.ZDebug(ctx, "prometheus start", "addr", prometheusListener.Addr(), "rpcRegisterName", rpcRegisterName) + target, err := jsonutil.JsonMarshal(prommetrics.BuildDefaultTarget(registerIP, prometheusPort)) + if err != nil { + return err + } + if err := client.SetKey(ctx, prommetrics.BuildDiscoveryKey(prommetrics.APIKeyName), target); err != nil { + if !errors.Is(err, discovery.ErrNotSupportedKeyValue) { + return err + } + } + go func() { + err := prommetrics.Start(prometheusListener) + if err == nil { + err = fmt.Errorf("listener done") + } + cancel(fmt.Errorf("prommetrics %s %w", rpcRegisterName, err)) + }() } + + var ( + rpcServer *grpc.Server + rpcGracefulStop chan struct{} + ) + + onGrpcServiceRegistrar := func(desc *grpc.ServiceDesc, impl any) { + if rpcServer != nil { + rpcServer.RegisterService(desc, impl) + return + } + var rpcListenAddr string + if autoSetPorts { + rpcListenAddr = net.JoinHostPort(listenIP, "0") + } else { + rpcPort, err := datautil.GetElemByIndex(rpcPorts, index) + if err != nil { + cancel(fmt.Errorf("rpcPorts index out of range %s %w", rpcRegisterName, err)) + return + } + rpcListenAddr = net.JoinHostPort(listenIP, strconv.Itoa(rpcPort)) + } + rpcListener, err := net.Listen("tcp", rpcListenAddr) + if err != nil { + cancel(fmt.Errorf("listen rpc %s %s %w", rpcRegisterName, rpcListenAddr, err)) + return + } + + rpcServer = grpc.NewServer(options...) + rpcServer.RegisterService(desc, impl) + rpcGracefulStop = make(chan struct{}) + rpcPort := rpcListener.Addr().(*net.TCPAddr).Port + log.ZDebug(ctx, "rpc start register", "rpcRegisterName", rpcRegisterName, "registerIP", registerIP, "rpcPort", rpcPort) + grpcOpt := grpc.WithTransportCredentials(insecure.NewCredentials()) + rpcGracefulStop = make(chan struct{}) + go func() { + <-ctx.Done() + rpcServer.GracefulStop() + close(rpcGracefulStop) + }() + if err := client.Register(ctx, rpcRegisterName, registerIP, rpcListener.Addr().(*net.TCPAddr).Port, grpcOpt); err != nil { + cancel(fmt.Errorf("rpc register %s %w", rpcRegisterName, err)) + return + } + + go func() { + err := rpcServer.Serve(rpcListener) + if err == nil { + err = fmt.Errorf("serve end") + } + cancel(fmt.Errorf("rpc %s %w", rpcRegisterName, err)) + }() + } + + err = rpcFn(ctx, config, client, &grpcServiceRegistrar{onRegisterService: onGrpcServiceRegistrar}) + if err != nil { + return err + } + <-ctx.Done() + log.ZDebug(ctx, "cmd wait done", "err", context.Cause(ctx)) + if rpcGracefulStop != nil { + timeout := time.NewTimer(time.Second * 15) + defer timeout.Stop() + select { + case <-timeout.C: + log.ZWarn(ctx, "rcp graceful stop timeout", nil) + case <-rpcGracefulStop: + log.ZDebug(ctx, "rcp graceful stop done") + } + } + return context.Cause(ctx) } -func gracefulStopWithCtx(ctx context.Context, f func()) error { - done := make(chan struct{}, 1) - go func() { - f() - close(done) - }() - select { - case <-ctx.Done(): - return errs.New("timeout, ctx graceful stop") - case <-done: - return nil +func listenTCP(addr string) (net.Listener, int, error) { + listener, err := net.Listen("tcp", addr) + if err != nil { + return nil, 0, errs.WrapMsg(err, "listen err", "addr", addr) } + return listener, listener.Addr().(*net.TCPAddr).Port, nil } func prommetricsUnaryInterceptor(rpcRegisterName string) grpc.ServerOption { @@ -317,3 +321,11 @@ func prommetricsUnaryInterceptor(rpcRegisterName string) grpc.ServerOption { func prommetricsStreamInterceptor(rpcRegisterName string) grpc.ServerOption { return grpc.ChainStreamInterceptor() } + +type grpcServiceRegistrar struct { + onRegisterService func(desc *grpc.ServiceDesc, impl any) +} + +func (x *grpcServiceRegistrar) RegisterService(desc *grpc.ServiceDesc, impl any) { + x.onRegisterService(desc, impl) +}