diff --git a/internal/rpc/group/group.go b/internal/rpc/group/group.go index 540c3f20d..9ef0b1667 100644 --- a/internal/rpc/group/group.go +++ b/internal/rpc/group/group.go @@ -12,6 +12,7 @@ import ( "Open_IM/pkg/common/middleware" promePkg "Open_IM/pkg/common/prometheus" "Open_IM/pkg/common/token_verify" + "Open_IM/pkg/common/tools" "Open_IM/pkg/common/trace_log" cp "Open_IM/pkg/common/utils" "Open_IM/pkg/getcdv3" @@ -115,20 +116,9 @@ func (s *groupServer) Run() { log.NewInfo("", "group rpc success") } -func OperationID(ctx context.Context) string { - s, _ := ctx.Value("operationID").(string) - return s -} - -func OpUserID(ctx context.Context) string { - s, _ := ctx.Value("opUserID").(string) - return s -} - func (s *groupServer) CreateGroup(ctx context.Context, req *pbGroup.CreateGroupReq) (*pbGroup.CreateGroupResp, error) { - resp := &pbGroup.CreateGroupResp{GroupInfo: &open_im_sdk.GroupInfo{}} - if err := token_verify.CheckAccessV2(ctx, req.OpUserID, req.OwnerUserID); err != nil { + if err := token_verify.CheckAccessV3(ctx, req.OwnerUserID); err != nil { return nil, err } var groupOwnerNum int @@ -236,7 +226,7 @@ func (s *groupServer) CreateGroup(ctx context.Context, req *pbGroup.CreateGroupR func (s *groupServer) GetJoinedGroupList(ctx context.Context, req *pbGroup.GetJoinedGroupListReq) (*pbGroup.GetJoinedGroupListResp, error) { resp := &pbGroup.GetJoinedGroupListResp{} - if err := token_verify.CheckAccessV2(ctx, req.OpUserID, req.FromUserID); err != nil { + if err := token_verify.CheckAccessV3(ctx, req.FromUserID); err != nil { return nil, err } joinedGroupList, err := rocksCache.GetJoinedGroupIDListFromCache(ctx, req.FromUserID) @@ -281,10 +271,11 @@ func (s *groupServer) GetJoinedGroupList(ctx context.Context, req *pbGroup.GetJo func (s *groupServer) InviteUserToGroup(ctx context.Context, req *pbGroup.InviteUserToGroupReq) (*pbGroup.InviteUserToGroupResp, error) { resp := &pbGroup.InviteUserToGroupResp{} - - if !imdb.IsExistGroupMember(req.GroupID, req.OpUserID) && !token_verify.IsManagerUserID(req.OpUserID) { - constant.SetErrorForResp(constant.ErrIdentity, resp.CommonResp) - return nil, utils.Wrap(constant.ErrIdentity, "") + opUserID := tools.OpUserID(ctx) + if err := token_verify.CheckManagerUserID(ctx, opUserID); err != nil { + if err := imdb.CheckIsExistGroupMember(ctx, req.GroupID, opUserID); err != nil { + return nil, err + } } groupInfo, err := (*imdb.Group)(nil).Take(ctx, req.GroupID) if err != nil { @@ -741,7 +732,7 @@ func (s *groupServer) GroupApplicationResponse(ctx context.Context, req *pbGroup chat.GroupApplicationRejectedNotification(req) } else { //return nil, utils.Wrap(constant.ErrArgs, "") - return nil, constant.ErrArgs.Warp() + return nil, constant.ErrArgs.Wrap() } return resp, nil } diff --git a/pkg/common/constant/err_info.go b/pkg/common/constant/err_info.go index 7c9263665..1e030e659 100644 --- a/pkg/common/constant/err_info.go +++ b/pkg/common/constant/err_info.go @@ -7,6 +7,7 @@ import ( "fmt" "github.com/pkg/errors" "gorm.io/gorm" + "strings" ) type ErrInfo struct { @@ -23,20 +24,8 @@ func (e *ErrInfo) Code() int32 { return e.ErrCode } -func (e *ErrInfo) Msg(msg string) *ErrInfo { - return &ErrInfo{ - ErrCode: e.ErrCode, - ErrMsg: msg, - DetailErrMsg: e.DetailErrMsg, - } -} - -func (e *ErrInfo) Warp() error { - return errors.WithStack(e) -} - -func (e *ErrInfo) WarpMessage(msg string) error { - return errors.WithMessage(e, "") +func (e *ErrInfo) Wrap(msg ...string) error { + return errors.Wrap(e, strings.Join(msg, "--")) } func NewErrNetwork(err error) error { diff --git a/pkg/common/db/mysql_model/im_mysql_model/group_member_model_k.go b/pkg/common/db/mysql_model/im_mysql_model/group_member_model_k.go index 03ef2b741..b695119ad 100644 --- a/pkg/common/db/mysql_model/im_mysql_model/group_member_model_k.go +++ b/pkg/common/db/mysql_model/im_mysql_model/group_member_model_k.go @@ -211,6 +211,18 @@ func IsExistGroupMember(groupID, userID string) bool { return true } +func CheckIsExistGroupMember(ctx context.Context, groupID, userID string) error { + var number int64 + err := GroupMemberDB.Table("group_members").Where("group_id = ? and user_id = ?", groupID, userID).Count(&number).Error + if err != nil { + return constant.ErrDB.Wrap() + } + if number != 1 { + return constant.ErrData.Wrap() + } + return nil +} + func GetGroupMemberByGroupID(groupID string, filter int32, begin int32, maxNumber int32) ([]GroupMember, error) { var memberList []GroupMember var err error diff --git a/pkg/common/token_verify/jwt_token.go b/pkg/common/token_verify/jwt_token.go index 47b48a0a9..438a972ca 100644 --- a/pkg/common/token_verify/jwt_token.go +++ b/pkg/common/token_verify/jwt_token.go @@ -5,6 +5,7 @@ import ( "Open_IM/pkg/common/constant" commonDB "Open_IM/pkg/common/db" "Open_IM/pkg/common/log" + "Open_IM/pkg/common/tools" "Open_IM/pkg/common/trace_log" "Open_IM/pkg/utils" "context" @@ -142,6 +143,13 @@ func IsManagerUserID(OpUserID string) bool { } } +func CheckManagerUserID(ctx context.Context, userID string) error { + if utils.IsContain(userID, config.Config.Manager.AppManagerUid) { + return nil + } + return constant.ErrNoPermission.Wrap() +} + func CheckAccess(ctx context.Context, OpUserID string, OwnerUserID string) bool { if utils.IsContain(OpUserID, config.Config.Manager.AppManagerUid) { return true @@ -165,6 +173,20 @@ func CheckAccessV2(ctx context.Context, OpUserID string, OwnerUserID string) (er return utils.Wrap(constant.ErrIdentity, open_utils.GetSelfFuncName()) } +func CheckAccessV3(ctx context.Context, OwnerUserID string) (err error) { + opUserID := tools.OpUserID(ctx) + defer func() { + trace_log.SetCtxInfo(ctx, utils.GetFuncName(1), err, "OpUserID", opUserID, "OwnerUserID", OwnerUserID) + }() + if utils.IsContain(opUserID, config.Config.Manager.AppManagerUid) { + return nil + } + if opUserID == OwnerUserID { + return nil + } + return constant.ErrIdentity.Wrap(utils.GetSelfFuncName()) +} + func GetUserIDFromToken(token string, operationID string) (bool, string, string) { claims, err := ParseToken(token, operationID) if err != nil { diff --git a/pkg/common/tools/op.go b/pkg/common/tools/op.go new file mode 100644 index 000000000..df901055a --- /dev/null +++ b/pkg/common/tools/op.go @@ -0,0 +1,13 @@ +package tools + +import "context" + +func OperationID(ctx context.Context) string { + s, _ := ctx.Value("operationID").(string) + return s +} + +func OpUserID(ctx context.Context) string { + s, _ := ctx.Value("opUserID").(string) + return s +}