Ejemplo n.º 1
0
func TestDNSTtlRR(t *testing.T) {
	s := newTestServerDNSSEC(t, false)
	defer s.Stop()

	serv := &msg.Service{Host: "10.0.0.2", Key: "ttl.skydns.test.", Ttl: 360}
	addService(t, s, serv.Key, time.Duration(serv.Ttl)*time.Second, serv)
	defer delService(t, s, serv.Key)

	c := new(dns.Client)

	tc := dnsTestCases[9] // TTL Test
	t.Logf("%v\n", tc)
	m := new(dns.Msg)
	m.SetQuestion(tc.Qname, tc.Qtype)
	if tc.dnssec == true {
		m.SetEdns0(4096, true)
	}
	resp, _, err := c.Exchange(m, "127.0.0.1:"+StrPort)
	if err != nil {
		t.Errorf("failing: %s: %s\n", m.String(), err.Error())
	}
	t.Logf("%s\n", resp)

	for i, a := range resp.Answer {
		if a.Header().Ttl != 360 {
			t.Errorf("Answer %d should have a Header TTL of %d, but has %d", i, 360, a.Header().Ttl)
		}
	}
}
Ejemplo n.º 2
0
func TestDNSTtlRRset(t *testing.T) {
	s := newTestServerDNSSEC(t, false)
	defer s.Stop()

	ttl := uint32(60)
	for _, serv := range services {
		addService(t, s, serv.Key, uint64(ttl), serv)
		defer delService(t, s, serv.Key)
		ttl += 60
	}
	c := new(dns.Client)
	tc := dnsTestCases[9]
	t.Logf("%v\n", tc)
	m := new(dns.Msg)
	m.SetQuestion(tc.Qname, tc.Qtype)
	if tc.dnssec == true {
		m.SetEdns0(4096, true)
	}
	resp, _, err := c.Exchange(m, "127.0.0.1:"+StrPort)
	if err != nil {
		t.Fatalf("failing: %s: %s\n", m.String(), err.Error())
	}
	t.Logf("%s\n", resp)
	ttl = 360
	for i, a := range resp.Answer {
		if a.Header().Ttl != ttl {
			t.Errorf("Answer %d should have a Header TTL of %d, but has %d", i, ttl, a.Header().Ttl)
		}
	}
}
Ejemplo n.º 3
0
// perform a DNS query and assert the reply code, number or answers, etc
func assertExchange(t *testing.T, z string, ty uint16, minAnswers int, maxAnswers int, expErr int) *dns.Msg {
	c := new(dns.Client)
	c.UDPSize = testUDPBufSize
	m := new(dns.Msg)
	m.RecursionDesired = true
	m.SetQuestion(z, ty)
	m.SetEdns0(testUDPBufSize, false) // we don't want to play with truncation here...

	r, _, err := c.Exchange(m, fmt.Sprintf("127.0.0.1:%d", testPort))
	t.Logf("Response:\n%+v\n", r)
	wt.AssertNoErr(t, err)
	if minAnswers == 0 && maxAnswers == 0 {
		wt.AssertStatus(t, r.Rcode, expErr, "DNS response code")
	} else {
		wt.AssertStatus(t, r.Rcode, dns.RcodeSuccess, "DNS response code")
	}
	answers := len(r.Answer)
	if minAnswers >= 0 && answers < minAnswers {
		wt.Fatalf(t, "Number of answers >= %d", minAnswers)
	}
	if maxAnswers >= 0 && answers > maxAnswers {
		wt.Fatalf(t, "Number of answers <= %d", maxAnswers)
	}
	return r
}
Ejemplo n.º 4
0
func TestInFlightEDns0(t *T) {
	m1 := new(dns.Msg)
	m1.SetQuestion(testAnyDomain, dns.TypeA)
	m1.SetEdns0(4096, false)
	w1 := getWriter()

	m2 := new(dns.Msg)
	m2.SetQuestion(testAnyDomain, dns.TypeA)
	w2 := getWriter()

	go func() {
		handleRequest(w1, m1)
	}()
	go func() {
		handleRequest(w2, m2)
	}()
	var r1 *dns.Msg
	var r2 *dns.Msg
	for r1 == nil || r2 == nil {
		select {
		case r1 = <-w1.ReplyCh:
		case r2 = <-w2.ReplyCh:
		}
	}
	//note: this test could be flaky since we're relying on google to return
	//edns0 response when we send one vs when we don't send one
	assert.NotNil(t, r1.IsEdns0())
	assert.Nil(t, r2.IsEdns0())
}
Ejemplo n.º 5
0
// Perform a DNS query and assert the reply code, number or answers, etc
func assertExchange(t *testing.T, z string, ty uint16, port int, minAnswers int, maxAnswers int, expErr int) (*dns.Msg, *dns.Msg) {
	require.NotEqual(t, 0, port, "invalid DNS server port")

	c := &dns.Client{
		UDPSize: testUDPBufSize,
	}

	m := new(dns.Msg)
	m.RecursionDesired = true
	m.SetQuestion(z, ty)
	m.SetEdns0(testUDPBufSize, false) // we don't want to play with truncation here...

	lstAddr := fmt.Sprintf("127.0.0.1:%d", port)
	r, _, err := c.Exchange(m, lstAddr)
	t.Logf("Response from '%s':\n%+v\n", lstAddr, r)
	if err != nil {
		t.Errorf("Error when querying DNS server at %s: %s", lstAddr, err)
	}
	require.NoError(t, err)
	if minAnswers == 0 && maxAnswers == 0 {
		require.Equal(t, expErr, r.Rcode, "DNS response code")
	} else {
		require.Equal(t, dns.RcodeSuccess, r.Rcode, "DNS response code")
	}
	answers := len(r.Answer)
	if minAnswers >= 0 && answers < minAnswers {
		require.FailNow(t, fmt.Sprintf("Number of answers >= %d", minAnswers))
	}
	if maxAnswers >= 0 && answers > maxAnswers {
		require.FailNow(t, fmt.Sprintf("Number of answers <= %d", maxAnswers))
	}
	return m, r
}
Ejemplo n.º 6
0
// dnsQuery will query a nameserver, iterating through the supplied servers as it retries
// The nameserver should include a port, to facilitate testing where we talk to a mock dns server.
func dnsQuery(fqdn string, rtype uint16, nameservers []string, recursive bool) (in *dns.Msg, err error) {
	m := new(dns.Msg)
	m.SetQuestion(fqdn, rtype)
	m.SetEdns0(4096, false)

	if !recursive {
		m.RecursionDesired = false
	}

	// Will retry the request based on the number of servers (n+1)
	for i := 1; i <= len(nameservers)+1; i++ {
		ns := nameservers[i%len(nameservers)]
		udp := &dns.Client{Net: "udp", Timeout: DNSTimeout}
		in, _, err = udp.Exchange(m, ns)

		if err == dns.ErrTruncated {
			tcp := &dns.Client{Net: "tcp", Timeout: DNSTimeout}
			// If the TCP request suceeds, the err will reset to nil
			in, _, err = tcp.Exchange(m, ns)
		}

		if err == nil {
			break
		}
	}
	return
}
Ejemplo n.º 7
0
// Send a DNS query via UDP, configured by a Request object. If successful,
// stores response details in Result object, otherwise, returns Result object
// with an error string.
func SendQuery(request *Request) (result Result, err error) {
	log.Printf("Sending query: %s", request)
	result.Request = *request

	record_type, ok := dns.StringToType[request.RecordType]
	if !ok {
		result.Error = fmt.Sprintf("Invalid type: %s", request.RecordType)
		return result, errors.New(result.Error)
	}

	m := new(dns.Msg)
	if request.VerifySignature == true {
		log.Printf("SetEdns0 for %s", request.RecordName)
		m.SetEdns0(4096, true)
	}
	m.SetQuestion(request.RecordName, record_type)
	c := new(dns.Client)
	in, rtt, err := c.Exchange(m, request.Destination)
	// log.Printf("Answer: %s [%d] %s", in, rtt, err)

	result.Duration = rtt
	if err != nil {
		result.Error = err.Error()
	} else {
		for _, rr := range in.Answer {
			answer := Answer{
				Ttl:    rr.Header().Ttl,
				Name:   rr.Header().Name,
				String: rr.String(),
			}
			result.Answers = append(result.Answers, answer)
		}
	}
	return result, nil
}
Ejemplo n.º 8
0
func searchServerIP(domain string, version int, DNSservers []string) (answer *dns.Msg, err error) {
	DNSserver := DNSservers[rand.Intn(len(DNSservers))]
	for i := 1; i <= 3; i++ {
		if DNSserver == "" {
			DNSserver = DNSservers[rand.Intn(len(DNSservers))]
		}
	}
	if DNSserver == "" {
		return nil, errors.New("DNSserver is an empty string")
	}
	dnsRequest := new(dns.Msg)
	if dnsRequest == nil {
		return nil, errors.New("Can not new dnsRequest")
	}
	dnsClient := new(dns.Client)
	if dnsClient == nil {
		return nil, errors.New("Can not new dnsClient")
	}
	if version == 4 {
		dnsRequest.SetQuestion(domain+".", dns.TypeA)
	} else if version == 6 {
		dnsRequest.SetQuestion(domain+".", dns.TypeAAAA)
	} else {
		return nil, errors.New("wrong parameter in version")
	}
	dnsRequest.SetEdns0(4096, true)
	answer, _, err = dnsClient.Exchange(dnsRequest, DNSserver)
	if err != nil {
		return nil, err
	}
	return answer, nil
}
Ejemplo n.º 9
0
func TestDNSExpire(t *testing.T) {
	s := newTestServerDNSSEC(t)
	defer s.Stop()

	serv := services[0]
	addService(t, s, serv.key, 1, &Service{Host: serv.Host, Port: serv.Port})
	// defer delService(t, s, serv.key) // It will delete itself...magically

	c := new(dns.Client)
	tc := dnsTestCases[0]
	m := new(dns.Msg)
	m.SetQuestion(tc.Qname, tc.Qtype)
	if tc.dnssec == true {
		m.SetEdns0(4096, true)
	}
	resp, _, err := c.Exchange(m, "127.0.0.1:"+StrPort)
	if err != nil {
		t.Fatalf("failing: %s: %s\n", m.String(), err.Error())
	}
	if resp.Rcode != dns.RcodeSuccess {
		t.Logf("%v\n", resp)
		t.Fail()
	}
	// Sleep to let it expire.
	time.Sleep(2 * time.Second)
	resp, _, err = c.Exchange(m, "127.0.0.1:"+StrPort)
	if err != nil {
		t.Errorf("failing: %s\n", err.Error())
	}
	if resp.Rcode != dns.RcodeNameError {
		t.Logf("%v\n", resp)
		t.Fail()
	}
}
Ejemplo n.º 10
0
func localQuery(mychan chan DNSreply, qname string, qtype uint16) {
	var result DNSreply
	var trials uint
	result.qname = qname
	result.qtype = qtype
	result.r = nil
	result.err = errors.New("No name server to answer the question")
	localm := new(dns.Msg)
	localm.Id = dns.Id()
	localm.RecursionDesired = true
	localm.Question = make([]dns.Question, 1)
	localm.SetEdns0(EDNSBUFFERSIZE, false) // Even if no EDNS requested, see #9 May be we should retry without it if timeout?
	localc := new(dns.Client)
	localc.ReadTimeout = timeout
	localm.Question[0] = dns.Question{qname, qtype, dns.ClassINET}
Tests:
	for trials = 0; trials < uint(*maxTrials); trials++ {
	Resolvers:
		for serverIndex := range conf.Servers {
			server := conf.Servers[serverIndex]
			result.nameserver = server
			// Brackets around the server address are necessary for IPv6 name servers
			r, rtt, err := localc.Exchange(localm, "["+server+"]:"+conf.Port) // Do not use net.JoinHostPort, see https://github.com/bortzmeyer/check-soa/commit/3e4edb13855d8c4016768796b2892aa83eda1933#commitcomment-2355543
			if r == nil {
				result.r = nil
				result.err = err
				if strings.Contains(err.Error(), "timeout") {
					// Try another resolver
					break Resolvers
				} else { // We give in
					break Tests
				}
			} else {
				result.rtt = rtt
				if r.Rcode == dns.RcodeSuccess {
					// TODO: as a result, NODATA (NOERROR/ANSWER=0) are silently ignored (try "foo", for instance, the name exists but no IP address)
					// TODO: for rcodes like SERVFAIL, trying another resolver could make sense
					result.r = r
					result.err = nil
					break Tests
				} else {
					// All the other codes are errors. Yes, it may
					// happens that one resolver returns REFUSED
					// and the others work but we do not handle
					// this case. TODO: delete the resolver from
					// the list and try another one
					result.r = r
					result.err = errors.New(dns.RcodeToString[r.Rcode])
					break Tests
				}
			}
		}
	}
	if *debug {
		fmt.Printf("DEBUG: end of DNS request \"%s\" / %d\n", qname, qtype)
	}
	mychan <- result
}
Ejemplo n.º 11
0
// exchangeOne performs a single DNS exchange with a randomly chosen server
// out of the server list, returning the response, time, and error (if any).
// This method sets the DNSSEC OK bit on the message to true before sending
// it to the resolver in case validation isn't the resolvers default behaviour.
func (dnsResolver *DNSResolverImpl) exchangeOne(ctx context.Context, hostname string, qtype uint16, msgStats metrics.Scope) (*dns.Msg, error) {
	m := new(dns.Msg)
	// Set question type
	m.SetQuestion(dns.Fqdn(hostname), qtype)
	// Set DNSSEC OK bit for resolver
	m.SetEdns0(4096, true)

	if len(dnsResolver.servers) < 1 {
		return nil, fmt.Errorf("Not configured with at least one DNS Server")
	}

	dnsResolver.stats.Inc("Rate", 1)

	// Randomly pick a server
	chosenServer := dnsResolver.servers[rand.Intn(len(dnsResolver.servers))]

	client := dnsResolver.dnsClient

	tries := 1
	start := dnsResolver.clk.Now()
	msgStats.Inc("Calls", 1)
	defer func() {
		msgStats.TimingDuration("Latency", dnsResolver.clk.Now().Sub(start))
	}()
	for {
		msgStats.Inc("Tries", 1)
		ch := make(chan dnsResp, 1)

		go func() {
			rsp, rtt, err := client.Exchange(m, chosenServer)
			msgStats.TimingDuration("SingleTryLatency", rtt)
			ch <- dnsResp{m: rsp, err: err}
		}()
		select {
		case <-ctx.Done():
			msgStats.Inc("Cancels", 1)
			msgStats.Inc("Errors", 1)
			return nil, ctx.Err()
		case r := <-ch:
			if r.err != nil {
				msgStats.Inc("Errors", 1)
				operr, ok := r.err.(*net.OpError)
				isRetryable := ok && operr.Temporary()
				hasRetriesLeft := tries < dnsResolver.maxTries
				if isRetryable && hasRetriesLeft {
					tries++
					continue
				} else if isRetryable && !hasRetriesLeft {
					msgStats.Inc("RanOutOfTries", 1)
				}
			} else {
				msgStats.Inc("Successes", 1)
			}
			return r.m, r.err
		}
	}
}
Ejemplo n.º 12
0
// Create skeleton edns opt RR from the query and
// add it to the message m
func ednsFromRequest(req, m *dns.Msg) {
	for _, r := range req.Extra {
		if r.Header().Rrtype == dns.TypeOPT {
			m.SetEdns0(4096, r.(*dns.OPT).Do())
			return
		}
	}
	return
}
Ejemplo n.º 13
0
func (self *TrivialDnsServer) refuseWithCode(w dns.ResponseWriter, req *dns.Msg, code int) {
	m := new(dns.Msg)
	for _, r := range req.Extra {
		if r.Header().Rrtype == dns.TypeOPT {
			m.SetEdns0(4096, r.(*dns.OPT).Do())
		}
	}
	m.SetRcode(req, code)
	w.WriteMsg(m)
}
Ejemplo n.º 14
0
func testLookupSRV(t *testing.T) {
	if !checkKubernetesRunning() {
		t.Skip("Skipping Kubernetes Integration tests. Kubernetes is not running")
	}

	// Note: Use different port to avoid conflict with servers used in other tests.
	coreFile :=
		`.:2054 {
    kubernetes coredns.local {
		endpoint http://localhost:8080
		namespaces demo
    }
`

	server, _, udp, err := Server(t, coreFile)
	if err != nil {
		t.Fatal("Could not get server: %s", err)
	}
	defer server.Stop()

	log.SetOutput(ioutil.Discard)

	// TODO: Add checks for A records in additional section

	for _, testData := range testdataLookupSRV {
		t.Logf("[log] Testing query string: '%v'\n", testData.Query)
		dnsClient := new(dns.Client)
		dnsMessage := new(dns.Msg)

		dnsMessage.SetQuestion(testData.Query, dns.TypeSRV)
		dnsMessage.SetEdns0(4096, true)

		res, _, err := dnsClient.Exchange(dnsMessage, udp)
		if err != nil {
			t.Fatal("Could not send query: %s", err)
		}
		// Count SRV records in the answer section
		srvRecordCount := 0
		for _, a := range res.Answer {
			fmt.Printf("RR: %v\n", a)
			if a.Header().Rrtype == dns.TypeSRV {
				srvRecordCount++
			}
		}

		if srvRecordCount != testData.SRVRecordCount {
			t.Errorf("Expected '%v' SRV records in response. Instead got '%v' SRV records. Test query string: '%v'", testData.SRVRecordCount, srvRecordCount, testData.Query)
		}
		if len(res.Answer) != testData.TotalAnswerCount {
			t.Errorf("Expected '%v' records in answer section. Instead got '%v' records in answer section. Test query string: '%v'", testData.TotalAnswerCount, len(res.Answer), testData.Query)
		}
	}
}
Ejemplo n.º 15
0
func resolve(req *dns.Msg, dnssec bool) (*dns.Msg, error) {
	extra2 := []dns.RR{}
	for _, extra := range req.Extra {
		if extra.Header().Rrtype != dns.TypeOPT {
			extra2 = append(extra2, extra)
		}
	}
	req.Extra = extra2
	req.SetEdns0(dns.DefaultMsgSize, dnssec)
	resolved, _, err := resolveViaResolverThreads(req)
	if err != nil {
		return nil, err
	}
	resolved.Compress = true
	return resolved, nil
}
Ejemplo n.º 16
0
// dnsQuery sends a DNS query to the given nameserver.
// The nameserver should include a port, to facilitate testing where we talk to a mock dns server.
func dnsQuery(fqdn string, rtype uint16, nameserver string, recursive bool) (in *dns.Msg, err error) {
	m := new(dns.Msg)
	m.SetQuestion(fqdn, rtype)
	m.SetEdns0(4096, false)

	if !recursive {
		m.RecursionDesired = false
	}

	in, err = dns.Exchange(m, nameserver)
	if err == dns.ErrTruncated {
		tcp := &dns.Client{Net: "tcp"}
		in, _, err = tcp.Exchange(m, nameserver)
	}

	return
}
Ejemplo n.º 17
0
func TestLookupBalanceRewriteCacheDnssec(t *testing.T) {
	name, rm, err := test.TempFile(t, ".", exampleOrg)
	if err != nil {
		t.Fatalf("failed to created zone: %s", err)
	}
	defer rm()
	rm1 := createKeyFile(t)
	defer rm1()

	corefile := `example.org:0 {
    file ` + name + `
    rewrite ANY HINFO
    dnssec {
        key file ` + base + `
    }
    loadbalance
}
`
	ex, _, udp, err := Server(t, corefile)
	if err != nil {
		t.Errorf("Could get server to start: %s", err)
		return
	}
	defer ex.Stop()

	log.SetOutput(ioutil.Discard)
	c := new(dns.Client)
	m := new(dns.Msg)
	m.SetQuestion("example.org.", dns.TypeA)
	m.SetEdns0(4096, true)
	res, _, err := c.Exchange(m, udp)
	if err != nil {
		t.Fatalf("Could not send query: %s", err)
	}
	sig := 0
	for _, a := range res.Answer {
		if a.Header().Rrtype == dns.TypeRRSIG {
			sig++
		}
	}
	if sig == 0 {
		t.Errorf("expected RRSIGs, got none")
		t.Logf("%v\n", res)
	}
}
Ejemplo n.º 18
0
Archivo: main.go Proyecto: mpdroog/dnsd
func (h *Handle) ServeDNS(w dns.ResponseWriter, req *dns.Msg) {
	//fmt.Printf("Req %+v\n", req)

	// Blacklist lookup
	domain := req.Question[0].String()
	if strings.HasSuffix(domain, "A") {
		domain = domain[1:strings.LastIndex(domain, ".")]
		fmt.Printf("LOOKUP=%s\n", domain)
		if _, ok := Adlist[domain]; ok {
			// todo: now what?
			fmt.Printf("DROP=%s\n", domain)

			m := new(dns.Msg)
			for _, r := range req.Extra {
				if r.Header().Rrtype == dns.TypeOPT {
					m.SetEdns0(4096, r.(*dns.OPT).Do())
				}
			}
			m.SetRcode(req, dns.RcodeRefused)
			w.WriteMsg(m)
			return
		}
	}

	// Forward
	c := new(dns.Client)
	res, rtt, err := c.Exchange(req, "8.8.8.8:53")
	if err != nil {
		fmt.Printf("Lookup fail %s", err.Error())
		m := new(dns.Msg)
		for _, r := range req.Extra {
			if r.Header().Rrtype == dns.TypeOPT {
				m.SetEdns0(4096, r.(*dns.OPT).Do())
			}
		}
		m.SetRcode(req, dns.RcodeRefused)
		w.WriteMsg(m)
		return
	}

	fmt.Printf("%s: request took %s\n", w.RemoteAddr(), rtt)
	w.WriteMsg(res)
}
Ejemplo n.º 19
0
func main() {
	if len(os.Args) != 2 {
		fmt.Printf("%s NAME\n", os.Args[0])
		os.Exit(1)
	}
	name := os.Args[1]
	conf, _ := dns.ClientConfigFromFile("/etc/resolv.conf")
	client := new(dns.Client)
	message := new(dns.Msg)
	message.Question = make([]dns.Question, 1)
	message.Question[0] = dns.Question{BASE, dns.TypeTXT, dns.ClassINET}
	message.SetEdns0(4096, true)
	message.RecursionDesired = true
	reply, _, err := client.Exchange(message, conf.Servers[0]+":"+conf.Port)
	if err != nil {
		fmt.Printf("Cannot get info for %s: %s\n", BASE, err)
		os.Exit(1)
	}
	if reply.Rcode != dns.RcodeSuccess {
		fmt.Printf("Bad answer from the resolver: %v\n", reply.Rcode)
		os.Exit(1)
	}
	if len(reply.Answer) == 0 {
		fmt.Printf("Zero answer for %s\n", BASE)
		os.Exit(1)
	}
	toReporter := make(chan instanceQuery)
	urls := 0
	for _, rr := range reply.Answer {
		switch rr.(type) {
		case *dns.TXT:
			url := rr.(*dns.TXT).Txt[0]
			urls++
			go queryOne(toReporter, url, name)
		}
		// Otherwise, ignore it. Probably a DNSSEC signature
	}
	fromReporter := make(chan string)
	go reporter(toReporter, fromReporter, urls)
	<-fromReporter

}
Ejemplo n.º 20
0
// Lookup looks up name,type using the recursive nameserver defines
// in the server's config. If none defined it returns an error.
func (s *server) Lookup(n string, t, bufsize uint16, dnssec bool) (*dns.Msg, error) {
	StatsLookupCount.Inc(1)
	promExternalRequestCount.WithLabelValues("lookup").Inc()

	if len(s.config.Nameservers) == 0 {
		return nil, fmt.Errorf("no nameservers configured can not lookup name")
	}
	if dns.CountLabel(n) < s.config.Ndots {
		return nil, fmt.Errorf("name has fewer than %d labels", s.config.Ndots)
	}
	m := new(dns.Msg)
	m.SetQuestion(n, t)
	m.SetEdns0(bufsize, dnssec)

	nsid := int(m.Id) % len(s.config.Nameservers)
	try := 0
Redo:
	r, _, err := s.dnsUDPclient.Exchange(m, s.config.Nameservers[nsid])
	if err == nil {
		if r.Rcode != dns.RcodeSuccess {
			return nil, fmt.Errorf("rcode is not equal to success")
		}
		// Reset TTLs to rcache TTL to make some of the other code
		// and the tests not care about TTLs
		for _, rr := range r.Answer {
			rr.Header().Ttl = uint32(s.config.RCacheTtl)
		}
		for _, rr := range r.Extra {
			rr.Header().Ttl = uint32(s.config.RCacheTtl)
		}
		return r, nil
	}
	// Seen an error, this can only mean, "server not reached", try again
	// but only if we have not exausted our nameservers.
	if try < len(s.config.Nameservers) {
		try++
		nsid = (nsid + 1) % len(s.config.Nameservers)
		goto Redo
	}
	return nil, fmt.Errorf("failure to lookup name")
}
Ejemplo n.º 21
0
// Lookup looks up name,type using the recursive nameserver defines
// in the server's config. If none defined it returns an error
func (s *server) Lookup(n string, t, bufsize uint16, dnssec bool) (*dns.Msg, error) {
	StatsLookupCount.Inc(1)
	if len(s.config.Nameservers) == 0 {
		return nil, fmt.Errorf("no nameservers configured can not lookup name")
	}
	if dns.CountLabel(n) < s.config.Ndots {
		return nil, fmt.Errorf("name has fewer than %d labels", s.config.Ndots)
	}
	m := new(dns.Msg)
	m.SetQuestion(n, t)
	m.SetEdns0(bufsize, dnssec)

	c := &dns.Client{Net: "udp", ReadTimeout: 2 * s.config.ReadTimeout, WriteTimeout: 2 * s.config.ReadTimeout}
	nsid := int(m.Id) % len(s.config.Nameservers)
	try := 0
Redo:
	// Move this to use s.udpClient/s.tcpClient code instead of allocating a new client for every query.
	r, _, err := c.Exchange(m, s.config.Nameservers[nsid])
	if err == nil {
		if r.Rcode != dns.RcodeSuccess {
			return nil, fmt.Errorf("rcode is not equal to success")
		}
		// Reset TTLs to rcache TTL to make some of the other code
		// and the tests not care about TTLs
		for _, rr := range r.Answer {
			rr.Header().Ttl = uint32(s.config.RCacheTtl)
		}
		for _, rr := range r.Extra {
			rr.Header().Ttl = uint32(s.config.RCacheTtl)
		}
		return r, nil
	}
	// Seen an error, this can only mean, "server not reached", try again
	// but only if we have not exausted our nameservers.
	if try < len(s.config.Nameservers) {
		try++
		nsid = (nsid + 1) % len(s.config.Nameservers)
		goto Redo
	}
	return nil, fmt.Errorf("failure to lookup name")
}
Ejemplo n.º 22
0
func BenchmarkDNSSECSingleNoCache(b *testing.B) {
	b.StopTimer()
	t := new(testing.T)
	s := newTestServerDNSSEC(t, false)
	defer s.Stop()

	serv := services[0]
	addService(t, s, serv.Key, 0, serv)
	defer delService(t, s, serv.Key)

	c := new(dns.Client)
	tc := dnsTestCases[0]
	m := new(dns.Msg)
	m.SetQuestion(tc.Qname, tc.Qtype)
	m.SetEdns0(4096, true)

	b.StartTimer()
	for i := 0; i < b.N; i++ {
		c.Exchange(m, "127.0.0.1:"+StrPort)
	}
}
Ejemplo n.º 23
0
Archivo: q.go Proyecto: raybejjani/dns
// Get the key from the DNS (uses the local resolver) and return them.
// If nothing is found we return nil
func getKey(name string, keytag uint16, server string, tcp bool) *dns.DNSKEY {
	c := new(dns.Client)
	if tcp {
		c.Net = "tcp"
	}
	m := new(dns.Msg)
	m.SetQuestion(name, dns.TypeDNSKEY)
	m.SetEdns0(4096, true)
	r, _, err := c.Exchange(m, server)
	if err != nil {
		return nil
	}
	for _, k := range r.Answer {
		if k1, ok := k.(*dns.DNSKEY); ok {
			if k1.KeyTag() == keytag {
				return k1
			}
		}
	}
	return nil
}
Ejemplo n.º 24
0
// Retrieve the DNSKEY records of a zone and convert them
// to DS records for SHA1, SHA256 and SHA384.
func ExampleDS() {
	config, _ := dns.ClientConfigFromFile("/etc/resolv.conf")
	c := new(dns.Client)
	m := new(dns.Msg)
	zone := "miek.nl"
	m.SetQuestion(dns.Fqdn(zone), dns.TypeDNSKEY)
	m.SetEdns0(4096, true)
	r, _, err := c.Exchange(m, config.Servers[0]+":"+config.Port)
	if err != nil {
		return
	}
	if r.Rcode != dns.RcodeSuccess {
		return
	}
	for _, k := range r.Answer {
		if key, ok := k.(*dns.DNSKEY); ok {
			for _, alg := range []uint8{dns.SHA1, dns.SHA256, dns.SHA384} {
				fmt.Printf("%s; %d\n", key.ToDS(alg).String(), key.Flags)
			}
		}
	}
}
Ejemplo n.º 25
0
func (t *tester) processName(name string) {
	r := &result{}
	msg := new(dns.Msg)
	msg.SetEdns0(4096, true)
	msg.SetQuestion(dns.Fqdn(name), dns.TypeA)

	wg := new(sync.WaitGroup)
	r.Started = time.Now()
	wg.Add(1)
	go func() {
		defer wg.Done()
		defer func() { r.NormalTook = time.Since(r.Started) }()
		resp, _, err := t.client.Exchange(msg, t.normalResolver)
		if err != nil {
			r.NormalError = err.Error()
		} else if resp.Rcode == dns.RcodeServerFailure {
			r.NormalError = fmt.Sprintf("%s", dns.RcodeToString[resp.Rcode])
		}
	}()
	wg.Add(1)
	go func() {
		defer wg.Done()
		defer func() { r.TorTook = time.Since(r.Started) }()
		resp, _, err := t.client.Exchange(msg, t.torResolver)
		if err != nil {
			r.TorError = err.Error()
		} else if resp.Rcode == dns.RcodeServerFailure {
			r.TorError = fmt.Sprintf("%s", dns.RcodeToString[resp.Rcode])
		}
	}()
	wg.Wait()

	if r.NormalError != "" && r.NormalError == r.TorError {
		return
	}
	t.results <- r
}
Ejemplo n.º 26
0
Archivo: serve.go Proyecto: abh/geodns
func (srv *Server) serve(w dns.ResponseWriter, req *dns.Msg, z *Zone) {

	qname := req.Question[0].Name
	qtype := req.Question[0].Qtype

	var qle *querylog.Entry

	if srv.queryLogger != nil {
		qle = &querylog.Entry{
			Time:   time.Now().UnixNano(),
			Origin: z.Origin,
			Name:   qname,
			Qtype:  qtype,
		}
		defer srv.queryLogger.Write(qle)
	}

	logPrintf("[zone %s] incoming  %s %s (id %d) from %s\n", z.Origin, qname,
		dns.TypeToString[qtype], req.Id, w.RemoteAddr())

	// Global meter
	metrics.Get("queries").(metrics.Meter).Mark(1)

	// Zone meter
	z.Metrics.Queries.Mark(1)

	logPrintln("Got request", req)

	label := getQuestionName(z, req)

	z.Metrics.LabelStats.Add(label)

	// IP that's talking to us (not EDNS CLIENT SUBNET)
	var realIP net.IP

	if addr, ok := w.RemoteAddr().(*net.UDPAddr); ok {
		realIP = make(net.IP, len(addr.IP))
		copy(realIP, addr.IP)
	} else if addr, ok := w.RemoteAddr().(*net.TCPAddr); ok {
		realIP = make(net.IP, len(addr.IP))
		copy(realIP, addr.IP)
	}
	if qle != nil {
		qle.RemoteAddr = realIP.String()
	}

	z.Metrics.ClientStats.Add(realIP.String())

	var ip net.IP // EDNS or real IP
	var edns *dns.EDNS0_SUBNET
	var opt_rr *dns.OPT

	for _, extra := range req.Extra {

		switch extra.(type) {
		case *dns.OPT:
			for _, o := range extra.(*dns.OPT).Option {
				opt_rr = extra.(*dns.OPT)
				switch e := o.(type) {
				case *dns.EDNS0_NSID:
					// do stuff with e.Nsid
				case *dns.EDNS0_SUBNET:
					z.Metrics.EdnsQueries.Mark(1)
					logPrintln("Got edns", e.Address, e.Family, e.SourceNetmask, e.SourceScope)
					if e.Address != nil {
						edns = e
						ip = e.Address

						if qle != nil {
							qle.HasECS = true
							qle.ClientAddr = fmt.Sprintf("%s/%d", ip, e.SourceNetmask)
						}
					}
				}
			}
		}
	}

	if len(ip) == 0 { // no edns subnet
		ip = realIP
		if qle != nil {
			qle.ClientAddr = fmt.Sprintf("%s/%d", ip, len(ip)*8)
		}
	}

	targets, netmask := z.Options.Targeting.GetTargets(ip)

	if qle != nil {
		qle.Targets = targets
	}

	m := new(dns.Msg)

	if qle != nil {
		defer func() {
			qle.Rcode = m.Rcode
			qle.Answers = len(m.Answer)
		}()
	}

	m.SetReply(req)
	if e := m.IsEdns0(); e != nil {
		m.SetEdns0(4096, e.Do())
	}
	m.Authoritative = true

	// TODO: set scope to 0 if there are no alternate responses
	if edns != nil {
		if edns.Family != 0 {
			if netmask < 16 {
				netmask = 16
			}
			edns.SourceScope = uint8(netmask)
			m.Extra = append(m.Extra, opt_rr)
		}
	}

	labels, labelQtype := z.findLabels(label, targets, qTypes{dns.TypeMF, dns.TypeCNAME, qtype})
	if labelQtype == 0 {
		labelQtype = qtype
	}

	if labels == nil {

		permitDebug := !*flagPrivateDebug || (realIP != nil && realIP.IsLoopback())

		firstLabel := (strings.Split(label, "."))[0]

		if qle != nil {
			qle.LabelName = firstLabel
		}

		if permitDebug && firstLabel == "_status" {
			if qtype == dns.TypeANY || qtype == dns.TypeTXT {
				m.Answer = statusRR(label + "." + z.Origin + ".")
			} else {
				m.Ns = append(m.Ns, z.SoaRR())
			}
			m.Authoritative = true
			w.WriteMsg(m)
			return
		}

		if firstLabel == "_country" {
			if qtype == dns.TypeANY || qtype == dns.TypeTXT {
				h := dns.RR_Header{Ttl: 1, Class: dns.ClassINET, Rrtype: dns.TypeTXT}
				h.Name = label + "." + z.Origin + "."

				txt := []string{
					w.RemoteAddr().String(),
					ip.String(),
				}

				targets, netmask := z.Options.Targeting.GetTargets(ip)
				txt = append(txt, strings.Join(targets, " "))
				txt = append(txt, fmt.Sprintf("/%d", netmask), serverID, serverIP)

				m.Answer = []dns.RR{&dns.TXT{Hdr: h,
					Txt: txt,
				}}
			} else {
				m.Ns = append(m.Ns, z.SoaRR())
			}

			m.Authoritative = true

			w.WriteMsg(m)
			return
		}

		// return NXDOMAIN
		m.SetRcode(req, dns.RcodeNameError)
		m.Authoritative = true

		m.Ns = []dns.RR{z.SoaRR()}

		w.WriteMsg(m)
		return
	}

	if servers := labels.Picker(labelQtype, labels.MaxHosts); servers != nil {
		var rrs []dns.RR
		for _, record := range servers {
			rr := dns.Copy(record.RR)
			rr.Header().Name = qname
			rrs = append(rrs, rr)
		}
		m.Answer = rrs
	}

	if len(m.Answer) == 0 {
		// Return a SOA so the NOERROR answer gets cached
		m.Ns = append(m.Ns, z.SoaRR())
	}

	logPrintln(m)

	if qle != nil {
		qle.LabelName = labels.Label
		qle.Answers = len(m.Answer)
		qle.Rcode = m.Rcode
	}
	err := w.WriteMsg(m)
	if err != nil {
		// if Pack'ing fails the Write fails. Return SERVFAIL.
		log.Println("Error writing packet", m)
		dns.HandleFailed(w, req)
	}
	return
}
Ejemplo n.º 27
0
func TestTCPDNSServer(t *testing.T) {
	setupForTest(t)

	const (
		numAnswers   = 512
		nonLocalName = "weave.works."
	)

	InitDefaultLogging(testing.Verbose())
	Info.Println("TestTCPDNSServer starting")

	zone, err := NewZoneDb(ZoneConfig{})
	require.NoError(t, err)
	err = zone.Start()
	require.NoError(t, err)
	defer zone.Stop()

	// generate a list of `numAnswers` IP addresses
	var addrs []ZoneRecord
	bs := make([]byte, 4)
	for i := 0; i < numAnswers; i++ {
		binary.LittleEndian.PutUint32(bs, uint32(i))
		ip := net.IPv4(bs[0], bs[1], bs[2], bs[3])
		addrs = append(addrs, ZoneRecord(Record{"", ip, 0, 0, 0}))
	}

	// handler for the fallback server: it will just return a very long response
	fallbackUDPHandler := func(w dns.ResponseWriter, req *dns.Msg) {
		if len(req.Question) == 0 {
			return // ignore empty queries (sent when shutting down the server)
		}
		maxLen := getMaxReplyLen(req, protUDP)
		t.Logf("Fallback UDP server got asked: returning %d answers", numAnswers)
		q := req.Question[0]
		m := makeAddressReply(req, &q, addrs, DefaultLocalTTL)
		mLen := m.Len()
		m.SetEdns0(uint16(maxLen), false)

		if mLen > maxLen {
			t.Logf("... truncated response (%d > %d)", mLen, maxLen)
			m.Truncated = true
		}
		w.WriteMsg(m)
	}
	fallbackTCPHandler := func(w dns.ResponseWriter, req *dns.Msg) {
		if len(req.Question) == 0 {
			return // ignore empty queries (sent when shutting down the server)
		}
		t.Logf("Fallback TCP server got asked: returning %d answers", numAnswers)
		q := req.Question[0]
		m := makeAddressReply(req, &q, addrs, DefaultLocalTTL)
		w.WriteMsg(m)
	}

	t.Logf("Running a DNS fallback server with UDP")
	fallback, err := newMockedFallback(fallbackUDPHandler, fallbackTCPHandler)
	require.NoError(t, err)
	fallback.Start()
	defer fallback.Stop()

	t.Logf("Creating a WeaveDNS server instance, falling back to 127.0.0.1:%d", fallback.Port)
	srv, err := NewDNSServer(DNSServerConfig{
		Zone:              zone,
		UpstreamCfg:       fallback.CliConfig,
		CacheDisabled:     true,
		ListenReadTimeout: testSocketTimeout,
	})
	require.NoError(t, err)
	err = srv.Start()
	require.NoError(t, err)
	go srv.ActivateAndServe()
	defer srv.Stop()
	time.Sleep(100 * time.Millisecond) // Allow sever goroutine to start

	testPort, err := srv.GetPort()
	require.NoError(t, err)
	require.NotEqual(t, 0, testPort, "listen port")
	dnsAddr := fmt.Sprintf("127.0.0.1:%d", testPort)

	t.Logf("Creating a UDP and a TCP client")
	uc := new(dns.Client)
	uc.UDPSize = minUDPSize
	tc := new(dns.Client)
	tc.Net = "tcp"

	t.Logf("Creating DNS query message")
	m := new(dns.Msg)
	m.RecursionDesired = true
	m.SetQuestion(nonLocalName, dns.TypeA)

	t.Logf("Checking the fallback server at %s returns a truncated response with UDP", fallback.Addr)
	r, _, err := uc.Exchange(m, fallback.Addr)
	t.Logf("Got response from fallback server (UDP) with %d answers", len(r.Answer))
	t.Logf("Response:\n%+v\n", r)
	require.NoError(t, err)
	require.True(t, r.MsgHdr.Truncated, "DNS truncated reponse flag")
	require.NotEqual(t, numAnswers, len(r.Answer), "number of answers (UDP)")

	t.Logf("Checking the WeaveDNS server at %s returns a truncated reponse with UDP", dnsAddr)
	r, _, err = uc.Exchange(m, dnsAddr)
	t.Logf("UDP Response:\n%+v\n", r)
	require.NoError(t, err)
	require.NotNil(t, r, "response")
	t.Logf("%d answers", len(r.Answer))
	require.True(t, r.MsgHdr.Truncated, "DNS truncated reponse flag")
	require.NotEqual(t, numAnswers, len(r.Answer), "number of answers (UDP)")

	t.Logf("Checking the WeaveDNS server at %s does not return a truncated reponse with TCP", dnsAddr)
	r, _, err = tc.Exchange(m, dnsAddr)
	t.Logf("TCP Response:\n%+v\n", r)
	require.NoError(t, err)
	require.NotNil(t, r, "response")
	t.Logf("%d answers", len(r.Answer))
	require.False(t, r.MsgHdr.Truncated, "DNS truncated response flag")
	require.Equal(t, numAnswers, len(r.Answer), "number of answers (TCP)")

	t.Logf("Checking the WeaveDNS server at %s does not return a truncated reponse with UDP with a bigger buffer", dnsAddr)
	m.SetEdns0(testUDPBufSize, false)
	r, _, err = uc.Exchange(m, dnsAddr)
	t.Logf("UDP-large Response:\n%+v\n", r)
	require.NoError(t, err)
	require.NotNil(t, r, "response")
	t.Logf("%d answers", len(r.Answer))
	require.NoError(t, err)
	require.False(t, r.MsgHdr.Truncated, "DNS truncated response flag")
	require.Equal(t, numAnswers, len(r.Answer), "number of answers (UDP-long)")
}
/*
For fragmentation, we use a naive algorithm.

We use the same header for every fragment, and include the same EDNS0
section in every additional section.

We add one RR at a time, until our fragment is larger than 512 bytes,
then we remove the last RR so that it fits in the 512 byte size limit.

If we discover that one of the fragments ends up with 0 RR in it (for
example because a single RR is too big), then we return a single
truncated response instead of the set of fragments.

We could perhaps make the process of building fragments faster by
bisecting the set of RR that we include in an answer. So, if we have 8
RR we could try all, then if that is too big, 4 RR, and if that fits
then 6 RR, until an optimal set of RR is found.

We could also possibly produce a smaller set of responses by
optimizing how we combine RR. Just taking account the various sizes is
the same as the bin packing problem, which is NP-hard:

  https://en.wikipedia.org/wiki/Bin_packing_problem

While some non-optimal but reasonable heuristics exist, in the case of
DNS we would have to use some sophisticated algorithm to also consider
name compression.
*/
func frag(reply *dns.Msg) []dns.Msg {
	// create a return value
	all_frags := []dns.Msg{}
	HasEdns0 := true
	// get each RR section and save a copy out
	remaining_answer := make([]dns.RR, len(reply.Answer))
	copy(remaining_answer, reply.Answer)
	remaining_ns := make([]dns.RR, len(reply.Ns))
	copy(remaining_ns, reply.Ns)
	remaining_extra := make([]dns.RR, len(reply.Extra))
	copy(remaining_extra, reply.Extra)

	// if we don't have EDNS0 in the packet, add it now
	if reply.IsEdns0() == nil {
		reply.SetEdns0(512, false)
	}

	// the EDNS option for later use
	var edns0_rr dns.RR = nil

	// remove the EDNS0 option from our additional ("extra") section
	// (we will include it separately on every fragment)
	for ofs, r := range remaining_extra {
		// found the EDNS option
		if r.Header().Rrtype == dns.TypeOPT {
			// save the EDNS option
			edns0_rr = r
			// remove from the set of extra RR
			remaining_extra = append(remaining_extra[0:ofs], remaining_extra[ofs+1:]...)
			// in principle we should only have one EDNS0 section
			break
		}
	}

	if edns0_rr == nil {
		log.Printf("Server reply missing EDNS0 option")
		return []dns.Msg{}
		//HasEdns0 = false
	}

	// now build fragments
	for {
		// make a shallow copy of our reply packet, and prepare space for our RR
		frag := *reply
		frag.Answer = []dns.RR{}
		frag.Ns = []dns.RR{}
		frag.Extra = []dns.RR{}

		// add our custom EDNS0 option (needed in every fragment)
		local_opt := new(dns.EDNS0_LOCAL)
		local_opt.Code = dns.EDNS0LOCALSTART + 1
		local_opt.Data = []byte{0, 0}
		if HasEdns0 == true {
			edns0_rr_copy := dns.Copy(edns0_rr)
			edns0_rr_copy.(*dns.OPT).Option = append(edns0_rr_copy.(*dns.OPT).Option, local_opt)
			frag.Extra = append(frag.Extra, edns0_rr_copy)
		}
		//if HasEdns0 == false {
		//	frag.Extra = append(frag.Extra, local_opt)
		//}

		// add as many RR to the answer as we can
		for len(remaining_answer) > 0 {
			frag.Answer = append(frag.Answer, remaining_answer[0])
			if frag.Len() <= 512 {
				// if the new answer fits, then remove it from our remaining list
				remaining_answer = remaining_answer[1:]
			} else {
				// otherwise we are full, remove it from our fragment and stop
				frag.Answer = frag.Answer[0 : len(frag.Answer)-1]
				break
			}
		}
		for len(remaining_ns) > 0 {
			frag.Ns = append(frag.Ns, remaining_ns[0])
			if frag.Len() <= 512 {
				// if the new answer fits, then remove it from our remaining list
				remaining_ns = remaining_ns[1:]
			} else {
				// otherwise we are full, remove it from our fragment and stop
				frag.Ns = frag.Ns[0 : len(frag.Ns)-1]
				break
			}
		}
		for len(remaining_extra) > 0 {
			frag.Extra = append(frag.Extra, remaining_extra[0])
			if frag.Len() <= 512 {
				// if the new answer fits, then remove it from our remaining list
				remaining_extra = remaining_extra[1:]
			} else {
				// otherwise we are full, remove it from our fragment and stop
				frag.Extra = frag.Extra[0 : len(frag.Extra)-1]
				break
			}
		}

		// check to see if we didn't manage to add any RR
		if (len(frag.Answer) == 0) && (len(frag.Ns) == 0) && (len(frag.Extra) == 1) {
			// TODO: test this :)
			// return a single truncated fragment without any RR
			frag.MsgHdr.Truncated = true
			frag.Extra = []dns.RR{}
			return []dns.Msg{frag}
		}

		// add to our list of fragments
		all_frags = append(all_frags, frag)
		// if we have finished all remaining sections, we are done
		if (len(remaining_answer) == 0) && (len(remaining_ns) == 0) && (len(remaining_extra) == 0) {
			break
		}
	}

	// fix up our fragments so they have the correct sequence and length values
	for n, frag := range all_frags {
		frag_edns0 := frag.IsEdns0()
		for _, opt := range frag_edns0.Option {
			if opt.Option() == dns.EDNS0LOCALSTART+1 {
				opt.(*dns.EDNS0_LOCAL).Data = []byte{byte(len(all_frags)), byte(n)}
			}
		}
	}

	// return our fragments
	return all_frags
}
Ejemplo n.º 29
0
func TestDNS(t *testing.T) {
	s := newTestServerDNSSEC(t, false)
	defer s.Stop()

	for _, serv := range services {
		addService(t, s, serv.Key, 0, serv)
		defer delService(t, s, serv.Key)
	}
	c := new(dns.Client)
	for _, tc := range dnsTestCases {
		m := new(dns.Msg)
		m.SetQuestion(tc.Qname, tc.Qtype)
		if tc.dnssec {
			m.SetEdns0(4096, true)
		}
		if tc.chaos {
			m.Question[0].Qclass = dns.ClassCHAOS
		}
		resp, _, err := c.Exchange(m, "127.0.0.1:"+StrPort)
		if err != nil {
			// try twice, be more resilent against remote lookups
			// timing out.
			resp, _, err = c.Exchange(m, "127.0.0.1:"+StrPort)
			if err != nil {
				t.Fatalf("failing: %s: %s\n", m.String(), err.Error())
			}
		}
		sort.Sort(rrSet(resp.Answer))
		sort.Sort(rrSet(resp.Ns))
		sort.Sort(rrSet(resp.Extra))
		fatal := false
		defer func() {
			if fatal {
				t.Logf("question: %s\n", m.Question[0].String())
				t.Logf("%s\n", resp)
			}
		}()
		if resp.Rcode != tc.Rcode {
			fatal = true
			t.Fatalf("rcode is %q, expected %q", dns.RcodeToString[resp.Rcode], dns.RcodeToString[tc.Rcode])
		}
		if len(resp.Answer) != len(tc.Answer) {
			fatal = true
			t.Fatalf("answer for %q contained %d results, %d expected", tc.Qname, len(resp.Answer), len(tc.Answer))
		}
		for i, a := range resp.Answer {
			if a.Header().Name != tc.Answer[i].Header().Name {
				fatal = true
				t.Fatalf("answer %d should have a Header Name of %q, but has %q", i, tc.Answer[i].Header().Name, a.Header().Name)
			}
			if a.Header().Ttl != tc.Answer[i].Header().Ttl {
				fatal = true
				t.Fatalf("Answer %d should have a Header TTL of %d, but has %d", i, tc.Answer[i].Header().Ttl, a.Header().Ttl)
			}
			if a.Header().Rrtype != tc.Answer[i].Header().Rrtype {
				fatal = true
				t.Fatalf("answer %d should have a header response type of %d, but has %d", i, tc.Answer[i].Header().Rrtype, a.Header().Rrtype)
			}
			switch x := a.(type) {
			case *dns.SRV:
				if x.Priority != tc.Answer[i].(*dns.SRV).Priority {
					fatal = true
					t.Fatalf("answer %d should have a Priority of %d, but has %d", i, tc.Answer[i].(*dns.SRV).Priority, x.Priority)
				}
				if x.Weight != tc.Answer[i].(*dns.SRV).Weight {
					fatal = true
					t.Fatalf("answer %d should have a Weight of %d, but has %d", i, tc.Answer[i].(*dns.SRV).Weight, x.Weight)
				}
				if x.Port != tc.Answer[i].(*dns.SRV).Port {
					fatal = true
					t.Fatalf("answer %d should have a Port of %d, but has %d", i, tc.Answer[i].(*dns.SRV).Port, x.Port)
				}
				if x.Target != tc.Answer[i].(*dns.SRV).Target {
					fatal = true
					t.Fatalf("answer %d should have a Target of %q, but has %q", i, tc.Answer[i].(*dns.SRV).Target, x.Target)
				}
			case *dns.A:
				if x.A.String() != tc.Answer[i].(*dns.A).A.String() {
					fatal = true
					t.Fatalf("answer %d should have a Address of %q, but has %q", i, tc.Answer[i].(*dns.A).A.String(), x.A.String())
				}
			case *dns.AAAA:
				if x.AAAA.String() != tc.Answer[i].(*dns.AAAA).AAAA.String() {
					fatal = true
					t.Fatalf("answer %d should have a Address of %q, but has %q", i, tc.Answer[i].(*dns.AAAA).AAAA.String(), x.AAAA.String())
				}
			case *dns.TXT:
				for j, txt := range x.Txt {
					if txt != tc.Answer[i].(*dns.TXT).Txt[j] {
						fatal = true
						t.Fatalf("answer %d should have a Txt of %q, but has %q", i, tc.Answer[i].(*dns.TXT).Txt[j], txt)
					}
				}
			case *dns.DNSKEY:
				tt := tc.Answer[i].(*dns.DNSKEY)
				if x.Flags != tt.Flags {
					fatal = true
					t.Fatalf("DNSKEY flags should be %q, but is %q", x.Flags, tt.Flags)
				}
				if x.Protocol != tt.Protocol {
					fatal = true
					t.Fatalf("DNSKEY protocol should be %q, but is %q", x.Protocol, tt.Protocol)
				}
				if x.Algorithm != tt.Algorithm {
					fatal = true
					t.Fatalf("DNSKEY algorithm should be %q, but is %q", x.Algorithm, tt.Algorithm)
				}
			case *dns.RRSIG:
				tt := tc.Answer[i].(*dns.RRSIG)
				if x.TypeCovered != tt.TypeCovered {
					fatal = true
					t.Fatalf("RRSIG type-covered should be %d, but is %d", x.TypeCovered, tt.TypeCovered)
				}
				if x.Algorithm != tt.Algorithm {
					fatal = true
					t.Fatalf("RRSIG algorithm should be %d, but is %d", x.Algorithm, tt.Algorithm)
				}
				if x.Labels != tt.Labels {
					fatal = true
					t.Fatalf("RRSIG label should be %d, but is %d", x.Labels, tt.Labels)
				}
				if x.OrigTtl != tt.OrigTtl {
					fatal = true
					t.Fatalf("RRSIG orig-ttl should be %d, but is %d", x.OrigTtl, tt.OrigTtl)
				}
				if x.KeyTag != tt.KeyTag {
					fatal = true
					t.Fatalf("RRSIG key-tag should be %d, but is %d", x.KeyTag, tt.KeyTag)
				}
				if x.SignerName != tt.SignerName {
					fatal = true
					t.Fatalf("RRSIG signer-name should be %q, but is %q", x.SignerName, tt.SignerName)
				}
			case *dns.SOA:
				tt := tc.Answer[i].(*dns.SOA)
				if x.Ns != tt.Ns {
					fatal = true
					t.Fatalf("SOA nameserver should be %q, but is %q", x.Ns, tt.Ns)
				}
			case *dns.PTR:
				tt := tc.Answer[i].(*dns.PTR)
				if x.Ptr != tt.Ptr {
					fatal = true
					t.Fatalf("PTR ptr should be %q, but is %q", x.Ptr, tt.Ptr)
				}
			case *dns.CNAME:
				tt := tc.Answer[i].(*dns.CNAME)
				if x.Target != tt.Target {
					fatal = true
					t.Fatalf("CNAME target should be %q, but is %q", x.Target, tt.Target)
				}
			case *dns.MX:
				tt := tc.Answer[i].(*dns.MX)
				if x.Mx != tt.Mx {
					t.Fatalf("MX Mx should be %q, but is %q", x.Mx, tt.Mx)
				}
				if x.Preference != tt.Preference {
					t.Fatalf("MX Preference should be %q, but is %q", x.Preference, tt.Preference)
				}
			}
		}
		if len(resp.Ns) != len(tc.Ns) {
			fatal = true
			t.Fatalf("authority for %q contained %d results, %d expected", tc.Qname, len(resp.Ns), len(tc.Ns))
		}
		for i, n := range resp.Ns {
			switch x := n.(type) {
			case *dns.SOA:
				tt := tc.Ns[i].(*dns.SOA)
				if x.Ns != tt.Ns {
					fatal = true
					t.Fatalf("SOA nameserver should be %q, but is %q", x.Ns, tt.Ns)
				}
			case *dns.NS:
				tt := tc.Ns[i].(*dns.NS)
				if x.Ns != tt.Ns {
					fatal = true
					t.Fatalf("NS nameserver should be %q, but is %q", x.Ns, tt.Ns)
				}
			case *dns.NSEC3:
				tt := tc.Ns[i].(*dns.NSEC3)
				if x.NextDomain != tt.NextDomain {
					fatal = true
					t.Fatalf("NSEC3 nextdomain should be %q, but is %q", x.NextDomain, tt.NextDomain)
				}
				if x.Hdr.Name != tt.Hdr.Name {
					fatal = true
					t.Fatalf("NSEC3 ownername should be %q, but is %q", x.Hdr.Name, tt.Hdr.Name)
				}
				for j, y := range x.TypeBitMap {
					if y != tt.TypeBitMap[j] {
						fatal = true
						t.Fatalf("NSEC3 bitmap should have %q, but is %q", dns.TypeToString[y], dns.TypeToString[tt.TypeBitMap[j]])
					}
				}
			}
		}
		if len(resp.Extra) != len(tc.Extra) {
			fatal = true
			t.Fatalf("additional for %q contained %d results, %d expected", tc.Qname, len(resp.Extra), len(tc.Extra))
		}
		for i, e := range resp.Extra {
			switch x := e.(type) {
			case *dns.A:
				if x.A.String() != tc.Extra[i].(*dns.A).A.String() {
					fatal = true
					t.Fatalf("extra %d should have a address of %q, but has %q", i, tc.Extra[i].(*dns.A).A.String(), x.A.String())
				}
			case *dns.AAAA:
				if x.AAAA.String() != tc.Extra[i].(*dns.AAAA).AAAA.String() {
					fatal = true
					t.Fatalf("extra %d should have a address of %q, but has %q", i, tc.Extra[i].(*dns.AAAA).AAAA.String(), x.AAAA.String())
				}
			case *dns.CNAME:
				tt := tc.Extra[i].(*dns.CNAME)
				if x.Target != tt.Target {
					// Super super gross hack.
					if x.Target == "a.ipaddr.skydns.test." && tt.Target == "b.ipaddr.skydns.test." {
						// These records are randomly choosen, either one is OK.
						continue
					}
					fatal = true
					t.Fatalf("CNAME target should be %q, but is %q", x.Target, tt.Target)
				}
			}
		}
	}
}
Ejemplo n.º 30
0
func TestDNSStubForward(t *testing.T) {
	s := newTestServer(t, false)
	defer s.Stop()

	c := new(dns.Client)
	m := new(dns.Msg)

	stubEx := &msg.Service{
		// IP address of a.iana-servers.net.
		Host: "199.43.132.53", Key: "a.example.com.stub.dns.skydns.test.",
	}
	stubBroken := &msg.Service{
		Host: "127.0.0.1", Port: 5454, Key: "b.example.org.stub.dns.skydns.test.",
	}
	stubLoop := &msg.Service{
		Host: "127.0.0.1", Port: Port, Key: "b.example.net.stub.dns.skydns.test.",
	}
	addService(t, s, stubEx.Key, 0, stubEx)
	defer delService(t, s, stubEx.Key)
	addService(t, s, stubBroken.Key, 0, stubBroken)
	defer delService(t, s, stubBroken.Key)
	addService(t, s, stubLoop.Key, 0, stubLoop)
	defer delService(t, s, stubLoop.Key)

	s.UpdateStubZones()

	m.SetQuestion("www.example.com.", dns.TypeA)
	resp, _, err := c.Exchange(m, "127.0.0.1:"+StrPort)
	if err != nil {
		// try twice
		resp, _, err = c.Exchange(m, "127.0.0.1:"+StrPort)
		if err != nil {
			t.Fatal(err)
		}
	}
	if len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess {
		t.Fatal("answer expected to have A records or rcode not equal to RcodeSuccess")
	}
	// The main diff. here is that we expect the AA bit to be set, because we directly
	// queried the authoritative servers.
	if resp.Authoritative != true {
		t.Fatal("answer expected to have AA bit set")
	}

	// This should fail.
	m.SetQuestion("www.example.org.", dns.TypeA)
	resp, _, err = c.Exchange(m, "127.0.0.1:"+StrPort)
	if len(resp.Answer) != 0 || resp.Rcode != dns.RcodeServerFailure {
		t.Fatal("answer expected to fail for example.org")
	}

	// This should really fail with a timeout.
	m.SetQuestion("www.example.net.", dns.TypeA)
	resp, _, err = c.Exchange(m, "127.0.0.1:"+StrPort)
	if err == nil {
		t.Fatal("answer expected to fail for example.net")
	} else {
		t.Logf("succesfully failing %s", err)
	}

	// Packet with EDNS0
	m.SetEdns0(4096, true)
	resp, _, err = c.Exchange(m, "127.0.0.1:"+StrPort)
	if err == nil {
		t.Fatal("answer expected to fail for example.net")
	} else {
		t.Logf("succesfully failing %s", err)
	}

	// Now start another SkyDNS instance on a different port,
	// add a stubservice for it and check if the forwarding is
	// actually working.
	oldStrPort := StrPort

	s1 := newTestServer(t, false)
	defer s1.Stop()
	s1.config.Domain = "skydns.com."

	// Add forwarding IP for internal.skydns.com. Use Port to point to server s.
	stubForward := &msg.Service{
		Host: "127.0.0.1", Port: Port, Key: "b.internal.skydns.com.stub.dns.skydns.test.",
	}
	addService(t, s, stubForward.Key, 0, stubForward)
	defer delService(t, s, stubForward.Key)
	s.UpdateStubZones()

	// Add an answer for this in our "new" server.
	stubReply := &msg.Service{
		Host: "127.1.1.1", Key: "www.internal.skydns.com.",
	}
	addService(t, s1, stubReply.Key, 0, stubReply)
	defer delService(t, s1, stubReply.Key)

	m = new(dns.Msg)
	m.SetQuestion("www.internal.skydns.com.", dns.TypeA)
	resp, _, err = c.Exchange(m, "127.0.0.1:"+oldStrPort)
	if err != nil {
		t.Fatalf("failed to forward %s", err)
	}
	if resp.Answer[0].(*dns.A).A.String() != "127.1.1.1" {
		t.Fatalf("failed to get correct reply")
	}

	// Adding an in baliwick internal domain forward.
	s2 := newTestServer(t, false)
	defer s2.Stop()
	s2.config.Domain = "internal.skydns.net."

	// Add forwarding IP for internal.skydns.net. Use Port to point to server s.
	stubForward1 := &msg.Service{
		Host: "127.0.0.1", Port: Port, Key: "b.internal.skydns.net.stub.dns.skydns.test.",
	}
	addService(t, s, stubForward1.Key, 0, stubForward1)
	defer delService(t, s, stubForward1.Key)
	s.UpdateStubZones()

	// Add an answer for this in our "new" server.
	stubReply1 := &msg.Service{
		Host: "127.10.10.10", Key: "www.internal.skydns.net.",
	}
	addService(t, s2, stubReply1.Key, 0, stubReply1)
	defer delService(t, s2, stubReply1.Key)

	m = new(dns.Msg)
	m.SetQuestion("www.internal.skydns.net.", dns.TypeA)
	resp, _, err = c.Exchange(m, "127.0.0.1:"+oldStrPort)
	if err != nil {
		t.Fatalf("failed to forward %s", err)
	}
	if resp.Answer[0].(*dns.A).A.String() != "127.10.10.10" {
		t.Fatalf("failed to get correct reply")
	}
}