diff --git a/internal/api/custom_validator.go b/internal/api/custom_validator.go new file mode 100644 index 000000000..541702677 --- /dev/null +++ b/internal/api/custom_validator.go @@ -0,0 +1,23 @@ +package api + +import ( + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant" + "github.com/go-playground/validator/v10" +) + +func RequiredIf(fl validator.FieldLevel) bool { + sessionType := fl.Parent().FieldByName("SessionType").Int() + switch sessionType { + case constant.SingleChatType, constant.NotificationChatType: + if fl.FieldName() == "RecvID" { + return fl.Field().String() != "" + } + case constant.GroupChatType, constant.SuperGroupChatType: + if fl.FieldName() == "GroupID" { + return fl.Field().String() != "" + } + default: + return true + } + return true +} diff --git a/internal/api/route.go b/internal/api/route.go index cc9f51788..823778bca 100644 --- a/internal/api/route.go +++ b/internal/api/route.go @@ -2,7 +2,6 @@ package api import ( "github.com/OpenIMSDK/Open-IM-Server/pkg/common/config" - "github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/log" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/mw" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/prome" @@ -162,20 +161,3 @@ func NewGinRouter(zk discoveryregistry.SvcDiscoveryRegistry, rdb redis.Universal } return r } - -func RequiredIf(fl validator.FieldLevel) bool { - sessionType := fl.Parent().FieldByName("SessionType").Int() - switch sessionType { - case constant.SingleChatType, constant.NotificationChatType: - if fl.FieldName() == "RecvID" { - return fl.Field().String() != "" - } - case constant.GroupChatType, constant.SuperGroupChatType: - if fl.FieldName() == "GroupID" { - return fl.Field().String() != "" - } - default: - return true - } - return true -} diff --git a/internal/push/consumer_init.go b/internal/push/consumer_init.go index 6eba5306e..283cf1d1b 100644 --- a/internal/push/consumer_init.go +++ b/internal/push/consumer_init.go @@ -1,9 +1,3 @@ -/* -** description(""). -** copyright('open-im,www.open-im.io'). -** author("fg,Gordon@open-im.io"). -** time(2021/3/22 15:33). - */ package push import ( diff --git a/internal/push/push_handler.go b/internal/push/push_handler.go index 6d3f5d61b..c6a1ac660 100644 --- a/internal/push/push_handler.go +++ b/internal/push/push_handler.go @@ -1,9 +1,3 @@ -/* -** description(""). -** copyright('OpenIM,www.OpenIM.io'). -** author("fg,Gordon@tuoyun.net"). -** time(2021/5/13 10:33). - */ package push import ( diff --git a/internal/push/push_to_client.go b/internal/push/push_to_client.go index f17e09e83..eac8010ae 100644 --- a/internal/push/push_to_client.go +++ b/internal/push/push_to_client.go @@ -1,9 +1,3 @@ -/* -** description(""). -** copyright('open-im,www.open-im.io'). -** author("fg,Gordon@open-im.io"). -** time(2021/3/5 14:31). - */ package push import ( diff --git a/internal/rpc/msg/MessageInterceptor.go b/internal/rpc/msg/MessageInterceptor.go new file mode 100644 index 000000000..e07db2ec8 --- /dev/null +++ b/internal/rpc/msg/MessageInterceptor.go @@ -0,0 +1,61 @@ +package msg + +import ( + "context" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/config" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/log" + "github.com/OpenIMSDK/Open-IM-Server/pkg/errs" + "github.com/OpenIMSDK/Open-IM-Server/pkg/proto/msg" + "github.com/OpenIMSDK/Open-IM-Server/pkg/proto/sdkws" +) + +type MessageInterceptorFunc func(ctx context.Context, req *msg.SendMsgReq) (*sdkws.MsgData, error) + +func MessageHasReadEnabled(_ context.Context, req *msg.SendMsgReq) (*sdkws.MsgData, error) { + switch req.MsgData.ContentType { + case constant.HasReadReceipt: + if config.Config.SingleMessageHasReadReceiptEnable { + return req.MsgData, nil + } else { + return nil, errs.ErrMessageHasReadDisable.Wrap() + } + case constant.GroupHasReadReceipt: + if config.Config.GroupMessageHasReadReceiptEnable { + return req.MsgData, nil + } else { + return nil, errs.ErrMessageHasReadDisable.Wrap() + } + } + return req.MsgData, nil +} +func MessageModifyCallback(ctx context.Context, req *msg.SendMsgReq) (*sdkws.MsgData, error) { + if err := CallbackMsgModify(ctx, req); err != nil && err != errs.ErrCallbackContinue { + log.ZWarn(ctx, "CallbackMsgModify failed", err, "req", req.String()) + return nil, err + } + return req.MsgData, nil +} +func MessageBeforeSendCallback(ctx context.Context, req *msg.SendMsgReq) (*sdkws.MsgData, error) { + switch req.MsgData.SessionType { + case constant.SingleChatType: + if err := CallbackBeforeSendSingleMsg(ctx, req); err != nil && err != errs.ErrCallbackContinue { + log.ZWarn(ctx, "CallbackBeforeSendSingleMsg failed", err, "req", req.String()) + return nil, err + } + case constant.GroupChatType: + if err := CallbackBeforeSendGroupMsg(ctx, req); err != nil && err != errs.ErrCallbackContinue { + log.ZWarn(ctx, "CallbackBeforeSendGroupMsg failed", err, "req", req.String()) + return nil, err + } + case constant.NotificationChatType: + case constant.SuperGroupChatType: + if err := CallbackBeforeSendGroupMsg(ctx, req); err != nil && err != errs.ErrCallbackContinue { + log.ZWarn(ctx, "CallbackBeforeSendGroupMsg failed", err, "req", req.String()) + return nil, err + } + default: + return nil, errs.ErrArgs.Wrap("unknown sessionType") + } + return req.MsgData, nil +} diff --git a/internal/rpc/msg/send_pull.go b/internal/rpc/msg/send_pull.go index f1987bdf7..40d08f30f 100644 --- a/internal/rpc/msg/send_pull.go +++ b/internal/rpc/msg/send_pull.go @@ -16,11 +16,6 @@ import ( func (m *msgServer) sendMsgSuperGroupChat(ctx context.Context, req *msg.SendMsgReq) (resp *msg.SendMsgResp, err error) { resp = &msg.SendMsgResp{} promePkg.Inc(promePkg.WorkSuperGroupChatMsgRecvSuccessCounter) - // callback - if err = CallbackBeforeSendGroupMsg(ctx, req); err != nil && err != errs.ErrCallbackContinue { - return nil, err - } - if _, err = m.messageVerification(ctx, req); err != nil { promePkg.Inc(promePkg.WorkSuperGroupChatMsgProcessFailedCounter) return nil, err @@ -63,9 +58,6 @@ func (m *msgServer) sendMsgNotification(ctx context.Context, req *msg.SendMsgReq func (m *msgServer) sendMsgSingleChat(ctx context.Context, req *msg.SendMsgReq) (resp *msg.SendMsgResp, err error) { promePkg.Inc(promePkg.SingleChatMsgRecvSuccessCounter) - if err = CallbackBeforeSendSingleMsg(ctx, req); err != nil && err != errs.ErrCallbackContinue { - return nil, err - } _, err = m.messageVerification(ctx, req) if err != nil { return nil, err @@ -103,10 +95,6 @@ func (m *msgServer) sendMsgSingleChat(ctx context.Context, req *msg.SendMsgReq) func (m *msgServer) sendMsgGroupChat(ctx context.Context, req *msg.SendMsgReq) (resp *msg.SendMsgResp, err error) { // callback promePkg.Inc(promePkg.GroupChatMsgRecvSuccessCounter) - err = CallbackBeforeSendGroupMsg(ctx, req) - if err != nil && err != errs.ErrCallbackContinue { - return nil, err - } var memberUserIDList []string if memberUserIDList, err = m.messageVerification(ctx, req); err != nil { diff --git a/internal/rpc/msg/server.go b/internal/rpc/msg/server.go index 16b2b901f..29731656f 100644 --- a/internal/rpc/msg/server.go +++ b/internal/rpc/msg/server.go @@ -18,6 +18,7 @@ import ( "google.golang.org/grpc" ) +type MessageInterceptorChain []MessageInterceptorFunc type msgServer struct { RegisterCenter discoveryregistry.SvcDiscoveryRegistry MsgDatabase controller.MsgDatabase @@ -29,8 +30,22 @@ type msgServer struct { *localcache.GroupLocalCache black *check.BlackChecker MessageLocker MessageLocker + Handlers MessageInterceptorChain } +func (m *msgServer) addInterceptorHandler(interceptorFunc ...MessageInterceptorFunc) { + m.Handlers = append(m.Handlers, interceptorFunc...) +} +func (m *msgServer) execInterceptorHandler(ctx context.Context, req *msg.SendMsgReq) error { + for _, handler := range m.Handlers { + msgData, err := handler(ctx, req) + if err != nil { + return err + } + req.MsgData = msgData + } + return nil +} func Start(client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error { rdb, err := cache.NewRedis() if err != nil { @@ -40,7 +55,6 @@ func Start(client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) e if err != nil { return err } - cacheModel := cache.NewCacheModel(rdb) msgDocModel := unrelation.NewMsgMongoDriver(mongo.GetDatabase()) extendMsgModel := unrelation.NewExtendMsgSetMongoDriver(mongo.GetDatabase()) @@ -60,6 +74,7 @@ func Start(client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) e friend: check.NewFriendChecker(client), MessageLocker: NewLockerMessage(cacheModel), } + s.addInterceptorHandler(MessageHasReadEnabled, MessageModifyCallback) s.initPrometheus() msg.RegisterMsgServer(server, s) return nil diff --git a/pkg/utils/utils_v2.go b/pkg/utils/utils_v2.go index 13d52cf15..15433cd8b 100644 --- a/pkg/utils/utils_v2.go +++ b/pkg/utils/utils_v2.go @@ -472,7 +472,7 @@ func Unwrap(err error) error { // NotNilReplace 当new_不为空时, 将old设置为new_ func NotNilReplace[T any](old, new_ *T) { - if old == nil || new_ == nil { + if new_ == nil { return } *old = *new_