func syncACLs(ovsdbClient ovsdb.Client, connections []db.Connection) { ovsdbACLs, err := ovsdbClient.ListACLs(lSwitch) if err != nil { log.WithError(err).Error("Failed to list ACLs") return } expACLs := directedACLs(ovsdb.ACL{ Core: ovsdb.ACLCore{ Action: "drop", Match: "ip", Priority: 0, }, }) for _, conn := range connections { if conn.From == stitch.PublicInternetLabel || conn.To == stitch.PublicInternetLabel { continue } expACLs = append(expACLs, directedACLs( ovsdb.ACL{ Core: ovsdb.ACLCore{ Action: "allow", Match: matchString(conn), Priority: 1, }, })...) } ovsdbKey := func(ovsdbIntf interface{}) interface{} { return ovsdbIntf.(ovsdb.ACL).Core } _, toCreate, toDelete := join.HashJoin(ovsdbACLSlice(expACLs), ovsdbACLSlice(ovsdbACLs), ovsdbKey, ovsdbKey) for _, acl := range toDelete { if err := ovsdbClient.DeleteACL(lSwitch, acl.(ovsdb.ACL)); err != nil { log.WithError(err).Warn("Error deleting ACL") } } for _, intf := range toCreate { acl := intf.(ovsdb.ACL).Core if err := ovsdbClient.CreateACL(lSwitch, acl.Direction, acl.Priority, acl.Match, acl.Action); err != nil { log.WithError(err).Warn("Error adding ACL") } } }
func checkACLs(t *testing.T, client ovsdb.Client, connections []db.Connection, exp []ovsdb.ACL) { syncACLs(client, connections) actual, _ := client.ListACLs(lSwitch) ovsdbKey := func(ovsdbIntf interface{}) interface{} { return ovsdbIntf.(ovsdb.ACL).Core } if _, left, right := join.HashJoin(ovsdbACLSlice(actual), ovsdbACLSlice(exp), ovsdbKey, ovsdbKey); len(left) != 0 || len(right) != 0 { t.Errorf("Wrong ACLs: expected %v, got %v.", exp, actual) } }