func (rbm *RBM) logistic(v *util.Vector, index int, isRow bool) (output float64) { output = 0.0 if isRow { output = util.VecDotProduct(v, rbm.lock.weights.GetValues(index)) } else { for i := 0; i < rbm.lock.weights.NumLabels(); i++ { output += v.Get(i) * rbm.lock.weights.Get(i, index) } } output = 1.0 / (1 + math.Exp(-output)) return }
// 输入和输出都有 bias 项 func (rbm *RBM) SampleHidden(v *util.Vector, n int, binary bool) *util.Vector { rbm.lock.RLock() defer rbm.lock.RUnlock() hiddenDim := rbm.options.NumHiddenUnits + 1 visibleDim := rbm.lock.weights.NumValues() hiddenUnits := util.NewVector(hiddenDim) visibleUnits := util.NewVector(visibleDim) hiddenUnits.Set(0, 1.0) visibleUnits.Set(0, 1.0) for j := 1; j < visibleDim; j++ { visibleUnits.Set(j, v.Get(j)) } // 更新 hidden units for i := 1; i < hiddenDim; i++ { prob := rbm.logistic(visibleUnits, i, true) if binary { hiddenUnits.Set(i, rbm.bernoulli(prob)) } else { hiddenUnits.Set(i, prob) } } // reconstruct n-1 次 for nCD := 0; nCD < n; nCD++ { for j := 1; j < visibleDim; j++ { var prob float64 prob = rbm.logistic(hiddenUnits, j, false) visibleUnits.Set(j, prob) } for i := 1; i < hiddenDim; i++ { prob := rbm.logistic(visibleUnits, i, true) if binary { hiddenUnits.Set(i, rbm.bernoulli(prob)) } else { hiddenUnits.Set(i, prob) } } } return hiddenUnits }