// // Initialize will check crf inputs and set it to default values if its // invalid. // func (crf *Runtime) Initialize(samples tabula.ClasetInterface) error { if crf.NStage <= 0 { crf.NStage = DefStage } if crf.TPRate <= 0 || crf.TPRate >= 1 { crf.TPRate = DefTPRate } if crf.TNRate <= 0 || crf.TNRate >= 1 { crf.TNRate = DefTNRate } if crf.NTree <= 0 { crf.NTree = DefNumTree } if crf.PercentBoot <= 0 { crf.PercentBoot = DefPercentBoot } if crf.NRandomFeature <= 0 { // Set default value to square-root of features. ncol := samples.GetNColumn() - 1 crf.NRandomFeature = int(math.Sqrt(float64(ncol))) } if crf.PerfFile == "" { crf.PerfFile = DefPerfFile } if crf.StatFile == "" { crf.StatFile = DefStatFile } crf.tnset = samples.Clone().(*tabula.Claset) return crf.Runtime.Initialize() }
// // Initialize will check forest inputs and set it to default values if invalid. // // It will also calculate number of random samples for each tree using, // // number-of-sample * percentage-of-bootstrap // // func (forest *Runtime) Initialize(samples tabula.ClasetInterface) error { if forest.NTree <= 0 { forest.NTree = DefNumTree } if forest.PercentBoot <= 0 { forest.PercentBoot = DefPercentBoot } if forest.NRandomFeature <= 0 { // Set default value to square-root of features. ncol := samples.GetNColumn() - 1 forest.NRandomFeature = int(math.Sqrt(float64(ncol))) } if forest.OOBStatsFile == "" { forest.OOBStatsFile = DefOOBStatsFile } if forest.PerfFile == "" { forest.PerfFile = DefPerfFile } if forest.StatFile == "" { forest.StatFile = DefStatFile } forest.nSubsample = int(float32(samples.GetNRow()) * (float32(forest.PercentBoot) / 100.0)) return forest.Runtime.Initialize() }
// SelectRandomFeature if NRandomFeature is greater than zero, select and // compute gain in n random features instead of in all features func (runtime *Runtime) SelectRandomFeature(D tabula.ClasetInterface) { if runtime.NRandomFeature <= 0 { // all features selected return } ncols := D.GetNColumn() // count all features minus class nfeature := ncols - 1 if runtime.NRandomFeature >= nfeature { // Do nothing if number of random feature equal or greater than // number of feature in dataset. return } // exclude class index and parent node index excludeIdx := []int{D.GetClassIndex()} cols := D.GetColumns() for x, col := range *cols { if (col.Flag & ColFlagParent) == ColFlagParent { excludeIdx = append(excludeIdx, x) } else { (*cols)[x].Flag |= ColFlagSkip } } // Select random features excluding feature in `excludeIdx`. var pickedIdx []int for x := 0; x < runtime.NRandomFeature; x++ { idx := numerus.IntPickRandPositive(ncols, false, pickedIdx, excludeIdx) pickedIdx = append(pickedIdx, idx) // Remove skip flag on selected column col := D.GetColumn(idx) col.Flag = col.Flag &^ ColFlagSkip } if DEBUG >= 1 { fmt.Println("[cart] selected random features:", pickedIdx) fmt.Println("[cart] selected columns :", D.GetColumns()) } }
/* computeGain calculate the gini index for each value in each attribute. */ func (runtime *Runtime) computeGain(D tabula.ClasetInterface) ( gains []gini.Gini, ) { switch runtime.SplitMethod { case SplitMethodGini: // create gains value for all attribute minus target class. gains = make([]gini.Gini, D.GetNColumn()) } runtime.SelectRandomFeature(D) classVS := D.GetClassValueSpace() classIdx := D.GetClassIndex() classType := D.GetClassType() for x, col := range *D.GetColumns() { // skip class attribute. if x == classIdx { continue } // skip column flagged with parent if (col.Flag & ColFlagParent) == ColFlagParent { gains[x].Skip = true continue } // ignore column flagged with skip if (col.Flag & ColFlagSkip) == ColFlagSkip { gains[x].Skip = true continue } // compute gain. if col.GetType() == tabula.TReal { attr := col.ToFloatSlice() if classType == tabula.TString { target := D.GetClassAsStrings() gains[x].ComputeContinu(&attr, &target, &classVS) } else { targetReal := D.GetClassAsReals() classVSReal := tekstus.StringsToFloat64( classVS) gains[x].ComputeContinuFloat(&attr, &targetReal, &classVSReal) } } else { attr := col.ToStringSlice() attrV := col.ValueSpace if DEBUG >= 2 { fmt.Println("[cart] attr :", attr) fmt.Println("[cart] attrV:", attrV) } target := D.GetClassAsStrings() gains[x].ComputeDiscrete(&attr, &attrV, &target, &classVS) } if DEBUG >= 2 { fmt.Println("[cart] gain :", gains[x]) } } return }