func main() { // Load and parse the data from csv files fmt.Println("Loading data...") trainData, err := base.ParseCSVToInstances("data/mnist_train.csv", true) if err != nil { panic(err) } testData, err := base.ParseCSVToInstances("data/mnist_test.csv", true) if err != nil { panic(err) } // Create a new linear SVC with some good default values classifier, err := linear_models.NewLinearSVC("l1", "l2", true, 1.0, 1e-4) if err != nil { panic(err) } // Don't output information on each iteration base.Silent() // Train the linear SVC fmt.Println("Training...") classifier.Fit(trainData) // Make predictions for the test data fmt.Println("Predicting...") predictions, err := classifier.Predict(testData) if err != nil { panic(err) } // Get a confusion matrix and print out some accuracy stats for our predictions confusionMat, err := evaluation.GetConfusionMatrix(testData, predictions) if err != nil { panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error())) } fmt.Println(evaluation.GetSummary(confusionMat)) }
func TestOneVsAllModel(t *testing.T) { classifierFunc := func(c string) base.Classifier { m, err := linear_models.NewLinearSVC("l1", "l2", true, 1.0, 1e-4) if err != nil { panic(err) } return m } Convey("Given data", t, func() { inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) So(err, ShouldBeNil) X, Y := base.InstancesTrainTestSplit(inst, 0.4) m := NewOneVsAllModel(classifierFunc) m.Fit(X) Convey("The maximum class index should be 2", func() { So(m.maxClassVal, ShouldEqual, 2) }) Convey("There should be three of everything...", func() { So(len(m.filters), ShouldEqual, 3) So(len(m.classifiers), ShouldEqual, 3) }) Convey("Predictions should work...", func() { predictions, err := m.Predict(Y) So(err, ShouldEqual, nil) cf, err := evaluation.GetConfusionMatrix(Y, predictions) So(err, ShouldEqual, nil) fmt.Println(evaluation.GetAccuracy(cf)) fmt.Println(evaluation.GetSummary(cf)) }) }) }