diff --git a/gin.go b/gin.go index 1633fe13..049d4b87 100644 --- a/gin.go +++ b/gin.go @@ -47,9 +47,6 @@ var regRemoveRepeatedChar = regexp.MustCompile("/{2,}") // HandlerFunc defines the handler used by gin middleware as return value. type HandlerFunc func(*Context) -// OptionFunc defines the function to change the default configuration -type OptionFunc func(*Engine) - // HandlersChain defines a HandlerFunc slice. type HandlersChain []HandlerFunc diff --git a/option.go b/option.go new file mode 100644 index 00000000..0c9bc911 --- /dev/null +++ b/option.go @@ -0,0 +1,113 @@ +package gin + +import "net/http" + +// OptionFunc defines the function to change the default configuration +type OptionFunc func(*Engine) + +// Use attaches a global middleware to the router +func Use(middleware ...HandlerFunc) OptionFunc { + return func(e *Engine) { + e.Use(middleware...) + } +} + +// Get is a shortcut for RouterGroup.Handle("GET", path, handle) +func Get(path string, handlers ...HandlerFunc) OptionFunc { + return func(e *Engine) { + e.GET(path, handlers...) + } +} + +// Post is a shortcut for RouterGroup.Handle("POST", path, handle) +func Post(path string, handlers ...HandlerFunc) OptionFunc { + return func(e *Engine) { + e.POST(path, handlers...) + } +} + +// Put is a shortcut for RouterGroup.Handle("PUT", path, handle) +func Put(path string, handlers ...HandlerFunc) OptionFunc { + return func(e *Engine) { + e.PUT(path, handlers...) + } +} + +// Delete is a shortcut for RouterGroup.Handle("DELETE", path, handle) +func Delete(path string, handlers ...HandlerFunc) OptionFunc { + return func(e *Engine) { + e.DELETE(path, handlers...) + } +} + +// Patch is a shortcut for RouterGroup.Handle("PATCH", path, handle) +func Patch(path string, handlers ...HandlerFunc) OptionFunc { + return func(e *Engine) { + e.PATCH(path, handlers...) + } +} + +// Head is a shortcut for RouterGroup.Handle("HEAD", path, handle) +func Head(path string, handlers ...HandlerFunc) OptionFunc { + return func(e *Engine) { + e.HEAD(path, handlers...) + } +} + +// Options is a shortcut for RouterGroup.Handle("OPTIONS", path, handle) +func Options(path string, handlers ...HandlerFunc) OptionFunc { + return func(e *Engine) { + e.OPTIONS(path, handlers...) + } +} + +// Any is a shortcut for RouterGroup.Handle("GET", path, handle) +func Any(path string, handlers ...HandlerFunc) OptionFunc { + return func(e *Engine) { + e.Any(path, handlers...) + } +} + +// Group is used to create a new router group. You should add all the routes that have common middlewares or the same path prefix +func Group(path string, groupFunc func(*RouterGroup), handlers ...HandlerFunc) OptionFunc { + return func(e *Engine) { + groupFunc( + e.Group(path, handlers...), + ) + } +} + +// Route is a shortcut for RouterGroup.Handle +func Route(httpMethod, relativePath string, handlers ...HandlerFunc) OptionFunc { + return func(e *Engine) { + e.Handle(httpMethod, relativePath, handlers...) + } +} + +// StaticFS returns a middleware that serves static files in the given file system +func StaticFS(path string, fs http.FileSystem) OptionFunc { + return func(e *Engine) { + e.StaticFS(path, fs) + } +} + +// StaticFile returns a middleware that serves a single file +func StaticFile(path, file string) OptionFunc { + return func(e *Engine) { + e.StaticFile(path, file) + } +} + +// Static returns a middleware that serves static files from a directory +func Static(path, root string) OptionFunc { + return func(e *Engine) { + e.Static(path, root) + } +} + +// NoRoute is a global handler for no matching routes +func NoRoute(handlers ...HandlerFunc) OptionFunc { + return func(e *Engine) { + e.NoRoute(handlers...) + } +} diff --git a/option_test.go b/option_test.go new file mode 100644 index 00000000..574a148d --- /dev/null +++ b/option_test.go @@ -0,0 +1,202 @@ +package gin + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestOption_Use(t *testing.T) { + var middleware1 HandlerFunc = func(c *Context) {} + var middleware2 HandlerFunc = func(c *Context) {} + + router := New( + Use(middleware1, middleware2), + ) + + assert.Equal(t, 2, len(router.Handlers)) + compareFunc(t, middleware1, router.Handlers[0]) + compareFunc(t, middleware2, router.Handlers[1]) +} + +func TestOption_HttpMethod(t *testing.T) { + tests := []struct { + method string + path string + optionFunc OptionFunc + want int + }{ + { + method: http.MethodGet, + path: "/get", + optionFunc: Get("/get", func(c *Context) { + assert.Equal(t, http.MethodGet, c.Request.Method) + assert.Equal(t, "/get", c.Request.URL.Path) + }), + }, + { + method: http.MethodPut, + path: "/put", + optionFunc: Put("/put", func(c *Context) { + assert.Equal(t, http.MethodPut, c.Request.Method) + assert.Equal(t, "/put", c.Request.URL.Path) + }), + }, + { + method: http.MethodPost, + path: "/post", + optionFunc: Post("/post", func(c *Context) { + assert.Equal(t, http.MethodPost, c.Request.Method) + assert.Equal(t, "/post", c.Request.URL.Path) + }), + }, + { + method: http.MethodDelete, + path: "/delete", + optionFunc: Delete("/delete", func(c *Context) { + assert.Equal(t, http.MethodDelete, c.Request.Method) + assert.Equal(t, "/delete", c.Request.URL.Path) + }), + }, + { + method: http.MethodPatch, + path: "/patch", + optionFunc: Patch("/patch", func(c *Context) { + assert.Equal(t, http.MethodPatch, c.Request.Method) + assert.Equal(t, "/patch", c.Request.URL.Path) + }), + }, + { + method: http.MethodOptions, + path: "/options", + optionFunc: Options("/options", func(c *Context) { + assert.Equal(t, http.MethodOptions, c.Request.Method) + assert.Equal(t, "/options", c.Request.URL.Path) + }), + }, + { + method: http.MethodHead, + path: "/head", + optionFunc: Head("/head", func(c *Context) { + assert.Equal(t, http.MethodHead, c.Request.Method) + assert.Equal(t, "/head", c.Request.URL.Path) + }), + }, + { + method: "GET", + path: "/any", + optionFunc: Any("/any", func(c *Context) { + assert.Equal(t, http.MethodGet, c.Request.Method) + assert.Equal(t, "/any", c.Request.URL.Path) + }), + }, + } + for _, tt := range tests { + t.Run(tt.method, func(t *testing.T) { + router := New(tt.optionFunc) + w := PerformRequest(router, tt.method, tt.path) + assert.Equal(t, 200, w.Code) + }) + } +} + +func TestOption_Any(t *testing.T) { + method := make(chan string, 1) + router := New( + Any("/any", func(c *Context) { + method <- c.Request.Method + assert.Equal(t, "/any", c.Request.URL.Path) + }), + ) + + tests := []struct { + method string + }{ + {http.MethodGet}, + {http.MethodPost}, + {http.MethodPut}, + {http.MethodPatch}, + {http.MethodDelete}, + {http.MethodHead}, + {http.MethodOptions}, + } + + for _, tt := range tests { + t.Run(tt.method, func(t *testing.T) { + w := PerformRequest(router, tt.method, "/any") + assert.Equal(t, 200, w.Code) + assert.Equal(t, tt.method, <-method) + }) + } +} + +func TestOption_Group(t *testing.T) { + router := New( + Group("/v1", func(group *RouterGroup) { + group.GET("/test", func(c *Context) { + assert.Equal(t, http.MethodGet, c.Request.Method) + assert.Equal(t, "/v1/test", c.Request.URL.Path) + }) + }), + ) + + w := PerformRequest(router, http.MethodGet, "/v1/test") + assert.Equal(t, 200, w.Code) +} + +func TestOption_Route(t *testing.T) { + router := New( + Route(http.MethodGet, "/test", func(c *Context) { + assert.Equal(t, http.MethodGet, c.Request.Method) + assert.Equal(t, "/test", c.Request.URL.Path) + }), + ) + + w := PerformRequest(router, http.MethodGet, "/test") + assert.Equal(t, 200, w.Code) +} + +func TestOption_StaticFS(t *testing.T) { + router := New( + StaticFS("/", http.Dir("./")), + ) + + w := PerformRequest(router, http.MethodGet, "/gin.go") + assert.Equal(t, 200, w.Code) + assert.Contains(t, w.Body.String(), "package gin") +} + +func TestOption_StaticFile(t *testing.T) { + router := New( + StaticFile("/gin.go", "gin.go"), + ) + + w := PerformRequest(router, http.MethodGet, "/gin.go") + assert.Equal(t, 200, w.Code) + assert.Contains(t, w.Body.String(), "package gin") +} + +func TestOption_Static(t *testing.T) { + router := New( + Static("/static", "./"), + ) + + w := PerformRequest(router, http.MethodGet, "/static/gin.go") + assert.Equal(t, 200, w.Code) + assert.Contains(t, w.Body.String(), "package gin") +} + +func TestOption_NoRoute(t *testing.T) { + router := New( + NoRoute(func(c *Context) { + c.String(http.StatusNotFound, "no route") + }), + ) + + assert.Equal(t, 1, len(router.noRoute)) + + w := PerformRequest(router, http.MethodGet, "/no-route") + assert.Equal(t, 404, w.Code) + assert.Equal(t, "no route", w.Body.String()) +}