func compareJWTs(a, b jose.JWT) string { if a.Encode() == b.Encode() { return "" } var aClaims, bClaims jose.Claims for _, j := range []struct { claims *jose.Claims jwt jose.JWT }{ {&aClaims, a}, {&bClaims, b}, } { var err error *j.claims, err = j.jwt.Claims() if err != nil { *j.claims = jose.Claims(map[string]interface{}{ "msg": "bad claims", "err": err, }) } } return diff.ObjectDiff(aClaims, bClaims) }
func TestNewOIDCAuthProvider(t *testing.T) { tempDir, err := ioutil.TempDir(os.TempDir(), "oidc_test") if err != nil { t.Fatalf("Cannot make temp dir %v", err) } cert := path.Join(tempDir, "oidc-cert") key := path.Join(tempDir, "oidc-key") defer os.RemoveAll(tempDir) oidctesting.GenerateSelfSignedCert(t, "127.0.0.1", cert, key) op := oidctesting.NewOIDCProvider(t, "") srv, err := op.ServeTLSWithKeyPair(cert, key) if err != nil { t.Fatalf("Cannot start server %v", err) } defer srv.Close() certData, err := ioutil.ReadFile(cert) if err != nil { t.Fatalf("Could not read cert bytes %v", err) } makeToken := func(exp time.Time) *jose.JWT { jwt, err := jose.NewSignedJWT(jose.Claims(map[string]interface{}{ "exp": exp.UTC().Unix(), }), op.PrivKey.Signer()) if err != nil { t.Fatalf("Could not create signed JWT %v", err) } return jwt } t0 := time.Now() goodToken := makeToken(t0.Add(time.Hour)).Encode() expiredToken := makeToken(t0.Add(-time.Hour)).Encode() tests := []struct { name string cfg map[string]string wantInitErr bool client OIDCClient wantCfg map[string]string wantTokenErr bool }{ { // A Valid configuration name: "no id token and no refresh token", cfg: map[string]string{ cfgIssuerUrl: srv.URL, cfgCertificateAuthority: cert, cfgClientID: "client-id", cfgClientSecret: "client-secret", }, wantTokenErr: true, }, { name: "valid config with an initial token", cfg: map[string]string{ cfgIssuerUrl: srv.URL, cfgCertificateAuthority: cert, cfgClientID: "client-id", cfgClientSecret: "client-secret", cfgIDToken: goodToken, }, client: new(noRefreshOIDCClient), wantCfg: map[string]string{ cfgIssuerUrl: srv.URL, cfgCertificateAuthority: cert, cfgClientID: "client-id", cfgClientSecret: "client-secret", cfgIDToken: goodToken, }, }, { name: "invalid ID token with a refresh token", cfg: map[string]string{ cfgIssuerUrl: srv.URL, cfgCertificateAuthority: cert, cfgClientID: "client-id", cfgClientSecret: "client-secret", cfgRefreshToken: "foo", cfgIDToken: expiredToken, }, client: &mockOIDCClient{ tokenResponse: oauth2.TokenResponse{ IDToken: goodToken, }, }, wantCfg: map[string]string{ cfgIssuerUrl: srv.URL, cfgCertificateAuthority: cert, cfgClientID: "client-id", cfgClientSecret: "client-secret", cfgRefreshToken: "foo", cfgIDToken: goodToken, }, }, { name: "invalid ID token with a refresh token, server returns new refresh token", cfg: map[string]string{ cfgIssuerUrl: srv.URL, cfgCertificateAuthority: cert, cfgClientID: "client-id", cfgClientSecret: "client-secret", cfgRefreshToken: "foo", cfgIDToken: expiredToken, }, client: &mockOIDCClient{ tokenResponse: oauth2.TokenResponse{ IDToken: goodToken, RefreshToken: "bar", }, }, wantCfg: map[string]string{ cfgIssuerUrl: srv.URL, cfgCertificateAuthority: cert, cfgClientID: "client-id", cfgClientSecret: "client-secret", cfgRefreshToken: "bar", cfgIDToken: goodToken, }, }, { name: "expired token and no refresh otken", cfg: map[string]string{ cfgIssuerUrl: srv.URL, cfgCertificateAuthority: cert, cfgClientID: "client-id", cfgClientSecret: "client-secret", cfgIDToken: expiredToken, }, wantTokenErr: true, }, { name: "valid base64d ca", cfg: map[string]string{ cfgIssuerUrl: srv.URL, cfgCertificateAuthorityData: base64.StdEncoding.EncodeToString(certData), cfgClientID: "client-id", cfgClientSecret: "client-secret", }, client: new(noRefreshOIDCClient), wantTokenErr: true, }, { name: "missing client ID", cfg: map[string]string{ cfgIssuerUrl: srv.URL, cfgCertificateAuthority: cert, cfgClientSecret: "client-secret", }, wantInitErr: true, }, { name: "missing client secret", cfg: map[string]string{ cfgIssuerUrl: srv.URL, cfgCertificateAuthority: cert, cfgClientID: "client-id", }, wantInitErr: true, }, { name: "missing issuer URL", cfg: map[string]string{ cfgCertificateAuthority: cert, cfgClientID: "client-id", cfgClientSecret: "secret", }, wantInitErr: true, }, { name: "missing TLS config", cfg: map[string]string{ cfgIssuerUrl: srv.URL, cfgClientID: "client-id", cfgClientSecret: "secret", }, wantInitErr: true, }, } for _, tt := range tests { clearCache() p, err := newOIDCAuthProvider("cluster.example.com", tt.cfg, new(persister)) if tt.wantInitErr { if err == nil { t.Errorf("%s: want non-nil err", tt.name) } continue } if err != nil { t.Errorf("%s: unexpected error on newOIDCAuthProvider: %v", tt.name, err) continue } provider := p.(*oidcAuthProvider) provider.client = tt.client provider.now = func() time.Time { return t0 } if _, err := provider.idToken(); err != nil { if !tt.wantTokenErr { t.Errorf("%s: failed to get id token: %v", tt.name, err) } continue } if tt.wantTokenErr { t.Errorf("%s: expected to not get id token: %v", tt.name, err) continue } if !reflect.DeepEqual(tt.wantCfg, provider.cfg) { t.Errorf("%s: expected config %#v got %#v", tt.name, tt.wantCfg, provider.cfg) } } }
func TestVerifyJWTExpiry(t *testing.T) { privKey, err := key.GeneratePrivateKey() if err != nil { t.Fatalf("can't generate private key: %v", err) } makeToken := func(s string, exp time.Time, count int) *jose.JWT { jwt, err := jose.NewSignedJWT(jose.Claims(map[string]interface{}{ "test": s, "exp": exp.UTC().Unix(), "count": count, }), privKey.Signer()) if err != nil { t.Fatalf("Could not create signed JWT %v", err) } return jwt } t0 := time.Now() tests := []struct { name string jwt *jose.JWT now time.Time wantErr bool wantExpired bool }{ { name: "valid jwt", jwt: makeToken("foo", t0.Add(time.Hour), 1), now: t0, }, { name: "invalid jwt", jwt: &jose.JWT{}, now: t0, wantErr: true, }, { name: "expired jwt", jwt: makeToken("foo", t0.Add(-time.Hour), 1), now: t0, wantExpired: true, }, { name: "jwt expires soon enough to be marked expired", jwt: makeToken("foo", t0, 1), now: t0, wantExpired: true, }, } for _, tc := range tests { func() { valid, err := verifyJWTExpiry(tc.now, tc.jwt.Encode()) if err != nil { if !tc.wantErr { t.Errorf("%s: %v", tc.name, err) } return } if tc.wantErr { t.Errorf("%s: expected error", tc.name) return } if valid && tc.wantExpired { t.Errorf("%s: expected token to be expired", tc.name) } if !valid && !tc.wantExpired { t.Errorf("%s: expected token to be valid", tc.name) } }() } }
func TestNewOIDCAuthProvider(t *testing.T) { tempDir, err := ioutil.TempDir(os.TempDir(), "oidc_test") if err != nil { t.Fatalf("Cannot make temp dir %v", err) } cert := path.Join(tempDir, "oidc-cert") key := path.Join(tempDir, "oidc-key") defer os.Remove(cert) defer os.Remove(key) defer os.Remove(tempDir) oidctesting.GenerateSelfSignedCert(t, "127.0.0.1", cert, key) op := oidctesting.NewOIDCProvider(t) srv, err := op.ServeTLSWithKeyPair(cert, key) if err != nil { t.Fatalf("Cannot start server %v", err) } defer srv.Close() op.AddMinimalProviderConfig(srv) certData, err := ioutil.ReadFile(cert) if err != nil { t.Fatalf("Could not read cert bytes %v", err) } jwt, err := jose.NewSignedJWT(jose.Claims(map[string]interface{}{ "test": "jwt", }), op.PrivKey.Signer()) if err != nil { t.Fatalf("Could not create signed JWT %v", err) } tests := []struct { cfg map[string]string wantErr bool wantInitialIDToken jose.JWT }{ { // A Valid configuration cfg: map[string]string{ cfgIssuerUrl: srv.URL, cfgCertificateAuthority: cert, cfgClientID: "client-id", cfgClientSecret: "client-secret", }, }, { // A Valid configuration with an Initial JWT cfg: map[string]string{ cfgIssuerUrl: srv.URL, cfgCertificateAuthority: cert, cfgClientID: "client-id", cfgClientSecret: "client-secret", cfgIDToken: jwt.Encode(), }, wantInitialIDToken: *jwt, }, { // Valid config, but using cfgCertificateAuthorityData cfg: map[string]string{ cfgIssuerUrl: srv.URL, cfgCertificateAuthorityData: base64.StdEncoding.EncodeToString(certData), cfgClientID: "client-id", cfgClientSecret: "client-secret", }, }, { // Missing client id cfg: map[string]string{ cfgIssuerUrl: srv.URL, cfgCertificateAuthority: cert, cfgClientSecret: "client-secret", }, wantErr: true, }, { // Missing client secret cfg: map[string]string{ cfgIssuerUrl: srv.URL, cfgCertificateAuthority: cert, cfgClientID: "client-id", }, wantErr: true, }, { // Missing issuer url. cfg: map[string]string{ cfgCertificateAuthority: cert, cfgClientID: "client-id", cfgClientSecret: "secret", }, wantErr: true, }, { // No TLS config cfg: map[string]string{ cfgIssuerUrl: srv.URL, cfgClientID: "client-id", cfgClientSecret: "secret", }, wantErr: true, }, } for i, tt := range tests { ap, err := newOIDCAuthProvider("cluster.example.com", tt.cfg, nil) if tt.wantErr { if err == nil { t.Errorf("case %d: want non-nil err", i) } continue } if err != nil { t.Errorf("case %d: unexpected error on newOIDCAuthProvider: %v", i, err) continue } oidcAP, ok := ap.(*oidcAuthProvider) if !ok { t.Errorf("case %d: expected ap to be an oidcAuthProvider", i) continue } if diff := compareJWTs(tt.wantInitialIDToken, oidcAP.initialIDToken); diff != "" { t.Errorf("case %d: compareJWTs(tt.wantInitialIDToken, oidcAP.initialIDToken)=%v", i, diff) } } }
func TestWrapTranport(t *testing.T) { oldBackoff := backoff defer func() { backoff = oldBackoff }() backoff = wait.Backoff{ Duration: 1 * time.Nanosecond, Steps: 3, } privKey, err := key.GeneratePrivateKey() if err != nil { t.Fatalf("can't generate private key: %v", err) } makeToken := func(s string, exp time.Time, count int) *jose.JWT { jwt, err := jose.NewSignedJWT(jose.Claims(map[string]interface{}{ "test": s, "exp": exp.UTC().Unix(), "count": count, }), privKey.Signer()) if err != nil { t.Fatalf("Could not create signed JWT %v", err) } return jwt } goodToken := makeToken("good", time.Now().Add(time.Hour), 0) goodToken2 := makeToken("good", time.Now().Add(time.Hour), 1) expiredToken := makeToken("good", time.Now().Add(-time.Hour), 0) str := func(s string) *string { return &s } tests := []struct { cfgIDToken *jose.JWT cfgRefreshToken *string expectRequests []testRoundTrip expectRefreshes []testRefresh expectPersists []testPersist wantStatus int wantErr bool }{ { // Initial JWT is set, it is good, it is set as bearer. cfgIDToken: goodToken, expectRequests: []testRoundTrip{ { expectBearerToken: goodToken.Encode(), returnHTTPStatus: 200, }, }, wantStatus: 200, }, { // Initial JWT is set, but it's expired, so it gets refreshed. cfgIDToken: expiredToken, cfgRefreshToken: str("rt1"), expectRefreshes: []testRefresh{ { expectRefreshToken: "rt1", returnTokens: oauth2.TokenResponse{ IDToken: goodToken.Encode(), }, }, }, expectRequests: []testRoundTrip{ { expectBearerToken: goodToken.Encode(), returnHTTPStatus: 200, }, }, expectPersists: []testPersist{ { cfg: map[string]string{ cfgIDToken: goodToken.Encode(), cfgRefreshToken: "rt1", }, }, }, wantStatus: 200, }, { // Initial JWT is set, but it's expired, so it gets refreshed - this // time the refresh token itself is also refreshed cfgIDToken: expiredToken, cfgRefreshToken: str("rt1"), expectRefreshes: []testRefresh{ { expectRefreshToken: "rt1", returnTokens: oauth2.TokenResponse{ IDToken: goodToken.Encode(), RefreshToken: "rt2", }, }, }, expectRequests: []testRoundTrip{ { expectBearerToken: goodToken.Encode(), returnHTTPStatus: 200, }, }, expectPersists: []testPersist{ { cfg: map[string]string{ cfgIDToken: goodToken.Encode(), cfgRefreshToken: "rt2", }, }, }, wantStatus: 200, }, { // Initial JWT is not set, so it gets refreshed. cfgRefreshToken: str("rt1"), expectRefreshes: []testRefresh{ { expectRefreshToken: "rt1", returnTokens: oauth2.TokenResponse{ IDToken: goodToken.Encode(), }, }, }, expectRequests: []testRoundTrip{ { expectBearerToken: goodToken.Encode(), returnHTTPStatus: 200, }, }, expectPersists: []testPersist{ { cfg: map[string]string{ cfgIDToken: goodToken.Encode(), cfgRefreshToken: "rt1", }, }, }, wantStatus: 200, }, { // Expired token, but no refresh token. cfgIDToken: expiredToken, wantErr: true, }, { // Initial JWT is not set, so it gets refreshed, but the server // rejects it when it is used, so it refreshes again, which // succeeds. cfgRefreshToken: str("rt1"), expectRefreshes: []testRefresh{ { expectRefreshToken: "rt1", returnTokens: oauth2.TokenResponse{ IDToken: goodToken.Encode(), }, }, { expectRefreshToken: "rt1", returnTokens: oauth2.TokenResponse{ IDToken: goodToken2.Encode(), }, }, }, expectRequests: []testRoundTrip{ { expectBearerToken: goodToken.Encode(), returnHTTPStatus: http.StatusUnauthorized, }, { expectBearerToken: goodToken2.Encode(), returnHTTPStatus: http.StatusOK, }, }, expectPersists: []testPersist{ { cfg: map[string]string{ cfgIDToken: goodToken.Encode(), cfgRefreshToken: "rt1", }, }, { cfg: map[string]string{ cfgIDToken: goodToken2.Encode(), cfgRefreshToken: "rt1", }, }, }, wantStatus: 200, }, { // Initial JWT is but the server rejects it when it is used, so it // refreshes again, which succeeds. cfgRefreshToken: str("rt1"), cfgIDToken: goodToken, expectRefreshes: []testRefresh{ { expectRefreshToken: "rt1", returnTokens: oauth2.TokenResponse{ IDToken: goodToken2.Encode(), }, }, }, expectRequests: []testRoundTrip{ { expectBearerToken: goodToken.Encode(), returnHTTPStatus: http.StatusUnauthorized, }, { expectBearerToken: goodToken2.Encode(), returnHTTPStatus: http.StatusOK, }, }, expectPersists: []testPersist{ { cfg: map[string]string{ cfgIDToken: goodToken2.Encode(), cfgRefreshToken: "rt1", }, }, }, wantStatus: 200, }, } for i, tt := range tests { client := &testOIDCClient{ refreshes: tt.expectRefreshes, } persister := &testPersister{ tt.expectPersists, } cfg := map[string]string{} if tt.cfgIDToken != nil { cfg[cfgIDToken] = tt.cfgIDToken.Encode() } if tt.cfgRefreshToken != nil { cfg[cfgRefreshToken] = *tt.cfgRefreshToken } ap := &oidcAuthProvider{ refresher: &idTokenRefresher{ client: client, cfg: cfg, persister: persister, }, } if tt.cfgIDToken != nil { ap.initialIDToken = *tt.cfgIDToken } tstRT := &testRoundTripper{ tt.expectRequests, } rt := ap.WrapTransport(tstRT) req, err := http.NewRequest("GET", "http://cluster.example.com", nil) if err != nil { t.Errorf("case %d: unexpected error making request: %v", i, err) } res, err := rt.RoundTrip(req) if tt.wantErr { if err == nil { t.Errorf("case %d: Expected non-nil error", i) } } else if err != nil { t.Errorf("case %d: unexpected error making round trip: %v", i, err) } else { if res.StatusCode != tt.wantStatus { t.Errorf("case %d: want=%d, got=%d", i, tt.wantStatus, res.StatusCode) } } if err = client.verify(); err != nil { t.Errorf("case %d: %v", i, err) } if err = persister.verify(); err != nil { t.Errorf("case %d: %v", i, err) } if err = tstRT.verify(); err != nil { t.Errorf("case %d: %v", i, err) continue } } }