Exemple #1
0
func (p *size) Generate(file *generator.FileDescriptor) {
	p.PluginImports = generator.NewPluginImports(p.Generator)
	p.atleastOne = false
	p.localName = generator.FileName(file)
	protoPkg := p.NewImport("github.com/gogo/protobuf/proto")
	if !gogoproto.ImportsGoGoProto(file.FileDescriptorProto) {
		protoPkg = p.NewImport("github.com/golang/protobuf/proto")
	}
	for _, message := range file.Messages() {
		if !gogoproto.IsSizer(file.FileDescriptorProto, message.DescriptorProto) {
			continue
		}
		if message.DescriptorProto.GetOptions().GetMapEntry() {
			continue
		}
		p.atleastOne = true
		proto3 := gogoproto.IsProto3(file.FileDescriptorProto)

		ccTypeName := generator.CamelCaseSlice(message.TypeName())
		p.P(`func (m *`, ccTypeName, `) Size() (n int) {`)
		p.In()
		p.P(`var l int`)
		p.P(`_ = l`)
		for _, field := range message.Field {
			fieldname := p.GetFieldName(message, field)
			nullable := gogoproto.IsNullable(field)
			repeated := field.IsRepeated()
			if repeated {
				p.P(`if len(m.`, fieldname, `) > 0 {`)
				p.In()
			} else if (!proto3 && nullable) || (!gogoproto.IsCustomType(field) && *field.Type == descriptor.FieldDescriptorProto_TYPE_BYTES) {
				p.P(`if m.`, fieldname, ` != nil {`)
				p.In()
			}
			packed := field.IsPacked()
			_, wire := p.GoType(message, field)
			wireType := wireToType(wire)
			fieldNumber := field.GetNumber()
			if packed {
				wireType = proto.WireBytes
			}
			key := keySize(fieldNumber, wireType)
			switch *field.Type {
			case descriptor.FieldDescriptorProto_TYPE_DOUBLE,
				descriptor.FieldDescriptorProto_TYPE_FIXED64,
				descriptor.FieldDescriptorProto_TYPE_SFIXED64:
				if packed {
					p.P(`n+=`, strconv.Itoa(key), `+sov`, p.localName, `(uint64(len(m.`, fieldname, `)*8))`, `+len(m.`, fieldname, `)*8`)
				} else if repeated {
					p.P(`n+=`, strconv.Itoa(key+8), `*len(m.`, fieldname, `)`)
				} else if proto3 {
					p.P(`if m.`, fieldname, ` != 0 {`)
					p.In()
					p.P(`n+=`, strconv.Itoa(key+8))
					p.Out()
					p.P(`}`)
				} else if nullable {
					p.P(`n+=`, strconv.Itoa(key+8))
				} else {
					p.P(`n+=`, strconv.Itoa(key+8))
				}
			case descriptor.FieldDescriptorProto_TYPE_FLOAT,
				descriptor.FieldDescriptorProto_TYPE_FIXED32,
				descriptor.FieldDescriptorProto_TYPE_SFIXED32:
				if packed {
					p.P(`n+=`, strconv.Itoa(key), `+sov`, p.localName, `(uint64(len(m.`, fieldname, `)*4))`, `+len(m.`, fieldname, `)*4`)
				} else if repeated {
					p.P(`n+=`, strconv.Itoa(key+4), `*len(m.`, fieldname, `)`)
				} else if proto3 {
					p.P(`if m.`, fieldname, ` != 0 {`)
					p.In()
					p.P(`n+=`, strconv.Itoa(key+4))
					p.Out()
					p.P(`}`)
				} else if nullable {
					p.P(`n+=`, strconv.Itoa(key+4))
				} else {
					p.P(`n+=`, strconv.Itoa(key+4))
				}
			case descriptor.FieldDescriptorProto_TYPE_INT64,
				descriptor.FieldDescriptorProto_TYPE_UINT64,
				descriptor.FieldDescriptorProto_TYPE_UINT32,
				descriptor.FieldDescriptorProto_TYPE_ENUM,
				descriptor.FieldDescriptorProto_TYPE_INT32:
				if packed {
					p.P(`l = 0`)
					p.P(`for _, e := range m.`, fieldname, ` {`)
					p.In()
					p.P(`l+=sov`, p.localName, `(uint64(e))`)
					p.Out()
					p.P(`}`)
					p.P(`n+=`, strconv.Itoa(key), `+sov`, p.localName, `(uint64(l))+l`)
				} else if repeated {
					p.P(`for _, e := range m.`, fieldname, ` {`)
					p.In()
					p.P(`n+=`, strconv.Itoa(key), `+sov`, p.localName, `(uint64(e))`)
					p.Out()
					p.P(`}`)
				} else if proto3 {
					p.P(`if m.`, fieldname, ` != 0 {`)
					p.In()
					p.P(`n+=`, strconv.Itoa(key), `+sov`, p.localName, `(uint64(m.`, fieldname, `))`)
					p.Out()
					p.P(`}`)
				} else if nullable {
					p.P(`n+=`, strconv.Itoa(key), `+sov`, p.localName, `(uint64(*m.`, fieldname, `))`)
				} else {
					p.P(`n+=`, strconv.Itoa(key), `+sov`, p.localName, `(uint64(m.`, fieldname, `))`)
				}
			case descriptor.FieldDescriptorProto_TYPE_BOOL:
				if packed {
					p.P(`n+=`, strconv.Itoa(key), `+sov`, p.localName, `(uint64(len(m.`, fieldname, `)))`, `+len(m.`, fieldname, `)*1`)
				} else if repeated {
					p.P(`n+=`, strconv.Itoa(key+1), `*len(m.`, fieldname, `)`)
				} else if proto3 {
					p.P(`if m.`, fieldname, ` {`)
					p.In()
					p.P(`n+=`, strconv.Itoa(key+1))
					p.Out()
					p.P(`}`)
				} else if nullable {
					p.P(`n+=`, strconv.Itoa(key+1))
				} else {
					p.P(`n+=`, strconv.Itoa(key+1))
				}
			case descriptor.FieldDescriptorProto_TYPE_STRING:
				if repeated {
					p.P(`for _, s := range m.`, fieldname, ` { `)
					p.In()
					p.P(`l = len(s)`)
					p.P(`n+=`, strconv.Itoa(key), `+l+sov`, p.localName, `(uint64(l))`)
					p.Out()
					p.P(`}`)
				} else if proto3 {
					p.P(`l=len(m.`, fieldname, `)`)
					p.P(`if l > 0 {`)
					p.In()
					p.P(`n+=`, strconv.Itoa(key), `+l+sov`, p.localName, `(uint64(l))`)
					p.Out()
					p.P(`}`)
				} else if nullable {
					p.P(`l=len(*m.`, fieldname, `)`)
					p.P(`n+=`, strconv.Itoa(key), `+l+sov`, p.localName, `(uint64(l))`)
				} else {
					p.P(`l=len(m.`, fieldname, `)`)
					p.P(`n+=`, strconv.Itoa(key), `+l+sov`, p.localName, `(uint64(l))`)
				}
			case descriptor.FieldDescriptorProto_TYPE_GROUP:
				panic(fmt.Errorf("size does not support group %v", fieldname))
			case descriptor.FieldDescriptorProto_TYPE_MESSAGE:
				if generator.IsMap(file.FileDescriptorProto, field) {
					mapMsg := generator.GetMap(file.FileDescriptorProto, field)
					keyField, valueField := mapMsg.GetMapFields()
					_, keywire := p.GoType(nil, keyField)
					_, valuewire := p.GoType(nil, valueField)
					_, fieldwire := p.GoType(nil, field)
					fieldKeySize := keySize(field.GetNumber(), wireToType(fieldwire))
					keyKeySize := keySize(1, wireToType(keywire))
					valueKeySize := keySize(2, wireToType(valuewire))
					p.P(`for k, v := range m.`, fieldname, ` { `)
					p.In()
					p.P(`_ = k`)
					p.P(`_ = v`)
					sum := []string{strconv.Itoa(keyKeySize)}
					switch keyField.GetType() {
					case descriptor.FieldDescriptorProto_TYPE_DOUBLE,
						descriptor.FieldDescriptorProto_TYPE_FIXED64,
						descriptor.FieldDescriptorProto_TYPE_SFIXED64:
						sum = append(sum, `8`)
					case descriptor.FieldDescriptorProto_TYPE_FLOAT,
						descriptor.FieldDescriptorProto_TYPE_FIXED32,
						descriptor.FieldDescriptorProto_TYPE_SFIXED32:
						sum = append(sum, `4`)
					case descriptor.FieldDescriptorProto_TYPE_INT64,
						descriptor.FieldDescriptorProto_TYPE_UINT64,
						descriptor.FieldDescriptorProto_TYPE_UINT32,
						descriptor.FieldDescriptorProto_TYPE_ENUM,
						descriptor.FieldDescriptorProto_TYPE_INT32:
						sum = append(sum, `sov`+p.localName+`(uint64(k))`)
					case descriptor.FieldDescriptorProto_TYPE_BOOL:
						sum = append(sum, `1`)
					case descriptor.FieldDescriptorProto_TYPE_STRING,
						descriptor.FieldDescriptorProto_TYPE_BYTES:
						sum = append(sum, `len(k)+sov`+p.localName+`(uint64(len(k)))`)
					case descriptor.FieldDescriptorProto_TYPE_SINT32,
						descriptor.FieldDescriptorProto_TYPE_SINT64:
						sum = append(sum, `soz`+p.localName+`(uint64(k))`)
					}
					sum = append(sum, strconv.Itoa(valueKeySize))
					switch valueField.GetType() {
					case descriptor.FieldDescriptorProto_TYPE_DOUBLE,
						descriptor.FieldDescriptorProto_TYPE_FIXED64,
						descriptor.FieldDescriptorProto_TYPE_SFIXED64:
						sum = append(sum, strconv.Itoa(8))
					case descriptor.FieldDescriptorProto_TYPE_FLOAT,
						descriptor.FieldDescriptorProto_TYPE_FIXED32,
						descriptor.FieldDescriptorProto_TYPE_SFIXED32:
						sum = append(sum, strconv.Itoa(4))
					case descriptor.FieldDescriptorProto_TYPE_INT64,
						descriptor.FieldDescriptorProto_TYPE_UINT64,
						descriptor.FieldDescriptorProto_TYPE_UINT32,
						descriptor.FieldDescriptorProto_TYPE_ENUM,
						descriptor.FieldDescriptorProto_TYPE_INT32:
						sum = append(sum, `sov`+p.localName+`(uint64(v))`)
					case descriptor.FieldDescriptorProto_TYPE_BOOL:
						sum = append(sum, `1`)
					case descriptor.FieldDescriptorProto_TYPE_STRING,
						descriptor.FieldDescriptorProto_TYPE_BYTES:
						sum = append(sum, `len(v)+sov`+p.localName+`(uint64(len(v)))`)
					case descriptor.FieldDescriptorProto_TYPE_SINT32,
						descriptor.FieldDescriptorProto_TYPE_SINT64:
						sum = append(sum, `soz`+p.localName+`(uint64(v))`)
					case descriptor.FieldDescriptorProto_TYPE_MESSAGE:
						p.P(`l=v.Size()`)
						sum = append(sum, `l+sov`+p.localName+`(uint64(l))`)
					}
					p.P(`mapEntrySize := `, strings.Join(sum, "+"))
					p.P(`n+=mapEntrySize+`, fieldKeySize, `+sov`, p.localName, `(uint64(mapEntrySize))`)
					p.Out()
					p.P(`}`)
				} else if repeated {
					p.P(`for _, e := range m.`, fieldname, ` { `)
					p.In()
					p.P(`l=e.Size()`)
					p.P(`n+=`, strconv.Itoa(key), `+l+sov`, p.localName, `(uint64(l))`)
					p.Out()
					p.P(`}`)
				} else {
					p.P(`l=m.`, fieldname, `.Size()`)
					p.P(`n+=`, strconv.Itoa(key), `+l+sov`, p.localName, `(uint64(l))`)
				}
			case descriptor.FieldDescriptorProto_TYPE_BYTES:
				if !gogoproto.IsCustomType(field) {
					if repeated {
						p.P(`for _, b := range m.`, fieldname, ` { `)
						p.In()
						p.P(`l = len(b)`)
						p.P(`n+=`, strconv.Itoa(key), `+l+sov`, p.localName, `(uint64(l))`)
						p.Out()
						p.P(`}`)
					} else if proto3 {
						p.P(`l=len(m.`, fieldname, `)`)
						p.P(`if l > 0 {`)
						p.In()
						p.P(`n+=`, strconv.Itoa(key), `+l+sov`, p.localName, `(uint64(l))`)
						p.Out()
						p.P(`}`)
					} else {
						p.P(`l=len(m.`, fieldname, `)`)
						p.P(`n+=`, strconv.Itoa(key), `+l+sov`, p.localName, `(uint64(l))`)
					}
				} else {
					if repeated {
						p.P(`for _, e := range m.`, fieldname, ` { `)
						p.In()
						p.P(`l=e.Size()`)
						p.P(`n+=`, strconv.Itoa(key), `+l+sov`, p.localName, `(uint64(l))`)
						p.Out()
						p.P(`}`)
					} else {
						p.P(`l=m.`, fieldname, `.Size()`)
						p.P(`n+=`, strconv.Itoa(key), `+l+sov`, p.localName, `(uint64(l))`)
					}
				}
			case descriptor.FieldDescriptorProto_TYPE_SINT32,
				descriptor.FieldDescriptorProto_TYPE_SINT64:
				if packed {
					p.P(`l = 0`)
					p.P(`for _, e := range m.`, fieldname, ` {`)
					p.In()
					p.P(`l+=soz`, p.localName, `(uint64(e))`)
					p.Out()
					p.P(`}`)
					p.P(`n+=`, strconv.Itoa(key), `+sov`, p.localName, `(uint64(l))+l`)
				} else if repeated {
					p.P(`for _, e := range m.`, fieldname, ` {`)
					p.In()
					p.P(`n+=`, strconv.Itoa(key), `+soz`, p.localName, `(uint64(e))`)
					p.Out()
					p.P(`}`)
				} else if proto3 {
					p.P(`if m.`, fieldname, ` != 0 {`)
					p.In()
					p.P(`n+=`, strconv.Itoa(key), `+soz`, p.localName, `(uint64(m.`, fieldname, `))`)
					p.Out()
					p.P(`}`)
				} else if nullable {
					p.P(`n+=`, strconv.Itoa(key), `+soz`, p.localName, `(uint64(*m.`, fieldname, `))`)
				} else {
					p.P(`n+=`, strconv.Itoa(key), `+soz`, p.localName, `(uint64(m.`, fieldname, `))`)
				}
			default:
				panic("not implemented")
			}
			if (!proto3 && nullable) || repeated || (!gogoproto.IsCustomType(field) && *field.Type == descriptor.FieldDescriptorProto_TYPE_BYTES) {
				p.Out()
				p.P(`}`)
			}
		}
		if message.DescriptorProto.HasExtension() {
			p.P(`if m.XXX_extensions != nil {`)
			p.In()
			if gogoproto.HasExtensionsMap(file.FileDescriptorProto, message.DescriptorProto) {
				p.P(`n += `, protoPkg.Use(), `.SizeOfExtensionMap(m.XXX_extensions)`)
			} else {
				p.P(`n+=len(m.XXX_extensions)`)
			}
			p.Out()
			p.P(`}`)
		}
		if gogoproto.HasUnrecognized(file.FileDescriptorProto, message.DescriptorProto) {
			p.P(`if m.XXX_unrecognized != nil {`)
			p.In()
			p.P(`n+=len(m.XXX_unrecognized)`)
			p.Out()
			p.P(`}`)
		}
		p.P(`return n`)
		p.Out()
		p.P(`}`)
		p.P()
	}

	if !p.atleastOne {
		return
	}

	p.sizeVarint()
	p.sizeZigZag()

}
Exemple #2
0
func (p *size) Generate(file *generator.FileDescriptor) {
	p.PluginImports = generator.NewPluginImports(p.Generator)
	p.atleastOne = false
	p.localName = generator.FileName(file)
	p.typesPkg = p.NewImport("github.com/gogo/protobuf/types")
	protoPkg := p.NewImport("github.com/gogo/protobuf/proto")
	if !gogoproto.ImportsGoGoProto(file.FileDescriptorProto) {
		protoPkg = p.NewImport("github.com/golang/protobuf/proto")
	}
	for _, message := range file.Messages() {
		sizeName := ""
		if gogoproto.IsSizer(file.FileDescriptorProto, message.DescriptorProto) {
			sizeName = "Size"
		} else if gogoproto.IsProtoSizer(file.FileDescriptorProto, message.DescriptorProto) {
			sizeName = "ProtoSize"
		} else {
			continue
		}
		if message.DescriptorProto.GetOptions().GetMapEntry() {
			continue
		}
		p.atleastOne = true
		ccTypeName := generator.CamelCaseSlice(message.TypeName())
		p.P(`func (m *`, ccTypeName, `) `, sizeName, `() (n int) {`)
		p.In()
		p.P(`var l int`)
		p.P(`_ = l`)
		oneofs := make(map[string]struct{})
		for _, field := range message.Field {
			oneof := field.OneofIndex != nil
			if !oneof {
				proto3 := gogoproto.IsProto3(file.FileDescriptorProto)
				p.generateField(proto3, file, message, field, sizeName)
			} else {
				fieldname := p.GetFieldName(message, field)
				if _, ok := oneofs[fieldname]; ok {
					continue
				} else {
					oneofs[fieldname] = struct{}{}
				}
				p.P(`if m.`, fieldname, ` != nil {`)
				p.In()
				p.P(`n+=m.`, fieldname, `.`, sizeName, `()`)
				p.Out()
				p.P(`}`)
			}
		}
		if message.DescriptorProto.HasExtension() {
			if gogoproto.HasExtensionsMap(file.FileDescriptorProto, message.DescriptorProto) {
				p.P(`n += `, protoPkg.Use(), `.SizeOfInternalExtension(m)`)
			} else {
				p.P(`if m.XXX_extensions != nil {`)
				p.In()
				p.P(`n+=len(m.XXX_extensions)`)
				p.Out()
				p.P(`}`)
			}
		}
		if gogoproto.HasUnrecognized(file.FileDescriptorProto, message.DescriptorProto) {
			p.P(`if m.XXX_unrecognized != nil {`)
			p.In()
			p.P(`n+=len(m.XXX_unrecognized)`)
			p.Out()
			p.P(`}`)
		}
		p.P(`return n`)
		p.Out()
		p.P(`}`)
		p.P()

		//Generate Size methods for oneof fields
		m := proto.Clone(message.DescriptorProto).(*descriptor.DescriptorProto)
		for _, f := range m.Field {
			oneof := f.OneofIndex != nil
			if !oneof {
				continue
			}
			ccTypeName := p.OneOfTypeName(message, f)
			p.P(`func (m *`, ccTypeName, `) `, sizeName, `() (n int) {`)
			p.In()
			p.P(`var l int`)
			p.P(`_ = l`)
			vanity.TurnOffNullableForNativeTypesWithoutDefaultsOnly(f)
			p.generateField(false, file, message, f, sizeName)
			p.P(`return n`)
			p.Out()
			p.P(`}`)
		}
	}

	if !p.atleastOne {
		return
	}

	p.sizeVarint()
	p.sizeZigZag()

}
Exemple #3
0
func (p *test) Generate(imports generator.PluginImports, file *generator.FileDescriptor) bool {
	used := false
	randPkg := imports.NewImport("math/rand")
	timePkg := imports.NewImport("time")
	testingPkg := imports.NewImport("testing")
	protoPkg := imports.NewImport("github.com/gogo/protobuf/proto")
	if !gogoproto.ImportsGoGoProto(file.FileDescriptorProto) {
		protoPkg = imports.NewImport("github.com/golang/protobuf/proto")
	}
	for _, message := range file.Messages() {
		ccTypeName := generator.CamelCaseSlice(message.TypeName())
		sizeName := ""
		if gogoproto.IsSizer(file.FileDescriptorProto, message.DescriptorProto) {
			sizeName = "Size"
		} else if gogoproto.IsProtoSizer(file.FileDescriptorProto, message.DescriptorProto) {
			sizeName = "ProtoSize"
		} else {
			continue
		}
		if message.DescriptorProto.GetOptions().GetMapEntry() {
			continue
		}

		if gogoproto.HasTestGen(file.FileDescriptorProto, message.DescriptorProto) {
			used = true
			p.P(`func Test`, ccTypeName, sizeName, `(t *`, testingPkg.Use(), `.T) {`)
			p.In()
			p.P(`seed := `, timePkg.Use(), `.Now().UnixNano()`)
			p.P(`popr := `, randPkg.Use(), `.New(`, randPkg.Use(), `.NewSource(seed))`)
			p.P(`p := NewPopulated`, ccTypeName, `(popr, true)`)
			p.P(`size2 := `, protoPkg.Use(), `.Size(p)`)
			p.P(`dAtA, err := `, protoPkg.Use(), `.Marshal(p)`)
			p.P(`if err != nil {`)
			p.In()
			p.P(`t.Fatalf("seed = %d, err = %v", seed, err)`)
			p.Out()
			p.P(`}`)
			p.P(`size := p.`, sizeName, `()`)
			p.P(`if len(dAtA) != size {`)
			p.In()
			p.P(`t.Errorf("seed = %d, size %v != marshalled size %v", seed, size, len(dAtA))`)
			p.Out()
			p.P(`}`)
			p.P(`if size2 != size {`)
			p.In()
			p.P(`t.Errorf("seed = %d, size %v != before marshal proto.Size %v", seed, size, size2)`)
			p.Out()
			p.P(`}`)
			p.P(`size3 := `, protoPkg.Use(), `.Size(p)`)
			p.P(`if size3 != size {`)
			p.In()
			p.P(`t.Errorf("seed = %d, size %v != after marshal proto.Size %v", seed, size, size3)`)
			p.Out()
			p.P(`}`)
			p.Out()
			p.P(`}`)
			p.P()
		}

		if gogoproto.HasBenchGen(file.FileDescriptorProto, message.DescriptorProto) {
			used = true
			p.P(`func Benchmark`, ccTypeName, sizeName, `(b *`, testingPkg.Use(), `.B) {`)
			p.In()
			p.P(`popr := `, randPkg.Use(), `.New(`, randPkg.Use(), `.NewSource(616))`)
			p.P(`total := 0`)
			p.P(`pops := make([]*`, ccTypeName, `, 1000)`)
			p.P(`for i := 0; i < 1000; i++ {`)
			p.In()
			p.P(`pops[i] = NewPopulated`, ccTypeName, `(popr, false)`)
			p.Out()
			p.P(`}`)
			p.P(`b.ResetTimer()`)
			p.P(`for i := 0; i < b.N; i++ {`)
			p.In()
			p.P(`total += pops[i%1000].`, sizeName, `()`)
			p.Out()
			p.P(`}`)
			p.P(`b.SetBytes(int64(total / b.N))`)
			p.Out()
			p.P(`}`)
			p.P()
		}

	}
	return used
}