Add support for contextual validation

By passing the Gin context to bindings, custom validators can take
advantage of the information in the context.
This commit is contained in:
Krzysztof Szafrański 2021-11-05 18:16:37 +01:00
parent f26790bbda
commit b626f63906
16 changed files with 410 additions and 73 deletions

View File

@ -7,7 +7,10 @@
package binding
import "net/http"
import (
"context"
"net/http"
)
// Content-Type MIME of the most common data formats.
const (
@ -32,20 +35,41 @@ type Binding interface {
Bind(*http.Request, interface{}) error
}
// BindingBody adds BindBody method to Binding. BindBody is similar with Bind,
// ContextBinding enables contextual validation by adding BindContext to Binding.
// Custom validators can take advantage of the information in the context.
type ContextBinding interface {
Binding
BindContext(context.Context, *http.Request, interface{}) error
}
// BindingBody adds BindBody method to Binding. BindBody is similar to Bind,
// but it reads the body from supplied bytes instead of req.Body.
type BindingBody interface {
Binding
BindBody([]byte, interface{}) error
}
// BindingUri adds BindUri method to Binding. BindUri is similar with Bind,
// but it read the Params.
// ContextBindingBody enables contextual validation by adding BindBodyContext to BindingBody.
// Custom validators can take advantage of the information in the context.
type ContextBindingBody interface {
BindingBody
BindContext(context.Context, *http.Request, interface{}) error
BindBodyContext(context.Context, []byte, interface{}) error
}
// BindingUri is similar to Bind, but it read the Params.
type BindingUri interface {
Name() string
BindUri(map[string][]string, interface{}) error
}
// ContextBindingUri enables contextual validation by adding BindUriContext to BindingUri.
// Custom validators can take advantage of the information in the context.
type ContextBindingUri interface {
BindingUri
BindUriContext(context.Context, map[string][]string, interface{}) error
}
// StructValidator is the minimal interface which needs to be implemented in
// order for it to be used as the validator engine for ensuring the correctness
// of the request. Gin provides a default implementation for this using
@ -64,6 +88,14 @@ type StructValidator interface {
Engine() interface{}
}
// ContextStructValidator is an extension of StructValidator that requires implementing
// context-aware validation.
// Custom validators can take advantage of the information in the context.
type ContextStructValidator interface {
StructValidator
ValidateStructContext(context.Context, interface{}) error
}
// Validator is the default validator which implements the StructValidator
// interface. It uses https://github.com/go-playground/validator/tree/v10.6.1
// under the hood.
@ -110,9 +142,12 @@ func Default(method, contentType string) Binding {
}
}
func validate(obj interface{}) error {
func validateContext(ctx context.Context, obj interface{}) error {
if Validator == nil {
return nil
}
if v, ok := Validator.(ContextStructValidator); ok {
return v.ValidateStructContext(ctx, obj)
}
return Validator.ValidateStruct(obj)
}

View File

@ -9,6 +9,7 @@ package binding
import (
"bytes"
"context"
"testing"
"github.com/stretchr/testify/assert"
@ -35,7 +36,7 @@ func TestBindingMsgPack(t *testing.T) {
string(data), string(data[1:]))
}
func testMsgPackBodyBinding(t *testing.T, b Binding, name, path, badPath, body, badBody string) {
func testMsgPackBodyBinding(t *testing.T, b ContextBinding, name, path, badPath, body, badBody string) {
assert.Equal(t, name, b.Name())
obj := FooStruct{}
@ -48,7 +49,17 @@ func testMsgPackBodyBinding(t *testing.T, b Binding, name, path, badPath, body,
obj = FooStruct{}
req = requestWithBody("POST", badPath, badBody)
req.Header.Add("Content-Type", MIMEMSGPACK)
err = MsgPack.Bind(req, &obj)
err = b.Bind(req, &obj)
assert.Error(t, err)
obj2 := ConditionalFooStruct{}
req = requestWithBody("POST", path, body)
req.Header.Add("Content-Type", MIMEMSGPACK)
err = b.BindContext(context.Background(), req, &obj2)
assert.NoError(t, err)
assert.Equal(t, "bar", obj2.Foo)
err = b.BindContext(context.WithValue(context.Background(), "condition", true), req, &obj2) // nolint
assert.Error(t, err)
}

View File

@ -7,7 +7,10 @@
package binding
import "net/http"
import (
"context"
"net/http"
)
// Content-Type MIME of the most common data formats.
const (
@ -30,20 +33,41 @@ type Binding interface {
Bind(*http.Request, interface{}) error
}
// BindingBody adds BindBody method to Binding. BindBody is similar with Bind,
// ContextBinding enables contextual validation by adding BindContext to Binding.
// Custom validators can take advantage of the information in the context.
type ContextBinding interface {
Binding
BindContext(context.Context, *http.Request, interface{}) error
}
// BindingBody adds BindBody method to Binding. BindBody is similar to Bind,
// but it reads the body from supplied bytes instead of req.Body.
type BindingBody interface {
Binding
BindBody([]byte, interface{}) error
}
// BindingUri adds BindUri method to Binding. BindUri is similar with Bind,
// but it read the Params.
// ContextBindingBody enables contextual validation by adding BindBodyContext to BindingBody.
// Custom validators can take advantage of the information in the context.
type ContextBindingBody interface {
BindingBody
BindContext(context.Context, *http.Request, interface{}) error
BindBodyContext(context.Context, []byte, interface{}) error
}
// BindingUri is similar to Bind, but it read the Params.
type BindingUri interface {
Name() string
BindUri(map[string][]string, interface{}) error
}
// ContextBindingUri enables contextual validation by adding BindUriContext to BindingUri.
// Custom validators can take advantage of the information in the context.
type ContextBindingUri interface {
BindingUri
BindUriContext(context.Context, map[string][]string, interface{}) error
}
// StructValidator is the minimal interface which needs to be implemented in
// order for it to be used as the validator engine for ensuring the correctness
// of the request. Gin provides a default implementation for this using
@ -62,6 +86,14 @@ type StructValidator interface {
Engine() interface{}
}
// ContextStructValidator is an extension of StructValidator that requires implementing
// context-aware validation.
// Custom validators can take advantage of the information in the context.
type ContextStructValidator interface {
StructValidator
ValidateStructContext(context.Context, interface{}) error
}
// Validator is the default validator which implements the StructValidator
// interface. It uses https://github.com/go-playground/validator/tree/v10.6.1
// under the hood.
@ -85,7 +117,7 @@ var (
// Default returns the appropriate Binding instance based on the HTTP method
// and the content type.
func Default(method, contentType string) Binding {
if method == "GET" {
if method == http.MethodGet {
return Form
}
@ -105,9 +137,12 @@ func Default(method, contentType string) Binding {
}
}
func validate(obj interface{}) error {
func validateContext(ctx context.Context, obj interface{}) error {
if Validator == nil {
return nil
}
if v, ok := Validator.(ContextStructValidator); ok {
return v.ValidateStructContext(ctx, obj)
}
return Validator.ValidateStruct(obj)
}

View File

@ -6,6 +6,7 @@ package binding
import (
"bytes"
"context"
"encoding/json"
"errors"
"io"
@ -20,6 +21,7 @@ import (
"time"
"github.com/gin-gonic/gin/testdata/protoexample"
"github.com/go-playground/validator/v10"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/proto"
)
@ -38,6 +40,10 @@ type FooStruct struct {
Foo string `msgpack:"foo" json:"foo" form:"foo" xml:"foo" binding:"required,max=32"`
}
type ConditionalFooStruct struct {
Foo string `msgpack:"foo" json:"foo" form:"foo" xml:"foo" binding:"required_if_condition,max=32"`
}
type FooBarStruct struct {
FooStruct
Bar string `msgpack:"bar" json:"bar" form:"bar" xml:"bar" binding:"required"`
@ -144,6 +150,16 @@ type FooStructForMapPtrType struct {
PtrBar *map[string]interface{} `form:"ptr_bar"`
}
func init() {
_ = Validator.Engine().(*validator.Validate).RegisterValidationCtx(
"required_if_condition", func(ctx context.Context, fl validator.FieldLevel) bool {
if ctx.Value("condition") == true {
return !fl.Field().IsZero()
}
return true
})
}
func TestBindingDefault(t *testing.T) {
assert.Equal(t, Form, Default("GET", ""))
assert.Equal(t, Form, Default("GET", MIMEJSON))
@ -796,6 +812,38 @@ func TestUriBinding(t *testing.T) {
assert.Equal(t, map[string]interface{}(nil), not.Name)
}
func TestUriBindingWithContext(t *testing.T) {
b := Uri
type Tag struct {
Name string `uri:"name" binding:"required_if_condition"`
}
empty := make(map[string][]string)
assert.NoError(t, b.BindUriContext(context.Background(), empty, new(Tag)))
assert.Error(t, b.BindUriContext(context.WithValue(context.Background(), "condition", true), empty, new(Tag))) // nolint
}
func TestUriBindingWithNotContextValidator(t *testing.T) {
prev := Validator
defer func() {
Validator = prev
}()
Validator = &notContextValidator{}
TestUriBinding(t)
}
type notContextValidator defaultValidator
func (v *notContextValidator) ValidateStruct(obj interface{}) error {
return (*defaultValidator)(v).ValidateStruct(obj)
}
func (v *notContextValidator) Engine() interface{} {
return (*defaultValidator)(v).Engine()
}
func TestUriInnerBinding(t *testing.T) {
type Tag struct {
Name string `uri:"name"`
@ -1179,7 +1227,7 @@ func testQueryBindingBoolFail(t *testing.T, method, path, badPath, body, badBody
assert.Error(t, err)
}
func testBodyBinding(t *testing.T, b Binding, name, path, badPath, body, badBody string) {
func testBodyBinding(t *testing.T, b ContextBinding, name, path, badPath, body, badBody string) {
assert.Equal(t, name, b.Name())
obj := FooStruct{}
@ -1190,7 +1238,16 @@ func testBodyBinding(t *testing.T, b Binding, name, path, badPath, body, badBody
obj = FooStruct{}
req = requestWithBody("POST", badPath, badBody)
err = JSON.Bind(req, &obj)
err = b.Bind(req, &obj)
assert.Error(t, err)
obj2 := ConditionalFooStruct{}
req = requestWithBody("POST", path, body)
err = b.BindContext(context.Background(), req, &obj2)
assert.NoError(t, err)
assert.Equal(t, "bar", obj2.Foo)
err = b.BindContext(context.WithValue(context.Background(), "condition", true), req, &obj2) // nolint
assert.Error(t, err)
}
@ -1204,7 +1261,7 @@ func testBodyBindingSlice(t *testing.T, b Binding, name, path, badPath, body, ba
var obj2 []FooStruct
req = requestWithBody("POST", badPath, badBody)
err = JSON.Bind(req, &obj2)
err = b.Bind(req, &obj2)
assert.Error(t, err)
}
@ -1249,7 +1306,7 @@ func testBodyBindingUseNumber(t *testing.T, b Binding, name, path, badPath, body
obj = FooStructUseNumber{}
req = requestWithBody("POST", badPath, badBody)
err = JSON.Bind(req, &obj)
err = b.Bind(req, &obj)
assert.Error(t, err)
}
@ -1267,7 +1324,7 @@ func testBodyBindingUseNumber2(t *testing.T, b Binding, name, path, badPath, bod
obj = FooStructUseNumber{}
req = requestWithBody("POST", badPath, badBody)
err = JSON.Bind(req, &obj)
err = b.Bind(req, &obj)
assert.Error(t, err)
}
@ -1285,7 +1342,7 @@ func testBodyBindingDisallowUnknownFields(t *testing.T, b Binding, path, badPath
obj = FooStructDisallowUnknownFields{}
req = requestWithBody("POST", badPath, badBody)
err = JSON.Bind(req, &obj)
err = b.Bind(req, &obj)
assert.Error(t, err)
assert.Contains(t, err.Error(), "what")
}
@ -1301,7 +1358,7 @@ func testBodyBindingFail(t *testing.T, b Binding, name, path, badPath, body, bad
obj = FooStruct{}
req = requestWithBody("POST", badPath, badBody)
err = JSON.Bind(req, &obj)
err = b.Bind(req, &obj)
assert.Error(t, err)
}
@ -1318,7 +1375,7 @@ func testProtoBodyBinding(t *testing.T, b Binding, name, path, badPath, body, ba
obj = protoexample.Test{}
req = requestWithBody("POST", badPath, badBody)
req.Header.Add("Content-Type", MIMEPROTOBUF)
err = ProtoBuf.Bind(req, &obj)
err = b.Bind(req, &obj)
assert.Error(t, err)
}
@ -1349,7 +1406,7 @@ func testProtoBodyBindingFail(t *testing.T, b Binding, name, path, badPath, body
obj = protoexample.Test{}
req = requestWithBody("POST", badPath, badBody)
req.Header.Add("Content-Type", MIMEPROTOBUF)
err = ProtoBuf.Bind(req, &obj)
err = b.Bind(req, &obj)
assert.Error(t, err)
}

View File

@ -5,6 +5,7 @@
package binding
import (
"context"
"fmt"
"reflect"
"sync"
@ -92,10 +93,14 @@ func (fe mapFieldError) Unwrap() error {
return fe.FieldError
}
var _ StructValidator = &defaultValidator{}
var _ ContextStructValidator = &defaultValidator{}
// ValidateStruct receives any kind of type, but validates only structs, pointers, slices, arrays, and maps.
func (v *defaultValidator) ValidateStruct(obj interface{}) error {
return v.ValidateStructContext(context.Background(), obj)
}
func (v *defaultValidator) ValidateStructContext(ctx context.Context, obj interface{}) error {
if obj == nil {
return nil
}
@ -103,21 +108,21 @@ func (v *defaultValidator) ValidateStruct(obj interface{}) error {
value := reflect.ValueOf(obj)
switch value.Kind() {
case reflect.Ptr:
return v.ValidateStruct(value.Elem().Interface())
return v.ValidateStructContext(ctx, value.Elem().Interface())
case reflect.Struct:
return v.validateStruct(obj)
return v.validateStruct(ctx, obj)
case reflect.Slice, reflect.Array:
var errs validator.ValidationErrors
if tag, ok := validatorTags[value.Type()]; ok {
if err := v.validateVar(obj, tag); err != nil {
if err := v.validateVar(ctx, obj, tag); err != nil {
errs = append(errs, err.(validator.ValidationErrors)...) // nolint: errorlint
}
}
count := value.Len()
for i := 0; i < count; i++ {
if err := v.ValidateStruct(value.Index(i).Interface()); err != nil {
if err := v.ValidateStructContext(ctx, value.Index(i).Interface()); err != nil {
for _, fieldError := range err.(validator.ValidationErrors) { // nolint: errorlint
errs = append(errs, sliceFieldError{fieldError, i})
}
@ -132,13 +137,13 @@ func (v *defaultValidator) ValidateStruct(obj interface{}) error {
var errs validator.ValidationErrors
if tag, ok := validatorTags[value.Type()]; ok {
if err := v.validateVar(obj, tag); err != nil {
if err := v.validateVar(ctx, obj, tag); err != nil {
errs = append(errs, err.(validator.ValidationErrors)...) // nolint: errorlint
}
}
for _, key := range value.MapKeys() {
if err := v.ValidateStruct(value.MapIndex(key).Interface()); err != nil {
if err := v.ValidateStructContext(ctx, value.MapIndex(key).Interface()); err != nil {
for _, fieldError := range err.(validator.ValidationErrors) { // nolint: errorlint
errs = append(errs, mapFieldError{fieldError, key.Interface()})
}
@ -154,15 +159,15 @@ func (v *defaultValidator) ValidateStruct(obj interface{}) error {
}
// validateStruct receives struct type
func (v *defaultValidator) validateStruct(obj interface{}) error {
func (v *defaultValidator) validateStruct(ctx context.Context, obj interface{}) error {
v.lazyinit()
return v.validate.Struct(obj)
return v.validate.StructCtx(ctx, obj)
}
// validateStruct receives slice, array, and map types
func (v *defaultValidator) validateVar(obj interface{}, tag string) error {
func (v *defaultValidator) validateVar(ctx context.Context, obj interface{}, tag string) error {
v.lazyinit()
return v.validate.Var(obj, tag)
return v.validate.VarCtx(ctx, obj, tag)
}
// Engine returns the underlying validator engine which powers the default

View File

@ -5,6 +5,7 @@
package binding
import (
"context"
"errors"
"net/http"
)
@ -19,7 +20,11 @@ func (formBinding) Name() string {
return "form"
}
func (formBinding) Bind(req *http.Request, obj interface{}) error {
func (b formBinding) Bind(req *http.Request, obj interface{}) error {
return b.BindContext(context.Background(), req, obj)
}
func (formBinding) BindContext(ctx context.Context, req *http.Request, obj interface{}) error {
if err := req.ParseForm(); err != nil {
return err
}
@ -29,34 +34,41 @@ func (formBinding) Bind(req *http.Request, obj interface{}) error {
if err := mapForm(obj, req.Form); err != nil {
return err
}
return validate(obj)
return validateContext(ctx, obj)
}
func (formPostBinding) Name() string {
return "form-urlencoded"
}
func (formPostBinding) Bind(req *http.Request, obj interface{}) error {
func (b formPostBinding) Bind(req *http.Request, obj interface{}) error {
return b.BindContext(context.Background(), req, obj)
}
func (formPostBinding) BindContext(ctx context.Context, req *http.Request, obj interface{}) error {
if err := req.ParseForm(); err != nil {
return err
}
if err := mapForm(obj, req.PostForm); err != nil {
return err
}
return validate(obj)
return validateContext(ctx, obj)
}
func (formMultipartBinding) Name() string {
return "multipart/form-data"
}
func (formMultipartBinding) Bind(req *http.Request, obj interface{}) error {
func (b formMultipartBinding) Bind(req *http.Request, obj interface{}) error {
return b.BindContext(context.Background(), req, obj)
}
func (formMultipartBinding) BindContext(ctx context.Context, req *http.Request, obj interface{}) error {
if err := req.ParseMultipartForm(defaultMemory); err != nil {
return err
}
if err := mappingByPtr(obj, (*multipartRequest)(req), "form"); err != nil {
return err
}
return validate(obj)
return validateContext(ctx, obj)
}

View File

@ -1,6 +1,7 @@
package binding
import (
"context"
"net/http"
"net/textproto"
"reflect"
@ -12,13 +13,15 @@ func (headerBinding) Name() string {
return "header"
}
func (headerBinding) Bind(req *http.Request, obj interface{}) error {
func (b headerBinding) Bind(req *http.Request, obj interface{}) error {
return b.BindContext(context.Background(), req, obj)
}
func (headerBinding) BindContext(ctx context.Context, req *http.Request, obj interface{}) error {
if err := mapHeader(obj, req.Header); err != nil {
return err
}
return validate(obj)
return validateContext(ctx, obj)
}
func mapHeader(ptr interface{}, h map[string][]string) error {

View File

@ -6,6 +6,7 @@ package binding
import (
"bytes"
"context"
"errors"
"io"
"net/http"
@ -30,18 +31,26 @@ func (jsonBinding) Name() string {
return "json"
}
func (jsonBinding) Bind(req *http.Request, obj interface{}) error {
func (b jsonBinding) Bind(req *http.Request, obj interface{}) error {
return b.BindContext(context.Background(), req, obj)
}
func (jsonBinding) BindContext(ctx context.Context, req *http.Request, obj interface{}) error {
if req == nil || req.Body == nil {
return errors.New("invalid request")
}
return decodeJSON(req.Body, obj)
return decodeJSON(ctx, req.Body, obj)
}
func (jsonBinding) BindBody(body []byte, obj interface{}) error {
return decodeJSON(bytes.NewReader(body), obj)
func (b jsonBinding) BindBody(body []byte, obj interface{}) error {
return b.BindBodyContext(context.Background(), body, obj)
}
func decodeJSON(r io.Reader, obj interface{}) error {
func (jsonBinding) BindBodyContext(ctx context.Context, body []byte, obj interface{}) error {
return decodeJSON(ctx, bytes.NewReader(body), obj)
}
func decodeJSON(ctx context.Context, r io.Reader, obj interface{}) error {
decoder := json.NewDecoder(r)
if EnableDecoderUseNumber {
decoder.UseNumber()
@ -52,5 +61,5 @@ func decodeJSON(r io.Reader, obj interface{}) error {
if err := decoder.Decode(obj); err != nil {
return err
}
return validate(obj)
return validateContext(ctx, obj)
}

View File

@ -9,6 +9,7 @@ package binding
import (
"bytes"
"context"
"io"
"net/http"
@ -21,18 +22,26 @@ func (msgpackBinding) Name() string {
return "msgpack"
}
func (msgpackBinding) Bind(req *http.Request, obj interface{}) error {
return decodeMsgPack(req.Body, obj)
func (b msgpackBinding) Bind(req *http.Request, obj interface{}) error {
return b.BindContext(context.Background(), req, obj)
}
func (msgpackBinding) BindBody(body []byte, obj interface{}) error {
return decodeMsgPack(bytes.NewReader(body), obj)
func (msgpackBinding) BindContext(ctx context.Context, req *http.Request, obj interface{}) error {
return decodeMsgPack(ctx, req.Body, obj)
}
func decodeMsgPack(r io.Reader, obj interface{}) error {
func (b msgpackBinding) BindBody(body []byte, obj interface{}) error {
return b.BindBodyContext(context.Background(), body, obj)
}
func (msgpackBinding) BindBodyContext(ctx context.Context, body []byte, obj interface{}) error {
return decodeMsgPack(ctx, bytes.NewReader(body), obj)
}
func decodeMsgPack(ctx context.Context, r io.Reader, obj interface{}) error {
cdc := new(codec.MsgpackHandle)
if err := codec.NewDecoder(r, cdc).Decode(&obj); err != nil {
return err
}
return validate(obj)
return validateContext(ctx, obj)
}

View File

@ -4,7 +4,10 @@
package binding
import "net/http"
import (
"context"
"net/http"
)
type queryBinding struct{}
@ -12,10 +15,14 @@ func (queryBinding) Name() string {
return "query"
}
func (queryBinding) Bind(req *http.Request, obj interface{}) error {
func (b queryBinding) Bind(req *http.Request, obj interface{}) error {
return b.BindContext(context.Background(), req, obj)
}
func (queryBinding) BindContext(ctx context.Context, req *http.Request, obj interface{}) error {
values := req.URL.Query()
if err := mapForm(obj, values); err != nil {
return err
}
return validate(obj)
return validateContext(ctx, obj)
}

View File

@ -4,15 +4,21 @@
package binding
import "context"
type uriBinding struct{}
func (uriBinding) Name() string {
return "uri"
}
func (uriBinding) BindUri(m map[string][]string, obj interface{}) error {
func (b uriBinding) BindUri(m map[string][]string, obj interface{}) error {
return b.BindUriContext(context.Background(), m, obj)
}
func (uriBinding) BindUriContext(ctx context.Context, m map[string][]string, obj interface{}) error {
if err := mapURI(obj, m); err != nil {
return err
}
return validate(obj)
return validateContext(ctx, obj)
}

View File

@ -6,6 +6,7 @@ package binding
import (
"bytes"
"context"
"testing"
"time"
@ -226,3 +227,7 @@ func TestValidatorEngine(t *testing.T) {
// Check that the error matches expectation
assert.Error(t, errs, "", "", "notone")
}
func validate(obj interface{}) error {
return validateContext(context.Background(), obj)
}

View File

@ -6,6 +6,7 @@ package binding
import (
"bytes"
"context"
"encoding/xml"
"io"
"net/http"
@ -17,17 +18,26 @@ func (xmlBinding) Name() string {
return "xml"
}
func (xmlBinding) Bind(req *http.Request, obj interface{}) error {
return decodeXML(req.Body, obj)
func (b xmlBinding) Bind(req *http.Request, obj interface{}) error {
return b.BindContext(context.Background(), req, obj)
}
func (xmlBinding) BindBody(body []byte, obj interface{}) error {
return decodeXML(bytes.NewReader(body), obj)
func (xmlBinding) BindContext(ctx context.Context, req *http.Request, obj interface{}) error {
return decodeXML(ctx, req.Body, obj)
}
func decodeXML(r io.Reader, obj interface{}) error {
func (b xmlBinding) BindBody(body []byte, obj interface{}) error {
return b.BindBodyContext(context.Background(), body, obj)
}
func (xmlBinding) BindBodyContext(ctx context.Context, body []byte, obj interface{}) error {
return decodeXML(ctx, bytes.NewReader(body), obj)
}
func decodeXML(ctx context.Context, r io.Reader, obj interface{}) error {
decoder := xml.NewDecoder(r)
if err := decoder.Decode(obj); err != nil {
return err
}
return validate(obj)
return validateContext(ctx, obj)
}

View File

@ -6,6 +6,7 @@ package binding
import (
"bytes"
"context"
"io"
"net/http"
@ -18,18 +19,26 @@ func (yamlBinding) Name() string {
return "yaml"
}
func (yamlBinding) Bind(req *http.Request, obj interface{}) error {
return decodeYAML(req.Body, obj)
func (b yamlBinding) Bind(req *http.Request, obj interface{}) error {
return b.BindContext(context.Background(), req, obj)
}
func (yamlBinding) BindBody(body []byte, obj interface{}) error {
return decodeYAML(bytes.NewReader(body), obj)
func (yamlBinding) BindContext(ctx context.Context, req *http.Request, obj interface{}) error {
return decodeYAML(ctx, req.Body, obj)
}
func decodeYAML(r io.Reader, obj interface{}) error {
func (b yamlBinding) BindBody(body []byte, obj interface{}) error {
return b.BindBodyContext(context.Background(), body, obj)
}
func (yamlBinding) BindBodyContext(ctx context.Context, body []byte, obj interface{}) error {
return decodeYAML(ctx, bytes.NewReader(body), obj)
}
func decodeYAML(ctx context.Context, r io.Reader, obj interface{}) error {
decoder := yaml.NewDecoder(r)
if err := decoder.Decode(obj); err != nil {
return err
}
return validate(obj)
return validateContext(ctx, obj)
}

View File

@ -704,12 +704,15 @@ func (c *Context) ShouldBindUri(obj interface{}) error {
for _, v := range c.Params {
m[v.Key] = []string{v.Value}
}
return binding.Uri.BindUri(m, obj)
return binding.Uri.BindUriContext(c, m, obj)
}
// ShouldBindWith binds the passed struct pointer using the specified binding engine.
// See the binding package.
func (c *Context) ShouldBindWith(obj interface{}, b binding.Binding) error {
if b, ok := b.(binding.ContextBinding); ok {
return b.BindContext(c, c.Request, obj)
}
return b.Bind(c.Request, obj)
}
@ -732,6 +735,9 @@ func (c *Context) ShouldBindBodyWith(obj interface{}, bb binding.BindingBody) (e
}
c.Set(BodyBytesKey, body)
}
if bb, ok := bb.(binding.ContextBindingBody); ok {
return bb.BindBodyContext(c, body, obj)
}
return bb.BindBody(body, obj)
}

View File

@ -24,6 +24,7 @@ import (
"github.com/gin-contrib/sse"
"github.com/gin-gonic/gin/binding"
testdata "github.com/gin-gonic/gin/testdata/protoexample"
"github.com/go-playground/validator/v10"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/proto"
)
@ -36,6 +37,16 @@ var _ context.Context = &Context{}
// BAD case: func (c *Context) Render(code int, render render.Render, obj ...interface{}) {
// test that information is not leaked when reusing Contexts (using the Pool)
func init() {
_ = binding.Validator.Engine().(*validator.Validate).RegisterValidationCtx(
"required_if_condition", func(ctx context.Context, fl validator.FieldLevel) bool {
if ctx.Value("condition") == true {
return !fl.Field().IsZero()
}
return true
})
}
func createMultipartRequest() *http.Request {
boundary := "--testboundary"
body := new(bytes.Buffer)
@ -1543,6 +1554,27 @@ func TestContextBindWithJSON(t *testing.T) {
assert.Equal(t, 0, w.Body.Len())
}
func TestContextBindWithJSONContextual(t *testing.T) {
w := httptest.NewRecorder()
c, _ := CreateTestContext(w)
c.Request, _ = http.NewRequest("POST", "/", bytes.NewBufferString("{\"bar\":\"foo\"}"))
c.Request.Header.Add("Content-Type", MIMEXML) // set fake content-type
var obj struct {
Foo string `json:"foo" binding:"required_if_condition"`
Bar string `json:"bar"`
}
c.Set("condition", true)
assert.Error(t, c.BindJSON(&obj))
c.Request, _ = http.NewRequest("POST", "/", bytes.NewBufferString("{\"foo\":\"bar\", \"bar\":\"foo\"}"))
assert.NoError(t, c.BindJSON(&obj))
assert.Equal(t, "foo", obj.Bar)
assert.Equal(t, "bar", obj.Foo)
assert.Equal(t, 0, w.Body.Len())
}
func TestContextBindWithXML(t *testing.T) {
w := httptest.NewRecorder()
c, _ := CreateTestContext(w)
@ -1672,6 +1704,92 @@ func TestContextShouldBindWithJSON(t *testing.T) {
assert.Equal(t, 0, w.Body.Len())
}
func TestContextShouldBindWithJSONContextual(t *testing.T) {
w := httptest.NewRecorder()
c, _ := CreateTestContext(w)
c.Request, _ = http.NewRequest("POST", "/", bytes.NewBufferString("{\"bar\":\"foo\"}"))
c.Request.Header.Add("Content-Type", MIMEXML) // set fake content-type
var obj struct {
Foo string `json:"foo" binding:"required_if_condition"`
Bar string `json:"bar"`
}
c.Set("condition", true)
assert.Error(t, c.ShouldBindJSON(&obj))
c.Request, _ = http.NewRequest("POST", "/", bytes.NewBufferString("{\"foo\":\"bar\", \"bar\":\"foo\"}"))
assert.NoError(t, c.ShouldBindJSON(&obj))
assert.Equal(t, "foo", obj.Bar)
assert.Equal(t, "bar", obj.Foo)
assert.Equal(t, 0, w.Body.Len())
}
func TestContextShouldBindBodyWithJSONContextual(t *testing.T) {
w := httptest.NewRecorder()
c, _ := CreateTestContext(w)
var obj struct {
Foo string `json:"foo" binding:"required_if_condition"`
Bar string `json:"bar"`
}
c.Set("condition", true)
c.Set(BodyBytesKey, []byte("{\"bar\":\"foo\"}"))
assert.Error(t, c.ShouldBindBodyWith(&obj, binding.JSON))
c.Set(BodyBytesKey, []byte("{\"foo\":\"bar\", \"bar\":\"foo\"}"))
assert.NoError(t, c.ShouldBindBodyWith(&obj, binding.JSON))
assert.Equal(t, "foo", obj.Bar)
assert.Equal(t, "bar", obj.Foo)
assert.Equal(t, 0, w.Body.Len())
}
func TestContextShouldBindWithNotContextBinding(t *testing.T) {
w := httptest.NewRecorder()
c, _ := CreateTestContext(w)
c.Request, _ = http.NewRequest("POST", "/", bytes.NewBufferString("{\"foo\":\"bar\", \"bar\":\"foo\"}"))
c.Request.Header.Add("Content-Type", MIMEXML) // set fake content-type
var obj struct {
Foo string `json:"foo" binding:"required_if_condition"`
Bar string `json:"bar"`
}
assert.NoError(t, c.ShouldBindWith(&obj, notContextBinding{}))
assert.Equal(t, "foo", obj.Bar)
assert.Equal(t, "bar", obj.Foo)
assert.Equal(t, 0, w.Body.Len())
}
func TestContextShouldBindBodyWithNotContextBinding(t *testing.T) {
w := httptest.NewRecorder()
c, _ := CreateTestContext(w)
var obj struct {
Foo string `json:"foo"`
Bar string `json:"bar"`
}
c.Set(BodyBytesKey, []byte("{\"foo\":\"bar\", \"bar\":\"foo\"}"))
assert.NoError(t, c.ShouldBindBodyWith(&obj, notContextBinding{}))
assert.Equal(t, "foo", obj.Bar)
assert.Equal(t, "bar", obj.Foo)
assert.Equal(t, 0, w.Body.Len())
}
type notContextBinding struct{}
func (notContextBinding) Name() string {
return binding.JSON.Name()
}
func (b notContextBinding) Bind(req *http.Request, obj interface{}) error {
return binding.JSON.Bind(req, obj)
}
func (b notContextBinding) BindBody(body []byte, obj interface{}) error {
return binding.JSON.BindBody(body, obj)
}
func TestContextShouldBindWithXML(t *testing.T) {
w := httptest.NewRecorder()
c, _ := CreateTestContext(w)