diff --git a/response_writer_test.go b/response_writer_test.go index b3cb9e72..ef198418 100644 --- a/response_writer_test.go +++ b/response_writer_test.go @@ -139,18 +139,20 @@ func (m *mockHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { func TestResponseWriterHijackAfterWrite(t *testing.T) { tests := []struct { - name string - action func(w ResponseWriter) error // Action to perform before hijacking - expectHijack bool - expectWritten bool - expectError error + name string + action func(w ResponseWriter) error // Action to perform before hijacking + expectWrittenBeforeHijack bool + expectHijackSuccess bool + expectWrittenAfterHijack bool + expectError error }{ { - name: "hijack before write should succeed", - action: func(w ResponseWriter) error { return nil }, - expectHijack: true, - expectWritten: true, // Hijack itself marks the writer as written - expectError: nil, + name: "hijack before write should succeed", + action: func(w ResponseWriter) error { return nil }, + expectWrittenBeforeHijack: false, + expectHijackSuccess: true, + expectWrittenAfterHijack: true, // Hijack itself marks the writer as written + expectError: nil, }, { name: "hijack after write should fail", @@ -158,9 +160,10 @@ func TestResponseWriterHijackAfterWrite(t *testing.T) { _, err := w.Write([]byte("test")) return err }, - expectHijack: false, - expectWritten: true, - expectError: errHijackAlreadyWritten, + expectWrittenBeforeHijack: true, + expectHijackSuccess: false, + expectWrittenAfterHijack: true, + expectError: errHijackAlreadyWritten, }, } @@ -171,12 +174,22 @@ func TestResponseWriterHijackAfterWrite(t *testing.T) { writer.reset(hijacker) w := ResponseWriter(writer) + // Check initial state + assert.False(t, w.Written(), "should not be written initially") + + // Perform pre-hijack action require.NoError(t, tc.action(w), "unexpected error during pre-hijack action") + // Check state before hijacking + assert.Equal(t, tc.expectWrittenBeforeHijack, w.Written(), "unexpected w.Written() state before hijack") + + // Attempt to hijack _, _, hijackErr := w.Hijack() + + // Check results require.ErrorIs(t, hijackErr, tc.expectError, "unexpected error from Hijack()") - assert.Equal(t, tc.expectHijack, hijacker.hijacked, "unexpected hijacker.hijacked state") - assert.Equal(t, tc.expectWritten, w.Written(), "unexpected w.Written() state") + assert.Equal(t, tc.expectHijackSuccess, hijacker.hijacked, "unexpected hijacker.hijacked state") + assert.Equal(t, tc.expectWrittenAfterHijack, w.Written(), "unexpected w.Written() state after hijack") }) } }