// 从options中创建训练器 func NewOnlineSGDClassifier(options OnlineSGDClassifierOptions) *OnlineSGDClassifier { classifier := new(OnlineSGDClassifier) classifier.options = options if classifier.options.BatchSize <= 1 { classifier.options.BatchSize = 1 } classifier.weights = util.NewSparseMatrix(options.NumLabels - 1) classifier.derivative = util.NewSparseMatrix(options.NumLabels - 1) classifier.instanceDerivative = util.NewSparseMatrix(options.NumLabels - 1) classifier.evaluator = new(FrapEvaluator) classifier.evaluator.Init(options.NumInstancesForEvaluation) classifier.featureDictionary = dictionary.NewDictionary(1) classifier.labelDictionary = dictionary.NewDictionary(0) return classifier }
// 向数据集中添加一个样本 // 成功添加则返回true,否则返回false func (set *inmemDataset) AddInstance(instance *Instance) bool { set.CheckFinalized(false) // 添加第一条样本时确定数据集的一些性质 if len(set.instances) == 0 { if instance.NamedFeatures != nil { set.useFeatureDict = true set.featureDict = dictionary.NewDictionary(1) // 特征ID从0开始 ConvertNamedFeatures(instance, set.featureDict) } if instance.Features.IsSparse() { set.options.FeatureIsSparse = true set.options.FeatureDimension = 0 } else { set.options.FeatureIsSparse = false set.options.FeatureDimension = len(instance.Features.Keys()) } if instance.Output == nil { set.options.IsSupervisedLearning = false } else { set.options.IsSupervisedLearning = true if instance.Output.LabelString != "" { set.useLabelDict = true set.labelDict = dictionary.NewDictionary(0) instance.Output.Label = set.labelDict.GetIdFromName(instance.Output.LabelString) } } } else { // 否则检查后续数据样本类型是否一致 if instance.NamedFeatures != nil { ConvertNamedFeatures(instance, set.featureDict) if !set.useFeatureDict { log.Print("数据集不使用特征词典而添加的样本使用NamedFeatures") return false } } else { if set.useFeatureDict { log.Print("数据集使用特征词典而添加的样本不使用NamedFeatures") return false } } if set.options.FeatureIsSparse { if !instance.Features.IsSparse() { log.Print("数据集使用稀疏特征而添加的样本不稀疏") return false } } else { if instance.Features.IsSparse() { log.Print("数据集使用稠密特征而添加的样本稀疏") return false } if set.options.FeatureDimension != len(instance.Features.Keys()) { log.Print("数据集特征数和添加样本的特征数不同") return false } } if instance.Output == nil { if set.options.IsSupervisedLearning { log.Print("数据集为监督式而添加样本为非监督式数据") return false } } else { if !set.options.IsSupervisedLearning { log.Print("数据集为非监督式而添加样本为监督式数据") return false } if instance.Output.LabelString != "" { if !set.useLabelDict { log.Print("数据集不使用标注词典而添加的样本使用LabelString") return false } } else { if set.useLabelDict { log.Print("数据集使用标注词典而添加的样本不使用LabelString") return false } } } } if set.options.IsSupervisedLearning { if instance.Output.LabelString != "" { instance.Output.Label = set.labelDict.GetIdFromName(instance.Output.LabelString) } if instance.Output.Label < 0 { log.Println("样本标注值不在合法范围") return false } if instance.Output.Label >= set.options.NumLabels { set.options.NumLabels = instance.Output.Label + 1 } } set.instances = append(set.instances, instance) return true }