예제 #1
0
func findBestSplit(partition base.FixedDataGrid) {
	var delta float64
	delta = math.MinInt64

	attrs := partition.AllAttributes()
	classAttrs := partition.AllClassAttributes()
	candidates := base.AttributeDifferenceReferences(attrs, classAttrs)

	fmt.Println(delta)
	fmt.Println(classAttrs)
	fmt.Println(reflect.TypeOf(partition))
	fmt.Println(reflect.TypeOf(candidates))

	for i, n := range attrs {
		fmt.Println(i)
		//fmt.Println(partition)
		fmt.Println(reflect.TypeOf(n))
		attributeSpec, _ := partition.GetAttribute(n)

		fmt.Println(partition.GetAttribute(n))
		_, rows := partition.Size()
		for j := 0; j < rows; j++ {
			data := partition.Get(attributeSpec, j)
			fmt.Println(base.UnpackBytesToFloat(data))
		}

	}
}
예제 #2
0
파일: random.go 프로젝트: CTLife/golearn
// GenerateSplitRule returns the best attribute out of those randomly chosen
// which maximises Information Gain
func (r *RandomTreeRuleGenerator) GenerateSplitRule(f base.FixedDataGrid) *DecisionTreeRule {

	var consideredAttributes []base.Attribute

	// First step is to generate the random attributes that we'll consider
	allAttributes := base.AttributeDifferenceReferences(f.AllAttributes(), f.AllClassAttributes())
	maximumAttribute := len(allAttributes)

	attrCounter := 0
	for {
		if len(consideredAttributes) >= r.Attributes {
			break
		}
		selectedAttrIndex := rand.Intn(maximumAttribute)
		selectedAttribute := allAttributes[selectedAttrIndex]
		matched := false
		for _, a := range consideredAttributes {
			if a.Equals(selectedAttribute) {
				matched = true
				break
			}
		}
		if matched {
			continue
		}
		consideredAttributes = append(consideredAttributes, selectedAttribute)
		attrCounter++
	}

	return r.internalRule.GetSplitRuleFromSelection(consideredAttributes, f)
}
예제 #3
0
// GenerateSplitAttribute returns the non-class Attribute which maximises the
// information gain.
//
// IMPORTANT: passing a base.Instances with no Attributes other than the class
// variable will panic()
func (r *InformationGainRuleGenerator) GenerateSplitAttribute(f base.FixedDataGrid) base.Attribute {

	attrs := f.AllAttributes()
	classAttrs := f.AllClassAttributes()
	candidates := base.AttributeDifferenceReferences(attrs, classAttrs)

	return r.GetSplitAttributeFromSelection(candidates, f)
}
예제 #4
0
파일: gini.go 프로젝트: CTLife/golearn
// GenerateSplitRule returns the non-class Attribute-based DecisionTreeRule
// which maximises the information gain.
//
// IMPORTANT: passing a base.Instances with no Attributes other than the class
// variable will panic()
func (g *GiniCoefficientRuleGenerator) GenerateSplitRule(f base.FixedDataGrid) *DecisionTreeRule {

	attrs := f.AllAttributes()
	classAttrs := f.AllClassAttributes()
	candidates := base.AttributeDifferenceReferences(attrs, classAttrs)

	return g.GetSplitRuleFromSelection(candidates, f)
}
예제 #5
0
파일: id3.go 프로젝트: tanduong/golearn
// Predict outputs a base.Instances containing predictions from this tree
func (d *DecisionTreeNode) Predict(what base.FixedDataGrid) (base.FixedDataGrid, error) {
	predictions := base.GeneratePredictionVector(what)
	classAttr := getClassAttr(predictions)
	classAttrSpec, err := predictions.GetAttribute(classAttr)
	if err != nil {
		panic(err)
	}
	predAttrs := base.AttributeDifferenceReferences(what.AllAttributes(), predictions.AllClassAttributes())
	predAttrSpecs := base.ResolveAttributes(what, predAttrs)
	what.MapOverRows(predAttrSpecs, func(row [][]byte, rowNo int) (bool, error) {
		cur := d
		for {
			if cur.Children == nil {
				predictions.Set(classAttrSpec, rowNo, classAttr.GetSysValFromString(cur.Class))
				break
			} else {
				splitVal := cur.SplitRule.SplitVal
				at := cur.SplitRule.SplitAttr
				ats, err := what.GetAttribute(at)
				if err != nil {
					//predictions.Set(classAttrSpec, rowNo, classAttr.GetSysValFromString(cur.Class))
					//break
					panic(err)
				}

				var classVar string
				if _, ok := ats.GetAttribute().(*base.FloatAttribute); ok {
					// If it's a numeric Attribute (e.g. FloatAttribute) check that
					// the value of the current node is greater than the old one
					classVal := base.UnpackBytesToFloat(what.Get(ats, rowNo))
					if classVal > splitVal {
						classVar = "1"
					} else {
						classVar = "0"
					}
				} else {
					classVar = ats.GetAttribute().GetStringFromSysVal(what.Get(ats, rowNo))
				}
				if next, ok := cur.Children[classVar]; ok {
					cur = next
				} else {
					// Suspicious of this
					var bestChild string
					for c := range cur.Children {
						bestChild = c
						if c > classVar {
							break
						}
					}
					cur = cur.Children[bestChild]
				}
			}
		}
		return true, nil
	})
	return predictions, nil
}
예제 #6
0
파일: id3.go 프로젝트: JacobXie/golearn
// Predict outputs a base.Instances containing predictions from this tree
func (d *DecisionTreeNode) Predict(what base.FixedDataGrid) base.FixedDataGrid {
	predictions := base.GeneratePredictionVector(what)
	classAttr := getClassAttr(predictions)
	classAttrSpec, err := predictions.GetAttribute(classAttr)
	if err != nil {
		panic(err)
	}
	predAttrs := base.AttributeDifferenceReferences(what.AllAttributes(), predictions.AllClassAttributes())
	predAttrSpecs := base.ResolveAttributes(what, predAttrs)
	what.MapOverRows(predAttrSpecs, func(row [][]byte, rowNo int) (bool, error) {
		cur := d
		for {
			if cur.Children == nil {
				predictions.Set(classAttrSpec, rowNo, classAttr.GetSysValFromString(cur.Class))
				break
			} else {
				at := cur.SplitAttr
				ats, err := what.GetAttribute(at)
				if err != nil {
					predictions.Set(classAttrSpec, rowNo, classAttr.GetSysValFromString(cur.Class))
					break
				}

				classVar := ats.GetAttribute().GetStringFromSysVal(what.Get(ats, rowNo))
				if next, ok := cur.Children[classVar]; ok {
					cur = next
				} else {
					var bestChild string
					for c := range cur.Children {
						bestChild = c
						if c > classVar {
							break
						}
					}
					cur = cur.Children[bestChild]
				}
			}
		}
		return true, nil
	})
	return predictions
}