コード例 #1
0
ファイル: manager.go プロジェクト: alexmavr/docker
// becomeLeader starts the subsystems that are run on the leader.
func (m *Manager) becomeLeader(ctx context.Context) {
	s := m.RaftNode.MemoryStore()

	rootCA := m.config.SecurityConfig.RootCA()
	nodeID := m.config.SecurityConfig.ClientTLSCreds.NodeID()

	raftCfg := raft.DefaultRaftConfig()
	raftCfg.ElectionTick = uint32(m.RaftNode.Config.ElectionTick)
	raftCfg.HeartbeatTick = uint32(m.RaftNode.Config.HeartbeatTick)

	clusterID := m.config.SecurityConfig.ClientTLSCreds.Organization()

	initialCAConfig := ca.DefaultCAConfig()
	initialCAConfig.ExternalCAs = m.config.ExternalCAs

	s.Update(func(tx store.Tx) error {
		// Add a default cluster object to the
		// store. Don't check the error because
		// we expect this to fail unless this
		// is a brand new cluster.
		store.CreateCluster(tx, defaultClusterObject(clusterID, initialCAConfig, raftCfg, rootCA))
		// Add Node entry for ourself, if one
		// doesn't exist already.
		store.CreateNode(tx, managerNode(nodeID))
		return nil
	})

	// Attempt to rotate the key-encrypting-key of the root CA key-material
	err := m.rotateRootCAKEK(ctx, clusterID)
	if err != nil {
		log.G(ctx).WithError(err).Error("root key-encrypting-key rotation failed")
	}

	m.replicatedOrchestrator = orchestrator.NewReplicatedOrchestrator(s)
	m.globalOrchestrator = orchestrator.NewGlobalOrchestrator(s)
	m.taskReaper = orchestrator.NewTaskReaper(s)
	m.scheduler = scheduler.New(s)
	m.keyManager = keymanager.New(s, keymanager.DefaultConfig())

	// TODO(stevvooe): Allocate a context that can be used to
	// shutdown underlying manager processes when leadership is
	// lost.

	m.allocator, err = allocator.New(s)
	if err != nil {
		log.G(ctx).WithError(err).Error("failed to create allocator")
		// TODO(stevvooe): It doesn't seem correct here to fail
		// creating the allocator but then use it anyway.
	}

	if m.keyManager != nil {
		go func(keyManager *keymanager.KeyManager) {
			if err := keyManager.Run(ctx); err != nil {
				log.G(ctx).WithError(err).Error("keymanager failed with an error")
			}
		}(m.keyManager)
	}

	go func(d *dispatcher.Dispatcher) {
		if err := d.Run(ctx); err != nil {
			log.G(ctx).WithError(err).Error("Dispatcher exited with an error")
		}
	}(m.Dispatcher)

	go func(server *ca.Server) {
		if err := server.Run(ctx); err != nil {
			log.G(ctx).WithError(err).Error("CA signer exited with an error")
		}
	}(m.caserver)

	// Start all sub-components in separate goroutines.
	// TODO(aluzzardi): This should have some kind of error handling so that
	// any component that goes down would bring the entire manager down.
	if m.allocator != nil {
		go func(allocator *allocator.Allocator) {
			if err := allocator.Run(ctx); err != nil {
				log.G(ctx).WithError(err).Error("allocator exited with an error")
			}
		}(m.allocator)
	}

	go func(scheduler *scheduler.Scheduler) {
		if err := scheduler.Run(ctx); err != nil {
			log.G(ctx).WithError(err).Error("scheduler exited with an error")
		}
	}(m.scheduler)

	go func(taskReaper *orchestrator.TaskReaper) {
		taskReaper.Run()
	}(m.taskReaper)

	go func(orchestrator *orchestrator.ReplicatedOrchestrator) {
		if err := orchestrator.Run(ctx); err != nil {
			log.G(ctx).WithError(err).Error("replicated orchestrator exited with an error")
		}
	}(m.replicatedOrchestrator)

	go func(globalOrchestrator *orchestrator.GlobalOrchestrator) {
		if err := globalOrchestrator.Run(ctx); err != nil {
			log.G(ctx).WithError(err).Error("global orchestrator exited with an error")
		}
	}(m.globalOrchestrator)

}
コード例 #2
0
ファイル: manager.go プロジェクト: maxim28/docker
// Run starts all manager sub-systems and the gRPC server at the configured
// address.
// The call never returns unless an error occurs or `Stop()` is called.
//
// TODO(aluzzardi): /!\ This function is *way* too complex. /!\
// It needs to be split into smaller manageable functions.
func (m *Manager) Run(parent context.Context) error {
	ctx, ctxCancel := context.WithCancel(parent)
	defer ctxCancel()

	// Harakiri.
	go func() {
		select {
		case <-ctx.Done():
		case <-m.stopped:
			ctxCancel()
		}
	}()

	leadershipCh, cancel := m.RaftNode.SubscribeLeadership()
	defer cancel()

	go func() {
		for leadershipEvent := range leadershipCh {
			// read out and discard all of the messages when we've stopped
			// don't acquire the mutex yet. if stopped is closed, we don't need
			// this stops this loop from starving Run()'s attempt to Lock
			select {
			case <-m.stopped:
				continue
			default:
				// do nothing, we're not stopped
			}
			// we're not stopping so NOW acquire the mutex
			m.mu.Lock()
			newState := leadershipEvent.(raft.LeadershipState)

			if newState == raft.IsLeader {
				s := m.RaftNode.MemoryStore()

				rootCA := m.config.SecurityConfig.RootCA()
				nodeID := m.config.SecurityConfig.ClientTLSCreds.NodeID()

				raftCfg := raft.DefaultRaftConfig()
				raftCfg.ElectionTick = uint32(m.RaftNode.Config.ElectionTick)
				raftCfg.HeartbeatTick = uint32(m.RaftNode.Config.HeartbeatTick)

				clusterID := m.config.SecurityConfig.ClientTLSCreds.Organization()

				initialCAConfig := ca.DefaultCAConfig()
				initialCAConfig.ExternalCAs = m.config.ExternalCAs

				s.Update(func(tx store.Tx) error {
					// Add a default cluster object to the
					// store. Don't check the error because
					// we expect this to fail unless this
					// is a brand new cluster.
					store.CreateCluster(tx, &api.Cluster{
						ID: clusterID,
						Spec: api.ClusterSpec{
							Annotations: api.Annotations{
								Name: store.DefaultClusterName,
							},
							Orchestration: api.OrchestrationConfig{
								TaskHistoryRetentionLimit: defaultTaskHistoryRetentionLimit,
							},
							Dispatcher: api.DispatcherConfig{
								HeartbeatPeriod: ptypes.DurationProto(dispatcher.DefaultHeartBeatPeriod),
							},
							Raft:     raftCfg,
							CAConfig: initialCAConfig,
						},
						RootCA: api.RootCA{
							CAKey:      rootCA.Key,
							CACert:     rootCA.Cert,
							CACertHash: rootCA.Digest.String(),
							JoinTokens: api.JoinTokens{
								Worker:  ca.GenerateJoinToken(rootCA),
								Manager: ca.GenerateJoinToken(rootCA),
							},
						},
					})
					// Add Node entry for ourself, if one
					// doesn't exist already.
					store.CreateNode(tx, &api.Node{
						ID: nodeID,
						Certificate: api.Certificate{
							CN:   nodeID,
							Role: api.NodeRoleManager,
							Status: api.IssuanceStatus{
								State: api.IssuanceStateIssued,
							},
						},
						Spec: api.NodeSpec{
							Role:       api.NodeRoleManager,
							Membership: api.NodeMembershipAccepted,
						},
					})
					return nil
				})

				// Attempt to rotate the key-encrypting-key of the root CA key-material
				err := m.rotateRootCAKEK(ctx, clusterID)
				if err != nil {
					log.G(ctx).WithError(err).Error("root key-encrypting-key rotation failed")
				}

				m.replicatedOrchestrator = orchestrator.NewReplicatedOrchestrator(s)
				m.globalOrchestrator = orchestrator.NewGlobalOrchestrator(s)
				m.taskReaper = orchestrator.NewTaskReaper(s)
				m.scheduler = scheduler.New(s)
				m.keyManager = keymanager.New(m.RaftNode.MemoryStore(), keymanager.DefaultConfig())

				// TODO(stevvooe): Allocate a context that can be used to
				// shutdown underlying manager processes when leadership is
				// lost.

				m.allocator, err = allocator.New(s)
				if err != nil {
					log.G(ctx).WithError(err).Error("failed to create allocator")
					// TODO(stevvooe): It doesn't seem correct here to fail
					// creating the allocator but then use it anyway.
				}

				if m.keyManager != nil {
					go func(keyManager *keymanager.KeyManager) {
						if err := keyManager.Run(ctx); err != nil {
							log.G(ctx).WithError(err).Error("keymanager failed with an error")
						}
					}(m.keyManager)
				}

				go func(d *dispatcher.Dispatcher) {
					if err := d.Run(ctx); err != nil {
						log.G(ctx).WithError(err).Error("Dispatcher exited with an error")
					}
				}(m.Dispatcher)

				go func(server *ca.Server) {
					if err := server.Run(ctx); err != nil {
						log.G(ctx).WithError(err).Error("CA signer exited with an error")
					}
				}(m.caserver)

				// Start all sub-components in separate goroutines.
				// TODO(aluzzardi): This should have some kind of error handling so that
				// any component that goes down would bring the entire manager down.

				if m.allocator != nil {
					go func(allocator *allocator.Allocator) {
						if err := allocator.Run(ctx); err != nil {
							log.G(ctx).WithError(err).Error("allocator exited with an error")
						}
					}(m.allocator)
				}

				go func(scheduler *scheduler.Scheduler) {
					if err := scheduler.Run(ctx); err != nil {
						log.G(ctx).WithError(err).Error("scheduler exited with an error")
					}
				}(m.scheduler)

				go func(taskReaper *orchestrator.TaskReaper) {
					taskReaper.Run()
				}(m.taskReaper)

				go func(orchestrator *orchestrator.ReplicatedOrchestrator) {
					if err := orchestrator.Run(ctx); err != nil {
						log.G(ctx).WithError(err).Error("replicated orchestrator exited with an error")
					}
				}(m.replicatedOrchestrator)

				go func(globalOrchestrator *orchestrator.GlobalOrchestrator) {
					if err := globalOrchestrator.Run(ctx); err != nil {
						log.G(ctx).WithError(err).Error("global orchestrator exited with an error")
					}
				}(m.globalOrchestrator)

			} else if newState == raft.IsFollower {
				m.Dispatcher.Stop()
				m.caserver.Stop()

				if m.allocator != nil {
					m.allocator.Stop()
					m.allocator = nil
				}

				m.replicatedOrchestrator.Stop()
				m.replicatedOrchestrator = nil

				m.globalOrchestrator.Stop()
				m.globalOrchestrator = nil

				m.taskReaper.Stop()
				m.taskReaper = nil

				m.scheduler.Stop()
				m.scheduler = nil

				if m.keyManager != nil {
					m.keyManager.Stop()
					m.keyManager = nil
				}
			}
			m.mu.Unlock()
		}
	}()

	proxyOpts := []grpc.DialOption{
		grpc.WithTimeout(5 * time.Second),
		grpc.WithTransportCredentials(m.config.SecurityConfig.ClientTLSCreds),
	}

	cs := raftpicker.NewConnSelector(m.RaftNode, proxyOpts...)
	m.connSelector = cs

	// We need special connSelector for controlapi because it provides automatic
	// leader tracking.
	// Other APIs are using connSelector which errors out on leader change, but
	// allows to react quickly to reelections.
	controlAPIProxyOpts := []grpc.DialOption{
		grpc.WithBackoffMaxDelay(time.Second),
		grpc.WithTransportCredentials(m.config.SecurityConfig.ClientTLSCreds),
	}

	controlAPIConnSelector := hackpicker.NewConnSelector(m.RaftNode, controlAPIProxyOpts...)

	authorize := func(ctx context.Context, roles []string) error {
		// Authorize the remote roles, ensure they can only be forwarded by managers
		_, err := ca.AuthorizeForwardedRoleAndOrg(ctx, roles, []string{ca.ManagerRole}, m.config.SecurityConfig.ClientTLSCreds.Organization())
		return err
	}

	baseControlAPI := controlapi.NewServer(m.RaftNode.MemoryStore(), m.RaftNode, m.config.SecurityConfig.RootCA())
	healthServer := health.NewHealthServer()

	authenticatedControlAPI := api.NewAuthenticatedWrapperControlServer(baseControlAPI, authorize)
	authenticatedDispatcherAPI := api.NewAuthenticatedWrapperDispatcherServer(m.Dispatcher, authorize)
	authenticatedCAAPI := api.NewAuthenticatedWrapperCAServer(m.caserver, authorize)
	authenticatedNodeCAAPI := api.NewAuthenticatedWrapperNodeCAServer(m.caserver, authorize)
	authenticatedRaftAPI := api.NewAuthenticatedWrapperRaftServer(m.RaftNode, authorize)
	authenticatedHealthAPI := api.NewAuthenticatedWrapperHealthServer(healthServer, authorize)
	authenticatedRaftMembershipAPI := api.NewAuthenticatedWrapperRaftMembershipServer(m.RaftNode, authorize)

	proxyDispatcherAPI := api.NewRaftProxyDispatcherServer(authenticatedDispatcherAPI, cs, m.RaftNode, ca.WithMetadataForwardTLSInfo)
	proxyCAAPI := api.NewRaftProxyCAServer(authenticatedCAAPI, cs, m.RaftNode, ca.WithMetadataForwardTLSInfo)
	proxyNodeCAAPI := api.NewRaftProxyNodeCAServer(authenticatedNodeCAAPI, cs, m.RaftNode, ca.WithMetadataForwardTLSInfo)
	proxyRaftMembershipAPI := api.NewRaftProxyRaftMembershipServer(authenticatedRaftMembershipAPI, cs, m.RaftNode, ca.WithMetadataForwardTLSInfo)

	// localProxyControlAPI is a special kind of proxy. It is only wired up
	// to receive requests from a trusted local socket, and these requests
	// don't use TLS, therefore the requests it handles locally should
	// bypass authorization. When it proxies, it sends them as requests from
	// this manager rather than forwarded requests (it has no TLS
	// information to put in the metadata map).
	forwardAsOwnRequest := func(ctx context.Context) (context.Context, error) { return ctx, nil }
	localProxyControlAPI := api.NewRaftProxyControlServer(baseControlAPI, controlAPIConnSelector, m.RaftNode, forwardAsOwnRequest)

	// Everything registered on m.server should be an authenticated
	// wrapper, or a proxy wrapping an authenticated wrapper!
	api.RegisterCAServer(m.server, proxyCAAPI)
	api.RegisterNodeCAServer(m.server, proxyNodeCAAPI)
	api.RegisterRaftServer(m.server, authenticatedRaftAPI)
	api.RegisterHealthServer(m.server, authenticatedHealthAPI)
	api.RegisterRaftMembershipServer(m.server, proxyRaftMembershipAPI)
	api.RegisterControlServer(m.localserver, localProxyControlAPI)
	api.RegisterControlServer(m.server, authenticatedControlAPI)
	api.RegisterDispatcherServer(m.server, proxyDispatcherAPI)

	errServe := make(chan error, 2)
	for proto, l := range m.listeners {
		go func(proto string, lis net.Listener) {
			ctx := log.WithLogger(ctx, log.G(ctx).WithFields(
				logrus.Fields{
					"proto": lis.Addr().Network(),
					"addr":  lis.Addr().String()}))
			if proto == "unix" {
				log.G(ctx).Info("Listening for local connections")
				// we need to disallow double closes because UnixListener.Close
				// can delete unix-socket file of newer listener. grpc calls
				// Close twice indeed: in Serve and in Stop.
				errServe <- m.localserver.Serve(&closeOnceListener{Listener: lis})
			} else {
				log.G(ctx).Info("Listening for connections")
				errServe <- m.server.Serve(lis)
			}
		}(proto, l)
	}

	// Set the raft server as serving for the health server
	healthServer.SetServingStatus("Raft", api.HealthCheckResponse_SERVING)

	if err := m.RaftNode.JoinAndStart(); err != nil {
		for _, lis := range m.listeners {
			lis.Close()
		}
		return fmt.Errorf("can't initialize raft node: %v", err)
	}

	close(m.started)

	go func() {
		err := m.RaftNode.Run(ctx)
		if err != nil {
			log.G(ctx).Error(err)
			m.Stop(ctx)
		}
	}()

	if err := raft.WaitForLeader(ctx, m.RaftNode); err != nil {
		m.server.Stop()
		return err
	}

	c, err := raft.WaitForCluster(ctx, m.RaftNode)
	if err != nil {
		m.server.Stop()
		return err
	}
	raftConfig := c.Spec.Raft

	if int(raftConfig.ElectionTick) != m.RaftNode.Config.ElectionTick {
		log.G(ctx).Warningf("election tick value (%ds) is different from the one defined in the cluster config (%vs), the cluster may be unstable", m.RaftNode.Config.ElectionTick, raftConfig.ElectionTick)
	}
	if int(raftConfig.HeartbeatTick) != m.RaftNode.Config.HeartbeatTick {
		log.G(ctx).Warningf("heartbeat tick value (%ds) is different from the one defined in the cluster config (%vs), the cluster may be unstable", m.RaftNode.Config.HeartbeatTick, raftConfig.HeartbeatTick)
	}

	// wait for an error in serving.
	err = <-errServe
	select {
	// check to see if stopped was posted to. if so, we're in the process of
	// stopping, or done and that's why we got the error. if stopping is
	// deliberate, stopped will ALWAYS be closed before the error is trigger,
	// so this path will ALWAYS be taken if the stop was deliberate
	case <-m.stopped:
		// shutdown was requested, do not return an error
		// but first, we wait to acquire a mutex to guarantee that stopping is
		// finished. as long as we acquire the mutex BEFORE we return, we know
		// that stopping is stopped.
		m.mu.Lock()
		m.mu.Unlock()
		return nil
	// otherwise, we'll get something from errServe, which indicates that an
	// error in serving has actually occurred and this isn't a planned shutdown
	default:
		return err
	}
}
コード例 #3
0
ファイル: storage_test.go プロジェクト: docker/swarmkit
// This test rotates the encryption key and waits for the expected thing to happen
func TestRaftEncryptionKeyRotationWait(t *testing.T) {
	t.Parallel()
	nodes := make(map[uint64]*raftutils.TestNode)
	var clockSource *fakeclock.FakeClock

	raftConfig := raft.DefaultRaftConfig()
	nodes[1], clockSource = raftutils.NewInitNode(t, tc, &raftConfig)
	defer raftutils.TeardownCluster(t, nodes)

	nodeIDs := []string{"id1", "id2", "id3"}
	values := make([]*api.Node, len(nodeIDs))

	// Propose 3 values
	var err error
	for i, nodeID := range nodeIDs[:3] {
		values[i], err = raftutils.ProposeValue(t, nodes[1], DefaultProposalTime, nodeID)
		require.NoError(t, err, "failed to propose value")
	}

	snapDir := filepath.Join(nodes[1].StateDir, "snap-v3-encrypted")

	startingKeys := nodes[1].KeyRotator.GetKeys()

	// rotate the encryption key
	nodes[1].KeyRotator.QueuePendingKey([]byte("key2"))
	nodes[1].KeyRotator.RotationNotify() <- struct{}{}

	// the rotation should trigger a snapshot, which should notify the rotator when it's done
	require.NoError(t, raftutils.PollFunc(clockSource, func() error {
		snapshots, err := storage.ListSnapshots(snapDir)
		if err != nil {
			return err
		}
		if len(snapshots) != 1 {
			return fmt.Errorf("expected 1 snapshot, found %d on new node", len(snapshots))
		}
		if nodes[1].KeyRotator.NeedsRotation() {
			return fmt.Errorf("rotation never finished")
		}
		return nil
	}))
	raftutils.CheckValuesOnNodes(t, clockSource, nodes, nodeIDs, values)

	// Propose a 4th value
	nodeIDs = append(nodeIDs, "id4")
	v, err := raftutils.ProposeValue(t, nodes[1], DefaultProposalTime, "id4")
	require.NoError(t, err, "failed to propose value")
	values = append(values, v)
	raftutils.CheckValuesOnNodes(t, clockSource, nodes, nodeIDs, values)

	nodes[1].Server.Stop()
	nodes[1].ShutdownRaft()

	// Try to restart node 1. Without the new unlock key, it can't actually start
	n, ctx := raftutils.CopyNode(t, clockSource, nodes[1], false, raftutils.NewSimpleKeyRotator(startingKeys))
	require.Error(t, n.Node.JoinAndStart(ctx),
		"should not have been able to restart since we can't read snapshots")

	// with the right key, it can start, even if the right key is only the pending key
	newKeys := startingKeys
	newKeys.PendingDEK = []byte("key2")
	nodes[1].KeyRotator = raftutils.NewSimpleKeyRotator(newKeys)
	nodes[1] = raftutils.RestartNode(t, clockSource, nodes[1], false)

	raftutils.WaitForCluster(t, clockSource, nodes)

	// as soon as we joined, it should have finished rotating the key
	require.False(t, nodes[1].KeyRotator.NeedsRotation())
	raftutils.CheckValuesOnNodes(t, clockSource, nodes, nodeIDs, values)

	// break snapshotting, and ensure that key rotation never finishes
	tempSnapDir := filepath.Join(nodes[1].StateDir, "snap-backup")
	require.NoError(t, os.Rename(snapDir, tempSnapDir))
	require.NoError(t, ioutil.WriteFile(snapDir, []byte("this is no longer a directory"), 0644))

	nodes[1].KeyRotator.QueuePendingKey([]byte("key3"))
	nodes[1].KeyRotator.RotationNotify() <- struct{}{}

	time.Sleep(250 * time.Millisecond)

	// rotation has not been finished, because we cannot take a snapshot
	require.True(t, nodes[1].KeyRotator.NeedsRotation())

	// Propose a 5th value, so we have WALs written with the new key
	nodeIDs = append(nodeIDs, "id5")
	v, err = raftutils.ProposeValue(t, nodes[1], DefaultProposalTime, "id5")
	require.NoError(t, err, "failed to propose value")
	values = append(values, v)
	raftutils.CheckValuesOnNodes(t, clockSource, nodes, nodeIDs, values)

	nodes[1].Server.Stop()
	nodes[1].ShutdownRaft()

	// restore the snapshot dir
	require.NoError(t, os.RemoveAll(snapDir))
	require.NoError(t, os.Rename(tempSnapDir, snapDir))

	// Now the wals are a mix of key2 and key3 - we can't actually start with either key
	singleKey := raft.EncryptionKeys{CurrentDEK: []byte("key2")}
	n, ctx = raftutils.CopyNode(t, clockSource, nodes[1], false, raftutils.NewSimpleKeyRotator(singleKey))
	require.Error(t, n.Node.JoinAndStart(ctx),
		"should not have been able to restart since we can't read all the WALs, even if we can read the snapshot")
	singleKey = raft.EncryptionKeys{CurrentDEK: []byte("key3")}
	n, ctx = raftutils.CopyNode(t, clockSource, nodes[1], false, raftutils.NewSimpleKeyRotator(singleKey))
	require.Error(t, n.Node.JoinAndStart(ctx),
		"should not have been able to restart since we can't read all the WALs, and also not the snapshot")

	nodes[1], ctx = raftutils.CopyNode(t, clockSource, nodes[1], false,
		raftutils.NewSimpleKeyRotator(raft.EncryptionKeys{
			CurrentDEK: []byte("key2"),
			PendingDEK: []byte("key3"),
		}))
	require.NoError(t, nodes[1].Node.JoinAndStart(ctx))

	// we can load, but we still need a snapshot because rotation hasn't finished
	snapshots, err := storage.ListSnapshots(snapDir)
	require.NoError(t, err)
	require.Len(t, snapshots, 1, "expected 1 snapshot")
	require.True(t, nodes[1].KeyRotator.NeedsRotation())
	currSnapshot := snapshots[0]

	// start the node - everything should fix itself
	go nodes[1].Node.Run(ctx)
	raftutils.WaitForCluster(t, clockSource, nodes)

	require.NoError(t, raftutils.PollFunc(clockSource, func() error {
		snapshots, err := storage.ListSnapshots(snapDir)
		if err != nil {
			return err
		}
		if len(snapshots) != 1 {
			return fmt.Errorf("expected 1 snapshots, found %d on new node", len(snapshots))
		}
		if snapshots[0] == currSnapshot {
			return fmt.Errorf("new snapshot not done yet")
		}
		if nodes[1].KeyRotator.NeedsRotation() {
			return fmt.Errorf("rotation never finished")
		}
		currSnapshot = snapshots[0]
		return nil
	}))
	raftutils.CheckValuesOnNodes(t, clockSource, nodes, nodeIDs, values)

	// If we can't update the keys, we wait for the next snapshot to do so
	nodes[1].KeyRotator.SetUpdateFunc(func() error { return fmt.Errorf("nope!") })
	nodes[1].KeyRotator.QueuePendingKey([]byte("key4"))
	nodes[1].KeyRotator.RotationNotify() <- struct{}{}

	require.NoError(t, raftutils.PollFunc(clockSource, func() error {
		snapshots, err := storage.ListSnapshots(snapDir)
		if err != nil {
			return err
		}
		if len(snapshots) != 1 {
			return fmt.Errorf("expected 1 snapshots, found %d on new node", len(snapshots))
		}
		if snapshots[0] == currSnapshot {
			return fmt.Errorf("new snapshot not done yet")
		}
		currSnapshot = snapshots[0]
		return nil
	}))
	require.True(t, nodes[1].KeyRotator.NeedsRotation())

	// Fix updating the key rotator, and propose a 6th value - this should trigger the key
	// rotation to finish
	nodes[1].KeyRotator.SetUpdateFunc(nil)
	nodeIDs = append(nodeIDs, "id6")
	v, err = raftutils.ProposeValue(t, nodes[1], DefaultProposalTime, "id6")
	require.NoError(t, err, "failed to propose value")
	values = append(values, v)
	raftutils.CheckValuesOnNodes(t, clockSource, nodes, nodeIDs, values)

	require.NoError(t, raftutils.PollFunc(clockSource, func() error {
		if nodes[1].KeyRotator.NeedsRotation() {
			return fmt.Errorf("rotation never finished")
		}
		return nil
	}))

	// no new snapshot
	snapshots, err = storage.ListSnapshots(snapDir)
	require.NoError(t, err)
	require.Len(t, snapshots, 1)
	require.Equal(t, currSnapshot, snapshots[0])

	// Even if something goes wrong with getting keys, and needs rotation returns a false positive,
	// if there's no PendingDEK nothing happens.

	fakeTrue := true
	nodes[1].KeyRotator.SetNeedsRotation(&fakeTrue)
	nodes[1].KeyRotator.RotationNotify() <- struct{}{}

	// propose another value
	nodeIDs = append(nodeIDs, "id7")
	v, err = raftutils.ProposeValue(t, nodes[1], DefaultProposalTime, "id7")
	require.NoError(t, err, "failed to propose value")
	values = append(values, v)
	raftutils.CheckValuesOnNodes(t, clockSource, nodes, nodeIDs, values)

	// no new snapshot
	snapshots, err = storage.ListSnapshots(snapDir)
	require.NoError(t, err)
	require.Len(t, snapshots, 1)
	require.Equal(t, currSnapshot, snapshots[0])

	// and when we restart, we can restart with the original key (the WAL written for the new proposed value)
	// is written with the old key
	nodes[1].Server.Stop()
	nodes[1].ShutdownRaft()

	nodes[1].KeyRotator = raftutils.NewSimpleKeyRotator(raft.EncryptionKeys{
		CurrentDEK: []byte("key4"),
	})
	nodes[1] = raftutils.RestartNode(t, clockSource, nodes[1], false)
	raftutils.WaitForCluster(t, clockSource, nodes)
	raftutils.CheckValuesOnNodes(t, clockSource, nodes, nodeIDs, values)
}