diff --git a/context.go b/context.go index d7280c5d..9b6716c5 100644 --- a/context.go +++ b/context.go @@ -73,6 +73,9 @@ type Context struct { // This mutex protects Keys map. mu sync.RWMutex + // This mutex protects headers map + hmu sync.RWMutex + // Keys is a key/value pair exclusively for the context of each request. Keys map[any]any @@ -982,6 +985,8 @@ func (c *Context) IsWebsocket() bool { } func (c *Context) requestHeader(key string) string { + c.hmu.RLock() + defer c.hmu.RUnlock() return c.Request.Header.Get(key) } @@ -1011,6 +1016,8 @@ func (c *Context) Status(code int) { // It writes a header in the response. // If value == "", this method removes the header `c.Writer.Header().Del(key)` func (c *Context) Header(key, value string) { + c.hmu.Lock() + defer c.hmu.Unlock() if value == "" { c.Writer.Header().Del(key) return diff --git a/context_test.go b/context_test.go index cc066ef8..7d140bf4 100644 --- a/context_test.go +++ b/context_test.go @@ -3423,6 +3423,48 @@ func TestContextSetCookieData(t *testing.T) { }) } +func TestParallelHeaderAccess(t *testing.T) { + t.Parallel() + const iterations = 1000 + const goroutines = 8 + + testCases := []struct { + name string + writerCount int + readerCount int + }{ + {"parallel_write_only", goroutines, 0}, + {"parallel_write_and_read", goroutines / 2, goroutines / 2}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c, _ := CreateTestContext(httptest.NewRecorder()) + c.Request, _ = http.NewRequest(http.MethodGet, "/", nil) + wg := sync.WaitGroup{} + for i := 0; i < tc.writerCount; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for range iterations { + c.Header("key", "value") + } + }() + } + for i := 0; i < tc.readerCount; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for range iterations { + _ = c.GetHeader("key") + } + }() + } + wg.Wait() + }) + } +} + func TestGetMapFromFormData(t *testing.T) { testCases := []struct { name string