Exemplo n.º 1
0
func batchLearn(train *mnist.Set) *linear.Softmax {
	var err error
	x := [][]float64{}
	y := []float64{}

	size := train.Count()
	bytes, label := train.Get(0)

	bar := newBar(int64(size))
	bar.Start()
	count := 0

	now := time.Now()
	fmt.Printf("Starting Loading of < Batch Softmax > at %v\n", now)

	for count = 0; count < size; count++ {
		inp := make([]float64, len(bytes))
		out := float64(label)

		for i := range bytes {
			//x[i] = float64(bytes[i]) / 255
			if bytes[i] > uint8(167) {
				inp[i] = 1
			}
		}

		x = append(x, inp)
		y = append(y, out)

		bytes, label = train.Get(count)

		bar.Increment()
	}

	bar.FinishPrint(fmt.Sprintf("Loaded %v examples into < Batch Softmax >\n\tTook %v", count, time.Now().Sub(now)))

	model := linear.NewSoftmax(base.StochasticGA, 1e-5, 10, 10, int(Epochs), x, y)

	now = time.Now()
	fmt.Printf("Starting Training of < Batch Softmax > at %v over %v epochs\n", now, Epochs)

	err = model.Learn()
	if err != nil {
		panic(fmt.Sprintf("Error learning on < Batch Softmax >!\n\t%v", err))
	}

	fmt.Printf("Stopped Training < Batch Softmax > at %v\n\tTraining Time: %v\n", time.Now(), time.Now().Sub(now))

	return model
}
Exemplo n.º 2
0
func onlineLearn(train *mnist.Set) *linear.Softmax {
	var err error

	stream := make(chan base.Datapoint, 1000)
	errors := make(chan error, 200)

	now := time.Now()
	fmt.Printf("Starting Loading/Training of < Online Softmax > at %v\n", now)

	model := linear.NewSoftmax(base.StochasticGA, 1e-5, 10, 10, 0, nil, nil, 784)

	go model.OnlineLearn(errors, stream, func(theta [][]float64) {})

	// push data onto the stream while waiting for errors
	go func() {
		size := train.Count()
		bytes, label := train.Get(0)

		bar := newBar(int64(size) * Epochs)
		bar.Start()
		count := 0
		for iter := int64(0); iter < Epochs; iter++ {
			for count = 0; count < size; count++ {
				x := make([]float64, len(bytes))
				y := []float64{float64(label)}

				for i := range bytes {
					//x[i] = float64(bytes[i]) / 255
					if bytes[i] > uint8(200) {
						x[i] = 1
					}
				}

				stream <- base.Datapoint{
					X: x,
					Y: y,
				}

				bytes, label = train.Get(count)

				bar.Increment()
			}
		}

		bar.FinishPrint(fmt.Sprintf("Loaded %v examples onto the data stream of < Online Softmax >\n\tRepeated through %v epochs", count, Epochs))

		close(stream)
	}()

	var more bool
	for {
		err, more = <-errors
		if err != nil {
			fmt.Printf("Error encountered when training!\n\t%v\n", err)
		}
		if !more {
			break
		}
	}

	// now the model is trained! Test it!
	fmt.Printf("Stopped Loading/Training of < Online Softmax > at %v\n\tTraining Time: %v\n", time.Now(), time.Now().Sub(now))

	return model
}