コード例 #1
0
func main() {

	fmt.Printf("Loading MNIST dataset into memory...\n")

	train, test, err := mnist.Load("./")
	if err != nil {
		panic(fmt.Sprintf("Error loading training set!\n\t%v\n\n", err))
	}

	fmt.Printf("MNIST dataset loaded!\n\t%v Training Examples\n\t%v Test Examples\n", len(train.Images), len(test.Images))

	//online := onlineLearn(train)

	//eval(online, test, "Online Softmax")

	batch := batchLearn(train)

	eval(batch, test, "Batch Softmax")
}
コード例 #2
0
ファイル: mnist.go プロジェクト: sguzwf/mlf
func main() {
	flag.Parse()
	set, _, err := GoMNIST.Load(*data)
	//_, set, err := GoMNIST.Load(*data)
	if err != nil {
		log.Fatal("无法载入数据")
	}
	log.Printf("#images = %d", len(set.Images))

	for i := 0; i < len(set.Images); i++ {
		content := fmt.Sprintf("%d 1:0", set.Labels[i])
		image := set.Images[i]
		for index, p := range image {
			if p != 0 {
				content = fmt.Sprintf("%s %d:%0.3f", content, index+1, float32(p)/255.)
			}
		}
		fmt.Printf("%s %d:0\n", content, len(image))
		if i%1000 == 0 {
			log.Printf("已处理 %d 条记录", i)
		}
	}
}
コード例 #3
0
ファイル: cmdline.go プロジェクト: evilrobot69/NeuralGo
func main() {
	flag.Parse()
	if *cpuProfileFlag != "" {
		f, err := os.Create(*cpuProfileFlag)
		if err != nil {
			log.Fatal(err)
		}
		pprof.StartCPUProfile(f)
		defer pprof.StopCPUProfile()
	}

	rand.Seed(time.Now().UTC().UnixNano())

	// Set up neural network.
	var neuralNetwork *neural.Network
	var trainingExamples []neural.Datapoint
	var testingExamples []neural.Datapoint
	if len(*mnistFlag) > 0 {
		train, test, err := GoMNIST.Load(*mnistFlag)
		if err != nil {
			log.Fatal(err)
		}
		for i := 0; i < train.Count(); i++ {
			var datapoint neural.Datapoint
			image, label := train.Get(i)
			datapoint.Values = append(datapoint.Values, float64(label))
			for _, pixel := range image {
				datapoint.Features = append(datapoint.Features, float64(pixel))
			}
			trainingExamples = append(trainingExamples, datapoint)
		}
		for i := 0; i < test.Count(); i++ {
			var datapoint neural.Datapoint
			image, label := test.Get(i)
			datapoint.Values = append(datapoint.Values, float64(label))
			for _, pixel := range image {
				datapoint.Features = append(datapoint.Features, float64(pixel))
			}
			testingExamples = append(testingExamples, datapoint)
		}
	} else {
		trainingExamples = ReadDatapointsOrDie(*trainingExamplesFlag)
		testingExamples = ReadDatapointsOrDie(*testingExamplesFlag)
	}
	fmt.Printf("Finished loading data!\n")

	byteNetwork, err := ioutil.ReadFile(*serializedNetworkFlag)
	if err != nil {
		log.Fatal(err)
	}
	neuralNetwork = new(neural.Network)
	neuralNetwork.Deserialize(byteNetwork)
	// If synapse weights aren't specified, randomize them.
	if neuralNetwork.Layers[0].Weight.At(0, 0) == 0 {
		neuralNetwork.RandomizeSynapses()
	}
	fmt.Printf("Finished creating the network!\n")

	// Train the model.
	learningConfiguration := neural.LearningConfiguration{
		Epochs:    proto.Int32(int32(*trainingIterationsFlag)),
		Rate:      proto.Float64(*learningRateFlag),
		Decay:     proto.Float64(*weightDecayFlag),
		BatchSize: proto.Int32(int32(*batchSizeFlag)),
		ErrorName: neural.ErrorName(neural.ErrorName_value[*errorNameFlag]).Enum(),
	}
	neural.Train(neuralNetwork, trainingExamples, learningConfiguration)

	// Test & output model:
	fmt.Printf("Training error: %v\nTesting error: %v\n",
		neural.Evaluate(*neuralNetwork, trainingExamples),
		neural.Evaluate(*neuralNetwork, testingExamples))
	if len(*serializedNetworkOutFlag) > 0 {
		ioutil.WriteFile(*serializedNetworkOutFlag, neuralNetwork.Serialize(), 0777)
	}
}