From b626f639064cf3a72e4698c7804ecef2b06b7ec4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krzysztof=20Szafra=C5=84ski?= Date: Fri, 5 Nov 2021 18:16:37 +0100 Subject: [PATCH] Add support for contextual validation By passing the Gin context to bindings, custom validators can take advantage of the information in the context. --- binding/binding.go | 45 ++++++++++-- binding/binding_msgpack_test.go | 15 +++- binding/binding_nomsgpack.go | 47 +++++++++++-- binding/binding_test.go | 75 +++++++++++++++++--- binding/default_validator.go | 27 +++++--- binding/form.go | 26 +++++-- binding/header.go | 9 ++- binding/json.go | 21 ++++-- binding/msgpack.go | 21 ++++-- binding/query.go | 13 +++- binding/uri.go | 10 ++- binding/validate_test.go | 5 ++ binding/xml.go | 22 ++++-- binding/yaml.go | 21 ++++-- context.go | 8 ++- context_test.go | 118 ++++++++++++++++++++++++++++++++ 16 files changed, 410 insertions(+), 73 deletions(-) diff --git a/binding/binding.go b/binding/binding.go index 7042101d..b400e6c5 100644 --- a/binding/binding.go +++ b/binding/binding.go @@ -7,7 +7,10 @@ package binding -import "net/http" +import ( + "context" + "net/http" +) // Content-Type MIME of the most common data formats. const ( @@ -32,20 +35,41 @@ type Binding interface { Bind(*http.Request, interface{}) error } -// BindingBody adds BindBody method to Binding. BindBody is similar with Bind, +// ContextBinding enables contextual validation by adding BindContext to Binding. +// Custom validators can take advantage of the information in the context. +type ContextBinding interface { + Binding + BindContext(context.Context, *http.Request, interface{}) error +} + +// BindingBody adds BindBody method to Binding. BindBody is similar to Bind, // but it reads the body from supplied bytes instead of req.Body. type BindingBody interface { Binding BindBody([]byte, interface{}) error } -// BindingUri adds BindUri method to Binding. BindUri is similar with Bind, -// but it read the Params. +// ContextBindingBody enables contextual validation by adding BindBodyContext to BindingBody. +// Custom validators can take advantage of the information in the context. +type ContextBindingBody interface { + BindingBody + BindContext(context.Context, *http.Request, interface{}) error + BindBodyContext(context.Context, []byte, interface{}) error +} + +// BindingUri is similar to Bind, but it read the Params. type BindingUri interface { Name() string BindUri(map[string][]string, interface{}) error } +// ContextBindingUri enables contextual validation by adding BindUriContext to BindingUri. +// Custom validators can take advantage of the information in the context. +type ContextBindingUri interface { + BindingUri + BindUriContext(context.Context, map[string][]string, interface{}) error +} + // StructValidator is the minimal interface which needs to be implemented in // order for it to be used as the validator engine for ensuring the correctness // of the request. Gin provides a default implementation for this using @@ -64,6 +88,14 @@ type StructValidator interface { Engine() interface{} } +// ContextStructValidator is an extension of StructValidator that requires implementing +// context-aware validation. +// Custom validators can take advantage of the information in the context. +type ContextStructValidator interface { + StructValidator + ValidateStructContext(context.Context, interface{}) error +} + // Validator is the default validator which implements the StructValidator // interface. It uses https://github.com/go-playground/validator/tree/v10.6.1 // under the hood. @@ -110,9 +142,12 @@ func Default(method, contentType string) Binding { } } -func validate(obj interface{}) error { +func validateContext(ctx context.Context, obj interface{}) error { if Validator == nil { return nil } + if v, ok := Validator.(ContextStructValidator); ok { + return v.ValidateStructContext(ctx, obj) + } return Validator.ValidateStruct(obj) } diff --git a/binding/binding_msgpack_test.go b/binding/binding_msgpack_test.go index 04d94079..7bc6d47d 100644 --- a/binding/binding_msgpack_test.go +++ b/binding/binding_msgpack_test.go @@ -9,6 +9,7 @@ package binding import ( "bytes" + "context" "testing" "github.com/stretchr/testify/assert" @@ -35,7 +36,7 @@ func TestBindingMsgPack(t *testing.T) { string(data), string(data[1:])) } -func testMsgPackBodyBinding(t *testing.T, b Binding, name, path, badPath, body, badBody string) { +func testMsgPackBodyBinding(t *testing.T, b ContextBinding, name, path, badPath, body, badBody string) { assert.Equal(t, name, b.Name()) obj := FooStruct{} @@ -48,7 +49,17 @@ func testMsgPackBodyBinding(t *testing.T, b Binding, name, path, badPath, body, obj = FooStruct{} req = requestWithBody("POST", badPath, badBody) req.Header.Add("Content-Type", MIMEMSGPACK) - err = MsgPack.Bind(req, &obj) + err = b.Bind(req, &obj) + assert.Error(t, err) + + obj2 := ConditionalFooStruct{} + req = requestWithBody("POST", path, body) + req.Header.Add("Content-Type", MIMEMSGPACK) + err = b.BindContext(context.Background(), req, &obj2) + assert.NoError(t, err) + assert.Equal(t, "bar", obj2.Foo) + + err = b.BindContext(context.WithValue(context.Background(), "condition", true), req, &obj2) // nolint assert.Error(t, err) } diff --git a/binding/binding_nomsgpack.go b/binding/binding_nomsgpack.go index 00d63036..f9b6f511 100644 --- a/binding/binding_nomsgpack.go +++ b/binding/binding_nomsgpack.go @@ -7,7 +7,10 @@ package binding -import "net/http" +import ( + "context" + "net/http" +) // Content-Type MIME of the most common data formats. const ( @@ -30,20 +33,41 @@ type Binding interface { Bind(*http.Request, interface{}) error } -// BindingBody adds BindBody method to Binding. BindBody is similar with Bind, +// ContextBinding enables contextual validation by adding BindContext to Binding. +// Custom validators can take advantage of the information in the context. +type ContextBinding interface { + Binding + BindContext(context.Context, *http.Request, interface{}) error +} + +// BindingBody adds BindBody method to Binding. BindBody is similar to Bind, // but it reads the body from supplied bytes instead of req.Body. type BindingBody interface { Binding BindBody([]byte, interface{}) error } -// BindingUri adds BindUri method to Binding. BindUri is similar with Bind, -// but it read the Params. +// ContextBindingBody enables contextual validation by adding BindBodyContext to BindingBody. +// Custom validators can take advantage of the information in the context. +type ContextBindingBody interface { + BindingBody + BindContext(context.Context, *http.Request, interface{}) error + BindBodyContext(context.Context, []byte, interface{}) error +} + +// BindingUri is similar to Bind, but it read the Params. type BindingUri interface { Name() string BindUri(map[string][]string, interface{}) error } +// ContextBindingUri enables contextual validation by adding BindUriContext to BindingUri. +// Custom validators can take advantage of the information in the context. +type ContextBindingUri interface { + BindingUri + BindUriContext(context.Context, map[string][]string, interface{}) error +} + // StructValidator is the minimal interface which needs to be implemented in // order for it to be used as the validator engine for ensuring the correctness // of the request. Gin provides a default implementation for this using @@ -62,6 +86,14 @@ type StructValidator interface { Engine() interface{} } +// ContextStructValidator is an extension of StructValidator that requires implementing +// context-aware validation. +// Custom validators can take advantage of the information in the context. +type ContextStructValidator interface { + StructValidator + ValidateStructContext(context.Context, interface{}) error +} + // Validator is the default validator which implements the StructValidator // interface. It uses https://github.com/go-playground/validator/tree/v10.6.1 // under the hood. @@ -85,7 +117,7 @@ var ( // Default returns the appropriate Binding instance based on the HTTP method // and the content type. func Default(method, contentType string) Binding { - if method == "GET" { + if method == http.MethodGet { return Form } @@ -105,9 +137,12 @@ func Default(method, contentType string) Binding { } } -func validate(obj interface{}) error { +func validateContext(ctx context.Context, obj interface{}) error { if Validator == nil { return nil } + if v, ok := Validator.(ContextStructValidator); ok { + return v.ValidateStructContext(ctx, obj) + } return Validator.ValidateStruct(obj) } diff --git a/binding/binding_test.go b/binding/binding_test.go index 5b0ce39d..c1d449a0 100644 --- a/binding/binding_test.go +++ b/binding/binding_test.go @@ -6,6 +6,7 @@ package binding import ( "bytes" + "context" "encoding/json" "errors" "io" @@ -20,6 +21,7 @@ import ( "time" "github.com/gin-gonic/gin/testdata/protoexample" + "github.com/go-playground/validator/v10" "github.com/stretchr/testify/assert" "google.golang.org/protobuf/proto" ) @@ -38,6 +40,10 @@ type FooStruct struct { Foo string `msgpack:"foo" json:"foo" form:"foo" xml:"foo" binding:"required,max=32"` } +type ConditionalFooStruct struct { + Foo string `msgpack:"foo" json:"foo" form:"foo" xml:"foo" binding:"required_if_condition,max=32"` +} + type FooBarStruct struct { FooStruct Bar string `msgpack:"bar" json:"bar" form:"bar" xml:"bar" binding:"required"` @@ -144,6 +150,16 @@ type FooStructForMapPtrType struct { PtrBar *map[string]interface{} `form:"ptr_bar"` } +func init() { + _ = Validator.Engine().(*validator.Validate).RegisterValidationCtx( + "required_if_condition", func(ctx context.Context, fl validator.FieldLevel) bool { + if ctx.Value("condition") == true { + return !fl.Field().IsZero() + } + return true + }) +} + func TestBindingDefault(t *testing.T) { assert.Equal(t, Form, Default("GET", "")) assert.Equal(t, Form, Default("GET", MIMEJSON)) @@ -796,6 +812,38 @@ func TestUriBinding(t *testing.T) { assert.Equal(t, map[string]interface{}(nil), not.Name) } +func TestUriBindingWithContext(t *testing.T) { + b := Uri + + type Tag struct { + Name string `uri:"name" binding:"required_if_condition"` + } + + empty := make(map[string][]string) + assert.NoError(t, b.BindUriContext(context.Background(), empty, new(Tag))) + assert.Error(t, b.BindUriContext(context.WithValue(context.Background(), "condition", true), empty, new(Tag))) // nolint +} + +func TestUriBindingWithNotContextValidator(t *testing.T) { + prev := Validator + defer func() { + Validator = prev + }() + Validator = ¬ContextValidator{} + + TestUriBinding(t) +} + +type notContextValidator defaultValidator + +func (v *notContextValidator) ValidateStruct(obj interface{}) error { + return (*defaultValidator)(v).ValidateStruct(obj) +} + +func (v *notContextValidator) Engine() interface{} { + return (*defaultValidator)(v).Engine() +} + func TestUriInnerBinding(t *testing.T) { type Tag struct { Name string `uri:"name"` @@ -1179,7 +1227,7 @@ func testQueryBindingBoolFail(t *testing.T, method, path, badPath, body, badBody assert.Error(t, err) } -func testBodyBinding(t *testing.T, b Binding, name, path, badPath, body, badBody string) { +func testBodyBinding(t *testing.T, b ContextBinding, name, path, badPath, body, badBody string) { assert.Equal(t, name, b.Name()) obj := FooStruct{} @@ -1190,7 +1238,16 @@ func testBodyBinding(t *testing.T, b Binding, name, path, badPath, body, badBody obj = FooStruct{} req = requestWithBody("POST", badPath, badBody) - err = JSON.Bind(req, &obj) + err = b.Bind(req, &obj) + assert.Error(t, err) + + obj2 := ConditionalFooStruct{} + req = requestWithBody("POST", path, body) + err = b.BindContext(context.Background(), req, &obj2) + assert.NoError(t, err) + assert.Equal(t, "bar", obj2.Foo) + + err = b.BindContext(context.WithValue(context.Background(), "condition", true), req, &obj2) // nolint assert.Error(t, err) } @@ -1204,7 +1261,7 @@ func testBodyBindingSlice(t *testing.T, b Binding, name, path, badPath, body, ba var obj2 []FooStruct req = requestWithBody("POST", badPath, badBody) - err = JSON.Bind(req, &obj2) + err = b.Bind(req, &obj2) assert.Error(t, err) } @@ -1249,7 +1306,7 @@ func testBodyBindingUseNumber(t *testing.T, b Binding, name, path, badPath, body obj = FooStructUseNumber{} req = requestWithBody("POST", badPath, badBody) - err = JSON.Bind(req, &obj) + err = b.Bind(req, &obj) assert.Error(t, err) } @@ -1267,7 +1324,7 @@ func testBodyBindingUseNumber2(t *testing.T, b Binding, name, path, badPath, bod obj = FooStructUseNumber{} req = requestWithBody("POST", badPath, badBody) - err = JSON.Bind(req, &obj) + err = b.Bind(req, &obj) assert.Error(t, err) } @@ -1285,7 +1342,7 @@ func testBodyBindingDisallowUnknownFields(t *testing.T, b Binding, path, badPath obj = FooStructDisallowUnknownFields{} req = requestWithBody("POST", badPath, badBody) - err = JSON.Bind(req, &obj) + err = b.Bind(req, &obj) assert.Error(t, err) assert.Contains(t, err.Error(), "what") } @@ -1301,7 +1358,7 @@ func testBodyBindingFail(t *testing.T, b Binding, name, path, badPath, body, bad obj = FooStruct{} req = requestWithBody("POST", badPath, badBody) - err = JSON.Bind(req, &obj) + err = b.Bind(req, &obj) assert.Error(t, err) } @@ -1318,7 +1375,7 @@ func testProtoBodyBinding(t *testing.T, b Binding, name, path, badPath, body, ba obj = protoexample.Test{} req = requestWithBody("POST", badPath, badBody) req.Header.Add("Content-Type", MIMEPROTOBUF) - err = ProtoBuf.Bind(req, &obj) + err = b.Bind(req, &obj) assert.Error(t, err) } @@ -1349,7 +1406,7 @@ func testProtoBodyBindingFail(t *testing.T, b Binding, name, path, badPath, body obj = protoexample.Test{} req = requestWithBody("POST", badPath, badBody) req.Header.Add("Content-Type", MIMEPROTOBUF) - err = ProtoBuf.Bind(req, &obj) + err = b.Bind(req, &obj) assert.Error(t, err) } diff --git a/binding/default_validator.go b/binding/default_validator.go index b60e3cf6..90a458d3 100644 --- a/binding/default_validator.go +++ b/binding/default_validator.go @@ -5,6 +5,7 @@ package binding import ( + "context" "fmt" "reflect" "sync" @@ -92,10 +93,14 @@ func (fe mapFieldError) Unwrap() error { return fe.FieldError } -var _ StructValidator = &defaultValidator{} +var _ ContextStructValidator = &defaultValidator{} // ValidateStruct receives any kind of type, but validates only structs, pointers, slices, arrays, and maps. func (v *defaultValidator) ValidateStruct(obj interface{}) error { + return v.ValidateStructContext(context.Background(), obj) +} + +func (v *defaultValidator) ValidateStructContext(ctx context.Context, obj interface{}) error { if obj == nil { return nil } @@ -103,21 +108,21 @@ func (v *defaultValidator) ValidateStruct(obj interface{}) error { value := reflect.ValueOf(obj) switch value.Kind() { case reflect.Ptr: - return v.ValidateStruct(value.Elem().Interface()) + return v.ValidateStructContext(ctx, value.Elem().Interface()) case reflect.Struct: - return v.validateStruct(obj) + return v.validateStruct(ctx, obj) case reflect.Slice, reflect.Array: var errs validator.ValidationErrors if tag, ok := validatorTags[value.Type()]; ok { - if err := v.validateVar(obj, tag); err != nil { + if err := v.validateVar(ctx, obj, tag); err != nil { errs = append(errs, err.(validator.ValidationErrors)...) // nolint: errorlint } } count := value.Len() for i := 0; i < count; i++ { - if err := v.ValidateStruct(value.Index(i).Interface()); err != nil { + if err := v.ValidateStructContext(ctx, value.Index(i).Interface()); err != nil { for _, fieldError := range err.(validator.ValidationErrors) { // nolint: errorlint errs = append(errs, sliceFieldError{fieldError, i}) } @@ -132,13 +137,13 @@ func (v *defaultValidator) ValidateStruct(obj interface{}) error { var errs validator.ValidationErrors if tag, ok := validatorTags[value.Type()]; ok { - if err := v.validateVar(obj, tag); err != nil { + if err := v.validateVar(ctx, obj, tag); err != nil { errs = append(errs, err.(validator.ValidationErrors)...) // nolint: errorlint } } for _, key := range value.MapKeys() { - if err := v.ValidateStruct(value.MapIndex(key).Interface()); err != nil { + if err := v.ValidateStructContext(ctx, value.MapIndex(key).Interface()); err != nil { for _, fieldError := range err.(validator.ValidationErrors) { // nolint: errorlint errs = append(errs, mapFieldError{fieldError, key.Interface()}) } @@ -154,15 +159,15 @@ func (v *defaultValidator) ValidateStruct(obj interface{}) error { } // validateStruct receives struct type -func (v *defaultValidator) validateStruct(obj interface{}) error { +func (v *defaultValidator) validateStruct(ctx context.Context, obj interface{}) error { v.lazyinit() - return v.validate.Struct(obj) + return v.validate.StructCtx(ctx, obj) } // validateStruct receives slice, array, and map types -func (v *defaultValidator) validateVar(obj interface{}, tag string) error { +func (v *defaultValidator) validateVar(ctx context.Context, obj interface{}, tag string) error { v.lazyinit() - return v.validate.Var(obj, tag) + return v.validate.VarCtx(ctx, obj, tag) } // Engine returns the underlying validator engine which powers the default diff --git a/binding/form.go b/binding/form.go index fa2a6540..5020bfc2 100644 --- a/binding/form.go +++ b/binding/form.go @@ -5,6 +5,7 @@ package binding import ( + "context" "errors" "net/http" ) @@ -19,7 +20,11 @@ func (formBinding) Name() string { return "form" } -func (formBinding) Bind(req *http.Request, obj interface{}) error { +func (b formBinding) Bind(req *http.Request, obj interface{}) error { + return b.BindContext(context.Background(), req, obj) +} + +func (formBinding) BindContext(ctx context.Context, req *http.Request, obj interface{}) error { if err := req.ParseForm(); err != nil { return err } @@ -29,34 +34,41 @@ func (formBinding) Bind(req *http.Request, obj interface{}) error { if err := mapForm(obj, req.Form); err != nil { return err } - return validate(obj) + return validateContext(ctx, obj) } func (formPostBinding) Name() string { return "form-urlencoded" } -func (formPostBinding) Bind(req *http.Request, obj interface{}) error { +func (b formPostBinding) Bind(req *http.Request, obj interface{}) error { + return b.BindContext(context.Background(), req, obj) +} + +func (formPostBinding) BindContext(ctx context.Context, req *http.Request, obj interface{}) error { if err := req.ParseForm(); err != nil { return err } if err := mapForm(obj, req.PostForm); err != nil { return err } - return validate(obj) + return validateContext(ctx, obj) } func (formMultipartBinding) Name() string { return "multipart/form-data" } -func (formMultipartBinding) Bind(req *http.Request, obj interface{}) error { +func (b formMultipartBinding) Bind(req *http.Request, obj interface{}) error { + return b.BindContext(context.Background(), req, obj) +} + +func (formMultipartBinding) BindContext(ctx context.Context, req *http.Request, obj interface{}) error { if err := req.ParseMultipartForm(defaultMemory); err != nil { return err } if err := mappingByPtr(obj, (*multipartRequest)(req), "form"); err != nil { return err } - - return validate(obj) + return validateContext(ctx, obj) } diff --git a/binding/header.go b/binding/header.go index b99302af..55b90dfb 100644 --- a/binding/header.go +++ b/binding/header.go @@ -1,6 +1,7 @@ package binding import ( + "context" "net/http" "net/textproto" "reflect" @@ -12,13 +13,15 @@ func (headerBinding) Name() string { return "header" } -func (headerBinding) Bind(req *http.Request, obj interface{}) error { +func (b headerBinding) Bind(req *http.Request, obj interface{}) error { + return b.BindContext(context.Background(), req, obj) +} +func (headerBinding) BindContext(ctx context.Context, req *http.Request, obj interface{}) error { if err := mapHeader(obj, req.Header); err != nil { return err } - - return validate(obj) + return validateContext(ctx, obj) } func mapHeader(ptr interface{}, h map[string][]string) error { diff --git a/binding/json.go b/binding/json.go index 45aaa494..3f28c7ad 100644 --- a/binding/json.go +++ b/binding/json.go @@ -6,6 +6,7 @@ package binding import ( "bytes" + "context" "errors" "io" "net/http" @@ -30,18 +31,26 @@ func (jsonBinding) Name() string { return "json" } -func (jsonBinding) Bind(req *http.Request, obj interface{}) error { +func (b jsonBinding) Bind(req *http.Request, obj interface{}) error { + return b.BindContext(context.Background(), req, obj) +} + +func (jsonBinding) BindContext(ctx context.Context, req *http.Request, obj interface{}) error { if req == nil || req.Body == nil { return errors.New("invalid request") } - return decodeJSON(req.Body, obj) + return decodeJSON(ctx, req.Body, obj) } -func (jsonBinding) BindBody(body []byte, obj interface{}) error { - return decodeJSON(bytes.NewReader(body), obj) +func (b jsonBinding) BindBody(body []byte, obj interface{}) error { + return b.BindBodyContext(context.Background(), body, obj) } -func decodeJSON(r io.Reader, obj interface{}) error { +func (jsonBinding) BindBodyContext(ctx context.Context, body []byte, obj interface{}) error { + return decodeJSON(ctx, bytes.NewReader(body), obj) +} + +func decodeJSON(ctx context.Context, r io.Reader, obj interface{}) error { decoder := json.NewDecoder(r) if EnableDecoderUseNumber { decoder.UseNumber() @@ -52,5 +61,5 @@ func decodeJSON(r io.Reader, obj interface{}) error { if err := decoder.Decode(obj); err != nil { return err } - return validate(obj) + return validateContext(ctx, obj) } diff --git a/binding/msgpack.go b/binding/msgpack.go index 2a442996..0d93bb3b 100644 --- a/binding/msgpack.go +++ b/binding/msgpack.go @@ -9,6 +9,7 @@ package binding import ( "bytes" + "context" "io" "net/http" @@ -21,18 +22,26 @@ func (msgpackBinding) Name() string { return "msgpack" } -func (msgpackBinding) Bind(req *http.Request, obj interface{}) error { - return decodeMsgPack(req.Body, obj) +func (b msgpackBinding) Bind(req *http.Request, obj interface{}) error { + return b.BindContext(context.Background(), req, obj) } -func (msgpackBinding) BindBody(body []byte, obj interface{}) error { - return decodeMsgPack(bytes.NewReader(body), obj) +func (msgpackBinding) BindContext(ctx context.Context, req *http.Request, obj interface{}) error { + return decodeMsgPack(ctx, req.Body, obj) } -func decodeMsgPack(r io.Reader, obj interface{}) error { +func (b msgpackBinding) BindBody(body []byte, obj interface{}) error { + return b.BindBodyContext(context.Background(), body, obj) +} + +func (msgpackBinding) BindBodyContext(ctx context.Context, body []byte, obj interface{}) error { + return decodeMsgPack(ctx, bytes.NewReader(body), obj) +} + +func decodeMsgPack(ctx context.Context, r io.Reader, obj interface{}) error { cdc := new(codec.MsgpackHandle) if err := codec.NewDecoder(r, cdc).Decode(&obj); err != nil { return err } - return validate(obj) + return validateContext(ctx, obj) } diff --git a/binding/query.go b/binding/query.go index 219743f2..9fe11368 100644 --- a/binding/query.go +++ b/binding/query.go @@ -4,7 +4,10 @@ package binding -import "net/http" +import ( + "context" + "net/http" +) type queryBinding struct{} @@ -12,10 +15,14 @@ func (queryBinding) Name() string { return "query" } -func (queryBinding) Bind(req *http.Request, obj interface{}) error { +func (b queryBinding) Bind(req *http.Request, obj interface{}) error { + return b.BindContext(context.Background(), req, obj) +} + +func (queryBinding) BindContext(ctx context.Context, req *http.Request, obj interface{}) error { values := req.URL.Query() if err := mapForm(obj, values); err != nil { return err } - return validate(obj) + return validateContext(ctx, obj) } diff --git a/binding/uri.go b/binding/uri.go index a3c0df51..dd6cf655 100644 --- a/binding/uri.go +++ b/binding/uri.go @@ -4,15 +4,21 @@ package binding +import "context" + type uriBinding struct{} func (uriBinding) Name() string { return "uri" } -func (uriBinding) BindUri(m map[string][]string, obj interface{}) error { +func (b uriBinding) BindUri(m map[string][]string, obj interface{}) error { + return b.BindUriContext(context.Background(), m, obj) +} + +func (uriBinding) BindUriContext(ctx context.Context, m map[string][]string, obj interface{}) error { if err := mapURI(obj, m); err != nil { return err } - return validate(obj) + return validateContext(ctx, obj) } diff --git a/binding/validate_test.go b/binding/validate_test.go index 5299fbf6..c05bbdc8 100644 --- a/binding/validate_test.go +++ b/binding/validate_test.go @@ -6,6 +6,7 @@ package binding import ( "bytes" + "context" "testing" "time" @@ -226,3 +227,7 @@ func TestValidatorEngine(t *testing.T) { // Check that the error matches expectation assert.Error(t, errs, "", "", "notone") } + +func validate(obj interface{}) error { + return validateContext(context.Background(), obj) +} diff --git a/binding/xml.go b/binding/xml.go index 4e901149..51d2f110 100644 --- a/binding/xml.go +++ b/binding/xml.go @@ -6,6 +6,7 @@ package binding import ( "bytes" + "context" "encoding/xml" "io" "net/http" @@ -17,17 +18,26 @@ func (xmlBinding) Name() string { return "xml" } -func (xmlBinding) Bind(req *http.Request, obj interface{}) error { - return decodeXML(req.Body, obj) +func (b xmlBinding) Bind(req *http.Request, obj interface{}) error { + return b.BindContext(context.Background(), req, obj) } -func (xmlBinding) BindBody(body []byte, obj interface{}) error { - return decodeXML(bytes.NewReader(body), obj) +func (xmlBinding) BindContext(ctx context.Context, req *http.Request, obj interface{}) error { + return decodeXML(ctx, req.Body, obj) } -func decodeXML(r io.Reader, obj interface{}) error { + +func (b xmlBinding) BindBody(body []byte, obj interface{}) error { + return b.BindBodyContext(context.Background(), body, obj) +} + +func (xmlBinding) BindBodyContext(ctx context.Context, body []byte, obj interface{}) error { + return decodeXML(ctx, bytes.NewReader(body), obj) +} + +func decodeXML(ctx context.Context, r io.Reader, obj interface{}) error { decoder := xml.NewDecoder(r) if err := decoder.Decode(obj); err != nil { return err } - return validate(obj) + return validateContext(ctx, obj) } diff --git a/binding/yaml.go b/binding/yaml.go index a2d36d6a..816f67e5 100644 --- a/binding/yaml.go +++ b/binding/yaml.go @@ -6,6 +6,7 @@ package binding import ( "bytes" + "context" "io" "net/http" @@ -18,18 +19,26 @@ func (yamlBinding) Name() string { return "yaml" } -func (yamlBinding) Bind(req *http.Request, obj interface{}) error { - return decodeYAML(req.Body, obj) +func (b yamlBinding) Bind(req *http.Request, obj interface{}) error { + return b.BindContext(context.Background(), req, obj) } -func (yamlBinding) BindBody(body []byte, obj interface{}) error { - return decodeYAML(bytes.NewReader(body), obj) +func (yamlBinding) BindContext(ctx context.Context, req *http.Request, obj interface{}) error { + return decodeYAML(ctx, req.Body, obj) } -func decodeYAML(r io.Reader, obj interface{}) error { +func (b yamlBinding) BindBody(body []byte, obj interface{}) error { + return b.BindBodyContext(context.Background(), body, obj) +} + +func (yamlBinding) BindBodyContext(ctx context.Context, body []byte, obj interface{}) error { + return decodeYAML(ctx, bytes.NewReader(body), obj) +} + +func decodeYAML(ctx context.Context, r io.Reader, obj interface{}) error { decoder := yaml.NewDecoder(r) if err := decoder.Decode(obj); err != nil { return err } - return validate(obj) + return validateContext(ctx, obj) } diff --git a/context.go b/context.go index 58f38c88..e98fee72 100644 --- a/context.go +++ b/context.go @@ -704,12 +704,15 @@ func (c *Context) ShouldBindUri(obj interface{}) error { for _, v := range c.Params { m[v.Key] = []string{v.Value} } - return binding.Uri.BindUri(m, obj) + return binding.Uri.BindUriContext(c, m, obj) } // ShouldBindWith binds the passed struct pointer using the specified binding engine. // See the binding package. func (c *Context) ShouldBindWith(obj interface{}, b binding.Binding) error { + if b, ok := b.(binding.ContextBinding); ok { + return b.BindContext(c, c.Request, obj) + } return b.Bind(c.Request, obj) } @@ -732,6 +735,9 @@ func (c *Context) ShouldBindBodyWith(obj interface{}, bb binding.BindingBody) (e } c.Set(BodyBytesKey, body) } + if bb, ok := bb.(binding.ContextBindingBody); ok { + return bb.BindBodyContext(c, body, obj) + } return bb.BindBody(body, obj) } diff --git a/context_test.go b/context_test.go index c286c0f4..1e30af29 100644 --- a/context_test.go +++ b/context_test.go @@ -24,6 +24,7 @@ import ( "github.com/gin-contrib/sse" "github.com/gin-gonic/gin/binding" testdata "github.com/gin-gonic/gin/testdata/protoexample" + "github.com/go-playground/validator/v10" "github.com/stretchr/testify/assert" "google.golang.org/protobuf/proto" ) @@ -36,6 +37,16 @@ var _ context.Context = &Context{} // BAD case: func (c *Context) Render(code int, render render.Render, obj ...interface{}) { // test that information is not leaked when reusing Contexts (using the Pool) +func init() { + _ = binding.Validator.Engine().(*validator.Validate).RegisterValidationCtx( + "required_if_condition", func(ctx context.Context, fl validator.FieldLevel) bool { + if ctx.Value("condition") == true { + return !fl.Field().IsZero() + } + return true + }) +} + func createMultipartRequest() *http.Request { boundary := "--testboundary" body := new(bytes.Buffer) @@ -1543,6 +1554,27 @@ func TestContextBindWithJSON(t *testing.T) { assert.Equal(t, 0, w.Body.Len()) } +func TestContextBindWithJSONContextual(t *testing.T) { + w := httptest.NewRecorder() + c, _ := CreateTestContext(w) + + c.Request, _ = http.NewRequest("POST", "/", bytes.NewBufferString("{\"bar\":\"foo\"}")) + c.Request.Header.Add("Content-Type", MIMEXML) // set fake content-type + + var obj struct { + Foo string `json:"foo" binding:"required_if_condition"` + Bar string `json:"bar"` + } + c.Set("condition", true) + assert.Error(t, c.BindJSON(&obj)) + + c.Request, _ = http.NewRequest("POST", "/", bytes.NewBufferString("{\"foo\":\"bar\", \"bar\":\"foo\"}")) + assert.NoError(t, c.BindJSON(&obj)) + assert.Equal(t, "foo", obj.Bar) + assert.Equal(t, "bar", obj.Foo) + assert.Equal(t, 0, w.Body.Len()) +} + func TestContextBindWithXML(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) @@ -1672,6 +1704,92 @@ func TestContextShouldBindWithJSON(t *testing.T) { assert.Equal(t, 0, w.Body.Len()) } +func TestContextShouldBindWithJSONContextual(t *testing.T) { + w := httptest.NewRecorder() + c, _ := CreateTestContext(w) + + c.Request, _ = http.NewRequest("POST", "/", bytes.NewBufferString("{\"bar\":\"foo\"}")) + c.Request.Header.Add("Content-Type", MIMEXML) // set fake content-type + + var obj struct { + Foo string `json:"foo" binding:"required_if_condition"` + Bar string `json:"bar"` + } + c.Set("condition", true) + assert.Error(t, c.ShouldBindJSON(&obj)) + + c.Request, _ = http.NewRequest("POST", "/", bytes.NewBufferString("{\"foo\":\"bar\", \"bar\":\"foo\"}")) + assert.NoError(t, c.ShouldBindJSON(&obj)) + assert.Equal(t, "foo", obj.Bar) + assert.Equal(t, "bar", obj.Foo) + assert.Equal(t, 0, w.Body.Len()) +} + +func TestContextShouldBindBodyWithJSONContextual(t *testing.T) { + w := httptest.NewRecorder() + c, _ := CreateTestContext(w) + + var obj struct { + Foo string `json:"foo" binding:"required_if_condition"` + Bar string `json:"bar"` + } + c.Set("condition", true) + c.Set(BodyBytesKey, []byte("{\"bar\":\"foo\"}")) + assert.Error(t, c.ShouldBindBodyWith(&obj, binding.JSON)) + + c.Set(BodyBytesKey, []byte("{\"foo\":\"bar\", \"bar\":\"foo\"}")) + assert.NoError(t, c.ShouldBindBodyWith(&obj, binding.JSON)) + assert.Equal(t, "foo", obj.Bar) + assert.Equal(t, "bar", obj.Foo) + assert.Equal(t, 0, w.Body.Len()) +} + +func TestContextShouldBindWithNotContextBinding(t *testing.T) { + w := httptest.NewRecorder() + c, _ := CreateTestContext(w) + + c.Request, _ = http.NewRequest("POST", "/", bytes.NewBufferString("{\"foo\":\"bar\", \"bar\":\"foo\"}")) + c.Request.Header.Add("Content-Type", MIMEXML) // set fake content-type + + var obj struct { + Foo string `json:"foo" binding:"required_if_condition"` + Bar string `json:"bar"` + } + assert.NoError(t, c.ShouldBindWith(&obj, notContextBinding{})) + assert.Equal(t, "foo", obj.Bar) + assert.Equal(t, "bar", obj.Foo) + assert.Equal(t, 0, w.Body.Len()) +} + +func TestContextShouldBindBodyWithNotContextBinding(t *testing.T) { + w := httptest.NewRecorder() + c, _ := CreateTestContext(w) + + var obj struct { + Foo string `json:"foo"` + Bar string `json:"bar"` + } + c.Set(BodyBytesKey, []byte("{\"foo\":\"bar\", \"bar\":\"foo\"}")) + assert.NoError(t, c.ShouldBindBodyWith(&obj, notContextBinding{})) + assert.Equal(t, "foo", obj.Bar) + assert.Equal(t, "bar", obj.Foo) + assert.Equal(t, 0, w.Body.Len()) +} + +type notContextBinding struct{} + +func (notContextBinding) Name() string { + return binding.JSON.Name() +} + +func (b notContextBinding) Bind(req *http.Request, obj interface{}) error { + return binding.JSON.Bind(req, obj) +} + +func (b notContextBinding) BindBody(body []byte, obj interface{}) error { + return binding.JSON.BindBody(body, obj) +} + func TestContextShouldBindWithXML(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w)