diff --git a/pkg/common/mw/check.go b/pkg/common/mw/check.go index e62ef3fa5..5c8a3ee0c 100644 --- a/pkg/common/mw/check.go +++ b/pkg/common/mw/check.go @@ -7,11 +7,12 @@ import ( "encoding/base64" "errors" "fmt" - "github.com/OpenIMSDK/Open-IM-Server/pkg/common/config" "math/rand" "strings" "sync" "time" + + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/config" ) var ( diff --git a/pkg/common/mw/intercept_chain.go b/pkg/common/mw/intercept_chain.go new file mode 100644 index 000000000..8feebab90 --- /dev/null +++ b/pkg/common/mw/intercept_chain.go @@ -0,0 +1,27 @@ +package mw + +import ( + "context" + + "google.golang.org/grpc" +) + +func InterceptChain(intercepts ...grpc.UnaryServerInterceptor) grpc.UnaryServerInterceptor { + l := len(intercepts) + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { + chain := func(currentInter grpc.UnaryServerInterceptor, currentHandler grpc.UnaryHandler) grpc.UnaryHandler { + return func(currentCtx context.Context, currentReq interface{}) (interface{}, error) { + return currentInter( + currentCtx, + currentReq, + info, + currentHandler) + } + } + chainHandler := handler + for i := l - 1; i >= 0; i-- { + chainHandler = chain(intercepts[i], chainHandler) + } + return chainHandler(ctx, req) + } +} diff --git a/pkg/common/mw/rpc_server_interceptor.go b/pkg/common/mw/rpc_server_interceptor.go index e923f1319..9da31b716 100644 --- a/pkg/common/mw/rpc_server_interceptor.go +++ b/pkg/common/mw/rpc_server_interceptor.go @@ -3,12 +3,13 @@ package mw import ( "context" "fmt" - "github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant" "math" "runtime" "runtime/debug" "strings" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/config" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/log" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/mw/specialerror" diff --git a/pkg/startrpc/start.go b/pkg/startrpc/start.go index f65164b6f..a9db548a6 100644 --- a/pkg/startrpc/start.go +++ b/pkg/startrpc/start.go @@ -28,7 +28,6 @@ func Start(rpcPort int, rpcRegisterName string, prometheusPort int, rpcFn func(c return err } defer listener.Close() - fmt.Println(config.Config.Zookeeper.ZkAddr, config.Config.Zookeeper.Schema, rpcRegisterName) zkClient, err := openKeeper.NewClient(config.Config.Zookeeper.ZkAddr, config.Config.Zookeeper.Schema, openKeeper.WithFreq(time.Hour), openKeeper.WithUserNameAndPassword(config.Config.Zookeeper.UserName, config.Config.Zookeeper.Password), openKeeper.WithRoundRobin(), openKeeper.WithTimeout(10)) @@ -46,10 +45,10 @@ func Start(rpcPort int, rpcRegisterName string, prometheusPort int, rpcFn func(c prome.NewGrpcRequestCounter() prome.NewGrpcRequestFailedCounter() prome.NewGrpcRequestSuccessCounter() + unaryInterceptor := mw.InterceptChain(grpcPrometheus.UnaryServerInterceptor, grpcPrometheus.UnaryServerInterceptor) options = append(options, []grpc.ServerOption{ - //grpc.UnaryInterceptor(prome.UnaryServerInterceptorPrometheus), grpc.StreamInterceptor(grpcPrometheus.StreamServerInterceptor), - grpc.UnaryInterceptor(grpcPrometheus.UnaryServerInterceptor), + grpc.UnaryInterceptor(unaryInterceptor), }...) } srv := grpc.NewServer(options...)