package mockingbird_test import ( "github.com/lazywei/liblinear" . "github.com/lazywei/mockingbird" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) var _ = Describe("Naive Bayes", func() { X, y := liblinear.ReadLibsvm("test_fixture/test_samples.libsvm", false) /* _, nFeatures := X.Dims() */ /* X, _ = X.View(0, 0, 3, nFeatures).(*mat64.Dense) */ /* y, _ = y.View(0, 0, 3, 1).(*mat64.Dense) */ nb := NewNaiveBayes() Describe("Fit", func() { nb.Fit(X, y) tokensTotal, langsTotal, langsCount, tokensTotalPerLang, tokenCountPerLang := nb.GetParams() It("should count tokens and languages", func() { Expect(tokensTotal).To(Equal(238)) Expect(langsTotal).To(Equal(22)) }) It("should count samples for each languages", func() { Expect(langsCount).To(Equal(map[int]int{
func main() { switch kingpin.Parse() { case "train": X, y := liblinear.ReadLibsvm(*trainSample, false) switch *trainSolver { case 0: nb := mb.NewNaiveBayes() nb.Fit(X, y) os.MkdirAll(*trainOutput, 0755) err := ioutil.WriteFile( filepath.Join(*trainOutput, "naive_bayes.gob"), []byte(nb.ToGob()), 0644, ) if err != nil { log.Fatal(err) } case 1: lr := mb.NewLogisticRegression() lr.Fit(X, y) lr.SaveModel(filepath.Join(*trainOutput, "lr.model")) default: fmt.Println("Unsupported Solver Type: %v", *trainSolver) } case "predict": fmt.Println("Data Loading ...") X, _ := liblinear.ReadLibsvm(*predictTestData, false) fmt.Println("Data Loaded") switch *predictSolver { case 0: fmt.Println("Model Loading ...") gobStr, err := ioutil.ReadFile(*predictModel) fmt.Println("Model Loaded") if err != nil { log.Fatal(err) } fmt.Println("Model Initiating ...") nb := mb.NewNaiveBayesFromGob(string(gobStr)) fmt.Println("Model Initiated ...") labels := []int{} for _, y := range nb.Predict(X) { labels = append(labels, y.Label) } spew.Dump(labels) case 1: lr := mb.NewLogisticRegressionFromModel(*predictModel) labels := []int{} for _, y := range lr.Predict(X) { labels = append(labels, y.Label) } spew.Dump(labels) default: fmt.Println("Unsupported Solver Type: %v", *predictSolver) } case "collectRosetta": CollectRosetta(*rosettaRootPath, *rosettaDestPath) case "convertLibsvm": if *libsvmBowPath != "" { ConvertLibsvmWithBow(*libsvmSamplePath, *libsvmOutputDirPath, *libsvmBowPath) } else { ConvertLibsvm(*libsvmSamplePath, *libsvmOutputDirPath) } } }