func splitVectors(in linalg.Vector, n int) []linalg.Vector { res := autofunc.Split(n, &autofunc.Variable{Vector: in}) resList := make([]linalg.Vector, len(res)) for i, x := range res { resList[i] = x.Output() } return resList }
// ApplyBlock applies the LSTM to a batch of inputs. func (l *LSTM) ApplyBlock(s []State, in []autofunc.Result) BlockResult { var internalPool, lastOutPool []*autofunc.Variable res := autofunc.PoolAll(in, func(in []autofunc.Result) autofunc.Result { var weavedInputs []autofunc.Result var internalResults []autofunc.Result for i, sObj := range s { state := sObj.(lstmState) internalVar := &autofunc.Variable{Vector: state.Internal} lastOutVar := &autofunc.Variable{Vector: state.Output} internalPool = append(internalPool, internalVar) lastOutPool = append(lastOutPool, lastOutVar) weavedInputs = append(weavedInputs, in[i], lastOutVar, internalVar) internalResults = append(internalResults, internalVar) } gateIn := autofunc.Concat(weavedInputs...) inValue := l.inputValue.Batch(gateIn, len(in)) inGate := l.inputGate.Batch(gateIn, len(in)) rememberGate := l.rememberGate.Batch(gateIn, len(in)) lastState := autofunc.Concat(internalResults...) newState := autofunc.Add(autofunc.Mul(rememberGate, lastState), autofunc.Mul(inValue, inGate)) return autofunc.Pool(newState, func(newState autofunc.Result) autofunc.Result { var newWeaved []autofunc.Result for i, state := range autofunc.Split(len(in), newState) { newWeaved = append(newWeaved, in[i], lastOutPool[i], state) } newGateIn := autofunc.Concat(newWeaved...) outGate := l.outputGate.Batch(newGateIn, len(in)) outValues := neuralnet.HyperbolicTangent{}.Apply(newState) return autofunc.Concat(newState, autofunc.Mul(outGate, outValues)) }) }) states, outs := splitLSTMOutput(len(in), res.Output()) return &lstmResult{ CellStates: states, OutputVecs: outs, InternalPool: internalPool, LastOutPool: lastOutPool, JoinedOut: res, } }