Example #1
0
func (classifier *kNNClassifier) Classify(testRow row.Row) (slice.Slice, error) {
	trainingData := classifier.trainingData
	if trainingData == nil {
		return nil, knnerrors.NewUntrainedClassifierError()
	}

	numTestRowFeatures := testRow.NumFeatures()
	numTrainingDataFeatures := trainingData.NumFeatures()
	if numTestRowFeatures != numTrainingDataFeatures {
		return nil, knnerrors.NewRowLengthMismatchError(numTestRowFeatures, numTrainingDataFeatures)
	}

	testFeatures, ok := testRow.Features().(slice.FloatSlice)
	if !ok {
		return nil, knnerrors.NewNonFloatFeaturesTestRowError()
	}
	testFeatureValues := testFeatures.Values()

	nearestNeighbours := knnutilities.NewKNNTargetCollection(classifier.k)

	for i := 0; i < trainingData.NumRows(); i++ {
		trainingRow, _ := trainingData.Row(i)
		trainingFeatures, _ := trainingRow.Features().(slice.FloatSlice)
		trainingFeatureValues := trainingFeatures.Values()

		distance := knnutilities.Euclidean(testFeatureValues, trainingFeatureValues, nearestNeighbours.MaxDistance())
		if distance < nearestNeighbours.MaxDistance() {
			nearestNeighbours.Insert(trainingRow.Target(), distance)
		}
	}

	return nearestNeighbours.Vote(), nil
}
Example #2
0
func (regressor *linearRegressor) Predict(testRow row.Row) (float64, error) {
	coefficients := regressor.coefficients
	if coefficients == nil {
		return 0, linearerrors.NewUntrainedRegressorError()
	}

	numTestRowFeatures := testRow.NumFeatures()
	numCoefficients := len(coefficients)
	if numCoefficients != numTestRowFeatures+1 {
		return 0, linearerrors.NewRowLengthMismatchError(numTestRowFeatures, numCoefficients)
	}

	testFeatures, ok := testRow.Features().(slice.FloatSlice)
	if !ok {
		return 0, linearerrors.NewNonFloatFeaturesTestRowError()
	}
	testFeatureValues := testFeatures.Values()

	result := coefficients[numCoefficients-1]
	for i, c := range coefficients[:numCoefficients-1] {
		result = result + c*testFeatureValues[i]
	}

	return result, nil
}
Example #3
0
					})

					It("Has 1 row", func() {
						Ω(ds.NumRows()).To(Equal(1))
					})

					Context("When getting a row", func() {
						Context("When the index is correct", func() {
							It("Consistently returns the correct row", func() {
								newRow, err = ds.Row(0)
								Ω(err).ShouldNot(HaveOccurred())

								newRowAgain, err := ds.Row(0)
								Ω(err).ShouldNot(HaveOccurred())

								Ω(newRow.Features().Equals(newRowAgain.Features())).Should(BeTrue())
								Ω(newRow.Target().Equals(newRowAgain.Target())).Should(BeTrue())
							})
						})

						Context("When the index is correct", func() {
							It("Returns an error", func() {
								_, err = ds.Row(-1)
								Ω(err).Should(HaveOccurred())

								newRow, err = ds.Row(1)
								Ω(err).Should(HaveOccurred())
							})
						})
					})
				})