//See another version of this test in proto/extensions_test.go
func TestGetExtensionStability(t *testing.T) {
	check := func(m *NoExtensionsMap) bool {
		ext1, err := proto.GetExtension(m, E_FieldB1)
		if err != nil {
			t.Fatalf("GetExtension() failed: %s", err)
		}
		ext2, err := proto.GetExtension(m, E_FieldB1)
		if err != nil {
			t.Fatalf("GetExtension() failed: %s", err)
		}
		return ext1.(*NinOptNative).Equal(ext2)
	}
	msg := &NoExtensionsMap{Field1: proto.Int64(2)}
	ext0 := &NinOptNative{Field1: proto.Float64(1)}
	if err := proto.SetExtension(msg, E_FieldB1, ext0); err != nil {
		t.Fatalf("Could not set ext1: %s", ext0)
	}
	if !check(msg) {
		t.Errorf("GetExtension() not stable before marshaling")
	}
	bb, err := proto.Marshal(msg)
	if err != nil {
		t.Fatalf("Marshal() failed: %s", err)
	}
	msg1 := &NoExtensionsMap{}
	err = proto.Unmarshal(bb, msg1)
	if err != nil {
		t.Fatalf("Unmarshal() failed: %s", err)
	}
	if !check(msg1) {
		t.Errorf("GetExtension() not stable after unmarshaling")
	}
}
func TestGetExtensionStability(t *testing.T) {
	check := func(m *pb.MyMessage) bool {
		ext1, err := proto.GetExtension(m, pb.E_Ext_More)
		if err != nil {
			t.Fatalf("GetExtension() failed: %s", err)
		}
		ext2, err := proto.GetExtension(m, pb.E_Ext_More)
		if err != nil {
			t.Fatalf("GetExtension() failed: %s", err)
		}
		return ext1 == ext2
	}
	msg := &pb.MyMessage{Count: proto.Int32(4)}
	ext0 := &pb.Ext{}
	if err := proto.SetExtension(msg, pb.E_Ext_More, ext0); err != nil {
		t.Fatalf("Could not set ext1: %s", ext0)
	}
	if !check(msg) {
		t.Errorf("GetExtension() not stable before marshaling")
	}
	bb, err := proto.Marshal(msg)
	if err != nil {
		t.Fatalf("Marshal() failed: %s", err)
	}
	msg1 := &pb.MyMessage{}
	err = proto.Unmarshal(bb, msg1)
	if err != nil {
		t.Fatalf("Unmarshal() failed: %s", err)
	}
	if !check(msg1) {
		t.Errorf("GetExtension() not stable after unmarshaling")
	}
}
func check(t *testing.T, m extendable, fieldA float64, ext *proto.ExtensionDesc) {
	if !proto.HasExtension(m, ext) {
		t.Fatalf("expected extension to be set")
	}
	fieldA2Interface, err := proto.GetExtension(m, ext)
	if err != nil {
		panic(err)
	}
	fieldA2 := fieldA2Interface.(*float64)
	if fieldA != *fieldA2 {
		t.Fatalf("Expected %f got %f", fieldA, *fieldA2)
	}
	fieldA3Interface, err := proto.GetUnsafeExtension(m, ext.Field)
	if err != nil {
		panic(err)
	}
	fieldA3 := fieldA3Interface.(*float64)
	if fieldA != *fieldA3 {
		t.Fatalf("Expected %f got %f", fieldA, *fieldA3)
	}
	proto.ClearExtension(m, ext)
	if proto.HasExtension(m, ext) {
		t.Fatalf("expected extension to be cleared")
	}
}
Beispiel #4
0
func GetEnumValueCustomName(field *google_protobuf.EnumValueDescriptorProto) string {
	if field.Options != nil {
		v, err := proto.GetExtension(field.Options, E_EnumvalueCustomname)
		if err == nil && v.(*string) != nil {
			return *(v.(*string))
		}
	}
	return ""
}
Beispiel #5
0
func GetMoreTags(field *google_protobuf.FieldDescriptorProto) *string {
	if field.Options != nil {
		v, err := proto.GetExtension(field.Options, E_Moretags)
		if err == nil && v.(*string) != nil {
			return (v.(*string))
		}
	}
	return nil
}
Beispiel #6
0
func GetCastType(field *google_protobuf.FieldDescriptorProto) string {
	if field.Options != nil {
		v, err := proto.GetExtension(field.Options, E_Casttype)
		if err == nil && v.(*string) != nil {
			return *(v.(*string))
		}
	}
	return ""
}
Beispiel #7
0
func GetModel(msg *generator.Descriptor) *ModelDescriptor {
	if msg.Options != nil {
		v, _ := proto.GetExtension(msg.Options, E_Model)
		if v != nil {
			return v.(*ModelDescriptor)
		}
	}

	return nil
}
Beispiel #8
0
func GetJoin(field *pb.FieldDescriptorProto) *JoinDescriptor {
	if field.Options != nil {
		v, _ := proto.GetExtension(field.Options, E_Join)
		if v != nil {
			return v.(*JoinDescriptor)
		}
	}

	return nil
}
Beispiel #9
0
func GetHTTPRule(method *pb.MethodDescriptorProto) *HttpRule {
	if method.Options == nil {
		return nil
	}
	v, _ := proto.GetExtension(method.Options, E_Http)
	if v == nil {
		return nil
	}
	return v.(*HttpRule)
}
func TestExtensionsRoundTrip(t *testing.T) {
	msg := &pb.MyMessage{}
	ext1 := &pb.Ext{
		Data: proto.String("hi"),
	}
	ext2 := &pb.Ext{
		Data: proto.String("there"),
	}
	exists := proto.HasExtension(msg, pb.E_Ext_More)
	if exists {
		t.Error("Extension More present unexpectedly")
	}
	if err := proto.SetExtension(msg, pb.E_Ext_More, ext1); err != nil {
		t.Error(err)
	}
	if err := proto.SetExtension(msg, pb.E_Ext_More, ext2); err != nil {
		t.Error(err)
	}
	e, err := proto.GetExtension(msg, pb.E_Ext_More)
	if err != nil {
		t.Error(err)
	}
	x, ok := e.(*pb.Ext)
	if !ok {
		t.Errorf("e has type %T, expected testdata.Ext", e)
	} else if *x.Data != "there" {
		t.Errorf("SetExtension failed to overwrite, got %+v, not 'there'", x)
	}
	proto.ClearExtension(msg, pb.E_Ext_More)
	if _, err = proto.GetExtension(msg, pb.E_Ext_More); err != proto.ErrMissingExtension {
		t.Errorf("got %v, expected ErrMissingExtension", e)
	}
	if _, err := proto.GetExtension(msg, pb.E_X215); err == nil {
		t.Error("expected bad extension error, got nil")
	}
	if err := proto.SetExtension(msg, pb.E_X215, 12); err == nil {
		t.Error("expected extension err")
	}
	if err := proto.SetExtension(msg, pb.E_Ext_More, 12); err == nil {
		t.Error("expected some sort of type mismatch error, got nil")
	}
}
func GetMinLength(field *pb.FieldDescriptorProto) (uint32, bool) {
	if field == nil || field.Options == nil {
		return 0, false
	}
	v, _ := proto.GetExtension(field.Options, E_MinLength)
	s, _ := v.(*uint32)
	if s == nil {
		return 0, false
	}
	return *s, true
}
func GetPattern(field *pb.FieldDescriptorProto) (string, bool) {
	if field == nil || field.Options == nil {
		return "", false
	}
	v, _ := proto.GetExtension(field.Options, E_Pattern)
	s, _ := v.(*string)
	if s == nil {
		return "", false
	}
	return *s, true
}
Beispiel #13
0
func getCastValue(field *descriptor.FieldDescriptorProto) (packageName string, typ string, err error) {
	if field.Options != nil {
		var v interface{}
		v, err = proto.GetExtension(field.Options, gogoproto.E_Castvalue)
		if err == nil && v.(*string) != nil {
			ctype := *(v.(*string))
			packageName, typ = splitCPackageType(ctype)
			return packageName, typ, nil
		}
	}
	return "", "", err
}
Beispiel #14
0
func GetColumn(field *pb.FieldDescriptorProto) *ColumnDescriptor {
	if field.Options == nil {
		field.Options = &pb.FieldOptions{}
	}

	v, _ := proto.GetExtension(field.Options, E_Column)
	if v == nil {
		if j, _ := proto.GetExtension(field.Options, E_Join); j != nil {
			return nil
		}

		v = &ColumnDescriptor{}
		proto.SetExtension(field.Options, E_Column, v)
	}

	c := v.(*ColumnDescriptor)
	if c.Ignore {
		return nil
	}

	return c
}
Beispiel #15
0
func GetScope(method *descriptor.MethodDescriptorProto) (string, bool) {
	if method.Options != nil {
		v, _ := proto.GetExtension(method.Options, E_Authz)
		if v != nil {
			scope := v.(*AuthzRule).Scope
			if scope != "" {
				return scope, true
			}
		}
	}

	return "", false
}
Beispiel #16
0
func FileHasBoolExtension(file *descriptor.FileDescriptorProto, extension *proto.ExtensionDesc) bool {
	if file.Options == nil {
		return false
	}
	value, err := proto.GetExtension(file.Options, extension)
	if err != nil {
		return false
	}
	if value == nil {
		return false
	}
	if value.(*bool) == nil {
		return false
	}
	return true
}
Beispiel #17
0
func MessageHasBoolExtension(msg *descriptor.DescriptorProto, extension *proto.ExtensionDesc) bool {
	if msg.Options == nil {
		return false
	}
	value, err := proto.GetExtension(msg.Options, extension)
	if err != nil {
		return false
	}
	if value == nil {
		return false
	}
	if value.(*bool) == nil {
		return false
	}
	return true
}
Beispiel #18
0
func FieldHasStringExtension(field *pb.FieldDescriptorProto, extension *proto.ExtensionDesc) bool {
	if field.Options == nil {
		return false
	}
	value, err := proto.GetExtension(field.Options, extension)
	if err != nil {
		return false
	}
	if value == nil {
		return false
	}
	if value.(*string) == nil {
		return false
	}
	return true
}
Beispiel #19
0
func EnumHasBoolExtension(enum *descriptor.EnumDescriptorProto, extension *proto.ExtensionDesc) bool {
	if enum.Options == nil {
		return false
	}
	value, err := proto.GetExtension(enum.Options, extension)
	if err != nil {
		return false
	}
	if value == nil {
		return false
	}
	if value.(*bool) == nil {
		return false
	}
	return true
}
Beispiel #20
0
func filterAPIs(service *pb.ServiceDescriptorProto, methods []*pb.MethodDescriptorProto, svcIndex int) []*API {
	var apis = make([]*API, 0, len(methods))
	path := fmt.Sprintf("6,%d", svcIndex) // 6 means service.

	var (
		descName  = "_" + service.GetName() + "_serviceDesc"
		methodIdx = 0
		streamIdx = 0
	)

	for i, method := range methods {
		stream := method.GetClientStreaming() || method.GetServerStreaming()
		index := 0
		if stream {
			index = streamIdx
		} else {
			index = methodIdx
		}

		v, _ := proto.GetExtension(method.Options, E_Http)
		info, _ := v.(*HttpRule)
		if info != nil {
			apis = append(apis, &API{
				service:       service,
				method:        method,
				desc:          info,
				descIndexPath: fmt.Sprintf("%s,2,%d", path, i), // 2 means method in a service.
				descName:      descName,
				stream:        stream,
				index:         index,
			})
		}

		if stream {
			streamIdx++
		} else {
			methodIdx++
		}
	}

	return apis
}
Beispiel #21
0
// marshalObject writes a struct to the Writer.
func (m *Marshaler) marshalObject(out *errWriter, v proto.Message, indent string) error {
	out.write("{")
	if m.Indent != "" {
		out.write("\n")
	}

	s := reflect.ValueOf(v).Elem()
	firstField := true
	for i := 0; i < s.NumField(); i++ {
		value := s.Field(i)
		valueField := s.Type().Field(i)
		if strings.HasPrefix(valueField.Name, "XXX_") {
			continue
		}

		// IsNil will panic on most value kinds.
		switch value.Kind() {
		case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice:
			if value.IsNil() {
				continue
			}
		}

		if !m.EmitDefaults {
			switch value.Kind() {
			case reflect.Bool:
				if !value.Bool() {
					continue
				}
			case reflect.Int32, reflect.Int64:
				if value.Int() == 0 {
					continue
				}
			case reflect.Uint32, reflect.Uint64:
				if value.Uint() == 0 {
					continue
				}
			case reflect.Float32, reflect.Float64:
				if value.Float() == 0 {
					continue
				}
			case reflect.String:
				if value.Len() == 0 {
					continue
				}
			}
		}

		// Oneof fields need special handling.
		if valueField.Tag.Get("protobuf_oneof") != "" {
			// value is an interface containing &T{real_value}.
			sv := value.Elem().Elem() // interface -> *T -> T
			value = sv.Field(0)
			valueField = sv.Type().Field(0)
		}
		prop := jsonProperties(valueField)
		if !firstField {
			m.writeSep(out)
		}
		// If the map value is a cast type, it may not implement proto.Message, therefore
		// allow the struct tag to declare the underlying message type. Instead of changing
		// the signatures of the child types (and because prop.mvalue is not public), use
		// CustomType as a passer.
		if value.Kind() == reflect.Map {
			if tag := valueField.Tag.Get("protobuf"); tag != "" {
				for _, v := range strings.Split(tag, ",") {
					if !strings.HasPrefix(v, "castvaluetype=") {
						continue
					}
					v = strings.TrimPrefix(v, "castvaluetype=")
					prop.CustomType = v
					break
				}
			}
		}
		if err := m.marshalField(out, prop, value, indent); err != nil {
			return err
		}
		firstField = false
	}

	// Handle proto2 extensions.
	if ep, ok := v.(extendableProto); ok {
		extensions := proto.RegisteredExtensions(v)
		extensionMap := ep.ExtensionMap()
		// Sort extensions for stable output.
		ids := make([]int32, 0, len(extensionMap))
		for id := range extensionMap {
			ids = append(ids, id)
		}
		sort.Sort(int32Slice(ids))
		for _, id := range ids {
			desc := extensions[id]
			if desc == nil {
				// unknown extension
				continue
			}
			ext, extErr := proto.GetExtension(ep, desc)
			if extErr != nil {
				return extErr
			}
			value := reflect.ValueOf(ext)
			var prop proto.Properties
			prop.Parse(desc.Tag)
			prop.OrigName = fmt.Sprintf("[%s]", desc.Name)
			if !firstField {
				m.writeSep(out)
			}
			if err := m.marshalField(out, &prop, value, indent); err != nil {
				return err
			}
			firstField = false
		}

	}

	if m.Indent != "" {
		out.write("\n")
		out.write(indent)
	}
	out.write("}")
	return out.err
}
func TestMarshalUnmarshalRepeatedExtension(t *testing.T) {
	// Add a repeated extension to the result.
	tests := []struct {
		name string
		ext  []*pb.ComplexExtension
	}{
		{
			"two fields",
			[]*pb.ComplexExtension{
				{First: proto.Int32(7)},
				{Second: proto.Int32(11)},
			},
		},
		{
			"repeated field",
			[]*pb.ComplexExtension{
				{Third: []int32{1000}},
				{Third: []int32{2000}},
			},
		},
		{
			"two fields and repeated field",
			[]*pb.ComplexExtension{
				{Third: []int32{1000}},
				{First: proto.Int32(9)},
				{Second: proto.Int32(21)},
				{Third: []int32{2000}},
			},
		},
	}
	for _, test := range tests {
		// Marshal message with a repeated extension.
		msg1 := new(pb.OtherMessage)
		err := proto.SetExtension(msg1, pb.E_RComplex, test.ext)
		if err != nil {
			t.Fatalf("[%s] Error setting extension: %v", test.name, err)
		}
		b, err := proto.Marshal(msg1)
		if err != nil {
			t.Fatalf("[%s] Error marshaling message: %v", test.name, err)
		}

		// Unmarshal and read the merged proto.
		msg2 := new(pb.OtherMessage)
		err = proto.Unmarshal(b, msg2)
		if err != nil {
			t.Fatalf("[%s] Error unmarshaling message: %v", test.name, err)
		}
		e, err := proto.GetExtension(msg2, pb.E_RComplex)
		if err != nil {
			t.Fatalf("[%s] Error getting extension: %v", test.name, err)
		}
		ext := e.([]*pb.ComplexExtension)
		if ext == nil {
			t.Fatalf("[%s] Invalid extension", test.name)
		}
		if !reflect.DeepEqual(ext, test.ext) {
			t.Errorf("[%s] Wrong value for ComplexExtension: got: %v want: %v\n", test.name, ext, test.ext)
		}
	}
}
func TestUnmarshalRepeatingNonRepeatedExtension(t *testing.T) {
	// We may see multiple instances of the same extension in the wire
	// format. For example, the proto compiler may encode custom options in
	// this way. Here, we verify that we merge the extensions together.
	tests := []struct {
		name string
		ext  []*pb.ComplexExtension
	}{
		{
			"two fields",
			[]*pb.ComplexExtension{
				{First: proto.Int32(7)},
				{Second: proto.Int32(11)},
			},
		},
		{
			"repeated field",
			[]*pb.ComplexExtension{
				{Third: []int32{1000}},
				{Third: []int32{2000}},
			},
		},
		{
			"two fields and repeated field",
			[]*pb.ComplexExtension{
				{Third: []int32{1000}},
				{First: proto.Int32(9)},
				{Second: proto.Int32(21)},
				{Third: []int32{2000}},
			},
		},
	}
	for _, test := range tests {
		var buf bytes.Buffer
		var want pb.ComplexExtension

		// Generate a serialized representation of a repeated extension
		// by catenating bytes together.
		for i, e := range test.ext {
			// Merge to create the wanted proto.
			proto.Merge(&want, e)

			// serialize the message
			msg := new(pb.OtherMessage)
			err := proto.SetExtension(msg, pb.E_Complex, e)
			if err != nil {
				t.Fatalf("[%s] Error setting extension %d: %v", test.name, i, err)
			}
			b, err := proto.Marshal(msg)
			if err != nil {
				t.Fatalf("[%s] Error marshaling message %d: %v", test.name, i, err)
			}
			buf.Write(b)
		}

		// Unmarshal and read the merged proto.
		msg2 := new(pb.OtherMessage)
		err := proto.Unmarshal(buf.Bytes(), msg2)
		if err != nil {
			t.Fatalf("[%s] Error unmarshaling message: %v", test.name, err)
		}
		e, err := proto.GetExtension(msg2, pb.E_Complex)
		if err != nil {
			t.Fatalf("[%s] Error getting extension: %v", test.name, err)
		}
		ext := e.(*pb.ComplexExtension)
		if ext == nil {
			t.Fatalf("[%s] Invalid extension", test.name)
		}
		if !reflect.DeepEqual(*ext, want) {
			t.Errorf("[%s] Wrong value for ComplexExtension: got: %v want: %v\n", test.name, ext, want)
		}
	}
}
Beispiel #24
0
func (g *jsonschema) fieldToSchema(field *pb.FieldDescriptorProto) (map[string]interface{}, string) {
	if field.Options != nil {
		v, _ := proto.GetExtension(field.Options, limbo.E_HideInSwagger)
		hidePtr, _ := v.(*bool)
		if hidePtr != nil && *hidePtr == true {
			return nil, ""
		}
	}

	var (
		def map[string]interface{}
		dep string
	)

	switch field.GetType() {

	case pb.FieldDescriptorProto_TYPE_BOOL:
		def = map[string]interface{}{
			"type": "boolean",
		}

	case pb.FieldDescriptorProto_TYPE_FLOAT:
		def = map[string]interface{}{
			"type":   "number",
			"format": "float",
		}

	case pb.FieldDescriptorProto_TYPE_DOUBLE:
		def = map[string]interface{}{
			"type":   "number",
			"format": "double",
		}

	case pb.FieldDescriptorProto_TYPE_FIXED32,
		pb.FieldDescriptorProto_TYPE_FIXED64,
		pb.FieldDescriptorProto_TYPE_UINT32,
		pb.FieldDescriptorProto_TYPE_UINT64:
		def = map[string]interface{}{
			"type": "integer",
		}

	case pb.FieldDescriptorProto_TYPE_INT32,
		pb.FieldDescriptorProto_TYPE_SFIXED32,
		pb.FieldDescriptorProto_TYPE_SINT32:
		def = map[string]interface{}{
			"type":   "integer",
			"format": "int32",
		}

	case pb.FieldDescriptorProto_TYPE_INT64,
		pb.FieldDescriptorProto_TYPE_SFIXED64,
		pb.FieldDescriptorProto_TYPE_SINT64:
		def = map[string]interface{}{
			"type":   "integer",
			"format": "int64",
		}

	case pb.FieldDescriptorProto_TYPE_STRING:
		def = map[string]interface{}{
			"type": "string",
		}
		if x, ok := limbo.GetFormat(field); ok {
			def["format"] = x
		}
		if x, ok := limbo.GetPattern(field); ok {
			def["pattern"] = x
		}
		if x, ok := limbo.GetMinLength(field); ok {
			def["minLength"] = x
		}
		if x, ok := limbo.GetMaxLength(field); ok {
			def["maxLength"] = x
		}

	case pb.FieldDescriptorProto_TYPE_BYTES:
		def = map[string]interface{}{
			"type":   "string",
			"format": "base64",
		}

	case pb.FieldDescriptorProto_TYPE_ENUM:
		dep = strings.TrimPrefix(field.GetTypeName(), ".")
		def = map[string]interface{}{
			"$ref": dep,
		}

	case pb.FieldDescriptorProto_TYPE_MESSAGE:
		dep = strings.TrimPrefix(field.GetTypeName(), ".")
		def = map[string]interface{}{
			"$ref": dep,
		}

	default:
		panic("unsupported " + field.GetType().String())

	}

	if field.IsRepeated() {
		def = map[string]interface{}{
			"type":  "array",
			"items": def,
		}
		if x, ok := limbo.GetMinItems(field); ok {
			def["minItems"] = x
		}
		if x, ok := limbo.GetMaxItems(field); ok {
			def["maxItems"] = x
		}
	}

	return def, dep
}
func TestGetExtensionDefaults(t *testing.T) {
	var setFloat64 float64 = 1
	var setFloat32 float32 = 2
	var setInt32 int32 = 3
	var setInt64 int64 = 4
	var setUint32 uint32 = 5
	var setUint64 uint64 = 6
	var setBool = true
	var setBool2 = false
	var setString = "Goodnight string"
	var setBytes = []byte("Goodnight bytes")
	var setEnum = pb.DefaultsMessage_TWO

	type testcase struct {
		ext  *proto.ExtensionDesc // Extension we are testing.
		want interface{}          // Expected value of extension, or nil (meaning that GetExtension will fail).
		def  interface{}          // Expected value of extension after ClearExtension().
	}
	tests := []testcase{
		{pb.E_NoDefaultDouble, setFloat64, nil},
		{pb.E_NoDefaultFloat, setFloat32, nil},
		{pb.E_NoDefaultInt32, setInt32, nil},
		{pb.E_NoDefaultInt64, setInt64, nil},
		{pb.E_NoDefaultUint32, setUint32, nil},
		{pb.E_NoDefaultUint64, setUint64, nil},
		{pb.E_NoDefaultSint32, setInt32, nil},
		{pb.E_NoDefaultSint64, setInt64, nil},
		{pb.E_NoDefaultFixed32, setUint32, nil},
		{pb.E_NoDefaultFixed64, setUint64, nil},
		{pb.E_NoDefaultSfixed32, setInt32, nil},
		{pb.E_NoDefaultSfixed64, setInt64, nil},
		{pb.E_NoDefaultBool, setBool, nil},
		{pb.E_NoDefaultBool, setBool2, nil},
		{pb.E_NoDefaultString, setString, nil},
		{pb.E_NoDefaultBytes, setBytes, nil},
		{pb.E_NoDefaultEnum, setEnum, nil},
		{pb.E_DefaultDouble, setFloat64, float64(3.1415)},
		{pb.E_DefaultFloat, setFloat32, float32(3.14)},
		{pb.E_DefaultInt32, setInt32, int32(42)},
		{pb.E_DefaultInt64, setInt64, int64(43)},
		{pb.E_DefaultUint32, setUint32, uint32(44)},
		{pb.E_DefaultUint64, setUint64, uint64(45)},
		{pb.E_DefaultSint32, setInt32, int32(46)},
		{pb.E_DefaultSint64, setInt64, int64(47)},
		{pb.E_DefaultFixed32, setUint32, uint32(48)},
		{pb.E_DefaultFixed64, setUint64, uint64(49)},
		{pb.E_DefaultSfixed32, setInt32, int32(50)},
		{pb.E_DefaultSfixed64, setInt64, int64(51)},
		{pb.E_DefaultBool, setBool, true},
		{pb.E_DefaultBool, setBool2, true},
		{pb.E_DefaultString, setString, "Hello, string"},
		{pb.E_DefaultBytes, setBytes, []byte("Hello, bytes")},
		{pb.E_DefaultEnum, setEnum, pb.DefaultsMessage_ONE},
	}

	checkVal := func(test testcase, msg *pb.DefaultsMessage, valWant interface{}) error {
		val, err := proto.GetExtension(msg, test.ext)
		if err != nil {
			if valWant != nil {
				return fmt.Errorf("GetExtension(): %s", err)
			}
			if want := proto.ErrMissingExtension; err != want {
				return fmt.Errorf("Unexpected error: got %v, want %v", err, want)
			}
			return nil
		}

		// All proto2 extension values are either a pointer to a value or a slice of values.
		ty := reflect.TypeOf(val)
		tyWant := reflect.TypeOf(test.ext.ExtensionType)
		if got, want := ty, tyWant; got != want {
			return fmt.Errorf("unexpected reflect.TypeOf(): got %v want %v", got, want)
		}
		tye := ty.Elem()
		tyeWant := tyWant.Elem()
		if got, want := tye, tyeWant; got != want {
			return fmt.Errorf("unexpected reflect.TypeOf().Elem(): got %v want %v", got, want)
		}

		// Check the name of the type of the value.
		// If it is an enum it will be type int32 with the name of the enum.
		if got, want := tye.Name(), tye.Name(); got != want {
			return fmt.Errorf("unexpected reflect.TypeOf().Elem().Name(): got %v want %v", got, want)
		}

		// Check that value is what we expect.
		// If we have a pointer in val, get the value it points to.
		valExp := val
		if ty.Kind() == reflect.Ptr {
			valExp = reflect.ValueOf(val).Elem().Interface()
		}
		if got, want := valExp, valWant; !reflect.DeepEqual(got, want) {
			return fmt.Errorf("unexpected reflect.DeepEqual(): got %v want %v", got, want)
		}

		return nil
	}

	setTo := func(test testcase) interface{} {
		setTo := reflect.ValueOf(test.want)
		if typ := reflect.TypeOf(test.ext.ExtensionType); typ.Kind() == reflect.Ptr {
			setTo = reflect.New(typ).Elem()
			setTo.Set(reflect.New(setTo.Type().Elem()))
			setTo.Elem().Set(reflect.ValueOf(test.want))
		}
		return setTo.Interface()
	}

	for _, test := range tests {
		msg := &pb.DefaultsMessage{}
		name := test.ext.Name

		// Check the initial value.
		if err := checkVal(test, msg, test.def); err != nil {
			t.Errorf("%s: %v", name, err)
		}

		// Set the per-type value and check value.
		name = fmt.Sprintf("%s (set to %T %v)", name, test.want, test.want)
		if err := proto.SetExtension(msg, test.ext, setTo(test)); err != nil {
			t.Errorf("%s: SetExtension(): %v", name, err)
			continue
		}
		if err := checkVal(test, msg, test.want); err != nil {
			t.Errorf("%s: %v", name, err)
			continue
		}

		// Set and check the value.
		name += " (cleared)"
		proto.ClearExtension(msg, test.ext)
		if err := checkVal(test, msg, test.def); err != nil {
			t.Errorf("%s: %v", name, err)
		}
	}
}
Beispiel #26
0
func (g *svcauth) findMethods(file *generator.FileDescriptor, service *pb.ServiceDescriptorProto) []*authMethod {
	methods := make([]*authMethod, 0, len(service.Method))

	var (
		defaultAuthnInfo *AuthnRule
		defaultAuthzInfo *AuthzRule
	)

	if service.Options != nil {
		v, _ := proto.GetExtension(service.Options, E_DefaultAuthn)
		defaultAuthnInfo, _ = v.(*AuthnRule)
	}

	if service.Options != nil {
		v, _ := proto.GetExtension(service.Options, E_DefaultAuthz)
		defaultAuthzInfo, _ = v.(*AuthzRule)
	}

	for _, method := range service.Method {
		var (
			authnInfo *AuthnRule
			authzInfo *AuthzRule
		)

		{ // authn
			v, _ := proto.GetExtension(method.Options, E_Authn)
			authnInfo, _ = v.(*AuthnRule)
			if authnInfo == nil && defaultAuthnInfo != nil {
				authnInfo = &AuthnRule{}
			}
			if authnInfo != nil {
				authnInfo = defaultAuthnInfo.Inherit(authnInfo)
				authnInfo.SetDefaults()
			}
		}

		{ // authz
			v, _ := proto.GetExtension(method.Options, E_Authz)
			authzInfo, _ = v.(*AuthzRule)
			if authzInfo == nil && defaultAuthzInfo != nil {
				authzInfo = &AuthzRule{}
			}
			if authzInfo != nil {
				authzInfo = defaultAuthzInfo.Inherit(authzInfo)
				authzInfo.SetDefaults()
			}
		}

		if authnInfo == nil && authzInfo == nil {
			continue
		}

		methods = append(methods, &authMethod{
			file:    file,
			service: service,
			method:  method,
			Authn:   authnInfo,
			Authz:   authzInfo,
			Name:    fmt.Sprintf("/%s.%s/%s", file.GetPackage(), service.GetName(), method.GetName()),
		})
	}

	return methods
}