// Tests that trusted peers and can connect above max peer caps. func TestServerTrustedPeers(t *testing.T) { defer testlog(t).detach() // Create a trusted peer to accept connections from key := newkey() trusted := &discover.Node{ ID: discover.PubkeyID(&key.PublicKey), } // Create a test server with limited connection slots started := make(chan *Peer) server := &Server{ ListenAddr: "127.0.0.1:0", PrivateKey: newkey(), MaxPeers: 3, NoDial: true, TrustedNodes: []*discover.Node{trusted}, newPeerHook: func(p *Peer) { started <- p }, } if err := server.Start(); err != nil { t.Fatal(err) } defer server.Stop() // Fill up all the slots on the server dialer := &net.Dialer{Deadline: time.Now().Add(3 * time.Second)} for i := 0; i < server.MaxPeers; i++ { // Establish a new connection conn, err := dialer.Dial("tcp", server.ListenAddr) if err != nil { t.Fatalf("conn %d: dial error: %v", i, err) } defer conn.Close() // Run the handshakes just like a real peer would, and wait for completion key := newkey() shake := &protoHandshake{Version: baseProtocolVersion, ID: discover.PubkeyID(&key.PublicKey)} if _, err = setupConn(conn, key, shake, server.Self(), keepalways); err != nil { t.Fatalf("conn %d: unexpected error: %v", i, err) } <-started } // Dial from the trusted peer, ensure connection is accepted conn, err := dialer.Dial("tcp", server.ListenAddr) if err != nil { t.Fatalf("trusted node: dial error: %v", err) } defer conn.Close() shake := &protoHandshake{Version: baseProtocolVersion, ID: trusted.ID} if _, err = setupConn(conn, key, shake, server.Self(), keepalways); err != nil { t.Fatalf("trusted node: unexpected error: %v", err) } select { case <-started: // Ok, trusted peer accepted case <-time.After(100 * time.Millisecond): t.Fatalf("trusted node timeout") } }
// Self returns the local node's endpoint information. func (srv *Server) Self() *discover.Node { srv.lock.Lock() defer srv.lock.Unlock() // If the server's not running, return an empty node if !srv.running { return &discover.Node{IP: net.ParseIP("0.0.0.0")} } // If the node is running but discovery is off, manually assemble the node infos if srv.ntab == nil { // Inbound connections disabled, use zero address if srv.listener == nil { return &discover.Node{IP: net.ParseIP("0.0.0.0"), ID: discover.PubkeyID(&srv.PrivateKey.PublicKey)} } // Otherwise inject the listener address too addr := srv.listener.Addr().(*net.TCPAddr) return &discover.Node{ ID: discover.PubkeyID(&srv.PrivateKey.PublicKey), IP: addr.IP, TCP: uint16(addr.Port), } } // Otherwise return the live node infos return srv.ntab.Self() }
func testEncHandshake(token []byte) error { type result struct { side string s secrets err error } var ( prv0, _ = crypto.GenerateKey() prv1, _ = crypto.GenerateKey() rw0, rw1 = net.Pipe() output = make(chan result) ) go func() { r := result{side: "initiator"} defer func() { output <- r }() pub1s := discover.PubkeyID(&prv1.PublicKey) r.s, r.err = initiatorEncHandshake(rw0, prv0, pub1s, token) if r.err != nil { return } id1 := discover.PubkeyID(&prv1.PublicKey) if r.s.RemoteID != id1 { r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.s.RemoteID, id1) } }() go func() { r := result{side: "receiver"} defer func() { output <- r }() r.s, r.err = receiverEncHandshake(rw1, prv1, token) if r.err != nil { return } id0 := discover.PubkeyID(&prv0.PublicKey) if r.s.RemoteID != id0 { r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.s.RemoteID, id0) } }() // wait for results from both sides r1, r2 := <-output, <-output if r1.err != nil { return fmt.Errorf("%s side error: %v", r1.side, r1.err) } if r2.err != nil { return fmt.Errorf("%s side error: %v", r2.side, r2.err) } // don't compare remote node IDs r1.s.RemoteID, r2.s.RemoteID = discover.NodeID{}, discover.NodeID{} // flip MACs on one of them so they compare equal r1.s.EgressMAC, r1.s.IngressMAC = r1.s.IngressMAC, r1.s.EgressMAC if !reflect.DeepEqual(r1.s, r2.s) { return fmt.Errorf("secrets mismatch:\n t1: %#v\n t2: %#v", r1.s, r2.s) } return nil }
func TestServerMaxPendingAccepts(t *testing.T) { defer testlog(t).detach() // Start a test server and a peer sink for synchronization started := make(chan *Peer) server := &Server{ ListenAddr: "127.0.0.1:0", PrivateKey: newkey(), MaxPeers: 10, MaxPendingPeers: 1, NoDial: true, newPeerHook: func(p *Peer) { started <- p }, } if err := server.Start(); err != nil { t.Fatal("failed to start test server: %v", err) } defer server.Stop() // Try and connect to the server on multiple threads concurrently conns := make([]net.Conn, 2) for i := 0; i < 2; i++ { dialer := &net.Dialer{Deadline: time.Now().Add(3 * time.Second)} conn, err := dialer.Dial("tcp", server.ListenAddr) if err != nil { t.Fatalf("failed to dial server: %v", err) } conns[i] = conn } // Check that a handshake on the second doesn't pass go func() { key := newkey() shake := &protoHandshake{Version: baseProtocolVersion, ID: discover.PubkeyID(&key.PublicKey)} if _, err := setupConn(conns[1], key, shake, server.Self(), keepalways); err != nil { t.Fatalf("failed to run handshake: %v", err) } }() select { case <-started: t.Fatalf("handshake on second connection accepted") case <-time.After(time.Second): } // Shake on first, check that both go through go func() { key := newkey() shake := &protoHandshake{Version: baseProtocolVersion, ID: discover.PubkeyID(&key.PublicKey)} if _, err := setupConn(conns[0], key, shake, server.Self(), keepalways); err != nil { t.Fatalf("failed to run handshake: %v", err) } }() for i := 0; i < 2; i++ { select { case <-started: case <-time.After(time.Second): t.Fatalf("peer %d: handshake timeout", i) } } }
func TestSetupConn(t *testing.T) { prv0, _ := crypto.GenerateKey() prv1, _ := crypto.GenerateKey() node0 := &discover.Node{ ID: discover.PubkeyID(&prv0.PublicKey), IP: net.IP{1, 2, 3, 4}, TCP: 33, } node1 := &discover.Node{ ID: discover.PubkeyID(&prv1.PublicKey), IP: net.IP{5, 6, 7, 8}, TCP: 44, } hs0 := &protoHandshake{ Version: baseProtocolVersion, ID: node0.ID, Caps: []Cap{{"a", 0}, {"b", 2}}, } hs1 := &protoHandshake{ Version: baseProtocolVersion, ID: node1.ID, Caps: []Cap{{"c", 1}, {"d", 3}}, } fd0, fd1 := net.Pipe() done := make(chan struct{}) keepalways := func(discover.NodeID) bool { return true } go func() { defer close(done) conn0, err := setupConn(fd0, prv0, hs0, node1, keepalways) if err != nil { t.Errorf("outbound side error: %v", err) return } if conn0.ID != node1.ID { t.Errorf("outbound conn id mismatch: got %v, want %v", conn0.ID, node1.ID) } if !reflect.DeepEqual(conn0.Caps, hs1.Caps) { t.Errorf("outbound caps mismatch: got %v, want %v", conn0.Caps, hs1.Caps) } }() conn1, err := setupConn(fd1, prv1, hs1, nil, keepalways) if err != nil { t.Fatalf("inbound side error: %v", err) } if conn1.ID != node0.ID { t.Errorf("inbound conn id mismatch: got %v, want %v", conn1.ID, node0.ID) } if !reflect.DeepEqual(conn1.Caps, hs0.Caps) { t.Errorf("inbound caps mismatch: got %v, want %v", conn1.Caps, hs0.Caps) } <-done }
// This test checks that connections are disconnected // just after the encryption handshake when the server is // at capacity. // // It also serves as a light-weight integration test. func TestServerDisconnectAtCap(t *testing.T) { defer testlog(t).detach() started := make(chan *Peer) srv := &Server{ ListenAddr: "127.0.0.1:0", PrivateKey: newkey(), MaxPeers: 10, NoDial: true, // This hook signals that the peer was actually started. We // need to wait for the peer to be started before dialing the // next connection to get a deterministic peer count. newPeerHook: func(p *Peer) { started <- p }, } if err := srv.Start(); err != nil { t.Fatal(err) } defer srv.Stop() nconns := srv.MaxPeers + 1 dialer := &net.Dialer{Deadline: time.Now().Add(3 * time.Second)} for i := 0; i < nconns; i++ { conn, err := dialer.Dial("tcp", srv.ListenAddr) if err != nil { t.Fatalf("conn %d: dial error: %v", i, err) } // Close the connection when the test ends, before // shutting down the server. defer conn.Close() // Run the handshakes just like a real peer would. key := newkey() hs := &protoHandshake{Version: baseProtocolVersion, ID: discover.PubkeyID(&key.PublicKey)} _, err = setupConn(conn, key, hs, srv.Self(), keepalways) if i == nconns-1 { // When handling the last connection, the server should // disconnect immediately instead of running the protocol // handshake. if err != DiscTooManyPeers { t.Errorf("conn %d: got error %q, expected %q", i, err, DiscTooManyPeers) } } else { // For all earlier connections, the handshake should go through. if err != nil { t.Fatalf("conn %d: unexpected error: %v", i, err) } // Wait for runPeer to be started. <-started } } }
// Start starts running the server. // Servers can not be re-used after stopping. func (srv *Server) Start() (err error) { srv.lock.Lock() defer srv.lock.Unlock() if srv.running { return errors.New("server already running") } srv.running = true glog.V(logger.Info).Infoln("Starting Server") // static fields if srv.PrivateKey == nil { return fmt.Errorf("Server.PrivateKey must be set to a non-nil key") } if srv.newTransport == nil { srv.newTransport = newRLPX } if srv.Dialer == nil { srv.Dialer = &net.Dialer{Timeout: defaultDialTimeout} } srv.quit = make(chan struct{}) srv.addpeer = make(chan *conn) srv.delpeer = make(chan *Peer) srv.posthandshake = make(chan *conn) srv.addstatic = make(chan *discover.Node) srv.peerOp = make(chan peerOpFunc) srv.peerOpDone = make(chan struct{}) // node table if srv.Discovery { ntab, err := discover.ListenUDP(srv.PrivateKey, srv.ListenAddr, srv.NAT, srv.NodeDatabase) if err != nil { return err } if err := ntab.SetFallbackNodes(srv.BootstrapNodes); err != nil { return err } srv.ntab = ntab } dynPeers := (srv.MaxPeers + 1) / 2 if !srv.Discovery { dynPeers = 0 } dialer := newDialState(srv.StaticNodes, srv.ntab, dynPeers) // handshake srv.ourHandshake = &protoHandshake{Version: baseProtocolVersion, Name: srv.Name, ID: discover.PubkeyID(&srv.PrivateKey.PublicKey)} for _, p := range srv.Protocols { srv.ourHandshake.Caps = append(srv.ourHandshake.Caps, p.cap()) } // listen/dial if srv.ListenAddr != "" { if err := srv.startListening(); err != nil { return err } } if srv.NoDial && srv.ListenAddr == "" { glog.V(logger.Warn).Infoln("I will be kind-of useless, neither dialing nor listening.") } srv.loopWG.Add(1) go srv.run(dialer) srv.running = true return nil }
func testEncHandshake(token []byte) error { type result struct { side string id discover.NodeID err error } var ( prv0, _ = crypto.GenerateKey() prv1, _ = crypto.GenerateKey() fd0, fd1 = net.Pipe() c0, c1 = newRLPX(fd0).(*rlpx), newRLPX(fd1).(*rlpx) output = make(chan result) ) go func() { r := result{side: "initiator"} defer func() { output <- r }() defer fd0.Close() dest := &discover.Node{ID: discover.PubkeyID(&prv1.PublicKey)} r.id, r.err = c0.doEncHandshake(prv0, dest) if r.err != nil { return } id1 := discover.PubkeyID(&prv1.PublicKey) if r.id != id1 { r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.id, id1) } }() go func() { r := result{side: "receiver"} defer func() { output <- r }() defer fd1.Close() r.id, r.err = c1.doEncHandshake(prv1, nil) if r.err != nil { return } id0 := discover.PubkeyID(&prv0.PublicKey) if r.id != id0 { r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.id, id0) } }() // wait for results from both sides r1, r2 := <-output, <-output if r1.err != nil { return fmt.Errorf("%s side error: %v", r1.side, r1.err) } if r2.err != nil { return fmt.Errorf("%s side error: %v", r2.side, r2.err) } // compare derived secrets if !reflect.DeepEqual(c0.rw.egressMAC, c1.rw.ingressMAC) { return fmt.Errorf("egress mac mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.egressMAC, c1.rw.ingressMAC) } if !reflect.DeepEqual(c0.rw.ingressMAC, c1.rw.egressMAC) { return fmt.Errorf("ingress mac mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.ingressMAC, c1.rw.egressMAC) } if !reflect.DeepEqual(c0.rw.enc, c1.rw.enc) { return fmt.Errorf("enc cipher mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.enc, c1.rw.enc) } if !reflect.DeepEqual(c0.rw.dec, c1.rw.dec) { return fmt.Errorf("dec cipher mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.dec, c1.rw.dec) } return nil }
func TestProtocolHandshake(t *testing.T) { var ( prv0, _ = crypto.GenerateKey() node0 = &discover.Node{ID: discover.PubkeyID(&prv0.PublicKey), IP: net.IP{1, 2, 3, 4}, TCP: 33} hs0 = &protoHandshake{Version: 3, ID: node0.ID, Caps: []Cap{{"a", 0}, {"b", 2}}} prv1, _ = crypto.GenerateKey() node1 = &discover.Node{ID: discover.PubkeyID(&prv1.PublicKey), IP: net.IP{5, 6, 7, 8}, TCP: 44} hs1 = &protoHandshake{Version: 3, ID: node1.ID, Caps: []Cap{{"c", 1}, {"d", 3}}} fd0, fd1 = net.Pipe() wg sync.WaitGroup ) wg.Add(2) go func() { defer wg.Done() defer fd1.Close() rlpx := newRLPX(fd0) remid, err := rlpx.doEncHandshake(prv0, node1) if err != nil { t.Errorf("dial side enc handshake failed: %v", err) return } if remid != node1.ID { t.Errorf("dial side remote id mismatch: got %v, want %v", remid, node1.ID) return } phs, err := rlpx.doProtoHandshake(hs0) if err != nil { t.Errorf("dial side proto handshake error: %v", err) return } phs.Rest = nil if !reflect.DeepEqual(phs, hs1) { t.Errorf("dial side proto handshake mismatch:\ngot: %s\nwant: %s\n", spew.Sdump(phs), spew.Sdump(hs1)) return } rlpx.close(DiscQuitting) }() go func() { defer wg.Done() defer fd1.Close() rlpx := newRLPX(fd1) remid, err := rlpx.doEncHandshake(prv1, nil) if err != nil { t.Errorf("listen side enc handshake failed: %v", err) return } if remid != node0.ID { t.Errorf("listen side remote id mismatch: got %v, want %v", remid, node0.ID) return } phs, err := rlpx.doProtoHandshake(hs1) if err != nil { t.Errorf("listen side proto handshake error: %v", err) return } phs.Rest = nil if !reflect.DeepEqual(phs, hs0) { t.Errorf("listen side proto handshake mismatch:\ngot: %s\nwant: %s\n", spew.Sdump(phs), spew.Sdump(hs0)) return } if err := ExpectMsg(rlpx, discMsg, []DiscReason{DiscQuitting}); err != nil { t.Errorf("error receiving disconnect: %v", err) } }() wg.Wait() }
// Tests that a failed dial will temporarily throttle a peer. func TestServerMaxPendingDials(t *testing.T) { defer testlog(t).detach() // Start a simple test server server := &Server{ ListenAddr: "127.0.0.1:0", PrivateKey: newkey(), MaxPeers: 10, MaxPendingPeers: 1, } if err := server.Start(); err != nil { t.Fatal("failed to start test server: %v", err) } defer server.Stop() // Simulate two separate remote peers peers := make(chan *discover.Node, 2) conns := make(chan net.Conn, 2) for i := 0; i < 2; i++ { listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("listener %d: failed to setup: %v", i, err) } defer listener.Close() addr := listener.Addr().(*net.TCPAddr) peers <- &discover.Node{ ID: discover.PubkeyID(&newkey().PublicKey), IP: addr.IP, TCP: uint16(addr.Port), } go func() { conn, err := listener.Accept() if err == nil { conns <- conn } }() } // Request a dial for both peers go func() { for i := 0; i < 2; i++ { server.staticDial <- <-peers // hack piggybacking the static implementation } }() // Make sure only one outbound connection goes through var conn net.Conn select { case conn = <-conns: case <-time.After(100 * time.Millisecond): t.Fatalf("first dial timeout") } select { case conn = <-conns: t.Fatalf("second dial completed prematurely") case <-time.After(100 * time.Millisecond): } // Finish the first dial, check the second conn.Close() select { case conn = <-conns: conn.Close() case <-time.After(100 * time.Millisecond): t.Fatalf("second dial timeout") } }
// Tests that static peers are (re)connected, and done so even above max peers. func TestServerStaticPeers(t *testing.T) { defer testlog(t).detach() // Create a test server with limited connection slots started := make(chan *Peer) server := &Server{ ListenAddr: "127.0.0.1:0", PrivateKey: newkey(), MaxPeers: 3, newPeerHook: func(p *Peer) { started <- p }, staticCycle: time.Second, } if err := server.Start(); err != nil { t.Fatal(err) } defer server.Stop() // Fill up all the slots on the server dialer := &net.Dialer{Deadline: time.Now().Add(3 * time.Second)} for i := 0; i < server.MaxPeers; i++ { // Establish a new connection conn, err := dialer.Dial("tcp", server.ListenAddr) if err != nil { t.Fatalf("conn %d: dial error: %v", i, err) } defer conn.Close() // Run the handshakes just like a real peer would, and wait for completion key := newkey() shake := &protoHandshake{Version: baseProtocolVersion, ID: discover.PubkeyID(&key.PublicKey)} if _, err = setupConn(conn, key, shake, server.Self(), keepalways); err != nil { t.Fatalf("conn %d: unexpected error: %v", i, err) } <-started } // Open a TCP listener to accept static connections listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("failed to setup listener: %v", err) } defer listener.Close() connected := make(chan net.Conn) go func() { for i := 0; i < 3; i++ { conn, err := listener.Accept() if err == nil { connected <- conn } } }() // Inject a static node and wait for a remote dial, then redial, then nothing addr := listener.Addr().(*net.TCPAddr) static := &discover.Node{ ID: discover.PubkeyID(&newkey().PublicKey), IP: addr.IP, TCP: uint16(addr.Port), } server.AddPeer(static) select { case conn := <-connected: // Close the first connection, expect redial conn.Close() case <-time.After(2 * server.staticCycle): t.Fatalf("remote dial timeout") } select { case conn := <-connected: // Keep the second connection, don't expect redial defer conn.Close() case <-time.After(2 * server.staticCycle): t.Fatalf("remote re-dial timeout") } select { case <-time.After(2 * server.staticCycle): // Timeout as no dial occurred case <-connected: t.Fatalf("connected node dialed") } }
func TestServerSetupConn(t *testing.T) { id := randomID() srvkey := newkey() srvid := discover.PubkeyID(&srvkey.PublicKey) tests := []struct { dontstart bool tt *setupTransport flags connFlag dialDest *discover.Node wantCloseErr error wantCalls string }{ { dontstart: true, tt: &setupTransport{id: id}, wantCalls: "close,", wantCloseErr: errServerStopped, }, { tt: &setupTransport{id: id, encHandshakeErr: errors.New("read error")}, flags: inboundConn, wantCalls: "doEncHandshake,close,", wantCloseErr: errors.New("read error"), }, { tt: &setupTransport{id: id}, dialDest: &discover.Node{ID: randomID()}, flags: dynDialedConn, wantCalls: "doEncHandshake,close,", wantCloseErr: DiscUnexpectedIdentity, }, { tt: &setupTransport{id: id, phs: &protoHandshake{ID: randomID()}}, dialDest: &discover.Node{ID: id}, flags: dynDialedConn, wantCalls: "doEncHandshake,doProtoHandshake,close,", wantCloseErr: DiscUnexpectedIdentity, }, { tt: &setupTransport{id: id, protoHandshakeErr: errors.New("foo")}, dialDest: &discover.Node{ID: id}, flags: dynDialedConn, wantCalls: "doEncHandshake,doProtoHandshake,close,", wantCloseErr: errors.New("foo"), }, { tt: &setupTransport{id: srvid, phs: &protoHandshake{ID: srvid}}, flags: inboundConn, wantCalls: "doEncHandshake,close,", wantCloseErr: DiscSelf, }, { tt: &setupTransport{id: id, phs: &protoHandshake{ID: id}}, flags: inboundConn, wantCalls: "doEncHandshake,doProtoHandshake,close,", wantCloseErr: DiscUselessPeer, }, } for i, test := range tests { srv := &Server{ Config: Config{ PrivateKey: srvkey, MaxPeers: 10, NoDial: true, Protocols: []Protocol{discard}, }, newTransport: func(fd net.Conn) transport { return test.tt }, } if !test.dontstart { if err := srv.Start(); err != nil { t.Fatalf("couldn't start server: %v", err) } } p1, _ := net.Pipe() srv.setupConn(p1, test.flags, test.dialDest) if !reflect.DeepEqual(test.tt.closeErr, test.wantCloseErr) { t.Errorf("test %d: close error mismatch: got %q, want %q", i, test.tt.closeErr, test.wantCloseErr) } if test.tt.calls != test.wantCalls { t.Errorf("test %d: calls mismatch: got %q, want %q", i, test.tt.calls, test.wantCalls) } } }