From edd2f42e35286ab4aae431f08d690d90b63cbaf3 Mon Sep 17 00:00:00 2001 From: 581 Date: Fri, 2 Mar 2018 13:01:02 +0800 Subject: [PATCH] update for supporting file binding --- binding/binding.go | 2 + binding/binding_test.go | 88 ++++++++++++++++++++++++++++++++++++++++- binding/form.go | 5 +++ binding/form_mapping.go | 30 ++++++++++++++ 4 files changed, 123 insertions(+), 2 deletions(-) diff --git a/binding/binding.go b/binding/binding.go index dc32d538..7e65c7f7 100644 --- a/binding/binding.go +++ b/binding/binding.go @@ -69,6 +69,8 @@ func Default(method, contentType string) Binding { return ProtoBuf case MIMEMSGPACK, MIMEMSGPACK2: return MsgPack + case MIMEMultipartPOSTForm: + return FormMultipart default: //case MIMEPOSTForm, MIMEMultipartPOSTForm: return Form } diff --git a/binding/binding_test.go b/binding/binding_test.go index 0c0f9f81..c6ad0c77 100644 --- a/binding/binding_test.go +++ b/binding/binding_test.go @@ -8,9 +8,11 @@ import ( "bytes" "encoding/json" "errors" + "io" "io/ioutil" "mime/multipart" "net/http" + "os" "testing" "time" @@ -29,6 +31,18 @@ type FooBarStruct struct { Bar string `msgpack:"bar" json:"bar" form:"bar" xml:"bar" binding:"required"` } +type FooBarFileStruct struct { + FooBarStruct + File *multipart.FileHeader `form:"file" binding:"required"` +} + +type FooBarFileFailStruct struct { + FooBarStruct + File *multipart.FileHeader `invalid_name:"file" binding:"required"` + // for unexport test + data *multipart.FileHeader `form:"data" binding:"required"` +} + type FooStructUseNumber struct { Foo interface{} `json:"foo" binding:"required"` } @@ -152,8 +166,8 @@ func TestBindingDefault(t *testing.T) { assert.Equal(t, Default("POST", MIMEPOSTForm), Form) assert.Equal(t, Default("PUT", MIMEPOSTForm), Form) - assert.Equal(t, Default("POST", MIMEMultipartPOSTForm), Form) - assert.Equal(t, Default("PUT", MIMEMultipartPOSTForm), Form) + assert.Equal(t, Default("POST", MIMEMultipartPOSTForm), FormMultipart) + assert.Equal(t, Default("PUT", MIMEMultipartPOSTForm), FormMultipart) assert.Equal(t, Default("POST", MIMEPROTOBUF), ProtoBuf) assert.Equal(t, Default("PUT", MIMEPROTOBUF), ProtoBuf) @@ -413,6 +427,48 @@ func createFormPostRequestFail() *http.Request { return req } +func createFormFilesMultipartRequest() *http.Request { + boundary := "--testboundary" + body := new(bytes.Buffer) + mw := multipart.NewWriter(body) + defer mw.Close() + + mw.SetBoundary(boundary) + mw.WriteField("foo", "bar") + mw.WriteField("bar", "foo") + + f, _ := os.Open("form.go") + defer f.Close() + fw, _ := mw.CreateFormFile("file", "form.go") + io.Copy(fw, f) + + req, _ := http.NewRequest("POST", "/?foo=getfoo&bar=getbar", body) + req.Header.Set("Content-Type", MIMEMultipartPOSTForm+"; boundary="+boundary) + + return req +} + +func createFormFilesMultipartRequestFail() *http.Request { + boundary := "--testboundary" + body := new(bytes.Buffer) + mw := multipart.NewWriter(body) + defer mw.Close() + + mw.SetBoundary(boundary) + mw.WriteField("foo", "bar") + mw.WriteField("bar", "foo") + + f, _ := os.Open("form.go") + defer f.Close() + fw, _ := mw.CreateFormFile("file_foo", "form_foo.go") + io.Copy(fw, f) + + req, _ := http.NewRequest("POST", "/?foo=getfoo&bar=getbar", body) + req.Header.Set("Content-Type", MIMEMultipartPOSTForm+"; boundary="+boundary) + + return req +} + func createFormMultipartRequest() *http.Request { boundary := "--testboundary" body := new(bytes.Buffer) @@ -457,6 +513,34 @@ func TestBindingFormPostFail(t *testing.T) { assert.Error(t, err) } +func TestBindingFormFilesMultipart(t *testing.T) { + req := createFormFilesMultipartRequest() + var obj FooBarFileStruct + FormMultipart.Bind(req, &obj) + + // file from os + f, _ := os.Open("form.go") + defer f.Close() + fileActual, _ := ioutil.ReadAll(f) + + // file from multipart + mf, _ := obj.File.Open() + defer mf.Close() + fileExpect, _ := ioutil.ReadAll(mf) + + assert.Equal(t, FormMultipart.Name(), "multipart/form-data") + assert.Equal(t, obj.Foo, "bar") + assert.Equal(t, obj.Bar, "foo") + assert.Equal(t, fileExpect, fileActual) +} + +func TestBindingFormFilesMultipartFail(t *testing.T) { + req := createFormFilesMultipartRequestFail() + var obj FooBarFileFailStruct + err := FormMultipart.Bind(req, &obj) + assert.Error(t, err) +} + func TestBindingFormMultipart(t *testing.T) { req := createFormMultipartRequest() var obj FooBarStruct diff --git a/binding/form.go b/binding/form.go index 0be59660..b26e036b 100644 --- a/binding/form.go +++ b/binding/form.go @@ -52,5 +52,10 @@ func (formMultipartBinding) Bind(req *http.Request, obj interface{}) error { if err := mapForm(obj, req.MultipartForm.Value); err != nil { return err } + + if err := mapFiles(obj, req); err != nil { + return err + } + return validate(obj) } diff --git a/binding/form_mapping.go b/binding/form_mapping.go index dd8c6246..ae57edff 100644 --- a/binding/form_mapping.go +++ b/binding/form_mapping.go @@ -6,11 +6,41 @@ package binding import ( "errors" + "fmt" + "net/http" "reflect" "strconv" "time" ) +func mapFiles(ptr interface{}, req *http.Request) error { + typ := reflect.TypeOf(ptr).Elem() + val := reflect.ValueOf(ptr).Elem() + for i := 0; i < typ.NumField(); i++ { + typeField := typ.Field(i) + structField := val.Field(i) + + t := fmt.Sprintf("%s", typeField.Type) + if string(t) != "*multipart.FileHeader" { + continue + } + + inputFieldName := typeField.Tag.Get("form") + if inputFieldName == "" { + inputFieldName = typeField.Name + } + + _, fileHeader, err := req.FormFile(inputFieldName) + if err != nil { + return err + } + + structField.Set(reflect.ValueOf(fileHeader)) + + } + return nil +} + func mapForm(ptr interface{}, form map[string][]string) error { typ := reflect.TypeOf(ptr).Elem() val := reflect.ValueOf(ptr).Elem()