コード例 #1
0
ファイル: linsolve.go プロジェクト: reggo/train
// Creates the features from the inputs. Features must be nSamples x nFeatures or nil
func FeaturizeTrainable(t Trainable, inputs common.RowMatrix, featurizedInputs *mat64.Dense) *mat64.Dense {
	nSamples, nDim := inputs.Dims()
	if featurizedInputs == nil {
		nFeatures := t.NumFeatures()
		featurizedInputs = mat64.NewDense(nSamples, nFeatures, nil)
	}

	rowViewer, isRowViewer := inputs.(mat64.RowViewer)
	var f func(start, end int)
	if isRowViewer {
		f = func(start, end int) {
			featurizer := t.NewFeaturizer()
			for i := start; i < end; i++ {
				featurizer.Featurize(rowViewer.RowView(i), featurizedInputs.RowView(i))
			}
		}
	} else {
		f = func(start, end int) {
			featurizer := t.NewFeaturizer()
			input := make([]float64, nDim)
			for i := start; i < end; i++ {
				inputs.Row(input, i)
				featurizer.Featurize(input, featurizedInputs.RowView(i))
			}
		}
	}

	common.ParallelFor(nSamples, common.GetGrainSize(nSamples, minGrain, maxGrain), f)
	return featurizedInputs
}
コード例 #2
0
ファイル: gradbased.go プロジェクト: reggo/train
// ObjDeriv computes the objective value and stores the derivative in place
func (g *BatchGradBased) ObjGrad(parameters []float64, derivative []float64) (loss float64) {
	c := make(chan lossDerivStruct, 10)

	// Set the channel for parallel for
	f := func(start, end int) {
		g.lossDerivFunc(start, end, c, parameters)
	}

	go func() {
		wg := &sync.WaitGroup{}
		// Compute the losses and the derivatives all in parallel
		wg.Add(2)
		go func() {
			common.ParallelFor(g.nTrain, g.grainSize, f)
			wg.Done()
		}()
		// Compute the regularization
		go func() {
			deriv := make([]float64, g.nParameters)
			loss := g.regularizer.LossDeriv(parameters, deriv)
			c <- lossDerivStruct{
				loss:  loss,
				deriv: deriv,
			}
			wg.Done()
		}()
		// Wait for all of the results to be sent on the channel
		wg.Wait()
		// Close the channel
		close(c)
	}()
	// zero the derivative
	for i := range derivative {
		derivative[i] = 0
	}

	// Range over the channel, incrementing the loss and derivative
	// as they come in
	for l := range c {
		loss += l.loss
		floats.Add(derivative, l.deriv)
	}
	// Normalize by the number of training samples
	loss /= float64(g.nTrain)
	floats.Scale(1/float64(g.nTrain), derivative)
	return loss
}
コード例 #3
0
ファイル: scale.go プロジェクト: reggo/scale
// UnscaleData is a wrapper for unscaling data in parallel.
// TODO: Make this work better so that if there is an error somewhere data isn't changed
func UnscaleData(scaler Scaler, data *mat64.Dense) error {
	m := &sync.Mutex{}
	var e ErrorList
	f := func(start, end int) {
		for r := start; r < end; r++ {
			errTmp := scaler.Unscale(data.RowView(r))
			if errTmp != nil {
				m.Lock()
				e = append(e, &SliceError{Header: "scale", Idx: r, Err: errTmp})
				m.Unlock()
			}
		}
	}

	nSamples, _ := data.Dims()
	grain := common.GetGrainSize(nSamples, 1, 500)
	common.ParallelFor(nSamples, grain, f)
	if len(e) != 0 {
		return e
	}
	return nil
}
コード例 #4
0
ファイル: predict.go プロジェクト: reggo/predict
func BatchPredict(batch BatchPredictor, inputs common.RowMatrix, outputs common.MutableRowMatrix, inputDim, outputDim int, grainSize int) (common.MutableRowMatrix, error) {

	// TODO: Add in something about error

	// Check that the inputs and outputs are the right sizes
	nSamples, dimInputs := inputs.Dims()
	if inputDim != dimInputs {
		return outputs, errors.New("predict batch: input dimension mismatch")
	}

	if outputs == nil {
		outputs = mat64.NewDense(nSamples, outputDim, nil)
	} else {
		nOutputSamples, dimOutputs := outputs.Dims()
		if dimOutputs != outputDim {
			return outputs, errors.New("predict batch: output dimension mismatch")
		}
		if nSamples != nOutputSamples {
			return outputs, errors.New("predict batch: rows mismatch")
		}
	}

	// Perform predictions in parallel. For each parallel call, form a new predictor so that
	// memory allocations are saved and no race condition happens.

	// If the input and/or output is a RowViewer, save time by avoiding a copy
	inputRVer, inputIsRowViewer := inputs.(mat64.RowViewer)
	outputRVer, outputIsRowViewer := outputs.(mat64.RowViewer)

	var f func(start, end int)

	// wrapper function to allow parallel prediction. Uses RowView if the type has it
	switch {
	default:
		panic("Shouldn't be here")
	case inputIsRowViewer, outputIsRowViewer:
		f = func(start, end int) {
			p := batch.NewPredictor()
			for i := start; i < end; i++ {
				p.Predict(inputRVer.RowView(i), outputRVer.RowView(i))
			}
		}

	case inputIsRowViewer && !outputIsRowViewer:
		f = func(start, end int) {
			p := batch.NewPredictor()
			output := make([]float64, outputDim)
			for i := start; i < end; i++ {
				outputs.Row(output, i)
				p.Predict(inputRVer.RowView(i), output)
				outputs.SetRow(i, output)
			}
		}
	case !inputIsRowViewer && outputIsRowViewer:
		f = func(start, end int) {
			p := batch.NewPredictor()
			input := make([]float64, inputDim)
			for i := start; i < end; i++ {
				inputs.Row(input, i)
				p.Predict(input, outputRVer.RowView(i))
			}
		}
	case !inputIsRowViewer && !outputIsRowViewer:
		f = func(start, end int) {
			p := batch.NewPredictor()
			input := make([]float64, inputDim)
			output := make([]float64, outputDim)
			for i := start; i < end; i++ {
				inputs.Row(input, i)
				outputs.Row(output, i)
				p.Predict(input, output)
				outputs.SetRow(i, output)
			}
		}
	}

	common.ParallelFor(nSamples, grainSize, f)
	return outputs, nil
}