Example #1
0
// 输出的度量名字为 "confusion:M/N" 其中M为真实标注,N为预测标注
func (e *ConfusionMatrixEvaluator) Evaluate(m supervised.Model, set data.Dataset) (result Evaluation) {
	result.Metrics = make(map[string]float64)
	iter := set.CreateIterator()
	iter.Start()
	for !iter.End() {
		instance := iter.GetInstance()
		out := m.Predict(instance)
		name := fmt.Sprintf("confusion:%d/%d", instance.Output.Label, out.Label)
		result.Metrics[name]++
		iter.Next()
	}
	return
}
Example #2
0
func (e *PREvaluator) Evaluate(m supervised.Model, set data.Dataset) (result Evaluation) {
	tp := 0 // true-positive
	tn := 0 // true-negative
	fp := 0 // false-positive
	fn := 0 // false-negative

	iter := set.CreateIterator()
	iter.Start()
	for !iter.End() {
		instance := iter.GetInstance()
		if instance.Output.Label > 2 {
			log.Fatal("调用PREvaluator但不是二分类问题")
		}

		out := m.Predict(instance)
		if out.Label == 0 {
			if instance.Output.Label == 0 {
				tn++
			} else {
				fn++
			}
		} else {
			if instance.Output.Label == 0 {
				fp++
			} else {
				tp++
			}
		}
		iter.Next()
	}

	result.Metrics = make(map[string]float64)
	result.Metrics["precision"] = float64(tp) / float64(tp+fp)
	result.Metrics["recall"] = float64(tp) / float64(tp+fn)
	result.Metrics["tp"] = float64(tp)
	result.Metrics["fp"] = float64(fp)
	result.Metrics["tn"] = float64(tn)
	result.Metrics["fn"] = float64(fn)
	result.Metrics["fscore"] =
		2 * result.Metrics["precision"] * result.Metrics["recall"] / (result.Metrics["precision"] + result.Metrics["recall"])

	return
}
Example #3
0
func (e *AccuracyEvaluator) Evaluate(m supervised.Model, set data.Dataset) (result Evaluation) {
	correctPrediction := 0
	totalPrediction := 0

	iter := set.CreateIterator()
	iter.Start()
	for !iter.End() {
		instance := iter.GetInstance()
		out := m.Predict(instance)
		if instance.Output.Label == out.Label {
			correctPrediction++
		}
		totalPrediction++
		iter.Next()
	}

	result.Metrics = make(map[string]float64)
	result.Metrics["accuracy"] = float64(correctPrediction) / float64(totalPrediction)

	return
}