Exemple #1
0
func NewTrainDataReader(reader io.Reader) (*TrainDataReader, error) {
	var nLayers int64
	if err := binary.Read(reader, binary.LittleEndian, &nLayers); err != nil {
		return nil, err
	}

	if addr.Layer(nLayers) > addr.NLayers {
		return nil, fmt.Errorf("The data contains %d layers, but I know only %d.",
			nLayers, addr.NLayers)
	}

	var inputLayerSizes [addr.NLayers]int
	var inputSize int
	for i := addr.Layer(0); i < addr.Layer(nLayers); i++ {
		var layerSize64 int64
		if err := binary.Read(reader, binary.LittleEndian, &layerSize64); err != nil {
			return nil, err
		}

		if layerSize64 < 0 {
			return nil, fmt.Errorf("Input layer size should be 0 or larger, was: %d", layerSize64)
		}

		inputLayerSizes[i] = int(layerSize64)
		inputSize += int(layerSize64)
	}

	inputVector := newInputVector(make([]float32, inputSize), inputLayerSizes)

	return &TrainDataReader{
		reader:      reader,
		inputSize:   inputSize,
		inputVector: inputVector,
	}, nil
}
Exemple #2
0
func (n *Normalizer) Normalize(x InputVector) {
	for layer, normalize := range n.normalizedLayers {
		if normalize {
			n.normalizers[layer].Normalize(x.Layer(addr.Layer(layer)))
		}
	}
}
Exemple #3
0
func (acc *NormalizationAccumulator) Process(x InputVector) error {
	for idx, normalized := range acc.normalizedLayers {
		if normalized {
			err := acc.accumulators[idx].Process(x.Layer(addr.Layer(idx)))
			if err != nil {
				return err
			}
		}
	}

	return nil
}
Exemple #4
0
func (iw *TrainDataWriter) Write(y int, inputVector InputVector) error {
	if iw.first {
		// Write the number of layers.
		if err := binary.Write(iw.writer, binary.LittleEndian, int64(addr.NLayers)); err != nil {
			return err
		}

		for i := addr.Layer(0); i < addr.NLayers; i++ {
			if err := binary.Write(iw.writer, binary.LittleEndian, int64(len(inputVector.Layer(i)))); err != nil {
				return err
			}
		}

		iw.first = false
	}

	if err := binary.Write(iw.writer, binary.LittleEndian, uint64(y)); err != nil {
		return err
	}

	return binary.Write(iw.writer, binary.LittleEndian, inputVector.All())
}