Implement io.ReaderFrom on responseWriter

The private ResponseWriter implementation in the stdlib implements this
interface to enable some optimizations around sendfile. The interface is
asserted for in io.Copy, probably the most common way to serve files, so
without this method we are missing out on those optimizations. As we're
only interested in writing the headers and capturing the number of bytes
written, we can just call io.Copy on the wrapped http.ResponseWriter and
it will do the right thing, calling ReadFrom if it's implemented and
doing a regular copy if not.
This commit is contained in:
Alex Guerra 2018-10-25 14:13:15 -05:00
parent 37a58e1db0
commit cade018a7b
2 changed files with 32 additions and 0 deletions

View File

@ -21,6 +21,7 @@ type responseWriterBase interface {
http.Hijacker
http.Flusher
http.CloseNotifier
io.ReaderFrom
// Returns the HTTP response status code of the current request.
Status() int
@ -113,3 +114,11 @@ func (w *responseWriter) Flush() {
w.WriteHeaderNow()
w.ResponseWriter.(http.Flusher).Flush()
}
// ReadFrom implements the io.ReaderFrom interface.
func (w *responseWriter) ReadFrom(src io.Reader) (n int64, err error) {
w.WriteHeaderNow()
n, err = io.Copy(w.ResponseWriter, src)
w.size += int(n)
return
}

View File

@ -5,8 +5,10 @@
package gin
import (
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/stretchr/testify/assert"
@ -23,6 +25,7 @@ var _ http.ResponseWriter = ResponseWriter(&responseWriter{})
var _ http.Hijacker = ResponseWriter(&responseWriter{})
var _ http.Flusher = ResponseWriter(&responseWriter{})
var _ http.CloseNotifier = ResponseWriter(&responseWriter{})
var _ io.ReaderFrom = ResponseWriter(&responseWriter{})
func init() {
SetMode(TestMode)
@ -129,3 +132,23 @@ func TestResponseWriterFlush(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
}
func TestResponseWriterReadFrom(t *testing.T) {
testWriter := httptest.NewRecorder()
writer := &responseWriter{}
writer.reset(testWriter)
w := ResponseWriter(writer)
n, err := io.Copy(w, strings.NewReader("hola"))
assert.Equal(t, int64(4), n)
assert.Equal(t, 4, w.Size())
assert.Equal(t, http.StatusOK, w.Status())
assert.Equal(t, http.StatusOK, testWriter.Code)
assert.Equal(t, "hola", testWriter.Body.String())
assert.NoError(t, err)
n, err = writer.ReadFrom(strings.NewReader(" adios"))
assert.Equal(t, int64(6), n)
assert.Equal(t, 10, w.Size())
assert.Equal(t, testWriter.Body.String(), "hola adios")
assert.NoError(t, err)
}