func (p *unmarshal) Generate(file *generator.FileDescriptor) { p.PluginImports = generator.NewPluginImports(p.Generator) p.atleastOne = false p.ioPkg = p.NewImport("io") p.mathPkg = p.NewImport("math") p.unsafePkg = p.NewImport("unsafe") fmtPkg := p.NewImport("fmt") protoPkg := p.NewImport("code.google.com/p/gogoprotobuf/proto") for _, message := range file.Messages() { ccTypeName := generator.CamelCaseSlice(message.TypeName()) if p.unsafe { if !gogoproto.IsUnsafeUnmarshaler(file.FileDescriptorProto, message.DescriptorProto) { continue } if gogoproto.IsUnmarshaler(file.FileDescriptorProto, message.DescriptorProto) { panic(fmt.Sprintf("unsafe_unmarshaler and unmarshaler enabled for %v", ccTypeName)) } } if !p.unsafe { if !gogoproto.IsUnmarshaler(file.FileDescriptorProto, message.DescriptorProto) { continue } if gogoproto.IsUnsafeUnmarshaler(file.FileDescriptorProto, message.DescriptorProto) { panic(fmt.Sprintf("unsafe_unmarshaler and unmarshaler enabled for %v", ccTypeName)) } } p.atleastOne = true p.P(`func (m *`, ccTypeName, `) Unmarshal(data []byte) error {`) p.In() p.P(`l := len(data)`) p.P(`index := 0`) p.P(`for index < l {`) p.In() p.P(`var wire uint64`) p.decodeVarint("wire", "uint64") p.P(`fieldNum := int32(wire >> 3)`) if len(message.Field) > 0 { p.P(`wireType := int(wire & 0x7)`) } p.P(`switch fieldNum {`) p.In() for _, field := range message.Field { fieldname := p.GetFieldName(message, field) packed := field.IsPacked() p.P(`case `, strconv.Itoa(int(field.GetNumber())), `:`) p.In() wireType := field.WireType() if packed { p.P(`if wireType == `, strconv.Itoa(proto.WireBytes), `{`) p.In() p.P(`var packedLen int`) p.decodeVarint("packedLen", "int") p.P(`postIndex := index + packedLen`) p.P(`if postIndex > l {`) p.In() p.P(`return `, p.ioPkg.Use(), `.ErrUnexpectedEOF`) p.Out() p.P(`}`) p.P(`for index < postIndex {`) p.In() p.field(field, fieldname) p.Out() p.P(`}`) p.Out() p.P(`} else if wireType == `, strconv.Itoa(wireType), `{`) p.In() p.field(field, fieldname) p.Out() p.P(`} else {`) p.In() p.P(`return ` + fmtPkg.Use() + `.Errorf("proto: wrong wireType = %d for field ` + fieldname + `", wireType)`) p.Out() p.P(`}`) } else { p.P(`if wireType != `, strconv.Itoa(wireType), `{`) p.In() p.P(`return ` + fmtPkg.Use() + `.Errorf("proto: wrong wireType = %d for field ` + fieldname + `", wireType)`) p.Out() p.P(`}`) p.field(field, fieldname) } } p.Out() p.P(`default:`) p.In() if message.DescriptorProto.HasExtension() { c := []string{} for _, erange := range message.GetExtensionRange() { c = append(c, `((fieldNum >= `+strconv.Itoa(int(erange.GetStart()))+") && (fieldNum<"+strconv.Itoa(int(erange.GetEnd()))+`))`) } p.P(`if `, strings.Join(c, "||"), `{`) p.In() p.P(`var sizeOfWire int`) p.P(`for {`) p.In() p.P(`sizeOfWire++`) p.P(`wire >>= 7`) p.P(`if wire == 0 {`) p.In() p.P(`break`) p.Out() p.P(`}`) p.Out() p.P(`}`) p.P(`index-=sizeOfWire`) p.P(`skippy, err := `, protoPkg.Use(), `.Skip(data[index:])`) p.P(`if err != nil {`) p.In() p.P(`return err`) p.Out() p.P(`}`) p.P(`if (index + skippy) > l {`) p.In() p.P(`return `, p.ioPkg.Use(), `.ErrUnexpectedEOF`) p.Out() p.P(`}`) if gogoproto.HasExtensionsMap(file.FileDescriptorProto, message.DescriptorProto) { p.P(`if m.XXX_extensions == nil {`) p.In() p.P(`m.XXX_extensions = make(map[int32]`, protoPkg.Use(), `.Extension)`) p.Out() p.P(`}`) p.P(`m.XXX_extensions[int32(fieldNum)] = `, protoPkg.Use(), `.NewExtension(data[index:index+skippy])`) } else { p.P(`m.XXX_extensions = append(m.XXX_extensions, data[index:index+skippy]...)`) } p.P(`index += skippy`) p.Out() p.P(`} else {`) p.In() } p.P(`var sizeOfWire int`) p.P(`for {`) p.In() p.P(`sizeOfWire++`) p.P(`wire >>= 7`) p.P(`if wire == 0 {`) p.In() p.P(`break`) p.Out() p.P(`}`) p.Out() p.P(`}`) p.P(`index-=sizeOfWire`) p.P(`skippy, err := `, protoPkg.Use(), `.Skip(data[index:])`) p.P(`if err != nil {`) p.In() p.P(`return err`) p.Out() p.P(`}`) p.P(`if (index + skippy) > l {`) p.In() p.P(`return `, p.ioPkg.Use(), `.ErrUnexpectedEOF`) p.Out() p.P(`}`) p.P(`m.XXX_unrecognized = append(m.XXX_unrecognized, data[index:index+skippy]...)`) p.P(`index += skippy`) p.Out() if message.DescriptorProto.HasExtension() { p.Out() p.P(`}`) } p.Out() p.P(`}`) p.Out() p.P(`}`) p.P(`return nil`) p.Out() p.P(`}`) } if !p.atleastOne { return } }
func (p *unmarshal) Generate(file *generator.FileDescriptor) { p.PluginImports = generator.NewPluginImports(p.Generator) p.atleastOne = false p.ioPkg = p.NewImport("io") p.unsafePkg = p.NewImport("unsafe") protoPkg := p.NewImport("code.google.com/p/gogoprotobuf/proto") for _, message := range file.Messages() { if !gogoproto.IsUnsafeUnmarshaler(file.FileDescriptorProto, message.DescriptorProto) { continue } ccTypeName := generator.CamelCaseSlice(message.TypeName()) if gogoproto.IsUnmarshaler(file.FileDescriptorProto, message.DescriptorProto) { panic(fmt.Sprintf("unsafe_unmarshaler and unmarshaler enabled for %v", ccTypeName)) } p.atleastOne = true p.P(`func (m *`, ccTypeName, `) Unmarshal(data []byte) error {`) p.In() p.P(`l := len(data)`) p.P(`index := 0`) p.P(`for index < l {`) p.In() p.P(`var wire uint64`) p.decodeVarint("wire", "uint64") p.P(`fieldNum := int32(wire >> 3)`) if len(message.Field) > 0 { p.P(`wireType := int(wire & 0x7)`) } p.P(`switch fieldNum {`) p.In() for _, field := range message.Field { fieldname := generator.CamelCase(*field.Name) repeated := field.IsRepeated() nullable := gogoproto.IsNullable(field) packed := field.IsPacked() p.P(`case `, strconv.Itoa(int(field.GetNumber())), `:`) p.In() wireType := field.WireType() if packed { p.P(`if wireType != `, strconv.Itoa(proto.WireBytes), `{`) p.In() p.P(`return proto.ErrWrongType`) p.Out() p.P(`}`) p.P(`var packedLen int`) p.decodeVarint("packedLen", "int") p.P(`postIndex := index + packedLen`) p.P(`if postIndex > l {`) p.In() p.P(`return `, p.ioPkg.Use(), `.ErrUnexpectedEOF`) p.Out() p.P(`}`) p.P(`for index < postIndex {`) p.In() } else { p.P(`if wireType != `, strconv.Itoa(wireType), `{`) p.In() p.P(`return proto.ErrWrongType`) p.Out() p.P(`}`) } switch *field.Type { case descriptor.FieldDescriptorProto_TYPE_DOUBLE: if repeated { p.P(`var v float64`) p.unsafeFixed64("v", "float64") p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`) } else if nullable { p.P(`var v float64`) p.unsafeFixed64("v", "float64") p.P(`m.`, fieldname, ` = &v`) } else { p.unsafeFixed64(`m.`+fieldname, "float64") } case descriptor.FieldDescriptorProto_TYPE_FLOAT: if repeated { p.P(`var v float32`) p.unsafeFixed32("v", "float32") p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`) } else if nullable { p.P(`var v float32`) p.unsafeFixed32("v", "float32") p.P(`m.`, fieldname, ` = &v`) } else { p.unsafeFixed32("m."+fieldname, "float32") } case descriptor.FieldDescriptorProto_TYPE_INT64: if repeated { p.P(`var v int64`) p.decodeVarint("v", "int64") p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`) } else if nullable { p.P(`var v int64`) p.decodeVarint("v", "int64") p.P(`m.`, fieldname, ` = &v`) } else { p.decodeVarint("m."+fieldname, "int64") } case descriptor.FieldDescriptorProto_TYPE_UINT64: if repeated { p.P(`var v uint64`) p.decodeVarint("v", "uint64") p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`) } else if nullable { p.P(`var v uint64`) p.decodeVarint("v", "uint64") p.P(`m.`, fieldname, ` = &v`) } else { p.decodeVarint("m."+fieldname, "uint64") } case descriptor.FieldDescriptorProto_TYPE_INT32: if repeated { p.P(`var v int32`) p.decodeVarint("v", "int32") p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`) } else if nullable { p.P(`var v int32`) p.decodeVarint("v", "int32") p.P(`m.`, fieldname, ` = &v`) } else { p.decodeVarint("m."+fieldname, "int32") } case descriptor.FieldDescriptorProto_TYPE_FIXED64: if repeated { p.P(`var v uint64`) p.unsafeFixed64("v", "uint64") p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`) } else if nullable { p.P(`var v uint64`) p.unsafeFixed64("v", "uint64") p.P(`m.`, fieldname, ` = &v`) } else { p.unsafeFixed64("m."+fieldname, "uint64") } case descriptor.FieldDescriptorProto_TYPE_FIXED32: if repeated { p.P(`var v uint32`) p.unsafeFixed32("v", "uint32") p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`) } else if nullable { p.P(`var v uint32`) p.unsafeFixed32("v", "uint32") p.P(`m.`, fieldname, ` = &v`) } else { p.unsafeFixed32("m."+fieldname, "uint32") } case descriptor.FieldDescriptorProto_TYPE_BOOL: if repeated { p.P(`var v int`) p.decodeVarint("v", "int") p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, bool(v != 0))`) } else if nullable { p.P(`var v int`) p.decodeVarint("v", "int") p.P(`b := bool(v != 0)`) p.P(`m.`, fieldname, ` = &b`) } else { p.P(`var v int`) p.decodeVarint("v", "int") p.P(`m.`, fieldname, ` = bool(v != 0)`) } case descriptor.FieldDescriptorProto_TYPE_STRING: p.P(`var stringLen uint64`) p.decodeVarint("stringLen", "uint64") p.P(`postIndex := index + int(stringLen)`) p.P(`if postIndex > l {`) p.In() p.P(`return `, p.ioPkg.Use(), `.ErrUnexpectedEOF`) p.Out() p.P(`}`) if repeated { p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, string(data[index:postIndex]))`) } else if nullable { p.P(`s := string(data[index:postIndex])`) p.P(`m.`, fieldname, ` = &s`) } else { p.P(`m.`, fieldname, ` = string(data[index:postIndex])`) } p.P(`index = postIndex`) case descriptor.FieldDescriptorProto_TYPE_GROUP: panic(fmt.Errorf("unmarshaler does not support group %v", fieldname)) case descriptor.FieldDescriptorProto_TYPE_MESSAGE: desc := p.ObjectNamed(field.GetTypeName()) msgname := p.TypeName(desc) msgnames := strings.Split(msgname, ".") typeName := msgnames[len(msgnames)-1] if gogoproto.IsEmbed(field) { fieldname = typeName } p.P(`var msglen int`) p.decodeVarint("msglen", "int") p.P(`postIndex := index + msglen`) p.P(`if postIndex > l {`) p.In() p.P(`return `, p.ioPkg.Use(), `.ErrUnexpectedEOF`) p.Out() p.P(`}`) if repeated { if nullable { p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, &`, msgname, `{})`) } else { p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, `, msgname, `{})`) } p.P(`m.`, fieldname, `[len(m.`, fieldname, `)-1].Unmarshal(data[index:postIndex])`) } else if nullable { p.P(`if m.`, fieldname, ` == nil {`) p.In() p.P(`m.`, fieldname, ` = &`, msgname, `{}`) p.Out() p.P(`}`) p.P(`if err := m.`, fieldname, `.Unmarshal(data[index:postIndex]); err != nil {`) p.In() p.P(`return err`) p.Out() p.P(`}`) } else { p.P(`if err := m.`, fieldname, `.Unmarshal(data[index:postIndex]); err != nil {`) p.In() p.P(`return err`) p.Out() p.P(`}`) } p.P(`index = postIndex`) case descriptor.FieldDescriptorProto_TYPE_BYTES: p.P(`var byteLen int`) p.decodeVarint("byteLen", "int") p.P(`postIndex := index + byteLen`) p.P(`if postIndex > l {`) p.In() p.P(`return `, p.ioPkg.Use(), `.ErrUnexpectedEOF`) p.Out() p.P(`}`) if !gogoproto.IsCustomType(field) { if repeated { p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, make([]byte, postIndex-index))`) p.P(`copy(m.`, fieldname, `[len(m.`, fieldname, `)-1], data[index:postIndex])`) } else if nullable { p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, data[index:postIndex]...)`) } else { p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, data[index:postIndex]...)`) } } else { _, ctyp, err := generator.GetCustomType(field) if err != nil { panic(err) } if repeated { p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, `, ctyp, `{})`) p.P(`m.`, fieldname, `[len(m.`, fieldname, `)-1].Unmarshal(data[index:postIndex])`) } else if nullable { p.P(`m.`, fieldname, ` = &`, ctyp, `{}`) p.P(`if err := m.`, fieldname, `.Unmarshal(data[index:postIndex]); err != nil {`) p.In() p.P(`return err`) p.Out() p.P(`}`) } else { p.P(`if err := m.`, fieldname, `.Unmarshal(data[index:postIndex]); err != nil {`) p.In() p.P(`return err`) p.Out() p.P(`}`) } } p.P(`index = postIndex`) case descriptor.FieldDescriptorProto_TYPE_UINT32: if repeated { p.P(`var v uint32`) p.decodeVarint("v", "uint32") p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`) } else if nullable { p.P(`var v uint32`) p.decodeVarint("v", "uint32") p.P(`m.`, fieldname, ` = &v`) } else { p.decodeVarint("m."+fieldname, "uint32") } case descriptor.FieldDescriptorProto_TYPE_ENUM: typName := p.TypeName(p.ObjectNamed(field.GetTypeName())) if repeated { p.P(`var v `, typName) p.decodeVarint("v", typName) p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`) } else if nullable { p.P(`var v `, typName) p.decodeVarint("v", typName) p.P(`m.`, fieldname, ` = &v`) } else { p.decodeVarint("m."+fieldname, typName) } case descriptor.FieldDescriptorProto_TYPE_SFIXED32: if repeated { p.P(`var v int32`) p.unsafeFixed32("v", "int32") p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`) } else if nullable { p.P(`var v int32`) p.unsafeFixed32("v", "int32") p.P(`m.`, fieldname, ` = &v`) } else { p.unsafeFixed32("m."+fieldname, "int32") } case descriptor.FieldDescriptorProto_TYPE_SFIXED64: if repeated { p.P(`var v int64`) p.unsafeFixed64("v", "int64") p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`) } else if nullable { p.P(`var v int64`) p.unsafeFixed64("v", "int64") p.P(`m.`, fieldname, ` = &v`) } else { p.unsafeFixed64("m."+fieldname, "int64") } case descriptor.FieldDescriptorProto_TYPE_SINT32: p.P(`var v int32`) p.decodeVarint("v", "int32") p.P(`v = int32((uint32(v) >> 1) ^ uint32(((v&1)<<31)>>31))`) if repeated { p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`) } else if nullable { p.P(`m.`, fieldname, ` = &v`) } else { p.P(`m.`, fieldname, ` = v`) } case descriptor.FieldDescriptorProto_TYPE_SINT64: p.P(`var v uint64`) p.decodeVarint("v", "uint64") p.P(`v = (v >> 1) ^ uint64((int64(v&1)<<63)>>63)`) if repeated { p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, int64(v))`) } else if nullable { p.P(`v2 := int64(v)`) p.P(`m.`, fieldname, ` = &v2`) } else { p.P(`m.`, fieldname, ` = int64(v)`) } default: panic("not implemented") } if packed { p.Out() p.P(`}`) } } p.Out() p.P(`default:`) p.In() if message.DescriptorProto.HasExtension() { c := []string{} for _, erange := range message.GetExtensionRange() { c = append(c, `((fieldNum >= `+strconv.Itoa(int(erange.GetStart()))+") && (fieldNum<"+strconv.Itoa(int(erange.GetEnd()))+`))`) } p.P(`if `, strings.Join(c, "||"), `{`) p.In() p.P(`var sizeOfWire int`) p.P(`for {`) p.In() p.P(`sizeOfWire++`) p.P(`wire >>= 7`) p.P(`if wire == 0 {`) p.In() p.P(`break`) p.Out() p.P(`}`) p.Out() p.P(`}`) p.P(`index-=sizeOfWire`) p.P(`skippy, err := `, protoPkg.Use(), `.Skip(data[index:])`) p.P(`if err != nil {`) p.In() p.P(`return err`) p.Out() p.P(`}`) p.P(`if m.XXX_extensions == nil {`) p.In() p.P(`m.XXX_extensions = make(map[int32]`, protoPkg.Use(), `.Extension)`) p.Out() p.P(`}`) p.P(`if (index + skippy) > l {`) p.In() p.P(`return `, p.ioPkg.Use(), `.ErrUnexpectedEOF`) p.Out() p.P(`}`) p.P(`m.XXX_extensions[int32(fieldNum)] = `, protoPkg.Use(), `.NewExtension(data[index:index+skippy])`) p.P(`index += skippy`) p.Out() p.P(`} else {`) p.In() } p.P(`var sizeOfWire int`) p.P(`for {`) p.In() p.P(`sizeOfWire++`) p.P(`wire >>= 7`) p.P(`if wire == 0 {`) p.In() p.P(`break`) p.Out() p.P(`}`) p.Out() p.P(`}`) p.P(`index-=sizeOfWire`) p.P(`skippy, err := `, protoPkg.Use(), `.Skip(data[index:])`) p.P(`if err != nil {`) p.In() p.P(`return err`) p.Out() p.P(`}`) p.P(`if (index + skippy) > l {`) p.In() p.P(`return `, p.ioPkg.Use(), `.ErrUnexpectedEOF`) p.Out() p.P(`}`) p.P(`m.XXX_unrecognized = append(m.XXX_unrecognized, data[index:index+skippy]...)`) p.P(`index += skippy`) p.Out() if message.DescriptorProto.HasExtension() { p.Out() p.P(`}`) } p.Out() p.P(`}`) p.Out() p.P(`}`) p.P(`return nil`) p.Out() p.P(`}`) } if !p.atleastOne { return } }