func (classifier *kNNClassifier) Classify(testRow row.Row) (slice.Slice, error) { trainingData := classifier.trainingData if trainingData == nil { return nil, knnerrors.NewUntrainedClassifierError() } numTestRowFeatures := testRow.NumFeatures() numTrainingDataFeatures := trainingData.NumFeatures() if numTestRowFeatures != numTrainingDataFeatures { return nil, knnerrors.NewRowLengthMismatchError(numTestRowFeatures, numTrainingDataFeatures) } testFeatures, ok := testRow.Features().(slice.FloatSlice) if !ok { return nil, knnerrors.NewNonFloatFeaturesTestRowError() } testFeatureValues := testFeatures.Values() nearestNeighbours := knnutilities.NewKNNTargetCollection(classifier.k) for i := 0; i < trainingData.NumRows(); i++ { trainingRow, _ := trainingData.Row(i) trainingFeatures, _ := trainingRow.Features().(slice.FloatSlice) trainingFeatureValues := trainingFeatures.Values() distance := knnutilities.Euclidean(testFeatureValues, trainingFeatureValues, nearestNeighbours.MaxDistance()) if distance < nearestNeighbours.MaxDistance() { nearestNeighbours.Insert(trainingRow.Target(), distance) } } return nearestNeighbours.Vote(), nil }
func (regressor *linearRegressor) Predict(testRow row.Row) (float64, error) { coefficients := regressor.coefficients if coefficients == nil { return 0, linearerrors.NewUntrainedRegressorError() } numTestRowFeatures := testRow.NumFeatures() numCoefficients := len(coefficients) if numCoefficients != numTestRowFeatures+1 { return 0, linearerrors.NewRowLengthMismatchError(numTestRowFeatures, numCoefficients) } testFeatures, ok := testRow.Features().(slice.FloatSlice) if !ok { return 0, linearerrors.NewNonFloatFeaturesTestRowError() } testFeatureValues := testFeatures.Values() result := coefficients[numCoefficients-1] for i, c := range coefficients[:numCoefficients-1] { result = result + c*testFeatureValues[i] } return result, nil }
Context("When getting a row", func() { It("Returns an error", func() { _, err := ds.Row(0) Ω(err).Should(HaveOccurred()) _, err = ds.Row(-1) Ω(err).Should(HaveOccurred()) _, err = ds.Row(1) Ω(err).Should(HaveOccurred()) }) }) Context("When adding a row", func() { var newRow row.Row var err error Context("When the row's length is incorrect", func() { BeforeEach(func() { err = ds.AddRowFromStrings([]string{"0.0", "hi", "mom", "94"}) }) It("Returns an error", func() { Ω(err).Should(HaveOccurred()) }) It("Has 0 rows", func() { Ω(ds.NumRows()).To(BeZero()) })