func (ss *serverStream) RecvMsg(m interface{}) (err error) { defer func() { if ss.trInfo != nil { ss.mu.Lock() if ss.trInfo.tr != nil { if err == nil { ss.trInfo.tr.LazyLog(&payload{sent: false, msg: m}, true) } else if err != io.EOF { ss.trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true) ss.trInfo.tr.SetError() } } ss.mu.Unlock() } }() var inPayload *stats.InPayload if stats.On() { inPayload = &stats.InPayload{} } if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxMsgSize, inPayload); err != nil { if err == io.EOF { return err } if err == io.ErrUnexpectedEOF { err = Errorf(codes.Internal, io.ErrUnexpectedEOF.Error()) } return toRPCErr(err) } if inPayload != nil { stats.HandleRPC(ss.s.Context(), inPayload) } return nil }
// Close kicks off the shutdown process of the transport. This should be called // only once on a transport. Once it is called, the transport should not be // accessed any more. func (t *http2Client) Close() (err error) { t.mu.Lock() if t.state == closing { t.mu.Unlock() return } if t.state == reachable || t.state == draining { close(t.errorChan) } t.state = closing t.mu.Unlock() close(t.shutdownChan) err = t.conn.Close() t.mu.Lock() streams := t.activeStreams t.activeStreams = nil t.mu.Unlock() // Notify all active streams. for _, s := range streams { s.mu.Lock() if !s.headerDone { close(s.headerChan) s.headerDone = true } s.mu.Unlock() s.write(recvMsg{err: ErrConnClosing}) } if stats.On() { connEnd := &stats.ConnEnd{ Client: true, } stats.HandleConn(t.ctx, connEnd) } return }
func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Stream, msg interface{}, cp Compressor, opts *transport.Options) error { var ( cbuf *bytes.Buffer outPayload *stats.OutPayload ) if cp != nil { cbuf = new(bytes.Buffer) } if stats.On() { outPayload = &stats.OutPayload{} } p, err := encode(s.opts.codec, msg, cp, cbuf, outPayload) if err != nil { // This typically indicates a fatal issue (e.g., memory // corruption or hardware faults) the application program // cannot handle. // // TODO(zhaoq): There exist other options also such as only closing the // faulty stream locally and remotely (Other streams can keep going). Find // the optimal option. grpclog.Fatalf("grpc: Server failed to encode response %v", err) } err = t.Write(stream, p, opts) if err == nil && outPayload != nil { outPayload.SentTime = time.Now() stats.HandleRPC(stream.Context(), outPayload) } return err }
// WriteStatus sends stream status to the client and terminates the stream. // There is no further I/O operations being able to perform on this stream. // TODO(zhaoq): Now it indicates the end of entire stream. Revisit if early // OK is adopted. func (t *http2Server) WriteStatus(s *Stream, statusCode codes.Code, statusDesc string) error { var headersSent, hasHeader bool s.mu.Lock() if s.state == streamDone { s.mu.Unlock() return nil } if s.headerOk { headersSent = true } if s.header.Len() > 0 { hasHeader = true } s.mu.Unlock() if !headersSent && hasHeader { t.WriteHeader(s, nil) headersSent = true } if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil { return err } t.hBuf.Reset() if !headersSent { t.hEnc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) t.hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"}) } t.hEnc.WriteField( hpack.HeaderField{ Name: "grpc-status", Value: strconv.Itoa(int(statusCode)), }) t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-message", Value: encodeGrpcMessage(statusDesc)}) // Attach the trailer metadata. for k, v := range s.trailer { // Clients don't tolerate reading restricted headers after some non restricted ones were sent. if isReservedHeader(k) { continue } for _, entry := range v { t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: entry}) } } bufLen := t.hBuf.Len() if err := t.writeHeaders(s, t.hBuf, true); err != nil { t.Close() return err } if stats.On() { outTrailer := &stats.OutTrailer{ WireLength: bufLen, } stats.HandleRPC(s.Context(), outTrailer) } t.closeStream(s) t.writableChan <- 0 return nil }
func (cs *clientStream) SendMsg(m interface{}) (err error) { if cs.tracing { cs.mu.Lock() if cs.trInfo.tr != nil { cs.trInfo.tr.LazyLog(&payload{sent: true, msg: m}, true) } cs.mu.Unlock() } // TODO Investigate how to signal the stats handling party. // generate error stats if err != nil && err != io.EOF? defer func() { if err != nil { cs.finish(err) } if err == nil { return } if err == io.EOF { // Specialize the process for server streaming. SendMesg is only called // once when creating the stream object. io.EOF needs to be skipped when // the rpc is early finished (before the stream object is created.). // TODO: It is probably better to move this into the generated code. if !cs.desc.ClientStreams && cs.desc.ServerStreams { err = nil } return } if _, ok := err.(transport.ConnectionError); !ok { cs.closeTransportStream(err) } err = toRPCErr(err) }() var outPayload *stats.OutPayload if stats.On() { outPayload = &stats.OutPayload{ Client: true, } } out, err := encode(cs.codec, m, cs.cp, cs.cbuf, outPayload) defer func() { if cs.cbuf != nil { cs.cbuf.Reset() } }() if err != nil { return Errorf(codes.Internal, "grpc: %v", err) } err = cs.t.Write(cs.s, out, &transport.Options{Last: false}) if err == nil && outPayload != nil { outPayload.SentTime = time.Now() stats.HandleRPC(cs.statsCtx, outPayload) } return err }
// WriteHeader sends the header metedata md back to the client. func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { s.mu.Lock() if s.headerOk || s.state == streamDone { s.mu.Unlock() return ErrIllegalHeaderWrite } s.headerOk = true if md.Len() > 0 { if s.header.Len() > 0 { s.header = metadata.Join(s.header, md) } else { s.header = md } } md = s.header s.mu.Unlock() if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil { return err } t.hBuf.Reset() t.hEnc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) t.hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"}) if s.sendCompress != "" { t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: s.sendCompress}) } for k, v := range md { if isReservedHeader(k) { // Clients don't tolerate reading restricted headers after some non restricted ones were sent. continue } for _, entry := range v { t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: entry}) } } bufLen := t.hBuf.Len() if err := t.writeHeaders(s, t.hBuf, false); err != nil { return err } if stats.On() { outHeader := &stats.OutHeader{ WireLength: bufLen, } stats.HandleRPC(s.Context(), outHeader) } t.writableChan <- 0 return nil }
// sendRequest writes out various information of an RPC such as Context and Message. func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHdr *transport.CallHdr, t transport.ClientTransport, args interface{}, opts *transport.Options) (_ *transport.Stream, err error) { stream, err := t.NewStream(ctx, callHdr) if err != nil { return nil, err } defer func() { if err != nil { // If err is connection error, t will be closed, no need to close stream here. if _, ok := err.(transport.ConnectionError); !ok { t.CloseStream(stream, err) } } }() var ( cbuf *bytes.Buffer outPayload *stats.OutPayload ) if compressor != nil { cbuf = new(bytes.Buffer) } if stats.On() { outPayload = &stats.OutPayload{ Client: true, } } outBuf, err := encode(codec, args, compressor, cbuf, outPayload) if err != nil { return nil, Errorf(codes.Internal, "grpc: %v", err) } err = t.Write(stream, outBuf, opts) if err == nil && outPayload != nil { outPayload.SentTime = time.Now() stats.Handle(ctx, outPayload) } // t.NewStream(...) could lead to an early rejection of the RPC (e.g., the service/method // does not exist.) so that t.Write could get io.EOF from wait(...). Leave the following // recvResponse to get the final status. if err != nil && err != io.EOF { return nil, err } // Sent successfully. return stream, nil }
// Close starts shutting down the http2Server transport. // TODO(zhaoq): Now the destruction is not blocked on any pending streams. This // could cause some resource issue. Revisit this later. func (t *http2Server) Close() (err error) { t.mu.Lock() if t.state == closing { t.mu.Unlock() return errors.New("transport: Close() was already called") } t.state = closing streams := t.activeStreams t.activeStreams = nil t.mu.Unlock() close(t.shutdownChan) err = t.conn.Close() // Cancel all active streams. for _, s := range streams { s.cancel() } if stats.On() { connEnd := &stats.ConnEnd{} stats.HandleConn(t.ctx, connEnd) } return }
func (ss *serverStream) SendMsg(m interface{}) (err error) { defer func() { if ss.trInfo != nil { ss.mu.Lock() if ss.trInfo.tr != nil { if err == nil { ss.trInfo.tr.LazyLog(&payload{sent: true, msg: m}, true) } else { ss.trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true) ss.trInfo.tr.SetError() } } ss.mu.Unlock() } }() var outPayload *stats.OutPayload if stats.On() { outPayload = &stats.OutPayload{} } out, err := encode(ss.codec, m, ss.cp, ss.cbuf, outPayload) defer func() { if ss.cbuf != nil { ss.cbuf.Reset() } }() if err != nil { err = Errorf(codes.Internal, "grpc: %v", err) return err } if err := ss.t.Write(ss.s, out, &transport.Options{Last: false}); err != nil { return toRPCErr(err) } if outPayload != nil { outPayload.SentTime = time.Now() stats.HandleRPC(ss.s.Context(), outPayload) } return nil }
// recvResponse receives and parses an RPC response. // On error, it returns the error and indicates whether the call should be retried. // // TODO(zhaoq): Check whether the received message sequence is valid. // TODO ctx is used for stats collection and processing. It is the context passed from the application. func recvResponse(ctx context.Context, dopts dialOptions, t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) (err error) { // Try to acquire header metadata from the server if there is any. defer func() { if err != nil { if _, ok := err.(transport.ConnectionError); !ok { t.CloseStream(stream, err) } } }() c.headerMD, err = stream.Header() if err != nil { return } p := &parser{r: stream} var inPayload *stats.InPayload if stats.On() { inPayload = &stats.InPayload{ Client: true, } } for { if err = recv(p, dopts.codec, stream, dopts.dc, reply, math.MaxInt32, inPayload); err != nil { if err == io.EOF { break } return } } if inPayload != nil && err == io.EOF && stream.StatusCode() == codes.OK { // TODO in the current implementation, inTrailer may be handled before inPayload in some cases. // Fix the order if necessary. stats.Handle(ctx, inPayload) } c.trailerMD = stream.Trailer() return nil }
func TestStartStop(t *testing.T) { stats.RegisterRPCHandler(nil) stats.RegisterConnHandler(nil) stats.Start() if stats.On() { t.Fatalf("stats.Start() with nil handler, stats.On() = true, want false") } stats.RegisterRPCHandler(func(ctx context.Context, s stats.RPCStats) {}) stats.RegisterConnHandler(nil) stats.Start() if !stats.On() { t.Fatalf("stats.Start() with non-nil handler, stats.On() = false, want true") } stats.Stop() stats.RegisterRPCHandler(nil) stats.RegisterConnHandler(func(ctx context.Context, s stats.ConnStats) {}) stats.Start() if !stats.On() { t.Fatalf("stats.Start() with non-nil conn handler, stats.On() = false, want true") } stats.Stop() stats.RegisterRPCHandler(func(ctx context.Context, s stats.RPCStats) {}) stats.RegisterConnHandler(func(ctx context.Context, s stats.ConnStats) {}) if stats.On() { t.Fatalf("after stats.RegisterRPCHandler(), stats.On() = true, want false") } stats.Start() if !stats.On() { t.Fatalf("after stats.Start(_), stats.On() = false, want true") } stats.Stop() if stats.On() { t.Fatalf("after stats.Stop(), stats.On() = true, want false") } }
// newHTTP2Server constructs a ServerTransport based on HTTP2. ConnectionError is // returned if something goes wrong. func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err error) { framer := newFramer(conn) // Send initial settings as connection preface to client. var settings []http2.Setting // TODO(zhaoq): Have a better way to signal "no limit" because 0 is // permitted in the HTTP2 spec. maxStreams := config.MaxStreams if maxStreams == 0 { maxStreams = math.MaxUint32 } else { settings = append(settings, http2.Setting{ ID: http2.SettingMaxConcurrentStreams, Val: maxStreams, }) } if initialWindowSize != defaultWindowSize { settings = append(settings, http2.Setting{ ID: http2.SettingInitialWindowSize, Val: uint32(initialWindowSize)}) } if err := framer.writeSettings(true, settings...); err != nil { return nil, connectionErrorf(true, err, "transport: %v", err) } // Adjust the connection flow control window if needed. if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 { if err := framer.writeWindowUpdate(true, 0, delta); err != nil { return nil, connectionErrorf(true, err, "transport: %v", err) } } var buf bytes.Buffer t := &http2Server{ ctx: context.Background(), conn: conn, remoteAddr: conn.RemoteAddr(), localAddr: conn.LocalAddr(), authInfo: config.AuthInfo, framer: framer, hBuf: &buf, hEnc: hpack.NewEncoder(&buf), maxStreams: maxStreams, inTapHandle: config.InTapHandle, controlBuf: newRecvBuffer(), fc: &inFlow{limit: initialConnWindowSize}, sendQuotaPool: newQuotaPool(defaultWindowSize), state: reachable, writableChan: make(chan int, 1), shutdownChan: make(chan struct{}), activeStreams: make(map[uint32]*Stream), streamSendQuota: defaultWindowSize, } if stats.On() { t.ctx = stats.TagConn(t.ctx, &stats.ConnTagInfo{ RemoteAddr: t.remoteAddr, LocalAddr: t.localAddr, }) connBegin := &stats.ConnBegin{} stats.HandleConn(t.ctx, connBegin) } go t.controller() t.writableChan <- 0 return t, nil }
func invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) (e error) { c := defaultCallInfo for _, o := range opts { if err := o.before(&c); err != nil { return toRPCErr(err) } } defer func() { for _, o := range opts { o.after(&c) } }() if EnableTracing { c.traceInfo.tr = trace.New("grpc.Sent."+methodFamily(method), method) defer c.traceInfo.tr.Finish() c.traceInfo.firstLine.client = true if deadline, ok := ctx.Deadline(); ok { c.traceInfo.firstLine.deadline = deadline.Sub(time.Now()) } c.traceInfo.tr.LazyLog(&c.traceInfo.firstLine, false) // TODO(dsymonds): Arrange for c.traceInfo.firstLine.remoteAddr to be set. defer func() { if e != nil { c.traceInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{e}}, true) c.traceInfo.tr.SetError() } }() } if stats.On() { begin := &stats.Begin{ Client: true, BeginTime: time.Now(), FailFast: c.failFast, } stats.Handle(ctx, begin) } defer func() { if stats.On() { end := &stats.End{ Client: true, EndTime: time.Now(), Error: e, } stats.Handle(ctx, end) } }() topts := &transport.Options{ Last: true, Delay: false, } for { var ( err error t transport.ClientTransport stream *transport.Stream // Record the put handler from Balancer.Get(...). It is called once the // RPC has completed or failed. put func() ) // TODO(zhaoq): Need a formal spec of fail-fast. callHdr := &transport.CallHdr{ Host: cc.authority, Method: method, } if cc.dopts.cp != nil { callHdr.SendCompress = cc.dopts.cp.Type() } gopts := BalancerGetOptions{ BlockingWait: !c.failFast, } t, put, err = cc.getTransport(ctx, gopts) if err != nil { // TODO(zhaoq): Probably revisit the error handling. if _, ok := err.(*rpcError); ok { return err } if err == errConnClosing || err == errConnUnavailable { if c.failFast { return Errorf(codes.Unavailable, "%v", err) } continue } // All the other errors are treated as Internal errors. return Errorf(codes.Internal, "%v", err) } if c.traceInfo.tr != nil { c.traceInfo.tr.LazyLog(&payload{sent: true, msg: args}, true) } stream, err = sendRequest(ctx, cc.dopts.codec, cc.dopts.cp, callHdr, t, args, topts) if err != nil { if put != nil { put() put = nil } // Retry a non-failfast RPC when // i) there is a connection error; or // ii) the server started to drain before this RPC was initiated. if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain { if c.failFast { return toRPCErr(err) } continue } return toRPCErr(err) } err = recvResponse(ctx, cc.dopts, t, &c, stream, reply) if err != nil { if put != nil { put() put = nil } if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain { if c.failFast { return toRPCErr(err) } continue } return toRPCErr(err) } if c.traceInfo.tr != nil { c.traceInfo.tr.LazyLog(&payload{sent: false, msg: reply}, true) } t.CloseStream(stream, nil) if put != nil { put() put = nil } return Errorf(stream.StatusCode(), "%s", stream.StatusDesc()) } }
func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, sd *StreamDesc, trInfo *traceInfo) (err error) { if stats.On() { begin := &stats.Begin{ BeginTime: time.Now(), } stats.HandleRPC(stream.Context(), begin) } defer func() { if stats.On() { end := &stats.End{ EndTime: time.Now(), } if err != nil && err != io.EOF { end.Error = toRPCErr(err) } stats.HandleRPC(stream.Context(), end) } }() if s.opts.cp != nil { stream.SetSendCompress(s.opts.cp.Type()) } ss := &serverStream{ t: t, s: stream, p: &parser{r: stream}, codec: s.opts.codec, cp: s.opts.cp, dc: s.opts.dc, maxMsgSize: s.opts.maxMsgSize, trInfo: trInfo, } if ss.cp != nil { ss.cbuf = new(bytes.Buffer) } if trInfo != nil { trInfo.tr.LazyLog(&trInfo.firstLine, false) defer func() { ss.mu.Lock() if err != nil && err != io.EOF { ss.trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true) ss.trInfo.tr.SetError() } ss.trInfo.tr.Finish() ss.trInfo.tr = nil ss.mu.Unlock() }() } var appErr error if s.opts.streamInt == nil { appErr = sd.Handler(srv.server, ss) } else { info := &StreamServerInfo{ FullMethod: stream.Method(), IsClientStream: sd.ClientStreams, IsServerStream: sd.ServerStreams, } appErr = s.opts.streamInt(srv.server, ss, info, sd.Handler) } if appErr != nil { if err, ok := appErr.(*rpcError); ok { ss.statusCode = err.code ss.statusDesc = err.desc } else if err, ok := appErr.(transport.StreamError); ok { ss.statusCode = err.Code ss.statusDesc = err.Desc } else { ss.statusCode = convertCode(appErr) ss.statusDesc = appErr.Error() } } if trInfo != nil { ss.mu.Lock() if ss.statusCode != codes.OK { ss.trInfo.tr.LazyLog(stringer(ss.statusDesc), true) ss.trInfo.tr.SetError() } else { ss.trInfo.tr.LazyLog(stringer("OK"), false) } ss.mu.Unlock() } errWrite := t.WriteStatus(ss.s, ss.statusCode, ss.statusDesc) if ss.statusCode != codes.OK { return Errorf(ss.statusCode, ss.statusDesc) } return errWrite }
// operateHeader takes action on the decoded headers. func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream), traceCtx func(context.Context, string) context.Context) (close bool) { buf := newRecvBuffer() s := &Stream{ id: frame.Header().StreamID, st: t, buf: buf, fc: &inFlow{limit: initialWindowSize}, } var state decodeState for _, hf := range frame.Fields { state.processHeaderField(hf) } if err := state.err; err != nil { if se, ok := err.(StreamError); ok { t.controlBuf.put(&resetStream{s.id, statusCodeConvTab[se.Code]}) } return } if frame.StreamEnded() { // s is just created by the caller. No lock needed. s.state = streamReadDone } s.recvCompress = state.encoding if state.timeoutSet { s.ctx, s.cancel = context.WithTimeout(t.ctx, state.timeout) } else { s.ctx, s.cancel = context.WithCancel(t.ctx) } pr := &peer.Peer{ Addr: t.remoteAddr, } // Attach Auth info if there is any. if t.authInfo != nil { pr.AuthInfo = t.authInfo } s.ctx = peer.NewContext(s.ctx, pr) // Cache the current stream to the context so that the server application // can find out. Required when the server wants to send some metadata // back to the client (unary call only). s.ctx = newContextWithStream(s.ctx, s) // Attach the received metadata to the context. if len(state.mdata) > 0 { s.ctx = metadata.NewContext(s.ctx, state.mdata) } s.dec = &recvBufferReader{ ctx: s.ctx, recv: s.buf, } s.recvCompress = state.encoding s.method = state.method if t.inTapHandle != nil { var err error info := &tap.Info{ FullMethodName: state.method, } s.ctx, err = t.inTapHandle(s.ctx, info) if err != nil { // TODO: Log the real error. t.controlBuf.put(&resetStream{s.id, http2.ErrCodeRefusedStream}) return } } t.mu.Lock() if t.state != reachable { t.mu.Unlock() return } if uint32(len(t.activeStreams)) >= t.maxStreams { t.mu.Unlock() t.controlBuf.put(&resetStream{s.id, http2.ErrCodeRefusedStream}) return } if s.id%2 != 1 || s.id <= t.maxStreamID { t.mu.Unlock() // illegal gRPC stream id. grpclog.Println("transport: http2Server.HandleStreams received an illegal stream id: ", s.id) return true } t.maxStreamID = s.id s.sendQuotaPool = newQuotaPool(int(t.streamSendQuota)) t.activeStreams[s.id] = s t.mu.Unlock() s.windowHandler = func(n int) { t.updateWindow(s, uint32(n)) } s.ctx = traceCtx(s.ctx, s.method) if stats.On() { s.ctx = stats.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method}) inHeader := &stats.InHeader{ FullMethod: s.method, RemoteAddr: t.remoteAddr, LocalAddr: t.localAddr, Compression: s.recvCompress, WireLength: int(frame.Header().Length), } stats.HandleRPC(s.ctx, inHeader) } handle(s) return }
// newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2 // and starts to receive messages on it. Non-nil error returns if construction // fails. func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) (_ ClientTransport, err error) { scheme := "http" conn, err := dial(ctx, opts.Dialer, addr.Addr) if err != nil { if opts.FailOnNonTempDialError { return nil, connectionErrorf(isTemporary(err), err, "transport: %v", err) } return nil, connectionErrorf(true, err, "transport: %v", err) } // Any further errors will close the underlying connection defer func(conn net.Conn) { if err != nil { conn.Close() } }(conn) var authInfo credentials.AuthInfo if creds := opts.TransportCredentials; creds != nil { scheme = "https" conn, authInfo, err = creds.ClientHandshake(ctx, addr.Addr, conn) if err != nil { // Credentials handshake errors are typically considered permanent // to avoid retrying on e.g. bad certificates. temp := isTemporary(err) return nil, connectionErrorf(temp, err, "transport: %v", err) } } ua := primaryUA if opts.UserAgent != "" { ua = opts.UserAgent + " " + ua } var buf bytes.Buffer t := &http2Client{ ctx: ctx, target: addr.Addr, userAgent: ua, md: addr.Metadata, conn: conn, remoteAddr: conn.RemoteAddr(), localAddr: conn.LocalAddr(), authInfo: authInfo, // The client initiated stream id is odd starting from 1. nextID: 1, writableChan: make(chan int, 1), shutdownChan: make(chan struct{}), errorChan: make(chan struct{}), goAway: make(chan struct{}), framer: newFramer(conn), hBuf: &buf, hEnc: hpack.NewEncoder(&buf), controlBuf: newRecvBuffer(), fc: &inFlow{limit: initialConnWindowSize}, sendQuotaPool: newQuotaPool(defaultWindowSize), scheme: scheme, state: reachable, activeStreams: make(map[uint32]*Stream), creds: opts.PerRPCCredentials, maxStreams: math.MaxInt32, streamSendQuota: defaultWindowSize, } // Start the reader goroutine for incoming message. Each transport has // a dedicated goroutine which reads HTTP2 frame from network. Then it // dispatches the frame to the corresponding stream entity. go t.reader() // Send connection preface to server. n, err := t.conn.Write(clientPreface) if err != nil { t.Close() return nil, connectionErrorf(true, err, "transport: %v", err) } if n != len(clientPreface) { t.Close() return nil, connectionErrorf(true, err, "transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface)) } if initialWindowSize != defaultWindowSize { err = t.framer.writeSettings(true, http2.Setting{ ID: http2.SettingInitialWindowSize, Val: uint32(initialWindowSize), }) } else { err = t.framer.writeSettings(true) } if err != nil { t.Close() return nil, connectionErrorf(true, err, "transport: %v", err) } // Adjust the connection flow control window if needed. if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 { if err := t.framer.writeWindowUpdate(true, 0, delta); err != nil { t.Close() return nil, connectionErrorf(true, err, "transport: %v", err) } } go t.controller() t.writableChan <- 0 if stats.On() { t.ctx = stats.TagConn(t.ctx, &stats.ConnTagInfo{ RemoteAddr: t.remoteAddr, LocalAddr: t.localAddr, }) connBegin := &stats.ConnBegin{ Client: true, } stats.HandleConn(t.ctx, connBegin) } return t, nil }
// NewStream creates a stream and register it into the transport as "active" // streams. func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Stream, err error) { pr := &peer.Peer{ Addr: t.remoteAddr, } // Attach Auth info if there is any. if t.authInfo != nil { pr.AuthInfo = t.authInfo } userCtx := ctx ctx = peer.NewContext(ctx, pr) authData := make(map[string]string) for _, c := range t.creds { // Construct URI required to get auth request metadata. var port string if pos := strings.LastIndex(t.target, ":"); pos != -1 { // Omit port if it is the default one. if t.target[pos+1:] != "443" { port = ":" + t.target[pos+1:] } } pos := strings.LastIndex(callHdr.Method, "/") if pos == -1 { return nil, streamErrorf(codes.InvalidArgument, "transport: malformed method name: %q", callHdr.Method) } audience := "https://" + callHdr.Host + port + callHdr.Method[:pos] data, err := c.GetRequestMetadata(ctx, audience) if err != nil { return nil, streamErrorf(codes.InvalidArgument, "transport: %v", err) } for k, v := range data { authData[k] = v } } t.mu.Lock() if t.activeStreams == nil { t.mu.Unlock() return nil, ErrConnClosing } if t.state == draining { t.mu.Unlock() return nil, ErrStreamDrain } if t.state != reachable { t.mu.Unlock() return nil, ErrConnClosing } checkStreamsQuota := t.streamsQuota != nil t.mu.Unlock() if checkStreamsQuota { sq, err := wait(ctx, nil, nil, t.shutdownChan, t.streamsQuota.acquire()) if err != nil { return nil, err } // Returns the quota balance back. if sq > 1 { t.streamsQuota.add(sq - 1) } } if _, err := wait(ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil { // Return the quota back now because there is no stream returned to the caller. if _, ok := err.(StreamError); ok && checkStreamsQuota { t.streamsQuota.add(1) } return nil, err } t.mu.Lock() if t.state == draining { t.mu.Unlock() if checkStreamsQuota { t.streamsQuota.add(1) } // Need to make t writable again so that the rpc in flight can still proceed. t.writableChan <- 0 return nil, ErrStreamDrain } if t.state != reachable { t.mu.Unlock() return nil, ErrConnClosing } s := t.newStream(ctx, callHdr) s.clientStatsCtx = userCtx t.activeStreams[s.id] = s // This stream is not counted when applySetings(...) initialize t.streamsQuota. // Reset t.streamsQuota to the right value. var reset bool if !checkStreamsQuota && t.streamsQuota != nil { reset = true } t.mu.Unlock() if reset { t.streamsQuota.add(-1) } // HPACK encodes various headers. Note that once WriteField(...) is // called, the corresponding headers/continuation frame has to be sent // because hpack.Encoder is stateful. t.hBuf.Reset() t.hEnc.WriteField(hpack.HeaderField{Name: ":method", Value: "POST"}) t.hEnc.WriteField(hpack.HeaderField{Name: ":scheme", Value: t.scheme}) t.hEnc.WriteField(hpack.HeaderField{Name: ":path", Value: callHdr.Method}) t.hEnc.WriteField(hpack.HeaderField{Name: ":authority", Value: callHdr.Host}) t.hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"}) t.hEnc.WriteField(hpack.HeaderField{Name: "user-agent", Value: t.userAgent}) t.hEnc.WriteField(hpack.HeaderField{Name: "te", Value: "trailers"}) if callHdr.SendCompress != "" { t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: callHdr.SendCompress}) } if dl, ok := ctx.Deadline(); ok { // Send out timeout regardless its value. The server can detect timeout context by itself. timeout := dl.Sub(time.Now()) t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: encodeTimeout(timeout)}) } for k, v := range authData { // Capital header names are illegal in HTTP/2. k = strings.ToLower(k) t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: v}) } var ( hasMD bool endHeaders bool ) if md, ok := metadata.FromContext(ctx); ok { hasMD = true for k, v := range md { // HTTP doesn't allow you to set pseudoheaders after non pseudoheaders were set. if isReservedHeader(k) { continue } for _, entry := range v { t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: entry}) } } } if md, ok := t.md.(*metadata.MD); ok { for k, v := range *md { if isReservedHeader(k) { continue } for _, entry := range v { t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: entry}) } } } first := true bufLen := t.hBuf.Len() // Sends the headers in a single batch even when they span multiple frames. for !endHeaders { size := t.hBuf.Len() if size > http2MaxFrameLen { size = http2MaxFrameLen } else { endHeaders = true } var flush bool if endHeaders && (hasMD || callHdr.Flush) { flush = true } if first { // Sends a HeadersFrame to server to start a new stream. p := http2.HeadersFrameParam{ StreamID: s.id, BlockFragment: t.hBuf.Next(size), EndStream: false, EndHeaders: endHeaders, } // Do a force flush for the buffered frames iff it is the last headers frame // and there is header metadata to be sent. Otherwise, there is flushing until // the corresponding data frame is written. err = t.framer.writeHeaders(flush, p) first = false } else { // Sends Continuation frames for the leftover headers. err = t.framer.writeContinuation(flush, s.id, endHeaders, t.hBuf.Next(size)) } if err != nil { t.notifyError(err) return nil, connectionErrorf(true, err, "transport: %v", err) } } if stats.On() { outHeader := &stats.OutHeader{ Client: true, WireLength: bufLen, FullMethod: callHdr.Method, RemoteAddr: t.remoteAddr, LocalAddr: t.localAddr, Compression: callHdr.SendCompress, } stats.HandleRPC(s.clientStatsCtx, outHeader) } t.writableChan <- 0 return s, nil }
// operateHeaders takes action on the decoded headers. func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { s, ok := t.getStream(frame) if !ok { return } var state decodeState for _, hf := range frame.Fields { state.processHeaderField(hf) } if state.err != nil { s.mu.Lock() if !s.headerDone { close(s.headerChan) s.headerDone = true } s.mu.Unlock() s.write(recvMsg{err: state.err}) // Something wrong. Stops reading even when there is remaining. return } endStream := frame.StreamEnded() var isHeader bool defer func() { if stats.On() { if isHeader { inHeader := &stats.InHeader{ Client: true, WireLength: int(frame.Header().Length), } stats.HandleRPC(s.clientStatsCtx, inHeader) } else { inTrailer := &stats.InTrailer{ Client: true, WireLength: int(frame.Header().Length), } stats.HandleRPC(s.clientStatsCtx, inTrailer) } } }() s.mu.Lock() if !endStream { s.recvCompress = state.encoding } if !s.headerDone { if !endStream && len(state.mdata) > 0 { s.header = state.mdata } close(s.headerChan) s.headerDone = true isHeader = true } if !endStream || s.state == streamDone { s.mu.Unlock() return } if len(state.mdata) > 0 { s.trailer = state.mdata } s.statusCode = state.statusCode s.statusDesc = state.statusDesc close(s.done) s.state = streamDone s.mu.Unlock() s.write(recvMsg{err: io.EOF}) }
func (cs *clientStream) RecvMsg(m interface{}) (err error) { defer func() { if err != nil && stats.On() { // Only generate End if err != nil. // If err == nil, it's not the last RecvMsg. // The last RecvMsg gets either an RPC error or io.EOF. end := &stats.End{ Client: true, EndTime: time.Now(), } if err != io.EOF { end.Error = toRPCErr(err) } stats.HandleRPC(cs.statsCtx, end) } }() var inPayload *stats.InPayload if stats.On() { inPayload = &stats.InPayload{ Client: true, } } err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32, inPayload) defer func() { // err != nil indicates the termination of the stream. if err != nil { cs.finish(err) } }() if err == nil { if cs.tracing { cs.mu.Lock() if cs.trInfo.tr != nil { cs.trInfo.tr.LazyLog(&payload{sent: false, msg: m}, true) } cs.mu.Unlock() } if inPayload != nil { stats.HandleRPC(cs.statsCtx, inPayload) } if !cs.desc.ClientStreams || cs.desc.ServerStreams { return } // Special handling for client streaming rpc. // This recv expects EOF or errors, so we don't collect inPayload. err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32, nil) cs.closeTransportStream(err) if err == nil { return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>")) } if err == io.EOF { if cs.s.StatusCode() == codes.OK { cs.finish(err) return nil } return Errorf(cs.s.StatusCode(), "%s", cs.s.StatusDesc()) } return toRPCErr(err) } if _, ok := err.(transport.ConnectionError); !ok { cs.closeTransportStream(err) } if err == io.EOF { if cs.s.StatusCode() == codes.OK { // Returns io.EOF to indicate the end of the stream. return } return Errorf(cs.s.StatusCode(), "%s", cs.s.StatusDesc()) } return toRPCErr(err) }
func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, md *MethodDesc, trInfo *traceInfo) (err error) { if stats.On() { begin := &stats.Begin{ BeginTime: time.Now(), } stats.HandleRPC(stream.Context(), begin) } defer func() { if stats.On() { end := &stats.End{ EndTime: time.Now(), } if err != nil && err != io.EOF { end.Error = toRPCErr(err) } stats.HandleRPC(stream.Context(), end) } }() if trInfo != nil { defer trInfo.tr.Finish() trInfo.firstLine.client = false trInfo.tr.LazyLog(&trInfo.firstLine, false) defer func() { if err != nil && err != io.EOF { trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true) trInfo.tr.SetError() } }() } if s.opts.cp != nil { // NOTE: this needs to be ahead of all handling, https://github.com/grpc/grpc-go/issues/686. stream.SetSendCompress(s.opts.cp.Type()) } p := &parser{r: stream} for { pf, req, err := p.recvMsg(s.opts.maxMsgSize) if err == io.EOF { // The entire stream is done (for unary RPC only). return err } if err == io.ErrUnexpectedEOF { err = Errorf(codes.Internal, io.ErrUnexpectedEOF.Error()) } if err != nil { switch err := err.(type) { case *rpcError: if e := t.WriteStatus(stream, err.code, err.desc); e != nil { grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e) } case transport.ConnectionError: // Nothing to do here. case transport.StreamError: if e := t.WriteStatus(stream, err.Code, err.Desc); e != nil { grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e) } default: panic(fmt.Sprintf("grpc: Unexpected error (%T) from recvMsg: %v", err, err)) } return err } if err := checkRecvPayload(pf, stream.RecvCompress(), s.opts.dc); err != nil { switch err := err.(type) { case *rpcError: if e := t.WriteStatus(stream, err.code, err.desc); e != nil { grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e) } return err default: if e := t.WriteStatus(stream, codes.Internal, err.Error()); e != nil { grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e) } // TODO checkRecvPayload always return RPC error. Add a return here if necessary. } } var inPayload *stats.InPayload if stats.On() { inPayload = &stats.InPayload{ RecvTime: time.Now(), } } statusCode := codes.OK statusDesc := "" df := func(v interface{}) error { if inPayload != nil { inPayload.WireLength = len(req) } if pf == compressionMade { var err error req, err = s.opts.dc.Do(bytes.NewReader(req)) if err != nil { if err := t.WriteStatus(stream, codes.Internal, err.Error()); err != nil { grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err) } return Errorf(codes.Internal, err.Error()) } } if len(req) > s.opts.maxMsgSize { // TODO: Revisit the error code. Currently keep it consistent with // java implementation. statusCode = codes.Internal statusDesc = fmt.Sprintf("grpc: server received a message of %d bytes exceeding %d limit", len(req), s.opts.maxMsgSize) } if err := s.opts.codec.Unmarshal(req, v); err != nil { return err } if inPayload != nil { inPayload.Payload = v inPayload.Data = req inPayload.Length = len(req) stats.HandleRPC(stream.Context(), inPayload) } if trInfo != nil { trInfo.tr.LazyLog(&payload{sent: false, msg: v}, true) } return nil } reply, appErr := md.Handler(srv.server, stream.Context(), df, s.opts.unaryInt) if appErr != nil { if err, ok := appErr.(*rpcError); ok { statusCode = err.code statusDesc = err.desc } else { statusCode = convertCode(appErr) statusDesc = appErr.Error() } if trInfo != nil && statusCode != codes.OK { trInfo.tr.LazyLog(stringer(statusDesc), true) trInfo.tr.SetError() } if err := t.WriteStatus(stream, statusCode, statusDesc); err != nil { grpclog.Printf("grpc: Server.processUnaryRPC failed to write status: %v", err) } return Errorf(statusCode, statusDesc) } if trInfo != nil { trInfo.tr.LazyLog(stringer("OK"), false) } opts := &transport.Options{ Last: true, Delay: false, } if err := s.sendResponse(t, stream, reply, s.opts.cp, opts); err != nil { switch err := err.(type) { case transport.ConnectionError: // Nothing to do here. case transport.StreamError: statusCode = err.Code statusDesc = err.Desc default: statusCode = codes.Unknown statusDesc = err.Error() } return err } if trInfo != nil { trInfo.tr.LazyLog(&payload{sent: true, msg: reply}, true) } errWrite := t.WriteStatus(stream, statusCode, statusDesc) if statusCode != codes.OK { return Errorf(statusCode, statusDesc) } return errWrite } }
func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) { var ( t transport.ClientTransport s *transport.Stream put func() ) c := defaultCallInfo for _, o := range opts { if err := o.before(&c); err != nil { return nil, toRPCErr(err) } } callHdr := &transport.CallHdr{ Host: cc.authority, Method: method, Flush: desc.ServerStreams && desc.ClientStreams, } if cc.dopts.cp != nil { callHdr.SendCompress = cc.dopts.cp.Type() } var trInfo traceInfo if EnableTracing { trInfo.tr = trace.New("grpc.Sent."+methodFamily(method), method) trInfo.firstLine.client = true if deadline, ok := ctx.Deadline(); ok { trInfo.firstLine.deadline = deadline.Sub(time.Now()) } trInfo.tr.LazyLog(&trInfo.firstLine, false) ctx = trace.NewContext(ctx, trInfo.tr) defer func() { if err != nil { // Need to call tr.finish() if error is returned. // Because tr will not be returned to caller. trInfo.tr.LazyPrintf("RPC: [%v]", err) trInfo.tr.SetError() trInfo.tr.Finish() } }() } if stats.On() { ctx = stats.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method}) begin := &stats.Begin{ Client: true, BeginTime: time.Now(), FailFast: c.failFast, } stats.HandleRPC(ctx, begin) } defer func() { if err != nil && stats.On() { // Only handle end stats if err != nil. end := &stats.End{ Client: true, Error: err, } stats.HandleRPC(ctx, end) } }() gopts := BalancerGetOptions{ BlockingWait: !c.failFast, } for { t, put, err = cc.getTransport(ctx, gopts) if err != nil { // TODO(zhaoq): Probably revisit the error handling. if _, ok := err.(*rpcError); ok { return nil, err } if err == errConnClosing || err == errConnUnavailable { if c.failFast { return nil, Errorf(codes.Unavailable, "%v", err) } continue } // All the other errors are treated as Internal errors. return nil, Errorf(codes.Internal, "%v", err) } s, err = t.NewStream(ctx, callHdr) if err != nil { if put != nil { put() put = nil } if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain { if c.failFast { return nil, toRPCErr(err) } continue } return nil, toRPCErr(err) } break } cs := &clientStream{ opts: opts, c: c, desc: desc, codec: cc.dopts.codec, cp: cc.dopts.cp, dc: cc.dopts.dc, put: put, t: t, s: s, p: &parser{r: s}, tracing: EnableTracing, trInfo: trInfo, statsCtx: ctx, } if cc.dopts.cp != nil { cs.cbuf = new(bytes.Buffer) } // Listen on ctx.Done() to detect cancellation and s.Done() to detect normal termination // when there is no pending I/O operations on this stream. go func() { select { case <-t.Error(): // Incur transport error, simply exit. case <-s.Done(): // TODO: The trace of the RPC is terminated here when there is no pending // I/O, which is probably not the optimal solution. if s.StatusCode() == codes.OK { cs.finish(nil) } else { cs.finish(Errorf(s.StatusCode(), "%s", s.StatusDesc())) } cs.closeTransportStream(nil) case <-s.GoAway(): cs.finish(errConnDrain) cs.closeTransportStream(errConnDrain) case <-s.Context().Done(): err := s.Context().Err() cs.finish(err) cs.closeTransportStream(transport.ContextErr(err)) } }() return cs, nil }