// WithProcessClosing returns a context.Context derived from ctx that // is cancelled as p is Closing (after: <-p.Closing()). It is simply: // // func WithProcessClosing(ctx context.Context, p goprocess.Process) context.Context { // ctx, cancel := context.WithCancel(ctx) // go func() { // <-p.Closing() // cancel() // }() // return ctx // } // func WithProcessClosing(ctx context.Context, p goprocess.Process) context.Context { ctx, cancel := context.WithCancel(ctx) go func() { <-p.Closing() cancel() }() return ctx }
func TestCancelAfterRequest(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) resp, err := doRequest(ctx) // Cancel before reading the body. // Request.Body should still be readable after the context is canceled. cancel() b, err := ioutil.ReadAll(resp.Body) if err != nil || string(b) != requestBody { t.Fatalf("could not read body: %q %v", b, err) } }
func TestCancel(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) go func() { time.Sleep(requestDuration / 2) cancel() }() resp, err := doRequest(ctx) if resp != nil || err == nil { t.Fatalf("expected error, didn't get one. resp: %v", resp) } if err != ctx.Err() { t.Fatalf("expected error from context but got: %v", err) } }
func TestSecureHandshakeFailsWithWrongKeys(t *testing.T) { // t.Skip("Skipping in favor of another test") ctx, cancel := context.WithCancel(context.Background()) defer cancel() c1, c2, p1, p2 := setupSingleConn(t, ctx) done := make(chan error) go secureHandshake(t, ctx, p2.PrivKey, c1, done) go secureHandshake(t, ctx, p1.PrivKey, c2, done) for i := 0; i < 2; i++ { if err := <-done; err == nil { t.Fatal("wrong keys should've errored out.") } } }
func TestClose(t *testing.T) { // t.Skip("Skipping in favor of another test") ctx, cancel := context.WithCancel(context.Background()) defer cancel() c1, c2, _, _ := setupSingleConn(t, ctx) testOneSendRecv(t, c1, c2) testOneSendRecv(t, c2, c1) c1.Close() testNotOneSendRecv(t, c1, c2) c2.Close() testNotOneSendRecv(t, c2, c1) testNotOneSendRecv(t, c1, c2) }
func testPing(t *testing.T, ps *PingService, p peer.ID) { pctx, cancel := context.WithCancel(context.Background()) defer cancel() ts, err := ps.Ping(pctx, p) if err != nil { t.Fatal(err) } for i := 0; i < 5; i++ { select { case took := <-ts: t.Log("ping took: ", took) case <-time.After(time.Second * 4): t.Fatal("failed to receive ping") } } }
func TestSecureCancelHandshake(t *testing.T) { // t.Skip("Skipping in favor of another test") ctx, cancel := context.WithCancel(context.Background()) c1, c2, p1, p2 := setupSingleConn(t, ctx) done := make(chan error) go secureHandshake(t, ctx, p1.PrivKey, c1, done) time.Sleep(time.Millisecond) cancel() // cancel ctx go secureHandshake(t, ctx, p2.PrivKey, c2, done) for i := 0; i < 2; i++ { if err := <-done; err == nil { t.Error("cancel should've errored out") } } }
func newSecureSession(ctx context.Context, local peer.ID, key ci.PrivKey, insecure io.ReadWriteCloser) (*secureSession, error) { s := &secureSession{localPeer: local, localKey: key} s.ctx, s.cancel = context.WithCancel(ctx) switch { case s.localPeer == "": return nil, errors.New("no local id provided") case s.localKey == nil: return nil, errors.New("no local private key provided") case !s.localPeer.MatchesPrivateKey(s.localKey): return nil, fmt.Errorf("peer.ID does not match PrivateKey") case insecure == nil: return nil, fmt.Errorf("insecure ReadWriter is nil") } s.ctx = ctx s.insecure = insecure s.insecureM = msgio.NewReadWriter(insecure) return s, nil }
func TestPing(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() h1 := netutil.GenHostSwarm(t, ctx) h2 := netutil.GenHostSwarm(t, ctx) err := h1.Connect(ctx, peer.PeerInfo{ ID: h2.ID(), Addrs: h2.Addrs(), }) if err != nil { t.Fatal(err) } ps1 := NewPingService(h1) ps2 := NewPingService(h2) testPing(t, ps1, h2.ID()) testPing(t, ps2, h1.ID()) }
func (s *Swarm) dialAddrs(ctx context.Context, p peer.ID, remoteAddrs []ma.Multiaddr) (conn.Conn, error) { // sort addresses so preferred addresses are dialed sooner sort.Sort(AddrList(remoteAddrs)) // try to connect to one of the peer's known addresses. // we dial concurrently to each of the addresses, which: // * makes the process faster overall // * attempts to get the fastest connection available. // * mitigates the waste of trying bad addresses log.Debugf("%s swarm dialing %s %s", s.local, p, remoteAddrs) ctx, cancel := context.WithCancel(ctx) defer cancel() // cancel work when we exit func conns := make(chan conn.Conn) errs := make(chan error, len(remoteAddrs)) // dialSingleAddr is used in the rate-limited async thing below. dialSingleAddr := func(addr ma.Multiaddr) { // rebind chans in scope so we can nil them out easily connsout := conns errsout := errs connC, err := s.dialAddr(ctx, p, addr) if err != nil { connsout = nil } else if connC == nil { // NOTE: this really should never happen log.Errorf("failed to dial %s %s and got no error!", p, addr) err = fmt.Errorf("failed to dial %s %s", p, addr) connsout = nil } else { errsout = nil } // check parent still wants our results select { case <-ctx.Done(): if connC != nil { connC.Close() } case errsout <- err: case connsout <- connC: } } // this whole thing is in a goroutine so we can use foundConn // to end early. go func() { limiter := make(chan struct{}, 8) for _, addr := range remoteAddrs { // returns whatever ratelimiting is acceptable for workerAddr. // may not rate limit at all. rl := s.addrDialRateLimit(addr) select { case <-ctx.Done(): // our context was cancelled return case rl <- struct{}{}: // take the token, move on } select { case <-ctx.Done(): // our context was cancelled return case limiter <- struct{}{}: // take the token, move on } go func(rlc <-chan struct{}, a ma.Multiaddr) { dialSingleAddr(a) <-limiter <-rlc }(rl, addr) } }() // wair for the results. exitErr := fmt.Errorf("failed to dial %s", p) for range remoteAddrs { select { case exitErr = <-errs: // log.Debug("dial error: ", exitErr) case connC := <-conns: // take the first + return asap return connC, nil case <-ctx.Done(): // break out and return error break } } return nil, exitErr }
func TestSecureCloseLeak(t *testing.T) { // t.Skip("Skipping in favor of another test") if testing.Short() { t.SkipNow() } if travis.IsRunning() { t.Skip("this doesn't work well on travis") } runPair := func(c1, c2 Conn, num int) { mc1 := msgioWrap(c1) mc2 := msgioWrap(c2) log.Debugf("runPair %d", num) for i := 0; i < num; i++ { log.Debugf("runPair iteration %d", i) b1 := []byte("beep") mc1.WriteMsg(b1) b2, err := mc2.ReadMsg() if err != nil { panic(err) } if !bytes.Equal(b1, b2) { panic("bytes not equal") } b2 = []byte("beep") mc2.WriteMsg(b2) b1, err = mc1.ReadMsg() if err != nil { panic(err) } if !bytes.Equal(b1, b2) { panic("bytes not equal") } time.Sleep(time.Microsecond * 5) } } var cons = 5 var msgs = 50 log.Debugf("Running %d connections * %d msgs.\n", cons, msgs) var wg sync.WaitGroup for i := 0; i < cons; i++ { wg.Add(1) ctx, cancel := context.WithCancel(context.Background()) c1, c2, _, _ := setupSecureConn(t, ctx) go func(c1, c2 Conn) { defer func() { c1.Close() c2.Close() cancel() wg.Done() }() runPair(c1, c2, msgs) }(c1, c2) } log.Debugf("Waiting...") wg.Wait() // done! time.Sleep(time.Millisecond * 150) ngr := runtime.NumGoroutine() if ngr > 25 { // panic("uncomment me to debug") t.Fatal("leaking goroutines:", ngr) } }
func TestCloseLeak(t *testing.T) { // t.Skip("Skipping in favor of another test") if testing.Short() { t.SkipNow() } if travis.IsRunning() { t.Skip("this doesn't work well on travis") } var wg sync.WaitGroup runPair := func(num int) { ctx, cancel := context.WithCancel(context.Background()) c1, c2, _, _ := setupSingleConn(t, ctx) mc1 := msgioWrap(c1) mc2 := msgioWrap(c2) for i := 0; i < num; i++ { b1 := []byte(fmt.Sprintf("beep%d", i)) mc1.WriteMsg(b1) b2, err := mc2.ReadMsg() if err != nil { panic(err) } if !bytes.Equal(b1, b2) { panic(fmt.Errorf("bytes not equal: %s != %s", b1, b2)) } b2 = []byte(fmt.Sprintf("boop%d", i)) mc2.WriteMsg(b2) b1, err = mc1.ReadMsg() if err != nil { panic(err) } if !bytes.Equal(b1, b2) { panic(fmt.Errorf("bytes not equal: %s != %s", b1, b2)) } <-time.After(time.Microsecond * 5) } c1.Close() c2.Close() cancel() // close the listener wg.Done() } var cons = 5 var msgs = 50 log.Debugf("Running %d connections * %d msgs.\n", cons, msgs) for i := 0; i < cons; i++ { wg.Add(1) go runPair(msgs) } log.Debugf("Waiting...\n") wg.Wait() // done! time.Sleep(time.Millisecond * 150) ngr := runtime.NumGoroutine() if ngr > 25 { // note, this is really innacurate //panic("uncomment me to debug") t.Fatal("leaking goroutines:", ngr) } }
func SubtestSwarm(t *testing.T, SwarmNum int, MsgNum int) { // t.Skip("skipping for another test") ctx := context.Background() swarms := makeSwarms(ctx, t, SwarmNum) // connect everyone connectSwarms(t, ctx, swarms) // ping/pong for _, s1 := range swarms { log.Debugf("-------------------------------------------------------") log.Debugf("%s ping pong round", s1.local) log.Debugf("-------------------------------------------------------") _, cancel := context.WithCancel(ctx) got := map[peer.ID]int{} errChan := make(chan error, MsgNum*len(swarms)) streamChan := make(chan *Stream, MsgNum) // send out "ping" x MsgNum to every peer go func() { defer close(streamChan) var wg sync.WaitGroup send := func(p peer.ID) { defer wg.Done() // first, one stream per peer (nice) stream, err := s1.NewStreamWithPeer(p) if err != nil { errChan <- err return } // send out ping! for k := 0; k < MsgNum; k++ { // with k messages msg := "ping" log.Debugf("%s %s %s (%d)", s1.local, msg, p, k) if _, err := stream.Write([]byte(msg)); err != nil { errChan <- err continue } } // read it later streamChan <- stream } for _, s2 := range swarms { if s2.local == s1.local { continue // dont send to self... } wg.Add(1) go send(s2.local) } wg.Wait() }() // receive "pong" x MsgNum from every peer go func() { defer close(errChan) count := 0 countShouldBe := MsgNum * (len(swarms) - 1) for stream := range streamChan { // one per peer defer stream.Close() // get peer on the other side p := stream.Conn().RemotePeer() // receive pings msgCount := 0 msg := make([]byte, 4) for k := 0; k < MsgNum; k++ { // with k messages // read from the stream if _, err := stream.Read(msg); err != nil { errChan <- err continue } if string(msg) != "pong" { errChan <- fmt.Errorf("unexpected message: %s", msg) continue } log.Debugf("%s %s %s (%d)", s1.local, msg, p, k) msgCount++ } got[p] = msgCount count += msgCount } if count != countShouldBe { errChan <- fmt.Errorf("count mismatch: %d != %d", count, countShouldBe) } }() // check any errors (blocks till consumer is done) for err := range errChan { if err != nil { t.Error(err.Error()) } } log.Debugf("%s got pongs", s1.local) if (len(swarms) - 1) != len(got) { t.Errorf("got (%d) less messages than sent (%d).", len(got), len(swarms)) } for p, n := range got { if n != MsgNum { t.Error("peer did not get all msgs", p, n, "/", MsgNum) } } cancel() <-time.After(10 * time.Millisecond) } for _, s := range swarms { s.Close() } }
func testDialerCloseEarly(t *testing.T, secure bool) { // t.Skip("Skipping in favor of another test") p1 := tu.RandPeerNetParamsOrFatal(t) p2 := tu.RandPeerNetParamsOrFatal(t) key1 := p1.PrivKey if !secure { key1 = nil t.Log("testing insecurely") } else { t.Log("testing securely") } ctx, cancel := context.WithCancel(context.Background()) l1, err := Listen(ctx, p1.Addr, p1.ID, key1) if err != nil { t.Fatal(err) } p1.Addr = l1.Multiaddr() // Addr has been determined by kernel. // lol nesting d2 := &Dialer{ LocalPeer: p2.ID, // PrivateKey: key2, -- dont give it key. we'll just close the conn. } d2.AddDialer(dialer(t, p2.Addr)) errs := make(chan error, 100) done := make(chan struct{}, 1) gotclosed := make(chan struct{}, 1) go func() { defer func() { done <- struct{}{} }() c, err := l1.Accept() if err != nil { if strings.Contains(err.Error(), "closed") { gotclosed <- struct{}{} return } errs <- err } if _, err := c.Write([]byte("hello")); err != nil { gotclosed <- struct{}{} return } errs <- fmt.Errorf("wrote to conn") }() c, err := d2.Dial(ctx, p1.Addr, p1.ID) if err != nil { t.Fatal(err) } c.Close() // close it early. readerrs := func() { for { select { case e := <-errs: t.Error(e) default: return } } } readerrs() l1.Close() <-done cancel() readerrs() close(errs) select { case <-gotclosed: default: t.Error("did not get closed") } }
func testDialer(t *testing.T, secure bool) { // t.Skip("Skipping in favor of another test") p1 := tu.RandPeerNetParamsOrFatal(t) p2 := tu.RandPeerNetParamsOrFatal(t) key1 := p1.PrivKey key2 := p2.PrivKey if !secure { key1 = nil key2 = nil t.Log("testing insecurely") } else { t.Log("testing securely") } ctx, cancel := context.WithCancel(context.Background()) l1, err := Listen(ctx, p1.Addr, p1.ID, key1) if err != nil { t.Fatal(err) } p1.Addr = l1.Multiaddr() // Addr has been determined by kernel. d2 := &Dialer{ LocalPeer: p2.ID, PrivateKey: key2, } d2.AddDialer(dialer(t, p2.Addr)) go echoListen(ctx, l1) c, err := d2.Dial(ctx, p1.Addr, p1.ID) if err != nil { t.Fatal("error dialing peer", err) } // fmt.Println("sending") mc := msgioWrap(c) mc.WriteMsg([]byte("beep")) mc.WriteMsg([]byte("boop")) out, err := mc.ReadMsg() if err != nil { t.Fatal(err) } // fmt.Println("recving", string(out)) data := string(out) if data != "beep" { t.Error("unexpected conn output", data) } out, err = mc.ReadMsg() if err != nil { t.Fatal(err) } data = string(out) if string(out) != "boop" { t.Error("unexpected conn output", data) } // fmt.Println("closing") c.Close() l1.Close() cancel() }