//See another version of this test in proto/extensions_test.go func TestGetExtensionStability(t *testing.T) { check := func(m *NoExtensionsMap) bool { ext1, err := proto.GetExtension(m, E_FieldB1) if err != nil { t.Fatalf("GetExtension() failed: %s", err) } ext2, err := proto.GetExtension(m, E_FieldB1) if err != nil { t.Fatalf("GetExtension() failed: %s", err) } return ext1.(*NinOptNative).Equal(ext2) } msg := &NoExtensionsMap{Field1: proto.Int64(2)} ext0 := &NinOptNative{Field1: proto.Float64(1)} if err := proto.SetExtension(msg, E_FieldB1, ext0); err != nil { t.Fatalf("Could not set ext1: %s", ext0) } if !check(msg) { t.Errorf("GetExtension() not stable before marshaling") } bb, err := proto.Marshal(msg) if err != nil { t.Fatalf("Marshal() failed: %s", err) } msg1 := &NoExtensionsMap{} err = proto.Unmarshal(bb, msg1) if err != nil { t.Fatalf("Unmarshal() failed: %s", err) } if !check(msg1) { t.Errorf("GetExtension() not stable after unmarshaling") } }
func TestGetExtensionStability(t *testing.T) { check := func(m *pb.MyMessage) bool { ext1, err := proto.GetExtension(m, pb.E_Ext_More) if err != nil { t.Fatalf("GetExtension() failed: %s", err) } ext2, err := proto.GetExtension(m, pb.E_Ext_More) if err != nil { t.Fatalf("GetExtension() failed: %s", err) } return ext1 == ext2 } msg := &pb.MyMessage{Count: proto.Int32(4)} ext0 := &pb.Ext{} if err := proto.SetExtension(msg, pb.E_Ext_More, ext0); err != nil { t.Fatalf("Could not set ext1: %s", ext0) } if !check(msg) { t.Errorf("GetExtension() not stable before marshaling") } bb, err := proto.Marshal(msg) if err != nil { t.Fatalf("Marshal() failed: %s", err) } msg1 := &pb.MyMessage{} err = proto.Unmarshal(bb, msg1) if err != nil { t.Fatalf("Unmarshal() failed: %s", err) } if !check(msg1) { t.Errorf("GetExtension() not stable after unmarshaling") } }
func check(t *testing.T, m extendable, fieldA float64, ext *proto.ExtensionDesc) { if !proto.HasExtension(m, ext) { t.Fatalf("expected extension to be set") } fieldA2Interface, err := proto.GetExtension(m, ext) if err != nil { panic(err) } fieldA2 := fieldA2Interface.(*float64) if fieldA != *fieldA2 { t.Fatalf("Expected %f got %f", fieldA, *fieldA2) } fieldA3Interface, err := proto.GetUnsafeExtension(m, ext.Field) if err != nil { panic(err) } fieldA3 := fieldA3Interface.(*float64) if fieldA != *fieldA3 { t.Fatalf("Expected %f got %f", fieldA, *fieldA3) } proto.ClearExtension(m, ext) if proto.HasExtension(m, ext) { t.Fatalf("expected extension to be cleared") } }
func GetJsonTag(field *google_protobuf.FieldDescriptorProto) *string { if field.Options != nil { v, err := proto.GetExtension(field.Options, E_Jsontag) if err == nil && v.(*string) != nil { return (v.(*string)) } } return nil }
func GetCustomName(field *google_protobuf.FieldDescriptorProto) string { if field.Options != nil { v, err := proto.GetExtension(field.Options, E_Customname) if err == nil && v.(*string) != nil { return *(v.(*string)) } } return "" }
func TestExtensionsRoundTrip(t *testing.T) { msg := &pb.MyMessage{} ext1 := &pb.Ext{ Data: proto.String("hi"), } ext2 := &pb.Ext{ Data: proto.String("there"), } exists := proto.HasExtension(msg, pb.E_Ext_More) if exists { t.Error("Extension More present unexpectedly") } if err := proto.SetExtension(msg, pb.E_Ext_More, ext1); err != nil { t.Error(err) } if err := proto.SetExtension(msg, pb.E_Ext_More, ext2); err != nil { t.Error(err) } e, err := proto.GetExtension(msg, pb.E_Ext_More) if err != nil { t.Error(err) } x, ok := e.(*pb.Ext) if !ok { t.Errorf("e has type %T, expected testdata.Ext", e) } else if *x.Data != "there" { t.Errorf("SetExtension failed to overwrite, got %+v, not 'there'", x) } proto.ClearExtension(msg, pb.E_Ext_More) if _, err = proto.GetExtension(msg, pb.E_Ext_More); err != proto.ErrMissingExtension { t.Errorf("got %v, expected ErrMissingExtension", e) } if _, err := proto.GetExtension(msg, pb.E_X215); err == nil { t.Error("expected bad extension error, got nil") } if err := proto.SetExtension(msg, pb.E_X215, 12); err == nil { t.Error("expected extension err") } if err := proto.SetExtension(msg, pb.E_Ext_More, 12); err == nil { t.Error("expected some sort of type mismatch error, got nil") } }
func getCastType(field *descriptor.FieldDescriptorProto) (packageName string, typ string, err error) { if field.Options != nil { v, err := proto.GetExtension(field.Options, gogoproto.E_Casttype) if err == nil && v.(*string) != nil { ctype := *(v.(*string)) packageName, typ = splitCPackageType(ctype) return packageName, typ, nil } } return "", "", err }
func FileHasBoolExtension(file *descriptor.FileDescriptorProto, extension *proto.ExtensionDesc) bool { if file.Options == nil { return false } value, err := proto.GetExtension(file.Options, extension) if err != nil { return false } if value == nil { return false } if value.(*bool) == nil { return false } return true }
func EnumHasBoolExtension(enum *descriptor.EnumDescriptorProto, extension *proto.ExtensionDesc) bool { if enum.Options == nil { return false } value, err := proto.GetExtension(enum.Options, extension) if err != nil { return false } if value == nil { return false } if value.(*bool) == nil { return false } return true }
func MessageHasBoolExtension(msg *descriptor.DescriptorProto, extension *proto.ExtensionDesc) bool { if msg.Options == nil { return false } value, err := proto.GetExtension(msg.Options, extension) if err != nil { return false } if value == nil { return false } if value.(*bool) == nil { return false } return true }
func TestGetExtensionDefaults(t *testing.T) { var setFloat64 float64 = 1 var setFloat32 float32 = 2 var setInt32 int32 = 3 var setInt64 int64 = 4 var setUint32 uint32 = 5 var setUint64 uint64 = 6 var setBool = true var setBool2 = false var setString = "Goodnight string" var setBytes = []byte("Goodnight bytes") var setEnum = pb.DefaultsMessage_TWO type testcase struct { ext *proto.ExtensionDesc // Extension we are testing. want interface{} // Expected value of extension, or nil (meaning that GetExtension will fail). def interface{} // Expected value of extension after ClearExtension(). } tests := []testcase{ {pb.E_NoDefaultDouble, setFloat64, nil}, {pb.E_NoDefaultFloat, setFloat32, nil}, {pb.E_NoDefaultInt32, setInt32, nil}, {pb.E_NoDefaultInt64, setInt64, nil}, {pb.E_NoDefaultUint32, setUint32, nil}, {pb.E_NoDefaultUint64, setUint64, nil}, {pb.E_NoDefaultSint32, setInt32, nil}, {pb.E_NoDefaultSint64, setInt64, nil}, {pb.E_NoDefaultFixed32, setUint32, nil}, {pb.E_NoDefaultFixed64, setUint64, nil}, {pb.E_NoDefaultSfixed32, setInt32, nil}, {pb.E_NoDefaultSfixed64, setInt64, nil}, {pb.E_NoDefaultBool, setBool, nil}, {pb.E_NoDefaultBool, setBool2, nil}, {pb.E_NoDefaultString, setString, nil}, {pb.E_NoDefaultBytes, setBytes, nil}, {pb.E_NoDefaultEnum, setEnum, nil}, {pb.E_DefaultDouble, setFloat64, float64(3.1415)}, {pb.E_DefaultFloat, setFloat32, float32(3.14)}, {pb.E_DefaultInt32, setInt32, int32(42)}, {pb.E_DefaultInt64, setInt64, int64(43)}, {pb.E_DefaultUint32, setUint32, uint32(44)}, {pb.E_DefaultUint64, setUint64, uint64(45)}, {pb.E_DefaultSint32, setInt32, int32(46)}, {pb.E_DefaultSint64, setInt64, int64(47)}, {pb.E_DefaultFixed32, setUint32, uint32(48)}, {pb.E_DefaultFixed64, setUint64, uint64(49)}, {pb.E_DefaultSfixed32, setInt32, int32(50)}, {pb.E_DefaultSfixed64, setInt64, int64(51)}, {pb.E_DefaultBool, setBool, true}, {pb.E_DefaultBool, setBool2, true}, {pb.E_DefaultString, setString, "Hello, string"}, {pb.E_DefaultBytes, setBytes, []byte("Hello, bytes")}, {pb.E_DefaultEnum, setEnum, pb.DefaultsMessage_ONE}, } checkVal := func(test testcase, msg *pb.DefaultsMessage, valWant interface{}) error { val, err := proto.GetExtension(msg, test.ext) if err != nil { if valWant != nil { return fmt.Errorf("GetExtension(): %s", err) } if want := proto.ErrMissingExtension; err != want { return fmt.Errorf("Unexpected error: got %v, want %v", err, want) } return nil } // All proto2 extension values are either a pointer to a value or a slice of values. ty := reflect.TypeOf(val) tyWant := reflect.TypeOf(test.ext.ExtensionType) if got, want := ty, tyWant; got != want { return fmt.Errorf("unexpected reflect.TypeOf(): got %v want %v", got, want) } tye := ty.Elem() tyeWant := tyWant.Elem() if got, want := tye, tyeWant; got != want { return fmt.Errorf("unexpected reflect.TypeOf().Elem(): got %v want %v", got, want) } // Check the name of the type of the value. // If it is an enum it will be type int32 with the name of the enum. if got, want := tye.Name(), tye.Name(); got != want { return fmt.Errorf("unexpected reflect.TypeOf().Elem().Name(): got %v want %v", got, want) } // Check that value is what we expect. // If we have a pointer in val, get the value it points to. valExp := val if ty.Kind() == reflect.Ptr { valExp = reflect.ValueOf(val).Elem().Interface() } if got, want := valExp, valWant; !reflect.DeepEqual(got, want) { return fmt.Errorf("unexpected reflect.DeepEqual(): got %v want %v", got, want) } return nil } setTo := func(test testcase) interface{} { setTo := reflect.ValueOf(test.want) if typ := reflect.TypeOf(test.ext.ExtensionType); typ.Kind() == reflect.Ptr { setTo = reflect.New(typ).Elem() setTo.Set(reflect.New(setTo.Type().Elem())) setTo.Elem().Set(reflect.ValueOf(test.want)) } return setTo.Interface() } for _, test := range tests { msg := &pb.DefaultsMessage{} name := test.ext.Name // Check the initial value. if err := checkVal(test, msg, test.def); err != nil { t.Errorf("%s: %v", name, err) } // Set the per-type value and check value. name = fmt.Sprintf("%s (set to %T %v)", name, test.want, test.want) if err := proto.SetExtension(msg, test.ext, setTo(test)); err != nil { t.Errorf("%s: SetExtension(): %v", name, err) continue } if err := checkVal(test, msg, test.want); err != nil { t.Errorf("%s: %v", name, err) continue } // Set and check the value. name += " (cleared)" proto.ClearExtension(msg, test.ext) if err := checkVal(test, msg, test.def); err != nil { t.Errorf("%s: %v", name, err) } } }