示例#1
0
func (p *unmarshal) Generate(file *generator.FileDescriptor) {
	p.PluginImports = generator.NewPluginImports(p.Generator)
	p.atleastOne = false

	p.ioPkg = p.NewImport("io")
	p.mathPkg = p.NewImport("math")
	p.unsafePkg = p.NewImport("unsafe")
	fmtPkg := p.NewImport("fmt")
	protoPkg := p.NewImport("code.google.com/p/gogoprotobuf/proto")

	for _, message := range file.Messages() {
		ccTypeName := generator.CamelCaseSlice(message.TypeName())
		if p.unsafe {
			if !gogoproto.IsUnsafeUnmarshaler(file.FileDescriptorProto, message.DescriptorProto) {
				continue
			}
			if gogoproto.IsUnmarshaler(file.FileDescriptorProto, message.DescriptorProto) {
				panic(fmt.Sprintf("unsafe_unmarshaler and unmarshaler enabled for %v", ccTypeName))
			}
		}
		if !p.unsafe {
			if !gogoproto.IsUnmarshaler(file.FileDescriptorProto, message.DescriptorProto) {
				continue
			}
			if gogoproto.IsUnsafeUnmarshaler(file.FileDescriptorProto, message.DescriptorProto) {
				panic(fmt.Sprintf("unsafe_unmarshaler and unmarshaler enabled for %v", ccTypeName))
			}
		}
		p.atleastOne = true

		p.P(`func (m *`, ccTypeName, `) Unmarshal(data []byte) error {`)
		p.In()
		p.P(`l := len(data)`)
		p.P(`index := 0`)
		p.P(`for index < l {`)
		p.In()
		p.P(`var wire uint64`)
		p.decodeVarint("wire", "uint64")
		p.P(`fieldNum := int32(wire >> 3)`)
		if len(message.Field) > 0 {
			p.P(`wireType := int(wire & 0x7)`)
		}
		p.P(`switch fieldNum {`)
		p.In()
		for _, field := range message.Field {
			fieldname := p.GetFieldName(message, field)

			packed := field.IsPacked()
			p.P(`case `, strconv.Itoa(int(field.GetNumber())), `:`)
			p.In()
			wireType := field.WireType()
			if packed {
				p.P(`if wireType == `, strconv.Itoa(proto.WireBytes), `{`)
				p.In()
				p.P(`var packedLen int`)
				p.decodeVarint("packedLen", "int")
				p.P(`postIndex := index + packedLen`)
				p.P(`if postIndex > l {`)
				p.In()
				p.P(`return `, p.ioPkg.Use(), `.ErrUnexpectedEOF`)
				p.Out()
				p.P(`}`)
				p.P(`for index < postIndex {`)
				p.In()
				p.field(field, fieldname)
				p.Out()
				p.P(`}`)
				p.Out()
				p.P(`} else if wireType == `, strconv.Itoa(wireType), `{`)
				p.In()
				p.field(field, fieldname)
				p.Out()
				p.P(`} else {`)
				p.In()
				p.P(`return ` + fmtPkg.Use() + `.Errorf("proto: wrong wireType = %d for field ` + fieldname + `", wireType)`)
				p.Out()
				p.P(`}`)
			} else {
				p.P(`if wireType != `, strconv.Itoa(wireType), `{`)
				p.In()
				p.P(`return ` + fmtPkg.Use() + `.Errorf("proto: wrong wireType = %d for field ` + fieldname + `", wireType)`)
				p.Out()
				p.P(`}`)
				p.field(field, fieldname)
			}
		}
		p.Out()
		p.P(`default:`)
		p.In()
		if message.DescriptorProto.HasExtension() {
			c := []string{}
			for _, erange := range message.GetExtensionRange() {
				c = append(c, `((fieldNum >= `+strconv.Itoa(int(erange.GetStart()))+") && (fieldNum<"+strconv.Itoa(int(erange.GetEnd()))+`))`)
			}
			p.P(`if `, strings.Join(c, "||"), `{`)
			p.In()
			p.P(`var sizeOfWire int`)
			p.P(`for {`)
			p.In()
			p.P(`sizeOfWire++`)
			p.P(`wire >>= 7`)
			p.P(`if wire == 0 {`)
			p.In()
			p.P(`break`)
			p.Out()
			p.P(`}`)
			p.Out()
			p.P(`}`)
			p.P(`index-=sizeOfWire`)
			p.P(`skippy, err := `, protoPkg.Use(), `.Skip(data[index:])`)
			p.P(`if err != nil {`)
			p.In()
			p.P(`return err`)
			p.Out()
			p.P(`}`)
			p.P(`if (index + skippy) > l {`)
			p.In()
			p.P(`return `, p.ioPkg.Use(), `.ErrUnexpectedEOF`)
			p.Out()
			p.P(`}`)
			if gogoproto.HasExtensionsMap(file.FileDescriptorProto, message.DescriptorProto) {
				p.P(`if m.XXX_extensions == nil {`)
				p.In()
				p.P(`m.XXX_extensions = make(map[int32]`, protoPkg.Use(), `.Extension)`)
				p.Out()
				p.P(`}`)
				p.P(`m.XXX_extensions[int32(fieldNum)] = `, protoPkg.Use(), `.NewExtension(data[index:index+skippy])`)
			} else {
				p.P(`m.XXX_extensions = append(m.XXX_extensions, data[index:index+skippy]...)`)
			}
			p.P(`index += skippy`)
			p.Out()
			p.P(`} else {`)
			p.In()
		}
		p.P(`var sizeOfWire int`)
		p.P(`for {`)
		p.In()
		p.P(`sizeOfWire++`)
		p.P(`wire >>= 7`)
		p.P(`if wire == 0 {`)
		p.In()
		p.P(`break`)
		p.Out()
		p.P(`}`)
		p.Out()
		p.P(`}`)
		p.P(`index-=sizeOfWire`)
		p.P(`skippy, err := `, protoPkg.Use(), `.Skip(data[index:])`)
		p.P(`if err != nil {`)
		p.In()
		p.P(`return err`)
		p.Out()
		p.P(`}`)
		p.P(`if (index + skippy) > l {`)
		p.In()
		p.P(`return `, p.ioPkg.Use(), `.ErrUnexpectedEOF`)
		p.Out()
		p.P(`}`)
		p.P(`m.XXX_unrecognized = append(m.XXX_unrecognized, data[index:index+skippy]...)`)
		p.P(`index += skippy`)
		p.Out()
		if message.DescriptorProto.HasExtension() {
			p.Out()
			p.P(`}`)
		}
		p.Out()
		p.P(`}`)
		p.Out()
		p.P(`}`)
		p.P(`return nil`)
		p.Out()
		p.P(`}`)
	}

	if !p.atleastOne {
		return
	}

}
示例#2
0
func (p *unmarshal) Generate(file *generator.FileDescriptor) {
	p.PluginImports = generator.NewPluginImports(p.Generator)
	p.atleastOne = false

	p.ioPkg = p.NewImport("io")
	p.unsafePkg = p.NewImport("unsafe")
	protoPkg := p.NewImport("code.google.com/p/gogoprotobuf/proto")

	for _, message := range file.Messages() {
		if !gogoproto.IsUnsafeUnmarshaler(file.FileDescriptorProto, message.DescriptorProto) {
			continue
		}
		ccTypeName := generator.CamelCaseSlice(message.TypeName())
		if gogoproto.IsUnmarshaler(file.FileDescriptorProto, message.DescriptorProto) {
			panic(fmt.Sprintf("unsafe_unmarshaler and unmarshaler enabled for %v", ccTypeName))
		}
		p.atleastOne = true

		p.P(`func (m *`, ccTypeName, `) Unmarshal(data []byte) error {`)
		p.In()
		p.P(`l := len(data)`)
		p.P(`index := 0`)
		p.P(`for index < l {`)
		p.In()
		p.P(`var wire uint64`)
		p.decodeVarint("wire", "uint64")
		p.P(`fieldNum := int32(wire >> 3)`)
		if len(message.Field) > 0 {
			p.P(`wireType := int(wire & 0x7)`)
		}
		p.P(`switch fieldNum {`)
		p.In()
		for _, field := range message.Field {
			fieldname := generator.CamelCase(*field.Name)
			repeated := field.IsRepeated()
			nullable := gogoproto.IsNullable(field)
			packed := field.IsPacked()
			p.P(`case `, strconv.Itoa(int(field.GetNumber())), `:`)
			p.In()
			wireType := field.WireType()
			if packed {
				p.P(`if wireType != `, strconv.Itoa(proto.WireBytes), `{`)
				p.In()
				p.P(`return proto.ErrWrongType`)
				p.Out()
				p.P(`}`)
				p.P(`var packedLen int`)
				p.decodeVarint("packedLen", "int")
				p.P(`postIndex := index + packedLen`)
				p.P(`if postIndex > l {`)
				p.In()
				p.P(`return `, p.ioPkg.Use(), `.ErrUnexpectedEOF`)
				p.Out()
				p.P(`}`)
				p.P(`for index < postIndex {`)
				p.In()
			} else {
				p.P(`if wireType != `, strconv.Itoa(wireType), `{`)
				p.In()
				p.P(`return proto.ErrWrongType`)
				p.Out()
				p.P(`}`)
			}
			switch *field.Type {
			case descriptor.FieldDescriptorProto_TYPE_DOUBLE:
				if repeated {
					p.P(`var v float64`)
					p.unsafeFixed64("v", "float64")
					p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`)
				} else if nullable {
					p.P(`var v float64`)
					p.unsafeFixed64("v", "float64")
					p.P(`m.`, fieldname, ` = &v`)
				} else {
					p.unsafeFixed64(`m.`+fieldname, "float64")
				}
			case descriptor.FieldDescriptorProto_TYPE_FLOAT:
				if repeated {
					p.P(`var v float32`)
					p.unsafeFixed32("v", "float32")
					p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`)
				} else if nullable {
					p.P(`var v float32`)
					p.unsafeFixed32("v", "float32")
					p.P(`m.`, fieldname, ` = &v`)
				} else {
					p.unsafeFixed32("m."+fieldname, "float32")
				}
			case descriptor.FieldDescriptorProto_TYPE_INT64:
				if repeated {
					p.P(`var v int64`)
					p.decodeVarint("v", "int64")
					p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`)
				} else if nullable {
					p.P(`var v int64`)
					p.decodeVarint("v", "int64")
					p.P(`m.`, fieldname, ` = &v`)
				} else {
					p.decodeVarint("m."+fieldname, "int64")
				}
			case descriptor.FieldDescriptorProto_TYPE_UINT64:
				if repeated {
					p.P(`var v uint64`)
					p.decodeVarint("v", "uint64")
					p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`)
				} else if nullable {
					p.P(`var v uint64`)
					p.decodeVarint("v", "uint64")
					p.P(`m.`, fieldname, ` = &v`)
				} else {
					p.decodeVarint("m."+fieldname, "uint64")
				}
			case descriptor.FieldDescriptorProto_TYPE_INT32:
				if repeated {
					p.P(`var v int32`)
					p.decodeVarint("v", "int32")
					p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`)
				} else if nullable {
					p.P(`var v int32`)
					p.decodeVarint("v", "int32")
					p.P(`m.`, fieldname, ` = &v`)
				} else {
					p.decodeVarint("m."+fieldname, "int32")
				}
			case descriptor.FieldDescriptorProto_TYPE_FIXED64:
				if repeated {
					p.P(`var v uint64`)
					p.unsafeFixed64("v", "uint64")
					p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`)
				} else if nullable {
					p.P(`var v uint64`)
					p.unsafeFixed64("v", "uint64")
					p.P(`m.`, fieldname, ` = &v`)
				} else {
					p.unsafeFixed64("m."+fieldname, "uint64")
				}
			case descriptor.FieldDescriptorProto_TYPE_FIXED32:
				if repeated {
					p.P(`var v uint32`)
					p.unsafeFixed32("v", "uint32")
					p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`)
				} else if nullable {
					p.P(`var v uint32`)
					p.unsafeFixed32("v", "uint32")
					p.P(`m.`, fieldname, ` = &v`)
				} else {
					p.unsafeFixed32("m."+fieldname, "uint32")
				}
			case descriptor.FieldDescriptorProto_TYPE_BOOL:
				if repeated {
					p.P(`var v int`)
					p.decodeVarint("v", "int")
					p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, bool(v != 0))`)
				} else if nullable {
					p.P(`var v int`)
					p.decodeVarint("v", "int")
					p.P(`b := bool(v != 0)`)
					p.P(`m.`, fieldname, ` = &b`)
				} else {
					p.P(`var v int`)
					p.decodeVarint("v", "int")
					p.P(`m.`, fieldname, ` = bool(v != 0)`)
				}
			case descriptor.FieldDescriptorProto_TYPE_STRING:
				p.P(`var stringLen uint64`)
				p.decodeVarint("stringLen", "uint64")
				p.P(`postIndex := index + int(stringLen)`)
				p.P(`if postIndex > l {`)
				p.In()
				p.P(`return `, p.ioPkg.Use(), `.ErrUnexpectedEOF`)
				p.Out()
				p.P(`}`)
				if repeated {
					p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, string(data[index:postIndex]))`)
				} else if nullable {
					p.P(`s := string(data[index:postIndex])`)
					p.P(`m.`, fieldname, ` = &s`)
				} else {
					p.P(`m.`, fieldname, ` = string(data[index:postIndex])`)
				}
				p.P(`index = postIndex`)
			case descriptor.FieldDescriptorProto_TYPE_GROUP:
				panic(fmt.Errorf("unmarshaler does not support group %v", fieldname))
			case descriptor.FieldDescriptorProto_TYPE_MESSAGE:
				desc := p.ObjectNamed(field.GetTypeName())
				msgname := p.TypeName(desc)
				msgnames := strings.Split(msgname, ".")
				typeName := msgnames[len(msgnames)-1]
				if gogoproto.IsEmbed(field) {
					fieldname = typeName
				}
				p.P(`var msglen int`)
				p.decodeVarint("msglen", "int")
				p.P(`postIndex := index + msglen`)
				p.P(`if postIndex > l {`)
				p.In()
				p.P(`return `, p.ioPkg.Use(), `.ErrUnexpectedEOF`)
				p.Out()
				p.P(`}`)
				if repeated {
					if nullable {
						p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, &`, msgname, `{})`)
					} else {
						p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, `, msgname, `{})`)
					}
					p.P(`m.`, fieldname, `[len(m.`, fieldname, `)-1].Unmarshal(data[index:postIndex])`)
				} else if nullable {
					p.P(`if m.`, fieldname, ` == nil {`)
					p.In()
					p.P(`m.`, fieldname, ` = &`, msgname, `{}`)
					p.Out()
					p.P(`}`)
					p.P(`if err := m.`, fieldname, `.Unmarshal(data[index:postIndex]); err != nil {`)
					p.In()
					p.P(`return err`)
					p.Out()
					p.P(`}`)
				} else {
					p.P(`if err := m.`, fieldname, `.Unmarshal(data[index:postIndex]); err != nil {`)
					p.In()
					p.P(`return err`)
					p.Out()
					p.P(`}`)
				}
				p.P(`index = postIndex`)
			case descriptor.FieldDescriptorProto_TYPE_BYTES:
				p.P(`var byteLen int`)
				p.decodeVarint("byteLen", "int")
				p.P(`postIndex := index + byteLen`)
				p.P(`if postIndex > l {`)
				p.In()
				p.P(`return `, p.ioPkg.Use(), `.ErrUnexpectedEOF`)
				p.Out()
				p.P(`}`)
				if !gogoproto.IsCustomType(field) {
					if repeated {
						p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, make([]byte, postIndex-index))`)
						p.P(`copy(m.`, fieldname, `[len(m.`, fieldname, `)-1], data[index:postIndex])`)
					} else if nullable {
						p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, data[index:postIndex]...)`)
					} else {
						p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, data[index:postIndex]...)`)
					}
				} else {
					_, ctyp, err := generator.GetCustomType(field)
					if err != nil {
						panic(err)
					}
					if repeated {
						p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, `, ctyp, `{})`)
						p.P(`m.`, fieldname, `[len(m.`, fieldname, `)-1].Unmarshal(data[index:postIndex])`)
					} else if nullable {
						p.P(`m.`, fieldname, ` = &`, ctyp, `{}`)
						p.P(`if err := m.`, fieldname, `.Unmarshal(data[index:postIndex]); err != nil {`)
						p.In()
						p.P(`return err`)
						p.Out()
						p.P(`}`)
					} else {
						p.P(`if err := m.`, fieldname, `.Unmarshal(data[index:postIndex]); err != nil {`)
						p.In()
						p.P(`return err`)
						p.Out()
						p.P(`}`)
					}
				}
				p.P(`index = postIndex`)
			case descriptor.FieldDescriptorProto_TYPE_UINT32:
				if repeated {
					p.P(`var v uint32`)
					p.decodeVarint("v", "uint32")
					p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`)
				} else if nullable {
					p.P(`var v uint32`)
					p.decodeVarint("v", "uint32")
					p.P(`m.`, fieldname, ` = &v`)
				} else {
					p.decodeVarint("m."+fieldname, "uint32")
				}
			case descriptor.FieldDescriptorProto_TYPE_ENUM:
				typName := p.TypeName(p.ObjectNamed(field.GetTypeName()))
				if repeated {
					p.P(`var v `, typName)
					p.decodeVarint("v", typName)
					p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`)
				} else if nullable {
					p.P(`var v `, typName)
					p.decodeVarint("v", typName)
					p.P(`m.`, fieldname, ` = &v`)
				} else {
					p.decodeVarint("m."+fieldname, typName)
				}
			case descriptor.FieldDescriptorProto_TYPE_SFIXED32:
				if repeated {
					p.P(`var v int32`)
					p.unsafeFixed32("v", "int32")
					p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`)
				} else if nullable {
					p.P(`var v int32`)
					p.unsafeFixed32("v", "int32")
					p.P(`m.`, fieldname, ` = &v`)
				} else {
					p.unsafeFixed32("m."+fieldname, "int32")
				}
			case descriptor.FieldDescriptorProto_TYPE_SFIXED64:
				if repeated {
					p.P(`var v int64`)
					p.unsafeFixed64("v", "int64")
					p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`)
				} else if nullable {
					p.P(`var v int64`)
					p.unsafeFixed64("v", "int64")
					p.P(`m.`, fieldname, ` = &v`)
				} else {
					p.unsafeFixed64("m."+fieldname, "int64")
				}
			case descriptor.FieldDescriptorProto_TYPE_SINT32:
				p.P(`var v int32`)
				p.decodeVarint("v", "int32")
				p.P(`v = int32((uint32(v) >> 1) ^ uint32(((v&1)<<31)>>31))`)
				if repeated {
					p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`)
				} else if nullable {
					p.P(`m.`, fieldname, ` = &v`)
				} else {
					p.P(`m.`, fieldname, ` = v`)
				}
			case descriptor.FieldDescriptorProto_TYPE_SINT64:
				p.P(`var v uint64`)
				p.decodeVarint("v", "uint64")
				p.P(`v = (v >> 1) ^ uint64((int64(v&1)<<63)>>63)`)
				if repeated {
					p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, int64(v))`)
				} else if nullable {
					p.P(`v2 := int64(v)`)
					p.P(`m.`, fieldname, ` = &v2`)
				} else {
					p.P(`m.`, fieldname, ` = int64(v)`)
				}
			default:
				panic("not implemented")
			}
			if packed {
				p.Out()
				p.P(`}`)
			}
		}
		p.Out()
		p.P(`default:`)
		p.In()
		if message.DescriptorProto.HasExtension() {
			c := []string{}
			for _, erange := range message.GetExtensionRange() {
				c = append(c, `((fieldNum >= `+strconv.Itoa(int(erange.GetStart()))+") && (fieldNum<"+strconv.Itoa(int(erange.GetEnd()))+`))`)
			}
			p.P(`if `, strings.Join(c, "||"), `{`)
			p.In()
			p.P(`var sizeOfWire int`)
			p.P(`for {`)
			p.In()
			p.P(`sizeOfWire++`)
			p.P(`wire >>= 7`)
			p.P(`if wire == 0 {`)
			p.In()
			p.P(`break`)
			p.Out()
			p.P(`}`)
			p.Out()
			p.P(`}`)
			p.P(`index-=sizeOfWire`)
			p.P(`skippy, err := `, protoPkg.Use(), `.Skip(data[index:])`)
			p.P(`if err != nil {`)
			p.In()
			p.P(`return err`)
			p.Out()
			p.P(`}`)
			p.P(`if m.XXX_extensions == nil {`)
			p.In()
			p.P(`m.XXX_extensions = make(map[int32]`, protoPkg.Use(), `.Extension)`)
			p.Out()
			p.P(`}`)
			p.P(`if (index + skippy) > l {`)
			p.In()
			p.P(`return `, p.ioPkg.Use(), `.ErrUnexpectedEOF`)
			p.Out()
			p.P(`}`)
			p.P(`m.XXX_extensions[int32(fieldNum)] = `, protoPkg.Use(), `.NewExtension(data[index:index+skippy])`)
			p.P(`index += skippy`)
			p.Out()
			p.P(`} else {`)
			p.In()
		}
		p.P(`var sizeOfWire int`)
		p.P(`for {`)
		p.In()
		p.P(`sizeOfWire++`)
		p.P(`wire >>= 7`)
		p.P(`if wire == 0 {`)
		p.In()
		p.P(`break`)
		p.Out()
		p.P(`}`)
		p.Out()
		p.P(`}`)
		p.P(`index-=sizeOfWire`)
		p.P(`skippy, err := `, protoPkg.Use(), `.Skip(data[index:])`)
		p.P(`if err != nil {`)
		p.In()
		p.P(`return err`)
		p.Out()
		p.P(`}`)
		p.P(`if (index + skippy) > l {`)
		p.In()
		p.P(`return `, p.ioPkg.Use(), `.ErrUnexpectedEOF`)
		p.Out()
		p.P(`}`)
		p.P(`m.XXX_unrecognized = append(m.XXX_unrecognized, data[index:index+skippy]...)`)
		p.P(`index += skippy`)
		p.Out()
		if message.DescriptorProto.HasExtension() {
			p.Out()
			p.P(`}`)
		}
		p.Out()
		p.P(`}`)
		p.Out()
		p.P(`}`)
		p.P(`return nil`)
		p.Out()
		p.P(`}`)
	}

	if !p.atleastOne {
		return
	}

}