示例#1
0
文件: nmf.go 项目: postfix/nmf
// Factors returns matrices W and H that are non-negative factors of V within the
// specified tolerance and computation limits given initial non-negative solutions Wo
// and Ho.
func Factors(V, Wo, Ho *mat64.Dense, c Config) (W, H *mat64.Dense, ok bool) {
	to := time.Now()

	W = Wo
	H = Ho

	var (
		wr, wc = W.Dims()
		hr, hc = H.Dims()

		tmp mat64.Dense
	)

	var vhT mat64.Dense
	gW := mat64.NewDense(wr, wc, nil)
	tmp.Mul(H, H.T())
	gW.Mul(W, &tmp)
	vhT.Mul(V, H.T())
	gW.Sub(gW, &vhT)

	var wTv mat64.Dense
	gH := mat64.NewDense(hr, hc, nil)
	tmp.Reset()
	tmp.Mul(W.T(), W)
	gH.Mul(&tmp, H)
	wTv.Mul(W.T(), V)
	gH.Sub(gH, &wTv)

	var gHT, gWHT mat64.Dense
	gHT.Clone(gH.T())
	gWHT.Stack(gW, &gHT)

	grad := mat64.Norm(&gWHT, 2)
	tolW := math.Max(0.001, c.Tolerance) * grad
	tolH := tolW

	var (
		_ok  bool
		iter int
	)

	decFiltW := func(r, c int, v float64) float64 {
		// decFiltW is applied to gW, so v = gW.At(r, c).
		if v < 0 || W.At(r, c) > 0 {
			return v
		}
		return 0
	}

	decFiltH := func(r, c int, v float64) float64 {
		// decFiltH is applied to gH, so v = gH.At(r, c).
		if v < 0 || H.At(r, c) > 0 {
			return v
		}
		return 0
	}

	var vT, hT, wT mat64.Dense
	for i := 0; i < c.MaxIter; i++ {
		gW.Apply(decFiltW, gW)
		gH.Apply(decFiltH, gH)

		var proj float64
		for _, v := range gW.RawMatrix().Data {
			proj += v * v
		}
		for _, v := range gH.RawMatrix().Data {
			proj += v * v
		}
		proj = math.Sqrt(proj)
		if proj < c.Tolerance*grad || time.Now().Sub(to) > c.Limit {
			break
		}

		vT.Clone(V.T())
		hT.Clone(H.T())
		wT.Clone(W.T())
		W, gW, iter, ok = nnlsSubproblem(&vT, &hT, &wT, tolW, c.MaxOuterSub, c.MaxInnerSub)
		if iter == 0 {
			tolW *= 0.1
		}

		wT.Reset()
		wT.Clone(W.T())
		W = &wT

		var gWT mat64.Dense
		gWT.Clone(gW.T())
		*gW = gWT

		H, gH, iter, _ok = nnlsSubproblem(V, W, H, tolH, c.MaxOuterSub, c.MaxInnerSub)
		ok = ok && _ok
		if iter == 0 {
			tolH *= 0.1
		}
	}

	return W, H, ok
}