diff --git a/gin.go b/gin.go index 2e033bf3..71d18aad 100644 --- a/gin.go +++ b/gin.go @@ -5,6 +5,7 @@ package gin import ( + "context" "fmt" "html/template" "net" @@ -186,6 +187,11 @@ type Engine struct { maxSections uint16 trustedProxies []string trustedCIDRs []*net.IPNet + + // server holds a reference to the HTTP server for graceful shutdown. + // This is set when one of the Run* methods is called. + server *http.Server + serverLock sync.Mutex } var _ IRouter = (*Engine)(nil) @@ -534,6 +540,30 @@ func parseIP(ip string) net.IP { return parsedIP } +// Shutdown gracefully shuts down the server without interrupting any active connections. +// Shutdown works by first closing all open listeners, then closing all idle connections, +// and then waiting indefinitely for connections to return to idle and then shut down. +// If the provided context expires before the shutdown is complete, Shutdown returns the +// context's error, otherwise it returns any error returned from closing the Server's +// underlying Listener(s). +// +// When Shutdown is called, Serve, ListenAndServe, and ListenAndServeTLS immediately +// return ErrServerClosed. Make sure the program doesn't exit and waits instead for +// Shutdown to return. +// +// This method returns nil if the server has not been started. +func (engine *Engine) Shutdown(ctx context.Context) error { + engine.serverLock.Lock() + srv := engine.server + engine.serverLock.Unlock() + + if srv == nil { + return nil + } + + return srv.Shutdown(ctx) +} + // Run attaches the router to a http.Server and starts listening and serving HTTP requests. // It is a shortcut for http.ListenAndServe(addr, router) // Note: this method will block the calling goroutine indefinitely unless an error happens. @@ -551,6 +581,11 @@ func (engine *Engine) Run(addr ...string) (err error) { Addr: address, Handler: engine.Handler(), } + + engine.serverLock.Lock() + engine.server = server + engine.serverLock.Unlock() + err = server.ListenAndServe() return } @@ -571,6 +606,11 @@ func (engine *Engine) RunTLS(addr, certFile, keyFile string) (err error) { Addr: addr, Handler: engine.Handler(), } + + engine.serverLock.Lock() + engine.server = server + engine.serverLock.Unlock() + err = server.ListenAndServeTLS(certFile, keyFile) return } @@ -597,6 +637,11 @@ func (engine *Engine) RunUnix(file string) (err error) { server := &http.Server{ // #nosec G112 Handler: engine.Handler(), } + + engine.serverLock.Lock() + engine.server = server + engine.serverLock.Unlock() + err = server.Serve(listener) return } @@ -654,6 +699,11 @@ func (engine *Engine) RunListener(listener net.Listener) (err error) { server := &http.Server{ // #nosec G112 Handler: engine.Handler(), } + + engine.serverLock.Lock() + engine.server = server + engine.serverLock.Unlock() + err = server.Serve(listener) return } diff --git a/graceful.go b/graceful.go new file mode 100644 index 00000000..4a5d3105 --- /dev/null +++ b/graceful.go @@ -0,0 +1,69 @@ +// Copyright 2014 Manu Martinez-Almeida. All rights reserved. +// Use of this source code is governed by a MIT style +// license that can be found in the LICENSE file. + +package gin + +import ( + "context" + "errors" + "net/http" + "os" + "os/signal" + "syscall" + "time" +) + +// ShutdownConfig holds configuration for graceful shutdown. +type ShutdownConfig struct { + // Timeout is the maximum duration to wait for active connections to finish. + // Default: 10 seconds + Timeout time.Duration + + // Signals are the OS signals that will trigger shutdown. + // Default: SIGINT, SIGTERM + Signals []os.Signal +} + +// RunWithShutdown starts the HTTP server and handles graceful shutdown on SIGINT/SIGTERM. +// It blocks until the server is shut down. +// The timeout parameter specifies the maximum duration to wait for active connections to finish. +func (engine *Engine) RunWithShutdown(addr string, timeout time.Duration) error { + return engine.RunWithShutdownConfig(addr, ShutdownConfig{ + Timeout: timeout, + Signals: []os.Signal{syscall.SIGINT, syscall.SIGTERM}, + }) +} + +// RunWithShutdownConfig starts the HTTP server with custom shutdown configuration. +// It blocks until the server is shut down. +func (engine *Engine) RunWithShutdownConfig(addr string, config ShutdownConfig) error { + if config.Timeout == 0 { + config.Timeout = 10 * time.Second + } + if len(config.Signals) == 0 { + config.Signals = []os.Signal{syscall.SIGINT, syscall.SIGTERM} + } + + ctx, stop := signal.NotifyContext(context.Background(), config.Signals...) + defer stop() + + errCh := make(chan error, 1) + go func() { + if err := engine.Run(addr); err != nil && !errors.Is(err, http.ErrServerClosed) { + errCh <- err + } + close(errCh) + }() + + select { + case err := <-errCh: + return err + case <-ctx.Done(): + } + + shutdownCtx, cancel := context.WithTimeout(context.Background(), config.Timeout) + defer cancel() + + return engine.Shutdown(shutdownCtx) +} diff --git a/graceful_test.go b/graceful_test.go new file mode 100644 index 00000000..e3121782 --- /dev/null +++ b/graceful_test.go @@ -0,0 +1,267 @@ +// Copyright 2014 Manu Martinez-Almeida. All rights reserved. +// Use of this source code is governed by a MIT style +// license that can be found in the LICENSE file. + +package gin + +import ( + "context" + "net" + "net/http" + "os" + "syscall" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestEngineShutdown(t *testing.T) { + router := New() + router.GET("/", func(c *Context) { + c.String(http.StatusOK, "ok") + }) + + // Start server in goroutine + go func() { + err := router.Run(":18080") + assert.ErrorIs(t, err, http.ErrServerClosed) + }() + time.Sleep(100 * time.Millisecond) // Wait for server start + + // Verify server is running + resp, err := http.Get("http://localhost:18080/") + require.NoError(t, err) + resp.Body.Close() + assert.Equal(t, http.StatusOK, resp.StatusCode) + + // Shutdown + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + err = router.Shutdown(ctx) + require.NoError(t, err) + + // Wait a moment for server to fully stop + time.Sleep(50 * time.Millisecond) + + // Verify server is stopped + _, err = http.Get("http://localhost:18080/") + require.Error(t, err) +} + +func TestEngineShutdownBeforeStart(t *testing.T) { + router := New() + + // Shutdown before starting should not error + err := router.Shutdown(context.Background()) + require.NoError(t, err) +} + +func TestEngineShutdownTLS(t *testing.T) { + router := New() + router.GET("/", func(c *Context) { + c.String(http.StatusOK, "ok") + }) + + // Start TLS server in goroutine + go func() { + err := router.RunTLS(":18443", "./testdata/certificate/cert.pem", "./testdata/certificate/key.pem") + assert.ErrorIs(t, err, http.ErrServerClosed) + }() + time.Sleep(100 * time.Millisecond) // Wait for server start + + // Shutdown + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + err := router.Shutdown(ctx) + require.NoError(t, err) +} + +func TestEngineShutdownWithActiveRequest(t *testing.T) { + router := New() + + requestStarted := make(chan struct{}) + requestDone := make(chan struct{}) + + router.GET("/slow", func(c *Context) { + close(requestStarted) + time.Sleep(500 * time.Millisecond) // Simulate slow request + c.String(http.StatusOK, "done") + close(requestDone) + }) + + // Start server + go func() { + _ = router.Run(":18081") + }() + time.Sleep(100 * time.Millisecond) + + // Start slow request + go func() { + resp, err := http.Get("http://localhost:18081/slow") + if err == nil { + resp.Body.Close() + } + }() + + // Wait for request to start + <-requestStarted + + // Initiate shutdown while request is in progress + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + shutdownDone := make(chan error, 1) + go func() { + shutdownDone <- router.Shutdown(ctx) + }() + + // Verify request completes before shutdown finishes + select { + case <-requestDone: + // Request completed - this is expected + case err := <-shutdownDone: + t.Errorf("Shutdown completed before request finished: %v", err) + } + + // Wait for shutdown to complete + err := <-shutdownDone + require.NoError(t, err) +} + +func TestRunWithShutdown(t *testing.T) { + router := New() + router.GET("/", func(c *Context) { + c.String(http.StatusOK, "ok") + }) + + errCh := make(chan error, 1) + go func() { + errCh <- router.RunWithShutdown(":18082", 5*time.Second) + }() + + // Wait for server to start + time.Sleep(100 * time.Millisecond) + + // Verify server is running + resp, err := http.Get("http://localhost:18082/") + require.NoError(t, err) + resp.Body.Close() + assert.Equal(t, http.StatusOK, resp.StatusCode) + + // Send shutdown signal to self + p, err := os.FindProcess(os.Getpid()) + require.NoError(t, err) + err = p.Signal(syscall.SIGINT) + require.NoError(t, err) + + // Wait for shutdown to complete + select { + case err := <-errCh: + require.NoError(t, err) + case <-time.After(10 * time.Second): + t.Fatal("Shutdown timed out") + } +} + +func TestRunWithShutdownConfig(t *testing.T) { + router := New() + router.GET("/", func(c *Context) { + c.String(http.StatusOK, "ok") + }) + + config := ShutdownConfig{ + Timeout: 5 * time.Second, + Signals: []os.Signal{syscall.SIGUSR1}, + } + + errCh := make(chan error, 1) + go func() { + errCh <- router.RunWithShutdownConfig(":18083", config) + }() + + // Wait for server to start + time.Sleep(100 * time.Millisecond) + + // Verify server is running + resp, err := http.Get("http://localhost:18083/") + require.NoError(t, err) + resp.Body.Close() + assert.Equal(t, http.StatusOK, resp.StatusCode) + + // Send custom signal + p, err := os.FindProcess(os.Getpid()) + require.NoError(t, err) + err = p.Signal(syscall.SIGUSR1) + require.NoError(t, err) + + // Wait for shutdown to complete + select { + case err := <-errCh: + require.NoError(t, err) + case <-time.After(10 * time.Second): + t.Fatal("Shutdown timed out") + } +} + +func TestRunWithShutdownConfigDefaults(t *testing.T) { + router := New() + router.GET("/", func(c *Context) { + c.String(http.StatusOK, "ok") + }) + + // Test with zero values to check defaults are applied + config := ShutdownConfig{} + + errCh := make(chan error, 1) + go func() { + errCh <- router.RunWithShutdownConfig(":18084", config) + }() + + // Wait for server to start + time.Sleep(100 * time.Millisecond) + + // Verify server is running + resp, err := http.Get("http://localhost:18084/") + require.NoError(t, err) + resp.Body.Close() + assert.Equal(t, http.StatusOK, resp.StatusCode) + + // Send SIGINT (default signal) + p, err := os.FindProcess(os.Getpid()) + require.NoError(t, err) + err = p.Signal(syscall.SIGINT) + require.NoError(t, err) + + // Wait for shutdown to complete + select { + case err := <-errCh: + require.NoError(t, err) + case <-time.After(15 * time.Second): + t.Fatal("Shutdown timed out") + } +} + +func TestRunWithShutdownServerError(t *testing.T) { + router := New() + + // Start a server on the same port first + listener, err := net.Listen("tcp", ":18085") + require.NoError(t, err) + defer listener.Close() + + // Try to run on the same port - should fail + errCh := make(chan error, 1) + go func() { + errCh <- router.RunWithShutdown(":18085", 5*time.Second) + }() + + // Should get an error because port is already in use + select { + case err := <-errCh: + require.Error(t, err) + case <-time.After(2 * time.Second): + t.Fatal("Expected error but timed out") + } +}