func forwardQuestion(q *dns.Question, forwarders []string) []dns.RR { //qType := dns.Type(q.Qtype).String() // query type //log.Printf("[Forwarder Lookup [%s] [%s]]\n", q.Name, qType) myReq := new(dns.Msg) myReq.SetQuestion(q.Name, q.Qtype) if len(forwarders) == 0 { // we have no upstreams, so we'll just not use any } else if strings.TrimSpace(forwarders[0]) == "!" { // we've been told explicitly to not pass anything along to any upsteams } else { c := new(dns.Client) for _, server := range forwarders { c.Net = "udp" m, _, err := c.Exchange(myReq, strings.TrimSpace(server)) if m != nil && m.MsgHdr.Truncated { c.Net = "tcp" m, _, err = c.Exchange(myReq, strings.TrimSpace(server)) } // FIXME: Cache misses. And cache hits, too. if err != nil { //log.Printf("[Forwarder Lookup [%s] [%s] failed: [%s]]\n", q.Name, qType, err) log.Println(err) } else { //log.Printf("[Forwarder Lookup [%s] [%s] success]\n", q.Name, qType) return m.Answer } } } return nil }
// Returns an anonymous function configured to resolve DNS // queries with a specific set of remote servers. func ServerHandler(addresses []string) handler { randGen := rand.New(rand.NewSource(time.Now().UnixNano())) // This is the actual handler return func(w dns.ResponseWriter, req *dns.Msg) { nameserver := addresses[randGen.Intn(len(addresses))] var protocol string switch t := w.RemoteAddr().(type) { default: log.Printf("ERROR: Unsupported protocol %T\n", t) return case *net.UDPAddr: protocol = "udp" case *net.TCPAddr: protocol = "tcp" } for _, q := range req.Question { log.Printf("Incoming request #%v: %s %s %v - using %s\n", req.Id, dns.ClassToString[q.Qclass], dns.TypeToString[q.Qtype], q.Name, nameserver) } c := new(dns.Client) c.Net = protocol resp, rtt, err := c.Exchange(req, nameserver) Redo: switch { case err != nil: log.Printf("ERROR: %s\n", err.Error()) sendFailure(w, req) return case req.Id != resp.Id: log.Printf("ERROR: Id mismatch: %v != %v\n", req.Id, resp.Id) sendFailure(w, req) return case resp.MsgHdr.Truncated && protocol != "tcp": log.Printf("WARNING: Truncated answer for request %v, retrying TCP\n", req.Id) c.Net = "tcp" resp, rtt, err = c.Exchange(req, nameserver) goto Redo } log.Printf("Request #%v: %.3d µs, server: %s(%s), size: %d bytes\n", resp.Id, rtt/1e3, nameserver, c.Net, resp.Len()) w.WriteMsg(resp) } // end of handler }
// shouldTransfer checks the primaries of zone, retrieves the SOA record, checks the current serial // and the remote serial and will return true if the remote one is higher than the locally configured one. func (z *Zone) shouldTransfer() (bool, error) { c := new(dns.Client) c.Net = "tcp" // do this query over TCP to minimize spoofing m := new(dns.Msg) m.SetQuestion(z.origin, dns.TypeSOA) var Err error serial := -1 Transfer: for _, tr := range z.TransferFrom { Err = nil ret, err := middleware.Exchange(c, m, tr) if err != nil || ret.Rcode != dns.RcodeSuccess { Err = err continue } for _, a := range ret.Answer { if a.Header().Rrtype == dns.TypeSOA { serial = int(a.(*dns.SOA).Serial) break Transfer } } } if serial == -1 { return false, Err } return less(z.Apex.SOA.Serial, uint32(serial)), Err }
func (this Server) DoDNSquery(m dns.Msg, TransProString string, server []string, timeout time.Duration) (*dns.Msg, error) { dnsClient := new(dns.Client) if dnsClient == nil { return nil, errors.New("Cannot create DNS client") } dnsClient.ReadTimeout = timeout dnsClient.WriteTimeout = timeout if TransProString != "TCP" && TransProString != "UDP" { return nil, errors.New("TransProString run") } dnsClient.Net = strings.ToLower(TransProString) ServerStr := server[rand.Intn(len(server))] ServerAddr := net.ParseIP(ServerStr) if ServerAddr.To16() != nil { ServerStr = "[" + ServerStr + "]:" + this.Port } else if ServerAddr.To4() != nil { ServerStr = ServerStr + this.Port } else { return nil, errors.New("invalid Server Address") } dnsResponse, _, err := dnsClient.Exchange(&m, ServerStr) if err != nil { return nil, err } return dnsResponse, nil }
// Perform DNS resolution func resolve(w http.ResponseWriter, r *http.Request, server string, domain string, querytype uint16) { m := new(dns.Msg) m.SetQuestion(domain, querytype) m.MsgHdr.RecursionDesired = true w.Header().Set("Content-Type", "application/json") w.Header().Set("Access-Control-Allow-Origin", "*") c := new(dns.Client) Redo: if in, _, err := c.Exchange(m, server); err == nil { // Second return value is RTT, not used for now if in.MsgHdr.Truncated { c.Net = "tcp" goto Redo } switch in.MsgHdr.Rcode { case dns.RcodeServerFailure: error(w, 500, 502, "The name server encountered an internal failure while processing this request (SERVFAIL)") case dns.RcodeNameError: error(w, 500, 503, "Some name that ought to exist, does not exist (NXDOMAIN)") case dns.RcodeRefused: error(w, 500, 505, "The name server refuses to perform the specified operation for policy or security reasons (REFUSED)") default: jsonify(w, r, in.Question, in.Answer, in.Ns, in.Extra) } } else { error(w, 500, 501, "DNS server could not be reached") } }
func lookup(msg *dns.Msg, client *dns.Client, server string, edns bool) (*dns.Msg, error) { if edns { opt := &dns.OPT{ Hdr: dns.RR_Header{ Name: ".", Rrtype: dns.TypeOPT, }, } opt.SetUDPSize(dns.DefaultMsgSize) msg.Extra = append(msg.Extra, opt) } response, _, err := client.Exchange(msg, server) if err != nil { return nil, err } if msg.Id != response.Id { return nil, fmt.Errorf("DNS ID mismatch, request: %d, response: %d", msg.Id, response.Id) } if response.MsgHdr.Truncated { if client.Net == "tcp" { return nil, fmt.Errorf("Got truncated message on tcp") } if edns { // Truncated even though EDNS is used client.Net = "tcp" } return lookup(msg, client, server, !edns) } return response, nil }
// NewDNSResolverImpl constructs a new DNS resolver object that utilizes the // provided list of DNS servers for resolution. func NewDNSResolverImpl( readTimeout time.Duration, servers []string, caaSERVFAILExceptions map[string]bool, stats metrics.Scope, clk clock.Clock, maxTries int, ) *DNSResolverImpl { // TODO(jmhodges): make constructor use an Option func pattern dnsClient := new(dns.Client) // Set timeout for underlying net.Conn dnsClient.ReadTimeout = readTimeout dnsClient.Net = "tcp" return &DNSResolverImpl{ dnsClient: dnsClient, servers: servers, allowRestrictedAddresses: false, caaSERVFAILExceptions: caaSERVFAILExceptions, maxTries: maxTries, clk: clk, stats: stats, txtStats: stats.NewScope("TXT"), aStats: stats.NewScope("A"), aaaaStats: stats.NewScope("AAAA"), caaStats: stats.NewScope("CAA"), mxStats: stats.NewScope("MX"), } }
func lookup(name string, queryType uint16, client *dns.Client, servAddr string, suffix string, edns bool) (*dns.Msg, error) { msg := &dns.Msg{} lname := strings.Join([]string{name, suffix}, ".") msg.SetQuestion(dns.Fqdn(lname), queryType) if edns { msg.SetEdns0(dns.DefaultMsgSize, false) } response, _, err := client.Exchange(msg, servAddr) if err == dns.ErrTruncated { if client.Net == "tcp" { return nil, fmt.Errorf("got truncated message on TCP (64kiB limit exceeded?)") } if edns { // Truncated even though EDNS is used client.Net = "tcp" } return lookup(name, queryType, client, servAddr, suffix, !edns) } if err != nil { return nil, err } if msg.Id != response.Id { return nil, fmt.Errorf("DNS ID mismatch, request: %d, response: %d", msg.Id, response.Id) } return response, nil }
func TestDNSForward(t *testing.T) { s := newTestServer("", "", "8.8.8.8:53") defer s.Stop() c := new(dns.Client) m := new(dns.Msg) m.SetQuestion("www.example.com.", dns.TypeA) resp, _, err := c.Exchange(m, "localhost:"+StrPort) if err != nil { t.Fatal(err) } if len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess { t.Fatal("Answer expected to have A records or rcode not equal to RcodeSuccess") } // TCP c.Net = "tcp" resp, _, err = c.Exchange(m, "localhost:"+StrPort) if err != nil { t.Fatal(err) } if len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess { t.Fatal("Answer expected to have A records or rcode not equal to RcodeSuccess") } // TODO(miek): DNSSEC DO query }
func TestDNSForwardLocal(t *testing.T) { forwardServer := newTestServer(t, false) service := &msg.Service{ Host: "199.43.132.53", Key: "a.skydns.test.", } addService(t, forwardServer, service.Key, 0, service) defer delService(t, forwardServer, service.Key) s := newTestServer(t, false) s.config.Nameservers = []string{forwardServer.config.DnsAddr} s.config.ForwardLocal = true defer s.Stop() c := new(dns.Client) m := new(dns.Msg) m.SetQuestion("a.skydns.test.", dns.TypeA) resp, _, err := c.Exchange(m, "127.0.0.1:"+StrPort) if err != nil { // try twice resp, _, err = c.Exchange(m, "127.0.0.1:"+StrPort) if err != nil { t.Fatal(err) } } if len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess { t.Fatal("answer expected to have A records or rcode not equal to RcodeSuccess") } // TCP c.Net = "tcp" resp, _, err = c.Exchange(m, "127.0.0.1:"+StrPort) if err != nil { t.Fatal(err) } if len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess { t.Fatal("answer expected to have A records or rcode not equal to RcodeSuccess") } m.SetQuestion("bbbbb.skydns.test.", dns.TypeA) resp, _, err = c.Exchange(m, "127.0.0.1:"+StrPort) if err != nil { t.Fatal(err) } if len(resp.Answer) != 0 { t.Fatal("answer expected to have zero A records") } // disable recursion and check s.config.NoRec = true m.SetQuestion("a.skydns.test.", dns.TypeA) resp, _, err = c.Exchange(m, "127.0.0.1:"+StrPort) if err != nil { t.Fatal(err) } if resp.Rcode != dns.RcodeServerFailure { t.Fatal("answer expected to have rcode equal to RcodeFailure") } }
func (r *Resolver) exchange(m *dns.Msg, a string, original []string) (res *dns.Msg, err error) { question := m.Question[0] sort.Strings(original) key := cachekey{question, fmt.Sprintf("%v", original)} if r.Debug { log.Println("KEY: ", key) } rt, ok := r.cache.Get(key) if ok { if r.Debug { log.Println("Cache HIT") } r1 := rt.(*dns.Msg) res = r1.Copy() return } if r.Debug { log.Println("Cache MISS") log.Println("QUERY: ", question.Name, "via", a) } res, err = dns.Exchange(m, a) if err != nil { if r.Debug { log.Println(err) } return } //Retry in case it was truncated if res.Truncated { if r.Debug { log.Println("truncated, retrying with tcp") } cl := new(dns.Client) cl.Net = "tcp" res, _, err = cl.Exchange(m, a) if err != nil { if r.Debug { log.Println(err) } return } } if r.Debug { log.Println(res) } if res.Rcode != dns.RcodeSuccess { return } if r.Debug { log.Println("Inserting into cache") } r.cache.Add(key, res) return }
/* * sendRequest() - send a DNS query */ func sendRequest(ctx *Context, m *dns.Msg, transport string, timeout time.Duration) (response *dns.Msg, rtt time.Duration, err error) { c := new(dns.Client) c.Timeout = timeout c.Net = transport // "udp" or "tcp" response, rtt, err = c.Exchange(m, addressString(ctx.server, ctx.port)) return }
func queryRemote(r *dns.Msg, ch chan *dns.Msg) { c := new(dns.Client) c.Net = "tcp" in, _, err := c.Exchange(r, *remoteDNS) if err == nil { ch <- in } else { logger.Printf("Failed to query, %q", err) } }
// tcpLookupIP is a helper to initiate a TCP-based DNS lookup for the given host. // The built-in Go resolver will do a UDP lookup first, and will only use TCP if // the response has the truncate bit set, which isn't common on DNS servers like // Consul's. By doing the TCP lookup directly, we get the best chance for the // largest list of hosts to join. Since joins are relatively rare events, it's ok // to do this rather expensive operation. func (m *Memberlist) tcpLookupIP(host string, defaultPort uint16) ([]ipPort, error) { // Don't attempt any TCP lookups against non-fully qualified domain // names, since those will likely come from the resolv.conf file. if !strings.Contains(host, ".") { return nil, nil } // Make sure the domain name is terminated with a dot (we know there's // at least one character at this point). dn := host if dn[len(dn)-1] != '.' { dn = dn + "." } // See if we can find a server to try. cc, err := dns.ClientConfigFromFile(m.config.DNSConfigPath) if err != nil { return nil, err } if len(cc.Servers) > 0 { // We support host:port in the DNS config, but need to add the // default port if one is not supplied. server := cc.Servers[0] if !hasPort(server) { server = net.JoinHostPort(server, cc.Port) } // Do the lookup. c := new(dns.Client) c.Net = "tcp" msg := new(dns.Msg) msg.SetQuestion(dn, dns.TypeANY) in, _, err := c.Exchange(msg, server) if err != nil { return nil, err } // Handle any IPs we get back that we can attempt to join. var ips []ipPort for _, r := range in.Answer { switch rr := r.(type) { case (*dns.A): ips = append(ips, ipPort{rr.A, defaultPort}) case (*dns.AAAA): ips = append(ips, ipPort{rr.AAAA, defaultPort}) case (*dns.CNAME): m.logger.Printf("[DEBUG] memberlist: Ignoring CNAME RR in TCP-first answer for '%s'", host) } } return ips, nil } return nil, nil }
func doQuery(c dns.Client, m dns.Msg, ds, dp string, queryType uint16, close chan struct{}) *dns.Msg { // r := &dns.Msg{} // var ee error //fmt.Println(utils.GetDebugLine(), " doQuery: ", " m.Question: ", m.Question, // " ds: ", ds, " dp: ", dp, " queryType ", queryType) utils.ServerLogger.Debug(" doQuery: m.Question: %v ds: %s dp: %s queryType: %v", m.Question, ds, dp, queryType) select { case <-close: return nil default: for l := 0; l < 3; l++ { r, _, ee := c.Exchange(&m, ds+":"+dp) if (ee != nil) || (r == nil) || (r.Answer == nil) { utils.ServerLogger.Error(" doQuery: retry: %s times error: %s", strconv.Itoa(l), ee.Error()) if (queryType == dns.TypeA) || (queryType == dns.TypeCNAME) { if strings.Contains(ee.Error(), "connection refused") { if c.Net == TCP { c.Net = UDP } } else if (ee == dns.ErrTruncated) && queryType == dns.TypeA { utils.ServerLogger.Error(" doQuery: response truncated: %v", r) // m.SetEdns0(4096,false) // m.SetQuestion(dns.Fqdn(domainName),dns.TypeCNAME) c.Net = TCP } else { if c.Net == TCP { c.Net = UDP } else { c.Net = TCP } } } } else { return r } } } return nil }
func main() { c := new(dns.Client) c.Net = "tcp" m := new(dns.Msg) m.SetQuestion("direct1.demo.direct-test.com.", dns.TypeCERT) //m.SetQuestion("kryptiq.direct-ci.com.", dns.TypeCERT) in, _, err := c.Exchange(m, "8.8.8.8:53") if err != nil { fmt.Println("ERROR: ", err) return } //fmt.Println("MsgHdr: ", in.MsgHdr) if in.MsgHdr.Rcode != 0 { fmt.Println("ERROR from DNS server: ", dns.RcodeToString[in.MsgHdr.Rcode]) return } //fmt.Println("Length of answer: ", len(in.Answer)) if rr, ok := in.Answer[0].(*dns.CERT); ok { //fmt.Println(rr.Type, rr.KeyTag, rr.Algorithm, rr.Certificate) asn, err := base64.StdEncoding.DecodeString(rr.Certificate) if err != nil { fmt.Println("Error b64 decoding: ", err) return } cert, err := x509.ParseCertificate(asn) if err != nil { fmt.Println("Error decoding cert: ", err) return } //fmt.Println(cert.Subject) //fmt.Println(cert.DNSNames) //fmt.Println(cert.EmailAddresses) //fmt.Println("--- Subject ---") //fmt.Println(cert.Subject.SerialNumber) /*for _, v := range cert.Subject.Names { fmt.Println(v) } fmt.Println("--- Issuer ---") for _, v := range cert.Issuer.Names { fmt.Println(v) }*/ fmt.Println("Subject: ", GetNameString(cert.Subject.Names)) fmt.Println("Issuer: ", GetNameString(cert.Issuer.Names)) } }
// Returns true if domain has a Name Server associated func queryNS(domain string, dnsServers []string, proto string) (int, error) { c := new(dns.Client) c.ReadTimeout = time.Duration(2 * time.Second) c.WriteTimeout = time.Duration(2 * time.Second) c.Net = proto m := new(dns.Msg) m.RecursionDesired = true dnsServer := dnsServers[rand.Intn(len(dnsServers))] m.SetQuestion(dns.Fqdn(domain), dns.TypeNS) in, _, err := c.Exchange(m, dnsServer+":53") if err == nil { return in.Rcode, err } return dns.RcodeRefused, err }
func fakeMsg(dom string, rrHeader uint16, proto string, serverPort int) (*dns.Msg, error) { qc := uint16(dns.ClassINET) c := new(dns.Client) c.Net = proto m := new(dns.Msg) m.Question = make([]dns.Question, 1) m.Question[0] = dns.Question{ Name: dns.Fqdn(dom), Qtype: rrHeader, Qclass: qc, } m.RecursionDesired = true in, _, err := c.Exchange(m, "127.0.0.1:"+strconv.Itoa(serverPort)) return in, err }
func (s *Server) recurse(w dns.ResponseWriter, req *dns.Msg) { if s.recurseTo == "" { dns.HandleFailed(w, req) return } c := new(dns.Client) in, _, err := c.Exchange(req, s.recurseTo) if err == nil { if in.MsgHdr.Truncated { c.Net = "tcp" in, _, err = c.Exchange(req, s.recurseTo) } w.WriteMsg(in) return } log.Warnf("Recursive error: %+v", err) dns.HandleFailed(w, req) }
func (s *server) HealthCheck() { c := new(dns.Client) c.Net = "tcp" m := new(dns.Msg) m.Question = make([]dns.Question, 1) m.Question[0] = dns.Question{HealthQuery, dns.TypeTXT, dns.ClassCHAOS} // doing this in the loop is not the best idea for _, serv := range s.router.Servers() { if !check(c, m, serv) { // do it again if !check(c, m, serv) { log.Printf("healthcheck failed for %s", serv) s.router.RemoveServer(serv) } } } }
// Get the key from the DNS (uses the local resolver) and return them. // If nothing is found we return nil func getKey(name string, keytag uint16, server string, tcp bool) *dns.DNSKEY { c := new(dns.Client) if tcp { c.Net = "tcp" } m := new(dns.Msg) m.SetQuestion(name, dns.TypeDNSKEY) m.SetEdns0(4096, true) r, _, err := c.Exchange(m, server) if err != nil { return nil } for _, k := range r.Answer { if k1, ok := k.(*dns.DNSKEY); ok { if k1.KeyTag() == keytag { return k1 } } } return nil }
func TestDNSForward(t *testing.T) { s := newTestServer(t, false) defer s.Stop() c := new(dns.Client) m := new(dns.Msg) m.SetQuestion("www.example.com.", dns.TypeA) resp, _, err := c.Exchange(m, "127.0.0.1:"+StrPort) if err != nil { // try twice resp, _, err = c.Exchange(m, "127.0.0.1:"+StrPort) if err != nil { t.Fatal(err) } } if len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess { t.Fatal("answer expected to have A records or rcode not equal to RcodeSuccess") } // TCP c.Net = "tcp" resp, _, err = c.Exchange(m, "127.0.0.1:"+StrPort) if err != nil { t.Fatal(err) } if len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess { t.Fatal("answer expected to have A records or rcode not equal to RcodeSuccess") } // disable recursion and check s.config.NoRec = true m.SetQuestion("www.example.com.", dns.TypeA) resp, _, err = c.Exchange(m, "127.0.0.1:"+StrPort) if err != nil { t.Fatal(err) } if resp.Rcode != dns.RcodeServerFailure { t.Fatal("answer expected to have rcode equal to RcodeFailure") } }
func lookup(name string, queryType uint16, client *dns.Client, servAddr string, suffix string, edns bool) (*dns.Msg, error) { msg := &dns.Msg{} lname := strings.Join([]string{name, suffix}, ".") msg.SetQuestion(dns.Fqdn(lname), queryType) if edns { opt := &dns.OPT{ Hdr: dns.RR_Header{ Name: ".", Rrtype: dns.TypeOPT, }, } opt.SetUDPSize(dns.DefaultMsgSize) msg.Extra = append(msg.Extra, opt) } response, _, err := client.Exchange(msg, servAddr) if err != nil { return nil, err } if msg.Id != response.Id { return nil, fmt.Errorf("DNS ID mismatch, request: %d, response: %d", msg.Id, response.Id) } if response.MsgHdr.Truncated { if client.Net == "tcp" { return nil, fmt.Errorf("Got truncated message on tcp") } if edns { // Truncated even though EDNS is used client.Net = "tcp" } return lookup(name, queryType, client, servAddr, suffix, !edns) } return response, nil }
func query(dom string) { nameserver := "127.0.0.1:8053" qt := dns.TypeA qc := uint16(dns.ClassINET) c := new(dns.Client) c.Net = "udp" m := new(dns.Msg) m.Question = make([]dns.Question, 1) m.Question[0] = dns.Question{ Name: dns.Fqdn(dom), Qtype: qt, Qclass: qc, } _, _, err := c.Exchange(m, nameserver) if err != nil { fmt.Println(err) } }
// defaultExtResolver queries other nameserver, potentially recurses; callers should probably be invoking extResolver // instead since that's the pluggable entrypoint into external resolution. func (res *Resolver) defaultExtResolver(r *dns.Msg, nameserver string, proto string, cnt int) (*dns.Msg, error) { var in *dns.Msg var err error c := new(dns.Client) c.Net = proto var t time.Duration = 5 * 1e9 if res.config.Timeout != 0 { t = time.Duration(int64(res.config.Timeout * 1e9)) } c.DialTimeout = t c.ReadTimeout = t c.WriteTimeout = t in, _, err = c.Exchange(r, nameserver) if err != nil { return in, err } // recurse if (in != nil) && (len(in.Answer) == 0) && (!in.MsgHdr.Authoritative) && (len(in.Ns) > 0) && (err != nil) { if cnt == recurseCnt { logging.CurLog.NonMesosRecursed.Inc() } if cnt > 0 { if soa, ok := (in.Ns[0]).(*dns.SOA); ok { return res.defaultExtResolver(r, net.JoinHostPort(soa.Ns, "53"), proto, cnt-1) } } } return in, err }
func TestDNS_Recurse(t *testing.T) { dir, srv := makeDNSServer(t) defer os.RemoveAll(dir) defer srv.agent.Shutdown() m := new(dns.Msg) m.SetQuestion("apple.com.", dns.TypeANY) c := new(dns.Client) c.Net = "tcp" addr, _ := srv.agent.config.ClientListener(srv.agent.config.Ports.DNS) in, _, err := c.Exchange(m, addr.String()) if err != nil { t.Fatalf("err: %v", err) } if len(in.Answer) == 0 { t.Fatalf("Bad: %#v", in) } if in.Rcode != dns.RcodeSuccess { t.Fatalf("Bad: %#v", in) } }
// ReadZoneXfr reads a zone from an axfr. func (c *Config) ReadZoneXfr(origin, master string) error { client := new(dns.Client) client.Net = "tcp" m := new(dns.Msg) m.SetAxfr(origin) z := dns.NewZone(origin) t, e := client.TransferIn(m, master) if e == nil { for r := range t { if r.Error == nil { // Loop answer section for _, rr := range r.RR { z.Insert(rr) } } } c.Zones[origin] = z dns.HandleFunc(origin, func(w dns.ResponseWriter, req *dns.Msg) { serve(w, req, c.Zones[origin]) }) return nil } return e }
func (d dnsAPI) Recurse(w dns.ResponseWriter, req *dns.Msg) { var client dns.Client if isTCP(w.RemoteAddr()) { client.Net = "tcp" } for _, recursor := range d.Recursors { req.Compress = true res, _, err := client.Exchange(req, recursor) if err != nil { continue } res.Compress = true w.WriteMsg(res) return } // Return SERVFAIL res := &dns.Msg{} res.RecursionAvailable = true res.SetRcode(req, dns.RcodeServerFailure) w.WriteMsg(res) }
func TestTCPDNSServer(t *testing.T) { setupForTest(t) const ( numAnswers = 512 nonLocalName = "weave.works." ) InitDefaultLogging(testing.Verbose()) Info.Println("TestTCPDNSServer starting") zone, err := NewZoneDb(ZoneConfig{}) require.NoError(t, err) err = zone.Start() require.NoError(t, err) defer zone.Stop() // generate a list of `numAnswers` IP addresses var addrs []ZoneRecord bs := make([]byte, 4) for i := 0; i < numAnswers; i++ { binary.LittleEndian.PutUint32(bs, uint32(i)) ip := net.IPv4(bs[0], bs[1], bs[2], bs[3]) addrs = append(addrs, ZoneRecord(Record{"", ip, 0, 0, 0})) } // handler for the fallback server: it will just return a very long response fallbackUDPHandler := func(w dns.ResponseWriter, req *dns.Msg) { if len(req.Question) == 0 { return // ignore empty queries (sent when shutting down the server) } maxLen := getMaxReplyLen(req, protUDP) t.Logf("Fallback UDP server got asked: returning %d answers", numAnswers) q := req.Question[0] m := makeAddressReply(req, &q, addrs, DefaultLocalTTL) mLen := m.Len() m.SetEdns0(uint16(maxLen), false) if mLen > maxLen { t.Logf("... truncated response (%d > %d)", mLen, maxLen) m.Truncated = true } w.WriteMsg(m) } fallbackTCPHandler := func(w dns.ResponseWriter, req *dns.Msg) { if len(req.Question) == 0 { return // ignore empty queries (sent when shutting down the server) } t.Logf("Fallback TCP server got asked: returning %d answers", numAnswers) q := req.Question[0] m := makeAddressReply(req, &q, addrs, DefaultLocalTTL) w.WriteMsg(m) } t.Logf("Running a DNS fallback server with UDP") fallback, err := newMockedFallback(fallbackUDPHandler, fallbackTCPHandler) require.NoError(t, err) fallback.Start() defer fallback.Stop() t.Logf("Creating a WeaveDNS server instance, falling back to 127.0.0.1:%d", fallback.Port) srv, err := NewDNSServer(DNSServerConfig{ Zone: zone, UpstreamCfg: fallback.CliConfig, CacheDisabled: true, ListenReadTimeout: testSocketTimeout, }) require.NoError(t, err) err = srv.Start() require.NoError(t, err) go srv.ActivateAndServe() defer srv.Stop() time.Sleep(100 * time.Millisecond) // Allow sever goroutine to start testPort, err := srv.GetPort() require.NoError(t, err) require.NotEqual(t, 0, testPort, "listen port") dnsAddr := fmt.Sprintf("127.0.0.1:%d", testPort) t.Logf("Creating a UDP and a TCP client") uc := new(dns.Client) uc.UDPSize = minUDPSize tc := new(dns.Client) tc.Net = "tcp" t.Logf("Creating DNS query message") m := new(dns.Msg) m.RecursionDesired = true m.SetQuestion(nonLocalName, dns.TypeA) t.Logf("Checking the fallback server at %s returns a truncated response with UDP", fallback.Addr) r, _, err := uc.Exchange(m, fallback.Addr) t.Logf("Got response from fallback server (UDP) with %d answers", len(r.Answer)) t.Logf("Response:\n%+v\n", r) require.NoError(t, err) require.True(t, r.MsgHdr.Truncated, "DNS truncated reponse flag") require.NotEqual(t, numAnswers, len(r.Answer), "number of answers (UDP)") t.Logf("Checking the WeaveDNS server at %s returns a truncated reponse with UDP", dnsAddr) r, _, err = uc.Exchange(m, dnsAddr) t.Logf("UDP Response:\n%+v\n", r) require.NoError(t, err) require.NotNil(t, r, "response") t.Logf("%d answers", len(r.Answer)) require.True(t, r.MsgHdr.Truncated, "DNS truncated reponse flag") require.NotEqual(t, numAnswers, len(r.Answer), "number of answers (UDP)") t.Logf("Checking the WeaveDNS server at %s does not return a truncated reponse with TCP", dnsAddr) r, _, err = tc.Exchange(m, dnsAddr) t.Logf("TCP Response:\n%+v\n", r) require.NoError(t, err) require.NotNil(t, r, "response") t.Logf("%d answers", len(r.Answer)) require.False(t, r.MsgHdr.Truncated, "DNS truncated response flag") require.Equal(t, numAnswers, len(r.Answer), "number of answers (TCP)") t.Logf("Checking the WeaveDNS server at %s does not return a truncated reponse with UDP with a bigger buffer", dnsAddr) m.SetEdns0(testUDPBufSize, false) r, _, err = uc.Exchange(m, dnsAddr) t.Logf("UDP-large Response:\n%+v\n", r) require.NoError(t, err) require.NotNil(t, r, "response") t.Logf("%d answers", len(r.Answer)) require.NoError(t, err) require.False(t, r.MsgHdr.Truncated, "DNS truncated response flag") require.Equal(t, numAnswers, len(r.Answer), "number of answers (UDP-long)") }
func main() { short = flag.Bool("short", false, "abbreviate long DNSSEC records") dnssec := flag.Bool("dnssec", false, "request DNSSEC records") query := flag.Bool("question", false, "show question") check := flag.Bool("check", false, "check internal DNSSEC consistency") raw := flag.Bool("raw", false, "do not strip 'http://' from the qname") six := flag.Bool("6", false, "use IPv6 only") four := flag.Bool("4", false, "use IPv4 only") anchor := flag.String("anchor", "", "use the DNSKEY in this file for interal DNSSEC consistency") tsig := flag.String("tsig", "", "request tsig with key: [hmac:]name:key") port := flag.Int("port", 53, "port number to use") aa := flag.Bool("aa", false, "set AA flag in query") ad := flag.Bool("ad", false, "set AD flag in query") cd := flag.Bool("cd", false, "set CD flag in query") rd := flag.Bool("rd", true, "set RD flag in query") fallback := flag.Bool("fallback", false, "fallback to 4096 bytes bufsize and after that TCP") tcp := flag.Bool("tcp", false, "TCP mode") nsid := flag.Bool("nsid", false, "set edns nsid option") client := flag.String("client", "", "set edns client-subnet option") //serial := flag.Int("serial", 0, "perform an IXFR with this serial") flag.Usage = func() { fmt.Fprintf(os.Stderr, "Usage: %s [options] [@server] [qtype] [qclass] [name ...]\n", os.Args[0]) flag.PrintDefaults() } qtype := uint16(0) qclass := uint16(dns.ClassINET) var qname []string flag.Parse() if *anchor != "" { f, err := os.Open(*anchor) if err != nil { fmt.Fprintf(os.Stderr, "Failure to open %s: %s\n", *anchor, err.Error()) } r, err := dns.ReadRR(f, *anchor) if err != nil { fmt.Fprintf(os.Stderr, "Failure to read an RR from %s: %s\n", *anchor, err.Error()) } if k, ok := r.(*dns.DNSKEY); !ok { fmt.Fprintf(os.Stderr, "No DNSKEY read from %s\n", *anchor) } else { dnskey = k } } var nameserver string Flags: for i := 0; i < flag.NArg(); i++ { // If it starts with @ it is a nameserver if flag.Arg(i)[0] == '@' { nameserver = flag.Arg(i) continue Flags } // First class, then type, to make ANY queries possible // And if it looks like type, it is a type if k, ok := dns.StringToType[strings.ToUpper(flag.Arg(i))]; ok { qtype = k continue Flags } // If it looks like a class, it is a class if k, ok := dns.StringToClass[strings.ToUpper(flag.Arg(i))]; ok { qclass = k continue Flags } // If it starts with TYPExxx it is unknown rr if strings.HasPrefix(flag.Arg(i), "TYPE") { i, e := strconv.Atoi(string([]byte(flag.Arg(i))[4:])) if e == nil { qtype = uint16(i) continue Flags } } // Anything else is a qname qname = append(qname, flag.Arg(i)) } if len(qname) == 0 { qname = make([]string, 1) qname[0] = "." qtype = dns.TypeNS } if qtype == 0 { qtype = dns.TypeA } if len(nameserver) == 0 { conf, err := dns.ClientConfigFromFile("/etc/resolv.conf") if err != nil { fmt.Fprintln(os.Stderr, err) os.Exit(2) } nameserver = "@" + conf.Servers[0] } nameserver = string([]byte(nameserver)[1:]) // chop off @ // if the nameserver is from /etc/resolv.conf the [ and ] are already // added, thereby breaking net.ParseIP. Check for this and don't // fully qualify such a name if nameserver[0] == '[' && nameserver[len(nameserver)-1] == ']' { nameserver = nameserver[1 : len(nameserver)-1] } if i := net.ParseIP(nameserver); i != nil { nameserver = net.JoinHostPort(nameserver, strconv.Itoa(*port)) } else { nameserver = dns.Fqdn(nameserver) + ":" + strconv.Itoa(*port) } c := new(dns.Client) if *tcp { c.Net = "tcp" if *four { c.Net = "tcp4" } if *six { c.Net = "tcp6" } } else { c.Net = "udp" if *four { c.Net = "udp4" } if *six { c.Net = "udp6" } } m := new(dns.Msg) m.MsgHdr.Authoritative = *aa m.MsgHdr.AuthenticatedData = *ad m.MsgHdr.CheckingDisabled = *cd m.MsgHdr.RecursionDesired = *rd m.Question = make([]dns.Question, 1) if *dnssec || *nsid || *client != "" { o := new(dns.OPT) o.Hdr.Name = "." o.Hdr.Rrtype = dns.TypeOPT if *dnssec { o.SetDo() o.SetUDPSize(dns.DefaultMsgSize) } if *nsid { e := new(dns.EDNS0_NSID) e.Code = dns.EDNS0NSID o.Option = append(o.Option, e) // NSD will not return nsid when the udp message size is too small o.SetUDPSize(dns.DefaultMsgSize) } if *client != "" { e := new(dns.EDNS0_SUBNET) e.Code = dns.EDNS0SUBNET e.SourceScope = 0 e.Address = net.ParseIP(*client) if e.Address == nil { fmt.Fprintf(os.Stderr, "Failure to parse IP address: %s\n", *client) return } e.Family = 1 // IP4 e.SourceNetmask = net.IPv4len * 8 if e.Address.To4() == nil { e.Family = 2 // IP6 e.SourceNetmask = net.IPv6len * 8 } o.Option = append(o.Option, e) } m.Extra = append(m.Extra, o) } for _, v := range qname { if !*raw && strings.HasPrefix(v, "http://") { v = v[7:] if v[len(v)-1] == '/' { v = v[:len(v)-1] } } m.Question[0] = dns.Question{dns.Fqdn(v), qtype, qclass} m.Id = dns.Id() // Add tsig if *tsig != "" { if algo, name, secret, ok := tsigKeyParse(*tsig); ok { m.SetTsig(name, algo, 300, time.Now().Unix()) c.TsigSecret = map[string]string{name: secret} } else { fmt.Fprintf(os.Stderr, "TSIG key data error\n") return } } if *query { fmt.Printf("%s", m.String()) fmt.Printf("\n;; size: %d bytes\n\n", m.Len()) } if qtype == dns.TypeAXFR { c.Net = "tcp" doXfr(c, m, nameserver) continue } if qtype == dns.TypeIXFR { doXfr(c, m, nameserver) continue } r, rtt, e := c.Exchange(m, nameserver) Redo: if e != nil { fmt.Printf(";; %s\n", e.Error()) continue } if r.Id != m.Id { fmt.Fprintf(os.Stderr, "Id mismatch\n") return } if r.MsgHdr.Truncated && *fallback { if c.Net != "tcp" { if !*dnssec { fmt.Printf(";; Truncated, trying %d bytes bufsize\n", dns.DefaultMsgSize) o := new(dns.OPT) o.Hdr.Name = "." o.Hdr.Rrtype = dns.TypeOPT o.SetUDPSize(dns.DefaultMsgSize) m.Extra = append(m.Extra, o) r, rtt, e = c.Exchange(m, nameserver) *dnssec = true goto Redo } else { // First EDNS, then TCP fmt.Printf(";; Truncated, trying TCP\n") c.Net = "tcp" r, rtt, e = c.Exchange(m, nameserver) goto Redo } } } if r.MsgHdr.Truncated && !*fallback { fmt.Printf(";; Truncated\n") } if *check { sigCheck(r, nameserver, *tcp) } if *short { r = shortMsg(r) } fmt.Printf("%v", r) fmt.Printf("\n;; query time: %.3d µs, server: %s(%s), size: %d bytes\n", rtt/1e3, nameserver, c.Net, r.Len()) } }