make Context.Value method concurrent safe

This commit is contained in:
Youlin Feng 2019-12-12 18:14:38 +08:00
parent 168fa94516
commit dda2ee6d95
4 changed files with 71 additions and 12 deletions

View File

@ -16,6 +16,7 @@ import (
"net/url"
"os"
"strings"
"sync"
"time"
"github.com/gin-contrib/sse"
@ -53,7 +54,10 @@ type Context struct {
engine *Engine
// Keys is a key/value pair exclusively for the context of each request.
Keys map[string]interface{}
Keys *sync.Map
// keysLocker protects the creation and reset of Keys *sync.Map
keysLocker sync.Mutex
// Errors is a list of errors attached to all the handlers/middlewares who used this context.
Errors errorMsgs
@ -79,7 +83,9 @@ func (c *Context) reset() {
c.handlers = nil
c.index = -1
c.fullPath = ""
c.keysLocker.Lock()
c.Keys = nil
c.keysLocker.Unlock()
c.Errors = c.Errors[0:0]
c.Accepted = nil
c.queryCache = nil
@ -94,10 +100,15 @@ func (c *Context) Copy() *Context {
cp.Writer = &cp.writermem
cp.index = abortIndex
cp.handlers = nil
cp.Keys = map[string]interface{}{}
for k, v := range c.Keys {
cp.Keys[k] = v
keys := cp.Keys
if keys != nil {
cp.Keys = new(sync.Map)
keys.Range(func(key, value interface{}) bool {
cp.Keys.Store(key, value)
return true
})
}
cp.keysLocker = sync.Mutex{}
paramCopy := make([]Param, len(cp.Params))
copy(paramCopy, cp.Params)
cp.Params = paramCopy
@ -219,16 +230,25 @@ func (c *Context) Error(err error) *Error {
// Set is used to store a new key/value pair exclusively for this context.
// It also lazy initializes c.Keys if it was not used previously.
func (c *Context) Set(key string, value interface{}) {
if c.Keys == nil {
c.Keys = make(map[string]interface{})
keys := c.Keys
if keys == nil {
c.keysLocker.Lock()
if c.Keys == nil {
c.Keys = new(sync.Map)
}
keys = c.Keys
c.keysLocker.Unlock()
}
c.Keys[key] = value
keys.Store(key, value)
}
// Get returns the value for the given key, ie: (value, true).
// If the value does not exists it returns (nil, false)
func (c *Context) Get(key string) (value interface{}, exists bool) {
value, exists = c.Keys[key]
keys := c.Keys
if keys != nil {
value, exists = keys.Load(key)
}
return
}

View File

@ -330,11 +330,31 @@ func TestContextCopy(t *testing.T) {
assert.Equal(t, &cp.writermem, cp.Writer.(*responseWriter))
assert.Equal(t, cp.Request, c.Request)
assert.Equal(t, cp.index, abortIndex)
assert.Equal(t, cp.Keys, c.Keys)
toMap := func(p *sync.Map) map[string]interface{} {
if p == nil {
return nil
}
m := map[string]interface{}{}
p.Range(func(key, value interface{}) bool {
if str, ok := key.(string); ok {
m[str] = value
}
return true
})
return m
}
assert.Equal(t, toMap(cp.Keys), toMap(c.Keys))
assert.Equal(t, cp.engine, c.engine)
assert.Equal(t, cp.Params, c.Params)
cp.Set("foo", "notBar")
assert.False(t, cp.Keys["foo"] == c.Keys["foo"])
var vc, vcp interface{}
if c.Keys != nil {
vc, _ = c.Keys.Load("foo")
}
if cp.Keys != nil {
vcp, _ = cp.Keys.Load("foo")
}
assert.False(t, vc == vcp)
}
func TestContextHandlerName(t *testing.T) {

View File

@ -245,7 +245,17 @@ func LoggerWithConfig(conf LoggerConfig) HandlerFunc {
param := LogFormatterParams{
Request: c.Request,
isTerm: isTerm,
Keys: c.Keys,
}
keys := c.Keys
if keys != nil {
param.Keys = make(map[string]interface{})
keys.Range(func(key, value interface{}) bool {
if str, ok := key.(string); ok {
param.Keys[str] = value
}
return true
})
}
// Stop timer

View File

@ -205,7 +205,16 @@ func TestLoggerWithConfigFormatting(t *testing.T) {
router.GET("/example", func(c *Context) {
// set dummy ClientIP
c.Request.Header.Set("X-Forwarded-For", "20.20.20.20")
gotKeys = c.Keys
keys := c.Keys
if keys != nil {
gotKeys = make(map[string]interface{})
keys.Range(func(key, value interface{}) bool {
if str, ok := key.(string); ok {
gotKeys[str] = value
}
return true
})
}
})
performRequest(router, "GET", "/example?a=100")