mirror of
https://github.com/gin-gonic/gin.git
synced 2026-06-05 10:28:21 +08:00
fix: concurrent-safe route registration with mutex synchronization
- Add sync.RWMutex to Engine struct for thread-safe access - Protect addRoute() with Lock() for exclusive write access - Protect Routes() with RLock() for concurrent read access - Add concurrent test to verify no data race - Add tests for serveError function coverage Fixes #4457
This commit is contained in:
parent
f5c267d2f8
commit
42008a7ca6
8
gin.go
8
gin.go
@ -96,6 +96,9 @@ type Engine struct {
|
|||||||
// (used for routing HTTP requests) happens only once, even if called multiple times concurrently.
|
// (used for routing HTTP requests) happens only once, even if called multiple times concurrently.
|
||||||
routeTreesUpdated sync.Once
|
routeTreesUpdated sync.Once
|
||||||
|
|
||||||
|
// mu protects the route trees during concurrent access.
|
||||||
|
mu sync.RWMutex
|
||||||
|
|
||||||
// RedirectTrailingSlash enables automatic redirection if the current route can't be matched but a
|
// RedirectTrailingSlash enables automatic redirection if the current route can't be matched but a
|
||||||
// handler for the path with (without) the trailing slash exists.
|
// handler for the path with (without) the trailing slash exists.
|
||||||
// For example if /foo/ is requested but a route only exists for /foo, the
|
// For example if /foo/ is requested but a route only exists for /foo, the
|
||||||
@ -368,6 +371,9 @@ func (engine *Engine) addRoute(method, path string, handlers HandlersChain) {
|
|||||||
|
|
||||||
debugPrintRoute(method, path, handlers)
|
debugPrintRoute(method, path, handlers)
|
||||||
|
|
||||||
|
engine.mu.Lock()
|
||||||
|
defer engine.mu.Unlock()
|
||||||
|
|
||||||
root := engine.trees.get(method)
|
root := engine.trees.get(method)
|
||||||
if root == nil {
|
if root == nil {
|
||||||
root = new(node)
|
root = new(node)
|
||||||
@ -388,6 +394,8 @@ func (engine *Engine) addRoute(method, path string, handlers HandlersChain) {
|
|||||||
// Routes returns a slice of registered routes, including some useful information, such as:
|
// Routes returns a slice of registered routes, including some useful information, such as:
|
||||||
// the http method, path, and the handler name.
|
// the http method, path, and the handler name.
|
||||||
func (engine *Engine) Routes() (routes RoutesInfo) {
|
func (engine *Engine) Routes() (routes RoutesInfo) {
|
||||||
|
engine.mu.RLock()
|
||||||
|
defer engine.mu.RUnlock()
|
||||||
for _, tree := range engine.trees {
|
for _, tree := range engine.trees {
|
||||||
routes = iterate("", tree.method, routes, tree.root)
|
routes = iterate("", tree.method, routes, tree.root)
|
||||||
}
|
}
|
||||||
|
|||||||
54
gin_test.go
54
gin_test.go
@ -15,6 +15,7 @@ import (
|
|||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@ -1084,3 +1085,56 @@ func TestUpdateRouteTreesCalledOnce(t *testing.T) {
|
|||||||
assert.Equal(t, "ok", w.Body.String())
|
assert.Equal(t, "ok", w.Body.String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConcurrentAddRouteAndRoutes(t *testing.T) {
|
||||||
|
router := New()
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
wg.Add(2)
|
||||||
|
go func(i int) {
|
||||||
|
defer wg.Done()
|
||||||
|
router.GET(fmt.Sprintf("/route%d", i), func(c *Context) {
|
||||||
|
c.String(http.StatusOK, "ok")
|
||||||
|
})
|
||||||
|
}(i)
|
||||||
|
go func(i int) {
|
||||||
|
defer wg.Done()
|
||||||
|
_ = router.Routes()
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
assert.Len(t, router.Routes(), 100)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServeErrorWritten(t *testing.T) {
|
||||||
|
router := New()
|
||||||
|
router.GET("/", func(c *Context) {
|
||||||
|
c.Status(http.StatusInternalServerError)
|
||||||
|
_, _ = c.Writer.Write([]byte("custom error"))
|
||||||
|
})
|
||||||
|
router.NoRoute(func(c *Context) {
|
||||||
|
c.Status(http.StatusNotFound)
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "/notfound", nil)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServeErrorStatusMismatch(t *testing.T) {
|
||||||
|
router := New()
|
||||||
|
router.HandleMethodNotAllowed = true
|
||||||
|
router.NoMethod(func(c *Context) {
|
||||||
|
c.Status(http.StatusForbidden)
|
||||||
|
})
|
||||||
|
|
||||||
|
router.GET("/exists", func(c *Context) {
|
||||||
|
c.String(http.StatusOK, "ok")
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "/exists", nil)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
assert.Equal(t, http.StatusForbidden, w.Code)
|
||||||
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user