示例#1
0
文件: main.go 项目: philipz/ntm
func main() {
	flag.Parse()

	h1Size := 100
	numHeads := 1
	n := 128
	m := 20
	c := ntm.NewEmptyController1(1, 1, h1Size, numHeads, n, m)
	weightsFromFile(c)

	runs := make([]Run, 0)
	for i := 0; i < 1; i++ {
		prob := ngram.GenProb()
		var l float64 = 0
		var x [][]float64
		var y [][]float64
		var machines []*ntm.NTM
		sampletimes := 100
		for j := 0; j < sampletimes; j++ {
			x, y = ngram.GenSeq(prob)
			model := &ntm.LogisticModel{Y: y}
			machines = ntm.ForwardBackward(c, x, model)
			l += model.Loss(ntm.Predictions(machines))
			if (j+1)%10 == 0 {
				log.Printf("%d %d %f", i, j+1, l/float64(j+1))
			}
		}
		l = l / float64(sampletimes)

		r := Run{
			Conf:        RunConf{Prob: prob},
			BitsPerSeq:  l,
			X:           x,
			Y:           y,
			Predictions: ntm.Predictions(machines),
			HeadWeights: ntm.HeadWeights(machines),
		}
		runs = append(runs, r)
		//log.Printf("x: %v", x)
		//log.Printf("y: %v", y)
		//log.Printf("predictions: %s", ntm.Sprint2(ntm.Predictions(machines)))
	}

	http.HandleFunc("/", root(runs))
	if err := http.ListenAndServe(":9000", nil); err != nil {
		log.Printf("%v", err)
	}
}
示例#2
0
文件: main.go 项目: philipz/ntm
func main() {
	flag.Parse()
	vectorSize := 8
	h1Size := 100
	numHeads := 1
	n := 128
	m := 20
	c := ntm.NewEmptyController1(vectorSize+2, vectorSize, h1Size, numHeads, n, m)
	copy(c.WeightsVal(), weightsFromFile())

	seqLens := []int{10, 20, 30, 50, 120}
	runs := make([]Run, 0, len(seqLens))
	for _, seql := range seqLens {
		x, y := copytask.GenSeq(seql, vectorSize)
		model := &ntm.LogisticModel{Y: y}
		machines := ntm.ForwardBackward(c, x, model)
		l := model.Loss(ntm.Predictions(machines))
		bps := l / float64(len(y)*len(y[0]))
		log.Printf("sequence length: %d, loss: %f", seql, bps)

		r := Run{
			SeqLen:      seql,
			BitsPerSeq:  bps,
			X:           x,
			Y:           y,
			Predictions: ntm.Predictions(machines),
			HeadWeights: ntm.HeadWeights(machines),
		}
		runs = append(runs, r)
		//log.Printf("x: %v", x)
		//log.Printf("y: %v", y)
		//log.Printf("predictions: %s", ntm.Sprint2(ntm.Predictions(machines)))
	}

	http.HandleFunc("/", root(runs))
	if err := http.ListenAndServe(":9000", nil); err != nil {
		log.Printf("%v", err)
	}
}