func main() { log.Println("Creating training samples...") samples := trainingSamples() attrs := trainingAttrs() log.Println("Training forest...") forest := idtrees.BuildForest(ForestSize, samples, attrs, TrainingSize, 75, func(s []idtrees.Sample, a []idtrees.Attr) *idtrees.Tree { return idtrees.ID3(s, a, 0) }) log.Println("Running classifications...") hist := mnist.LoadTestingDataSet().CorrectnessHistogram(func(data []float64) int { sample := newImageSample(mnist.Sample{Intensities: data}) res := forest.Classify(sample) var maxVal float64 var maxClass int for class, x := range res { if x > maxVal { maxVal = x maxClass = class.(int) } } return maxClass }) log.Println("Results:", hist) }
func main() { if len(os.Args) != 2 { fmt.Fprintln(os.Stderr, "Usage: idtrees <data.csv>") fmt.Fprintln(os.Stderr, "") fmt.Fprintln(os.Stderr, " The first row of the input CSV file specifies field names.") fmt.Fprintln(os.Stderr, " Fields with names starting with _ are ignored.") fmt.Fprintln(os.Stderr, " The field whose name begins with * is identified by the tree.") fmt.Fprintln(os.Stderr, "") os.Exit(1) } log.Println("Reading CSV file...") f, err := os.Open(os.Args[1]) if err != nil { fmt.Fprintln(os.Stderr, "Error opening file:", err) os.Exit(1) } defer f.Close() samples, keys, err := ReadCSV(f) if err != nil { fmt.Fprintln(os.Stderr, err) os.Exit(1) } log.Println("Generating tree...") tree := idtrees.ID3(samples, keys, 0) log.Println("Printing tree...") fmt.Println(tree) }