// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package startrpc

import (
	"context"
	"errors"
	"fmt"
	"net"
	"os"
	"os/signal"
	"strconv"
	"syscall"
	"time"

	conf "github.com/openimsdk/open-im-server/v3/pkg/common/config"
	"github.com/openimsdk/tools/utils/datautil"
	"github.com/openimsdk/tools/utils/jsonutil"
	"github.com/openimsdk/tools/utils/network"
	"google.golang.org/grpc/status"

	kdisc "github.com/openimsdk/open-im-server/v3/pkg/common/discovery"
	"github.com/openimsdk/open-im-server/v3/pkg/common/prommetrics"
	"github.com/openimsdk/tools/discovery"
	"github.com/openimsdk/tools/errs"
	"github.com/openimsdk/tools/log"
	"github.com/openimsdk/tools/mw"
	"google.golang.org/grpc"
	"google.golang.org/grpc/credentials/insecure"
)

func init() {
	prommetrics.RegistryAll()
}

func Start[T any](ctx context.Context, disc *conf.Discovery, prometheusConfig *conf.Prometheus, listenIP,
	registerIP string, autoSetPorts bool, rpcPorts []int, index int, rpcRegisterName string, notification *conf.Notification, config T,
	watchConfigNames []string, watchServiceNames []string,
	rpcFn func(ctx context.Context, config T, client discovery.Conn, server grpc.ServiceRegistrar) error,
	options ...grpc.ServerOption) error {

	if notification != nil {
		conf.InitNotification(notification)
	}

	options = append(options, mw.GrpcServer())

	registerIP, err := network.GetRpcRegisterIP(registerIP)
	if err != nil {
		return err
	}
	var prometheusListenAddr string
	if autoSetPorts {
		prometheusListenAddr = net.JoinHostPort(listenIP, "0")
	} else {
		prometheusPort, err := datautil.GetElemByIndex(prometheusConfig.Ports, index)
		if err != nil {
			return err
		}
		prometheusListenAddr = net.JoinHostPort(listenIP, strconv.Itoa(prometheusPort))
	}

	watchConfigNames = append(watchConfigNames, conf.LogConfigFileName)

	client, err := kdisc.NewDiscoveryRegister(disc, watchServiceNames)
	if err != nil {
		return err
	}

	defer client.Close()
	client.AddOption(
		mw.GrpcClient(), grpc.WithTransportCredentials(insecure.NewCredentials()),
		grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"LoadBalancingPolicy": "%s"}`, "round_robin")),
	)

	ctx, cancel := context.WithCancelCause(ctx)

	go func() {
		sigs := make(chan os.Signal, 1)
		signal.Notify(sigs, syscall.SIGTERM, syscall.SIGINT, syscall.SIGKILL)
		select {
		case <-ctx.Done():
			return
		case val := <-sigs:
			log.ZDebug(ctx, "recv signal", "signal", val.String())
			cancel(fmt.Errorf("signal %s", val.String()))
		}
	}()

	if prometheusListenAddr != "" {
		options = append(
			options,
			prommetricsUnaryInterceptor(rpcRegisterName),
			prommetricsStreamInterceptor(rpcRegisterName),
		)
		prometheusListener, prometheusPort, err := listenTCP(prometheusListenAddr)
		if err != nil {
			return err
		}
		log.ZDebug(ctx, "prometheus start", "addr", prometheusListener.Addr(), "rpcRegisterName", rpcRegisterName)
		target, err := jsonutil.JsonMarshal(prommetrics.BuildDefaultTarget(registerIP, prometheusPort))
		if err != nil {
			return err
		}
		if err := client.SetKey(ctx, prommetrics.BuildDiscoveryKey(prommetrics.APIKeyName), target); err != nil {
			if !errors.Is(err, discovery.ErrNotSupportedKeyValue) {
				return err
			}
		}
		go func() {
			err := prommetrics.Start(prometheusListener)
			if err == nil {
				err = fmt.Errorf("listener done")
			}
			cancel(fmt.Errorf("prommetrics %s %w", rpcRegisterName, err))
		}()
	}

	var (
		rpcServer       *grpc.Server
		rpcGracefulStop chan struct{}
	)

	onGrpcServiceRegistrar := func(desc *grpc.ServiceDesc, impl any) {
		if rpcServer != nil {
			rpcServer.RegisterService(desc, impl)
			return
		}
		var rpcListenAddr string
		if autoSetPorts {
			rpcListenAddr = net.JoinHostPort(listenIP, "0")
		} else {
			rpcPort, err := datautil.GetElemByIndex(rpcPorts, index)
			if err != nil {
				cancel(fmt.Errorf("rpcPorts index out of range %s %w", rpcRegisterName, err))
				return
			}
			rpcListenAddr = net.JoinHostPort(listenIP, strconv.Itoa(rpcPort))
		}
		rpcListener, err := net.Listen("tcp", rpcListenAddr)
		if err != nil {
			cancel(fmt.Errorf("listen rpc %s %s %w", rpcRegisterName, rpcListenAddr, err))
			return
		}

		rpcServer = grpc.NewServer(options...)
		rpcServer.RegisterService(desc, impl)
		rpcGracefulStop = make(chan struct{})
		rpcPort := rpcListener.Addr().(*net.TCPAddr).Port
		log.ZDebug(ctx, "rpc start register", "rpcRegisterName", rpcRegisterName, "registerIP", registerIP, "rpcPort", rpcPort)
		grpcOpt := grpc.WithTransportCredentials(insecure.NewCredentials())
		rpcGracefulStop = make(chan struct{})
		go func() {
			<-ctx.Done()
			rpcServer.GracefulStop()
			close(rpcGracefulStop)
		}()
		if err := client.Register(ctx, rpcRegisterName, registerIP, rpcListener.Addr().(*net.TCPAddr).Port, grpcOpt); err != nil {
			cancel(fmt.Errorf("rpc register %s %w", rpcRegisterName, err))
			return
		}

		go func() {
			err := rpcServer.Serve(rpcListener)
			if err == nil {
				err = fmt.Errorf("serve end")
			}
			cancel(fmt.Errorf("rpc %s %w", rpcRegisterName, err))
		}()
	}

	err = rpcFn(ctx, config, client, &grpcServiceRegistrar{onRegisterService: onGrpcServiceRegistrar})
	if err != nil {
		return err
	}
	<-ctx.Done()
	log.ZDebug(ctx, "cmd wait done", "err", context.Cause(ctx))
	if rpcGracefulStop != nil {
		timeout := time.NewTimer(time.Second * 15)
		defer timeout.Stop()
		select {
		case <-timeout.C:
			log.ZWarn(ctx, "rcp graceful stop timeout", nil)
		case <-rpcGracefulStop:
			log.ZDebug(ctx, "rcp graceful stop done")
		}
	}
	return context.Cause(ctx)
}

func listenTCP(addr string) (net.Listener, int, error) {
	listener, err := net.Listen("tcp", addr)
	if err != nil {
		return nil, 0, errs.WrapMsg(err, "listen err", "addr", addr)
	}
	return listener, listener.Addr().(*net.TCPAddr).Port, nil
}

func prommetricsUnaryInterceptor(rpcRegisterName string) grpc.ServerOption {
	getCode := func(err error) int {
		if err == nil {
			return 0
		}
		rpcErr, ok := err.(interface{ GRPCStatus() *status.Status })
		if !ok {
			return -1
		}
		return int(rpcErr.GRPCStatus().Code())
	}
	return grpc.ChainUnaryInterceptor(func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
		resp, err := handler(ctx, req)
		prommetrics.RPCCall(rpcRegisterName, info.FullMethod, getCode(err))
		return resp, err
	})
}

func prommetricsStreamInterceptor(rpcRegisterName string) grpc.ServerOption {
	return grpc.ChainStreamInterceptor()
}

type grpcServiceRegistrar struct {
	onRegisterService func(desc *grpc.ServiceDesc, impl any)
}

func (x *grpcServiceRegistrar) RegisterService(desc *grpc.ServiceDesc, impl any) {
	x.onRegisterService(desc, impl)
}