diff --git a/internal/push/push_to_client.go b/internal/push/push_to_client.go index 31cafb0f0..abc4ade1e 100644 --- a/internal/push/push_to_client.go +++ b/internal/push/push_to_client.go @@ -18,7 +18,6 @@ import ( "github.com/OpenIMSDK/Open-IM-Server/pkg/common/mcontext" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/prome" "github.com/OpenIMSDK/Open-IM-Server/pkg/discoveryregistry" - "github.com/OpenIMSDK/Open-IM-Server/pkg/errs" "github.com/OpenIMSDK/Open-IM-Server/pkg/proto/group" "github.com/OpenIMSDK/Open-IM-Server/pkg/proto/msggateway" "github.com/OpenIMSDK/Open-IM-Server/pkg/proto/sdkws" @@ -79,7 +78,7 @@ func (p *Pusher) DismissGroup(ctx context.Context, groupID string) error { func (p *Pusher) Push2User(ctx context.Context, userIDs []string, msg *sdkws.MsgData) error { log.ZDebug(ctx, "Get msg from msg_transfer And push msg", "userIDs", userIDs, "msg", msg.String()) // callback - if err := callbackOnlinePush(ctx, userIDs, msg); err != nil && err != errs.ErrCallbackContinue { + if err := callbackOnlinePush(ctx, userIDs, msg); err != nil { return err } // push @@ -132,7 +131,7 @@ func (p *Pusher) UnmarshalNotificationElem(bytes []byte, t interface{}) error { func (p *Pusher) Push2SuperGroup(ctx context.Context, groupID string, msg *sdkws.MsgData) (err error) { log.ZDebug(ctx, "Get super group msg from msg_transfer and push msg", "msg", msg.String(), "groupID", groupID) var pushToUserIDs []string - if err := callbackBeforeSuperGroupOnlinePush(ctx, groupID, msg, &pushToUserIDs); err != nil && err != errs.ErrCallbackContinue { + if err := callbackBeforeSuperGroupOnlinePush(ctx, groupID, msg, &pushToUserIDs); err != nil { return err } if len(pushToUserIDs) == 0 { diff --git a/internal/rpc/group/group.go b/internal/rpc/group/group.go index cf1c443e6..2fb35c0b3 100644 --- a/internal/rpc/group/group.go +++ b/internal/rpc/group/group.go @@ -155,7 +155,7 @@ func (s *groupServer) CreateGroup(ctx context.Context, req *pbGroup.CreateGroupR if len(userMap) != len(userIDs) { return nil, errs.ErrUserIDNotFound.Wrap("user not found") } - if err := CallbackBeforeCreateGroup(ctx, req); err != nil && err != errs.ErrCallbackContinue { + if err := CallbackBeforeCreateGroup(ctx, req); err != nil { return nil, err } var groupMembers []*relationTb.GroupMemberModel @@ -173,7 +173,7 @@ func (s *groupServer) CreateGroup(ctx context.Context, req *pbGroup.CreateGroupR groupMember.InviterUserID = mcontext.GetOpUserID(ctx) groupMember.JoinTime = time.Now() groupMember.MuteEndTime = time.Unix(0, 0) - if err := CallbackBeforeMemberJoinGroup(ctx, groupMember, group.Ex); err != nil && err != errs.ErrCallbackContinue { + if err := CallbackBeforeMemberJoinGroup(ctx, groupMember, group.Ex); err != nil { return err } groupMembers = append(groupMembers, groupMember) @@ -364,7 +364,7 @@ func (s *groupServer) InviteUserToGroup(ctx context.Context, req *pbGroup.Invite member.JoinSource = constant.JoinByInvitation member.JoinTime = time.Now() member.MuteEndTime = time.Unix(0, 0) - if err := CallbackBeforeMemberJoinGroup(ctx, member, group.Ex); err != nil && err != errs.ErrCallbackContinue { + if err := CallbackBeforeMemberJoinGroup(ctx, member, group.Ex); err != nil { return nil, err } groupMembers = append(groupMembers, member) @@ -704,7 +704,7 @@ func (s *groupServer) GroupApplicationResponse(ctx context.Context, req *pbGroup OperatorUserID: mcontext.GetOpUserID(ctx), Ex: groupRequest.Ex, } - if err = CallbackBeforeMemberJoinGroup(ctx, member, group.Ex); err != nil && err != errs.ErrCallbackContinue { + if err = CallbackBeforeMemberJoinGroup(ctx, member, group.Ex); err != nil { return nil, err } } @@ -756,7 +756,7 @@ func (s *groupServer) JoinGroup(ctx context.Context, req *pbGroup.JoinGroupReq) groupMember.InviterUserID = req.InviterUserID groupMember.JoinTime = time.Now() groupMember.MuteEndTime = time.Unix(0, 0) - if err := CallbackBeforeMemberJoinGroup(ctx, groupMember, group.Ex); err != nil && err != errs.ErrCallbackContinue { + if err := CallbackBeforeMemberJoinGroup(ctx, groupMember, group.Ex); err != nil { return nil, err } if err := s.GroupDatabase.CreateGroup(ctx, nil, []*relationTb.GroupMemberModel{groupMember}); err != nil { diff --git a/internal/rpc/msg/message_interceptor.go b/internal/rpc/msg/message_interceptor.go index d142abbbe..702cdc764 100644 --- a/internal/rpc/msg/message_interceptor.go +++ b/internal/rpc/msg/message_interceptor.go @@ -5,7 +5,6 @@ import ( "github.com/OpenIMSDK/Open-IM-Server/pkg/common/config" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant" - "github.com/OpenIMSDK/Open-IM-Server/pkg/common/log" "github.com/OpenIMSDK/Open-IM-Server/pkg/errs" "github.com/OpenIMSDK/Open-IM-Server/pkg/proto/msg" "github.com/OpenIMSDK/Open-IM-Server/pkg/proto/sdkws" @@ -30,28 +29,3 @@ func MessageHasReadEnabled(_ context.Context, req *msg.SendMsgReq) (*sdkws.MsgDa } return req.MsgData, nil } -func MessageModifyCallback(ctx context.Context, req *msg.SendMsgReq) (*sdkws.MsgData, error) { - if err := callbackMsgModify(ctx, req); err != nil && err != errs.ErrCallbackContinue { - log.ZWarn(ctx, "CallbackMsgModify failed", err, "req", req.String()) - return nil, err - } - return req.MsgData, nil -} -func MessageBeforeSendCallback(ctx context.Context, req *msg.SendMsgReq) (*sdkws.MsgData, error) { - switch req.MsgData.SessionType { - case constant.SingleChatType: - if err := callbackBeforeSendSingleMsg(ctx, req); err != nil && err != errs.ErrCallbackContinue { - log.ZWarn(ctx, "CallbackBeforeSendSingleMsg failed", err, "req", req.String()) - return nil, err - } - case constant.NotificationChatType: - case constant.SuperGroupChatType: - if err := callbackBeforeSendGroupMsg(ctx, req); err != nil && err != errs.ErrCallbackContinue { - log.ZWarn(ctx, "CallbackBeforeSendGroupMsg failed", err, "req", req.String()) - return nil, err - } - default: - return nil, errs.ErrArgs.Wrap("unknown sessionType") - } - return req.MsgData, nil -} diff --git a/internal/rpc/msg/send.go b/internal/rpc/msg/send.go index 12c614fea..5c5e303b1 100644 --- a/internal/rpc/msg/send.go +++ b/internal/rpc/msg/send.go @@ -18,9 +18,6 @@ func (m *msgServer) SendMsg(ctx context.Context, req *pbMsg.SendMsgReq) (resp *p return nil, errs.ErrMessageHasReadDisable.Wrap() } m.encapsulateMsgData(req.MsgData) - if err := callbackMsgModify(ctx, req); err != nil && err != errs.ErrCallbackContinue { - return nil, err - } switch req.MsgData.SessionType { case constant.SingleChatType: return m.sendMsgSingleChat(ctx, req) @@ -40,12 +37,18 @@ func (m *msgServer) sendMsgSuperGroupChat(ctx context.Context, req *pbMsg.SendMs promePkg.Inc(promePkg.WorkSuperGroupChatMsgProcessFailedCounter) return nil, err } + if err = callbackBeforeSendGroupMsg(ctx, req); err != nil { + return nil, err + } + if err := callbackMsgModify(ctx, req); err != nil { + return nil, err + } err = m.MsgDatabase.MsgToMQ(ctx, utils.GenConversationUniqueKeyForGroup(req.MsgData.GroupID), req.MsgData) if err != nil { return nil, err } if err = callbackAfterSendGroupMsg(ctx, req); err != nil { - log.ZError(ctx, "CallbackAfterSendGroupMsg", err) + log.ZWarn(ctx, "CallbackAfterSendGroupMsg", err) } promePkg.Inc(promePkg.WorkSuperGroupChatMsgProcessSuccessCounter) resp.SendTime = req.MsgData.SendTime @@ -85,13 +88,19 @@ func (m *msgServer) sendMsgSingleChat(ctx context.Context, req *pbMsg.SendMsgReq promePkg.Inc(promePkg.SingleChatMsgProcessFailedCounter) return nil, errs.ErrUserNotRecvMsg } else { + if err = callbackBeforeSendSingleMsg(ctx, req); err != nil { + return nil, err + } + if err := callbackMsgModify(ctx, req); err != nil { + return nil, err + } if err := m.MsgDatabase.MsgToMQ(ctx, utils.GenConversationUniqueKeyForSingle(req.MsgData.SendID, req.MsgData.RecvID), req.MsgData); err != nil { promePkg.Inc(promePkg.SingleChatMsgProcessFailedCounter) return nil, err } err = callbackAfterSendSingleMsg(ctx, req) - if err != nil && err != errs.ErrCallbackContinue { - return nil, err + if err != nil { + log.ZWarn(ctx, "CallbackAfterSendSingleMsg", err, "req", req) } resp = &pbMsg.SendMsgResp{ ServerMsgID: req.MsgData.ServerMsgID, diff --git a/internal/rpc/msg/server.go b/internal/rpc/msg/server.go index 85b247ea3..c2b30191e 100644 --- a/internal/rpc/msg/server.go +++ b/internal/rpc/msg/server.go @@ -81,7 +81,7 @@ func Start(client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) e MessageLocker: NewLockerMessage(cacheModel), } s.notificationSender = rpcclient.NewNotificationSender(rpcclient.WithLocalSendMsg(s.SendMsg)) - s.addInterceptorHandler(MessageHasReadEnabled, MessageModifyCallback) + s.addInterceptorHandler(MessageHasReadEnabled) s.initPrometheus() msg.RegisterMsgServer(server, s) return nil