func SplitDataset(ds dataset.Dataset, trainingRatio float64, source rand.Source) (dataset.Dataset, dataset.Dataset, error) {
	if trainingRatio < 0 || trainingRatio > 1 {
		return nil, nil, fmt.Errorf("Unable to split dataset with invalid ratio %.2f", trainingRatio)
	}

	numRows := ds.NumRows()
	if numRows == 0 {
		return nil, nil, errors.New("Cannot split empty dataset")
	}

	r := rand.New(source)
	perm := r.Perm(numRows)

	trainingRowMap := make([]int, 0, numRows)
	testRowMap := make([]int, 0, numRows)

	for _, rowIndex := range perm {
		if r.Float64() < trainingRatio {
			trainingRowMap = append(trainingRowMap, rowIndex)
		} else {
			testRowMap = append(testRowMap, rowIndex)
		}
	}

	return dataset.NewSubset(ds, trainingRowMap), dataset.NewSubset(ds, testRowMap), nil
}
Exemple #2
0
func (classifier *kNNClassifier) Train(trainingData dataset.Dataset) error {
	if !trainingData.AllFeaturesFloats() {
		return knnerrors.NewNonFloatFeaturesTrainingSetError()
	}

	if trainingData.NumRows() == 0 {
		return knnerrors.NewEmptyTrainingDatasetError()
	}

	classifier.trainingData = trainingData
	return nil
}
			Context("when the dataset has rows and the ratio is valid", func() {
				var expectedTrainingTargets, expectedTestTargets []slice.Slice

				BeforeEach(func() {
					for i := 0; i < 20; i++ {
						originalSet.AddRowFromStrings([]string{strconv.Itoa(i)})
					}

					expectedTrainingTargets = makeSingleFloatTargets(16, 13, 1, 2, 19, 6, 5, 14, 10, 11, 4, 15, 8, 3, 9)
					expectedTestTargets = makeSingleFloatTargets(12, 7, 17, 0, 18)
				})

				It("splits the dataset into two according to the given ratio", func() {
					Ω(err).ShouldNot(HaveOccurred())

					Ω(trainingSet.NumRows()).Should(Equal(len(expectedTrainingTargets)))
					for i := 0; i < trainingSet.NumRows(); i++ {
						trainingRow, rowErr := trainingSet.Row(i)
						Ω(rowErr).ShouldNot(HaveOccurred())
						Ω(trainingRow.Target().Equals(expectedTrainingTargets[i])).Should(BeTrue())
					}

					Ω(testSet.NumRows()).Should(Equal(len(expectedTestTargets)))
					for i := 0; i < testSet.NumRows(); i++ {
						testRow, rowErr := testSet.Row(i)
						Ω(rowErr).ShouldNot(HaveOccurred())
						Ω(testRow.Target().Equals(expectedTestTargets[i])).Should(BeTrue())
					}
				})
			})
		})
Exemple #4
0
			Ω(ds.NumFeatures()).Should(Equal(2))
			Ω(ds.NumTargets()).Should(Equal(3))
		})
	})

	Describe("Adding, Counting, and Getting rows", func() {
		BeforeEach(func() {
			columnTypes, err := columntype.StringsToColumnTypes([]string{"1.0", "x", "x", "1.0", "x"})
			Ω(err).ShouldNot(HaveOccurred())

			ds = dataset.NewDataset([]int{1, 3}, []int{0, 2, 4}, columnTypes)
		})

		Context("When the dataset is empty", func() {
			It("Has 0 rows", func() {
				Ω(ds.NumRows()).To(BeZero())
			})

			Context("When getting a row", func() {
				It("Returns an error", func() {
					_, err := ds.Row(0)
					Ω(err).Should(HaveOccurred())

					_, err = ds.Row(-1)
					Ω(err).Should(HaveOccurred())

					_, err = ds.Row(1)
					Ω(err).Should(HaveOccurred())
				})
			})