示例#1
0
文件: main.go 项目: jlund3/ford
func main() {
	util.SeedNow()

	corpus := load.Newsgroups.Import()
	lda := topic.NewLDA(corpus, 20, .1, .01)
	graphical.RunGibbs(lda, 100)

	labels := eval.NewLabelCorpusWordFeature(corpus, lda.Z)
	fmt.Println("Accuracty:", eval.Naive(labels.SplitRand(.8)))
	for z := 0; z < lda.T; z++ {
		fmt.Printf("%d: %s\n", z, lda.TopicSummary(z, 10))
	}
}
示例#2
0
文件: main.go 项目: jlund3/ford
func main() {
	corpus := load.Newsgroups.Import()
	_, constraints := load.GetConstraints("data/constraints/newsgroups.txt")

	util.SeedNow()
	seed := rand.Int63()

	rand.Seed(seed)
	trainPart, testPart := util.CreateSplit(corpus.M, .8)

	rand.Seed(seed)

	name := "annealing"
	dirpath := "scratch/output/" + name + "/"
	util.EnsureDir(dirpath)
	out, err := ioutil.TempFile(dirpath, "")
	if err != nil {
		panic(err)
	}
	defer out.Close()

	fmt.Fprintln(out, "# seed ", seed)
	fmt.Fprintln(out, "# iter duration changes wz-naive wz-wabbit w-naive w-wabbit post")

	itm := topic.NewITM(corpus, 20, .1, .01, 100)
	for i := 0; i < 100; i++ {
		itm.Gibbs()
	}

	for _, constraint := range constraints {
		itm.AddConstraintString(constraint)
	}

	checker := topic.NewConvergenceCheck(itm.Z)
	var duration time.Duration

	epoc := func(iters int, temp float64) {
		fmt.Fprintln(out, "# temp %f", temp)
		for i := 0; i < iters; i++ {
			start := time.Now()
			itm.AnnealedGibbs(temp)
			end := time.Now()
			duration += end.Sub(start)

			wzLabeled := eval.NewLabelCorpusWordFeature(corpus, itm.Z)
			wzTrain, wzTest := wzLabeled.Split(trainPart, testPart)

			zLabeled := eval.NewLabeledCorpusFeature(corpus, itm.Z)
			zTrain, zTest := zLabeled.Split(trainPart, testPart)

			stats := []string{
				fmt.Sprintf("%d", i),
				fmt.Sprintf("%f", duration.Seconds()),
				fmt.Sprintf("%d", checker.Check()),

				fmt.Sprintf("%f", eval.Naive(wzTrain, wzTest)),
				fmt.Sprintf("%f", eval.Wabbit(wzTrain, wzTest)),

				fmt.Sprintf("%f", eval.Naive(zTrain, zTest)),
				fmt.Sprintf("%f", eval.Wabbit(zTrain, zTest)),

				fmt.Sprintf("%f", itm.Posterior())}

			fmt.Fprintln(out, strings.Join(stats, " "))
		}
	}

	epoc(20, 10)
	epoc(20, 5)
	epoc(20, 2.5)
	epoc(20, 2)
	epoc(20, 1.75)
	epoc(20, 1.5)
	epoc(20, 1.25)
	epoc(20, 1.1)
	epoc(20, 1)
	epoc(20, .75)
	epoc(20, .5)
	epoc(20, .25)
	epoc(20, .1)
	epoc(20, .01)
}