func findBestSplit(partition base.FixedDataGrid) { var delta float64 delta = math.MinInt64 attrs := partition.AllAttributes() classAttrs := partition.AllClassAttributes() candidates := base.AttributeDifferenceReferences(attrs, classAttrs) fmt.Println(delta) fmt.Println(classAttrs) fmt.Println(reflect.TypeOf(partition)) fmt.Println(reflect.TypeOf(candidates)) for i, n := range attrs { fmt.Println(i) //fmt.Println(partition) fmt.Println(reflect.TypeOf(n)) attributeSpec, _ := partition.GetAttribute(n) fmt.Println(partition.GetAttribute(n)) _, rows := partition.Size() for j := 0; j < rows; j++ { data := partition.Get(attributeSpec, j) fmt.Println(base.UnpackBytesToFloat(data)) } } }
// Predict outputs a base.Instances containing predictions from this tree func (d *DecisionTreeNode) Predict(what base.FixedDataGrid) (base.FixedDataGrid, error) { predictions := base.GeneratePredictionVector(what) classAttr := getClassAttr(predictions) classAttrSpec, err := predictions.GetAttribute(classAttr) if err != nil { panic(err) } predAttrs := base.AttributeDifferenceReferences(what.AllAttributes(), predictions.AllClassAttributes()) predAttrSpecs := base.ResolveAttributes(what, predAttrs) what.MapOverRows(predAttrSpecs, func(row [][]byte, rowNo int) (bool, error) { cur := d for { if cur.Children == nil { predictions.Set(classAttrSpec, rowNo, classAttr.GetSysValFromString(cur.Class)) break } else { splitVal := cur.SplitRule.SplitVal at := cur.SplitRule.SplitAttr ats, err := what.GetAttribute(at) if err != nil { //predictions.Set(classAttrSpec, rowNo, classAttr.GetSysValFromString(cur.Class)) //break panic(err) } var classVar string if _, ok := ats.GetAttribute().(*base.FloatAttribute); ok { // If it's a numeric Attribute (e.g. FloatAttribute) check that // the value of the current node is greater than the old one classVal := base.UnpackBytesToFloat(what.Get(ats, rowNo)) if classVal > splitVal { classVar = "1" } else { classVar = "0" } } else { classVar = ats.GetAttribute().GetStringFromSysVal(what.Get(ats, rowNo)) } if next, ok := cur.Children[classVar]; ok { cur = next } else { // Suspicious of this var bestChild string for c := range cur.Children { bestChild = c if c > classVar { break } } cur = cur.Children[bestChild] } } } return true, nil }) return predictions, nil }
// Predict outputs a base.Instances containing predictions from this tree func (d *DecisionTreeNode) Predict(what base.FixedDataGrid) base.FixedDataGrid { predictions := base.GeneratePredictionVector(what) classAttr := getClassAttr(predictions) classAttrSpec, err := predictions.GetAttribute(classAttr) if err != nil { panic(err) } predAttrs := base.AttributeDifferenceReferences(what.AllAttributes(), predictions.AllClassAttributes()) predAttrSpecs := base.ResolveAttributes(what, predAttrs) what.MapOverRows(predAttrSpecs, func(row [][]byte, rowNo int) (bool, error) { cur := d for { if cur.Children == nil { predictions.Set(classAttrSpec, rowNo, classAttr.GetSysValFromString(cur.Class)) break } else { at := cur.SplitAttr ats, err := what.GetAttribute(at) if err != nil { predictions.Set(classAttrSpec, rowNo, classAttr.GetSysValFromString(cur.Class)) break } classVar := ats.GetAttribute().GetStringFromSysVal(what.Get(ats, rowNo)) if next, ok := cur.Children[classVar]; ok { cur = next } else { var bestChild string for c := range cur.Children { bestChild = c if c > classVar { break } } cur = cur.Children[bestChild] } } } return true, nil }) return predictions }