示例#1
0
文件: index.go 项目: tsandall/opa
func (ind *bindingIndex) getNode(val interface{}) *indexNode {
	hashCode := hash(val)
	head := ind.table[hashCode]
	for entry := head; entry != nil; entry = entry.next {
		if util.Compare(entry.key, val) == 0 {
			return entry
		}
	}
	return nil
}
示例#2
0
func reduceMax(x interface{}) (ast.Value, error) {
	switch x := x.(type) {
	case []interface{}:
		if len(x) == 0 {
			return nil, empty{}
		}
		var max interface{}
		for i := range x {
			if util.Compare(max, x[i]) <= 0 {
				max = x[i]
			}
		}
		return ast.InterfaceToValue(max)
	}
	return nil, fmt.Errorf("max: source must be array")
}
示例#3
0
func TestInit(t *testing.T) {
	ctx := context.Background()

	tmp1, err := ioutil.TempFile("", "docFile")
	if err != nil {
		panic(err)
	}
	defer os.Remove(tmp1.Name())

	doc1 := `{"foo": "bar", "x": {"y": {"z": [1]}}}`
	if _, err := tmp1.Write([]byte(doc1)); err != nil {
		panic(err)
	}
	if err := tmp1.Close(); err != nil {
		panic(err)
	}

	tmp2, err := ioutil.TempFile("", "policyFile")
	if err != nil {
		panic(err)
	}
	defer os.Remove(tmp2.Name())
	mod1 := `
	package a.b.c
	import data.foo
	p = true :- foo = "bar"
	p = true :- 1 = 2
	`
	if _, err := tmp2.Write([]byte(mod1)); err != nil {
		panic(err)
	}
	if err := tmp2.Close(); err != nil {
		panic(err)
	}

	tmp3, err := ioutil.TempDir("", "policyDir")
	if err != nil {
		panic(err)
	}

	defer os.RemoveAll(tmp3)

	tmp4 := filepath.Join(tmp3, "existingPolicy")

	err = ioutil.WriteFile(tmp4, []byte(`
	package a.b.c
	q = true :- false
	`), 0644)
	if err != nil {
		panic(err)
	}

	rt := Runtime{}

	err = rt.init(ctx, &Params{
		Paths:     []string{tmp1.Name(), tmp2.Name()},
		PolicyDir: tmp3,
	})

	if err != nil {
		t.Errorf("Unexpected error: %v", err)
		return
	}

	txn := storage.NewTransactionOrDie(ctx, rt.Store)

	node, err := rt.Store.Read(ctx, txn, storage.MustParsePath("/foo"))
	if util.Compare(node, "bar") != 0 || err != nil {
		t.Errorf("Expected %v but got %v (err: %v)", "bar", node, err)
		return
	}

	modules := rt.Store.ListPolicies(txn)
	expected := ast.MustParseModule(mod1)

	if !expected.Equal(modules[tmp2.Name()]) {
		t.Fatalf("Expected %v but got: %v", expected, modules[tmp2.Name()])
	}

}
示例#4
0
文件: compare.go 项目: tsandall/opa
// Compare returns an integer indicating whether two AST values are less than,
// equal to, or greater than each other.
//
// If a is less than b, the return value is negative. If a is greater than b,
// the return value is positive. If a is equal to b, the return value is zero.
//
// Different types are never equal to each other. For comparison purposes, types
// are sorted as follows:
//
// nil < Null < Boolean < Number < String < Var < Ref < Array < Object < Set <
// ArrayComprehension < Expr < Body < Rule < Import < Package < Module.
//
// Arrays and Refs are equal iff both a and b have the same length and all
// corresponding elements are equal. If one element is not equal, the return
// value is the same as for the first differing element. If all elements are
// equal but a and b have different lengths, the shorter is considered less than
// the other.
//
// Objects are considered equal iff both a and b have the same sorted (key,
// value) pairs and are of the same length. Other comparisons are consistent but
// not defined.
//
// Sets are considered equal iff the symmetric difference of a and b is empty.
// Other comparisons are consistent but not defined.
func Compare(a, b interface{}) int {

	if t, ok := a.(*Term); ok {
		if t == nil {
			return Compare(nil, b)
		}
		return Compare(t.Value, b)
	}

	if t, ok := b.(*Term); ok {
		if t == nil {
			return Compare(a, nil)
		}
		return Compare(a, t.Value)
	}

	if a == nil {
		if b == nil {
			return 0
		}
		return -1
	}
	if b == nil {
		return 1
	}

	sortA := sortOrder(a)
	sortB := sortOrder(b)

	if sortA < sortB {
		return -1
	} else if sortB < sortA {
		return 1
	}

	switch a := a.(type) {
	case Null:
		return 0
	case Boolean:
		b := b.(Boolean)
		if a.Equal(b) {
			return 0
		}
		if !a {
			return -1
		}
		return 1
	case Number:
		return util.Compare(json.Number(a), json.Number(b.(Number)))
	case String:
		b := b.(String)
		if a.Equal(b) {
			return 0
		}
		if a < b {
			return -1
		}
		return 1
	case Var:
		b := b.(Var)
		if a.Equal(b) {
			return 0
		}
		if a < b {
			return -1
		}
		return 1
	case Ref:
		b := b.(Ref)
		return termSliceCompare(a, b)
	case Array:
		b := b.(Array)
		return termSliceCompare(a, b)
	case Object:
		b := b.(Object)
		keysA := a.Keys()
		keysB := b.Keys()
		sort.Sort(termSlice(keysA))
		sort.Sort(termSlice(keysB))
		minLen := len(a)
		if len(b) < len(a) {
			minLen = len(b)
		}
		for i := 0; i < minLen; i++ {
			keysCmp := Compare(keysA[i], keysB[i])
			if keysCmp < 0 {
				return -1
			}
			if keysCmp > 0 {
				return 1
			}
			valA := a.Get(keysA[i])
			valB := b.Get(keysB[i])
			valCmp := Compare(valA, valB)
			if valCmp != 0 {
				return valCmp
			}
		}
		if len(a) < len(b) {
			return -1
		}
		if len(b) < len(a) {
			return 1
		}
		return 0
	case *Set:
		b := b.(*Set)
		sort.Sort(termSlice(*a))
		sort.Sort(termSlice(*b))
		return termSliceCompare(*a, *b)
	case *ArrayComprehension:
		b := b.(*ArrayComprehension)
		if cmp := Compare(a.Term, b.Term); cmp != 0 {
			return cmp
		}
		return Compare(a.Body, b.Body)
	case *Expr:
		b := b.(*Expr)
		return a.Compare(b)
	case Body:
		b := b.(Body)
		return a.Compare(b)
	case *Rule:
		b := b.(*Rule)
		return a.Compare(b)
	case *Import:
		b := b.(*Import)
		return a.Compare(b)
	case *Package:
		b := b.(*Package)
		return a.Compare(b)
	case *Module:
		b := b.(*Module)
		return a.Compare(b)
	}
	panic(fmt.Sprintf("illegal value: %T", a))
}