Ejemplo n.º 1
0
// Mul takes the matrix product of a and b, placing the result in the receiver.
//
// See the Muler interface for more information.
func (m *Dense) Mul(a, b Matrix) {
	ar, ac := a.Dims()
	br, bc := b.Dims()

	if ac != br {
		panic(ErrShape)
	}

	aU, aTrans := untranspose(a)
	bU, bTrans := untranspose(b)
	m.reuseAs(ar, bc)
	var restore func()
	if m == aU {
		m, restore = m.isolatedWorkspace(aU)
		defer restore()
	} else if m == bU {
		m, restore = m.isolatedWorkspace(bU)
		defer restore()
	}
	aT := blas.NoTrans
	if aTrans {
		aT = blas.Trans
	}
	bT := blas.NoTrans
	if bTrans {
		bT = blas.Trans
	}

	// Some of the cases do not have a transpose option, so create
	// temporary memory.
	// C = A^T * B = (B^T * A)^T
	// C^T = B^T * A.
	if aU, ok := aU.(RawMatrixer); ok {
		amat := aU.RawMatrix()
		if bU, ok := bU.(RawMatrixer); ok {
			bmat := bU.RawMatrix()
			blas64.Gemm(aT, bT, 1, amat, bmat, 0, m.mat)
			return
		}
		if bU, ok := bU.(RawSymmetricer); ok {
			bmat := bU.RawSymmetric()
			if aTrans {
				c := getWorkspace(ac, ar, false)
				blas64.Symm(blas.Left, 1, bmat, amat, 0, c.mat)
				strictCopy(m, c.T())
				putWorkspace(c)
				return
			}
			blas64.Symm(blas.Right, 1, bmat, amat, 0, m.mat)
			return
		}
		if bU, ok := bU.(RawTriangular); ok {
			// Trmm updates in place, so copy aU first.
			bmat := bU.RawTriangular()
			if aTrans {
				c := getWorkspace(ac, ar, false)
				var tmp Dense
				tmp.SetRawMatrix(aU.RawMatrix())
				c.Copy(&tmp)
				bT := blas.Trans
				if bTrans {
					bT = blas.NoTrans
				}
				blas64.Trmm(blas.Left, bT, 1, bmat, c.mat)
				strictCopy(m, c.T())
				putWorkspace(c)
				return
			}
			m.Copy(a)
			blas64.Trmm(blas.Right, bT, 1, bmat, m.mat)
			return
		}
		if bU, ok := bU.(*Vector); ok {
			bvec := bU.RawVector()
			if bTrans {
				// {ar,1} x {1,bc}, which is not a vector.
				// Instead, construct B as a General.
				bmat := blas64.General{
					Rows:   bc,
					Cols:   1,
					Stride: bvec.Inc,
					Data:   bvec.Data,
				}
				blas64.Gemm(aT, bT, 1, amat, bmat, 0, m.mat)
				return
			}
			cvec := blas64.Vector{
				Inc:  m.mat.Stride,
				Data: m.mat.Data,
			}
			blas64.Gemv(aT, 1, amat, bvec, 0, cvec)
			return
		}
	}
	if bU, ok := bU.(RawMatrixer); ok {
		bmat := bU.RawMatrix()
		if aU, ok := aU.(RawSymmetricer); ok {
			amat := aU.RawSymmetric()
			if bTrans {
				c := getWorkspace(bc, br, false)
				blas64.Symm(blas.Right, 1, amat, bmat, 0, c.mat)
				strictCopy(m, c.T())
				putWorkspace(c)
				return
			}
			blas64.Symm(blas.Left, 1, amat, bmat, 0, m.mat)
			return
		}
		if aU, ok := aU.(RawTriangular); ok {
			// Trmm updates in place, so copy bU first.
			amat := aU.RawTriangular()
			if bTrans {
				c := getWorkspace(bc, br, false)
				var tmp Dense
				tmp.SetRawMatrix(bU.RawMatrix())
				c.Copy(&tmp)
				aT := blas.Trans
				if aTrans {
					aT = blas.NoTrans
				}
				blas64.Trmm(blas.Right, aT, 1, amat, c.mat)
				strictCopy(m, c.T())
				putWorkspace(c)
				return
			}
			m.Copy(b)
			blas64.Trmm(blas.Left, aT, 1, amat, m.mat)
			return
		}
		if aU, ok := aU.(*Vector); ok {
			avec := aU.RawVector()
			if aTrans {
				// {1,ac} x {ac, bc}
				// Transpose B so that the vector is on the right.
				cvec := blas64.Vector{
					Inc:  1,
					Data: m.mat.Data,
				}
				bT := blas.Trans
				if bTrans {
					bT = blas.NoTrans
				}
				blas64.Gemv(bT, 1, bmat, avec, 0, cvec)
				return
			}
			// {ar,1} x {1,bc} which is not a vector result.
			// Instead, construct A as a General.
			amat := blas64.General{
				Rows:   ar,
				Cols:   1,
				Stride: avec.Inc,
				Data:   avec.Data,
			}
			blas64.Gemm(aT, bT, 1, amat, bmat, 0, m.mat)
			return
		}
	}

	if aU, ok := aU.(Vectorer); ok {
		if bU, ok := bU.(Vectorer); ok {
			row := make([]float64, ac)
			col := make([]float64, br)
			if aTrans {
				if bTrans {
					for r := 0; r < ar; r++ {
						dataTmp := m.mat.Data[r*m.mat.Stride : r*m.mat.Stride+bc]
						for c := 0; c < bc; c++ {
							dataTmp[c] = blas64.Dot(ac,
								blas64.Vector{Inc: 1, Data: aU.Col(row, r)},
								blas64.Vector{Inc: 1, Data: bU.Row(col, c)},
							)
						}
					}
					return
				}
				// TODO(jonlawlor): determine if (b*a)' is more efficient
				for r := 0; r < ar; r++ {
					dataTmp := m.mat.Data[r*m.mat.Stride : r*m.mat.Stride+bc]
					for c := 0; c < bc; c++ {
						dataTmp[c] = blas64.Dot(ac,
							blas64.Vector{Inc: 1, Data: aU.Col(row, r)},
							blas64.Vector{Inc: 1, Data: bU.Col(col, c)},
						)
					}
				}
				return
			}
			if bTrans {
				for r := 0; r < ar; r++ {
					dataTmp := m.mat.Data[r*m.mat.Stride : r*m.mat.Stride+bc]
					for c := 0; c < bc; c++ {
						dataTmp[c] = blas64.Dot(ac,
							blas64.Vector{Inc: 1, Data: aU.Row(row, r)},
							blas64.Vector{Inc: 1, Data: bU.Row(col, c)},
						)
					}
				}
				return
			}
			for r := 0; r < ar; r++ {
				dataTmp := m.mat.Data[r*m.mat.Stride : r*m.mat.Stride+bc]
				for c := 0; c < bc; c++ {
					dataTmp[c] = blas64.Dot(ac,
						blas64.Vector{Inc: 1, Data: aU.Row(row, r)},
						blas64.Vector{Inc: 1, Data: bU.Col(col, c)},
					)
				}
			}
			return
		}
	}

	row := make([]float64, ac)
	for r := 0; r < ar; r++ {
		for i := range row {
			row[i] = a.At(r, i)
		}
		for c := 0; c < bc; c++ {
			var v float64
			for i, e := range row {
				v += e * b.At(i, c)
			}
			m.mat.Data[r*m.mat.Stride+c] = v
		}
	}
}
Ejemplo n.º 2
0
func Dlahr2Test(t *testing.T, impl Dlahr2er) {
	rnd := rand.New(rand.NewSource(1))
	for _, test := range []struct {
		n, k, nb int
	}{
		{3, 0, 3},
		{3, 1, 2},
		{3, 1, 1},

		{5, 0, 5},
		{5, 1, 4},
		{5, 1, 3},
		{5, 1, 2},
		{5, 1, 1},
		{5, 2, 3},
		{5, 2, 2},
		{5, 2, 1},
		{5, 3, 2},
		{5, 3, 1},

		{7, 3, 4},
		{7, 3, 3},
		{7, 3, 2},
		{7, 3, 1},

		{10, 0, 10},
		{10, 1, 9},
		{10, 1, 5},
		{10, 1, 1},
		{10, 5, 5},
		{10, 5, 3},
		{10, 5, 1},
	} {
		for cas := 0; cas < 100; cas++ {
			for _, extraStride := range []int{0, 1, 10} {
				n := test.n
				k := test.k
				nb := test.nb

				a := randomGeneral(n, n-k+1, n-k+1+extraStride, rnd)
				aCopy := a
				aCopy.Data = make([]float64, len(a.Data))
				copy(aCopy.Data, a.Data)
				tmat := nanTriangular(blas.Upper, nb, nb+extraStride)
				y := nanGeneral(n, nb, nb+extraStride)
				tau := nanSlice(nb)

				impl.Dlahr2(n, k, nb, a.Data, a.Stride, tau, tmat.Data, tmat.Stride, y.Data, y.Stride)

				prefix := fmt.Sprintf("Case n=%v, k=%v, nb=%v, ldex=%v", n, k, nb, extraStride)

				if !generalOutsideAllNaN(a) {
					t.Errorf("%v: out-of-range write to A\n%v", prefix, a.Data)
				}
				if !triangularOutsideAllNaN(tmat) {
					t.Errorf("%v: out-of-range write to T\n%v", prefix, tmat.Data)
				}
				if !generalOutsideAllNaN(y) {
					t.Errorf("%v: out-of-range write to Y\n%v", prefix, y.Data)
				}

				// Check that A[:k,:] and A[:,nb:] blocks were not modified.
				for i := 0; i < n; i++ {
					for j := 0; j < n-k+1; j++ {
						if i >= k && j < nb {
							continue
						}
						if a.Data[i*a.Stride+j] != aCopy.Data[i*aCopy.Stride+j] {
							t.Errorf("%v: unexpected write to A[%v,%v]", prefix, i, j)
						}
					}
				}

				// Check that all elements of tau were assigned.
				for i, v := range tau {
					if math.IsNaN(v) {
						t.Errorf("%v: tau[%v] not assigned", prefix, i)
					}
				}

				// Extract V from a.
				v := blas64.General{
					Rows:   n - k + 1,
					Cols:   nb,
					Stride: nb,
					Data:   make([]float64, (n-k+1)*nb),
				}
				for j := 0; j < v.Cols; j++ {
					v.Data[(j+1)*v.Stride+j] = 1
					for i := j + 2; i < v.Rows; i++ {
						v.Data[i*v.Stride+j] = a.Data[(i+k-1)*a.Stride+j]
					}
				}

				// VT = V.
				vt := v
				vt.Data = make([]float64, len(v.Data))
				copy(vt.Data, v.Data)
				// VT = V * T.
				blas64.Trmm(blas.Right, blas.NoTrans, 1, tmat, vt)
				// YWant = A * V * T.
				ywant := blas64.General{
					Rows:   n,
					Cols:   nb,
					Stride: nb,
					Data:   make([]float64, n*nb),
				}
				blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aCopy, vt, 0, ywant)

				// Compare Y and YWant.
				for i := 0; i < n; i++ {
					for j := 0; j < nb; j++ {
						diff := math.Abs(ywant.Data[i*ywant.Stride+j] - y.Data[i*y.Stride+j])
						if diff > 1e-14 {
							t.Errorf("%v: unexpected Y[%v,%v], diff=%v", prefix, i, j, diff)
						}
					}
				}

				// Construct Q directly from the first nb columns of a.
				q := constructQ("QR", n-k, nb, a.Data[k*a.Stride:], a.Stride, tau)
				if !isOrthonormal(q) {
					t.Errorf("%v: Q is not orthogonal", prefix)
				}
				// Construct Q as the product Q = I - V*T*V^T.
				qwant := blas64.General{
					Rows:   n - k + 1,
					Cols:   n - k + 1,
					Stride: n - k + 1,
					Data:   make([]float64, (n-k+1)*(n-k+1)),
				}
				for i := 0; i < qwant.Rows; i++ {
					qwant.Data[i*qwant.Stride+i] = 1
				}
				blas64.Gemm(blas.NoTrans, blas.Trans, -1, vt, v, 1, qwant)
				if !isOrthonormal(qwant) {
					t.Errorf("%v: Q = I - V*T*V^T is not orthogonal", prefix)
				}

				// Compare Q and QWant. Note that since Q is
				// (n-k)×(n-k) and QWant is (n-k+1)×(n-k+1), we
				// ignore the first row and column of QWant.
				for i := 0; i < n-k; i++ {
					for j := 0; j < n-k; j++ {
						diff := math.Abs(q.Data[i*q.Stride+j] - qwant.Data[(i+1)*qwant.Stride+j+1])
						if diff > 1e-14 {
							t.Errorf("%v: unexpected Q[%v,%v], diff=%v", prefix, i, j, diff)
						}
					}
				}
			}
		}
	}

	// Go runs tests from the source directory, so unfortunately we need to
	// include the "../testlapack" part.
	file, err := os.Open(filepath.FromSlash("../testlapack/testdata/dlahr2data.json.gz"))
	if err != nil {
		log.Fatal(err)
	}
	defer file.Close()
	r, err := gzip.NewReader(file)
	if err != nil {
		log.Fatal(err)
	}
	defer r.Close()

	var tests []Dlahr2test
	json.NewDecoder(r).Decode(&tests)
	for _, test := range tests {
		tau := make([]float64, len(test.TauWant))
		for _, ldex := range []int{0, 1, 20} {
			n := test.N
			k := test.K
			nb := test.NB

			lda := n - k + 1 + ldex
			a := make([]float64, (n-1)*lda+n-k+1)
			copyMatrix(n, n-k+1, a, lda, test.A)

			ldt := nb + ldex
			tmat := make([]float64, (nb-1)*ldt+nb)

			ldy := nb + ldex
			y := make([]float64, (n-1)*ldy+nb)

			impl.Dlahr2(n, k, nb, a, lda, tau, tmat, ldt, y, ldy)

			prefix := fmt.Sprintf("Case n=%v, k=%v, nb=%v, ldex=%v", n, k, nb, ldex)
			if !equalApprox(n, n-k+1, a, lda, test.AWant, 1e-14) {
				t.Errorf("%v: unexpected matrix A\n got=%v\nwant=%v", prefix, a, test.AWant)
			}
			if !equalApproxTriangular(true, nb, tmat, ldt, test.TWant, 1e-14) {
				t.Errorf("%v: unexpected matrix T\n got=%v\nwant=%v", prefix, tmat, test.TWant)
			}
			if !equalApprox(n, nb, y, ldy, test.YWant, 1e-14) {
				t.Errorf("%v: unexpected matrix Y\n got=%v\nwant=%v", prefix, y, test.YWant)
			}
			if !floats.EqualApprox(tau, test.TauWant, 1e-14) {
				t.Errorf("%v: unexpected slice tau\n got=%v\nwant=%v", prefix, tau, test.TauWant)
			}
		}
	}
}