Implement io.ReaderFrom on responseWriter

The private ResponseWriter implementation in the stdlib implements this
interface to enable some optimizations around sendfile. The interface is
asserted for in io.Copy, probably the most common way to serve files, so
without this method we are missing out on those optimizations. As we're
only interested in writing the headers and capturing the number of bytes
written, we can just call io.Copy on the wrapped http.ResponseWriter and
it will do the right thing, calling ReadFrom if it's implemented and
doing a regular copy if not.
This commit is contained in:
Alex Guerra 2016-06-14 20:36:17 -05:00
parent e66f3b5a53
commit 7fb5d3b261
2 changed files with 33 additions and 0 deletions

View File

@ -22,6 +22,7 @@ type (
http.Hijacker
http.Flusher
http.CloseNotifier
io.ReaderFrom
// Returns the HTTP response status code of the current request.
Status() int
@ -114,3 +115,11 @@ func (w *responseWriter) CloseNotify() <-chan bool {
func (w *responseWriter) Flush() {
w.ResponseWriter.(http.Flusher).Flush()
}
// Implements the io.ReaderFrom interface
func (w *responseWriter) ReadFrom(src io.Reader) (n int64, err error) {
w.WriteHeaderNow()
n, err = io.Copy(w.ResponseWriter, src)
w.size += int(n)
return
}

View File

@ -5,8 +5,10 @@
package gin
import (
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/stretchr/testify/assert"
@ -23,6 +25,7 @@ var _ http.ResponseWriter = ResponseWriter(&responseWriter{})
var _ http.Hijacker = ResponseWriter(&responseWriter{})
var _ http.Flusher = ResponseWriter(&responseWriter{})
var _ http.CloseNotifier = ResponseWriter(&responseWriter{})
var _ io.ReaderFrom = ResponseWriter(&responseWriter{})
func init() {
SetMode(TestMode)
@ -113,3 +116,24 @@ func TestResponseWriterHijack(t *testing.T) {
w.Flush()
}
func TestResponseWriterReadFrom(t *testing.T) {
testWriter := httptest.NewRecorder()
writer := &responseWriter{}
writer.reset(testWriter)
w := ResponseWriter(writer)
n, err := io.Copy(w, strings.NewReader("hola"))
assert.Equal(t, n, int64(4))
assert.Equal(t, w.Size(), 4)
assert.Equal(t, w.Status(), 200)
assert.Equal(t, testWriter.Code, 200)
assert.Equal(t, testWriter.Body.String(), "hola")
assert.NoError(t, err)
n, err = writer.ReadFrom(strings.NewReader(" adios"))
assert.Equal(t, n, int64(6))
assert.Equal(t, w.Size(), 10)
assert.Equal(t, testWriter.Body.String(), "hola adios")
assert.NoError(t, err)
}