func validateRestartPolicy(rp *api.RestartPolicy) error { if rp == nil { return nil } if rp.Delay != nil { delay, err := gogotypes.DurationFromProto(rp.Delay) if err != nil { return err } if delay < 0 { return grpc.Errorf(codes.InvalidArgument, "TaskSpec: restart-delay cannot be negative") } } if rp.Window != nil { win, err := gogotypes.DurationFromProto(rp.Window) if err != nil { return err } if win < 0 { return grpc.Errorf(codes.InvalidArgument, "TaskSpec: restart-window cannot be negative") } } return nil }
func printClusterSummary(cluster *api.Cluster) { w := tabwriter.NewWriter(os.Stdout, 8, 8, 8, ' ', 0) defer w.Flush() common.FprintfIfNotEmpty(w, "ID\t: %s\n", cluster.ID) common.FprintfIfNotEmpty(w, "Name\t: %s\n", cluster.Spec.Annotations.Name) fmt.Fprintf(w, "Orchestration settings:\n") fmt.Fprintf(w, " Task history entries: %d\n", cluster.Spec.Orchestration.TaskHistoryRetentionLimit) heartbeatPeriod, err := gogotypes.DurationFromProto(cluster.Spec.Dispatcher.HeartbeatPeriod) if err == nil { fmt.Fprintf(w, "Dispatcher settings:\n") fmt.Fprintf(w, " Dispatcher heartbeat period: %s\n", heartbeatPeriod.String()) } fmt.Fprintf(w, "Certificate Authority settings:\n") if cluster.Spec.CAConfig.NodeCertExpiry != nil { clusterDuration, err := gogotypes.DurationFromProto(cluster.Spec.CAConfig.NodeCertExpiry) if err != nil { fmt.Fprintf(w, " Certificate Validity Duration: [ERROR PARSING DURATION]\n") } else { fmt.Fprintf(w, " Certificate Validity Duration: %s\n", clusterDuration.String()) } } if len(cluster.Spec.CAConfig.ExternalCAs) > 0 { fmt.Fprintf(w, " External CAs:\n") for _, ca := range cluster.Spec.CAConfig.ExternalCAs { fmt.Fprintf(w, " %s: %s\n", ca.Protocol, ca.URL) } } fmt.Fprintln(w, " Join Tokens:") fmt.Fprintln(w, " Worker:", cluster.RootCA.JoinTokens.Worker) fmt.Fprintln(w, " Manager:", cluster.RootCA.JoinTokens.Manager) if cluster.Spec.TaskDefaults.LogDriver != nil { fmt.Fprintf(w, "Default Log Driver\t: %s\n", cluster.Spec.TaskDefaults.LogDriver.Name) var keys []string if len(cluster.Spec.TaskDefaults.LogDriver.Options) != 0 { for k := range cluster.Spec.TaskDefaults.LogDriver.Options { keys = append(keys, k) } sort.Strings(keys) for _, k := range keys { v := cluster.Spec.TaskDefaults.LogDriver.Options[k] if v != "" { fmt.Fprintf(w, " %s\t: %s\n", k, v) } else { fmt.Fprintf(w, " %s\t\n", k) } } } } }
func (c *containerConfig) healthcheck() *enginecontainer.HealthConfig { hcSpec := c.spec().Healthcheck if hcSpec == nil { return nil } interval, _ := gogotypes.DurationFromProto(hcSpec.Interval) timeout, _ := gogotypes.DurationFromProto(hcSpec.Timeout) return &enginecontainer.HealthConfig{ Test: hcSpec.Test, Interval: interval, Timeout: timeout, Retries: int(hcSpec.Retries), } }
// updateCluster is called when there are cluster changes, and it ensures that the local RootCA is // always aware of changes in clusterExpiry and the Root CA key material func (s *Server) updateCluster(ctx context.Context, cluster *api.Cluster) { s.mu.Lock() s.joinTokens = cluster.RootCA.JoinTokens.Copy() s.mu.Unlock() var err error // If the cluster has a RootCA, let's try to update our SecurityConfig to reflect the latest values rCA := cluster.RootCA if len(rCA.CACert) != 0 && len(rCA.CAKey) != 0 { expiry := DefaultNodeCertExpiration if cluster.Spec.CAConfig.NodeCertExpiry != nil { // NodeCertExpiry exists, let's try to parse the duration out of it clusterExpiry, err := gogotypes.DurationFromProto(cluster.Spec.CAConfig.NodeCertExpiry) if err != nil { log.G(ctx).WithFields(logrus.Fields{ "cluster.id": cluster.ID, "method": "(*Server).updateCluster", }).WithError(err).Warn("failed to parse certificate expiration, using default") } else { // We were able to successfully parse the expiration out of the cluster. expiry = clusterExpiry } } else { // NodeCertExpiry seems to be nil log.G(ctx).WithFields(logrus.Fields{ "cluster.id": cluster.ID, "method": "(*Server).updateCluster", }).WithError(err).Warn("failed to parse certificate expiration, using default") } // Attempt to update our local RootCA with the new parameters err = s.securityConfig.UpdateRootCA(rCA.CACert, rCA.CAKey, expiry) if err != nil { log.G(ctx).WithFields(logrus.Fields{ "cluster.id": cluster.ID, "method": "(*Server).updateCluster", }).WithError(err).Error("updating Root CA failed") } else { log.G(ctx).WithFields(logrus.Fields{ "cluster.id": cluster.ID, "method": "(*Server).updateCluster", }).Debugf("Root CA updated successfully") } } // Update our security config with the list of External CA URLs // from the new cluster state. // TODO(aaronl): In the future, this will be abstracted with an // ExternalCA interface that has different implementations for // different CA types. At the moment, only CFSSL is supported. var cfsslURLs []string for _, ca := range cluster.Spec.CAConfig.ExternalCAs { if ca.Protocol == api.ExternalCA_CAProtocolCFSSL { cfsslURLs = append(cfsslURLs, ca.URL) } } s.securityConfig.externalCA.UpdateURLs(cfsslURLs...) }
func (c *containerAdapter) shutdown(ctx context.Context) error { // Default stop grace period to 10s. stopgrace := 10 * time.Second spec := c.container.spec() if spec.StopGracePeriod != nil { stopgrace, _ = gogotypes.DurationFromProto(spec.StopGracePeriod) } return c.client.ContainerStop(ctx, c.container.name(), &stopgrace) }
func validateClusterSpec(spec *api.ClusterSpec) error { if spec == nil { return grpc.Errorf(codes.InvalidArgument, errInvalidArgument.Error()) } // Validate that expiry time being provided is valid, and over our minimum if spec.CAConfig.NodeCertExpiry != nil { expiry, err := gogotypes.DurationFromProto(spec.CAConfig.NodeCertExpiry) if err != nil { return grpc.Errorf(codes.InvalidArgument, errInvalidArgument.Error()) } if expiry < ca.MinNodeCertExpiration { return grpc.Errorf(codes.InvalidArgument, "minimum certificate expiry time is: %s", ca.MinNodeCertExpiration) } } // Validate that AcceptancePolicies only include Secrets that are bcrypted // TODO(diogo): Add a global list of acceptace algorithms. We only support bcrypt for now. if len(spec.AcceptancePolicy.Policies) > 0 { for _, policy := range spec.AcceptancePolicy.Policies { if policy.Secret != nil && strings.ToLower(policy.Secret.Alg) != "bcrypt" { return grpc.Errorf(codes.InvalidArgument, "hashing algorithm is not supported: %s", policy.Secret.Alg) } } } // Validate that heartbeatPeriod time being provided is valid if spec.Dispatcher.HeartbeatPeriod != nil { heartbeatPeriod, err := gogotypes.DurationFromProto(spec.Dispatcher.HeartbeatPeriod) if err != nil { return grpc.Errorf(codes.InvalidArgument, errInvalidArgument.Error()) } if heartbeatPeriod < 0 { return grpc.Errorf(codes.InvalidArgument, "heartbeat time period cannot be a negative duration") } } return nil }
// unmarshalValue converts/copies a value into the target. // prop may be nil. func (u *Unmarshaler) unmarshalValue(target reflect.Value, inputValue json.RawMessage, prop *proto.Properties) error { targetType := target.Type() // Allocate memory for pointer fields. if targetType.Kind() == reflect.Ptr { target.Set(reflect.New(targetType.Elem())) return u.unmarshalValue(target.Elem(), inputValue, prop) } // Handle well-known types. if wkt, ok := target.Addr().Interface().(isWkt); ok { switch wkt.XXX_WellKnownType() { case "DoubleValue", "FloatValue", "Int64Value", "UInt64Value", "Int32Value", "UInt32Value", "BoolValue", "StringValue", "BytesValue": // "Wrappers use the same representation in JSON // as the wrapped primitive type, except that null is allowed." // encoding/json will turn JSON `null` into Go `nil`, // so we don't have to do any extra work. return u.unmarshalValue(target.Field(0), inputValue, prop) case "Any": return fmt.Errorf("unmarshaling Any not supported yet") case "Duration": unq, err := strconv.Unquote(string(inputValue)) if err != nil { return err } d, err := time.ParseDuration(unq) if err != nil { return fmt.Errorf("bad Duration: %v", err) } ns := d.Nanoseconds() s := ns / 1e9 ns %= 1e9 target.Field(0).SetInt(s) target.Field(1).SetInt(ns) return nil case "Timestamp": unq, err := strconv.Unquote(string(inputValue)) if err != nil { return err } t, err := time.Parse(time.RFC3339Nano, unq) if err != nil { return fmt.Errorf("bad Timestamp: %v", err) } target.Field(0).SetInt(int64(t.Unix())) target.Field(1).SetInt(int64(t.Nanosecond())) return nil } } if t, ok := target.Addr().Interface().(*time.Time); ok { ts := &types.Timestamp{} if err := u.unmarshalValue(reflect.ValueOf(ts).Elem(), inputValue, prop); err != nil { return err } tt, err := types.TimestampFromProto(ts) if err != nil { return err } *t = tt return nil } if d, ok := target.Addr().Interface().(*time.Duration); ok { dur := &types.Duration{} if err := u.unmarshalValue(reflect.ValueOf(dur).Elem(), inputValue, prop); err != nil { return err } dd, err := types.DurationFromProto(dur) if err != nil { return err } *d = dd return nil } // Handle enums, which have an underlying type of int32, // and may appear as strings. // The case of an enum appearing as a number is handled // at the bottom of this function. if inputValue[0] == '"' && prop != nil && prop.Enum != "" { vmap := proto.EnumValueMap(prop.Enum) // Don't need to do unquoting; valid enum names // are from a limited character set. s := inputValue[1 : len(inputValue)-1] n, ok := vmap[string(s)] if !ok { return fmt.Errorf("unknown value %q for enum %s", s, prop.Enum) } if target.Kind() == reflect.Ptr { // proto2 target.Set(reflect.New(targetType.Elem())) target = target.Elem() } target.SetInt(int64(n)) return nil } // Handle nested messages. if targetType.Kind() == reflect.Struct { var jsonFields map[string]json.RawMessage if err := json.Unmarshal(inputValue, &jsonFields); err != nil { return err } consumeField := func(prop *proto.Properties) (json.RawMessage, bool) { // Be liberal in what names we accept; both orig_name and camelName are okay. fieldNames := acceptedJSONFieldNames(prop) vOrig, okOrig := jsonFields[fieldNames.orig] vCamel, okCamel := jsonFields[fieldNames.camel] if !okOrig && !okCamel { return nil, false } // If, for some reason, both are present in the data, favour the camelName. var raw json.RawMessage if okOrig { raw = vOrig delete(jsonFields, fieldNames.orig) } if okCamel { raw = vCamel delete(jsonFields, fieldNames.camel) } return raw, true } sprops := proto.GetProperties(targetType) for i := 0; i < target.NumField(); i++ { ft := target.Type().Field(i) if strings.HasPrefix(ft.Name, "XXX_") { continue } valueForField, ok := consumeField(sprops.Prop[i]) if !ok { continue } if err := u.unmarshalValue(target.Field(i), valueForField, sprops.Prop[i]); err != nil { return err } } // Check for any oneof fields. if len(jsonFields) > 0 { for _, oop := range sprops.OneofTypes { raw, ok := consumeField(oop.Prop) if !ok { continue } nv := reflect.New(oop.Type.Elem()) target.Field(oop.Field).Set(nv) if err := u.unmarshalValue(nv.Elem().Field(0), raw, oop.Prop); err != nil { return err } } } if !u.AllowUnknownFields && len(jsonFields) > 0 { // Pick any field to be the scapegoat. var f string for fname := range jsonFields { f = fname break } return fmt.Errorf("unknown field %q in %v", f, targetType) } return nil } // Handle arrays if targetType.Kind() == reflect.Slice { if targetType.Elem().Kind() == reflect.Uint8 { outRef := reflect.New(targetType) outVal := outRef.Interface() //CustomType with underlying type []byte if _, ok := outVal.(interface { UnmarshalJSON([]byte) error }); ok { if err := json.Unmarshal(inputValue, outVal); err != nil { return err } target.Set(outRef.Elem()) return nil } // Special case for encoded bytes. Pre-go1.5 doesn't support unmarshalling // strings into aliased []byte types. // https://github.com/golang/go/commit/4302fd0409da5e4f1d71471a6770dacdc3301197 // https://github.com/golang/go/commit/c60707b14d6be26bf4213114d13070bff00d0b0a var out []byte if err := json.Unmarshal(inputValue, &out); err != nil { return err } target.SetBytes(out) return nil } var slc []json.RawMessage if err := json.Unmarshal(inputValue, &slc); err != nil { return err } len := len(slc) target.Set(reflect.MakeSlice(targetType, len, len)) for i := 0; i < len; i++ { if err := u.unmarshalValue(target.Index(i), slc[i], prop); err != nil { return err } } return nil } // Handle maps (whose keys are always strings) if targetType.Kind() == reflect.Map { var mp map[string]json.RawMessage if err := json.Unmarshal(inputValue, &mp); err != nil { return err } target.Set(reflect.MakeMap(targetType)) var keyprop, valprop *proto.Properties if prop != nil { // These could still be nil if the protobuf metadata is broken somehow. // TODO: This won't work because the fields are unexported. // We should probably just reparse them. //keyprop, valprop = prop.mkeyprop, prop.mvalprop } for ks, raw := range mp { // Unmarshal map key. The core json library already decoded the key into a // string, so we handle that specially. Other types were quoted post-serialization. var k reflect.Value if targetType.Key().Kind() == reflect.String { k = reflect.ValueOf(ks) } else { k = reflect.New(targetType.Key()).Elem() if err := u.unmarshalValue(k, json.RawMessage(ks), keyprop); err != nil { return err } } if !k.Type().AssignableTo(targetType.Key()) { k = k.Convert(targetType.Key()) } // Unmarshal map value. v := reflect.New(targetType.Elem()).Elem() if err := u.unmarshalValue(v, raw, valprop); err != nil { return err } target.SetMapIndex(k, v) } return nil } // 64-bit integers can be encoded as strings. In this case we drop // the quotes and proceed as normal. isNum := targetType.Kind() == reflect.Int64 || targetType.Kind() == reflect.Uint64 if isNum && strings.HasPrefix(string(inputValue), `"`) { inputValue = inputValue[1 : len(inputValue)-1] } // Use the encoding/json for parsing other value types. return json.Unmarshal(inputValue, target.Addr().Interface()) }
func (r *Orchestrator) initTasks(ctx context.Context, readTx store.ReadTx) error { tasks, err := store.FindTasks(readTx, store.All) if err != nil { return err } for _, t := range tasks { if t.NodeID != "" { n := store.GetNode(readTx, t.NodeID) if invalidNode(n) && t.Status.State <= api.TaskStateRunning && t.DesiredState <= api.TaskStateRunning { r.restartTasks[t.ID] = struct{}{} } } } _, err = r.store.Batch(func(batch *store.Batch) error { for _, t := range tasks { if t.ServiceID == "" { continue } // TODO(aluzzardi): We should NOT retrieve the service here. service := store.GetService(readTx, t.ServiceID) if service == nil { // Service was deleted err := batch.Update(func(tx store.Tx) error { return store.DeleteTask(tx, t.ID) }) if err != nil { log.G(ctx).WithError(err).Error("failed to set task desired state to dead") } continue } // TODO(aluzzardi): This is shady. We should have a more generic condition. if t.DesiredState != api.TaskStateReady || !orchestrator.IsReplicatedService(service) { continue } restartDelay := orchestrator.DefaultRestartDelay if t.Spec.Restart != nil && t.Spec.Restart.Delay != nil { var err error restartDelay, err = gogotypes.DurationFromProto(t.Spec.Restart.Delay) if err != nil { log.G(ctx).WithError(err).Error("invalid restart delay") restartDelay = orchestrator.DefaultRestartDelay } } if restartDelay != 0 { timestamp, err := gogotypes.TimestampFromProto(t.Status.Timestamp) if err == nil { restartTime := timestamp.Add(restartDelay) calculatedRestartDelay := restartTime.Sub(time.Now()) if calculatedRestartDelay < restartDelay { restartDelay = calculatedRestartDelay } if restartDelay > 0 { _ = batch.Update(func(tx store.Tx) error { t := store.GetTask(tx, t.ID) // TODO(aluzzardi): This is shady as well. We should have a more generic condition. if t == nil || t.DesiredState != api.TaskStateReady { return nil } r.restarts.DelayStart(ctx, tx, nil, t.ID, restartDelay, true) return nil }) continue } } else { log.G(ctx).WithError(err).Error("invalid status timestamp") } } // Start now err := batch.Update(func(tx store.Tx) error { return r.restarts.StartNow(tx, t.ID) }) if err != nil { log.G(ctx).WithError(err).WithField("task.id", t.ID).Error("moving task out of delayed state failed") } } return nil }) return err }
// Run runs dispatcher tasks which should be run on leader dispatcher. // Dispatcher can be stopped with cancelling ctx or calling Stop(). func (d *Dispatcher) Run(ctx context.Context) error { d.mu.Lock() if d.isRunning() { d.mu.Unlock() return errors.New("dispatcher is already running") } ctx = log.WithModule(ctx, "dispatcher") if err := d.markNodesUnknown(ctx); err != nil { log.G(ctx).Errorf(`failed to move all nodes to "unknown" state: %v`, err) } configWatcher, cancel, err := store.ViewAndWatch( d.store, func(readTx store.ReadTx) error { clusters, err := store.FindClusters(readTx, store.ByName(store.DefaultClusterName)) if err != nil { return err } if err == nil && len(clusters) == 1 { heartbeatPeriod, err := gogotypes.DurationFromProto(clusters[0].Spec.Dispatcher.HeartbeatPeriod) if err == nil && heartbeatPeriod > 0 { d.config.HeartbeatPeriod = heartbeatPeriod } if clusters[0].NetworkBootstrapKeys != nil { d.networkBootstrapKeys = clusters[0].NetworkBootstrapKeys } } return nil }, state.EventUpdateCluster{}, ) if err != nil { d.mu.Unlock() return err } // set queues here to guarantee that Close will close them d.mgrQueue = watch.NewQueue() d.keyMgrQueue = watch.NewQueue() peerWatcher, peerCancel := d.cluster.SubscribePeers() defer peerCancel() d.lastSeenManagers = getWeightedPeers(d.cluster) defer cancel() d.ctx, d.cancel = context.WithCancel(ctx) ctx = d.ctx d.wg.Add(1) defer d.wg.Done() d.mu.Unlock() publishManagers := func(peers []*api.Peer) { var mgrs []*api.WeightedPeer for _, p := range peers { mgrs = append(mgrs, &api.WeightedPeer{ Peer: p, Weight: remotes.DefaultObservationWeight, }) } d.mu.Lock() d.lastSeenManagers = mgrs d.mu.Unlock() d.mgrQueue.Publish(mgrs) } batchTimer := time.NewTimer(maxBatchInterval) defer batchTimer.Stop() for { select { case ev := <-peerWatcher: publishManagers(ev.([]*api.Peer)) case <-d.processUpdatesTrigger: d.processUpdates(ctx) batchTimer.Reset(maxBatchInterval) case <-batchTimer.C: d.processUpdates(ctx) batchTimer.Reset(maxBatchInterval) case v := <-configWatcher: cluster := v.(state.EventUpdateCluster) d.mu.Lock() if cluster.Cluster.Spec.Dispatcher.HeartbeatPeriod != nil { // ignore error, since Spec has passed validation before heartbeatPeriod, _ := gogotypes.DurationFromProto(cluster.Cluster.Spec.Dispatcher.HeartbeatPeriod) if heartbeatPeriod != d.config.HeartbeatPeriod { // only call d.nodes.updatePeriod when heartbeatPeriod changes d.config.HeartbeatPeriod = heartbeatPeriod d.nodes.updatePeriod(d.config.HeartbeatPeriod, d.config.HeartbeatEpsilon, d.config.GracePeriodMultiplier) } } d.networkBootstrapKeys = cluster.Cluster.NetworkBootstrapKeys d.mu.Unlock() d.keyMgrQueue.Publish(cluster.Cluster.NetworkBootstrapKeys) case <-ctx.Done(): return nil } } }
// Run starts the update and returns only once its complete or cancelled. func (u *Updater) Run(ctx context.Context, slots []orchestrator.Slot) { defer close(u.doneChan) service := u.newService // If the update is in a PAUSED state, we should not do anything. if service.UpdateStatus != nil && (service.UpdateStatus.State == api.UpdateStatus_PAUSED || service.UpdateStatus.State == api.UpdateStatus_ROLLBACK_PAUSED) { return } var dirtySlots []orchestrator.Slot for _, slot := range slots { if u.isSlotDirty(slot) { dirtySlots = append(dirtySlots, slot) } } // Abort immediately if all tasks are clean. if len(dirtySlots) == 0 { if service.UpdateStatus != nil && (service.UpdateStatus.State == api.UpdateStatus_UPDATING || service.UpdateStatus.State == api.UpdateStatus_ROLLBACK_STARTED) { u.completeUpdate(ctx, service.ID) } return } // If there's no update in progress, we are starting one. if service.UpdateStatus == nil { u.startUpdate(ctx, service.ID) } parallelism := 0 if service.Spec.Update != nil { parallelism = int(service.Spec.Update.Parallelism) } if parallelism == 0 { // TODO(aluzzardi): We could try to optimize unlimited parallelism by performing updates in a single // goroutine using a batch transaction. parallelism = len(dirtySlots) } // Start the workers. slotQueue := make(chan orchestrator.Slot) wg := sync.WaitGroup{} wg.Add(parallelism) for i := 0; i < parallelism; i++ { go func() { u.worker(ctx, slotQueue) wg.Done() }() } failureAction := api.UpdateConfig_PAUSE allowedFailureFraction := float32(0) monitoringPeriod := defaultMonitor if service.Spec.Update != nil { failureAction = service.Spec.Update.FailureAction allowedFailureFraction = service.Spec.Update.MaxFailureRatio if service.Spec.Update.Monitor != nil { var err error monitoringPeriod, err = gogotypes.DurationFromProto(service.Spec.Update.Monitor) if err != nil { monitoringPeriod = defaultMonitor } } } var failedTaskWatch chan events.Event if failureAction != api.UpdateConfig_CONTINUE { var cancelWatch func() failedTaskWatch, cancelWatch = state.Watch( u.store.WatchQueue(), state.EventUpdateTask{ Task: &api.Task{ServiceID: service.ID, Status: api.TaskStatus{State: api.TaskStateRunning}}, Checks: []state.TaskCheckFunc{state.TaskCheckServiceID, state.TaskCheckStateGreaterThan}, }, ) defer cancelWatch() } stopped := false failedTasks := make(map[string]struct{}) totalFailures := 0 failureTriggersAction := func(failedTask *api.Task) bool { // Ignore tasks we have already seen as failures. if _, found := failedTasks[failedTask.ID]; found { return false } // If this failed/completed task is one that we // created as part of this update, we should // follow the failure action. u.updatedTasksMu.Lock() startedAt, found := u.updatedTasks[failedTask.ID] u.updatedTasksMu.Unlock() if found && (startedAt.IsZero() || time.Since(startedAt) <= monitoringPeriod) { failedTasks[failedTask.ID] = struct{}{} totalFailures++ if float32(totalFailures)/float32(len(dirtySlots)) > allowedFailureFraction { switch failureAction { case api.UpdateConfig_PAUSE: stopped = true message := fmt.Sprintf("update paused due to failure or early termination of task %s", failedTask.ID) u.pauseUpdate(ctx, service.ID, message) return true case api.UpdateConfig_ROLLBACK: // Never roll back a rollback if service.UpdateStatus != nil && service.UpdateStatus.State == api.UpdateStatus_ROLLBACK_STARTED { message := fmt.Sprintf("rollback paused due to failure or early termination of task %s", failedTask.ID) u.pauseUpdate(ctx, service.ID, message) return true } stopped = true message := fmt.Sprintf("update rolled back due to failure or early termination of task %s", failedTask.ID) u.rollbackUpdate(ctx, service.ID, message) return true } } } return false } slotsLoop: for _, slot := range dirtySlots { retryLoop: for { // Wait for a worker to pick up the task or abort the update, whichever comes first. select { case <-u.stopChan: stopped = true break slotsLoop case ev := <-failedTaskWatch: if failureTriggersAction(ev.(state.EventUpdateTask).Task) { break slotsLoop } case slotQueue <- slot: break retryLoop } } } close(slotQueue) wg.Wait() if !stopped { // Keep watching for task failures for one more monitoringPeriod, // before declaring the update complete. doneMonitoring := time.After(monitoringPeriod) monitorLoop: for { select { case <-u.stopChan: stopped = true break monitorLoop case <-doneMonitoring: break monitorLoop case ev := <-failedTaskWatch: if failureTriggersAction(ev.(state.EventUpdateTask).Task) { break monitorLoop } } } } // TODO(aaronl): Potentially roll back the service if not enough tasks // have reached RUNNING by this point. if !stopped { u.completeUpdate(ctx, service.ID) } }
func (r *Supervisor) shouldRestart(ctx context.Context, t *api.Task, service *api.Service) bool { // TODO(aluzzardi): This function should not depend on `service`. condition := orchestrator.RestartCondition(t) if condition != api.RestartOnAny && (condition != api.RestartOnFailure || t.Status.State == api.TaskStateCompleted) { return false } if t.Spec.Restart == nil || t.Spec.Restart.MaxAttempts == 0 { return true } instanceTuple := instanceTuple{ instance: t.Slot, serviceID: t.ServiceID, } // Instance is not meaningful for "global" tasks, so they need to be // indexed by NodeID. if orchestrator.IsGlobalService(service) { instanceTuple.nodeID = t.NodeID } r.mu.Lock() defer r.mu.Unlock() restartInfo := r.history[instanceTuple] if restartInfo == nil { return true } if t.Spec.Restart.Window == nil || (t.Spec.Restart.Window.Seconds == 0 && t.Spec.Restart.Window.Nanos == 0) { return restartInfo.totalRestarts < t.Spec.Restart.MaxAttempts } if restartInfo.restartedInstances == nil { return true } window, err := gogotypes.DurationFromProto(t.Spec.Restart.Window) if err != nil { log.G(ctx).WithError(err).Error("invalid restart lookback window") return restartInfo.totalRestarts < t.Spec.Restart.MaxAttempts } lookback := time.Now().Add(-window) var next *list.Element for e := restartInfo.restartedInstances.Front(); e != nil; e = next { next = e.Next() if e.Value.(restartedInstance).timestamp.After(lookback) { break } restartInfo.restartedInstances.Remove(e) } numRestarts := uint64(restartInfo.restartedInstances.Len()) if numRestarts == 0 { restartInfo.restartedInstances = nil } return numRestarts < t.Spec.Restart.MaxAttempts }
// Restart initiates a new task to replace t if appropriate under the service's // restart policy. func (r *Supervisor) Restart(ctx context.Context, tx store.Tx, cluster *api.Cluster, service *api.Service, t api.Task) error { // TODO(aluzzardi): This function should not depend on `service`. // Is the old task still in the process of restarting? If so, wait for // its restart delay to elapse, to avoid tight restart loops (for // example, when the image doesn't exist). r.mu.Lock() oldDelay, ok := r.delays[t.ID] if ok { if !oldDelay.waiter { oldDelay.waiter = true go r.waitRestart(ctx, oldDelay, cluster, t.ID) } r.mu.Unlock() return nil } r.mu.Unlock() // Sanity check: was the task shut down already by a separate call to // Restart? If so, we must avoid restarting it, because this will create // an extra task. This should never happen unless there is a bug. if t.DesiredState > api.TaskStateRunning { return errors.New("Restart called on task that was already shut down") } t.DesiredState = api.TaskStateShutdown err := store.UpdateTask(tx, &t) if err != nil { log.G(ctx).WithError(err).Errorf("failed to set task desired state to dead") return err } if !r.shouldRestart(ctx, &t, service) { return nil } var restartTask *api.Task if orchestrator.IsReplicatedService(service) { restartTask = orchestrator.NewTask(cluster, service, t.Slot, "") } else if orchestrator.IsGlobalService(service) { restartTask = orchestrator.NewTask(cluster, service, 0, t.NodeID) } else { log.G(ctx).Error("service not supported by restart supervisor") return nil } n := store.GetNode(tx, t.NodeID) restartTask.DesiredState = api.TaskStateReady var restartDelay time.Duration // Restart delay is not applied to drained nodes if n == nil || n.Spec.Availability != api.NodeAvailabilityDrain { if t.Spec.Restart != nil && t.Spec.Restart.Delay != nil { var err error restartDelay, err = gogotypes.DurationFromProto(t.Spec.Restart.Delay) if err != nil { log.G(ctx).WithError(err).Error("invalid restart delay; using default") restartDelay = orchestrator.DefaultRestartDelay } } else { restartDelay = orchestrator.DefaultRestartDelay } } waitStop := true // Normally we wait for the old task to stop running, but we skip this // if the old task is already dead or the node it's assigned to is down. if (n != nil && n.Status.State == api.NodeStatus_DOWN) || t.Status.State > api.TaskStateRunning { waitStop = false } if err := store.CreateTask(tx, restartTask); err != nil { log.G(ctx).WithError(err).WithField("task.id", restartTask.ID).Error("task create failed") return err } r.recordRestartHistory(restartTask) r.DelayStart(ctx, tx, &t, restartTask.ID, restartDelay, waitStop) return nil }