コード例 #1
0
ファイル: online_sgd.go プロジェクト: reginald1787/mlf
// 从options中创建训练器
func NewOnlineSGDClassifier(options OnlineSGDClassifierOptions) *OnlineSGDClassifier {
	classifier := new(OnlineSGDClassifier)
	classifier.options = options
	if classifier.options.BatchSize <= 1 {
		classifier.options.BatchSize = 1
	}
	classifier.weights = util.NewSparseMatrix(options.NumLabels - 1)
	classifier.derivative = util.NewSparseMatrix(options.NumLabels - 1)
	classifier.instanceDerivative = util.NewSparseMatrix(options.NumLabels - 1)
	classifier.evaluator = new(FrapEvaluator)
	classifier.evaluator.Init(options.NumInstancesForEvaluation)
	classifier.featureDictionary = dictionary.NewDictionary(1)
	classifier.labelDictionary = dictionary.NewDictionary(0)

	return classifier
}
コード例 #2
0
ファイル: inmem_dataset.go プロジェクト: reginald1787/mlf
// 向数据集中添加一个样本
// 成功添加则返回true,否则返回false
func (set *inmemDataset) AddInstance(instance *Instance) bool {
	set.CheckFinalized(false)

	// 添加第一条样本时确定数据集的一些性质
	if len(set.instances) == 0 {
		if instance.NamedFeatures != nil {
			set.useFeatureDict = true
			set.featureDict = dictionary.NewDictionary(1) // 特征ID从0开始
			ConvertNamedFeatures(instance, set.featureDict)
		}

		if instance.Features.IsSparse() {
			set.options.FeatureIsSparse = true
			set.options.FeatureDimension = 0
		} else {
			set.options.FeatureIsSparse = false
			set.options.FeatureDimension = len(instance.Features.Keys())
		}

		if instance.Output == nil {
			set.options.IsSupervisedLearning = false
		} else {
			set.options.IsSupervisedLearning = true
			if instance.Output.LabelString != "" {
				set.useLabelDict = true
				set.labelDict = dictionary.NewDictionary(0)
				instance.Output.Label =
					set.labelDict.GetIdFromName(instance.Output.LabelString)
			}
		}
	} else {
		// 否则检查后续数据样本类型是否一致
		if instance.NamedFeatures != nil {
			ConvertNamedFeatures(instance, set.featureDict)
			if !set.useFeatureDict {
				log.Print("数据集不使用特征词典而添加的样本使用NamedFeatures")
				return false
			}
		} else {
			if set.useFeatureDict {
				log.Print("数据集使用特征词典而添加的样本不使用NamedFeatures")
				return false
			}
		}

		if set.options.FeatureIsSparse {
			if !instance.Features.IsSparse() {
				log.Print("数据集使用稀疏特征而添加的样本不稀疏")
				return false
			}
		} else {
			if instance.Features.IsSparse() {
				log.Print("数据集使用稠密特征而添加的样本稀疏")
				return false
			}

			if set.options.FeatureDimension != len(instance.Features.Keys()) {
				log.Print("数据集特征数和添加样本的特征数不同")
				return false
			}
		}

		if instance.Output == nil {
			if set.options.IsSupervisedLearning {
				log.Print("数据集为监督式而添加样本为非监督式数据")
				return false
			}
		} else {
			if !set.options.IsSupervisedLearning {
				log.Print("数据集为非监督式而添加样本为监督式数据")
				return false
			}

			if instance.Output.LabelString != "" {
				if !set.useLabelDict {
					log.Print("数据集不使用标注词典而添加的样本使用LabelString")
					return false
				}
			} else {
				if set.useLabelDict {
					log.Print("数据集使用标注词典而添加的样本不使用LabelString")
					return false
				}
			}
		}
	}

	if set.options.IsSupervisedLearning {
		if instance.Output.LabelString != "" {
			instance.Output.Label =
				set.labelDict.GetIdFromName(instance.Output.LabelString)
		}

		if instance.Output.Label < 0 {
			log.Println("样本标注值不在合法范围")
			return false
		}

		if instance.Output.Label >= set.options.NumLabels {
			set.options.NumLabels = instance.Output.Label + 1
		}
	}

	set.instances = append(set.instances, instance)
	return true
}