diff --git a/README.md b/README.md index 4aa638d6..7bab0267 100644 --- a/README.md +++ b/README.md @@ -2080,6 +2080,67 @@ func ListHandler(s *Service) func(ctx *gin.Context) { } ``` +### Bind form-data request with custom field type + +Gin can support the encoding.TextUnmarshaler interface for non-struct types + +```go +type HexInteger int + +func (f *HexInteger) UnmarshalText(text []byte) error { + v, err := strconv.ParseInt(string(text), 16, 64) + if err != nil { + return err + } + *f = HexInteger(v) + return nil +} + +type FormA struct { + FieldA HexInteger `form:"field_a"` +} + +// query with field_a = "0f" +func GetDataA(c *gin.Context) { + var a FormA + c.Bind(&a) + // a.FieldA == 15 +} +``` + +For struct types, you can implement your own custom Unmarshaler using the `binding.BindUnmarshaler` +interface, which has the interface signature of `UnmarshalParam(param string) error`. + +```go +type customType struct { + Protocol string + Path string + Name string +} + +func (f *customType) UnmarshalParam(param string) error { + parts := strings.Split(param, ":") + if len(parts) != 3 { + return fmt.Errorf("invalid format") + } + f.Protocol = parts[0] + f.Path = parts[1] + f.Name = parts[2] + return nil +} + +type FormA struct { + FieldA customType `form:"field_a"` +} + +// query with field_a = "file:/:foo" +func GetDataA(c *gin.Context) { + var a FormA + c.Bind(&a) + // a.FieldA.Protocol == "file", a.FieldA.Path == "/", and a.FieldA.Name == "foo" +} +``` + ### http2 server push http.Pusher is supported only **go1.8+**. See the [golang blog](https://blog.golang.org/h2push) for detail information. diff --git a/binding/binding.go b/binding/binding.go index 703a1cf8..b54b0a8e 100644 --- a/binding/binding.go +++ b/binding/binding.go @@ -46,6 +46,10 @@ type BindingUri interface { BindUri(map[string][]string, any) error } +type BindUnmarshaler interface { + UnmarshalParam(param string) error +} + // StructValidator is the minimal interface which needs to be implemented in // order for it to be used as the validator engine for ensuring the correctness // of the request. Gin provides a default implementation for this using diff --git a/binding/form_mapping.go b/binding/form_mapping.go index c24dd553..44eb0071 100644 --- a/binding/form_mapping.go +++ b/binding/form_mapping.go @@ -5,6 +5,7 @@ package binding import ( + "encoding" "errors" "fmt" "reflect" @@ -175,11 +176,17 @@ func setByForm(value reflect.Value, field reflect.StructField, form map[string][ if !ok { vs = []string{opt.defaultValue} } + if ok, err := trySetCustom(vs[0], value); ok || err != nil { + return ok, err + } return true, setSlice(vs, value, field) case reflect.Array: if !ok { vs = []string{opt.defaultValue} } + if ok, err := trySetCustom(vs[0], value); ok || err != nil { + return ok, err + } if len(vs) != value.Len() { return false, fmt.Errorf("%q is not valid value for %s", vs, value.Type().String()) } @@ -193,10 +200,26 @@ func setByForm(value reflect.Value, field reflect.StructField, form map[string][ if len(vs) > 0 { val = vs[0] } + if ok, err := trySetCustom(val, value); ok || err != nil { + return ok, err + } return true, setWithProperType(val, value, field) } } +func trySetCustom(val string, value reflect.Value) (isSet bool, err error) { + switch v := value.Addr().Interface().(type) { + case encoding.TextUnmarshaler: + if value.Kind() != reflect.Struct { + return true, v.UnmarshalText([]byte(val)) + } + case BindUnmarshaler: + return true, v.UnmarshalParam(val) + } + + return false, nil +} + func setWithProperType(val string, value reflect.Value, field reflect.StructField) error { switch value.Kind() { case reflect.Int: diff --git a/binding/form_mapping_test.go b/binding/form_mapping_test.go index 78f4df0e..e6e086c3 100644 --- a/binding/form_mapping_test.go +++ b/binding/form_mapping_test.go @@ -5,7 +5,10 @@ package binding import ( + "fmt" "reflect" + "strconv" + "strings" "testing" "time" @@ -288,3 +291,110 @@ func TestMappingIgnoredCircularRef(t *testing.T) { err := mappingByPtr(&s, formSource{}, "form") assert.NoError(t, err) } + +type foohex int + +func (f *foohex) UnmarshalText(text []byte) error { + v, err := strconv.ParseInt(string(text), 16, 64) + if err != nil { + return err + } + *f = foohex(v) + return nil +} + +func TestMappingCustomFieldType(t *testing.T) { + var s struct { + Foo foohex `form:"foo"` + } + err := mappingByPtr(&s, formSource{"foo": {`f5`}}, "form") + assert.NoError(t, err) + + assert.EqualValues(t, 245, s.Foo) +} + +func TestMappingCustomFieldTypeWithURI(t *testing.T) { + var s struct { + Foo foohex `uri:"foo"` + } + err := mappingByPtr(&s, formSource{"foo": {`f5`}}, "uri") + assert.NoError(t, err) + + assert.EqualValues(t, 245, s.Foo) +} + +type customType struct { + Protocol string + Path string + Name string +} + +func (f *customType) UnmarshalParam(param string) error { + parts := strings.Split(param, ":") + if len(parts) != 3 { + return fmt.Errorf("invalid format") + } + f.Protocol = parts[0] + f.Path = parts[1] + f.Name = parts[2] + return nil +} + +func TestMappingCustomStructType(t *testing.T) { + var s struct { + FileData customType `form:"data"` + } + err := mappingByPtr(&s, formSource{"data": {`file:/foo:happiness`}}, "form") + assert.NoError(t, err) + + assert.EqualValues(t, "file", s.FileData.Protocol) + assert.EqualValues(t, "/foo", s.FileData.Path) + assert.EqualValues(t, "happiness", s.FileData.Name) +} + +func TestMappingCustomPointerStructType(t *testing.T) { + var s struct { + FileData *customType `form:"data"` + } + err := mappingByPtr(&s, formSource{"data": {`file:/foo:happiness`}}, "form") + assert.NoError(t, err) + + assert.EqualValues(t, "file", s.FileData.Protocol) + assert.EqualValues(t, "/foo", s.FileData.Path) + assert.EqualValues(t, "happiness", s.FileData.Name) +} + +type MySlice []string + +func (s *MySlice) UnmarshalParam(param string) error { + *s = MySlice(strings.Split(param, ",")) + return nil +} + +func TestMappingCustomSliceType(t *testing.T) { + var s struct { + Permissions MySlice `form:"permissions"` + } + err := mappingByPtr(&s, formSource{"permissions": {"read,write,delete"}}, "form") + assert.NoError(t, err) + + assert.EqualValues(t, []string{"read", "write", "delete"}, s.Permissions) +} + +type MyArray [3]string + +func (s *MyArray) UnmarshalParam(param string) error { + parts := strings.Split(param, ",") + *s = MyArray([3]string{parts[0], parts[1], parts[2]}) + return nil +} + +func TestMappingCustomArrayType(t *testing.T) { + var s struct { + Permissions MyArray `form:"permissions"` + } + err := mappingByPtr(&s, formSource{"permissions": {"read,write,delete"}}, "form") + assert.NoError(t, err) + + assert.EqualValues(t, [3]string{"read", "write", "delete"}, s.Permissions) +}