コード例 #1
0
ファイル: knn.go プロジェクト: c4e8ece0/goodlearn
func (classifier *kNNClassifier) Classify(testRow row.Row) (slice.Slice, error) {
	trainingData := classifier.trainingData
	if trainingData == nil {
		return nil, knnerrors.NewUntrainedClassifierError()
	}

	numTestRowFeatures := testRow.NumFeatures()
	numTrainingDataFeatures := trainingData.NumFeatures()
	if numTestRowFeatures != numTrainingDataFeatures {
		return nil, knnerrors.NewRowLengthMismatchError(numTestRowFeatures, numTrainingDataFeatures)
	}

	testFeatures, ok := testRow.Features().(slice.FloatSlice)
	if !ok {
		return nil, knnerrors.NewNonFloatFeaturesTestRowError()
	}
	testFeatureValues := testFeatures.Values()

	nearestNeighbours := knnutilities.NewKNNTargetCollection(classifier.k)

	for i := 0; i < trainingData.NumRows(); i++ {
		trainingRow, _ := trainingData.Row(i)
		trainingFeatures, _ := trainingRow.Features().(slice.FloatSlice)
		trainingFeatureValues := trainingFeatures.Values()

		distance := knnutilities.Euclidean(testFeatureValues, trainingFeatureValues, nearestNeighbours.MaxDistance())
		if distance < nearestNeighbours.MaxDistance() {
			nearestNeighbours.Insert(trainingRow.Target(), distance)
		}
	}

	return nearestNeighbours.Vote(), nil
}
コード例 #2
0
ファイル: linear.go プロジェクト: c4e8ece0/goodlearn
func (regressor *linearRegressor) Predict(testRow row.Row) (float64, error) {
	coefficients := regressor.coefficients
	if coefficients == nil {
		return 0, linearerrors.NewUntrainedRegressorError()
	}

	numTestRowFeatures := testRow.NumFeatures()
	numCoefficients := len(coefficients)
	if numCoefficients != numTestRowFeatures+1 {
		return 0, linearerrors.NewRowLengthMismatchError(numTestRowFeatures, numCoefficients)
	}

	testFeatures, ok := testRow.Features().(slice.FloatSlice)
	if !ok {
		return 0, linearerrors.NewNonFloatFeaturesTestRowError()
	}
	testFeatureValues := testFeatures.Values()

	result := coefficients[numCoefficients-1]
	for i, c := range coefficients[:numCoefficients-1] {
		result = result + c*testFeatureValues[i]
	}

	return result, nil
}
コード例 #3
0
ファイル: dataset_test.go プロジェクト: c4e8ece0/goodlearn
			Context("When getting a row", func() {
				It("Returns an error", func() {
					_, err := ds.Row(0)
					Ω(err).Should(HaveOccurred())

					_, err = ds.Row(-1)
					Ω(err).Should(HaveOccurred())

					_, err = ds.Row(1)
					Ω(err).Should(HaveOccurred())
				})
			})

			Context("When adding a row", func() {
				var newRow row.Row
				var err error

				Context("When the row's length is incorrect", func() {
					BeforeEach(func() {
						err = ds.AddRowFromStrings([]string{"0.0", "hi", "mom", "94"})
					})

					It("Returns an error", func() {
						Ω(err).Should(HaveOccurred())
					})

					It("Has 0 rows", func() {
						Ω(ds.NumRows()).To(BeZero())
					})