Modify c.Request.Context() when using ContextWithFallback

This commit is contained in:
Ben Visness 2022-10-18 17:11:02 -05:00
parent 33ab0fc155
commit 4a3008f3a0
3 changed files with 209 additions and 121 deletions

View File

@ -5,6 +5,7 @@
package gin package gin
import ( import (
"context"
"errors" "errors"
"io" "io"
"io/ioutil" "io/ioutil"
@ -50,9 +51,10 @@ const abortIndex int8 = math.MaxInt8 >> 1
// Context is the most important part of gin. It allows us to pass variables between middleware, // Context is the most important part of gin. It allows us to pass variables between middleware,
// manage the flow, validate the JSON of a request and render a JSON response for example. // manage the flow, validate the JSON of a request and render a JSON response for example.
type Context struct { type Context struct {
writermem responseWriter writermem responseWriter
Request *http.Request Request *http.Request
Writer ResponseWriter Writer ResponseWriter
requestContext context.Context
Params Params Params Params
handlers HandlersChain handlers HandlersChain
@ -108,6 +110,22 @@ func (c *Context) reset() {
*c.skippedNodes = (*c.skippedNodes)[:0] *c.skippedNodes = (*c.skippedNodes)[:0]
} }
// Initializes c.Request and c.requestContext according to the ContextWithFallback feature flag
func (c *Context) setRequest(req *http.Request) {
if req == nil {
c.Request = nil
c.requestContext = nil
return
}
c.requestContext = req.Context()
if c.engine.ContextWithFallback {
c.Request = req.WithContext(c)
} else {
c.Request = req
}
}
// Copy returns a copy of the current context that can be safely used outside the request's scope. // Copy returns a copy of the current context that can be safely used outside the request's scope.
// This has to be used when the context has to be passed to a goroutine. // This has to be used when the context has to be passed to a goroutine.
func (c *Context) Copy() *Context { func (c *Context) Copy() *Context {
@ -1173,26 +1191,26 @@ func (c *Context) SetAccepted(formats ...string) {
// Deadline returns that there is no deadline (ok==false) when c.Request has no Context. // Deadline returns that there is no deadline (ok==false) when c.Request has no Context.
func (c *Context) Deadline() (deadline time.Time, ok bool) { func (c *Context) Deadline() (deadline time.Time, ok bool) {
if !c.engine.ContextWithFallback || c.Request == nil || c.Request.Context() == nil { if !c.engine.ContextWithFallback || c.Request == nil || c.requestContext == nil {
return return
} }
return c.Request.Context().Deadline() return c.requestContext.Deadline()
} }
// Done returns nil (chan which will wait forever) when c.Request has no Context. // Done returns nil (chan which will wait forever) when c.Request has no Context.
func (c *Context) Done() <-chan struct{} { func (c *Context) Done() <-chan struct{} {
if !c.engine.ContextWithFallback || c.Request == nil || c.Request.Context() == nil { if !c.engine.ContextWithFallback || c.Request == nil || c.requestContext == nil {
return nil return nil
} }
return c.Request.Context().Done() return c.requestContext.Done()
} }
// Err returns nil when c.Request has no Context. // Err returns nil when c.Request has no Context.
func (c *Context) Err() error { func (c *Context) Err() error {
if !c.engine.ContextWithFallback || c.Request == nil || c.Request.Context() == nil { if !c.engine.ContextWithFallback || c.Request == nil || c.requestContext == nil {
return nil return nil
} }
return c.Request.Context().Err() return c.requestContext.Err()
} }
// Value returns the value associated with this context for key, or nil // Value returns the value associated with this context for key, or nil
@ -1210,8 +1228,8 @@ func (c *Context) Value(key any) any {
return val return val
} }
} }
if !c.engine.ContextWithFallback || c.Request == nil || c.Request.Context() == nil { if !c.engine.ContextWithFallback || c.Request == nil || c.requestContext == nil {
return nil return nil
} }
return c.Request.Context().Value(key) return c.requestContext.Value(key)
} }

View File

@ -78,8 +78,9 @@ func TestContextFormFile(t *testing.T) {
} }
mw.Close() mw.Close()
c, _ := CreateTestContext(httptest.NewRecorder()) c, _ := CreateTestContext(httptest.NewRecorder())
c.Request, _ = http.NewRequest("POST", "/", buf) req, _ := http.NewRequest("POST", "/", buf)
c.Request.Header.Set("Content-Type", mw.FormDataContentType()) req.Header.Set("Content-Type", mw.FormDataContentType())
c.setRequest(req)
f, err := c.FormFile("file") f, err := c.FormFile("file")
if assert.NoError(t, err) { if assert.NoError(t, err) {
assert.Equal(t, "test", f.Filename) assert.Equal(t, "test", f.Filename)
@ -99,8 +100,9 @@ func TestContextMultipartForm(t *testing.T) {
} }
mw.Close() mw.Close()
c, _ := CreateTestContext(httptest.NewRecorder()) c, _ := CreateTestContext(httptest.NewRecorder())
c.Request, _ = http.NewRequest("POST", "/", buf) req, _ := http.NewRequest("POST", "/", buf)
c.Request.Header.Set("Content-Type", mw.FormDataContentType()) req.Header.Set("Content-Type", mw.FormDataContentType())
c.setRequest(req)
f, err := c.MultipartForm() f, err := c.MultipartForm()
if assert.NoError(t, err) { if assert.NoError(t, err) {
assert.NotNil(t, f) assert.NotNil(t, f)
@ -115,8 +117,9 @@ func TestSaveUploadedOpenFailed(t *testing.T) {
mw.Close() mw.Close()
c, _ := CreateTestContext(httptest.NewRecorder()) c, _ := CreateTestContext(httptest.NewRecorder())
c.Request, _ = http.NewRequest("POST", "/", buf) req, _ := http.NewRequest("POST", "/", buf)
c.Request.Header.Set("Content-Type", mw.FormDataContentType()) req.Header.Set("Content-Type", mw.FormDataContentType())
c.setRequest(req)
f := &multipart.FileHeader{ f := &multipart.FileHeader{
Filename: "file", Filename: "file",
@ -134,8 +137,9 @@ func TestSaveUploadedCreateFailed(t *testing.T) {
} }
mw.Close() mw.Close()
c, _ := CreateTestContext(httptest.NewRecorder()) c, _ := CreateTestContext(httptest.NewRecorder())
c.Request, _ = http.NewRequest("POST", "/", buf) req, _ := http.NewRequest("POST", "/", buf)
c.Request.Header.Set("Content-Type", mw.FormDataContentType()) req.Header.Set("Content-Type", mw.FormDataContentType())
c.setRequest(req)
f, err := c.FormFile("file") f, err := c.FormFile("file")
if assert.NoError(t, err) { if assert.NoError(t, err) {
assert.Equal(t, "test", f.Filename) assert.Equal(t, "test", f.Filename)
@ -318,7 +322,8 @@ func TestContextGetStringMapStringSlice(t *testing.T) {
func TestContextCopy(t *testing.T) { func TestContextCopy(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder()) c, _ := CreateTestContext(httptest.NewRecorder())
c.index = 2 c.index = 2
c.Request, _ = http.NewRequest("POST", "/hola", nil) req, _ := http.NewRequest("POST", "/hola", nil)
c.setRequest(req)
c.handlers = HandlersChain{func(c *Context) {}} c.handlers = HandlersChain{func(c *Context) {}}
c.Params = Params{Param{Key: "foo", Value: "bar"}} c.Params = Params{Param{Key: "foo", Value: "bar"}}
c.Set("foo", "bar") c.Set("foo", "bar")
@ -373,7 +378,8 @@ func TestContextHandler(t *testing.T) {
func TestContextQuery(t *testing.T) { func TestContextQuery(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder()) c, _ := CreateTestContext(httptest.NewRecorder())
c.Request, _ = http.NewRequest("GET", "http://example.com/?foo=bar&page=10&id=", nil) req, _ := http.NewRequest("GET", "http://example.com/?foo=bar&page=10&id=", nil)
c.setRequest(req)
value, ok := c.GetQuery("foo") value, ok := c.GetQuery("foo")
assert.True(t, ok) assert.True(t, ok)
@ -424,9 +430,10 @@ func TestContextDefaultQueryOnEmptyRequest(t *testing.T) {
func TestContextQueryAndPostForm(t *testing.T) { func TestContextQueryAndPostForm(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder()) c, _ := CreateTestContext(httptest.NewRecorder())
body := bytes.NewBufferString("foo=bar&page=11&both=&foo=second") body := bytes.NewBufferString("foo=bar&page=11&both=&foo=second")
c.Request, _ = http.NewRequest("POST", req, _ := http.NewRequest("POST",
"/?both=GET&id=main&id=omit&array[]=first&array[]=second&ids[a]=hi&ids[b]=3.14", body) "/?both=GET&id=main&id=omit&array[]=first&array[]=second&ids[a]=hi&ids[b]=3.14", body)
c.Request.Header.Add("Content-Type", MIMEPOSTForm) req.Header.Add("Content-Type", MIMEPOSTForm)
c.setRequest(req)
assert.Equal(t, "bar", c.DefaultPostForm("foo", "none")) assert.Equal(t, "bar", c.DefaultPostForm("foo", "none"))
assert.Equal(t, "bar", c.PostForm("foo")) assert.Equal(t, "bar", c.PostForm("foo"))
@ -521,7 +528,7 @@ func TestContextQueryAndPostForm(t *testing.T) {
func TestContextPostFormMultipart(t *testing.T) { func TestContextPostFormMultipart(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder()) c, _ := CreateTestContext(httptest.NewRecorder())
c.Request = createMultipartRequest() c.setRequest(createMultipartRequest())
var obj struct { var obj struct {
Foo string `form:"foo"` Foo string `form:"foo"`
@ -627,8 +634,9 @@ func TestContextSetCookiePathEmpty(t *testing.T) {
func TestContextGetCookie(t *testing.T) { func TestContextGetCookie(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder()) c, _ := CreateTestContext(httptest.NewRecorder())
c.Request, _ = http.NewRequest("GET", "/get", nil) req, _ := http.NewRequest("GET", "/get", nil)
c.Request.Header.Set("Cookie", "user=gin") req.Header.Set("Cookie", "user=gin")
c.setRequest(req)
cookie, _ := c.Cookie("user") cookie, _ := c.Cookie("user")
assert.Equal(t, "gin", cookie) assert.Equal(t, "gin", cookie)
@ -683,7 +691,8 @@ func TestContextRenderJSON(t *testing.T) {
func TestContextRenderJSONP(t *testing.T) { func TestContextRenderJSONP(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := CreateTestContext(w) c, _ := CreateTestContext(w)
c.Request, _ = http.NewRequest("GET", "http://example.com/?callback=x", nil) req, _ := http.NewRequest("GET", "http://example.com/?callback=x", nil)
c.setRequest(req)
c.JSONP(http.StatusCreated, H{"foo": "bar"}) c.JSONP(http.StatusCreated, H{"foo": "bar"})
@ -697,7 +706,8 @@ func TestContextRenderJSONP(t *testing.T) {
func TestContextRenderJSONPWithoutCallback(t *testing.T) { func TestContextRenderJSONPWithoutCallback(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := CreateTestContext(w) c, _ := CreateTestContext(w)
c.Request, _ = http.NewRequest("GET", "http://example.com", nil) req, _ := http.NewRequest("GET", "http://example.com", nil)
c.setRequest(req)
c.JSONP(http.StatusCreated, H{"foo": "bar"}) c.JSONP(http.StatusCreated, H{"foo": "bar"})
@ -996,7 +1006,8 @@ func TestContextRenderFile(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := CreateTestContext(w) c, _ := CreateTestContext(w)
c.Request, _ = http.NewRequest("GET", "/", nil) req, _ := http.NewRequest("GET", "/", nil)
c.setRequest(req)
c.File("./gin.go") c.File("./gin.go")
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
@ -1010,7 +1021,8 @@ func TestContextRenderFileFromFS(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := CreateTestContext(w) c, _ := CreateTestContext(w)
c.Request, _ = http.NewRequest("GET", "/some/path", nil) req, _ := http.NewRequest("GET", "/some/path", nil)
c.setRequest(req)
c.FileFromFS("./gin.go", Dir(".", false)) c.FileFromFS("./gin.go", Dir(".", false))
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
@ -1026,7 +1038,8 @@ func TestContextRenderAttachment(t *testing.T) {
c, _ := CreateTestContext(w) c, _ := CreateTestContext(w)
newFilename := "new_filename.go" newFilename := "new_filename.go"
c.Request, _ = http.NewRequest("GET", "/", nil) req, _ := http.NewRequest("GET", "/", nil)
c.setRequest(req)
c.FileAttachment("./gin.go", newFilename) c.FileAttachment("./gin.go", newFilename)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
@ -1039,7 +1052,8 @@ func TestContextRenderUTF8Attachment(t *testing.T) {
c, _ := CreateTestContext(w) c, _ := CreateTestContext(w)
newFilename := "new🧡_filename.go" newFilename := "new🧡_filename.go"
c.Request, _ = http.NewRequest("GET", "/", nil) req, _ := http.NewRequest("GET", "/", nil)
c.setRequest(req)
c.FileAttachment("./gin.go", newFilename) c.FileAttachment("./gin.go", newFilename)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
@ -1118,7 +1132,8 @@ func TestContextRenderRedirectWithRelativePath(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := CreateTestContext(w) c, _ := CreateTestContext(w)
c.Request, _ = http.NewRequest("POST", "http://example.com", nil) req, _ := http.NewRequest("POST", "http://example.com", nil)
c.setRequest(req)
assert.Panics(t, func() { c.Redirect(299, "/new_path") }) assert.Panics(t, func() { c.Redirect(299, "/new_path") })
assert.Panics(t, func() { c.Redirect(309, "/new_path") }) assert.Panics(t, func() { c.Redirect(309, "/new_path") })
@ -1132,7 +1147,8 @@ func TestContextRenderRedirectWithAbsolutePath(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := CreateTestContext(w) c, _ := CreateTestContext(w)
c.Request, _ = http.NewRequest("POST", "http://example.com", nil) req, _ := http.NewRequest("POST", "http://example.com", nil)
c.setRequest(req)
c.Redirect(http.StatusFound, "http://google.com") c.Redirect(http.StatusFound, "http://google.com")
c.Writer.WriteHeaderNow() c.Writer.WriteHeaderNow()
@ -1144,7 +1160,8 @@ func TestContextRenderRedirectWith201(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := CreateTestContext(w) c, _ := CreateTestContext(w)
c.Request, _ = http.NewRequest("POST", "http://example.com", nil) req, _ := http.NewRequest("POST", "http://example.com", nil)
c.setRequest(req)
c.Redirect(http.StatusCreated, "/resource") c.Redirect(http.StatusCreated, "/resource")
c.Writer.WriteHeaderNow() c.Writer.WriteHeaderNow()
@ -1154,7 +1171,8 @@ func TestContextRenderRedirectWith201(t *testing.T) {
func TestContextRenderRedirectAll(t *testing.T) { func TestContextRenderRedirectAll(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder()) c, _ := CreateTestContext(httptest.NewRecorder())
c.Request, _ = http.NewRequest("POST", "http://example.com", nil) req, _ := http.NewRequest("POST", "http://example.com", nil)
c.setRequest(req)
assert.Panics(t, func() { c.Redirect(http.StatusOK, "/resource") }) assert.Panics(t, func() { c.Redirect(http.StatusOK, "/resource") })
assert.Panics(t, func() { c.Redirect(http.StatusAccepted, "/resource") }) assert.Panics(t, func() { c.Redirect(http.StatusAccepted, "/resource") })
assert.Panics(t, func() { c.Redirect(299, "/resource") }) assert.Panics(t, func() { c.Redirect(299, "/resource") })
@ -1166,7 +1184,8 @@ func TestContextRenderRedirectAll(t *testing.T) {
func TestContextNegotiationWithJSON(t *testing.T) { func TestContextNegotiationWithJSON(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := CreateTestContext(w) c, _ := CreateTestContext(w)
c.Request, _ = http.NewRequest("POST", "", nil) req, _ := http.NewRequest("POST", "", nil)
c.setRequest(req)
c.Negotiate(http.StatusOK, Negotiate{ c.Negotiate(http.StatusOK, Negotiate{
Offered: []string{MIMEJSON, MIMEXML, MIMEYAML}, Offered: []string{MIMEJSON, MIMEXML, MIMEYAML},
@ -1181,7 +1200,8 @@ func TestContextNegotiationWithJSON(t *testing.T) {
func TestContextNegotiationWithXML(t *testing.T) { func TestContextNegotiationWithXML(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := CreateTestContext(w) c, _ := CreateTestContext(w)
c.Request, _ = http.NewRequest("POST", "", nil) req, _ := http.NewRequest("POST", "", nil)
c.setRequest(req)
c.Negotiate(http.StatusOK, Negotiate{ c.Negotiate(http.StatusOK, Negotiate{
Offered: []string{MIMEXML, MIMEJSON, MIMEYAML}, Offered: []string{MIMEXML, MIMEJSON, MIMEYAML},
@ -1196,7 +1216,8 @@ func TestContextNegotiationWithXML(t *testing.T) {
func TestContextNegotiationWithYAML(t *testing.T) { func TestContextNegotiationWithYAML(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := CreateTestContext(w) c, _ := CreateTestContext(w)
c.Request, _ = http.NewRequest("POST", "", nil) req, _ := http.NewRequest("POST", "", nil)
c.setRequest(req)
c.Negotiate(http.StatusOK, Negotiate{ c.Negotiate(http.StatusOK, Negotiate{
Offered: []string{MIMEYAML, MIMEXML, MIMEJSON, MIMETOML}, Offered: []string{MIMEYAML, MIMEXML, MIMEJSON, MIMETOML},
@ -1211,7 +1232,8 @@ func TestContextNegotiationWithYAML(t *testing.T) {
func TestContextNegotiationWithTOML(t *testing.T) { func TestContextNegotiationWithTOML(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := CreateTestContext(w) c, _ := CreateTestContext(w)
c.Request, _ = http.NewRequest("POST", "", nil) req, _ := http.NewRequest("POST", "", nil)
c.setRequest(req)
c.Negotiate(http.StatusOK, Negotiate{ c.Negotiate(http.StatusOK, Negotiate{
Offered: []string{MIMETOML, MIMEXML, MIMEJSON, MIMEYAML}, Offered: []string{MIMETOML, MIMEXML, MIMEJSON, MIMEYAML},
@ -1226,7 +1248,8 @@ func TestContextNegotiationWithTOML(t *testing.T) {
func TestContextNegotiationWithHTML(t *testing.T) { func TestContextNegotiationWithHTML(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, router := CreateTestContext(w) c, router := CreateTestContext(w)
c.Request, _ = http.NewRequest("POST", "", nil) req, _ := http.NewRequest("POST", "", nil)
c.setRequest(req)
templ := template.Must(template.New("t").Parse(`Hello {{.name}}`)) templ := template.Must(template.New("t").Parse(`Hello {{.name}}`))
router.SetHTMLTemplate(templ) router.SetHTMLTemplate(templ)
@ -1244,7 +1267,8 @@ func TestContextNegotiationWithHTML(t *testing.T) {
func TestContextNegotiationNotSupport(t *testing.T) { func TestContextNegotiationNotSupport(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := CreateTestContext(w) c, _ := CreateTestContext(w)
c.Request, _ = http.NewRequest("POST", "", nil) req, _ := http.NewRequest("POST", "", nil)
c.setRequest(req)
c.Negotiate(http.StatusOK, Negotiate{ c.Negotiate(http.StatusOK, Negotiate{
Offered: []string{MIMEPOSTForm}, Offered: []string{MIMEPOSTForm},
@ -1257,7 +1281,8 @@ func TestContextNegotiationNotSupport(t *testing.T) {
func TestContextNegotiationFormat(t *testing.T) { func TestContextNegotiationFormat(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder()) c, _ := CreateTestContext(httptest.NewRecorder())
c.Request, _ = http.NewRequest("POST", "", nil) req, _ := http.NewRequest("POST", "", nil)
c.setRequest(req)
assert.Panics(t, func() { c.NegotiateFormat() }) assert.Panics(t, func() { c.NegotiateFormat() })
assert.Equal(t, MIMEJSON, c.NegotiateFormat(MIMEJSON, MIMEXML)) assert.Equal(t, MIMEJSON, c.NegotiateFormat(MIMEJSON, MIMEXML))
@ -1266,8 +1291,9 @@ func TestContextNegotiationFormat(t *testing.T) {
func TestContextNegotiationFormatWithAccept(t *testing.T) { func TestContextNegotiationFormatWithAccept(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder()) c, _ := CreateTestContext(httptest.NewRecorder())
c.Request, _ = http.NewRequest("POST", "/", nil) req, _ := http.NewRequest("POST", "/", nil)
c.Request.Header.Add("Accept", "text/html,application/xhtml+xml,application/xml;q=0.9;q=0.8") req.Header.Add("Accept", "text/html,application/xhtml+xml,application/xml;q=0.9;q=0.8")
c.setRequest(req)
assert.Equal(t, MIMEXML, c.NegotiateFormat(MIMEJSON, MIMEXML)) assert.Equal(t, MIMEXML, c.NegotiateFormat(MIMEJSON, MIMEXML))
assert.Equal(t, MIMEHTML, c.NegotiateFormat(MIMEXML, MIMEHTML)) assert.Equal(t, MIMEHTML, c.NegotiateFormat(MIMEXML, MIMEHTML))
@ -1276,8 +1302,9 @@ func TestContextNegotiationFormatWithAccept(t *testing.T) {
func TestContextNegotiationFormatWithWildcardAccept(t *testing.T) { func TestContextNegotiationFormatWithWildcardAccept(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder()) c, _ := CreateTestContext(httptest.NewRecorder())
c.Request, _ = http.NewRequest("POST", "/", nil) req, _ := http.NewRequest("POST", "/", nil)
c.Request.Header.Add("Accept", "*/*") req.Header.Add("Accept", "*/*")
c.setRequest(req)
assert.Equal(t, c.NegotiateFormat("*/*"), "*/*") assert.Equal(t, c.NegotiateFormat("*/*"), "*/*")
assert.Equal(t, c.NegotiateFormat("text/*"), "text/*") assert.Equal(t, c.NegotiateFormat("text/*"), "text/*")
@ -1287,8 +1314,9 @@ func TestContextNegotiationFormatWithWildcardAccept(t *testing.T) {
assert.Equal(t, c.NegotiateFormat(MIMEHTML), MIMEHTML) assert.Equal(t, c.NegotiateFormat(MIMEHTML), MIMEHTML)
c, _ = CreateTestContext(httptest.NewRecorder()) c, _ = CreateTestContext(httptest.NewRecorder())
c.Request, _ = http.NewRequest("POST", "/", nil) req, _ = http.NewRequest("POST", "/", nil)
c.Request.Header.Add("Accept", "text/*") req.Header.Add("Accept", "text/*")
c.setRequest(req)
assert.Equal(t, c.NegotiateFormat("*/*"), "*/*") assert.Equal(t, c.NegotiateFormat("*/*"), "*/*")
assert.Equal(t, c.NegotiateFormat("text/*"), "text/*") assert.Equal(t, c.NegotiateFormat("text/*"), "text/*")
@ -1300,8 +1328,9 @@ func TestContextNegotiationFormatWithWildcardAccept(t *testing.T) {
func TestContextNegotiationFormatCustom(t *testing.T) { func TestContextNegotiationFormatCustom(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder()) c, _ := CreateTestContext(httptest.NewRecorder())
c.Request, _ = http.NewRequest("POST", "/", nil) req, _ := http.NewRequest("POST", "/", nil)
c.Request.Header.Add("Accept", "text/html,application/xhtml+xml,application/xml;q=0.9;q=0.8") req.Header.Add("Accept", "text/html,application/xhtml+xml,application/xml;q=0.9;q=0.8")
c.setRequest(req)
c.Accepted = nil c.Accepted = nil
c.SetAccepted(MIMEJSON, MIMEXML) c.SetAccepted(MIMEJSON, MIMEXML)
@ -1433,7 +1462,8 @@ func TestContextAbortWithError(t *testing.T) {
func TestContextClientIP(t *testing.T) { func TestContextClientIP(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder()) c, _ := CreateTestContext(httptest.NewRecorder())
c.Request, _ = http.NewRequest("POST", "/", nil) req, _ := http.NewRequest("POST", "/", nil)
c.setRequest(req)
c.engine.trustedCIDRs, _ = c.engine.prepareTrustedCIDRs() c.engine.trustedCIDRs, _ = c.engine.prepareTrustedCIDRs()
resetContextForClientIPTests(c) resetContextForClientIPTests(c)
@ -1569,16 +1599,18 @@ func resetContextForClientIPTests(c *Context) {
func TestContextContentType(t *testing.T) { func TestContextContentType(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder()) c, _ := CreateTestContext(httptest.NewRecorder())
c.Request, _ = http.NewRequest("POST", "/", nil) req, _ := http.NewRequest("POST", "/", nil)
c.Request.Header.Set("Content-Type", "application/json; charset=utf-8") req.Header.Set("Content-Type", "application/json; charset=utf-8")
c.setRequest(req)
assert.Equal(t, "application/json", c.ContentType()) assert.Equal(t, "application/json", c.ContentType())
} }
func TestContextAutoBindJSON(t *testing.T) { func TestContextAutoBindJSON(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder()) c, _ := CreateTestContext(httptest.NewRecorder())
c.Request, _ = http.NewRequest("POST", "/", bytes.NewBufferString("{\"foo\":\"bar\", \"bar\":\"foo\"}")) req, _ := http.NewRequest("POST", "/", bytes.NewBufferString("{\"foo\":\"bar\", \"bar\":\"foo\"}"))
c.Request.Header.Add("Content-Type", MIMEJSON) req.Header.Add("Content-Type", MIMEJSON)
c.setRequest(req)
var obj struct { var obj struct {
Foo string `json:"foo"` Foo string `json:"foo"`
@ -1594,8 +1626,9 @@ func TestContextBindWithJSON(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := CreateTestContext(w) c, _ := CreateTestContext(w)
c.Request, _ = http.NewRequest("POST", "/", bytes.NewBufferString("{\"foo\":\"bar\", \"bar\":\"foo\"}")) req, _ := http.NewRequest("POST", "/", bytes.NewBufferString("{\"foo\":\"bar\", \"bar\":\"foo\"}"))
c.Request.Header.Add("Content-Type", MIMEXML) // set fake content-type req.Header.Add("Content-Type", MIMEXML) // set fake content-type
c.setRequest(req)
var obj struct { var obj struct {
Foo string `json:"foo"` Foo string `json:"foo"`
@ -1611,12 +1644,13 @@ func TestContextBindWithXML(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := CreateTestContext(w) c, _ := CreateTestContext(w)
c.Request, _ = http.NewRequest("POST", "/", bytes.NewBufferString(`<?xml version="1.0" encoding="UTF-8"?> req, _ := http.NewRequest("POST", "/", bytes.NewBufferString(`<?xml version="1.0" encoding="UTF-8"?>
<root> <root>
<foo>FOO</foo> <foo>FOO</foo>
<bar>BAR</bar> <bar>BAR</bar>
</root>`)) </root>`))
c.Request.Header.Add("Content-Type", MIMEXML) // set fake content-type req.Header.Add("Content-Type", MIMEXML) // set fake content-type
c.setRequest(req)
var obj struct { var obj struct {
Foo string `xml:"foo"` Foo string `xml:"foo"`
@ -1632,10 +1666,11 @@ func TestContextBindHeader(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := CreateTestContext(w) c, _ := CreateTestContext(w)
c.Request, _ = http.NewRequest("POST", "/", nil) req, _ := http.NewRequest("POST", "/", nil)
c.Request.Header.Add("rate", "8000") req.Header.Add("rate", "8000")
c.Request.Header.Add("domain", "music") req.Header.Add("domain", "music")
c.Request.Header.Add("limit", "1000") req.Header.Add("limit", "1000")
c.setRequest(req)
var testHeader struct { var testHeader struct {
Rate int `header:"Rate"` Rate int `header:"Rate"`
@ -1654,7 +1689,8 @@ func TestContextBindWithQuery(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := CreateTestContext(w) c, _ := CreateTestContext(w)
c.Request, _ = http.NewRequest("POST", "/?foo=bar&bar=foo", bytes.NewBufferString("foo=unused")) req, _ := http.NewRequest("POST", "/?foo=bar&bar=foo", bytes.NewBufferString("foo=unused"))
c.setRequest(req)
var obj struct { var obj struct {
Foo string `form:"foo"` Foo string `form:"foo"`
@ -1670,8 +1706,9 @@ func TestContextBindWithYAML(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := CreateTestContext(w) c, _ := CreateTestContext(w)
c.Request, _ = http.NewRequest("POST", "/", bytes.NewBufferString("foo: bar\nbar: foo")) req, _ := http.NewRequest("POST", "/", bytes.NewBufferString("foo: bar\nbar: foo"))
c.Request.Header.Add("Content-Type", MIMEXML) // set fake content-type req.Header.Add("Content-Type", MIMEXML) // set fake content-type
c.setRequest(req)
var obj struct { var obj struct {
Foo string `yaml:"foo"` Foo string `yaml:"foo"`
@ -1687,8 +1724,9 @@ func TestContextBindWithTOML(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := CreateTestContext(w) c, _ := CreateTestContext(w)
c.Request, _ = http.NewRequest("POST", "/", bytes.NewBufferString("foo = 'bar'\nbar = 'foo'")) req, _ := http.NewRequest("POST", "/", bytes.NewBufferString("foo = 'bar'\nbar = 'foo'"))
c.Request.Header.Add("Content-Type", MIMEXML) // set fake content-type req.Header.Add("Content-Type", MIMEXML) // set fake content-type
c.setRequest(req)
var obj struct { var obj struct {
Foo string `toml:"foo"` Foo string `toml:"foo"`
@ -1704,8 +1742,9 @@ func TestContextBadAutoBind(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := CreateTestContext(w) c, _ := CreateTestContext(w)
c.Request, _ = http.NewRequest("POST", "http://example.com", bytes.NewBufferString("\"foo\":\"bar\", \"bar\":\"foo\"}")) req, _ := http.NewRequest("POST", "http://example.com", bytes.NewBufferString("\"foo\":\"bar\", \"bar\":\"foo\"}"))
c.Request.Header.Add("Content-Type", MIMEJSON) req.Header.Add("Content-Type", MIMEJSON)
c.setRequest(req)
var obj struct { var obj struct {
Foo string `json:"foo"` Foo string `json:"foo"`
Bar string `json:"bar"` Bar string `json:"bar"`
@ -1723,8 +1762,9 @@ func TestContextBadAutoBind(t *testing.T) {
func TestContextAutoShouldBindJSON(t *testing.T) { func TestContextAutoShouldBindJSON(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder()) c, _ := CreateTestContext(httptest.NewRecorder())
c.Request, _ = http.NewRequest("POST", "/", bytes.NewBufferString("{\"foo\":\"bar\", \"bar\":\"foo\"}")) req, _ := http.NewRequest("POST", "/", bytes.NewBufferString("{\"foo\":\"bar\", \"bar\":\"foo\"}"))
c.Request.Header.Add("Content-Type", MIMEJSON) req.Header.Add("Content-Type", MIMEJSON)
c.setRequest(req)
var obj struct { var obj struct {
Foo string `json:"foo"` Foo string `json:"foo"`
@ -1740,8 +1780,9 @@ func TestContextShouldBindWithJSON(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := CreateTestContext(w) c, _ := CreateTestContext(w)
c.Request, _ = http.NewRequest("POST", "/", bytes.NewBufferString("{\"foo\":\"bar\", \"bar\":\"foo\"}")) req, _ := http.NewRequest("POST", "/", bytes.NewBufferString("{\"foo\":\"bar\", \"bar\":\"foo\"}"))
c.Request.Header.Add("Content-Type", MIMEXML) // set fake content-type req.Header.Add("Content-Type", MIMEXML) // set fake content-type
c.setRequest(req)
var obj struct { var obj struct {
Foo string `json:"foo"` Foo string `json:"foo"`
@ -1757,12 +1798,13 @@ func TestContextShouldBindWithXML(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := CreateTestContext(w) c, _ := CreateTestContext(w)
c.Request, _ = http.NewRequest("POST", "/", bytes.NewBufferString(`<?xml version="1.0" encoding="UTF-8"?> req, _ := http.NewRequest("POST", "/", bytes.NewBufferString(`<?xml version="1.0" encoding="UTF-8"?>
<root> <root>
<foo>FOO</foo> <foo>FOO</foo>
<bar>BAR</bar> <bar>BAR</bar>
</root>`)) </root>`))
c.Request.Header.Add("Content-Type", MIMEXML) // set fake content-type req.Header.Add("Content-Type", MIMEXML) // set fake content-type
c.setRequest(req)
var obj struct { var obj struct {
Foo string `xml:"foo"` Foo string `xml:"foo"`
@ -1778,10 +1820,11 @@ func TestContextShouldBindHeader(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := CreateTestContext(w) c, _ := CreateTestContext(w)
c.Request, _ = http.NewRequest("POST", "/", nil) req, _ := http.NewRequest("POST", "/", nil)
c.Request.Header.Add("rate", "8000") req.Header.Add("rate", "8000")
c.Request.Header.Add("domain", "music") req.Header.Add("domain", "music")
c.Request.Header.Add("limit", "1000") req.Header.Add("limit", "1000")
c.setRequest(req)
var testHeader struct { var testHeader struct {
Rate int `header:"Rate"` Rate int `header:"Rate"`
@ -1800,7 +1843,8 @@ func TestContextShouldBindWithQuery(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := CreateTestContext(w) c, _ := CreateTestContext(w)
c.Request, _ = http.NewRequest("POST", "/?foo=bar&bar=foo&Foo=bar1&Bar=foo1", bytes.NewBufferString("foo=unused")) req, _ := http.NewRequest("POST", "/?foo=bar&bar=foo&Foo=bar1&Bar=foo1", bytes.NewBufferString("foo=unused"))
c.setRequest(req)
var obj struct { var obj struct {
Foo string `form:"foo"` Foo string `form:"foo"`
@ -1820,8 +1864,9 @@ func TestContextShouldBindWithYAML(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := CreateTestContext(w) c, _ := CreateTestContext(w)
c.Request, _ = http.NewRequest("POST", "/", bytes.NewBufferString("foo: bar\nbar: foo")) req, _ := http.NewRequest("POST", "/", bytes.NewBufferString("foo: bar\nbar: foo"))
c.Request.Header.Add("Content-Type", MIMEXML) // set fake content-type req.Header.Add("Content-Type", MIMEXML) // set fake content-type
c.setRequest(req)
var obj struct { var obj struct {
Foo string `yaml:"foo"` Foo string `yaml:"foo"`
@ -1837,8 +1882,9 @@ func TestContextShouldBindWithTOML(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := CreateTestContext(w) c, _ := CreateTestContext(w)
c.Request, _ = http.NewRequest("POST", "/", bytes.NewBufferString("foo='bar'\nbar= 'foo'")) req, _ := http.NewRequest("POST", "/", bytes.NewBufferString("foo='bar'\nbar= 'foo'"))
c.Request.Header.Add("Content-Type", MIMETOML) // set fake content-type req.Header.Add("Content-Type", MIMETOML) // set fake content-type
c.setRequest(req)
var obj struct { var obj struct {
Foo string `toml:"foo"` Foo string `toml:"foo"`
@ -1854,8 +1900,9 @@ func TestContextBadAutoShouldBind(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := CreateTestContext(w) c, _ := CreateTestContext(w)
c.Request, _ = http.NewRequest("POST", "http://example.com", bytes.NewBufferString("\"foo\":\"bar\", \"bar\":\"foo\"}")) req, _ := http.NewRequest("POST", "http://example.com", bytes.NewBufferString("\"foo\":\"bar\", \"bar\":\"foo\"}"))
c.Request.Header.Add("Content-Type", MIMEJSON) req.Header.Add("Content-Type", MIMEJSON)
c.setRequest(req)
var obj struct { var obj struct {
Foo string `json:"foo"` Foo string `json:"foo"`
Bar string `json:"bar"` Bar string `json:"bar"`
@ -1917,9 +1964,10 @@ func TestContextShouldBindBodyWith(t *testing.T) {
{ {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := CreateTestContext(w) c, _ := CreateTestContext(w)
c.Request, _ = http.NewRequest( req, _ := http.NewRequest(
"POST", "http://example.com", bytes.NewBufferString(tt.bodyA), "POST", "http://example.com", bytes.NewBufferString(tt.bodyA),
) )
c.setRequest(req)
// When it binds to typeA and typeB, it finds the body is // When it binds to typeA and typeB, it finds the body is
// not typeB but typeA. // not typeB but typeA.
objA := typeA{} objA := typeA{}
@ -1935,9 +1983,10 @@ func TestContextShouldBindBodyWith(t *testing.T) {
// not typeA but typeB. // not typeA but typeB.
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := CreateTestContext(w) c, _ := CreateTestContext(w)
c.Request, _ = http.NewRequest( req, _ := http.NewRequest(
"POST", "http://example.com", bytes.NewBufferString(tt.bodyB), "POST", "http://example.com", bytes.NewBufferString(tt.bodyB),
) )
c.setRequest(req)
objA := typeA{} objA := typeA{}
assert.Error(t, c.ShouldBindBodyWith(&objA, tt.bindingA)) assert.Error(t, c.ShouldBindBodyWith(&objA, tt.bindingA))
assert.NotEqual(t, typeA{"FOO"}, objA) assert.NotEqual(t, typeA{"FOO"}, objA)
@ -1950,7 +1999,8 @@ func TestContextShouldBindBodyWith(t *testing.T) {
func TestContextGolangContext(t *testing.T) { func TestContextGolangContext(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder()) c, _ := CreateTestContext(httptest.NewRecorder())
c.Request, _ = http.NewRequest("POST", "/", bytes.NewBufferString("{\"foo\":\"bar\", \"bar\":\"foo\"}")) req, _ := http.NewRequest("POST", "/", bytes.NewBufferString("{\"foo\":\"bar\", \"bar\":\"foo\"}"))
c.setRequest(req)
assert.NoError(t, c.Err()) assert.NoError(t, c.Err())
assert.Nil(t, c.Done()) assert.Nil(t, c.Done())
ti, ok := c.Deadline() ti, ok := c.Deadline()
@ -1968,29 +2018,32 @@ func TestContextGolangContext(t *testing.T) {
func TestWebsocketsRequired(t *testing.T) { func TestWebsocketsRequired(t *testing.T) {
// Example request from spec: https://tools.ietf.org/html/rfc6455#section-1.2 // Example request from spec: https://tools.ietf.org/html/rfc6455#section-1.2
c, _ := CreateTestContext(httptest.NewRecorder()) c, _ := CreateTestContext(httptest.NewRecorder())
c.Request, _ = http.NewRequest("GET", "/chat", nil) req, _ := http.NewRequest("GET", "/chat", nil)
c.Request.Header.Set("Host", "server.example.com") req.Header.Set("Host", "server.example.com")
c.Request.Header.Set("Upgrade", "websocket") req.Header.Set("Upgrade", "websocket")
c.Request.Header.Set("Connection", "Upgrade") req.Header.Set("Connection", "Upgrade")
c.Request.Header.Set("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==") req.Header.Set("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
c.Request.Header.Set("Origin", "http://example.com") req.Header.Set("Origin", "http://example.com")
c.Request.Header.Set("Sec-WebSocket-Protocol", "chat, superchat") req.Header.Set("Sec-WebSocket-Protocol", "chat, superchat")
c.Request.Header.Set("Sec-WebSocket-Version", "13") req.Header.Set("Sec-WebSocket-Version", "13")
c.setRequest(req)
assert.True(t, c.IsWebsocket()) assert.True(t, c.IsWebsocket())
// Normal request, no websocket required. // Normal request, no websocket required.
c, _ = CreateTestContext(httptest.NewRecorder()) c, _ = CreateTestContext(httptest.NewRecorder())
c.Request, _ = http.NewRequest("GET", "/chat", nil) req, _ = http.NewRequest("GET", "/chat", nil)
c.Request.Header.Set("Host", "server.example.com") req.Header.Set("Host", "server.example.com")
c.setRequest(req)
assert.False(t, c.IsWebsocket()) assert.False(t, c.IsWebsocket())
} }
func TestGetRequestHeaderValue(t *testing.T) { func TestGetRequestHeaderValue(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder()) c, _ := CreateTestContext(httptest.NewRecorder())
c.Request, _ = http.NewRequest("GET", "/chat", nil) req, _ := http.NewRequest("GET", "/chat", nil)
c.Request.Header.Set("Gin-Version", "1.0.0") req.Header.Set("Gin-Version", "1.0.0")
c.setRequest(req)
assert.Equal(t, "1.0.0", c.GetHeader("Gin-Version")) assert.Equal(t, "1.0.0", c.GetHeader("Gin-Version"))
assert.Empty(t, c.GetHeader("Connection")) assert.Empty(t, c.GetHeader("Connection"))
@ -1999,8 +2052,9 @@ func TestGetRequestHeaderValue(t *testing.T) {
func TestContextGetRawData(t *testing.T) { func TestContextGetRawData(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder()) c, _ := CreateTestContext(httptest.NewRecorder())
body := bytes.NewBufferString("Fetch binary post data") body := bytes.NewBufferString("Fetch binary post data")
c.Request, _ = http.NewRequest("POST", "/", body) req, _ := http.NewRequest("POST", "/", body)
c.Request.Header.Add("Content-Type", MIMEPOSTForm) req.Header.Add("Content-Type", MIMEPOSTForm)
c.setRequest(req)
data, err := c.GetRawData() data, err := c.GetRawData()
assert.Nil(t, err) assert.Nil(t, err)
@ -2148,8 +2202,9 @@ func TestContextWithKeysMutex(t *testing.T) {
func TestRemoteIPFail(t *testing.T) { func TestRemoteIPFail(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder()) c, _ := CreateTestContext(httptest.NewRecorder())
c.Request, _ = http.NewRequest("POST", "/", nil) req, _ := http.NewRequest("POST", "/", nil)
c.Request.RemoteAddr = "[:::]:80" req.RemoteAddr = "[:::]:80"
c.setRequest(req)
ip := net.ParseIP(c.RemoteIP()) ip := net.ParseIP(c.RemoteIP())
trust := c.engine.isTrustedProxy(ip) trust := c.engine.isTrustedProxy(ip)
assert.Nil(t, ip) assert.Nil(t, ip)
@ -2169,11 +2224,11 @@ func TestContextWithFallbackDeadlineFromRequestContext(t *testing.T) {
// enable ContextWithFallback feature flag // enable ContextWithFallback feature flag
c2.engine.ContextWithFallback = true c2.engine.ContextWithFallback = true
c2.Request, _ = http.NewRequest(http.MethodGet, "/", nil)
d := time.Now().Add(time.Second) d := time.Now().Add(time.Second)
ctx, cancel := context.WithDeadline(context.Background(), d) ctx, cancel := context.WithDeadline(context.Background(), d)
defer cancel() defer cancel()
c2.Request = c2.Request.WithContext(ctx) req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/", nil)
c2.setRequest(req)
deadline, ok = c2.Deadline() deadline, ok = c2.Deadline()
assert.Equal(t, d, deadline) assert.Equal(t, d, deadline)
assert.True(t, ok) assert.True(t, ok)
@ -2190,11 +2245,13 @@ func TestContextWithFallbackDoneFromRequestContext(t *testing.T) {
// enable ContextWithFallback feature flag // enable ContextWithFallback feature flag
c2.engine.ContextWithFallback = true c2.engine.ContextWithFallback = true
c2.Request, _ = http.NewRequest(http.MethodGet, "/", nil)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
c2.Request = c2.Request.WithContext(ctx) req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/", nil)
c2.setRequest(req)
cancel() cancel()
assert.NotNil(t, <-c2.Done()) if assert.NotNil(t, c2.Done()) {
assert.NotNil(t, <-c2.Done())
}
} }
func TestContextWithFallbackErrFromRequestContext(t *testing.T) { func TestContextWithFallbackErrFromRequestContext(t *testing.T) {
@ -2208,9 +2265,9 @@ func TestContextWithFallbackErrFromRequestContext(t *testing.T) {
// enable ContextWithFallback feature flag // enable ContextWithFallback feature flag
c2.engine.ContextWithFallback = true c2.engine.ContextWithFallback = true
c2.Request, _ = http.NewRequest(http.MethodGet, "/", nil)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
c2.Request = c2.Request.WithContext(ctx) req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/", nil)
c2.setRequest(req)
cancel() cancel()
assert.EqualError(t, c2.Err(), context.Canceled.Error()) assert.EqualError(t, c2.Err(), context.Canceled.Error())
@ -2231,8 +2288,9 @@ func TestContextWithFallbackValueFromRequestContext(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder()) c, _ := CreateTestContext(httptest.NewRecorder())
// enable ContextWithFallback feature flag // enable ContextWithFallback feature flag
c.engine.ContextWithFallback = true c.engine.ContextWithFallback = true
c.Request, _ = http.NewRequest("POST", "/", nil) ctx := context.WithValue(context.TODO(), key, "value")
c.Request = c.Request.WithContext(context.WithValue(context.TODO(), key, "value")) req, _ := http.NewRequestWithContext(ctx, "POST", "/", nil)
c.setRequest(req)
return c, key return c, key
}, },
value: "value", value: "value",
@ -2243,8 +2301,9 @@ func TestContextWithFallbackValueFromRequestContext(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder()) c, _ := CreateTestContext(httptest.NewRecorder())
// enable ContextWithFallback feature flag // enable ContextWithFallback feature flag
c.engine.ContextWithFallback = true c.engine.ContextWithFallback = true
c.Request, _ = http.NewRequest("POST", "/", nil) ctx := context.WithValue(context.TODO(), contextKey("key"), "value")
c.Request = c.Request.WithContext(context.WithValue(context.TODO(), contextKey("key"), "value")) req, _ := http.NewRequestWithContext(ctx, "POST", "/", nil)
c.setRequest(req)
return c, contextKey("key") return c, contextKey("key")
}, },
value: "value", value: "value",
@ -2255,7 +2314,7 @@ func TestContextWithFallbackValueFromRequestContext(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder()) c, _ := CreateTestContext(httptest.NewRecorder())
// enable ContextWithFallback feature flag // enable ContextWithFallback feature flag
c.engine.ContextWithFallback = true c.engine.ContextWithFallback = true
c.Request = nil c.setRequest(nil)
return c, "key" return c, "key"
}, },
value: nil, value: nil,
@ -2266,7 +2325,8 @@ func TestContextWithFallbackValueFromRequestContext(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder()) c, _ := CreateTestContext(httptest.NewRecorder())
// enable ContextWithFallback feature flag // enable ContextWithFallback feature flag
c.engine.ContextWithFallback = true c.engine.ContextWithFallback = true
c.Request, _ = http.NewRequest("POST", "/", nil) req, _ := http.NewRequest("POST", "/", nil)
c.setRequest(req)
return c, "key" return c, "key"
}, },
value: nil, value: nil,
@ -2280,6 +2340,16 @@ func TestContextWithFallbackValueFromRequestContext(t *testing.T) {
} }
} }
func TestContextWithFallbackModifiesRequestContext(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder())
// enable ContextWithFallback feature flag
c.engine.ContextWithFallback = true
req, _ := http.NewRequest(http.MethodGet, "/", nil)
c.setRequest(req)
assert.Equal(t, c, c.Request.Context())
}
func TestContextCopyShouldNotCancel(t *testing.T) { func TestContextCopyShouldNotCancel(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)

2
gin.go
View File

@ -566,7 +566,7 @@ func (engine *Engine) RunListener(listener net.Listener) (err error) {
func (engine *Engine) ServeHTTP(w http.ResponseWriter, req *http.Request) { func (engine *Engine) ServeHTTP(w http.ResponseWriter, req *http.Request) {
c := engine.pool.Get().(*Context) c := engine.pool.Get().(*Context)
c.writermem.reset(w) c.writermem.reset(w)
c.Request = req c.setRequest(req)
c.reset() c.reset()
engine.handleHTTPRequest(c) engine.handleHTTPRequest(c)