func (network NN) updateMiniBatch(batch []TrainItem, eta, lmbda float64, n int) { var err error cxw := make([]matrices.Matrix, len(network.weights)) cxb := make([]matrices.Matrix, len(network.biases)) for i, m := range network.weights { cxw[i] = matrices.InitMatrix(m.Rows(), m.Cols()) } for i, m := range network.biases { cxb[i] = matrices.InitMatrix(m.Rows(), m.Cols()) } for _, item := range batch { nablaW, nablaB := network.backprop(item) for i, nabla := range nablaW { cxw[i], err = cxw[i].Add(nabla) if err != nil { panic(err) } } for i, nabla := range nablaB { cxb[i], err = cxb[i].Add(nabla) if err != nil { panic(err) } } } multByConst := matrices.Mult(eta / float64(len(batch))) for i, w := range cxw { regularization := matrices.Mult(1 - eta*lmbda/float64(n)) reduced := w.Apply(multByConst) network.weights[i], err = network.weights[i].Apply(regularization).Sub(reduced) if err != nil { panic(err) } } for i, b := range cxb { reduced := b.Apply(multByConst) network.biases[i], err = network.biases[i].Sub(reduced) if err != nil { panic(err) } } }
func (network NN) backprop(item TrainItem) ([]matrices.Matrix, []matrices.Matrix) { nablaW := make([]matrices.Matrix, len(network.weights)) nablaB := make([]matrices.Matrix, len(network.biases)) for i, m := range network.weights { nablaW[i] = matrices.InitMatrix(m.Rows(), m.Cols()) } for i, m := range network.biases { nablaB[i] = matrices.InitMatrix(m.Rows(), m.Cols()) } activation := item.Values activations := make([]matrices.Matrix, len(network.weights)+1) activations[0] = activation zs := make([]matrices.Matrix, len(network.weights)) for i := range network.weights { weights := network.weights[i] biases := network.biases[i] multiplied, err := activation.Dot(weights) if err != nil { panic(err) } z, err := multiplied.Add(biases) if err != nil { panic(err) } zs[i] = z activation = z.Sigmoid() activations[i+1] = activation } y, err := matrices.OneHotMatrix(1, item.Distinct, 0, int(item.Label)) if err != nil { panic(err) } // old code with MSE // costDerivative, err := activations[len(activations) - 1].Sub(y) // if err != nil { // panic(err) // } // delta, err := costDerivative.Mult(zs[len(zs) - 1].SigmoidPrime()) // if err != nil { // panic(err) // } // new code with cross-entropy delta, err := activations[len(activations)-1].Sub(y) if err != nil { panic(err) } nablaB[len(nablaB)-1] = delta nablaW[len(nablaW)-1], err = activations[len(activations)-2].Transpose().Dot(delta) if err != nil { panic(err) } for l := 2; l < len(network.layers); l++ { z := zs[len(zs)-l] sp := z.SigmoidPrime() dotted, err := delta.Dot(network.weights[len(network.weights)-l+1].Transpose()) if err != nil { panic(err) } delta, err = dotted.Mult(sp) if err != nil { panic(err) } nablaB[len(nablaB)-l] = delta nablaW[len(nablaW)-l], err = activations[len(activations)-l-1].Transpose().Dot(delta) if err != nil { panic(err) } } return nablaW, nablaB }