예제 #1
0
파일: nn.go 프로젝트: garretraziel/nn
func (network NN) updateMiniBatch(batch []TrainItem, eta, lmbda float64, n int) {
	var err error
	cxw := make([]matrices.Matrix, len(network.weights))
	cxb := make([]matrices.Matrix, len(network.biases))
	for i, m := range network.weights {
		cxw[i] = matrices.InitMatrix(m.Rows(), m.Cols())
	}
	for i, m := range network.biases {
		cxb[i] = matrices.InitMatrix(m.Rows(), m.Cols())
	}

	for _, item := range batch {
		nablaW, nablaB := network.backprop(item)
		for i, nabla := range nablaW {
			cxw[i], err = cxw[i].Add(nabla)
			if err != nil {
				panic(err)
			}
		}
		for i, nabla := range nablaB {
			cxb[i], err = cxb[i].Add(nabla)
			if err != nil {
				panic(err)
			}
		}
	}
	multByConst := matrices.Mult(eta / float64(len(batch)))
	for i, w := range cxw {
		regularization := matrices.Mult(1 - eta*lmbda/float64(n))
		reduced := w.Apply(multByConst)
		network.weights[i], err = network.weights[i].Apply(regularization).Sub(reduced)
		if err != nil {
			panic(err)
		}
	}
	for i, b := range cxb {
		reduced := b.Apply(multByConst)
		network.biases[i], err = network.biases[i].Sub(reduced)
		if err != nil {
			panic(err)
		}
	}
}
예제 #2
0
파일: nn.go 프로젝트: garretraziel/nn
func (network NN) backprop(item TrainItem) ([]matrices.Matrix, []matrices.Matrix) {
	nablaW := make([]matrices.Matrix, len(network.weights))
	nablaB := make([]matrices.Matrix, len(network.biases))
	for i, m := range network.weights {
		nablaW[i] = matrices.InitMatrix(m.Rows(), m.Cols())
	}
	for i, m := range network.biases {
		nablaB[i] = matrices.InitMatrix(m.Rows(), m.Cols())
	}

	activation := item.Values
	activations := make([]matrices.Matrix, len(network.weights)+1)
	activations[0] = activation
	zs := make([]matrices.Matrix, len(network.weights))

	for i := range network.weights {
		weights := network.weights[i]
		biases := network.biases[i]
		multiplied, err := activation.Dot(weights)
		if err != nil {
			panic(err)
		}
		z, err := multiplied.Add(biases)
		if err != nil {
			panic(err)
		}
		zs[i] = z
		activation = z.Sigmoid()
		activations[i+1] = activation
	}

	y, err := matrices.OneHotMatrix(1, item.Distinct, 0, int(item.Label))
	if err != nil {
		panic(err)
	}

	// old code with MSE
	// costDerivative, err := activations[len(activations) - 1].Sub(y)
	// if err != nil {
	//     panic(err)
	// }
	// delta, err := costDerivative.Mult(zs[len(zs) - 1].SigmoidPrime())
	// if err != nil {
	//     panic(err)
	// }

	// new code with cross-entropy
	delta, err := activations[len(activations)-1].Sub(y)
	if err != nil {
		panic(err)
	}
	nablaB[len(nablaB)-1] = delta
	nablaW[len(nablaW)-1], err = activations[len(activations)-2].Transpose().Dot(delta)
	if err != nil {
		panic(err)
	}

	for l := 2; l < len(network.layers); l++ {
		z := zs[len(zs)-l]
		sp := z.SigmoidPrime()
		dotted, err := delta.Dot(network.weights[len(network.weights)-l+1].Transpose())
		if err != nil {
			panic(err)
		}
		delta, err = dotted.Mult(sp)
		if err != nil {
			panic(err)
		}
		nablaB[len(nablaB)-l] = delta
		nablaW[len(nablaW)-l], err = activations[len(activations)-l-1].Transpose().Dot(delta)
		if err != nil {
			panic(err)
		}
	}

	return nablaW, nablaB
}