コード例 #1
0
ファイル: applyforest.go プロジェクト: lytics/CloudForest
func main() {
	fm := flag.String("fm",
		"featurematrix.afm", "AFM formated feature matrix containing data.")
	rf := flag.String("rfpred",
		"rface.sf", "A predictor forest.")
	predfn := flag.String("preds",
		"", "The name of a file to write the predictions into.")
	votefn := flag.String("votes",
		"", "The name of a file to write categorical vote totals to.")
	var num bool
	flag.BoolVar(&num, "mean", false, "Force numeric (mean) voting.")
	var sum bool
	flag.BoolVar(&sum, "sum", false, "Force numeric sum voting (for gradient boosting etc).")
	var expit bool
	flag.BoolVar(&expit, "expit", false, "Expit (inverst logit) transform data (for gradient boosting classification).")
	var cat bool
	flag.BoolVar(&cat, "mode", false, "Force categorical (mode) voting.")

	flag.Parse()

	//Parse Data
	data, err := CloudForest.LoadAFM(*fm)
	if err != nil {
		log.Fatal(err)
	}

	forestfile, err := os.Open(*rf) // For read access.
	if err != nil {
		log.Fatal(err)
	}
	defer forestfile.Close()
	forestreader := CloudForest.NewForestReader(forestfile)
	forest, err := forestreader.ReadForest()
	if err != nil {
		log.Fatal(err)
	}

	var predfile *os.File
	if *predfn != "" {
		predfile, err = os.Create(*predfn)
		if err != nil {
			log.Fatal(err)
		}
		defer predfile.Close()
	}

	var bb CloudForest.VoteTallyer
	switch {
	case sum:
		bb = CloudForest.NewSumBallotBox(data.Data[0].Length())

	case !cat && (num || strings.HasPrefix(forest.Target, "N")):
		bb = CloudForest.NewNumBallotBox(data.Data[0].Length())

	default:
		bb = CloudForest.NewCatBallotBox(data.Data[0].Length())
	}

	for _, tree := range forest.Trees {
		tree.Vote(data, bb)
	}

	targeti, hasTarget := data.Map[forest.Target]
	if hasTarget {
		fmt.Printf("Target is %v in feature %v\n", forest.Target, targeti)
		er := bb.TallyError(data.Data[targeti])
		fmt.Printf("Error: %v\n", er)
	}
	if *predfn != "" {
		fmt.Printf("Outputting label predicted actual tsv to %v\n", *predfn)
		for i, l := range data.CaseLabels {
			actual := "NA"
			if hasTarget {
				actual = data.Data[targeti].GetStr(i)
			}

			result := ""

			if sum || forest.Intercept != 0.0 {
				numresult := 0.0
				if sum {
					numresult = bb.(*CloudForest.SumBallotBox).TallyNum(i) + forest.Intercept
				} else {
					numresult = bb.(*CloudForest.NumBallotBox).TallyNum(i) + forest.Intercept
				}
				if expit {
					numresult = CloudForest.Expit(numresult)
				}
				result = fmt.Sprintf("%v", numresult)

			} else {
				result = bb.Tally(i)
			}
			fmt.Fprintf(predfile, "%v\t%v\t%v\n", l, result, actual)
		}
	}

	//Not thread safe code!
	if *votefn != "" {
		fmt.Printf("Outputting vote totals to %v\n", *votefn)
		cbb := bb.(*CloudForest.CatBallotBox)
		votefile, err := os.Create(*votefn)
		if err != nil {
			log.Fatal(err)
		}
		defer votefile.Close()
		fmt.Fprintf(votefile, ".")

		for _, lable := range cbb.CatMap.Back {
			fmt.Fprintf(votefile, "\t%v", lable)
		}
		fmt.Fprintf(votefile, "\n")

		for i, box := range cbb.Box {
			fmt.Fprintf(votefile, "%v", data.CaseLabels[i])

			for j := range cbb.CatMap.Back {
				total := 0.0
				total = box.Map[j]

				fmt.Fprintf(votefile, "\t%v", total)

			}
			fmt.Fprintf(votefile, "\n")

		}
	}
}
コード例 #2
0
ファイル: leafcount.go プロジェクト: lytics/CloudForest
func main() {
	fm := flag.String("fm", "featurematrix.afm", "AFM formated feature matrix to use.")
	rf := flag.String("rfpred", "rface.sf", "A predictor forest.")
	outf := flag.String("leaves", "leaves.tsv", "a case by case sparse matrix of leaf co-occurrence in tsv format")
	boutf := flag.String("branches", "", "a case by feature sparse matrix of leaf co-occurrence in tsv format")
	soutf := flag.String("splits", "", "a file to write a json record of splite per feature")
	var threads int
	flag.IntVar(&threads, "threads", 1, "Parse seperate forests in n seperate threads.")

	flag.Parse()

	splits := make(map[string][]string)

	//Parse Data
	data, err := CloudForest.LoadAFM(*fm)
	if err != nil {
		log.Fatal(err)
	}

	log.Print("Data file ", len(data.Data), " by ", data.Data[0].Length())

	counts := new(CloudForest.SparseCounter)
	var caseFeatureCounts *CloudForest.SparseCounter
	if *boutf != "" {
		caseFeatureCounts = new(CloudForest.SparseCounter)
	}

	files := strings.Split(*rf, ",")

	runtime.GOMAXPROCS(threads)

	fileChan := make(chan string, 0)
	doneChan := make(chan int, 0)

	go func() {
		for _, fn := range files {
			fileChan <- fn
		}
	}()

	for i := 0; i < threads; i++ {

		go func() {
			for {
				fn := <-fileChan

				forestfile, err := os.Open(fn) // For read access.
				if err != nil {
					log.Fatal(err)
				}
				defer forestfile.Close()
				forestreader := CloudForest.NewForestReader(forestfile)
				forest, err := forestreader.ReadForest()
				if err != nil {
					log.Fatal(err)
				}
				log.Print("Forest has ", len(forest.Trees), " trees ")

				for i := 0; i < len(forest.Trees); i++ {
					fmt.Print(".")
					leaves := forest.Trees[i].GetLeaves(data, caseFeatureCounts)
					for _, leaf := range leaves {
						for j := 0; j < len(leaf.Cases); j++ {
							for k := 0; k < len(leaf.Cases); k++ {

								counts.Add(leaf.Cases[j], leaf.Cases[k], 1)

							}
						}
					}

					if *soutf != "" {
						forest.Trees[i].Root.Climb(func(n *CloudForest.Node) {
							if n.Splitter != nil {
								name := n.Splitter.Feature
								_, ok := splits[name]
								if !ok {
									splits[name] = make([]string, 0, 10)
								}
								split := ""
								switch n.Splitter.Numerical {
								case true:
									split = fmt.Sprintf("%v", n.Splitter.Value)
								case false:
									keys := make([]string, 0, len(n.Splitter.Left))
									for k := range n.Splitter.Left {
										keys = append(keys, k)
									}
									split = strings.Join(keys, ",")
								}
								splits[name] = append(splits[name], split)
							}
						})
					}

				}
				doneChan <- 1
			}
		}()

	}

	for i := 0; i < len(files); i++ {
		<-doneChan
	}

	log.Print("Outputting Case Case  Co-Occurrence Counts")
	outfile, err := os.Create(*outf)
	if err != nil {
		log.Fatal(err)
	}
	defer outfile.Close()
	counts.WriteTsv(outfile)

	if *boutf != "" {
		log.Print("Outputting Case Feature Co-Occurrence Counts")
		boutfile, err := os.Create(*boutf)
		if err != nil {
			log.Fatal(err)
		}
		defer boutfile.Close()
		caseFeatureCounts.WriteTsv(boutfile)
	}

	if *soutf != "" {
		soutfile, err := os.Create(*soutf)
		if err != nil {
			log.Fatal(err)
		}
		defer soutfile.Close()
		encoder := json.NewEncoder(soutfile)
		encoder.Encode(splits)
	}
}