示例#1
0
func TestG1(t *testing.T) {
	cmp := func(got *G1, want *bn256.G1) error {
		if gotB, wantB := got.Marshal(), want.Marshal(); !bytes.Equal(gotB, wantB) {
			return fmt.Errorf("Got %v want %v", got, want)
		}
		return nil
	}
	onetest := func(k *big.Int) error {
		var (
			got  = new(G1).ScalarBaseMult(k)
			gotB = got.Marshal()
			want = new(bn256.G1).ScalarBaseMult(k)
		)
		if g, w := got.String(), want.String(); g != w {
			// TODO: Minor implementation difference causes String for
			// golang.org/x/crypto/bn256.G1 to return (1, -2) for k=1,
			// while (1, 65000549695646603732796438742359905742825358107623003571877145026864184071781)
			// for this package. The two are identical since
			// (-2 mod p) == 65000549695646603732796438742359905742825358107623003571877145026864184071781
			// So, ignore that difference.
			if k.Cmp(big.NewInt(1)) == 0 {
				w = "bn256.G1(1, 65000549695646603732796438742359905742825358107623003571877145026864184071781)"
			}
			if g != w {
				return fmt.Errorf("k=%v: String: Got %q, want %q", k, g, w)
			}
		}
		if err := cmp(got, want); err != nil {
			return fmt.Errorf("k=%v: ScalarBaseMult: %v", k, err)
		}
		if err := cmp(
			new(G1).Add(got, new(G1).ScalarBaseMult(big.NewInt(3))),
			new(bn256.G1).Add(want, new(bn256.G1).ScalarBaseMult(big.NewInt(3))),
		); err != nil {
			return fmt.Errorf("k=%v: Add: %v", k, err)
		}
		if err := cmp(new(G1).Neg(got), new(bn256.G1).Neg(want)); err != nil {
			return fmt.Errorf("k=%v: Neg: %v", k, err)
		}
		// Unmarshal and Marshal again.
		unmarshaled, ok := new(G1).Unmarshal(gotB)
		if !ok {
			return fmt.Errorf("k=%v: Unmarshal failed", k)
		}
		again := unmarshaled.Marshal()
		if !bytes.Equal(gotB, again) {
			return fmt.Errorf("k=%v: Umarshal+Marshal: Got %v, want %v", k, again, gotB)
		}
		return nil
	}
	if err := onetest(big.NewInt(0)); err != nil {
		t.Error(err)
	}
	if err := onetest(big.NewInt(1)); err != nil {
		t.Error(err)
	}
	for i := 0; i < 100; i++ {
		k, err := rand.Int(rand.Reader, p)
		if err != nil {
			t.Fatal(err)
		}
		if err := onetest(k); err != nil {
			t.Errorf("%v (random test #%d)", err, i)
		}
	}
}