예제 #1
0
파일: cart.go 프로젝트: shuLhan/go-mining
// SelectRandomFeature if NRandomFeature is greater than zero, select and
// compute gain in n random features instead of in all features
func (runtime *Runtime) SelectRandomFeature(D tabula.ClasetInterface) {
	if runtime.NRandomFeature <= 0 {
		// all features selected
		return
	}

	ncols := D.GetNColumn()

	// count all features minus class
	nfeature := ncols - 1
	if runtime.NRandomFeature >= nfeature {
		// Do nothing if number of random feature equal or greater than
		// number of feature in dataset.
		return
	}

	// exclude class index and parent node index
	excludeIdx := []int{D.GetClassIndex()}
	cols := D.GetColumns()
	for x, col := range *cols {
		if (col.Flag & ColFlagParent) == ColFlagParent {
			excludeIdx = append(excludeIdx, x)
		} else {
			(*cols)[x].Flag |= ColFlagSkip
		}
	}

	// Select random features excluding feature in `excludeIdx`.
	var pickedIdx []int
	for x := 0; x < runtime.NRandomFeature; x++ {
		idx := numerus.IntPickRandPositive(ncols, false, pickedIdx,
			excludeIdx)
		pickedIdx = append(pickedIdx, idx)

		// Remove skip flag on selected column
		col := D.GetColumn(idx)
		col.Flag = col.Flag &^ ColFlagSkip
	}

	if DEBUG >= 1 {
		fmt.Println("[cart] selected random features:", pickedIdx)
		fmt.Println("[cart] selected columns        :", D.GetColumns())
	}
}
예제 #2
0
파일: cart.go 프로젝트: shuLhan/go-mining
/*
computeGain calculate the gini index for each value in each attribute.
*/
func (runtime *Runtime) computeGain(D tabula.ClasetInterface) (
	gains []gini.Gini,
) {
	switch runtime.SplitMethod {
	case SplitMethodGini:
		// create gains value for all attribute minus target class.
		gains = make([]gini.Gini, D.GetNColumn())
	}

	runtime.SelectRandomFeature(D)

	classVS := D.GetClassValueSpace()
	classIdx := D.GetClassIndex()
	classType := D.GetClassType()

	for x, col := range *D.GetColumns() {
		// skip class attribute.
		if x == classIdx {
			continue
		}

		// skip column flagged with parent
		if (col.Flag & ColFlagParent) == ColFlagParent {
			gains[x].Skip = true
			continue
		}

		// ignore column flagged with skip
		if (col.Flag & ColFlagSkip) == ColFlagSkip {
			gains[x].Skip = true
			continue
		}

		// compute gain.
		if col.GetType() == tabula.TReal {
			attr := col.ToFloatSlice()

			if classType == tabula.TString {
				target := D.GetClassAsStrings()
				gains[x].ComputeContinu(&attr, &target,
					&classVS)
			} else {
				targetReal := D.GetClassAsReals()
				classVSReal := tekstus.StringsToFloat64(
					classVS)

				gains[x].ComputeContinuFloat(&attr,
					&targetReal, &classVSReal)
			}
		} else {
			attr := col.ToStringSlice()
			attrV := col.ValueSpace

			if DEBUG >= 2 {
				fmt.Println("[cart] attr :", attr)
				fmt.Println("[cart] attrV:", attrV)
			}

			target := D.GetClassAsStrings()
			gains[x].ComputeDiscrete(&attr, &attrV, &target,
				&classVS)
		}

		if DEBUG >= 2 {
			fmt.Println("[cart] gain :", gains[x])
		}
	}
	return
}
예제 #3
0
파일: cart.go 프로젝트: shuLhan/go-mining
/*
splitTreeByGain calculate the gain in all dataset, and split into two node:
left and right.

Return node with the split information.
*/
func (runtime *Runtime) splitTreeByGain(D tabula.ClasetInterface) (
	node *binary.BTNode,
	e error,
) {
	node = &binary.BTNode{}

	D.RecountMajorMinor()

	// if dataset is empty return node labeled with majority classes in
	// dataset.
	nrow := D.GetNRow()

	if nrow <= 0 {
		if DEBUG >= 2 {
			fmt.Printf("[cart] empty dataset (%s) : %v\n",
				D.MajorityClass(), D)
		}

		node.Value = NodeValue{
			IsLeaf: true,
			Class:  D.MajorityClass(),
			Size:   0,
		}
		return node, nil
	}

	// if all dataset is in the same class, return node as leaf with class
	// is set to that class.
	single, name := D.IsInSingleClass()
	if single {
		if DEBUG >= 2 {
			fmt.Printf("[cart] in single class (%s): %v\n", name,
				D.GetColumns())
		}

		node.Value = NodeValue{
			IsLeaf: true,
			Class:  name,
			Size:   nrow,
		}
		return node, nil
	}

	if DEBUG >= 2 {
		fmt.Println("[cart] D:", D)
	}

	// calculate the Gini gain for each attribute.
	gains := runtime.computeGain(D)

	// get attribute with maximum Gini gain.
	MaxGainIdx := gini.FindMaxGain(&gains)
	MaxGain := gains[MaxGainIdx]

	// if maxgain value is 0, use majority class as node and terminate
	// the process
	if MaxGain.GetMaxGainValue() == 0 {
		if DEBUG >= 2 {
			fmt.Println("[cart] max gain 0 with target",
				D.GetClassAsStrings(),
				" and majority class is ", D.MajorityClass())
		}

		node.Value = NodeValue{
			IsLeaf: true,
			Class:  D.MajorityClass(),
			Size:   0,
		}
		return node, nil
	}

	// using the sorted index in MaxGain, sort all field in dataset
	tabula.SortColumnsByIndex(D, MaxGain.SortedIndex)

	if DEBUG >= 2 {
		fmt.Println("[cart] maxgain:", MaxGain)
	}

	// Now that we have attribute with max gain in MaxGainIdx, and their
	// gain dan partition value in Gains[MaxGainIdx] and
	// GetMaxPartValue(), we split the dataset based on type of max-gain
	// attribute.
	// If its continuous, split the attribute using numeric value.
	// If its discrete, split the attribute using subset (partition) of
	// nominal values.
	var splitV interface{}

	if MaxGain.IsContinu {
		splitV = MaxGain.GetMaxPartGainValue()
	} else {
		attrPartV := MaxGain.GetMaxPartGainValue()
		attrSubV := attrPartV.(tekstus.ListStrings)
		splitV = attrSubV[0].Normalize()
	}

	if DEBUG >= 2 {
		fmt.Println("[cart] maxgainindex:", MaxGainIdx)
		fmt.Println("[cart] split v:", splitV)
	}

	node.Value = NodeValue{
		SplitAttrName: D.GetColumn(MaxGainIdx).GetName(),
		IsLeaf:        false,
		IsContinu:     MaxGain.IsContinu,
		Size:          nrow,
		SplitAttrIdx:  MaxGainIdx,
		SplitV:        splitV,
	}

	dsL, dsR, e := tabula.SplitRowsByValue(D, MaxGainIdx, splitV)

	if e != nil {
		return node, e
	}

	splitL := dsL.(tabula.ClasetInterface)
	splitR := dsR.(tabula.ClasetInterface)

	// Set the flag to parent in attribute referenced by
	// MaxGainIdx, so it will not computed again in the next round.
	cols := splitL.GetColumns()
	for x := range *cols {
		if x == MaxGainIdx {
			(*cols)[x].Flag = ColFlagParent
		} else {
			(*cols)[x].Flag = 0
		}
	}

	cols = splitR.GetColumns()
	for x := range *cols {
		if x == MaxGainIdx {
			(*cols)[x].Flag = ColFlagParent
		} else {
			(*cols)[x].Flag = 0
		}
	}

	nodeLeft, e := runtime.splitTreeByGain(splitL)
	if e != nil {
		return node, e
	}

	nodeRight, e := runtime.splitTreeByGain(splitR)
	if e != nil {
		return node, e
	}

	node.SetLeft(nodeLeft)
	node.SetRight(nodeRight)

	return node, nil
}