示例#1
0
func TestRandomForest1(testEnv *testing.T) {
	inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
	if err != nil {
		panic(err)
	}

	rand.Seed(time.Now().UnixNano())
	trainData, testData := base.InstancesTrainTestSplit(inst, 0.6)
	filt := filters.NewChiMergeFilter(inst, 0.90)
	for _, a := range base.NonClassFloatAttributes(inst) {
		filt.AddAttribute(a)
	}
	filt.Train()
	trainDataf := base.NewLazilyFilteredInstances(trainData, filt)
	testDataf := base.NewLazilyFilteredInstances(testData, filt)
	rf := new(BaggedModel)
	for i := 0; i < 10; i++ {
		rf.AddModel(trees.NewRandomTree(2))
	}
	rf.Fit(trainDataf)
	fmt.Println(rf)
	predictions := rf.Predict(testDataf)
	fmt.Println(predictions)
	confusionMat := eval.GetConfusionMatrix(testDataf, predictions)
	fmt.Println(confusionMat)
	fmt.Println(eval.GetMacroPrecision(confusionMat))
	fmt.Println(eval.GetMacroRecall(confusionMat))
	fmt.Println(eval.GetSummary(confusionMat))
}
示例#2
0
func TestChiMergeFilter(t *testing.T) {
	Convey("Chi-Merge Filter", t, func() {
		// See http://sci2s.ugr.es/keel/pdf/algorithm/congreso/1992-Kerber-ChimErge-AAAI92.pdf
		//   Randy Kerber, ChiMerge: Discretisation of Numeric Attributes, 1992
		instances, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
		So(err, ShouldBeNil)

		Convey("Create and train the filter", func() {
			filter := NewChiMergeFilter(instances, 0.90)
			filter.AddAttribute(instances.AllAttributes()[0])
			filter.AddAttribute(instances.AllAttributes()[1])
			filter.Train()

			Convey("Filter the dataset", func() {
				filteredInstances := base.NewLazilyFilteredInstances(instances, filter)

				classAttributes := filteredInstances.AllClassAttributes()

				Convey("There should only be one class attribute", func() {
					So(len(classAttributes), ShouldEqual, 1)
				})

				expectedClassAttribute := "Species"

				Convey(fmt.Sprintf("The class attribute should be %s", expectedClassAttribute), func() {
					So(classAttributes[0].GetName(), ShouldEqual, expectedClassAttribute)
				})
			})
		})
	})
}
示例#3
0
func main() {

	var tree base.Classifier

	rand.Seed(44111342)

	// Load in the iris dataset
	iris, err := base.ParseCSVToInstances("/home/kralli/go/src/github.com/sjwhitworth/golearn/examples/datasets/iris_headers.csv", true)
	if err != nil {
		panic(err)
	}

	// Discretise the iris dataset with Chi-Merge
	filt := filters.NewChiMergeFilter(iris, 0.999)
	for _, a := range base.NonClassFloatAttributes(iris) {
		filt.AddAttribute(a)
	}
	filt.Train()
	irisf := base.NewLazilyFilteredInstances(iris, filt)

	// Create a 60-40 training-test split
	//testData
	trainData, _ := base.InstancesTrainTestSplit(iris, 0.60)

	findBestSplit(trainData)

	//fmt.Println(trainData)
	//fmt.Println(testData)

	fmt.Println(tree)
	fmt.Println(irisf)
}
示例#4
0
// Predict issues predictions. Each class-specific classifier is expected
// to output a value between 0 (indicating that a given instance is not
// a given class) and 1 (indicating that the given instance is definitely
// that class). For each instance, the class with the highest value is chosen.
// The result is undefined if several underlying models output the same value.
func (m *OneVsAllModel) Predict(what base.FixedDataGrid) (base.FixedDataGrid, error) {
	ret := base.GeneratePredictionVector(what)
	vecs := make([]base.FixedDataGrid, m.maxClassVal+1)
	specs := make([]base.AttributeSpec, m.maxClassVal+1)
	for i := uint64(0); i <= m.maxClassVal; i++ {
		f := m.filters[i]
		c := base.NewLazilyFilteredInstances(what, f)
		p, err := m.classifiers[i].Predict(c)
		if err != nil {
			return nil, err
		}
		vecs[i] = p
		specs[i] = base.ResolveAttributes(p, p.AllClassAttributes())[0]
	}
	_, rows := ret.Size()
	spec := base.ResolveAttributes(ret, ret.AllClassAttributes())[0]
	for i := 0; i < rows; i++ {
		class := uint64(0)
		best := 0.0
		for j := uint64(0); j <= m.maxClassVal; j++ {
			val := base.UnpackBytesToFloat(vecs[j].Get(specs[j], i))
			if val > best {
				class = j
				best = val
			}
		}
		ret.Set(spec, i, base.PackU64ToBytes(class))
	}
	return ret, nil
}
示例#5
0
func TestBinaryFilterClassPreservation(t *testing.T) {
	Convey("Given a contrived dataset...", t, func() {
		// Read the contrived dataset
		inst, err := base.ParseCSVToInstances("./binary_test.csv", true)
		So(err, ShouldEqual, nil)

		// Add all Attributes to the filter
		bFilt := NewBinaryConvertFilter()
		bAttrs := inst.AllAttributes()
		for _, a := range bAttrs {
			bFilt.AddAttribute(a)
		}
		bFilt.Train()

		// Construct a LazilyFilteredInstances to handle it
		instF := base.NewLazilyFilteredInstances(inst, bFilt)

		Convey("All the expected class Attributes should be present if discretised...", func() {
			attrMap := make(map[string]bool)
			attrMap["arbitraryClass_hi"] = false
			attrMap["arbitraryClass_there"] = false
			attrMap["arbitraryClass_world"] = false

			for _, a := range instF.AllClassAttributes() {
				attrMap[a.GetName()] = true
			}

			So(attrMap["arbitraryClass_hi"], ShouldEqual, true)
			So(attrMap["arbitraryClass_there"], ShouldEqual, true)
			So(attrMap["arbitraryClass_world"], ShouldEqual, true)
		})
	})
}
示例#6
0
func BenchmarkBaggingRandomForestPredict(t *testing.B) {
	inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
	if err != nil {
		t.Fatal("Unable to parse CSV to instances: %s", err.Error())
	}

	rand.Seed(time.Now().UnixNano())
	filt := filters.NewChiMergeFilter(inst, 0.90)
	for _, a := range base.NonClassFloatAttributes(inst) {
		filt.AddAttribute(a)
	}
	filt.Train()
	instf := base.NewLazilyFilteredInstances(inst, filt)

	rf := new(BaggedModel)
	for i := 0; i < 10; i++ {
		rf.AddModel(trees.NewRandomTree(2))
	}

	rf.Fit(instf)
	t.ResetTimer()
	for i := 0; i < 20; i++ {
		rf.Predict(instf)
	}
}
示例#7
0
func TestRandomTreeClassificationAfterDiscretisation(t *testing.T) {
	Convey("Predictions on filtered data with a Random Tree", t, func() {
		instances, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
		So(err, ShouldBeNil)

		trainData, testData := base.InstancesTrainTestSplit(instances, 0.6)

		filter := filters.NewChiMergeFilter(instances, 0.9)
		for _, a := range base.NonClassFloatAttributes(instances) {
			filter.AddAttribute(a)
		}
		filter.Train()
		filteredTrainData := base.NewLazilyFilteredInstances(trainData, filter)
		filteredTestData := base.NewLazilyFilteredInstances(testData, filter)
		verifyTreeClassification(filteredTrainData, filteredTestData)
	})
}
示例#8
0
func (m *MultiLayerNet) convertToFloatInsts(X base.FixedDataGrid) base.FixedDataGrid {

	// Make sure everything's a FloatAttribute
	fFilt := filters.NewFloatConvertFilter()
	for _, a := range X.AllAttributes() {
		fFilt.AddAttribute(a)
	}
	fFilt.Train()
	insts := base.NewLazilyFilteredInstances(X, fFilt)
	return insts
}
示例#9
0
func convertToBinary(src base.FixedDataGrid) base.FixedDataGrid {
	// Convert to binary
	b := filters.NewBinaryConvertFilter()
	attrs := base.NonClassAttributes(src)
	for _, a := range attrs {
		b.AddAttribute(a)
	}
	b.Train()
	ret := base.NewLazilyFilteredInstances(src, b)
	return ret
}
示例#10
0
func TestBaggedModelRandomForest(t *testing.T) {
	Convey("Given data", t, func() {
		inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
		So(err, ShouldBeNil)

		Convey("Splitting the data into training and test data", func() {
			trainData, testData := base.InstancesTrainTestSplit(inst, 0.6)

			Convey("Filtering the split datasets", func() {
				rand.Seed(time.Now().UnixNano())
				filt := filters.NewChiMergeFilter(inst, 0.90)
				for _, a := range base.NonClassFloatAttributes(inst) {
					filt.AddAttribute(a)
				}
				filt.Train()
				trainDataf := base.NewLazilyFilteredInstances(trainData, filt)
				testDataf := base.NewLazilyFilteredInstances(testData, filt)

				Convey("Fitting and Predicting with a Bagged Model of 10 Random Trees", func() {
					rf := new(BaggedModel)
					for i := 0; i < 10; i++ {
						rf.AddModel(trees.NewRandomTree(2))
					}

					rf.Fit(trainDataf)
					predictions := rf.Predict(testDataf)

					confusionMat, err := evaluation.GetConfusionMatrix(testDataf, predictions)
					So(err, ShouldBeNil)

					Convey("Predictions are somewhat accurate", func() {
						So(evaluation.GetAccuracy(confusionMat), ShouldBeGreaterThan, 0.5)
					})
				})
			})
		})
	})
}
示例#11
0
func TestBinning(t *testing.T) {
	Convey("Given some data and a reference", t, func() {
		// Read the data
		inst1, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
		if err != nil {
			panic(err)
		}

		inst2, err := base.ParseCSVToInstances("../examples/datasets/iris_binned.csv", true)
		if err != nil {
			panic(err)
		}
		//
		// Construct the binning filter
		binAttr := inst1.AllAttributes()[0]
		filt := NewBinningFilter(inst1, 10)
		filt.AddAttribute(binAttr)
		filt.Train()
		inst1f := base.NewLazilyFilteredInstances(inst1, filt)

		// Retrieve the categorical version of the original Attribute
		var cAttr base.Attribute
		for _, a := range inst1f.AllAttributes() {
			if a.GetName() == binAttr.GetName() {
				cAttr = a
			}
		}

		cAttrSpec, err := inst1f.GetAttribute(cAttr)
		So(err, ShouldEqual, nil)
		binAttrSpec, err := inst2.GetAttribute(binAttr)
		So(err, ShouldEqual, nil)

		//
		// Create the LazilyFilteredInstances
		// and check the values
		Convey("Discretized version should match reference", func() {
			_, rows := inst1.Size()
			for i := 0; i < rows; i++ {
				val1 := inst1f.Get(cAttrSpec, i)
				val2 := inst2.Get(binAttrSpec, i)
				val1s := cAttr.GetStringFromSysVal(val1)
				val2s := binAttr.GetStringFromSysVal(val2)
				So(val1s, ShouldEqual, val2s)
			}
		})
	})
}
示例#12
0
func TestRandomForest(t *testing.T) {
	Convey("Given a valid CSV file", t, func() {
		inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
		So(err, ShouldBeNil)

		Convey("When Chi-Merge filtering the data", func() {
			filt := filters.NewChiMergeFilter(inst, 0.90)
			for _, a := range base.NonClassFloatAttributes(inst) {
				filt.AddAttribute(a)
			}
			filt.Train()
			instf := base.NewLazilyFilteredInstances(inst, filt)

			Convey("Splitting the data into test and training sets", func() {
				trainData, testData := base.InstancesTrainTestSplit(instf, 0.60)

				Convey("Fitting and predicting with a Random Forest", func() {
					rf := NewRandomForest(10, 3)
					err = rf.Fit(trainData)
					So(err, ShouldBeNil)

					predictions, err := rf.Predict(testData)
					So(err, ShouldBeNil)

					confusionMat, err := evaluation.GetConfusionMatrix(testData, predictions)
					So(err, ShouldBeNil)

					Convey("Predictions should be somewhat accurate", func() {
						So(evaluation.GetAccuracy(confusionMat), ShouldBeGreaterThan, 0.35)
					})
				})
			})
		})

		Convey("Fitting with a Random Forest with too many features compared to the data", func() {
			rf := NewRandomForest(10, len(base.NonClassAttributes(inst))+1)
			err = rf.Fit(inst)

			Convey("Should return an error", func() {
				So(err, ShouldNotBeNil)
			})
		})
	})
}
示例#13
0
// Fit creates n filtered datasets (where n is the number of values
// a CategoricalAttribute can take) and uses them to train the
// underlying classifiers.
func (m *OneVsAllModel) Fit(using base.FixedDataGrid) {
	var classAttr *base.CategoricalAttribute
	// Do some validation
	classAttrs := using.AllClassAttributes()
	for _, a := range classAttrs {
		if c, ok := a.(*base.CategoricalAttribute); !ok {
			panic("Unsupported ClassAttribute type")
		} else {
			classAttr = c
		}
	}
	attrs := m.generateAttributes(using)

	// Find the highest stored value
	val := uint64(0)
	classVals := classAttr.GetValues()
	for _, s := range classVals {
		cur := base.UnpackBytesToU64(classAttr.GetSysValFromString(s))
		if cur > val {
			val = cur
		}
	}
	if val == 0 {
		panic("Must have more than one class!")
	}
	m.maxClassVal = val

	// Create individual filtered instances for training
	filters := make([]*oneVsAllFilter, val+1)
	classifiers := make([]base.Classifier, val+1)
	for i := uint64(0); i <= val; i++ {
		f := &oneVsAllFilter{
			attrs,
			classAttr,
			i,
		}
		filters[i] = f
		classifiers[i] = m.NewClassifierFunction(classVals[int(i)])
		classifiers[i].Fit(base.NewLazilyFilteredInstances(using, f))
	}

	m.filters = filters
	m.classifiers = classifiers
}
示例#14
0
func TestChiMerge4(testEnv *testing.T) {
	// See http://sci2s.ugr.es/keel/pdf/algorithm/congreso/1992-Kerber-ChimErge-AAAI92.pdf
	//   Randy Kerber, ChiMerge: Discretisation of Numeric Attributes, 1992
	inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
	if err != nil {
		panic(err)
	}

	filt := NewChiMergeFilter(inst, 0.90)
	filt.AddAttribute(inst.AllAttributes()[0])
	filt.AddAttribute(inst.AllAttributes()[1])
	filt.Train()
	instf := base.NewLazilyFilteredInstances(inst, filt)
	fmt.Println(instf)
	fmt.Println(instf.String())
	clsAttrs := instf.AllClassAttributes()
	if len(clsAttrs) != 1 {
		panic(fmt.Sprintf("%d != %d", len(clsAttrs), 1))
	}
	if clsAttrs[0].GetName() != "Species" {
		panic("Class Attribute wrong!")
	}
}
示例#15
0
func TestRandomForest1(testEnv *testing.T) {
	inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
	if err != nil {
		panic(err)
	}

	filt := filters.NewChiMergeFilter(inst, 0.90)
	for _, a := range base.NonClassFloatAttributes(inst) {
		filt.AddAttribute(a)
	}
	filt.Train()
	instf := base.NewLazilyFilteredInstances(inst, filt)

	trainData, testData := base.InstancesTrainTestSplit(instf, 0.60)

	rf := NewRandomForest(10, 3)
	rf.Fit(trainData)
	predictions := rf.Predict(testData)
	fmt.Println(predictions)
	confusionMat := eval.GetConfusionMatrix(testData, predictions)
	fmt.Println(confusionMat)
	fmt.Println(eval.GetSummary(confusionMat))
}
示例#16
0
func TestRandomTreeClassification(t *testing.T) {
	Convey("Predictions on filtered data with a Random Tree", t, func() {
		instances, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
		So(err, ShouldBeNil)

		trainData, testData := base.InstancesTrainTestSplit(instances, 0.6)

		filter := filters.NewChiMergeFilter(instances, 0.9)
		for _, a := range base.NonClassFloatAttributes(instances) {
			filter.AddAttribute(a)
		}
		filter.Train()
		filteredTrainData := base.NewLazilyFilteredInstances(trainData, filter)
		filteredTestData := base.NewLazilyFilteredInstances(testData, filter)

		Convey("Using InferID3Tree to create the tree and do the fitting", func() {
			Convey("Using a RandomTreeRule", func() {
				randomTreeRuleGenerator := new(RandomTreeRuleGenerator)
				randomTreeRuleGenerator.Attributes = 2
				root := InferID3Tree(filteredTrainData, randomTreeRuleGenerator)

				Convey("Predicting with the tree", func() {
					predictions, err := root.Predict(filteredTestData)
					So(err, ShouldBeNil)

					confusionMatrix, err := evaluation.GetConfusionMatrix(filteredTestData, predictions)
					So(err, ShouldBeNil)

					Convey("Predictions should be somewhat accurate", func() {
						So(evaluation.GetAccuracy(confusionMatrix), ShouldBeGreaterThan, 0.5)
					})
				})
			})

			Convey("Using a InformationGainRule", func() {
				informationGainRuleGenerator := new(InformationGainRuleGenerator)
				root := InferID3Tree(filteredTrainData, informationGainRuleGenerator)

				Convey("Predicting with the tree", func() {
					predictions, err := root.Predict(filteredTestData)
					So(err, ShouldBeNil)

					confusionMatrix, err := evaluation.GetConfusionMatrix(filteredTestData, predictions)
					So(err, ShouldBeNil)

					Convey("Predictions should be somewhat accurate", func() {
						So(evaluation.GetAccuracy(confusionMatrix), ShouldBeGreaterThan, 0.5)
					})
				})
			})
		})

		Convey("Using NewRandomTree to create the tree", func() {
			root := NewRandomTree(2)

			Convey("Fitting with the tree", func() {
				err = root.Fit(filteredTrainData)
				So(err, ShouldBeNil)

				Convey("Predicting with the tree, *without* pruning first", func() {
					predictions, err := root.Predict(filteredTestData)
					So(err, ShouldBeNil)

					confusionMatrix, err := evaluation.GetConfusionMatrix(filteredTestData, predictions)
					So(err, ShouldBeNil)

					Convey("Predictions should be somewhat accurate", func() {
						So(evaluation.GetAccuracy(confusionMatrix), ShouldBeGreaterThan, 0.5)
					})
				})

				Convey("Predicting with the tree, pruning first", func() {
					root.Prune(filteredTestData)

					predictions, err := root.Predict(filteredTestData)
					So(err, ShouldBeNil)

					confusionMatrix, err := evaluation.GetConfusionMatrix(filteredTestData, predictions)
					So(err, ShouldBeNil)

					Convey("Predictions should be somewhat accurate", func() {
						So(evaluation.GetAccuracy(confusionMatrix), ShouldBeGreaterThan, 0.4)
					})
				})
			})
		})
	})
}
示例#17
0
func TestFloatFilter(t *testing.T) {

	Convey("Given a contrived dataset...", t, func() {

		// Read the contrived dataset
		inst, err := base.ParseCSVToInstances("./binary_test.csv", true)
		So(err, ShouldEqual, nil)

		// Add Attributes to the filter
		bFilt := NewFloatConvertFilter()
		bAttrs := base.NonClassAttributes(inst)
		for _, a := range bAttrs {
			bFilt.AddAttribute(a)
		}
		bFilt.Train()

		// Construct a LazilyFilteredInstances to handle it
		instF := base.NewLazilyFilteredInstances(inst, bFilt)

		Convey("All the non-class Attributes should be floats...", func() {
			// Check that all the Attributes are the right type
			for _, a := range base.NonClassAttributes(instF) {
				_, ok := a.(*base.FloatAttribute)
				So(ok, ShouldEqual, true)
			}
		})

		// Check that all the class Attributes made it
		Convey("All the class Attributes should have survived...", func() {
			origClassAttrs := inst.AllClassAttributes()
			newClassAttrs := instF.AllClassAttributes()
			intersectClassAttrs := base.AttributeIntersect(origClassAttrs, newClassAttrs)
			So(len(intersectClassAttrs), ShouldEqual, len(origClassAttrs))
		})
		// Check that the Attributes have the right names
		Convey("Attribute names should be correct...", func() {
			origNames := []string{"floatAttr", "shouldBe1Binary",
				"shouldBe3Binary_stoicism", "shouldBe3Binary_heroism",
				"shouldBe3Binary_romanticism", "arbitraryClass"}
			origMap := make(map[string]bool)
			for _, a := range origNames {
				origMap[a] = false
			}
			for _, a := range instF.AllAttributes() {
				name := a.GetName()
				_, ok := origMap[name]
				So(ok, ShouldBeTrue)

				origMap[name] = true
			}
			for a := range origMap {
				So(origMap[a], ShouldEqual, true)
			}
		})

		Convey("All Attributes should be the correct type...", func() {
			for _, a := range instF.AllAttributes() {
				if a.GetName() == "arbitraryClass" {
					_, ok := a.(*base.CategoricalAttribute)
					So(ok, ShouldEqual, true)
				} else {
					_, ok := a.(*base.FloatAttribute)
					So(ok, ShouldEqual, true)
				}
			}
		})

		// Check that the Attributes have been discretised correctly
		Convey("FloatConversion should have worked", func() {
			// Build Attribute map
			attrMap := make(map[string]base.Attribute)
			for _, a := range instF.AllAttributes() {
				attrMap[a.GetName()] = a
			}
			// For each attribute
			for name := range attrMap {
				So(name, ShouldBeIn, []string{
					"floatAttr",
					"shouldBe1Binary",
					"shouldBe3Binary_stoicism",
					"shouldBe3Binary_heroism",
					"shouldBe3Binary_romanticism",
					"arbitraryClass",
				})

				attr := attrMap[name]
				as, err := instF.GetAttribute(attr)
				So(err, ShouldEqual, nil)

				if name == "floatAttr" {
					So(instF.Get(as, 0), ShouldResemble, base.PackFloatToBytes(1.0))
					So(instF.Get(as, 1), ShouldResemble, base.PackFloatToBytes(1.0))
					So(instF.Get(as, 2), ShouldResemble, base.PackFloatToBytes(0.0))
				} else if name == "shouldBe1Binary" {
					So(instF.Get(as, 0), ShouldResemble, base.PackFloatToBytes(0.0))
					So(instF.Get(as, 1), ShouldResemble, base.PackFloatToBytes(1.0))
					So(instF.Get(as, 2), ShouldResemble, base.PackFloatToBytes(1.0))
				} else if name == "shouldBe3Binary_stoicism" {
					So(instF.Get(as, 0), ShouldResemble, base.PackFloatToBytes(1.0))
					So(instF.Get(as, 1), ShouldResemble, base.PackFloatToBytes(0.0))
					So(instF.Get(as, 2), ShouldResemble, base.PackFloatToBytes(0.0))
				} else if name == "shouldBe3Binary_heroism" {
					So(instF.Get(as, 0), ShouldResemble, base.PackFloatToBytes(0.0))
					So(instF.Get(as, 1), ShouldResemble, base.PackFloatToBytes(1.0))
					So(instF.Get(as, 2), ShouldResemble, base.PackFloatToBytes(0.0))
				} else if name == "shouldBe3Binary_romanticism" {
					So(instF.Get(as, 0), ShouldResemble, base.PackFloatToBytes(0.0))
					So(instF.Get(as, 1), ShouldResemble, base.PackFloatToBytes(0.0))
					So(instF.Get(as, 2), ShouldResemble, base.PackFloatToBytes(1.0))
				} else if name == "arbitraryClass" {
				}
			}
		})
	})
}
示例#18
0
func TestBinaryFilter(t *testing.T) {

	Convey("Given a contrived dataset...", t, func() {

		// Read the contrived dataset
		inst, err := base.ParseCSVToInstances("./binary_test.csv", true)
		So(err, ShouldEqual, nil)

		// Add Attributes to the filter
		bFilt := NewBinaryConvertFilter()
		bAttrs := base.NonClassAttributes(inst)
		for _, a := range bAttrs {
			bFilt.AddAttribute(a)
		}
		bFilt.Train()

		// Construct a LazilyFilteredInstances to handle it
		instF := base.NewLazilyFilteredInstances(inst, bFilt)

		Convey("All the non-class Attributes should be binary...", func() {
			// Check that all the Attributes are the right type
			for _, a := range base.NonClassAttributes(instF) {
				_, ok := a.(*base.BinaryAttribute)
				So(ok, ShouldEqual, true)
			}
		})

		// Check that all the class Attributes made it
		Convey("All the class Attributes should have survived...", func() {
			origClassAttrs := inst.AllClassAttributes()
			newClassAttrs := instF.AllClassAttributes()
			intersectClassAttrs := base.AttributeIntersect(origClassAttrs, newClassAttrs)
			So(len(intersectClassAttrs), ShouldEqual, len(origClassAttrs))
		})
		// Check that the Attributes have the right names
		Convey("Attribute names should be correct...", func() {
			origNames := []string{"floatAttr", "shouldBe1Binary",
				"shouldBe3Binary_stoicism", "shouldBe3Binary_heroism",
				"shouldBe3Binary_romanticism", "arbitraryClass"}
			origMap := make(map[string]bool)
			for _, a := range origNames {
				origMap[a] = false
			}
			for _, a := range instF.AllAttributes() {
				name := a.GetName()
				_, ok := origMap[name]
				if !ok {
					t.Error(fmt.Sprintf("Weird: %s", name))
				}
				origMap[name] = true
			}
			for a := range origMap {
				So(origMap[a], ShouldEqual, true)
			}
		})

		// Check that the Attributes have been discretised correctly
		Convey("Discretisation should have worked", func() {
			// Build Attribute map
			attrMap := make(map[string]base.Attribute)
			for _, a := range instF.AllAttributes() {
				attrMap[a.GetName()] = a
			}
			// For each attribute
			for name := range attrMap {
				attr := attrMap[name]
				// Retrieve AttributeSpec
				as, err := instF.GetAttribute(attr)
				So(err, ShouldEqual, nil)
				if name == "floatAttr" {
					So(instF.Get(as, 0), ShouldResemble, []byte{1})
					So(instF.Get(as, 1), ShouldResemble, []byte{1})
					So(instF.Get(as, 2), ShouldResemble, []byte{0})
				} else if name == "shouldBe1Binary" {
					So(instF.Get(as, 0), ShouldResemble, []byte{0})
					So(instF.Get(as, 1), ShouldResemble, []byte{1})
					So(instF.Get(as, 2), ShouldResemble, []byte{1})
				} else if name == "shouldBe3Binary_stoicism" {
					So(instF.Get(as, 0), ShouldResemble, []byte{1})
					So(instF.Get(as, 1), ShouldResemble, []byte{0})
					So(instF.Get(as, 2), ShouldResemble, []byte{0})
				} else if name == "shouldBe3Binary_heroism" {
					So(instF.Get(as, 0), ShouldResemble, []byte{0})
					So(instF.Get(as, 1), ShouldResemble, []byte{1})
					So(instF.Get(as, 2), ShouldResemble, []byte{0})
				} else if name == "shouldBe3Binary_romanticism" {
					So(instF.Get(as, 0), ShouldResemble, []byte{0})
					So(instF.Get(as, 1), ShouldResemble, []byte{0})
					So(instF.Get(as, 2), ShouldResemble, []byte{1})
				} else if name == "arbitraryClass" {
				} else {
					t.Error("Shouldn't have %s", name)
				}

			}
		})

	})

}
示例#19
0
文件: trees.go 项目: JacobXie/golearn
func main() {

	var tree base.Classifier

	rand.Seed(time.Now().UTC().UnixNano())

	// Load in the iris dataset
	iris, err := base.ParseCSVToInstances("../datasets/iris_headers.csv", true)
	if err != nil {
		panic(err)
	}

	// Discretise the iris dataset with Chi-Merge
	filt := filters.NewChiMergeFilter(iris, 0.99)
	for _, a := range base.NonClassFloatAttributes(iris) {
		filt.AddAttribute(a)
	}
	filt.Train()
	irisf := base.NewLazilyFilteredInstances(iris, filt)

	// Create a 60-40 training-test split
	trainData, testData := base.InstancesTrainTestSplit(irisf, 0.60)

	//
	// First up, use ID3
	//
	tree = trees.NewID3DecisionTree(0.6)
	// (Parameter controls train-prune split.)

	// Train the ID3 tree
	tree.Fit(trainData)

	// Generate predictions
	predictions := tree.Predict(testData)

	// Evaluate
	fmt.Println("ID3 Performance")
	cf := eval.GetConfusionMatrix(testData, predictions)
	fmt.Println(eval.GetSummary(cf))

	//
	// Next up, Random Trees
	//

	// Consider two randomly-chosen attributes
	tree = trees.NewRandomTree(2)
	tree.Fit(testData)
	predictions = tree.Predict(testData)
	fmt.Println("RandomTree Performance")
	cf = eval.GetConfusionMatrix(testData, predictions)
	fmt.Println(eval.GetSummary(cf))

	//
	// Finally, Random Forests
	//
	tree = ensemble.NewRandomForest(100, 3)
	tree.Fit(trainData)
	predictions = tree.Predict(testData)
	fmt.Println("RandomForest Performance")
	cf = eval.GetConfusionMatrix(testData, predictions)
	fmt.Println(eval.GetSummary(cf))
}
示例#20
0
文件: trees.go 项目: CTLife/golearn
func main() {

	var tree base.Classifier

	rand.Seed(44111342)

	// Load in the iris dataset
	iris, err := base.ParseCSVToInstances("../datasets/iris_headers.csv", true)
	if err != nil {
		panic(err)
	}

	// Discretise the iris dataset with Chi-Merge
	filt := filters.NewChiMergeFilter(iris, 0.999)
	for _, a := range base.NonClassFloatAttributes(iris) {
		filt.AddAttribute(a)
	}
	filt.Train()
	irisf := base.NewLazilyFilteredInstances(iris, filt)

	// Create a 60-40 training-test split
	trainData, testData := base.InstancesTrainTestSplit(irisf, 0.60)

	//
	// First up, use ID3
	//
	tree = trees.NewID3DecisionTree(0.6)
	// (Parameter controls train-prune split.)

	// Train the ID3 tree
	err = tree.Fit(trainData)
	if err != nil {
		panic(err)
	}

	// Generate predictions
	predictions, err := tree.Predict(testData)
	if err != nil {
		panic(err)
	}

	// Evaluate
	fmt.Println("ID3 Performance (information gain)")
	cf, err := evaluation.GetConfusionMatrix(testData, predictions)
	if err != nil {
		panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error()))
	}
	fmt.Println(evaluation.GetSummary(cf))

	tree = trees.NewID3DecisionTreeFromRule(0.6, new(trees.InformationGainRatioRuleGenerator))
	// (Parameter controls train-prune split.)

	// Train the ID3 tree
	err = tree.Fit(trainData)
	if err != nil {
		panic(err)
	}

	// Generate predictions
	predictions, err = tree.Predict(testData)
	if err != nil {
		panic(err)
	}

	// Evaluate
	fmt.Println("ID3 Performance (information gain ratio)")
	cf, err = evaluation.GetConfusionMatrix(testData, predictions)
	if err != nil {
		panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error()))
	}
	fmt.Println(evaluation.GetSummary(cf))

	tree = trees.NewID3DecisionTreeFromRule(0.6, new(trees.GiniCoefficientRuleGenerator))
	// (Parameter controls train-prune split.)

	// Train the ID3 tree
	err = tree.Fit(trainData)
	if err != nil {
		panic(err)
	}

	// Generate predictions
	predictions, err = tree.Predict(testData)
	if err != nil {
		panic(err)
	}

	// Evaluate
	fmt.Println("ID3 Performance (gini index generator)")
	cf, err = evaluation.GetConfusionMatrix(testData, predictions)
	if err != nil {
		panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error()))
	}
	fmt.Println(evaluation.GetSummary(cf))
	//
	// Next up, Random Trees
	//

	// Consider two randomly-chosen attributes
	tree = trees.NewRandomTree(2)
	err = tree.Fit(testData)
	if err != nil {
		panic(err)
	}
	predictions, err = tree.Predict(testData)
	if err != nil {
		panic(err)
	}
	fmt.Println("RandomTree Performance")
	cf, err = evaluation.GetConfusionMatrix(testData, predictions)
	if err != nil {
		panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error()))
	}
	fmt.Println(evaluation.GetSummary(cf))

	//
	// Finally, Random Forests
	//
	tree = ensemble.NewRandomForest(70, 3)
	err = tree.Fit(trainData)
	if err != nil {
		panic(err)
	}
	predictions, err = tree.Predict(testData)
	if err != nil {
		panic(err)
	}
	fmt.Println("RandomForest Performance")
	cf, err = evaluation.GetConfusionMatrix(testData, predictions)
	if err != nil {
		panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error()))
	}
	fmt.Println(evaluation.GetSummary(cf))
}