feat: support prefixPath and method matching for middleware

This commit is contained in:
x1a0t 2023-11-17 19:39:27 +08:00
parent 44d0dd7092
commit b4018c5d91
No known key found for this signature in database
GPG Key ID: 36A18C0B1CC975D3
2 changed files with 52 additions and 0 deletions

View File

@ -67,6 +67,20 @@ func (group *RouterGroup) Use(middleware ...HandlerFunc) IRoutes {
return group.returnObj()
}
func (group *RouterGroup) UseWithFilter(httpMethod, prefixRelativePath string, middleware ...HandlerFunc) IRoutes {
wrapper := func(c *Context) {
if c.Request.Method == httpMethod && strings.HasPrefix(c.FullPath(), group.calculateAbsolutePath(prefixRelativePath)) {
for _, m := range middleware {
m(c)
}
}
c.Next()
}
return group.Use(wrapper)
}
// Group creates a new router group. You should add all the routes that have common middlewares or the same path prefix.
// For example, all the routes that use a common middleware for authorization could be grouped.
func (group *RouterGroup) Group(relativePath string, handlers ...HandlerFunc) *RouterGroup {

View File

@ -5,7 +5,9 @@
package gin
import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
@ -193,3 +195,39 @@ func testRoutesInterface(t *testing.T, r IRoutes) {
assert.Equal(t, r, r.Static("/static", "."))
assert.Equal(t, r, r.StaticFS("/static2", Dir(".", false)))
}
func TestUseWithFilter(t *testing.T) {
router := New()
records := make(map[string]bool)
router.UseWithFilter("GET", "/api/v1", func(c *Context) {
records[fmt.Sprintf("%s,%s", c.Request.Method, c.FullPath())] = true
})
assert.Len(t, router.Handlers, 1)
assert.Equal(t, "/", router.BasePath())
router.GET("/api/v1/hello", func(c *Context) {
c.Status(http.StatusOK)
})
router.GET("/api/v2/hello", func(c *Context) {
c.Status(http.StatusOK)
})
router.POST("/api/v1/hello", func(c *Context) {
c.Status(http.StatusOK)
})
req, _ := http.NewRequest("GET", "/api/v1/hello", nil)
router.ServeHTTP(httptest.NewRecorder(), req)
req2, _ := http.NewRequest("GET", "/api/v2/hello", nil)
router.ServeHTTP(httptest.NewRecorder(), req2)
req3, _ := http.NewRequest("POST", "/api/v2/hello", nil)
router.ServeHTTP(httptest.NewRecorder(), req3)
assert.Equal(t, records["GET,/api/v1/hello"], true)
assert.Equal(t, records["GET,/api/v2/hello"], false)
assert.Equal(t, records["POST,/api/v1/hello"], false)
}