// param base.IFixedDataGrid // return base.IFixedDataGrid func (p *AveragePerceptron) Predict(what base.FixedDataGrid) base.FixedDataGrid { if !p.trained { panic("Cannot call Predict on an untrained AveragePerceptron") } data := processData(what) allAttrs := base.CheckCompatible(what, p.TrainingData) if allAttrs == nil { // Don't have the same Attributes return nil } // Remove the Attributes which aren't numeric allNumericAttrs := make([]base.Attribute, 0) for _, a := range allAttrs { if fAttr, ok := a.(*base.FloatAttribute); ok { allNumericAttrs = append(allNumericAttrs, fAttr) } } ret := base.GeneratePredictionVector(what) classAttr := ret.AllClassAttributes()[0] classSpec, err := ret.GetAttribute(classAttr) if err != nil { panic(err) } for i, datum := range data { result := p.score(datum) if result > 0.0 { ret.Set(classSpec, i, base.PackU64ToBytes(1)) } else { ret.Set(classSpec, 1, []byte{0, 0, 0, 0, 0, 0, 0, 0}) } } return ret }
// Predict returns a classification for the vector, based on a vector input, using the KNN algorithm. func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) base.FixedDataGrid { // Check what distance function we are using var distanceFunc pairwise.PairwiseDistanceFunc switch KNN.DistanceFunc { case "euclidean": distanceFunc = pairwise.NewEuclidean() case "manhattan": distanceFunc = pairwise.NewManhattan() default: panic("unsupported distance function") } // Check Compatibility allAttrs := base.CheckCompatible(what, KNN.TrainingData) if allAttrs == nil { // Don't have the same Attributes return nil } // Remove the Attributes which aren't numeric allNumericAttrs := make([]base.Attribute, 0) for _, a := range allAttrs { if fAttr, ok := a.(*base.FloatAttribute); ok { allNumericAttrs = append(allNumericAttrs, fAttr) } } // Generate return vector ret := base.GeneratePredictionVector(what) // Resolve Attribute specifications for both whatAttrSpecs := base.ResolveAttributes(what, allNumericAttrs) trainAttrSpecs := base.ResolveAttributes(KNN.TrainingData, allNumericAttrs) // Reserve storage for most the most similar items distances := make(map[int]float64) // Reserve storage for voting map maxmap := make(map[string]int) // Reserve storage for row computations trainRowBuf := make([]float64, len(allNumericAttrs)) predRowBuf := make([]float64, len(allNumericAttrs)) // Iterate over all outer rows what.MapOverRows(whatAttrSpecs, func(predRow [][]byte, predRowNo int) (bool, error) { // Read the float values out for i, _ := range allNumericAttrs { predRowBuf[i] = base.UnpackBytesToFloat(predRow[i]) } predMat := utilities.FloatsToMatrix(predRowBuf) // Find the closest match in the training data KNN.TrainingData.MapOverRows(trainAttrSpecs, func(trainRow [][]byte, srcRowNo int) (bool, error) { // Read the float values out for i, _ := range allNumericAttrs { trainRowBuf[i] = base.UnpackBytesToFloat(trainRow[i]) } // Compute the distance trainMat := utilities.FloatsToMatrix(trainRowBuf) distances[srcRowNo] = distanceFunc.Distance(predMat, trainMat) return true, nil }) sorted := utilities.SortIntMap(distances) values := sorted[:KNN.NearestNeighbours] // Reset maxMap for a := range maxmap { maxmap[a] = 0 } // Refresh maxMap for _, elem := range values { label := base.GetClass(KNN.TrainingData, elem) if _, ok := maxmap[label]; ok { maxmap[label]++ } else { maxmap[label] = 1 } } // Sort the maxMap var maxClass string maxVal := -1 for a := range maxmap { if maxmap[a] > maxVal { maxVal = maxmap[a] maxClass = a } } base.SetClass(ret, predRowNo, maxClass) return true, nil }) return ret }
// Predict returns a classification for the vector, based on a vector input, using the KNN algorithm. func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) base.FixedDataGrid { // Check what distance function we are using var distanceFunc pairwise.PairwiseDistanceFunc switch KNN.DistanceFunc { case "euclidean": distanceFunc = pairwise.NewEuclidean() case "manhattan": distanceFunc = pairwise.NewManhattan() default: panic("unsupported distance function") } // Check Compatibility allAttrs := base.CheckCompatible(what, KNN.TrainingData) if allAttrs == nil { // Don't have the same Attributes return nil } // Use optimised version if permitted if KNN.AllowOptimisations { if KNN.DistanceFunc == "euclidean" { if KNN.canUseOptimisations(what) { return KNN.optimisedEuclideanPredict(what.(*base.DenseInstances)) } } } fmt.Println("Optimisations are switched off") // Remove the Attributes which aren't numeric allNumericAttrs := make([]base.Attribute, 0) for _, a := range allAttrs { if fAttr, ok := a.(*base.FloatAttribute); ok { allNumericAttrs = append(allNumericAttrs, fAttr) } } // Generate return vector ret := base.GeneratePredictionVector(what) // Resolve Attribute specifications for both whatAttrSpecs := base.ResolveAttributes(what, allNumericAttrs) trainAttrSpecs := base.ResolveAttributes(KNN.TrainingData, allNumericAttrs) // Reserve storage for most the most similar items distances := make(map[int]float64) // Reserve storage for voting map maxmap := make(map[string]int) // Reserve storage for row computations trainRowBuf := make([]float64, len(allNumericAttrs)) predRowBuf := make([]float64, len(allNumericAttrs)) _, maxRow := what.Size() curRow := 0 // Iterate over all outer rows what.MapOverRows(whatAttrSpecs, func(predRow [][]byte, predRowNo int) (bool, error) { if (curRow%1) == 0 && curRow > 0 { fmt.Printf("KNN: %.2f %% done\n", float64(curRow)*100.0/float64(maxRow)) } curRow++ // Read the float values out for i, _ := range allNumericAttrs { predRowBuf[i] = base.UnpackBytesToFloat(predRow[i]) } predMat := utilities.FloatsToMatrix(predRowBuf) // Find the closest match in the training data KNN.TrainingData.MapOverRows(trainAttrSpecs, func(trainRow [][]byte, srcRowNo int) (bool, error) { // Read the float values out for i, _ := range allNumericAttrs { trainRowBuf[i] = base.UnpackBytesToFloat(trainRow[i]) } // Compute the distance trainMat := utilities.FloatsToMatrix(trainRowBuf) distances[srcRowNo] = distanceFunc.Distance(predMat, trainMat) return true, nil }) sorted := utilities.SortIntMap(distances) values := sorted[:KNN.NearestNeighbours] maxClass := KNN.vote(maxmap, values) base.SetClass(ret, predRowNo, maxClass) return true, nil }) return ret }