func batchLearn(train *mnist.Set) *linear.Softmax { var err error x := [][]float64{} y := []float64{} size := train.Count() bytes, label := train.Get(0) bar := newBar(int64(size)) bar.Start() count := 0 now := time.Now() fmt.Printf("Starting Loading of < Batch Softmax > at %v\n", now) for count = 0; count < size; count++ { inp := make([]float64, len(bytes)) out := float64(label) for i := range bytes { //x[i] = float64(bytes[i]) / 255 if bytes[i] > uint8(167) { inp[i] = 1 } } x = append(x, inp) y = append(y, out) bytes, label = train.Get(count) bar.Increment() } bar.FinishPrint(fmt.Sprintf("Loaded %v examples into < Batch Softmax >\n\tTook %v", count, time.Now().Sub(now))) model := linear.NewSoftmax(base.StochasticGA, 1e-5, 10, 10, int(Epochs), x, y) now = time.Now() fmt.Printf("Starting Training of < Batch Softmax > at %v over %v epochs\n", now, Epochs) err = model.Learn() if err != nil { panic(fmt.Sprintf("Error learning on < Batch Softmax >!\n\t%v", err)) } fmt.Printf("Stopped Training < Batch Softmax > at %v\n\tTraining Time: %v\n", time.Now(), time.Now().Sub(now)) return model }
func onlineLearn(train *mnist.Set) *linear.Softmax { var err error stream := make(chan base.Datapoint, 1000) errors := make(chan error, 200) now := time.Now() fmt.Printf("Starting Loading/Training of < Online Softmax > at %v\n", now) model := linear.NewSoftmax(base.StochasticGA, 1e-5, 10, 10, 0, nil, nil, 784) go model.OnlineLearn(errors, stream, func(theta [][]float64) {}) // push data onto the stream while waiting for errors go func() { size := train.Count() bytes, label := train.Get(0) bar := newBar(int64(size) * Epochs) bar.Start() count := 0 for iter := int64(0); iter < Epochs; iter++ { for count = 0; count < size; count++ { x := make([]float64, len(bytes)) y := []float64{float64(label)} for i := range bytes { //x[i] = float64(bytes[i]) / 255 if bytes[i] > uint8(200) { x[i] = 1 } } stream <- base.Datapoint{ X: x, Y: y, } bytes, label = train.Get(count) bar.Increment() } } bar.FinishPrint(fmt.Sprintf("Loaded %v examples onto the data stream of < Online Softmax >\n\tRepeated through %v epochs", count, Epochs)) close(stream) }() var more bool for { err, more = <-errors if err != nil { fmt.Printf("Error encountered when training!\n\t%v\n", err) } if !more { break } } // now the model is trained! Test it! fmt.Printf("Stopped Loading/Training of < Online Softmax > at %v\n\tTraining Time: %v\n", time.Now(), time.Now().Sub(now)) return model }