Beispiel #1
0
Datei: rbm.go Projekt: sguzwf/mlf
func (rbm *RBM) logistic(v *util.Vector, index int, isRow bool) (output float64) {
	output = 0.0
	if isRow {
		output = util.VecDotProduct(v, rbm.lock.weights.GetValues(index))
	} else {
		for i := 0; i < rbm.lock.weights.NumLabels(); i++ {
			output += v.Get(i) * rbm.lock.weights.Get(i, index)
		}
	}
	output = 1.0 / (1 + math.Exp(-output))
	return
}
Beispiel #2
0
Datei: rbm.go Projekt: sguzwf/mlf
// 输入和输出都有 bias 项
func (rbm *RBM) SampleHidden(v *util.Vector, n int, binary bool) *util.Vector {
	rbm.lock.RLock()
	defer rbm.lock.RUnlock()
	hiddenDim := rbm.options.NumHiddenUnits + 1
	visibleDim := rbm.lock.weights.NumValues()

	hiddenUnits := util.NewVector(hiddenDim)
	visibleUnits := util.NewVector(visibleDim)
	hiddenUnits.Set(0, 1.0)
	visibleUnits.Set(0, 1.0)

	for j := 1; j < visibleDim; j++ {
		visibleUnits.Set(j, v.Get(j))
	}

	// 更新 hidden units
	for i := 1; i < hiddenDim; i++ {
		prob := rbm.logistic(visibleUnits, i, true)
		if binary {
			hiddenUnits.Set(i, rbm.bernoulli(prob))
		} else {
			hiddenUnits.Set(i, prob)
		}
	}

	// reconstruct n-1 次
	for nCD := 0; nCD < n; nCD++ {
		for j := 1; j < visibleDim; j++ {
			var prob float64
			prob = rbm.logistic(hiddenUnits, j, false)
			visibleUnits.Set(j, prob)
		}
		for i := 1; i < hiddenDim; i++ {
			prob := rbm.logistic(visibleUnits, i, true)
			if binary {
				hiddenUnits.Set(i, rbm.bernoulli(prob))
			} else {
				hiddenUnits.Set(i, prob)
			}
		}
	}

	return hiddenUnits
}