use Request.Context().Done() instead of CloseNotify()

This commit is contained in:
Glonee 2023-04-28 23:51:20 +08:00
parent eac2daac64
commit f80371bc27
No known key found for this signature in database
GPG Key ID: 647C8F09A5A4BB89
3 changed files with 23 additions and 28 deletions

View File

@ -1075,7 +1075,7 @@ func (c *Context) SSEvent(name string, message any) {
// indicates "Is client disconnected in middle of stream" // indicates "Is client disconnected in middle of stream"
func (c *Context) Stream(step func(w io.Writer) bool) bool { func (c *Context) Stream(step func(w io.Writer) bool) bool {
w := c.Writer w := c.Writer
clientGone := w.CloseNotify() clientGone := c.Request.Context().Done()
for { for {
select { select {
case <-clientGone: case <-clientGone:

View File

@ -2049,29 +2049,9 @@ func TestContextRenderDataFromReaderNoHeaders(t *testing.T) {
assert.Equal(t, fmt.Sprintf("%d", contentLength), w.Header().Get("Content-Length")) assert.Equal(t, fmt.Sprintf("%d", contentLength), w.Header().Get("Content-Length"))
} }
type TestResponseRecorder struct {
*httptest.ResponseRecorder
closeChannel chan bool
}
func (r *TestResponseRecorder) CloseNotify() <-chan bool {
return r.closeChannel
}
func (r *TestResponseRecorder) closeClient() {
r.closeChannel <- true
}
func CreateTestResponseRecorder() *TestResponseRecorder {
return &TestResponseRecorder{
httptest.NewRecorder(),
make(chan bool, 1),
}
}
func TestContextStream(t *testing.T) { func TestContextStream(t *testing.T) {
w := CreateTestResponseRecorder() w := httptest.NewRecorder()
c, _ := CreateTestContext(w) c, _ := CreateTestContextWithCloser(w)
stopStream := true stopStream := true
c.Stream(func(w io.Writer) bool { c.Stream(func(w io.Writer) bool {
@ -2089,12 +2069,12 @@ func TestContextStream(t *testing.T) {
} }
func TestContextStreamWithClientGone(t *testing.T) { func TestContextStreamWithClientGone(t *testing.T) {
w := CreateTestResponseRecorder() w := httptest.NewRecorder()
c, _ := CreateTestContext(w) c, closeClient := CreateTestContextWithCloser(w)
c.Stream(func(writer io.Writer) bool { c.Stream(func(writer io.Writer) bool {
defer func() { defer func() {
w.closeClient() closeClient()
}() }()
_, err := writer.Write([]byte("test")) _, err := writer.Write([]byte("test"))
@ -2107,7 +2087,7 @@ func TestContextStreamWithClientGone(t *testing.T) {
} }
func TestContextResetInHandler(t *testing.T) { func TestContextResetInHandler(t *testing.T) {
w := CreateTestResponseRecorder() w := httptest.NewRecorder()
c, _ := CreateTestContext(w) c, _ := CreateTestContext(w)
c.handlers = []HandlerFunc{ c.handlers = []HandlerFunc{

View File

@ -4,7 +4,10 @@
package gin package gin
import "net/http" import (
"context"
"net/http"
)
// CreateTestContext returns a fresh engine and context for testing purposes // CreateTestContext returns a fresh engine and context for testing purposes
func CreateTestContext(w http.ResponseWriter) (c *Context, r *Engine) { func CreateTestContext(w http.ResponseWriter) (c *Context, r *Engine) {
@ -22,3 +25,15 @@ func CreateTestContextOnly(w http.ResponseWriter, r *Engine) (c *Context) {
c.writermem.reset(w) c.writermem.reset(w)
return return
} }
// CreateTestContextOnly returns a fresh context and its closer
func CreateTestContextWithCloser(w http.ResponseWriter) (c *Context, closeClient context.CancelFunc) {
r := New()
c = r.allocateContext(0)
c.reset()
c.writermem.reset(w)
ctx, closeClient := context.WithCancel(context.Background())
var req http.Request
c.Request = req.WithContext(ctx)
return c, closeClient
}