From 8f662295ce4cd1d1470ab6276ac7f9cfa382bfd1 Mon Sep 17 00:00:00 2001 From: shahariaz Date: Sat, 23 May 2026 01:01:30 +0600 Subject: [PATCH] feat: add InitSSE(), SSEStream() and fix deprecated CloseNotifier in Stream() --- context.go | 70 +++++++++++++++++++++++++++++++++++++++++++++++- context_test.go | 71 ++++++++++++++++++++++++++++++++++++++++++------- docs/doc.md | 30 +++++++++++++++++++++ 3 files changed, 161 insertions(+), 10 deletions(-) diff --git a/context.go b/context.go index 5174033e..a8ef30cb 100644 --- a/context.go +++ b/context.go @@ -1315,7 +1315,33 @@ func (c *Context) FileAttachment(filepath, filename string) { http.ServeFile(c.Writer, c.Request, filepath) } +// InitSSE prepares the response for a Server-Sent Events stream by setting the +// required HTTP headers: Content-Type is set to "text/event-stream", +// Cache-Control to "no-cache", and Connection to "keep-alive". +// The headers are flushed to the client immediately so that the browser opens +// the stream before the first event is sent. +// +// Call this once at the beginning of your SSE handler, before any SSEvent call: +// +// router.GET("/stream", func(c *gin.Context) { +// c.InitSSE() +// for i := range 5 { +// c.SSEvent("message", i) +// c.Writer.Flush() +// } +// }) +func (c *Context) InitSSE() { + c.Writer.Header().Set("Content-Type", sse.ContentType) + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.WriteHeaderNow() +} + // SSEvent writes a Server-Sent Event into the body stream. +// It sets Content-Type and Cache-Control headers on the first call if they have +// not already been set (e.g. by InitSSE). The writer is NOT flushed automatically; +// call c.Writer.Flush() after each event to push it to the client immediately. +// To include the optional id or retry fields use c.Render(-1, sse.Event{…}) directly. func (c *Context) SSEvent(name string, message any) { c.Render(-1, sse.Event{ Event: name, @@ -1323,11 +1349,53 @@ func (c *Context) SSEvent(name string, message any) { }) } +// SSEStream initializes an SSE response and calls step in a loop to send events +// until either the client disconnects or step returns false. +// +// It returns true when the client disconnected (c.Request.Context() was cancelled) +// and false when step returned false (normal end-of-stream). +// +// The writer is flushed automatically after every successful step call. +// step receives the current Context so it can call c.SSEvent, c.Render, or +// inspect c.Request.Context().Done() for its own blocking select: +// +// router.GET("/events", func(c *gin.Context) { +// ch := make(chan string) +// go produce(ch) +// c.SSEStream(func(c *gin.Context) bool { +// select { +// case msg, ok := <-ch: +// if !ok { +// return false // channel closed → end stream normally +// } +// c.SSEvent("message", msg) +// return true +// case <-c.Request.Context().Done(): +// return false // client gone → end stream +// } +// }) +// }) +func (c *Context) SSEStream(step func(c *Context) bool) bool { + c.InitSSE() + ctx := c.Request.Context() + for { + select { + case <-ctx.Done(): + return true + default: + if !step(c) { + return false + } + c.Writer.Flush() + } + } +} + // Stream sends a streaming response and returns a boolean // indicates "Is client disconnected in middle of stream" func (c *Context) Stream(step func(w io.Writer) bool) bool { w := c.Writer - clientGone := w.CloseNotify() + clientGone := c.Request.Context().Done() for { select { case <-clientGone: diff --git a/context_test.go b/context_test.go index ef60379d..41f2ab12 100644 --- a/context_test.go +++ b/context_test.go @@ -1441,6 +1441,58 @@ func TestContextRenderSSE(t *testing.T) { assert.Equal(t, strings.ReplaceAll(w.Body.String(), " ", ""), strings.ReplaceAll("event:float\ndata:1.5\n\nid:123\ndata:text\n\nevent:chat\ndata:{\"bar\":\"foo\",\"foo\":\"bar\"}\n\n", " ", "")) } +func TestContextInitSSE(t *testing.T) { + w := httptest.NewRecorder() + c, _ := CreateTestContext(w) + c.Request, _ = http.NewRequest(http.MethodGet, "/", nil) + + c.InitSSE() + + assert.Equal(t, sse.ContentType, w.Header().Get("Content-Type")) + assert.Equal(t, "no-cache", w.Header().Get("Cache-Control")) + assert.Equal(t, "keep-alive", w.Header().Get("Connection")) + assert.Equal(t, http.StatusOK, w.Code) +} + +func TestContextSSEStreamNormalEnd(t *testing.T) { + w := httptest.NewRecorder() + c, _ := CreateTestContext(w) + c.Request, _ = http.NewRequest(http.MethodGet, "/", nil) + + count := 0 + disconnected := c.SSEStream(func(c *Context) bool { + count++ + c.SSEvent("ping", count) + return count < 3 + }) + + assert.False(t, disconnected) + assert.Equal(t, 3, count) + assert.Equal(t, sse.ContentType, w.Header().Get("Content-Type")) + assert.Equal(t, "no-cache", w.Header().Get("Cache-Control")) + assert.Equal(t, "keep-alive", w.Header().Get("Connection")) + assert.Contains(t, w.Body.String(), "event:ping") +} + +func TestContextSSEStreamClientDisconnect(t *testing.T) { + w := httptest.NewRecorder() + c, _ := CreateTestContext(w) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + c.Request, _ = http.NewRequestWithContext(ctx, http.MethodGet, "/", nil) + + // step cancels the context and returns true (keep streaming). + // On the next loop iteration SSEStream's outer select sees ctx.Done() + // is closed and returns true, indicating client disconnected. + result := c.SSEStream(func(c *Context) bool { + cancel() // simulate client disconnect + return true + }) + + assert.True(t, result) +} + func TestContextRenderFile(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) @@ -3030,10 +3082,6 @@ func (r *TestResponseRecorder) CloseNotify() <-chan bool { return r.closeChannel } -func (r *TestResponseRecorder) closeClient() { - r.closeChannel <- true -} - func CreateTestResponseRecorder() *TestResponseRecorder { return &TestResponseRecorder{ httptest.NewRecorder(), @@ -3044,6 +3092,7 @@ func CreateTestResponseRecorder() *TestResponseRecorder { func TestContextStream(t *testing.T) { w := CreateTestResponseRecorder() c, _ := CreateTestContext(w) + c.Request, _ = http.NewRequest(http.MethodGet, "/", nil) stopStream := true c.Stream(func(w io.Writer) bool { @@ -3064,17 +3113,21 @@ func TestContextStreamWithClientGone(t *testing.T) { w := CreateTestResponseRecorder() c, _ := CreateTestContext(w) - c.Stream(func(writer io.Writer) bool { - defer func() { - w.closeClient() - }() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + c.Request, _ = http.NewRequestWithContext(ctx, http.MethodGet, "/", nil) + // step cancels the context and returns true (keep streaming). + // On the next loop iteration Stream's outer select sees clientGone + // is closed and returns true, indicating client disconnected. + result := c.Stream(func(writer io.Writer) bool { _, err := writer.Write([]byte("test")) require.NoError(t, err) - + cancel() // simulate client disconnect return true }) + assert.True(t, result) assert.Equal(t, "test", w.Body.String()) } diff --git a/docs/doc.md b/docs/doc.md index d1c33b87..d974cd24 100644 --- a/docs/doc.md +++ b/docs/doc.md @@ -1879,6 +1879,36 @@ func main() { } ``` +### Server-Sent Events (SSE) + +Use `c.InitSSE()` to set the required headers, then `c.SSEvent()` + `c.Writer.Flush()` to push events: + +```go +router.GET("/stream", func(c *gin.Context) { + c.InitSSE() + for i := range 5 { + c.SSEvent("message", gin.H{"count": i}) + c.Writer.Flush() + } +}) +``` + +For a long-running stream that stops when the client disconnects, use `c.SSEStream()`: + +```go +router.GET("/stream", func(c *gin.Context) { + i := 0 + c.SSEStream(func(c *gin.Context) bool { + i++ + c.SSEvent("message", gin.H{"count": i}) + return i < 10 // return false to end the stream normally + }) +}) +``` + +`SSEStream` returns `true` if the client disconnected mid-stream, `false` if the step +function ended the stream by returning `false`. + ### HTML rendering Using LoadHTMLGlob() or LoadHTMLFiles() or LoadHTMLFS()