diff --git a/context.go b/context.go index a2e28e5b..640e20b2 100644 --- a/context.go +++ b/context.go @@ -141,6 +141,16 @@ func (c *Context) Copy() *Context { cp.Params = make([]Param, len(cParams)) copy(cp.Params, cParams) + if c.Errors != nil { + cp.Errors = make(errorMsgs, len(c.Errors)) + copy(cp.Errors, c.Errors) + } + + if c.Accepted != nil { + cp.Accepted = make([]string, len(c.Accepted)) + copy(cp.Accepted, c.Accepted) + } + return &cp } diff --git a/context_test.go b/context_test.go index 364a92ae..2dfdd392 100644 --- a/context_test.go +++ b/context_test.go @@ -689,6 +689,50 @@ func TestContextCopy(t *testing.T) { assert.Equal(t, cp.fullPath, c.fullPath) } +func TestContextCopyCopiesErrors(t *testing.T) { + c, _ := CreateTestContext(httptest.NewRecorder()) + c.Request, _ = http.NewRequest(http.MethodGet, "/", nil) + _ = c.Error(errors.New("first error")) + _ = c.Error(errors.New("second error")) + + cp := c.Copy() + + // copied context has the same errors + assert.Len(t, cp.Errors, 2) + assert.Equal(t, c.Errors[0].Error(), cp.Errors[0].Error()) + assert.Equal(t, c.Errors[1].Error(), cp.Errors[1].Error()) + + // mutations on the copy do not affect the original + _ = cp.Error(errors.New("third error")) + assert.Len(t, c.Errors, 2) + assert.Len(t, cp.Errors, 3) +} + +func TestContextCopyCopiesAccepted(t *testing.T) { + c, _ := CreateTestContext(httptest.NewRecorder()) + c.Request, _ = http.NewRequest(http.MethodGet, "/", nil) + c.SetAccepted("application/json", "text/html") + + cp := c.Copy() + + assert.Equal(t, c.Accepted, cp.Accepted) + + // mutations on the copy do not affect the original + cp.SetAccepted("text/plain") + assert.Equal(t, []string{"application/json", "text/html"}, c.Accepted) + assert.Equal(t, []string{"text/plain"}, cp.Accepted) +} + +func TestContextCopyNilErrorsAndAccepted(t *testing.T) { + c, _ := CreateTestContext(httptest.NewRecorder()) + c.Request, _ = http.NewRequest(http.MethodGet, "/", nil) + + cp := c.Copy() + + assert.Nil(t, cp.Errors) + assert.Nil(t, cp.Accepted) +} + func TestContextHandlerName(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) c.handlers = HandlersChain{func(c *Context) {}, handlerNameTest}