Beispiel #1
0
func TestGetExtensionStability(t *testing.T) {
	check := func(m *pb.MyMessage) bool {
		ext1, err := proto.GetExtension(m, pb.E_Ext_More)
		if err != nil {
			t.Fatalf("GetExtension() failed: %s", err)
		}
		ext2, err := proto.GetExtension(m, pb.E_Ext_More)
		if err != nil {
			t.Fatalf("GetExtension() failed: %s", err)
		}
		return ext1 == ext2
	}
	msg := &pb.MyMessage{Count: proto.Int32(4)}
	ext0 := &pb.Ext{}
	if err := proto.SetExtension(msg, pb.E_Ext_More, ext0); err != nil {
		t.Fatalf("Could not set ext1: %s", ext0)
	}
	if !check(msg) {
		t.Errorf("GetExtension() not stable before marshaling")
	}
	bb, err := proto.Marshal(msg)
	if err != nil {
		t.Fatalf("Marshal() failed: %s", err)
	}
	msg1 := &pb.MyMessage{}
	err = proto.Unmarshal(bb, msg1)
	if err != nil {
		t.Fatalf("Unmarshal() failed: %s", err)
	}
	if !check(msg1) {
		t.Errorf("GetExtension() not stable after unmarshaling")
	}
}
Beispiel #2
0
// Take n tokens from bucket t, key k
func (client *Client) Take(t string, k string, n int32) (response *limitd.Response, takeResponse *limitd.TakeResponse, err error) {
	requestID := uniuri.New()

	request := &limitd.Request{
		Id:     proto.String(requestID),
		Method: limitd.Request_TAKE.Enum(),
		Type:   proto.String(t),
		Key:    proto.String(k),
		Count:  proto.Int32(n),
	}

	// goprotobuf.EncodeVarint followed by proto.Marshal
	responseChan := make(chan *limitd.Response)
	client.PendingRequests[requestID] = responseChan

	data, _ := proto.Marshal(request)
	data = append(proto.EncodeVarint(uint64(len(data))), data...)
	client.Conn.Write(data)

	response = <-responseChan
	takeR, err := proto.GetExtension(response, limitd.E_TakeResponse_Response)
	if err != nil {
		return
	}

	if takeResponseCasted, ok := takeR.(*limitd.TakeResponse); ok {
		takeResponse = takeResponseCasted
	}

	return
}
Beispiel #3
0
// SetPackageNames sets the package name for this run.
// The package name must agree across all files being generated.
func (g *Generator) SetPackageNames() {
	for _, f := range g.allFiles {
		pkgName := ""
		if f.Options != nil {
			extMap := f.GetOptions().ExtensionMap()
			if _, ok := extMap[50000]; ok {
				itf, err := proto.GetExtension(
					f.GetOptions(), javascript_package.E_JavascriptPackage)
				if err == nil {
					pkgName = *itf.(*string)
				}
			}
		}
		if pkgName == "" {
			pkgName = f.GetPackage()
		}

		uniquePackageName[f.FileDescriptorProto] = pkgName
		registerUniquePackageName(pkgName, f)

		if f == g.genFiles[0] {
			g.packageName = pkgName
		}
	}
}
Beispiel #4
0
func fillTreeWithMethod(tree *tree, key string, proto *descriptor.MethodDescriptorProto, loc string, locs map[string]*descriptor.SourceCodeInfo_Location) *method {
	key = fmt.Sprintf("%s.%s", key, proto.GetName())
	tree.methods[key] = &method{key: key, comment: getComment(loc, locs), MethodDescriptorProto: proto}
	if input, ok := tree.messages[proto.GetInputType()]; ok {
		tree.methods[key].input = input
	}
	if proto.GetClientStreaming() {
		tree.methods[key].inputStream = true
	}
	if output, ok := tree.messages[proto.GetOutputType()]; ok {
		tree.methods[key].output = output
	}
	if proto.GetServerStreaming() {
		tree.methods[key].outputStream = true
	}
	if proto.Options != nil && protobuf.HasExtension(proto.Options, gateway.E_Http) {
		ext, err := protobuf.GetExtension(proto.Options, gateway.E_Http)
		if err == nil {
			if opts, ok := ext.(*gateway.HttpRule); ok {
				if endpoint := newEndpoint(opts); endpoint != nil {
					tree.methods[key].endpoints = append(tree.methods[key].endpoints, endpoint)
				}
				for _, opts := range opts.AdditionalBindings {
					if endpoint := newEndpoint(opts); endpoint != nil {
						tree.methods[key].endpoints = append(tree.methods[key].endpoints, endpoint)
					}
				}
			}
		}
	}
	return tree.methods[key]
}
Beispiel #5
0
func newServerContactReqChan(conn *ricochetConn, msg *packet.OpenChannel) (*contactReqChan, error) {
	ch := new(contactReqChan)
	ch.conn = conn
	ch.chanID = (uint16)(msg.GetChannelIdentifier())
	ch.reqData = new(ContactRequest)

	ext, err := proto.GetExtension(msg, packet.E_ContactRequest)
	if err != nil {
		return nil, err
	}
	if ext == nil {
		return nil, fmt.Errorf("server: missing ContactRequest extension")
	}
	req := ext.(*packet.ContactRequest)
	ch.reqData.Hostname = conn.hostname
	ch.reqData.MyNickname = req.GetNickname()
	if len(ch.reqData.MyNickname) > ContactReqNicknameMaxCharacters {
		return nil, fmt.Errorf("server: ContactRequest nickname too long")
	}
	ch.reqData.Message = req.GetMessageText()
	if len(ch.reqData.Message) > ContactReqMessageMaxCharacters {
		return nil, fmt.Errorf("server: ContactRequest message too long")
	}
	return ch, nil
}
Beispiel #6
0
func TestExtensionsRoundTrip(t *testing.T) {
	msg := &pb.MyMessage{}
	ext1 := &pb.Ext{
		Data: proto.String("hi"),
	}
	ext2 := &pb.Ext{
		Data: proto.String("there"),
	}
	exists := proto.HasExtension(msg, pb.E_Ext_More)
	if exists {
		t.Error("Extension More present unexpectedly")
	}
	if err := proto.SetExtension(msg, pb.E_Ext_More, ext1); err != nil {
		t.Error(err)
	}
	if err := proto.SetExtension(msg, pb.E_Ext_More, ext2); err != nil {
		t.Error(err)
	}
	e, err := proto.GetExtension(msg, pb.E_Ext_More)
	if err != nil {
		t.Error(err)
	}
	x, ok := e.(*pb.Ext)
	if !ok {
		t.Errorf("e has type %T, expected testdata.Ext", e)
	} else if *x.Data != "there" {
		t.Errorf("SetExtension failed to overwrite, got %+v, not 'there'", x)
	}
	proto.ClearExtension(msg, pb.E_Ext_More)
	if _, err = proto.GetExtension(msg, pb.E_Ext_More); err != proto.ErrMissingExtension {
		t.Errorf("got %v, expected ErrMissingExtension", e)
	}
	if _, err := proto.GetExtension(msg, pb.E_X215); err == nil {
		t.Error("expected bad extension error, got nil")
	}
	if err := proto.SetExtension(msg, pb.E_X215, 12); err == nil {
		t.Error("expected extension err")
	}
	if err := proto.SetExtension(msg, pb.E_Ext_More, 12); err == nil {
		t.Error("expected some sort of type mismatch error, got nil")
	}
}
Beispiel #7
0
func int64AsNumber(field *descriptor.FieldDescriptorProto) bool {
	var int64Encoding int64_encoding.Int64Encoding
	if field.Options != nil {
		extMap := field.GetOptions().ExtensionMap()
		if _, ok := extMap[50001]; ok {
			itf, err := proto.GetExtension(field.GetOptions(), int64_encoding.E_Jstype)
			if err == nil {
				int64Encoding = *itf.(*int64_encoding.Int64Encoding)
			}
		}
	}
	return int64Encoding == int64_encoding.Int64Encoding_JS_NUMBER
}
Beispiel #8
0
func newServerAuthHSChan(conn *ricochetConn, msg *packet.OpenChannel) (*authHSChan, error) {
	ch := new(authHSChan)
	ch.conn = conn
	ch.chanID = (uint16)(msg.GetChannelIdentifier())

	ext, err := proto.GetExtension(msg, packet.E_ClientCookie)
	if err != nil {
		return nil, err
	}
	ch.clientCookie = ext.([]byte)
	if len(ch.clientCookie) != authHiddenServiceCookieSize {
		return nil, fmt.Errorf("invalid AuthHiddenService client_cookie")
	}
	return ch, nil
}
func convertFile(file *descriptor.FileDescriptorProto) ([]*plugin.CodeGeneratorResponse_File, error) {
	name := path.Base(file.GetName())
	pkg, ok := globalPkg.relativelyLookupPackage(file.GetPackage())
	if !ok {
		return nil, fmt.Errorf("no such package found: %s", file.GetPackage())
	}

	response := []*plugin.CodeGeneratorResponse_File{}
	for _, msg := range file.GetMessageType() {
		options := msg.GetOptions()
		if options == nil {
			continue
		}
		if !proto.HasExtension(options, E_TableName) {
			continue
		}
		optionValue, err := proto.GetExtension(options, E_TableName)
		if err != nil {
			return nil, err
		}
		tableName := *optionValue.(*string)
		if len(tableName) == 0 {
			return nil, fmt.Errorf("table name of %s cannot be empty", msg.GetName())
		}

		glog.V(2).Info("Generating schema for a message type ", msg.GetName())
		schema, err := convertMessageType(pkg, msg)
		if err != nil {
			glog.Errorf("Failed to convert %s: %v", name, err)
			return nil, err
		}

		jsonSchema, err := json.Marshal(schema)
		if err != nil {
			glog.Error("Failed to encode schema", err)
			return nil, err
		}

		resFile := &plugin.CodeGeneratorResponse_File{
			Name:    proto.String(fmt.Sprintf("%s/%s.schema", strings.Replace(file.GetPackage(), ".", "/", -1), tableName)),
			Content: proto.String(string(jsonSchema)),
		}
		response = append(response, resFile)
	}

	return response, nil
}
Beispiel #10
0
func extractAPIOptions(meth *descriptor.MethodDescriptorProto) (*options.HttpRule, error) {
	if meth.Options == nil {
		return nil, nil
	}
	if !proto.HasExtension(meth.Options, options.E_Http) {
		return nil, nil
	}
	ext, err := proto.GetExtension(meth.Options, options.E_Http)
	if err != nil {
		return nil, err
	}
	opts, ok := ext.(*options.HttpRule)
	if !ok {
		return nil, fmt.Errorf("extension is %T; want an HttpRule", ext)
	}
	return opts, nil
}
Beispiel #11
0
func (ch *contactReqChan) onChannelResult(msg *packet.ChannelResult) error {
	if ch.conn.isServer {
		return fmt.Errorf("opened contact req channel to client")
	}
	if ch.state != chanStateOpening {
		return fmt.Errorf("received spurious ContactRequest ChannelResult")
	}
	ch.state = chanStateOpen

	// If this routine was called, the channel WAS opened, without incident.
	// Extract the response, and take action accordingly.
	ext, err := proto.GetExtension(msg, packet.E_Response)
	if err != nil {
		return err
	}
	resp := ext.(*packet.ContactRequestResponse)
	return ch.onResponse(resp)
}
Beispiel #12
0
func (ch *authHSChan) onChannelResult(msg *packet.ChannelResult) error {
	if ch.conn.isServer {
		return fmt.Errorf("opened auth channel to client")
	}
	if ch.state != chanStateOpening {
		return fmt.Errorf("received spurious AuthHiddenService ChannelResult")
	}
	ch.state = chanStateOpen

	// If this routine was called, the channel WAS opened, without incident.
	// Extract the server cookie, and send the proof.
	ext, err := proto.GetExtension(msg, packet.E_ServerCookie)
	if err != nil {
		return err
	}
	ch.serverCookie = ext.([]byte)
	if len(ch.serverCookie) != authHiddenServiceCookieSize {
		return fmt.Errorf("invalid AuthHiddenService server_cookie")
	}

	// Encode the public key to DER.
	pkDER, err := pkcs1.EncodePublicKeyDER(&ch.conn.endpoint.privateKey.PublicKey)
	if err != nil {
		return err
	}

	// Calculate the proof.
	proof := ch.calculateProof(ch.conn.endpoint.hostname, ch.conn.hostname)

	// Sign the proof.
	sig, err := rsa.SignPKCS1v15(rand.Reader, ch.conn.endpoint.privateKey, crypto.SHA256, proof)
	if err != nil {
		return err
	}

	return ch.sendProof(pkDER, sig)
}
Beispiel #13
0
// Extract extracts a compilation from the specified extra action info.
func (c *Config) Extract(ctx context.Context, info *eapb.ExtraActionInfo) (*kindex.Compilation, error) {
	si, err := proto.GetExtension(info, eapb.E_SpawnInfo_SpawnInfo)
	if err != nil {
		return nil, fmt.Errorf("extra action does not have SpawnInfo: %v", err)
	}
	spawnInfo := si.(*eapb.SpawnInfo)

	// Verify that the mnemonic is what we expect.
	if m := info.GetMnemonic(); m != c.Mnemonic && c.Mnemonic != "" {
		return nil, fmt.Errorf("mnemonic does not match %q ≠ %q", m, c.Mnemonic)
	}

	// Construct the basic compilation.
	toolArgs := extractToolArgs(spawnInfo.Argument)
	log.Printf("Extracting compilation for %q", info.GetOwner())
	cu := &kindex.Compilation{
		Proto: &apb.CompilationUnit{
			VName: &spb.VName{
				Language:  govname.Language,
				Corpus:    c.Corpus,
				Signature: info.GetOwner(),
			},
			Argument:         toolArgs.fullArgs,
			SourceFile:       toolArgs.sources,
			WorkingDirectory: c.Root,
			Environment: []*apb.CompilationUnit_Env{{
				Name:  "GOROOT",
				Value: toolArgs.goRoot,
			}},
		},
	}

	// Load and populate file contents and required inputs.  Do this in two
	// passes: First scan the inputs and filter out which ones we actually want
	// to keep; then load their contents concurrently.
	var wantPaths []string
	for _, in := range spawnInfo.InputFile {
		if toolArgs.wantInput(in) {
			wantPaths = append(wantPaths, in)
			cu.Files = append(cu.Files, nil)
			cu.Proto.RequiredInput = append(cu.Proto.RequiredInput, nil)
		}
	}

	// Fetch concurrently. Each element of the proto slices is accessed by a
	// single goroutine corresponding to its index.
	log.Printf("Reading file contents for %d required inputs", len(wantPaths))
	start := time.Now()
	var wg sync.WaitGroup
	for i, path := range wantPaths {
		i, path := i, path
		wg.Add(1)
		go func() {
			defer wg.Done()
			fd, err := c.readFile(ctx, path)
			if err != nil {
				log.Fatalf("Unable to read input %q: %v", path, err)
			}
			cu.Files[i] = fd
			cu.Proto.RequiredInput[i] = c.fileDataToInfo(fd)
		}()
	}
	wg.Wait()
	log.Printf("Finished reading required inputs [%v elapsed]", time.Since(start))

	// Set the output path.  Although the SpawnInfo has room for multiple
	// outputs, we expect only one to be set in practice.  It's harmless if
	// there are more, though, so don't fail for that.
	for _, out := range spawnInfo.OutputFile {
		cu.Proto.OutputKey = out
		break
	}

	// Capture environment variables.
	for _, evar := range spawnInfo.Variable {
		if evar.GetName() == "PATH" {
			// TODO(fromberger): Perhaps whitelist or blacklist which
			// environment variables to capture here.
			continue
		}
		cu.Proto.Environment = append(cu.Proto.Environment, &apb.CompilationUnit_Env{
			Name:  evar.GetName(),
			Value: evar.GetValue(),
		})
	}

	return cu, nil
}
Beispiel #14
0
// ProcessConnection starts a blocking process loop which continually waits for
// new messages to arrive from the connection and uses the given RicochetService
// to process them.
func (r *Ricochet) processConnection(oc *OpenConnection, service RicochetService) {
	service.OnConnect(oc)
	defer service.OnDisconnect(oc)

	for {
		if oc.Closed {
			return
		}

		packet, err := r.rni.RecvRicochetPacket(oc.conn)
		if err != nil {
			oc.Close()
			return
		}

		if len(packet.Data) == 0 {
			service.OnChannelClosed(oc, packet.Channel)
			continue
		}

		if packet.Channel == 0 {

			res := new(Protocol_Data_Control.Packet)
			err := proto.Unmarshal(packet.Data[:], res)

			if err != nil {
				service.OnGenericError(oc, packet.Channel)
				continue
			}

			if res.GetOpenChannel() != nil {
				opm := res.GetOpenChannel()

				if oc.GetChannelType(opm.GetChannelIdentifier()) != "none" {
					// Channel is already in use.
					service.OnBadUsageError(oc, opm.GetChannelIdentifier())
					continue
				}

				// If I am a Client, the server can only open even numbered channels
				if oc.Client && opm.GetChannelIdentifier()%2 != 0 {
					service.OnBadUsageError(oc, opm.GetChannelIdentifier())
					continue
				}

				// If I am a Server, the client can only open odd numbered channels
				if !oc.Client && opm.GetChannelIdentifier()%2 != 1 {
					service.OnBadUsageError(oc, opm.GetChannelIdentifier())
					continue
				}

				switch opm.GetChannelType() {
				case "im.ricochet.auth.hidden-service":
					if oc.Client {
						// Servers are authed by default and can't auth with hidden-service
						service.OnBadUsageError(oc, opm.GetChannelIdentifier())
					} else if oc.IsAuthed {
						// Can't auth if already authed
						service.OnBadUsageError(oc, opm.GetChannelIdentifier())
					} else if oc.HasChannel("im.ricochet.auth.hidden-service") {
						// Can't open more than 1 auth channel
						service.OnBadUsageError(oc, opm.GetChannelIdentifier())
					} else {
						clientCookie, err := proto.GetExtension(opm, Protocol_Data_AuthHiddenService.E_ClientCookie)
						if err == nil {
							clientCookieB := [16]byte{}
							copy(clientCookieB[:], clientCookie.([]byte)[:])
							service.OnAuthenticationRequest(oc, opm.GetChannelIdentifier(), clientCookieB)
						} else {
							// Must include Client Cookie
							service.OnBadUsageError(oc, opm.GetChannelIdentifier())
						}
					}
				case "im.ricochet.chat":
					if !oc.IsAuthed {
						// Can't open chat channel if not authorized
						service.OnUnauthorizedError(oc, opm.GetChannelIdentifier())
					} else if !service.IsKnownContact(oc.OtherHostname) {
						// Can't open chat channel if not a known contact
						service.OnUnauthorizedError(oc, opm.GetChannelIdentifier())
					} else {
						service.OnOpenChannelRequest(oc, opm.GetChannelIdentifier(), "im.ricochet.chat")
					}
				case "im.ricochet.contact.request":
					if oc.Client {
						// Servers are not allowed to send contact requests
						service.OnBadUsageError(oc, opm.GetChannelIdentifier())
					} else if !oc.IsAuthed {
						// Can't open a contact channel if not authed
						service.OnUnauthorizedError(oc, opm.GetChannelIdentifier())
					} else if oc.HasChannel("im.ricochet.contact.request") {
						// Only 1 contact channel is allowed to be open at a time
						service.OnBadUsageError(oc, opm.GetChannelIdentifier())
					} else {
						contactRequestI, err := proto.GetExtension(opm, Protocol_Data_ContactRequest.E_ContactRequest)
						if err == nil {
							contactRequest, check := contactRequestI.(*Protocol_Data_ContactRequest.ContactRequest)
							if check {
								service.OnContactRequest(oc, opm.GetChannelIdentifier(), contactRequest.GetNickname(), contactRequest.GetMessageText())
								break
							}
						}
						service.OnBadUsageError(oc, opm.GetChannelIdentifier())
					}
				default:
					service.OnUnknownTypeError(oc, opm.GetChannelIdentifier())
				}
			} else if res.GetChannelResult() != nil {
				crm := res.GetChannelResult()
				if crm.GetOpened() {
					switch oc.GetChannelType(crm.GetChannelIdentifier()) {
					case "im.ricochet.auth.hidden-service":
						serverCookie, err := proto.GetExtension(crm, Protocol_Data_AuthHiddenService.E_ServerCookie)
						if err == nil {
							serverCookieB := [16]byte{}
							copy(serverCookieB[:], serverCookie.([]byte)[:])
							service.OnAuthenticationChallenge(oc, crm.GetChannelIdentifier(), serverCookieB)
						} else {
							service.OnBadUsageError(oc, crm.GetChannelIdentifier())
						}
					case "im.ricochet.chat":
						service.OnOpenChannelRequestSuccess(oc, crm.GetChannelIdentifier())
					case "im.ricochet.contact.request":
						responseI, err := proto.GetExtension(res.GetChannelResult(), Protocol_Data_ContactRequest.E_Response)
						if err == nil {
							response, check := responseI.(*Protocol_Data_ContactRequest.Response)
							if check {
								service.OnContactRequestAck(oc, crm.GetChannelIdentifier(), response.GetStatus().String())
								break
							}
						}
						service.OnBadUsageError(oc, crm.GetChannelIdentifier())
					default:
						service.OnBadUsageError(oc, crm.GetChannelIdentifier())
					}
				} else {
					if oc.GetChannelType(crm.GetChannelIdentifier()) != "none" {
						service.OnFailedChannelOpen(oc, crm.GetChannelIdentifier(), crm.GetCommonError().String())
					} else {
						oc.CloseChannel(crm.GetChannelIdentifier())
					}
				}
			} else {
				// Unknown Message
				oc.CloseChannel(packet.Channel)
			}
		} else if oc.GetChannelType(packet.Channel) == "im.ricochet.auth.hidden-service" {
			res := new(Protocol_Data_AuthHiddenService.Packet)
			err := proto.Unmarshal(packet.Data[:], res)

			if err != nil {
				oc.CloseChannel(packet.Channel)
				continue
			}

			if res.GetProof() != nil && !oc.Client { // Only Clients Send Proofs
				service.OnAuthenticationProof(oc, packet.Channel, res.GetProof().GetPublicKey(), res.GetProof().GetSignature(), service.IsKnownContact(oc.OtherHostname))
			} else if res.GetResult() != nil && oc.Client { // Only Servers Send Results
				service.OnAuthenticationResult(oc, packet.Channel, res.GetResult().GetAccepted(), res.GetResult().GetIsKnownContact())
			} else {
				// If neither of the above are satisfied we just close the connection
				oc.Close()
			}

		} else if oc.GetChannelType(packet.Channel) == "im.ricochet.chat" {

			// NOTE: These auth checks should be redundant, however they
			// are included here for defense-in-depth if for some reason
			// a previously authed connection becomes untrusted / not known and
			// the state is not cleaned up.
			if !oc.IsAuthed {
				// Can't send chat messages if not authorized
				service.OnUnauthorizedError(oc, packet.Channel)
			} else if !service.IsKnownContact(oc.OtherHostname) {
				// Can't send chat message if not a known contact
				service.OnUnauthorizedError(oc, packet.Channel)
			} else {
				res := new(Protocol_Data_Chat.Packet)
				err := proto.Unmarshal(packet.Data[:], res)

				if err != nil {
					oc.CloseChannel(packet.Channel)
					continue
				}

				if res.GetChatMessage() != nil {
					service.OnChatMessage(oc, packet.Channel, int32(res.GetChatMessage().GetMessageId()), res.GetChatMessage().GetMessageText())
				} else if res.GetChatAcknowledge() != nil {
					service.OnChatMessageAck(oc, packet.Channel, int32(res.GetChatMessage().GetMessageId()))
				} else {
					// If neither of the above are satisfied we just close the connection
					oc.Close()
				}
			}
		} else if oc.GetChannelType(packet.Channel) == "im.ricochet.contact.request" {

			// NOTE: These auth checks should be redundant, however they
			// are included here for defense-in-depth if for some reason
			// a previously authed connection becomes untrusted / not known and
			// the state is not cleaned up.
			if !oc.Client {
				// Clients are not allowed to send contact request responses
				service.OnBadUsageError(oc, packet.Channel)
			} else if !oc.IsAuthed {
				// Can't send a contact request if not authed
				service.OnBadUsageError(oc, packet.Channel)
			} else {
				res := new(Protocol_Data_ContactRequest.Response)
				err := proto.Unmarshal(packet.Data[:], res)
				log.Printf("%v", res)
				if err != nil {
					oc.CloseChannel(packet.Channel)
					continue
				}
				service.OnContactRequestAck(oc, packet.Channel, res.GetStatus().String())
			}
		} else if oc.GetChannelType(packet.Channel) == "none" {
			// Invalid Channel Assignment
			oc.CloseChannel(packet.Channel)
		} else {
			oc.Close()
		}
	}
}
Beispiel #15
0
func main() {
	msg := plugin.CodeGeneratorRequest{}
	buff, err := ioutil.ReadAll(os.Stdin)
	if err != nil {
		panic(err)
	}

	if err := proto.Unmarshal(buff, &msg); err != nil {
		panic(err)
	}

	ret := &plugin.CodeGeneratorResponse{}
	defer func() {
		buff, _ := proto.Marshal(ret)
		os.Stdout.Write(buff)
	}()

	param := msg.GetParameter()
	imports := map[string]string{}
	sources := map[string]string{}

	for _, p := range strings.Split(param, ",") {
		if len(p) == 0 {
			continue
		}

		if p[0] == 'M' {
			parts := strings.Split(p[1:], "=")

			imports[parts[0]] = parts[1]
		}
	}

	messages := map[string]message{}

	for _, file := range msg.GetProtoFile() {
		for _, msg := range file.GetMessageType() {
			m := message{}
			for _, f := range msg.GetField() {
				m.Fields = append(m.Fields, field{
					ProtoName: f.GetName(),
					GoName:    goise(f.GetName()),
				})
			}

			messages["."+file.GetPackage()+"."+msg.GetName()] = m
			sources["."+file.GetPackage()+"."+msg.GetName()] = file.GetName()
		}
	}

	for _, file := range msg.GetProtoFile() {
		services := map[string]service{}
		goPackage := "main"
		if file.GetOptions() != nil {
			goPackage = file.GetOptions().GetGoPackage()
		}

		for _, svc := range file.GetService() {
			s := service{
				GoName: goise(svc.GetName()),
			}

			for _, meth := range svc.GetMethod() {
				m := method{
					GoName:      goise(meth.GetName()),
					GoInputType: goise(meth.GetInputType()),
					Input:       messages[meth.GetInputType()],
					InputType:   meth.GetInputType(),
				}

				if meth.GetOptions() == nil {
					continue
				}

				if tmp, err := proto.GetExtension(meth.GetOptions(), google_api.E_Http); err == nil {
					http := tmp.(*google_api.HttpRule)

					if http.Get != "" {
						m.PathArgs = parsePath(messages, meth, http.Get)
						m.Path = http.Get
						m.Method = "GET"
					}

					if http.Put != "" {
						m.PathArgs = parsePath(messages, meth, http.Put)
						m.Path = http.Put

						m.Method = "PUT"
					}

					if http.Post != "" {
						m.PathArgs = parsePath(messages, meth, http.Post)
						m.Path = http.Post

						m.Method = "POST"
					}

					if http.Delete != "" {
						m.PathArgs = parsePath(messages, meth, http.Delete)
						m.Path = http.Delete

						m.Method = "DELETE"
					}

					if http.Body != "" {
						if m.Method == "PUT" || m.Method == "POST" {
							m.Body = http.Body
							m.GoBodyName = goise(http.Body)
						} else {
							log.Printf("WARN: Got http.body on non-put, non-post method.")
						}
					}
				}

				s.Methods = append(s.Methods, m)
			}

			if len(s.Methods) > 0 {
				services[svc.GetName()] = s
			}
		}

		if len(services) > 0 {
			fname := strings.Replace(file.GetName(), ".proto", ".pb.kit.go", 1)
			buff := bytes.NewBuffer(nil)
			imps := map[string]string{}
			impUsages := map[string]int{}

			for _, svc := range services {
				for _, meth := range svc.Methods {
					src := sources[meth.InputType]
					if src == file.GetName() {
						continue
					}

					impUsages[src] = impUsages[src] + 1
				}
			}

			for k := range impUsages {
				if v, ok := imports[k]; ok {
					imps[k] = v
					continue
				}

				fmt.Fprintf(os.Stderr, "Import %q is unknown.", k)
			}

			fileTemplate.Execute(buff, struct {
				Package  string
				Imports  map[string]string
				Services map[string]service
			}{goPackage, imps, services})
			data := buff.String()

			ret.File = append(ret.File, &plugin.CodeGeneratorResponse_File{
				Name:    &fname,
				Content: &data,
			})
		}
	}
}
Beispiel #16
0
// marshalObject writes a struct to the Writer.
func (m *Marshaler) marshalObject(out *errWriter, v proto.Message, indent string) error {
	out.write("{")
	if m.Indent != "" {
		out.write("\n")
	}

	s := reflect.ValueOf(v).Elem()
	firstField := true
	for i := 0; i < s.NumField(); i++ {
		value := s.Field(i)
		valueField := s.Type().Field(i)
		if strings.HasPrefix(valueField.Name, "XXX_") {
			continue
		}

		// TODO: proto3 objects should have default values omitted.

		// IsNil will panic on most value kinds.
		switch value.Kind() {
		case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice:
			if value.IsNil() {
				continue
			}
		}

		// Oneof fields need special handling.
		if valueField.Tag.Get("protobuf_oneof") != "" {
			// value is an interface containing &T{real_value}.
			sv := value.Elem().Elem() // interface -> *T -> T
			value = sv.Field(0)
			valueField = sv.Type().Field(0)
		}
		prop := jsonProperties(valueField)
		if !firstField {
			m.writeSep(out)
		}
		if err := m.marshalField(out, prop, value, indent); err != nil {
			return err
		}
		firstField = false
	}

	// Handle proto2 extensions.
	if ep, ok := v.(extendableProto); ok {
		extensions := proto.RegisteredExtensions(v)
		extensionMap := ep.ExtensionMap()
		// Sort extensions for stable output.
		ids := make([]int32, 0, len(extensionMap))
		for id := range extensionMap {
			ids = append(ids, id)
		}
		sort.Sort(int32Slice(ids))
		for _, id := range ids {
			desc := extensions[id]
			if desc == nil {
				// unknown extension
				continue
			}
			ext, extErr := proto.GetExtension(ep, desc)
			if extErr != nil {
				return extErr
			}
			value := reflect.ValueOf(ext)
			var prop proto.Properties
			prop.Parse(desc.Tag)
			prop.OrigName = fmt.Sprintf("[%s]", desc.Name)
			if !firstField {
				m.writeSep(out)
			}
			if err := m.marshalField(out, &prop, value, indent); err != nil {
				return err
			}
			firstField = false
		}

	}

	if m.Indent != "" {
		out.write("\n")
		out.write(indent)
	}
	out.write("}")
	return out.err
}
Beispiel #17
0
// marshalObject writes a struct to the Writer.
func (m *Marshaler) marshalObject(out *errWriter, v proto.Message, indent string) error {
	s := reflect.ValueOf(v).Elem()

	// Handle well-known types.
	type wkt interface {
		XXX_WellKnownType() string
	}
	if wkt, ok := v.(wkt); ok {
		switch wkt.XXX_WellKnownType() {
		case "DoubleValue", "FloatValue", "Int64Value", "UInt64Value",
			"Int32Value", "UInt32Value", "BoolValue", "StringValue", "BytesValue":
			// "Wrappers use the same representation in JSON
			//  as the wrapped primitive type, ..."
			sprop := proto.GetProperties(s.Type())
			return m.marshalValue(out, sprop.Prop[0], s.Field(0), indent)
		case "Duration":
			// "Generated output always contains 3, 6, or 9 fractional digits,
			//  depending on required precision."
			s, ns := s.Field(0).Int(), s.Field(1).Int()
			d := time.Duration(s)*time.Second + time.Duration(ns)*time.Nanosecond
			x := fmt.Sprintf("%.9f", d.Seconds())
			x = strings.TrimSuffix(x, "000")
			x = strings.TrimSuffix(x, "000")
			out.write(`"`)
			out.write(x)
			out.write(`s"`)
			return out.err
		case "Timestamp":
			// "RFC 3339, where generated output will always be Z-normalized
			//  and uses 3, 6 or 9 fractional digits."
			s, ns := s.Field(0).Int(), s.Field(1).Int()
			t := time.Unix(s, ns).UTC()
			// time.RFC3339Nano isn't exactly right (we need to get 3/6/9 fractional digits).
			x := t.Format("2006-01-02T15:04:05.000000000")
			x = strings.TrimSuffix(x, "000")
			x = strings.TrimSuffix(x, "000")
			out.write(`"`)
			out.write(x)
			out.write(`Z"`)
			return out.err
		}
	}

	out.write("{")
	if m.Indent != "" {
		out.write("\n")
	}

	firstField := true
	for i := 0; i < s.NumField(); i++ {
		value := s.Field(i)
		valueField := s.Type().Field(i)
		if strings.HasPrefix(valueField.Name, "XXX_") {
			continue
		}

		// IsNil will panic on most value kinds.
		switch value.Kind() {
		case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice:
			if value.IsNil() {
				continue
			}
		}

		if !m.EmitDefaults {
			switch value.Kind() {
			case reflect.Bool:
				if !value.Bool() {
					continue
				}
			case reflect.Int32, reflect.Int64:
				if value.Int() == 0 {
					continue
				}
			case reflect.Uint32, reflect.Uint64:
				if value.Uint() == 0 {
					continue
				}
			case reflect.Float32, reflect.Float64:
				if value.Float() == 0 {
					continue
				}
			case reflect.String:
				if value.Len() == 0 {
					continue
				}
			}
		}

		// Oneof fields need special handling.
		if valueField.Tag.Get("protobuf_oneof") != "" {
			// value is an interface containing &T{real_value}.
			sv := value.Elem().Elem() // interface -> *T -> T
			value = sv.Field(0)
			valueField = sv.Type().Field(0)
		}
		prop := jsonProperties(valueField, m.OrigName)
		if !firstField {
			m.writeSep(out)
		}
		if err := m.marshalField(out, prop, value, indent); err != nil {
			return err
		}
		firstField = false
	}

	// Handle proto2 extensions.
	if ep, ok := v.(extendableProto); ok {
		extensions := proto.RegisteredExtensions(v)
		extensionMap := ep.ExtensionMap()
		// Sort extensions for stable output.
		ids := make([]int32, 0, len(extensionMap))
		for id := range extensionMap {
			ids = append(ids, id)
		}
		sort.Sort(int32Slice(ids))
		for _, id := range ids {
			desc := extensions[id]
			if desc == nil {
				// unknown extension
				continue
			}
			ext, extErr := proto.GetExtension(ep, desc)
			if extErr != nil {
				return extErr
			}
			value := reflect.ValueOf(ext)
			var prop proto.Properties
			prop.Parse(desc.Tag)
			prop.JSONName = fmt.Sprintf("[%s]", desc.Name)
			if !firstField {
				m.writeSep(out)
			}
			if err := m.marshalField(out, &prop, value, indent); err != nil {
				return err
			}
			firstField = false
		}

	}

	if m.Indent != "" {
		out.write("\n")
		out.write(indent)
	}
	out.write("}")
	return out.err
}
Beispiel #18
0
func TestMarshalUnmarshalRepeatedExtension(t *testing.T) {
	// Add a repeated extension to the result.
	tests := []struct {
		name string
		ext  []*pb.ComplexExtension
	}{
		{
			"two fields",
			[]*pb.ComplexExtension{
				{First: proto.Int32(7)},
				{Second: proto.Int32(11)},
			},
		},
		{
			"repeated field",
			[]*pb.ComplexExtension{
				{Third: []int32{1000}},
				{Third: []int32{2000}},
			},
		},
		{
			"two fields and repeated field",
			[]*pb.ComplexExtension{
				{Third: []int32{1000}},
				{First: proto.Int32(9)},
				{Second: proto.Int32(21)},
				{Third: []int32{2000}},
			},
		},
	}
	for _, test := range tests {
		// Marshal message with a repeated extension.
		msg1 := new(pb.OtherMessage)
		err := proto.SetExtension(msg1, pb.E_RComplex, test.ext)
		if err != nil {
			t.Fatalf("[%s] Error setting extension: %v", test.name, err)
		}
		b, err := proto.Marshal(msg1)
		if err != nil {
			t.Fatalf("[%s] Error marshaling message: %v", test.name, err)
		}

		// Unmarshal and read the merged proto.
		msg2 := new(pb.OtherMessage)
		err = proto.Unmarshal(b, msg2)
		if err != nil {
			t.Fatalf("[%s] Error unmarshaling message: %v", test.name, err)
		}
		e, err := proto.GetExtension(msg2, pb.E_RComplex)
		if err != nil {
			t.Fatalf("[%s] Error getting extension: %v", test.name, err)
		}
		ext := e.([]*pb.ComplexExtension)
		if ext == nil {
			t.Fatalf("[%s] Invalid extension", test.name)
		}
		if !reflect.DeepEqual(ext, test.ext) {
			t.Errorf("[%s] Wrong value for ComplexExtension: got: %v want: %v\n", test.name, ext, test.ext)
		}
	}
}
Beispiel #19
0
func TestGetExtensionDefaults(t *testing.T) {
	var setFloat64 float64 = 1
	var setFloat32 float32 = 2
	var setInt32 int32 = 3
	var setInt64 int64 = 4
	var setUint32 uint32 = 5
	var setUint64 uint64 = 6
	var setBool = true
	var setBool2 = false
	var setString = "Goodnight string"
	var setBytes = []byte("Goodnight bytes")
	var setEnum = pb.DefaultsMessage_TWO

	type testcase struct {
		ext  *proto.ExtensionDesc // Extension we are testing.
		want interface{}          // Expected value of extension, or nil (meaning that GetExtension will fail).
		def  interface{}          // Expected value of extension after ClearExtension().
	}
	tests := []testcase{
		{pb.E_NoDefaultDouble, setFloat64, nil},
		{pb.E_NoDefaultFloat, setFloat32, nil},
		{pb.E_NoDefaultInt32, setInt32, nil},
		{pb.E_NoDefaultInt64, setInt64, nil},
		{pb.E_NoDefaultUint32, setUint32, nil},
		{pb.E_NoDefaultUint64, setUint64, nil},
		{pb.E_NoDefaultSint32, setInt32, nil},
		{pb.E_NoDefaultSint64, setInt64, nil},
		{pb.E_NoDefaultFixed32, setUint32, nil},
		{pb.E_NoDefaultFixed64, setUint64, nil},
		{pb.E_NoDefaultSfixed32, setInt32, nil},
		{pb.E_NoDefaultSfixed64, setInt64, nil},
		{pb.E_NoDefaultBool, setBool, nil},
		{pb.E_NoDefaultBool, setBool2, nil},
		{pb.E_NoDefaultString, setString, nil},
		{pb.E_NoDefaultBytes, setBytes, nil},
		{pb.E_NoDefaultEnum, setEnum, nil},
		{pb.E_DefaultDouble, setFloat64, float64(3.1415)},
		{pb.E_DefaultFloat, setFloat32, float32(3.14)},
		{pb.E_DefaultInt32, setInt32, int32(42)},
		{pb.E_DefaultInt64, setInt64, int64(43)},
		{pb.E_DefaultUint32, setUint32, uint32(44)},
		{pb.E_DefaultUint64, setUint64, uint64(45)},
		{pb.E_DefaultSint32, setInt32, int32(46)},
		{pb.E_DefaultSint64, setInt64, int64(47)},
		{pb.E_DefaultFixed32, setUint32, uint32(48)},
		{pb.E_DefaultFixed64, setUint64, uint64(49)},
		{pb.E_DefaultSfixed32, setInt32, int32(50)},
		{pb.E_DefaultSfixed64, setInt64, int64(51)},
		{pb.E_DefaultBool, setBool, true},
		{pb.E_DefaultBool, setBool2, true},
		{pb.E_DefaultString, setString, "Hello, string"},
		{pb.E_DefaultBytes, setBytes, []byte("Hello, bytes")},
		{pb.E_DefaultEnum, setEnum, pb.DefaultsMessage_ONE},
	}

	checkVal := func(test testcase, msg *pb.DefaultsMessage, valWant interface{}) error {
		val, err := proto.GetExtension(msg, test.ext)
		if err != nil {
			if valWant != nil {
				return fmt.Errorf("GetExtension(): %s", err)
			}
			if want := proto.ErrMissingExtension; err != want {
				return fmt.Errorf("Unexpected error: got %v, want %v", err, want)
			}
			return nil
		}

		// All proto2 extension values are either a pointer to a value or a slice of values.
		ty := reflect.TypeOf(val)
		tyWant := reflect.TypeOf(test.ext.ExtensionType)
		if got, want := ty, tyWant; got != want {
			return fmt.Errorf("unexpected reflect.TypeOf(): got %v want %v", got, want)
		}
		tye := ty.Elem()
		tyeWant := tyWant.Elem()
		if got, want := tye, tyeWant; got != want {
			return fmt.Errorf("unexpected reflect.TypeOf().Elem(): got %v want %v", got, want)
		}

		// Check the name of the type of the value.
		// If it is an enum it will be type int32 with the name of the enum.
		if got, want := tye.Name(), tye.Name(); got != want {
			return fmt.Errorf("unexpected reflect.TypeOf().Elem().Name(): got %v want %v", got, want)
		}

		// Check that value is what we expect.
		// If we have a pointer in val, get the value it points to.
		valExp := val
		if ty.Kind() == reflect.Ptr {
			valExp = reflect.ValueOf(val).Elem().Interface()
		}
		if got, want := valExp, valWant; !reflect.DeepEqual(got, want) {
			return fmt.Errorf("unexpected reflect.DeepEqual(): got %v want %v", got, want)
		}

		return nil
	}

	setTo := func(test testcase) interface{} {
		setTo := reflect.ValueOf(test.want)
		if typ := reflect.TypeOf(test.ext.ExtensionType); typ.Kind() == reflect.Ptr {
			setTo = reflect.New(typ).Elem()
			setTo.Set(reflect.New(setTo.Type().Elem()))
			setTo.Elem().Set(reflect.ValueOf(test.want))
		}
		return setTo.Interface()
	}

	for _, test := range tests {
		msg := &pb.DefaultsMessage{}
		name := test.ext.Name

		// Check the initial value.
		if err := checkVal(test, msg, test.def); err != nil {
			t.Errorf("%s: %v", name, err)
		}

		// Set the per-type value and check value.
		name = fmt.Sprintf("%s (set to %T %v)", name, test.want, test.want)
		if err := proto.SetExtension(msg, test.ext, setTo(test)); err != nil {
			t.Errorf("%s: SetExtension(): %v", name, err)
			continue
		}
		if err := checkVal(test, msg, test.want); err != nil {
			t.Errorf("%s: %v", name, err)
			continue
		}

		// Set and check the value.
		name += " (cleared)"
		proto.ClearExtension(msg, test.ext)
		if err := checkVal(test, msg, test.def); err != nil {
			t.Errorf("%s: %v", name, err)
		}
	}
}
Beispiel #20
0
func TestUnmarshalRepeatingNonRepeatedExtension(t *testing.T) {
	// We may see multiple instances of the same extension in the wire
	// format. For example, the proto compiler may encode custom options in
	// this way. Here, we verify that we merge the extensions together.
	tests := []struct {
		name string
		ext  []*pb.ComplexExtension
	}{
		{
			"two fields",
			[]*pb.ComplexExtension{
				{First: proto.Int32(7)},
				{Second: proto.Int32(11)},
			},
		},
		{
			"repeated field",
			[]*pb.ComplexExtension{
				{Third: []int32{1000}},
				{Third: []int32{2000}},
			},
		},
		{
			"two fields and repeated field",
			[]*pb.ComplexExtension{
				{Third: []int32{1000}},
				{First: proto.Int32(9)},
				{Second: proto.Int32(21)},
				{Third: []int32{2000}},
			},
		},
	}
	for _, test := range tests {
		var buf bytes.Buffer
		var want pb.ComplexExtension

		// Generate a serialized representation of a repeated extension
		// by catenating bytes together.
		for i, e := range test.ext {
			// Merge to create the wanted proto.
			proto.Merge(&want, e)

			// serialize the message
			msg := new(pb.OtherMessage)
			err := proto.SetExtension(msg, pb.E_Complex, e)
			if err != nil {
				t.Fatalf("[%s] Error setting extension %d: %v", test.name, i, err)
			}
			b, err := proto.Marshal(msg)
			if err != nil {
				t.Fatalf("[%s] Error marshaling message %d: %v", test.name, i, err)
			}
			buf.Write(b)
		}

		// Unmarshal and read the merged proto.
		msg2 := new(pb.OtherMessage)
		err := proto.Unmarshal(buf.Bytes(), msg2)
		if err != nil {
			t.Fatalf("[%s] Error unmarshaling message: %v", test.name, err)
		}
		e, err := proto.GetExtension(msg2, pb.E_Complex)
		if err != nil {
			t.Fatalf("[%s] Error getting extension: %v", test.name, err)
		}
		ext := e.(*pb.ComplexExtension)
		if ext == nil {
			t.Fatalf("[%s] Invalid extension", test.name)
		}
		if !reflect.DeepEqual(*ext, want) {
			t.Errorf("[%s] Wrong value for ComplexExtension: got: %s want: %s\n", test.name, ext, want)
		}
	}
}