// Ensure the muxer closes connections that don't have a registered header byte. func TestMux_Listen_ErrUnregisteredHandler(t *testing.T) { // Open single listener on random port. ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatal(err) } defer ln.Close() // Write log output to a buffer to verify. var buf bytes.Buffer // Mux listener. m := mux.New(ln) m.Timeout = 1 * time.Second m.LogOutput = &buf if testing.Verbose() { m.LogOutput = io.MultiWriter(m.LogOutput, os.Stderr) } var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() m.Serve() }() // Send message to listener. conn, err := net.Dial("tcp", ln.Addr().String()) if err != nil { t.Fatal(err) } defer conn.Close() // Write unregistered header byte. if _, err := conn.Write([]byte{'\x80'}); err != nil { t.Fatal(err) } // Connection should close immediately. if _, err := ioutil.ReadAll(conn); err != nil { t.Fatalf("unexpected error: %s", err) } // Close connection and wait for server to finish. ln.Close() wg.Wait() // Verify error was logged. time.Sleep(100 * time.Millisecond) if s := buf.String(); !strings.Contains(s, `unregistered header byte: 0x80`) { t.Fatalf("unexpected log output:\n\n%s", s) } }
// Ensure two handlers cannot be registered for the same header byte. func TestMux_Listen_ErrAlreadyRegistered(t *testing.T) { defer func() { if r := recover(); r != `header byte already registered: 0x05` { t.Fatalf("unexpected recover: %#v", r) } }() // Register two listeners with the same header byte. mux := mux.New(nil) mux.Listen([]byte{'\x05'}) mux.Listen([]byte{'\x05'}) }
// Ensure the muxer can split a listener's connections across multiple listeners. func TestMux_Listen(t *testing.T) { // Open single listener on random port. ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatal(err) } defer ln.Close() // Create muxer for listener. m := mux.New(ln) m.Timeout = 1 * time.Second if testing.Verbose() { m.LogOutput = os.Stderr } // Create listeners and begin serving mux. m.Listen([]byte{'\x00'}) subln := m.Listen([]byte{'G', 'P', 'D'}) go m.Serve() // Send message to listener. go func() { conn, err := net.Dial("tcp", ln.Addr().String()) if err != nil { t.Fatal(err) } defer conn.Close() // Write data & close. if _, err := conn.Write([]byte("GET")); err != nil { t.Fatal(err) } else if err = conn.Close(); err != nil { t.Fatal(err) } }() // Receive connection on appropriate listener. conn, err := subln.Accept() if err != nil { t.Fatal(err) } defer conn.Close() // Read message. if buf, err := ioutil.ReadAll(conn); err != nil { t.Fatal(err) } else if string(buf) != "GET" { t.Fatalf("unexpected message: %q", string(buf)) } else if err = conn.Close(); err != nil { t.Fatal(err) } }
// Mux multiplexes a listener into two listeners. func Mux(ln net.Listener) (storeLn, httpLn net.Listener) { m := mux.New(ln) // The store listens to everything prefixed with its header byte. storeLn = m.Listen([]byte{storeHdr}) // HTTP listens to all methods: CONNECT, DELETE, GET, HEAD, OPTIONS, POST, PUT, TRACE. httpLn = m.Listen([]byte{'C', 'D', 'G', 'H', 'O', 'P', 'T'}) go m.Serve() return }
// Run executes the program. func (m *Main) Run(args ...string) error { // Create logger. m.logger = log.New(m.Stdout, "", log.LstdFlags) // Parse command line flags. opt, err := m.ParseFlags(args...) if err != nil { return err } // Set up advertised address and default peer set. m.advertiseAddr = MergeHostPort(opt.Host, opt.Addr) if len(opt.Peers) == 0 { opt.Peers = []string{m.advertiseAddr} } // Create a slice of peers with their HTTP address set instead. httpPeers, err := SetPortSlice(opt.Peers, opt.Addr) if err != nil { return fmt.Errorf("set port slice: %s", err) } m.peers = httpPeers // Initialise the default client using the peer list os.Setenv("DISCOVERD", strings.Join(opt.Peers, ",")) discoverd.DefaultClient = discoverd.NewClient() // if there is a discoverd process already running on this // address perform a deployment by starting a proxy DNS server // and shutting down the old discoverd job var deploy *dd.Deployment var targetLogIndex dt.TargetLogIndex target := fmt.Sprintf("http://%s:1111", opt.Host) m.logger.Println("checking for existing discoverd process at", target) if err := discoverd.NewClientWithURL(target).Ping(target); err == nil { m.logger.Println("discoverd responding at", target, "taking over") deploy, err = dd.NewDeployment("discoverd") if err != nil { return err } m.logger.Println("Created deployment") if err := deploy.MarkPerforming(m.advertiseAddr, 60); err != nil { return err } m.logger.Println("marked", m.advertiseAddr, "as performing in deployent") addr, resolvers := waitHostDNSConfig() if opt.DNSAddr != "" { addr = opt.DNSAddr } if len(opt.Recursors) > 0 { resolvers = opt.Recursors } m.logger.Println("starting proxy DNS server") if err := m.openDNSServer(addr, resolvers); err != nil { return fmt.Errorf("Failed to start DNS server: %s", err) } m.logger.Printf("discoverd listening for DNS on %s", addr) targetLogIndex, err = discoverd.NewClientWithURL(target).Shutdown(target) if err != nil { return err } // Sleep for 2x the election timeout. // This is to work around an issue with hashicorp/raft that can allow us to be elected with // no log entries, hence truncating the log and losing all data! time.Sleep(2 * time.Second) } else { m.logger.Println("failed to contact existing discoverd server, starting up without takeover") m.logger.Println("err:", err) } // Open listener. ln, err := net.Listen("tcp4", opt.Addr) if err != nil { return err } m.ln = keepalive.Listener(ln) // Open mux m.mux = mux.New(m.ln) go m.mux.Serve() m.dataDir = opt.DataDir // if the advertise addr is not in the peer list we are proxying proxying := true for _, addr := range m.peers { if addr == m.advertiseAddr { proxying = false break } } if proxying { // Notify user that we're proxying if the store wasn't initialized. m.logger.Println("advertised address not in peer set, joining as proxy") } else { // Open store if we are not proxying. if err := m.openStore(); err != nil { return fmt.Errorf("Failed to open store: %s", err) } } // Wait for the store to catchup before switching to local store if we are doing a deployment if m.store != nil && targetLogIndex.LastIndex > 0 { for m.store.LastIndex() < targetLogIndex.LastIndex { m.logger.Println("Waiting for store to catchup, current:", m.store.LastIndex(), "target:", targetLogIndex.LastIndex) time.Sleep(100 * time.Millisecond) } } // If we already started the DNS server as part of a deployment above, // and we have an initialized store, just switch from the proxy store // to the initialized store. // // Else if we have a DNS address, start a DNS server right away. // // Otherwise wait for the host network to come up and then start a DNS // server. if m.dnsServer != nil && m.store != nil { m.dnsServer.SetStore(m.store) } else if opt.DNSAddr != "" { if err := m.openDNSServer(opt.DNSAddr, opt.Recursors); err != nil { return fmt.Errorf("Failed to start DNS server: %s", err) } m.logger.Printf("discoverd listening for DNS on %s", opt.DNSAddr) } else if opt.WaitNetDNS { go func() { addr, resolvers := waitHostDNSConfig() m.mu.Lock() if err := m.openDNSServer(addr, resolvers); err != nil { log.Fatalf("Failed to start DNS server: %s", err) } m.mu.Unlock() m.logger.Printf("discoverd listening for DNS on %s", addr) // Notify webhook. if opt.Notify != "" { m.Notify(opt.Notify, addr) } }() } if err := m.openHTTPServer(); err != nil { return fmt.Errorf("Failed to start HTTP server: %s", err) } if deploy != nil { if err := deploy.MarkDone(m.advertiseAddr); err != nil { return err } m.logger.Println("marked", m.advertiseAddr, "as done in deployment") } // Notify user that the servers are listening. m.logger.Printf("discoverd listening for HTTP on %s", opt.Addr) // Wait for leadership. if err := m.waitForLeader(IndefiniteTimeout); err != nil { return err } // Notify URL that discoverd is running. httpAddr := ln.Addr().String() host, port, _ := net.SplitHostPort(httpAddr) if host == "0.0.0.0" { httpAddr = net.JoinHostPort(os.Getenv("EXTERNAL_IP"), port) } m.Notify(opt.Notify, opt.DNSAddr) go func() { for { hb, err := discoverd.AddServiceAndRegister("discoverd", httpAddr) if err != nil { m.logger.Println("failed to register service/instance, retrying in 5 seconds:", err) time.Sleep(5 * time.Second) continue } m.mu.Lock() m.hb = hb m.mu.Unlock() break } }() return nil }