Ejemplo n.º 1
0
Archivo: main.go Proyecto: philipz/ntm
func main() {
	flag.Parse()
	gen, err := poem.NewGenerator("data/quantangshi3000.int")
	if err != nil {
		log.Fatalf("%v", err)
	}
	h1Size := 512
	numHeads := 8
	n := 128
	m := 32
	c := ntm.NewEmptyController1(gen.InputSize(), gen.OutputSize(), h1Size, numHeads, n, m)
	assignWeights(c)

	p := [][]string{
		{"红", "", "", "", ""},
		{"春", "", "", "", ""},
		{"愿", "", "", "", ""},
		{"此", "", "", "", ""},
	}
	//p := [][]string{
	//  {"阿", "", "", "", ""},
	//  {"扁", "", "", "", ""},
	//  {"无", "", "", "", ""},
	//  {"罪", "", "", "", ""},
	//}
	//p := [][]string{
	//  {"九", "", "", "", "", "", ""},
	//  {"二", "", "", "", "", "", ""},
	//  {"共", "", "", "", "", "", ""},
	//  {"识", "", "", "", "", "", ""},
	//}
	//p := [][]string{
	//  {"十", "", "", "", ""},
	//  {"四", "", "", "", ""},
	//  {"日", "", "", "", ""},
	//  {"罢", "", "", "", ""},
	//  {"免", "", "", "", ""},
	//  {"蔡", "", "", "", ""},
	//  {"正", "", "", "", ""},
	//  {"元", "", "", "", ""},
	//}
	//p := [][]string{
	//	{"全", "", "", "", "", "", ""},
	//	{"力", "", "", "", "", "", ""},
	//	{"支", "", "", "", "", "", ""},
	//	{"持", "", "", "", "", "", ""},
	//	{"柯", "", "", "", "", "", ""},
	//	{"匹", "", "", "", "", "", ""},
	//	{"文", "", "", "", "", "", ""},
	//	{"哲", "", "", "", "", "", ""},
	//}
	rand.Seed(15)
	pred := predict(c, p, gen)
	showPrediction(pred, gen, p)
}
Ejemplo n.º 2
0
Archivo: main.go Proyecto: philipz/ntm
func main() {
	flag.Parse()

	h1Size := 100
	numHeads := 1
	n := 128
	m := 20
	c := ntm.NewEmptyController1(1, 1, h1Size, numHeads, n, m)
	weightsFromFile(c)

	runs := make([]Run, 0)
	for i := 0; i < 1; i++ {
		prob := ngram.GenProb()
		var l float64 = 0
		var x [][]float64
		var y [][]float64
		var machines []*ntm.NTM
		sampletimes := 100
		for j := 0; j < sampletimes; j++ {
			x, y = ngram.GenSeq(prob)
			model := &ntm.LogisticModel{Y: y}
			machines = ntm.ForwardBackward(c, x, model)
			l += model.Loss(ntm.Predictions(machines))
			if (j+1)%10 == 0 {
				log.Printf("%d %d %f", i, j+1, l/float64(j+1))
			}
		}
		l = l / float64(sampletimes)

		r := Run{
			Conf:        RunConf{Prob: prob},
			BitsPerSeq:  l,
			X:           x,
			Y:           y,
			Predictions: ntm.Predictions(machines),
			HeadWeights: ntm.HeadWeights(machines),
		}
		runs = append(runs, r)
		//log.Printf("x: %v", x)
		//log.Printf("y: %v", y)
		//log.Printf("predictions: %s", ntm.Sprint2(ntm.Predictions(machines)))
	}

	http.HandleFunc("/", root(runs))
	if err := http.ListenAndServe(":9000", nil); err != nil {
		log.Printf("%v", err)
	}
}
Ejemplo n.º 3
0
Archivo: main.go Proyecto: philipz/ntm
func main() {
	flag.Parse()
	vectorSize := 8
	h1Size := 100
	numHeads := 1
	n := 128
	m := 20
	c := ntm.NewEmptyController1(vectorSize+2, vectorSize, h1Size, numHeads, n, m)
	copy(c.WeightsVal(), weightsFromFile())

	seqLens := []int{10, 20, 30, 50, 120}
	runs := make([]Run, 0, len(seqLens))
	for _, seql := range seqLens {
		x, y := copytask.GenSeq(seql, vectorSize)
		model := &ntm.LogisticModel{Y: y}
		machines := ntm.ForwardBackward(c, x, model)
		l := model.Loss(ntm.Predictions(machines))
		bps := l / float64(len(y)*len(y[0]))
		log.Printf("sequence length: %d, loss: %f", seql, bps)

		r := Run{
			SeqLen:      seql,
			BitsPerSeq:  bps,
			X:           x,
			Y:           y,
			Predictions: ntm.Predictions(machines),
			HeadWeights: ntm.HeadWeights(machines),
		}
		runs = append(runs, r)
		//log.Printf("x: %v", x)
		//log.Printf("y: %v", y)
		//log.Printf("predictions: %s", ntm.Sprint2(ntm.Predictions(machines)))
	}

	http.HandleFunc("/", root(runs))
	if err := http.ListenAndServe(":9000", nil); err != nil {
		log.Printf("%v", err)
	}
}
Ejemplo n.º 4
0
Archivo: main.go Proyecto: philipz/ntm
func main() {
	flag.Parse()
	if *cpuprofile != "" {
		f, err := os.Create(*cpuprofile)
		if err != nil {
			log.Fatal(err)
		}
		pprof.StartCPUProfile(f)
		defer pprof.StopCPUProfile()
	}

	http.HandleFunc("/Weights", func(w http.ResponseWriter, r *http.Request) {
		c := make(chan []byte)
		weightsChan <- c
		w.Write(<-c)
	})
	http.HandleFunc("/Loss", func(w http.ResponseWriter, r *http.Request) {
		c := make(chan []float64)
		lossChan <- c
		json.NewEncoder(w).Encode(<-c)
	})
	http.HandleFunc("/PrintDebug", func(w http.ResponseWriter, r *http.Request) {
		printDebugChan <- struct{}{}
	})
	port := 8096
	go func() {
		log.Printf("Listening on port %d", port)
		if err := http.ListenAndServe(fmt.Sprintf(":%d", port), nil); err != nil {
			log.Fatalf("%v", err)
		}
	}()

	var seed int64 = 16
	rand.Seed(seed)

	genFunc := "bt"
	x, y := repeatcopy.G[genFunc](1, 1)
	h1Size := 100
	numHeads := 2
	n := 128
	m := 20
	c := ntm.NewEmptyController1(len(x[0]), len(y[0]), h1Size, numHeads, n, m)
	weights := c.WeightsVal()
	for i := range weights {
		weights[i] = 1 * (rand.Float64() - 0.5)
	}

	losses := make([]float64, 0)
	doPrint := false

	rmsp := ntm.NewRMSProp(c)
	log.Printf("genFunc: %s, seed: %d, numweights: %d, numHeads: %d", genFunc, seed, len(c.WeightsVal()), c.NumHeads())
	for i := 1; ; i++ {
		x, y := repeatcopy.G[genFunc](rand.Intn(10)+1, rand.Intn(10)+1)
		model := &ntm.LogisticModel{Y: y}
		machines := rmsp.Train(x, model, 0.95, 0.5, 1e-3, 1e-3)
		l := model.Loss(ntm.Predictions(machines))
		if i%1000 == 0 {
			bpc := l / float64(len(y)*len(y[0]))
			losses = append(losses, bpc)
			log.Printf("%d, bpc: %f, seq length: %d", i, bpc, len(y))
		}

		handleHTTP(c, losses, &doPrint)

		if i%1000 == 0 && doPrint {
			printDebug(y, machines)
		}
	}
}
Ejemplo n.º 5
0
Archivo: main.go Proyecto: philipz/ntm
func main() {
	flag.Parse()
	blas64.Use(cgo.Implementation{})

	if *cpuprofile != "" {
		f, err := os.Create(*cpuprofile)
		if err != nil {
			log.Fatal(err)
		}
		pprof.StartCPUProfile(f)
		defer pprof.StopCPUProfile()
	}

	http.HandleFunc("/Weights", func(w http.ResponseWriter, r *http.Request) {
		c := make(chan []byte)
		weightsChan <- c
		w.Write(<-c)
	})
	http.HandleFunc("/Loss", func(w http.ResponseWriter, r *http.Request) {
		c := make(chan []float64)
		lossChan <- c
		json.NewEncoder(w).Encode(<-c)
	})
	http.HandleFunc("/PrintDebug", func(w http.ResponseWriter, r *http.Request) {
		printDebugChan <- struct{}{}
	})
	port := 8085
	go func() {
		log.Printf("Listening on port %d", port)
		if err := http.ListenAndServe(fmt.Sprintf(":%d", port), nil); err != nil {
			log.Fatalf("%v", err)
		}
	}()

	var seed int64 = 5
	rand.Seed(seed)
	log.Printf("seed: %d", seed)

	gen, err := poem.NewGenerator("data/quantangshi3000.int")
	if err != nil {
		log.Fatalf("%v", err)
	}
	h1Size := 512
	numHeads := 8
	n := 128
	m := 32
	c := ntm.NewEmptyController1(gen.InputSize(), gen.OutputSize(), h1Size, numHeads, n, m)
	weights := c.WeightsVal()
	for i := range weights {
		weights[i] = 1 * (rand.Float64() - 0.5)
	}

	losses := make([]float64, 0)
	doPrint := false

	rmsp := ntm.NewRMSProp(c)
	log.Printf("numweights: %d", len(c.WeightsVal()))
	var bpcSum float64 = 0
	for i := 1; ; i++ {
		x, y := gen.GenSeq()
		machines := rmsp.Train(x, &ntm.MultinomialModel{Y: y}, 0.95, 0.5, 1e-3, 1e-3)

		numChar := len(y) / 2
		l := (&ntm.MultinomialModel{Y: y[numChar+1:]}).Loss(ntm.Predictions(machines[numChar+1:]))
		bpc := l / float64(numChar)
		bpcSum += bpc

		acc := 100
		if i%acc == 0 {
			bpc := bpcSum / float64(acc)
			bpcSum = 0
			losses = append(losses, bpc)
			log.Printf("%d, bpc: %f, seq length: %d", i, bpc, len(y))
		}

		handleHTTP(c, losses, &doPrint)

		if i%10 == 0 && doPrint {
			printDebug(y, machines)
		}
	}
}
Ejemplo n.º 6
0
Archivo: main.go Proyecto: philipz/ntm
func main() {
	flag.Parse()
	if *cpuprofile != "" {
		f, err := os.Create(*cpuprofile)
		if err != nil {
			log.Fatal(err)
		}
		pprof.StartCPUProfile(f)
		defer pprof.StopCPUProfile()
	}

	http.HandleFunc("/Weights", func(w http.ResponseWriter, r *http.Request) {
		c := make(chan []byte)
		weightsChan <- c
		w.Write(<-c)
	})
	http.HandleFunc("/Loss", func(w http.ResponseWriter, r *http.Request) {
		c := make(chan []float64)
		lossChan <- c
		json.NewEncoder(w).Encode(<-c)
	})
	http.HandleFunc("/PrintDebug", func(w http.ResponseWriter, r *http.Request) {
		printDebugChan <- struct{}{}
	})
	port := 8087
	go func() {
		log.Printf("Listening on port %d", port)
		if err := http.ListenAndServe(fmt.Sprintf(":%d", port), nil); err != nil {
			log.Fatalf("%v", err)
		}
	}()

	var seed int64 = 7
	rand.Seed(seed)

	h1Size := 100
	numHeads := 1
	n := 128
	m := 20
	c := ntm.NewEmptyController1(1, 1, h1Size, numHeads, n, m)
	weights := c.WeightsVal()
	for i := range weights {
		weights[i] = 1 * (rand.Float64() - 0.5)
	}

	losses := make([]float64, 0)
	doPrint := false

	rmsp := ntm.NewRMSProp(c)
	log.Printf("seed: %d, numweights: %d, numHeads: %d", seed, len(c.WeightsVal()), c.NumHeads())
	for i := 1; ; i++ {
		x, y := ngram.GenSeq(ngram.GenProb())
		machines := rmsp.Train(x, &ntm.LogisticModel{Y: y}, 0.95, 0.5, 1e-3, 1e-3)

		if i%1000 == 0 {
			prob := ngram.GenProb()
			var l float64 = 0
			samn := 100
			for j := 0; j < samn; j++ {
				x, y = ngram.GenSeq(prob)
				model := &ntm.LogisticModel{Y: y}
				machines = ntm.ForwardBackward(c, x, model)
				l += model.Loss(ntm.Predictions(machines))
			}
			l = l / float64(samn)
			losses = append(losses, l)
			log.Printf("%d, bits-per-seq: %f", i, l)
		}

		handleHTTP(c, losses, &doPrint)

		if i%1000 == 0 && doPrint {
			printDebug(x, y, machines)
		}
	}
}