From 7fb5d3b261e003713a3434051484f5beb2b10610 Mon Sep 17 00:00:00 2001 From: Alex Guerra Date: Tue, 14 Jun 2016 20:36:17 -0500 Subject: [PATCH] Implement io.ReaderFrom on responseWriter The private ResponseWriter implementation in the stdlib implements this interface to enable some optimizations around sendfile. The interface is asserted for in io.Copy, probably the most common way to serve files, so without this method we are missing out on those optimizations. As we're only interested in writing the headers and capturing the number of bytes written, we can just call io.Copy on the wrapped http.ResponseWriter and it will do the right thing, calling ReadFrom if it's implemented and doing a regular copy if not. --- response_writer.go | 9 +++++++++ response_writer_test.go | 24 ++++++++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/response_writer.go b/response_writer.go index fcbe230d..99176424 100644 --- a/response_writer.go +++ b/response_writer.go @@ -22,6 +22,7 @@ type ( http.Hijacker http.Flusher http.CloseNotifier + io.ReaderFrom // Returns the HTTP response status code of the current request. Status() int @@ -114,3 +115,11 @@ func (w *responseWriter) CloseNotify() <-chan bool { func (w *responseWriter) Flush() { w.ResponseWriter.(http.Flusher).Flush() } + +// Implements the io.ReaderFrom interface +func (w *responseWriter) ReadFrom(src io.Reader) (n int64, err error) { + w.WriteHeaderNow() + n, err = io.Copy(w.ResponseWriter, src) + w.size += int(n) + return +} diff --git a/response_writer_test.go b/response_writer_test.go index 14ff3a89..3b98691c 100644 --- a/response_writer_test.go +++ b/response_writer_test.go @@ -5,8 +5,10 @@ package gin import ( + "io" "net/http" "net/http/httptest" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -23,6 +25,7 @@ var _ http.ResponseWriter = ResponseWriter(&responseWriter{}) var _ http.Hijacker = ResponseWriter(&responseWriter{}) var _ http.Flusher = ResponseWriter(&responseWriter{}) var _ http.CloseNotifier = ResponseWriter(&responseWriter{}) +var _ io.ReaderFrom = ResponseWriter(&responseWriter{}) func init() { SetMode(TestMode) @@ -113,3 +116,24 @@ func TestResponseWriterHijack(t *testing.T) { w.Flush() } + +func TestResponseWriterReadFrom(t *testing.T) { + testWriter := httptest.NewRecorder() + writer := &responseWriter{} + writer.reset(testWriter) + w := ResponseWriter(writer) + + n, err := io.Copy(w, strings.NewReader("hola")) + assert.Equal(t, n, int64(4)) + assert.Equal(t, w.Size(), 4) + assert.Equal(t, w.Status(), 200) + assert.Equal(t, testWriter.Code, 200) + assert.Equal(t, testWriter.Body.String(), "hola") + assert.NoError(t, err) + + n, err = writer.ReadFrom(strings.NewReader(" adios")) + assert.Equal(t, n, int64(6)) + assert.Equal(t, w.Size(), 10) + assert.Equal(t, testWriter.Body.String(), "hola adios") + assert.NoError(t, err) +}