func SplitDataset(ds dataset.Dataset, trainingRatio float64, source rand.Source) (dataset.Dataset, dataset.Dataset, error) { if trainingRatio < 0 || trainingRatio > 1 { return nil, nil, fmt.Errorf("Unable to split dataset with invalid ratio %.2f", trainingRatio) } numRows := ds.NumRows() if numRows == 0 { return nil, nil, errors.New("Cannot split empty dataset") } r := rand.New(source) perm := r.Perm(numRows) trainingRowMap := make([]int, 0, numRows) testRowMap := make([]int, 0, numRows) for _, rowIndex := range perm { if r.Float64() < trainingRatio { trainingRowMap = append(trainingRowMap, rowIndex) } else { testRowMap = append(testRowMap, rowIndex) } } return dataset.NewSubset(ds, trainingRowMap), dataset.NewSubset(ds, testRowMap), nil }
func (classifier *kNNClassifier) Train(trainingData dataset.Dataset) error { if !trainingData.AllFeaturesFloats() { return knnerrors.NewNonFloatFeaturesTrainingSetError() } if trainingData.NumRows() == 0 { return knnerrors.NewEmptyTrainingDatasetError() } classifier.trainingData = trainingData return nil }
Context("when the dataset has rows and the ratio is valid", func() { var expectedTrainingTargets, expectedTestTargets []slice.Slice BeforeEach(func() { for i := 0; i < 20; i++ { originalSet.AddRowFromStrings([]string{strconv.Itoa(i)}) } expectedTrainingTargets = makeSingleFloatTargets(16, 13, 1, 2, 19, 6, 5, 14, 10, 11, 4, 15, 8, 3, 9) expectedTestTargets = makeSingleFloatTargets(12, 7, 17, 0, 18) }) It("splits the dataset into two according to the given ratio", func() { Ω(err).ShouldNot(HaveOccurred()) Ω(trainingSet.NumRows()).Should(Equal(len(expectedTrainingTargets))) for i := 0; i < trainingSet.NumRows(); i++ { trainingRow, rowErr := trainingSet.Row(i) Ω(rowErr).ShouldNot(HaveOccurred()) Ω(trainingRow.Target().Equals(expectedTrainingTargets[i])).Should(BeTrue()) } Ω(testSet.NumRows()).Should(Equal(len(expectedTestTargets))) for i := 0; i < testSet.NumRows(); i++ { testRow, rowErr := testSet.Row(i) Ω(rowErr).ShouldNot(HaveOccurred()) Ω(testRow.Target().Equals(expectedTestTargets[i])).Should(BeTrue()) } }) }) })
Ω(ds.NumFeatures()).Should(Equal(2)) Ω(ds.NumTargets()).Should(Equal(3)) }) }) Describe("Adding, Counting, and Getting rows", func() { BeforeEach(func() { columnTypes, err := columntype.StringsToColumnTypes([]string{"1.0", "x", "x", "1.0", "x"}) Ω(err).ShouldNot(HaveOccurred()) ds = dataset.NewDataset([]int{1, 3}, []int{0, 2, 4}, columnTypes) }) Context("When the dataset is empty", func() { It("Has 0 rows", func() { Ω(ds.NumRows()).To(BeZero()) }) Context("When getting a row", func() { It("Returns an error", func() { _, err := ds.Row(0) Ω(err).Should(HaveOccurred()) _, err = ds.Row(-1) Ω(err).Should(HaveOccurred()) _, err = ds.Row(1) Ω(err).Should(HaveOccurred()) }) })