func main() { flag.Parse() trainData := &pb.TrainingData{} if err := parseToProto(*trainDataPath, trainData); err != nil { glog.Fatal(err) } glog.Infof( "Loaded %v training examples, %v test examples", len(trainData.GetTrain()), len(trainData.GetTest())) config := &pb.ForestConfig{} if err := parseToProto(*configPath, config); err != nil { glog.Fatal(err) } glog.Infof("Loaded forest config %+v", config) generator, err := dt.NewForestGenerator(config) if err != nil { glog.Fatal(err) } forest := generator.ConstructForest(trainData.GetTrain()) learningCurve := dt.LearningCurve(forest, trainData.GetTest()) glog.Infof("Learning curve: %+v", learningCurve) serializedForest, err := json.MarshalIndent(forest, "", " ") if err != nil { glog.Fatal(err) } os.Stdout.Write(serializedForest) }
func (m *MongoTrainer) runTraining(task *trainingTask) error { dataSource, err := NewDataSource(task.row.GetDataSourceConfig(), m.Collection.Database.Session) if err != nil { return err } trainingData, err := dataSource.GetTrainingData() if err != nil { return err } generator, err := dt.NewForestGenerator(task.row.GetForestConfig()) if err != nil { return err } task.row.Forest = generator.ConstructForest(trainingData.GetTrain()) task.row.TrainingResults = dt.LearningCurve(task.row.Forest, trainingData.GetTest()) return nil }