func main() { var tree base.Classifier rand.Seed(time.Now().UTC().UnixNano()) // Load in the iris dataset iris, err := base.ParseCSVToInstances("../datasets/iris_headers.csv", true) if err != nil { panic(err) } // Discretise the iris dataset with Chi-Merge filt := filters.NewChiMergeFilter(iris, 0.99) filt.AddAllNumericAttributes() filt.Build() filt.Run(iris) // Create a 60-40 training-test split insts := base.InstancesTrainTestSplit(iris, 0.60) // // First up, use ID3 // tree = trees.NewID3DecisionTree(0.6) // (Parameter controls train-prune split.) // Train the ID3 tree tree.Fit(insts[0]) // Generate predictions predictions := tree.Predict(insts[1]) // Evaluate fmt.Println("ID3 Performance") cf := eval.GetConfusionMatrix(insts[1], predictions) fmt.Println(eval.GetSummary(cf)) // // Next up, Random Trees // // Consider two randomly-chosen attributes tree = trees.NewRandomTree(2) tree.Fit(insts[0]) predictions = tree.Predict(insts[1]) fmt.Println("RandomTree Performance") cf = eval.GetConfusionMatrix(insts[1], predictions) fmt.Println(eval.GetSummary(cf)) // // Finally, Random Forests // tree = ensemble.NewRandomForest(100, 3) tree.Fit(insts[0]) predictions = tree.Predict(insts[1]) fmt.Println("RandomForest Performance") cf = eval.GetConfusionMatrix(insts[1], predictions) fmt.Println(eval.GetSummary(cf)) }
// Fit builds the RandomForest on the specified instances func (f *RandomForest) Fit(on base.FixedDataGrid) { f.Model = new(meta.BaggedModel) f.Model.RandomFeatures = f.Features for i := 0; i < f.ForestSize; i++ { tree := trees.NewID3DecisionTree(0.00) f.Model.AddModel(tree) } f.Model.Fit(on) }
// Fit builds the RandomForest on the specified instances func (f *RandomForest) Fit(on base.FixedDataGrid) error { numNonClassAttributes := len(base.NonClassAttributes(on)) if numNonClassAttributes < f.Features { return errors.New(fmt.Sprintf( "Random forest with %d features cannot fit data grid with %d non-class attributes", f.Features, numNonClassAttributes, )) } f.Model = new(meta.BaggedModel) f.Model.RandomFeatures = f.Features for i := 0; i < f.ForestSize; i++ { tree := trees.NewID3DecisionTree(0.00) f.Model.AddModel(tree) } f.Model.Fit(on) return nil }
func main() { var tree base.Classifier rand.Seed(44111342) // Load in the iris dataset iris, err := base.ParseCSVToInstances("../datasets/iris_headers.csv", true) if err != nil { panic(err) } // Discretise the iris dataset with Chi-Merge filt := filters.NewChiMergeFilter(iris, 0.999) for _, a := range base.NonClassFloatAttributes(iris) { filt.AddAttribute(a) } filt.Train() irisf := base.NewLazilyFilteredInstances(iris, filt) // Create a 60-40 training-test split trainData, testData := base.InstancesTrainTestSplit(irisf, 0.60) // // First up, use ID3 // tree = trees.NewID3DecisionTree(0.6) // (Parameter controls train-prune split.) // Train the ID3 tree err = tree.Fit(trainData) if err != nil { panic(err) } // Generate predictions predictions, err := tree.Predict(testData) if err != nil { panic(err) } // Evaluate fmt.Println("ID3 Performance (information gain)") cf, err := evaluation.GetConfusionMatrix(testData, predictions) if err != nil { panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error())) } fmt.Println(evaluation.GetSummary(cf)) tree = trees.NewID3DecisionTreeFromRule(0.6, new(trees.InformationGainRatioRuleGenerator)) // (Parameter controls train-prune split.) // Train the ID3 tree err = tree.Fit(trainData) if err != nil { panic(err) } // Generate predictions predictions, err = tree.Predict(testData) if err != nil { panic(err) } // Evaluate fmt.Println("ID3 Performance (information gain ratio)") cf, err = evaluation.GetConfusionMatrix(testData, predictions) if err != nil { panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error())) } fmt.Println(evaluation.GetSummary(cf)) tree = trees.NewID3DecisionTreeFromRule(0.6, new(trees.GiniCoefficientRuleGenerator)) // (Parameter controls train-prune split.) // Train the ID3 tree err = tree.Fit(trainData) if err != nil { panic(err) } // Generate predictions predictions, err = tree.Predict(testData) if err != nil { panic(err) } // Evaluate fmt.Println("ID3 Performance (gini index generator)") cf, err = evaluation.GetConfusionMatrix(testData, predictions) if err != nil { panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error())) } fmt.Println(evaluation.GetSummary(cf)) // // Next up, Random Trees // // Consider two randomly-chosen attributes tree = trees.NewRandomTree(2) err = tree.Fit(testData) if err != nil { panic(err) } predictions, err = tree.Predict(testData) if err != nil { panic(err) } fmt.Println("RandomTree Performance") cf, err = evaluation.GetConfusionMatrix(testData, predictions) if err != nil { panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error())) } fmt.Println(evaluation.GetSummary(cf)) // // Finally, Random Forests // tree = ensemble.NewRandomForest(70, 3) err = tree.Fit(trainData) if err != nil { panic(err) } predictions, err = tree.Predict(testData) if err != nil { panic(err) } fmt.Println("RandomForest Performance") cf, err = evaluation.GetConfusionMatrix(testData, predictions) if err != nil { panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error())) } fmt.Println(evaluation.GetSummary(cf)) }