2025-10-14 15:48:01 +08:00

71 lines
2.1 KiB
Go

package startrpc
import (
"context"
"time"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/stability/ratelimit"
"github.com/openimsdk/tools/stability/ratelimit/bbr"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
type RateLimiter struct {
Enable bool
Window time.Duration
Bucket int
CPUThreshold int64
}
func NewRateLimiter(config *RateLimiter) ratelimit.Limiter {
if !config.Enable {
return nil
}
return bbr.NewBBRLimiter(
bbr.WithWindow(config.Window),
bbr.WithBucket(config.Bucket),
bbr.WithCPUThreshold(config.CPUThreshold),
)
}
func UnaryRateLimitInterceptor(limiter ratelimit.Limiter) grpc.ServerOption {
if limiter == nil {
return grpc.ChainUnaryInterceptor(func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) {
return handler(ctx, req)
})
}
return grpc.ChainUnaryInterceptor(func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) {
done, err := limiter.Allow()
if err != nil {
log.ZWarn(ctx, "rpc rate limited", err, "method", info.FullMethod)
return nil, status.Errorf(codes.ResourceExhausted, "rpc request rate limit exceeded: %v, please try again later", err)
}
defer done(ratelimit.DoneInfo{})
return handler(ctx, req)
})
}
func StreamRateLimitInterceptor(limiter ratelimit.Limiter) grpc.ServerOption {
if limiter == nil {
return grpc.ChainStreamInterceptor(func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
return handler(srv, ss)
})
}
return grpc.ChainStreamInterceptor(func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
done, err := limiter.Allow()
if err != nil {
log.ZWarn(ss.Context(), "rpc rate limited", err, "method", info.FullMethod)
return status.Errorf(codes.ResourceExhausted, "rpc request rate limit exceeded: %v, please try again later", err)
}
defer done(ratelimit.DoneInfo{})
return handler(srv, ss)
})
}