コード例 #1
0
ファイル: bagging.go プロジェクト: jwmu/golearn
// generateTrainingAttrs selects RandomFeatures number of base.Attributes from
// the provided base.Instances.
func (b *BaggedModel) generateTrainingAttrs(model int, from *base.Instances) []base.Attribute {
	ret := make([]base.Attribute, 0)
	if b.RandomFeatures == 0 {
		for j := 0; j < from.Cols; j++ {
			attr := from.GetAttr(j)
			ret = append(ret, attr)
		}
	} else {
		for {
			if len(ret) >= b.RandomFeatures {
				break
			}
			attrIndex := rand.Intn(from.Cols)
			if attrIndex == from.ClassIndex {
				continue
			}
			attr := from.GetAttr(attrIndex)
			matched := false
			for _, a := range ret {
				if a.Equals(attr) {
					matched = true
					break
				}
			}
			if !matched {
				ret = append(ret, attr)
			}
		}
	}
	ret = append(ret, from.GetClassAttr())
	b.lock.Lock()
	b.selectedAttributes[model] = ret
	b.lock.Unlock()
	return ret
}
コード例 #2
0
ファイル: id3.go プロジェクト: 24hours/golearn
// Predict outputs a base.Instances containing predictions from this tree
func (d *DecisionTreeNode) Predict(what *base.Instances) *base.Instances {
	outputAttrs := make([]base.Attribute, 1)
	outputAttrs[0] = what.GetClassAttr()
	predictions := base.NewInstances(outputAttrs, what.Rows)
	for i := 0; i < what.Rows; i++ {
		cur := d
		for {
			if cur.Children == nil {
				predictions.SetAttrStr(i, 0, cur.Class)
				break
			} else {
				at := cur.SplitAttr
				j := what.GetAttrIndex(at)
				if j == -1 {
					predictions.SetAttrStr(i, 0, cur.Class)
					break
				}
				classVar := at.GetStringFromSysVal(what.Get(i, j))
				if next, ok := cur.Children[classVar]; ok {
					cur = next
				} else {
					var bestChild string
					for c := range cur.Children {
						bestChild = c
						if c > classVar {
							break
						}
					}
					cur = cur.Children[bestChild]
				}
			}
		}
	}
	return predictions
}