func (s *server) RoundRobin(rrs []dns.RR) { if !s.config.RoundRobin { return } // If we have more than 1 CNAME don't touch the packet, because some stub resolver (=glibc) // can't deal with the returned packet if the CNAMEs need to be accesses in the reverse order. cname := 0 for _, r := range rrs { if r.Header().Rrtype == dns.TypeCNAME { cname++ if cname > 1 { return } } } switch l := len(rrs); l { case 2: if dns.Id()%2 == 0 { rrs[0], rrs[1] = rrs[1], rrs[0] } default: for j := 0; j < l*(int(dns.Id())%4+1); j++ { q := int(dns.Id()) % l p := int(dns.Id()) % l if q == p { p = (p + 1) % l } rrs[q], rrs[p] = rrs[p], rrs[q] } } }
func (self *DnsResolver) lookupHost(host string, triesLeft int) ([]net.IP, error) { m1 := new(dns.Msg) m1.Id = dns.Id() m1.RecursionDesired = true m1.Question = make([]dns.Question, 1) m1.Question[0] = dns.Question{dns.Fqdn(host), dns.TypeA, dns.ClassINET} in, err := dns.Exchange(m1, self.Servers[self.r.Intn(len(self.Servers))]) result := []net.IP{} if err != nil { if strings.HasSuffix(err.Error(), "i/o timeout") && triesLeft > 0 { triesLeft -= 1 return self.lookupHost(host, triesLeft) } else { return result, err } } if in != nil && in.Rcode != dns.RcodeSuccess { return result, errors.New(dns.RcodeToString[in.Rcode]) } for _, record := range in.Answer { if t, ok := record.(*dns.A); ok { result = append(result, t.A) } } return result, err }
func (h *handler) handleRecursive(w dns.ResponseWriter, req *dns.Msg) { h.ns.debugf("recursive request: %+v", *req) // Resolve unqualified names locally if len(req.Question) == 1 { hostname := dns.Fqdn(req.Question[0].Name) if strings.Count(hostname, ".") == 1 { h.handleLocal(w, req) return } } upstreamConfig, err := h.upstream.Config() if err != nil { h.ns.errorf("unable to read upstream config: %s", err) } for _, server := range upstreamConfig.Servers { reqCopy := req.Copy() reqCopy.Id = dns.Id() response, _, err := h.client.Exchange(reqCopy, fmt.Sprintf("%s:%s", server, upstreamConfig.Port)) if (err != nil && err != dns.ErrTruncated) || response == nil { h.ns.debugf("error trying %s: %v", server, err) continue } response.Id = req.Id if h.responseTooBig(req, response) { response.Compress = true } h.respond(w, response) return } h.respond(w, h.makeErrorResponse(req, dns.RcodeServerFailure)) }
func TestDNSHandlerMultiQuestions(t *testing.T) { var ( h = newDNSHandler(&testStore{}, "tt", dns.Fqdn("test.glimpse.io")) m = &dns.Msg{} w = &testWriter{} ) m.Id = dns.Id() m.RecursionDesired = true m.Question = make([]dns.Question, 3) for i := range m.Question { m.Question[i] = dns.Question{ Name: "foo.bar.baz.", Qtype: dns.TypeA, Qclass: dns.ClassINET, } } h.ServeDNS(w, m) r := w.msg if want, have := dns.RcodeNotImplemented, r.Rcode; want != have { t.Errorf( "want rcode %s, have %s", dns.RcodeToString[want], dns.RcodeToString[have], ) } }
func TestStress(t *testing.T) { domains := []string{"www.google.com.", "www.isc.org.", "www.outlook.com.", "miek.nl.", "doesnotexist.miek.nl."} l := len(domains) max := 8 procs := runtime.GOMAXPROCS(max) wg := new(sync.WaitGroup) wg.Add(max) u := New() defer u.Destroy() if err := u.ResolvConf("/etc/resolv.conf"); err != nil { return } for i := 0; i < max; i++ { go func() { for i := 0; i < 100; i++ { d := domains[int(dns.Id())%l] r, err := u.Resolve(d, dns.TypeA, dns.ClassINET) if err != nil { t.Log("failure to resolve: " + d) continue } if !r.HaveData && d != "doesnotexist.miek.nl." { t.Log("no data when resolving: " + d) continue } } wg.Done() }() } wg.Wait() runtime.GOMAXPROCS(procs) }
// Uses github.com/miekg/dns package, in order to invoke one query. func (r *Resolver) queryDNSServer(dnsServer, domainname, rrType string, edns bool) (*dns.Msg, error) { fqdn := dns.Fqdn(domainname) r.dnsQueryMsg.Id = dns.Id() r.dnsQueryMsg.SetQuestion(fqdn, dns.StringToType[rrType]) dnsServerSocket := dnsServer + ":" + DNSPORT dnsResponseMsg, err := dns.Exchange(r.dnsQueryMsg, dnsServerSocket) if err != nil { return nil, errors.New("dns.Exchange() failed") } if r.dnsQueryMsg.Id != dnsResponseMsg.Id { log.Printf("DNS msgID mismatch: Request-ID(%d), Response-ID(%d)", r.dnsQueryMsg.Id, dnsResponseMsg.Id) return nil, errors.New("DNS Msg-ID mismatch") } if dnsResponseMsg.MsgHdr.Truncated { if r.dnsClient.Net == "tcp" { return nil, errors.New("Received invalid truncated Msg over TCP") //fmt.Errorf("Got truncated message on tcp") } if edns { r.dnsClient.Net = "tcp" } return r.queryDNSServer(dnsServer, domainname, rrType, !edns) } return dnsResponseMsg, nil }
func (d *DNSServer) handleRecursive(client *dns.Client, defaultMaxResponseSize int) func(dns.ResponseWriter, *dns.Msg) { return func(w dns.ResponseWriter, req *dns.Msg) { d.ns.debugf("recursive request: %+v", *req) // Resolve unqualified names locally if len(req.Question) == 1 && req.Question[0].Qtype == dns.TypeA { hostname := dns.Fqdn(req.Question[0].Name) if strings.Count(hostname, ".") == 1 { d.handleLocal(defaultMaxResponseSize)(w, req) return } } for _, server := range d.upstream.Servers { reqCopy := req.Copy() reqCopy.Id = dns.Id() response, _, err := client.Exchange(reqCopy, fmt.Sprintf("%s:%s", server, d.upstream.Port)) if err != nil || response == nil { d.ns.debugf("error trying %s: %v", server, err) continue } d.ns.debugf("response: %+v", response) response.Id = req.Id if err := w.WriteMsg(response); err != nil { d.ns.infof("error responding: %v", err) } return } d.errorResponse(req, dns.RcodeServerFailure, w) } }
func DirectedQuery(servers []string, rd bool, q dns.Question, ctx context.Context) (*dns.Msg, error) { cl := dns.Client{ Net: "tcp", } m := &dns.Msg{ MsgHdr: dns.MsgHdr{ Id: dns.Id(), RecursionDesired: rd, }, Compress: true, Question: []dns.Question{q}, } m = m.SetEdns0(4096, false) type txResult struct { Response *dns.Msg Err error } maxTries := len(servers) if maxTries < 3 { maxTries = 3 } var mainErr error for i := 0; i < maxTries; i++ { s := servers[i%len(servers)] host, port, err := denet.FuzzySplitHostPort(s) if err != nil { return nil, err } if port == "" { port = "53" } txResultChan := make(chan txResult, 1) go func() { r, _, err := cl.Exchange(m, net.JoinHostPort(host, port)) txResultChan <- txResult{r, err} }() select { case <-ctx.Done(): return nil, ctx.Err() case txResult := <-txResultChan: if txResult.Err == nil { return txResult.Response, nil } mainErr = txResult.Err } } return nil, mainErr }
func (s *server) ServeDNS(w dns.ResponseWriter, req *dns.Msg) { q := req.Question[0] name := strings.ToLower(q.Name) if q.Qtype == dns.TypeIXFR || q.Qtype == dns.TypeAXFR { m := new(dns.Msg) m.SetRcode(req, dns.RcodeNotImplemented) w.WriteMsg(m) return } allServers, err := s.router.Match(name) if err != nil || len(allServers) == 0 { m := new(dns.Msg) m.SetRcode(req, dns.RcodeServerFailure) w.WriteMsg(m) return } serv := allServers[int(dns.Id())%len(allServers)] log.Printf("routing %s to %s", name, serv) c := new(dns.Client) ret, _, err := c.Exchange(req, serv) // serv has the port if err != nil { m := new(dns.Msg) m.SetRcode(req, dns.RcodeServerFailure) w.WriteMsg(m) return } w.WriteMsg(ret) }
func newMsg(host string, qClass uint16) *dns.Msg { m1 := new(dns.Msg) m1.Id = dns.Id() m1.RecursionDesired = true m1.Question = make([]dns.Question, 1) m1.Question[0] = dns.Question{host, qClass, dns.ClassINET} return m1 }
func localQuery(mychan chan DNSreply, qname string, qtype uint16) { var result DNSreply var trials uint result.qname = qname result.qtype = qtype result.r = nil result.err = errors.New("No name server to answer the question") localm := new(dns.Msg) localm.Id = dns.Id() localm.RecursionDesired = true localm.Question = make([]dns.Question, 1) localm.SetEdns0(EDNSBUFFERSIZE, false) // Even if no EDNS requested, see #9 May be we should retry without it if timeout? localc := new(dns.Client) localc.ReadTimeout = timeout localm.Question[0] = dns.Question{qname, qtype, dns.ClassINET} Tests: for trials = 0; trials < uint(*maxTrials); trials++ { Resolvers: for serverIndex := range conf.Servers { server := conf.Servers[serverIndex] result.nameserver = server // Brackets around the server address are necessary for IPv6 name servers r, rtt, err := localc.Exchange(localm, "["+server+"]:"+conf.Port) // Do not use net.JoinHostPort, see https://github.com/bortzmeyer/check-soa/commit/3e4edb13855d8c4016768796b2892aa83eda1933#commitcomment-2355543 if r == nil { result.r = nil result.err = err if strings.Contains(err.Error(), "timeout") { // Try another resolver break Resolvers } else { // We give in break Tests } } else { result.rtt = rtt if r.Rcode == dns.RcodeSuccess { // TODO: as a result, NODATA (NOERROR/ANSWER=0) are silently ignored (try "foo", for instance, the name exists but no IP address) // TODO: for rcodes like SERVFAIL, trying another resolver could make sense result.r = r result.err = nil break Tests } else { // All the other codes are errors. Yes, it may // happens that one resolver returns REFUSED // and the others work but we do not handle // this case. TODO: delete the resolver from // the list and try another one result.r = r result.err = errors.New(dns.RcodeToString[r.Rcode]) break Tests } } } } if *debug { fmt.Printf("DEBUG: end of DNS request \"%s\" / %d\n", qname, qtype) } mychan <- result }
func check(c *dns.Client, m *dns.Msg, addr string) bool { m.Id = dns.Id() in, _, err := c.Exchange(m, addr) if err != nil { return false } if in.Rcode != dns.RcodeSuccess { return false } return true }
func DnsGetDoaminIP(domain string) (string, error) { m := new(dns.Msg) m.Id = dns.Id() m.SetQuestion(dns.Fqdn(domain), dns.TypeA) m.RecursionDesired = true res, err := dnsQuery(m) if nil != err { return "", err } return pickIP(res), nil }
func (d *DnsDomain) Test() bool { if !(*Domain)(d).Test() { return false } fqdn := d.Name if strings.HasPrefix(fqdn, "*.") { fqdn = "a" + fqdn[1:] } if !strings.HasSuffix(fqdn, ".") { fqdn = fqdn + "." } any_ok := false d.DNS = make([]*DnsRecords, 0, len(DNS_servers)) for name, addr := range DNS_servers { records := new(DnsRecords) records.Server = name records.NS = addr d.DNS = append(d.DNS, records) req := new(dns.Msg) req.Id = dns.Id() req.RecursionDesired = true req.Question = []dns.Question{ dns.Question{fqdn, dns.TypeA, dns.ClassINET}, } resp, err := dns_client.Exchange(req, addr) if err != nil { records.Status = 900 records.Message = err.Error() continue } records.IPs = make([]string, 0, len(resp.Answer)) for _, rr := range resp.Answer { switch a := rr.(type) { case *dns.RR_A: records.IPs = append(records.IPs, a.A.String()) } } if len(records.IPs) > 0 { any_ok = true } else { records.Status = 900 records.Message = "No records" } } return any_ok }
func roundRobin(in []dns.RR) []dns.RR { cname := []dns.RR{} address := []dns.RR{} rest := []dns.RR{} for _, r := range in { switch r.Header().Rrtype { case dns.TypeCNAME: cname = append(cname, r) case dns.TypeA, dns.TypeAAAA: address = append(address, r) default: rest = append(rest, r) } } switch l := len(address); l { case 0, 1: break case 2: if dns.Id()%2 == 0 { address[0], address[1] = address[1], address[0] } default: for j := 0; j < l*(int(dns.Id())%4+1); j++ { q := int(dns.Id()) % l p := int(dns.Id()) % l if q == p { p = (p + 1) % l } address[q], address[p] = address[p], address[q] } } out := append(cname, rest...) out = append(out, address...) return out }
func (resolver ConsulDnsAddressResolver) Resolve(service string) (string, error) { m1 := new(dns.Msg) m1.Id = dns.Id() m1.RecursionDesired = true m1.SetQuestion(service+".service.consul.", dns.TypeA) c := new(dns.Client) in, _, err := c.Exchange(m1, resolver.ServerAddress) if err != nil { log.Fatal(err) } if len(in.Answer) > 0 { log.Println(in.Answer) return in.Answer[0].(*dns.A).A.String(), nil } return "", errors.New("Could not resolve service address") }
func MakeDnsFrame(host string, t uint16, streamid uint16) (req *dns.Msg, f Frame, err error) { log.Debug("make a dns query for %s.", host) req = new(dns.Msg) req.Id = dns.Id() req.SetQuestion(dns.Fqdn(host), t) req.RecursionDesired = true b, err := req.Pack() if err != nil { return } f = NewFrameDns(streamid, b) return }
func findSoaNs(domain string) (string, string, string) { var cname string var soa string var ns string add := func(c, s, n string) { cname += c soa += s ns += n return } cname += domain + "," m1 := new(dns.Msg) m1.Id = dns.Id() m1.RecursionDesired = true m1.Question = make([]dns.Question, 1) m1.Question[0] = dns.Question{domain, dns.TypeSOA, dns.ClassINET} in, _ := dns.Exchange(m1, (cf.Servers[1] + ":53")) rrList := [...][]dns.RR{in.Answer, in.Ns, in.Extra} for _, rr := range rrList { for i := len(rr) - 1; i >= 0; i-- { switch rr[i].Header().Rrtype { case dns.TypeCNAME: temp_cname := rr[i].(*dns.CNAME) add(findSoaNs(temp_cname.Target)) // fmt.Println( "temp_cname:" , temp_cname ) return cname, soa, ns break case dns.TypeNS: temp_ns := rr[i].(*dns.NS) ns += temp_ns.Ns + "," // + "|" + fmt.Sprint( temp_ns.Hdr.Ttl ) + "," // fmt.Println( "temp_ns:" , temp_ns ) break case dns.TypeSOA: temp_soa := rr[i].(*dns.SOA) soa += temp_soa.Ns + "," // + "|" + fmt.Sprint( temp_soa.Hdr.Ttl ) + "," // fmt.Println( "temp_soa:" , temp_soa ) break } } } return cname, soa, ns }
func TestDockerClientError(t *testing.T) { listener, dockerClient, _, client := setup(t, ".") dockerClient.inspectContainer = func(string) (*docker.Container, error) { return nil, fmt.Errorf("error") } msg := &dns.Msg{} msg.Id = dns.Id() msg.RecursionDesired = true msg.Question = []dns.Question{ {Name: "api.docker.", Qclass: dns.ClassINET, Qtype: dns.TypeA}, } _, _, err := client.Exchange(msg, listener.Addr().String()) if err != nil { t.Fatal(err) } listener.Close() }
/* * makeMessage() - construct DNS message structure */ func makeMessage(c *Context, qname, qtype, qclass string, ext Extension) *dns.Msg { m := new(dns.Msg) m.Id = dns.Id() if c.restype == RESOLUTION_STUB { m.RecursionDesired = true } else { m.RecursionDesired = false } if c.adflag { m.AuthenticatedData = true } if c.cdflag { m.CheckingDisabled = true } if ext["dnssec_return_status"] || ext["dnssec_return_only_secure"] || ext["dnssec_return_validation_chain"] { opt := new(dns.OPT) opt.Hdr.Name = "." opt.Hdr.Rrtype = dns.TypeOPT opt.SetDo() m.Extra = append(m.Extra, opt) } m.Question = make([]dns.Question, 1) qtype_int, ok := dns.StringToType[strings.ToUpper(qtype)] if !ok { fmt.Printf("%s: Unrecognized query type.\n", qtype) return nil } qclass_int, ok := dns.StringToClass[strings.ToUpper(qclass)] if !ok { fmt.Printf("%s: Unrecognized query class.\n", qclass) return nil } m.Question[0] = dns.Question{qname, qtype_int, qclass_int} return m }
// dnsLookup is used whenever we need to conduct a DNS query over a given TCP connection func DnsLookup(addr string, conn net.Conn) (*DnsResponse, error) { //log.Printf("Doing a DNS lookup on %s", addr) dnsResponse := &DnsResponse{ records: make([]DNSRecord, 0), } // create the connection to the DNS server dnsConn := &dns.Conn{Conn: conn} defer dnsConn.Close() m := new(dns.Msg) m.Id = dns.Id() // set the question section in the dns query // Fqdn returns the fully qualified domain name m.SetQuestion(dns.Fqdn(addr), dns.TypeA) m.RecursionDesired = true dnsConn.WriteMsg(m) response, err := dnsConn.ReadMsg() if err != nil { log.Printf("Could not process DNS response: %v", err) return nil, err } now := time.Now() // iterate over RRs containing the DNS answer for _, answer := range response.Answer { if a, ok := answer.(*dns.A); ok { // append the result to our list of records // the A records in the RDATA section of the DNS answer // contains the actual IP address dnsResponse.records = append(dnsResponse.records, DNSRecord{ IP: a.A, ExpireAt: now.Add(time.Duration(a.Hdr.Ttl) * time.Second), }) //log.Printf("###TTL:%d", a.Hdr.Ttl) } } return dnsResponse, nil }
func handleStaticRequest(config DomainConfig, w dns.ResponseWriter, r *dns.Msg) { m := new(dns.Msg) m.SetReply(r) defer w.WriteMsg(m) for _, q := range r.Question { for _, addr := range config.addrs { if addr.To4() != nil { // "If ip is not an IPv4 address, To4 returns nil." dnsAppend(q, m, &dns.A{A: addr}) } else { dnsAppend(q, m, &dns.AAAA{AAAA: addr}) } } for _, cname := range config.cnames { dnsAppend(q, m, &dns.CNAME{Target: cname}) if r.RecursionDesired && len(config.Nameservers) > 0 { recR := &dns.Msg{ MsgHdr: dns.MsgHdr{ Id: dns.Id(), }, Question: []dns.Question{ {Name: cname, Qtype: q.Qtype, Qclass: q.Qclass}, }, } recM := handleForwardingRaw(config.Nameservers, recR, w.RemoteAddr()) for _, rr := range recM.Answer { dnsAppend(q, m, rr) } for _, rr := range recM.Extra { dnsAppend(q, m, rr) } } } for _, txt := range config.txts { dnsAppend(q, m, &dns.TXT{Txt: txt}) } } }
// NewResolver returns an initialized Resolver struct. func NewResolver(config *Config) *Resolver { msg := &dns.Msg{} msg.Id = dns.Id() msg.RecursionDesired = true msg.SetQuestion("", dns.TypeANY) if config.edns { msg = handleEDNS(msg) } return &Resolver{ config.dnsServers, config.rrTypes, config.edns, msg, &dns.Client{}, &Result{ res: make(map[interface{}][]string), }, } }
// Get a single metric from dnsmasq. Returns the numeric value of the // metric. func (mc *metricsClient) getSingleMetric(name string) (int64, error) { msg := new(dns.Msg) msg.Id = dns.Id() msg.RecursionDesired = false msg.Question = make([]dns.Question, 1) msg.Question[0] = dns.Question{ Name: name, Qtype: dns.TypeTXT, Qclass: dns.ClassCHAOS, } in, _, err := mc.dnsClient.Exchange(msg, mc.addrPort) if err != nil { return 0, err } if len(in.Answer) != 1 { return 0, fmt.Errorf("Invalid number of Answer records for %s: %d", name, len(in.Answer)) } if t, ok := in.Answer[0].(*dns.TXT); ok { glog.V(4).Infof("Got valid TXT response %+v for %s", t, name) if len(t.Txt) != 1 { return 0, fmt.Errorf("Invalid number of TXT records for %s: %d", name, len(t.Txt)) } value, err := strconv.ParseInt(t.Txt[0], 10, 64) if err != nil { return 0, err } return value, nil } return 0, fmt.Errorf("missing txt record for %s", name) }
func (r Resolver) ResolveA(addr string) (net.IP, bool) { // log.Println("Looking up A for " + addr) m1 := new(dns.Msg) m1.Id = dns.Id() m1.RecursionDesired = true m1.Question = make([]dns.Question, 1) m1.Question[0] = dns.Question{addr, dns.TypeA, dns.ClassINET} c := new(dns.Client) in, _, err := c.Exchange(m1, "10.1.3.254:53") if err != nil { log.Println("Failed on c.Exchange call: " + err.Error()) return nil, false } if a, ok := in.Answer[0].(*dns.A); ok { // log.Println("I found an answer for " + addr) return a.A, true } // log.Println("Nothing found for " + addr) return nil, false }
func (s *server) AddressRecords(q dns.Question) (records []dns.RR, err error) { name := strings.ToLower(q.Name) if name == "master."+s.config.Domain || name == s.config.Domain { for _, m := range s.client.GetCluster() { u, e := url.Parse(m) if e != nil { continue } h, _, e := net.SplitHostPort(u.Host) if e != nil { continue } ip := net.ParseIP(h) switch { case ip.To4() != nil && q.Qtype == dns.TypeA: records = append(records, &dns.A{Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: s.Ttl}, A: ip.To4()}) case ip.To4() == nil && q.Qtype == dns.TypeAAAA: records = append(records, &dns.AAAA{Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: s.Ttl}, AAAA: ip.To16()}) } } return } r, err := s.client.Get(path(name), false, true) if err != nil { println(err.Error()) return nil, err } var serv *Service if !r.Node.Dir { // single element if err := json.Unmarshal([]byte(r.Node.Value), &serv); err != nil { log.Printf("error: Failure to parse value: %q", err) return nil, err } ip := net.ParseIP(serv.Host) ttl := uint32(r.Node.TTL) if ttl == 0 { ttl = s.Ttl } switch { case ip == nil: case ip.To4() != nil && q.Qtype == dns.TypeA: a := new(dns.A) a.Hdr = dns.RR_Header{Name: q.Name, Rrtype: q.Qtype, Class: dns.ClassINET, Ttl: ttl} a.A = ip.To4() records = append(records, a) case ip.To4() == nil && q.Qtype == dns.TypeAAAA: aaaa := new(dns.AAAA) aaaa.Hdr = dns.RR_Header{Name: q.Name, Rrtype: q.Qtype, Class: dns.ClassINET, Ttl: ttl} aaaa.AAAA = ip.To16() records = append(records, aaaa) } return records, nil } for _, serv := range s.loopNodes(&r.Node.Nodes) { ip := net.ParseIP(serv.Host) switch { case ip == nil: case ip.To4() != nil && q.Qtype == dns.TypeA: a := new(dns.A) a.Hdr = dns.RR_Header{Name: q.Name, Rrtype: q.Qtype, Class: dns.ClassINET, Ttl: uint32(r.Node.TTL)} a.A = ip.To4() records = append(records, a) case ip.To4() == nil && q.Qtype == dns.TypeAAAA: aaaa := new(dns.AAAA) aaaa.Hdr = dns.RR_Header{Name: q.Name, Rrtype: q.Qtype, Class: dns.ClassINET, Ttl: uint32(r.Node.TTL)} aaaa.AAAA = ip.To16() records = append(records, aaaa) } } if s.config.RoundRobin { switch l := len(records); l { case 2: if dns.Id()%2 == 0 { records[0], records[1] = records[1], records[0] } default: // Do a minimum of l swap, maximum of 4l swaps for j := 0; j < l*(int(dns.Id())%4+1); j++ { q := int(dns.Id()) % l p := int(dns.Id()) % l if q == p { p = (p + 1) % l } records[q], records[p] = records[p], records[q] } } } return records, nil }
func TestDNS(t *testing.T) { testutil.RequireEtcd(t) defer testutil.DumpEtcdOnFailure(t) masterConfig, clientFile, err := testserver.StartTestMaster() if err != nil { t.Fatalf("unexpected error: %v", err) } localAddr := "" if ip, err := cmdutil.DefaultLocalIP4(); err == nil { localAddr = ip.String() } else if err == cmdutil.ErrorNoDefaultIP { localAddr = "127.0.0.1" } else if err != nil { t.Fatalf("Unable to find a local IP address: %v", err) } localIP := net.ParseIP(localAddr) var masterIP net.IP // verify service DNS entry is visible stop := make(chan struct{}) waitutil.Until(func() { m1 := &dns.Msg{ MsgHdr: dns.MsgHdr{Id: dns.Id(), RecursionDesired: false}, Question: []dns.Question{{"kubernetes.default.svc.cluster.local.", dns.TypeA, dns.ClassINET}}, } in, err := dns.Exchange(m1, masterConfig.DNSConfig.BindAddress) if err != nil { t.Logf("unexpected error: %v", err) return } if len(in.Answer) != 1 { t.Logf("unexpected answer: %#v", in) return } if a, ok := in.Answer[0].(*dns.A); ok { if a.A == nil { t.Fatalf("expected an A record with an IP: %#v", a) } masterIP = a.A } else { t.Fatalf("expected an A record: %#v", in) } t.Log(in) close(stop) }, 50*time.Millisecond, stop) client, err := testutil.GetClusterAdminKubeClient(clientFile) if err != nil { t.Fatalf("unexpected error: %v", err) } // Verify kubernetes service port is 53 and target port is the // configured masterConfig.DNSConfig.BindAddress port. _, dnsPortString, err := net.SplitHostPort(masterConfig.DNSConfig.BindAddress) if err != nil { t.Fatalf("unexpected error: %v", err) } dnsPort, err := strconv.Atoi(dnsPortString) if err != nil { t.Fatalf("unexpected error: %v", err) } kubernetesService, err := client.Services(kapi.NamespaceDefault).Get("kubernetes") if err != nil { t.Fatalf("unexpected error: %v", err) } found := false for _, port := range kubernetesService.Spec.Ports { if port.Port == 53 && port.TargetPort.IntVal == int32(dnsPort) && port.Protocol == kapi.ProtocolTCP { found = true } } if !found { t.Fatalf("did not find DNS port in kubernetes service: %#v", kubernetesService) } for { if _, err := client.Services(kapi.NamespaceDefault).Create(&kapi.Service{ ObjectMeta: kapi.ObjectMeta{ Name: "headless", }, Spec: kapi.ServiceSpec{ ClusterIP: kapi.ClusterIPNone, Ports: []kapi.ServicePort{{Port: 443}}, }, }); err != nil { if errors.IsForbidden(err) { t.Logf("forbidden, sleeping: %v", err) time.Sleep(100 * time.Millisecond) continue } t.Fatalf("unexpected error: %v", err) } if _, err := client.Endpoints(kapi.NamespaceDefault).Create(&kapi.Endpoints{ ObjectMeta: kapi.ObjectMeta{ Name: "headless", }, Subsets: []kapi.EndpointSubset{{ Addresses: []kapi.EndpointAddress{{IP: "172.0.0.1"}}, Ports: []kapi.EndpointPort{ {Port: 2345}, }, }}, }); err != nil { t.Fatalf("unexpected error: %v", err) } break } headlessIP := net.ParseIP("172.0.0.1") headlessIPHash := getHash(headlessIP.String()) if _, err := client.Services(kapi.NamespaceDefault).Create(&kapi.Service{ ObjectMeta: kapi.ObjectMeta{ Name: "headless2", }, Spec: kapi.ServiceSpec{ ClusterIP: kapi.ClusterIPNone, Ports: []kapi.ServicePort{{Port: 443}}, }, }); err != nil { t.Fatalf("unexpected error: %v", err) } if _, err := client.Endpoints(kapi.NamespaceDefault).Create(&kapi.Endpoints{ ObjectMeta: kapi.ObjectMeta{ Name: "headless2", }, Subsets: []kapi.EndpointSubset{{ Addresses: []kapi.EndpointAddress{{IP: "172.0.0.2"}}, Ports: []kapi.EndpointPort{ {Port: 2345, Name: "other"}, {Port: 2346, Name: "http"}, }, }}, }); err != nil { t.Fatalf("unexpected error: %v", err) } headless2IP := net.ParseIP("172.0.0.2") precannedIP := net.ParseIP("10.2.4.50") headless2IPHash := getHash(headless2IP.String()) tests := []struct { dnsQuestionName string recursionExpected bool retry bool expect []*net.IP srv []*dns.SRV }{ { // wildcard resolution of a service works dnsQuestionName: "foo.kubernetes.default.svc.cluster.local.", expect: []*net.IP{&masterIP}, }, { // resolving endpoints of a service works dnsQuestionName: "_endpoints.kubernetes.default.svc.cluster.local.", expect: []*net.IP{&localIP}, }, { // openshift override works dnsQuestionName: "openshift.default.svc.cluster.local.", expect: []*net.IP{&masterIP}, }, { // pod by IP dnsQuestionName: "10-2-4-50.default.pod.cluster.local.", expect: []*net.IP{&precannedIP}, }, { // headless service dnsQuestionName: "headless.default.svc.cluster.local.", expect: []*net.IP{&headlessIP}, }, { // specific port of a headless service dnsQuestionName: "unknown-port-2345.e1.headless.default.svc.cluster.local.", expect: []*net.IP{&headlessIP}, }, { // SRV record for that service dnsQuestionName: "headless.default.svc.cluster.local.", srv: []*dns.SRV{ { Target: headlessIPHash + "._unknown-port-2345._tcp.headless.default.svc.cluster.local.", Port: 2345, }, }, }, { // the SRV record resolves to the IP dnsQuestionName: "unknown-port-2345.e1.headless.default.svc.cluster.local.", expect: []*net.IP{&headlessIP}, }, { // headless 2 service dnsQuestionName: "headless2.default.svc.cluster.local.", expect: []*net.IP{&headless2IP}, }, { // SRV records for that service dnsQuestionName: "headless2.default.svc.cluster.local.", srv: []*dns.SRV{ { Target: headless2IPHash + "._http._tcp.headless2.default.svc.cluster.local.", Port: 2346, }, { Target: headless2IPHash + "._other._tcp.headless2.default.svc.cluster.local.", Port: 2345, }, }, }, { // the SRV record resolves to the IP dnsQuestionName: "other.e1.headless2.default.svc.cluster.local.", expect: []*net.IP{&headless2IP}, }, { dnsQuestionName: "www.google.com.", recursionExpected: true, }, } for i, tc := range tests { qType := dns.TypeA if tc.srv != nil { qType = dns.TypeSRV } m1 := &dns.Msg{ MsgHdr: dns.MsgHdr{Id: dns.Id(), RecursionDesired: tc.recursionExpected}, Question: []dns.Question{{tc.dnsQuestionName, qType, dns.ClassINET}}, } ch := make(chan struct{}) count := 0 failedLatency := 0 waitutil.Until(func() { count++ if count > 100 { t.Errorf("%d: failed after max iterations", i) close(ch) return } before := time.Now() in, err := dns.Exchange(m1, masterConfig.DNSConfig.BindAddress) if err != nil { return } after := time.Now() delta := after.Sub(before) if delta > 500*time.Millisecond { failedLatency++ if failedLatency > 10 { t.Errorf("%d: failed after 10 requests took longer than 500ms", i) close(ch) } return } switch { case tc.srv != nil: if len(in.Answer) != len(tc.srv) { t.Logf("%d: incorrect number of answers: %#v", i, in) return } case tc.recursionExpected: if len(in.Answer) == 0 { t.Errorf("%d: expected forward resolution: %#v", i, in) } close(ch) return default: if len(in.Answer) != len(tc.expect) { t.Logf("%d: did not resolve or unexpected forward resolution: %#v", i, in) return } } for _, answer := range in.Answer { switch a := answer.(type) { case *dns.A: matches := false if a.A != nil { for _, expect := range tc.expect { if a.A.String() == expect.String() { matches = true break } } } if !matches { t.Errorf("%d: A record does not match any expected answer for %q: %v", i, tc.dnsQuestionName, a.A) } case *dns.SRV: matches := false for _, expect := range tc.srv { if expect.Port == a.Port && expect.Target == a.Target { matches = true break } } if !matches { t.Errorf("%d: SRV record does not match any expected answer %q: %#v", i, tc.dnsQuestionName, a) } default: t.Errorf("%d: expected an A or SRV record %q: %#v", i, tc.dnsQuestionName, in) } } t.Log(in) close(ch) }, 50*time.Millisecond, ch) } }
func TestRecursiveCompress(t *testing.T) { const ( hostname = "foo.example." maxSize = 512 ) // Construct a response that is >512 when uncompressed, <512 when compressed response := dns.Msg{} response.Authoritative = true response.Answer = []dns.RR{} header := dns.RR_Header{ Name: hostname, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 10, } for response.Len() <= maxSize { ip := address.Address(rand.Uint32()).IP4() response.Answer = append(response.Answer, &dns.A{Hdr: header, A: ip}) } response.Compress = true require.True(t, response.Len() <= maxSize) // A dns server that returns the above response var gotRequest = make(chan struct{}, 1) handleRecursive := func(w dns.ResponseWriter, req *dns.Msg) { gotRequest <- struct{}{} require.Equal(t, req.Question[0].Name, hostname) response.SetReply(req) err := w.WriteMsg(&response) require.Nil(t, err) } mux := dns.NewServeMux() mux.HandleFunc(topDomain, handleRecursive) udpListener, err := net.ListenPacket("udp", "0.0.0.0:0") require.Nil(t, err) udpServer := &dns.Server{PacketConn: udpListener, Handler: mux} udpServerPort := udpListener.LocalAddr().(*net.UDPAddr).Port go udpServer.ActivateAndServe() defer udpServer.Shutdown() // The weavedns server, pointed at the above server dnsserver, _, udpPort, _ := startServer(t, &dns.ClientConfig{ Servers: []string{"127.0.0.1"}, Port: strconv.Itoa(udpServerPort), Ndots: 1, Timeout: 5, Attempts: 2, }) defer dnsserver.Stop() // Now do lookup, check its what we expected. // NB this doesn't really test golang's resolver behaves correctly, as I can't see // a way to point golangs resolver at a specific hosts. req := new(dns.Msg) req.Id = dns.Id() req.RecursionDesired = true req.Question = make([]dns.Question, 1) req.Question[0] = dns.Question{ Name: hostname, Qtype: dns.TypeA, Qclass: dns.ClassINET, } c := new(dns.Client) res, _, err := c.Exchange(req, fmt.Sprintf("127.0.0.1:%d", udpPort)) require.Nil(t, err) require.True(t, len(gotRequest) > 0) require.True(t, res.Len() > maxSize) }
// dedup will de-duplicate a message on a per section basis. // Multiple identical (same name, class, type and rdata) RRs will be coalesced into one. func (s *server) dedup(m *dns.Msg) *dns.Msg { // Answer section ma := make(map[string]dns.RR) for _, a := range m.Answer { // Or use Pack()... Think this function also could be placed in go dns. s1 := a.Header().Name s1 += strconv.Itoa(int(a.Header().Class)) s1 += strconv.Itoa(int(a.Header().Rrtype)) // there can only be one CNAME for an ownername if a.Header().Rrtype == dns.TypeCNAME { if _, ok := ma[s1]; ok { // already exist, randomly overwrite if roundrobin is true // Note: even with roundrobin *off* this depends on the // order we get the names. if s.config.RoundRobin && dns.Id()%2 == 0 { ma[s1] = a continue } } ma[s1] = a continue } for i := 1; i <= dns.NumField(a); i++ { s1 += dns.Field(a, i) } ma[s1] = a } // Only is our map is smaller than the #RR in the answer section we should reset the RRs // in the section it self if len(ma) < len(m.Answer) { i := 0 for _, v := range ma { m.Answer[i] = v i++ } m.Answer = m.Answer[:len(ma)] } // Additional section me := make(map[string]dns.RR) for _, e := range m.Extra { s1 := e.Header().Name s1 += strconv.Itoa(int(e.Header().Class)) s1 += strconv.Itoa(int(e.Header().Rrtype)) // there can only be one CNAME for an ownername if e.Header().Rrtype == dns.TypeCNAME { if _, ok := me[s1]; ok { // already exist, randomly overwrite if roundrobin is true if s.config.RoundRobin && dns.Id()%2 == 0 { me[s1] = e continue } } me[s1] = e continue } for i := 1; i <= dns.NumField(e); i++ { s1 += dns.Field(e, i) } me[s1] = e } if len(me) < len(m.Extra) { i := 0 for _, v := range me { m.Extra[i] = v i++ } m.Extra = m.Extra[:len(me)] } return m }
// ServeDNS is the handler for DNS requests, responsible for parsing DNS request, possibly forwarding // it to a real dns server and returning a response. func (s *server) ServeDNS(w dns.ResponseWriter, req *dns.Msg) { m := new(dns.Msg) m.SetReply(req) m.Authoritative = true m.RecursionAvailable = true m.Compress = true bufsize := uint16(512) dnssec := false tcp := false start := time.Now() if req.Question[0].Qtype == dns.TypeANY { m.Authoritative = false m.Rcode = dns.RcodeRefused m.RecursionAvailable = false m.RecursionDesired = false m.Compress = false // if write fails don't care w.WriteMsg(m) promErrorCount.WithLabelValues("refused").Inc() return } if o := req.IsEdns0(); o != nil { bufsize = o.UDPSize() dnssec = o.Do() } if bufsize < 512 { bufsize = 512 } // with TCP we can send 64K if tcp = isTCP(w); tcp { bufsize = dns.MaxMsgSize - 1 promRequestCount.WithLabelValues("tcp").Inc() } else { promRequestCount.WithLabelValues("udp").Inc() } StatsRequestCount.Inc(1) if dnssec { StatsDnssecOkCount.Inc(1) promDnssecOkCount.Inc() } defer func() { promCacheSize.WithLabelValues("response").Set(float64(s.rcache.Size())) }() // Check cache first. key := cache.QuestionKey(req.Question[0], dnssec) m1, exp, hit := s.rcache.Search(key) if hit { // Cache hit! \o/ if time.Since(exp) < 0 { m1.Id = m.Id m1.Compress = true m1.Truncated = false if dnssec { // The key for DNS/DNSSEC in cache is different, no // need to do Denial/Sign here. //if s.config.PubKey != nil { //s.Denial(m1) // not needed for cache hits //s.Sign(m1, bufsize) //} } if m1.Len() > int(bufsize) && !tcp { promErrorCount.WithLabelValues("truncated").Inc() m1.Truncated = true } // Still round-robin even with hits from the cache. // Only shuffle A and AAAA records with each other. if req.Question[0].Qtype == dns.TypeA || req.Question[0].Qtype == dns.TypeAAAA { s.RoundRobin(m1.Answer) } if err := w.WriteMsg(m1); err != nil { log.Printf("skydns: failure to return reply %q", err) } metricSizeAndDuration(m1, start, tcp) return } // Expired! /o\ s.rcache.Remove(key) } q := req.Question[0] name := strings.ToLower(q.Name) if s.config.Verbose { log.Printf("skydns: received DNS Request for %q from %q with type %d", q.Name, w.RemoteAddr(), q.Qtype) } for zone, ns := range *s.config.stub { if strings.HasSuffix(name, zone) { resp := s.ServeDNSStubForward(w, req, ns) metricSizeAndDuration(resp, start, tcp) return } } // If the qname is local.dns.skydns.local. and s.config.Local != "", substitute that name. if s.config.Local != "" && name == s.config.localDomain { name = s.config.Local } if q.Qtype == dns.TypePTR && strings.HasSuffix(name, ".in-addr.arpa.") || strings.HasSuffix(name, ".ip6.arpa.") { resp := s.ServeDNSReverse(w, req) metricSizeAndDuration(resp, start, tcp) return } if q.Qclass != dns.ClassCHAOS && !strings.HasSuffix(name, s.config.Domain) { if s.config.Verbose { log.Printf("skydns: %q is not sub of %q, forwarding...", name, s.config.Domain) } resp := s.ServeDNSForward(w, req) metricSizeAndDuration(resp, start, tcp) return } promCacheMiss.WithLabelValues("response").Inc() defer func() { if m.Rcode == dns.RcodeServerFailure { if err := w.WriteMsg(m); err != nil { log.Printf("skydns: failure to return reply %q", err) } return } // Set TTL to the minimum of the RRset and dedup the message, i.e. // remove identical RRs. m = s.dedup(m) minttl := s.config.Ttl if len(m.Answer) > 1 { for _, r := range m.Answer { if r.Header().Ttl < minttl { minttl = r.Header().Ttl } } for _, r := range m.Answer { r.Header().Ttl = minttl } } if !m.Truncated { s.rcache.InsertMessage(cache.QuestionKey(req.Question[0], dnssec), m) } if dnssec { if s.config.PubKey != nil { m.AuthenticatedData = true s.Denial(m) s.Sign(m, bufsize) } } if m.Len() > dns.MaxMsgSize { log.Printf("skydns: overflowing maximum message size: %d, dropping additional section", m.Len()) m.Extra = nil // Drop entire additional section to see if this helps. if m.Len() > dns.MaxMsgSize { // *Still* too large. log.Printf("skydns: still overflowing maximum message size: %d", m.Len()) promErrorCount.WithLabelValues("overflow").Inc() m1 := new(dns.Msg) // Use smaller msg to signal failure. m1.SetRcode(m, dns.RcodeServerFailure) if err := w.WriteMsg(m1); err != nil { log.Printf("skydns: failure to return reply %q", err) } metricSizeAndDuration(m1, start, tcp) return } } if m.Len() > int(bufsize) && !tcp { m.Extra = nil // As above, drop entire additional section. if m.Len() > int(bufsize) { promErrorCount.WithLabelValues("truncated").Inc() m.Truncated = true } } if err := w.WriteMsg(m); err != nil { log.Printf("skydns: failure to return reply %q %d", err, m.Len()) } metricSizeAndDuration(m, start, tcp) }() if name == s.config.Domain { if q.Qtype == dns.TypeSOA { m.Answer = []dns.RR{s.NewSOA()} return } if q.Qtype == dns.TypeDNSKEY { if s.config.PubKey != nil { m.Answer = []dns.RR{s.config.PubKey} return } } } if q.Qclass == dns.ClassCHAOS { if q.Qtype == dns.TypeTXT { switch name { case "authors.bind.": fallthrough case s.config.Domain: hdr := dns.RR_Header{Name: q.Name, Rrtype: dns.TypeTXT, Class: dns.ClassCHAOS, Ttl: 0} authors := []string{"Erik St. Martin", "Brian Ketelsen", "Miek Gieben", "Michael Crosby"} for _, a := range authors { m.Answer = append(m.Answer, &dns.TXT{Hdr: hdr, Txt: []string{a}}) } for j := 0; j < len(authors)*(int(dns.Id())%4+1); j++ { q := int(dns.Id()) % len(authors) p := int(dns.Id()) % len(authors) if q == p { p = (p + 1) % len(authors) } m.Answer[q], m.Answer[p] = m.Answer[p], m.Answer[q] } return case "version.bind.": fallthrough case "version.server.": hdr := dns.RR_Header{Name: q.Name, Rrtype: dns.TypeTXT, Class: dns.ClassCHAOS, Ttl: 0} m.Answer = []dns.RR{&dns.TXT{Hdr: hdr, Txt: []string{Version}}} return case "hostname.bind.": fallthrough case "id.server.": // TODO(miek): machine name to return hdr := dns.RR_Header{Name: q.Name, Rrtype: dns.TypeTXT, Class: dns.ClassCHAOS, Ttl: 0} m.Answer = []dns.RR{&dns.TXT{Hdr: hdr, Txt: []string{"localhost"}}} return } } // still here, fail m.SetReply(req) m.SetRcode(req, dns.RcodeServerFailure) return } switch q.Qtype { case dns.TypeNS: if name != s.config.Domain { log.Printf("skydns: %q unmatch default domain", name) break } // Lookup s.config.DnsDomain records, extra, err := s.NSRecords(q, s.config.dnsDomain) if err != nil { if e, ok := err.(*etcd.EtcdError); ok { if e.ErrorCode == 100 { s.NameError(m, req) return } } } m.Answer = append(m.Answer, records...) m.Extra = append(m.Extra, extra...) case dns.TypeA, dns.TypeAAAA: records, err := s.AddressRecords(q, name, nil, bufsize, dnssec, false) if err != nil { if e, ok := err.(*etcd.EtcdError); ok { if e.ErrorCode == 100 { s.NameError(m, req) return } } } m.Answer = append(m.Answer, records...) case dns.TypeTXT: records, err := s.TXTRecords(q, name) if err != nil { if e, ok := err.(*etcd.EtcdError); ok { if e.ErrorCode == 100 { s.NameError(m, req) return } } } m.Answer = append(m.Answer, records...) case dns.TypeCNAME: records, err := s.CNAMERecords(q, name) if err != nil { if e, ok := err.(*etcd.EtcdError); ok { if e.ErrorCode == 100 { s.NameError(m, req) return } } } m.Answer = append(m.Answer, records...) case dns.TypeMX: records, extra, err := s.MXRecords(q, name, bufsize, dnssec) if err != nil { if e, ok := err.(*etcd.EtcdError); ok { if e.ErrorCode == 100 { s.NameError(m, req) return } } } m.Answer = append(m.Answer, records...) m.Extra = append(m.Extra, extra...) default: fallthrough // also catch other types, so that they return NODATA case dns.TypeSRV: records, extra, err := s.SRVRecords(q, name, bufsize, dnssec) if err != nil { if e, ok := err.(*etcd.EtcdError); ok { if e.ErrorCode == 100 { s.NameError(m, req) return } } if q.Qtype == dns.TypeSRV { // Otherwise NODATA s.ServerFailure(m, req) return } } // if we are here again, check the types, because an answer may only // be given for SRV. All other types should return NODATA, the // NXDOMAIN part is handled in the above code. TODO(miek): yes this // can be done in a more elegant manor. if q.Qtype == dns.TypeSRV { m.Answer = append(m.Answer, records...) m.Extra = append(m.Extra, extra...) } } if len(m.Answer) == 0 { // NODATA response StatsNoDataCount.Inc(1) m.Ns = []dns.RR{s.NewSOA()} m.Ns[0].Header().Ttl = s.config.MinTtl } }