mirror of
https://github.com/gin-gonic/gin.git
synced 2025-10-18 23:12:17 +08:00
make Context.Value method concurrent safe
This commit is contained in:
parent
168fa94516
commit
53422c4926
30
context.go
30
context.go
@ -16,7 +16,10 @@ import (
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/gin-contrib/sse"
|
||||
"github.com/gin-gonic/gin/binding"
|
||||
@ -53,7 +56,7 @@ type Context struct {
|
||||
engine *Engine
|
||||
|
||||
// Keys is a key/value pair exclusively for the context of each request.
|
||||
Keys map[string]interface{}
|
||||
Keys *sync.Map
|
||||
|
||||
// Errors is a list of errors attached to all the handlers/middlewares who used this context.
|
||||
Errors errorMsgs
|
||||
@ -94,9 +97,13 @@ func (c *Context) Copy() *Context {
|
||||
cp.Writer = &cp.writermem
|
||||
cp.index = abortIndex
|
||||
cp.handlers = nil
|
||||
cp.Keys = map[string]interface{}{}
|
||||
for k, v := range c.Keys {
|
||||
cp.Keys[k] = v
|
||||
keys := cp.Keys
|
||||
if keys != nil {
|
||||
cp.Keys = new(sync.Map)
|
||||
keys.Range(func(key, value interface{}) bool {
|
||||
cp.Keys.Store(key, value)
|
||||
return true
|
||||
})
|
||||
}
|
||||
paramCopy := make([]Param, len(cp.Params))
|
||||
copy(paramCopy, cp.Params)
|
||||
@ -219,16 +226,23 @@ func (c *Context) Error(err error) *Error {
|
||||
// Set is used to store a new key/value pair exclusively for this context.
|
||||
// It also lazy initializes c.Keys if it was not used previously.
|
||||
func (c *Context) Set(key string, value interface{}) {
|
||||
if c.Keys == nil {
|
||||
c.Keys = make(map[string]interface{})
|
||||
keys := c.Keys
|
||||
if keys == nil {
|
||||
atomic.CompareAndSwapPointer((*unsafe.Pointer)(unsafe.Pointer(&c.Keys)), nil, unsafe.Pointer(new(sync.Map)))
|
||||
keys = c.Keys
|
||||
}
|
||||
if keys != nil {
|
||||
keys.Store(key, value)
|
||||
}
|
||||
c.Keys[key] = value
|
||||
}
|
||||
|
||||
// Get returns the value for the given key, ie: (value, true).
|
||||
// If the value does not exists it returns (nil, false)
|
||||
func (c *Context) Get(key string) (value interface{}, exists bool) {
|
||||
value, exists = c.Keys[key]
|
||||
keys := c.Keys
|
||||
if keys != nil {
|
||||
value, exists = keys.Load(key)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -330,11 +330,31 @@ func TestContextCopy(t *testing.T) {
|
||||
assert.Equal(t, &cp.writermem, cp.Writer.(*responseWriter))
|
||||
assert.Equal(t, cp.Request, c.Request)
|
||||
assert.Equal(t, cp.index, abortIndex)
|
||||
assert.Equal(t, cp.Keys, c.Keys)
|
||||
toMap := func(p *sync.Map) map[string]interface{} {
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
m := map[string]interface{}{}
|
||||
p.Range(func(key, value interface{}) bool {
|
||||
if str, ok := key.(string); ok {
|
||||
m[str] = value
|
||||
}
|
||||
return true
|
||||
})
|
||||
return m
|
||||
}
|
||||
assert.Equal(t, toMap(cp.Keys), toMap(c.Keys))
|
||||
assert.Equal(t, cp.engine, c.engine)
|
||||
assert.Equal(t, cp.Params, c.Params)
|
||||
cp.Set("foo", "notBar")
|
||||
assert.False(t, cp.Keys["foo"] == c.Keys["foo"])
|
||||
var vc, vcp interface{}
|
||||
if c.Keys != nil {
|
||||
vc, _ = c.Keys.Load("foo")
|
||||
}
|
||||
if cp.Keys != nil {
|
||||
vcp, _ = cp.Keys.Load("foo")
|
||||
}
|
||||
assert.False(t, vc == vcp)
|
||||
}
|
||||
|
||||
func TestContextHandlerName(t *testing.T) {
|
||||
|
12
logger.go
12
logger.go
@ -245,7 +245,17 @@ func LoggerWithConfig(conf LoggerConfig) HandlerFunc {
|
||||
param := LogFormatterParams{
|
||||
Request: c.Request,
|
||||
isTerm: isTerm,
|
||||
Keys: c.Keys,
|
||||
}
|
||||
|
||||
keys := c.Keys
|
||||
if keys != nil {
|
||||
param.Keys = make(map[string]interface{})
|
||||
keys.Range(func(key, value interface{}) bool {
|
||||
if str, ok := key.(string); ok {
|
||||
param.Keys[str] = value
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// Stop timer
|
||||
|
@ -205,7 +205,16 @@ func TestLoggerWithConfigFormatting(t *testing.T) {
|
||||
router.GET("/example", func(c *Context) {
|
||||
// set dummy ClientIP
|
||||
c.Request.Header.Set("X-Forwarded-For", "20.20.20.20")
|
||||
gotKeys = c.Keys
|
||||
keys := c.Keys
|
||||
if keys != nil {
|
||||
gotKeys = make(map[string]interface{})
|
||||
keys.Range(func(key, value interface{}) bool {
|
||||
if str, ok := key.(string); ok {
|
||||
gotKeys[str] = value
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
})
|
||||
performRequest(router, "GET", "/example?a=100")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user