diff --git a/context.go b/context.go index 5716318e..0f412725 100644 --- a/context.go +++ b/context.go @@ -1075,7 +1075,7 @@ func (c *Context) SSEvent(name string, message any) { // indicates "Is client disconnected in middle of stream" func (c *Context) Stream(step func(w io.Writer) bool) bool { w := c.Writer - clientGone := w.CloseNotify() + clientGone := c.Request.Context().Done() for { select { case <-clientGone: diff --git a/context_test.go b/context_test.go index 1dec902c..16e51f5b 100644 --- a/context_test.go +++ b/context_test.go @@ -2049,29 +2049,9 @@ func TestContextRenderDataFromReaderNoHeaders(t *testing.T) { assert.Equal(t, fmt.Sprintf("%d", contentLength), w.Header().Get("Content-Length")) } -type TestResponseRecorder struct { - *httptest.ResponseRecorder - closeChannel chan bool -} - -func (r *TestResponseRecorder) CloseNotify() <-chan bool { - return r.closeChannel -} - -func (r *TestResponseRecorder) closeClient() { - r.closeChannel <- true -} - -func CreateTestResponseRecorder() *TestResponseRecorder { - return &TestResponseRecorder{ - httptest.NewRecorder(), - make(chan bool, 1), - } -} - func TestContextStream(t *testing.T) { - w := CreateTestResponseRecorder() - c, _ := CreateTestContext(w) + w := httptest.NewRecorder() + c, _ := CreateTestContextWithCloser(w) stopStream := true c.Stream(func(w io.Writer) bool { @@ -2089,12 +2069,12 @@ func TestContextStream(t *testing.T) { } func TestContextStreamWithClientGone(t *testing.T) { - w := CreateTestResponseRecorder() - c, _ := CreateTestContext(w) + w := httptest.NewRecorder() + c, closeClient := CreateTestContextWithCloser(w) c.Stream(func(writer io.Writer) bool { defer func() { - w.closeClient() + closeClient() }() _, err := writer.Write([]byte("test")) @@ -2107,7 +2087,7 @@ func TestContextStreamWithClientGone(t *testing.T) { } func TestContextResetInHandler(t *testing.T) { - w := CreateTestResponseRecorder() + w := httptest.NewRecorder() c, _ := CreateTestContext(w) c.handlers = []HandlerFunc{ diff --git a/test_helpers.go b/test_helpers.go index 7508c5c9..e14ca27f 100644 --- a/test_helpers.go +++ b/test_helpers.go @@ -4,7 +4,10 @@ package gin -import "net/http" +import ( + "context" + "net/http" +) // CreateTestContext returns a fresh engine and context for testing purposes func CreateTestContext(w http.ResponseWriter) (c *Context, r *Engine) { @@ -22,3 +25,15 @@ func CreateTestContextOnly(w http.ResponseWriter, r *Engine) (c *Context) { c.writermem.reset(w) return } + +// CreateTestContextOnly returns a fresh context and its closer +func CreateTestContextWithCloser(w http.ResponseWriter) (c *Context, closeClient context.CancelFunc) { + r := New() + c = r.allocateContext(0) + c.reset() + c.writermem.reset(w) + ctx, closeClient := context.WithCancel(context.Background()) + var req http.Request + c.Request = req.WithContext(ctx) + return c, closeClient +}