Ejemplo n.º 1
0
func (l *lstmGate) BatchR(rv autofunc.RVector, in autofunc.RResult, n int) autofunc.RResult {
	if l.Peephole == nil {
		return l.Activation.ApplyR(rv, l.Dense.BatchR(rv, in, n))
	}
	return autofunc.PoolR(in, func(in autofunc.RResult) autofunc.RResult {
		vecSize := len(in.Output()) / n
		var weightedInputs []autofunc.RResult
		var peepholed []autofunc.RResult
		peephole := autofunc.NewRVariable(l.Peephole, rv)
		for i := 0; i < n; i++ {
			start := vecSize * i
			weightedEnd := start + vecSize - len(l.Peephole.Vector)
			weightedInputs = append(weightedInputs, autofunc.SliceR(in, start, weightedEnd))
			peepholeMe := autofunc.SliceR(in, weightedEnd, (i+1)*vecSize)
			peepholed = append(peepholed, autofunc.MulR(peephole, peepholeMe))
		}
		weighted := l.Dense.BatchR(rv, autofunc.ConcatR(weightedInputs...), n)
		joinedPeep := autofunc.ConcatR(peepholed...)
		return l.Activation.ApplyR(rv, autofunc.AddR(joinedPeep, weighted))
	})
}
Ejemplo n.º 2
0
// ApplyBlockR applies the block to an input.
func (g *GRU) ApplyBlockR(rv autofunc.RVector, s []RState, in []autofunc.RResult) BlockRResult {
	stateVars, stateRes := PoolVecRStates(s)
	var gateInputs []autofunc.RResult
	for i, x := range stateRes {
		gateInputs = append(gateInputs, in[i], x)
	}
	n := len(in)

	gateInput := autofunc.ConcatR(gateInputs...)
	stateIn := autofunc.ConcatR(stateRes...)

	resetMask := g.resetGate.BatchR(rv, gateInput, n)
	updateMask := g.updateGate.BatchR(rv, gateInput, n)

	maskedByReset := autofunc.MulR(resetMask, stateIn)
	inputValue := autofunc.PoolSplitR(n, maskedByReset,
		func(newStates []autofunc.RResult) autofunc.RResult {
			var newGateInputs []autofunc.RResult
			for i, input := range in {
				newGateInputs = append(newGateInputs, input, newStates[i])
			}
			newIn := autofunc.ConcatR(newGateInputs...)
			return g.inputValue.BatchR(rv, newIn, n)
		})

	newState := autofunc.PoolR(updateMask, func(umask autofunc.RResult) autofunc.RResult {
		updateComplement := autofunc.AddScalerR(autofunc.ScaleR(umask, -1), 1)
		return autofunc.AddR(autofunc.MulR(umask, stateIn),
			autofunc.MulR(updateComplement, inputValue))
	})

	return &gruRResult{
		InStates: stateVars,
		Output:   newState,
	}
}
Ejemplo n.º 3
0
// ApplyBlockR is like ApplyBlock, but with support for
// the R operator.
func (l *LSTM) ApplyBlockR(rv autofunc.RVector, s []RState, in []autofunc.RResult) BlockRResult {
	var internalPool, lastOutPool []*autofunc.Variable
	res := autofunc.PoolAllR(in, func(in []autofunc.RResult) autofunc.RResult {
		var lastOutPoolR []*autofunc.RVariable
		var weavedInputs []autofunc.RResult
		var internalResults []autofunc.RResult
		for i, sObj := range s {
			state := sObj.(lstmRState)
			internalVar := &autofunc.Variable{Vector: state.Internal}
			lastOutVar := &autofunc.Variable{Vector: state.Output}

			internalPool = append(internalPool, internalVar)
			lastOutPool = append(lastOutPool, lastOutVar)

			internalR := &autofunc.RVariable{
				Variable:   internalVar,
				ROutputVec: state.InternalR,
			}
			lastOutR := &autofunc.RVariable{
				Variable:   lastOutVar,
				ROutputVec: state.OutputR,
			}

			lastOutPoolR = append(lastOutPoolR, lastOutR)
			weavedInputs = append(weavedInputs, in[i], lastOutR, internalR)
			internalResults = append(internalResults, internalR)
		}

		gateIn := autofunc.ConcatR(weavedInputs...)
		inValue := l.inputValue.BatchR(rv, gateIn, len(in))
		inGate := l.inputGate.BatchR(rv, gateIn, len(in))
		rememberGate := l.rememberGate.BatchR(rv, gateIn, len(in))

		lastState := autofunc.ConcatR(internalResults...)
		newState := autofunc.AddR(autofunc.MulR(rememberGate, lastState),
			autofunc.MulR(inValue, inGate))

		return autofunc.PoolR(newState, func(newState autofunc.RResult) autofunc.RResult {
			var newWeaved []autofunc.RResult
			for i, state := range autofunc.SplitR(len(in), newState) {
				newWeaved = append(newWeaved, in[i], lastOutPoolR[i], state)
			}
			newGateIn := autofunc.ConcatR(newWeaved...)
			outGate := l.outputGate.BatchR(rv, newGateIn, len(in))
			outValues := neuralnet.HyperbolicTangent{}.ApplyR(rv, newState)
			return autofunc.ConcatR(newState, autofunc.MulR(outGate, outValues))
		})
	})

	states, outs := splitLSTMOutput(len(in), res.Output())
	statesR, outsR := splitLSTMOutput(len(in), res.ROutput())
	return &lstmRResult{
		CellStates:   states,
		RCellStates:  statesR,
		OutputVecs:   outs,
		ROutputVecs:  outsR,
		InternalPool: internalPool,
		LastOutPool:  lastOutPool,
		JoinedOut:    res,
	}
}