// NewRescaledFastForestEvaluator returns an evalator for a tree // that automatically corrects for various scaling factors required // for a given evaluation func NewRescaledFastForestEvaluator(f *pb.Forest) (Evaluator, error) { e := &fastForestEvaluator{ trees: make([]Evaluator, 0, len(f.GetTrees())), } for _, t := range f.GetTrees() { evaluator, err := newFastTreeEvaluator(t) if err != nil { return nil, err } e.trees = append(e.trees, evaluator) } switch f.GetRescaling() { case pb.Rescaling_NONE: return e, nil case pb.Rescaling_AVERAGING: return EvaluatorFunc(func(features []float64) float64 { return e.Evaluate(features) / float64(len(e.trees)) }), nil case pb.Rescaling_LOG_ODDS: return EvaluatorFunc(func(features []float64) float64 { return 1.0 / (1.0 + math.Exp(-2.0*e.Evaluate(features))) }), nil } return nil, fmt.Errorf("unknown rescaling method: %v", f.GetRescaling) }
// LearningCurve computes the progressive learning curve after each epoch on the // given examples func LearningCurve(f *pb.Forest, e Examples) *pb.TrainingResults { tr := &pb.TrainingResults{ EpochResults: make([]*pb.EpochResult, 0, len(f.GetTrees())), } for i := range f.GetTrees() { evaluator, err := NewRescaledFastForestEvaluator(&pb.Forest{ Trees: f.GetTrees()[:i], Rescaling: f.GetRescaling().Enum(), }) if err != nil { glog.Fatal(err) } er := computeEpochResult(evaluator, e) tr.EpochResults = append(tr.EpochResults, &er) } return tr }