func TestServerLoginUnrecognizedSessionKey(t *testing.T) { ciRepo := client.NewClientIdentityRepo([]oidc.ClientIdentity{ oidc.ClientIdentity{ Credentials: oidc.ClientCredentials{ ID: "XXX", Secret: "secrete", }, }, }) km := &StaticKeyManager{ signer: &StaticSigner{sig: nil, err: errors.New("fail")}, } sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()) srv := &Server{ IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"}, KeyManager: km, SessionManager: sm, ClientIdentityRepo: ciRepo, } ident := oidc.Identity{ID: "YYY", Name: "elroy", Email: "*****@*****.**"} code, err := srv.Login(ident, "XXX") if err == nil { t.Fatalf("Expected non-nil error") } if code != "" { t.Fatalf("Expected empty code, got=%s", code) } }
func (cfg *SingleServerConfig) Configure(srv *Server) error { k, err := key.GeneratePrivateKey() if err != nil { return err } ks := key.NewPrivateKeySet([]*key.PrivateKey{k}, time.Now().Add(24*time.Hour)) kRepo := key.NewPrivateKeySetRepo() if err = kRepo.Set(ks); err != nil { return err } cf, err := os.Open(cfg.ClientsFile) if err != nil { return fmt.Errorf("unable to read clients from file %s: %v", cfg.ClientsFile, err) } defer cf.Close() ciRepo, err := client.NewClientIdentityRepoFromReader(cf) if err != nil { return fmt.Errorf("unable to read client identities from file %s: %v", cfg.ClientsFile, err) } f, err := os.Open(cfg.ConnectorsFile) if err != nil { return fmt.Errorf("opening connectors file: %v", err) } defer f.Close() cfgs, err := connector.ReadConfigs(f) if err != nil { return fmt.Errorf("decoding connector configs: %v", err) } cfgRepo := connector.NewConnectorConfigRepoFromConfigs(cfgs) sRepo := session.NewSessionRepo() skRepo := session.NewSessionKeyRepo() sm := session.NewSessionManager(sRepo, skRepo) userRepo, err := user.NewUserRepoFromFile(cfg.UsersFile) if err != nil { return fmt.Errorf("unable to read users from file: %v", err) } pwiRepo := user.NewPasswordInfoRepo() refTokRepo := refresh.NewRefreshTokenRepo() txnFactory := repo.InMemTransactionFactory userManager := manager.NewUserManager(userRepo, pwiRepo, cfgRepo, txnFactory, manager.ManagerOptions{}) srv.ClientIdentityRepo = ciRepo srv.KeySetRepo = kRepo srv.ConnectorConfigRepo = cfgRepo srv.UserRepo = userRepo srv.UserManager = userManager srv.PasswordInfoRepo = pwiRepo srv.SessionManager = sm srv.RefreshTokenRepo = refTokRepo return nil }
func TestServerLogin(t *testing.T) { ci := oidc.ClientIdentity{ Credentials: oidc.ClientCredentials{ ID: "XXX", Secret: "secrete", }, Metadata: oidc.ClientMetadata{ RedirectURLs: []url.URL{ url.URL{ Scheme: "http", Host: "client.example.com", Path: "/callback", }, }, }, } ciRepo := client.NewClientIdentityRepo([]oidc.ClientIdentity{ci}) km := &StaticKeyManager{ signer: &StaticSigner{sig: []byte("beer"), err: nil}, } sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()) sm.GenerateCode = staticGenerateCodeFunc("fakecode") sessionID, err := sm.NewSession("test_connector_id", ci.Credentials.ID, "bogus", ci.Metadata.RedirectURLs[0], "", false, []string{"openid"}) if err != nil { t.Fatalf("Unexpected error: %v", err) } userRepo, err := makeNewUserRepo() if err != nil { t.Fatalf("Unexpected error: %v", err) } srv := &Server{ IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"}, KeyManager: km, SessionManager: sm, ClientIdentityRepo: ciRepo, UserRepo: userRepo, } ident := oidc.Identity{ID: "YYY", Name: "elroy", Email: "*****@*****.**"} key, err := sm.NewSessionKey(sessionID) if err != nil { t.Fatalf("Unexpected error: %v", err) } redirectURL, err := srv.Login(ident, key) if err != nil { t.Fatalf("Unexpected err from Server.Login: %v", err) } wantRedirectURL := "http://client.example.com/callback?code=fakecode&state=bogus" if wantRedirectURL != redirectURL { t.Fatalf("Unexpected redirectURL: want=%q, got=%q", wantRedirectURL, redirectURL) } }
func TestServerNewSession(t *testing.T) { sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()) srv := &Server{ SessionManager: sm, } state := "pants" nonce := "oncenay" ci := oidc.ClientIdentity{ Credentials: oidc.ClientCredentials{ ID: "XXX", Secret: "secrete", }, Metadata: oidc.ClientMetadata{ RedirectURLs: []url.URL{ url.URL{ Scheme: "http", Host: "client.example.com", Path: "/callback", }, }, }, } key, err := srv.NewSession("bogus_idpc", ci.Credentials.ID, state, ci.Metadata.RedirectURLs[0], nonce, false, []string{"openid"}) if err != nil { t.Fatalf("Unexpected error: %v", err) } sessionID, err := sm.ExchangeKey(key) if err != nil { t.Fatalf("Session not retreivable: %v", err) } ses, err := sm.AttachRemoteIdentity(sessionID, oidc.Identity{}) if err != nil { t.Fatalf("Unable to add Identity to Session: %v", err) } if !reflect.DeepEqual(ci.Metadata.RedirectURLs[0], ses.RedirectURL) { t.Fatalf("Session created with incorrect RedirectURL: want=%#v got=%#v", ci.Metadata.RedirectURLs[0], ses.RedirectURL) } if ci.Credentials.ID != ses.ClientID { t.Fatalf("Session created with incorrect ClientID: want=%q got=%q", ci.Credentials.ID, ses.ClientID) } if state != ses.ClientState { t.Fatalf("Session created with incorrect State: want=%q got=%q", state, ses.ClientState) } if nonce != ses.Nonce { t.Fatalf("Session created with incorrect Nonce: want=%q got=%q", nonce, ses.Nonce) } }
func TestServerTokenUnrecognizedKey(t *testing.T) { ci := oidc.ClientIdentity{ Credentials: oidc.ClientCredentials{ ID: "XXX", Secret: "secrete", }, } ciRepo := client.NewClientIdentityRepo([]oidc.ClientIdentity{ci}) km := &StaticKeyManager{ signer: &StaticSigner{sig: []byte("beer"), err: nil}, } sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()) srv := &Server{ IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"}, KeyManager: km, SessionManager: sm, ClientIdentityRepo: ciRepo, } sessionID, err := sm.NewSession("connector_id", ci.Credentials.ID, "bogus", url.URL{}, "", false, []string{"openid", "offline_access"}) if err != nil { t.Fatalf("Unexpected error: %v", err) } _, err = sm.AttachRemoteIdentity(sessionID, oidc.Identity{}) if err != nil { t.Fatalf("Unexpected error: %v", err) } jwt, token, err := srv.CodeToken(ci.Credentials, "foo") if err == nil { t.Fatalf("Expected non-nil error") } if jwt != nil { t.Fatalf("Expected nil jwt") } if token != "" { t.Fatalf("Expected empty refresh token") } }
func (cfg *MultiServerConfig) Configure(srv *Server) error { if len(cfg.KeySecrets) == 0 { return errors.New("missing key secret") } if cfg.DatabaseConfig.DSN == "" { return errors.New("missing database connection string") } dbc, err := db.NewConnection(cfg.DatabaseConfig) if err != nil { return fmt.Errorf("unable to initialize database connection: %v", err) } kRepo, err := db.NewPrivateKeySetRepo(dbc, cfg.UseOldFormat, cfg.KeySecrets...) if err != nil { return fmt.Errorf("unable to create PrivateKeySetRepo: %v", err) } ciRepo := db.NewClientIdentityRepo(dbc) sRepo := db.NewSessionRepo(dbc) skRepo := db.NewSessionKeyRepo(dbc) cfgRepo := db.NewConnectorConfigRepo(dbc) userRepo := db.NewUserRepo(dbc) pwiRepo := db.NewPasswordInfoRepo(dbc) userManager := user.NewManager(userRepo, pwiRepo, db.TransactionFactory(dbc), user.ManagerOptions{}) refreshTokenRepo := db.NewRefreshTokenRepo(dbc) sm := session.NewSessionManager(sRepo, skRepo) srv.ClientIdentityRepo = ciRepo srv.KeySetRepo = kRepo srv.ConnectorConfigRepo = cfgRepo srv.UserRepo = userRepo srv.UserManager = userManager srv.PasswordInfoRepo = pwiRepo srv.SessionManager = sm srv.RefreshTokenRepo = refreshTokenRepo return nil }
func mockServer(cis []oidc.ClientIdentity) (*server.Server, error) { k, err := key.GeneratePrivateKey() if err != nil { return nil, fmt.Errorf("Unable to generate private key: %v", err) } km := key.NewPrivateKeyManager() err = km.Set(key.NewPrivateKeySet([]*key.PrivateKey{k}, time.Now().Add(time.Minute))) if err != nil { return nil, err } sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()) srv := &server.Server{ IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"}, KeyManager: km, ClientIdentityRepo: client.NewClientIdentityRepo(cis), SessionManager: sm, } return srv, nil }
func TestServerTokenFail(t *testing.T) { issuerURL := url.URL{Scheme: "http", Host: "server.example.com"} keyFixture := "goodkey" ccFixture := oidc.ClientCredentials{ ID: "XXX", Secret: "secrete", } signerFixture := &StaticSigner{sig: []byte("beer"), err: nil} tests := []struct { signer jose.Signer argCC oidc.ClientCredentials argKey string err error scope []string refreshToken string }{ // control test case to make sure fixtures check out { signer: signerFixture, argCC: ccFixture, argKey: keyFixture, scope: []string{"openid", "offline_access"}, refreshToken: fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), }, // no 'offline_access' in 'scope', should get empty refresh token { signer: signerFixture, argCC: ccFixture, argKey: keyFixture, scope: []string{"openid"}, }, // unrecognized key { signer: signerFixture, argCC: ccFixture, argKey: "foo", err: oauth2.NewError(oauth2.ErrorInvalidGrant), scope: []string{"openid", "offline_access"}, }, // unrecognized client { signer: signerFixture, argCC: oidc.ClientCredentials{ID: "YYY"}, argKey: keyFixture, err: oauth2.NewError(oauth2.ErrorInvalidClient), scope: []string{"openid", "offline_access"}, }, // signing operation fails { signer: &StaticSigner{sig: nil, err: errors.New("fail")}, argCC: ccFixture, argKey: keyFixture, err: oauth2.NewError(oauth2.ErrorServerError), scope: []string{"openid", "offline_access"}, }, } for i, tt := range tests { sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()) sm.GenerateCode = func() (string, error) { return keyFixture, nil } sessionID, err := sm.NewSession("connector_id", ccFixture.ID, "bogus", url.URL{}, "", false, tt.scope) if err != nil { t.Fatalf("Unexpected error: %v", err) } _, err = sm.AttachRemoteIdentity(sessionID, oidc.Identity{}) if err != nil { t.Errorf("case %d: unexpected error: %v", i, err) continue } km := &StaticKeyManager{ signer: tt.signer, } ciRepo := client.NewClientIdentityRepo([]oidc.ClientIdentity{ oidc.ClientIdentity{Credentials: ccFixture}, }) _, err = sm.AttachUser(sessionID, "testid-1") if err != nil { t.Fatalf("case %d: unexpected error: %v", i, err) } userRepo, err := makeNewUserRepo() if err != nil { t.Fatalf("Unexpected error: %v", err) } refreshTokenRepo, err := refreshtest.NewTestRefreshTokenRepo() if err != nil { t.Fatalf("Unexpected error: %v", err) } srv := &Server{ IssuerURL: issuerURL, KeyManager: km, SessionManager: sm, ClientIdentityRepo: ciRepo, UserRepo: userRepo, RefreshTokenRepo: refreshTokenRepo, } _, err = sm.NewSessionKey(sessionID) if err != nil { t.Fatalf("Unexpected error: %v", err) } jwt, token, err := srv.CodeToken(tt.argCC, tt.argKey) if token != tt.refreshToken { fmt.Printf("case %d: expect refresh token %q, got %q\n", i, tt.refreshToken, token) t.Fatalf("case %d: expect refresh token %q, got %q", i, tt.refreshToken, token) panic("") } if !reflect.DeepEqual(err, tt.err) { t.Errorf("case %d: expect %v, got %v", i, tt.err, err) } if err == nil && jwt == nil { t.Errorf("case %d: got nil JWT", i) } if err != nil && jwt != nil { t.Errorf("case %d: got non-nil JWT %v", i, jwt) } } }
func TestServerCodeToken(t *testing.T) { ci := oidc.ClientIdentity{ Credentials: oidc.ClientCredentials{ ID: "XXX", Secret: "secrete", }, } ciRepo := client.NewClientIdentityRepo([]oidc.ClientIdentity{ci}) km := &StaticKeyManager{ signer: &StaticSigner{sig: []byte("beer"), err: nil}, } sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()) userRepo, err := makeNewUserRepo() if err != nil { t.Fatalf("Unexpected error: %v", err) } refreshTokenRepo, err := refreshtest.NewTestRefreshTokenRepo() if err != nil { t.Fatalf("Unexpected error: %v", err) } srv := &Server{ IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"}, KeyManager: km, SessionManager: sm, ClientIdentityRepo: ciRepo, UserRepo: userRepo, RefreshTokenRepo: refreshTokenRepo, } tests := []struct { scope []string refreshToken string }{ // No 'offline_access' in scope, should get empty refresh token. { scope: []string{"openid"}, refreshToken: "", }, // Have 'offline_access' in scope, should get non-empty refresh token. { scope: []string{"openid", "offline_access"}, refreshToken: fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), }, } for i, tt := range tests { sessionID, err := sm.NewSession("bogus_idpc", ci.Credentials.ID, "bogus", url.URL{}, "", false, tt.scope) if err != nil { t.Fatalf("case %d: unexpected error: %v", i, err) } _, err = sm.AttachRemoteIdentity(sessionID, oidc.Identity{}) if err != nil { t.Fatalf("case %d: unexpected error: %v", i, err) } _, err = sm.AttachUser(sessionID, "testid-1") if err != nil { t.Fatalf("case %d: unexpected error: %v", i, err) } key, err := sm.NewSessionKey(sessionID) if err != nil { t.Fatalf("case %d: unexpected error: %v", i, err) } jwt, token, err := srv.CodeToken(ci.Credentials, key) if err != nil { t.Fatalf("case %d: unexpected error: %v", i, err) } if jwt == nil { t.Fatalf("case %d: expect non-nil jwt", i) } if token != tt.refreshToken { t.Fatalf("case %d: expect refresh token %q, got %q", i, tt.refreshToken, token) } } }
func TestServerLoginDisabledUser(t *testing.T) { ci := oidc.ClientIdentity{ Credentials: oidc.ClientCredentials{ ID: "XXX", Secret: "secrete", }, Metadata: oidc.ClientMetadata{ RedirectURLs: []url.URL{ url.URL{ Scheme: "http", Host: "client.example.com", Path: "/callback", }, }, }, } ciRepo := client.NewClientIdentityRepo([]oidc.ClientIdentity{ci}) km := &StaticKeyManager{ signer: &StaticSigner{sig: []byte("beer"), err: nil}, } sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()) sm.GenerateCode = staticGenerateCodeFunc("fakecode") sessionID, err := sm.NewSession("test_connector_id", ci.Credentials.ID, "bogus", ci.Metadata.RedirectURLs[0], "", false, []string{"openid"}) if err != nil { t.Fatalf("Unexpected error: %v", err) } userRepo, err := makeNewUserRepo() if err != nil { t.Fatalf("Unexpected error: %v", err) } err = userRepo.Create(nil, user.User{ ID: "disabled-1", Email: "*****@*****.**", Disabled: true, }) if err != nil { t.Fatalf("Unexpected error: %v", err) } err = userRepo.AddRemoteIdentity(nil, "disabled-1", user.RemoteIdentity{ ConnectorID: "test_connector_id", ID: "disabled-connector-id", }) srv := &Server{ IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"}, KeyManager: km, SessionManager: sm, ClientIdentityRepo: ciRepo, UserRepo: userRepo, } ident := oidc.Identity{ID: "disabled-connector-id", Name: "elroy", Email: "*****@*****.**"} key, err := sm.NewSessionKey(sessionID) if err != nil { t.Fatalf("Unexpected error: %v", err) } _, err = srv.Login(ident, key) if err == nil { t.Errorf("disabled user was allowed to log in") } }
func makeTestFixtures() (*testFixtures, error) { userRepo := user.NewUserRepoFromUsers(testUsers) pwRepo := user.NewPasswordInfoRepoFromPasswordInfos(testPasswordInfos) manager := user.NewManager(userRepo, pwRepo, repo.InMemTransactionFactory, user.ManagerOptions{}) connConfigs := []connector.ConnectorConfig{ &connector.OIDCConnectorConfig{ ID: "oidc", IssuerURL: testIssuerURL.String(), ClientID: "12345", ClientSecret: "567789", }, &connector.OIDCConnectorConfig{ ID: "oidc-trusted", IssuerURL: testIssuerURL.String(), ClientID: "12345-trusted", ClientSecret: "567789-trusted", TrustedEmailProvider: true, }, &connector.LocalConnectorConfig{ ID: "local", }, } sessionManager := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()) sessionManager.GenerateCode = sequentialGenerateCodeFunc() emailer, err := email.NewTemplatizedEmailerFromGlobs( emailTemplatesLocation+"/*.txt", emailTemplatesLocation+"/*.html", &email.FakeEmailer{}) if err != nil { return nil, err } clientIdentityRepo := client.NewClientIdentityRepo([]oidc.ClientIdentity{ oidc.ClientIdentity{ Credentials: oidc.ClientCredentials{ ID: "XXX", Secret: testClientSecret, }, Metadata: oidc.ClientMetadata{ RedirectURLs: []url.URL{ testRedirectURL, }, }, }, }) km := key.NewPrivateKeyManager() err = km.Set(key.NewPrivateKeySet([]*key.PrivateKey{testPrivKey}, time.Now().Add(time.Minute))) if err != nil { return nil, err } tpl, err := getTemplates("dex", "https://coreos.com/assets/images/brand/coreos-mark-30px.png", templatesLocation) if err != nil { return nil, err } srv := &Server{ IssuerURL: testIssuerURL, SessionManager: sessionManager, ClientIdentityRepo: clientIdentityRepo, Templates: tpl, UserRepo: userRepo, PasswordInfoRepo: pwRepo, UserManager: manager, KeyManager: km, } err = setTemplates(srv, tpl) if err != nil { return nil, err } for _, config := range connConfigs { if err := srv.AddConnector(config); err != nil { return nil, err } } srv.UserEmailer = useremail.NewUserEmailer(srv.UserRepo, srv.PasswordInfoRepo, srv.KeyManager.Signer, srv.SessionManager.ValidityWindow, srv.IssuerURL, emailer, "*****@*****.**", srv.absURL(httpPathResetPassword), srv.absURL(httpPathEmailVerify)) return &testFixtures{ srv: srv, redirectURL: testRedirectURL, userRepo: userRepo, sessionManager: sessionManager, emailer: emailer, clientIdentityRepo: clientIdentityRepo, }, nil }
func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) { idpcs := []connector.Connector{ &fakeConnector{loginURL: "http://fake.example.com"}, } srv := &Server{ IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"}, SessionManager: session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()), ClientIdentityRepo: client.NewClientIdentityRepo([]oidc.ClientIdentity{ oidc.ClientIdentity{ Credentials: oidc.ClientCredentials{ ID: "XXX", Secret: "secrete", }, Metadata: oidc.ClientMetadata{ RedirectURLs: []url.URL{ url.URL{Scheme: "http", Host: "client.example.com", Path: "/callback"}, }, }, }, }), } tests := []struct { query url.Values wantCode int wantLocation string }{ // no redirect_uri provided, but client only has one, so it's usable { query: url.Values{ "response_type": []string{"code"}, "client_id": []string{"XXX"}, "connector_id": []string{"fake"}, "scope": []string{"openid"}, }, wantCode: http.StatusTemporaryRedirect, wantLocation: "http://fake.example.com", }, // provided redirect_uri matches client { query: url.Values{ "response_type": []string{"code"}, "redirect_uri": []string{"http://client.example.com/callback"}, "client_id": []string{"XXX"}, "connector_id": []string{"fake"}, "scope": []string{"openid"}, }, wantCode: http.StatusTemporaryRedirect, wantLocation: "http://fake.example.com", }, // provided redirect_uri does not match client { query: url.Values{ "response_type": []string{"code"}, "redirect_uri": []string{"http://unrecognized.example.com/callback"}, "client_id": []string{"XXX"}, "connector_id": []string{"fake"}, "scope": []string{"openid"}, }, wantCode: http.StatusBadRequest, }, // nonexistant client_id { query: url.Values{ "response_type": []string{"code"}, "redirect_uri": []string{"http://client.example.com/callback"}, "client_id": []string{"YYY"}, "connector_id": []string{"fake"}, "scope": []string{"openid"}, }, wantCode: http.StatusBadRequest, }, // unsupported response type, redirects back to client { query: url.Values{ "response_type": []string{"token"}, "client_id": []string{"XXX"}, "connector_id": []string{"fake"}, "scope": []string{"openid"}, }, wantCode: http.StatusTemporaryRedirect, wantLocation: "http://client.example.com/callback?error=unsupported_response_type&state=", }, // no 'openid' in scope { query: url.Values{ "response_type": []string{"code"}, "redirect_uri": []string{"http://client.example.com/callback"}, "client_id": []string{"XXX"}, "connector_id": []string{"fake"}, }, wantCode: http.StatusBadRequest, }, } for i, tt := range tests { hdlr := handleAuthFunc(srv, idpcs, nil) w := httptest.NewRecorder() u := fmt.Sprintf("http://server.example.com?%s", tt.query.Encode()) req, err := http.NewRequest("GET", u, nil) if err != nil { t.Errorf("case %d: unable to form HTTP request: %v", i, err) continue } hdlr.ServeHTTP(w, req) if tt.wantCode != w.Code { t.Errorf("case %d: HTTP code mismatch: want=%d got=%d", i, tt.wantCode, w.Code) continue } gotLocation := w.Header().Get("Location") if tt.wantLocation != gotLocation { t.Errorf("case %d: HTTP Location header mismatch: want=%s got=%s", i, tt.wantLocation, gotLocation) } } }
func TestHandleAuthFuncResponsesMultipleRedirectURLs(t *testing.T) { idpcs := []connector.Connector{ &fakeConnector{loginURL: "http://fake.example.com"}, } srv := &Server{ IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"}, SessionManager: session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()), ClientIdentityRepo: client.NewClientIdentityRepo([]oidc.ClientIdentity{ oidc.ClientIdentity{ Credentials: oidc.ClientCredentials{ ID: "XXX", Secret: "secrete", }, Metadata: oidc.ClientMetadata{ RedirectURLs: []url.URL{ url.URL{Scheme: "http", Host: "foo.example.com", Path: "/callback"}, url.URL{Scheme: "http", Host: "bar.example.com", Path: "/callback"}, }, }, }, }), } tests := []struct { query url.Values wantCode int wantLocation string }{ // provided redirect_uri matches client's first { query: url.Values{ "response_type": []string{"code"}, "redirect_uri": []string{"http://foo.example.com/callback"}, "client_id": []string{"XXX"}, "connector_id": []string{"fake"}, "scope": []string{"openid"}, }, wantCode: http.StatusTemporaryRedirect, wantLocation: "http://fake.example.com", }, // provided redirect_uri matches client's second { query: url.Values{ "response_type": []string{"code"}, "redirect_uri": []string{"http://bar.example.com/callback"}, "client_id": []string{"XXX"}, "connector_id": []string{"fake"}, "scope": []string{"openid"}, }, wantCode: http.StatusTemporaryRedirect, wantLocation: "http://fake.example.com", }, // provided redirect_uri does not match either of client's { query: url.Values{ "response_type": []string{"code"}, "redirect_uri": []string{"http://unrecognized.example.com/callback"}, "client_id": []string{"XXX"}, "connector_id": []string{"fake"}, "scope": []string{"openid"}, }, wantCode: http.StatusBadRequest, }, // no redirect_uri provided { query: url.Values{ "response_type": []string{"code"}, "client_id": []string{"XXX"}, "connector_id": []string{"fake"}, "scope": []string{"openid"}, }, wantCode: http.StatusBadRequest, }, } for i, tt := range tests { hdlr := handleAuthFunc(srv, idpcs, nil, true) w := httptest.NewRecorder() u := fmt.Sprintf("http://server.example.com?%s", tt.query.Encode()) req, err := http.NewRequest("GET", u, nil) if err != nil { t.Errorf("case %d: unable to form HTTP request: %v", i, err) continue } hdlr.ServeHTTP(w, req) if tt.wantCode != w.Code { t.Errorf("case %d: HTTP code mismatch: want=%d got=%d", i, tt.wantCode, w.Code) t.Errorf("case %d: BODY: %v", i, w.Body.String()) t.Errorf("case %d: LOCO: %v", i, w.HeaderMap.Get("Location")) continue } gotLocation := w.Header().Get("Location") if tt.wantLocation != gotLocation { t.Errorf("case %d: HTTP Location header mismatch: want=%s got=%s", i, tt.wantLocation, gotLocation) } } }
func TestHTTPExchangeTokenRefreshToken(t *testing.T) { password, err := user.NewPasswordFromPlaintext("woof") if err != nil { t.Fatalf("unexpectd error: %q", err) } passwordInfo := user.PasswordInfo{ UserID: "elroy77", Password: password, } cfg := &connector.LocalConnectorConfig{ PasswordInfos: []user.PasswordInfo{passwordInfo}, } ci := oidc.ClientIdentity{ Credentials: oidc.ClientCredentials{ ID: "72de74a9", Secret: "XXX", }, } cir := client.NewClientIdentityRepo([]oidc.ClientIdentity{ci}) issuerURL := url.URL{Scheme: "http", Host: "server.example.com"} sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()) k, err := key.GeneratePrivateKey() if err != nil { t.Fatalf("Unable to generate RSA key: %v", err) } km := key.NewPrivateKeyManager() err = km.Set(key.NewPrivateKeySet([]*key.PrivateKey{k}, time.Now().Add(time.Minute))) if err != nil { t.Fatalf("Unexpected error: %v", err) } usr := user.User{ ID: "ID-test", Email: "*****@*****.**", DisplayName: "displayname", } userRepo := user.NewUserRepo() if err := userRepo.Create(nil, usr); err != nil { t.Fatalf("Unexpected error: %v", err) } passwordInfoRepo := user.NewPasswordInfoRepo() refreshTokenRepo, err := refreshtest.NewTestRefreshTokenRepo() if err != nil { t.Fatalf("Unexpected error: %v", err) } srv := &server.Server{ IssuerURL: issuerURL, KeyManager: km, SessionManager: sm, ClientIdentityRepo: cir, Templates: template.New(connector.LoginPageTemplateName), Connectors: []connector.Connector{}, UserRepo: userRepo, PasswordInfoRepo: passwordInfoRepo, RefreshTokenRepo: refreshTokenRepo, } if err = srv.AddConnector(cfg); err != nil { t.Fatalf("Unexpected error: %v", err) } sClient := &phttp.HandlerClient{Handler: srv.HTTPHandler()} pcfg, err := oidc.FetchProviderConfig(sClient, issuerURL.String()) if err != nil { t.Fatalf("Failed to fetch provider config: %v", err) } ks := key.NewPublicKeySet([]jose.JWK{k.JWK()}, time.Now().Add(1*time.Hour)) ccfg := oidc.ClientConfig{ HTTPClient: sClient, ProviderConfig: pcfg, Credentials: ci.Credentials, RedirectURL: "http://client.example.com", KeySet: *ks, } cl, err := oidc.NewClient(ccfg) if err != nil { t.Fatalf("Failed creating oidc.Client: %v", err) } m := http.NewServeMux() var claims jose.Claims var refresh string m.HandleFunc("/callback", handleCallbackFunc(cl, &claims, &refresh)) cClient := &phttp.HandlerClient{Handler: m} // this will actually happen due to some interaction between the // end-user and a remote identity provider sessionID, err := sm.NewSession("bogus_idpc", ci.Credentials.ID, "bogus", url.URL{}, "", false, []string{"openid", "offline_access"}) if err != nil { t.Fatalf("Unexpected error: %v", err) } if _, err = sm.AttachRemoteIdentity(sessionID, passwordInfo.Identity()); err != nil { t.Fatalf("Unexpected error: %v", err) } if _, err = sm.AttachUser(sessionID, usr.ID); err != nil { t.Fatalf("Unexpected error: %v", err) } key, err := sm.NewSessionKey(sessionID) if err != nil { t.Fatalf("Unexpected error: %v", err) } req, err := http.NewRequest("GET", fmt.Sprintf("http://client.example.com/callback?code=%s", key), nil) if err != nil { t.Fatalf("Failed creating HTTP request: %v", err) } resp, err := cClient.Do(req) if err != nil { t.Fatalf("Failed resolving HTTP requests against /callback: %v", err) } if err := verifyUserClaims(claims, &ci, &usr, issuerURL); err != nil { t.Fatalf("Failed to verify claims: %v", err) } if resp.StatusCode != http.StatusOK { t.Fatalf("Received status code %d, want %d", resp.StatusCode, http.StatusOK) } if refresh == "" { t.Fatalf("No refresh token") } // Use refresh token to get a new ID token. token, err := cl.RefreshToken(refresh) if err != nil { t.Fatalf("Unexpected error: %v", err) } claims, err = token.Claims() if err != nil { t.Fatalf("Failed parsing claims from client token: %v", err) } if err := verifyUserClaims(claims, &ci, &usr, issuerURL); err != nil { t.Fatalf("Failed to verify claims: %v", err) } }