Esempio n. 1
0
// DeserializeLSTM deserializes an LSTM.
func DeserializeLSTM(d []byte) (*LSTM, error) {
	slice, err := serializer.DeserializeSlice(d)
	if err != nil {
		return nil, err
	}
	if len(slice) != 6 {
		return nil, errors.New("invalid slice length in LSTM")
	}
	hiddenSize, ok := slice[0].(serializer.Int)
	inputValue, ok1 := slice[1].(*lstmGate)
	inputGate, ok2 := slice[2].(*lstmGate)
	rememberGate, ok3 := slice[3].(*lstmGate)
	outputGate, ok4 := slice[4].(*lstmGate)
	initStateData, ok5 := slice[5].(serializer.Bytes)
	if !ok || !ok1 || !ok2 || !ok3 || !ok4 || !ok5 {
		return nil, errors.New("invalid types in LSTM slice")
	}
	var initState autofunc.Variable
	if err := json.Unmarshal(initStateData, &initState); err != nil {
		return nil, err
	}
	return &LSTM{
		hiddenSize:   int(hiddenSize),
		inputValue:   inputValue,
		inputGate:    inputGate,
		rememberGate: rememberGate,
		outputGate:   outputGate,
		initState:    &initState,
	}, nil
}
Esempio n. 2
0
func deserializeLSTMGate(d []byte) (*lstmGate, error) {
	list, err := serializer.DeserializeSlice(d)
	if err != nil {
		return nil, err
	}
	if len(list) != 2 && len(list) != 3 {
		return nil, errors.New("invalid slice length for LSTM gate")
	}
	dense, ok := list[0].(*neuralnet.DenseLayer)
	activ, ok1 := list[1].(neuralnet.Layer)
	if !ok || !ok1 {
		return nil, errors.New("invalid types for LSTM gate slice")
	}
	res := &lstmGate{Dense: dense, Activation: activ}
	if len(list) == 3 {
		peephole, ok := list[2].(serializer.Bytes)
		if !ok {
			return nil, errors.New("invalid types for LSTM gate slice")
		}
		if err := json.Unmarshal(peephole, &res.Peephole); err != nil {
			return nil, fmt.Errorf("bad peephole data: %s", err)
		}
	}
	return res, nil
}
Esempio n. 3
0
// DeserializeGRU creates a GRU from some serialized
// data about the GRU.
func DeserializeGRU(d []byte) (*GRU, error) {
	slice, err := serializer.DeserializeSlice(d)
	if err != nil {
		return nil, err
	}
	if len(slice) != 5 {
		return nil, errors.New("invalid slice length in GRU")
	}
	hiddenSize, ok := slice[0].(serializer.Int)
	inputValue, ok1 := slice[1].(*lstmGate)
	resetGate, ok2 := slice[2].(*lstmGate)
	updateGate, ok3 := slice[3].(*lstmGate)
	initStateData, ok4 := slice[4].(serializer.Bytes)
	if !ok || !ok1 || !ok2 || !ok3 || !ok4 {
		return nil, errors.New("invalid types in GRU slice")
	}
	var initState autofunc.Variable
	if err := json.Unmarshal(initStateData, &initState); err != nil {
		return nil, errors.New("invalid init state in GRU slice")
	}
	return &GRU{
		hiddenSize: int(hiddenSize),
		inputValue: inputValue,
		resetGate:  resetGate,
		updateGate: updateGate,
		initState:  &initState,
	}, nil
}
Esempio n. 4
0
// DeserializeStackedBlock deserializes a StackedBlock.
func DeserializeStackedBlock(d []byte) (StackedBlock, error) {
	list, err := serializer.DeserializeSlice(d)
	if err != nil {
		return nil, err
	}
	res := make(StackedBlock, len(list))
	for i, s := range list {
		var ok bool
		res[i], ok = s.(Block)
		if !ok {
			return nil, fmt.Errorf("layer %d (%T) is not Block", i, s)
		}
	}
	return res, nil
}
Esempio n. 5
0
// DeserializeBidirectional deserializes a previously
// serialized Bidirectional instance.
func DeserializeBidirectional(d []byte) (*Bidirectional, error) {
	slice, err := serializer.DeserializeSlice(d)
	if err != nil {
		return nil, err
	}
	if len(slice) != 3 {
		return nil, errors.New("invalid Bidirectional slice length")
	}
	s1, ok1 := slice[0].(seqfunc.RFunc)
	s2, ok2 := slice[1].(seqfunc.RFunc)
	s3, ok3 := slice[2].(seqfunc.RFunc)
	if !ok1 || !ok2 || !ok3 {
		return nil, errors.New("invalid Bidirectional slice types")
	}
	return &Bidirectional{s1, s2, s3}, nil
}
Esempio n. 6
0
func DeserializeNetwork(data []byte) (Network, error) {
	var res Network

	slice, err := serializer.DeserializeSlice(data)
	if err != nil {
		return nil, err
	}

	for _, x := range slice {
		if layer, ok := x.(Layer); ok {
			res = append(res, layer)
		} else {
			return nil, errors.New("slice element is not a Layer")
		}
	}

	return res, nil
}
Esempio n. 7
0
// DeserializeNetworkBlock deserializes a NetworkBlock.
func DeserializeNetworkBlock(d []byte) (*NetworkBlock, error) {
	list, err := serializer.DeserializeSlice(d)
	if err != nil {
		return nil, err
	} else if len(list) != 3 {
		return nil, errors.New("bad network list length")
	}
	stateSize, ok := list[0].(serializer.Int)
	network, ok1 := list[1].(neuralnet.Network)
	initData, ok2 := list[2].(serializer.Bytes)
	if !ok || !ok1 || !ok2 {
		return nil, errors.New("bad types in network list")
	}
	var initState autofunc.Variable
	if err := json.Unmarshal(initData, &initState); err != nil {
		return nil, err
	}
	res := NewNetworkBlock(network, int(stateSize))
	res.batcherBlock.Start = &initState
	return res, nil
}