func (l *lstmGate) BatchR(rv autofunc.RVector, in autofunc.RResult, n int) autofunc.RResult { if l.Peephole == nil { return l.Activation.ApplyR(rv, l.Dense.BatchR(rv, in, n)) } return autofunc.PoolR(in, func(in autofunc.RResult) autofunc.RResult { vecSize := len(in.Output()) / n var weightedInputs []autofunc.RResult var peepholed []autofunc.RResult peephole := autofunc.NewRVariable(l.Peephole, rv) for i := 0; i < n; i++ { start := vecSize * i weightedEnd := start + vecSize - len(l.Peephole.Vector) weightedInputs = append(weightedInputs, autofunc.SliceR(in, start, weightedEnd)) peepholeMe := autofunc.SliceR(in, weightedEnd, (i+1)*vecSize) peepholed = append(peepholed, autofunc.MulR(peephole, peepholeMe)) } weighted := l.Dense.BatchR(rv, autofunc.ConcatR(weightedInputs...), n) joinedPeep := autofunc.ConcatR(peepholed...) return l.Activation.ApplyR(rv, autofunc.AddR(joinedPeep, weighted)) }) }
// ApplyBlockR applies the block to an input. func (g *GRU) ApplyBlockR(rv autofunc.RVector, s []RState, in []autofunc.RResult) BlockRResult { stateVars, stateRes := PoolVecRStates(s) var gateInputs []autofunc.RResult for i, x := range stateRes { gateInputs = append(gateInputs, in[i], x) } n := len(in) gateInput := autofunc.ConcatR(gateInputs...) stateIn := autofunc.ConcatR(stateRes...) resetMask := g.resetGate.BatchR(rv, gateInput, n) updateMask := g.updateGate.BatchR(rv, gateInput, n) maskedByReset := autofunc.MulR(resetMask, stateIn) inputValue := autofunc.PoolSplitR(n, maskedByReset, func(newStates []autofunc.RResult) autofunc.RResult { var newGateInputs []autofunc.RResult for i, input := range in { newGateInputs = append(newGateInputs, input, newStates[i]) } newIn := autofunc.ConcatR(newGateInputs...) return g.inputValue.BatchR(rv, newIn, n) }) newState := autofunc.PoolR(updateMask, func(umask autofunc.RResult) autofunc.RResult { updateComplement := autofunc.AddScalerR(autofunc.ScaleR(umask, -1), 1) return autofunc.AddR(autofunc.MulR(umask, stateIn), autofunc.MulR(updateComplement, inputValue)) }) return &gruRResult{ InStates: stateVars, Output: newState, } }
// ApplyBlockR is like ApplyBlock, but with support for // the R operator. func (l *LSTM) ApplyBlockR(rv autofunc.RVector, s []RState, in []autofunc.RResult) BlockRResult { var internalPool, lastOutPool []*autofunc.Variable res := autofunc.PoolAllR(in, func(in []autofunc.RResult) autofunc.RResult { var lastOutPoolR []*autofunc.RVariable var weavedInputs []autofunc.RResult var internalResults []autofunc.RResult for i, sObj := range s { state := sObj.(lstmRState) internalVar := &autofunc.Variable{Vector: state.Internal} lastOutVar := &autofunc.Variable{Vector: state.Output} internalPool = append(internalPool, internalVar) lastOutPool = append(lastOutPool, lastOutVar) internalR := &autofunc.RVariable{ Variable: internalVar, ROutputVec: state.InternalR, } lastOutR := &autofunc.RVariable{ Variable: lastOutVar, ROutputVec: state.OutputR, } lastOutPoolR = append(lastOutPoolR, lastOutR) weavedInputs = append(weavedInputs, in[i], lastOutR, internalR) internalResults = append(internalResults, internalR) } gateIn := autofunc.ConcatR(weavedInputs...) inValue := l.inputValue.BatchR(rv, gateIn, len(in)) inGate := l.inputGate.BatchR(rv, gateIn, len(in)) rememberGate := l.rememberGate.BatchR(rv, gateIn, len(in)) lastState := autofunc.ConcatR(internalResults...) newState := autofunc.AddR(autofunc.MulR(rememberGate, lastState), autofunc.MulR(inValue, inGate)) return autofunc.PoolR(newState, func(newState autofunc.RResult) autofunc.RResult { var newWeaved []autofunc.RResult for i, state := range autofunc.SplitR(len(in), newState) { newWeaved = append(newWeaved, in[i], lastOutPoolR[i], state) } newGateIn := autofunc.ConcatR(newWeaved...) outGate := l.outputGate.BatchR(rv, newGateIn, len(in)) outValues := neuralnet.HyperbolicTangent{}.ApplyR(rv, newState) return autofunc.ConcatR(newState, autofunc.MulR(outGate, outValues)) }) }) states, outs := splitLSTMOutput(len(in), res.Output()) statesR, outsR := splitLSTMOutput(len(in), res.ROutput()) return &lstmRResult{ CellStates: states, RCellStates: statesR, OutputVecs: outs, ROutputVecs: outsR, InternalPool: internalPool, LastOutPool: lastOutPool, JoinedOut: res, } }