// Run applies a trained BinningFilter to a set of Instances, // discretising any numeric attributes added. // // IMPORTANT: Run discretises in-place, so make sure to take // a copy if the original instances are still needed // // IMPORTANT: This function panic()s if the filter has not been // trained. Call Build() before running this function // // IMPORTANT: Call Build() after adding any additional attributes. // Otherwise, the training structure will be out of date from // the values expected and could cause a panic. func (b *BinningFilter) Run(on *base.Instances) { if !b.trained { panic("Call Build() beforehand") } for attr := range b.Attributes { minVal := b.MinVals[attr] maxVal := b.MaxVals[attr] disc := 0 // Casts to float32 to replicate a floating point precision error delta := float32(maxVal - minVal) delta /= float32(b.BinCount) for i := 0; i < on.Rows; i++ { val := on.Get(i, attr) if val <= minVal { disc = 0 } else { disc = int(math.Floor(float64(float32(val-minVal) / delta))) if disc >= b.BinCount { disc = b.BinCount - 1 } } on.Set(i, attr, float64(disc)) } newAttribute := new(base.CategoricalAttribute) newAttribute.SetName(on.GetAttr(attr).GetName()) for i := 0; i < b.BinCount; i++ { newAttribute.GetSysValFromString(fmt.Sprintf("%d", i)) } on.ReplaceAttr(attr, newAttribute) } }
// Run discretises the set of Instances `on' // // IMPORTANT: ChiMergeFilter discretises in place. func (c *ChiMergeFilter) Run(on *base.Instances) { if !c._Trained { panic("Call Build() beforehand") } for attr := range c.Tables { table := c.Tables[attr] for i := 0; i < on.Rows; i++ { val := on.Get(i, attr) dis := 0 for j, k := range table { if k.Value < val { dis = j continue } break } on.Set(i, attr, float64(dis)) } newAttribute := new(base.CategoricalAttribute) newAttribute.SetName(on.GetAttr(attr).GetName()) for _, k := range table { newAttribute.GetSysValFromString(fmt.Sprintf("%f", k.Value)) } on.ReplaceAttr(attr, newAttribute) } }
// GetAttributesAfterFiltering gets a list of before/after // Attributes as base.FilteredAttributes func (c *ChiMergeFilter) GetAttributesAfterFiltering() []base.FilteredAttribute { oldAttrs := c.train.AllAttributes() ret := make([]base.FilteredAttribute, len(oldAttrs)) for i, a := range oldAttrs { if c.attrs[a] { retAttr := new(base.CategoricalAttribute) retAttr.SetName(a.GetName()) for _, k := range c.tables[a] { retAttr.GetSysValFromString(fmt.Sprintf("%f", k.Value)) } ret[i] = base.FilteredAttribute{a, retAttr} } else { ret[i] = base.FilteredAttribute{a, a} } } return ret }
// Fit creates n filtered datasets (where n is the number of values // a CategoricalAttribute can take) and uses them to train the // underlying classifiers. func (m *OneVsAllModel) Fit(using base.FixedDataGrid) { var classAttr *base.CategoricalAttribute // Do some validation classAttrs := using.AllClassAttributes() for _, a := range classAttrs { if c, ok := a.(*base.CategoricalAttribute); !ok { panic("Unsupported ClassAttribute type") } else { classAttr = c } } attrs := m.generateAttributes(using) // Find the highest stored value val := uint64(0) classVals := classAttr.GetValues() for _, s := range classVals { cur := base.UnpackBytesToU64(classAttr.GetSysValFromString(s)) if cur > val { val = cur } } if val == 0 { panic("Must have more than one class!") } m.maxClassVal = val // Create individual filtered instances for training filters := make([]*oneVsAllFilter, val+1) classifiers := make([]base.Classifier, val+1) for i := uint64(0); i <= val; i++ { f := &oneVsAllFilter{ attrs, classAttr, i, } filters[i] = f classifiers[i] = m.NewClassifierFunction(classVals[int(i)]) classifiers[i].Fit(base.NewLazilyFilteredInstances(using, f)) } m.filters = filters m.classifiers = classifiers }
// GetAttributesAfterFiltering gets a list of before/after // Attributes as base.FilteredAttributes func (b *BinningFilter) GetAttributesAfterFiltering() []base.FilteredAttribute { oldAttrs := b.train.AllAttributes() ret := make([]base.FilteredAttribute, len(oldAttrs)) for i, a := range oldAttrs { if b.attrs[a] { retAttr := new(base.CategoricalAttribute) minVal := b.minVals[a] maxVal := b.maxVals[a] delta := float64(maxVal-minVal) / float64(b.bins) retAttr.SetName(a.GetName()) for i := 0; i <= b.bins; i++ { floatVal := float64(i)*delta + minVal fmtStr := fmt.Sprintf("%%.%df", a.(*base.FloatAttribute).Precision) binVal := fmt.Sprintf(fmtStr, floatVal) retAttr.GetSysValFromString(binVal) } ret[i] = base.FilteredAttribute{a, retAttr} } else { ret[i] = base.FilteredAttribute{a, a} } } return ret }