func TrainCmd(netPath, dirPath string) { log.Println("Loading samples...") images, width, height, err := LoadTrainingImages(dirPath) if err != nil { fmt.Fprintln(os.Stderr, err) os.Exit(1) } log.Println("Creating network...") var network neuralnet.Network networkData, err := ioutil.ReadFile(netPath) if err == nil { network, err = neuralnet.DeserializeNetwork(networkData) if err != nil { fmt.Fprintln(os.Stderr, "Failed to load network:", err) os.Exit(1) } log.Println("Loaded network from file.") } else { mean, stddev := sampleStatistics(images) convLayer := &neuralnet.ConvLayer{ FilterCount: FilterCount, FilterWidth: 4, FilterHeight: 4, Stride: 2, InputWidth: width, InputHeight: height, InputDepth: ImageDepth, } maxLayer := &neuralnet.MaxPoolingLayer{ XSpan: 3, YSpan: 3, InputWidth: convLayer.OutputWidth(), InputHeight: convLayer.OutputHeight(), InputDepth: convLayer.OutputDepth(), } convLayer1 := &neuralnet.ConvLayer{ FilterCount: FilterCount1, FilterWidth: 3, FilterHeight: 3, Stride: 2, InputWidth: maxLayer.OutputWidth(), InputHeight: maxLayer.OutputHeight(), InputDepth: maxLayer.InputDepth, } network = neuralnet.Network{ &neuralnet.RescaleLayer{ Bias: -mean, Scale: 1 / stddev, }, convLayer, neuralnet.HyperbolicTangent{}, maxLayer, neuralnet.HyperbolicTangent{}, convLayer1, neuralnet.HyperbolicTangent{}, &neuralnet.DenseLayer{ InputCount: convLayer1.OutputWidth() * convLayer1.OutputHeight() * convLayer1.OutputDepth(), OutputCount: HiddenSize, }, neuralnet.HyperbolicTangent{}, &neuralnet.DenseLayer{ InputCount: HiddenSize, OutputCount: len(images), }, &neuralnet.LogSoftmaxLayer{}, } network.Randomize() log.Println("Created new network.") } samples := neuralSamples(images) sgd.ShuffleSampleSet(samples) validationCount := int(ValidationFraction * float64(samples.Len())) validationSamples := samples.Subset(0, validationCount) trainingSamples := samples.Subset(validationCount, samples.Len()) costFunc := neuralnet.DotCost{} gradienter := &sgd.Adam{ Gradienter: &neuralnet.BatchRGradienter{ Learner: network.BatchLearner(), CostFunc: &neuralnet.RegularizingCost{ Variables: network.Parameters(), Penalty: Regularization, CostFunc: costFunc, }, }, } sgd.SGDInteractive(gradienter, trainingSamples, StepSize, BatchSize, func() bool { log.Printf("Costs: validation=%d/%d cost=%f", countCorrect(network, validationSamples), validationSamples.Len(), neuralnet.TotalCost(costFunc, network, trainingSamples)) return true }) data, _ := network.Serialize() if err := ioutil.WriteFile(netPath, data, 0755); err != nil { fmt.Fprintln(os.Stderr, "Failed to save:", err) os.Exit(1) } }
func Train(rnnFile, sampleDir string, stepSize float64) { log.Println("Loading samples...") samples, err := ReadSamples(sampleDir) if err != nil { fmt.Fprintln(os.Stderr, "Failed to read samples:", err) os.Exit(1) } var seqFunc *rnn.Bidirectional rnnData, err := ioutil.ReadFile(rnnFile) if err == nil { log.Println("Loaded network from file.") seqFunc, err = rnn.DeserializeBidirectional(rnnData) if err != nil { fmt.Fprintln(os.Stderr, "Failed to deserialize network:", err) os.Exit(1) } } else { log.Println("Created network.") seqFunc = createNetwork(samples) } crossLen := int(CrossRatio * float64(samples.Len())) log.Println("Using", samples.Len()-crossLen, "training and", crossLen, "validation samples...") // Always shuffle the samples in the same way. rand.Seed(123) sgd.ShuffleSampleSet(samples) validation := samples.Subset(0, crossLen) training := samples.Subset(crossLen, samples.Len()) gradienter := &sgd.Adam{ Gradienter: &ctc.RGradienter{ Learner: seqFunc, SeqFunc: seqFunc, MaxConcurrency: MaxConcurrency, MaxSubBatch: MaxSubBatch, }, } var epoch int toggleRegularization(seqFunc, true) sgd.SGDInteractive(gradienter, training, stepSize, BatchSize, func() bool { toggleRegularization(seqFunc, false) cost := ctc.TotalCost(seqFunc, training, CostBatchSize, MaxConcurrency) crossCost := ctc.TotalCost(seqFunc, validation, CostBatchSize, MaxConcurrency) toggleRegularization(seqFunc, true) log.Printf("Epoch %d: cost=%e cross=%e", epoch, cost, crossCost) epoch++ return true }) toggleRegularization(seqFunc, false) data, err := seqFunc.Serialize() if err != nil { fmt.Fprintln(os.Stderr, "Failed to serialize:", err) os.Exit(1) } if err := ioutil.WriteFile(rnnFile, data, 0755); err != nil { fmt.Fprintln(os.Stderr, "Failed to save:", err) os.Exit(1) } }