func (classifier *kNNClassifier) Classify(testRow row.Row) (slice.Slice, error) { trainingData := classifier.trainingData if trainingData == nil { return nil, knnerrors.NewUntrainedClassifierError() } numTestRowFeatures := testRow.NumFeatures() numTrainingDataFeatures := trainingData.NumFeatures() if numTestRowFeatures != numTrainingDataFeatures { return nil, knnerrors.NewRowLengthMismatchError(numTestRowFeatures, numTrainingDataFeatures) } testFeatures, ok := testRow.Features().(slice.FloatSlice) if !ok { return nil, knnerrors.NewNonFloatFeaturesTestRowError() } testFeatureValues := testFeatures.Values() nearestNeighbours := knnutilities.NewKNNTargetCollection(classifier.k) for i := 0; i < trainingData.NumRows(); i++ { trainingRow, _ := trainingData.Row(i) trainingFeatures, _ := trainingRow.Features().(slice.FloatSlice) trainingFeatureValues := trainingFeatures.Values() distance := knnutilities.Euclidean(testFeatureValues, trainingFeatureValues, nearestNeighbours.MaxDistance()) if distance < nearestNeighbours.MaxDistance() { nearestNeighbours.Insert(trainingRow.Target(), distance) } } return nearestNeighbours.Vote(), nil }
func (regressor *linearRegressor) Predict(testRow row.Row) (float64, error) { coefficients := regressor.coefficients if coefficients == nil { return 0, linearerrors.NewUntrainedRegressorError() } numTestRowFeatures := testRow.NumFeatures() numCoefficients := len(coefficients) if numCoefficients != numTestRowFeatures+1 { return 0, linearerrors.NewRowLengthMismatchError(numTestRowFeatures, numCoefficients) } testFeatures, ok := testRow.Features().(slice.FloatSlice) if !ok { return 0, linearerrors.NewNonFloatFeaturesTestRowError() } testFeatureValues := testFeatures.Values() result := coefficients[numCoefficients-1] for i, c := range coefficients[:numCoefficients-1] { result = result + c*testFeatureValues[i] } return result, nil }
}) It("Has 1 row", func() { Ω(ds.NumRows()).To(Equal(1)) }) Context("When getting a row", func() { Context("When the index is correct", func() { It("Consistently returns the correct row", func() { newRow, err = ds.Row(0) Ω(err).ShouldNot(HaveOccurred()) newRowAgain, err := ds.Row(0) Ω(err).ShouldNot(HaveOccurred()) Ω(newRow.Features().Equals(newRowAgain.Features())).Should(BeTrue()) Ω(newRow.Target().Equals(newRowAgain.Target())).Should(BeTrue()) }) }) Context("When the index is correct", func() { It("Returns an error", func() { _, err = ds.Row(-1) Ω(err).Should(HaveOccurred()) newRow, err = ds.Row(1) Ω(err).Should(HaveOccurred()) }) }) }) })