diff --git a/context.go b/context.go
index b66b8adc..75a2d9ef 100644
--- a/context.go
+++ b/context.go
@@ -5,6 +5,7 @@
package gin
import (
+ "context"
"errors"
"io"
"io/ioutil"
@@ -50,9 +51,10 @@ const abortIndex int8 = math.MaxInt8 >> 1
// Context is the most important part of gin. It allows us to pass variables between middleware,
// manage the flow, validate the JSON of a request and render a JSON response for example.
type Context struct {
- writermem responseWriter
- Request *http.Request
- Writer ResponseWriter
+ writermem responseWriter
+ Request *http.Request
+ Writer ResponseWriter
+ requestContext context.Context
Params Params
handlers HandlersChain
@@ -108,6 +110,22 @@ func (c *Context) reset() {
*c.skippedNodes = (*c.skippedNodes)[:0]
}
+// Initializes c.Request and c.requestContext according to the ContextWithFallback feature flag
+func (c *Context) setRequest(req *http.Request) {
+ if req == nil {
+ c.Request = nil
+ c.requestContext = nil
+ return
+ }
+
+ c.requestContext = req.Context()
+ if c.engine.ContextWithFallback {
+ c.Request = req.WithContext(c)
+ } else {
+ c.Request = req
+ }
+}
+
// Copy returns a copy of the current context that can be safely used outside the request's scope.
// This has to be used when the context has to be passed to a goroutine.
func (c *Context) Copy() *Context {
@@ -1173,26 +1191,26 @@ func (c *Context) SetAccepted(formats ...string) {
// Deadline returns that there is no deadline (ok==false) when c.Request has no Context.
func (c *Context) Deadline() (deadline time.Time, ok bool) {
- if !c.engine.ContextWithFallback || c.Request == nil || c.Request.Context() == nil {
+ if !c.engine.ContextWithFallback || c.Request == nil || c.requestContext == nil {
return
}
- return c.Request.Context().Deadline()
+ return c.requestContext.Deadline()
}
// Done returns nil (chan which will wait forever) when c.Request has no Context.
func (c *Context) Done() <-chan struct{} {
- if !c.engine.ContextWithFallback || c.Request == nil || c.Request.Context() == nil {
+ if !c.engine.ContextWithFallback || c.Request == nil || c.requestContext == nil {
return nil
}
- return c.Request.Context().Done()
+ return c.requestContext.Done()
}
// Err returns nil when c.Request has no Context.
func (c *Context) Err() error {
- if !c.engine.ContextWithFallback || c.Request == nil || c.Request.Context() == nil {
+ if !c.engine.ContextWithFallback || c.Request == nil || c.requestContext == nil {
return nil
}
- return c.Request.Context().Err()
+ return c.requestContext.Err()
}
// Value returns the value associated with this context for key, or nil
@@ -1210,8 +1228,8 @@ func (c *Context) Value(key any) any {
return val
}
}
- if !c.engine.ContextWithFallback || c.Request == nil || c.Request.Context() == nil {
+ if !c.engine.ContextWithFallback || c.Request == nil || c.requestContext == nil {
return nil
}
- return c.Request.Context().Value(key)
+ return c.requestContext.Value(key)
}
diff --git a/context_test.go b/context_test.go
index b3e81c14..012ea2b5 100644
--- a/context_test.go
+++ b/context_test.go
@@ -78,8 +78,9 @@ func TestContextFormFile(t *testing.T) {
}
mw.Close()
c, _ := CreateTestContext(httptest.NewRecorder())
- c.Request, _ = http.NewRequest("POST", "/", buf)
- c.Request.Header.Set("Content-Type", mw.FormDataContentType())
+ req, _ := http.NewRequest("POST", "/", buf)
+ req.Header.Set("Content-Type", mw.FormDataContentType())
+ c.setRequest(req)
f, err := c.FormFile("file")
if assert.NoError(t, err) {
assert.Equal(t, "test", f.Filename)
@@ -99,8 +100,9 @@ func TestContextMultipartForm(t *testing.T) {
}
mw.Close()
c, _ := CreateTestContext(httptest.NewRecorder())
- c.Request, _ = http.NewRequest("POST", "/", buf)
- c.Request.Header.Set("Content-Type", mw.FormDataContentType())
+ req, _ := http.NewRequest("POST", "/", buf)
+ req.Header.Set("Content-Type", mw.FormDataContentType())
+ c.setRequest(req)
f, err := c.MultipartForm()
if assert.NoError(t, err) {
assert.NotNil(t, f)
@@ -115,8 +117,9 @@ func TestSaveUploadedOpenFailed(t *testing.T) {
mw.Close()
c, _ := CreateTestContext(httptest.NewRecorder())
- c.Request, _ = http.NewRequest("POST", "/", buf)
- c.Request.Header.Set("Content-Type", mw.FormDataContentType())
+ req, _ := http.NewRequest("POST", "/", buf)
+ req.Header.Set("Content-Type", mw.FormDataContentType())
+ c.setRequest(req)
f := &multipart.FileHeader{
Filename: "file",
@@ -134,8 +137,9 @@ func TestSaveUploadedCreateFailed(t *testing.T) {
}
mw.Close()
c, _ := CreateTestContext(httptest.NewRecorder())
- c.Request, _ = http.NewRequest("POST", "/", buf)
- c.Request.Header.Set("Content-Type", mw.FormDataContentType())
+ req, _ := http.NewRequest("POST", "/", buf)
+ req.Header.Set("Content-Type", mw.FormDataContentType())
+ c.setRequest(req)
f, err := c.FormFile("file")
if assert.NoError(t, err) {
assert.Equal(t, "test", f.Filename)
@@ -318,7 +322,8 @@ func TestContextGetStringMapStringSlice(t *testing.T) {
func TestContextCopy(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder())
c.index = 2
- c.Request, _ = http.NewRequest("POST", "/hola", nil)
+ req, _ := http.NewRequest("POST", "/hola", nil)
+ c.setRequest(req)
c.handlers = HandlersChain{func(c *Context) {}}
c.Params = Params{Param{Key: "foo", Value: "bar"}}
c.Set("foo", "bar")
@@ -373,7 +378,8 @@ func TestContextHandler(t *testing.T) {
func TestContextQuery(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder())
- c.Request, _ = http.NewRequest("GET", "http://example.com/?foo=bar&page=10&id=", nil)
+ req, _ := http.NewRequest("GET", "http://example.com/?foo=bar&page=10&id=", nil)
+ c.setRequest(req)
value, ok := c.GetQuery("foo")
assert.True(t, ok)
@@ -424,9 +430,10 @@ func TestContextDefaultQueryOnEmptyRequest(t *testing.T) {
func TestContextQueryAndPostForm(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder())
body := bytes.NewBufferString("foo=bar&page=11&both=&foo=second")
- c.Request, _ = http.NewRequest("POST",
+ req, _ := http.NewRequest("POST",
"/?both=GET&id=main&id=omit&array[]=first&array[]=second&ids[a]=hi&ids[b]=3.14", body)
- c.Request.Header.Add("Content-Type", MIMEPOSTForm)
+ req.Header.Add("Content-Type", MIMEPOSTForm)
+ c.setRequest(req)
assert.Equal(t, "bar", c.DefaultPostForm("foo", "none"))
assert.Equal(t, "bar", c.PostForm("foo"))
@@ -521,7 +528,7 @@ func TestContextQueryAndPostForm(t *testing.T) {
func TestContextPostFormMultipart(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder())
- c.Request = createMultipartRequest()
+ c.setRequest(createMultipartRequest())
var obj struct {
Foo string `form:"foo"`
@@ -627,8 +634,9 @@ func TestContextSetCookiePathEmpty(t *testing.T) {
func TestContextGetCookie(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder())
- c.Request, _ = http.NewRequest("GET", "/get", nil)
- c.Request.Header.Set("Cookie", "user=gin")
+ req, _ := http.NewRequest("GET", "/get", nil)
+ req.Header.Set("Cookie", "user=gin")
+ c.setRequest(req)
cookie, _ := c.Cookie("user")
assert.Equal(t, "gin", cookie)
@@ -683,7 +691,8 @@ func TestContextRenderJSON(t *testing.T) {
func TestContextRenderJSONP(t *testing.T) {
w := httptest.NewRecorder()
c, _ := CreateTestContext(w)
- c.Request, _ = http.NewRequest("GET", "http://example.com/?callback=x", nil)
+ req, _ := http.NewRequest("GET", "http://example.com/?callback=x", nil)
+ c.setRequest(req)
c.JSONP(http.StatusCreated, H{"foo": "bar"})
@@ -697,7 +706,8 @@ func TestContextRenderJSONP(t *testing.T) {
func TestContextRenderJSONPWithoutCallback(t *testing.T) {
w := httptest.NewRecorder()
c, _ := CreateTestContext(w)
- c.Request, _ = http.NewRequest("GET", "http://example.com", nil)
+ req, _ := http.NewRequest("GET", "http://example.com", nil)
+ c.setRequest(req)
c.JSONP(http.StatusCreated, H{"foo": "bar"})
@@ -996,7 +1006,8 @@ func TestContextRenderFile(t *testing.T) {
w := httptest.NewRecorder()
c, _ := CreateTestContext(w)
- c.Request, _ = http.NewRequest("GET", "/", nil)
+ req, _ := http.NewRequest("GET", "/", nil)
+ c.setRequest(req)
c.File("./gin.go")
assert.Equal(t, http.StatusOK, w.Code)
@@ -1010,7 +1021,8 @@ func TestContextRenderFileFromFS(t *testing.T) {
w := httptest.NewRecorder()
c, _ := CreateTestContext(w)
- c.Request, _ = http.NewRequest("GET", "/some/path", nil)
+ req, _ := http.NewRequest("GET", "/some/path", nil)
+ c.setRequest(req)
c.FileFromFS("./gin.go", Dir(".", false))
assert.Equal(t, http.StatusOK, w.Code)
@@ -1026,7 +1038,8 @@ func TestContextRenderAttachment(t *testing.T) {
c, _ := CreateTestContext(w)
newFilename := "new_filename.go"
- c.Request, _ = http.NewRequest("GET", "/", nil)
+ req, _ := http.NewRequest("GET", "/", nil)
+ c.setRequest(req)
c.FileAttachment("./gin.go", newFilename)
assert.Equal(t, 200, w.Code)
@@ -1039,7 +1052,8 @@ func TestContextRenderUTF8Attachment(t *testing.T) {
c, _ := CreateTestContext(w)
newFilename := "new🧡_filename.go"
- c.Request, _ = http.NewRequest("GET", "/", nil)
+ req, _ := http.NewRequest("GET", "/", nil)
+ c.setRequest(req)
c.FileAttachment("./gin.go", newFilename)
assert.Equal(t, 200, w.Code)
@@ -1118,7 +1132,8 @@ func TestContextRenderRedirectWithRelativePath(t *testing.T) {
w := httptest.NewRecorder()
c, _ := CreateTestContext(w)
- c.Request, _ = http.NewRequest("POST", "http://example.com", nil)
+ req, _ := http.NewRequest("POST", "http://example.com", nil)
+ c.setRequest(req)
assert.Panics(t, func() { c.Redirect(299, "/new_path") })
assert.Panics(t, func() { c.Redirect(309, "/new_path") })
@@ -1132,7 +1147,8 @@ func TestContextRenderRedirectWithAbsolutePath(t *testing.T) {
w := httptest.NewRecorder()
c, _ := CreateTestContext(w)
- c.Request, _ = http.NewRequest("POST", "http://example.com", nil)
+ req, _ := http.NewRequest("POST", "http://example.com", nil)
+ c.setRequest(req)
c.Redirect(http.StatusFound, "http://google.com")
c.Writer.WriteHeaderNow()
@@ -1144,7 +1160,8 @@ func TestContextRenderRedirectWith201(t *testing.T) {
w := httptest.NewRecorder()
c, _ := CreateTestContext(w)
- c.Request, _ = http.NewRequest("POST", "http://example.com", nil)
+ req, _ := http.NewRequest("POST", "http://example.com", nil)
+ c.setRequest(req)
c.Redirect(http.StatusCreated, "/resource")
c.Writer.WriteHeaderNow()
@@ -1154,7 +1171,8 @@ func TestContextRenderRedirectWith201(t *testing.T) {
func TestContextRenderRedirectAll(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder())
- c.Request, _ = http.NewRequest("POST", "http://example.com", nil)
+ req, _ := http.NewRequest("POST", "http://example.com", nil)
+ c.setRequest(req)
assert.Panics(t, func() { c.Redirect(http.StatusOK, "/resource") })
assert.Panics(t, func() { c.Redirect(http.StatusAccepted, "/resource") })
assert.Panics(t, func() { c.Redirect(299, "/resource") })
@@ -1166,7 +1184,8 @@ func TestContextRenderRedirectAll(t *testing.T) {
func TestContextNegotiationWithJSON(t *testing.T) {
w := httptest.NewRecorder()
c, _ := CreateTestContext(w)
- c.Request, _ = http.NewRequest("POST", "", nil)
+ req, _ := http.NewRequest("POST", "", nil)
+ c.setRequest(req)
c.Negotiate(http.StatusOK, Negotiate{
Offered: []string{MIMEJSON, MIMEXML, MIMEYAML},
@@ -1181,7 +1200,8 @@ func TestContextNegotiationWithJSON(t *testing.T) {
func TestContextNegotiationWithXML(t *testing.T) {
w := httptest.NewRecorder()
c, _ := CreateTestContext(w)
- c.Request, _ = http.NewRequest("POST", "", nil)
+ req, _ := http.NewRequest("POST", "", nil)
+ c.setRequest(req)
c.Negotiate(http.StatusOK, Negotiate{
Offered: []string{MIMEXML, MIMEJSON, MIMEYAML},
@@ -1196,7 +1216,8 @@ func TestContextNegotiationWithXML(t *testing.T) {
func TestContextNegotiationWithYAML(t *testing.T) {
w := httptest.NewRecorder()
c, _ := CreateTestContext(w)
- c.Request, _ = http.NewRequest("POST", "", nil)
+ req, _ := http.NewRequest("POST", "", nil)
+ c.setRequest(req)
c.Negotiate(http.StatusOK, Negotiate{
Offered: []string{MIMEYAML, MIMEXML, MIMEJSON, MIMETOML},
@@ -1211,7 +1232,8 @@ func TestContextNegotiationWithYAML(t *testing.T) {
func TestContextNegotiationWithTOML(t *testing.T) {
w := httptest.NewRecorder()
c, _ := CreateTestContext(w)
- c.Request, _ = http.NewRequest("POST", "", nil)
+ req, _ := http.NewRequest("POST", "", nil)
+ c.setRequest(req)
c.Negotiate(http.StatusOK, Negotiate{
Offered: []string{MIMETOML, MIMEXML, MIMEJSON, MIMEYAML},
@@ -1226,7 +1248,8 @@ func TestContextNegotiationWithTOML(t *testing.T) {
func TestContextNegotiationWithHTML(t *testing.T) {
w := httptest.NewRecorder()
c, router := CreateTestContext(w)
- c.Request, _ = http.NewRequest("POST", "", nil)
+ req, _ := http.NewRequest("POST", "", nil)
+ c.setRequest(req)
templ := template.Must(template.New("t").Parse(`Hello {{.name}}`))
router.SetHTMLTemplate(templ)
@@ -1244,7 +1267,8 @@ func TestContextNegotiationWithHTML(t *testing.T) {
func TestContextNegotiationNotSupport(t *testing.T) {
w := httptest.NewRecorder()
c, _ := CreateTestContext(w)
- c.Request, _ = http.NewRequest("POST", "", nil)
+ req, _ := http.NewRequest("POST", "", nil)
+ c.setRequest(req)
c.Negotiate(http.StatusOK, Negotiate{
Offered: []string{MIMEPOSTForm},
@@ -1257,7 +1281,8 @@ func TestContextNegotiationNotSupport(t *testing.T) {
func TestContextNegotiationFormat(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder())
- c.Request, _ = http.NewRequest("POST", "", nil)
+ req, _ := http.NewRequest("POST", "", nil)
+ c.setRequest(req)
assert.Panics(t, func() { c.NegotiateFormat() })
assert.Equal(t, MIMEJSON, c.NegotiateFormat(MIMEJSON, MIMEXML))
@@ -1266,8 +1291,9 @@ func TestContextNegotiationFormat(t *testing.T) {
func TestContextNegotiationFormatWithAccept(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder())
- c.Request, _ = http.NewRequest("POST", "/", nil)
- c.Request.Header.Add("Accept", "text/html,application/xhtml+xml,application/xml;q=0.9;q=0.8")
+ req, _ := http.NewRequest("POST", "/", nil)
+ req.Header.Add("Accept", "text/html,application/xhtml+xml,application/xml;q=0.9;q=0.8")
+ c.setRequest(req)
assert.Equal(t, MIMEXML, c.NegotiateFormat(MIMEJSON, MIMEXML))
assert.Equal(t, MIMEHTML, c.NegotiateFormat(MIMEXML, MIMEHTML))
@@ -1276,8 +1302,9 @@ func TestContextNegotiationFormatWithAccept(t *testing.T) {
func TestContextNegotiationFormatWithWildcardAccept(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder())
- c.Request, _ = http.NewRequest("POST", "/", nil)
- c.Request.Header.Add("Accept", "*/*")
+ req, _ := http.NewRequest("POST", "/", nil)
+ req.Header.Add("Accept", "*/*")
+ c.setRequest(req)
assert.Equal(t, c.NegotiateFormat("*/*"), "*/*")
assert.Equal(t, c.NegotiateFormat("text/*"), "text/*")
@@ -1287,8 +1314,9 @@ func TestContextNegotiationFormatWithWildcardAccept(t *testing.T) {
assert.Equal(t, c.NegotiateFormat(MIMEHTML), MIMEHTML)
c, _ = CreateTestContext(httptest.NewRecorder())
- c.Request, _ = http.NewRequest("POST", "/", nil)
- c.Request.Header.Add("Accept", "text/*")
+ req, _ = http.NewRequest("POST", "/", nil)
+ req.Header.Add("Accept", "text/*")
+ c.setRequest(req)
assert.Equal(t, c.NegotiateFormat("*/*"), "*/*")
assert.Equal(t, c.NegotiateFormat("text/*"), "text/*")
@@ -1300,8 +1328,9 @@ func TestContextNegotiationFormatWithWildcardAccept(t *testing.T) {
func TestContextNegotiationFormatCustom(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder())
- c.Request, _ = http.NewRequest("POST", "/", nil)
- c.Request.Header.Add("Accept", "text/html,application/xhtml+xml,application/xml;q=0.9;q=0.8")
+ req, _ := http.NewRequest("POST", "/", nil)
+ req.Header.Add("Accept", "text/html,application/xhtml+xml,application/xml;q=0.9;q=0.8")
+ c.setRequest(req)
c.Accepted = nil
c.SetAccepted(MIMEJSON, MIMEXML)
@@ -1433,7 +1462,8 @@ func TestContextAbortWithError(t *testing.T) {
func TestContextClientIP(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder())
- c.Request, _ = http.NewRequest("POST", "/", nil)
+ req, _ := http.NewRequest("POST", "/", nil)
+ c.setRequest(req)
c.engine.trustedCIDRs, _ = c.engine.prepareTrustedCIDRs()
resetContextForClientIPTests(c)
@@ -1569,16 +1599,18 @@ func resetContextForClientIPTests(c *Context) {
func TestContextContentType(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder())
- c.Request, _ = http.NewRequest("POST", "/", nil)
- c.Request.Header.Set("Content-Type", "application/json; charset=utf-8")
+ req, _ := http.NewRequest("POST", "/", nil)
+ req.Header.Set("Content-Type", "application/json; charset=utf-8")
+ c.setRequest(req)
assert.Equal(t, "application/json", c.ContentType())
}
func TestContextAutoBindJSON(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder())
- c.Request, _ = http.NewRequest("POST", "/", bytes.NewBufferString("{\"foo\":\"bar\", \"bar\":\"foo\"}"))
- c.Request.Header.Add("Content-Type", MIMEJSON)
+ req, _ := http.NewRequest("POST", "/", bytes.NewBufferString("{\"foo\":\"bar\", \"bar\":\"foo\"}"))
+ req.Header.Add("Content-Type", MIMEJSON)
+ c.setRequest(req)
var obj struct {
Foo string `json:"foo"`
@@ -1594,8 +1626,9 @@ func TestContextBindWithJSON(t *testing.T) {
w := httptest.NewRecorder()
c, _ := CreateTestContext(w)
- c.Request, _ = http.NewRequest("POST", "/", bytes.NewBufferString("{\"foo\":\"bar\", \"bar\":\"foo\"}"))
- c.Request.Header.Add("Content-Type", MIMEXML) // set fake content-type
+ req, _ := http.NewRequest("POST", "/", bytes.NewBufferString("{\"foo\":\"bar\", \"bar\":\"foo\"}"))
+ req.Header.Add("Content-Type", MIMEXML) // set fake content-type
+ c.setRequest(req)
var obj struct {
Foo string `json:"foo"`
@@ -1611,12 +1644,13 @@ func TestContextBindWithXML(t *testing.T) {
w := httptest.NewRecorder()
c, _ := CreateTestContext(w)
- c.Request, _ = http.NewRequest("POST", "/", bytes.NewBufferString(`
+ req, _ := http.NewRequest("POST", "/", bytes.NewBufferString(`
FOO
BAR
`))
- c.Request.Header.Add("Content-Type", MIMEXML) // set fake content-type
+ req.Header.Add("Content-Type", MIMEXML) // set fake content-type
+ c.setRequest(req)
var obj struct {
Foo string `xml:"foo"`
@@ -1632,10 +1666,11 @@ func TestContextBindHeader(t *testing.T) {
w := httptest.NewRecorder()
c, _ := CreateTestContext(w)
- c.Request, _ = http.NewRequest("POST", "/", nil)
- c.Request.Header.Add("rate", "8000")
- c.Request.Header.Add("domain", "music")
- c.Request.Header.Add("limit", "1000")
+ req, _ := http.NewRequest("POST", "/", nil)
+ req.Header.Add("rate", "8000")
+ req.Header.Add("domain", "music")
+ req.Header.Add("limit", "1000")
+ c.setRequest(req)
var testHeader struct {
Rate int `header:"Rate"`
@@ -1654,7 +1689,8 @@ func TestContextBindWithQuery(t *testing.T) {
w := httptest.NewRecorder()
c, _ := CreateTestContext(w)
- c.Request, _ = http.NewRequest("POST", "/?foo=bar&bar=foo", bytes.NewBufferString("foo=unused"))
+ req, _ := http.NewRequest("POST", "/?foo=bar&bar=foo", bytes.NewBufferString("foo=unused"))
+ c.setRequest(req)
var obj struct {
Foo string `form:"foo"`
@@ -1670,8 +1706,9 @@ func TestContextBindWithYAML(t *testing.T) {
w := httptest.NewRecorder()
c, _ := CreateTestContext(w)
- c.Request, _ = http.NewRequest("POST", "/", bytes.NewBufferString("foo: bar\nbar: foo"))
- c.Request.Header.Add("Content-Type", MIMEXML) // set fake content-type
+ req, _ := http.NewRequest("POST", "/", bytes.NewBufferString("foo: bar\nbar: foo"))
+ req.Header.Add("Content-Type", MIMEXML) // set fake content-type
+ c.setRequest(req)
var obj struct {
Foo string `yaml:"foo"`
@@ -1687,8 +1724,9 @@ func TestContextBindWithTOML(t *testing.T) {
w := httptest.NewRecorder()
c, _ := CreateTestContext(w)
- c.Request, _ = http.NewRequest("POST", "/", bytes.NewBufferString("foo = 'bar'\nbar = 'foo'"))
- c.Request.Header.Add("Content-Type", MIMEXML) // set fake content-type
+ req, _ := http.NewRequest("POST", "/", bytes.NewBufferString("foo = 'bar'\nbar = 'foo'"))
+ req.Header.Add("Content-Type", MIMEXML) // set fake content-type
+ c.setRequest(req)
var obj struct {
Foo string `toml:"foo"`
@@ -1704,8 +1742,9 @@ func TestContextBadAutoBind(t *testing.T) {
w := httptest.NewRecorder()
c, _ := CreateTestContext(w)
- c.Request, _ = http.NewRequest("POST", "http://example.com", bytes.NewBufferString("\"foo\":\"bar\", \"bar\":\"foo\"}"))
- c.Request.Header.Add("Content-Type", MIMEJSON)
+ req, _ := http.NewRequest("POST", "http://example.com", bytes.NewBufferString("\"foo\":\"bar\", \"bar\":\"foo\"}"))
+ req.Header.Add("Content-Type", MIMEJSON)
+ c.setRequest(req)
var obj struct {
Foo string `json:"foo"`
Bar string `json:"bar"`
@@ -1723,8 +1762,9 @@ func TestContextBadAutoBind(t *testing.T) {
func TestContextAutoShouldBindJSON(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder())
- c.Request, _ = http.NewRequest("POST", "/", bytes.NewBufferString("{\"foo\":\"bar\", \"bar\":\"foo\"}"))
- c.Request.Header.Add("Content-Type", MIMEJSON)
+ req, _ := http.NewRequest("POST", "/", bytes.NewBufferString("{\"foo\":\"bar\", \"bar\":\"foo\"}"))
+ req.Header.Add("Content-Type", MIMEJSON)
+ c.setRequest(req)
var obj struct {
Foo string `json:"foo"`
@@ -1740,8 +1780,9 @@ func TestContextShouldBindWithJSON(t *testing.T) {
w := httptest.NewRecorder()
c, _ := CreateTestContext(w)
- c.Request, _ = http.NewRequest("POST", "/", bytes.NewBufferString("{\"foo\":\"bar\", \"bar\":\"foo\"}"))
- c.Request.Header.Add("Content-Type", MIMEXML) // set fake content-type
+ req, _ := http.NewRequest("POST", "/", bytes.NewBufferString("{\"foo\":\"bar\", \"bar\":\"foo\"}"))
+ req.Header.Add("Content-Type", MIMEXML) // set fake content-type
+ c.setRequest(req)
var obj struct {
Foo string `json:"foo"`
@@ -1757,12 +1798,13 @@ func TestContextShouldBindWithXML(t *testing.T) {
w := httptest.NewRecorder()
c, _ := CreateTestContext(w)
- c.Request, _ = http.NewRequest("POST", "/", bytes.NewBufferString(`
+ req, _ := http.NewRequest("POST", "/", bytes.NewBufferString(`
FOO
BAR
`))
- c.Request.Header.Add("Content-Type", MIMEXML) // set fake content-type
+ req.Header.Add("Content-Type", MIMEXML) // set fake content-type
+ c.setRequest(req)
var obj struct {
Foo string `xml:"foo"`
@@ -1778,10 +1820,11 @@ func TestContextShouldBindHeader(t *testing.T) {
w := httptest.NewRecorder()
c, _ := CreateTestContext(w)
- c.Request, _ = http.NewRequest("POST", "/", nil)
- c.Request.Header.Add("rate", "8000")
- c.Request.Header.Add("domain", "music")
- c.Request.Header.Add("limit", "1000")
+ req, _ := http.NewRequest("POST", "/", nil)
+ req.Header.Add("rate", "8000")
+ req.Header.Add("domain", "music")
+ req.Header.Add("limit", "1000")
+ c.setRequest(req)
var testHeader struct {
Rate int `header:"Rate"`
@@ -1800,7 +1843,8 @@ func TestContextShouldBindWithQuery(t *testing.T) {
w := httptest.NewRecorder()
c, _ := CreateTestContext(w)
- c.Request, _ = http.NewRequest("POST", "/?foo=bar&bar=foo&Foo=bar1&Bar=foo1", bytes.NewBufferString("foo=unused"))
+ req, _ := http.NewRequest("POST", "/?foo=bar&bar=foo&Foo=bar1&Bar=foo1", bytes.NewBufferString("foo=unused"))
+ c.setRequest(req)
var obj struct {
Foo string `form:"foo"`
@@ -1820,8 +1864,9 @@ func TestContextShouldBindWithYAML(t *testing.T) {
w := httptest.NewRecorder()
c, _ := CreateTestContext(w)
- c.Request, _ = http.NewRequest("POST", "/", bytes.NewBufferString("foo: bar\nbar: foo"))
- c.Request.Header.Add("Content-Type", MIMEXML) // set fake content-type
+ req, _ := http.NewRequest("POST", "/", bytes.NewBufferString("foo: bar\nbar: foo"))
+ req.Header.Add("Content-Type", MIMEXML) // set fake content-type
+ c.setRequest(req)
var obj struct {
Foo string `yaml:"foo"`
@@ -1837,8 +1882,9 @@ func TestContextShouldBindWithTOML(t *testing.T) {
w := httptest.NewRecorder()
c, _ := CreateTestContext(w)
- c.Request, _ = http.NewRequest("POST", "/", bytes.NewBufferString("foo='bar'\nbar= 'foo'"))
- c.Request.Header.Add("Content-Type", MIMETOML) // set fake content-type
+ req, _ := http.NewRequest("POST", "/", bytes.NewBufferString("foo='bar'\nbar= 'foo'"))
+ req.Header.Add("Content-Type", MIMETOML) // set fake content-type
+ c.setRequest(req)
var obj struct {
Foo string `toml:"foo"`
@@ -1854,8 +1900,9 @@ func TestContextBadAutoShouldBind(t *testing.T) {
w := httptest.NewRecorder()
c, _ := CreateTestContext(w)
- c.Request, _ = http.NewRequest("POST", "http://example.com", bytes.NewBufferString("\"foo\":\"bar\", \"bar\":\"foo\"}"))
- c.Request.Header.Add("Content-Type", MIMEJSON)
+ req, _ := http.NewRequest("POST", "http://example.com", bytes.NewBufferString("\"foo\":\"bar\", \"bar\":\"foo\"}"))
+ req.Header.Add("Content-Type", MIMEJSON)
+ c.setRequest(req)
var obj struct {
Foo string `json:"foo"`
Bar string `json:"bar"`
@@ -1917,9 +1964,10 @@ func TestContextShouldBindBodyWith(t *testing.T) {
{
w := httptest.NewRecorder()
c, _ := CreateTestContext(w)
- c.Request, _ = http.NewRequest(
+ req, _ := http.NewRequest(
"POST", "http://example.com", bytes.NewBufferString(tt.bodyA),
)
+ c.setRequest(req)
// When it binds to typeA and typeB, it finds the body is
// not typeB but typeA.
objA := typeA{}
@@ -1935,9 +1983,10 @@ func TestContextShouldBindBodyWith(t *testing.T) {
// not typeA but typeB.
w := httptest.NewRecorder()
c, _ := CreateTestContext(w)
- c.Request, _ = http.NewRequest(
+ req, _ := http.NewRequest(
"POST", "http://example.com", bytes.NewBufferString(tt.bodyB),
)
+ c.setRequest(req)
objA := typeA{}
assert.Error(t, c.ShouldBindBodyWith(&objA, tt.bindingA))
assert.NotEqual(t, typeA{"FOO"}, objA)
@@ -1950,7 +1999,8 @@ func TestContextShouldBindBodyWith(t *testing.T) {
func TestContextGolangContext(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder())
- c.Request, _ = http.NewRequest("POST", "/", bytes.NewBufferString("{\"foo\":\"bar\", \"bar\":\"foo\"}"))
+ req, _ := http.NewRequest("POST", "/", bytes.NewBufferString("{\"foo\":\"bar\", \"bar\":\"foo\"}"))
+ c.setRequest(req)
assert.NoError(t, c.Err())
assert.Nil(t, c.Done())
ti, ok := c.Deadline()
@@ -1968,29 +2018,32 @@ func TestContextGolangContext(t *testing.T) {
func TestWebsocketsRequired(t *testing.T) {
// Example request from spec: https://tools.ietf.org/html/rfc6455#section-1.2
c, _ := CreateTestContext(httptest.NewRecorder())
- c.Request, _ = http.NewRequest("GET", "/chat", nil)
- c.Request.Header.Set("Host", "server.example.com")
- c.Request.Header.Set("Upgrade", "websocket")
- c.Request.Header.Set("Connection", "Upgrade")
- c.Request.Header.Set("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
- c.Request.Header.Set("Origin", "http://example.com")
- c.Request.Header.Set("Sec-WebSocket-Protocol", "chat, superchat")
- c.Request.Header.Set("Sec-WebSocket-Version", "13")
+ req, _ := http.NewRequest("GET", "/chat", nil)
+ req.Header.Set("Host", "server.example.com")
+ req.Header.Set("Upgrade", "websocket")
+ req.Header.Set("Connection", "Upgrade")
+ req.Header.Set("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
+ req.Header.Set("Origin", "http://example.com")
+ req.Header.Set("Sec-WebSocket-Protocol", "chat, superchat")
+ req.Header.Set("Sec-WebSocket-Version", "13")
+ c.setRequest(req)
assert.True(t, c.IsWebsocket())
// Normal request, no websocket required.
c, _ = CreateTestContext(httptest.NewRecorder())
- c.Request, _ = http.NewRequest("GET", "/chat", nil)
- c.Request.Header.Set("Host", "server.example.com")
+ req, _ = http.NewRequest("GET", "/chat", nil)
+ req.Header.Set("Host", "server.example.com")
+ c.setRequest(req)
assert.False(t, c.IsWebsocket())
}
func TestGetRequestHeaderValue(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder())
- c.Request, _ = http.NewRequest("GET", "/chat", nil)
- c.Request.Header.Set("Gin-Version", "1.0.0")
+ req, _ := http.NewRequest("GET", "/chat", nil)
+ req.Header.Set("Gin-Version", "1.0.0")
+ c.setRequest(req)
assert.Equal(t, "1.0.0", c.GetHeader("Gin-Version"))
assert.Empty(t, c.GetHeader("Connection"))
@@ -1999,8 +2052,9 @@ func TestGetRequestHeaderValue(t *testing.T) {
func TestContextGetRawData(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder())
body := bytes.NewBufferString("Fetch binary post data")
- c.Request, _ = http.NewRequest("POST", "/", body)
- c.Request.Header.Add("Content-Type", MIMEPOSTForm)
+ req, _ := http.NewRequest("POST", "/", body)
+ req.Header.Add("Content-Type", MIMEPOSTForm)
+ c.setRequest(req)
data, err := c.GetRawData()
assert.Nil(t, err)
@@ -2148,8 +2202,9 @@ func TestContextWithKeysMutex(t *testing.T) {
func TestRemoteIPFail(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder())
- c.Request, _ = http.NewRequest("POST", "/", nil)
- c.Request.RemoteAddr = "[:::]:80"
+ req, _ := http.NewRequest("POST", "/", nil)
+ req.RemoteAddr = "[:::]:80"
+ c.setRequest(req)
ip := net.ParseIP(c.RemoteIP())
trust := c.engine.isTrustedProxy(ip)
assert.Nil(t, ip)
@@ -2169,11 +2224,11 @@ func TestContextWithFallbackDeadlineFromRequestContext(t *testing.T) {
// enable ContextWithFallback feature flag
c2.engine.ContextWithFallback = true
- c2.Request, _ = http.NewRequest(http.MethodGet, "/", nil)
d := time.Now().Add(time.Second)
ctx, cancel := context.WithDeadline(context.Background(), d)
defer cancel()
- c2.Request = c2.Request.WithContext(ctx)
+ req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/", nil)
+ c2.setRequest(req)
deadline, ok = c2.Deadline()
assert.Equal(t, d, deadline)
assert.True(t, ok)
@@ -2190,11 +2245,13 @@ func TestContextWithFallbackDoneFromRequestContext(t *testing.T) {
// enable ContextWithFallback feature flag
c2.engine.ContextWithFallback = true
- c2.Request, _ = http.NewRequest(http.MethodGet, "/", nil)
ctx, cancel := context.WithCancel(context.Background())
- c2.Request = c2.Request.WithContext(ctx)
+ req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/", nil)
+ c2.setRequest(req)
cancel()
- assert.NotNil(t, <-c2.Done())
+ if assert.NotNil(t, c2.Done()) {
+ assert.NotNil(t, <-c2.Done())
+ }
}
func TestContextWithFallbackErrFromRequestContext(t *testing.T) {
@@ -2208,9 +2265,9 @@ func TestContextWithFallbackErrFromRequestContext(t *testing.T) {
// enable ContextWithFallback feature flag
c2.engine.ContextWithFallback = true
- c2.Request, _ = http.NewRequest(http.MethodGet, "/", nil)
ctx, cancel := context.WithCancel(context.Background())
- c2.Request = c2.Request.WithContext(ctx)
+ req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/", nil)
+ c2.setRequest(req)
cancel()
assert.EqualError(t, c2.Err(), context.Canceled.Error())
@@ -2231,8 +2288,9 @@ func TestContextWithFallbackValueFromRequestContext(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder())
// enable ContextWithFallback feature flag
c.engine.ContextWithFallback = true
- c.Request, _ = http.NewRequest("POST", "/", nil)
- c.Request = c.Request.WithContext(context.WithValue(context.TODO(), key, "value"))
+ ctx := context.WithValue(context.TODO(), key, "value")
+ req, _ := http.NewRequestWithContext(ctx, "POST", "/", nil)
+ c.setRequest(req)
return c, key
},
value: "value",
@@ -2243,8 +2301,9 @@ func TestContextWithFallbackValueFromRequestContext(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder())
// enable ContextWithFallback feature flag
c.engine.ContextWithFallback = true
- c.Request, _ = http.NewRequest("POST", "/", nil)
- c.Request = c.Request.WithContext(context.WithValue(context.TODO(), contextKey("key"), "value"))
+ ctx := context.WithValue(context.TODO(), contextKey("key"), "value")
+ req, _ := http.NewRequestWithContext(ctx, "POST", "/", nil)
+ c.setRequest(req)
return c, contextKey("key")
},
value: "value",
@@ -2255,7 +2314,7 @@ func TestContextWithFallbackValueFromRequestContext(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder())
// enable ContextWithFallback feature flag
c.engine.ContextWithFallback = true
- c.Request = nil
+ c.setRequest(nil)
return c, "key"
},
value: nil,
@@ -2266,7 +2325,8 @@ func TestContextWithFallbackValueFromRequestContext(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder())
// enable ContextWithFallback feature flag
c.engine.ContextWithFallback = true
- c.Request, _ = http.NewRequest("POST", "/", nil)
+ req, _ := http.NewRequest("POST", "/", nil)
+ c.setRequest(req)
return c, "key"
},
value: nil,
@@ -2280,6 +2340,16 @@ func TestContextWithFallbackValueFromRequestContext(t *testing.T) {
}
}
+func TestContextWithFallbackModifiesRequestContext(t *testing.T) {
+ c, _ := CreateTestContext(httptest.NewRecorder())
+ // enable ContextWithFallback feature flag
+ c.engine.ContextWithFallback = true
+ req, _ := http.NewRequest(http.MethodGet, "/", nil)
+ c.setRequest(req)
+
+ assert.Equal(t, c, c.Request.Context())
+}
+
func TestContextCopyShouldNotCancel(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
diff --git a/gin.go b/gin.go
index f9324299..c81cda6b 100644
--- a/gin.go
+++ b/gin.go
@@ -566,7 +566,7 @@ func (engine *Engine) RunListener(listener net.Listener) (err error) {
func (engine *Engine) ServeHTTP(w http.ResponseWriter, req *http.Request) {
c := engine.pool.Get().(*Context)
c.writermem.reset(w)
- c.Request = req
+ c.setRequest(req)
c.reset()
engine.handleHTTPRequest(c)