func createDecisionNode(data [][]float64, target []float64, minSamples int, ch chan *DecisionTree) { // ending conditions if (len(target) < minSamples) || (stats.RMS(target) == 0) { ch <- &DecisionTree{stats.Mean(target), 0, 0., nil, nil} return } // Find the best variable to split on bestRMS := 1e30 var bestCut float64 var bestCol int cut := 0. rms := 0. for i := 0; i < len(data[0]); i++ { // make sure there's some variance in the column cut, rms = findOptimalCut(data[:][i], target) if rms < bestRMS { bestRMS = rms bestCol = i bestCut = cut } } lowData := make([][]float64, 0, len(data)/2) lowTarget := make([]float64, 0, len(data)/2) highData := make([][]float64, 0, len(data)/2) highTarget := make([]float64, 0, len(data)/2) for i, row := range data { if row[bestCol] <= bestCut { lowData = append(lowData, row) lowTarget = append(lowTarget, target[i]) } else { highData = append(highData, row) highTarget = append(highTarget, target[i]) } } chLow := make(chan *DecisionTree) chHigh := make(chan *DecisionTree) go createDecisionNode(lowData, lowTarget, minSamples, chLow) go createDecisionNode(highData, highTarget, minSamples, chHigh) var treeLow, treeHigh *DecisionTree for i := 0; i < 2; i++ { select { case t := <-chLow: treeLow = t case t := <-chHigh: treeHigh = t } } node := DecisionTree{stats.Mean(target), bestCol, bestCut, treeLow, treeHigh} ch <- &node }
// Given a column of data, find the cutoff that gives the // smallest sum of RMS func findOptimalCut(column, target []float64) (bestCut, bestRMS float64) { // join the column and the targets together paired := make(pairFloat64Collection, len(column)) for i := 0; i < len(column); i++ { paired[i].sort_val = column[i] paired[i].other_val = target[i] } // sort the indices by the column values sort.Sort(paired) // eliminate duplicates paired = Dedupe(paired) if len(paired) <= 1 { bestCut = 0. bestRMS = 1e30 return bestCut, bestRMS } else if stats.RMS(target) == 0. { bestCut = 0. bestRMS = 0. return bestCut, bestRMS } // extract it back into two different arrays which are now sorted and deduplicated for i, pair := range paired { column[i] = pair.sort_val target[i] = pair.other_val } // now loop through cuts to find the best one bestRMS, bestCut = 1.e30, 0. var rms float64 for i := 1; i < len(column); i++ { rms = stats.RMS(target[:i]) + stats.RMS(target[i:]) if rms < bestRMS { bestRMS = rms bestCut = (column[i] + column[i-1]) / 2 } } return bestCut, bestRMS }