コード例 #1
0
ファイル: main.go プロジェクト: unixpickle/weakai
func main() {
	training := mnist.LoadTrainingDataSet()
	samples := make([][]bool, len(training.Samples))
	for i, sample := range training.Samples {
		samples[i] = make([]bool, len(sample.Intensities))
		for j, x := range sample.Intensities {
			if x > 0.5 {
				samples[i][j] = true
			}
		}
	}

	layers := buildLayers()
	trainer := rbm.Trainer{
		GibbsSteps: GibbsSteps,
		StepSize:   BigStepSize,
		Epochs:     BigEpochs,
		BatchSize:  runtime.GOMAXPROCS(0),
	}
	log.Println("Training...")
	trainer.TrainDeep(layers, samples[:1000])
	trainer.StepSize = SmallStepSize
	trainer.Epochs = SmallEpochs
	trainer.TrainDeep(layers, samples[:1000])
	log.Println("Generating outputs...")

	testingSamples := mnist.LoadTestingDataSet()

	mnist.SaveReconstructionGrid("output.png", func(img []float64) []float64 {
		return reconstruct(layers, img)
	}, testingSamples, ReconstructionGridSize, ReconstructionGridSize)
}
コード例 #2
0
ファイル: main.go プロジェクト: unixpickle/weakai
func trainingSamples() []idtrees.Sample {
	set := mnist.LoadTrainingDataSet()
	res := make([]idtrees.Sample, len(set.Samples))
	for i, x := range set.Samples {
		res[i] = newImageSample(x)
	}
	return res
}
コード例 #3
0
ファイル: main.go プロジェクト: unixpickle/weakai
func main() {
	training := mnist.LoadTrainingDataSet()
	crossValidation := mnist.LoadTestingDataSet()

	net := createNet(training)

	trainingSamples := dataSetSamples(training)
	gradienter := &neuralnet.BatchRGradienter{
		Learner:  net.BatchLearner(),
		CostFunc: neuralnet.MeanSquaredCost{},
	}
	rmsGrad := &sgd.RMSProp{Gradienter: gradienter}

	sgd.SGDInteractive(rmsGrad, trainingSamples, StepSize, BatchSize, func() bool {
		log.Println("Printing score...")
		printScore("Cross", net, crossValidation)
		log.Println("Running training round...")
		return true
	})
}
コード例 #4
0
ファイル: main.go プロジェクト: unixpickle/weakai
func main() {
	if len(os.Args) != 2 {
		fmt.Fprintln(os.Stderr, "Usage: mnist_classify <classifier_out.json>")
		os.Exit(1)
	}

	outputFile := os.Args[1]

	training := mnist.LoadTrainingDataSet()

	binSamples := binarySamples(training.Samples)
	classifier := pretrainedClassifier(binSamples)

	trainClassifier(classifier, training)
	data, _ := classifier.Serialize()

	if err := ioutil.WriteFile(outputFile, data, 0755); err != nil {
		fmt.Fprintln(os.Stderr, err)
		os.Exit(1)
	}
}