func LoadLibSVMDataset(path string, usingSparseRepresentation bool) data.Dataset { log.Print("载入libsvm格式文件", path) content, err := ioutil.ReadFile(path) if err != nil { log.Fatalf("无法打开文件\"%v\",错误提示:%v\n", path, err) } lines := strings.Split(string(content), "\n") minFeature := 10000 maxFeature := 0 labels := make(map[string]int) labelIndex := 0 for _, l := range lines { if l == "" { continue } fields := strings.Split(l, " ") _, ok := labels[fields[0]] if !ok { labels[fields[0]] = labelIndex labelIndex++ } for i := 1; i < len(fields); i++ { if fields[i] == "" { continue } fs := strings.Split(fields[i], ":") fid, _ := strconv.Atoi(fs[0]) if fid > maxFeature { maxFeature = fid } if fid < minFeature { minFeature = fid } } } if minFeature == 0 || maxFeature < 2 { log.Fatal("文件输入格式不合法") } log.Printf("feature 数目 %d", maxFeature) log.Printf("label 数目 %d", len(labels)) set := data.NewInmemDataset() for _, l := range lines { if l == "" { continue } fields := strings.Split(l, " ") instance := new(data.Instance) instance.Output = &data.InstanceOutput{ Label: labels[fields[0]], LabelString: fields[0], } if usingSparseRepresentation { instance.NamedFeatures = make(map[string]float64) } else { instance.Features = util.NewVector(maxFeature + 1) } // 常数项 if !usingSparseRepresentation { instance.Features.Set(0, 1) } for i := 1; i < len(fields); i++ { if fields[i] == "" { continue } fs := strings.Split(fields[i], ":") fid, _ := strconv.Atoi(fs[0]) value, _ := strconv.ParseFloat(fs[1], 64) if usingSparseRepresentation { instance.NamedFeatures[fs[0]] = value } else { instance.Features.Set(fid, value) } } set.AddInstance(instance) } set.Finalize() log.Print("载入数据样本数目 ", set.NumInstances()) return set }
func TestTrainWithNamedFeatures(t *testing.T) { set := data.NewInmemDataset() instance1 := new(data.Instance) instance1.NamedFeatures = map[string]float64{ "1": 1, "2": 1, "3": 1, "4": 3, } instance1.Output = &data.InstanceOutput{Label: 0} set.AddInstance(instance1) instance2 := new(data.Instance) instance2.NamedFeatures = map[string]float64{ "1": 1, "2": 3, "3": 1, "4": 5, } instance2.Output = &data.InstanceOutput{Label: 0} set.AddInstance(instance2) instance3 := new(data.Instance) instance3.NamedFeatures = map[string]float64{ "1": 1, "2": 3, "3": 4, "4": 7, } instance3.Output = &data.InstanceOutput{Label: 1} set.AddInstance(instance3) instance4 := new(data.Instance) instance4.NamedFeatures = map[string]float64{ "1": 1, "2": 2, "3": 8, "4": 6, } instance4.Output = &data.InstanceOutput{Label: 1} set.AddInstance(instance4) set.Finalize() gdTrainerOptions := TrainerOptions{ Optimizer: optimizer.OptimizerOptions{ OptimizerName: "gd", RegularizationScheme: 2, RegularizationFactor: 1, LearningRate: 0.1, ConvergingDeltaWeight: 1e-6, ConvergingSteps: 3, MaxIterations: 0, GDBatchSize: 0, // full-bath }, } gdTrainer := NewMaxEntClassifierTrainer(gdTrainerOptions) lbfgsTrainerOptions := TrainerOptions{ Optimizer: optimizer.OptimizerOptions{ OptimizerName: "lbfgs", RegularizationScheme: 2, RegularizationFactor: 1, LearningRate: 1, ConvergingDeltaWeight: 1e-6, ConvergingSteps: 3, MaxIterations: 0, }, } lbfgsTrainer := NewMaxEntClassifierTrainer(lbfgsTrainerOptions) lbfgsTrainer.Train(set) gdTrainer.Train(set).Write("test.mlf") model := LoadModel("test.mlf") util.Expect(t, "0", model.Predict(instance1).Label) util.Expect(t, "0", model.Predict(instance2).Label) util.Expect(t, "1", model.Predict(instance3).Label) util.Expect(t, "1", model.Predict(instance4).Label) }