diff --git a/auth.go b/auth.go index 2503c515..d7bcb1cc 100644 --- a/auth.go +++ b/auth.go @@ -25,6 +25,8 @@ type authPair struct { } type authPairs []authPair +type authHeaderValidator func(c *Context) (authenticatedUser string, ok bool) +type UsernamePasswordValidator func(username, password string) bool func (a authPairs) searchCredential(authValue string) (string, bool) { if authValue == "" { @@ -38,30 +40,26 @@ func (a authPairs) searchCredential(authValue string) (string, bool) { return "", false } -// BasicAuthForRealm returns a Basic HTTP Authorization middleware. It takes as arguments a map[string]string where -// the key is the user name and the value is the password, as well as the name of the Realm. +// BasicAuthForRealmWithValidator returns a Basic HTTP Authorization middleware. +// Its first argument is a function that checks the username and password and returns true if an account matches. +// The second parameter is the name of the realm. If the realm is empty, "Authorization Required" will be used by default. // If the realm is empty, "Authorization Required" will be used by default. // (see http://tools.ietf.org/html/rfc2617#section-1.2) -func BasicAuthForRealm(accounts Accounts, realm string) HandlerFunc { - if realm == "" { - realm = "Authorization Required" - } - realm = "Basic realm=" + strconv.Quote(realm) - pairs := processAccounts(accounts) - return func(c *Context) { - // Search user in the slice of allowed credentials - user, found := pairs.searchCredential(c.requestHeader("Authorization")) - if !found { - // Credentials doesn't match, we return 401 and abort handlers chain. - c.Header("WWW-Authenticate", realm) - c.AbortWithStatus(http.StatusUnauthorized) - return +func BasicAuthForRealmWithValidator(validator UsernamePasswordValidator, realm string) HandlerFunc { + headerValidator := func(c *Context) (string, bool) { + username, password, ok := c.Request.BasicAuth() + if !ok { + return username, false } - // The user credentials was found, set user's id to key AuthUserKey in this context, the user's id can be read later using - // c.MustGet(gin.AuthUserKey). - c.Set(AuthUserKey, user) + ok = validator(username, password) + if ok { + return username, true + } + return "", false } + + return basicAuthForRealmWithValidator(headerValidator, realm) } // BasicAuth returns a Basic HTTP Authorization middleware. It takes as argument a map[string]string where @@ -70,6 +68,48 @@ func BasicAuth(accounts Accounts) HandlerFunc { return BasicAuthForRealm(accounts, "") } +// basicAuthForRealmWithValidator returns a Basic HTTP Authorization middleware. It takes as arguments a function and the realm. +// The function takes the context and returns the user if found and a boolean indicating whether or not authentication was successful. +// the second parameter is the name of the realm. If the realm is empty, "Authorization Required" will be used by default. +// (see http://tools.ietf.org/html/rfc2617#section-1.2) +func basicAuthForRealmWithValidator(validateUser authHeaderValidator, realm string) HandlerFunc { + if realm == "" { + realm = "Authorization Required" + } + realm = "Basic realm=" + strconv.Quote(realm) + + return func(c *Context) { + // Search user in the slice of allowed credentials + user, ok := validateUser(c) + + if !ok { + // Credentials doesn't match, we return 401 and abort handlers chain. + c.Header("WWW-Authenticate", realm) + c.AbortWithStatus(http.StatusUnauthorized) + return + } + // The user credentials was found, set user's id to key AuthUserKey in this context, the user's id can be read later using + // c.MustGet(gin.AuthUserKey). + c.Set(AuthUserKey, user) + } +} + +// BasicAuthForRealm returns a Basic HTTP Authorization middleware. It takes as arguments a map[string]string where +// the key is the user name and the value is the password, as well as the name of the Realm. +// If the realm is empty, "Authorization Required" will be used by default. +// (see http://tools.ietf.org/html/rfc2617#section-1.2) +func BasicAuthForRealm(accounts Accounts, realm string) HandlerFunc { + return basicAuthForRealmWithValidator(accountsValidator(accounts), realm) +} + +// accountsValidator returns a validator that searches for the right account using the given authorization header +func accountsValidator(accounts Accounts) authHeaderValidator { + pairs := processAccounts(accounts) + return func(c *Context) (string, bool) { + return pairs.searchCredential(c.requestHeader("Authorization")) + } +} + func processAccounts(accounts Accounts) authPairs { length := len(accounts) assert1(length > 0, "Empty list of authorized credentials") diff --git a/auth_test.go b/auth_test.go index 42b6f8fd..c7cdd4bc 100644 --- a/auth_test.go +++ b/auth_test.go @@ -77,6 +77,52 @@ func TestBasicAuthSearchCredential(t *testing.T) { assert.False(t, found) } +// test basic auth middleware with a custom validator (successful) +func TestBasicAuthWithValidatorSucceed(t *testing.T) { + middleware := BasicAuthForRealmWithValidator(func(username, password string) bool { + return username == "admin" && password == "password" + }, "") + + called := false + router := New() + router.Use(middleware) + router.GET("/login", func(c *Context) { + called = true + c.String(http.StatusOK, c.MustGet(AuthUserKey).(string)) + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/login", nil) + req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:password"))) + router.ServeHTTP(w, req) + + assert.True(t, called) + assert.Equal(t, http.StatusOK, w.Code) +} + +// test basic auth middleware with a custom validator (wrong password) +func TestBasicAuthWithValidatorFail(t *testing.T) { + middleware := BasicAuthForRealmWithValidator(func(username, password string) bool { + return username == "admin" && password == "password" + }, "") + + called := false + router := New() + router.Use(middleware) + router.GET("/login", func(c *Context) { + called = true + c.String(http.StatusOK, c.MustGet(AuthUserKey).(string)) + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/login", nil) + req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:wrong_password"))) + router.ServeHTTP(w, req) + assert.False(t, called) + assert.Equal(t, http.StatusUnauthorized, w.Code) + assert.Equal(t, "Basic realm=\"Authorization Required\"", w.Header().Get("WWW-Authenticate")) +} + func TestBasicAuthAuthorizationHeader(t *testing.T) { assert.Equal(t, "Basic YWRtaW46cGFzc3dvcmQ=", authorizationHeader("admin", "password")) }