// Leave asks to a member of the raft to remove // us from the raft cluster. This method is called // from a member who is willing to leave its raft // membership to an active member of the raft func (n *Node) Leave(ctx context.Context, req *api.LeaveRequest) (*api.LeaveResponse, error) { nodeInfo, err := ca.RemoteNode(ctx) if err != nil { return nil, err } fields := logrus.Fields{ "node.id": nodeInfo.NodeID, "method": "(*Node).Leave", } if nodeInfo.ForwardedBy != nil { fields["forwarder.id"] = nodeInfo.ForwardedBy.NodeID } log.G(ctx).WithFields(fields).Debugf("") // can't stop the raft node while an async RPC is in progress n.stopMu.RLock() defer n.stopMu.RUnlock() if !n.IsMember() { return nil, ErrNoRaftMember } if !n.isLeader() { return nil, ErrLostLeadership } err = n.RemoveMember(ctx, req.Node.RaftID) if err != nil { return nil, err } return &api.LeaveResponse{}, nil }
// ResolveAddress returns the address reaching for a given node ID. func (n *Node) ResolveAddress(ctx context.Context, msg *api.ResolveAddressRequest) (*api.ResolveAddressResponse, error) { if !n.IsMember() { return nil, ErrNoRaftMember } nodeInfo, err := ca.RemoteNode(ctx) if err != nil { return nil, err } fields := logrus.Fields{ "node.id": nodeInfo.NodeID, "method": "(*Node).ResolveAddress", "raft_id": fmt.Sprintf("%x", n.Config.ID), } if nodeInfo.ForwardedBy != nil { fields["forwarder.id"] = nodeInfo.ForwardedBy.NodeID } log.G(ctx).WithFields(fields).Debug("") member := n.cluster.GetMember(msg.RaftID) if member == nil { return nil, grpc.Errorf(codes.NotFound, "member %x not found", msg.RaftID) } return &api.ResolveAddressResponse{Addr: member.Addr}, nil }
// Leave asks to a member of the raft to remove // us from the raft cluster. This method is called // from a member who is willing to leave its raft // membership to an active member of the raft func (n *Node) Leave(ctx context.Context, req *api.LeaveRequest) (*api.LeaveResponse, error) { nodeInfo, err := ca.RemoteNode(ctx) if err != nil { return nil, err } ctx, cancel := n.WithContext(ctx) defer cancel() fields := logrus.Fields{ "node.id": nodeInfo.NodeID, "method": "(*Node).Leave", "raft_id": fmt.Sprintf("%x", n.Config.ID), } if nodeInfo.ForwardedBy != nil { fields["forwarder.id"] = nodeInfo.ForwardedBy.NodeID } log.G(ctx).WithFields(fields).Debug("") if err := n.removeMember(ctx, req.Node.RaftID); err != nil { return nil, err } return &api.LeaveResponse{}, nil }
// DetachNetwork allows the node to request the release of // the resources associated to the network attachment. // - Returns `InvalidArgument` if attachment ID is not provided. // - Returns `NotFound` if the attachment is not found. // - Returns an error if the deletion fails. func (ra *ResourceAllocator) DetachNetwork(ctx context.Context, request *api.DetachNetworkRequest) (*api.DetachNetworkResponse, error) { if request.AttachmentID == "" { return nil, grpc.Errorf(codes.InvalidArgument, errInvalidArgument.Error()) } nodeInfo, err := ca.RemoteNode(ctx) if err != nil { return nil, err } if err := ra.store.Update(func(tx store.Tx) error { t := store.GetTask(tx, request.AttachmentID) if t == nil { return grpc.Errorf(codes.NotFound, "attachment %s not found", request.AttachmentID) } if t.NodeID != nodeInfo.NodeID { return grpc.Errorf(codes.PermissionDenied, "attachment %s doesn't belong to this node", request.AttachmentID) } return store.DeleteTask(tx, request.AttachmentID) }); err != nil { return nil, err } return &api.DetachNetworkResponse{}, nil }
// PublishLogs publishes log messages for a given subscription func (lb *LogBroker) PublishLogs(stream api.LogBroker_PublishLogsServer) error { remote, err := ca.RemoteNode(stream.Context()) if err != nil { return err } for { log, err := stream.Recv() if err == io.EOF { return stream.SendAndClose(&api.PublishLogsResponse{}) } if err != nil { return err } if log.SubscriptionID == "" { return grpc.Errorf(codes.InvalidArgument, "missing subscription ID") } // Make sure logs are emitted using the right Node ID to avoid impersonation. for _, msg := range log.Messages { if msg.Context.NodeID != remote.NodeID { return grpc.Errorf(codes.PermissionDenied, "invalid NodeID: expected=%s;received=%s", remote.NodeID, msg.Context.NodeID) } } lb.publish(log) } }
// Join asks to a member of the raft to propose // a configuration change and add us as a member thus // beginning the log replication process. This method // is called from an aspiring member to an existing member func (n *Node) Join(ctx context.Context, req *api.JoinRequest) (*api.JoinResponse, error) { nodeInfo, err := ca.RemoteNode(ctx) if err != nil { return nil, err } fields := logrus.Fields{ "node.id": nodeInfo.NodeID, "method": "(*Node).Join", } if nodeInfo.ForwardedBy != nil { fields["forwarder.id"] = nodeInfo.ForwardedBy.NodeID } log := log.G(ctx).WithFields(fields) // can't stop the raft node while an async RPC is in progress n.stopMu.RLock() defer n.stopMu.RUnlock() n.membershipLock.Lock() defer n.membershipLock.Unlock() if n.Node == nil { log.WithError(ErrStopped).Errorf(ErrStopped.Error()) return nil, ErrStopped } if !n.IsLeader() { return nil, ErrLostLeadership } // Find a unique ID for the joining member. var raftID uint64 for { raftID = uint64(rand.Int63()) + 1 if n.cluster.GetMember(raftID) == nil && !n.cluster.IsIDRemoved(raftID) { break } } err = n.addMember(ctx, req.Addr, raftID, nodeInfo.NodeID) if err != nil { log.WithError(err).Errorf("failed to add member") return nil, err } var nodes []*api.RaftMember for _, node := range n.cluster.Members() { nodes = append(nodes, &api.RaftMember{ RaftID: node.RaftID, NodeID: node.NodeID, Addr: node.Addr, }) } log.Debugf("node joined") return &api.JoinResponse{Members: nodes, RaftID: raftID}, nil }
// Heartbeat is heartbeat method for nodes. It returns new TTL in response. // Node should send new heartbeat earlier than now + TTL, otherwise it will // be deregistered from dispatcher and its status will be updated to NodeStatus_DOWN func (d *Dispatcher) Heartbeat(ctx context.Context, r *api.HeartbeatRequest) (*api.HeartbeatResponse, error) { nodeInfo, err := ca.RemoteNode(ctx) if err != nil { return nil, err } period, err := d.nodes.Heartbeat(nodeInfo.NodeID, r.SessionID) return &api.HeartbeatResponse{Period: *ptypes.DurationProto(period)}, err }
// AttachNetwork allows the node to request the resources // allocation needed for a network attachment on the specific node. // - Returns `InvalidArgument` if the Spec is malformed. // - Returns `NotFound` if the Network is not found. // - Returns `PermissionDenied` if the Network is not manually attachable. // - Returns an error if the creation fails. func (ra *ResourceAllocator) AttachNetwork(ctx context.Context, request *api.AttachNetworkRequest) (*api.AttachNetworkResponse, error) { nodeInfo, err := ca.RemoteNode(ctx) if err != nil { return nil, err } var network *api.Network ra.store.View(func(tx store.ReadTx) { network = store.GetNetwork(tx, request.Config.Target) if network == nil { if networks, err := store.FindNetworks(tx, store.ByName(request.Config.Target)); err == nil && len(networks) == 1 { network = networks[0] } } }) if network == nil { return nil, grpc.Errorf(codes.NotFound, "network %s not found", request.Config.Target) } if !network.Spec.Attachable { return nil, grpc.Errorf(codes.PermissionDenied, "network %s not manually attachable", request.Config.Target) } t := &api.Task{ ID: identity.NewID(), NodeID: nodeInfo.NodeID, Spec: api.TaskSpec{ Runtime: &api.TaskSpec_Attachment{ Attachment: &api.NetworkAttachmentSpec{ ContainerID: request.ContainerID, }, }, Networks: []*api.NetworkAttachmentConfig{ { Target: network.ID, Addresses: request.Config.Addresses, }, }, }, Status: api.TaskStatus{ State: api.TaskStateNew, Timestamp: ptypes.MustTimestampProto(time.Now()), Message: "created", }, DesiredState: api.TaskStateRunning, // TODO: Add Network attachment. } if err := ra.store.Update(func(tx store.Tx) error { return store.CreateTask(tx, t) }); err != nil { return nil, err } return &api.AttachNetworkResponse{AttachmentID: t.ID}, nil }
// gets the node IP from the context of a grpc call func nodeIPFromContext(ctx context.Context) (string, error) { nodeInfo, err := ca.RemoteNode(ctx) if err != nil { return "", err } addr, _, err := net.SplitHostPort(nodeInfo.RemoteAddr) if err != nil { return "", errors.Wrap(err, "unable to get ip from addr:port") } return addr, nil }
// ListenSubscriptions returns a stream of matching subscriptions for the current node func (lb *LogBroker) ListenSubscriptions(request *api.ListenSubscriptionsRequest, stream api.LogBroker_ListenSubscriptionsServer) error { remote, err := ca.RemoteNode(stream.Context()) if err != nil { return err } log := log.G(stream.Context()).WithFields( logrus.Fields{ "method": "(*LogBroker).ListenSubscriptions", "node": remote.NodeID, }, ) subscriptions, subscriptionCh, subscriptionCancel := lb.watchSubscriptions() defer subscriptionCancel() log.Debug("node registered") // Start by sending down all active subscriptions. for _, subscription := range subscriptions { select { case <-stream.Context().Done(): return stream.Context().Err() case <-lb.pctx.Done(): return nil default: } if err := stream.Send(subscription); err != nil { log.Error(err) return err } } // Send down new subscriptions. // TODO(aluzzardi): We should filter by relevant tasks for this node rather for { select { case v := <-subscriptionCh: subscription := v.(*api.SubscriptionMessage) if err := stream.Send(subscription); err != nil { log.Error(err) return err } case <-stream.Context().Done(): return stream.Context().Err() case <-lb.pctx.Done(): return nil } } }
// Join asks to a member of the raft to propose // a configuration change and add us as a member thus // beginning the log replication process. This method // is called from an aspiring member to an existing member func (n *Node) Join(ctx context.Context, req *api.JoinRequest) (*api.JoinResponse, error) { nodeInfo, err := ca.RemoteNode(ctx) if err != nil { return nil, err } fields := logrus.Fields{ "node.id": nodeInfo.NodeID, "method": "(*Node).Join", } if nodeInfo.ForwardedBy != nil { fields["forwarder.id"] = nodeInfo.ForwardedBy.NodeID } log := log.G(ctx).WithFields(fields) raftID, err := identity.ParseNodeID(nodeInfo.NodeID) if err != nil { return nil, err } // can't stop the raft node while an async RPC is in progress n.stopMu.RLock() defer n.stopMu.RUnlock() if n.Node == nil { log.WithError(ErrStopped).Errorf(ErrStopped.Error()) return nil, ErrStopped } // We submit a configuration change only if the node was not registered yet if n.cluster.GetMember(raftID) == nil { err = n.addMember(ctx, req.Addr, raftID) if err != nil { log.WithError(err).Errorf("failed to add member") return nil, err } } var nodes []*api.RaftMember for _, node := range n.cluster.Members() { nodes = append(nodes, &api.RaftMember{ RaftID: node.RaftID, Addr: node.Addr, }) } log.Debugf("node joined") return &api.JoinResponse{Members: nodes}, nil }
// PublishLogs publishes log messages for a given subscription func (lb *LogBroker) PublishLogs(stream api.LogBroker_PublishLogsServer) (err error) { remote, err := ca.RemoteNode(stream.Context()) if err != nil { return err } var currentSubscription *subscription defer func() { if currentSubscription != nil { currentSubscription.Done(remote.NodeID, err) } }() for { log, err := stream.Recv() if err == io.EOF { return stream.SendAndClose(&api.PublishLogsResponse{}) } if err != nil { return err } if log.SubscriptionID == "" { return grpc.Errorf(codes.InvalidArgument, "missing subscription ID") } if currentSubscription == nil { currentSubscription = lb.getSubscription(log.SubscriptionID) if currentSubscription == nil { return grpc.Errorf(codes.NotFound, "unknown subscription ID") } } else { if log.SubscriptionID != currentSubscription.message.ID { return grpc.Errorf(codes.InvalidArgument, "different subscription IDs in the same session") } } // Make sure logs are emitted using the right Node ID to avoid impersonation. for _, msg := range log.Messages { if msg.Context.NodeID != remote.NodeID { return grpc.Errorf(codes.PermissionDenied, "invalid NodeID: expected=%s;received=%s", remote.NodeID, msg.Context.NodeID) } } lb.publish(log) } }
// ResolveAddress returns the address reaching for a given node ID. func (n *Node) ResolveAddress(ctx context.Context, msg *api.ResolveAddressRequest) (*api.ResolveAddressResponse, error) { nodeInfo, err := ca.RemoteNode(ctx) if err != nil { return nil, err } fields := logrus.Fields{ "node.id": nodeInfo.NodeID, "method": "(*Node).ResolveAddress", } if nodeInfo.ForwardedBy != nil { fields["forwarder.id"] = nodeInfo.ForwardedBy.NodeID } log.G(ctx).WithFields(fields).Debugf("") member := n.cluster.GetMember(msg.RaftID) if member == nil { return nil, grpc.Errorf(codes.NotFound, "member %s not found", identity.FormatNodeID(msg.RaftID)) } return &api.ResolveAddressResponse{Addr: member.Addr}, nil }
// Join asks to a member of the raft to propose // a configuration change and add us as a member thus // beginning the log replication process. This method // is called from an aspiring member to an existing member func (n *Node) Join(ctx context.Context, req *api.JoinRequest) (*api.JoinResponse, error) { nodeInfo, err := ca.RemoteNode(ctx) if err != nil { return nil, err } fields := logrus.Fields{ "node.id": nodeInfo.NodeID, "method": "(*Node).Join", "raft_id": fmt.Sprintf("%x", n.Config.ID), } if nodeInfo.ForwardedBy != nil { fields["forwarder.id"] = nodeInfo.ForwardedBy.NodeID } log := log.G(ctx).WithFields(fields) log.Debug("") // can't stop the raft node while an async RPC is in progress n.stopMu.RLock() defer n.stopMu.RUnlock() n.membershipLock.Lock() defer n.membershipLock.Unlock() if !n.IsMember() { return nil, ErrNoRaftMember } if !n.isLeader() { return nil, ErrLostLeadership } // Find a unique ID for the joining member. var raftID uint64 for { raftID = uint64(rand.Int63()) + 1 if n.cluster.GetMember(raftID) == nil && !n.cluster.IsIDRemoved(raftID) { break } } remoteAddr := req.Addr // If the joining node sent an address like 0.0.0.0:4242, automatically // determine its actual address based on the GRPC connection. This // avoids the need for a prospective member to know its own address. requestHost, requestPort, err := net.SplitHostPort(remoteAddr) if err != nil { return nil, fmt.Errorf("invalid address %s in raft join request", remoteAddr) } requestIP := net.ParseIP(requestHost) if requestIP != nil && requestIP.IsUnspecified() { remoteHost, _, err := net.SplitHostPort(nodeInfo.RemoteAddr) if err != nil { return nil, err } remoteAddr = net.JoinHostPort(remoteHost, requestPort) } // We do not bother submitting a configuration change for the // new member if we can't contact it back using its address if err := n.checkHealth(ctx, remoteAddr, 5*time.Second); err != nil { return nil, err } err = n.addMember(ctx, remoteAddr, raftID, nodeInfo.NodeID) if err != nil { log.WithError(err).Errorf("failed to add member %x", raftID) return nil, err } var nodes []*api.RaftMember for _, node := range n.cluster.Members() { nodes = append(nodes, &api.RaftMember{ RaftID: node.RaftID, NodeID: node.NodeID, Addr: node.Addr, }) } log.Debugf("node joined") return &api.JoinResponse{Members: nodes, RaftID: raftID}, nil }
// Session is a stream which controls agent connection. // Each message contains list of backup Managers with weights. Also there is // a special boolean field Disconnect which if true indicates that node should // reconnect to another Manager immediately. func (d *Dispatcher) Session(r *api.SessionRequest, stream api.Dispatcher_SessionServer) error { ctx := stream.Context() nodeInfo, err := ca.RemoteNode(ctx) if err != nil { return err } nodeID := nodeInfo.NodeID if err := d.isRunningLocked(); err != nil { return err } // register the node. sessionID, err := d.register(stream.Context(), nodeID, r.Description) if err != nil { return err } fields := logrus.Fields{ "node.id": nodeID, "node.session": sessionID, "method": "(*Dispatcher).Session", } if nodeInfo.ForwardedBy != nil { fields["forwarder.id"] = nodeInfo.ForwardedBy.NodeID } log := log.G(ctx).WithFields(fields) var nodeObj *api.Node nodeUpdates, cancel, err := store.ViewAndWatch(d.store, func(readTx store.ReadTx) error { nodeObj = store.GetNode(readTx, nodeID) return nil }, state.EventUpdateNode{Node: &api.Node{ID: nodeID}, Checks: []state.NodeCheckFunc{state.NodeCheckID}}, ) if cancel != nil { defer cancel() } if err != nil { log.WithError(err).Error("ViewAndWatch Node failed") } if _, err = d.nodes.GetWithSession(nodeID, sessionID); err != nil { return err } if err := stream.Send(&api.SessionMessage{ SessionID: sessionID, Node: nodeObj, Managers: d.getManagers(), NetworkBootstrapKeys: d.networkBootstrapKeys, }); err != nil { return err } managerUpdates, mgrCancel := d.mgrQueue.Watch() defer mgrCancel() keyMgrUpdates, keyMgrCancel := d.keyMgrQueue.Watch() defer keyMgrCancel() // disconnectNode is a helper forcibly shutdown connection disconnectNode := func() error { // force disconnect by shutting down the stream. transportStream, ok := transport.StreamFromContext(stream.Context()) if ok { // if we have the transport stream, we can signal a disconnect // in the client. if err := transportStream.ServerTransport().Close(); err != nil { log.WithError(err).Error("session end") } } nodeStatus := api.NodeStatus{State: api.NodeStatus_DISCONNECTED, Message: "node is currently trying to find new manager"} if err := d.nodeRemove(nodeID, nodeStatus); err != nil { log.WithError(err).Error("failed to remove node") } // still return an abort if the transport closure was ineffective. return grpc.Errorf(codes.Aborted, "node must disconnect") } for { // After each message send, we need to check the nodes sessionID hasn't // changed. If it has, we will the stream and make the node // re-register. node, err := d.nodes.GetWithSession(nodeID, sessionID) if err != nil { return err } var mgrs []*api.WeightedPeer var disconnect bool select { case ev := <-managerUpdates: mgrs = ev.([]*api.WeightedPeer) case ev := <-nodeUpdates: nodeObj = ev.(state.EventUpdateNode).Node case <-stream.Context().Done(): return stream.Context().Err() case <-node.Disconnect: disconnect = true case <-d.ctx.Done(): disconnect = true case <-keyMgrUpdates: } if mgrs == nil { mgrs = d.getManagers() } if err := stream.Send(&api.SessionMessage{ SessionID: sessionID, Node: nodeObj, Managers: mgrs, NetworkBootstrapKeys: d.networkBootstrapKeys, }); err != nil { return err } if disconnect { return disconnectNode() } } }
// ListenSubscriptions returns a stream of matching subscriptions for the current node func (lb *LogBroker) ListenSubscriptions(request *api.ListenSubscriptionsRequest, stream api.LogBroker_ListenSubscriptionsServer) error { remote, err := ca.RemoteNode(stream.Context()) if err != nil { return err } lb.nodeConnected(remote.NodeID) defer lb.nodeDisconnected(remote.NodeID) log := log.G(stream.Context()).WithFields( logrus.Fields{ "method": "(*LogBroker).ListenSubscriptions", "node": remote.NodeID, }, ) subscriptions, subscriptionCh, subscriptionCancel := lb.watchSubscriptions(remote.NodeID) defer subscriptionCancel() log.Debug("node registered") activeSubscriptions := make(map[string]*subscription) defer func() { // If the worker quits, mark all active subscriptions as finished. for _, subscription := range activeSubscriptions { subscription.Done(remote.NodeID, fmt.Errorf("node %s disconnected unexpectedly", remote.NodeID)) } }() // Start by sending down all active subscriptions. for _, subscription := range subscriptions { select { case <-stream.Context().Done(): return stream.Context().Err() case <-lb.pctx.Done(): return nil default: } if err := stream.Send(subscription.message); err != nil { log.Error(err) return err } activeSubscriptions[subscription.message.ID] = subscription } // Send down new subscriptions. for { select { case v := <-subscriptionCh: subscription := v.(*subscription) if subscription.message.Close { log.WithField("subscription.id", subscription.message.ID).Debug("subscription closed") delete(activeSubscriptions, subscription.message.ID) } else { // Avoid sending down the same subscription multiple times if _, ok := activeSubscriptions[subscription.message.ID]; ok { continue } activeSubscriptions[subscription.message.ID] = subscription log.WithField("subscription.id", subscription.message.ID).Debug("subscription added") } if err := stream.Send(subscription.message); err != nil { log.Error(err) return err } case <-stream.Context().Done(): return stream.Context().Err() case <-lb.pctx.Done(): return nil } } }
// Tasks is a stream of tasks state for node. Each message contains full list // of tasks which should be run on node, if task is not present in that list, // it should be terminated. func (d *Dispatcher) Tasks(r *api.TasksRequest, stream api.Dispatcher_TasksServer) error { nodeInfo, err := ca.RemoteNode(stream.Context()) if err != nil { return err } nodeID := nodeInfo.NodeID if err := d.isRunningLocked(); err != nil { return err } fields := logrus.Fields{ "node.id": nodeID, "node.session": r.SessionID, "method": "(*Dispatcher).Tasks", } if nodeInfo.ForwardedBy != nil { fields["forwarder.id"] = nodeInfo.ForwardedBy.NodeID } log.G(stream.Context()).WithFields(fields).Debugf("") if _, err = d.nodes.GetWithSession(nodeID, r.SessionID); err != nil { return err } tasksMap := make(map[string]*api.Task) nodeTasks, cancel, err := store.ViewAndWatch( d.store, func(readTx store.ReadTx) error { tasks, err := store.FindTasks(readTx, store.ByNodeID(nodeID)) if err != nil { return err } for _, t := range tasks { tasksMap[t.ID] = t } return nil }, state.EventCreateTask{Task: &api.Task{NodeID: nodeID}, Checks: []state.TaskCheckFunc{state.TaskCheckNodeID}}, state.EventUpdateTask{Task: &api.Task{NodeID: nodeID}, Checks: []state.TaskCheckFunc{state.TaskCheckNodeID}}, state.EventDeleteTask{Task: &api.Task{NodeID: nodeID}, Checks: []state.TaskCheckFunc{state.TaskCheckNodeID}}, ) if err != nil { return err } defer cancel() for { if _, err := d.nodes.GetWithSession(nodeID, r.SessionID); err != nil { return err } var tasks []*api.Task for _, t := range tasksMap { // dispatcher only sends tasks that have been assigned to a node if t != nil && t.Status.State >= api.TaskStateAssigned { tasks = append(tasks, t) } } if err := stream.Send(&api.TasksMessage{Tasks: tasks}); err != nil { return err } select { case event := <-nodeTasks: switch v := event.(type) { case state.EventCreateTask: tasksMap[v.Task.ID] = v.Task case state.EventUpdateTask: tasksMap[v.Task.ID] = v.Task case state.EventDeleteTask: delete(tasksMap, v.Task.ID) } case <-stream.Context().Done(): return stream.Context().Err() case <-d.ctx.Done(): return d.ctx.Err() } } }
// UpdateTaskStatus updates status of task. Node should send such updates // on every status change of its tasks. func (d *Dispatcher) UpdateTaskStatus(ctx context.Context, r *api.UpdateTaskStatusRequest) (*api.UpdateTaskStatusResponse, error) { nodeInfo, err := ca.RemoteNode(ctx) if err != nil { return nil, err } nodeID := nodeInfo.NodeID fields := logrus.Fields{ "node.id": nodeID, "node.session": r.SessionID, "method": "(*Dispatcher).UpdateTaskStatus", } if nodeInfo.ForwardedBy != nil { fields["forwarder.id"] = nodeInfo.ForwardedBy.NodeID } log := log.G(ctx).WithFields(fields) if err := d.isRunningLocked(); err != nil { return nil, err } if _, err := d.nodes.GetWithSession(nodeID, r.SessionID); err != nil { return nil, err } // Validate task updates for _, u := range r.Updates { if u.Status == nil { log.WithField("task.id", u.TaskID).Warn("task report has nil status") continue } var t *api.Task d.store.View(func(tx store.ReadTx) { t = store.GetTask(tx, u.TaskID) }) if t == nil { log.WithField("task.id", u.TaskID).Warn("cannot find target task in store") continue } if t.NodeID != nodeID { err := grpc.Errorf(codes.PermissionDenied, "cannot update a task not assigned this node") log.WithField("task.id", u.TaskID).Error(err) return nil, err } } d.taskUpdatesLock.Lock() // Enqueue task updates for _, u := range r.Updates { if u.Status == nil { continue } d.taskUpdates[u.TaskID] = u.Status } numUpdates := len(d.taskUpdates) d.taskUpdatesLock.Unlock() if numUpdates >= maxBatchItems { d.processTaskUpdatesTrigger <- struct{}{} } return nil, nil }
// Assignments is a stream of assignments for a node. Each message contains // either full list of tasks and secrets for the node, or an incremental update. func (d *Dispatcher) Assignments(r *api.AssignmentsRequest, stream api.Dispatcher_AssignmentsServer) error { nodeInfo, err := ca.RemoteNode(stream.Context()) if err != nil { return err } nodeID := nodeInfo.NodeID dctx, err := d.isRunningLocked() if err != nil { return err } fields := logrus.Fields{ "node.id": nodeID, "node.session": r.SessionID, "method": "(*Dispatcher).Assignments", } if nodeInfo.ForwardedBy != nil { fields["forwarder.id"] = nodeInfo.ForwardedBy.NodeID } log := log.G(stream.Context()).WithFields(fields) log.Debugf("") if _, err = d.nodes.GetWithSession(nodeID, r.SessionID); err != nil { return err } var ( sequence int64 appliesTo string initial api.AssignmentsMessage ) tasksMap := make(map[string]*api.Task) tasksUsingSecret := make(map[string]map[string]struct{}) sendMessage := func(msg api.AssignmentsMessage, assignmentType api.AssignmentsMessage_Type) error { sequence++ msg.AppliesTo = appliesTo msg.ResultsIn = strconv.FormatInt(sequence, 10) appliesTo = msg.ResultsIn msg.Type = assignmentType if err := stream.Send(&msg); err != nil { return err } return nil } // returns a slice of new secrets to send down addSecretsForTask := func(readTx store.ReadTx, t *api.Task) []*api.Secret { container := t.Spec.GetContainer() if container == nil { return nil } var newSecrets []*api.Secret for _, secretRef := range container.Secrets { // Empty ID prefix will return all secrets. Bail if there is no SecretID if secretRef.SecretID == "" { log.Debugf("invalid secret reference") continue } secretID := secretRef.SecretID log := log.WithFields(logrus.Fields{ "secret.id": secretID, "secret.name": secretRef.SecretName, }) if len(tasksUsingSecret[secretID]) == 0 { tasksUsingSecret[secretID] = make(map[string]struct{}) secrets, err := store.FindSecrets(readTx, store.ByIDPrefix(secretID)) if err != nil { log.WithError(err).Errorf("error retrieving secret") continue } if len(secrets) != 1 { log.Debugf("secret not found") continue } // If the secret was found and there was one result // (there should never be more than one because of the // uniqueness constraint), add this secret to our // initial set that we send down. newSecrets = append(newSecrets, secrets[0]) } tasksUsingSecret[secretID][t.ID] = struct{}{} } return newSecrets } // TODO(aaronl): Also send node secrets that should be exposed to // this node. nodeTasks, cancel, err := store.ViewAndWatch( d.store, func(readTx store.ReadTx) error { tasks, err := store.FindTasks(readTx, store.ByNodeID(nodeID)) if err != nil { return err } for _, t := range tasks { // We only care about tasks that are ASSIGNED or // higher. If the state is below ASSIGNED, the // task may not meet the constraints for this // node, so we have to be careful about sending // secrets associated with it. if t.Status.State < api.TaskStateAssigned { continue } tasksMap[t.ID] = t taskChange := &api.AssignmentChange{ Assignment: &api.Assignment{ Item: &api.Assignment_Task{ Task: t, }, }, Action: api.AssignmentChange_AssignmentActionUpdate, } initial.Changes = append(initial.Changes, taskChange) // Only send secrets down if these tasks are in < RUNNING if t.Status.State <= api.TaskStateRunning { newSecrets := addSecretsForTask(readTx, t) for _, secret := range newSecrets { secretChange := &api.AssignmentChange{ Assignment: &api.Assignment{ Item: &api.Assignment_Secret{ Secret: secret, }, }, Action: api.AssignmentChange_AssignmentActionUpdate, } initial.Changes = append(initial.Changes, secretChange) } } } return nil }, state.EventUpdateTask{Task: &api.Task{NodeID: nodeID}, Checks: []state.TaskCheckFunc{state.TaskCheckNodeID}}, state.EventDeleteTask{Task: &api.Task{NodeID: nodeID}, Checks: []state.TaskCheckFunc{state.TaskCheckNodeID}}, state.EventUpdateSecret{}, state.EventDeleteSecret{}, ) if err != nil { return err } defer cancel() if err := sendMessage(initial, api.AssignmentsMessage_COMPLETE); err != nil { return err } for { // Check for session expiration if _, err := d.nodes.GetWithSession(nodeID, r.SessionID); err != nil { return err } // bursty events should be processed in batches and sent out together var ( update api.AssignmentsMessage modificationCnt int batchingTimer *time.Timer batchingTimeout <-chan time.Time updateTasks = make(map[string]*api.Task) updateSecrets = make(map[string]*api.Secret) removeTasks = make(map[string]struct{}) removeSecrets = make(map[string]struct{}) ) oneModification := func() { modificationCnt++ if batchingTimer != nil { batchingTimer.Reset(batchingWaitTime) } else { batchingTimer = time.NewTimer(batchingWaitTime) batchingTimeout = batchingTimer.C } } // Release the secrets references from this task releaseSecretsForTask := func(t *api.Task) bool { var modified bool container := t.Spec.GetContainer() if container == nil { return modified } for _, secretRef := range container.Secrets { secretID := secretRef.SecretID delete(tasksUsingSecret[secretID], t.ID) if len(tasksUsingSecret[secretID]) == 0 { // No tasks are using the secret anymore delete(tasksUsingSecret, secretID) removeSecrets[secretID] = struct{}{} modified = true } } return modified } // The batching loop waits for 50 ms after the most recent // change, or until modificationBatchLimit is reached. The // worst case latency is modificationBatchLimit * batchingWaitTime, // which is 10 seconds. batchingLoop: for modificationCnt < modificationBatchLimit { select { case event := <-nodeTasks: switch v := event.(type) { // We don't monitor EventCreateTask because tasks are // never created in the ASSIGNED state. First tasks are // created by the orchestrator, then the scheduler moves // them to ASSIGNED. If this ever changes, we will need // to monitor task creations as well. case state.EventUpdateTask: // We only care about tasks that are ASSIGNED or // higher. if v.Task.Status.State < api.TaskStateAssigned { continue } if oldTask, exists := tasksMap[v.Task.ID]; exists { // States ASSIGNED and below are set by the orchestrator/scheduler, // not the agent, so tasks in these states need to be sent to the // agent even if nothing else has changed. if equality.TasksEqualStable(oldTask, v.Task) && v.Task.Status.State > api.TaskStateAssigned { // this update should not trigger a task change for the agent tasksMap[v.Task.ID] = v.Task // If this task got updated to a final state, let's release // the secrets that are being used by the task if v.Task.Status.State > api.TaskStateRunning { // If releasing the secrets caused a secret to be // removed from an agent, mark one modification if releaseSecretsForTask(v.Task) { oneModification() } } continue } } else if v.Task.Status.State <= api.TaskStateRunning { // If this task wasn't part of the assignment set before, and it's <= RUNNING // add the secrets it references to the secrets assignment. // Task states > RUNNING are worker reported only, are never created in // a > RUNNING state. var newSecrets []*api.Secret d.store.View(func(readTx store.ReadTx) { newSecrets = addSecretsForTask(readTx, v.Task) }) for _, secret := range newSecrets { updateSecrets[secret.ID] = secret } } tasksMap[v.Task.ID] = v.Task updateTasks[v.Task.ID] = v.Task oneModification() case state.EventDeleteTask: if _, exists := tasksMap[v.Task.ID]; !exists { continue } removeTasks[v.Task.ID] = struct{}{} delete(tasksMap, v.Task.ID) // Release the secrets being used by this task // Ignoring the return here. We will always mark // this as a modification, since a task is being // removed. releaseSecretsForTask(v.Task) oneModification() // TODO(aaronl): For node secrets, we'll need to handle // EventCreateSecret. case state.EventUpdateSecret: if _, exists := tasksUsingSecret[v.Secret.ID]; !exists { continue } log.Debugf("Secret %s (ID: %d) was updated though it was still referenced by one or more tasks", v.Secret.Spec.Annotations.Name, v.Secret.ID) case state.EventDeleteSecret: if _, exists := tasksUsingSecret[v.Secret.ID]; !exists { continue } log.Debugf("Secret %s (ID: %d) was deleted though it was still referenced by one or more tasks", v.Secret.Spec.Annotations.Name, v.Secret.ID) } case <-batchingTimeout: break batchingLoop case <-stream.Context().Done(): return stream.Context().Err() case <-dctx.Done(): return dctx.Err() } } if batchingTimer != nil { batchingTimer.Stop() } if modificationCnt > 0 { for id, task := range updateTasks { if _, ok := removeTasks[id]; !ok { taskChange := &api.AssignmentChange{ Assignment: &api.Assignment{ Item: &api.Assignment_Task{ Task: task, }, }, Action: api.AssignmentChange_AssignmentActionUpdate, } update.Changes = append(update.Changes, taskChange) } } for id, secret := range updateSecrets { // If, due to multiple updates, this secret is no longer in use, // don't send it down. if len(tasksUsingSecret[id]) == 0 { // delete this secret for the secrets to be updated // so that deleteSecrets knows the current list delete(updateSecrets, id) continue } secretChange := &api.AssignmentChange{ Assignment: &api.Assignment{ Item: &api.Assignment_Secret{ Secret: secret, }, }, Action: api.AssignmentChange_AssignmentActionUpdate, } update.Changes = append(update.Changes, secretChange) } for id := range removeTasks { taskChange := &api.AssignmentChange{ Assignment: &api.Assignment{ Item: &api.Assignment_Task{ Task: &api.Task{ID: id}, }, }, Action: api.AssignmentChange_AssignmentActionRemove, } update.Changes = append(update.Changes, taskChange) } for id := range removeSecrets { // If this secret is also being sent on the updated set // don't also add it to the removed set if _, ok := updateSecrets[id]; ok { continue } secretChange := &api.AssignmentChange{ Assignment: &api.Assignment{ Item: &api.Assignment_Secret{ Secret: &api.Secret{ID: id}, }, }, Action: api.AssignmentChange_AssignmentActionRemove, } update.Changes = append(update.Changes, secretChange) } if err := sendMessage(update, api.AssignmentsMessage_INCREMENTAL); err != nil { return err } } } }
// Tasks is a stream of tasks state for node. Each message contains full list // of tasks which should be run on node, if task is not present in that list, // it should be terminated. func (d *Dispatcher) Tasks(r *api.TasksRequest, stream api.Dispatcher_TasksServer) error { nodeInfo, err := ca.RemoteNode(stream.Context()) if err != nil { return err } nodeID := nodeInfo.NodeID dctx, err := d.isRunningLocked() if err != nil { return err } fields := logrus.Fields{ "node.id": nodeID, "node.session": r.SessionID, "method": "(*Dispatcher).Tasks", } if nodeInfo.ForwardedBy != nil { fields["forwarder.id"] = nodeInfo.ForwardedBy.NodeID } log.G(stream.Context()).WithFields(fields).Debugf("") if _, err = d.nodes.GetWithSession(nodeID, r.SessionID); err != nil { return err } tasksMap := make(map[string]*api.Task) nodeTasks, cancel, err := store.ViewAndWatch( d.store, func(readTx store.ReadTx) error { tasks, err := store.FindTasks(readTx, store.ByNodeID(nodeID)) if err != nil { return err } for _, t := range tasks { tasksMap[t.ID] = t } return nil }, state.EventCreateTask{Task: &api.Task{NodeID: nodeID}, Checks: []state.TaskCheckFunc{state.TaskCheckNodeID}}, state.EventUpdateTask{Task: &api.Task{NodeID: nodeID}, Checks: []state.TaskCheckFunc{state.TaskCheckNodeID}}, state.EventDeleteTask{Task: &api.Task{NodeID: nodeID}, Checks: []state.TaskCheckFunc{state.TaskCheckNodeID}}, ) if err != nil { return err } defer cancel() for { if _, err := d.nodes.GetWithSession(nodeID, r.SessionID); err != nil { return err } var tasks []*api.Task for _, t := range tasksMap { // dispatcher only sends tasks that have been assigned to a node if t != nil && t.Status.State >= api.TaskStateAssigned { tasks = append(tasks, t) } } if err := stream.Send(&api.TasksMessage{Tasks: tasks}); err != nil { return err } // bursty events should be processed in batches and sent out snapshot var ( modificationCnt int batchingTimer *time.Timer batchingTimeout <-chan time.Time ) batchingLoop: for modificationCnt < modificationBatchLimit { select { case event := <-nodeTasks: switch v := event.(type) { case state.EventCreateTask: tasksMap[v.Task.ID] = v.Task modificationCnt++ case state.EventUpdateTask: if oldTask, exists := tasksMap[v.Task.ID]; exists { // States ASSIGNED and below are set by the orchestrator/scheduler, // not the agent, so tasks in these states need to be sent to the // agent even if nothing else has changed. if equality.TasksEqualStable(oldTask, v.Task) && v.Task.Status.State > api.TaskStateAssigned { // this update should not trigger action at agent tasksMap[v.Task.ID] = v.Task continue } } tasksMap[v.Task.ID] = v.Task modificationCnt++ case state.EventDeleteTask: delete(tasksMap, v.Task.ID) modificationCnt++ } if batchingTimer != nil { batchingTimer.Reset(batchingWaitTime) } else { batchingTimer = time.NewTimer(batchingWaitTime) batchingTimeout = batchingTimer.C } case <-batchingTimeout: break batchingLoop case <-stream.Context().Done(): return stream.Context().Err() case <-dctx.Done(): return dctx.Err() } } if batchingTimer != nil { batchingTimer.Stop() } } }
// Tasks is a stream of tasks state for node. Each message contains full list // of tasks which should be run on node, if task is not present in that list, // it should be terminated. func (d *Dispatcher) Tasks(r *api.TasksRequest, stream api.Dispatcher_TasksServer) error { nodeInfo, err := ca.RemoteNode(stream.Context()) if err != nil { return err } nodeID := nodeInfo.NodeID if err := d.isRunningLocked(); err != nil { return err } fields := logrus.Fields{ "node.id": nodeID, "node.session": r.SessionID, "method": "(*Dispatcher).Tasks", } if nodeInfo.ForwardedBy != nil { fields["forwarder.id"] = nodeInfo.ForwardedBy.NodeID } log.G(stream.Context()).WithFields(fields).Debugf("") if _, err = d.nodes.GetWithSession(nodeID, r.SessionID); err != nil { return err } tasksMap := make(map[string]*api.Task) nodeTasks, cancel, err := store.ViewAndWatch( d.store, func(readTx store.ReadTx) error { tasks, err := store.FindTasks(readTx, store.ByNodeID(nodeID)) if err != nil { return err } for _, t := range tasks { tasksMap[t.ID] = t } return nil }, state.EventCreateTask{Task: &api.Task{NodeID: nodeID}, Checks: []state.TaskCheckFunc{state.TaskCheckNodeID}}, state.EventUpdateTask{Task: &api.Task{NodeID: nodeID}, Checks: []state.TaskCheckFunc{state.TaskCheckNodeID}}, state.EventDeleteTask{Task: &api.Task{NodeID: nodeID}, Checks: []state.TaskCheckFunc{state.TaskCheckNodeID}}, ) if err != nil { return err } defer cancel() for { if _, err := d.nodes.GetWithSession(nodeID, r.SessionID); err != nil { return err } var tasks []*api.Task for _, t := range tasksMap { // dispatcher only sends tasks that have been assigned to a node if t != nil && t.Status.State >= api.TaskStateAssigned { tasks = append(tasks, t) } } if err := stream.Send(&api.TasksMessage{Tasks: tasks}); err != nil { return err } // bursty events should be processed in batches and sent out snapshot const modificationBatchLimit = 200 const eventPausedGap = 50 * time.Millisecond var modificationCnt int // eventPaused is true when there have been modifications // but next event has not arrived within eventPausedGap eventPaused := false for modificationCnt < modificationBatchLimit && !eventPaused { select { case event := <-nodeTasks: switch v := event.(type) { case state.EventCreateTask: tasksMap[v.Task.ID] = v.Task modificationCnt++ case state.EventUpdateTask: if oldTask, exists := tasksMap[v.Task.ID]; exists { if equality.TasksEqualStable(oldTask, v.Task) { // this update should not trigger action at agent tasksMap[v.Task.ID] = v.Task continue } } tasksMap[v.Task.ID] = v.Task modificationCnt++ case state.EventDeleteTask: delete(tasksMap, v.Task.ID) modificationCnt++ } case <-time.After(eventPausedGap): if modificationCnt > 0 { eventPaused = true } case <-stream.Context().Done(): return stream.Context().Err() case <-d.ctx.Done(): return d.ctx.Err() } } } }
// Assignments is a stream of assignments for a node. Each message contains // either full list of tasks and secrets for the node, or an incremental update. func (d *Dispatcher) Assignments(r *api.AssignmentsRequest, stream api.Dispatcher_AssignmentsServer) error { nodeInfo, err := ca.RemoteNode(stream.Context()) if err != nil { return err } nodeID := nodeInfo.NodeID if err := d.isRunningLocked(); err != nil { return err } fields := logrus.Fields{ "node.id": nodeID, "node.session": r.SessionID, "method": "(*Dispatcher).Assignments", } if nodeInfo.ForwardedBy != nil { fields["forwarder.id"] = nodeInfo.ForwardedBy.NodeID } log := log.G(stream.Context()).WithFields(fields) log.Debugf("") if _, err = d.nodes.GetWithSession(nodeID, r.SessionID); err != nil { return err } var ( sequence int64 appliesTo string initial api.AssignmentsMessage ) tasksMap := make(map[string]*api.Task) sendMessage := func(msg api.AssignmentsMessage, assignmentType api.AssignmentsMessage_Type) error { sequence++ msg.AppliesTo = appliesTo msg.ResultsIn = strconv.FormatInt(sequence, 10) appliesTo = msg.ResultsIn msg.Type = assignmentType if err := stream.Send(&msg); err != nil { return err } return nil } // TODO(aaronl): Also send node secrets that should be exposed to // this node. nodeTasks, cancel, err := store.ViewAndWatch( d.store, func(readTx store.ReadTx) error { tasks, err := store.FindTasks(readTx, store.ByNodeID(nodeID)) if err != nil { return err } for _, t := range tasks { // We only care about tasks that are ASSIGNED or // higher. If the state is below ASSIGNED, the // task may not meet the constraints for this // node, so we have to be careful about sending // secrets associated with it. if t.Status.State < api.TaskStateAssigned { continue } tasksMap[t.ID] = t initial.UpdateTasks = append(initial.UpdateTasks, t) } return nil }, state.EventUpdateTask{Task: &api.Task{NodeID: nodeID}, Checks: []state.TaskCheckFunc{state.TaskCheckNodeID}}, state.EventDeleteTask{Task: &api.Task{NodeID: nodeID}, Checks: []state.TaskCheckFunc{state.TaskCheckNodeID}}, ) if err != nil { return err } defer cancel() if err := sendMessage(initial, api.AssignmentsMessage_COMPLETE); err != nil { return err } for { // Check for session expiration if _, err := d.nodes.GetWithSession(nodeID, r.SessionID); err != nil { return err } // bursty events should be processed in batches and sent out together var ( update api.AssignmentsMessage modificationCnt int batchingTimer *time.Timer batchingTimeout <-chan time.Time updateTasks = make(map[string]*api.Task) removeTasks = make(map[string]struct{}) ) oneModification := func() { modificationCnt++ if batchingTimer != nil { batchingTimer.Reset(batchingWaitTime) } else { batchingTimer = time.NewTimer(batchingWaitTime) batchingTimeout = batchingTimer.C } } // The batching loop waits for 50 ms after the most recent // change, or until modificationBatchLimit is reached. The // worst case latency is modificationBatchLimit * batchingWaitTime, // which is 10 seconds. batchingLoop: for modificationCnt < modificationBatchLimit { select { case event := <-nodeTasks: switch v := event.(type) { // We don't monitor EventCreateTask because tasks are // never created in the ASSIGNED state. First tasks are // created by the orchestrator, then the scheduler moves // them to ASSIGNED. If this ever changes, we will need // to monitor task creations as well. case state.EventUpdateTask: // We only care about tasks that are ASSIGNED or // higher. if v.Task.Status.State < api.TaskStateAssigned { continue } if oldTask, exists := tasksMap[v.Task.ID]; exists { // States ASSIGNED and below are set by the orchestrator/scheduler, // not the agent, so tasks in these states need to be sent to the // agent even if nothing else has changed. if equality.TasksEqualStable(oldTask, v.Task) && v.Task.Status.State > api.TaskStateAssigned { // this update should not trigger a task change for the agent tasksMap[v.Task.ID] = v.Task continue } } tasksMap[v.Task.ID] = v.Task updateTasks[v.Task.ID] = v.Task oneModification() case state.EventDeleteTask: if _, exists := tasksMap[v.Task.ID]; !exists { continue } removeTasks[v.Task.ID] = struct{}{} delete(tasksMap, v.Task.ID) oneModification() } case <-batchingTimeout: break batchingLoop case <-stream.Context().Done(): return stream.Context().Err() case <-d.ctx.Done(): return d.ctx.Err() } } if batchingTimer != nil { batchingTimer.Stop() } if modificationCnt > 0 { for id, task := range updateTasks { if _, ok := removeTasks[id]; !ok { update.UpdateTasks = append(update.UpdateTasks, task) } } for id := range removeTasks { update.RemoveTasks = append(update.RemoveTasks, id) } if err := sendMessage(update, api.AssignmentsMessage_INCREMENTAL); err != nil { return err } } } }