/* splitTreeByGain calculate the gain in all dataset, and split into two node: left and right. Return node with the split information. */ func (runtime *Runtime) splitTreeByGain(D tabula.ClasetInterface) ( node *binary.BTNode, e error, ) { node = &binary.BTNode{} D.RecountMajorMinor() // if dataset is empty return node labeled with majority classes in // dataset. nrow := D.GetNRow() if nrow <= 0 { if DEBUG >= 2 { fmt.Printf("[cart] empty dataset (%s) : %v\n", D.MajorityClass(), D) } node.Value = NodeValue{ IsLeaf: true, Class: D.MajorityClass(), Size: 0, } return node, nil } // if all dataset is in the same class, return node as leaf with class // is set to that class. single, name := D.IsInSingleClass() if single { if DEBUG >= 2 { fmt.Printf("[cart] in single class (%s): %v\n", name, D.GetColumns()) } node.Value = NodeValue{ IsLeaf: true, Class: name, Size: nrow, } return node, nil } if DEBUG >= 2 { fmt.Println("[cart] D:", D) } // calculate the Gini gain for each attribute. gains := runtime.computeGain(D) // get attribute with maximum Gini gain. MaxGainIdx := gini.FindMaxGain(&gains) MaxGain := gains[MaxGainIdx] // if maxgain value is 0, use majority class as node and terminate // the process if MaxGain.GetMaxGainValue() == 0 { if DEBUG >= 2 { fmt.Println("[cart] max gain 0 with target", D.GetClassAsStrings(), " and majority class is ", D.MajorityClass()) } node.Value = NodeValue{ IsLeaf: true, Class: D.MajorityClass(), Size: 0, } return node, nil } // using the sorted index in MaxGain, sort all field in dataset tabula.SortColumnsByIndex(D, MaxGain.SortedIndex) if DEBUG >= 2 { fmt.Println("[cart] maxgain:", MaxGain) } // Now that we have attribute with max gain in MaxGainIdx, and their // gain dan partition value in Gains[MaxGainIdx] and // GetMaxPartValue(), we split the dataset based on type of max-gain // attribute. // If its continuous, split the attribute using numeric value. // If its discrete, split the attribute using subset (partition) of // nominal values. var splitV interface{} if MaxGain.IsContinu { splitV = MaxGain.GetMaxPartGainValue() } else { attrPartV := MaxGain.GetMaxPartGainValue() attrSubV := attrPartV.(tekstus.ListStrings) splitV = attrSubV[0].Normalize() } if DEBUG >= 2 { fmt.Println("[cart] maxgainindex:", MaxGainIdx) fmt.Println("[cart] split v:", splitV) } node.Value = NodeValue{ SplitAttrName: D.GetColumn(MaxGainIdx).GetName(), IsLeaf: false, IsContinu: MaxGain.IsContinu, Size: nrow, SplitAttrIdx: MaxGainIdx, SplitV: splitV, } dsL, dsR, e := tabula.SplitRowsByValue(D, MaxGainIdx, splitV) if e != nil { return node, e } splitL := dsL.(tabula.ClasetInterface) splitR := dsR.(tabula.ClasetInterface) // Set the flag to parent in attribute referenced by // MaxGainIdx, so it will not computed again in the next round. cols := splitL.GetColumns() for x := range *cols { if x == MaxGainIdx { (*cols)[x].Flag = ColFlagParent } else { (*cols)[x].Flag = 0 } } cols = splitR.GetColumns() for x := range *cols { if x == MaxGainIdx { (*cols)[x].Flag = ColFlagParent } else { (*cols)[x].Flag = 0 } } nodeLeft, e := runtime.splitTreeByGain(splitL) if e != nil { return node, e } nodeRight, e := runtime.splitTreeByGain(splitR) if e != nil { return node, e } node.SetLeft(nodeLeft) node.SetRight(nodeRight) return node, nil }