func TestClientKeysFuncAll(t *testing.T) { priv1, err := key.GeneratePrivateKey() if err != nil { t.Fatalf("failed to generate private key, error=%v", err) } priv2, err := key.GeneratePrivateKey() if err != nil { t.Fatalf("failed to generate private key, error=%v", err) } now := time.Now() future := now.Add(time.Hour) past := now.Add(-1 * time.Hour) tests := []struct { keySet *key.PublicKeySet want []key.PublicKey }{ // two keys, non-expired set { keySet: key.NewPublicKeySet([]jose.JWK{priv2.JWK(), priv1.JWK()}, future), want: []key.PublicKey{*key.NewPublicKey(priv2.JWK()), *key.NewPublicKey(priv1.JWK())}, }, // no keys, non-expired set { keySet: key.NewPublicKeySet([]jose.JWK{}, future), want: []key.PublicKey{}, }, // two keys, expired set { keySet: key.NewPublicKeySet([]jose.JWK{priv2.JWK(), priv1.JWK()}, past), want: []key.PublicKey{}, }, // no keys, expired set { keySet: key.NewPublicKeySet([]jose.JWK{}, past), want: []key.PublicKey{}, }, } for i, tt := range tests { var c Client c.keySet = *tt.keySet keysFunc := c.keysFuncAll() got := keysFunc() if !reflect.DeepEqual(tt.want, got) { t.Errorf("case %d: want=%#v got=%#v", i, tt.want, got) } } }
func mockServer(cis []client.LoadableClient) (*server.Server, error) { dbMap := db.NewMemDB() k, err := key.GeneratePrivateKey() if err != nil { return nil, fmt.Errorf("Unable to generate private key: %v", err) } km := key.NewPrivateKeyManager() err = km.Set(key.NewPrivateKeySet([]*key.PrivateKey{k}, time.Now().Add(time.Minute))) if err != nil { return nil, err } clientRepo, clientManager, err := makeClientRepoAndManager(dbMap, cis) if err != nil { return nil, err } sm := manager.NewSessionManager(db.NewSessionRepo(dbMap), db.NewSessionKeyRepo(dbMap)) srv := &server.Server{ IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"}, KeyManager: km, ClientRepo: clientRepo, ClientManager: clientManager, SessionManager: sm, } return srv, nil }
func TestGetClientIDFromAuthorizedRequest(t *testing.T) { now := time.Now() tomorrow := now.Add(24 * time.Hour) privKey, err := key.GeneratePrivateKey() if err != nil { t.Fatalf("Failed to generate private key, error=%v", err) } signer := privKey.Signer() makeToken := func(iss, sub, aud string, iat, exp time.Time) string { claims := oidc.NewClaims(iss, sub, aud, iat, exp) jwt, err := jose.NewSignedJWT(claims, signer) if err != nil { t.Fatalf("Failed to generate JWT, error=%v", err) } return jwt.Encode() } tests := []struct { header string wantClient string wantErr bool }{ { header: fmt.Sprintf("BEARER %s", makeToken("iss", "CLIENT_ID", "", now, tomorrow)), wantClient: "CLIENT_ID", wantErr: false, }, { header: fmt.Sprintf("BEARER %s", makeToken("iss", "", "", now, tomorrow)), wantErr: true, }, } for i, tt := range tests { req := &http.Request{ Header: http.Header{ "Authorization": []string{tt.header}, }, } gotClient, err := getClientIDFromAuthorizedRequest(req) if tt.wantErr { if err == nil { t.Errorf("case %d: want non-nil err", i) } continue } if err != nil { t.Errorf("case %d: got err: %q", i, err) continue } if gotClient != tt.wantClient { t.Errorf("case %d: want=%v, got=%v", i, tt.wantClient, gotClient) } } }
func makeTestFixtures() (*UserEmailer, *testEmailer, *key.PublicKey) { ur := user.NewUserRepoFromUsers([]user.UserWithRemoteIdentities{ { User: user.User{ ID: "ID-1", Email: "*****@*****.**", Admin: true, }, }, { User: user.User{ ID: "ID-2", Email: "*****@*****.**", }, }, { User: user.User{ ID: "ID-3", Email: "*****@*****.**", }, }, }) pwr := user.NewPasswordInfoRepoFromPasswordInfos([]user.PasswordInfo{ { UserID: "ID-1", Password: []byte("password-1"), }, { UserID: "ID-2", Password: []byte("password-2"), }, }) privKey, err := key.GeneratePrivateKey() if err != nil { panic(fmt.Sprintf("Failed to generate private key, error=%v", err)) } publicKey := key.NewPublicKey(privKey.JWK()) signer := privKey.Signer() signerFn := func() (jose.Signer, error) { return signer, nil } textTemplateString := `{{define "password-reset.txt"}}{{.link}}{{end}} {{define "verify-email.txt"}}{{.link}}{{end}}"` textTemplates := template.New("text") _, err = textTemplates.Parse(textTemplateString) if err != nil { panic(fmt.Sprintf("error parsing text templates: %v", err)) } htmlTemplates := htmltemplate.New("html") emailer := &testEmailer{} tEmailer := email.NewTemplatizedEmailerFromTemplates(textTemplates, htmlTemplates, emailer) userEmailer := NewUserEmailer(ur, pwr, signerFn, validityWindow, issuerURL, tEmailer, fromAddress, passwordResetURL, verifyEmailURL, acceptInvitationURL) return userEmailer, emailer, publicKey }
func (cfg *SingleServerConfig) Configure(srv *Server) error { k, err := key.GeneratePrivateKey() if err != nil { return err } ks := key.NewPrivateKeySet([]*key.PrivateKey{k}, time.Now().Add(24*time.Hour)) kRepo := key.NewPrivateKeySetRepo() if err = kRepo.Set(ks); err != nil { return err } cf, err := os.Open(cfg.ClientsFile) if err != nil { return fmt.Errorf("unable to read clients from file %s: %v", cfg.ClientsFile, err) } defer cf.Close() ciRepo, err := client.NewClientIdentityRepoFromReader(cf) if err != nil { return fmt.Errorf("unable to read client identities from file %s: %v", cfg.ClientsFile, err) } f, err := os.Open(cfg.ConnectorsFile) if err != nil { return fmt.Errorf("opening connectors file: %v", err) } defer f.Close() cfgs, err := connector.ReadConfigs(f) if err != nil { return fmt.Errorf("decoding connector configs: %v", err) } cfgRepo := connector.NewConnectorConfigRepoFromConfigs(cfgs) sRepo := session.NewSessionRepo() skRepo := session.NewSessionKeyRepo() sm := session.NewSessionManager(sRepo, skRepo) userRepo, err := user.NewUserRepoFromFile(cfg.UsersFile) if err != nil { return fmt.Errorf("unable to read users from file: %v", err) } pwiRepo := user.NewPasswordInfoRepo() refTokRepo := refresh.NewRefreshTokenRepo() txnFactory := repo.InMemTransactionFactory userManager := manager.NewUserManager(userRepo, pwiRepo, cfgRepo, txnFactory, manager.ManagerOptions{}) srv.ClientIdentityRepo = ciRepo srv.KeySetRepo = kRepo srv.ConnectorConfigRepo = cfgRepo srv.UserRepo = userRepo srv.UserManager = userManager srv.PasswordInfoRepo = pwiRepo srv.SessionManager = sm srv.RefreshTokenRepo = refTokRepo return nil }
// NewOIDCProvider provides a bare minimum OIDC IdP Server useful for testing. func NewOIDCProvider(t *testing.T) *OIDCProvider { privKey, err := key.GeneratePrivateKey() if err != nil { t.Fatalf("Cannot create OIDC Provider: %v", err) return nil } op := &OIDCProvider{ Mux: http.NewServeMux(), PrivKey: privKey, } op.Mux.HandleFunc("/.well-known/openid-configuration", op.handleConfig) op.Mux.HandleFunc("/keys", op.handleKeys) return op }
func mockServer(cis []oidc.ClientIdentity) (*server.Server, error) { k, err := key.GeneratePrivateKey() if err != nil { return nil, fmt.Errorf("Unable to generate private key: %v", err) } km := key.NewPrivateKeyManager() err = km.Set(key.NewPrivateKeySet([]*key.PrivateKey{k}, time.Now().Add(time.Minute))) if err != nil { return nil, err } sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()) srv := &server.Server{ IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"}, KeyManager: km, ClientIdentityRepo: client.NewClientIdentityRepo(cis), SessionManager: sm, } return srv, nil }
func mockServer(cis []client.Client) (*server.Server, error) { dbMap := db.NewMemDB() k, err := key.GeneratePrivateKey() if err != nil { return nil, fmt.Errorf("Unable to generate private key: %v", err) } km := key.NewPrivateKeyManager() err = km.Set(key.NewPrivateKeySet([]*key.PrivateKey{k}, time.Now().Add(time.Minute))) if err != nil { return nil, err } clientIDGenerator := func(hostport string) (string, error) { return hostport, nil } secGen := func() ([]byte, error) { return []byte("secret"), nil } clientRepo := db.NewClientRepo(dbMap) clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbMap), cis, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen}) if err != nil { return nil, err } sm := manager.NewSessionManager(db.NewSessionRepo(dbMap), db.NewSessionKeyRepo(dbMap)) srv := &server.Server{ IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"}, KeyManager: km, ClientRepo: clientRepo, ClientManager: clientManager, SessionManager: sm, } return srv, nil }
func TestHTTPExchangeTokenRefreshToken(t *testing.T) { password, err := user.NewPasswordFromPlaintext("woof") if err != nil { t.Fatalf("unexpectd error: %q", err) } passwordInfo := user.PasswordInfo{ UserID: "elroy77", Password: password, } cfg := &connector.LocalConnectorConfig{ ID: "local", } validRedirURL := url.URL{ Scheme: "http", Host: "client.example.com", Path: "/callback", } ci := client.Client{ Credentials: oidc.ClientCredentials{ ID: validRedirURL.Host, Secret: base64.URLEncoding.EncodeToString([]byte("secret")), }, Metadata: oidc.ClientMetadata{ RedirectURIs: []url.URL{ validRedirURL, }, }, } dbMap := db.NewMemDB() clientRepo, clientManager, err := makeClientRepoAndManager(dbMap, []client.LoadableClient{{ Client: ci, }}) if err != nil { t.Fatalf("Failed to create client identity manager: " + err.Error()) } passwordInfoRepo, err := db.NewPasswordInfoRepoFromPasswordInfos(db.NewMemDB(), []user.PasswordInfo{passwordInfo}) if err != nil { t.Fatalf("Failed to create password info repo: %v", err) } issuerURL := url.URL{Scheme: "http", Host: "server.example.com"} sm := manager.NewSessionManager(db.NewSessionRepo(dbMap), db.NewSessionKeyRepo(dbMap)) k, err := key.GeneratePrivateKey() if err != nil { t.Fatalf("Unable to generate RSA key: %v", err) } km := key.NewPrivateKeyManager() err = km.Set(key.NewPrivateKeySet([]*key.PrivateKey{k}, time.Now().Add(time.Minute))) if err != nil { t.Fatalf("Unexpected error: %v", err) } usr := user.User{ ID: "ID-test", Email: "*****@*****.**", DisplayName: "displayname", } userRepo := db.NewUserRepo(db.NewMemDB()) if err := userRepo.Create(nil, usr); err != nil { t.Fatalf("Unexpected error: %v", err) } refreshTokenRepo := refreshtest.NewTestRefreshTokenRepo() srv := &server.Server{ IssuerURL: issuerURL, KeyManager: km, SessionManager: sm, ClientRepo: clientRepo, ClientManager: clientManager, Templates: template.New(connector.LoginPageTemplateName), Connectors: []connector.Connector{}, UserRepo: userRepo, PasswordInfoRepo: passwordInfoRepo, RefreshTokenRepo: refreshTokenRepo, } if err = srv.AddConnector(cfg); err != nil { t.Fatalf("Unexpected error: %v", err) } sClient := &phttp.HandlerClient{Handler: srv.HTTPHandler()} pcfg, err := oidc.FetchProviderConfig(sClient, issuerURL.String()) if err != nil { t.Fatalf("Failed to fetch provider config: %v", err) } ks := key.NewPublicKeySet([]jose.JWK{k.JWK()}, time.Now().Add(1*time.Hour)) ccfg := oidc.ClientConfig{ HTTPClient: sClient, ProviderConfig: pcfg, Credentials: ci.Credentials, RedirectURL: validRedirURL.String(), KeySet: *ks, } cl, err := oidc.NewClient(ccfg) if err != nil { t.Fatalf("Failed creating oidc.Client: %v", err) } m := http.NewServeMux() var claims jose.Claims var refresh string m.HandleFunc("/callback", handleCallbackFunc(cl, &claims, &refresh)) cClient := &phttp.HandlerClient{Handler: m} // this will actually happen due to some interaction between the // end-user and a remote identity provider sessionID, err := sm.NewSession("bogus_idpc", ci.Credentials.ID, "bogus", url.URL{}, "", false, []string{"openid", "offline_access", "email", "profile"}) if err != nil { t.Fatalf("Unexpected error: %v", err) } if _, err = sm.AttachRemoteIdentity(sessionID, passwordInfo.Identity()); err != nil { t.Fatalf("Unexpected error: %v", err) } if _, err = sm.AttachUser(sessionID, usr.ID); err != nil { t.Fatalf("Unexpected error: %v", err) } key, err := sm.NewSessionKey(sessionID) if err != nil { t.Fatalf("Unexpected error: %v", err) } req, err := http.NewRequest("GET", fmt.Sprintf("http://client.example.com/callback?code=%s", key), nil) if err != nil { t.Fatalf("Failed creating HTTP request: %v", err) } resp, err := cClient.Do(req) if err != nil { t.Fatalf("Failed resolving HTTP requests against /callback: %v", err) } if err := verifyUserClaims(claims, &ci, &usr, issuerURL); err != nil { t.Fatalf("Failed to verify claims: %v", err) } if resp.StatusCode != http.StatusOK { t.Fatalf("Received status code %d, want %d", resp.StatusCode, http.StatusOK) } if refresh == "" { t.Fatalf("No refresh token") } // Use refresh token to get a new ID token. token, err := cl.RefreshToken(refresh) if err != nil { t.Fatalf("Unexpected error: %v", err) } claims, err = token.Claims() if err != nil { t.Fatalf("Failed parsing claims from client token: %v", err) } if err := verifyUserClaims(claims, &ci, &usr, issuerURL); err != nil { t.Fatalf("Failed to verify claims: %v", err) } }
func TestInvitationParseAndVerify(t *testing.T) { issuer, _ := url.Parse("http://example.com") notIssuer, _ := url.Parse("http://other.com") client := "myclient" user := User{ID: "1234", Email: "*****@*****.**"} callback, _ := url.Parse("http://client.example.com") expires := time.Hour * 3 password := Password("Halloween is the best holiday") privKey, _ := key.GeneratePrivateKey() signer := privKey.Signer() publicKeys := []key.PublicKey{*key.NewPublicKey(privKey.JWK())} tests := []struct { invite Invitation wantErr bool signer jose.Signer }{ { invite: NewInvitation(user, password, *issuer, client, *callback, expires), signer: signer, wantErr: false, }, { invite: NewInvitation(user, password, *issuer, client, *callback, expires), signer: signer, wantErr: false, }, { invite: NewInvitation(user, password, *issuer, client, *callback, -expires), signer: signer, wantErr: true, }, { invite: NewInvitation(user, password, *notIssuer, client, *callback, expires), signer: signer, wantErr: true, }, { invite: NewInvitation(User{Email: "*****@*****.**"}, password, *issuer, client, *callback, expires), signer: signer, wantErr: true, }, { invite: NewInvitation(User{ID: "JONNY_NO_EMAIL"}, password, *issuer, client, *callback, expires), signer: signer, wantErr: true, }, { invite: NewInvitation(user, Password(""), *issuer, client, *callback, expires), signer: signer, wantErr: true, }, { invite: NewInvitation(user, password, *issuer, "", *callback, expires), signer: signer, wantErr: true, }, { invite: NewInvitation(user, password, *issuer, "", url.URL{}, expires), signer: signer, wantErr: true, }, } for i, tt := range tests { jwt, err := jose.NewSignedJWT(tt.invite.Claims, tt.signer) if err != nil { t.Fatalf("case %d: failed to generate JWT, error: %v", i, err) } token := jwt.Encode() parsed, err := ParseAndVerifyInvitationToken(token, *issuer, publicKeys) if tt.wantErr { if err == nil { t.Errorf("case %d: want no-nil error, got nil", i) } continue } if err != nil { t.Errorf("case %d: unexpected error: %v", i, err) continue } if diff := pretty.Compare(tt.invite, parsed); diff != "" { t.Errorf("case %d: Compare(want, got): %v", i, diff) } } }
func (cfg *SingleServerConfig) Configure(srv *Server) error { k, err := key.GeneratePrivateKey() if err != nil { return err } dbMap := db.NewMemDB() ks := key.NewPrivateKeySet([]*key.PrivateKey{k}, time.Now().Add(24*time.Hour)) kRepo := key.NewPrivateKeySetRepo() if err = kRepo.Set(ks); err != nil { return err } clients, err := loadClients(cfg.ClientsFile) if err != nil { return fmt.Errorf("unable to read clients from file %s: %v", cfg.ClientsFile, err) } clientRepo, err := db.NewClientRepoFromClients(dbMap, clients) if err != nil { return err } f, err := os.Open(cfg.ConnectorsFile) if err != nil { return fmt.Errorf("opening connectors file: %v", err) } defer f.Close() cfgs, err := connector.ReadConfigs(f) if err != nil { return fmt.Errorf("decoding connector configs: %v", err) } cfgRepo := db.NewConnectorConfigRepo(dbMap) if err := cfgRepo.Set(cfgs); err != nil { return fmt.Errorf("failed to set connectors: %v", err) } sRepo := db.NewSessionRepo(dbMap) skRepo := db.NewSessionKeyRepo(dbMap) sm := sessionmanager.NewSessionManager(sRepo, skRepo) users, pwis, err := loadUsers(cfg.UsersFile) if err != nil { return fmt.Errorf("unable to read users from file: %v", err) } userRepo, err := db.NewUserRepoFromUsers(dbMap, users) if err != nil { return err } pwiRepo, err := db.NewPasswordInfoRepoFromPasswordInfos(dbMap, pwis) if err != nil { return err } refTokRepo := db.NewRefreshTokenRepo(dbMap) txnFactory := db.TransactionFactory(dbMap) userManager := usermanager.NewUserManager(userRepo, pwiRepo, cfgRepo, txnFactory, usermanager.ManagerOptions{}) clientManager := clientmanager.NewClientManager(clientRepo, db.TransactionFactory(dbMap), clientmanager.ManagerOptions{}) if err != nil { return fmt.Errorf("Failed to create client identity manager: %v", err) } srv.ClientRepo = clientRepo srv.ClientManager = clientManager srv.KeySetRepo = kRepo srv.ConnectorConfigRepo = cfgRepo srv.UserRepo = userRepo srv.UserManager = userManager srv.PasswordInfoRepo = pwiRepo srv.SessionManager = sm srv.RefreshTokenRepo = refTokRepo srv.HealthChecks = append(srv.HealthChecks, db.NewHealthChecker(dbMap)) srv.dbMap = dbMap return nil }
func TestInvitationHandler(t *testing.T) { invUserID := "ID-1" invVerifiedID := "ID-Verified" invGoodSigner := key.NewPrivateKeySet([]*key.PrivateKey{testPrivKey}, time.Now().Add(time.Minute)).Active().Signer() badKey, err := key.GeneratePrivateKey() if err != nil { panic(fmt.Sprintf("couldn't make new key: %q", err)) } invBadSigner := key.NewPrivateKeySet([]*key.PrivateKey{badKey}, time.Now().Add(time.Minute)).Active().Signer() makeInvitationToken := func(password, userID, clientID, email string, callback url.URL, expires time.Duration, signer jose.Signer) string { iv := user.NewInvitation( user.User{ID: userID, Email: email}, user.Password(password), testIssuerURL, clientID, callback, expires) jwt, err := jose.NewSignedJWT(iv.Claims, signer) if err != nil { t.Fatalf("couldn't make token: %q", err) } token := jwt.Encode() return token } tests := []struct { userID string query url.Values signer jose.Signer wantCode int wantCallback url.URL wantEmailVerified bool }{ { // Case 0 Happy Path userID: invUserID, query: url.Values{ "token": []string{makeInvitationToken("password", invUserID, testClientID, "*****@*****.**", testRedirectURL, time.Hour*1, invGoodSigner)}, }, signer: invGoodSigner, wantCode: http.StatusSeeOther, wantCallback: testRedirectURL, wantEmailVerified: true, }, { // Case 1 user already verified userID: invVerifiedID, query: url.Values{ "token": []string{makeInvitationToken("password", invVerifiedID, testClientID, "*****@*****.**", testRedirectURL, time.Hour*1, invGoodSigner)}, }, signer: invGoodSigner, wantCode: http.StatusSeeOther, wantCallback: testRedirectURL, wantEmailVerified: true, }, { // Case 2 bad email userID: invUserID, query: url.Values{ "token": []string{makeInvitationToken("password", invVerifiedID, testClientID, "*****@*****.**", testRedirectURL, time.Hour*1, invGoodSigner)}, }, signer: invGoodSigner, wantCode: http.StatusBadRequest, wantCallback: testRedirectURL, wantEmailVerified: false, }, { // Case 3 bad signer userID: invUserID, query: url.Values{ "token": []string{makeInvitationToken("password", invUserID, testClientID, "*****@*****.**", testRedirectURL, time.Hour*1, invBadSigner)}, }, signer: invGoodSigner, wantCode: http.StatusBadRequest, wantCallback: testRedirectURL, wantEmailVerified: false, }, } for i, tt := range tests { f, err := makeTestFixtures() if err != nil { t.Fatalf("case %d: could not make test fixtures: %v", i, err) } keys, err := f.srv.KeyManager.PublicKeys() if err != nil { t.Fatalf("case %d: test fixture key infrastructure is broken: %v", i, err) } tZero := clock.Now() handler := &InvitationHandler{ passwordResetURL: f.srv.absURL("RESETME"), issuerURL: testIssuerURL, um: f.srv.UserManager, keysFunc: f.srv.KeyManager.PublicKeys, signerFunc: func() (jose.Signer, error) { return tt.signer, nil }, redirectValidityWindow: 100 * time.Second, } w := httptest.NewRecorder() u := testIssuerURL u.RawQuery = tt.query.Encode() req, err := http.NewRequest("GET", u.String(), nil) if err != nil { t.Fatalf("case %d: impossible error: %v", i, err) } handler.ServeHTTP(w, req) if tt.wantCode != w.Code { t.Errorf("case %d: wantCode=%v, got=%v", i, tt.wantCode, w.Code) continue } usr, err := f.srv.UserManager.Get(tt.userID) if err != nil { t.Fatalf("case %d: unexpected error: %v", i, err) } if usr.EmailVerified != tt.wantEmailVerified { t.Errorf("case %d: wantEmailVerified=%v got=%v", i, tt.wantEmailVerified, usr.EmailVerified) } if w.Code == http.StatusSeeOther { locString := w.HeaderMap.Get("Location") loc, err := url.Parse(locString) if err != nil { t.Fatalf("case %d: redirect returned nonsense url: '%v', %v", i, locString, err) } pwrToken := loc.Query().Get("token") pwrReset, err := user.ParseAndVerifyPasswordResetToken(pwrToken, testIssuerURL, keys) if err != nil { t.Errorf("case %d: password token is invalid: %v", i, err) } expTime := pwrReset.Claims["exp"].(float64) if expTime > float64(tZero.Add(handler.redirectValidityWindow).Unix()) || expTime < float64(tZero.Unix()) { t.Errorf("case %d: funny expiration time detected: %d", i, pwrReset.Claims["exp"]) } if pwrReset.Claims["aud"] != testClientID { t.Errorf("case %d: wanted \"aud\"=%v got=%v", i, testClientID, pwrReset.Claims["aud"]) } if pwrReset.Claims["iss"] != testIssuerURL.String() { t.Errorf("case %d: wanted \"iss\"=%v got=%v", i, testIssuerURL, pwrReset.Claims["iss"]) } if pwrReset.UserID() != tt.userID { t.Errorf("case %d: wanted UserID=%v got=%v", i, tt.userID, pwrReset.UserID()) } if bytes.Compare(pwrReset.Password(), user.Password("password")) != 0 { t.Errorf("case %d: wanted Password=%v got=%v", i, user.Password("password"), pwrReset.Password()) } if *pwrReset.Callback() != testRedirectURL { t.Errorf("case %d: wanted callback=%v got=%v", i, testRedirectURL, pwrReset.Callback()) } } } }
func TestHandleVerifyEmailResend(t *testing.T) { now := time.Now() tomorrow := now.Add(24 * time.Hour) yesterday := now.Add(-24 * time.Hour) privKey, err := key.GeneratePrivateKey() if err != nil { t.Fatalf("Failed to generate private key, error=%v", err) } signer := privKey.Signer() pubKey := *key.NewPublicKey(privKey.JWK()) keysFunc := func() ([]key.PublicKey, error) { return []key.PublicKey{pubKey}, nil } makeToken := func(iss, sub, aud string, iat, exp time.Time) string { claims := oidc.NewClaims(iss, sub, aud, iat, exp) jwt, err := jose.NewSignedJWT(claims, signer) if err != nil { t.Fatalf("Failed to generate JWT, error=%v", err) } return jwt.Encode() } tests := []struct { bearerJWT string userJWT string redirectURL url.URL wantCode int verifyEmailUserID string }{ { // The happy case bearerJWT: makeToken(testIssuerURL.String(), testClientID, testClientID, now, tomorrow), userJWT: makeToken(testIssuerURL.String(), "ID-1", testClientID, now, tomorrow), redirectURL: testRedirectURL, wantCode: http.StatusOK, }, { // Already verified bearerJWT: makeToken(testIssuerURL.String(), testClientID, testClientID, now, tomorrow), userJWT: makeToken(testIssuerURL.String(), "ID-1", testClientID, now, tomorrow), redirectURL: testRedirectURL, wantCode: http.StatusBadRequest, verifyEmailUserID: "ID-1", }, { // Expired userJWT bearerJWT: makeToken(testIssuerURL.String(), testClientID, testClientID, now, tomorrow), userJWT: makeToken(testIssuerURL.String(), "ID-1", testClientID, now, yesterday), redirectURL: testRedirectURL, wantCode: http.StatusUnauthorized, }, { // Client ID is unknown bearerJWT: makeToken(testIssuerURL.String(), "fakeclientid", testClientID, now, tomorrow), userJWT: makeToken(testIssuerURL.String(), "ID-1", testClientID, now, tomorrow), redirectURL: testRedirectURL, wantCode: http.StatusBadRequest, }, { // No sub in user JWT bearerJWT: makeToken(testIssuerURL.String(), testClientID, testClientID, now, tomorrow), userJWT: makeToken(testIssuerURL.String(), "", testClientID, now, tomorrow), redirectURL: testRedirectURL, wantCode: http.StatusBadRequest, }, { // Unknown user bearerJWT: makeToken(testIssuerURL.String(), testClientID, testClientID, now, tomorrow), userJWT: makeToken(testIssuerURL.String(), "NonExistent", testClientID, now, tomorrow), redirectURL: testRedirectURL, wantCode: http.StatusBadRequest, }, { // No redirect URL bearerJWT: makeToken(testIssuerURL.String(), testClientID, testClientID, now, tomorrow), userJWT: makeToken(testIssuerURL.String(), "ID-1", testClientID, now, tomorrow), redirectURL: url.URL{}, wantCode: http.StatusBadRequest, }, } for i, tt := range tests { f, err := makeTestFixtures() if tt.verifyEmailUserID != "" { usr, _ := f.userRepo.Get(nil, tt.verifyEmailUserID) usr.EmailVerified = true f.userRepo.Update(nil, usr) } if err != nil { t.Fatalf("case %d: could not make test fixtures: %v", i, err) } hdlr := handleVerifyEmailResendFunc( testIssuerURL, keysFunc, f.srv.UserEmailer, f.userRepo, f.clientManager) w := httptest.NewRecorder() u := "http://example.com" q := struct { Token string `json:"token"` RedirectURI string `json:"redirectURI"` }{ Token: tt.userJWT, RedirectURI: tt.redirectURL.String(), } qBytes, err := json.Marshal(&q) if err != nil { t.Errorf("case %d: unable to marshal JSON: %q", i, err) } req, err := http.NewRequest("POST", u, bytes.NewReader(qBytes)) req.Header.Set("Authorization", "Bearer "+tt.bearerJWT) if err != nil { t.Errorf("case %d: unable to form HTTP request: %v", i, err) } hdlr.ServeHTTP(w, req) if tt.wantCode != w.Code { t.Errorf("case %d: wantCode=%v, got=%v", i, tt.wantCode, w.Code) t.Logf("case %d: response body was: %v", i, w.Body.String()) } } }
func TestVerifyJWTExpiry(t *testing.T) { privKey, err := key.GeneratePrivateKey() if err != nil { t.Fatalf("can't generate private key: %v", err) } makeToken := func(s string, exp time.Time, count int) *jose.JWT { jwt, err := jose.NewSignedJWT(jose.Claims(map[string]interface{}{ "test": s, "exp": exp.UTC().Unix(), "count": count, }), privKey.Signer()) if err != nil { t.Fatalf("Could not create signed JWT %v", err) } return jwt } t0 := time.Now() tests := []struct { name string jwt *jose.JWT now time.Time wantErr bool wantExpired bool }{ { name: "valid jwt", jwt: makeToken("foo", t0.Add(time.Hour), 1), now: t0, }, { name: "invalid jwt", jwt: &jose.JWT{}, now: t0, wantErr: true, }, { name: "expired jwt", jwt: makeToken("foo", t0.Add(-time.Hour), 1), now: t0, wantExpired: true, }, { name: "jwt expires soon enough to be marked expired", jwt: makeToken("foo", t0, 1), now: t0, wantExpired: true, }, } for _, tc := range tests { func() { valid, err := verifyJWTExpiry(tc.now, tc.jwt.Encode()) if err != nil { if !tc.wantErr { t.Errorf("%s: %v", tc.name, err) } return } if tc.wantErr { t.Errorf("%s: expected error", tc.name) return } if valid && tc.wantExpired { t.Errorf("%s: expected token to be expired", tc.name) } if !valid && !tc.wantExpired { t.Errorf("%s: expected token to be valid", tc.name) } }() } }
func TestWrapTranport(t *testing.T) { oldBackoff := backoff defer func() { backoff = oldBackoff }() backoff = wait.Backoff{ Duration: 1 * time.Nanosecond, Steps: 3, } privKey, err := key.GeneratePrivateKey() if err != nil { t.Fatalf("can't generate private key: %v", err) } makeToken := func(s string, exp time.Time, count int) *jose.JWT { jwt, err := jose.NewSignedJWT(jose.Claims(map[string]interface{}{ "test": s, "exp": exp.UTC().Unix(), "count": count, }), privKey.Signer()) if err != nil { t.Fatalf("Could not create signed JWT %v", err) } return jwt } goodToken := makeToken("good", time.Now().Add(time.Hour), 0) goodToken2 := makeToken("good", time.Now().Add(time.Hour), 1) expiredToken := makeToken("good", time.Now().Add(-time.Hour), 0) str := func(s string) *string { return &s } tests := []struct { cfgIDToken *jose.JWT cfgRefreshToken *string expectRequests []testRoundTrip expectRefreshes []testRefresh expectPersists []testPersist wantStatus int wantErr bool }{ { // Initial JWT is set, it is good, it is set as bearer. cfgIDToken: goodToken, expectRequests: []testRoundTrip{ { expectBearerToken: goodToken.Encode(), returnHTTPStatus: 200, }, }, wantStatus: 200, }, { // Initial JWT is set, but it's expired, so it gets refreshed. cfgIDToken: expiredToken, cfgRefreshToken: str("rt1"), expectRefreshes: []testRefresh{ { expectRefreshToken: "rt1", returnTokens: oauth2.TokenResponse{ IDToken: goodToken.Encode(), }, }, }, expectRequests: []testRoundTrip{ { expectBearerToken: goodToken.Encode(), returnHTTPStatus: 200, }, }, expectPersists: []testPersist{ { cfg: map[string]string{ cfgIDToken: goodToken.Encode(), cfgRefreshToken: "rt1", }, }, }, wantStatus: 200, }, { // Initial JWT is set, but it's expired, so it gets refreshed - this // time the refresh token itself is also refreshed cfgIDToken: expiredToken, cfgRefreshToken: str("rt1"), expectRefreshes: []testRefresh{ { expectRefreshToken: "rt1", returnTokens: oauth2.TokenResponse{ IDToken: goodToken.Encode(), RefreshToken: "rt2", }, }, }, expectRequests: []testRoundTrip{ { expectBearerToken: goodToken.Encode(), returnHTTPStatus: 200, }, }, expectPersists: []testPersist{ { cfg: map[string]string{ cfgIDToken: goodToken.Encode(), cfgRefreshToken: "rt2", }, }, }, wantStatus: 200, }, { // Initial JWT is not set, so it gets refreshed. cfgRefreshToken: str("rt1"), expectRefreshes: []testRefresh{ { expectRefreshToken: "rt1", returnTokens: oauth2.TokenResponse{ IDToken: goodToken.Encode(), }, }, }, expectRequests: []testRoundTrip{ { expectBearerToken: goodToken.Encode(), returnHTTPStatus: 200, }, }, expectPersists: []testPersist{ { cfg: map[string]string{ cfgIDToken: goodToken.Encode(), cfgRefreshToken: "rt1", }, }, }, wantStatus: 200, }, { // Expired token, but no refresh token. cfgIDToken: expiredToken, wantErr: true, }, { // Initial JWT is not set, so it gets refreshed, but the server // rejects it when it is used, so it refreshes again, which // succeeds. cfgRefreshToken: str("rt1"), expectRefreshes: []testRefresh{ { expectRefreshToken: "rt1", returnTokens: oauth2.TokenResponse{ IDToken: goodToken.Encode(), }, }, { expectRefreshToken: "rt1", returnTokens: oauth2.TokenResponse{ IDToken: goodToken2.Encode(), }, }, }, expectRequests: []testRoundTrip{ { expectBearerToken: goodToken.Encode(), returnHTTPStatus: http.StatusUnauthorized, }, { expectBearerToken: goodToken2.Encode(), returnHTTPStatus: http.StatusOK, }, }, expectPersists: []testPersist{ { cfg: map[string]string{ cfgIDToken: goodToken.Encode(), cfgRefreshToken: "rt1", }, }, { cfg: map[string]string{ cfgIDToken: goodToken2.Encode(), cfgRefreshToken: "rt1", }, }, }, wantStatus: 200, }, { // Initial JWT is but the server rejects it when it is used, so it // refreshes again, which succeeds. cfgRefreshToken: str("rt1"), cfgIDToken: goodToken, expectRefreshes: []testRefresh{ { expectRefreshToken: "rt1", returnTokens: oauth2.TokenResponse{ IDToken: goodToken2.Encode(), }, }, }, expectRequests: []testRoundTrip{ { expectBearerToken: goodToken.Encode(), returnHTTPStatus: http.StatusUnauthorized, }, { expectBearerToken: goodToken2.Encode(), returnHTTPStatus: http.StatusOK, }, }, expectPersists: []testPersist{ { cfg: map[string]string{ cfgIDToken: goodToken2.Encode(), cfgRefreshToken: "rt1", }, }, }, wantStatus: 200, }, } for i, tt := range tests { client := &testOIDCClient{ refreshes: tt.expectRefreshes, } persister := &testPersister{ tt.expectPersists, } cfg := map[string]string{} if tt.cfgIDToken != nil { cfg[cfgIDToken] = tt.cfgIDToken.Encode() } if tt.cfgRefreshToken != nil { cfg[cfgRefreshToken] = *tt.cfgRefreshToken } ap := &oidcAuthProvider{ refresher: &idTokenRefresher{ client: client, cfg: cfg, persister: persister, }, } if tt.cfgIDToken != nil { ap.initialIDToken = *tt.cfgIDToken } tstRT := &testRoundTripper{ tt.expectRequests, } rt := ap.WrapTransport(tstRT) req, err := http.NewRequest("GET", "http://cluster.example.com", nil) if err != nil { t.Errorf("case %d: unexpected error making request: %v", i, err) } res, err := rt.RoundTrip(req) if tt.wantErr { if err == nil { t.Errorf("case %d: Expected non-nil error", i) } } else if err != nil { t.Errorf("case %d: unexpected error making round trip: %v", i, err) } else { if res.StatusCode != tt.wantStatus { t.Errorf("case %d: want=%d, got=%d", i, tt.wantStatus, res.StatusCode) } } if err = client.verify(); err != nil { t.Errorf("case %d: %v", i, err) } if err = persister.verify(); err != nil { t.Errorf("case %d: %v", i, err) } if err = tstRT.verify(); err != nil { t.Errorf("case %d: %v", i, err) continue } } }
"github.com/coreos/dex/connector" "github.com/coreos/dex/db" "github.com/coreos/dex/user" "github.com/coreos/dex/user/manager" ) var ( clock = clockwork.NewFakeClock() testIssuerURL = url.URL{Scheme: "https", Host: "auth.example.com"} testClientID = "client.example.com" testClientSecret = base64.URLEncoding.EncodeToString([]byte("secret")) testRedirectURL = url.URL{Scheme: "https", Host: "client.example.com", Path: "/redirect"} testBadRedirectURL = url.URL{Scheme: "https", Host: "bad.example.com", Path: "/redirect"} testResetPasswordURL = url.URL{Scheme: "https", Host: "auth.example.com", Path: "/resetPassword"} testPrivKey, _ = key.GeneratePrivateKey() ) type tokenHandlerTransport struct { Handler http.Handler Token string } func (t *tokenHandlerTransport) RoundTrip(r *http.Request) (*http.Response, error) { r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", t.Token)) w := httptest.NewRecorder() t.Handler.ServeHTTP(w, r) resp := http.Response{ StatusCode: w.Code, Header: w.Header(), Body: ioutil.NopCloser(w.Body),
func TestEmailVerificationParseAndVerify(t *testing.T) { issuer, _ := url.Parse("http://example.com") otherIssuer, _ := url.Parse("http://bad.example.com") client := "myclient" user := User{ID: "1234", Email: "*****@*****.**"} callback, _ := url.Parse("http://client.example.com") expires := time.Hour * 3 goodEV := NewEmailVerification(user, client, *issuer, *callback, expires) expiredEV := NewEmailVerification(user, client, *issuer, *callback, -expires) wrongIssuerEV := NewEmailVerification(user, client, *otherIssuer, *callback, expires) noSubEV := NewEmailVerification(User{}, client, *issuer, *callback, expires) privKey, err := key.GeneratePrivateKey() if err != nil { t.Fatalf("Failed to generate private key, error=%v", err) } signer := privKey.Signer() privKey2, err := key.GeneratePrivateKey() if err != nil { t.Fatalf("Failed to generate private key, error=%v", err) } otherSigner := privKey2.Signer() tests := []struct { ev EmailVerification wantErr bool signer jose.Signer }{ { ev: goodEV, signer: signer, wantErr: false, }, { ev: expiredEV, signer: signer, wantErr: true, }, { ev: wrongIssuerEV, signer: signer, wantErr: true, }, { ev: goodEV, signer: otherSigner, wantErr: true, }, { ev: noSubEV, signer: signer, wantErr: true, }, } for i, tt := range tests { jwt, err := jose.NewSignedJWT(tt.ev.Claims, tt.signer) if err != nil { t.Fatalf("Failed to generate JWT, error=%v", err) } token := jwt.Encode() ev, err := ParseAndVerifyEmailVerificationToken(token, *issuer, []key.PublicKey{*key.NewPublicKey(privKey.JWK())}) if tt.wantErr { t.Logf("err: %v", err) if err == nil { t.Errorf("case %d: want non-nil err, got nil", i) } continue } if err != nil { t.Errorf("case %d: non-nil err: %q", i, err) } if diff := pretty.Compare(tt.ev.Claims, ev.Claims); diff != "" { t.Errorf("case %d: Compare(want, got): %v", i, diff) } } }
func TestJWTVerifier(t *testing.T) { iss := "http://example.com" now := time.Now() future12 := now.Add(12 * time.Hour) past36 := now.Add(-36 * time.Hour) past12 := now.Add(-12 * time.Hour) priv1, err := key.GeneratePrivateKey() if err != nil { t.Fatalf("failed to generate private key, error=%v", err) } pk1 := *key.NewPublicKey(priv1.JWK()) priv2, err := key.GeneratePrivateKey() if err != nil { t.Fatalf("failed to generate private key, error=%v", err) } pk2 := *key.NewPublicKey(priv2.JWK()) jwtPK1, err := jose.NewSignedJWT(NewClaims(iss, "XXX", "XXX", past12, future12), priv1.Signer()) if err != nil { t.Fatalf("unexpected error: %v", err) } jwtPK1BadClaims, err := jose.NewSignedJWT(NewClaims(iss, "XXX", "YYY", past12, future12), priv1.Signer()) if err != nil { t.Fatalf("unexpected error: %v", err) } jwtExpired, err := jose.NewSignedJWT(NewClaims(iss, "XXX", "XXX", past36, past12), priv1.Signer()) if err != nil { t.Fatalf("unexpected error: %v", err) } jwtPK2, err := jose.NewSignedJWT(NewClaims(iss, "XXX", "XXX", past12, future12), priv2.Signer()) if err != nil { t.Fatalf("unexpected error: %v", err) } tests := []struct { verifier JWTVerifier jwt jose.JWT wantErr bool }{ // JWT signed with available key { verifier: JWTVerifier{ issuer: "example.com", clientID: "XXX", syncFunc: func() error { return nil }, keysFunc: func() []key.PublicKey { return []key.PublicKey{pk1} }, }, jwt: *jwtPK1, wantErr: false, }, // JWT signed with available key, with bad claims { verifier: JWTVerifier{ issuer: "example.com", clientID: "XXX", syncFunc: func() error { return nil }, keysFunc: func() []key.PublicKey { return []key.PublicKey{pk1} }, }, jwt: *jwtPK1BadClaims, wantErr: true, }, // expired JWT signed with available key { verifier: JWTVerifier{ issuer: "example.com", clientID: "XXX", syncFunc: func() error { return nil }, keysFunc: func() []key.PublicKey { return []key.PublicKey{pk1} }, }, jwt: *jwtExpired, wantErr: true, }, // JWT signed with unrecognized key, verifiable after sync { verifier: JWTVerifier{ issuer: "example.com", clientID: "XXX", syncFunc: func() error { return nil }, keysFunc: func() func() []key.PublicKey { var i int return func() []key.PublicKey { defer func() { i++ }() return [][]key.PublicKey{ []key.PublicKey{pk1}, []key.PublicKey{pk2}, }[i] } }(), }, jwt: *jwtPK2, wantErr: false, }, // JWT signed with unrecognized key, not verifiable after sync { verifier: JWTVerifier{ issuer: "example.com", clientID: "XXX", syncFunc: func() error { return nil }, keysFunc: func() []key.PublicKey { return []key.PublicKey{pk1} }, }, jwt: *jwtPK2, wantErr: true, }, // verifier gets no keys from keysFunc, still not verifiable after sync { verifier: JWTVerifier{ issuer: "example.com", clientID: "XXX", syncFunc: func() error { return nil }, keysFunc: func() []key.PublicKey { return []key.PublicKey{} }, }, jwt: *jwtPK1, wantErr: true, }, // verifier gets no keys from keysFunc, verifiable after sync { verifier: JWTVerifier{ issuer: "example.com", clientID: "XXX", syncFunc: func() error { return nil }, keysFunc: func() func() []key.PublicKey { var i int return func() []key.PublicKey { defer func() { i++ }() return [][]key.PublicKey{ []key.PublicKey{}, []key.PublicKey{pk2}, }[i] } }(), }, jwt: *jwtPK2, wantErr: false, }, } for i, tt := range tests { err := tt.verifier.Verify(tt.jwt) if tt.wantErr && (err == nil) { t.Errorf("case %d: wanted non-nil error", i) } else if !tt.wantErr && (err != nil) { t.Errorf("case %d: wanted nil error, got %v", i, err) } } }
func TestDBPrivateKeySetRepoSetGet(t *testing.T) { s1 := []byte("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx") s2 := []byte("oooooooooooooooooooooooooooooooo") s3 := []byte("wwwwwwwwwwwwwwwwwwwwwwwwwwwwwwww") keys := []*key.PrivateKey{} for i := 0; i < 2; i++ { k, err := key.GeneratePrivateKey() if err != nil { t.Fatalf("Unable to generate RSA key: %v", err) } keys = append(keys, k) } ks := key.NewPrivateKeySet( []*key.PrivateKey{keys[0], keys[1]}, time.Now().Add(time.Minute)) tests := []struct { setSecrets [][]byte getSecrets [][]byte wantErr bool }{ { // same secrets used to encrypt, decrypt setSecrets: [][]byte{s1, s2}, getSecrets: [][]byte{s1, s2}, }, { // setSecrets got rotated, but getSecrets didn't yet. setSecrets: [][]byte{s2, s3}, getSecrets: [][]byte{s1, s2}, }, { // getSecrets doesn't have s3 setSecrets: [][]byte{s3}, getSecrets: [][]byte{s1, s2}, wantErr: true, }, } for i, tt := range tests { dbMap := connect(t) setRepo, err := db.NewPrivateKeySetRepo(dbMap, false, tt.setSecrets...) if err != nil { t.Fatalf(err.Error()) } getRepo, err := db.NewPrivateKeySetRepo(dbMap, false, tt.getSecrets...) if err != nil { t.Fatalf(err.Error()) } if err := setRepo.Set(ks); err != nil { t.Fatalf("case %d: Unexpected error: %v", i, err) } got, err := getRepo.Get() if tt.wantErr { if err == nil { t.Errorf("case %d: want err, got nil", i) } continue } if err != nil { t.Fatalf("case %d: Unexpected error: %v", i, err) } if diff := pretty.Compare(ks, got); diff != "" { t.Fatalf("case %d:Retrieved incorrect KeySet: Compare(want,got): %v", i, diff) } } }
func TestClientToken(t *testing.T) { now := time.Now() tomorrow := now.Add(24 * time.Hour) validClientID := "valid-client" ci := oidc.ClientIdentity{ Credentials: oidc.ClientCredentials{ ID: validClientID, Secret: base64.URLEncoding.EncodeToString([]byte("secret")), }, Metadata: oidc.ClientMetadata{ RedirectURIs: []url.URL{ {Scheme: "https", Host: "authn.example.com", Path: "/callback"}, }, }, } repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{ci}) if err != nil { t.Fatalf("Failed to create client identity repo: %v", err) } privKey, err := key.GeneratePrivateKey() if err != nil { t.Fatalf("Failed to generate private key, error=%v", err) } signer := privKey.Signer() pubKey := *key.NewPublicKey(privKey.JWK()) validIss := "https://example.com" makeToken := func(iss, sub, aud string, iat, exp time.Time) string { claims := oidc.NewClaims(iss, sub, aud, iat, exp) jwt, err := jose.NewSignedJWT(claims, signer) if err != nil { t.Fatalf("Failed to generate JWT, error=%v", err) } return jwt.Encode() } validJWT := makeToken(validIss, validClientID, validClientID, now, tomorrow) invalidJWT := makeToken("", "", "", now, tomorrow) tests := []struct { keys []key.PublicKey repo client.ClientIdentityRepo header string wantCode int }{ // valid token { keys: []key.PublicKey{pubKey}, repo: repo, header: fmt.Sprintf("BEARER %s", validJWT), wantCode: http.StatusOK, }, // invalid token { keys: []key.PublicKey{pubKey}, repo: repo, header: fmt.Sprintf("BEARER %s", invalidJWT), wantCode: http.StatusUnauthorized, }, // empty header { keys: []key.PublicKey{pubKey}, repo: repo, header: "", wantCode: http.StatusUnauthorized, }, // unparsable token { keys: []key.PublicKey{pubKey}, repo: repo, header: "BEARER xxx", wantCode: http.StatusUnauthorized, }, // no verification keys { keys: []key.PublicKey{}, repo: repo, header: fmt.Sprintf("BEARER %s", validJWT), wantCode: http.StatusUnauthorized, }, // nil repo { keys: []key.PublicKey{pubKey}, repo: nil, header: fmt.Sprintf("BEARER %s", validJWT), wantCode: http.StatusUnauthorized, }, // empty repo { keys: []key.PublicKey{pubKey}, repo: db.NewClientIdentityRepo(db.NewMemDB()), header: fmt.Sprintf("BEARER %s", validJWT), wantCode: http.StatusUnauthorized, }, // client not in repo { keys: []key.PublicKey{pubKey}, repo: repo, header: fmt.Sprintf("BEARER %s", makeToken(validIss, "DOESNT-EXIST", "DOESNT-EXIST", now, tomorrow)), wantCode: http.StatusUnauthorized, }, } for i, tt := range tests { w := httptest.NewRecorder() mw := &clientTokenMiddleware{ issuerURL: validIss, ciRepo: tt.repo, keysFunc: func() ([]key.PublicKey, error) { return tt.keys, nil }, next: staticHandler{}, } req := &http.Request{ Header: http.Header{ "Authorization": []string{tt.header}, }, } mw.ServeHTTP(w, req) if tt.wantCode != w.Code { t.Errorf("case %d: invalid response code, want=%d, got=%d", i, tt.wantCode, w.Code) } } }
func TestJWTVerifier(t *testing.T) { iss := "http://example.com" now := time.Now() future12 := now.Add(12 * time.Hour) past36 := now.Add(-36 * time.Hour) past12 := now.Add(-12 * time.Hour) priv1, err := key.GeneratePrivateKey() if err != nil { t.Fatalf("failed to generate private key, error=%v", err) } pk1 := *key.NewPublicKey(priv1.JWK()) priv2, err := key.GeneratePrivateKey() if err != nil { t.Fatalf("failed to generate private key, error=%v", err) } pk2 := *key.NewPublicKey(priv2.JWK()) newJWT := func(issuer, subject string, aud interface{}, issuedAt, exp time.Time, signer jose.Signer) jose.JWT { jwt, err := jose.NewSignedJWT(NewClaims(issuer, subject, aud, issuedAt, exp), signer) if err != nil { t.Fatal(err) } return *jwt } tests := []struct { name string verifier JWTVerifier jwt jose.JWT wantErr bool }{ { name: "JWT signed with available key", verifier: JWTVerifier{ issuer: "example.com", clientID: "XXX", syncFunc: func() error { return nil }, keysFunc: func() []key.PublicKey { return []key.PublicKey{pk1} }, }, jwt: newJWT(iss, "XXX", "XXX", past12, future12, priv1.Signer()), wantErr: false, }, { name: "JWT signed with available key, with bad claims", verifier: JWTVerifier{ issuer: "example.com", clientID: "XXX", syncFunc: func() error { return nil }, keysFunc: func() []key.PublicKey { return []key.PublicKey{pk1} }, }, jwt: newJWT(iss, "XXX", "YYY", past12, future12, priv1.Signer()), wantErr: true, }, { name: "JWT signed with available key", verifier: JWTVerifier{ issuer: "example.com", clientID: "XXX", syncFunc: func() error { return nil }, keysFunc: func() []key.PublicKey { return []key.PublicKey{pk1} }, }, jwt: newJWT(iss, "XXX", []string{"YYY", "ZZZ"}, past12, future12, priv1.Signer()), wantErr: true, }, { name: "expired JWT signed with available key", verifier: JWTVerifier{ issuer: "example.com", clientID: "XXX", syncFunc: func() error { return nil }, keysFunc: func() []key.PublicKey { return []key.PublicKey{pk1} }, }, jwt: newJWT(iss, "XXX", "XXX", past36, past12, priv1.Signer()), wantErr: true, }, { name: "JWT signed with unrecognized key, verifiable after sync", verifier: JWTVerifier{ issuer: "example.com", clientID: "XXX", syncFunc: func() error { return nil }, keysFunc: func() func() []key.PublicKey { var i int return func() []key.PublicKey { defer func() { i++ }() return [][]key.PublicKey{ []key.PublicKey{pk1}, []key.PublicKey{pk2}, }[i] } }(), }, jwt: newJWT(iss, "XXX", "XXX", past36, future12, priv2.Signer()), wantErr: false, }, { name: "JWT signed with unrecognized key, not verifiable after sync", verifier: JWTVerifier{ issuer: "example.com", clientID: "XXX", syncFunc: func() error { return nil }, keysFunc: func() []key.PublicKey { return []key.PublicKey{pk1} }, }, jwt: newJWT(iss, "XXX", "XXX", past12, future12, priv2.Signer()), wantErr: true, }, { name: "verifier gets no keys from keysFunc, still not verifiable after sync", verifier: JWTVerifier{ issuer: "example.com", clientID: "XXX", syncFunc: func() error { return nil }, keysFunc: func() []key.PublicKey { return []key.PublicKey{} }, }, jwt: newJWT(iss, "XXX", "XXX", past12, future12, priv1.Signer()), wantErr: true, }, { name: "verifier gets no keys from keysFunc, verifiable after sync", verifier: JWTVerifier{ issuer: "example.com", clientID: "XXX", syncFunc: func() error { return nil }, keysFunc: func() func() []key.PublicKey { var i int return func() []key.PublicKey { defer func() { i++ }() return [][]key.PublicKey{ []key.PublicKey{}, []key.PublicKey{pk2}, }[i] } }(), }, jwt: newJWT(iss, "XXX", "XXX", past12, future12, priv2.Signer()), wantErr: false, }, { name: "JWT signed with available key, 'aud' is a string array", verifier: JWTVerifier{ issuer: "example.com", clientID: "XXX", syncFunc: func() error { return nil }, keysFunc: func() []key.PublicKey { return []key.PublicKey{pk1} }, }, jwt: newJWT(iss, "XXX", []string{"ZZZ", "XXX"}, past12, future12, priv1.Signer()), wantErr: false, }, { name: "invalid issuer claim shouldn't trigger sync", verifier: JWTVerifier{ issuer: "example.com", clientID: "XXX", syncFunc: func() error { t.Errorf("invalid issuer claim shouldn't trigger a sync") return nil }, keysFunc: func() func() []key.PublicKey { var i int return func() []key.PublicKey { defer func() { i++ }() return [][]key.PublicKey{ []key.PublicKey{}, []key.PublicKey{pk2}, }[i] } }(), }, jwt: newJWT("invalid-issuer", "XXX", []string{"ZZZ", "XXX"}, past12, future12, priv2.Signer()), wantErr: true, }, } for _, tt := range tests { err := tt.verifier.Verify(tt.jwt) if tt.wantErr && (err == nil) { t.Errorf("case %q: wanted non-nil error", tt.name) } else if !tt.wantErr && (err != nil) { t.Errorf("case %q: wanted nil error, got %v", tt.name, err) } } }
func TestPasswordResetParseAndVerify(t *testing.T) { issuer, _ := url.Parse("http://example.com") otherIssuer, _ := url.Parse("http://bad.example.com") client := "myclient" user := User{ID: "1234", Email: "*****@*****.**"} callback, _ := url.Parse("http://client.example.com") expires := time.Hour * 3 password := Password("passy") goodPR := NewPasswordReset(user, password, *issuer, client, *callback, expires) goodPRNoCB := NewPasswordReset(user, password, *issuer, "", url.URL{}, expires) expiredPR := NewPasswordReset(user, password, *issuer, client, *callback, -expires) wrongIssuerPR := NewPasswordReset(user, password, *otherIssuer, client, *callback, expires) noSubPR := NewPasswordReset(User{}, password, *issuer, client, *callback, expires) noPWPR := NewPasswordReset(user, Password(""), *issuer, client, *callback, expires) noClientPR := NewPasswordReset(user, password, *issuer, "", *callback, expires) privKey, err := key.GeneratePrivateKey() if err != nil { t.Fatalf("Failed to generate private key, error=%v", err) } signer := privKey.Signer() privKey2, err := key.GeneratePrivateKey() if err != nil { t.Fatalf("Failed to generate private key, error=%v", err) } otherSigner := privKey2.Signer() tests := []struct { ev PasswordReset wantErr bool signer jose.Signer }{ { ev: goodPR, signer: signer, wantErr: false, }, { ev: goodPRNoCB, signer: signer, wantErr: false, }, { ev: expiredPR, signer: signer, wantErr: true, }, { ev: wrongIssuerPR, signer: signer, wantErr: true, }, { ev: goodPR, signer: otherSigner, wantErr: true, }, { ev: noSubPR, signer: signer, wantErr: true, }, { ev: noPWPR, signer: signer, wantErr: true, }, { ev: noClientPR, signer: signer, wantErr: true, }, } for i, tt := range tests { token, err := tt.ev.Token(tt.signer) if err != nil { t.Errorf("case %d: non-nil error creating token: %v", i, err) } ev, err := ParseAndVerifyPasswordResetToken(token, *issuer, []key.PublicKey{*key.NewPublicKey(privKey.JWK())}) if tt.wantErr { t.Logf("err: %v", err) if err == nil { t.Errorf("case %d: want non-nil err, got nil", i) } continue } if err != nil { t.Errorf("case %d: non-nil err: %q", i, err) } if diff := pretty.Compare(tt.ev.claims, ev.claims); diff != "" { t.Errorf("case %d: Compare(want, got): %v", i, diff) } } }
func TestResetPasswordHandler(t *testing.T) { makeToken := func(userID, password, clientID string, callback url.URL, expires time.Duration, signer jose.Signer) string { pr := user.NewPasswordReset("ID-1", user.Password(password), testIssuerURL, clientID, callback, expires) jwt, err := jose.NewSignedJWT(pr.Claims, signer) if err != nil { t.Fatalf("couldn't make token: %q", err) } token := jwt.Encode() return token } goodSigner := key.NewPrivateKeySet([]*key.PrivateKey{testPrivKey}, time.Now().Add(time.Minute)).Active().Signer() badKey, err := key.GeneratePrivateKey() if err != nil { t.Fatalf("couldn't make new key: %q", err) } badSigner := key.NewPrivateKeySet([]*key.PrivateKey{badKey}, time.Now().Add(time.Minute)).Active().Signer() str := func(s string) []string { return []string{s} } user.PasswordHasher = func(s string) ([]byte, error) { return []byte(strings.ToUpper(s)), nil } defer func() { user.PasswordHasher = user.DefaultPasswordHasher }() tokenForCase := map[int]string{ 0: makeToken("ID-1", "password", testClientID, testRedirectURL, time.Hour*1, goodSigner), 2: makeToken("ID-1", "password", testClientID, url.URL{}, time.Hour*1, goodSigner), 5: makeToken("ID-1", "password", testClientID, url.URL{}, time.Hour*1, goodSigner), } tests := []struct { query url.Values method string wantFormValues *url.Values wantCode int wantPassword string }{ // Scenario 1: Happy Path { // Case 0 // Step 1.1 - User clicks link in email, has valid token. query: url.Values{ "token": str(tokenForCase[0]), }, method: "GET", wantCode: http.StatusOK, wantFormValues: &url.Values{ "password": str(""), "token": str(tokenForCase[0]), }, wantPassword: "******", }, { // Case 1 // Step 1.2 - User enters in new valid password, password is changed, user is redirected. query: url.Values{ "token": str(makeToken("ID-1", "password", testClientID, testRedirectURL, time.Hour*1, goodSigner)), "password": str("new_password"), }, method: "POST", wantCode: http.StatusSeeOther, wantFormValues: &url.Values{}, wantPassword: "******", }, // Scenario 2: Happy Path, but without redirect. { // Case 2 // Step 2.1 - User clicks link in email, has valid token. query: url.Values{ "token": str(tokenForCase[2]), }, method: "GET", wantCode: http.StatusOK, wantFormValues: &url.Values{ "password": str(""), "token": str(tokenForCase[2]), }, wantPassword: "******", }, { // Case 3 // Step 2.2 - User enters in new valid password, password is changed, user is redirected. query: url.Values{ "token": str(makeToken("ID-1", "password", testClientID, url.URL{}, time.Hour*1, goodSigner)), "password": str("new_password"), }, method: "POST", // no redirect wantCode: http.StatusOK, wantFormValues: &url.Values{}, wantPassword: "******", }, // Errors { // Case 4 // Step 1.1.1 - User clicks link in email, has invalid token. query: url.Values{ "token": str(makeToken("ID-1", "password", testClientID, testRedirectURL, time.Hour*1, badSigner)), }, method: "GET", wantCode: http.StatusBadRequest, wantFormValues: &url.Values{}, wantPassword: "******", }, { // Case 5 // Step 2.2.1 - User enters in new valid password, password is changed, no redirect query: url.Values{ "token": str(tokenForCase[5]), "password": str("shrt"), }, method: "POST", // no redirect wantCode: http.StatusBadRequest, wantFormValues: &url.Values{ "password": str(""), "token": str(tokenForCase[5]), }, wantPassword: "******", }, { // Case 6 // Step 2.2.2 - User enters in new valid password, with suspicious token. query: url.Values{ "token": str(makeToken("ID-1", "password", testClientID, url.URL{}, time.Hour*1, badSigner)), "password": str("shrt"), }, method: "POST", // no redirect wantCode: http.StatusBadRequest, wantFormValues: &url.Values{}, wantPassword: "******", }, { // Case 7 // Token lacking client id query: url.Values{ "token": str(makeToken("ID-1", "password", "", url.URL{}, time.Hour*1, goodSigner)), "password": str("shrt"), }, method: "GET", wantCode: http.StatusBadRequest, wantPassword: "******", }, { // Case 8 // Token lacking client id query: url.Values{ "token": str(makeToken("ID-1", "password", "", url.URL{}, time.Hour*1, goodSigner)), "password": str("shrt"), }, method: "POST", wantCode: http.StatusBadRequest, wantPassword: "******", }, } for i, tt := range tests { f, err := makeTestFixtures() if err != nil { t.Fatalf("case %d: could not make test fixtures: %v", i, err) } hdlr := ResetPasswordHandler{ tpl: f.srv.ResetPasswordTemplate, issuerURL: testIssuerURL, um: f.srv.UserManager, keysFunc: f.srv.KeyManager.PublicKeys, } w := httptest.NewRecorder() var req *http.Request u := testIssuerURL u.Path = httpPathResetPassword if tt.method == "POST" { req, err = http.NewRequest(tt.method, u.String(), strings.NewReader(tt.query.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") } else { u.RawQuery = tt.query.Encode() req, err = http.NewRequest(tt.method, u.String(), nil) } if err != nil { t.Errorf("case %d: unable to form HTTP request: %v", i, err) } hdlr.ServeHTTP(w, req) if tt.wantCode != w.Code { t.Errorf("case %d: wantCode=%v, got=%v", i, tt.wantCode, w.Code) continue } values, err := html.FormValues("#resetPasswordForm", bytes.NewReader(w.Body.Bytes())) if err != nil { t.Errorf("case %d: could not parse form: %v", i, err) } if tt.wantFormValues != nil { if diff := pretty.Compare(*tt.wantFormValues, values); diff != "" { t.Errorf("case %d: Compare(wantFormValues, got) = %v", i, diff) } } pwi, err := f.srv.PasswordInfoRepo.Get(nil, "ID-1") if err != nil { t.Errorf("case %d: Error getting Password info: %v", i, err) } if tt.wantPassword != string(pwi.Password) { t.Errorf("case %d: wantPassword=%v, got=%v", i, tt.wantPassword, string(pwi.Password)) } } }
func TestClientToken(t *testing.T) { now := time.Now() tomorrow := now.Add(24 * time.Hour) clientMetadata := oidc.ClientMetadata{ RedirectURIs: []url.URL{ {Scheme: "https", Host: "authn.example.com", Path: "/callback"}, }, } dbm := db.NewMemDB() clientRepo := db.NewClientRepo(dbm) clientManager := clientmanager.NewClientManager(clientRepo, db.TransactionFactory(dbm), clientmanager.ManagerOptions{}) cli := client.Client{ Metadata: clientMetadata, } creds, err := clientManager.New(cli) if err != nil { t.Fatalf("Failed to create client: %v", err) } validClientID := creds.ID privKey, err := key.GeneratePrivateKey() if err != nil { t.Fatalf("Failed to generate private key, error=%v", err) } signer := privKey.Signer() pubKey := *key.NewPublicKey(privKey.JWK()) validIss := "https://example.com" makeToken := func(iss, sub, aud string, iat, exp time.Time) string { claims := oidc.NewClaims(iss, sub, aud, iat, exp) jwt, err := jose.NewSignedJWT(claims, signer) if err != nil { t.Fatalf("Failed to generate JWT, error=%v", err) } return jwt.Encode() } validJWT := makeToken(validIss, validClientID, validClientID, now, tomorrow) invalidJWT := makeToken("", "", "", now, tomorrow) tests := []struct { keys []key.PublicKey manager *clientmanager.ClientManager header string wantCode int }{ // valid token { keys: []key.PublicKey{pubKey}, manager: clientManager, header: fmt.Sprintf("BEARER %s", validJWT), wantCode: http.StatusOK, }, // invalid token { keys: []key.PublicKey{pubKey}, manager: clientManager, header: fmt.Sprintf("BEARER %s", invalidJWT), wantCode: http.StatusUnauthorized, }, // empty header { keys: []key.PublicKey{pubKey}, manager: clientManager, header: "", wantCode: http.StatusUnauthorized, }, // unparsable token { keys: []key.PublicKey{pubKey}, manager: clientManager, header: "BEARER xxx", wantCode: http.StatusUnauthorized, }, // no verification keys { keys: []key.PublicKey{}, manager: clientManager, header: fmt.Sprintf("BEARER %s", validJWT), wantCode: http.StatusUnauthorized, }, // nil repo { keys: []key.PublicKey{pubKey}, manager: nil, header: fmt.Sprintf("BEARER %s", validJWT), wantCode: http.StatusUnauthorized, }, // empty repo { keys: []key.PublicKey{pubKey}, manager: clientmanager.NewClientManager(db.NewClientRepo(db.NewMemDB()), db.TransactionFactory(db.NewMemDB()), clientmanager.ManagerOptions{}), header: fmt.Sprintf("BEARER %s", validJWT), wantCode: http.StatusUnauthorized, }, // client not in repo { keys: []key.PublicKey{pubKey}, manager: clientManager, header: fmt.Sprintf("BEARER %s", makeToken(validIss, "DOESNT-EXIST", "DOESNT-EXIST", now, tomorrow)), wantCode: http.StatusUnauthorized, }, } for i, tt := range tests { w := httptest.NewRecorder() mw := &clientTokenMiddleware{ issuerURL: validIss, ciManager: tt.manager, keysFunc: func() ([]key.PublicKey, error) { return tt.keys, nil }, next: staticHandler{}, } req := &http.Request{ Header: http.Header{ "Authorization": []string{tt.header}, }, } mw.ServeHTTP(w, req) if tt.wantCode != w.Code { t.Errorf("case %d: invalid response code, want=%d, got=%d", i, tt.wantCode, w.Code) } } }