func (classifier *kNNClassifier) Classify(testRow row.Row) (slice.Slice, error) { trainingData := classifier.trainingData if trainingData == nil { return nil, knnerrors.NewUntrainedClassifierError() } numTestRowFeatures := testRow.NumFeatures() numTrainingDataFeatures := trainingData.NumFeatures() if numTestRowFeatures != numTrainingDataFeatures { return nil, knnerrors.NewRowLengthMismatchError(numTestRowFeatures, numTrainingDataFeatures) } testFeatures, ok := testRow.Features().(slice.FloatSlice) if !ok { return nil, knnerrors.NewNonFloatFeaturesTestRowError() } testFeatureValues := testFeatures.Values() nearestNeighbours := knnutilities.NewKNNTargetCollection(classifier.k) for i := 0; i < trainingData.NumRows(); i++ { trainingRow, _ := trainingData.Row(i) trainingFeatures, _ := trainingRow.Features().(slice.FloatSlice) trainingFeatureValues := trainingFeatures.Values() distance := knnutilities.Euclidean(testFeatureValues, trainingFeatureValues, nearestNeighbours.MaxDistance()) if distance < nearestNeighbours.MaxDistance() { nearestNeighbours.Insert(trainingRow.Target(), distance) } } return nearestNeighbours.Vote(), nil }
func GradientDescent( initialGuess []float64, learningRate, precision float64, maxIterations int, gradient func([]float64) ([]float64, error), ) ([]float64, error) { if len(initialGuess) == 0 { return nil, errors.New("initialGuess cannot be empty") } oldResult := make([]float64, len(initialGuess)) newResult := make([]float64, len(initialGuess)) copy(oldResult, initialGuess) for i := 0; i < maxIterations; i++ { gradientAtOldResult, err := gradient(oldResult) if err != nil { return nil, err } newResult = vectorutilities.Add(oldResult, vectorutilities.Scale(-learningRate, gradientAtOldResult)) if (knnutilities.Euclidean(newResult, oldResult, precision)) < precision*precision { return newResult, nil } else { oldResult = newResult } } return newResult, nil }
package knnutilities_test import ( "github.com/amitkgupta/goodlearn/classifier/knn/knnutilities" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) var _ = Describe("Euclidean", func() { x := []float64{1, 2, -3.14} y := []float64{-5, 6, 2.718} squareEuclideanDistance := 86.316164 Context("When the square of the Euclidean distance is less than the bailout", func() { var bailout float64 = 90 It("Returns the square distance", func() { Ω(knnutilities.Euclidean(x, y, bailout)).Should(BeNumerically("~", squareEuclideanDistance, 0.001)) }) }) Context("When the square of the Euclidean distance is greater than or equal to the bailout", func() { var bailout float64 = 80 It("Returns the bailout", func() { Ω(knnutilities.Euclidean(x, y, bailout)).Should(Equal(bailout)) }) }) })