Merge 90d541ec284b55d7c86bc7a51b9740a7d1940887 into 8acbe657f1c140e3fba38f869978cab2376500c9

This commit is contained in:
Yashvardhan Kukreja 2024-04-02 17:30:47 +08:00 committed by GitHub
commit 679596bf98
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 172 additions and 22 deletions

View File

@ -121,6 +121,13 @@ func Benchmark404Many(B *testing.B) {
runRequest(B, router, "GET", "/viewfake")
}
func BenchmarkOnRedirect(B *testing.B) {
router := New()
router.GET("/something", func(c *Context) {})
router.OnRedirect(func(c *Context) {})
runRequest(B, router, "GET", "/something/")
}
type mockWriter struct {
headers http.Header
}

35
gin.go
View File

@ -165,8 +165,10 @@ type Engine struct {
FuncMap template.FuncMap
allNoRoute HandlersChain
allNoMethod HandlersChain
allRedirectMethod HandlersChain
noRoute HandlersChain
noMethod HandlersChain
redirectMethod HandlersChain
pool sync.Pool
trees methodTrees
maxParams uint16
@ -300,6 +302,12 @@ func (engine *Engine) NoRoute(handlers ...HandlerFunc) {
engine.rebuild404Handlers()
}
// OnRedirect adds handlers for when redirects are made.
func (engine *Engine) OnRedirect(handlers ...HandlerFunc) {
engine.redirectMethod = handlers
engine.rebuildRedirectHandlers()
}
// NoMethod sets the handlers called when Engine.HandleMethodNotAllowed = true.
func (engine *Engine) NoMethod(handlers ...HandlerFunc) {
engine.noMethod = handlers
@ -313,6 +321,7 @@ func (engine *Engine) Use(middleware ...HandlerFunc) IRoutes {
engine.RouterGroup.Use(middleware...)
engine.rebuild404Handlers()
engine.rebuild405Handlers()
engine.rebuildRedirectHandlers()
return engine
}
@ -333,6 +342,10 @@ func (engine *Engine) rebuild405Handlers() {
engine.allNoMethod = engine.combineHandlers(engine.noMethod)
}
func (engine *Engine) rebuildRedirectHandlers() {
engine.allRedirectMethod = engine.combineHandlers(engine.redirectMethod)
}
func (engine *Engine) addRoute(method, path string, handlers HandlersChain) {
assert1(path[0] == '/', "path must begin with '/'")
assert1(method != "", "HTTP method can not be empty")
@ -635,14 +648,19 @@ func (engine *Engine) handleHTTPRequest(c *Context) {
return
}
if httpMethod != http.MethodConnect && rPath != "/" {
executeRedirectionMiddlewares := len(engine.redirectMethod) > 0
if value.tsr && engine.RedirectTrailingSlash {
redirectTrailingSlash(c)
c.handlers = engine.allRedirectMethod
redirectTrailingSlash(c, executeRedirectionMiddlewares)
return
}
if engine.RedirectFixedPath && redirectFixedPath(c, root, engine.RedirectFixedPath) {
if engine.RedirectFixedPath {
c.handlers = engine.allRedirectMethod
if redirectFixedPath(c, root, engine.RedirectFixedPath, executeRedirectionMiddlewares) {
return
}
}
}
break
}
@ -689,7 +707,7 @@ func serveError(c *Context, code int, defaultMessage []byte) {
c.writermem.WriteHeaderNow()
}
func redirectTrailingSlash(c *Context) {
func redirectTrailingSlash(c *Context, withRedirectionMiddlewares bool) {
req := c.Request
p := req.URL.Path
if prefix := path.Clean(c.Request.Header.Get("X-Forwarded-Prefix")); prefix != "." {
@ -702,22 +720,22 @@ func redirectTrailingSlash(c *Context) {
if length := len(p); length > 1 && p[length-1] == '/' {
req.URL.Path = p[:length-1]
}
redirectRequest(c)
redirectRequest(c, withRedirectionMiddlewares)
}
func redirectFixedPath(c *Context, root *node, trailingSlash bool) bool {
func redirectFixedPath(c *Context, root *node, trailingSlash, withRedirectionMiddlewares bool) bool {
req := c.Request
rPath := req.URL.Path
if fixedPath, ok := root.findCaseInsensitivePath(cleanPath(rPath), trailingSlash); ok {
req.URL.Path = bytesconv.BytesToString(fixedPath)
redirectRequest(c)
redirectRequest(c, withRedirectionMiddlewares)
return true
}
return false
}
func redirectRequest(c *Context) {
func redirectRequest(c *Context, executeRequestChain bool) {
req := c.Request
rPath := req.URL.Path
rURL := req.URL.String()
@ -727,6 +745,9 @@ func redirectRequest(c *Context) {
code = http.StatusTemporaryRedirect
}
debugPrint("redirecting request %d: %s --> %s", code, rPath, rURL)
if executeRequestChain {
c.Next()
}
http.Redirect(c.Writer, req, rURL, code)
c.writermem.WriteHeaderNow()
}

View File

@ -42,6 +42,11 @@ func NoRoute(handlers ...gin.HandlerFunc) {
engine().NoRoute(handlers...)
}
// OnRedirect is a wrapper for Engine.OnRedirect.
func OnRedirect(handlers ...gin.HandlerFunc) {
engine().OnRedirect(handlers...)
}
// NoMethod is a wrapper for Engine.NoMethod.
func NoMethod(handlers ...gin.HandlerFunc) {
engine().NoMethod(handlers...)

View File

@ -433,6 +433,59 @@ func TestNoMethodWithoutGlobalHandlers(t *testing.T) {
compareFunc(t, router.allNoMethod[1], middleware0)
}
func TestOnRedirectWithoutGlobalHandlers(t *testing.T) {
var middleware0 HandlerFunc = func(c *Context) {}
var middleware1 HandlerFunc = func(c *Context) {}
router := New()
router.OnRedirect(middleware0)
assert.Nil(t, router.Handlers)
assert.Len(t, router.redirectMethod, 1)
assert.Len(t, router.allRedirectMethod, 1)
compareFunc(t, router.redirectMethod[0], middleware0)
compareFunc(t, router.allRedirectMethod[0], middleware0)
router.OnRedirect(middleware1, middleware0)
assert.Len(t, router.redirectMethod, 2)
assert.Len(t, router.allRedirectMethod, 2)
compareFunc(t, router.redirectMethod[0], middleware1)
compareFunc(t, router.allRedirectMethod[0], middleware1)
compareFunc(t, router.redirectMethod[1], middleware0)
compareFunc(t, router.allRedirectMethod[1], middleware0)
}
func TestOnRedirectWithGlobalHandlers(t *testing.T) {
var middleware0 HandlerFunc = func(c *Context) {}
var middleware1 HandlerFunc = func(c *Context) {}
var middleware2 HandlerFunc = func(c *Context) {}
router := New()
router.Use(middleware2)
router.OnRedirect(middleware0)
assert.Len(t, router.allRedirectMethod, 2)
assert.Len(t, router.Handlers, 1)
assert.Len(t, router.redirectMethod, 1)
compareFunc(t, router.Handlers[0], middleware2)
compareFunc(t, router.redirectMethod[0], middleware0)
compareFunc(t, router.allRedirectMethod[0], middleware2)
compareFunc(t, router.allRedirectMethod[1], middleware0)
router.Use(middleware1)
assert.Len(t, router.allRedirectMethod, 3)
assert.Len(t, router.Handlers, 2)
assert.Len(t, router.redirectMethod, 1)
compareFunc(t, router.Handlers[0], middleware2)
compareFunc(t, router.Handlers[1], middleware1)
compareFunc(t, router.redirectMethod[0], middleware0)
compareFunc(t, router.allRedirectMethod[0], middleware2)
compareFunc(t, router.allRedirectMethod[1], middleware1)
compareFunc(t, router.allRedirectMethod[2], middleware0)
}
func TestRebuild404Handlers(t *testing.T) {
}

2
go.mod
View File

@ -32,6 +32,8 @@ require (
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
golang.org/x/arch v0.7.0 // indirect
golang.org/x/crypto v0.19.0 // indirect
golang.org/x/lint v0.0.0-20210508222113-6edffad5e616 // indirect
golang.org/x/sys v0.17.0 // indirect
golang.org/x/text v0.14.0 // indirect
golang.org/x/tools v0.18.0 // indirect
)

15
go.sum
View File

@ -63,16 +63,31 @@ github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZ
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.7.0 h1:pskyeJh/3AmoQ8CPE95vxHLqp1G1GfGNXTmcl9NEKTc=
golang.org/x/arch v0.7.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo=
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/lint v0.0.0-20210508222113-6edffad5e616 h1:VLliZ0d+/avPrXXH+OakdXhpJuEoBZuwh1m2j7U6Iug=
golang.org/x/lint v0.0.0-20210508222113-6edffad5e616/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/tools v0.18.0 h1:k8NLag8AGHnn+PHbl7g43CtqZAwG60vZkLqgyZgIHgQ=
golang.org/x/tools v0.18.0/go.mod h1:GL7B4CwcLLeo59yx/9UWWuNOW1n3VZ4f5axWfML7Lcg=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=

View File

@ -34,6 +34,9 @@ func TestMiddlewareGeneralCase(t *testing.T) {
router.NoMethod(func(c *Context) {
signature += " XX "
})
router.OnRedirect(func(c *Context) {
signature += " YYY "
})
// RUN
w := PerformRequest(router, "GET", "/")
@ -78,6 +81,50 @@ func TestMiddlewareNoRoute(t *testing.T) {
assert.Equal(t, "ACEGHFDB", signature)
}
func TestMiddlewareOnRedirect(t *testing.T) {
signature := ""
router := New()
router.Use(func(c *Context) {
signature += "A"
c.Next()
signature += "B"
})
router.Use(func(c *Context) {
signature += "C"
c.Next()
c.Next()
c.Next()
c.Next()
signature += "D"
})
router.NoRoute(func(c *Context) {
signature += "E"
c.Next()
signature += "F"
}, func(c *Context) {
signature += "G"
c.Next()
signature += "H"
})
router.NoMethod(func(c *Context) {
signature += " X "
})
router.OnRedirect(func(c *Context) {
signature += "Y"
})
router.GET("/foo", func(c *Context) {
c.String(200, "Hello, World!")
})
// RUN
w := PerformRequest(router, "GET", "/foo/")
// TEST
assert.Equal(t, http.StatusMovedPermanently, w.Code)
assert.Equal(t, "ACYDB", signature)
}
func TestMiddlewareNoMethodEnabled(t *testing.T) {
signature := ""
router := New()