func TestCART(t *testing.T) { fds := "../../testdata/iris/iris.dsv" ds := tabula.Claset{} _, e := dsv.SimpleRead(fds, &ds) if nil != e { t.Fatal(e) } fmt.Println("[cart_test] class index:", ds.GetClassIndex()) // copy target to be compared later. targetv := ds.GetClassAsStrings() assert(t, NRows, ds.GetNRow(), true) // Build CART tree. CART, e := cart.New(&ds, cart.SplitMethodGini, 0) if e != nil { t.Fatal(e) } fmt.Println("[cart_test] CART Tree:\n", CART) // Create test set testset := tabula.Claset{} _, e = dsv.SimpleRead(fds, &testset) if nil != e { t.Fatal(e) } testset.GetClassColumn().ClearValues() // Classifiy test set e = CART.ClassifySet(&testset) if nil != e { t.Fatal(e) } assert(t, targetv, testset.GetClassAsStrings(), true) }
func TestReaderWithClaset(t *testing.T) { fcfg := "testdata/claset.dsv" claset := tabula.Claset{} _, e := dsv.NewReader(fcfg, &claset) if e != nil { t.Fatal(e) } assert(t, 3, claset.GetClassIndex(), true) claset.SetMajorityClass("regular") claset.SetMinorityClass("vandalism") clone := claset.Clone().(tabula.ClasetInterface) assert(t, 3, clone.GetClassIndex(), true) assert(t, "regular", clone.MajorityClass(), true) assert(t, "vandalism", clone.MinorityClass(), true) }
func main() { defer un(trace("cart")) flag.Parse() if len(flag.Args()) <= 0 { usage() os.Exit(1) } fcfg := flag.Arg(0) // Parsing config file and check command parameter values. cartrt, e := createCart(fcfg) if e != nil { panic(e) } // Get dataset dataset := tabula.Claset{} _, e = dsv.SimpleRead(fcfg, &dataset) if e != nil { panic(e) } if DEBUG >= 1 { fmt.Printf("[cart] Class index: %v\n", dataset.GetClassIndex()) } e = cartrt.Build(&dataset) if e != nil { panic(e) } if DEBUG >= 1 { fmt.Println("[cart] CART tree:\n", cartrt) } }
func getSamples() (train, test tabula.ClasetInterface) { samples := tabula.Claset{} _, e := dsv.SimpleRead(SampleDsvFile, &samples) if nil != e { log.Fatal(e) } if !DoTest { return &samples, nil } ntrain := int(float32(samples.Len()) * (float32(NBootstrap) / 100.0)) bag, oob, _, _ := tabula.RandomPickRows(&samples, ntrain, false) train = bag.(tabula.ClasetInterface) test = oob.(tabula.ClasetInterface) train.SetClassIndex(samples.GetClassIndex()) test.SetClassIndex(samples.GetClassIndex()) return train, test }