From 2efd359c205ccc6827a59ecf056441b90b5b35a1 Mon Sep 17 00:00:00 2001 From: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> Date: Fri, 25 Apr 2025 22:26:08 -0700 Subject: [PATCH] Bind: return 413 status code when error is `http.MaxBytesError` The Go standard library includes a method `http.MaxBytesReader` that allows limiting the request body. For example, users can create a middleware like: ```go func MiddlewareMaxBodySize(c *gin.Context) { // Limit request body to 100 bytes c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, 100) c.Next() } ``` When the body exceeds the limit, reading from the request body returns an error of type `http.MaxBytesError`. This PR makes sure that when the error is of kind `http.MaxBytesError`, Gin returns the correct status code 413 (Request Entity Too Large) instead of a generic 400 (Bad Request). --- context.go | 11 +++++++-- context_test.go | 64 ++++++++++++++++++++++++++++++++----------------- gin_test.go | 10 ++++---- 3 files changed, 56 insertions(+), 29 deletions(-) diff --git a/context.go b/context.go index 1c76c0f6..288c0afb 100644 --- a/context.go +++ b/context.go @@ -769,8 +769,15 @@ func (c *Context) BindUri(obj any) error { // It will abort the request with HTTP 400 if any error occurs. // See the binding package. func (c *Context) MustBindWith(obj any, b binding.Binding) error { - if err := c.ShouldBindWith(obj, b); err != nil { - c.AbortWithError(http.StatusBadRequest, err).SetType(ErrorTypeBind) //nolint: errcheck + err := c.ShouldBindWith(obj, b) + if err != nil { + maxBytesErr := &http.MaxBytesError{} + switch { + case errors.As(err, &maxBytesErr): + c.AbortWithError(http.StatusRequestEntityTooLarge, err).SetType(ErrorTypeBind) //nolint: errcheck + default: + c.AbortWithError(http.StatusBadRequest, err).SetType(ErrorTypeBind) //nolint: errcheck + } return err } return nil diff --git a/context_test.go b/context_test.go index ef0cfccd..3c0ab82e 100644 --- a/context_test.go +++ b/context_test.go @@ -671,7 +671,7 @@ func TestContextDefaultQueryOnEmptyRequest(t *testing.T) { func TestContextQueryAndPostForm(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) - body := bytes.NewBufferString("foo=bar&page=11&both=&foo=second") + body := strings.NewReader("foo=bar&page=11&both=&foo=second") c.Request, _ = http.NewRequest(http.MethodPost, "/?both=GET&id=main&id=omit&array[]=first&array[]=second&ids[a]=hi&ids[b]=3.14", body) c.Request.Header.Add("Content-Type", MIMEPOSTForm) @@ -946,7 +946,7 @@ func TestContextRenderJSONPWithoutCallback(t *testing.T) { c.JSONP(http.StatusCreated, H{"foo": "bar"}) assert.Equal(t, http.StatusCreated, w.Code) - assert.Equal(t, "{\"foo\":\"bar\"}", w.Body.String()) + assert.Equal(t, `{"foo":"bar"}`, w.Body.String()) assert.Equal(t, "application/json; charset=utf-8", w.Header().Get("Content-Type")) } @@ -972,7 +972,7 @@ func TestContextRenderAPIJSON(t *testing.T) { c.JSON(http.StatusCreated, H{"foo": "bar"}) assert.Equal(t, http.StatusCreated, w.Code) - assert.Equal(t, "{\"foo\":\"bar\"}", w.Body.String()) + assert.Equal(t, `{"foo":"bar"}`, w.Body.String()) assert.Equal(t, "application/vnd.api+json", w.Header().Get("Content-Type")) } @@ -1432,7 +1432,7 @@ func TestContextNegotiationWithJSON(t *testing.T) { }) assert.Equal(t, http.StatusOK, w.Code) - assert.Equal(t, "{\"foo\":\"bar\"}", w.Body.String()) + assert.Equal(t, `{"foo":"bar"}`, w.Body.String()) assert.Equal(t, "application/json; charset=utf-8", w.Header().Get("Content-Type")) } @@ -1848,9 +1848,29 @@ func TestContextContentType(t *testing.T) { assert.Equal(t, "application/json", c.ContentType()) } +func TestContextBindRequestTooLarge(t *testing.T) { + w := httptest.NewRecorder() + c, _ := CreateTestContext(w) + + c.Request, _ = http.NewRequest(http.MethodPost, "/", strings.NewReader(`{"foo":"bar", "bar":"foo"}`)) + c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, 10) + + var obj struct { + Foo string `json:"foo"` + Bar string `json:"bar"` + } + require.Error(t, c.BindJSON(&obj)) + c.Writer.WriteHeaderNow() + + assert.Empty(t, obj.Bar) + assert.Empty(t, obj.Foo) + assert.Equal(t, http.StatusRequestEntityTooLarge, w.Code) + assert.True(t, c.IsAborted()) +} + func TestContextAutoBindJSON(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) - c.Request, _ = http.NewRequest(http.MethodPost, "/", bytes.NewBufferString("{\"foo\":\"bar\", \"bar\":\"foo\"}")) + c.Request, _ = http.NewRequest(http.MethodPost, "/", strings.NewReader(`{"foo":"bar", "bar":"foo"}`)) c.Request.Header.Add("Content-Type", MIMEJSON) var obj struct { @@ -1867,7 +1887,7 @@ func TestContextBindWithJSON(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest(http.MethodPost, "/", bytes.NewBufferString("{\"foo\":\"bar\", \"bar\":\"foo\"}")) + c.Request, _ = http.NewRequest(http.MethodPost, "/", strings.NewReader(`{"foo":"bar", "bar":"foo"}`)) c.Request.Header.Add("Content-Type", MIMEXML) // set fake content-type var obj struct { @@ -1884,7 +1904,7 @@ func TestContextBindWithXML(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest(http.MethodPost, "/", bytes.NewBufferString(` + c.Request, _ = http.NewRequest(http.MethodPost, "/", strings.NewReader(` FOO BAR @@ -1951,7 +1971,7 @@ func TestContextBindWithQuery(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest(http.MethodPost, "/?foo=bar&bar=foo", bytes.NewBufferString("foo=unused")) + c.Request, _ = http.NewRequest(http.MethodPost, "/?foo=bar&bar=foo", strings.NewReader("foo=unused")) var obj struct { Foo string `form:"foo"` @@ -1967,7 +1987,7 @@ func TestContextBindWithYAML(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest(http.MethodPost, "/", bytes.NewBufferString("foo: bar\nbar: foo")) + c.Request, _ = http.NewRequest(http.MethodPost, "/", strings.NewReader("foo: bar\nbar: foo")) c.Request.Header.Add("Content-Type", MIMEXML) // set fake content-type var obj struct { @@ -1984,7 +2004,7 @@ func TestContextBindWithTOML(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest(http.MethodPost, "/", bytes.NewBufferString("foo = 'bar'\nbar = 'foo'")) + c.Request, _ = http.NewRequest(http.MethodPost, "/", strings.NewReader("foo = 'bar'\nbar = 'foo'")) c.Request.Header.Add("Content-Type", MIMEXML) // set fake content-type var obj struct { @@ -2001,7 +2021,7 @@ func TestContextBadAutoBind(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest(http.MethodPost, "http://example.com", bytes.NewBufferString("\"foo\":\"bar\", \"bar\":\"foo\"}")) + c.Request, _ = http.NewRequest(http.MethodPost, "http://example.com", strings.NewReader("\"foo\":\"bar\", \"bar\":\"foo\"}")) c.Request.Header.Add("Content-Type", MIMEJSON) var obj struct { Foo string `json:"foo"` @@ -2020,7 +2040,7 @@ func TestContextBadAutoBind(t *testing.T) { func TestContextAutoShouldBindJSON(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) - c.Request, _ = http.NewRequest(http.MethodPost, "/", bytes.NewBufferString("{\"foo\":\"bar\", \"bar\":\"foo\"}")) + c.Request, _ = http.NewRequest(http.MethodPost, "/", strings.NewReader(`{"foo":"bar", "bar":"foo"}`)) c.Request.Header.Add("Content-Type", MIMEJSON) var obj struct { @@ -2037,7 +2057,7 @@ func TestContextShouldBindWithJSON(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest(http.MethodPost, "/", bytes.NewBufferString("{\"foo\":\"bar\", \"bar\":\"foo\"}")) + c.Request, _ = http.NewRequest(http.MethodPost, "/", strings.NewReader(`{"foo":"bar", "bar":"foo"}`)) c.Request.Header.Add("Content-Type", MIMEXML) // set fake content-type var obj struct { @@ -2054,7 +2074,7 @@ func TestContextShouldBindWithXML(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest(http.MethodPost, "/", bytes.NewBufferString(` + c.Request, _ = http.NewRequest(http.MethodPost, "/", strings.NewReader(` FOO BAR @@ -2121,7 +2141,7 @@ func TestContextShouldBindWithQuery(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest(http.MethodPost, "/?foo=bar&bar=foo&Foo=bar1&Bar=foo1", bytes.NewBufferString("foo=unused")) + c.Request, _ = http.NewRequest(http.MethodPost, "/?foo=bar&bar=foo&Foo=bar1&Bar=foo1", strings.NewReader("foo=unused")) var obj struct { Foo string `form:"foo"` @@ -2141,7 +2161,7 @@ func TestContextShouldBindWithYAML(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest(http.MethodPost, "/", bytes.NewBufferString("foo: bar\nbar: foo")) + c.Request, _ = http.NewRequest(http.MethodPost, "/", strings.NewReader("foo: bar\nbar: foo")) c.Request.Header.Add("Content-Type", MIMEXML) // set fake content-type var obj struct { @@ -2158,7 +2178,7 @@ func TestContextShouldBindWithTOML(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest(http.MethodPost, "/", bytes.NewBufferString("foo='bar'\nbar= 'foo'")) + c.Request, _ = http.NewRequest(http.MethodPost, "/", strings.NewReader("foo='bar'\nbar= 'foo'")) c.Request.Header.Add("Content-Type", MIMETOML) // set fake content-type var obj struct { @@ -2175,7 +2195,7 @@ func TestContextBadAutoShouldBind(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Request, _ = http.NewRequest(http.MethodPost, "http://example.com", bytes.NewBufferString("\"foo\":\"bar\", \"bar\":\"foo\"}")) + c.Request, _ = http.NewRequest(http.MethodPost, "http://example.com", strings.NewReader(`"foo":"bar", "bar":"foo"}`)) c.Request.Header.Add("Content-Type", MIMEJSON) var obj struct { Foo string `json:"foo"` @@ -2239,7 +2259,7 @@ func TestContextShouldBindBodyWith(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) c.Request, _ = http.NewRequest( - http.MethodPost, "http://example.com", bytes.NewBufferString(tt.bodyA), + http.MethodPost, "http://example.com", strings.NewReader(tt.bodyA), ) // When it binds to typeA and typeB, it finds the body is // not typeB but typeA. @@ -2257,7 +2277,7 @@ func TestContextShouldBindBodyWith(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) c.Request, _ = http.NewRequest( - http.MethodPost, "http://example.com", bytes.NewBufferString(tt.bodyB), + http.MethodPost, "http://example.com", strings.NewReader(tt.bodyB), ) objA := typeA{} require.Error(t, c.ShouldBindBodyWith(&objA, tt.bindingA)) @@ -2603,7 +2623,7 @@ func TestContextShouldBindBodyWithPlain(t *testing.T) { func TestContextGolangContext(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) - c.Request, _ = http.NewRequest(http.MethodPost, "/", bytes.NewBufferString("{\"foo\":\"bar\", \"bar\":\"foo\"}")) + c.Request, _ = http.NewRequest(http.MethodPost, "/", strings.NewReader(`{"foo":"bar", "bar":"foo"}`)) require.NoError(t, c.Err()) assert.Nil(t, c.Done()) ti, ok := c.Deadline() @@ -2651,7 +2671,7 @@ func TestGetRequestHeaderValue(t *testing.T) { func TestContextGetRawData(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) - body := bytes.NewBufferString("Fetch binary post data") + body := strings.NewReader("Fetch binary post data") c.Request, _ = http.NewRequest(http.MethodPost, "/", body) c.Request.Header.Add("Content-Type", MIMEPOSTForm) diff --git a/gin_test.go b/gin_test.go index a80b690e..5eba4e30 100644 --- a/gin_test.go +++ b/gin_test.go @@ -338,7 +338,7 @@ func TestLoadHTMLFSTestMode(t *testing.T) { ) defer ts.Close() - res, err := http.Get(fmt.Sprintf("%s/test", ts.URL)) + res, err := http.Get(ts.URL + "/test") if err != nil { t.Error(err) } @@ -358,7 +358,7 @@ func TestLoadHTMLFSDebugMode(t *testing.T) { ) defer ts.Close() - res, err := http.Get(fmt.Sprintf("%s/test", ts.URL)) + res, err := http.Get(ts.URL + "/test") if err != nil { t.Error(err) } @@ -378,7 +378,7 @@ func TestLoadHTMLFSReleaseMode(t *testing.T) { ) defer ts.Close() - res, err := http.Get(fmt.Sprintf("%s/test", ts.URL)) + res, err := http.Get(ts.URL + "/test") if err != nil { t.Error(err) } @@ -405,7 +405,7 @@ func TestLoadHTMLFSUsingTLS(t *testing.T) { }, } client := &http.Client{Transport: tr} - res, err := client.Get(fmt.Sprintf("%s/test", ts.URL)) + res, err := client.Get(ts.URL + "/test") if err != nil { t.Error(err) } @@ -425,7 +425,7 @@ func TestLoadHTMLFSFuncMap(t *testing.T) { ) defer ts.Close() - res, err := http.Get(fmt.Sprintf("%s/raw", ts.URL)) + res, err := http.Get(ts.URL + "/raw") if err != nil { t.Error(err) }