Exemplo n.º 1
0
Arquivo: rbm.go Projeto: sguzwf/mlf
func main() {
	flag.Parse()
	runtime.GOMAXPROCS(runtime.NumCPU())

	// 载入训练集
	set := contrib.LoadLibSVMDataset(*libsvm_file, false)

	options := rbm.RBMOptions{
		NumHiddenUnits:       *hidden,
		NumCD:                *numCD,
		Worker:               runtime.NumCPU(),
		LearningRate:         *learning_rate,
		MaxIter:              *maxIter,
		BatchSize:            *batch_size,
		Delta:                *delta,
		UseBinaryHiddenUnits: *useBinary,
	}

	// 创建训练器
	machine := rbm.NewRBM(options)

	machine.Train(set)

	machine.Write(*model)
}
Exemplo n.º 2
0
func main() {
	flag.Parse()
	runtime.GOMAXPROCS(runtime.NumCPU())

	// 载入训练集
	set := contrib.LoadLibSVMDataset(*libsvm_file, false)

	// 创建训练器
	machine := rbm.LoadRBM(*model)

	visibleDim := set.GetOptions().FeatureDimension
	hiddenDim := machine.GetOptions().NumHiddenUnits + 1

	iter := set.CreateIterator()
	iter.Start()
	for !iter.End() {
		instance := iter.GetInstance()
		v := util.NewVector(visibleDim)

		content := fmt.Sprintf("%s", instance.Output.LabelString)

		for i := 0; i < visibleDim; i++ {
			value := instance.Features.Get(i)
			v.Set(i, value)
			if value != 0.0 && *append {
				content = fmt.Sprintf("%s %d:%d", content, i+1, int(value))
			}
		}

		h := machine.SampleHidden(v, *numCD, *useBinary)

		for i := 1; i < hiddenDim; i++ {
			value := h.Get(i)
			if value != 0.0 {
				if *append {
					if *useBinary {
						content = fmt.Sprintf("%s %d:%d", content, visibleDim+i-1, int(value))
					} else {
						content = fmt.Sprintf("%s %d:%.3f", content, visibleDim+i-1, value)
					}
				} else {
					if *useBinary {
						content = fmt.Sprintf("%s %d:%d", content, i, int(value))
					} else {
						content = fmt.Sprintf("%s %d:%.3f", content, i, value)
					}
				}
			}
		}

		fmt.Printf("%s\n", content)

		iter.Next()
	}
}
Exemplo n.º 3
0
func main() {
	flag.Parse()

	set := contrib.LoadLibSVMDataset(*input, true)

	iterator := set.CreateIterator()

	client := &http.Client{}
	for {
		iterator.Start()
		for !iterator.End() {
			instance := iterator.GetInstance()

			if *mode != "train" {
				instance.Output = nil
			}

			httpBody, errMarshal := json.Marshal(instance)
			if errMarshal != nil {
				log.Print("无法JSON串行化样本")
			}
			req, _ := http.NewRequest("POST", "http://"+*server+"/train", bytes.NewReader(httpBody))
			req.Header.Set("Content-Type", "application/json")
			res, err := client.Do(req)
			io.Copy(ioutil.Discard, res.Body)
			if err != nil {
				log.Print("http请求失败, err=", err)
			} else {
				res.Body.Close()
			}

			iterator.Next()
		}
	}

}
Exemplo n.º 4
0
func main() {
	flag.Parse()
	runtime.GOMAXPROCS(runtime.NumCPU())

	// 载入训练集
	set := contrib.LoadLibSVMDataset(*libsvm_file, false)

	// 设置训练器参数
	trainerOptions := supervised.TrainerOptions{
		Optimizer: optimizer.OptimizerOptions{
			OptimizerName:         *opt,
			RegularizationScheme:  *reg,
			RegularizationFactor:  *reg_factor,
			LearningRate:          *learning_rate,
			CharacteristicTime:    *characteristic_time,
			ConvergingDeltaWeight: *delta,
			ConvergingSteps:       3,
			MaxIterations:         *max_iter,
			GDBatchSize:           *batch_size,
		}}

	// 创建训练器
	trainer := supervised.NewMaxEntClassifierTrainer(trainerOptions)

	// 打开处理器profile文件
	if *cpuprofile != "" {
		f, err := os.Create(*cpuprofile)
		if err != nil {
			log.Fatal(err)
		}
		pprof.StartCPUProfile(f)
		defer pprof.StopCPUProfile()
	}

	// 进行交叉评价
	evaluators := eval.NewEvaluators([]eval.Evaluator{
		&eval.PREvaluator{}, &eval.AccuracyEvaluator{}})
	if *folds != 0 {
		result := eval.CrossValidate(trainer, set, evaluators, *folds)
		log.Print(*folds, "-folds 交叉评价:")
		log.Printf("精度   =  %.2f %%", result.Metrics["precision"]*100)
		log.Printf("召回率 =  %.2f %%", result.Metrics["recall"]*100)
		log.Printf("F1     =  %.2f %%", result.Metrics["fscore"]*100)
		log.Printf("准确度 =  %.2f %%", result.Metrics["accuracy"]*100)
		return
	}

	// 在全部数据上训练模型
	model := trainer.Train(set)
	model.Write(*model_file)

	// 测试模型
	if *test_file != "" {
		// 载入测试集
		testSet := contrib.LoadLibSVMDataset(*test_file, false)

		// 在测试集上评价模型并输出结果
		result := evaluators.Evaluate(model, testSet)
		log.Print("测试数据集评价:")
		log.Printf("精度   =  %.2f %%", result.Metrics["precision"]*100)
		log.Printf("召回率 =  %.2f %%", result.Metrics["recall"]*100)
		log.Printf("F1     =  %.2f %%", result.Metrics["fscore"]*100)
		log.Printf("准确度 =  %.2f %%", result.Metrics["accuracy"]*100)
	}
}