Exemplo n.º 1
0
// Sample generates rows(batch) samples using the Metropolis Hastings sample
// generation method. The initial location is NOT updated during the call to Sample.
//
// The number of columns in batch must equal len(m.Initial), otherwise Sample
// will panic.
func (m MetropolisHastingser) Sample(batch *mat64.Dense) {
	burnIn := m.BurnIn
	rate := m.Rate
	if rate == 0 {
		rate = 1
	}
	r, c := batch.Dims()
	if len(m.Initial) != c {
		panic("metropolishastings: length mismatch")
	}

	// Use the optimal size for the temporary memory to allow the fewest calls
	// to MetropolisHastings. The case where tmp shadows samples must be
	// aligned with the logic after burn-in so that tmp does not shadow samples
	// during the rate portion.
	tmp := batch
	if rate > r {
		tmp = mat64.NewDense(rate, c, nil)
	}
	rTmp, _ := tmp.Dims()

	// Perform burn-in.
	remaining := burnIn
	initial := make([]float64, c)
	copy(initial, m.Initial)
	for remaining != 0 {
		newSamp := min(rTmp, remaining)
		MetropolisHastings(tmp.View(0, 0, newSamp, c).(*mat64.Dense), initial, m.Target, m.Proposal, m.Src)
		copy(initial, tmp.RawRowView(newSamp-1))
		remaining -= newSamp
	}

	if rate == 1 {
		MetropolisHastings(batch, initial, m.Target, m.Proposal, m.Src)
		return
	}

	if rTmp <= r {
		tmp = mat64.NewDense(rate, c, nil)
	}

	// Take a single sample from the chain.
	MetropolisHastings(batch.View(0, 0, 1, c).(*mat64.Dense), initial, m.Target, m.Proposal, m.Src)

	copy(initial, batch.RawRowView(0))
	// For all of the other samples, first generate Rate samples and then actually
	// accept the last one.
	for i := 1; i < r; i++ {
		MetropolisHastings(tmp, initial, m.Target, m.Proposal, m.Src)
		v := tmp.RawRowView(rate - 1)
		batch.SetRow(i, v)
		copy(initial, v)
	}
}