diff --git a/.github/workflows/gin.yml b/.github/workflows/gin.yml index 21895f28..15c2530a 100644 --- a/.github/workflows/gin.yml +++ b/.github/workflows/gin.yml @@ -21,13 +21,14 @@ jobs: - name: Setup golangci-lint uses: golangci/golangci-lint-action@v2 with: - version: v1.42.0 + version: v1.42.1 args: --verbose test: + needs: lint strategy: matrix: os: [ubuntu-latest, macos-latest] - go: [1.13, 1.14, 1.15, 1.16] + go: [1.13, 1.14, 1.15, 1.16, 1.17] test-tags: ['', nomsgpack] include: - os: ubuntu-latest diff --git a/.golangci.yml b/.golangci.yml index 78a4259a..c5e1de38 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -2,21 +2,38 @@ run: timeout: 5m linters: enable: - - gofmt - - misspell - - revive + - asciicheck + - depguard + - dogsled + - durationcheck + - errcheck + - errorlint + - exportloopref + - gci + - gofmt + - goimports + - gosec + - misspell + - nakedret + - nilerr + - nolintlint + - revive + - wastedassign issues: exclude-rules: - - linters: - - structcheck - - unused - text: "`data` is unused" - - linters: - - staticcheck - text: "SA1019:" - - linters: - - revive - text: "var-naming:" - - linters: - - revive - text: "exported:" + - linters: + - structcheck + - unused + text: "`data` is unused" + - linters: + - staticcheck + text: "SA1019:" + - linters: + - revive + text: "var-naming:" + - linters: + - revive + text: "exported:" + - path: _test\.go + linters: + - gosec # security is not make sense in tests diff --git a/README.md b/README.md index 198bf011..cad746d6 100644 --- a/README.md +++ b/README.md @@ -78,6 +78,7 @@ Gin is a web framework written in Go (Golang). It features a martini-like API wi - [http2 server push](#http2-server-push) - [Define format for the log of routes](#define-format-for-the-log-of-routes) - [Set and get a cookie](#set-and-get-a-cookie) + - [Don't trust all proxies](#don't-trust-all-proxies) - [Testing](#testing) - [Users](#users) @@ -1827,7 +1828,7 @@ func main() { quit := make(chan os.Signal) // kill (no param) default send syscall.SIGTERM // kill -2 is syscall.SIGINT - // kill -9 is syscall.SIGKILL but can't be catch, so don't need add it + // kill -9 is syscall.SIGKILL but can't be caught, so don't need to add it signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) <-quit log.Println("Shutting down server...") @@ -2000,7 +2001,7 @@ func SomeHandler(c *gin.Context) { objA := formA{} objB := formB{} // This reads c.Request.Body and stores the result into the context. - if errA := c.ShouldBindBodyWith(&objA, binding.JSON); errA == nil { + if errA := c.ShouldBindBodyWith(&objA, binding.Form); errA == nil { c.String(http.StatusOK, `the body should be formA`) // At this time, it reuses body stored in the context. } else if errB := c.ShouldBindBodyWith(&objB, binding.JSON); errB == nil { @@ -2202,11 +2203,17 @@ Gin lets you specify which headers to hold the real client IP (if any), as well as specifying which proxies (or direct clients) you trust to specify one of these headers. -The `TrustedProxies` slice on your `gin.Engine` specifes network addresses or -network CIDRs from where clients which their request headers related to client +Use function `SetTrustedProxies()` on your `gin.Engine` to specify network addresses +or network CIDRs from where clients which their request headers related to client IP can be trusted. They can be IPv4 addresses, IPv4 CIDRs, IPv6 addresses or IPv6 CIDRs. +**Attention:** Gin trust all proxies by default if you don't specify a trusted +proxy using the function above, **this is NOT safe**. At the same time, if you don't +use any proxy, you can disable this feature by using `Engine.SetTrustedProxies(nil)`, +then `Context.ClientIP()` will return the remote address directly to avoid some +unnecessary computation. + ```go import ( "fmt" @@ -2217,7 +2224,7 @@ import ( func main() { router := gin.Default() - router.TrustedProxies = []string{"192.168.1.2"} + router.SetTrustedProxies([]string{"192.168.1.2"}) router.GET("/", func(c *gin.Context) { // If the client is 192.168.1.2, use the X-Forwarded-For @@ -2230,6 +2237,34 @@ func main() { } ``` +**Notice:** If you are using a CDN service, you can set the `Engine.TrustedPlatform` +to skip TrustedProxies check, it has a higher priority than TrustedProxies. +Look at the example below: +```go +import ( + "fmt" + + "github.com/gin-gonic/gin" +) + +func main() { + + router := gin.Default() + // Use predefined header gin.PlatformXXX + router.TrustedPlatform = gin.PlatformGoogleAppEngine + // Or set your own trusted request header for another trusted proxy service + // Don't set it to any suspect request header, it's unsafe + router.TrustedPlatform = "X-CDN-IP" + + router.GET("/", func(c *gin.Context) { + // If you set TrustedPlatform, ClientIP() will resolve the + // corresponding header and return IP directly + fmt.Printf("ClientIP: %s\n", c.ClientIP()) + }) + router.Run() +} +``` + ## Testing The `net/http/httptest` package is preferable way for HTTP testing. diff --git a/binding/binding_test.go b/binding/binding_test.go index 17df7dc5..5b0ce39d 100644 --- a/binding/binding_test.go +++ b/binding/binding_test.go @@ -1339,6 +1339,13 @@ func testProtoBodyBindingFail(t *testing.T, b Binding, name, path, badPath, body err := b.Bind(req, &obj) assert.Error(t, err) + invalid_obj := FooStruct{} + req.Body = ioutil.NopCloser(strings.NewReader(`{"msg":"hello"}`)) + req.Header.Add("Content-Type", MIMEPROTOBUF) + err = b.Bind(req, &invalid_obj) + assert.Error(t, err) + assert.Equal(t, err.Error(), "obj is not ProtoMessage") + obj = protoexample.Test{} req = requestWithBody("POST", badPath, badBody) req.Header.Add("Content-Type", MIMEPROTOBUF) diff --git a/binding/form.go b/binding/form.go index 040af9e2..fa2a6540 100644 --- a/binding/form.go +++ b/binding/form.go @@ -5,6 +5,7 @@ package binding import ( + "errors" "net/http" ) @@ -22,7 +23,7 @@ func (formBinding) Bind(req *http.Request, obj interface{}) error { if err := req.ParseForm(); err != nil { return err } - if err := req.ParseMultipartForm(defaultMemory); err != nil && err != http.ErrNotMultipart { + if err := req.ParseMultipartForm(defaultMemory); err != nil && !errors.Is(err, http.ErrNotMultipart) { return err } if err := mapForm(obj, req.Form); err != nil { diff --git a/binding/form_mapping.go b/binding/form_mapping.go index f8b4b123..fa7ad1bc 100644 --- a/binding/form_mapping.go +++ b/binding/form_mapping.go @@ -61,7 +61,7 @@ func mapFormByTag(ptr interface{}, form map[string][]string, tag string) error { // setter tries to set value on a walking by fields of a struct type setter interface { - TrySet(value reflect.Value, field reflect.StructField, key string, opt setOptions) (isSetted bool, err error) + TrySet(value reflect.Value, field reflect.StructField, key string, opt setOptions) (isSet bool, err error) } type formSource map[string][]string @@ -69,7 +69,7 @@ type formSource map[string][]string var _ setter = formSource(nil) // TrySet tries to set a value by request's form source (like map[string][]string) -func (form formSource) TrySet(value reflect.Value, field reflect.StructField, tagValue string, opt setOptions) (isSetted bool, err error) { +func (form formSource) TrySet(value reflect.Value, field reflect.StructField, tagValue string, opt setOptions) (isSet bool, err error) { return setByForm(value, field, form, tagValue, opt) } @@ -92,14 +92,14 @@ func mapping(value reflect.Value, field reflect.StructField, setter setter, tag isNew = true vPtr = reflect.New(value.Type().Elem()) } - isSetted, err := mapping(vPtr.Elem(), field, setter, tag) + isSet, err := mapping(vPtr.Elem(), field, setter, tag) if err != nil { return false, err } - if isNew && isSetted { + if isNew && isSet { value.Set(vPtr) } - return isSetted, nil + return isSet, nil } if vKind != reflect.Struct || !field.Anonymous { @@ -115,7 +115,7 @@ func mapping(value reflect.Value, field reflect.StructField, setter setter, tag if vKind == reflect.Struct { tValue := value.Type() - var isSetted bool + var isSet bool for i := 0; i < value.NumField(); i++ { sf := tValue.Field(i) if sf.PkgPath != "" && !sf.Anonymous { // unexported @@ -125,9 +125,9 @@ func mapping(value reflect.Value, field reflect.StructField, setter setter, tag if err != nil { return false, err } - isSetted = isSetted || ok + isSet = isSet || ok } - return isSetted, nil + return isSet, nil } return false, nil } @@ -164,7 +164,7 @@ func tryToSetValue(value reflect.Value, field reflect.StructField, setter setter return setter.TrySet(value, field, tagValue, setOpt) } -func setByForm(value reflect.Value, field reflect.StructField, form map[string][]string, tagValue string, opt setOptions) (isSetted bool, err error) { +func setByForm(value reflect.Value, field reflect.StructField, form map[string][]string, tagValue string, opt setOptions) (isSet bool, err error) { vs, ok := form[tagValue] if !ok && !opt.isDefaultExists { return false, nil diff --git a/binding/multipart_form_mapping.go b/binding/multipart_form_mapping.go index 69c0a544..c4d7ed74 100644 --- a/binding/multipart_form_mapping.go +++ b/binding/multipart_form_mapping.go @@ -32,7 +32,7 @@ func (r *multipartRequest) TrySet(value reflect.Value, field reflect.StructField return setByForm(value, field, r.MultipartForm.Value, key, opt) } -func setByMultipartFormFile(value reflect.Value, field reflect.StructField, files []*multipart.FileHeader) (isSetted bool, err error) { +func setByMultipartFormFile(value reflect.Value, field reflect.StructField, files []*multipart.FileHeader) (isSet bool, err error) { switch value.Kind() { case reflect.Ptr: switch value.Interface().(type) { @@ -48,9 +48,9 @@ func setByMultipartFormFile(value reflect.Value, field reflect.StructField, file } case reflect.Slice: slice := reflect.MakeSlice(value.Type(), len(files), len(files)) - isSetted, err = setArrayOfMultipartFormFiles(slice, field, files) - if err != nil || !isSetted { - return isSetted, err + isSet, err = setArrayOfMultipartFormFiles(slice, field, files) + if err != nil || !isSet { + return isSet, err } value.Set(slice) return true, nil @@ -60,14 +60,14 @@ func setByMultipartFormFile(value reflect.Value, field reflect.StructField, file return false, ErrMultiFileHeader } -func setArrayOfMultipartFormFiles(value reflect.Value, field reflect.StructField, files []*multipart.FileHeader) (isSetted bool, err error) { +func setArrayOfMultipartFormFiles(value reflect.Value, field reflect.StructField, files []*multipart.FileHeader) (isSet bool, err error) { if value.Len() != len(files) { return false, ErrMultiFileHeaderLenInvalid } for i := range files { - setted, err := setByMultipartFormFile(value.Index(i), field, files[i:i+1]) - if err != nil || !setted { - return setted, err + set, err := setByMultipartFormFile(value.Index(i), field, files[i:i+1]) + if err != nil || !set { + return set, err } } return true, nil diff --git a/binding/protobuf.go b/binding/protobuf.go index ca02897a..a4e47153 100644 --- a/binding/protobuf.go +++ b/binding/protobuf.go @@ -5,6 +5,7 @@ package binding import ( + "errors" "io/ioutil" "net/http" @@ -26,7 +27,11 @@ func (b protobufBinding) Bind(req *http.Request, obj interface{}) error { } func (protobufBinding) BindBody(body []byte, obj interface{}) error { - if err := proto.Unmarshal(body, obj.(proto.Message)); err != nil { + msg, ok := obj.(proto.Message) + if !ok { + return errors.New("obj is not ProtoMessage") + } + if err := proto.Unmarshal(body, msg); err != nil { return err } // Here it's same to return validate(obj), but util now we can't add diff --git a/context.go b/context.go index 56c20f2c..58f38c88 100644 --- a/context.go +++ b/context.go @@ -55,8 +55,9 @@ type Context struct { index int8 fullPath string - engine *Engine - params *Params + engine *Engine + params *Params + skippedNodes *[]skippedNode // This mutex protect Keys map mu sync.RWMutex @@ -99,6 +100,7 @@ func (c *Context) reset() { c.queryCache = nil c.formCache = nil *c.params = (*c.params)[:0] + *c.skippedNodes = (*c.skippedNodes)[:0] } // Copy returns a copy of the current context that can be safely used outside the request's scope. @@ -220,7 +222,8 @@ func (c *Context) Error(err error) *Error { panic("err is nil") } - parsedError, ok := err.(*Error) + var parsedError *Error + ok := errors.As(err, &parsedError) if !ok { parsedError = &Error{ Err: err, @@ -383,6 +386,15 @@ func (c *Context) Param(key string) string { return c.Params.ByName(key) } +// AddParam adds param to context and +// replaces path param key with given value for e2e testing purposes +// Example Route: "/user/:id" +// AddParam("id", 1) +// Result: "/user/1" +func (c *Context) AddParam(key, value string) { + c.Params = append(c.Params, Param{Key: key, Value: value}) +} + // Query returns the keyed url query value if it exists, // otherwise it returns an empty string `("")`. // It is shortcut for `c.Request.URL.Query().Get(key)` @@ -506,7 +518,7 @@ func (c *Context) initFormCache() { c.formCache = make(url.Values) req := c.Request if err := req.ParseMultipartForm(c.engine.MaxMultipartMemory); err != nil { - if err != http.ErrNotMultipart { + if !errors.Is(err, http.ErrNotMultipart) { debugPrint("error on parse multipart form array: %v", err) } } @@ -723,20 +735,16 @@ func (c *Context) ShouldBindBodyWith(obj interface{}, bb binding.BindingBody) (e return bb.BindBody(body, obj) } -// ClientIP implements a best effort algorithm to return the real client IP. +// ClientIP implements one best effort algorithm to return the real client IP. // It called c.RemoteIP() under the hood, to check if the remote IP is a trusted proxy or not. -// If it's it will then try to parse the headers defined in Engine.RemoteIPHeaders (defaulting to [X-Forwarded-For, X-Real-Ip]). -// If the headers are nots syntactically valid OR the remote IP does not correspong to a trusted proxy, +// If it is it will then try to parse the headers defined in Engine.RemoteIPHeaders (defaulting to [X-Forwarded-For, X-Real-Ip]). +// If the headers are not syntactically valid OR the remote IP does not correspond to a trusted proxy, // the remote IP (coming form Request.RemoteAddr) is returned. func (c *Context) ClientIP() string { - // Check if we're running on a trusted platform - switch c.engine.TrustedPlatform { - case PlatformGoogleAppEngine: - if addr := c.requestHeader("X-Appengine-Remote-Addr"); addr != "" { - return addr - } - case PlatformCloudflare: - if addr := c.requestHeader("CF-Connecting-IP"); addr != "" { + // Check if we're running on a trusted platform, continue running backwards if error + if c.engine.TrustedPlatform != "" { + // Developers can define their own header of Trusted Platform or use predefined constants + if addr := c.requestHeader(c.engine.TrustedPlatform); addr != "" { return addr } } @@ -756,7 +764,7 @@ func (c *Context) ClientIP() string { if trusted && c.engine.ForwardedByClientIP && c.engine.RemoteIPHeaders != nil { for _, headerName := range c.engine.RemoteIPHeaders { - ip, valid := validateHeader(c.requestHeader(headerName)) + ip, valid := c.engine.validateHeader(c.requestHeader(headerName)) if valid { return ip } @@ -765,10 +773,21 @@ func (c *Context) ClientIP() string { return remoteIP.String() } +func (e *Engine) isTrustedProxy(ip net.IP) bool { + if e.trustedCIDRs != nil { + for _, cidr := range e.trustedCIDRs { + if cidr.Contains(ip) { + return true + } + } + } + return false +} + // RemoteIP parses the IP from Request.RemoteAddr, normalizes and returns the IP (without the port). // It also checks if the remoteIP is a trusted proxy or not. // In order to perform this validation, it will see if the IP is contained within at least one of the CIDR blocks -// defined in Engine.TrustedProxies +// defined by Engine.SetTrustedProxies() func (c *Context) RemoteIP() (net.IP, bool) { ip, _, err := net.SplitHostPort(strings.TrimSpace(c.Request.RemoteAddr)) if err != nil { @@ -779,35 +798,25 @@ func (c *Context) RemoteIP() (net.IP, bool) { return nil, false } - if c.engine.trustedCIDRs != nil { - for _, cidr := range c.engine.trustedCIDRs { - if cidr.Contains(remoteIP) { - return remoteIP, true - } - } - } - - return remoteIP, false + return remoteIP, c.engine.isTrustedProxy(remoteIP) } -func validateHeader(header string) (clientIP string, valid bool) { +func (e *Engine) validateHeader(header string) (clientIP string, valid bool) { if header == "" { return "", false } items := strings.Split(header, ",") - for i, ipStr := range items { - ipStr = strings.TrimSpace(ipStr) + for i := len(items) - 1; i >= 0; i-- { + ipStr := strings.TrimSpace(items[i]) ip := net.ParseIP(ipStr) if ip == nil { return "", false } - // We need to return the first IP in the list, but, - // we should not early return since we need to validate that - // the rest of the header is syntactically valid - if i == 0 { - clientIP = ipStr - valid = true + // X-Forwarded-For is appended by proxy + // Check IPs in reverse order and stop when find untrusted proxy + if (i == 0) || (!e.isTrustedProxy(ip)) { + return ipStr, true } } return diff --git a/context_1.16_test.go b/context_1.16_test.go new file mode 100644 index 00000000..053e6c5a --- /dev/null +++ b/context_1.16_test.go @@ -0,0 +1,31 @@ +// Copyright 2021 Gin Core Team. All rights reserved. +// Use of this source code is governed by a MIT style +// license that can be found in the LICENSE file. + +//go:build !go1.17 +// +build !go1.17 + +package gin + +import ( + "bytes" + "mime/multipart" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestContextFormFileFailed16(t *testing.T) { + buf := new(bytes.Buffer) + mw := multipart.NewWriter(buf) + mw.Close() + c, _ := CreateTestContext(httptest.NewRecorder()) + c.Request, _ = http.NewRequest("POST", "/", nil) + c.Request.Header.Set("Content-Type", mw.FormDataContentType()) + c.engine.MaxMultipartMemory = 8 << 20 + f, err := c.FormFile("file") + assert.Error(t, err) + assert.Nil(t, f) +} diff --git a/context_1.17_test.go b/context_1.17_test.go new file mode 100644 index 00000000..431d54c7 --- /dev/null +++ b/context_1.17_test.go @@ -0,0 +1,33 @@ +// Copyright 2021 Gin Core Team. All rights reserved. +// Use of this source code is governed by a MIT style +// license that can be found in the LICENSE file. + +//go:build go1.17 +// +build go1.17 + +package gin + +import ( + "bytes" + "mime/multipart" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestContextFormFileFailed17(t *testing.T) { + buf := new(bytes.Buffer) + mw := multipart.NewWriter(buf) + mw.Close() + c, _ := CreateTestContext(httptest.NewRecorder()) + c.Request, _ = http.NewRequest("POST", "/", nil) + c.Request.Header.Set("Content-Type", mw.FormDataContentType()) + c.engine.MaxMultipartMemory = 8 << 20 + assert.Panics(t, func() { + f, err := c.FormFile("file") + assert.Error(t, err) + assert.Nil(t, f) + }) +} diff --git a/context_test.go b/context_test.go index 176eaae6..c286c0f4 100644 --- a/context_test.go +++ b/context_test.go @@ -23,10 +23,9 @@ import ( "github.com/gin-contrib/sse" "github.com/gin-gonic/gin/binding" + testdata "github.com/gin-gonic/gin/testdata/protoexample" "github.com/stretchr/testify/assert" "google.golang.org/protobuf/proto" - - testdata "github.com/gin-gonic/gin/testdata/protoexample" ) var _ context.Context = &Context{} @@ -87,19 +86,6 @@ func TestContextFormFile(t *testing.T) { assert.NoError(t, c.SaveUploadedFile(f, "test")) } -func TestContextFormFileFailed(t *testing.T) { - buf := new(bytes.Buffer) - mw := multipart.NewWriter(buf) - mw.Close() - c, _ := CreateTestContext(httptest.NewRecorder()) - c.Request, _ = http.NewRequest("POST", "/", nil) - c.Request.Header.Set("Content-Type", mw.FormDataContentType()) - c.engine.MaxMultipartMemory = 8 << 20 - f, err := c.FormFile("file") - assert.Error(t, err) - assert.Nil(t, f) -} - func TestContextMultipartForm(t *testing.T) { buf := new(bytes.Buffer) mw := multipart.NewWriter(buf) @@ -1423,13 +1409,17 @@ func TestContextClientIP(t *testing.T) { c.engine.RemoteIPHeaders = []string{"X-Forwarded-For"} assert.Equal(t, "40.40.40.40", c.ClientIP()) + // Disabled TrustedProxies feature + _ = c.engine.SetTrustedProxies(nil) + assert.Equal(t, "40.40.40.40", c.ClientIP()) + // Last proxy is trusted, but the RemoteAddr is not _ = c.engine.SetTrustedProxies([]string{"30.30.30.30"}) assert.Equal(t, "40.40.40.40", c.ClientIP()) // Only trust RemoteAddr _ = c.engine.SetTrustedProxies([]string{"40.40.40.40"}) - assert.Equal(t, "20.20.20.20", c.ClientIP()) + assert.Equal(t, "30.30.30.30", c.ClientIP()) // All steps are trusted _ = c.engine.SetTrustedProxies([]string{"40.40.40.40", "30.30.30.30", "20.20.20.20"}) @@ -1468,8 +1458,20 @@ func TestContextClientIP(t *testing.T) { c.engine.TrustedPlatform = PlatformGoogleAppEngine assert.Equal(t, "50.50.50.50", c.ClientIP()) - // Test the legacy flag + // Use custom TrustedPlatform header + c.engine.TrustedPlatform = "X-CDN-IP" + c.Request.Header.Set("X-CDN-IP", "80.80.80.80") + assert.Equal(t, "80.80.80.80", c.ClientIP()) + // wrong header + c.engine.TrustedPlatform = "X-Wrong-Header" + assert.Equal(t, "40.40.40.40", c.ClientIP()) + + c.Request.Header.Del("X-CDN-IP") + // TrustedPlatform is empty c.engine.TrustedPlatform = "" + assert.Equal(t, "40.40.40.40", c.ClientIP()) + + // Test the legacy flag c.engine.AppEngine = true assert.Equal(t, "50.50.50.50", c.ClientIP()) c.engine.AppEngine = false @@ -2150,3 +2152,14 @@ func TestContextWithFallbackValueFromRequestContext(t *testing.T) { }) } } + +func TestContextAddParam(t *testing.T) { + c := &Context{} + id := "id" + value := "1" + c.AddParam(id, value) + + v, ok := c.Params.Get(id) + assert.Equal(t, ok, true) + assert.Equal(t, value, v) +} diff --git a/errors_test.go b/errors_test.go index ee95ab31..9a17d859 100644 --- a/errors_test.go +++ b/errors_test.go @@ -113,7 +113,7 @@ func (e TestErr) Error() string { return string(e) } // TestErrorUnwrap tests the behavior of gin.Error with "errors.Is()" and "errors.As()". // "errors.Is()" and "errors.As()" have been added to the standard library in go 1.13. func TestErrorUnwrap(t *testing.T) { - innerErr := TestErr("somme error") + innerErr := TestErr("some error") // 2 layers of wrapping : use 'fmt.Errorf("%w")' to wrap a gin.Error{}, which itself wraps innerErr err := fmt.Errorf("wrapped: %w", &Error{ diff --git a/gin.go b/gin.go index 6ab2be66..1d9c9fee 100644 --- a/gin.go +++ b/gin.go @@ -11,6 +11,7 @@ import ( "net/http" "os" "path" + "reflect" "strings" "sync" @@ -27,6 +28,8 @@ var ( var defaultPlatform string +var defaultTrustedCIDRs = []*net.IPNet{{IP: net.IP{0x0, 0x0, 0x0, 0x0}, Mask: net.IPMask{0x0, 0x0, 0x0, 0x0}}} // 0.0.0.0/0 + // HandlerFunc defines the handler used by gin middleware as return value. type HandlerFunc func(*Context) @@ -56,10 +59,10 @@ type RoutesInfo []RouteInfo const ( // When running on Google App Engine. Trust X-Appengine-Remote-Addr // for determining the client's IP - PlatformGoogleAppEngine = "google-app-engine" + PlatformGoogleAppEngine = "X-Appengine-Remote-Addr" // When using Cloudflare's CDN. Trust CF-Connecting-IP for determining // the client's IP - PlatformCloudflare = "cloudflare" + PlatformCloudflare = "CF-Connecting-IP" ) // Engine is the framework's instance, it contains the muxer, middleware and configuration settings. @@ -119,15 +122,9 @@ type Engine struct { // List of headers used to obtain the client IP when // `(*gin.Engine).ForwardedByClientIP` is `true` and // `(*gin.Context).Request.RemoteAddr` is matched by at least one of the - // network origins of `(*gin.Engine).TrustedProxies`. + // network origins of list defined by `(*gin.Engine).SetTrustedProxies()`. RemoteIPHeaders []string - // List of network origins (IPv4 addresses, IPv4 CIDRs, IPv6 addresses or - // IPv6 CIDRs) from which to trust request's headers that contain - // alternative client IP when `(*gin.Engine).ForwardedByClientIP` is - // `true`. - TrustedProxies []string - // If set to a constant of value gin.Platform*, trusts the headers set by // that platform, for example to determine the client IP TrustedPlatform string @@ -147,6 +144,8 @@ type Engine struct { pool sync.Pool trees methodTrees maxParams uint16 + maxSections uint16 + trustedProxies []string trustedCIDRs []*net.IPNet } @@ -174,7 +173,6 @@ func New() *Engine { HandleMethodNotAllowed: false, ForwardedByClientIP: true, RemoteIPHeaders: []string{"X-Forwarded-For", "X-Real-IP"}, - TrustedProxies: []string{"0.0.0.0/0"}, TrustedPlatform: defaultPlatform, UseRawPath: false, RemoveExtraSlash: false, @@ -183,6 +181,8 @@ func New() *Engine { trees: make(methodTrees, 0, 9), delims: render.Delims{Left: "{{", Right: "}}"}, secureJSONPrefix: "while(1);", + trustedProxies: []string{"0.0.0.0/0"}, + trustedCIDRs: defaultTrustedCIDRs, } engine.RouterGroup.engine = engine engine.pool.New = func() interface{} { @@ -201,7 +201,8 @@ func Default() *Engine { func (engine *Engine) allocateContext() *Context { v := make(Params, 0, engine.maxParams) - return &Context{engine: engine, params: &v} + skippedNodes := make([]skippedNode, 0, engine.maxSections) + return &Context{engine: engine, params: &v, skippedNodes: &skippedNodes} } // Delims sets template left and right delims and returns a Engine instance. @@ -264,7 +265,7 @@ func (engine *Engine) NoRoute(handlers ...HandlerFunc) { engine.rebuild404Handlers() } -// NoMethod sets the handlers called when... TODO. +// NoMethod sets the handlers called when Engine.HandleMethodNotAllowed = true. func (engine *Engine) NoMethod(handlers ...HandlerFunc) { engine.noMethod = handlers engine.rebuild405Handlers() @@ -307,6 +308,10 @@ func (engine *Engine) addRoute(method, path string, handlers HandlersChain) { if paramsCount := countParams(path); paramsCount > engine.maxParams { engine.maxParams = paramsCount } + + if sectionsCount := countSections(path); sectionsCount > engine.maxSections { + engine.maxSections = sectionsCount + } } // Routes returns a slice of registered routes, including some useful information, such as: @@ -341,9 +346,9 @@ func iterate(path, method string, routes RoutesInfo, root *node) RoutesInfo { func (engine *Engine) Run(addr ...string) (err error) { defer func() { debugPrintError(err) }() - err = engine.parseTrustedProxies() - if err != nil { - return err + if engine.isUnsafeTrustedProxies() { + debugPrint("[WARNING] You trusted all proxies, this is NOT safe. We recommend you to set a value.\n" + + "Please check https://pkg.go.dev/github.com/gin-gonic/gin#readme-don-t-trust-all-proxies for details.") } address := resolveAddress(addr) @@ -353,12 +358,12 @@ func (engine *Engine) Run(addr ...string) (err error) { } func (engine *Engine) prepareTrustedCIDRs() ([]*net.IPNet, error) { - if engine.TrustedProxies == nil { + if engine.trustedProxies == nil { return nil, nil } - cidr := make([]*net.IPNet, 0, len(engine.TrustedProxies)) - for _, trustedProxy := range engine.TrustedProxies { + cidr := make([]*net.IPNet, 0, len(engine.trustedProxies)) + for _, trustedProxy := range engine.trustedProxies { if !strings.Contains(trustedProxy, "/") { ip := parseIP(trustedProxy) if ip == nil { @@ -381,13 +386,25 @@ func (engine *Engine) prepareTrustedCIDRs() ([]*net.IPNet, error) { return cidr, nil } -// SetTrustedProxies set Engine.TrustedProxies +// SetTrustedProxies set a list of network origins (IPv4 addresses, +// IPv4 CIDRs, IPv6 addresses or IPv6 CIDRs) from which to trust +// request's headers that contain alternative client IP when +// `(*gin.Engine).ForwardedByClientIP` is `true`. `TrustedProxies` +// feature is enabled by default, and it also trusts all proxies +// by default. If you want to disable this feature, use +// Engine.SetTrustedProxies(nil), then Context.ClientIP() will +// return the remote address directly. func (engine *Engine) SetTrustedProxies(trustedProxies []string) error { - engine.TrustedProxies = trustedProxies + engine.trustedProxies = trustedProxies return engine.parseTrustedProxies() } -// parseTrustedProxies parse Engine.TrustedProxies to Engine.trustedCIDRs +// isUnsafeTrustedProxies compares Engine.trustedCIDRs and defaultTrustedCIDRs, it's not safe if equal (returns true) +func (engine *Engine) isUnsafeTrustedProxies() bool { + return reflect.DeepEqual(engine.trustedCIDRs, defaultTrustedCIDRs) +} + +// parseTrustedProxies parse Engine.trustedProxies to Engine.trustedCIDRs func (engine *Engine) parseTrustedProxies() error { trustedCIDRs, err := engine.prepareTrustedCIDRs() engine.trustedCIDRs = trustedCIDRs @@ -415,9 +432,9 @@ func (engine *Engine) RunTLS(addr, certFile, keyFile string) (err error) { debugPrint("Listening and serving HTTPS on %s\n", addr) defer func() { debugPrintError(err) }() - err = engine.parseTrustedProxies() - if err != nil { - return err + if engine.isUnsafeTrustedProxies() { + debugPrint("[WARNING] You trusted all proxies, this is NOT safe. We recommend you to set a value.\n" + + "Please check https://pkg.go.dev/github.com/gin-gonic/gin#readme-don-t-trust-all-proxies for details.") } err = http.ListenAndServeTLS(addr, certFile, keyFile, engine) @@ -431,9 +448,9 @@ func (engine *Engine) RunUnix(file string) (err error) { debugPrint("Listening and serving HTTP on unix:/%s", file) defer func() { debugPrintError(err) }() - err = engine.parseTrustedProxies() - if err != nil { - return err + if engine.isUnsafeTrustedProxies() { + debugPrint("[WARNING] You trusted all proxies, this is NOT safe. We recommend you to set a value.\n" + + "Please check https://pkg.go.dev/github.com/gin-gonic/gin#readme-don-t-trust-all-proxies for details.") } listener, err := net.Listen("unix", file) @@ -454,9 +471,9 @@ func (engine *Engine) RunFd(fd int) (err error) { debugPrint("Listening and serving HTTP on fd@%d", fd) defer func() { debugPrintError(err) }() - err = engine.parseTrustedProxies() - if err != nil { - return err + if engine.isUnsafeTrustedProxies() { + debugPrint("[WARNING] You trusted all proxies, this is NOT safe. We recommend you to set a value.\n" + + "Please check https://pkg.go.dev/github.com/gin-gonic/gin#readme-don-t-trust-all-proxies for details.") } f := os.NewFile(uintptr(fd), fmt.Sprintf("fd@%d", fd)) @@ -475,9 +492,9 @@ func (engine *Engine) RunListener(listener net.Listener) (err error) { debugPrint("Listening and serving HTTP on listener what's bind with address@%s", listener.Addr()) defer func() { debugPrintError(err) }() - err = engine.parseTrustedProxies() - if err != nil { - return err + if engine.isUnsafeTrustedProxies() { + debugPrint("[WARNING] You trusted all proxies, this is NOT safe. We recommend you to set a value.\n" + + "Please check https://pkg.go.dev/github.com/gin-gonic/gin#readme-don-t-trust-all-proxies for details.") } err = http.Serve(listener, engine) @@ -528,7 +545,7 @@ func (engine *Engine) handleHTTPRequest(c *Context) { } root := t[i].root // Find route in tree - value := root.getValue(rPath, c.params, unescape) + value := root.getValue(rPath, c.params, c.skippedNodes, unescape) if value.params != nil { c.Params = *value.params } @@ -556,7 +573,7 @@ func (engine *Engine) handleHTTPRequest(c *Context) { if tree.method == httpMethod { continue } - if value := tree.root.getValue(rPath, nil, unescape); value.handlers != nil { + if value := tree.root.getValue(rPath, nil, c.skippedNodes, unescape); value.handlers != nil { c.handlers = engine.allNoMethod serveError(c, http.StatusMethodNotAllowed, default405Body) return diff --git a/gin_integration_test.go b/gin_integration_test.go index 094c46e8..8c22e7bd 100644 --- a/gin_integration_test.go +++ b/gin_integration_test.go @@ -76,6 +76,12 @@ func TestRunEmpty(t *testing.T) { testRequest(t, "http://localhost:8080/example") } +func TestBadTrustedCIDRs(t *testing.T) { + router := New() + assert.Error(t, router.SetTrustedProxies([]string{"hello/world"})) +} + +/* legacy tests func TestBadTrustedCIDRsForRun(t *testing.T) { os.Setenv("PORT", "") router := New() @@ -143,6 +149,7 @@ func TestBadTrustedCIDRsForRunTLS(t *testing.T) { router.TrustedProxies = []string{"hello/world"} assert.Error(t, router.RunTLS(":8080", "./testdata/certificate/cert.pem", "./testdata/certificate/key.pem")) } +*/ func TestRunTLS(t *testing.T) { router := New() @@ -401,8 +408,13 @@ func TestTreeRunDynamicRouting(t *testing.T) { router.GET("/ab/*xx", func(c *Context) { c.String(http.StatusOK, "/ab/*xx") }) router.GET("/", func(c *Context) { c.String(http.StatusOK, "home") }) router.GET("/:cc", func(c *Context) { c.String(http.StatusOK, "/:cc") }) + router.GET("/c1/:dd/e", func(c *Context) { c.String(http.StatusOK, "/c1/:dd/e") }) + router.GET("/c1/:dd/e1", func(c *Context) { c.String(http.StatusOK, "/c1/:dd/e1") }) + router.GET("/c1/:dd/f1", func(c *Context) { c.String(http.StatusOK, "/c1/:dd/f1") }) + router.GET("/c1/:dd/f2", func(c *Context) { c.String(http.StatusOK, "/c1/:dd/f2") }) router.GET("/:cc/cc", func(c *Context) { c.String(http.StatusOK, "/:cc/cc") }) router.GET("/:cc/:dd/ee", func(c *Context) { c.String(http.StatusOK, "/:cc/:dd/ee") }) + router.GET("/:cc/:dd/f", func(c *Context) { c.String(http.StatusOK, "/:cc/:dd/f") }) router.GET("/:cc/:dd/:ee/ff", func(c *Context) { c.String(http.StatusOK, "/:cc/:dd/:ee/ff") }) router.GET("/:cc/:dd/:ee/:ff/gg", func(c *Context) { c.String(http.StatusOK, "/:cc/:dd/:ee/:ff/gg") }) router.GET("/:cc/:dd/:ee/:ff/:gg/hh", func(c *Context) { c.String(http.StatusOK, "/:cc/:dd/:ee/:ff/:gg/hh") }) @@ -439,6 +451,10 @@ func TestTreeRunDynamicRouting(t *testing.T) { testRequest(t, ts.URL+"/all", "", "/:cc") testRequest(t, ts.URL+"/all/cc", "", "/:cc/cc") testRequest(t, ts.URL+"/a/cc", "", "/:cc/cc") + testRequest(t, ts.URL+"/c1/d/e", "", "/c1/:dd/e") + testRequest(t, ts.URL+"/c1/d/e1", "", "/c1/:dd/e1") + testRequest(t, ts.URL+"/c1/d/ee", "", "/:cc/:dd/ee") + testRequest(t, ts.URL+"/c1/d/f", "", "/:cc/:dd/f") testRequest(t, ts.URL+"/c/d/ee", "", "/:cc/:dd/ee") testRequest(t, ts.URL+"/c/d/e/ff", "", "/:cc/:dd/:ee/ff") testRequest(t, ts.URL+"/c/d/e/f/gg", "", "/:cc/:dd/:ee/:ff/gg") @@ -521,6 +537,12 @@ func TestTreeRunDynamicRouting(t *testing.T) { testRequest(t, ts.URL+"/get/abc/123abf/testss", "", "/get/abc/123abf/:param") testRequest(t, ts.URL+"/get/abc/123abfff/te", "", "/get/abc/123abfff/:param") // 404 not found + testRequest(t, ts.URL+"/c/d/e", "404 Not Found") + testRequest(t, ts.URL+"/c/d/e1", "404 Not Found") + testRequest(t, ts.URL+"/c/d/eee", "404 Not Found") + testRequest(t, ts.URL+"/c1/d/eee", "404 Not Found") + testRequest(t, ts.URL+"/c1/d/e2", "404 Not Found") + testRequest(t, ts.URL+"/cc/dd/ee/ff/gg/hh1", "404 Not Found") testRequest(t, ts.URL+"/a/dd", "404 Not Found") testRequest(t, ts.URL+"/addr/dd/aa", "404 Not Found") testRequest(t, ts.URL+"/something/secondthing/121", "404 Not Found") diff --git a/gin_test.go b/gin_test.go index 678d95f2..21c43d15 100644 --- a/gin_test.go +++ b/gin_test.go @@ -539,19 +539,15 @@ func TestPrepareTrustedCIRDsWith(t *testing.T) { // valid ipv4 cidr { expectedTrustedCIDRs := []*net.IPNet{parseCIDR("0.0.0.0/0")} - r.TrustedProxies = []string{"0.0.0.0/0"} - - trustedCIDRs, err := r.prepareTrustedCIDRs() + err := r.SetTrustedProxies([]string{"0.0.0.0/0"}) assert.NoError(t, err) - assert.Equal(t, expectedTrustedCIDRs, trustedCIDRs) + assert.Equal(t, expectedTrustedCIDRs, r.trustedCIDRs) } // invalid ipv4 cidr { - r.TrustedProxies = []string{"192.168.1.33/33"} - - _, err := r.prepareTrustedCIDRs() + err := r.SetTrustedProxies([]string{"192.168.1.33/33"}) assert.Error(t, err) } @@ -559,19 +555,16 @@ func TestPrepareTrustedCIRDsWith(t *testing.T) { // valid ipv4 address { expectedTrustedCIDRs := []*net.IPNet{parseCIDR("192.168.1.33/32")} - r.TrustedProxies = []string{"192.168.1.33"} - trustedCIDRs, err := r.prepareTrustedCIDRs() + err := r.SetTrustedProxies([]string{"192.168.1.33"}) assert.NoError(t, err) - assert.Equal(t, expectedTrustedCIDRs, trustedCIDRs) + assert.Equal(t, expectedTrustedCIDRs, r.trustedCIDRs) } // invalid ipv4 address { - r.TrustedProxies = []string{"192.168.1.256"} - - _, err := r.prepareTrustedCIDRs() + err := r.SetTrustedProxies([]string{"192.168.1.256"}) assert.Error(t, err) } @@ -579,19 +572,15 @@ func TestPrepareTrustedCIRDsWith(t *testing.T) { // valid ipv6 address { expectedTrustedCIDRs := []*net.IPNet{parseCIDR("2002:0000:0000:1234:abcd:ffff:c0a8:0101/128")} - r.TrustedProxies = []string{"2002:0000:0000:1234:abcd:ffff:c0a8:0101"} - - trustedCIDRs, err := r.prepareTrustedCIDRs() + err := r.SetTrustedProxies([]string{"2002:0000:0000:1234:abcd:ffff:c0a8:0101"}) assert.NoError(t, err) - assert.Equal(t, expectedTrustedCIDRs, trustedCIDRs) + assert.Equal(t, expectedTrustedCIDRs, r.trustedCIDRs) } // invalid ipv6 address { - r.TrustedProxies = []string{"gggg:0000:0000:1234:abcd:ffff:c0a8:0101"} - - _, err := r.prepareTrustedCIDRs() + err := r.SetTrustedProxies([]string{"gggg:0000:0000:1234:abcd:ffff:c0a8:0101"}) assert.Error(t, err) } @@ -599,19 +588,15 @@ func TestPrepareTrustedCIRDsWith(t *testing.T) { // valid ipv6 cidr { expectedTrustedCIDRs := []*net.IPNet{parseCIDR("::/0")} - r.TrustedProxies = []string{"::/0"} - - trustedCIDRs, err := r.prepareTrustedCIDRs() + err := r.SetTrustedProxies([]string{"::/0"}) assert.NoError(t, err) - assert.Equal(t, expectedTrustedCIDRs, trustedCIDRs) + assert.Equal(t, expectedTrustedCIDRs, r.trustedCIDRs) } // invalid ipv6 cidr { - r.TrustedProxies = []string{"gggg:0000:0000:1234:abcd:ffff:c0a8:0101/129"} - - _, err := r.prepareTrustedCIDRs() + err := r.SetTrustedProxies([]string{"gggg:0000:0000:1234:abcd:ffff:c0a8:0101/129"}) assert.Error(t, err) } @@ -623,36 +608,32 @@ func TestPrepareTrustedCIRDsWith(t *testing.T) { parseCIDR("192.168.0.0/16"), parseCIDR("172.16.0.1/32"), } - r.TrustedProxies = []string{ + err := r.SetTrustedProxies([]string{ "::/0", "192.168.0.0/16", "172.16.0.1", - } - - trustedCIDRs, err := r.prepareTrustedCIDRs() + }) assert.NoError(t, err) - assert.Equal(t, expectedTrustedCIDRs, trustedCIDRs) + assert.Equal(t, expectedTrustedCIDRs, r.trustedCIDRs) } // invalid combination { - r.TrustedProxies = []string{ + err := r.SetTrustedProxies([]string{ "::/0", "192.168.0.0/16", "172.16.0.256", - } - _, err := r.prepareTrustedCIDRs() + }) assert.Error(t, err) } // nil value { - r.TrustedProxies = nil - trustedCIDRs, err := r.prepareTrustedCIDRs() + err := r.SetTrustedProxies(nil) - assert.Nil(t, trustedCIDRs) + assert.Nil(t, r.trustedCIDRs) assert.Nil(t, err) } diff --git a/go.mod b/go.mod index bf682593..c25eecf9 100644 --- a/go.mod +++ b/go.mod @@ -5,9 +5,9 @@ go 1.13 require ( github.com/gin-contrib/sse v0.1.0 github.com/go-playground/validator/v10 v10.9.0 - github.com/goccy/go-json v0.7.6 - github.com/json-iterator/go v1.1.11 - github.com/mattn/go-isatty v0.0.13 + github.com/goccy/go-json v0.7.10 + github.com/json-iterator/go v1.1.12 + github.com/mattn/go-isatty v0.0.14 github.com/stretchr/testify v1.7.0 github.com/ugorji/go/codec v1.2.6 google.golang.org/protobuf v1.27.1 diff --git a/go.sum b/go.sum index 351ee253..51497a0b 100644 --- a/go.sum +++ b/go.sum @@ -12,14 +12,14 @@ github.com/go-playground/universal-translator v0.18.0 h1:82dyy6p4OuJq4/CByFNOn/j github.com/go-playground/universal-translator v0.18.0/go.mod h1:UvRDBj+xPUEGrFYl+lu/H90nyDXpg0fqeB/AQUGNTVA= github.com/go-playground/validator/v10 v10.9.0 h1:NgTtmN58D0m8+UuxtYmGztBJB7VnPgjj221I1QHci2A= github.com/go-playground/validator/v10 v10.9.0/go.mod h1:74x4gJWsvQexRdW8Pn3dXSGrTK4nAUsbPlLADvpJkos= -github.com/goccy/go-json v0.7.6 h1:H0wq4jppBQ+9222sk5+hPLL25abZQiRuQ6YPnjO9c+A= -github.com/goccy/go-json v0.7.6/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= +github.com/goccy/go-json v0.7.10 h1:ulhbuNe1JqE68nMRXXTJRrUu0uhouf0VevLINxQq4Ec= +github.com/goccy/go-json v0.7.10/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= -github.com/json-iterator/go v1.1.11 h1:uVUAXhF2To8cbw/3xN3pxj6kk7TYKs98NIrTqPlMWAQ= -github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= +github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= +github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= @@ -30,12 +30,12 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.2.1 h1:BqpAaACuzVSgi/VLzGZIobT2z4v53pjosyNd9Yv6n/w= github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY= -github.com/mattn/go-isatty v0.0.13 h1:qdl+GuBjcsKKDco5BsxPJlId98mSWNKqYA+Co0SC1yA= -github.com/mattn/go-isatty v0.0.13/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= +github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y= +github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= -github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742 h1:Esafd1046DLDQ0W1YjYsBW+p8U2u7vzgW2SQVmlNazg= -github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= +github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -54,9 +54,9 @@ github.com/ugorji/go/codec v1.2.6/go.mod h1:V6TCNZ4PHqoHGFZuSG1W8nrCzzdgA2DozYxW golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 h1:/UOmuWzQfxxo9UtlXMwuQU8CMgg1eZXqTRwkSQJWKOI= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069 h1:siQdpVirKtzPhKl3lZWozZraCFObP8S1v6PRp0bLrtU= golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= diff --git a/middleware_test.go b/middleware_test.go index fca1c530..4b4afd4a 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -118,7 +118,10 @@ func TestMiddlewareNoMethodEnabled(t *testing.T) { func TestMiddlewareNoMethodDisabled(t *testing.T) { signature := "" router := New() + + // NoMethod disabled router.HandleMethodNotAllowed = false + router.Use(func(c *Context) { signature += "A" c.Next() @@ -144,6 +147,7 @@ func TestMiddlewareNoMethodDisabled(t *testing.T) { router.POST("/", func(c *Context) { signature += " XX " }) + // RUN w := performRequest(router, "GET", "/") diff --git a/recovery.go b/recovery.go index 3101fe28..39f13551 100644 --- a/recovery.go +++ b/recovery.go @@ -6,6 +6,7 @@ package gin import ( "bytes" + "errors" "fmt" "io" "io/ioutil" @@ -60,7 +61,8 @@ func CustomRecoveryWithWriter(out io.Writer, handle RecoveryFunc) HandlerFunc { // condition that warrants a panic stack trace. var brokenPipe bool if ne, ok := err.(*net.OpError); ok { - if se, ok := ne.Err.(*os.SyscallError); ok { + var se *os.SyscallError + if errors.As(ne, &se) { if strings.Contains(strings.ToLower(se.Error()), "broken pipe") || strings.Contains(strings.ToLower(se.Error()), "connection reset by peer") { brokenPipe = true } diff --git a/render/render_test.go b/render/render_test.go index 22e7d5ab..e417731a 100644 --- a/render/render_test.go +++ b/render/render_test.go @@ -14,10 +14,9 @@ import ( "strings" "testing" + testdata "github.com/gin-gonic/gin/testdata/protoexample" "github.com/stretchr/testify/assert" "google.golang.org/protobuf/proto" - - testdata "github.com/gin-gonic/gin/testdata/protoexample" ) // TODO unit tests diff --git a/routergroup.go b/routergroup.go index bb24bd52..27d7aad6 100644 --- a/routergroup.go +++ b/routergroup.go @@ -14,6 +14,13 @@ import ( var ( // reg match english letters for http method name regEnLetter = regexp.MustCompile("^[A-Z]+$") + + // anyMethods for RouterGroup Any method + anyMethods = []string{ + http.MethodGet, http.MethodPost, http.MethodPut, http.MethodPatch, + http.MethodHead, http.MethodOptions, http.MethodDelete, http.MethodConnect, + http.MethodTrace, + } ) // IRouter defines all router handle interface includes single and group router. @@ -136,15 +143,10 @@ func (group *RouterGroup) HEAD(relativePath string, handlers ...HandlerFunc) IRo // Any registers a route that matches all the HTTP methods. // GET, POST, PUT, PATCH, HEAD, OPTIONS, DELETE, CONNECT, TRACE. func (group *RouterGroup) Any(relativePath string, handlers ...HandlerFunc) IRoutes { - group.handle(http.MethodGet, relativePath, handlers) - group.handle(http.MethodPost, relativePath, handlers) - group.handle(http.MethodPut, relativePath, handlers) - group.handle(http.MethodPatch, relativePath, handlers) - group.handle(http.MethodHead, relativePath, handlers) - group.handle(http.MethodOptions, relativePath, handlers) - group.handle(http.MethodDelete, relativePath, handlers) - group.handle(http.MethodConnect, relativePath, handlers) - group.handle(http.MethodTrace, relativePath, handlers) + for _, method := range anyMethods { + group.handle(method, relativePath, handlers) + } + return group.returnObj() } diff --git a/tree.go b/tree.go index b23ccfd5..a0d98352 100644 --- a/tree.go +++ b/tree.go @@ -17,6 +17,7 @@ import ( var ( strColon = []byte(":") strStar = []byte("*") + strSlash = []byte("/") ) // Param is a single URL parameter, consisting of a key and a value. @@ -98,6 +99,11 @@ func countParams(path string) uint16 { return n } +func countSections(path string) uint16 { + s := bytesconv.StringToBytes(path) + return uint16(bytes.Count(s, strSlash)) +} + type nodeType uint8 const ( @@ -393,16 +399,19 @@ type nodeValue struct { fullPath string } +type skippedNode struct { + path string + node *node + paramsCount int16 +} + // Returns the handle registered with the given path (key). The values of // wildcards are saved to a map. // If no handle can be found, a TSR (trailing slash redirect) recommendation is // made if a handle exists with an extra (without the) trailing slash for the // given path. -func (n *node) getValue(path string, params *Params, unescape bool) (value nodeValue) { - var ( - skippedPath string - latestNode = n // Caching the latest node - ) +func (n *node) getValue(path string, params *Params, skippedNodes *[]skippedNode, unescape bool) (value nodeValue) { + var globalParamsCount int16 walk: // Outer loop for walking the tree for { @@ -417,15 +426,20 @@ walk: // Outer loop for walking the tree if c == idxc { // strings.HasPrefix(n.children[len(n.children)-1].path, ":") == n.wildChild if n.wildChild { - skippedPath = prefix + path - latestNode = &node{ - path: n.path, - wildChild: n.wildChild, - nType: n.nType, - priority: n.priority, - children: n.children, - handlers: n.handlers, - fullPath: n.fullPath, + index := len(*skippedNodes) + *skippedNodes = (*skippedNodes)[:index+1] + (*skippedNodes)[index] = skippedNode{ + path: prefix + path, + node: &node{ + path: n.path, + wildChild: n.wildChild, + nType: n.nType, + priority: n.priority, + children: n.children, + handlers: n.handlers, + fullPath: n.fullPath, + }, + paramsCount: globalParamsCount, } } @@ -434,10 +448,22 @@ walk: // Outer loop for walking the tree } } // If the path at the end of the loop is not equal to '/' and the current node has no child nodes - // the current node needs to be equal to the latest matching node - matched := path != "/" && !n.wildChild - if matched { - n = latestNode + // the current node needs to roll back to last vaild skippedNode + + if path != "/" && !n.wildChild { + for l := len(*skippedNodes); l > 0; { + skippedNode := (*skippedNodes)[l-1] + *skippedNodes = (*skippedNodes)[:l-1] + if strings.HasSuffix(skippedNode.path, path) { + path = skippedNode.path + n = skippedNode.node + if value.params != nil { + *value.params = (*value.params)[:skippedNode.paramsCount] + } + globalParamsCount = skippedNode.paramsCount + continue walk + } + } } // If there is no wildcard pattern, recommend a redirection @@ -451,18 +477,12 @@ walk: // Outer loop for walking the tree // Handle wildcard child, which is always at the end of the array n = n.children[len(n.children)-1] + globalParamsCount++ switch n.nType { case param: // fix truncate the parameter // tree_test.go line: 204 - if matched { - path = prefix + path - // The saved path is used after the prefix route is intercepted by matching - if n.indices == "/" { - path = skippedPath[1:] - } - } // Find param end (either '/' or path end) end := 0 @@ -548,9 +568,22 @@ walk: // Outer loop for walking the tree if path == prefix { // If the current path does not equal '/' and the node does not have a registered handle and the most recently matched node has a child node - // the current node needs to be equal to the latest matching node - if latestNode.wildChild && n.handlers == nil && path != "/" { - n = latestNode.children[len(latestNode.children)-1] + // the current node needs to roll back to last vaild skippedNode + if n.handlers == nil && path != "/" { + for l := len(*skippedNodes); l > 0; { + skippedNode := (*skippedNodes)[l-1] + *skippedNodes = (*skippedNodes)[:l-1] + if strings.HasSuffix(skippedNode.path, path) { + path = skippedNode.path + n = skippedNode.node + if value.params != nil { + *value.params = (*value.params)[:skippedNode.paramsCount] + } + globalParamsCount = skippedNode.paramsCount + continue walk + } + } + // n = latestNode.children[len(latestNode.children)-1] } // We should have reached the node containing the handle. // Check if this node has a handle registered. @@ -581,19 +614,21 @@ walk: // Outer loop for walking the tree return } - if path != "/" && len(skippedPath) > 0 && strings.HasSuffix(skippedPath, path) { - path = skippedPath - // Reduce the number of cycles - n, latestNode = latestNode, n - // skippedPath cannot execute - // example: - // * /:cc/cc - // call /a/cc expectations:match/200 Actual:match/200 - // call /a/dd expectations:unmatch/404 Actual: panic - // call /addr/dd/aa expectations:unmatch/404 Actual: panic - // skippedPath: It can only be executed if the secondary route is not found - skippedPath = "" - continue walk + // roll back to last vaild skippedNode + if path != "/" { + for l := len(*skippedNodes); l > 0; { + skippedNode := (*skippedNodes)[l-1] + *skippedNodes = (*skippedNodes)[:l-1] + if strings.HasSuffix(skippedNode.path, path) { + path = skippedNode.path + n = skippedNode.node + if value.params != nil { + *value.params = (*value.params)[:skippedNode.paramsCount] + } + globalParamsCount = skippedNode.paramsCount + continue walk + } + } } // Nothing found. We can recommend to redirect to the same URL with an diff --git a/tree_test.go b/tree_test.go index cbb37340..49b3b57e 100644 --- a/tree_test.go +++ b/tree_test.go @@ -33,6 +33,11 @@ func getParams() *Params { return &ps } +func getSkippedNodes() *[]skippedNode { + ps := make([]skippedNode, 0, 20) + return &ps +} + func checkRequests(t *testing.T, tree *node, requests testRequests, unescapes ...bool) { unescape := false if len(unescapes) >= 1 { @@ -40,7 +45,7 @@ func checkRequests(t *testing.T, tree *node, requests testRequests, unescapes .. } for _, request := range requests { - value := tree.getValue(request.path, getParams(), unescape) + value := tree.getValue(request.path, getParams(), getSkippedNodes(), unescape) if value.handlers == nil { if !request.nilHandler { @@ -157,6 +162,8 @@ func TestTreeWildcard(t *testing.T) { "/aa/*xx", "/ab/*xx", "/:cc", + "/c1/:dd/e", + "/c1/:dd/e1", "/:cc/cc", "/:cc/:dd/ee", "/:cc/:dd/:ee/ff", @@ -238,6 +245,9 @@ func TestTreeWildcard(t *testing.T) { {"/alldd", false, "/:cc", Params{Param{Key: "cc", Value: "alldd"}}}, {"/all/cc", false, "/:cc/cc", Params{Param{Key: "cc", Value: "all"}}}, {"/a/cc", false, "/:cc/cc", Params{Param{Key: "cc", Value: "a"}}}, + {"/c1/d/e", false, "/c1/:dd/e", Params{Param{Key: "dd", Value: "d"}}}, + {"/c1/d/e1", false, "/c1/:dd/e1", Params{Param{Key: "dd", Value: "d"}}}, + {"/c1/d/ee", false, "/:cc/:dd/ee", Params{Param{Key: "cc", Value: "c1"}, Param{Key: "dd", Value: "d"}}}, {"/cc/cc", false, "/:cc/cc", Params{Param{Key: "cc", Value: "cc"}}}, {"/ccc/cc", false, "/:cc/cc", Params{Param{Key: "cc", Value: "ccc"}}}, {"/deedwjfs/cc", false, "/:cc/cc", Params{Param{Key: "cc", Value: "deedwjfs"}}}, @@ -437,7 +447,7 @@ func TestTreeChildConflict(t *testing.T) { testRoutes(t, routes) } -func TestTreeDupliatePath(t *testing.T) { +func TestTreeDuplicatePath(t *testing.T) { tree := &node{} routes := [...]string{ @@ -605,7 +615,7 @@ func TestTreeTrailingSlashRedirect(t *testing.T) { "/doc/", } for _, route := range tsrRoutes { - value := tree.getValue(route, nil, false) + value := tree.getValue(route, nil, getSkippedNodes(), false) if value.handlers != nil { t.Fatalf("non-nil handler for TSR route '%s", route) } else if !value.tsr { @@ -622,7 +632,7 @@ func TestTreeTrailingSlashRedirect(t *testing.T) { "/api/world/abc", } for _, route := range noTsrRoutes { - value := tree.getValue(route, nil, false) + value := tree.getValue(route, nil, getSkippedNodes(), false) if value.handlers != nil { t.Fatalf("non-nil handler for No-TSR route '%s", route) } else if value.tsr { @@ -641,7 +651,7 @@ func TestTreeRootTrailingSlashRedirect(t *testing.T) { t.Fatalf("panic inserting test route: %v", recv) } - value := tree.getValue("/", nil, false) + value := tree.getValue("/", nil, getSkippedNodes(), false) if value.handlers != nil { t.Fatalf("non-nil handler") } else if value.tsr { @@ -821,7 +831,7 @@ func TestTreeInvalidNodeType(t *testing.T) { // normal lookup recv := catchPanic(func() { - tree.getValue("/test", nil, false) + tree.getValue("/test", nil, getSkippedNodes(), false) }) if rs, ok := recv.(string); !ok || rs != panicMsg { t.Fatalf("Expected panic '"+panicMsg+"', got '%v'", recv) @@ -846,7 +856,7 @@ func TestTreeInvalidParamsType(t *testing.T) { params := make(Params, 0) // try to trigger slice bounds out of range with capacity 0 - tree.getValue("/test", ¶ms, false) + tree.getValue("/test", ¶ms, getSkippedNodes(), false) } func TestTreeWildcardConflictEx(t *testing.T) {