Ejemplo n.º 1
0
Archivo: main.go Proyecto: philipz/ntm
func vecFromString(s string, g *poem.Generator) []float64 {
	v := make([]float64, g.InputSize())
	c, ok := g.Dataset.Chars[s]
	if !ok {
		return v
	}
	v[c] = 1
	return v
}
Ejemplo n.º 2
0
Archivo: main.go Proyecto: philipz/ntm
func sample(output []float64, gen *poem.Generator) ([]float64, int) {
	var characterIndex int
	r := rand.Float64()
	var sum float64
	for i, v := range output {
		sum += v
		if sum >= r {
			characterIndex = i
			break
		}
	}

	input := make([]float64, gen.InputSize())
	input[characterIndex] = 1
	return input, characterIndex
}
Ejemplo n.º 3
0
Archivo: main.go Proyecto: philipz/ntm
func predict(c ntm.Controller, shi [][]string, gen *poem.Generator) [][]float64 {
	machine := ntm.MakeEmptyNTM(c)

	// Feed the poem constraints into the NTM.
	numChar := 0
	output := make([][]float64, 0)
	for _, line := range shi {
		for _, s := range line {
			numChar += 1
			input := vecFromString(s, gen)
			machine, output = forward(machine, input, output)
		}
		numChar += 1
		input := gen.Linefeed()
		machine, output = forward(machine, input, output)
	}
	input := gen.EndOfPoem()
	machine, output = forward(machine, input, output)

	input = make([]float64, gen.InputSize())
	machine, output = forward(machine, input, output)

	// Follow the predictions of the NTM.
	i := 1
	for _, line := range shi {
		for _, s := range line {
			if s != "" {
				input = vecFromString(s, gen)
			} else {
				input, _ = sample(output[len(output)-1], gen)
			}
			machine, output = forward(machine, input, output)
			i++
		}

		if i >= numChar {
			break
		}
		input, _ = sample(output[len(output)-1], gen)
		machine, output = forward(machine, input, output)
		i++
	}

	return output
}
Ejemplo n.º 4
0
Archivo: main.go Proyecto: philipz/ntm
func showPrediction(pred [][]float64, gen *poem.Generator, oripoem [][]string) {
	ps := make([]string, len(pred))

	// Prepare a slice representation of the poem constraints.
	i := len(pred)/2 + 1
	poema := make([]string, len(pred))
	for _, line := range oripoem {
		for _, c := range line {
			if c != "" {
				poema[i] = c
			}
			i++
		}
		poema[i] = poem.CharLinefeed
		i++
	}

	// Determine the final characters from the predicted probability densities, with the following requirements:
	//   * The choosen character is the same as the input constraints.
	//   * The choosen characters are unique among themselves.
	res := make([][]poem.Char, len(pred))
	for i, p := range pred {
		if i < len(pred)/2+1 {
			ps[i] = ""
		} else if poema[i] != "" {
			ps[i] = poema[i]
		} else {
			sorted := gen.SortOutput(p)
			for _, c := range sorted {
				if c.S == poem.CharUnknown || c.S == poem.CharLinefeed {
					continue
				}
				var dup bool = false
				for _, psc := range ps {
					if psc == c.S {
						dup = true
						break
					}
				}
				if !dup {
					ps[i] = c.S
					break
				}
			}
		}

		res[i] = gen.SortOutput(p)[0:5]
	}

	// Print the probability densities.
	for i, chars := range res {
		log.Printf("%s -> %v", ps[i], chars)
		if i == len(res)/2 {
			log.Printf("-------------")
		}
	}

	// Print the final generated poem.
	s := "\n"
	for _, c := range ps[len(ps)/2+1:] {
		if c == poem.CharLinefeed {
			s += "\n"
		} else {
			s += c
		}
	}
	log.Printf(s)
}