mirror of
https://github.com/gin-gonic/gin.git
synced 2025-10-18 23:12:17 +08:00
make Context.Value method concurrent safe
This commit is contained in:
parent
168fa94516
commit
53422c4926
30
context.go
30
context.go
@ -16,7 +16,10 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
"github.com/gin-contrib/sse"
|
"github.com/gin-contrib/sse"
|
||||||
"github.com/gin-gonic/gin/binding"
|
"github.com/gin-gonic/gin/binding"
|
||||||
@ -53,7 +56,7 @@ type Context struct {
|
|||||||
engine *Engine
|
engine *Engine
|
||||||
|
|
||||||
// Keys is a key/value pair exclusively for the context of each request.
|
// Keys is a key/value pair exclusively for the context of each request.
|
||||||
Keys map[string]interface{}
|
Keys *sync.Map
|
||||||
|
|
||||||
// Errors is a list of errors attached to all the handlers/middlewares who used this context.
|
// Errors is a list of errors attached to all the handlers/middlewares who used this context.
|
||||||
Errors errorMsgs
|
Errors errorMsgs
|
||||||
@ -94,9 +97,13 @@ func (c *Context) Copy() *Context {
|
|||||||
cp.Writer = &cp.writermem
|
cp.Writer = &cp.writermem
|
||||||
cp.index = abortIndex
|
cp.index = abortIndex
|
||||||
cp.handlers = nil
|
cp.handlers = nil
|
||||||
cp.Keys = map[string]interface{}{}
|
keys := cp.Keys
|
||||||
for k, v := range c.Keys {
|
if keys != nil {
|
||||||
cp.Keys[k] = v
|
cp.Keys = new(sync.Map)
|
||||||
|
keys.Range(func(key, value interface{}) bool {
|
||||||
|
cp.Keys.Store(key, value)
|
||||||
|
return true
|
||||||
|
})
|
||||||
}
|
}
|
||||||
paramCopy := make([]Param, len(cp.Params))
|
paramCopy := make([]Param, len(cp.Params))
|
||||||
copy(paramCopy, cp.Params)
|
copy(paramCopy, cp.Params)
|
||||||
@ -219,16 +226,23 @@ func (c *Context) Error(err error) *Error {
|
|||||||
// Set is used to store a new key/value pair exclusively for this context.
|
// Set is used to store a new key/value pair exclusively for this context.
|
||||||
// It also lazy initializes c.Keys if it was not used previously.
|
// It also lazy initializes c.Keys if it was not used previously.
|
||||||
func (c *Context) Set(key string, value interface{}) {
|
func (c *Context) Set(key string, value interface{}) {
|
||||||
if c.Keys == nil {
|
keys := c.Keys
|
||||||
c.Keys = make(map[string]interface{})
|
if keys == nil {
|
||||||
|
atomic.CompareAndSwapPointer((*unsafe.Pointer)(unsafe.Pointer(&c.Keys)), nil, unsafe.Pointer(new(sync.Map)))
|
||||||
|
keys = c.Keys
|
||||||
|
}
|
||||||
|
if keys != nil {
|
||||||
|
keys.Store(key, value)
|
||||||
}
|
}
|
||||||
c.Keys[key] = value
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get returns the value for the given key, ie: (value, true).
|
// Get returns the value for the given key, ie: (value, true).
|
||||||
// If the value does not exists it returns (nil, false)
|
// If the value does not exists it returns (nil, false)
|
||||||
func (c *Context) Get(key string) (value interface{}, exists bool) {
|
func (c *Context) Get(key string) (value interface{}, exists bool) {
|
||||||
value, exists = c.Keys[key]
|
keys := c.Keys
|
||||||
|
if keys != nil {
|
||||||
|
value, exists = keys.Load(key)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -330,11 +330,31 @@ func TestContextCopy(t *testing.T) {
|
|||||||
assert.Equal(t, &cp.writermem, cp.Writer.(*responseWriter))
|
assert.Equal(t, &cp.writermem, cp.Writer.(*responseWriter))
|
||||||
assert.Equal(t, cp.Request, c.Request)
|
assert.Equal(t, cp.Request, c.Request)
|
||||||
assert.Equal(t, cp.index, abortIndex)
|
assert.Equal(t, cp.index, abortIndex)
|
||||||
assert.Equal(t, cp.Keys, c.Keys)
|
toMap := func(p *sync.Map) map[string]interface{} {
|
||||||
|
if p == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
m := map[string]interface{}{}
|
||||||
|
p.Range(func(key, value interface{}) bool {
|
||||||
|
if str, ok := key.(string); ok {
|
||||||
|
m[str] = value
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
assert.Equal(t, toMap(cp.Keys), toMap(c.Keys))
|
||||||
assert.Equal(t, cp.engine, c.engine)
|
assert.Equal(t, cp.engine, c.engine)
|
||||||
assert.Equal(t, cp.Params, c.Params)
|
assert.Equal(t, cp.Params, c.Params)
|
||||||
cp.Set("foo", "notBar")
|
cp.Set("foo", "notBar")
|
||||||
assert.False(t, cp.Keys["foo"] == c.Keys["foo"])
|
var vc, vcp interface{}
|
||||||
|
if c.Keys != nil {
|
||||||
|
vc, _ = c.Keys.Load("foo")
|
||||||
|
}
|
||||||
|
if cp.Keys != nil {
|
||||||
|
vcp, _ = cp.Keys.Load("foo")
|
||||||
|
}
|
||||||
|
assert.False(t, vc == vcp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestContextHandlerName(t *testing.T) {
|
func TestContextHandlerName(t *testing.T) {
|
||||||
|
12
logger.go
12
logger.go
@ -245,7 +245,17 @@ func LoggerWithConfig(conf LoggerConfig) HandlerFunc {
|
|||||||
param := LogFormatterParams{
|
param := LogFormatterParams{
|
||||||
Request: c.Request,
|
Request: c.Request,
|
||||||
isTerm: isTerm,
|
isTerm: isTerm,
|
||||||
Keys: c.Keys,
|
}
|
||||||
|
|
||||||
|
keys := c.Keys
|
||||||
|
if keys != nil {
|
||||||
|
param.Keys = make(map[string]interface{})
|
||||||
|
keys.Range(func(key, value interface{}) bool {
|
||||||
|
if str, ok := key.(string); ok {
|
||||||
|
param.Keys[str] = value
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop timer
|
// Stop timer
|
||||||
|
@ -205,7 +205,16 @@ func TestLoggerWithConfigFormatting(t *testing.T) {
|
|||||||
router.GET("/example", func(c *Context) {
|
router.GET("/example", func(c *Context) {
|
||||||
// set dummy ClientIP
|
// set dummy ClientIP
|
||||||
c.Request.Header.Set("X-Forwarded-For", "20.20.20.20")
|
c.Request.Header.Set("X-Forwarded-For", "20.20.20.20")
|
||||||
gotKeys = c.Keys
|
keys := c.Keys
|
||||||
|
if keys != nil {
|
||||||
|
gotKeys = make(map[string]interface{})
|
||||||
|
keys.Range(func(key, value interface{}) bool {
|
||||||
|
if str, ok := key.(string); ok {
|
||||||
|
gotKeys[str] = value
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
})
|
})
|
||||||
performRequest(router, "GET", "/example?a=100")
|
performRequest(router, "GET", "/example?a=100")
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user