func TestDataAfterRst(t *testing.T) { local, remote := newFakeConnPair() _ = NewSession(local, NewStream, false) trans := frame.NewBasicTransport(remote) // make sure that we get an RST STREAM_CLOSED done := make(chan int) go func() { defer func() { done <- 1 }() f, err := trans.ReadFrame() if err != nil { t.Errorf("Failed to read frame sent from session: %v", err) return } fr, ok := f.(*frame.RStreamRst) if !ok { t.Errorf("Frame is not STREAM_RST: %v", f) return } if fr.ErrorCode() != frame.StreamClosed { t.Errorf("Error code on STREAM_RST is not STREAM_CLOSED. Got %d, expected %d", fr.ErrorCode(), frame.StreamClosed) return } }() fSyn := frame.NewWStreamSyn() if err := fSyn.Set(301, 0, 0, false); err != nil { t.Fatalf("Failed to make syn frame: %v", err) } if err := trans.WriteFrame(fSyn); err != nil { t.Fatalf("Failed to send syn: %v", err) } fRst := frame.NewWStreamRst() if err := fRst.Set(301, frame.Cancel); err != nil { t.Fatal("Failed to make rst frame: %v", err) } if err := trans.WriteFrame(fRst); err != nil { t.Fatalf("Failed to write rst frame: %v", err) } fData := frame.NewWStreamData() if err := fData.Set(301, []byte{0xa, 0xFF}, false); err != nil { t.Fatalf("Failed to set data frame") } trans.WriteFrame(fData) <-done }
func (s *Stream) resetWith(errorCode frame.ErrorCode, resetErr error) { // only ever send one reset if !atomic.CompareAndSwapUint32(&s.sentRst, 0, 1) { return } // close the stream s.closeWithAndRemoveLater(resetErr) // make the reset frame rst := frame.NewWStreamRst() if err := rst.Set(s.id, errorCode); err != nil { s.die(frame.InternalError, err) } // need write lock to make sure no data frames get sent after we send the reset s.writer.Lock() defer s.writer.Unlock() // send it if err := s.session.writeFrame(rst, zeroTime); err != nil { s.die(frame.InternalError, err) } }
func (s *Session) handleFrame(rf frame.RFrame) { switch f := rf.(type) { case *frame.RStreamSyn: // if we're going away, refuse new streams if atomic.LoadInt32(&s.local.goneAway) == 1 { rstF := frame.NewWStreamRst() rstF.Set(f.StreamId(), frame.RefusedStream) go s.writeFrame(rstF, time.Time{}) return } if f.StreamId() <= frame.StreamId(atomic.LoadUint32(&s.remote.lastId)) { s.die(frame.ProtocolError, fmt.Errorf("Stream id %d is less than last remote id.", f.StreamId())) return } if s.isLocal(f.StreamId()) { s.die(frame.ProtocolError, fmt.Errorf("Stream id has wrong parity for remote endpoint: %d", f.StreamId())) return } // update last remote id atomic.StoreUint32(&s.remote.lastId, uint32(f.StreamId())) // make the new stream str := s.newStream(f.StreamId(), f.RelatedStreamId(), f.StreamPriority(), f.StreamInfo(), false, f.Fin(), s.defaultWindowSize, s) // add it to the stream map s.streams.Set(f.StreamId(), str) // put the new stream on the accept channel s.accept <- str case *frame.RStreamData: if str := s.getStream(f.StreamId()); str != nil { str.handleStreamData(f) } else { // DATA frames on closed connections are just stream-level errors fRst := frame.NewWStreamRst() if err := fRst.Set(f.StreamId(), frame.StreamClosed); err != nil { s.die(frame.InternalError, err) } s.wr.Lock() defer s.wr.Unlock() s.transport.WriteFrame(fRst) return } case *frame.RStreamRst: // delegate to the stream to handle these frames if str := s.getStream(f.StreamId()); str != nil { str.handleStreamRst(f) } case *frame.RStreamWndInc: // delegate to the stream to handle these frames if str := s.getStream(f.StreamId()); str != nil { str.handleStreamWndInc(f) } case *frame.RGoAway: atomic.StoreInt32(&s.remote.goneAway, 1) s.remoteDebug = f.Debug() lastId := f.LastStreamId() s.streams.Each(func(id frame.StreamId, str stream) { // close all streams that we opened above the last handled id if s.isLocal(str.Id()) && str.Id() > lastId { str.closeWith(fmt.Errorf("Remote is going away")) } }) case *frame.RPing: if !f.Ack() { pingF := frame.NewWPing() pingF.Set(f.StreamId(), f.Body(), true) go s.writeFrame(pingF, time.Time{}) } default: s.die(frame.ProtocolError, fmt.Errorf("Unrecognized frame type: %v", reflect.TypeOf(f))) return } }