コード例 #1
0
func (fft *FastFtrlTrainer) TrainImpl(
	model_file string,
	train_file string,
	line_cnt int,
	test_file string) error {

	if !fft.Init {
		fft.log4fft.Error("[FastFtrlTrainer-TrainImpl] Fast ftrl trainer restore error.")
		return errors.New("[FastFtrlTrainer-TrainImpl] Fast ftrl trainer restore error.")
	}

	fft.log4fft.Info(fmt.Sprintf(
		"[%s] params={alpha:%.2f, beta:%.2f, l1:%.2f, l2:%.2f, dropout:%.2f, epoch:%d}\n",
		fft.JobName,
		fft.ParamServer.Alpha,
		fft.ParamServer.Beta,
		fft.ParamServer.L1,
		fft.ParamServer.L2,
		fft.ParamServer.Dropout,
		fft.Epoch))

	var solvers []solver.FtrlWorker = make([]solver.FtrlWorker, fft.NumThreads)
	for i := 0; i < fft.NumThreads; i++ {
		solvers[i].Initialize(&fft.ParamServer, fft.PusStep, fft.FetchStep)
	}

	predict_func := func(x util.Pvector) float64 {
		return fft.ParamServer.Predict(x)
	}

	var timer util.StopWatch
	timer.StartTimer()
	for iter := 0; iter < fft.Epoch; iter++ {
		var file_parser ParallelFileParser
		file_parser.OpenFile(train_file, fft.NumThreads)
		count := 0
		var loss float64 = 0.

		var lock sync.Mutex
		worker_func := func(i int, c *sync.WaitGroup) {
			local_count := 0
			var local_loss float64 = 0
			for {
				flag, y, x := file_parser.ReadSampleMultiThread(i)
				if flag != nil {
					break
				}

				pred := solvers[i].Update(x, y, &fft.ParamServer)
				local_loss += calc_loss(y, pred)
				local_count++

				if i == 0 && local_count%10000 == 0 {
					tmp_cnt := math.Min(float64(local_count*fft.NumThreads), float64(line_cnt))
					fft.log4fft.Info(fmt.Sprintf("[%s] epoch=%d processed=[%.2f%%] time=[%.2f] train-loss=[%.6f]\n",
						fft.JobName,
						iter,
						float64(tmp_cnt*100)/float64(line_cnt),
						timer.StopTimer(),
						float64(local_loss)/float64(local_count)))
				}
			}
			lock.Lock()
			count += local_count
			loss += local_loss
			lock.Unlock()

			solvers[i].PushParam(&fft.ParamServer)
			defer c.Done()
		}

		if iter == 0 && util.UtilGreater(fft.BurnIn, float64(0)) {
			burn_in_cnt := int(fft.BurnIn * float64(line_cnt))
			var local_loss float64 = 0
			for i := 0; i < burn_in_cnt; i++ {
				//线程0做预热
				flag, y, x := file_parser.ReadSample(0)
				if flag != nil {
					break
				}

				pred := fft.ParamServer.Update(x, y)
				local_loss += calc_loss(y, pred)
				if i%10000 == 0 {
					fft.log4fft.Info(fmt.Sprintf("[%s] burn-in processed=[%.2f%%] time=[%.2f] train-loss=[%.6f]\n",
						fft.JobName,
						float64((i+1)*100)/float64(line_cnt),
						timer.StopTimer(),
						float64(local_loss)/float64(i+1)))
				}
			}

			fft.log4fft.Info(fmt.Sprintf("[%s] burn-in processed=[%.2f%%] time=[%.2f] train-loss=[%.6f]\n",
				fft.JobName,
				float64(burn_in_cnt*100)/float64(line_cnt),
				timer.StopTimer(),
				float64(local_loss)/float64(burn_in_cnt)))

			if util.UtilFloat64Equal(fft.BurnIn, float64(1)) {
				continue
			}
		}

		for i := 0; i < fft.NumThreads; i++ {
			solvers[i].Reset(&fft.ParamServer)
		}

		util.UtilParallelRun(worker_func, fft.NumThreads)

		file_parser.CloseFile(fft.NumThreads)

		//		f(w,
		//			"[%s] epoch=%d processed=[%.2f%%] time=[%.2f] train-loss=[%.6f]\n",
		//			fft.JobName,
		//			iter,
		//			float64(count*100)/float64(line_cnt),
		//			timer.StopTimer(),
		//			float64(loss)/float64(count))

		if test_file != "" {
			eval_loss := evaluate_file(test_file, predict_func, fft.NumThreads)
			fft.log4fft.Info(fmt.Sprintf("[%s] validation-loss=[%f]\n", fft.JobName, float64(eval_loss)))
		}
	}

	return fft.ParamServer.SaveModel(model_file)
}
コード例 #2
0
ファイル: ftrl_predict.go プロジェクト: vivounicorn/goline
func Run(argc int, argv []string) (string, error) {

	var job_name string
	var test_file string
	var model_file string
	var output_file string
	var threshold float64
	log := util.GetLogger()

	if len(argv) == 5 {
		job_name = argv[0]
		test_file = argv[1]
		model_file = argv[2]
		output_file = argv[3]
		threshold, _ = strconv.ParseFloat(argv[4], 64)
	} else {
		print_usage(argc, argv)
		log.Error("[Predictor-Run] Input parameters error.")
		return fmt.Sprintf(errorjson, "[Predictor-Run] Input parameters error."), errors.New("[Predictor-Run] Input parameters error.")
	}

	if len(job_name) == 0 || len(test_file) == 0 || len(model_file) == 0 || len(output_file) == 0 {
		print_usage(argc, argv)
		log.Error("[Predictor-Run] Input parameters error.")
		return fmt.Sprintf(errorjson, "[Predictor-Run] Input parameters error."), errors.New("[Predictor-Run] Input parameters error.")
	}

	var model solver.LRModel
	model.Initialize(model_file)

	var wfp *os.File
	var err1 error
	exist := func(filename string) bool {
		var exist = true
		if _, err := os.Stat(filename); os.IsNotExist(err) {
			exist = false
		}
		return exist
	}

	if exist(output_file) {
		wfp, err1 = os.OpenFile(output_file, os.O_SYNC, 0666)
	} else {
		wfp, err1 = os.Create(output_file)
	}

	if err1 != nil {
		log.Error("[Predictor-Run] Open file error." + err1.Error())
		return fmt.Sprintf(errorjson, err1.Error()), errors.New("[Predictor-Run] Open file error." + err1.Error())
	}

	defer wfp.Close()

	cnt := 0      //样本总数
	pcorrect := 0 //正样本预测正确数
	pcnt := 0     //正样本总数
	ncorrect := 0 //负样本预测正确数
	var loss float64 = 0.
	var parser trainer.FileParser
	err := parser.OpenFile(test_file)
	if err != nil {
		log.Error("[Predictor-Run] Open file error." + err.Error())
		return fmt.Sprintf(errorjson, err.Error()), errors.New("[Predictor-Run] Open file error." + err.Error())
	}

	var pred_scores util.Dvector

	for {
		res, y, x := parser.ReadSample()
		if res != nil {
			break
		}

		pred := model.Predict(x)
		pred = math.Max(math.Min(pred, 1.-10e-15), 10e-15)
		wfp.WriteString(fmt.Sprintf("%f\n", pred))

		pred_scores = append(pred_scores, util.DPair{pred, y})

		cnt++
		if util.UtilFloat64Equal(y, 1.0) {
			pcnt++
		}

		var pred_label float64 = 0
		if pred > threshold {
			pred_label = 1
		}

		if util.UtilFloat64Equal(pred_label, y) {
			if util.UtilFloat64Equal(y, 1.0) {
				pcorrect++
			} else {
				ncorrect++
			}
		}

		pred = math.Max(math.Min(pred, 1.-10e-15), 10e-15)
		if y > 0 {
			loss += -math.Log(pred)
		} else {
			loss += -math.Log(1. - pred)
		}

	}

	auc := calc_auc(pred_scores)
	if auc < 0.5 {
		auc = 0.5
	}

	if cnt > 0 {
		log.Info(fmt.Sprintf("[%s] Log-likelihood = %f\n", job_name, float64(loss)/float64(cnt)))
		log.Info(fmt.Sprintf("[%s] Precision = %.2f%% (%d/%d)\n", job_name,
			float64(pcorrect*100)/float64(cnt-pcnt-ncorrect+pcorrect),
			pcorrect, cnt-pcnt-ncorrect+pcorrect))
		log.Info(fmt.Sprintf("[%s] Recall = %.2f%% (%d/%d)\n", job_name,
			float64(pcorrect*100)/float64(pcnt), pcorrect, pcnt))
		log.Info(fmt.Sprintf("[%s] Accuracy = %.2f%% (%d/%d)\n", job_name,
			float64((pcorrect+ncorrect)*100)/float64(cnt), (pcorrect + ncorrect), cnt))
		log.Info(fmt.Sprintf("[%s] AUC = %f\n", job_name, auc))
	}

	parser.CloseFile()

	util.Write2File(output_file, fmt.Sprintf(" Log-likelihood = %f\n Precision = %f (%d/%d)\n Recall = %f (%d/%d)\n Accuracy = %f (%d/%d)\n AUC = %f\n",
		float64(loss)/float64(cnt),
		float64(pcorrect)/float64(cnt-pcnt-ncorrect+pcorrect), pcorrect, cnt-pcnt-ncorrect+pcorrect,
		float64(pcorrect)/float64(pcnt), pcorrect, pcnt,
		float64(pcorrect+ncorrect)/float64(cnt), pcorrect+ncorrect, cnt,
		auc))

	return fmt.Sprintf(returnJson,
		job_name,
		fmt.Sprintf("Log-likelihood = %f", float64(loss)/float64(cnt)),
		fmt.Sprintf("Precision = %f (%d/%d)", float64(pcorrect)/float64(cnt-pcnt-ncorrect+pcorrect), pcorrect, cnt-pcnt-ncorrect+pcorrect),
		fmt.Sprintf("Recall = %f (%d/%d)", float64(pcorrect)/float64(pcnt), pcorrect, pcnt),
		fmt.Sprintf("Accuracy = %f (%d/%d)", float64((pcorrect+ncorrect))/float64(cnt), (pcorrect+ncorrect), cnt),
		fmt.Sprintf("AUC = %f", auc),
		output_file), nil
}
コード例 #3
0
ファイル: lands.go プロジェクト: vivounicorn/goline
/*
 * 离线模型请求串格式
 * http://127.0.0.1:8080/offline?biz=[model name]&src=[hdfs/local]&dst=[redis&local&json]
                &alpha=[0.1]&beta=[0.1]&l1=[10]&l2=[10]&dropout=[0.1]&epoch=[2]
				&push=[push step]&fetch=[fetch step]&threads=[threads number]
				&train=[train file name]&test=[test file name]&debug=[off]&thd=[threshold]
   src:训练、测试数据源为hdfs/local
   dst:模型输出到redis、local和json
   train:训练数据完整路径
   test:测试数据完整路径
*/
func (lan *Lands) offlineServeHttp(w http.ResponseWriter, par *util.ModelParam) error {
	lan.log4goline.Info("[Lands-offlineServeHttp] Begin offline learning...")
	conn := lan.pool.Get()
	defer conn.Close()
	var base_path_off, model_path, train_path, test_path string
	var predict_path string

	//建立模型训练本地路径
	base_path_off = lan.conf.DataPathBase + par.Biz + "/off/"

	timestamp := time.Now().Format(TimeFormatString)
	err := util.Mkdir(base_path_off + "/" + timestamp)
	if err != nil {
		lan.log4goline.Error("[Lands-offlineServeHttp] Offline model make local directory error." + err.Error())
		return errors.New("[Lands-offlineServeHttp] Offline model make local directory error." + err.Error())
	}

	base_path_on := lan.conf.DataPathBase + par.Biz + "/on/workspace"

	err = util.Mkdir(base_path_on)
	if err != nil {
		lan.log4goline.Error("[Lands-offlineServeHttp] Make online local workspace directory error." + err.Error())
		return errors.New("[Lands-offlineServeHttp] Make online local workspace directory error." + err.Error())
	}

	//挂载数据
	lan.log4goline.Info("[Lands-offlineServeHttp] Mount training data.")
	if par.Src == "hdfs" {
		client, err := lan.createHdfsClient()
		if err != nil {
			lan.log4goline.Error("[Lands-offlineServeHttp] Create hdfs client error." + err.Error())
			return errors.New("[Lands-offlineServeHttp] Create hdfs client error." + err.Error())
		}

		model_path = base_path_off + "/" + timestamp + "/model.dat"
		train_path = base_path_off + "/" + timestamp + "/train.dat"
		test_path = base_path_off + "/" + timestamp + "/test.dat"
		predict_path = base_path_off + "/" + timestamp + "/predict.dat"

		client.GetMerge(par.Train, train_path, false)
		if err != nil {
			lan.log4goline.Error("[Lands-offlineServeHttp] Getmerge train data from hdfs to local error." + err.Error())
			return errors.New("[Lands-offlineServeHttp] Getmerge train data from hdfs to local error." + err.Error())
		}

		client.GetMerge(par.Test, test_path, false)
		if err != nil {
			lan.log4goline.Error("[Lands-offlineServeHttp] Getmerge test data from hdfs to local error." + err.Error())
			return errors.New("[Lands-offlineServeHttp] Getmerge test data from hdfs to local error." + err.Error())
		}

	} else if par.Src == "local" {
		model_path = base_path_off + "/" + timestamp + "/model.dat"
		train_path = base_path_off + "/" + timestamp + "/train.dat"
		test_path = base_path_off + "/" + timestamp + "/test.dat"
		predict_path = base_path_off + "/" + timestamp + "/predict.dat"
		err := util.CopyFile(train_path, par.Train)
		if err != nil {
			lan.log4goline.Error("[Lands-offlineServeHttp] Copy train data from local to local error." + err.Error())
			return errors.New("[Lands-offlineServeHttp] Copy train data from local to local error." + err.Error())
		}
		err = util.CopyFile(test_path, par.Test)
		if err != nil {
			lan.log4goline.Error("[Lands-offlineServeHttp] Copy test data from local to local error." + err.Error())
			return errors.New("[Lands-offlineServeHttp] Copy test data from local to local error." + err.Error())
		} else {
			lan.log4goline.Error("[Lands-offlineServeHttp] Training data source path error.")
			return errors.New("[Lands-offlineServeHttp] Training data source path error.")
		}
	}

	lan.log4goline.Info(fmt.Sprintf("[Lands-offlineServeHttp] Model path=%s,train path=%s, test path=%s\n",
		model_path,
		train_path,
		test_path))

	//训练数据格式检查及转换
	lan.log4goline.Info("[Lands-offlineServeHttp] Check training data.")
	_, err = lan.checkData(train_path)
	if err != nil {
		lan.log4goline.Error("[Lands-offlineServeHttp] Check train data from local to local error." + err.Error())
		return errors.New("[Lands-offlineServeHttp] Check train data from local to local error." + err.Error())
	}
	lan.log4goline.Info("[Lands-offlineServeHttp] Check testing data.")
	_, err = lan.checkData(test_path)
	if err != nil {
		lan.log4goline.Error("[Lands-offlineServeHttp] Check test data from local to local error." + err.Error())
		return errors.New("[Lands-offlineServeHttp] Check test data from local to local error." + err.Error())
	}

	//数据抽样
	if !util.UtilFloat64Equal(par.Sample, 1.0) && !util.UtilFloat64Equal(par.Sample, -1.0) {
		lan.log4goline.Info("[Lands-offlineServeHttp] Training data sampling.")
		err = util.FileSampleWithRatio(train_path, par.Sample)
		if err != nil {
			lan.log4goline.Error("[Lands-offlineServeHttp] Train data sampling error." + err.Error())
			return errors.New("[Lands-offlineServeHttp] Train data sampling error." + err.Error())
		}
	}

	//模型训练
	lan.log4goline.Info("[Lands-offlineServeHttp] Offline model training.")
	var fft trainer.FastFtrlTrainer
	fft.SetJobName(par.Biz + " offline " + timestamp)
	if !fft.Initialize(par.Epoch, par.Threads, true, 0, par.Push, par.Fetch) {
		lan.log4goline.Error("[Lands-offlineServeHttp] Initialize ftrl trainer error")
		return errors.New("[Lands-offlineServeHttp] Initialize ftrl trainer error.")
	}

	err = fft.Train(par.Alpha, par.Beta, par.L1, par.L2, par.Dropout, model_path,
		train_path, test_path)
	if err != nil {
		lan.log4goline.Error("[Lands-offlineServeHttp] Training model error." + err.Error())
		return errors.New("[Lands-offlineServeHttp] Training model error." + err.Error())
	}

	lan.log4goline.Info("[Lands-offlineServeHttp] Predict testing data.")
	_, err = predictor.Run(3, []string{par.Biz + " offline " + timestamp, test_path,
		model_path,
		predict_path,
		par.Threshold})

	if err != nil {
		lan.log4goline.Error("[Lands-offlineServeHttp] Predicting model error." + err.Error())
		return errors.New("[Lands-offlineServeHttp] Predicting model error." + err.Error())
	}

	err = util.CopyFile(base_path_on+"/model.dat", base_path_off+"/"+timestamp+"/model.dat")
	if err != nil {
		lan.log4goline.Error("[Lands-offlineServeHttp] Copy model error." + err.Error())
		return errors.New("[Lands-offlineServeHttp] Copy model error." + err.Error())
	}

	//清理目录
	lan.log4goline.Info("[Lands-offlineServeHttp] Clear local directory." + base_path_off)
	err = util.KeepLatestN(base_path_off, 5)
	if err != nil {
		lan.log4goline.Warn("[Lands-offlineServeHttp] Clear local file error." + err.Error())
		errors.New("[Lands-offlineServeHttp] Clear local file error." + err.Error())
	}

	m, err := fft.ParamServer.SaveEncodeModel()
	if err != nil {
		lan.log4goline.Error("[Lands-offlineServeHttp] Save model error." + err.Error())
		return errors.New("[Lands-offlineServeHttp] Save model error." + err.Error())
	}

	fmt.Fprintf(w, "%s", m)

	//模型存入redis
	lan.log4goline.Info("[Lands-offlineServeHttp] Write model to redis.")
	if par.Dst == "redis" {
		key := ModelPrefix + par.Biz
		conn.Send("SET", key, m)
		conn.Flush()
		_, err := conn.Receive()
		if err != nil {
			lan.log4goline.Error(fmt.Sprintln("[Lands-offlineServeHttp] Save model to redis error." + err.Error()))
			return errors.New("[Lands-offlineServeHttp] Save model to redis error." + err.Error())
		}
	}

	lan.log4goline.Info("[Lands-offlineServeHttp] End offline learning.")
	return nil
}