Exemplo n.º 1
0
func splitVectors(in linalg.Vector, n int) []linalg.Vector {
	res := autofunc.Split(n, &autofunc.Variable{Vector: in})
	resList := make([]linalg.Vector, len(res))
	for i, x := range res {
		resList[i] = x.Output()
	}
	return resList
}
Exemplo n.º 2
0
// ApplyBlock applies the LSTM to a batch of inputs.
func (l *LSTM) ApplyBlock(s []State, in []autofunc.Result) BlockResult {
	var internalPool, lastOutPool []*autofunc.Variable
	res := autofunc.PoolAll(in, func(in []autofunc.Result) autofunc.Result {
		var weavedInputs []autofunc.Result
		var internalResults []autofunc.Result
		for i, sObj := range s {
			state := sObj.(lstmState)
			internalVar := &autofunc.Variable{Vector: state.Internal}
			lastOutVar := &autofunc.Variable{Vector: state.Output}

			internalPool = append(internalPool, internalVar)
			lastOutPool = append(lastOutPool, lastOutVar)

			weavedInputs = append(weavedInputs, in[i], lastOutVar, internalVar)
			internalResults = append(internalResults, internalVar)
		}

		gateIn := autofunc.Concat(weavedInputs...)
		inValue := l.inputValue.Batch(gateIn, len(in))
		inGate := l.inputGate.Batch(gateIn, len(in))
		rememberGate := l.rememberGate.Batch(gateIn, len(in))

		lastState := autofunc.Concat(internalResults...)
		newState := autofunc.Add(autofunc.Mul(rememberGate, lastState),
			autofunc.Mul(inValue, inGate))

		return autofunc.Pool(newState, func(newState autofunc.Result) autofunc.Result {
			var newWeaved []autofunc.Result
			for i, state := range autofunc.Split(len(in), newState) {
				newWeaved = append(newWeaved, in[i], lastOutPool[i], state)
			}
			newGateIn := autofunc.Concat(newWeaved...)
			outGate := l.outputGate.Batch(newGateIn, len(in))
			outValues := neuralnet.HyperbolicTangent{}.Apply(newState)
			return autofunc.Concat(newState, autofunc.Mul(outGate, outValues))
		})
	})

	states, outs := splitLSTMOutput(len(in), res.Output())
	return &lstmResult{
		CellStates:   states,
		OutputVecs:   outs,
		InternalPool: internalPool,
		LastOutPool:  lastOutPool,
		JoinedOut:    res,
	}
}