コード例 #1
0
ファイル: misc.go プロジェクト: sguzwf/algorithm
func updateScaling(W *sets.FloatMatrixSet, lmbda, s, z *matrix.FloatMatrix) (err error) {
	err = nil
	var stmp, ztmp *matrix.FloatMatrix
	/*
	   Nonlinear and 'l' blocks

	      d :=  d .* sqrt( s ./ z )
	      lmbda := lmbda .* sqrt(s) .* sqrt(z)
	*/
	mnl := 0
	dnlset := W.At("dnl")
	dnliset := W.At("dnli")
	dset := W.At("d")
	diset := W.At("di")
	beta := W.At("beta")[0]
	if dnlset != nil && dnlset[0].NumElements() > 0 {
		mnl = dnlset[0].NumElements()
	}
	ml := dset[0].NumElements()
	m := mnl + ml
	//fmt.Printf("ml=%d, mnl=%d, m=%d'n", ml, mnl, m)

	stmp = matrix.FloatVector(s.FloatArray()[:m])
	stmp.Apply(math.Sqrt)
	s.SetIndexesFromArray(stmp.FloatArray(), matrix.MakeIndexSet(0, m, 1)...)

	ztmp = matrix.FloatVector(z.FloatArray()[:m])
	ztmp.Apply(math.Sqrt)
	z.SetIndexesFromArray(ztmp.FloatArray(), matrix.MakeIndexSet(0, m, 1)...)

	// d := d .* s .* z
	if len(dnlset) > 0 {
		blas.TbmvFloat(s, dnlset[0], &la_.IOpt{"n", mnl}, &la_.IOpt{"k", 0}, &la_.IOpt{"lda", 1})
		blas.TbsvFloat(z, dnlset[0], &la_.IOpt{"n", mnl}, &la_.IOpt{"k", 0}, &la_.IOpt{"lda", 1})
		//dnliset[0].Apply(dnlset[0], func(a float64)float64 { return 1.0/a})
		//--dnliset[0] = matrix.Inv(dnlset[0])
		matrix.Set(dnliset[0], dnlset[0])
		dnliset[0].Inv()
	}
	blas.TbmvFloat(s, dset[0], &la_.IOpt{"n", ml},
		&la_.IOpt{"k", 0}, &la_.IOpt{"lda", 1}, &la_.IOpt{"offseta", mnl})
	blas.TbsvFloat(z, dset[0], &la_.IOpt{"n", ml},
		&la_.IOpt{"k", 0}, &la_.IOpt{"lda", 1}, &la_.IOpt{"offseta", mnl})
	//diset[0].Apply(dset[0], func(a float64)float64 { return 1.0/a})
	//--diset[0] = matrix.Inv(dset[0])
	matrix.Set(diset[0], dset[0])
	diset[0].Inv()

	// lmbda := s .* z
	blas.CopyFloat(s, lmbda, &la_.IOpt{"n", m})
	blas.TbmvFloat(z, lmbda, &la_.IOpt{"n", m}, &la_.IOpt{"k", 0}, &la_.IOpt{"lda", 1})

	// 'q' blocks.
	// Let st and zt be the new variables in the old scaling:
	//
	//     st = s_k,   zt = z_k
	//
	// and a = sqrt(st' * J * st),  b = sqrt(zt' * J * zt).
	//
	// 1. Compute the hyperbolic Householder transformation 2*q*q' - J
	//    that maps st/a to zt/b.
	//
	//        c = sqrt( (1 + st'*zt/(a*b)) / 2 )
	//        q = (st/a + J*zt/b) / (2*c).
	//
	//    The new scaling point is
	//
	//        wk := betak * sqrt(a/b) * (2*v[k]*v[k]' - J) * q
	//
	//    with betak = W['beta'][k].
	//
	// 3. The scaled variable:
	//
	//        lambda_k0 = sqrt(a*b) * c
	//        lambda_k1 = sqrt(a*b) * ( (2vk*vk' - J) * (-d*q + u/2) )_1
	//
	//    where
	//
	//        u = st/a - J*zt/b
	//        d = ( vk0 * (vk'*u) + u0/2 ) / (2*vk0 *(vk'*q) - q0 + 1).
	//
	// 4. Update scaling
	//
	//        v[k] := wk^1/2
	//              = 1 / sqrt(2*(wk0 + 1)) * (wk + e).
	//        beta[k] *=  sqrt(a/b)

	ind := m
	for k, v := range W.At("v") {
		m = v.NumElements()

		// ln = sqrt( lambda_k' * J * lambda_k ) !! NOT USED!!
		jnrm2(lmbda, m, ind) // ?? NOT USED ??

		// a = sqrt( sk' * J * sk ) = sqrt( st' * J * st )
		// s := s / a = st / a
		aa := jnrm2(s, m, ind)
		blas.ScalFloat(s, 1.0/aa, &la_.IOpt{"n", m}, &la_.IOpt{"offset", ind})

		// b = sqrt( zk' * J * zk ) = sqrt( zt' * J * zt )
		// z := z / a = zt / b
		bb := jnrm2(z, m, ind)
		blas.ScalFloat(z, 1.0/bb, &la_.IOpt{"n", m}, &la_.IOpt{"offset", ind})

		// c = sqrt( ( 1 + (st'*zt) / (a*b) ) / 2 )
		cc := blas.DotFloat(s, z, &la_.IOpt{"offsetx", ind}, &la_.IOpt{"offsety", ind},
			&la_.IOpt{"n", m})
		cc = math.Sqrt((1.0 + cc) / 2.0)

		// vs = v' * st / a
		vs := blas.DotFloat(v, s, &la_.IOpt{"offsety", ind}, &la_.IOpt{"n", m})

		// vz = v' * J *zt / b
		vz := jdot(v, z, m, 0, ind)

		// vq = v' * q where q = (st/a + J * zt/b) / (2 * c)
		vq := (vs + vz) / 2.0 / cc

		// vq = v' * q where q = (st/a + J * zt/b) / (2 * c)
		vu := vs - vz
		// lambda_k0 = c
		lmbda.SetIndex(ind, cc)

		// wk0 = 2 * vk0 * (vk' * q) - q0
		wk0 := 2.0*v.GetIndex(0)*vq - (s.GetIndex(ind)+z.GetIndex(ind))/2.0/cc

		// d = (v[0] * (vk' * u) - u0/2) / (wk0 + 1)
		dd := (v.GetIndex(0)*vu - s.GetIndex(ind)/2.0 + z.GetIndex(ind)/2.0) / (wk0 + 1.0)

		// lambda_k1 = 2 * v_k1 * vk' * (-d*q + u/2) - d*q1 + u1/2
		blas.CopyFloat(v, lmbda, &la_.IOpt{"offsetx", 1}, &la_.IOpt{"offsety", ind + 1},
			&la_.IOpt{"n", m - 1})
		blas.ScalFloat(lmbda, (2.0 * (-dd*vq + 0.5*vu)),
			&la_.IOpt{"offsetx", ind + 1}, &la_.IOpt{"offsety", ind + 1}, &la_.IOpt{"n", m - 1})
		blas.AxpyFloat(s, lmbda, 0.5*(1.0-dd/cc),
			&la_.IOpt{"offsetx", ind + 1}, &la_.IOpt{"offsety", ind + 1}, &la_.IOpt{"n", m - 1})
		blas.AxpyFloat(z, lmbda, 0.5*(1.0+dd/cc),
			&la_.IOpt{"offsetx", ind + 1}, &la_.IOpt{"offsety", ind + 1}, &la_.IOpt{"n", m - 1})

		// Scale so that sqrt(lambda_k' * J * lambda_k) = sqrt(aa*bb).
		blas.ScalFloat(lmbda, math.Sqrt(aa*bb), &la_.IOpt{"offset", ind}, &la_.IOpt{"n", m})

		// v := (2*v*v' - J) * q
		//    = 2 * (v'*q) * v' - (J* st/a + zt/b) / (2*c)
		blas.ScalFloat(v, 2.0*vq)
		v.SetIndex(0, v.GetIndex(0)-(s.GetIndex(ind)/2.0/cc))
		blas.AxpyFloat(s, v, 0.5/cc, &la_.IOpt{"offsetx", ind + 1}, &la_.IOpt{"offsety", 1},
			&la_.IOpt{"n", m - 1})
		blas.AxpyFloat(z, v, -0.5/cc, &la_.IOpt{"offsetx", ind}, &la_.IOpt{"n", m})

		// v := v^{1/2} = 1/sqrt(2 * (v0 + 1)) * (v + e)
		v0 := v.GetIndex(0) + 1.0
		v.SetIndex(0, v0)
		blas.ScalFloat(v, 1.0/math.Sqrt(2.0*v0))

		// beta[k] *= ( aa / bb )**1/2
		bk := beta.GetIndex(k)
		beta.SetIndex(k, bk*math.Sqrt(aa/bb))

		ind += m
	}
	//fmt.Printf("-- end of q:\nz=\n%v\nlmbda=\n%v\n", z.ConvertToString(), lmbda.ConvertToString())
	//fmt.Printf("beta=\n%v\n", beta.ConvertToString())

	// 's' blocks
	//
	// Let st, zt be the updated variables in the old scaling:
	//
	//     st = Ls * Ls', zt = Lz * Lz'.
	//
	// where Ls and Lz are the 's' components of s, z.
	//
	// 1.  SVD Lz'*Ls = Uk * lambda_k^+ * Vk'.
	//
	// 2.  New scaling is
	//
	//         r[k] := r[k] * Ls * Vk * diag(lambda_k^+)^{-1/2}
	//         rti[k] := r[k] * Lz * Uk * diag(lambda_k^+)^{-1/2}.
	//

	maxr := 0
	for _, m := range W.At("r") {
		if m.Rows() > maxr {
			maxr = m.Rows()
		}
	}
	work := matrix.FloatZeros(maxr*maxr, 1)
	vlensum := 0
	for _, m := range W.At("v") {
		vlensum += m.NumElements()
	}
	ind = mnl + ml + vlensum
	ind2 := ind
	ind3 := 0
	rset := W.At("r")
	rtiset := W.At("rti")

	for k, _ := range rset {
		r := rset[k]
		rti := rtiset[k]
		m = r.Rows()
		//fmt.Printf("m=%d, r=\n%v\nrti=\n%v\n", m, r.ConvertToString(), rti.ConvertToString())

		// r := r*sk = r*Ls
		blas.GemmFloat(r, s, work, 1.0, 0.0, &la_.IOpt{"m", m}, &la_.IOpt{"n", m},
			&la_.IOpt{"k", m}, &la_.IOpt{"ldb", m}, &la_.IOpt{"ldc", m},
			&la_.IOpt{"offsetb", ind2})
		//fmt.Printf("1 work=\n%v\n", work.ConvertToString())
		blas.CopyFloat(work, r, &la_.IOpt{"n", m * m})

		// rti := rti*zk = rti*Lz
		blas.GemmFloat(rti, z, work, 1.0, 0.0, &la_.IOpt{"m", m}, &la_.IOpt{"n", m},
			&la_.IOpt{"k", m}, &la_.IOpt{"ldb", m}, &la_.IOpt{"ldc", m},
			&la_.IOpt{"offsetb", ind2})
		//fmt.Printf("2 work=\n%v\n", work.ConvertToString())
		blas.CopyFloat(work, rti, &la_.IOpt{"n", m * m})

		// SVD Lz'*Ls = U * lmbds^+ * V'; store U in sk and V' in zk. '
		blas.GemmFloat(z, s, work, 1.0, 0.0, la_.OptTransA, &la_.IOpt{"m", m},
			&la_.IOpt{"n", m}, &la_.IOpt{"k", m}, &la_.IOpt{"lda", m}, &la_.IOpt{"ldb", m},
			&la_.IOpt{"ldc", m}, &la_.IOpt{"offseta", ind2}, &la_.IOpt{"offsetb", ind2})
		//fmt.Printf("3 work=\n%v\n", work.ConvertToString())

		// U = s, Vt = z
		lapack.GesvdFloat(work, lmbda, s, z, la_.OptJobuAll, la_.OptJobvtAll,
			&la_.IOpt{"m", m}, &la_.IOpt{"n", m}, &la_.IOpt{"lda", m}, &la_.IOpt{"ldu", m},
			&la_.IOpt{"ldvt", m}, &la_.IOpt{"offsets", ind}, &la_.IOpt{"offsetu", ind2},
			&la_.IOpt{"offsetvt", ind2})

		// r := r*V
		blas.GemmFloat(r, z, work, 1.0, 0.0, la_.OptTransB, &la_.IOpt{"m", m},
			&la_.IOpt{"n", m}, &la_.IOpt{"k", m}, &la_.IOpt{"ldb", m}, &la_.IOpt{"ldc", m},
			&la_.IOpt{"offsetb", ind2})
		//fmt.Printf("4 work=\n%v\n", work.ConvertToString())
		blas.CopyFloat(work, r, &la_.IOpt{"n", m * m})

		// rti := rti*U
		blas.GemmFloat(rti, s, work, 1.0, 0.0, &la_.IOpt{"m", m}, &la_.IOpt{"n", m},
			&la_.IOpt{"k", m}, &la_.IOpt{"ldb", m}, &la_.IOpt{"ldc", m},
			&la_.IOpt{"offsetb", ind2})
		//fmt.Printf("5 work=\n%v\n", work.ConvertToString())
		blas.CopyFloat(work, rti, &la_.IOpt{"n", m * m})

		for i := 0; i < m; i++ {
			a := 1.0 / math.Sqrt(lmbda.GetIndex(ind+i))
			blas.ScalFloat(r, a, &la_.IOpt{"n", m}, &la_.IOpt{"offset", m * i})
			blas.ScalFloat(rti, a, &la_.IOpt{"n", m}, &la_.IOpt{"offset", m * i})
		}
		ind += m
		ind2 += m * m
		ind3 += m // !!NOT USED: ind3!!
	}

	//fmt.Printf("-- end of s:\nz=\n%v\nlmbda=\n%v\n", z.ConvertToString(), lmbda.ConvertToString())

	return

}