func main() {
	flag.Parse()
	trainData := &pb.TrainingData{}
	if err := parseToProto(*trainDataPath, trainData); err != nil {
		glog.Fatal(err)
	}

	glog.Infof(
		"Loaded %v training examples, %v test examples",
		len(trainData.GetTrain()),
		len(trainData.GetTest()))

	config := &pb.ForestConfig{}
	if err := parseToProto(*configPath, config); err != nil {
		glog.Fatal(err)
	}
	glog.Infof("Loaded forest config %+v", config)

	generator, err := dt.NewForestGenerator(config)
	if err != nil {
		glog.Fatal(err)
	}
	forest := generator.ConstructForest(trainData.GetTrain())
	learningCurve := dt.LearningCurve(forest, trainData.GetTest())

	glog.Infof("Learning curve: %+v", learningCurve)

	serializedForest, err := json.MarshalIndent(forest, "", "  ")
	if err != nil {
		glog.Fatal(err)
	}

	os.Stdout.Write(serializedForest)
}
Esempio n. 2
0
func (m *MongoTrainer) runTraining(task *trainingTask) error {
	dataSource, err := NewDataSource(task.row.GetDataSourceConfig(), m.Collection.Database.Session)
	if err != nil {
		return err
	}
	trainingData, err := dataSource.GetTrainingData()
	if err != nil {
		return err
	}

	generator, err := dt.NewForestGenerator(task.row.GetForestConfig())
	if err != nil {
		return err
	}
	task.row.Forest = generator.ConstructForest(trainingData.GetTrain())
	task.row.TrainingResults = dt.LearningCurve(task.row.Forest, trainingData.GetTest())
	return nil
}