func (ind *bindingIndex) getNode(val interface{}) *indexNode { hashCode := hash(val) head := ind.table[hashCode] for entry := head; entry != nil; entry = entry.next { if util.Compare(entry.key, val) == 0 { return entry } } return nil }
func reduceMax(x interface{}) (ast.Value, error) { switch x := x.(type) { case []interface{}: if len(x) == 0 { return nil, empty{} } var max interface{} for i := range x { if util.Compare(max, x[i]) <= 0 { max = x[i] } } return ast.InterfaceToValue(max) } return nil, fmt.Errorf("max: source must be array") }
func TestInit(t *testing.T) { ctx := context.Background() tmp1, err := ioutil.TempFile("", "docFile") if err != nil { panic(err) } defer os.Remove(tmp1.Name()) doc1 := `{"foo": "bar", "x": {"y": {"z": [1]}}}` if _, err := tmp1.Write([]byte(doc1)); err != nil { panic(err) } if err := tmp1.Close(); err != nil { panic(err) } tmp2, err := ioutil.TempFile("", "policyFile") if err != nil { panic(err) } defer os.Remove(tmp2.Name()) mod1 := ` package a.b.c import data.foo p = true :- foo = "bar" p = true :- 1 = 2 ` if _, err := tmp2.Write([]byte(mod1)); err != nil { panic(err) } if err := tmp2.Close(); err != nil { panic(err) } tmp3, err := ioutil.TempDir("", "policyDir") if err != nil { panic(err) } defer os.RemoveAll(tmp3) tmp4 := filepath.Join(tmp3, "existingPolicy") err = ioutil.WriteFile(tmp4, []byte(` package a.b.c q = true :- false `), 0644) if err != nil { panic(err) } rt := Runtime{} err = rt.init(ctx, &Params{ Paths: []string{tmp1.Name(), tmp2.Name()}, PolicyDir: tmp3, }) if err != nil { t.Errorf("Unexpected error: %v", err) return } txn := storage.NewTransactionOrDie(ctx, rt.Store) node, err := rt.Store.Read(ctx, txn, storage.MustParsePath("/foo")) if util.Compare(node, "bar") != 0 || err != nil { t.Errorf("Expected %v but got %v (err: %v)", "bar", node, err) return } modules := rt.Store.ListPolicies(txn) expected := ast.MustParseModule(mod1) if !expected.Equal(modules[tmp2.Name()]) { t.Fatalf("Expected %v but got: %v", expected, modules[tmp2.Name()]) } }
// Compare returns an integer indicating whether two AST values are less than, // equal to, or greater than each other. // // If a is less than b, the return value is negative. If a is greater than b, // the return value is positive. If a is equal to b, the return value is zero. // // Different types are never equal to each other. For comparison purposes, types // are sorted as follows: // // nil < Null < Boolean < Number < String < Var < Ref < Array < Object < Set < // ArrayComprehension < Expr < Body < Rule < Import < Package < Module. // // Arrays and Refs are equal iff both a and b have the same length and all // corresponding elements are equal. If one element is not equal, the return // value is the same as for the first differing element. If all elements are // equal but a and b have different lengths, the shorter is considered less than // the other. // // Objects are considered equal iff both a and b have the same sorted (key, // value) pairs and are of the same length. Other comparisons are consistent but // not defined. // // Sets are considered equal iff the symmetric difference of a and b is empty. // Other comparisons are consistent but not defined. func Compare(a, b interface{}) int { if t, ok := a.(*Term); ok { if t == nil { return Compare(nil, b) } return Compare(t.Value, b) } if t, ok := b.(*Term); ok { if t == nil { return Compare(a, nil) } return Compare(a, t.Value) } if a == nil { if b == nil { return 0 } return -1 } if b == nil { return 1 } sortA := sortOrder(a) sortB := sortOrder(b) if sortA < sortB { return -1 } else if sortB < sortA { return 1 } switch a := a.(type) { case Null: return 0 case Boolean: b := b.(Boolean) if a.Equal(b) { return 0 } if !a { return -1 } return 1 case Number: return util.Compare(json.Number(a), json.Number(b.(Number))) case String: b := b.(String) if a.Equal(b) { return 0 } if a < b { return -1 } return 1 case Var: b := b.(Var) if a.Equal(b) { return 0 } if a < b { return -1 } return 1 case Ref: b := b.(Ref) return termSliceCompare(a, b) case Array: b := b.(Array) return termSliceCompare(a, b) case Object: b := b.(Object) keysA := a.Keys() keysB := b.Keys() sort.Sort(termSlice(keysA)) sort.Sort(termSlice(keysB)) minLen := len(a) if len(b) < len(a) { minLen = len(b) } for i := 0; i < minLen; i++ { keysCmp := Compare(keysA[i], keysB[i]) if keysCmp < 0 { return -1 } if keysCmp > 0 { return 1 } valA := a.Get(keysA[i]) valB := b.Get(keysB[i]) valCmp := Compare(valA, valB) if valCmp != 0 { return valCmp } } if len(a) < len(b) { return -1 } if len(b) < len(a) { return 1 } return 0 case *Set: b := b.(*Set) sort.Sort(termSlice(*a)) sort.Sort(termSlice(*b)) return termSliceCompare(*a, *b) case *ArrayComprehension: b := b.(*ArrayComprehension) if cmp := Compare(a.Term, b.Term); cmp != 0 { return cmp } return Compare(a.Body, b.Body) case *Expr: b := b.(*Expr) return a.Compare(b) case Body: b := b.(Body) return a.Compare(b) case *Rule: b := b.(*Rule) return a.Compare(b) case *Import: b := b.(*Import) return a.Compare(b) case *Package: b := b.(*Package) return a.Compare(b) case *Module: b := b.(*Module) return a.Compare(b) } panic(fmt.Sprintf("illegal value: %T", a)) }