func main() { flag.Parse() runtime.GOMAXPROCS(runtime.NumCPU()) // 载入训练集 set := contrib.LoadLibSVMDataset(*libsvm_file, false) options := rbm.RBMOptions{ NumHiddenUnits: *hidden, NumCD: *numCD, Worker: runtime.NumCPU(), LearningRate: *learning_rate, MaxIter: *maxIter, BatchSize: *batch_size, Delta: *delta, UseBinaryHiddenUnits: *useBinary, } // 创建训练器 machine := rbm.NewRBM(options) machine.Train(set) machine.Write(*model) }
func main() { flag.Parse() runtime.GOMAXPROCS(runtime.NumCPU()) // 载入训练集 set := contrib.LoadLibSVMDataset(*libsvm_file, false) // 创建训练器 machine := rbm.LoadRBM(*model) visibleDim := set.GetOptions().FeatureDimension hiddenDim := machine.GetOptions().NumHiddenUnits + 1 iter := set.CreateIterator() iter.Start() for !iter.End() { instance := iter.GetInstance() v := util.NewVector(visibleDim) content := fmt.Sprintf("%s", instance.Output.LabelString) for i := 0; i < visibleDim; i++ { value := instance.Features.Get(i) v.Set(i, value) if value != 0.0 && *append { content = fmt.Sprintf("%s %d:%d", content, i+1, int(value)) } } h := machine.SampleHidden(v, *numCD, *useBinary) for i := 1; i < hiddenDim; i++ { value := h.Get(i) if value != 0.0 { if *append { if *useBinary { content = fmt.Sprintf("%s %d:%d", content, visibleDim+i-1, int(value)) } else { content = fmt.Sprintf("%s %d:%.3f", content, visibleDim+i-1, value) } } else { if *useBinary { content = fmt.Sprintf("%s %d:%d", content, i, int(value)) } else { content = fmt.Sprintf("%s %d:%.3f", content, i, value) } } } } fmt.Printf("%s\n", content) iter.Next() } }
func main() { flag.Parse() set := contrib.LoadLibSVMDataset(*input, true) iterator := set.CreateIterator() client := &http.Client{} for { iterator.Start() for !iterator.End() { instance := iterator.GetInstance() if *mode != "train" { instance.Output = nil } httpBody, errMarshal := json.Marshal(instance) if errMarshal != nil { log.Print("无法JSON串行化样本") } req, _ := http.NewRequest("POST", "http://"+*server+"/train", bytes.NewReader(httpBody)) req.Header.Set("Content-Type", "application/json") res, err := client.Do(req) io.Copy(ioutil.Discard, res.Body) if err != nil { log.Print("http请求失败, err=", err) } else { res.Body.Close() } iterator.Next() } } }
func main() { flag.Parse() runtime.GOMAXPROCS(runtime.NumCPU()) // 载入训练集 set := contrib.LoadLibSVMDataset(*libsvm_file, false) // 设置训练器参数 trainerOptions := supervised.TrainerOptions{ Optimizer: optimizer.OptimizerOptions{ OptimizerName: *opt, RegularizationScheme: *reg, RegularizationFactor: *reg_factor, LearningRate: *learning_rate, CharacteristicTime: *characteristic_time, ConvergingDeltaWeight: *delta, ConvergingSteps: 3, MaxIterations: *max_iter, GDBatchSize: *batch_size, }} // 创建训练器 trainer := supervised.NewMaxEntClassifierTrainer(trainerOptions) // 打开处理器profile文件 if *cpuprofile != "" { f, err := os.Create(*cpuprofile) if err != nil { log.Fatal(err) } pprof.StartCPUProfile(f) defer pprof.StopCPUProfile() } // 进行交叉评价 evaluators := eval.NewEvaluators([]eval.Evaluator{ &eval.PREvaluator{}, &eval.AccuracyEvaluator{}}) if *folds != 0 { result := eval.CrossValidate(trainer, set, evaluators, *folds) log.Print(*folds, "-folds 交叉评价:") log.Printf("精度 = %.2f %%", result.Metrics["precision"]*100) log.Printf("召回率 = %.2f %%", result.Metrics["recall"]*100) log.Printf("F1 = %.2f %%", result.Metrics["fscore"]*100) log.Printf("准确度 = %.2f %%", result.Metrics["accuracy"]*100) return } // 在全部数据上训练模型 model := trainer.Train(set) model.Write(*model_file) // 测试模型 if *test_file != "" { // 载入测试集 testSet := contrib.LoadLibSVMDataset(*test_file, false) // 在测试集上评价模型并输出结果 result := evaluators.Evaluate(model, testSet) log.Print("测试数据集评价:") log.Printf("精度 = %.2f %%", result.Metrics["precision"]*100) log.Printf("召回率 = %.2f %%", result.Metrics["recall"]*100) log.Printf("F1 = %.2f %%", result.Metrics["fscore"]*100) log.Printf("准确度 = %.2f %%", result.Metrics["accuracy"]*100) } }