diff --git a/pkg/common/startrpc/start.go b/pkg/common/startrpc/start.go index bf469480e..b3fba038d 100644 --- a/pkg/common/startrpc/start.go +++ b/pkg/common/startrpc/start.go @@ -99,13 +99,6 @@ func Start[T any](ctx context.Context, disc *conf.Discovery, prometheusConfig *c sigs := make(chan os.Signal, 1) signal.Notify(sigs, syscall.SIGTERM) - var gsrv grpcServiceRegistrar - - err = rpcFn(ctx, config, client, &gsrv) - if err != nil { - return err - } - ctx, cancel := context.WithCancelCause(ctx) if prometheusListenAddr != "" { @@ -137,38 +130,51 @@ func Start[T any](ctx context.Context, disc *conf.Discovery, prometheusConfig *c }() } - var rpcGracefulStop chan struct{} + var ( + rpcServer *grpc.Server + rpcGracefulStop chan struct{} + ) - if len(gsrv.services) > 0 { - rpcListener, rpcPort, err := listenTCP(rpcListenAddr) + onGrpcServiceRegistrar := func(desc *grpc.ServiceDesc, impl any) { + if rpcServer != nil { + rpcServer.RegisterService(desc, impl) + return + } + rpcListener, err := net.Listen("tcp", rpcListenAddr) if err != nil { - return err - } - srv := grpc.NewServer(options...) - - for _, service := range gsrv.services { - srv.RegisterService(service.desc, service.impl) - } - grpcOpt := grpc.WithTransportCredentials(insecure.NewCredentials()) - if err := client.Register(ctx, rpcRegisterName, registerIP, rpcPort, grpcOpt); err != nil { - return err + cancel(fmt.Errorf("listen rpc %s %s %w", rpcRegisterName, rpcListenAddr, err)) + return } + rpcServer = grpc.NewServer(options...) + rpcServer.RegisterService(desc, impl) rpcGracefulStop = make(chan struct{}) + rpcPort := rpcListener.Addr().(*net.TCPAddr).Port + log.ZDebug(ctx, "rpc start register", "rpcRegisterName", rpcRegisterName, "registerIP", registerIP, "rpcPort", rpcPort) + grpcOpt := grpc.WithTransportCredentials(insecure.NewCredentials()) + rpcGracefulStop = make(chan struct{}) + go func() { + <-ctx.Done() + rpcServer.GracefulStop() + close(rpcGracefulStop) + }() + if err := client.Register(ctx, rpcRegisterName, registerIP, rpcListener.Addr().(*net.TCPAddr).Port, grpcOpt); err != nil { + cancel(fmt.Errorf("rpc register %s %w", rpcRegisterName, err)) + return + } go func() { - err := srv.Serve(rpcListener) + err := rpcServer.Serve(rpcListener) if err == nil { err = fmt.Errorf("serve end") } cancel(fmt.Errorf("rpc %s %w", rpcRegisterName, err)) }() + } - go func() { - <-ctx.Done() - srv.GracefulStop() - close(rpcGracefulStop) - }() + err = rpcFn(ctx, config, client, &grpcServiceRegistrar{onRegisterService: onGrpcServiceRegistrar}) + if err != nil { + return err } select { @@ -220,18 +226,10 @@ func prommetricsStreamInterceptor(rpcRegisterName string) grpc.ServerOption { return grpc.ChainStreamInterceptor() } -type grpcService struct { - desc *grpc.ServiceDesc - impl any -} - type grpcServiceRegistrar struct { - services []*grpcService + onRegisterService func(desc *grpc.ServiceDesc, impl any) } func (x *grpcServiceRegistrar) RegisterService(desc *grpc.ServiceDesc, impl any) { - x.services = append(x.services, &grpcService{ - desc: desc, - impl: impl, - }) + x.onRegisterService(desc, impl) }