func main() { test := flag.Bool("test", false, "whether tests neural network or not") modelFilename := flag.String("m", "nn.json", "model filename (*.json)") outFilename := flag.String("o", "nn.json", "outtput model filename (*.json)") learningRate := flag.Float64("learning_rate", 0.1, "Learning rate") epoches := flag.Int("epoch", 50000*10, "Epoches") numHiddenUnits := flag.Int("hidden_units", 100, "Number of hidden units") flag.Parse() if *test == true { Test(*modelFilename) return } trainingPath := "../data/train-images-idx3-ubyte" labelPath := "../data/train-labels-idx1-ubyte" file, err := os.Open(trainingPath) if err != nil { log.Fatal(err) } images, w, h := mnist.ReadMNISTImages(file) fmt.Println(len(images), w, h, w*h) lfile, lerr := os.Open(labelPath) if lerr != nil { log.Fatal(lerr) } labels := mnist.ReadMNISTLabels(lfile) // Convert image to data matrix data := mnist.NormalizePixel(mnist.PrepareX(images)) target := mnist.PrepareY(labels) // Setup Neural Network net := mlp3.NewNeuralNetwork(w*h, *numHiddenUnits, 10) option := mlp3.TrainingOption{ LearningRate: *learningRate, Epoches: *epoches, // the number of iterations in SGD Monitoring: true, } // Perform training start := time.Now() nerr := net.Train(data, target, option) if nerr != nil { log.Fatal(nerr) } elapsed := time.Now().Sub(start) fmt.Println(elapsed) oerr := net.Dump(*outFilename) if oerr != nil { log.Fatal(err) } fmt.Println("Parameters are dummped to", *outFilename) fmt.Println("Training finished!") }
// Classification test using MNIST dataset. func Test(filename string) { net, err := mlp3.Load(filename) if err != nil { log.Fatal(err) } testPath := "../data/t10k-images-idx3-ubyte" targetPath := "../data/t10k-labels-idx1-ubyte" file, err := os.Open(testPath) if err != nil { log.Fatal(err) } images, w, h := mnist.ReadMNISTImages(file) fmt.Println(len(images), w, h, w*h) lfile, lerr := os.Open(targetPath) if lerr != nil { log.Fatal(lerr) } labels := mnist.ReadMNISTLabels(lfile) // Convert image to data matrix data := mnist.NormalizePixel(mnist.PrepareX(images)) target := mnist.PrepareY(labels) result := nnet.Test(net, data) sum := 0.0 for i := range result { if result[i] == nnet.Argmax(target[i]) { sum += 1.0 } } fmt.Printf("Acc. %f (%d/%d)\n", sum/float64(len(result)), int(sum), len(result)) }
func main() { outFilename := flag.String("output", "nn.json", "Output filename (*.json)") modelFilename := flag.String("model", "", "Model filename (*.json)") learningRate := flag.Float64("learning_rate", 0.1, "Learning rate") epoches := flag.Int("epoch", 5, "Epoches") usePersistent := flag.Bool("persistent", false, "Persistent constrastive learning") orderOfGibbsSampling := flag.Int("order", 1, "Order of Gibbs sampling") orderOfDownSampling := flag.Int("down", 1, "Order of down sampling") miniBatchSize := flag.Int("size", 20, "Mini-batch size") l2 := flag.Bool("l2", false, "L2 regularization") numHiddenUnits := flag.Int("hidden_units", 100, "Number of hidden units") flag.Parse() trainingPath := "../data/train-images-idx3-ubyte" file, err := os.Open(trainingPath) if err != nil { log.Fatal(err) } images, w, h := mnist.ReadMNISTImages(file) // Convert image to data matrix data := mnist.PrepareX(images) data = mnist.NormalizePixel(mnist.DownSample(data, w, h, *orderOfDownSampling)) // w and h with down sampled data w, h = w/(*orderOfDownSampling), h/(*orderOfDownSampling) // Create RBM var r *rbm.RBM if *modelFilename != "" { r, err = rbm.Load(*modelFilename) if err != nil { log.Fatal(err) } fmt.Println("Load parameters from", *modelFilename) } else { numVisibleUnits := w * h r = rbm.New(numVisibleUnits, *numHiddenUnits) } // Training option := rbm.TrainingOption{ LearningRate: *learningRate, Epoches: *epoches, OrderOfGibbsSampling: *orderOfGibbsSampling, UsePersistent: *usePersistent, MiniBatchSize: *miniBatchSize, L2Regularization: *l2, RegularizationRate: 0.0001, Monitoring: true, } fmt.Println("Start training") start := time.Now() terr := r.Train(data, option) if terr != nil { log.Fatal(terr) } fmt.Println("Elapsed:", time.Now().Sub(start)) oerr := r.Dump(*outFilename) if oerr != nil { log.Fatal(oerr) } fmt.Println("Parameters are dumped to", *outFilename) fmt.Println("Training finished.") }