// Run applies a trained BinningFilter to a set of Instances, // discretising any numeric attributes added. // // IMPORTANT: Run discretises in-place, so make sure to take // a copy if the original instances are still needed // // IMPORTANT: This function panic()s if the filter has not been // trained. Call Build() before running this function // // IMPORTANT: Call Build() after adding any additional attributes. // Otherwise, the training structure will be out of date from // the values expected and could cause a panic. func (b *BinningFilter) Run(on *base.Instances) { if !b.trained { panic("Call Build() beforehand") } for attr := range b.Attributes { minVal := b.MinVals[attr] maxVal := b.MaxVals[attr] disc := 0 // Casts to float32 to replicate a floating point precision error delta := float32(maxVal - minVal) delta /= float32(b.BinCount) for i := 0; i < on.Rows; i++ { val := on.Get(i, attr) if val <= minVal { disc = 0 } else { disc = int(math.Floor(float64(float32(val-minVal) / delta))) if disc >= b.BinCount { disc = b.BinCount - 1 } } on.Set(i, attr, float64(disc)) } newAttribute := new(base.CategoricalAttribute) newAttribute.SetName(on.GetAttr(attr).GetName()) for i := 0; i < b.BinCount; i++ { newAttribute.GetSysValFromString(fmt.Sprintf("%d", i)) } on.ReplaceAttr(attr, newAttribute) } }
// Predict outputs a base.Instances containing predictions from this tree func (d *DecisionTreeNode) Predict(what *base.Instances) *base.Instances { outputAttrs := make([]base.Attribute, 1) outputAttrs[0] = what.GetClassAttr() predictions := base.NewInstances(outputAttrs, what.Rows) for i := 0; i < what.Rows; i++ { cur := d for { if cur.Children == nil { predictions.SetAttrStr(i, 0, cur.Class) break } else { at := cur.SplitAttr j := what.GetAttrIndex(at) if j == -1 { predictions.SetAttrStr(i, 0, cur.Class) break } classVar := at.GetStringFromSysVal(what.Get(i, j)) if next, ok := cur.Children[classVar]; ok { cur = next } else { var bestChild string for c := range cur.Children { bestChild = c if c > classVar { break } } cur = cur.Children[bestChild] } } } } return predictions }
func convertInstancesToLabelVec(X *base.Instances) []float64 { labelVec := make([]float64, X.Rows) for i := 0; i < X.Rows; i++ { labelVec[i] = X.Get(i, X.ClassIndex) } return labelVec }
// Run discretises the set of Instances `on' // // IMPORTANT: ChiMergeFilter discretises in place. func (c *ChiMergeFilter) Run(on *base.Instances) { if !c._Trained { panic("Call Build() beforehand") } for attr := range c.Tables { table := c.Tables[attr] for i := 0; i < on.Rows; i++ { val := on.Get(i, attr) dis := 0 for j, k := range table { if k.Value < val { dis = j continue } break } on.Set(i, attr, float64(dis)) } newAttribute := new(base.CategoricalAttribute) newAttribute.SetName(on.GetAttr(attr).GetName()) for _, k := range table { newAttribute.GetSysValFromString(fmt.Sprintf("%f", k.Value)) } on.ReplaceAttr(attr, newAttribute) } }
func ChiMBuildFrequencyTable(attr int, inst *base.Instances) []*FrequencyTableEntry { ret := make([]*FrequencyTableEntry, 0) var attribute *base.FloatAttribute attribute, ok := inst.GetAttr(attr).(*base.FloatAttribute) if !ok { panic("only use Chi-M on numeric stuff") } for i := 0; i < inst.Rows; i++ { value := inst.Get(i, attr) valueConv := attribute.GetUsrVal(value) class := inst.GetClass(i) // Search the frequency table for the value found := false for _, entry := range ret { if entry.Value == valueConv { found = true entry.Frequency[class] += 1 } } if !found { newEntry := &FrequencyTableEntry{ valueConv, make(map[string]int), } newEntry.Frequency[class] = 1 ret = append(ret, newEntry) } } return ret }
func (lr *LinearRegression) Fit(inst *base.Instances) error { if inst.Rows < inst.GetAttributeCount() { return NotEnoughDataError } // Split into two matrices, observed results (dependent variable y) // and the explanatory variables (X) - see http://en.wikipedia.org/wiki/Linear_regression observed := mat64.NewDense(inst.Rows, 1, nil) explVariables := mat64.NewDense(inst.Rows, inst.GetAttributeCount(), nil) for i := 0; i < inst.Rows; i++ { observed.Set(i, 0, inst.Get(i, inst.ClassIndex)) // Set observed data for j := 0; j < inst.GetAttributeCount(); j++ { if j == 0 { // Set intercepts to 1.0 // Could / should be done better: http://www.theanalysisfactor.com/interpret-the-intercept/ explVariables.Set(i, 0, 1.0) } else { explVariables.Set(i, j, inst.Get(i, j-1)) } } } n := inst.GetAttributeCount() qr := mat64.QR(explVariables) q := qr.Q() reg := qr.R() var transposed, qty mat64.Dense transposed.TCopy(q) qty.Mul(&transposed, observed) regressionCoefficients := make([]float64, n) for i := n - 1; i >= 0; i-- { regressionCoefficients[i] = qty.At(i, 0) for j := i + 1; j < n; j++ { regressionCoefficients[i] -= regressionCoefficients[j] * reg.At(i, j) } regressionCoefficients[i] /= reg.At(i, i) } lr.disturbance = regressionCoefficients[0] lr.regressionCoefficients = regressionCoefficients[1:] lr.fitted = true return nil }
func (lr *LogisticRegression) Predict(X *base.Instances) *base.Instances { ret := X.GeneratePredictionVector() row := make([]float64, X.Cols-1) for i := 0; i < X.Rows; i++ { rowCounter := 0 for j := 0; j < X.Cols; j++ { if j != X.ClassIndex { row[rowCounter] = X.Get(i, j) rowCounter++ } } fmt.Println(Predict(lr.model, row), row) ret.Set(i, 0, Predict(lr.model, row)) } return ret }
func convertInstancesToProblemVec(X *base.Instances) [][]float64 { problemVec := make([][]float64, X.Rows) for i := 0; i < X.Rows; i++ { problemVecCounter := 0 problemVec[i] = make([]float64, X.Cols-1) for j := 0; j < X.Cols; j++ { if j == X.ClassIndex { continue } problemVec[i][problemVecCounter] = X.Get(i, j) problemVecCounter++ } } fmt.Println(problemVec, X) return problemVec }
func (lr *LinearRegression) Predict(X *base.Instances) (*base.Instances, error) { if !lr.fitted { return nil, NoTrainingDataError } ret := X.GeneratePredictionVector() for i := 0; i < X.Rows; i++ { var prediction float64 = lr.disturbance for j := 0; j < X.Cols; j++ { if j != X.ClassIndex { prediction += X.Get(i, j) * lr.regressionCoefficients[j] } } ret.Set(i, 0, prediction) } return ret, nil }