func (t *transport) handlePing(ch *e3x.Channel) { var ( err error pkt *lob.Packet id string alive bool ) defer ch.Close() pkt, err = ch.ReadPacket() if err != nil { // log error // tracef("error: %s", err) return } id, _ = pkt.Header().GetString("vn") rpc := t.lookupRPC(id) if rpc == nil { alive = false } else { alive = true } pkt = &lob.Packet{} pkt.Header().SetBool("alive", alive) err = ch.WritePacket(pkt) if err != nil { // log error // tracef("error: %s", err) return } }
func (c *Channel) traceReceivedPacket(pkt *lob.Packet) { if tracer.Enabled { tracer.Emit("channel.rcv.packet", tracer.Info{ "channel_id": c.TID, "packet_id": pkt.TID, "packet": tracer.Info{ "header": pkt.Header(), "body": base64.StdEncoding.EncodeToString(pkt.Body(nil)), }, }) } }
func (t *transport) handleGetPredecessor(ch *e3x.Channel) { var ( err error pkt *lob.Packet id string vnode *chord.Vnode res *completeVnode ) defer ch.Close() pkt, err = ch.ReadPacket() if err != nil { // log error // tracef("error: %s", err) return } id, _ = pkt.Header().GetString("vn") rpc := t.lookupRPC(id) if rpc == nil { // log // tracef("error: %s", "no RPC") return } vnode, err = rpc.GetPredecessor() if err != nil { // log // tracef("error: %s", err) return } res = t.completeVnode(vnode) err = json.NewEncoder(newStream(ch)).Encode(&res) if err != nil { // log // tracef("error: %s", err) return } if res != nil { // tracef("handle.GetPredecessor(Vnode(%q)) => Vnode(%q)", id, res.Id) } }
func (c *Channel) traceDroppedPacket(pkt *lob.Packet, reason string) { if tracer.Enabled { info := tracer.Info{ "channel_id": c.TID, "packet_id": pkt.TID, "reason": reason, } if pkt != nil { info["packet"] = tracer.Info{ "header": pkt.Header(), "body": base64.StdEncoding.EncodeToString(pkt.Body(nil)), } } tracer.Emit("channel.rcv.packet", info) } }
func (c *Channel) applyAckHeaders(pkt *lob.Packet) { if !c.reliable { return } if c.iSeq == cBlankSeq { // nothin to ack return } hdr := pkt.Header() if c.iSeq >= cInitialSeq { hdr.Ack, hdr.HasAck = c.iSeq, true } if l := c.buildMissList(); len(l) > 0 { hdr.Miss, hdr.HasMiss = l, true } c.iAckedSeq = c.iSeq }
// Request a nodes predecessor func (t *transport) GetPredecessor(vn *chord.Vnode) (*chord.Vnode, error) { var ( addr *e3x.Addr ch *e3x.Channel pkt *lob.Packet res *completeVnode err error ) addr = t.lookupAddr(hashname.H(vn.Host)) if addr == nil { return nil, e3x.ErrNoAddress } ch, err = t.e.Open(addr, "chord.predecessor.get", true) if err != nil { return nil, err } defer ch.Close() // ch.SetReadDeadline(time.Now().Add(30*time.Second)) // ch.SetWriteDeadline(time.Now().Add(30*time.Second)) pkt = &lob.Packet{} pkt.Header().SetString("vn", vn.String()) err = ch.WritePacket(pkt) if err != nil { return nil, err } err = json.NewDecoder(newStream(ch)).Decode(&res) if err != nil { return nil, err } if res != nil { // tracef("GetPredecessor(Vnode(%q)) => Vnode(%q)", vn.String(), res.Id) } return t.internalVnode(res), nil }
func (c *Channel) traceWrite(pkt *lob.Packet, p *Pipe) { if tracer.Enabled { info := tracer.Info{ "channel_id": c.TID, } if p != nil { info["path"] = p.RemoteAddr().String() } if pkt != nil { info["packet_id"] = pkt.TID info["packet"] = tracer.Info{ "header": pkt.Header(), "body": base64.StdEncoding.EncodeToString(pkt.Body(nil)), } } tracer.Emit("channel.write", info) } }
// Ping a Vnode, check for liveness func (t *transport) Ping(vn *chord.Vnode) (bool, error) { var ( addr *e3x.Addr ch *e3x.Channel pkt *lob.Packet alive bool err error ) addr = t.lookupAddr(hashname.H(vn.Host)) if addr == nil { return false, e3x.ErrNoAddress } ch, err = t.e.Open(addr, "chord.ping", true) if err != nil { return false, err } defer ch.Close() // ch.SetReadDeadline(time.Now().Add(30*time.Second)) // ch.SetWriteDeadline(time.Now().Add(30*time.Second)) pkt = &lob.Packet{} pkt.Header().SetString("vn", vn.String()) err = ch.WritePacket(pkt) if err != nil { return false, err } pkt, err = ch.ReadPacket() if err != nil { return false, err } alive, _ = pkt.Header().GetBool("alive") // tracef("Ping(Vnode(%q)) => %v", vn.String(), alive) return alive, nil }
func (c *Channel) traceWriteError(pkt *lob.Packet, p *Pipe, reason error) error { if tracer.Enabled { info := tracer.Info{ "channel_id": c.TID, "reason": reason.Error(), } if p != nil { info["path"] = p.RemoteAddr().String() } if pkt != nil { info["packet_id"] = pkt.TID info["packet"] = tracer.Info{ "header": pkt.Header(), "body": base64.StdEncoding.EncodeToString(pkt.Body(nil)), } } tracer.Emit("channel.write.error", info) } return reason }
func TestBridge(t *testing.T) { // given: // A <-> B exchange // B <-> R exchange // A x-x R no exchange // // when: // R --> B route token from A->B to B // A --x B block A from contacting B (while adding R's addresses to the exchange A->B) // // then: // A and B should still be able to communicate. assert := assert.New(t) var blacklist []net.Addr blacklistRule := func(src net.Addr) bool { t.Logf("FW(%v, src=%s)", blacklist, src) if len(blacklist) == 0 { return true } for _, addr := range blacklist { if transports.EqualAddr(addr, src) { return false } } return true } A, err := e3x.Open( e3x.Log(nil), e3x.Transport(udp.Config{}), Module(Config{})) assert.NoError(err) B, err := e3x.Open( e3x.Log(nil), e3x.Transport(fw.Config{Config: udp.Config{}, Allow: fw.RuleFunc(blacklistRule)}), Module(Config{})) assert.NoError(err) R, err := e3x.Open( e3x.Log(nil), e3x.Transport(udp.Config{}), Module(Config{})) assert.NoError(err) done := make(chan bool, 1) go func() { var ( pkt *lob.Packet err error n = 1 first = true ) defer func() { done <- true }() c, err := A.Listen("ping", true).AcceptChannel() defer c.Close() for ; n > 0; n-- { pkt, err = c.ReadPacket() if err != nil { t.Fatalf("ping: error: %s", err) return } if first { n, _ = pkt.Header().GetInt("n") first = false } t.Logf("RCV ping: %d", n) err = c.WritePacket(&lob.Packet{}) if err != nil { t.Fatalf("ping: error: %s", err) return } t.Logf("SND pong: %d", n) } }() Aident, err := A.LocalIdentity() assert.NoError(err) Bident, err := B.LocalIdentity() assert.NoError(err) { addr, err := transports.ResolveAddr("peer", string(R.LocalHashname())) assert.NoError(err) Bident = Bident.AddPathCandiate(addr) } log.Println("\x1B[31m------------------------------------------------\x1B[0m") // blacklist A blacklist = append(blacklist, Aident.Addresses()...) log.Println("\x1B[32mblacklist:\x1B[0m", blacklist) log.Println("\x1B[31m------------------------------------------------\x1B[0m") _, err = R.Dial(Bident) assert.NoError(err) _, err = R.Dial(Aident) assert.NoError(err) ABex, err := A.Dial(Bident) assert.NoError(err) log.Println("\x1B[31m------------------------------------------------\x1B[0m") log.Printf("ab-local-token = %x", ABex.LocalToken()) log.Printf("ab-remote-token = %x", ABex.RemoteToken()) log.Println("\x1B[31m------------------------------------------------\x1B[0m") { ch, err := B.Open(Aident, "ping", true) assert.NoError(err) for n := 10; n > 0; n-- { pkt := &lob.Packet{} pkt.Header().SetInt("n", n) err = ch.WritePacket(pkt) if err != nil { t.Fatalf("ping: error: %s", err) } t.Logf("SND ping: %d", n) _, err = ch.ReadPacket() if err != nil { t.Fatalf("ping: error: %s", err) } t.Logf("RCV pong: %d", n) } ch.Close() } <-done assert.NoError(A.Close()) assert.NoError(B.Close()) assert.NoError(R.Close()) }
func (s *state) DecryptPacket(pkt *lob.Packet) (*lob.Packet, error) { s.mtx.RLock() defer s.mtx.RUnlock() if !s.CanDecryptPacket() { return nil, cipherset.ErrInvalidState } if pkt == nil { return nil, nil } if !pkt.Header().IsZero() || pkt.BodyLen() < 16+4+4 { return nil, cipherset.ErrInvalidPacket } var ( nonce [16]byte bodyRaw []byte innerRaw []byte innerLen = pkt.BodyLen() - (16 + 4 + 4) body = bufpool.New() inner = bufpool.New().SetLen(innerLen) ) pkt.Body(body.SetLen(pkt.BodyLen()).RawBytes()[:0]) bodyRaw = body.RawBytes() innerRaw = inner.RawBytes() // compare token if !bytes.Equal(bodyRaw[:16], (*s.localToken)[:]) { inner.Free() body.Free() return nil, cipherset.ErrInvalidPacket } // copy nonce copy(nonce[:], bodyRaw[16:16+4]) { // verify hmac mac := bodyRaw[16+4+innerLen:] macKey := append(s.lineDecryptionKey, nonce[:4]...) h := hmac.New(sha256.New, macKey) h.Write(bodyRaw[16+4 : 16+4+innerLen]) if subtle.ConstantTimeCompare(mac, fold(h.Sum(nil), 4)) != 1 { inner.Free() body.Free() return nil, cipherset.ErrInvalidPacket } } { // decrypt inner aesBlock, err := aes.NewCipher(s.lineDecryptionKey) if err != nil { inner.Free() body.Free() return nil, err } aes := Cipher.NewCTR(aesBlock, nonce[:]) if aes == nil { inner.Free() body.Free() return nil, cipherset.ErrInvalidPacket } aes.XORKeyStream(innerRaw, bodyRaw[16+4:16+4+innerLen]) } innerPkt, err := lob.Decode(inner) if err != nil { inner.Free() body.Free() return nil, err } inner.Free() body.Free() return innerPkt, nil }
func (s *cipherTestSuite) TestPacketEncryption() { var ( assert = s.Assertions c = s.cipher ) var ( ka cipherset.Key kb cipherset.Key sa cipherset.State sb cipherset.State ha cipherset.Handshake hb cipherset.Handshake pkt *lob.Packet box []byte err error ok bool ) ka, err = c.GenerateKey() assert.NoError(err) kb, err = c.GenerateKey() assert.NoError(err) sa, err = c.NewState(ka) assert.NoError(err) sb, err = c.NewState(kb) assert.NoError(err) err = sa.SetRemoteKey(kb) assert.NoError(err) box, err = sa.EncryptHandshake(1, nil) assert.NoError(err) hb, err = c.DecryptHandshake(kb, box) assert.NoError(err) ok = sb.ApplyHandshake(hb) assert.True(ok) box, err = sb.EncryptHandshake(1, nil) assert.NoError(err) ha, err = c.DecryptHandshake(ka, box) assert.NoError(err) ok = sa.ApplyHandshake(ha) assert.True(ok) pkt = lob.New([]byte("Hello world!")) pkt.Header().SetInt("foo", 0xbeaf) pkt, err = sa.EncryptPacket(pkt) assert.NoError(err) assert.NotNil(pkt) assert.Nil(pkt.Header().Bytes) assert.True(pkt.Header().IsZero()) assert.NotEmpty(pkt.Body) pkt, err = sb.DecryptPacket(pkt) assert.NoError(err) assert.NotNil(pkt) assert.Nil(pkt.Header().Bytes) assert.Equal(&lob.Header{Extra: map[string]interface{}{"foo": 0xbeaf}}, pkt.Header()) assert.Equal([]byte("Hello world!"), pkt.Body(nil)) pkt = lob.New([]byte("Bye world!")) pkt.Header().SetInt("bar", 0xdead) pkt, err = sb.EncryptPacket(pkt) assert.NoError(err) assert.NotNil(pkt) assert.Nil(pkt.Header().Bytes) assert.True(pkt.Header().IsZero()) assert.NotEmpty(pkt.Body) pkt, err = sa.DecryptPacket(pkt) assert.NoError(err) assert.NotNil(pkt) assert.Nil(pkt.Header().Bytes) assert.Equal(&lob.Header{Extra: map[string]interface{}{"bar": 0xdead}}, pkt.Header()) assert.Equal([]byte("Bye world!"), pkt.Body(nil)) }
func (s *state) DecryptPacket(pkt *lob.Packet) (*lob.Packet, error) { s.mtx.RLock() defer s.mtx.RUnlock() if !s.CanDecryptPacket() { return nil, cipherset.ErrInvalidState } if pkt == nil { return nil, nil } if !pkt.Header().IsZero() || pkt.BodyLen() < lenToken+lenNonce { return nil, cipherset.ErrInvalidPacket } var ( nonce [lenNonce]byte bodyRaw []byte innerRaw []byte innerPkt *lob.Packet body = bufpool.New() inner = bufpool.New() ok bool ) pkt.Body(body.SetLen(pkt.BodyLen()).RawBytes()[:0]) bodyRaw = body.RawBytes() innerRaw = inner.RawBytes() // compare token if !bytes.Equal(bodyRaw[:lenToken], (*s.localToken)[:]) { inner.Free() body.Free() return nil, cipherset.ErrInvalidPacket } // copy nonce copy(nonce[:], bodyRaw[lenToken:lenToken+lenNonce]) // decrypt inner packet innerRaw, ok = box.OpenAfterPrecomputation( innerRaw[:0], bodyRaw[lenToken+lenNonce:], &nonce, s.lineDecryptionKey) if !ok { inner.Free() body.Free() return nil, cipherset.ErrInvalidPacket } inner.SetLen(len(innerRaw)) innerPkt, err := lob.Decode(inner) if err != nil { inner.Free() body.Free() return nil, err } inner.Free() body.Free() return innerPkt, nil }
func (c *Channel) receivedPacket(pkt *lob.Packet) { const ( errBrokenChannel = "broken channel" errMissingSeq = "missing seq" errDuplicatePacket = "duplicate packet" errFullBuffer = "full buffer" ) c.mtx.Lock() if c.broken { c.mtx.Unlock() c.traceDroppedPacket(pkt, errBrokenChannel) statChannelRcvPktDrop.Add(1) return } var ( hdr = pkt.Header() seq, hasSeq = hdr.Seq, hdr.HasSeq ack, hasAck = hdr.Ack, hdr.HasAck miss, hasMiss = hdr.Miss, hdr.HasMiss end, hasEnd = hdr.End, hdr.HasEnd ) if !c.reliable { // unreliable channels (internaly) emulate reliable channels. seq = c.iBufferedSeq + 1 hasSeq = true } else { // determine what to drop from the write buffer if hasAck { if hasSeq { statChannelRcvAckInline.Add(1) } else { statChannelRcvAckAdHoc.Add(1) } var ( oldAck = c.oAckedSeq changed bool ) if c.oAckedSeq < ack { c.oAckedSeq = ack changed = true } for i := oldAck + 1; i <= ack; i++ { if e := c.writeBuffer[i]; e != nil { e.pkt.Free() } delete(c.writeBuffer, i) changed = true } if len(c.writeBuffer) == 0 { c.needsResend = false } if changed { c.cndWrite.Signal() if c.deliveredEnd || c.receivedEnd { c.cndClose.Signal() } } if hasMiss { c.processMissingPackets(ack, miss) } } } if !hasSeq { // drop: is not a valid packet c.mtx.Unlock() c.traceDroppedPacket(pkt, errMissingSeq) if !hasAck { statChannelRcvPktDrop.Add(1) } return } if c.reliable && c.iSeenSeq < seq { // record highest seen seq c.iSeenSeq = seq } if seq <= c.iSeq { // drop: the reader already read a packet with this seq c.mtx.Unlock() c.traceDroppedPacket(pkt, errDuplicatePacket) statChannelRcvPktDrop.Add(1) return } if len(c.readBuffer) >= cReadBufferSize { // drop: the read buffer is full c.mtx.Unlock() c.traceDroppedPacket(pkt, errFullBuffer) statChannelRcvPktDrop.Add(1) return } if c.readBuffer.IndexOf(seq) >= 0 { // drop: a packet with this seq is already buffered c.mtx.Unlock() c.traceDroppedPacket(pkt, errDuplicatePacket) statChannelRcvPktDrop.Add(1) return } if c.iBufferedSeq < seq { c.iBufferedSeq = seq } if end && hasEnd { c.receivedEnd = true c.deliverAck() } c.readBuffer = append(c.readBuffer, &readBufferEntry{pkt, seq, end}) sort.Sort(c.readBuffer) c.cndRead.Signal() c.mtx.Unlock() c.traceReceivedPacket(pkt) statChannelRcvPkt.Add(1) }
func (c *Channel) write(pkt *lob.Packet, p *Pipe) error { if pkt.TID == 0 { pkt.TID = tracer.NewID() } if c.broken { // When a channel is marked as broken the all writes // must return a BrokenChannelError. return c.traceWriteError(pkt, p, &BrokenChannelError{c.hashname, c.typ, c.id}) } if c.writeDeadlineReached { // When a channel reached a write deadline then all writes // must return a ErrTimeout. return c.traceWriteError(pkt, p, ErrTimeout) } if c.deliveredEnd { // When a channel sent a packet with the "end" header set // then all subsequent writes must return io.EOF return c.traceWriteError(pkt, p, io.EOF) } c.oSeq++ hdr := pkt.Header() hdr.C, hdr.HasC = c.id, true if c.reliable { hdr.Seq, hdr.HasSeq = c.oSeq, true } if !c.serverside && c.oSeq == cInitialSeq { hdr.Type, hdr.HasType = c.typ, true } end := hdr.HasEnd && hdr.End if end { c.deliveredEnd = true c.setCloseDeadline() } if c.reliable { if c.oSeq%30 == 0 || hdr.End { c.applyAckHeaders(pkt) } c.writeBuffer[c.oSeq] = &writeBufferEntry{pkt, end, time.Time{}, p} c.needsResend = false } err := c.x.deliverPacket(pkt, p) if err != nil { return c.traceWriteError(pkt, p, err) } statChannelSndPkt.Add(1) if pkt.Header().HasAck { statChannelSndAckInline.Add(1) } if c.oSeq == cInitialSeq && c.serverside { c.unsetOpenDeadline() } c.traceWrite(pkt, p) if !c.reliable { pkt.Free() } return nil }
func TestFloodReliable(t *testing.T) { if testing.Short() { t.Skip("this is a long running test.") } withTwoEndpoints(t, func(A, B *Endpoint) { A.setOptions(DisableLog()) B.setOptions(DisableLog()) var ( assert = assert.New(t) c *Channel ident *Identity pkt *lob.Packet err error ) go func() { c, err := A.Listen("flood", true).AcceptChannel() if assert.NoError(err) && assert.NotNil(c) { defer c.Close() pkt, err = c.ReadPacket() assert.NoError(err) assert.NotNil(pkt) for i := 0; i < 100000; i++ { pkt := lob.New(nil) pkt.Header().SetInt("flood_id", i) err = c.WritePacket(pkt) assert.NoError(err) } } }() ident, err = A.LocalIdentity() assert.NoError(err) c, err = B.Open(ident, "flood", true) assert.NoError(err) assert.NotNil(c) defer c.Close() err = c.WritePacket(lob.New(nil)) assert.NoError(err) lastID := -1 for { pkt, err = c.ReadPacket() if err == io.EOF { break } assert.NoError(err) assert.NotNil(pkt) if err != nil { break } if pkt != nil { id, _ := pkt.Header().GetInt("flood_id") assert.True(lastID < id) lastID = id } } }) }