func evaluate_stream(stream []string, func_predict func(x util.Pvector) float64, num_threads int) float64 { var parser StreamParser parser.Open(stream) count := 0 var loss float64 = 0 var lock sync.Mutex var predict_worker = func(i int, c *sync.WaitGroup) { local_count := 0 var local_loss float64 = 0 for { res, local_y, local_x := parser.ReadSampleMultiThread() if res != nil { break } local_loss += calc_loss(local_y, func_predict(local_x)) local_count++ } lock.Lock() count += local_count loss += local_loss lock.Unlock() defer c.Done() } util.UtilParallelRun(predict_worker, num_threads) parser.Close() if count > 0 { loss = loss / float64(count) } return loss }
func (fft *FastFtrlTrainer) TrainImpl( model_file string, train_file string, line_cnt int, test_file string) error { if !fft.Init { fft.log4fft.Error("[FastFtrlTrainer-TrainImpl] Fast ftrl trainer restore error.") return errors.New("[FastFtrlTrainer-TrainImpl] Fast ftrl trainer restore error.") } fft.log4fft.Info(fmt.Sprintf( "[%s] params={alpha:%.2f, beta:%.2f, l1:%.2f, l2:%.2f, dropout:%.2f, epoch:%d}\n", fft.JobName, fft.ParamServer.Alpha, fft.ParamServer.Beta, fft.ParamServer.L1, fft.ParamServer.L2, fft.ParamServer.Dropout, fft.Epoch)) var solvers []solver.FtrlWorker = make([]solver.FtrlWorker, fft.NumThreads) for i := 0; i < fft.NumThreads; i++ { solvers[i].Initialize(&fft.ParamServer, fft.PusStep, fft.FetchStep) } predict_func := func(x util.Pvector) float64 { return fft.ParamServer.Predict(x) } var timer util.StopWatch timer.StartTimer() for iter := 0; iter < fft.Epoch; iter++ { var file_parser ParallelFileParser file_parser.OpenFile(train_file, fft.NumThreads) count := 0 var loss float64 = 0. var lock sync.Mutex worker_func := func(i int, c *sync.WaitGroup) { local_count := 0 var local_loss float64 = 0 for { flag, y, x := file_parser.ReadSampleMultiThread(i) if flag != nil { break } pred := solvers[i].Update(x, y, &fft.ParamServer) local_loss += calc_loss(y, pred) local_count++ if i == 0 && local_count%10000 == 0 { tmp_cnt := math.Min(float64(local_count*fft.NumThreads), float64(line_cnt)) fft.log4fft.Info(fmt.Sprintf("[%s] epoch=%d processed=[%.2f%%] time=[%.2f] train-loss=[%.6f]\n", fft.JobName, iter, float64(tmp_cnt*100)/float64(line_cnt), timer.StopTimer(), float64(local_loss)/float64(local_count))) } } lock.Lock() count += local_count loss += local_loss lock.Unlock() solvers[i].PushParam(&fft.ParamServer) defer c.Done() } if iter == 0 && util.UtilGreater(fft.BurnIn, float64(0)) { burn_in_cnt := int(fft.BurnIn * float64(line_cnt)) var local_loss float64 = 0 for i := 0; i < burn_in_cnt; i++ { //线程0做预热 flag, y, x := file_parser.ReadSample(0) if flag != nil { break } pred := fft.ParamServer.Update(x, y) local_loss += calc_loss(y, pred) if i%10000 == 0 { fft.log4fft.Info(fmt.Sprintf("[%s] burn-in processed=[%.2f%%] time=[%.2f] train-loss=[%.6f]\n", fft.JobName, float64((i+1)*100)/float64(line_cnt), timer.StopTimer(), float64(local_loss)/float64(i+1))) } } fft.log4fft.Info(fmt.Sprintf("[%s] burn-in processed=[%.2f%%] time=[%.2f] train-loss=[%.6f]\n", fft.JobName, float64(burn_in_cnt*100)/float64(line_cnt), timer.StopTimer(), float64(local_loss)/float64(burn_in_cnt))) if util.UtilFloat64Equal(fft.BurnIn, float64(1)) { continue } } for i := 0; i < fft.NumThreads; i++ { solvers[i].Reset(&fft.ParamServer) } util.UtilParallelRun(worker_func, fft.NumThreads) file_parser.CloseFile(fft.NumThreads) // f(w, // "[%s] epoch=%d processed=[%.2f%%] time=[%.2f] train-loss=[%.6f]\n", // fft.JobName, // iter, // float64(count*100)/float64(line_cnt), // timer.StopTimer(), // float64(loss)/float64(count)) if test_file != "" { eval_loss := evaluate_file(test_file, predict_func, fft.NumThreads) fft.log4fft.Info(fmt.Sprintf("[%s] validation-loss=[%f]\n", fft.JobName, float64(eval_loss))) } } return fft.ParamServer.SaveModel(model_file) }
func (lft *LockFreeFtrlTrainer) TrainImpl( model_file string, train_file string, line_cnt int, test_file string) error { if !lft.Init { lft.log.Error("[LockFreeFtrlTrainer-TrainImpl] Fast ftrl trainer restore error.") return errors.New("[LockFreeFtrlTrainer-TrainImpl] Fast ftrl trainer restore error.") } lft.log.Info(fmt.Sprintf("[%s] params={alpha:%.2f, beta:%.2f, l1:%.2f, l2:%.2f, dropout:%.2f, epoch:%d}\n", lft.JobName, lft.Solver.Alpha, lft.Solver.Beta, lft.Solver.L1, lft.Solver.L2, lft.Solver.Dropout, lft.Epoch)) predict_func := func(x util.Pvector) float64 { return lft.Solver.Predict(x) } var timer util.StopWatch timer.StartTimer() for iter := 0; iter < lft.Epoch; iter++ { var file_parser FileParser file_parser.OpenFile(train_file) count := 0 var loss float64 = 0 var lock sync.Mutex worker_func := func(i int, c *sync.WaitGroup) { local_count := 0 var local_loss float64 = 0 for { flag, y, x := file_parser.ReadSampleMultiThread() if flag != nil { break } pred := lft.Solver.Update(x, y) local_loss += calc_loss(y, pred) local_count++ if i == 0 && local_count%10000 == 0 { tmp_cnt := math.Min(float64(local_count*lft.NumThreads), float64(line_cnt)) lft.log.Info(fmt.Sprintf("[%s] epoch=%d processed=[%.2f%%] time=[%.2f] train-loss=[%.6f]\n", lft.JobName, iter, float64(tmp_cnt*100)/float64(line_cnt), timer.StopTimer(), float64(local_loss)/float64(local_count))) } } lock.Lock() count += local_count loss += local_loss lock.Unlock() defer c.Done() } util.UtilParallelRun(worker_func, lft.NumThreads) file_parser.CloseFile() lft.log.Info(fmt.Sprintf("[%s] epoch=%d processed=[%.2f%%] time=[%.2f] train-loss=[%.6f]\n", lft.JobName, iter, float64(count*100)/float64(line_cnt), timer.StopTimer(), float64(loss)/float64(count))) if test_file != "" { eval_loss := evaluate_file(test_file, predict_func, 0) lft.log.Info(fmt.Sprintf("[%s] validation-loss=[%f]\n", lft.JobName, float64(eval_loss))) } } return lft.Solver.SaveModel(model_file) }
func (lft *LockFreeFtrlTrainer) TrainBatch( encodemodel string, instances []string) error { line_cnt := len(instances) if line_cnt == 0 { lft.log.Error("[LockFreeFtrlTrainer-TrainBatch] No model retrained.") return errors.New("[LockFreeFtrlTrainer-TrainBatch] No model retrained.") } var fls solver.FtrlSolver err := json.Unmarshal([]byte(encodemodel), &fls) if err != nil { lft.log.Error("[LockFreeFtrlTrainer-TrainBatch]" + err.Error()) return errors.New("[LockFreeFtrlTrainer-TrainBatch]" + err.Error()) } lft.Solver = fls lft.log.Info(fmt.Sprintf("[%s] params={alpha:%.2f, beta:%.2f, l1:%.2f, l2:%.2f, dropout:%.2f, epoch:%d}\n", lft.JobName, lft.Solver.Alpha, lft.Solver.Beta, lft.Solver.L1, lft.Solver.L2, lft.Solver.Dropout, lft.Epoch)) predict_func := func(x util.Pvector) float64 { return lft.Solver.Predict(x) } var timer util.StopWatch timer.StartTimer() for iter := 0; iter < lft.Epoch; iter++ { var stream_parser StreamParser stream_parser.Open(instances) count := 0 var loss float64 = 0 var lock sync.Mutex worker_func := func(i int, c *sync.WaitGroup) { local_count := 0 var local_loss float64 = 0 for { flag, y, x := stream_parser.ReadSampleMultiThread() if flag != nil { break } pred := lft.Solver.Update(x, y) local_loss += calc_loss(y, pred) local_count++ if i == 0 && local_count%10000 == 0 { tmp_cnt := math.Min(float64(local_count*lft.NumThreads), float64(line_cnt)) lft.log.Info(fmt.Sprintf("[%s] epoch=%d processed=[%.2f%%] time=[%.2f] train-loss=[%.6f]\n", lft.JobName, iter, float64(tmp_cnt*100)/float64(line_cnt), timer.StopTimer(), float64(local_loss)/float64(local_count))) } } lock.Lock() count += local_count loss += local_loss lock.Unlock() defer c.Done() } util.UtilParallelRun(worker_func, lft.NumThreads) stream_parser.Close() lft.log.Info(fmt.Sprintf("[%s] epoch=%d processed=[%.2f%%] time=[%.2f] train-loss=[%.6f]\n", lft.JobName, iter, float64(count*100)/float64(line_cnt), timer.StopTimer(), float64(loss)/float64(count))) eval_loss := evaluate_stream(instances, predict_func, 0) lft.log.Info(fmt.Sprintf("[%s] validation-loss=[%f]\n", lft.JobName, float64(eval_loss))) } return nil }
func read_problem_info( train_file string, read_cache bool, num_threads int) (int, int, error) { feat_num := 0 line_cnt := 0 log := util.GetLogger() var lock sync.Mutex var parser FileParser var errall error read_from_cache := func(path string) error { fs, err := os.Open(path) defer fs.Close() if err != nil { return err } bfRd := bufio.NewReader(fs) line, err := bfRd.ReadString('\n') if err != nil { return err } var res []string = s.Split(line, " ") if len(res) != 2 { log.Error("[read_problem_info] File format error.") return errors.New("[read_problem_info] File format error.") } feat_num, errall = strconv.Atoi(res[0]) if errall != nil { log.Error("[read_problem_info] Label format error." + errall.Error()) return errors.New("[read_problem_info] Label format error." + errall.Error()) } line_cnt, errall = strconv.Atoi(res[1]) if errall != nil { log.Error("[read_problem_info] Feature format error." + errall.Error()) return errors.New("[read_problem_info] Feature format error." + errall.Error()) } return nil } exist := func(filename string) bool { var exist = true if _, err := os.Stat(filename); os.IsNotExist(err) { exist = false } return exist } write_to_cache := func(filename string) error { var f *os.File var err1 error if exist(filename) { f, err1 = os.OpenFile(filename, os.O_WRONLY, 0666) } else { f, err1 = os.Create(filename) } if err1 != nil { return err1 } defer f.Close() wireteString := string(feat_num) + " " + string(line_cnt) + "\n" _, err1 = io.WriteString(f, wireteString) if err1 != nil { return err1 } return nil } read_problem_worker := func(i int, c *sync.WaitGroup) { local_max_feat := 0 local_count := 0 for { flag, _, local_x := parser.ReadSampleMultiThread() if flag != nil { break } for i := 0; i < len(local_x); i++ { if local_x[i].Index+1 > local_max_feat { local_max_feat = local_x[i].Index + 1 } } local_count++ } lock.Lock() line_cnt += local_count lock.Unlock() if local_max_feat > feat_num { feat_num = local_max_feat } defer c.Done() } cache_file := string(train_file) + ".cache" cache_exists := exist(cache_file) if read_cache && cache_exists { read_from_cache(cache_file) } else { parser.OpenFile(train_file) util.UtilParallelRun(read_problem_worker, num_threads) parser.CloseFile() } log.Info(fmt.Sprintf("[read_problem_info] Instances=[%d] features=[%d]\n", line_cnt, feat_num)) if read_cache && !cache_exists { write_to_cache(cache_file) } return feat_num, line_cnt, nil }