Ejemplo n.º 1
0
func (g *Generator) GetFieldName(message *Descriptor, field *descriptor.FieldDescriptorProto) string {
	goTyp, _ := g.GoType(message, field)
	fieldname := CamelCase(*field.Name)
	if gogoproto.IsCustomName(field) {
		fieldname = gogoproto.GetCustomName(field)
	}
	if gogoproto.IsEmbed(field) {
		fieldname = EmbedFieldName(goTyp)
	}
	if field.OneofIndex != nil {
		fieldname = message.OneofDecl[int(*field.OneofIndex)].GetName()
		fieldname = CamelCase(fieldname)
	}
	for _, f := range methodNames {
		if f == fieldname {
			return fieldname + "_"
		}
	}
	if !gogoproto.IsProtoSizer(message.file, message.DescriptorProto) {
		if fieldname == "Size" {
			return fieldname + "_"
		}
	}
	return fieldname
}
Ejemplo n.º 2
0
func (p *size) Generate(file *generator.FileDescriptor) {
	p.PluginImports = generator.NewPluginImports(p.Generator)
	p.atleastOne = false
	p.localName = generator.FileName(file)
	p.typesPkg = p.NewImport("github.com/maditya/protobuf/types")
	protoPkg := p.NewImport("github.com/maditya/protobuf/proto")
	if !gogoproto.ImportsGoGoProto(file.FileDescriptorProto) {
		protoPkg = p.NewImport("github.com/golang/protobuf/proto")
	}
	for _, message := range file.Messages() {
		sizeName := ""
		if gogoproto.IsSizer(file.FileDescriptorProto, message.DescriptorProto) {
			sizeName = "Size"
		} else if gogoproto.IsProtoSizer(file.FileDescriptorProto, message.DescriptorProto) {
			sizeName = "ProtoSize"
		} else {
			continue
		}
		if message.DescriptorProto.GetOptions().GetMapEntry() {
			continue
		}
		p.atleastOne = true
		ccTypeName := generator.CamelCaseSlice(message.TypeName())
		p.P(`func (m *`, ccTypeName, `) `, sizeName, `() (n int) {`)
		p.In()
		p.P(`var l int`)
		p.P(`_ = l`)
		oneofs := make(map[string]struct{})
		for _, field := range message.Field {
			oneof := field.OneofIndex != nil
			if !oneof {
				proto3 := gogoproto.IsProto3(file.FileDescriptorProto)
				p.generateField(proto3, file, message, field, sizeName)
			} else {
				fieldname := p.GetFieldName(message, field)
				if _, ok := oneofs[fieldname]; ok {
					continue
				} else {
					oneofs[fieldname] = struct{}{}
				}
				p.P(`if m.`, fieldname, ` != nil {`)
				p.In()
				p.P(`n+=m.`, fieldname, `.`, sizeName, `()`)
				p.Out()
				p.P(`}`)
			}
		}
		if message.DescriptorProto.HasExtension() {
			if gogoproto.HasExtensionsMap(file.FileDescriptorProto, message.DescriptorProto) {
				p.P(`n += `, protoPkg.Use(), `.SizeOfInternalExtension(m)`)
			} else {
				p.P(`if m.XXX_extensions != nil {`)
				p.In()
				p.P(`n+=len(m.XXX_extensions)`)
				p.Out()
				p.P(`}`)
			}
		}
		if gogoproto.HasUnrecognized(file.FileDescriptorProto, message.DescriptorProto) {
			p.P(`if m.XXX_unrecognized != nil {`)
			p.In()
			p.P(`n+=len(m.XXX_unrecognized)`)
			p.Out()
			p.P(`}`)
		}
		p.P(`return n`)
		p.Out()
		p.P(`}`)
		p.P()

		//Generate Size methods for oneof fields
		m := proto.Clone(message.DescriptorProto).(*descriptor.DescriptorProto)
		for _, f := range m.Field {
			oneof := f.OneofIndex != nil
			if !oneof {
				continue
			}
			ccTypeName := p.OneOfTypeName(message, f)
			p.P(`func (m *`, ccTypeName, `) `, sizeName, `() (n int) {`)
			p.In()
			p.P(`var l int`)
			p.P(`_ = l`)
			vanity.TurnOffNullableForNativeTypesWithoutDefaultsOnly(f)
			p.generateField(false, file, message, f, sizeName)
			p.P(`return n`)
			p.Out()
			p.P(`}`)
		}
	}

	if !p.atleastOne {
		return
	}

	p.sizeVarint()
	p.sizeZigZag()

}
Ejemplo n.º 3
0
func (p *test) Generate(imports generator.PluginImports, file *generator.FileDescriptor) bool {
	used := false
	randPkg := imports.NewImport("math/rand")
	timePkg := imports.NewImport("time")
	testingPkg := imports.NewImport("testing")
	protoPkg := imports.NewImport("github.com/maditya/protobuf/proto")
	if !gogoproto.ImportsGoGoProto(file.FileDescriptorProto) {
		protoPkg = imports.NewImport("github.com/golang/protobuf/proto")
	}
	for _, message := range file.Messages() {
		ccTypeName := generator.CamelCaseSlice(message.TypeName())
		sizeName := ""
		if gogoproto.IsSizer(file.FileDescriptorProto, message.DescriptorProto) {
			sizeName = "Size"
		} else if gogoproto.IsProtoSizer(file.FileDescriptorProto, message.DescriptorProto) {
			sizeName = "ProtoSize"
		} else {
			continue
		}
		if message.DescriptorProto.GetOptions().GetMapEntry() {
			continue
		}

		if gogoproto.HasTestGen(file.FileDescriptorProto, message.DescriptorProto) {
			used = true
			p.P(`func Test`, ccTypeName, sizeName, `(t *`, testingPkg.Use(), `.T) {`)
			p.In()
			p.P(`seed := `, timePkg.Use(), `.Now().UnixNano()`)
			p.P(`popr := `, randPkg.Use(), `.New(`, randPkg.Use(), `.NewSource(seed))`)
			p.P(`p := NewPopulated`, ccTypeName, `(popr, true)`)
			p.P(`size2 := `, protoPkg.Use(), `.Size(p)`)
			p.P(`data, err := `, protoPkg.Use(), `.Marshal(p)`)
			p.P(`if err != nil {`)
			p.In()
			p.P(`t.Fatalf("seed = %d, err = %v", seed, err)`)
			p.Out()
			p.P(`}`)
			p.P(`size := p.`, sizeName, `()`)
			p.P(`if len(data) != size {`)
			p.In()
			p.P(`t.Errorf("seed = %d, size %v != marshalled size %v", seed, size, len(data))`)
			p.Out()
			p.P(`}`)
			p.P(`if size2 != size {`)
			p.In()
			p.P(`t.Errorf("seed = %d, size %v != before marshal proto.Size %v", seed, size, size2)`)
			p.Out()
			p.P(`}`)
			p.P(`size3 := `, protoPkg.Use(), `.Size(p)`)
			p.P(`if size3 != size {`)
			p.In()
			p.P(`t.Errorf("seed = %d, size %v != after marshal proto.Size %v", seed, size, size3)`)
			p.Out()
			p.P(`}`)
			p.Out()
			p.P(`}`)
			p.P()
		}

		if gogoproto.HasBenchGen(file.FileDescriptorProto, message.DescriptorProto) {
			used = true
			p.P(`func Benchmark`, ccTypeName, sizeName, `(b *`, testingPkg.Use(), `.B) {`)
			p.In()
			p.P(`popr := `, randPkg.Use(), `.New(`, randPkg.Use(), `.NewSource(616))`)
			p.P(`total := 0`)
			p.P(`pops := make([]*`, ccTypeName, `, 1000)`)
			p.P(`for i := 0; i < 1000; i++ {`)
			p.In()
			p.P(`pops[i] = NewPopulated`, ccTypeName, `(popr, false)`)
			p.Out()
			p.P(`}`)
			p.P(`b.ResetTimer()`)
			p.P(`for i := 0; i < b.N; i++ {`)
			p.In()
			p.P(`total += pops[i%1000].`, sizeName, `()`)
			p.Out()
			p.P(`}`)
			p.P(`b.SetBytes(int64(total / b.N))`)
			p.Out()
			p.P(`}`)
			p.P()
		}

	}
	return used
}
Ejemplo n.º 4
0
func (p *marshalto) generateField(proto3 bool, numGen NumGen, file *generator.FileDescriptor, message *generator.Descriptor, field *descriptor.FieldDescriptorProto) {
	fieldname := p.GetOneOfFieldName(message, field)
	nullable := gogoproto.IsNullable(field)
	repeated := field.IsRepeated()
	required := field.IsRequired()

	protoSizer := gogoproto.IsProtoSizer(file.FileDescriptorProto, message.DescriptorProto)
	doNilCheck := gogoproto.NeedsNilCheck(proto3, field)
	if required && nullable {
		p.P(`if m.`, fieldname, `== nil {`)
		p.In()
		if !gogoproto.ImportsGoGoProto(file.FileDescriptorProto) {
			p.P(`return 0, new(`, p.protoPkg.Use(), `.RequiredNotSetError)`)
		} else {
			p.P(`return 0, `, p.protoPkg.Use(), `.NewRequiredNotSetError("`, field.GetName(), `")`)
		}
		p.Out()
		p.P(`} else {`)
	} else if repeated {
		p.P(`if len(m.`, fieldname, `) > 0 {`)
		p.In()
	} else if doNilCheck {
		p.P(`if m.`, fieldname, ` != nil {`)
		p.In()
	}
	packed := field.IsPacked() || (proto3 && field.IsRepeated() && generator.IsScalar(field))
	wireType := field.WireType()
	fieldNumber := field.GetNumber()
	if packed {
		wireType = proto.WireBytes
	}
	switch *field.Type {
	case descriptor.FieldDescriptorProto_TYPE_DOUBLE:
		if !p.unsafe || gogoproto.IsCastType(field) {
			if packed {
				p.encodeKey(fieldNumber, wireType)
				p.callVarint(`len(m.`, fieldname, `) * 8`)
				p.P(`for _, num := range m.`, fieldname, ` {`)
				p.In()
				p.P(`f`, numGen.Next(), ` := `, p.mathPkg.Use(), `.Float64bits(float64(num))`)
				p.encodeFixed64("f" + numGen.Current())
				p.Out()
				p.P(`}`)
			} else if repeated {
				p.P(`for _, num := range m.`, fieldname, ` {`)
				p.In()
				p.encodeKey(fieldNumber, wireType)
				p.P(`f`, numGen.Next(), ` := `, p.mathPkg.Use(), `.Float64bits(float64(num))`)
				p.encodeFixed64("f" + numGen.Current())
				p.Out()
				p.P(`}`)
			} else if proto3 {
				p.P(`if m.`, fieldname, ` != 0 {`)
				p.In()
				p.encodeKey(fieldNumber, wireType)
				p.callFixed64(p.mathPkg.Use(), `.Float64bits(float64(m.`+fieldname, `))`)
				p.Out()
				p.P(`}`)
			} else if !nullable {
				p.encodeKey(fieldNumber, wireType)
				p.callFixed64(p.mathPkg.Use(), `.Float64bits(float64(m.`+fieldname, `))`)
			} else {
				p.encodeKey(fieldNumber, wireType)
				p.callFixed64(p.mathPkg.Use(), `.Float64bits(float64(*m.`+fieldname, `))`)
			}
		} else {
			if packed {
				p.encodeKey(fieldNumber, wireType)
				p.callVarint(`len(m.`, fieldname, `) * 8`)
				p.P(`for _, num := range m.`, fieldname, ` {`)
				p.In()
				p.unsafeFixed64("num", "float64")
				p.Out()
				p.P(`}`)
			} else if repeated {
				p.P(`for _, num := range m.`, fieldname, ` {`)
				p.In()
				p.encodeKey(fieldNumber, wireType)
				p.unsafeFixed64("num", "float64")
				p.Out()
				p.P(`}`)
			} else if proto3 {
				p.P(`if m.`, fieldname, ` != 0 {`)
				p.In()
				p.encodeKey(fieldNumber, wireType)
				p.unsafeFixed64(`m.`+fieldname, "float64")
				p.Out()
				p.P(`}`)
			} else if !nullable {
				p.encodeKey(fieldNumber, wireType)
				p.unsafeFixed64(`m.`+fieldname, "float64")
			} else {
				p.encodeKey(fieldNumber, wireType)
				p.unsafeFixed64(`*m.`+fieldname, `float64`)
			}
		}
	case descriptor.FieldDescriptorProto_TYPE_FLOAT:
		if !p.unsafe || gogoproto.IsCastType(field) {
			if packed {
				p.encodeKey(fieldNumber, wireType)
				p.callVarint(`len(m.`, fieldname, `) * 4`)
				p.P(`for _, num := range m.`, fieldname, ` {`)
				p.In()
				p.P(`f`, numGen.Next(), ` := `, p.mathPkg.Use(), `.Float32bits(float32(num))`)
				p.encodeFixed32("f" + numGen.Current())
				p.Out()
				p.P(`}`)
			} else if repeated {
				p.P(`for _, num := range m.`, fieldname, ` {`)
				p.In()
				p.encodeKey(fieldNumber, wireType)
				p.P(`f`, numGen.Next(), ` := `, p.mathPkg.Use(), `.Float32bits(float32(num))`)
				p.encodeFixed32("f" + numGen.Current())
				p.Out()
				p.P(`}`)
			} else if proto3 {
				p.P(`if m.`, fieldname, ` != 0 {`)
				p.In()
				p.encodeKey(fieldNumber, wireType)
				p.callFixed32(p.mathPkg.Use(), `.Float32bits(float32(m.`+fieldname, `))`)
				p.Out()
				p.P(`}`)
			} else if !nullable {
				p.encodeKey(fieldNumber, wireType)
				p.callFixed32(p.mathPkg.Use(), `.Float32bits(float32(m.`+fieldname, `))`)
			} else {
				p.encodeKey(fieldNumber, wireType)
				p.callFixed32(p.mathPkg.Use(), `.Float32bits(float32(*m.`+fieldname, `))`)
			}
		} else {
			if packed {
				p.encodeKey(fieldNumber, wireType)
				p.callVarint(`len(m.`, fieldname, `) * 4`)
				p.P(`for _, num := range m.`, fieldname, ` {`)
				p.In()
				p.unsafeFixed32("num", "float32")
				p.Out()
				p.P(`}`)
			} else if repeated {
				p.P(`for _, num := range m.`, fieldname, ` {`)
				p.In()
				p.encodeKey(fieldNumber, wireType)
				p.unsafeFixed32("num", "float32")
				p.Out()
				p.P(`}`)
			} else if proto3 {
				p.P(`if m.`, fieldname, ` != 0 {`)
				p.In()
				p.encodeKey(fieldNumber, wireType)
				p.unsafeFixed32(`m.`+fieldname, `float32`)
				p.Out()
				p.P(`}`)
			} else if !nullable {
				p.encodeKey(fieldNumber, wireType)
				p.unsafeFixed32(`m.`+fieldname, `float32`)
			} else {
				p.encodeKey(fieldNumber, wireType)
				p.unsafeFixed32(`*m.`+fieldname, "float32")
			}
		}
	case descriptor.FieldDescriptorProto_TYPE_INT64,
		descriptor.FieldDescriptorProto_TYPE_UINT64,
		descriptor.FieldDescriptorProto_TYPE_INT32,
		descriptor.FieldDescriptorProto_TYPE_UINT32,
		descriptor.FieldDescriptorProto_TYPE_ENUM:
		if packed {
			jvar := "j" + numGen.Next()
			p.P(`data`, numGen.Next(), ` := make([]byte, len(m.`, fieldname, `)*10)`)
			p.P(`var `, jvar, ` int`)
			if *field.Type == descriptor.FieldDescriptorProto_TYPE_INT64 ||
				*field.Type == descriptor.FieldDescriptorProto_TYPE_INT32 {
				p.P(`for _, num1 := range m.`, fieldname, ` {`)
				p.In()
				p.P(`num := uint64(num1)`)
			} else {
				p.P(`for _, num := range m.`, fieldname, ` {`)
				p.In()
			}
			p.P(`for num >= 1<<7 {`)
			p.In()
			p.P(`data`, numGen.Current(), `[`, jvar, `] = uint8(uint64(num)&0x7f|0x80)`)
			p.P(`num >>= 7`)
			p.P(jvar, `++`)
			p.Out()
			p.P(`}`)
			p.P(`data`, numGen.Current(), `[`, jvar, `] = uint8(num)`)
			p.P(jvar, `++`)
			p.Out()
			p.P(`}`)
			p.encodeKey(fieldNumber, wireType)
			p.callVarint(jvar)
			p.P(`i += copy(data[i:], data`, numGen.Current(), `[:`, jvar, `])`)
		} else if repeated {
			p.P(`for _, num := range m.`, fieldname, ` {`)
			p.In()
			p.encodeKey(fieldNumber, wireType)
			p.callVarint("num")
			p.Out()
			p.P(`}`)
		} else if proto3 {
			p.P(`if m.`, fieldname, ` != 0 {`)
			p.In()
			p.encodeKey(fieldNumber, wireType)
			p.callVarint(`m.`, fieldname)
			p.Out()
			p.P(`}`)
		} else if !nullable {
			p.encodeKey(fieldNumber, wireType)
			p.callVarint(`m.`, fieldname)
		} else {
			p.encodeKey(fieldNumber, wireType)
			p.callVarint(`*m.`, fieldname)
		}
	case descriptor.FieldDescriptorProto_TYPE_FIXED64,
		descriptor.FieldDescriptorProto_TYPE_SFIXED64:
		if !p.unsafe {
			if packed {
				p.encodeKey(fieldNumber, wireType)
				p.callVarint(`len(m.`, fieldname, `) * 8`)
				p.P(`for _, num := range m.`, fieldname, ` {`)
				p.In()
				p.encodeFixed64("num")
				p.Out()
				p.P(`}`)
			} else if repeated {
				p.P(`for _, num := range m.`, fieldname, ` {`)
				p.In()
				p.encodeKey(fieldNumber, wireType)
				p.encodeFixed64("num")
				p.Out()
				p.P(`}`)
			} else if proto3 {
				p.P(`if m.`, fieldname, ` != 0 {`)
				p.In()
				p.encodeKey(fieldNumber, wireType)
				p.callFixed64("m." + fieldname)
				p.Out()
				p.P(`}`)
			} else if !nullable {
				p.encodeKey(fieldNumber, wireType)
				p.callFixed64("m." + fieldname)
			} else {
				p.encodeKey(fieldNumber, wireType)
				p.callFixed64("*m." + fieldname)
			}
		} else {
			typeName := "int64"
			if *field.Type == descriptor.FieldDescriptorProto_TYPE_FIXED64 {
				typeName = "uint64"
			}
			if packed {
				p.encodeKey(fieldNumber, wireType)
				p.callVarint(`len(m.`, fieldname, `) * 8`)
				p.P(`for _, num := range m.`, fieldname, ` {`)
				p.In()
				p.unsafeFixed64("num", typeName)
				p.Out()
				p.P(`}`)
			} else if repeated {
				p.P(`for _, num := range m.`, fieldname, ` {`)
				p.In()
				p.encodeKey(fieldNumber, wireType)
				p.unsafeFixed64("num", typeName)
				p.Out()
				p.P(`}`)
			} else if proto3 {
				p.P(`if m.`, fieldname, ` != 0 {`)
				p.In()
				p.encodeKey(fieldNumber, wireType)
				p.unsafeFixed64("m."+fieldname, typeName)
				p.Out()
				p.P(`}`)
			} else if !nullable {
				p.encodeKey(fieldNumber, wireType)
				p.unsafeFixed64("m."+fieldname, typeName)
			} else {
				p.encodeKey(fieldNumber, wireType)
				p.unsafeFixed64("*m."+fieldname, typeName)
			}
		}
	case descriptor.FieldDescriptorProto_TYPE_FIXED32,
		descriptor.FieldDescriptorProto_TYPE_SFIXED32:
		if !p.unsafe {
			if packed {
				p.encodeKey(fieldNumber, wireType)
				p.callVarint(`len(m.`, fieldname, `) * 4`)
				p.P(`for _, num := range m.`, fieldname, ` {`)
				p.In()
				p.encodeFixed32("num")
				p.Out()
				p.P(`}`)
			} else if repeated {
				p.P(`for _, num := range m.`, fieldname, ` {`)
				p.In()
				p.encodeKey(fieldNumber, wireType)
				p.encodeFixed32("num")
				p.Out()
				p.P(`}`)
			} else if proto3 {
				p.P(`if m.`, fieldname, ` != 0 {`)
				p.In()
				p.encodeKey(fieldNumber, wireType)
				p.callFixed32("m." + fieldname)
				p.Out()
				p.P(`}`)
			} else if !nullable {
				p.encodeKey(fieldNumber, wireType)
				p.callFixed32("m." + fieldname)
			} else {
				p.encodeKey(fieldNumber, wireType)
				p.callFixed32("*m." + fieldname)
			}
		} else {
			typeName := "int32"
			if *field.Type == descriptor.FieldDescriptorProto_TYPE_FIXED32 {
				typeName = "uint32"
			}
			if packed {
				p.encodeKey(fieldNumber, wireType)
				p.callVarint(`len(m.`, fieldname, `) * 4`)
				p.P(`for _, num := range m.`, fieldname, ` {`)
				p.In()
				p.unsafeFixed32("num", typeName)
				p.Out()
				p.P(`}`)
			} else if repeated {
				p.P(`for _, num := range m.`, fieldname, ` {`)
				p.In()
				p.encodeKey(fieldNumber, wireType)
				p.unsafeFixed32("num", typeName)
				p.Out()
				p.P(`}`)
			} else if proto3 {
				p.P(`if m.`, fieldname, ` != 0 {`)
				p.In()
				p.encodeKey(fieldNumber, wireType)
				p.unsafeFixed32("m."+fieldname, typeName)
				p.Out()
				p.P(`}`)
			} else if !nullable {
				p.encodeKey(fieldNumber, wireType)
				p.unsafeFixed32("m."+fieldname, typeName)
			} else {
				p.encodeKey(fieldNumber, wireType)
				p.unsafeFixed32("*m."+fieldname, typeName)
			}
		}
	case descriptor.FieldDescriptorProto_TYPE_BOOL:
		if packed {
			p.encodeKey(fieldNumber, wireType)
			p.callVarint(`len(m.`, fieldname, `)`)
			p.P(`for _, b := range m.`, fieldname, ` {`)
			p.In()
			p.P(`if b {`)
			p.In()
			p.P(`data[i] = 1`)
			p.Out()
			p.P(`} else {`)
			p.In()
			p.P(`data[i] = 0`)
			p.Out()
			p.P(`}`)
			p.P(`i++`)
			p.Out()
			p.P(`}`)
		} else if repeated {
			p.P(`for _, b := range m.`, fieldname, ` {`)
			p.In()
			p.encodeKey(fieldNumber, wireType)
			p.P(`if b {`)
			p.In()
			p.P(`data[i] = 1`)
			p.Out()
			p.P(`} else {`)
			p.In()
			p.P(`data[i] = 0`)
			p.Out()
			p.P(`}`)
			p.P(`i++`)
			p.Out()
			p.P(`}`)
		} else if proto3 {
			p.P(`if m.`, fieldname, ` {`)
			p.In()
			p.encodeKey(fieldNumber, wireType)
			p.P(`if m.`, fieldname, ` {`)
			p.In()
			p.P(`data[i] = 1`)
			p.Out()
			p.P(`} else {`)
			p.In()
			p.P(`data[i] = 0`)
			p.Out()
			p.P(`}`)
			p.P(`i++`)
			p.Out()
			p.P(`}`)
		} else if !nullable {
			p.encodeKey(fieldNumber, wireType)
			p.P(`if m.`, fieldname, ` {`)
			p.In()
			p.P(`data[i] = 1`)
			p.Out()
			p.P(`} else {`)
			p.In()
			p.P(`data[i] = 0`)
			p.Out()
			p.P(`}`)
			p.P(`i++`)
		} else {
			p.encodeKey(fieldNumber, wireType)
			p.P(`if *m.`, fieldname, ` {`)
			p.In()
			p.P(`data[i] = 1`)
			p.Out()
			p.P(`} else {`)
			p.In()
			p.P(`data[i] = 0`)
			p.Out()
			p.P(`}`)
			p.P(`i++`)
		}
	case descriptor.FieldDescriptorProto_TYPE_STRING:
		if repeated {
			p.P(`for _, s := range m.`, fieldname, ` {`)
			p.In()
			p.encodeKey(fieldNumber, wireType)
			p.P(`l = len(s)`)
			p.encodeVarint("l")
			p.P(`i+=copy(data[i:], s)`)
			p.Out()
			p.P(`}`)
		} else if proto3 {
			p.P(`if len(m.`, fieldname, `) > 0 {`)
			p.In()
			p.encodeKey(fieldNumber, wireType)
			p.callVarint(`len(m.`, fieldname, `)`)
			p.P(`i+=copy(data[i:], m.`, fieldname, `)`)
			p.Out()
			p.P(`}`)
		} else if !nullable {
			p.encodeKey(fieldNumber, wireType)
			p.callVarint(`len(m.`, fieldname, `)`)
			p.P(`i+=copy(data[i:], m.`, fieldname, `)`)
		} else {
			p.encodeKey(fieldNumber, wireType)
			p.callVarint(`len(*m.`, fieldname, `)`)
			p.P(`i+=copy(data[i:], *m.`, fieldname, `)`)
		}
	case descriptor.FieldDescriptorProto_TYPE_GROUP:
		panic(fmt.Errorf("marshaler does not support group %v", fieldname))
	case descriptor.FieldDescriptorProto_TYPE_MESSAGE:
		if p.IsMap(field) {
			m := p.GoMapType(nil, field)
			keygoTyp, keywire := p.GoType(nil, m.KeyField)
			keygoAliasTyp, _ := p.GoType(nil, m.KeyAliasField)
			// keys may not be pointers
			keygoTyp = strings.Replace(keygoTyp, "*", "", 1)
			keygoAliasTyp = strings.Replace(keygoAliasTyp, "*", "", 1)
			keyCapTyp := generator.CamelCase(keygoTyp)
			valuegoTyp, valuewire := p.GoType(nil, m.ValueField)
			valuegoAliasTyp, _ := p.GoType(nil, m.ValueAliasField)
			nullable, valuegoTyp, valuegoAliasTyp = generator.GoMapValueTypes(field, m.ValueField, valuegoTyp, valuegoAliasTyp)
			keyKeySize := keySize(1, wireToType(keywire))
			valueKeySize := keySize(2, wireToType(valuewire))
			if gogoproto.IsStableMarshaler(file.FileDescriptorProto, message.DescriptorProto) {
				keysName := `keysFor` + fieldname
				p.P(keysName, ` := make([]`, keygoTyp, `, 0, len(m.`, fieldname, `))`)
				p.P(`for k, _ := range m.`, fieldname, ` {`)
				p.In()
				p.P(keysName, ` = append(`, keysName, `, `, keygoTyp, `(k))`)
				p.Out()
				p.P(`}`)
				p.P(p.sortKeysPkg.Use(), `.`, keyCapTyp, `s(`, keysName, `)`)
				p.P(`for _, k := range `, keysName, ` {`)
			} else {
				p.P(`for k, _ := range m.`, fieldname, ` {`)
			}
			p.In()
			p.encodeKey(fieldNumber, wireType)
			sum := []string{strconv.Itoa(keyKeySize)}
			switch m.KeyField.GetType() {
			case descriptor.FieldDescriptorProto_TYPE_DOUBLE,
				descriptor.FieldDescriptorProto_TYPE_FIXED64,
				descriptor.FieldDescriptorProto_TYPE_SFIXED64:
				sum = append(sum, `8`)
			case descriptor.FieldDescriptorProto_TYPE_FLOAT,
				descriptor.FieldDescriptorProto_TYPE_FIXED32,
				descriptor.FieldDescriptorProto_TYPE_SFIXED32:
				sum = append(sum, `4`)
			case descriptor.FieldDescriptorProto_TYPE_INT64,
				descriptor.FieldDescriptorProto_TYPE_UINT64,
				descriptor.FieldDescriptorProto_TYPE_UINT32,
				descriptor.FieldDescriptorProto_TYPE_ENUM,
				descriptor.FieldDescriptorProto_TYPE_INT32:
				sum = append(sum, `sov`+p.localName+`(uint64(k))`)
			case descriptor.FieldDescriptorProto_TYPE_BOOL:
				sum = append(sum, `1`)
			case descriptor.FieldDescriptorProto_TYPE_STRING,
				descriptor.FieldDescriptorProto_TYPE_BYTES:
				sum = append(sum, `len(k)+sov`+p.localName+`(uint64(len(k)))`)
			case descriptor.FieldDescriptorProto_TYPE_SINT32,
				descriptor.FieldDescriptorProto_TYPE_SINT64:
				sum = append(sum, `soz`+p.localName+`(uint64(k))`)
			}
			if gogoproto.IsStableMarshaler(file.FileDescriptorProto, message.DescriptorProto) {
				p.P(`v := m.`, fieldname, `[`, keygoAliasTyp, `(k)]`)
			} else {
				p.P(`v := m.`, fieldname, `[k]`)
			}
			accessor := `v`
			switch m.ValueField.GetType() {
			case descriptor.FieldDescriptorProto_TYPE_DOUBLE,
				descriptor.FieldDescriptorProto_TYPE_FIXED64,
				descriptor.FieldDescriptorProto_TYPE_SFIXED64:
				sum = append(sum, strconv.Itoa(valueKeySize))
				sum = append(sum, strconv.Itoa(8))
			case descriptor.FieldDescriptorProto_TYPE_FLOAT,
				descriptor.FieldDescriptorProto_TYPE_FIXED32,
				descriptor.FieldDescriptorProto_TYPE_SFIXED32:
				sum = append(sum, strconv.Itoa(valueKeySize))
				sum = append(sum, strconv.Itoa(4))
			case descriptor.FieldDescriptorProto_TYPE_INT64,
				descriptor.FieldDescriptorProto_TYPE_UINT64,
				descriptor.FieldDescriptorProto_TYPE_UINT32,
				descriptor.FieldDescriptorProto_TYPE_ENUM,
				descriptor.FieldDescriptorProto_TYPE_INT32:
				sum = append(sum, strconv.Itoa(valueKeySize))
				sum = append(sum, `sov`+p.localName+`(uint64(v))`)
			case descriptor.FieldDescriptorProto_TYPE_BOOL:
				sum = append(sum, strconv.Itoa(valueKeySize))
				sum = append(sum, `1`)
			case descriptor.FieldDescriptorProto_TYPE_STRING:
				sum = append(sum, strconv.Itoa(valueKeySize))
				sum = append(sum, `len(v)+sov`+p.localName+`(uint64(len(v)))`)
			case descriptor.FieldDescriptorProto_TYPE_BYTES:
				if gogoproto.IsCustomType(field) {
					p.P(`cSize := 0`)
					if gogoproto.IsNullable(field) {
						p.P(`if `, accessor, ` != nil {`)
						p.In()
					}
					p.P(`cSize = `, accessor, `.Size()`)
					p.P(`cSize += `, strconv.Itoa(valueKeySize), ` + sov`+p.localName+`(uint64(cSize))`)
					if gogoproto.IsNullable(field) {
						p.Out()
						p.P(`}`)
					}
					sum = append(sum, `cSize`)
				} else {
					p.P(`byteSize := 0`)
					if proto3 {
						p.P(`if len(v) > 0 {`)
					} else {
						p.P(`if v != nil {`)
					}
					p.In()
					p.P(`byteSize = `, strconv.Itoa(valueKeySize), ` + len(v)+sov`+p.localName+`(uint64(len(v)))`)
					p.Out()
					p.P(`}`)
					sum = append(sum, `byteSize`)
				}
			case descriptor.FieldDescriptorProto_TYPE_SINT32,
				descriptor.FieldDescriptorProto_TYPE_SINT64:
				sum = append(sum, strconv.Itoa(valueKeySize))
				sum = append(sum, `soz`+p.localName+`(uint64(v))`)
			case descriptor.FieldDescriptorProto_TYPE_MESSAGE:
				if valuegoTyp != valuegoAliasTyp &&
					!gogoproto.IsStdTime(field) &&
					!gogoproto.IsStdDuration(field) {
					if nullable {
						// cast back to the type that has the generated methods on it
						accessor = `((` + valuegoTyp + `)(` + accessor + `))`
					} else {
						accessor = `((*` + valuegoTyp + `)(&` + accessor + `))`
					}
				} else if !nullable {
					accessor = `(&v)`
				}
				p.P(`msgSize := 0`)
				p.P(`if `, accessor, ` != nil {`)
				p.In()
				if gogoproto.IsStdTime(field) {
					p.P(`msgSize = `, p.typesPkg.Use(), `.SizeOfStdTime(*`, accessor, `)`)
				} else if gogoproto.IsStdDuration(field) {
					p.P(`msgSize = `, p.typesPkg.Use(), `.SizeOfStdDuration(*`, accessor, `)`)
				} else if protoSizer {
					p.P(`msgSize = `, accessor, `.ProtoSize()`)
				} else {
					p.P(`msgSize = `, accessor, `.Size()`)
				}
				p.P(`msgSize += `, strconv.Itoa(valueKeySize), ` + sov`+p.localName+`(uint64(msgSize))`)
				p.Out()
				p.P(`}`)
				sum = append(sum, `msgSize`)
			}
			p.P(`mapSize := `, strings.Join(sum, " + "))
			p.callVarint("mapSize")
			p.encodeKey(1, wireToType(keywire))
			p.mapField(numGen, field, m.KeyField, "k", protoSizer)
			nullableMsg := nullable && (m.ValueField.GetType() == descriptor.FieldDescriptorProto_TYPE_MESSAGE ||
				gogoproto.IsCustomType(field) && m.ValueField.IsBytes())
			plainBytes := m.ValueField.IsBytes() && !gogoproto.IsCustomType(field)
			if nullableMsg {
				p.P(`if `, accessor, ` != nil { `)
				p.In()
			} else if plainBytes {
				if proto3 {
					p.P(`if len(`, accessor, `) > 0 {`)
				} else {
					p.P(`if `, accessor, ` != nil {`)
				}
				p.In()
			}
			p.encodeKey(2, wireToType(valuewire))
			p.mapField(numGen, field, m.ValueField, accessor, protoSizer)
			if nullableMsg || plainBytes {
				p.Out()
				p.P(`}`)
			}
			p.Out()
			p.P(`}`)
		} else if repeated {
			p.P(`for _, msg := range m.`, fieldname, ` {`)
			p.In()
			p.encodeKey(fieldNumber, wireType)
			varName := "msg"
			if gogoproto.IsStdTime(field) {
				if gogoproto.IsNullable(field) {
					varName = "*" + varName
				}
				p.callVarint(p.typesPkg.Use(), `.SizeOfStdTime(`, varName, `)`)
				p.P(`n, err := `, p.typesPkg.Use(), `.StdTimeMarshalTo(`, varName, `, data[i:])`)
			} else if gogoproto.IsStdDuration(field) {
				if gogoproto.IsNullable(field) {
					varName = "*" + varName
				}
				p.callVarint(p.typesPkg.Use(), `.SizeOfStdDuration(`, varName, `)`)
				p.P(`n, err := `, p.typesPkg.Use(), `.StdDurationMarshalTo(`, varName, `, data[i:])`)
			} else if protoSizer {
				p.callVarint(varName, ".ProtoSize()")
				p.P(`n, err := `, varName, `.MarshalTo(data[i:])`)
			} else {
				p.callVarint(varName, ".Size()")
				p.P(`n, err := `, varName, `.MarshalTo(data[i:])`)
			}
			p.P(`if err != nil {`)
			p.In()
			p.P(`return 0, err`)
			p.Out()
			p.P(`}`)
			p.P(`i+=n`)
			p.Out()
			p.P(`}`)
		} else {
			p.encodeKey(fieldNumber, wireType)
			varName := `m.` + fieldname
			if gogoproto.IsStdTime(field) {
				if gogoproto.IsNullable(field) {
					varName = "*" + varName
				}
				p.callVarint(p.typesPkg.Use(), `.SizeOfStdTime(`, varName, `)`)
				p.P(`n`, numGen.Next(), `, err := `, p.typesPkg.Use(), `.StdTimeMarshalTo(`, varName, `, data[i:])`)
			} else if gogoproto.IsStdDuration(field) {
				if gogoproto.IsNullable(field) {
					varName = "*" + varName
				}
				p.callVarint(p.typesPkg.Use(), `.SizeOfStdDuration(`, varName, `)`)
				p.P(`n`, numGen.Next(), `, err := `, p.typesPkg.Use(), `.StdDurationMarshalTo(`, varName, `, data[i:])`)
			} else if protoSizer {
				p.callVarint(varName, `.ProtoSize()`)
				p.P(`n`, numGen.Next(), `, err := `, varName, `.MarshalTo(data[i:])`)
			} else {
				p.callVarint(varName, `.Size()`)
				p.P(`n`, numGen.Next(), `, err := `, varName, `.MarshalTo(data[i:])`)
			}
			p.P(`if err != nil {`)
			p.In()
			p.P(`return 0, err`)
			p.Out()
			p.P(`}`)
			p.P(`i+=n`, numGen.Current())
		}
	case descriptor.FieldDescriptorProto_TYPE_BYTES:
		if !gogoproto.IsCustomType(field) {
			if repeated {
				p.P(`for _, b := range m.`, fieldname, ` {`)
				p.In()
				p.encodeKey(fieldNumber, wireType)
				p.callVarint("len(b)")
				p.P(`i+=copy(data[i:], b)`)
				p.Out()
				p.P(`}`)
			} else if proto3 {
				p.P(`if len(m.`, fieldname, `) > 0 {`)
				p.In()
				p.encodeKey(fieldNumber, wireType)
				p.callVarint(`len(m.`, fieldname, `)`)
				p.P(`i+=copy(data[i:], m.`, fieldname, `)`)
				p.Out()
				p.P(`}`)
			} else {
				p.encodeKey(fieldNumber, wireType)
				p.callVarint(`len(m.`, fieldname, `)`)
				p.P(`i+=copy(data[i:], m.`, fieldname, `)`)
			}
		} else {
			if repeated {
				p.P(`for _, msg := range m.`, fieldname, ` {`)
				p.In()
				p.encodeKey(fieldNumber, wireType)
				if protoSizer {
					p.callVarint(`msg.ProtoSize()`)
				} else {
					p.callVarint(`msg.Size()`)
				}
				p.P(`n, err := msg.MarshalTo(data[i:])`)
				p.P(`if err != nil {`)
				p.In()
				p.P(`return 0, err`)
				p.Out()
				p.P(`}`)
				p.P(`i+=n`)
				p.Out()
				p.P(`}`)
			} else {
				p.encodeKey(fieldNumber, wireType)
				if protoSizer {
					p.callVarint(`m.`, fieldname, `.ProtoSize()`)
				} else {
					p.callVarint(`m.`, fieldname, `.Size()`)
				}
				p.P(`n`, numGen.Next(), `, err := m.`, fieldname, `.MarshalTo(data[i:])`)
				p.P(`if err != nil {`)
				p.In()
				p.P(`return 0, err`)
				p.Out()
				p.P(`}`)
				p.P(`i+=n`, numGen.Current())
			}
		}
	case descriptor.FieldDescriptorProto_TYPE_SINT32:
		if packed {
			datavar := "data" + numGen.Next()
			jvar := "j" + numGen.Next()
			p.P(datavar, ` := make([]byte, len(m.`, fieldname, ")*5)")
			p.P(`var `, jvar, ` int`)
			p.P(`for _, num := range m.`, fieldname, ` {`)
			p.In()
			xvar := "x" + numGen.Next()
			p.P(xvar, ` := (uint32(num) << 1) ^ uint32((num >> 31))`)
			p.P(`for `, xvar, ` >= 1<<7 {`)
			p.In()
			p.P(datavar, `[`, jvar, `] = uint8(uint64(`, xvar, `)&0x7f|0x80)`)
			p.P(jvar, `++`)
			p.P(xvar, ` >>= 7`)
			p.Out()
			p.P(`}`)
			p.P(datavar, `[`, jvar, `] = uint8(`, xvar, `)`)
			p.P(jvar, `++`)
			p.Out()
			p.P(`}`)
			p.encodeKey(fieldNumber, wireType)
			p.callVarint(jvar)
			p.P(`i+=copy(data[i:], `, datavar, `[:`, jvar, `])`)
		} else if repeated {
			p.P(`for _, num := range m.`, fieldname, ` {`)
			p.In()
			p.encodeKey(fieldNumber, wireType)
			p.P(`x`, numGen.Next(), ` := (uint32(num) << 1) ^ uint32((num >> 31))`)
			p.encodeVarint("x" + numGen.Current())
			p.Out()
			p.P(`}`)
		} else if proto3 {
			p.P(`if m.`, fieldname, ` != 0 {`)
			p.In()
			p.encodeKey(fieldNumber, wireType)
			p.callVarint(`(uint32(m.`, fieldname, `) << 1) ^ uint32((m.`, fieldname, ` >> 31))`)
			p.Out()
			p.P(`}`)
		} else if !nullable {
			p.encodeKey(fieldNumber, wireType)
			p.callVarint(`(uint32(m.`, fieldname, `) << 1) ^ uint32((m.`, fieldname, ` >> 31))`)
		} else {
			p.encodeKey(fieldNumber, wireType)
			p.callVarint(`(uint32(*m.`, fieldname, `) << 1) ^ uint32((*m.`, fieldname, ` >> 31))`)
		}
	case descriptor.FieldDescriptorProto_TYPE_SINT64:
		if packed {
			jvar := "j" + numGen.Next()
			xvar := "x" + numGen.Next()
			datavar := "data" + numGen.Next()
			p.P(`var `, jvar, ` int`)
			p.P(datavar, ` := make([]byte, len(m.`, fieldname, `)*10)`)
			p.P(`for _, num := range m.`, fieldname, ` {`)
			p.In()
			p.P(xvar, ` := (uint64(num) << 1) ^ uint64((num >> 63))`)
			p.P(`for `, xvar, ` >= 1<<7 {`)
			p.In()
			p.P(datavar, `[`, jvar, `] = uint8(uint64(`, xvar, `)&0x7f|0x80)`)
			p.P(jvar, `++`)
			p.P(xvar, ` >>= 7`)
			p.Out()
			p.P(`}`)
			p.P(datavar, `[`, jvar, `] = uint8(`, xvar, `)`)
			p.P(jvar, `++`)
			p.Out()
			p.P(`}`)
			p.encodeKey(fieldNumber, wireType)
			p.callVarint(jvar)
			p.P(`i+=copy(data[i:], `, datavar, `[:`, jvar, `])`)
		} else if repeated {
			p.P(`for _, num := range m.`, fieldname, ` {`)
			p.In()
			p.encodeKey(fieldNumber, wireType)
			p.P(`x`, numGen.Next(), ` := (uint64(num) << 1) ^ uint64((num >> 63))`)
			p.encodeVarint("x" + numGen.Current())
			p.Out()
			p.P(`}`)
		} else if proto3 {
			p.P(`if m.`, fieldname, ` != 0 {`)
			p.In()
			p.encodeKey(fieldNumber, wireType)
			p.callVarint(`(uint64(m.`, fieldname, `) << 1) ^ uint64((m.`, fieldname, ` >> 63))`)
			p.Out()
			p.P(`}`)
		} else if !nullable {
			p.encodeKey(fieldNumber, wireType)
			p.callVarint(`(uint64(m.`, fieldname, `) << 1) ^ uint64((m.`, fieldname, ` >> 63))`)
		} else {
			p.encodeKey(fieldNumber, wireType)
			p.callVarint(`(uint64(*m.`, fieldname, `) << 1) ^ uint64((*m.`, fieldname, ` >> 63))`)
		}
	default:
		panic("not implemented")
	}
	if (required && nullable) || repeated || doNilCheck {
		p.Out()
		p.P(`}`)
	}
}
Ejemplo n.º 5
0
func (p *marshalto) Generate(file *generator.FileDescriptor) {
	numGen := NewNumGen()
	p.PluginImports = generator.NewPluginImports(p.Generator)
	p.atleastOne = false
	p.localName = generator.FileName(file)

	p.mathPkg = p.NewImport("math")
	p.sortKeysPkg = p.NewImport("github.com/maditya/protobuf/sortkeys")
	p.protoPkg = p.NewImport("github.com/maditya/protobuf/proto")
	if !gogoproto.ImportsGoGoProto(file.FileDescriptorProto) {
		p.protoPkg = p.NewImport("github.com/golang/protobuf/proto")
	}
	p.unsafePkg = p.NewImport("unsafe")
	p.errorsPkg = p.NewImport("errors")
	p.typesPkg = p.NewImport("github.com/maditya/protobuf/types")

	for _, message := range file.Messages() {
		if message.DescriptorProto.GetOptions().GetMapEntry() {
			continue
		}
		ccTypeName := generator.CamelCaseSlice(message.TypeName())
		if p.unsafe {
			if !gogoproto.IsUnsafeMarshaler(file.FileDescriptorProto, message.DescriptorProto) {
				continue
			}
			if gogoproto.IsMarshaler(file.FileDescriptorProto, message.DescriptorProto) {
				panic(fmt.Sprintf("unsafe_marshaler and marshalto enabled for %v", ccTypeName))
			}
		}
		if !p.unsafe {
			if !gogoproto.IsMarshaler(file.FileDescriptorProto, message.DescriptorProto) {
				continue
			}
			if gogoproto.IsUnsafeMarshaler(file.FileDescriptorProto, message.DescriptorProto) {
				panic(fmt.Sprintf("unsafe_marshaler and marshalto enabled for %v", ccTypeName))
			}
		}
		p.atleastOne = true

		p.P(`func (m *`, ccTypeName, `) Marshal() (data []byte, err error) {`)
		p.In()
		if gogoproto.IsProtoSizer(file.FileDescriptorProto, message.DescriptorProto) {
			p.P(`size := m.ProtoSize()`)
		} else {
			p.P(`size := m.Size()`)
		}
		p.P(`data = make([]byte, size)`)
		p.P(`n, err := m.MarshalTo(data)`)
		p.P(`if err != nil {`)
		p.In()
		p.P(`return nil, err`)
		p.Out()
		p.P(`}`)
		p.P(`return data[:n], nil`)
		p.Out()
		p.P(`}`)
		p.P(``)
		p.P(`func (m *`, ccTypeName, `) MarshalTo(data []byte) (int, error) {`)
		p.In()
		p.P(`var i int`)
		p.P(`_ = i`)
		p.P(`var l int`)
		p.P(`_ = l`)
		fields := orderFields(message.GetField())
		sort.Sort(fields)
		oneofs := make(map[string]struct{})
		for _, field := range message.Field {
			oneof := field.OneofIndex != nil
			if !oneof {
				proto3 := gogoproto.IsProto3(file.FileDescriptorProto)
				p.generateField(proto3, numGen, file, message, field)
			} else {
				fieldname := p.GetFieldName(message, field)
				if _, ok := oneofs[fieldname]; !ok {
					oneofs[fieldname] = struct{}{}
					p.P(`if m.`, fieldname, ` != nil {`)
					p.In()
					p.P(`nn`, numGen.Next(), `, err := m.`, fieldname, `.MarshalTo(data[i:])`)
					p.P(`if err != nil {`)
					p.In()
					p.P(`return 0, err`)
					p.Out()
					p.P(`}`)
					p.P(`i+=nn`, numGen.Current())
					p.Out()
					p.P(`}`)
				}
			}
		}
		if message.DescriptorProto.HasExtension() {
			if gogoproto.HasExtensionsMap(file.FileDescriptorProto, message.DescriptorProto) {
				p.P(`n, err := `, p.protoPkg.Use(), `.EncodeInternalExtension(m, data[i:])`)
				p.P(`if err != nil {`)
				p.In()
				p.P(`return 0, err`)
				p.Out()
				p.P(`}`)
				p.P(`i+=n`)
			} else {
				p.P(`if m.XXX_extensions != nil {`)
				p.In()
				p.P(`i+=copy(data[i:], m.XXX_extensions)`)
				p.Out()
				p.P(`}`)
			}
		}
		if gogoproto.HasUnrecognized(file.FileDescriptorProto, message.DescriptorProto) {
			p.P(`if m.XXX_unrecognized != nil {`)
			p.In()
			p.P(`i+=copy(data[i:], m.XXX_unrecognized)`)
			p.Out()
			p.P(`}`)
		}

		p.P(`return i, nil`)
		p.Out()
		p.P(`}`)
		p.P()

		//Generate MarshalTo methods for oneof fields
		m := proto.Clone(message.DescriptorProto).(*descriptor.DescriptorProto)
		for _, field := range m.Field {
			oneof := field.OneofIndex != nil
			if !oneof {
				continue
			}
			ccTypeName := p.OneOfTypeName(message, field)
			p.P(`func (m *`, ccTypeName, `) MarshalTo(data []byte) (int, error) {`)
			p.In()
			p.P(`i := 0`)
			vanity.TurnOffNullableForNativeTypesWithoutDefaultsOnly(field)
			p.generateField(false, numGen, file, message, field)
			p.P(`return i, nil`)
			p.Out()
			p.P(`}`)
		}
	}

	if p.atleastOne {
		p.P(`func encodeFixed64`, p.localName, `(data []byte, offset int, v uint64) int {`)
		p.In()
		p.P(`data[offset] = uint8(v)`)
		p.P(`data[offset+1] = uint8(v >> 8)`)
		p.P(`data[offset+2] = uint8(v >> 16)`)
		p.P(`data[offset+3] = uint8(v >> 24)`)
		p.P(`data[offset+4] = uint8(v >> 32)`)
		p.P(`data[offset+5] = uint8(v >> 40)`)
		p.P(`data[offset+6] = uint8(v >> 48)`)
		p.P(`data[offset+7] = uint8(v >> 56)`)
		p.P(`return offset+8`)
		p.Out()
		p.P(`}`)

		p.P(`func encodeFixed32`, p.localName, `(data []byte, offset int, v uint32) int {`)
		p.In()
		p.P(`data[offset] = uint8(v)`)
		p.P(`data[offset+1] = uint8(v >> 8)`)
		p.P(`data[offset+2] = uint8(v >> 16)`)
		p.P(`data[offset+3] = uint8(v >> 24)`)
		p.P(`return offset+4`)
		p.Out()
		p.P(`}`)

		p.P(`func encodeVarint`, p.localName, `(data []byte, offset int, v uint64) int {`)
		p.In()
		p.P(`for v >= 1<<7 {`)
		p.In()
		p.P(`data[offset] = uint8(v&0x7f|0x80)`)
		p.P(`v >>= 7`)
		p.P(`offset++`)
		p.Out()
		p.P(`}`)
		p.P(`data[offset] = uint8(v)`)
		p.P(`return offset+1`)
		p.Out()
		p.P(`}`)
	}

}
Ejemplo n.º 6
0
func (p *testProto) Generate(imports generator.PluginImports, file *generator.FileDescriptor) bool {
	used := false
	testingPkg := imports.NewImport("testing")
	randPkg := imports.NewImport("math/rand")
	timePkg := imports.NewImport("time")
	protoPkg := imports.NewImport("github.com/maditya/protobuf/proto")
	if !gogoproto.ImportsGoGoProto(file.FileDescriptorProto) {
		protoPkg = imports.NewImport("github.com/golang/protobuf/proto")
	}
	for _, message := range file.Messages() {
		ccTypeName := generator.CamelCaseSlice(message.TypeName())
		if message.DescriptorProto.GetOptions().GetMapEntry() {
			continue
		}
		if gogoproto.HasTestGen(file.FileDescriptorProto, message.DescriptorProto) {
			used = true

			p.P(`func Test`, ccTypeName, `Proto(t *`, testingPkg.Use(), `.T) {`)
			p.In()
			p.P(`seed := `, timePkg.Use(), `.Now().UnixNano()`)
			p.P(`popr := `, randPkg.Use(), `.New(`, randPkg.Use(), `.NewSource(seed))`)
			p.P(`p := NewPopulated`, ccTypeName, `(popr, false)`)
			p.P(`data, err := `, protoPkg.Use(), `.Marshal(p)`)
			p.P(`if err != nil {`)
			p.In()
			p.P(`t.Fatalf("seed = %d, err = %v", seed, err)`)
			p.Out()
			p.P(`}`)
			p.P(`msg := &`, ccTypeName, `{}`)
			p.P(`if err := `, protoPkg.Use(), `.Unmarshal(data, msg); err != nil {`)
			p.In()
			p.P(`t.Fatalf("seed = %d, err = %v", seed, err)`)
			p.Out()
			p.P(`}`)
			p.P(`littlefuzz := make([]byte, len(data))`)
			p.P(`copy(littlefuzz, data)`)
			p.P(`for i := range data {`)
			p.In()
			p.P(`data[i] = byte(popr.Intn(256))`)
			p.Out()
			p.P(`}`)
			if gogoproto.HasVerboseEqual(file.FileDescriptorProto, message.DescriptorProto) {
				p.P(`if err := p.VerboseEqual(msg); err != nil {`)
				p.In()
				p.P(`t.Fatalf("seed = %d, %#v !VerboseProto %#v, since %v", seed, msg, p, err)`)
				p.Out()
				p.P(`}`)
			}
			p.P(`if !p.Equal(msg) {`)
			p.In()
			p.P(`t.Fatalf("seed = %d, %#v !Proto %#v", seed, msg, p)`)
			p.Out()
			p.P(`}`)
			p.P(`if len(littlefuzz) > 0 {`)
			p.In()
			p.P(`fuzzamount := 100`)
			p.P(`for i := 0; i < fuzzamount; i++ {`)
			p.In()
			p.P(`littlefuzz[popr.Intn(len(littlefuzz))] = byte(popr.Intn(256))`)
			p.P(`littlefuzz = append(littlefuzz, byte(popr.Intn(256)))`)
			p.Out()
			p.P(`}`)
			p.P(`// shouldn't panic`)
			p.P(`_ = `, protoPkg.Use(), `.Unmarshal(littlefuzz, msg)`)
			p.Out()
			p.P(`}`)
			p.Out()
			p.P(`}`)
			p.P()
		}

		if gogoproto.HasTestGen(file.FileDescriptorProto, message.DescriptorProto) {
			if gogoproto.IsMarshaler(file.FileDescriptorProto, message.DescriptorProto) || gogoproto.IsUnsafeMarshaler(file.FileDescriptorProto, message.DescriptorProto) {
				p.P(`func Test`, ccTypeName, `MarshalTo(t *`, testingPkg.Use(), `.T) {`)
				p.In()
				p.P(`seed := `, timePkg.Use(), `.Now().UnixNano()`)
				p.P(`popr := `, randPkg.Use(), `.New(`, randPkg.Use(), `.NewSource(seed))`)
				p.P(`p := NewPopulated`, ccTypeName, `(popr, false)`)
				if gogoproto.IsProtoSizer(file.FileDescriptorProto, message.DescriptorProto) {
					p.P(`size := p.ProtoSize()`)
				} else {
					p.P(`size := p.Size()`)
				}
				p.P(`data := make([]byte, size)`)
				p.P(`for i := range data {`)
				p.In()
				p.P(`data[i] = byte(popr.Intn(256))`)
				p.Out()
				p.P(`}`)
				p.P(`_, err := p.MarshalTo(data)`)
				p.P(`if err != nil {`)
				p.In()
				p.P(`t.Fatalf("seed = %d, err = %v", seed, err)`)
				p.Out()
				p.P(`}`)
				p.P(`msg := &`, ccTypeName, `{}`)
				p.P(`if err := `, protoPkg.Use(), `.Unmarshal(data, msg); err != nil {`)
				p.In()
				p.P(`t.Fatalf("seed = %d, err = %v", seed, err)`)
				p.Out()
				p.P(`}`)
				p.P(`for i := range data {`)
				p.In()
				p.P(`data[i] = byte(popr.Intn(256))`)
				p.Out()
				p.P(`}`)
				if gogoproto.HasVerboseEqual(file.FileDescriptorProto, message.DescriptorProto) {
					p.P(`if err := p.VerboseEqual(msg); err != nil {`)
					p.In()
					p.P(`t.Fatalf("seed = %d, %#v !VerboseProto %#v, since %v", seed, msg, p, err)`)
					p.Out()
					p.P(`}`)
				}
				p.P(`if !p.Equal(msg) {`)
				p.In()
				p.P(`t.Fatalf("seed = %d, %#v !Proto %#v", seed, msg, p)`)
				p.Out()
				p.P(`}`)
				p.Out()
				p.P(`}`)
				p.P()
			}
		}

		if gogoproto.HasBenchGen(file.FileDescriptorProto, message.DescriptorProto) {
			used = true
			p.P(`func Benchmark`, ccTypeName, `ProtoMarshal(b *`, testingPkg.Use(), `.B) {`)
			p.In()
			p.P(`popr := `, randPkg.Use(), `.New(`, randPkg.Use(), `.NewSource(616))`)
			p.P(`total := 0`)
			p.P(`pops := make([]*`, ccTypeName, `, 10000)`)
			p.P(`for i := 0; i < 10000; i++ {`)
			p.In()
			p.P(`pops[i] = NewPopulated`, ccTypeName, `(popr, false)`)
			p.Out()
			p.P(`}`)
			p.P(`b.ResetTimer()`)
			p.P(`for i := 0; i < b.N; i++ {`)
			p.In()
			p.P(`data, err := `, protoPkg.Use(), `.Marshal(pops[i%10000])`)
			p.P(`if err != nil {`)
			p.In()
			p.P(`panic(err)`)
			p.Out()
			p.P(`}`)
			p.P(`total += len(data)`)
			p.Out()
			p.P(`}`)
			p.P(`b.SetBytes(int64(total / b.N))`)
			p.Out()
			p.P(`}`)
			p.P()

			p.P(`func Benchmark`, ccTypeName, `ProtoUnmarshal(b *`, testingPkg.Use(), `.B) {`)
			p.In()
			p.P(`popr := `, randPkg.Use(), `.New(`, randPkg.Use(), `.NewSource(616))`)
			p.P(`total := 0`)
			p.P(`datas := make([][]byte, 10000)`)
			p.P(`for i := 0; i < 10000; i++ {`)
			p.In()
			p.P(`data, err := `, protoPkg.Use(), `.Marshal(NewPopulated`, ccTypeName, `(popr, false))`)
			p.P(`if err != nil {`)
			p.In()
			p.P(`panic(err)`)
			p.Out()
			p.P(`}`)
			p.P(`datas[i] = data`)
			p.Out()
			p.P(`}`)
			p.P(`msg := &`, ccTypeName, `{}`)
			p.P(`b.ResetTimer()`)
			p.P(`for i := 0; i < b.N; i++ {`)
			p.In()
			p.P(`total += len(datas[i%10000])`)
			p.P(`if err := `, protoPkg.Use(), `.Unmarshal(datas[i%10000], msg); err != nil {`)
			p.In()
			p.P(`panic(err)`)
			p.Out()
			p.P(`}`)
			p.Out()
			p.P(`}`)
			p.P(`b.SetBytes(int64(total / b.N))`)
			p.Out()
			p.P(`}`)
			p.P()
		}
	}
	return used
}