Example #1
0
func TestClientKeysFuncAll(t *testing.T) {
	priv1, err := key.GeneratePrivateKey()
	if err != nil {
		t.Fatalf("failed to generate private key, error=%v", err)
	}

	priv2, err := key.GeneratePrivateKey()
	if err != nil {
		t.Fatalf("failed to generate private key, error=%v", err)
	}

	now := time.Now()
	future := now.Add(time.Hour)
	past := now.Add(-1 * time.Hour)

	tests := []struct {
		keySet *key.PublicKeySet
		want   []key.PublicKey
	}{
		// two keys, non-expired set
		{
			keySet: key.NewPublicKeySet([]jose.JWK{priv2.JWK(), priv1.JWK()}, future),
			want:   []key.PublicKey{*key.NewPublicKey(priv2.JWK()), *key.NewPublicKey(priv1.JWK())},
		},

		// no keys, non-expired set
		{
			keySet: key.NewPublicKeySet([]jose.JWK{}, future),
			want:   []key.PublicKey{},
		},

		// two keys, expired set
		{
			keySet: key.NewPublicKeySet([]jose.JWK{priv2.JWK(), priv1.JWK()}, past),
			want:   []key.PublicKey{},
		},

		// no keys, expired set
		{
			keySet: key.NewPublicKeySet([]jose.JWK{}, past),
			want:   []key.PublicKey{},
		},
	}

	for i, tt := range tests {
		var c Client
		c.keySet = *tt.keySet
		keysFunc := c.keysFuncAll()
		got := keysFunc()
		if !reflect.DeepEqual(tt.want, got) {
			t.Errorf("case %d: want=%#v got=%#v", i, tt.want, got)
		}
	}
}
Example #2
0
func mockServer(cis []client.LoadableClient) (*server.Server, error) {
	dbMap := db.NewMemDB()
	k, err := key.GeneratePrivateKey()
	if err != nil {
		return nil, fmt.Errorf("Unable to generate private key: %v", err)
	}

	km := key.NewPrivateKeyManager()
	err = km.Set(key.NewPrivateKeySet([]*key.PrivateKey{k}, time.Now().Add(time.Minute)))
	if err != nil {
		return nil, err
	}

	clientRepo, clientManager, err := makeClientRepoAndManager(dbMap, cis)
	if err != nil {
		return nil, err
	}

	sm := manager.NewSessionManager(db.NewSessionRepo(dbMap), db.NewSessionKeyRepo(dbMap))
	srv := &server.Server{
		IssuerURL:      url.URL{Scheme: "http", Host: "server.example.com"},
		KeyManager:     km,
		ClientRepo:     clientRepo,
		ClientManager:  clientManager,
		SessionManager: sm,
	}

	return srv, nil
}
Example #3
0
func TestGetClientIDFromAuthorizedRequest(t *testing.T) {
	now := time.Now()
	tomorrow := now.Add(24 * time.Hour)

	privKey, err := key.GeneratePrivateKey()
	if err != nil {
		t.Fatalf("Failed to generate private key, error=%v", err)
	}

	signer := privKey.Signer()

	makeToken := func(iss, sub, aud string, iat, exp time.Time) string {
		claims := oidc.NewClaims(iss, sub, aud, iat, exp)
		jwt, err := jose.NewSignedJWT(claims, signer)
		if err != nil {
			t.Fatalf("Failed to generate JWT, error=%v", err)
		}
		return jwt.Encode()
	}

	tests := []struct {
		header     string
		wantClient string
		wantErr    bool
	}{
		{
			header:     fmt.Sprintf("BEARER %s", makeToken("iss", "CLIENT_ID", "", now, tomorrow)),
			wantClient: "CLIENT_ID",
			wantErr:    false,
		},
		{
			header:  fmt.Sprintf("BEARER %s", makeToken("iss", "", "", now, tomorrow)),
			wantErr: true,
		},
	}

	for i, tt := range tests {
		req := &http.Request{
			Header: http.Header{
				"Authorization": []string{tt.header},
			},
		}
		gotClient, err := getClientIDFromAuthorizedRequest(req)
		if tt.wantErr {
			if err == nil {
				t.Errorf("case %d: want non-nil err", i)
			}
			continue
		}

		if err != nil {
			t.Errorf("case %d: got err: %q", i, err)
			continue
		}

		if gotClient != tt.wantClient {
			t.Errorf("case %d: want=%v, got=%v", i, tt.wantClient, gotClient)
		}
	}
}
Example #4
0
func makeTestFixtures() (*UserEmailer, *testEmailer, *key.PublicKey) {
	ur := user.NewUserRepoFromUsers([]user.UserWithRemoteIdentities{
		{
			User: user.User{
				ID:    "ID-1",
				Email: "*****@*****.**",
				Admin: true,
			},
		}, {
			User: user.User{
				ID:    "ID-2",
				Email: "*****@*****.**",
			},
		}, {
			User: user.User{
				ID:    "ID-3",
				Email: "*****@*****.**",
			},
		},
	})
	pwr := user.NewPasswordInfoRepoFromPasswordInfos([]user.PasswordInfo{
		{
			UserID:   "ID-1",
			Password: []byte("password-1"),
		},
		{
			UserID:   "ID-2",
			Password: []byte("password-2"),
		},
	})

	privKey, err := key.GeneratePrivateKey()
	if err != nil {
		panic(fmt.Sprintf("Failed to generate private key, error=%v", err))
	}

	publicKey := key.NewPublicKey(privKey.JWK())
	signer := privKey.Signer()
	signerFn := func() (jose.Signer, error) {
		return signer, nil
	}

	textTemplateString := `{{define "password-reset.txt"}}{{.link}}{{end}}
{{define "verify-email.txt"}}{{.link}}{{end}}"`
	textTemplates := template.New("text")
	_, err = textTemplates.Parse(textTemplateString)
	if err != nil {
		panic(fmt.Sprintf("error parsing text templates: %v", err))
	}

	htmlTemplates := htmltemplate.New("html")

	emailer := &testEmailer{}
	tEmailer := email.NewTemplatizedEmailerFromTemplates(textTemplates, htmlTemplates, emailer)

	userEmailer := NewUserEmailer(ur, pwr, signerFn, validityWindow, issuerURL, tEmailer, fromAddress, passwordResetURL, verifyEmailURL, acceptInvitationURL)

	return userEmailer, emailer, publicKey
}
Example #5
0
func (cfg *SingleServerConfig) Configure(srv *Server) error {
	k, err := key.GeneratePrivateKey()
	if err != nil {
		return err
	}

	ks := key.NewPrivateKeySet([]*key.PrivateKey{k}, time.Now().Add(24*time.Hour))
	kRepo := key.NewPrivateKeySetRepo()
	if err = kRepo.Set(ks); err != nil {
		return err
	}

	cf, err := os.Open(cfg.ClientsFile)
	if err != nil {
		return fmt.Errorf("unable to read clients from file %s: %v", cfg.ClientsFile, err)
	}
	defer cf.Close()
	ciRepo, err := client.NewClientIdentityRepoFromReader(cf)
	if err != nil {
		return fmt.Errorf("unable to read client identities from file %s: %v", cfg.ClientsFile, err)
	}

	f, err := os.Open(cfg.ConnectorsFile)
	if err != nil {
		return fmt.Errorf("opening connectors file: %v", err)
	}
	defer f.Close()
	cfgs, err := connector.ReadConfigs(f)
	if err != nil {
		return fmt.Errorf("decoding connector configs: %v", err)
	}
	cfgRepo := connector.NewConnectorConfigRepoFromConfigs(cfgs)

	sRepo := session.NewSessionRepo()
	skRepo := session.NewSessionKeyRepo()
	sm := session.NewSessionManager(sRepo, skRepo)

	userRepo, err := user.NewUserRepoFromFile(cfg.UsersFile)
	if err != nil {
		return fmt.Errorf("unable to read users from file: %v", err)
	}

	pwiRepo := user.NewPasswordInfoRepo()

	refTokRepo := refresh.NewRefreshTokenRepo()

	txnFactory := repo.InMemTransactionFactory
	userManager := manager.NewUserManager(userRepo, pwiRepo, cfgRepo, txnFactory, manager.ManagerOptions{})
	srv.ClientIdentityRepo = ciRepo
	srv.KeySetRepo = kRepo
	srv.ConnectorConfigRepo = cfgRepo
	srv.UserRepo = userRepo
	srv.UserManager = userManager
	srv.PasswordInfoRepo = pwiRepo
	srv.SessionManager = sm
	srv.RefreshTokenRepo = refTokRepo
	return nil

}
Example #6
0
// NewOIDCProvider provides a bare minimum OIDC IdP Server useful for testing.
func NewOIDCProvider(t *testing.T) *OIDCProvider {
	privKey, err := key.GeneratePrivateKey()
	if err != nil {
		t.Fatalf("Cannot create OIDC Provider: %v", err)
		return nil
	}

	op := &OIDCProvider{
		Mux:     http.NewServeMux(),
		PrivKey: privKey,
	}

	op.Mux.HandleFunc("/.well-known/openid-configuration", op.handleConfig)
	op.Mux.HandleFunc("/keys", op.handleKeys)

	return op
}
Example #7
0
func mockServer(cis []oidc.ClientIdentity) (*server.Server, error) {
	k, err := key.GeneratePrivateKey()
	if err != nil {
		return nil, fmt.Errorf("Unable to generate private key: %v", err)
	}

	km := key.NewPrivateKeyManager()
	err = km.Set(key.NewPrivateKeySet([]*key.PrivateKey{k}, time.Now().Add(time.Minute)))
	if err != nil {
		return nil, err
	}

	sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
	srv := &server.Server{
		IssuerURL:          url.URL{Scheme: "http", Host: "server.example.com"},
		KeyManager:         km,
		ClientIdentityRepo: client.NewClientIdentityRepo(cis),
		SessionManager:     sm,
	}

	return srv, nil
}
Example #8
0
func mockServer(cis []client.Client) (*server.Server, error) {
	dbMap := db.NewMemDB()
	k, err := key.GeneratePrivateKey()
	if err != nil {
		return nil, fmt.Errorf("Unable to generate private key: %v", err)
	}

	km := key.NewPrivateKeyManager()
	err = km.Set(key.NewPrivateKeySet([]*key.PrivateKey{k}, time.Now().Add(time.Minute)))
	if err != nil {
		return nil, err
	}

	clientIDGenerator := func(hostport string) (string, error) {
		return hostport, nil
	}
	secGen := func() ([]byte, error) {
		return []byte("secret"), nil
	}
	clientRepo := db.NewClientRepo(dbMap)
	clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbMap), cis, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
	if err != nil {
		return nil, err
	}

	sm := manager.NewSessionManager(db.NewSessionRepo(dbMap), db.NewSessionKeyRepo(dbMap))
	srv := &server.Server{
		IssuerURL:      url.URL{Scheme: "http", Host: "server.example.com"},
		KeyManager:     km,
		ClientRepo:     clientRepo,
		ClientManager:  clientManager,
		SessionManager: sm,
	}

	return srv, nil
}
Example #9
0
func TestHTTPExchangeTokenRefreshToken(t *testing.T) {
	password, err := user.NewPasswordFromPlaintext("woof")
	if err != nil {
		t.Fatalf("unexpectd error: %q", err)
	}

	passwordInfo := user.PasswordInfo{
		UserID:   "elroy77",
		Password: password,
	}

	cfg := &connector.LocalConnectorConfig{
		ID: "local",
	}

	validRedirURL := url.URL{
		Scheme: "http",
		Host:   "client.example.com",
		Path:   "/callback",
	}
	ci := client.Client{
		Credentials: oidc.ClientCredentials{
			ID:     validRedirURL.Host,
			Secret: base64.URLEncoding.EncodeToString([]byte("secret")),
		},
		Metadata: oidc.ClientMetadata{
			RedirectURIs: []url.URL{
				validRedirURL,
			},
		},
	}

	dbMap := db.NewMemDB()
	clientRepo, clientManager, err := makeClientRepoAndManager(dbMap,
		[]client.LoadableClient{{
			Client: ci,
		}})
	if err != nil {
		t.Fatalf("Failed to create client identity manager: " + err.Error())
	}

	passwordInfoRepo, err := db.NewPasswordInfoRepoFromPasswordInfos(db.NewMemDB(), []user.PasswordInfo{passwordInfo})
	if err != nil {
		t.Fatalf("Failed to create password info repo: %v", err)
	}

	issuerURL := url.URL{Scheme: "http", Host: "server.example.com"}
	sm := manager.NewSessionManager(db.NewSessionRepo(dbMap), db.NewSessionKeyRepo(dbMap))

	k, err := key.GeneratePrivateKey()
	if err != nil {
		t.Fatalf("Unable to generate RSA key: %v", err)
	}

	km := key.NewPrivateKeyManager()
	err = km.Set(key.NewPrivateKeySet([]*key.PrivateKey{k}, time.Now().Add(time.Minute)))
	if err != nil {
		t.Fatalf("Unexpected error: %v", err)
	}

	usr := user.User{
		ID:          "ID-test",
		Email:       "*****@*****.**",
		DisplayName: "displayname",
	}
	userRepo := db.NewUserRepo(db.NewMemDB())
	if err := userRepo.Create(nil, usr); err != nil {
		t.Fatalf("Unexpected error: %v", err)
	}

	refreshTokenRepo := refreshtest.NewTestRefreshTokenRepo()

	srv := &server.Server{
		IssuerURL:        issuerURL,
		KeyManager:       km,
		SessionManager:   sm,
		ClientRepo:       clientRepo,
		ClientManager:    clientManager,
		Templates:        template.New(connector.LoginPageTemplateName),
		Connectors:       []connector.Connector{},
		UserRepo:         userRepo,
		PasswordInfoRepo: passwordInfoRepo,
		RefreshTokenRepo: refreshTokenRepo,
	}

	if err = srv.AddConnector(cfg); err != nil {
		t.Fatalf("Unexpected error: %v", err)
	}

	sClient := &phttp.HandlerClient{Handler: srv.HTTPHandler()}
	pcfg, err := oidc.FetchProviderConfig(sClient, issuerURL.String())
	if err != nil {
		t.Fatalf("Failed to fetch provider config: %v", err)
	}

	ks := key.NewPublicKeySet([]jose.JWK{k.JWK()}, time.Now().Add(1*time.Hour))

	ccfg := oidc.ClientConfig{
		HTTPClient:     sClient,
		ProviderConfig: pcfg,
		Credentials:    ci.Credentials,
		RedirectURL:    validRedirURL.String(),
		KeySet:         *ks,
	}

	cl, err := oidc.NewClient(ccfg)
	if err != nil {
		t.Fatalf("Failed creating oidc.Client: %v", err)
	}

	m := http.NewServeMux()

	var claims jose.Claims
	var refresh string

	m.HandleFunc("/callback", handleCallbackFunc(cl, &claims, &refresh))
	cClient := &phttp.HandlerClient{Handler: m}

	// this will actually happen due to some interaction between the
	// end-user and a remote identity provider
	sessionID, err := sm.NewSession("bogus_idpc", ci.Credentials.ID, "bogus", url.URL{}, "", false, []string{"openid", "offline_access", "email", "profile"})
	if err != nil {
		t.Fatalf("Unexpected error: %v", err)
	}
	if _, err = sm.AttachRemoteIdentity(sessionID, passwordInfo.Identity()); err != nil {
		t.Fatalf("Unexpected error: %v", err)
	}

	if _, err = sm.AttachUser(sessionID, usr.ID); err != nil {
		t.Fatalf("Unexpected error: %v", err)
	}

	key, err := sm.NewSessionKey(sessionID)
	if err != nil {
		t.Fatalf("Unexpected error: %v", err)
	}

	req, err := http.NewRequest("GET", fmt.Sprintf("http://client.example.com/callback?code=%s", key), nil)
	if err != nil {
		t.Fatalf("Failed creating HTTP request: %v", err)
	}

	resp, err := cClient.Do(req)
	if err != nil {
		t.Fatalf("Failed resolving HTTP requests against /callback: %v", err)
	}

	if err := verifyUserClaims(claims, &ci, &usr, issuerURL); err != nil {
		t.Fatalf("Failed to verify claims: %v", err)
	}

	if resp.StatusCode != http.StatusOK {
		t.Fatalf("Received status code %d, want %d", resp.StatusCode, http.StatusOK)
	}

	if refresh == "" {
		t.Fatalf("No refresh token")
	}

	// Use refresh token to get a new ID token.
	token, err := cl.RefreshToken(refresh)
	if err != nil {
		t.Fatalf("Unexpected error: %v", err)
	}

	claims, err = token.Claims()
	if err != nil {
		t.Fatalf("Failed parsing claims from client token: %v", err)
	}

	if err := verifyUserClaims(claims, &ci, &usr, issuerURL); err != nil {
		t.Fatalf("Failed to verify claims: %v", err)
	}
}
Example #10
0
func TestInvitationParseAndVerify(t *testing.T) {
	issuer, _ := url.Parse("http://example.com")
	notIssuer, _ := url.Parse("http://other.com")
	client := "myclient"
	user := User{ID: "1234", Email: "*****@*****.**"}
	callback, _ := url.Parse("http://client.example.com")
	expires := time.Hour * 3
	password := Password("Halloween is the best holiday")
	privKey, _ := key.GeneratePrivateKey()
	signer := privKey.Signer()
	publicKeys := []key.PublicKey{*key.NewPublicKey(privKey.JWK())}

	tests := []struct {
		invite  Invitation
		wantErr bool
		signer  jose.Signer
	}{
		{
			invite:  NewInvitation(user, password, *issuer, client, *callback, expires),
			signer:  signer,
			wantErr: false,
		},
		{
			invite:  NewInvitation(user, password, *issuer, client, *callback, expires),
			signer:  signer,
			wantErr: false,
		},
		{
			invite:  NewInvitation(user, password, *issuer, client, *callback, -expires),
			signer:  signer,
			wantErr: true,
		},
		{
			invite:  NewInvitation(user, password, *notIssuer, client, *callback, expires),
			signer:  signer,
			wantErr: true,
		},
		{
			invite:  NewInvitation(User{Email: "*****@*****.**"}, password, *issuer, client, *callback, expires),
			signer:  signer,
			wantErr: true,
		},
		{
			invite:  NewInvitation(User{ID: "JONNY_NO_EMAIL"}, password, *issuer, client, *callback, expires),
			signer:  signer,
			wantErr: true,
		},
		{
			invite:  NewInvitation(user, Password(""), *issuer, client, *callback, expires),
			signer:  signer,
			wantErr: true,
		},
		{
			invite:  NewInvitation(user, password, *issuer, "", *callback, expires),
			signer:  signer,
			wantErr: true,
		},
		{
			invite:  NewInvitation(user, password, *issuer, "", url.URL{}, expires),
			signer:  signer,
			wantErr: true,
		},
	}

	for i, tt := range tests {
		jwt, err := jose.NewSignedJWT(tt.invite.Claims, tt.signer)
		if err != nil {
			t.Fatalf("case %d: failed to generate JWT, error: %v", i, err)
		}
		token := jwt.Encode()

		parsed, err := ParseAndVerifyInvitationToken(token, *issuer, publicKeys)

		if tt.wantErr {
			if err == nil {
				t.Errorf("case %d: want no-nil error, got nil", i)
			}
			continue
		}

		if err != nil {
			t.Errorf("case %d: unexpected error: %v", i, err)
			continue
		}

		if diff := pretty.Compare(tt.invite, parsed); diff != "" {
			t.Errorf("case %d: Compare(want, got): %v", i, diff)
		}
	}
}
Example #11
0
func (cfg *SingleServerConfig) Configure(srv *Server) error {
	k, err := key.GeneratePrivateKey()
	if err != nil {
		return err
	}

	dbMap := db.NewMemDB()

	ks := key.NewPrivateKeySet([]*key.PrivateKey{k}, time.Now().Add(24*time.Hour))
	kRepo := key.NewPrivateKeySetRepo()
	if err = kRepo.Set(ks); err != nil {
		return err
	}

	clients, err := loadClients(cfg.ClientsFile)
	if err != nil {
		return fmt.Errorf("unable to read clients from file %s: %v", cfg.ClientsFile, err)
	}

	clientRepo, err := db.NewClientRepoFromClients(dbMap, clients)
	if err != nil {
		return err
	}

	f, err := os.Open(cfg.ConnectorsFile)
	if err != nil {
		return fmt.Errorf("opening connectors file: %v", err)
	}
	defer f.Close()
	cfgs, err := connector.ReadConfigs(f)
	if err != nil {
		return fmt.Errorf("decoding connector configs: %v", err)
	}
	cfgRepo := db.NewConnectorConfigRepo(dbMap)
	if err := cfgRepo.Set(cfgs); err != nil {
		return fmt.Errorf("failed to set connectors: %v", err)
	}

	sRepo := db.NewSessionRepo(dbMap)
	skRepo := db.NewSessionKeyRepo(dbMap)
	sm := sessionmanager.NewSessionManager(sRepo, skRepo)

	users, pwis, err := loadUsers(cfg.UsersFile)
	if err != nil {
		return fmt.Errorf("unable to read users from file: %v", err)
	}
	userRepo, err := db.NewUserRepoFromUsers(dbMap, users)
	if err != nil {
		return err
	}

	pwiRepo, err := db.NewPasswordInfoRepoFromPasswordInfos(dbMap, pwis)
	if err != nil {
		return err
	}

	refTokRepo := db.NewRefreshTokenRepo(dbMap)

	txnFactory := db.TransactionFactory(dbMap)
	userManager := usermanager.NewUserManager(userRepo, pwiRepo, cfgRepo, txnFactory, usermanager.ManagerOptions{})
	clientManager := clientmanager.NewClientManager(clientRepo, db.TransactionFactory(dbMap), clientmanager.ManagerOptions{})
	if err != nil {
		return fmt.Errorf("Failed to create client identity manager: %v", err)
	}
	srv.ClientRepo = clientRepo
	srv.ClientManager = clientManager
	srv.KeySetRepo = kRepo
	srv.ConnectorConfigRepo = cfgRepo
	srv.UserRepo = userRepo
	srv.UserManager = userManager
	srv.PasswordInfoRepo = pwiRepo
	srv.SessionManager = sm
	srv.RefreshTokenRepo = refTokRepo
	srv.HealthChecks = append(srv.HealthChecks, db.NewHealthChecker(dbMap))
	srv.dbMap = dbMap
	return nil
}
Example #12
0
func TestInvitationHandler(t *testing.T) {
	invUserID := "ID-1"
	invVerifiedID := "ID-Verified"
	invGoodSigner := key.NewPrivateKeySet([]*key.PrivateKey{testPrivKey},
		time.Now().Add(time.Minute)).Active().Signer()

	badKey, err := key.GeneratePrivateKey()
	if err != nil {
		panic(fmt.Sprintf("couldn't make new key: %q", err))
	}

	invBadSigner := key.NewPrivateKeySet([]*key.PrivateKey{badKey},
		time.Now().Add(time.Minute)).Active().Signer()

	makeInvitationToken := func(password, userID, clientID, email string, callback url.URL, expires time.Duration, signer jose.Signer) string {
		iv := user.NewInvitation(
			user.User{ID: userID, Email: email},
			user.Password(password),
			testIssuerURL,
			clientID,
			callback,
			expires)

		jwt, err := jose.NewSignedJWT(iv.Claims, signer)
		if err != nil {
			t.Fatalf("couldn't make token: %q", err)
		}
		token := jwt.Encode()
		return token
	}

	tests := []struct {
		userID            string
		query             url.Values
		signer            jose.Signer
		wantCode          int
		wantCallback      url.URL
		wantEmailVerified bool
	}{
		{ // Case 0 Happy Path
			userID: invUserID,
			query: url.Values{
				"token": []string{makeInvitationToken("password", invUserID, testClientID, "*****@*****.**", testRedirectURL, time.Hour*1, invGoodSigner)},
			},
			signer:            invGoodSigner,
			wantCode:          http.StatusSeeOther,
			wantCallback:      testRedirectURL,
			wantEmailVerified: true,
		},
		{ // Case 1 user already verified
			userID: invVerifiedID,
			query: url.Values{
				"token": []string{makeInvitationToken("password", invVerifiedID, testClientID, "*****@*****.**", testRedirectURL, time.Hour*1, invGoodSigner)},
			},
			signer:            invGoodSigner,
			wantCode:          http.StatusSeeOther,
			wantCallback:      testRedirectURL,
			wantEmailVerified: true,
		},
		{ // Case 2 bad email
			userID: invUserID,
			query: url.Values{
				"token": []string{makeInvitationToken("password", invVerifiedID, testClientID, "*****@*****.**", testRedirectURL, time.Hour*1, invGoodSigner)},
			},
			signer:            invGoodSigner,
			wantCode:          http.StatusBadRequest,
			wantCallback:      testRedirectURL,
			wantEmailVerified: false,
		},
		{ // Case 3 bad signer
			userID: invUserID,
			query: url.Values{
				"token": []string{makeInvitationToken("password", invUserID, testClientID, "*****@*****.**", testRedirectURL, time.Hour*1, invBadSigner)},
			},
			signer:            invGoodSigner,
			wantCode:          http.StatusBadRequest,
			wantCallback:      testRedirectURL,
			wantEmailVerified: false,
		},
	}

	for i, tt := range tests {
		f, err := makeTestFixtures()
		if err != nil {
			t.Fatalf("case %d: could not make test fixtures: %v", i, err)
		}

		keys, err := f.srv.KeyManager.PublicKeys()
		if err != nil {
			t.Fatalf("case %d: test fixture key infrastructure is broken: %v", i, err)
		}

		tZero := clock.Now()
		handler := &InvitationHandler{
			passwordResetURL:       f.srv.absURL("RESETME"),
			issuerURL:              testIssuerURL,
			um:                     f.srv.UserManager,
			keysFunc:               f.srv.KeyManager.PublicKeys,
			signerFunc:             func() (jose.Signer, error) { return tt.signer, nil },
			redirectValidityWindow: 100 * time.Second,
		}

		w := httptest.NewRecorder()
		u := testIssuerURL
		u.RawQuery = tt.query.Encode()
		req, err := http.NewRequest("GET", u.String(), nil)
		if err != nil {
			t.Fatalf("case %d: impossible error: %v", i, err)
		}

		handler.ServeHTTP(w, req)

		if tt.wantCode != w.Code {
			t.Errorf("case %d: wantCode=%v, got=%v", i, tt.wantCode, w.Code)
			continue
		}

		usr, err := f.srv.UserManager.Get(tt.userID)
		if err != nil {
			t.Fatalf("case %d: unexpected error: %v", i, err)
		}

		if usr.EmailVerified != tt.wantEmailVerified {
			t.Errorf("case %d: wantEmailVerified=%v got=%v", i, tt.wantEmailVerified, usr.EmailVerified)
		}

		if w.Code == http.StatusSeeOther {
			locString := w.HeaderMap.Get("Location")
			loc, err := url.Parse(locString)
			if err != nil {
				t.Fatalf("case %d: redirect returned nonsense url: '%v', %v", i, locString, err)
			}

			pwrToken := loc.Query().Get("token")
			pwrReset, err := user.ParseAndVerifyPasswordResetToken(pwrToken, testIssuerURL, keys)
			if err != nil {
				t.Errorf("case %d: password token is invalid: %v", i, err)
			}

			expTime := pwrReset.Claims["exp"].(float64)
			if expTime > float64(tZero.Add(handler.redirectValidityWindow).Unix()) ||
				expTime < float64(tZero.Unix()) {
				t.Errorf("case %d: funny expiration time detected: %d", i, pwrReset.Claims["exp"])
			}

			if pwrReset.Claims["aud"] != testClientID {
				t.Errorf("case %d: wanted \"aud\"=%v got=%v", i, testClientID, pwrReset.Claims["aud"])
			}

			if pwrReset.Claims["iss"] != testIssuerURL.String() {
				t.Errorf("case %d: wanted \"iss\"=%v got=%v", i, testIssuerURL, pwrReset.Claims["iss"])
			}

			if pwrReset.UserID() != tt.userID {
				t.Errorf("case %d: wanted UserID=%v got=%v", i, tt.userID, pwrReset.UserID())
			}

			if bytes.Compare(pwrReset.Password(), user.Password("password")) != 0 {
				t.Errorf("case %d: wanted Password=%v got=%v", i, user.Password("password"), pwrReset.Password())
			}

			if *pwrReset.Callback() != testRedirectURL {
				t.Errorf("case %d: wanted callback=%v got=%v", i, testRedirectURL, pwrReset.Callback())
			}
		}
	}
}
Example #13
0
func TestHandleVerifyEmailResend(t *testing.T) {
	now := time.Now()
	tomorrow := now.Add(24 * time.Hour)
	yesterday := now.Add(-24 * time.Hour)

	privKey, err := key.GeneratePrivateKey()
	if err != nil {
		t.Fatalf("Failed to generate private key, error=%v", err)
	}

	signer := privKey.Signer()

	pubKey := *key.NewPublicKey(privKey.JWK())
	keysFunc := func() ([]key.PublicKey, error) {
		return []key.PublicKey{pubKey}, nil
	}

	makeToken := func(iss, sub, aud string, iat, exp time.Time) string {
		claims := oidc.NewClaims(iss, sub, aud, iat, exp)
		jwt, err := jose.NewSignedJWT(claims, signer)
		if err != nil {
			t.Fatalf("Failed to generate JWT, error=%v", err)
		}
		return jwt.Encode()
	}

	tests := []struct {
		bearerJWT         string
		userJWT           string
		redirectURL       url.URL
		wantCode          int
		verifyEmailUserID string
	}{
		{
			// The happy case
			bearerJWT: makeToken(testIssuerURL.String(),
				testClientID, testClientID, now, tomorrow),
			userJWT: makeToken(testIssuerURL.String(),
				"ID-1", testClientID, now, tomorrow),
			redirectURL: testRedirectURL,
			wantCode:    http.StatusOK,
		},
		{
			// Already verified
			bearerJWT: makeToken(testIssuerURL.String(),
				testClientID, testClientID, now, tomorrow),
			userJWT: makeToken(testIssuerURL.String(),
				"ID-1", testClientID, now, tomorrow),
			redirectURL:       testRedirectURL,
			wantCode:          http.StatusBadRequest,
			verifyEmailUserID: "ID-1",
		},
		{
			// Expired userJWT
			bearerJWT: makeToken(testIssuerURL.String(),
				testClientID, testClientID, now, tomorrow),
			userJWT: makeToken(testIssuerURL.String(),
				"ID-1", testClientID, now, yesterday),
			redirectURL: testRedirectURL,
			wantCode:    http.StatusUnauthorized,
		},
		{
			// Client ID is unknown
			bearerJWT: makeToken(testIssuerURL.String(),
				"fakeclientid", testClientID, now, tomorrow),
			userJWT: makeToken(testIssuerURL.String(),
				"ID-1", testClientID, now, tomorrow),
			redirectURL: testRedirectURL,
			wantCode:    http.StatusBadRequest,
		},
		{
			// No sub in user JWT
			bearerJWT: makeToken(testIssuerURL.String(),
				testClientID, testClientID, now, tomorrow),
			userJWT: makeToken(testIssuerURL.String(),
				"", testClientID, now, tomorrow),
			redirectURL: testRedirectURL,
			wantCode:    http.StatusBadRequest,
		},
		{
			// Unknown user
			bearerJWT: makeToken(testIssuerURL.String(),
				testClientID, testClientID, now, tomorrow),
			userJWT: makeToken(testIssuerURL.String(),
				"NonExistent", testClientID, now, tomorrow),
			redirectURL: testRedirectURL,
			wantCode:    http.StatusBadRequest,
		},
		{
			// No redirect URL
			bearerJWT: makeToken(testIssuerURL.String(),
				testClientID, testClientID, now, tomorrow),
			userJWT: makeToken(testIssuerURL.String(),
				"ID-1", testClientID, now, tomorrow),
			redirectURL: url.URL{},
			wantCode:    http.StatusBadRequest,
		},
	}

	for i, tt := range tests {
		f, err := makeTestFixtures()
		if tt.verifyEmailUserID != "" {
			usr, _ := f.userRepo.Get(nil, tt.verifyEmailUserID)
			usr.EmailVerified = true
			f.userRepo.Update(nil, usr)
		}

		if err != nil {
			t.Fatalf("case %d: could not make test fixtures: %v", i, err)
		}

		hdlr := handleVerifyEmailResendFunc(
			testIssuerURL,
			keysFunc,
			f.srv.UserEmailer,
			f.userRepo,
			f.clientManager)

		w := httptest.NewRecorder()
		u := "http://example.com"
		q := struct {
			Token       string `json:"token"`
			RedirectURI string `json:"redirectURI"`
		}{
			Token:       tt.userJWT,
			RedirectURI: tt.redirectURL.String(),
		}
		qBytes, err := json.Marshal(&q)
		if err != nil {
			t.Errorf("case %d: unable to marshal JSON: %q", i, err)
		}

		req, err := http.NewRequest("POST", u, bytes.NewReader(qBytes))
		req.Header.Set("Authorization", "Bearer "+tt.bearerJWT)

		if err != nil {
			t.Errorf("case %d: unable to form HTTP request: %v", i, err)
		}

		hdlr.ServeHTTP(w, req)
		if tt.wantCode != w.Code {
			t.Errorf("case %d: wantCode=%v, got=%v", i, tt.wantCode, w.Code)
			t.Logf("case %d: response body was: %v", i, w.Body.String())
		}
	}
}
Example #14
0
func TestVerifyJWTExpiry(t *testing.T) {
	privKey, err := key.GeneratePrivateKey()
	if err != nil {
		t.Fatalf("can't generate private key: %v", err)
	}
	makeToken := func(s string, exp time.Time, count int) *jose.JWT {
		jwt, err := jose.NewSignedJWT(jose.Claims(map[string]interface{}{
			"test":  s,
			"exp":   exp.UTC().Unix(),
			"count": count,
		}), privKey.Signer())
		if err != nil {
			t.Fatalf("Could not create signed JWT %v", err)
		}
		return jwt
	}

	t0 := time.Now()

	tests := []struct {
		name        string
		jwt         *jose.JWT
		now         time.Time
		wantErr     bool
		wantExpired bool
	}{
		{
			name: "valid jwt",
			jwt:  makeToken("foo", t0.Add(time.Hour), 1),
			now:  t0,
		},
		{
			name:    "invalid jwt",
			jwt:     &jose.JWT{},
			now:     t0,
			wantErr: true,
		},
		{
			name:        "expired jwt",
			jwt:         makeToken("foo", t0.Add(-time.Hour), 1),
			now:         t0,
			wantExpired: true,
		},
		{
			name:        "jwt expires soon enough to be marked expired",
			jwt:         makeToken("foo", t0, 1),
			now:         t0,
			wantExpired: true,
		},
	}

	for _, tc := range tests {
		func() {
			valid, err := verifyJWTExpiry(tc.now, tc.jwt.Encode())
			if err != nil {
				if !tc.wantErr {
					t.Errorf("%s: %v", tc.name, err)
				}
				return
			}
			if tc.wantErr {
				t.Errorf("%s: expected error", tc.name)
				return
			}

			if valid && tc.wantExpired {
				t.Errorf("%s: expected token to be expired", tc.name)
			}
			if !valid && !tc.wantExpired {
				t.Errorf("%s: expected token to be valid", tc.name)
			}
		}()
	}
}
Example #15
0
func TestWrapTranport(t *testing.T) {
	oldBackoff := backoff
	defer func() {
		backoff = oldBackoff
	}()
	backoff = wait.Backoff{
		Duration: 1 * time.Nanosecond,
		Steps:    3,
	}

	privKey, err := key.GeneratePrivateKey()
	if err != nil {
		t.Fatalf("can't generate private key: %v", err)
	}

	makeToken := func(s string, exp time.Time, count int) *jose.JWT {
		jwt, err := jose.NewSignedJWT(jose.Claims(map[string]interface{}{
			"test":  s,
			"exp":   exp.UTC().Unix(),
			"count": count,
		}), privKey.Signer())
		if err != nil {
			t.Fatalf("Could not create signed JWT %v", err)
		}
		return jwt
	}

	goodToken := makeToken("good", time.Now().Add(time.Hour), 0)
	goodToken2 := makeToken("good", time.Now().Add(time.Hour), 1)
	expiredToken := makeToken("good", time.Now().Add(-time.Hour), 0)

	str := func(s string) *string {
		return &s
	}
	tests := []struct {
		cfgIDToken      *jose.JWT
		cfgRefreshToken *string

		expectRequests []testRoundTrip

		expectRefreshes []testRefresh

		expectPersists []testPersist

		wantStatus int
		wantErr    bool
	}{
		{
			// Initial JWT is set, it is good, it is set as bearer.
			cfgIDToken: goodToken,

			expectRequests: []testRoundTrip{
				{
					expectBearerToken: goodToken.Encode(),
					returnHTTPStatus:  200,
				},
			},

			wantStatus: 200,
		},
		{
			// Initial JWT is set, but it's expired, so it gets refreshed.
			cfgIDToken:      expiredToken,
			cfgRefreshToken: str("rt1"),

			expectRefreshes: []testRefresh{
				{
					expectRefreshToken: "rt1",
					returnTokens: oauth2.TokenResponse{
						IDToken: goodToken.Encode(),
					},
				},
			},

			expectRequests: []testRoundTrip{
				{
					expectBearerToken: goodToken.Encode(),
					returnHTTPStatus:  200,
				},
			},

			expectPersists: []testPersist{
				{
					cfg: map[string]string{
						cfgIDToken:      goodToken.Encode(),
						cfgRefreshToken: "rt1",
					},
				},
			},

			wantStatus: 200,
		},
		{
			// Initial JWT is set, but it's expired, so it gets refreshed - this
			// time the refresh token itself is also refreshed
			cfgIDToken:      expiredToken,
			cfgRefreshToken: str("rt1"),

			expectRefreshes: []testRefresh{
				{
					expectRefreshToken: "rt1",
					returnTokens: oauth2.TokenResponse{
						IDToken:      goodToken.Encode(),
						RefreshToken: "rt2",
					},
				},
			},

			expectRequests: []testRoundTrip{
				{
					expectBearerToken: goodToken.Encode(),
					returnHTTPStatus:  200,
				},
			},

			expectPersists: []testPersist{
				{
					cfg: map[string]string{
						cfgIDToken:      goodToken.Encode(),
						cfgRefreshToken: "rt2",
					},
				},
			},

			wantStatus: 200,
		},
		{
			// Initial JWT is not set, so it gets refreshed.
			cfgRefreshToken: str("rt1"),

			expectRefreshes: []testRefresh{
				{
					expectRefreshToken: "rt1",
					returnTokens: oauth2.TokenResponse{
						IDToken: goodToken.Encode(),
					},
				},
			},

			expectRequests: []testRoundTrip{
				{
					expectBearerToken: goodToken.Encode(),
					returnHTTPStatus:  200,
				},
			},

			expectPersists: []testPersist{
				{
					cfg: map[string]string{
						cfgIDToken:      goodToken.Encode(),
						cfgRefreshToken: "rt1",
					},
				},
			},

			wantStatus: 200,
		},
		{
			// Expired token, but no refresh token.
			cfgIDToken: expiredToken,

			wantErr: true,
		},
		{
			// Initial JWT is not set, so it gets refreshed, but the server
			// rejects it when it is used, so it refreshes again, which
			// succeeds.
			cfgRefreshToken: str("rt1"),

			expectRefreshes: []testRefresh{
				{
					expectRefreshToken: "rt1",
					returnTokens: oauth2.TokenResponse{
						IDToken: goodToken.Encode(),
					},
				},
				{
					expectRefreshToken: "rt1",
					returnTokens: oauth2.TokenResponse{
						IDToken: goodToken2.Encode(),
					},
				},
			},

			expectRequests: []testRoundTrip{
				{
					expectBearerToken: goodToken.Encode(),
					returnHTTPStatus:  http.StatusUnauthorized,
				},
				{
					expectBearerToken: goodToken2.Encode(),
					returnHTTPStatus:  http.StatusOK,
				},
			},

			expectPersists: []testPersist{
				{
					cfg: map[string]string{
						cfgIDToken:      goodToken.Encode(),
						cfgRefreshToken: "rt1",
					},
				},
				{
					cfg: map[string]string{
						cfgIDToken:      goodToken2.Encode(),
						cfgRefreshToken: "rt1",
					},
				},
			},

			wantStatus: 200,
		},
		{
			// Initial JWT is but the server rejects it when it is used, so it
			// refreshes again, which succeeds.
			cfgRefreshToken: str("rt1"),
			cfgIDToken:      goodToken,

			expectRefreshes: []testRefresh{
				{
					expectRefreshToken: "rt1",
					returnTokens: oauth2.TokenResponse{
						IDToken: goodToken2.Encode(),
					},
				},
			},

			expectRequests: []testRoundTrip{
				{
					expectBearerToken: goodToken.Encode(),
					returnHTTPStatus:  http.StatusUnauthorized,
				},
				{
					expectBearerToken: goodToken2.Encode(),
					returnHTTPStatus:  http.StatusOK,
				},
			},

			expectPersists: []testPersist{
				{
					cfg: map[string]string{
						cfgIDToken:      goodToken2.Encode(),
						cfgRefreshToken: "rt1",
					},
				},
			},
			wantStatus: 200,
		},
	}

	for i, tt := range tests {
		client := &testOIDCClient{
			refreshes: tt.expectRefreshes,
		}

		persister := &testPersister{
			tt.expectPersists,
		}

		cfg := map[string]string{}
		if tt.cfgIDToken != nil {
			cfg[cfgIDToken] = tt.cfgIDToken.Encode()
		}

		if tt.cfgRefreshToken != nil {
			cfg[cfgRefreshToken] = *tt.cfgRefreshToken
		}

		ap := &oidcAuthProvider{
			refresher: &idTokenRefresher{
				client:    client,
				cfg:       cfg,
				persister: persister,
			},
		}

		if tt.cfgIDToken != nil {
			ap.initialIDToken = *tt.cfgIDToken
		}

		tstRT := &testRoundTripper{
			tt.expectRequests,
		}

		rt := ap.WrapTransport(tstRT)

		req, err := http.NewRequest("GET", "http://cluster.example.com", nil)
		if err != nil {
			t.Errorf("case %d: unexpected error making request: %v", i, err)
		}

		res, err := rt.RoundTrip(req)
		if tt.wantErr {
			if err == nil {
				t.Errorf("case %d: Expected non-nil error", i)
			}
		} else if err != nil {
			t.Errorf("case %d: unexpected error making round trip: %v", i, err)

		} else {
			if res.StatusCode != tt.wantStatus {
				t.Errorf("case %d: want=%d, got=%d", i, tt.wantStatus, res.StatusCode)
			}
		}

		if err = client.verify(); err != nil {
			t.Errorf("case %d: %v", i, err)
		}

		if err = persister.verify(); err != nil {
			t.Errorf("case %d: %v", i, err)
		}

		if err = tstRT.verify(); err != nil {
			t.Errorf("case %d: %v", i, err)
			continue
		}

	}
}
Example #16
0
	"github.com/coreos/dex/connector"
	"github.com/coreos/dex/db"
	"github.com/coreos/dex/user"
	"github.com/coreos/dex/user/manager"
)

