diff --git a/context.go b/context.go index b66b8adc..75a2d9ef 100644 --- a/context.go +++ b/context.go @@ -5,6 +5,7 @@ package gin import ( + "context" "errors" "io" "io/ioutil" @@ -50,9 +51,10 @@ const abortIndex int8 = math.MaxInt8 >> 1 // Context is the most important part of gin. It allows us to pass variables between middleware, // manage the flow, validate the JSON of a request and render a JSON response for example. type Context struct { - writermem responseWriter - Request *http.Request - Writer ResponseWriter + writermem responseWriter + Request *http.Request + Writer ResponseWriter + requestContext context.Context Params Params handlers HandlersChain @@ -108,6 +110,22 @@ func (c *Context) reset() { *c.skippedNodes = (*c.skippedNodes)[:0] } +// Initializes c.Request and c.requestContext according to the ContextWithFallback feature flag +func (c *Context) setRequest(req *http.Request) { + if req == nil { + c.Request = nil + c.requestContext = nil + return + } + + c.requestContext = req.Context() + if c.engine.ContextWithFallback { + c.Request = req.WithContext(c) + } else { + c.Request = req + } +} + // Copy returns a copy of the current context that can be safely used outside the request's scope. // This has to be used when the context has to be passed to a goroutine. func (c *Context) Copy() *Context { @@ -1173,26 +1191,26 @@ func (c *Context) SetAccepted(formats ...string) { // Deadline returns that there is no deadline (ok==false) when c.Request has no Context. func (c *Context) Deadline() (deadline time.Time, ok bool) { - if !c.engine.ContextWithFallback || c.Request == nil || c.Request.Context() == nil { + if !c.engine.ContextWithFallback || c.Request == nil || c.requestContext == nil { return } - return c.Request.Context().Deadline() + return c.requestContext.Deadline() } // Done returns nil (chan which will wait forever) when c.Request has no Context. func (c *Context) Done() <-chan struct{} { - if !c.engine.ContextWithFallback || c.Request == nil || c.Request.Context() == nil { + if !c.engine.ContextWithFallback || c.Request == nil || c.requestContext == nil { return nil } - return c.Request.Context().Done() + return c.requestContext.Done() } // Err returns nil when c.Request has no Context. func (c *Context) Err() error { - if !c.engine.ContextWithFallback || c.Request == nil || c.Request.Context() == nil { + if !c.engine.ContextWithFallback || c.Request == nil || c.requestContext == nil { return nil } - return c.Request.Context().Err() + return c.requestContext.Err() } // Value returns the value associated with this context for key, or nil @@ -1210,8 +1228,8 @@ func (c *Context) Value(key any) any { return val } } - if !c.engine.ContextWithFallback || c.Request == nil || c.Request.Context() == nil { + if !c.engine.ContextWithFallback || c.Request == nil || c.requestContext == nil { return nil } - return c.Request.Context().Value(key) + return c.requestContext.Value(key) } diff --git a/context_test.go b/context_test.go index b3e81c14..012ea2b5 100644 --- a/context_test.go +++ b/context_test.go @@ -78,8 +78,9 @@ func TestContextFormFile(t *testing.T) { } mw.Close() c, _ := CreateTestContext(httptest.NewRecorder()) - c.Request, _ = http.NewRequest("POST", "/", buf) - c.Request.Header.Set("Content-Type", mw.FormDataContentType()) + req, _ := http.NewRequest("POST", "/", buf) + req.Header.Set("Content-Type", mw.FormDataContentType()) + c.setRequest(req) f, err := c.FormFile("file") if assert.NoError(t, err) { assert.Equal(t, "test", f.Filename) @@ -99,8 +100,9 @@ func TestContextMultipartForm(t *testing.T) { } mw.Close() c, _ := CreateTestContext(httptest.NewRecorder()) - c.Request, _ = http.NewRequest("POST", "/", buf) - c.Request.Header.Set("Content-Type", mw.FormDataContentType()) + req, _ := http.NewRequest("POST", "/", buf) + req.Header.Set("Content-Type", mw.FormDataContentType()) + c.setRequest(req) f, err := c.MultipartForm() if assert.NoError(t, err) { assert.NotNil(t, f) @@ -115,8 +117,9 @@ func TestSaveUploadedOpenFailed(t *testing.T) { mw.Close() c, _ := CreateTestContext(httptest.NewRecorder()) - c.Request, _ = http.NewRequest("POST", "/", buf) - c.Request.Header.Set("Content-Type", mw.FormDataContentType()) + req, _ := http.NewRequest("POST", "/", buf) + req.Header.Set("Content-Type", mw.FormDataContentType()) + c.setRequest(req) f := &multipart.FileHeader{ Filename: "file", @@ -134,8 +137,9 @@ func TestSaveUploadedCreateFailed(t *testing.T) { } mw.Close() c, _ := CreateTestContext(httptest.NewRecorder()) - c.Request, _ = http.NewRequest("POST", "/", buf) - c.Request.Header.Set("Content-Type", mw.FormDataContentType()) + req, _ := http.NewRequest("POST", "/", buf) + req.Header.Set("Content-Type", mw.FormDataContentType()) + c.setRequest(req) f, err := c.FormFile("file") if assert.NoError(t, err) { assert.Equal(t, "test", f.Filename) @@ -318,7 +322,8 @@ func TestContextGetStringMapStringSlice(t *testing.T) { func TestContextCopy(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) c.index = 2 - c.Request, _ = http.NewRequest("POST", "/hola", nil) + req, _ := http.NewRequest("POST", "/hola", nil) + c.setRequest(req) c.handlers = HandlersChain{func(c *Context) {}} c.Params = Params{Param{Key: "foo", Value: "bar"}} c.Set("foo", "bar") @@ -373,7 +378,8 @@ func TestContextHandler(t *testing.T) { func TestContextQuery(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) - c.Request, _ = http.NewRequest("GET", "http://example.com/?foo=bar&page=10&id=", nil) + req, _ := http.NewRequest("GET", "http://example.com/?foo=bar&page=10&id=", nil) + c.setRequest(req) value, ok := c.GetQuery("foo") assert.True(t, ok) @@ -424,9 +430,10 @@ func TestContextDefaultQueryOnEmptyRequest(t *testing.T) { func TestContextQueryAndPostForm(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) body := bytes.NewBufferString("foo=bar&page=11&both=&foo=second") - c.Request, _ = http.NewRequest("POST", + req, _ := http.NewRequest("POST", "/?both=GET&id=main&id=omit&array[]=first&array[]=second&ids[a]=hi&ids[b]=3.14", body) - c.Request.Header.Add("Content-Type", MIMEPOSTForm) + req.Header.Add("Content-Type", MIMEPOSTForm) + c.setRequest(req) assert.Equal(t, "bar", c.DefaultPostForm("foo", "none")) assert.Equal(t, "bar", c.PostForm("foo")) @@ -521,7 +528,7 @@ func TestContextQueryAndPostForm(t *testing.T) { func TestContextPostFormMultipart(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) - c.Request = createMultipartRequest() + c.setRequest(createMultipartRequest()) var obj struct { Foo string `form:"foo"` @@ -627,8 +634,9 @@ func TestContextSetCookiePathEmpty(t *testing.T) { func TestContextGetCookie(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) - c.Request, _ = http.NewRequest("GET", "/get", nil) - c.Request.Header.Set("Cookie", "user=gin") + req, _ := http.NewRequest("GET", "/get", nil) + req.Header.Set("Cookie", "user=gin") + c.setRequest(req) cookie, _ := c.Cookie("user") assert.Equal(t, "gin", cookie) @@ -683,7 +691,8 @@ func TestContextRenderJSON(t *testing.T) { func TestContextRenderJSONP(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest("GET", "http://example.com/?callback=x", nil) + req, _ := http.NewRequest("GET", "http://example.com/?callback=x", nil) + c.setRequest(req) c.JSONP(http.StatusCreated, H{"foo": "bar"}) @@ -697,7 +706,8 @@ func TestContextRenderJSONP(t *testing.T) { func TestContextRenderJSONPWithoutCallback(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest("GET", "http://example.com", nil) + req, _ := http.NewRequest("GET", "http://example.com", nil) + c.setRequest(req) c.JSONP(http.StatusCreated, H{"foo": "bar"}) @@ -996,7 +1006,8 @@ func TestContextRenderFile(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest("GET", "/", nil) + req, _ := http.NewRequest("GET", "/", nil) + c.setRequest(req) c.File("./gin.go") assert.Equal(t, http.StatusOK, w.Code) @@ -1010,7 +1021,8 @@ func TestContextRenderFileFromFS(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest("GET", "/some/path", nil) + req, _ := http.NewRequest("GET", "/some/path", nil) + c.setRequest(req) c.FileFromFS("./gin.go", Dir(".", false)) assert.Equal(t, http.StatusOK, w.Code) @@ -1026,7 +1038,8 @@ func TestContextRenderAttachment(t *testing.T) { c, _ := CreateTestContext(w) newFilename := "new_filename.go" - c.Request, _ = http.NewRequest("GET", "/", nil) + req, _ := http.NewRequest("GET", "/", nil) + c.setRequest(req) c.FileAttachment("./gin.go", newFilename) assert.Equal(t, 200, w.Code) @@ -1039,7 +1052,8 @@ func TestContextRenderUTF8Attachment(t *testing.T) { c, _ := CreateTestContext(w) newFilename := "new🧡_filename.go" - c.Request, _ = http.NewRequest("GET", "/", nil) + req, _ := http.NewRequest("GET", "/", nil) + c.setRequest(req) c.FileAttachment("./gin.go", newFilename) assert.Equal(t, 200, w.Code) @@ -1118,7 +1132,8 @@ func TestContextRenderRedirectWithRelativePath(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest("POST", "http://example.com", nil) + req, _ := http.NewRequest("POST", "http://example.com", nil) + c.setRequest(req) assert.Panics(t, func() { c.Redirect(299, "/new_path") }) assert.Panics(t, func() { c.Redirect(309, "/new_path") }) @@ -1132,7 +1147,8 @@ func TestContextRenderRedirectWithAbsolutePath(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest("POST", "http://example.com", nil) + req, _ := http.NewRequest("POST", "http://example.com", nil) + c.setRequest(req) c.Redirect(http.StatusFound, "http://google.com") c.Writer.WriteHeaderNow() @@ -1144,7 +1160,8 @@ func TestContextRenderRedirectWith201(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest("POST", "http://example.com", nil) + req, _ := http.NewRequest("POST", "http://example.com", nil) + c.setRequest(req) c.Redirect(http.StatusCreated, "/resource") c.Writer.WriteHeaderNow() @@ -1154,7 +1171,8 @@ func TestContextRenderRedirectWith201(t *testing.T) { func TestContextRenderRedirectAll(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) - c.Request, _ = http.NewRequest("POST", "http://example.com", nil) + req, _ := http.NewRequest("POST", "http://example.com", nil) + c.setRequest(req) assert.Panics(t, func() { c.Redirect(http.StatusOK, "/resource") }) assert.Panics(t, func() { c.Redirect(http.StatusAccepted, "/resource") }) assert.Panics(t, func() { c.Redirect(299, "/resource") }) @@ -1166,7 +1184,8 @@ func TestContextRenderRedirectAll(t *testing.T) { func TestContextNegotiationWithJSON(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest("POST", "", nil) + req, _ := http.NewRequest("POST", "", nil) + c.setRequest(req) c.Negotiate(http.StatusOK, Negotiate{ Offered: []string{MIMEJSON, MIMEXML, MIMEYAML}, @@ -1181,7 +1200,8 @@ func TestContextNegotiationWithJSON(t *testing.T) { func TestContextNegotiationWithXML(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest("POST", "", nil) + req, _ := http.NewRequest("POST", "", nil) + c.setRequest(req) c.Negotiate(http.StatusOK, Negotiate{ Offered: []string{MIMEXML, MIMEJSON, MIMEYAML}, @@ -1196,7 +1216,8 @@ func TestContextNegotiationWithXML(t *testing.T) { func TestContextNegotiationWithYAML(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest("POST", "", nil) + req, _ := http.NewRequest("POST", "", nil) + c.setRequest(req) c.Negotiate(http.StatusOK, Negotiate{ Offered: []string{MIMEYAML, MIMEXML, MIMEJSON, MIMETOML}, @@ -1211,7 +1232,8 @@ func TestContextNegotiationWithYAML(t *testing.T) { func TestContextNegotiationWithTOML(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest("POST", "", nil) + req, _ := http.NewRequest("POST", "", nil) + c.setRequest(req) c.Negotiate(http.StatusOK, Negotiate{ Offered: []string{MIMETOML, MIMEXML, MIMEJSON, MIMEYAML}, @@ -1226,7 +1248,8 @@ func TestContextNegotiationWithTOML(t *testing.T) { func TestContextNegotiationWithHTML(t *testing.T) { w := httptest.NewRecorder() c, router := CreateTestContext(w) - c.Request, _ = http.NewRequest("POST", "", nil) + req, _ := http.NewRequest("POST", "", nil) + c.setRequest(req) templ := template.Must(template.New("t").Parse(`Hello {{.name}}`)) router.SetHTMLTemplate(templ) @@ -1244,7 +1267,8 @@ func TestContextNegotiationWithHTML(t *testing.T) { func TestContextNegotiationNotSupport(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest("POST", "", nil) + req, _ := http.NewRequest("POST", "", nil) + c.setRequest(req) c.Negotiate(http.StatusOK, Negotiate{ Offered: []string{MIMEPOSTForm}, @@ -1257,7 +1281,8 @@ func TestContextNegotiationNotSupport(t *testing.T) { func TestContextNegotiationFormat(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) - c.Request, _ = http.NewRequest("POST", "", nil) + req, _ := http.NewRequest("POST", "", nil) + c.setRequest(req) assert.Panics(t, func() { c.NegotiateFormat() }) assert.Equal(t, MIMEJSON, c.NegotiateFormat(MIMEJSON, MIMEXML)) @@ -1266,8 +1291,9 @@ func TestContextNegotiationFormat(t *testing.T) { func TestContextNegotiationFormatWithAccept(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) - c.Request, _ = http.NewRequest("POST", "/", nil) - c.Request.Header.Add("Accept", "text/html,application/xhtml+xml,application/xml;q=0.9;q=0.8") + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Add("Accept", "text/html,application/xhtml+xml,application/xml;q=0.9;q=0.8") + c.setRequest(req) assert.Equal(t, MIMEXML, c.NegotiateFormat(MIMEJSON, MIMEXML)) assert.Equal(t, MIMEHTML, c.NegotiateFormat(MIMEXML, MIMEHTML)) @@ -1276,8 +1302,9 @@ func TestContextNegotiationFormatWithAccept(t *testing.T) { func TestContextNegotiationFormatWithWildcardAccept(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) - c.Request, _ = http.NewRequest("POST", "/", nil) - c.Request.Header.Add("Accept", "*/*") + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Add("Accept", "*/*") + c.setRequest(req) assert.Equal(t, c.NegotiateFormat("*/*"), "*/*") assert.Equal(t, c.NegotiateFormat("text/*"), "text/*") @@ -1287,8 +1314,9 @@ func TestContextNegotiationFormatWithWildcardAccept(t *testing.T) { assert.Equal(t, c.NegotiateFormat(MIMEHTML), MIMEHTML) c, _ = CreateTestContext(httptest.NewRecorder()) - c.Request, _ = http.NewRequest("POST", "/", nil) - c.Request.Header.Add("Accept", "text/*") + req, _ = http.NewRequest("POST", "/", nil) + req.Header.Add("Accept", "text/*") + c.setRequest(req) assert.Equal(t, c.NegotiateFormat("*/*"), "*/*") assert.Equal(t, c.NegotiateFormat("text/*"), "text/*") @@ -1300,8 +1328,9 @@ func TestContextNegotiationFormatWithWildcardAccept(t *testing.T) { func TestContextNegotiationFormatCustom(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) - c.Request, _ = http.NewRequest("POST", "/", nil) - c.Request.Header.Add("Accept", "text/html,application/xhtml+xml,application/xml;q=0.9;q=0.8") + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Add("Accept", "text/html,application/xhtml+xml,application/xml;q=0.9;q=0.8") + c.setRequest(req) c.Accepted = nil c.SetAccepted(MIMEJSON, MIMEXML) @@ -1433,7 +1462,8 @@ func TestContextAbortWithError(t *testing.T) { func TestContextClientIP(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) - c.Request, _ = http.NewRequest("POST", "/", nil) + req, _ := http.NewRequest("POST", "/", nil) + c.setRequest(req) c.engine.trustedCIDRs, _ = c.engine.prepareTrustedCIDRs() resetContextForClientIPTests(c) @@ -1569,16 +1599,18 @@ func resetContextForClientIPTests(c *Context) { func TestContextContentType(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) - c.Request, _ = http.NewRequest("POST", "/", nil) - c.Request.Header.Set("Content-Type", "application/json; charset=utf-8") + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Set("Content-Type", "application/json; charset=utf-8") + c.setRequest(req) assert.Equal(t, "application/json", c.ContentType()) } func TestContextAutoBindJSON(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) - c.Request, _ = http.NewRequest("POST", "/", bytes.NewBufferString("{\"foo\":\"bar\", \"bar\":\"foo\"}")) - c.Request.Header.Add("Content-Type", MIMEJSON) + req, _ := http.NewRequest("POST", "/", bytes.NewBufferString("{\"foo\":\"bar\", \"bar\":\"foo\"}")) + req.Header.Add("Content-Type", MIMEJSON) + c.setRequest(req) var obj struct { Foo string `json:"foo"` @@ -1594,8 +1626,9 @@ func TestContextBindWithJSON(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest("POST", "/", bytes.NewBufferString("{\"foo\":\"bar\", \"bar\":\"foo\"}")) - c.Request.Header.Add("Content-Type", MIMEXML) // set fake content-type + req, _ := http.NewRequest("POST", "/", bytes.NewBufferString("{\"foo\":\"bar\", \"bar\":\"foo\"}")) + req.Header.Add("Content-Type", MIMEXML) // set fake content-type + c.setRequest(req) var obj struct { Foo string `json:"foo"` @@ -1611,12 +1644,13 @@ func TestContextBindWithXML(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest("POST", "/", bytes.NewBufferString(` + req, _ := http.NewRequest("POST", "/", bytes.NewBufferString(` FOO BAR `)) - c.Request.Header.Add("Content-Type", MIMEXML) // set fake content-type + req.Header.Add("Content-Type", MIMEXML) // set fake content-type + c.setRequest(req) var obj struct { Foo string `xml:"foo"` @@ -1632,10 +1666,11 @@ func TestContextBindHeader(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest("POST", "/", nil) - c.Request.Header.Add("rate", "8000") - c.Request.Header.Add("domain", "music") - c.Request.Header.Add("limit", "1000") + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Add("rate", "8000") + req.Header.Add("domain", "music") + req.Header.Add("limit", "1000") + c.setRequest(req) var testHeader struct { Rate int `header:"Rate"` @@ -1654,7 +1689,8 @@ func TestContextBindWithQuery(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest("POST", "/?foo=bar&bar=foo", bytes.NewBufferString("foo=unused")) + req, _ := http.NewRequest("POST", "/?foo=bar&bar=foo", bytes.NewBufferString("foo=unused")) + c.setRequest(req) var obj struct { Foo string `form:"foo"` @@ -1670,8 +1706,9 @@ func TestContextBindWithYAML(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest("POST", "/", bytes.NewBufferString("foo: bar\nbar: foo")) - c.Request.Header.Add("Content-Type", MIMEXML) // set fake content-type + req, _ := http.NewRequest("POST", "/", bytes.NewBufferString("foo: bar\nbar: foo")) + req.Header.Add("Content-Type", MIMEXML) // set fake content-type + c.setRequest(req) var obj struct { Foo string `yaml:"foo"` @@ -1687,8 +1724,9 @@ func TestContextBindWithTOML(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest("POST", "/", bytes.NewBufferString("foo = 'bar'\nbar = 'foo'")) - c.Request.Header.Add("Content-Type", MIMEXML) // set fake content-type + req, _ := http.NewRequest("POST", "/", bytes.NewBufferString("foo = 'bar'\nbar = 'foo'")) + req.Header.Add("Content-Type", MIMEXML) // set fake content-type + c.setRequest(req) var obj struct { Foo string `toml:"foo"` @@ -1704,8 +1742,9 @@ func TestContextBadAutoBind(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest("POST", "http://example.com", bytes.NewBufferString("\"foo\":\"bar\", \"bar\":\"foo\"}")) - c.Request.Header.Add("Content-Type", MIMEJSON) + req, _ := http.NewRequest("POST", "http://example.com", bytes.NewBufferString("\"foo\":\"bar\", \"bar\":\"foo\"}")) + req.Header.Add("Content-Type", MIMEJSON) + c.setRequest(req) var obj struct { Foo string `json:"foo"` Bar string `json:"bar"` @@ -1723,8 +1762,9 @@ func TestContextBadAutoBind(t *testing.T) { func TestContextAutoShouldBindJSON(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) - c.Request, _ = http.NewRequest("POST", "/", bytes.NewBufferString("{\"foo\":\"bar\", \"bar\":\"foo\"}")) - c.Request.Header.Add("Content-Type", MIMEJSON) + req, _ := http.NewRequest("POST", "/", bytes.NewBufferString("{\"foo\":\"bar\", \"bar\":\"foo\"}")) + req.Header.Add("Content-Type", MIMEJSON) + c.setRequest(req) var obj struct { Foo string `json:"foo"` @@ -1740,8 +1780,9 @@ func TestContextShouldBindWithJSON(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest("POST", "/", bytes.NewBufferString("{\"foo\":\"bar\", \"bar\":\"foo\"}")) - c.Request.Header.Add("Content-Type", MIMEXML) // set fake content-type + req, _ := http.NewRequest("POST", "/", bytes.NewBufferString("{\"foo\":\"bar\", \"bar\":\"foo\"}")) + req.Header.Add("Content-Type", MIMEXML) // set fake content-type + c.setRequest(req) var obj struct { Foo string `json:"foo"` @@ -1757,12 +1798,13 @@ func TestContextShouldBindWithXML(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest("POST", "/", bytes.NewBufferString(` + req, _ := http.NewRequest("POST", "/", bytes.NewBufferString(` FOO BAR `)) - c.Request.Header.Add("Content-Type", MIMEXML) // set fake content-type + req.Header.Add("Content-Type", MIMEXML) // set fake content-type + c.setRequest(req) var obj struct { Foo string `xml:"foo"` @@ -1778,10 +1820,11 @@ func TestContextShouldBindHeader(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest("POST", "/", nil) - c.Request.Header.Add("rate", "8000") - c.Request.Header.Add("domain", "music") - c.Request.Header.Add("limit", "1000") + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Add("rate", "8000") + req.Header.Add("domain", "music") + req.Header.Add("limit", "1000") + c.setRequest(req) var testHeader struct { Rate int `header:"Rate"` @@ -1800,7 +1843,8 @@ func TestContextShouldBindWithQuery(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest("POST", "/?foo=bar&bar=foo&Foo=bar1&Bar=foo1", bytes.NewBufferString("foo=unused")) + req, _ := http.NewRequest("POST", "/?foo=bar&bar=foo&Foo=bar1&Bar=foo1", bytes.NewBufferString("foo=unused")) + c.setRequest(req) var obj struct { Foo string `form:"foo"` @@ -1820,8 +1864,9 @@ func TestContextShouldBindWithYAML(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest("POST", "/", bytes.NewBufferString("foo: bar\nbar: foo")) - c.Request.Header.Add("Content-Type", MIMEXML) // set fake content-type + req, _ := http.NewRequest("POST", "/", bytes.NewBufferString("foo: bar\nbar: foo")) + req.Header.Add("Content-Type", MIMEXML) // set fake content-type + c.setRequest(req) var obj struct { Foo string `yaml:"foo"` @@ -1837,8 +1882,9 @@ func TestContextShouldBindWithTOML(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest("POST", "/", bytes.NewBufferString("foo='bar'\nbar= 'foo'")) - c.Request.Header.Add("Content-Type", MIMETOML) // set fake content-type + req, _ := http.NewRequest("POST", "/", bytes.NewBufferString("foo='bar'\nbar= 'foo'")) + req.Header.Add("Content-Type", MIMETOML) // set fake content-type + c.setRequest(req) var obj struct { Foo string `toml:"foo"` @@ -1854,8 +1900,9 @@ func TestContextBadAutoShouldBind(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest("POST", "http://example.com", bytes.NewBufferString("\"foo\":\"bar\", \"bar\":\"foo\"}")) - c.Request.Header.Add("Content-Type", MIMEJSON) + req, _ := http.NewRequest("POST", "http://example.com", bytes.NewBufferString("\"foo\":\"bar\", \"bar\":\"foo\"}")) + req.Header.Add("Content-Type", MIMEJSON) + c.setRequest(req) var obj struct { Foo string `json:"foo"` Bar string `json:"bar"` @@ -1917,9 +1964,10 @@ func TestContextShouldBindBodyWith(t *testing.T) { { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest( + req, _ := http.NewRequest( "POST", "http://example.com", bytes.NewBufferString(tt.bodyA), ) + c.setRequest(req) // When it binds to typeA and typeB, it finds the body is // not typeB but typeA. objA := typeA{} @@ -1935,9 +1983,10 @@ func TestContextShouldBindBodyWith(t *testing.T) { // not typeA but typeB. w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest( + req, _ := http.NewRequest( "POST", "http://example.com", bytes.NewBufferString(tt.bodyB), ) + c.setRequest(req) objA := typeA{} assert.Error(t, c.ShouldBindBodyWith(&objA, tt.bindingA)) assert.NotEqual(t, typeA{"FOO"}, objA) @@ -1950,7 +1999,8 @@ func TestContextShouldBindBodyWith(t *testing.T) { func TestContextGolangContext(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) - c.Request, _ = http.NewRequest("POST", "/", bytes.NewBufferString("{\"foo\":\"bar\", \"bar\":\"foo\"}")) + req, _ := http.NewRequest("POST", "/", bytes.NewBufferString("{\"foo\":\"bar\", \"bar\":\"foo\"}")) + c.setRequest(req) assert.NoError(t, c.Err()) assert.Nil(t, c.Done()) ti, ok := c.Deadline() @@ -1968,29 +2018,32 @@ func TestContextGolangContext(t *testing.T) { func TestWebsocketsRequired(t *testing.T) { // Example request from spec: https://tools.ietf.org/html/rfc6455#section-1.2 c, _ := CreateTestContext(httptest.NewRecorder()) - c.Request, _ = http.NewRequest("GET", "/chat", nil) - c.Request.Header.Set("Host", "server.example.com") - c.Request.Header.Set("Upgrade", "websocket") - c.Request.Header.Set("Connection", "Upgrade") - c.Request.Header.Set("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==") - c.Request.Header.Set("Origin", "http://example.com") - c.Request.Header.Set("Sec-WebSocket-Protocol", "chat, superchat") - c.Request.Header.Set("Sec-WebSocket-Version", "13") + req, _ := http.NewRequest("GET", "/chat", nil) + req.Header.Set("Host", "server.example.com") + req.Header.Set("Upgrade", "websocket") + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==") + req.Header.Set("Origin", "http://example.com") + req.Header.Set("Sec-WebSocket-Protocol", "chat, superchat") + req.Header.Set("Sec-WebSocket-Version", "13") + c.setRequest(req) assert.True(t, c.IsWebsocket()) // Normal request, no websocket required. c, _ = CreateTestContext(httptest.NewRecorder()) - c.Request, _ = http.NewRequest("GET", "/chat", nil) - c.Request.Header.Set("Host", "server.example.com") + req, _ = http.NewRequest("GET", "/chat", nil) + req.Header.Set("Host", "server.example.com") + c.setRequest(req) assert.False(t, c.IsWebsocket()) } func TestGetRequestHeaderValue(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) - c.Request, _ = http.NewRequest("GET", "/chat", nil) - c.Request.Header.Set("Gin-Version", "1.0.0") + req, _ := http.NewRequest("GET", "/chat", nil) + req.Header.Set("Gin-Version", "1.0.0") + c.setRequest(req) assert.Equal(t, "1.0.0", c.GetHeader("Gin-Version")) assert.Empty(t, c.GetHeader("Connection")) @@ -1999,8 +2052,9 @@ func TestGetRequestHeaderValue(t *testing.T) { func TestContextGetRawData(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) body := bytes.NewBufferString("Fetch binary post data") - c.Request, _ = http.NewRequest("POST", "/", body) - c.Request.Header.Add("Content-Type", MIMEPOSTForm) + req, _ := http.NewRequest("POST", "/", body) + req.Header.Add("Content-Type", MIMEPOSTForm) + c.setRequest(req) data, err := c.GetRawData() assert.Nil(t, err) @@ -2148,8 +2202,9 @@ func TestContextWithKeysMutex(t *testing.T) { func TestRemoteIPFail(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) - c.Request, _ = http.NewRequest("POST", "/", nil) - c.Request.RemoteAddr = "[:::]:80" + req, _ := http.NewRequest("POST", "/", nil) + req.RemoteAddr = "[:::]:80" + c.setRequest(req) ip := net.ParseIP(c.RemoteIP()) trust := c.engine.isTrustedProxy(ip) assert.Nil(t, ip) @@ -2169,11 +2224,11 @@ func TestContextWithFallbackDeadlineFromRequestContext(t *testing.T) { // enable ContextWithFallback feature flag c2.engine.ContextWithFallback = true - c2.Request, _ = http.NewRequest(http.MethodGet, "/", nil) d := time.Now().Add(time.Second) ctx, cancel := context.WithDeadline(context.Background(), d) defer cancel() - c2.Request = c2.Request.WithContext(ctx) + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/", nil) + c2.setRequest(req) deadline, ok = c2.Deadline() assert.Equal(t, d, deadline) assert.True(t, ok) @@ -2190,11 +2245,13 @@ func TestContextWithFallbackDoneFromRequestContext(t *testing.T) { // enable ContextWithFallback feature flag c2.engine.ContextWithFallback = true - c2.Request, _ = http.NewRequest(http.MethodGet, "/", nil) ctx, cancel := context.WithCancel(context.Background()) - c2.Request = c2.Request.WithContext(ctx) + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/", nil) + c2.setRequest(req) cancel() - assert.NotNil(t, <-c2.Done()) + if assert.NotNil(t, c2.Done()) { + assert.NotNil(t, <-c2.Done()) + } } func TestContextWithFallbackErrFromRequestContext(t *testing.T) { @@ -2208,9 +2265,9 @@ func TestContextWithFallbackErrFromRequestContext(t *testing.T) { // enable ContextWithFallback feature flag c2.engine.ContextWithFallback = true - c2.Request, _ = http.NewRequest(http.MethodGet, "/", nil) ctx, cancel := context.WithCancel(context.Background()) - c2.Request = c2.Request.WithContext(ctx) + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/", nil) + c2.setRequest(req) cancel() assert.EqualError(t, c2.Err(), context.Canceled.Error()) @@ -2231,8 +2288,9 @@ func TestContextWithFallbackValueFromRequestContext(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) // enable ContextWithFallback feature flag c.engine.ContextWithFallback = true - c.Request, _ = http.NewRequest("POST", "/", nil) - c.Request = c.Request.WithContext(context.WithValue(context.TODO(), key, "value")) + ctx := context.WithValue(context.TODO(), key, "value") + req, _ := http.NewRequestWithContext(ctx, "POST", "/", nil) + c.setRequest(req) return c, key }, value: "value", @@ -2243,8 +2301,9 @@ func TestContextWithFallbackValueFromRequestContext(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) // enable ContextWithFallback feature flag c.engine.ContextWithFallback = true - c.Request, _ = http.NewRequest("POST", "/", nil) - c.Request = c.Request.WithContext(context.WithValue(context.TODO(), contextKey("key"), "value")) + ctx := context.WithValue(context.TODO(), contextKey("key"), "value") + req, _ := http.NewRequestWithContext(ctx, "POST", "/", nil) + c.setRequest(req) return c, contextKey("key") }, value: "value", @@ -2255,7 +2314,7 @@ func TestContextWithFallbackValueFromRequestContext(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) // enable ContextWithFallback feature flag c.engine.ContextWithFallback = true - c.Request = nil + c.setRequest(nil) return c, "key" }, value: nil, @@ -2266,7 +2325,8 @@ func TestContextWithFallbackValueFromRequestContext(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) // enable ContextWithFallback feature flag c.engine.ContextWithFallback = true - c.Request, _ = http.NewRequest("POST", "/", nil) + req, _ := http.NewRequest("POST", "/", nil) + c.setRequest(req) return c, "key" }, value: nil, @@ -2280,6 +2340,16 @@ func TestContextWithFallbackValueFromRequestContext(t *testing.T) { } } +func TestContextWithFallbackModifiesRequestContext(t *testing.T) { + c, _ := CreateTestContext(httptest.NewRecorder()) + // enable ContextWithFallback feature flag + c.engine.ContextWithFallback = true + req, _ := http.NewRequest(http.MethodGet, "/", nil) + c.setRequest(req) + + assert.Equal(t, c, c.Request.Context()) +} + func TestContextCopyShouldNotCancel(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) diff --git a/gin.go b/gin.go index f9324299..c81cda6b 100644 --- a/gin.go +++ b/gin.go @@ -566,7 +566,7 @@ func (engine *Engine) RunListener(listener net.Listener) (err error) { func (engine *Engine) ServeHTTP(w http.ResponseWriter, req *http.Request) { c := engine.pool.Get().(*Context) c.writermem.reset(w) - c.Request = req + c.setRequest(req) c.reset() engine.handleHTTPRequest(c)