func TestRandomForest1(testEnv *testing.T) { inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) if err != nil { panic(err) } rand.Seed(time.Now().UnixNano()) trainData, testData := base.InstancesTrainTestSplit(inst, 0.6) filt := filters.NewChiMergeFilter(inst, 0.90) for _, a := range base.NonClassFloatAttributes(inst) { filt.AddAttribute(a) } filt.Train() trainDataf := base.NewLazilyFilteredInstances(trainData, filt) testDataf := base.NewLazilyFilteredInstances(testData, filt) rf := new(BaggedModel) for i := 0; i < 10; i++ { rf.AddModel(trees.NewRandomTree(2)) } rf.Fit(trainDataf) fmt.Println(rf) predictions := rf.Predict(testDataf) fmt.Println(predictions) confusionMat := eval.GetConfusionMatrix(testDataf, predictions) fmt.Println(confusionMat) fmt.Println(eval.GetMacroPrecision(confusionMat)) fmt.Println(eval.GetMacroRecall(confusionMat)) fmt.Println(eval.GetSummary(confusionMat)) }
func TestChiMergeFilter(t *testing.T) { Convey("Chi-Merge Filter", t, func() { // See http://sci2s.ugr.es/keel/pdf/algorithm/congreso/1992-Kerber-ChimErge-AAAI92.pdf // Randy Kerber, ChiMerge: Discretisation of Numeric Attributes, 1992 instances, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) So(err, ShouldBeNil) Convey("Create and train the filter", func() { filter := NewChiMergeFilter(instances, 0.90) filter.AddAttribute(instances.AllAttributes()[0]) filter.AddAttribute(instances.AllAttributes()[1]) filter.Train() Convey("Filter the dataset", func() { filteredInstances := base.NewLazilyFilteredInstances(instances, filter) classAttributes := filteredInstances.AllClassAttributes() Convey("There should only be one class attribute", func() { So(len(classAttributes), ShouldEqual, 1) }) expectedClassAttribute := "Species" Convey(fmt.Sprintf("The class attribute should be %s", expectedClassAttribute), func() { So(classAttributes[0].GetName(), ShouldEqual, expectedClassAttribute) }) }) }) }) }
func main() { var tree base.Classifier rand.Seed(44111342) // Load in the iris dataset iris, err := base.ParseCSVToInstances("/home/kralli/go/src/github.com/sjwhitworth/golearn/examples/datasets/iris_headers.csv", true) if err != nil { panic(err) } // Discretise the iris dataset with Chi-Merge filt := filters.NewChiMergeFilter(iris, 0.999) for _, a := range base.NonClassFloatAttributes(iris) { filt.AddAttribute(a) } filt.Train() irisf := base.NewLazilyFilteredInstances(iris, filt) // Create a 60-40 training-test split //testData trainData, _ := base.InstancesTrainTestSplit(iris, 0.60) findBestSplit(trainData) //fmt.Println(trainData) //fmt.Println(testData) fmt.Println(tree) fmt.Println(irisf) }
// Predict issues predictions. Each class-specific classifier is expected // to output a value between 0 (indicating that a given instance is not // a given class) and 1 (indicating that the given instance is definitely // that class). For each instance, the class with the highest value is chosen. // The result is undefined if several underlying models output the same value. func (m *OneVsAllModel) Predict(what base.FixedDataGrid) (base.FixedDataGrid, error) { ret := base.GeneratePredictionVector(what) vecs := make([]base.FixedDataGrid, m.maxClassVal+1) specs := make([]base.AttributeSpec, m.maxClassVal+1) for i := uint64(0); i <= m.maxClassVal; i++ { f := m.filters[i] c := base.NewLazilyFilteredInstances(what, f) p, err := m.classifiers[i].Predict(c) if err != nil { return nil, err } vecs[i] = p specs[i] = base.ResolveAttributes(p, p.AllClassAttributes())[0] } _, rows := ret.Size() spec := base.ResolveAttributes(ret, ret.AllClassAttributes())[0] for i := 0; i < rows; i++ { class := uint64(0) best := 0.0 for j := uint64(0); j <= m.maxClassVal; j++ { val := base.UnpackBytesToFloat(vecs[j].Get(specs[j], i)) if val > best { class = j best = val } } ret.Set(spec, i, base.PackU64ToBytes(class)) } return ret, nil }
func TestBinaryFilterClassPreservation(t *testing.T) { Convey("Given a contrived dataset...", t, func() { // Read the contrived dataset inst, err := base.ParseCSVToInstances("./binary_test.csv", true) So(err, ShouldEqual, nil) // Add all Attributes to the filter bFilt := NewBinaryConvertFilter() bAttrs := inst.AllAttributes() for _, a := range bAttrs { bFilt.AddAttribute(a) } bFilt.Train() // Construct a LazilyFilteredInstances to handle it instF := base.NewLazilyFilteredInstances(inst, bFilt) Convey("All the expected class Attributes should be present if discretised...", func() { attrMap := make(map[string]bool) attrMap["arbitraryClass_hi"] = false attrMap["arbitraryClass_there"] = false attrMap["arbitraryClass_world"] = false for _, a := range instF.AllClassAttributes() { attrMap[a.GetName()] = true } So(attrMap["arbitraryClass_hi"], ShouldEqual, true) So(attrMap["arbitraryClass_there"], ShouldEqual, true) So(attrMap["arbitraryClass_world"], ShouldEqual, true) }) }) }
func BenchmarkBaggingRandomForestPredict(t *testing.B) { inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) if err != nil { t.Fatal("Unable to parse CSV to instances: %s", err.Error()) } rand.Seed(time.Now().UnixNano()) filt := filters.NewChiMergeFilter(inst, 0.90) for _, a := range base.NonClassFloatAttributes(inst) { filt.AddAttribute(a) } filt.Train() instf := base.NewLazilyFilteredInstances(inst, filt) rf := new(BaggedModel) for i := 0; i < 10; i++ { rf.AddModel(trees.NewRandomTree(2)) } rf.Fit(instf) t.ResetTimer() for i := 0; i < 20; i++ { rf.Predict(instf) } }
func TestRandomTreeClassificationAfterDiscretisation(t *testing.T) { Convey("Predictions on filtered data with a Random Tree", t, func() { instances, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) So(err, ShouldBeNil) trainData, testData := base.InstancesTrainTestSplit(instances, 0.6) filter := filters.NewChiMergeFilter(instances, 0.9) for _, a := range base.NonClassFloatAttributes(instances) { filter.AddAttribute(a) } filter.Train() filteredTrainData := base.NewLazilyFilteredInstances(trainData, filter) filteredTestData := base.NewLazilyFilteredInstances(testData, filter) verifyTreeClassification(filteredTrainData, filteredTestData) }) }
func (m *MultiLayerNet) convertToFloatInsts(X base.FixedDataGrid) base.FixedDataGrid { // Make sure everything's a FloatAttribute fFilt := filters.NewFloatConvertFilter() for _, a := range X.AllAttributes() { fFilt.AddAttribute(a) } fFilt.Train() insts := base.NewLazilyFilteredInstances(X, fFilt) return insts }
func convertToBinary(src base.FixedDataGrid) base.FixedDataGrid { // Convert to binary b := filters.NewBinaryConvertFilter() attrs := base.NonClassAttributes(src) for _, a := range attrs { b.AddAttribute(a) } b.Train() ret := base.NewLazilyFilteredInstances(src, b) return ret }
func TestBaggedModelRandomForest(t *testing.T) { Convey("Given data", t, func() { inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) So(err, ShouldBeNil) Convey("Splitting the data into training and test data", func() { trainData, testData := base.InstancesTrainTestSplit(inst, 0.6) Convey("Filtering the split datasets", func() { rand.Seed(time.Now().UnixNano()) filt := filters.NewChiMergeFilter(inst, 0.90) for _, a := range base.NonClassFloatAttributes(inst) { filt.AddAttribute(a) } filt.Train() trainDataf := base.NewLazilyFilteredInstances(trainData, filt) testDataf := base.NewLazilyFilteredInstances(testData, filt) Convey("Fitting and Predicting with a Bagged Model of 10 Random Trees", func() { rf := new(BaggedModel) for i := 0; i < 10; i++ { rf.AddModel(trees.NewRandomTree(2)) } rf.Fit(trainDataf) predictions := rf.Predict(testDataf) confusionMat, err := evaluation.GetConfusionMatrix(testDataf, predictions) So(err, ShouldBeNil) Convey("Predictions are somewhat accurate", func() { So(evaluation.GetAccuracy(confusionMat), ShouldBeGreaterThan, 0.5) }) }) }) }) }) }
func TestBinning(t *testing.T) { Convey("Given some data and a reference", t, func() { // Read the data inst1, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) if err != nil { panic(err) } inst2, err := base.ParseCSVToInstances("../examples/datasets/iris_binned.csv", true) if err != nil { panic(err) } // // Construct the binning filter binAttr := inst1.AllAttributes()[0] filt := NewBinningFilter(inst1, 10) filt.AddAttribute(binAttr) filt.Train() inst1f := base.NewLazilyFilteredInstances(inst1, filt) // Retrieve the categorical version of the original Attribute var cAttr base.Attribute for _, a := range inst1f.AllAttributes() { if a.GetName() == binAttr.GetName() { cAttr = a } } cAttrSpec, err := inst1f.GetAttribute(cAttr) So(err, ShouldEqual, nil) binAttrSpec, err := inst2.GetAttribute(binAttr) So(err, ShouldEqual, nil) // // Create the LazilyFilteredInstances // and check the values Convey("Discretized version should match reference", func() { _, rows := inst1.Size() for i := 0; i < rows; i++ { val1 := inst1f.Get(cAttrSpec, i) val2 := inst2.Get(binAttrSpec, i) val1s := cAttr.GetStringFromSysVal(val1) val2s := binAttr.GetStringFromSysVal(val2) So(val1s, ShouldEqual, val2s) } }) }) }
func TestRandomForest(t *testing.T) { Convey("Given a valid CSV file", t, func() { inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) So(err, ShouldBeNil) Convey("When Chi-Merge filtering the data", func() { filt := filters.NewChiMergeFilter(inst, 0.90) for _, a := range base.NonClassFloatAttributes(inst) { filt.AddAttribute(a) } filt.Train() instf := base.NewLazilyFilteredInstances(inst, filt) Convey("Splitting the data into test and training sets", func() { trainData, testData := base.InstancesTrainTestSplit(instf, 0.60) Convey("Fitting and predicting with a Random Forest", func() { rf := NewRandomForest(10, 3) err = rf.Fit(trainData) So(err, ShouldBeNil) predictions, err := rf.Predict(testData) So(err, ShouldBeNil) confusionMat, err := evaluation.GetConfusionMatrix(testData, predictions) So(err, ShouldBeNil) Convey("Predictions should be somewhat accurate", func() { So(evaluation.GetAccuracy(confusionMat), ShouldBeGreaterThan, 0.35) }) }) }) }) Convey("Fitting with a Random Forest with too many features compared to the data", func() { rf := NewRandomForest(10, len(base.NonClassAttributes(inst))+1) err = rf.Fit(inst) Convey("Should return an error", func() { So(err, ShouldNotBeNil) }) }) }) }
// Fit creates n filtered datasets (where n is the number of values // a CategoricalAttribute can take) and uses them to train the // underlying classifiers. func (m *OneVsAllModel) Fit(using base.FixedDataGrid) { var classAttr *base.CategoricalAttribute // Do some validation classAttrs := using.AllClassAttributes() for _, a := range classAttrs { if c, ok := a.(*base.CategoricalAttribute); !ok { panic("Unsupported ClassAttribute type") } else { classAttr = c } } attrs := m.generateAttributes(using) // Find the highest stored value val := uint64(0) classVals := classAttr.GetValues() for _, s := range classVals { cur := base.UnpackBytesToU64(classAttr.GetSysValFromString(s)) if cur > val { val = cur } } if val == 0 { panic("Must have more than one class!") } m.maxClassVal = val // Create individual filtered instances for training filters := make([]*oneVsAllFilter, val+1) classifiers := make([]base.Classifier, val+1) for i := uint64(0); i <= val; i++ { f := &oneVsAllFilter{ attrs, classAttr, i, } filters[i] = f classifiers[i] = m.NewClassifierFunction(classVals[int(i)]) classifiers[i].Fit(base.NewLazilyFilteredInstances(using, f)) } m.filters = filters m.classifiers = classifiers }
func TestChiMerge4(testEnv *testing.T) { // See http://sci2s.ugr.es/keel/pdf/algorithm/congreso/1992-Kerber-ChimErge-AAAI92.pdf // Randy Kerber, ChiMerge: Discretisation of Numeric Attributes, 1992 inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) if err != nil { panic(err) } filt := NewChiMergeFilter(inst, 0.90) filt.AddAttribute(inst.AllAttributes()[0]) filt.AddAttribute(inst.AllAttributes()[1]) filt.Train() instf := base.NewLazilyFilteredInstances(inst, filt) fmt.Println(instf) fmt.Println(instf.String()) clsAttrs := instf.AllClassAttributes() if len(clsAttrs) != 1 { panic(fmt.Sprintf("%d != %d", len(clsAttrs), 1)) } if clsAttrs[0].GetName() != "Species" { panic("Class Attribute wrong!") } }
func TestRandomForest1(testEnv *testing.T) { inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) if err != nil { panic(err) } filt := filters.NewChiMergeFilter(inst, 0.90) for _, a := range base.NonClassFloatAttributes(inst) { filt.AddAttribute(a) } filt.Train() instf := base.NewLazilyFilteredInstances(inst, filt) trainData, testData := base.InstancesTrainTestSplit(instf, 0.60) rf := NewRandomForest(10, 3) rf.Fit(trainData) predictions := rf.Predict(testData) fmt.Println(predictions) confusionMat := eval.GetConfusionMatrix(testData, predictions) fmt.Println(confusionMat) fmt.Println(eval.GetSummary(confusionMat)) }
func TestRandomTreeClassification(t *testing.T) { Convey("Predictions on filtered data with a Random Tree", t, func() { instances, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) So(err, ShouldBeNil) trainData, testData := base.InstancesTrainTestSplit(instances, 0.6) filter := filters.NewChiMergeFilter(instances, 0.9) for _, a := range base.NonClassFloatAttributes(instances) { filter.AddAttribute(a) } filter.Train() filteredTrainData := base.NewLazilyFilteredInstances(trainData, filter) filteredTestData := base.NewLazilyFilteredInstances(testData, filter) Convey("Using InferID3Tree to create the tree and do the fitting", func() { Convey("Using a RandomTreeRule", func() { randomTreeRuleGenerator := new(RandomTreeRuleGenerator) randomTreeRuleGenerator.Attributes = 2 root := InferID3Tree(filteredTrainData, randomTreeRuleGenerator) Convey("Predicting with the tree", func() { predictions, err := root.Predict(filteredTestData) So(err, ShouldBeNil) confusionMatrix, err := evaluation.GetConfusionMatrix(filteredTestData, predictions) So(err, ShouldBeNil) Convey("Predictions should be somewhat accurate", func() { So(evaluation.GetAccuracy(confusionMatrix), ShouldBeGreaterThan, 0.5) }) }) }) Convey("Using a InformationGainRule", func() { informationGainRuleGenerator := new(InformationGainRuleGenerator) root := InferID3Tree(filteredTrainData, informationGainRuleGenerator) Convey("Predicting with the tree", func() { predictions, err := root.Predict(filteredTestData) So(err, ShouldBeNil) confusionMatrix, err := evaluation.GetConfusionMatrix(filteredTestData, predictions) So(err, ShouldBeNil) Convey("Predictions should be somewhat accurate", func() { So(evaluation.GetAccuracy(confusionMatrix), ShouldBeGreaterThan, 0.5) }) }) }) }) Convey("Using NewRandomTree to create the tree", func() { root := NewRandomTree(2) Convey("Fitting with the tree", func() { err = root.Fit(filteredTrainData) So(err, ShouldBeNil) Convey("Predicting with the tree, *without* pruning first", func() { predictions, err := root.Predict(filteredTestData) So(err, ShouldBeNil) confusionMatrix, err := evaluation.GetConfusionMatrix(filteredTestData, predictions) So(err, ShouldBeNil) Convey("Predictions should be somewhat accurate", func() { So(evaluation.GetAccuracy(confusionMatrix), ShouldBeGreaterThan, 0.5) }) }) Convey("Predicting with the tree, pruning first", func() { root.Prune(filteredTestData) predictions, err := root.Predict(filteredTestData) So(err, ShouldBeNil) confusionMatrix, err := evaluation.GetConfusionMatrix(filteredTestData, predictions) So(err, ShouldBeNil) Convey("Predictions should be somewhat accurate", func() { So(evaluation.GetAccuracy(confusionMatrix), ShouldBeGreaterThan, 0.4) }) }) }) }) }) }
func TestFloatFilter(t *testing.T) { Convey("Given a contrived dataset...", t, func() { // Read the contrived dataset inst, err := base.ParseCSVToInstances("./binary_test.csv", true) So(err, ShouldEqual, nil) // Add Attributes to the filter bFilt := NewFloatConvertFilter() bAttrs := base.NonClassAttributes(inst) for _, a := range bAttrs { bFilt.AddAttribute(a) } bFilt.Train() // Construct a LazilyFilteredInstances to handle it instF := base.NewLazilyFilteredInstances(inst, bFilt) Convey("All the non-class Attributes should be floats...", func() { // Check that all the Attributes are the right type for _, a := range base.NonClassAttributes(instF) { _, ok := a.(*base.FloatAttribute) So(ok, ShouldEqual, true) } }) // Check that all the class Attributes made it Convey("All the class Attributes should have survived...", func() { origClassAttrs := inst.AllClassAttributes() newClassAttrs := instF.AllClassAttributes() intersectClassAttrs := base.AttributeIntersect(origClassAttrs, newClassAttrs) So(len(intersectClassAttrs), ShouldEqual, len(origClassAttrs)) }) // Check that the Attributes have the right names Convey("Attribute names should be correct...", func() { origNames := []string{"floatAttr", "shouldBe1Binary", "shouldBe3Binary_stoicism", "shouldBe3Binary_heroism", "shouldBe3Binary_romanticism", "arbitraryClass"} origMap := make(map[string]bool) for _, a := range origNames { origMap[a] = false } for _, a := range instF.AllAttributes() { name := a.GetName() _, ok := origMap[name] So(ok, ShouldBeTrue) origMap[name] = true } for a := range origMap { So(origMap[a], ShouldEqual, true) } }) Convey("All Attributes should be the correct type...", func() { for _, a := range instF.AllAttributes() { if a.GetName() == "arbitraryClass" { _, ok := a.(*base.CategoricalAttribute) So(ok, ShouldEqual, true) } else { _, ok := a.(*base.FloatAttribute) So(ok, ShouldEqual, true) } } }) // Check that the Attributes have been discretised correctly Convey("FloatConversion should have worked", func() { // Build Attribute map attrMap := make(map[string]base.Attribute) for _, a := range instF.AllAttributes() { attrMap[a.GetName()] = a } // For each attribute for name := range attrMap { So(name, ShouldBeIn, []string{ "floatAttr", "shouldBe1Binary", "shouldBe3Binary_stoicism", "shouldBe3Binary_heroism", "shouldBe3Binary_romanticism", "arbitraryClass", }) attr := attrMap[name] as, err := instF.GetAttribute(attr) So(err, ShouldEqual, nil) if name == "floatAttr" { So(instF.Get(as, 0), ShouldResemble, base.PackFloatToBytes(1.0)) So(instF.Get(as, 1), ShouldResemble, base.PackFloatToBytes(1.0)) So(instF.Get(as, 2), ShouldResemble, base.PackFloatToBytes(0.0)) } else if name == "shouldBe1Binary" { So(instF.Get(as, 0), ShouldResemble, base.PackFloatToBytes(0.0)) So(instF.Get(as, 1), ShouldResemble, base.PackFloatToBytes(1.0)) So(instF.Get(as, 2), ShouldResemble, base.PackFloatToBytes(1.0)) } else if name == "shouldBe3Binary_stoicism" { So(instF.Get(as, 0), ShouldResemble, base.PackFloatToBytes(1.0)) So(instF.Get(as, 1), ShouldResemble, base.PackFloatToBytes(0.0)) So(instF.Get(as, 2), ShouldResemble, base.PackFloatToBytes(0.0)) } else if name == "shouldBe3Binary_heroism" { So(instF.Get(as, 0), ShouldResemble, base.PackFloatToBytes(0.0)) So(instF.Get(as, 1), ShouldResemble, base.PackFloatToBytes(1.0)) So(instF.Get(as, 2), ShouldResemble, base.PackFloatToBytes(0.0)) } else if name == "shouldBe3Binary_romanticism" { So(instF.Get(as, 0), ShouldResemble, base.PackFloatToBytes(0.0)) So(instF.Get(as, 1), ShouldResemble, base.PackFloatToBytes(0.0)) So(instF.Get(as, 2), ShouldResemble, base.PackFloatToBytes(1.0)) } else if name == "arbitraryClass" { } } }) }) }
func TestBinaryFilter(t *testing.T) { Convey("Given a contrived dataset...", t, func() { // Read the contrived dataset inst, err := base.ParseCSVToInstances("./binary_test.csv", true) So(err, ShouldEqual, nil) // Add Attributes to the filter bFilt := NewBinaryConvertFilter() bAttrs := base.NonClassAttributes(inst) for _, a := range bAttrs { bFilt.AddAttribute(a) } bFilt.Train() // Construct a LazilyFilteredInstances to handle it instF := base.NewLazilyFilteredInstances(inst, bFilt) Convey("All the non-class Attributes should be binary...", func() { // Check that all the Attributes are the right type for _, a := range base.NonClassAttributes(instF) { _, ok := a.(*base.BinaryAttribute) So(ok, ShouldEqual, true) } }) // Check that all the class Attributes made it Convey("All the class Attributes should have survived...", func() { origClassAttrs := inst.AllClassAttributes() newClassAttrs := instF.AllClassAttributes() intersectClassAttrs := base.AttributeIntersect(origClassAttrs, newClassAttrs) So(len(intersectClassAttrs), ShouldEqual, len(origClassAttrs)) }) // Check that the Attributes have the right names Convey("Attribute names should be correct...", func() { origNames := []string{"floatAttr", "shouldBe1Binary", "shouldBe3Binary_stoicism", "shouldBe3Binary_heroism", "shouldBe3Binary_romanticism", "arbitraryClass"} origMap := make(map[string]bool) for _, a := range origNames { origMap[a] = false } for _, a := range instF.AllAttributes() { name := a.GetName() _, ok := origMap[name] if !ok { t.Error(fmt.Sprintf("Weird: %s", name)) } origMap[name] = true } for a := range origMap { So(origMap[a], ShouldEqual, true) } }) // Check that the Attributes have been discretised correctly Convey("Discretisation should have worked", func() { // Build Attribute map attrMap := make(map[string]base.Attribute) for _, a := range instF.AllAttributes() { attrMap[a.GetName()] = a } // For each attribute for name := range attrMap { attr := attrMap[name] // Retrieve AttributeSpec as, err := instF.GetAttribute(attr) So(err, ShouldEqual, nil) if name == "floatAttr" { So(instF.Get(as, 0), ShouldResemble, []byte{1}) So(instF.Get(as, 1), ShouldResemble, []byte{1}) So(instF.Get(as, 2), ShouldResemble, []byte{0}) } else if name == "shouldBe1Binary" { So(instF.Get(as, 0), ShouldResemble, []byte{0}) So(instF.Get(as, 1), ShouldResemble, []byte{1}) So(instF.Get(as, 2), ShouldResemble, []byte{1}) } else if name == "shouldBe3Binary_stoicism" { So(instF.Get(as, 0), ShouldResemble, []byte{1}) So(instF.Get(as, 1), ShouldResemble, []byte{0}) So(instF.Get(as, 2), ShouldResemble, []byte{0}) } else if name == "shouldBe3Binary_heroism" { So(instF.Get(as, 0), ShouldResemble, []byte{0}) So(instF.Get(as, 1), ShouldResemble, []byte{1}) So(instF.Get(as, 2), ShouldResemble, []byte{0}) } else if name == "shouldBe3Binary_romanticism" { So(instF.Get(as, 0), ShouldResemble, []byte{0}) So(instF.Get(as, 1), ShouldResemble, []byte{0}) So(instF.Get(as, 2), ShouldResemble, []byte{1}) } else if name == "arbitraryClass" { } else { t.Error("Shouldn't have %s", name) } } }) }) }
func main() { var tree base.Classifier rand.Seed(time.Now().UTC().UnixNano()) // Load in the iris dataset iris, err := base.ParseCSVToInstances("../datasets/iris_headers.csv", true) if err != nil { panic(err) } // Discretise the iris dataset with Chi-Merge filt := filters.NewChiMergeFilter(iris, 0.99) for _, a := range base.NonClassFloatAttributes(iris) { filt.AddAttribute(a) } filt.Train() irisf := base.NewLazilyFilteredInstances(iris, filt) // Create a 60-40 training-test split trainData, testData := base.InstancesTrainTestSplit(irisf, 0.60) // // First up, use ID3 // tree = trees.NewID3DecisionTree(0.6) // (Parameter controls train-prune split.) // Train the ID3 tree tree.Fit(trainData) // Generate predictions predictions := tree.Predict(testData) // Evaluate fmt.Println("ID3 Performance") cf := eval.GetConfusionMatrix(testData, predictions) fmt.Println(eval.GetSummary(cf)) // // Next up, Random Trees // // Consider two randomly-chosen attributes tree = trees.NewRandomTree(2) tree.Fit(testData) predictions = tree.Predict(testData) fmt.Println("RandomTree Performance") cf = eval.GetConfusionMatrix(testData, predictions) fmt.Println(eval.GetSummary(cf)) // // Finally, Random Forests // tree = ensemble.NewRandomForest(100, 3) tree.Fit(trainData) predictions = tree.Predict(testData) fmt.Println("RandomForest Performance") cf = eval.GetConfusionMatrix(testData, predictions) fmt.Println(eval.GetSummary(cf)) }
func main() { var tree base.Classifier rand.Seed(44111342) // Load in the iris dataset iris, err := base.ParseCSVToInstances("../datasets/iris_headers.csv", true) if err != nil { panic(err) } // Discretise the iris dataset with Chi-Merge filt := filters.NewChiMergeFilter(iris, 0.999) for _, a := range base.NonClassFloatAttributes(iris) { filt.AddAttribute(a) } filt.Train() irisf := base.NewLazilyFilteredInstances(iris, filt) // Create a 60-40 training-test split trainData, testData := base.InstancesTrainTestSplit(irisf, 0.60) // // First up, use ID3 // tree = trees.NewID3DecisionTree(0.6) // (Parameter controls train-prune split.) // Train the ID3 tree err = tree.Fit(trainData) if err != nil { panic(err) } // Generate predictions predictions, err := tree.Predict(testData) if err != nil { panic(err) } // Evaluate fmt.Println("ID3 Performance (information gain)") cf, err := evaluation.GetConfusionMatrix(testData, predictions) if err != nil { panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error())) } fmt.Println(evaluation.GetSummary(cf)) tree = trees.NewID3DecisionTreeFromRule(0.6, new(trees.InformationGainRatioRuleGenerator)) // (Parameter controls train-prune split.) // Train the ID3 tree err = tree.Fit(trainData) if err != nil { panic(err) } // Generate predictions predictions, err = tree.Predict(testData) if err != nil { panic(err) } // Evaluate fmt.Println("ID3 Performance (information gain ratio)") cf, err = evaluation.GetConfusionMatrix(testData, predictions) if err != nil { panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error())) } fmt.Println(evaluation.GetSummary(cf)) tree = trees.NewID3DecisionTreeFromRule(0.6, new(trees.GiniCoefficientRuleGenerator)) // (Parameter controls train-prune split.) // Train the ID3 tree err = tree.Fit(trainData) if err != nil { panic(err) } // Generate predictions predictions, err = tree.Predict(testData) if err != nil { panic(err) } // Evaluate fmt.Println("ID3 Performance (gini index generator)") cf, err = evaluation.GetConfusionMatrix(testData, predictions) if err != nil { panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error())) } fmt.Println(evaluation.GetSummary(cf)) // // Next up, Random Trees // // Consider two randomly-chosen attributes tree = trees.NewRandomTree(2) err = tree.Fit(testData) if err != nil { panic(err) } predictions, err = tree.Predict(testData) if err != nil { panic(err) } fmt.Println("RandomTree Performance") cf, err = evaluation.GetConfusionMatrix(testData, predictions) if err != nil { panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error())) } fmt.Println(evaluation.GetSummary(cf)) // // Finally, Random Forests // tree = ensemble.NewRandomForest(70, 3) err = tree.Fit(trainData) if err != nil { panic(err) } predictions, err = tree.Predict(testData) if err != nil { panic(err) } fmt.Println("RandomForest Performance") cf, err = evaluation.GetConfusionMatrix(testData, predictions) if err != nil { panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error())) } fmt.Println(evaluation.GetSummary(cf)) }