Merge 90d541ec284b55d7c86bc7a51b9740a7d1940887 into 8acbe657f1c140e3fba38f869978cab2376500c9

This commit is contained in:
Yashvardhan Kukreja 2024-04-02 17:30:47 +08:00 committed by GitHub
commit 679596bf98
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 172 additions and 22 deletions

View File

@ -121,6 +121,13 @@ func Benchmark404Many(B *testing.B) {
runRequest(B, router, "GET", "/viewfake") runRequest(B, router, "GET", "/viewfake")
} }
func BenchmarkOnRedirect(B *testing.B) {
router := New()
router.GET("/something", func(c *Context) {})
router.OnRedirect(func(c *Context) {})
runRequest(B, router, "GET", "/something/")
}
type mockWriter struct { type mockWriter struct {
headers http.Header headers http.Header
} }

65
gin.go
View File

@ -159,20 +159,22 @@ type Engine struct {
// ContextWithFallback enable fallback Context.Deadline(), Context.Done(), Context.Err() and Context.Value() when Context.Request.Context() is not nil. // ContextWithFallback enable fallback Context.Deadline(), Context.Done(), Context.Err() and Context.Value() when Context.Request.Context() is not nil.
ContextWithFallback bool ContextWithFallback bool
delims render.Delims delims render.Delims
secureJSONPrefix string secureJSONPrefix string
HTMLRender render.HTMLRender HTMLRender render.HTMLRender
FuncMap template.FuncMap FuncMap template.FuncMap
allNoRoute HandlersChain allNoRoute HandlersChain
allNoMethod HandlersChain allNoMethod HandlersChain
noRoute HandlersChain allRedirectMethod HandlersChain
noMethod HandlersChain noRoute HandlersChain
pool sync.Pool noMethod HandlersChain
trees methodTrees redirectMethod HandlersChain
maxParams uint16 pool sync.Pool
maxSections uint16 trees methodTrees
trustedProxies []string maxParams uint16
trustedCIDRs []*net.IPNet maxSections uint16
trustedProxies []string
trustedCIDRs []*net.IPNet
} }
var _ IRouter = (*Engine)(nil) var _ IRouter = (*Engine)(nil)
@ -300,6 +302,12 @@ func (engine *Engine) NoRoute(handlers ...HandlerFunc) {
engine.rebuild404Handlers() engine.rebuild404Handlers()
} }
// OnRedirect adds handlers for when redirects are made.
func (engine *Engine) OnRedirect(handlers ...HandlerFunc) {
engine.redirectMethod = handlers
engine.rebuildRedirectHandlers()
}
// NoMethod sets the handlers called when Engine.HandleMethodNotAllowed = true. // NoMethod sets the handlers called when Engine.HandleMethodNotAllowed = true.
func (engine *Engine) NoMethod(handlers ...HandlerFunc) { func (engine *Engine) NoMethod(handlers ...HandlerFunc) {
engine.noMethod = handlers engine.noMethod = handlers
@ -313,6 +321,7 @@ func (engine *Engine) Use(middleware ...HandlerFunc) IRoutes {
engine.RouterGroup.Use(middleware...) engine.RouterGroup.Use(middleware...)
engine.rebuild404Handlers() engine.rebuild404Handlers()
engine.rebuild405Handlers() engine.rebuild405Handlers()
engine.rebuildRedirectHandlers()
return engine return engine
} }
@ -333,6 +342,10 @@ func (engine *Engine) rebuild405Handlers() {
engine.allNoMethod = engine.combineHandlers(engine.noMethod) engine.allNoMethod = engine.combineHandlers(engine.noMethod)
} }
func (engine *Engine) rebuildRedirectHandlers() {
engine.allRedirectMethod = engine.combineHandlers(engine.redirectMethod)
}
func (engine *Engine) addRoute(method, path string, handlers HandlersChain) { func (engine *Engine) addRoute(method, path string, handlers HandlersChain) {
assert1(path[0] == '/', "path must begin with '/'") assert1(path[0] == '/', "path must begin with '/'")
assert1(method != "", "HTTP method can not be empty") assert1(method != "", "HTTP method can not be empty")
@ -635,12 +648,17 @@ func (engine *Engine) handleHTTPRequest(c *Context) {
return return
} }
if httpMethod != http.MethodConnect && rPath != "/" { if httpMethod != http.MethodConnect && rPath != "/" {
executeRedirectionMiddlewares := len(engine.redirectMethod) > 0
if value.tsr && engine.RedirectTrailingSlash { if value.tsr && engine.RedirectTrailingSlash {
redirectTrailingSlash(c) c.handlers = engine.allRedirectMethod
redirectTrailingSlash(c, executeRedirectionMiddlewares)
return return
} }
if engine.RedirectFixedPath && redirectFixedPath(c, root, engine.RedirectFixedPath) { if engine.RedirectFixedPath {
return c.handlers = engine.allRedirectMethod
if redirectFixedPath(c, root, engine.RedirectFixedPath, executeRedirectionMiddlewares) {
return
}
} }
} }
break break
@ -689,7 +707,7 @@ func serveError(c *Context, code int, defaultMessage []byte) {
c.writermem.WriteHeaderNow() c.writermem.WriteHeaderNow()
} }
func redirectTrailingSlash(c *Context) { func redirectTrailingSlash(c *Context, withRedirectionMiddlewares bool) {
req := c.Request req := c.Request
p := req.URL.Path p := req.URL.Path
if prefix := path.Clean(c.Request.Header.Get("X-Forwarded-Prefix")); prefix != "." { if prefix := path.Clean(c.Request.Header.Get("X-Forwarded-Prefix")); prefix != "." {
@ -702,22 +720,22 @@ func redirectTrailingSlash(c *Context) {
if length := len(p); length > 1 && p[length-1] == '/' { if length := len(p); length > 1 && p[length-1] == '/' {
req.URL.Path = p[:length-1] req.URL.Path = p[:length-1]
} }
redirectRequest(c) redirectRequest(c, withRedirectionMiddlewares)
} }
func redirectFixedPath(c *Context, root *node, trailingSlash bool) bool { func redirectFixedPath(c *Context, root *node, trailingSlash, withRedirectionMiddlewares bool) bool {
req := c.Request req := c.Request
rPath := req.URL.Path rPath := req.URL.Path
if fixedPath, ok := root.findCaseInsensitivePath(cleanPath(rPath), trailingSlash); ok { if fixedPath, ok := root.findCaseInsensitivePath(cleanPath(rPath), trailingSlash); ok {
req.URL.Path = bytesconv.BytesToString(fixedPath) req.URL.Path = bytesconv.BytesToString(fixedPath)
redirectRequest(c) redirectRequest(c, withRedirectionMiddlewares)
return true return true
} }
return false return false
} }
func redirectRequest(c *Context) { func redirectRequest(c *Context, executeRequestChain bool) {
req := c.Request req := c.Request
rPath := req.URL.Path rPath := req.URL.Path
rURL := req.URL.String() rURL := req.URL.String()
@ -727,6 +745,9 @@ func redirectRequest(c *Context) {
code = http.StatusTemporaryRedirect code = http.StatusTemporaryRedirect
} }
debugPrint("redirecting request %d: %s --> %s", code, rPath, rURL) debugPrint("redirecting request %d: %s --> %s", code, rPath, rURL)
if executeRequestChain {
c.Next()
}
http.Redirect(c.Writer, req, rURL, code) http.Redirect(c.Writer, req, rURL, code)
c.writermem.WriteHeaderNow() c.writermem.WriteHeaderNow()
} }

View File

@ -42,6 +42,11 @@ func NoRoute(handlers ...gin.HandlerFunc) {
engine().NoRoute(handlers...) engine().NoRoute(handlers...)
} }
// OnRedirect is a wrapper for Engine.OnRedirect.
func OnRedirect(handlers ...gin.HandlerFunc) {
engine().OnRedirect(handlers...)
}
// NoMethod is a wrapper for Engine.NoMethod. // NoMethod is a wrapper for Engine.NoMethod.
func NoMethod(handlers ...gin.HandlerFunc) { func NoMethod(handlers ...gin.HandlerFunc) {
engine().NoMethod(handlers...) engine().NoMethod(handlers...)

View File

@ -433,6 +433,59 @@ func TestNoMethodWithoutGlobalHandlers(t *testing.T) {
compareFunc(t, router.allNoMethod[1], middleware0) compareFunc(t, router.allNoMethod[1], middleware0)
} }
func TestOnRedirectWithoutGlobalHandlers(t *testing.T) {
var middleware0 HandlerFunc = func(c *Context) {}
var middleware1 HandlerFunc = func(c *Context) {}
router := New()
router.OnRedirect(middleware0)
assert.Nil(t, router.Handlers)
assert.Len(t, router.redirectMethod, 1)
assert.Len(t, router.allRedirectMethod, 1)
compareFunc(t, router.redirectMethod[0], middleware0)
compareFunc(t, router.allRedirectMethod[0], middleware0)
router.OnRedirect(middleware1, middleware0)
assert.Len(t, router.redirectMethod, 2)
assert.Len(t, router.allRedirectMethod, 2)
compareFunc(t, router.redirectMethod[0], middleware1)
compareFunc(t, router.allRedirectMethod[0], middleware1)
compareFunc(t, router.redirectMethod[1], middleware0)
compareFunc(t, router.allRedirectMethod[1], middleware0)
}
func TestOnRedirectWithGlobalHandlers(t *testing.T) {
var middleware0 HandlerFunc = func(c *Context) {}
var middleware1 HandlerFunc = func(c *Context) {}
var middleware2 HandlerFunc = func(c *Context) {}
router := New()
router.Use(middleware2)
router.OnRedirect(middleware0)
assert.Len(t, router.allRedirectMethod, 2)
assert.Len(t, router.Handlers, 1)
assert.Len(t, router.redirectMethod, 1)
compareFunc(t, router.Handlers[0], middleware2)
compareFunc(t, router.redirectMethod[0], middleware0)
compareFunc(t, router.allRedirectMethod[0], middleware2)
compareFunc(t, router.allRedirectMethod[1], middleware0)
router.Use(middleware1)
assert.Len(t, router.allRedirectMethod, 3)
assert.Len(t, router.Handlers, 2)
assert.Len(t, router.redirectMethod, 1)
compareFunc(t, router.Handlers[0], middleware2)
compareFunc(t, router.Handlers[1], middleware1)
compareFunc(t, router.redirectMethod[0], middleware0)
compareFunc(t, router.allRedirectMethod[0], middleware2)
compareFunc(t, router.allRedirectMethod[1], middleware1)
compareFunc(t, router.allRedirectMethod[2], middleware0)
}
func TestRebuild404Handlers(t *testing.T) { func TestRebuild404Handlers(t *testing.T) {
} }

2
go.mod
View File

@ -32,6 +32,8 @@ require (
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
golang.org/x/arch v0.7.0 // indirect golang.org/x/arch v0.7.0 // indirect
golang.org/x/crypto v0.19.0 // indirect golang.org/x/crypto v0.19.0 // indirect
golang.org/x/lint v0.0.0-20210508222113-6edffad5e616 // indirect
golang.org/x/sys v0.17.0 // indirect golang.org/x/sys v0.17.0 // indirect
golang.org/x/text v0.14.0 // indirect golang.org/x/text v0.14.0 // indirect
golang.org/x/tools v0.18.0 // indirect
) )

15
go.sum
View File

@ -63,16 +63,31 @@ github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZ
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.7.0 h1:pskyeJh/3AmoQ8CPE95vxHLqp1G1GfGNXTmcl9NEKTc= golang.org/x/arch v0.7.0 h1:pskyeJh/3AmoQ8CPE95vxHLqp1G1GfGNXTmcl9NEKTc=
golang.org/x/arch v0.7.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= golang.org/x/arch v0.7.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo= golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo=
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/lint v0.0.0-20210508222113-6edffad5e616 h1:VLliZ0d+/avPrXXH+OakdXhpJuEoBZuwh1m2j7U6Iug=
golang.org/x/lint v0.0.0-20210508222113-6edffad5e616/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/tools v0.18.0 h1:k8NLag8AGHnn+PHbl7g43CtqZAwG60vZkLqgyZgIHgQ=
golang.org/x/tools v0.18.0/go.mod h1:GL7B4CwcLLeo59yx/9UWWuNOW1n3VZ4f5axWfML7Lcg=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=

View File

@ -34,6 +34,9 @@ func TestMiddlewareGeneralCase(t *testing.T) {
router.NoMethod(func(c *Context) { router.NoMethod(func(c *Context) {
signature += " XX " signature += " XX "
}) })
router.OnRedirect(func(c *Context) {
signature += " YYY "
})
// RUN // RUN
w := PerformRequest(router, "GET", "/") w := PerformRequest(router, "GET", "/")
@ -78,6 +81,50 @@ func TestMiddlewareNoRoute(t *testing.T) {
assert.Equal(t, "ACEGHFDB", signature) assert.Equal(t, "ACEGHFDB", signature)
} }
func TestMiddlewareOnRedirect(t *testing.T) {
signature := ""
router := New()
router.Use(func(c *Context) {
signature += "A"
c.Next()
signature += "B"
})
router.Use(func(c *Context) {
signature += "C"
c.Next()
c.Next()
c.Next()
c.Next()
signature += "D"
})
router.NoRoute(func(c *Context) {
signature += "E"
c.Next()
signature += "F"
}, func(c *Context) {
signature += "G"
c.Next()
signature += "H"
})
router.NoMethod(func(c *Context) {
signature += " X "
})
router.OnRedirect(func(c *Context) {
signature += "Y"
})
router.GET("/foo", func(c *Context) {
c.String(200, "Hello, World!")
})
// RUN
w := PerformRequest(router, "GET", "/foo/")
// TEST
assert.Equal(t, http.StatusMovedPermanently, w.Code)
assert.Equal(t, "ACYDB", signature)
}
func TestMiddlewareNoMethodEnabled(t *testing.T) { func TestMiddlewareNoMethodEnabled(t *testing.T) {
signature := "" signature := ""
router := New() router := New()