コード例 #1
0
ファイル: cross_validate.go プロジェクト: reginald1787/mlf
// 进行N-fold cross-validation,输出评价
func CrossValidate(trainer supervised.Trainer, set data.Dataset,
	evals *Evaluators, folds int) (output Evaluation) {
	output.Metrics = make(map[string]float64)
	for iFold := 0; iFold < folds; iFold++ {
		// 裂分训练数据
		trainBuckets := []data.SkipBucket{
			{false, iFold},
			{true, 1},
			{false, folds - 1 - iFold},
		}
		trainSet := data.NewSkipDataset(set, trainBuckets)

		// 裂分评价数据
		evalBuckets := []data.SkipBucket{
			{true, iFold},
			{false, 1},
			{true, folds - 1 - iFold},
		}
		evalSet := data.NewSkipDataset(set, evalBuckets)

		// 在训练数据上训练模型
		model := trainer.Train(trainSet)

		// 在评价数据上评价
		metrics := evals.Evaluate(model, evalSet)

		// 累加评价结果
		for m, v := range metrics.Metrics {
			output.Metrics[m] += v
		}
	}

	// 评价结果求平均
	for m := range output.Metrics {
		output.Metrics[m] /= float64(folds)
	}

	return
}
コード例 #2
0
ファイル: lbfgs.go プロジェクト: sguzwf/mlf
func (opt *lbfgsOptimizer) OptimizeWeights(
	weights *util.Matrix, derivative_func ComputeInstanceDerivativeFunc, set data.Dataset) {

	// 学习率计算器
	learningRate := NewLearningRate(opt.options)

	// 偏导数向量
	derivative := weights.Populate()

	// 优化循环
	step := 0
	convergingSteps := 0
	oldWeights := weights.Populate()
	weightsDelta := weights.Populate()

	// 为各个工作协程开辟临时资源
	numLbfgsThreads := *lbfgs_threads
	if numLbfgsThreads == 0 {
		numLbfgsThreads = runtime.NumCPU()
	}
	workerSet := make([]data.Dataset, numLbfgsThreads)
	workerDerivative := make([]*util.Matrix, numLbfgsThreads)
	workerInstanceDerivative := make([]*util.Matrix, numLbfgsThreads)
	for iWorker := 0; iWorker < numLbfgsThreads; iWorker++ {
		workerBuckets := []data.SkipBucket{
			{true, iWorker},
			{false, 1},
			{true, numLbfgsThreads - 1 - iWorker},
		}
		workerSet[iWorker] = data.NewSkipDataset(set, workerBuckets)
		workerDerivative[iWorker] = weights.Populate()
		workerInstanceDerivative[iWorker] = weights.Populate()
	}

	log.Print("开始L-BFGS优化")
	for {
		if opt.options.MaxIterations > 0 && step >= opt.options.MaxIterations {
			break
		}
		step++

		// 开始工作协程
		workerChannel := make(chan int, numLbfgsThreads)
		for iWorker := 0; iWorker < numLbfgsThreads; iWorker++ {
			go func(iw int) {
				workerDerivative[iw].Clear()
				iterator := workerSet[iw].CreateIterator()
				iterator.Start()
				for !iterator.End() {
					instance := iterator.GetInstance()
					derivative_func(
						weights, instance, workerInstanceDerivative[iw])
					//					log.Print(workerInstanceDerivative[iw].GetValues(0))
					workerDerivative[iw].Increment(
						workerInstanceDerivative[iw], float64(1)/float64(set.NumInstances()))
					iterator.Next()
				}
				workerChannel <- iw
			}(iWorker)
		}

		derivative.Clear()

		// 等待工作协程结束
		for iWorker := 0; iWorker < numLbfgsThreads; iWorker++ {
			<-workerChannel
		}
		for iWorker := 0; iWorker < numLbfgsThreads; iWorker++ {
			derivative.Increment(workerDerivative[iWorker], 1)
		}

		// 添加正则化项
		derivative.Increment(ComputeRegularization(weights, opt.options), 1.0/float64(set.NumInstances()))

		// 计算特征权重的增量
		delta := opt.GetDeltaX(weights, derivative)

		// 根据学习率更新权重
		learning_rate := learningRate.ComputeLearningRate(delta)
		weights.Increment(delta, learning_rate)

		weightsDelta.WeightedSum(weights, oldWeights, 1, -1)
		oldWeights.DeepCopy(weights)
		weightsNorm := weights.Norm()
		weightsDeltaNorm := weightsDelta.Norm()
		log.Printf("#%d |dw|/|w|=%f |w|=%f lr=%1.3g", step, weightsDeltaNorm/weightsNorm, weightsNorm, learning_rate)

		// 判断是否溢出
		if math.IsNaN(weightsNorm) {
			log.Fatal("优化失败:不收敛")
		}

		// 判断是否收敛
		if weightsDeltaNorm/weightsNorm < opt.options.ConvergingDeltaWeight {
			convergingSteps++
			if convergingSteps > opt.options.ConvergingSteps {
				log.Printf("收敛")
				break
			}
		} else {
			convergingSteps = 0
		}
	}
}