From cade018a7b00fecb6157488f2c413d8f79b8de78 Mon Sep 17 00:00:00 2001 From: Alex Guerra Date: Thu, 25 Oct 2018 14:13:15 -0500 Subject: [PATCH] Implement io.ReaderFrom on responseWriter The private ResponseWriter implementation in the stdlib implements this interface to enable some optimizations around sendfile. The interface is asserted for in io.Copy, probably the most common way to serve files, so without this method we are missing out on those optimizations. As we're only interested in writing the headers and capturing the number of bytes written, we can just call io.Copy on the wrapped http.ResponseWriter and it will do the right thing, calling ReadFrom if it's implemented and doing a regular copy if not. --- response_writer.go | 9 +++++++++ response_writer_test.go | 23 +++++++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/response_writer.go b/response_writer.go index 923b53f8..680bc1e2 100644 --- a/response_writer.go +++ b/response_writer.go @@ -21,6 +21,7 @@ type responseWriterBase interface { http.Hijacker http.Flusher http.CloseNotifier + io.ReaderFrom // Returns the HTTP response status code of the current request. Status() int @@ -113,3 +114,11 @@ func (w *responseWriter) Flush() { w.WriteHeaderNow() w.ResponseWriter.(http.Flusher).Flush() } + +// ReadFrom implements the io.ReaderFrom interface. +func (w *responseWriter) ReadFrom(src io.Reader) (n int64, err error) { + w.WriteHeaderNow() + n, err = io.Copy(w.ResponseWriter, src) + w.size += int(n) + return +} diff --git a/response_writer_test.go b/response_writer_test.go index b8c5c885..caa64064 100644 --- a/response_writer_test.go +++ b/response_writer_test.go @@ -5,8 +5,10 @@ package gin import ( + "io" "net/http" "net/http/httptest" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -23,6 +25,7 @@ var _ http.ResponseWriter = ResponseWriter(&responseWriter{}) var _ http.Hijacker = ResponseWriter(&responseWriter{}) var _ http.Flusher = ResponseWriter(&responseWriter{}) var _ http.CloseNotifier = ResponseWriter(&responseWriter{}) +var _ io.ReaderFrom = ResponseWriter(&responseWriter{}) func init() { SetMode(TestMode) @@ -129,3 +132,23 @@ func TestResponseWriterFlush(t *testing.T) { assert.NoError(t, err) assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) } + +func TestResponseWriterReadFrom(t *testing.T) { + testWriter := httptest.NewRecorder() + writer := &responseWriter{} + writer.reset(testWriter) + w := ResponseWriter(writer) + + n, err := io.Copy(w, strings.NewReader("hola")) + assert.Equal(t, int64(4), n) + assert.Equal(t, 4, w.Size()) + assert.Equal(t, http.StatusOK, w.Status()) + assert.Equal(t, http.StatusOK, testWriter.Code) + assert.Equal(t, "hola", testWriter.Body.String()) + assert.NoError(t, err) + n, err = writer.ReadFrom(strings.NewReader(" adios")) + assert.Equal(t, int64(6), n) + assert.Equal(t, 10, w.Size()) + assert.Equal(t, testWriter.Body.String(), "hola adios") + assert.NoError(t, err) +}