diff --git a/context.go b/context.go index 1c76c0f6..9bdc8348 100644 --- a/context.go +++ b/context.go @@ -5,6 +5,7 @@ package gin import ( + "context" "errors" "io" "io/fs" @@ -90,6 +91,9 @@ type Context struct { // SameSite allows a server to define a cookie attribute making it impossible for // the browser to send this cookie along with cross-site requests. sameSite http.SameSite + + internalContextMu sync.RWMutex + internalContext context.Context } /************************************/ @@ -111,6 +115,10 @@ func (c *Context) reset() { c.sameSite = 0 *c.params = (*c.params)[:0] *c.skippedNodes = (*c.skippedNodes)[:0] + + if c.useInternalContext() { + c.WithInternalContext(context.Background()) + } } // Copy returns a copy of the current context that can be safely used outside the request's scope. @@ -1307,6 +1315,49 @@ func (c *Context) SetAccepted(formats ...string) { /***** GOLANG.ORG/X/NET/CONTEXT *****/ /************************************/ +// WithInternalContext replaces the internal context stored with the provided one in a thread safe manner. +// It's important that any context you pass in is not something the wraps *gin.Context, +// if you want to wrap a context and then provide it to WithInternalContext, use InternalContext(). +// If you don't plan to provide the context back to WithInternalContext you can safely use *Context directly. +// Otherwise you'll end up with a stack overflow. +// +// For example: +// var c *Context // given a context +// // you can safely wrap it and pass it downstream +// myDownstreamFunction(context.WithValue(c, ...)) +// +// // but when you want to call WithInternalContext you should do it like this +// c.WithInternalContext(context.WithValue(c.InternalContext(), ...)) +func (c *Context) WithInternalContext(ctx context.Context) { + if !c.useInternalContext() { + panic("Can't use WithInternalContext when UseInternalContext is false") + } + + c.internalContextMu.Lock() + defer c.internalContextMu.Unlock() + + c.internalContext = ctx +} + +// InternalContext provides the currently stored internal context in a thread safe manner. +// Use this if you want to wrap a context.Context which you'll end up providing to WithInternalContext. +// If you don't plan to provide the context back to WithInternalContext you can safely use *Context directly. +func (c *Context) InternalContext() context.Context { + if !c.useInternalContext() { + panic("Can't use InternalContext when UseInternalContext is false") + } + + c.internalContextMu.RLock() + defer c.internalContextMu.RUnlock() + + return c.internalContext +} + +// hasRequestContext returns whether c.Request has Context and fallback. +func (c *Context) useInternalContext() bool { + return c.engine != nil && c.engine.UseInternalContext +} + // hasRequestContext returns whether c.Request has Context and fallback. func (c *Context) hasRequestContext() bool { hasFallback := c.engine != nil && c.engine.ContextWithFallback @@ -1316,26 +1367,44 @@ func (c *Context) hasRequestContext() bool { // Deadline returns that there is no deadline (ok==false) when c.Request has no Context. func (c *Context) Deadline() (deadline time.Time, ok bool) { - if !c.hasRequestContext() { - return + if c.useInternalContext() { + c.internalContextMu.RLock() + defer c.internalContextMu.RUnlock() + + return c.internalContext.Deadline() + } else if c.hasRequestContext() { + return c.Request.Context().Deadline() } - return c.Request.Context().Deadline() + + return } // Done returns nil (chan which will wait forever) when c.Request has no Context. func (c *Context) Done() <-chan struct{} { - if !c.hasRequestContext() { - return nil + if c.useInternalContext() { + c.internalContextMu.RLock() + defer c.internalContextMu.RUnlock() + + return c.internalContext.Done() + } else if c.hasRequestContext() { + return c.Request.Context().Done() } - return c.Request.Context().Done() + + return nil } // Err returns nil when c.Request has no Context. func (c *Context) Err() error { - if !c.hasRequestContext() { - return nil + if c.useInternalContext() { + c.internalContextMu.RLock() + defer c.internalContextMu.RUnlock() + + return c.internalContext.Err() + } else if c.hasRequestContext() { + return c.Request.Context().Err() } - return c.Request.Context().Err() + + return nil } // Value returns the value associated with this context for key, or nil @@ -1353,8 +1422,14 @@ func (c *Context) Value(key any) any { return val } } - if !c.hasRequestContext() { - return nil + if c.useInternalContext() { + c.internalContextMu.RLock() + defer c.internalContextMu.RUnlock() + + return c.internalContext.Value(key) + } else if c.hasRequestContext() { + return c.Request.Context().Value(key) } - return c.Request.Context().Value(key) + + return nil } diff --git a/context_test.go b/context_test.go index ef0cfccd..ba21f65f 100644 --- a/context_test.go +++ b/context_test.go @@ -2952,6 +2952,142 @@ func TestContextWithFallbackValueFromRequestContext(t *testing.T) { } } +func TestContextUseInternalContextDeadline(t *testing.T) { + c, _ := CreateTestContext(httptest.NewRecorder(), func(c *Context) { + // enable UseInternalContext feature flag + c.engine.UseInternalContext = true + }) + + deadline, ok := c.Deadline() + assert.Zero(t, deadline) + assert.False(t, ok) + + c2, _ := CreateTestContext(httptest.NewRecorder(), func(c *Context) { + // enable UseInternalContext feature flag + c.engine.UseInternalContext = true + }) + + d := time.Now().Add(time.Second) + ctx, cancel := context.WithDeadline(context.Background(), d) + defer cancel() + c2.WithInternalContext(ctx) + deadline, ok = c2.Deadline() + assert.Equal(t, d, deadline) + assert.True(t, ok) +} + +func TestContextUseInternalContextDone(t *testing.T) { + c, _ := CreateTestContext(httptest.NewRecorder(), func(c *Context) { + // enable UseInternalContext feature flag + c.engine.UseInternalContext = true + }) + + assert.Nil(t, c.Done()) + + c2, _ := CreateTestContext(httptest.NewRecorder(), func(c *Context) { + // enable UseInternalContext feature flag + c.engine.UseInternalContext = true + }) + + ctx, cancel := context.WithCancel(context.Background()) + c2.WithInternalContext(ctx) + cancel() + assert.NotNil(t, <-c2.Done()) +} + +func TestContextUseInternalContextErr(t *testing.T) { + c, _ := CreateTestContext(httptest.NewRecorder(), func(c *Context) { + // enable UseInternalContext feature flag + c.engine.UseInternalContext = true + }) + + require.NoError(t, c.Err()) + + c2, _ := CreateTestContext(httptest.NewRecorder(), func(c *Context) { + // enable UseInternalContext feature flag + c.engine.UseInternalContext = true + }) + + ctx, cancel := context.WithCancel(context.Background()) + c2.WithInternalContext(ctx) + cancel() + + assert.EqualError(t, c2.Err(), context.Canceled.Error()) +} + +func TestContextUseInternalContextValue(t *testing.T) { + type contextKey string + + tests := []struct { + name string + getContextAndKey func() (*Context, any) + value any + }{ + { + name: "c with struct context key", + getContextAndKey: func() (*Context, any) { + type KeyStruct struct{} // https://staticcheck.dev/docs/checks/#SA1029 + var key KeyStruct + c, _ := CreateTestContext(httptest.NewRecorder(), func(c *Context) { + // enable UseInternalContext feature flag + c.engine.UseInternalContext = true + }) + c.WithInternalContext(context.WithValue(context.TODO(), key, "value")) + return c, key + }, + value: "value", + }, + { + name: "c with struct context key and request context with different value", + getContextAndKey: func() (*Context, any) { + type KeyStruct struct{} // https://staticcheck.dev/docs/checks/#SA1029 + var key KeyStruct + c, _ := CreateTestContext(httptest.NewRecorder(), func(c *Context) { + // enable UseInternalContext feature flag + c.engine.UseInternalContext = true + // enable ContextWithFallback feature flag + c.engine.ContextWithFallback = true + c.Request, _ = http.NewRequest(http.MethodPost, "/", nil) + }) + c.WithInternalContext(context.WithValue(context.TODO(), key, "value")) + c.Request = c.Request.WithContext(context.WithValue(context.TODO(), key, "other value")) + return c, key + }, + value: "value", + }, + { + name: "c with string context key", + getContextAndKey: func() (*Context, any) { + c, _ := CreateTestContext(httptest.NewRecorder(), func(c *Context) { + // enable UseInternalContext feature flag + c.engine.UseInternalContext = true + }) + c.WithInternalContext(context.WithValue(context.TODO(), contextKey("key"), "value")) + return c, contextKey("key") + }, + value: "value", + }, + { + name: "c with background internal context", + getContextAndKey: func() (*Context, any) { + c, _ := CreateTestContext(httptest.NewRecorder(), func(c *Context) { + // enable UseInternalContext feature flag + c.engine.UseInternalContext = true + }) + c.WithInternalContext(context.Background()) + return c, "key" + }, + value: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c, key := tt.getContextAndKey() + assert.Equal(t, tt.value, c.Value(key)) + }) + } +} + func TestContextCopyShouldNotCancel(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) diff --git a/gin.go b/gin.go index 0761c14d..6f6bf510 100644 --- a/gin.go +++ b/gin.go @@ -17,7 +17,6 @@ import ( "github.com/gin-gonic/gin/internal/bytesconv" "github.com/gin-gonic/gin/render" - "github.com/quic-go/quic-go/http3" "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" @@ -161,9 +160,14 @@ type Engine struct { // UseH2C enable h2c support. UseH2C bool - // ContextWithFallback enable fallback Context.Deadline(), Context.Done(), Context.Err() and Context.Value() when Context.Request.Context() is not nil. + // ContextWithFallback enable fallback Context.Deadline(), Context.Done(), Context.Err() and Context.Value() + // through Context.Request when Context.Request.Context() is not nil. ContextWithFallback bool + // UseInternalContext enable fallback Context.Deadline(), Context.Done(), Context.Err() + // through InternalContext and supersedes ContextWithFallback + UseInternalContext bool + delims render.Delims secureJSONPrefix string HTMLRender render.HTMLRender diff --git a/test_helpers.go b/test_helpers.go index 7508c5c9..b3cb1035 100644 --- a/test_helpers.go +++ b/test_helpers.go @@ -7,9 +7,12 @@ package gin import "net/http" // CreateTestContext returns a fresh engine and context for testing purposes -func CreateTestContext(w http.ResponseWriter) (c *Context, r *Engine) { +func CreateTestContext(w http.ResponseWriter, opts ...func(c *Context)) (c *Context, r *Engine) { r = New() c = r.allocateContext(0) + for _, opt := range opts { + opt(c) + } c.reset() c.writermem.reset(w) return