示例#1
0
文件: knn.go 项目: npbool/golearn
// Returns a classification for the vector, based on a vector input, using the KNN algorithm.
// See http://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm.
func (KNN *KNNClassifier) Predict(vector []float64, K int) string {

	convertedVector := util.FloatsToMatrix(vector)
	// Get the number of rows
	rows, _ := KNN.Data.Dims()
	rownumbers := make(map[int]float64)
	labels := make([]string, 0)
	maxmap := make(map[string]int)

	// Check what distance function we are using
	switch KNN.DistanceFunc {
	case "euclidean":
		{
			euclidean := pairwiseMetrics.NewEuclidean()
			for i := 0; i < rows; i++ {
				row := KNN.Data.RowView(i)
				rowMat := util.FloatsToMatrix(row)
				distance := euclidean.Distance(rowMat, convertedVector)
				rownumbers[i] = distance
			}
		}
	case "manhattan":
		{
			manhattan := pairwiseMetrics.NewEuclidean()
			for i := 0; i < rows; i++ {
				row := KNN.Data.RowView(i)
				rowMat := util.FloatsToMatrix(row)
				distance := manhattan.Distance(rowMat, convertedVector)
				rownumbers[i] = distance
			}
		}
	}

	sorted := util.SortIntMap(rownumbers)
	values := sorted[:K]

	for _, elem := range values {
		// It's when we access this map
		labels = append(labels, KNN.Labels[elem])

		if _, ok := maxmap[KNN.Labels[elem]]; ok {
			maxmap[KNN.Labels[elem]] += 1
		} else {
			maxmap[KNN.Labels[elem]] = 1
		}
	}

	sortedlabels := util.SortStringMap(maxmap)
	label := sortedlabels[0]

	return label
}
示例#2
0
文件: knn.go 项目: 24hours/golearn
// Returns a classification for the vector, based on a vector input, using the KNN algorithm.
// See http://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm.
func (KNN *KNNClassifier) PredictOne(vector []float64) string {

	rows := KNN.TrainingData.Rows
	rownumbers := make(map[int]float64)
	labels := make([]string, 0)
	maxmap := make(map[string]int)

	convertedVector := util.FloatsToMatrix(vector)

	// Check what distance function we are using
	switch KNN.DistanceFunc {
	case "euclidean":
		{
			euclidean := pairwiseMetrics.NewEuclidean()
			for i := 0; i < rows; i++ {
				row := KNN.TrainingData.GetRowVectorWithoutClass(i)
				rowMat := util.FloatsToMatrix(row)
				distance := euclidean.Distance(rowMat, convertedVector)
				rownumbers[i] = distance
			}
		}
	case "manhattan":
		{
			manhattan := pairwiseMetrics.NewEuclidean()
			for i := 0; i < rows; i++ {
				row := KNN.TrainingData.GetRowVectorWithoutClass(i)
				rowMat := util.FloatsToMatrix(row)
				distance := manhattan.Distance(rowMat, convertedVector)
				rownumbers[i] = distance
			}
		}
	}

	sorted := util.SortIntMap(rownumbers)
	values := sorted[:KNN.NearestNeighbours]

	for _, elem := range values {
		label := KNN.TrainingData.GetClass(elem)
		labels = append(labels, label)

		if _, ok := maxmap[label]; ok {
			maxmap[label] += 1
		} else {
			maxmap[label] = 1
		}
	}

	sortedlabels := util.SortStringMap(maxmap)
	label := sortedlabels[0]

	return label
}