示例#1
0
func (s *testServer) StreamingOutputCall(args *testpb.StreamingOutputCallRequest, stream testpb.TestService_StreamingOutputCallServer) error {
	if md, ok := metadata.FromContext(stream.Context()); ok {
		// For testing purpose, returns an error if there is attached metadata.
		if len(md) > 0 {
			return grpc.Errorf(codes.DataLoss, "got extra metadata")
		}
	}
	cs := args.GetResponseParameters()
	for _, c := range cs {
		if us := c.GetIntervalUs(); us > 0 {
			time.Sleep(time.Duration(us) * time.Microsecond)
		}

		payload, err := newPayload(args.GetResponseType(), c.GetSize())
		if err != nil {
			return err
		}

		if err := stream.Send(&testpb.StreamingOutputCallResponse{
			Payload: payload,
		}); err != nil {
			return err
		}
	}
	return nil
}
示例#2
0
func (s *testServer) EmptyCall(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
	if md, ok := metadata.FromContext(ctx); ok {
		// For testing purpose, returns an error if there is attached metadata other than
		// the user agent set by the client application.
		if _, ok := md["user-agent"]; !ok {
			return nil, grpc.Errorf(codes.DataLoss, "got extra metadata")
		}
		var str []string
		for _, entry := range md["user-agent"] {
			str = append(str, "ua", entry)
		}
		grpc.SendHeader(ctx, metadata.Pairs(str...))
	}
	return new(testpb.Empty), nil
}
示例#3
0
func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
	md, ok := metadata.FromContext(ctx)
	if ok {
		if err := grpc.SendHeader(ctx, md); err != nil {
			return nil, fmt.Errorf("grpc.SendHeader(%v, %v) = %v, want %v", ctx, md, err, nil)
		}
		grpc.SetTrailer(ctx, md)
	}
	if s.security != "" {
		// Check Auth info
		authInfo, ok := credentials.FromContext(ctx)
		if !ok {
			return nil, fmt.Errorf("Failed to get AuthInfo from ctx.")
		}
		var authType, serverName string
		switch info := authInfo.(type) {
		case credentials.TLSInfo:
			authType = info.AuthType()
			serverName = info.State.ServerName
		default:
			return nil, fmt.Errorf("Unknown AuthInfo type")
		}
		if authType != s.security {
			return nil, fmt.Errorf("Wrong auth type: got %q, want %q", authType, s.security)
		}
		if serverName != "x.test.youtube.com" {
			return nil, fmt.Errorf("Unknown server name %q", serverName)
		}
	}

	// Simulate some service delay.
	time.Sleep(time.Second)

	payload, err := newPayload(in.GetResponseType(), in.GetResponseSize())
	if err != nil {
		return nil, err
	}

	return &testpb.SimpleResponse{
		Payload: payload,
	}, nil
}
示例#4
0
func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error {
	md, ok := metadata.FromContext(stream.Context())
	if ok {
		if err := stream.SendHeader(md); err != nil {
			return fmt.Errorf("%v.SendHeader(%v) = %v, want %v", stream, md, err, nil)
		}
		stream.SetTrailer(md)
	}
	for {
		in, err := stream.Recv()
		if err == io.EOF {
			// read done.
			return nil
		}
		if err != nil {
			return err
		}
		cs := in.GetResponseParameters()
		for _, c := range cs {
			if us := c.GetIntervalUs(); us > 0 {
				time.Sleep(time.Duration(us) * time.Microsecond)
			}

			payload, err := newPayload(in.GetResponseType(), c.GetSize())
			if err != nil {
				return err
			}

			if err := stream.Send(&testpb.StreamingOutputCallResponse{
				Payload: payload,
			}); err != nil {
				return err
			}
		}
	}
}
示例#5
0
// NewStream creates a stream and register it into the transport as "active"
// streams.
func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Stream, err error) {
	// Record the timeout value on the context.
	var timeout time.Duration
	if dl, ok := ctx.Deadline(); ok {
		timeout = dl.Sub(time.Now())
		if timeout <= 0 {
			return nil, ContextErr(context.DeadlineExceeded)
		}
	}
	// Attach Auth info if there is any.
	if t.authInfo != nil {
		ctx = credentials.NewContext(ctx, t.authInfo)
	}
	authData := make(map[string]string)
	for _, c := range t.authCreds {
		// Construct URI required to get auth request metadata.
		var port string
		if pos := strings.LastIndex(t.target, ":"); pos != -1 {
			// Omit port if it is the default one.
			if t.target[pos+1:] != "443" {
				port = ":" + t.target[pos+1:]
			}
		}
		pos := strings.LastIndex(callHdr.Method, "/")
		if pos == -1 {
			return nil, StreamErrorf(codes.InvalidArgument, "transport: malformed method name: %q", callHdr.Method)
		}
		audience := "https://" + callHdr.Host + port + callHdr.Method[:pos]
		data, err := c.GetRequestMetadata(ctx, audience)
		if err != nil {
			return nil, StreamErrorf(codes.InvalidArgument, "transport: %v", err)
		}
		for k, v := range data {
			authData[k] = v
		}
	}
	t.mu.Lock()
	if t.state != reachable {
		t.mu.Unlock()
		return nil, ErrConnClosing
	}
	checkStreamsQuota := t.streamsQuota != nil
	t.mu.Unlock()
	if checkStreamsQuota {
		sq, err := wait(ctx, t.shutdownChan, t.streamsQuota.acquire())
		if err != nil {
			return nil, err
		}
		// Returns the quota balance back.
		if sq > 1 {
			t.streamsQuota.add(sq - 1)
		}
	}
	if _, err := wait(ctx, t.shutdownChan, t.writableChan); err != nil {
		// t.streamsQuota will be updated when t.CloseStream is invoked.
		return nil, err
	}
	t.mu.Lock()
	if t.state != reachable {
		t.mu.Unlock()
		return nil, ErrConnClosing
	}
	s := t.newStream(ctx, callHdr)
	t.activeStreams[s.id] = s

	// This stream is not counted when applySetings(...) initialize t.streamsQuota.
	// Reset t.streamsQuota to the right value.
	var reset bool
	if !checkStreamsQuota && t.streamsQuota != nil {
		reset = true
	}
	t.mu.Unlock()
	if reset {
		t.streamsQuota.reset(-1)
	}

	// HPACK encodes various headers. Note that once WriteField(...) is
	// called, the corresponding headers/continuation frame has to be sent
	// because hpack.Encoder is stateful.
	t.hBuf.Reset()
	t.hEnc.WriteField(hpack.HeaderField{Name: ":method", Value: "POST"})
	t.hEnc.WriteField(hpack.HeaderField{Name: ":scheme", Value: t.scheme})
	t.hEnc.WriteField(hpack.HeaderField{Name: ":path", Value: callHdr.Method})
	t.hEnc.WriteField(hpack.HeaderField{Name: ":authority", Value: callHdr.Host})
	t.hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"})
	t.hEnc.WriteField(hpack.HeaderField{Name: "user-agent", Value: t.userAgent})
	t.hEnc.WriteField(hpack.HeaderField{Name: "te", Value: "trailers"})

	if timeout > 0 {
		t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: timeoutEncode(timeout)})
	}
	for k, v := range authData {
		t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: v})
	}
	var (
		hasMD      bool
		endHeaders bool
	)
	if md, ok := metadata.FromContext(ctx); ok {
		hasMD = true
		for k, v := range md {
			for _, entry := range v {
				t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: entry})
			}
		}
	}
	first := true
	// Sends the headers in a single batch even when they span multiple frames.
	for !endHeaders {
		size := t.hBuf.Len()
		if size > http2MaxFrameLen {
			size = http2MaxFrameLen
		} else {
			endHeaders = true
		}
		if first {
			// Sends a HeadersFrame to server to start a new stream.
			p := http2.HeadersFrameParam{
				StreamID:      s.id,
				BlockFragment: t.hBuf.Next(size),
				EndStream:     false,
				EndHeaders:    endHeaders,
			}
			// Do a force flush for the buffered frames iff it is the last headers frame
			// and there is header metadata to be sent. Otherwise, there is flushing until
			// the corresponding data frame is written.
			err = t.framer.writeHeaders(hasMD && endHeaders, p)
			first = false
		} else {
			// Sends Continuation frames for the leftover headers.
			err = t.framer.writeContinuation(hasMD && endHeaders, s.id, endHeaders, t.hBuf.Next(size))
		}
		if err != nil {
			t.notifyError(err)
			return nil, ConnectionErrorf("transport: %v", err)
		}
	}
	t.writableChan <- 0
	return s, nil
}