コード例 #1
0
ファイル: diag.go プロジェクト: sguzwf/algorithm
/*
 * Compute
 *   X = B*diag(D).-1      flags & RIGHT == true
 *   X = diag(D).-1*C      flags & LEFT  == true
 *
 * Arguments:
 *   B     M-by-N matrix if flags&RIGHT == true or N-by-M matrix if flags&LEFT == true
 *
 *   D     N element column or row vector or N-by-N matrix
 *
 *   flags Indicator bits, LEFT or RIGHT
 */
func SolveDiag(B, D *matrix.FloatMatrix, flags Flags) {
	var c, d0 matrix.FloatMatrix
	if D.Cols() == 1 {
		// diagonal is column vector
		switch flags & (LEFT | RIGHT) {
		case LEFT:
			// scale rows; for each column element-wise multiply with D-vector
			for k := 0; k < B.Cols(); k++ {
				B.SubMatrix(&c, 0, k, B.Rows(), 1)
				c.Div(D)
			}
		case RIGHT:
			// scale columns
			for k := 0; k < B.Cols(); k++ {
				B.SubMatrix(&c, 0, k, B.Rows(), 1)
				// scale the column
				c.Scale(1.0 / D.GetAt(k, 0))
			}
		}
	} else {
		var d *matrix.FloatMatrix
		if D.Rows() == 1 {
			d = D
		} else {
			D.SubMatrix(&d0, 0, 0, 1, D.Cols(), D.LeadingIndex()+1)
			d = &d0
		}
		switch flags & (LEFT | RIGHT) {
		case LEFT:
			for k := 0; k < B.Rows(); k++ {
				B.SubMatrix(&c, k, 0, 1, B.Cols())
				// scale the row
				c.Scale(1.0 / d.GetAt(0, k))
			}
		case RIGHT:
			// scale columns
			for k := 0; k < B.Cols(); k++ {
				B.SubMatrix(&c, 0, k, B.Rows(), 1)
				// scale the column
				c.Scale(1.0 / d.GetAt(0, k))
			}
		}
	}
}