func TestDNSTtlRR(t *testing.T) { s := newTestServerDNSSEC(t, false) defer s.Stop() serv := &msg.Service{Host: "10.0.0.2", Key: "ttl.skydns.test.", Ttl: 360} addService(t, s, serv.Key, time.Duration(serv.Ttl)*time.Second, serv) defer delService(t, s, serv.Key) c := new(dns.Client) tc := dnsTestCases[9] // TTL Test t.Logf("%v\n", tc) m := new(dns.Msg) m.SetQuestion(tc.Qname, tc.Qtype) if tc.dnssec == true { m.SetEdns0(4096, true) } resp, _, err := c.Exchange(m, "127.0.0.1:"+StrPort) if err != nil { t.Errorf("failing: %s: %s\n", m.String(), err.Error()) } t.Logf("%s\n", resp) for i, a := range resp.Answer { if a.Header().Ttl != 360 { t.Errorf("Answer %d should have a Header TTL of %d, but has %d", i, 360, a.Header().Ttl) } } }
func TestDNSTtlRRset(t *testing.T) { s := newTestServerDNSSEC(t, false) defer s.Stop() ttl := uint32(60) for _, serv := range services { addService(t, s, serv.Key, uint64(ttl), serv) defer delService(t, s, serv.Key) ttl += 60 } c := new(dns.Client) tc := dnsTestCases[9] t.Logf("%v\n", tc) m := new(dns.Msg) m.SetQuestion(tc.Qname, tc.Qtype) if tc.dnssec == true { m.SetEdns0(4096, true) } resp, _, err := c.Exchange(m, "127.0.0.1:"+StrPort) if err != nil { t.Fatalf("failing: %s: %s\n", m.String(), err.Error()) } t.Logf("%s\n", resp) ttl = 360 for i, a := range resp.Answer { if a.Header().Ttl != ttl { t.Errorf("Answer %d should have a Header TTL of %d, but has %d", i, ttl, a.Header().Ttl) } } }
// perform a DNS query and assert the reply code, number or answers, etc func assertExchange(t *testing.T, z string, ty uint16, minAnswers int, maxAnswers int, expErr int) *dns.Msg { c := new(dns.Client) c.UDPSize = testUDPBufSize m := new(dns.Msg) m.RecursionDesired = true m.SetQuestion(z, ty) m.SetEdns0(testUDPBufSize, false) // we don't want to play with truncation here... r, _, err := c.Exchange(m, fmt.Sprintf("127.0.0.1:%d", testPort)) t.Logf("Response:\n%+v\n", r) wt.AssertNoErr(t, err) if minAnswers == 0 && maxAnswers == 0 { wt.AssertStatus(t, r.Rcode, expErr, "DNS response code") } else { wt.AssertStatus(t, r.Rcode, dns.RcodeSuccess, "DNS response code") } answers := len(r.Answer) if minAnswers >= 0 && answers < minAnswers { wt.Fatalf(t, "Number of answers >= %d", minAnswers) } if maxAnswers >= 0 && answers > maxAnswers { wt.Fatalf(t, "Number of answers <= %d", maxAnswers) } return r }
func TestInFlightEDns0(t *T) { m1 := new(dns.Msg) m1.SetQuestion(testAnyDomain, dns.TypeA) m1.SetEdns0(4096, false) w1 := getWriter() m2 := new(dns.Msg) m2.SetQuestion(testAnyDomain, dns.TypeA) w2 := getWriter() go func() { handleRequest(w1, m1) }() go func() { handleRequest(w2, m2) }() var r1 *dns.Msg var r2 *dns.Msg for r1 == nil || r2 == nil { select { case r1 = <-w1.ReplyCh: case r2 = <-w2.ReplyCh: } } //note: this test could be flaky since we're relying on google to return //edns0 response when we send one vs when we don't send one assert.NotNil(t, r1.IsEdns0()) assert.Nil(t, r2.IsEdns0()) }
// Perform a DNS query and assert the reply code, number or answers, etc func assertExchange(t *testing.T, z string, ty uint16, port int, minAnswers int, maxAnswers int, expErr int) (*dns.Msg, *dns.Msg) { require.NotEqual(t, 0, port, "invalid DNS server port") c := &dns.Client{ UDPSize: testUDPBufSize, } m := new(dns.Msg) m.RecursionDesired = true m.SetQuestion(z, ty) m.SetEdns0(testUDPBufSize, false) // we don't want to play with truncation here... lstAddr := fmt.Sprintf("127.0.0.1:%d", port) r, _, err := c.Exchange(m, lstAddr) t.Logf("Response from '%s':\n%+v\n", lstAddr, r) if err != nil { t.Errorf("Error when querying DNS server at %s: %s", lstAddr, err) } require.NoError(t, err) if minAnswers == 0 && maxAnswers == 0 { require.Equal(t, expErr, r.Rcode, "DNS response code") } else { require.Equal(t, dns.RcodeSuccess, r.Rcode, "DNS response code") } answers := len(r.Answer) if minAnswers >= 0 && answers < minAnswers { require.FailNow(t, fmt.Sprintf("Number of answers >= %d", minAnswers)) } if maxAnswers >= 0 && answers > maxAnswers { require.FailNow(t, fmt.Sprintf("Number of answers <= %d", maxAnswers)) } return m, r }
// dnsQuery will query a nameserver, iterating through the supplied servers as it retries // The nameserver should include a port, to facilitate testing where we talk to a mock dns server. func dnsQuery(fqdn string, rtype uint16, nameservers []string, recursive bool) (in *dns.Msg, err error) { m := new(dns.Msg) m.SetQuestion(fqdn, rtype) m.SetEdns0(4096, false) if !recursive { m.RecursionDesired = false } // Will retry the request based on the number of servers (n+1) for i := 1; i <= len(nameservers)+1; i++ { ns := nameservers[i%len(nameservers)] udp := &dns.Client{Net: "udp", Timeout: DNSTimeout} in, _, err = udp.Exchange(m, ns) if err == dns.ErrTruncated { tcp := &dns.Client{Net: "tcp", Timeout: DNSTimeout} // If the TCP request suceeds, the err will reset to nil in, _, err = tcp.Exchange(m, ns) } if err == nil { break } } return }
// Send a DNS query via UDP, configured by a Request object. If successful, // stores response details in Result object, otherwise, returns Result object // with an error string. func SendQuery(request *Request) (result Result, err error) { log.Printf("Sending query: %s", request) result.Request = *request record_type, ok := dns.StringToType[request.RecordType] if !ok { result.Error = fmt.Sprintf("Invalid type: %s", request.RecordType) return result, errors.New(result.Error) } m := new(dns.Msg) if request.VerifySignature == true { log.Printf("SetEdns0 for %s", request.RecordName) m.SetEdns0(4096, true) } m.SetQuestion(request.RecordName, record_type) c := new(dns.Client) in, rtt, err := c.Exchange(m, request.Destination) // log.Printf("Answer: %s [%d] %s", in, rtt, err) result.Duration = rtt if err != nil { result.Error = err.Error() } else { for _, rr := range in.Answer { answer := Answer{ Ttl: rr.Header().Ttl, Name: rr.Header().Name, String: rr.String(), } result.Answers = append(result.Answers, answer) } } return result, nil }
func searchServerIP(domain string, version int, DNSservers []string) (answer *dns.Msg, err error) { DNSserver := DNSservers[rand.Intn(len(DNSservers))] for i := 1; i <= 3; i++ { if DNSserver == "" { DNSserver = DNSservers[rand.Intn(len(DNSservers))] } } if DNSserver == "" { return nil, errors.New("DNSserver is an empty string") } dnsRequest := new(dns.Msg) if dnsRequest == nil { return nil, errors.New("Can not new dnsRequest") } dnsClient := new(dns.Client) if dnsClient == nil { return nil, errors.New("Can not new dnsClient") } if version == 4 { dnsRequest.SetQuestion(domain+".", dns.TypeA) } else if version == 6 { dnsRequest.SetQuestion(domain+".", dns.TypeAAAA) } else { return nil, errors.New("wrong parameter in version") } dnsRequest.SetEdns0(4096, true) answer, _, err = dnsClient.Exchange(dnsRequest, DNSserver) if err != nil { return nil, err } return answer, nil }
func TestDNSExpire(t *testing.T) { s := newTestServerDNSSEC(t) defer s.Stop() serv := services[0] addService(t, s, serv.key, 1, &Service{Host: serv.Host, Port: serv.Port}) // defer delService(t, s, serv.key) // It will delete itself...magically c := new(dns.Client) tc := dnsTestCases[0] m := new(dns.Msg) m.SetQuestion(tc.Qname, tc.Qtype) if tc.dnssec == true { m.SetEdns0(4096, true) } resp, _, err := c.Exchange(m, "127.0.0.1:"+StrPort) if err != nil { t.Fatalf("failing: %s: %s\n", m.String(), err.Error()) } if resp.Rcode != dns.RcodeSuccess { t.Logf("%v\n", resp) t.Fail() } // Sleep to let it expire. time.Sleep(2 * time.Second) resp, _, err = c.Exchange(m, "127.0.0.1:"+StrPort) if err != nil { t.Errorf("failing: %s\n", err.Error()) } if resp.Rcode != dns.RcodeNameError { t.Logf("%v\n", resp) t.Fail() } }
func localQuery(mychan chan DNSreply, qname string, qtype uint16) { var result DNSreply var trials uint result.qname = qname result.qtype = qtype result.r = nil result.err = errors.New("No name server to answer the question") localm := new(dns.Msg) localm.Id = dns.Id() localm.RecursionDesired = true localm.Question = make([]dns.Question, 1) localm.SetEdns0(EDNSBUFFERSIZE, false) // Even if no EDNS requested, see #9 May be we should retry without it if timeout? localc := new(dns.Client) localc.ReadTimeout = timeout localm.Question[0] = dns.Question{qname, qtype, dns.ClassINET} Tests: for trials = 0; trials < uint(*maxTrials); trials++ { Resolvers: for serverIndex := range conf.Servers { server := conf.Servers[serverIndex] result.nameserver = server // Brackets around the server address are necessary for IPv6 name servers r, rtt, err := localc.Exchange(localm, "["+server+"]:"+conf.Port) // Do not use net.JoinHostPort, see https://github.com/bortzmeyer/check-soa/commit/3e4edb13855d8c4016768796b2892aa83eda1933#commitcomment-2355543 if r == nil { result.r = nil result.err = err if strings.Contains(err.Error(), "timeout") { // Try another resolver break Resolvers } else { // We give in break Tests } } else { result.rtt = rtt if r.Rcode == dns.RcodeSuccess { // TODO: as a result, NODATA (NOERROR/ANSWER=0) are silently ignored (try "foo", for instance, the name exists but no IP address) // TODO: for rcodes like SERVFAIL, trying another resolver could make sense result.r = r result.err = nil break Tests } else { // All the other codes are errors. Yes, it may // happens that one resolver returns REFUSED // and the others work but we do not handle // this case. TODO: delete the resolver from // the list and try another one result.r = r result.err = errors.New(dns.RcodeToString[r.Rcode]) break Tests } } } } if *debug { fmt.Printf("DEBUG: end of DNS request \"%s\" / %d\n", qname, qtype) } mychan <- result }
// exchangeOne performs a single DNS exchange with a randomly chosen server // out of the server list, returning the response, time, and error (if any). // This method sets the DNSSEC OK bit on the message to true before sending // it to the resolver in case validation isn't the resolvers default behaviour. func (dnsResolver *DNSResolverImpl) exchangeOne(ctx context.Context, hostname string, qtype uint16, msgStats metrics.Scope) (*dns.Msg, error) { m := new(dns.Msg) // Set question type m.SetQuestion(dns.Fqdn(hostname), qtype) // Set DNSSEC OK bit for resolver m.SetEdns0(4096, true) if len(dnsResolver.servers) < 1 { return nil, fmt.Errorf("Not configured with at least one DNS Server") } dnsResolver.stats.Inc("Rate", 1) // Randomly pick a server chosenServer := dnsResolver.servers[rand.Intn(len(dnsResolver.servers))] client := dnsResolver.dnsClient tries := 1 start := dnsResolver.clk.Now() msgStats.Inc("Calls", 1) defer func() { msgStats.TimingDuration("Latency", dnsResolver.clk.Now().Sub(start)) }() for { msgStats.Inc("Tries", 1) ch := make(chan dnsResp, 1) go func() { rsp, rtt, err := client.Exchange(m, chosenServer) msgStats.TimingDuration("SingleTryLatency", rtt) ch <- dnsResp{m: rsp, err: err} }() select { case <-ctx.Done(): msgStats.Inc("Cancels", 1) msgStats.Inc("Errors", 1) return nil, ctx.Err() case r := <-ch: if r.err != nil { msgStats.Inc("Errors", 1) operr, ok := r.err.(*net.OpError) isRetryable := ok && operr.Temporary() hasRetriesLeft := tries < dnsResolver.maxTries if isRetryable && hasRetriesLeft { tries++ continue } else if isRetryable && !hasRetriesLeft { msgStats.Inc("RanOutOfTries", 1) } } else { msgStats.Inc("Successes", 1) } return r.m, r.err } } }
// Create skeleton edns opt RR from the query and // add it to the message m func ednsFromRequest(req, m *dns.Msg) { for _, r := range req.Extra { if r.Header().Rrtype == dns.TypeOPT { m.SetEdns0(4096, r.(*dns.OPT).Do()) return } } return }
func (self *TrivialDnsServer) refuseWithCode(w dns.ResponseWriter, req *dns.Msg, code int) { m := new(dns.Msg) for _, r := range req.Extra { if r.Header().Rrtype == dns.TypeOPT { m.SetEdns0(4096, r.(*dns.OPT).Do()) } } m.SetRcode(req, code) w.WriteMsg(m) }
func testLookupSRV(t *testing.T) { if !checkKubernetesRunning() { t.Skip("Skipping Kubernetes Integration tests. Kubernetes is not running") } // Note: Use different port to avoid conflict with servers used in other tests. coreFile := `.:2054 { kubernetes coredns.local { endpoint http://localhost:8080 namespaces demo } ` server, _, udp, err := Server(t, coreFile) if err != nil { t.Fatal("Could not get server: %s", err) } defer server.Stop() log.SetOutput(ioutil.Discard) // TODO: Add checks for A records in additional section for _, testData := range testdataLookupSRV { t.Logf("[log] Testing query string: '%v'\n", testData.Query) dnsClient := new(dns.Client) dnsMessage := new(dns.Msg) dnsMessage.SetQuestion(testData.Query, dns.TypeSRV) dnsMessage.SetEdns0(4096, true) res, _, err := dnsClient.Exchange(dnsMessage, udp) if err != nil { t.Fatal("Could not send query: %s", err) } // Count SRV records in the answer section srvRecordCount := 0 for _, a := range res.Answer { fmt.Printf("RR: %v\n", a) if a.Header().Rrtype == dns.TypeSRV { srvRecordCount++ } } if srvRecordCount != testData.SRVRecordCount { t.Errorf("Expected '%v' SRV records in response. Instead got '%v' SRV records. Test query string: '%v'", testData.SRVRecordCount, srvRecordCount, testData.Query) } if len(res.Answer) != testData.TotalAnswerCount { t.Errorf("Expected '%v' records in answer section. Instead got '%v' records in answer section. Test query string: '%v'", testData.TotalAnswerCount, len(res.Answer), testData.Query) } } }
func resolve(req *dns.Msg, dnssec bool) (*dns.Msg, error) { extra2 := []dns.RR{} for _, extra := range req.Extra { if extra.Header().Rrtype != dns.TypeOPT { extra2 = append(extra2, extra) } } req.Extra = extra2 req.SetEdns0(dns.DefaultMsgSize, dnssec) resolved, _, err := resolveViaResolverThreads(req) if err != nil { return nil, err } resolved.Compress = true return resolved, nil }
// dnsQuery sends a DNS query to the given nameserver. // The nameserver should include a port, to facilitate testing where we talk to a mock dns server. func dnsQuery(fqdn string, rtype uint16, nameserver string, recursive bool) (in *dns.Msg, err error) { m := new(dns.Msg) m.SetQuestion(fqdn, rtype) m.SetEdns0(4096, false) if !recursive { m.RecursionDesired = false } in, err = dns.Exchange(m, nameserver) if err == dns.ErrTruncated { tcp := &dns.Client{Net: "tcp"} in, _, err = tcp.Exchange(m, nameserver) } return }
func TestLookupBalanceRewriteCacheDnssec(t *testing.T) { name, rm, err := test.TempFile(t, ".", exampleOrg) if err != nil { t.Fatalf("failed to created zone: %s", err) } defer rm() rm1 := createKeyFile(t) defer rm1() corefile := `example.org:0 { file ` + name + ` rewrite ANY HINFO dnssec { key file ` + base + ` } loadbalance } ` ex, _, udp, err := Server(t, corefile) if err != nil { t.Errorf("Could get server to start: %s", err) return } defer ex.Stop() log.SetOutput(ioutil.Discard) c := new(dns.Client) m := new(dns.Msg) m.SetQuestion("example.org.", dns.TypeA) m.SetEdns0(4096, true) res, _, err := c.Exchange(m, udp) if err != nil { t.Fatalf("Could not send query: %s", err) } sig := 0 for _, a := range res.Answer { if a.Header().Rrtype == dns.TypeRRSIG { sig++ } } if sig == 0 { t.Errorf("expected RRSIGs, got none") t.Logf("%v\n", res) } }
func (h *Handle) ServeDNS(w dns.ResponseWriter, req *dns.Msg) { //fmt.Printf("Req %+v\n", req) // Blacklist lookup domain := req.Question[0].String() if strings.HasSuffix(domain, "A") { domain = domain[1:strings.LastIndex(domain, ".")] fmt.Printf("LOOKUP=%s\n", domain) if _, ok := Adlist[domain]; ok { // todo: now what? fmt.Printf("DROP=%s\n", domain) m := new(dns.Msg) for _, r := range req.Extra { if r.Header().Rrtype == dns.TypeOPT { m.SetEdns0(4096, r.(*dns.OPT).Do()) } } m.SetRcode(req, dns.RcodeRefused) w.WriteMsg(m) return } } // Forward c := new(dns.Client) res, rtt, err := c.Exchange(req, "8.8.8.8:53") if err != nil { fmt.Printf("Lookup fail %s", err.Error()) m := new(dns.Msg) for _, r := range req.Extra { if r.Header().Rrtype == dns.TypeOPT { m.SetEdns0(4096, r.(*dns.OPT).Do()) } } m.SetRcode(req, dns.RcodeRefused) w.WriteMsg(m) return } fmt.Printf("%s: request took %s\n", w.RemoteAddr(), rtt) w.WriteMsg(res) }
func main() { if len(os.Args) != 2 { fmt.Printf("%s NAME\n", os.Args[0]) os.Exit(1) } name := os.Args[1] conf, _ := dns.ClientConfigFromFile("/etc/resolv.conf") client := new(dns.Client) message := new(dns.Msg) message.Question = make([]dns.Question, 1) message.Question[0] = dns.Question{BASE, dns.TypeTXT, dns.ClassINET} message.SetEdns0(4096, true) message.RecursionDesired = true reply, _, err := client.Exchange(message, conf.Servers[0]+":"+conf.Port) if err != nil { fmt.Printf("Cannot get info for %s: %s\n", BASE, err) os.Exit(1) } if reply.Rcode != dns.RcodeSuccess { fmt.Printf("Bad answer from the resolver: %v\n", reply.Rcode) os.Exit(1) } if len(reply.Answer) == 0 { fmt.Printf("Zero answer for %s\n", BASE) os.Exit(1) } toReporter := make(chan instanceQuery) urls := 0 for _, rr := range reply.Answer { switch rr.(type) { case *dns.TXT: url := rr.(*dns.TXT).Txt[0] urls++ go queryOne(toReporter, url, name) } // Otherwise, ignore it. Probably a DNSSEC signature } fromReporter := make(chan string) go reporter(toReporter, fromReporter, urls) <-fromReporter }
// Lookup looks up name,type using the recursive nameserver defines // in the server's config. If none defined it returns an error. func (s *server) Lookup(n string, t, bufsize uint16, dnssec bool) (*dns.Msg, error) { StatsLookupCount.Inc(1) promExternalRequestCount.WithLabelValues("lookup").Inc() if len(s.config.Nameservers) == 0 { return nil, fmt.Errorf("no nameservers configured can not lookup name") } if dns.CountLabel(n) < s.config.Ndots { return nil, fmt.Errorf("name has fewer than %d labels", s.config.Ndots) } m := new(dns.Msg) m.SetQuestion(n, t) m.SetEdns0(bufsize, dnssec) nsid := int(m.Id) % len(s.config.Nameservers) try := 0 Redo: r, _, err := s.dnsUDPclient.Exchange(m, s.config.Nameservers[nsid]) if err == nil { if r.Rcode != dns.RcodeSuccess { return nil, fmt.Errorf("rcode is not equal to success") } // Reset TTLs to rcache TTL to make some of the other code // and the tests not care about TTLs for _, rr := range r.Answer { rr.Header().Ttl = uint32(s.config.RCacheTtl) } for _, rr := range r.Extra { rr.Header().Ttl = uint32(s.config.RCacheTtl) } return r, nil } // Seen an error, this can only mean, "server not reached", try again // but only if we have not exausted our nameservers. if try < len(s.config.Nameservers) { try++ nsid = (nsid + 1) % len(s.config.Nameservers) goto Redo } return nil, fmt.Errorf("failure to lookup name") }
// Lookup looks up name,type using the recursive nameserver defines // in the server's config. If none defined it returns an error func (s *server) Lookup(n string, t, bufsize uint16, dnssec bool) (*dns.Msg, error) { StatsLookupCount.Inc(1) if len(s.config.Nameservers) == 0 { return nil, fmt.Errorf("no nameservers configured can not lookup name") } if dns.CountLabel(n) < s.config.Ndots { return nil, fmt.Errorf("name has fewer than %d labels", s.config.Ndots) } m := new(dns.Msg) m.SetQuestion(n, t) m.SetEdns0(bufsize, dnssec) c := &dns.Client{Net: "udp", ReadTimeout: 2 * s.config.ReadTimeout, WriteTimeout: 2 * s.config.ReadTimeout} nsid := int(m.Id) % len(s.config.Nameservers) try := 0 Redo: // Move this to use s.udpClient/s.tcpClient code instead of allocating a new client for every query. r, _, err := c.Exchange(m, s.config.Nameservers[nsid]) if err == nil { if r.Rcode != dns.RcodeSuccess { return nil, fmt.Errorf("rcode is not equal to success") } // Reset TTLs to rcache TTL to make some of the other code // and the tests not care about TTLs for _, rr := range r.Answer { rr.Header().Ttl = uint32(s.config.RCacheTtl) } for _, rr := range r.Extra { rr.Header().Ttl = uint32(s.config.RCacheTtl) } return r, nil } // Seen an error, this can only mean, "server not reached", try again // but only if we have not exausted our nameservers. if try < len(s.config.Nameservers) { try++ nsid = (nsid + 1) % len(s.config.Nameservers) goto Redo } return nil, fmt.Errorf("failure to lookup name") }
func BenchmarkDNSSECSingleNoCache(b *testing.B) { b.StopTimer() t := new(testing.T) s := newTestServerDNSSEC(t, false) defer s.Stop() serv := services[0] addService(t, s, serv.Key, 0, serv) defer delService(t, s, serv.Key) c := new(dns.Client) tc := dnsTestCases[0] m := new(dns.Msg) m.SetQuestion(tc.Qname, tc.Qtype) m.SetEdns0(4096, true) b.StartTimer() for i := 0; i < b.N; i++ { c.Exchange(m, "127.0.0.1:"+StrPort) } }
// Get the key from the DNS (uses the local resolver) and return them. // If nothing is found we return nil func getKey(name string, keytag uint16, server string, tcp bool) *dns.DNSKEY { c := new(dns.Client) if tcp { c.Net = "tcp" } m := new(dns.Msg) m.SetQuestion(name, dns.TypeDNSKEY) m.SetEdns0(4096, true) r, _, err := c.Exchange(m, server) if err != nil { return nil } for _, k := range r.Answer { if k1, ok := k.(*dns.DNSKEY); ok { if k1.KeyTag() == keytag { return k1 } } } return nil }
// Retrieve the DNSKEY records of a zone and convert them // to DS records for SHA1, SHA256 and SHA384. func ExampleDS() { config, _ := dns.ClientConfigFromFile("/etc/resolv.conf") c := new(dns.Client) m := new(dns.Msg) zone := "miek.nl" m.SetQuestion(dns.Fqdn(zone), dns.TypeDNSKEY) m.SetEdns0(4096, true) r, _, err := c.Exchange(m, config.Servers[0]+":"+config.Port) if err != nil { return } if r.Rcode != dns.RcodeSuccess { return } for _, k := range r.Answer { if key, ok := k.(*dns.DNSKEY); ok { for _, alg := range []uint8{dns.SHA1, dns.SHA256, dns.SHA384} { fmt.Printf("%s; %d\n", key.ToDS(alg).String(), key.Flags) } } } }
func (t *tester) processName(name string) { r := &result{} msg := new(dns.Msg) msg.SetEdns0(4096, true) msg.SetQuestion(dns.Fqdn(name), dns.TypeA) wg := new(sync.WaitGroup) r.Started = time.Now() wg.Add(1) go func() { defer wg.Done() defer func() { r.NormalTook = time.Since(r.Started) }() resp, _, err := t.client.Exchange(msg, t.normalResolver) if err != nil { r.NormalError = err.Error() } else if resp.Rcode == dns.RcodeServerFailure { r.NormalError = fmt.Sprintf("%s", dns.RcodeToString[resp.Rcode]) } }() wg.Add(1) go func() { defer wg.Done() defer func() { r.TorTook = time.Since(r.Started) }() resp, _, err := t.client.Exchange(msg, t.torResolver) if err != nil { r.TorError = err.Error() } else if resp.Rcode == dns.RcodeServerFailure { r.TorError = fmt.Sprintf("%s", dns.RcodeToString[resp.Rcode]) } }() wg.Wait() if r.NormalError != "" && r.NormalError == r.TorError { return } t.results <- r }
func (srv *Server) serve(w dns.ResponseWriter, req *dns.Msg, z *Zone) { qname := req.Question[0].Name qtype := req.Question[0].Qtype var qle *querylog.Entry if srv.queryLogger != nil { qle = &querylog.Entry{ Time: time.Now().UnixNano(), Origin: z.Origin, Name: qname, Qtype: qtype, } defer srv.queryLogger.Write(qle) } logPrintf("[zone %s] incoming %s %s (id %d) from %s\n", z.Origin, qname, dns.TypeToString[qtype], req.Id, w.RemoteAddr()) // Global meter metrics.Get("queries").(metrics.Meter).Mark(1) // Zone meter z.Metrics.Queries.Mark(1) logPrintln("Got request", req) label := getQuestionName(z, req) z.Metrics.LabelStats.Add(label) // IP that's talking to us (not EDNS CLIENT SUBNET) var realIP net.IP if addr, ok := w.RemoteAddr().(*net.UDPAddr); ok { realIP = make(net.IP, len(addr.IP)) copy(realIP, addr.IP) } else if addr, ok := w.RemoteAddr().(*net.TCPAddr); ok { realIP = make(net.IP, len(addr.IP)) copy(realIP, addr.IP) } if qle != nil { qle.RemoteAddr = realIP.String() } z.Metrics.ClientStats.Add(realIP.String()) var ip net.IP // EDNS or real IP var edns *dns.EDNS0_SUBNET var opt_rr *dns.OPT for _, extra := range req.Extra { switch extra.(type) { case *dns.OPT: for _, o := range extra.(*dns.OPT).Option { opt_rr = extra.(*dns.OPT) switch e := o.(type) { case *dns.EDNS0_NSID: // do stuff with e.Nsid case *dns.EDNS0_SUBNET: z.Metrics.EdnsQueries.Mark(1) logPrintln("Got edns", e.Address, e.Family, e.SourceNetmask, e.SourceScope) if e.Address != nil { edns = e ip = e.Address if qle != nil { qle.HasECS = true qle.ClientAddr = fmt.Sprintf("%s/%d", ip, e.SourceNetmask) } } } } } } if len(ip) == 0 { // no edns subnet ip = realIP if qle != nil { qle.ClientAddr = fmt.Sprintf("%s/%d", ip, len(ip)*8) } } targets, netmask := z.Options.Targeting.GetTargets(ip) if qle != nil { qle.Targets = targets } m := new(dns.Msg) if qle != nil { defer func() { qle.Rcode = m.Rcode qle.Answers = len(m.Answer) }() } m.SetReply(req) if e := m.IsEdns0(); e != nil { m.SetEdns0(4096, e.Do()) } m.Authoritative = true // TODO: set scope to 0 if there are no alternate responses if edns != nil { if edns.Family != 0 { if netmask < 16 { netmask = 16 } edns.SourceScope = uint8(netmask) m.Extra = append(m.Extra, opt_rr) } } labels, labelQtype := z.findLabels(label, targets, qTypes{dns.TypeMF, dns.TypeCNAME, qtype}) if labelQtype == 0 { labelQtype = qtype } if labels == nil { permitDebug := !*flagPrivateDebug || (realIP != nil && realIP.IsLoopback()) firstLabel := (strings.Split(label, "."))[0] if qle != nil { qle.LabelName = firstLabel } if permitDebug && firstLabel == "_status" { if qtype == dns.TypeANY || qtype == dns.TypeTXT { m.Answer = statusRR(label + "." + z.Origin + ".") } else { m.Ns = append(m.Ns, z.SoaRR()) } m.Authoritative = true w.WriteMsg(m) return } if firstLabel == "_country" { if qtype == dns.TypeANY || qtype == dns.TypeTXT { h := dns.RR_Header{Ttl: 1, Class: dns.ClassINET, Rrtype: dns.TypeTXT} h.Name = label + "." + z.Origin + "." txt := []string{ w.RemoteAddr().String(), ip.String(), } targets, netmask := z.Options.Targeting.GetTargets(ip) txt = append(txt, strings.Join(targets, " ")) txt = append(txt, fmt.Sprintf("/%d", netmask), serverID, serverIP) m.Answer = []dns.RR{&dns.TXT{Hdr: h, Txt: txt, }} } else { m.Ns = append(m.Ns, z.SoaRR()) } m.Authoritative = true w.WriteMsg(m) return } // return NXDOMAIN m.SetRcode(req, dns.RcodeNameError) m.Authoritative = true m.Ns = []dns.RR{z.SoaRR()} w.WriteMsg(m) return } if servers := labels.Picker(labelQtype, labels.MaxHosts); servers != nil { var rrs []dns.RR for _, record := range servers { rr := dns.Copy(record.RR) rr.Header().Name = qname rrs = append(rrs, rr) } m.Answer = rrs } if len(m.Answer) == 0 { // Return a SOA so the NOERROR answer gets cached m.Ns = append(m.Ns, z.SoaRR()) } logPrintln(m) if qle != nil { qle.LabelName = labels.Label qle.Answers = len(m.Answer) qle.Rcode = m.Rcode } err := w.WriteMsg(m) if err != nil { // if Pack'ing fails the Write fails. Return SERVFAIL. log.Println("Error writing packet", m) dns.HandleFailed(w, req) } return }
func TestTCPDNSServer(t *testing.T) { setupForTest(t) const ( numAnswers = 512 nonLocalName = "weave.works." ) InitDefaultLogging(testing.Verbose()) Info.Println("TestTCPDNSServer starting") zone, err := NewZoneDb(ZoneConfig{}) require.NoError(t, err) err = zone.Start() require.NoError(t, err) defer zone.Stop() // generate a list of `numAnswers` IP addresses var addrs []ZoneRecord bs := make([]byte, 4) for i := 0; i < numAnswers; i++ { binary.LittleEndian.PutUint32(bs, uint32(i)) ip := net.IPv4(bs[0], bs[1], bs[2], bs[3]) addrs = append(addrs, ZoneRecord(Record{"", ip, 0, 0, 0})) } // handler for the fallback server: it will just return a very long response fallbackUDPHandler := func(w dns.ResponseWriter, req *dns.Msg) { if len(req.Question) == 0 { return // ignore empty queries (sent when shutting down the server) } maxLen := getMaxReplyLen(req, protUDP) t.Logf("Fallback UDP server got asked: returning %d answers", numAnswers) q := req.Question[0] m := makeAddressReply(req, &q, addrs, DefaultLocalTTL) mLen := m.Len() m.SetEdns0(uint16(maxLen), false) if mLen > maxLen { t.Logf("... truncated response (%d > %d)", mLen, maxLen) m.Truncated = true } w.WriteMsg(m) } fallbackTCPHandler := func(w dns.ResponseWriter, req *dns.Msg) { if len(req.Question) == 0 { return // ignore empty queries (sent when shutting down the server) } t.Logf("Fallback TCP server got asked: returning %d answers", numAnswers) q := req.Question[0] m := makeAddressReply(req, &q, addrs, DefaultLocalTTL) w.WriteMsg(m) } t.Logf("Running a DNS fallback server with UDP") fallback, err := newMockedFallback(fallbackUDPHandler, fallbackTCPHandler) require.NoError(t, err) fallback.Start() defer fallback.Stop() t.Logf("Creating a WeaveDNS server instance, falling back to 127.0.0.1:%d", fallback.Port) srv, err := NewDNSServer(DNSServerConfig{ Zone: zone, UpstreamCfg: fallback.CliConfig, CacheDisabled: true, ListenReadTimeout: testSocketTimeout, }) require.NoError(t, err) err = srv.Start() require.NoError(t, err) go srv.ActivateAndServe() defer srv.Stop() time.Sleep(100 * time.Millisecond) // Allow sever goroutine to start testPort, err := srv.GetPort() require.NoError(t, err) require.NotEqual(t, 0, testPort, "listen port") dnsAddr := fmt.Sprintf("127.0.0.1:%d", testPort) t.Logf("Creating a UDP and a TCP client") uc := new(dns.Client) uc.UDPSize = minUDPSize tc := new(dns.Client) tc.Net = "tcp" t.Logf("Creating DNS query message") m := new(dns.Msg) m.RecursionDesired = true m.SetQuestion(nonLocalName, dns.TypeA) t.Logf("Checking the fallback server at %s returns a truncated response with UDP", fallback.Addr) r, _, err := uc.Exchange(m, fallback.Addr) t.Logf("Got response from fallback server (UDP) with %d answers", len(r.Answer)) t.Logf("Response:\n%+v\n", r) require.NoError(t, err) require.True(t, r.MsgHdr.Truncated, "DNS truncated reponse flag") require.NotEqual(t, numAnswers, len(r.Answer), "number of answers (UDP)") t.Logf("Checking the WeaveDNS server at %s returns a truncated reponse with UDP", dnsAddr) r, _, err = uc.Exchange(m, dnsAddr) t.Logf("UDP Response:\n%+v\n", r) require.NoError(t, err) require.NotNil(t, r, "response") t.Logf("%d answers", len(r.Answer)) require.True(t, r.MsgHdr.Truncated, "DNS truncated reponse flag") require.NotEqual(t, numAnswers, len(r.Answer), "number of answers (UDP)") t.Logf("Checking the WeaveDNS server at %s does not return a truncated reponse with TCP", dnsAddr) r, _, err = tc.Exchange(m, dnsAddr) t.Logf("TCP Response:\n%+v\n", r) require.NoError(t, err) require.NotNil(t, r, "response") t.Logf("%d answers", len(r.Answer)) require.False(t, r.MsgHdr.Truncated, "DNS truncated response flag") require.Equal(t, numAnswers, len(r.Answer), "number of answers (TCP)") t.Logf("Checking the WeaveDNS server at %s does not return a truncated reponse with UDP with a bigger buffer", dnsAddr) m.SetEdns0(testUDPBufSize, false) r, _, err = uc.Exchange(m, dnsAddr) t.Logf("UDP-large Response:\n%+v\n", r) require.NoError(t, err) require.NotNil(t, r, "response") t.Logf("%d answers", len(r.Answer)) require.NoError(t, err) require.False(t, r.MsgHdr.Truncated, "DNS truncated response flag") require.Equal(t, numAnswers, len(r.Answer), "number of answers (UDP-long)") }
/* For fragmentation, we use a naive algorithm. We use the same header for every fragment, and include the same EDNS0 section in every additional section. We add one RR at a time, until our fragment is larger than 512 bytes, then we remove the last RR so that it fits in the 512 byte size limit. If we discover that one of the fragments ends up with 0 RR in it (for example because a single RR is too big), then we return a single truncated response instead of the set of fragments. We could perhaps make the process of building fragments faster by bisecting the set of RR that we include in an answer. So, if we have 8 RR we could try all, then if that is too big, 4 RR, and if that fits then 6 RR, until an optimal set of RR is found. We could also possibly produce a smaller set of responses by optimizing how we combine RR. Just taking account the various sizes is the same as the bin packing problem, which is NP-hard: https://en.wikipedia.org/wiki/Bin_packing_problem While some non-optimal but reasonable heuristics exist, in the case of DNS we would have to use some sophisticated algorithm to also consider name compression. */ func frag(reply *dns.Msg) []dns.Msg { // create a return value all_frags := []dns.Msg{} HasEdns0 := true // get each RR section and save a copy out remaining_answer := make([]dns.RR, len(reply.Answer)) copy(remaining_answer, reply.Answer) remaining_ns := make([]dns.RR, len(reply.Ns)) copy(remaining_ns, reply.Ns) remaining_extra := make([]dns.RR, len(reply.Extra)) copy(remaining_extra, reply.Extra) // if we don't have EDNS0 in the packet, add it now if reply.IsEdns0() == nil { reply.SetEdns0(512, false) } // the EDNS option for later use var edns0_rr dns.RR = nil // remove the EDNS0 option from our additional ("extra") section // (we will include it separately on every fragment) for ofs, r := range remaining_extra { // found the EDNS option if r.Header().Rrtype == dns.TypeOPT { // save the EDNS option edns0_rr = r // remove from the set of extra RR remaining_extra = append(remaining_extra[0:ofs], remaining_extra[ofs+1:]...) // in principle we should only have one EDNS0 section break } } if edns0_rr == nil { log.Printf("Server reply missing EDNS0 option") return []dns.Msg{} //HasEdns0 = false } // now build fragments for { // make a shallow copy of our reply packet, and prepare space for our RR frag := *reply frag.Answer = []dns.RR{} frag.Ns = []dns.RR{} frag.Extra = []dns.RR{} // add our custom EDNS0 option (needed in every fragment) local_opt := new(dns.EDNS0_LOCAL) local_opt.Code = dns.EDNS0LOCALSTART + 1 local_opt.Data = []byte{0, 0} if HasEdns0 == true { edns0_rr_copy := dns.Copy(edns0_rr) edns0_rr_copy.(*dns.OPT).Option = append(edns0_rr_copy.(*dns.OPT).Option, local_opt) frag.Extra = append(frag.Extra, edns0_rr_copy) } //if HasEdns0 == false { // frag.Extra = append(frag.Extra, local_opt) //} // add as many RR to the answer as we can for len(remaining_answer) > 0 { frag.Answer = append(frag.Answer, remaining_answer[0]) if frag.Len() <= 512 { // if the new answer fits, then remove it from our remaining list remaining_answer = remaining_answer[1:] } else { // otherwise we are full, remove it from our fragment and stop frag.Answer = frag.Answer[0 : len(frag.Answer)-1] break } } for len(remaining_ns) > 0 { frag.Ns = append(frag.Ns, remaining_ns[0]) if frag.Len() <= 512 { // if the new answer fits, then remove it from our remaining list remaining_ns = remaining_ns[1:] } else { // otherwise we are full, remove it from our fragment and stop frag.Ns = frag.Ns[0 : len(frag.Ns)-1] break } } for len(remaining_extra) > 0 { frag.Extra = append(frag.Extra, remaining_extra[0]) if frag.Len() <= 512 { // if the new answer fits, then remove it from our remaining list remaining_extra = remaining_extra[1:] } else { // otherwise we are full, remove it from our fragment and stop frag.Extra = frag.Extra[0 : len(frag.Extra)-1] break } } // check to see if we didn't manage to add any RR if (len(frag.Answer) == 0) && (len(frag.Ns) == 0) && (len(frag.Extra) == 1) { // TODO: test this :) // return a single truncated fragment without any RR frag.MsgHdr.Truncated = true frag.Extra = []dns.RR{} return []dns.Msg{frag} } // add to our list of fragments all_frags = append(all_frags, frag) // if we have finished all remaining sections, we are done if (len(remaining_answer) == 0) && (len(remaining_ns) == 0) && (len(remaining_extra) == 0) { break } } // fix up our fragments so they have the correct sequence and length values for n, frag := range all_frags { frag_edns0 := frag.IsEdns0() for _, opt := range frag_edns0.Option { if opt.Option() == dns.EDNS0LOCALSTART+1 { opt.(*dns.EDNS0_LOCAL).Data = []byte{byte(len(all_frags)), byte(n)} } } } // return our fragments return all_frags }
func TestDNS(t *testing.T) { s := newTestServerDNSSEC(t, false) defer s.Stop() for _, serv := range services { addService(t, s, serv.Key, 0, serv) defer delService(t, s, serv.Key) } c := new(dns.Client) for _, tc := range dnsTestCases { m := new(dns.Msg) m.SetQuestion(tc.Qname, tc.Qtype) if tc.dnssec { m.SetEdns0(4096, true) } if tc.chaos { m.Question[0].Qclass = dns.ClassCHAOS } resp, _, err := c.Exchange(m, "127.0.0.1:"+StrPort) if err != nil { // try twice, be more resilent against remote lookups // timing out. resp, _, err = c.Exchange(m, "127.0.0.1:"+StrPort) if err != nil { t.Fatalf("failing: %s: %s\n", m.String(), err.Error()) } } sort.Sort(rrSet(resp.Answer)) sort.Sort(rrSet(resp.Ns)) sort.Sort(rrSet(resp.Extra)) fatal := false defer func() { if fatal { t.Logf("question: %s\n", m.Question[0].String()) t.Logf("%s\n", resp) } }() if resp.Rcode != tc.Rcode { fatal = true t.Fatalf("rcode is %q, expected %q", dns.RcodeToString[resp.Rcode], dns.RcodeToString[tc.Rcode]) } if len(resp.Answer) != len(tc.Answer) { fatal = true t.Fatalf("answer for %q contained %d results, %d expected", tc.Qname, len(resp.Answer), len(tc.Answer)) } for i, a := range resp.Answer { if a.Header().Name != tc.Answer[i].Header().Name { fatal = true t.Fatalf("answer %d should have a Header Name of %q, but has %q", i, tc.Answer[i].Header().Name, a.Header().Name) } if a.Header().Ttl != tc.Answer[i].Header().Ttl { fatal = true t.Fatalf("Answer %d should have a Header TTL of %d, but has %d", i, tc.Answer[i].Header().Ttl, a.Header().Ttl) } if a.Header().Rrtype != tc.Answer[i].Header().Rrtype { fatal = true t.Fatalf("answer %d should have a header response type of %d, but has %d", i, tc.Answer[i].Header().Rrtype, a.Header().Rrtype) } switch x := a.(type) { case *dns.SRV: if x.Priority != tc.Answer[i].(*dns.SRV).Priority { fatal = true t.Fatalf("answer %d should have a Priority of %d, but has %d", i, tc.Answer[i].(*dns.SRV).Priority, x.Priority) } if x.Weight != tc.Answer[i].(*dns.SRV).Weight { fatal = true t.Fatalf("answer %d should have a Weight of %d, but has %d", i, tc.Answer[i].(*dns.SRV).Weight, x.Weight) } if x.Port != tc.Answer[i].(*dns.SRV).Port { fatal = true t.Fatalf("answer %d should have a Port of %d, but has %d", i, tc.Answer[i].(*dns.SRV).Port, x.Port) } if x.Target != tc.Answer[i].(*dns.SRV).Target { fatal = true t.Fatalf("answer %d should have a Target of %q, but has %q", i, tc.Answer[i].(*dns.SRV).Target, x.Target) } case *dns.A: if x.A.String() != tc.Answer[i].(*dns.A).A.String() { fatal = true t.Fatalf("answer %d should have a Address of %q, but has %q", i, tc.Answer[i].(*dns.A).A.String(), x.A.String()) } case *dns.AAAA: if x.AAAA.String() != tc.Answer[i].(*dns.AAAA).AAAA.String() { fatal = true t.Fatalf("answer %d should have a Address of %q, but has %q", i, tc.Answer[i].(*dns.AAAA).AAAA.String(), x.AAAA.String()) } case *dns.TXT: for j, txt := range x.Txt { if txt != tc.Answer[i].(*dns.TXT).Txt[j] { fatal = true t.Fatalf("answer %d should have a Txt of %q, but has %q", i, tc.Answer[i].(*dns.TXT).Txt[j], txt) } } case *dns.DNSKEY: tt := tc.Answer[i].(*dns.DNSKEY) if x.Flags != tt.Flags { fatal = true t.Fatalf("DNSKEY flags should be %q, but is %q", x.Flags, tt.Flags) } if x.Protocol != tt.Protocol { fatal = true t.Fatalf("DNSKEY protocol should be %q, but is %q", x.Protocol, tt.Protocol) } if x.Algorithm != tt.Algorithm { fatal = true t.Fatalf("DNSKEY algorithm should be %q, but is %q", x.Algorithm, tt.Algorithm) } case *dns.RRSIG: tt := tc.Answer[i].(*dns.RRSIG) if x.TypeCovered != tt.TypeCovered { fatal = true t.Fatalf("RRSIG type-covered should be %d, but is %d", x.TypeCovered, tt.TypeCovered) } if x.Algorithm != tt.Algorithm { fatal = true t.Fatalf("RRSIG algorithm should be %d, but is %d", x.Algorithm, tt.Algorithm) } if x.Labels != tt.Labels { fatal = true t.Fatalf("RRSIG label should be %d, but is %d", x.Labels, tt.Labels) } if x.OrigTtl != tt.OrigTtl { fatal = true t.Fatalf("RRSIG orig-ttl should be %d, but is %d", x.OrigTtl, tt.OrigTtl) } if x.KeyTag != tt.KeyTag { fatal = true t.Fatalf("RRSIG key-tag should be %d, but is %d", x.KeyTag, tt.KeyTag) } if x.SignerName != tt.SignerName { fatal = true t.Fatalf("RRSIG signer-name should be %q, but is %q", x.SignerName, tt.SignerName) } case *dns.SOA: tt := tc.Answer[i].(*dns.SOA) if x.Ns != tt.Ns { fatal = true t.Fatalf("SOA nameserver should be %q, but is %q", x.Ns, tt.Ns) } case *dns.PTR: tt := tc.Answer[i].(*dns.PTR) if x.Ptr != tt.Ptr { fatal = true t.Fatalf("PTR ptr should be %q, but is %q", x.Ptr, tt.Ptr) } case *dns.CNAME: tt := tc.Answer[i].(*dns.CNAME) if x.Target != tt.Target { fatal = true t.Fatalf("CNAME target should be %q, but is %q", x.Target, tt.Target) } case *dns.MX: tt := tc.Answer[i].(*dns.MX) if x.Mx != tt.Mx { t.Fatalf("MX Mx should be %q, but is %q", x.Mx, tt.Mx) } if x.Preference != tt.Preference { t.Fatalf("MX Preference should be %q, but is %q", x.Preference, tt.Preference) } } } if len(resp.Ns) != len(tc.Ns) { fatal = true t.Fatalf("authority for %q contained %d results, %d expected", tc.Qname, len(resp.Ns), len(tc.Ns)) } for i, n := range resp.Ns { switch x := n.(type) { case *dns.SOA: tt := tc.Ns[i].(*dns.SOA) if x.Ns != tt.Ns { fatal = true t.Fatalf("SOA nameserver should be %q, but is %q", x.Ns, tt.Ns) } case *dns.NS: tt := tc.Ns[i].(*dns.NS) if x.Ns != tt.Ns { fatal = true t.Fatalf("NS nameserver should be %q, but is %q", x.Ns, tt.Ns) } case *dns.NSEC3: tt := tc.Ns[i].(*dns.NSEC3) if x.NextDomain != tt.NextDomain { fatal = true t.Fatalf("NSEC3 nextdomain should be %q, but is %q", x.NextDomain, tt.NextDomain) } if x.Hdr.Name != tt.Hdr.Name { fatal = true t.Fatalf("NSEC3 ownername should be %q, but is %q", x.Hdr.Name, tt.Hdr.Name) } for j, y := range x.TypeBitMap { if y != tt.TypeBitMap[j] { fatal = true t.Fatalf("NSEC3 bitmap should have %q, but is %q", dns.TypeToString[y], dns.TypeToString[tt.TypeBitMap[j]]) } } } } if len(resp.Extra) != len(tc.Extra) { fatal = true t.Fatalf("additional for %q contained %d results, %d expected", tc.Qname, len(resp.Extra), len(tc.Extra)) } for i, e := range resp.Extra { switch x := e.(type) { case *dns.A: if x.A.String() != tc.Extra[i].(*dns.A).A.String() { fatal = true t.Fatalf("extra %d should have a address of %q, but has %q", i, tc.Extra[i].(*dns.A).A.String(), x.A.String()) } case *dns.AAAA: if x.AAAA.String() != tc.Extra[i].(*dns.AAAA).AAAA.String() { fatal = true t.Fatalf("extra %d should have a address of %q, but has %q", i, tc.Extra[i].(*dns.AAAA).AAAA.String(), x.AAAA.String()) } case *dns.CNAME: tt := tc.Extra[i].(*dns.CNAME) if x.Target != tt.Target { // Super super gross hack. if x.Target == "a.ipaddr.skydns.test." && tt.Target == "b.ipaddr.skydns.test." { // These records are randomly choosen, either one is OK. continue } fatal = true t.Fatalf("CNAME target should be %q, but is %q", x.Target, tt.Target) } } } } }
func TestDNSStubForward(t *testing.T) { s := newTestServer(t, false) defer s.Stop() c := new(dns.Client) m := new(dns.Msg) stubEx := &msg.Service{ // IP address of a.iana-servers.net. Host: "199.43.132.53", Key: "a.example.com.stub.dns.skydns.test.", } stubBroken := &msg.Service{ Host: "127.0.0.1", Port: 5454, Key: "b.example.org.stub.dns.skydns.test.", } stubLoop := &msg.Service{ Host: "127.0.0.1", Port: Port, Key: "b.example.net.stub.dns.skydns.test.", } addService(t, s, stubEx.Key, 0, stubEx) defer delService(t, s, stubEx.Key) addService(t, s, stubBroken.Key, 0, stubBroken) defer delService(t, s, stubBroken.Key) addService(t, s, stubLoop.Key, 0, stubLoop) defer delService(t, s, stubLoop.Key) s.UpdateStubZones() m.SetQuestion("www.example.com.", dns.TypeA) resp, _, err := c.Exchange(m, "127.0.0.1:"+StrPort) if err != nil { // try twice resp, _, err = c.Exchange(m, "127.0.0.1:"+StrPort) if err != nil { t.Fatal(err) } } if len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess { t.Fatal("answer expected to have A records or rcode not equal to RcodeSuccess") } // The main diff. here is that we expect the AA bit to be set, because we directly // queried the authoritative servers. if resp.Authoritative != true { t.Fatal("answer expected to have AA bit set") } // This should fail. m.SetQuestion("www.example.org.", dns.TypeA) resp, _, err = c.Exchange(m, "127.0.0.1:"+StrPort) if len(resp.Answer) != 0 || resp.Rcode != dns.RcodeServerFailure { t.Fatal("answer expected to fail for example.org") } // This should really fail with a timeout. m.SetQuestion("www.example.net.", dns.TypeA) resp, _, err = c.Exchange(m, "127.0.0.1:"+StrPort) if err == nil { t.Fatal("answer expected to fail for example.net") } else { t.Logf("succesfully failing %s", err) } // Packet with EDNS0 m.SetEdns0(4096, true) resp, _, err = c.Exchange(m, "127.0.0.1:"+StrPort) if err == nil { t.Fatal("answer expected to fail for example.net") } else { t.Logf("succesfully failing %s", err) } // Now start another SkyDNS instance on a different port, // add a stubservice for it and check if the forwarding is // actually working. oldStrPort := StrPort s1 := newTestServer(t, false) defer s1.Stop() s1.config.Domain = "skydns.com." // Add forwarding IP for internal.skydns.com. Use Port to point to server s. stubForward := &msg.Service{ Host: "127.0.0.1", Port: Port, Key: "b.internal.skydns.com.stub.dns.skydns.test.", } addService(t, s, stubForward.Key, 0, stubForward) defer delService(t, s, stubForward.Key) s.UpdateStubZones() // Add an answer for this in our "new" server. stubReply := &msg.Service{ Host: "127.1.1.1", Key: "www.internal.skydns.com.", } addService(t, s1, stubReply.Key, 0, stubReply) defer delService(t, s1, stubReply.Key) m = new(dns.Msg) m.SetQuestion("www.internal.skydns.com.", dns.TypeA) resp, _, err = c.Exchange(m, "127.0.0.1:"+oldStrPort) if err != nil { t.Fatalf("failed to forward %s", err) } if resp.Answer[0].(*dns.A).A.String() != "127.1.1.1" { t.Fatalf("failed to get correct reply") } // Adding an in baliwick internal domain forward. s2 := newTestServer(t, false) defer s2.Stop() s2.config.Domain = "internal.skydns.net." // Add forwarding IP for internal.skydns.net. Use Port to point to server s. stubForward1 := &msg.Service{ Host: "127.0.0.1", Port: Port, Key: "b.internal.skydns.net.stub.dns.skydns.test.", } addService(t, s, stubForward1.Key, 0, stubForward1) defer delService(t, s, stubForward1.Key) s.UpdateStubZones() // Add an answer for this in our "new" server. stubReply1 := &msg.Service{ Host: "127.10.10.10", Key: "www.internal.skydns.net.", } addService(t, s2, stubReply1.Key, 0, stubReply1) defer delService(t, s2, stubReply1.Key) m = new(dns.Msg) m.SetQuestion("www.internal.skydns.net.", dns.TypeA) resp, _, err = c.Exchange(m, "127.0.0.1:"+oldStrPort) if err != nil { t.Fatalf("failed to forward %s", err) } if resp.Answer[0].(*dns.A).A.String() != "127.10.10.10" { t.Fatalf("failed to get correct reply") } }