Пример #1
0
func TestCART(t *testing.T) {
	fds := "../../testdata/iris/iris.dsv"

	ds := tabula.Claset{}

	_, e := dsv.SimpleRead(fds, &ds)
	if nil != e {
		t.Fatal(e)
	}

	fmt.Println("[cart_test] class index:", ds.GetClassIndex())

	// copy target to be compared later.
	targetv := ds.GetClassAsStrings()

	assert(t, NRows, ds.GetNRow(), true)

	// Build CART tree.
	CART, e := cart.New(&ds, cart.SplitMethodGini, 0)
	if e != nil {
		t.Fatal(e)
	}

	fmt.Println("[cart_test] CART Tree:\n", CART)

	// Create test set
	testset := tabula.Claset{}
	_, e = dsv.SimpleRead(fds, &testset)

	if nil != e {
		t.Fatal(e)
	}

	testset.GetClassColumn().ClearValues()

	// Classifiy test set
	e = CART.ClassifySet(&testset)
	if nil != e {
		t.Fatal(e)
	}

	assert(t, targetv, testset.GetClassAsStrings(), true)
}
Пример #2
0
/*
GrowTree build a new tree in forest, return OOB error value or error if tree
can not grow.

Algorithm,

(1) Select random samples with replacement, also with OOB.
(2) Build tree using CART, without pruning.
(3) Add tree to forest.
(4) Save index of random samples for calculating error rate later.
(5) Run OOB on forest.
(6) Calculate OOB error rate and statistic values.
*/
func (forest *Runtime) GrowTree(samples tabula.ClasetInterface) (
	cm *classifier.CM, stat *classifier.Stat, e error,
) {
	stat = &classifier.Stat{}
	stat.ID = int64(len(forest.trees))
	stat.Start()

	// (1)
	bag, oob, bagIdx, oobIdx := tabula.RandomPickRows(
		samples.(tabula.DatasetInterface),
		forest.nSubsample, true)

	bagset := bag.(tabula.ClasetInterface)

	if DEBUG >= 2 {
		bagset.RecountMajorMinor()
		fmt.Println(tag, "Bagging:", bagset)
	}

	// (2)
	cart, e := cart.New(bagset, cart.SplitMethodGini,
		forest.NRandomFeature)
	if e != nil {
		return nil, nil, e
	}

	// (3)
	forest.AddCartTree(*cart)

	// (4)
	forest.AddBagIndex(bagIdx)

	// (5)
	if forest.RunOOB {
		oobset := oob.(tabula.ClasetInterface)
		_, cm, _ = forest.ClassifySet(oobset, oobIdx)

		forest.AddOOBCM(cm)
	}

	stat.End()

	if DEBUG >= 3 && forest.RunOOB {
		fmt.Println(tag, "Elapsed time (s):", stat.ElapsedTime)
	}

	forest.AddStat(stat)

	// (6)
	if forest.RunOOB {
		forest.ComputeStatFromCM(stat, cm)

		if DEBUG >= 2 {
			fmt.Println(tag, "OOB stat:", stat)
		}
	}

	forest.ComputeStatTotal(stat)
	e = forest.WriteOOBStat(stat)

	return cm, stat, e
}