コード例 #1
0
ファイル: normal_test.go プロジェクト: darrenmcc/stat
func TestMarginal(t *testing.T) {
	for _, test := range []struct {
		mu       []float64
		sigma    *mat64.SymDense
		marginal []int
	}{
		{
			mu:       []float64{2, 3, 4},
			sigma:    mat64.NewSymDense(3, []float64{2, 0.5, 3, 0.5, 1, 0.6, 3, 0.6, 10}),
			marginal: []int{0},
		},
		{
			mu:       []float64{2, 3, 4},
			sigma:    mat64.NewSymDense(3, []float64{2, 0.5, 3, 0.5, 1, 0.6, 3, 0.6, 10}),
			marginal: []int{0, 2},
		},
		{
			mu:    []float64{2, 3, 4, 5},
			sigma: mat64.NewSymDense(4, []float64{2, 0.5, 3, 0.1, 0.5, 1, 0.6, 0.2, 3, 0.6, 10, 0.3, 0.1, 0.2, 0.3, 3}),

			marginal: []int{0, 3},
		},
	} {
		normal, ok := NewNormal(test.mu, test.sigma, nil)
		if !ok {
			t.Fatalf("Bad test, covariance matrix not positive definite")
		}
		marginal, ok := normal.MarginalNormal(test.marginal, nil)
		if !ok {
			t.Fatalf("Bad test, marginal matrix not positive definite")
		}
		dim := normal.Dim()
		nSamples := 1000000
		samps := mat64.NewDense(nSamples, dim, nil)
		for i := 0; i < nSamples; i++ {
			normal.Rand(samps.RawRowView(i))
		}
		estMean := make([]float64, dim)
		for i := range estMean {
			estMean[i] = stat.Mean(mat64.Col(nil, i, samps), nil)
		}
		for i, v := range test.marginal {
			if math.Abs(marginal.mu[i]-estMean[v]) > 1e-2 {
				t.Errorf("Mean mismatch: want: %v, got %v", estMean[v], marginal.mu[i])
			}
		}

		marginalCov := marginal.CovarianceMatrix(nil)
		estCov := stat.CovarianceMatrix(nil, samps, nil)
		for i, v1 := range test.marginal {
			for j, v2 := range test.marginal {
				c := marginalCov.At(i, j)
				ec := estCov.At(v1, v2)
				if math.Abs(c-ec) > 5e-2 {
					t.Errorf("Cov mismatch element i = %d, j = %d: want: %v, got %v", i, j, c, ec)
				}
			}
		}
	}
}
コード例 #2
0
func checkEntropy(t *testing.T, i int, x []float64, e entropyer, tol float64) {
	tmp := make([]float64, len(x))
	for i, v := range x {
		tmp[i] = -e.LogProb(v)
	}
	entropy := stat.Mean(tmp, nil)
	if !floats.EqualWithinAbsOrRel(entropy, e.Entropy(), tol, tol) {
		t.Errorf("Entropy mismatch case %v: want: %v, got: %v", i, entropy, e.Entropy())
	}
}
コード例 #3
0
ファイル: gPHydro.go プロジェクト: jgcarvalho/gPHydro
func (sc ScaleAA) Apply(seq string, sw int) ([]float64, []float64) {
	// sw must be odd (test before)
	hydro := make([]float64, len(seq))
	hydrosw := make([]float64, len(seq))
	for i := 0; i < len(seq); i++ {
		hydro[i] = sc[string(seq[i])]
	}
	for i := 0; i < len(hydro); i++ {
		if (i >= (sw / 2)) && (i < (len(hydro) - (sw / 2))) {
			b, e := i-(sw/2), i+(sw/2)+1
			hydrosw[i] = stat.Mean(hydro[b:e], nil)
		} else if i < (sw / 2) {
			b, e := 0, i+(sw/2)+1
			hydrosw[i] = stat.Mean(hydro[b:e], nil)
		} else if i >= (len(hydro) - (sw / 2)) {
			b, e := i-(sw/2), len(hydro)
			hydrosw[i] = stat.Mean(hydro[b:e], nil)
		}
	}
	return hydro, hydrosw
}
コード例 #4
0
ファイル: normal_test.go プロジェクト: darrenmcc/stat
func TestNormRand(t *testing.T) {
	for _, test := range []struct {
		mean []float64
		cov  []float64
	}{
		{
			mean: []float64{0, 0},
			cov: []float64{
				1, 0,
				0, 1,
			},
		},
		{
			mean: []float64{0, 0},
			cov: []float64{
				1, 0.9,
				0.9, 1,
			},
		},
		{
			mean: []float64{6, 7},
			cov: []float64{
				5, 0.9,
				0.9, 2,
			},
		},
	} {
		dim := len(test.mean)
		cov := mat64.NewSymDense(dim, test.cov)
		n, ok := NewNormal(test.mean, cov, nil)
		if !ok {
			t.Errorf("bad covariance matrix")
		}

		nSamples := 1000000
		samps := mat64.NewDense(nSamples, dim, nil)
		for i := 0; i < nSamples; i++ {
			n.Rand(samps.RawRowView(i))
		}
		estMean := make([]float64, dim)
		for i := range estMean {
			estMean[i] = stat.Mean(mat64.Col(nil, i, samps), nil)
		}
		if !floats.EqualApprox(estMean, test.mean, 1e-2) {
			t.Errorf("Mean mismatch: want: %v, got %v", test.mean, estMean)
		}
		estCov := stat.CovarianceMatrix(nil, samps, nil)
		if !mat64.EqualApprox(estCov, cov, 1e-2) {
			t.Errorf("Cov mismatch: want: %v, got %v", cov, estCov)
		}
	}
}
コード例 #5
0
func checkSkewness(t *testing.T, i int, x []float64, s skewnesser, tol float64) {
	mean := s.Mean()
	std := s.StdDev()
	tmp := make([]float64, len(x))
	for i, v := range x {
		tmp[i] = math.Pow(v-mean, 3)
	}
	mu3 := stat.Mean(tmp, nil)
	skewness := mu3 / math.Pow(std, 3)
	if !floats.EqualWithinAbsOrRel(skewness, s.Skewness(), tol, tol) {
		t.Errorf("ExKurtosis mismatch case %v: want: %v, got: %v", i, skewness, s.Skewness())
	}
}
コード例 #6
0
func checkExKurtosis(t *testing.T, i int, x []float64, e exKurtosiser, tol float64) {
	mean := e.Mean()
	tmp := make([]float64, len(x))
	for i, x := range x {
		tmp[i] = math.Pow(x-mean, 4)
	}
	variance := stat.Variance(x, nil)
	mu4 := stat.Mean(tmp, nil)
	kurtosis := mu4/(variance*variance) - 3
	if !floats.EqualWithinAbsOrRel(kurtosis, e.ExKurtosis(), tol, tol) {
		t.Errorf("ExKurtosis mismatch case %v: want: %v, got: %v", i, kurtosis, e.ExKurtosis())
	}
}
コード例 #7
0
ファイル: sample_test.go プロジェクト: sbinet/gonum-stat
func TestRejection(t *testing.T) {
	// Test by finding the expected value of a Normal.
	trueMean := 3.0
	target := distuv.Normal{Mu: trueMean, Sigma: 2}
	proposal := distuv.Normal{Mu: 0, Sigma: 5}

	nSamples := 100000
	x := make([]float64, nSamples)
	Rejection(x, target, proposal, 100, nil)
	ev := stat.Mean(x, nil)
	if math.Abs(ev-trueMean) > 1e-2 {
		t.Errorf("Mean mismatch: Want %v, got %v", trueMean, ev)
	}
}
コード例 #8
0
ファイル: sample_test.go プロジェクト: darrenmcc/stat
func TestImportance(t *testing.T) {
	// Test by finding the expected value of a Normal.
	trueMean := 3.0
	target := dist.Normal{Mu: trueMean, Sigma: 2}
	proposal := dist.Normal{Mu: 0, Sigma: 5}
	nSamples := 100000
	x := make([]float64, nSamples)
	weights := make([]float64, nSamples)
	Importance(x, weights, target, proposal)
	ev := stat.Mean(x, weights)
	if math.Abs(ev-trueMean) > 1e-2 {
		t.Errorf("Mean mismatch: Want %v, got %v", trueMean, ev)
	}
}
コード例 #9
0
ファイル: sample_test.go プロジェクト: sbinet/gonum-stat
func TestMetropolisHastings(t *testing.T) {
	// Test by finding the expected value of a Normal.
	trueMean := 3.0
	target := distuv.Normal{Mu: trueMean, Sigma: 2}
	proposal := condNorm{Sigma: 5}

	burnin := 500
	nSamples := 100000 + burnin
	x := make([]float64, nSamples)
	MetropolisHastings(x, 100, target, proposal, nil)
	// Remove burnin
	x = x[burnin:]
	ev := stat.Mean(x, nil)
	if math.Abs(ev-trueMean) > 1e-2 {
		t.Errorf("Mean mismatch: Want %v, got %v", trueMean, ev)
	}
}
コード例 #10
0
ファイル: exponential.go プロジェクト: darrenmcc/stat
// SuffStat computes the sufficient statistics of set of samples to update
// the distribution. The sufficient statistics are stored in place, and the
// effective number of samples are returned.
//
// The exponential distribution has one sufficient statistic, the average rate
// of the samples.
//
// If weights is nil, the weights are assumed to be 1, otherwise panics if
// len(samples) != len(weights). Panics if len(suffStat) != 1.
func (Exponential) SuffStat(samples, weights, suffStat []float64) (nSamples float64) {
	if len(weights) != 0 && len(samples) != len(weights) {
		panic("dist: slice size mismatch")
	}

	if len(suffStat) != 1 {
		panic("exponential: wrong suffStat length")
	}

	if len(weights) == 0 {
		nSamples = float64(len(samples))
	} else {
		nSamples = floats.Sum(weights)
	}

	mean := stat.Mean(samples, weights)
	suffStat[0] = 1 / mean
	return nSamples
}
コード例 #11
0
ファイル: exponential.go プロジェクト: sbinet/gonum-stat
// SuffStat computes the sufficient statistics of set of samples to update
// the distribution. The sufficient statistics are stored in place, and the
// effective number of samples are returned.
//
// The exponential distribution has one sufficient statistic, the average rate
// of the samples.
//
// If weights is nil, the weights are assumed to be 1, otherwise panics if
// len(samples) != len(weights). Panics if len(suffStat) != 1.
func (Exponential) SuffStat(samples, weights, suffStat []float64) (nSamples float64) {
	if len(weights) != 0 && len(samples) != len(weights) {
		panic(badLength)
	}

	if len(suffStat) != 1 {
		panic(badSuffStat)
	}

	if len(weights) == 0 {
		nSamples = float64(len(samples))
	} else {
		nSamples = floats.Sum(weights)
	}

	mean := stat.Mean(samples, weights)
	suffStat[0] = 1 / mean
	return nSamples
}
コード例 #12
0
ファイル: sample_test.go プロジェクト: sbinet/gonum-stat
func TestRejection(t *testing.T) {
	// Test by finding the expected value of a uniform.
	dim := 3
	bounds := make([]distmv.Bound, dim)
	for i := 0; i < dim; i++ {
		min := rand.NormFloat64()
		max := rand.NormFloat64()
		if min > max {
			min, max = max, min
		}
		bounds[i].Min = min
		bounds[i].Max = max
	}
	target := distmv.NewUniform(bounds, nil)
	mu := target.Mean(nil)

	muImp := make([]float64, dim)
	sigmaImp := mat64.NewSymDense(dim, nil)
	for i := 0; i < dim; i++ {
		sigmaImp.SetSym(i, i, 6)
	}
	proposal, ok := distmv.NewNormal(muImp, sigmaImp, nil)
	if !ok {
		t.Fatal("bad test, sigma not pos def")
	}

	nSamples := 1000
	batch := mat64.NewDense(nSamples, dim, nil)
	weights := make([]float64, nSamples)
	_, ok = Rejection(batch, target, proposal, 1000, nil)
	if !ok {
		t.Error("Bad test, nan samples")
	}

	for i := 0; i < dim; i++ {
		col := mat64.Col(nil, i, batch)
		ev := stat.Mean(col, weights)
		if math.Abs(ev-mu[i]) > 1e-2 {
			t.Errorf("Mean mismatch: Want %v, got %v", mu[i], ev)
		}
	}
}
コード例 #13
0
ファイル: norm.go プロジェクト: cjslep/stat
// SuffStat computes the sufficient statistics of a set of samples to update
// the distribution. The sufficient statistics are stored in place, and the
// effective number of samples are returned.
//
// The normal distribution has two sufficient statistics, the mean of the samples
// and the standard deviation of the samples.
//
// If weights is nil, the weights are assumed to be 1, otherwise panics if
// len(samples) != len(weights). Panics if len(suffStat) != 2.
func (Normal) SuffStat(samples, weights, suffStat []float64) (nSamples float64) {
	lenSamp := len(samples)
	if len(weights) != 0 && len(samples) != len(weights) {
		panic("dist: slice size mismatch")
	}
	if len(suffStat) != 2 {
		panic("dist: incorrect suffStat length")
	}

	if len(weights) == 0 {
		nSamples = float64(lenSamp)
	} else {
		nSamples = floats.Sum(weights)
	}

	mean := stat.Mean(samples, weights)
	suffStat[0] = mean

	// Use Moment and not StdDev because we want it to be uncorrected
	variance := stat.Moment(2, samples, mean, weights)
	suffStat[1] = math.Sqrt(variance)
	return nSamples
}
コード例 #14
0
ファイル: sample_test.go プロジェクト: sbinet/gonum-stat
func compareNormal(t *testing.T, want *distmv.Normal, batch *mat64.Dense, weights []float64) {
	dim := want.Dim()
	mu := want.Mean(nil)
	sigma := want.CovarianceMatrix(nil)
	n, _ := batch.Dims()
	if weights == nil {
		weights = make([]float64, n)
		for i := range weights {
			weights[i] = 1
		}
	}
	for i := 0; i < dim; i++ {
		col := mat64.Col(nil, i, batch)
		ev := stat.Mean(col, weights)
		if math.Abs(ev-mu[i]) > 1e-2 {
			t.Errorf("Mean mismatch: Want %v, got %v", mu[i], ev)
		}
	}

	cov := stat.CovarianceMatrix(nil, batch, weights)
	if !mat64.EqualApprox(cov, sigma, 1.5e-1) {
		t.Errorf("Covariance matrix mismatch")
	}
}
コード例 #15
0
ファイル: distribution_test.go プロジェクト: darrenmcc/stat
// testFullDist tests all of the functions of a fullDist.
func testFullDist(t *testing.T, f fullDist, i int) {
	tol := 1e-1
	const n = 1e6
	xs := make([]float64, n)
	for i := range xs {
		xs[i] = f.Rand()
	}
	sortedXs := make([]float64, n)
	copy(sortedXs, xs)
	sort.Float64s(sortedXs)
	tmp := make([]float64, n)

	// Mean check.
	mean := stat.Mean(xs, nil)
	if !floats.EqualWithinAbsOrRel(mean, f.Mean(), tol, tol) {
		t.Errorf("Mean mismatch case %v: want: %v, got: %v", i, mean, f.Mean())
	} else {
		mean = f.Mean()
	}

	// Median check.
	median := stat.Quantile(0.5, stat.Empirical, sortedXs, nil)
	if !floats.EqualWithinAbsOrRel(median, f.Median(), tol, tol) {
		t.Errorf("Median mismatch case %v: want: %v, got: %v", i, median, f.Median())
	}

	// Variance check.
	variance := stat.Variance(xs, nil)
	if !floats.EqualWithinAbsOrRel(variance, f.Variance(), tol, tol) {
		t.Errorf("Variance mismatch case %v: want: %v, got: %v", i, mean, f.Variance())
	} else {
		variance = f.Variance()
	}

	std := math.Sqrt(variance)
	if !floats.EqualWithinAbsOrRel(std, f.StdDev(), tol, tol) {
		t.Errorf("StdDev mismatch case %v: want: %v, got: %v", i, mean, f.StdDev())
	} else {
		std = f.StdDev()
	}

	// Entropy check.
	for i, x := range xs {
		tmp[i] = -f.LogProb(x)
	}
	entropy := stat.Mean(tmp, nil)
	if !floats.EqualWithinAbsOrRel(entropy, f.Entropy(), tol, tol) {
		t.Errorf("Entropy mismatch case %v: want: %v, got: %v", i, entropy, f.Entropy())
	}

	// Excess Kurtosis check.
	for i, x := range xs {
		tmp[i] = math.Pow(x-mean, 4)
	}
	mu4 := stat.Mean(tmp, nil)
	kurtosis := mu4/(variance*variance) - 3
	if !floats.EqualWithinAbsOrRel(kurtosis, f.ExKurtosis(), tol, tol) {
		t.Errorf("ExKurtosis mismatch case %v: want: %v, got: %v", i, kurtosis, f.ExKurtosis())
	}

	// Skewness check.
	for i, x := range xs {
		tmp[i] = math.Pow(x-mean, 3)
	}
	mu3 := stat.Mean(tmp, nil)
	skewness := mu3 / math.Pow(std, 3)
	if !floats.EqualWithinAbsOrRel(skewness, f.Skewness(), tol, tol) {
		t.Errorf("ExKurtosis mismatch case %v: want: %v, got: %v", i, skewness, f.Skewness())
	}

	// Quantile, CDF, and survival check.
	for i, p := range []float64{0.1, 0.25, 0.5, 0.75, 0.9} {
		x := f.Quantile(p)
		cdf := f.CDF(x)
		estCDF := stat.CDF(x, stat.Empirical, sortedXs, nil)
		if !floats.EqualWithinAbsOrRel(cdf, estCDF, tol, tol) {
			t.Errorf("CDF mismatch case %v: want: %v, got: %v", i, estCDF, cdf)
		}
		if !floats.EqualWithinAbsOrRel(cdf, p, tol, tol) {
			t.Errorf("Quantile/CDF mismatch case %v: want: %v, got: %v", i, p, cdf)
		}
		if math.Abs(1-cdf-f.Survival(x)) > 1e-14 {
			t.Errorf("Survival/CDF mismatch case %v: want: %v, got: %v", i, 1-cdf, f.Survival(x))
		}
	}

	// Prob and LogProb check.
	m := 1001
	bins := make([]float64, m)
	dividers := make([]float64, m)
	floats.Span(bins, 0, 1)
	for i, v := range bins {
		dividers[i] = f.Quantile(v)
	}
	counts := stat.Histogram(nil, dividers, sortedXs, nil)
	// Test PDf against normalized count
	for i, v := range counts {
		v /= float64(n)
		at := f.Quantile((bins[i] + bins[i+1]) / 2)
		prob := f.Prob(at)
		if !floats.EqualWithinAbsOrRel(skewness, f.Skewness(), tol, tol) {
			t.Errorf("Prob mismatch case %v at %v: want: %v, got: %v", i, at, v, prob)
			break
		}
		if math.Abs(math.Log(prob)-f.LogProb(at)) > 1e-14 {
			t.Errorf("Prob and LogProb mismatch case %v at %v: want %v, got %v", i, at, math.Log(prob), f.LogProb(at))
			break
		}
	}
}
コード例 #16
0
func checkMean(t *testing.T, i int, x []float64, m meaner, tol float64) {
	mean := stat.Mean(x, nil)
	if !floats.EqualWithinAbsOrRel(mean, m.Mean(), tol, tol) {
		t.Errorf("Mean mismatch case %v: want: %v, got: %v", i, mean, m.Mean())
	}
}
コード例 #17
0
ファイル: gp.go プロジェクト: btracey/gaussproc
// AddBatch adds a set training points to the Gp. This call updates internal
// values needed for prediction, so it is more efficient to add samples
// as a batch.
func (g *GP) AddBatch(x mat64.Matrix, y []float64) error {
	// Note: The outputs are stored scaled to have a mean of zero and a variance
	// of 1.

	// Verify input parameters
	rx, cx := x.Dims()
	ry := len(y)
	if rx != ry {
		panic(badInOut)
	}
	if cx != g.inputDim {
		panic(badInputLength)
	}
	nSamples := len(g.outputs)

	// Append the new data to the list of stored data.
	inputs := mat64.NewDense(rx+nSamples, g.inputDim, nil)
	inputs.Copy(g.inputs)
	inputs.View(nSamples, 0, rx, g.inputDim).(*mat64.Dense).Copy(x)
	g.inputs = inputs
	// Rescale the output data to its original value, append the new data, and
	// then rescale to have mean 0 and variance of 1.
	for i, v := range g.outputs {
		g.outputs[i] = v*g.std + g.mean
	}
	g.outputs = append(g.outputs, y...)
	g.mean = stat.Mean(g.outputs, nil)
	g.std = stat.StdDev(g.outputs, nil)
	for i, v := range g.outputs {
		g.outputs[i] = (v - g.mean) / g.std
	}

	// Add to the kernel matrix.
	k := mat64.NewSymDense(rx+nSamples, nil)
	k.CopySym(g.k)
	g.k = k
	// Compute the kernel with the new points and the old points
	for i := 0; i < nSamples; i++ {
		for j := nSamples; j < rx+nSamples; j++ {
			v := g.kernel.Distance(g.inputs.RawRowView(i), g.inputs.RawRowView(j))
			g.k.SetSym(i, j, v)
		}
	}

	// Compute the kernel with the new points and themselves
	for i := nSamples; i < rx+nSamples; i++ {
		for j := i; j < nSamples+rx; j++ {
			v := g.kernel.Distance(g.inputs.RawRowView(i), g.inputs.RawRowView(j))
			if i == j {
				v += g.noise
			}
			g.k.SetSym(i, j, v)
		}
	}
	// Cache necessary matrix results for computing predictions.
	var chol mat64.Cholesky
	ok := chol.Factorize(g.k)
	if !ok {
		return ErrSingular
	}
	g.cholK = &chol
	g.sigInvY.Reset()
	v := mat64.NewVector(len(g.outputs), g.outputs)
	g.sigInvY.SolveCholeskyVec(g.cholK, v)
	return nil
}
コード例 #18
0
ファイル: normal_test.go プロジェクト: darrenmcc/stat
func TestConditionNormal(t *testing.T) {
	// Uncorrelated values shouldn't influence the updated values.
	for _, test := range []struct {
		mu       []float64
		sigma    *mat64.SymDense
		observed []int
		values   []float64

		newMu    []float64
		newSigma *mat64.SymDense
	}{
		{
			mu:       []float64{2, 3},
			sigma:    mat64.NewSymDense(2, []float64{2, 0, 0, 5}),
			observed: []int{0},
			values:   []float64{10},

			newMu:    []float64{3},
			newSigma: mat64.NewSymDense(1, []float64{5}),
		},
		{
			mu:       []float64{2, 3},
			sigma:    mat64.NewSymDense(2, []float64{2, 0, 0, 5}),
			observed: []int{1},
			values:   []float64{10},

			newMu:    []float64{2},
			newSigma: mat64.NewSymDense(1, []float64{2}),
		},
		{
			mu:       []float64{2, 3, 4},
			sigma:    mat64.NewSymDense(3, []float64{2, 0, 0, 0, 5, 0, 0, 0, 10}),
			observed: []int{1},
			values:   []float64{10},

			newMu:    []float64{2, 4},
			newSigma: mat64.NewSymDense(2, []float64{2, 0, 0, 10}),
		},
		{
			mu:       []float64{2, 3, 4},
			sigma:    mat64.NewSymDense(3, []float64{2, 0, 0, 0, 5, 0, 0, 0, 10}),
			observed: []int{0, 1},
			values:   []float64{10, 15},

			newMu:    []float64{4},
			newSigma: mat64.NewSymDense(1, []float64{10}),
		},
		{
			mu:       []float64{2, 3, 4, 5},
			sigma:    mat64.NewSymDense(4, []float64{2, 0.5, 0, 0, 0.5, 5, 0, 0, 0, 0, 10, 2, 0, 0, 2, 3}),
			observed: []int{0, 1},
			values:   []float64{10, 15},

			newMu:    []float64{4, 5},
			newSigma: mat64.NewSymDense(2, []float64{10, 2, 2, 3}),
		},
	} {
		normal, ok := NewNormal(test.mu, test.sigma, nil)
		if !ok {
			t.Fatalf("Bad test, original sigma not positive definite")
		}
		newNormal, ok := normal.ConditionNormal(test.observed, test.values, nil)
		if !ok {
			t.Fatalf("Bad test, update failure")
		}

		if !floats.EqualApprox(test.newMu, newNormal.mu, 1e-12) {
			t.Errorf("Updated mean mismatch. Want %v, got %v.", test.newMu, newNormal.mu)
		}

		var sigma mat64.SymDense
		sigma.FromCholesky(&newNormal.chol)
		if !mat64.EqualApprox(test.newSigma, &sigma, 1e-12) {
			t.Errorf("Updated sigma mismatch\n.Want:\n% v\nGot:\n% v\n", test.newSigma, sigma)
		}
	}

	// Test bivariate case where the update rule is analytic
	for _, test := range []struct {
		mu    []float64
		std   []float64
		rho   float64
		value float64
	}{
		{
			mu:    []float64{2, 3},
			std:   []float64{3, 5},
			rho:   0.9,
			value: 1000,
		},
		{
			mu:    []float64{2, 3},
			std:   []float64{3, 5},
			rho:   -0.9,
			value: 1000,
		},
	} {
		std := test.std
		rho := test.rho
		sigma := mat64.NewSymDense(2, []float64{std[0] * std[0], std[0] * std[1] * rho, std[0] * std[1] * rho, std[1] * std[1]})
		normal, ok := NewNormal(test.mu, sigma, nil)
		if !ok {
			t.Fatalf("Bad test, original sigma not positive definite")
		}
		newNormal, ok := normal.ConditionNormal([]int{1}, []float64{test.value}, nil)
		if !ok {
			t.Fatalf("Bad test, update failed")
		}
		var newSigma mat64.SymDense
		newSigma.FromCholesky(&newNormal.chol)
		trueMean := test.mu[0] + rho*(std[0]/std[1])*(test.value-test.mu[1])
		if math.Abs(trueMean-newNormal.mu[0]) > 1e-14 {
			t.Errorf("Mean mismatch. Want %v, got %v", trueMean, newNormal.mu[0])
		}
		trueVar := (1 - rho*rho) * std[0] * std[0]
		if math.Abs(trueVar-newSigma.At(0, 0)) > 1e-14 {
			t.Errorf("Std mismatch. Want %v, got %v", trueMean, newNormal.mu[0])
		}
	}

	// Test via sampling.
	for _, test := range []struct {
		mu         []float64
		sigma      *mat64.SymDense
		observed   []int
		unobserved []int
		value      []float64
	}{
		// The indices in unobserved must be in ascending order for this test.
		{
			mu:    []float64{2, 3, 4},
			sigma: mat64.NewSymDense(3, []float64{2, 0.5, 3, 0.5, 1, 0.6, 3, 0.6, 10}),

			observed:   []int{0},
			unobserved: []int{1, 2},
			value:      []float64{1.9},
		},
		{
			mu:    []float64{2, 3, 4, 5},
			sigma: mat64.NewSymDense(4, []float64{2, 0.5, 3, 0.1, 0.5, 1, 0.6, 0.2, 3, 0.6, 10, 0.3, 0.1, 0.2, 0.3, 3}),

			observed:   []int{0, 3},
			unobserved: []int{1, 2},
			value:      []float64{1.9, 2.9},
		},
	} {
		totalSamp := 4000000
		var nSamp int
		samples := mat64.NewDense(totalSamp, len(test.mu), nil)
		normal, ok := NewNormal(test.mu, test.sigma, nil)
		if !ok {
			t.Errorf("bad test")
		}
		sample := make([]float64, len(test.mu))
		for i := 0; i < totalSamp; i++ {
			normal.Rand(sample)
			isClose := true
			for i, v := range test.observed {
				if math.Abs(sample[v]-test.value[i]) > 1e-1 {
					isClose = false
					break
				}
			}
			if isClose {
				samples.SetRow(nSamp, sample)
				nSamp++
			}
		}

		if nSamp < 100 {
			t.Errorf("bad test, not enough samples")
			continue
		}
		samples = samples.View(0, 0, nSamp, len(test.mu)).(*mat64.Dense)

		// Compute mean and covariance matrix.
		estMean := make([]float64, len(test.mu))
		for i := range estMean {
			estMean[i] = stat.Mean(mat64.Col(nil, i, samples), nil)
		}
		estCov := stat.CovarianceMatrix(nil, samples, nil)

		// Compute update rule.
		newNormal, ok := normal.ConditionNormal(test.observed, test.value, nil)
		if !ok {
			t.Fatalf("Bad test, update failure")
		}

		var subEstMean []float64
		for _, v := range test.unobserved {

			subEstMean = append(subEstMean, estMean[v])
		}
		subEstCov := mat64.NewSymDense(len(test.unobserved), nil)
		for i := 0; i < len(test.unobserved); i++ {
			for j := i; j < len(test.unobserved); j++ {
				subEstCov.SetSym(i, j, estCov.At(test.unobserved[i], test.unobserved[j]))
			}
		}

		for i, v := range subEstMean {
			if math.Abs(newNormal.mu[i]-v) > 5e-2 {
				t.Errorf("Mean mismatch. Want %v, got %v.", newNormal.mu[i], v)
			}
		}
		var sigma mat64.SymDense
		sigma.FromCholesky(&newNormal.chol)
		if !mat64.EqualApprox(&sigma, subEstCov, 1e-1) {
			t.Errorf("Covariance mismatch. Want:\n%0.8v\nGot:\n%0.8v\n", subEstCov, sigma)
		}
	}
}