diff --git a/context.go b/context.go index 5174033e..52d8a0a6 100644 --- a/context.go +++ b/context.go @@ -1328,10 +1328,16 @@ func (c *Context) SSEvent(name string, message any) { func (c *Context) Stream(step func(w io.Writer) bool) bool { w := c.Writer clientGone := w.CloseNotify() + var requestGone <-chan struct{} + if c.Request != nil { + requestGone = c.Request.Context().Done() + } for { select { case <-clientGone: return true + case <-requestGone: + return true default: keepOpen := step(w) w.Flush() diff --git a/context_test.go b/context_test.go index ef60379d..80fac6d2 100644 --- a/context_test.go +++ b/context_test.go @@ -3078,6 +3078,25 @@ func TestContextStreamWithClientGone(t *testing.T) { assert.Equal(t, "test", w.Body.String()) } +func TestContextStreamWithRequestContextDone(t *testing.T) { + w := httptest.NewRecorder() + c, _ := CreateTestContext(w) + + reqCtx, cancel := context.WithCancel(context.Background()) + cancel() + c.Request = httptest.NewRequest(http.MethodGet, "/", nil).WithContext(reqCtx) + + disconnected := c.Stream(func(writer io.Writer) bool { + _, err := writer.Write([]byte("test")) + require.NoError(t, err) + + return true + }) + + assert.True(t, disconnected) + assert.Empty(t, w.Body.String()) +} + func TestContextResetInHandler(t *testing.T) { w := CreateTestResponseRecorder() c, _ := CreateTestContext(w)