diff --git a/response_writer.go b/response_writer.go index ab2f5fec..170c42f4 100644 --- a/response_writer.go +++ b/response_writer.go @@ -17,7 +17,7 @@ const ( defaultStatus = http.StatusOK ) -var errHijackAlreadyWritten = errors.New("gin: response already written") +var errHijackAlreadyWritten = errors.New("gin: response body already written") // ResponseWriter ... type ResponseWriter interface { @@ -109,7 +109,9 @@ func (w *responseWriter) Written() bool { // Hijack implements the http.Hijacker interface. func (w *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { - if w.Written() { + // Allow hijacking after headers are written (size == 0), but not after body data (size > 0) + // For compatibility with websocket libraries (e.g., github.com/coder/websocket) + if w.size > 0 { return nil, nil, errHijackAlreadyWritten } if w.size < 0 { diff --git a/response_writer_test.go b/response_writer_test.go index ef198418..86206959 100644 --- a/response_writer_test.go +++ b/response_writer_test.go @@ -194,6 +194,64 @@ func TestResponseWriterHijackAfterWrite(t *testing.T) { } } +// Test: WebSocket compatibility - allow hijack after WriteHeaderNow(), but block after body data. +func TestResponseWriterHijackAfterWriteHeaderNow(t *testing.T) { + tests := []struct { + name string + action func(w ResponseWriter) error + expectWrittenBeforeHijack bool + expectHijackSuccess bool + expectWrittenAfterHijack bool + expectError error + }{ + { + name: "hijack after WriteHeaderNow only should succeed (websocket pattern)", + action: func(w ResponseWriter) error { + w.WriteHeaderNow() // Simulate websocket.Accept() behavior + return nil + }, + expectWrittenBeforeHijack: true, + expectHijackSuccess: true, // NEW BEHAVIOR: allow hijack after just header write + expectWrittenAfterHijack: true, + expectError: nil, + }, + { + name: "hijack after WriteHeaderNow + Write should fail", + action: func(w ResponseWriter) error { + w.WriteHeaderNow() + _, err := w.Write([]byte("test")) + return err + }, + expectWrittenBeforeHijack: true, + expectHijackSuccess: false, + expectWrittenAfterHijack: true, + expectError: errHijackAlreadyWritten, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + hijacker := &mockHijacker{ResponseRecorder: httptest.NewRecorder()} + writer := &responseWriter{} + writer.reset(hijacker) + w := ResponseWriter(writer) + + require.NoError(t, tc.action(w), "unexpected error during pre-hijack action") + + assert.Equal(t, tc.expectWrittenBeforeHijack, w.Written(), "unexpected w.Written() state before hijack") + + _, _, hijackErr := w.Hijack() + + if tc.expectError == nil { + assert.NoError(t, hijackErr, "expected hijack to succeed") + } else { + assert.ErrorIs(t, hijackErr, tc.expectError, "unexpected error from Hijack()") + } + assert.Equal(t, tc.expectHijackSuccess, hijacker.hijacked, "unexpected hijacker.hijacked state") + assert.Equal(t, tc.expectWrittenAfterHijack, w.Written(), "unexpected w.Written() state after hijack") + }) + } +} + func TestResponseWriterFlush(t *testing.T) { testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { writer := &responseWriter{}