// DeserializeLSTM deserializes an LSTM. func DeserializeLSTM(d []byte) (*LSTM, error) { slice, err := serializer.DeserializeSlice(d) if err != nil { return nil, err } if len(slice) != 6 { return nil, errors.New("invalid slice length in LSTM") } hiddenSize, ok := slice[0].(serializer.Int) inputValue, ok1 := slice[1].(*lstmGate) inputGate, ok2 := slice[2].(*lstmGate) rememberGate, ok3 := slice[3].(*lstmGate) outputGate, ok4 := slice[4].(*lstmGate) initStateData, ok5 := slice[5].(serializer.Bytes) if !ok || !ok1 || !ok2 || !ok3 || !ok4 || !ok5 { return nil, errors.New("invalid types in LSTM slice") } var initState autofunc.Variable if err := json.Unmarshal(initStateData, &initState); err != nil { return nil, err } return &LSTM{ hiddenSize: int(hiddenSize), inputValue: inputValue, inputGate: inputGate, rememberGate: rememberGate, outputGate: outputGate, initState: &initState, }, nil }
func deserializeLSTMGate(d []byte) (*lstmGate, error) { list, err := serializer.DeserializeSlice(d) if err != nil { return nil, err } if len(list) != 2 && len(list) != 3 { return nil, errors.New("invalid slice length for LSTM gate") } dense, ok := list[0].(*neuralnet.DenseLayer) activ, ok1 := list[1].(neuralnet.Layer) if !ok || !ok1 { return nil, errors.New("invalid types for LSTM gate slice") } res := &lstmGate{Dense: dense, Activation: activ} if len(list) == 3 { peephole, ok := list[2].(serializer.Bytes) if !ok { return nil, errors.New("invalid types for LSTM gate slice") } if err := json.Unmarshal(peephole, &res.Peephole); err != nil { return nil, fmt.Errorf("bad peephole data: %s", err) } } return res, nil }
// DeserializeGRU creates a GRU from some serialized // data about the GRU. func DeserializeGRU(d []byte) (*GRU, error) { slice, err := serializer.DeserializeSlice(d) if err != nil { return nil, err } if len(slice) != 5 { return nil, errors.New("invalid slice length in GRU") } hiddenSize, ok := slice[0].(serializer.Int) inputValue, ok1 := slice[1].(*lstmGate) resetGate, ok2 := slice[2].(*lstmGate) updateGate, ok3 := slice[3].(*lstmGate) initStateData, ok4 := slice[4].(serializer.Bytes) if !ok || !ok1 || !ok2 || !ok3 || !ok4 { return nil, errors.New("invalid types in GRU slice") } var initState autofunc.Variable if err := json.Unmarshal(initStateData, &initState); err != nil { return nil, errors.New("invalid init state in GRU slice") } return &GRU{ hiddenSize: int(hiddenSize), inputValue: inputValue, resetGate: resetGate, updateGate: updateGate, initState: &initState, }, nil }
// DeserializeStackedBlock deserializes a StackedBlock. func DeserializeStackedBlock(d []byte) (StackedBlock, error) { list, err := serializer.DeserializeSlice(d) if err != nil { return nil, err } res := make(StackedBlock, len(list)) for i, s := range list { var ok bool res[i], ok = s.(Block) if !ok { return nil, fmt.Errorf("layer %d (%T) is not Block", i, s) } } return res, nil }
// DeserializeBidirectional deserializes a previously // serialized Bidirectional instance. func DeserializeBidirectional(d []byte) (*Bidirectional, error) { slice, err := serializer.DeserializeSlice(d) if err != nil { return nil, err } if len(slice) != 3 { return nil, errors.New("invalid Bidirectional slice length") } s1, ok1 := slice[0].(seqfunc.RFunc) s2, ok2 := slice[1].(seqfunc.RFunc) s3, ok3 := slice[2].(seqfunc.RFunc) if !ok1 || !ok2 || !ok3 { return nil, errors.New("invalid Bidirectional slice types") } return &Bidirectional{s1, s2, s3}, nil }
func DeserializeNetwork(data []byte) (Network, error) { var res Network slice, err := serializer.DeserializeSlice(data) if err != nil { return nil, err } for _, x := range slice { if layer, ok := x.(Layer); ok { res = append(res, layer) } else { return nil, errors.New("slice element is not a Layer") } } return res, nil }
// DeserializeNetworkBlock deserializes a NetworkBlock. func DeserializeNetworkBlock(d []byte) (*NetworkBlock, error) { list, err := serializer.DeserializeSlice(d) if err != nil { return nil, err } else if len(list) != 3 { return nil, errors.New("bad network list length") } stateSize, ok := list[0].(serializer.Int) network, ok1 := list[1].(neuralnet.Network) initData, ok2 := list[2].(serializer.Bytes) if !ok || !ok1 || !ok2 { return nil, errors.New("bad types in network list") } var initState autofunc.Variable if err := json.Unmarshal(initData, &initState); err != nil { return nil, err } res := NewNetworkBlock(network, int(stateSize)) res.batcherBlock.Start = &initState return res, nil }