Esempio n. 1
0
func validateRestartPolicy(rp *api.RestartPolicy) error {
	if rp == nil {
		return nil
	}

	if rp.Delay != nil {
		delay, err := gogotypes.DurationFromProto(rp.Delay)
		if err != nil {
			return err
		}
		if delay < 0 {
			return grpc.Errorf(codes.InvalidArgument, "TaskSpec: restart-delay cannot be negative")
		}
	}

	if rp.Window != nil {
		win, err := gogotypes.DurationFromProto(rp.Window)
		if err != nil {
			return err
		}
		if win < 0 {
			return grpc.Errorf(codes.InvalidArgument, "TaskSpec: restart-window cannot be negative")
		}
	}

	return nil
}
Esempio n. 2
0
func printClusterSummary(cluster *api.Cluster) {
	w := tabwriter.NewWriter(os.Stdout, 8, 8, 8, ' ', 0)
	defer w.Flush()

	common.FprintfIfNotEmpty(w, "ID\t: %s\n", cluster.ID)
	common.FprintfIfNotEmpty(w, "Name\t: %s\n", cluster.Spec.Annotations.Name)
	fmt.Fprintf(w, "Orchestration settings:\n")
	fmt.Fprintf(w, "  Task history entries: %d\n", cluster.Spec.Orchestration.TaskHistoryRetentionLimit)

	heartbeatPeriod, err := gogotypes.DurationFromProto(cluster.Spec.Dispatcher.HeartbeatPeriod)
	if err == nil {
		fmt.Fprintf(w, "Dispatcher settings:\n")
		fmt.Fprintf(w, "  Dispatcher heartbeat period: %s\n", heartbeatPeriod.String())
	}

	fmt.Fprintf(w, "Certificate Authority settings:\n")
	if cluster.Spec.CAConfig.NodeCertExpiry != nil {
		clusterDuration, err := gogotypes.DurationFromProto(cluster.Spec.CAConfig.NodeCertExpiry)
		if err != nil {
			fmt.Fprintf(w, "  Certificate Validity Duration: [ERROR PARSING DURATION]\n")
		} else {
			fmt.Fprintf(w, "  Certificate Validity Duration: %s\n", clusterDuration.String())
		}
	}
	if len(cluster.Spec.CAConfig.ExternalCAs) > 0 {
		fmt.Fprintf(w, "  External CAs:\n")
		for _, ca := range cluster.Spec.CAConfig.ExternalCAs {
			fmt.Fprintf(w, "    %s: %s\n", ca.Protocol, ca.URL)
		}
	}

	fmt.Fprintln(w, "  Join Tokens:")
	fmt.Fprintln(w, "    Worker:", cluster.RootCA.JoinTokens.Worker)
	fmt.Fprintln(w, "    Manager:", cluster.RootCA.JoinTokens.Manager)

	if cluster.Spec.TaskDefaults.LogDriver != nil {
		fmt.Fprintf(w, "Default Log Driver\t: %s\n", cluster.Spec.TaskDefaults.LogDriver.Name)
		var keys []string

		if len(cluster.Spec.TaskDefaults.LogDriver.Options) != 0 {
			for k := range cluster.Spec.TaskDefaults.LogDriver.Options {
				keys = append(keys, k)
			}
			sort.Strings(keys)

			for _, k := range keys {
				v := cluster.Spec.TaskDefaults.LogDriver.Options[k]
				if v != "" {
					fmt.Fprintf(w, "  %s\t: %s\n", k, v)
				} else {
					fmt.Fprintf(w, "  %s\t\n", k)
				}
			}
		}
	}
}
Esempio n. 3
0
func (c *containerConfig) healthcheck() *enginecontainer.HealthConfig {
	hcSpec := c.spec().Healthcheck
	if hcSpec == nil {
		return nil
	}
	interval, _ := gogotypes.DurationFromProto(hcSpec.Interval)
	timeout, _ := gogotypes.DurationFromProto(hcSpec.Timeout)
	return &enginecontainer.HealthConfig{
		Test:     hcSpec.Test,
		Interval: interval,
		Timeout:  timeout,
		Retries:  int(hcSpec.Retries),
	}
}
Esempio n. 4
0
// updateCluster is called when there are cluster changes, and it ensures that the local RootCA is
// always aware of changes in clusterExpiry and the Root CA key material
func (s *Server) updateCluster(ctx context.Context, cluster *api.Cluster) {
	s.mu.Lock()
	s.joinTokens = cluster.RootCA.JoinTokens.Copy()
	s.mu.Unlock()
	var err error

	// If the cluster has a RootCA, let's try to update our SecurityConfig to reflect the latest values
	rCA := cluster.RootCA
	if len(rCA.CACert) != 0 && len(rCA.CAKey) != 0 {
		expiry := DefaultNodeCertExpiration
		if cluster.Spec.CAConfig.NodeCertExpiry != nil {
			// NodeCertExpiry exists, let's try to parse the duration out of it
			clusterExpiry, err := gogotypes.DurationFromProto(cluster.Spec.CAConfig.NodeCertExpiry)
			if err != nil {
				log.G(ctx).WithFields(logrus.Fields{
					"cluster.id": cluster.ID,
					"method":     "(*Server).updateCluster",
				}).WithError(err).Warn("failed to parse certificate expiration, using default")
			} else {
				// We were able to successfully parse the expiration out of the cluster.
				expiry = clusterExpiry
			}
		} else {
			// NodeCertExpiry seems to be nil
			log.G(ctx).WithFields(logrus.Fields{
				"cluster.id": cluster.ID,
				"method":     "(*Server).updateCluster",
			}).WithError(err).Warn("failed to parse certificate expiration, using default")

		}
		// Attempt to update our local RootCA with the new parameters
		err = s.securityConfig.UpdateRootCA(rCA.CACert, rCA.CAKey, expiry)
		if err != nil {
			log.G(ctx).WithFields(logrus.Fields{
				"cluster.id": cluster.ID,
				"method":     "(*Server).updateCluster",
			}).WithError(err).Error("updating Root CA failed")
		} else {
			log.G(ctx).WithFields(logrus.Fields{
				"cluster.id": cluster.ID,
				"method":     "(*Server).updateCluster",
			}).Debugf("Root CA updated successfully")
		}
	}

	// Update our security config with the list of External CA URLs
	// from the new cluster state.

	// TODO(aaronl): In the future, this will be abstracted with an
	// ExternalCA interface that has different implementations for
	// different CA types. At the moment, only CFSSL is supported.
	var cfsslURLs []string
	for _, ca := range cluster.Spec.CAConfig.ExternalCAs {
		if ca.Protocol == api.ExternalCA_CAProtocolCFSSL {
			cfsslURLs = append(cfsslURLs, ca.URL)
		}
	}

	s.securityConfig.externalCA.UpdateURLs(cfsslURLs...)
}
Esempio n. 5
0
func (c *containerAdapter) shutdown(ctx context.Context) error {
	// Default stop grace period to 10s.
	stopgrace := 10 * time.Second
	spec := c.container.spec()
	if spec.StopGracePeriod != nil {
		stopgrace, _ = gogotypes.DurationFromProto(spec.StopGracePeriod)
	}
	return c.client.ContainerStop(ctx, c.container.name(), &stopgrace)
}
Esempio n. 6
0
func validateClusterSpec(spec *api.ClusterSpec) error {
	if spec == nil {
		return grpc.Errorf(codes.InvalidArgument, errInvalidArgument.Error())
	}

	// Validate that expiry time being provided is valid, and over our minimum
	if spec.CAConfig.NodeCertExpiry != nil {
		expiry, err := gogotypes.DurationFromProto(spec.CAConfig.NodeCertExpiry)
		if err != nil {
			return grpc.Errorf(codes.InvalidArgument, errInvalidArgument.Error())
		}
		if expiry < ca.MinNodeCertExpiration {
			return grpc.Errorf(codes.InvalidArgument, "minimum certificate expiry time is: %s", ca.MinNodeCertExpiration)
		}
	}

	// Validate that AcceptancePolicies only include Secrets that are bcrypted
	// TODO(diogo): Add a global list of acceptace algorithms. We only support bcrypt for now.
	if len(spec.AcceptancePolicy.Policies) > 0 {
		for _, policy := range spec.AcceptancePolicy.Policies {
			if policy.Secret != nil && strings.ToLower(policy.Secret.Alg) != "bcrypt" {
				return grpc.Errorf(codes.InvalidArgument, "hashing algorithm is not supported: %s", policy.Secret.Alg)
			}
		}
	}

	// Validate that heartbeatPeriod time being provided is valid
	if spec.Dispatcher.HeartbeatPeriod != nil {
		heartbeatPeriod, err := gogotypes.DurationFromProto(spec.Dispatcher.HeartbeatPeriod)
		if err != nil {
			return grpc.Errorf(codes.InvalidArgument, errInvalidArgument.Error())
		}
		if heartbeatPeriod < 0 {
			return grpc.Errorf(codes.InvalidArgument, "heartbeat time period cannot be a negative duration")
		}
	}

	return nil
}
Esempio n. 7
0
// unmarshalValue converts/copies a value into the target.
// prop may be nil.
func (u *Unmarshaler) unmarshalValue(target reflect.Value, inputValue json.RawMessage, prop *proto.Properties) error {
	targetType := target.Type()

	// Allocate memory for pointer fields.
	if targetType.Kind() == reflect.Ptr {
		target.Set(reflect.New(targetType.Elem()))
		return u.unmarshalValue(target.Elem(), inputValue, prop)
	}

	// Handle well-known types.
	if wkt, ok := target.Addr().Interface().(isWkt); ok {
		switch wkt.XXX_WellKnownType() {
		case "DoubleValue", "FloatValue", "Int64Value", "UInt64Value",
			"Int32Value", "UInt32Value", "BoolValue", "StringValue", "BytesValue":
			// "Wrappers use the same representation in JSON
			//  as the wrapped primitive type, except that null is allowed."
			// encoding/json will turn JSON `null` into Go `nil`,
			// so we don't have to do any extra work.
			return u.unmarshalValue(target.Field(0), inputValue, prop)
		case "Any":
			return fmt.Errorf("unmarshaling Any not supported yet")
		case "Duration":
			unq, err := strconv.Unquote(string(inputValue))
			if err != nil {
				return err
			}
			d, err := time.ParseDuration(unq)
			if err != nil {
				return fmt.Errorf("bad Duration: %v", err)
			}
			ns := d.Nanoseconds()
			s := ns / 1e9
			ns %= 1e9
			target.Field(0).SetInt(s)
			target.Field(1).SetInt(ns)
			return nil
		case "Timestamp":
			unq, err := strconv.Unquote(string(inputValue))
			if err != nil {
				return err
			}
			t, err := time.Parse(time.RFC3339Nano, unq)
			if err != nil {
				return fmt.Errorf("bad Timestamp: %v", err)
			}
			target.Field(0).SetInt(int64(t.Unix()))
			target.Field(1).SetInt(int64(t.Nanosecond()))
			return nil
		}
	}

	if t, ok := target.Addr().Interface().(*time.Time); ok {
		ts := &types.Timestamp{}
		if err := u.unmarshalValue(reflect.ValueOf(ts).Elem(), inputValue, prop); err != nil {
			return err
		}
		tt, err := types.TimestampFromProto(ts)
		if err != nil {
			return err
		}
		*t = tt
		return nil
	}

	if d, ok := target.Addr().Interface().(*time.Duration); ok {
		dur := &types.Duration{}
		if err := u.unmarshalValue(reflect.ValueOf(dur).Elem(), inputValue, prop); err != nil {
			return err
		}
		dd, err := types.DurationFromProto(dur)
		if err != nil {
			return err
		}
		*d = dd
		return nil
	}

	// Handle enums, which have an underlying type of int32,
	// and may appear as strings.
	// The case of an enum appearing as a number is handled
	// at the bottom of this function.
	if inputValue[0] == '"' && prop != nil && prop.Enum != "" {
		vmap := proto.EnumValueMap(prop.Enum)
		// Don't need to do unquoting; valid enum names
		// are from a limited character set.
		s := inputValue[1 : len(inputValue)-1]
		n, ok := vmap[string(s)]
		if !ok {
			return fmt.Errorf("unknown value %q for enum %s", s, prop.Enum)
		}
		if target.Kind() == reflect.Ptr { // proto2
			target.Set(reflect.New(targetType.Elem()))
			target = target.Elem()
		}
		target.SetInt(int64(n))
		return nil
	}

	// Handle nested messages.
	if targetType.Kind() == reflect.Struct {
		var jsonFields map[string]json.RawMessage
		if err := json.Unmarshal(inputValue, &jsonFields); err != nil {
			return err
		}

		consumeField := func(prop *proto.Properties) (json.RawMessage, bool) {
			// Be liberal in what names we accept; both orig_name and camelName are okay.
			fieldNames := acceptedJSONFieldNames(prop)

			vOrig, okOrig := jsonFields[fieldNames.orig]
			vCamel, okCamel := jsonFields[fieldNames.camel]
			if !okOrig && !okCamel {
				return nil, false
			}
			// If, for some reason, both are present in the data, favour the camelName.
			var raw json.RawMessage
			if okOrig {
				raw = vOrig
				delete(jsonFields, fieldNames.orig)
			}
			if okCamel {
				raw = vCamel
				delete(jsonFields, fieldNames.camel)
			}
			return raw, true
		}

		sprops := proto.GetProperties(targetType)
		for i := 0; i < target.NumField(); i++ {
			ft := target.Type().Field(i)
			if strings.HasPrefix(ft.Name, "XXX_") {
				continue
			}
			valueForField, ok := consumeField(sprops.Prop[i])
			if !ok {
				continue
			}

			if err := u.unmarshalValue(target.Field(i), valueForField, sprops.Prop[i]); err != nil {
				return err
			}
		}
		// Check for any oneof fields.
		if len(jsonFields) > 0 {
			for _, oop := range sprops.OneofTypes {
				raw, ok := consumeField(oop.Prop)
				if !ok {
					continue
				}
				nv := reflect.New(oop.Type.Elem())
				target.Field(oop.Field).Set(nv)
				if err := u.unmarshalValue(nv.Elem().Field(0), raw, oop.Prop); err != nil {
					return err
				}
			}
		}
		if !u.AllowUnknownFields && len(jsonFields) > 0 {
			// Pick any field to be the scapegoat.
			var f string
			for fname := range jsonFields {
				f = fname
				break
			}
			return fmt.Errorf("unknown field %q in %v", f, targetType)
		}
		return nil
	}

	// Handle arrays
	if targetType.Kind() == reflect.Slice {
		if targetType.Elem().Kind() == reflect.Uint8 {
			outRef := reflect.New(targetType)
			outVal := outRef.Interface()
			//CustomType with underlying type []byte
			if _, ok := outVal.(interface {
				UnmarshalJSON([]byte) error
			}); ok {
				if err := json.Unmarshal(inputValue, outVal); err != nil {
					return err
				}
				target.Set(outRef.Elem())
				return nil
			}
			// Special case for encoded bytes. Pre-go1.5 doesn't support unmarshalling
			// strings into aliased []byte types.
			// https://github.com/golang/go/commit/4302fd0409da5e4f1d71471a6770dacdc3301197
			// https://github.com/golang/go/commit/c60707b14d6be26bf4213114d13070bff00d0b0a
			var out []byte
			if err := json.Unmarshal(inputValue, &out); err != nil {
				return err
			}
			target.SetBytes(out)
			return nil
		}

		var slc []json.RawMessage
		if err := json.Unmarshal(inputValue, &slc); err != nil {
			return err
		}
		len := len(slc)
		target.Set(reflect.MakeSlice(targetType, len, len))
		for i := 0; i < len; i++ {
			if err := u.unmarshalValue(target.Index(i), slc[i], prop); err != nil {
				return err
			}
		}
		return nil
	}

	// Handle maps (whose keys are always strings)
	if targetType.Kind() == reflect.Map {
		var mp map[string]json.RawMessage
		if err := json.Unmarshal(inputValue, &mp); err != nil {
			return err
		}
		target.Set(reflect.MakeMap(targetType))
		var keyprop, valprop *proto.Properties
		if prop != nil {
			// These could still be nil if the protobuf metadata is broken somehow.
			// TODO: This won't work because the fields are unexported.
			// We should probably just reparse them.
			//keyprop, valprop = prop.mkeyprop, prop.mvalprop
		}
		for ks, raw := range mp {
			// Unmarshal map key. The core json library already decoded the key into a
			// string, so we handle that specially. Other types were quoted post-serialization.
			var k reflect.Value
			if targetType.Key().Kind() == reflect.String {
				k = reflect.ValueOf(ks)
			} else {
				k = reflect.New(targetType.Key()).Elem()
				if err := u.unmarshalValue(k, json.RawMessage(ks), keyprop); err != nil {
					return err
				}
			}

			if !k.Type().AssignableTo(targetType.Key()) {
				k = k.Convert(targetType.Key())
			}

			// Unmarshal map value.
			v := reflect.New(targetType.Elem()).Elem()
			if err := u.unmarshalValue(v, raw, valprop); err != nil {
				return err
			}
			target.SetMapIndex(k, v)
		}
		return nil
	}

	// 64-bit integers can be encoded as strings. In this case we drop
	// the quotes and proceed as normal.
	isNum := targetType.Kind() == reflect.Int64 || targetType.Kind() == reflect.Uint64
	if isNum && strings.HasPrefix(string(inputValue), `"`) {
		inputValue = inputValue[1 : len(inputValue)-1]
	}

	// Use the encoding/json for parsing other value types.
	return json.Unmarshal(inputValue, target.Addr().Interface())
}
Esempio n. 8
0
func (r *Orchestrator) initTasks(ctx context.Context, readTx store.ReadTx) error {
	tasks, err := store.FindTasks(readTx, store.All)
	if err != nil {
		return err
	}
	for _, t := range tasks {
		if t.NodeID != "" {
			n := store.GetNode(readTx, t.NodeID)
			if invalidNode(n) && t.Status.State <= api.TaskStateRunning && t.DesiredState <= api.TaskStateRunning {
				r.restartTasks[t.ID] = struct{}{}
			}
		}
	}

	_, err = r.store.Batch(func(batch *store.Batch) error {
		for _, t := range tasks {
			if t.ServiceID == "" {
				continue
			}

			// TODO(aluzzardi): We should NOT retrieve the service here.
			service := store.GetService(readTx, t.ServiceID)
			if service == nil {
				// Service was deleted
				err := batch.Update(func(tx store.Tx) error {
					return store.DeleteTask(tx, t.ID)
				})
				if err != nil {
					log.G(ctx).WithError(err).Error("failed to set task desired state to dead")
				}
				continue
			}
			// TODO(aluzzardi): This is shady. We should have a more generic condition.
			if t.DesiredState != api.TaskStateReady || !orchestrator.IsReplicatedService(service) {
				continue
			}
			restartDelay := orchestrator.DefaultRestartDelay
			if t.Spec.Restart != nil && t.Spec.Restart.Delay != nil {
				var err error
				restartDelay, err = gogotypes.DurationFromProto(t.Spec.Restart.Delay)
				if err != nil {
					log.G(ctx).WithError(err).Error("invalid restart delay")
					restartDelay = orchestrator.DefaultRestartDelay
				}
			}
			if restartDelay != 0 {
				timestamp, err := gogotypes.TimestampFromProto(t.Status.Timestamp)
				if err == nil {
					restartTime := timestamp.Add(restartDelay)
					calculatedRestartDelay := restartTime.Sub(time.Now())
					if calculatedRestartDelay < restartDelay {
						restartDelay = calculatedRestartDelay
					}
					if restartDelay > 0 {
						_ = batch.Update(func(tx store.Tx) error {
							t := store.GetTask(tx, t.ID)
							// TODO(aluzzardi): This is shady as well. We should have a more generic condition.
							if t == nil || t.DesiredState != api.TaskStateReady {
								return nil
							}
							r.restarts.DelayStart(ctx, tx, nil, t.ID, restartDelay, true)
							return nil
						})
						continue
					}
				} else {
					log.G(ctx).WithError(err).Error("invalid status timestamp")
				}
			}

			// Start now
			err := batch.Update(func(tx store.Tx) error {
				return r.restarts.StartNow(tx, t.ID)
			})
			if err != nil {
				log.G(ctx).WithError(err).WithField("task.id", t.ID).Error("moving task out of delayed state failed")
			}
		}
		return nil
	})

	return err
}
Esempio n. 9
0
// Run runs dispatcher tasks which should be run on leader dispatcher.
// Dispatcher can be stopped with cancelling ctx or calling Stop().
func (d *Dispatcher) Run(ctx context.Context) error {
	d.mu.Lock()
	if d.isRunning() {
		d.mu.Unlock()
		return errors.New("dispatcher is already running")
	}
	ctx = log.WithModule(ctx, "dispatcher")
	if err := d.markNodesUnknown(ctx); err != nil {
		log.G(ctx).Errorf(`failed to move all nodes to "unknown" state: %v`, err)
	}
	configWatcher, cancel, err := store.ViewAndWatch(
		d.store,
		func(readTx store.ReadTx) error {
			clusters, err := store.FindClusters(readTx, store.ByName(store.DefaultClusterName))
			if err != nil {
				return err
			}
			if err == nil && len(clusters) == 1 {
				heartbeatPeriod, err := gogotypes.DurationFromProto(clusters[0].Spec.Dispatcher.HeartbeatPeriod)
				if err == nil && heartbeatPeriod > 0 {
					d.config.HeartbeatPeriod = heartbeatPeriod
				}
				if clusters[0].NetworkBootstrapKeys != nil {
					d.networkBootstrapKeys = clusters[0].NetworkBootstrapKeys
				}
			}
			return nil
		},
		state.EventUpdateCluster{},
	)
	if err != nil {
		d.mu.Unlock()
		return err
	}
	// set queues here to guarantee that Close will close them
	d.mgrQueue = watch.NewQueue()
	d.keyMgrQueue = watch.NewQueue()

	peerWatcher, peerCancel := d.cluster.SubscribePeers()
	defer peerCancel()
	d.lastSeenManagers = getWeightedPeers(d.cluster)

	defer cancel()
	d.ctx, d.cancel = context.WithCancel(ctx)
	ctx = d.ctx
	d.wg.Add(1)
	defer d.wg.Done()
	d.mu.Unlock()

	publishManagers := func(peers []*api.Peer) {
		var mgrs []*api.WeightedPeer
		for _, p := range peers {
			mgrs = append(mgrs, &api.WeightedPeer{
				Peer:   p,
				Weight: remotes.DefaultObservationWeight,
			})
		}
		d.mu.Lock()
		d.lastSeenManagers = mgrs
		d.mu.Unlock()
		d.mgrQueue.Publish(mgrs)
	}

	batchTimer := time.NewTimer(maxBatchInterval)
	defer batchTimer.Stop()

	for {
		select {
		case ev := <-peerWatcher:
			publishManagers(ev.([]*api.Peer))
		case <-d.processUpdatesTrigger:
			d.processUpdates(ctx)
			batchTimer.Reset(maxBatchInterval)
		case <-batchTimer.C:
			d.processUpdates(ctx)
			batchTimer.Reset(maxBatchInterval)
		case v := <-configWatcher:
			cluster := v.(state.EventUpdateCluster)
			d.mu.Lock()
			if cluster.Cluster.Spec.Dispatcher.HeartbeatPeriod != nil {
				// ignore error, since Spec has passed validation before
				heartbeatPeriod, _ := gogotypes.DurationFromProto(cluster.Cluster.Spec.Dispatcher.HeartbeatPeriod)
				if heartbeatPeriod != d.config.HeartbeatPeriod {
					// only call d.nodes.updatePeriod when heartbeatPeriod changes
					d.config.HeartbeatPeriod = heartbeatPeriod
					d.nodes.updatePeriod(d.config.HeartbeatPeriod, d.config.HeartbeatEpsilon, d.config.GracePeriodMultiplier)
				}
			}
			d.networkBootstrapKeys = cluster.Cluster.NetworkBootstrapKeys
			d.mu.Unlock()
			d.keyMgrQueue.Publish(cluster.Cluster.NetworkBootstrapKeys)
		case <-ctx.Done():
			return nil
		}
	}
}
Esempio n. 10
0
// Run starts the update and returns only once its complete or cancelled.
func (u *Updater) Run(ctx context.Context, slots []orchestrator.Slot) {
	defer close(u.doneChan)

	service := u.newService

	// If the update is in a PAUSED state, we should not do anything.
	if service.UpdateStatus != nil &&
		(service.UpdateStatus.State == api.UpdateStatus_PAUSED ||
			service.UpdateStatus.State == api.UpdateStatus_ROLLBACK_PAUSED) {
		return
	}

	var dirtySlots []orchestrator.Slot
	for _, slot := range slots {
		if u.isSlotDirty(slot) {
			dirtySlots = append(dirtySlots, slot)
		}
	}
	// Abort immediately if all tasks are clean.
	if len(dirtySlots) == 0 {
		if service.UpdateStatus != nil &&
			(service.UpdateStatus.State == api.UpdateStatus_UPDATING ||
				service.UpdateStatus.State == api.UpdateStatus_ROLLBACK_STARTED) {
			u.completeUpdate(ctx, service.ID)
		}
		return
	}

	// If there's no update in progress, we are starting one.
	if service.UpdateStatus == nil {
		u.startUpdate(ctx, service.ID)
	}

	parallelism := 0
	if service.Spec.Update != nil {
		parallelism = int(service.Spec.Update.Parallelism)
	}
	if parallelism == 0 {
		// TODO(aluzzardi): We could try to optimize unlimited parallelism by performing updates in a single
		// goroutine using a batch transaction.
		parallelism = len(dirtySlots)
	}

	// Start the workers.
	slotQueue := make(chan orchestrator.Slot)
	wg := sync.WaitGroup{}
	wg.Add(parallelism)
	for i := 0; i < parallelism; i++ {
		go func() {
			u.worker(ctx, slotQueue)
			wg.Done()
		}()
	}

	failureAction := api.UpdateConfig_PAUSE
	allowedFailureFraction := float32(0)
	monitoringPeriod := defaultMonitor

	if service.Spec.Update != nil {
		failureAction = service.Spec.Update.FailureAction
		allowedFailureFraction = service.Spec.Update.MaxFailureRatio

		if service.Spec.Update.Monitor != nil {
			var err error
			monitoringPeriod, err = gogotypes.DurationFromProto(service.Spec.Update.Monitor)
			if err != nil {
				monitoringPeriod = defaultMonitor
			}
		}
	}

	var failedTaskWatch chan events.Event

	if failureAction != api.UpdateConfig_CONTINUE {
		var cancelWatch func()
		failedTaskWatch, cancelWatch = state.Watch(
			u.store.WatchQueue(),
			state.EventUpdateTask{
				Task:   &api.Task{ServiceID: service.ID, Status: api.TaskStatus{State: api.TaskStateRunning}},
				Checks: []state.TaskCheckFunc{state.TaskCheckServiceID, state.TaskCheckStateGreaterThan},
			},
		)
		defer cancelWatch()
	}

	stopped := false
	failedTasks := make(map[string]struct{})
	totalFailures := 0

	failureTriggersAction := func(failedTask *api.Task) bool {
		// Ignore tasks we have already seen as failures.
		if _, found := failedTasks[failedTask.ID]; found {
			return false
		}

		// If this failed/completed task is one that we
		// created as part of this update, we should
		// follow the failure action.
		u.updatedTasksMu.Lock()
		startedAt, found := u.updatedTasks[failedTask.ID]
		u.updatedTasksMu.Unlock()

		if found && (startedAt.IsZero() || time.Since(startedAt) <= monitoringPeriod) {
			failedTasks[failedTask.ID] = struct{}{}
			totalFailures++
			if float32(totalFailures)/float32(len(dirtySlots)) > allowedFailureFraction {
				switch failureAction {
				case api.UpdateConfig_PAUSE:
					stopped = true
					message := fmt.Sprintf("update paused due to failure or early termination of task %s", failedTask.ID)
					u.pauseUpdate(ctx, service.ID, message)
					return true
				case api.UpdateConfig_ROLLBACK:
					// Never roll back a rollback
					if service.UpdateStatus != nil && service.UpdateStatus.State == api.UpdateStatus_ROLLBACK_STARTED {
						message := fmt.Sprintf("rollback paused due to failure or early termination of task %s", failedTask.ID)
						u.pauseUpdate(ctx, service.ID, message)
						return true
					}
					stopped = true
					message := fmt.Sprintf("update rolled back due to failure or early termination of task %s", failedTask.ID)
					u.rollbackUpdate(ctx, service.ID, message)
					return true
				}
			}
		}

		return false
	}

slotsLoop:
	for _, slot := range dirtySlots {
	retryLoop:
		for {
			// Wait for a worker to pick up the task or abort the update, whichever comes first.
			select {
			case <-u.stopChan:
				stopped = true
				break slotsLoop
			case ev := <-failedTaskWatch:
				if failureTriggersAction(ev.(state.EventUpdateTask).Task) {
					break slotsLoop
				}
			case slotQueue <- slot:
				break retryLoop
			}
		}
	}

	close(slotQueue)
	wg.Wait()

	if !stopped {
		// Keep watching for task failures for one more monitoringPeriod,
		// before declaring the update complete.
		doneMonitoring := time.After(monitoringPeriod)
	monitorLoop:
		for {
			select {
			case <-u.stopChan:
				stopped = true
				break monitorLoop
			case <-doneMonitoring:
				break monitorLoop
			case ev := <-failedTaskWatch:
				if failureTriggersAction(ev.(state.EventUpdateTask).Task) {
					break monitorLoop
				}
			}
		}
	}

	// TODO(aaronl): Potentially roll back the service if not enough tasks
	// have reached RUNNING by this point.

	if !stopped {
		u.completeUpdate(ctx, service.ID)
	}
}
Esempio n. 11
0
func (r *Supervisor) shouldRestart(ctx context.Context, t *api.Task, service *api.Service) bool {
	// TODO(aluzzardi): This function should not depend on `service`.

	condition := orchestrator.RestartCondition(t)

	if condition != api.RestartOnAny &&
		(condition != api.RestartOnFailure || t.Status.State == api.TaskStateCompleted) {
		return false
	}

	if t.Spec.Restart == nil || t.Spec.Restart.MaxAttempts == 0 {
		return true
	}

	instanceTuple := instanceTuple{
		instance:  t.Slot,
		serviceID: t.ServiceID,
	}

	// Instance is not meaningful for "global" tasks, so they need to be
	// indexed by NodeID.
	if orchestrator.IsGlobalService(service) {
		instanceTuple.nodeID = t.NodeID
	}

	r.mu.Lock()
	defer r.mu.Unlock()

	restartInfo := r.history[instanceTuple]
	if restartInfo == nil {
		return true
	}

	if t.Spec.Restart.Window == nil || (t.Spec.Restart.Window.Seconds == 0 && t.Spec.Restart.Window.Nanos == 0) {
		return restartInfo.totalRestarts < t.Spec.Restart.MaxAttempts
	}

	if restartInfo.restartedInstances == nil {
		return true
	}

	window, err := gogotypes.DurationFromProto(t.Spec.Restart.Window)
	if err != nil {
		log.G(ctx).WithError(err).Error("invalid restart lookback window")
		return restartInfo.totalRestarts < t.Spec.Restart.MaxAttempts
	}
	lookback := time.Now().Add(-window)

	var next *list.Element
	for e := restartInfo.restartedInstances.Front(); e != nil; e = next {
		next = e.Next()

		if e.Value.(restartedInstance).timestamp.After(lookback) {
			break
		}
		restartInfo.restartedInstances.Remove(e)
	}

	numRestarts := uint64(restartInfo.restartedInstances.Len())

	if numRestarts == 0 {
		restartInfo.restartedInstances = nil
	}

	return numRestarts < t.Spec.Restart.MaxAttempts
}
Esempio n. 12
0
// Restart initiates a new task to replace t if appropriate under the service's
// restart policy.
func (r *Supervisor) Restart(ctx context.Context, tx store.Tx, cluster *api.Cluster, service *api.Service, t api.Task) error {
	// TODO(aluzzardi): This function should not depend on `service`.

	// Is the old task still in the process of restarting? If so, wait for
	// its restart delay to elapse, to avoid tight restart loops (for
	// example, when the image doesn't exist).
	r.mu.Lock()
	oldDelay, ok := r.delays[t.ID]
	if ok {
		if !oldDelay.waiter {
			oldDelay.waiter = true
			go r.waitRestart(ctx, oldDelay, cluster, t.ID)
		}
		r.mu.Unlock()
		return nil
	}
	r.mu.Unlock()

	// Sanity check: was the task shut down already by a separate call to
	// Restart? If so, we must avoid restarting it, because this will create
	// an extra task. This should never happen unless there is a bug.
	if t.DesiredState > api.TaskStateRunning {
		return errors.New("Restart called on task that was already shut down")
	}

	t.DesiredState = api.TaskStateShutdown
	err := store.UpdateTask(tx, &t)
	if err != nil {
		log.G(ctx).WithError(err).Errorf("failed to set task desired state to dead")
		return err
	}

	if !r.shouldRestart(ctx, &t, service) {
		return nil
	}

	var restartTask *api.Task

	if orchestrator.IsReplicatedService(service) {
		restartTask = orchestrator.NewTask(cluster, service, t.Slot, "")
	} else if orchestrator.IsGlobalService(service) {
		restartTask = orchestrator.NewTask(cluster, service, 0, t.NodeID)
	} else {
		log.G(ctx).Error("service not supported by restart supervisor")
		return nil
	}

	n := store.GetNode(tx, t.NodeID)

	restartTask.DesiredState = api.TaskStateReady

	var restartDelay time.Duration
	// Restart delay is not applied to drained nodes
	if n == nil || n.Spec.Availability != api.NodeAvailabilityDrain {
		if t.Spec.Restart != nil && t.Spec.Restart.Delay != nil {
			var err error
			restartDelay, err = gogotypes.DurationFromProto(t.Spec.Restart.Delay)
			if err != nil {
				log.G(ctx).WithError(err).Error("invalid restart delay; using default")
				restartDelay = orchestrator.DefaultRestartDelay
			}
		} else {
			restartDelay = orchestrator.DefaultRestartDelay
		}
	}

	waitStop := true

	// Normally we wait for the old task to stop running, but we skip this
	// if the old task is already dead or the node it's assigned to is down.
	if (n != nil && n.Status.State == api.NodeStatus_DOWN) || t.Status.State > api.TaskStateRunning {
		waitStop = false
	}

	if err := store.CreateTask(tx, restartTask); err != nil {
		log.G(ctx).WithError(err).WithField("task.id", restartTask.ID).Error("task create failed")
		return err
	}

	r.recordRestartHistory(restartTask)

	r.DelayStart(ctx, tx, &t, restartTask.ID, restartDelay, waitStop)
	return nil
}