Beispiel #1
0
// NewRescaledFastForestEvaluator returns an evalator for a tree
// that automatically corrects for various scaling factors required
// for a given evaluation
func NewRescaledFastForestEvaluator(f *pb.Forest) (Evaluator, error) {
	e := &fastForestEvaluator{
		trees: make([]Evaluator, 0, len(f.GetTrees())),
	}

	for _, t := range f.GetTrees() {
		evaluator, err := newFastTreeEvaluator(t)
		if err != nil {
			return nil, err
		}
		e.trees = append(e.trees, evaluator)
	}

	switch f.GetRescaling() {
	case pb.Rescaling_NONE:
		return e, nil
	case pb.Rescaling_AVERAGING:
		return EvaluatorFunc(func(features []float64) float64 {
			return e.Evaluate(features) / float64(len(e.trees))
		}), nil
	case pb.Rescaling_LOG_ODDS:
		return EvaluatorFunc(func(features []float64) float64 {
			return 1.0 / (1.0 + math.Exp(-2.0*e.Evaluate(features)))
		}), nil
	}

	return nil, fmt.Errorf("unknown rescaling method: %v", f.GetRescaling)
}
// LearningCurve computes the progressive learning curve after each epoch on the
// given examples
func LearningCurve(f *pb.Forest, e Examples) *pb.TrainingResults {
	tr := &pb.TrainingResults{
		EpochResults: make([]*pb.EpochResult, 0, len(f.GetTrees())),
	}

	for i := range f.GetTrees() {
		evaluator, err := NewRescaledFastForestEvaluator(&pb.Forest{
			Trees:     f.GetTrees()[:i],
			Rescaling: f.GetRescaling().Enum(),
		})
		if err != nil {
			glog.Fatal(err)
		}
		er := computeEpochResult(evaluator, e)
		tr.EpochResults = append(tr.EpochResults, &er)
	}
	return tr
}