diff --git a/binding/binding.go b/binding/binding.go index 26d71c9f..520c5109 100644 --- a/binding/binding.go +++ b/binding/binding.go @@ -98,7 +98,9 @@ func Default(method, contentType string) Binding { return MsgPack case MIMEYAML: return YAML - default: //case MIMEPOSTForm, MIMEMultipartPOSTForm: + case MIMEMultipartPOSTForm: + return FormMultipart + default: // case MIMEPOSTForm: return Form } } diff --git a/binding/binding_test.go b/binding/binding_test.go index b265af36..ee788225 100644 --- a/binding/binding_test.go +++ b/binding/binding_test.go @@ -8,9 +8,11 @@ import ( "bytes" "encoding/json" "errors" + "io" "io/ioutil" "mime/multipart" "net/http" + "os" "strconv" "strings" "testing" @@ -31,6 +33,18 @@ type FooBarStruct struct { Bar string `msgpack:"bar" json:"bar" form:"bar" xml:"bar" binding:"required"` } +type FooBarFileStruct struct { + FooBarStruct + File *multipart.FileHeader `form:"file" binding:"required"` +} + +type FooBarFileFailStruct struct { + FooBarStruct + File *multipart.FileHeader `invalid_name:"file" binding:"required"` + // for unexport test + data *multipart.FileHeader `form:"data" binding:"required"` +} + type FooDefaultBarStruct struct { FooStruct Bar string `msgpack:"bar" json:"bar" form:"bar,default=hello" xml:"bar" binding:"required"` @@ -187,8 +201,8 @@ func TestBindingDefault(t *testing.T) { assert.Equal(t, Form, Default("POST", MIMEPOSTForm)) assert.Equal(t, Form, Default("PUT", MIMEPOSTForm)) - assert.Equal(t, Form, Default("POST", MIMEMultipartPOSTForm)) - assert.Equal(t, Form, Default("PUT", MIMEMultipartPOSTForm)) + assert.Equal(t, FormMultipart, Default("POST", MIMEMultipartPOSTForm)) + assert.Equal(t, FormMultipart, Default("PUT", MIMEMultipartPOSTForm)) assert.Equal(t, ProtoBuf, Default("POST", MIMEPROTOBUF)) assert.Equal(t, ProtoBuf, Default("PUT", MIMEPROTOBUF)) @@ -536,6 +550,54 @@ func createFormPostRequestForMapFail(t *testing.T) *http.Request { return req } +func createFormFilesMultipartRequest(t *testing.T) *http.Request { + boundary := "--testboundary" + body := new(bytes.Buffer) + mw := multipart.NewWriter(body) + defer mw.Close() + + assert.NoError(t, mw.SetBoundary(boundary)) + assert.NoError(t, mw.WriteField("foo", "bar")) + assert.NoError(t, mw.WriteField("bar", "foo")) + + f, err := os.Open("form.go") + assert.NoError(t, err) + defer f.Close() + fw, err1 := mw.CreateFormFile("file", "form.go") + assert.NoError(t, err1) + io.Copy(fw, f) + + req, err2 := http.NewRequest("POST", "/?foo=getfoo&bar=getbar", body) + assert.NoError(t, err2) + req.Header.Set("Content-Type", MIMEMultipartPOSTForm+"; boundary="+boundary) + + return req +} + +func createFormFilesMultipartRequestFail(t *testing.T) *http.Request { + boundary := "--testboundary" + body := new(bytes.Buffer) + mw := multipart.NewWriter(body) + defer mw.Close() + + assert.NoError(t, mw.SetBoundary(boundary)) + assert.NoError(t, mw.WriteField("foo", "bar")) + assert.NoError(t, mw.WriteField("bar", "foo")) + + f, err := os.Open("form.go") + assert.NoError(t, err) + defer f.Close() + fw, err1 := mw.CreateFormFile("file_foo", "form_foo.go") + assert.NoError(t, err1) + io.Copy(fw, f) + + req, err2 := http.NewRequest("POST", "/?foo=getfoo&bar=getbar", body) + assert.NoError(t, err2) + req.Header.Set("Content-Type", MIMEMultipartPOSTForm+"; boundary="+boundary) + + return req +} + func createFormMultipartRequest(t *testing.T) *http.Request { boundary := "--testboundary" body := new(bytes.Buffer) @@ -613,6 +675,34 @@ func TestBindingFormPostForMapFail(t *testing.T) { assert.Error(t, err) } +func TestBindingFormFilesMultipart(t *testing.T) { + req := createFormFilesMultipartRequest(t) + var obj FooBarFileStruct + FormMultipart.Bind(req, &obj) + + // file from os + f, _ := os.Open("form.go") + defer f.Close() + fileActual, _ := ioutil.ReadAll(f) + + // file from multipart + mf, _ := obj.File.Open() + defer mf.Close() + fileExpect, _ := ioutil.ReadAll(mf) + + assert.Equal(t, FormMultipart.Name(), "multipart/form-data") + assert.Equal(t, obj.Foo, "bar") + assert.Equal(t, obj.Bar, "foo") + assert.Equal(t, fileExpect, fileActual) +} + +func TestBindingFormFilesMultipartFail(t *testing.T) { + req := createFormFilesMultipartRequestFail(t) + var obj FooBarFileFailStruct + err := FormMultipart.Bind(req, &obj) + assert.Error(t, err) +} + func TestBindingFormMultipart(t *testing.T) { req := createFormMultipartRequest(t) var obj FooBarStruct diff --git a/binding/form.go b/binding/form.go index 8955c95b..f1f89195 100644 --- a/binding/form.go +++ b/binding/form.go @@ -56,5 +56,10 @@ func (formMultipartBinding) Bind(req *http.Request, obj interface{}) error { if err := mapForm(obj, req.MultipartForm.Value); err != nil { return err } + + if err := mapFiles(obj, req); err != nil { + return err + } + return validate(obj) } diff --git a/binding/form_mapping.go b/binding/form_mapping.go index ba9d2c4f..fc33b1df 100644 --- a/binding/form_mapping.go +++ b/binding/form_mapping.go @@ -7,6 +7,7 @@ package binding import ( "errors" "fmt" + "net/http" "reflect" "strconv" "strings" @@ -15,6 +16,34 @@ import ( "github.com/gin-gonic/gin/internal/json" ) +func mapFiles(ptr interface{}, req *http.Request) error { + typ := reflect.TypeOf(ptr).Elem() + val := reflect.ValueOf(ptr).Elem() + for i := 0; i < typ.NumField(); i++ { + typeField := typ.Field(i) + structField := val.Field(i) + + t := fmt.Sprintf("%s", typeField.Type) + if string(t) != "*multipart.FileHeader" { + continue + } + + inputFieldName := typeField.Tag.Get("form") + if inputFieldName == "" { + inputFieldName = typeField.Name + } + + _, fileHeader, err := req.FormFile(inputFieldName) + if err != nil { + return err + } + + structField.Set(reflect.ValueOf(fileHeader)) + + } + return nil +} + var errUnknownType = errors.New("Unknown type") func mapUri(ptr interface{}, m map[string][]string) error {