예제 #1
0
func main() {
	test := flag.Bool("test", false, "whether tests neural network or not")
	modelFilename := flag.String("m", "nn.json", "model filename (*.json)")
	outFilename := flag.String("o", "nn.json", "outtput model filename (*.json)")
	learningRate := flag.Float64("learning_rate", 0.1, "Learning rate")
	epoches := flag.Int("epoch", 50000*10, "Epoches")
	numHiddenUnits := flag.Int("hidden_units", 100, "Number of hidden units")

	flag.Parse()
	if *test == true {
		Test(*modelFilename)
		return
	}

	trainingPath := "../data/train-images-idx3-ubyte"
	labelPath := "../data/train-labels-idx1-ubyte"
	file, err := os.Open(trainingPath)
	if err != nil {
		log.Fatal(err)
	}

	images, w, h := mnist.ReadMNISTImages(file)
	fmt.Println(len(images), w, h, w*h)

	lfile, lerr := os.Open(labelPath)
	if lerr != nil {
		log.Fatal(lerr)
	}
	labels := mnist.ReadMNISTLabels(lfile)

	// Convert image to data matrix
	data := mnist.NormalizePixel(mnist.PrepareX(images))
	target := mnist.PrepareY(labels)

	// Setup Neural Network
	net := mlp3.NewNeuralNetwork(w*h, *numHiddenUnits, 10)
	option := mlp3.TrainingOption{
		LearningRate: *learningRate,
		Epoches:      *epoches, // the number of iterations in SGD
		Monitoring:   true,
	}

	// Perform training
	start := time.Now()
	nerr := net.Train(data, target, option)
	if nerr != nil {
		log.Fatal(nerr)
	}
	elapsed := time.Now().Sub(start)
	fmt.Println(elapsed)

	oerr := net.Dump(*outFilename)
	if oerr != nil {
		log.Fatal(err)
	}
	fmt.Println("Parameters are dummped to", *outFilename)
	fmt.Println("Training finished!")
}
예제 #2
0
// Classification test using MNIST dataset.
func Test(filename string) {
	net, err := mlp3.Load(filename)
	if err != nil {
		log.Fatal(err)
	}

	testPath := "../data/t10k-images-idx3-ubyte"
	targetPath := "../data/t10k-labels-idx1-ubyte"

	file, err := os.Open(testPath)
	if err != nil {
		log.Fatal(err)
	}
	images, w, h := mnist.ReadMNISTImages(file)
	fmt.Println(len(images), w, h, w*h)

	lfile, lerr := os.Open(targetPath)
	if lerr != nil {
		log.Fatal(lerr)
	}
	labels := mnist.ReadMNISTLabels(lfile)

	// Convert image to data matrix
	data := mnist.NormalizePixel(mnist.PrepareX(images))
	target := mnist.PrepareY(labels)

	result := nnet.Test(net, data)

	sum := 0.0
	for i := range result {
		if result[i] == nnet.Argmax(target[i]) {
			sum += 1.0
		}
	}
	fmt.Printf("Acc. %f (%d/%d)\n", sum/float64(len(result)),
		int(sum), len(result))
}
예제 #3
0
func main() {
	outFilename := flag.String("output", "nn.json", "Output filename (*.json)")
	modelFilename := flag.String("model", "", "Model filename (*.json)")
	learningRate := flag.Float64("learning_rate", 0.1, "Learning rate")
	epoches := flag.Int("epoch", 5, "Epoches")
	usePersistent := flag.Bool("persistent", false, "Persistent constrastive learning")
	orderOfGibbsSampling := flag.Int("order", 1, "Order of Gibbs sampling")
	orderOfDownSampling := flag.Int("down", 1, "Order of down sampling")
	miniBatchSize := flag.Int("size", 20, "Mini-batch size")
	l2 := flag.Bool("l2", false, "L2 regularization")
	numHiddenUnits := flag.Int("hidden_units", 100, "Number of hidden units")
	flag.Parse()

	trainingPath := "../data/train-images-idx3-ubyte"
	file, err := os.Open(trainingPath)
	if err != nil {
		log.Fatal(err)
	}
	images, w, h := mnist.ReadMNISTImages(file)

	// Convert image to data matrix
	data := mnist.PrepareX(images)
	data = mnist.NormalizePixel(mnist.DownSample(data, w, h, *orderOfDownSampling))

	// w and h with down sampled data
	w, h = w/(*orderOfDownSampling), h/(*orderOfDownSampling)

	// Create RBM
	var r *rbm.RBM
	if *modelFilename != "" {
		r, err = rbm.Load(*modelFilename)
		if err != nil {
			log.Fatal(err)

		}
		fmt.Println("Load parameters from", *modelFilename)
	} else {
		numVisibleUnits := w * h
		r = rbm.New(numVisibleUnits, *numHiddenUnits)
	}

	// Training
	option := rbm.TrainingOption{
		LearningRate:         *learningRate,
		Epoches:              *epoches,
		OrderOfGibbsSampling: *orderOfGibbsSampling,
		UsePersistent:        *usePersistent,
		MiniBatchSize:        *miniBatchSize,
		L2Regularization:     *l2,
		RegularizationRate:   0.0001,
		Monitoring:           true,
	}

	fmt.Println("Start training")
	start := time.Now()
	terr := r.Train(data, option)
	if terr != nil {
		log.Fatal(terr)
	}
	fmt.Println("Elapsed:", time.Now().Sub(start))

	oerr := r.Dump(*outFilename)
	if oerr != nil {
		log.Fatal(oerr)
	}
	fmt.Println("Parameters are dumped to", *outFilename)
	fmt.Println("Training finished.")
}