예제 #1
0
파일: trainer.go 프로젝트: postfix/hmm-1
func main() {
	flagAddr := flag.String("addr", ":6060", "Listening address")
	flagCorpus := flag.String("corpus", "", "Corpus file in JSON format")
	flagStates := flag.Int("states", 2, "Number of hidden states")
	flagIter := flag.Int("iter", 20, "Number of EM iterations")
	flagModel := flag.String("model", "", "Model file in JSON format")
	flagLL := flag.String("logl", "", "Log-likelihood file")
	flagPProf := flag.Bool("pprof", false, "Output pprof file")
	flagParallel := flag.Bool("parallel", true, "Run multi-threading")
	flag.Parse()

	go func() {
		log.Println(http.ListenAndServe(*flagAddr, nil))
	}()

	var corpus []*core.Instance
	if f, e := os.Open(*flagCorpus); e != nil {
		log.Fatalf("Cannot open %s: %v", *flagCorpus, e)
	} else {
		defer f.Close()
		corpus = loader.LoadJSON(f)
		// Infer unexported fileds of Instance.
		for i, _ := range corpus {
			corpus[i].Index()
		}
	}

	C := core.EstimateC(corpus)
	baseline := core.Init(*flagStates, C, corpus, rand.New(rand.NewSource(99)))

	f := core.CreateFileOrStdout(*flagLL)
	if f != os.Stdout {
		defer f.Close()
	}

	if *flagPProf {
		defer profile.Start(profile.CPUProfile).Stop()
	}

	if *flagParallel {
		runtime.GOMAXPROCS(runtime.NumCPU())
	}

	model := core.Train(corpus, *flagStates, C, *flagIter, baseline, f)
	core.SaveModel(model, *flagModel)
}
예제 #2
0
func TestGenerateAndLoad(t *testing.T) {
	dir, e := ioutil.TempDir("", "converter_test")
	if e != nil {
		t.Fatalf("Cannot create temp dir: %v", e)
	}
	defer os.RemoveAll(dir)

	*flagCSV = buildnrun.Pkg(path.Join(kCSVDir, kCSVFile))
	*flagCorpus = path.Join(dir, "corpus.json")

	Run(new(PlainFeatureGenerator))

	corpus := loader.LoadJSON(α(os.Open(*flagCorpus)).(io.Reader))
	fmt.Printf("Loaded %d instances.\nThe first one:%v\nThe last one:%v\n",
		len(corpus), corpus[0], corpus[len(corpus)-1])

	if len(corpus) != 5376 {
		t.Errorf("Expecting %d instances, got %d", 5376, len(corpus))
	}
}