func (fft *FastFtrlTrainer) TrainImpl( model_file string, train_file string, line_cnt int, test_file string) error { if !fft.Init { fft.log4fft.Error("[FastFtrlTrainer-TrainImpl] Fast ftrl trainer restore error.") return errors.New("[FastFtrlTrainer-TrainImpl] Fast ftrl trainer restore error.") } fft.log4fft.Info(fmt.Sprintf( "[%s] params={alpha:%.2f, beta:%.2f, l1:%.2f, l2:%.2f, dropout:%.2f, epoch:%d}\n", fft.JobName, fft.ParamServer.Alpha, fft.ParamServer.Beta, fft.ParamServer.L1, fft.ParamServer.L2, fft.ParamServer.Dropout, fft.Epoch)) var solvers []solver.FtrlWorker = make([]solver.FtrlWorker, fft.NumThreads) for i := 0; i < fft.NumThreads; i++ { solvers[i].Initialize(&fft.ParamServer, fft.PusStep, fft.FetchStep) } predict_func := func(x util.Pvector) float64 { return fft.ParamServer.Predict(x) } var timer util.StopWatch timer.StartTimer() for iter := 0; iter < fft.Epoch; iter++ { var file_parser ParallelFileParser file_parser.OpenFile(train_file, fft.NumThreads) count := 0 var loss float64 = 0. var lock sync.Mutex worker_func := func(i int, c *sync.WaitGroup) { local_count := 0 var local_loss float64 = 0 for { flag, y, x := file_parser.ReadSampleMultiThread(i) if flag != nil { break } pred := solvers[i].Update(x, y, &fft.ParamServer) local_loss += calc_loss(y, pred) local_count++ if i == 0 && local_count%10000 == 0 { tmp_cnt := math.Min(float64(local_count*fft.NumThreads), float64(line_cnt)) fft.log4fft.Info(fmt.Sprintf("[%s] epoch=%d processed=[%.2f%%] time=[%.2f] train-loss=[%.6f]\n", fft.JobName, iter, float64(tmp_cnt*100)/float64(line_cnt), timer.StopTimer(), float64(local_loss)/float64(local_count))) } } lock.Lock() count += local_count loss += local_loss lock.Unlock() solvers[i].PushParam(&fft.ParamServer) defer c.Done() } if iter == 0 && util.UtilGreater(fft.BurnIn, float64(0)) { burn_in_cnt := int(fft.BurnIn * float64(line_cnt)) var local_loss float64 = 0 for i := 0; i < burn_in_cnt; i++ { //线程0做预热 flag, y, x := file_parser.ReadSample(0) if flag != nil { break } pred := fft.ParamServer.Update(x, y) local_loss += calc_loss(y, pred) if i%10000 == 0 { fft.log4fft.Info(fmt.Sprintf("[%s] burn-in processed=[%.2f%%] time=[%.2f] train-loss=[%.6f]\n", fft.JobName, float64((i+1)*100)/float64(line_cnt), timer.StopTimer(), float64(local_loss)/float64(i+1))) } } fft.log4fft.Info(fmt.Sprintf("[%s] burn-in processed=[%.2f%%] time=[%.2f] train-loss=[%.6f]\n", fft.JobName, float64(burn_in_cnt*100)/float64(line_cnt), timer.StopTimer(), float64(local_loss)/float64(burn_in_cnt))) if util.UtilFloat64Equal(fft.BurnIn, float64(1)) { continue } } for i := 0; i < fft.NumThreads; i++ { solvers[i].Reset(&fft.ParamServer) } util.UtilParallelRun(worker_func, fft.NumThreads) file_parser.CloseFile(fft.NumThreads) // f(w, // "[%s] epoch=%d processed=[%.2f%%] time=[%.2f] train-loss=[%.6f]\n", // fft.JobName, // iter, // float64(count*100)/float64(line_cnt), // timer.StopTimer(), // float64(loss)/float64(count)) if test_file != "" { eval_loss := evaluate_file(test_file, predict_func, fft.NumThreads) fft.log4fft.Info(fmt.Sprintf("[%s] validation-loss=[%f]\n", fft.JobName, float64(eval_loss))) } } return fft.ParamServer.SaveModel(model_file) }
func Run(argc int, argv []string) (string, error) { var job_name string var test_file string var model_file string var output_file string var threshold float64 log := util.GetLogger() if len(argv) == 5 { job_name = argv[0] test_file = argv[1] model_file = argv[2] output_file = argv[3] threshold, _ = strconv.ParseFloat(argv[4], 64) } else { print_usage(argc, argv) log.Error("[Predictor-Run] Input parameters error.") return fmt.Sprintf(errorjson, "[Predictor-Run] Input parameters error."), errors.New("[Predictor-Run] Input parameters error.") } if len(job_name) == 0 || len(test_file) == 0 || len(model_file) == 0 || len(output_file) == 0 { print_usage(argc, argv) log.Error("[Predictor-Run] Input parameters error.") return fmt.Sprintf(errorjson, "[Predictor-Run] Input parameters error."), errors.New("[Predictor-Run] Input parameters error.") } var model solver.LRModel model.Initialize(model_file) var wfp *os.File var err1 error exist := func(filename string) bool { var exist = true if _, err := os.Stat(filename); os.IsNotExist(err) { exist = false } return exist } if exist(output_file) { wfp, err1 = os.OpenFile(output_file, os.O_SYNC, 0666) } else { wfp, err1 = os.Create(output_file) } if err1 != nil { log.Error("[Predictor-Run] Open file error." + err1.Error()) return fmt.Sprintf(errorjson, err1.Error()), errors.New("[Predictor-Run] Open file error." + err1.Error()) } defer wfp.Close() cnt := 0 //样本总数 pcorrect := 0 //正样本预测正确数 pcnt := 0 //正样本总数 ncorrect := 0 //负样本预测正确数 var loss float64 = 0. var parser trainer.FileParser err := parser.OpenFile(test_file) if err != nil { log.Error("[Predictor-Run] Open file error." + err.Error()) return fmt.Sprintf(errorjson, err.Error()), errors.New("[Predictor-Run] Open file error." + err.Error()) } var pred_scores util.Dvector for { res, y, x := parser.ReadSample() if res != nil { break } pred := model.Predict(x) pred = math.Max(math.Min(pred, 1.-10e-15), 10e-15) wfp.WriteString(fmt.Sprintf("%f\n", pred)) pred_scores = append(pred_scores, util.DPair{pred, y}) cnt++ if util.UtilFloat64Equal(y, 1.0) { pcnt++ } var pred_label float64 = 0 if pred > threshold { pred_label = 1 } if util.UtilFloat64Equal(pred_label, y) { if util.UtilFloat64Equal(y, 1.0) { pcorrect++ } else { ncorrect++ } } pred = math.Max(math.Min(pred, 1.-10e-15), 10e-15) if y > 0 { loss += -math.Log(pred) } else { loss += -math.Log(1. - pred) } } auc := calc_auc(pred_scores) if auc < 0.5 { auc = 0.5 } if cnt > 0 { log.Info(fmt.Sprintf("[%s] Log-likelihood = %f\n", job_name, float64(loss)/float64(cnt))) log.Info(fmt.Sprintf("[%s] Precision = %.2f%% (%d/%d)\n", job_name, float64(pcorrect*100)/float64(cnt-pcnt-ncorrect+pcorrect), pcorrect, cnt-pcnt-ncorrect+pcorrect)) log.Info(fmt.Sprintf("[%s] Recall = %.2f%% (%d/%d)\n", job_name, float64(pcorrect*100)/float64(pcnt), pcorrect, pcnt)) log.Info(fmt.Sprintf("[%s] Accuracy = %.2f%% (%d/%d)\n", job_name, float64((pcorrect+ncorrect)*100)/float64(cnt), (pcorrect + ncorrect), cnt)) log.Info(fmt.Sprintf("[%s] AUC = %f\n", job_name, auc)) } parser.CloseFile() util.Write2File(output_file, fmt.Sprintf(" Log-likelihood = %f\n Precision = %f (%d/%d)\n Recall = %f (%d/%d)\n Accuracy = %f (%d/%d)\n AUC = %f\n", float64(loss)/float64(cnt), float64(pcorrect)/float64(cnt-pcnt-ncorrect+pcorrect), pcorrect, cnt-pcnt-ncorrect+pcorrect, float64(pcorrect)/float64(pcnt), pcorrect, pcnt, float64(pcorrect+ncorrect)/float64(cnt), pcorrect+ncorrect, cnt, auc)) return fmt.Sprintf(returnJson, job_name, fmt.Sprintf("Log-likelihood = %f", float64(loss)/float64(cnt)), fmt.Sprintf("Precision = %f (%d/%d)", float64(pcorrect)/float64(cnt-pcnt-ncorrect+pcorrect), pcorrect, cnt-pcnt-ncorrect+pcorrect), fmt.Sprintf("Recall = %f (%d/%d)", float64(pcorrect)/float64(pcnt), pcorrect, pcnt), fmt.Sprintf("Accuracy = %f (%d/%d)", float64((pcorrect+ncorrect))/float64(cnt), (pcorrect+ncorrect), cnt), fmt.Sprintf("AUC = %f", auc), output_file), nil }
/* * 离线模型请求串格式 * http://127.0.0.1:8080/offline?biz=[model name]&src=[hdfs/local]&dst=[redis&local&json] &alpha=[0.1]&beta=[0.1]&l1=[10]&l2=[10]&dropout=[0.1]&epoch=[2] &push=[push step]&fetch=[fetch step]&threads=[threads number] &train=[train file name]&test=[test file name]&debug=[off]&thd=[threshold] src:训练、测试数据源为hdfs/local dst:模型输出到redis、local和json train:训练数据完整路径 test:测试数据完整路径 */ func (lan *Lands) offlineServeHttp(w http.ResponseWriter, par *util.ModelParam) error { lan.log4goline.Info("[Lands-offlineServeHttp] Begin offline learning...") conn := lan.pool.Get() defer conn.Close() var base_path_off, model_path, train_path, test_path string var predict_path string //建立模型训练本地路径 base_path_off = lan.conf.DataPathBase + par.Biz + "/off/" timestamp := time.Now().Format(TimeFormatString) err := util.Mkdir(base_path_off + "/" + timestamp) if err != nil { lan.log4goline.Error("[Lands-offlineServeHttp] Offline model make local directory error." + err.Error()) return errors.New("[Lands-offlineServeHttp] Offline model make local directory error." + err.Error()) } base_path_on := lan.conf.DataPathBase + par.Biz + "/on/workspace" err = util.Mkdir(base_path_on) if err != nil { lan.log4goline.Error("[Lands-offlineServeHttp] Make online local workspace directory error." + err.Error()) return errors.New("[Lands-offlineServeHttp] Make online local workspace directory error." + err.Error()) } //挂载数据 lan.log4goline.Info("[Lands-offlineServeHttp] Mount training data.") if par.Src == "hdfs" { client, err := lan.createHdfsClient() if err != nil { lan.log4goline.Error("[Lands-offlineServeHttp] Create hdfs client error." + err.Error()) return errors.New("[Lands-offlineServeHttp] Create hdfs client error." + err.Error()) } model_path = base_path_off + "/" + timestamp + "/model.dat" train_path = base_path_off + "/" + timestamp + "/train.dat" test_path = base_path_off + "/" + timestamp + "/test.dat" predict_path = base_path_off + "/" + timestamp + "/predict.dat" client.GetMerge(par.Train, train_path, false) if err != nil { lan.log4goline.Error("[Lands-offlineServeHttp] Getmerge train data from hdfs to local error." + err.Error()) return errors.New("[Lands-offlineServeHttp] Getmerge train data from hdfs to local error." + err.Error()) } client.GetMerge(par.Test, test_path, false) if err != nil { lan.log4goline.Error("[Lands-offlineServeHttp] Getmerge test data from hdfs to local error." + err.Error()) return errors.New("[Lands-offlineServeHttp] Getmerge test data from hdfs to local error." + err.Error()) } } else if par.Src == "local" { model_path = base_path_off + "/" + timestamp + "/model.dat" train_path = base_path_off + "/" + timestamp + "/train.dat" test_path = base_path_off + "/" + timestamp + "/test.dat" predict_path = base_path_off + "/" + timestamp + "/predict.dat" err := util.CopyFile(train_path, par.Train) if err != nil { lan.log4goline.Error("[Lands-offlineServeHttp] Copy train data from local to local error." + err.Error()) return errors.New("[Lands-offlineServeHttp] Copy train data from local to local error." + err.Error()) } err = util.CopyFile(test_path, par.Test) if err != nil { lan.log4goline.Error("[Lands-offlineServeHttp] Copy test data from local to local error." + err.Error()) return errors.New("[Lands-offlineServeHttp] Copy test data from local to local error." + err.Error()) } else { lan.log4goline.Error("[Lands-offlineServeHttp] Training data source path error.") return errors.New("[Lands-offlineServeHttp] Training data source path error.") } } lan.log4goline.Info(fmt.Sprintf("[Lands-offlineServeHttp] Model path=%s,train path=%s, test path=%s\n", model_path, train_path, test_path)) //训练数据格式检查及转换 lan.log4goline.Info("[Lands-offlineServeHttp] Check training data.") _, err = lan.checkData(train_path) if err != nil { lan.log4goline.Error("[Lands-offlineServeHttp] Check train data from local to local error." + err.Error()) return errors.New("[Lands-offlineServeHttp] Check train data from local to local error." + err.Error()) } lan.log4goline.Info("[Lands-offlineServeHttp] Check testing data.") _, err = lan.checkData(test_path) if err != nil { lan.log4goline.Error("[Lands-offlineServeHttp] Check test data from local to local error." + err.Error()) return errors.New("[Lands-offlineServeHttp] Check test data from local to local error." + err.Error()) } //数据抽样 if !util.UtilFloat64Equal(par.Sample, 1.0) && !util.UtilFloat64Equal(par.Sample, -1.0) { lan.log4goline.Info("[Lands-offlineServeHttp] Training data sampling.") err = util.FileSampleWithRatio(train_path, par.Sample) if err != nil { lan.log4goline.Error("[Lands-offlineServeHttp] Train data sampling error." + err.Error()) return errors.New("[Lands-offlineServeHttp] Train data sampling error." + err.Error()) } } //模型训练 lan.log4goline.Info("[Lands-offlineServeHttp] Offline model training.") var fft trainer.FastFtrlTrainer fft.SetJobName(par.Biz + " offline " + timestamp) if !fft.Initialize(par.Epoch, par.Threads, true, 0, par.Push, par.Fetch) { lan.log4goline.Error("[Lands-offlineServeHttp] Initialize ftrl trainer error") return errors.New("[Lands-offlineServeHttp] Initialize ftrl trainer error.") } err = fft.Train(par.Alpha, par.Beta, par.L1, par.L2, par.Dropout, model_path, train_path, test_path) if err != nil { lan.log4goline.Error("[Lands-offlineServeHttp] Training model error." + err.Error()) return errors.New("[Lands-offlineServeHttp] Training model error." + err.Error()) } lan.log4goline.Info("[Lands-offlineServeHttp] Predict testing data.") _, err = predictor.Run(3, []string{par.Biz + " offline " + timestamp, test_path, model_path, predict_path, par.Threshold}) if err != nil { lan.log4goline.Error("[Lands-offlineServeHttp] Predicting model error." + err.Error()) return errors.New("[Lands-offlineServeHttp] Predicting model error." + err.Error()) } err = util.CopyFile(base_path_on+"/model.dat", base_path_off+"/"+timestamp+"/model.dat") if err != nil { lan.log4goline.Error("[Lands-offlineServeHttp] Copy model error." + err.Error()) return errors.New("[Lands-offlineServeHttp] Copy model error." + err.Error()) } //清理目录 lan.log4goline.Info("[Lands-offlineServeHttp] Clear local directory." + base_path_off) err = util.KeepLatestN(base_path_off, 5) if err != nil { lan.log4goline.Warn("[Lands-offlineServeHttp] Clear local file error." + err.Error()) errors.New("[Lands-offlineServeHttp] Clear local file error." + err.Error()) } m, err := fft.ParamServer.SaveEncodeModel() if err != nil { lan.log4goline.Error("[Lands-offlineServeHttp] Save model error." + err.Error()) return errors.New("[Lands-offlineServeHttp] Save model error." + err.Error()) } fmt.Fprintf(w, "%s", m) //模型存入redis lan.log4goline.Info("[Lands-offlineServeHttp] Write model to redis.") if par.Dst == "redis" { key := ModelPrefix + par.Biz conn.Send("SET", key, m) conn.Flush() _, err := conn.Receive() if err != nil { lan.log4goline.Error(fmt.Sprintln("[Lands-offlineServeHttp] Save model to redis error." + err.Error())) return errors.New("[Lands-offlineServeHttp] Save model to redis error." + err.Error()) } } lan.log4goline.Info("[Lands-offlineServeHttp] End offline learning.") return nil }