diff --git a/.github/workflows/gin.yml b/.github/workflows/gin.yml index 75e6d05d..1062252a 100644 --- a/.github/workflows/gin.yml +++ b/.github/workflows/gin.yml @@ -15,24 +15,33 @@ jobs: lint: runs-on: ubuntu-latest steps: - - name: Setup go + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Set up Go uses: actions/setup-go@v5 with: - go-version: "^1.18" - - name: Checkout repository - uses: actions/checkout@v4 + go-version: "^1" - name: Setup golangci-lint - uses: golangci/golangci-lint-action@v3.7.0 + uses: golangci/golangci-lint-action@v6 with: - version: v1.55.2 + version: v1.61.0 args: --verbose test: needs: lint strategy: matrix: os: [ubuntu-latest, macos-latest] - go: ["1.18", "1.19", "1.20", "1.21"] - test-tags: ["", "-tags nomsgpack", '-tags "sonic avx"', "-tags go_json"] + go: ["1.23", "1.24"] + test-tags: + [ + "", + "-tags nomsgpack", + '--ldflags="-checklinkname=0" -tags "sonic avx"', + "-tags go_json", + "-race", + ] include: - os: ubuntu-latest go-build: ~/.cache/go-build @@ -72,7 +81,3 @@ jobs: uses: codecov/codecov-action@v4 with: flags: ${{ matrix.os }},go-${{ matrix.go }},${{ matrix.test-tags }} - - - name: Format - if: matrix.go-version == '1.21.x' - run: diff -u <(echo -n) <(gofmt -d .) diff --git a/.github/workflows/goreleaser.yml b/.github/workflows/goreleaser.yml index cbd5d418..22edf453 100644 --- a/.github/workflows/goreleaser.yml +++ b/.github/workflows/goreleaser.yml @@ -16,14 +16,12 @@ jobs: uses: actions/checkout@v4 with: fetch-depth: 0 - - - name: Set up Go + - name: Set up Go uses: actions/setup-go@v5 with: go-version: "^1" - - - name: Run GoReleaser - uses: goreleaser/goreleaser-action@v5 + - name: Run GoReleaser + uses: goreleaser/goreleaser-action@v6 with: # either 'goreleaser' (default) or 'goreleaser-pro' distribution: goreleaser diff --git a/.gitignore b/.gitignore index bdd50c95..1ea0e2b9 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,7 @@ count.out test profile.out tmp.out + +# Develop tools +.idea/ +.vscode/ diff --git a/.golangci.yml b/.golangci.yml index 4a72f734..925e1306 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -7,7 +7,7 @@ linters: - durationcheck - errcheck - errorlint - - exportloopref + - copyloopvar - gci - gofmt - goimports @@ -16,7 +16,10 @@ linters: - nakedret - nilerr - nolintlint + - perfsprint - revive + - testifylint + - usestdlibvars - wastedassign linters-settings: @@ -33,6 +36,14 @@ linters-settings: - G112 - G201 - G203 + perfsprint: + err-error: true + errorf: true + int-conversion: true + sprintf1: true + strconcat: true + testifylint: + enable-all: true issues: exclude-rules: @@ -55,3 +66,6 @@ issues: - linters: - revive path: _test\.go + - path: gin.go + linters: + - gci diff --git a/.goreleaser.yaml b/.goreleaser.yaml index e435e56a..99b66fee 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -1,8 +1,7 @@ project_name: gin builds: - - - # If true, skip the build. + - # If true, skip the build. # Useful for library projects. # Default is false skip: true @@ -10,7 +9,7 @@ builds: changelog: # Set it to true if you wish to skip the changelog generation. # This may result in an empty release notes on GitHub/GitLab/Gitea. - skip: false + disable: false # Changelog generation implementation to use. # @@ -21,7 +20,7 @@ changelog: # - `github-native`: uses the GitHub release notes generation API, disables the groups feature. # # Defaults to `git`. - use: git + use: github # Sorts the changelog by the commit's messages. # Could either be asc, desc or empty @@ -38,20 +37,20 @@ changelog: - title: Features regexp: "^.*feat[(\\w)]*:+.*$" order: 0 - - title: 'Bug fixes' + - title: "Bug fixes" regexp: "^.*fix[(\\w)]*:+.*$" order: 1 - - title: 'Enhancements' + - title: "Enhancements" regexp: "^.*chore[(\\w)]*:+.*$" order: 2 + - title: "Refactor" + regexp: "^.*refactor[(\\w)]*:+.*$" + order: 3 + - title: "Build process updates" + regexp: ^.*?(build|ci)(\(.+\))??!?:.+$ + order: 4 + - title: "Documentation updates" + regexp: ^.*?docs?(\(.+\))??!?:.+$ + order: 4 - title: Others order: 999 - - filters: - # Commit messages matching the regexp listed here will be removed from - # the changelog - # Default is empty - exclude: - - '^docs' - - 'CICD' - - typo diff --git a/CHANGELOG.md b/CHANGELOG.md index 79685205..5648902d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,8 +1,82 @@ # Gin ChangeLog +## Gin v1.10.0 + +### Features + +* feat(auth): add proxy-server authentication (#3877) (@EndlessParadox1) +* feat(bind): ShouldBindBodyWith shortcut and change doc (#3871) (@RedCrazyGhost) +* feat(binding): Support custom BindUnmarshaler for binding. (#3933) (@dkkb) +* feat(binding): support override default binding implement (#3514) (@ssfyn) +* feat(engine): Added `OptionFunc` and `With` (#3572) (@flc1125) +* feat(logger): ability to skip logs based on user-defined logic (#3593) (@palvaneh) + +### Bug fixes + +* Revert "fix(uri): query binding bug (#3236)" (#3899) (@appleboy) +* fix(binding): binding error while not upload file (#3819) (#3820) (@clearcodecn) +* fix(binding): dereference pointer to struct (#3199) (@echovl) +* fix(context): make context Value method adhere to Go standards (#3897) (@FarmerChillax) +* fix(engine): fix unit test (#3878) (@flc1125) +* fix(header): Allow header according to RFC 7231 (HTTP 405) (#3759) (@Crocmagnon) +* fix(route): Add fullPath in context copy (#3784) (@KarthikReddyPuli) +* fix(router): catch-all conflicting wildcard (#3812) (@FirePing32) +* fix(sec): upgrade golang.org/x/crypto to 0.17.0 (#3832) (@chncaption) +* fix(tree): correctly expand the capacity of params (#3502) (@georgijd-form3) +* fix(uri): query binding bug (#3236) (@illiafox) +* fix: Add pointer support for url query params (#3659) (#3666) (@omkar-foss) +* fix: protect Context.Keys map when call Copy method (#3873) (@kingcanfish) + +### Enhancements + +* chore(CI): update release args (#3595) (@qloog) +* chore(IP): add TrustedPlatform constant for Fly.io. (#3839) (@ab) +* chore(debug): add ability to override the debugPrint statement (#2337) (@josegonzalez) +* chore(deps): update dependencies to latest versions (#3835) (@appleboy) +* chore(header): Add support for RFC 9512: application/yaml (#3851) (@vincentbernat) +* chore(http): use white color for HTTP 1XX (#3741) (@viralparmarme) +* chore(optimize): the ShouldBindUri method of the Context struct (#3911) (@1911860538) +* chore(perf): Optimize the Copy method of the Context struct (#3859) (@1911860538) +* chore(refactor): modify interface check way (#3855) (@demoManito) +* chore(request): check reader if it's nil before reading (#3419) (@noahyao1024) +* chore(security): upgrade Protobuf for CVE-2024-24786 (#3893) (@Fotkurz) +* chore: refactor CI and update dependencies (#3848) (@appleboy) +* chore: refactor configuration files for better readability (#3951) (@appleboy) +* chore: update GitHub Actions configuration (#3792) (@appleboy) +* chore: update changelog categories and improve documentation (#3917) (@appleboy) +* chore: update dependencies to latest versions (#3694) (@appleboy) +* chore: update external dependencies to latest versions (#3950) (@appleboy) +* chore: update various Go dependencies to latest versions (#3901) (@appleboy) + +### Build process updates + +* build(codecov): Added a codecov configuration (#3891) (@flc1125) +* ci(Makefile): vet command add .PHONY (#3915) (@imalasong) +* ci(lint): update tooling and workflows for consistency (#3834) (@appleboy) +* ci(release): refactor changelog regex patterns and exclusions (#3914) (@appleboy) +* ci(testing): add go1.22 version (#3842) (@appleboy) + +### Documentation updates + +* docs(context): Added deprecation comments to BindWith (#3880) (@flc1125) +* docs(middleware): comments to function `BasicAuthForProxy` (#3881) (@EndlessParadox1) +* docs: Add document to constant `AuthProxyUserKey` and `BasicAuthForProxy`. (#3887) (@EndlessParadox1) +* docs: fix typo in comment (#3868) (@testwill) +* docs: fix typo in function documentation (#3872) (@TotomiEcio) +* docs: remove redundant comments (#3765) (@WeiTheShinobi) +* feat: update version constant to v1.10.0 (#3952) (@appleboy) + +### Others + +* Upgrade golang.org/x/net -> v0.13.0 (#3684) (@cpcf) +* test(git): gitignore add develop tools (#3370) (@demoManito) +* test(http): use constant instead of numeric literal (#3863) (@testwill) +* test(path): Optimize unit test execution results (#3883) (@flc1125) +* test(render): increased unit tests coverage (#3691) (@araujo88) + ## Gin v1.9.1 -### BUG FIXES +### BUG FIXES * fix Request.Context() checks [#3512](https://github.com/gin-gonic/gin/pull/3512) @@ -414,7 +488,7 @@ - [FIX] Refactor render - [FIX] Reworked tests - [FIX] logger now supports cygwin -- [FIX] Use X-Forwarded-For before X-Real-Ip +- [FIX] Use X-Forwarded-For before X-Real-IP - [FIX] time.Time binding (#904) ## Gin 1.1.4 diff --git a/Makefile b/Makefile index ebde4ee8..1a7de86b 100644 --- a/Makefile +++ b/Makefile @@ -8,6 +8,7 @@ TESTFOLDER := $(shell $(GO) list ./... | grep -E 'gin$$|binding$$|render$$' | gr TESTTAGS ?= "" .PHONY: test +# Run tests to verify code functionality. test: echo "mode: count" > coverage.out for d in $(TESTFOLDER); do \ @@ -30,10 +31,12 @@ test: done .PHONY: fmt +# Ensure consistent code formatting. fmt: $(GOFMT) -w $(GOFILES) .PHONY: fmt-check +# format (check only). fmt-check: @diff=$$($(GOFMT) -d $(GOFILES)); \ if [ -n "$$diff" ]; then \ @@ -42,31 +45,37 @@ fmt-check: exit 1; \ fi; +.PHONY: vet +# Examine packages and report suspicious constructs if any. vet: $(GO) vet $(VETPACKAGES) .PHONY: lint +# Inspect source code for stylistic errors or potential bugs. lint: @hash golint > /dev/null 2>&1; if [ $$? -ne 0 ]; then \ $(GO) get -u golang.org/x/lint/golint; \ fi for PKG in $(PACKAGES); do golint -set_exit_status $$PKG || exit 1; done; -.PHONY: misspell-check -misspell-check: - @hash misspell > /dev/null 2>&1; if [ $$? -ne 0 ]; then \ - $(GO) get -u github.com/client9/misspell/cmd/misspell; \ - fi - misspell -error $(GOFILES) - .PHONY: misspell +# Correct commonly misspelled English words in source code. misspell: @hash misspell > /dev/null 2>&1; if [ $$? -ne 0 ]; then \ $(GO) get -u github.com/client9/misspell/cmd/misspell; \ fi misspell -w $(GOFILES) +.PHONY: misspell-check +# misspell (check only). +misspell-check: + @hash misspell > /dev/null 2>&1; if [ $$? -ne 0 ]; then \ + $(GO) get -u github.com/client9/misspell/cmd/misspell; \ + fi + misspell -error $(GOFILES) + .PHONY: tools +# Install tools (golint and misspell). tools: @if [ $(GO_VERSION) -gt 15 ]; then \ $(GO) install golang.org/x/lint/golint@latest; \ @@ -75,3 +84,23 @@ tools: $(GO) install golang.org/x/lint/golint; \ $(GO) install github.com/client9/misspell/cmd/misspell; \ fi + +.PHONY: help +# Help. +help: + @echo '' + @echo 'Usage:' + @echo ' make [target]' + @echo '' + @echo 'Targets:' + @awk '/^[a-zA-Z\-\0-9]+:/ { \ + helpMessage = match(lastLine, /^# (.*)/); \ + if (helpMessage) { \ + helpCommand = substr($$1, 0, index($$1, ":")-1); \ + helpMessage = substr(lastLine, RSTART + 2, RLENGTH); \ + printf " - \033[36m%-20s\033[0m %s\n", helpCommand, helpMessage; \ + } \ + } \ + { lastLine = $$0 }' $(MAKEFILE_LIST) + +.DEFAULT_GOAL := help diff --git a/README.md b/README.md index e007bf2f..9f548cc0 100644 --- a/README.md +++ b/README.md @@ -5,52 +5,50 @@ [![Build Status](https://github.com/gin-gonic/gin/workflows/Run%20Tests/badge.svg?branch=master)](https://github.com/gin-gonic/gin/actions?query=branch%3Amaster) [![codecov](https://codecov.io/gh/gin-gonic/gin/branch/master/graph/badge.svg)](https://codecov.io/gh/gin-gonic/gin) [![Go Report Card](https://goreportcard.com/badge/github.com/gin-gonic/gin)](https://goreportcard.com/report/github.com/gin-gonic/gin) -[![GoDoc](https://pkg.go.dev/badge/github.com/gin-gonic/gin?status.svg)](https://pkg.go.dev/github.com/gin-gonic/gin?tab=doc) +[![Go Reference](https://pkg.go.dev/badge/github.com/gin-gonic/gin?status.svg)](https://pkg.go.dev/github.com/gin-gonic/gin?tab=doc) [![Sourcegraph](https://sourcegraph.com/github.com/gin-gonic/gin/-/badge.svg)](https://sourcegraph.com/github.com/gin-gonic/gin?badge) [![Open Source Helpers](https://www.codetriage.com/gin-gonic/gin/badges/users.svg)](https://www.codetriage.com/gin-gonic/gin) [![Release](https://img.shields.io/github/release/gin-gonic/gin.svg?style=flat-square)](https://github.com/gin-gonic/gin/releases) [![TODOs](https://badgen.net/https/api.tickgit.com/badgen/github.com/gin-gonic/gin)](https://www.tickgit.com/browse?repo=github.com/gin-gonic/gin) -Gin is a web framework written in [Go](https://go.dev/). It features a martini-like API with performance that is up to 40 times faster thanks to [httprouter](https://github.com/julienschmidt/httprouter). If you need performance and good productivity, you will love Gin. +Gin is a web framework written in [Go](https://go.dev/). It features a martini-like API with performance that is up to 40 times faster thanks to [httprouter](https://github.com/julienschmidt/httprouter). +If you need performance and good productivity, you will love Gin. -**The key features of Gin are:** +**Gin's key features are:** - Zero allocation router -- Fast +- Speed - Middleware support - Crash-free - JSON validation -- Routes grouping +- Route grouping - Error management -- Rendering built-in -- Extendable - +- Built-in rendering +- Extensible ## Getting started ### Prerequisites -- **[Go](https://go.dev/)**: any one of the **three latest major** [releases](https://go.dev/doc/devel/release) (we test it with these). +Gin requires [Go](https://go.dev/) version [1.23](https://go.dev/doc/devel/release#go1.23.0) or above. ### Getting Gin -With [Go module](https://github.com/golang/go/wiki/Modules) support, simply add the following import +With [Go's module support](https://go.dev/wiki/Modules#how-to-use-modules), `go [build|run|test]` automatically fetches the necessary dependencies when you add the import in your code: -``` +```sh import "github.com/gin-gonic/gin" ``` -to your code, and then `go [build|run|test]` will automatically fetch the necessary dependencies. - -Otherwise, run the following Go command to install the `gin` package: +Alternatively, use `go get`: ```sh -$ go get -u github.com/gin-gonic/gin +go get -u github.com/gin-gonic/gin ``` ### Running Gin -First you need to import Gin package for using Gin, one simplest example likes the follow `example.go`: +A basic example: ```go package main @@ -72,31 +70,31 @@ func main() { } ``` -And use the Go command to run the demo: +To run the code, use the `go run` command, like: -``` -# run example.go and visit 0.0.0.0:8080/ping on browser -$ go run example.go +```sh +go run example.go ``` -### Learn more examples +Then visit [`0.0.0.0:8080/ping`](http://0.0.0.0:8080/ping) in your browser to see the response! + +### See more examples #### Quick Start -Learn and practice more examples, please read the [Gin Quick Start](docs/doc.md) which includes API examples and builds tag. +Learn and practice with the [Gin Quick Start](docs/doc.md), which includes API examples and builds tag. #### Examples -A number of ready-to-run examples demonstrating various use cases of Gin on the [Gin examples](https://github.com/gin-gonic/examples) repository. - +A number of ready-to-run examples demonstrating various use cases of Gin are available in the [Gin examples](https://github.com/gin-gonic/examples) repository. ## Documentation -See [API documentation and descriptions](https://godoc.org/github.com/gin-gonic/gin) for package. +See the [API documentation on go.dev](https://pkg.go.dev/github.com/gin-gonic/gin). -All documentation is available on the Gin website. +The documentation is also available on [gin-gonic.com](https://gin-gonic.com) in several languages: -- [English](https://gin-gonic.com/docs/) +- [English](https://gin-gonic.com/en/docs/) - [简体中文](https://gin-gonic.com/zh-cn/docs/) - [繁體中文](https://gin-gonic.com/zh-tw/docs/) - [日本語](https://gin-gonic.com/ja/docs/) @@ -104,19 +102,19 @@ All documentation is available on the Gin website. - [한국어](https://gin-gonic.com/ko-kr/docs/) - [Turkish](https://gin-gonic.com/tr/docs/) - [Persian](https://gin-gonic.com/fa/docs/) +- [Português](https://gin-gonic.com/pt/docs/) +- [Russian](https://gin-gonic.com/ru/docs/) -### Articles about Gin - -A curated list of awesome Gin framework. +### Articles - [Tutorial: Developing a RESTful API with Go and Gin](https://go.dev/doc/tutorial/web-service-gin) ## Benchmarks -Gin uses a custom version of [HttpRouter](https://github.com/julienschmidt/httprouter), [see all benchmarks details](/BENCHMARKS.md). +Gin uses a custom version of [HttpRouter](https://github.com/julienschmidt/httprouter), [see all benchmarks](/BENCHMARKS.md). | Benchmark name | (1) | (2) | (3) | (4) | -| ------------------------------ | ---------:| ---------------:| ------------:| ---------------:| +| ------------------------------ | --------: | --------------: | -----------: | --------------: | | BenchmarkGin_GithubAll | **43550** | **27364 ns/op** | **0 B/op** | **0 allocs/op** | | BenchmarkAce_GithubAll | 40543 | 29670 ns/op | 0 B/op | 0 allocs/op | | BenchmarkAero_GithubAll | 57632 | 20648 ns/op | 0 B/op | 0 allocs/op | @@ -153,26 +151,23 @@ Gin uses a custom version of [HttpRouter](https://github.com/julienschmidt/httpr - (3): Heap Memory (B/op), lower is better - (4): Average Allocations per Repetition (allocs/op), lower is better - -## Middlewares +## Middleware You can find many useful Gin middlewares at [gin-contrib](https://github.com/gin-contrib). +## Uses -## Users - -Awesome project lists using [Gin](https://github.com/gin-gonic/gin) web framework. - -* [gorush](https://github.com/appleboy/gorush): A push notification server written in Go. -* [fnproject](https://github.com/fnproject/fn): The container native, cloud agnostic serverless platform. -* [photoprism](https://github.com/photoprism/photoprism): Personal photo management powered by Go and Google TensorFlow. -* [lura](https://github.com/luraproject/lura): Ultra performant API Gateway with middlewares. -* [picfit](https://github.com/thoas/picfit): An image resizing server written in Go. -* [dkron](https://github.com/distribworks/dkron): Distributed, fault tolerant job scheduling system. +Here are some awesome projects that are using the [Gin](https://github.com/gin-gonic/gin) web framework. +- [gorush](https://github.com/appleboy/gorush): A push notification server. +- [fnproject](https://github.com/fnproject/fn): A container native, cloud agnostic serverless platform. +- [photoprism](https://github.com/photoprism/photoprism): Personal photo management powered by Google TensorFlow. +- [lura](https://github.com/luraproject/lura): Ultra performant API Gateway with middleware. +- [picfit](https://github.com/thoas/picfit): An image resizing server. +- [dkron](https://github.com/distribworks/dkron): Distributed, fault tolerant job scheduling system. ## Contributing Gin is the work of hundreds of contributors. We appreciate your help! -Please see [CONTRIBUTING](CONTRIBUTING.md) for details on submitting patches and the contribution workflow. +Please see [CONTRIBUTING.md](CONTRIBUTING.md) for details on submitting patches and the contribution workflow. diff --git a/auth.go b/auth.go index 2503c515..5d3222d5 100644 --- a/auth.go +++ b/auth.go @@ -16,6 +16,9 @@ import ( // AuthUserKey is the cookie name for user credential in basic auth. const AuthUserKey = "user" +// AuthProxyUserKey is the cookie name for proxy_user credential in basic auth for proxy. +const AuthProxyUserKey = "proxy_user" + // Accounts defines a key/value for user/pass list of authorized logins. type Accounts map[string]string @@ -89,3 +92,25 @@ func authorizationHeader(user, password string) string { base := user + ":" + password return "Basic " + base64.StdEncoding.EncodeToString(bytesconv.StringToBytes(base)) } + +// BasicAuthForProxy returns a Basic HTTP Proxy-Authorization middleware. +// If the realm is empty, "Proxy Authorization Required" will be used by default. +func BasicAuthForProxy(accounts Accounts, realm string) HandlerFunc { + if realm == "" { + realm = "Proxy Authorization Required" + } + realm = "Basic realm=" + strconv.Quote(realm) + pairs := processAccounts(accounts) + return func(c *Context) { + proxyUser, found := pairs.searchCredential(c.requestHeader("Proxy-Authorization")) + if !found { + // Credentials doesn't match, we return 407 and abort handlers chain. + c.Header("Proxy-Authenticate", realm) + c.AbortWithStatus(http.StatusProxyAuthRequired) + return + } + // The proxy_user credentials was found, set proxy_user's id to key AuthProxyUserKey in this context, the proxy_user's id can be read later using + // c.MustGet(gin.AuthProxyUserKey). + c.Set(AuthProxyUserKey, proxyUser) + } +} diff --git a/auth_test.go b/auth_test.go index 42b6f8fd..9166e3b0 100644 --- a/auth_test.go +++ b/auth_test.go @@ -90,7 +90,7 @@ func TestBasicAuthSucceed(t *testing.T) { }) w := httptest.NewRecorder() - req, _ := http.NewRequest("GET", "/login", nil) + req, _ := http.NewRequest(http.MethodGet, "/login", nil) req.Header.Set("Authorization", authorizationHeader("admin", "password")) router.ServeHTTP(w, req) @@ -109,7 +109,7 @@ func TestBasicAuth401(t *testing.T) { }) w := httptest.NewRecorder() - req, _ := http.NewRequest("GET", "/login", nil) + req, _ := http.NewRequest(http.MethodGet, "/login", nil) req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:password"))) router.ServeHTTP(w, req) @@ -129,7 +129,7 @@ func TestBasicAuth401WithCustomRealm(t *testing.T) { }) w := httptest.NewRecorder() - req, _ := http.NewRequest("GET", "/login", nil) + req, _ := http.NewRequest(http.MethodGet, "/login", nil) req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:password"))) router.ServeHTTP(w, req) @@ -137,3 +137,40 @@ func TestBasicAuth401WithCustomRealm(t *testing.T) { assert.Equal(t, http.StatusUnauthorized, w.Code) assert.Equal(t, "Basic realm=\"My Custom \\\"Realm\\\"\"", w.Header().Get("WWW-Authenticate")) } + +func TestBasicAuthForProxySucceed(t *testing.T) { + accounts := Accounts{"admin": "password"} + router := New() + router.Use(BasicAuthForProxy(accounts, "")) + router.Any("/*proxyPath", func(c *Context) { + c.String(http.StatusOK, c.MustGet(AuthProxyUserKey).(string)) + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/test", nil) + req.Header.Set("Proxy-Authorization", authorizationHeader("admin", "password")) + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "admin", w.Body.String()) +} + +func TestBasicAuthForProxy407(t *testing.T) { + called := false + accounts := Accounts{"foo": "bar"} + router := New() + router.Use(BasicAuthForProxy(accounts, "")) + router.Any("/*proxyPath", func(c *Context) { + called = true + c.String(http.StatusOK, c.MustGet(AuthProxyUserKey).(string)) + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/test", nil) + req.Header.Set("Proxy-Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:password"))) + router.ServeHTTP(w, req) + + assert.False(t, called) + assert.Equal(t, http.StatusProxyAuthRequired, w.Code) + assert.Equal(t, "Basic realm=\"Proxy Authorization Required\"", w.Header().Get("Proxy-Authenticate")) +} diff --git a/benchmarks_test.go b/benchmarks_test.go index 5b7929b8..3a8d53f3 100644 --- a/benchmarks_test.go +++ b/benchmarks_test.go @@ -14,21 +14,21 @@ import ( func BenchmarkOneRoute(B *testing.B) { router := New() router.GET("/ping", func(c *Context) {}) - runRequest(B, router, "GET", "/ping") + runRequest(B, router, http.MethodGet, "/ping") } func BenchmarkRecoveryMiddleware(B *testing.B) { router := New() router.Use(Recovery()) router.GET("/", func(c *Context) {}) - runRequest(B, router, "GET", "/") + runRequest(B, router, http.MethodGet, "/") } func BenchmarkLoggerMiddleware(B *testing.B) { router := New() router.Use(LoggerWithWriter(newMockWriter())) router.GET("/", func(c *Context) {}) - runRequest(B, router, "GET", "/") + runRequest(B, router, http.MethodGet, "/") } func BenchmarkManyHandlers(B *testing.B) { @@ -37,7 +37,7 @@ func BenchmarkManyHandlers(B *testing.B) { router.Use(func(c *Context) {}) router.Use(func(c *Context) {}) router.GET("/ping", func(c *Context) {}) - runRequest(B, router, "GET", "/ping") + runRequest(B, router, http.MethodGet, "/ping") } func Benchmark5Params(B *testing.B) { @@ -45,7 +45,7 @@ func Benchmark5Params(B *testing.B) { router := New() router.Use(func(c *Context) {}) router.GET("/param/:param1/:params2/:param3/:param4/:param5", func(c *Context) {}) - runRequest(B, router, "GET", "/param/path/to/parameter/john/12345") + runRequest(B, router, http.MethodGet, "/param/path/to/parameter/john/12345") } func BenchmarkOneRouteJSON(B *testing.B) { @@ -56,7 +56,7 @@ func BenchmarkOneRouteJSON(B *testing.B) { router.GET("/json", func(c *Context) { c.JSON(http.StatusOK, data) }) - runRequest(B, router, "GET", "/json") + runRequest(B, router, http.MethodGet, "/json") } func BenchmarkOneRouteHTML(B *testing.B) { @@ -68,7 +68,7 @@ func BenchmarkOneRouteHTML(B *testing.B) { router.GET("/html", func(c *Context) { c.HTML(http.StatusOK, "index", "hola") }) - runRequest(B, router, "GET", "/html") + runRequest(B, router, http.MethodGet, "/html") } func BenchmarkOneRouteSet(B *testing.B) { @@ -76,7 +76,7 @@ func BenchmarkOneRouteSet(B *testing.B) { router.GET("/ping", func(c *Context) { c.Set("key", "value") }) - runRequest(B, router, "GET", "/ping") + runRequest(B, router, http.MethodGet, "/ping") } func BenchmarkOneRouteString(B *testing.B) { @@ -84,13 +84,13 @@ func BenchmarkOneRouteString(B *testing.B) { router.GET("/text", func(c *Context) { c.String(http.StatusOK, "this is a plain text") }) - runRequest(B, router, "GET", "/text") + runRequest(B, router, http.MethodGet, "/text") } func BenchmarkManyRoutesFist(B *testing.B) { router := New() router.Any("/ping", func(c *Context) {}) - runRequest(B, router, "GET", "/ping") + runRequest(B, router, http.MethodGet, "/ping") } func BenchmarkManyRoutesLast(B *testing.B) { @@ -103,7 +103,7 @@ func Benchmark404(B *testing.B) { router := New() router.Any("/something", func(c *Context) {}) router.NoRoute(func(c *Context) {}) - runRequest(B, router, "GET", "/ping") + runRequest(B, router, http.MethodGet, "/ping") } func Benchmark404Many(B *testing.B) { @@ -118,7 +118,7 @@ func Benchmark404Many(B *testing.B) { router.GET("/user/:id/:mode", func(c *Context) {}) router.NoRoute(func(c *Context) {}) - runRequest(B, router, "GET", "/viewfake") + runRequest(B, router, http.MethodGet, "/viewfake") } type mockWriter struct { diff --git a/binding/binding.go b/binding/binding.go index 40948529..702d0e82 100644 --- a/binding/binding.go +++ b/binding/binding.go @@ -21,6 +21,7 @@ const ( MIMEMSGPACK = "application/x-msgpack" MIMEMSGPACK2 = "application/msgpack" MIMEYAML = "application/x-yaml" + MIMEYAML2 = "application/yaml" MIMETOML = "application/toml" ) @@ -72,18 +73,19 @@ var Validator StructValidator = &defaultValidator{} // These implement the Binding interface and can be used to bind the data // present in the request to struct instances. var ( - JSON = jsonBinding{} - XML = xmlBinding{} - Form = formBinding{} - Query = queryBinding{} - FormPost = formPostBinding{} - FormMultipart = formMultipartBinding{} - ProtoBuf = protobufBinding{} - MsgPack = msgpackBinding{} - YAML = yamlBinding{} - Uri = uriBinding{} - Header = headerBinding{} - TOML = tomlBinding{} + JSON BindingBody = jsonBinding{} + XML BindingBody = xmlBinding{} + Form Binding = formBinding{} + Query Binding = queryBinding{} + FormPost Binding = formPostBinding{} + FormMultipart Binding = formMultipartBinding{} + ProtoBuf BindingBody = protobufBinding{} + MsgPack BindingBody = msgpackBinding{} + YAML BindingBody = yamlBinding{} + Uri BindingUri = uriBinding{} + Header Binding = headerBinding{} + Plain BindingBody = plainBinding{} + TOML BindingBody = tomlBinding{} ) // Default returns the appropriate Binding instance based on the HTTP method @@ -102,7 +104,7 @@ func Default(method, contentType string) Binding { return ProtoBuf case MIMEMSGPACK, MIMEMSGPACK2: return MsgPack - case MIMEYAML: + case MIMEYAML, MIMEYAML2: return YAML case MIMETOML: return TOML diff --git a/binding/binding_msgpack_test.go b/binding/binding_msgpack_test.go index a6cd6aa8..7a5db34b 100644 --- a/binding/binding_msgpack_test.go +++ b/binding/binding_msgpack_test.go @@ -8,9 +8,11 @@ package binding import ( "bytes" + "net/http" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/ugorji/go/codec" ) @@ -24,7 +26,7 @@ func TestBindingMsgPack(t *testing.T) { buf := bytes.NewBuffer([]byte{}) assert.NotNil(t, buf) err := codec.NewEncoder(buf, h).Encode(test) - assert.NoError(t, err) + require.NoError(t, err) data := buf.Bytes() @@ -38,20 +40,20 @@ func testMsgPackBodyBinding(t *testing.T, b Binding, name, path, badPath, body, assert.Equal(t, name, b.Name()) obj := FooStruct{} - req := requestWithBody("POST", path, body) + req := requestWithBody(http.MethodPost, path, body) req.Header.Add("Content-Type", MIMEMSGPACK) err := b.Bind(req, &obj) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "bar", obj.Foo) obj = FooStruct{} - req = requestWithBody("POST", badPath, badBody) + req = requestWithBody(http.MethodPost, badPath, badBody) req.Header.Add("Content-Type", MIMEMSGPACK) err = MsgPack.Bind(req, &obj) - assert.Error(t, err) + require.Error(t, err) } func TestBindingDefaultMsgPack(t *testing.T) { - assert.Equal(t, MsgPack, Default("POST", MIMEMSGPACK)) - assert.Equal(t, MsgPack, Default("PUT", MIMEMSGPACK2)) + assert.Equal(t, MsgPack, Default(http.MethodPost, MIMEMSGPACK)) + assert.Equal(t, MsgPack, Default(http.MethodPut, MIMEMSGPACK2)) } diff --git a/binding/binding_nomsgpack.go b/binding/binding_nomsgpack.go index 93ad8ba3..c8e61310 100644 --- a/binding/binding_nomsgpack.go +++ b/binding/binding_nomsgpack.go @@ -19,6 +19,7 @@ const ( MIMEMultipartPOSTForm = "multipart/form-data" MIMEPROTOBUF = "application/x-protobuf" MIMEYAML = "application/x-yaml" + MIMEYAML2 = "application/yaml" MIMETOML = "application/toml" ) @@ -80,6 +81,7 @@ var ( Uri = uriBinding{} Header = headerBinding{} TOML = tomlBinding{} + Plain = plainBinding{} ) // Default returns the appropriate Binding instance based on the HTTP method @@ -96,7 +98,7 @@ func Default(method, contentType string) Binding { return XML case MIMEPROTOBUF: return ProtoBuf - case MIMEYAML: + case MIMEYAML, MIMEYAML2: return YAML case MIMEMultipartPOSTForm: return FormMultipart diff --git a/binding/binding_test.go b/binding/binding_test.go index 9af4f88a..bdab3694 100644 --- a/binding/binding_test.go +++ b/binding/binding_test.go @@ -20,6 +20,7 @@ import ( "github.com/gin-gonic/gin/testdata/protoexample" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "google.golang.org/protobuf/proto" ) @@ -68,15 +69,19 @@ type FooStructDisallowUnknownFields struct { } type FooBarStructForTimeType struct { - TimeFoo time.Time `form:"time_foo" time_format:"2006-01-02" time_utc:"1" time_location:"Asia/Chongqing"` - TimeBar time.Time `form:"time_bar" time_format:"2006-01-02" time_utc:"1"` - CreateTime time.Time `form:"createTime" time_format:"unixNano"` - UnixTime time.Time `form:"unixTime" time_format:"unix"` + TimeFoo time.Time `form:"time_foo" time_format:"2006-01-02" time_utc:"1" time_location:"Asia/Chongqing"` + TimeBar time.Time `form:"time_bar" time_format:"2006-01-02" time_utc:"1"` + CreateTime time.Time `form:"createTime" time_format:"unixNano"` + UnixTime time.Time `form:"unixTime" time_format:"unix"` + UnixMilliTime time.Time `form:"unixMilliTime" time_format:"unixmilli"` + UnixMicroTime time.Time `form:"unixMicroTime" time_format:"uNiXmiCrO"` } type FooStructForTimeTypeNotUnixFormat struct { - CreateTime time.Time `form:"createTime" time_format:"unixNano"` - UnixTime time.Time `form:"unixTime" time_format:"unix"` + CreateTime time.Time `form:"createTime" time_format:"unixNano"` + UnixTime time.Time `form:"unixTime" time_format:"unix"` + UnixMilliTime time.Time `form:"unixMilliTime" time_format:"unixMilli"` + UnixMicroTime time.Time `form:"unixMicroTime" time_format:"unixMicro"` } type FooStructForTimeTypeNotFormat struct { @@ -144,36 +149,38 @@ type FooStructForMapPtrType struct { } func TestBindingDefault(t *testing.T) { - assert.Equal(t, Form, Default("GET", "")) - assert.Equal(t, Form, Default("GET", MIMEJSON)) + assert.Equal(t, Form, Default(http.MethodGet, "")) + assert.Equal(t, Form, Default(http.MethodGet, MIMEJSON)) - assert.Equal(t, JSON, Default("POST", MIMEJSON)) - assert.Equal(t, JSON, Default("PUT", MIMEJSON)) + assert.Equal(t, JSON, Default(http.MethodPost, MIMEJSON)) + assert.Equal(t, JSON, Default(http.MethodPut, MIMEJSON)) - assert.Equal(t, XML, Default("POST", MIMEXML)) - assert.Equal(t, XML, Default("PUT", MIMEXML2)) + assert.Equal(t, XML, Default(http.MethodPost, MIMEXML)) + assert.Equal(t, XML, Default(http.MethodPut, MIMEXML2)) - assert.Equal(t, Form, Default("POST", MIMEPOSTForm)) - assert.Equal(t, Form, Default("PUT", MIMEPOSTForm)) + assert.Equal(t, Form, Default(http.MethodPost, MIMEPOSTForm)) + assert.Equal(t, Form, Default(http.MethodPut, MIMEPOSTForm)) - assert.Equal(t, FormMultipart, Default("POST", MIMEMultipartPOSTForm)) - assert.Equal(t, FormMultipart, Default("PUT", MIMEMultipartPOSTForm)) + assert.Equal(t, FormMultipart, Default(http.MethodPost, MIMEMultipartPOSTForm)) + assert.Equal(t, FormMultipart, Default(http.MethodPut, MIMEMultipartPOSTForm)) - assert.Equal(t, ProtoBuf, Default("POST", MIMEPROTOBUF)) - assert.Equal(t, ProtoBuf, Default("PUT", MIMEPROTOBUF)) + assert.Equal(t, ProtoBuf, Default(http.MethodPost, MIMEPROTOBUF)) + assert.Equal(t, ProtoBuf, Default(http.MethodPut, MIMEPROTOBUF)) - assert.Equal(t, YAML, Default("POST", MIMEYAML)) - assert.Equal(t, YAML, Default("PUT", MIMEYAML)) + assert.Equal(t, YAML, Default(http.MethodPost, MIMEYAML)) + assert.Equal(t, YAML, Default(http.MethodPut, MIMEYAML)) + assert.Equal(t, YAML, Default(http.MethodPost, MIMEYAML2)) + assert.Equal(t, YAML, Default(http.MethodPut, MIMEYAML2)) - assert.Equal(t, TOML, Default("POST", MIMETOML)) - assert.Equal(t, TOML, Default("PUT", MIMETOML)) + assert.Equal(t, TOML, Default(http.MethodPost, MIMETOML)) + assert.Equal(t, TOML, Default(http.MethodPut, MIMETOML)) } func TestBindingJSONNilBody(t *testing.T) { var obj FooStruct req, _ := http.NewRequest(http.MethodPost, "/", nil) err := JSON.Bind(req, &obj) - assert.Error(t, err) + require.Error(t, err) } func TestBindingJSON(t *testing.T) { @@ -224,137 +231,137 @@ func TestBindingJSONStringMap(t *testing.T) { } func TestBindingForm(t *testing.T) { - testFormBinding(t, "POST", + testFormBinding(t, http.MethodPost, "/", "/", "foo=bar&bar=foo", "bar2=foo") } func TestBindingForm2(t *testing.T) { - testFormBinding(t, "GET", + testFormBinding(t, http.MethodGet, "/?foo=bar&bar=foo", "/?bar2=foo", "", "") } func TestBindingFormEmbeddedStruct(t *testing.T) { - testFormBindingEmbeddedStruct(t, "POST", + testFormBindingEmbeddedStruct(t, http.MethodPost, "/", "/", "page=1&size=2&appkey=test-appkey", "bar2=foo") } func TestBindingFormEmbeddedStruct2(t *testing.T) { - testFormBindingEmbeddedStruct(t, "GET", + testFormBindingEmbeddedStruct(t, http.MethodGet, "/?page=1&size=2&appkey=test-appkey", "/?bar2=foo", "", "") } func TestBindingFormDefaultValue(t *testing.T) { - testFormBindingDefaultValue(t, "POST", + testFormBindingDefaultValue(t, http.MethodPost, "/", "/", "foo=bar", "bar2=foo") } func TestBindingFormDefaultValue2(t *testing.T) { - testFormBindingDefaultValue(t, "GET", + testFormBindingDefaultValue(t, http.MethodGet, "/?foo=bar", "/?bar2=foo", "", "") } func TestBindingFormForTime(t *testing.T) { - testFormBindingForTime(t, "POST", + testFormBindingForTime(t, http.MethodPost, "/", "/", - "time_foo=2017-11-15&time_bar=&createTime=1562400033000000123&unixTime=1562400033", "bar2=foo") - testFormBindingForTimeNotUnixFormat(t, "POST", + "time_foo=2017-11-15&time_bar=&createTime=1562400033000000123&unixTime=1562400033&unixMilliTime=1562400033001&unixMicroTime=1562400033000012", "bar2=foo") + testFormBindingForTimeNotUnixFormat(t, http.MethodPost, "/", "/", - "time_foo=2017-11-15&createTime=bad&unixTime=bad", "bar2=foo") - testFormBindingForTimeNotFormat(t, "POST", + "time_foo=2017-11-15&createTime=bad&unixTime=bad&unixMilliTime=bad&unixMicroTime=bad", "bar2=foo") + testFormBindingForTimeNotFormat(t, http.MethodPost, "/", "/", "time_foo=2017-11-15", "bar2=foo") - testFormBindingForTimeFailFormat(t, "POST", + testFormBindingForTimeFailFormat(t, http.MethodPost, "/", "/", "time_foo=2017-11-15", "bar2=foo") - testFormBindingForTimeFailLocation(t, "POST", + testFormBindingForTimeFailLocation(t, http.MethodPost, "/", "/", "time_foo=2017-11-15", "bar2=foo") } func TestBindingFormForTime2(t *testing.T) { - testFormBindingForTime(t, "GET", - "/?time_foo=2017-11-15&time_bar=&createTime=1562400033000000123&unixTime=1562400033", "/?bar2=foo", + testFormBindingForTime(t, http.MethodGet, + "/?time_foo=2017-11-15&time_bar=&createTime=1562400033000000123&unixTime=1562400033&unixMilliTime=1562400033001&unixMicroTime=1562400033000012", "/?bar2=foo", "", "") - testFormBindingForTimeNotUnixFormat(t, "POST", + testFormBindingForTimeNotUnixFormat(t, http.MethodPost, "/", "/", - "time_foo=2017-11-15&createTime=bad&unixTime=bad", "bar2=foo") - testFormBindingForTimeNotFormat(t, "GET", + "time_foo=2017-11-15&createTime=bad&unixTime=bad&unixMilliTime=bad&unixMicroTime=bad", "bar2=foo") + testFormBindingForTimeNotFormat(t, http.MethodGet, "/?time_foo=2017-11-15", "/?bar2=foo", "", "") - testFormBindingForTimeFailFormat(t, "GET", + testFormBindingForTimeFailFormat(t, http.MethodGet, "/?time_foo=2017-11-15", "/?bar2=foo", "", "") - testFormBindingForTimeFailLocation(t, "GET", + testFormBindingForTimeFailLocation(t, http.MethodGet, "/?time_foo=2017-11-15", "/?bar2=foo", "", "") } func TestFormBindingIgnoreField(t *testing.T) { - testFormBindingIgnoreField(t, "POST", + testFormBindingIgnoreField(t, http.MethodPost, "/", "/", "-=bar", "") } func TestBindingFormInvalidName(t *testing.T) { - testFormBindingInvalidName(t, "POST", + testFormBindingInvalidName(t, http.MethodPost, "/", "/", "test_name=bar", "bar2=foo") } func TestBindingFormInvalidName2(t *testing.T) { - testFormBindingInvalidName2(t, "POST", + testFormBindingInvalidName2(t, http.MethodPost, "/", "/", "map_foo=bar", "bar2=foo") } func TestBindingFormForType(t *testing.T) { - testFormBindingForType(t, "POST", + testFormBindingForType(t, http.MethodPost, "/", "/", "map_foo={\"bar\":123}", "map_foo=1", "Map") - testFormBindingForType(t, "POST", + testFormBindingForType(t, http.MethodPost, "/", "/", "slice_foo=1&slice_foo=2", "bar2=1&bar2=2", "Slice") - testFormBindingForType(t, "GET", + testFormBindingForType(t, http.MethodGet, "/?slice_foo=1&slice_foo=2", "/?bar2=1&bar2=2", "", "", "Slice") - testFormBindingForType(t, "POST", + testFormBindingForType(t, http.MethodPost, "/", "/", "slice_map_foo=1&slice_map_foo=2", "bar2=1&bar2=2", "SliceMap") - testFormBindingForType(t, "GET", + testFormBindingForType(t, http.MethodGet, "/?slice_map_foo=1&slice_map_foo=2", "/?bar2=1&bar2=2", "", "", "SliceMap") - testFormBindingForType(t, "POST", + testFormBindingForType(t, http.MethodPost, "/", "/", "ptr_bar=test", "bar2=test", "Ptr") - testFormBindingForType(t, "GET", + testFormBindingForType(t, http.MethodGet, "/?ptr_bar=test", "/?bar2=test", "", "", "Ptr") - testFormBindingForType(t, "POST", + testFormBindingForType(t, http.MethodPost, "/", "/", "idx=123", "id1=1", "Struct") - testFormBindingForType(t, "GET", + testFormBindingForType(t, http.MethodGet, "/?idx=123", "/?id1=1", "", "", "Struct") - testFormBindingForType(t, "POST", + testFormBindingForType(t, http.MethodPost, "/", "/", "name=thinkerou", "name1=ou", "StructPointer") - testFormBindingForType(t, "GET", + testFormBindingForType(t, http.MethodGet, "/?name=thinkerou", "/?name1=ou", "", "", "StructPointer") } @@ -371,10 +378,10 @@ func TestBindingFormStringMap(t *testing.T) { func TestBindingFormStringSliceMap(t *testing.T) { obj := make(map[string][]string) - req := requestWithBody("POST", "/", "foo=something&foo=bar&hello=world") + req := requestWithBody(http.MethodPost, "/", "foo=something&foo=bar&hello=world") req.Header.Add("Content-Type", MIMEPOSTForm) err := Form.Bind(req, &obj) - assert.NoError(t, err) + require.NoError(t, err) assert.NotNil(t, obj) assert.Len(t, obj, 2) target := map[string][]string{ @@ -384,38 +391,38 @@ func TestBindingFormStringSliceMap(t *testing.T) { assert.True(t, reflect.DeepEqual(obj, target)) objInvalid := make(map[string][]int) - req = requestWithBody("POST", "/", "foo=something&foo=bar&hello=world") + req = requestWithBody(http.MethodPost, "/", "foo=something&foo=bar&hello=world") req.Header.Add("Content-Type", MIMEPOSTForm) err = Form.Bind(req, &objInvalid) - assert.Error(t, err) + require.Error(t, err) } func TestBindingQuery(t *testing.T) { - testQueryBinding(t, "POST", + testQueryBinding(t, http.MethodPost, "/?foo=bar&bar=foo", "/", "foo=unused", "bar2=foo") } func TestBindingQuery2(t *testing.T) { - testQueryBinding(t, "GET", + testQueryBinding(t, http.MethodGet, "/?foo=bar&bar=foo", "/?bar2=foo", "foo=unused", "") } func TestBindingQueryFail(t *testing.T) { - testQueryBindingFail(t, "POST", + testQueryBindingFail(t, http.MethodPost, "/?map_foo=", "/", "map_foo=unused", "bar2=foo") } func TestBindingQueryFail2(t *testing.T) { - testQueryBindingFail(t, "GET", + testQueryBindingFail(t, http.MethodGet, "/?map_foo=", "/?bar2=foo", "map_foo=unused", "") } func TestBindingQueryBoolFail(t *testing.T) { - testQueryBindingBoolFail(t, "GET", + testQueryBindingBoolFail(t, http.MethodGet, "/?bool_foo=fasl", "/?bar2=foo", "bool_foo=unused", "") } @@ -424,18 +431,18 @@ func TestBindingQueryStringMap(t *testing.T) { b := Query obj := make(map[string]string) - req := requestWithBody("GET", "/?foo=bar&hello=world", "") + req := requestWithBody(http.MethodGet, "/?foo=bar&hello=world", "") err := b.Bind(req, &obj) - assert.NoError(t, err) + require.NoError(t, err) assert.NotNil(t, obj) assert.Len(t, obj, 2) assert.Equal(t, "bar", obj["foo"]) assert.Equal(t, "world", obj["hello"]) obj = make(map[string]string) - req = requestWithBody("GET", "/?foo=bar&foo=2&hello=world", "") // should pick last + req = requestWithBody(http.MethodGet, "/?foo=bar&foo=2&hello=world", "") // should pick last err = b.Bind(req, &obj) - assert.NoError(t, err) + require.NoError(t, err) assert.NotNil(t, obj) assert.Len(t, obj, 2) assert.Equal(t, "2", obj["foo"]) @@ -492,29 +499,29 @@ func TestBindingYAMLFail(t *testing.T) { } func createFormPostRequest(t *testing.T) *http.Request { - req, err := http.NewRequest("POST", "/?foo=getfoo&bar=getbar", bytes.NewBufferString("foo=bar&bar=foo")) - assert.NoError(t, err) + req, err := http.NewRequest(http.MethodPost, "/?foo=getfoo&bar=getbar", bytes.NewBufferString("foo=bar&bar=foo")) + require.NoError(t, err) req.Header.Set("Content-Type", MIMEPOSTForm) return req } func createDefaultFormPostRequest(t *testing.T) *http.Request { - req, err := http.NewRequest("POST", "/?foo=getfoo&bar=getbar", bytes.NewBufferString("foo=bar")) - assert.NoError(t, err) + req, err := http.NewRequest(http.MethodPost, "/?foo=getfoo&bar=getbar", bytes.NewBufferString("foo=bar")) + require.NoError(t, err) req.Header.Set("Content-Type", MIMEPOSTForm) return req } func createFormPostRequestForMap(t *testing.T) *http.Request { - req, err := http.NewRequest("POST", "/?map_foo=getfoo", bytes.NewBufferString("map_foo={\"bar\":123}")) - assert.NoError(t, err) + req, err := http.NewRequest(http.MethodPost, "/?map_foo=getfoo", bytes.NewBufferString("map_foo={\"bar\":123}")) + require.NoError(t, err) req.Header.Set("Content-Type", MIMEPOSTForm) return req } func createFormPostRequestForMapFail(t *testing.T) *http.Request { - req, err := http.NewRequest("POST", "/?map_foo=getfoo", bytes.NewBufferString("map_foo=hello")) - assert.NoError(t, err) + req, err := http.NewRequest(http.MethodPost, "/?map_foo=getfoo", bytes.NewBufferString("map_foo=hello")) + require.NoError(t, err) req.Header.Set("Content-Type", MIMEPOSTForm) return req } @@ -525,20 +532,20 @@ func createFormFilesMultipartRequest(t *testing.T) *http.Request { mw := multipart.NewWriter(body) defer mw.Close() - assert.NoError(t, mw.SetBoundary(boundary)) - assert.NoError(t, mw.WriteField("foo", "bar")) - assert.NoError(t, mw.WriteField("bar", "foo")) + require.NoError(t, mw.SetBoundary(boundary)) + require.NoError(t, mw.WriteField("foo", "bar")) + require.NoError(t, mw.WriteField("bar", "foo")) f, err := os.Open("form.go") - assert.NoError(t, err) + require.NoError(t, err) defer f.Close() fw, err1 := mw.CreateFormFile("file", "form.go") - assert.NoError(t, err1) + require.NoError(t, err1) _, err = io.Copy(fw, f) - assert.NoError(t, err) + require.NoError(t, err) - req, err2 := http.NewRequest("POST", "/?foo=getfoo&bar=getbar", body) - assert.NoError(t, err2) + req, err2 := http.NewRequest(http.MethodPost, "/?foo=getfoo&bar=getbar", body) + require.NoError(t, err2) req.Header.Set("Content-Type", MIMEMultipartPOSTForm+"; boundary="+boundary) return req @@ -550,20 +557,20 @@ func createFormFilesMultipartRequestFail(t *testing.T) *http.Request { mw := multipart.NewWriter(body) defer mw.Close() - assert.NoError(t, mw.SetBoundary(boundary)) - assert.NoError(t, mw.WriteField("foo", "bar")) - assert.NoError(t, mw.WriteField("bar", "foo")) + require.NoError(t, mw.SetBoundary(boundary)) + require.NoError(t, mw.WriteField("foo", "bar")) + require.NoError(t, mw.WriteField("bar", "foo")) f, err := os.Open("form.go") - assert.NoError(t, err) + require.NoError(t, err) defer f.Close() fw, err1 := mw.CreateFormFile("file_foo", "form_foo.go") - assert.NoError(t, err1) + require.NoError(t, err1) _, err = io.Copy(fw, f) - assert.NoError(t, err) + require.NoError(t, err) - req, err2 := http.NewRequest("POST", "/?foo=getfoo&bar=getbar", body) - assert.NoError(t, err2) + req, err2 := http.NewRequest(http.MethodPost, "/?foo=getfoo&bar=getbar", body) + require.NoError(t, err2) req.Header.Set("Content-Type", MIMEMultipartPOSTForm+"; boundary="+boundary) return req @@ -575,11 +582,11 @@ func createFormMultipartRequest(t *testing.T) *http.Request { mw := multipart.NewWriter(body) defer mw.Close() - assert.NoError(t, mw.SetBoundary(boundary)) - assert.NoError(t, mw.WriteField("foo", "bar")) - assert.NoError(t, mw.WriteField("bar", "foo")) - req, err := http.NewRequest("POST", "/?foo=getfoo&bar=getbar", body) - assert.NoError(t, err) + require.NoError(t, mw.SetBoundary(boundary)) + require.NoError(t, mw.WriteField("foo", "bar")) + require.NoError(t, mw.WriteField("bar", "foo")) + req, err := http.NewRequest(http.MethodPost, "/?foo=getfoo&bar=getbar", body) + require.NoError(t, err) req.Header.Set("Content-Type", MIMEMultipartPOSTForm+"; boundary="+boundary) return req } @@ -590,10 +597,10 @@ func createFormMultipartRequestForMap(t *testing.T) *http.Request { mw := multipart.NewWriter(body) defer mw.Close() - assert.NoError(t, mw.SetBoundary(boundary)) - assert.NoError(t, mw.WriteField("map_foo", "{\"bar\":123, \"name\":\"thinkerou\", \"pai\": 3.14}")) - req, err := http.NewRequest("POST", "/?map_foo=getfoo", body) - assert.NoError(t, err) + require.NoError(t, mw.SetBoundary(boundary)) + require.NoError(t, mw.WriteField("map_foo", "{\"bar\":123, \"name\":\"thinkerou\", \"pai\": 3.14}")) + req, err := http.NewRequest(http.MethodPost, "/?map_foo=getfoo", body) + require.NoError(t, err) req.Header.Set("Content-Type", MIMEMultipartPOSTForm+"; boundary="+boundary) return req } @@ -604,10 +611,10 @@ func createFormMultipartRequestForMapFail(t *testing.T) *http.Request { mw := multipart.NewWriter(body) defer mw.Close() - assert.NoError(t, mw.SetBoundary(boundary)) - assert.NoError(t, mw.WriteField("map_foo", "3.14")) - req, err := http.NewRequest("POST", "/?map_foo=getfoo", body) - assert.NoError(t, err) + require.NoError(t, mw.SetBoundary(boundary)) + require.NoError(t, mw.WriteField("map_foo", "3.14")) + req, err := http.NewRequest(http.MethodPost, "/?map_foo=getfoo", body) + require.NoError(t, err) req.Header.Set("Content-Type", MIMEMultipartPOSTForm+"; boundary="+boundary) return req } @@ -615,7 +622,7 @@ func createFormMultipartRequestForMapFail(t *testing.T) *http.Request { func TestBindingFormPost(t *testing.T) { req := createFormPostRequest(t) var obj FooBarStruct - assert.NoError(t, FormPost.Bind(req, &obj)) + require.NoError(t, FormPost.Bind(req, &obj)) assert.Equal(t, "form-urlencoded", FormPost.Name()) assert.Equal(t, "bar", obj.Foo) @@ -625,7 +632,7 @@ func TestBindingFormPost(t *testing.T) { func TestBindingDefaultValueFormPost(t *testing.T) { req := createDefaultFormPostRequest(t) var obj FooDefaultBarStruct - assert.NoError(t, FormPost.Bind(req, &obj)) + require.NoError(t, FormPost.Bind(req, &obj)) assert.Equal(t, "bar", obj.Foo) assert.Equal(t, "hello", obj.Bar) @@ -635,22 +642,22 @@ func TestBindingFormPostForMap(t *testing.T) { req := createFormPostRequestForMap(t) var obj FooStructForMapType err := FormPost.Bind(req, &obj) - assert.NoError(t, err) - assert.Equal(t, float64(123), obj.MapFoo["bar"].(float64)) + require.NoError(t, err) + assert.InDelta(t, float64(123), obj.MapFoo["bar"].(float64), 0.01) } func TestBindingFormPostForMapFail(t *testing.T) { req := createFormPostRequestForMapFail(t) var obj FooStructForMapType err := FormPost.Bind(req, &obj) - assert.Error(t, err) + require.Error(t, err) } func TestBindingFormFilesMultipart(t *testing.T) { req := createFormFilesMultipartRequest(t) var obj FooBarFileStruct err := FormMultipart.Bind(req, &obj) - assert.NoError(t, err) + require.NoError(t, err) // file from os f, _ := os.Open("form.go") @@ -662,9 +669,9 @@ func TestBindingFormFilesMultipart(t *testing.T) { defer mf.Close() fileExpect, _ := io.ReadAll(mf) - assert.Equal(t, FormMultipart.Name(), "multipart/form-data") - assert.Equal(t, obj.Foo, "bar") - assert.Equal(t, obj.Bar, "foo") + assert.Equal(t, "multipart/form-data", FormMultipart.Name()) + assert.Equal(t, "bar", obj.Foo) + assert.Equal(t, "foo", obj.Bar) assert.Equal(t, fileExpect, fileActual) } @@ -672,13 +679,13 @@ func TestBindingFormFilesMultipartFail(t *testing.T) { req := createFormFilesMultipartRequestFail(t) var obj FooBarFileFailStruct err := FormMultipart.Bind(req, &obj) - assert.Error(t, err) + require.Error(t, err) } func TestBindingFormMultipart(t *testing.T) { req := createFormMultipartRequest(t) var obj FooBarStruct - assert.NoError(t, FormMultipart.Bind(req, &obj)) + require.NoError(t, FormMultipart.Bind(req, &obj)) assert.Equal(t, "multipart/form-data", FormMultipart.Name()) assert.Equal(t, "bar", obj.Foo) @@ -689,17 +696,17 @@ func TestBindingFormMultipartForMap(t *testing.T) { req := createFormMultipartRequestForMap(t) var obj FooStructForMapType err := FormMultipart.Bind(req, &obj) - assert.NoError(t, err) - assert.Equal(t, float64(123), obj.MapFoo["bar"].(float64)) + require.NoError(t, err) + assert.InDelta(t, float64(123), obj.MapFoo["bar"].(float64), 0.01) assert.Equal(t, "thinkerou", obj.MapFoo["name"].(string)) - assert.Equal(t, float64(3.14), obj.MapFoo["pai"].(float64)) + assert.InDelta(t, float64(3.14), obj.MapFoo["pai"].(float64), 0.01) } func TestBindingFormMultipartForMapFail(t *testing.T) { req := createFormMultipartRequestForMapFail(t) var obj FooStructForMapType err := FormMultipart.Bind(req, &obj) - assert.Error(t, err) + require.Error(t, err) } func TestBindingProtoBuf(t *testing.T) { @@ -728,9 +735,9 @@ func TestBindingProtoBufFail(t *testing.T) { func TestValidationFails(t *testing.T) { var obj FooStruct - req := requestWithBody("POST", "/", `{"bar": "foo"}`) + req := requestWithBody(http.MethodPost, "/", `{"bar": "foo"}`) err := JSON.Bind(req, &obj) - assert.Error(t, err) + require.Error(t, err) } func TestValidationDisabled(t *testing.T) { @@ -739,9 +746,9 @@ func TestValidationDisabled(t *testing.T) { defer func() { Validator = backup }() var obj FooStruct - req := requestWithBody("POST", "/", `{"bar": "foo"}`) + req := requestWithBody(http.MethodPost, "/", `{"bar": "foo"}`) err := JSON.Bind(req, &obj) - assert.NoError(t, err) + require.NoError(t, err) } func TestRequiredSucceeds(t *testing.T) { @@ -750,9 +757,9 @@ func TestRequiredSucceeds(t *testing.T) { } var obj HogeStruct - req := requestWithBody("POST", "/", `{"hoge": 0}`) + req := requestWithBody(http.MethodPost, "/", `{"hoge": 0}`) err := JSON.Bind(req, &obj) - assert.NoError(t, err) + require.NoError(t, err) } func TestRequiredFails(t *testing.T) { @@ -761,9 +768,9 @@ func TestRequiredFails(t *testing.T) { } var obj HogeStruct - req := requestWithBody("POST", "/", `{"boen": 0}`) + req := requestWithBody(http.MethodPost, "/", `{"boen": 0}`) err := JSON.Bind(req, &obj) - assert.Error(t, err) + require.Error(t, err) } func TestHeaderBinding(t *testing.T) { @@ -775,12 +782,12 @@ func TestHeaderBinding(t *testing.T) { } var theader tHeader - req := requestWithBody("GET", "/", "") + req := requestWithBody(http.MethodGet, "/", "") req.Header.Add("limit", "1000") - assert.NoError(t, h.Bind(req, &theader)) + require.NoError(t, h.Bind(req, &theader)) assert.Equal(t, 1000, theader.Limit) - req = requestWithBody("GET", "/", "") + req = requestWithBody(http.MethodGet, "/", "") req.Header.Add("fail", `{fail:fail}`) type failStruct struct { @@ -788,7 +795,7 @@ func TestHeaderBinding(t *testing.T) { } err := h.Bind(req, &failStruct{}) - assert.Error(t, err) + require.Error(t, err) } func TestUriBinding(t *testing.T) { @@ -801,14 +808,14 @@ func TestUriBinding(t *testing.T) { var tag Tag m := make(map[string][]string) m["name"] = []string{"thinkerou"} - assert.NoError(t, b.BindUri(m, &tag)) + require.NoError(t, b.BindUri(m, &tag)) assert.Equal(t, "thinkerou", tag.Name) type NotSupportStruct struct { Name map[string]any `uri:"name"` } var not NotSupportStruct - assert.Error(t, b.BindUri(m, ¬)) + require.Error(t, b.BindUri(m, ¬)) assert.Equal(t, map[string]any(nil), not.Name) } @@ -829,9 +836,9 @@ func TestUriInnerBinding(t *testing.T) { } var tag Tag - assert.NoError(t, Uri.BindUri(m, &tag)) - assert.Equal(t, tag.Name, expectedName) - assert.Equal(t, tag.S.Age, expectedAge) + require.NoError(t, Uri.BindUri(m, &tag)) + assert.Equal(t, expectedName, tag.Name) + assert.Equal(t, expectedAge, tag.S.Age) } func testFormBindingEmbeddedStruct(t *testing.T, method, path, badPath, body, badBody string) { @@ -840,11 +847,11 @@ func testFormBindingEmbeddedStruct(t *testing.T, method, path, badPath, body, ba obj := QueryTest{} req := requestWithBody(method, path, body) - if method == "POST" { + if method == http.MethodPost { req.Header.Add("Content-Type", MIMEPOSTForm) } err := b.Bind(req, &obj) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, 1, obj.Page) assert.Equal(t, 2, obj.Size) assert.Equal(t, "test-appkey", obj.Appkey) @@ -856,18 +863,18 @@ func testFormBinding(t *testing.T, method, path, badPath, body, badBody string) obj := FooBarStruct{} req := requestWithBody(method, path, body) - if method == "POST" { + if method == http.MethodPost { req.Header.Add("Content-Type", MIMEPOSTForm) } err := b.Bind(req, &obj) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "bar", obj.Foo) assert.Equal(t, "foo", obj.Bar) obj = FooBarStruct{} req = requestWithBody(method, badPath, badBody) err = JSON.Bind(req, &obj) - assert.Error(t, err) + require.Error(t, err) } func testFormBindingDefaultValue(t *testing.T, method, path, badPath, body, badBody string) { @@ -876,18 +883,18 @@ func testFormBindingDefaultValue(t *testing.T, method, path, badPath, body, badB obj := FooDefaultBarStruct{} req := requestWithBody(method, path, body) - if method == "POST" { + if method == http.MethodPost { req.Header.Add("Content-Type", MIMEPOSTForm) } err := b.Bind(req, &obj) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "bar", obj.Foo) assert.Equal(t, "hello", obj.Bar) obj = FooDefaultBarStruct{} req = requestWithBody(method, badPath, badBody) err = JSON.Bind(req, &obj) - assert.Error(t, err) + require.Error(t, err) } func TestFormBindingFail(t *testing.T) { @@ -895,20 +902,20 @@ func TestFormBindingFail(t *testing.T) { assert.Equal(t, "form", b.Name()) obj := FooBarStruct{} - req, _ := http.NewRequest("POST", "/", nil) + req, _ := http.NewRequest(http.MethodPost, "/", nil) err := b.Bind(req, &obj) - assert.Error(t, err) + require.Error(t, err) } func TestFormBindingMultipartFail(t *testing.T) { obj := FooBarStruct{} - req, err := http.NewRequest("POST", "/", strings.NewReader("foo=bar")) - assert.NoError(t, err) + req, err := http.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")) + require.NoError(t, err) req.Header.Set("Content-Type", MIMEMultipartPOSTForm+";boundary=testboundary") _, err = req.MultipartReader() - assert.NoError(t, err) + require.NoError(t, err) err = Form.Bind(req, &obj) - assert.Error(t, err) + require.Error(t, err) } func TestFormPostBindingFail(t *testing.T) { @@ -916,9 +923,9 @@ func TestFormPostBindingFail(t *testing.T) { assert.Equal(t, "form-urlencoded", b.Name()) obj := FooBarStruct{} - req, _ := http.NewRequest("POST", "/", nil) + req, _ := http.NewRequest(http.MethodPost, "/", nil) err := b.Bind(req, &obj) - assert.Error(t, err) + require.Error(t, err) } func TestFormMultipartBindingFail(t *testing.T) { @@ -926,9 +933,9 @@ func TestFormMultipartBindingFail(t *testing.T) { assert.Equal(t, "multipart/form-data", b.Name()) obj := FooBarStruct{} - req, _ := http.NewRequest("POST", "/", nil) + req, _ := http.NewRequest(http.MethodPost, "/", nil) err := b.Bind(req, &obj) - assert.Error(t, err) + require.Error(t, err) } func testFormBindingForTime(t *testing.T, method, path, badPath, body, badBody string) { @@ -937,23 +944,25 @@ func testFormBindingForTime(t *testing.T, method, path, badPath, body, badBody s obj := FooBarStructForTimeType{} req := requestWithBody(method, path, body) - if method == "POST" { + if method == http.MethodPost { req.Header.Add("Content-Type", MIMEPOSTForm) } err := b.Bind(req, &obj) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, int64(1510675200), obj.TimeFoo.Unix()) assert.Equal(t, "Asia/Chongqing", obj.TimeFoo.Location().String()) assert.Equal(t, int64(-62135596800), obj.TimeBar.Unix()) assert.Equal(t, "UTC", obj.TimeBar.Location().String()) assert.Equal(t, int64(1562400033000000123), obj.CreateTime.UnixNano()) assert.Equal(t, int64(1562400033), obj.UnixTime.Unix()) + assert.Equal(t, int64(1562400033001), obj.UnixMilliTime.UnixMilli()) + assert.Equal(t, int64(1562400033000012), obj.UnixMicroTime.UnixMicro()) obj = FooBarStructForTimeType{} req = requestWithBody(method, badPath, badBody) err = JSON.Bind(req, &obj) - assert.Error(t, err) + require.Error(t, err) } func testFormBindingForTimeNotUnixFormat(t *testing.T, method, path, badPath, body, badBody string) { @@ -962,16 +971,16 @@ func testFormBindingForTimeNotUnixFormat(t *testing.T, method, path, badPath, bo obj := FooStructForTimeTypeNotUnixFormat{} req := requestWithBody(method, path, body) - if method == "POST" { + if method == http.MethodPost { req.Header.Add("Content-Type", MIMEPOSTForm) } err := b.Bind(req, &obj) - assert.Error(t, err) + require.Error(t, err) obj = FooStructForTimeTypeNotUnixFormat{} req = requestWithBody(method, badPath, badBody) err = JSON.Bind(req, &obj) - assert.Error(t, err) + require.Error(t, err) } func testFormBindingForTimeNotFormat(t *testing.T, method, path, badPath, body, badBody string) { @@ -980,16 +989,16 @@ func testFormBindingForTimeNotFormat(t *testing.T, method, path, badPath, body, obj := FooStructForTimeTypeNotFormat{} req := requestWithBody(method, path, body) - if method == "POST" { + if method == http.MethodPost { req.Header.Add("Content-Type", MIMEPOSTForm) } err := b.Bind(req, &obj) - assert.Error(t, err) + require.Error(t, err) obj = FooStructForTimeTypeNotFormat{} req = requestWithBody(method, badPath, badBody) err = JSON.Bind(req, &obj) - assert.Error(t, err) + require.Error(t, err) } func testFormBindingForTimeFailFormat(t *testing.T, method, path, badPath, body, badBody string) { @@ -998,16 +1007,16 @@ func testFormBindingForTimeFailFormat(t *testing.T, method, path, badPath, body, obj := FooStructForTimeTypeFailFormat{} req := requestWithBody(method, path, body) - if method == "POST" { + if method == http.MethodPost { req.Header.Add("Content-Type", MIMEPOSTForm) } err := b.Bind(req, &obj) - assert.Error(t, err) + require.Error(t, err) obj = FooStructForTimeTypeFailFormat{} req = requestWithBody(method, badPath, badBody) err = JSON.Bind(req, &obj) - assert.Error(t, err) + require.Error(t, err) } func testFormBindingForTimeFailLocation(t *testing.T, method, path, badPath, body, badBody string) { @@ -1016,16 +1025,16 @@ func testFormBindingForTimeFailLocation(t *testing.T, method, path, badPath, bod obj := FooStructForTimeTypeFailLocation{} req := requestWithBody(method, path, body) - if method == "POST" { + if method == http.MethodPost { req.Header.Add("Content-Type", MIMEPOSTForm) } err := b.Bind(req, &obj) - assert.Error(t, err) + require.Error(t, err) obj = FooStructForTimeTypeFailLocation{} req = requestWithBody(method, badPath, badBody) err = JSON.Bind(req, &obj) - assert.Error(t, err) + require.Error(t, err) } func testFormBindingIgnoreField(t *testing.T, method, path, badPath, body, badBody string) { @@ -1034,11 +1043,11 @@ func testFormBindingIgnoreField(t *testing.T, method, path, badPath, body, badBo obj := FooStructForIgnoreFormTag{} req := requestWithBody(method, path, body) - if method == "POST" { + if method == http.MethodPost { req.Header.Add("Content-Type", MIMEPOSTForm) } err := b.Bind(req, &obj) - assert.NoError(t, err) + require.NoError(t, err) assert.Nil(t, obj.Foo) } @@ -1049,17 +1058,17 @@ func testFormBindingInvalidName(t *testing.T, method, path, badPath, body, badBo obj := InvalidNameType{} req := requestWithBody(method, path, body) - if method == "POST" { + if method == http.MethodPost { req.Header.Add("Content-Type", MIMEPOSTForm) } err := b.Bind(req, &obj) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "", obj.TestName) obj = InvalidNameType{} req = requestWithBody(method, badPath, badBody) err = JSON.Bind(req, &obj) - assert.Error(t, err) + require.Error(t, err) } func testFormBindingInvalidName2(t *testing.T, method, path, badPath, body, badBody string) { @@ -1068,16 +1077,16 @@ func testFormBindingInvalidName2(t *testing.T, method, path, badPath, body, badB obj := InvalidNameMapType{} req := requestWithBody(method, path, body) - if method == "POST" { + if method == http.MethodPost { req.Header.Add("Content-Type", MIMEPOSTForm) } err := b.Bind(req, &obj) - assert.Error(t, err) + require.Error(t, err) obj = InvalidNameMapType{} req = requestWithBody(method, badPath, badBody) err = JSON.Bind(req, &obj) - assert.Error(t, err) + require.Error(t, err) } func testFormBindingForType(t *testing.T, method, path, badPath, body, badBody string, typ string) { @@ -1085,24 +1094,24 @@ func testFormBindingForType(t *testing.T, method, path, badPath, body, badBody s assert.Equal(t, "form", b.Name()) req := requestWithBody(method, path, body) - if method == "POST" { + if method == http.MethodPost { req.Header.Add("Content-Type", MIMEPOSTForm) } switch typ { case "Slice": obj := FooStructForSliceType{} err := b.Bind(req, &obj) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, []int{1, 2}, obj.SliceFoo) obj = FooStructForSliceType{} req = requestWithBody(method, badPath, badBody) err = JSON.Bind(req, &obj) - assert.Error(t, err) + require.Error(t, err) case "Struct": obj := FooStructForStructType{} err := b.Bind(req, &obj) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, struct { Idx int "form:\"idx\"" @@ -1111,7 +1120,7 @@ func testFormBindingForType(t *testing.T, method, path, badPath, body, badBody s case "StructPointer": obj := FooStructForStructPointerType{} err := b.Bind(req, &obj) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, struct { Name string "form:\"name\"" @@ -1120,33 +1129,33 @@ func testFormBindingForType(t *testing.T, method, path, badPath, body, badBody s case "Map": obj := FooStructForMapType{} err := b.Bind(req, &obj) - assert.NoError(t, err) - assert.Equal(t, float64(123), obj.MapFoo["bar"].(float64)) + require.NoError(t, err) + assert.InDelta(t, float64(123), obj.MapFoo["bar"].(float64), 0.01) case "SliceMap": obj := FooStructForSliceMapType{} err := b.Bind(req, &obj) - assert.Error(t, err) + require.Error(t, err) case "Ptr": obj := FooStructForStringPtrType{} err := b.Bind(req, &obj) - assert.NoError(t, err) + require.NoError(t, err) assert.Nil(t, obj.PtrFoo) assert.Equal(t, "test", *obj.PtrBar) obj = FooStructForStringPtrType{} obj.PtrBar = new(string) err = b.Bind(req, &obj) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "test", *obj.PtrBar) objErr := FooStructForMapPtrType{} err = b.Bind(req, &objErr) - assert.Error(t, err) + require.Error(t, err) obj = FooStructForStringPtrType{} req = requestWithBody(method, badPath, badBody) err = b.Bind(req, &obj) - assert.Error(t, err) + require.Error(t, err) } } @@ -1156,11 +1165,11 @@ func testQueryBinding(t *testing.T, method, path, badPath, body, badBody string) obj := FooBarStruct{} req := requestWithBody(method, path, body) - if method == "POST" { + if method == http.MethodPost { req.Header.Add("Content-Type", MIMEPOSTForm) } err := b.Bind(req, &obj) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "bar", obj.Foo) assert.Equal(t, "foo", obj.Bar) } @@ -1171,11 +1180,11 @@ func testQueryBindingFail(t *testing.T, method, path, badPath, body, badBody str obj := FooStructForMapType{} req := requestWithBody(method, path, body) - if method == "POST" { + if method == http.MethodPost { req.Header.Add("Content-Type", MIMEPOSTForm) } err := b.Bind(req, &obj) - assert.Error(t, err) + require.Error(t, err) } func testQueryBindingBoolFail(t *testing.T, method, path, badPath, body, badBody string) { @@ -1184,50 +1193,50 @@ func testQueryBindingBoolFail(t *testing.T, method, path, badPath, body, badBody obj := FooStructForBoolType{} req := requestWithBody(method, path, body) - if method == "POST" { + if method == http.MethodPost { req.Header.Add("Content-Type", MIMEPOSTForm) } err := b.Bind(req, &obj) - assert.Error(t, err) + require.Error(t, err) } func testBodyBinding(t *testing.T, b Binding, name, path, badPath, body, badBody string) { assert.Equal(t, name, b.Name()) obj := FooStruct{} - req := requestWithBody("POST", path, body) + req := requestWithBody(http.MethodPost, path, body) err := b.Bind(req, &obj) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "bar", obj.Foo) obj = FooStruct{} - req = requestWithBody("POST", badPath, badBody) + req = requestWithBody(http.MethodPost, badPath, badBody) err = JSON.Bind(req, &obj) - assert.Error(t, err) + require.Error(t, err) } func testBodyBindingSlice(t *testing.T, b Binding, name, path, badPath, body, badBody string) { assert.Equal(t, name, b.Name()) var obj1 []FooStruct - req := requestWithBody("POST", path, body) + req := requestWithBody(http.MethodPost, path, body) err := b.Bind(req, &obj1) - assert.NoError(t, err) + require.NoError(t, err) var obj2 []FooStruct - req = requestWithBody("POST", badPath, badBody) + req = requestWithBody(http.MethodPost, badPath, badBody) err = JSON.Bind(req, &obj2) - assert.Error(t, err) + require.Error(t, err) } func testBodyBindingStringMap(t *testing.T, b Binding, path, badPath, body, badBody string) { obj := make(map[string]string) - req := requestWithBody("POST", path, body) + req := requestWithBody(http.MethodPost, path, body) if b.Name() == "form" { req.Header.Add("Content-Type", MIMEPOSTForm) } err := b.Bind(req, &obj) - assert.NoError(t, err) + require.NoError(t, err) assert.NotNil(t, obj) assert.Len(t, obj, 2) assert.Equal(t, "bar", obj["foo"]) @@ -1235,52 +1244,52 @@ func testBodyBindingStringMap(t *testing.T, b Binding, path, badPath, body, badB if badPath != "" && badBody != "" { obj = make(map[string]string) - req = requestWithBody("POST", badPath, badBody) + req = requestWithBody(http.MethodPost, badPath, badBody) err = b.Bind(req, &obj) - assert.Error(t, err) + require.Error(t, err) } objInt := make(map[string]int) - req = requestWithBody("POST", path, body) + req = requestWithBody(http.MethodPost, path, body) err = b.Bind(req, &objInt) - assert.Error(t, err) + require.Error(t, err) } func testBodyBindingUseNumber(t *testing.T, b Binding, name, path, badPath, body, badBody string) { assert.Equal(t, name, b.Name()) obj := FooStructUseNumber{} - req := requestWithBody("POST", path, body) + req := requestWithBody(http.MethodPost, path, body) EnableDecoderUseNumber = true err := b.Bind(req, &obj) - assert.NoError(t, err) + require.NoError(t, err) // we hope it is int64(123) v, e := obj.Foo.(json.Number).Int64() - assert.NoError(t, e) + require.NoError(t, e) assert.Equal(t, int64(123), v) obj = FooStructUseNumber{} - req = requestWithBody("POST", badPath, badBody) + req = requestWithBody(http.MethodPost, badPath, badBody) err = JSON.Bind(req, &obj) - assert.Error(t, err) + require.Error(t, err) } func testBodyBindingUseNumber2(t *testing.T, b Binding, name, path, badPath, body, badBody string) { assert.Equal(t, name, b.Name()) obj := FooStructUseNumber{} - req := requestWithBody("POST", path, body) + req := requestWithBody(http.MethodPost, path, body) EnableDecoderUseNumber = false err := b.Bind(req, &obj) - assert.NoError(t, err) + require.NoError(t, err) // it will return float64(123) if not use EnableDecoderUseNumber // maybe it is not hoped - assert.Equal(t, float64(123), obj.Foo) + assert.InDelta(t, float64(123), obj.Foo, 0.01) obj = FooStructUseNumber{} - req = requestWithBody("POST", badPath, badBody) + req = requestWithBody(http.MethodPost, badPath, badBody) err = JSON.Bind(req, &obj) - assert.Error(t, err) + require.Error(t, err) } func testBodyBindingDisallowUnknownFields(t *testing.T, b Binding, path, badPath, body, badBody string) { @@ -1290,15 +1299,15 @@ func testBodyBindingDisallowUnknownFields(t *testing.T, b Binding, path, badPath }() obj := FooStructDisallowUnknownFields{} - req := requestWithBody("POST", path, body) + req := requestWithBody(http.MethodPost, path, body) err := b.Bind(req, &obj) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "bar", obj.Foo) obj = FooStructDisallowUnknownFields{} - req = requestWithBody("POST", badPath, badBody) + req = requestWithBody(http.MethodPost, badPath, badBody) err = JSON.Bind(req, &obj) - assert.Error(t, err) + require.Error(t, err) assert.Contains(t, err.Error(), "what") } @@ -1306,32 +1315,32 @@ func testBodyBindingFail(t *testing.T, b Binding, name, path, badPath, body, bad assert.Equal(t, name, b.Name()) obj := FooStruct{} - req := requestWithBody("POST", path, body) + req := requestWithBody(http.MethodPost, path, body) err := b.Bind(req, &obj) - assert.Error(t, err) + require.Error(t, err) assert.Equal(t, "", obj.Foo) obj = FooStruct{} - req = requestWithBody("POST", badPath, badBody) + req = requestWithBody(http.MethodPost, badPath, badBody) err = JSON.Bind(req, &obj) - assert.Error(t, err) + require.Error(t, err) } func testProtoBodyBinding(t *testing.T, b Binding, name, path, badPath, body, badBody string) { assert.Equal(t, name, b.Name()) obj := protoexample.Test{} - req := requestWithBody("POST", path, body) + req := requestWithBody(http.MethodPost, path, body) req.Header.Add("Content-Type", MIMEPROTOBUF) err := b.Bind(req, &obj) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "yes", *obj.Label) obj = protoexample.Test{} - req = requestWithBody("POST", badPath, badBody) + req = requestWithBody(http.MethodPost, badPath, badBody) req.Header.Add("Content-Type", MIMEPROTOBUF) err = ProtoBuf.Bind(req, &obj) - assert.Error(t, err) + require.Error(t, err) } type hook struct{} @@ -1340,29 +1349,69 @@ func (h hook) Read([]byte) (int, error) { return 0, errors.New("error") } +type failRead struct{} + +func (f *failRead) Read(b []byte) (n int, err error) { + return 0, errors.New("my fail") +} + +func (f *failRead) Close() error { + return nil +} + +func TestPlainBinding(t *testing.T) { + p := Plain + assert.Equal(t, "plain", p.Name()) + + var s string + req := requestWithBody(http.MethodPost, "/", "test string") + require.NoError(t, p.Bind(req, &s)) + assert.Equal(t, "test string", s) + + var bs []byte + req = requestWithBody(http.MethodPost, "/", "test []byte") + require.NoError(t, p.Bind(req, &bs)) + assert.Equal(t, bs, []byte("test []byte")) + + var i int + req = requestWithBody(http.MethodPost, "/", "test fail") + require.Error(t, p.Bind(req, &i)) + + req = requestWithBody(http.MethodPost, "/", "") + req.Body = &failRead{} + require.Error(t, p.Bind(req, &s)) + + req = requestWithBody(http.MethodPost, "/", "") + require.NoError(t, p.Bind(req, nil)) + + var ptr *string + req = requestWithBody(http.MethodPost, "/", "") + require.NoError(t, p.Bind(req, ptr)) +} + func testProtoBodyBindingFail(t *testing.T, b Binding, name, path, badPath, body, badBody string) { assert.Equal(t, name, b.Name()) obj := protoexample.Test{} - req := requestWithBody("POST", path, body) + req := requestWithBody(http.MethodPost, path, body) req.Body = io.NopCloser(&hook{}) req.Header.Add("Content-Type", MIMEPROTOBUF) err := b.Bind(req, &obj) - assert.Error(t, err) + require.Error(t, err) invalidobj := FooStruct{} req.Body = io.NopCloser(strings.NewReader(`{"msg":"hello"}`)) req.Header.Add("Content-Type", MIMEPROTOBUF) err = b.Bind(req, &invalidobj) - assert.Error(t, err) - assert.Equal(t, err.Error(), "obj is not ProtoMessage") + require.Error(t, err) + assert.Equal(t, "obj is not ProtoMessage", err.Error()) obj = protoexample.Test{} - req = requestWithBody("POST", badPath, badBody) + req = requestWithBody(http.MethodPost, badPath, badBody) req.Header.Add("Content-Type", MIMEPROTOBUF) err = ProtoBuf.Bind(req, &obj) - assert.Error(t, err) + require.Error(t, err) } func requestWithBody(method, path, body string) (req *http.Request) { diff --git a/binding/default_validator.go b/binding/default_validator.go index ac43d7cc..44b7a2ac 100644 --- a/binding/default_validator.go +++ b/binding/default_validator.go @@ -5,8 +5,8 @@ package binding import ( - "fmt" "reflect" + "strconv" "strings" "sync" @@ -22,25 +22,20 @@ type SliceValidationError []error // Error concatenates all error elements in SliceValidationError into a single string separated by \n. func (err SliceValidationError) Error() string { - n := len(err) - switch n { - case 0: + if len(err) == 0 { return "" - default: - var b strings.Builder - if err[0] != nil { - fmt.Fprintf(&b, "[%d]: %s", 0, err[0].Error()) - } - if n > 1 { - for i := 1; i < n; i++ { - if err[i] != nil { - b.WriteString("\n") - fmt.Fprintf(&b, "[%d]: %s", i, err[i].Error()) - } - } - } - return b.String() } + + var b strings.Builder + for i := 0; i < len(err); i++ { + if err[i] != nil { + if b.Len() > 0 { + b.WriteString("\n") + } + b.WriteString("[" + strconv.Itoa(i) + "]: " + err[i].Error()) + } + } + return b.String() } var _ StructValidator = (*defaultValidator)(nil) diff --git a/binding/default_validator_benchmark_test.go b/binding/default_validator_benchmark_test.go index 9292e2aa..44547412 100644 --- a/binding/default_validator_benchmark_test.go +++ b/binding/default_validator_benchmark_test.go @@ -12,11 +12,15 @@ import ( func BenchmarkSliceValidationError(b *testing.B) { const size int = 100 + e := make(SliceValidationError, size) + for j := 0; j < size; j++ { + e[j] = errors.New(strconv.Itoa(j)) + } + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { - e := make(SliceValidationError, size) - for j := 0; j < size; j++ { - e[j] = errors.New(strconv.Itoa(j)) - } if len(e.Error()) == 0 { b.Errorf("error") } diff --git a/binding/form_mapping.go b/binding/form_mapping.go index 77a1bde6..235692d2 100644 --- a/binding/form_mapping.go +++ b/binding/form_mapping.go @@ -159,12 +159,69 @@ func tryToSetValue(value reflect.Value, field reflect.StructField, setter setter if k, v := head(opt, "="); k == "default" { setOpt.isDefaultExists = true setOpt.defaultValue = v + + // convert semicolon-separated default values to csv-separated values for processing in setByForm + if field.Type.Kind() == reflect.Slice || field.Type.Kind() == reflect.Array { + cfTag := field.Tag.Get("collection_format") + if cfTag == "" || cfTag == "multi" || cfTag == "csv" { + setOpt.defaultValue = strings.ReplaceAll(v, ";", ",") + } + } } } return setter.TrySet(value, field, tagValue, setOpt) } +// BindUnmarshaler is the interface used to wrap the UnmarshalParam method. +type BindUnmarshaler interface { + // UnmarshalParam decodes and assigns a value from an form or query param. + UnmarshalParam(param string) error +} + +// trySetCustom tries to set a custom type value +// If the value implements the BindUnmarshaler interface, it will be used to set the value, we will return `true` +// to skip the default value setting. +func trySetCustom(val string, value reflect.Value) (isSet bool, err error) { + switch v := value.Addr().Interface().(type) { + case BindUnmarshaler: + return true, v.UnmarshalParam(val) + } + return false, nil +} + +func trySplit(vs []string, field reflect.StructField) (newVs []string, err error) { + cfTag := field.Tag.Get("collection_format") + if cfTag == "" || cfTag == "multi" { + return vs, nil + } + + var sep string + switch cfTag { + case "csv": + sep = "," + case "ssv": + sep = " " + case "tsv": + sep = "\t" + case "pipes": + sep = "|" + default: + return vs, fmt.Errorf("%s is not supported in the collection_format. (csv, ssv, pipes)", cfTag) + } + + totalLength := 0 + for _, v := range vs { + totalLength += strings.Count(v, sep) + 1 + } + newVs = make([]string, 0, totalLength) + for _, v := range vs { + newVs = append(newVs, strings.Split(v, sep)...) + } + + return newVs, nil +} + func setByForm(value reflect.Value, field reflect.StructField, form map[string][]string, tagValue string, opt setOptions) (isSet bool, err error) { vs, ok := form[tagValue] if !ok && !opt.isDefaultExists { @@ -175,15 +232,46 @@ func setByForm(value reflect.Value, field reflect.StructField, form map[string][ case reflect.Slice: if !ok { vs = []string{opt.defaultValue} + + // pre-process the default value for multi if present + cfTag := field.Tag.Get("collection_format") + if cfTag == "" || cfTag == "multi" { + vs = strings.Split(opt.defaultValue, ",") + } } + + if ok, err = trySetCustom(vs[0], value); ok { + return ok, err + } + + if vs, err = trySplit(vs, field); err != nil { + return false, err + } + return true, setSlice(vs, value, field) case reflect.Array: if !ok { vs = []string{opt.defaultValue} + + // pre-process the default value for multi if present + cfTag := field.Tag.Get("collection_format") + if cfTag == "" || cfTag == "multi" { + vs = strings.Split(opt.defaultValue, ",") + } } + + if ok, err = trySetCustom(vs[0], value); ok { + return ok, err + } + + if vs, err = trySplit(vs, field); err != nil { + return false, err + } + if len(vs) != value.Len() { return false, fmt.Errorf("%q is not valid value for %s", vs, value.Type().String()) } + return true, setArray(vs, value, field) default: var val string @@ -193,6 +281,12 @@ func setByForm(value reflect.Value, field reflect.StructField, form map[string][ if len(vs) > 0 { val = vs[0] + if val == "" { + val = opt.defaultValue + } + } + if ok, err := trySetCustom(val, value); ok { + return ok, err } return true, setWithProperType(val, value, field) } @@ -304,18 +398,24 @@ func setTimeField(val string, structField reflect.StructField, value reflect.Val } switch tf := strings.ToLower(timeFormat); tf { - case "unix", "unixnano": + case "unix", "unixmilli", "unixmicro", "unixnano": tv, err := strconv.ParseInt(val, 10, 64) if err != nil { return err } - d := time.Duration(1) - if tf == "unixnano" { - d = time.Second + var t time.Time + switch tf { + case "unix": + t = time.Unix(tv, 0) + case "unixmilli": + t = time.UnixMilli(tv) + case "unixmicro": + t = time.UnixMicro(tv) + default: + t = time.Unix(0, tv) } - t := time.Unix(tv/int64(d), tv%int64(d)) value.Set(reflect.ValueOf(t)) return nil } @@ -377,11 +477,8 @@ func setTimeDuration(val string, value reflect.Value) error { } func head(str, sep string) (head string, tail string) { - idx := strings.Index(str, sep) - if idx < 0 { - return str, "" - } - return str[:idx], str[idx+len(sep):] + head, tail, _ = strings.Cut(str, sep) + return head, tail } func setFormMap(ptr any, form map[string][]string) error { diff --git a/binding/form_mapping_test.go b/binding/form_mapping_test.go index 16527eb9..1277fd5f 100644 --- a/binding/form_mapping_test.go +++ b/binding/form_mapping_test.go @@ -5,12 +5,17 @@ package binding import ( + "encoding/hex" + "errors" "mime/multipart" "reflect" + "strconv" + "strings" "testing" "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestMappingBaseTypes(t *testing.T) { @@ -55,7 +60,7 @@ func TestMappingBaseTypes(t *testing.T) { field := val.Elem().Type().Field(0) _, err := mapping(val, emptyField, formSource{field.Name: {tt.form}}, "form") - assert.NoError(t, err, testName) + require.NoError(t, err, testName) actual := val.Elem().Field(0).Interface() assert.Equal(t, tt.expect, actual, testName) @@ -64,13 +69,15 @@ func TestMappingBaseTypes(t *testing.T) { func TestMappingDefault(t *testing.T) { var s struct { + Str string `form:",default=defaultVal"` Int int `form:",default=9"` Slice []int `form:",default=9"` Array [1]int `form:",default=9"` } err := mappingByPtr(&s, formSource{}, "form") - assert.NoError(t, err) + require.NoError(t, err) + assert.Equal(t, "defaultVal", s.Str) assert.Equal(t, 9, s.Int) assert.Equal(t, []int{9}, s.Slice) assert.Equal(t, [1]int{9}, s.Array) @@ -81,7 +88,7 @@ func TestMappingSkipField(t *testing.T) { A int } err := mappingByPtr(&s, formSource{}, "form") - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, 0, s.A) } @@ -92,7 +99,7 @@ func TestMappingIgnoreField(t *testing.T) { B int `form:"-"` } err := mappingByPtr(&s, formSource{"A": {"9"}, "B": {"9"}}, "form") - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, 9, s.A) assert.Equal(t, 0, s.B) @@ -104,7 +111,7 @@ func TestMappingUnexportedField(t *testing.T) { b int `form:"b"` } err := mappingByPtr(&s, formSource{"a": {"9"}, "b": {"9"}}, "form") - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, 9, s.A) assert.Equal(t, 0, s.b) @@ -115,7 +122,7 @@ func TestMappingPrivateField(t *testing.T) { f int `form:"field"` } err := mappingByPtr(&s, formSource{"field": {"6"}}, "form") - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, 0, s.f) } @@ -125,7 +132,7 @@ func TestMappingUnknownFieldType(t *testing.T) { } err := mappingByPtr(&s, formSource{"U": {"unknown"}}, "form") - assert.Error(t, err) + require.Error(t, err) assert.Equal(t, errUnknownType, err) } @@ -134,7 +141,7 @@ func TestMappingURI(t *testing.T) { F int `uri:"field"` } err := mapURI(&s, map[string][]string{"field": {"6"}}) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, 6, s.F) } @@ -143,16 +150,34 @@ func TestMappingForm(t *testing.T) { F int `form:"field"` } err := mapForm(&s, map[string][]string{"field": {"6"}}) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, 6, s.F) } +func TestMappingFormFieldNotSent(t *testing.T) { + var s struct { + F string `form:"field,default=defVal"` + } + err := mapForm(&s, map[string][]string{}) + require.NoError(t, err) + assert.Equal(t, "defVal", s.F) +} + +func TestMappingFormWithEmptyToDefault(t *testing.T) { + var s struct { + F string `form:"field,default=DefVal"` + } + err := mapForm(&s, map[string][]string{"field": {""}}) + require.NoError(t, err) + assert.Equal(t, "DefVal", s.F) +} + func TestMapFormWithTag(t *testing.T) { var s struct { F int `externalTag:"field"` } err := MapFormWithTag(&s, map[string][]string{"field": {"6"}}, "externalTag") - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, 6, s.F) } @@ -167,7 +192,7 @@ func TestMappingTime(t *testing.T) { var err error time.Local, err = time.LoadLocation("Europe/Berlin") - assert.NoError(t, err) + require.NoError(t, err) err = mapForm(&s, map[string][]string{ "Time": {"2019-01-20T16:02:58Z"}, @@ -176,7 +201,7 @@ func TestMappingTime(t *testing.T) { "CSTTime": {"2019-01-20"}, "UTCTime": {"2019-01-20"}, }) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "2019-01-20 16:02:58 +0000 UTC", s.Time.String()) assert.Equal(t, "2019-01-20 00:00:00 +0100 CET", s.LocalTime.String()) @@ -191,14 +216,14 @@ func TestMappingTime(t *testing.T) { Time time.Time `time_location:"wrong"` } err = mapForm(&wrongLoc, map[string][]string{"Time": {"2019-01-20T16:02:58Z"}}) - assert.Error(t, err) + require.Error(t, err) // wrong time value var wrongTime struct { Time time.Time } err = mapForm(&wrongTime, map[string][]string{"Time": {"wrong"}}) - assert.Error(t, err) + require.Error(t, err) } func TestMappingTimeDuration(t *testing.T) { @@ -208,12 +233,12 @@ func TestMappingTimeDuration(t *testing.T) { // ok err := mappingByPtr(&s, formSource{"D": {"5s"}}, "form") - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, 5*time.Second, s.D) // error err = mappingByPtr(&s, formSource{"D": {"wrong"}}, "form") - assert.Error(t, err) + require.Error(t, err) } func TestMappingSlice(t *testing.T) { @@ -223,17 +248,17 @@ func TestMappingSlice(t *testing.T) { // default value err := mappingByPtr(&s, formSource{}, "form") - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, []int{9}, s.Slice) // ok err = mappingByPtr(&s, formSource{"slice": {"3", "4"}}, "form") - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, []int{3, 4}, s.Slice) // error err = mappingByPtr(&s, formSource{"slice": {"wrong"}}, "form") - assert.Error(t, err) + require.Error(t, err) } func TestMappingArray(t *testing.T) { @@ -243,20 +268,125 @@ func TestMappingArray(t *testing.T) { // wrong default err := mappingByPtr(&s, formSource{}, "form") - assert.Error(t, err) + require.Error(t, err) // ok err = mappingByPtr(&s, formSource{"array": {"3", "4"}}, "form") - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, [2]int{3, 4}, s.Array) // error - not enough vals err = mappingByPtr(&s, formSource{"array": {"3"}}, "form") - assert.Error(t, err) + require.Error(t, err) // error - wrong value err = mappingByPtr(&s, formSource{"array": {"wrong"}}, "form") - assert.Error(t, err) + require.Error(t, err) +} + +func TestMappingCollectionFormat(t *testing.T) { + var s struct { + SliceMulti []int `form:"slice_multi" collection_format:"multi"` + SliceCsv []int `form:"slice_csv" collection_format:"csv"` + SliceSsv []int `form:"slice_ssv" collection_format:"ssv"` + SliceTsv []int `form:"slice_tsv" collection_format:"tsv"` + SlicePipes []int `form:"slice_pipes" collection_format:"pipes"` + ArrayMulti [2]int `form:"array_multi" collection_format:"multi"` + ArrayCsv [2]int `form:"array_csv" collection_format:"csv"` + ArraySsv [2]int `form:"array_ssv" collection_format:"ssv"` + ArrayTsv [2]int `form:"array_tsv" collection_format:"tsv"` + ArrayPipes [2]int `form:"array_pipes" collection_format:"pipes"` + } + err := mappingByPtr(&s, formSource{ + "slice_multi": {"1", "2"}, + "slice_csv": {"1,2"}, + "slice_ssv": {"1 2"}, + "slice_tsv": {"1 2"}, + "slice_pipes": {"1|2"}, + "array_multi": {"1", "2"}, + "array_csv": {"1,2"}, + "array_ssv": {"1 2"}, + "array_tsv": {"1 2"}, + "array_pipes": {"1|2"}, + }, "form") + require.NoError(t, err) + + assert.Equal(t, []int{1, 2}, s.SliceMulti) + assert.Equal(t, []int{1, 2}, s.SliceCsv) + assert.Equal(t, []int{1, 2}, s.SliceSsv) + assert.Equal(t, []int{1, 2}, s.SliceTsv) + assert.Equal(t, []int{1, 2}, s.SlicePipes) + assert.Equal(t, [2]int{1, 2}, s.ArrayMulti) + assert.Equal(t, [2]int{1, 2}, s.ArrayCsv) + assert.Equal(t, [2]int{1, 2}, s.ArraySsv) + assert.Equal(t, [2]int{1, 2}, s.ArrayTsv) + assert.Equal(t, [2]int{1, 2}, s.ArrayPipes) +} + +func TestMappingCollectionFormatInvalid(t *testing.T) { + var s struct { + SliceCsv []int `form:"slice_csv" collection_format:"xxx"` + } + err := mappingByPtr(&s, formSource{ + "slice_csv": {"1,2"}, + }, "form") + require.Error(t, err) + + var s2 struct { + ArrayCsv [2]int `form:"array_csv" collection_format:"xxx"` + } + err = mappingByPtr(&s2, formSource{ + "array_csv": {"1,2"}, + }, "form") + require.Error(t, err) +} + +func TestMappingMultipleDefaultWithCollectionFormat(t *testing.T) { + var s struct { + SliceMulti []int `form:",default=1;2;3" collection_format:"multi"` + SliceCsv []int `form:",default=1;2;3" collection_format:"csv"` + SliceSsv []int `form:",default=1 2 3" collection_format:"ssv"` + SliceTsv []int `form:",default=1\t2\t3" collection_format:"tsv"` + SlicePipes []int `form:",default=1|2|3" collection_format:"pipes"` + ArrayMulti [2]int `form:",default=1;2" collection_format:"multi"` + ArrayCsv [2]int `form:",default=1;2" collection_format:"csv"` + ArraySsv [2]int `form:",default=1 2" collection_format:"ssv"` + ArrayTsv [2]int `form:",default=1\t2" collection_format:"tsv"` + ArrayPipes [2]int `form:",default=1|2" collection_format:"pipes"` + SliceStringMulti []string `form:",default=1;2;3" collection_format:"multi"` + SliceStringCsv []string `form:",default=1;2;3" collection_format:"csv"` + SliceStringSsv []string `form:",default=1 2 3" collection_format:"ssv"` + SliceStringTsv []string `form:",default=1\t2\t3" collection_format:"tsv"` + SliceStringPipes []string `form:",default=1|2|3" collection_format:"pipes"` + ArrayStringMulti [2]string `form:",default=1;2" collection_format:"multi"` + ArrayStringCsv [2]string `form:",default=1;2" collection_format:"csv"` + ArrayStringSsv [2]string `form:",default=1 2" collection_format:"ssv"` + ArrayStringTsv [2]string `form:",default=1\t2" collection_format:"tsv"` + ArrayStringPipes [2]string `form:",default=1|2" collection_format:"pipes"` + } + err := mappingByPtr(&s, formSource{}, "form") + require.NoError(t, err) + + assert.Equal(t, []int{1, 2, 3}, s.SliceMulti) + assert.Equal(t, []int{1, 2, 3}, s.SliceCsv) + assert.Equal(t, []int{1, 2, 3}, s.SliceSsv) + assert.Equal(t, []int{1, 2, 3}, s.SliceTsv) + assert.Equal(t, []int{1, 2, 3}, s.SlicePipes) + assert.Equal(t, [2]int{1, 2}, s.ArrayMulti) + assert.Equal(t, [2]int{1, 2}, s.ArrayCsv) + assert.Equal(t, [2]int{1, 2}, s.ArraySsv) + assert.Equal(t, [2]int{1, 2}, s.ArrayTsv) + assert.Equal(t, [2]int{1, 2}, s.ArrayPipes) + assert.Equal(t, []string{"1", "2", "3"}, s.SliceStringMulti) + assert.Equal(t, []string{"1", "2", "3"}, s.SliceStringCsv) + assert.Equal(t, []string{"1", "2", "3"}, s.SliceStringSsv) + assert.Equal(t, []string{"1", "2", "3"}, s.SliceStringTsv) + assert.Equal(t, []string{"1", "2", "3"}, s.SliceStringPipes) + assert.Equal(t, [2]string{"1", "2"}, s.ArrayStringMulti) + assert.Equal(t, [2]string{"1", "2"}, s.ArrayStringCsv) + assert.Equal(t, [2]string{"1", "2"}, s.ArrayStringSsv) + assert.Equal(t, [2]string{"1", "2"}, s.ArrayStringTsv) + assert.Equal(t, [2]string{"1", "2"}, s.ArrayStringPipes) } func TestMappingStructField(t *testing.T) { @@ -267,7 +397,7 @@ func TestMappingStructField(t *testing.T) { } err := mappingByPtr(&s, formSource{"J": {`{"I": 9}`}}, "form") - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, 9, s.J.I) } @@ -285,20 +415,20 @@ func TestMappingPtrField(t *testing.T) { // With 0 items. var req0 ptrRequest err = mappingByPtr(&req0, formSource{}, "form") - assert.NoError(t, err) + require.NoError(t, err) assert.Empty(t, req0.Items) // With 1 item. var req1 ptrRequest err = mappingByPtr(&req1, formSource{"items": {`{"key": 1}`}}, "form") - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, req1.Items, 1) assert.EqualValues(t, 1, req1.Items[0].Key) // With 2 items. var req2 ptrRequest err = mappingByPtr(&req2, formSource{"items": {`{"key": 1}`, `{"key": 2}`}}, "form") - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, req2.Items, 2) assert.EqualValues(t, 1, req2.Items[0].Key) assert.EqualValues(t, 2, req2.Items[1].Key) @@ -310,7 +440,7 @@ func TestMappingMapField(t *testing.T) { } err := mappingByPtr(&s, formSource{"M": {`{"one": 1}`}}, "form") - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, map[string]int{"one": 1}, s.M) } @@ -321,5 +451,187 @@ func TestMappingIgnoredCircularRef(t *testing.T) { var s S err := mappingByPtr(&s, formSource{}, "form") - assert.NoError(t, err) + require.NoError(t, err) +} + +type customUnmarshalParamHex int + +func (f *customUnmarshalParamHex) UnmarshalParam(param string) error { + v, err := strconv.ParseInt(param, 16, 64) + if err != nil { + return err + } + *f = customUnmarshalParamHex(v) + return nil +} + +func TestMappingCustomUnmarshalParamHexWithFormTag(t *testing.T) { + var s struct { + Foo customUnmarshalParamHex `form:"foo"` + } + err := mappingByPtr(&s, formSource{"foo": {`f5`}}, "form") + require.NoError(t, err) + + assert.EqualValues(t, 245, s.Foo) +} + +func TestMappingCustomUnmarshalParamHexWithURITag(t *testing.T) { + var s struct { + Foo customUnmarshalParamHex `uri:"foo"` + } + err := mappingByPtr(&s, formSource{"foo": {`f5`}}, "uri") + require.NoError(t, err) + + assert.EqualValues(t, 245, s.Foo) +} + +type customUnmarshalParamType struct { + Protocol string + Path string + Name string +} + +func (f *customUnmarshalParamType) UnmarshalParam(param string) error { + parts := strings.Split(param, ":") + if len(parts) != 3 { + return errors.New("invalid format") + } + f.Protocol = parts[0] + f.Path = parts[1] + f.Name = parts[2] + return nil +} + +func TestMappingCustomStructTypeWithFormTag(t *testing.T) { + var s struct { + FileData customUnmarshalParamType `form:"data"` + } + err := mappingByPtr(&s, formSource{"data": {`file:/foo:happiness`}}, "form") + require.NoError(t, err) + + assert.EqualValues(t, "file", s.FileData.Protocol) + assert.EqualValues(t, "/foo", s.FileData.Path) + assert.EqualValues(t, "happiness", s.FileData.Name) +} + +func TestMappingCustomStructTypeWithURITag(t *testing.T) { + var s struct { + FileData customUnmarshalParamType `uri:"data"` + } + err := mappingByPtr(&s, formSource{"data": {`file:/foo:happiness`}}, "uri") + require.NoError(t, err) + + assert.EqualValues(t, "file", s.FileData.Protocol) + assert.EqualValues(t, "/foo", s.FileData.Path) + assert.EqualValues(t, "happiness", s.FileData.Name) +} + +func TestMappingCustomPointerStructTypeWithFormTag(t *testing.T) { + var s struct { + FileData *customUnmarshalParamType `form:"data"` + } + err := mappingByPtr(&s, formSource{"data": {`file:/foo:happiness`}}, "form") + require.NoError(t, err) + + assert.EqualValues(t, "file", s.FileData.Protocol) + assert.EqualValues(t, "/foo", s.FileData.Path) + assert.EqualValues(t, "happiness", s.FileData.Name) +} + +func TestMappingCustomPointerStructTypeWithURITag(t *testing.T) { + var s struct { + FileData *customUnmarshalParamType `uri:"data"` + } + err := mappingByPtr(&s, formSource{"data": {`file:/foo:happiness`}}, "uri") + require.NoError(t, err) + + assert.EqualValues(t, "file", s.FileData.Protocol) + assert.EqualValues(t, "/foo", s.FileData.Path) + assert.EqualValues(t, "happiness", s.FileData.Name) +} + +type customPath []string + +func (p *customPath) UnmarshalParam(param string) error { + elems := strings.Split(param, "/") + n := len(elems) + if n < 2 { + return errors.New("invalid format") + } + + *p = elems + return nil +} + +func TestMappingCustomSliceUri(t *testing.T) { + var s struct { + FileData customPath `uri:"path"` + } + err := mappingByPtr(&s, formSource{"path": {`bar/foo`}}, "uri") + require.NoError(t, err) + + assert.EqualValues(t, "bar", s.FileData[0]) + assert.EqualValues(t, "foo", s.FileData[1]) +} + +func TestMappingCustomSliceForm(t *testing.T) { + var s struct { + FileData customPath `form:"path"` + } + err := mappingByPtr(&s, formSource{"path": {`bar/foo`}}, "form") + require.NoError(t, err) + + assert.EqualValues(t, "bar", s.FileData[0]) + assert.EqualValues(t, "foo", s.FileData[1]) +} + +type objectID [12]byte + +func (o *objectID) UnmarshalParam(param string) error { + oid, err := convertTo(param) + if err != nil { + return err + } + + *o = oid + return nil +} + +func convertTo(s string) (objectID, error) { + var nilObjectID objectID + if len(s) != 24 { + return nilObjectID, errors.New("invalid format") + } + + var oid [12]byte + _, err := hex.Decode(oid[:], []byte(s)) + if err != nil { + return nilObjectID, err + } + + return oid, nil +} + +func TestMappingCustomArrayUri(t *testing.T) { + var s struct { + FileData objectID `uri:"id"` + } + val := `664a062ac74a8ad104e0e80f` + err := mappingByPtr(&s, formSource{"id": {val}}, "uri") + require.NoError(t, err) + + expected, _ := convertTo(val) + assert.EqualValues(t, expected, s.FileData) +} + +func TestMappingCustomArrayForm(t *testing.T) { + var s struct { + FileData objectID `form:"id"` + } + val := `664a062ac74a8ad104e0e80f` + err := mappingByPtr(&s, formSource{"id": {val}}, "form") + require.NoError(t, err) + + expected, _ := convertTo(val) + assert.EqualValues(t, expected, s.FileData) } diff --git a/binding/multipart_form_mapping_test.go b/binding/multipart_form_mapping_test.go index 4e97c0f0..c93f2141 100644 --- a/binding/multipart_form_mapping_test.go +++ b/binding/multipart_form_mapping_test.go @@ -12,6 +12,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestFormMultipartBindingBindOneFile(t *testing.T) { @@ -27,7 +28,7 @@ func TestFormMultipartBindingBindOneFile(t *testing.T) { req := createRequestMultipartFiles(t, file) err := FormMultipart.Bind(req, &s) - assert.NoError(t, err) + require.NoError(t, err) assertMultipartFileHeader(t, &s.FileValue, file) assertMultipartFileHeader(t, s.FilePtr, file) @@ -53,7 +54,7 @@ func TestFormMultipartBindingBindTwoFiles(t *testing.T) { req := createRequestMultipartFiles(t, files...) err := FormMultipart.Bind(req, &s) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, s.SliceValues, len(files)) assert.Len(t, s.SlicePtrs, len(files)) @@ -90,7 +91,7 @@ func TestFormMultipartBindingBindError(t *testing.T) { } { req := createRequestMultipartFiles(t, files...) err := FormMultipart.Bind(req, tt.s) - assert.Error(t, err) + require.Error(t, err) } } @@ -106,17 +107,17 @@ func createRequestMultipartFiles(t *testing.T, files ...testFile) *http.Request mw := multipart.NewWriter(&body) for _, file := range files { fw, err := mw.CreateFormFile(file.Fieldname, file.Filename) - assert.NoError(t, err) + require.NoError(t, err) n, err := fw.Write(file.Content) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, len(file.Content), n) } err := mw.Close() - assert.NoError(t, err) + require.NoError(t, err) - req, err := http.NewRequest("POST", "/", &body) - assert.NoError(t, err) + req, err := http.NewRequest(http.MethodPost, "/", &body) + require.NoError(t, err) req.Header.Set("Content-Type", MIMEMultipartPOSTForm+"; boundary="+mw.Boundary()) return req @@ -127,12 +128,12 @@ func assertMultipartFileHeader(t *testing.T, fh *multipart.FileHeader, file test assert.Equal(t, int64(len(file.Content)), fh.Size) fl, err := fh.Open() - assert.NoError(t, err) + require.NoError(t, err) body, err := io.ReadAll(fl) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, string(file.Content), string(body)) err = fl.Close() - assert.NoError(t, err) + require.NoError(t, err) } diff --git a/binding/plain.go b/binding/plain.go new file mode 100644 index 00000000..3b250bb0 --- /dev/null +++ b/binding/plain.go @@ -0,0 +1,56 @@ +package binding + +import ( + "fmt" + "io" + "net/http" + "reflect" + + "github.com/gin-gonic/gin/internal/bytesconv" +) + +type plainBinding struct{} + +func (plainBinding) Name() string { + return "plain" +} + +func (plainBinding) Bind(req *http.Request, obj interface{}) error { + all, err := io.ReadAll(req.Body) + if err != nil { + return err + } + + return decodePlain(all, obj) +} + +func (plainBinding) BindBody(body []byte, obj any) error { + return decodePlain(body, obj) +} + +func decodePlain(data []byte, obj any) error { + if obj == nil { + return nil + } + + v := reflect.ValueOf(obj) + + for v.Kind() == reflect.Ptr { + if v.IsNil() { + return nil + } + v = v.Elem() + } + + if v.Kind() == reflect.String { + v.SetString(bytesconv.BytesToString(data)) + return nil + } + + if _, ok := v.Interface().([]byte); ok { + v.SetBytes(data) + return nil + } + + return fmt.Errorf("type (%T) unknown type", v) +} diff --git a/binding/protobuf.go b/binding/protobuf.go index 57721fc9..259ae8e7 100644 --- a/binding/protobuf.go +++ b/binding/protobuf.go @@ -34,7 +34,7 @@ func (protobufBinding) BindBody(body []byte, obj any) error { if err := proto.Unmarshal(body, msg); err != nil { return err } - // Here it's same to return validate(obj), but util now we can't add + // Here it's same to return validate(obj), but until now we can't add // `binding:""` to the struct which automatically generate by gen-proto return nil // return validate(obj) diff --git a/binding/toml.go b/binding/toml.go index a66b93aa..2681231d 100644 --- a/binding/toml.go +++ b/binding/toml.go @@ -31,5 +31,5 @@ func decodeToml(r io.Reader, obj any) error { if err := decoder.Decode(obj); err != nil { return err } - return decoder.Decode(obj) + return validate(obj) } diff --git a/binding/validate_test.go b/binding/validate_test.go index 1fc15ff0..c9bbe601 100644 --- a/binding/validate_test.go +++ b/binding/validate_test.go @@ -11,6 +11,7 @@ import ( "github.com/go-playground/validator/v10" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) type testInterface interface { @@ -113,10 +114,10 @@ func TestValidateNoValidationValues(t *testing.T) { test := createNoValidationValues() empty := structNoValidationValues{} - assert.Nil(t, validate(test)) - assert.Nil(t, validate(&test)) - assert.Nil(t, validate(empty)) - assert.Nil(t, validate(&empty)) + require.NoError(t, validate(test)) + require.NoError(t, validate(&test)) + require.NoError(t, validate(empty)) + require.NoError(t, validate(&empty)) assert.Equal(t, origin, test) } @@ -163,8 +164,8 @@ func TestValidateNoValidationPointers(t *testing.T) { //assert.Nil(t, validate(test)) //assert.Nil(t, validate(&test)) - assert.Nil(t, validate(empty)) - assert.Nil(t, validate(&empty)) + require.NoError(t, validate(empty)) + require.NoError(t, validate(&empty)) //assert.Equal(t, origin, test) } @@ -173,22 +174,22 @@ type Object map[string]any func TestValidatePrimitives(t *testing.T) { obj := Object{"foo": "bar", "bar": 1} - assert.NoError(t, validate(obj)) - assert.NoError(t, validate(&obj)) + require.NoError(t, validate(obj)) + require.NoError(t, validate(&obj)) assert.Equal(t, Object{"foo": "bar", "bar": 1}, obj) obj2 := []Object{{"foo": "bar", "bar": 1}, {"foo": "bar", "bar": 1}} - assert.NoError(t, validate(obj2)) - assert.NoError(t, validate(&obj2)) + require.NoError(t, validate(obj2)) + require.NoError(t, validate(&obj2)) nu := 10 - assert.NoError(t, validate(nu)) - assert.NoError(t, validate(&nu)) + require.NoError(t, validate(nu)) + require.NoError(t, validate(&nu)) assert.Equal(t, 10, nu) str := "value" - assert.NoError(t, validate(str)) - assert.NoError(t, validate(&str)) + require.NoError(t, validate(str)) + require.NoError(t, validate(&str)) assert.Equal(t, "value", str) } @@ -212,8 +213,8 @@ func TestValidateAndModifyStruct(t *testing.T) { s := structModifyValidation{Integer: 1} errs := validate(&s) - assert.Nil(t, errs) - assert.Equal(t, s, structModifyValidation{Integer: 0}) + require.NoError(t, errs) + assert.Equal(t, structModifyValidation{Integer: 0}, s) } // structCustomValidation is a helper struct we use to check that @@ -239,14 +240,14 @@ func TestValidatorEngine(t *testing.T) { err := engine.RegisterValidation("notone", notOne) // Check that we can register custom validation without error - assert.Nil(t, err) + require.NoError(t, err) // Create an instance which will fail validation withOne := structCustomValidation{Integer: 1} errs := validate(withOne) // Check that we got back non-nil errs - assert.NotNil(t, errs) + require.Error(t, errs) // Check that the error matches expectation - assert.Error(t, errs, "", "", "notone") + require.Error(t, errs, "", "", "notone") } diff --git a/codecov.yml b/codecov.yml new file mode 100644 index 00000000..47782e50 --- /dev/null +++ b/codecov.yml @@ -0,0 +1,13 @@ +coverage: + require_ci_to_pass: true + + status: + project: + default: + target: 99% + threshold: 99% + + patch: + default: + target: 99% + threshold: 95% \ No newline at end of file diff --git a/context.go b/context.go index 1d7b47b3..288c0afb 100644 --- a/context.go +++ b/context.go @@ -7,6 +7,7 @@ package gin import ( "errors" "io" + "io/fs" "log" "math" "mime/multipart" @@ -34,6 +35,7 @@ const ( MIMEPOSTForm = binding.MIMEPOSTForm MIMEMultipartPOSTForm = binding.MIMEMultipartPOSTForm MIMEYAML = binding.MIMEYAML + MIMEYAML2 = binding.MIMEYAML2 MIMETOML = binding.MIMETOML ) @@ -43,6 +45,10 @@ const BodyBytesKey = "_gin-gonic/gin/bodybyteskey" // ContextKey is the key that a Context returns itself for. const ContextKey = "_gin-gonic/gin/contextkey" +type ContextKeyType int + +const ContextRequestKey ContextKeyType = 0 + // abortIndex represents a typical value used in abort functions. const abortIndex int8 = math.MaxInt8 >> 1 @@ -113,20 +119,27 @@ func (c *Context) Copy() *Context { cp := Context{ writermem: c.writermem, Request: c.Request, - Params: c.Params, engine: c.engine, } + cp.writermem.ResponseWriter = nil cp.Writer = &cp.writermem cp.index = abortIndex cp.handlers = nil - cp.Keys = map[string]any{} - for k, v := range c.Keys { + cp.fullPath = c.fullPath + + cKeys := c.Keys + cp.Keys = make(map[string]any, len(cKeys)) + c.mu.RLock() + for k, v := range cKeys { cp.Keys[k] = v } - paramCopy := make([]Param, len(cp.Params)) - copy(paramCopy, cp.Params) - cp.Params = paramCopy + c.mu.RUnlock() + + cParams := c.Params + cp.Params = make([]Param, len(cParams)) + copy(cp.Params, cParams) + return &cp } @@ -141,6 +154,9 @@ func (c *Context) HandlerName() string { func (c *Context) HandlerNames() []string { hn := make([]string, 0, len(c.handlers)) for _, val := range c.handlers { + if val == nil { + continue + } hn = append(hn, nameOfFunction(val)) } return hn @@ -171,7 +187,9 @@ func (c *Context) FullPath() string { func (c *Context) Next() { c.index++ for c.index < int8(len(c.handlers)) { - c.handlers[c.index](c) + if c.handlers[c.index] != nil { + c.handlers[c.index](c) + } c.index++ } } @@ -273,108 +291,171 @@ func (c *Context) MustGet(key string) any { panic("Key \"" + key + "\" does not exist") } -// GetString returns the value associated with the key as a string. -func (c *Context) GetString(key string) (s string) { +func getTyped[T any](c *Context, key string) (res T) { if val, ok := c.Get(key); ok && val != nil { - s, _ = val.(string) + res, _ = val.(T) } return } +// GetString returns the value associated with the key as a string. +func (c *Context) GetString(key string) (s string) { + return getTyped[string](c, key) +} + // GetBool returns the value associated with the key as a boolean. func (c *Context) GetBool(key string) (b bool) { - if val, ok := c.Get(key); ok && val != nil { - b, _ = val.(bool) - } - return + return getTyped[bool](c, key) } // GetInt returns the value associated with the key as an integer. func (c *Context) GetInt(key string) (i int) { - if val, ok := c.Get(key); ok && val != nil { - i, _ = val.(int) - } - return + return getTyped[int](c, key) } -// GetInt64 returns the value associated with the key as an integer. +// GetInt8 returns the value associated with the key as an integer 8. +func (c *Context) GetInt8(key string) (i8 int8) { + return getTyped[int8](c, key) +} + +// GetInt16 returns the value associated with the key as an integer 16. +func (c *Context) GetInt16(key string) (i16 int16) { + return getTyped[int16](c, key) +} + +// GetInt32 returns the value associated with the key as an integer 32. +func (c *Context) GetInt32(key string) (i32 int32) { + return getTyped[int32](c, key) +} + +// GetInt64 returns the value associated with the key as an integer 64. func (c *Context) GetInt64(key string) (i64 int64) { - if val, ok := c.Get(key); ok && val != nil { - i64, _ = val.(int64) - } - return + return getTyped[int64](c, key) } // GetUint returns the value associated with the key as an unsigned integer. func (c *Context) GetUint(key string) (ui uint) { - if val, ok := c.Get(key); ok && val != nil { - ui, _ = val.(uint) - } - return + return getTyped[uint](c, key) } -// GetUint64 returns the value associated with the key as an unsigned integer. +// GetUint8 returns the value associated with the key as an unsigned integer 8. +func (c *Context) GetUint8(key string) (ui8 uint8) { + return getTyped[uint8](c, key) +} + +// GetUint16 returns the value associated with the key as an unsigned integer 16. +func (c *Context) GetUint16(key string) (ui16 uint16) { + return getTyped[uint16](c, key) +} + +// GetUint32 returns the value associated with the key as an unsigned integer 32. +func (c *Context) GetUint32(key string) (ui32 uint32) { + return getTyped[uint32](c, key) +} + +// GetUint64 returns the value associated with the key as an unsigned integer 64. func (c *Context) GetUint64(key string) (ui64 uint64) { - if val, ok := c.Get(key); ok && val != nil { - ui64, _ = val.(uint64) - } - return + return getTyped[uint64](c, key) +} + +// GetFloat32 returns the value associated with the key as a float32. +func (c *Context) GetFloat32(key string) (f32 float32) { + return getTyped[float32](c, key) } // GetFloat64 returns the value associated with the key as a float64. func (c *Context) GetFloat64(key string) (f64 float64) { - if val, ok := c.Get(key); ok && val != nil { - f64, _ = val.(float64) - } - return + return getTyped[float64](c, key) } // GetTime returns the value associated with the key as time. func (c *Context) GetTime(key string) (t time.Time) { - if val, ok := c.Get(key); ok && val != nil { - t, _ = val.(time.Time) - } - return + return getTyped[time.Time](c, key) } // GetDuration returns the value associated with the key as a duration. func (c *Context) GetDuration(key string) (d time.Duration) { - if val, ok := c.Get(key); ok && val != nil { - d, _ = val.(time.Duration) - } - return + return getTyped[time.Duration](c, key) +} + +// GetIntSlice returns the value associated with the key as a slice of integers. +func (c *Context) GetIntSlice(key string) (is []int) { + return getTyped[[]int](c, key) +} + +// GetInt8Slice returns the value associated with the key as a slice of int8 integers. +func (c *Context) GetInt8Slice(key string) (i8s []int8) { + return getTyped[[]int8](c, key) +} + +// GetInt16Slice returns the value associated with the key as a slice of int16 integers. +func (c *Context) GetInt16Slice(key string) (i16s []int16) { + return getTyped[[]int16](c, key) +} + +// GetInt32Slice returns the value associated with the key as a slice of int32 integers. +func (c *Context) GetInt32Slice(key string) (i32s []int32) { + return getTyped[[]int32](c, key) +} + +// GetInt64Slice returns the value associated with the key as a slice of int64 integers. +func (c *Context) GetInt64Slice(key string) (i64s []int64) { + return getTyped[[]int64](c, key) +} + +// GetUintSlice returns the value associated with the key as a slice of unsigned integers. +func (c *Context) GetUintSlice(key string) (uis []uint) { + return getTyped[[]uint](c, key) +} + +// GetUint8Slice returns the value associated with the key as a slice of uint8 integers. +func (c *Context) GetUint8Slice(key string) (ui8s []uint8) { + return getTyped[[]uint8](c, key) +} + +// GetUint16Slice returns the value associated with the key as a slice of uint16 integers. +func (c *Context) GetUint16Slice(key string) (ui16s []uint16) { + return getTyped[[]uint16](c, key) +} + +// GetUint32Slice returns the value associated with the key as a slice of uint32 integers. +func (c *Context) GetUint32Slice(key string) (ui32s []uint32) { + return getTyped[[]uint32](c, key) +} + +// GetUint64Slice returns the value associated with the key as a slice of uint64 integers. +func (c *Context) GetUint64Slice(key string) (ui64s []uint64) { + return getTyped[[]uint64](c, key) +} + +// GetFloat32Slice returns the value associated with the key as a slice of float32 numbers. +func (c *Context) GetFloat32Slice(key string) (f32s []float32) { + return getTyped[[]float32](c, key) +} + +// GetFloat64Slice returns the value associated with the key as a slice of float64 numbers. +func (c *Context) GetFloat64Slice(key string) (f64s []float64) { + return getTyped[[]float64](c, key) } // GetStringSlice returns the value associated with the key as a slice of strings. func (c *Context) GetStringSlice(key string) (ss []string) { - if val, ok := c.Get(key); ok && val != nil { - ss, _ = val.([]string) - } - return + return getTyped[[]string](c, key) } // GetStringMap returns the value associated with the key as a map of interfaces. func (c *Context) GetStringMap(key string) (sm map[string]any) { - if val, ok := c.Get(key); ok && val != nil { - sm, _ = val.(map[string]any) - } - return + return getTyped[map[string]any](c, key) } // GetStringMapString returns the value associated with the key as a map of strings. func (c *Context) GetStringMapString(key string) (sms map[string]string) { - if val, ok := c.Get(key); ok && val != nil { - sms, _ = val.(map[string]string) - } - return + return getTyped[map[string]string](c, key) } // GetStringMapStringSlice returns the value associated with the key as a map to a slice of strings. func (c *Context) GetStringMapStringSlice(key string) (smss map[string][]string) { - if val, ok := c.Get(key); ok && val != nil { - smss, _ = val.(map[string][]string) - } - return + return getTyped[map[string][]string](c, key) } /************************************/ @@ -386,7 +467,7 @@ func (c *Context) GetStringMapStringSlice(key string) (smss map[string][]string) // // router.GET("/user/:id", func(c *gin.Context) { // // a GET request to /user/john -// id := c.Param("id") // id == "/john" +// id := c.Param("id") // id == "john" // // a GET request to /user/john/ // id := c.Param("id") // id == "/john/" // }) @@ -457,7 +538,7 @@ func (c *Context) QueryArray(key string) (values []string) { func (c *Context) initQueryCache() { if c.queryCache == nil { - if c.Request != nil { + if c.Request != nil && c.Request.URL != nil { c.queryCache = c.Request.URL.Query() } else { c.queryCache = url.Values{} @@ -596,14 +677,22 @@ func (c *Context) MultipartForm() (*multipart.Form, error) { } // SaveUploadedFile uploads the form file to specific dst. -func (c *Context) SaveUploadedFile(file *multipart.FileHeader, dst string) error { +func (c *Context) SaveUploadedFile(file *multipart.FileHeader, dst string, perm ...fs.FileMode) error { src, err := file.Open() if err != nil { return err } defer src.Close() - if err = os.MkdirAll(filepath.Dir(dst), 0750); err != nil { + var mode os.FileMode = 0o750 + if len(perm) > 0 { + mode = perm[0] + } + dir := filepath.Dir(dst) + if err = os.MkdirAll(dir, mode); err != nil { + return err + } + if err = os.Chmod(dir, mode); err != nil { return err } @@ -656,6 +745,11 @@ func (c *Context) BindTOML(obj any) error { return c.MustBindWith(obj, binding.TOML) } +// BindPlain is a shortcut for c.MustBindWith(obj, binding.Plain). +func (c *Context) BindPlain(obj any) error { + return c.MustBindWith(obj, binding.Plain) +} + // BindHeader is a shortcut for c.MustBindWith(obj, binding.Header). func (c *Context) BindHeader(obj any) error { return c.MustBindWith(obj, binding.Header) @@ -728,6 +822,11 @@ func (c *Context) ShouldBindTOML(obj any) error { return c.ShouldBindWith(obj, binding.TOML) } +// ShouldBindPlain is a shortcut for c.ShouldBindWith(obj, binding.Plain). +func (c *Context) ShouldBindPlain(obj any) error { + return c.ShouldBindWith(obj, binding.Plain) +} + // ShouldBindHeader is a shortcut for c.ShouldBindWith(obj, binding.Header). func (c *Context) ShouldBindHeader(obj any) error { return c.ShouldBindWith(obj, binding.Header) @@ -735,7 +834,7 @@ func (c *Context) ShouldBindHeader(obj any) error { // ShouldBindUri binds the passed struct pointer using the specified binding engine. func (c *Context) ShouldBindUri(obj any) error { - m := make(map[string][]string) + m := make(map[string][]string, len(c.Params)) for _, v := range c.Params { m[v.Key] = []string{v.Value} } @@ -770,9 +869,34 @@ func (c *Context) ShouldBindBodyWith(obj any, bb binding.BindingBody) (err error return bb.BindBody(body, obj) } +// ShouldBindBodyWithJSON is a shortcut for c.ShouldBindBodyWith(obj, binding.JSON). +func (c *Context) ShouldBindBodyWithJSON(obj any) error { + return c.ShouldBindBodyWith(obj, binding.JSON) +} + +// ShouldBindBodyWithXML is a shortcut for c.ShouldBindBodyWith(obj, binding.XML). +func (c *Context) ShouldBindBodyWithXML(obj any) error { + return c.ShouldBindBodyWith(obj, binding.XML) +} + +// ShouldBindBodyWithYAML is a shortcut for c.ShouldBindBodyWith(obj, binding.YAML). +func (c *Context) ShouldBindBodyWithYAML(obj any) error { + return c.ShouldBindBodyWith(obj, binding.YAML) +} + +// ShouldBindBodyWithTOML is a shortcut for c.ShouldBindBodyWith(obj, binding.TOML). +func (c *Context) ShouldBindBodyWithTOML(obj any) error { + return c.ShouldBindBodyWith(obj, binding.TOML) +} + +// ShouldBindBodyWithPlain is a shortcut for c.ShouldBindBodyWith(obj, binding.Plain). +func (c *Context) ShouldBindBodyWithPlain(obj any) error { + return c.ShouldBindBodyWith(obj, binding.Plain) +} + // ClientIP implements one best effort algorithm to return the real client IP. // It calls c.RemoteIP() under the hood, to check if the remote IP is a trusted proxy or not. -// If it is it will then try to parse the headers defined in Engine.RemoteIPHeaders (defaulting to [X-Forwarded-For, X-Real-Ip]). +// If it is it will then try to parse the headers defined in Engine.RemoteIPHeaders (defaulting to [X-Forwarded-For, X-Real-IP]). // If the headers are not syntactically valid OR the remote IP does not correspond to a trusted proxy, // the remote IP (coming from Request.RemoteAddr) is returned. func (c *Context) ClientIP() string { @@ -880,6 +1004,9 @@ func (c *Context) GetHeader(key string) string { // GetRawData returns stream data. func (c *Context) GetRawData() ([]byte, error) { + if c.Request.Body == nil { + return nil, errors.New("cannot read nil body") + } return io.ReadAll(c.Request.Body) } @@ -1134,7 +1261,7 @@ func (c *Context) Negotiate(code int, config Negotiate) { data := chooseData(config.XMLData, config.Data) c.XML(code, data) - case binding.MIMEYAML: + case binding.MIMEYAML, binding.MIMEYAML2: data := chooseData(config.YAMLData, config.Data) c.YAML(code, data) @@ -1222,7 +1349,7 @@ func (c *Context) Err() error { // if no value is associated with key. Successive calls to Value with // the same key returns the same result. func (c *Context) Value(key any) any { - if key == 0 { + if key == ContextRequestKey { return c.Request } if key == ContextKey { diff --git a/context_1.18_test.go b/context_1.18_test.go deleted file mode 100644 index 6118beaa..00000000 --- a/context_1.18_test.go +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright 2021 Gin Core Team. All rights reserved. -// Use of this source code is governed by a MIT style -// license that can be found in the LICENSE file. - -//go:build !go1.19 - -package gin - -import ( - "bytes" - "mime/multipart" - "net/http" - "net/http/httptest" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestContextFormFileFailed18(t *testing.T) { - buf := new(bytes.Buffer) - mw := multipart.NewWriter(buf) - defer func(mw *multipart.Writer) { - err := mw.Close() - if err != nil { - assert.Error(t, err) - } - }(mw) - c, _ := CreateTestContext(httptest.NewRecorder()) - c.Request, _ = http.NewRequest("POST", "/", nil) - c.Request.Header.Set("Content-Type", mw.FormDataContentType()) - c.engine.MaxMultipartMemory = 8 << 20 - assert.Panics(t, func() { - f, err := c.FormFile("file") - assert.Error(t, err) - assert.Nil(t, f) - }) -} diff --git a/context_1.19_test.go b/context_1.19_test.go deleted file mode 100644 index dd75325b..00000000 --- a/context_1.19_test.go +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright 2022 Gin Core Team. All rights reserved. -// Use of this source code is governed by a MIT style -// license that can be found in the LICENSE file. - -//go:build go1.19 - -package gin - -import ( - "bytes" - "mime/multipart" - "net/http" - "net/http/httptest" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestContextFormFileFailed19(t *testing.T) { - buf := new(bytes.Buffer) - mw := multipart.NewWriter(buf) - mw.Close() - c, _ := CreateTestContext(httptest.NewRecorder()) - c.Request, _ = http.NewRequest("POST", "/", nil) - c.Request.Header.Set("Content-Type", mw.FormDataContentType()) - c.engine.MaxMultipartMemory = 8 << 20 - f, err := c.FormFile("file") - assert.Error(t, err) - assert.Nil(t, f) -} diff --git a/context_test.go b/context_test.go index 237eab0e..562df979 100644 --- a/context_test.go +++ b/context_test.go @@ -11,13 +11,16 @@ import ( "fmt" "html/template" "io" + "io/fs" "mime/multipart" "net" "net/http" "net/http/httptest" "net/url" "os" + "path/filepath" "reflect" + "strconv" "strings" "sync" "testing" @@ -27,6 +30,7 @@ import ( "github.com/gin-gonic/gin/binding" testdata "github.com/gin-gonic/gin/testdata/protoexample" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "google.golang.org/protobuf/proto" ) @@ -58,7 +62,7 @@ func createMultipartRequest() *http.Request { must(mw.WriteField("time_location", "31/12/2016 14:55")) must(mw.WriteField("names[a]", "thinkerou")) must(mw.WriteField("names[b]", "tianou")) - req, err := http.NewRequest("POST", "/", body) + req, err := http.NewRequest(http.MethodPost, "/", body) must(err) req.Header.Set("Content-Type", MIMEMultipartPOSTForm+"; boundary="+boundary) return req @@ -74,41 +78,50 @@ func TestContextFormFile(t *testing.T) { buf := new(bytes.Buffer) mw := multipart.NewWriter(buf) w, err := mw.CreateFormFile("file", "test") - if assert.NoError(t, err) { - _, err = w.Write([]byte("test")) - assert.NoError(t, err) - } + require.NoError(t, err) + _, err = w.Write([]byte("test")) + require.NoError(t, err) mw.Close() c, _ := CreateTestContext(httptest.NewRecorder()) - c.Request, _ = http.NewRequest("POST", "/", buf) + c.Request, _ = http.NewRequest(http.MethodPost, "/", buf) c.Request.Header.Set("Content-Type", mw.FormDataContentType()) f, err := c.FormFile("file") - if assert.NoError(t, err) { - assert.Equal(t, "test", f.Filename) - } + require.NoError(t, err) + assert.Equal(t, "test", f.Filename) - assert.NoError(t, c.SaveUploadedFile(f, "test")) + require.NoError(t, c.SaveUploadedFile(f, "test")) +} + +func TestContextFormFileFailed(t *testing.T) { + buf := new(bytes.Buffer) + mw := multipart.NewWriter(buf) + mw.Close() + c, _ := CreateTestContext(httptest.NewRecorder()) + c.Request, _ = http.NewRequest(http.MethodPost, "/", nil) + c.Request.Header.Set("Content-Type", mw.FormDataContentType()) + c.engine.MaxMultipartMemory = 8 << 20 + f, err := c.FormFile("file") + require.Error(t, err) + assert.Nil(t, f) } func TestContextMultipartForm(t *testing.T) { buf := new(bytes.Buffer) mw := multipart.NewWriter(buf) - assert.NoError(t, mw.WriteField("foo", "bar")) + require.NoError(t, mw.WriteField("foo", "bar")) w, err := mw.CreateFormFile("file", "test") - if assert.NoError(t, err) { - _, err = w.Write([]byte("test")) - assert.NoError(t, err) - } + require.NoError(t, err) + _, err = w.Write([]byte("test")) + require.NoError(t, err) mw.Close() c, _ := CreateTestContext(httptest.NewRecorder()) - c.Request, _ = http.NewRequest("POST", "/", buf) + c.Request, _ = http.NewRequest(http.MethodPost, "/", buf) c.Request.Header.Set("Content-Type", mw.FormDataContentType()) f, err := c.MultipartForm() - if assert.NoError(t, err) { - assert.NotNil(t, f) - } + require.NoError(t, err) + assert.NotNil(t, f) - assert.NoError(t, c.SaveUploadedFile(f.File["file"][0], "test")) + require.NoError(t, c.SaveUploadedFile(f.File["file"][0], "test")) } func TestSaveUploadedOpenFailed(t *testing.T) { @@ -117,33 +130,70 @@ func TestSaveUploadedOpenFailed(t *testing.T) { mw.Close() c, _ := CreateTestContext(httptest.NewRecorder()) - c.Request, _ = http.NewRequest("POST", "/", buf) + c.Request, _ = http.NewRequest(http.MethodPost, "/", buf) c.Request.Header.Set("Content-Type", mw.FormDataContentType()) f := &multipart.FileHeader{ Filename: "file", } - assert.Error(t, c.SaveUploadedFile(f, "test")) + require.Error(t, c.SaveUploadedFile(f, "test")) } func TestSaveUploadedCreateFailed(t *testing.T) { buf := new(bytes.Buffer) mw := multipart.NewWriter(buf) w, err := mw.CreateFormFile("file", "test") - if assert.NoError(t, err) { - _, err = w.Write([]byte("test")) - assert.NoError(t, err) - } + require.NoError(t, err) + _, err = w.Write([]byte("test")) + require.NoError(t, err) mw.Close() c, _ := CreateTestContext(httptest.NewRecorder()) - c.Request, _ = http.NewRequest("POST", "/", buf) + c.Request, _ = http.NewRequest(http.MethodPost, "/", buf) c.Request.Header.Set("Content-Type", mw.FormDataContentType()) f, err := c.FormFile("file") - if assert.NoError(t, err) { - assert.Equal(t, "test", f.Filename) - } + require.NoError(t, err) + assert.Equal(t, "test", f.Filename) - assert.Error(t, c.SaveUploadedFile(f, "/")) + require.Error(t, c.SaveUploadedFile(f, "/")) +} + +func TestSaveUploadedFileWithPermission(t *testing.T) { + buf := new(bytes.Buffer) + mw := multipart.NewWriter(buf) + w, err := mw.CreateFormFile("file", "permission_test") + require.NoError(t, err) + _, err = w.Write([]byte("permission_test")) + require.NoError(t, err) + mw.Close() + c, _ := CreateTestContext(httptest.NewRecorder()) + c.Request, _ = http.NewRequest(http.MethodPost, "/", buf) + c.Request.Header.Set("Content-Type", mw.FormDataContentType()) + f, err := c.FormFile("file") + require.NoError(t, err) + assert.Equal(t, "permission_test", f.Filename) + var mode fs.FileMode = 0o755 + require.NoError(t, c.SaveUploadedFile(f, "permission_test", mode)) + info, err := os.Stat(filepath.Dir("permission_test")) + require.NoError(t, err) + assert.Equal(t, info.Mode().Perm(), mode) +} + +func TestSaveUploadedFileWithPermissionFailed(t *testing.T) { + buf := new(bytes.Buffer) + mw := multipart.NewWriter(buf) + w, err := mw.CreateFormFile("file", "permission_test") + require.NoError(t, err) + _, err = w.Write([]byte("permission_test")) + require.NoError(t, err) + mw.Close() + c, _ := CreateTestContext(httptest.NewRecorder()) + c.Request, _ = http.NewRequest(http.MethodPost, "/", buf) + c.Request.Header.Set("Content-Type", mw.FormDataContentType()) + f, err := c.FormFile("file") + require.NoError(t, err) + assert.Equal(t, "permission_test", f.Filename) + var mode fs.FileMode = 0o644 + require.Error(t, c.SaveUploadedFile(f, "test/permission_test", mode)) } func TestContextReset(t *testing.T) { @@ -161,10 +211,10 @@ func TestContextReset(t *testing.T) { assert.False(t, c.IsAborted()) assert.Nil(t, c.Keys) assert.Nil(t, c.Accepted) - assert.Len(t, c.Errors, 0) + assert.Empty(t, c.Errors) assert.Empty(t, c.Errors.Errors()) assert.Empty(t, c.Errors.ByType(ErrorTypeAny)) - assert.Len(t, c.Params, 0) + assert.Empty(t, c.Params) assert.EqualValues(t, c.index, -1) assert.Equal(t, c.Writer.(*responseWriter), &c.writermem) } @@ -217,13 +267,13 @@ func TestContextSetGetValues(t *testing.T) { var a any = 1 c.Set("intInterface", a) - assert.Exactly(t, c.MustGet("string").(string), "this is a string") + assert.Exactly(t, "this is a string", c.MustGet("string").(string)) assert.Exactly(t, c.MustGet("int32").(int32), int32(-42)) - assert.Exactly(t, c.MustGet("int64").(int64), int64(42424242424242)) - assert.Exactly(t, c.MustGet("uint64").(uint64), uint64(42)) - assert.Exactly(t, c.MustGet("float32").(float32), float32(4.2)) - assert.Exactly(t, c.MustGet("float64").(float64), 4.2) - assert.Exactly(t, c.MustGet("intInterface").(int), 1) + assert.Exactly(t, int64(42424242424242), c.MustGet("int64").(int64)) + assert.Exactly(t, uint64(42), c.MustGet("uint64").(uint64)) + assert.InDelta(t, float32(4.2), c.MustGet("float32").(float32), 0.01) + assert.InDelta(t, 4.2, c.MustGet("float64").(float64), 0.01) + assert.Exactly(t, 1, c.MustGet("intInterface").(int)) } func TestContextGetString(t *testing.T) { @@ -244,6 +294,30 @@ func TestContextGetInt(t *testing.T) { assert.Equal(t, 1, c.GetInt("int")) } +func TestContextGetInt8(t *testing.T) { + c, _ := CreateTestContext(httptest.NewRecorder()) + key := "int8" + value := int8(0x7F) + c.Set(key, value) + assert.Equal(t, value, c.GetInt8(key)) +} + +func TestContextGetInt16(t *testing.T) { + c, _ := CreateTestContext(httptest.NewRecorder()) + key := "int16" + value := int16(0x7FFF) + c.Set(key, value) + assert.Equal(t, value, c.GetInt16(key)) +} + +func TestContextGetInt32(t *testing.T) { + c, _ := CreateTestContext(httptest.NewRecorder()) + key := "int32" + value := int32(0x7FFFFFFF) + c.Set(key, value) + assert.Equal(t, value, c.GetInt32(key)) +} + func TestContextGetInt64(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) c.Set("int64", int64(42424242424242)) @@ -256,16 +330,48 @@ func TestContextGetUint(t *testing.T) { assert.Equal(t, uint(1), c.GetUint("uint")) } +func TestContextGetUint8(t *testing.T) { + c, _ := CreateTestContext(httptest.NewRecorder()) + key := "uint8" + value := uint8(0xFF) + c.Set(key, value) + assert.Equal(t, value, c.GetUint8(key)) +} + +func TestContextGetUint16(t *testing.T) { + c, _ := CreateTestContext(httptest.NewRecorder()) + key := "uint16" + value := uint16(0xFFFF) + c.Set(key, value) + assert.Equal(t, value, c.GetUint16(key)) +} + +func TestContextGetUint32(t *testing.T) { + c, _ := CreateTestContext(httptest.NewRecorder()) + key := "uint32" + value := uint32(0xFFFFFFFF) + c.Set(key, value) + assert.Equal(t, value, c.GetUint32(key)) +} + func TestContextGetUint64(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) c.Set("uint64", uint64(18446744073709551615)) assert.Equal(t, uint64(18446744073709551615), c.GetUint64("uint64")) } +func TestContextGetFloat32(t *testing.T) { + c, _ := CreateTestContext(httptest.NewRecorder()) + key := "float32" + value := float32(3.14) + c.Set(key, value) + assert.InDelta(t, value, c.GetFloat32(key), 0.01) +} + func TestContextGetFloat64(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) c.Set("float64", 4.2) - assert.Equal(t, 4.2, c.GetFloat64("float64")) + assert.InDelta(t, 4.2, c.GetFloat64("float64"), 0.01) } func TestContextGetTime(t *testing.T) { @@ -281,6 +387,102 @@ func TestContextGetDuration(t *testing.T) { assert.Equal(t, time.Second, c.GetDuration("duration")) } +func TestContextGetIntSlice(t *testing.T) { + c, _ := CreateTestContext(httptest.NewRecorder()) + key := "int-slice" + value := []int{1, 2} + c.Set(key, value) + assert.Equal(t, value, c.GetIntSlice(key)) +} + +func TestContextGetInt8Slice(t *testing.T) { + c, _ := CreateTestContext(httptest.NewRecorder()) + key := "int8-slice" + value := []int8{1, 2} + c.Set(key, value) + assert.Equal(t, value, c.GetInt8Slice(key)) +} + +func TestContextGetInt16Slice(t *testing.T) { + c, _ := CreateTestContext(httptest.NewRecorder()) + key := "int16-slice" + value := []int16{1, 2} + c.Set(key, value) + assert.Equal(t, value, c.GetInt16Slice(key)) +} + +func TestContextGetInt32Slice(t *testing.T) { + c, _ := CreateTestContext(httptest.NewRecorder()) + key := "int32-slice" + value := []int32{1, 2} + c.Set(key, value) + assert.Equal(t, value, c.GetInt32Slice(key)) +} + +func TestContextGetInt64Slice(t *testing.T) { + c, _ := CreateTestContext(httptest.NewRecorder()) + key := "int64-slice" + value := []int64{1, 2} + c.Set(key, value) + assert.Equal(t, value, c.GetInt64Slice(key)) +} + +func TestContextGetUintSlice(t *testing.T) { + c, _ := CreateTestContext(httptest.NewRecorder()) + key := "uint-slice" + value := []uint{1, 2} + c.Set(key, value) + assert.Equal(t, value, c.GetUintSlice(key)) +} + +func TestContextGetUint8Slice(t *testing.T) { + c, _ := CreateTestContext(httptest.NewRecorder()) + key := "uint8-slice" + value := []uint8{1, 2} + c.Set(key, value) + assert.Equal(t, value, c.GetUint8Slice(key)) +} + +func TestContextGetUint16Slice(t *testing.T) { + c, _ := CreateTestContext(httptest.NewRecorder()) + key := "uint16-slice" + value := []uint16{1, 2} + c.Set(key, value) + assert.Equal(t, value, c.GetUint16Slice(key)) +} + +func TestContextGetUint32Slice(t *testing.T) { + c, _ := CreateTestContext(httptest.NewRecorder()) + key := "uint32-slice" + value := []uint32{1, 2} + c.Set(key, value) + assert.Equal(t, value, c.GetUint32Slice(key)) +} + +func TestContextGetUint64Slice(t *testing.T) { + c, _ := CreateTestContext(httptest.NewRecorder()) + key := "uint64-slice" + value := []uint64{1, 2} + c.Set(key, value) + assert.Equal(t, value, c.GetUint64Slice(key)) +} + +func TestContextGetFloat32Slice(t *testing.T) { + c, _ := CreateTestContext(httptest.NewRecorder()) + key := "float32-slice" + value := []float32{1, 2} + c.Set(key, value) + assert.Equal(t, value, c.GetFloat32Slice(key)) +} + +func TestContextGetFloat64Slice(t *testing.T) { + c, _ := CreateTestContext(httptest.NewRecorder()) + key := "float64-slice" + value := []float64{1, 2} + c.Set(key, value) + assert.Equal(t, value, c.GetFloat64Slice(key)) +} + func TestContextGetStringSlice(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) c.Set("slice", []string{"foo"}) @@ -320,22 +522,24 @@ func TestContextGetStringMapStringSlice(t *testing.T) { func TestContextCopy(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) c.index = 2 - c.Request, _ = http.NewRequest("POST", "/hola", nil) + c.Request, _ = http.NewRequest(http.MethodPost, "/hola", nil) c.handlers = HandlersChain{func(c *Context) {}} c.Params = Params{Param{Key: "foo", Value: "bar"}} c.Set("foo", "bar") + c.fullPath = "/hola" cp := c.Copy() assert.Nil(t, cp.handlers) assert.Nil(t, cp.writermem.ResponseWriter) assert.Equal(t, &cp.writermem, cp.Writer.(*responseWriter)) assert.Equal(t, cp.Request, c.Request) - assert.Equal(t, cp.index, abortIndex) + assert.Equal(t, abortIndex, cp.index) assert.Equal(t, cp.Keys, c.Keys) assert.Equal(t, cp.engine, c.engine) assert.Equal(t, cp.Params, c.Params) cp.Set("foo", "notBar") - assert.False(t, cp.Keys["foo"] == c.Keys["foo"]) + assert.NotEqual(t, cp.Keys["foo"], c.Keys["foo"]) + assert.Equal(t, cp.fullPath, c.fullPath) } func TestContextHandlerName(t *testing.T) { @@ -347,11 +551,11 @@ func TestContextHandlerName(t *testing.T) { func TestContextHandlerNames(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) - c.handlers = HandlersChain{func(c *Context) {}, handlerNameTest, func(c *Context) {}, handlerNameTest2} + c.handlers = HandlersChain{func(c *Context) {}, nil, handlerNameTest, func(c *Context) {}, handlerNameTest2} names := c.HandlerNames() - assert.True(t, len(names) == 4) + assert.Len(t, names, 4) for _, name := range names { assert.Regexp(t, `^(.*/vendor/)?(github\.com/gin-gonic/gin\.){1}(TestContextHandlerNames\.func.*){0,1}(handlerNameTest.*){0,1}`, name) } @@ -375,7 +579,7 @@ func TestContextHandler(t *testing.T) { func TestContextQuery(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) - c.Request, _ = http.NewRequest("GET", "http://example.com/?foo=bar&page=10&id=", nil) + c.Request, _ = http.NewRequest(http.MethodGet, "http://example.com/?foo=bar&page=10&id=", nil) value, ok := c.GetQuery("foo") assert.True(t, ok) @@ -408,6 +612,48 @@ func TestContextQuery(t *testing.T) { assert.Empty(t, c.PostForm("foo")) } +func TestContextInitQueryCache(t *testing.T) { + validURL, err := url.Parse("https://github.com/gin-gonic/gin/pull/3969?key=value&otherkey=othervalue") + require.NoError(t, err) + + tests := []struct { + testName string + testContext *Context + expectedQueryCache url.Values + }{ + { + testName: "queryCache should remain unchanged if already not nil", + testContext: &Context{ + queryCache: url.Values{"a": []string{"b"}}, + Request: &http.Request{URL: validURL}, // valid request for evidence that values weren't extracted + }, + expectedQueryCache: url.Values{"a": []string{"b"}}, + }, + { + testName: "queryCache should be empty when Request is nil", + testContext: &Context{Request: nil}, // explicit nil for readability + expectedQueryCache: url.Values{}, + }, + { + testName: "queryCache should be empty when Request.URL is nil", + testContext: &Context{Request: &http.Request{URL: nil}}, // explicit nil for readability + expectedQueryCache: url.Values{}, + }, + { + testName: "queryCache should be populated when it not yet populated and Request + Request.URL are non nil", + testContext: &Context{Request: &http.Request{URL: validURL}}, // explicit nil for readability + expectedQueryCache: url.Values{"key": []string{"value"}, "otherkey": []string{"othervalue"}}, + }, + } + + for _, test := range tests { + t.Run(test.testName, func(t *testing.T) { + test.testContext.initQueryCache() + assert.Equal(t, test.expectedQueryCache, test.testContext.queryCache) + }) + } +} + func TestContextDefaultQueryOnEmptyRequest(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) // here c.Request == nil assert.NotPanics(t, func() { @@ -426,7 +672,7 @@ func TestContextDefaultQueryOnEmptyRequest(t *testing.T) { func TestContextQueryAndPostForm(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) body := strings.NewReader("foo=bar&page=11&both=&foo=second") - c.Request, _ = http.NewRequest("POST", + c.Request, _ = http.NewRequest(http.MethodPost, "/?both=GET&id=main&id=omit&array[]=first&array[]=second&ids[a]=hi&ids[b]=3.14", body) c.Request.Header.Add("Content-Type", MIMEPOSTForm) @@ -446,7 +692,7 @@ func TestContextQueryAndPostForm(t *testing.T) { assert.Empty(t, value) assert.Empty(t, c.PostForm("both")) assert.Empty(t, c.DefaultPostForm("both", "nothing")) - assert.Equal(t, "GET", c.Query("both"), "GET") + assert.Equal(t, http.MethodGet, c.Query("both"), http.MethodGet) value, ok = c.GetQuery("id") assert.True(t, ok) @@ -473,7 +719,7 @@ func TestContextQueryAndPostForm(t *testing.T) { Both string `form:"both"` Array []string `form:"array[]"` } - assert.NoError(t, c.Bind(&obj)) + require.NoError(t, c.Bind(&obj)) assert.Equal(t, "bar", obj.Foo, "bar") assert.Equal(t, "main", obj.ID, "main") assert.Equal(t, 11, obj.Page, 11) @@ -490,11 +736,11 @@ func TestContextQueryAndPostForm(t *testing.T) { assert.Equal(t, "second", values[1]) values = c.QueryArray("nokey") - assert.Equal(t, 0, len(values)) + assert.Empty(t, values) values = c.QueryArray("both") - assert.Equal(t, 1, len(values)) - assert.Equal(t, "GET", values[0]) + assert.Len(t, values, 1) + assert.Equal(t, http.MethodGet, values[0]) dicts, ok := c.GetQueryMap("ids") assert.True(t, ok) @@ -503,22 +749,22 @@ func TestContextQueryAndPostForm(t *testing.T) { dicts, ok = c.GetQueryMap("nokey") assert.False(t, ok) - assert.Equal(t, 0, len(dicts)) + assert.Empty(t, dicts) dicts, ok = c.GetQueryMap("both") assert.False(t, ok) - assert.Equal(t, 0, len(dicts)) + assert.Empty(t, dicts) dicts, ok = c.GetQueryMap("array") assert.False(t, ok) - assert.Equal(t, 0, len(dicts)) + assert.Empty(t, dicts) dicts = c.QueryMap("ids") assert.Equal(t, "hi", dicts["a"]) assert.Equal(t, "3.14", dicts["b"]) dicts = c.QueryMap("nokey") - assert.Equal(t, 0, len(dicts)) + assert.Empty(t, dicts) } func TestContextPostFormMultipart(t *testing.T) { @@ -536,7 +782,7 @@ func TestContextPostFormMultipart(t *testing.T) { TimeLocation time.Time `form:"time_location" time_format:"02/01/2006 15:04" time_location:"Asia/Tokyo"` BlankTime time.Time `form:"blank_time" time_format:"02/01/2006 15:04"` } - assert.NoError(t, c.Bind(&obj)) + require.NoError(t, c.Bind(&obj)) assert.Equal(t, "bar", obj.Foo) assert.Equal(t, "10", obj.Bar) assert.Equal(t, 10, obj.BarAsInt) @@ -590,10 +836,10 @@ func TestContextPostFormMultipart(t *testing.T) { assert.Equal(t, "second", values[1]) values = c.PostFormArray("nokey") - assert.Equal(t, 0, len(values)) + assert.Empty(t, values) values = c.PostFormArray("foo") - assert.Equal(t, 1, len(values)) + assert.Len(t, values, 1) assert.Equal(t, "bar", values[0]) dicts, ok := c.GetPostFormMap("names") @@ -603,14 +849,14 @@ func TestContextPostFormMultipart(t *testing.T) { dicts, ok = c.GetPostFormMap("nokey") assert.False(t, ok) - assert.Equal(t, 0, len(dicts)) + assert.Empty(t, dicts) dicts = c.PostFormMap("names") assert.Equal(t, "thinkerou", dicts["a"]) assert.Equal(t, "tianou", dicts["b"]) dicts = c.PostFormMap("nokey") - assert.Equal(t, 0, len(dicts)) + assert.Empty(t, dicts) } func TestContextSetCookie(t *testing.T) { @@ -629,13 +875,13 @@ func TestContextSetCookiePathEmpty(t *testing.T) { func TestContextGetCookie(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) - c.Request, _ = http.NewRequest("GET", "/get", nil) + c.Request, _ = http.NewRequest(http.MethodGet, "/get", nil) c.Request.Header.Set("Cookie", "user=gin") cookie, _ := c.Cookie("user") assert.Equal(t, "gin", cookie) _, err := c.Cookie("nokey") - assert.Error(t, err) + require.Error(t, err) } func TestContextBodyAllowedForStatus(t *testing.T) { @@ -681,7 +927,7 @@ func TestContextRenderJSON(t *testing.T) { func TestContextRenderJSONP(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest("GET", "http://example.com/?callback=x", nil) + c.Request, _ = http.NewRequest(http.MethodGet, "http://example.com/?callback=x", nil) c.JSONP(http.StatusCreated, H{"foo": "bar"}) @@ -695,7 +941,7 @@ func TestContextRenderJSONP(t *testing.T) { func TestContextRenderJSONPWithoutCallback(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest("GET", "http://example.com", nil) + c.Request, _ = http.NewRequest(http.MethodGet, "http://example.com", nil) c.JSONP(http.StatusCreated, H{"foo": "bar"}) @@ -740,7 +986,7 @@ func TestContextRenderNoContentAPIJSON(t *testing.T) { assert.Equal(t, http.StatusNoContent, w.Code) assert.Empty(t, w.Body.String()) - assert.Equal(t, w.Header().Get("Content-Type"), "application/vnd.api+json") + assert.Equal(t, "application/vnd.api+json", w.Header().Get("Content-Type")) } // Tests that the response is serialized as JSON @@ -838,7 +1084,7 @@ func TestContextRenderHTML2(t *testing.T) { c, router := CreateTestContext(w) // print debug warning log when Engine.trees > 0 - router.addRoute("GET", "/", HandlersChain{func(_ *Context) {}}) + router.addRoute(http.MethodGet, "/", HandlersChain{func(_ *Context) {}}) assert.Len(t, router.trees, 1) templ := template.Must(template.New("t").Parse(`Hello {{.name}}`)) @@ -948,7 +1194,7 @@ func TestContextRenderNoContentHTMLString(t *testing.T) { assert.Equal(t, "text/html; charset=utf-8", w.Header().Get("Content-Type")) } -// TestContextData tests that the response can be written from `bytestring` +// TestContextRenderData tests that the response can be written from `bytestring` // with specified MIME type func TestContextRenderData(t *testing.T) { w := httptest.NewRecorder() @@ -994,11 +1240,11 @@ func TestContextRenderFile(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest("GET", "/", nil) + c.Request, _ = http.NewRequest(http.MethodGet, "/", nil) c.File("./gin.go") assert.Equal(t, http.StatusOK, w.Code) - assert.Contains(t, w.Body.String(), "func New() *Engine {") + assert.Contains(t, w.Body.String(), "func New(opts ...OptionFunc) *Engine {") // Content-Type='text/plain; charset=utf-8' when go version <= 1.16, // else, Content-Type='text/x-go; charset=utf-8' assert.NotEqual(t, "", w.Header().Get("Content-Type")) @@ -1008,11 +1254,11 @@ func TestContextRenderFileFromFS(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest("GET", "/some/path", nil) + c.Request, _ = http.NewRequest(http.MethodGet, "/some/path", nil) c.FileFromFS("./gin.go", Dir(".", false)) assert.Equal(t, http.StatusOK, w.Code) - assert.Contains(t, w.Body.String(), "func New() *Engine {") + assert.Contains(t, w.Body.String(), "func New(opts ...OptionFunc) *Engine {") // Content-Type='text/plain; charset=utf-8' when go version <= 1.16, // else, Content-Type='text/x-go; charset=utf-8' assert.NotEqual(t, "", w.Header().Get("Content-Type")) @@ -1024,11 +1270,11 @@ func TestContextRenderAttachment(t *testing.T) { c, _ := CreateTestContext(w) newFilename := "new_filename.go" - c.Request, _ = http.NewRequest("GET", "/", nil) + c.Request, _ = http.NewRequest(http.MethodGet, "/", nil) c.FileAttachment("./gin.go", newFilename) assert.Equal(t, 200, w.Code) - assert.Contains(t, w.Body.String(), "func New() *Engine {") + assert.Contains(t, w.Body.String(), "func New(opts ...OptionFunc) *Engine {") assert.Equal(t, fmt.Sprintf("attachment; filename=\"%s\"", newFilename), w.Header().Get("Content-Disposition")) } @@ -1038,11 +1284,11 @@ func TestContextRenderAndEscapeAttachment(t *testing.T) { maliciousFilename := "tampering_field.sh\"; \\\"; dummy=.go" actualEscapedResponseFilename := "tampering_field.sh\\\"; \\\\\\\"; dummy=.go" - c.Request, _ = http.NewRequest("GET", "/", nil) + c.Request, _ = http.NewRequest(http.MethodGet, "/", nil) c.FileAttachment("./gin.go", maliciousFilename) assert.Equal(t, 200, w.Code) - assert.Contains(t, w.Body.String(), "func New() *Engine {") + assert.Contains(t, w.Body.String(), "func New(opts ...OptionFunc) *Engine {") assert.Equal(t, fmt.Sprintf("attachment; filename=\"%s\"", actualEscapedResponseFilename), w.Header().Get("Content-Disposition")) } @@ -1051,16 +1297,16 @@ func TestContextRenderUTF8Attachment(t *testing.T) { c, _ := CreateTestContext(w) newFilename := "new🧡_filename.go" - c.Request, _ = http.NewRequest("GET", "/", nil) + c.Request, _ = http.NewRequest(http.MethodGet, "/", nil) c.FileAttachment("./gin.go", newFilename) assert.Equal(t, 200, w.Code) - assert.Contains(t, w.Body.String(), "func New() *Engine {") + assert.Contains(t, w.Body.String(), "func New(opts ...OptionFunc) *Engine {") assert.Equal(t, `attachment; filename*=UTF-8''`+url.QueryEscape(newFilename), w.Header().Get("Content-Disposition")) } // TestContextRenderYAML tests that the response is serialized as YAML -// and Content-Type is set to application/x-yaml +// and Content-Type is set to application/yaml func TestContextRenderYAML(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) @@ -1069,7 +1315,7 @@ func TestContextRenderYAML(t *testing.T) { assert.Equal(t, http.StatusCreated, w.Code) assert.Equal(t, "foo: bar\n", w.Body.String()) - assert.Equal(t, "application/x-yaml; charset=utf-8", w.Header().Get("Content-Type")) + assert.Equal(t, "application/yaml; charset=utf-8", w.Header().Get("Content-Type")) } // TestContextRenderTOML tests that the response is serialized as TOML @@ -1102,7 +1348,7 @@ func TestContextRenderProtoBuf(t *testing.T) { c.ProtoBuf(http.StatusCreated, data) protoData, err := proto.Marshal(data) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, http.StatusCreated, w.Code) assert.Equal(t, string(protoData), w.Body.String()) @@ -1130,7 +1376,7 @@ func TestContextRenderRedirectWithRelativePath(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest("POST", "http://example.com", nil) + c.Request, _ = http.NewRequest(http.MethodPost, "http://example.com", nil) assert.Panics(t, func() { c.Redirect(299, "/new_path") }) assert.Panics(t, func() { c.Redirect(309, "/new_path") }) @@ -1144,7 +1390,7 @@ func TestContextRenderRedirectWithAbsolutePath(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest("POST", "http://example.com", nil) + c.Request, _ = http.NewRequest(http.MethodPost, "http://example.com", nil) c.Redirect(http.StatusFound, "http://google.com") c.Writer.WriteHeaderNow() @@ -1156,7 +1402,7 @@ func TestContextRenderRedirectWith201(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest("POST", "http://example.com", nil) + c.Request, _ = http.NewRequest(http.MethodPost, "http://example.com", nil) c.Redirect(http.StatusCreated, "/resource") c.Writer.WriteHeaderNow() @@ -1166,7 +1412,7 @@ func TestContextRenderRedirectWith201(t *testing.T) { func TestContextRenderRedirectAll(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) - c.Request, _ = http.NewRequest("POST", "http://example.com", nil) + c.Request, _ = http.NewRequest(http.MethodPost, "http://example.com", nil) assert.Panics(t, func() { c.Redirect(http.StatusOK, "/resource") }) assert.Panics(t, func() { c.Redirect(http.StatusAccepted, "/resource") }) assert.Panics(t, func() { c.Redirect(299, "/resource") }) @@ -1178,10 +1424,10 @@ func TestContextRenderRedirectAll(t *testing.T) { func TestContextNegotiationWithJSON(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest("POST", "", nil) + c.Request, _ = http.NewRequest(http.MethodPost, "", nil) c.Negotiate(http.StatusOK, Negotiate{ - Offered: []string{MIMEJSON, MIMEXML, MIMEYAML}, + Offered: []string{MIMEJSON, MIMEXML, MIMEYAML, MIMEYAML2}, Data: H{"foo": "bar"}, }) @@ -1193,10 +1439,10 @@ func TestContextNegotiationWithJSON(t *testing.T) { func TestContextNegotiationWithXML(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest("POST", "", nil) + c.Request, _ = http.NewRequest(http.MethodPost, "", nil) c.Negotiate(http.StatusOK, Negotiate{ - Offered: []string{MIMEXML, MIMEJSON, MIMEYAML}, + Offered: []string{MIMEXML, MIMEJSON, MIMEYAML, MIMEYAML2}, Data: H{"foo": "bar"}, }) @@ -1208,25 +1454,25 @@ func TestContextNegotiationWithXML(t *testing.T) { func TestContextNegotiationWithYAML(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest("POST", "", nil) + c.Request, _ = http.NewRequest(http.MethodPost, "", nil) c.Negotiate(http.StatusOK, Negotiate{ - Offered: []string{MIMEYAML, MIMEXML, MIMEJSON, MIMETOML}, + Offered: []string{MIMEYAML, MIMEXML, MIMEJSON, MIMETOML, MIMEYAML2}, Data: H{"foo": "bar"}, }) assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, "foo: bar\n", w.Body.String()) - assert.Equal(t, "application/x-yaml; charset=utf-8", w.Header().Get("Content-Type")) + assert.Equal(t, "application/yaml; charset=utf-8", w.Header().Get("Content-Type")) } func TestContextNegotiationWithTOML(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest("POST", "", nil) + c.Request, _ = http.NewRequest(http.MethodPost, "", nil) c.Negotiate(http.StatusOK, Negotiate{ - Offered: []string{MIMETOML, MIMEXML, MIMEJSON, MIMEYAML}, + Offered: []string{MIMETOML, MIMEXML, MIMEJSON, MIMEYAML, MIMEYAML2}, Data: H{"foo": "bar"}, }) @@ -1238,7 +1484,7 @@ func TestContextNegotiationWithTOML(t *testing.T) { func TestContextNegotiationWithHTML(t *testing.T) { w := httptest.NewRecorder() c, router := CreateTestContext(w) - c.Request, _ = http.NewRequest("POST", "", nil) + c.Request, _ = http.NewRequest(http.MethodPost, "", nil) templ := template.Must(template.New("t").Parse(`Hello {{.name}}`)) router.SetHTMLTemplate(templ) @@ -1256,20 +1502,20 @@ func TestContextNegotiationWithHTML(t *testing.T) { func TestContextNegotiationNotSupport(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest("POST", "", nil) + c.Request, _ = http.NewRequest(http.MethodPost, "", nil) c.Negotiate(http.StatusOK, Negotiate{ Offered: []string{MIMEPOSTForm}, }) assert.Equal(t, http.StatusNotAcceptable, w.Code) - assert.Equal(t, c.index, abortIndex) + assert.Equal(t, abortIndex, c.index) assert.True(t, c.IsAborted()) } func TestContextNegotiationFormat(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) - c.Request, _ = http.NewRequest("POST", "", nil) + c.Request, _ = http.NewRequest(http.MethodPost, "", nil) assert.Panics(t, func() { c.NegotiateFormat() }) assert.Equal(t, MIMEJSON, c.NegotiateFormat(MIMEJSON, MIMEXML)) @@ -1278,7 +1524,7 @@ func TestContextNegotiationFormat(t *testing.T) { func TestContextNegotiationFormatWithAccept(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) - c.Request, _ = http.NewRequest("POST", "/", nil) + c.Request, _ = http.NewRequest(http.MethodPost, "/", nil) c.Request.Header.Add("Accept", "text/html,application/xhtml+xml,application/xml;q=0.9;q=0.8") assert.Equal(t, MIMEXML, c.NegotiateFormat(MIMEJSON, MIMEXML)) @@ -1288,31 +1534,31 @@ func TestContextNegotiationFormatWithAccept(t *testing.T) { func TestContextNegotiationFormatWithWildcardAccept(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) - c.Request, _ = http.NewRequest("POST", "/", nil) + c.Request, _ = http.NewRequest(http.MethodPost, "/", nil) c.Request.Header.Add("Accept", "*/*") - assert.Equal(t, c.NegotiateFormat("*/*"), "*/*") - assert.Equal(t, c.NegotiateFormat("text/*"), "text/*") - assert.Equal(t, c.NegotiateFormat("application/*"), "application/*") - assert.Equal(t, c.NegotiateFormat(MIMEJSON), MIMEJSON) - assert.Equal(t, c.NegotiateFormat(MIMEXML), MIMEXML) - assert.Equal(t, c.NegotiateFormat(MIMEHTML), MIMEHTML) + assert.Equal(t, "*/*", c.NegotiateFormat("*/*")) + assert.Equal(t, "text/*", c.NegotiateFormat("text/*")) + assert.Equal(t, "application/*", c.NegotiateFormat("application/*")) + assert.Equal(t, MIMEJSON, c.NegotiateFormat(MIMEJSON)) + assert.Equal(t, MIMEXML, c.NegotiateFormat(MIMEXML)) + assert.Equal(t, MIMEHTML, c.NegotiateFormat(MIMEHTML)) c, _ = CreateTestContext(httptest.NewRecorder()) - c.Request, _ = http.NewRequest("POST", "/", nil) + c.Request, _ = http.NewRequest(http.MethodPost, "/", nil) c.Request.Header.Add("Accept", "text/*") - assert.Equal(t, c.NegotiateFormat("*/*"), "*/*") - assert.Equal(t, c.NegotiateFormat("text/*"), "text/*") - assert.Equal(t, c.NegotiateFormat("application/*"), "") - assert.Equal(t, c.NegotiateFormat(MIMEJSON), "") - assert.Equal(t, c.NegotiateFormat(MIMEXML), "") - assert.Equal(t, c.NegotiateFormat(MIMEHTML), MIMEHTML) + assert.Equal(t, "*/*", c.NegotiateFormat("*/*")) + assert.Equal(t, "text/*", c.NegotiateFormat("text/*")) + assert.Equal(t, "", c.NegotiateFormat("application/*")) + assert.Equal(t, "", c.NegotiateFormat(MIMEJSON)) + assert.Equal(t, "", c.NegotiateFormat(MIMEXML)) + assert.Equal(t, MIMEHTML, c.NegotiateFormat(MIMEHTML)) } func TestContextNegotiationFormatCustom(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) - c.Request, _ = http.NewRequest("POST", "/", nil) + c.Request, _ = http.NewRequest(http.MethodPost, "/", nil) c.Request.Header.Add("Accept", "text/html,application/xhtml+xml,application/xml;q=0.9;q=0.8") c.Accepted = nil @@ -1325,7 +1571,7 @@ func TestContextNegotiationFormatCustom(t *testing.T) { func TestContextNegotiationFormat2(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) - c.Request, _ = http.NewRequest("POST", "/", nil) + c.Request, _ = http.NewRequest(http.MethodPost, "/", nil) c.Request.Header.Add("Accept", "image/tiff-fx") assert.Equal(t, "", c.NegotiateFormat("image/tiff")) @@ -1345,7 +1591,7 @@ func TestContextIsAborted(t *testing.T) { assert.True(t, c.IsAborted()) } -// TestContextData tests that the response can be written from `bytestring` +// TestContextAbortWithStatus tests that the response can be written from `bytestring` // with specified MIME type func TestContextAbortWithStatus(t *testing.T) { w := httptest.NewRecorder() @@ -1386,7 +1632,7 @@ func TestContextAbortWithStatusJSON(t *testing.T) { buf := new(bytes.Buffer) _, err := buf.ReadFrom(w.Body) - assert.NoError(t, err) + require.NoError(t, err) jsonStringBody := buf.String() assert.Equal(t, "{\"foo\":\"fooValue\",\"bar\":\"barValue\"}", jsonStringBody) } @@ -1453,7 +1699,7 @@ func TestContextAbortWithError(t *testing.T) { func TestContextClientIP(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) - c.Request, _ = http.NewRequest("POST", "/", nil) + c.Request, _ = http.NewRequest(http.MethodPost, "/", nil) c.engine.trustedCIDRs, _ = c.engine.prepareTrustedCIDRs() resetContextForClientIPTests(c) @@ -1596,7 +1842,7 @@ func resetContextForClientIPTests(c *Context) { func TestContextContentType(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) - c.Request, _ = http.NewRequest("POST", "/", nil) + c.Request, _ = http.NewRequest(http.MethodPost, "/", nil) c.Request.Header.Set("Content-Type", "application/json; charset=utf-8") assert.Equal(t, "application/json", c.ContentType()) @@ -1624,14 +1870,14 @@ func TestContextBindRequestTooLarge(t *testing.T) { func TestContextAutoBindJSON(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) - c.Request, _ = http.NewRequest("POST", "/", strings.NewReader(`{"foo":"bar", "bar":"foo"}`)) + c.Request, _ = http.NewRequest(http.MethodPost, "/", strings.NewReader(`{"foo":"bar", "bar":"foo"}`)) c.Request.Header.Add("Content-Type", MIMEJSON) var obj struct { Foo string `json:"foo"` Bar string `json:"bar"` } - assert.NoError(t, c.Bind(&obj)) + require.NoError(t, c.Bind(&obj)) assert.Equal(t, "foo", obj.Bar) assert.Equal(t, "bar", obj.Foo) assert.Empty(t, c.Errors) @@ -1641,14 +1887,14 @@ func TestContextBindWithJSON(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest("POST", "/", strings.NewReader(`{"foo":"bar", "bar":"foo"}`)) + c.Request, _ = http.NewRequest(http.MethodPost, "/", strings.NewReader(`{"foo":"bar", "bar":"foo"}`)) c.Request.Header.Add("Content-Type", MIMEXML) // set fake content-type var obj struct { Foo string `json:"foo"` Bar string `json:"bar"` } - assert.NoError(t, c.BindJSON(&obj)) + require.NoError(t, c.BindJSON(&obj)) assert.Equal(t, "foo", obj.Bar) assert.Equal(t, "bar", obj.Foo) assert.Equal(t, 0, w.Body.Len()) @@ -1658,7 +1904,7 @@ func TestContextBindWithXML(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest("POST", "/", strings.NewReader(` + c.Request, _ = http.NewRequest(http.MethodPost, "/", strings.NewReader(` FOO BAR @@ -1669,17 +1915,41 @@ func TestContextBindWithXML(t *testing.T) { Foo string `xml:"foo"` Bar string `xml:"bar"` } - assert.NoError(t, c.BindXML(&obj)) + require.NoError(t, c.BindXML(&obj)) assert.Equal(t, "FOO", obj.Foo) assert.Equal(t, "BAR", obj.Bar) assert.Equal(t, 0, w.Body.Len()) } +func TestContextBindPlain(t *testing.T) { + // string + w := httptest.NewRecorder() + c, _ := CreateTestContext(w) + c.Request, _ = http.NewRequest(http.MethodPost, "/", bytes.NewBufferString(`test string`)) + c.Request.Header.Add("Content-Type", MIMEPlain) + + var s string + + require.NoError(t, c.BindPlain(&s)) + assert.Equal(t, "test string", s) + assert.Equal(t, 0, w.Body.Len()) + + // []byte + c.Request, _ = http.NewRequest(http.MethodPost, "/", bytes.NewBufferString(`test []byte`)) + c.Request.Header.Add("Content-Type", MIMEPlain) + + var bs []byte + + require.NoError(t, c.BindPlain(&bs)) + assert.Equal(t, []byte("test []byte"), bs) + assert.Equal(t, 0, w.Body.Len()) +} + func TestContextBindHeader(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest("POST", "/", nil) + c.Request, _ = http.NewRequest(http.MethodPost, "/", nil) c.Request.Header.Add("rate", "8000") c.Request.Header.Add("domain", "music") c.Request.Header.Add("limit", "1000") @@ -1690,7 +1960,7 @@ func TestContextBindHeader(t *testing.T) { Limit int `header:"limit"` } - assert.NoError(t, c.BindHeader(&testHeader)) + require.NoError(t, c.BindHeader(&testHeader)) assert.Equal(t, 8000, testHeader.Rate) assert.Equal(t, "music", testHeader.Domain) assert.Equal(t, 1000, testHeader.Limit) @@ -1701,13 +1971,13 @@ func TestContextBindWithQuery(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest("POST", "/?foo=bar&bar=foo", strings.NewReader("foo=unused")) + c.Request, _ = http.NewRequest(http.MethodPost, "/?foo=bar&bar=foo", strings.NewReader("foo=unused")) var obj struct { Foo string `form:"foo"` Bar string `form:"bar"` } - assert.NoError(t, c.BindQuery(&obj)) + require.NoError(t, c.BindQuery(&obj)) assert.Equal(t, "foo", obj.Bar) assert.Equal(t, "bar", obj.Foo) assert.Equal(t, 0, w.Body.Len()) @@ -1717,14 +1987,14 @@ func TestContextBindWithYAML(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest("POST", "/", strings.NewReader("foo: bar\nbar: foo")) + c.Request, _ = http.NewRequest(http.MethodPost, "/", strings.NewReader("foo: bar\nbar: foo")) c.Request.Header.Add("Content-Type", MIMEXML) // set fake content-type var obj struct { Foo string `yaml:"foo"` Bar string `yaml:"bar"` } - assert.NoError(t, c.BindYAML(&obj)) + require.NoError(t, c.BindYAML(&obj)) assert.Equal(t, "foo", obj.Bar) assert.Equal(t, "bar", obj.Foo) assert.Equal(t, 0, w.Body.Len()) @@ -1734,14 +2004,14 @@ func TestContextBindWithTOML(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest("POST", "/", strings.NewReader("foo = 'bar'\nbar = 'foo'")) + c.Request, _ = http.NewRequest(http.MethodPost, "/", strings.NewReader("foo = 'bar'\nbar = 'foo'")) c.Request.Header.Add("Content-Type", MIMEXML) // set fake content-type var obj struct { Foo string `toml:"foo"` Bar string `toml:"bar"` } - assert.NoError(t, c.BindTOML(&obj)) + require.NoError(t, c.BindTOML(&obj)) assert.Equal(t, "foo", obj.Bar) assert.Equal(t, "bar", obj.Foo) assert.Equal(t, 0, w.Body.Len()) @@ -1751,7 +2021,7 @@ func TestContextBadAutoBind(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest("POST", "http://example.com", strings.NewReader("\"foo\":\"bar\", \"bar\":\"foo\"}")) + c.Request, _ = http.NewRequest(http.MethodPost, "http://example.com", strings.NewReader("\"foo\":\"bar\", \"bar\":\"foo\"}")) c.Request.Header.Add("Content-Type", MIMEJSON) var obj struct { Foo string `json:"foo"` @@ -1759,7 +2029,7 @@ func TestContextBadAutoBind(t *testing.T) { } assert.False(t, c.IsAborted()) - assert.Error(t, c.Bind(&obj)) + require.Error(t, c.Bind(&obj)) c.Writer.WriteHeaderNow() assert.Empty(t, obj.Bar) @@ -1770,14 +2040,14 @@ func TestContextBadAutoBind(t *testing.T) { func TestContextAutoShouldBindJSON(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) - c.Request, _ = http.NewRequest("POST", "/", strings.NewReader(`{"foo":"bar", "bar":"foo"}`)) + c.Request, _ = http.NewRequest(http.MethodPost, "/", strings.NewReader(`{"foo":"bar", "bar":"foo"}`)) c.Request.Header.Add("Content-Type", MIMEJSON) var obj struct { Foo string `json:"foo"` Bar string `json:"bar"` } - assert.NoError(t, c.ShouldBind(&obj)) + require.NoError(t, c.ShouldBind(&obj)) assert.Equal(t, "foo", obj.Bar) assert.Equal(t, "bar", obj.Foo) assert.Empty(t, c.Errors) @@ -1787,14 +2057,14 @@ func TestContextShouldBindWithJSON(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest("POST", "/", strings.NewReader(`{"foo":"bar", "bar":"foo"}`)) + c.Request, _ = http.NewRequest(http.MethodPost, "/", strings.NewReader(`{"foo":"bar", "bar":"foo"}`)) c.Request.Header.Add("Content-Type", MIMEXML) // set fake content-type var obj struct { Foo string `json:"foo"` Bar string `json:"bar"` } - assert.NoError(t, c.ShouldBindJSON(&obj)) + require.NoError(t, c.ShouldBindJSON(&obj)) assert.Equal(t, "foo", obj.Bar) assert.Equal(t, "bar", obj.Foo) assert.Equal(t, 0, w.Body.Len()) @@ -1804,7 +2074,7 @@ func TestContextShouldBindWithXML(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest("POST", "/", strings.NewReader(` + c.Request, _ = http.NewRequest(http.MethodPost, "/", strings.NewReader(` FOO BAR @@ -1815,17 +2085,41 @@ func TestContextShouldBindWithXML(t *testing.T) { Foo string `xml:"foo"` Bar string `xml:"bar"` } - assert.NoError(t, c.ShouldBindXML(&obj)) + require.NoError(t, c.ShouldBindXML(&obj)) assert.Equal(t, "FOO", obj.Foo) assert.Equal(t, "BAR", obj.Bar) assert.Equal(t, 0, w.Body.Len()) } +func TestContextShouldBindPlain(t *testing.T) { + // string + w := httptest.NewRecorder() + c, _ := CreateTestContext(w) + c.Request, _ = http.NewRequest(http.MethodPost, "/", bytes.NewBufferString(`test string`)) + c.Request.Header.Add("Content-Type", MIMEPlain) + + var s string + + require.NoError(t, c.ShouldBindPlain(&s)) + assert.Equal(t, "test string", s) + assert.Equal(t, 0, w.Body.Len()) + // []byte + + c.Request, _ = http.NewRequest(http.MethodPost, "/", bytes.NewBufferString(`test []byte`)) + c.Request.Header.Add("Content-Type", MIMEPlain) + + var bs []byte + + require.NoError(t, c.ShouldBindPlain(&bs)) + assert.Equal(t, []byte("test []byte"), bs) + assert.Equal(t, 0, w.Body.Len()) +} + func TestContextShouldBindHeader(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest("POST", "/", nil) + c.Request, _ = http.NewRequest(http.MethodPost, "/", nil) c.Request.Header.Add("rate", "8000") c.Request.Header.Add("domain", "music") c.Request.Header.Add("limit", "1000") @@ -1836,7 +2130,7 @@ func TestContextShouldBindHeader(t *testing.T) { Limit int `header:"limit"` } - assert.NoError(t, c.ShouldBindHeader(&testHeader)) + require.NoError(t, c.ShouldBindHeader(&testHeader)) assert.Equal(t, 8000, testHeader.Rate) assert.Equal(t, "music", testHeader.Domain) assert.Equal(t, 1000, testHeader.Limit) @@ -1847,7 +2141,7 @@ func TestContextShouldBindWithQuery(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest("POST", "/?foo=bar&bar=foo&Foo=bar1&Bar=foo1", strings.NewReader("foo=unused")) + c.Request, _ = http.NewRequest(http.MethodPost, "/?foo=bar&bar=foo&Foo=bar1&Bar=foo1", strings.NewReader("foo=unused")) var obj struct { Foo string `form:"foo"` @@ -1855,7 +2149,7 @@ func TestContextShouldBindWithQuery(t *testing.T) { Foo1 string `form:"Foo"` Bar1 string `form:"Bar"` } - assert.NoError(t, c.ShouldBindQuery(&obj)) + require.NoError(t, c.ShouldBindQuery(&obj)) assert.Equal(t, "foo", obj.Bar) assert.Equal(t, "bar", obj.Foo) assert.Equal(t, "foo1", obj.Bar1) @@ -1867,14 +2161,14 @@ func TestContextShouldBindWithYAML(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest("POST", "/", strings.NewReader("foo: bar\nbar: foo")) + c.Request, _ = http.NewRequest(http.MethodPost, "/", strings.NewReader("foo: bar\nbar: foo")) c.Request.Header.Add("Content-Type", MIMEXML) // set fake content-type var obj struct { Foo string `yaml:"foo"` Bar string `yaml:"bar"` } - assert.NoError(t, c.ShouldBindYAML(&obj)) + require.NoError(t, c.ShouldBindYAML(&obj)) assert.Equal(t, "foo", obj.Bar) assert.Equal(t, "bar", obj.Foo) assert.Equal(t, 0, w.Body.Len()) @@ -1884,14 +2178,14 @@ func TestContextShouldBindWithTOML(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest("POST", "/", strings.NewReader("foo='bar'\nbar= 'foo'")) + c.Request, _ = http.NewRequest(http.MethodPost, "/", strings.NewReader("foo='bar'\nbar= 'foo'")) c.Request.Header.Add("Content-Type", MIMETOML) // set fake content-type var obj struct { Foo string `toml:"foo"` Bar string `toml:"bar"` } - assert.NoError(t, c.ShouldBindTOML(&obj)) + require.NoError(t, c.ShouldBindTOML(&obj)) assert.Equal(t, "foo", obj.Bar) assert.Equal(t, "bar", obj.Foo) assert.Equal(t, 0, w.Body.Len()) @@ -1901,7 +2195,7 @@ func TestContextBadAutoShouldBind(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest("POST", "http://example.com", strings.NewReader("\"foo\":\"bar\", \"bar\":\"foo\"}")) + c.Request, _ = http.NewRequest(http.MethodPost, "http://example.com", strings.NewReader(`"foo":"bar", "bar":"foo"}`)) c.Request.Header.Add("Content-Type", MIMEJSON) var obj struct { Foo string `json:"foo"` @@ -1909,7 +2203,7 @@ func TestContextBadAutoShouldBind(t *testing.T) { } assert.False(t, c.IsAborted()) - assert.Error(t, c.ShouldBind(&obj)) + require.Error(t, c.ShouldBind(&obj)) assert.Empty(t, obj.Bar) assert.Empty(t, obj.Foo) @@ -1965,15 +2259,15 @@ func TestContextShouldBindBodyWith(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) c.Request, _ = http.NewRequest( - "POST", "http://example.com", strings.NewReader(tt.bodyA), + http.MethodPost, "http://example.com", strings.NewReader(tt.bodyA), ) // When it binds to typeA and typeB, it finds the body is // not typeB but typeA. objA := typeA{} - assert.NoError(t, c.ShouldBindBodyWith(&objA, tt.bindingA)) + require.NoError(t, c.ShouldBindBodyWith(&objA, tt.bindingA)) assert.Equal(t, typeA{"FOO"}, objA) objB := typeB{} - assert.Error(t, c.ShouldBindBodyWith(&objB, tt.bindingB)) + require.Error(t, c.ShouldBindBodyWith(&objB, tt.bindingB)) assert.NotEqual(t, typeB{"BAR"}, objB) } // bodyB to typeA and typeB @@ -1983,27 +2277,359 @@ func TestContextShouldBindBodyWith(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) c.Request, _ = http.NewRequest( - "POST", "http://example.com", strings.NewReader(tt.bodyB), + http.MethodPost, "http://example.com", strings.NewReader(tt.bodyB), ) objA := typeA{} - assert.Error(t, c.ShouldBindBodyWith(&objA, tt.bindingA)) + require.Error(t, c.ShouldBindBodyWith(&objA, tt.bindingA)) assert.NotEqual(t, typeA{"FOO"}, objA) objB := typeB{} - assert.NoError(t, c.ShouldBindBodyWith(&objB, tt.bindingB)) + require.NoError(t, c.ShouldBindBodyWith(&objB, tt.bindingB)) assert.Equal(t, typeB{"BAR"}, objB) } } } +func TestContextShouldBindBodyWithJSON(t *testing.T) { + for _, tt := range []struct { + name string + bindingBody binding.BindingBody + body string + }{ + { + name: " JSON & JSON-BODY ", + bindingBody: binding.JSON, + body: `{"foo":"FOO"}`, + }, + { + name: " JSON & XML-BODY ", + bindingBody: binding.XML, + body: ` + +FOO +`, + }, + { + name: " JSON & YAML-BODY ", + bindingBody: binding.YAML, + body: `foo: FOO`, + }, + { + name: " JSON & TOM-BODY ", + bindingBody: binding.TOML, + body: `foo=FOO`, + }, + } { + t.Logf("testing: %s", tt.name) + + w := httptest.NewRecorder() + c, _ := CreateTestContext(w) + + c.Request, _ = http.NewRequest(http.MethodPost, "/", bytes.NewBufferString(tt.body)) + + type typeJSON struct { + Foo string `json:"foo" binding:"required"` + } + objJSON := typeJSON{} + + if tt.bindingBody == binding.JSON { + require.NoError(t, c.ShouldBindBodyWithJSON(&objJSON)) + assert.Equal(t, typeJSON{"FOO"}, objJSON) + } + + if tt.bindingBody == binding.XML { + require.Error(t, c.ShouldBindBodyWithJSON(&objJSON)) + assert.Equal(t, typeJSON{}, objJSON) + } + + if tt.bindingBody == binding.YAML { + require.Error(t, c.ShouldBindBodyWithJSON(&objJSON)) + assert.Equal(t, typeJSON{}, objJSON) + } + + if tt.bindingBody == binding.TOML { + require.Error(t, c.ShouldBindBodyWithJSON(&objJSON)) + assert.Equal(t, typeJSON{}, objJSON) + } + } +} + +func TestContextShouldBindBodyWithXML(t *testing.T) { + for _, tt := range []struct { + name string + bindingBody binding.BindingBody + body string + }{ + { + name: " XML & JSON-BODY ", + bindingBody: binding.JSON, + body: `{"foo":"FOO"}`, + }, + { + name: " XML & XML-BODY ", + bindingBody: binding.XML, + body: ` + +FOO +`, + }, + { + name: " XML & YAML-BODY ", + bindingBody: binding.YAML, + body: `foo: FOO`, + }, + { + name: " XML & TOM-BODY ", + bindingBody: binding.TOML, + body: `foo=FOO`, + }, + } { + t.Logf("testing: %s", tt.name) + + w := httptest.NewRecorder() + c, _ := CreateTestContext(w) + + c.Request, _ = http.NewRequest(http.MethodPost, "/", bytes.NewBufferString(tt.body)) + + type typeXML struct { + Foo string `xml:"foo" binding:"required"` + } + objXML := typeXML{} + + if tt.bindingBody == binding.JSON { + require.Error(t, c.ShouldBindBodyWithXML(&objXML)) + assert.Equal(t, typeXML{}, objXML) + } + + if tt.bindingBody == binding.XML { + require.NoError(t, c.ShouldBindBodyWithXML(&objXML)) + assert.Equal(t, typeXML{"FOO"}, objXML) + } + + if tt.bindingBody == binding.YAML { + require.Error(t, c.ShouldBindBodyWithXML(&objXML)) + assert.Equal(t, typeXML{}, objXML) + } + + if tt.bindingBody == binding.TOML { + require.Error(t, c.ShouldBindBodyWithXML(&objXML)) + assert.Equal(t, typeXML{}, objXML) + } + } +} + +func TestContextShouldBindBodyWithYAML(t *testing.T) { + for _, tt := range []struct { + name string + bindingBody binding.BindingBody + body string + }{ + { + name: " YAML & JSON-BODY ", + bindingBody: binding.JSON, + body: `{"foo":"FOO"}`, + }, + { + name: " YAML & XML-BODY ", + bindingBody: binding.XML, + body: ` + +FOO +`, + }, + { + name: " YAML & YAML-BODY ", + bindingBody: binding.YAML, + body: `foo: FOO`, + }, + { + name: " YAML & TOM-BODY ", + bindingBody: binding.TOML, + body: `foo=FOO`, + }, + } { + t.Logf("testing: %s", tt.name) + + w := httptest.NewRecorder() + c, _ := CreateTestContext(w) + + c.Request, _ = http.NewRequest(http.MethodPost, "/", bytes.NewBufferString(tt.body)) + + type typeYAML struct { + Foo string `yaml:"foo" binding:"required"` + } + objYAML := typeYAML{} + + // YAML belongs to a super collection of JSON, so JSON can be parsed by YAML + if tt.bindingBody == binding.JSON { + require.NoError(t, c.ShouldBindBodyWithYAML(&objYAML)) + assert.Equal(t, typeYAML{"FOO"}, objYAML) + } + + if tt.bindingBody == binding.XML { + require.Error(t, c.ShouldBindBodyWithYAML(&objYAML)) + assert.Equal(t, typeYAML{}, objYAML) + } + + if tt.bindingBody == binding.YAML { + require.NoError(t, c.ShouldBindBodyWithYAML(&objYAML)) + assert.Equal(t, typeYAML{"FOO"}, objYAML) + } + + if tt.bindingBody == binding.TOML { + require.Error(t, c.ShouldBindBodyWithYAML(&objYAML)) + assert.Equal(t, typeYAML{}, objYAML) + } + } +} + +func TestContextShouldBindBodyWithTOML(t *testing.T) { + for _, tt := range []struct { + name string + bindingBody binding.BindingBody + body string + }{ + { + name: " TOML & JSON-BODY ", + bindingBody: binding.JSON, + body: `{"foo":"FOO"}`, + }, + { + name: " TOML & XML-BODY ", + bindingBody: binding.XML, + body: ` + +FOO +`, + }, + { + name: " TOML & YAML-BODY ", + bindingBody: binding.YAML, + body: `foo: FOO`, + }, + { + name: " TOML & TOM-BODY ", + bindingBody: binding.TOML, + body: `foo = 'FOO'`, + }, + } { + t.Logf("testing: %s", tt.name) + + w := httptest.NewRecorder() + c, _ := CreateTestContext(w) + + c.Request, _ = http.NewRequest(http.MethodPost, "/", bytes.NewBufferString(tt.body)) + + type typeTOML struct { + Foo string `toml:"foo" binding:"required"` + } + objTOML := typeTOML{} + + if tt.bindingBody == binding.JSON { + require.Error(t, c.ShouldBindBodyWithTOML(&objTOML)) + assert.Equal(t, typeTOML{}, objTOML) + } + + if tt.bindingBody == binding.XML { + require.Error(t, c.ShouldBindBodyWithTOML(&objTOML)) + assert.Equal(t, typeTOML{}, objTOML) + } + + if tt.bindingBody == binding.YAML { + require.Error(t, c.ShouldBindBodyWithTOML(&objTOML)) + assert.Equal(t, typeTOML{}, objTOML) + } + + if tt.bindingBody == binding.TOML { + require.NoError(t, c.ShouldBindBodyWithTOML(&objTOML)) + assert.Equal(t, typeTOML{"FOO"}, objTOML) + } + } +} + +func TestContextShouldBindBodyWithPlain(t *testing.T) { + for _, tt := range []struct { + name string + bindingBody binding.BindingBody + body string + }{ + { + name: " JSON & JSON-BODY ", + bindingBody: binding.JSON, + body: `{"foo":"FOO"}`, + }, + { + name: " JSON & XML-BODY ", + bindingBody: binding.XML, + body: ` + +FOO +`, + }, + { + name: " JSON & YAML-BODY ", + bindingBody: binding.YAML, + body: `foo: FOO`, + }, + { + name: " JSON & TOM-BODY ", + bindingBody: binding.TOML, + body: `foo=FOO`, + }, + { + name: " JSON & Plain-BODY ", + bindingBody: binding.Plain, + body: `foo=FOO`, + }, + } { + t.Logf("testing: %s", tt.name) + + w := httptest.NewRecorder() + c, _ := CreateTestContext(w) + + c.Request, _ = http.NewRequest(http.MethodPost, "/", bytes.NewBufferString(tt.body)) + + type typeJSON struct { + Foo string `json:"foo" binding:"required"` + } + objJSON := typeJSON{} + + if tt.bindingBody == binding.Plain { + body := "" + require.NoError(t, c.ShouldBindBodyWithPlain(&body)) + assert.Equal(t, "foo=FOO", body) + } + + if tt.bindingBody == binding.JSON { + require.NoError(t, c.ShouldBindBodyWithJSON(&objJSON)) + assert.Equal(t, typeJSON{"FOO"}, objJSON) + } + + if tt.bindingBody == binding.XML { + require.Error(t, c.ShouldBindBodyWithJSON(&objJSON)) + assert.Equal(t, typeJSON{}, objJSON) + } + + if tt.bindingBody == binding.YAML { + require.Error(t, c.ShouldBindBodyWithJSON(&objJSON)) + assert.Equal(t, typeJSON{}, objJSON) + } + + if tt.bindingBody == binding.TOML { + require.Error(t, c.ShouldBindBodyWithJSON(&objJSON)) + assert.Equal(t, typeJSON{}, objJSON) + } + } +} + func TestContextGolangContext(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) - c.Request, _ = http.NewRequest("POST", "/", strings.NewReader(`{"foo":"bar", "bar":"foo"}`)) - assert.NoError(t, c.Err()) + c.Request, _ = http.NewRequest(http.MethodPost, "/", strings.NewReader(`{"foo":"bar", "bar":"foo"}`)) + require.NoError(t, c.Err()) assert.Nil(t, c.Done()) ti, ok := c.Deadline() - assert.Equal(t, ti, time.Time{}) + assert.Equal(t, time.Time{}, ti) assert.False(t, ok) - assert.Equal(t, c.Value(0), c.Request) + assert.Equal(t, c.Value(ContextRequestKey), c.Request) assert.Equal(t, c.Value(ContextKey), c) assert.Nil(t, c.Value("foo")) @@ -2015,7 +2641,7 @@ func TestContextGolangContext(t *testing.T) { func TestWebsocketsRequired(t *testing.T) { // Example request from spec: https://tools.ietf.org/html/rfc6455#section-1.2 c, _ := CreateTestContext(httptest.NewRecorder()) - c.Request, _ = http.NewRequest("GET", "/chat", nil) + c.Request, _ = http.NewRequest(http.MethodGet, "/chat", nil) c.Request.Header.Set("Host", "server.example.com") c.Request.Header.Set("Upgrade", "websocket") c.Request.Header.Set("Connection", "Upgrade") @@ -2028,7 +2654,7 @@ func TestWebsocketsRequired(t *testing.T) { // Normal request, no websocket required. c, _ = CreateTestContext(httptest.NewRecorder()) - c.Request, _ = http.NewRequest("GET", "/chat", nil) + c.Request, _ = http.NewRequest(http.MethodGet, "/chat", nil) c.Request.Header.Set("Host", "server.example.com") assert.False(t, c.IsWebsocket()) @@ -2036,7 +2662,7 @@ func TestWebsocketsRequired(t *testing.T) { func TestGetRequestHeaderValue(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) - c.Request, _ = http.NewRequest("GET", "/chat", nil) + c.Request, _ = http.NewRequest(http.MethodGet, "/chat", nil) c.Request.Header.Set("Gin-Version", "1.0.0") assert.Equal(t, "1.0.0", c.GetHeader("Gin-Version")) @@ -2046,11 +2672,11 @@ func TestGetRequestHeaderValue(t *testing.T) { func TestContextGetRawData(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) body := strings.NewReader("Fetch binary post data") - c.Request, _ = http.NewRequest("POST", "/", body) + c.Request, _ = http.NewRequest(http.MethodPost, "/", body) c.Request.Header.Add("Content-Type", MIMEPOSTForm) data, err := c.GetRawData() - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, "Fetch binary post data", string(data)) } @@ -2069,7 +2695,7 @@ func TestContextRenderDataFromReader(t *testing.T) { assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, body, w.Body.String()) assert.Equal(t, contentType, w.Header().Get("Content-Type")) - assert.Equal(t, fmt.Sprintf("%d", contentLength), w.Header().Get("Content-Length")) + assert.Equal(t, strconv.FormatInt(contentLength, 10), w.Header().Get("Content-Length")) assert.Equal(t, extraHeaders["Content-Disposition"], w.Header().Get("Content-Disposition")) } @@ -2087,7 +2713,7 @@ func TestContextRenderDataFromReaderNoHeaders(t *testing.T) { assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, body, w.Body.String()) assert.Equal(t, contentType, w.Header().Get("Content-Type")) - assert.Equal(t, fmt.Sprintf("%d", contentLength), w.Header().Get("Content-Length")) + assert.Equal(t, strconv.FormatInt(contentLength, 10), w.Header().Get("Content-Length")) } type TestResponseRecorder struct { @@ -2121,7 +2747,7 @@ func TestContextStream(t *testing.T) { }() _, err := w.Write([]byte("test")) - assert.NoError(t, err) + require.NoError(t, err) return stopStream }) @@ -2139,7 +2765,7 @@ func TestContextStreamWithClientGone(t *testing.T) { }() _, err := writer.Write([]byte("test")) - assert.NoError(t, err) + require.NoError(t, err) return true }) @@ -2175,8 +2801,8 @@ func TestRaceParamsContextCopy(t *testing.T) { }(c.Copy(), c.Param("name")) }) } - PerformRequest(router, "GET", "/name1/api") - PerformRequest(router, "GET", "/name2/api") + PerformRequest(router, http.MethodGet, "/name1/api") + PerformRequest(router, http.MethodGet, "/name2/api") wg.Wait() } @@ -2195,7 +2821,7 @@ func TestContextWithKeysMutex(t *testing.T) { func TestRemoteIPFail(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) - c.Request, _ = http.NewRequest("POST", "/", nil) + c.Request, _ = http.NewRequest(http.MethodPost, "/", nil) c.Request.RemoteAddr = "[:::]:80" ip := net.ParseIP(c.RemoteIP()) trust := c.engine.isTrustedProxy(ip) @@ -2267,7 +2893,7 @@ func TestContextWithFallbackErrFromRequestContext(t *testing.T) { // enable ContextWithFallback feature flag c.engine.ContextWithFallback = true - assert.Nil(t, c.Err()) + require.NoError(t, c.Err()) c2, _ := CreateTestContext(httptest.NewRecorder()) // enable ContextWithFallback feature flag @@ -2292,11 +2918,12 @@ func TestContextWithFallbackValueFromRequestContext(t *testing.T) { { name: "c with struct context key", getContextAndKey: func() (*Context, any) { - var key struct{} + type KeyStruct struct{} // https://staticcheck.dev/docs/checks/#SA1029 + var key KeyStruct c, _ := CreateTestContext(httptest.NewRecorder()) // enable ContextWithFallback feature flag c.engine.ContextWithFallback = true - c.Request, _ = http.NewRequest("POST", "/", nil) + c.Request, _ = http.NewRequest(http.MethodPost, "/", nil) c.Request = c.Request.WithContext(context.WithValue(context.TODO(), key, "value")) return c, key }, @@ -2308,7 +2935,7 @@ func TestContextWithFallbackValueFromRequestContext(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) // enable ContextWithFallback feature flag c.engine.ContextWithFallback = true - c.Request, _ = http.NewRequest("POST", "/", nil) + c.Request, _ = http.NewRequest(http.MethodPost, "/", nil) c.Request = c.Request.WithContext(context.WithValue(context.TODO(), contextKey("key"), "value")) return c, contextKey("key") }, @@ -2331,7 +2958,7 @@ func TestContextWithFallbackValueFromRequestContext(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) // enable ContextWithFallback feature flag c.engine.ContextWithFallback = true - c.Request, _ = http.NewRequest("POST", "/", nil) + c.Request, _ = http.NewRequest(http.MethodPost, "/", nil) return c, "key" }, value: nil, @@ -2416,7 +3043,7 @@ func TestContextAddParam(t *testing.T) { c.AddParam(id, value) v, ok := c.Params.Get(id) - assert.Equal(t, ok, true) + assert.True(t, ok) assert.Equal(t, value, v) } @@ -2463,7 +3090,7 @@ func TestInterceptedHeader(t *testing.T) { c.Header("X-Test-2", "present") c.String(http.StatusOK, "hello world") }) - c.Request = httptest.NewRequest("GET", "/", nil) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) r.HandleContext(c) // Result() has headers frozen when WriteHeaderNow() has been called // Compared to this time, this is when the response headers will be flushed @@ -2472,3 +3099,47 @@ func TestInterceptedHeader(t *testing.T) { assert.Equal(t, "", w.Result().Header.Get("X-Test")) assert.Equal(t, "present", w.Result().Header.Get("X-Test-2")) } + +func TestContextNext(t *testing.T) { + c, _ := CreateTestContext(httptest.NewRecorder()) + + // Test with no handlers + c.Next() + assert.Equal(t, int8(0), c.index) + + // Test with one handler + c.index = -1 + c.handlers = HandlersChain{func(c *Context) { + c.Set("key", "value") + }} + c.Next() + assert.Equal(t, int8(1), c.index) + value, exists := c.Get("key") + assert.True(t, exists) + assert.Equal(t, "value", value) + + // Test with multiple handlers + c.handlers = HandlersChain{ + func(c *Context) { + c.Set("key1", "value1") + c.Next() + c.Set("key2", "value2") + }, + nil, + func(c *Context) { + c.Set("key3", "value3") + }, + } + c.index = -1 + c.Next() + assert.Equal(t, int8(4), c.index) + value, exists = c.Get("key1") + assert.True(t, exists) + assert.Equal(t, "value1", value) + value, exists = c.Get("key2") + assert.True(t, exists) + assert.Equal(t, "value2", value) + value, exists = c.Get("key3") + assert.True(t, exists) + assert.Equal(t, "value3", value) +} diff --git a/debug.go b/debug.go index 1fc0cafe..f2016168 100644 --- a/debug.go +++ b/debug.go @@ -10,19 +10,23 @@ import ( "runtime" "strconv" "strings" + "sync/atomic" ) -const ginSupportMinGoVer = 18 +const ginSupportMinGoVer = 21 // IsDebugging returns true if the framework is running in debug mode. // Use SetMode(gin.ReleaseMode) to disable debug mode. func IsDebugging() bool { - return ginMode == debugCode + return atomic.LoadInt32(&ginMode) == debugCode } // DebugPrintRouteFunc indicates debug log output format. var DebugPrintRouteFunc func(httpMethod, absolutePath, handlerName string, nuHandlers int) +// DebugPrintFunc indicates debug log output format. +var DebugPrintFunc func(format string, values ...interface{}) + func debugPrintRoute(httpMethod, absolutePath string, handlers HandlersChain) { if IsDebugging() { nuHandlers := len(handlers) @@ -48,12 +52,19 @@ func debugPrintLoadTemplate(tmpl *template.Template) { } func debugPrint(format string, values ...any) { - if IsDebugging() { - if !strings.HasSuffix(format, "\n") { - format += "\n" - } - fmt.Fprintf(DefaultWriter, "[GIN-debug] "+format, values...) + if !IsDebugging() { + return } + + if DebugPrintFunc != nil { + DebugPrintFunc(format, values...) + return + } + + if !strings.HasSuffix(format, "\n") { + format += "\n" + } + fmt.Fprintf(DefaultWriter, "[GIN-debug] "+format, values...) } func getMinVer(v string) (uint64, error) { @@ -67,7 +78,7 @@ func getMinVer(v string) (uint64, error) { func debugPrintWARNINGDefault() { if v, e := getMinVer(runtime.Version()); e == nil && v < ginSupportMinGoVer { - debugPrint(`[WARNING] Now Gin requires Go 1.18+. + debugPrint(`[WARNING] Now Gin requires Go 1.23+. `) } diff --git a/debug_test.go b/debug_test.go index 2d5e9a56..59b61beb 100644 --- a/debug_test.go +++ b/debug_test.go @@ -10,6 +10,7 @@ import ( "html/template" "io" "log" + "net/http" "os" "runtime" "strings" @@ -17,6 +18,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) // TODO @@ -59,7 +61,7 @@ func TestDebugPrintError(t *testing.T) { func TestDebugPrintRoutes(t *testing.T) { re := captureOutput(t, func() { SetMode(DebugMode) - debugPrintRoute("GET", "/path/to/route/:param", HandlersChain{func(c *Context) {}, handlerNameTest}) + debugPrintRoute(http.MethodGet, "/path/to/route/:param", HandlersChain{func(c *Context) {}, handlerNameTest}) SetMode(TestMode) }) assert.Regexp(t, `^\[GIN-debug\] GET /path/to/route/:param --> (.*/vendor/)?github.com/gin-gonic/gin.handlerNameTest \(2 handlers\)\n$`, re) @@ -71,7 +73,7 @@ func TestDebugPrintRouteFunc(t *testing.T) { } re := captureOutput(t, func() { SetMode(DebugMode) - debugPrintRoute("GET", "/path/to/route/:param1/:param2", HandlersChain{func(c *Context) {}, handlerNameTest}) + debugPrintRoute(http.MethodGet, "/path/to/route/:param1/:param2", HandlersChain{func(c *Context) {}, handlerNameTest}) SetMode(TestMode) }) assert.Regexp(t, `^\[GIN-debug\] GET /path/to/route/:param1/:param2 --> (.*/vendor/)?github.com/gin-gonic/gin.handlerNameTest \(2 handlers\)\n$`, re) @@ -104,7 +106,7 @@ func TestDebugPrintWARNINGDefault(t *testing.T) { }) m, e := getMinVer(runtime.Version()) if e == nil && m < ginSupportMinGoVer { - assert.Equal(t, "[GIN-debug] [WARNING] Now Gin requires Go 1.18+.\n\n[GIN-debug] [WARNING] Creating an Engine instance with the Logger and Recovery middleware already attached.\n\n", re) + assert.Equal(t, "[GIN-debug] [WARNING] Now Gin requires Go 1.23+.\n\n[GIN-debug] [WARNING] Creating an Engine instance with the Logger and Recovery middleware already attached.\n\n", re) } else { assert.Equal(t, "[GIN-debug] [WARNING] Creating an Engine instance with the Logger and Recovery middleware already attached.\n\n", re) } @@ -154,13 +156,13 @@ func TestGetMinVer(t *testing.T) { var m uint64 var e error _, e = getMinVer("go1") - assert.NotNil(t, e) + require.Error(t, e) m, e = getMinVer("go1.1") assert.Equal(t, uint64(1), m) - assert.Nil(t, e) + require.NoError(t, e) m, e = getMinVer("go1.1.1") - assert.Nil(t, e) + require.NoError(t, e) assert.Equal(t, uint64(1), m) _, e = getMinVer("go1.1.1.1") - assert.NotNil(t, e) + require.Error(t, e) } diff --git a/deprecated.go b/deprecated.go index 9521308f..b4c6cd88 100644 --- a/deprecated.go +++ b/deprecated.go @@ -12,6 +12,8 @@ import ( // BindWith binds the passed struct pointer using the specified binding engine. // See the binding package. +// +// Deprecated: Use MustBindWith or ShouldBindWith. func (c *Context) BindWith(obj any, b binding.Binding) error { log.Println(`BindWith(\"any, binding.Binding\") error is going to be deprecated, please check issue #662 and either use MustBindWith() if you diff --git a/deprecated_test.go b/deprecated_test.go index 0240b2ec..6c8f2a7f 100644 --- a/deprecated_test.go +++ b/deprecated_test.go @@ -18,7 +18,7 @@ func TestBindWith(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest("POST", "/?foo=bar&bar=foo", bytes.NewBufferString("foo=unused")) + c.Request, _ = http.NewRequest(http.MethodPost, "/?foo=bar&bar=foo", bytes.NewBufferString("foo=unused")) var obj struct { Foo string `form:"foo"` diff --git a/docs/doc.md b/docs/doc.md index df006e87..ce466652 100644 --- a/docs/doc.md +++ b/docs/doc.md @@ -26,7 +26,10 @@ - [Custom Validators](#custom-validators) - [Only Bind Query String](#only-bind-query-string) - [Bind Query String or Post Data](#bind-query-string-or-post-data) + - [Bind default value if none provided](#bind-default-value-if-none-provided) + - [Collection format for arrays](#collection-format-for-arrays) - [Bind Uri](#bind-uri) + - [Bind custom unmarshaler](#bind-custom-unmarshaler) - [Bind Header](#bind-header) - [Bind HTML checkboxes](#bind-html-checkboxes) - [Multipart/Urlencoded binding](#multiparturlencoded-binding) @@ -67,7 +70,7 @@ ### Build with json replacement -Gin uses `encoding/json` as default json package but you can change it by build from other tags. +Gin uses `encoding/json` as the default JSON package but you can change it by building from other tags. [jsoniter](https://github.com/json-iterator/go) @@ -81,7 +84,7 @@ go build -tags=jsoniter . go build -tags=go_json . ``` -[sonic](https://github.com/bytedance/sonic) (you have to ensure that your cpu support avx instruction.) +[sonic](https://github.com/bytedance/sonic) (you have to ensure that your cpu supports avx instruction.) ```sh $ go build -tags="sonic avx" . @@ -117,7 +120,7 @@ func main() { router.HEAD("/someHead", head) router.OPTIONS("/someOptions", options) - // By default it serves on :8080 unless a + // By default, it serves on :8080 unless a // PORT environment variable was defined. router.Run() // router.Run(":3000") for a hard coded port @@ -169,7 +172,7 @@ func main() { router := gin.Default() // Query string parameters are parsed using the existing underlying request object. - // The request responds to an url matching: /welcome?firstname=Jane&lastname=Doe + // The request responds to a URL matching: /welcome?firstname=Jane&lastname=Doe router.GET("/welcome", func(c *gin.Context) { firstname := c.DefaultQuery("firstname", "Guest") lastname := c.Query("lastname") // shortcut for c.Request.URL.Query().Get("lastname") @@ -297,7 +300,7 @@ curl -X POST http://localhost:8080/upload \ #### Multiple files -See the detail [example code](https://github.com/gin-gonic/examples/tree/master/upload-file/multiple). +See the detailed [example code](https://github.com/gin-gonic/examples/tree/master/upload-file/multiple). ```go func main() { @@ -337,16 +340,16 @@ func main() { router := gin.Default() // Simple group: v1 - v1 := router.Group("/v1") { + v1 := router.Group("/v1") v1.POST("/login", loginEndpoint) v1.POST("/submit", submitEndpoint) v1.POST("/read", readEndpoint) } // Simple group: v2 - v2 := router.Group("/v2") { + v2 := router.Group("/v2") v2.POST("/login", loginEndpoint) v2.POST("/submit", submitEndpoint) v2.POST("/read", readEndpoint) @@ -513,19 +516,19 @@ Sample Output ```go func main() { router := gin.New() - + // skip logging for desired paths by setting SkipPaths in LoggerConfig loggerConfig := gin.LoggerConfig{SkipPaths: []string{"/metrics"}} - + // skip logging based on your logic by setting Skip func in LoggerConfig loggerConfig.Skip = func(c *gin.Context) bool { // as an example skip non server side errors return c.Writer.Status() < http.StatusInternalServerError } - - engine.Use(gin.LoggerWithConfig(loggerConfig)) + + router.Use(gin.LoggerWithConfig(loggerConfig)) router.Use(gin.Recovery()) - + // skipped router.GET("/metrics", func(c *gin.Context) { c.Status(http.StatusNotImplemented) @@ -540,7 +543,7 @@ func main() { router.GET("/data", func(c *gin.Context) { c.Status(http.StatusNotImplemented) }) - + router.Run(":8080") } @@ -612,7 +615,7 @@ You can also specify that specific fields are required. If a field is decorated ```go // Binding from JSON type Login struct { - User string `form:"user" json:"user" xml:"user" binding:"required"` + User string `form:"user" json:"user" xml:"user" binding:"required"` Password string `form:"password" json:"password" xml:"password" binding:"required"` } @@ -701,7 +704,7 @@ $ curl -v -X POST \ {"error":"Key: 'Login.Password' Error:Field validation for 'Password' failed on the 'required' tag"} ``` -Skip validate: when running the above example using the above the `curl` command, it returns error. Because the example use `binding:"required"` for `Password`. If use `binding:"-"` for `Password`, then it will not return error when running the above example again. +Skip-validation: Running the example above using the `curl` command returns an error. This is because the example uses `binding:"required"` for `Password`. If instead, you use `binding:"-"` for `Password`, then it will not return an error when you run the example again. ### Custom Validators @@ -829,6 +832,8 @@ type Person struct { Birthday time.Time `form:"birthday" time_format:"2006-01-02" time_utc:"1"` CreateTime time.Time `form:"createTime" time_format:"unixNano"` UnixTime time.Time `form:"unixTime" time_format:"unix"` + UnixMilliTime time.Time `form:"unixMilliTime" time_format:"unixmilli"` + UnixMicroTime time.Time `form:"unixMicroTime" time_format:"uNiXmIcRo"` // case does not matter for "unix*" time formats } func main() { @@ -848,6 +853,8 @@ func startPage(c *gin.Context) { log.Println(person.Birthday) log.Println(person.CreateTime) log.Println(person.UnixTime) + log.Println(person.UnixMilliTime) + log.Println(person.UnixMicroTime) } c.String(http.StatusOK, "Success") @@ -857,7 +864,107 @@ func startPage(c *gin.Context) { Test it with: ```sh -curl -X GET "localhost:8085/testing?name=appleboy&address=xyz&birthday=1992-03-15&createTime=1562400033000000123&unixTime=1562400033" +curl -X GET "localhost:8085/testing?name=appleboy&address=xyz&birthday=1992-03-15&createTime=1562400033000000123&unixTime=1562400033&unixMilliTime=1562400033001&unixMicroTime=1562400033000012" +``` + + +### Bind default value if none provided + +If the server should bind a default value to a field when the client does not provide one, specify the default value using the `default` key within the `form` tag: + +``` +package main + +import ( + "net/http" + + "github.com/gin-gonic/gin" +) + +type Person struct { + Name string `form:"name,default=William"` + Age int `form:"age,default=10"` + Friends []string `form:"friends,default=Will;Bill"` + Addresses [2]string `form:"addresses,default=foo bar" collection_format:"ssv"` + LapTimes []int `form:"lap_times,default=1;2;3" collection_format:"csv"` +} + +func main() { + g := gin.Default() + g.POST("/person", func(c *gin.Context) { + var req Person + if err := c.ShouldBindQuery(&req); err != nil { + c.JSON(http.StatusBadRequest, err) + return + } + c.JSON(http.StatusOK, req) + }) + _ = g.Run("localhost:8080") +} +``` + +``` +curl -X POST http://localhost:8080/person +{"Name":"William","Age":10,"Friends":["Will","Bill"],"Colors":["red","blue"],"LapTimes":[1,2,3]} +``` + +NOTE: For default [collection values](#collection-format-for-arrays), the following rules apply: +- Since commas are used to delimit tag options, they are not supported within a default value and will result in undefined behavior +- For the collection formats "multi" and "csv", a semicolon should be used in place of a comma to delimited default values +- Since semicolons are used to delimit default values for "multi" and "csv", they are not supported within a default value for "multi" and "csv" + + +#### Collection format for arrays + +| Format | Description | Example | +| --------------- | --------------------------------------------------------- | ----------------------- | +| multi (default) | Multiple parameter instances rather than multiple values. | key=foo&key=bar&key=baz | +| csv | Comma-separated values. | foo,bar,baz | +| ssv | Space-separated values. | foo bar baz | +| tsv | Tab-separated values. | "foo\tbar\tbaz" | +| pipes | Pipe-separated values. | foo\|bar\|baz | + +```go +package main + +import ( + "log" + "time" + "github.com/gin-gonic/gin" +) + +type Person struct { + Name string `form:"name"` + Addresses []string `form:"addresses" collection_format:"csv"` + Birthday time.Time `form:"birthday" time_format:"2006-01-02" time_utc:"1"` + CreateTime time.Time `form:"createTime" time_format:"unixNano"` + UnixTime time.Time `form:"unixTime" time_format:"unix"` +} + +func main() { + route := gin.Default() + route.GET("/testing", startPage) + route.Run(":8085") +} +func startPage(c *gin.Context) { + var person Person + // If `GET`, only `Form` binding engine (`query`) used. + // If `POST`, first checks the `content-type` for `JSON` or `XML`, then uses `Form` (`form-data`). + // See more at https://github.com/gin-gonic/gin/blob/master/binding/binding.go#L48 + if c.ShouldBind(&person) == nil { + log.Println(person.Name) + log.Println(person.Addresses) + log.Println(person.Birthday) + log.Println(person.CreateTime) + log.Println(person.UnixTime) + } + c.String(200, "Success") +} +``` + +Test it with: +```sh +$ curl -X GET "localhost:8085/testing?name=appleboy&addresses=foo,bar&birthday=1992-03-15&createTime=1562400033000000123&unixTime=1562400033" ``` ### Bind Uri @@ -899,6 +1006,46 @@ curl -v localhost:8088/thinkerou/987fbc97-4bed-5078-9f07-9141ba07c9f3 curl -v localhost:8088/thinkerou/not-uuid ``` +### Bind custom unmarshaler + +```go +package main + +import ( + "github.com/gin-gonic/gin" + "strings" +) + +type Birthday string + +func (b *Birthday) UnmarshalParam(param string) error { + *b = Birthday(strings.Replace(param, "-", "/", -1)) + return nil +} + +func main() { + route := gin.Default() + var request struct { + Birthday Birthday `form:"birthday"` + } + route.GET("/test", func(ctx *gin.Context) { + _ = ctx.BindQuery(&request) + ctx.JSON(200, request.Birthday) + }) + route.Run(":8088") +} +``` + +Test it with: + +```sh +curl 'localhost:8088/test?birthday=2000-01-01' +``` +Result +```sh +"2000/01/01" +``` + ### Bind Header ```go @@ -1040,7 +1187,7 @@ func main() { }) r.GET("/moreJSON", func(c *gin.Context) { - // You also can use a struct + // You can also use a struct var msg struct { Name string `json:"user"` Message string @@ -1109,7 +1256,7 @@ func main() { #### JSONP -Using JSONP to request data from a server in a different domain. Add callback to response body if the query parameter callback exists. +Using JSONP to request data from a server in a different domain. Add callback to response body if the query parameter callback exists. ```go func main() { @@ -1158,7 +1305,7 @@ func main() { #### PureJSON -Normally, JSON replaces special HTML characters with their unicode entities, e.g. `<` becomes `\u003c`. If you want to encode such characters literally, you can use PureJSON instead. +Normally, JSON replaces special HTML characters with their unicode entities, e.g. `<` becomes `\u003c`. If you want to encode such characters literally, you can use PureJSON instead. This feature is unavailable in Go 1.6 and lower. ```go @@ -1193,7 +1340,7 @@ func main() { router.StaticFS("/more_static", http.Dir("my_file_system")) router.StaticFile("/favicon.ico", "./resources/favicon.ico") router.StaticFileFS("/more_favicon.ico", "more_favicon.ico", http.Dir("my_file_system")) - + // Listen and serve on 0.0.0.0:8080 router.Run(":8080") } @@ -1246,13 +1393,19 @@ func main() { ### HTML rendering -Using LoadHTMLGlob() or LoadHTMLFiles() +Using LoadHTMLGlob() or LoadHTMLFiles() or LoadHTMLFS() ```go +//go:embed templates/* +var templates embed.FS + func main() { router := gin.Default() router.LoadHTMLGlob("templates/*") //router.LoadHTMLFiles("templates/template1.html", "templates/template2.html") + //router.LoadHTMLFS(http.Dir("templates"), "template1.html", "template2.html") + //or + //router.LoadHTMLFS(http.FS(templates), "templates/template1.html", "templates/template2.html") router.GET("/index", func(c *gin.Context) { c.HTML(http.StatusOK, "index.tmpl", gin.H{ "title": "Main website", @@ -1343,7 +1496,7 @@ You may use custom delims #### Custom Template Funcs -See the detail [example code](https://github.com/gin-gonic/examples/tree/master/template). +See the detailed [example code](https://github.com/gin-gonic/examples/tree/master/template). main.go @@ -1395,7 +1548,7 @@ Date: 2017/07/01 ### Multitemplate -Gin allow by default use only one html.Template. Check [a multitemplate render](https://github.com/gin-contrib/multitemplate) for using features like go 1.6 `block template`. +Gin allows only one html.Template by default. Check [a multitemplate render](https://github.com/gin-contrib/multitemplate) for using features like go 1.6 `block template`. ### Redirects @@ -1944,7 +2097,7 @@ type formB struct { func SomeHandler(c *gin.Context) { objA := formA{} objB := formB{} - // This c.ShouldBind consumes c.Request.Body and it cannot be reused. + // Calling c.ShouldBind consumes c.Request.Body and it cannot be reused. if errA := c.ShouldBind(&objA); errA == nil { c.String(http.StatusOK, `the body should be formA`) // Always an error is occurred by this because c.Request.Body is EOF now. @@ -1956,7 +2109,12 @@ func SomeHandler(c *gin.Context) { } ``` -For this, you can use `c.ShouldBindBodyWith`. +For this, you can use `c.ShouldBindBodyWith` or shortcuts. + +- `c.ShouldBindBodyWithJSON` is a shortcut for c.ShouldBindBodyWith(obj, binding.JSON). +- `c.ShouldBindBodyWithXML` is a shortcut for c.ShouldBindBodyWith(obj, binding.XML). +- `c.ShouldBindBodyWithYAML` is a shortcut for c.ShouldBindBodyWith(obj, binding.YAML). +- `c.ShouldBindBodyWithTOML` is a shortcut for c.ShouldBindBodyWith(obj, binding.TOML). ```go func SomeHandler(c *gin.Context) { @@ -1969,7 +2127,7 @@ func SomeHandler(c *gin.Context) { } else if errB := c.ShouldBindBodyWith(&objB, binding.JSON); errB == nil { c.String(http.StatusOK, `the body should be formB JSON`) // And it can accepts other formats - } else if errB2 := c.ShouldBindBodyWith(&objB, binding.XML); errB2 == nil { + } else if errB2 := c.ShouldBindBodyWithXML(&objB); errB2 == nil { c.String(http.StatusOK, `the body should be formB XML`) } else { ... @@ -2172,7 +2330,7 @@ or network CIDRs from where clients which their request headers related to clien IP can be trusted. They can be IPv4 addresses, IPv4 CIDRs, IPv6 addresses or IPv6 CIDRs. -**Attention:** Gin trust all proxies by default if you don't specify a trusted +**Attention:** Gin trusts all proxies by default if you don't specify a trusted proxy using the function above, **this is NOT safe**. At the same time, if you don't use any proxy, you can disable this feature by using `Engine.SetTrustedProxies(nil)`, then `Context.ClientIP()` will return the remote address directly to avoid some @@ -2201,7 +2359,7 @@ func main() { ``` **Notice:** If you are using a CDN service, you can set the `Engine.TrustedPlatform` -to skip TrustedProxies check, it has a higher priority than TrustedProxies. +to skip TrustedProxies check, it has a higher priority than TrustedProxies. Look at the example below: ```go diff --git a/errors_test.go b/errors_test.go index f77a6342..72a36992 100644 --- a/errors_test.go +++ b/errors_test.go @@ -11,6 +11,7 @@ import ( "github.com/gin-gonic/gin/internal/json" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestError(t *testing.T) { @@ -122,7 +123,7 @@ func TestErrorUnwrap(t *testing.T) { }) // check that 'errors.Is()' and 'errors.As()' behave as expected : - assert.True(t, errors.Is(err, innerErr)) + require.ErrorIs(t, err, innerErr) var testErr TestErr - assert.True(t, errors.As(err, &testErr)) + require.ErrorAs(t, err, &testErr) } diff --git a/fs.go b/fs.go index f17d7434..51c3db86 100644 --- a/fs.go +++ b/fs.go @@ -9,37 +9,43 @@ import ( "os" ) -type onlyFilesFS struct { - fs http.FileSystem +// OnlyFilesFS implements an http.FileSystem without `Readdir` functionality. +type OnlyFilesFS struct { + FileSystem http.FileSystem } -type neuteredReaddirFile struct { - http.File -} +// Open passes `Open` to the upstream implementation without `Readdir` functionality. +func (o OnlyFilesFS) Open(name string) (http.File, error) { + f, err := o.FileSystem.Open(name) -// Dir returns a http.FileSystem that can be used by http.FileServer(). It is used internally -// in router.Static(). -// if listDirectory == true, then it works the same as http.Dir() otherwise it returns -// a filesystem that prevents http.FileServer() to list the directory files. -func Dir(root string, listDirectory bool) http.FileSystem { - fs := http.Dir(root) - if listDirectory { - return fs - } - return &onlyFilesFS{fs} -} - -// Open conforms to http.Filesystem. -func (fs onlyFilesFS) Open(name string) (http.File, error) { - f, err := fs.fs.Open(name) if err != nil { return nil, err } - return neuteredReaddirFile{f}, nil + + return neutralizedReaddirFile{f}, nil } -// Readdir overrides the http.File default implementation. -func (f neuteredReaddirFile) Readdir(_ int) ([]os.FileInfo, error) { +// neutralizedReaddirFile wraps http.File with a specific implementation of `Readdir`. +type neutralizedReaddirFile struct { + http.File +} + +// Readdir overrides the http.File default implementation and always returns nil. +func (n neutralizedReaddirFile) Readdir(_ int) ([]os.FileInfo, error) { // this disables directory listing return nil, nil } + +// Dir returns an http.FileSystem that can be used by http.FileServer(). +// It is used internally in router.Static(). +// if listDirectory == true, then it works the same as http.Dir(), +// otherwise it returns a filesystem that prevents http.FileServer() to list the directory files. +func Dir(root string, listDirectory bool) http.FileSystem { + fs := http.Dir(root) + + if listDirectory { + return fs + } + + return &OnlyFilesFS{FileSystem: fs} +} diff --git a/fs_test.go b/fs_test.go new file mode 100644 index 00000000..167ac1af --- /dev/null +++ b/fs_test.go @@ -0,0 +1,72 @@ +package gin + +import ( + "errors" + "net/http" + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type mockFileSystem struct { + open func(name string) (http.File, error) +} + +func (m *mockFileSystem) Open(name string) (http.File, error) { + return m.open(name) +} + +func TestOnlyFilesFS_Open(t *testing.T) { + var testFile *os.File + mockFS := &mockFileSystem{ + open: func(name string) (http.File, error) { + return testFile, nil + }, + } + fs := &OnlyFilesFS{FileSystem: mockFS} + + file, err := fs.Open("foo") + + require.NoError(t, err) + assert.Equal(t, testFile, file.(neutralizedReaddirFile).File) +} + +func TestOnlyFilesFS_Open_err(t *testing.T) { + testError := errors.New("mock") + mockFS := &mockFileSystem{ + open: func(_ string) (http.File, error) { + return nil, testError + }, + } + fs := &OnlyFilesFS{FileSystem: mockFS} + + file, err := fs.Open("foo") + + require.ErrorIs(t, err, testError) + assert.Nil(t, file) +} + +func Test_neuteredReaddirFile_Readdir(t *testing.T) { + n := neutralizedReaddirFile{} + + res, err := n.Readdir(0) + + require.NoError(t, err) + assert.Nil(t, res) +} + +func TestDir_listDirectory(t *testing.T) { + testRoot := "foo" + fs := Dir(testRoot, true) + + assert.Equal(t, http.Dir(testRoot), fs) +} + +func TestDir(t *testing.T) { + testRoot := "foo" + fs := Dir(testRoot, false) + + assert.Equal(t, &OnlyFilesFS{FileSystem: http.Dir(testRoot)}, fs) +} diff --git a/gin.go b/gin.go index 24a9864a..f9813e1d 100644 --- a/gin.go +++ b/gin.go @@ -16,12 +16,18 @@ import ( "sync" "github.com/gin-gonic/gin/internal/bytesconv" + filesystem "github.com/gin-gonic/gin/internal/fs" "github.com/gin-gonic/gin/render" + + "github.com/quic-go/quic-go/http3" "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" ) const defaultMultipartMemory = 32 << 20 // 32 MB +const escapedColon = "\\:" +const colon = ":" +const backslash = "\\" var ( default404Body = []byte("404 page not found") @@ -47,6 +53,9 @@ var regRemoveRepeatedChar = regexp.MustCompile("/{2,}") // HandlerFunc defines the handler used by gin middleware as return value. type HandlerFunc func(*Context) +// OptionFunc defines the function to change the default configuration +type OptionFunc func(*Engine) + // HandlersChain defines a HandlerFunc slice. type HandlersChain []HandlerFunc @@ -182,7 +191,7 @@ var _ IRouter = (*Engine)(nil) // - ForwardedByClientIP: true // - UseRawPath: false // - UnescapePathValues: true -func New() *Engine { +func New(opts ...OptionFunc) *Engine { debugPrintWARNINGNew() engine := &Engine{ RouterGroup: RouterGroup{ @@ -211,15 +220,15 @@ func New() *Engine { engine.pool.New = func() any { return engine.allocateContext(engine.maxParams) } - return engine + return engine.With(opts...) } // Default returns an Engine instance with the Logger and Recovery middleware already attached. -func Default() *Engine { +func Default(opts ...OptionFunc) *Engine { debugPrintWARNINGDefault() engine := New() engine.Use(Logger(), Recovery()) - return engine + return engine.With(opts...) } func (engine *Engine) Handler() http.Handler { @@ -277,6 +286,19 @@ func (engine *Engine) LoadHTMLFiles(files ...string) { engine.SetHTMLTemplate(templ) } +// LoadHTMLFS loads an http.FileSystem and a slice of patterns +// and associates the result with HTML renderer. +func (engine *Engine) LoadHTMLFS(fs http.FileSystem, patterns ...string) { + if IsDebugging() { + engine.HTMLRender = render.HTMLDebug{FileSystem: fs, Patterns: patterns, FuncMap: engine.FuncMap, Delims: engine.delims} + return + } + + templ := template.Must(template.New("").Delims(engine.delims.Left, engine.delims.Right).Funcs(engine.FuncMap).ParseFS( + filesystem.FileSystem{FileSystem: fs}, patterns...)) + engine.SetHTMLTemplate(templ) +} + // SetHTMLTemplate associate a template with HTML renderer. func (engine *Engine) SetHTMLTemplate(templ *template.Template) { if len(engine.trees) > 0 { @@ -313,6 +335,15 @@ func (engine *Engine) Use(middleware ...HandlerFunc) IRoutes { return engine } +// With returns a Engine with the configuration set in the OptionFunc. +func (engine *Engine) With(opts ...OptionFunc) *Engine { + for _, opt := range opts { + opt(engine) + } + + return engine +} + func (engine *Engine) rebuild404Handlers() { engine.allNoRoute = engine.combineHandlers(engine.noRoute) } @@ -371,23 +402,6 @@ func iterate(path, method string, routes RoutesInfo, root *node) RoutesInfo { return routes } -// Run attaches the router to a http.Server and starts listening and serving HTTP requests. -// It is a shortcut for http.ListenAndServe(addr, router) -// Note: this method will block the calling goroutine indefinitely unless an error happens. -func (engine *Engine) Run(addr ...string) (err error) { - defer func() { debugPrintError(err) }() - - if engine.isUnsafeTrustedProxies() { - debugPrint("[WARNING] You trusted all proxies, this is NOT safe. We recommend you to set a value.\n" + - "Please check https://pkg.go.dev/github.com/gin-gonic/gin#readme-don-t-trust-all-proxies for details.") - } - - address := resolveAddress(addr) - debugPrint("Listening and serving HTTP on %s\n", address) - err = http.ListenAndServe(address, engine.Handler()) - return -} - func (engine *Engine) prepareTrustedCIDRs() ([]*net.IPNet, error) { if engine.trustedProxies == nil { return nil, nil @@ -477,6 +491,26 @@ func (engine *Engine) validateHeader(header string) (clientIP string, valid bool return "", false } +// updateRouteTree do update to the route tree recursively +func updateRouteTree(n *node) { + n.path = strings.ReplaceAll(n.path, escapedColon, colon) + n.fullPath = strings.ReplaceAll(n.fullPath, escapedColon, colon) + n.indices = strings.ReplaceAll(n.indices, backslash, colon) + if n.children == nil { + return + } + for _, child := range n.children { + updateRouteTree(child) + } +} + +// updateRouteTrees do update to the route trees +func (engine *Engine) updateRouteTrees() { + for _, tree := range engine.trees { + updateRouteTree(tree.root) + } +} + // parseIP parse a string representation of an IP and returns a net.IP with the // minimum byte representation or nil if input is invalid. func parseIP(ip string) net.IP { @@ -491,6 +525,23 @@ func parseIP(ip string) net.IP { return parsedIP } +// Run attaches the router to a http.Server and starts listening and serving HTTP requests. +// It is a shortcut for http.ListenAndServe(addr, router) +// Note: this method will block the calling goroutine indefinitely unless an error happens. +func (engine *Engine) Run(addr ...string) (err error) { + defer func() { debugPrintError(err) }() + + if engine.isUnsafeTrustedProxies() { + debugPrint("[WARNING] You trusted all proxies, this is NOT safe. We recommend you to set a value.\n" + + "Please check https://github.com/gin-gonic/gin/blob/master/docs/doc.md#dont-trust-all-proxies for details.") + } + engine.updateRouteTrees() + address := resolveAddress(addr) + debugPrint("Listening and serving HTTP on %s\n", address) + err = http.ListenAndServe(address, engine.Handler()) + return +} + // RunTLS attaches the router to a http.Server and starts listening and serving HTTPS (secure) requests. // It is a shortcut for http.ListenAndServeTLS(addr, certFile, keyFile, router) // Note: this method will block the calling goroutine indefinitely unless an error happens. @@ -500,7 +551,7 @@ func (engine *Engine) RunTLS(addr, certFile, keyFile string) (err error) { if engine.isUnsafeTrustedProxies() { debugPrint("[WARNING] You trusted all proxies, this is NOT safe. We recommend you to set a value.\n" + - "Please check https://pkg.go.dev/github.com/gin-gonic/gin#readme-don-t-trust-all-proxies for details.") + "Please check https://github.com/gin-gonic/gin/blob/master/docs/doc.md#dont-trust-all-proxies for details.") } err = http.ListenAndServeTLS(addr, certFile, keyFile, engine.Handler()) @@ -552,6 +603,22 @@ func (engine *Engine) RunFd(fd int) (err error) { return } +// RunQUIC attaches the router to a http.Server and starts listening and serving QUIC requests. +// It is a shortcut for http3.ListenAndServeQUIC(addr, certFile, keyFile, router) +// Note: this method will block the calling goroutine indefinitely unless an error happens. +func (engine *Engine) RunQUIC(addr, certFile, keyFile string) (err error) { + debugPrint("Listening and serving QUIC on %s\n", addr) + defer func() { debugPrintError(err) }() + + if engine.isUnsafeTrustedProxies() { + debugPrint("[WARNING] You trusted all proxies, this is NOT safe. We recommend you to set a value.\n" + + "Please check https://github.com/gin-gonic/gin/blob/master/docs/doc.md#dont-trust-all-proxies for details.") + } + + err = http3.ListenAndServeQUIC(addr, certFile, keyFile, engine.Handler()) + return +} + // RunListener attaches the router to a http.Server and starts listening and serving HTTP requests // through the specified net.Listener func (engine *Engine) RunListener(listener net.Listener) (err error) { @@ -584,10 +651,12 @@ func (engine *Engine) ServeHTTP(w http.ResponseWriter, req *http.Request) { // Disclaimer: You can loop yourself to deal with this, use wisely. func (engine *Engine) HandleContext(c *Context) { oldIndexValue := c.index + oldHandlers := c.handlers c.reset() engine.handleHTTPRequest(c) c.index = oldIndexValue + c.handlers = oldHandlers } func (engine *Engine) handleHTTPRequest(c *Context) { @@ -634,7 +703,7 @@ func (engine *Engine) handleHTTPRequest(c *Context) { break } - if engine.HandleMethodNotAllowed { + if engine.HandleMethodNotAllowed && len(t) > 0 { // According to RFC 7231 section 6.5.5, MUST generate an Allow header field in response // containing a list of the target resource's currently supported methods. allowed := make([]string, 0, len(t)-1) diff --git a/ginS/gins.go b/ginS/gins.go index ea38c613..3e6a92eb 100644 --- a/ginS/gins.go +++ b/ginS/gins.go @@ -32,6 +32,11 @@ func LoadHTMLFiles(files ...string) { engine().LoadHTMLFiles(files...) } +// LoadHTMLFS is a wrapper for Engine.LoadHTMLFS. +func LoadHTMLFS(fs http.FileSystem, patterns ...string) { + engine().LoadHTMLFS(fs, patterns...) +} + // SetHTMLTemplate is a wrapper for Engine.SetHTMLTemplate. func SetHTMLTemplate(templ *template.Template) { engine().SetHTMLTemplate(templ) @@ -154,7 +159,7 @@ func RunUnix(file string) (err error) { // RunFd attaches the router to a http.Server and starts listening and serving HTTP requests // through the specified file descriptor. -// Note: the method will block the calling goroutine indefinitely unless on error happens. +// Note: the method will block the calling goroutine indefinitely unless an error happens. func RunFd(fd int) (err error) { return engine().RunFd(fd) } diff --git a/gin_integration_test.go b/gin_integration_test.go index 02b96221..3082bc2c 100644 --- a/gin_integration_test.go +++ b/gin_integration_test.go @@ -21,6 +21,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) // params[0]=url example:http://127.0.0.1:8080/index (cannot be empty) @@ -40,11 +41,11 @@ func testRequest(t *testing.T, params ...string) { client := &http.Client{Transport: tr} resp, err := client.Get(params[0]) - assert.NoError(t, err) + require.NoError(t, err) defer resp.Body.Close() body, ioerr := io.ReadAll(resp.Body) - assert.NoError(t, ioerr) + require.NoError(t, ioerr) var responseStatus = "200 OK" if len(params) > 1 && params[1] != "" { @@ -73,13 +74,13 @@ func TestRunEmpty(t *testing.T) { // otherwise the main thread will complete time.Sleep(5 * time.Millisecond) - assert.Error(t, router.Run(":8080")) + require.Error(t, router.Run(":8080")) testRequest(t, "http://localhost:8080/example") } func TestBadTrustedCIDRs(t *testing.T) { router := New() - assert.Error(t, router.SetTrustedProxies([]string{"hello/world"})) + require.Error(t, router.SetTrustedProxies([]string{"hello/world"})) } /* legacy tests @@ -87,7 +88,7 @@ func TestBadTrustedCIDRsForRun(t *testing.T) { os.Setenv("PORT", "") router := New() router.TrustedProxies = []string{"hello/world"} - assert.Error(t, router.Run(":8080")) + require.Error(t, router.Run(":8080")) } func TestBadTrustedCIDRsForRunUnix(t *testing.T) { @@ -100,7 +101,7 @@ func TestBadTrustedCIDRsForRunUnix(t *testing.T) { go func() { router.GET("/example", func(c *Context) { c.String(http.StatusOK, "it worked") }) - assert.Error(t, router.RunUnix(unixTestSocket)) + require.Error(t, router.RunUnix(unixTestSocket)) }() // have to wait for the goroutine to start and run the server // otherwise the main thread will complete @@ -112,15 +113,15 @@ func TestBadTrustedCIDRsForRunFd(t *testing.T) { router.TrustedProxies = []string{"hello/world"} addr, err := net.ResolveTCPAddr("tcp", "localhost:0") - assert.NoError(t, err) + require.NoError(t, err) listener, err := net.ListenTCP("tcp", addr) - assert.NoError(t, err) + require.NoError(t, err) socketFile, err := listener.File() - assert.NoError(t, err) + require.NoError(t, err) go func() { router.GET("/example", func(c *Context) { c.String(http.StatusOK, "it worked") }) - assert.Error(t, router.RunFd(int(socketFile.Fd()))) + require.Error(t, router.RunFd(int(socketFile.Fd()))) }() // have to wait for the goroutine to start and run the server // otherwise the main thread will complete @@ -132,12 +133,12 @@ func TestBadTrustedCIDRsForRunListener(t *testing.T) { router.TrustedProxies = []string{"hello/world"} addr, err := net.ResolveTCPAddr("tcp", "localhost:0") - assert.NoError(t, err) + require.NoError(t, err) listener, err := net.ListenTCP("tcp", addr) - assert.NoError(t, err) + require.NoError(t, err) go func() { router.GET("/example", func(c *Context) { c.String(http.StatusOK, "it worked") }) - assert.Error(t, router.RunListener(listener)) + require.Error(t, router.RunListener(listener)) }() // have to wait for the goroutine to start and run the server // otherwise the main thread will complete @@ -148,7 +149,7 @@ func TestBadTrustedCIDRsForRunTLS(t *testing.T) { os.Setenv("PORT", "") router := New() router.TrustedProxies = []string{"hello/world"} - assert.Error(t, router.RunTLS(":8080", "./testdata/certificate/cert.pem", "./testdata/certificate/key.pem")) + require.Error(t, router.RunTLS(":8080", "./testdata/certificate/cert.pem", "./testdata/certificate/key.pem")) } */ @@ -164,7 +165,7 @@ func TestRunTLS(t *testing.T) { // otherwise the main thread will complete time.Sleep(5 * time.Millisecond) - assert.Error(t, router.RunTLS(":8443", "./testdata/certificate/cert.pem", "./testdata/certificate/key.pem")) + require.Error(t, router.RunTLS(":8443", "./testdata/certificate/cert.pem", "./testdata/certificate/key.pem")) testRequest(t, "https://localhost:8443/example") } @@ -201,7 +202,7 @@ func TestPusher(t *testing.T) { // otherwise the main thread will complete time.Sleep(5 * time.Millisecond) - assert.Error(t, router.RunTLS(":8449", "./testdata/certificate/cert.pem", "./testdata/certificate/key.pem")) + require.Error(t, router.RunTLS(":8449", "./testdata/certificate/cert.pem", "./testdata/certificate/key.pem")) testRequest(t, "https://localhost:8449/pusher") } @@ -216,14 +217,14 @@ func TestRunEmptyWithEnv(t *testing.T) { // otherwise the main thread will complete time.Sleep(5 * time.Millisecond) - assert.Error(t, router.Run(":3123")) + require.Error(t, router.Run(":3123")) testRequest(t, "http://localhost:3123/example") } func TestRunTooMuchParams(t *testing.T) { router := New() assert.Panics(t, func() { - assert.NoError(t, router.Run("2", "2")) + require.NoError(t, router.Run("2", "2")) }) } @@ -237,7 +238,7 @@ func TestRunWithPort(t *testing.T) { // otherwise the main thread will complete time.Sleep(5 * time.Millisecond) - assert.Error(t, router.Run(":5150")) + require.Error(t, router.Run(":5150")) testRequest(t, "http://localhost:5150/example") } @@ -257,7 +258,7 @@ func TestUnixSocket(t *testing.T) { time.Sleep(5 * time.Millisecond) c, err := net.Dial("unix", unixTestSocket) - assert.NoError(t, err) + require.NoError(t, err) fmt.Fprint(c, "GET /example HTTP/1.0\r\n\r\n") scanner := bufio.NewScanner(c) @@ -271,22 +272,38 @@ func TestUnixSocket(t *testing.T) { func TestBadUnixSocket(t *testing.T) { router := New() - assert.Error(t, router.RunUnix("#/tmp/unix_unit_test")) + require.Error(t, router.RunUnix("#/tmp/unix_unit_test")) +} + +func TestRunQUIC(t *testing.T) { + router := New() + go func() { + router.GET("/example", func(c *Context) { c.String(http.StatusOK, "it worked") }) + + assert.NoError(t, router.RunQUIC(":8443", "./testdata/certificate/cert.pem", "./testdata/certificate/key.pem")) + }() + + // have to wait for the goroutine to start and run the server + // otherwise the main thread will complete + time.Sleep(5 * time.Millisecond) + + require.Error(t, router.RunQUIC(":8443", "./testdata/certificate/cert.pem", "./testdata/certificate/key.pem")) + testRequest(t, "https://localhost:8443/example") } func TestFileDescriptor(t *testing.T) { router := New() addr, err := net.ResolveTCPAddr("tcp", "localhost:0") - assert.NoError(t, err) + require.NoError(t, err) listener, err := net.ListenTCP("tcp", addr) - assert.NoError(t, err) + require.NoError(t, err) socketFile, err := listener.File() if isWindows() { // not supported by windows, it is unimplemented now - assert.Error(t, err) + require.Error(t, err) } else { - assert.NoError(t, err) + require.NoError(t, err) } if socketFile == nil { @@ -302,7 +319,7 @@ func TestFileDescriptor(t *testing.T) { time.Sleep(5 * time.Millisecond) c, err := net.Dial("tcp", listener.Addr().String()) - assert.NoError(t, err) + require.NoError(t, err) fmt.Fprintf(c, "GET /example HTTP/1.0\r\n\r\n") scanner := bufio.NewScanner(c) @@ -316,15 +333,15 @@ func TestFileDescriptor(t *testing.T) { func TestBadFileDescriptor(t *testing.T) { router := New() - assert.Error(t, router.RunFd(0)) + require.Error(t, router.RunFd(0)) } func TestListener(t *testing.T) { router := New() addr, err := net.ResolveTCPAddr("tcp", "localhost:0") - assert.NoError(t, err) + require.NoError(t, err) listener, err := net.ListenTCP("tcp", addr) - assert.NoError(t, err) + require.NoError(t, err) go func() { router.GET("/example", func(c *Context) { c.String(http.StatusOK, "it worked") }) assert.NoError(t, router.RunListener(listener)) @@ -334,7 +351,7 @@ func TestListener(t *testing.T) { time.Sleep(5 * time.Millisecond) c, err := net.Dial("tcp", listener.Addr().String()) - assert.NoError(t, err) + require.NoError(t, err) fmt.Fprintf(c, "GET /example HTTP/1.0\r\n\r\n") scanner := bufio.NewScanner(c) @@ -349,11 +366,11 @@ func TestListener(t *testing.T) { func TestBadListener(t *testing.T) { router := New() addr, err := net.ResolveTCPAddr("tcp", "localhost:10086") - assert.NoError(t, err) + require.NoError(t, err) listener, err := net.ListenTCP("tcp", addr) - assert.NoError(t, err) + require.NoError(t, err) listener.Close() - assert.Error(t, router.RunListener(listener)) + require.Error(t, router.RunListener(listener)) } func TestWithHttptestWithAutoSelectedPort(t *testing.T) { @@ -379,7 +396,14 @@ func TestConcurrentHandleContext(t *testing.T) { wg.Add(iterations) for i := 0; i < iterations; i++ { go func() { - testGetRequestHandler(t, router, "/") + req, err := http.NewRequest(http.MethodGet, "/", nil) + assert.NoError(t, err) + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + assert.Equal(t, "it worked", w.Body.String(), "resp body should match") + assert.Equal(t, 200, w.Code, "should get a 200") wg.Done() }() } @@ -401,17 +425,6 @@ func TestConcurrentHandleContext(t *testing.T) { // testRequest(t, "http://localhost:8033/example") // } -func testGetRequestHandler(t *testing.T, h http.Handler, url string) { - req, err := http.NewRequest(http.MethodGet, url, nil) - assert.NoError(t, err) - - w := httptest.NewRecorder() - h.ServeHTTP(w, req) - - assert.Equal(t, "it worked", w.Body.String(), "resp body should match") - assert.Equal(t, 200, w.Code, "should get a 200") -} - func TestTreeRunDynamicRouting(t *testing.T) { router := New() router.GET("/aa/*xx", func(c *Context) { c.String(http.StatusOK, "/aa/*xx") }) @@ -561,3 +574,28 @@ func TestTreeRunDynamicRouting(t *testing.T) { func isWindows() bool { return runtime.GOOS == "windows" } + +func TestEscapedColon(t *testing.T) { + router := New() + f := func(u string) { + router.GET(u, func(c *Context) { c.String(http.StatusOK, u) }) + } + f("/r/r\\:r") + f("/r/r:r") + f("/r/r/:r") + f("/r/r/\\:r") + f("/r/r/r\\:r") + assert.Panics(t, func() { + f("\\foo:") + }) + + router.updateRouteTrees() + ts := httptest.NewServer(router) + defer ts.Close() + + testRequest(t, ts.URL+"/r/r123", "", "/r/r:r") + testRequest(t, ts.URL+"/r/r:r", "", "/r/r\\:r") + testRequest(t, ts.URL+"/r/r/r123", "", "/r/r/:r") + testRequest(t, ts.URL+"/r/r/:r", "", "/r/r/\\:r") + testRequest(t, ts.URL+"/r/r/r:r", "", "/r/r/r\\:r") +} diff --git a/gin_test.go b/gin_test.go index 8825ac7e..a80b690e 100644 --- a/gin_test.go +++ b/gin_test.go @@ -14,11 +14,13 @@ import ( "net/http/httptest" "reflect" "strconv" + "strings" "sync/atomic" "testing" "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "golang.org/x/net/http2" ) @@ -71,7 +73,7 @@ func TestLoadHTMLGlobDebugMode(t *testing.T) { ) defer ts.Close() - res, err := http.Get(fmt.Sprintf("%s/test", ts.URL)) + res, err := http.Get(ts.URL + "/test") if err != nil { t.Error(err) } @@ -129,7 +131,7 @@ func TestLoadHTMLGlobTestMode(t *testing.T) { ) defer ts.Close() - res, err := http.Get(fmt.Sprintf("%s/test", ts.URL)) + res, err := http.Get(ts.URL + "/test") if err != nil { t.Error(err) } @@ -149,7 +151,7 @@ func TestLoadHTMLGlobReleaseMode(t *testing.T) { ) defer ts.Close() - res, err := http.Get(fmt.Sprintf("%s/test", ts.URL)) + res, err := http.Get(ts.URL + "/test") if err != nil { t.Error(err) } @@ -176,7 +178,7 @@ func TestLoadHTMLGlobUsingTLS(t *testing.T) { }, } client := &http.Client{Transport: tr} - res, err := client.Get(fmt.Sprintf("%s/test", ts.URL)) + res, err := client.Get(ts.URL + "/test") if err != nil { t.Error(err) } @@ -196,7 +198,7 @@ func TestLoadHTMLGlobFromFuncMap(t *testing.T) { ) defer ts.Close() - res, err := http.Get(fmt.Sprintf("%s/raw", ts.URL)) + res, err := http.Get(ts.URL + "/raw") if err != nil { t.Error(err) } @@ -227,7 +229,7 @@ func TestLoadHTMLFilesTestMode(t *testing.T) { ) defer ts.Close() - res, err := http.Get(fmt.Sprintf("%s/test", ts.URL)) + res, err := http.Get(ts.URL + "/test") if err != nil { t.Error(err) } @@ -247,7 +249,7 @@ func TestLoadHTMLFilesDebugMode(t *testing.T) { ) defer ts.Close() - res, err := http.Get(fmt.Sprintf("%s/test", ts.URL)) + res, err := http.Get(ts.URL + "/test") if err != nil { t.Error(err) } @@ -267,7 +269,7 @@ func TestLoadHTMLFilesReleaseMode(t *testing.T) { ) defer ts.Close() - res, err := http.Get(fmt.Sprintf("%s/test", ts.URL)) + res, err := http.Get(ts.URL + "/test") if err != nil { t.Error(err) } @@ -294,7 +296,7 @@ func TestLoadHTMLFilesUsingTLS(t *testing.T) { }, } client := &http.Client{Transport: tr} - res, err := client.Get(fmt.Sprintf("%s/test", ts.URL)) + res, err := client.Get(ts.URL + "/test") if err != nil { t.Error(err) } @@ -314,6 +316,115 @@ func TestLoadHTMLFilesFuncMap(t *testing.T) { ) defer ts.Close() + res, err := http.Get(ts.URL + "/raw") + if err != nil { + t.Error(err) + } + + resp, _ := io.ReadAll(res.Body) + assert.Equal(t, "Date: 2017/07/01", string(resp)) +} + +var tmplFS = http.Dir("testdata/template") + +func TestLoadHTMLFSTestMode(t *testing.T) { + ts := setupHTMLFiles( + t, + TestMode, + false, + func(router *Engine) { + router.LoadHTMLFS(tmplFS, "hello.tmpl", "raw.tmpl") + }, + ) + defer ts.Close() + + res, err := http.Get(fmt.Sprintf("%s/test", ts.URL)) + if err != nil { + t.Error(err) + } + + resp, _ := io.ReadAll(res.Body) + assert.Equal(t, "

Hello world

", string(resp)) +} + +func TestLoadHTMLFSDebugMode(t *testing.T) { + ts := setupHTMLFiles( + t, + DebugMode, + false, + func(router *Engine) { + router.LoadHTMLFS(tmplFS, "hello.tmpl", "raw.tmpl") + }, + ) + defer ts.Close() + + res, err := http.Get(fmt.Sprintf("%s/test", ts.URL)) + if err != nil { + t.Error(err) + } + + resp, _ := io.ReadAll(res.Body) + assert.Equal(t, "

Hello world

", string(resp)) +} + +func TestLoadHTMLFSReleaseMode(t *testing.T) { + ts := setupHTMLFiles( + t, + ReleaseMode, + false, + func(router *Engine) { + router.LoadHTMLFS(tmplFS, "hello.tmpl", "raw.tmpl") + }, + ) + defer ts.Close() + + res, err := http.Get(fmt.Sprintf("%s/test", ts.URL)) + if err != nil { + t.Error(err) + } + + resp, _ := io.ReadAll(res.Body) + assert.Equal(t, "

Hello world

", string(resp)) +} + +func TestLoadHTMLFSUsingTLS(t *testing.T) { + ts := setupHTMLFiles( + t, + TestMode, + true, + func(router *Engine) { + router.LoadHTMLFS(tmplFS, "hello.tmpl", "raw.tmpl") + }, + ) + defer ts.Close() + + // Use InsecureSkipVerify for avoiding `x509: certificate signed by unknown authority` error + tr := &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + } + client := &http.Client{Transport: tr} + res, err := client.Get(fmt.Sprintf("%s/test", ts.URL)) + if err != nil { + t.Error(err) + } + + resp, _ := io.ReadAll(res.Body) + assert.Equal(t, "

Hello world

", string(resp)) +} + +func TestLoadHTMLFSFuncMap(t *testing.T) { + ts := setupHTMLFiles( + t, + TestMode, + false, + func(router *Engine) { + router.LoadHTMLFS(tmplFS, "hello.tmpl", "raw.tmpl") + }, + ) + defer ts.Close() + res, err := http.Get(fmt.Sprintf("%s/raw", ts.URL)) if err != nil { t.Error(err) @@ -325,31 +436,31 @@ func TestLoadHTMLFilesFuncMap(t *testing.T) { func TestAddRoute(t *testing.T) { router := New() - router.addRoute("GET", "/", HandlersChain{func(_ *Context) {}}) + router.addRoute(http.MethodGet, "/", HandlersChain{func(_ *Context) {}}) assert.Len(t, router.trees, 1) - assert.NotNil(t, router.trees.get("GET")) - assert.Nil(t, router.trees.get("POST")) + assert.NotNil(t, router.trees.get(http.MethodGet)) + assert.Nil(t, router.trees.get(http.MethodPost)) - router.addRoute("POST", "/", HandlersChain{func(_ *Context) {}}) + router.addRoute(http.MethodPost, "/", HandlersChain{func(_ *Context) {}}) assert.Len(t, router.trees, 2) - assert.NotNil(t, router.trees.get("GET")) - assert.NotNil(t, router.trees.get("POST")) + assert.NotNil(t, router.trees.get(http.MethodGet)) + assert.NotNil(t, router.trees.get(http.MethodPost)) - router.addRoute("POST", "/post", HandlersChain{func(_ *Context) {}}) + router.addRoute(http.MethodPost, "/post", HandlersChain{func(_ *Context) {}}) assert.Len(t, router.trees, 2) } func TestAddRouteFails(t *testing.T) { router := New() assert.Panics(t, func() { router.addRoute("", "/", HandlersChain{func(_ *Context) {}}) }) - assert.Panics(t, func() { router.addRoute("GET", "a", HandlersChain{func(_ *Context) {}}) }) - assert.Panics(t, func() { router.addRoute("GET", "/", HandlersChain{}) }) + assert.Panics(t, func() { router.addRoute(http.MethodGet, "a", HandlersChain{func(_ *Context) {}}) }) + assert.Panics(t, func() { router.addRoute(http.MethodGet, "/", HandlersChain{}) }) - router.addRoute("POST", "/post", HandlersChain{func(_ *Context) {}}) + router.addRoute(http.MethodPost, "/post", HandlersChain{func(_ *Context) {}}) assert.Panics(t, func() { - router.addRoute("POST", "/post", HandlersChain{func(_ *Context) {}}) + router.addRoute(http.MethodPost, "/post", HandlersChain{func(_ *Context) {}}) }) } @@ -491,27 +602,27 @@ func TestListOfRoutes(t *testing.T) { assert.Len(t, list, 7) assertRoutePresent(t, list, RouteInfo{ - Method: "GET", + Method: http.MethodGet, Path: "/favicon.ico", Handler: "^(.*/vendor/)?github.com/gin-gonic/gin.handlerTest1$", }) assertRoutePresent(t, list, RouteInfo{ - Method: "GET", + Method: http.MethodGet, Path: "/", Handler: "^(.*/vendor/)?github.com/gin-gonic/gin.handlerTest1$", }) assertRoutePresent(t, list, RouteInfo{ - Method: "GET", + Method: http.MethodGet, Path: "/users/", Handler: "^(.*/vendor/)?github.com/gin-gonic/gin.handlerTest2$", }) assertRoutePresent(t, list, RouteInfo{ - Method: "GET", + Method: http.MethodGet, Path: "/users/:id", Handler: "^(.*/vendor/)?github.com/gin-gonic/gin.handlerTest1$", }) assertRoutePresent(t, list, RouteInfo{ - Method: "POST", + Method: http.MethodPost, Path: "/users/:id", Handler: "^(.*/vendor/)?github.com/gin-gonic/gin.handlerTest2$", }) @@ -529,7 +640,7 @@ func TestEngineHandleContext(t *testing.T) { } assert.NotPanics(t, func() { - w := PerformRequest(r, "GET", "/") + w := PerformRequest(r, http.MethodGet, "/") assert.Equal(t, 301, w.Code) }) } @@ -546,10 +657,10 @@ func TestEngineHandleContextManyReEntries(t *testing.T) { r.GET("/:count", func(c *Context) { countStr := c.Param("count") count, err := strconv.Atoi(countStr) - assert.NoError(t, err) + require.NoError(t, err) n, err := c.Writer.Write([]byte(".")) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, 1, n) switch { @@ -562,7 +673,7 @@ func TestEngineHandleContextManyReEntries(t *testing.T) { }) assert.NotPanics(t, func() { - w := PerformRequest(r, "GET", "/"+strconv.Itoa(expectValue-1)) // include 0 value + w := PerformRequest(r, http.MethodGet, "/"+strconv.Itoa(expectValue-1)) // include 0 value assert.Equal(t, 200, w.Code) assert.Equal(t, expectValue, w.Body.Len()) }) @@ -571,6 +682,44 @@ func TestEngineHandleContextManyReEntries(t *testing.T) { assert.Equal(t, int64(expectValue), middlewareCounter) } +func TestEngineHandleContextPreventsMiddlewareReEntry(t *testing.T) { + // given + var handlerCounterV1, handlerCounterV2, middlewareCounterV1 int64 + + r := New() + v1 := r.Group("/v1") + { + v1.Use(func(c *Context) { + atomic.AddInt64(&middlewareCounterV1, 1) + }) + v1.GET("/test", func(c *Context) { + atomic.AddInt64(&handlerCounterV1, 1) + c.Status(http.StatusOK) + }) + } + + v2 := r.Group("/v2") + { + v2.GET("/test", func(c *Context) { + c.Request.URL.Path = "/v1/test" + r.HandleContext(c) + }, func(c *Context) { + atomic.AddInt64(&handlerCounterV2, 1) + }) + } + + // when + responseV1 := PerformRequest(r, "GET", "/v1/test") + responseV2 := PerformRequest(r, "GET", "/v2/test") + + // then + assert.Equal(t, 200, responseV1.Code) + assert.Equal(t, 200, responseV2.Code) + assert.Equal(t, int64(2), handlerCounterV1) + assert.Equal(t, int64(2), middlewareCounterV1) + assert.Equal(t, int64(1), handlerCounterV2) +} + func TestPrepareTrustedCIRDsWith(t *testing.T) { r := New() @@ -579,7 +728,7 @@ func TestPrepareTrustedCIRDsWith(t *testing.T) { expectedTrustedCIDRs := []*net.IPNet{parseCIDR("0.0.0.0/0")} err := r.SetTrustedProxies([]string{"0.0.0.0/0"}) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, expectedTrustedCIDRs, r.trustedCIDRs) } @@ -587,7 +736,7 @@ func TestPrepareTrustedCIRDsWith(t *testing.T) { { err := r.SetTrustedProxies([]string{"192.168.1.33/33"}) - assert.Error(t, err) + require.Error(t, err) } // valid ipv4 address @@ -596,7 +745,7 @@ func TestPrepareTrustedCIRDsWith(t *testing.T) { err := r.SetTrustedProxies([]string{"192.168.1.33"}) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, expectedTrustedCIDRs, r.trustedCIDRs) } @@ -604,7 +753,7 @@ func TestPrepareTrustedCIRDsWith(t *testing.T) { { err := r.SetTrustedProxies([]string{"192.168.1.256"}) - assert.Error(t, err) + require.Error(t, err) } // valid ipv6 address @@ -612,7 +761,7 @@ func TestPrepareTrustedCIRDsWith(t *testing.T) { expectedTrustedCIDRs := []*net.IPNet{parseCIDR("2002:0000:0000:1234:abcd:ffff:c0a8:0101/128")} err := r.SetTrustedProxies([]string{"2002:0000:0000:1234:abcd:ffff:c0a8:0101"}) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, expectedTrustedCIDRs, r.trustedCIDRs) } @@ -620,7 +769,7 @@ func TestPrepareTrustedCIRDsWith(t *testing.T) { { err := r.SetTrustedProxies([]string{"gggg:0000:0000:1234:abcd:ffff:c0a8:0101"}) - assert.Error(t, err) + require.Error(t, err) } // valid ipv6 cidr @@ -628,7 +777,7 @@ func TestPrepareTrustedCIRDsWith(t *testing.T) { expectedTrustedCIDRs := []*net.IPNet{parseCIDR("::/0")} err := r.SetTrustedProxies([]string{"::/0"}) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, expectedTrustedCIDRs, r.trustedCIDRs) } @@ -636,7 +785,7 @@ func TestPrepareTrustedCIRDsWith(t *testing.T) { { err := r.SetTrustedProxies([]string{"gggg:0000:0000:1234:abcd:ffff:c0a8:0101/129"}) - assert.Error(t, err) + require.Error(t, err) } // valid combination @@ -652,7 +801,7 @@ func TestPrepareTrustedCIRDsWith(t *testing.T) { "172.16.0.1", }) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, expectedTrustedCIDRs, r.trustedCIDRs) } @@ -664,7 +813,7 @@ func TestPrepareTrustedCIRDsWith(t *testing.T) { "172.16.0.256", }) - assert.Error(t, err) + require.Error(t, err) } // nil value @@ -672,7 +821,7 @@ func TestPrepareTrustedCIRDsWith(t *testing.T) { err := r.SetTrustedProxies(nil) assert.Nil(t, r.trustedCIDRs) - assert.Nil(t, err) + require.NoError(t, err) } } @@ -696,3 +845,71 @@ func assertRoutePresent(t *testing.T, gotRoutes RoutesInfo, wantRoute RouteInfo) func handlerTest1(c *Context) {} func handlerTest2(c *Context) {} + +func TestNewOptionFunc(t *testing.T) { + var fc = func(e *Engine) { + e.GET("/test1", handlerTest1) + e.GET("/test2", handlerTest2) + + e.Use(func(c *Context) { + c.Next() + }) + } + + r := New(fc) + + routes := r.Routes() + assertRoutePresent(t, routes, RouteInfo{Path: "/test1", Method: http.MethodGet, Handler: "github.com/gin-gonic/gin.handlerTest1"}) + assertRoutePresent(t, routes, RouteInfo{Path: "/test2", Method: http.MethodGet, Handler: "github.com/gin-gonic/gin.handlerTest2"}) +} + +func TestWithOptionFunc(t *testing.T) { + r := New() + + r.With(func(e *Engine) { + e.GET("/test1", handlerTest1) + e.GET("/test2", handlerTest2) + + e.Use(func(c *Context) { + c.Next() + }) + }) + + routes := r.Routes() + assertRoutePresent(t, routes, RouteInfo{Path: "/test1", Method: http.MethodGet, Handler: "github.com/gin-gonic/gin.handlerTest1"}) + assertRoutePresent(t, routes, RouteInfo{Path: "/test2", Method: http.MethodGet, Handler: "github.com/gin-gonic/gin.handlerTest2"}) +} + +type Birthday string + +func (b *Birthday) UnmarshalParam(param string) error { + *b = Birthday(strings.Replace(param, "-", "/", -1)) + return nil +} + +func TestCustomUnmarshalStruct(t *testing.T) { + route := Default() + var request struct { + Birthday Birthday `form:"birthday"` + } + route.GET("/test", func(ctx *Context) { + _ = ctx.BindQuery(&request) + ctx.JSON(200, request.Birthday) + }) + req := httptest.NewRequest(http.MethodGet, "/test?birthday=2000-01-01", nil) + w := httptest.NewRecorder() + route.ServeHTTP(w, req) + assert.Equal(t, 200, w.Code) + assert.Equal(t, `"2000/01/01"`, w.Body.String()) +} + +// Test the fix for https://github.com/gin-gonic/gin/issues/4002 +func TestMethodNotAllowedNoRoute(t *testing.T) { + g := New() + g.HandleMethodNotAllowed = true + + req := httptest.NewRequest(http.MethodGet, "/", nil) + resp := httptest.NewRecorder() + assert.NotPanics(t, func() { g.ServeHTTP(resp, req) }) + assert.Equal(t, http.StatusNotFound, resp.Code) +} diff --git a/githubapi_test.go b/githubapi_test.go index 9276bed5..0c86af2e 100644 --- a/githubapi_test.go +++ b/githubapi_test.go @@ -10,10 +10,12 @@ import ( "net/http" "net/http/httptest" "os" + "strconv" "strings" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) type route struct { @@ -295,9 +297,9 @@ func TestShouldBindUri(t *testing.T) { } router.Handle(http.MethodGet, "/rest/:name/:id", func(c *Context) { var person Person - assert.NoError(t, c.ShouldBindUri(&person)) - assert.True(t, person.Name != "") - assert.True(t, person.ID != "") + require.NoError(t, c.ShouldBindUri(&person)) + assert.NotEqual(t, "", person.Name) + assert.NotEqual(t, "", person.ID) c.String(http.StatusOK, "ShouldBindUri test OK") }) @@ -317,9 +319,9 @@ func TestBindUri(t *testing.T) { } router.Handle(http.MethodGet, "/rest/:name/:id", func(c *Context) { var person Person - assert.NoError(t, c.BindUri(&person)) - assert.True(t, person.Name != "") - assert.True(t, person.ID != "") + require.NoError(t, c.BindUri(&person)) + assert.NotEqual(t, "", person.Name) + assert.NotEqual(t, "", person.ID) c.String(http.StatusOK, "BindUri test OK") }) @@ -338,7 +340,7 @@ func TestBindUriError(t *testing.T) { } router.Handle(http.MethodGet, "/new/rest/:num", func(c *Context) { var m Member - assert.Error(t, c.BindUri(&m)) + require.Error(t, c.BindUri(&m)) }) path1, _ := exampleFromPath("/new/rest/:num") @@ -410,7 +412,7 @@ func exampleFromPath(path string) (string, Params) { } if start >= 0 { if c == '/' { - value := fmt.Sprint(rand.Intn(100000)) + value := strconv.Itoa(rand.Intn(100000)) params = append(params, Param{ Key: path[start:i], Value: value, @@ -424,7 +426,7 @@ func exampleFromPath(path string) (string, Params) { } } if start >= 0 { - value := fmt.Sprint(rand.Intn(100000)) + value := strconv.Itoa(rand.Intn(100000)) params = append(params, Param{ Key: path[start:], Value: value, diff --git a/go.mod b/go.mod index 0b60c5d7..3a7e1ba6 100644 --- a/go.mod +++ b/go.mod @@ -1,37 +1,46 @@ module github.com/gin-gonic/gin -go 1.20 +go 1.23.0 require ( - github.com/bytedance/sonic v1.10.2 - github.com/gin-contrib/sse v0.1.0 - github.com/go-playground/validator/v10 v10.17.0 + github.com/bytedance/sonic v1.13.1 + github.com/gin-contrib/sse v1.1.0 + github.com/go-playground/validator/v10 v10.26.0 github.com/goccy/go-json v0.10.2 github.com/json-iterator/go v1.1.12 github.com/mattn/go-isatty v0.0.20 - github.com/pelletier/go-toml/v2 v2.1.1 - github.com/stretchr/testify v1.8.4 + github.com/pelletier/go-toml/v2 v2.2.2 + github.com/quic-go/quic-go v0.51.0 + github.com/stretchr/testify v1.10.0 github.com/ugorji/go/codec v1.2.12 - golang.org/x/net v0.20.0 - google.golang.org/protobuf v1.32.0 + golang.org/x/net v0.38.0 + google.golang.org/protobuf v1.34.1 gopkg.in/yaml.v3 v3.0.1 ) require ( - github.com/chenzhuoyu/base64x v0.0.0-20230717121745-296ad89f973d // indirect - github.com/chenzhuoyu/iasm v0.9.1 // indirect + github.com/bytedance/sonic/loader v0.2.4 // indirect + github.com/cloudwego/base64x v0.1.5 // indirect github.com/davecgh/go-spew v1.1.1 // indirect - github.com/gabriel-vasile/mimetype v1.4.3 // indirect + github.com/gabriel-vasile/mimetype v1.4.8 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect - github.com/klauspost/cpuid/v2 v2.2.6 // indirect - github.com/leodido/go-urn v1.3.0 // indirect - github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect + github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect + github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect + github.com/klauspost/cpuid/v2 v2.0.9 // indirect + github.com/leodido/go-urn v1.4.0 // indirect + github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 // indirect github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/onsi/ginkgo/v2 v2.9.5 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/quic-go/qpack v0.5.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect - golang.org/x/arch v0.7.0 // indirect - golang.org/x/crypto v0.18.0 // indirect - golang.org/x/sys v0.16.0 // indirect - golang.org/x/text v0.14.0 // indirect + go.uber.org/mock v0.5.0 // indirect + golang.org/x/arch v0.0.0-20210923205945-b76863e36670 // indirect + golang.org/x/crypto v0.36.0 // indirect + golang.org/x/mod v0.18.0 // indirect + golang.org/x/sync v0.12.0 // indirect + golang.org/x/sys v0.31.0 // indirect + golang.org/x/text v0.23.0 // indirect + golang.org/x/tools v0.22.0 // indirect ) diff --git a/go.sum b/go.sum index e360d9d2..5a7f2adf 100644 --- a/go.sum +++ b/go.sum @@ -1,85 +1,111 @@ -github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= -github.com/bytedance/sonic v1.10.0-rc/go.mod h1:ElCzW+ufi8qKqNW0FY314xriJhyJhuoJ3gFZdAHF7NM= -github.com/bytedance/sonic v1.10.2 h1:GQebETVBxYB7JGWJtLBi07OVzWwt+8dWA00gEVW2ZFE= -github.com/bytedance/sonic v1.10.2/go.mod h1:iZcSUejdk5aukTND/Eu/ivjQuEL0Cu9/rf50Hi0u/g4= -github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= -github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= -github.com/chenzhuoyu/base64x v0.0.0-20230717121745-296ad89f973d h1:77cEq6EriyTZ0g/qfRdp61a3Uu/AWrgIq2s0ClJV1g0= -github.com/chenzhuoyu/base64x v0.0.0-20230717121745-296ad89f973d/go.mod h1:8EPpVsBuRksnlj1mLy4AWzRNQYxauNi62uWcE3to6eA= -github.com/chenzhuoyu/iasm v0.9.0/go.mod h1:Xjy2NpN3h7aUqeqM+woSuuvxmIe6+DDsiNLIrkAmYog= -github.com/chenzhuoyu/iasm v0.9.1 h1:tUHQJXo3NhBqw6s33wkGn9SP3bvrWLdlVIJ3hQBL7P0= -github.com/chenzhuoyu/iasm v0.9.1/go.mod h1:Xjy2NpN3h7aUqeqM+woSuuvxmIe6+DDsiNLIrkAmYog= +github.com/bytedance/sonic v1.13.1 h1:Jyd5CIvdFnkOWuKXr+wm4Nyk2h0yAFsr8ucJgEasO3g= +github.com/bytedance/sonic v1.13.1/go.mod h1:o68xyaF9u2gvVBuGHPlUVCy+ZfmNNO5ETf1+KgkJhz4= +github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= +github.com/bytedance/sonic/loader v0.2.4 h1:ZWCw4stuXUsn1/+zQDqeE7JKP+QO47tz7QCNan80NzY= +github.com/bytedance/sonic/loader v0.2.4/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI= +github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= +github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= +github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= +github.com/cloudwego/base64x v0.1.5 h1:XPciSp1xaq2VCSt6lF0phncD4koWyULpl5bUxbfCyP4= +github.com/cloudwego/base64x v0.1.5/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= +github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= -github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= -github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= -github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= +github.com/gabriel-vasile/mimetype v1.4.8 h1:FfZ3gj38NjllZIeJAmMhr+qKL8Wu+nOoI3GqacKw1NM= +github.com/gabriel-vasile/mimetype v1.4.8/go.mod h1:ByKUIKGjh1ODkGM1asKUbQZOLGrPjydw3hYPU2YU9t8= +github.com/gin-contrib/sse v1.1.0 h1:n0w2GMuUpWDVp7qSpvze6fAu9iRxJY4Hmj6AmBOU05w= +github.com/gin-contrib/sse v1.1.0/go.mod h1:hxRZ5gVpWMT7Z0B0gSNYqqsSCNIJMjzvm6fqCz9vjwM= +github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= +github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= -github.com/go-playground/validator/v10 v10.17.0 h1:SmVVlfAOtlZncTxRuinDPomC2DkXJ4E5T9gDA0AIH74= -github.com/go-playground/validator/v10 v10.17.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= +github.com/go-playground/validator/v10 v10.26.0 h1:SP05Nqhjcvz81uJaRfEV0YBSSSGMc/iMaVtFbr3Sw2k= +github.com/go-playground/validator/v10 v10.26.0/go.mod h1:I5QpIEbmr8On7W0TktmJAumgzX4CA1XNl4ZmDuVHKKo= +github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= +github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= -github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= +github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= +github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE= +github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= +github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= -github.com/klauspost/cpuid/v2 v2.2.6 h1:ndNyv040zDGIDh8thGkXYjnFtiN02M1PVVF+JE/48xc= -github.com/klauspost/cpuid/v2 v2.2.6/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M= -github.com/leodido/go-urn v1.3.0 h1:jX8FDLfW4ThVXctBNZ+3cIWnCSnrACDV73r76dy0aQQ= -github.com/leodido/go-urn v1.3.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= +github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= +github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= -github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= -github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= -github.com/pelletier/go-toml/v2 v2.1.1 h1:LWAJwfNvjQZCFIDKWYQaM62NcYeYViCmWIwmOStowAI= -github.com/pelletier/go-toml/v2 v2.1.1/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc= +github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q= +github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k= +github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE= +github.com/onsi/gomega v1.27.6/go.mod h1:PIQNjfQwkP3aQAH7lf7j87O/5FiNr+ZR8+ipb+qQlhg= +github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= +github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI= +github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg= +github.com/quic-go/quic-go v0.51.0 h1:K8exxe9zXxeRKxaXxi/GpUqYiTrtdiWP8bo1KFya6Wc= +github.com/quic-go/quic-go v0.51.0/go.mod h1:MFlGGpcpJqRAfmYi6NC2cptDPSxRWTOGNuP4wqrWmzQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= +go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU= +go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM= +golang.org/x/arch v0.0.0-20210923205945-b76863e36670 h1:18EFjUmQOcUvxNYSkA6jO9VAiXCnxFY6NyDX0bHDmkU= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= -golang.org/x/arch v0.7.0 h1:pskyeJh/3AmoQ8CPE95vxHLqp1G1GfGNXTmcl9NEKTc= -golang.org/x/arch v0.7.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= -golang.org/x/crypto v0.18.0 h1:PGVlW0xEltQnzFZ55hkuX5+KLyrMYhHld1YHO4AKcdc= -golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= -golang.org/x/net v0.20.0 h1:aCL9BSgETF1k+blQaYUBx9hJ9LOGP3gAVemcZlf1Kpo= -golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= -golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= +golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= +golang.org/x/mod v0.18.0 h1:5+9lSbEzPSdWkH32vYPBwEpX8KwDbM52Ud9xBUvNlb0= +golang.org/x/mod v0.18.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= +golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= +golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw= +golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU= -golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= -golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= -google.golang.org/protobuf v1.32.0 h1:pPC6BG5ex8PDFnkbrGU3EixyhKcQ2aDuBS36lqK/C7I= -google.golang.org/protobuf v1.32.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= +golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= +golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= +golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= +golang.org/x/tools v0.22.0 h1:gqSGLZqv+AI9lIQzniJ0nZDRG5GBPsSi+DRNHWNz6yA= +golang.org/x/tools v0.22.0/go.mod h1:aCwcsjqvq7Yqt6TNyX7QMU2enbQ/Gt0bo6krSeEri+c= +google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= +google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50= -rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/internal/bytesconv/bytesconv_1.20.go b/internal/bytesconv/bytesconv.go similarity index 97% rename from internal/bytesconv/bytesconv_1.20.go rename to internal/bytesconv/bytesconv.go index 5b6040a6..a02c53c3 100644 --- a/internal/bytesconv/bytesconv_1.20.go +++ b/internal/bytesconv/bytesconv.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a MIT style // license that can be found in the LICENSE file. -//go:build go1.20 - package bytesconv import ( diff --git a/internal/bytesconv/bytesconv_1.19.go b/internal/bytesconv/bytesconv_1.19.go deleted file mode 100644 index 669c9c91..00000000 --- a/internal/bytesconv/bytesconv_1.19.go +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright 2020 Gin Core Team. All rights reserved. -// Use of this source code is governed by a MIT style -// license that can be found in the LICENSE file. - -//go:build !go1.20 - -package bytesconv - -import ( - "unsafe" -) - -// StringToBytes converts string to byte slice without a memory allocation. -func StringToBytes(s string) []byte { - return *(*[]byte)(unsafe.Pointer( - &struct { - string - Cap int - }{s, len(s)}, - )) -} - -// BytesToString converts byte slice to string without a memory allocation. -func BytesToString(b []byte) string { - return *(*string)(unsafe.Pointer(&b)) -} diff --git a/internal/fs/fs.go b/internal/fs/fs.go new file mode 100644 index 00000000..524ac08b --- /dev/null +++ b/internal/fs/fs.go @@ -0,0 +1,22 @@ +package fs + +import ( + "io/fs" + "net/http" +) + +// FileSystem implements an [fs.FS]. +type FileSystem struct { + http.FileSystem +} + +// Open passes `Open` to the upstream implementation and return an [fs.File]. +func (o FileSystem) Open(name string) (fs.File, error) { + f, err := o.FileSystem.Open(name) + + if err != nil { + return nil, err + } + + return fs.File(f), nil +} diff --git a/internal/fs/fs_test.go b/internal/fs/fs_test.go new file mode 100644 index 00000000..113e92b6 --- /dev/null +++ b/internal/fs/fs_test.go @@ -0,0 +1,49 @@ +package fs + +import ( + "errors" + "net/http" + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type mockFileSystem struct { + open func(name string) (http.File, error) +} + +func (m *mockFileSystem) Open(name string) (http.File, error) { + return m.open(name) +} + +func TesFileSystem_Open(t *testing.T) { + var testFile *os.File + mockFS := &mockFileSystem{ + open: func(name string) (http.File, error) { + return testFile, nil + }, + } + fs := &FileSystem{mockFS} + + file, err := fs.Open("foo") + + require.NoError(t, err) + assert.Equal(t, testFile, file) +} + +func TestFileSystem_Open_err(t *testing.T) { + testError := errors.New("mock") + mockFS := &mockFileSystem{ + open: func(_ string) (http.File, error) { + return nil, testError + }, + } + fs := &FileSystem{mockFS} + + file, err := fs.Open("foo") + + require.ErrorIs(t, err, testError) + assert.Nil(t, file) +} diff --git a/logger_test.go b/logger_test.go index 6c1814dc..de00c499 100644 --- a/logger_test.go +++ b/logger_test.go @@ -31,9 +31,9 @@ func TestLogger(t *testing.T) { router.HEAD("/example", func(c *Context) {}) router.OPTIONS("/example", func(c *Context) {}) - PerformRequest(router, "GET", "/example?a=100") + PerformRequest(router, http.MethodGet, "/example?a=100") assert.Contains(t, buffer.String(), "200") - assert.Contains(t, buffer.String(), "GET") + assert.Contains(t, buffer.String(), http.MethodGet) assert.Contains(t, buffer.String(), "/example") assert.Contains(t, buffer.String(), "a=100") @@ -41,21 +41,21 @@ func TestLogger(t *testing.T) { // like integration tests because they test the whole logging process rather // than individual functions. Im not sure where these should go. buffer.Reset() - PerformRequest(router, "POST", "/example") + PerformRequest(router, http.MethodPost, "/example") assert.Contains(t, buffer.String(), "200") - assert.Contains(t, buffer.String(), "POST") + assert.Contains(t, buffer.String(), http.MethodPost) assert.Contains(t, buffer.String(), "/example") buffer.Reset() - PerformRequest(router, "PUT", "/example") + PerformRequest(router, http.MethodPut, "/example") assert.Contains(t, buffer.String(), "200") - assert.Contains(t, buffer.String(), "PUT") + assert.Contains(t, buffer.String(), http.MethodPut) assert.Contains(t, buffer.String(), "/example") buffer.Reset() - PerformRequest(router, "DELETE", "/example") + PerformRequest(router, http.MethodDelete, "/example") assert.Contains(t, buffer.String(), "200") - assert.Contains(t, buffer.String(), "DELETE") + assert.Contains(t, buffer.String(), http.MethodDelete) assert.Contains(t, buffer.String(), "/example") buffer.Reset() @@ -77,9 +77,9 @@ func TestLogger(t *testing.T) { assert.Contains(t, buffer.String(), "/example") buffer.Reset() - PerformRequest(router, "GET", "/notfound") + PerformRequest(router, http.MethodGet, "/notfound") assert.Contains(t, buffer.String(), "404") - assert.Contains(t, buffer.String(), "GET") + assert.Contains(t, buffer.String(), http.MethodGet) assert.Contains(t, buffer.String(), "/notfound") } @@ -95,9 +95,9 @@ func TestLoggerWithConfig(t *testing.T) { router.HEAD("/example", func(c *Context) {}) router.OPTIONS("/example", func(c *Context) {}) - PerformRequest(router, "GET", "/example?a=100") + PerformRequest(router, http.MethodGet, "/example?a=100") assert.Contains(t, buffer.String(), "200") - assert.Contains(t, buffer.String(), "GET") + assert.Contains(t, buffer.String(), http.MethodGet) assert.Contains(t, buffer.String(), "/example") assert.Contains(t, buffer.String(), "a=100") @@ -105,21 +105,21 @@ func TestLoggerWithConfig(t *testing.T) { // like integration tests because they test the whole logging process rather // than individual functions. Im not sure where these should go. buffer.Reset() - PerformRequest(router, "POST", "/example") + PerformRequest(router, http.MethodPost, "/example") assert.Contains(t, buffer.String(), "200") - assert.Contains(t, buffer.String(), "POST") + assert.Contains(t, buffer.String(), http.MethodPost) assert.Contains(t, buffer.String(), "/example") buffer.Reset() - PerformRequest(router, "PUT", "/example") + PerformRequest(router, http.MethodPut, "/example") assert.Contains(t, buffer.String(), "200") - assert.Contains(t, buffer.String(), "PUT") + assert.Contains(t, buffer.String(), http.MethodPut) assert.Contains(t, buffer.String(), "/example") buffer.Reset() - PerformRequest(router, "DELETE", "/example") + PerformRequest(router, http.MethodDelete, "/example") assert.Contains(t, buffer.String(), "200") - assert.Contains(t, buffer.String(), "DELETE") + assert.Contains(t, buffer.String(), http.MethodDelete) assert.Contains(t, buffer.String(), "/example") buffer.Reset() @@ -141,9 +141,9 @@ func TestLoggerWithConfig(t *testing.T) { assert.Contains(t, buffer.String(), "/example") buffer.Reset() - PerformRequest(router, "GET", "/notfound") + PerformRequest(router, http.MethodGet, "/notfound") assert.Contains(t, buffer.String(), "404") - assert.Contains(t, buffer.String(), "GET") + assert.Contains(t, buffer.String(), http.MethodGet) assert.Contains(t, buffer.String(), "/notfound") } @@ -169,12 +169,12 @@ func TestLoggerWithFormatter(t *testing.T) { ) })) router.GET("/example", func(c *Context) {}) - PerformRequest(router, "GET", "/example?a=100") + PerformRequest(router, http.MethodGet, "/example?a=100") // output test assert.Contains(t, buffer.String(), "[FORMATTER TEST]") assert.Contains(t, buffer.String(), "200") - assert.Contains(t, buffer.String(), "GET") + assert.Contains(t, buffer.String(), http.MethodGet) assert.Contains(t, buffer.String(), "/example") assert.Contains(t, buffer.String(), "a=100") } @@ -210,12 +210,12 @@ func TestLoggerWithConfigFormatting(t *testing.T) { gotKeys = c.Keys time.Sleep(time.Millisecond) }) - PerformRequest(router, "GET", "/example?a=100") + PerformRequest(router, http.MethodGet, "/example?a=100") // output test assert.Contains(t, buffer.String(), "[FORMATTER TEST]") assert.Contains(t, buffer.String(), "200") - assert.Contains(t, buffer.String(), "GET") + assert.Contains(t, buffer.String(), http.MethodGet) assert.Contains(t, buffer.String(), "/example") assert.Contains(t, buffer.String(), "a=100") @@ -225,7 +225,7 @@ func TestLoggerWithConfigFormatting(t *testing.T) { assert.Equal(t, 200, gotParam.StatusCode) assert.NotEmpty(t, gotParam.Latency) assert.Equal(t, "20.20.20.20", gotParam.ClientIP) - assert.Equal(t, "GET", gotParam.Method) + assert.Equal(t, http.MethodGet, gotParam.Method) assert.Equal(t, "/example?a=100", gotParam.Path) assert.Empty(t, gotParam.ErrorMessage) assert.Equal(t, gotKeys, gotParam.Keys) @@ -239,7 +239,7 @@ func TestDefaultLogFormatter(t *testing.T) { StatusCode: 200, Latency: time.Second * 5, ClientIP: "20.20.20.20", - Method: "GET", + Method: http.MethodGet, Path: "/", ErrorMessage: "", isTerm: false, @@ -250,7 +250,7 @@ func TestDefaultLogFormatter(t *testing.T) { StatusCode: 200, Latency: time.Second * 5, ClientIP: "20.20.20.20", - Method: "GET", + Method: http.MethodGet, Path: "/", ErrorMessage: "", isTerm: true, @@ -260,7 +260,7 @@ func TestDefaultLogFormatter(t *testing.T) { StatusCode: 200, Latency: time.Millisecond * 9876543210, ClientIP: "20.20.20.20", - Method: "GET", + Method: http.MethodGet, Path: "/", ErrorMessage: "", isTerm: true, @@ -271,7 +271,7 @@ func TestDefaultLogFormatter(t *testing.T) { StatusCode: 200, Latency: time.Millisecond * 9876543210, ClientIP: "20.20.20.20", - Method: "GET", + Method: http.MethodGet, Path: "/", ErrorMessage: "", isTerm: false, @@ -292,10 +292,10 @@ func TestColorForMethod(t *testing.T) { return p.MethodColor() } - assert.Equal(t, blue, colorForMethod("GET"), "get should be blue") - assert.Equal(t, cyan, colorForMethod("POST"), "post should be cyan") - assert.Equal(t, yellow, colorForMethod("PUT"), "put should be yellow") - assert.Equal(t, red, colorForMethod("DELETE"), "delete should be red") + assert.Equal(t, blue, colorForMethod(http.MethodGet), "get should be blue") + assert.Equal(t, cyan, colorForMethod(http.MethodPost), "post should be cyan") + assert.Equal(t, yellow, colorForMethod(http.MethodPut), "put should be yellow") + assert.Equal(t, red, colorForMethod(http.MethodDelete), "delete should be red") assert.Equal(t, green, colorForMethod("PATCH"), "patch should be green") assert.Equal(t, magenta, colorForMethod("HEAD"), "head should be magenta") assert.Equal(t, white, colorForMethod("OPTIONS"), "options should be white") @@ -329,13 +329,13 @@ func TestIsOutputColor(t *testing.T) { } consoleColorMode = autoColor - assert.Equal(t, true, p.IsOutputColor()) + assert.True(t, p.IsOutputColor()) ForceConsoleColor() - assert.Equal(t, true, p.IsOutputColor()) + assert.True(t, p.IsOutputColor()) DisableConsoleColor() - assert.Equal(t, false, p.IsOutputColor()) + assert.False(t, p.IsOutputColor()) // test with isTerm flag false. p = LogFormatterParams{ @@ -343,13 +343,13 @@ func TestIsOutputColor(t *testing.T) { } consoleColorMode = autoColor - assert.Equal(t, false, p.IsOutputColor()) + assert.False(t, p.IsOutputColor()) ForceConsoleColor() - assert.Equal(t, true, p.IsOutputColor()) + assert.True(t, p.IsOutputColor()) DisableConsoleColor() - assert.Equal(t, false, p.IsOutputColor()) + assert.False(t, p.IsOutputColor()) // reset console color mode. consoleColorMode = autoColor @@ -369,15 +369,15 @@ func TestErrorLogger(t *testing.T) { c.String(http.StatusInternalServerError, "hola!") }) - w := PerformRequest(router, "GET", "/error") + w := PerformRequest(router, http.MethodGet, "/error") assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, "{\"error\":\"this is an error\"}", w.Body.String()) - w = PerformRequest(router, "GET", "/abort") + w = PerformRequest(router, http.MethodGet, "/abort") assert.Equal(t, http.StatusUnauthorized, w.Code) assert.Equal(t, "{\"error\":\"no authorized\"}", w.Body.String()) - w = PerformRequest(router, "GET", "/print") + w = PerformRequest(router, http.MethodGet, "/print") assert.Equal(t, http.StatusInternalServerError, w.Code) assert.Equal(t, "hola!{\"error\":\"this is an error\"}", w.Body.String()) } @@ -389,11 +389,11 @@ func TestLoggerWithWriterSkippingPaths(t *testing.T) { router.GET("/logged", func(c *Context) {}) router.GET("/skipped", func(c *Context) {}) - PerformRequest(router, "GET", "/logged") + PerformRequest(router, http.MethodGet, "/logged") assert.Contains(t, buffer.String(), "200") buffer.Reset() - PerformRequest(router, "GET", "/skipped") + PerformRequest(router, http.MethodGet, "/skipped") assert.Contains(t, buffer.String(), "") } @@ -407,11 +407,11 @@ func TestLoggerWithConfigSkippingPaths(t *testing.T) { router.GET("/logged", func(c *Context) {}) router.GET("/skipped", func(c *Context) {}) - PerformRequest(router, "GET", "/logged") + PerformRequest(router, http.MethodGet, "/logged") assert.Contains(t, buffer.String(), "200") buffer.Reset() - PerformRequest(router, "GET", "/skipped") + PerformRequest(router, http.MethodGet, "/skipped") assert.Contains(t, buffer.String(), "") } @@ -427,11 +427,11 @@ func TestLoggerWithConfigSkipper(t *testing.T) { router.GET("/logged", func(c *Context) { c.Status(http.StatusOK) }) router.GET("/skipped", func(c *Context) { c.Status(http.StatusNoContent) }) - PerformRequest(router, "GET", "/logged") + PerformRequest(router, http.MethodGet, "/logged") assert.Contains(t, buffer.String(), "200") buffer.Reset() - PerformRequest(router, "GET", "/skipped") + PerformRequest(router, http.MethodGet, "/skipped") assert.Contains(t, buffer.String(), "") } diff --git a/middleware_test.go b/middleware_test.go index acdf89c4..eafc60ad 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -35,7 +35,7 @@ func TestMiddlewareGeneralCase(t *testing.T) { signature += " XX " }) // RUN - w := PerformRequest(router, "GET", "/") + w := PerformRequest(router, http.MethodGet, "/") // TEST assert.Equal(t, http.StatusOK, w.Code) @@ -71,7 +71,7 @@ func TestMiddlewareNoRoute(t *testing.T) { signature += " X " }) // RUN - w := PerformRequest(router, "GET", "/") + w := PerformRequest(router, http.MethodGet, "/") // TEST assert.Equal(t, http.StatusNotFound, w.Code) @@ -108,7 +108,7 @@ func TestMiddlewareNoMethodEnabled(t *testing.T) { signature += " XX " }) // RUN - w := PerformRequest(router, "GET", "/") + w := PerformRequest(router, http.MethodGet, "/") // TEST assert.Equal(t, http.StatusMethodNotAllowed, w.Code) @@ -149,7 +149,7 @@ func TestMiddlewareNoMethodDisabled(t *testing.T) { }) // RUN - w := PerformRequest(router, "GET", "/") + w := PerformRequest(router, http.MethodGet, "/") // TEST assert.Equal(t, http.StatusNotFound, w.Code) @@ -175,7 +175,7 @@ func TestMiddlewareAbort(t *testing.T) { }) // RUN - w := PerformRequest(router, "GET", "/") + w := PerformRequest(router, http.MethodGet, "/") // TEST assert.Equal(t, http.StatusUnauthorized, w.Code) @@ -196,7 +196,7 @@ func TestMiddlewareAbortHandlersChainAndNext(t *testing.T) { c.Next() }) // RUN - w := PerformRequest(router, "GET", "/") + w := PerformRequest(router, http.MethodGet, "/") // TEST assert.Equal(t, http.StatusGone, w.Code) @@ -219,7 +219,7 @@ func TestMiddlewareFailHandlersChain(t *testing.T) { signature += "C" }) // RUN - w := PerformRequest(router, "GET", "/") + w := PerformRequest(router, http.MethodGet, "/") // TEST assert.Equal(t, http.StatusInternalServerError, w.Code) @@ -246,7 +246,7 @@ func TestMiddlewareWrite(t *testing.T) { }) }) - w := PerformRequest(router, "GET", "/") + w := PerformRequest(router, http.MethodGet, "/") assert.Equal(t, http.StatusBadRequest, w.Code) assert.Equal(t, strings.Replace("hola\nbar{\"foo\":\"bar\"}{\"foo\":\"bar\"}event:test\ndata:message\n\n", " ", "", -1), strings.Replace(w.Body.String(), " ", "", -1)) diff --git a/mode.go b/mode.go index fd26d907..13aa3be0 100644 --- a/mode.go +++ b/mode.go @@ -8,6 +8,7 @@ import ( "flag" "io" "os" + "sync/atomic" "github.com/gin-gonic/gin/binding" ) @@ -43,10 +44,8 @@ var DefaultWriter io.Writer = os.Stdout // DefaultErrorWriter is the default io.Writer used by Gin to debug errors var DefaultErrorWriter io.Writer = os.Stderr -var ( - ginMode = debugCode - modeName = DebugMode -) +var ginMode int32 = debugCode +var modeName atomic.Value func init() { mode := os.Getenv(EnvGinMode) @@ -64,17 +63,16 @@ func SetMode(value string) { } switch value { - case DebugMode: - ginMode = debugCode + case DebugMode, "": + atomic.StoreInt32(&ginMode, debugCode) case ReleaseMode: - ginMode = releaseCode + atomic.StoreInt32(&ginMode, releaseCode) case TestMode: - ginMode = testCode + atomic.StoreInt32(&ginMode, testCode) default: panic("gin mode unknown: " + value + " (available mode: debug release test)") } - - modeName = value + modeName.Store(value) } // DisableBindValidation closes the default validator. @@ -96,5 +94,5 @@ func EnableJsonDecoderDisallowUnknownFields() { // Mode returns current gin mode. func Mode() string { - return modeName + return modeName.Load().(string) } diff --git a/mode_test.go b/mode_test.go index 2407f463..be03a9d0 100644 --- a/mode_test.go +++ b/mode_test.go @@ -5,8 +5,8 @@ package gin import ( - "flag" "os" + "sync/atomic" "testing" "github.com/gin-gonic/gin/binding" @@ -18,31 +18,24 @@ func init() { } func TestSetMode(t *testing.T) { - assert.Equal(t, testCode, ginMode) + assert.Equal(t, int32(testCode), atomic.LoadInt32(&ginMode)) assert.Equal(t, TestMode, Mode()) os.Unsetenv(EnvGinMode) SetMode("") - assert.Equal(t, testCode, ginMode) + assert.Equal(t, int32(testCode), atomic.LoadInt32(&ginMode)) assert.Equal(t, TestMode, Mode()) - tmp := flag.CommandLine - flag.CommandLine = flag.NewFlagSet("", flag.ContinueOnError) - SetMode("") - assert.Equal(t, debugCode, ginMode) - assert.Equal(t, DebugMode, Mode()) - flag.CommandLine = tmp - SetMode(DebugMode) - assert.Equal(t, debugCode, ginMode) + assert.Equal(t, int32(debugCode), atomic.LoadInt32(&ginMode)) assert.Equal(t, DebugMode, Mode()) SetMode(ReleaseMode) - assert.Equal(t, releaseCode, ginMode) + assert.Equal(t, int32(releaseCode), atomic.LoadInt32(&ginMode)) assert.Equal(t, ReleaseMode, Mode()) SetMode(TestMode) - assert.Equal(t, testCode, ginMode) + assert.Equal(t, int32(testCode), atomic.LoadInt32(&ginMode)) assert.Equal(t, TestMode, Mode()) assert.Panics(t, func() { SetMode("unknown") }) diff --git a/path_test.go b/path_test.go index caefd63a..2269b78e 100644 --- a/path_test.go +++ b/path_test.go @@ -6,6 +6,7 @@ package gin import ( + "runtime" "strings" "testing" @@ -80,9 +81,13 @@ func TestPathCleanMallocs(t *testing.T) { t.Skip("skipping malloc count in short mode") } + if runtime.GOMAXPROCS(0) > 1 { + t.Skip("skipping malloc count; GOMAXPROCS>1") + } + for _, test := range cleanTests { allocs := testing.AllocsPerRun(100, func() { cleanPath(test.result) }) - assert.EqualValues(t, allocs, 0) + assert.InDelta(t, 0, allocs, 0.01) } } diff --git a/recovery_test.go b/recovery_test.go index fa8ab894..08eec1e4 100644 --- a/recovery_test.go +++ b/recovery_test.go @@ -5,7 +5,6 @@ package gin import ( - "fmt" "net" "net/http" "os" @@ -26,14 +25,14 @@ func TestPanicClean(t *testing.T) { panic("Oupps, Houston, we have a problem") }) // RUN - w := PerformRequest(router, "GET", "/recovery", + w := PerformRequest(router, http.MethodGet, "/recovery", header{ Key: "Host", Value: "www.google.com", }, header{ Key: "Authorization", - Value: fmt.Sprintf("Bearer %s", password), + Value: "Bearer " + password, }, header{ Key: "Content-Type", @@ -56,7 +55,7 @@ func TestPanicInHandler(t *testing.T) { panic("Oupps, Houston, we have a problem") }) // RUN - w := PerformRequest(router, "GET", "/recovery") + w := PerformRequest(router, http.MethodGet, "/recovery") // TEST assert.Equal(t, http.StatusInternalServerError, w.Code) assert.Contains(t, buffer.String(), "panic recovered") @@ -67,7 +66,7 @@ func TestPanicInHandler(t *testing.T) { // Debug mode prints the request SetMode(DebugMode) // RUN - w = PerformRequest(router, "GET", "/recovery") + w = PerformRequest(router, http.MethodGet, "/recovery") // TEST assert.Equal(t, http.StatusInternalServerError, w.Code) assert.Contains(t, buffer.String(), "GET /recovery") @@ -84,7 +83,7 @@ func TestPanicWithAbort(t *testing.T) { panic("Oupps, Houston, we have a problem") }) // RUN - w := PerformRequest(router, "GET", "/recovery") + w := PerformRequest(router, http.MethodGet, "/recovery") // TEST assert.Equal(t, http.StatusBadRequest, w.Code) } @@ -135,7 +134,7 @@ func TestPanicWithBrokenPipe(t *testing.T) { panic(e) }) // RUN - w := PerformRequest(router, "GET", "/recovery") + w := PerformRequest(router, http.MethodGet, "/recovery") // TEST assert.Equal(t, expectCode, w.Code) assert.Contains(t, strings.ToLower(buf.String()), expectMsg) @@ -156,7 +155,7 @@ func TestCustomRecoveryWithWriter(t *testing.T) { panic("Oupps, Houston, we have a problem") }) // RUN - w := PerformRequest(router, "GET", "/recovery") + w := PerformRequest(router, http.MethodGet, "/recovery") // TEST assert.Equal(t, http.StatusBadRequest, w.Code) assert.Contains(t, buffer.String(), "panic recovered") @@ -167,7 +166,7 @@ func TestCustomRecoveryWithWriter(t *testing.T) { // Debug mode prints the request SetMode(DebugMode) // RUN - w = PerformRequest(router, "GET", "/recovery") + w = PerformRequest(router, http.MethodGet, "/recovery") // TEST assert.Equal(t, http.StatusBadRequest, w.Code) assert.Contains(t, buffer.String(), "GET /recovery") @@ -191,7 +190,7 @@ func TestCustomRecovery(t *testing.T) { panic("Oupps, Houston, we have a problem") }) // RUN - w := PerformRequest(router, "GET", "/recovery") + w := PerformRequest(router, http.MethodGet, "/recovery") // TEST assert.Equal(t, http.StatusBadRequest, w.Code) assert.Contains(t, buffer.String(), "panic recovered") @@ -202,7 +201,7 @@ func TestCustomRecovery(t *testing.T) { // Debug mode prints the request SetMode(DebugMode) // RUN - w = PerformRequest(router, "GET", "/recovery") + w = PerformRequest(router, http.MethodGet, "/recovery") // TEST assert.Equal(t, http.StatusBadRequest, w.Code) assert.Contains(t, buffer.String(), "GET /recovery") @@ -226,7 +225,7 @@ func TestRecoveryWithWriterWithCustomRecovery(t *testing.T) { panic("Oupps, Houston, we have a problem") }) // RUN - w := PerformRequest(router, "GET", "/recovery") + w := PerformRequest(router, http.MethodGet, "/recovery") // TEST assert.Equal(t, http.StatusBadRequest, w.Code) assert.Contains(t, buffer.String(), "panic recovered") @@ -237,7 +236,7 @@ func TestRecoveryWithWriterWithCustomRecovery(t *testing.T) { // Debug mode prints the request SetMode(DebugMode) // RUN - w = PerformRequest(router, "GET", "/recovery") + w = PerformRequest(router, http.MethodGet, "/recovery") // TEST assert.Equal(t, http.StatusBadRequest, w.Code) assert.Contains(t, buffer.String(), "GET /recovery") diff --git a/render/html.go b/render/html.go index c308408d..f5e7455a 100644 --- a/render/html.go +++ b/render/html.go @@ -7,6 +7,8 @@ package render import ( "html/template" "net/http" + + "github.com/gin-gonic/gin/internal/fs" ) // Delims represents a set of Left and Right delimiters for HTML template rendering. @@ -31,10 +33,12 @@ type HTMLProduction struct { // HTMLDebug contains template delims and pattern and function with file list. type HTMLDebug struct { - Files []string - Glob string - Delims Delims - FuncMap template.FuncMap + Files []string + Glob string + FileSystem http.FileSystem + Patterns []string + Delims Delims + FuncMap template.FuncMap } // HTML contains template reference and its name with given interface object. @@ -73,7 +77,11 @@ func (r HTMLDebug) loadTemplate() *template.Template { if r.Glob != "" { return template.Must(template.New("").Delims(r.Delims.Left, r.Delims.Right).Funcs(r.FuncMap).ParseGlob(r.Glob)) } - panic("the HTML debug render was created without files or glob pattern") + if r.FileSystem != nil && len(r.Patterns) > 0 { + return template.Must(template.New("").Delims(r.Delims.Left, r.Delims.Right).Funcs(r.FuncMap).ParseFS( + fs.FileSystem{FileSystem: r.FileSystem}, r.Patterns...)) + } + panic("the HTML debug render was created without files or glob pattern or file system with patterns") } // Render (HTML) executes template and writes its result with custom ContentType for response. diff --git a/render/json.go b/render/json.go index fc8dea45..23923c44 100644 --- a/render/json.go +++ b/render/json.go @@ -9,6 +9,7 @@ import ( "fmt" "html/template" "net/http" + "unicode" "github.com/gin-gonic/gin/internal/bytesconv" "github.com/gin-gonic/gin/internal/json" @@ -151,7 +152,7 @@ func (r JsonpJSON) WriteContentType(w http.ResponseWriter) { } // Render (AsciiJSON) marshals the given interface object and writes it with custom ContentType. -func (r AsciiJSON) Render(w http.ResponseWriter) (err error) { +func (r AsciiJSON) Render(w http.ResponseWriter) error { r.WriteContentType(w) ret, err := json.Marshal(r.Data) if err != nil { @@ -159,12 +160,15 @@ func (r AsciiJSON) Render(w http.ResponseWriter) (err error) { } var buffer bytes.Buffer + escapeBuf := make([]byte, 0, 6) // Preallocate 6 bytes for Unicode escape sequences + for _, r := range bytesconv.BytesToString(ret) { - cvt := string(r) - if r >= 128 { - cvt = fmt.Sprintf("\\u%04x", int64(r)) + if r > unicode.MaxASCII { + escapeBuf = fmt.Appendf(escapeBuf[:0], "\\u%04x", r) // Reuse escapeBuf + buffer.Write(escapeBuf) + } else { + buffer.WriteByte(byte(r)) } - buffer.WriteString(cvt) } _, err = w.Write(buffer.Bytes()) diff --git a/render/render.go b/render/render.go index 7955000c..4bdcfa23 100644 --- a/render/render.go +++ b/render/render.go @@ -15,22 +15,22 @@ type Render interface { } var ( - _ Render = JSON{} - _ Render = IndentedJSON{} - _ Render = SecureJSON{} - _ Render = JsonpJSON{} - _ Render = XML{} - _ Render = String{} - _ Render = Redirect{} - _ Render = Data{} - _ Render = HTML{} - _ HTMLRender = HTMLDebug{} - _ HTMLRender = HTMLProduction{} - _ Render = YAML{} - _ Render = Reader{} - _ Render = AsciiJSON{} - _ Render = ProtoBuf{} - _ Render = TOML{} + _ Render = (*JSON)(nil) + _ Render = (*IndentedJSON)(nil) + _ Render = (*SecureJSON)(nil) + _ Render = (*JsonpJSON)(nil) + _ Render = (*XML)(nil) + _ Render = (*String)(nil) + _ Render = (*Redirect)(nil) + _ Render = (*Data)(nil) + _ Render = (*HTML)(nil) + _ HTMLRender = (*HTMLDebug)(nil) + _ HTMLRender = (*HTMLProduction)(nil) + _ Render = (*YAML)(nil) + _ Render = (*Reader)(nil) + _ Render = (*AsciiJSON)(nil) + _ Render = (*ProtoBuf)(nil) + _ Render = (*TOML)(nil) ) func writeContentType(w http.ResponseWriter, value []string) { diff --git a/render/render_msgpack_test.go b/render/render_msgpack_test.go index db4b71e5..579897cc 100644 --- a/render/render_msgpack_test.go +++ b/render/render_msgpack_test.go @@ -12,6 +12,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/ugorji/go/codec" ) @@ -29,7 +30,7 @@ func TestRenderMsgPack(t *testing.T) { err := (MsgPack{data}).Render(w) - assert.NoError(t, err) + require.NoError(t, err) h := new(codec.MsgpackHandle) assert.NotNil(t, h) @@ -37,7 +38,7 @@ func TestRenderMsgPack(t *testing.T) { assert.NotNil(t, buf) err = codec.NewEncoder(buf, h).Encode(data) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, w.Body.String(), buf.String()) assert.Equal(t, "application/msgpack; charset=utf-8", w.Header().Get("Content-Type")) } diff --git a/render/render_test.go b/render/render_test.go index c9db635f..4dd2a3af 100644 --- a/render/render_test.go +++ b/render/render_test.go @@ -18,6 +18,7 @@ import ( "github.com/gin-gonic/gin/internal/json" testdata "github.com/gin-gonic/gin/testdata/protoexample" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "google.golang.org/protobuf/proto" ) @@ -36,7 +37,7 @@ func TestRenderJSON(t *testing.T) { err := (JSON{data}).Render(w) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "{\"foo\":\"bar\",\"html\":\"\\u003cb\\u003e\"}", w.Body.String()) assert.Equal(t, "application/json; charset=utf-8", w.Header().Get("Content-Type")) } @@ -46,7 +47,7 @@ func TestRenderJSONError(t *testing.T) { data := make(chan int) // json: unsupported type: chan int - assert.Error(t, (JSON{data}).Render(w)) + require.Error(t, (JSON{data}).Render(w)) } func TestRenderIndentedJSON(t *testing.T) { @@ -58,7 +59,7 @@ func TestRenderIndentedJSON(t *testing.T) { err := (IndentedJSON{data}).Render(w) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "{\n \"bar\": \"foo\",\n \"foo\": \"bar\"\n}", w.Body.String()) assert.Equal(t, "application/json; charset=utf-8", w.Header().Get("Content-Type")) } @@ -69,7 +70,7 @@ func TestRenderIndentedJSONPanics(t *testing.T) { // json: unsupported type: chan int err := (IndentedJSON{data}).Render(w) - assert.Error(t, err) + require.Error(t, err) } func TestRenderSecureJSON(t *testing.T) { @@ -83,7 +84,7 @@ func TestRenderSecureJSON(t *testing.T) { err1 := (SecureJSON{"while(1);", data}).Render(w1) - assert.NoError(t, err1) + require.NoError(t, err1) assert.Equal(t, "{\"foo\":\"bar\"}", w1.Body.String()) assert.Equal(t, "application/json; charset=utf-8", w1.Header().Get("Content-Type")) @@ -95,7 +96,7 @@ func TestRenderSecureJSON(t *testing.T) { }} err2 := (SecureJSON{"while(1);", datas}).Render(w2) - assert.NoError(t, err2) + require.NoError(t, err2) assert.Equal(t, "while(1);[{\"foo\":\"bar\"},{\"bar\":\"foo\"}]", w2.Body.String()) assert.Equal(t, "application/json; charset=utf-8", w2.Header().Get("Content-Type")) } @@ -106,7 +107,7 @@ func TestRenderSecureJSONFail(t *testing.T) { // json: unsupported type: chan int err := (SecureJSON{"while(1);", data}).Render(w) - assert.Error(t, err) + require.Error(t, err) } func TestRenderJsonpJSON(t *testing.T) { @@ -120,7 +121,7 @@ func TestRenderJsonpJSON(t *testing.T) { err1 := (JsonpJSON{"x", data}).Render(w1) - assert.NoError(t, err1) + require.NoError(t, err1) assert.Equal(t, "x({\"foo\":\"bar\"});", w1.Body.String()) assert.Equal(t, "application/javascript; charset=utf-8", w1.Header().Get("Content-Type")) @@ -132,7 +133,7 @@ func TestRenderJsonpJSON(t *testing.T) { }} err2 := (JsonpJSON{"x", datas}).Render(w2) - assert.NoError(t, err2) + require.NoError(t, err2) assert.Equal(t, "x([{\"foo\":\"bar\"},{\"bar\":\"foo\"}]);", w2.Body.String()) assert.Equal(t, "application/javascript; charset=utf-8", w2.Header().Get("Content-Type")) } @@ -191,7 +192,7 @@ func TestRenderJsonpJSONError2(t *testing.T) { assert.Equal(t, "application/javascript; charset=utf-8", w.Header().Get("Content-Type")) e := (JsonpJSON{"", data}).Render(w) - assert.NoError(t, e) + require.NoError(t, e) assert.Equal(t, "{\"foo\":\"bar\"}", w.Body.String()) assert.Equal(t, "application/javascript; charset=utf-8", w.Header().Get("Content-Type")) @@ -203,7 +204,7 @@ func TestRenderJsonpJSONFail(t *testing.T) { // json: unsupported type: chan int err := (JsonpJSON{"x", data}).Render(w) - assert.Error(t, err) + require.Error(t, err) } func TestRenderAsciiJSON(t *testing.T) { @@ -215,7 +216,7 @@ func TestRenderAsciiJSON(t *testing.T) { err := (AsciiJSON{data1}).Render(w1) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "{\"lang\":\"GO\\u8bed\\u8a00\",\"tag\":\"\\u003cbr\\u003e\"}", w1.Body.String()) assert.Equal(t, "application/json", w1.Header().Get("Content-Type")) @@ -223,7 +224,7 @@ func TestRenderAsciiJSON(t *testing.T) { data2 := 3.1415926 err = (AsciiJSON{data2}).Render(w2) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "3.1415926", w2.Body.String()) } @@ -232,7 +233,7 @@ func TestRenderAsciiJSONFail(t *testing.T) { data := make(chan int) // json: unsupported type: chan int - assert.Error(t, (AsciiJSON{data}).Render(w)) + require.Error(t, (AsciiJSON{data}).Render(w)) } func TestRenderPureJSON(t *testing.T) { @@ -242,7 +243,7 @@ func TestRenderPureJSON(t *testing.T) { "html": "", } err := (PureJSON{data}).Render(w) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "{\"foo\":\"bar\",\"html\":\"\"}\n", w.Body.String()) assert.Equal(t, "application/json; charset=utf-8", w.Header().Get("Content-Type")) } @@ -280,12 +281,12 @@ b: d: [3, 4] ` (YAML{data}).WriteContentType(w) - assert.Equal(t, "application/x-yaml; charset=utf-8", w.Header().Get("Content-Type")) + assert.Equal(t, "application/yaml; charset=utf-8", w.Header().Get("Content-Type")) err := (YAML{data}).Render(w) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "|4-\n a : Easy!\n b:\n \tc: 2\n \td: [3, 4]\n \t\n", w.Body.String()) - assert.Equal(t, "application/x-yaml; charset=utf-8", w.Header().Get("Content-Type")) + assert.Equal(t, "application/yaml; charset=utf-8", w.Header().Get("Content-Type")) } type fail struct{} @@ -298,7 +299,7 @@ func (ft *fail) MarshalYAML() (any, error) { func TestRenderYAMLFail(t *testing.T) { w := httptest.NewRecorder() err := (YAML{&fail{}}).Render(w) - assert.Error(t, err) + require.Error(t, err) } func TestRenderTOML(t *testing.T) { @@ -311,7 +312,7 @@ func TestRenderTOML(t *testing.T) { assert.Equal(t, "application/toml; charset=utf-8", w.Header().Get("Content-Type")) err := (TOML{data}).Render(w) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "foo = 'bar'\nhtml = ''\n", w.Body.String()) assert.Equal(t, "application/toml; charset=utf-8", w.Header().Get("Content-Type")) } @@ -319,7 +320,7 @@ func TestRenderTOML(t *testing.T) { func TestRenderTOMLFail(t *testing.T) { w := httptest.NewRecorder() err := (TOML{net.IPv4bcast}).Render(w) - assert.Error(t, err) + require.Error(t, err) } // test Protobuf rendering @@ -334,12 +335,12 @@ func TestRenderProtoBuf(t *testing.T) { (ProtoBuf{data}).WriteContentType(w) protoData, err := proto.Marshal(data) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "application/x-protobuf", w.Header().Get("Content-Type")) err = (ProtoBuf{data}).Render(w) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, string(protoData), w.Body.String()) assert.Equal(t, "application/x-protobuf", w.Header().Get("Content-Type")) } @@ -348,7 +349,7 @@ func TestRenderProtoBufFail(t *testing.T) { w := httptest.NewRecorder() data := &testdata.Test{} err := (ProtoBuf{data}).Render(w) - assert.Error(t, err) + require.Error(t, err) } func TestRenderXML(t *testing.T) { @@ -362,14 +363,14 @@ func TestRenderXML(t *testing.T) { err := (XML{data}).Render(w) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "bar", w.Body.String()) assert.Equal(t, "application/xml; charset=utf-8", w.Header().Get("Content-Type")) } func TestRenderRedirect(t *testing.T) { - req, err := http.NewRequest("GET", "/test-redirect", nil) - assert.NoError(t, err) + req, err := http.NewRequest(http.MethodGet, "/test-redirect", nil) + require.NoError(t, err) data1 := Redirect{ Code: http.StatusMovedPermanently, @@ -379,7 +380,7 @@ func TestRenderRedirect(t *testing.T) { w := httptest.NewRecorder() err = data1.Render(w) - assert.NoError(t, err) + require.NoError(t, err) data2 := Redirect{ Code: http.StatusOK, @@ -390,7 +391,7 @@ func TestRenderRedirect(t *testing.T) { w = httptest.NewRecorder() assert.PanicsWithValue(t, "Cannot redirect with status code 200", func() { err := data2.Render(w) - assert.NoError(t, err) + require.NoError(t, err) }) data3 := Redirect{ @@ -401,7 +402,7 @@ func TestRenderRedirect(t *testing.T) { w = httptest.NewRecorder() err = data3.Render(w) - assert.NoError(t, err) + require.NoError(t, err) // only improve coverage data2.WriteContentType(w) @@ -416,7 +417,7 @@ func TestRenderData(t *testing.T) { Data: data, }).Render(w) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "#!PNG some raw data", w.Body.String()) assert.Equal(t, "image/png", w.Header().Get("Content-Type")) } @@ -435,7 +436,7 @@ func TestRenderString(t *testing.T) { Data: []any{"manu", 2}, }).Render(w) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "hola manu 2", w.Body.String()) assert.Equal(t, "text/plain; charset=utf-8", w.Header().Get("Content-Type")) } @@ -448,7 +449,7 @@ func TestRenderStringLenZero(t *testing.T) { Data: []any{}, }).Render(w) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "hola %s %d", w.Body.String()) assert.Equal(t, "text/plain; charset=utf-8", w.Header().Get("Content-Type")) } @@ -464,7 +465,7 @@ func TestRenderHTMLTemplate(t *testing.T) { err := instance.Render(w) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "Hello alexandernyquist", w.Body.String()) assert.Equal(t, "text/html; charset=utf-8", w.Header().Get("Content-Type")) } @@ -480,7 +481,7 @@ func TestRenderHTMLTemplateEmptyName(t *testing.T) { err := instance.Render(w) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "Hello alexandernyquist", w.Body.String()) assert.Equal(t, "text/html; charset=utf-8", w.Header().Get("Content-Type")) } @@ -488,10 +489,12 @@ func TestRenderHTMLTemplateEmptyName(t *testing.T) { func TestRenderHTMLDebugFiles(t *testing.T) { w := httptest.NewRecorder() htmlRender := HTMLDebug{ - Files: []string{"../testdata/template/hello.tmpl"}, - Glob: "", - Delims: Delims{Left: "{[{", Right: "}]}"}, - FuncMap: nil, + Files: []string{"../testdata/template/hello.tmpl"}, + Glob: "", + FileSystem: nil, + Patterns: nil, + Delims: Delims{Left: "{[{", Right: "}]}"}, + FuncMap: nil, } instance := htmlRender.Instance("hello.tmpl", map[string]any{ "name": "thinkerou", @@ -499,7 +502,7 @@ func TestRenderHTMLDebugFiles(t *testing.T) { err := instance.Render(w) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "

Hello thinkerou

", w.Body.String()) assert.Equal(t, "text/html; charset=utf-8", w.Header().Get("Content-Type")) } @@ -507,10 +510,12 @@ func TestRenderHTMLDebugFiles(t *testing.T) { func TestRenderHTMLDebugGlob(t *testing.T) { w := httptest.NewRecorder() htmlRender := HTMLDebug{ - Files: nil, - Glob: "../testdata/template/hello*", - Delims: Delims{Left: "{[{", Right: "}]}"}, - FuncMap: nil, + Files: nil, + Glob: "../testdata/template/hello*", + FileSystem: nil, + Patterns: nil, + Delims: Delims{Left: "{[{", Right: "}]}"}, + FuncMap: nil, } instance := htmlRender.Instance("hello.tmpl", map[string]any{ "name": "thinkerou", @@ -518,17 +523,40 @@ func TestRenderHTMLDebugGlob(t *testing.T) { err := instance.Render(w) - assert.NoError(t, err) + require.NoError(t, err) + assert.Equal(t, "

Hello thinkerou

", w.Body.String()) + assert.Equal(t, "text/html; charset=utf-8", w.Header().Get("Content-Type")) +} + +func TestRenderHTMLDebugFS(t *testing.T) { + w := httptest.NewRecorder() + htmlRender := HTMLDebug{ + Files: nil, + Glob: "", + FileSystem: http.Dir("../testdata/template"), + Patterns: []string{"hello.tmpl"}, + Delims: Delims{Left: "{[{", Right: "}]}"}, + FuncMap: nil, + } + instance := htmlRender.Instance("hello.tmpl", map[string]any{ + "name": "thinkerou", + }) + + err := instance.Render(w) + + require.NoError(t, err) assert.Equal(t, "

Hello thinkerou

", w.Body.String()) assert.Equal(t, "text/html; charset=utf-8", w.Header().Get("Content-Type")) } func TestRenderHTMLDebugPanics(t *testing.T) { htmlRender := HTMLDebug{ - Files: nil, - Glob: "", - Delims: Delims{"{{", "}}"}, - FuncMap: nil, + Files: nil, + Glob: "", + FileSystem: nil, + Patterns: nil, + Delims: Delims{"{{", "}}"}, + FuncMap: nil, } assert.Panics(t, func() { htmlRender.Instance("", nil) }) } @@ -548,7 +576,7 @@ func TestRenderReader(t *testing.T) { Headers: headers, }).Render(w) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, body, w.Body.String()) assert.Equal(t, "image/png", w.Header().Get("Content-Type")) assert.Equal(t, strconv.Itoa(len(body)), w.Header().Get("Content-Length")) @@ -571,7 +599,7 @@ func TestRenderReaderNoContentLength(t *testing.T) { Headers: headers, }).Render(w) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, body, w.Body.String()) assert.Equal(t, "image/png", w.Header().Get("Content-Type")) assert.NotContains(t, "Content-Length", w.Header()) @@ -588,6 +616,6 @@ func TestRenderWriteError(t *testing.T) { ResponseRecorder: httptest.NewRecorder(), } err := r.Render(ew) - assert.NotNil(t, err) + require.Error(t, err) assert.Equal(t, `write "my-prefix:" error`, err.Error()) } diff --git a/render/yaml.go b/render/yaml.go index fc927c1f..042bb821 100644 --- a/render/yaml.go +++ b/render/yaml.go @@ -15,7 +15,7 @@ type YAML struct { Data any } -var yamlContentType = []string{"application/x-yaml; charset=utf-8"} +var yamlContentType = []string{"application/yaml; charset=utf-8"} // Render (YAML) marshals the given interface object and writes data with custom ContentType. func (r YAML) Render(w http.ResponseWriter) error { diff --git a/response_writer_test.go b/response_writer_test.go index 964aa307..259b8fa8 100644 --- a/response_writer_test.go +++ b/response_writer_test.go @@ -10,6 +10,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) // TODO @@ -95,13 +96,13 @@ func TestResponseWriterWrite(t *testing.T) { assert.Equal(t, http.StatusOK, w.Status()) assert.Equal(t, http.StatusOK, testWriter.Code) assert.Equal(t, "hola", testWriter.Body.String()) - assert.NoError(t, err) + require.NoError(t, err) n, err = w.Write([]byte(" adios")) assert.Equal(t, 6, n) assert.Equal(t, 10, w.Size()) assert.Equal(t, "hola adios", testWriter.Body.String()) - assert.NoError(t, err) + require.NoError(t, err) } func TestResponseWriterHijack(t *testing.T) { @@ -112,7 +113,7 @@ func TestResponseWriterHijack(t *testing.T) { assert.Panics(t, func() { _, _, err := w.Hijack() - assert.NoError(t, err) + require.NoError(t, err) }) assert.True(t, w.Written()) @@ -135,7 +136,7 @@ func TestResponseWriterFlush(t *testing.T) { // should return 500 resp, err := http.Get(testServer.URL) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) } diff --git a/routergroup.go b/routergroup.go index c833fe8f..b2540ec1 100644 --- a/routergroup.go +++ b/routergroup.go @@ -218,7 +218,7 @@ func (group *RouterGroup) createStaticHandler(relativePath string, fs http.FileS fileServer := http.StripPrefix(absolutePath, http.FileServer(fs)) return func(c *Context) { - if _, noListing := fs.(*onlyFilesFS); noListing { + if _, noListing := fs.(*OnlyFilesFS); noListing { c.Writer.WriteHeader(http.StatusNotFound) } diff --git a/routes_test.go b/routes_test.go index a0ff695f..995ff51c 100644 --- a/routes_test.go +++ b/routes_test.go @@ -13,6 +13,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) type header struct { @@ -180,58 +181,58 @@ func TestRouteRedirectTrailingSlash(t *testing.T) { w = PerformRequest(router, http.MethodGet, "/path2", header{Key: "X-Forwarded-Prefix", Value: "/api"}) assert.Equal(t, "/api/path2/", w.Header().Get("Location")) - assert.Equal(t, 301, w.Code) + assert.Equal(t, http.StatusMovedPermanently, w.Code) w = PerformRequest(router, http.MethodGet, "/path2/", header{Key: "X-Forwarded-Prefix", Value: "/api/"}) - assert.Equal(t, 200, w.Code) + assert.Equal(t, http.StatusOK, w.Code) w = PerformRequest(router, http.MethodGet, "/path/", header{Key: "X-Forwarded-Prefix", Value: "../../api#?"}) assert.Equal(t, "/api/path", w.Header().Get("Location")) - assert.Equal(t, 301, w.Code) + assert.Equal(t, http.StatusMovedPermanently, w.Code) w = PerformRequest(router, http.MethodGet, "/path/", header{Key: "X-Forwarded-Prefix", Value: "../../api"}) assert.Equal(t, "/api/path", w.Header().Get("Location")) - assert.Equal(t, 301, w.Code) + assert.Equal(t, http.StatusMovedPermanently, w.Code) w = PerformRequest(router, http.MethodGet, "/path2", header{Key: "X-Forwarded-Prefix", Value: "../../api"}) assert.Equal(t, "/api/path2/", w.Header().Get("Location")) - assert.Equal(t, 301, w.Code) + assert.Equal(t, http.StatusMovedPermanently, w.Code) w = PerformRequest(router, http.MethodGet, "/path2", header{Key: "X-Forwarded-Prefix", Value: "/../../api"}) assert.Equal(t, "/api/path2/", w.Header().Get("Location")) - assert.Equal(t, 301, w.Code) + assert.Equal(t, http.StatusMovedPermanently, w.Code) w = PerformRequest(router, http.MethodGet, "/path/", header{Key: "X-Forwarded-Prefix", Value: "api/../../"}) assert.Equal(t, "//path", w.Header().Get("Location")) - assert.Equal(t, 301, w.Code) + assert.Equal(t, http.StatusMovedPermanently, w.Code) w = PerformRequest(router, http.MethodGet, "/path/", header{Key: "X-Forwarded-Prefix", Value: "api/../../../"}) assert.Equal(t, "/path", w.Header().Get("Location")) - assert.Equal(t, 301, w.Code) + assert.Equal(t, http.StatusMovedPermanently, w.Code) w = PerformRequest(router, http.MethodGet, "/path2", header{Key: "X-Forwarded-Prefix", Value: "../../gin-gonic.com"}) assert.Equal(t, "/gin-goniccom/path2/", w.Header().Get("Location")) - assert.Equal(t, 301, w.Code) + assert.Equal(t, http.StatusMovedPermanently, w.Code) w = PerformRequest(router, http.MethodGet, "/path2", header{Key: "X-Forwarded-Prefix", Value: "/../../gin-gonic.com"}) assert.Equal(t, "/gin-goniccom/path2/", w.Header().Get("Location")) - assert.Equal(t, 301, w.Code) + assert.Equal(t, http.StatusMovedPermanently, w.Code) w = PerformRequest(router, http.MethodGet, "/path/", header{Key: "X-Forwarded-Prefix", Value: "https://gin-gonic.com/#"}) assert.Equal(t, "https/gin-goniccom/https/gin-goniccom/path", w.Header().Get("Location")) - assert.Equal(t, 301, w.Code) + assert.Equal(t, http.StatusMovedPermanently, w.Code) w = PerformRequest(router, http.MethodGet, "/path/", header{Key: "X-Forwarded-Prefix", Value: "#api"}) assert.Equal(t, "api/api/path", w.Header().Get("Location")) - assert.Equal(t, 301, w.Code) + assert.Equal(t, http.StatusMovedPermanently, w.Code) w = PerformRequest(router, http.MethodGet, "/path/", header{Key: "X-Forwarded-Prefix", Value: "/nor-mal/#?a=1"}) assert.Equal(t, "/nor-mal/a1/path", w.Header().Get("Location")) - assert.Equal(t, 301, w.Code) + assert.Equal(t, http.StatusMovedPermanently, w.Code) w = PerformRequest(router, http.MethodGet, "/path/", header{Key: "X-Forwarded-Prefix", Value: "/nor-mal/%2e%2e/"}) assert.Equal(t, "/nor-mal/2e2e/path", w.Header().Get("Location")) - assert.Equal(t, 301, w.Code) + assert.Equal(t, http.StatusMovedPermanently, w.Code) router.RedirectTrailingSlash = false @@ -339,7 +340,7 @@ func TestRouteParamsByNameWithExtraSlash(t *testing.T) { // TestRouteParamsNotEmpty tests that context parameters will be set // even if a route with params/wildcards is registered after the context -// initialisation (which happened in a previous requets). +// initialisation (which happened in a previous requests). func TestRouteParamsNotEmpty(t *testing.T) { name := "" lastName := "" @@ -386,7 +387,7 @@ func TestRouteStaticFile(t *testing.T) { } defer os.Remove(f.Name()) _, err = f.WriteString("Gin Web Framework") - assert.NoError(t, err) + require.NoError(t, err) f.Close() dir, filename := filepath.Split(f.Name()) @@ -421,7 +422,7 @@ func TestRouteStaticFileFS(t *testing.T) { } defer os.Remove(f.Name()) _, err = f.WriteString("Gin Web Framework") - assert.NoError(t, err) + require.NoError(t, err) f.Close() dir, filename := filepath.Split(f.Name()) @@ -484,7 +485,7 @@ func TestRouterMiddlewareAndStatic(t *testing.T) { // Content-Type='text/plain; charset=utf-8' when go version <= 1.16, // else, Content-Type='text/x-go; charset=utf-8' assert.NotEqual(t, "", w.Header().Get("Content-Type")) - assert.NotEqual(t, w.Header().Get("Last-Modified"), "Mon, 02 Jan 2006 15:04:05 MST") + assert.NotEqual(t, "Mon, 02 Jan 2006 15:04:05 MST", w.Header().Get("Last-Modified")) assert.Equal(t, "Mon, 02 Jan 2006 15:04:05 MST", w.Header().Get("Expires")) assert.Equal(t, "Gin Framework", w.Header().Get("x-GIN")) } @@ -522,8 +523,8 @@ func TestRouteNotAllowedEnabled3(t *testing.T) { w := PerformRequest(router, http.MethodPut, "/path") assert.Equal(t, http.StatusMethodNotAllowed, w.Code) allowed := w.Header().Get("Allow") - assert.Contains(t, allowed, "GET") - assert.Contains(t, allowed, "POST") + assert.Contains(t, allowed, http.MethodGet) + assert.Contains(t, allowed, http.MethodPost) } func TestRouteNotAllowedDisabled(t *testing.T) { @@ -556,10 +557,10 @@ func TestRouterNotFoundWithRemoveExtraSlash(t *testing.T) { {"/nope", http.StatusNotFound, ""}, // NotFound } for _, tr := range testRoutes { - w := PerformRequest(router, "GET", tr.route) + w := PerformRequest(router, http.MethodGet, tr.route) assert.Equal(t, tr.code, w.Code) if w.Code != http.StatusNotFound { - assert.Equal(t, tr.location, fmt.Sprint(w.Header().Get("Location"))) + assert.Equal(t, tr.location, w.Header().Get("Location")) } } } @@ -589,7 +590,7 @@ func TestRouterNotFound(t *testing.T) { w := PerformRequest(router, http.MethodGet, tr.route) assert.Equal(t, tr.code, w.Code) if w.Code != http.StatusNotFound { - assert.Equal(t, tr.location, fmt.Sprint(w.Header().Get("Location"))) + assert.Equal(t, tr.location, w.Header().Get("Location")) } } @@ -619,11 +620,11 @@ func TestRouterNotFound(t *testing.T) { router = New() router.NoRoute(func(c *Context) { if c.Request.RequestURI == "/login" { - c.String(200, "login") + c.String(http.StatusOK, "login") } }) router.GET("/logout", func(c *Context) { - c.String(200, "logout") + c.String(http.StatusOK, "logout") }) w = PerformRequest(router, http.MethodGet, "/login") assert.Equal(t, "login", w.Body.String()) @@ -635,7 +636,7 @@ func TestRouterStaticFSNotFound(t *testing.T) { router := New() router.StaticFS("/", http.FileSystem(http.Dir("/thisreallydoesntexist/"))) router.NoRoute(func(c *Context) { - c.String(404, "non existent") + c.String(http.StatusNotFound, "non existent") }) w := PerformRequest(router, http.MethodGet, "/nonexistent") @@ -718,12 +719,12 @@ func TestRouteRawPathNoUnescape(t *testing.T) { func TestRouteServeErrorWithWriteHeader(t *testing.T) { route := New() route.Use(func(c *Context) { - c.Status(421) + c.Status(http.StatusMisdirectedRequest) c.Next() }) w := PerformRequest(route, http.MethodGet, "/NotFound") - assert.Equal(t, 421, w.Code) + assert.Equal(t, http.StatusMisdirectedRequest, w.Code) assert.Equal(t, 0, w.Body.Len()) } @@ -785,6 +786,6 @@ func TestEngineHandleMethodNotAllowedCornerCase(t *testing.T) { v1.GET("/orgs/:id", handlerTest1) v1.DELETE("/orgs/:id", handlerTest1) - w := PerformRequest(r, "GET", "/base/v1/user/groups") + w := PerformRequest(r, http.MethodGet, "/base/v1/user/groups") assert.Equal(t, http.StatusNotFound, w.Code) } diff --git a/tree.go b/tree.go index 878023d1..0d3e5a8c 100644 --- a/tree.go +++ b/tree.go @@ -65,17 +65,10 @@ func (trees methodTrees) get(method string) *node { return nil } -func min(a, b int) int { - if a <= b { - return a - } - return b -} - func longestCommonPrefix(a, b string) int { i := 0 - max := min(len(a), len(b)) - for i < max && a[i] == b[i] { + max_ := min(len(a), len(b)) + for i < max_ && a[i] == b[i] { i++ } return i @@ -205,7 +198,7 @@ walk: } // Check if a child with the next path byte exists - for i, max := 0, len(n.indices); i < max; i++ { + for i, max_ := 0, len(n.indices); i < max_; i++ { if c == n.indices[i] { parentFullPathIndex += len(n.path) i = n.incrementChildPrio(i) @@ -269,7 +262,19 @@ walk: // Returns -1 as index, if no wildcard was found. func findWildcard(path string) (wildcard string, i int, valid bool) { // Find start + escapeColon := false for start, c := range []byte(path) { + if escapeColon { + escapeColon = false + if c == ':' { + continue + } + panic("invalid escape string in path '" + path + "'") + } + if c == '\\' { + escapeColon = true + continue + } // A wildcard starts with ':' (param) or '*' (catch-all) if c != ':' && c != '*' { continue @@ -364,7 +369,7 @@ func (n *node) insertChild(path string, fullPath string, handlers HandlersChain) // currently fixed width 1 for '/' i-- - if path[i] != '/' { + if i < 0 || path[i] != '/' { panic("no / before catch-all in path '" + fullPath + "'") } @@ -770,7 +775,7 @@ walk: // Outer loop for walking the tree // Runes are up to 4 byte long, // -4 would definitely be another rune. var off int - for max := min(npLen, 3); off < max; off++ { + for max_ := min(npLen, 3); off < max_; off++ { if i := npLen - off; utf8.RuneStart(oldPath[i]) { // read rune from cached path rv, _ = utf8.DecodeRuneInString(oldPath[i:]) diff --git a/tree_test.go b/tree_test.go index c9b03130..74eb6104 100644 --- a/tree_test.go +++ b/tree_test.go @@ -192,6 +192,7 @@ func TestTreeWildcard(t *testing.T) { "/get/abc/123abg/:param", "/get/abc/123abf/:param", "/get/abc/123abfff/:param", + "/get/abc/escaped_colon/test\\:param", } for _, route := range routes { tree.addRoute(route, fakeHandler(route)) @@ -315,6 +316,7 @@ func TestTreeWildcard(t *testing.T) { {"/get/abc/123abg/test", false, "/get/abc/123abg/:param", Params{Param{Key: "param", Value: "test"}}}, {"/get/abc/123abf/testss", false, "/get/abc/123abf/:param", Params{Param{Key: "param", Value: "testss"}}}, {"/get/abc/123abfff/te", false, "/get/abc/123abfff/:param", Params{Param{Key: "param", Value: "te"}}}, + {"/get/abc/escaped_colon/test\\:param", false, "/get/abc/escaped_colon/test\\:param", nil}, }) checkPriorities(t, tree) @@ -419,6 +421,9 @@ func TestTreeWildcardConflict(t *testing.T) { {"/id/:id", false}, {"/static/*file", false}, {"/static/", true}, + {"/escape/test\\:d1", false}, + {"/escape/test\\:d2", false}, + {"/escape/test:param", false}, } testRoutes(t, routes) } @@ -971,3 +976,45 @@ func TestTreeWildcardConflictEx(t *testing.T) { } } } + +func TestTreeInvalidEscape(t *testing.T) { + routes := map[string]bool{ + "/r1/r": true, + "/r2/:r": true, + "/r3/\\:r": true, + } + tree := &node{} + for route, valid := range routes { + recv := catchPanic(func() { + tree.addRoute(route, fakeHandler(route)) + }) + if recv == nil != valid { + t.Fatalf("%s should be %t but got %v", route, valid, recv) + } + } +} + +func TestWildcardInvalidSlash(t *testing.T) { + const panicMsgPrefix = "no / before catch-all in path" + + routes := map[string]bool{ + "/foo/bar": true, + "/foo/x*zy": false, + "/foo/b*r": false, + } + + for route, valid := range routes { + tree := &node{} + recv := catchPanic(func() { + tree.addRoute(route, nil) + }) + + if recv == nil != valid { + t.Fatalf("%s should be %t but got %v", route, valid, recv) + } + + if rs, ok := recv.(string); recv != nil && (!ok || !strings.HasPrefix(rs, panicMsgPrefix)) { + t.Fatalf(`"Expected panic "%s" for route '%s', got "%v"`, panicMsgPrefix, route, recv) + } + } +} diff --git a/utils_test.go b/utils_test.go index 058ddb9d..8098c681 100644 --- a/utils_test.go +++ b/utils_test.go @@ -29,7 +29,7 @@ type testStruct struct { } func (t *testStruct) ServeHTTP(w http.ResponseWriter, req *http.Request) { - assert.Equal(t.T, "POST", req.Method) + assert.Equal(t.T, http.MethodPost, req.Method) assert.Equal(t.T, "/path", req.URL.Path) w.WriteHeader(http.StatusInternalServerError) fmt.Fprint(w, "hello") @@ -39,17 +39,17 @@ func TestWrap(t *testing.T) { router := New() router.POST("/path", WrapH(&testStruct{t})) router.GET("/path2", WrapF(func(w http.ResponseWriter, req *http.Request) { - assert.Equal(t, "GET", req.Method) + assert.Equal(t, http.MethodGet, req.Method) assert.Equal(t, "/path2", req.URL.Path) w.WriteHeader(http.StatusBadRequest) fmt.Fprint(w, "hola!") })) - w := PerformRequest(router, "POST", "/path") + w := PerformRequest(router, http.MethodPost, "/path") assert.Equal(t, http.StatusInternalServerError, w.Code) assert.Equal(t, "hello", w.Body.String()) - w = PerformRequest(router, "GET", "/path2") + w = PerformRequest(router, http.MethodGet, "/path2") assert.Equal(t, http.StatusBadRequest, w.Code) assert.Equal(t, "hola!", w.Body.String()) } @@ -119,13 +119,13 @@ func TestBindMiddleware(t *testing.T) { called = true value = c.MustGet(BindKey).(*bindTestStruct) }) - PerformRequest(router, "GET", "/?foo=hola&bar=10") + PerformRequest(router, http.MethodGet, "/?foo=hola&bar=10") assert.True(t, called) assert.Equal(t, "hola", value.Foo) assert.Equal(t, 10, value.Bar) called = false - PerformRequest(router, "GET", "/?foo=hola&bar=1") + PerformRequest(router, http.MethodGet, "/?foo=hola&bar=1") assert.False(t, called) assert.Panics(t, func() { @@ -145,6 +145,6 @@ func TestMarshalXMLforH(t *testing.T) { } func TestIsASCII(t *testing.T) { - assert.Equal(t, isASCII("test"), true) - assert.Equal(t, isASCII("🧡💛💚💙💜"), false) + assert.True(t, isASCII("test")) + assert.False(t, isASCII("🧡💛💚💙💜")) } diff --git a/version.go b/version.go index 85462e55..93ad9654 100644 --- a/version.go +++ b/version.go @@ -5,4 +5,4 @@ package gin // Version is the current gin framework's version. -const Version = "v1.9.1" +const Version = "v1.10.0"