diff --git a/response_writer.go b/response_writer.go index 9035e6f1..0479ae1b 100644 --- a/response_writer.go +++ b/response_writer.go @@ -17,7 +17,10 @@ const ( defaultStatus = http.StatusOK ) -var errHijackAlreadyWritten = errors.New("gin: response body already written") +var ( + errHijackAlreadyWritten = errors.New("gin: response body already written") + errHijackNotSupported = errors.New("gin: underlying ResponseWriter does not implement http.Hijacker") +) // ResponseWriter ... type ResponseWriter interface { @@ -117,12 +120,25 @@ func (w *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { if w.size < 0 { w.size = 0 } - return w.ResponseWriter.(http.Hijacker).Hijack() + if hijacker, ok := w.ResponseWriter.(http.Hijacker); ok { + return hijacker.Hijack() + } + return nil, nil, errHijackNotSupported } // CloseNotify implements the http.CloseNotifier interface. +// +// If the underlying ResponseWriter doesn't implement http.CloseNotifier +// (e.g. httptest.NewRecorder), the returned channel will never fire. +// Use Request.Context().Done() to observe client disconnects instead. +// +// Deprecated: the CloseNotifier interface predates Go's context package. +// New code should use Request.Context instead. func (w *responseWriter) CloseNotify() <-chan bool { - return w.ResponseWriter.(http.CloseNotifier).CloseNotify() + if cn, ok := w.ResponseWriter.(http.CloseNotifier); ok { + return cn.CloseNotify() + } + return make(chan bool) } // Flush implements the http.Flusher interface. diff --git a/response_writer_test.go b/response_writer_test.go index dfc1d2c6..a58068f0 100644 --- a/response_writer_test.go +++ b/response_writer_test.go @@ -113,15 +113,12 @@ func TestResponseWriterHijack(t *testing.T) { writer.reset(testWriter) w := ResponseWriter(writer) - assert.Panics(t, func() { - _, _, err := w.Hijack() - require.NoError(t, err) - }) + _, _, err := w.Hijack() + require.ErrorIs(t, err, errHijackNotSupported) assert.True(t, w.Written()) - assert.Panics(t, func() { - w.CloseNotify() - }) + ch := w.CloseNotify() + assert.NotNil(t, ch) w.Flush() } @@ -315,3 +312,48 @@ func TestPusherWithoutPusher(t *testing.T) { pusher := w.Pusher() assert.Nil(t, pusher, "Expected pusher to be nil") } + +// mockCloseNotifier is an http.ResponseWriter that implements http.CloseNotifier. +type mockCloseNotifier struct { + *httptest.ResponseRecorder +} + +func (m *mockCloseNotifier) CloseNotify() <-chan bool { + return make(chan bool) +} + +func TestCloseNotifyWithCloseNotifier(t *testing.T) { + rw := &mockCloseNotifier{ResponseRecorder: httptest.NewRecorder()} + w := &responseWriter{} + w.reset(rw) + + ch := w.CloseNotify() + assert.NotNil(t, ch, "Expected CloseNotify channel to be non-nil") +} + +func TestCloseNotifyWithoutCloseNotifier(t *testing.T) { + // httptest.NewRecorder does not implement http.CloseNotifier + rw := httptest.NewRecorder() + w := &responseWriter{} + w.reset(rw) + + ch := w.CloseNotify() + assert.NotNil(t, ch, "Expected non-nil channel when CloseNotifier is not supported") + select { + case <-ch: + t.Fatal("channel should never fire when CloseNotifier is not supported") + default: + } +} + +func TestHijackWithoutHijacker(t *testing.T) { + // httptest.NewRecorder does not implement http.Hijacker + rw := httptest.NewRecorder() + w := &responseWriter{} + w.reset(rw) + + conn, buf, err := w.Hijack() + assert.Nil(t, conn) + assert.Nil(t, buf) + require.ErrorIs(t, err, errHijackNotSupported) +}