From 900858e4ea8ece03d43570e93eb6629887b8ff36 Mon Sep 17 00:00:00 2001 From: withchao <993506633@qq.com> Date: Wed, 22 Mar 2023 10:11:18 +0800 Subject: [PATCH] rpc custom header --- pkg/common/constant/constant.go | 2 +- pkg/common/mw/check.go | 4 ++-- pkg/common/mw/rpc_client_interceptor.go | 6 +++++- pkg/common/mw/rpc_server_interceptor.go | 9 +++++++++ 4 files changed, 17 insertions(+), 4 deletions(-) diff --git a/pkg/common/constant/constant.go b/pkg/common/constant/constant.go index d2188cbd9..cbb350eb0 100644 --- a/pkg/common/constant/constant.go +++ b/pkg/common/constant/constant.go @@ -275,7 +275,7 @@ const OpUserID = "opUserID" const ConnID = "connID" const OpUserPlatform = "platform" const Token = "token" -const RpcMwCustom = "hCustom" +const RpcCustomHeader = "customHeader" // rpc中间件自定义ctx参数 const CheckKey = "CheckKey" const ( diff --git a/pkg/common/mw/check.go b/pkg/common/mw/check.go index 65c332da2..e62ef3fa5 100644 --- a/pkg/common/mw/check.go +++ b/pkg/common/mw/check.go @@ -15,8 +15,8 @@ import ( ) var ( - block cipher.Block once sync.Once + block cipher.Block ) func init() { @@ -37,7 +37,7 @@ func initAesKey() { func genReqKey(args []string) string { initAesKey() plaintext := md5.Sum([]byte(strings.Join(args, ":"))) - var iv = make([]byte, aes.BlockSize, aes.BlockSize+md5.Size) + iv := make([]byte, aes.BlockSize, aes.BlockSize+md5.Size) if _, err := rand.Read(iv); err != nil { panic(err) } diff --git a/pkg/common/mw/rpc_client_interceptor.go b/pkg/common/mw/rpc_client_interceptor.go index d558965e0..eda323836 100644 --- a/pkg/common/mw/rpc_client_interceptor.go +++ b/pkg/common/mw/rpc_client_interceptor.go @@ -24,14 +24,18 @@ func rpcClientInterceptor(ctx context.Context, method string, req, resp interfac } log.ZInfo(ctx, "rpc client req", "funcName", method, "req", rpcString(req)) md := metadata.Pairs() - if keys, _ := ctx.Value(constant.RpcMwCustom).([]string); len(keys) > 0 { + if keys, _ := ctx.Value(constant.RpcCustomHeader).([]string); len(keys) > 0 { for _, key := range keys { val, ok := ctx.Value(key).([]string) if !ok { return errs.ErrInternalServer.Wrap(fmt.Sprintf("ctx missing key %s", key)) } + if len(val) == 0 { + return errs.ErrInternalServer.Wrap(fmt.Sprintf("ctx key %s value is empty", key)) + } md.Set(key, val...) } + md.Set(constant.RpcCustomHeader, keys...) } operationID, ok := ctx.Value(constant.OperationID).(string) if !ok { diff --git a/pkg/common/mw/rpc_server_interceptor.go b/pkg/common/mw/rpc_server_interceptor.go index f76fa5324..f47a4321b 100644 --- a/pkg/common/mw/rpc_server_interceptor.go +++ b/pkg/common/mw/rpc_server_interceptor.go @@ -56,6 +56,15 @@ func rpcServerInterceptor(ctx context.Context, req interface{}, info *grpc.Unary if !ok { return nil, status.New(codes.InvalidArgument, "missing metadata").Err() } + if keys := md.Get(constant.RpcCustomHeader); len(keys) > 0 { + for _, key := range keys { + values := md.Get(key) + if len(values) == 0 { + return nil, status.New(codes.InvalidArgument, fmt.Sprintf("missing metadata key %s", key)).Err() + } + ctx = context.WithValue(ctx, key, values) + } + } args := make([]string, 0, 4) if opts := md.Get(constant.OperationID); len(opts) != 1 || opts[0] == "" { return nil, status.New(codes.InvalidArgument, "operationID error").Err()