示例#1
0
文件: test.go 项目: happyYao/gbdt-1
func main() {
	//defer profile.Start(profile.CPUProfile).Stop() //monitor program performance
	if gbdt.Conf == nil {
		fmt.Println("nil pointer")
	} else {
		fmt.Println(gbdt.Conf)
	}
	modelname := "./gbdt.model"
	start := time.Now()
	trainpath := "./train.data"
	//trainpath:="/opt/tmp/gbdt/data/train.data"
	train_sample_number := 4584
	//train_sample_number:=4186052
	dataset := &gbdt.DataSet{}
	dataset.LoadDataFromFile(trainpath, train_sample_number)
	g := gbdt.NewGBDT()
	g.Init(dataset)
	g.Train(dataset)
	model := g.Save()
	if err := ioutil.WriteFile(modelname, []byte(model), 0666); err != nil {
		fmt.Println(err)
	}
	latency := time.Since(start)
	fmt.Println("train time:", latency)
	start2 := time.Now()
	testpath := "./test.data"
	//testpath:="/opt/tmp/gbdt/data/smalltest.data"
	//test_sample_number := 837210
	test_sample_number := 4584
	testdataset := &gbdt.DataSet{}
	testdataset.LoadDataFromFile(testpath, test_sample_number)

	/*model,err:=ioutil.ReadFile(modelname)
	if err!=nil {
		log.Fatal(err)
	}
	g:=&gbdt.GBDT{}
	g.Load(string(model))*/ //load model from local file

	g.Load(model)
	samples := testdataset.GetSamples()
	tree_count := gbdt.Conf.GetTreecount()
	var click_sum, no_click_sum float64
	for i := 0; i < len(samples); i++ {
		if samples[i].GetLabel() == 1 {
			click_sum += float64(samples[i].GetWeight())
		} else {
			no_click_sum += float64(samples[i].GetWeight())
		}
	}
	latency = time.Since(start2)
	fmt.Println("precision time:", latency)
	auc := gbdt.NewAuc()
	for j := 0; j < len(samples); j++ {
		p := gbdt.LogitCtr(g.Predict(samples[j], tree_count))
		auc.Add(float64(p), float64(samples[j].GetWeight()), samples[j].GetLabel())
	}
	fmt.Println("auc:", auc.CalculateAuc())
	auc.PrintConfusionTable()
	/*FeatureMapFile:="./feature.map"
	feature_data,err:=ioutil.ReadFile(FeatureMapFile)
	if err!=nil {
		log.Fatal(err)
	}
	feature_map:=gbdt.LoadFeatureMap(string(feature_data))
	feature_weight_list:=g.GetFeatureWeight()
	for i := 0; i < len(feature_weight_list); i++ {
		fid:=feature_weight_list[i].Key
		fmt.Println(feature_map[fid],":",feature_weight_list[i].Value)
	}*/

}
示例#2
0
文件: train.go 项目: postfix/gbdt
func main() {
	start := time.Now()
	var debug bool
	flag.BoolVar(&debug, "debug", false, "whether print training info")

	var train_sample_number int
	flag.IntVar(&train_sample_number, "trainrows", 0, "train sample number")

	var trainpath, modelname, aucfile string
	flag.StringVar(&trainpath, "trainpath", "/opt/tmp/search_rerank/train.data", "train data path")
	flag.StringVar(&modelname, "modelname", "/opt/tmp/search_rerank/gbdt.model", "model file")
	flag.StringVar(&aucfile, "aucfile", "/opt/tmp/search_rerank/train_auc", "training auc")

	var feature_cost_file, feature_map_file string
	flag.StringVar(&feature_cost_file, "feature_cost_file", "./data/feature.cost", "feature init cost")
	flag.StringVar(&feature_map_file, "feature_map_file", "./data/feature.map", "feature map")

	var treecount int
	flag.IntVar(&treecount, "treecount", 100, "tree number")

	var feature_num int
	flag.IntVar(&feature_num, "feature_num", 45, "feature number")

	var depth, min_leaf_size int
	flag.IntVar(&depth, "depth", 4, "tree depth")
	flag.IntVar(&min_leaf_size, "min_leaf_size", 20000, "min leaf size")

	var feature_sampling_ratio, data_sampling_ratio, shrinkage float64
	flag.Float64Var(&feature_sampling_ratio, "feature_sampling_ratio", 0.7, "feature sampling ratio")
	flag.Float64Var(&data_sampling_ratio, "data_sampling_ratio", 0.6, "data sampling ration")
	flag.Float64Var(&shrinkage, "shrinkage", 0.1, "step size")

	var test_sample_number int
	flag.IntVar(&test_sample_number, "testrows", 0, "test sample number")

	var testpath string
	flag.StringVar(&testpath, "testpath", "/opt/tmp/search_rerank/test.data", "test data path")

	var istestset bool
	flag.BoolVar(&istestset, "istestset", false, "whether use testset")

	var switch_feature_tune bool
	flag.BoolVar(&switch_feature_tune, "switch_feature_tune", false, "switch feature tune")

	flag.Parse()
	/*if train_sample_number<1000000 {
		log.Println("read train file err")
		os.Exit(1)
	}*/

	gbdt.Conf.Debug = debug
	gbdt.Conf.Tree_count = treecount
	gbdt.Conf.Number_of_feature = feature_num
	gbdt.Conf.Max_depth = depth
	gbdt.Conf.Min_leaf_size = min_leaf_size
	gbdt.Conf.Feature_sampling_ratio = float32(feature_sampling_ratio)
	gbdt.Conf.Data_sampling_ratio = float32(data_sampling_ratio)
	gbdt.Conf.Shrinkage = float32(shrinkage)
	gbdt.Conf.Enable_feature_tunning = switch_feature_tune

	if gbdt.Conf.Enable_feature_tunning {
		gbdt.Conf.InitFeatureCost()
	}
	//gbdt.Conf.LoadFeatureCost(feature_cost_file)
	//log.Println(feature_cost_file,":feature cost file load done!")
	if gbdt.Conf == nil {
		log.Println("nil pointer")
		os.Exit(1)
	} else {
		log.Println(gbdt.Conf)
	}

	dataset := &gbdt.DataSet{}
	dataset.LoadDataFromFile(trainpath, train_sample_number)
	g := gbdt.NewGBDT()
	g.Init(dataset)
	g.Train(dataset)
	model := g.Save()
	if err := ioutil.WriteFile(modelname, []byte(model), 0666); err != nil {
		log.Println(err)
		os.Exit(1)
	}
	var auc_score, logloss float32 = 0.0, 0.0
	if istestset {
		testdataset := &gbdt.DataSet{}
		testdataset.LoadDataFromFile(testpath, test_sample_number)
		auc_score, logloss = EvalModel(testdataset, g, treecount)
	} else {
		auc_score, logloss = EvalModel(dataset, g, treecount)
	}

	log.Println("auc:", auc_score)
	log.Println("logloss:", logloss)
	evalscore := fmt.Sprintf("auc:%v,logloss:%v", auc_score, logloss)
	if err := ioutil.WriteFile(aucfile, []byte(evalscore), 0666); err != nil {
		log.Println(err)
		os.Exit(1)
	}
	feature_data, err := ioutil.ReadFile(feature_map_file)
	if err != nil {
		log.Fatal(err)
	}

	feature_map := gbdt.LoadFeatureMap(string(feature_data))
	log.Println(feature_map_file, ":feature_map load done!")

	feature_weight_list := g.GetFeatureWeight()
	log.Println("feature weight:")
	for i := 0; i < len(feature_weight_list); i++ {
		fid := feature_weight_list[i].Key
		log.Println(feature_map[fid], ":", feature_weight_list[i].Value)
	}
	if switch_feature_tune {
		log.Println("feature cost:")
		for i := 0; i < len(gbdt.Conf.Feature_costs); i++ {
			log.Println(i, " ", feature_map[i], ":", gbdt.Conf.Feature_costs[i])
		}
	}

	log.Println("training time:", time.Since(start))

}