コード例 #1
0
ファイル: DecisionTree.go プロジェクト: neggert/decisiontree
func createDecisionNode(data [][]float64, target []float64, minSamples int, ch chan *DecisionTree) {
	// ending conditions
	if (len(target) < minSamples) || (stats.RMS(target) == 0) {
		ch <- &DecisionTree{stats.Mean(target), 0, 0., nil, nil}
		return
	}
	// Find the best variable to split on
	bestRMS := 1e30
	var bestCut float64
	var bestCol int
	cut := 0.
	rms := 0.
	for i := 0; i < len(data[0]); i++ {
		// make sure there's some variance in the column
		cut, rms = findOptimalCut(data[:][i], target)
		if rms < bestRMS {
			bestRMS = rms
			bestCol = i
			bestCut = cut
		}
	}
	lowData := make([][]float64, 0, len(data)/2)
	lowTarget := make([]float64, 0, len(data)/2)
	highData := make([][]float64, 0, len(data)/2)
	highTarget := make([]float64, 0, len(data)/2)

	for i, row := range data {
		if row[bestCol] <= bestCut {
			lowData = append(lowData, row)
			lowTarget = append(lowTarget, target[i])
		} else {
			highData = append(highData, row)
			highTarget = append(highTarget, target[i])
		}
	}

	chLow := make(chan *DecisionTree)
	chHigh := make(chan *DecisionTree)

	go createDecisionNode(lowData, lowTarget, minSamples, chLow)
	go createDecisionNode(highData, highTarget, minSamples, chHigh)

	var treeLow, treeHigh *DecisionTree

	for i := 0; i < 2; i++ {
		select {
		case t := <-chLow:
			treeLow = t
		case t := <-chHigh:
			treeHigh = t
		}
	}

	node := DecisionTree{stats.Mean(target), bestCol, bestCut,
		treeLow, treeHigh}
	ch <- &node
}
コード例 #2
0
ファイル: DecisionTree.go プロジェクト: neggert/decisiontree
// Given a column of data, find the cutoff that gives the
// smallest sum of RMS
func findOptimalCut(column, target []float64) (bestCut, bestRMS float64) {
	// join the column and the targets together
	paired := make(pairFloat64Collection, len(column))
	for i := 0; i < len(column); i++ {
		paired[i].sort_val = column[i]
		paired[i].other_val = target[i]
	}
	// sort the indices by the column values
	sort.Sort(paired)
	// eliminate duplicates
	paired = Dedupe(paired)
	if len(paired) <= 1 {
		bestCut = 0.
		bestRMS = 1e30
		return bestCut, bestRMS
	} else if stats.RMS(target) == 0. {
		bestCut = 0.
		bestRMS = 0.
		return bestCut, bestRMS
	}
	// extract it back into two different arrays which are now sorted and deduplicated
	for i, pair := range paired {
		column[i] = pair.sort_val
		target[i] = pair.other_val
	}

	// now loop through cuts to find the best one
	bestRMS, bestCut = 1.e30, 0.
	var rms float64
	for i := 1; i < len(column); i++ {
		rms = stats.RMS(target[:i]) + stats.RMS(target[i:])
		if rms < bestRMS {
			bestRMS = rms
			bestCut = (column[i] + column[i-1]) / 2
		}
	}
	return bestCut, bestRMS
}