Merge 443ef770a76e15ee1175d06c8587f1767e0ad94a into 8763f33c65f7df8be5b9fe7504ab7fcf20abb41d

This commit is contained in:
Ruben de Vries 2025-03-23 04:01:20 +05:30 committed by GitHub
commit ef57e78f56
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 233 additions and 15 deletions

View File

@ -5,6 +5,7 @@
package gin
import (
"context"
"errors"
"io"
"io/fs"
@ -90,6 +91,9 @@ type Context struct {
// SameSite allows a server to define a cookie attribute making it impossible for
// the browser to send this cookie along with cross-site requests.
sameSite http.SameSite
internalContextMu sync.RWMutex
internalContext context.Context
}
/************************************/
@ -111,6 +115,10 @@ func (c *Context) reset() {
c.sameSite = 0
*c.params = (*c.params)[:0]
*c.skippedNodes = (*c.skippedNodes)[:0]
if c.useInternalContext() {
c.WithInternalContext(context.Background())
}
}
// Copy returns a copy of the current context that can be safely used outside the request's scope.
@ -1307,6 +1315,49 @@ func (c *Context) SetAccepted(formats ...string) {
/***** GOLANG.ORG/X/NET/CONTEXT *****/
/************************************/
// WithInternalContext replaces the internal context stored with the provided one in a thread safe manner.
// It's important that any context you pass in is not something the wraps *gin.Context,
// if you want to wrap a context and then provide it to WithInternalContext, use InternalContext().
// If you don't plan to provide the context back to WithInternalContext you can safely use *Context directly.
// Otherwise you'll end up with a stack overflow.
//
// For example:
// var c *Context // given a context
// // you can safely wrap it and pass it downstream
// myDownstreamFunction(context.WithValue(c, ...))
//
// // but when you want to call WithInternalContext you should do it like this
// c.WithInternalContext(context.WithValue(c.InternalContext(), ...))
func (c *Context) WithInternalContext(ctx context.Context) {
if !c.useInternalContext() {
panic("Can't use WithInternalContext when UseInternalContext is false")
}
c.internalContextMu.Lock()
defer c.internalContextMu.Unlock()
c.internalContext = ctx
}
// InternalContext provides the currently stored internal context in a thread safe manner.
// Use this if you want to wrap a context.Context which you'll end up providing to WithInternalContext.
// If you don't plan to provide the context back to WithInternalContext you can safely use *Context directly.
func (c *Context) InternalContext() context.Context {
if !c.useInternalContext() {
panic("Can't use InternalContext when UseInternalContext is false")
}
c.internalContextMu.RLock()
defer c.internalContextMu.RUnlock()
return c.internalContext
}
// hasRequestContext returns whether c.Request has Context and fallback.
func (c *Context) useInternalContext() bool {
return c.engine != nil && c.engine.UseInternalContext
}
// hasRequestContext returns whether c.Request has Context and fallback.
func (c *Context) hasRequestContext() bool {
hasFallback := c.engine != nil && c.engine.ContextWithFallback
@ -1316,26 +1367,44 @@ func (c *Context) hasRequestContext() bool {
// Deadline returns that there is no deadline (ok==false) when c.Request has no Context.
func (c *Context) Deadline() (deadline time.Time, ok bool) {
if !c.hasRequestContext() {
return
if c.useInternalContext() {
c.internalContextMu.RLock()
defer c.internalContextMu.RUnlock()
return c.internalContext.Deadline()
} else if c.hasRequestContext() {
return c.Request.Context().Deadline()
}
return c.Request.Context().Deadline()
return
}
// Done returns nil (chan which will wait forever) when c.Request has no Context.
func (c *Context) Done() <-chan struct{} {
if !c.hasRequestContext() {
return nil
if c.useInternalContext() {
c.internalContextMu.RLock()
defer c.internalContextMu.RUnlock()
return c.internalContext.Done()
} else if c.hasRequestContext() {
return c.Request.Context().Done()
}
return c.Request.Context().Done()
return nil
}
// Err returns nil when c.Request has no Context.
func (c *Context) Err() error {
if !c.hasRequestContext() {
return nil
if c.useInternalContext() {
c.internalContextMu.RLock()
defer c.internalContextMu.RUnlock()
return c.internalContext.Err()
} else if c.hasRequestContext() {
return c.Request.Context().Err()
}
return c.Request.Context().Err()
return nil
}
// Value returns the value associated with this context for key, or nil
@ -1353,8 +1422,14 @@ func (c *Context) Value(key any) any {
return val
}
}
if !c.hasRequestContext() {
return nil
if c.useInternalContext() {
c.internalContextMu.RLock()
defer c.internalContextMu.RUnlock()
return c.internalContext.Value(key)
} else if c.hasRequestContext() {
return c.Request.Context().Value(key)
}
return c.Request.Context().Value(key)
return nil
}

View File

@ -2952,6 +2952,142 @@ func TestContextWithFallbackValueFromRequestContext(t *testing.T) {
}
}
func TestContextUseInternalContextDeadline(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder(), func(c *Context) {
// enable UseInternalContext feature flag
c.engine.UseInternalContext = true
})
deadline, ok := c.Deadline()
assert.Zero(t, deadline)
assert.False(t, ok)
c2, _ := CreateTestContext(httptest.NewRecorder(), func(c *Context) {
// enable UseInternalContext feature flag
c.engine.UseInternalContext = true
})
d := time.Now().Add(time.Second)
ctx, cancel := context.WithDeadline(context.Background(), d)
defer cancel()
c2.WithInternalContext(ctx)
deadline, ok = c2.Deadline()
assert.Equal(t, d, deadline)
assert.True(t, ok)
}
func TestContextUseInternalContextDone(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder(), func(c *Context) {
// enable UseInternalContext feature flag
c.engine.UseInternalContext = true
})
assert.Nil(t, c.Done())
c2, _ := CreateTestContext(httptest.NewRecorder(), func(c *Context) {
// enable UseInternalContext feature flag
c.engine.UseInternalContext = true
})
ctx, cancel := context.WithCancel(context.Background())
c2.WithInternalContext(ctx)
cancel()
assert.NotNil(t, <-c2.Done())
}
func TestContextUseInternalContextErr(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder(), func(c *Context) {
// enable UseInternalContext feature flag
c.engine.UseInternalContext = true
})
require.NoError(t, c.Err())
c2, _ := CreateTestContext(httptest.NewRecorder(), func(c *Context) {
// enable UseInternalContext feature flag
c.engine.UseInternalContext = true
})
ctx, cancel := context.WithCancel(context.Background())
c2.WithInternalContext(ctx)
cancel()
assert.EqualError(t, c2.Err(), context.Canceled.Error())
}
func TestContextUseInternalContextValue(t *testing.T) {
type contextKey string
tests := []struct {
name string
getContextAndKey func() (*Context, any)
value any
}{
{
name: "c with struct context key",
getContextAndKey: func() (*Context, any) {
type KeyStruct struct{} // https://staticcheck.dev/docs/checks/#SA1029
var key KeyStruct
c, _ := CreateTestContext(httptest.NewRecorder(), func(c *Context) {
// enable UseInternalContext feature flag
c.engine.UseInternalContext = true
})
c.WithInternalContext(context.WithValue(context.TODO(), key, "value"))
return c, key
},
value: "value",
},
{
name: "c with struct context key and request context with different value",
getContextAndKey: func() (*Context, any) {
type KeyStruct struct{} // https://staticcheck.dev/docs/checks/#SA1029
var key KeyStruct
c, _ := CreateTestContext(httptest.NewRecorder(), func(c *Context) {
// enable UseInternalContext feature flag
c.engine.UseInternalContext = true
// enable ContextWithFallback feature flag
c.engine.ContextWithFallback = true
c.Request, _ = http.NewRequest(http.MethodPost, "/", nil)
})
c.WithInternalContext(context.WithValue(context.TODO(), key, "value"))
c.Request = c.Request.WithContext(context.WithValue(context.TODO(), key, "other value"))
return c, key
},
value: "value",
},
{
name: "c with string context key",
getContextAndKey: func() (*Context, any) {
c, _ := CreateTestContext(httptest.NewRecorder(), func(c *Context) {
// enable UseInternalContext feature flag
c.engine.UseInternalContext = true
})
c.WithInternalContext(context.WithValue(context.TODO(), contextKey("key"), "value"))
return c, contextKey("key")
},
value: "value",
},
{
name: "c with background internal context",
getContextAndKey: func() (*Context, any) {
c, _ := CreateTestContext(httptest.NewRecorder(), func(c *Context) {
// enable UseInternalContext feature flag
c.engine.UseInternalContext = true
})
c.WithInternalContext(context.Background())
return c, "key"
},
value: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, key := tt.getContextAndKey()
assert.Equal(t, tt.value, c.Value(key))
})
}
}
func TestContextCopyShouldNotCancel(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)

8
gin.go
View File

@ -17,7 +17,6 @@ import (
"github.com/gin-gonic/gin/internal/bytesconv"
"github.com/gin-gonic/gin/render"
"github.com/quic-go/quic-go/http3"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
@ -161,9 +160,14 @@ type Engine struct {
// UseH2C enable h2c support.
UseH2C bool
// ContextWithFallback enable fallback Context.Deadline(), Context.Done(), Context.Err() and Context.Value() when Context.Request.Context() is not nil.
// ContextWithFallback enable fallback Context.Deadline(), Context.Done(), Context.Err() and Context.Value()
// through Context.Request when Context.Request.Context() is not nil.
ContextWithFallback bool
// UseInternalContext enable fallback Context.Deadline(), Context.Done(), Context.Err()
// through InternalContext and supersedes ContextWithFallback
UseInternalContext bool
delims render.Delims
secureJSONPrefix string
HTMLRender render.HTMLRender

View File

@ -7,9 +7,12 @@ package gin
import "net/http"
// CreateTestContext returns a fresh engine and context for testing purposes
func CreateTestContext(w http.ResponseWriter) (c *Context, r *Engine) {
func CreateTestContext(w http.ResponseWriter, opts ...func(c *Context)) (c *Context, r *Engine) {
r = New()
c = r.allocateContext(0)
for _, opt := range opts {
opt(c)
}
c.reset()
c.writermem.reset(w)
return