mirror of
https://github.com/gin-gonic/gin.git
synced 2026-06-06 03:08:11 +08:00
Adds a new Engine option that, when enabled, prevents global middleware registered via Use() from being executed for 405 Method Not Allowed responses. This allows NoMethod handlers to run without triggering middleware that may reject the request (e.g., authentication, checksum validation) before the 405 response can be sent. Fixes gin-gonic/gin#4189
1218 lines
29 KiB
Go
1218 lines
29 KiB
Go
// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
|
|
// Use of this source code is governed by a MIT style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package gin
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"fmt"
|
|
"html/template"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"reflect"
|
|
"strconv"
|
|
"strings"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"golang.org/x/net/http2"
|
|
)
|
|
|
|
func formatAsDate(t time.Time) string {
|
|
year, month, day := t.Date()
|
|
return fmt.Sprintf("%d/%02d/%02d", year, month, day)
|
|
}
|
|
|
|
func setupHTMLFiles(t *testing.T, mode string, tls bool, loadMethod func(*Engine)) *httptest.Server {
|
|
SetMode(mode)
|
|
defer SetMode(TestMode)
|
|
|
|
var router *Engine
|
|
captureOutput(t, func() {
|
|
router = New()
|
|
router.Delims("{[{", "}]}")
|
|
router.SetFuncMap(template.FuncMap{
|
|
"formatAsDate": formatAsDate,
|
|
})
|
|
loadMethod(router)
|
|
router.GET("/test", func(c *Context) {
|
|
c.HTML(http.StatusOK, "hello.tmpl", map[string]string{"name": "world"})
|
|
})
|
|
router.GET("/raw", func(c *Context) {
|
|
c.HTML(http.StatusOK, "raw.tmpl", map[string]any{
|
|
"now": time.Date(2017, 07, 01, 0, 0, 0, 0, time.UTC), //nolint:gofumpt
|
|
})
|
|
})
|
|
})
|
|
|
|
var ts *httptest.Server
|
|
|
|
if tls {
|
|
ts = httptest.NewTLSServer(router)
|
|
} else {
|
|
ts = httptest.NewServer(router)
|
|
}
|
|
|
|
return ts
|
|
}
|
|
|
|
func TestLoadHTMLGlobDebugMode(t *testing.T) {
|
|
ts := setupHTMLFiles(
|
|
t,
|
|
DebugMode,
|
|
false,
|
|
func(router *Engine) {
|
|
router.LoadHTMLGlob("./testdata/template/*")
|
|
},
|
|
)
|
|
defer ts.Close()
|
|
|
|
res, err := http.Get(ts.URL + "/test")
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
|
|
resp, _ := io.ReadAll(res.Body)
|
|
assert.Equal(t, "<h1>Hello world</h1>", string(resp))
|
|
}
|
|
|
|
func TestH2c(t *testing.T) {
|
|
ln, err := net.Listen("tcp", localhostIP+":0")
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
r := Default()
|
|
r.UseH2C = true
|
|
r.GET("/", func(c *Context) {
|
|
c.String(200, "<h1>Hello world</h1>")
|
|
})
|
|
go func() {
|
|
err := http.Serve(ln, r.Handler())
|
|
if err != nil {
|
|
t.Log(err)
|
|
}
|
|
}()
|
|
defer ln.Close()
|
|
|
|
url := "http://" + ln.Addr().String() + "/"
|
|
|
|
httpClient := http.Client{
|
|
Transport: &http2.Transport{
|
|
AllowHTTP: true,
|
|
DialTLS: func(netw, addr string, cfg *tls.Config) (net.Conn, error) {
|
|
return net.Dial(netw, addr)
|
|
},
|
|
},
|
|
}
|
|
|
|
res, err := httpClient.Get(url)
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
|
|
resp, _ := io.ReadAll(res.Body)
|
|
assert.Equal(t, "<h1>Hello world</h1>", string(resp))
|
|
}
|
|
|
|
func TestLoadHTMLGlobTestMode(t *testing.T) {
|
|
ts := setupHTMLFiles(
|
|
t,
|
|
TestMode,
|
|
false,
|
|
func(router *Engine) {
|
|
router.LoadHTMLGlob("./testdata/template/*")
|
|
},
|
|
)
|
|
defer ts.Close()
|
|
|
|
res, err := http.Get(ts.URL + "/test")
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
|
|
resp, _ := io.ReadAll(res.Body)
|
|
assert.Equal(t, "<h1>Hello world</h1>", string(resp))
|
|
}
|
|
|
|
func TestLoadHTMLGlobReleaseMode(t *testing.T) {
|
|
ts := setupHTMLFiles(
|
|
t,
|
|
ReleaseMode,
|
|
false,
|
|
func(router *Engine) {
|
|
router.LoadHTMLGlob("./testdata/template/*")
|
|
},
|
|
)
|
|
defer ts.Close()
|
|
|
|
res, err := http.Get(ts.URL + "/test")
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
|
|
resp, _ := io.ReadAll(res.Body)
|
|
assert.Equal(t, "<h1>Hello world</h1>", string(resp))
|
|
}
|
|
|
|
func TestLoadHTMLGlobUsingTLS(t *testing.T) {
|
|
ts := setupHTMLFiles(
|
|
t,
|
|
DebugMode,
|
|
true,
|
|
func(router *Engine) {
|
|
router.LoadHTMLGlob("./testdata/template/*")
|
|
},
|
|
)
|
|
defer ts.Close()
|
|
|
|
// Use InsecureSkipVerify for avoiding `x509: certificate signed by unknown authority` error
|
|
tr := &http.Transport{
|
|
TLSClientConfig: &tls.Config{
|
|
InsecureSkipVerify: true,
|
|
},
|
|
}
|
|
client := &http.Client{Transport: tr}
|
|
res, err := client.Get(ts.URL + "/test")
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
|
|
resp, _ := io.ReadAll(res.Body)
|
|
assert.Equal(t, "<h1>Hello world</h1>", string(resp))
|
|
}
|
|
|
|
func TestLoadHTMLGlobFromFuncMap(t *testing.T) {
|
|
ts := setupHTMLFiles(
|
|
t,
|
|
DebugMode,
|
|
false,
|
|
func(router *Engine) {
|
|
router.LoadHTMLGlob("./testdata/template/*")
|
|
},
|
|
)
|
|
defer ts.Close()
|
|
|
|
res, err := http.Get(ts.URL + "/raw")
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
|
|
resp, _ := io.ReadAll(res.Body)
|
|
assert.Equal(t, "Date: 2017/07/01", string(resp))
|
|
}
|
|
|
|
func init() {
|
|
SetMode(TestMode)
|
|
}
|
|
|
|
func TestCreateEngine(t *testing.T) {
|
|
router := New()
|
|
assert.Equal(t, "/", router.basePath)
|
|
assert.Equal(t, router.engine, router)
|
|
assert.Empty(t, router.Handlers)
|
|
}
|
|
|
|
func TestLoadHTMLFilesTestMode(t *testing.T) {
|
|
ts := setupHTMLFiles(
|
|
t,
|
|
TestMode,
|
|
false,
|
|
func(router *Engine) {
|
|
router.LoadHTMLFiles("./testdata/template/hello.tmpl", "./testdata/template/raw.tmpl")
|
|
},
|
|
)
|
|
defer ts.Close()
|
|
|
|
res, err := http.Get(ts.URL + "/test")
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
|
|
resp, _ := io.ReadAll(res.Body)
|
|
assert.Equal(t, "<h1>Hello world</h1>", string(resp))
|
|
}
|
|
|
|
func TestLoadHTMLFilesDebugMode(t *testing.T) {
|
|
ts := setupHTMLFiles(
|
|
t,
|
|
DebugMode,
|
|
false,
|
|
func(router *Engine) {
|
|
router.LoadHTMLFiles("./testdata/template/hello.tmpl", "./testdata/template/raw.tmpl")
|
|
},
|
|
)
|
|
defer ts.Close()
|
|
|
|
res, err := http.Get(ts.URL + "/test")
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
|
|
resp, _ := io.ReadAll(res.Body)
|
|
assert.Equal(t, "<h1>Hello world</h1>", string(resp))
|
|
}
|
|
|
|
func TestLoadHTMLFilesReleaseMode(t *testing.T) {
|
|
ts := setupHTMLFiles(
|
|
t,
|
|
ReleaseMode,
|
|
false,
|
|
func(router *Engine) {
|
|
router.LoadHTMLFiles("./testdata/template/hello.tmpl", "./testdata/template/raw.tmpl")
|
|
},
|
|
)
|
|
defer ts.Close()
|
|
|
|
res, err := http.Get(ts.URL + "/test")
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
|
|
resp, _ := io.ReadAll(res.Body)
|
|
assert.Equal(t, "<h1>Hello world</h1>", string(resp))
|
|
}
|
|
|
|
func TestLoadHTMLFilesUsingTLS(t *testing.T) {
|
|
ts := setupHTMLFiles(
|
|
t,
|
|
TestMode,
|
|
true,
|
|
func(router *Engine) {
|
|
router.LoadHTMLFiles("./testdata/template/hello.tmpl", "./testdata/template/raw.tmpl")
|
|
},
|
|
)
|
|
defer ts.Close()
|
|
|
|
// Use InsecureSkipVerify for avoiding `x509: certificate signed by unknown authority` error
|
|
tr := &http.Transport{
|
|
TLSClientConfig: &tls.Config{
|
|
InsecureSkipVerify: true,
|
|
},
|
|
}
|
|
client := &http.Client{Transport: tr}
|
|
res, err := client.Get(ts.URL + "/test")
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
|
|
resp, _ := io.ReadAll(res.Body)
|
|
assert.Equal(t, "<h1>Hello world</h1>", string(resp))
|
|
}
|
|
|
|
func TestLoadHTMLFilesFuncMap(t *testing.T) {
|
|
ts := setupHTMLFiles(
|
|
t,
|
|
TestMode,
|
|
false,
|
|
func(router *Engine) {
|
|
router.LoadHTMLFiles("./testdata/template/hello.tmpl", "./testdata/template/raw.tmpl")
|
|
},
|
|
)
|
|
defer ts.Close()
|
|
|
|
res, err := http.Get(ts.URL + "/raw")
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
|
|
resp, _ := io.ReadAll(res.Body)
|
|
assert.Equal(t, "Date: 2017/07/01", string(resp))
|
|
}
|
|
|
|
var tmplFS = http.Dir("testdata/template")
|
|
|
|
func TestLoadHTMLFSTestMode(t *testing.T) {
|
|
ts := setupHTMLFiles(
|
|
t,
|
|
TestMode,
|
|
false,
|
|
func(router *Engine) {
|
|
router.LoadHTMLFS(tmplFS, "hello.tmpl", "raw.tmpl")
|
|
},
|
|
)
|
|
defer ts.Close()
|
|
|
|
res, err := http.Get(ts.URL + "/test")
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
|
|
resp, _ := io.ReadAll(res.Body)
|
|
assert.Equal(t, "<h1>Hello world</h1>", string(resp))
|
|
}
|
|
|
|
func TestLoadHTMLFSDebugMode(t *testing.T) {
|
|
ts := setupHTMLFiles(
|
|
t,
|
|
DebugMode,
|
|
false,
|
|
func(router *Engine) {
|
|
router.LoadHTMLFS(tmplFS, "hello.tmpl", "raw.tmpl")
|
|
},
|
|
)
|
|
defer ts.Close()
|
|
|
|
res, err := http.Get(ts.URL + "/test")
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
|
|
resp, _ := io.ReadAll(res.Body)
|
|
assert.Equal(t, "<h1>Hello world</h1>", string(resp))
|
|
}
|
|
|
|
func TestLoadHTMLFSReleaseMode(t *testing.T) {
|
|
ts := setupHTMLFiles(
|
|
t,
|
|
ReleaseMode,
|
|
false,
|
|
func(router *Engine) {
|
|
router.LoadHTMLFS(tmplFS, "hello.tmpl", "raw.tmpl")
|
|
},
|
|
)
|
|
defer ts.Close()
|
|
|
|
res, err := http.Get(ts.URL + "/test")
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
|
|
resp, _ := io.ReadAll(res.Body)
|
|
assert.Equal(t, "<h1>Hello world</h1>", string(resp))
|
|
}
|
|
|
|
func TestLoadHTMLFSUsingTLS(t *testing.T) {
|
|
ts := setupHTMLFiles(
|
|
t,
|
|
TestMode,
|
|
true,
|
|
func(router *Engine) {
|
|
router.LoadHTMLFS(tmplFS, "hello.tmpl", "raw.tmpl")
|
|
},
|
|
)
|
|
defer ts.Close()
|
|
|
|
// Use InsecureSkipVerify for avoiding `x509: certificate signed by unknown authority` error
|
|
tr := &http.Transport{
|
|
TLSClientConfig: &tls.Config{
|
|
InsecureSkipVerify: true,
|
|
},
|
|
}
|
|
client := &http.Client{Transport: tr}
|
|
res, err := client.Get(ts.URL + "/test")
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
|
|
resp, _ := io.ReadAll(res.Body)
|
|
assert.Equal(t, "<h1>Hello world</h1>", string(resp))
|
|
}
|
|
|
|
func TestLoadHTMLFSFuncMap(t *testing.T) {
|
|
ts := setupHTMLFiles(
|
|
t,
|
|
TestMode,
|
|
false,
|
|
func(router *Engine) {
|
|
router.LoadHTMLFS(tmplFS, "hello.tmpl", "raw.tmpl")
|
|
},
|
|
)
|
|
defer ts.Close()
|
|
|
|
res, err := http.Get(ts.URL + "/raw")
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
|
|
resp, _ := io.ReadAll(res.Body)
|
|
assert.Equal(t, "Date: 2017/07/01", string(resp))
|
|
}
|
|
|
|
func TestAddRoute(t *testing.T) {
|
|
router := New()
|
|
router.addRoute(http.MethodGet, "/", HandlersChain{func(_ *Context) {}})
|
|
|
|
assert.Len(t, router.trees, 1)
|
|
assert.NotNil(t, router.trees.get(http.MethodGet))
|
|
assert.Nil(t, router.trees.get(http.MethodPost))
|
|
|
|
router.addRoute(http.MethodPost, "/", HandlersChain{func(_ *Context) {}})
|
|
|
|
assert.Len(t, router.trees, 2)
|
|
assert.NotNil(t, router.trees.get(http.MethodGet))
|
|
assert.NotNil(t, router.trees.get(http.MethodPost))
|
|
|
|
router.addRoute(http.MethodPost, "/post", HandlersChain{func(_ *Context) {}})
|
|
assert.Len(t, router.trees, 2)
|
|
}
|
|
|
|
func TestAddRouteFails(t *testing.T) {
|
|
router := New()
|
|
assert.Panics(t, func() { router.addRoute("", "/", HandlersChain{func(_ *Context) {}}) })
|
|
assert.Panics(t, func() { router.addRoute(http.MethodGet, "a", HandlersChain{func(_ *Context) {}}) })
|
|
assert.Panics(t, func() { router.addRoute(http.MethodGet, "/", HandlersChain{}) })
|
|
|
|
router.addRoute(http.MethodPost, "/post", HandlersChain{func(_ *Context) {}})
|
|
assert.Panics(t, func() {
|
|
router.addRoute(http.MethodPost, "/post", HandlersChain{func(_ *Context) {}})
|
|
})
|
|
}
|
|
|
|
func TestCreateDefaultRouter(t *testing.T) {
|
|
router := Default()
|
|
assert.Len(t, router.Handlers, 2)
|
|
}
|
|
|
|
func TestNoRouteWithoutGlobalHandlers(t *testing.T) {
|
|
var middleware0 HandlerFunc = func(c *Context) {}
|
|
var middleware1 HandlerFunc = func(c *Context) {}
|
|
|
|
router := New()
|
|
|
|
router.NoRoute(middleware0)
|
|
assert.Nil(t, router.Handlers)
|
|
assert.Len(t, router.noRoute, 1)
|
|
assert.Len(t, router.allNoRoute, 1)
|
|
compareFunc(t, router.noRoute[0], middleware0)
|
|
compareFunc(t, router.allNoRoute[0], middleware0)
|
|
|
|
router.NoRoute(middleware1, middleware0)
|
|
assert.Len(t, router.noRoute, 2)
|
|
assert.Len(t, router.allNoRoute, 2)
|
|
compareFunc(t, router.noRoute[0], middleware1)
|
|
compareFunc(t, router.allNoRoute[0], middleware1)
|
|
compareFunc(t, router.noRoute[1], middleware0)
|
|
compareFunc(t, router.allNoRoute[1], middleware0)
|
|
}
|
|
|
|
func TestNoRouteWithGlobalHandlers(t *testing.T) {
|
|
var middleware0 HandlerFunc = func(c *Context) {}
|
|
var middleware1 HandlerFunc = func(c *Context) {}
|
|
var middleware2 HandlerFunc = func(c *Context) {}
|
|
|
|
router := New()
|
|
router.Use(middleware2)
|
|
|
|
router.NoRoute(middleware0)
|
|
assert.Len(t, router.allNoRoute, 2)
|
|
assert.Len(t, router.Handlers, 1)
|
|
assert.Len(t, router.noRoute, 1)
|
|
|
|
compareFunc(t, router.Handlers[0], middleware2)
|
|
compareFunc(t, router.noRoute[0], middleware0)
|
|
compareFunc(t, router.allNoRoute[0], middleware2)
|
|
compareFunc(t, router.allNoRoute[1], middleware0)
|
|
|
|
router.Use(middleware1)
|
|
assert.Len(t, router.allNoRoute, 3)
|
|
assert.Len(t, router.Handlers, 2)
|
|
assert.Len(t, router.noRoute, 1)
|
|
|
|
compareFunc(t, router.Handlers[0], middleware2)
|
|
compareFunc(t, router.Handlers[1], middleware1)
|
|
compareFunc(t, router.noRoute[0], middleware0)
|
|
compareFunc(t, router.allNoRoute[0], middleware2)
|
|
compareFunc(t, router.allNoRoute[1], middleware1)
|
|
compareFunc(t, router.allNoRoute[2], middleware0)
|
|
}
|
|
|
|
func TestNoMethodWithoutGlobalHandlers(t *testing.T) {
|
|
var middleware0 HandlerFunc = func(c *Context) {}
|
|
var middleware1 HandlerFunc = func(c *Context) {}
|
|
|
|
router := New()
|
|
|
|
router.NoMethod(middleware0)
|
|
assert.Empty(t, router.Handlers)
|
|
assert.Len(t, router.noMethod, 1)
|
|
assert.Len(t, router.allNoMethod, 1)
|
|
compareFunc(t, router.noMethod[0], middleware0)
|
|
compareFunc(t, router.allNoMethod[0], middleware0)
|
|
|
|
router.NoMethod(middleware1, middleware0)
|
|
assert.Len(t, router.noMethod, 2)
|
|
assert.Len(t, router.allNoMethod, 2)
|
|
compareFunc(t, router.noMethod[0], middleware1)
|
|
compareFunc(t, router.allNoMethod[0], middleware1)
|
|
compareFunc(t, router.noMethod[1], middleware0)
|
|
compareFunc(t, router.allNoMethod[1], middleware0)
|
|
}
|
|
|
|
func TestRebuild404Handlers(t *testing.T) {
|
|
var middleware0 HandlerFunc = func(c *Context) {}
|
|
var middleware1 HandlerFunc = func(c *Context) {}
|
|
|
|
router := New()
|
|
|
|
// Initially, allNoRoute should be nil
|
|
assert.Nil(t, router.allNoRoute)
|
|
|
|
// Set NoRoute handlers
|
|
router.NoRoute(middleware0)
|
|
assert.Len(t, router.allNoRoute, 1)
|
|
assert.Len(t, router.noRoute, 1)
|
|
compareFunc(t, router.allNoRoute[0], middleware0)
|
|
|
|
// Add Use middleware should trigger rebuild404Handlers
|
|
router.Use(middleware1)
|
|
assert.Len(t, router.allNoRoute, 2)
|
|
assert.Len(t, router.Handlers, 1)
|
|
assert.Len(t, router.noRoute, 1)
|
|
|
|
// Global middleware should come first
|
|
compareFunc(t, router.allNoRoute[0], middleware1)
|
|
compareFunc(t, router.allNoRoute[1], middleware0)
|
|
}
|
|
|
|
func TestNoMethodWithGlobalHandlers(t *testing.T) {
|
|
var middleware0 HandlerFunc = func(c *Context) {}
|
|
var middleware1 HandlerFunc = func(c *Context) {}
|
|
var middleware2 HandlerFunc = func(c *Context) {}
|
|
|
|
router := New()
|
|
router.Use(middleware2)
|
|
|
|
router.NoMethod(middleware0)
|
|
assert.Len(t, router.allNoMethod, 2)
|
|
assert.Len(t, router.Handlers, 1)
|
|
assert.Len(t, router.noMethod, 1)
|
|
|
|
compareFunc(t, router.Handlers[0], middleware2)
|
|
compareFunc(t, router.noMethod[0], middleware0)
|
|
compareFunc(t, router.allNoMethod[0], middleware2)
|
|
compareFunc(t, router.allNoMethod[1], middleware0)
|
|
|
|
router.Use(middleware1)
|
|
assert.Len(t, router.allNoMethod, 3)
|
|
assert.Len(t, router.Handlers, 2)
|
|
assert.Len(t, router.noMethod, 1)
|
|
|
|
compareFunc(t, router.Handlers[0], middleware2)
|
|
compareFunc(t, router.Handlers[1], middleware1)
|
|
compareFunc(t, router.noMethod[0], middleware0)
|
|
compareFunc(t, router.allNoMethod[0], middleware2)
|
|
compareFunc(t, router.allNoMethod[1], middleware1)
|
|
compareFunc(t, router.allNoMethod[2], middleware0)
|
|
}
|
|
|
|
func compareFunc(t *testing.T, a, b any) {
|
|
sf1 := reflect.ValueOf(a)
|
|
sf2 := reflect.ValueOf(b)
|
|
if sf1.Pointer() != sf2.Pointer() {
|
|
t.Error("different functions")
|
|
}
|
|
}
|
|
|
|
func TestListOfRoutes(t *testing.T) {
|
|
router := New()
|
|
router.GET("/favicon.ico", handlerTest1)
|
|
router.GET("/", handlerTest1)
|
|
group := router.Group("/users")
|
|
{
|
|
group.GET("/", handlerTest2)
|
|
group.GET("/:id", handlerTest1)
|
|
group.POST("/:id", handlerTest2)
|
|
}
|
|
router.Static("/static", ".")
|
|
|
|
list := router.Routes()
|
|
|
|
assert.Len(t, list, 7)
|
|
assertRoutePresent(t, list, RouteInfo{
|
|
Method: http.MethodGet,
|
|
Path: "/favicon.ico",
|
|
Handler: "^(.*/vendor/)?github.com/gin-gonic/gin.handlerTest1$",
|
|
})
|
|
assertRoutePresent(t, list, RouteInfo{
|
|
Method: http.MethodGet,
|
|
Path: "/",
|
|
Handler: "^(.*/vendor/)?github.com/gin-gonic/gin.handlerTest1$",
|
|
})
|
|
assertRoutePresent(t, list, RouteInfo{
|
|
Method: http.MethodGet,
|
|
Path: "/users/",
|
|
Handler: "^(.*/vendor/)?github.com/gin-gonic/gin.handlerTest2$",
|
|
})
|
|
assertRoutePresent(t, list, RouteInfo{
|
|
Method: http.MethodGet,
|
|
Path: "/users/:id",
|
|
Handler: "^(.*/vendor/)?github.com/gin-gonic/gin.handlerTest1$",
|
|
})
|
|
assertRoutePresent(t, list, RouteInfo{
|
|
Method: http.MethodPost,
|
|
Path: "/users/:id",
|
|
Handler: "^(.*/vendor/)?github.com/gin-gonic/gin.handlerTest2$",
|
|
})
|
|
}
|
|
|
|
func TestEngineHandleContext(t *testing.T) {
|
|
r := New()
|
|
r.GET("/", func(c *Context) {
|
|
c.Request.URL.Path = "/v2"
|
|
r.HandleContext(c)
|
|
})
|
|
v2 := r.Group("/v2")
|
|
{
|
|
v2.GET("/", func(c *Context) {})
|
|
}
|
|
|
|
assert.NotPanics(t, func() {
|
|
w := PerformRequest(r, http.MethodGet, "/")
|
|
assert.Equal(t, 301, w.Code)
|
|
})
|
|
}
|
|
|
|
func TestEngineHandleContextManyReEntries(t *testing.T) {
|
|
expectValue := 10000
|
|
|
|
var handlerCounter, middlewareCounter int64
|
|
|
|
r := New()
|
|
r.Use(func(c *Context) {
|
|
atomic.AddInt64(&middlewareCounter, 1)
|
|
})
|
|
r.GET("/:count", func(c *Context) {
|
|
countStr := c.Param("count")
|
|
count, err := strconv.Atoi(countStr)
|
|
require.NoError(t, err)
|
|
|
|
n, err := c.Writer.Write([]byte("."))
|
|
require.NoError(t, err)
|
|
assert.Equal(t, 1, n)
|
|
|
|
switch {
|
|
case count > 0:
|
|
c.Request.URL.Path = "/" + strconv.Itoa(count-1)
|
|
r.HandleContext(c)
|
|
}
|
|
}, func(c *Context) {
|
|
atomic.AddInt64(&handlerCounter, 1)
|
|
})
|
|
|
|
assert.NotPanics(t, func() {
|
|
w := PerformRequest(r, http.MethodGet, "/"+strconv.Itoa(expectValue-1)) // include 0 value
|
|
assert.Equal(t, 200, w.Code)
|
|
assert.Equal(t, expectValue, w.Body.Len())
|
|
})
|
|
|
|
assert.Equal(t, int64(expectValue), handlerCounter)
|
|
assert.Equal(t, int64(expectValue), middlewareCounter)
|
|
}
|
|
|
|
func TestEngineHandleContextPreventsMiddlewareReEntry(t *testing.T) {
|
|
// given
|
|
var handlerCounterV1, handlerCounterV2, middlewareCounterV1 int64
|
|
|
|
r := New()
|
|
v1 := r.Group("/v1")
|
|
{
|
|
v1.Use(func(c *Context) {
|
|
atomic.AddInt64(&middlewareCounterV1, 1)
|
|
})
|
|
v1.GET("/test", func(c *Context) {
|
|
atomic.AddInt64(&handlerCounterV1, 1)
|
|
c.Status(http.StatusOK)
|
|
})
|
|
}
|
|
|
|
v2 := r.Group("/v2")
|
|
{
|
|
v2.GET("/test", func(c *Context) {
|
|
c.Request.URL.Path = "/v1/test"
|
|
r.HandleContext(c)
|
|
}, func(c *Context) {
|
|
atomic.AddInt64(&handlerCounterV2, 1)
|
|
})
|
|
}
|
|
|
|
// when
|
|
responseV1 := PerformRequest(r, "GET", "/v1/test")
|
|
responseV2 := PerformRequest(r, "GET", "/v2/test")
|
|
|
|
// then
|
|
assert.Equal(t, 200, responseV1.Code)
|
|
assert.Equal(t, 200, responseV2.Code)
|
|
assert.Equal(t, int64(2), handlerCounterV1)
|
|
assert.Equal(t, int64(2), middlewareCounterV1)
|
|
assert.Equal(t, int64(1), handlerCounterV2)
|
|
}
|
|
|
|
func TestEngineHandleContextNoRouteWithGroupMiddleware(t *testing.T) {
|
|
// Scenario from issue #1848:
|
|
// - Engine with no global middleware (gin.New())
|
|
// - A group with middleware
|
|
// - A route in that group
|
|
// - NoRoute handler that redirects via HandleContext
|
|
// The group middleware should run exactly once per HandleContext call,
|
|
// not accumulate across redirects.
|
|
|
|
var middlewareCount, handlerCount int64
|
|
|
|
r := New()
|
|
grp := r.Group("", func(c *Context) {
|
|
atomic.AddInt64(&middlewareCount, 1)
|
|
c.Next()
|
|
})
|
|
grp.GET("/target", func(c *Context) {
|
|
atomic.AddInt64(&handlerCount, 1)
|
|
c.String(http.StatusOK, "ok")
|
|
})
|
|
|
|
r.NoRoute(func(c *Context) {
|
|
c.Request.URL.Path = "/target"
|
|
r.HandleContext(c)
|
|
})
|
|
|
|
// Access a non-existent route to trigger NoRoute -> HandleContext
|
|
w := PerformRequest(r, "GET", "/nonexistent")
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
assert.Equal(t, "ok", w.Body.String())
|
|
// Middleware and handler should each run exactly once
|
|
assert.Equal(t, int64(1), atomic.LoadInt64(&middlewareCount))
|
|
assert.Equal(t, int64(1), atomic.LoadInt64(&handlerCount))
|
|
}
|
|
|
|
func TestEngineHandleContextNoRouteWithEngineMiddleware(t *testing.T) {
|
|
// When engine middleware exists and NoRoute redirects via HandleContext,
|
|
// verify the handlers run the expected number of times.
|
|
|
|
var engineMwCount, groupMwCount, handlerCount int64
|
|
|
|
r := New()
|
|
r.Use(func(c *Context) {
|
|
atomic.AddInt64(&engineMwCount, 1)
|
|
c.Next()
|
|
})
|
|
|
|
grp := r.Group("", func(c *Context) {
|
|
atomic.AddInt64(&groupMwCount, 1)
|
|
c.Next()
|
|
})
|
|
grp.GET("/target", func(c *Context) {
|
|
atomic.AddInt64(&handlerCount, 1)
|
|
c.String(http.StatusOK, "ok")
|
|
})
|
|
|
|
r.NoRoute(func(c *Context) {
|
|
c.Request.URL.Path = "/target"
|
|
r.HandleContext(c)
|
|
})
|
|
|
|
w := PerformRequest(r, "GET", "/nonexistent")
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
assert.Equal(t, "ok", w.Body.String())
|
|
// Handler and group middleware should each run once (from HandleContext)
|
|
assert.Equal(t, int64(1), atomic.LoadInt64(&handlerCount))
|
|
assert.Equal(t, int64(1), atomic.LoadInt64(&groupMwCount))
|
|
// Engine middleware runs twice: once for the NoRoute chain, once for the HandleContext chain
|
|
// This is expected behavior since HandleContext re-enters the full handler chain
|
|
assert.Equal(t, int64(2), atomic.LoadInt64(&engineMwCount))
|
|
}
|
|
|
|
func TestEngineHandleContextUseEscapedPathPercentEncoded(t *testing.T) {
|
|
r := New()
|
|
r.UseEscapedPath = true
|
|
r.UnescapePathValues = false
|
|
|
|
r.GET("/v1/:path", func(c *Context) {
|
|
// Path is Escaped, the %25 is not interpreted as %
|
|
assert.Equal(t, "foo%252Fbar", c.Param("path"))
|
|
c.Status(http.StatusOK)
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/v1/foo%252Fbar", nil)
|
|
w := httptest.NewRecorder()
|
|
r.ServeHTTP(w, req)
|
|
}
|
|
|
|
func TestEngineHandleContextUseRawPathPercentEncoded(t *testing.T) {
|
|
r := New()
|
|
r.UseRawPath = true
|
|
r.UnescapePathValues = false
|
|
|
|
r.GET("/v1/:path", func(c *Context) {
|
|
// Path is used, the %25 is interpreted as %
|
|
assert.Equal(t, "foo%2Fbar", c.Param("path"))
|
|
c.Status(http.StatusOK)
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/v1/foo%252Fbar", nil)
|
|
w := httptest.NewRecorder()
|
|
r.ServeHTTP(w, req)
|
|
}
|
|
|
|
func TestEngineHandleContextUseEscapedPathOverride(t *testing.T) {
|
|
r := New()
|
|
r.UseEscapedPath = true
|
|
r.UseRawPath = true
|
|
r.UnescapePathValues = false
|
|
|
|
r.GET("/v1/:path", func(c *Context) {
|
|
assert.Equal(t, "foo%25bar", c.Param("path"))
|
|
c.Status(http.StatusOK)
|
|
})
|
|
|
|
assert.NotPanics(t, func() {
|
|
w := PerformRequest(r, http.MethodGet, "/v1/foo%25bar")
|
|
assert.Equal(t, 200, w.Code)
|
|
})
|
|
}
|
|
|
|
func TestPrepareTrustedCIRDsWith(t *testing.T) {
|
|
r := New()
|
|
|
|
// valid ipv4 cidr
|
|
{
|
|
expectedTrustedCIDRs := []*net.IPNet{parseCIDR("0.0.0.0/0")}
|
|
err := r.SetTrustedProxies([]string{"0.0.0.0/0"})
|
|
|
|
require.NoError(t, err)
|
|
assert.Equal(t, expectedTrustedCIDRs, r.trustedCIDRs)
|
|
}
|
|
|
|
// invalid ipv4 cidr
|
|
{
|
|
err := r.SetTrustedProxies([]string{"192.168.1.33/33"})
|
|
|
|
require.Error(t, err)
|
|
}
|
|
|
|
// valid ipv4 address
|
|
{
|
|
expectedTrustedCIDRs := []*net.IPNet{parseCIDR("192.168.1.33/32")}
|
|
|
|
err := r.SetTrustedProxies([]string{"192.168.1.33"})
|
|
|
|
require.NoError(t, err)
|
|
assert.Equal(t, expectedTrustedCIDRs, r.trustedCIDRs)
|
|
}
|
|
|
|
// invalid ipv4 address
|
|
{
|
|
err := r.SetTrustedProxies([]string{"192.168.1.256"})
|
|
|
|
require.Error(t, err)
|
|
}
|
|
|
|
// valid ipv6 address
|
|
{
|
|
expectedTrustedCIDRs := []*net.IPNet{parseCIDR("2002:0000:0000:1234:abcd:ffff:c0a8:0101/128")}
|
|
err := r.SetTrustedProxies([]string{"2002:0000:0000:1234:abcd:ffff:c0a8:0101"})
|
|
|
|
require.NoError(t, err)
|
|
assert.Equal(t, expectedTrustedCIDRs, r.trustedCIDRs)
|
|
}
|
|
|
|
// invalid ipv6 address
|
|
{
|
|
err := r.SetTrustedProxies([]string{"gggg:0000:0000:1234:abcd:ffff:c0a8:0101"})
|
|
|
|
require.Error(t, err)
|
|
}
|
|
|
|
// valid ipv6 cidr
|
|
{
|
|
expectedTrustedCIDRs := []*net.IPNet{parseCIDR("::/0")}
|
|
err := r.SetTrustedProxies([]string{"::/0"})
|
|
|
|
require.NoError(t, err)
|
|
assert.Equal(t, expectedTrustedCIDRs, r.trustedCIDRs)
|
|
}
|
|
|
|
// invalid ipv6 cidr
|
|
{
|
|
err := r.SetTrustedProxies([]string{"gggg:0000:0000:1234:abcd:ffff:c0a8:0101/129"})
|
|
|
|
require.Error(t, err)
|
|
}
|
|
|
|
// valid combination
|
|
{
|
|
expectedTrustedCIDRs := []*net.IPNet{
|
|
parseCIDR("::/0"),
|
|
parseCIDR("192.168.0.0/16"),
|
|
parseCIDR("172.16.0.1/32"),
|
|
}
|
|
err := r.SetTrustedProxies([]string{
|
|
"::/0",
|
|
"192.168.0.0/16",
|
|
"172.16.0.1",
|
|
})
|
|
|
|
require.NoError(t, err)
|
|
assert.Equal(t, expectedTrustedCIDRs, r.trustedCIDRs)
|
|
}
|
|
|
|
// invalid combination
|
|
{
|
|
err := r.SetTrustedProxies([]string{
|
|
"::/0",
|
|
"192.168.0.0/16",
|
|
"172.16.0.256",
|
|
})
|
|
|
|
require.Error(t, err)
|
|
}
|
|
|
|
// nil value
|
|
{
|
|
err := r.SetTrustedProxies(nil)
|
|
|
|
assert.Nil(t, r.trustedCIDRs)
|
|
require.NoError(t, err)
|
|
}
|
|
}
|
|
|
|
func parseCIDR(cidr string) *net.IPNet {
|
|
_, parsedCIDR, err := net.ParseCIDR(cidr)
|
|
if err != nil {
|
|
fmt.Println(err)
|
|
}
|
|
return parsedCIDR
|
|
}
|
|
|
|
func assertRoutePresent(t *testing.T, gotRoutes RoutesInfo, wantRoute RouteInfo) {
|
|
for _, gotRoute := range gotRoutes {
|
|
if gotRoute.Path == wantRoute.Path && gotRoute.Method == wantRoute.Method {
|
|
assert.Regexp(t, wantRoute.Handler, gotRoute.Handler)
|
|
return
|
|
}
|
|
}
|
|
t.Errorf("route not found: %v", wantRoute)
|
|
}
|
|
|
|
func handlerTest1(c *Context) {}
|
|
func handlerTest2(c *Context) {}
|
|
|
|
func TestNewOptionFunc(t *testing.T) {
|
|
fc := func(e *Engine) {
|
|
e.GET("/test1", handlerTest1)
|
|
e.GET("/test2", handlerTest2)
|
|
|
|
e.Use(func(c *Context) {
|
|
c.Next()
|
|
})
|
|
}
|
|
|
|
r := New(fc)
|
|
|
|
routes := r.Routes()
|
|
assertRoutePresent(t, routes, RouteInfo{Path: "/test1", Method: http.MethodGet, Handler: "github.com/gin-gonic/gin.handlerTest1"})
|
|
assertRoutePresent(t, routes, RouteInfo{Path: "/test2", Method: http.MethodGet, Handler: "github.com/gin-gonic/gin.handlerTest2"})
|
|
}
|
|
|
|
func TestWithOptionFunc(t *testing.T) {
|
|
r := New()
|
|
|
|
r.With(func(e *Engine) {
|
|
e.GET("/test1", handlerTest1)
|
|
e.GET("/test2", handlerTest2)
|
|
|
|
e.Use(func(c *Context) {
|
|
c.Next()
|
|
})
|
|
})
|
|
|
|
routes := r.Routes()
|
|
assertRoutePresent(t, routes, RouteInfo{Path: "/test1", Method: http.MethodGet, Handler: "github.com/gin-gonic/gin.handlerTest1"})
|
|
assertRoutePresent(t, routes, RouteInfo{Path: "/test2", Method: http.MethodGet, Handler: "github.com/gin-gonic/gin.handlerTest2"})
|
|
}
|
|
|
|
type Birthday string
|
|
|
|
func (b *Birthday) UnmarshalParam(param string) error {
|
|
*b = Birthday(strings.ReplaceAll(param, "-", "/"))
|
|
return nil
|
|
}
|
|
|
|
func TestCustomUnmarshalStruct(t *testing.T) {
|
|
route := Default()
|
|
var request struct {
|
|
Birthday Birthday `form:"birthday"`
|
|
}
|
|
route.GET("/test", func(ctx *Context) {
|
|
_ = ctx.BindQuery(&request)
|
|
ctx.JSON(200, request.Birthday)
|
|
})
|
|
req := httptest.NewRequest(http.MethodGet, "/test?birthday=2000-01-01", nil)
|
|
w := httptest.NewRecorder()
|
|
route.ServeHTTP(w, req)
|
|
assert.Equal(t, 200, w.Code)
|
|
assert.Equal(t, `"2000/01/01"`, w.Body.String())
|
|
}
|
|
|
|
// Test the fix for https://github.com/gin-gonic/gin/issues/4002
|
|
func TestMethodNotAllowedNoRoute(t *testing.T) {
|
|
g := New()
|
|
g.HandleMethodNotAllowed = true
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
resp := httptest.NewRecorder()
|
|
assert.NotPanics(t, func() { g.ServeHTTP(resp, req) })
|
|
assert.Equal(t, http.StatusNotFound, resp.Code)
|
|
}
|
|
|
|
// Test the fix for https://github.com/gin-gonic/gin/pull/4415
|
|
func TestLiteralColonWithRun(t *testing.T) {
|
|
SetMode(TestMode)
|
|
router := New()
|
|
|
|
router.GET(`/test\:action`, func(c *Context) {
|
|
c.JSON(http.StatusOK, H{"path": "literal_colon"})
|
|
})
|
|
|
|
router.updateRouteTrees()
|
|
|
|
w := httptest.NewRecorder()
|
|
|
|
req, _ := http.NewRequest(http.MethodGet, "/test:action", nil)
|
|
router.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
assert.Contains(t, w.Body.String(), "literal_colon")
|
|
}
|
|
|
|
func TestLiteralColonWithDirectServeHTTP(t *testing.T) {
|
|
SetMode(TestMode)
|
|
router := New()
|
|
|
|
router.GET(`/test\:action`, func(c *Context) {
|
|
c.JSON(http.StatusOK, H{"path": "literal_colon"})
|
|
})
|
|
|
|
w := httptest.NewRecorder()
|
|
req, _ := http.NewRequest(http.MethodGet, "/test:action", nil)
|
|
router.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
assert.Contains(t, w.Body.String(), "literal_colon")
|
|
}
|
|
|
|
func TestLiteralColonWithHandler(t *testing.T) {
|
|
SetMode(TestMode)
|
|
router := New()
|
|
|
|
router.GET(`/test\:action`, func(c *Context) {
|
|
c.JSON(http.StatusOK, H{"path": "literal_colon"})
|
|
})
|
|
|
|
handler := router.Handler()
|
|
|
|
w := httptest.NewRecorder()
|
|
req, _ := http.NewRequest(http.MethodGet, "/test:action", nil)
|
|
handler.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
assert.Contains(t, w.Body.String(), "literal_colon")
|
|
}
|
|
|
|
func TestLiteralColonWithHTTPServer(t *testing.T) {
|
|
SetMode(TestMode)
|
|
router := New()
|
|
|
|
router.GET(`/test\:action`, func(c *Context) {
|
|
c.JSON(http.StatusOK, H{"path": "literal_colon"})
|
|
})
|
|
|
|
router.GET("/test/:param", func(c *Context) {
|
|
c.JSON(http.StatusOK, H{"param": c.Param("param")})
|
|
})
|
|
|
|
w := httptest.NewRecorder()
|
|
req, _ := http.NewRequest(http.MethodGet, "/test:action", nil)
|
|
router.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
assert.Contains(t, w.Body.String(), "literal_colon")
|
|
|
|
w2 := httptest.NewRecorder()
|
|
req2, _ := http.NewRequest(http.MethodGet, "/test/foo", nil)
|
|
router.ServeHTTP(w2, req2)
|
|
|
|
assert.Equal(t, http.StatusOK, w2.Code)
|
|
assert.Contains(t, w2.Body.String(), "foo")
|
|
}
|
|
|
|
// Test that updateRouteTrees is called only once
|
|
func TestUpdateRouteTreesCalledOnce(t *testing.T) {
|
|
SetMode(TestMode)
|
|
router := New()
|
|
|
|
router.GET(`/test\:action`, func(c *Context) {
|
|
c.String(http.StatusOK, "ok")
|
|
})
|
|
|
|
for range 5 {
|
|
w := httptest.NewRecorder()
|
|
req, _ := http.NewRequest(http.MethodGet, "/test:action", nil)
|
|
router.ServeHTTP(w, req)
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
assert.Equal(t, "ok", w.Body.String())
|
|
}
|
|
}
|
|
|
|
// Test the fix for https://github.com/gin-gonic/gin/issues/4189
|
|
func TestSkipMethodNotAllowedMiddleware(t *testing.T) {
|
|
g := New()
|
|
g.HandleMethodNotAllowed = true
|
|
g.SkipMethodNotAllowedMiddleware = true
|
|
|
|
var middlewareCalled bool
|
|
middleware := func(c *Context) {
|
|
middlewareCalled = true
|
|
c.Next()
|
|
}
|
|
noMethodHandler := func(c *Context) {
|
|
c.String(http.StatusMethodNotAllowed, "method not allowed")
|
|
}
|
|
|
|
g.Use(middleware)
|
|
g.NoMethod(noMethodHandler)
|
|
g.POST("/test", func(c *Context) {
|
|
c.String(http.StatusOK, "ok")
|
|
})
|
|
|
|
w := httptest.NewRecorder()
|
|
req, _ := http.NewRequest(http.MethodGet, "/test", nil)
|
|
g.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, http.StatusMethodNotAllowed, w.Code)
|
|
assert.Equal(t, "method not allowed", w.Body.String())
|
|
assert.False(t, middlewareCalled, "middleware should not be called when SkipMethodNotAllowedMiddleware is true")
|
|
}
|
|
|
|
func TestSkipMethodNotAllowedMiddlewareDisabled(t *testing.T) {
|
|
g := New()
|
|
g.HandleMethodNotAllowed = true
|
|
g.SkipMethodNotAllowedMiddleware = false
|
|
|
|
var middlewareCalled bool
|
|
middleware := func(c *Context) {
|
|
middlewareCalled = true
|
|
c.Next()
|
|
}
|
|
noMethodHandler := func(c *Context) {
|
|
c.String(http.StatusMethodNotAllowed, "method not allowed")
|
|
}
|
|
|
|
g.Use(middleware)
|
|
g.NoMethod(noMethodHandler)
|
|
g.POST("/test", func(c *Context) {
|
|
c.String(http.StatusOK, "ok")
|
|
})
|
|
|
|
w := httptest.NewRecorder()
|
|
req, _ := http.NewRequest(http.MethodGet, "/test", nil)
|
|
g.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, http.StatusMethodNotAllowed, w.Code)
|
|
assert.Equal(t, "method not allowed", w.Body.String())
|
|
assert.True(t, middlewareCalled, "middleware should be called when SkipMethodNotAllowedMiddleware is false")
|
|
}
|