diff --git a/context.go b/context.go index 046f284e..1cddcbf3 100644 --- a/context.go +++ b/context.go @@ -16,7 +16,10 @@ import ( "net/url" "os" "strings" + "sync" + "sync/atomic" "time" + "unsafe" "github.com/gin-contrib/sse" "github.com/gin-gonic/gin/binding" @@ -53,7 +56,7 @@ type Context struct { engine *Engine // Keys is a key/value pair exclusively for the context of each request. - Keys map[string]interface{} + Keys *sync.Map // Errors is a list of errors attached to all the handlers/middlewares who used this context. Errors errorMsgs @@ -94,9 +97,13 @@ func (c *Context) Copy() *Context { cp.Writer = &cp.writermem cp.index = abortIndex cp.handlers = nil - cp.Keys = map[string]interface{}{} - for k, v := range c.Keys { - cp.Keys[k] = v + keys := cp.Keys + if keys != nil { + cp.Keys = new(sync.Map) + keys.Range(func(key, value interface{}) bool { + cp.Keys.Store(key, value) + return true + }) } paramCopy := make([]Param, len(cp.Params)) copy(paramCopy, cp.Params) @@ -219,16 +226,23 @@ func (c *Context) Error(err error) *Error { // Set is used to store a new key/value pair exclusively for this context. // It also lazy initializes c.Keys if it was not used previously. func (c *Context) Set(key string, value interface{}) { - if c.Keys == nil { - c.Keys = make(map[string]interface{}) + keys := c.Keys + if keys == nil { + atomic.CompareAndSwapPointer((*unsafe.Pointer)(unsafe.Pointer(&c.Keys)), nil, unsafe.Pointer(new(sync.Map))) + keys = c.Keys + } + if keys != nil { + keys.Store(key, value) } - c.Keys[key] = value } // Get returns the value for the given key, ie: (value, true). // If the value does not exists it returns (nil, false) func (c *Context) Get(key string) (value interface{}, exists bool) { - value, exists = c.Keys[key] + keys := c.Keys + if keys != nil { + value, exists = keys.Load(key) + } return } diff --git a/context_test.go b/context_test.go index 18709d3d..ca67dcb0 100644 --- a/context_test.go +++ b/context_test.go @@ -330,11 +330,31 @@ func TestContextCopy(t *testing.T) { assert.Equal(t, &cp.writermem, cp.Writer.(*responseWriter)) assert.Equal(t, cp.Request, c.Request) assert.Equal(t, cp.index, abortIndex) - assert.Equal(t, cp.Keys, c.Keys) + toMap := func(p *sync.Map) map[string]interface{} { + if p == nil { + return nil + } + m := map[string]interface{}{} + p.Range(func(key, value interface{}) bool { + if str, ok := key.(string); ok { + m[str] = value + } + return true + }) + return m + } + assert.Equal(t, toMap(cp.Keys), toMap(c.Keys)) assert.Equal(t, cp.engine, c.engine) assert.Equal(t, cp.Params, c.Params) cp.Set("foo", "notBar") - assert.False(t, cp.Keys["foo"] == c.Keys["foo"]) + var vc, vcp interface{} + if c.Keys != nil { + vc, _ = c.Keys.Load("foo") + } + if cp.Keys != nil { + vcp, _ = cp.Keys.Load("foo") + } + assert.False(t, vc == vcp) } func TestContextHandlerName(t *testing.T) { diff --git a/logger.go b/logger.go index d5b96b3e..6d80b026 100644 --- a/logger.go +++ b/logger.go @@ -245,7 +245,17 @@ func LoggerWithConfig(conf LoggerConfig) HandlerFunc { param := LogFormatterParams{ Request: c.Request, isTerm: isTerm, - Keys: c.Keys, + } + + keys := c.Keys + if keys != nil { + param.Keys = make(map[string]interface{}) + keys.Range(func(key, value interface{}) bool { + if str, ok := key.(string); ok { + param.Keys[str] = value + } + return true + }) } // Stop timer diff --git a/logger_test.go b/logger_test.go index fc53f356..9feac5b2 100644 --- a/logger_test.go +++ b/logger_test.go @@ -205,7 +205,16 @@ func TestLoggerWithConfigFormatting(t *testing.T) { router.GET("/example", func(c *Context) { // set dummy ClientIP c.Request.Header.Set("X-Forwarded-For", "20.20.20.20") - gotKeys = c.Keys + keys := c.Keys + if keys != nil { + gotKeys = make(map[string]interface{}) + keys.Range(func(key, value interface{}) bool { + if str, ok := key.(string); ok { + gotKeys[str] = value + } + return true + }) + } }) performRequest(router, "GET", "/example?a=100")