Ejemplo n.º 1
0
func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Stream, msg interface{}, cp Compressor, opts *transport.Options) error {
	var (
		cbuf       *bytes.Buffer
		outPayload *stats.OutPayload
	)
	if cp != nil {
		cbuf = new(bytes.Buffer)
	}
	if stats.On() {
		outPayload = &stats.OutPayload{}
	}
	p, err := encode(s.opts.codec, msg, cp, cbuf, outPayload)
	if err != nil {
		// This typically indicates a fatal issue (e.g., memory
		// corruption or hardware faults) the application program
		// cannot handle.
		//
		// TODO(zhaoq): There exist other options also such as only closing the
		// faulty stream locally and remotely (Other streams can keep going). Find
		// the optimal option.
		grpclog.Fatalf("grpc: Server failed to encode response %v", err)
	}
	err = t.Write(stream, p, opts)
	if err == nil && outPayload != nil {
		outPayload.SentTime = time.Now()
		stats.HandleRPC(stream.Context(), outPayload)
	}
	return err
}
Ejemplo n.º 2
0
func (ss *serverStream) RecvMsg(m interface{}) (err error) {
	defer func() {
		if ss.trInfo != nil {
			ss.mu.Lock()
			if ss.trInfo.tr != nil {
				if err == nil {
					ss.trInfo.tr.LazyLog(&payload{sent: false, msg: m}, true)
				} else if err != io.EOF {
					ss.trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
					ss.trInfo.tr.SetError()
				}
			}
			ss.mu.Unlock()
		}
	}()
	var inPayload *stats.InPayload
	if stats.On() {
		inPayload = &stats.InPayload{}
	}
	if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxMsgSize, inPayload); err != nil {
		if err == io.EOF {
			return err
		}
		if err == io.ErrUnexpectedEOF {
			err = Errorf(codes.Internal, io.ErrUnexpectedEOF.Error())
		}
		return toRPCErr(err)
	}
	if inPayload != nil {
		stats.HandleRPC(ss.s.Context(), inPayload)
	}
	return nil
}
Ejemplo n.º 3
0
// WriteStatus sends stream status to the client and terminates the stream.
// There is no further I/O operations being able to perform on this stream.
// TODO(zhaoq): Now it indicates the end of entire stream. Revisit if early
// OK is adopted.
func (t *http2Server) WriteStatus(s *Stream, statusCode codes.Code, statusDesc string) error {
	var headersSent, hasHeader bool
	s.mu.Lock()
	if s.state == streamDone {
		s.mu.Unlock()
		return nil
	}
	if s.headerOk {
		headersSent = true
	}
	if s.header.Len() > 0 {
		hasHeader = true
	}
	s.mu.Unlock()

	if !headersSent && hasHeader {
		t.WriteHeader(s, nil)
		headersSent = true
	}

	if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil {
		return err
	}
	t.hBuf.Reset()
	if !headersSent {
		t.hEnc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
		t.hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"})
	}
	t.hEnc.WriteField(
		hpack.HeaderField{
			Name:  "grpc-status",
			Value: strconv.Itoa(int(statusCode)),
		})
	t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-message", Value: encodeGrpcMessage(statusDesc)})
	// Attach the trailer metadata.
	for k, v := range s.trailer {
		// Clients don't tolerate reading restricted headers after some non restricted ones were sent.
		if isReservedHeader(k) {
			continue
		}
		for _, entry := range v {
			t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: entry})
		}
	}
	bufLen := t.hBuf.Len()
	if err := t.writeHeaders(s, t.hBuf, true); err != nil {
		t.Close()
		return err
	}
	if stats.On() {
		outTrailer := &stats.OutTrailer{
			WireLength: bufLen,
		}
		stats.HandleRPC(s.Context(), outTrailer)
	}
	t.closeStream(s)
	t.writableChan <- 0
	return nil
}
Ejemplo n.º 4
0
func (cs *clientStream) SendMsg(m interface{}) (err error) {
	if cs.tracing {
		cs.mu.Lock()
		if cs.trInfo.tr != nil {
			cs.trInfo.tr.LazyLog(&payload{sent: true, msg: m}, true)
		}
		cs.mu.Unlock()
	}
	// TODO Investigate how to signal the stats handling party.
	// generate error stats if err != nil && err != io.EOF?
	defer func() {
		if err != nil {
			cs.finish(err)
		}
		if err == nil {
			return
		}
		if err == io.EOF {
			// Specialize the process for server streaming. SendMesg is only called
			// once when creating the stream object. io.EOF needs to be skipped when
			// the rpc is early finished (before the stream object is created.).
			// TODO: It is probably better to move this into the generated code.
			if !cs.desc.ClientStreams && cs.desc.ServerStreams {
				err = nil
			}
			return
		}
		if _, ok := err.(transport.ConnectionError); !ok {
			cs.closeTransportStream(err)
		}
		err = toRPCErr(err)
	}()
	var outPayload *stats.OutPayload
	if stats.On() {
		outPayload = &stats.OutPayload{
			Client: true,
		}
	}
	out, err := encode(cs.codec, m, cs.cp, cs.cbuf, outPayload)
	defer func() {
		if cs.cbuf != nil {
			cs.cbuf.Reset()
		}
	}()
	if err != nil {
		return Errorf(codes.Internal, "grpc: %v", err)
	}
	err = cs.t.Write(cs.s, out, &transport.Options{Last: false})
	if err == nil && outPayload != nil {
		outPayload.SentTime = time.Now()
		stats.HandleRPC(cs.statsCtx, outPayload)
	}
	return err
}
Ejemplo n.º 5
0
// WriteHeader sends the header metedata md back to the client.
func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
	s.mu.Lock()
	if s.headerOk || s.state == streamDone {
		s.mu.Unlock()
		return ErrIllegalHeaderWrite
	}
	s.headerOk = true
	if md.Len() > 0 {
		if s.header.Len() > 0 {
			s.header = metadata.Join(s.header, md)
		} else {
			s.header = md
		}
	}
	md = s.header
	s.mu.Unlock()
	if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil {
		return err
	}
	t.hBuf.Reset()
	t.hEnc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
	t.hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"})
	if s.sendCompress != "" {
		t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: s.sendCompress})
	}
	for k, v := range md {
		if isReservedHeader(k) {
			// Clients don't tolerate reading restricted headers after some non restricted ones were sent.
			continue
		}
		for _, entry := range v {
			t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: entry})
		}
	}
	bufLen := t.hBuf.Len()
	if err := t.writeHeaders(s, t.hBuf, false); err != nil {
		return err
	}
	if stats.On() {
		outHeader := &stats.OutHeader{
			WireLength: bufLen,
		}
		stats.HandleRPC(s.Context(), outHeader)
	}
	t.writableChan <- 0
	return nil
}
Ejemplo n.º 6
0
// sendRequest writes out various information of an RPC such as Context and Message.
func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHdr *transport.CallHdr, t transport.ClientTransport, args interface{}, opts *transport.Options) (_ *transport.Stream, err error) {
	stream, err := t.NewStream(ctx, callHdr)
	if err != nil {
		return nil, err
	}
	defer func() {
		if err != nil {
			// If err is connection error, t will be closed, no need to close stream here.
			if _, ok := err.(transport.ConnectionError); !ok {
				t.CloseStream(stream, err)
			}
		}
	}()
	var (
		cbuf       *bytes.Buffer
		outPayload *stats.OutPayload
	)
	if compressor != nil {
		cbuf = new(bytes.Buffer)
	}
	if stats.On() {
		outPayload = &stats.OutPayload{
			Client: true,
		}
	}
	outBuf, err := encode(codec, args, compressor, cbuf, outPayload)
	if err != nil {
		return nil, Errorf(codes.Internal, "grpc: %v", err)
	}
	err = t.Write(stream, outBuf, opts)
	if err == nil && outPayload != nil {
		outPayload.SentTime = time.Now()
		stats.HandleRPC(ctx, outPayload)
	}
	// t.NewStream(...) could lead to an early rejection of the RPC (e.g., the service/method
	// does not exist.) so that t.Write could get io.EOF from wait(...). Leave the following
	// recvResponse to get the final status.
	if err != nil && err != io.EOF {
		return nil, err
	}
	// Sent successfully.
	return stream, nil
}
Ejemplo n.º 7
0
func (ss *serverStream) SendMsg(m interface{}) (err error) {
	defer func() {
		if ss.trInfo != nil {
			ss.mu.Lock()
			if ss.trInfo.tr != nil {
				if err == nil {
					ss.trInfo.tr.LazyLog(&payload{sent: true, msg: m}, true)
				} else {
					ss.trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
					ss.trInfo.tr.SetError()
				}
			}
			ss.mu.Unlock()
		}
	}()
	var outPayload *stats.OutPayload
	if stats.On() {
		outPayload = &stats.OutPayload{}
	}
	out, err := encode(ss.codec, m, ss.cp, ss.cbuf, outPayload)
	defer func() {
		if ss.cbuf != nil {
			ss.cbuf.Reset()
		}
	}()
	if err != nil {
		err = Errorf(codes.Internal, "grpc: %v", err)
		return err
	}
	if err := ss.t.Write(ss.s, out, &transport.Options{Last: false}); err != nil {
		return toRPCErr(err)
	}
	if outPayload != nil {
		outPayload.SentTime = time.Now()
		stats.HandleRPC(ss.s.Context(), outPayload)
	}
	return nil
}
Ejemplo n.º 8
0
// recvResponse receives and parses an RPC response.
// On error, it returns the error and indicates whether the call should be retried.
//
// TODO(zhaoq): Check whether the received message sequence is valid.
// TODO ctx is used for stats collection and processing. It is the context passed from the application.
func recvResponse(ctx context.Context, dopts dialOptions, t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) (err error) {
	// Try to acquire header metadata from the server if there is any.
	defer func() {
		if err != nil {
			if _, ok := err.(transport.ConnectionError); !ok {
				t.CloseStream(stream, err)
			}
		}
	}()
	c.headerMD, err = stream.Header()
	if err != nil {
		return
	}
	p := &parser{r: stream}
	var inPayload *stats.InPayload
	if stats.On() {
		inPayload = &stats.InPayload{
			Client: true,
		}
	}
	for {
		if err = recv(p, dopts.codec, stream, dopts.dc, reply, math.MaxInt32, inPayload); err != nil {
			if err == io.EOF {
				break
			}
			return
		}
	}
	if inPayload != nil && err == io.EOF && stream.StatusCode() == codes.OK {
		// TODO in the current implementation, inTrailer may be handled before inPayload in some cases.
		// Fix the order if necessary.
		stats.HandleRPC(ctx, inPayload)
	}
	c.trailerMD = stream.Trailer()
	return nil
}
Ejemplo n.º 9
0
// operateHeader takes action on the decoded headers.
func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream), traceCtx func(context.Context, string) context.Context) (close bool) {
	buf := newRecvBuffer()
	s := &Stream{
		id:  frame.Header().StreamID,
		st:  t,
		buf: buf,
		fc:  &inFlow{limit: initialWindowSize},
	}

	var state decodeState
	for _, hf := range frame.Fields {
		state.processHeaderField(hf)
	}
	if err := state.err; err != nil {
		if se, ok := err.(StreamError); ok {
			t.controlBuf.put(&resetStream{s.id, statusCodeConvTab[se.Code]})
		}
		return
	}

	if frame.StreamEnded() {
		// s is just created by the caller. No lock needed.
		s.state = streamReadDone
	}
	s.recvCompress = state.encoding
	if state.timeoutSet {
		s.ctx, s.cancel = context.WithTimeout(t.ctx, state.timeout)
	} else {
		s.ctx, s.cancel = context.WithCancel(t.ctx)
	}
	pr := &peer.Peer{
		Addr: t.remoteAddr,
	}
	// Attach Auth info if there is any.
	if t.authInfo != nil {
		pr.AuthInfo = t.authInfo
	}
	s.ctx = peer.NewContext(s.ctx, pr)
	// Cache the current stream to the context so that the server application
	// can find out. Required when the server wants to send some metadata
	// back to the client (unary call only).
	s.ctx = newContextWithStream(s.ctx, s)
	// Attach the received metadata to the context.
	if len(state.mdata) > 0 {
		s.ctx = metadata.NewContext(s.ctx, state.mdata)
	}

	s.dec = &recvBufferReader{
		ctx:  s.ctx,
		recv: s.buf,
	}
	s.recvCompress = state.encoding
	s.method = state.method
	if t.inTapHandle != nil {
		var err error
		info := &tap.Info{
			FullMethodName: state.method,
		}
		s.ctx, err = t.inTapHandle(s.ctx, info)
		if err != nil {
			// TODO: Log the real error.
			t.controlBuf.put(&resetStream{s.id, http2.ErrCodeRefusedStream})
			return
		}
	}
	t.mu.Lock()
	if t.state != reachable {
		t.mu.Unlock()
		return
	}
	if uint32(len(t.activeStreams)) >= t.maxStreams {
		t.mu.Unlock()
		t.controlBuf.put(&resetStream{s.id, http2.ErrCodeRefusedStream})
		return
	}
	if s.id%2 != 1 || s.id <= t.maxStreamID {
		t.mu.Unlock()
		// illegal gRPC stream id.
		grpclog.Println("transport: http2Server.HandleStreams received an illegal stream id: ", s.id)
		return true
	}
	t.maxStreamID = s.id
	s.sendQuotaPool = newQuotaPool(int(t.streamSendQuota))
	t.activeStreams[s.id] = s
	t.mu.Unlock()
	s.windowHandler = func(n int) {
		t.updateWindow(s, uint32(n))
	}
	s.ctx = traceCtx(s.ctx, s.method)
	if stats.On() {
		s.ctx = stats.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method})
		inHeader := &stats.InHeader{
			FullMethod:  s.method,
			RemoteAddr:  t.remoteAddr,
			LocalAddr:   t.localAddr,
			Compression: s.recvCompress,
			WireLength:  int(frame.Header().Length),
		}
		stats.HandleRPC(s.ctx, inHeader)
	}
	handle(s)
	return
}
Ejemplo n.º 10
0
func (cs *clientStream) RecvMsg(m interface{}) (err error) {
	defer func() {
		if err != nil && stats.On() {
			// Only generate End if err != nil.
			// If err == nil, it's not the last RecvMsg.
			// The last RecvMsg gets either an RPC error or io.EOF.
			end := &stats.End{
				Client:  true,
				EndTime: time.Now(),
			}
			if err != io.EOF {
				end.Error = toRPCErr(err)
			}
			stats.HandleRPC(cs.statsCtx, end)
		}
	}()
	var inPayload *stats.InPayload
	if stats.On() {
		inPayload = &stats.InPayload{
			Client: true,
		}
	}
	err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32, inPayload)
	defer func() {
		// err != nil indicates the termination of the stream.
		if err != nil {
			cs.finish(err)
		}
	}()
	if err == nil {
		if cs.tracing {
			cs.mu.Lock()
			if cs.trInfo.tr != nil {
				cs.trInfo.tr.LazyLog(&payload{sent: false, msg: m}, true)
			}
			cs.mu.Unlock()
		}
		if inPayload != nil {
			stats.HandleRPC(cs.statsCtx, inPayload)
		}
		if !cs.desc.ClientStreams || cs.desc.ServerStreams {
			return
		}
		// Special handling for client streaming rpc.
		// This recv expects EOF or errors, so we don't collect inPayload.
		err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32, nil)
		cs.closeTransportStream(err)
		if err == nil {
			return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
		}
		if err == io.EOF {
			if cs.s.StatusCode() == codes.OK {
				cs.finish(err)
				return nil
			}
			return Errorf(cs.s.StatusCode(), "%s", cs.s.StatusDesc())
		}
		return toRPCErr(err)
	}
	if _, ok := err.(transport.ConnectionError); !ok {
		cs.closeTransportStream(err)
	}
	if err == io.EOF {
		if cs.s.StatusCode() == codes.OK {
			// Returns io.EOF to indicate the end of the stream.
			return
		}
		return Errorf(cs.s.StatusCode(), "%s", cs.s.StatusDesc())
	}
	return toRPCErr(err)
}
Ejemplo n.º 11
0
func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) {
	var (
		t   transport.ClientTransport
		s   *transport.Stream
		put func()
	)
	c := defaultCallInfo
	for _, o := range opts {
		if err := o.before(&c); err != nil {
			return nil, toRPCErr(err)
		}
	}
	callHdr := &transport.CallHdr{
		Host:   cc.authority,
		Method: method,
		Flush:  desc.ServerStreams && desc.ClientStreams,
	}
	if cc.dopts.cp != nil {
		callHdr.SendCompress = cc.dopts.cp.Type()
	}
	var trInfo traceInfo
	if EnableTracing {
		trInfo.tr = trace.New("grpc.Sent."+methodFamily(method), method)
		trInfo.firstLine.client = true
		if deadline, ok := ctx.Deadline(); ok {
			trInfo.firstLine.deadline = deadline.Sub(time.Now())
		}
		trInfo.tr.LazyLog(&trInfo.firstLine, false)
		ctx = trace.NewContext(ctx, trInfo.tr)
		defer func() {
			if err != nil {
				// Need to call tr.finish() if error is returned.
				// Because tr will not be returned to caller.
				trInfo.tr.LazyPrintf("RPC: [%v]", err)
				trInfo.tr.SetError()
				trInfo.tr.Finish()
			}
		}()
	}
	if stats.On() {
		ctx = stats.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method})
		begin := &stats.Begin{
			Client:    true,
			BeginTime: time.Now(),
			FailFast:  c.failFast,
		}
		stats.HandleRPC(ctx, begin)
	}
	defer func() {
		if err != nil && stats.On() {
			// Only handle end stats if err != nil.
			end := &stats.End{
				Client: true,
				Error:  err,
			}
			stats.HandleRPC(ctx, end)
		}
	}()
	gopts := BalancerGetOptions{
		BlockingWait: !c.failFast,
	}
	for {
		t, put, err = cc.getTransport(ctx, gopts)
		if err != nil {
			// TODO(zhaoq): Probably revisit the error handling.
			if _, ok := err.(*rpcError); ok {
				return nil, err
			}
			if err == errConnClosing || err == errConnUnavailable {
				if c.failFast {
					return nil, Errorf(codes.Unavailable, "%v", err)
				}
				continue
			}
			// All the other errors are treated as Internal errors.
			return nil, Errorf(codes.Internal, "%v", err)
		}

		s, err = t.NewStream(ctx, callHdr)
		if err != nil {
			if put != nil {
				put()
				put = nil
			}
			if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain {
				if c.failFast {
					return nil, toRPCErr(err)
				}
				continue
			}
			return nil, toRPCErr(err)
		}
		break
	}
	cs := &clientStream{
		opts:  opts,
		c:     c,
		desc:  desc,
		codec: cc.dopts.codec,
		cp:    cc.dopts.cp,
		dc:    cc.dopts.dc,

		put: put,
		t:   t,
		s:   s,
		p:   &parser{r: s},

		tracing: EnableTracing,
		trInfo:  trInfo,

		statsCtx: ctx,
	}
	if cc.dopts.cp != nil {
		cs.cbuf = new(bytes.Buffer)
	}
	// Listen on ctx.Done() to detect cancellation and s.Done() to detect normal termination
	// when there is no pending I/O operations on this stream.
	go func() {
		select {
		case <-t.Error():
			// Incur transport error, simply exit.
		case <-s.Done():
			// TODO: The trace of the RPC is terminated here when there is no pending
			// I/O, which is probably not the optimal solution.
			if s.StatusCode() == codes.OK {
				cs.finish(nil)
			} else {
				cs.finish(Errorf(s.StatusCode(), "%s", s.StatusDesc()))
			}
			cs.closeTransportStream(nil)
		case <-s.GoAway():
			cs.finish(errConnDrain)
			cs.closeTransportStream(errConnDrain)
		case <-s.Context().Done():
			err := s.Context().Err()
			cs.finish(err)
			cs.closeTransportStream(transport.ContextErr(err))
		}
	}()
	return cs, nil
}
Ejemplo n.º 12
0
// operateHeaders takes action on the decoded headers.
func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
	s, ok := t.getStream(frame)
	if !ok {
		return
	}
	var state decodeState
	for _, hf := range frame.Fields {
		state.processHeaderField(hf)
	}
	if state.err != nil {
		s.mu.Lock()
		if !s.headerDone {
			close(s.headerChan)
			s.headerDone = true
		}
		s.mu.Unlock()
		s.write(recvMsg{err: state.err})
		// Something wrong. Stops reading even when there is remaining.
		return
	}

	endStream := frame.StreamEnded()
	var isHeader bool
	defer func() {
		if stats.On() {
			if isHeader {
				inHeader := &stats.InHeader{
					Client:     true,
					WireLength: int(frame.Header().Length),
				}
				stats.HandleRPC(s.clientStatsCtx, inHeader)
			} else {
				inTrailer := &stats.InTrailer{
					Client:     true,
					WireLength: int(frame.Header().Length),
				}
				stats.HandleRPC(s.clientStatsCtx, inTrailer)
			}
		}
	}()

	s.mu.Lock()
	if !endStream {
		s.recvCompress = state.encoding
	}
	if !s.headerDone {
		if !endStream && len(state.mdata) > 0 {
			s.header = state.mdata
		}
		close(s.headerChan)
		s.headerDone = true
		isHeader = true
	}
	if !endStream || s.state == streamDone {
		s.mu.Unlock()
		return
	}

	if len(state.mdata) > 0 {
		s.trailer = state.mdata
	}
	s.statusCode = state.statusCode
	s.statusDesc = state.statusDesc
	close(s.done)
	s.state = streamDone
	s.mu.Unlock()
	s.write(recvMsg{err: io.EOF})
}
Ejemplo n.º 13
0
// NewStream creates a stream and register it into the transport as "active"
// streams.
func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Stream, err error) {
	pr := &peer.Peer{
		Addr: t.remoteAddr,
	}
	// Attach Auth info if there is any.
	if t.authInfo != nil {
		pr.AuthInfo = t.authInfo
	}
	userCtx := ctx
	ctx = peer.NewContext(ctx, pr)
	authData := make(map[string]string)
	for _, c := range t.creds {
		// Construct URI required to get auth request metadata.
		var port string
		if pos := strings.LastIndex(t.target, ":"); pos != -1 {
			// Omit port if it is the default one.
			if t.target[pos+1:] != "443" {
				port = ":" + t.target[pos+1:]
			}
		}
		pos := strings.LastIndex(callHdr.Method, "/")
		if pos == -1 {
			return nil, streamErrorf(codes.InvalidArgument, "transport: malformed method name: %q", callHdr.Method)
		}
		audience := "https://" + callHdr.Host + port + callHdr.Method[:pos]
		data, err := c.GetRequestMetadata(ctx, audience)
		if err != nil {
			return nil, streamErrorf(codes.InvalidArgument, "transport: %v", err)
		}
		for k, v := range data {
			authData[k] = v
		}
	}
	t.mu.Lock()
	if t.activeStreams == nil {
		t.mu.Unlock()
		return nil, ErrConnClosing
	}
	if t.state == draining {
		t.mu.Unlock()
		return nil, ErrStreamDrain
	}
	if t.state != reachable {
		t.mu.Unlock()
		return nil, ErrConnClosing
	}
	checkStreamsQuota := t.streamsQuota != nil
	t.mu.Unlock()
	if checkStreamsQuota {
		sq, err := wait(ctx, nil, nil, t.shutdownChan, t.streamsQuota.acquire())
		if err != nil {
			return nil, err
		}
		// Returns the quota balance back.
		if sq > 1 {
			t.streamsQuota.add(sq - 1)
		}
	}
	if _, err := wait(ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil {
		// Return the quota back now because there is no stream returned to the caller.
		if _, ok := err.(StreamError); ok && checkStreamsQuota {
			t.streamsQuota.add(1)
		}
		return nil, err
	}
	t.mu.Lock()
	if t.state == draining {
		t.mu.Unlock()
		if checkStreamsQuota {
			t.streamsQuota.add(1)
		}
		// Need to make t writable again so that the rpc in flight can still proceed.
		t.writableChan <- 0
		return nil, ErrStreamDrain
	}
	if t.state != reachable {
		t.mu.Unlock()
		return nil, ErrConnClosing
	}
	s := t.newStream(ctx, callHdr)
	s.clientStatsCtx = userCtx
	t.activeStreams[s.id] = s

	// This stream is not counted when applySetings(...) initialize t.streamsQuota.
	// Reset t.streamsQuota to the right value.
	var reset bool
	if !checkStreamsQuota && t.streamsQuota != nil {
		reset = true
	}
	t.mu.Unlock()
	if reset {
		t.streamsQuota.add(-1)
	}

	// HPACK encodes various headers. Note that once WriteField(...) is
	// called, the corresponding headers/continuation frame has to be sent
	// because hpack.Encoder is stateful.
	t.hBuf.Reset()
	t.hEnc.WriteField(hpack.HeaderField{Name: ":method", Value: "POST"})
	t.hEnc.WriteField(hpack.HeaderField{Name: ":scheme", Value: t.scheme})
	t.hEnc.WriteField(hpack.HeaderField{Name: ":path", Value: callHdr.Method})
	t.hEnc.WriteField(hpack.HeaderField{Name: ":authority", Value: callHdr.Host})
	t.hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"})
	t.hEnc.WriteField(hpack.HeaderField{Name: "user-agent", Value: t.userAgent})
	t.hEnc.WriteField(hpack.HeaderField{Name: "te", Value: "trailers"})

	if callHdr.SendCompress != "" {
		t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: callHdr.SendCompress})
	}
	if dl, ok := ctx.Deadline(); ok {
		// Send out timeout regardless its value. The server can detect timeout context by itself.
		timeout := dl.Sub(time.Now())
		t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: encodeTimeout(timeout)})
	}

	for k, v := range authData {
		// Capital header names are illegal in HTTP/2.
		k = strings.ToLower(k)
		t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: v})
	}
	var (
		hasMD      bool
		endHeaders bool
	)
	if md, ok := metadata.FromContext(ctx); ok {
		hasMD = true
		for k, v := range md {
			// HTTP doesn't allow you to set pseudoheaders after non pseudoheaders were set.
			if isReservedHeader(k) {
				continue
			}
			for _, entry := range v {
				t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: entry})
			}
		}
	}
	if md, ok := t.md.(*metadata.MD); ok {
		for k, v := range *md {
			if isReservedHeader(k) {
				continue
			}
			for _, entry := range v {
				t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: entry})
			}
		}
	}
	first := true
	bufLen := t.hBuf.Len()
	// Sends the headers in a single batch even when they span multiple frames.
	for !endHeaders {
		size := t.hBuf.Len()
		if size > http2MaxFrameLen {
			size = http2MaxFrameLen
		} else {
			endHeaders = true
		}
		var flush bool
		if endHeaders && (hasMD || callHdr.Flush) {
			flush = true
		}
		if first {
			// Sends a HeadersFrame to server to start a new stream.
			p := http2.HeadersFrameParam{
				StreamID:      s.id,
				BlockFragment: t.hBuf.Next(size),
				EndStream:     false,
				EndHeaders:    endHeaders,
			}
			// Do a force flush for the buffered frames iff it is the last headers frame
			// and there is header metadata to be sent. Otherwise, there is flushing until
			// the corresponding data frame is written.
			err = t.framer.writeHeaders(flush, p)
			first = false
		} else {
			// Sends Continuation frames for the leftover headers.
			err = t.framer.writeContinuation(flush, s.id, endHeaders, t.hBuf.Next(size))
		}
		if err != nil {
			t.notifyError(err)
			return nil, connectionErrorf(true, err, "transport: %v", err)
		}
	}
	if stats.On() {
		outHeader := &stats.OutHeader{
			Client:      true,
			WireLength:  bufLen,
			FullMethod:  callHdr.Method,
			RemoteAddr:  t.remoteAddr,
			LocalAddr:   t.localAddr,
			Compression: callHdr.SendCompress,
		}
		stats.HandleRPC(s.clientStatsCtx, outHeader)
	}
	t.writableChan <- 0
	return s, nil
}
Ejemplo n.º 14
0
func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, sd *StreamDesc, trInfo *traceInfo) (err error) {
	if stats.On() {
		begin := &stats.Begin{
			BeginTime: time.Now(),
		}
		stats.HandleRPC(stream.Context(), begin)
	}
	defer func() {
		if stats.On() {
			end := &stats.End{
				EndTime: time.Now(),
			}
			if err != nil && err != io.EOF {
				end.Error = toRPCErr(err)
			}
			stats.HandleRPC(stream.Context(), end)
		}
	}()
	if s.opts.cp != nil {
		stream.SetSendCompress(s.opts.cp.Type())
	}
	ss := &serverStream{
		t:          t,
		s:          stream,
		p:          &parser{r: stream},
		codec:      s.opts.codec,
		cp:         s.opts.cp,
		dc:         s.opts.dc,
		maxMsgSize: s.opts.maxMsgSize,
		trInfo:     trInfo,
	}
	if ss.cp != nil {
		ss.cbuf = new(bytes.Buffer)
	}
	if trInfo != nil {
		trInfo.tr.LazyLog(&trInfo.firstLine, false)
		defer func() {
			ss.mu.Lock()
			if err != nil && err != io.EOF {
				ss.trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
				ss.trInfo.tr.SetError()
			}
			ss.trInfo.tr.Finish()
			ss.trInfo.tr = nil
			ss.mu.Unlock()
		}()
	}
	var appErr error
	if s.opts.streamInt == nil {
		appErr = sd.Handler(srv.server, ss)
	} else {
		info := &StreamServerInfo{
			FullMethod:     stream.Method(),
			IsClientStream: sd.ClientStreams,
			IsServerStream: sd.ServerStreams,
		}
		appErr = s.opts.streamInt(srv.server, ss, info, sd.Handler)
	}
	if appErr != nil {
		if err, ok := appErr.(*rpcError); ok {
			ss.statusCode = err.code
			ss.statusDesc = err.desc
		} else if err, ok := appErr.(transport.StreamError); ok {
			ss.statusCode = err.Code
			ss.statusDesc = err.Desc
		} else {
			ss.statusCode = convertCode(appErr)
			ss.statusDesc = appErr.Error()
		}
	}
	if trInfo != nil {
		ss.mu.Lock()
		if ss.statusCode != codes.OK {
			ss.trInfo.tr.LazyLog(stringer(ss.statusDesc), true)
			ss.trInfo.tr.SetError()
		} else {
			ss.trInfo.tr.LazyLog(stringer("OK"), false)
		}
		ss.mu.Unlock()
	}
	errWrite := t.WriteStatus(ss.s, ss.statusCode, ss.statusDesc)
	if ss.statusCode != codes.OK {
		return Errorf(ss.statusCode, ss.statusDesc)
	}
	return errWrite

}
Ejemplo n.º 15
0
func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, md *MethodDesc, trInfo *traceInfo) (err error) {
	if stats.On() {
		begin := &stats.Begin{
			BeginTime: time.Now(),
		}
		stats.HandleRPC(stream.Context(), begin)
	}
	defer func() {
		if stats.On() {
			end := &stats.End{
				EndTime: time.Now(),
			}
			if err != nil && err != io.EOF {
				end.Error = toRPCErr(err)
			}
			stats.HandleRPC(stream.Context(), end)
		}
	}()
	if trInfo != nil {
		defer trInfo.tr.Finish()
		trInfo.firstLine.client = false
		trInfo.tr.LazyLog(&trInfo.firstLine, false)
		defer func() {
			if err != nil && err != io.EOF {
				trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
				trInfo.tr.SetError()
			}
		}()
	}
	if s.opts.cp != nil {
		// NOTE: this needs to be ahead of all handling, https://github.com/grpc/grpc-go/issues/686.
		stream.SetSendCompress(s.opts.cp.Type())
	}
	p := &parser{r: stream}
	for {
		pf, req, err := p.recvMsg(s.opts.maxMsgSize)
		if err == io.EOF {
			// The entire stream is done (for unary RPC only).
			return err
		}
		if err == io.ErrUnexpectedEOF {
			err = Errorf(codes.Internal, io.ErrUnexpectedEOF.Error())
		}
		if err != nil {
			switch err := err.(type) {
			case *rpcError:
				if e := t.WriteStatus(stream, err.code, err.desc); e != nil {
					grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e)
				}
			case transport.ConnectionError:
				// Nothing to do here.
			case transport.StreamError:
				if e := t.WriteStatus(stream, err.Code, err.Desc); e != nil {
					grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e)
				}
			default:
				panic(fmt.Sprintf("grpc: Unexpected error (%T) from recvMsg: %v", err, err))
			}
			return err
		}

		if err := checkRecvPayload(pf, stream.RecvCompress(), s.opts.dc); err != nil {
			switch err := err.(type) {
			case *rpcError:
				if e := t.WriteStatus(stream, err.code, err.desc); e != nil {
					grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e)
				}
				return err
			default:
				if e := t.WriteStatus(stream, codes.Internal, err.Error()); e != nil {
					grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e)
				}
				// TODO checkRecvPayload always return RPC error. Add a return here if necessary.
			}
		}
		var inPayload *stats.InPayload
		if stats.On() {
			inPayload = &stats.InPayload{
				RecvTime: time.Now(),
			}
		}
		statusCode := codes.OK
		statusDesc := ""
		df := func(v interface{}) error {
			if inPayload != nil {
				inPayload.WireLength = len(req)
			}
			if pf == compressionMade {
				var err error
				req, err = s.opts.dc.Do(bytes.NewReader(req))
				if err != nil {
					if err := t.WriteStatus(stream, codes.Internal, err.Error()); err != nil {
						grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err)
					}
					return Errorf(codes.Internal, err.Error())
				}
			}
			if len(req) > s.opts.maxMsgSize {
				// TODO: Revisit the error code. Currently keep it consistent with
				// java implementation.
				statusCode = codes.Internal
				statusDesc = fmt.Sprintf("grpc: server received a message of %d bytes exceeding %d limit", len(req), s.opts.maxMsgSize)
			}
			if err := s.opts.codec.Unmarshal(req, v); err != nil {
				return err
			}
			if inPayload != nil {
				inPayload.Payload = v
				inPayload.Data = req
				inPayload.Length = len(req)
				stats.HandleRPC(stream.Context(), inPayload)
			}
			if trInfo != nil {
				trInfo.tr.LazyLog(&payload{sent: false, msg: v}, true)
			}
			return nil
		}
		reply, appErr := md.Handler(srv.server, stream.Context(), df, s.opts.unaryInt)
		if appErr != nil {
			if err, ok := appErr.(*rpcError); ok {
				statusCode = err.code
				statusDesc = err.desc
			} else {
				statusCode = convertCode(appErr)
				statusDesc = appErr.Error()
			}
			if trInfo != nil && statusCode != codes.OK {
				trInfo.tr.LazyLog(stringer(statusDesc), true)
				trInfo.tr.SetError()
			}
			if err := t.WriteStatus(stream, statusCode, statusDesc); err != nil {
				grpclog.Printf("grpc: Server.processUnaryRPC failed to write status: %v", err)
			}
			return Errorf(statusCode, statusDesc)
		}
		if trInfo != nil {
			trInfo.tr.LazyLog(stringer("OK"), false)
		}
		opts := &transport.Options{
			Last:  true,
			Delay: false,
		}
		if err := s.sendResponse(t, stream, reply, s.opts.cp, opts); err != nil {
			switch err := err.(type) {
			case transport.ConnectionError:
				// Nothing to do here.
			case transport.StreamError:
				statusCode = err.Code
				statusDesc = err.Desc
			default:
				statusCode = codes.Unknown
				statusDesc = err.Error()
			}
			return err
		}
		if trInfo != nil {
			trInfo.tr.LazyLog(&payload{sent: true, msg: reply}, true)
		}
		errWrite := t.WriteStatus(stream, statusCode, statusDesc)
		if statusCode != codes.OK {
			return Errorf(statusCode, statusDesc)
		}
		return errWrite
	}
}
Ejemplo n.º 16
0
func invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) (e error) {
	c := defaultCallInfo
	for _, o := range opts {
		if err := o.before(&c); err != nil {
			return toRPCErr(err)
		}
	}
	defer func() {
		for _, o := range opts {
			o.after(&c)
		}
	}()
	if EnableTracing {
		c.traceInfo.tr = trace.New("grpc.Sent."+methodFamily(method), method)
		defer c.traceInfo.tr.Finish()
		c.traceInfo.firstLine.client = true
		if deadline, ok := ctx.Deadline(); ok {
			c.traceInfo.firstLine.deadline = deadline.Sub(time.Now())
		}
		c.traceInfo.tr.LazyLog(&c.traceInfo.firstLine, false)
		// TODO(dsymonds): Arrange for c.traceInfo.firstLine.remoteAddr to be set.
		defer func() {
			if e != nil {
				c.traceInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{e}}, true)
				c.traceInfo.tr.SetError()
			}
		}()
	}
	if stats.On() {
		ctx = stats.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method})
		begin := &stats.Begin{
			Client:    true,
			BeginTime: time.Now(),
			FailFast:  c.failFast,
		}
		stats.HandleRPC(ctx, begin)
	}
	defer func() {
		if stats.On() {
			end := &stats.End{
				Client:  true,
				EndTime: time.Now(),
				Error:   e,
			}
			stats.HandleRPC(ctx, end)
		}
	}()
	topts := &transport.Options{
		Last:  true,
		Delay: false,
	}
	for {
		var (
			err    error
			t      transport.ClientTransport
			stream *transport.Stream
			// Record the put handler from Balancer.Get(...). It is called once the
			// RPC has completed or failed.
			put func()
		)
		// TODO(zhaoq): Need a formal spec of fail-fast.
		callHdr := &transport.CallHdr{
			Host:   cc.authority,
			Method: method,
		}
		if cc.dopts.cp != nil {
			callHdr.SendCompress = cc.dopts.cp.Type()
		}
		gopts := BalancerGetOptions{
			BlockingWait: !c.failFast,
		}
		t, put, err = cc.getTransport(ctx, gopts)
		if err != nil {
			// TODO(zhaoq): Probably revisit the error handling.
			if _, ok := err.(*rpcError); ok {
				return err
			}
			if err == errConnClosing || err == errConnUnavailable {
				if c.failFast {
					return Errorf(codes.Unavailable, "%v", err)
				}
				continue
			}
			// All the other errors are treated as Internal errors.
			return Errorf(codes.Internal, "%v", err)
		}
		if c.traceInfo.tr != nil {
			c.traceInfo.tr.LazyLog(&payload{sent: true, msg: args}, true)
		}
		stream, err = sendRequest(ctx, cc.dopts.codec, cc.dopts.cp, callHdr, t, args, topts)
		if err != nil {
			if put != nil {
				put()
				put = nil
			}
			// Retry a non-failfast RPC when
			// i) there is a connection error; or
			// ii) the server started to drain before this RPC was initiated.
			if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain {
				if c.failFast {
					return toRPCErr(err)
				}
				continue
			}
			return toRPCErr(err)
		}
		err = recvResponse(ctx, cc.dopts, t, &c, stream, reply)
		if err != nil {
			if put != nil {
				put()
				put = nil
			}
			if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain {
				if c.failFast {
					return toRPCErr(err)
				}
				continue
			}
			return toRPCErr(err)
		}
		if c.traceInfo.tr != nil {
			c.traceInfo.tr.LazyLog(&payload{sent: false, msg: reply}, true)
		}
		t.CloseStream(stream, nil)
		if put != nil {
			put()
			put = nil
		}
		return Errorf(stream.StatusCode(), "%s", stream.StatusDesc())
	}
}