コード例 #1
0
ファイル: stream_test.go プロジェクト: wilseypa/rphash-golang
func TestStreamingKMeansOnNumImagesData(t *testing.T) {
	numClusters := 10
	lines, err := utils.ReadLines("../demo/data/MNISTnumImages5000.txt")
	if err != nil {
		panic(err)
	}
	dimensionality := len(lines[0])
	data := utils.StringArrayToFloatArray(lines)

	start := time.Now()
	kmeansStream := clusterer.NewKMeansStream(numClusters, 10, dimensionality)
	for _, vector := range data {
		kmeansStream.AddDataPoint(vector)
	}

	result := kmeansStream.GetCentroids()
	time := time.Since(start)
	totalSqDist := float64(0)
	for _, vector := range data {
		_, dist := utils.FindNearestDistance(vector, result)
		totalSqDist += dist * dist
	}

	t.Log("Total Square Distance: ", totalSqDist)
	t.Log("Average Square Distance: ", totalSqDist/float64(len(data)))
	t.Log("Runtime(seconds): ", time.Seconds())

	if len(result) != numClusters {
		t.Errorf("RPHash Stream did not present the correct number of clusters.")
	}
}
コード例 #2
0
ファイル: simple_test.go プロジェクト: wilseypa/rphash-golang
func TestRPHashSimpleOnNumImagesData(t *testing.T) {
	numClusters := 10
	lines, err := utils.ReadLines("../demo/data/MNISTnumImages5000.txt")
	if err != nil {
		panic(err)
	}
	data := utils.StringArrayToFloatArray(lines)

	start := time.Now()
	RPHashObject := reader.NewSimpleArray(data, numClusters)
	simpleObject := simple.NewSimple(RPHashObject)
	simpleObject.Run()

	result := RPHashObject.GetCentroids()
	time := time.Since(start)

	totalSqDist := float64(0)
	for _, vector := range data {
		_, dist := utils.FindNearestDistance(vector, result)
		totalSqDist += dist * dist
	}

	t.Log("Total Square Distance: ", totalSqDist)
	t.Log("Average Square Distance: ", totalSqDist/float64(len(data)))
	t.Log("Runtime(seconds): ", time.Seconds())

	if len(result) != numClusters {
		t.Errorf("RPHash Stream did not present the correct number of clusters.")
	}
}
コード例 #3
0
ファイル: main.go プロジェクト: wilseypa/rphash-golang
func main() {
	var rphashObject *reader.StreamObject
	var rphashStream *stream.Stream
	var centroids []types.Centroid
	t1 := time.Now()
	// Split the data into shards and send them to the Agents to work on.
	f.Source(func(out chan Vector) {
		records, err := utils.ReadLines(dataFilePath)
		if err != nil {
			panic(err)
		}
		// Convert the record to standard floating points.
		for i, record := range records {
			if i == 0 {
				// Create a new RPHash stream.
				rphashObject = reader.NewStreamObject(len(record), numClusters)
				rphashStream = stream.NewStream(rphashObject)
				rphashStream.RunCount = 1
			}
			data := make([]float64, len(record))
			for j, entry := range record {
				f, err := strconv.ParseFloat(entry, 64)
				f = parse.Normalize(f)
				if err != nil {
					panic(err)
				}
				data[j] = f
			}
			out <- Vector{Data: data}
		}
	}, numShards).Map(func(vec Vector) {
		centroids = append(centroids, rphashStream.AddVectorOnlineStep(vec.Data))
	}).Run()

	for _, cent := range centroids {
		rphashStream.CentroidCounter.Add(cent)
	}
	normalizedResults := rphashStream.GetCentroids()
	t2 := time.Now()
	log.Println("Time: ", t2.Sub(t1))

	denormalizedResults := make([][]float64, len(normalizedResults))
	for i, result := range normalizedResults {
		row := make([]float64, len(result))
		for j, dimension := range result {
			row[j] = parse.DeNormalize(dimension)
		}
		denormalizedResults[i] = row
	}
	labels := make([]string, len(denormalizedResults))
	xPlotValues := make([][]float64, len(denormalizedResults))
	yPlotValues := make([][]float64, len(denormalizedResults))
	for i, result := range denormalizedResults {
		xPlotValues[i] = make([]float64, len(result))
		yPlotValues[i] = make([]float64, len(result))
		for j, val := range result {
			xPlotValues[i][j] = float64(j)
			yPlotValues[i][j] = val
		}
		Paint(result, i)
		sI := strconv.FormatInt(int64(i), 16)
		labels[i] = "Digit " + sI + " (by Classifier Centroid)"
	}
	GeneratePlots(xPlotValues, yPlotValues, "High Dimension Handwritting Digits 0-9 Classification", "Dimension", "Strength of Visual Pixel Recognition (0-1000)", "plots/centroid-dimensions-", labels)
}