Merge 90d541ec284b55d7c86bc7a51b9740a7d1940887 into 8acbe657f1c140e3fba38f869978cab2376500c9

This commit is contained in:
Yashvardhan Kukreja 2024-04-02 17:30:47 +08:00 committed by GitHub
commit 679596bf98
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 172 additions and 22 deletions

View File

@ -121,6 +121,13 @@ func Benchmark404Many(B *testing.B) {
runRequest(B, router, "GET", "/viewfake")
}
func BenchmarkOnRedirect(B *testing.B) {
router := New()
router.GET("/something", func(c *Context) {})
router.OnRedirect(func(c *Context) {})
runRequest(B, router, "GET", "/something/")
}
type mockWriter struct {
headers http.Header
}

65
gin.go
View File

@ -159,20 +159,22 @@ type Engine struct {
// ContextWithFallback enable fallback Context.Deadline(), Context.Done(), Context.Err() and Context.Value() when Context.Request.Context() is not nil.
ContextWithFallback bool
delims render.Delims
secureJSONPrefix string
HTMLRender render.HTMLRender
FuncMap template.FuncMap
allNoRoute HandlersChain
allNoMethod HandlersChain
noRoute HandlersChain
noMethod HandlersChain
pool sync.Pool
trees methodTrees
maxParams uint16
maxSections uint16
trustedProxies []string
trustedCIDRs []*net.IPNet
delims render.Delims
secureJSONPrefix string
HTMLRender render.HTMLRender
FuncMap template.FuncMap
allNoRoute HandlersChain
allNoMethod HandlersChain
allRedirectMethod HandlersChain
noRoute HandlersChain
noMethod HandlersChain
redirectMethod HandlersChain
pool sync.Pool
trees methodTrees
maxParams uint16
maxSections uint16
trustedProxies []string
trustedCIDRs []*net.IPNet
}
var _ IRouter = (*Engine)(nil)
@ -300,6 +302,12 @@ func (engine *Engine) NoRoute(handlers ...HandlerFunc) {
engine.rebuild404Handlers()
}
// OnRedirect adds handlers for when redirects are made.
func (engine *Engine) OnRedirect(handlers ...HandlerFunc) {
engine.redirectMethod = handlers
engine.rebuildRedirectHandlers()
}
// NoMethod sets the handlers called when Engine.HandleMethodNotAllowed = true.
func (engine *Engine) NoMethod(handlers ...HandlerFunc) {
engine.noMethod = handlers
@ -313,6 +321,7 @@ func (engine *Engine) Use(middleware ...HandlerFunc) IRoutes {
engine.RouterGroup.Use(middleware...)
engine.rebuild404Handlers()
engine.rebuild405Handlers()
engine.rebuildRedirectHandlers()
return engine
}
@ -333,6 +342,10 @@ func (engine *Engine) rebuild405Handlers() {
engine.allNoMethod = engine.combineHandlers(engine.noMethod)
}
func (engine *Engine) rebuildRedirectHandlers() {
engine.allRedirectMethod = engine.combineHandlers(engine.redirectMethod)
}
func (engine *Engine) addRoute(method, path string, handlers HandlersChain) {
assert1(path[0] == '/', "path must begin with '/'")
assert1(method != "", "HTTP method can not be empty")
@ -635,12 +648,17 @@ func (engine *Engine) handleHTTPRequest(c *Context) {
return
}
if httpMethod != http.MethodConnect && rPath != "/" {
executeRedirectionMiddlewares := len(engine.redirectMethod) > 0
if value.tsr && engine.RedirectTrailingSlash {
redirectTrailingSlash(c)
c.handlers = engine.allRedirectMethod
redirectTrailingSlash(c, executeRedirectionMiddlewares)
return
}
if engine.RedirectFixedPath && redirectFixedPath(c, root, engine.RedirectFixedPath) {
return
if engine.RedirectFixedPath {
c.handlers = engine.allRedirectMethod
if redirectFixedPath(c, root, engine.RedirectFixedPath, executeRedirectionMiddlewares) {
return
}
}
}
break
@ -689,7 +707,7 @@ func serveError(c *Context, code int, defaultMessage []byte) {
c.writermem.WriteHeaderNow()
}
func redirectTrailingSlash(c *Context) {
func redirectTrailingSlash(c *Context, withRedirectionMiddlewares bool) {
req := c.Request
p := req.URL.Path
if prefix := path.Clean(c.Request.Header.Get("X-Forwarded-Prefix")); prefix != "." {
@ -702,22 +720,22 @@ func redirectTrailingSlash(c *Context) {
if length := len(p); length > 1 && p[length-1] == '/' {
req.URL.Path = p[:length-1]
}
redirectRequest(c)
redirectRequest(c, withRedirectionMiddlewares)
}
func redirectFixedPath(c *Context, root *node, trailingSlash bool) bool {
func redirectFixedPath(c *Context, root *node, trailingSlash, withRedirectionMiddlewares bool) bool {
req := c.Request
rPath := req.URL.Path
if fixedPath, ok := root.findCaseInsensitivePath(cleanPath(rPath), trailingSlash); ok {
req.URL.Path = bytesconv.BytesToString(fixedPath)
redirectRequest(c)
redirectRequest(c, withRedirectionMiddlewares)
return true
}
return false
}
func redirectRequest(c *Context) {
func redirectRequest(c *Context, executeRequestChain bool) {
req := c.Request
rPath := req.URL.Path
rURL := req.URL.String()
@ -727,6 +745,9 @@ func redirectRequest(c *Context) {
code = http.StatusTemporaryRedirect
}
debugPrint("redirecting request %d: %s --> %s", code, rPath, rURL)
if executeRequestChain {
c.Next()
}
http.Redirect(c.Writer, req, rURL, code)
c.writermem.WriteHeaderNow()
}

View File

@ -42,6 +42,11 @@ func NoRoute(handlers ...gin.HandlerFunc) {
engine().NoRoute(handlers...)
}
// OnRedirect is a wrapper for Engine.OnRedirect.
func OnRedirect(handlers ...gin.HandlerFunc) {
engine().OnRedirect(handlers...)
}
// NoMethod is a wrapper for Engine.NoMethod.
func NoMethod(handlers ...gin.HandlerFunc) {
engine().NoMethod(handlers...)

View File

@ -433,6 +433,59 @@ func TestNoMethodWithoutGlobalHandlers(t *testing.T) {
compareFunc(t, router.allNoMethod[1], middleware0)
}
func TestOnRedirectWithoutGlobalHandlers(t *testing.T) {
var middleware0 HandlerFunc = func(c *Context) {}
var middleware1 HandlerFunc = func(c *Context) {}
router := New()
router.OnRedirect(middleware0)
assert.Nil(t, router.Handlers)
assert.Len(t, router.redirectMethod, 1)
assert.Len(t, router.allRedirectMethod, 1)
compareFunc(t, router.redirectMethod[0], middleware0)
compareFunc(t, router.allRedirectMethod[0], middleware0)
router.OnRedirect(middleware1, middleware0)
assert.Len(t, router.redirectMethod, 2)
assert.Len(t, router.allRedirectMethod, 2)
compareFunc(t, router.redirectMethod[0], middleware1)
compareFunc(t, router.allRedirectMethod[0], middleware1)
compareFunc(t, router.redirectMethod[1], middleware0)
compareFunc(t, router.allRedirectMethod[1], middleware0)
}
func TestOnRedirectWithGlobalHandlers(t *testing.T) {
var middleware0 HandlerFunc = func(c *Context) {}
var middleware1 HandlerFunc = func(c *Context) {}
var middleware2 HandlerFunc = func(c *Context) {}
router := New()
router.Use(middleware2)
router.OnRedirect(middleware0)
assert.Len(t, router.allRedirectMethod, 2)
assert.Len(t, router.Handlers, 1)
assert.Len(t, router.redirectMethod, 1)
compareFunc(t, router.Handlers[0], middleware2)
compareFunc(t, router.redirectMethod[0], middleware0)
compareFunc(t, router.allRedirectMethod[0], middleware2)
compareFunc(t, router.allRedirectMethod[1], middleware0)
router.Use(middleware1)
assert.Len(t, router.allRedirectMethod, 3)
assert.Len(t, router.Handlers, 2)
assert.Len(t, router.redirectMethod, 1)
compareFunc(t, router.Handlers[0], middleware2)
compareFunc(t, router.Handlers[1], middleware1)
compareFunc(t, router.redirectMethod[0], middleware0)
compareFunc(t, router.allRedirectMethod[0], middleware2)
compareFunc(t, router.allRedirectMethod[1], middleware1)
compareFunc(t, router.allRedirectMethod[2], middleware0)
}
func TestRebuild404Handlers(t *testing.T) {
}

2
go.mod
View File

@ -32,6 +32,8 @@ require (
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
golang.org/x/arch v0.7.0 // indirect
golang.org/x/crypto v0.19.0 // indirect
golang.org/x/lint v0.0.0-20210508222113-6edffad5e616 // indirect
golang.org/x/sys v0.17.0 // indirect
golang.org/x/text v0.14.0 // indirect
golang.org/x/tools v0.18.0 // indirect
)

15
go.sum
View File

@ -63,16 +63,31 @@ github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZ
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.7.0 h1:pskyeJh/3AmoQ8CPE95vxHLqp1G1GfGNXTmcl9NEKTc=
golang.org/x/arch v0.7.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo=
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/lint v0.0.0-20210508222113-6edffad5e616 h1:VLliZ0d+/avPrXXH+OakdXhpJuEoBZuwh1m2j7U6Iug=
golang.org/x/lint v0.0.0-20210508222113-6edffad5e616/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/tools v0.18.0 h1:k8NLag8AGHnn+PHbl7g43CtqZAwG60vZkLqgyZgIHgQ=
golang.org/x/tools v0.18.0/go.mod h1:GL7B4CwcLLeo59yx/9UWWuNOW1n3VZ4f5axWfML7Lcg=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=

View File

@ -34,6 +34,9 @@ func TestMiddlewareGeneralCase(t *testing.T) {
router.NoMethod(func(c *Context) {
signature += " XX "
})
router.OnRedirect(func(c *Context) {
signature += " YYY "
})
// RUN
w := PerformRequest(router, "GET", "/")
@ -78,6 +81,50 @@ func TestMiddlewareNoRoute(t *testing.T) {
assert.Equal(t, "ACEGHFDB", signature)
}
func TestMiddlewareOnRedirect(t *testing.T) {
signature := ""
router := New()
router.Use(func(c *Context) {
signature += "A"
c.Next()
signature += "B"
})
router.Use(func(c *Context) {
signature += "C"
c.Next()
c.Next()
c.Next()
c.Next()
signature += "D"
})
router.NoRoute(func(c *Context) {
signature += "E"
c.Next()
signature += "F"
}, func(c *Context) {
signature += "G"
c.Next()
signature += "H"
})
router.NoMethod(func(c *Context) {
signature += " X "
})
router.OnRedirect(func(c *Context) {
signature += "Y"
})
router.GET("/foo", func(c *Context) {
c.String(200, "Hello, World!")
})
// RUN
w := PerformRequest(router, "GET", "/foo/")
// TEST
assert.Equal(t, http.StatusMovedPermanently, w.Code)
assert.Equal(t, "ACYDB", signature)
}
func TestMiddlewareNoMethodEnabled(t *testing.T) {
signature := ""
router := New()