예제 #1
0
파일: knn.go 프로젝트: c4e8ece0/goodlearn
func (classifier *kNNClassifier) Classify(testRow row.Row) (slice.Slice, error) {
	trainingData := classifier.trainingData
	if trainingData == nil {
		return nil, knnerrors.NewUntrainedClassifierError()
	}

	numTestRowFeatures := testRow.NumFeatures()
	numTrainingDataFeatures := trainingData.NumFeatures()
	if numTestRowFeatures != numTrainingDataFeatures {
		return nil, knnerrors.NewRowLengthMismatchError(numTestRowFeatures, numTrainingDataFeatures)
	}

	testFeatures, ok := testRow.Features().(slice.FloatSlice)
	if !ok {
		return nil, knnerrors.NewNonFloatFeaturesTestRowError()
	}
	testFeatureValues := testFeatures.Values()

	nearestNeighbours := knnutilities.NewKNNTargetCollection(classifier.k)

	for i := 0; i < trainingData.NumRows(); i++ {
		trainingRow, _ := trainingData.Row(i)
		trainingFeatures, _ := trainingRow.Features().(slice.FloatSlice)
		trainingFeatureValues := trainingFeatures.Values()

		distance := knnutilities.Euclidean(testFeatureValues, trainingFeatureValues, nearestNeighbours.MaxDistance())
		if distance < nearestNeighbours.MaxDistance() {
			nearestNeighbours.Insert(trainingRow.Target(), distance)
		}
	}

	return nearestNeighbours.Vote(), nil
}
예제 #2
0
func GradientDescent(
	initialGuess []float64,
	learningRate, precision float64,
	maxIterations int,
	gradient func([]float64) ([]float64, error),
) ([]float64, error) {
	if len(initialGuess) == 0 {
		return nil, errors.New("initialGuess cannot be empty")
	}

	oldResult := make([]float64, len(initialGuess))
	newResult := make([]float64, len(initialGuess))
	copy(oldResult, initialGuess)

	for i := 0; i < maxIterations; i++ {
		gradientAtOldResult, err := gradient(oldResult)
		if err != nil {
			return nil, err
		}

		newResult = vectorutilities.Add(oldResult, vectorutilities.Scale(-learningRate, gradientAtOldResult))

		if (knnutilities.Euclidean(newResult, oldResult, precision)) < precision*precision {
			return newResult, nil
		} else {
			oldResult = newResult
		}
	}

	return newResult, nil
}
예제 #3
0
package knnutilities_test

import (
	"github.com/amitkgupta/goodlearn/classifier/knn/knnutilities"

	. "github.com/onsi/ginkgo"
	. "github.com/onsi/gomega"
)

var _ = Describe("Euclidean", func() {
	x := []float64{1, 2, -3.14}
	y := []float64{-5, 6, 2.718}
	squareEuclideanDistance := 86.316164

	Context("When the square of the Euclidean distance is less than the bailout", func() {
		var bailout float64 = 90

		It("Returns the square distance", func() {
			Ω(knnutilities.Euclidean(x, y, bailout)).Should(BeNumerically("~", squareEuclideanDistance, 0.001))
		})
	})

	Context("When the square of the Euclidean distance is greater than or equal to the bailout", func() {
		var bailout float64 = 80

		It("Returns the bailout", func() {
			Ω(knnutilities.Euclidean(x, y, bailout)).Should(Equal(bailout))
		})
	})
})