func (mq *msgQueue) doWork(ctx context.Context) { // allow ten minutes for connections // this includes looking them up in the dht // dialing them, and handshaking conctx, cancel := context.WithTimeout(ctx, time.Minute*10) defer cancel() err := mq.network.ConnectTo(conctx, mq.p) if err != nil { log.Infof("cant connect to peer %s: %s", mq.p, err) // TODO: cant connect, what now? return } // grab outgoing message mq.outlk.Lock() wlm := mq.out if wlm == nil || wlm.Empty() { mq.outlk.Unlock() return } mq.out = nil mq.outlk.Unlock() sendctx, cancel := context.WithTimeout(ctx, time.Minute*5) defer cancel() // send wantlist updates err = mq.network.SendMessage(sendctx, mq.p, wlm) if err != nil { log.Infof("bitswap send error: %s", err) // TODO: what do we do if this fails? return } }
func getFileCmd(nodes []int, file string) error { file = normalizePath(file) blocks, ok := files[file] if !ok { return fmt.Errorf("Tried to get file, '%s', which has not been added.\n", file) } var wg sync.WaitGroup // Get blocks and then Has them for _, node := range nodes { // remove blocks peer already has or nah? // I'm assuming that peers with the first block of the file have the whole file, // which i think is ok for the simulation, but i might have to change this later alreadyhas, err := peers[node].Blockstore().Has(files[file][0]) check(err) if alreadyhas { continue } wg.Add(1) go func(i int) { timer := recorder.NewTimer() ctx, cancel := context.WithTimeout(context.Background(), deadline) defer cancel() received, _ := peers[i].Exchange.GetBlocks(ctx, blocks) for j := 0; j < len(blocks); j++ { blockTimer := recorder.NewTimer() x := <-received if x == nil { wg.Done() return } recorder.EndBlockTime(blockTimer, peers[i].Peer.Pretty()) fmt.Println(i, x, j) ctx, cancel := context.WithTimeout(context.Background(), time.Second) err := peers[i].Exchange.HasBlock(ctx, x) if err != nil { fmt.Println("error when adding block", i, err) } cancel() } recorder.EndFileTime(timer, peers[i].Peer.Pretty(), file) // peers[i].Exchange.Close() wg.Done() }(node) } wg.Wait() testGet(nodes, file) return nil }
// TODO simplify this test. get to the _essence_! func TestSendToWantingPeer(t *testing.T) { if testing.Short() { t.SkipNow() } net := tn.VirtualNetwork(mockrouting.NewServer(), delay.Fixed(kNetworkDelay)) sg := NewTestSessionGenerator(net) defer sg.Close() bg := blocksutil.NewBlockGenerator() prev := rebroadcastDelay.Set(time.Second / 2) defer func() { rebroadcastDelay.Set(prev) }() peers := sg.Instances(2) peerA := peers[0] peerB := peers[1] t.Logf("Session %v\n", peerA.Peer) t.Logf("Session %v\n", peerB.Peer) timeout := time.Second waitTime := time.Second * 5 alpha := bg.Next() // peerA requests and waits for block alpha ctx, _ := context.WithTimeout(context.TODO(), waitTime) alphaPromise, err := peerA.Exchange.GetBlocks(ctx, []key.Key{alpha.Key()}) if err != nil { t.Fatal(err) } // peerB announces to the network that he has block alpha ctx, _ = context.WithTimeout(context.TODO(), timeout) err = peerB.Exchange.HasBlock(ctx, alpha) if err != nil { t.Fatal(err) } // At some point, peerA should get alpha (or timeout) blkrecvd, ok := <-alphaPromise if !ok { t.Fatal("context timed out and broke promise channel!") } if blkrecvd.Key() != alpha.Key() { t.Fatal("Wrong block!") } }
func TestDeadlineFractionCancel(t *testing.T) { ctx1, cancel1 := context.WithTimeout(context.Background(), 10*time.Millisecond) ctx2, cancel2 := WithDeadlineFraction(ctx1, 0.5) select { case <-ctx1.Done(): t.Fatal("ctx1 ended too early") case <-ctx2.Done(): t.Fatal("ctx2 ended too early") default: } cancel2() select { case <-ctx1.Done(): t.Fatal("ctx1 should NOT be cancelled") case <-ctx2.Done(): default: t.Fatal("ctx2 should be cancelled") } cancel1() select { case <-ctx1.Done(): case <-ctx2.Done(): default: t.Fatal("ctx1 should be cancelled") } }
// connects to providers for the given keys func (bs *Bitswap) providerConnector(parent context.Context) { defer log.Info("bitswap client worker shutting down...") for { log.Event(parent, "Bitswap.ProviderConnector.Loop") select { case req := <-bs.findKeys: keys := req.keys if len(keys) == 0 { log.Warning("Received batch request for zero blocks") continue } log.Event(parent, "Bitswap.ProviderConnector.Work", eventlog.LoggableMap{"Keys": keys}) // NB: Optimization. Assumes that providers of key[0] are likely to // be able to provide for all keys. This currently holds true in most // every situation. Later, this assumption may not hold as true. child, cancel := context.WithTimeout(req.ctx, providerRequestTimeout) providers := bs.network.FindProvidersAsync(child, keys[0], maxProvidersPerRequest) for p := range providers { go bs.network.ConnectTo(req.ctx, p) } cancel() case <-parent.Done(): return } } }
func TestFindPeer(t *testing.T) { // t.Skip("skipping test to debug another") if testing.Short() { t.SkipNow() } ctx := context.Background() _, peers, dhts := setupDHTS(ctx, 4, t) defer func() { for i := 0; i < 4; i++ { dhts[i].Close() dhts[i].host.Close() } }() connect(t, ctx, dhts[0], dhts[1]) connect(t, ctx, dhts[1], dhts[2]) connect(t, ctx, dhts[1], dhts[3]) ctxT, _ := context.WithTimeout(ctx, time.Second) p, err := dhts[0].FindPeer(ctxT, peers[2]) if err != nil { t.Fatal(err) } if p.ID == "" { t.Fatal("Failed to find peer.") } if p.ID != peers[2] { t.Fatal("Didnt find expected peer.") } }
func TestGetBlocksSequential(t *testing.T) { var servs = Mocks(t, 4) for _, s := range servs { defer s.Close() } bg := blocksutil.NewBlockGenerator() blks := bg.Blocks(50) var keys []key.Key for _, blk := range blks { keys = append(keys, blk.Key()) servs[0].AddBlock(blk) } t.Log("one instance at a time, get blocks concurrently") for i := 1; i < len(servs); i++ { ctx, _ := context.WithTimeout(context.TODO(), time.Second*50) out := servs[i].GetBlocks(ctx, keys) gotten := make(map[key.Key]*blocks.Block) for blk := range out { if _, ok := gotten[blk.Key()]; ok { t.Fatal("Got duplicate block!") } gotten[blk.Key()] = blk } if len(gotten) != len(blks) { t.Fatalf("Didnt get enough blocks back: %d/%d", len(gotten), len(blks)) } } }
// getNode returns the node for link. If it return an error, // stop processing. if it returns a nil node, just skip it. // // the error handling is a little complicated. func (t *traversal) getNode(link *mdag.Link) (*mdag.Node, error) { getNode := func(l *mdag.Link) (*mdag.Node, error) { ctx, cancel := context.WithTimeout(context.TODO(), time.Minute) defer cancel() next, err := l.GetNode(ctx, t.opts.DAG) if err != nil { return nil, err } skip, err := t.shouldSkip(next) if skip { next = nil } return next, err } next, err := getNode(link) if err != nil && t.opts.ErrFunc != nil { // attempt recovery. err = t.opts.ErrFunc(err) next = nil // skip regardless } return next, err }
func TestBasicBitswap(t *testing.T) { net := tn.VirtualNetwork(mockrouting.NewServer(), delay.Fixed(kNetworkDelay)) sg := NewTestSessionGenerator(net) defer sg.Close() bg := blocksutil.NewBlockGenerator() t.Log("Test a one node trying to get one block from another") instances := sg.Instances(2) blocks := bg.Blocks(1) err := instances[0].Exchange.HasBlock(context.TODO(), blocks[0]) if err != nil { t.Fatal(err) } ctx, _ := context.WithTimeout(context.TODO(), time.Second*5) blk, err := instances[1].Exchange.GetBlock(ctx, blocks[0].Key()) if err != nil { t.Fatal(err) } t.Log(blk) for _, inst := range instances { err := inst.Exchange.Close() if err != nil { t.Fatal(err) } } }
func (bs *Bitswap) connectToProviders(ctx context.Context, entries []wantlist.Entry) { ctx, cancel := context.WithCancel(ctx) defer cancel() // Get providers for all entries in wantlist (could take a while) wg := sync.WaitGroup{} for _, e := range entries { wg.Add(1) go func(k key.Key) { defer wg.Done() child, cancel := context.WithTimeout(ctx, providerRequestTimeout) defer cancel() providers := bs.network.FindProvidersAsync(child, k, maxProvidersPerRequest) for prov := range providers { go func(p peer.ID) { bs.network.ConnectTo(ctx, p) }(prov) } }(e.Key) } wg.Wait() // make sure all our children do finish. }
func TestGetBlockFromPeerAfterPeerAnnounces(t *testing.T) { net := tn.VirtualNetwork(mockrouting.NewServer(), delay.Fixed(kNetworkDelay)) block := blocks.NewBlock([]byte("block")) g := NewTestSessionGenerator(net) defer g.Close() peers := g.Instances(2) hasBlock := peers[0] defer hasBlock.Exchange.Close() if err := hasBlock.Exchange.HasBlock(context.Background(), block); err != nil { t.Fatal(err) } wantsBlock := peers[1] defer wantsBlock.Exchange.Close() ctx, _ := context.WithTimeout(context.Background(), time.Second) received, err := wantsBlock.Exchange.GetBlock(ctx, block.Key()) if err != nil { t.Log(err) t.Fatal("Expected to succeed") } if !bytes.Equal(block.Data, received.Data) { t.Fatal("Data doesn't match") } }
func TestValueGetSet(t *testing.T) { // t.Skip("skipping test to debug another") ctx := context.Background() dhtA := setupDHT(ctx, t) dhtB := setupDHT(ctx, t) defer dhtA.Close() defer dhtB.Close() defer dhtA.host.Close() defer dhtB.host.Close() vf := &record.ValidChecker{ Func: func(key.Key, []byte) error { return nil }, Sign: false, } dhtA.Validator["v"] = vf dhtB.Validator["v"] = vf connect(t, ctx, dhtA, dhtB) ctxT, _ := context.WithTimeout(ctx, time.Second) dhtA.PutValue(ctxT, "/v/hello", []byte("world")) ctxT, _ = context.WithTimeout(ctx, time.Second*2) val, err := dhtA.GetValue(ctxT, "/v/hello") if err != nil { t.Fatal(err) } if string(val) != "world" { t.Fatalf("Expected 'world' got '%s'", string(val)) } ctxT, _ = context.WithTimeout(ctx, time.Second*2) val, err = dhtB.GetValue(ctxT, "/v/hello") if err != nil { t.Fatal(err) } if string(val) != "world" { t.Fatalf("Expected 'world' got '%s'", string(val)) } }
// gatedDialAttempt is an attempt to dial a node. It is gated by the swarm's // dial synchronization systems: dialsync and dialbackoff. func (s *Swarm) gatedDialAttempt(ctx context.Context, p peer.ID) (*Conn, error) { var logdial = lgbl.Dial("swarm", s.LocalPeer(), p, nil, nil) defer log.EventBegin(ctx, "swarmDialAttemptSync", logdial).Done() // check if we already have an open connection first conn := s.bestConnectionToPeer(p) if conn != nil { return conn, nil } // check if there's an ongoing dial to this peer if ok, wait := s.dsync.Lock(p); ok { // ok, we have been charged to dial! let's do it. // if it succeeds, dial will add the conn to the swarm itself. defer log.EventBegin(ctx, "swarmDialAttemptStart", logdial).Done() ctxT, cancel := context.WithTimeout(ctx, s.dialT) conn, err := s.dial(ctxT, p) cancel() s.dsync.Unlock(p) log.Debugf("dial end %s", conn) if err != nil { log.Event(ctx, "swarmDialBackoffAdd", logdial) s.backf.AddBackoff(p) // let others know to backoff // ok, we failed. try again. (if loop is done, our error is output) return nil, fmt.Errorf("dial attempt failed: %s", err) } log.Event(ctx, "swarmDialBackoffClear", logdial) s.backf.Clear(p) // okay, no longer need to backoff return conn, nil } else { // we did not dial. we must wait for someone else to dial. // check whether we should backoff first... if s.backf.Backoff(p) { log.Event(ctx, "swarmDialBackoff", logdial) return nil, ErrDialBackoff } defer log.EventBegin(ctx, "swarmDialWait", logdial).Done() select { case <-wait: // wait for that other dial to finish. // see if it worked, OR we got an incoming dial in the meantime... conn := s.bestConnectionToPeer(p) if conn != nil { return conn, nil } return nil, ErrDialFailed case <-ctx.Done(): // or we may have to bail... return nil, ctx.Err() } } }
func TestCarryOnWhenDeadlineExpires(t *testing.T) { impossibleDeadline := time.Nanosecond fastExpiringCtx, _ := context.WithTimeout(context.Background(), impossibleDeadline) n := New() defer n.Shutdown() block := blocks.NewBlock([]byte("A Missed Connection")) blockChannel := n.Subscribe(fastExpiringCtx, block.Key()) assertBlockChannelNil(t, blockChannel) }
func TestFindPeersConnectedToPeer(t *testing.T) { t.Skip("not quite correct (see note)") if testing.Short() { t.SkipNow() } ctx := context.Background() _, peers, dhts := setupDHTS(ctx, 4, t) defer func() { for i := 0; i < 4; i++ { dhts[i].Close() dhts[i].host.Close() } }() // topology: // 0-1, 1-2, 1-3, 2-3 connect(t, ctx, dhts[0], dhts[1]) connect(t, ctx, dhts[1], dhts[2]) connect(t, ctx, dhts[1], dhts[3]) connect(t, ctx, dhts[2], dhts[3]) // fmt.Println("0 is", peers[0]) // fmt.Println("1 is", peers[1]) // fmt.Println("2 is", peers[2]) // fmt.Println("3 is", peers[3]) ctxT, _ := context.WithTimeout(ctx, time.Second) pchan, err := dhts[0].FindPeersConnectedToPeer(ctxT, peers[2]) if err != nil { t.Fatal(err) } // shouldFind := []peer.ID{peers[1], peers[3]} found := []peer.PeerInfo{} for nextp := range pchan { found = append(found, nextp) } // fmt.Printf("querying 0 (%s) FindPeersConnectedToPeer 2 (%s)\n", peers[0], peers[2]) // fmt.Println("should find 1, 3", shouldFind) // fmt.Println("found", found) // testPeerListsMatch(t, shouldFind, found) log.Warning("TestFindPeersConnectedToPeer is not quite correct") if len(found) == 0 { t.Fatal("didn't find any peers.") } }
func ExampleWithTimeout() { // Pass a context with a timeout to tell a blocking function that it // should abandon its work after the timeout elapses. ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) select { case <-time.After(200 * time.Millisecond): fmt.Println("overslept") case <-ctx.Done(): fmt.Println(ctx.Err()) // prints "context deadline exceeded" } // Output: // context deadline exceeded }
func TestBootstrap(t *testing.T) { // t.Skip("skipping test to debug another") if testing.Short() { t.SkipNow() } ctx := context.Background() nDHTs := 30 _, _, dhts := setupDHTS(ctx, nDHTs, t) defer func() { for i := 0; i < nDHTs; i++ { dhts[i].Close() defer dhts[i].host.Close() } }() t.Logf("connecting %d dhts in a ring", nDHTs) for i := 0; i < nDHTs; i++ { connect(t, ctx, dhts[i], dhts[(i+1)%len(dhts)]) } <-time.After(100 * time.Millisecond) // bootstrap a few times until we get good tables. stop := make(chan struct{}) go func() { for { t.Logf("bootstrapping them so they find each other", nDHTs) ctxT, _ := context.WithTimeout(ctx, 5*time.Second) bootstrap(t, ctxT, dhts) select { case <-time.After(50 * time.Millisecond): continue // being explicit case <-stop: return } } }() waitForWellFormedTables(t, dhts, 7, 10, 20*time.Second) close(stop) if u.Debug { // the routing tables should be full now. let's inspect them. printRoutingTables(dhts) } }
// WithDeadlineFraction returns a Context with a fraction of the // original context's timeout. This is useful in sequential pipelines // of work, where one might try options and fall back to others // depending on the time available, or failure to respond. For example: // // // getPicture returns a picture from our encrypted database // // we have a pipeline of multiple steps. we need to: // // - get the data from a database // // - decrypt it // // - apply many transforms // // // // we **know** that each step takes increasingly more time. // // The transforms are much more expensive than decryption, and // // decryption is more expensive than the database lookup. // // If our database takes too long (i.e. >0.2 of available time), // // there's no use in continuing. // func getPicture(ctx context.Context, key string) ([]byte, error) { // // fractional timeout contexts to the rescue! // // // try the database with 0.2 of remaining time. // ctx1, _ := ctxext.WithDeadlineFraction(ctx, 0.2) // val, err := db.Get(ctx1, key) // if err != nil { // return nil, err // } // // // try decryption with 0.3 of remaining time. // ctx2, _ := ctxext.WithDeadlineFraction(ctx, 0.3) // if val, err = decryptor.Decrypt(ctx2, val); err != nil { // return nil, err // } // // // try transforms with all remaining time. hopefully it's enough! // return transformer.Transform(ctx, val) // } // // func WithDeadlineFraction(ctx context.Context, fraction float64) ( context.Context, context.CancelFunc) { d, found := ctx.Deadline() if !found { // no deadline return context.WithCancel(ctx) } left := d.Sub(time.Now()) if left < 0 { // already passed... return context.WithCancel(ctx) } left = time.Duration(float64(left) * fraction) return context.WithTimeout(ctx, left) }
func (bs *Bitswap) ReceiveMessage(ctx context.Context, p peer.ID, incoming bsmsg.BitSwapMessage) { // This call records changes to wantlists, blocks received, // and number of bytes transfered. bs.engine.MessageReceived(p, incoming) // TODO: this is bad, and could be easily abused. // Should only track *useful* messages in ledger iblocks := incoming.Blocks() if len(iblocks) == 0 { return } // quickly send out cancels, reduces chances of duplicate block receives var keys []key.Key for _, block := range iblocks { if _, found := bs.wm.wl.Contains(block.Key()); !found { log.Info("received un-asked-for block: %s", block) continue } keys = append(keys, block.Key()) } bs.wm.CancelWants(keys) wg := sync.WaitGroup{} for _, block := range iblocks { wg.Add(1) go func(b *blocks.Block) { defer wg.Done() if err := bs.updateReceiveCounters(b.Key()); err != nil { return // ignore error, is either logged previously, or ErrAlreadyHaveBlock } k := b.Key() log.Event(ctx, "Bitswap.GetBlockRequest.End", &k) log.Debugf("got block %s from %s", b, p) hasBlockCtx, cancel := context.WithTimeout(ctx, hasBlockTimeout) defer cancel() if err := bs.HasBlock(hasBlockCtx, b); err != nil { log.Warningf("ReceiveMessage HasBlock error: %s", err) } }(block) } wg.Wait() }
func TestDeadlineFractionHalf(t *testing.T) { if os.Getenv("TRAVIS") == "true" { t.Skip("timeouts don't work reliably on travis") } ctx1, _ := context.WithTimeout(context.Background(), 10*time.Millisecond) ctx2, _ := WithDeadlineFraction(ctx1, 0.5) select { case <-ctx1.Done(): t.Fatal("ctx1 ended too early") case <-ctx2.Done(): t.Fatal("ctx2 ended too early") default: } <-time.After(2 * time.Millisecond) select { case <-ctx1.Done(): t.Fatal("ctx1 ended too early") case <-ctx2.Done(): t.Fatal("ctx2 ended too early") default: } <-time.After(4 * time.Millisecond) select { case <-ctx1.Done(): t.Fatal("ctx1 ended too early") case <-ctx2.Done(): default: t.Fatal("ctx2 ended too late") } <-time.After(6 * time.Millisecond) select { case <-ctx1.Done(): default: t.Fatal("ctx1 ended too late") } }
func getCmd(nodes []int, block *blocks.Block) error { var wg sync.WaitGroup for _, node := range nodes { wg.Add(1) go func(i int) { ctx, cancel := context.WithTimeout(context.Background(), deadline) defer cancel() peers[i].Exchange.GetBlock(ctx, block.Key()) fmt.Printf("Gotem from node %d.\n", i) peers[i].Exchange.Close() wg.Done() }(node) } wg.Wait() return nil }
func TestBlocks(t *testing.T) { bstore := blockstore.NewBlockstore(dssync.MutexWrap(ds.NewMapDatastore())) bs, err := New(bstore, offline.Exchange(bstore)) if err != nil { t.Error("failed to construct block service", err) return } defer bs.Close() b := blocks.NewBlock([]byte("beep boop")) h := u.Hash([]byte("beep boop")) if !bytes.Equal(b.Multihash, h) { t.Error("Block Multihash and data multihash not equal") } if b.Key() != key.Key(h) { t.Error("Block key and data multihash key not equal") } k, err := bs.AddBlock(b) if err != nil { t.Error("failed to add block to BlockService", err) return } if k != b.Key() { t.Error("returned key is not equal to block key", err) } ctx, _ := context.WithTimeout(context.TODO(), time.Second*5) b2, err := bs.GetBlock(ctx, b.Key()) if err != nil { t.Error("failed to retrieve block from BlockService", err) return } if b.Key() != b2.Key() { t.Error("Block keys not equal.") } if !bytes.Equal(b.Data, b2.Data) { t.Error("Block data is not equal.") } }
func TestProviderForKeyButNetworkCannotFind(t *testing.T) { // TODO revisit this rs := mockrouting.NewServer() net := tn.VirtualNetwork(rs, delay.Fixed(kNetworkDelay)) g := NewTestSessionGenerator(net) defer g.Close() block := blocks.NewBlock([]byte("block")) pinfo := p2ptestutil.RandTestBogusIdentityOrFatal(t) rs.Client(pinfo).Provide(context.Background(), block.Key()) // but not on network solo := g.Next() defer solo.Exchange.Close() ctx, _ := context.WithTimeout(context.Background(), time.Nanosecond) _, err := solo.Exchange.GetBlock(ctx, block.Key()) if err != context.DeadlineExceeded { t.Fatal("Expected DeadlineExceeded error") } }
// this test is on the context tool itself, not our stuff. it's for sanity on ours. func TestDeadline(t *testing.T) { if os.Getenv("TRAVIS") == "true" { t.Skip("timeouts don't work reliably on travis") } ctx, _ := context.WithTimeout(context.Background(), 5*time.Millisecond) select { case <-ctx.Done(): t.Fatal("ended too early") default: } <-time.After(6 * time.Millisecond) select { case <-ctx.Done(): default: t.Fatal("ended too late") } }
func TestLayeredGet(t *testing.T) { // t.Skip("skipping test to debug another") if testing.Short() { t.SkipNow() } ctx := context.Background() _, _, dhts := setupDHTS(ctx, 4, t) defer func() { for i := 0; i < 4; i++ { dhts[i].Close() defer dhts[i].host.Close() } }() connect(t, ctx, dhts[0], dhts[1]) connect(t, ctx, dhts[1], dhts[2]) connect(t, ctx, dhts[1], dhts[3]) err := dhts[3].Provide(ctx, key.Key("/v/hello")) if err != nil { t.Fatal(err) } time.Sleep(time.Millisecond * 6) t.Log("interface was changed. GetValue should not use providers.") ctxT, _ := context.WithTimeout(ctx, time.Second) val, err := dhts[0].GetValue(ctxT, key.Key("/v/hello")) if err != routing.ErrNotFound { t.Error(err) } if string(val) == "world" { t.Error("should not get value.") } if len(val) > 0 && string(val) != "world" { t.Error("worse, there's a value and its not even the right one.") } }
func (bs *Bitswap) provideWorker(ctx context.Context, id int) { idmap := eventlog.LoggableMap{"ID": id} for { log.Event(ctx, "Bitswap.ProvideWorker.Loop", idmap) select { case k, ok := <-bs.provideKeys: log.Event(ctx, "Bitswap.ProvideWorker.Work", idmap, &k) if !ok { log.Debug("provideKeys channel closed") return } ctx, cancel := context.WithTimeout(ctx, provideTimeout) err := bs.network.Provide(ctx, k) if err != nil { log.Error(err) } cancel() case <-ctx.Done(): return } } }
// runHandshake performs initial communication over insecure channel to share // keys, IDs, and initiate communication, assigning all necessary params. // requires the duplex channel to be a msgio.ReadWriter (for framed messaging) func (s *secureSession) runHandshake() error { ctx, cancel := context.WithTimeout(s.ctx, HandshakeTimeout) // remove defer cancel() // ============================================================================= // step 1. Propose -- propose cipher suite + send pubkeys + nonce // Generate and send Hello packet. // Hello = (rand, PublicKey, Supported) nonceOut := make([]byte, nonceSize) _, err := rand.Read(nonceOut) if err != nil { return err } defer log.EventBegin(ctx, "secureHandshake", s).Done() s.local.permanentPubKey = s.localKey.GetPublic() myPubKeyBytes, err := s.local.permanentPubKey.Bytes() if err != nil { return err } proposeOut := new(pb.Propose) proposeOut.Rand = nonceOut proposeOut.Pubkey = myPubKeyBytes proposeOut.Exchanges = &SupportedExchanges proposeOut.Ciphers = &SupportedCiphers proposeOut.Hashes = &SupportedHashes // log.Debugf("1.0 Propose: nonce:%s exchanges:%s ciphers:%s hashes:%s", // nonceOut, SupportedExchanges, SupportedCiphers, SupportedHashes) // Send Propose packet (respects ctx) proposeOutBytes, err := writeMsgCtx(ctx, s.insecureM, proposeOut) if err != nil { return err } // Receive + Parse their Propose packet and generate an Exchange packet. proposeIn := new(pb.Propose) proposeInBytes, err := readMsgCtx(ctx, s.insecureM, proposeIn) if err != nil { return err } // log.Debugf("1.0.1 Propose recv: nonce:%s exchanges:%s ciphers:%s hashes:%s", // proposeIn.GetRand(), proposeIn.GetExchanges(), proposeIn.GetCiphers(), proposeIn.GetHashes()) // ============================================================================= // step 1.1 Identify -- get identity from their key // get remote identity s.remote.permanentPubKey, err = ci.UnmarshalPublicKey(proposeIn.GetPubkey()) if err != nil { return err } // get peer id s.remotePeer, err = peer.IDFromPublicKey(s.remote.permanentPubKey) if err != nil { return err } log.Debugf("1.1 Identify: %s Remote Peer Identified as %s", s.localPeer, s.remotePeer) // ============================================================================= // step 1.2 Selection -- select/agree on best encryption parameters // to determine order, use cmp(H(remote_pubkey||local_rand), H(local_pubkey||remote_rand)). oh1 := u.Hash(append(proposeIn.GetPubkey(), nonceOut...)) oh2 := u.Hash(append(myPubKeyBytes, proposeIn.GetRand()...)) order := bytes.Compare(oh1, oh2) if order == 0 { return ErrEcho // talking to self (same socket. must be reuseport + dialing self) } s.local.curveT, err = selectBest(order, SupportedExchanges, proposeIn.GetExchanges()) if err != nil { return err } s.local.cipherT, err = selectBest(order, SupportedCiphers, proposeIn.GetCiphers()) if err != nil { return err } s.local.hashT, err = selectBest(order, SupportedHashes, proposeIn.GetHashes()) if err != nil { return err } // we use the same params for both directions (must choose same curve) // WARNING: if they dont SelectBest the same way, this won't work... s.remote.curveT = s.local.curveT s.remote.cipherT = s.local.cipherT s.remote.hashT = s.local.hashT // log.Debugf("1.2 selection: exchange:%s cipher:%s hash:%s", // s.local.curveT, s.local.cipherT, s.local.hashT) // ============================================================================= // step 2. Exchange -- exchange (signed) ephemeral keys. verify signatures. // Generate EphemeralPubKey var genSharedKey ci.GenSharedKey s.local.ephemeralPubKey, genSharedKey, err = ci.GenerateEKeyPair(s.local.curveT) // Gather corpus to sign. selectionOut := new(bytes.Buffer) selectionOut.Write(proposeOutBytes) selectionOut.Write(proposeInBytes) selectionOut.Write(s.local.ephemeralPubKey) selectionOutBytes := selectionOut.Bytes() // log.Debugf("2.0 exchange: %v", selectionOutBytes) exchangeOut := new(pb.Exchange) exchangeOut.Epubkey = s.local.ephemeralPubKey exchangeOut.Signature, err = s.localKey.Sign(selectionOutBytes) if err != nil { return err } // Send Propose packet (respects ctx) if _, err := writeMsgCtx(ctx, s.insecureM, exchangeOut); err != nil { return err } // Receive + Parse their Exchange packet. exchangeIn := new(pb.Exchange) if _, err := readMsgCtx(ctx, s.insecureM, exchangeIn); err != nil { return err } // ============================================================================= // step 2.1. Verify -- verify their exchange packet is good. // get their ephemeral pub key s.remote.ephemeralPubKey = exchangeIn.GetEpubkey() selectionIn := new(bytes.Buffer) selectionIn.Write(proposeInBytes) selectionIn.Write(proposeOutBytes) selectionIn.Write(s.remote.ephemeralPubKey) selectionInBytes := selectionIn.Bytes() // log.Debugf("2.0.1 exchange recv: %v", selectionInBytes) // u.POut("Remote Peer Identified as %s\n", s.remote) sigOK, err := s.remote.permanentPubKey.Verify(selectionInBytes, exchangeIn.GetSignature()) if err != nil { // log.Error("2.1 Verify: failed: %s", err) return err } if !sigOK { err := errors.New("Bad signature!") // log.Error("2.1 Verify: failed: %s", err) return err } // log.Debugf("2.1 Verify: signature verified.") // ============================================================================= // step 2.2. Keys -- generate keys for mac + encryption // OK! seems like we're good to go. s.sharedSecret, err = genSharedKey(exchangeIn.GetEpubkey()) if err != nil { return err } // generate two sets of keys (stretching) k1, k2 := ci.KeyStretcher(s.local.cipherT, s.local.hashT, s.sharedSecret) // use random nonces to decide order. switch { case order > 0: // just break case order < 0: k1, k2 = k2, k1 // swap default: // we should've bailed before this. but if not, bail here. return ErrEcho } s.local.keys = k1 s.remote.keys = k2 // log.Debug("2.2 keys:\n\tshared: %v\n\tk1: %v\n\tk2: %v", // s.sharedSecret, s.local.keys, s.remote.keys) // ============================================================================= // step 2.3. MAC + Cipher -- prepare MAC + cipher if err := s.local.makeMacAndCipher(); err != nil { return err } if err := s.remote.makeMacAndCipher(); err != nil { return err } // log.Debug("2.3 mac + cipher.") // ============================================================================= // step 3. Finish -- send expected message to verify encryption works (send local nonce) // setup ETM ReadWriter w := NewETMWriter(s.insecure, s.local.cipher, s.local.mac) r := NewETMReader(s.insecure, s.remote.cipher, s.remote.mac) s.secure = msgio.Combine(w, r).(msgio.ReadWriteCloser) // log.Debug("3.0 finish. sending: %v", proposeIn.GetRand()) // send their Nonce. if _, err := s.secure.Write(proposeIn.GetRand()); err != nil { return fmt.Errorf("Failed to write Finish nonce: %s", err) } // read our Nonce nonceOut2 := make([]byte, len(nonceOut)) if _, err := io.ReadFull(s.secure, nonceOut2); err != nil { return fmt.Errorf("Failed to read Finish nonce: %s", err) } // log.Debug("3.0 finish.\n\texpect: %v\n\tactual: %v", nonceOut, nonceOut2) if !bytes.Equal(nonceOut, nonceOut2) { return fmt.Errorf("Failed to read our encrypted nonce: %s != %s", nonceOut2, nonceOut) } // Whew! ok, that's all folks. return nil }
func TestGetFailures(t *testing.T) { if testing.Short() { t.SkipNow() } ctx := context.Background() mn, err := mocknet.FullMeshConnected(ctx, 2) if err != nil { t.Fatal(err) } hosts := mn.Hosts() tsds := dssync.MutexWrap(ds.NewMapDatastore()) d := NewDHT(ctx, hosts[0], tsds) d.Update(ctx, hosts[1].ID()) // Reply with failures to every message hosts[1].SetStreamHandler(ProtocolDHT, func(s inet.Stream) { defer s.Close() io.Copy(ioutil.Discard, s) }) // This one should time out ctx1, _ := context.WithTimeout(context.Background(), 200*time.Millisecond) if _, err := d.GetValue(ctx1, key.Key("test")); err != nil { if merr, ok := err.(u.MultiErr); ok && len(merr) > 0 { err = merr[0] } if err != context.DeadlineExceeded && err != context.Canceled { t.Fatal("Got different error than we expected", err) } } else { t.Fatal("Did not get expected error!") } t.Log("Timeout test passed.") // Reply with failures to every message hosts[1].SetStreamHandler(ProtocolDHT, func(s inet.Stream) { defer s.Close() pbr := ggio.NewDelimitedReader(s, inet.MessageSizeMax) pbw := ggio.NewDelimitedWriter(s) pmes := new(pb.Message) if err := pbr.ReadMsg(pmes); err != nil { panic(err) } resp := &pb.Message{ Type: pmes.Type, } if err := pbw.WriteMsg(resp); err != nil { panic(err) } }) // This one should fail with NotFound. // long context timeout to ensure we dont end too early. // the dht should be exhausting its query and returning not found. // (was 3 seconds before which should be _plenty_ of time, but maybe // travis machines really have a hard time...) ctx2, _ := context.WithTimeout(context.Background(), 20*time.Second) _, err = d.GetValue(ctx2, key.Key("test")) if err != nil { if merr, ok := err.(u.MultiErr); ok && len(merr) > 0 { err = merr[0] } if err != routing.ErrNotFound { t.Fatalf("Expected ErrNotFound, got: %s", err) } } else { t.Fatal("expected error, got none.") } t.Log("ErrNotFound check passed!") // Now we test this DHT's handleGetValue failure { typ := pb.Message_GET_VALUE str := "hello" sk, err := d.getOwnPrivateKey() if err != nil { t.Fatal(err) } rec, err := record.MakePutRecord(sk, key.Key(str), []byte("blah"), true) if err != nil { t.Fatal(err) } req := pb.Message{ Type: &typ, Key: &str, Record: rec, } s, err := hosts[1].NewStream(ProtocolDHT, hosts[0].ID()) if err != nil { t.Fatal(err) } defer s.Close() pbr := ggio.NewDelimitedReader(s, inet.MessageSizeMax) pbw := ggio.NewDelimitedWriter(s) if err := pbw.WriteMsg(&req); err != nil { t.Fatal(err) } pmes := new(pb.Message) if err := pbr.ReadMsg(pmes); err != nil { t.Fatal(err) } if pmes.GetRecord() != nil { t.Fatal("shouldnt have value") } if pmes.GetProviderPeers() != nil { t.Fatal("shouldnt have provider peers") } } }
// If less than K nodes are in the entire network, it should fail when we make // a GET rpc and nobody has the value func TestLessThanKResponses(t *testing.T) { // t.Skip("skipping test to debug another") // t.Skip("skipping test because it makes a lot of output") ctx := context.Background() mn, err := mocknet.FullMeshConnected(ctx, 6) if err != nil { t.Fatal(err) } hosts := mn.Hosts() tsds := dssync.MutexWrap(ds.NewMapDatastore()) d := NewDHT(ctx, hosts[0], tsds) for i := 1; i < 5; i++ { d.Update(ctx, hosts[i].ID()) } // Reply with random peers to every message for _, host := range hosts { host := host // shadow loop var host.SetStreamHandler(ProtocolDHT, func(s inet.Stream) { defer s.Close() pbr := ggio.NewDelimitedReader(s, inet.MessageSizeMax) pbw := ggio.NewDelimitedWriter(s) pmes := new(pb.Message) if err := pbr.ReadMsg(pmes); err != nil { panic(err) } switch pmes.GetType() { case pb.Message_GET_VALUE: pi := host.Peerstore().PeerInfo(hosts[1].ID()) resp := &pb.Message{ Type: pmes.Type, CloserPeers: pb.PeerInfosToPBPeers(d.host.Network(), []peer.PeerInfo{pi}), } if err := pbw.WriteMsg(resp); err != nil { panic(err) } default: panic("Shouldnt recieve this.") } }) } ctx, _ = context.WithTimeout(ctx, time.Second*30) if _, err := d.GetValue(ctx, key.Key("hello")); err != nil { switch err { case routing.ErrNotFound: //Success! return case u.ErrTimeout: t.Fatal("Should not have gotten timeout!") default: t.Fatalf("Got unexpected error: %s", err) } } t.Fatal("Expected to recieve an error.") }
func TestNotFound(t *testing.T) { // t.Skip("skipping test to debug another") if testing.Short() { t.SkipNow() } ctx := context.Background() mn, err := mocknet.FullMeshConnected(ctx, 16) if err != nil { t.Fatal(err) } hosts := mn.Hosts() tsds := dssync.MutexWrap(ds.NewMapDatastore()) d := NewDHT(ctx, hosts[0], tsds) for _, p := range hosts { d.Update(ctx, p.ID()) } // Reply with random peers to every message for _, host := range hosts { host := host // shadow loop var host.SetStreamHandler(ProtocolDHT, func(s inet.Stream) { defer s.Close() pbr := ggio.NewDelimitedReader(s, inet.MessageSizeMax) pbw := ggio.NewDelimitedWriter(s) pmes := new(pb.Message) if err := pbr.ReadMsg(pmes); err != nil { panic(err) } switch pmes.GetType() { case pb.Message_GET_VALUE: resp := &pb.Message{Type: pmes.Type} ps := []peer.PeerInfo{} for i := 0; i < 7; i++ { p := hosts[rand.Intn(len(hosts))].ID() pi := host.Peerstore().PeerInfo(p) ps = append(ps, pi) } resp.CloserPeers = pb.PeerInfosToPBPeers(d.host.Network(), ps) if err := pbw.WriteMsg(resp); err != nil { panic(err) } default: panic("Shouldnt recieve this.") } }) } // long timeout to ensure timing is not at play. ctx, _ = context.WithTimeout(ctx, time.Second*20) v, err := d.GetValue(ctx, key.Key("hello")) log.Debugf("get value got %v", v) if err != nil { if merr, ok := err.(u.MultiErr); ok && len(merr) > 0 { err = merr[0] } switch err { case routing.ErrNotFound: //Success! return case u.ErrTimeout: t.Fatal("Should not have gotten timeout!") default: t.Fatalf("Got unexpected error: %s", err) } } t.Fatal("Expected to recieve an error.") }