/* splitTreeByGain calculate the gain in all dataset, and split into two node: left and right. Return node with the split information. */ func (runtime *Runtime) splitTreeByGain(D tabula.ClasetInterface) ( node *binary.BTNode, e error, ) { node = &binary.BTNode{} D.RecountMajorMinor() // if dataset is empty return node labeled with majority classes in // dataset. nrow := D.GetNRow() if nrow <= 0 { if DEBUG >= 2 { fmt.Printf("[cart] empty dataset (%s) : %v\n", D.MajorityClass(), D) } node.Value = NodeValue{ IsLeaf: true, Class: D.MajorityClass(), Size: 0, } return node, nil } // if all dataset is in the same class, return node as leaf with class // is set to that class. single, name := D.IsInSingleClass() if single { if DEBUG >= 2 { fmt.Printf("[cart] in single class (%s): %v\n", name, D.GetColumns()) } node.Value = NodeValue{ IsLeaf: true, Class: name, Size: nrow, } return node, nil } if DEBUG >= 2 { fmt.Println("[cart] D:", D) } // calculate the Gini gain for each attribute. gains := runtime.computeGain(D) // get attribute with maximum Gini gain. MaxGainIdx := gini.FindMaxGain(&gains) MaxGain := gains[MaxGainIdx] // if maxgain value is 0, use majority class as node and terminate // the process if MaxGain.GetMaxGainValue() == 0 { if DEBUG >= 2 { fmt.Println("[cart] max gain 0 with target", D.GetClassAsStrings(), " and majority class is ", D.MajorityClass()) } node.Value = NodeValue{ IsLeaf: true, Class: D.MajorityClass(), Size: 0, } return node, nil } // using the sorted index in MaxGain, sort all field in dataset tabula.SortColumnsByIndex(D, MaxGain.SortedIndex) if DEBUG >= 2 { fmt.Println("[cart] maxgain:", MaxGain) } // Now that we have attribute with max gain in MaxGainIdx, and their // gain dan partition value in Gains[MaxGainIdx] and // GetMaxPartValue(), we split the dataset based on type of max-gain // attribute. // If its continuous, split the attribute using numeric value. // If its discrete, split the attribute using subset (partition) of // nominal values. var splitV interface{} if MaxGain.IsContinu { splitV = MaxGain.GetMaxPartGainValue() } else { attrPartV := MaxGain.GetMaxPartGainValue() attrSubV := attrPartV.(tekstus.ListStrings) splitV = attrSubV[0].Normalize() } if DEBUG >= 2 { fmt.Println("[cart] maxgainindex:", MaxGainIdx) fmt.Println("[cart] split v:", splitV) } node.Value = NodeValue{ SplitAttrName: D.GetColumn(MaxGainIdx).GetName(), IsLeaf: false, IsContinu: MaxGain.IsContinu, Size: nrow, SplitAttrIdx: MaxGainIdx, SplitV: splitV, } dsL, dsR, e := tabula.SplitRowsByValue(D, MaxGainIdx, splitV) if e != nil { return node, e } splitL := dsL.(tabula.ClasetInterface) splitR := dsR.(tabula.ClasetInterface) // Set the flag to parent in attribute referenced by // MaxGainIdx, so it will not computed again in the next round. cols := splitL.GetColumns() for x := range *cols { if x == MaxGainIdx { (*cols)[x].Flag = ColFlagParent } else { (*cols)[x].Flag = 0 } } cols = splitR.GetColumns() for x := range *cols { if x == MaxGainIdx { (*cols)[x].Flag = ColFlagParent } else { (*cols)[x].Flag = 0 } } nodeLeft, e := runtime.splitTreeByGain(splitL) if e != nil { return node, e } nodeRight, e := runtime.splitTreeByGain(splitR) if e != nil { return node, e } node.SetLeft(nodeLeft) node.SetRight(nodeRight) return node, nil }
// // createForest will create and return a forest and run the training `samples` // on it. // // Algorithm, // (1) Initialize forest. // (2) For 0 to maximum number of tree in forest, // (2.1) grow one tree until success. // (2.2) If tree tp-rate and tn-rate greater than threshold, stop growing. // (3) Calculate weight. // (4) TODO: Move true-negative from samples. The collection of true-negative // will be used again to test the model and after test and the sample with FP // will be moved to training samples again. // (5) Refill samples with false-positive. // func (crf *Runtime) createForest(samples tabula.ClasetInterface) ( forest *rf.Runtime, e error, ) { var cm *classifier.CM var stat *classifier.Stat fmt.Println(tag, "Forest samples:", samples) // (1) forest = &rf.Runtime{ Runtime: classifier.Runtime{ RunOOB: true, }, NTree: crf.NTree, NRandomFeature: crf.NRandomFeature, } e = forest.Initialize(samples) if e != nil { return nil, e } // (2) for t := 0; t < crf.NTree; t++ { if DEBUG >= 2 { fmt.Println(tag, "Tree #", t) } // (2.1) for { cm, stat, e = forest.GrowTree(samples) if e == nil { break } } // (2.2) if stat.TPRate > crf.TPRate && stat.TNRate > crf.TNRate { break } } e = forest.Finalize() if e != nil { return nil, e } // (3) crf.computeWeight(stat) if DEBUG >= 1 { fmt.Println(tag, "Weight:", stat.FMeasure) } // (4) crf.deleteTrueNegative(samples, cm) // (5) crf.runTPSet(samples) samples.RecountMajorMinor() return forest, nil }