// Prune eliminates branches which hurt accuracy func (d *DecisionTreeNode) Prune(using base.FixedDataGrid) { // If you're a leaf, you're already pruned if d.Children == nil { return } if d.SplitAttr == nil { return } // Recursively prune children of this node sub := base.DecomposeOnAttributeValues(using, d.SplitAttr) for k := range d.Children { if sub[k] == nil { continue } subH, subV := sub[k].Size() if subH == 0 || subV == 0 { continue } d.Children[k].Prune(sub[k]) } // Get a baseline accuracy baselineAccuracy := computeAccuracy(d.Predict(using), using) // Speculatively remove the children and re-evaluate tmpChildren := d.Children d.Children = nil newAccuracy := computeAccuracy(d.Predict(using), using) // Keep the children removed if better, else restore if newAccuracy < baselineAccuracy { d.Children = tmpChildren } }
// InferID3Tree builds a decision tree using a RuleGenerator // from a set of Instances (implements the ID3 algorithm) func InferID3Tree(from base.FixedDataGrid, with RuleGenerator) *DecisionTreeNode { // Count the number of classes at this node classes := base.GetClassDistribution(from) // If there's only one class, return a DecisionTreeLeaf with // the only class available if len(classes) == 1 { maxClass := "" for i := range classes { maxClass = i } ret := &DecisionTreeNode{ LeafNode, nil, classes, maxClass, getClassAttr(from), &DecisionTreeRule{nil, 0.0}, } return ret } // Only have the class attribute maxVal := 0 maxClass := "" for i := range classes { if classes[i] > maxVal { maxClass = i maxVal = classes[i] } } // If there are no more Attributes left to split on, // return a DecisionTreeLeaf with the majority class cols, _ := from.Size() if cols == 2 { ret := &DecisionTreeNode{ LeafNode, nil, classes, maxClass, getClassAttr(from), &DecisionTreeRule{nil, 0.0}, } return ret } // Generate a return structure ret := &DecisionTreeNode{ RuleNode, nil, classes, maxClass, getClassAttr(from), nil, } // Generate the splitting rule splitRule := with.GenerateSplitRule(from) if splitRule == nil { // Can't determine, just return what we have return ret } // Split the attributes based on this attribute's value var splitInstances map[string]base.FixedDataGrid if _, ok := splitRule.SplitAttr.(*base.FloatAttribute); ok { splitInstances = base.DecomposeOnNumericAttributeThreshold(from, splitRule.SplitAttr, splitRule.SplitVal) } else { splitInstances = base.DecomposeOnAttributeValues(from, splitRule.SplitAttr) } // Create new children from these attributes ret.Children = make(map[string]*DecisionTreeNode) for k := range splitInstances { newInstances := splitInstances[k] ret.Children[k] = InferID3Tree(newInstances, with) } ret.SplitRule = splitRule return ret }