func printScore(prefix string, n neuralnet.Network, d mnist.DataSet) { classifier := func(v []float64) int { r := n.Apply(&autofunc.Variable{v}) return networkOutput(r) } correctCount := d.NumCorrect(classifier) histogram := d.CorrectnessHistogram(classifier) log.Printf("%s: %d/%d - %s", prefix, correctCount, len(d.Samples), histogram) }
func trainClassifier(n neuralnet.Network, d mnist.DataSet) { log.Println("Training classifier (ctrl+C to finish)...") killChan := make(chan struct{}) go func() { c := make(chan os.Signal, 1) signal.Notify(c, os.Interrupt) <-c signal.Stop(c) fmt.Println("\nCaught interrupt. Ctrl+C again to terminate.") close(killChan) }() inputs := make([]linalg.Vector, len(d.Samples)) outputs := make([]linalg.Vector, len(d.Samples)) for i, x := range d.IntensityVectors() { inputs[i] = x } for i, x := range d.LabelVectors() { outputs[i] = x } samples := neuralnet.VectorSampleSet(inputs, outputs) batcher := &neuralnet.BatchRGradienter{ Learner: n.BatchLearner(), CostFunc: neuralnet.MeanSquaredCost{}, } crossValidation := mnist.LoadTestingDataSet() sgd.SGDInteractive(batcher, samples, ClassifierStepSize, ClassifierBatchSize, func() bool { printScore("Training", n, d) printScore("Cross", n, crossValidation) return true }) }
func dataSetSamples(d mnist.DataSet) sgd.SampleSet { labelVecs := d.LabelVectors() inputVecs := d.IntensityVectors() return neuralnet.VectorSampleSet(vecVec(inputVecs), vecVec(labelVecs)) }