Merge 96f63d68d3086cd0883f77ac39c5cc73aedd530d into 19f5a13fb421fd71c10a06f244734bba317f63a6

This commit is contained in:
unbyte 2025-05-21 09:11:45 +07:00 committed by GitHub
commit f26ebea157
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 93 additions and 3 deletions

22
gin.go
View File

@ -171,8 +171,10 @@ type Engine struct {
FuncMap template.FuncMap FuncMap template.FuncMap
allNoRoute HandlersChain allNoRoute HandlersChain
allNoMethod HandlersChain allNoMethod HandlersChain
allAutoRedirect HandlersChain
noRoute HandlersChain noRoute HandlersChain
noMethod HandlersChain noMethod HandlersChain
autoRedirect HandlersChain
pool sync.Pool pool sync.Pool
trees methodTrees trees methodTrees
maxParams uint16 maxParams uint16
@ -325,6 +327,13 @@ func (engine *Engine) NoMethod(handlers ...HandlerFunc) {
engine.rebuild405Handlers() engine.rebuild405Handlers()
} }
// AutoRedirect sets the handlers called when auto redirected
// (RedirectTrailingSlash and RedirectFixedPath)
func (engine *Engine) AutoRedirect(handlers ...HandlerFunc) {
engine.autoRedirect = handlers
engine.rebuildAutoRedirectHandlers()
}
// Use attaches a global middleware to the router. i.e. the middleware attached through Use() will be // Use attaches a global middleware to the router. i.e. the middleware attached through Use() will be
// included in the handlers chain for every single request. Even 404, 405, static files... // included in the handlers chain for every single request. Even 404, 405, static files...
// For example, this is the right place for a logger or error management middleware. // For example, this is the right place for a logger or error management middleware.
@ -332,6 +341,7 @@ func (engine *Engine) Use(middleware ...HandlerFunc) IRoutes {
engine.RouterGroup.Use(middleware...) engine.RouterGroup.Use(middleware...)
engine.rebuild404Handlers() engine.rebuild404Handlers()
engine.rebuild405Handlers() engine.rebuild405Handlers()
engine.rebuildAutoRedirectHandlers()
return engine return engine
} }
@ -352,6 +362,10 @@ func (engine *Engine) rebuild405Handlers() {
engine.allNoMethod = engine.combineHandlers(engine.noMethod) engine.allNoMethod = engine.combineHandlers(engine.noMethod)
} }
func (engine *Engine) rebuildAutoRedirectHandlers() {
engine.allAutoRedirect = engine.combineHandlers(engine.autoRedirect)
}
func (engine *Engine) addRoute(method, path string, handlers HandlersChain) { func (engine *Engine) addRoute(method, path string, handlers HandlersChain) {
assert1(path[0] == '/', "path must begin with '/'") assert1(path[0] == '/', "path must begin with '/'")
assert1(method != "", "HTTP method can not be empty") assert1(method != "", "HTTP method can not be empty")
@ -692,6 +706,7 @@ func (engine *Engine) handleHTTPRequest(c *Context) {
return return
} }
if httpMethod != http.MethodConnect && rPath != "/" { if httpMethod != http.MethodConnect && rPath != "/" {
c.handlers = engine.allAutoRedirect
if value.tsr && engine.RedirectTrailingSlash { if value.tsr && engine.RedirectTrailingSlash {
redirectTrailingSlash(c) redirectTrailingSlash(c)
return return
@ -776,13 +791,14 @@ func redirectFixedPath(c *Context, root *node, trailingSlash bool) bool {
func redirectRequest(c *Context) { func redirectRequest(c *Context) {
req := c.Request req := c.Request
rPath := req.URL.Path
rURL := req.URL.String()
code := http.StatusMovedPermanently // Permanent redirect, request with GET method code := http.StatusMovedPermanently // Permanent redirect, request with GET method
if req.Method != http.MethodGet { if req.Method != http.MethodGet {
code = http.StatusTemporaryRedirect code = http.StatusTemporaryRedirect
} }
c.Next()
rPath := req.URL.Path
rURL := req.URL.String()
debugPrint("redirecting request %d: %s --> %s", code, rPath, rURL) debugPrint("redirecting request %d: %s --> %s", code, rPath, rURL)
http.Redirect(c.Writer, req, rURL, code) http.Redirect(c.Writer, req, rURL, code)
c.writermem.WriteHeaderNow() c.writermem.WriteHeaderNow()

View File

@ -578,6 +578,59 @@ func TestNoMethodWithGlobalHandlers(t *testing.T) {
compareFunc(t, router.allNoMethod[2], middleware0) compareFunc(t, router.allNoMethod[2], middleware0)
} }
func TestAutoRedirectWithoutGlobalHandlers(t *testing.T) {
var middleware0 HandlerFunc = func(c *Context) {}
var middleware1 HandlerFunc = func(c *Context) {}
router := New()
router.AutoRedirect(middleware0)
assert.Nil(t, router.Handlers)
assert.Len(t, router.autoRedirect, 1)
assert.Len(t, router.allAutoRedirect, 1)
compareFunc(t, router.autoRedirect[0], middleware0)
compareFunc(t, router.allAutoRedirect[0], middleware0)
router.AutoRedirect(middleware1, middleware0)
assert.Len(t, router.autoRedirect, 2)
assert.Len(t, router.allAutoRedirect, 2)
compareFunc(t, router.autoRedirect[0], middleware1)
compareFunc(t, router.allAutoRedirect[0], middleware1)
compareFunc(t, router.autoRedirect[1], middleware0)
compareFunc(t, router.allAutoRedirect[1], middleware0)
}
func TestAutoRedirectWithGlobalHandlers(t *testing.T) {
var middleware0 HandlerFunc = func(c *Context) {}
var middleware1 HandlerFunc = func(c *Context) {}
var middleware2 HandlerFunc = func(c *Context) {}
router := New()
router.Use(middleware2)
router.AutoRedirect(middleware0)
assert.Len(t, router.allAutoRedirect, 2)
assert.Len(t, router.Handlers, 1)
assert.Len(t, router.autoRedirect, 1)
compareFunc(t, router.Handlers[0], middleware2)
compareFunc(t, router.autoRedirect[0], middleware0)
compareFunc(t, router.allAutoRedirect[0], middleware2)
compareFunc(t, router.allAutoRedirect[1], middleware0)
router.Use(middleware1)
assert.Len(t, router.allAutoRedirect, 3)
assert.Len(t, router.Handlers, 2)
assert.Len(t, router.autoRedirect, 1)
compareFunc(t, router.Handlers[0], middleware2)
compareFunc(t, router.Handlers[1], middleware1)
compareFunc(t, router.autoRedirect[0], middleware0)
compareFunc(t, router.allAutoRedirect[0], middleware2)
compareFunc(t, router.allAutoRedirect[1], middleware1)
compareFunc(t, router.allAutoRedirect[2], middleware0)
}
func compareFunc(t *testing.T, a, b any) { func compareFunc(t *testing.T, a, b any) {
sf1 := reflect.ValueOf(a) sf1 := reflect.ValueOf(a)
sf2 := reflect.ValueOf(b) sf2 := reflect.ValueOf(b)

View File

@ -273,6 +273,27 @@ func TestRouteRedirectFixedPath(t *testing.T) {
assert.Equal(t, http.StatusTemporaryRedirect, w.Code) assert.Equal(t, http.StatusTemporaryRedirect, w.Code)
} }
func TestRouteRedirectWithHandler(t *testing.T) {
router := New()
router.RedirectTrailingSlash = true
router.GET("/path", func(c *Context) {})
passed := []bool{false, false}
router.Use(func(c *Context) {
passed[0] = true
c.Next()
})
router.AutoRedirect(func(c *Context) {
passed[1] = true
c.Next()
})
w := performRequest(router, http.MethodGet, "/path/")
assert.Equal(t, "/path", w.Header().Get("Location"))
assert.Equal(t, http.StatusMovedPermanently, w.Code)
assert.True(t, passed[0])
assert.True(t, passed[1])
}
// TestContextParamsGet tests that a parameter can be parsed from the URL. // TestContextParamsGet tests that a parameter can be parsed from the URL.
func TestRouteParamsByName(t *testing.T) { func TestRouteParamsByName(t *testing.T) {
name := "" name := ""