diff --git a/binding/form_mapping.go b/binding/form_mapping.go index b81ad195..e4fb1b8b 100644 --- a/binding/form_mapping.go +++ b/binding/form_mapping.go @@ -5,6 +5,7 @@ package binding import ( + "encoding" "errors" "fmt" "reflect" @@ -16,6 +17,11 @@ import ( "github.com/gin-gonic/gin/internal/json" ) +var ( + typeTextUnmarshaler = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() + typeTime = reflect.TypeOf(time.Time{}) +) + var errUnknownType = errors.New("unknown type") func mapUri(ptr interface{}, m map[string][]string) error { @@ -205,11 +211,15 @@ func setWithProperType(val string, value reflect.Value, field reflect.StructFiel case reflect.String: value.SetString(val) case reflect.Struct: - switch value.Interface().(type) { - case time.Time: + tValue := value.Type() + switch { + case tValue == typeTime: return setTimeField(val, field, value) + case reflect.PtrTo(tValue).Implements(typeTextUnmarshaler): + return setTextUnmarshalerField(val, field, value.Addr()) + default: + return json.Unmarshal(bytesconv.StringToBytes(val), value.Addr().Interface()) } - return json.Unmarshal(bytesconv.StringToBytes(val), value.Addr().Interface()) case reflect.Map: return json.Unmarshal(bytesconv.StringToBytes(val), value.Addr().Interface()) default: @@ -218,6 +228,11 @@ func setWithProperType(val string, value reflect.Value, field reflect.StructFiel return nil } +func setTextUnmarshalerField(val string, field reflect.StructField, value reflect.Value) error { + u := value.Interface().(encoding.TextUnmarshaler) + return u.UnmarshalText(bytesconv.StringToBytes(val)) +} + func setIntField(val string, bitSize int, field reflect.Value) error { if val == "" { val = "0" diff --git a/binding/form_mapping_test.go b/binding/form_mapping_test.go index 2a560371..ba06a01c 100644 --- a/binding/form_mapping_test.go +++ b/binding/form_mapping_test.go @@ -5,6 +5,7 @@ package binding import ( + "github.com/gin-gonic/gin/internal/bytesconv" "reflect" "testing" "time" @@ -279,3 +280,43 @@ func TestMappingIgnoredCircularRef(t *testing.T) { err := mappingByPtr(&s, formSource{}, "form") assert.NoError(t, err) } + +type TestStringWrapper struct { + defined bool + val string +} + +func (t *TestStringWrapper) UnmarshalText(text []byte) error { + t.defined = true + t.val = bytesconv.BytesToString(text) + return nil +} + +func (t *TestStringWrapper) String() string { + return t.val +} + +func (t *TestStringWrapper) Undefined() bool { + return !t.defined +} + +func TestMappingTextUnmarshaler(t *testing.T) { + type Query struct { + Name TestStringWrapper `json:"name" form:"name"` + } + + q := Query{} + err := mappingByPtr(&q, formSource{}, "form") + assert.NoError(t, err) + assert.True(t, q.Name.Undefined()) + assert.Empty(t, q.Name.String()) + + form := map[string][]string{ + "name": {"test"}, + } + + err = mappingByPtr(&q, formSource(form), "form") + assert.NoError(t, err) + assert.False(t, q.Name.Undefined()) + assert.Equal(t, "test", q.Name.String()) +}