Ejemplo n.º 1
0
func TestDBClientIdentityRepoMetadataNoExist(t *testing.T) {
	r := db.NewClientIdentityRepo(connect(t))

	got, err := r.Metadata("noexist")
	if err != client.ErrorNotFound {
		t.Errorf("want==%q, got==%q", client.ErrorNotFound, err)
	}
	if got != nil {
		t.Fatalf("Retrieved incorrect ClientMetadata: want=nil got=%#v", got)
	}
}
Ejemplo n.º 2
0
func TestCreate(t *testing.T) {
	repo := db.NewClientIdentityRepo(db.NewMemDB())
	res := &clientResource{repo: repo}
	tests := [][]string{
		[]string{"http://example.com"},
		[]string{"https://example.com"},
		[]string{"http://example.com/foo"},
		[]string{"http://example.com/bar", "http://example.com/foo"},
	}
	endpoint := "http://example.com/clients"

	for i, tt := range tests {
		body := strings.NewReader(fmt.Sprintf(`{"redirectURIs":["%s"]}`, strings.Join(tt, `","`)))
		r, err := http.NewRequest("POST", endpoint, body)
		if err != nil {
			t.Fatalf("Failed creating http.Request: %v", err)
		}
		r.Header.Set("content-type", "application/json")
		w := httptest.NewRecorder()
		res.ServeHTTP(w, r)

		if w.Code != http.StatusCreated {
			t.Errorf("case %d: invalid response code, want=%d, got=%d", i, http.StatusCreated, w.Code)
		}

		var client schema.ClientWithSecret
		if err := json.Unmarshal(w.Body.Bytes(), &client); err != nil {
			t.Errorf("case %d: unexpected error=%v", i, err)
		}
		if len(client.RedirectURIs) != len(tt) {
			t.Errorf("case %d: unexpected number of redirect URIs, want=%d, got=%d", i, len(tt), len(client.RedirectURIs))
		}

		if !reflect.DeepEqual(tt, client.RedirectURIs) {
			t.Errorf("case %d: unexpected client redirect URIs: want=%v got=%v", i, tt, client.RedirectURIs)
		}

		if client.Id == "" {
			t.Errorf("case %d: empty client ID in response", i)
		}

		if client.Secret == "" {
			t.Errorf("case %d: empty client secret in response", i)
		}

		wantLoc := fmt.Sprintf("%s/%s", endpoint, client.Id)
		gotLoc := w.Header().Get("Location")
		if gotLoc != wantLoc {
			t.Errorf("case %d: invalid location header, want=%v, got=%v", i, wantLoc, gotLoc)
		}
	}
}
Ejemplo n.º 3
0
func newDBDriver(dsn string) (driver, error) {
	dbc, err := db.NewConnection(db.Config{DSN: dsn})
	if err != nil {
		return nil, err
	}

	drv := &dbDriver{
		ciRepo:  db.NewClientIdentityRepo(dbc),
		cfgRepo: db.NewConnectorConfigRepo(dbc),
		usrRepo: db.NewUserRepo(dbc),
	}
	return drv, nil
}
Ejemplo n.º 4
0
func TestDBClientIdentityRepoAuthenticate(t *testing.T) {
	r := db.NewClientIdentityRepo(connect(t))

	cm := oidc.ClientMetadata{
		RedirectURIs: []url.URL{
			url.URL{Scheme: "http", Host: "127.0.0.1:5556", Path: "/cb"},
		},
	}

	cc, err := r.New("baz", cm)
	if err != nil {
		t.Fatalf(err.Error())
	}

	if cc.ID != "baz" {
		t.Fatalf("Returned ClientCredentials has incorrect ID: want=baz got=%s", cc.ID)
	}

	ok, err := r.Authenticate(*cc)
	if err != nil {
		t.Fatalf("Unexpected error: %v", err)
	} else if !ok {
		t.Fatalf("Authentication failed for good creds")
	}

	creds := []oidc.ClientCredentials{
		// completely made up
		oidc.ClientCredentials{ID: "foo", Secret: "bar"},

		// good client ID, bad secret
		oidc.ClientCredentials{ID: cc.ID, Secret: "bar"},

		// bad client ID, good secret
		oidc.ClientCredentials{ID: "foo", Secret: cc.Secret},

		// good client ID, secret with some fluff on the end
		oidc.ClientCredentials{ID: cc.ID, Secret: fmt.Sprintf("%sfluff", cc.Secret)},
	}
	for i, c := range creds {
		ok, err := r.Authenticate(c)
		if err != nil {
			t.Errorf("case %d: unexpected error: %v", i, err)
		} else if ok {
			t.Errorf("case %d: authentication succeeded for bad creds", i)
		}
	}
}
Ejemplo n.º 5
0
func TestDBClientIdentityAll(t *testing.T) {
	r := db.NewClientIdentityRepo(connect(t))

	cm := oidc.ClientMetadata{
		RedirectURIs: []url.URL{
			url.URL{Scheme: "http", Host: "127.0.0.1:5556", Path: "/cb"},
		},
	}

	_, err := r.New("foo", cm)
	if err != nil {
		t.Fatalf(err.Error())
	}

	got, err := r.All()
	if err != nil {
		t.Fatalf(err.Error())
	}
	count := len(got)
	if count != 1 {
		t.Fatalf("Retrieved incorrect number of ClientIdentities: want=1 got=%d", count)
	}

	if diff := pretty.Compare(cm, got[0].Metadata); diff != "" {
		t.Fatalf("Retrieved incorrect ClientMetadata: Compare(want,got): %v", diff)
	}

	cm = oidc.ClientMetadata{
		RedirectURIs: []url.URL{
			url.URL{Scheme: "http", Host: "foo.com", Path: "/cb"},
		},
	}
	_, err = r.New("bar", cm)
	if err != nil {
		t.Fatalf(err.Error())
	}

	got, err = r.All()
	if err != nil {
		t.Fatalf(err.Error())
	}
	count = len(got)
	if count != 2 {
		t.Fatalf("Retrieved incorrect number of ClientIdentities: want=2 got=%d", count)
	}
}
Ejemplo n.º 6
0
Archivo: config.go Proyecto: ryanj/dex
func (cfg *MultiServerConfig) Configure(srv *Server) error {
	if len(cfg.KeySecrets) == 0 {
		return errors.New("missing key secret")
	}

	if cfg.DatabaseConfig.DSN == "" {
		return errors.New("missing database connection string")
	}

	dbc, err := db.NewConnection(cfg.DatabaseConfig)
	if err != nil {
		return fmt.Errorf("unable to initialize database connection: %v", err)
	}
	if _, ok := dbc.Dialect.(gorp.PostgresDialect); !ok {
		return errors.New("only postgres backend supported for multi server configurations")
	}

	kRepo, err := db.NewPrivateKeySetRepo(dbc, cfg.UseOldFormat, cfg.KeySecrets...)
	if err != nil {
		return fmt.Errorf("unable to create PrivateKeySetRepo: %v", err)
	}

	ciRepo := db.NewClientIdentityRepo(dbc)
	sRepo := db.NewSessionRepo(dbc)
	skRepo := db.NewSessionKeyRepo(dbc)
	cfgRepo := db.NewConnectorConfigRepo(dbc)
	userRepo := db.NewUserRepo(dbc)
	pwiRepo := db.NewPasswordInfoRepo(dbc)
	userManager := usermanager.NewUserManager(userRepo, pwiRepo, cfgRepo, db.TransactionFactory(dbc), usermanager.ManagerOptions{})
	refreshTokenRepo := db.NewRefreshTokenRepo(dbc)

	sm := sessionmanager.NewSessionManager(sRepo, skRepo)

	srv.ClientIdentityRepo = ciRepo
	srv.KeySetRepo = kRepo
	srv.ConnectorConfigRepo = cfgRepo
	srv.UserRepo = userRepo
	srv.UserManager = userManager
	srv.PasswordInfoRepo = pwiRepo
	srv.SessionManager = sm
	srv.RefreshTokenRepo = refreshTokenRepo
	srv.HealthChecks = append(srv.HealthChecks, db.NewHealthChecker(dbc))
	return nil
}
Ejemplo n.º 7
0
func (cfg *MultiServerConfig) Configure(srv *Server) error {
	if len(cfg.KeySecrets) == 0 {
		return errors.New("missing key secret")
	}

	if cfg.DatabaseConfig.DSN == "" {
		return errors.New("missing database connection string")
	}

	dbc, err := db.NewConnection(cfg.DatabaseConfig)
	if err != nil {
		return fmt.Errorf("unable to initialize database connection: %v", err)
	}

	kRepo, err := db.NewPrivateKeySetRepo(dbc, cfg.UseOldFormat, cfg.KeySecrets...)
	if err != nil {
		return fmt.Errorf("unable to create PrivateKeySetRepo: %v", err)
	}

	ciRepo := db.NewClientIdentityRepo(dbc)
	sRepo := db.NewSessionRepo(dbc)
	skRepo := db.NewSessionKeyRepo(dbc)
	cfgRepo := db.NewConnectorConfigRepo(dbc)
	userRepo := db.NewUserRepo(dbc)
	pwiRepo := db.NewPasswordInfoRepo(dbc)
	userManager := manager.NewUserManager(userRepo, pwiRepo, cfgRepo, db.TransactionFactory(dbc), manager.ManagerOptions{})
	refreshTokenRepo := db.NewRefreshTokenRepo(dbc)

	sm := session.NewSessionManager(sRepo, skRepo)

	srv.ClientIdentityRepo = ciRepo
	srv.KeySetRepo = kRepo
	srv.ConnectorConfigRepo = cfgRepo
	srv.UserRepo = userRepo
	srv.UserManager = userManager
	srv.PasswordInfoRepo = pwiRepo
	srv.SessionManager = sm
	srv.RefreshTokenRepo = refreshTokenRepo
	return nil
}
Ejemplo n.º 8
0
func TestDBClientIdentityRepoNewDuplicate(t *testing.T) {
	r := db.NewClientIdentityRepo(connect(t))

	meta1 := oidc.ClientMetadata{
		RedirectURIs: []url.URL{
			url.URL{Scheme: "http", Host: "foo.example.com"},
		},
	}

	if _, err := r.New("foo", meta1); err != nil {
		t.Fatalf("unexpected error: %v", err)
	}

	meta2 := oidc.ClientMetadata{
		RedirectURIs: []url.URL{
			url.URL{Scheme: "http", Host: "bar.example.com"},
		},
	}

	if _, err := r.New("foo", meta2); err == nil {
		t.Fatalf("expected non-nil error")
	}
}
Ejemplo n.º 9
0
func TestDBClientIdentityRepoMetadata(t *testing.T) {
	r := db.NewClientIdentityRepo(connect(t))

	cm := oidc.ClientMetadata{
		RedirectURIs: []url.URL{
			url.URL{Scheme: "http", Host: "127.0.0.1:5556", Path: "/cb"},
			url.URL{Scheme: "https", Host: "example.com", Path: "/callback"},
		},
	}

	_, err := r.New("foo", cm)
	if err != nil {
		t.Fatalf(err.Error())
	}

	got, err := r.Metadata("foo")
	if err != nil {
		t.Fatalf(err.Error())
	}

	if diff := pretty.Compare(cm, *got); diff != "" {
		t.Fatalf("Retrieved incorrect ClientMetadata: Compare(want,got): %v", diff)
	}
}
Ejemplo n.º 10
0
func TestClientToken(t *testing.T) {
	now := time.Now()
	tomorrow := now.Add(24 * time.Hour)
	validClientID := "valid-client"
	ci := oidc.ClientIdentity{
		Credentials: oidc.ClientCredentials{
			ID:     validClientID,
			Secret: base64.URLEncoding.EncodeToString([]byte("secret")),
		},
		Metadata: oidc.ClientMetadata{
			RedirectURIs: []url.URL{
				{Scheme: "https", Host: "authn.example.com", Path: "/callback"},
			},
		},
	}
	repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{ci})
	if err != nil {
		t.Fatalf("Failed to create client identity repo: %v", err)
	}

	privKey, err := key.GeneratePrivateKey()
	if err != nil {
		t.Fatalf("Failed to generate private key, error=%v", err)
	}
	signer := privKey.Signer()
	pubKey := *key.NewPublicKey(privKey.JWK())

	validIss := "https://example.com"

	makeToken := func(iss, sub, aud string, iat, exp time.Time) string {
		claims := oidc.NewClaims(iss, sub, aud, iat, exp)
		jwt, err := jose.NewSignedJWT(claims, signer)
		if err != nil {
			t.Fatalf("Failed to generate JWT, error=%v", err)
		}
		return jwt.Encode()
	}

	validJWT := makeToken(validIss, validClientID, validClientID, now, tomorrow)
	invalidJWT := makeToken("", "", "", now, tomorrow)

	tests := []struct {
		keys     []key.PublicKey
		repo     client.ClientIdentityRepo
		header   string
		wantCode int
	}{
		// valid token
		{
			keys:     []key.PublicKey{pubKey},
			repo:     repo,
			header:   fmt.Sprintf("BEARER %s", validJWT),
			wantCode: http.StatusOK,
		},
		// invalid token
		{
			keys:     []key.PublicKey{pubKey},
			repo:     repo,
			header:   fmt.Sprintf("BEARER %s", invalidJWT),
			wantCode: http.StatusUnauthorized,
		},
		// empty header
		{
			keys:     []key.PublicKey{pubKey},
			repo:     repo,
			header:   "",
			wantCode: http.StatusUnauthorized,
		},
		// unparsable token
		{
			keys:     []key.PublicKey{pubKey},
			repo:     repo,
			header:   "BEARER xxx",
			wantCode: http.StatusUnauthorized,
		},
		// no verification keys
		{
			keys:     []key.PublicKey{},
			repo:     repo,
			header:   fmt.Sprintf("BEARER %s", validJWT),
			wantCode: http.StatusUnauthorized,
		},
		// nil repo
		{
			keys:     []key.PublicKey{pubKey},
			repo:     nil,
			header:   fmt.Sprintf("BEARER %s", validJWT),
			wantCode: http.StatusUnauthorized,
		},
		// empty repo
		{
			keys:     []key.PublicKey{pubKey},
			repo:     db.NewClientIdentityRepo(db.NewMemDB()),
			header:   fmt.Sprintf("BEARER %s", validJWT),
			wantCode: http.StatusUnauthorized,
		},
		// client not in repo
		{
			keys:     []key.PublicKey{pubKey},
			repo:     repo,
			header:   fmt.Sprintf("BEARER %s", makeToken(validIss, "DOESNT-EXIST", "DOESNT-EXIST", now, tomorrow)),
			wantCode: http.StatusUnauthorized,
		},
	}

	for i, tt := range tests {
		w := httptest.NewRecorder()
		mw := &clientTokenMiddleware{
			issuerURL: validIss,
			ciRepo:    tt.repo,
			keysFunc: func() ([]key.PublicKey, error) {
				return tt.keys, nil
			},
			next: staticHandler{},
		}
		req := &http.Request{
			Header: http.Header{
				"Authorization": []string{tt.header},
			},
		}

		mw.ServeHTTP(w, req)
		if tt.wantCode != w.Code {
			t.Errorf("case %d: invalid response code, want=%d, got=%d", i, tt.wantCode, w.Code)
		}
	}
}
Ejemplo n.º 11
0
func TestCreateInvalidRequest(t *testing.T) {
	u := &url.URL{Scheme: "http", Host: "example.com", Path: "clients"}
	h := http.Header{"Content-Type": []string{"application/json"}}
	repo := db.NewClientIdentityRepo(db.NewMemDB())
	res := &clientResource{repo: repo}
	tests := []struct {
		req      *http.Request
		wantCode int
		wantBody string
	}{
		// invalid content-type
		{
			req:      &http.Request{Method: "POST", URL: u, Header: http.Header{"Content-Type": []string{"application/xml"}}},
			wantCode: http.StatusBadRequest,
			wantBody: `{"error":"invalid_request","error_description":"unsupported content-type"}`,
		},
		// invalid method
		{
			req:      &http.Request{Method: "DELETE", URL: u, Header: h},
			wantCode: http.StatusMethodNotAllowed,
			wantBody: `{"error":"invalid_request","error_description":"HTTP DELETE method not supported for this resource"}`,
		},
		// invalid method
		{
			req:      &http.Request{Method: "PUT", URL: u, Header: h},
			wantCode: http.StatusMethodNotAllowed,
			wantBody: `{"error":"invalid_request","error_description":"HTTP PUT method not supported for this resource"}`,
		},
		// invalid method
		{
			req:      &http.Request{Method: "HEAD", URL: u, Header: h},
			wantCode: http.StatusMethodNotAllowed,
			wantBody: `{"error":"invalid_request","error_description":"HTTP HEAD method not supported for this resource"}`,
		},
		// unserializable body
		{
			req:      &http.Request{Method: "POST", URL: u, Header: h, Body: makeBody("asdf")},
			wantCode: http.StatusBadRequest,
			wantBody: `{"error":"invalid_request","error_description":"unable to decode request body"}`,
		},
		// empty body
		{
			req:      &http.Request{Method: "POST", URL: u, Header: h, Body: makeBody("")},
			wantCode: http.StatusBadRequest,
			wantBody: `{"error":"invalid_request","error_description":"unable to decode request body"}`,
		},
		// missing url field
		{
			req:      &http.Request{Method: "POST", URL: u, Header: h, Body: makeBody(`{"id":"foo"}`)},
			wantCode: http.StatusBadRequest,
			wantBody: `{"error":"invalid_client_metadata","error_description":"zero redirect URLs"}`,
		},
		// empty url array
		{
			req:      &http.Request{Method: "POST", URL: u, Header: h, Body: makeBody(`{"redirectURIs":[]}`)},
			wantCode: http.StatusBadRequest,
			wantBody: `{"error":"invalid_client_metadata","error_description":"zero redirect URLs"}`,
		},
		// array with empty string
		{
			req:      &http.Request{Method: "POST", URL: u, Header: h, Body: makeBody(`{"redirectURIs":[""]}`)},
			wantCode: http.StatusBadRequest,
			wantBody: `{"error":"invalid_client_metadata","error_description":"missing or invalid field: redirectURIs"}`,
		},
		// uri with unusable scheme
		{
			req:      &http.Request{Method: "POST", URL: u, Header: h, Body: makeBody(`{"redirectURIs":["asdf.com"]}`)},
			wantCode: http.StatusBadRequest,
			wantBody: `{"error":"invalid_client_metadata","error_description":"no host for uri field redirect_uris"}`,
		},
		// uri missing host
		{
			req:      &http.Request{Method: "POST", URL: u, Header: h, Body: makeBody(`{"redirectURIs":["http://"]}`)},
			wantCode: http.StatusBadRequest,
			wantBody: `{"error":"invalid_client_metadata","error_description":"no host for uri field redirect_uris"}`,
		},
	}

	for i, tt := range tests {
		w := httptest.NewRecorder()
		res.ServeHTTP(w, tt.req)

		if w.Code != tt.wantCode {
			t.Errorf("case %d: invalid response code, want=%d, got=%d", i, tt.wantCode, w.Code)
		}

		gotBody := w.Body.String()
		if gotBody != tt.wantBody {
			t.Errorf("case %d: invalid response body, want=%s, got=%s", i, tt.wantBody, gotBody)
		}
	}
}