func (gdpe *gradientDescentParameterEstimator) Train(ds dataset.Dataset) error { if !ds.AllFeaturesFloats() { return gdeErrors.NewNonFloatFeaturesError() } if !ds.AllTargetsFloats() { return gdeErrors.NewNonFloatTargetError() } if ds.NumTargets() != 1 { return gdeErrors.NewInvalidNumberOfTargetsError(ds.NumTargets()) } if ds.NumFeatures() == 0 { return gdeErrors.NewNoFeaturesError() } gdpe.trainingSet = ds return nil }
func (regressor *linearRegressor) Train(trainingData dataset.Dataset) error { if !trainingData.AllFeaturesFloats() { return linearerrors.NewNonFloatFeaturesError() } if !trainingData.AllTargetsFloats() { return linearerrors.NewNonFloatTargetsError() } if trainingData.NumTargets() != 1 { return linearerrors.NewInvalidNumberOfTargetsError(trainingData.NumTargets()) } if trainingData.NumFeatures() == 0 { return linearerrors.NewNoFeaturesError() } estimator, err := gradientdescentestimator.NewGradientDescentParameterEstimator( defaultLearningRate, defaultPrecision, defaultMaxIterations, gradientdescentestimator.LinearModelLeastSquaresLossGradient, ) if err != nil { return linearerrors.NewEstimatorConstructionError(err) } err = estimator.Train(trainingData) if err != nil { return linearerrors.NewEstimatorTrainingError(err) } coefficients, err := estimator.Estimate(defaultInitialCoefficientEstimate(trainingData.NumFeatures())) if err != nil { return linearerrors.NewEstimatorEstimationError(err) } regressor.coefficients = coefficients return nil }
It("Returns true", func() { Ω(ds.AllTargetsFloats()).Should(BeFalse()) }) }) }) Describe("NumFeatures and NumTargets", func() { BeforeEach(func() { columnTypes, err := columntype.StringsToColumnTypes([]string{"1.0", "x", "x", "1.0", "x"}) Ω(err).ShouldNot(HaveOccurred()) ds = dataset.NewDataset([]int{1, 3}, []int{0, 2, 4}, columnTypes) }) It("Returns the correct number of features and targets", func() { Ω(ds.NumFeatures()).Should(Equal(2)) Ω(ds.NumTargets()).Should(Equal(3)) }) }) Describe("Adding, Counting, and Getting rows", func() { BeforeEach(func() { columnTypes, err := columntype.StringsToColumnTypes([]string{"1.0", "x", "x", "1.0", "x"}) Ω(err).ShouldNot(HaveOccurred()) ds = dataset.NewDataset([]int{1, 3}, []int{0, 2, 4}, columnTypes) }) Context("When the dataset is empty", func() { It("Has 0 rows", func() { Ω(ds.NumRows()).To(BeZero())