Ejemplo n.º 1
0
func TestAcceptStream(t *testing.T) {
	t.Parallel()

	local, remote := newFakeConnPair()

	// don't need the remote output
	remote.Discard()

	// true for a client session
	s := NewSession(local, NewStream, true)
	defer s.Close()

	f := frame.NewWStreamSyn()
	f.Set(300, 0, 0, false)

	// send the frame into the session
	trans := frame.NewBasicTransport(remote)
	trans.WriteFrame(f)

	done := make(chan int)
	go func() {
		defer func() { done <- 1 }()

		// wait for accept
		str, err := s.Accept()

		if err != nil {
			t.Errorf("Error accepting stream: %v", err)
			return
		}

		if str.Id() != frame.StreamId(300) {
			t.Errorf("Stream has wrong id. Expected %d, got %d", str.Id(), 300)
		}
	}()

	select {
	case <-time.After(time.Second):
		t.Fatalf("Timed out!")
	case <-done:
	}
}
Ejemplo n.º 2
0
func (s *Session) GoAway(errorCode frame.ErrorCode, debug []byte) (err error) {
	if !atomic.CompareAndSwapInt32(&s.local.goneAway, 0, 1) {
		return fmt.Errorf("Already sent GoAway!")
	}

	s.wr.Lock()
	defer s.wr.Unlock()

	f := frame.NewWGoAway()
	remoteId := frame.StreamId(atomic.LoadUint32(&s.remote.lastId))
	if err = f.Set(remoteId, errorCode, debug); err != nil {
		s.die(frame.InternalError, err)
		return
	}

	if err = s.transport.WriteFrame(f); err != nil {
		s.die(frame.InternalError, err)
		return
	}

	return
}
Ejemplo n.º 3
0
func (s *Session) OpenStream(priority frame.StreamPriority, relatedStreamId frame.StreamId, fin bool, info []byte) (ret IStream, err error) {
	// check if the remote has gone away
	if atomic.LoadInt32(&s.remote.goneAway) == 1 {
		return nil, fmt.Errorf("Failed to create stream, remote has gone away.")
	}

	// this lock prevents the following race:
	// goroutine1       goroutine2
	// - inc stream id
	//                  - inc stream id
	//                  - send streamsyn
	// - send streamsyn
	s.wr.Lock()
	defer s.wr.Unlock()

	// get the next id we can use
	nextId := frame.StreamId(atomic.AddUint32(&s.local.lastId, 2))

	// make the stream
	str := s.newStream(nextId, relatedStreamId, priority, info, fin, false, s.defaultWindowSize, s)

	// add to to the stream map
	s.streams.Set(nextId, str)

	// write the frame
	if err = s.syn.Set(nextId, relatedStreamId, priority, fin, info); err != nil {
		s.die(frame.InternalError, err)
		return
	}

	if err = s.transport.WriteFrame(s.syn); err != nil {
		s.die(frame.InternalError, err)
		return
	}

	return str, nil
}
Ejemplo n.º 4
0
func (a *sessionAdaptor) OpenStream(priority StreamPriority, related StreamId, fin bool, info []byte) (Stream, error) {
	str, err := a.ISession.OpenStream(frame.StreamPriority(priority), frame.StreamId(related), fin, info)
	return &streamAdaptor{str}, err
}
Ejemplo n.º 5
0
func (s *Session) handleFrame(rf frame.RFrame) {
	switch f := rf.(type) {
	case *frame.RStreamSyn:
		// if we're going away, refuse new streams
		if atomic.LoadInt32(&s.local.goneAway) == 1 {
			rstF := frame.NewWStreamRst()
			rstF.Set(f.StreamId(), frame.RefusedStream)
			go s.writeFrame(rstF, time.Time{})
			return
		}

		if f.StreamId() <= frame.StreamId(atomic.LoadUint32(&s.remote.lastId)) {
			s.die(frame.ProtocolError, fmt.Errorf("Stream id %d is less than last remote id.", f.StreamId()))
			return
		}

		if s.isLocal(f.StreamId()) {
			s.die(frame.ProtocolError, fmt.Errorf("Stream id has wrong parity for remote endpoint: %d", f.StreamId()))
			return
		}

		// update last remote id
		atomic.StoreUint32(&s.remote.lastId, uint32(f.StreamId()))

		// make the new stream
		str := s.newStream(f.StreamId(), f.RelatedStreamId(), f.StreamPriority(), f.StreamInfo(), false, f.Fin(), s.defaultWindowSize, s)

		// add it to the stream map
		s.streams.Set(f.StreamId(), str)

		// put the new stream on the accept channel
		s.accept <- str

	case *frame.RStreamData:
		if str := s.getStream(f.StreamId()); str != nil {
			str.handleStreamData(f)
		} else {
			// DATA frames on closed connections are just stream-level errors
			fRst := frame.NewWStreamRst()
			if err := fRst.Set(f.StreamId(), frame.StreamClosed); err != nil {
				s.die(frame.InternalError, err)
			}

			s.wr.Lock()
			defer s.wr.Unlock()

			s.transport.WriteFrame(fRst)
			return
		}

	case *frame.RStreamRst:
		// delegate to the stream to handle these frames
		if str := s.getStream(f.StreamId()); str != nil {
			str.handleStreamRst(f)
		}
	case *frame.RStreamWndInc:
		// delegate to the stream to handle these frames
		if str := s.getStream(f.StreamId()); str != nil {
			str.handleStreamWndInc(f)
		}

	case *frame.RGoAway:
		atomic.StoreInt32(&s.remote.goneAway, 1)
		s.remoteDebug = f.Debug()

		lastId := f.LastStreamId()
		s.streams.Each(func(id frame.StreamId, str stream) {
			// close all streams that we opened above the last handled id
			if s.isLocal(str.Id()) && str.Id() > lastId {
				str.closeWith(fmt.Errorf("Remote is going away"))
			}
		})

	case *frame.RPing:
		if !f.Ack() {
			pingF := frame.NewWPing()
			pingF.Set(f.StreamId(), f.Body(), true)
			go s.writeFrame(pingF, time.Time{})
		}

	default:
		s.die(frame.ProtocolError, fmt.Errorf("Unrecognized frame type: %v", reflect.TypeOf(f)))
		return
	}
}