Example #1
0
// Check that the cache entries are ok
func TestCacheEntries(t *testing.T) {
	InitDefaultLogging(testing.Verbose())
	Info.Println("TestCacheEntries starting")

	Info.Println("Checking cache consistency")

	const cacheLen = 128
	clk := clock.NewMock()

	l, err := NewCache(cacheLen, clk)
	wt.AssertNoErr(t, err)

	questionMsg := new(dns.Msg)
	questionMsg.SetQuestion("some.name", dns.TypeA)
	questionMsg.RecursionDesired = true

	question := &questionMsg.Question[0]

	t.Logf("Trying to get a name")
	resp, err := l.Get(questionMsg, minUDPSize)
	wt.AssertNoErr(t, err)
	if resp != nil {
		t.Logf("Got\n%s", resp)
		t.Fatalf("ERROR: Did not expect a reponse from Get() yet")
	}
	t.Logf("Trying to get it again")
	resp, err = l.Get(questionMsg, minUDPSize)
	wt.AssertNoErr(t, err)
	if resp != nil {
		t.Logf("Got\n%s", resp)
		t.Fatalf("ERROR: Did not expect a reponse from Get() yet")
	}

	t.Logf("Inserting the reply")
	records := []ZoneRecord{Record{"some.name", net.ParseIP("10.0.1.1"), 0, 0, 0}}
	reply1 := makeAddressReply(questionMsg, question, records)
	l.Put(questionMsg, reply1, nullTTL, 0)

	t.Logf("Checking we can Get() the reply now")
	resp, err = l.Get(questionMsg, minUDPSize)
	wt.AssertNoErr(t, err)
	wt.AssertTrue(t, resp != nil, "reponse from Get()")
	t.Logf("Received '%s'", resp.Answer[0])
	wt.AssertType(t, resp.Answer[0], (*dns.A)(nil), "DNS record")
	ttlGet1 := resp.Answer[0].Header().Ttl

	clk.Add(time.Duration(1) * time.Second)
	t.Logf("Checking that a second Get(), after 1 second, gets the same result, but with reduced TTL")
	resp, err = l.Get(questionMsg, minUDPSize)
	wt.AssertNoErr(t, err)
	wt.AssertTrue(t, resp != nil, "reponse from a second Get()")
	t.Logf("Received '%s'", resp.Answer[0])
	wt.AssertType(t, resp.Answer[0], (*dns.A)(nil), "DNS record")
	ttlGet2 := resp.Answer[0].Header().Ttl
	wt.AssertEqualInt(t, int(ttlGet1-ttlGet2), 1, "TTL difference")

	clk.Add(time.Duration(localTTL) * time.Second)
	t.Logf("Checking that a third Get(), after %d second, gets no result", localTTL)
	resp, err = l.Get(questionMsg, minUDPSize)
	wt.AssertNoErr(t, err)
	if resp != nil {
		t.Logf("Got\n%s", resp)
		t.Fatalf("ERROR: Did NOT expect a reponse from the second Get()")
	}

	t.Logf("Checking that an Remove() results in Get() returning nothing")
	records = []ZoneRecord{Record{"some.name", net.ParseIP("10.0.9.9"), 0, 0, 0}}
	replyTemp := makeAddressReply(questionMsg, question, records)
	l.Put(questionMsg, replyTemp, nullTTL, 0)
	lenBefore := l.Len()
	l.Remove(question)
	wt.AssertEqualInt(t, l.Len(), lenBefore-1, "cache length")
	l.Remove(question) // do it again: should have no effect...
	wt.AssertEqualInt(t, l.Len(), lenBefore-1, "cache length")

	resp, err = l.Get(questionMsg, minUDPSize)
	wt.AssertNoErr(t, err)
	wt.AssertTrue(t, resp == nil, "reponse from the Get() after a Remove()")

	t.Logf("Inserting a two replies for the same query")
	records = []ZoneRecord{Record{"some.name", net.ParseIP("10.0.1.2"), 0, 0, 0}}
	reply2 := makeAddressReply(questionMsg, question, records)
	l.Put(questionMsg, reply2, nullTTL, 0)
	clk.Add(time.Duration(1) * time.Second)
	records = []ZoneRecord{Record{"some.name", net.ParseIP("10.0.1.3"), 0, 0, 0}}
	reply3 := makeAddressReply(questionMsg, question, records)
	l.Put(questionMsg, reply3, nullTTL, 0)

	t.Logf("Checking we get the last one...")
	resp, err = l.Get(questionMsg, minUDPSize)
	wt.AssertNoErr(t, err)
	wt.AssertTrue(t, resp != nil, "reponse from the Get()")
	t.Logf("Received '%s'", resp.Answer[0])
	wt.AssertType(t, resp.Answer[0], (*dns.A)(nil), "DNS record")
	wt.AssertEqualString(t, resp.Answer[0].(*dns.A).A.String(), "10.0.1.3", "IP address")
	wt.AssertEqualInt(t, int(resp.Answer[0].Header().Ttl), int(localTTL), "TTL")

	clk.Add(time.Duration(localTTL-1) * time.Second)
	resp, err = l.Get(questionMsg, minUDPSize)
	wt.AssertNoErr(t, err)
	wt.AssertTrue(t, resp != nil, "reponse from the Get()")
	t.Logf("Received '%s'", resp.Answer[0])
	wt.AssertType(t, resp.Answer[0], (*dns.A)(nil), "DNS record")
	wt.AssertEqualString(t, resp.Answer[0].(*dns.A).A.String(), "10.0.1.3", "IP address")
	wt.AssertEqualInt(t, int(resp.Answer[0].Header().Ttl), 1, "TTL")

	t.Logf("Checking we get empty replies when they are expired...")
	lenBefore = l.Len()
	clk.Add(time.Duration(localTTL) * time.Second)
	resp, err = l.Get(questionMsg, minUDPSize)
	wt.AssertNoErr(t, err)
	if resp != nil {
		t.Logf("Got\n%s", resp.Answer[0])
		t.Fatalf("ERROR: Did NOT expect a reponse from the Get()")
	}
	wt.AssertEqualInt(t, l.Len(), lenBefore-1, "cache length (after getting an expired entry)")

	questionMsg2 := new(dns.Msg)
	questionMsg2.SetQuestion("some.other.name", dns.TypeA)
	questionMsg2.RecursionDesired = true
	question2 := &questionMsg2.Question[0]

	t.Logf("Trying to Get() a name")
	resp, err = l.Get(questionMsg2, minUDPSize)
	wt.AssertNoErr(t, err)
	wt.AssertNil(t, resp, "reponse from Get() yet")

	t.Logf("Checking that an Remove() between Get() and Put() does not break things")
	records = []ZoneRecord{Record{"some.name", net.ParseIP("10.0.9.9"), 0, 0, 0}}
	replyTemp2 := makeAddressReply(questionMsg2, question2, records)
	l.Remove(question2)
	l.Put(questionMsg2, replyTemp2, nullTTL, 0)
	resp, err = l.Get(questionMsg2, minUDPSize)
	wt.AssertNoErr(t, err)
	wt.AssertNotNil(t, resp, "reponse from Get()")

	questionMsg3 := new(dns.Msg)
	questionMsg3.SetQuestion("some.other.name", dns.TypeA)
	questionMsg3.RecursionDesired = true
	question3 := &questionMsg3.Question[0]

	t.Logf("Checking that a entry with CacheNoLocalReplies return an error")
	l.Put(questionMsg3, nil, nullTTL, CacheNoLocalReplies)
	resp, err = l.Get(questionMsg3, minUDPSize)
	wt.AssertNil(t, resp, "Get() response with CacheNoLocalReplies")
	wt.AssertNotNil(t, err, "Get() error with CacheNoLocalReplies")

	clk.Add(time.Second * time.Duration(negLocalTTL+1))
	t.Logf("Checking that we get an expired response after %f seconds", negLocalTTL)
	resp, err = l.Get(questionMsg3, minUDPSize)
	wt.AssertNil(t, resp, "expired Get() response with CacheNoLocalReplies")
	wt.AssertNil(t, err, "expired Get() error with CacheNoLocalReplies")

	l.Remove(question3)
	t.Logf("Checking that Put&Get with CacheNoLocalReplies with a Remove in the middle returns nothing")
	l.Put(questionMsg3, nil, nullTTL, CacheNoLocalReplies)
	l.Remove(question3)
	resp, err = l.Get(questionMsg3, minUDPSize)
	wt.AssertNil(t, resp, "Get() reponse with CacheNoLocalReplies")
	wt.AssertNil(t, err, "Get() error with CacheNoLocalReplies")
}
Example #2
0
func TestUDPDNSServer(t *testing.T) {
	setupForTest(t)
	const (
		containerID     = "foobar"
		successTestName = "test1.weave.local."
		failTestName    = "test2.weave.local."
		nonLocalName    = "weave.works."
		testAddr1       = "10.2.2.1"
	)
	testCIDR1 := testAddr1 + "/24"

	InitDefaultLogging(true)
	var zone = NewZoneDb(DefaultLocalDomain)
	ip, _, _ := net.ParseCIDR(testCIDR1)
	zone.AddRecord(containerID, successTestName, ip)

	fallbackHandler := func(w dns.ResponseWriter, req *dns.Msg) {
		m := new(dns.Msg)
		m.SetReply(req)
		if len(req.Question) == 1 {
			q := req.Question[0]
			if q.Name == nonLocalName && q.Qtype == dns.TypeMX {
				m.Answer = make([]dns.RR, 1)
				m.Answer[0] = &dns.MX{Hdr: dns.RR_Header{Name: m.Question[0].Name, Rrtype: dns.TypeMX, Class: dns.ClassINET, Ttl: 0}, Mx: "mail." + nonLocalName}
			} else if q.Name == nonLocalName && q.Qtype == dns.TypeANY {
				m.Answer = make([]dns.RR, 512/len("mailn."+nonLocalName)+1)
				for i := range m.Answer {
					m.Answer[i] = &dns.MX{Hdr: dns.RR_Header{Name: m.Question[0].Name, Rrtype: dns.TypeMX, Class: dns.ClassINET, Ttl: 0}, Mx: fmt.Sprintf("mail%d.%s", i, nonLocalName)}
				}
			} else if q.Name == testRDNSnonlocal && q.Qtype == dns.TypePTR {
				m.Answer = make([]dns.RR, 1)
				m.Answer[0] = &dns.PTR{Hdr: dns.RR_Header{Name: m.Question[0].Name, Rrtype: dns.TypePTR, Class: dns.ClassINET, Ttl: 0}, Ptr: "ns1.google.com."}
			} else if q.Name == testRDNSfail && q.Qtype == dns.TypePTR {
				m.Rcode = dns.RcodeNameError
			}
		}
		w.WriteMsg(m)
	}

	// Run another DNS server for fallback
	s, fallbackAddr, err := runLocalUDPServer(t, "127.0.0.1:0", fallbackHandler)
	wt.AssertNoErr(t, err)
	defer s.Shutdown()

	_, fallbackPort, err := net.SplitHostPort(fallbackAddr)
	wt.AssertNoErr(t, err)

	config := &dns.ClientConfig{Servers: []string{"127.0.0.1"}, Port: fallbackPort}
	srv, err := NewDNSServer(DNSServerConfig{UpstreamCfg: config, Port: testPort}, zone, nil)
	wt.AssertNoErr(t, err)
	defer srv.Stop()
	go srv.Start()
	time.Sleep(100 * time.Millisecond) // Allow sever goroutine to start

	var r *dns.Msg

	r = assertExchange(t, successTestName, dns.TypeA, 1, 1, 0)
	wt.AssertType(t, r.Answer[0], (*dns.A)(nil), "DNS record")
	wt.AssertEqualString(t, r.Answer[0].(*dns.A).A.String(), testAddr1, "IP address")

	assertExchange(t, failTestName, dns.TypeA, 0, 0, dns.RcodeNameError)

	r = assertExchange(t, testRDNSsuccess, dns.TypePTR, 1, 1, 0)
	wt.AssertType(t, r.Answer[0], (*dns.PTR)(nil), "DNS record")
	wt.AssertEqualString(t, r.Answer[0].(*dns.PTR).Ptr, successTestName, "IP address")

	assertExchange(t, testRDNSfail, dns.TypePTR, 0, 0, dns.RcodeNameError)

	// This should fail because we don't have information about MX records
	assertExchange(t, successTestName, dns.TypeMX, 0, 0, dns.RcodeNameError)

	// This non-local query for an MX record should succeed by being
	// passed on to the fallback server
	assertExchange(t, nonLocalName, dns.TypeMX, 1, -1, 0)

	// Now ask a query that we expect to return a lot of data.
	assertExchange(t, nonLocalName, dns.TypeANY, 5, -1, 0)

	assertExchange(t, testRDNSnonlocal, dns.TypePTR, 1, -1, 0)

	// Not testing MDNS functionality of server here (yet), since it
	// needs two servers, each listening on its own address
}