diff --git a/routergroup.go b/routergroup.go index c833fe8f..42137dfc 100644 --- a/routergroup.go +++ b/routergroup.go @@ -67,6 +67,20 @@ func (group *RouterGroup) Use(middleware ...HandlerFunc) IRoutes { return group.returnObj() } +func (group *RouterGroup) UseWithFilter(httpMethod, prefixRelativePath string, middleware ...HandlerFunc) IRoutes { + wrapper := func(c *Context) { + if c.Request.Method == httpMethod && strings.HasPrefix(c.FullPath(), group.calculateAbsolutePath(prefixRelativePath)) { + for _, m := range middleware { + m(c) + } + } + + c.Next() + } + + return group.Use(wrapper) +} + // Group creates a new router group. You should add all the routes that have common middlewares or the same path prefix. // For example, all the routes that use a common middleware for authorization could be grouped. func (group *RouterGroup) Group(relativePath string, handlers ...HandlerFunc) *RouterGroup { diff --git a/routergroup_test.go b/routergroup_test.go index 6848063e..374f9e9f 100644 --- a/routergroup_test.go +++ b/routergroup_test.go @@ -5,7 +5,9 @@ package gin import ( + "fmt" "net/http" + "net/http/httptest" "testing" "github.com/stretchr/testify/assert" @@ -193,3 +195,39 @@ func testRoutesInterface(t *testing.T, r IRoutes) { assert.Equal(t, r, r.Static("/static", ".")) assert.Equal(t, r, r.StaticFS("/static2", Dir(".", false))) } + +func TestUseWithFilter(t *testing.T) { + router := New() + records := make(map[string]bool) + router.UseWithFilter("GET", "/api/v1", func(c *Context) { + records[fmt.Sprintf("%s,%s", c.Request.Method, c.FullPath())] = true + }) + + assert.Len(t, router.Handlers, 1) + assert.Equal(t, "/", router.BasePath()) + + router.GET("/api/v1/hello", func(c *Context) { + c.Status(http.StatusOK) + }) + + router.GET("/api/v2/hello", func(c *Context) { + c.Status(http.StatusOK) + }) + + router.POST("/api/v1/hello", func(c *Context) { + c.Status(http.StatusOK) + }) + + req, _ := http.NewRequest("GET", "/api/v1/hello", nil) + router.ServeHTTP(httptest.NewRecorder(), req) + + req2, _ := http.NewRequest("GET", "/api/v2/hello", nil) + router.ServeHTTP(httptest.NewRecorder(), req2) + + req3, _ := http.NewRequest("POST", "/api/v2/hello", nil) + router.ServeHTTP(httptest.NewRecorder(), req3) + + assert.Equal(t, records["GET,/api/v1/hello"], true) + assert.Equal(t, records["GET,/api/v2/hello"], false) + assert.Equal(t, records["POST,/api/v1/hello"], false) +}