func TestStopping(t *testing.T) { ctx := context.Background() // setup p1 := &TestProtocol{Pipe: msg.NewPipe(10)} p2 := &TestProtocol{Pipe: msg.NewPipe(10)} pid1 := pb.ProtocolID_Test pid2 := pb.ProtocolID_Identify mux1 := NewMuxer(ctx, ProtocolMap{ pid1: p1, pid2: p2, }) peer1 := newPeer(t, "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275aaaaaa") // peer2 := newPeer(t, "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275bbbbbb") // test outgoing p1 for _, s := range []string{"foo1", "bar1", "baz1"} { p1.Outgoing <- msg.New(peer1, []byte(s)) testWrappedMsg(t, <-mux1.Outgoing, pid1, []byte(s)) } // test incoming p1 for _, s := range []string{"foo2", "bar2", "baz2"} { d, err := wrapData([]byte(s), pid1) if err != nil { t.Error(err) } mux1.Incoming <- msg.New(peer1, d) testMsg(t, <-p1.Incoming, []byte(s)) } mux1.Close() // waits // test outgoing p1 for _, s := range []string{"foo3", "bar3", "baz3"} { p1.Outgoing <- msg.New(peer1, []byte(s)) select { case m := <-mux1.Outgoing: t.Errorf("should not have received anything. Got: %v", string(m.Data())) case <-time.After(time.Millisecond): } } // test incoming p1 for _, s := range []string{"foo4", "bar4", "baz4"} { d, err := wrapData([]byte(s), pid1) if err != nil { t.Error(err) } mux1.Incoming <- msg.New(peer1, d) select { case <-p1.Incoming: t.Error("should not have received anything.") case <-time.After(time.Millisecond): } } }
// Handles the receiving + wrapping of messages, per conn. // Consider using reflect.Select with one goroutine instead of n. func (s *Swarm) fanInSingle(c conn.Conn) { // cleanup all data associated with this child Connection. defer func() { // remove it from the map. s.connsLock.Lock() delete(s.conns, c.RemotePeer().Key()) s.connsLock.Unlock() s.Children().Done() c.Children().Done() // child of Conn as well. }() i := 0 for { select { case <-s.Closing(): // Swarm closing return case <-c.Closing(): // Conn closing return case data, ok := <-c.In(): if !ok { log.Infof("%s in channel closed", c) return // channel closed. } i++ log.Debugf("%s received message from %s (%d)", s.local, c.RemotePeer(), i) s.Incoming <- msg.New(c.RemotePeer(), data) } } }
func TestServiceRequestTimeout(t *testing.T) { ctx, _ := context.WithTimeout(context.Background(), time.Millisecond) s1 := NewService(ctx, &ReverseHandler{}) s2 := NewService(ctx, &ReverseHandler{}) peer1 := newPeer(t, "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275aaaaaa") // patch services together go func() { for { <-time.After(time.Millisecond) select { case m := <-s1.GetPipe().Outgoing: s2.GetPipe().Incoming <- m case m := <-s2.GetPipe().Outgoing: s1.GetPipe().Incoming <- m case <-ctx.Done(): return } } }() m1 := msg.New(peer1, []byte("beep")) m2, err := s1.SendRequest(ctx, m1) if err == nil || m2 != nil { t.Error("should've timed out") } }
func TestServiceRequest(t *testing.T) { ctx := context.Background() s1 := NewService(ctx, &ReverseHandler{}) s2 := NewService(ctx, &ReverseHandler{}) peer1 := newPeer(t, "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275aaaaaa") // patch services together go func() { for { select { case m := <-s1.GetPipe().Outgoing: s2.GetPipe().Incoming <- m case m := <-s2.GetPipe().Outgoing: s1.GetPipe().Incoming <- m case <-ctx.Done(): return } } }() m1 := msg.New(peer1, []byte("beep")) m2, err := s1.SendRequest(ctx, m1) if err != nil { t.Error(err) } if !bytes.Equal(m2.Data(), []byte("peeb")) { t.Errorf("service handler data incorrect: %v != %v", m2.Data(), "oof") } }
func TestServiceHandler(t *testing.T) { ctx := context.Background() h := &ReverseHandler{} s := NewService(ctx, h) peer1 := newPeer(t, "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275aaaaaa") d, err := wrapData([]byte("beep"), nil) if err != nil { t.Error(err) } m1 := msg.New(peer1, d) s.GetPipe().Incoming <- m1 m2 := <-s.GetPipe().Outgoing d, rid, err := unwrapData(m2.Data()) if err != nil { t.Error(err) } if rid != nil { t.Error("RequestID should be nil") } if !bytes.Equal(d, []byte("peeb")) { t.Errorf("service handler data incorrect: %v != %v", d, "oof") } }
// handleIncomingMessage routes message to the appropriate protocol. func (m *Muxer) handleIncomingMessage(m1 msg.NetMessage) { defer m.Children().Done() m.bwiLock.Lock() // TODO: compensate for overhead m.bwIn += uint64(len(m1.Data())) m.bwiLock.Unlock() data, pid, err := unwrapData(m1.Data()) if err != nil { log.Errorf("muxer de-serializing error: %v", err) return } conn.ReleaseBuffer(m1.Data()) m2 := msg.New(m1.Peer(), data) proto, found := m.Protocols[pid] if !found { log.Errorf("muxer unknown protocol %v", pid) return } select { case proto.GetPipe().Incoming <- m2: case <-m.Closing(): return } }
func TestSimpleMuxer(t *testing.T) { ctx := context.Background() // setup p1 := &TestProtocol{Pipe: msg.NewPipe(10)} p2 := &TestProtocol{Pipe: msg.NewPipe(10)} pid1 := pb.ProtocolID_Test pid2 := pb.ProtocolID_Routing mux1 := NewMuxer(ctx, ProtocolMap{ pid1: p1, pid2: p2, }) peer1 := newPeer(t, "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275aaaaaa") // peer2 := newPeer(t, "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275bbbbbb") // test outgoing p1 for _, s := range []string{"foo", "bar", "baz"} { p1.Outgoing <- msg.New(peer1, []byte(s)) testWrappedMsg(t, <-mux1.Outgoing, pid1, []byte(s)) } // test incoming p1 for _, s := range []string{"foo", "bar", "baz"} { d, err := wrapData([]byte(s), pid1) if err != nil { t.Error(err) } mux1.Incoming <- msg.New(peer1, d) testMsg(t, <-p1.Incoming, []byte(s)) } // test outgoing p2 for _, s := range []string{"foo", "bar", "baz"} { p2.Outgoing <- msg.New(peer1, []byte(s)) testWrappedMsg(t, <-mux1.Outgoing, pid2, []byte(s)) } // test incoming p2 for _, s := range []string{"foo", "bar", "baz"} { d, err := wrapData([]byte(s), pid2) if err != nil { t.Error(err) } mux1.Incoming <- msg.New(peer1, d) testMsg(t, <-p2.Incoming, []byte(s)) } }
func (t *ReverseHandler) HandleMessage(ctx context.Context, m msg.NetMessage) msg.NetMessage { d := m.Data() for i, j := 0, len(d)-1; i < j; i, j = i+1, j-1 { d[i], d[j] = d[j], d[i] } return msg.New(m.Peer(), d) }
func (s *service) handleIncomingMessage(m msg.NetMessage) { defer s.Children().Done() // unwrap the incoming message data, rid, err := unwrapData(m.Data()) if err != nil { log.Errorf("service de-serializing error: %v", err) return } m2 := msg.New(m.Peer(), data) // if it's a request (or has no RequestID), handle it if rid == nil || rid.IsRequest() { handler := s.GetHandler() if handler == nil { log.Errorf("service dropped msg: %v", m) return // no handler, drop it. } // should this be "go HandleMessage ... ?" r1 := handler.HandleMessage(s.Context(), m2) // if handler gave us a response, send it back out! if r1 != nil { err := s.sendMessage(s.Context(), r1, rid.Response()) if err != nil { log.Errorf("error sending response message: %v", err) } } return } // Otherwise, it is a response. handle it. if !rid.IsResponse() { log.Errorf("RequestID should identify a response here.") } key := RequestKey(m.Peer().ID(), RequestID(rid)) s.RequestsLock.RLock() r, found := s.Requests[key] s.RequestsLock.RUnlock() if !found { log.Errorf("no request key %v (timeout?)", []byte(key)) return } select { case r.Response <- m2: case <-s.Closing(): } }
func pong(ctx context.Context, swarm *Swarm) { i := 0 for { select { case <-ctx.Done(): return case m1 := <-swarm.Incoming: if bytes.Equal(m1.Data(), []byte("ping")) { m2 := msg.New(m1.Peer(), []byte("pong")) i++ log.Debugf("%s pong %s (%d)", swarm.local, m1.Peer(), i) swarm.Outgoing <- m2 } } } }
// sendMessage sends a message out (actual leg work. SendMessage is to export w/o rid) func (s *service) sendMessage(ctx context.Context, m msg.NetMessage, rid RequestID) error { // serialize ServiceMessage wrapper data, err := wrapData(m.Data(), rid) if err != nil { return err } // log.Debug("Service send message [to = %s]", m.Peer()) // send message m2 := msg.New(m.Peer(), data) select { case s.Outgoing <- m2: case <-ctx.Done(): return ctx.Err() } return nil }
// handleOutgoingMessage wraps out a message and sends it out the func (m *Muxer) handleOutgoingMessage(pid pb.ProtocolID, m1 msg.NetMessage) { defer m.Children().Done() data, err := wrapData(m1.Data(), pid) if err != nil { log.Errorf("muxer serializing error: %v", err) return } m.bwoLock.Lock() // TODO: compensate for overhead // TODO(jbenet): switch this to a goroutine to prevent sync waiting. m.bwOut += uint64(len(data)) m.bwoLock.Unlock() m2 := msg.New(m1.Peer(), data) select { case m.GetPipe().Outgoing <- m2: case <-m.Closing(): return } }
func SubtestSwarm(t *testing.T, addrs []string, MsgNum int) { // t.Skip("skipping for another test") ctx := context.Background() swarms, peers := makeSwarms(ctx, t, addrs) // connect everyone { var wg sync.WaitGroup connect := func(s *Swarm, dst peer.Peer) { // copy for other peer cp, err := s.peers.Get(dst.ID()) if err != nil { t.Fatal(err) } cp.AddAddress(dst.Addresses()[0]) log.Info("SWARM TEST: %s dialing %s", s.local, dst) if _, err := s.Dial(cp); err != nil { t.Fatal("error swarm dialing to peer", err) } log.Info("SWARM TEST: %s connected to %s", s.local, dst) wg.Done() } log.Info("Connecting swarms simultaneously.") for _, s := range swarms { for _, p := range peers { if p != s.local { // don't connect to self. wg.Add(1) connect(s, p) } } } wg.Wait() } // ping/pong for _, s1 := range swarms { ctx, cancel := context.WithCancel(ctx) // setup all others to pong for _, s2 := range swarms { if s1 == s2 { continue } go pong(ctx, s2) } peers, err := s1.peers.All() if err != nil { t.Fatal(err) } for k := 0; k < MsgNum; k++ { for _, p := range *peers { log.Debugf("%s ping %s (%d)", s1.local, p, k) s1.Outgoing <- msg.New(p, []byte("ping")) } } got := map[u.Key]int{} for k := 0; k < (MsgNum * len(*peers)); k++ { log.Debugf("%s waiting for pong (%d)", s1.local, k) msg := <-s1.Incoming if string(msg.Data()) != "pong" { t.Error("unexpected conn output", msg.Data) } n, _ := got[msg.Peer().Key()] got[msg.Peer().Key()] = n + 1 } if len(*peers) != len(got) { t.Error("got less messages than sent") } for p, n := range got { if n != MsgNum { t.Error("peer did not get all msgs", p, n, "/", MsgNum) } } cancel() <-time.After(50 * time.Microsecond) } for _, s := range swarms { s.Close() } }
func TestSimultMuxer(t *testing.T) { if testing.Short() { t.SkipNow() } // run muxer ctx, cancel := context.WithCancel(context.Background()) // setup p1 := &TestProtocol{Pipe: msg.NewPipe(10)} p2 := &TestProtocol{Pipe: msg.NewPipe(10)} pid1 := pb.ProtocolID_Test pid2 := pb.ProtocolID_Identify mux1 := NewMuxer(ctx, ProtocolMap{ pid1: p1, pid2: p2, }) peer1 := newPeer(t, "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275aaaaaa") // peer2 := newPeer(t, "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275bbbbbb") // counts total := 10000 speed := time.Microsecond * 1 counts := [2][2][2]int{} var countsLock sync.Mutex // run producers at every end sending incrementing messages produceOut := func(pid pb.ProtocolID, size int) { limiter := time.Tick(speed) for i := 0; i < size; i++ { <-limiter s := fmt.Sprintf("proto %v out %v", pid, i) m := msg.New(peer1, []byte(s)) mux1.Protocols[pid].GetPipe().Outgoing <- m countsLock.Lock() counts[pid][0][0]++ countsLock.Unlock() // log.Debug("sent %v", s) } } produceIn := func(pid pb.ProtocolID, size int) { limiter := time.Tick(speed) for i := 0; i < size; i++ { <-limiter s := fmt.Sprintf("proto %v in %v", pid, i) d, err := wrapData([]byte(s), pid) if err != nil { t.Error(err) } m := msg.New(peer1, d) mux1.Incoming <- m countsLock.Lock() counts[pid][1][0]++ countsLock.Unlock() // log.Debug("sent %v", s) } } consumeOut := func() { for { select { case m := <-mux1.Outgoing: data, pid, err := unwrapData(m.Data()) if err != nil { t.Error(err) } // log.Debug("got %v", string(data)) _ = data countsLock.Lock() counts[pid][1][1]++ countsLock.Unlock() case <-ctx.Done(): return } } } consumeIn := func(pid pb.ProtocolID) { for { select { case m := <-mux1.Protocols[pid].GetPipe().Incoming: countsLock.Lock() counts[pid][0][1]++ countsLock.Unlock() // log.Debug("got %v", string(m.Data())) _ = m case <-ctx.Done(): return } } } go produceOut(pid1, total) go produceOut(pid2, total) go produceIn(pid1, total) go produceIn(pid2, total) go consumeOut() go consumeIn(pid1) go consumeIn(pid2) limiter := time.Tick(speed) for { <-limiter countsLock.Lock() got := counts[0][0][0] + counts[0][0][1] + counts[0][1][0] + counts[0][1][1] + counts[1][0][0] + counts[1][0][1] + counts[1][1][0] + counts[1][1][1] countsLock.Unlock() if got == total*8 { cancel() return } } }