コード例 #1
0
ファイル: average_test.go プロジェクト: CTLife/golearn
func TestPredict(t *testing.T) {

	a := NewAveragePerceptron(10, 1.2, 0.5, 0.3)

	if a == nil {

		t.Errorf("Unable to create average perceptron")
	}

	absPath, _ := filepath.Abs("../examples/datasets/house-votes-84.csv")
	rawData, err := base.ParseCSVToInstances(absPath, true)
	if err != nil {
		t.Fail()
	}

	trainData, testData := base.InstancesTrainTestSplit(rawData, 0.5)
	a.Fit(trainData)

	if a.trained == false {
		t.Errorf("Perceptron was not trained")
	}

	predictions := a.Predict(testData)
	cf, err := evaluation.GetConfusionMatrix(testData, predictions)
	if err != nil {
		t.Errorf("Couldn't get confusion matrix: %s", err)
		t.Fail()
	}
	fmt.Println(evaluation.GetSummary(cf))
	fmt.Println(trainData)
	fmt.Println(testData)
	if evaluation.GetAccuracy(cf) < 0.65 {
		t.Errorf("Perceptron not trained correctly")
	}
}
コード例 #2
0
ファイル: knn_bench_test.go プロジェクト: CTLife/golearn
func BenchmarkKNNWithNoOpts(b *testing.B) {
	// Load
	train, test := readMnist()
	cls := NewKnnClassifier("euclidean", 1)
	cls.AllowOptimisations = false
	cls.Fit(train)
	predictions := cls.Predict(test)
	c, err := evaluation.GetConfusionMatrix(test, predictions)
	if err != nil {
		panic(err)
	}
	fmt.Println(evaluation.GetSummary(c))
	fmt.Println(evaluation.GetAccuracy(c))
}
コード例 #3
0
ファイル: multisvc_test.go プロジェクト: CTLife/golearn
func TestMultiSVMUnweighted(t *testing.T) {
	Convey("Loading data...", t, func() {
		inst, err := base.ParseCSVToInstances("../examples/datasets/articles.csv", false)
		So(err, ShouldBeNil)
		X, Y := base.InstancesTrainTestSplit(inst, 0.4)

		m := NewMultiLinearSVC("l1", "l2", true, 1.0, 1e-4, nil)
		m.Fit(X)

		Convey("Predictions should work...", func() {
			predictions, err := m.Predict(Y)
			cf, err := evaluation.GetConfusionMatrix(Y, predictions)
			So(err, ShouldEqual, nil)
			So(evaluation.GetAccuracy(cf), ShouldBeGreaterThan, 0.70)
		})
	})
}
コード例 #4
0
ファイル: randomforest_test.go プロジェクト: CTLife/golearn
func TestRandomForest(t *testing.T) {
	Convey("Given a valid CSV file", t, func() {
		inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
		So(err, ShouldBeNil)

		Convey("When Chi-Merge filtering the data", func() {
			filt := filters.NewChiMergeFilter(inst, 0.90)
			for _, a := range base.NonClassFloatAttributes(inst) {
				filt.AddAttribute(a)
			}
			filt.Train()
			instf := base.NewLazilyFilteredInstances(inst, filt)

			Convey("Splitting the data into test and training sets", func() {
				trainData, testData := base.InstancesTrainTestSplit(instf, 0.60)

				Convey("Fitting and predicting with a Random Forest", func() {
					rf := NewRandomForest(10, 3)
					err = rf.Fit(trainData)
					So(err, ShouldBeNil)

					predictions, err := rf.Predict(testData)
					So(err, ShouldBeNil)

					confusionMat, err := evaluation.GetConfusionMatrix(testData, predictions)
					So(err, ShouldBeNil)

					Convey("Predictions should be somewhat accurate", func() {
						So(evaluation.GetAccuracy(confusionMat), ShouldBeGreaterThan, 0.35)
					})
				})
			})
		})

		Convey("Fitting with a Random Forest with too many features compared to the data", func() {
			rf := NewRandomForest(10, len(base.NonClassAttributes(inst))+1)
			err = rf.Fit(inst)

			Convey("Should return an error", func() {
				So(err, ShouldNotBeNil)
			})
		})
	})
}
コード例 #5
0
ファイル: multisvc_test.go プロジェクト: CTLife/golearn
func TestMultiSVMWeighted(t *testing.T) {
	Convey("Loading data...", t, func() {
		weights := make(map[string]float64)
		weights["Finance"] = 0.1739
		weights["Tech"] = 0.0750
		weights["Politics"] = 0.4928

		inst, err := base.ParseCSVToInstances("../examples/datasets/articles.csv", false)
		So(err, ShouldBeNil)
		X, Y := base.InstancesTrainTestSplit(inst, 0.4)

		m := NewMultiLinearSVC("l1", "l2", true, 0.62, 1e-4, weights)
		m.Fit(X)

		Convey("Predictions should work...", func() {
			predictions, err := m.Predict(Y)
			cf, err := evaluation.GetConfusionMatrix(Y, predictions)
			So(err, ShouldEqual, nil)
			So(evaluation.GetAccuracy(cf), ShouldBeGreaterThan, 0.70)
		})
	})
}
コード例 #6
0
ファイル: bagging_test.go プロジェクト: GeekFreaker/golearn
func TestBaggedModelRandomForest(t *testing.T) {
	Convey("Given data", t, func() {
		inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
		So(err, ShouldBeNil)

		Convey("Splitting the data into training and test data", func() {
			trainData, testData := base.InstancesTrainTestSplit(inst, 0.6)

			Convey("Filtering the split datasets", func() {
				rand.Seed(time.Now().UnixNano())
				filt := filters.NewChiMergeFilter(inst, 0.90)
				for _, a := range base.NonClassFloatAttributes(inst) {
					filt.AddAttribute(a)
				}
				filt.Train()
				trainDataf := base.NewLazilyFilteredInstances(trainData, filt)
				testDataf := base.NewLazilyFilteredInstances(testData, filt)

				Convey("Fitting and Predicting with a Bagged Model of 10 Random Trees", func() {
					rf := new(BaggedModel)
					for i := 0; i < 10; i++ {
						rf.AddModel(trees.NewRandomTree(2))
					}

					rf.Fit(trainDataf)
					predictions := rf.Predict(testDataf)

					confusionMat, err := evaluation.GetConfusionMatrix(testDataf, predictions)
					So(err, ShouldBeNil)

					Convey("Predictions are somewhat accurate", func() {
						So(evaluation.GetAccuracy(confusionMat), ShouldBeGreaterThan, 0.5)
					})
				})
			})
		})
	})
}
コード例 #7
0
ファイル: one_v_all_test.go プロジェクト: CTLife/golearn
func TestOneVsAllModel(t *testing.T) {

	classifierFunc := func(c string) base.Classifier {
		m, err := linear_models.NewLinearSVC("l1", "l2", true, 1.0, 1e-4)
		if err != nil {
			panic(err)
		}
		return m
	}

	Convey("Given data", t, func() {
		inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
		So(err, ShouldBeNil)

		X, Y := base.InstancesTrainTestSplit(inst, 0.4)

		m := NewOneVsAllModel(classifierFunc)
		m.Fit(X)

		Convey("The maximum class index should be 2", func() {
			So(m.maxClassVal, ShouldEqual, 2)
		})

		Convey("There should be three of everything...", func() {
			So(len(m.filters), ShouldEqual, 3)
			So(len(m.classifiers), ShouldEqual, 3)
		})

		Convey("Predictions should work...", func() {
			predictions, err := m.Predict(Y)
			So(err, ShouldEqual, nil)
			cf, err := evaluation.GetConfusionMatrix(Y, predictions)
			So(err, ShouldEqual, nil)
			fmt.Println(evaluation.GetAccuracy(cf))
			fmt.Println(evaluation.GetSummary(cf))
		})
	})
}
コード例 #8
0
ファイル: tree_test.go プロジェクト: CTLife/golearn
func verifyTreeClassification(trainData, testData base.FixedDataGrid) {
	rand.Seed(44414515)
	Convey("Using InferID3Tree to create the tree and do the fitting", func() {
		Convey("Using a RandomTreeRule", func() {
			randomTreeRuleGenerator := new(RandomTreeRuleGenerator)
			randomTreeRuleGenerator.Attributes = 2
			root := InferID3Tree(trainData, randomTreeRuleGenerator)

			Convey("Predicting with the tree", func() {
				predictions, err := root.Predict(testData)
				So(err, ShouldBeNil)

				confusionMatrix, err := evaluation.GetConfusionMatrix(testData, predictions)
				So(err, ShouldBeNil)

				Convey("Predictions should be somewhat accurate", func() {
					So(evaluation.GetAccuracy(confusionMatrix), ShouldBeGreaterThan, 0.5)
				})
			})
		})

		Convey("Using a InformationGainRule", func() {
			informationGainRuleGenerator := new(InformationGainRuleGenerator)
			root := InferID3Tree(trainData, informationGainRuleGenerator)

			Convey("Predicting with the tree", func() {
				predictions, err := root.Predict(testData)
				So(err, ShouldBeNil)

				confusionMatrix, err := evaluation.GetConfusionMatrix(testData, predictions)
				So(err, ShouldBeNil)

				Convey("Predictions should be somewhat accurate", func() {
					So(evaluation.GetAccuracy(confusionMatrix), ShouldBeGreaterThan, 0.5)
				})
			})
		})
		Convey("Using a GiniCoefficientRuleGenerator", func() {
			gRuleGen := new(GiniCoefficientRuleGenerator)
			root := InferID3Tree(trainData, gRuleGen)
			Convey("Predicting with the tree", func() {
				predictions, err := root.Predict(testData)
				So(err, ShouldBeNil)

				confusionMatrix, err := evaluation.GetConfusionMatrix(testData, predictions)
				So(err, ShouldBeNil)

				Convey("Predictions should be somewhat accurate", func() {
					So(evaluation.GetAccuracy(confusionMatrix), ShouldBeGreaterThan, 0.5)
				})
			})
		})
		Convey("Using a InformationGainRatioRuleGenerator", func() {
			gRuleGen := new(InformationGainRatioRuleGenerator)
			root := InferID3Tree(trainData, gRuleGen)
			Convey("Predicting with the tree", func() {
				predictions, err := root.Predict(testData)
				So(err, ShouldBeNil)

				confusionMatrix, err := evaluation.GetConfusionMatrix(testData, predictions)
				So(err, ShouldBeNil)

				Convey("Predictions should be somewhat accurate", func() {
					So(evaluation.GetAccuracy(confusionMatrix), ShouldBeGreaterThan, 0.5)
				})
			})
		})

	})

	Convey("Using NewRandomTree to create the tree", func() {
		root := NewRandomTree(2)

		Convey("Fitting with the tree", func() {
			err := root.Fit(trainData)
			So(err, ShouldBeNil)

			Convey("Predicting with the tree, *without* pruning first", func() {
				predictions, err := root.Predict(testData)
				So(err, ShouldBeNil)

				confusionMatrix, err := evaluation.GetConfusionMatrix(testData, predictions)
				So(err, ShouldBeNil)

				Convey("Predictions should be somewhat accurate", func() {
					So(evaluation.GetAccuracy(confusionMatrix), ShouldBeGreaterThan, 0.5)
				})
			})

			Convey("Predicting with the tree, pruning first", func() {
				root.Prune(testData)

				predictions, err := root.Predict(testData)
				So(err, ShouldBeNil)

				confusionMatrix, err := evaluation.GetConfusionMatrix(testData, predictions)
				So(err, ShouldBeNil)

				Convey("Predictions should be somewhat accurate", func() {
					So(evaluation.GetAccuracy(confusionMatrix), ShouldBeGreaterThan, 0.4)
				})
			})
		})
	})
}
コード例 #9
0
ファイル: id3.go プロジェクト: tanduong/golearn
// computeAccuracy is a helper method for Prune()
func computeAccuracy(predictions base.FixedDataGrid, from base.FixedDataGrid) float64 {
	cf, _ := evaluation.GetConfusionMatrix(from, predictions)
	return evaluation.GetAccuracy(cf)
}
コード例 #10
0
ファイル: id3.go プロジェクト: 24hours/golearn
// computeAccuracy is a helper method for Prune()
func computeAccuracy(predictions *base.Instances, from *base.Instances) float64 {
	cf := eval.GetConfusionMatrix(from, predictions)
	return eval.GetAccuracy(cf)
}
コード例 #11
0
ファイル: tree_test.go プロジェクト: GeekFreaker/golearn
func TestRandomTreeClassification(t *testing.T) {
	Convey("Predictions on filtered data with a Random Tree", t, func() {
		instances, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
		So(err, ShouldBeNil)

		trainData, testData := base.InstancesTrainTestSplit(instances, 0.6)

		filter := filters.NewChiMergeFilter(instances, 0.9)
		for _, a := range base.NonClassFloatAttributes(instances) {
			filter.AddAttribute(a)
		}
		filter.Train()
		filteredTrainData := base.NewLazilyFilteredInstances(trainData, filter)
		filteredTestData := base.NewLazilyFilteredInstances(testData, filter)

		Convey("Using InferID3Tree to create the tree and do the fitting", func() {
			Convey("Using a RandomTreeRule", func() {
				randomTreeRuleGenerator := new(RandomTreeRuleGenerator)
				randomTreeRuleGenerator.Attributes = 2
				root := InferID3Tree(filteredTrainData, randomTreeRuleGenerator)

				Convey("Predicting with the tree", func() {
					predictions, err := root.Predict(filteredTestData)
					So(err, ShouldBeNil)

					confusionMatrix, err := evaluation.GetConfusionMatrix(filteredTestData, predictions)
					So(err, ShouldBeNil)

					Convey("Predictions should be somewhat accurate", func() {
						So(evaluation.GetAccuracy(confusionMatrix), ShouldBeGreaterThan, 0.5)
					})
				})
			})

			Convey("Using a InformationGainRule", func() {
				informationGainRuleGenerator := new(InformationGainRuleGenerator)
				root := InferID3Tree(filteredTrainData, informationGainRuleGenerator)

				Convey("Predicting with the tree", func() {
					predictions, err := root.Predict(filteredTestData)
					So(err, ShouldBeNil)

					confusionMatrix, err := evaluation.GetConfusionMatrix(filteredTestData, predictions)
					So(err, ShouldBeNil)

					Convey("Predictions should be somewhat accurate", func() {
						So(evaluation.GetAccuracy(confusionMatrix), ShouldBeGreaterThan, 0.5)
					})
				})
			})
		})

		Convey("Using NewRandomTree to create the tree", func() {
			root := NewRandomTree(2)

			Convey("Fitting with the tree", func() {
				err = root.Fit(filteredTrainData)
				So(err, ShouldBeNil)

				Convey("Predicting with the tree, *without* pruning first", func() {
					predictions, err := root.Predict(filteredTestData)
					So(err, ShouldBeNil)

					confusionMatrix, err := evaluation.GetConfusionMatrix(filteredTestData, predictions)
					So(err, ShouldBeNil)

					Convey("Predictions should be somewhat accurate", func() {
						So(evaluation.GetAccuracy(confusionMatrix), ShouldBeGreaterThan, 0.5)
					})
				})

				Convey("Predicting with the tree, pruning first", func() {
					root.Prune(filteredTestData)

					predictions, err := root.Predict(filteredTestData)
					So(err, ShouldBeNil)

					confusionMatrix, err := evaluation.GetConfusionMatrix(filteredTestData, predictions)
					So(err, ShouldBeNil)

					Convey("Predictions should be somewhat accurate", func() {
						So(evaluation.GetAccuracy(confusionMatrix), ShouldBeGreaterThan, 0.4)
					})
				})
			})
		})
	})
}