func sgdOnSequences(f *rnn.Bidirectional, s []seqtoseq.Sample) { gradient := autofunc.NewGradient(f.Parameters()) for _, x := range s { inRes := seqfunc.ConstResult([][]linalg.Vector{x.Inputs}) output := f.ApplySeqs(inRes) upstreamGrad := make([]linalg.Vector, len(x.Outputs)) for i, o := range x.Outputs { upstreamGrad[i] = o.Copy().Scale(-1) } output.PropagateGradient([][]linalg.Vector{upstreamGrad}, gradient) } for _, vec := range gradient { for i, x := range vec { if x > 0 { vec[i] = 1 } else { vec[i] = -1 } } } gradient.AddToVars(-StepSize) }
func Train(rnnFile, sampleDir string, stepSize float64) { log.Println("Loading samples...") samples, err := ReadSamples(sampleDir) if err != nil { fmt.Fprintln(os.Stderr, "Failed to read samples:", err) os.Exit(1) } var seqFunc *rnn.Bidirectional rnnData, err := ioutil.ReadFile(rnnFile) if err == nil { log.Println("Loaded network from file.") seqFunc, err = rnn.DeserializeBidirectional(rnnData) if err != nil { fmt.Fprintln(os.Stderr, "Failed to deserialize network:", err) os.Exit(1) } } else { log.Println("Created network.") seqFunc = createNetwork(samples) } crossLen := int(CrossRatio * float64(samples.Len())) log.Println("Using", samples.Len()-crossLen, "training and", crossLen, "validation samples...") // Always shuffle the samples in the same way. rand.Seed(123) sgd.ShuffleSampleSet(samples) validation := samples.Subset(0, crossLen) training := samples.Subset(crossLen, samples.Len()) gradienter := &sgd.Adam{ Gradienter: &ctc.RGradienter{ Learner: seqFunc, SeqFunc: seqFunc, MaxConcurrency: MaxConcurrency, MaxSubBatch: MaxSubBatch, }, } var epoch int toggleRegularization(seqFunc, true) sgd.SGDInteractive(gradienter, training, stepSize, BatchSize, func() bool { toggleRegularization(seqFunc, false) cost := ctc.TotalCost(seqFunc, training, CostBatchSize, MaxConcurrency) crossCost := ctc.TotalCost(seqFunc, validation, CostBatchSize, MaxConcurrency) toggleRegularization(seqFunc, true) log.Printf("Epoch %d: cost=%e cross=%e", epoch, cost, crossCost) epoch++ return true }) toggleRegularization(seqFunc, false) data, err := seqFunc.Serialize() if err != nil { fmt.Fprintln(os.Stderr, "Failed to serialize:", err) os.Exit(1) } if err := ioutil.WriteFile(rnnFile, data, 0755); err != nil { fmt.Fprintln(os.Stderr, "Failed to save:", err) os.Exit(1) } }