diff --git a/gin.go b/gin.go index f9813e1d..e64e8938 100644 --- a/gin.go +++ b/gin.go @@ -171,8 +171,10 @@ type Engine struct { FuncMap template.FuncMap allNoRoute HandlersChain allNoMethod HandlersChain + allAutoRedirect HandlersChain noRoute HandlersChain noMethod HandlersChain + autoRedirect HandlersChain pool sync.Pool trees methodTrees maxParams uint16 @@ -325,6 +327,13 @@ func (engine *Engine) NoMethod(handlers ...HandlerFunc) { engine.rebuild405Handlers() } +// AutoRedirect sets the handlers called when auto redirected +// (RedirectTrailingSlash and RedirectFixedPath) +func (engine *Engine) AutoRedirect(handlers ...HandlerFunc) { + engine.autoRedirect = handlers + engine.rebuildAutoRedirectHandlers() +} + // Use attaches a global middleware to the router. i.e. the middleware attached through Use() will be // included in the handlers chain for every single request. Even 404, 405, static files... // For example, this is the right place for a logger or error management middleware. @@ -332,6 +341,7 @@ func (engine *Engine) Use(middleware ...HandlerFunc) IRoutes { engine.RouterGroup.Use(middleware...) engine.rebuild404Handlers() engine.rebuild405Handlers() + engine.rebuildAutoRedirectHandlers() return engine } @@ -352,6 +362,10 @@ func (engine *Engine) rebuild405Handlers() { engine.allNoMethod = engine.combineHandlers(engine.noMethod) } +func (engine *Engine) rebuildAutoRedirectHandlers() { + engine.allAutoRedirect = engine.combineHandlers(engine.autoRedirect) +} + func (engine *Engine) addRoute(method, path string, handlers HandlersChain) { assert1(path[0] == '/', "path must begin with '/'") assert1(method != "", "HTTP method can not be empty") @@ -692,6 +706,7 @@ func (engine *Engine) handleHTTPRequest(c *Context) { return } if httpMethod != http.MethodConnect && rPath != "/" { + c.handlers = engine.allAutoRedirect if value.tsr && engine.RedirectTrailingSlash { redirectTrailingSlash(c) return @@ -776,13 +791,14 @@ func redirectFixedPath(c *Context, root *node, trailingSlash bool) bool { func redirectRequest(c *Context) { req := c.Request - rPath := req.URL.Path - rURL := req.URL.String() - code := http.StatusMovedPermanently // Permanent redirect, request with GET method if req.Method != http.MethodGet { code = http.StatusTemporaryRedirect } + c.Next() + + rPath := req.URL.Path + rURL := req.URL.String() debugPrint("redirecting request %d: %s --> %s", code, rPath, rURL) http.Redirect(c.Writer, req, rURL, code) c.writermem.WriteHeaderNow() diff --git a/gin_test.go b/gin_test.go index 250269e5..443087b9 100644 --- a/gin_test.go +++ b/gin_test.go @@ -578,6 +578,59 @@ func TestNoMethodWithGlobalHandlers(t *testing.T) { compareFunc(t, router.allNoMethod[2], middleware0) } +func TestAutoRedirectWithoutGlobalHandlers(t *testing.T) { + var middleware0 HandlerFunc = func(c *Context) {} + var middleware1 HandlerFunc = func(c *Context) {} + + router := New() + + router.AutoRedirect(middleware0) + assert.Nil(t, router.Handlers) + assert.Len(t, router.autoRedirect, 1) + assert.Len(t, router.allAutoRedirect, 1) + compareFunc(t, router.autoRedirect[0], middleware0) + compareFunc(t, router.allAutoRedirect[0], middleware0) + + router.AutoRedirect(middleware1, middleware0) + assert.Len(t, router.autoRedirect, 2) + assert.Len(t, router.allAutoRedirect, 2) + compareFunc(t, router.autoRedirect[0], middleware1) + compareFunc(t, router.allAutoRedirect[0], middleware1) + compareFunc(t, router.autoRedirect[1], middleware0) + compareFunc(t, router.allAutoRedirect[1], middleware0) +} + +func TestAutoRedirectWithGlobalHandlers(t *testing.T) { + var middleware0 HandlerFunc = func(c *Context) {} + var middleware1 HandlerFunc = func(c *Context) {} + var middleware2 HandlerFunc = func(c *Context) {} + + router := New() + router.Use(middleware2) + + router.AutoRedirect(middleware0) + assert.Len(t, router.allAutoRedirect, 2) + assert.Len(t, router.Handlers, 1) + assert.Len(t, router.autoRedirect, 1) + + compareFunc(t, router.Handlers[0], middleware2) + compareFunc(t, router.autoRedirect[0], middleware0) + compareFunc(t, router.allAutoRedirect[0], middleware2) + compareFunc(t, router.allAutoRedirect[1], middleware0) + + router.Use(middleware1) + assert.Len(t, router.allAutoRedirect, 3) + assert.Len(t, router.Handlers, 2) + assert.Len(t, router.autoRedirect, 1) + + compareFunc(t, router.Handlers[0], middleware2) + compareFunc(t, router.Handlers[1], middleware1) + compareFunc(t, router.autoRedirect[0], middleware0) + compareFunc(t, router.allAutoRedirect[0], middleware2) + compareFunc(t, router.allAutoRedirect[1], middleware1) + compareFunc(t, router.allAutoRedirect[2], middleware0) +} + func compareFunc(t *testing.T, a, b any) { sf1 := reflect.ValueOf(a) sf2 := reflect.ValueOf(b) diff --git a/routes_test.go b/routes_test.go index 995ff51c..72cc3087 100644 --- a/routes_test.go +++ b/routes_test.go @@ -273,6 +273,27 @@ func TestRouteRedirectFixedPath(t *testing.T) { assert.Equal(t, http.StatusTemporaryRedirect, w.Code) } +func TestRouteRedirectWithHandler(t *testing.T) { + router := New() + router.RedirectTrailingSlash = true + router.GET("/path", func(c *Context) {}) + passed := []bool{false, false} + router.Use(func(c *Context) { + passed[0] = true + c.Next() + }) + router.AutoRedirect(func(c *Context) { + passed[1] = true + c.Next() + }) + + w := performRequest(router, http.MethodGet, "/path/") + assert.Equal(t, "/path", w.Header().Get("Location")) + assert.Equal(t, http.StatusMovedPermanently, w.Code) + assert.True(t, passed[0]) + assert.True(t, passed[1]) +} + // TestContextParamsGet tests that a parameter can be parsed from the URL. func TestRouteParamsByName(t *testing.T) { name := ""