func (gdpe *gradientDescentParameterEstimator) Train(ds dataset.Dataset) error { if !ds.AllFeaturesFloats() { return gdeErrors.NewNonFloatFeaturesError() } if !ds.AllTargetsFloats() { return gdeErrors.NewNonFloatTargetError() } if ds.NumTargets() != 1 { return gdeErrors.NewInvalidNumberOfTargetsError(ds.NumTargets()) } if ds.NumFeatures() == 0 { return gdeErrors.NewNoFeaturesError() } gdpe.trainingSet = ds return nil }
func (regressor *linearRegressor) Train(trainingData dataset.Dataset) error { if !trainingData.AllFeaturesFloats() { return linearerrors.NewNonFloatFeaturesError() } if !trainingData.AllTargetsFloats() { return linearerrors.NewNonFloatTargetsError() } if trainingData.NumTargets() != 1 { return linearerrors.NewInvalidNumberOfTargetsError(trainingData.NumTargets()) } if trainingData.NumFeatures() == 0 { return linearerrors.NewNoFeaturesError() } estimator, err := gradientdescentestimator.NewGradientDescentParameterEstimator( defaultLearningRate, defaultPrecision, defaultMaxIterations, gradientdescentestimator.LinearModelLeastSquaresLossGradient, ) if err != nil { return linearerrors.NewEstimatorConstructionError(err) } err = estimator.Train(trainingData) if err != nil { return linearerrors.NewEstimatorTrainingError(err) } coefficients, err := estimator.Estimate(defaultInitialCoefficientEstimate(trainingData.NumFeatures())) if err != nil { return linearerrors.NewEstimatorEstimationError(err) } regressor.coefficients = coefficients return nil }
Ω(ds.AllFeaturesFloats()).Should(BeFalse()) }) }) }) Describe("AllTargetsFloats", func() { Context("When all targets are floats", func() { BeforeEach(func() { columnTypes, err := columntype.StringsToColumnTypes([]string{"1.0", "1.0"}) Ω(err).ShouldNot(HaveOccurred()) ds = dataset.NewDataset([]int{}, []int{0, 1}, columnTypes) }) It("Returns true", func() { Ω(ds.AllTargetsFloats()).Should(BeTrue()) }) }) Context("When not all targets are floats", func() { BeforeEach(func() { columnTypes, err := columntype.StringsToColumnTypes([]string{"x", "1.0"}) Ω(err).ShouldNot(HaveOccurred()) ds = dataset.NewDataset([]int{}, []int{0, 1}, columnTypes) }) It("Returns true", func() { Ω(ds.AllTargetsFloats()).Should(BeFalse()) }) })