diff --git a/auth.go b/auth.go index 2503c515..0439e50f 100644 --- a/auth.go +++ b/auth.go @@ -7,6 +7,7 @@ package gin import ( "crypto/subtle" "encoding/base64" + "errors" "net/http" "strconv" @@ -26,6 +27,11 @@ type authPair struct { type authPairs []authPair +var ( + // ErrUnauthorized cannot authorize the request. + ErrUnauthorized = errors.New("unauthorized") +) + func (a authPairs) searchCredential(authValue string) (string, bool) { if authValue == "" { return "", false @@ -53,6 +59,7 @@ func BasicAuthForRealm(accounts Accounts, realm string) HandlerFunc { user, found := pairs.searchCredential(c.requestHeader("Authorization")) if !found { // Credentials doesn't match, we return 401 and abort handlers chain. + c.Error(ErrUnauthorized) c.Header("WWW-Authenticate", realm) c.AbortWithStatus(http.StatusUnauthorized) return diff --git a/auth_test.go b/auth_test.go index 42b6f8fd..4e6a936c 100644 --- a/auth_test.go +++ b/auth_test.go @@ -137,3 +137,27 @@ func TestBasicAuth401WithCustomRealm(t *testing.T) { assert.Equal(t, http.StatusUnauthorized, w.Code) assert.Equal(t, "Basic realm=\"My Custom \\\"Realm\\\"\"", w.Header().Get("WWW-Authenticate")) } + +func TestBasicAuthWithMiddleware(t *testing.T) { + called := false + router := New() + router.Use(func(c *Context) { + called = true + c.Next() + if c.Errors.Last().Err == ErrUnauthorized { + c.JSON(401, H{"message": "Begone!"}) + } + }, BasicAuth(Accounts{"foo": "bar"})) + router.GET("/login", func(c *Context) { + c.String(http.StatusOK, c.MustGet(AuthUserKey).(string)) + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/login", nil) + req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:password"))) + router.ServeHTTP(w, req) + + assert.True(t, called) + assert.Equal(t, http.StatusUnauthorized, w.Code) + assert.JSONEq(t, `{"message": "Begone!"}`, w.Body.String()) +} diff --git a/docs/doc.md b/docs/doc.md index e48c2ba1..d7d47cd9 100644 --- a/docs/doc.md +++ b/docs/doc.md @@ -46,6 +46,7 @@ - [Redirects](#redirects) - [Custom Middleware](#custom-middleware) - [Using BasicAuth() middleware](#using-basicauth-middleware) + - [Detecting authorization failure in custom middleware](#detecting-authorization-failure-in-custom-middleware) - [Goroutines inside a middleware](#goroutines-inside-a-middleware) - [Custom HTTP configuration](#custom-http-configuration) - [Support Let's Encrypt](#support-lets-encrypt) @@ -1468,6 +1469,26 @@ func main() { } ``` +#### Detecting authorization failure in custom middleware + +When the `BasicAuth` middleware fails authorization, an `Error` is added to the `gin.Context.Errors` slice. You can detect this failure in a custom middleware with code like this: + +```go +func main() { + router := New() + router.Use(func(c *Context) { + c.Next() + if c.Errors.Last().Err == ErrUnauthorized { + // Unauthorized detected, act accordingly + } + }) + router.Use(BasicAuth(Accounts{"admin": "password"})) + router.GET("/login", func(c *Context) { + c.String(http.StatusOK, c.MustGet(AuthUserKey).(string)) + }) +} +``` + ### Goroutines inside a middleware When starting new Goroutines inside a middleware or handler, you **SHOULD NOT** use the original context inside it, you have to use a read-only copy.