diff --git a/response_writer.go b/response_writer.go index 753a0b09..ab2f5fec 100644 --- a/response_writer.go +++ b/response_writer.go @@ -6,6 +6,7 @@ package gin import ( "bufio" + "errors" "io" "net" "net/http" @@ -16,6 +17,8 @@ const ( defaultStatus = http.StatusOK ) +var errHijackAlreadyWritten = errors.New("gin: response already written") + // ResponseWriter ... type ResponseWriter interface { http.ResponseWriter @@ -106,6 +109,9 @@ func (w *responseWriter) Written() bool { // Hijack implements the http.Hijacker interface. func (w *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if w.Written() { + return nil, nil, errHijackAlreadyWritten + } if w.size < 0 { w.size = 0 } diff --git a/response_writer_test.go b/response_writer_test.go index 259b8fa8..ef198418 100644 --- a/response_writer_test.go +++ b/response_writer_test.go @@ -5,6 +5,8 @@ package gin import ( + "bufio" + "net" "net/http" "net/http/httptest" "testing" @@ -124,6 +126,74 @@ func TestResponseWriterHijack(t *testing.T) { w.Flush() } +type mockHijacker struct { + *httptest.ResponseRecorder + hijacked bool +} + +// Hijack implements the http.Hijacker interface. It just records that it was called. +func (m *mockHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { + m.hijacked = true + return nil, nil, nil +} + +func TestResponseWriterHijackAfterWrite(t *testing.T) { + tests := []struct { + name string + action func(w ResponseWriter) error // Action to perform before hijacking + expectWrittenBeforeHijack bool + expectHijackSuccess bool + expectWrittenAfterHijack bool + expectError error + }{ + { + name: "hijack before write should succeed", + action: func(w ResponseWriter) error { return nil }, + expectWrittenBeforeHijack: false, + expectHijackSuccess: true, + expectWrittenAfterHijack: true, // Hijack itself marks the writer as written + expectError: nil, + }, + { + name: "hijack after write should fail", + action: func(w ResponseWriter) error { + _, err := w.Write([]byte("test")) + return err + }, + expectWrittenBeforeHijack: true, + expectHijackSuccess: false, + expectWrittenAfterHijack: true, + expectError: errHijackAlreadyWritten, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + hijacker := &mockHijacker{ResponseRecorder: httptest.NewRecorder()} + writer := &responseWriter{} + writer.reset(hijacker) + w := ResponseWriter(writer) + + // Check initial state + assert.False(t, w.Written(), "should not be written initially") + + // Perform pre-hijack action + require.NoError(t, tc.action(w), "unexpected error during pre-hijack action") + + // Check state before hijacking + assert.Equal(t, tc.expectWrittenBeforeHijack, w.Written(), "unexpected w.Written() state before hijack") + + // Attempt to hijack + _, _, hijackErr := w.Hijack() + + // Check results + require.ErrorIs(t, hijackErr, tc.expectError, "unexpected error from Hijack()") + assert.Equal(t, tc.expectHijackSuccess, hijacker.hijacked, "unexpected hijacker.hijacked state") + assert.Equal(t, tc.expectWrittenAfterHijack, w.Written(), "unexpected w.Written() state after hijack") + }) + } +} + func TestResponseWriterFlush(t *testing.T) { testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { writer := &responseWriter{}