func main() { fmt.Printf("Loading MNIST dataset into memory...\n") train, test, err := mnist.Load("./") if err != nil { panic(fmt.Sprintf("Error loading training set!\n\t%v\n\n", err)) } fmt.Printf("MNIST dataset loaded!\n\t%v Training Examples\n\t%v Test Examples\n", len(train.Images), len(test.Images)) //online := onlineLearn(train) //eval(online, test, "Online Softmax") batch := batchLearn(train) eval(batch, test, "Batch Softmax") }
func main() { flag.Parse() set, _, err := GoMNIST.Load(*data) //_, set, err := GoMNIST.Load(*data) if err != nil { log.Fatal("无法载入数据") } log.Printf("#images = %d", len(set.Images)) for i := 0; i < len(set.Images); i++ { content := fmt.Sprintf("%d 1:0", set.Labels[i]) image := set.Images[i] for index, p := range image { if p != 0 { content = fmt.Sprintf("%s %d:%0.3f", content, index+1, float32(p)/255.) } } fmt.Printf("%s %d:0\n", content, len(image)) if i%1000 == 0 { log.Printf("已处理 %d 条记录", i) } } }
func main() { flag.Parse() if *cpuProfileFlag != "" { f, err := os.Create(*cpuProfileFlag) if err != nil { log.Fatal(err) } pprof.StartCPUProfile(f) defer pprof.StopCPUProfile() } rand.Seed(time.Now().UTC().UnixNano()) // Set up neural network. var neuralNetwork *neural.Network var trainingExamples []neural.Datapoint var testingExamples []neural.Datapoint if len(*mnistFlag) > 0 { train, test, err := GoMNIST.Load(*mnistFlag) if err != nil { log.Fatal(err) } for i := 0; i < train.Count(); i++ { var datapoint neural.Datapoint image, label := train.Get(i) datapoint.Values = append(datapoint.Values, float64(label)) for _, pixel := range image { datapoint.Features = append(datapoint.Features, float64(pixel)) } trainingExamples = append(trainingExamples, datapoint) } for i := 0; i < test.Count(); i++ { var datapoint neural.Datapoint image, label := test.Get(i) datapoint.Values = append(datapoint.Values, float64(label)) for _, pixel := range image { datapoint.Features = append(datapoint.Features, float64(pixel)) } testingExamples = append(testingExamples, datapoint) } } else { trainingExamples = ReadDatapointsOrDie(*trainingExamplesFlag) testingExamples = ReadDatapointsOrDie(*testingExamplesFlag) } fmt.Printf("Finished loading data!\n") byteNetwork, err := ioutil.ReadFile(*serializedNetworkFlag) if err != nil { log.Fatal(err) } neuralNetwork = new(neural.Network) neuralNetwork.Deserialize(byteNetwork) // If synapse weights aren't specified, randomize them. if neuralNetwork.Layers[0].Weight.At(0, 0) == 0 { neuralNetwork.RandomizeSynapses() } fmt.Printf("Finished creating the network!\n") // Train the model. learningConfiguration := neural.LearningConfiguration{ Epochs: proto.Int32(int32(*trainingIterationsFlag)), Rate: proto.Float64(*learningRateFlag), Decay: proto.Float64(*weightDecayFlag), BatchSize: proto.Int32(int32(*batchSizeFlag)), ErrorName: neural.ErrorName(neural.ErrorName_value[*errorNameFlag]).Enum(), } neural.Train(neuralNetwork, trainingExamples, learningConfiguration) // Test & output model: fmt.Printf("Training error: %v\nTesting error: %v\n", neural.Evaluate(*neuralNetwork, trainingExamples), neural.Evaluate(*neuralNetwork, testingExamples)) if len(*serializedNetworkOutFlag) > 0 { ioutil.WriteFile(*serializedNetworkOutFlag, neuralNetwork.Serialize(), 0777) } }