func testKeysCRUD(t *testing.T, s storage.Storage) { updateAndCompare := func(k storage.Keys) { err := s.UpdateKeys(func(oldKeys storage.Keys) (storage.Keys, error) { return k, nil }) if err != nil { t.Errorf("failed to update keys: %v", err) return } if got, err := s.GetKeys(); err != nil { t.Errorf("failed to get keys: %v", err) } else { got.NextRotation = got.NextRotation.UTC() if diff := pretty.Compare(k, got); diff != "" { t.Errorf("got keys did not equal expected: %s", diff) } } } // Postgres isn't as accurate with nano seconds as we'd like n := time.Now().UTC().Round(time.Second) keys1 := storage.Keys{ SigningKey: jsonWebKeys[0].Private, SigningKeyPub: jsonWebKeys[0].Public, NextRotation: n, } keys2 := storage.Keys{ SigningKey: jsonWebKeys[2].Private, SigningKeyPub: jsonWebKeys[2].Public, NextRotation: n.Add(time.Hour), VerificationKeys: []storage.VerificationKey{ { PublicKey: jsonWebKeys[0].Public, Expiry: n.Add(time.Hour), }, { PublicKey: jsonWebKeys[1].Public, Expiry: n.Add(time.Hour * 2), }, }, } updateAndCompare(keys1) updateAndCompare(keys2) }
func startGarbageCollection(ctx context.Context, s storage.Storage, frequency time.Duration, now func() time.Time) { go func() { for { select { case <-ctx.Done(): return case <-time.After(frequency): if r, err := s.GarbageCollect(now()); err != nil { log.Printf("garbage collection failed: %v", err) } else if r.AuthRequests > 0 || r.AuthCodes > 0 { log.Printf("garbage collection run, delete auth requests=%d, auth codes=%d", r.AuthRequests, r.AuthCodes) } } } }() return }
// testTimezones tests that backends either fully support timezones or // do the correct standardization. func testTimezones(t *testing.T, s storage.Storage) { est, err := time.LoadLocation("America/New_York") if err != nil { t.Fatal(err) } // Create an expiry with timezone info. Only expect backends to be // accurate to the millisecond expiry := time.Now().In(est).Round(time.Millisecond) c := storage.AuthCode{ ID: storage.NewID(), ClientID: "foobar", RedirectURI: "https://*****:*****@example.com", EmailVerified: true, Groups: []string{"a", "b"}, }, } if err := s.CreateAuthCode(c); err != nil { t.Fatalf("failed creating auth code: %v", err) } got, err := s.GetAuthCode(c.ID) if err != nil { t.Fatalf("failed to get auth code: %v", err) } // Ensure that if the resulting time is converted to the same // timezone, it's the same value. We DO NOT expect timezones // to be preserved. gotTime := got.Expiry.In(est) wantTime := expiry if !gotTime.Equal(wantTime) { t.Fatalf("expected expiry %v got %v", wantTime, gotTime) } }
func validateCrossClientTrust(s storage.Storage, clientID, peerID string) (trusted bool, err error) { if peerID == clientID { return true, nil } peer, err := s.GetClient(peerID) if err != nil { if err != storage.ErrNotFound { log.Printf("Failed to get client: %v", err) return false, err } return false, nil } for _, id := range peer.TrustedPeers { if id == clientID { return true, nil } } return false, nil }
func testClientConcurrentUpdate(t *testing.T, s storage.Storage) { c := storage.Client{ ID: storage.NewID(), Secret: "foobar", RedirectURIs: []string{"foo://bar.com/", "https://auth.example.com"}, Name: "dex client", LogoURL: "https://goo.gl/JIyzIC", } if err := s.CreateClient(c); err != nil { t.Fatalf("create client: %v", err) } var err1, err2 error err1 = s.UpdateClient(c.ID, func(old storage.Client) (storage.Client, error) { old.Secret = "new secret 1" err2 = s.UpdateClient(c.ID, func(old storage.Client) (storage.Client, error) { old.Secret = "new secret 2" return old, nil }) return old, nil }) if (err1 == nil) == (err2 == nil) { t.Errorf("update client:\nupdate1: %v\nupdate2: %v\n", err1, err2) } }
func testKeysConcurrentUpdate(t *testing.T, s storage.Storage) { // Test twice. Once for a create, once for an update. for i := 0; i < 2; i++ { n := time.Now().UTC().Round(time.Second) keys1 := storage.Keys{ SigningKey: jsonWebKeys[0].Private, SigningKeyPub: jsonWebKeys[0].Public, NextRotation: n, } keys2 := storage.Keys{ SigningKey: jsonWebKeys[2].Private, SigningKeyPub: jsonWebKeys[2].Public, NextRotation: n.Add(time.Hour), VerificationKeys: []storage.VerificationKey{ { PublicKey: jsonWebKeys[0].Public, Expiry: n.Add(time.Hour), }, { PublicKey: jsonWebKeys[1].Public, Expiry: n.Add(time.Hour * 2), }, }, } var err1, err2 error err1 = s.UpdateKeys(func(old storage.Keys) (storage.Keys, error) { err2 = s.UpdateKeys(func(old storage.Keys) (storage.Keys, error) { return keys1, nil }) return keys2, nil }) if (err1 == nil) == (err2 == nil) { t.Errorf("update keys: concurrent updates both returned no error") } } }
func testPasswordConcurrentUpdate(t *testing.T, s storage.Storage) { // Use bcrypt.MinCost to keep the tests short. passwordHash, err := bcrypt.GenerateFromPassword([]byte("secret"), bcrypt.MinCost) if err != nil { t.Fatal(err) } password := storage.Password{ Email: "*****@*****.**", Hash: passwordHash, Username: "******", UserID: "foobar", } if err := s.CreatePassword(password); err != nil { t.Fatalf("create password token: %v", err) } var err1, err2 error err1 = s.UpdatePassword(password.Email, func(old storage.Password) (storage.Password, error) { old.Username = "******" err2 = s.UpdatePassword(password.Email, func(old storage.Password) (storage.Password, error) { old.Username = "******" return old, nil }) return old, nil }) if (err1 == nil) == (err2 == nil) { t.Errorf("update password: concurrent updates both returned no error") } }
func testAuthCodeCRUD(t *testing.T, s storage.Storage) { a := storage.AuthCode{ ID: storage.NewID(), ClientID: "foobar", RedirectURI: "https://*****:*****@example.com", EmailVerified: true, Groups: []string{"a", "b"}, }, } if err := s.CreateAuthCode(a); err != nil { t.Fatalf("failed creating auth code: %v", err) } got, err := s.GetAuthCode(a.ID) if err != nil { t.Fatalf("failed to get auth req: %v", err) } if a.Expiry.Unix() != got.Expiry.Unix() { t.Errorf("auth code expiry did not match want=%s vs got=%s", a.Expiry, got.Expiry) } got.Expiry = a.Expiry // time fields do not compare well if diff := pretty.Compare(a, got); diff != "" { t.Errorf("auth code retrieved from storage did not match: %s", diff) } if err := s.DeleteAuthCode(a.ID); err != nil { t.Fatalf("delete auth code: %v", err) } _, err = s.GetAuthCode(a.ID) mustBeErrNotFound(t, "auth code", err) }
func testRefreshTokenCRUD(t *testing.T, s storage.Storage) { id := storage.NewID() refresh := storage.RefreshToken{ RefreshToken: id, ClientID: "client_id", ConnectorID: "client_secret", Scopes: []string{"openid", "email", "profile"}, Claims: storage.Claims{ UserID: "1", Username: "******", Email: "*****@*****.**", EmailVerified: true, Groups: []string{"a", "b"}, }, } if err := s.CreateRefresh(refresh); err != nil { t.Fatalf("create refresh token: %v", err) } getAndCompare := func(id string, want storage.RefreshToken) { gr, err := s.GetRefresh(id) if err != nil { t.Errorf("get refresh: %v", err) return } if diff := pretty.Compare(want, gr); diff != "" { t.Errorf("refresh token retrieved from storage did not match: %s", diff) } } getAndCompare(id, refresh) if err := s.DeleteRefresh(id); err != nil { t.Fatalf("failed to delete refresh request: %v", err) } if _, err := s.GetRefresh(id); err != storage.ErrNotFound { t.Errorf("after deleting refresh expected storage.ErrNotFound, got %v", err) } }
func testAuthRequestCRUD(t *testing.T, s storage.Storage) { a := storage.AuthRequest{ ID: storage.NewID(), ClientID: "foobar", ResponseTypes: []string{"code"}, Scopes: []string{"openid", "email"}, RedirectURI: "https://*****:*****@example.com", EmailVerified: true, Groups: []string{"a", "b"}, }, } identity := storage.Claims{Email: "foobar"} if err := s.CreateAuthRequest(a); err != nil { t.Fatalf("failed creating auth request: %v", err) } if err := s.UpdateAuthRequest(a.ID, func(old storage.AuthRequest) (storage.AuthRequest, error) { old.Claims = identity old.ConnectorID = "connID" return old, nil }); err != nil { t.Fatalf("failed to update auth request: %v", err) } got, err := s.GetAuthRequest(a.ID) if err != nil { t.Fatalf("failed to get auth req: %v", err) } if !reflect.DeepEqual(got.Claims, identity) { t.Fatalf("update failed, wanted identity=%#v got %#v", identity, got.Claims) } }
func testAuthRequestConcurrentUpdate(t *testing.T, s storage.Storage) { a := storage.AuthRequest{ ID: storage.NewID(), ClientID: "foobar", ResponseTypes: []string{"code"}, Scopes: []string{"openid", "email"}, RedirectURI: "https://*****:*****@example.com", EmailVerified: true, Groups: []string{"a", "b"}, }, } if err := s.CreateAuthRequest(a); err != nil { t.Fatalf("failed creating auth request: %v", err) } var err1, err2 error err1 = s.UpdateAuthRequest(a.ID, func(old storage.AuthRequest) (storage.AuthRequest, error) { old.State = "state 1" err2 = s.UpdateAuthRequest(a.ID, func(old storage.AuthRequest) (storage.AuthRequest, error) { old.State = "state 2" return old, nil }) return old, nil }) if (err1 == nil) == (err2 == nil) { t.Errorf("update auth request:\nupdate1: %v\nupdate2: %v\n", err1, err2) } }
func testGC(t *testing.T, s storage.Storage) { est, err := time.LoadLocation("America/New_York") if err != nil { t.Fatal(err) } pst, err := time.LoadLocation("America/Los_Angeles") if err != nil { t.Fatal(err) } expiry := time.Now().In(est) c := storage.AuthCode{ ID: storage.NewID(), ClientID: "foobar", RedirectURI: "https://*****:*****@example.com", EmailVerified: true, Groups: []string{"a", "b"}, }, } if err := s.CreateAuthCode(c); err != nil { t.Fatalf("failed creating auth code: %v", err) } for _, tz := range []*time.Location{time.UTC, est, pst} { result, err := s.GarbageCollect(expiry.Add(-time.Hour).In(tz)) if err != nil { t.Errorf("garbage collection failed: %v", err) } else { if result.AuthCodes != 0 || result.AuthRequests != 0 { t.Errorf("expected no garbage collection results, got %#v", result) } } if _, err := s.GetAuthCode(c.ID); err != nil { t.Errorf("expected to be able to get auth code after GC: %v", err) } } if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil { t.Errorf("garbage collection failed: %v", err) } else if r.AuthCodes != 1 { t.Errorf("expected to garbage collect 1 objects, got %d", r.AuthCodes) } if _, err := s.GetAuthCode(c.ID); err == nil { t.Errorf("expected auth code to be GC'd") } else if err != storage.ErrNotFound { t.Errorf("expected storage.ErrNotFound, got %v", err) } a := storage.AuthRequest{ ID: storage.NewID(), ClientID: "foobar", ResponseTypes: []string{"code"}, Scopes: []string{"openid", "email"}, RedirectURI: "https://*****:*****@example.com", EmailVerified: true, Groups: []string{"a", "b"}, }, } if err := s.CreateAuthRequest(a); err != nil { t.Fatalf("failed creating auth request: %v", err) } for _, tz := range []*time.Location{time.UTC, est, pst} { result, err := s.GarbageCollect(expiry.Add(-time.Hour).In(tz)) if err != nil { t.Errorf("garbage collection failed: %v", err) } else { if result.AuthCodes != 0 || result.AuthRequests != 0 { t.Errorf("expected no garbage collection results, got %#v", result) } } if _, err := s.GetAuthRequest(a.ID); err != nil { t.Errorf("expected to be able to get auth code after GC: %v", err) } } if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil { t.Errorf("garbage collection failed: %v", err) } else if r.AuthRequests != 1 { t.Errorf("expected to garbage collect 1 objects, got %d", r.AuthRequests) } if _, err := s.GetAuthRequest(a.ID); err == nil { t.Errorf("expected auth code to be GC'd") } else if err != storage.ErrNotFound { t.Errorf("expected storage.ErrNotFound, got %v", err) } }
func testPasswordCRUD(t *testing.T, s storage.Storage) { // Use bcrypt.MinCost to keep the tests short. passwordHash, err := bcrypt.GenerateFromPassword([]byte("secret"), bcrypt.MinCost) if err != nil { t.Fatal(err) } password := storage.Password{ Email: "*****@*****.**", Hash: passwordHash, Username: "******", UserID: "foobar", } if err := s.CreatePassword(password); err != nil { t.Fatalf("create password token: %v", err) } getAndCompare := func(id string, want storage.Password) { gr, err := s.GetPassword(id) if err != nil { t.Errorf("get password %q: %v", id, err) return } if diff := pretty.Compare(want, gr); diff != "" { t.Errorf("password retrieved from storage did not match: %s", diff) } } getAndCompare("*****@*****.**", password) getAndCompare("*****@*****.**", password) // Emails should be case insensitive if err := s.UpdatePassword(password.Email, func(old storage.Password) (storage.Password, error) { old.Username = "******" return old, nil }); err != nil { t.Fatalf("failed to update auth request: %v", err) } password.Username = "******" getAndCompare("*****@*****.**", password) var passwordList []storage.Password passwordList = append(passwordList, password) listAndCompare := func(want []storage.Password) { passwords, err := s.ListPasswords() if err != nil { t.Errorf("list password: %v", err) return } sort.Sort(byEmail(want)) sort.Sort(byEmail(passwords)) if diff := pretty.Compare(want, passwords); diff != "" { t.Errorf("password list retrieved from storage did not match: %s", diff) } } listAndCompare(passwordList) if err := s.DeletePassword(password.Email); err != nil { t.Fatalf("failed to delete password: %v", err) } if _, err := s.GetPassword(password.Email); err != storage.ErrNotFound { t.Errorf("after deleting password expected storage.ErrNotFound, got %v", err) } }
func testClientCRUD(t *testing.T, s storage.Storage) { id := storage.NewID() c := storage.Client{ ID: id, Secret: "foobar", RedirectURIs: []string{"foo://bar.com/", "https://auth.example.com"}, Name: "dex client", LogoURL: "https://goo.gl/JIyzIC", } err := s.DeleteClient(id) mustBeErrNotFound(t, "client", err) if err := s.CreateClient(c); err != nil { t.Fatalf("create client: %v", err) } getAndCompare := func(id string, want storage.Client) { gc, err := s.GetClient(id) if err != nil { t.Errorf("get client: %v", err) return } if diff := pretty.Compare(want, gc); diff != "" { t.Errorf("client retrieved from storage did not match: %s", diff) } } getAndCompare(id, c) newSecret := "barfoo" err = s.UpdateClient(id, func(old storage.Client) (storage.Client, error) { old.Secret = newSecret return old, nil }) if err != nil { t.Errorf("update client: %v", err) } c.Secret = newSecret getAndCompare(id, c) if err := s.DeleteClient(id); err != nil { t.Fatalf("delete client: %v", err) } _, err = s.GetClient(id) mustBeErrNotFound(t, "client", err) }
// parse the initial request from the OAuth2 client. // // For correctness the logic is largely copied from https://github.com/RangelReale/osin. func parseAuthorizationRequest(s storage.Storage, supportedResponseTypes map[string]bool, r *http.Request) (req storage.AuthRequest, oauth2Err *authErr) { if err := r.ParseForm(); err != nil { return req, &authErr{"", "", errInvalidRequest, "Failed to parse request."} } redirectURI, err := url.QueryUnescape(r.Form.Get("redirect_uri")) if err != nil { return req, &authErr{"", "", errInvalidRequest, "No redirect_uri provided."} } state := r.FormValue("state") clientID := r.Form.Get("client_id") client, err := s.GetClient(clientID) if err != nil { if err == storage.ErrNotFound { description := fmt.Sprintf("Invalid client_id (%q).", clientID) return req, &authErr{"", "", errUnauthorizedClient, description} } log.Printf("Failed to get client: %v", err) return req, &authErr{"", "", errServerError, ""} } if !validateRedirectURI(client, redirectURI) { description := fmt.Sprintf("Unregistered redirect_uri (%q).", redirectURI) return req, &authErr{"", "", errInvalidRequest, description} } newErr := func(typ, format string, a ...interface{}) *authErr { return &authErr{state, redirectURI, typ, fmt.Sprintf(format, a...)} } scopes := strings.Fields(r.Form.Get("scope")) var ( unrecognized []string invalidScopes []string ) hasOpenIDScope := false for _, scope := range scopes { switch scope { case scopeOpenID: hasOpenIDScope = true case scopeOfflineAccess, scopeEmail, scopeProfile, scopeGroups: default: peerID, ok := parseCrossClientScope(scope) if !ok { unrecognized = append(unrecognized, scope) continue } isTrusted, err := validateCrossClientTrust(s, clientID, peerID) if err != nil { return req, newErr(errServerError, "") } if !isTrusted { invalidScopes = append(invalidScopes, scope) } } } if !hasOpenIDScope { return req, newErr("invalid_scope", `Missing required scope(s) ["openid"].`) } if len(unrecognized) > 0 { return req, newErr("invalid_scope", "Unrecognized scope(s) %q", unrecognized) } if len(invalidScopes) > 0 { return req, newErr("invalid_scope", "Client can't request scope(s) %q", invalidScopes) } nonce := r.Form.Get("nonce") responseTypes := strings.Split(r.Form.Get("response_type"), " ") for _, responseType := range responseTypes { if !supportedResponseTypes[responseType] { return req, newErr("invalid_request", "Invalid response type %q", responseType) } switch responseType { case responseTypeCode: case responseTypeToken: // Implicit flow requires a nonce value. // https://openid.net/specs/openid-connect-core-1_0.html#ImplicitAuthRequest if nonce == "" { return req, newErr("invalid_request", "Response type 'token' requires a 'nonce' value.") } if redirectURI == redirectURIOOB { err := fmt.Sprintf("Cannot use response type 'token' with redirect_uri '%s'.", redirectURIOOB) return req, newErr("invalid_request", err) } default: return req, newErr("invalid_request", "Invalid response type %q", responseType) } } return storage.AuthRequest{ ID: storage.NewID(), ClientID: client.ID, State: r.Form.Get("state"), Nonce: nonce, ForceApprovalPrompt: r.Form.Get("approval_prompt") == "force", Scopes: scopes, RedirectURI: redirectURI, ResponseTypes: responseTypes, }, nil }
func testGC(t *testing.T, s storage.Storage) { n := time.Now().UTC() c := storage.AuthCode{ ID: storage.NewID(), ClientID: "foobar", RedirectURI: "https://*****:*****@example.com", EmailVerified: true, Groups: []string{"a", "b"}, }, } if err := s.CreateAuthCode(c); err != nil { t.Fatalf("failed creating auth code: %v", err) } if _, err := s.GarbageCollect(n); err != nil { t.Errorf("garbage collection failed: %v", err) } if _, err := s.GetAuthCode(c.ID); err != nil { t.Errorf("expected to be able to get auth code after GC: %v", err) } if r, err := s.GarbageCollect(n.Add(time.Minute)); err != nil { t.Errorf("garbage collection failed: %v", err) } else if r.AuthCodes != 1 { t.Errorf("expected to garbage collect 1 objects, got %d", r.AuthCodes) } if _, err := s.GetAuthCode(c.ID); err == nil { t.Errorf("expected auth code to be GC'd") } else if err != storage.ErrNotFound { t.Errorf("expected storage.ErrNotFound, got %v", err) } a := storage.AuthRequest{ ID: storage.NewID(), ClientID: "foobar", ResponseTypes: []string{"code"}, Scopes: []string{"openid", "email"}, RedirectURI: "https://*****:*****@example.com", EmailVerified: true, Groups: []string{"a", "b"}, }, } if err := s.CreateAuthRequest(a); err != nil { t.Fatalf("failed creating auth request: %v", err) } if _, err := s.GarbageCollect(n); err != nil { t.Errorf("garbage collection failed: %v", err) } if _, err := s.GetAuthRequest(a.ID); err != nil { t.Errorf("expected to be able to get auth code after GC: %v", err) } if r, err := s.GarbageCollect(n.Add(time.Minute)); err != nil { t.Errorf("garbage collection failed: %v", err) } else if r.AuthRequests != 1 { t.Errorf("expected to garbage collect 1 objects, got %d", r.AuthRequests) } if _, err := s.GetAuthRequest(a.ID); err == nil { t.Errorf("expected auth code to be GC'd") } else if err != storage.ErrNotFound { t.Errorf("expected storage.ErrNotFound, got %v", err) } }