func TestGetExtensionStability(t *testing.T) { check := func(m *pb.MyMessage) bool { ext1, err := proto.GetExtension(m, pb.E_Ext_More) if err != nil { t.Fatalf("GetExtension() failed: %s", err) } ext2, err := proto.GetExtension(m, pb.E_Ext_More) if err != nil { t.Fatalf("GetExtension() failed: %s", err) } return ext1 == ext2 } msg := &pb.MyMessage{Count: proto.Int32(4)} ext0 := &pb.Ext{} if err := proto.SetExtension(msg, pb.E_Ext_More, ext0); err != nil { t.Fatalf("Could not set ext1: %s", ext0) } if !check(msg) { t.Errorf("GetExtension() not stable before marshaling") } bb, err := proto.Marshal(msg) if err != nil { t.Fatalf("Marshal() failed: %s", err) } msg1 := &pb.MyMessage{} err = proto.Unmarshal(bb, msg1) if err != nil { t.Fatalf("Unmarshal() failed: %s", err) } if !check(msg1) { t.Errorf("GetExtension() not stable after unmarshaling") } }
// Take n tokens from bucket t, key k func (client *Client) Take(t string, k string, n int32) (response *limitd.Response, takeResponse *limitd.TakeResponse, err error) { requestID := uniuri.New() request := &limitd.Request{ Id: proto.String(requestID), Method: limitd.Request_TAKE.Enum(), Type: proto.String(t), Key: proto.String(k), Count: proto.Int32(n), } // goprotobuf.EncodeVarint followed by proto.Marshal responseChan := make(chan *limitd.Response) client.PendingRequests[requestID] = responseChan data, _ := proto.Marshal(request) data = append(proto.EncodeVarint(uint64(len(data))), data...) client.Conn.Write(data) response = <-responseChan takeR, err := proto.GetExtension(response, limitd.E_TakeResponse_Response) if err != nil { return } if takeResponseCasted, ok := takeR.(*limitd.TakeResponse); ok { takeResponse = takeResponseCasted } return }
// SetPackageNames sets the package name for this run. // The package name must agree across all files being generated. func (g *Generator) SetPackageNames() { for _, f := range g.allFiles { pkgName := "" if f.Options != nil { extMap := f.GetOptions().ExtensionMap() if _, ok := extMap[50000]; ok { itf, err := proto.GetExtension( f.GetOptions(), javascript_package.E_JavascriptPackage) if err == nil { pkgName = *itf.(*string) } } } if pkgName == "" { pkgName = f.GetPackage() } uniquePackageName[f.FileDescriptorProto] = pkgName registerUniquePackageName(pkgName, f) if f == g.genFiles[0] { g.packageName = pkgName } } }
func fillTreeWithMethod(tree *tree, key string, proto *descriptor.MethodDescriptorProto, loc string, locs map[string]*descriptor.SourceCodeInfo_Location) *method { key = fmt.Sprintf("%s.%s", key, proto.GetName()) tree.methods[key] = &method{key: key, comment: getComment(loc, locs), MethodDescriptorProto: proto} if input, ok := tree.messages[proto.GetInputType()]; ok { tree.methods[key].input = input } if proto.GetClientStreaming() { tree.methods[key].inputStream = true } if output, ok := tree.messages[proto.GetOutputType()]; ok { tree.methods[key].output = output } if proto.GetServerStreaming() { tree.methods[key].outputStream = true } if proto.Options != nil && protobuf.HasExtension(proto.Options, gateway.E_Http) { ext, err := protobuf.GetExtension(proto.Options, gateway.E_Http) if err == nil { if opts, ok := ext.(*gateway.HttpRule); ok { if endpoint := newEndpoint(opts); endpoint != nil { tree.methods[key].endpoints = append(tree.methods[key].endpoints, endpoint) } for _, opts := range opts.AdditionalBindings { if endpoint := newEndpoint(opts); endpoint != nil { tree.methods[key].endpoints = append(tree.methods[key].endpoints, endpoint) } } } } } return tree.methods[key] }
func newServerContactReqChan(conn *ricochetConn, msg *packet.OpenChannel) (*contactReqChan, error) { ch := new(contactReqChan) ch.conn = conn ch.chanID = (uint16)(msg.GetChannelIdentifier()) ch.reqData = new(ContactRequest) ext, err := proto.GetExtension(msg, packet.E_ContactRequest) if err != nil { return nil, err } if ext == nil { return nil, fmt.Errorf("server: missing ContactRequest extension") } req := ext.(*packet.ContactRequest) ch.reqData.Hostname = conn.hostname ch.reqData.MyNickname = req.GetNickname() if len(ch.reqData.MyNickname) > ContactReqNicknameMaxCharacters { return nil, fmt.Errorf("server: ContactRequest nickname too long") } ch.reqData.Message = req.GetMessageText() if len(ch.reqData.Message) > ContactReqMessageMaxCharacters { return nil, fmt.Errorf("server: ContactRequest message too long") } return ch, nil }
func TestExtensionsRoundTrip(t *testing.T) { msg := &pb.MyMessage{} ext1 := &pb.Ext{ Data: proto.String("hi"), } ext2 := &pb.Ext{ Data: proto.String("there"), } exists := proto.HasExtension(msg, pb.E_Ext_More) if exists { t.Error("Extension More present unexpectedly") } if err := proto.SetExtension(msg, pb.E_Ext_More, ext1); err != nil { t.Error(err) } if err := proto.SetExtension(msg, pb.E_Ext_More, ext2); err != nil { t.Error(err) } e, err := proto.GetExtension(msg, pb.E_Ext_More) if err != nil { t.Error(err) } x, ok := e.(*pb.Ext) if !ok { t.Errorf("e has type %T, expected testdata.Ext", e) } else if *x.Data != "there" { t.Errorf("SetExtension failed to overwrite, got %+v, not 'there'", x) } proto.ClearExtension(msg, pb.E_Ext_More) if _, err = proto.GetExtension(msg, pb.E_Ext_More); err != proto.ErrMissingExtension { t.Errorf("got %v, expected ErrMissingExtension", e) } if _, err := proto.GetExtension(msg, pb.E_X215); err == nil { t.Error("expected bad extension error, got nil") } if err := proto.SetExtension(msg, pb.E_X215, 12); err == nil { t.Error("expected extension err") } if err := proto.SetExtension(msg, pb.E_Ext_More, 12); err == nil { t.Error("expected some sort of type mismatch error, got nil") } }
func int64AsNumber(field *descriptor.FieldDescriptorProto) bool { var int64Encoding int64_encoding.Int64Encoding if field.Options != nil { extMap := field.GetOptions().ExtensionMap() if _, ok := extMap[50001]; ok { itf, err := proto.GetExtension(field.GetOptions(), int64_encoding.E_Jstype) if err == nil { int64Encoding = *itf.(*int64_encoding.Int64Encoding) } } } return int64Encoding == int64_encoding.Int64Encoding_JS_NUMBER }
func newServerAuthHSChan(conn *ricochetConn, msg *packet.OpenChannel) (*authHSChan, error) { ch := new(authHSChan) ch.conn = conn ch.chanID = (uint16)(msg.GetChannelIdentifier()) ext, err := proto.GetExtension(msg, packet.E_ClientCookie) if err != nil { return nil, err } ch.clientCookie = ext.([]byte) if len(ch.clientCookie) != authHiddenServiceCookieSize { return nil, fmt.Errorf("invalid AuthHiddenService client_cookie") } return ch, nil }
func convertFile(file *descriptor.FileDescriptorProto) ([]*plugin.CodeGeneratorResponse_File, error) { name := path.Base(file.GetName()) pkg, ok := globalPkg.relativelyLookupPackage(file.GetPackage()) if !ok { return nil, fmt.Errorf("no such package found: %s", file.GetPackage()) } response := []*plugin.CodeGeneratorResponse_File{} for _, msg := range file.GetMessageType() { options := msg.GetOptions() if options == nil { continue } if !proto.HasExtension(options, E_TableName) { continue } optionValue, err := proto.GetExtension(options, E_TableName) if err != nil { return nil, err } tableName := *optionValue.(*string) if len(tableName) == 0 { return nil, fmt.Errorf("table name of %s cannot be empty", msg.GetName()) } glog.V(2).Info("Generating schema for a message type ", msg.GetName()) schema, err := convertMessageType(pkg, msg) if err != nil { glog.Errorf("Failed to convert %s: %v", name, err) return nil, err } jsonSchema, err := json.Marshal(schema) if err != nil { glog.Error("Failed to encode schema", err) return nil, err } resFile := &plugin.CodeGeneratorResponse_File{ Name: proto.String(fmt.Sprintf("%s/%s.schema", strings.Replace(file.GetPackage(), ".", "/", -1), tableName)), Content: proto.String(string(jsonSchema)), } response = append(response, resFile) } return response, nil }
func extractAPIOptions(meth *descriptor.MethodDescriptorProto) (*options.HttpRule, error) { if meth.Options == nil { return nil, nil } if !proto.HasExtension(meth.Options, options.E_Http) { return nil, nil } ext, err := proto.GetExtension(meth.Options, options.E_Http) if err != nil { return nil, err } opts, ok := ext.(*options.HttpRule) if !ok { return nil, fmt.Errorf("extension is %T; want an HttpRule", ext) } return opts, nil }
func (ch *contactReqChan) onChannelResult(msg *packet.ChannelResult) error { if ch.conn.isServer { return fmt.Errorf("opened contact req channel to client") } if ch.state != chanStateOpening { return fmt.Errorf("received spurious ContactRequest ChannelResult") } ch.state = chanStateOpen // If this routine was called, the channel WAS opened, without incident. // Extract the response, and take action accordingly. ext, err := proto.GetExtension(msg, packet.E_Response) if err != nil { return err } resp := ext.(*packet.ContactRequestResponse) return ch.onResponse(resp) }
func (ch *authHSChan) onChannelResult(msg *packet.ChannelResult) error { if ch.conn.isServer { return fmt.Errorf("opened auth channel to client") } if ch.state != chanStateOpening { return fmt.Errorf("received spurious AuthHiddenService ChannelResult") } ch.state = chanStateOpen // If this routine was called, the channel WAS opened, without incident. // Extract the server cookie, and send the proof. ext, err := proto.GetExtension(msg, packet.E_ServerCookie) if err != nil { return err } ch.serverCookie = ext.([]byte) if len(ch.serverCookie) != authHiddenServiceCookieSize { return fmt.Errorf("invalid AuthHiddenService server_cookie") } // Encode the public key to DER. pkDER, err := pkcs1.EncodePublicKeyDER(&ch.conn.endpoint.privateKey.PublicKey) if err != nil { return err } // Calculate the proof. proof := ch.calculateProof(ch.conn.endpoint.hostname, ch.conn.hostname) // Sign the proof. sig, err := rsa.SignPKCS1v15(rand.Reader, ch.conn.endpoint.privateKey, crypto.SHA256, proof) if err != nil { return err } return ch.sendProof(pkDER, sig) }
// Extract extracts a compilation from the specified extra action info. func (c *Config) Extract(ctx context.Context, info *eapb.ExtraActionInfo) (*kindex.Compilation, error) { si, err := proto.GetExtension(info, eapb.E_SpawnInfo_SpawnInfo) if err != nil { return nil, fmt.Errorf("extra action does not have SpawnInfo: %v", err) } spawnInfo := si.(*eapb.SpawnInfo) // Verify that the mnemonic is what we expect. if m := info.GetMnemonic(); m != c.Mnemonic && c.Mnemonic != "" { return nil, fmt.Errorf("mnemonic does not match %q ≠ %q", m, c.Mnemonic) } // Construct the basic compilation. toolArgs := extractToolArgs(spawnInfo.Argument) log.Printf("Extracting compilation for %q", info.GetOwner()) cu := &kindex.Compilation{ Proto: &apb.CompilationUnit{ VName: &spb.VName{ Language: govname.Language, Corpus: c.Corpus, Signature: info.GetOwner(), }, Argument: toolArgs.fullArgs, SourceFile: toolArgs.sources, WorkingDirectory: c.Root, Environment: []*apb.CompilationUnit_Env{{ Name: "GOROOT", Value: toolArgs.goRoot, }}, }, } // Load and populate file contents and required inputs. Do this in two // passes: First scan the inputs and filter out which ones we actually want // to keep; then load their contents concurrently. var wantPaths []string for _, in := range spawnInfo.InputFile { if toolArgs.wantInput(in) { wantPaths = append(wantPaths, in) cu.Files = append(cu.Files, nil) cu.Proto.RequiredInput = append(cu.Proto.RequiredInput, nil) } } // Fetch concurrently. Each element of the proto slices is accessed by a // single goroutine corresponding to its index. log.Printf("Reading file contents for %d required inputs", len(wantPaths)) start := time.Now() var wg sync.WaitGroup for i, path := range wantPaths { i, path := i, path wg.Add(1) go func() { defer wg.Done() fd, err := c.readFile(ctx, path) if err != nil { log.Fatalf("Unable to read input %q: %v", path, err) } cu.Files[i] = fd cu.Proto.RequiredInput[i] = c.fileDataToInfo(fd) }() } wg.Wait() log.Printf("Finished reading required inputs [%v elapsed]", time.Since(start)) // Set the output path. Although the SpawnInfo has room for multiple // outputs, we expect only one to be set in practice. It's harmless if // there are more, though, so don't fail for that. for _, out := range spawnInfo.OutputFile { cu.Proto.OutputKey = out break } // Capture environment variables. for _, evar := range spawnInfo.Variable { if evar.GetName() == "PATH" { // TODO(fromberger): Perhaps whitelist or blacklist which // environment variables to capture here. continue } cu.Proto.Environment = append(cu.Proto.Environment, &apb.CompilationUnit_Env{ Name: evar.GetName(), Value: evar.GetValue(), }) } return cu, nil }
// ProcessConnection starts a blocking process loop which continually waits for // new messages to arrive from the connection and uses the given RicochetService // to process them. func (r *Ricochet) processConnection(oc *OpenConnection, service RicochetService) { service.OnConnect(oc) defer service.OnDisconnect(oc) for { if oc.Closed { return } packet, err := r.rni.RecvRicochetPacket(oc.conn) if err != nil { oc.Close() return } if len(packet.Data) == 0 { service.OnChannelClosed(oc, packet.Channel) continue } if packet.Channel == 0 { res := new(Protocol_Data_Control.Packet) err := proto.Unmarshal(packet.Data[:], res) if err != nil { service.OnGenericError(oc, packet.Channel) continue } if res.GetOpenChannel() != nil { opm := res.GetOpenChannel() if oc.GetChannelType(opm.GetChannelIdentifier()) != "none" { // Channel is already in use. service.OnBadUsageError(oc, opm.GetChannelIdentifier()) continue } // If I am a Client, the server can only open even numbered channels if oc.Client && opm.GetChannelIdentifier()%2 != 0 { service.OnBadUsageError(oc, opm.GetChannelIdentifier()) continue } // If I am a Server, the client can only open odd numbered channels if !oc.Client && opm.GetChannelIdentifier()%2 != 1 { service.OnBadUsageError(oc, opm.GetChannelIdentifier()) continue } switch opm.GetChannelType() { case "im.ricochet.auth.hidden-service": if oc.Client { // Servers are authed by default and can't auth with hidden-service service.OnBadUsageError(oc, opm.GetChannelIdentifier()) } else if oc.IsAuthed { // Can't auth if already authed service.OnBadUsageError(oc, opm.GetChannelIdentifier()) } else if oc.HasChannel("im.ricochet.auth.hidden-service") { // Can't open more than 1 auth channel service.OnBadUsageError(oc, opm.GetChannelIdentifier()) } else { clientCookie, err := proto.GetExtension(opm, Protocol_Data_AuthHiddenService.E_ClientCookie) if err == nil { clientCookieB := [16]byte{} copy(clientCookieB[:], clientCookie.([]byte)[:]) service.OnAuthenticationRequest(oc, opm.GetChannelIdentifier(), clientCookieB) } else { // Must include Client Cookie service.OnBadUsageError(oc, opm.GetChannelIdentifier()) } } case "im.ricochet.chat": if !oc.IsAuthed { // Can't open chat channel if not authorized service.OnUnauthorizedError(oc, opm.GetChannelIdentifier()) } else if !service.IsKnownContact(oc.OtherHostname) { // Can't open chat channel if not a known contact service.OnUnauthorizedError(oc, opm.GetChannelIdentifier()) } else { service.OnOpenChannelRequest(oc, opm.GetChannelIdentifier(), "im.ricochet.chat") } case "im.ricochet.contact.request": if oc.Client { // Servers are not allowed to send contact requests service.OnBadUsageError(oc, opm.GetChannelIdentifier()) } else if !oc.IsAuthed { // Can't open a contact channel if not authed service.OnUnauthorizedError(oc, opm.GetChannelIdentifier()) } else if oc.HasChannel("im.ricochet.contact.request") { // Only 1 contact channel is allowed to be open at a time service.OnBadUsageError(oc, opm.GetChannelIdentifier()) } else { contactRequestI, err := proto.GetExtension(opm, Protocol_Data_ContactRequest.E_ContactRequest) if err == nil { contactRequest, check := contactRequestI.(*Protocol_Data_ContactRequest.ContactRequest) if check { service.OnContactRequest(oc, opm.GetChannelIdentifier(), contactRequest.GetNickname(), contactRequest.GetMessageText()) break } } service.OnBadUsageError(oc, opm.GetChannelIdentifier()) } default: service.OnUnknownTypeError(oc, opm.GetChannelIdentifier()) } } else if res.GetChannelResult() != nil { crm := res.GetChannelResult() if crm.GetOpened() { switch oc.GetChannelType(crm.GetChannelIdentifier()) { case "im.ricochet.auth.hidden-service": serverCookie, err := proto.GetExtension(crm, Protocol_Data_AuthHiddenService.E_ServerCookie) if err == nil { serverCookieB := [16]byte{} copy(serverCookieB[:], serverCookie.([]byte)[:]) service.OnAuthenticationChallenge(oc, crm.GetChannelIdentifier(), serverCookieB) } else { service.OnBadUsageError(oc, crm.GetChannelIdentifier()) } case "im.ricochet.chat": service.OnOpenChannelRequestSuccess(oc, crm.GetChannelIdentifier()) case "im.ricochet.contact.request": responseI, err := proto.GetExtension(res.GetChannelResult(), Protocol_Data_ContactRequest.E_Response) if err == nil { response, check := responseI.(*Protocol_Data_ContactRequest.Response) if check { service.OnContactRequestAck(oc, crm.GetChannelIdentifier(), response.GetStatus().String()) break } } service.OnBadUsageError(oc, crm.GetChannelIdentifier()) default: service.OnBadUsageError(oc, crm.GetChannelIdentifier()) } } else { if oc.GetChannelType(crm.GetChannelIdentifier()) != "none" { service.OnFailedChannelOpen(oc, crm.GetChannelIdentifier(), crm.GetCommonError().String()) } else { oc.CloseChannel(crm.GetChannelIdentifier()) } } } else { // Unknown Message oc.CloseChannel(packet.Channel) } } else if oc.GetChannelType(packet.Channel) == "im.ricochet.auth.hidden-service" { res := new(Protocol_Data_AuthHiddenService.Packet) err := proto.Unmarshal(packet.Data[:], res) if err != nil { oc.CloseChannel(packet.Channel) continue } if res.GetProof() != nil && !oc.Client { // Only Clients Send Proofs service.OnAuthenticationProof(oc, packet.Channel, res.GetProof().GetPublicKey(), res.GetProof().GetSignature(), service.IsKnownContact(oc.OtherHostname)) } else if res.GetResult() != nil && oc.Client { // Only Servers Send Results service.OnAuthenticationResult(oc, packet.Channel, res.GetResult().GetAccepted(), res.GetResult().GetIsKnownContact()) } else { // If neither of the above are satisfied we just close the connection oc.Close() } } else if oc.GetChannelType(packet.Channel) == "im.ricochet.chat" { // NOTE: These auth checks should be redundant, however they // are included here for defense-in-depth if for some reason // a previously authed connection becomes untrusted / not known and // the state is not cleaned up. if !oc.IsAuthed { // Can't send chat messages if not authorized service.OnUnauthorizedError(oc, packet.Channel) } else if !service.IsKnownContact(oc.OtherHostname) { // Can't send chat message if not a known contact service.OnUnauthorizedError(oc, packet.Channel) } else { res := new(Protocol_Data_Chat.Packet) err := proto.Unmarshal(packet.Data[:], res) if err != nil { oc.CloseChannel(packet.Channel) continue } if res.GetChatMessage() != nil { service.OnChatMessage(oc, packet.Channel, int32(res.GetChatMessage().GetMessageId()), res.GetChatMessage().GetMessageText()) } else if res.GetChatAcknowledge() != nil { service.OnChatMessageAck(oc, packet.Channel, int32(res.GetChatMessage().GetMessageId())) } else { // If neither of the above are satisfied we just close the connection oc.Close() } } } else if oc.GetChannelType(packet.Channel) == "im.ricochet.contact.request" { // NOTE: These auth checks should be redundant, however they // are included here for defense-in-depth if for some reason // a previously authed connection becomes untrusted / not known and // the state is not cleaned up. if !oc.Client { // Clients are not allowed to send contact request responses service.OnBadUsageError(oc, packet.Channel) } else if !oc.IsAuthed { // Can't send a contact request if not authed service.OnBadUsageError(oc, packet.Channel) } else { res := new(Protocol_Data_ContactRequest.Response) err := proto.Unmarshal(packet.Data[:], res) log.Printf("%v", res) if err != nil { oc.CloseChannel(packet.Channel) continue } service.OnContactRequestAck(oc, packet.Channel, res.GetStatus().String()) } } else if oc.GetChannelType(packet.Channel) == "none" { // Invalid Channel Assignment oc.CloseChannel(packet.Channel) } else { oc.Close() } } }
func main() { msg := plugin.CodeGeneratorRequest{} buff, err := ioutil.ReadAll(os.Stdin) if err != nil { panic(err) } if err := proto.Unmarshal(buff, &msg); err != nil { panic(err) } ret := &plugin.CodeGeneratorResponse{} defer func() { buff, _ := proto.Marshal(ret) os.Stdout.Write(buff) }() param := msg.GetParameter() imports := map[string]string{} sources := map[string]string{} for _, p := range strings.Split(param, ",") { if len(p) == 0 { continue } if p[0] == 'M' { parts := strings.Split(p[1:], "=") imports[parts[0]] = parts[1] } } messages := map[string]message{} for _, file := range msg.GetProtoFile() { for _, msg := range file.GetMessageType() { m := message{} for _, f := range msg.GetField() { m.Fields = append(m.Fields, field{ ProtoName: f.GetName(), GoName: goise(f.GetName()), }) } messages["."+file.GetPackage()+"."+msg.GetName()] = m sources["."+file.GetPackage()+"."+msg.GetName()] = file.GetName() } } for _, file := range msg.GetProtoFile() { services := map[string]service{} goPackage := "main" if file.GetOptions() != nil { goPackage = file.GetOptions().GetGoPackage() } for _, svc := range file.GetService() { s := service{ GoName: goise(svc.GetName()), } for _, meth := range svc.GetMethod() { m := method{ GoName: goise(meth.GetName()), GoInputType: goise(meth.GetInputType()), Input: messages[meth.GetInputType()], InputType: meth.GetInputType(), } if meth.GetOptions() == nil { continue } if tmp, err := proto.GetExtension(meth.GetOptions(), google_api.E_Http); err == nil { http := tmp.(*google_api.HttpRule) if http.Get != "" { m.PathArgs = parsePath(messages, meth, http.Get) m.Path = http.Get m.Method = "GET" } if http.Put != "" { m.PathArgs = parsePath(messages, meth, http.Put) m.Path = http.Put m.Method = "PUT" } if http.Post != "" { m.PathArgs = parsePath(messages, meth, http.Post) m.Path = http.Post m.Method = "POST" } if http.Delete != "" { m.PathArgs = parsePath(messages, meth, http.Delete) m.Path = http.Delete m.Method = "DELETE" } if http.Body != "" { if m.Method == "PUT" || m.Method == "POST" { m.Body = http.Body m.GoBodyName = goise(http.Body) } else { log.Printf("WARN: Got http.body on non-put, non-post method.") } } } s.Methods = append(s.Methods, m) } if len(s.Methods) > 0 { services[svc.GetName()] = s } } if len(services) > 0 { fname := strings.Replace(file.GetName(), ".proto", ".pb.kit.go", 1) buff := bytes.NewBuffer(nil) imps := map[string]string{} impUsages := map[string]int{} for _, svc := range services { for _, meth := range svc.Methods { src := sources[meth.InputType] if src == file.GetName() { continue } impUsages[src] = impUsages[src] + 1 } } for k := range impUsages { if v, ok := imports[k]; ok { imps[k] = v continue } fmt.Fprintf(os.Stderr, "Import %q is unknown.", k) } fileTemplate.Execute(buff, struct { Package string Imports map[string]string Services map[string]service }{goPackage, imps, services}) data := buff.String() ret.File = append(ret.File, &plugin.CodeGeneratorResponse_File{ Name: &fname, Content: &data, }) } } }
// marshalObject writes a struct to the Writer. func (m *Marshaler) marshalObject(out *errWriter, v proto.Message, indent string) error { out.write("{") if m.Indent != "" { out.write("\n") } s := reflect.ValueOf(v).Elem() firstField := true for i := 0; i < s.NumField(); i++ { value := s.Field(i) valueField := s.Type().Field(i) if strings.HasPrefix(valueField.Name, "XXX_") { continue } // TODO: proto3 objects should have default values omitted. // IsNil will panic on most value kinds. switch value.Kind() { case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice: if value.IsNil() { continue } } // Oneof fields need special handling. if valueField.Tag.Get("protobuf_oneof") != "" { // value is an interface containing &T{real_value}. sv := value.Elem().Elem() // interface -> *T -> T value = sv.Field(0) valueField = sv.Type().Field(0) } prop := jsonProperties(valueField) if !firstField { m.writeSep(out) } if err := m.marshalField(out, prop, value, indent); err != nil { return err } firstField = false } // Handle proto2 extensions. if ep, ok := v.(extendableProto); ok { extensions := proto.RegisteredExtensions(v) extensionMap := ep.ExtensionMap() // Sort extensions for stable output. ids := make([]int32, 0, len(extensionMap)) for id := range extensionMap { ids = append(ids, id) } sort.Sort(int32Slice(ids)) for _, id := range ids { desc := extensions[id] if desc == nil { // unknown extension continue } ext, extErr := proto.GetExtension(ep, desc) if extErr != nil { return extErr } value := reflect.ValueOf(ext) var prop proto.Properties prop.Parse(desc.Tag) prop.OrigName = fmt.Sprintf("[%s]", desc.Name) if !firstField { m.writeSep(out) } if err := m.marshalField(out, &prop, value, indent); err != nil { return err } firstField = false } } if m.Indent != "" { out.write("\n") out.write(indent) } out.write("}") return out.err }
// marshalObject writes a struct to the Writer. func (m *Marshaler) marshalObject(out *errWriter, v proto.Message, indent string) error { s := reflect.ValueOf(v).Elem() // Handle well-known types. type wkt interface { XXX_WellKnownType() string } if wkt, ok := v.(wkt); ok { switch wkt.XXX_WellKnownType() { case "DoubleValue", "FloatValue", "Int64Value", "UInt64Value", "Int32Value", "UInt32Value", "BoolValue", "StringValue", "BytesValue": // "Wrappers use the same representation in JSON // as the wrapped primitive type, ..." sprop := proto.GetProperties(s.Type()) return m.marshalValue(out, sprop.Prop[0], s.Field(0), indent) case "Duration": // "Generated output always contains 3, 6, or 9 fractional digits, // depending on required precision." s, ns := s.Field(0).Int(), s.Field(1).Int() d := time.Duration(s)*time.Second + time.Duration(ns)*time.Nanosecond x := fmt.Sprintf("%.9f", d.Seconds()) x = strings.TrimSuffix(x, "000") x = strings.TrimSuffix(x, "000") out.write(`"`) out.write(x) out.write(`s"`) return out.err case "Timestamp": // "RFC 3339, where generated output will always be Z-normalized // and uses 3, 6 or 9 fractional digits." s, ns := s.Field(0).Int(), s.Field(1).Int() t := time.Unix(s, ns).UTC() // time.RFC3339Nano isn't exactly right (we need to get 3/6/9 fractional digits). x := t.Format("2006-01-02T15:04:05.000000000") x = strings.TrimSuffix(x, "000") x = strings.TrimSuffix(x, "000") out.write(`"`) out.write(x) out.write(`Z"`) return out.err } } out.write("{") if m.Indent != "" { out.write("\n") } firstField := true for i := 0; i < s.NumField(); i++ { value := s.Field(i) valueField := s.Type().Field(i) if strings.HasPrefix(valueField.Name, "XXX_") { continue } // IsNil will panic on most value kinds. switch value.Kind() { case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice: if value.IsNil() { continue } } if !m.EmitDefaults { switch value.Kind() { case reflect.Bool: if !value.Bool() { continue } case reflect.Int32, reflect.Int64: if value.Int() == 0 { continue } case reflect.Uint32, reflect.Uint64: if value.Uint() == 0 { continue } case reflect.Float32, reflect.Float64: if value.Float() == 0 { continue } case reflect.String: if value.Len() == 0 { continue } } } // Oneof fields need special handling. if valueField.Tag.Get("protobuf_oneof") != "" { // value is an interface containing &T{real_value}. sv := value.Elem().Elem() // interface -> *T -> T value = sv.Field(0) valueField = sv.Type().Field(0) } prop := jsonProperties(valueField, m.OrigName) if !firstField { m.writeSep(out) } if err := m.marshalField(out, prop, value, indent); err != nil { return err } firstField = false } // Handle proto2 extensions. if ep, ok := v.(extendableProto); ok { extensions := proto.RegisteredExtensions(v) extensionMap := ep.ExtensionMap() // Sort extensions for stable output. ids := make([]int32, 0, len(extensionMap)) for id := range extensionMap { ids = append(ids, id) } sort.Sort(int32Slice(ids)) for _, id := range ids { desc := extensions[id] if desc == nil { // unknown extension continue } ext, extErr := proto.GetExtension(ep, desc) if extErr != nil { return extErr } value := reflect.ValueOf(ext) var prop proto.Properties prop.Parse(desc.Tag) prop.JSONName = fmt.Sprintf("[%s]", desc.Name) if !firstField { m.writeSep(out) } if err := m.marshalField(out, &prop, value, indent); err != nil { return err } firstField = false } } if m.Indent != "" { out.write("\n") out.write(indent) } out.write("}") return out.err }
func TestMarshalUnmarshalRepeatedExtension(t *testing.T) { // Add a repeated extension to the result. tests := []struct { name string ext []*pb.ComplexExtension }{ { "two fields", []*pb.ComplexExtension{ {First: proto.Int32(7)}, {Second: proto.Int32(11)}, }, }, { "repeated field", []*pb.ComplexExtension{ {Third: []int32{1000}}, {Third: []int32{2000}}, }, }, { "two fields and repeated field", []*pb.ComplexExtension{ {Third: []int32{1000}}, {First: proto.Int32(9)}, {Second: proto.Int32(21)}, {Third: []int32{2000}}, }, }, } for _, test := range tests { // Marshal message with a repeated extension. msg1 := new(pb.OtherMessage) err := proto.SetExtension(msg1, pb.E_RComplex, test.ext) if err != nil { t.Fatalf("[%s] Error setting extension: %v", test.name, err) } b, err := proto.Marshal(msg1) if err != nil { t.Fatalf("[%s] Error marshaling message: %v", test.name, err) } // Unmarshal and read the merged proto. msg2 := new(pb.OtherMessage) err = proto.Unmarshal(b, msg2) if err != nil { t.Fatalf("[%s] Error unmarshaling message: %v", test.name, err) } e, err := proto.GetExtension(msg2, pb.E_RComplex) if err != nil { t.Fatalf("[%s] Error getting extension: %v", test.name, err) } ext := e.([]*pb.ComplexExtension) if ext == nil { t.Fatalf("[%s] Invalid extension", test.name) } if !reflect.DeepEqual(ext, test.ext) { t.Errorf("[%s] Wrong value for ComplexExtension: got: %v want: %v\n", test.name, ext, test.ext) } } }
func TestGetExtensionDefaults(t *testing.T) { var setFloat64 float64 = 1 var setFloat32 float32 = 2 var setInt32 int32 = 3 var setInt64 int64 = 4 var setUint32 uint32 = 5 var setUint64 uint64 = 6 var setBool = true var setBool2 = false var setString = "Goodnight string" var setBytes = []byte("Goodnight bytes") var setEnum = pb.DefaultsMessage_TWO type testcase struct { ext *proto.ExtensionDesc // Extension we are testing. want interface{} // Expected value of extension, or nil (meaning that GetExtension will fail). def interface{} // Expected value of extension after ClearExtension(). } tests := []testcase{ {pb.E_NoDefaultDouble, setFloat64, nil}, {pb.E_NoDefaultFloat, setFloat32, nil}, {pb.E_NoDefaultInt32, setInt32, nil}, {pb.E_NoDefaultInt64, setInt64, nil}, {pb.E_NoDefaultUint32, setUint32, nil}, {pb.E_NoDefaultUint64, setUint64, nil}, {pb.E_NoDefaultSint32, setInt32, nil}, {pb.E_NoDefaultSint64, setInt64, nil}, {pb.E_NoDefaultFixed32, setUint32, nil}, {pb.E_NoDefaultFixed64, setUint64, nil}, {pb.E_NoDefaultSfixed32, setInt32, nil}, {pb.E_NoDefaultSfixed64, setInt64, nil}, {pb.E_NoDefaultBool, setBool, nil}, {pb.E_NoDefaultBool, setBool2, nil}, {pb.E_NoDefaultString, setString, nil}, {pb.E_NoDefaultBytes, setBytes, nil}, {pb.E_NoDefaultEnum, setEnum, nil}, {pb.E_DefaultDouble, setFloat64, float64(3.1415)}, {pb.E_DefaultFloat, setFloat32, float32(3.14)}, {pb.E_DefaultInt32, setInt32, int32(42)}, {pb.E_DefaultInt64, setInt64, int64(43)}, {pb.E_DefaultUint32, setUint32, uint32(44)}, {pb.E_DefaultUint64, setUint64, uint64(45)}, {pb.E_DefaultSint32, setInt32, int32(46)}, {pb.E_DefaultSint64, setInt64, int64(47)}, {pb.E_DefaultFixed32, setUint32, uint32(48)}, {pb.E_DefaultFixed64, setUint64, uint64(49)}, {pb.E_DefaultSfixed32, setInt32, int32(50)}, {pb.E_DefaultSfixed64, setInt64, int64(51)}, {pb.E_DefaultBool, setBool, true}, {pb.E_DefaultBool, setBool2, true}, {pb.E_DefaultString, setString, "Hello, string"}, {pb.E_DefaultBytes, setBytes, []byte("Hello, bytes")}, {pb.E_DefaultEnum, setEnum, pb.DefaultsMessage_ONE}, } checkVal := func(test testcase, msg *pb.DefaultsMessage, valWant interface{}) error { val, err := proto.GetExtension(msg, test.ext) if err != nil { if valWant != nil { return fmt.Errorf("GetExtension(): %s", err) } if want := proto.ErrMissingExtension; err != want { return fmt.Errorf("Unexpected error: got %v, want %v", err, want) } return nil } // All proto2 extension values are either a pointer to a value or a slice of values. ty := reflect.TypeOf(val) tyWant := reflect.TypeOf(test.ext.ExtensionType) if got, want := ty, tyWant; got != want { return fmt.Errorf("unexpected reflect.TypeOf(): got %v want %v", got, want) } tye := ty.Elem() tyeWant := tyWant.Elem() if got, want := tye, tyeWant; got != want { return fmt.Errorf("unexpected reflect.TypeOf().Elem(): got %v want %v", got, want) } // Check the name of the type of the value. // If it is an enum it will be type int32 with the name of the enum. if got, want := tye.Name(), tye.Name(); got != want { return fmt.Errorf("unexpected reflect.TypeOf().Elem().Name(): got %v want %v", got, want) } // Check that value is what we expect. // If we have a pointer in val, get the value it points to. valExp := val if ty.Kind() == reflect.Ptr { valExp = reflect.ValueOf(val).Elem().Interface() } if got, want := valExp, valWant; !reflect.DeepEqual(got, want) { return fmt.Errorf("unexpected reflect.DeepEqual(): got %v want %v", got, want) } return nil } setTo := func(test testcase) interface{} { setTo := reflect.ValueOf(test.want) if typ := reflect.TypeOf(test.ext.ExtensionType); typ.Kind() == reflect.Ptr { setTo = reflect.New(typ).Elem() setTo.Set(reflect.New(setTo.Type().Elem())) setTo.Elem().Set(reflect.ValueOf(test.want)) } return setTo.Interface() } for _, test := range tests { msg := &pb.DefaultsMessage{} name := test.ext.Name // Check the initial value. if err := checkVal(test, msg, test.def); err != nil { t.Errorf("%s: %v", name, err) } // Set the per-type value and check value. name = fmt.Sprintf("%s (set to %T %v)", name, test.want, test.want) if err := proto.SetExtension(msg, test.ext, setTo(test)); err != nil { t.Errorf("%s: SetExtension(): %v", name, err) continue } if err := checkVal(test, msg, test.want); err != nil { t.Errorf("%s: %v", name, err) continue } // Set and check the value. name += " (cleared)" proto.ClearExtension(msg, test.ext) if err := checkVal(test, msg, test.def); err != nil { t.Errorf("%s: %v", name, err) } } }
func TestUnmarshalRepeatingNonRepeatedExtension(t *testing.T) { // We may see multiple instances of the same extension in the wire // format. For example, the proto compiler may encode custom options in // this way. Here, we verify that we merge the extensions together. tests := []struct { name string ext []*pb.ComplexExtension }{ { "two fields", []*pb.ComplexExtension{ {First: proto.Int32(7)}, {Second: proto.Int32(11)}, }, }, { "repeated field", []*pb.ComplexExtension{ {Third: []int32{1000}}, {Third: []int32{2000}}, }, }, { "two fields and repeated field", []*pb.ComplexExtension{ {Third: []int32{1000}}, {First: proto.Int32(9)}, {Second: proto.Int32(21)}, {Third: []int32{2000}}, }, }, } for _, test := range tests { var buf bytes.Buffer var want pb.ComplexExtension // Generate a serialized representation of a repeated extension // by catenating bytes together. for i, e := range test.ext { // Merge to create the wanted proto. proto.Merge(&want, e) // serialize the message msg := new(pb.OtherMessage) err := proto.SetExtension(msg, pb.E_Complex, e) if err != nil { t.Fatalf("[%s] Error setting extension %d: %v", test.name, i, err) } b, err := proto.Marshal(msg) if err != nil { t.Fatalf("[%s] Error marshaling message %d: %v", test.name, i, err) } buf.Write(b) } // Unmarshal and read the merged proto. msg2 := new(pb.OtherMessage) err := proto.Unmarshal(buf.Bytes(), msg2) if err != nil { t.Fatalf("[%s] Error unmarshaling message: %v", test.name, err) } e, err := proto.GetExtension(msg2, pb.E_Complex) if err != nil { t.Fatalf("[%s] Error getting extension: %v", test.name, err) } ext := e.(*pb.ComplexExtension) if ext == nil { t.Fatalf("[%s] Invalid extension", test.name) } if !reflect.DeepEqual(*ext, want) { t.Errorf("[%s] Wrong value for ComplexExtension: got: %s want: %s\n", test.name, ext, want) } } }