diff --git a/context.go b/context.go index a2e28e5b..d2295165 100644 --- a/context.go +++ b/context.go @@ -66,7 +66,9 @@ type Context struct { Params Params handlers HandlersChain index int8 - fullPath string + // abortedBy is the index of the handler that called Abort(), if any. + abortedBy int8 + fullPath string engine *Engine params *Params @@ -105,6 +107,7 @@ func (c *Context) reset() { c.Params = c.Params[:0] c.handlers = nil c.index = -1 + c.abortedBy = -1 c.fullPath = "" c.Keys = nil @@ -129,6 +132,7 @@ func (c *Context) Copy() *Context { cp.writermem.ResponseWriter = nil cp.Writer = &cp.writermem cp.index = abortIndex + cp.abortedBy = c.abortedBy cp.handlers = nil cp.fullPath = c.fullPath @@ -205,9 +209,29 @@ func (c *Context) IsAborted() bool { // If the authorization fails (ex: the password does not match), call Abort to ensure the remaining handlers // for this request are not called. func (c *Context) Abort() { + if !c.IsAborted() { + c.abortedBy = c.index + } c.index = abortIndex } +// AbortedByHandler returns the handler that called Abort(), if available. +func (c *Context) AbortedByHandler() HandlerFunc { + if c.abortedBy < 0 || int(c.abortedBy) >= len(c.handlers) { + return nil + } + return c.handlers[c.abortedBy] +} + +// AbortedBy returns the handler name that called Abort(). +func (c *Context) AbortedBy() string { + h := c.AbortedByHandler() + if h == nil { + return "" + } + return nameOfFunction(h) +} + // AbortWithStatus calls `Abort()` and writes the headers with the specified status code. // For example, a failed attempt to authenticate a request could use: context.AbortWithStatus(401). func (c *Context) AbortWithStatus(code int) { diff --git a/context_test.go b/context_test.go index 364a92ae..edacdff7 100644 --- a/context_test.go +++ b/context_test.go @@ -714,6 +714,18 @@ func handlerNameTest(c *Context) { func handlerNameTest2(c *Context) { } +func abortedByMiddleware1(c *Context) { + c.Next() +} + +func abortedByMiddleware2(c *Context) { + c.Abort() +} + +func abortedByMiddleware3(c *Context) { + c.Abort() +} + var handlerTest HandlerFunc = func(c *Context) { } @@ -724,6 +736,42 @@ func TestContextHandler(t *testing.T) { assert.Equal(t, reflect.ValueOf(handlerTest).Pointer(), reflect.ValueOf(c.Handler()).Pointer()) } +func TestContextAbortedByWithoutHandler(t *testing.T) { + c, _ := CreateTestContext(httptest.NewRecorder()) + + assert.Nil(t, c.AbortedByHandler()) + assert.Empty(t, c.AbortedBy()) + + c.Abort() + + assert.True(t, c.IsAborted()) + assert.Nil(t, c.AbortedByHandler()) + assert.Empty(t, c.AbortedBy()) +} + +func TestContextAbortedByHandler(t *testing.T) { + c, _ := CreateTestContext(httptest.NewRecorder()) + c.handlers = HandlersChain{abortedByMiddleware1, abortedByMiddleware2, abortedByMiddleware3} + + c.Next() + + assert.True(t, c.IsAborted()) + assert.Equal(t, reflect.ValueOf(abortedByMiddleware2).Pointer(), reflect.ValueOf(c.AbortedByHandler()).Pointer()) + assert.Regexp(t, "^(.*/vendor/)?github.com/gin-gonic/gin.abortedByMiddleware2$", c.AbortedBy()) +} + +func TestContextAbortedByPreserveFirstAborter(t *testing.T) { + c, _ := CreateTestContext(httptest.NewRecorder()) + c.handlers = HandlersChain{abortedByMiddleware2} + + c.Next() + firstAborter := c.AbortedByHandler() + c.Abort() + + assert.True(t, c.IsAborted()) + assert.Equal(t, reflect.ValueOf(firstAborter).Pointer(), reflect.ValueOf(c.AbortedByHandler()).Pointer()) +} + func TestContextQuery(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) c.Request, _ = http.NewRequest(http.MethodGet, "http://example.com/?foo=bar&page=10&id=", nil)