From 42008a7ca617d390e5dd13f38f091158ce51bf07 Mon Sep 17 00:00:00 2001 From: mehrdadbn9 Date: Fri, 13 Feb 2026 12:15:54 +0330 Subject: [PATCH] fix: concurrent-safe route registration with mutex synchronization - Add sync.RWMutex to Engine struct for thread-safe access - Protect addRoute() with Lock() for exclusive write access - Protect Routes() with RLock() for concurrent read access - Add concurrent test to verify no data race - Add tests for serveError function coverage Fixes #4457 --- gin.go | 8 ++++++++ gin_test.go | 54 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+) diff --git a/gin.go b/gin.go index 2e033bf3..bd8a824c 100644 --- a/gin.go +++ b/gin.go @@ -96,6 +96,9 @@ type Engine struct { // (used for routing HTTP requests) happens only once, even if called multiple times concurrently. routeTreesUpdated sync.Once + // mu protects the route trees during concurrent access. + mu sync.RWMutex + // RedirectTrailingSlash enables automatic redirection if the current route can't be matched but a // handler for the path with (without) the trailing slash exists. // For example if /foo/ is requested but a route only exists for /foo, the @@ -368,6 +371,9 @@ func (engine *Engine) addRoute(method, path string, handlers HandlersChain) { debugPrintRoute(method, path, handlers) + engine.mu.Lock() + defer engine.mu.Unlock() + root := engine.trees.get(method) if root == nil { root = new(node) @@ -388,6 +394,8 @@ func (engine *Engine) addRoute(method, path string, handlers HandlersChain) { // Routes returns a slice of registered routes, including some useful information, such as: // the http method, path, and the handler name. func (engine *Engine) Routes() (routes RoutesInfo) { + engine.mu.RLock() + defer engine.mu.RUnlock() for _, tree := range engine.trees { routes = iterate("", tree.method, routes, tree.root) } diff --git a/gin_test.go b/gin_test.go index 43c9494d..a5ce3cd2 100644 --- a/gin_test.go +++ b/gin_test.go @@ -15,6 +15,7 @@ import ( "reflect" "strconv" "strings" + "sync" "sync/atomic" "testing" "time" @@ -1084,3 +1085,56 @@ func TestUpdateRouteTreesCalledOnce(t *testing.T) { assert.Equal(t, "ok", w.Body.String()) } } + +func TestConcurrentAddRouteAndRoutes(t *testing.T) { + router := New() + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(2) + go func(i int) { + defer wg.Done() + router.GET(fmt.Sprintf("/route%d", i), func(c *Context) { + c.String(http.StatusOK, "ok") + }) + }(i) + go func(i int) { + defer wg.Done() + _ = router.Routes() + }(i) + } + wg.Wait() + assert.Len(t, router.Routes(), 100) +} + +func TestServeErrorWritten(t *testing.T) { + router := New() + router.GET("/", func(c *Context) { + c.Status(http.StatusInternalServerError) + _, _ = c.Writer.Write([]byte("custom error")) + }) + router.NoRoute(func(c *Context) { + c.Status(http.StatusNotFound) + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/notfound", nil) + router.ServeHTTP(w, req) + assert.Equal(t, http.StatusNotFound, w.Code) +} + +func TestServeErrorStatusMismatch(t *testing.T) { + router := New() + router.HandleMethodNotAllowed = true + router.NoMethod(func(c *Context) { + c.Status(http.StatusForbidden) + }) + + router.GET("/exists", func(c *Context) { + c.String(http.StatusOK, "ok") + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodPost, "/exists", nil) + router.ServeHTTP(w, req) + assert.Equal(t, http.StatusForbidden, w.Code) +}