Ejemplo n.º 1
0
// StdDevBatch predicts the standard deviation at a set of locations of x.
func (g *GP) StdDevBatch(std []float64, x mat64.Matrix) []float64 {
	r, c := x.Dims()
	if c != g.inputDim {
		panic(badInputLength)
	}
	if std == nil {
		std = make([]float64, r)
	}
	if len(std) != r {
		panic(badStorage)
	}
	// For a single point, the stddev is
	// 		sigma = k(x,x) - k_*^T * K^-1 * k_*
	// where k is the vector of kernels between the input points and the output points
	// For many points, the formula is:
	// 		nu_* = k(x_*, k_*) - k_*^T * K^-1 * k_*
	// This creates the full covariance matrix which is an rxr matrix. However,
	// the standard deviations are just the diagonal of this matrix. Instead, be
	// smart about it and compute the diagonal terms one at a time.
	kStar := g.formKStar(x)
	var tmp mat64.Dense
	tmp.SolveCholesky(g.cholK, kStar)

	// set k(x_*, x_*) into std then subtract k_*^T K^-1 k_* , computed one row at a time
	var tmp2 mat64.Vector
	row := make([]float64, c)
	for i := range std {
		for k := 0; k < c; k++ {
			row[k] = x.At(i, k)
		}
		std[i] = g.kernel.Distance(row, row)
		tmp2.MulVec(kStar.ColView(i).T(), tmp.ColView(i))
		rt, ct := tmp2.Dims()
		if rt != 1 && ct != 1 {
			panic("bad size")
		}
		std[i] -= tmp2.At(0, 0)
		std[i] = math.Sqrt(std[i])
	}
	// Need to scale the standard deviation to be in the same units as y.
	floats.Scale(g.std, std)
	return std
}
Ejemplo n.º 2
0
// StdDev predicts the standard deviation of the function at x.
func (g *GP) StdDev(x []float64) float64 {
	if len(x) != g.inputDim {
		panic(badInputLength)
	}
	// nu_* = k(x_*, k_*) - k_*^T * K^-1 * k_*
	n := len(g.outputs)
	kstar := mat64.NewVector(n, nil)
	for i := 0; i < n; i++ {
		v := g.kernel.Distance(g.inputs.RawRowView(i), x)
		kstar.SetVec(i, v)
	}
	self := g.kernel.Distance(x, x)
	var tmp mat64.Vector
	tmp.SolveCholeskyVec(g.cholK, kstar)
	var tmp2 mat64.Vector
	tmp2.MulVec(kstar.T(), &tmp)
	rt, ct := tmp2.Dims()
	if rt != 1 || ct != 1 {
		panic("bad size")
	}
	return math.Sqrt(self-tmp2.At(0, 0)) * g.std
}