From 8763f33c65f7df8be5b9fe7504ab7fcf20abb41d Mon Sep 17 00:00:00 2001 From: bound2 <9380102+bound2@users.noreply.github.com> Date: Thu, 20 Mar 2025 17:40:41 +0200 Subject: [PATCH] fix: prevent middleware re-entry issue in HandleContext (#3987) --- gin.go | 2 ++ gin_test.go | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/gin.go b/gin.go index e17596aa..0761c14d 100644 --- a/gin.go +++ b/gin.go @@ -637,10 +637,12 @@ func (engine *Engine) ServeHTTP(w http.ResponseWriter, req *http.Request) { // Disclaimer: You can loop yourself to deal with this, use wisely. func (engine *Engine) HandleContext(c *Context) { oldIndexValue := c.index + oldHandlers := c.handlers c.reset() engine.handleHTTPRequest(c) c.index = oldIndexValue + c.handlers = oldHandlers } func (engine *Engine) handleHTTPRequest(c *Context) { diff --git a/gin_test.go b/gin_test.go index 732da18b..850ae09b 100644 --- a/gin_test.go +++ b/gin_test.go @@ -573,6 +573,44 @@ func TestEngineHandleContextManyReEntries(t *testing.T) { assert.Equal(t, int64(expectValue), middlewareCounter) } +func TestEngineHandleContextPreventsMiddlewareReEntry(t *testing.T) { + // given + var handlerCounterV1, handlerCounterV2, middlewareCounterV1 int64 + + r := New() + v1 := r.Group("/v1") + { + v1.Use(func(c *Context) { + atomic.AddInt64(&middlewareCounterV1, 1) + }) + v1.GET("/test", func(c *Context) { + atomic.AddInt64(&handlerCounterV1, 1) + c.Status(http.StatusOK) + }) + } + + v2 := r.Group("/v2") + { + v2.GET("/test", func(c *Context) { + c.Request.URL.Path = "/v1/test" + r.HandleContext(c) + }, func(c *Context) { + atomic.AddInt64(&handlerCounterV2, 1) + }) + } + + // when + responseV1 := PerformRequest(r, "GET", "/v1/test") + responseV2 := PerformRequest(r, "GET", "/v2/test") + + // then + assert.Equal(t, 200, responseV1.Code) + assert.Equal(t, 200, responseV2.Code) + assert.Equal(t, int64(2), handlerCounterV1) + assert.Equal(t, int64(2), middlewareCounterV1) + assert.Equal(t, int64(1), handlerCounterV2) +} + func TestPrepareTrustedCIRDsWith(t *testing.T) { r := New()