예제 #1
0
파일: main.go 프로젝트: unixpickle/weakai
func main() {
	log.Println("Creating training samples...")
	samples := trainingSamples()
	attrs := trainingAttrs()
	log.Println("Training forest...")
	forest := idtrees.BuildForest(ForestSize, samples, attrs, TrainingSize, 75,
		func(s []idtrees.Sample, a []idtrees.Attr) *idtrees.Tree {
			return idtrees.ID3(s, a, 0)
		})
	log.Println("Running classifications...")
	hist := mnist.LoadTestingDataSet().CorrectnessHistogram(func(data []float64) int {
		sample := newImageSample(mnist.Sample{Intensities: data})
		res := forest.Classify(sample)
		var maxVal float64
		var maxClass int
		for class, x := range res {
			if x > maxVal {
				maxVal = x
				maxClass = class.(int)
			}
		}
		return maxClass
	})
	log.Println("Results:", hist)
}
예제 #2
0
파일: main.go 프로젝트: unixpickle/weakai
func main() {
	if len(os.Args) != 2 {
		fmt.Fprintln(os.Stderr, "Usage: idtrees <data.csv>")
		fmt.Fprintln(os.Stderr, "")
		fmt.Fprintln(os.Stderr, "  The first row of the input CSV file specifies field names.")
		fmt.Fprintln(os.Stderr, "  Fields with names starting with _ are ignored.")
		fmt.Fprintln(os.Stderr, "  The field whose name begins with * is identified by the tree.")
		fmt.Fprintln(os.Stderr, "")
		os.Exit(1)
	}

	log.Println("Reading CSV file...")

	f, err := os.Open(os.Args[1])
	if err != nil {
		fmt.Fprintln(os.Stderr, "Error opening file:", err)
		os.Exit(1)
	}
	defer f.Close()

	samples, keys, err := ReadCSV(f)
	if err != nil {
		fmt.Fprintln(os.Stderr, err)
		os.Exit(1)
	}

	log.Println("Generating tree...")
	tree := idtrees.ID3(samples, keys, 0)
	log.Println("Printing tree...")

	fmt.Println(tree)
}