Beispiel #1
0
func (cfg *MultiServerConfig) Configure(srv *Server) error {
	if len(cfg.KeySecrets) == 0 {
		return errors.New("missing key secret")
	}

	if cfg.DatabaseConfig.DSN == "" {
		return errors.New("missing database connection string")
	}

	dbc, err := db.NewConnection(cfg.DatabaseConfig)
	if err != nil {
		return fmt.Errorf("unable to initialize database connection: %v", err)
	}
	if _, ok := dbc.Dialect.(gorp.PostgresDialect); !ok {
		return errors.New("only postgres backend supported for multi server configurations")
	}

	kRepo, err := db.NewPrivateKeySetRepo(dbc, cfg.UseOldFormat, cfg.KeySecrets...)
	if err != nil {
		return fmt.Errorf("unable to create PrivateKeySetRepo: %v", err)
	}

	ciRepo := db.NewClientRepo(dbc)
	sRepo := db.NewSessionRepo(dbc)
	skRepo := db.NewSessionKeyRepo(dbc)
	cfgRepo := db.NewConnectorConfigRepo(dbc)
	userRepo := db.NewUserRepo(dbc)
	pwiRepo := db.NewPasswordInfoRepo(dbc)
	userManager := usermanager.NewUserManager(userRepo, pwiRepo, cfgRepo, db.TransactionFactory(dbc), usermanager.ManagerOptions{})
	clientManager := clientmanager.NewClientManager(ciRepo, db.TransactionFactory(dbc), clientmanager.ManagerOptions{})
	refreshTokenRepo := db.NewRefreshTokenRepo(dbc)

	sm := sessionmanager.NewSessionManager(sRepo, skRepo)

	srv.ClientRepo = ciRepo
	srv.ClientManager = clientManager
	srv.KeySetRepo = kRepo
	srv.ConnectorConfigRepo = cfgRepo
	srv.UserRepo = userRepo
	srv.UserManager = userManager
	srv.PasswordInfoRepo = pwiRepo
	srv.SessionManager = sm
	srv.RefreshTokenRepo = refreshTokenRepo
	srv.HealthChecks = append(srv.HealthChecks, db.NewHealthChecker(dbc))
	srv.dbMap = dbc
	return nil
}
Beispiel #2
0
func (cfg *MultiServerConfig) Configure(srv *Server) error {
	if len(cfg.KeySecrets) == 0 {
		return errors.New("missing key secret")
	}

	if cfg.DatabaseConfig.DSN == "" {
		return errors.New("missing database connection string")
	}

	dbc, err := db.NewConnection(cfg.DatabaseConfig)
	if err != nil {
		return fmt.Errorf("unable to initialize database connection: %v", err)
	}

	kRepo, err := db.NewPrivateKeySetRepo(dbc, cfg.UseOldFormat, cfg.KeySecrets...)
	if err != nil {
		return fmt.Errorf("unable to create PrivateKeySetRepo: %v", err)
	}

	ciRepo := db.NewClientIdentityRepo(dbc)
	sRepo := db.NewSessionRepo(dbc)
	skRepo := db.NewSessionKeyRepo(dbc)
	cfgRepo := db.NewConnectorConfigRepo(dbc)
	userRepo := db.NewUserRepo(dbc)
	pwiRepo := db.NewPasswordInfoRepo(dbc)
	userManager := user.NewManager(userRepo, pwiRepo, db.TransactionFactory(dbc), user.ManagerOptions{})
	refreshTokenRepo := db.NewRefreshTokenRepo(dbc)

	sm := session.NewSessionManager(sRepo, skRepo)

	srv.ClientIdentityRepo = ciRepo
	srv.KeySetRepo = kRepo
	srv.ConnectorConfigRepo = cfgRepo
	srv.UserRepo = userRepo
	srv.UserManager = userManager
	srv.PasswordInfoRepo = pwiRepo
	srv.SessionManager = sm
	srv.RefreshTokenRepo = refreshTokenRepo
	return nil
}
Beispiel #3
0
func TestDBPrivateKeySetRepoSetGet(t *testing.T) {
	s1 := []byte("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx")
	s2 := []byte("oooooooooooooooooooooooooooooooo")
	s3 := []byte("wwwwwwwwwwwwwwwwwwwwwwwwwwwwwwww")

	keys := []*key.PrivateKey{}
	for i := 0; i < 2; i++ {
		k, err := key.GeneratePrivateKey()
		if err != nil {
			t.Fatalf("Unable to generate RSA key: %v", err)
		}
		keys = append(keys, k)
	}

	ks := key.NewPrivateKeySet(
		[]*key.PrivateKey{keys[0], keys[1]}, time.Now().Add(time.Minute))

	tests := []struct {
		setSecrets [][]byte
		getSecrets [][]byte
		wantErr    bool
	}{
		{
			// same secrets used to encrypt, decrypt
			setSecrets: [][]byte{s1, s2},
			getSecrets: [][]byte{s1, s2},
		},
		{
			// setSecrets got rotated, but getSecrets didn't yet.
			setSecrets: [][]byte{s2, s3},
			getSecrets: [][]byte{s1, s2},
		},
		{
			// getSecrets doesn't have s3
			setSecrets: [][]byte{s3},
			getSecrets: [][]byte{s1, s2},
			wantErr:    true,
		},
	}

	for i, tt := range tests {
		dbMap := connect(t)
		setRepo, err := db.NewPrivateKeySetRepo(dbMap, false, tt.setSecrets...)
		if err != nil {
			t.Fatalf(err.Error())
		}

		getRepo, err := db.NewPrivateKeySetRepo(dbMap, false, tt.getSecrets...)
		if err != nil {
			t.Fatalf(err.Error())
		}

		if err := setRepo.Set(ks); err != nil {
			t.Fatalf("case %d: Unexpected error: %v", i, err)
		}

		got, err := getRepo.Get()
		if tt.wantErr {
			if err == nil {
				t.Errorf("case %d: want err, got nil", i)
			}
			continue
		}
		if err != nil {
			t.Fatalf("case %d: Unexpected error: %v", i, err)
		}

		if diff := pretty.Compare(ks, got); diff != "" {
			t.Fatalf("case %d:Retrieved incorrect KeySet: Compare(want,got): %v", i, diff)
		}

	}
}
Beispiel #4
0
func main() {
	fs := flag.NewFlagSet("dex-overlord", flag.ExitOnError)

	keySecrets := pflag.NewBase64List(32)
	fs.Var(keySecrets, "key-secrets", "A comma-separated list of base64 encoded 32 byte strings used as symmetric keys used to encrypt/decrypt signing key data in DB. The first key is considered the active key and used for encryption, while the others are used to decrypt.")

	useOldFormat := fs.Bool("use-deprecated-secret-format", false, "In prior releases, the database used AES-CBC to encrypt keys. New deployments should use the default AES-GCM encryption.")

	dbURL := fs.String("db-url", "", "DSN-formatted database connection string")

	dbMigrate := fs.Bool("db-migrate", true, "perform database migrations when starting up overlord. This includes the initial DB objects creation.")

	keyPeriod := fs.Duration("key-period", 24*time.Hour, "length of time for-which a given key will be valid")
	gcInterval := fs.Duration("gc-interval", time.Hour, "length of time between garbage collection runs")

	adminListen := fs.String("admin-listen", "http://127.0.0.1:5557", "scheme, host and port for listening for administrative operation requests ")

	adminAPISecret := pflag.NewBase64(server.AdminAPISecretLength)
	fs.Var(adminAPISecret, "admin-api-secret", fmt.Sprintf("A base64-encoded %d byte string which is used to protect the Admin API.", server.AdminAPISecretLength))

	localConnectorID := fs.String("local-connector", "local", "ID of the local connector")
	logDebug := fs.Bool("log-debug", false, "log debug-level information")
	logTimestamps := fs.Bool("log-timestamps", false, "prefix log lines with timestamps")

	printVersion := fs.Bool("version", false, "Print the version and exit")

	if err := fs.Parse(os.Args[1:]); err != nil {
		fmt.Fprintln(os.Stderr, err.Error())
		os.Exit(1)
	}

	if err := pflag.SetFlagsFromEnv(fs, "DEX_OVERLORD"); err != nil {
		fmt.Fprintln(os.Stderr, err.Error())
		os.Exit(1)
	}

	if *printVersion {
		fmt.Printf("dex version %s\ngo version %s\n", strings.TrimPrefix(version, "v"), strings.TrimPrefix(runtime.Version(), "go"))
		os.Exit(0)
	}

	if *logDebug {
		log.EnableDebug()
	}
	if *logTimestamps {
		log.EnableTimestamps()
	}

	adminURL, err := url.Parse(*adminListen)
	if err != nil {
		log.Fatalf("Unable to use --admin-listen flag: %v", err)
	}

	if len(keySecrets.BytesSlice()) == 0 {
		log.Fatalf("Must specify at least one key secret")
	}

	dbCfg := db.Config{
		DSN:                *dbURL,
		MaxIdleConnections: 1,
		MaxOpenConnections: 1,
	}
	dbc, err := db.NewConnection(dbCfg)
	if err != nil {
		log.Fatalf(err.Error())
	}

	if *dbMigrate {
		var sleep time.Duration
		for {
			var err error
			var migrations int
			if migrations, err = db.MigrateToLatest(dbc); err == nil {
				log.Infof("Performed %d db migrations", migrations)
				break
			}
			sleep = ptime.ExpBackoff(sleep, time.Minute)
			log.Errorf("Unable to migrate database, retrying in %v: %v", sleep, err)
			time.Sleep(sleep)
		}
	}

	userRepo := db.NewUserRepo(dbc)
	pwiRepo := db.NewPasswordInfoRepo(dbc)
	connCfgRepo := db.NewConnectorConfigRepo(dbc)
	userManager := manager.NewUserManager(userRepo,
		pwiRepo, connCfgRepo, db.TransactionFactory(dbc), manager.ManagerOptions{})
	adminAPI := admin.NewAdminAPI(userManager, userRepo, pwiRepo, *localConnectorID)
	kRepo, err := db.NewPrivateKeySetRepo(dbc, *useOldFormat, keySecrets.BytesSlice()...)
	if err != nil {
		log.Fatalf(err.Error())
	}

	var sleep time.Duration
	for {
		var done bool
		_, err := kRepo.Get()
		switch err {
		case nil:
			done = true
		case key.ErrorNoKeys:
			done = true
		case db.ErrorCannotDecryptKeys:
			log.Fatalf("Cannot decrypt keys using any of the given key secrets. The key secrets must be changed to include one that can decrypt the existing keys, or the existing keys must be deleted.")
		}

		if done {
			break
		}
		sleep = ptime.ExpBackoff(sleep, time.Minute)
		log.Errorf("Unable to get keys from repository, retrying in %v: %v", sleep, err)
		time.Sleep(sleep)
	}

	krot := key.NewPrivateKeyRotator(kRepo, *keyPeriod)
	s := server.NewAdminServer(adminAPI, krot, adminAPISecret.String())
	h := s.HTTPHandler()
	httpsrv := &http.Server{
		Addr:    adminURL.Host,
		Handler: h,
	}

	gc := db.NewGarbageCollector(dbc, *gcInterval)

	log.Infof("Binding to %s...", httpsrv.Addr)
	go func() {
		log.Fatal(httpsrv.ListenAndServe())
	}()

	gc.Run()
	<-krot.Run()
}