func (s *Server) ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, error) { ok, err := s.ClientIdentityRepo.Authenticate(creds) if err != nil { log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err) return nil, oauth2.NewError(oauth2.ErrorServerError) } if !ok { return nil, oauth2.NewError(oauth2.ErrorInvalidClient) } signer, err := s.KeyManager.Signer() if err != nil { log.Errorf("Failed to generate ID token: %v", err) return nil, oauth2.NewError(oauth2.ErrorServerError) } now := time.Now() exp := now.Add(s.SessionManager.ValidityWindow) claims := oidc.NewClaims(s.IssuerURL.String(), creds.ID, creds.ID, now, exp) claims.Add("name", creds.ID) jwt, err := jose.NewSignedJWT(claims, signer) if err != nil { log.Errorf("Failed to generate ID token: %v", err) return nil, oauth2.NewError(oauth2.ErrorServerError) } log.Infof("Client token sent: clientID=%s", creds.ID) return jwt, nil }
// addClaimsFromScope adds claims that are based on the scopes that the client requested. // Currently, these include cross-client claims (aud, azp). func (s *Server) addClaimsFromScope(claims jose.Claims, scopes scope.Scopes, clientID string) error { crossClientIDs := scopes.CrossClientIDs() if len(crossClientIDs) > 0 { var aud []string for _, id := range crossClientIDs { if clientID == id { aud = append(aud, id) continue } allowed, err := s.CrossClientAuthAllowed(clientID, id) if err != nil { log.Errorf("Failed to check cross client auth. reqClientID %v; authClient:ID %v; err: %v", clientID, id, err) return oauth2.NewError(oauth2.ErrorServerError) } if !allowed { err := oauth2.NewError(oauth2.ErrorInvalidRequest) err.Description = fmt.Sprintf( "%q is not authorized to perform cross-client requests for %q", clientID, id) return err } aud = append(aud, id) } if len(aud) == 1 { claims.Add("aud", aud[0]) } else { claims.Add("aud", aud) } claims.Add("azp", clientID) } return nil }
func (s *grpcServer) Token(userID, clientID string, iat, exp time.Time) (*jose.JWT, string, error) { signer, err := s.server.KeyManager.Signer() if err != nil { log.Errorf("grpc.go: Failed to generate ID token: %v", err) return nil, "", oauth2.NewError(oauth2.ErrorServerError) } user, err := s.server.UserRepo.Get(nil, userID) if err != nil { log.Errorf("grpc.go: Failed to fetch user %q from repo: %v: ", userID, err) return nil, "", oauth2.NewError(oauth2.ErrorServerError) } claims := oidc.NewClaims(s.server.IssuerURL.String(), userID, clientID, iat, exp) user.AddToClaims(claims) if user.Admin { claims.Add(OtsimoUserTypeClaim, "adm") } jwt, err := jose.NewSignedJWT(claims, signer) if err != nil { log.Errorf("grpc.go: Failed to generate ID token: %v", err) return nil, "", oauth2.NewError(oauth2.ErrorServerError) } refreshToken, err := s.server.RefreshTokenRepo.Create(user.ID, clientID) if err != nil { log.Errorf("grpc.go: Failed to generate refresh token: %v", err) return nil, "", oauth2.NewError(oauth2.ErrorServerError) } return jwt, refreshToken, nil }
func TestRedirectAuthError(t *testing.T) { wantCode := http.StatusFound tests := []struct { err error state string redirectURL url.URL wantLoc string }{ { err: errors.New("foobar"), state: "bazinga", redirectURL: url.URL{Scheme: "http", Host: "server.example.com"}, wantLoc: "http://server.example.com?error=server_error&state=bazinga", }, { err: oauth2.NewError(oauth2.ErrorInvalidRequest), state: "foo", redirectURL: url.URL{Scheme: "http", Host: "server.example.com"}, wantLoc: "http://server.example.com?error=invalid_request&state=foo", }, { err: oauth2.NewError(oauth2.ErrorUnsupportedResponseType), state: "bar", redirectURL: url.URL{Scheme: "http", Host: "server.example.com"}, wantLoc: "http://server.example.com?error=unsupported_response_type&state=bar", }, } for i, tt := range tests { w := httptest.NewRecorder() redirectAuthError(w, tt.err, tt.state, tt.redirectURL) if wantCode != w.Code { t.Errorf("case %d: incorrect HTTP status: want=%d got=%d", i, wantCode, w.Code) } wantHeader := http.Header{"Location": []string{tt.wantLoc}} gotHeader := w.Header() if !reflect.DeepEqual(wantHeader, gotHeader) { t.Errorf("case %d: incorrect HTTP headers: want=%#v got=%#v", i, wantHeader, gotHeader) } gotBody := w.Body.String() if gotBody != "" { t.Errorf("case %d: incorrect empty HTTP body, got=%q", i, gotBody) } } }
func writeAuthError(w http.ResponseWriter, err error, state string) { oerr, ok := err.(*oauth2.Error) if !ok { oerr = oauth2.NewError(oauth2.ErrorServerError) } oerr.State = state writeResponseWithBody(w, http.StatusBadRequest, oerr) }
func TestWriteAuthError(t *testing.T) { wantCode := http.StatusBadRequest wantHeader := http.Header{"Content-Type": []string{"application/json"}} tests := []struct { err error state string wantBody string }{ { err: errors.New("foobar"), state: "bazinga", wantBody: `{"error":"server_error","state":"bazinga"}`, }, { err: oauth2.NewError(oauth2.ErrorInvalidRequest), state: "foo", wantBody: `{"error":"invalid_request","state":"foo"}`, }, { err: oauth2.NewError(oauth2.ErrorUnsupportedResponseType), state: "bar", wantBody: `{"error":"unsupported_response_type","state":"bar"}`, }, } for i, tt := range tests { w := httptest.NewRecorder() writeAuthError(w, tt.err, tt.state) if wantCode != w.Code { t.Errorf("case %d: incorrect HTTP status: want=%d got=%d", i, wantCode, w.Code) } gotHeader := w.Header() if !reflect.DeepEqual(wantHeader, gotHeader) { t.Errorf("case %d: incorrect HTTP headers: want=%#v got=%#v", i, wantHeader, gotHeader) } gotBody := w.Body.String() if tt.wantBody != gotBody { t.Errorf("case %d: incorrect HTTP body: want=%q got=%q", i, tt.wantBody, gotBody) } } }
func (s *Server) RefreshToken(creds oidc.ClientCredentials, token string) (*jose.JWT, error) { ok, err := s.ClientIdentityRepo.Authenticate(creds) if err != nil { log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err) return nil, oauth2.NewError(oauth2.ErrorServerError) } if !ok { log.Errorf("Failed to Authenticate client %s", creds.ID) return nil, oauth2.NewError(oauth2.ErrorInvalidClient) } userID, err := s.RefreshTokenRepo.Verify(creds.ID, token) switch err { case nil: break case refresh.ErrorInvalidToken: return nil, oauth2.NewError(oauth2.ErrorInvalidRequest) case refresh.ErrorInvalidClientID: return nil, oauth2.NewError(oauth2.ErrorInvalidClient) default: return nil, oauth2.NewError(oauth2.ErrorServerError) } user, err := s.UserRepo.Get(nil, userID) if err != nil { // The error can be user.ErrorNotFound, but we are not deleting // user at this moment, so this shouldn't happen. log.Errorf("Failed to fetch user %q from repo: %v: ", userID, err) return nil, oauth2.NewError(oauth2.ErrorServerError) } signer, err := s.KeyManager.Signer() if err != nil { log.Errorf("Failed to refresh ID token: %v", err) return nil, oauth2.NewError(oauth2.ErrorServerError) } now := time.Now() expireAt := now.Add(session.DefaultSessionValidityWindow) claims := oidc.NewClaims(s.IssuerURL.String(), user.ID, creds.ID, now, expireAt) user.AddToClaims(claims) jwt, err := jose.NewSignedJWT(claims, signer) if err != nil { log.Errorf("Failed to generate ID token: %v", err) return nil, oauth2.NewError(oauth2.ErrorServerError) } log.Infof("New token sent: clientID=%s", creds.ID) return jwt, nil }
func redirectAuthError(w http.ResponseWriter, err error, state string, redirectURL url.URL) { oerr, ok := err.(*oauth2.Error) if !ok { oerr = oauth2.NewError(oauth2.ErrorServerError) } q := redirectURL.Query() q.Set("error", oerr.Type) q.Set("state", state) redirectURL.RawQuery = q.Encode() w.Header().Set("Location", redirectURL.String()) w.WriteHeader(http.StatusFound) }
func writeTokenError(w http.ResponseWriter, err error, state string) { oerr, ok := err.(*oauth2.Error) if !ok { oerr = oauth2.NewError(oauth2.ErrorServerError) } oerr.State = state var status int switch oerr.Type { case oauth2.ErrorInvalidClient: status = http.StatusUnauthorized w.Header().Set("WWW-Authenticate", "Basic") default: status = http.StatusBadRequest } writeResponseWithBody(w, status, oerr) }
func (c *uaaOAuth2Connector) Identity(cli chttp.Client) (oidc.Identity, error) { uaaUserInfoURL := *c.uaaBaseURL uaaUserInfoURL.Path = path.Join(uaaUserInfoURL.Path, "/userinfo") req, err := http.NewRequest("GET", uaaUserInfoURL.String(), nil) if err != nil { return oidc.Identity{}, err } resp, err := cli.Do(req) if err != nil { return oidc.Identity{}, fmt.Errorf("get: %v", err) } defer resp.Body.Close() switch { case resp.StatusCode >= 400 && resp.StatusCode < 600: // attempt to decode error from UAA var authErr uaaError if err := json.NewDecoder(resp.Body).Decode(&authErr); err != nil { return oidc.Identity{}, oauth2.NewError(oauth2.ErrorAccessDenied) } return oidc.Identity{}, authErr case resp.StatusCode == http.StatusOK: default: return oidc.Identity{}, fmt.Errorf("unexpected status from providor %s", resp.Status) } var user struct { UserID string `json:"user_id"` Email string `json:"email"` Name string `json:"name"` UserName string `json:"user_name"` } if err := json.NewDecoder(resp.Body).Decode(&user); err != nil { return oidc.Identity{}, fmt.Errorf("getting user info: %v", err) } name := user.Name if name == "" { name = user.UserName } return oidc.Identity{ ID: user.UserID, Name: name, Email: user.Email, }, nil }
func (c *githubOAuth2Connector) Identity(cli chttp.Client) (oidc.Identity, error) { req, err := http.NewRequest("GET", githubAPIUserURL, nil) if err != nil { return oidc.Identity{}, err } resp, err := cli.Do(req) if err != nil { return oidc.Identity{}, fmt.Errorf("get: %v", err) } defer resp.Body.Close() switch { case resp.StatusCode >= 400 && resp.StatusCode < 600: // attempt to decode error from github var authErr githubError if err := json.NewDecoder(resp.Body).Decode(&authErr); err != nil { return oidc.Identity{}, oauth2.NewError(oauth2.ErrorAccessDenied) } return oidc.Identity{}, authErr case resp.StatusCode == http.StatusOK: default: return oidc.Identity{}, fmt.Errorf("unexpected status from providor %s", resp.Status) } var user struct { Login string `json:"login"` ID int64 `json:"id"` Email string `json:"email"` Name string `json:"name"` } if err := json.NewDecoder(resp.Body).Decode(&user); err != nil { return oidc.Identity{}, fmt.Errorf("getting user info: %v", err) } name := user.Name if name == "" { name = user.Login } return oidc.Identity{ ID: strconv.FormatInt(user.ID, 10), Name: name, Email: user.Email, }, nil }
func getAndDecode(cli chttp.Client, url string, v interface{}) error { req, err := http.NewRequest("GET", url, nil) if err != nil { return err } resp, err := cli.Do(req) if err != nil { return fmt.Errorf("get: %v", err) } defer resp.Body.Close() switch { case resp.StatusCode >= 400 && resp.StatusCode < 500: return oauth2.NewError(oauth2.ErrorAccessDenied) case resp.StatusCode == http.StatusOK: default: return fmt.Errorf("unexpected status from providor %s", resp.Status) } if err := json.NewDecoder(resp.Body).Decode(v); err != nil { return fmt.Errorf("decode body: %v", err) } return nil }
func (c *facebookOAuth2Connector) Identity(cli chttp.Client) (oidc.Identity, error) { var user struct { ID string `json:"id"` Email string `json:"email"` Name string `json:"name"` } req, err := http.NewRequest("GET", facebookGraphAPIURL, nil) if err != nil { return oidc.Identity{}, err } resp, err := cli.Do(req) if err != nil { return oidc.Identity{}, fmt.Errorf("get: %v", err) } defer resp.Body.Close() switch { case resp.StatusCode >= 400 && resp.StatusCode < 600: var authErr facebookErr if err := json.NewDecoder(resp.Body).Decode(&authErr); err != nil { return oidc.Identity{}, oauth2.NewError(oauth2.ErrorAccessDenied) } return oidc.Identity{}, authErr case resp.StatusCode == http.StatusOK: default: return oidc.Identity{}, fmt.Errorf("unexpected status from providor %s", resp.Status) } if err := json.NewDecoder(resp.Body).Decode(&user); err != nil { return oidc.Identity{}, fmt.Errorf("decode body: %v", err) } return oidc.Identity{ ID: user.ID, Name: user.Name, Email: user.Email, }, nil }
func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jose.JWT, string, error) { ok, err := s.ClientManager.Authenticate(creds) if err != nil { log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err) return nil, "", oauth2.NewError(oauth2.ErrorServerError) } if !ok { log.Errorf("Failed to Authenticate client %s", creds.ID) return nil, "", oauth2.NewError(oauth2.ErrorInvalidClient) } sessionID, err := s.SessionManager.ExchangeKey(sessionKey) if err != nil { return nil, "", oauth2.NewError(oauth2.ErrorInvalidGrant) } ses, err := s.SessionManager.Kill(sessionID) if err != nil { return nil, "", oauth2.NewError(oauth2.ErrorInvalidRequest) } if ses.ClientID != creds.ID { return nil, "", oauth2.NewError(oauth2.ErrorInvalidGrant) } signer, err := s.KeyManager.Signer() if err != nil { log.Errorf("Failed to generate ID token: %v", err) return nil, "", oauth2.NewError(oauth2.ErrorServerError) } user, err := s.UserRepo.Get(nil, ses.UserID) if err != nil { log.Errorf("Failed to fetch user %q from repo: %v: ", ses.UserID, err) return nil, "", oauth2.NewError(oauth2.ErrorServerError) } claims := ses.Claims(s.IssuerURL.String()) user.AddToClaims(claims) s.addClaimsFromScope(claims, ses.Scope, ses.ClientID) jwt, err := jose.NewSignedJWT(claims, signer) if err != nil { log.Errorf("Failed to generate ID token: %v", err) return nil, "", oauth2.NewError(oauth2.ErrorServerError) } // Generate refresh token when 'scope' contains 'offline_access'. var refreshToken string for _, scope := range ses.Scope { if scope == "offline_access" { log.Infof("Session %s requests offline access, will generate refresh token", sessionID) refreshToken, err = s.RefreshTokenRepo.Create(ses.UserID, creds.ID, ses.Scope) switch err { case nil: break default: log.Errorf("Failed to generate refresh token: %v", err) return nil, "", oauth2.NewError(oauth2.ErrorServerError) } break } } log.Infof("Session %s token sent: clientID=%s", sessionID, creds.ID) return jwt, refreshToken, nil }
func TestServerRefreshToken(t *testing.T) { issuerURL := url.URL{Scheme: "http", Host: "server.example.com"} credXXX := oidc.ClientCredentials{ ID: "XXX", Secret: "secret", } credYYY := oidc.ClientCredentials{ ID: "YYY", Secret: "secret", } signerFixture := &StaticSigner{sig: []byte("beer"), err: nil} tests := []struct { token string clientID string // The client that associates with the token. creds oidc.ClientCredentials signer jose.Signer err error }{ // Everything is good. { fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), "XXX", credXXX, signerFixture, nil, }, // Invalid refresh token(malformatted). { "invalid-token", "XXX", credXXX, signerFixture, oauth2.NewError(oauth2.ErrorInvalidRequest), }, // Invalid refresh token(invalid payload content). { fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-2"))), "XXX", credXXX, signerFixture, oauth2.NewError(oauth2.ErrorInvalidRequest), }, // Invalid refresh token(invalid ID content). { fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), "XXX", credXXX, signerFixture, oauth2.NewError(oauth2.ErrorInvalidRequest), }, // Invalid client(client is not associated with the token). { fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), "XXX", credYYY, signerFixture, oauth2.NewError(oauth2.ErrorInvalidClient), }, // Invalid client(no client ID). { fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), "XXX", oidc.ClientCredentials{ID: "", Secret: "aaa"}, signerFixture, oauth2.NewError(oauth2.ErrorInvalidClient), }, // Invalid client(no such client). { fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), "XXX", oidc.ClientCredentials{ID: "AAA", Secret: "aaa"}, signerFixture, oauth2.NewError(oauth2.ErrorInvalidClient), }, // Invalid client(no secrets). { fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), "XXX", oidc.ClientCredentials{ID: "XXX"}, signerFixture, oauth2.NewError(oauth2.ErrorInvalidClient), }, // Invalid client(invalid secret). { fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), "XXX", oidc.ClientCredentials{ID: "XXX", Secret: "bad-secret"}, signerFixture, oauth2.NewError(oauth2.ErrorInvalidClient), }, // Signing operation fails. { fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), "XXX", credXXX, &StaticSigner{sig: nil, err: errors.New("fail")}, oauth2.NewError(oauth2.ErrorServerError), }, } for i, tt := range tests { km := &StaticKeyManager{ signer: tt.signer, } ciRepo := client.NewClientIdentityRepo([]oidc.ClientIdentity{ oidc.ClientIdentity{Credentials: credXXX}, oidc.ClientIdentity{Credentials: credYYY}, }) userRepo, err := makeNewUserRepo() if err != nil { t.Fatalf("Unexpected error: %v", err) } refreshTokenRepo, err := refreshtest.NewTestRefreshTokenRepo() if err != nil { t.Fatalf("Unexpected error: %v", err) } srv := &Server{ IssuerURL: issuerURL, KeyManager: km, ClientIdentityRepo: ciRepo, UserRepo: userRepo, RefreshTokenRepo: refreshTokenRepo, } if _, err := refreshTokenRepo.Create("testid-1", tt.clientID); err != nil { t.Fatalf("Unexpected error: %v", err) } jwt, err := srv.RefreshToken(tt.creds, tt.token) if !reflect.DeepEqual(err, tt.err) { t.Errorf("Case %d: expect: %v, got: %v", i, tt.err, err) } if jwt != nil { if string(jwt.Signature) != "beer" { t.Errorf("Case %d: expect signature: beer, got signature: %v", i, jwt.Signature) } claims, err := jwt.Claims() if err != nil { t.Errorf("Case %d: unexpected error: %v", i, err) } if claims["iss"] != issuerURL.String() || claims["sub"] != "testid-1" || claims["aud"] != "XXX" { t.Errorf("Case %d: invalid claims: %v", i, claims) } } } // Test that we should return error when user cannot be found after // verifying the token. km := &StaticKeyManager{ signer: signerFixture, } ciRepo := client.NewClientIdentityRepo([]oidc.ClientIdentity{ oidc.ClientIdentity{Credentials: credXXX}, oidc.ClientIdentity{Credentials: credYYY}, }) userRepo, err := makeNewUserRepo() if err != nil { t.Fatalf("Unexpected error: %v", err) } // Create a user that will be removed later. if err := userRepo.Create(nil, user.User{ ID: "testid-2", Email: "*****@*****.**", }); err != nil { t.Fatalf("Unexpected error: %v", err) } refreshTokenRepo, err := refreshtest.NewTestRefreshTokenRepo() if err != nil { t.Fatalf("Unexpected error: %v", err) } srv := &Server{ IssuerURL: issuerURL, KeyManager: km, ClientIdentityRepo: ciRepo, UserRepo: userRepo, RefreshTokenRepo: refreshTokenRepo, } if _, err := refreshTokenRepo.Create("testid-2", credXXX.ID); err != nil { t.Fatalf("Unexpected error: %v", err) } // Recreate the user repo to remove the user we created. userRepo, err = makeNewUserRepo() if err != nil { t.Fatalf("Unexpected error: %v", err) } srv.UserRepo = userRepo _, err = srv.RefreshToken(credXXX, fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1")))) if !reflect.DeepEqual(err, oauth2.NewError(oauth2.ErrorServerError)) { t.Errorf("Expect: %v, got: %v", oauth2.NewError(oauth2.ErrorServerError), err) } }
func TestServerTokenFail(t *testing.T) { issuerURL := url.URL{Scheme: "http", Host: "server.example.com"} keyFixture := "goodkey" ccFixture := oidc.ClientCredentials{ ID: "XXX", Secret: "secrete", } signerFixture := &StaticSigner{sig: []byte("beer"), err: nil} tests := []struct { signer jose.Signer argCC oidc.ClientCredentials argKey string err error scope []string refreshToken string }{ // control test case to make sure fixtures check out { signer: signerFixture, argCC: ccFixture, argKey: keyFixture, scope: []string{"openid", "offline_access"}, refreshToken: fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), }, // no 'offline_access' in 'scope', should get empty refresh token { signer: signerFixture, argCC: ccFixture, argKey: keyFixture, scope: []string{"openid"}, }, // unrecognized key { signer: signerFixture, argCC: ccFixture, argKey: "foo", err: oauth2.NewError(oauth2.ErrorInvalidGrant), scope: []string{"openid", "offline_access"}, }, // unrecognized client { signer: signerFixture, argCC: oidc.ClientCredentials{ID: "YYY"}, argKey: keyFixture, err: oauth2.NewError(oauth2.ErrorInvalidClient), scope: []string{"openid", "offline_access"}, }, // signing operation fails { signer: &StaticSigner{sig: nil, err: errors.New("fail")}, argCC: ccFixture, argKey: keyFixture, err: oauth2.NewError(oauth2.ErrorServerError), scope: []string{"openid", "offline_access"}, }, } for i, tt := range tests { sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()) sm.GenerateCode = func() (string, error) { return keyFixture, nil } sessionID, err := sm.NewSession("connector_id", ccFixture.ID, "bogus", url.URL{}, "", false, tt.scope) if err != nil { t.Fatalf("Unexpected error: %v", err) } _, err = sm.AttachRemoteIdentity(sessionID, oidc.Identity{}) if err != nil { t.Errorf("case %d: unexpected error: %v", i, err) continue } km := &StaticKeyManager{ signer: tt.signer, } ciRepo := client.NewClientIdentityRepo([]oidc.ClientIdentity{ oidc.ClientIdentity{Credentials: ccFixture}, }) _, err = sm.AttachUser(sessionID, "testid-1") if err != nil { t.Fatalf("case %d: unexpected error: %v", i, err) } userRepo, err := makeNewUserRepo() if err != nil { t.Fatalf("Unexpected error: %v", err) } refreshTokenRepo, err := refreshtest.NewTestRefreshTokenRepo() if err != nil { t.Fatalf("Unexpected error: %v", err) } srv := &Server{ IssuerURL: issuerURL, KeyManager: km, SessionManager: sm, ClientIdentityRepo: ciRepo, UserRepo: userRepo, RefreshTokenRepo: refreshTokenRepo, } _, err = sm.NewSessionKey(sessionID) if err != nil { t.Fatalf("Unexpected error: %v", err) } jwt, token, err := srv.CodeToken(tt.argCC, tt.argKey) if token != tt.refreshToken { fmt.Printf("case %d: expect refresh token %q, got %q\n", i, tt.refreshToken, token) t.Fatalf("case %d: expect refresh token %q, got %q", i, tt.refreshToken, token) panic("") } if !reflect.DeepEqual(err, tt.err) { t.Errorf("case %d: expect %v, got %v", i, tt.err, err) } if err == nil && jwt == nil { t.Errorf("case %d: got nil JWT", i) } if err != nil && jwt != nil { t.Errorf("case %d: got non-nil JWT %v", i, jwt) } } }
func TestServerRefreshToken(t *testing.T) { clientB := client.Client{ Credentials: oidc.ClientCredentials{ ID: "example2.com", Secret: clientTestSecret, }, Metadata: oidc.ClientMetadata{ RedirectURIs: []url.URL{ url.URL{Scheme: "https", Host: "example2.com", Path: "one/two/three"}, }, }, } signerFixture := &StaticSigner{sig: []byte("beer"), err: nil} // NOTE(ericchiang): These tests assume that the database ID of the first // refresh token will be "1". tests := []struct { token string expectedRefreshToken string clientID string // The client that associates with the token. creds oidc.ClientCredentials signer jose.Signer createScopes []string refreshScopes []string expectedAud []string err error }{ // Everything is good. { token: getRefreshTokenEncoded("1", "refresh-1"), expectedRefreshToken: getRefreshTokenEncoded("1", "refresh-2"), clientID: testClientID, creds: testClientCredentials, signer: signerFixture, createScopes: []string{"openid", "profile"}, refreshScopes: []string{"openid", "profile"}, }, // Asking for a scope not originally granted to you. { token: getRefreshTokenEncoded("1", "refresh-1"), clientID: testClientID, creds: testClientCredentials, signer: signerFixture, createScopes: []string{"openid", "profile"}, refreshScopes: []string{"openid", "profile", "extra_scope"}, err: oauth2.NewError(oauth2.ErrorInvalidRequest), }, // Invalid refresh token(malformatted). { token: "invalid-token", clientID: testClientID, creds: testClientCredentials, signer: signerFixture, createScopes: []string{"openid", "profile"}, refreshScopes: []string{"openid", "profile"}, err: oauth2.NewError(oauth2.ErrorInvalidRequest), }, // Invalid refresh token(invalid payload content). { token: getRefreshTokenEncoded("1", "refresh-2"), clientID: testClientID, creds: testClientCredentials, signer: signerFixture, createScopes: []string{"openid", "profile"}, refreshScopes: []string{"openid", "profile"}, err: oauth2.NewError(oauth2.ErrorInvalidRequest), }, // Invalid refresh token(invalid ID content). { token: getRefreshTokenEncoded("0", "refresh-1"), clientID: testClientID, creds: testClientCredentials, signer: signerFixture, createScopes: []string{"openid", "profile"}, refreshScopes: []string{"openid", "profile"}, err: oauth2.NewError(oauth2.ErrorInvalidRequest), }, // Invalid client(client is not associated with the token). { token: getRefreshTokenEncoded("1", "refresh-1"), clientID: testClientID, creds: clientB.Credentials, signer: signerFixture, createScopes: []string{"openid", "profile"}, refreshScopes: []string{"openid", "profile"}, err: oauth2.NewError(oauth2.ErrorInvalidClient), }, // Invalid client(no client ID). { token: getRefreshTokenEncoded("1", "refresh-1"), clientID: testClientID, creds: oidc.ClientCredentials{ID: "", Secret: "aaa"}, signer: signerFixture, createScopes: []string{"openid", "profile"}, refreshScopes: []string{"openid", "profile"}, err: oauth2.NewError(oauth2.ErrorInvalidClient), }, // Invalid client(no such client). { token: getRefreshTokenEncoded("1", "refresh-1"), clientID: testClientID, creds: oidc.ClientCredentials{ID: "AAA", Secret: "aaa"}, signer: signerFixture, createScopes: []string{"openid", "profile"}, refreshScopes: []string{"openid", "profile"}, err: oauth2.NewError(oauth2.ErrorInvalidClient), }, // Invalid client(no secrets). { token: getRefreshTokenEncoded("1", "refresh-1"), clientID: testClientID, creds: oidc.ClientCredentials{ID: testClientID}, signer: signerFixture, createScopes: []string{"openid", "profile"}, refreshScopes: []string{"openid", "profile"}, err: oauth2.NewError(oauth2.ErrorInvalidClient), }, // Invalid client(invalid secret). { token: getRefreshTokenEncoded("1", "refresh-1"), clientID: testClientID, creds: oidc.ClientCredentials{ID: "bad-id", Secret: "bad-secret"}, signer: signerFixture, createScopes: []string{"openid", "profile"}, refreshScopes: []string{"openid", "profile"}, err: oauth2.NewError(oauth2.ErrorInvalidClient), }, // Signing operation fails. { token: getRefreshTokenEncoded("1", "refresh-1"), clientID: testClientID, creds: testClientCredentials, signer: &StaticSigner{sig: nil, err: errors.New("fail")}, createScopes: []string{"openid", "profile"}, refreshScopes: []string{"openid", "profile"}, err: oauth2.NewError(oauth2.ErrorServerError), }, // Valid Cross-Client { token: getRefreshTokenEncoded("1", "refresh-1"), expectedRefreshToken: getRefreshTokenEncoded("1", "refresh-2"), clientID: "client_a", creds: oidc.ClientCredentials{ ID: "client_a", Secret: base64.URLEncoding.EncodeToString( []byte("client_a_secret")), }, signer: signerFixture, createScopes: []string{"openid", "profile", scope.ScopeGoogleCrossClient + "client_b"}, refreshScopes: []string{"openid", "profile", scope.ScopeGoogleCrossClient + "client_b"}, expectedAud: []string{"client_b"}, }, // Valid Cross-Client - but this time we leave out the scopes in the // refresh request, which should result in the original stored scopes // being used. { token: getRefreshTokenEncoded("1", "refresh-1"), expectedRefreshToken: getRefreshTokenEncoded("1", "refresh-2"), clientID: "client_a", creds: oidc.ClientCredentials{ ID: "client_a", Secret: base64.URLEncoding.EncodeToString( []byte("client_a_secret")), }, signer: signerFixture, createScopes: []string{"openid", "profile", scope.ScopeGoogleCrossClient + "client_b"}, refreshScopes: []string{}, expectedAud: []string{"client_b"}, }, // Valid Cross-Client - asking for fewer scopes than originally used // when creating the refresh token, which is ok. { token: getRefreshTokenEncoded("1", "refresh-1"), expectedRefreshToken: getRefreshTokenEncoded("1", "refresh-2"), clientID: "client_a", creds: oidc.ClientCredentials{ ID: "client_a", Secret: base64.URLEncoding.EncodeToString( []byte("client_a_secret")), }, signer: signerFixture, createScopes: []string{"openid", "profile", scope.ScopeGoogleCrossClient + "client_b", scope.ScopeGoogleCrossClient + "client_c"}, refreshScopes: []string{"openid", "profile", scope.ScopeGoogleCrossClient + "client_b"}, expectedAud: []string{"client_b"}, }, // Valid Cross-Client - asking for multiple clients in the audience. { token: getRefreshTokenEncoded("1", "refresh-1"), expectedRefreshToken: getRefreshTokenEncoded("1", "refresh-2"), clientID: "client_a", creds: oidc.ClientCredentials{ ID: "client_a", Secret: base64.URLEncoding.EncodeToString( []byte("client_a_secret")), }, signer: signerFixture, createScopes: []string{"openid", "profile", scope.ScopeGoogleCrossClient + "client_b", scope.ScopeGoogleCrossClient + "client_c"}, refreshScopes: []string{"openid", "profile", scope.ScopeGoogleCrossClient + "client_b", scope.ScopeGoogleCrossClient + "client_c"}, expectedAud: []string{"client_b", "client_c"}, }, // Invalid Cross-Client - didn't orignally request cross-client when // refresh token was created. { token: getRefreshTokenEncoded("1", "refresh-1"), clientID: "client_a", creds: oidc.ClientCredentials{ ID: "client_a", Secret: base64.URLEncoding.EncodeToString( []byte("client_a_secret")), }, signer: signerFixture, createScopes: []string{"openid", "profile"}, refreshScopes: []string{"openid", "profile", scope.ScopeGoogleCrossClient + "client_b"}, err: oauth2.NewError(oauth2.ErrorInvalidRequest), }, } for i, tt := range tests { km := &StaticKeyManager{ signer: tt.signer, } f, err := makeCrossClientTestFixtures() if err != nil { t.Fatalf("error making test fixtures: %v", err) } f.srv.RefreshTokenRepo = refreshtest.NewTestRefreshTokenRepo() f.srv.KeyManager = km _, err = f.clientRepo.New(nil, clientB) if err != nil { t.Errorf("case %d: error creating other client: %v", i, err) } if _, err := f.srv.RefreshTokenRepo.Create(testUserID1, tt.clientID, "", tt.createScopes); err != nil { t.Fatalf("Unexpected error: %v", err) } jwt, refreshToken, expiresIn, err := f.srv.RefreshToken(tt.creds, tt.refreshScopes, tt.token) if !reflect.DeepEqual(err, tt.err) { t.Errorf("Case %d: expect: %v, got: %v", i, tt.err, err) } if jwt != nil { if string(jwt.Signature) != "beer" { t.Errorf("Case %d: expect signature: beer, got signature: %v", i, jwt.Signature) } claims, err := jwt.Claims() if err != nil { t.Errorf("Case %d: unexpected error: %v", i, err) } var expectedAud interface{} if tt.expectedAud == nil { expectedAud = testClientID } else if len(tt.expectedAud) == 1 { expectedAud = tt.expectedAud[0] } else { expectedAud = tt.expectedAud } if claims["iss"] != testIssuerURL.String() { t.Errorf("Case %d: want=%v, got=%v", i, testIssuerURL.String(), claims["iss"]) } if claims["sub"] != testUserID1 { t.Errorf("Case %d: want=%v, got=%v", i, testUserID1, claims["sub"]) } if diff := pretty.Compare(claims["aud"], expectedAud); diff != "" { t.Errorf("Case %d: want=%v, got=%v", i, expectedAud, claims["aud"]) } } if diff := pretty.Compare(refreshToken, tt.expectedRefreshToken); diff != "" { t.Errorf("Case %d: want=%v, got=%v", i, tt.expectedRefreshToken, refreshToken) } if err == nil && expiresIn.IsZero() { t.Errorf("case %d: got zero expiration time %v", i, expiresIn) } } }
func TestServerTokenFail(t *testing.T) { keyFixture := "goodkey" signerFixture := &StaticSigner{sig: []byte("beer"), err: nil} tests := []struct { signer jose.Signer argCC oidc.ClientCredentials argKey string err error scope []string refreshToken string }{ // control test case to make sure fixtures check out { // NOTE(ericchiang): This test assumes that the database ID of the first // refresh token will be "1". signer: signerFixture, argCC: testClientCredentials, argKey: keyFixture, scope: []string{"openid", "offline_access"}, refreshToken: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), }, // no 'offline_access' in 'scope', should get empty refresh token { signer: signerFixture, argCC: testClientCredentials, argKey: keyFixture, scope: []string{"openid"}, }, // unrecognized key { signer: signerFixture, argCC: testClientCredentials, argKey: "foo", err: oauth2.NewError(oauth2.ErrorInvalidGrant), scope: []string{"openid", "offline_access"}, }, // unrecognized client { signer: signerFixture, argCC: oidc.ClientCredentials{ID: "YYY"}, argKey: keyFixture, err: oauth2.NewError(oauth2.ErrorInvalidClient), scope: []string{"openid", "offline_access"}, }, // signing operation fails { signer: &StaticSigner{sig: nil, err: errors.New("fail")}, argCC: testClientCredentials, argKey: keyFixture, err: oauth2.NewError(oauth2.ErrorServerError), scope: []string{"openid", "offline_access"}, }, } for i, tt := range tests { f, err := makeTestFixtures() if err != nil { t.Fatalf("error making test fixtures: %v", err) } sm := f.sessionManager sm.GenerateCode = func() (string, error) { return keyFixture, nil } f.srv.RefreshTokenRepo = refreshtest.NewTestRefreshTokenRepo() f.srv.KeyManager = &StaticKeyManager{ signer: tt.signer, } sessionID, err := sm.NewSession(testConnectorID1, testClientID, "bogus", url.URL{}, "", false, tt.scope) if err != nil { t.Fatalf("Unexpected error: %v", err) } _, err = sm.AttachRemoteIdentity(sessionID, oidc.Identity{}) if err != nil { t.Errorf("case %d: unexpected error: %v", i, err) continue } _, err = sm.AttachUser(sessionID, testUserID1) if err != nil { t.Fatalf("case %d: unexpected error: %v", i, err) } _, err = sm.NewSessionKey(sessionID) if err != nil { t.Fatalf("Unexpected error: %v", err) } jwt, token, expiresAt, err := f.srv.CodeToken(tt.argCC, tt.argKey) if token != tt.refreshToken { fmt.Printf("case %d: expect refresh token %q, got %q\n", i, tt.refreshToken, token) t.Fatalf("case %d: expect refresh token %q, got %q", i, tt.refreshToken, token) panic("") } if !reflect.DeepEqual(err, tt.err) { t.Errorf("case %d: expect %v, got %v", i, tt.err, err) } if err == nil && jwt == nil { t.Errorf("case %d: got nil JWT", i) } if err != nil && jwt != nil { t.Errorf("case %d: got non-nil JWT %v", i, jwt) } if err == nil && expiresAt.IsZero() { t.Errorf("case %d: got zero expiration time %v", i, expiresAt) } } }
func validateScopes(srv OIDCServer, clientID string, scopes []string) error { foundOpenIDScope := false for i, curScope := range scopes { if i > 0 && curScope == scopes[i-1] { err := oauth2.NewError(oauth2.ErrorInvalidRequest) err.Description = fmt.Sprintf( "Duplicate scopes are not allowed: %q", curScope) return err } switch { case strings.HasPrefix(curScope, scope.ScopeGoogleCrossClient): otherClient := curScope[len(scope.ScopeGoogleCrossClient):] var allowed bool var err error if otherClient == clientID { allowed = true } else { allowed, err = srv.CrossClientAuthAllowed(clientID, otherClient) if err != nil { return err } } if !allowed { err := oauth2.NewError(oauth2.ErrorInvalidRequest) err.Description = fmt.Sprintf( "%q is not authorized to perform cross-client requests for %q", clientID, otherClient) return err } case curScope == "openid": foundOpenIDScope = true case curScope == "profile": case curScope == "email": case curScope == scope.ScopeGroups: case curScope == "offline_access": // According to the spec, for offline_access scope, the client must // use a response_type value that would result in an Authorization // Code. Currently oauth2.ResponseTypeCode is the only supported // response type, and it's been checked above, so we don't need to // check it again here. // // TODO(yifan): Verify that 'consent' should be in 'prompt'. default: // Reject all other scopes. err := oauth2.NewError(oauth2.ErrorInvalidRequest) err.Description = fmt.Sprintf("%q is not a recognized scope", curScope) return err } } if !foundOpenIDScope { log.Errorf("Invalid auth request: missing 'openid' in 'scope'") err := oauth2.NewError(oauth2.ErrorInvalidRequest) err.Description = "Invalid auth request: missing 'openid' in 'scope'" return err } return nil }
func handleTokenFunc(srv OIDCServer) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { w.Header().Set("Allow", "POST") phttp.WriteError(w, http.StatusMethodNotAllowed, fmt.Sprintf("POST only acceptable method")) return } err := r.ParseForm() if err != nil { log.Errorf("error parsing request: %v", err) writeTokenError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), "") return } state := r.PostForm.Get("state") user, password, ok := r.BasicAuth() if !ok { log.Errorf("error parsing basic auth") writeTokenError(w, oauth2.NewError(oauth2.ErrorInvalidClient), state) return } creds := oidc.ClientCredentials{ID: user, Secret: password} var jwt *jose.JWT var refreshToken string grantType := r.PostForm.Get("grant_type") switch grantType { case oauth2.GrantTypeAuthCode: code := r.PostForm.Get("code") if code == "" { log.Errorf("missing code param") writeTokenError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), state) return } jwt, refreshToken, err = srv.CodeToken(creds, code) if err != nil { log.Errorf("couldn't exchange code for token: %v", err) writeTokenError(w, err, state) return } case oauth2.GrantTypeClientCreds: jwt, err = srv.ClientCredsToken(creds) if err != nil { log.Errorf("couldn't creds for token: %v", err) writeTokenError(w, err, state) return } case oauth2.GrantTypeRefreshToken: token := r.PostForm.Get("refresh_token") if token == "" { writeTokenError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), state) return } jwt, err = srv.RefreshToken(creds, token) if err != nil { writeTokenError(w, err, state) return } default: log.Errorf("unsupported grant: %v", grantType) writeTokenError(w, oauth2.NewError(oauth2.ErrorUnsupportedGrantType), state) return } t := oAuth2Token{ AccessToken: jwt.Encode(), IDToken: jwt.Encode(), TokenType: "bearer", RefreshToken: refreshToken, } b, err := json.Marshal(t) if err != nil { log.Errorf("Failed marshaling %#v to JSON: %v", t, err) writeTokenError(w, oauth2.NewError(oauth2.ErrorServerError), state) return } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write(b) } }
func handleAuthFunc(srv OIDCServer, idpcs []connector.Connector, tpl *template.Template, registrationEnabled bool) http.HandlerFunc { idx := makeConnectorMap(idpcs) return func(w http.ResponseWriter, r *http.Request) { if r.Method != "GET" { w.Header().Set("Allow", "GET") phttp.WriteError(w, http.StatusMethodNotAllowed, "GET only acceptable method") return } q := r.URL.Query() register := q.Get("register") == "1" && registrationEnabled e := q.Get("error") if e != "" { sessionKey := q.Get("state") if err := srv.KillSession(sessionKey); err != nil { log.Errorf("Failed killing sessionKey %q: %v", sessionKey, err) } renderLoginPage(w, r, srv, idpcs, register, tpl) return } connectorID := q.Get("connector_id") idpc, ok := idx[connectorID] if !ok { renderLoginPage(w, r, srv, idpcs, register, tpl) return } acr, err := oauth2.ParseAuthCodeRequest(q) if err != nil { log.Errorf("Invalid auth request: %v", err) writeAuthError(w, err, acr.State) return } cm, err := srv.ClientMetadata(acr.ClientID) if err != nil { log.Errorf("Failed fetching client %q from repo: %v", acr.ClientID, err) writeAuthError(w, oauth2.NewError(oauth2.ErrorServerError), acr.State) return } if cm == nil { log.Errorf("Client %q not found", acr.ClientID) writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), acr.State) return } if len(cm.RedirectURLs) == 0 { log.Errorf("Client %q has no redirect URLs", acr.ClientID) writeAuthError(w, oauth2.NewError(oauth2.ErrorServerError), acr.State) return } redirectURL, err := client.ValidRedirectURL(acr.RedirectURL, cm.RedirectURLs) if err != nil { switch err { case (client.ErrorCantChooseRedirectURL): log.Errorf("Request must provide redirect URL as client %q has registered many", acr.ClientID) writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), acr.State) return case (client.ErrorInvalidRedirectURL): log.Errorf("Request provided unregistered redirect URL: %s", acr.RedirectURL) writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), acr.State) return case (client.ErrorNoValidRedirectURLs): log.Errorf("There are no registered URLs for the requested client: %s", acr.RedirectURL) writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), acr.State) return } } if acr.ResponseType != oauth2.ResponseTypeCode { log.Errorf("unexpected ResponseType: %v: ", acr.ResponseType) redirectAuthError(w, oauth2.NewError(oauth2.ErrorUnsupportedResponseType), acr.State, redirectURL) return } // Check scopes. var scopes []string foundOpenIDScope := false for _, scope := range acr.Scope { switch scope { case "openid": foundOpenIDScope = true scopes = append(scopes, scope) case "offline_access": // According to the spec, for offline_access scope, the client must // use a response_type value that would result in an Authorization Code. // Currently oauth2.ResponseTypeCode is the only supported response type, // and it's been checked above, so we don't need to check it again here. // // TODO(yifan): Verify that 'consent' should be in 'prompt'. scopes = append(scopes, scope) default: // Pass all other scopes. scopes = append(scopes, scope) } } if !foundOpenIDScope { log.Errorf("Invalid auth request: missing 'openid' in 'scope'") writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), acr.State) return } nonce := q.Get("nonce") key, err := srv.NewSession(connectorID, acr.ClientID, acr.State, redirectURL, nonce, register, acr.Scope) if err != nil { log.Errorf("Error creating new session: %v: ", err) redirectAuthError(w, err, acr.State, redirectURL) return } if register { _, ok := idpc.(*connector.LocalConnector) if ok { q := url.Values{} q.Set("code", key) ru := httpPathRegister + "?" + q.Encode() w.Header().Set("Location", ru) w.WriteHeader(http.StatusFound) return } } var p string if register { p = "select_account consent" } if shouldReprompt(r) || register { p = "select_account" } lu, err := idpc.LoginURL(key, p) if err != nil { log.Errorf("Connector.LoginURL failed: %v", err) redirectAuthError(w, err, acr.State, redirectURL) return } http.SetCookie(w, createLastSeenCookie()) w.Header().Set("Location", lu) w.WriteHeader(http.StatusFound) return } }
func TestServerRefreshToken(t *testing.T) { issuerURL := url.URL{Scheme: "http", Host: "server.example.com"} clientA := client.Client{ Credentials: oidc.ClientCredentials{ ID: testClientID, Secret: clientTestSecret, }, Metadata: oidc.ClientMetadata{ RedirectURIs: []url.URL{ url.URL{Scheme: "https", Host: "client.example.com", Path: "one/two/three"}, }, }, } clientB := client.Client{ Credentials: oidc.ClientCredentials{ ID: "example2.com", Secret: clientTestSecret, }, Metadata: oidc.ClientMetadata{ RedirectURIs: []url.URL{ url.URL{Scheme: "https", Host: "example2.com", Path: "one/two/three"}, }, }, } signerFixture := &StaticSigner{sig: []byte("beer"), err: nil} // NOTE(ericchiang): These tests assume that the database ID of the first // refresh token will be "1". tests := []struct { token string clientID string // The client that associates with the token. creds oidc.ClientCredentials signer jose.Signer err error }{ // Everything is good. { fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), clientA.Credentials.ID, clientA.Credentials, signerFixture, nil, }, // Invalid refresh token(malformatted). { "invalid-token", clientA.Credentials.ID, clientA.Credentials, signerFixture, oauth2.NewError(oauth2.ErrorInvalidRequest), }, // Invalid refresh token(invalid payload content). { fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-2"))), clientA.Credentials.ID, clientA.Credentials, signerFixture, oauth2.NewError(oauth2.ErrorInvalidRequest), }, // Invalid refresh token(invalid ID content). { fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), clientA.Credentials.ID, clientA.Credentials, signerFixture, oauth2.NewError(oauth2.ErrorInvalidRequest), }, // Invalid client(client is not associated with the token). { fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), clientA.Credentials.ID, clientB.Credentials, signerFixture, oauth2.NewError(oauth2.ErrorInvalidClient), }, // Invalid client(no client ID). { fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), clientA.Credentials.ID, oidc.ClientCredentials{ID: "", Secret: "aaa"}, signerFixture, oauth2.NewError(oauth2.ErrorInvalidClient), }, // Invalid client(no such client). { fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), clientA.Credentials.ID, oidc.ClientCredentials{ID: "AAA", Secret: "aaa"}, signerFixture, oauth2.NewError(oauth2.ErrorInvalidClient), }, // Invalid client(no secrets). { fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), clientA.Credentials.ID, oidc.ClientCredentials{ID: testClientID}, signerFixture, oauth2.NewError(oauth2.ErrorInvalidClient), }, // Invalid client(invalid secret). { fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), clientA.Credentials.ID, oidc.ClientCredentials{ID: "bad-id", Secret: "bad-secret"}, signerFixture, oauth2.NewError(oauth2.ErrorInvalidClient), }, // Signing operation fails. { fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), clientA.Credentials.ID, clientA.Credentials, &StaticSigner{sig: nil, err: errors.New("fail")}, oauth2.NewError(oauth2.ErrorServerError), }, } for i, tt := range tests { km := &StaticKeyManager{ signer: tt.signer, } clients := []client.Client{ clientA, clientB, } clientIDGenerator := func(hostport string) (string, error) { return hostport, nil } secGen := func() ([]byte, error) { return []byte("secret"), nil } dbm := db.NewMemDB() clientRepo := db.NewClientRepo(dbm) clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbm), clients, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen}) if err != nil { t.Fatalf("Failed to create client identity manager: %v", err) } userRepo, err := makeNewUserRepo() if err != nil { t.Fatalf("Unexpected error: %v", err) } refreshTokenRepo := refreshtest.NewTestRefreshTokenRepo() srv := &Server{ IssuerURL: issuerURL, KeyManager: km, ClientRepo: clientRepo, ClientManager: clientManager, UserRepo: userRepo, RefreshTokenRepo: refreshTokenRepo, } if _, err := refreshTokenRepo.Create("testid-1", tt.clientID); err != nil { t.Fatalf("Unexpected error: %v", err) } jwt, err := srv.RefreshToken(tt.creds, tt.token) if !reflect.DeepEqual(err, tt.err) { t.Errorf("Case %d: expect: %v, got: %v", i, tt.err, err) } if jwt != nil { if string(jwt.Signature) != "beer" { t.Errorf("Case %d: expect signature: beer, got signature: %v", i, jwt.Signature) } claims, err := jwt.Claims() if err != nil { t.Errorf("Case %d: unexpected error: %v", i, err) } if claims["iss"] != issuerURL.String() || claims["sub"] != "testid-1" || claims["aud"] != testClientID { t.Errorf("Case %d: invalid claims: %v", i, claims) } } } // Test that we should return error when user cannot be found after // verifying the token. km := &StaticKeyManager{ signer: signerFixture, } clients := []client.Client{ clientA, clientB, } clientIDGenerator := func(hostport string) (string, error) { return hostport, nil } secGen := func() ([]byte, error) { return []byte("secret"), nil } dbm := db.NewMemDB() clientRepo := db.NewClientRepo(dbm) clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbm), clients, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen}) if err != nil { t.Fatalf("Failed to create client identity manager: %v", err) } userRepo, err := makeNewUserRepo() if err != nil { t.Fatalf("Unexpected error: %v", err) } // Create a user that will be removed later. if err := userRepo.Create(nil, user.User{ ID: "testid-2", Email: "*****@*****.**", }); err != nil { t.Fatalf("Unexpected error: %v", err) } refreshTokenRepo := refreshtest.NewTestRefreshTokenRepo() srv := &Server{ IssuerURL: issuerURL, KeyManager: km, ClientRepo: clientRepo, ClientManager: clientManager, UserRepo: userRepo, RefreshTokenRepo: refreshTokenRepo, } if _, err := refreshTokenRepo.Create("testid-2", clientA.Credentials.ID); err != nil { t.Fatalf("Unexpected error: %v", err) } // Recreate the user repo to remove the user we created. userRepo, err = makeNewUserRepo() if err != nil { t.Fatalf("Unexpected error: %v", err) } srv.UserRepo = userRepo _, err = srv.RefreshToken(clientA.Credentials, fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1")))) if !reflect.DeepEqual(err, oauth2.NewError(oauth2.ErrorServerError)) { t.Errorf("Expect: %v, got: %v", oauth2.NewError(oauth2.ErrorServerError), err) } }
func TestServerTokenFail(t *testing.T) { issuerURL := url.URL{Scheme: "http", Host: "server.example.com"} keyFixture := "goodkey" ccFixture := oidc.ClientCredentials{ ID: testClientID, Secret: clientTestSecret, } signerFixture := &StaticSigner{sig: []byte("beer"), err: nil} tests := []struct { signer jose.Signer argCC oidc.ClientCredentials argKey string err error scope []string refreshToken string }{ // control test case to make sure fixtures check out { // NOTE(ericchiang): This test assumes that the database ID of the first // refresh token will be "1". signer: signerFixture, argCC: ccFixture, argKey: keyFixture, scope: []string{"openid", "offline_access"}, refreshToken: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), }, // no 'offline_access' in 'scope', should get empty refresh token { signer: signerFixture, argCC: ccFixture, argKey: keyFixture, scope: []string{"openid"}, }, // unrecognized key { signer: signerFixture, argCC: ccFixture, argKey: "foo", err: oauth2.NewError(oauth2.ErrorInvalidGrant), scope: []string{"openid", "offline_access"}, }, // unrecognized client { signer: signerFixture, argCC: oidc.ClientCredentials{ID: "YYY"}, argKey: keyFixture, err: oauth2.NewError(oauth2.ErrorInvalidClient), scope: []string{"openid", "offline_access"}, }, // signing operation fails { signer: &StaticSigner{sig: nil, err: errors.New("fail")}, argCC: ccFixture, argKey: keyFixture, err: oauth2.NewError(oauth2.ErrorServerError), scope: []string{"openid", "offline_access"}, }, } for i, tt := range tests { sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB())) sm.GenerateCode = func() (string, error) { return keyFixture, nil } sessionID, err := sm.NewSession("connector_id", ccFixture.ID, "bogus", url.URL{}, "", false, tt.scope) if err != nil { t.Fatalf("Unexpected error: %v", err) } _, err = sm.AttachRemoteIdentity(sessionID, oidc.Identity{}) if err != nil { t.Errorf("case %d: unexpected error: %v", i, err) continue } km := &StaticKeyManager{ signer: tt.signer, } clients := []client.Client{ client.Client{ Credentials: ccFixture, Metadata: oidc.ClientMetadata{ RedirectURIs: []url.URL{ validRedirURL, }, }, }, } dbm := db.NewMemDB() clientIDGenerator := func(hostport string) (string, error) { return hostport, nil } secGen := func() ([]byte, error) { return []byte("secret"), nil } clientRepo := db.NewClientRepo(dbm) clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbm), clients, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen}) if err != nil { t.Fatalf("Failed to create client identity manager: %v", err) } _, err = sm.AttachUser(sessionID, "testid-1") if err != nil { t.Fatalf("case %d: unexpected error: %v", i, err) } userRepo, err := makeNewUserRepo() if err != nil { t.Fatalf("Unexpected error: %v", err) } refreshTokenRepo := refreshtest.NewTestRefreshTokenRepo() srv := &Server{ IssuerURL: issuerURL, KeyManager: km, SessionManager: sm, ClientRepo: clientRepo, ClientManager: clientManager, UserRepo: userRepo, RefreshTokenRepo: refreshTokenRepo, } _, err = sm.NewSessionKey(sessionID) if err != nil { t.Fatalf("Unexpected error: %v", err) } jwt, token, err := srv.CodeToken(tt.argCC, tt.argKey) if token != tt.refreshToken { fmt.Printf("case %d: expect refresh token %q, got %q\n", i, tt.refreshToken, token) t.Fatalf("case %d: expect refresh token %q, got %q", i, tt.refreshToken, token) panic("") } if !reflect.DeepEqual(err, tt.err) { t.Errorf("case %d: expect %v, got %v", i, tt.err, err) } if err == nil && jwt == nil { t.Errorf("case %d: got nil JWT", i) } if err != nil && jwt != nil { t.Errorf("case %d: got non-nil JWT %v", i, jwt) } } }
func handleAuthFunc(srv OIDCServer, baseURL url.URL, idpcs []connector.Connector, tpl *template.Template, registrationEnabled bool) http.HandlerFunc { idx := makeConnectorMap(idpcs) return func(w http.ResponseWriter, r *http.Request) { if r.Method != "GET" { w.Header().Set("Allow", "GET") phttp.WriteError(w, http.StatusMethodNotAllowed, "GET only acceptable method") return } q := r.URL.Query() register := q.Get("register") == "1" && registrationEnabled e := q.Get("error") if e != "" { sessionKey := q.Get("state") if err := srv.KillSession(sessionKey); err != nil { log.Errorf("Failed killing sessionKey %q: %v", sessionKey, err) } renderLoginPage(w, r, srv, idpcs, register, tpl) return } connectorID := q.Get("connector_id") idpc, ok := idx[connectorID] if !ok { renderLoginPage(w, r, srv, idpcs, register, tpl) return } acr, err := oauth2.ParseAuthCodeRequest(q) if err != nil { log.Errorf("Invalid auth request: %v", err) writeAuthError(w, err, acr.State) return } cli, err := srv.Client(acr.ClientID) if err != nil { log.Errorf("Failed fetching client %q from repo: %v", acr.ClientID, err) writeAuthError(w, oauth2.NewError(oauth2.ErrorServerError), acr.State) return } if err == client.ErrorNotFound { log.Errorf("Client %q not found", acr.ClientID) writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), acr.State) return } redirectURL, err := cli.ValidRedirectURL(acr.RedirectURL) if err != nil { switch err { case (client.ErrorCantChooseRedirectURL): log.Errorf("Request must provide redirect URL as client %q has registered many", acr.ClientID) writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), acr.State) return case (client.ErrorInvalidRedirectURL): log.Errorf("Request provided unregistered redirect URL: %s", acr.RedirectURL) writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), acr.State) return case (client.ErrorNoValidRedirectURLs): log.Errorf("There are no registered URLs for the requested client: %s", acr.RedirectURL) writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), acr.State) return } } if acr.ResponseType != oauth2.ResponseTypeCode { log.Errorf("unexpected ResponseType: %v: ", acr.ResponseType) redirectAuthError(w, oauth2.NewError(oauth2.ErrorUnsupportedResponseType), acr.State, redirectURL) return } // Check scopes. if scopeErr := validateScopes(srv, acr.ClientID, acr.Scope); scopeErr != nil { log.Error(scopeErr) writeAuthError(w, scopeErr, acr.State) return } nonce := q.Get("nonce") key, err := srv.NewSession(connectorID, acr.ClientID, acr.State, redirectURL, nonce, register, acr.Scope) if err != nil { log.Errorf("Error creating new session: %v: ", err) redirectAuthError(w, err, acr.State, redirectURL) return } if register { _, ok := idpc.(*connector.LocalConnector) if ok { q := url.Values{} q.Set("code", key) ru := path.Join(baseURL.Path, httpPathRegister) + "?" + q.Encode() w.Header().Set("Location", ru) w.WriteHeader(http.StatusFound) return } } var p string if register { p = "select_account consent" } if shouldReprompt(r) || register { p = "select_account" } lu, err := idpc.LoginURL(key, p) if err != nil { log.Errorf("Connector.LoginURL failed: %v", err) redirectAuthError(w, err, acr.State, redirectURL) return } http.SetCookie(w, createLastSeenCookie()) w.Header().Set("Location", lu) w.WriteHeader(http.StatusFound) return } }
func TestWriteTokenError(t *testing.T) { tests := []struct { err error state string wantCode int wantHeader http.Header wantBody string }{ { err: oauth2.NewError(oauth2.ErrorInvalidRequest), state: "bazinga", wantCode: http.StatusBadRequest, wantHeader: http.Header{ "Content-Type": []string{"application/json"}, }, wantBody: `{"error":"invalid_request","state":"bazinga"}`, }, { err: oauth2.NewError(oauth2.ErrorInvalidRequest), wantCode: http.StatusBadRequest, wantHeader: http.Header{ "Content-Type": []string{"application/json"}, }, wantBody: `{"error":"invalid_request"}`, }, { err: oauth2.NewError(oauth2.ErrorInvalidGrant), wantCode: http.StatusBadRequest, wantHeader: http.Header{ "Content-Type": []string{"application/json"}, }, wantBody: `{"error":"invalid_grant"}`, }, { err: oauth2.NewError(oauth2.ErrorInvalidClient), wantCode: http.StatusUnauthorized, wantHeader: http.Header{ "Content-Type": []string{"application/json"}, "Www-Authenticate": []string{"Basic"}, }, wantBody: `{"error":"invalid_client"}`, }, { err: oauth2.NewError(oauth2.ErrorServerError), wantCode: http.StatusBadRequest, wantHeader: http.Header{ "Content-Type": []string{"application/json"}, }, wantBody: `{"error":"server_error"}`, }, { err: oauth2.NewError(oauth2.ErrorUnsupportedGrantType), wantCode: http.StatusBadRequest, wantHeader: http.Header{ "Content-Type": []string{"application/json"}, }, wantBody: `{"error":"unsupported_grant_type"}`, }, { err: errors.New("generic failure"), wantCode: http.StatusBadRequest, wantHeader: http.Header{ "Content-Type": []string{"application/json"}, }, wantBody: `{"error":"server_error"}`, }, } for i, tt := range tests { w := httptest.NewRecorder() writeTokenError(w, tt.err, tt.state) if tt.wantCode != w.Code { t.Errorf("case %d: incorrect HTTP status: want=%d got=%d", i, tt.wantCode, w.Code) } gotHeader := w.Header() if !reflect.DeepEqual(tt.wantHeader, gotHeader) { t.Errorf("case %d: incorrect HTTP headers: want=%#v got=%#v", i, tt.wantHeader, gotHeader) } gotBody := w.Body.String() if tt.wantBody != gotBody { t.Errorf("case %d: incorrect HTTP body: want=%q got=%q", i, tt.wantBody, gotBody) } } }
func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, token string) (*jose.JWT, string, time.Time, error) { ok, err := s.ClientManager.Authenticate(creds) if err != nil { log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err) return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError) } if !ok { log.Errorf("Failed to Authenticate client %s", creds.ID) return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorInvalidClient) } userID, connectorID, rtScopes, err := s.RefreshTokenRepo.Verify(creds.ID, token) switch err { case nil: break case refresh.ErrorInvalidToken: return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorInvalidRequest) case refresh.ErrorInvalidClientID: return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorInvalidClient) default: return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError) } if len(scopes) == 0 { scopes = rtScopes } else { if !rtScopes.Contains(scopes) { return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorInvalidRequest) } } usr, err := s.UserRepo.Get(nil, userID) if err != nil { // The error can be user.ErrorNotFound, but we are not deleting // user at this moment, so this shouldn't happen. log.Errorf("Failed to fetch user %q from repo: %v: ", userID, err) return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError) } var groups []string if rtScopes.HasScope(scope.ScopeGroups) { conn, ok := s.connector(connectorID) if !ok { log.Errorf("refresh token contained invalid connector ID (%s)", connectorID) return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError) } grouper, ok := conn.(connector.GroupsConnector) if !ok { log.Errorf("refresh token requested groups for connector (%s) that doesn't support groups", connectorID) return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError) } remoteIdentities, err := s.UserRepo.GetRemoteIdentities(nil, userID) if err != nil { log.Errorf("failed to get remote identities: %v", err) return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError) } remoteIdentity, ok := func() (user.RemoteIdentity, bool) { for _, ri := range remoteIdentities { if ri.ConnectorID == connectorID { return ri, true } } return user.RemoteIdentity{}, false }() if !ok { log.Errorf("failed to get remote identity for connector %s", connectorID) return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError) } if groups, err = grouper.Groups(remoteIdentity.ID); err != nil { log.Errorf("failed to get groups for refresh token: %v", connectorID) return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError) } } signer, err := s.KeyManager.Signer() if err != nil { log.Errorf("Failed to refresh ID token: %v", err) return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError) } now := time.Now() expiresAt := now.Add(session.DefaultSessionValidityWindow) claims := oidc.NewClaims(s.IssuerURL.String(), usr.ID, creds.ID, now, expiresAt) usr.AddToClaims(claims) if rtScopes.HasScope(scope.ScopeGroups) { if groups == nil { groups = []string{} } claims["groups"] = groups } s.addClaimsFromScope(claims, scope.Scopes(scopes), creds.ID) jwt, err := jose.NewSignedJWT(claims, signer) if err != nil { log.Errorf("Failed to generate ID token: %v", err) return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError) } refreshToken, err := s.RefreshTokenRepo.RenewRefreshToken(creds.ID, userID, token) if err != nil { log.Errorf("Failed to generate new refresh token: %v", err) return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError) } log.Infof("New token sent: clientID=%s", creds.ID) return jwt, refreshToken, expiresAt, nil }