diff --git a/README.md b/README.md index 05b696df..ccc5bcc5 100644 --- a/README.md +++ b/README.md @@ -275,8 +275,11 @@ func main() { // single file file, _ := c.FormFile("file") log.Println(file.Filename) - - c.String(http.StatusOK, fmt.Printf("'%s' uploaded!", file.Filename)) + + // Upload the file to specific dst. + // c.SaveUploadedFile(file, dst) + + c.String(http.StatusOK, fmt.Sprintf("'%s' uploaded!", file.Filename)) }) router.Run(":8080") } @@ -304,8 +307,11 @@ func main() { for _, file := range files { log.Println(file.Filename) + + // Upload the file to specific dst. + // c.SaveUploadedFile(file, dst) } - c.String(http.StatusOK, fmt.Printf("%d files uploaded!", len(files))) + c.String(http.StatusOK, fmt.Sprintf("%d files uploaded!", len(files))) }) router.Run(":8080") } diff --git a/context.go b/context.go index 198dd3e3..f29464d7 100644 --- a/context.go +++ b/context.go @@ -13,6 +13,7 @@ import ( "net" "net/http" "net/url" + "os" "strings" "time" @@ -431,6 +432,24 @@ func (c *Context) MultipartForm() (*multipart.Form, error) { return c.Request.MultipartForm, err } +// SaveUploadedFile uploads the form file to specific dst. +func (c *Context) SaveUploadedFile(file *multipart.FileHeader, dst string) error { + src, err := file.Open() + if err != nil { + return err + } + defer src.Close() + + out, err := os.Create(dst) + if err != nil { + return err + } + defer out.Close() + + io.Copy(out, src) + return nil +} + // Bind checks the Content-Type to select a binding engine automatically, // Depending the "Content-Type" header different bindings are used: // "application/json" --> JSON binding diff --git a/context_test.go b/context_test.go index 758fecdc..db960fb3 100644 --- a/context_test.go +++ b/context_test.go @@ -72,12 +72,18 @@ func TestContextFormFile(t *testing.T) { if assert.NoError(t, err) { assert.Equal(t, "test", f.Filename) } + + assert.NoError(t, c.SaveUploadedFile(f, "test")) } func TestContextMultipartForm(t *testing.T) { buf := new(bytes.Buffer) mw := multipart.NewWriter(buf) mw.WriteField("foo", "bar") + w, err := mw.CreateFormFile("file", "test") + if assert.NoError(t, err) { + w.Write([]byte("test")) + } mw.Close() c, _ := CreateTestContext(httptest.NewRecorder()) c.Request, _ = http.NewRequest("POST", "/", buf) @@ -86,6 +92,42 @@ func TestContextMultipartForm(t *testing.T) { if assert.NoError(t, err) { assert.NotNil(t, f) } + + assert.NoError(t, c.SaveUploadedFile(f.File["file"][0], "test")) +} + +func TestSaveUploadedOpenFailed(t *testing.T) { + buf := new(bytes.Buffer) + mw := multipart.NewWriter(buf) + mw.Close() + + c, _ := CreateTestContext(httptest.NewRecorder()) + c.Request, _ = http.NewRequest("POST", "/", buf) + c.Request.Header.Set("Content-Type", mw.FormDataContentType()) + + f := &multipart.FileHeader{ + Filename: "file", + } + assert.Error(t, c.SaveUploadedFile(f, "test")) +} + +func TestSaveUploadedCreateFailed(t *testing.T) { + buf := new(bytes.Buffer) + mw := multipart.NewWriter(buf) + w, err := mw.CreateFormFile("file", "test") + if assert.NoError(t, err) { + w.Write([]byte("test")) + } + mw.Close() + c, _ := CreateTestContext(httptest.NewRecorder()) + c.Request, _ = http.NewRequest("POST", "/", buf) + c.Request.Header.Set("Content-Type", mw.FormDataContentType()) + f, err := c.FormFile("file") + if assert.NoError(t, err) { + assert.Equal(t, "test", f.Filename) + } + + assert.Error(t, c.SaveUploadedFile(f, "/")) } func TestContextReset(t *testing.T) { diff --git a/examples/http2/main.go b/examples/http2/main.go index 19e65f84..07df01e2 100644 --- a/examples/http2/main.go +++ b/examples/http2/main.go @@ -2,6 +2,8 @@ package main import ( "html/template" + "log" + "os" "github.com/gin-gonic/gin" ) @@ -18,6 +20,9 @@ var html = template.Must(template.New("https").Parse(` `)) func main() { + logger := log.New(os.Stderr, "", 0) + logger.Println("[WARNING] DON'T USE THE EMBED CERTS FROM THIS EXAMPLE IN PRODUCTION ENVIRONMENT, GENERATE YOUR OWN!") + r := gin.Default() r.SetHTMLTemplate(html) diff --git a/examples/upload-file/multiple/main.go b/examples/upload-file/multiple/main.go index 22588348..4bb4cdcb 100644 --- a/examples/upload-file/multiple/main.go +++ b/examples/upload-file/multiple/main.go @@ -2,9 +2,7 @@ package main import ( "fmt" - "io" "net/http" - "os" "github.com/gin-gonic/gin" ) @@ -25,24 +23,10 @@ func main() { files := form.File["files"] for _, file := range files { - // Source - src, err := file.Open() - if err != nil { - c.String(http.StatusBadRequest, fmt.Sprintf("file open err: %s", err.Error())) + if err := c.SaveUploadedFile(file, file.Filename); err != nil { + c.String(http.StatusBadRequest, fmt.Sprintf("upload file err: %s", err.Error())) return } - defer src.Close() - - // Destination - dst, err := os.Create(file.Filename) - if err != nil { - c.String(http.StatusBadRequest, fmt.Sprintf("Create file err: %s", err.Error())) - return - } - defer dst.Close() - - // Copy - io.Copy(dst, src) } c.String(http.StatusOK, fmt.Sprintf("Uploaded successfully %d files with fields name=%s and email=%s.", len(files), name, email)) diff --git a/examples/upload-file/single/main.go b/examples/upload-file/single/main.go index 1e9596cb..372a2994 100644 --- a/examples/upload-file/single/main.go +++ b/examples/upload-file/single/main.go @@ -2,9 +2,7 @@ package main import ( "fmt" - "io" "net/http" - "os" "github.com/gin-gonic/gin" ) @@ -22,23 +20,11 @@ func main() { c.String(http.StatusBadRequest, fmt.Sprintf("get form err: %s", err.Error())) return } - src, err := file.Open() - if err != nil { - c.String(http.StatusBadRequest, fmt.Sprintf("file open err: %s", err.Error())) + + if err := c.SaveUploadedFile(file, file.Filename); err != nil { + c.String(http.StatusBadRequest, fmt.Sprintf("upload file err: %s", err.Error())) return } - defer src.Close() - - // Destination - dst, err := os.Create(file.Filename) - if err != nil { - c.String(http.StatusBadRequest, fmt.Sprintf("Create file err: %s", err.Error())) - return - } - defer dst.Close() - - // Copy - io.Copy(dst, src) c.String(http.StatusOK, fmt.Sprintf("File %s uploaded successfully with fields name=%s and email=%s.", file.Filename, name, email)) }) diff --git a/vendor/vendor.json b/vendor/vendor.json index 3af3cb55..29141c4c 100644 --- a/vendor/vendor.json +++ b/vendor/vendor.json @@ -34,10 +34,10 @@ "revisionTime": "2017-06-01T23:02:30Z" }, { - "checksumSHA1": "bgb/lk2wroBJ5z+JI2xnVj7WkwI=", + "checksumSHA1": "WFJPa8cL6nzQU3yA1iN+gmaqrSU=", "path": "github.com/json-iterator/go", - "revision": "845d8438db34cc782608bbee7647a522b4e87de0", - "revisionTime": "2017-07-10T17:07:18Z" + "revision": "4b33139ad07fda872cb378bb4218b2fab74ce62b", + "revisionTime": "2017-07-12T09:56:51Z" }, { "checksumSHA1": "9if9IBLsxkarJ804NPWAzgskIAk=",