コード例 #1
0
ファイル: knn.go プロジェクト: c4e8ece0/goodlearn
func (classifier *kNNClassifier) Classify(testRow row.Row) (slice.Slice, error) {
	trainingData := classifier.trainingData
	if trainingData == nil {
		return nil, knnerrors.NewUntrainedClassifierError()
	}

	numTestRowFeatures := testRow.NumFeatures()
	numTrainingDataFeatures := trainingData.NumFeatures()
	if numTestRowFeatures != numTrainingDataFeatures {
		return nil, knnerrors.NewRowLengthMismatchError(numTestRowFeatures, numTrainingDataFeatures)
	}

	testFeatures, ok := testRow.Features().(slice.FloatSlice)
	if !ok {
		return nil, knnerrors.NewNonFloatFeaturesTestRowError()
	}
	testFeatureValues := testFeatures.Values()

	nearestNeighbours := knnutilities.NewKNNTargetCollection(classifier.k)

	for i := 0; i < trainingData.NumRows(); i++ {
		trainingRow, _ := trainingData.Row(i)
		trainingFeatures, _ := trainingRow.Features().(slice.FloatSlice)
		trainingFeatureValues := trainingFeatures.Values()

		distance := knnutilities.Euclidean(testFeatureValues, trainingFeatureValues, nearestNeighbours.MaxDistance())
		if distance < nearestNeighbours.MaxDistance() {
			nearestNeighbours.Insert(trainingRow.Target(), distance)
		}
	}

	return nearestNeighbours.Vote(), nil
}
コード例 #2
0
	"github.com/amitkgupta/goodlearn/data/slice"

	"math"

	. "github.com/onsi/ginkgo"
	. "github.com/onsi/gomega"
)

var _ = Describe("SortedTargetCollection", func() {
	Describe("Insert and MaxDistance", func() {
		var stc knnutilities.SortedTargetCollection
		var target slice.Slice
		var err error

		BeforeEach(func() {
			stc = knnutilities.NewKNNTargetCollection(2)
			target, err = slice.SliceFromRawValues(true, []int{}, []columntype.ColumnType{}, []float64{})
			Ω(err).ShouldNot(HaveOccurred())
		})

		Context("Before the collection is full", func() {
			It("The MaxDistance should be +Inf", func() {
				Ω(stc.MaxDistance()).Should(Equal(math.MaxFloat64))

				stc.Insert(target, 1.0)
				Ω(stc.MaxDistance()).Should(Equal(math.MaxFloat64))
			})
		})

		Context("When the collection is full", func() {
			initialMax := 3.0