loginsrv

Unnamed repository; edit this file 'description' to name the repository.
git clone git@jamesshield.xyz:repos/loginsrv.git
Log | Files | Refs | README | LICENSE

commit 5707da95e2cdb1c3fbe2905e64236114e8ce06a8
parent 1f46a6546c6b81a03ecac4af660271faea54dee6
Author: Sebastian Mancke <sebastian.mancke@snabble.io>
Date:   Mon,  3 Jun 2019 11:54:21 +0200

Merge pull request #130 from g-w/user-claims-404

Return empty claims if user claims endpoint responds with not found
Diffstat:
Mlogin/user_claims_provider.go | 14++++++++++----
Mlogin/user_claims_provider_test.go | 30++++++++++++++++++++++++++++++
2 files changed, 40 insertions(+), 4 deletions(-)

diff --git a/login/user_claims_provider.go b/login/user_claims_provider.go @@ -46,6 +46,9 @@ func (provider *userClaimsProvider) Claims(userInfo model.UserInfo) (jwt.Claims, resp.Body.Close() }() + if resp.StatusCode == http.StatusNotFound { + return customClaims(userInfo.AsMap()), nil + } if resp.StatusCode != http.StatusOK { return nil, errors.Errorf("bad http response code %d", resp.StatusCode) } @@ -58,10 +61,7 @@ func (provider *userClaimsProvider) Claims(userInfo model.UserInfo) (jwt.Claims, return nil, err } - claims := customClaims(userInfo.AsMap()) - claims.merge(remoteClaims) - - return claims, nil + return mergeClaims(userInfo, remoteClaims), nil } func (provider *userClaimsProvider) buildURL(userInfo model.UserInfo) string { @@ -91,6 +91,12 @@ func (provider *userClaimsProvider) buildURL(userInfo model.UserInfo) string { return u.String() } +func mergeClaims(userInfo model.UserInfo, remoteClaims map[string]interface{}) customClaims { + claims := customClaims(userInfo.AsMap()) + claims.merge(remoteClaims) + return claims +} + func validateURL(s string) error { _, err := url.Parse(s) return errors.Wrap(err, "invalid claims provider url") diff --git a/login/user_claims_provider_test.go b/login/user_claims_provider_test.go @@ -78,6 +78,36 @@ func Test_userClaimsProvider_Claims(t *testing.T) { ) } +func Test_userClaimsProvider_Claims_NotFound(t *testing.T) { + mock := createMockServer( + mockResponse{ + url: endpointPath, + status: http.StatusNotFound, + body: ``, + }, + ) + defer mock.Close() + provider, err := newUserClaimsProvider(mock.URL+endpointPath, token, time.Minute) + require.NoError(t, err) + + claims, err := provider.Claims(model.UserInfo{ + Sub: "test@example.com", + Origin: "origin", + Domain: "example.com", + }) + + require.NoError(t, err) + + assert.Equal(t, + customClaims{ + "domain": "example.com", + "origin": "origin", + "sub": "test@example.com", + }, + claims, + ) +} + func Test_userClaimsProvider_Claims_EndpointNotReachable(t *testing.T) { provider, err := newUserClaimsProvider("http://not-exists.example.com", token, time.Millisecond) require.NoError(t, err)