Example #1
0
// Prune eliminates branches which hurt accuracy
func (d *DecisionTreeNode) Prune(using base.FixedDataGrid) {
	// If you're a leaf, you're already pruned
	if d.Children == nil {
		return
	}
	if d.SplitAttr == nil {
		return
	}

	// Recursively prune children of this node
	sub := base.DecomposeOnAttributeValues(using, d.SplitAttr)
	for k := range d.Children {
		if sub[k] == nil {
			continue
		}
		subH, subV := sub[k].Size()
		if subH == 0 || subV == 0 {
			continue
		}
		d.Children[k].Prune(sub[k])
	}

	// Get a baseline accuracy
	baselineAccuracy := computeAccuracy(d.Predict(using), using)

	// Speculatively remove the children and re-evaluate
	tmpChildren := d.Children
	d.Children = nil
	newAccuracy := computeAccuracy(d.Predict(using), using)

	// Keep the children removed if better, else restore
	if newAccuracy < baselineAccuracy {
		d.Children = tmpChildren
	}
}
Example #2
0
// InferID3Tree builds a decision tree using a RuleGenerator
// from a set of Instances (implements the ID3 algorithm)
func InferID3Tree(from base.FixedDataGrid, with RuleGenerator) *DecisionTreeNode {
	// Count the number of classes at this node
	classes := base.GetClassDistribution(from)
	// If there's only one class, return a DecisionTreeLeaf with
	// the only class available
	if len(classes) == 1 {
		maxClass := ""
		for i := range classes {
			maxClass = i
		}
		ret := &DecisionTreeNode{
			LeafNode,
			nil,
			classes,
			maxClass,
			getClassAttr(from),
			&DecisionTreeRule{nil, 0.0},
		}
		return ret
	}

	// Only have the class attribute
	maxVal := 0
	maxClass := ""
	for i := range classes {
		if classes[i] > maxVal {
			maxClass = i
			maxVal = classes[i]
		}
	}

	// If there are no more Attributes left to split on,
	// return a DecisionTreeLeaf with the majority class
	cols, _ := from.Size()
	if cols == 2 {
		ret := &DecisionTreeNode{
			LeafNode,
			nil,
			classes,
			maxClass,
			getClassAttr(from),
			&DecisionTreeRule{nil, 0.0},
		}
		return ret
	}

	// Generate a return structure
	ret := &DecisionTreeNode{
		RuleNode,
		nil,
		classes,
		maxClass,
		getClassAttr(from),
		nil,
	}

	// Generate the splitting rule
	splitRule := with.GenerateSplitRule(from)
	if splitRule == nil {
		// Can't determine, just return what we have
		return ret
	}

	// Split the attributes based on this attribute's value
	var splitInstances map[string]base.FixedDataGrid
	if _, ok := splitRule.SplitAttr.(*base.FloatAttribute); ok {
		splitInstances = base.DecomposeOnNumericAttributeThreshold(from,
			splitRule.SplitAttr, splitRule.SplitVal)
	} else {
		splitInstances = base.DecomposeOnAttributeValues(from, splitRule.SplitAttr)
	}
	// Create new children from these attributes
	ret.Children = make(map[string]*DecisionTreeNode)
	for k := range splitInstances {
		newInstances := splitInstances[k]
		ret.Children[k] = InferID3Tree(newInstances, with)
	}
	ret.SplitRule = splitRule
	return ret
}