diff --git a/response_writer_test.go b/response_writer_test.go index fe4fe25e..d68ce5ee 100644 --- a/response_writer_test.go +++ b/response_writer_test.go @@ -131,39 +131,54 @@ type mockHijacker struct { hijacked bool } +// Hijack implements the http.Hijacker interface. It just records that it was called. func (m *mockHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { m.hijacked = true return nil, nil, nil } func TestResponseWriterHijackAfterWrite(t *testing.T) { - // Test case 1: Hijack before writing - hijacker := &mockHijacker{ResponseRecorder: httptest.NewRecorder()} - writer := &responseWriter{} - writer.reset(hijacker) - w := ResponseWriter(writer) + tests := []struct { + name string + action func(w ResponseWriter) error // Action to perform before hijacking + expectHijack bool + expectWritten bool + expectError error + }{ + { + name: "hijack before write should succeed", + action: func(w ResponseWriter) error { return nil }, + expectHijack: true, + expectWritten: true, // Hijack itself marks the writer as written + expectError: nil, + }, + { + name: "hijack after write should fail", + action: func(w ResponseWriter) error { + _, err := w.Write([]byte("test")) + return err + }, + expectHijack: false, + expectWritten: true, + expectError: errHijackAlreadyWritten, + }, + } - _, _, err := w.Hijack() - require.NoError(t, err) - assert.True(t, hijacker.hijacked, "Hijack should be called") - assert.True(t, w.Written(), "Written() should be true after Hijack") + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + hijacker := &mockHijacker{ResponseRecorder: httptest.NewRecorder()} + writer := &responseWriter{} + writer.reset(hijacker) + w := ResponseWriter(writer) - // Test case 2: Hijack after writing - hijacker2 := &mockHijacker{ResponseRecorder: httptest.NewRecorder()} - writer2 := &responseWriter{} - writer2.reset(hijacker2) - w2 := ResponseWriter(writer2) + require.NoError(t, tc.action(w), "unexpected error during pre-hijack action") - _, err = w2.Write([]byte("test")) - require.NoError(t, err) - assert.True(t, w2.Written(), "Written() should be true after Write") - - // Now, try to hijack - _, _, err = w2.Hijack() - require.Error(t, err) - assert.Equal(t, errHijackAlreadyWritten, err, "Hijack after write should return errHijackAlreadyWritten") - assert.False(t, hijacker2.hijacked, "Hijack should not be called after write") -} + _, _, hijackErr := w.Hijack() + assert.ErrorIs(t, hijackErr, tc.expectError, "unexpected error from Hijack()") + assert.Equal(t, tc.expectHijack, hijacker.hijacked, "unexpected hijacker.hijacked state") + assert.Equal(t, tc.expectWritten, w.Written(), "unexpected w.Written() state") + }) + }} func TestResponseWriterFlush(t *testing.T) { testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {