Beispiel #1
0
// Sets up the OpenFlow tables to get packets from containers into the OVN controlled
// bridge.  The Openflow tables are organized as follows.
//
//     - Table 0 will check for packets destined to an ip address of a label with MAC
//     0A:00:00:00:00:00 (obtained by OVN faking out arp) and use the OF mulipath action
//     to balance load packets across n links where n is the number of containers
//     implementing the label.  This result is stored in NXM_NX_REG0. This is done using
//     a symmetric l3/4 hash, so transport connections should remain intact.
//
//     -Table 1 reads NXM_NX_REG0 and changes the destination mac address to one of the
//     MACs of the containers that implement the label
//
// XXX: The multipath action doesn't perform well.  We should migrate away from it
// choosing datapath recirculation instead.
func updateOpenFlow(dk docker.Client, odb ovsdb.Ovsdb, containers []db.Container,
	labels []db.Label, connections []db.Connection) {
	targetOF, err := generateTargetOpenFlow(dk, odb, containers, labels, connections)
	if err != nil {
		log.WithError(err).Error("failed to get target OpenFlow flows")
		return
	}
	currentOF, err := generateCurrentOpenFlow(dk)
	if err != nil {
		log.WithError(err).Error("failed to get current OpenFlow flows")
		return
	}

	_, flowsToDel, flowsToAdd := join.HashJoin(currentOF, targetOF, nil, nil)

	for _, f := range flowsToDel {
		if err := deleteOFRule(dk, f.(OFRule)); err != nil {
			log.WithError(err).Error("error deleting OpenFlow flow")
		}
	}

	for _, f := range flowsToAdd {
		if err := addOFRule(dk, f.(OFRule)); err != nil {
			log.WithError(err).Error("error adding OpenFlow flow")
		}
	}
}
Beispiel #2
0
func updatePlacements(view db.Database, spec stitch.Stitch) {
	stitchPlacements := toDBPlacements(spec.QueryPlacements())
	key := func(val interface{}) interface{} {
		pVal := val.(db.Placement)
		return struct {
			tl   string
			rule db.PlacementRule
		}{pVal.TargetLabel, pVal.Rule}
	}

	_, addSet, removeSet := join.HashJoin(stitchPlacements,
		db.PlacementSlice(view.SelectFromPlacement(nil)), key, key)

	for _, toAddIntf := range addSet {
		toAdd := toAddIntf.(db.Placement)

		newPlacement := view.InsertPlacement()
		newPlacement.TargetLabel = toAdd.TargetLabel
		newPlacement.Rule = toAdd.Rule
		view.Commit(newPlacement)
	}

	for _, toRemove := range removeSet {
		view.Remove(toRemove.(db.Placement))
	}
}
Beispiel #3
0
func updatePlacements(view db.Database, spec stitch.Stitch) {
	var placements db.PlacementSlice
	for _, sp := range spec.Placements {
		placements = append(placements, db.Placement{
			TargetLabel: sp.TargetLabel,
			Exclusive:   sp.Exclusive,
			OtherLabel:  sp.OtherLabel,
			Provider:    sp.Provider,
			Size:        sp.Size,
			Region:      sp.Region,
		})
	}

	key := func(val interface{}) interface{} {
		p := val.(db.Placement)
		p.ID = 0
		return p
	}

	dbPlacements := db.PlacementSlice(view.SelectFromPlacement(nil))
	_, addSet, removeSet := join.HashJoin(placements, dbPlacements, key, key)

	for _, toAddIntf := range addSet {
		toAdd := toAddIntf.(db.Placement)

		id := view.InsertPlacement().ID
		toAdd.ID = id
		view.Commit(toAdd)
	}

	for _, toRemove := range removeSet {
		view.Remove(toRemove.(db.Placement))
	}
}
Beispiel #4
0
func syncDir(store Store, dir directory, path string, idsArg []string) {
	_, dirKeys, ids := join.HashJoin(join.StringSlice(dir.keys()),
		join.StringSlice(idsArg), nil, nil)

	var etcdLog string
	for _, dirKey := range dirKeys {
		id := dirKey.(string)
		keyPath := fmt.Sprintf("%s/%s", path, id)
		err := store.Delete(keyPath)
		if err != nil {
			etcdLog = fmt.Sprintf("Failed to delete %s: %s", keyPath, err)
		}
		delete(dir, id)
	}

	for _, idElem := range ids {
		id := idElem.(string)
		if _, ok := dir[id]; ok {
			continue
		}

		key := fmt.Sprintf("%s/%s", path, id)
		if err := store.Mkdir(key); err != nil {
			etcdLog = fmt.Sprintf("Failed to create dir %s: %s", key, err)
			continue
		}
		dir[id] = map[string]string{}
	}

	// Etcd failure leads to a bunch of useless errors.  Therefore we only log once.
	if etcdLog != "" {
		log.Error(etcdLog)
	}
}
Beispiel #5
0
func updateConnections(view db.Database, spec stitch.Stitch) {
	scs, vcs := stitch.ConnectionSlice(spec.Connections),
		view.SelectFromConnection(nil)

	dbcKey := func(val interface{}) interface{} {
		c := val.(db.Connection)
		return stitch.Connection{
			From:    c.From,
			To:      c.To,
			MinPort: c.MinPort,
			MaxPort: c.MaxPort,
		}
	}

	pairs, stitches, dbcs := join.HashJoin(scs, db.ConnectionSlice(vcs), nil, dbcKey)

	for _, dbc := range dbcs {
		view.Remove(dbc.(db.Connection))
	}

	for _, stitchc := range stitches {
		pairs = append(pairs, join.Pair{L: stitchc, R: view.InsertConnection()})
	}

	for _, pair := range pairs {
		stitchc := pair.L.(stitch.Connection)
		dbc := pair.R.(db.Connection)

		dbc.From = stitchc.From
		dbc.To = stitchc.To
		dbc.MinPort = stitchc.MinPort
		dbc.MaxPort = stitchc.MaxPort
		view.Commit(dbc)
	}
}
Beispiel #6
0
func updateDBLabels(view db.Database, etcdData storeData, ipMap map[string]string) {
	// Gather all of the label keys and IPs for single host labels, and IPs of
	// the containers in a given label.
	containerIPs := map[string][]string{}
	labelIPs := map[string]string{}
	labelKeys := map[string]struct{}{}
	for _, c := range etcdData.containers {
		for _, l := range c.Labels {
			labelKeys[l] = struct{}{}
			cIP := ipMap[strconv.Itoa(c.StitchID)]
			if _, ok := etcdData.multiHost[l]; !ok {
				labelIPs[l] = cIP
			}

			// The ordering of IPs between function calls will be consistent
			// because the containers are sorted by their StitchIDs when
			// inserted into etcd.
			containerIPs[l] = append(containerIPs[l], cIP)
		}
	}

	labelKeyFunc := func(val interface{}) interface{} {
		return val.(db.Label).Label
	}

	labelKeySlice := join.StringSlice{}
	for l := range labelKeys {
		labelKeySlice = append(labelKeySlice, l)
	}

	pairs, dbls, dirKeys := join.HashJoin(db.LabelSlice(view.SelectFromLabel(nil)),
		labelKeySlice, labelKeyFunc, nil)

	for _, dbl := range dbls {
		view.Remove(dbl.(db.Label))
	}

	for _, key := range dirKeys {
		pairs = append(pairs, join.Pair{L: view.InsertLabel(), R: key})
	}

	for _, pair := range pairs {
		dbl := pair.L.(db.Label)
		dbl.Label = pair.R.(string)
		if _, ok := etcdData.multiHost[dbl.Label]; ok {
			dbl.IP = etcdData.multiHost[dbl.Label]
			dbl.MultiHost = true
		} else {
			dbl.IP = labelIPs[dbl.Label]
			dbl.MultiHost = false
		}
		dbl.ContainerIPs = containerIPs[dbl.Label]

		view.Commit(dbl)
	}
}
Beispiel #7
0
func syncACLs(ovsdbClient ovsdb.Client, connections []db.Connection) {
	ovsdbACLs, err := ovsdbClient.ListACLs(lSwitch)
	if err != nil {
		log.WithError(err).Error("Failed to list ACLs")
		return
	}

	expACLs := directedACLs(ovsdb.ACL{
		Core: ovsdb.ACLCore{
			Action:   "drop",
			Match:    "ip",
			Priority: 0,
		},
	})

	for _, conn := range connections {
		if conn.From == stitch.PublicInternetLabel ||
			conn.To == stitch.PublicInternetLabel {
			continue
		}
		expACLs = append(expACLs, directedACLs(
			ovsdb.ACL{
				Core: ovsdb.ACLCore{
					Action:   "allow",
					Match:    matchString(conn),
					Priority: 1,
				},
			})...)
	}

	ovsdbKey := func(ovsdbIntf interface{}) interface{} {
		return ovsdbIntf.(ovsdb.ACL).Core
	}
	_, toCreate, toDelete := join.HashJoin(ovsdbACLSlice(expACLs),
		ovsdbACLSlice(ovsdbACLs), ovsdbKey, ovsdbKey)

	for _, acl := range toDelete {
		if err := ovsdbClient.DeleteACL(lSwitch, acl.(ovsdb.ACL)); err != nil {
			log.WithError(err).Warn("Error deleting ACL")
		}
	}

	for _, intf := range toCreate {
		acl := intf.(ovsdb.ACL).Core
		if err := ovsdbClient.CreateACL(lSwitch, acl.Direction,
			acl.Priority, acl.Match, acl.Action); err != nil {
			log.WithError(err).Warn("Error adding ACL")
		}
	}
}
Beispiel #8
0
func checkACLs(t *testing.T, client ovsdb.Client,
	connections []db.Connection, exp []ovsdb.ACL) {

	syncACLs(client, connections)

	actual, _ := client.ListACLs(lSwitch)

	ovsdbKey := func(ovsdbIntf interface{}) interface{} {
		return ovsdbIntf.(ovsdb.ACL).Core
	}
	if _, left, right := join.HashJoin(ovsdbACLSlice(actual), ovsdbACLSlice(exp),
		ovsdbKey, ovsdbKey); len(left) != 0 || len(right) != 0 {
		t.Errorf("Wrong ACLs: expected %v, got %v.", exp, actual)
	}
}
Beispiel #9
0
func syncAddressSets(ovsdbClient ovsdb.Client, labels []db.Label) {
	ovsdbAddresses, err := ovsdbClient.ListAddressSets(lSwitch)
	if err != nil {
		log.WithError(err).Error("Failed to list address sets")
		return
	}

	var expAddressSets []ovsdb.AddressSet
	for _, l := range labels {
		if l.Label == stitch.PublicInternetLabel {
			continue
		}
		expAddressSets = append(expAddressSets,
			ovsdb.AddressSet{
				Name:      addressSetName(l.Label),
				Addresses: unique(append(l.ContainerIPs, l.IP)),
			},
		)
	}
	ovsdbKey := func(intf interface{}) interface{} {
		addrSet := intf.(ovsdb.AddressSet)
		// OVSDB returns the addresses in a non-deterministic order, so we
		// sort them.
		sort.Strings(addrSet.Addresses)
		return addressSetKey{
			name:      addrSet.Name,
			addresses: strings.Join(addrSet.Addresses, " "),
		}
	}
	_, toCreate, toDelete := join.HashJoin(addressSlice(expAddressSets),
		addressSlice(ovsdbAddresses), ovsdbKey, ovsdbKey)

	for _, intf := range toDelete {
		addr := intf.(ovsdb.AddressSet)
		if err := ovsdbClient.DeleteAddressSet(lSwitch, addr.Name); err != nil {
			log.WithError(err).Warn("Error deleting address set")
		}
	}

	for _, intf := range toCreate {
		addr := intf.(ovsdb.AddressSet)
		if err := ovsdbClient.CreateAddressSet(
			lSwitch, addr.Name, addr.Addresses); err != nil {
			log.WithError(err).Warn("Error adding address set")
		}
	}
}
Beispiel #10
0
// There certain exceptions, as certain ports will never be deleted.
func updatePorts(odb ovsdb.Ovsdb, containers []db.Container) {
	// An Open vSwitch patch port is referred to as a "port".

	targetPorts := generateTargetPorts(containers)
	currentPorts, err := generateCurrentPorts(odb)
	if err != nil {
		log.WithError(err).Error("failed to generate current openflow ports")
		return
	}

	key := func(val interface{}) interface{} {
		return struct {
			name, bridge string
		}{
			name:   val.(ovsPort).name,
			bridge: val.(ovsPort).bridge,
		}
	}

	pairs, lefts, rights := join.HashJoin(currentPorts, targetPorts, key, key)

	for _, l := range lefts {
		if l.(ovsPort).name == l.(ovsPort).bridge {
			// The "bridge" port for the bridge should never be deleted
			continue
		}
		if err := delPort(odb, l.(ovsPort)); err != nil {
			log.WithError(err).Error("failed to delete openflow port")
			continue
		}
	}
	for _, r := range rights {
		if err := addPort(odb, r.(ovsPort)); err != nil {
			log.WithError(err).Error("failed to add openflow port")
			continue
		}
	}
	for _, p := range pairs {
		if err := modPort(odb, p.L.(ovsPort), p.R.(ovsPort)); err != nil {
			log.WithError(err).Error("failed to modify openflow port")
			continue
		}
	}
}
Beispiel #11
0
func (tester networkTester) test(container db.Container) (
	unreachable []string, unauthorized []string) {

	// We should be able to ping ourselves.
	expReachable := map[string]struct{}{
		container.IP: {},
	}
	for _, label := range container.Labels {
		for _, toLabelName := range tester.connectionMap[label] {
			toLabel := tester.labelMap[toLabelName]
			for _, ip := range append(toLabel.ContainerIPs, toLabel.IP) {
				expReachable[ip] = struct{}{}
			}
		}
		// We can ping our ovearching label, but not other containers within the
		// label. E.g. 1.yellow.q can ping yellow.q (but not 2.yellow.q).
		expReachable[tester.labelMap[label].IP] = struct{}{}
	}

	var expPings []pingResult
	for _, ip := range tester.allIPs {
		_, reachable := expReachable[ip]
		expPings = append(expPings, pingResult{
			target:    ip,
			reachable: reachable,
		})
	}
	pingResults := tester.pingAll(container)
	_, failures, _ := join.HashJoin(pingSlice(expPings), pingSlice(pingResults),
		nil, nil)

	for _, badIntf := range failures {
		bad := badIntf.(pingResult)
		if bad.reachable {
			unreachable = append(unreachable, bad.target)
		} else {
			unauthorized = append(unauthorized, bad.target)
		}
	}

	return unreachable, unauthorized
}
Beispiel #12
0
func updateIPs(namespace string, dev string, currIPs []string,
	targetIPs []string) error {

	_, ipToDel, ipToAdd := join.HashJoin(join.StringSlice(currIPs),
		join.StringSlice(targetIPs), nil, nil)

	for _, ip := range ipToDel {
		if err := delIP(namespace, ip.(string), dev); err != nil {
			return err
		}
	}

	for _, ip := range ipToAdd {
		if err := addIP(namespace, ip.(string), dev); err != nil {
			return err
		}
	}

	return nil
}
Beispiel #13
0
func checkAddressSet(t *testing.T, client ovsdb.Client,
	labels []db.Label, exp []ovsdb.AddressSet) {

	syncAddressSets(client, labels)
	actual, _ := client.ListAddressSets(lSwitch)

	ovsdbKey := func(intf interface{}) interface{} {
		addrSet := intf.(ovsdb.AddressSet)
		// OVSDB returns the addresses in a non-deterministic order, so we
		// sort them.
		sort.Strings(addrSet.Addresses)
		return addressSetKey{
			name:      addrSet.Name,
			addresses: strings.Join(addrSet.Addresses, " "),
		}
	}
	if _, lefts, rights := join.HashJoin(addressSlice(actual), addressSlice(exp),
		ovsdbKey, ovsdbKey); len(lefts) != 0 || len(rights) != 0 {
		t.Errorf("Wrong address sets: expected %v, got %v.", exp, actual)
	}
}
Beispiel #14
0
func (sv *supervisor) runAppTransact(view db.Database,
	dkcsArgs []docker.Container) []string {

	var tearDowns []string

	dbKey := func(val interface{}) interface{} {
		return val.(db.Container).DockerID
	}
	dkKey := func(val interface{}) interface{} {
		return val.(docker.Container).ID
	}

	pairs, dbcs, dkcs := join.HashJoin(db.ContainerSlice(
		view.SelectFromContainer(nil)),
		docker.ContainerSlice(dkcsArgs), dbKey, dkKey)

	for _, iface := range dbcs {
		dbc := iface.(db.Container)

		tearDowns = append(tearDowns, dbc.DockerID)
		view.Remove(dbc)
	}

	for _, dkc := range dkcs {
		pairs = append(pairs, join.Pair{L: view.InsertContainer(), R: dkc})
	}

	for _, pair := range pairs {
		dbc := pair.L.(db.Container)
		dkc := pair.R.(docker.Container)

		dbc.DockerID = dkc.ID
		dbc.Pid = dkc.Pid
		dbc.Image = dkc.Image
		dbc.Command = append([]string{dkc.Path}, dkc.Args...)
		view.Commit(dbc)
	}

	return tearDowns
}
Beispiel #15
0
func updateRoutes(containers []db.Container) {
	targetRoutes := routeSlice{
		{
			ip:        "10.0.0.0/8",
			dev:       innerVeth,
			isDefault: false,
		},
		{
			ip:        gatewayIP,
			dev:       innerVeth,
			isDefault: true,
		},
	}

	for _, dbc := range containers {
		ns := networkNS(dbc.DockerID)

		currentRoutes, err := generateCurrentRoutes(ns)
		if err != nil {
			log.WithError(err).Error("failed to get current ip routes")
			continue
		}

		_, routesDel, routesAdd := join.HashJoin(currentRoutes, targetRoutes,
			nil, nil)

		for _, l := range routesDel {
			if err := deleteRoute(ns, l.(route)); err != nil {
				log.WithError(err).Error("error deleting route")
			}
		}

		for _, r := range routesAdd {
			if err := addRoute(ns, r.(route)); err != nil {
				log.WithError(err).Error("error adding route")
			}
		}
	}
}
Beispiel #16
0
func updateVeths(containers []db.Container) {
	// A virtual ethernet link that links the host and container is a "veth".
	//
	// The ends of the veth have different config options like mtu, etc.
	// However if you delete one side, both will be deleted.

	targetVeths := generateTargetVeths(containers)
	currentVeths, err := generateCurrentVeths(containers)
	if err != nil {
		log.WithError(err).Error("failed to get veths")
		return
	}

	key := func(val interface{}) interface{} {
		return val.(netdev).name
	}

	pairs, lefts, rights := join.HashJoin(currentVeths, targetVeths, key, key)

	for _, l := range lefts {
		if err := delVeth(l.(netdev)); err != nil {
			log.WithError(err).Error("failed to delete veth")
			continue
		}
	}
	for _, r := range rights {
		if err := addVeth(r.(netdev)); err != nil {
			log.WithError(err).Error("failed to add veth")
			continue
		}
	}
	for _, p := range pairs {
		if err := modVeth(p.L.(netdev), p.R.(netdev)); err != nil {
			log.WithError(err).Error("failed to modify veth")
			continue
		}
	}
}
Beispiel #17
0
func readLabelTransact(view db.Database, dir directory) {
	lKey := func(val interface{}) interface{} {
		return val.(db.Label).Label
	}
	pairs, dbls, dirKeys := join.HashJoin(db.LabelSlice(view.SelectFromLabel(nil)),
		join.StringSlice(dir.keys()), lKey, nil)

	for _, dbl := range dbls {
		view.Remove(dbl.(db.Label))
	}

	for _, key := range dirKeys {
		pairs = append(pairs, join.Pair{L: view.InsertLabel(), R: key})
	}

	for _, pair := range pairs {
		dbl := pair.L.(db.Label)
		dbl.Label = pair.R.(string)
		dbl.IP = dir[dbl.Label]["IP"]
		_, dbl.MultiHost = dir[dbl.Label]["MultiHost"]
		view.Commit(dbl)
	}
}
Beispiel #18
0
// If a namespace in the path is detected as invalid and conflicts with
// a namespace that should exist, it's removed and replaced.
func updateNamespaces(containers []db.Container) {
	// A symbolic link in the netns path is considered a "namespace".
	// The actual namespace is elsewhere but we link them all into the
	// canonical location and manage them there.
	//
	// We keep all our namespaces in /var/run/netns/

	var targetNamespaces nsInfoSlice
	for _, dbc := range containers {
		targetNamespaces = append(targetNamespaces,
			nsInfo{ns: networkNS(dbc.DockerID), pid: dbc.Pid})
	}
	currentNamespaces, err := generateCurrentNamespaces()
	if err != nil {
		log.WithError(err).Error("failed to get namespaces")
		return
	}

	key := func(val interface{}) interface{} {
		return val.(nsInfo).ns
	}

	_, lefts, rights := join.HashJoin(currentNamespaces, targetNamespaces, key, key)

	for _, l := range lefts {
		if err := delNS(l.(nsInfo)); err != nil {
			log.WithError(err).Error("error deleting namespace")
		}
	}

	for _, r := range rights {
		if err := addNS(r.(nsInfo)); err != nil {
			log.WithError(err).Error("error adding namespace")
		}
	}
}
Beispiel #19
0
func updateNAT(containers []db.Container, connections []db.Connection) {
	targetRules := generateTargetNatRules(containers, connections)
	currRules, err := generateCurrentNatRules()
	if err != nil {
		log.WithError(err).Error("failed to get NAT rules")
		return
	}

	_, rulesToDel, rulesToAdd := join.HashJoin(currRules, targetRules, nil, nil)

	for _, rule := range rulesToDel {
		if err := deleteNatRule(rule.(ipRule)); err != nil {
			log.WithError(err).Error("failed to delete ip rule")
			continue
		}
	}

	for _, rule := range rulesToAdd {
		if err := addNatRule(rule.(ipRule)); err != nil {
			log.WithError(err).Error("failed to add ip rule")
			continue
		}
	}
}
Beispiel #20
0
// syncACLs returns the permissions that need to be removed and added in order
// for the cloud ACLs to match the policy.
// rangesToAdd is guaranteed to always have exactly one item in the IpRanges slice.
func syncACLs(desiredACLs []acl.ACL, desiredGroupID string,
	current []*ec2.IpPermission) (rangesToAdd []*ec2.IpPermission, foundGroup bool,
	toRemove []*ec2.IpPermission) {

	var currRangeRules []*ec2.IpPermission
	for _, perm := range current {
		for _, ipRange := range perm.IpRanges {
			currRangeRules = append(currRangeRules, &ec2.IpPermission{
				IpProtocol: perm.IpProtocol,
				FromPort:   perm.FromPort,
				ToPort:     perm.ToPort,
				IpRanges: []*ec2.IpRange{
					ipRange,
				},
			})
		}
		for _, pair := range perm.UserIdGroupPairs {
			if *pair.GroupId != desiredGroupID {
				toRemove = append(toRemove, &ec2.IpPermission{
					UserIdGroupPairs: []*ec2.UserIdGroupPair{
						pair,
					},
				})
			} else {
				foundGroup = true
			}
		}
	}

	var desiredRangeRules []*ec2.IpPermission
	for _, acl := range desiredACLs {
		desiredRangeRules = append(desiredRangeRules, &ec2.IpPermission{
			FromPort: aws.Int64(int64(acl.MinPort)),
			ToPort:   aws.Int64(int64(acl.MaxPort)),
			IpRanges: []*ec2.IpRange{
				{
					CidrIp: aws.String(acl.CidrIP),
				},
			},
			IpProtocol: aws.String("tcp"),
		}, &ec2.IpPermission{
			FromPort: aws.Int64(int64(acl.MinPort)),
			ToPort:   aws.Int64(int64(acl.MaxPort)),
			IpRanges: []*ec2.IpRange{
				{
					CidrIp: aws.String(acl.CidrIP),
				},
			},
			IpProtocol: aws.String("udp"),
		}, &ec2.IpPermission{
			FromPort: aws.Int64(-1),
			ToPort:   aws.Int64(-1),
			IpRanges: []*ec2.IpRange{
				{
					CidrIp: aws.String(acl.CidrIP),
				},
			},
			IpProtocol: aws.String("icmp"),
		})
	}

	_, toAdd, rangesToRemove := join.HashJoin(ipPermSlice(desiredRangeRules),
		ipPermSlice(currRangeRules), permToACLKey, permToACLKey)
	for _, intf := range toAdd {
		rangesToAdd = append(rangesToAdd, intf.(*ec2.IpPermission))
	}
	for _, intf := range rangesToRemove {
		toRemove = append(toRemove, intf.(*ec2.IpPermission))
	}

	return rangesToAdd, foundGroup, toRemove
}
Beispiel #21
0
func TestInterfaces(t *testing.T) {
	ovsdbClient := NewFakeOvsdbClient()

	// Create new switch.
	lswitch1 := "test-switch-1"
	err := ovsdbClient.CreateLogicalSwitch(lswitch1)
	assert.Nil(t, err)

	key := func(val interface{}) interface{} {
		iface := val.(Interface)
		return struct {
			name, bridge string
			ofport       int
		}{
			name:   iface.Name,
			bridge: iface.Bridge,
			ofport: *iface.OFPort,
		}
	}

	checkCorrectness := func(ovsdbIfaces []Interface, localIfaces ...Interface) {
		pair, _, _ := join.HashJoin(InterfaceSlice(ovsdbIfaces),
			InterfaceSlice(localIfaces), key, key)
		assert.Equal(t, len(pair), len(ovsdbIfaces))
	}

	// Create a new Bridge. In quilt this is usually done in supervisor.
	_, err = ovsdbClient.transact("Open_vSwitch", ovs.Operation{
		Op:    "insert",
		Table: "Bridge",
		Row:   map[string]interface{}{"name": lswitch1},
	})
	assert.Nil(t, err)

	// Ovsdb mock uses defaultOFPort as the ofport created for each interface.
	expectedOFPort := int(defaultOFPort)

	// Create one interface.
	iface1 := Interface{
		Name:   "iface1",
		Bridge: lswitch1,
		OFPort: &expectedOFPort,
	}

	err = ovsdbClient.CreateInterface(iface1.Bridge, iface1.Name)
	assert.Nil(t, err)

	ifaces, err := ovsdbClient.ListInterfaces()
	assert.Nil(t, err)
	assert.Equal(t, 1, len(ifaces))

	checkCorrectness(ifaces, iface1)

	// Now create a new switch and bridge. Attach one new interface to them.

	// Create new switch.
	lswitch2 := "test-switch-2"
	err = ovsdbClient.CreateLogicalSwitch(lswitch2)
	assert.Nil(t, err)

	// Create a new Bridge.
	_, err = ovsdbClient.transact("Open_vSwitch", ovs.Operation{
		Op:    "insert",
		Table: "Bridge",
		Row:   map[string]interface{}{"name": lswitch2},
	})
	assert.Nil(t, err)

	// Create a new interface.
	iface2 := Interface{
		Name:   "iface2",
		Bridge: lswitch2,
		OFPort: &expectedOFPort,
	}

	err = ovsdbClient.CreateInterface(iface2.Bridge, iface2.Name)
	assert.Nil(t, err)

	ifaces, err = ovsdbClient.ListInterfaces()
	assert.Nil(t, err)
	assert.Equal(t, 2, len(ifaces))

	checkCorrectness(ifaces, iface1, iface2)

	iface1, iface2 = ifaces[0], ifaces[1]

	// Delete interface 1.
	err = ovsdbClient.DeleteInterface(iface1)
	assert.Nil(t, err)

	ifaces, err = ovsdbClient.ListInterfaces()
	assert.Nil(t, err)

	assert.Equal(t, 1, len(ifaces))

	checkCorrectness(ifaces, iface2)

	// Delete interface 2.
	err = ovsdbClient.DeleteInterface(iface2)
	assert.Nil(t, err)

	ifaces, err = ovsdbClient.ListInterfaces()
	assert.Nil(t, err)
	assert.Zero(t, len(ifaces))

	// Test ModifyInterface. We do this by creating an interface with type peer,
	// attach a mac address to it, and add external_ids.
	iface := Interface{
		Name:        "test-modify-iface",
		Peer:        "lolz",
		AttachedMAC: "00:00:00:00:00:00",
		Bridge:      lswitch1,
		Type:        "patch",
	}

	err = ovsdbClient.CreateInterface(iface.Bridge, iface.Name)
	assert.Nil(t, err)

	err = ovsdbClient.ModifyInterface(iface)
	assert.Nil(t, err)

	ifaces, err = ovsdbClient.ListInterfaces()
	assert.Nil(t, err)

	ovsdbIface := ifaces[0]
	iface.uuid = ovsdbIface.uuid
	iface.portUUID = ovsdbIface.portUUID
	iface.OFPort = ovsdbIface.OFPort
	assert.Equal(t, iface, ovsdbIface)
}
Beispiel #22
0
func TestAddressSets(t *testing.T) {
	ovsdbClient := NewFakeOvsdbClient()

	key := func(intf interface{}) interface{} {
		addrSet := intf.(AddressSet)
		// OVSDB returns the addresses in a non-deterministic order, so we
		// sort them.
		sort.Strings(addrSet.Addresses)
		return addressSetKey{
			name:      addrSet.Name,
			addresses: strings.Join(addrSet.Addresses, " "),
		}
	}

	checkCorrectness := func(ovsdbAddrSets []AddressSet, expAddrSets ...AddressSet) {
		pair, _, _ := join.HashJoin(addressSlice(ovsdbAddrSets),
			addressSlice(expAddrSets), key, key)
		assert.Equal(t, len(pair), len(expAddrSets))
	}

	// Create new switch.
	lswitch := "test-switch"
	err := ovsdbClient.CreateLogicalSwitch(lswitch)
	assert.Nil(t, err)

	// Create one Address Set.
	addrSet1 := AddressSet{
		Name:      "red",
		Addresses: []string{"foo", "bar"},
	}

	err = ovsdbClient.CreateAddressSet(lswitch, addrSet1.Name, addrSet1.Addresses)
	assert.Nil(t, err)

	// It should now have one ACL entry to be listed.
	ovsdbAddrSets, err := ovsdbClient.ListAddressSets(lswitch)
	assert.Nil(t, err)
	assert.Equal(t, 1, len(ovsdbAddrSets))

	checkCorrectness(ovsdbAddrSets, addrSet1)

	// Create one more address set.
	addrSet2 := AddressSet{
		Name:      "blue",
		Addresses: []string{"bar", "baz"},
	}

	err = ovsdbClient.CreateAddressSet(lswitch, addrSet2.Name, addrSet2.Addresses)
	assert.Nil(t, err)

	// It should now have two address sets to be listed.
	ovsdbAddrSets, err = ovsdbClient.ListAddressSets(lswitch)
	assert.Nil(t, err)
	assert.Equal(t, 2, len(ovsdbAddrSets))

	checkCorrectness(ovsdbAddrSets, addrSet1, addrSet2)

	// Delete the first address set.
	err = ovsdbClient.DeleteAddressSet(lswitch, addrSet1.Name)
	assert.Nil(t, err)

	// It should now have only one address set to be listed.
	ovsdbAddrSets, err = ovsdbClient.ListAddressSets(lswitch)
	assert.Nil(t, err)
	assert.Equal(t, 1, len(ovsdbAddrSets))

	checkCorrectness(ovsdbAddrSets, addrSet2)

	// Delete the other ACL rule.
	err = ovsdbClient.DeleteAddressSet(lswitch, addrSet2.Name)
	assert.Nil(t, err)

	// It should now have only one address set to be listed.
	ovsdbAddrSets, err = ovsdbClient.ListAddressSets(lswitch)
	assert.Nil(t, err)
	assert.Zero(t, len(ovsdbAddrSets))
}
Beispiel #23
0
func TestEngine(t *testing.T) {
	pre := `var deployment = createDeployment({
		namespace: "namespace",
		adminACL: ["1.2.3.4/32"],
	});
	var baseMachine = new Machine({provider: "Amazon", size: "m4.large"});`
	conn := db.New()

	code := pre + `deployment.deploy(baseMachine.asMaster().replicate(2));
		deployment.deploy(baseMachine.asWorker().replicate(3));`

	updateStitch(t, conn, prog(t, code))
	acl, err := selectACL(conn)
	assert.Nil(t, err)
	assert.Equal(t, 1, len(acl.Admin))

	masters, workers := selectMachines(conn)
	assert.Equal(t, 2, len(masters))
	assert.Equal(t, 3, len(workers))

	/* Verify master increase. */
	code = pre + `deployment.deploy(baseMachine.asMaster().replicate(4));
		deployment.deploy(baseMachine.asWorker().replicate(5));`

	updateStitch(t, conn, prog(t, code))
	masters, workers = selectMachines(conn)
	assert.Equal(t, 4, len(masters))
	assert.Equal(t, 5, len(workers))

	/* Verify that external writes stick around. */
	conn.Transact(func(view db.Database) error {
		masters := view.SelectFromMachine(func(m db.Machine) bool {
			return m.Role == db.Master
		})
		workers := view.SelectFromMachine(func(m db.Machine) bool {
			return m.Role == db.Worker
		})

		for _, master := range masters {
			master.CloudID = "1"
			master.PublicIP = "2"
			master.PrivateIP = "3"
			view.Commit(master)
		}

		for _, worker := range workers {
			worker.CloudID = "1"
			worker.PublicIP = "2"
			worker.PrivateIP = "3"
			view.Commit(worker)
		}

		return nil
	})

	/* Also verify that masters and workers decrease properly. */
	code = pre + `deployment.deploy(baseMachine.asMaster());
		deployment.deploy(baseMachine.asWorker());`
	updateStitch(t, conn, prog(t, code))

	masters, workers = selectMachines(conn)

	assert.Equal(t, 1, len(masters))
	assert.Equal(t, "1", masters[0].CloudID)
	assert.Equal(t, "2", masters[0].PublicIP)
	assert.Equal(t, "3", masters[0].PrivateIP)

	assert.Equal(t, 1, len(workers))
	assert.Equal(t, "1", workers[0].CloudID)
	assert.Equal(t, "2", workers[0].PublicIP)
	assert.Equal(t, "3", workers[0].PrivateIP)

	/* Empty Namespace does nothing. */
	code = pre + `deployment.namespace = "";
		deployment.deploy(baseMachine.asMaster());
		deployment.deploy(baseMachine.asWorker());`
	updateStitch(t, conn, prog(t, code))
	masters, workers = selectMachines(conn)

	assert.Equal(t, 1, len(masters))
	assert.Equal(t, "1", masters[0].CloudID)
	assert.Equal(t, "2", masters[0].PublicIP)
	assert.Equal(t, "3", masters[0].PrivateIP)

	assert.Equal(t, 1, len(workers))
	assert.Equal(t, "1", workers[0].CloudID)
	assert.Equal(t, "2", workers[0].PublicIP)
	assert.Equal(t, "3", workers[0].PrivateIP)

	/* Verify things go to zero. */
	code = pre + `deployment.deploy(baseMachine.asWorker())`
	updateStitch(t, conn, prog(t, code))
	masters, workers = selectMachines(conn)
	assert.Zero(t, len(masters))
	assert.Zero(t, len(workers))

	// This function checks whether there is a one-to-one mapping for each machine
	// in `slice` to a provider in `providers`.
	providersInSlice := func(slice db.MachineSlice, providers db.ProviderSlice) bool {
		lKey := func(left interface{}) interface{} {
			return left.(db.Machine).Provider
		}
		rKey := func(right interface{}) interface{} {
			return right.(db.Provider)
		}
		_, l, r := join.HashJoin(slice, providers, lKey, rKey)
		return len(l) == 0 && len(r) == 0
	}

	/* Test mixed providers. */
	code = `deployment.deploy([
		new Machine({provider: "Amazon", size: "m4.large", role: "Master"}),
		new Machine({provider: "Vagrant", size: "v.large", role: "Master"}),
		new Machine({provider: "Amazon", size: "m4.large", role: "Worker"}),
		new Machine({provider: "Google", size: "g.large", role: "Worker"})]);`
	updateStitch(t, conn, prog(t, code))
	masters, workers = selectMachines(conn)
	assert.True(t, providersInSlice(masters,
		db.ProviderSlice{db.Amazon, db.Vagrant}))
	assert.True(t, providersInSlice(workers, db.ProviderSlice{db.Amazon, db.Google}))

	/* Test that machines with different providers don't match. */
	code = `deployment.deploy([
		new Machine({provider: "Amazon", size: "m4.large", role: "Master"}),
		new Machine({provider: "Amazon", size: "m4.large", role: "Worker"})]);`
	updateStitch(t, conn, prog(t, code))
	masters, _ = selectMachines(conn)
	assert.True(t, providersInSlice(masters, db.ProviderSlice{db.Amazon}))
}
Beispiel #24
0
// The leader of the cluster is responsible for properly configuring OVN northd for
// container networking.  This simply means creating a logical port for each container
// and label.  The specialized OpenFlow rules Quilt requires are managed by the workers
// individuallly.
func runMaster(conn db.Conn) {
	var leader, init bool
	var labels []db.Label
	var containers []db.Container
	var connections []db.Connection
	conn.Transact(func(view db.Database) error {
		init = checkSupervisorInit(view)
		leader = view.EtcdLeader()

		labels = view.SelectFromLabel(func(label db.Label) bool {
			return label.IP != ""
		})

		containers = view.SelectFromContainer(func(dbc db.Container) bool {
			return dbc.Mac != "" && dbc.IP != ""
		})

		connections = view.SelectFromConnection(nil)
		return nil
	})

	if !init || !leader {
		return
	}

	var dbData []dbport
	for _, l := range labels {
		if l.MultiHost {
			dbData = append(dbData, dbport{
				bridge: lSwitch,
				ip:     l.IP,
				mac:    labelMac,
			})
		}
	}
	for _, c := range containers {
		dbData = append(dbData, dbport{bridge: lSwitch, ip: c.IP, mac: c.Mac})
	}

	ovsdbClient, err := ovsdb.Open()
	if err != nil {
		log.WithError(err).Error("Failed to connect to OVSDB.")
		return
	}
	defer ovsdbClient.Close()

	ovsdbClient.CreateLogicalSwitch(lSwitch)
	lports, err := ovsdbClient.ListLogicalPorts(lSwitch)
	if err != nil {
		log.WithError(err).Error("Failed to list OVN ports.")
		return
	}

	portKey := func(val interface{}) interface{} {
		port := val.(ovsdb.LPort)
		return fmt.Sprintf("bridge:%s\nname:%s", port.Bridge, port.Name)
	}

	dbKey := func(val interface{}) interface{} {
		dbPort := val.(dbport)
		return fmt.Sprintf("bridge:%s\nname:%s", dbPort.bridge, dbPort.ip)
	}

	_, ovsps, dbps := join.HashJoin(ovsdb.LPortSlice(lports), dbslice(dbData),
		portKey, dbKey)

	for _, dbp := range dbps {
		lport := dbp.(dbport)
		log.WithField("IP", lport.ip).Info("New logical port.")
		err := ovsdbClient.CreateLogicalPort(lport.bridge, lport.ip, lport.mac,
			lport.ip)
		if err != nil {
			log.WithError(err).Warnf("Failed to create port %s.", lport.ip)
		}
	}

	for _, ovsp := range ovsps {
		lport := ovsp.(ovsdb.LPort)
		log.Infof("Delete logical port %s.", lport.Name)
		if err := ovsdbClient.DeleteLogicalPort(lSwitch, lport); err != nil {
			log.WithError(err).Warn("Failed to delete logical port.")
		}
	}

	updateACLs(ovsdbClient, connections, labels)
}
Beispiel #25
0
func (clst *azureCluster) syncSecurityRules(securityGroupName string,
	localRules securityRuleSlice, cloudRules securityRuleSlice) error {
	key := func(val interface{}) interface{} {
		property := val.(network.SecurityRule).Properties
		return struct {
			sourcePortRange          string
			sourceAddressPrefix      string
			destinationPortRange     string
			destinationAddressPrefix string
			direction                network.SecurityRuleDirection
		}{
			sourcePortRange:          *property.SourcePortRange,
			sourceAddressPrefix:      *property.SourceAddressPrefix,
			destinationPortRange:     *property.DestinationPortRange,
			destinationAddressPrefix: *property.DestinationAddressPrefix,
			direction:                property.Direction,
		}
	}

	_, addList, deleteList := join.HashJoin(localRules, cloudRules, key, key)

	// Each security rule is required to be assigned one unique priority number
	// Between 100 and 4096.
	newPriorities := []int32{}

	currPriorities := make(map[int32]struct{})
	for _, rule := range cloudRules {
		currPriorities[*rule.Properties.Priority] = struct{}{}
	}

	cancel := make(chan struct{})
	for _, r := range deleteList {
		rule := r.(network.SecurityRule)
		delete(currPriorities, *rule.Properties.Priority)
		if _, err := clst.azureClient.securityRuleDelete(resourceGroupName,
			securityGroupName, *rule.Name, cancel); err != nil {
			return err
		}
	}

	priority := int32(100)
	for range addList {
		foundSlot := false
		for !foundSlot {
			if priority > 4096 {
				return errors.New("max number of security rules reached")
			}
			if _, ok := currPriorities[priority]; !ok {
				newPriorities = append(newPriorities, priority)
				foundSlot = true
			}
			priority++
		}
	}

	for i, r := range addList {
		rule := r.(network.SecurityRule)
		rule.Properties.Priority = &newPriorities[i]
		if _, err := clst.azureClient.securityRuleCreate(resourceGroupName,
			securityGroupName, *rule.Name, rule, cancel); err != nil {
			return err
		}
	}
	return nil
}
Beispiel #26
0
func TestEngine(t *testing.T) {
	spew := spew.NewDefaultConfig()
	spew.MaxDepth = 2

	conn := db.New()

	code := `
(define Namespace "Namespace")
(define MasterCount 2)
(define WorkerCount 3)
(makeList MasterCount (machine (provider "Amazon") (size "m4.large") (role "Master")))
(makeList WorkerCount (machine (provider "Amazon") (size "m4.large") (role "Worker")))
(define AdminACL (list "1.2.3.4/32"))`

	UpdatePolicy(conn, prog(t, code))
	err := conn.Transact(func(view db.Database) error {
		cluster, err := view.GetCluster()
		masters := view.SelectFromMachine(func(m db.Machine) bool {
			return m.Role == db.Master
		})
		workers := view.SelectFromMachine(func(m db.Machine) bool {
			return m.Role == db.Worker
		})

		if err != nil {
			return err
		} else if len(cluster.ACLs) != 1 {
			return fmt.Errorf("bad cluster: %s", spew.Sdump(cluster))
		}

		if len(masters) != 2 {
			return fmt.Errorf("bad masters: %s", spew.Sdump(masters))
		}

		if len(workers) != 3 {
			return fmt.Errorf("bad workers: %s", spew.Sdump(workers))
		}
		return nil
	})
	if err != nil {
		t.Error(err.Error())
	}

	/* Verify master increase. */
	code = `
(define Namespace "Namespace")
(define MasterCount 4)
(define WorkerCount 5)
(makeList MasterCount (machine (provider "Amazon") (size "m4.large") (role "Master")))
(makeList WorkerCount (machine (provider "Amazon") (size "m4.large") (role "Worker")))
(define AdminACL (list "1.2.3.4/32"))`

	UpdatePolicy(conn, prog(t, code))
	err = conn.Transact(func(view db.Database) error {
		masters := view.SelectFromMachine(func(m db.Machine) bool {
			return m.Role == db.Master
		})
		workers := view.SelectFromMachine(func(m db.Machine) bool {
			return m.Role == db.Worker
		})

		if len(masters) != 4 {
			return fmt.Errorf("bad masters: %s", spew.Sdump(masters))
		}

		if len(workers) != 5 {
			return fmt.Errorf("bad workers: %s", spew.Sdump(workers))
		}
		return nil
	})
	if err != nil {
		t.Error(err.Error())
	}

	/* Verify that external writes stick around. */
	err = conn.Transact(func(view db.Database) error {
		masters := view.SelectFromMachine(func(m db.Machine) bool {
			return m.Role == db.Master
		})
		workers := view.SelectFromMachine(func(m db.Machine) bool {
			return m.Role == db.Worker
		})

		for _, master := range masters {
			master.CloudID = "1"
			master.PublicIP = "2"
			master.PrivateIP = "3"
			view.Commit(master)
		}

		for _, worker := range workers {
			worker.CloudID = "1"
			worker.PublicIP = "2"
			worker.PrivateIP = "3"
			view.Commit(worker)
		}

		return nil
	})

	/* Also verify that masters and workers decrease properly. */
	code = `
(define Namespace "Namespace")
(define MasterCount 1)
(define WorkerCount 1)
(makeList MasterCount (machine (provider "Amazon") (size "m4.large") (role "Master")))
(makeList WorkerCount (machine (provider "Amazon") (size "m4.large") (role "Worker")))
(define AdminACL (list "1.2.3.4/32"))`
	UpdatePolicy(conn, prog(t, code))
	err = conn.Transact(func(view db.Database) error {
		masters := view.SelectFromMachine(func(m db.Machine) bool {
			return m.Role == db.Master
		})
		workers := view.SelectFromMachine(func(m db.Machine) bool {
			return m.Role == db.Worker
		})

		if len(masters) != 1 || masters[0].CloudID != "1" ||
			masters[0].PublicIP != "2" || masters[0].PrivateIP != "3" {
			return fmt.Errorf("bad masters: %s", spew.Sdump(masters))
		}

		if len(workers) != 1 || workers[0].CloudID != "1" ||
			workers[0].PublicIP != "2" || workers[0].PrivateIP != "3" {
			return fmt.Errorf("bad workers: %s", spew.Sdump(workers))
		}
		return nil
	})
	if err != nil {
		t.Error(err.Error())
	}

	/* Empty Namespace does nothing. */
	code = `
(define MasterCount 1)
(define WorkerCount 1)
(makeList MasterCount (machine (provider "Amazon") (size "m4.large") (role "Master")))
(makeList WorkerCount (machine (provider "Amazon") (size "m4.large") (role "Worker")))
(define AdminACL (list "1.2.3.4/32"))`
	UpdatePolicy(conn, prog(t, code))
	err = conn.Transact(func(view db.Database) error {
		masters := view.SelectFromMachine(func(m db.Machine) bool {
			return m.Role == db.Master
		})
		workers := view.SelectFromMachine(func(m db.Machine) bool {
			return m.Role == db.Worker
		})

		if len(masters) != 1 || masters[0].CloudID != "1" ||
			masters[0].PublicIP != "2" || masters[0].PrivateIP != "3" {
			return fmt.Errorf("bad masters: %s", spew.Sdump(masters))
		}

		if len(workers) != 1 || workers[0].CloudID != "1" ||
			workers[0].PublicIP != "2" || workers[0].PrivateIP != "3" {
			return fmt.Errorf("bad workers: %s", spew.Sdump(workers))
		}
		return nil
	})
	if err != nil {
		t.Error(err.Error())
	}

	/* Verify things go to zero. */
	code = `
(define Namespace "Namespace")
(define MasterCount 0)
(define WorkerCount 1)
(makeList MasterCount (machine (provider "Amazon") (size "m4.large") (role "Master")))
(makeList WorkerCount (machine (provider "Amazon") (size "m4.large") (role "Worker")))
(define AdminACL (list "1.2.3.4/32"))`
	UpdatePolicy(conn, prog(t, code))
	err = conn.Transact(func(view db.Database) error {
		masters := view.SelectFromMachine(func(m db.Machine) bool {
			return m.Role == db.Master
		})
		workers := view.SelectFromMachine(func(m db.Machine) bool {
			return m.Role == db.Worker
		})

		if len(masters) != 0 {
			return fmt.Errorf("bad masters: %s", spew.Sdump(masters))
		}

		if len(workers) != 0 {
			return fmt.Errorf("bad workers: %s", spew.Sdump(workers))
		}
		return nil
	})
	if err != nil {
		t.Error(err.Error())
	}

	// This function checks whether there is a one-to-one mapping for each machine
	// in `slice` to a provider in `providers`.
	providersInSlice := func(slice db.MachineSlice, providers db.ProviderSlice) bool {
		lKey := func(left interface{}) interface{} {
			return left.(db.Machine).Provider
		}
		rKey := func(right interface{}) interface{} {
			return right.(db.Provider)
		}
		_, l, r := join.HashJoin(slice, providers, lKey, rKey)
		return len(l) == 0 && len(r) == 0
	}

	/* Test mixed providers. */
	code = `
	(define Namespace "Namespace")
	(list (machine (provider "Amazon") (size "m4.large") (role "Master"))
	      (machine (provider "Vagrant") (size "v.large") (role "Master")))
	(list (machine (provider "Azure") (size "a.large") (role "Worker"))
	      (machine (provider "Google") (size "g.large") (role "Worker")))
	(define AdminACL (list "1.2.3.4/32"))`
	UpdatePolicy(conn, prog(t, code))
	err = conn.Transact(func(view db.Database) error {
		masters := view.SelectFromMachine(func(m db.Machine) bool {
			return m.Role == db.Master
		})
		workers := view.SelectFromMachine(func(m db.Machine) bool {
			return m.Role == db.Worker
		})

		if !providersInSlice(masters, db.ProviderSlice{db.Amazon, db.Vagrant}) {
			return fmt.Errorf("bad masters: %s", spew.Sdump(masters))
		}

		if !providersInSlice(workers, db.ProviderSlice{db.Azure, db.Google}) {
			return fmt.Errorf("bad workers: %s", spew.Sdump(workers))
		}
		return nil
	})
	if err != nil {
		t.Error(err.Error())
	}

	/* Test that machines with different providers don't match. */
	code = `
	(define Namespace "Namespace")
	(list (machine (provider "Amazon") (size "m4.large") (role "Master"))
	      (machine (provider "Azure") (size "a.large") (role "Master")))
	(list (machine (provider "Amazon") (size "m4.large") (role "Worker")))
	(define AdminACL (list "1.2.3.4/32"))`
	UpdatePolicy(conn, prog(t, code))
	err = conn.Transact(func(view db.Database) error {
		masters := view.SelectFromMachine(func(m db.Machine) bool {
			return m.Role == db.Master
		})

		if !providersInSlice(masters, db.ProviderSlice{db.Amazon, db.Azure}) {
			return fmt.Errorf("bad masters: %s", spew.Sdump(masters))
		}
		return nil
	})
	if err != nil {
		t.Error(err.Error())
	}
}
Beispiel #27
0
func TestSyncSecurityRules(t *testing.T) {
	fakeClst := newFakeAzureCluster()
	cancel := make(chan struct{})
	nsg := network.SecurityGroup{
		Name:     stringPtr("test-nsg"),
		ID:       stringPtr("test-nsg"),
		Location: stringPtr("location1"),
		Tags:     &map[string]*string{nsTag: &fakeClst.namespace},
	}
	fakeClst.azureClient.securityGroupCreate(resourceGroupName, *nsg.Name, nsg,
		cancel)

	sync := func(localRules *[]network.SecurityRule,
		cloudRules *[]network.SecurityRule) ([]int32, error) {
		if err := fakeClst.syncSecurityRules(*nsg.Name, *localRules,
			*cloudRules); err != nil {
			return nil, err
		}

		result, err := fakeClst.azureClient.securityRuleList(resourceGroupName,
			*nsg.Name)
		if err != nil {
			return nil, err
		}

		*cloudRules = *result.Value

		cloudPriorities := []int32{}
		for _, rule := range *cloudRules {
			properties := *rule.Properties
			cloudPriorities = append(cloudPriorities, *properties.Priority)
		}

		return cloudPriorities, nil
	}

	checkSync := func(localRules []network.SecurityRule, expectedPriorities []int32,
		cloudRules []network.SecurityRule, cloudPriorities []int32) {
		priorityKey := func(val interface{}) interface{} {
			return val.(int32)
		}

		pair, left, right := join.HashJoin(int32Slice(expectedPriorities),
			int32Slice(cloudPriorities), priorityKey, priorityKey)
		if len(pair) != len(cloudRules) || len(left) != 0 || len(right) != 0 {
			t.Log(pair, left, right)
			t.Error("Error setting security rule priorities.")
		}

		ruleKey := func(val interface{}) interface{} {
			property := val.(network.SecurityRule).Properties
			return struct {
				sourcePortRange      string
				sourceAddressPrefix  string
				destinationPortRange string
				destAddressPrefix    string
				direction            network.SecurityRuleDirection
			}{
				sourcePortRange:      *property.SourcePortRange,
				sourceAddressPrefix:  *property.SourceAddressPrefix,
				destinationPortRange: *property.DestinationPortRange,
				destAddressPrefix:    *property.DestinationAddressPrefix,
				direction:            property.Direction,
			}
		}

		pair, left, right = join.HashJoin(securityRuleSlice(localRules),
			securityRuleSlice(cloudRules), ruleKey, ruleKey)
		if len(pair) != len(cloudRules) || len(left) != 0 || len(right) != 0 {
			t.Error("Error setting security rules.")
		}
	}

	// Initially add two rules.
	rule1 := network.SecurityRule{
		ID:   stringPtr("1"),
		Name: stringPtr("1"),
		Properties: &network.SecurityRulePropertiesFormat{
			Protocol:                 network.Asterisk,
			SourcePortRange:          stringPtr("*"),
			SourceAddressPrefix:      stringPtr("10.0.0.1"),
			DestinationPortRange:     stringPtr("*"),
			DestinationAddressPrefix: stringPtr("*"),
			Access:    network.Allow,
			Direction: network.Inbound,
		},
	}

	rule2 := network.SecurityRule{
		ID:   stringPtr("2"),
		Name: stringPtr("2"),
		Properties: &network.SecurityRulePropertiesFormat{
			Protocol:                 network.Asterisk,
			SourcePortRange:          stringPtr("*"),
			SourceAddressPrefix:      stringPtr("10.0.0.2"),
			DestinationPortRange:     stringPtr("*"),
			DestinationAddressPrefix: stringPtr("*"),
			Access:    network.Allow,
			Direction: network.Inbound,
		},
	}

	localRules := []network.SecurityRule{rule1, rule2}
	cloudRules := []network.SecurityRule{}

	expectedPriorities := []int32{100, 101}
	cloudPriorities, err := sync(&localRules, &cloudRules)
	if err != nil {
		t.Error(err)
	}

	checkSync(localRules, expectedPriorities, cloudRules, cloudPriorities)

	// Add two more rules.
	rule3 := network.SecurityRule{
		ID:   stringPtr("3"),
		Name: stringPtr("3"),
		Properties: &network.SecurityRulePropertiesFormat{
			Protocol:                 network.Asterisk,
			SourcePortRange:          stringPtr("*"),
			SourceAddressPrefix:      stringPtr("*"),
			DestinationPortRange:     stringPtr("*"),
			DestinationAddressPrefix: stringPtr("10.0.0.3"),
			Access:    network.Allow,
			Direction: network.Inbound,
		},
	}

	rule4 := network.SecurityRule{
		ID:   stringPtr("4"),
		Name: stringPtr("4"),
		Properties: &network.SecurityRulePropertiesFormat{
			Protocol:                 network.Asterisk,
			SourcePortRange:          stringPtr("*"),
			SourceAddressPrefix:      stringPtr("*"),
			DestinationPortRange:     stringPtr("*"),
			DestinationAddressPrefix: stringPtr("10.0.0.4"),
			Access:    network.Allow,
			Direction: network.Inbound,
		},
	}

	localRules = append(localRules, rule3, rule4)
	expectedPriorities = []int32{100, 101, 102, 103}
	cloudPriorities, err = sync(&localRules, &cloudRules)
	if err != nil {
		t.Error(err)
	}

	checkSync(localRules, expectedPriorities, cloudRules, cloudPriorities)

	// Add duplicate rules.
	localRules = append(localRules, rule3, rule4)
	expectedPriorities = []int32{100, 101, 102, 103}
	cloudPriorities, err = sync(&localRules, &cloudRules)
	if err != nil {
		t.Error(err)
	}

	checkSync(localRules, expectedPriorities, cloudRules, cloudPriorities)

	// Keep rule1, and add two new rules.
	rule5 := network.SecurityRule{
		ID:   stringPtr("5"),
		Name: stringPtr("5"),
		Properties: &network.SecurityRulePropertiesFormat{
			Protocol:                 network.Asterisk,
			SourcePortRange:          stringPtr("*"),
			SourceAddressPrefix:      stringPtr("1.2.3.4"),
			DestinationPortRange:     stringPtr("*"),
			DestinationAddressPrefix: stringPtr("*"),
			Access:    network.Allow,
			Direction: network.Inbound,
		},
	}

	rule6 := network.SecurityRule{
		ID:   stringPtr("6"),
		Name: stringPtr("6"),
		Properties: &network.SecurityRulePropertiesFormat{
			Protocol:                 network.Asterisk,
			SourcePortRange:          stringPtr("*"),
			SourceAddressPrefix:      stringPtr("5.6.7.8"),
			DestinationPortRange:     stringPtr("*"),
			DestinationAddressPrefix: stringPtr("*"),
			Access:    network.Allow,
			Direction: network.Inbound,
		},
	}

	localRules = []network.SecurityRule{rule1, rule5, rule6}
	expectedPriorities = []int32{100, 101, 102}
	cloudPriorities, err = sync(&localRules, &cloudRules)
	if err != nil {
		t.Error(err)
	}

	checkSync(localRules, expectedPriorities, cloudRules, cloudPriorities)
}
Beispiel #28
0
// SetACLs adds and removes acls in `clst` so that it conforms to `acls`.
func (clst *Cluster) SetACLs(acls []acl.ACL) error {
	list, err := service.Firewalls.List(clst.projID).Do()
	if err != nil {
		return err
	}

	currACLs := clst.parseACLs(list.Items)
	pair, toAdd, toRemove := join.HashJoin(acl.Slice(acls), acl.Slice(currACLs),
		nil, nil)

	var toSet []acl.ACL
	for _, a := range toAdd {
		toSet = append(toSet, a.(acl.ACL))
	}
	for _, p := range pair {
		toSet = append(toSet, p.L.(acl.ACL))
	}
	for _, a := range toRemove {
		toSet = append(toSet, acl.ACL{
			MinPort: a.(acl.ACL).MinPort,
			MaxPort: a.(acl.ACL).MaxPort,
			CidrIP:  "", // Remove all currently allowed IPs.
		})
	}

	for acl, cidrIPs := range groupACLsByPorts(toSet) {
		fw, err := clst.getCreateFirewall(acl.MinPort, acl.MaxPort)
		if err != nil {
			return err
		}

		if reflect.DeepEqual(fw.SourceRanges, cidrIPs) {
			continue
		}

		var op *compute.Operation
		if len(cidrIPs) == 0 {
			log.WithField("ports", fmt.Sprintf(
				"%d-%d", acl.MinPort, acl.MaxPort)).
				Debug("Google: Deleting firewall")
			op, err = clst.firewallDelete(fw.Name)
			if err != nil {
				return err
			}
		} else {
			log.WithField("ports", fmt.Sprintf(
				"%d-%d", acl.MinPort, acl.MaxPort)).
				WithField("CidrIPs", cidrIPs).
				Debug("Google: Setting ACLs")
			op, err = clst.firewallPatch(fw.Name, cidrIPs)
			if err != nil {
				return err
			}
		}
		if err := clst.operationWait(
			[]*compute.Operation{op}, global); err != nil {
			return err
		}
	}

	return nil
}
Beispiel #29
0
func TestACLs(t *testing.T) {
	ovsdbClient := NewFakeOvsdbClient()

	key := func(val interface{}) interface{} {
		return val.(ACL).Core
	}

	checkCorrectness := func(ovsdbACLs []ACL, localACLs ...ACL) {
		pair, _, _ := join.HashJoin(ACLSlice(ovsdbACLs), ACLSlice(localACLs),
			key, key)
		assert.Equal(t, len(pair), len(localACLs))
	}

	// Create new switch.
	lswitch := "test-switch"
	err := ovsdbClient.CreateLogicalSwitch(lswitch)
	assert.Nil(t, err)

	// Create one ACL rule.
	localCore1 := ACLCore{
		Priority:  1,
		Direction: "from-lport",
		Match:     "0.0.0.0",
		Action:    "allow",
	}

	localACL1 := ACL{
		Core: localCore1,
		Log:  false,
	}

	err = ovsdbClient.CreateACL(lswitch, localCore1.Direction, localCore1.Priority,
		localCore1.Match, localCore1.Action)
	assert.Nil(t, err)

	// It should now have one ACL entry to be listed.
	ovsdbACLs, err := ovsdbClient.ListACLs(lswitch)
	assert.Nil(t, err)
	assert.Equal(t, 1, len(ovsdbACLs))

	ovsdbACL1 := ovsdbACLs[0]

	checkCorrectness(ovsdbACLs, localACL1)

	// Create one more ACL rule.
	localCore2 := ACLCore{
		Priority:  2,
		Direction: "from-lport",
		Match:     "0.0.0.1",
		Action:    "drop",
	}
	localACL2 := ACL{
		Core: localCore2,
		Log:  false,
	}

	err = ovsdbClient.CreateACL(lswitch, localCore2.Direction, localCore2.Priority,
		localCore2.Match, localCore2.Action)
	assert.Nil(t, err)

	// It should now have two ACL entries to be listed.
	ovsdbACLs, err = ovsdbClient.ListACLs(lswitch)
	assert.Nil(t, err)
	assert.Equal(t, 2, len(ovsdbACLs))

	checkCorrectness(ovsdbACLs, localACL1, localACL2)

	// Delete the first ACL rule.
	err = ovsdbClient.DeleteACL(lswitch, ovsdbACL1)
	assert.Nil(t, err)

	// It should now have only one ACL entry to be listed.
	ovsdbACLs, err = ovsdbClient.ListACLs(lswitch)
	assert.Nil(t, err)

	ovsdbACL2 := ovsdbACLs[0]

	assert.Equal(t, 1, len(ovsdbACLs))

	checkCorrectness(ovsdbACLs, localACL2)

	// Delete the other ACL rule.
	err = ovsdbClient.DeleteACL(lswitch, ovsdbACL2)
	assert.Nil(t, err)

	// It should now have only one ACL entry to be listed.
	ovsdbACLs, err = ovsdbClient.ListACLs(lswitch)
	assert.Nil(t, err)
	assert.Zero(t, len(ovsdbACLs))
}