func TestTagRPCCtx(t *testing.T) { defer stats.RegisterRPCTagger(nil) ctx1 := context.Background() stats.RegisterRPCTagger(nil) ctx2 := stats.TagRPC(ctx1, nil) if ctx2 != ctx1 { t.Fatalf("nil rpc ctx tagger should not modify context, got %v; want %v", ctx2, ctx1) } stats.RegisterRPCTagger(func(ctx context.Context, info *stats.RPCTagInfo) context.Context { return context.WithValue(ctx, rpcCtxKey{}, "rpcctxvalue") }) ctx3 := stats.TagRPC(ctx1, nil) if v, ok := ctx3.Value(rpcCtxKey{}).(string); !ok || v != "rpcctxvalue" { t.Fatalf("got context %v; want %v", ctx3, context.WithValue(ctx1, rpcCtxKey{}, "rpcctxvalue")) } }
// operateHeader takes action on the decoded headers. func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream), traceCtx func(context.Context, string) context.Context) (close bool) { buf := newRecvBuffer() s := &Stream{ id: frame.Header().StreamID, st: t, buf: buf, fc: &inFlow{limit: initialWindowSize}, } var state decodeState for _, hf := range frame.Fields { state.processHeaderField(hf) } if err := state.err; err != nil { if se, ok := err.(StreamError); ok { t.controlBuf.put(&resetStream{s.id, statusCodeConvTab[se.Code]}) } return } if frame.StreamEnded() { // s is just created by the caller. No lock needed. s.state = streamReadDone } s.recvCompress = state.encoding if state.timeoutSet { s.ctx, s.cancel = context.WithTimeout(t.ctx, state.timeout) } else { s.ctx, s.cancel = context.WithCancel(t.ctx) } pr := &peer.Peer{ Addr: t.remoteAddr, } // Attach Auth info if there is any. if t.authInfo != nil { pr.AuthInfo = t.authInfo } s.ctx = peer.NewContext(s.ctx, pr) // Cache the current stream to the context so that the server application // can find out. Required when the server wants to send some metadata // back to the client (unary call only). s.ctx = newContextWithStream(s.ctx, s) // Attach the received metadata to the context. if len(state.mdata) > 0 { s.ctx = metadata.NewContext(s.ctx, state.mdata) } s.dec = &recvBufferReader{ ctx: s.ctx, recv: s.buf, } s.recvCompress = state.encoding s.method = state.method if t.inTapHandle != nil { var err error info := &tap.Info{ FullMethodName: state.method, } s.ctx, err = t.inTapHandle(s.ctx, info) if err != nil { // TODO: Log the real error. t.controlBuf.put(&resetStream{s.id, http2.ErrCodeRefusedStream}) return } } t.mu.Lock() if t.state != reachable { t.mu.Unlock() return } if uint32(len(t.activeStreams)) >= t.maxStreams { t.mu.Unlock() t.controlBuf.put(&resetStream{s.id, http2.ErrCodeRefusedStream}) return } if s.id%2 != 1 || s.id <= t.maxStreamID { t.mu.Unlock() // illegal gRPC stream id. grpclog.Println("transport: http2Server.HandleStreams received an illegal stream id: ", s.id) return true } t.maxStreamID = s.id s.sendQuotaPool = newQuotaPool(int(t.streamSendQuota)) t.activeStreams[s.id] = s t.mu.Unlock() s.windowHandler = func(n int) { t.updateWindow(s, uint32(n)) } s.ctx = traceCtx(s.ctx, s.method) if stats.On() { s.ctx = stats.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method}) inHeader := &stats.InHeader{ FullMethod: s.method, RemoteAddr: t.remoteAddr, LocalAddr: t.localAddr, Compression: s.recvCompress, WireLength: int(frame.Header().Length), } stats.HandleRPC(s.ctx, inHeader) } handle(s) return }
func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) { var ( t transport.ClientTransport s *transport.Stream put func() ) c := defaultCallInfo for _, o := range opts { if err := o.before(&c); err != nil { return nil, toRPCErr(err) } } callHdr := &transport.CallHdr{ Host: cc.authority, Method: method, Flush: desc.ServerStreams && desc.ClientStreams, } if cc.dopts.cp != nil { callHdr.SendCompress = cc.dopts.cp.Type() } var trInfo traceInfo if EnableTracing { trInfo.tr = trace.New("grpc.Sent."+methodFamily(method), method) trInfo.firstLine.client = true if deadline, ok := ctx.Deadline(); ok { trInfo.firstLine.deadline = deadline.Sub(time.Now()) } trInfo.tr.LazyLog(&trInfo.firstLine, false) ctx = trace.NewContext(ctx, trInfo.tr) defer func() { if err != nil { // Need to call tr.finish() if error is returned. // Because tr will not be returned to caller. trInfo.tr.LazyPrintf("RPC: [%v]", err) trInfo.tr.SetError() trInfo.tr.Finish() } }() } if stats.On() { ctx = stats.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method}) begin := &stats.Begin{ Client: true, BeginTime: time.Now(), FailFast: c.failFast, } stats.HandleRPC(ctx, begin) } defer func() { if err != nil && stats.On() { // Only handle end stats if err != nil. end := &stats.End{ Client: true, Error: err, } stats.HandleRPC(ctx, end) } }() gopts := BalancerGetOptions{ BlockingWait: !c.failFast, } for { t, put, err = cc.getTransport(ctx, gopts) if err != nil { // TODO(zhaoq): Probably revisit the error handling. if _, ok := err.(*rpcError); ok { return nil, err } if err == errConnClosing || err == errConnUnavailable { if c.failFast { return nil, Errorf(codes.Unavailable, "%v", err) } continue } // All the other errors are treated as Internal errors. return nil, Errorf(codes.Internal, "%v", err) } s, err = t.NewStream(ctx, callHdr) if err != nil { if put != nil { put() put = nil } if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain { if c.failFast { return nil, toRPCErr(err) } continue } return nil, toRPCErr(err) } break } cs := &clientStream{ opts: opts, c: c, desc: desc, codec: cc.dopts.codec, cp: cc.dopts.cp, dc: cc.dopts.dc, put: put, t: t, s: s, p: &parser{r: s}, tracing: EnableTracing, trInfo: trInfo, statsCtx: ctx, } if cc.dopts.cp != nil { cs.cbuf = new(bytes.Buffer) } // Listen on ctx.Done() to detect cancellation and s.Done() to detect normal termination // when there is no pending I/O operations on this stream. go func() { select { case <-t.Error(): // Incur transport error, simply exit. case <-s.Done(): // TODO: The trace of the RPC is terminated here when there is no pending // I/O, which is probably not the optimal solution. if s.StatusCode() == codes.OK { cs.finish(nil) } else { cs.finish(Errorf(s.StatusCode(), "%s", s.StatusDesc())) } cs.closeTransportStream(nil) case <-s.GoAway(): cs.finish(errConnDrain) cs.closeTransportStream(errConnDrain) case <-s.Context().Done(): err := s.Context().Err() cs.finish(err) cs.closeTransportStream(transport.ContextErr(err)) } }() return cs, nil }
func invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) (e error) { c := defaultCallInfo for _, o := range opts { if err := o.before(&c); err != nil { return toRPCErr(err) } } defer func() { for _, o := range opts { o.after(&c) } }() if EnableTracing { c.traceInfo.tr = trace.New("grpc.Sent."+methodFamily(method), method) defer c.traceInfo.tr.Finish() c.traceInfo.firstLine.client = true if deadline, ok := ctx.Deadline(); ok { c.traceInfo.firstLine.deadline = deadline.Sub(time.Now()) } c.traceInfo.tr.LazyLog(&c.traceInfo.firstLine, false) // TODO(dsymonds): Arrange for c.traceInfo.firstLine.remoteAddr to be set. defer func() { if e != nil { c.traceInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{e}}, true) c.traceInfo.tr.SetError() } }() } if stats.On() { ctx = stats.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method}) begin := &stats.Begin{ Client: true, BeginTime: time.Now(), FailFast: c.failFast, } stats.HandleRPC(ctx, begin) } defer func() { if stats.On() { end := &stats.End{ Client: true, EndTime: time.Now(), Error: e, } stats.HandleRPC(ctx, end) } }() topts := &transport.Options{ Last: true, Delay: false, } for { var ( err error t transport.ClientTransport stream *transport.Stream // Record the put handler from Balancer.Get(...). It is called once the // RPC has completed or failed. put func() ) // TODO(zhaoq): Need a formal spec of fail-fast. callHdr := &transport.CallHdr{ Host: cc.authority, Method: method, } if cc.dopts.cp != nil { callHdr.SendCompress = cc.dopts.cp.Type() } gopts := BalancerGetOptions{ BlockingWait: !c.failFast, } t, put, err = cc.getTransport(ctx, gopts) if err != nil { // TODO(zhaoq): Probably revisit the error handling. if _, ok := err.(*rpcError); ok { return err } if err == errConnClosing || err == errConnUnavailable { if c.failFast { return Errorf(codes.Unavailable, "%v", err) } continue } // All the other errors are treated as Internal errors. return Errorf(codes.Internal, "%v", err) } if c.traceInfo.tr != nil { c.traceInfo.tr.LazyLog(&payload{sent: true, msg: args}, true) } stream, err = sendRequest(ctx, cc.dopts.codec, cc.dopts.cp, callHdr, t, args, topts) if err != nil { if put != nil { put() put = nil } // Retry a non-failfast RPC when // i) there is a connection error; or // ii) the server started to drain before this RPC was initiated. if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain { if c.failFast { return toRPCErr(err) } continue } return toRPCErr(err) } err = recvResponse(ctx, cc.dopts, t, &c, stream, reply) if err != nil { if put != nil { put() put = nil } if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain { if c.failFast { return toRPCErr(err) } continue } return toRPCErr(err) } if c.traceInfo.tr != nil { c.traceInfo.tr.LazyLog(&payload{sent: false, msg: reply}, true) } t.CloseStream(stream, nil) if put != nil { put() put = nil } return Errorf(stream.StatusCode(), "%s", stream.StatusDesc()) } }