Ejemplo n.º 1
0
func (trainer *MaxEntClassifierTrainer) Train(set data.Dataset) Model {
	// 检查训练数据是否是分类问题
	if !set.GetOptions().IsSupervisedLearning {
		log.Fatal("训练数据不是分类问题数据")
	}

	// 建立新的优化器
	optimizer := optimizer.NewOptimizer(trainer.options.Optimizer)

	// 建立特征权重向量
	featureDimension := set.GetOptions().FeatureDimension
	numLabels := set.GetOptions().NumLabels
	var weights *util.Matrix
	if set.GetOptions().FeatureIsSparse {
		weights = util.NewSparseMatrix(numLabels)
	} else {
		weights = util.NewMatrix(numLabels, featureDimension)
	}

	// 得到优化的特征权重向量
	optimizer.OptimizeWeights(weights, MaxEntComputeInstanceDerivative, set)

	classifier := new(MaxEntClassifier)
	classifier.Weights = weights
	classifier.NumLabels = numLabels
	classifier.FeatureDimension = featureDimension
	classifier.FeatureDictionary = set.GetFeatureDictionary()
	classifier.LabelDictionary = set.GetLabelDictionary()
	return classifier
}
Ejemplo n.º 2
0
Archivo: rbm.go Proyecto: sguzwf/mlf
func (rbm *RBM) Train(set data.Dataset) {
	featureDimension := set.GetOptions().FeatureDimension
	visibleDim := featureDimension
	hiddenDim := rbm.options.NumHiddenUnits + 1
	log.Printf("#visible = %d, #hidden = %d", featureDimension-1, hiddenDim-1)

	// 随机化 weights
	rbm.lock.Lock()
	rbm.lock.weights = util.NewMatrix(hiddenDim, visibleDim)
	oldWeights := util.NewMatrix(hiddenDim, visibleDim)
	batchDerivative := util.NewMatrix(hiddenDim, visibleDim)
	for i := 0; i < hiddenDim; i++ {
		for j := 0; j < visibleDim; j++ {
			value := (rand.Float64()*2 - 1) * 0.01
			rbm.lock.weights.Set(i, j, value)
		}
	}
	rbm.lock.Unlock()

	// 启动工作协程
	ch := make(chan *data.Instance, rbm.options.Worker)
	out := make(chan *util.Matrix, rbm.options.Worker)
	for iWorker := 0; iWorker < rbm.options.Worker; iWorker++ {
		go rbm.derivativeWorker(ch, out, visibleDim, hiddenDim)
	}

	iteration := 0
	delta := 1.0
	for (rbm.options.MaxIter == 0 || iteration < rbm.options.MaxIter) &&
		(rbm.options.Delta == 0 || delta > rbm.options.Delta) {
		iteration++

		go rbm.feeder(set, ch)
		iBatch := 0
		batchDerivative.Clear()
		numInstances := set.NumInstances()
		for it := 0; it < numInstances; it++ {
			// 乱序读入
			derivative := <-out
			batchDerivative.Increment(derivative, rbm.options.LearningRate)
			iBatch++

			if iBatch == rbm.options.BatchSize || it == numInstances-1 {
				rbm.lock.Lock()
				rbm.lock.weights.Increment(batchDerivative, 1.0)
				rbm.lock.Unlock()
				iBatch = 0
				batchDerivative.Clear()
			}
		}

		// 统计delta和|weight|
		rbm.lock.RLock()
		weightsNorm := rbm.lock.weights.Norm()
		batchDerivative.DeepCopy(rbm.lock.weights)
		batchDerivative.Increment(oldWeights, -1.0)
		derivativeNorm := batchDerivative.Norm()
		delta = derivativeNorm / weightsNorm
		log.Printf("iter = %d, delta = %f, |weight| = %f",
			iteration, delta, weightsNorm)
		oldWeights.DeepCopy(rbm.lock.weights)
		rbm.lock.RUnlock()
	}
}