var (
	clock = clockwork.NewFakeClock()

	testIssuerURL        = url.URL{Scheme: "https", Host: "auth.example.com"}
	testClientID         = "client.example.com"
	testClientSecret     = base64.URLEncoding.EncodeToString([]byte("secret"))
	testRedirectURL      = url.URL{Scheme: "https", Host: "client.example.com", Path: "/redirect"}
	testBadRedirectURL   = url.URL{Scheme: "https", Host: "bad.example.com", Path: "/redirect"}
	testResetPasswordURL = url.URL{Scheme: "https", Host: "auth.example.com", Path: "/resetPassword"}
	testPrivKey, _       = key.GeneratePrivateKey()
)

type tokenHandlerTransport struct {
	Handler http.Handler
	Token   string
}

func (t *tokenHandlerTransport) RoundTrip(r *http.Request) (*http.Response, error) {
	r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", t.Token))
	w := httptest.NewRecorder()
	t.Handler.ServeHTTP(w, r)
	resp := http.Response{
		StatusCode: w.Code,
		Header:     w.Header(),
		Body:       ioutil.NopCloser(w.Body),
Example #17
0
func TestEmailVerificationParseAndVerify(t *testing.T) {

	issuer, _ := url.Parse("http://example.com")
	otherIssuer, _ := url.Parse("http://bad.example.com")
	client := "myclient"
	user := User{ID: "1234", Email: "*****@*****.**"}
	callback, _ := url.Parse("http://client.example.com")
	expires := time.Hour * 3

	goodEV := NewEmailVerification(user, client, *issuer, *callback, expires)
	expiredEV := NewEmailVerification(user, client, *issuer, *callback, -expires)
	wrongIssuerEV := NewEmailVerification(user, client, *otherIssuer, *callback, expires)
	noSubEV := NewEmailVerification(User{}, client, *issuer, *callback, expires)

	privKey, err := key.GeneratePrivateKey()
	if err != nil {
		t.Fatalf("Failed to generate private key, error=%v", err)
	}
	signer := privKey.Signer()

	privKey2, err := key.GeneratePrivateKey()
	if err != nil {
		t.Fatalf("Failed to generate private key, error=%v", err)
	}
	otherSigner := privKey2.Signer()

	tests := []struct {
		ev      EmailVerification
		wantErr bool
		signer  jose.Signer
	}{

		{
			ev:      goodEV,
			signer:  signer,
			wantErr: false,
		},
		{
			ev:      expiredEV,
			signer:  signer,
			wantErr: true,
		},
		{
			ev:      wrongIssuerEV,
			signer:  signer,
			wantErr: true,
		},
		{
			ev:      goodEV,
			signer:  otherSigner,
			wantErr: true,
		},
		{
			ev:      noSubEV,
			signer:  signer,
			wantErr: true,
		},
	}

	for i, tt := range tests {

		jwt, err := jose.NewSignedJWT(tt.ev.Claims, tt.signer)
		if err != nil {
			t.Fatalf("Failed to generate JWT, error=%v", err)
		}
		token := jwt.Encode()

		ev, err := ParseAndVerifyEmailVerificationToken(token, *issuer,
			[]key.PublicKey{*key.NewPublicKey(privKey.JWK())})

		if tt.wantErr {
			t.Logf("err: %v", err)
			if err == nil {
				t.Errorf("case %d: want non-nil err, got nil", i)
			}
			continue
		}

		if err != nil {
			t.Errorf("case %d: non-nil err: %q", i, err)

		}

		if diff := pretty.Compare(tt.ev.Claims, ev.Claims); diff != "" {
			t.Errorf("case %d: Compare(want, got): %v", i, diff)
		}
	}
}
Example #18
0
func TestJWTVerifier(t *testing.T) {
	iss := "http://example.com"
	now := time.Now()
	future12 := now.Add(12 * time.Hour)
	past36 := now.Add(-36 * time.Hour)
	past12 := now.Add(-12 * time.Hour)

	priv1, err := key.GeneratePrivateKey()
	if err != nil {
		t.Fatalf("failed to generate private key, error=%v", err)
	}
	pk1 := *key.NewPublicKey(priv1.JWK())

	priv2, err := key.GeneratePrivateKey()
	if err != nil {
		t.Fatalf("failed to generate private key, error=%v", err)
	}
	pk2 := *key.NewPublicKey(priv2.JWK())

	jwtPK1, err := jose.NewSignedJWT(NewClaims(iss, "XXX", "XXX", past12, future12), priv1.Signer())
	if err != nil {
		t.Fatalf("unexpected error: %v", err)
	}

	jwtPK1BadClaims, err := jose.NewSignedJWT(NewClaims(iss, "XXX", "YYY", past12, future12), priv1.Signer())
	if err != nil {
		t.Fatalf("unexpected error: %v", err)
	}

	jwtExpired, err := jose.NewSignedJWT(NewClaims(iss, "XXX", "XXX", past36, past12), priv1.Signer())
	if err != nil {
		t.Fatalf("unexpected error: %v", err)
	}

	jwtPK2, err := jose.NewSignedJWT(NewClaims(iss, "XXX", "XXX", past12, future12), priv2.Signer())
	if err != nil {
		t.Fatalf("unexpected error: %v", err)
	}

	tests := []struct {
		verifier JWTVerifier
		jwt      jose.JWT
		wantErr  bool
	}{
		// JWT signed with available key
		{
			verifier: JWTVerifier{
				issuer:   "example.com",
				clientID: "XXX",
				syncFunc: func() error { return nil },
				keysFunc: func() []key.PublicKey {
					return []key.PublicKey{pk1}
				},
			},
			jwt:     *jwtPK1,
			wantErr: false,
		},

		// JWT signed with available key, with bad claims
		{
			verifier: JWTVerifier{
				issuer:   "example.com",
				clientID: "XXX",
				syncFunc: func() error { return nil },
				keysFunc: func() []key.PublicKey {
					return []key.PublicKey{pk1}
				},
			},
			jwt:     *jwtPK1BadClaims,
			wantErr: true,
		},

		// expired JWT signed with available key
		{
			verifier: JWTVerifier{
				issuer:   "example.com",
				clientID: "XXX",
				syncFunc: func() error { return nil },
				keysFunc: func() []key.PublicKey {
					return []key.PublicKey{pk1}
				},
			},
			jwt:     *jwtExpired,
			wantErr: true,
		},

		// JWT signed with unrecognized key, verifiable after sync
		{
			verifier: JWTVerifier{
				issuer:   "example.com",
				clientID: "XXX",
				syncFunc: func() error { return nil },
				keysFunc: func() func() []key.PublicKey {
					var i int
					return func() []key.PublicKey {
						defer func() { i++ }()
						return [][]key.PublicKey{
							[]key.PublicKey{pk1},
							[]key.PublicKey{pk2},
						}[i]
					}
				}(),
			},
			jwt:     *jwtPK2,
			wantErr: false,
		},

		// JWT signed with unrecognized key, not verifiable after sync
		{
			verifier: JWTVerifier{
				issuer:   "example.com",
				clientID: "XXX",
				syncFunc: func() error { return nil },
				keysFunc: func() []key.PublicKey {
					return []key.PublicKey{pk1}
				},
			},
			jwt:     *jwtPK2,
			wantErr: true,
		},

		// verifier gets no keys from keysFunc, still not verifiable after sync
		{
			verifier: JWTVerifier{
				issuer:   "example.com",
				clientID: "XXX",
				syncFunc: func() error { return nil },
				keysFunc: func() []key.PublicKey {
					return []key.PublicKey{}
				},
			},
			jwt:     *jwtPK1,
			wantErr: true,
		},

		// verifier gets no keys from keysFunc, verifiable after sync
		{
			verifier: JWTVerifier{
				issuer:   "example.com",
				clientID: "XXX",
				syncFunc: func() error { return nil },
				keysFunc: func() func() []key.PublicKey {
					var i int
					return func() []key.PublicKey {
						defer func() { i++ }()
						return [][]key.PublicKey{
							[]key.PublicKey{},
							[]key.PublicKey{pk2},
						}[i]
					}
				}(),
			},
			jwt:     *jwtPK2,
			wantErr: false,
		},
	}

	for i, tt := range tests {
		err := tt.verifier.Verify(tt.jwt)
		if tt.wantErr && (err == nil) {
			t.Errorf("case %d: wanted non-nil error", i)
		} else if !tt.wantErr && (err != nil) {
			t.Errorf("case %d: wanted nil error, got %v", i, err)
		}
	}
}
Example #19
0
func TestDBPrivateKeySetRepoSetGet(t *testing.T) {
	s1 := []byte("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx")
	s2 := []byte("oooooooooooooooooooooooooooooooo")
	s3 := []byte("wwwwwwwwwwwwwwwwwwwwwwwwwwwwwwww")

	keys := []*key.PrivateKey{}
	for i := 0; i < 2; i++ {
		k, err := key.GeneratePrivateKey()
		if err != nil {
			t.Fatalf("Unable to generate RSA key: %v", err)
		}
		keys = append(keys, k)
	}

	ks := key.NewPrivateKeySet(
		[]*key.PrivateKey{keys[0], keys[1]}, time.Now().Add(time.Minute))

	tests := []struct {
		setSecrets [][]byte
		getSecrets [][]byte
		wantErr    bool
	}{
		{
			// same secrets used to encrypt, decrypt
			setSecrets: [][]byte{s1, s2},
			getSecrets: [][]byte{s1, s2},
		},
		{
			// setSecrets got rotated, but getSecrets didn't yet.
			setSecrets: [][]byte{s2, s3},
			getSecrets: [][]byte{s1, s2},
		},
		{
			// getSecrets doesn't have s3
			setSecrets: [][]byte{s3},
			getSecrets: [][]byte{s1, s2},
			wantErr:    true,
		},
	}

	for i, tt := range tests {
		dbMap := connect(t)
		setRepo, err := db.NewPrivateKeySetRepo(dbMap, false, tt.setSecrets...)
		if err != nil {
			t.Fatalf(err.Error())
		}

		getRepo, err := db.NewPrivateKeySetRepo(dbMap, false, tt.getSecrets...)
		if err != nil {
			t.Fatalf(err.Error())
		}

		if err := setRepo.Set(ks); err != nil {
			t.Fatalf("case %d: Unexpected error: %v", i, err)
		}

		got, err := getRepo.Get()
		if tt.wantErr {
			if err == nil {
				t.Errorf("case %d: want err, got nil", i)
			}
			continue
		}
		if err != nil {
			t.Fatalf("case %d: Unexpected error: %v", i, err)
		}

		if diff := pretty.Compare(ks, got); diff != "" {
			t.Fatalf("case %d:Retrieved incorrect KeySet: Compare(want,got): %v", i, diff)
		}

	}
}
Example #20
0
func TestClientToken(t *testing.T) {
	now := time.Now()
	tomorrow := now.Add(24 * time.Hour)
	validClientID := "valid-client"
	ci := oidc.ClientIdentity{
		Credentials: oidc.ClientCredentials{
			ID:     validClientID,
			Secret: base64.URLEncoding.EncodeToString([]byte("secret")),
		},
		Metadata: oidc.ClientMetadata{
			RedirectURIs: []url.URL{
				{Scheme: "https", Host: "authn.example.com", Path: "/callback"},
			},
		},
	}
	repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{ci})
	if err != nil {
		t.Fatalf("Failed to create client identity repo: %v", err)
	}

	privKey, err := key.GeneratePrivateKey()
	if err != nil {
		t.Fatalf("Failed to generate private key, error=%v", err)
	}
	signer := privKey.Signer()
	pubKey := *key.NewPublicKey(privKey.JWK())

	validIss := "https://example.com"

	makeToken := func(iss, sub, aud string, iat, exp time.Time) string {
		claims := oidc.NewClaims(iss, sub, aud, iat, exp)
		jwt, err := jose.NewSignedJWT(claims, signer)
		if err != nil {
			t.Fatalf("Failed to generate JWT, error=%v", err)
		}
		return jwt.Encode()
	}

	validJWT := makeToken(validIss, validClientID, validClientID, now, tomorrow)
	invalidJWT := makeToken("", "", "", now, tomorrow)

	tests := []struct {
		keys     []key.PublicKey
		repo     client.ClientIdentityRepo
		header   string
		wantCode int
	}{
		// valid token
		{
			keys:     []key.PublicKey{pubKey},
			repo:     repo,
			header:   fmt.Sprintf("BEARER %s", validJWT),
			wantCode: http.StatusOK,
		},
		// invalid token
		{
			keys:     []key.PublicKey{pubKey},
			repo:     repo,
			header:   fmt.Sprintf("BEARER %s", invalidJWT),
			wantCode: http.StatusUnauthorized,
		},
		// empty header
		{
			keys:     []key.PublicKey{pubKey},
			repo:     repo,
			header:   "",
			wantCode: http.StatusUnauthorized,
		},
		// unparsable token
		{
			keys:     []key.PublicKey{pubKey},
			repo:     repo,
			header:   "BEARER xxx",
			wantCode: http.StatusUnauthorized,
		},
		// no verification keys
		{
			keys:     []key.PublicKey{},
			repo:     repo,
			header:   fmt.Sprintf("BEARER %s", validJWT),
			wantCode: http.StatusUnauthorized,
		},
		// nil repo
		{
			keys:     []key.PublicKey{pubKey},
			repo:     nil,
			header:   fmt.Sprintf("BEARER %s", validJWT),
			wantCode: http.StatusUnauthorized,
		},
		// empty repo
		{
			keys:     []key.PublicKey{pubKey},
			repo:     db.NewClientIdentityRepo(db.NewMemDB()),
			header:   fmt.Sprintf("BEARER %s", validJWT),
			wantCode: http.StatusUnauthorized,
		},
		// client not in repo
		{
			keys:     []key.PublicKey{pubKey},
			repo:     repo,
			header:   fmt.Sprintf("BEARER %s", makeToken(validIss, "DOESNT-EXIST", "DOESNT-EXIST", now, tomorrow)),
			wantCode: http.StatusUnauthorized,
		},
	}

	for i, tt := range tests {
		w := httptest.NewRecorder()
		mw := &clientTokenMiddleware{
			issuerURL: validIss,
			ciRepo:    tt.repo,
			keysFunc: func() ([]key.PublicKey, error) {
				return tt.keys, nil
			},
			next: staticHandler{},
		}
		req := &http.Request{
			Header: http.Header{
				"Authorization": []string{tt.header},
			},
		}

		mw.ServeHTTP(w, req)
		if tt.wantCode != w.Code {
			t.Errorf("case %d: invalid response code, want=%d, got=%d", i, tt.wantCode, w.Code)
		}
	}
}
Example #21
0
func TestJWTVerifier(t *testing.T) {
	iss := "http://example.com"
	now := time.Now()
	future12 := now.Add(12 * time.Hour)
	past36 := now.Add(-36 * time.Hour)
	past12 := now.Add(-12 * time.Hour)

	priv1, err := key.GeneratePrivateKey()
	if err != nil {
		t.Fatalf("failed to generate private key, error=%v", err)
	}
	pk1 := *key.NewPublicKey(priv1.JWK())

	priv2, err := key.GeneratePrivateKey()
	if err != nil {
		t.Fatalf("failed to generate private key, error=%v", err)
	}
	pk2 := *key.NewPublicKey(priv2.JWK())

	newJWT := func(issuer, subject string, aud interface{}, issuedAt, exp time.Time, signer jose.Signer) jose.JWT {
		jwt, err := jose.NewSignedJWT(NewClaims(issuer, subject, aud, issuedAt, exp), signer)
		if err != nil {
			t.Fatal(err)
		}
		return *jwt
	}

	tests := []struct {
		name     string
		verifier JWTVerifier
		jwt      jose.JWT
		wantErr  bool
	}{
		{
			name: "JWT signed with available key",
			verifier: JWTVerifier{
				issuer:   "example.com",
				clientID: "XXX",
				syncFunc: func() error { return nil },
				keysFunc: func() []key.PublicKey {
					return []key.PublicKey{pk1}
				},
			},
			jwt:     newJWT(iss, "XXX", "XXX", past12, future12, priv1.Signer()),
			wantErr: false,
		},
		{
			name: "JWT signed with available key, with bad claims",
			verifier: JWTVerifier{
				issuer:   "example.com",
				clientID: "XXX",
				syncFunc: func() error { return nil },
				keysFunc: func() []key.PublicKey {
					return []key.PublicKey{pk1}
				},
			},
			jwt:     newJWT(iss, "XXX", "YYY", past12, future12, priv1.Signer()),
			wantErr: true,
		},

		{
			name: "JWT signed with available key",
			verifier: JWTVerifier{
				issuer:   "example.com",
				clientID: "XXX",
				syncFunc: func() error { return nil },
				keysFunc: func() []key.PublicKey {
					return []key.PublicKey{pk1}
				},
			},
			jwt:     newJWT(iss, "XXX", []string{"YYY", "ZZZ"}, past12, future12, priv1.Signer()),
			wantErr: true,
		},

		{
			name: "expired JWT signed with available key",
			verifier: JWTVerifier{
				issuer:   "example.com",
				clientID: "XXX",
				syncFunc: func() error { return nil },
				keysFunc: func() []key.PublicKey {
					return []key.PublicKey{pk1}
				},
			},
			jwt:     newJWT(iss, "XXX", "XXX", past36, past12, priv1.Signer()),
			wantErr: true,
		},

		{
			name: "JWT signed with unrecognized key, verifiable after sync",
			verifier: JWTVerifier{
				issuer:   "example.com",
				clientID: "XXX",
				syncFunc: func() error { return nil },
				keysFunc: func() func() []key.PublicKey {
					var i int
					return func() []key.PublicKey {
						defer func() { i++ }()
						return [][]key.PublicKey{
							[]key.PublicKey{pk1},
							[]key.PublicKey{pk2},
						}[i]
					}
				}(),
			},
			jwt:     newJWT(iss, "XXX", "XXX", past36, future12, priv2.Signer()),
			wantErr: false,
		},

		{
			name: "JWT signed with unrecognized key, not verifiable after sync",
			verifier: JWTVerifier{
				issuer:   "example.com",
				clientID: "XXX",
				syncFunc: func() error { return nil },
				keysFunc: func() []key.PublicKey {
					return []key.PublicKey{pk1}
				},
			},
			jwt:     newJWT(iss, "XXX", "XXX", past12, future12, priv2.Signer()),
			wantErr: true,
		},

		{
			name: "verifier gets no keys from keysFunc, still not verifiable after sync",
			verifier: JWTVerifier{
				issuer:   "example.com",
				clientID: "XXX",
				syncFunc: func() error { return nil },
				keysFunc: func() []key.PublicKey {
					return []key.PublicKey{}
				},
			},
			jwt:     newJWT(iss, "XXX", "XXX", past12, future12, priv1.Signer()),
			wantErr: true,
		},

		{
			name: "verifier gets no keys from keysFunc, verifiable after sync",
			verifier: JWTVerifier{
				issuer:   "example.com",
				clientID: "XXX",
				syncFunc: func() error { return nil },
				keysFunc: func() func() []key.PublicKey {
					var i int
					return func() []key.PublicKey {
						defer func() { i++ }()
						return [][]key.PublicKey{
							[]key.PublicKey{},
							[]key.PublicKey{pk2},
						}[i]
					}
				}(),
			},
			jwt:     newJWT(iss, "XXX", "XXX", past12, future12, priv2.Signer()),
			wantErr: false,
		},

		{
			name: "JWT signed with available key, 'aud' is a string array",
			verifier: JWTVerifier{
				issuer:   "example.com",
				clientID: "XXX",
				syncFunc: func() error { return nil },
				keysFunc: func() []key.PublicKey {
					return []key.PublicKey{pk1}
				},
			},
			jwt:     newJWT(iss, "XXX", []string{"ZZZ", "XXX"}, past12, future12, priv1.Signer()),
			wantErr: false,
		},
		{
			name: "invalid issuer claim shouldn't trigger sync",
			verifier: JWTVerifier{
				issuer:   "example.com",
				clientID: "XXX",
				syncFunc: func() error {
					t.Errorf("invalid issuer claim shouldn't trigger a sync")
					return nil
				},
				keysFunc: func() func() []key.PublicKey {
					var i int
					return func() []key.PublicKey {
						defer func() { i++ }()
						return [][]key.PublicKey{
							[]key.PublicKey{},
							[]key.PublicKey{pk2},
						}[i]
					}
				}(),
			},
			jwt:     newJWT("invalid-issuer", "XXX", []string{"ZZZ", "XXX"}, past12, future12, priv2.Signer()),
			wantErr: true,
		},
	}

	for _, tt := range tests {
		err := tt.verifier.Verify(tt.jwt)
		if tt.wantErr && (err == nil) {
			t.Errorf("case %q: wanted non-nil error", tt.name)
		} else if !tt.wantErr && (err != nil) {
			t.Errorf("case %q: wanted nil error, got %v", tt.name, err)
		}
	}
}
Example #22
0
func TestPasswordResetParseAndVerify(t *testing.T) {

	issuer, _ := url.Parse("http://example.com")
	otherIssuer, _ := url.Parse("http://bad.example.com")
	client := "myclient"
	user := User{ID: "1234", Email: "*****@*****.**"}
	callback, _ := url.Parse("http://client.example.com")
	expires := time.Hour * 3
	password := Password("passy")

	goodPR := NewPasswordReset(user, password, *issuer, client, *callback, expires)
	goodPRNoCB := NewPasswordReset(user, password, *issuer, "", url.URL{}, expires)
	expiredPR := NewPasswordReset(user, password, *issuer, client, *callback, -expires)
	wrongIssuerPR := NewPasswordReset(user, password, *otherIssuer, client, *callback, expires)
	noSubPR := NewPasswordReset(User{}, password, *issuer, client, *callback, expires)
	noPWPR := NewPasswordReset(user, Password(""), *issuer, client, *callback, expires)
	noClientPR := NewPasswordReset(user, password, *issuer, "", *callback, expires)

	privKey, err := key.GeneratePrivateKey()
	if err != nil {
		t.Fatalf("Failed to generate private key, error=%v", err)
	}
	signer := privKey.Signer()

	privKey2, err := key.GeneratePrivateKey()
	if err != nil {
		t.Fatalf("Failed to generate private key, error=%v", err)
	}
	otherSigner := privKey2.Signer()

	tests := []struct {
		ev      PasswordReset
		wantErr bool
		signer  jose.Signer
	}{

		{
			ev:      goodPR,
			signer:  signer,
			wantErr: false,
		},
		{
			ev:      goodPRNoCB,
			signer:  signer,
			wantErr: false,
		},

		{
			ev:      expiredPR,
			signer:  signer,
			wantErr: true,
		},
		{
			ev:      wrongIssuerPR,
			signer:  signer,
			wantErr: true,
		},
		{
			ev:      goodPR,
			signer:  otherSigner,
			wantErr: true,
		},
		{
			ev:      noSubPR,
			signer:  signer,
			wantErr: true,
		},
		{
			ev:      noPWPR,
			signer:  signer,
			wantErr: true,
		},
		{
			ev:      noClientPR,
			signer:  signer,
			wantErr: true,
		},
	}

	for i, tt := range tests {

		token, err := tt.ev.Token(tt.signer)
		if err != nil {
			t.Errorf("case %d: non-nil error creating token: %v", i, err)
		}

		ev, err := ParseAndVerifyPasswordResetToken(token, *issuer,
			[]key.PublicKey{*key.NewPublicKey(privKey.JWK())})

		if tt.wantErr {
			t.Logf("err: %v", err)
			if err == nil {
				t.Errorf("case %d: want non-nil err, got nil", i)
			}
			continue
		}

		if err != nil {
			t.Errorf("case %d: non-nil err: %q", i, err)

		}

		if diff := pretty.Compare(tt.ev.claims, ev.claims); diff != "" {
			t.Errorf("case %d: Compare(want, got): %v", i, diff)
		}
	}
}
Example #23
0
func TestResetPasswordHandler(t *testing.T) {
	makeToken := func(userID, password, clientID string, callback url.URL, expires time.Duration, signer jose.Signer) string {
		pr := user.NewPasswordReset("ID-1",
			user.Password(password),
			testIssuerURL,
			clientID,
			callback,
			expires)

		jwt, err := jose.NewSignedJWT(pr.Claims, signer)
		if err != nil {
			t.Fatalf("couldn't make token: %q", err)
		}
		token := jwt.Encode()
		return token
	}
	goodSigner := key.NewPrivateKeySet([]*key.PrivateKey{testPrivKey},
		time.Now().Add(time.Minute)).Active().Signer()

	badKey, err := key.GeneratePrivateKey()
	if err != nil {
		t.Fatalf("couldn't make new key: %q", err)
	}
	badSigner := key.NewPrivateKeySet([]*key.PrivateKey{badKey},
		time.Now().Add(time.Minute)).Active().Signer()

	str := func(s string) []string {
		return []string{s}
	}

	user.PasswordHasher = func(s string) ([]byte, error) {
		return []byte(strings.ToUpper(s)), nil
	}
	defer func() {
		user.PasswordHasher = user.DefaultPasswordHasher
	}()

	tokenForCase := map[int]string{
		0: makeToken("ID-1", "password", testClientID, testRedirectURL, time.Hour*1, goodSigner),
		2: makeToken("ID-1", "password", testClientID, url.URL{}, time.Hour*1, goodSigner),
		5: makeToken("ID-1", "password", testClientID, url.URL{}, time.Hour*1, goodSigner),
	}

	tests := []struct {
		query url.Values

		method string

		wantFormValues *url.Values
		wantCode       int
		wantPassword   string
	}{
		// Scenario 1: Happy Path
		{ // Case 0
			// Step 1.1 - User clicks link in email, has valid token.
			query: url.Values{
				"token": str(tokenForCase[0]),
			},
			method: "GET",

			wantCode: http.StatusOK,
			wantFormValues: &url.Values{
				"password": str(""),
				"token":    str(tokenForCase[0]),
			},
			wantPassword: "******",
		},
		{ // Case 1
			// Step 1.2 - User enters in new valid password, password is changed, user is redirected.
			query: url.Values{
				"token":    str(makeToken("ID-1", "password", testClientID, testRedirectURL, time.Hour*1, goodSigner)),
				"password": str("new_password"),
			},
			method: "POST",

			wantCode:       http.StatusSeeOther,
			wantFormValues: &url.Values{},
			wantPassword:   "******",
		},
		// Scenario 2: Happy Path, but without redirect.
		{ // Case 2
			// Step 2.1 - User clicks link in email, has valid token.
			query: url.Values{
				"token": str(tokenForCase[2]),
			},
			method: "GET",

			wantCode: http.StatusOK,
			wantFormValues: &url.Values{
				"password": str(""),
				"token":    str(tokenForCase[2]),
			},
			wantPassword: "******",
		},
		{ // Case 3
			// Step 2.2 - User enters in new valid password, password is changed, user is redirected.
			query: url.Values{
				"token":    str(makeToken("ID-1", "password", testClientID, url.URL{}, time.Hour*1, goodSigner)),
				"password": str("new_password"),
			},
			method: "POST",

			// no redirect
			wantCode:       http.StatusOK,
			wantFormValues: &url.Values{},
			wantPassword:   "******",
		},
		// Errors
		{ // Case 4
			// Step 1.1.1 - User clicks link in email, has invalid token.
			query: url.Values{
				"token": str(makeToken("ID-1", "password", testClientID, testRedirectURL, time.Hour*1, badSigner)),
			},
			method: "GET",

			wantCode:       http.StatusBadRequest,
			wantFormValues: &url.Values{},
			wantPassword:   "******",
		},

		{ // Case 5
			// Step 2.2.1 - User enters in new valid password, password is changed, no redirect
			query: url.Values{
				"token":    str(tokenForCase[5]),
				"password": str("shrt"),
			},
			method: "POST",

			// no redirect
			wantCode: http.StatusBadRequest,
			wantFormValues: &url.Values{
				"password": str(""),
				"token":    str(tokenForCase[5]),
			},
			wantPassword: "******",
		},
		{ // Case 6
			// Step 2.2.2 - User enters in new valid password, with suspicious token.
			query: url.Values{
				"token":    str(makeToken("ID-1", "password", testClientID, url.URL{}, time.Hour*1, badSigner)),
				"password": str("shrt"),
			},
			method: "POST",

			// no redirect
			wantCode:       http.StatusBadRequest,
			wantFormValues: &url.Values{},
			wantPassword:   "******",
		},
		{ // Case 7
			// Token lacking client id
			query: url.Values{
				"token":    str(makeToken("ID-1", "password", "", url.URL{}, time.Hour*1, goodSigner)),
				"password": str("shrt"),
			},
			method: "GET",

			wantCode:     http.StatusBadRequest,
			wantPassword: "******",
		},
		{ // Case 8
			// Token lacking client id
			query: url.Values{
				"token":    str(makeToken("ID-1", "password", "", url.URL{}, time.Hour*1, goodSigner)),
				"password": str("shrt"),
			},
			method: "POST",

			wantCode:     http.StatusBadRequest,
			wantPassword: "******",
		},
	}
	for i, tt := range tests {
		f, err := makeTestFixtures()
		if err != nil {
			t.Fatalf("case %d: could not make test fixtures: %v", i, err)
		}

		hdlr := ResetPasswordHandler{
			tpl:       f.srv.ResetPasswordTemplate,
			issuerURL: testIssuerURL,
			um:        f.srv.UserManager,
			keysFunc:  f.srv.KeyManager.PublicKeys,
		}

		w := httptest.NewRecorder()
		var req *http.Request
		u := testIssuerURL
		u.Path = httpPathResetPassword
		if tt.method == "POST" {
			req, err = http.NewRequest(tt.method, u.String(), strings.NewReader(tt.query.Encode()))
			req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
		} else {
			u.RawQuery = tt.query.Encode()
			req, err = http.NewRequest(tt.method, u.String(), nil)
		}
		if err != nil {
			t.Errorf("case %d: unable to form HTTP request: %v", i, err)
		}

		hdlr.ServeHTTP(w, req)

		if tt.wantCode != w.Code {
			t.Errorf("case %d: wantCode=%v, got=%v", i, tt.wantCode, w.Code)
			continue
		}

		values, err := html.FormValues("#resetPasswordForm", bytes.NewReader(w.Body.Bytes()))
		if err != nil {
			t.Errorf("case %d: could not parse form: %v", i, err)
		}

		if tt.wantFormValues != nil {
			if diff := pretty.Compare(*tt.wantFormValues, values); diff != "" {
				t.Errorf("case %d: Compare(wantFormValues, got) = %v", i, diff)
			}
		}
		pwi, err := f.srv.PasswordInfoRepo.Get(nil, "ID-1")
		if err != nil {
			t.Errorf("case %d: Error getting Password info: %v", i, err)
		}
		if tt.wantPassword != string(pwi.Password) {
			t.Errorf("case %d: wantPassword=%v, got=%v", i, tt.wantPassword, string(pwi.Password))
		}

	}
}
Example #24
0
func TestClientToken(t *testing.T) {
	now := time.Now()
	tomorrow := now.Add(24 * time.Hour)
	clientMetadata := oidc.ClientMetadata{
		RedirectURIs: []url.URL{
			{Scheme: "https", Host: "authn.example.com", Path: "/callback"},
		},
	}

	dbm := db.NewMemDB()
	clientRepo := db.NewClientRepo(dbm)
	clientManager := clientmanager.NewClientManager(clientRepo, db.TransactionFactory(dbm), clientmanager.ManagerOptions{})
	cli := client.Client{
		Metadata: clientMetadata,
	}
	creds, err := clientManager.New(cli)
	if err != nil {
		t.Fatalf("Failed to create client: %v", err)
	}
	validClientID := creds.ID

	privKey, err := key.GeneratePrivateKey()
	if err != nil {
		t.Fatalf("Failed to generate private key, error=%v", err)
	}
	signer := privKey.Signer()
	pubKey := *key.NewPublicKey(privKey.JWK())

	validIss := "https://example.com"

	makeToken := func(iss, sub, aud string, iat, exp time.Time) string {
		claims := oidc.NewClaims(iss, sub, aud, iat, exp)
		jwt, err := jose.NewSignedJWT(claims, signer)
		if err != nil {
			t.Fatalf("Failed to generate JWT, error=%v", err)
		}
		return jwt.Encode()
	}

	validJWT := makeToken(validIss, validClientID, validClientID, now, tomorrow)
	invalidJWT := makeToken("", "", "", now, tomorrow)

	tests := []struct {
		keys     []key.PublicKey
		manager  *clientmanager.ClientManager
		header   string
		wantCode int
	}{
		// valid token
		{
			keys:     []key.PublicKey{pubKey},
			manager:  clientManager,
			header:   fmt.Sprintf("BEARER %s", validJWT),
			wantCode: http.StatusOK,
		},
		// invalid token
		{
			keys:     []key.PublicKey{pubKey},
			manager:  clientManager,
			header:   fmt.Sprintf("BEARER %s", invalidJWT),
			wantCode: http.StatusUnauthorized,
		},
		// empty header
		{
			keys:     []key.PublicKey{pubKey},
			manager:  clientManager,
			header:   "",
			wantCode: http.StatusUnauthorized,
		},
		// unparsable token
		{
			keys:     []key.PublicKey{pubKey},
			manager:  clientManager,
			header:   "BEARER xxx",
			wantCode: http.StatusUnauthorized,
		},
		// no verification keys
		{
			keys:     []key.PublicKey{},
			manager:  clientManager,
			header:   fmt.Sprintf("BEARER %s", validJWT),
			wantCode: http.StatusUnauthorized,
		},
		// nil repo
		{
			keys:     []key.PublicKey{pubKey},
			manager:  nil,
			header:   fmt.Sprintf("BEARER %s", validJWT),
			wantCode: http.StatusUnauthorized,
		},
		// empty repo
		{
			keys:     []key.PublicKey{pubKey},
			manager:  clientmanager.NewClientManager(db.NewClientRepo(db.NewMemDB()), db.TransactionFactory(db.NewMemDB()), clientmanager.ManagerOptions{}),
			header:   fmt.Sprintf("BEARER %s", validJWT),
			wantCode: http.StatusUnauthorized,
		},
		// client not in repo
		{
			keys:     []key.PublicKey{pubKey},
			manager:  clientManager,
			header:   fmt.Sprintf("BEARER %s", makeToken(validIss, "DOESNT-EXIST", "DOESNT-EXIST", now, tomorrow)),
			wantCode: http.StatusUnauthorized,
		},
	}

	for i, tt := range tests {
		w := httptest.NewRecorder()
		mw := &clientTokenMiddleware{
			issuerURL: validIss,
			ciManager: tt.manager,
			keysFunc: func() ([]key.PublicKey, error) {
				return tt.keys, nil
			},
			next: staticHandler{},
		}
		req := &http.Request{
			Header: http.Header{
				"Authorization": []string{tt.header},
			},
		}

		mw.ServeHTTP(w, req)
		if tt.wantCode != w.Code {
			t.Errorf("case %d: invalid response code, want=%d, got=%d", i, tt.wantCode, w.Code)
		}
	}
}