func (s *session) heartbeat(ctx context.Context) error { log.G(ctx).Debugf("(*session).heartbeat") client := api.NewDispatcherClient(s.conn) heartbeat := time.NewTimer(1) // send out a heartbeat right away defer heartbeat.Stop() for { select { case <-heartbeat.C: heartbeatCtx, cancel := context.WithTimeout(ctx, dispatcherRPCTimeout) resp, err := client.Heartbeat(heartbeatCtx, &api.HeartbeatRequest{ SessionID: s.sessionID, }) cancel() if err != nil { if grpc.Code(err) == codes.NotFound { err = errNodeNotRegistered } return err } period, err := ptypes.Duration(&resp.Period) if err != nil { return err } heartbeat.Reset(period) case <-s.closed: return errSessionClosed case <-ctx.Done(): return ctx.Err() } } }
// start begins the session and returns the first SessionMessage. func (s *session) start(ctx context.Context) error { log.G(ctx).Debugf("(*session).start") client := api.NewDispatcherClient(s.agent.config.Conn) description, err := s.agent.config.Executor.Describe(ctx) if err != nil { log.G(ctx).WithError(err).WithField("executor", s.agent.config.Executor). Errorf("node description unavailable") return err } // Override hostname if s.agent.config.Hostname != "" { description.Hostname = s.agent.config.Hostname } stream, err := client.Session(ctx, &api.SessionRequest{ Description: description, }) if err != nil { return err } msg, err := stream.Recv() if err != nil { return err } s.sessionID = msg.SessionID s.session = stream return s.handleSessionMessage(ctx, msg) }
// start begins the session and returns the first SessionMessage. func (s *session) start(ctx context.Context) error { log.G(ctx).Debugf("(*session).start") description, err := s.agent.config.Executor.Describe(ctx) if err != nil { log.G(ctx).WithError(err).WithField("executor", s.agent.config.Executor). Errorf("node description unavailable") return err } // Override hostname if s.agent.config.Hostname != "" { description.Hostname = s.agent.config.Hostname } errChan := make(chan error, 1) var ( msg *api.SessionMessage stream api.Dispatcher_SessionClient ) // Note: we don't defer cancellation of this context, because the // streaming RPC is used after this function returned. We only cancel // it in the timeout case to make sure the goroutine completes. sessionCtx, cancelSession := context.WithCancel(ctx) // Need to run Session in a goroutine since there's no way to set a // timeout for an individual Recv call in a stream. go func() { client := api.NewDispatcherClient(s.conn) stream, err = client.Session(sessionCtx, &api.SessionRequest{ Description: description, SessionID: s.sessionID, }) if err != nil { errChan <- err return } msg, err = stream.Recv() errChan <- err }() select { case err := <-errChan: if err != nil { return err } case <-time.After(dispatcherRPCTimeout): cancelSession() return errors.New("session initiation timed out") } s.sessionID = msg.SessionID s.session = stream return s.handleSessionMessage(ctx, msg) }
// sendTaskStatus uses the current session to send the status of a single task. func (s *session) sendTaskStatus(ctx context.Context, taskID string, status *api.TaskStatus) error { client := api.NewDispatcherClient(s.conn) if _, err := client.UpdateTaskStatus(ctx, &api.UpdateTaskStatusRequest{ SessionID: s.sessionID, Updates: []*api.UpdateTaskStatusRequest_TaskStatusUpdate{ { TaskID: taskID, Status: status, }, }, }); err != nil { // TODO(stevvooe): Dispatcher should not return this error. Status // reports for unknown tasks should be ignored. if grpc.Code(err) == codes.NotFound { return errTaskUnknown } return err } return nil }
func (s *session) sendTaskStatuses(ctx context.Context, updates ...*api.UpdateTaskStatusRequest_TaskStatusUpdate) ([]*api.UpdateTaskStatusRequest_TaskStatusUpdate, error) { if len(updates) < 1 { return nil, nil } const batchSize = 1024 select { case <-s.registered: select { case <-s.closed: return updates, ErrClosed default: } case <-s.closed: return updates, ErrClosed case <-ctx.Done(): return updates, ctx.Err() } client := api.NewDispatcherClient(s.conn) n := batchSize if len(updates) < n { n = len(updates) } if _, err := client.UpdateTaskStatus(ctx, &api.UpdateTaskStatusRequest{ SessionID: s.sessionID, Updates: updates[:n], }); err != nil { log.G(ctx).WithError(err).Errorf("failed sending task status batch size of %d", len(updates[:n])) return updates, err } return updates[n:], nil }
func (s *session) watch(ctx context.Context) error { log.G(ctx).Debugf("(*session).watch") client := api.NewDispatcherClient(s.conn) watch, err := client.Tasks(ctx, &api.TasksRequest{ SessionID: s.sessionID}) if err != nil { return err } for { resp, err := watch.Recv() if err != nil { return err } select { case s.tasks <- resp: case <-s.closed: return errSessionClosed case <-ctx.Done(): return ctx.Err() } } }
func TestManager(t *testing.T) { ctx := context.TODO() store := store.NewMemoryStore(nil) assert.NotNil(t, store) temp, err := ioutil.TempFile("", "test-socket") assert.NoError(t, err) assert.NoError(t, temp.Close()) assert.NoError(t, os.Remove(temp.Name())) defer os.RemoveAll(temp.Name()) lunix, err := net.Listen("unix", temp.Name()) assert.NoError(t, err) ltcp, err := net.Listen("tcp", "127.0.0.1:0") assert.NoError(t, err) stateDir, err := ioutil.TempDir("", "test-raft") assert.NoError(t, err) defer os.RemoveAll(stateDir) tc := testutils.NewTestCA(t) defer tc.Stop() agentSecurityConfig, err := tc.NewNodeConfig(ca.AgentRole) assert.NoError(t, err) agentDiffOrgSecurityConfig, err := tc.NewNodeConfigOrg(ca.AgentRole, "another-org") assert.NoError(t, err) managerSecurityConfig, err := tc.NewNodeConfig(ca.ManagerRole) assert.NoError(t, err) m, err := manager.New(&manager.Config{ ProtoListener: map[string]net.Listener{"unix": lunix, "tcp": ltcp}, StateDir: stateDir, SecurityConfig: managerSecurityConfig, }) assert.NoError(t, err) assert.NotNil(t, m) done := make(chan error) defer close(done) go func() { done <- m.Run(ctx) }() opts := []grpc.DialOption{ grpc.WithTimeout(10 * time.Second), grpc.WithTransportCredentials(agentSecurityConfig.ClientTLSCreds), } conn, err := grpc.Dial(ltcp.Addr().String(), opts...) assert.NoError(t, err) defer func() { assert.NoError(t, conn.Close()) }() // We have to send a dummy request to verify if the connection is actually up. client := api.NewDispatcherClient(conn) _, err = client.Heartbeat(context.Background(), &api.HeartbeatRequest{}) assert.Equal(t, dispatcher.ErrNodeNotRegistered.Error(), grpc.ErrorDesc(err)) // Try to have a client in a different org access this manager opts = []grpc.DialOption{ grpc.WithTimeout(10 * time.Second), grpc.WithTransportCredentials(agentDiffOrgSecurityConfig.ClientTLSCreds), } conn2, err := grpc.Dial(ltcp.Addr().String(), opts...) assert.NoError(t, err) defer func() { assert.NoError(t, conn2.Close()) }() // We have to send a dummy request to verify if the connection is actually up. client = api.NewDispatcherClient(conn2) _, err = client.Heartbeat(context.Background(), &api.HeartbeatRequest{}) assert.Contains(t, grpc.ErrorDesc(err), "Permission denied: unauthorized peer role: rpc error: code = 7 desc = Permission denied: remote certificate not part of organization") // Verify that requests to the various GRPC services running on TCP // are rejected if they don't have certs. opts = []grpc.DialOption{ grpc.WithTimeout(10 * time.Second), grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{InsecureSkipVerify: true})), } noCertConn, err := grpc.Dial(ltcp.Addr().String(), opts...) assert.NoError(t, err) defer func() { assert.NoError(t, noCertConn.Close()) }() client = api.NewDispatcherClient(noCertConn) _, err = client.Heartbeat(context.Background(), &api.HeartbeatRequest{}) assert.EqualError(t, err, "rpc error: code = 7 desc = Permission denied: unauthorized peer role: rpc error: code = 7 desc = no client certificates in request") controlClient := api.NewControlClient(noCertConn) _, err = controlClient.ListNodes(context.Background(), &api.ListNodesRequest{}) assert.EqualError(t, err, "rpc error: code = 7 desc = Permission denied: unauthorized peer role: rpc error: code = 7 desc = no client certificates in request") raftClient := api.NewRaftMembershipClient(noCertConn) _, err = raftClient.Join(context.Background(), &api.JoinRequest{}) assert.EqualError(t, err, "rpc error: code = 7 desc = Permission denied: unauthorized peer role: rpc error: code = 7 desc = no client certificates in request") m.Stop(ctx) // After stopping we should MAY receive an error from ListenAndServe if // all this happened before WaitForLeader completed, so don't check the // error. <-done }
func (s *session) watch(ctx context.Context) error { log := log.G(ctx).WithFields(logrus.Fields{"method": "(*session).watch"}) log.Debugf("") var ( resp *api.AssignmentsMessage assignmentWatch api.Dispatcher_AssignmentsClient tasksWatch api.Dispatcher_TasksClient streamReference string tasksFallback bool err error ) client := api.NewDispatcherClient(s.conn) for { // If this is the first time we're running the loop, or there was a reference mismatch // attempt to get the assignmentWatch if assignmentWatch == nil && !tasksFallback { assignmentWatch, err = client.Assignments(ctx, &api.AssignmentsRequest{SessionID: s.sessionID}) if err != nil { return err } } // We have an assignmentWatch, let's try to receive an AssignmentMessage if assignmentWatch != nil { // If we get a code = 12 desc = unknown method Assignments, try to use tasks resp, err = assignmentWatch.Recv() if err != nil { if grpc.Code(err) != codes.Unimplemented { return err } tasksFallback = true assignmentWatch = nil log.WithError(err).Infof("falling back to Tasks") } } // This code is here for backwards compatibility (so that newer clients can use the // older method Tasks) if tasksWatch == nil && tasksFallback { tasksWatch, err = client.Tasks(ctx, &api.TasksRequest{SessionID: s.sessionID}) if err != nil { return err } } if tasksWatch != nil { // When falling back to Tasks because of an old managers, we wrap the tasks in assignments. var taskResp *api.TasksMessage var assignmentChanges []*api.AssignmentChange taskResp, err = tasksWatch.Recv() if err != nil { return err } for _, t := range taskResp.Tasks { taskChange := &api.AssignmentChange{ Assignment: &api.Assignment{ Item: &api.Assignment_Task{ Task: t, }, }, Action: api.AssignmentChange_AssignmentActionUpdate, } assignmentChanges = append(assignmentChanges, taskChange) } resp = &api.AssignmentsMessage{Type: api.AssignmentsMessage_COMPLETE, Changes: assignmentChanges} } // If there seems to be a gap in the stream, let's break out of the inner for and // re-sync (by calling Assignments again). if streamReference != "" && streamReference != resp.AppliesTo { assignmentWatch = nil } else { streamReference = resp.ResultsIn } select { case s.assignments <- resp: case <-s.closed: return errSessionClosed case <-ctx.Done(): return ctx.Err() } } }
func TestManager(t *testing.T) { ctx := context.Background() temp, err := ioutil.TempFile("", "test-socket") assert.NoError(t, err) assert.NoError(t, temp.Close()) assert.NoError(t, os.Remove(temp.Name())) defer os.RemoveAll(temp.Name()) stateDir, err := ioutil.TempDir("", "test-raft") assert.NoError(t, err) defer os.RemoveAll(stateDir) tc := testutils.NewTestCA(t, func(p ca.CertPaths) *ca.KeyReadWriter { return ca.NewKeyReadWriter(p, []byte("kek"), nil) }) defer tc.Stop() agentSecurityConfig, err := tc.NewNodeConfig(ca.WorkerRole) assert.NoError(t, err) agentDiffOrgSecurityConfig, err := tc.NewNodeConfigOrg(ca.WorkerRole, "another-org") assert.NoError(t, err) managerSecurityConfig, err := tc.NewNodeConfig(ca.ManagerRole) assert.NoError(t, err) m, err := New(&Config{ RemoteAPI: RemoteAddrs{ListenAddr: "127.0.0.1:0"}, ControlAPI: temp.Name(), StateDir: stateDir, SecurityConfig: managerSecurityConfig, AutoLockManagers: true, UnlockKey: []byte("kek"), }) assert.NoError(t, err) assert.NotNil(t, m) tcpAddr := m.Addr() done := make(chan error) defer close(done) go func() { done <- m.Run(ctx) }() opts := []grpc.DialOption{ grpc.WithTimeout(10 * time.Second), grpc.WithTransportCredentials(agentSecurityConfig.ClientTLSCreds), } conn, err := grpc.Dial(tcpAddr, opts...) assert.NoError(t, err) defer func() { assert.NoError(t, conn.Close()) }() // We have to send a dummy request to verify if the connection is actually up. client := api.NewDispatcherClient(conn) _, err = client.Heartbeat(ctx, &api.HeartbeatRequest{}) assert.Equal(t, dispatcher.ErrNodeNotRegistered.Error(), grpc.ErrorDesc(err)) _, err = client.Session(ctx, &api.SessionRequest{}) assert.NoError(t, err) // Try to have a client in a different org access this manager opts = []grpc.DialOption{ grpc.WithTimeout(10 * time.Second), grpc.WithTransportCredentials(agentDiffOrgSecurityConfig.ClientTLSCreds), } conn2, err := grpc.Dial(tcpAddr, opts...) assert.NoError(t, err) defer func() { assert.NoError(t, conn2.Close()) }() client = api.NewDispatcherClient(conn2) _, err = client.Heartbeat(context.Background(), &api.HeartbeatRequest{}) assert.Contains(t, grpc.ErrorDesc(err), "Permission denied: unauthorized peer role: rpc error: code = 7 desc = Permission denied: remote certificate not part of organization") // Verify that requests to the various GRPC services running on TCP // are rejected if they don't have certs. opts = []grpc.DialOption{ grpc.WithTimeout(10 * time.Second), grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{InsecureSkipVerify: true})), } noCertConn, err := grpc.Dial(tcpAddr, opts...) assert.NoError(t, err) defer func() { assert.NoError(t, noCertConn.Close()) }() client = api.NewDispatcherClient(noCertConn) _, err = client.Heartbeat(context.Background(), &api.HeartbeatRequest{}) assert.EqualError(t, err, "rpc error: code = 7 desc = Permission denied: unauthorized peer role: rpc error: code = 7 desc = no client certificates in request") controlClient := api.NewControlClient(noCertConn) _, err = controlClient.ListNodes(context.Background(), &api.ListNodesRequest{}) assert.EqualError(t, err, "rpc error: code = 7 desc = Permission denied: unauthorized peer role: rpc error: code = 7 desc = no client certificates in request") raftClient := api.NewRaftMembershipClient(noCertConn) _, err = raftClient.Join(context.Background(), &api.JoinRequest{}) assert.EqualError(t, err, "rpc error: code = 7 desc = Permission denied: unauthorized peer role: rpc error: code = 7 desc = no client certificates in request") opts = []grpc.DialOption{ grpc.WithTimeout(10 * time.Second), grpc.WithTransportCredentials(managerSecurityConfig.ClientTLSCreds), } controlConn, err := grpc.Dial(tcpAddr, opts...) assert.NoError(t, err) defer func() { assert.NoError(t, controlConn.Close()) }() // check that the kek is added to the config var cluster api.Cluster m.raftNode.MemoryStore().View(func(tx store.ReadTx) { clusters, err := store.FindClusters(tx, store.All) require.NoError(t, err) require.Len(t, clusters, 1) cluster = *clusters[0] }) require.NotNil(t, cluster) require.Len(t, cluster.UnlockKeys, 1) require.Equal(t, &api.EncryptionKey{ Subsystem: ca.ManagerRole, Key: []byte("kek"), }, cluster.UnlockKeys[0]) // Test removal of the agent node agentID := agentSecurityConfig.ClientTLSCreds.NodeID() assert.NoError(t, m.raftNode.MemoryStore().Update(func(tx store.Tx) error { return store.CreateNode(tx, &api.Node{ ID: agentID, Certificate: api.Certificate{ Role: api.NodeRoleWorker, CN: agentID, }, }, ) })) controlClient = api.NewControlClient(controlConn) _, err = controlClient.RemoveNode(context.Background(), &api.RemoveNodeRequest{ NodeID: agentID, Force: true, }, ) assert.NoError(t, err) client = api.NewDispatcherClient(conn) _, err = client.Heartbeat(context.Background(), &api.HeartbeatRequest{}) assert.Contains(t, grpc.ErrorDesc(err), "removed from swarm") m.Stop(ctx) // After stopping we should MAY receive an error from ListenAndServe if // all this happened before WaitForLeader completed, so don't check the // error. <-done }
func startDispatcher(c *Config) (*grpcDispatcher, error) { l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { return nil, err } tca := testutils.NewTestCA(nil, testutils.AcceptancePolicy(true, true, "")) agentSecurityConfig1, err := tca.NewNodeConfig(ca.AgentRole) if err != nil { return nil, err } agentSecurityConfig2, err := tca.NewNodeConfig(ca.AgentRole) if err != nil { return nil, err } managerSecurityConfig, err := tca.NewNodeConfig(ca.ManagerRole) if err != nil { return nil, err } serverOpts := []grpc.ServerOption{grpc.Creds(managerSecurityConfig.ServerTLSCreds)} s := grpc.NewServer(serverOpts...) tc := &testCluster{addr: l.Addr().String(), store: tca.MemoryStore} d := New(tc, c) authorize := func(ctx context.Context, roles []string) error { _, err := ca.AuthorizeForwardedRoleAndOrg(ctx, roles, []string{ca.ManagerRole}, tca.Organization) return err } authenticatedDispatcherAPI := api.NewAuthenticatedWrapperDispatcherServer(d, authorize) api.RegisterDispatcherServer(s, authenticatedDispatcherAPI) go func() { // Serve will always return an error (even when properly stopped). // Explicitly ignore it. _ = s.Serve(l) }() go d.Run(context.Background()) if err := raftutils.PollFuncWithTimeout(nil, func() error { d.mu.Lock() defer d.mu.Unlock() if !d.isRunning() { return fmt.Errorf("dispatcher is not running") } return nil }, 5*time.Second); err != nil { return nil, err } clientOpts := []grpc.DialOption{grpc.WithTimeout(10 * time.Second)} clientOpts1 := append(clientOpts, grpc.WithTransportCredentials(agentSecurityConfig1.ClientTLSCreds)) clientOpts2 := append(clientOpts, grpc.WithTransportCredentials(agentSecurityConfig2.ClientTLSCreds)) clientOpts3 := append(clientOpts, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{InsecureSkipVerify: true}))) conn1, err := grpc.Dial(l.Addr().String(), clientOpts1...) if err != nil { return nil, err } conn2, err := grpc.Dial(l.Addr().String(), clientOpts2...) if err != nil { return nil, err } conn3, err := grpc.Dial(l.Addr().String(), clientOpts3...) if err != nil { return nil, err } clients := []api.DispatcherClient{api.NewDispatcherClient(conn1), api.NewDispatcherClient(conn2), api.NewDispatcherClient(conn3)} securityConfigs := []*ca.SecurityConfig{agentSecurityConfig1, agentSecurityConfig2, managerSecurityConfig} conns := []*grpc.ClientConn{conn1, conn2, conn3} return &grpcDispatcher{ Clients: clients, SecurityConfigs: securityConfigs, Store: tc.MemoryStore(), dispatcherServer: d, conns: conns, grpcServer: s, testCA: tca, }, nil }