// SelectRandomFeature if NRandomFeature is greater than zero, select and // compute gain in n random features instead of in all features func (runtime *Runtime) SelectRandomFeature(D tabula.ClasetInterface) { if runtime.NRandomFeature <= 0 { // all features selected return } ncols := D.GetNColumn() // count all features minus class nfeature := ncols - 1 if runtime.NRandomFeature >= nfeature { // Do nothing if number of random feature equal or greater than // number of feature in dataset. return } // exclude class index and parent node index excludeIdx := []int{D.GetClassIndex()} cols := D.GetColumns() for x, col := range *cols { if (col.Flag & ColFlagParent) == ColFlagParent { excludeIdx = append(excludeIdx, x) } else { (*cols)[x].Flag |= ColFlagSkip } } // Select random features excluding feature in `excludeIdx`. var pickedIdx []int for x := 0; x < runtime.NRandomFeature; x++ { idx := numerus.IntPickRandPositive(ncols, false, pickedIdx, excludeIdx) pickedIdx = append(pickedIdx, idx) // Remove skip flag on selected column col := D.GetColumn(idx) col.Flag = col.Flag &^ ColFlagSkip } if DEBUG >= 1 { fmt.Println("[cart] selected random features:", pickedIdx) fmt.Println("[cart] selected columns :", D.GetColumns()) } }
/* computeGain calculate the gini index for each value in each attribute. */ func (runtime *Runtime) computeGain(D tabula.ClasetInterface) ( gains []gini.Gini, ) { switch runtime.SplitMethod { case SplitMethodGini: // create gains value for all attribute minus target class. gains = make([]gini.Gini, D.GetNColumn()) } runtime.SelectRandomFeature(D) classVS := D.GetClassValueSpace() classIdx := D.GetClassIndex() classType := D.GetClassType() for x, col := range *D.GetColumns() { // skip class attribute. if x == classIdx { continue } // skip column flagged with parent if (col.Flag & ColFlagParent) == ColFlagParent { gains[x].Skip = true continue } // ignore column flagged with skip if (col.Flag & ColFlagSkip) == ColFlagSkip { gains[x].Skip = true continue } // compute gain. if col.GetType() == tabula.TReal { attr := col.ToFloatSlice() if classType == tabula.TString { target := D.GetClassAsStrings() gains[x].ComputeContinu(&attr, &target, &classVS) } else { targetReal := D.GetClassAsReals() classVSReal := tekstus.StringsToFloat64( classVS) gains[x].ComputeContinuFloat(&attr, &targetReal, &classVSReal) } } else { attr := col.ToStringSlice() attrV := col.ValueSpace if DEBUG >= 2 { fmt.Println("[cart] attr :", attr) fmt.Println("[cart] attrV:", attrV) } target := D.GetClassAsStrings() gains[x].ComputeDiscrete(&attr, &attrV, &target, &classVS) } if DEBUG >= 2 { fmt.Println("[cart] gain :", gains[x]) } } return }