func (trainer *MaxEntClassifierTrainer) Train(set data.Dataset) Model { // 检查训练数据是否是分类问题 if !set.GetOptions().IsSupervisedLearning { log.Fatal("训练数据不是分类问题数据") } // 建立新的优化器 optimizer := optimizer.NewOptimizer(trainer.options.Optimizer) // 建立特征权重向量 featureDimension := set.GetOptions().FeatureDimension numLabels := set.GetOptions().NumLabels var weights *util.Matrix if set.GetOptions().FeatureIsSparse { weights = util.NewSparseMatrix(numLabels) } else { weights = util.NewMatrix(numLabels, featureDimension) } // 得到优化的特征权重向量 optimizer.OptimizeWeights(weights, MaxEntComputeInstanceDerivative, set) classifier := new(MaxEntClassifier) classifier.Weights = weights classifier.NumLabels = numLabels classifier.FeatureDimension = featureDimension classifier.FeatureDictionary = set.GetFeatureDictionary() classifier.LabelDictionary = set.GetLabelDictionary() return classifier }
// 从options中创建训练器 func NewOnlineSGDClassifier(options OnlineSGDClassifierOptions) *OnlineSGDClassifier { classifier := new(OnlineSGDClassifier) classifier.options = options if classifier.options.BatchSize <= 1 { classifier.options.BatchSize = 1 } classifier.weights = util.NewSparseMatrix(options.NumLabels - 1) classifier.derivative = util.NewSparseMatrix(options.NumLabels - 1) classifier.instanceDerivative = util.NewSparseMatrix(options.NumLabels - 1) classifier.evaluator = new(FrapEvaluator) classifier.evaluator.Init(options.NumInstancesForEvaluation) classifier.featureDictionary = dictionary.NewDictionary(1) classifier.labelDictionary = dictionary.NewDictionary(0) return classifier }
// 初始化优化结构体 // 为结构体中的向量分配新的内存,向量的长度可能发生变化。 func (opt *lbfgsOptimizer) initStruct(labels, features int, isSparse bool) { opt.labels = labels opt.x = make([]*util.Matrix, *lbfgs_history_size) opt.g = make([]*util.Matrix, *lbfgs_history_size) opt.s = make([]*util.Matrix, *lbfgs_history_size) opt.y = make([]*util.Matrix, *lbfgs_history_size) opt.ro = util.NewVector(*lbfgs_history_size) opt.alpha = util.NewVector(*lbfgs_history_size) opt.beta = util.NewVector(*lbfgs_history_size) if !isSparse { opt.q = util.NewMatrix(labels, features) opt.z = util.NewMatrix(labels, features) for i := 0; i < *lbfgs_history_size; i++ { opt.x[i] = util.NewMatrix(labels, features) opt.g[i] = util.NewMatrix(labels, features) opt.s[i] = util.NewMatrix(labels, features) opt.y[i] = util.NewMatrix(labels, features) } } else { opt.q = util.NewSparseMatrix(labels) opt.z = util.NewSparseMatrix(labels) for i := 0; i < *lbfgs_history_size; i++ { opt.x[i] = util.NewSparseMatrix(labels) opt.g[i] = util.NewSparseMatrix(labels) opt.s[i] = util.NewSparseMatrix(labels) opt.y[i] = util.NewSparseMatrix(labels) } } }