예제 #1
0
파일: mlp.go 프로젝트: nickpoorman/nnet
// Train performs mini-batch SGD-based backpropagation to optimize network.
func (d *MLP) Train(input [][]float64, target [][]float64, option TrainingOption) error {
	d.Option = option
	opt := nnet.BaseTrainingOption{
		Epoches:       d.Option.Epoches,
		MiniBatchSize: d.Option.MiniBatchSize,
		Monitoring:    d.Option.Monitoring,
	}
	s := nnet.NewTrainer(opt)
	return s.SupervisedMiniBatchTrain(d, input, target)
}
예제 #2
0
파일: gbrbm.go 프로젝트: nickpoorman/nnet
// Train performs Contrastive divergense learning algorithm to train GBRBM.
// The alrogithm is based on (mini-batch) Stochastic Gradient Ascent.
func (rbm *GBRBM) Train(data [][]float64, option TrainingOption) error {
	rbm.Option = option
	opt := nnet.BaseTrainingOption{
		Epoches:       rbm.Option.Epoches,
		MiniBatchSize: rbm.Option.MiniBatchSize,
		Monitoring:    rbm.Option.Monitoring,
	}

	// Peistent Contrastive learning
	if rbm.Option.UsePersistent {
		rbm.PersistentVisibleUnits = nnet.MakeMatrix(len(data), len(data[0]))
		copy(rbm.PersistentVisibleUnits, data)
	}

	s := nnet.NewTrainer(opt)
	return s.UnSupervisedMiniBatchTrain(rbm, data)
}