package mockingbird_test

import (
	"github.com/lazywei/liblinear"
	. "github.com/lazywei/mockingbird"

	. "github.com/onsi/ginkgo"
	. "github.com/onsi/gomega"
)

var _ = Describe("Naive Bayes", func() {
	X, y := liblinear.ReadLibsvm("test_fixture/test_samples.libsvm", false)

	/* _, nFeatures := X.Dims() */
	/* X, _ = X.View(0, 0, 3, nFeatures).(*mat64.Dense) */
	/* y, _ = y.View(0, 0, 3, 1).(*mat64.Dense) */

	nb := NewNaiveBayes()

	Describe("Fit", func() {

		nb.Fit(X, y)
		tokensTotal, langsTotal, langsCount, tokensTotalPerLang, tokenCountPerLang := nb.GetParams()

		It("should count tokens and languages", func() {
			Expect(tokensTotal).To(Equal(238))
			Expect(langsTotal).To(Equal(22))
		})

		It("should count samples for each languages", func() {
			Expect(langsCount).To(Equal(map[int]int{
Example #2
0
func main() {
	switch kingpin.Parse() {
	case "train":
		X, y := liblinear.ReadLibsvm(*trainSample, false)

		switch *trainSolver {
		case 0:
			nb := mb.NewNaiveBayes()
			nb.Fit(X, y)

			os.MkdirAll(*trainOutput, 0755)
			err := ioutil.WriteFile(
				filepath.Join(*trainOutput, "naive_bayes.gob"),
				[]byte(nb.ToGob()),
				0644,
			)

			if err != nil {
				log.Fatal(err)
			}
		case 1:
			lr := mb.NewLogisticRegression()
			lr.Fit(X, y)
			lr.SaveModel(filepath.Join(*trainOutput, "lr.model"))

		default:
			fmt.Println("Unsupported Solver Type: %v", *trainSolver)
		}

	case "predict":
		fmt.Println("Data Loading ...")
		X, _ := liblinear.ReadLibsvm(*predictTestData, false)
		fmt.Println("Data Loaded")

		switch *predictSolver {
		case 0:
			fmt.Println("Model Loading ...")
			gobStr, err := ioutil.ReadFile(*predictModel)
			fmt.Println("Model Loaded")

			if err != nil {
				log.Fatal(err)
			}

			fmt.Println("Model Initiating ...")
			nb := mb.NewNaiveBayesFromGob(string(gobStr))
			fmt.Println("Model Initiated ...")

			labels := []int{}
			for _, y := range nb.Predict(X) {
				labels = append(labels, y.Label)
			}
			spew.Dump(labels)
		case 1:
			lr := mb.NewLogisticRegressionFromModel(*predictModel)
			labels := []int{}
			for _, y := range lr.Predict(X) {
				labels = append(labels, y.Label)
			}
			spew.Dump(labels)

		default:

			fmt.Println("Unsupported Solver Type: %v", *predictSolver)
		}

	case "collectRosetta":
		CollectRosetta(*rosettaRootPath, *rosettaDestPath)

	case "convertLibsvm":
		if *libsvmBowPath != "" {
			ConvertLibsvmWithBow(*libsvmSamplePath, *libsvmOutputDirPath, *libsvmBowPath)
		} else {
			ConvertLibsvm(*libsvmSamplePath, *libsvmOutputDirPath)
		}
	}
}