From f80371bc27018ef05c823b804a870bbff3d0b675 Mon Sep 17 00:00:00 2001 From: Glonee Date: Fri, 28 Apr 2023 23:51:20 +0800 Subject: [PATCH] use Request.Context().Done() instead of CloseNotify() --- context.go | 2 +- context_test.go | 32 ++++++-------------------------- test_helpers.go | 17 ++++++++++++++++- 3 files changed, 23 insertions(+), 28 deletions(-) diff --git a/context.go b/context.go index 5716318e..0f412725 100644 --- a/context.go +++ b/context.go @@ -1075,7 +1075,7 @@ func (c *Context) SSEvent(name string, message any) { // indicates "Is client disconnected in middle of stream" func (c *Context) Stream(step func(w io.Writer) bool) bool { w := c.Writer - clientGone := w.CloseNotify() + clientGone := c.Request.Context().Done() for { select { case <-clientGone: diff --git a/context_test.go b/context_test.go index 1dec902c..16e51f5b 100644 --- a/context_test.go +++ b/context_test.go @@ -2049,29 +2049,9 @@ func TestContextRenderDataFromReaderNoHeaders(t *testing.T) { assert.Equal(t, fmt.Sprintf("%d", contentLength), w.Header().Get("Content-Length")) } -type TestResponseRecorder struct { - *httptest.ResponseRecorder - closeChannel chan bool -} - -func (r *TestResponseRecorder) CloseNotify() <-chan bool { - return r.closeChannel -} - -func (r *TestResponseRecorder) closeClient() { - r.closeChannel <- true -} - -func CreateTestResponseRecorder() *TestResponseRecorder { - return &TestResponseRecorder{ - httptest.NewRecorder(), - make(chan bool, 1), - } -} - func TestContextStream(t *testing.T) { - w := CreateTestResponseRecorder() - c, _ := CreateTestContext(w) + w := httptest.NewRecorder() + c, _ := CreateTestContextWithCloser(w) stopStream := true c.Stream(func(w io.Writer) bool { @@ -2089,12 +2069,12 @@ func TestContextStream(t *testing.T) { } func TestContextStreamWithClientGone(t *testing.T) { - w := CreateTestResponseRecorder() - c, _ := CreateTestContext(w) + w := httptest.NewRecorder() + c, closeClient := CreateTestContextWithCloser(w) c.Stream(func(writer io.Writer) bool { defer func() { - w.closeClient() + closeClient() }() _, err := writer.Write([]byte("test")) @@ -2107,7 +2087,7 @@ func TestContextStreamWithClientGone(t *testing.T) { } func TestContextResetInHandler(t *testing.T) { - w := CreateTestResponseRecorder() + w := httptest.NewRecorder() c, _ := CreateTestContext(w) c.handlers = []HandlerFunc{ diff --git a/test_helpers.go b/test_helpers.go index 7508c5c9..e14ca27f 100644 --- a/test_helpers.go +++ b/test_helpers.go @@ -4,7 +4,10 @@ package gin -import "net/http" +import ( + "context" + "net/http" +) // CreateTestContext returns a fresh engine and context for testing purposes func CreateTestContext(w http.ResponseWriter) (c *Context, r *Engine) { @@ -22,3 +25,15 @@ func CreateTestContextOnly(w http.ResponseWriter, r *Engine) (c *Context) { c.writermem.reset(w) return } + +// CreateTestContextOnly returns a fresh context and its closer +func CreateTestContextWithCloser(w http.ResponseWriter) (c *Context, closeClient context.CancelFunc) { + r := New() + c = r.allocateContext(0) + c.reset() + c.writermem.reset(w) + ctx, closeClient := context.WithCancel(context.Background()) + var req http.Request + c.Request = req.WithContext(ctx) + return c, closeClient +}