コード例 #1
0
ファイル: conv_layer_test.go プロジェクト: unixpickle/weakai
func TestConvLayerSerialize(t *testing.T) {
	layer, _, _ := convLayerTestInfo()
	data, err := layer.Serialize()
	if err != nil {
		t.Fatal(err)
	}
	dataType := layer.SerializerType()

	l, err := serializer.GetDeserializer(dataType)(data)
	if err != nil {
		t.Fatal(err)
	}
	layer, ok := l.(Network)
	if !ok {
		t.Fatalf("decoded layer was not a Network, but rather a %T", l)
	}

	expLists := [][]float64{
		{
			0.348, 0.299, 0.946, 0.806,
			0.101, 0.705, 0.821, 0.819,
			0.106, 0.348, 0.285, 0.133,
		},
		{
			0.293, 0.494, 0.148, 0.758,
			0.901, 0.050, 0.415, 0.892,
			0.736, 0.458, 0.465, 0.167,
		},
		{0.333, -0.255},
	}
	actualLists := [][]float64{
		layer[0].(*ConvLayer).Filters[0].Data,
		layer[0].(*ConvLayer).Filters[1].Data,
		layer[0].(*ConvLayer).Biases.Vector,
	}

	for i, x := range expLists {
		actual := actualLists[i]
		equal := true
		for j, v := range x {
			if math.Abs(actual[j]-v) > 1e-6 {
				equal = false
			}
		}
		if !equal {
			t.Errorf("list %d does not match", i)
		}
	}
}
コード例 #2
0
func TestMaxPoolingSerialize(t *testing.T) {
	layer := &MaxPoolingLayer{3, 3, 10, 11, 2}
	encoded, err := layer.Serialize()
	if err != nil {
		t.Fatal(err)
	}
	layerType := layer.SerializerType()
	decoded, err := serializer.GetDeserializer(layerType)(encoded)
	if err != nil {
		t.Fatal(err)
	}
	layer, ok := decoded.(*MaxPoolingLayer)
	if !ok {
		t.Fatalf("expected *MaxPoolingLayer but got %T", decoded)
	}
}
コード例 #3
0
ファイル: dense_layer_test.go プロジェクト: unixpickle/weakai
func TestDenseSerialize(t *testing.T) {
	network, _, _ := denseTestInfo()
	layer := network[0].(*DenseLayer)

	normalEncoded, err := layer.Serialize()
	if err != nil {
		t.Fatal(err)
	}
	jsonEncoded, _ := json.Marshal(layer)
	layerType := layer.SerializerType()

	for i, encoded := range [][]byte{normalEncoded, jsonEncoded} {
		decoded, err := serializer.GetDeserializer(layerType)(encoded)
		if err != nil {
			t.Fatal(err)
		}

		layer, ok := decoded.(*DenseLayer)
		if !ok {
			t.Fatalf("%d: decoded layer was not a *DenseLayer, but a %T", i, decoded)
		}

		expLists := [][]float64{
			{1, 2, 3, -3, 2, -1},
			{-6, 9},
		}
		actualLists := [][]float64{layer.Weights.Data.Vector, layer.Biases.Var.Vector}

		for k, x := range expLists {
			actual := actualLists[k]
			equal := true
			for j, v := range x {
				if actual[j] != v {
					equal = false
				}
			}
			if !equal {
				t.Errorf("%d: list %d does not match", i, k)
			}
		}
	}

}
コード例 #4
0
ファイル: network_test.go プロジェクト: unixpickle/weakai
func TestNetworkSerialize(t *testing.T) {
	network := Network{
		&DenseLayer{InputCount: 3, OutputCount: 2},
		&Sigmoid{},
	}
	network.Randomize()

	encoded, err := network.Serialize()
	if err != nil {
		t.Fatal(err)
	}
	layerType := network.SerializerType()

	decoded, err := serializer.GetDeserializer(layerType)(encoded)
	if err != nil {
		t.Fatal(err)
	}

	decodedNet, ok := decoded.(Network)
	if !ok {
		t.Fatalf("expected *Network but got %T", decoded)
	}

	if len(network) != len(decodedNet) {
		t.Fatalf("expected %d layers but got %d", len(network), len(decodedNet))
	}

	_, ok = decodedNet[0].(*DenseLayer)
	if !ok {
		t.Fatalf("expected *DenseLayer but got %T", decodedNet[0])
	}

	_, ok = decodedNet[1].(*Sigmoid)
	if !ok {
		t.Fatalf("expected Sigmoid but got %T", decodedNet[1])
	}
}