From 159db68702aac16693797b76428495d141b3cf84 Mon Sep 17 00:00:00 2001 From: walle250ai Date: Mon, 27 Apr 2026 20:42:42 +0800 Subject: [PATCH] feat(binding): add JSONStrict binding for per-request unknown field rejection Add jsonStrictBinding that always enables DisallowUnknownFields on the JSON decoder, independent of the global EnableDecoderDisallowUnknownFields flag. Exposes BindJSONStrict and ShouldBindJSONStrict on Context. Co-Authored-By: Claude Sonnet 4.6 --- binding/binding.go | 1 + binding/json.go | 29 ++++++++ binding/json_test.go | 66 ++++++++++++++++++ context.go | 14 ++++ context_test.go | 161 +++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 271 insertions(+) diff --git a/binding/binding.go b/binding/binding.go index eced0ae2..4ef52500 100644 --- a/binding/binding.go +++ b/binding/binding.go @@ -75,6 +75,7 @@ var Validator StructValidator = &defaultValidator{} // present in the request to struct instances. var ( JSON BindingBody = jsonBinding{} + JSONStrict BindingBody = jsonStrictBinding{} XML BindingBody = xmlBinding{} Form Binding = formBinding{} Query Binding = queryBinding{} diff --git a/binding/json.go b/binding/json.go index f4ae921a..a61b5a22 100644 --- a/binding/json.go +++ b/binding/json.go @@ -54,3 +54,32 @@ func decodeJSON(r io.Reader, obj any) error { } return validate(obj) } + +type jsonStrictBinding struct{} + +func (jsonStrictBinding) Name() string { + return "json-strict" +} + +func (jsonStrictBinding) Bind(req *http.Request, obj any) error { + if req == nil || req.Body == nil { + return errors.New("invalid request") + } + return decodeJSONStrict(req.Body, obj) +} + +func (jsonStrictBinding) BindBody(body []byte, obj any) error { + return decodeJSONStrict(bytes.NewReader(body), obj) +} + +func decodeJSONStrict(r io.Reader, obj any) error { + decoder := json.API.NewDecoder(r) + if EnableDecoderUseNumber { + decoder.UseNumber() + } + decoder.DisallowUnknownFields() + if err := decoder.Decode(obj); err != nil { + return err + } + return validate(obj) +} diff --git a/binding/json_test.go b/binding/json_test.go index 942ee3eb..719ccf5b 100644 --- a/binding/json_test.go +++ b/binding/json_test.go @@ -214,3 +214,69 @@ func (tpc timePointerCodec) Decode(ptr unsafe.Pointer, iter *jsoniter.Iterator) } // endregion + +func TestJSONStrictBindingBindBody(t *testing.T) { + t.Run("normal request with known fields", func(t *testing.T) { + var s struct { + Foo string `json:"foo"` + } + err := jsonStrictBinding{}.BindBody([]byte(`{"foo": "FOO"}`), &s) + require.NoError(t, err) + assert.Equal(t, "FOO", s.Foo) + }) + + t.Run("request with unknown fields should error", func(t *testing.T) { + var s struct { + Foo string `json:"foo"` + } + err := jsonStrictBinding{}.BindBody([]byte(`{"foo": "FOO", "bar": "BAR"}`), &s) + require.Error(t, err) + assert.Contains(t, err.Error(), "unknown field") + }) + + t.Run("empty body should error", func(t *testing.T) { + var s struct { + Foo string `json:"foo"` + } + err := jsonStrictBinding{}.BindBody([]byte{}, &s) + require.Error(t, err) + }) + + t.Run("invalid JSON should error", func(t *testing.T) { + var s struct { + Foo string `json:"foo"` + } + err := jsonStrictBinding{}.BindBody([]byte(`{"foo": "FOO"`), &s) + require.Error(t, err) + }) + + t.Run("jsonBinding should ignore unknown fields when global switch is off", func(t *testing.T) { + oldValue := EnableDecoderDisallowUnknownFields + defer func() { + EnableDecoderDisallowUnknownFields = oldValue + }() + EnableDecoderDisallowUnknownFields = false + + var s struct { + Foo string `json:"foo"` + } + err := jsonBinding{}.BindBody([]byte(`{"foo": "FOO", "bar": "BAR"}`), &s) + require.NoError(t, err) + assert.Equal(t, "FOO", s.Foo) + }) + + t.Run("jsonStrictBinding should always reject unknown fields regardless of global switch", func(t *testing.T) { + oldValue := EnableDecoderDisallowUnknownFields + defer func() { + EnableDecoderDisallowUnknownFields = oldValue + }() + EnableDecoderDisallowUnknownFields = false + + var s struct { + Foo string `json:"foo"` + } + err := jsonStrictBinding{}.BindBody([]byte(`{"foo": "FOO", "bar": "BAR"}`), &s) + require.Error(t, err) + assert.Contains(t, err.Error(), "unknown field") + }) +} diff --git a/context.go b/context.go index 5174033e..479b84aa 100644 --- a/context.go +++ b/context.go @@ -764,6 +764,13 @@ func (c *Context) BindJSON(obj any) error { return c.MustBindWith(obj, binding.JSON) } +// BindJSONStrict is a shortcut for c.MustBindWith(obj, binding.JSONStrict). +// It will return an error and abort the request with HTTP 400 if any error occurs, +// including when the JSON contains unknown fields. +func (c *Context) BindJSONStrict(obj any) error { + return c.MustBindWith(obj, binding.JSONStrict) +} + // BindXML is a shortcut for c.MustBindWith(obj, binding.BindXML). func (c *Context) BindXML(obj any) error { return c.MustBindWith(obj, binding.XML) @@ -868,6 +875,13 @@ func (c *Context) ShouldBindJSON(obj any) error { return c.ShouldBindWith(obj, binding.JSON) } +// ShouldBindJSONStrict is a shortcut for c.ShouldBindWith(obj, binding.JSONStrict). +// It works like ShouldBindJSON but returns an error if the JSON contains unknown fields. +// This method does not set the response status code to 400 or abort if input is not valid. +func (c *Context) ShouldBindJSONStrict(obj any) error { + return c.ShouldBindWith(obj, binding.JSONStrict) +} + // ShouldBindXML is a shortcut for c.ShouldBindWith(obj, binding.XML). // It works like ShouldBindJSON but binds the request body as XML data. func (c *Context) ShouldBindXML(obj any) error { diff --git a/context_test.go b/context_test.go index ef60379d..a7489e41 100644 --- a/context_test.go +++ b/context_test.go @@ -3808,3 +3808,164 @@ func BenchmarkGetMapFromFormData(b *testing.B) { }) } } + +func TestContextBindJSONStrict(t *testing.T) { + t.Run("normal request with known fields should succeed", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := CreateTestContext(w) + + c.Request, _ = http.NewRequest(http.MethodPost, "/", strings.NewReader(`{"foo":"bar", "bar":"foo"}`)) + c.Request.Header.Add("Content-Type", MIMEJSON) + + var obj struct { + Foo string `json:"foo"` + Bar string `json:"bar"` + } + require.NoError(t, c.BindJSONStrict(&obj)) + assert.Equal(t, "foo", obj.Bar) + assert.Equal(t, "bar", obj.Foo) + assert.Equal(t, 0, w.Body.Len()) + assert.False(t, c.IsAborted()) + }) + + t.Run("request with unknown fields should return 400 and abort", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := CreateTestContext(w) + + c.Request, _ = http.NewRequest(http.MethodPost, "/", strings.NewReader(`{"foo":"bar", "unknown":"field"}`)) + c.Request.Header.Add("Content-Type", MIMEJSON) + + var obj struct { + Foo string `json:"foo"` + } + require.Error(t, c.BindJSONStrict(&obj)) + assert.Contains(t, c.Errors.Last().Err.Error(), "unknown field") + assert.Equal(t, http.StatusBadRequest, w.Code) + assert.True(t, c.IsAborted()) + }) + + t.Run("invalid JSON should return 400 and abort", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := CreateTestContext(w) + + c.Request, _ = http.NewRequest(http.MethodPost, "/", strings.NewReader(`{"foo":"bar"`)) + c.Request.Header.Add("Content-Type", MIMEJSON) + + var obj struct { + Foo string `json:"foo"` + } + require.Error(t, c.BindJSONStrict(&obj)) + assert.Equal(t, http.StatusBadRequest, w.Code) + assert.True(t, c.IsAborted()) + }) + + t.Run("empty body should return 400 and abort", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := CreateTestContext(w) + + c.Request, _ = http.NewRequest(http.MethodPost, "/", strings.NewReader("")) + c.Request.Header.Add("Content-Type", MIMEJSON) + + var obj struct { + Foo string `json:"foo"` + } + require.Error(t, c.BindJSONStrict(&obj)) + assert.Equal(t, http.StatusBadRequest, w.Code) + assert.True(t, c.IsAborted()) + }) +} + +func TestContextShouldBindJSONStrict(t *testing.T) { + t.Run("normal request with known fields should succeed", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := CreateTestContext(w) + + c.Request, _ = http.NewRequest(http.MethodPost, "/", strings.NewReader(`{"foo":"bar", "bar":"foo"}`)) + c.Request.Header.Add("Content-Type", MIMEJSON) + + var obj struct { + Foo string `json:"foo"` + Bar string `json:"bar"` + } + require.NoError(t, c.ShouldBindJSONStrict(&obj)) + assert.Equal(t, "foo", obj.Bar) + assert.Equal(t, "bar", obj.Foo) + assert.Equal(t, 0, w.Body.Len()) + assert.False(t, c.IsAborted()) + }) + + t.Run("request with unknown fields should return error but not abort", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := CreateTestContext(w) + + c.Request, _ = http.NewRequest(http.MethodPost, "/", strings.NewReader(`{"foo":"bar", "unknown":"field"}`)) + c.Request.Header.Add("Content-Type", MIMEJSON) + + var obj struct { + Foo string `json:"foo"` + } + err := c.ShouldBindJSONStrict(&obj) + require.Error(t, err) + assert.Contains(t, err.Error(), "unknown field") + assert.False(t, c.IsAborted()) + }) + + t.Run("invalid JSON should return error but not abort", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := CreateTestContext(w) + + c.Request, _ = http.NewRequest(http.MethodPost, "/", strings.NewReader(`{"foo":"bar"`)) + c.Request.Header.Add("Content-Type", MIMEJSON) + + var obj struct { + Foo string `json:"foo"` + } + err := c.ShouldBindJSONStrict(&obj) + require.Error(t, err) + assert.False(t, c.IsAborted()) + }) +} + +func TestContextJSONBindingIndependence(t *testing.T) { + t.Run("BindJSON should ignore unknown fields when global switch is off", func(t *testing.T) { + oldValue := binding.EnableDecoderDisallowUnknownFields + defer func() { + binding.EnableDecoderDisallowUnknownFields = oldValue + }() + binding.EnableDecoderDisallowUnknownFields = false + + w := httptest.NewRecorder() + c, _ := CreateTestContext(w) + + c.Request, _ = http.NewRequest(http.MethodPost, "/", strings.NewReader(`{"foo":"bar", "unknown":"field"}`)) + c.Request.Header.Add("Content-Type", MIMEJSON) + + var obj struct { + Foo string `json:"foo"` + } + require.NoError(t, c.BindJSON(&obj)) + assert.Equal(t, "bar", obj.Foo) + assert.False(t, c.IsAborted()) + }) + + t.Run("BindJSONStrict should always reject unknown fields regardless of global switch", func(t *testing.T) { + oldValue := binding.EnableDecoderDisallowUnknownFields + defer func() { + binding.EnableDecoderDisallowUnknownFields = oldValue + }() + binding.EnableDecoderDisallowUnknownFields = false + + w := httptest.NewRecorder() + c, _ := CreateTestContext(w) + + c.Request, _ = http.NewRequest(http.MethodPost, "/", strings.NewReader(`{"foo":"bar", "unknown":"field"}`)) + c.Request.Header.Add("Content-Type", MIMEJSON) + + var obj struct { + Foo string `json:"foo"` + } + require.Error(t, c.BindJSONStrict(&obj)) + assert.Contains(t, c.Errors.Last().Err.Error(), "unknown field") + assert.True(t, c.IsAborted()) + }) +}