示例#1
0
文件: main.go 项目: unixpickle/weakai
func main() {
	training := mnist.LoadTrainingDataSet()
	samples := make([][]bool, len(training.Samples))
	for i, sample := range training.Samples {
		samples[i] = make([]bool, len(sample.Intensities))
		for j, x := range sample.Intensities {
			if x > 0.5 {
				samples[i][j] = true
			}
		}
	}

	layers := buildLayers()
	trainer := rbm.Trainer{
		GibbsSteps: GibbsSteps,
		StepSize:   BigStepSize,
		Epochs:     BigEpochs,
		BatchSize:  runtime.GOMAXPROCS(0),
	}
	log.Println("Training...")
	trainer.TrainDeep(layers, samples[:1000])
	trainer.StepSize = SmallStepSize
	trainer.Epochs = SmallEpochs
	trainer.TrainDeep(layers, samples[:1000])
	log.Println("Generating outputs...")

	testingSamples := mnist.LoadTestingDataSet()

	mnist.SaveReconstructionGrid("output.png", func(img []float64) []float64 {
		return reconstruct(layers, img)
	}, testingSamples, ReconstructionGridSize, ReconstructionGridSize)
}
示例#2
0
文件: main.go 项目: unixpickle/weakai
func main() {
	log.Println("Creating training samples...")
	samples := trainingSamples()
	attrs := trainingAttrs()
	log.Println("Training forest...")
	forest := idtrees.BuildForest(ForestSize, samples, attrs, TrainingSize, 75,
		func(s []idtrees.Sample, a []idtrees.Attr) *idtrees.Tree {
			return idtrees.ID3(s, a, 0)
		})
	log.Println("Running classifications...")
	hist := mnist.LoadTestingDataSet().CorrectnessHistogram(func(data []float64) int {
		sample := newImageSample(mnist.Sample{Intensities: data})
		res := forest.Classify(sample)
		var maxVal float64
		var maxClass int
		for class, x := range res {
			if x > maxVal {
				maxVal = x
				maxClass = class.(int)
			}
		}
		return maxClass
	})
	log.Println("Results:", hist)
}
示例#3
0
文件: main.go 项目: unixpickle/weakai
func main() {
	training := mnist.LoadTrainingDataSet()
	crossValidation := mnist.LoadTestingDataSet()

	net := createNet(training)

	trainingSamples := dataSetSamples(training)
	gradienter := &neuralnet.BatchRGradienter{
		Learner:  net.BatchLearner(),
		CostFunc: neuralnet.MeanSquaredCost{},
	}
	rmsGrad := &sgd.RMSProp{Gradienter: gradienter}

	sgd.SGDInteractive(rmsGrad, trainingSamples, StepSize, BatchSize, func() bool {
		log.Println("Printing score...")
		printScore("Cross", net, crossValidation)
		log.Println("Running training round...")
		return true
	})
}
示例#4
0
文件: main.go 项目: unixpickle/weakai
func trainClassifier(n neuralnet.Network, d mnist.DataSet) {
	log.Println("Training classifier (ctrl+C to finish)...")

	killChan := make(chan struct{})

	go func() {
		c := make(chan os.Signal, 1)
		signal.Notify(c, os.Interrupt)
		<-c
		signal.Stop(c)
		fmt.Println("\nCaught interrupt. Ctrl+C again to terminate.")
		close(killChan)
	}()

	inputs := make([]linalg.Vector, len(d.Samples))
	outputs := make([]linalg.Vector, len(d.Samples))
	for i, x := range d.IntensityVectors() {
		inputs[i] = x
	}
	for i, x := range d.LabelVectors() {
		outputs[i] = x
	}
	samples := neuralnet.VectorSampleSet(inputs, outputs)
	batcher := &neuralnet.BatchRGradienter{
		Learner:  n.BatchLearner(),
		CostFunc: neuralnet.MeanSquaredCost{},
	}

	crossValidation := mnist.LoadTestingDataSet()

	sgd.SGDInteractive(batcher, samples, ClassifierStepSize,
		ClassifierBatchSize, func() bool {
			printScore("Training", n, d)
			printScore("Cross", n, crossValidation)
			return true
		})
}