예제 #1
0
파일: capabilities.go 프로젝트: tmc/twilio
// Generate creates a Capabilities Token given some configuration values.
// See https://www.twilio.com/docs/api/client/capability-tokens for details.
func Generate(c Capabilities, expires time.Duration) (string, error) {
	signer := jose.NewSignerHMAC("", []byte(c.AuthToken))
	claims := jose.Claims{}

	claims.Add("iss", c.AccountSid)
	claims.Add("exp", Clock.Now().Add(expires).Unix())
	scopes := []string{}
	if c.AllowClientOutgoing != "" {
		scope := fmt.Sprintf("scope:client:outgoing?appSid=%s", c.AllowClientOutgoing)
		if c.AllowClientIncoming != "" {
			scope += fmt.Sprintf("&clientName=%s", c.AllowClientIncoming)
		}
		scopes = append(scopes, scope)
	}
	if c.AllowClientIncoming != "" {
		scopes = append(scopes, fmt.Sprintf("scope:client:incoming?clientName=%s", c.AllowClientIncoming))
	}
	claims.Add("scope", strings.Join(scopes, " "))

	jwt, err := jose.NewSignedJWT(claims, signer)
	if err != nil {
		return "", err
	}
	return jwt.Encode(), nil
}
예제 #2
0
func TestGetClientIDFromAuthorizedRequest(t *testing.T) {
	now := time.Now()
	tomorrow := now.Add(24 * time.Hour)

	privKey, err := key.GeneratePrivateKey()
	if err != nil {
		t.Fatalf("Failed to generate private key, error=%v", err)
	}

	signer := privKey.Signer()

	makeToken := func(iss, sub, aud string, iat, exp time.Time) string {
		claims := oidc.NewClaims(iss, sub, aud, iat, exp)
		jwt, err := jose.NewSignedJWT(claims, signer)
		if err != nil {
			t.Fatalf("Failed to generate JWT, error=%v", err)
		}
		return jwt.Encode()
	}

	tests := []struct {
		header     string
		wantClient string
		wantErr    bool
	}{
		{
			header:     fmt.Sprintf("BEARER %s", makeToken("iss", "CLIENT_ID", "", now, tomorrow)),
			wantClient: "CLIENT_ID",
			wantErr:    false,
		},
		{
			header:  fmt.Sprintf("BEARER %s", makeToken("iss", "", "", now, tomorrow)),
			wantErr: true,
		},
	}

	for i, tt := range tests {
		req := &http.Request{
			Header: http.Header{
				"Authorization": []string{tt.header},
			},
		}
		gotClient, err := getClientIDFromAuthorizedRequest(req)
		if tt.wantErr {
			if err == nil {
				t.Errorf("case %d: want non-nil err", i)
			}
			continue
		}

		if err != nil {
			t.Errorf("case %d: got err: %q", i, err)
			continue
		}

		if gotClient != tt.wantClient {
			t.Errorf("case %d: want=%v, got=%v", i, tt.wantClient, gotClient)
		}
	}
}
예제 #3
0
파일: server.go 프로젝트: derekparker/dex
func (s *Server) ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, error) {
	ok, err := s.ClientIdentityRepo.Authenticate(creds)
	if err != nil {
		log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err)
		return nil, oauth2.NewError(oauth2.ErrorServerError)
	}
	if !ok {
		return nil, oauth2.NewError(oauth2.ErrorInvalidClient)
	}

	signer, err := s.KeyManager.Signer()
	if err != nil {
		log.Errorf("Failed to generate ID token: %v", err)
		return nil, oauth2.NewError(oauth2.ErrorServerError)
	}

	now := time.Now()
	exp := now.Add(s.SessionManager.ValidityWindow)
	claims := oidc.NewClaims(s.IssuerURL.String(), creds.ID, creds.ID, now, exp)
	claims.Add("name", creds.ID)

	jwt, err := jose.NewSignedJWT(claims, signer)
	if err != nil {
		log.Errorf("Failed to generate ID token: %v", err)
		return nil, oauth2.NewError(oauth2.ErrorServerError)
	}

	log.Infof("Client token sent: clientID=%s", creds.ID)

	return jwt, nil
}
예제 #4
0
파일: grpc.go 프로젝트: otsimo/accounts
func (s *grpcServer) Token(userID, clientID string, iat, exp time.Time) (*jose.JWT, string, error) {
	signer, err := s.server.KeyManager.Signer()
	if err != nil {
		log.Errorf("grpc.go: Failed to generate ID token: %v", err)
		return nil, "", oauth2.NewError(oauth2.ErrorServerError)
	}

	user, err := s.server.UserRepo.Get(nil, userID)
	if err != nil {
		log.Errorf("grpc.go: Failed to fetch user %q from repo: %v: ", userID, err)
		return nil, "", oauth2.NewError(oauth2.ErrorServerError)
	}
	claims := oidc.NewClaims(s.server.IssuerURL.String(), userID, clientID, iat, exp)
	user.AddToClaims(claims)

	if user.Admin {
		claims.Add(OtsimoUserTypeClaim, "adm")
	}

	jwt, err := jose.NewSignedJWT(claims, signer)
	if err != nil {
		log.Errorf("grpc.go: Failed to generate ID token: %v", err)
		return nil, "", oauth2.NewError(oauth2.ErrorServerError)
	}

	refreshToken, err := s.server.RefreshTokenRepo.Create(user.ID, clientID)
	if err != nil {
		log.Errorf("grpc.go: Failed to generate refresh token: %v", err)
		return nil, "", oauth2.NewError(oauth2.ErrorServerError)
	}

	return jwt, refreshToken, nil
}
예제 #5
0
파일: email.go 프로젝트: philips/dex
// SendResetPasswordEmail sends a password reset email to the user specified by the email addresss, containing a link with a signed token which can be visitied to initiate the password change/reset process.
// This method DOES NOT check for client ID, redirect URL validity - it is expected that upstream users have already done so.
// If there is no emailer is configured, the URL of the aforementioned link is returned, otherwise nil is returned.
func (u *UserEmailer) SendResetPasswordEmail(email string, redirectURL url.URL, clientID string) (*url.URL, error) {
	usr, err := u.ur.GetByEmail(nil, email)
	if err == user.ErrorNotFound {
		log.Errorf("No Such user for email: %q", email)
		return nil, err
	}
	if err != nil {
		log.Errorf("Error getting user: %q", err)
		return nil, err
	}

	pwi, err := u.pwi.Get(nil, usr.ID)
	if err == user.ErrorNotFound {
		// TODO(bobbyrullo): In this case, maybe send a different email explaining that
		// they don't have a local password.
		log.Errorf("No Password for userID: %q", usr.ID)
		return nil, err
	}
	if err != nil {
		log.Errorf("Error getting password: %q", err)
		return nil, err
	}

	signer, err := u.signerFn()
	if err != nil || signer == nil {
		log.Errorf("error getting signer: %v (%v)", err, signer)
		return nil, err
	}

	passwordReset := user.NewPasswordReset(usr, pwi.Password, u.issuerURL,
		clientID, redirectURL, u.tokenValidityWindow)
	jwt, err := jose.NewSignedJWT(passwordReset.Claims, signer)
	if err != nil {
		log.Errorf("error constructing or signing PasswordReset JWT: %v", err)
		return nil, err
	}
	token := jwt.Encode()

	resetURL := u.passwordResetURL
	q := resetURL.Query()
	q.Set("token", token)
	resetURL.RawQuery = q.Encode()

	if u.emailer != nil {
		err = u.emailer.SendMail(u.fromAddress, "Reset your password.", "password-reset",
			map[string]interface{}{
				"email": usr.Email,
				"link":  resetURL.String(),
			}, usr.Email)
		if err != nil {
			log.Errorf("error sending password reset email %v: ", err)
		}
		return nil, err
	}
	return &resetURL, nil
}
예제 #6
0
func makeUserToken(issuerURL url.URL, userID, clientID string, expires time.Duration, privKey *key.PrivateKey) string {

	signer := key.NewPrivateKeySet([]*key.PrivateKey{testPrivKey},
		time.Now().Add(time.Minute)).Active().Signer()
	claims := oidc.NewClaims(issuerURL.String(), userID, clientID, time.Now(), time.Now().Add(expires))
	jwt, err := jose.NewSignedJWT(claims, signer)
	if err != nil {
		panic(fmt.Sprintf("could not make token: %v", err))
	}
	return jwt.Encode()
}
예제 #7
0
func (op *oidcProvider) generateToken(t *testing.T, iss, sub, aud string, usernameClaim, value string, iat, exp time.Time) string {
	signer := op.privKey.Signer()
	claims := oidc.NewClaims(iss, sub, aud, iat, exp)
	claims.Add(usernameClaim, value)

	jwt, err := jose.NewSignedJWT(claims, signer)
	if err != nil {
		t.Fatalf("Cannot generate token: %v", err)
		return ""
	}
	return jwt.Encode()
}
예제 #8
0
// Token serializes the EmailVerification into a signed JWT.
func (e EmailVerification) Token(signer jose.Signer) (string, error) {
	if signer == nil {
		return "", errors.New("no signer")
	}

	jwt, err := jose.NewSignedJWT(e.claims, signer)
	if err != nil {
		return "", err
	}

	return jwt.Encode(), nil
}
예제 #9
0
파일: password.go 프로젝트: no2key/dex
// Token serializes the PasswordReset into a signed JWT.
func (e PasswordReset) Token(signer jose.Signer) (string, error) {
	if signer == nil {
		return "", errors.New("no signer")
	}

	jwt, err := jose.NewSignedJWT(e.claims, signer)
	if err != nil {
		return "", err
	}

	return jwt.Encode(), nil
}
예제 #10
0
파일: server.go 프로젝트: derekparker/dex
func (s *Server) RefreshToken(creds oidc.ClientCredentials, token string) (*jose.JWT, error) {
	ok, err := s.ClientIdentityRepo.Authenticate(creds)
	if err != nil {
		log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err)
		return nil, oauth2.NewError(oauth2.ErrorServerError)
	}
	if !ok {
		log.Errorf("Failed to Authenticate client %s", creds.ID)
		return nil, oauth2.NewError(oauth2.ErrorInvalidClient)
	}

	userID, err := s.RefreshTokenRepo.Verify(creds.ID, token)
	switch err {
	case nil:
		break
	case refresh.ErrorInvalidToken:
		return nil, oauth2.NewError(oauth2.ErrorInvalidRequest)
	case refresh.ErrorInvalidClientID:
		return nil, oauth2.NewError(oauth2.ErrorInvalidClient)
	default:
		return nil, oauth2.NewError(oauth2.ErrorServerError)
	}

	user, err := s.UserRepo.Get(nil, userID)
	if err != nil {
		// The error can be user.ErrorNotFound, but we are not deleting
		// user at this moment, so this shouldn't happen.
		log.Errorf("Failed to fetch user %q from repo: %v: ", userID, err)
		return nil, oauth2.NewError(oauth2.ErrorServerError)
	}

	signer, err := s.KeyManager.Signer()
	if err != nil {
		log.Errorf("Failed to refresh ID token: %v", err)
		return nil, oauth2.NewError(oauth2.ErrorServerError)
	}

	now := time.Now()
	expireAt := now.Add(session.DefaultSessionValidityWindow)

	claims := oidc.NewClaims(s.IssuerURL.String(), user.ID, creds.ID, now, expireAt)
	user.AddToClaims(claims)

	jwt, err := jose.NewSignedJWT(claims, signer)
	if err != nil {
		log.Errorf("Failed to generate ID token: %v", err)
		return nil, oauth2.NewError(oauth2.ErrorServerError)
	}

	log.Infof("New token sent: clientID=%s", creds.ID)

	return jwt, nil
}
예제 #11
0
파일: email.go 프로젝트: GamerockSA/dex
func (u *UserEmailer) signedClaimsToken(claims jose.Claims) (string, error) {
	signer, err := u.signerFn()
	if err != nil || signer == nil {
		log.Errorf("error getting signer: %v (%v)", err, signer)
		return "", err
	}

	jwt, err := jose.NewSignedJWT(claims, signer)
	if err != nil {
		log.Errorf("error constructing or signing a JWT: %v", err)
		return "", err
	}
	return jwt.Encode(), nil
}
예제 #12
0
func generateToken(t *testing.T, op *oidctesting.OIDCProvider, iss, sub, aud string, usernameClaim, value, groupsClaim string, groups interface{}, iat, exp time.Time) string {
	signer := op.PrivKey.Signer()
	claims := oidc.NewClaims(iss, sub, aud, iat, exp)
	claims.Add(usernameClaim, value)
	if groups != nil && groupsClaim != "" {
		claims.Add(groupsClaim, groups)
	}

	jwt, err := jose.NewSignedJWT(claims, signer)
	if err != nil {
		t.Fatalf("Cannot generate token: %v", err)
		return ""
	}
	return jwt.Encode()
}
예제 #13
0
//Creates a signed JWT token for the requesting subject and issuer URL
func createJWTToken(subject string, issuerUrl string, tokenttl time.Duration, scopesMap map[string]struct{}, unsignedToken bool) (jwt *jose.JWT, err error) {

	privateKey, err := privateKey()
	if err != nil {
		return nil, base.HTTPErrorf(http.StatusInternalServerError, "Error getting private RSA Key")
	}

	now := time.Now()
	expiresIn := tokenttl
	expiryTime := now.Add(expiresIn)

	cl := jose.Claims{
		"sub": subject,
		"iat": now.Unix(),
		"exp": expiryTime.Unix(),
		"iss": issuerUrl,
		"aud": testProviderAud,
	}

	if _, ok := scopesMap["email"]; ok {
		cl["email"] = subject + "@syncgatewayoidctesting.com"
	}

	if _, ok := scopesMap["profile"]; ok {
		cl["nickname"] = "slim jim"
	}

	signer := jose.NewSignerRSA(testProviderKeyIdentifier, *privateKey)
	if !unsignedToken {
		jwt, err = jose.NewSignedJWT(cl, signer)
		if err != nil {
			return nil, err
		}

	} else {

		header := jose.JOSEHeader{
			"alg": signer.Alg(),
			"kid": signer.ID(),
		}
		unsignedJWT, err := jose.NewJWT(header, cl)
		if err != nil {
			return nil, err
		}
		jwt = &unsignedJWT
	}
	return
}
예제 #14
0
파일: email.go 프로젝트: GamerockSA/dex
// SendEmailVerification sends an email to the user with the given userID containing a link which when visited marks the user as having had their email verified.
// If there is no emailer is configured, the URL of the aforementioned link is returned, otherwise nil is returned.
func (u *UserEmailer) SendEmailVerification(userID, clientID string, redirectURL url.URL) (*url.URL, error) {
	usr, err := u.ur.Get(nil, userID)
	if err == user.ErrorNotFound {
		log.Errorf("No Such user for ID: %q", userID)
		return nil, err
	}
	if err != nil {
		log.Errorf("Error getting user: %q", err)
		return nil, err
	}

	ev := user.NewEmailVerification(usr, clientID, u.issuerURL, redirectURL, u.tokenValidityWindow)

	signer, err := u.signerFn()
	if err != nil || signer == nil {
		log.Errorf("error getting signer: %v (signer: %v)", err, signer)
		return nil, err
	}

	jwt, err := jose.NewSignedJWT(ev.Claims, signer)
	if err != nil {
		log.Errorf("error constructing or signing EmailVerification JWT: %v", err)
		return nil, err
	}
	token := jwt.Encode()

	verifyURL := u.verifyEmailURL
	q := verifyURL.Query()
	q.Set("token", token)
	verifyURL.RawQuery = q.Encode()

	if u.emailer != nil {
		err = u.emailer.SendMail(u.fromAddress, "Please verify your email address.", "verify-email",
			map[string]interface{}{
				"email": usr.Email,
				"link":  verifyURL.String(),
			}, usr.Email)
		if err != nil {
			log.Errorf("error sending email verification email %v: ", err)
		}
		return nil, err

	}
	return &verifyURL, nil
}
예제 #15
0
func TestHandleVerifyEmailResend(t *testing.T) {
	now := time.Now()
	tomorrow := now.Add(24 * time.Hour)
	yesterday := now.Add(-24 * time.Hour)

	privKey, err := key.GeneratePrivateKey()
	if err != nil {
		t.Fatalf("Failed to generate private key, error=%v", err)
	}

	signer := privKey.Signer()

	pubKey := *key.NewPublicKey(privKey.JWK())
	keysFunc := func() ([]key.PublicKey, error) {
		return []key.PublicKey{pubKey}, nil
	}

	makeToken := func(iss, sub, aud string, iat, exp time.Time) string {
		claims := oidc.NewClaims(iss, sub, aud, iat, exp)
		jwt, err := jose.NewSignedJWT(claims, signer)
		if err != nil {
			t.Fatalf("Failed to generate JWT, error=%v", err)
		}
		return jwt.Encode()
	}

	tests := []struct {
		bearerJWT         string
		userJWT           string
		redirectURL       url.URL
		wantCode          int
		verifyEmailUserID string
	}{
		{
			// The happy case
			bearerJWT: makeToken(testIssuerURL.String(),
				testClientID, testClientID, now, tomorrow),
			userJWT: makeToken(testIssuerURL.String(),
				"ID-1", testClientID, now, tomorrow),
			redirectURL: testRedirectURL,
			wantCode:    http.StatusOK,
		},
		{
			// Already verified
			bearerJWT: makeToken(testIssuerURL.String(),
				testClientID, testClientID, now, tomorrow),
			userJWT: makeToken(testIssuerURL.String(),
				"ID-1", testClientID, now, tomorrow),
			redirectURL:       testRedirectURL,
			wantCode:          http.StatusBadRequest,
			verifyEmailUserID: "ID-1",
		},
		{
			// Expired userJWT
			bearerJWT: makeToken(testIssuerURL.String(),
				testClientID, testClientID, now, tomorrow),
			userJWT: makeToken(testIssuerURL.String(),
				"ID-1", testClientID, now, yesterday),
			redirectURL: testRedirectURL,
			wantCode:    http.StatusUnauthorized,
		},
		{
			// Client ID is unknown
			bearerJWT: makeToken(testIssuerURL.String(),
				"fakeclientid", testClientID, now, tomorrow),
			userJWT: makeToken(testIssuerURL.String(),
				"ID-1", testClientID, now, tomorrow),
			redirectURL: testRedirectURL,
			wantCode:    http.StatusBadRequest,
		},
		{
			// No sub in user JWT
			bearerJWT: makeToken(testIssuerURL.String(),
				testClientID, testClientID, now, tomorrow),
			userJWT: makeToken(testIssuerURL.String(),
				"", testClientID, now, tomorrow),
			redirectURL: testRedirectURL,
			wantCode:    http.StatusBadRequest,
		},
		{
			// Unknown user
			bearerJWT: makeToken(testIssuerURL.String(),
				testClientID, testClientID, now, tomorrow),
			userJWT: makeToken(testIssuerURL.String(),
				"NonExistent", testClientID, now, tomorrow),
			redirectURL: testRedirectURL,
			wantCode:    http.StatusBadRequest,
		},
		{
			// No redirect URL
			bearerJWT: makeToken(testIssuerURL.String(),
				testClientID, testClientID, now, tomorrow),
			userJWT: makeToken(testIssuerURL.String(),
				"ID-1", testClientID, now, tomorrow),
			redirectURL: url.URL{},
			wantCode:    http.StatusBadRequest,
		},
	}

	for i, tt := range tests {
		f, err := makeTestFixtures()
		if tt.verifyEmailUserID != "" {
			usr, _ := f.userRepo.Get(nil, tt.verifyEmailUserID)
			usr.EmailVerified = true
			f.userRepo.Update(nil, usr)
		}

		if err != nil {
			t.Fatalf("case %d: could not make test fixtures: %v", i, err)
		}

		hdlr := handleVerifyEmailResendFunc(
			testIssuerURL,
			keysFunc,
			f.srv.UserEmailer,
			f.userRepo,
			f.clientManager)

		w := httptest.NewRecorder()
		u := "http://example.com"
		q := struct {
			Token       string `json:"token"`
			RedirectURI string `json:"redirectURI"`
		}{
			Token:       tt.userJWT,
			RedirectURI: tt.redirectURL.String(),
		}
		qBytes, err := json.Marshal(&q)
		if err != nil {
			t.Errorf("case %d: unable to marshal JSON: %q", i, err)
		}

		req, err := http.NewRequest("POST", u, bytes.NewReader(qBytes))
		req.Header.Set("Authorization", "Bearer "+tt.bearerJWT)

		if err != nil {
			t.Errorf("case %d: unable to form HTTP request: %v", i, err)
		}

		hdlr.ServeHTTP(w, req)
		if tt.wantCode != w.Code {
			t.Errorf("case %d: wantCode=%v, got=%v", i, tt.wantCode, w.Code)
			t.Logf("case %d: response body was: %v", i, w.Body.String())
		}
	}
}
예제 #16
0
func TestVerifyJWTExpiry(t *testing.T) {
	privKey, err := key.GeneratePrivateKey()
	if err != nil {
		t.Fatalf("can't generate private key: %v", err)
	}
	makeToken := func(s string, exp time.Time, count int) *jose.JWT {
		jwt, err := jose.NewSignedJWT(jose.Claims(map[string]interface{}{
			"test":  s,
			"exp":   exp.UTC().Unix(),
			"count": count,
		}), privKey.Signer())
		if err != nil {
			t.Fatalf("Could not create signed JWT %v", err)
		}
		return jwt
	}

	t0 := time.Now()

	tests := []struct {
		name        string
		jwt         *jose.JWT
		now         time.Time
		wantErr     bool
		wantExpired bool
	}{
		{
			name: "valid jwt",
			jwt:  makeToken("foo", t0.Add(time.Hour), 1),
			now:  t0,
		},
		{
			name:    "invalid jwt",
			jwt:     &jose.JWT{},
			now:     t0,
			wantErr: true,
		},
		{
			name:        "expired jwt",
			jwt:         makeToken("foo", t0.Add(-time.Hour), 1),
			now:         t0,
			wantExpired: true,
		},
		{
			name:        "jwt expires soon enough to be marked expired",
			jwt:         makeToken("foo", t0, 1),
			now:         t0,
			wantExpired: true,
		},
	}

	for _, tc := range tests {
		func() {
			valid, err := verifyJWTExpiry(tc.now, tc.jwt.Encode())
			if err != nil {
				if !tc.wantErr {
					t.Errorf("%s: %v", tc.name, err)
				}
				return
			}
			if tc.wantErr {
				t.Errorf("%s: expected error", tc.name)
				return
			}

			if valid && tc.wantExpired {
				t.Errorf("%s: expected token to be expired", tc.name)
			}
			if !valid && !tc.wantExpired {
				t.Errorf("%s: expected token to be valid", tc.name)
			}
		}()
	}
}
예제 #17
0
func TestNewOIDCAuthProvider(t *testing.T) {
	tempDir, err := ioutil.TempDir(os.TempDir(), "oidc_test")
	if err != nil {
		t.Fatalf("Cannot make temp dir %v", err)
	}
	cert := path.Join(tempDir, "oidc-cert")
	key := path.Join(tempDir, "oidc-key")
	defer os.RemoveAll(tempDir)

	oidctesting.GenerateSelfSignedCert(t, "127.0.0.1", cert, key)
	op := oidctesting.NewOIDCProvider(t, "")
	srv, err := op.ServeTLSWithKeyPair(cert, key)
	if err != nil {
		t.Fatalf("Cannot start server %v", err)
	}
	defer srv.Close()

	certData, err := ioutil.ReadFile(cert)
	if err != nil {
		t.Fatalf("Could not read cert bytes %v", err)
	}

	makeToken := func(exp time.Time) *jose.JWT {
		jwt, err := jose.NewSignedJWT(jose.Claims(map[string]interface{}{
			"exp": exp.UTC().Unix(),
		}), op.PrivKey.Signer())
		if err != nil {
			t.Fatalf("Could not create signed JWT %v", err)
		}
		return jwt
	}

	t0 := time.Now()

	goodToken := makeToken(t0.Add(time.Hour)).Encode()
	expiredToken := makeToken(t0.Add(-time.Hour)).Encode()

	tests := []struct {
		name string

		cfg         map[string]string
		wantInitErr bool

		client       OIDCClient
		wantCfg      map[string]string
		wantTokenErr bool
	}{
		{
			// A Valid configuration
			name: "no id token and no refresh token",
			cfg: map[string]string{
				cfgIssuerUrl:            srv.URL,
				cfgCertificateAuthority: cert,
				cfgClientID:             "client-id",
				cfgClientSecret:         "client-secret",
			},
			wantTokenErr: true,
		},
		{
			name: "valid config with an initial token",
			cfg: map[string]string{
				cfgIssuerUrl:            srv.URL,
				cfgCertificateAuthority: cert,
				cfgClientID:             "client-id",
				cfgClientSecret:         "client-secret",
				cfgIDToken:              goodToken,
			},
			client: new(noRefreshOIDCClient),
			wantCfg: map[string]string{
				cfgIssuerUrl:            srv.URL,
				cfgCertificateAuthority: cert,
				cfgClientID:             "client-id",
				cfgClientSecret:         "client-secret",
				cfgIDToken:              goodToken,
			},
		},
		{
			name: "invalid ID token with a refresh token",
			cfg: map[string]string{
				cfgIssuerUrl:            srv.URL,
				cfgCertificateAuthority: cert,
				cfgClientID:             "client-id",
				cfgClientSecret:         "client-secret",
				cfgRefreshToken:         "foo",
				cfgIDToken:              expiredToken,
			},
			client: &mockOIDCClient{
				tokenResponse: oauth2.TokenResponse{
					IDToken: goodToken,
				},
			},
			wantCfg: map[string]string{
				cfgIssuerUrl:            srv.URL,
				cfgCertificateAuthority: cert,
				cfgClientID:             "client-id",
				cfgClientSecret:         "client-secret",
				cfgRefreshToken:         "foo",
				cfgIDToken:              goodToken,
			},
		},
		{
			name: "invalid ID token with a refresh token, server returns new refresh token",
			cfg: map[string]string{
				cfgIssuerUrl:            srv.URL,
				cfgCertificateAuthority: cert,
				cfgClientID:             "client-id",
				cfgClientSecret:         "client-secret",
				cfgRefreshToken:         "foo",
				cfgIDToken:              expiredToken,
			},
			client: &mockOIDCClient{
				tokenResponse: oauth2.TokenResponse{
					IDToken:      goodToken,
					RefreshToken: "bar",
				},
			},
			wantCfg: map[string]string{
				cfgIssuerUrl:            srv.URL,
				cfgCertificateAuthority: cert,
				cfgClientID:             "client-id",
				cfgClientSecret:         "client-secret",
				cfgRefreshToken:         "bar",
				cfgIDToken:              goodToken,
			},
		},
		{
			name: "expired token and no refresh otken",
			cfg: map[string]string{
				cfgIssuerUrl:            srv.URL,
				cfgCertificateAuthority: cert,
				cfgClientID:             "client-id",
				cfgClientSecret:         "client-secret",
				cfgIDToken:              expiredToken,
			},
			wantTokenErr: true,
		},
		{
			name: "valid base64d ca",
			cfg: map[string]string{
				cfgIssuerUrl:                srv.URL,
				cfgCertificateAuthorityData: base64.StdEncoding.EncodeToString(certData),
				cfgClientID:                 "client-id",
				cfgClientSecret:             "client-secret",
			},
			client:       new(noRefreshOIDCClient),
			wantTokenErr: true,
		},
		{
			name: "missing client ID",
			cfg: map[string]string{
				cfgIssuerUrl:            srv.URL,
				cfgCertificateAuthority: cert,
				cfgClientSecret:         "client-secret",
			},
			wantInitErr: true,
		},
		{
			name: "missing client secret",
			cfg: map[string]string{
				cfgIssuerUrl:            srv.URL,
				cfgCertificateAuthority: cert,
				cfgClientID:             "client-id",
			},
			wantInitErr: true,
		},
		{
			name: "missing issuer URL",
			cfg: map[string]string{
				cfgCertificateAuthority: cert,
				cfgClientID:             "client-id",
				cfgClientSecret:         "secret",
			},
			wantInitErr: true,
		},
		{
			name: "missing TLS config",
			cfg: map[string]string{
				cfgIssuerUrl:    srv.URL,
				cfgClientID:     "client-id",
				cfgClientSecret: "secret",
			},
			wantInitErr: true,
		},
	}

	for _, tt := range tests {
		clearCache()

		p, err := newOIDCAuthProvider("cluster.example.com", tt.cfg, new(persister))
		if tt.wantInitErr {
			if err == nil {
				t.Errorf("%s: want non-nil err", tt.name)
			}
			continue
		}

		if err != nil {
			t.Errorf("%s: unexpected error on newOIDCAuthProvider: %v", tt.name, err)
			continue
		}

		provider := p.(*oidcAuthProvider)
		provider.client = tt.client
		provider.now = func() time.Time { return t0 }

		if _, err := provider.idToken(); err != nil {
			if !tt.wantTokenErr {
				t.Errorf("%s: failed to get id token: %v", tt.name, err)
			}
			continue
		}
		if tt.wantTokenErr {
			t.Errorf("%s: expected to not get id token: %v", tt.name, err)
			continue
		}

		if !reflect.DeepEqual(tt.wantCfg, provider.cfg) {
			t.Errorf("%s: expected config %#v got %#v", tt.name, tt.wantCfg, provider.cfg)
		}
	}
}
예제 #18
0
파일: server.go 프로젝트: GamerockSA/dex
func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jose.JWT, string, error) {
	ok, err := s.ClientManager.Authenticate(creds)
	if err != nil {
		log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err)
		return nil, "", oauth2.NewError(oauth2.ErrorServerError)
	}
	if !ok {
		log.Errorf("Failed to Authenticate client %s", creds.ID)
		return nil, "", oauth2.NewError(oauth2.ErrorInvalidClient)
	}

	sessionID, err := s.SessionManager.ExchangeKey(sessionKey)
	if err != nil {
		return nil, "", oauth2.NewError(oauth2.ErrorInvalidGrant)
	}

	ses, err := s.SessionManager.Kill(sessionID)
	if err != nil {
		return nil, "", oauth2.NewError(oauth2.ErrorInvalidRequest)
	}

	if ses.ClientID != creds.ID {
		return nil, "", oauth2.NewError(oauth2.ErrorInvalidGrant)
	}

	signer, err := s.KeyManager.Signer()
	if err != nil {
		log.Errorf("Failed to generate ID token: %v", err)
		return nil, "", oauth2.NewError(oauth2.ErrorServerError)
	}

	user, err := s.UserRepo.Get(nil, ses.UserID)
	if err != nil {
		log.Errorf("Failed to fetch user %q from repo: %v: ", ses.UserID, err)
		return nil, "", oauth2.NewError(oauth2.ErrorServerError)
	}

	claims := ses.Claims(s.IssuerURL.String())
	user.AddToClaims(claims)

	s.addClaimsFromScope(claims, ses.Scope, ses.ClientID)

	jwt, err := jose.NewSignedJWT(claims, signer)
	if err != nil {
		log.Errorf("Failed to generate ID token: %v", err)
		return nil, "", oauth2.NewError(oauth2.ErrorServerError)
	}

	// Generate refresh token when 'scope' contains 'offline_access'.
	var refreshToken string

	for _, scope := range ses.Scope {
		if scope == "offline_access" {
			log.Infof("Session %s requests offline access, will generate refresh token", sessionID)

			refreshToken, err = s.RefreshTokenRepo.Create(ses.UserID, creds.ID, ses.Scope)
			switch err {
			case nil:
				break
			default:
				log.Errorf("Failed to generate refresh token: %v", err)
				return nil, "", oauth2.NewError(oauth2.ErrorServerError)
			}
			break
		}
	}

	log.Infof("Session %s token sent: clientID=%s", sessionID, creds.ID)
	return jwt, refreshToken, nil
}
예제 #19
0
func TestJWTVerifier(t *testing.T) {
	iss := "http://example.com"
	now := time.Now()
	future12 := now.Add(12 * time.Hour)
	past36 := now.Add(-36 * time.Hour)
	past12 := now.Add(-12 * time.Hour)

	priv1, err := key.GeneratePrivateKey()
	if err != nil {
		t.Fatalf("failed to generate private key, error=%v", err)
	}
	pk1 := *key.NewPublicKey(priv1.JWK())

	priv2, err := key.GeneratePrivateKey()
	if err != nil {
		t.Fatalf("failed to generate private key, error=%v", err)
	}
	pk2 := *key.NewPublicKey(priv2.JWK())

	newJWT := func(issuer, subject string, aud interface{}, issuedAt, exp time.Time, signer jose.Signer) jose.JWT {
		jwt, err := jose.NewSignedJWT(NewClaims(issuer, subject, aud, issuedAt, exp), signer)
		if err != nil {
			t.Fatal(err)
		}
		return *jwt
	}

	tests := []struct {
		name     string
		verifier JWTVerifier
		jwt      jose.JWT
		wantErr  bool
	}{
		{
			name: "JWT signed with available key",
			verifier: JWTVerifier{
				issuer:   "example.com",
				clientID: "XXX",
				syncFunc: func() error { return nil },
				keysFunc: func() []key.PublicKey {
					return []key.PublicKey{pk1}
				},
			},
			jwt:     newJWT(iss, "XXX", "XXX", past12, future12, priv1.Signer()),
			wantErr: false,
		},
		{
			name: "JWT signed with available key, with bad claims",
			verifier: JWTVerifier{
				issuer:   "example.com",
				clientID: "XXX",
				syncFunc: func() error { return nil },
				keysFunc: func() []key.PublicKey {
					return []key.PublicKey{pk1}
				},
			},
			jwt:     newJWT(iss, "XXX", "YYY", past12, future12, priv1.Signer()),
			wantErr: true,
		},

		{
			name: "JWT signed with available key",
			verifier: JWTVerifier{
				issuer:   "example.com",
				clientID: "XXX",
				syncFunc: func() error { return nil },
				keysFunc: func() []key.PublicKey {
					return []key.PublicKey{pk1}
				},
			},
			jwt:     newJWT(iss, "XXX", []string{"YYY", "ZZZ"}, past12, future12, priv1.Signer()),
			wantErr: true,
		},

		{
			name: "expired JWT signed with available key",
			verifier: JWTVerifier{
				issuer:   "example.com",
				clientID: "XXX",
				syncFunc: func() error { return nil },
				keysFunc: func() []key.PublicKey {
					return []key.PublicKey{pk1}
				},
			},
			jwt:     newJWT(iss, "XXX", "XXX", past36, past12, priv1.Signer()),
			wantErr: true,
		},

		{
			name: "JWT signed with unrecognized key, verifiable after sync",
			verifier: JWTVerifier{
				issuer:   "example.com",
				clientID: "XXX",
				syncFunc: func() error { return nil },
				keysFunc: func() func() []key.PublicKey {
					var i int
					return func() []key.PublicKey {
						defer func() { i++ }()
						return [][]key.PublicKey{
							[]key.PublicKey{pk1},
							[]key.PublicKey{pk2},
						}[i]
					}
				}(),
			},
			jwt:     newJWT(iss, "XXX", "XXX", past36, future12, priv2.Signer()),
			wantErr: false,
		},

		{
			name: "JWT signed with unrecognized key, not verifiable after sync",
			verifier: JWTVerifier{
				issuer:   "example.com",
				clientID: "XXX",
				syncFunc: func() error { return nil },
				keysFunc: func() []key.PublicKey {
					return []key.PublicKey{pk1}
				},
			},
			jwt:     newJWT(iss, "XXX", "XXX", past12, future12, priv2.Signer()),
			wantErr: true,
		},

		{
			name: "verifier gets no keys from keysFunc, still not verifiable after sync",
			verifier: JWTVerifier{
				issuer:   "example.com",
				clientID: "XXX",
				syncFunc: func() error { return nil },
				keysFunc: func() []key.PublicKey {
					return []key.PublicKey{}
				},
			},
			jwt:     newJWT(iss, "XXX", "XXX", past12, future12, priv1.Signer()),
			wantErr: true,
		},

		{
			name: "verifier gets no keys from keysFunc, verifiable after sync",
			verifier: JWTVerifier{
				issuer:   "example.com",
				clientID: "XXX",
				syncFunc: func() error { return nil },
				keysFunc: func() func() []key.PublicKey {
					var i int
					return func() []key.PublicKey {
						defer func() { i++ }()
						return [][]key.PublicKey{
							[]key.PublicKey{},
							[]key.PublicKey{pk2},
						}[i]
					}
				}(),
			},
			jwt:     newJWT(iss, "XXX", "XXX", past12, future12, priv2.Signer()),
			wantErr: false,
		},

		{
			name: "JWT signed with available key, 'aud' is a string array",
			verifier: JWTVerifier{
				issuer:   "example.com",
				clientID: "XXX",
				syncFunc: func() error { return nil },
				keysFunc: func() []key.PublicKey {
					return []key.PublicKey{pk1}
				},
			},
			jwt:     newJWT(iss, "XXX", []string{"ZZZ", "XXX"}, past12, future12, priv1.Signer()),
			wantErr: false,
		},
		{
			name: "invalid issuer claim shouldn't trigger sync",
			verifier: JWTVerifier{
				issuer:   "example.com",
				clientID: "XXX",
				syncFunc: func() error {
					t.Errorf("invalid issuer claim shouldn't trigger a sync")
					return nil
				},
				keysFunc: func() func() []key.PublicKey {
					var i int
					return func() []key.PublicKey {
						defer func() { i++ }()
						return [][]key.PublicKey{
							[]key.PublicKey{},
							[]key.PublicKey{pk2},
						}[i]
					}
				}(),
			},
			jwt:     newJWT("invalid-issuer", "XXX", []string{"ZZZ", "XXX"}, past12, future12, priv2.Signer()),
			wantErr: true,
		},
	}

	for _, tt := range tests {
		err := tt.verifier.Verify(tt.jwt)
		if tt.wantErr && (err == nil) {
			t.Errorf("case %q: wanted non-nil error", tt.name)
		} else if !tt.wantErr && (err != nil) {
			t.Errorf("case %q: wanted nil error, got %v", tt.name, err)
		}
	}
}
예제 #20
0
func TestWrapTranport(t *testing.T) {
	oldBackoff := backoff
	defer func() {
		backoff = oldBackoff
	}()
	backoff = wait.Backoff{
		Duration: 1 * time.Nanosecond,
		Steps:    3,
	}

	privKey, err := key.GeneratePrivateKey()
	if err != nil {
		t.Fatalf("can't generate private key: %v", err)
	}

	makeToken := func(s string, exp time.Time, count int) *jose.JWT {
		jwt, err := jose.NewSignedJWT(jose.Claims(map[string]interface{}{
			"test":  s,
			"exp":   exp.UTC().Unix(),
			"count": count,
		}), privKey.Signer())
		if err != nil {
			t.Fatalf("Could not create signed JWT %v", err)
		}
		return jwt
	}

	goodToken := makeToken("good", time.Now().Add(time.Hour), 0)
	goodToken2 := makeToken("good", time.Now().Add(time.Hour), 1)
	expiredToken := makeToken("good", time.Now().Add(-time.Hour), 0)

	str := func(s string) *string {
		return &s
	}
	tests := []struct {
		cfgIDToken      *jose.JWT
		cfgRefreshToken *string

		expectRequests []testRoundTrip

		expectRefreshes []testRefresh

		expectPersists []testPersist

		wantStatus int
		wantErr    bool
	}{
		{
			// Initial JWT is set, it is good, it is set as bearer.
			cfgIDToken: goodToken,

			expectRequests: []testRoundTrip{
				{
					expectBearerToken: goodToken.Encode(),
					returnHTTPStatus:  200,
				},
			},

			wantStatus: 200,
		},
		{
			// Initial JWT is set, but it's expired, so it gets refreshed.
			cfgIDToken:      expiredToken,
			cfgRefreshToken: str("rt1"),

			expectRefreshes: []testRefresh{
				{
					expectRefreshToken: "rt1",
					returnTokens: oauth2.TokenResponse{
						IDToken: goodToken.Encode(),
					},
				},
			},

			expectRequests: []testRoundTrip{
				{
					expectBearerToken: goodToken.Encode(),
					returnHTTPStatus:  200,
				},
			},

			expectPersists: []testPersist{
				{
					cfg: map[string]string{
						cfgIDToken:      goodToken.Encode(),
						cfgRefreshToken: "rt1",
					},
				},
			},

			wantStatus: 200,
		},
		{
			// Initial JWT is set, but it's expired, so it gets refreshed - this
			// time the refresh token itself is also refreshed
			cfgIDToken:      expiredToken,
			cfgRefreshToken: str("rt1"),

			expectRefreshes: []testRefresh{
				{
					expectRefreshToken: "rt1",
					returnTokens: oauth2.TokenResponse{
						IDToken:      goodToken.Encode(),
						RefreshToken: "rt2",
					},
				},
			},

			expectRequests: []testRoundTrip{
				{
					expectBearerToken: goodToken.Encode(),
					returnHTTPStatus:  200,
				},
			},

			expectPersists: []testPersist{
				{
					cfg: map[string]string{
						cfgIDToken:      goodToken.Encode(),
						cfgRefreshToken: "rt2",
					},
				},
			},

			wantStatus: 200,
		},
		{
			// Initial JWT is not set, so it gets refreshed.
			cfgRefreshToken: str("rt1"),

			expectRefreshes: []testRefresh{
				{
					expectRefreshToken: "rt1",
					returnTokens: oauth2.TokenResponse{
						IDToken: goodToken.Encode(),
					},
				},
			},

			expectRequests: []testRoundTrip{
				{
					expectBearerToken: goodToken.Encode(),
					returnHTTPStatus:  200,
				},
			},

			expectPersists: []testPersist{
				{
					cfg: map[string]string{
						cfgIDToken:      goodToken.Encode(),
						cfgRefreshToken: "rt1",
					},
				},
			},

			wantStatus: 200,
		},
		{
			// Expired token, but no refresh token.
			cfgIDToken: expiredToken,

			wantErr: true,
		},
		{
			// Initial JWT is not set, so it gets refreshed, but the server
			// rejects it when it is used, so it refreshes again, which
			// succeeds.
			cfgRefreshToken: str("rt1"),

			expectRefreshes: []testRefresh{
				{
					expectRefreshToken: "rt1",
					returnTokens: oauth2.TokenResponse{
						IDToken: goodToken.Encode(),
					},
				},
				{
					expectRefreshToken: "rt1",
					returnTokens: oauth2.TokenResponse{
						IDToken: goodToken2.Encode(),
					},
				},
			},

			expectRequests: []testRoundTrip{
				{
					expectBearerToken: goodToken.Encode(),
					returnHTTPStatus:  http.StatusUnauthorized,
				},
				{
					expectBearerToken: goodToken2.Encode(),
					returnHTTPStatus:  http.StatusOK,
				},
			},

			expectPersists: []testPersist{
				{
					cfg: map[string]string{
						cfgIDToken:      goodToken.Encode(),
						cfgRefreshToken: "rt1",
					},
				},
				{
					cfg: map[string]string{
						cfgIDToken:      goodToken2.Encode(),
						cfgRefreshToken: "rt1",
					},
				},
			},

			wantStatus: 200,
		},
		{
			// Initial JWT is but the server rejects it when it is used, so it
			// refreshes again, which succeeds.
			cfgRefreshToken: str("rt1"),
			cfgIDToken:      goodToken,

			expectRefreshes: []testRefresh{
				{
					expectRefreshToken: "rt1",
					returnTokens: oauth2.TokenResponse{
						IDToken: goodToken2.Encode(),
					},
				},
			},

			expectRequests: []testRoundTrip{
				{
					expectBearerToken: goodToken.Encode(),
					returnHTTPStatus:  http.StatusUnauthorized,
				},
				{
					expectBearerToken: goodToken2.Encode(),
					returnHTTPStatus:  http.StatusOK,
				},
			},

			expectPersists: []testPersist{
				{
					cfg: map[string]string{
						cfgIDToken:      goodToken2.Encode(),
						cfgRefreshToken: "rt1",
					},
				},
			},
			wantStatus: 200,
		},
	}

	for i, tt := range tests {
		client := &testOIDCClient{
			refreshes: tt.expectRefreshes,
		}

		persister := &testPersister{
			tt.expectPersists,
		}

		cfg := map[string]string{}
		if tt.cfgIDToken != nil {
			cfg[cfgIDToken] = tt.cfgIDToken.Encode()
		}

		if tt.cfgRefreshToken != nil {
			cfg[cfgRefreshToken] = *tt.cfgRefreshToken
		}

		ap := &oidcAuthProvider{
			refresher: &idTokenRefresher{
				client:    client,
				cfg:       cfg,
				persister: persister,
			},
		}

		if tt.cfgIDToken != nil {
			ap.initialIDToken = *tt.cfgIDToken
		}

		tstRT := &testRoundTripper{
			tt.expectRequests,
		}

		rt := ap.WrapTransport(tstRT)

		req, err := http.NewRequest("GET", "http://cluster.example.com", nil)
		if err != nil {
			t.Errorf("case %d: unexpected error making request: %v", i, err)
		}

		res, err := rt.RoundTrip(req)
		if tt.wantErr {
			if err == nil {
				t.Errorf("case %d: Expected non-nil error", i)
			}
		} else if err != nil {
			t.Errorf("case %d: unexpected error making round trip: %v", i, err)

		} else {
			if res.StatusCode != tt.wantStatus {
				t.Errorf("case %d: want=%d, got=%d", i, tt.wantStatus, res.StatusCode)
			}
		}

		if err = client.verify(); err != nil {
			t.Errorf("case %d: %v", i, err)
		}

		if err = persister.verify(); err != nil {
			t.Errorf("case %d: %v", i, err)
		}

		if err = tstRT.verify(); err != nil {
			t.Errorf("case %d: %v", i, err)
			continue
		}

	}
}
예제 #21
0
func TestNewOIDCAuthProvider(t *testing.T) {
	tempDir, err := ioutil.TempDir(os.TempDir(), "oidc_test")
	if err != nil {
		t.Fatalf("Cannot make temp dir %v", err)
	}
	cert := path.Join(tempDir, "oidc-cert")
	key := path.Join(tempDir, "oidc-key")

	defer os.Remove(cert)
	defer os.Remove(key)
	defer os.Remove(tempDir)

	oidctesting.GenerateSelfSignedCert(t, "127.0.0.1", cert, key)
	op := oidctesting.NewOIDCProvider(t)
	srv, err := op.ServeTLSWithKeyPair(cert, key)
	if err != nil {
		t.Fatalf("Cannot start server %v", err)
	}
	defer srv.Close()
	op.AddMinimalProviderConfig(srv)

	certData, err := ioutil.ReadFile(cert)
	if err != nil {
		t.Fatalf("Could not read cert bytes %v", err)
	}

	jwt, err := jose.NewSignedJWT(jose.Claims(map[string]interface{}{
		"test": "jwt",
	}), op.PrivKey.Signer())
	if err != nil {
		t.Fatalf("Could not create signed JWT %v", err)
	}

	tests := []struct {
		cfg map[string]string

		wantErr            bool
		wantInitialIDToken jose.JWT
	}{
		{
			// A Valid configuration
			cfg: map[string]string{
				cfgIssuerUrl:            srv.URL,
				cfgCertificateAuthority: cert,
				cfgClientID:             "client-id",
				cfgClientSecret:         "client-secret",
			},
		},
		{
			// A Valid configuration with an Initial JWT
			cfg: map[string]string{
				cfgIssuerUrl:            srv.URL,
				cfgCertificateAuthority: cert,
				cfgClientID:             "client-id",
				cfgClientSecret:         "client-secret",
				cfgIDToken:              jwt.Encode(),
			},
			wantInitialIDToken: *jwt,
		},
		{
			// Valid config, but using cfgCertificateAuthorityData
			cfg: map[string]string{
				cfgIssuerUrl:                srv.URL,
				cfgCertificateAuthorityData: base64.StdEncoding.EncodeToString(certData),
				cfgClientID:                 "client-id",
				cfgClientSecret:             "client-secret",
			},
		},
		{
			// Missing client id
			cfg: map[string]string{
				cfgIssuerUrl:            srv.URL,
				cfgCertificateAuthority: cert,
				cfgClientSecret:         "client-secret",
			},
			wantErr: true,
		},
		{
			// Missing client secret
			cfg: map[string]string{
				cfgIssuerUrl:            srv.URL,
				cfgCertificateAuthority: cert,
				cfgClientID:             "client-id",
			},
			wantErr: true,
		},
		{
			// Missing issuer url.
			cfg: map[string]string{
				cfgCertificateAuthority: cert,
				cfgClientID:             "client-id",
				cfgClientSecret:         "secret",
			},
			wantErr: true,
		},
		{
			// No TLS config
			cfg: map[string]string{
				cfgIssuerUrl:    srv.URL,
				cfgClientID:     "client-id",
				cfgClientSecret: "secret",
			},
			wantErr: true,
		},
	}

	for i, tt := range tests {
		ap, err := newOIDCAuthProvider("cluster.example.com", tt.cfg, nil)
		if tt.wantErr {
			if err == nil {
				t.Errorf("case %d: want non-nil err", i)
			}
			continue
		}

		if err != nil {
			t.Errorf("case %d: unexpected error on newOIDCAuthProvider: %v", i, err)
			continue
		}

		oidcAP, ok := ap.(*oidcAuthProvider)
		if !ok {
			t.Errorf("case %d: expected ap to be an oidcAuthProvider", i)
			continue
		}

		if diff := compareJWTs(tt.wantInitialIDToken, oidcAP.initialIDToken); diff != "" {
			t.Errorf("case %d: compareJWTs(tt.wantInitialIDToken, oidcAP.initialIDToken)=%v", i, diff)
		}
	}
}
예제 #22
0
func TestEmailVerificationParseAndVerify(t *testing.T) {

	issuer, _ := url.Parse("http://example.com")
	otherIssuer, _ := url.Parse("http://bad.example.com")
	client := "myclient"
	user := User{ID: "1234", Email: "*****@*****.**"}
	callback, _ := url.Parse("http://client.example.com")
	expires := time.Hour * 3

	goodEV := NewEmailVerification(user, client, *issuer, *callback, expires)
	expiredEV := NewEmailVerification(user, client, *issuer, *callback, -expires)
	wrongIssuerEV := NewEmailVerification(user, client, *otherIssuer, *callback, expires)
	noSubEV := NewEmailVerification(User{}, client, *issuer, *callback, expires)

	privKey, err := key.GeneratePrivateKey()
	if err != nil {
		t.Fatalf("Failed to generate private key, error=%v", err)
	}
	signer := privKey.Signer()

	privKey2, err := key.GeneratePrivateKey()
	if err != nil {
		t.Fatalf("Failed to generate private key, error=%v", err)
	}
	otherSigner := privKey2.Signer()

	tests := []struct {
		ev      EmailVerification
		wantErr bool
		signer  jose.Signer
	}{

		{
			ev:      goodEV,
			signer:  signer,
			wantErr: false,
		},
		{
			ev:      expiredEV,
			signer:  signer,
			wantErr: true,
		},
		{
			ev:      wrongIssuerEV,
			signer:  signer,
			wantErr: true,
		},
		{
			ev:      goodEV,
			signer:  otherSigner,
			wantErr: true,
		},
		{
			ev:      noSubEV,
			signer:  signer,
			wantErr: true,
		},
	}

	for i, tt := range tests {

		jwt, err := jose.NewSignedJWT(tt.ev.Claims, tt.signer)
		if err != nil {
			t.Fatalf("Failed to generate JWT, error=%v", err)
		}
		token := jwt.Encode()

		ev, err := ParseAndVerifyEmailVerificationToken(token, *issuer,
			[]key.PublicKey{*key.NewPublicKey(privKey.JWK())})

		if tt.wantErr {
			t.Logf("err: %v", err)
			if err == nil {
				t.Errorf("case %d: want non-nil err, got nil", i)
			}
			continue
		}

		if err != nil {
			t.Errorf("case %d: non-nil err: %q", i, err)

		}

		if diff := pretty.Compare(tt.ev.Claims, ev.Claims); diff != "" {
			t.Errorf("case %d: Compare(want, got): %v", i, diff)
		}
	}
}
예제 #23
0
func TestJWTVerifier(t *testing.T) {
	iss := "http://example.com"
	now := time.Now()
	future12 := now.Add(12 * time.Hour)
	past36 := now.Add(-36 * time.Hour)
	past12 := now.Add(-12 * time.Hour)

	priv1, err := key.GeneratePrivateKey()
	if err != nil {
		t.Fatalf("failed to generate private key, error=%v", err)
	}
	pk1 := *key.NewPublicKey(priv1.JWK())

	priv2, err := key.GeneratePrivateKey()
	if err != nil {
		t.Fatalf("failed to generate private key, error=%v", err)
	}
	pk2 := *key.NewPublicKey(priv2.JWK())

	jwtPK1, err := jose.NewSignedJWT(NewClaims(iss, "XXX", "XXX", past12, future12), priv1.Signer())
	if err != nil {
		t.Fatalf("unexpected error: %v", err)
	}

	jwtPK1BadClaims, err := jose.NewSignedJWT(NewClaims(iss, "XXX", "YYY", past12, future12), priv1.Signer())
	if err != nil {
		t.Fatalf("unexpected error: %v", err)
	}

	jwtExpired, err := jose.NewSignedJWT(NewClaims(iss, "XXX", "XXX", past36, past12), priv1.Signer())
	if err != nil {
		t.Fatalf("unexpected error: %v", err)
	}

	jwtPK2, err := jose.NewSignedJWT(NewClaims(iss, "XXX", "XXX", past12, future12), priv2.Signer())
	if err != nil {
		t.Fatalf("unexpected error: %v", err)
	}

	tests := []struct {
		verifier JWTVerifier
		jwt      jose.JWT
		wantErr  bool
	}{
		// JWT signed with available key
		{
			verifier: JWTVerifier{
				issuer:   "example.com",
				clientID: "XXX",
				syncFunc: func() error { return nil },
				keysFunc: func() []key.PublicKey {
					return []key.PublicKey{pk1}
				},
			},
			jwt:     *jwtPK1,
			wantErr: false,
		},

		// JWT signed with available key, with bad claims
		{
			verifier: JWTVerifier{
				issuer:   "example.com",
				clientID: "XXX",
				syncFunc: func() error { return nil },
				keysFunc: func() []key.PublicKey {
					return []key.PublicKey{pk1}
				},
			},
			jwt:     *jwtPK1BadClaims,
			wantErr: true,
		},

		// expired JWT signed with available key
		{
			verifier: JWTVerifier{
				issuer:   "example.com",
				clientID: "XXX",
				syncFunc: func() error { return nil },
				keysFunc: func() []key.PublicKey {
					return []key.PublicKey{pk1}
				},
			},
			jwt:     *jwtExpired,
			wantErr: true,
		},

		// JWT signed with unrecognized key, verifiable after sync
		{
			verifier: JWTVerifier{
				issuer:   "example.com",
				clientID: "XXX",
				syncFunc: func() error { return nil },
				keysFunc: func() func() []key.PublicKey {
					var i int
					return func() []key.PublicKey {
						defer func() { i++ }()
						return [][]key.PublicKey{
							[]key.PublicKey{pk1},
							[]key.PublicKey{pk2},
						}[i]
					}
				}(),
			},
			jwt:     *jwtPK2,
			wantErr: false,
		},

		// JWT signed with unrecognized key, not verifiable after sync
		{
			verifier: JWTVerifier{
				issuer:   "example.com",
				clientID: "XXX",
				syncFunc: func() error { return nil },
				keysFunc: func() []key.PublicKey {
					return []key.PublicKey{pk1}
				},
			},
			jwt:     *jwtPK2,
			wantErr: true,
		},

		// verifier gets no keys from keysFunc, still not verifiable after sync
		{
			verifier: JWTVerifier{
				issuer:   "example.com",
				clientID: "XXX",
				syncFunc: func() error { return nil },
				keysFunc: func() []key.PublicKey {
					return []key.PublicKey{}
				},
			},
			jwt:     *jwtPK1,
			wantErr: true,
		},

		// verifier gets no keys from keysFunc, verifiable after sync
		{
			verifier: JWTVerifier{
				issuer:   "example.com",
				clientID: "XXX",
				syncFunc: func() error { return nil },
				keysFunc: func() func() []key.PublicKey {
					var i int
					return func() []key.PublicKey {
						defer func() { i++ }()
						return [][]key.PublicKey{
							[]key.PublicKey{},
							[]key.PublicKey{pk2},
						}[i]
					}
				}(),
			},
			jwt:     *jwtPK2,
			wantErr: false,
		},
	}

	for i, tt := range tests {
		err := tt.verifier.Verify(tt.jwt)
		if tt.wantErr && (err == nil) {
			t.Errorf("case %d: wanted non-nil error", i)
		} else if !tt.wantErr && (err != nil) {
			t.Errorf("case %d: wanted nil error, got %v", i, err)
		}
	}
}
예제 #24
0
func TestInvitationParseAndVerify(t *testing.T) {
	issuer, _ := url.Parse("http://example.com")
	notIssuer, _ := url.Parse("http://other.com")
	client := "myclient"
	user := User{ID: "1234", Email: "*****@*****.**"}
	callback, _ := url.Parse("http://client.example.com")
	expires := time.Hour * 3
	password := Password("Halloween is the best holiday")
	privKey, _ := key.GeneratePrivateKey()
	signer := privKey.Signer()
	publicKeys := []key.PublicKey{*key.NewPublicKey(privKey.JWK())}

	tests := []struct {
		invite  Invitation
		wantErr bool
		signer  jose.Signer
	}{
		{
			invite:  NewInvitation(user, password, *issuer, client, *callback, expires),
			signer:  signer,
			wantErr: false,
		},
		{
			invite:  NewInvitation(user, password, *issuer, client, *callback, expires),
			signer:  signer,
			wantErr: false,
		},
		{
			invite:  NewInvitation(user, password, *issuer, client, *callback, -expires),
			signer:  signer,
			wantErr: true,
		},
		{
			invite:  NewInvitation(user, password, *notIssuer, client, *callback, expires),
			signer:  signer,
			wantErr: true,
		},
		{
			invite:  NewInvitation(User{Email: "*****@*****.**"}, password, *issuer, client, *callback, expires),
			signer:  signer,
			wantErr: true,
		},
		{
			invite:  NewInvitation(User{ID: "JONNY_NO_EMAIL"}, password, *issuer, client, *callback, expires),
			signer:  signer,
			wantErr: true,
		},
		{
			invite:  NewInvitation(user, Password(""), *issuer, client, *callback, expires),
			signer:  signer,
			wantErr: true,
		},
		{
			invite:  NewInvitation(user, password, *issuer, "", *callback, expires),
			signer:  signer,
			wantErr: true,
		},
		{
			invite:  NewInvitation(user, password, *issuer, "", url.URL{}, expires),
			signer:  signer,
			wantErr: true,
		},
	}

	for i, tt := range tests {
		jwt, err := jose.NewSignedJWT(tt.invite.Claims, tt.signer)
		if err != nil {
			t.Fatalf("case %d: failed to generate JWT, error: %v", i, err)
		}
		token := jwt.Encode()

		parsed, err := ParseAndVerifyInvitationToken(token, *issuer, publicKeys)

		if tt.wantErr {
			if err == nil {
				t.Errorf("case %d: want no-nil error, got nil", i)
			}
			continue
		}

		if err != nil {
			t.Errorf("case %d: unexpected error: %v", i, err)
			continue
		}

		if diff := pretty.Compare(tt.invite, parsed); diff != "" {
			t.Errorf("case %d: Compare(want, got): %v", i, diff)
		}
	}
}
예제 #25
0
파일: server.go 프로젝트: Tecsisa/dex
func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, token string) (*jose.JWT, string, time.Time, error) {
	ok, err := s.ClientManager.Authenticate(creds)
	if err != nil {
		log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err)
		return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
	}
	if !ok {
		log.Errorf("Failed to Authenticate client %s", creds.ID)
		return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorInvalidClient)
	}

	userID, connectorID, rtScopes, err := s.RefreshTokenRepo.Verify(creds.ID, token)
	switch err {
	case nil:
		break
	case refresh.ErrorInvalidToken:
		return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorInvalidRequest)
	case refresh.ErrorInvalidClientID:
		return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorInvalidClient)
	default:
		return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
	}

	if len(scopes) == 0 {
		scopes = rtScopes
	} else {
		if !rtScopes.Contains(scopes) {
			return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorInvalidRequest)
		}
	}

	usr, err := s.UserRepo.Get(nil, userID)
	if err != nil {
		// The error can be user.ErrorNotFound, but we are not deleting
		// user at this moment, so this shouldn't happen.
		log.Errorf("Failed to fetch user %q from repo: %v: ", userID, err)
		return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
	}

	var groups []string
	if rtScopes.HasScope(scope.ScopeGroups) {
		conn, ok := s.connector(connectorID)
		if !ok {
			log.Errorf("refresh token contained invalid connector ID (%s)", connectorID)
			return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
		}

		grouper, ok := conn.(connector.GroupsConnector)
		if !ok {
			log.Errorf("refresh token requested groups for connector (%s) that doesn't support groups", connectorID)
			return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
		}

		remoteIdentities, err := s.UserRepo.GetRemoteIdentities(nil, userID)
		if err != nil {
			log.Errorf("failed to get remote identities: %v", err)
			return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
		}
		remoteIdentity, ok := func() (user.RemoteIdentity, bool) {
			for _, ri := range remoteIdentities {
				if ri.ConnectorID == connectorID {
					return ri, true
				}
			}
			return user.RemoteIdentity{}, false
		}()
		if !ok {
			log.Errorf("failed to get remote identity for connector %s", connectorID)
			return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
		}
		if groups, err = grouper.Groups(remoteIdentity.ID); err != nil {
			log.Errorf("failed to get groups for refresh token: %v", connectorID)
			return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
		}
	}

	signer, err := s.KeyManager.Signer()
	if err != nil {
		log.Errorf("Failed to refresh ID token: %v", err)
		return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
	}

	now := time.Now()
	expiresAt := now.Add(session.DefaultSessionValidityWindow)

	claims := oidc.NewClaims(s.IssuerURL.String(), usr.ID, creds.ID, now, expiresAt)
	usr.AddToClaims(claims)
	if rtScopes.HasScope(scope.ScopeGroups) {
		if groups == nil {
			groups = []string{}
		}
		claims["groups"] = groups
	}

	s.addClaimsFromScope(claims, scope.Scopes(scopes), creds.ID)

	jwt, err := jose.NewSignedJWT(claims, signer)
	if err != nil {
		log.Errorf("Failed to generate ID token: %v", err)
		return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
	}

	refreshToken, err := s.RefreshTokenRepo.RenewRefreshToken(creds.ID, userID, token)
	if err != nil {
		log.Errorf("Failed to generate new refresh token: %v", err)
		return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
	}

	log.Infof("New token sent: clientID=%s", creds.ID)

	return jwt, refreshToken, expiresAt, nil
}
예제 #26
0
파일: password_test.go 프로젝트: ryanj/dex
func TestPasswordResetParseAndVerify(t *testing.T) {

	issuer, _ := url.Parse("http://example.com")
	otherIssuer, _ := url.Parse("http://bad.example.com")
	client := "myclient"
	user := User{ID: "1234", Email: "*****@*****.**"}
	callback, _ := url.Parse("http://client.example.com")
	expires := time.Hour * 3
	password := Password("passy")
	userID := user.ID

	goodPR := NewPasswordReset(userID, password, *issuer, client, *callback, expires)
	goodPRNoCB := NewPasswordReset(userID, password, *issuer, client, url.URL{}, expires)
	expiredPR := NewPasswordReset(userID, password, *issuer, client, *callback, -expires)
	wrongIssuerPR := NewPasswordReset(userID, password, *otherIssuer, client, *callback, expires)
	noSubPR := NewPasswordReset("", password, *issuer, client, *callback, expires)
	noPWPR := NewPasswordReset(userID, Password(""), *issuer, client, *callback, expires)
	noClientPR := NewPasswordReset(userID, password, *issuer, "", *callback, expires)
	noClientNoCBPR := NewPasswordReset(userID, password, *issuer, "", url.URL{}, expires)

	privKey, err := key.GeneratePrivateKey()
	if err != nil {
		t.Fatalf("Failed to generate private key, error=%v", err)
	}
	signer := privKey.Signer()

	privKey2, err := key.GeneratePrivateKey()
	if err != nil {
		t.Fatalf("Failed to generate private key, error=%v", err)
	}
	otherSigner := privKey2.Signer()

	tests := []struct {
		ev      PasswordReset
		wantErr bool
		signer  jose.Signer
	}{

		{
			ev:      goodPR,
			signer:  signer,
			wantErr: false,
		},
		{
			ev:      goodPRNoCB,
			signer:  signer,
			wantErr: false,
		},

		{
			ev:      expiredPR,
			signer:  signer,
			wantErr: true,
		},
		{
			ev:      wrongIssuerPR,
			signer:  signer,
			wantErr: true,
		},
		{
			ev:      goodPR,
			signer:  otherSigner,
			wantErr: true,
		},
		{
			ev:      noSubPR,
			signer:  signer,
			wantErr: true,
		},
		{
			ev:      noPWPR,
			signer:  signer,
			wantErr: true,
		},
		{
			ev:      noClientPR,
			signer:  signer,
			wantErr: true,
		},
		{
			ev:      noClientNoCBPR,
			signer:  signer,
			wantErr: true,
		},
	}

	for i, tt := range tests {

		jwt, err := jose.NewSignedJWT(tt.ev.Claims, tt.signer)
		if err != nil {
			t.Fatalf("Failed to generate JWT, error=%v", err)
		}
		token := jwt.Encode()

		ev, err := ParseAndVerifyPasswordResetToken(token, *issuer,
			[]key.PublicKey{*key.NewPublicKey(privKey.JWK())})

		if tt.wantErr {
			t.Logf("err: %v", err)
			if err == nil {
				t.Errorf("case %d: want non-nil err, got nil", i)
			}
			continue
		}

		if err != nil {
			t.Errorf("case %d: non-nil err: %q", i, err)

		}

		if diff := pretty.Compare(tt.ev.Claims, ev.Claims); diff != "" {
			t.Errorf("case %d: Compare(want, got): %v", i, diff)
		}
	}
}
예제 #27
0
func TestClientToken(t *testing.T) {
	now := time.Now()
	tomorrow := now.Add(24 * time.Hour)
	validClientID := "valid-client"
	ci := oidc.ClientIdentity{
		Credentials: oidc.ClientCredentials{
			ID:     validClientID,
			Secret: base64.URLEncoding.EncodeToString([]byte("secret")),
		},
		Metadata: oidc.ClientMetadata{
			RedirectURIs: []url.URL{
				{Scheme: "https", Host: "authn.example.com", Path: "/callback"},
			},
		},
	}
	repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{ci})
	if err != nil {
		t.Fatalf("Failed to create client identity repo: %v", err)
	}

	privKey, err := key.GeneratePrivateKey()
	if err != nil {
		t.Fatalf("Failed to generate private key, error=%v", err)
	}
	signer := privKey.Signer()
	pubKey := *key.NewPublicKey(privKey.JWK())

	validIss := "https://example.com"

	makeToken := func(iss, sub, aud string, iat, exp time.Time) string {
		claims := oidc.NewClaims(iss, sub, aud, iat, exp)
		jwt, err := jose.NewSignedJWT(claims, signer)
		if err != nil {
			t.Fatalf("Failed to generate JWT, error=%v", err)
		}
		return jwt.Encode()
	}

	validJWT := makeToken(validIss, validClientID, validClientID, now, tomorrow)
	invalidJWT := makeToken("", "", "", now, tomorrow)

	tests := []struct {
		keys     []key.PublicKey
		repo     client.ClientIdentityRepo
		header   string
		wantCode int
	}{
		// valid token
		{
			keys:     []key.PublicKey{pubKey},
			repo:     repo,
			header:   fmt.Sprintf("BEARER %s", validJWT),
			wantCode: http.StatusOK,
		},
		// invalid token
		{
			keys:     []key.PublicKey{pubKey},
			repo:     repo,
			header:   fmt.Sprintf("BEARER %s", invalidJWT),
			wantCode: http.StatusUnauthorized,
		},
		// empty header
		{
			keys:     []key.PublicKey{pubKey},
			repo:     repo,
			header:   "",
			wantCode: http.StatusUnauthorized,
		},
		// unparsable token
		{
			keys:     []key.PublicKey{pubKey},
			repo:     repo,
			header:   "BEARER xxx",
			wantCode: http.StatusUnauthorized,
		},
		// no verification keys
		{
			keys:     []key.PublicKey{},
			repo:     repo,
			header:   fmt.Sprintf("BEARER %s", validJWT),
			wantCode: http.StatusUnauthorized,
		},
		// nil repo
		{
			keys:     []key.PublicKey{pubKey},
			repo:     nil,
			header:   fmt.Sprintf("BEARER %s", validJWT),
			wantCode: http.StatusUnauthorized,
		},
		// empty repo
		{
			keys:     []key.PublicKey{pubKey},
			repo:     db.NewClientIdentityRepo(db.NewMemDB()),
			header:   fmt.Sprintf("BEARER %s", validJWT),
			wantCode: http.StatusUnauthorized,
		},
		// client not in repo
		{
			keys:     []key.PublicKey{pubKey},
			repo:     repo,
			header:   fmt.Sprintf("BEARER %s", makeToken(validIss, "DOESNT-EXIST", "DOESNT-EXIST", now, tomorrow)),
			wantCode: http.StatusUnauthorized,
		},
	}

	for i, tt := range tests {
		w := httptest.NewRecorder()
		mw := &clientTokenMiddleware{
			issuerURL: validIss,
			ciRepo:    tt.repo,
			keysFunc: func() ([]key.PublicKey, error) {
				return tt.keys, nil
			},
			next: staticHandler{},
		}
		req := &http.Request{
			Header: http.Header{
				"Authorization": []string{tt.header},
			},
		}

		mw.ServeHTTP(w, req)
		if tt.wantCode != w.Code {
			t.Errorf("case %d: invalid response code, want=%d, got=%d", i, tt.wantCode, w.Code)
		}
	}
}
예제 #28
0
파일: invitation.go 프로젝트: Tecsisa/dex
func (h *InvitationHandler) handleGET(w http.ResponseWriter, r *http.Request) {
	q := r.URL.Query()
	token := q.Get("token")

	keys, err := h.keysFunc()
	if err != nil {
		log.Errorf("internal error getting public keys: %v", err)
		writeAPIError(w, http.StatusInternalServerError, newAPIError(errorServerError,
			"There's been an error processing your request."))
		return
	}

	invite, err := user.ParseAndVerifyInvitationToken(token, h.issuerURL, keys)
	if err != nil {
		log.Debugf("invalid invitation token: %v (%v)", err, token)
		writeAPIError(w, http.StatusBadRequest, newAPIError(errorInvalidRequest,
			"Your invitation could not be verified"))
		return
	}

	_, err = h.um.VerifyEmail(invite)
	if err != nil && err != manager.ErrorEmailAlreadyVerified {
		// Allow AlreadyVerified folks to pass through- otherwise
		// folks who encounter an error after passing this point will
		// never be able to set their passwords.
		log.Debugf("error attempting to verify email: %v", err)
		switch err {
		case user.ErrorNotFound:
			writeAPIError(w, http.StatusBadRequest, newAPIError(errorInvalidRequest,
				"Your email does not match the email address on file"))
			return
		default:
			log.Errorf("internal error verifying email: %v", err)
			writeAPIError(w, http.StatusInternalServerError, newAPIError(errorServerError,
				"There's been an error processing your request."))
			return
		}
	}

	passwordReset := invite.PasswordReset(h.issuerURL, h.redirectValidityWindow)
	signer, err := h.signerFunc()
	if err != nil || signer == nil {
		log.Errorf("error getting signer: %v (signer: %v)", err, signer)
		writeAPIError(w, http.StatusInternalServerError, newAPIError(errorServerError,
			"There's been an error processing your request."))
		return
	}

	jwt, err := jose.NewSignedJWT(passwordReset.Claims, signer)
	if err != nil {
		log.Errorf("error constructing or signing PasswordReset from Invitation JWT: %v", err)
		writeAPIError(w, http.StatusInternalServerError, newAPIError(errorServerError,
			"There's been an error processing your request."))
		return
	}
	passwordResetToken := jwt.Encode()

	passwordResetURL := h.passwordResetURL
	newQuery := passwordResetURL.Query()
	newQuery.Set("token", passwordResetToken)
	passwordResetURL.RawQuery = newQuery.Encode()
	http.Redirect(w, r, passwordResetURL.String(), http.StatusSeeOther)
}
예제 #29
0
func TestResetPasswordHandler(t *testing.T) {
	makeToken := func(userID, password, clientID string, callback url.URL, expires time.Duration, signer jose.Signer) string {
		pr := user.NewPasswordReset("ID-1",
			user.Password(password),
			testIssuerURL,
			clientID,
			callback,
			expires)

		jwt, err := jose.NewSignedJWT(pr.Claims, signer)
		if err != nil {
			t.Fatalf("couldn't make token: %q", err)
		}
		token := jwt.Encode()
		return token
	}
	goodSigner := key.NewPrivateKeySet([]*key.PrivateKey{testPrivKey},
		time.Now().Add(time.Minute)).Active().Signer()

	badKey, err := key.GeneratePrivateKey()
	if err != nil {
		t.Fatalf("couldn't make new key: %q", err)
	}
	badSigner := key.NewPrivateKeySet([]*key.PrivateKey{badKey},
		time.Now().Add(time.Minute)).Active().Signer()

	str := func(s string) []string {
		return []string{s}
	}

	user.PasswordHasher = func(s string) ([]byte, error) {
		return []byte(strings.ToUpper(s)), nil
	}
	defer func() {
		user.PasswordHasher = user.DefaultPasswordHasher
	}()

	tokenForCase := map[int]string{
		0: makeToken("ID-1", "password", testClientID, testRedirectURL, time.Hour*1, goodSigner),
		2: makeToken("ID-1", "password", testClientID, url.URL{}, time.Hour*1, goodSigner),
		5: makeToken("ID-1", "password", testClientID, url.URL{}, time.Hour*1, goodSigner),
	}

	tests := []struct {
		query url.Values

		method string

		wantFormValues *url.Values
		wantCode       int
		wantPassword   string
	}{
		// Scenario 1: Happy Path
		{ // Case 0
			// Step 1.1 - User clicks link in email, has valid token.
			query: url.Values{
				"token": str(tokenForCase[0]),
			},
			method: "GET",

			wantCode: http.StatusOK,
			wantFormValues: &url.Values{
				"password": str(""),
				"token":    str(tokenForCase[0]),
			},
			wantPassword: "******",
		},
		{ // Case 1
			// Step 1.2 - User enters in new valid password, password is changed, user is redirected.
			query: url.Values{
				"token":    str(makeToken("ID-1", "password", testClientID, testRedirectURL, time.Hour*1, goodSigner)),
				"password": str("new_password"),
			},
			method: "POST",

			wantCode:       http.StatusSeeOther,
			wantFormValues: &url.Values{},
			wantPassword:   "******",
		},
		// Scenario 2: Happy Path, but without redirect.
		{ // Case 2
			// Step 2.1 - User clicks link in email, has valid token.
			query: url.Values{
				"token": str(tokenForCase[2]),
			},
			method: "GET",

			wantCode: http.StatusOK,
			wantFormValues: &url.Values{
				"password": str(""),
				"token":    str(tokenForCase[2]),
			},
			wantPassword: "******",
		},
		{ // Case 3
			// Step 2.2 - User enters in new valid password, password is changed, user is redirected.
			query: url.Values{
				"token":    str(makeToken("ID-1", "password", testClientID, url.URL{}, time.Hour*1, goodSigner)),
				"password": str("new_password"),
			},
			method: "POST",

			// no redirect
			wantCode:       http.StatusOK,
			wantFormValues: &url.Values{},
			wantPassword:   "******",
		},
		// Errors
		{ // Case 4
			// Step 1.1.1 - User clicks link in email, has invalid token.
			query: url.Values{
				"token": str(makeToken("ID-1", "password", testClientID, testRedirectURL, time.Hour*1, badSigner)),
			},
			method: "GET",

			wantCode:       http.StatusBadRequest,
			wantFormValues: &url.Values{},
			wantPassword:   "******",
		},

		{ // Case 5
			// Step 2.2.1 - User enters in new valid password, password is changed, no redirect
			query: url.Values{
				"token":    str(tokenForCase[5]),
				"password": str("shrt"),
			},
			method: "POST",

			// no redirect
			wantCode: http.StatusBadRequest,
			wantFormValues: &url.Values{
				"password": str(""),
				"token":    str(tokenForCase[5]),
			},
			wantPassword: "******",
		},
		{ // Case 6
			// Step 2.2.2 - User enters in new valid password, with suspicious token.
			query: url.Values{
				"token":    str(makeToken("ID-1", "password", testClientID, url.URL{}, time.Hour*1, badSigner)),
				"password": str("shrt"),
			},
			method: "POST",

			// no redirect
			wantCode:       http.StatusBadRequest,
			wantFormValues: &url.Values{},
			wantPassword:   "******",
		},
		{ // Case 7
			// Token lacking client id
			query: url.Values{
				"token":    str(makeToken("ID-1", "password", "", url.URL{}, time.Hour*1, goodSigner)),
				"password": str("shrt"),
			},
			method: "GET",

			wantCode:     http.StatusBadRequest,
			wantPassword: "******",
		},
		{ // Case 8
			// Token lacking client id
			query: url.Values{
				"token":    str(makeToken("ID-1", "password", "", url.URL{}, time.Hour*1, goodSigner)),
				"password": str("shrt"),
			},
			method: "POST",

			wantCode:     http.StatusBadRequest,
			wantPassword: "******",
		},
	}
	for i, tt := range tests {
		f, err := makeTestFixtures()
		if err != nil {
			t.Fatalf("case %d: could not make test fixtures: %v", i, err)
		}

		hdlr := ResetPasswordHandler{
			tpl:       f.srv.ResetPasswordTemplate,
			issuerURL: testIssuerURL,
			um:        f.srv.UserManager,
			keysFunc:  f.srv.KeyManager.PublicKeys,
		}

		w := httptest.NewRecorder()
		var req *http.Request
		u := testIssuerURL
		u.Path = httpPathResetPassword
		if tt.method == "POST" {
			req, err = http.NewRequest(tt.method, u.String(), strings.NewReader(tt.query.Encode()))
			req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
		} else {
			u.RawQuery = tt.query.Encode()
			req, err = http.NewRequest(tt.method, u.String(), nil)
		}
		if err != nil {
			t.Errorf("case %d: unable to form HTTP request: %v", i, err)
		}

		hdlr.ServeHTTP(w, req)

		if tt.wantCode != w.Code {
			t.Errorf("case %d: wantCode=%v, got=%v", i, tt.wantCode, w.Code)
			continue
		}

		values, err := html.FormValues("#resetPasswordForm", bytes.NewReader(w.Body.Bytes()))
		if err != nil {
			t.Errorf("case %d: could not parse form: %v", i, err)
		}

		if tt.wantFormValues != nil {
			if diff := pretty.Compare(*tt.wantFormValues, values); diff != "" {
				t.Errorf("case %d: Compare(wantFormValues, got) = %v", i, diff)
			}
		}
		pwi, err := f.srv.PasswordInfoRepo.Get(nil, "ID-1")
		if err != nil {
			t.Errorf("case %d: Error getting Password info: %v", i, err)
		}
		if tt.wantPassword != string(pwi.Password) {
			t.Errorf("case %d: wantPassword=%v, got=%v", i, tt.wantPassword, string(pwi.Password))
		}

	}
}
예제 #30
0
func TestInvitationHandler(t *testing.T) {
	invUserID := "ID-1"
	invVerifiedID := "ID-Verified"
	invGoodSigner := key.NewPrivateKeySet([]*key.PrivateKey{testPrivKey},
		time.Now().Add(time.Minute)).Active().Signer()

	badKey, err := key.GeneratePrivateKey()
	if err != nil {
		panic(fmt.Sprintf("couldn't make new key: %q", err))
	}

	invBadSigner := key.NewPrivateKeySet([]*key.PrivateKey{badKey},
		time.Now().Add(time.Minute)).Active().Signer()

	makeInvitationToken := func(password, userID, clientID, email string, callback url.URL, expires time.Duration, signer jose.Signer) string {
		iv := user.NewInvitation(
			user.User{ID: userID, Email: email},
			user.Password(password),
			testIssuerURL,
			clientID,
			callback,
			expires)

		jwt, err := jose.NewSignedJWT(iv.Claims, signer)
		if err != nil {
			t.Fatalf("couldn't make token: %q", err)
		}
		token := jwt.Encode()
		return token
	}

	tests := []struct {
		userID            string
		query             url.Values
		signer            jose.Signer
		wantCode          int
		wantCallback      url.URL
		wantEmailVerified bool
	}{
		{ // Case 0 Happy Path
			userID: invUserID,
			query: url.Values{
				"token": []string{makeInvitationToken("password", invUserID, testClientID, "*****@*****.**", testRedirectURL, time.Hour*1, invGoodSigner)},
			},
			signer:            invGoodSigner,
			wantCode:          http.StatusSeeOther,
			wantCallback:      testRedirectURL,
			wantEmailVerified: true,
		},
		{ // Case 1 user already verified
			userID: invVerifiedID,
			query: url.Values{
				"token": []string{makeInvitationToken("password", invVerifiedID, testClientID, "*****@*****.**", testRedirectURL, time.Hour*1, invGoodSigner)},
			},
			signer:            invGoodSigner,
			wantCode:          http.StatusSeeOther,
			wantCallback:      testRedirectURL,
			wantEmailVerified: true,
		},
		{ // Case 2 bad email
			userID: invUserID,
			query: url.Values{
				"token": []string{makeInvitationToken("password", invVerifiedID, testClientID, "*****@*****.**", testRedirectURL, time.Hour*1, invGoodSigner)},
			},
			signer:            invGoodSigner,
			wantCode:          http.StatusBadRequest,
			wantCallback:      testRedirectURL,
			wantEmailVerified: false,
		},
		{ // Case 3 bad signer
			userID: invUserID,
			query: url.Values{
				"token": []string{makeInvitationToken("password", invUserID, testClientID, "*****@*****.**", testRedirectURL, time.Hour*1, invBadSigner)},
			},
			signer:            invGoodSigner,
			wantCode:          http.StatusBadRequest,
			wantCallback:      testRedirectURL,
			wantEmailVerified: false,
		},
	}

	for i, tt := range tests {
		f, err := makeTestFixtures()
		if err != nil {
			t.Fatalf("case %d: could not make test fixtures: %v", i, err)
		}

		keys, err := f.srv.KeyManager.PublicKeys()
		if err != nil {
			t.Fatalf("case %d: test fixture key infrastructure is broken: %v", i, err)
		}

		tZero := clock.Now()
		handler := &InvitationHandler{
			passwordResetURL:       f.srv.absURL("RESETME"),
			issuerURL:              testIssuerURL,
			um:                     f.srv.UserManager,
			keysFunc:               f.srv.KeyManager.PublicKeys,
			signerFunc:             func() (jose.Signer, error) { return tt.signer, nil },
			redirectValidityWindow: 100 * time.Second,
		}

		w := httptest.NewRecorder()
		u := testIssuerURL
		u.RawQuery = tt.query.Encode()
		req, err := http.NewRequest("GET", u.String(), nil)
		if err != nil {
			t.Fatalf("case %d: impossible error: %v", i, err)
		}

		handler.ServeHTTP(w, req)

		if tt.wantCode != w.Code {
			t.Errorf("case %d: wantCode=%v, got=%v", i, tt.wantCode, w.Code)
			continue
		}

		usr, err := f.srv.UserManager.Get(tt.userID)
		if err != nil {
			t.Fatalf("case %d: unexpected error: %v", i, err)
		}

		if usr.EmailVerified != tt.wantEmailVerified {
			t.Errorf("case %d: wantEmailVerified=%v got=%v", i, tt.wantEmailVerified, usr.EmailVerified)
		}

		if w.Code == http.StatusSeeOther {
			locString := w.HeaderMap.Get("Location")
			loc, err := url.Parse(locString)
			if err != nil {
				t.Fatalf("case %d: redirect returned nonsense url: '%v', %v", i, locString, err)
			}

			pwrToken := loc.Query().Get("token")
			pwrReset, err := user.ParseAndVerifyPasswordResetToken(pwrToken, testIssuerURL, keys)
			if err != nil {
				t.Errorf("case %d: password token is invalid: %v", i, err)
			}

			expTime := pwrReset.Claims["exp"].(float64)
			if expTime > float64(tZero.Add(handler.redirectValidityWindow).Unix()) ||
				expTime < float64(tZero.Unix()) {
				t.Errorf("case %d: funny expiration time detected: %d", i, pwrReset.Claims["exp"])
			}

			if pwrReset.Claims["aud"] != testClientID {
				t.Errorf("case %d: wanted \"aud\"=%v got=%v", i, testClientID, pwrReset.Claims["aud"])
			}

			if pwrReset.Claims["iss"] != testIssuerURL.String() {
				t.Errorf("case %d: wanted \"iss\"=%v got=%v", i, testIssuerURL, pwrReset.Claims["iss"])
			}

			if pwrReset.UserID() != tt.userID {
				t.Errorf("case %d: wanted UserID=%v got=%v", i, tt.userID, pwrReset.UserID())
			}

			if bytes.Compare(pwrReset.Password(), user.Password("password")) != 0 {
				t.Errorf("case %d: wanted Password=%v got=%v", i, user.Password("password"), pwrReset.Password())
			}

			if *pwrReset.Callback() != testRedirectURL {
				t.Errorf("case %d: wanted callback=%v got=%v", i, testRedirectURL, pwrReset.Callback())
			}
		}
	}
}