From fdb1c440e5edc88881dba0675bd320b7ce6ad6ab Mon Sep 17 00:00:00 2001 From: Aaron Janse Date: Mon, 13 Jun 2022 14:20:02 -0400 Subject: [PATCH] implement io.ReaderFrom for gin.ResponseWriter This causes io.Copy, which calls io.copyBuffer, to automatically use the sendfile syscall by calling the underlying http.ResponseWriter's ReadFrom implementation [1]. This can be a significant performance improvement. [1] https://github.com/golang/go/blob/7eeec1f6e4/src/io/io.go#L410-L413 Co-authored-by: Alex Guerra --- response_writer.go | 10 ++++++++++ response_writer_test.go | 22 ++++++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/response_writer.go b/response_writer.go index 77c7ed8f..6e6a2a8e 100644 --- a/response_writer.go +++ b/response_writer.go @@ -22,6 +22,7 @@ type ResponseWriter interface { http.Hijacker http.Flusher http.CloseNotifier + io.ReaderFrom // Status returns the HTTP response status code of the current request. Status() int @@ -87,6 +88,15 @@ func (w *responseWriter) WriteString(s string) (n int, err error) { return } +// ReadFrom implements the io.ReaderFrom interface, allowing Go to automatically +// use the sendfile syscall in methods such as http.ServeFile +func (w *responseWriter) ReadFrom(src io.Reader) (n int64, err error) { + w.WriteHeaderNow() + n, err = io.Copy(w.ResponseWriter, src) + w.size += int(n) + return +} + func (w *responseWriter) Status() int { return w.status } diff --git a/response_writer_test.go b/response_writer_test.go index 57d163c9..e97f403c 100644 --- a/response_writer_test.go +++ b/response_writer_test.go @@ -5,8 +5,10 @@ package gin import ( + "io" "net/http" "net/http/httptest" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -132,3 +134,23 @@ func TestResponseWriterFlush(t *testing.T) { assert.NoError(t, err) assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) } + +func TestResponseWriterReadFrom(t *testing.T) { + testWriter := httptest.NewRecorder() + writer := &responseWriter{} + writer.reset(testWriter) + w := ResponseWriter(writer) + + n, err := io.Copy(w, strings.NewReader("hola")) + assert.Equal(t, int64(4), n) + assert.Equal(t, 4, w.Size()) + assert.Equal(t, http.StatusOK, w.Status()) + assert.Equal(t, http.StatusOK, testWriter.Code) + assert.Equal(t, "hola", testWriter.Body.String()) + assert.NoError(t, err) + n, err = writer.ReadFrom(strings.NewReader(" adios")) + assert.Equal(t, int64(6), n) + assert.Equal(t, 10, w.Size()) + assert.Equal(t, testWriter.Body.String(), "hola adios") + assert.NoError(t, err) +}