diff --git a/context.go b/context.go index a8ef30cb..48c75e0f 100644 --- a/context.go +++ b/context.go @@ -1335,6 +1335,7 @@ func (c *Context) InitSSE() { c.Writer.Header().Set("Cache-Control", "no-cache") c.Writer.Header().Set("Connection", "keep-alive") c.Writer.WriteHeaderNow() + c.Writer.Flush() } // SSEvent writes a Server-Sent Event into the body stream. @@ -1377,10 +1378,13 @@ func (c *Context) SSEvent(name string, message any) { // }) func (c *Context) SSEStream(step func(c *Context) bool) bool { c.InitSSE() - ctx := c.Request.Context() + var done <-chan struct{} + if c.Request != nil { + done = c.Request.Context().Done() + } for { select { - case <-ctx.Done(): + case <-done: return true default: if !step(c) { @@ -1395,7 +1399,10 @@ func (c *Context) SSEStream(step func(c *Context) bool) bool { // indicates "Is client disconnected in middle of stream" func (c *Context) Stream(step func(w io.Writer) bool) bool { w := c.Writer - clientGone := c.Request.Context().Done() + var clientGone <-chan struct{} + if c.Request != nil { + clientGone = c.Request.Context().Done() + } for { select { case <-clientGone: diff --git a/context_test.go b/context_test.go index 41f2ab12..5a8cff64 100644 --- a/context_test.go +++ b/context_test.go @@ -1452,6 +1452,7 @@ func TestContextInitSSE(t *testing.T) { assert.Equal(t, "no-cache", w.Header().Get("Cache-Control")) assert.Equal(t, "keep-alive", w.Header().Get("Connection")) assert.Equal(t, http.StatusOK, w.Code) + assert.True(t, w.Flushed) } func TestContextSSEStreamNormalEnd(t *testing.T) { @@ -1474,6 +1475,22 @@ func TestContextSSEStreamNormalEnd(t *testing.T) { assert.Contains(t, w.Body.String(), "event:ping") } +func TestContextSSEStreamNilRequest(t *testing.T) { + w := httptest.NewRecorder() + c, _ := CreateTestContext(w) + // c.Request is intentionally left nil to verify no panic + + count := 0 + assert.NotPanics(t, func() { + disconnected := c.SSEStream(func(c *Context) bool { + count++ + return count < 2 + }) + assert.False(t, disconnected) + }) + assert.Equal(t, 2, count) +} + func TestContextSSEStreamClientDisconnect(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) @@ -3131,6 +3148,24 @@ func TestContextStreamWithClientGone(t *testing.T) { assert.Equal(t, "test", w.Body.String()) } +func TestContextStreamNilRequest(t *testing.T) { + w := CreateTestResponseRecorder() + c, _ := CreateTestContext(w) + // c.Request is intentionally left nil to verify no panic + + count := 0 + assert.NotPanics(t, func() { + disconnected := c.Stream(func(writer io.Writer) bool { + count++ + _, err := writer.Write([]byte("x")) + require.NoError(t, err) + return count < 2 + }) + assert.False(t, disconnected) + }) + assert.Equal(t, 2, count) +} + func TestContextResetInHandler(t *testing.T) { w := CreateTestResponseRecorder() c, _ := CreateTestContext(w)