Ejemplo n.º 1
0
func (l Logger) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
	state := middleware.State{W: w, Req: r}
	for _, rule := range l.Rules {
		if middleware.Name(rule.NameScope).Matches(state.Name()) {
			responseRecorder := middleware.NewResponseRecorder(w)
			rcode, err := l.Next.ServeDNS(ctx, responseRecorder, r)

			if rcode > 0 {
				// There was an error up the chain, but no response has been written yet.
				// The error must be handled here so the log entry will record the response size.
				if l.ErrorFunc != nil {
					l.ErrorFunc(responseRecorder, r, rcode)
				} else {
					rc := middleware.RcodeToString(rcode)

					answer := new(dns.Msg)
					answer.SetRcode(r, rcode)
					state.SizeAndDo(answer)

					metrics.Report(state, metrics.Dropped, rc, answer.Len(), time.Now())
					w.WriteMsg(answer)
				}
				rcode = 0
			}
			rep := middleware.NewReplacer(r, responseRecorder, CommonLogEmptyValue)
			rule.Log.Println(rep.Replace(rule.Format))
			return rcode, err

		}
	}
	return l.Next.ServeDNS(ctx, w, r)
}
Ejemplo n.º 2
0
// set the reply for the entry
// returns True if the entry has changed the validUntil time
func (e *cacheEntry) setReply(reply *dns.Msg, ttl int, flags uint8, now time.Time) bool {
	var prevValidUntil time.Time
	if e.Status == stResolved {
		if reply != nil {
			Debug.Printf("[cache msgid %d] replacing response in cache", reply.MsgHdr.Id)
		}
		prevValidUntil = e.validUntil
	}

	e.Status = stResolved
	e.Flags = flags
	e.putTime = now

	if ttl != nullTTL {
		e.validUntil = now.Add(time.Second * time.Duration(ttl))
	} else if reply != nil {
		// calculate the validUntil from the reply TTL
		var minTTL uint32 = math.MaxUint32
		for _, rr := range reply.Answer {
			ttl := rr.Header().Ttl
			if ttl < minTTL {
				minTTL = ttl // TODO: improve the minTTL calculation (maybe we should skip some RRs)
			}
		}
		e.validUntil = now.Add(time.Second * time.Duration(minTTL))
	}

	if reply != nil {
		e.reply = *reply
		e.ReplyLen = reply.Len()
	}

	return (prevValidUntil != e.validUntil)
}
Ejemplo n.º 3
0
func ReportErrorCount(resp *dns.Msg, sys System) {
	if resp == nil || errorCount == nil {
		return
	}

	if resp.Truncated {
		errorCount.WithLabelValues(string(sys), string(Truncated)).Inc()
		return
	}
	if resp.Len() > dns.MaxMsgSize {
		errorCount.WithLabelValues(string(sys), string(Overflow)).Inc()
		return
	}

	switch resp.Rcode {
	case dns.RcodeServerFailure:
		errorCount.WithLabelValues(string(sys), string(Fail)).Inc()
	case dns.RcodeRefused:
		errorCount.WithLabelValues(string(sys), string(Refused)).Inc()
	case dns.RcodeNameError:
		errorCount.WithLabelValues(string(sys), string(Nxdomain)).Inc()
		// nodata ??
	}

}
Ejemplo n.º 4
0
// WriteMsg records the status code and calls the
// underlying ResponseWriter's WriteMsg method.
func (r *ResponseRecorder) WriteMsg(res *dns.Msg) error {
	r.rcode = res.Rcode
	// We may get called multiple times (axfr for instance).
	// Save the last message, but add the sizes.
	r.size += res.Len()
	r.msg = res
	return r.ResponseWriter.WriteMsg(res)
}
Ejemplo n.º 5
0
// Sign signs a message m, it takes care of negative or nodata responses as
// well by synthesising NSEC3 records. It will also cache the signatures, using
// a hash of the signed data as a key.
// We also fake the origin TTL in the signature, because we don't want to
// throw away signatures when services decide to have longer TTL. So we just
// set the origTTL to 60.
// TODO(miek): revisit origTTL
func (s *server) Sign(m *dns.Msg, bufsize uint16) {
	now := time.Now().UTC()
	incep := uint32(now.Add(-3 * time.Hour).Unix())     // 2+1 hours, be sure to catch daylight saving time and such
	expir := uint32(now.Add(7 * 24 * time.Hour).Unix()) // sign for a week

	defer func() {
		promCacheSize.WithLabelValues("signature").Set(float64(s.scache.Size()))
	}()

	for _, r := range rrSets(m.Answer) {
		if r[0].Header().Rrtype == dns.TypeRRSIG {
			continue
		}
		if !dns.IsSubDomain(s.config.Domain, r[0].Header().Name) {
			continue
		}
		if sig, err := s.signSet(r, now, incep, expir); err == nil {
			m.Answer = append(m.Answer, sig)
		}
	}
	for _, r := range rrSets(m.Ns) {
		if r[0].Header().Rrtype == dns.TypeRRSIG {
			continue
		}
		if !dns.IsSubDomain(s.config.Domain, r[0].Header().Name) {
			continue
		}
		if sig, err := s.signSet(r, now, incep, expir); err == nil {
			m.Ns = append(m.Ns, sig)
		}
	}
	for _, r := range rrSets(m.Extra) {
		if r[0].Header().Rrtype == dns.TypeRRSIG || r[0].Header().Rrtype == dns.TypeOPT {
			continue
		}
		if !dns.IsSubDomain(s.config.Domain, r[0].Header().Name) {
			continue
		}
		if sig, err := s.signSet(r, now, incep, expir); err == nil {
			m.Extra = append(m.Extra, sig)
		}
	}
	if bufsize >= 512 || bufsize <= 4096 {
		// TCP here?
		promErrorCount.WithLabelValues("truncated").Inc()
		m.Truncated = m.Len() > int(bufsize)
	}
	o := new(dns.OPT)
	o.Hdr.Name = "."
	o.Hdr.Rrtype = dns.TypeOPT
	o.SetDo()
	o.SetUDPSize(4096) // TODO(miek): echo client
	m.Extra = append(m.Extra, o)
	return
}
Ejemplo n.º 6
0
// DefaultErrorFunc responds to an DNS request with an error.
func DefaultErrorFunc(w dns.ResponseWriter, r *dns.Msg, rcode int) {
	state := middleware.State{W: w, Req: r}
	rc := middleware.RcodeToString(rcode)

	answer := new(dns.Msg)
	answer.SetRcode(r, rcode)
	state.SizeAndDo(answer)

	metrics.Report(state, metrics.Dropped, rc, answer.Len(), time.Now())
	w.WriteMsg(answer)
}
Ejemplo n.º 7
0
func truncateResp(resp *dns.Msg, maxSize int, isTCP bool) {
	if !isTCP {
		resp.Truncated = true
	}

	// trim the Answer RRs one by one till the whole message fits
	// within the reply size
	for resp.Len() > maxSize {
		resp.Answer = resp.Answer[:len(resp.Answer)-1]
	}
}
Ejemplo n.º 8
0
// metricSizeAndDuration sets the size and duration metrics.
func metricSizeAndDuration(resp *dns.Msg, start time.Time, tcp bool) {
	net := "udp"
	rlen := float64(0)
	if tcp {
		net = "tcp"
	}
	if resp != nil {
		rlen = float64(resp.Len())
	}
	promRequestDuration.WithLabelValues(net).Observe(float64(time.Since(start)) / float64(time.Second))
	promResponseSize.WithLabelValues(net).Observe(rlen)
}
Ejemplo n.º 9
0
func ReportDuration(resp *dns.Msg, start time.Time, sys System) {
	if requestDuration == nil || responseSize == nil {
		return
	}

	rlen := float64(0)
	if resp != nil {
		rlen = float64(resp.Len())
	}
	requestDuration.WithLabelValues(string(sys)).Observe(float64(time.Since(start)) / float64(time.Second))
	responseSize.WithLabelValues(string(sys)).Observe(rlen)
}
Ejemplo n.º 10
0
func Respond(w dns.ResponseWriter, req *dns.Msg, records []dns.RR) {
	m := new(dns.Msg)
	m.SetReply(req)
	m.Authoritative = true
	m.RecursionAvailable = true
	m.Compress = true
	m.Answer = records

	// Figure out the max response size
	bufsize := uint16(512)
	tcp := isTcp(w)

	if o := req.IsEdns0(); o != nil {
		bufsize = o.UDPSize()
	}

	if tcp {
		bufsize = dns.MaxMsgSize - 1
	} else if bufsize < 512 {
		bufsize = 512
	}

	if m.Len() > dns.MaxMsgSize {
		fqdn := dns.Fqdn(req.Question[0].Name)
		log.WithFields(log.Fields{"fqdn": fqdn}).Debug("Response too big, dropping Extra")
		m.Extra = nil
		if m.Len() > dns.MaxMsgSize {
			log.WithFields(log.Fields{"fqdn": fqdn}).Debug("Response still too big")
			m := new(dns.Msg)
			m.SetRcode(m, dns.RcodeServerFailure)
		}
	}

	if m.Len() > int(bufsize) && !tcp {
		log.Debug("Too big 1")
		m.Extra = nil
		if m.Len() > int(bufsize) {
			log.Debug("Too big 2")
			m.Answer = nil
			m.Truncated = true
		}
	}

	err := w.WriteMsg(m)
	if err != nil {
		log.Warn("Failed to return reply: ", err, m.Len())
	}

}
Ejemplo n.º 11
0
func TestFit(t *testing.T) {
	m := new(dns.Msg)
	m.SetQuestion("miek.nl", dns.TypeA)

	rr, _ := dns.NewRR("www.miek.nl. IN SRV 10 10 8080 blaat.miek.nl.")
	for i := 0; i < 101; i++ {
		m.Answer = append(m.Answer, rr)
	}
	// Uncompresses length is now 4424. Try trimming this to 1927
	Fit(m, 1927, true)

	if m.Len() > 1927 {
		t.Fatalf("failed to fix message, expected < %d, got %d", 1927, m.Len())
	}
}
Ejemplo n.º 12
0
// trimAnswers makes sure a UDP response is not longer than allowed by RFC 1035.
// We first enforce an arbitrary limit, and then make sure the response doesn't
// exceed 512 bytes.
func trimAnswers(resp *dns.Msg) (trimmed bool) {
	numAnswers := len(resp.Answer)

	// This cuts UDP responses to a useful but limited number of responses.
	if numAnswers > maxServiceResponses {
		resp.Answer = resp.Answer[:maxServiceResponses]
	}

	// This enforces the hard limit of 512 bytes per the RFC.
	for len(resp.Answer) > 0 && resp.Len() > 512 {
		resp.Answer = resp.Answer[:len(resp.Answer)-1]
	}

	return len(resp.Answer) < numAnswers
}
Ejemplo n.º 13
0
// trimUDPAnswers makes sure a UDP response is not longer than allowed by RFC
// 1035.  Enforce an arbitrary limit that can be further ratcheted down by
// config, and then make sure the response doesn't exceed 512 bytes.
func trimUDPAnswers(config *DNSConfig, resp *dns.Msg) (trimmed bool) {
	numAnswers := len(resp.Answer)

	// This cuts UDP responses to a useful but limited number of responses.
	maxAnswers := lib.MinInt(maxUDPAnswerLimit, config.UDPAnswerLimit)
	if numAnswers > maxAnswers {
		resp.Answer = resp.Answer[:maxAnswers]
	}

	// This enforces the hard limit of 512 bytes per the RFC.
	for len(resp.Answer) > 0 && resp.Len() > 512 {
		resp.Answer = resp.Answer[:len(resp.Answer)-1]
	}

	return len(resp.Answer) < numAnswers
}
Ejemplo n.º 14
0
func truncateResp(resp *dns.Msg, maxSize int, isTCP bool) {
	if !isTCP {
		resp.Truncated = true
	}

	srv := resp.Question[0].Qtype == dns.TypeSRV
	// trim the Answer RRs one by one till the whole message fits
	// within the reply size
	for resp.Len() > maxSize {
		resp.Answer = resp.Answer[:len(resp.Answer)-1]

		if srv && len(resp.Extra) > 0 {
			resp.Extra = resp.Extra[:len(resp.Extra)-1]
		}
	}
}
Ejemplo n.º 15
0
// set the reply for the entry
// returns True if the entry has changed the validUntil time
func (e *cacheEntry) setReply(reply *dns.Msg, ttl int, flags uint8, now time.Time) bool {
	var prevValidUntil time.Time
	if e.Status == stResolved {
		if reply != nil {
			Log.Debugf("[cache msgid %d] replacing response in cache", reply.MsgHdr.Id)
		}
		prevValidUntil = e.validUntil
	}

	// make sure we do not overwrite noLocalReplies entries
	if flags&CacheNoLocalReplies != 0 {
		if e.Flags&CacheNoLocalReplies != 0 {
			return false
		}
	}

	if ttl != nullTTL {
		e.validUntil = now.Add(time.Second * time.Duration(ttl))
	} else if reply != nil {
		// calculate the validUntil from the reply TTL
		var minTTL uint32 = math.MaxUint32
		for _, rr := range reply.Answer {
			ttl := rr.Header().Ttl
			if ttl < minTTL {
				minTTL = ttl // TODO: improve the minTTL calculation (maybe we should skip some RRs)
			}
		}
		e.validUntil = now.Add(time.Second * time.Duration(minTTL))
	} else {
		Log.Warningf("[cache] no valid TTL could be calculated")
	}

	e.Status = stResolved
	e.Flags = flags
	e.putTime = now

	if reply != nil {
		e.reply = *reply.Copy()
		e.ReplyLen = reply.Len()
	}

	return (prevValidUntil != e.validUntil)
}
Ejemplo n.º 16
0
// Scrub scrubs the reply message so that it will fit the client's buffer. If even after dropping
// the additional section, it still does not fit the TC bit will be set on the message. Note,
// the TC bit will be set regardless of protocol, even TCP message will get the bit, the client
// should then retry with pigeons.
// TODO(referral).
func (s *State) Scrub(reply *dns.Msg) (*dns.Msg, Result) {
	size := s.Size()
	l := reply.Len()
	if size >= l {
		return reply, ScrubIgnored
	}
	// TODO(miek): check for delegation

	// If not delegation, drop additional section.
	reply.Extra = nil
	s.SizeAndDo(reply)
	l = reply.Len()
	if size >= l {
		return reply, ScrubDone
	}
	// Still?!! does not fit.
	reply.Truncated = true
	return reply, ScrubDone
}
Ejemplo n.º 17
0
// truncate removes answers until the given dns.Msg fits the permitted
// length of the given transmission channel and sets the TC bit.
// See https://tools.ietf.org/html/rfc1035#section-4.2.1
func truncate(m *dns.Msg, udp bool) *dns.Msg {
	max := dns.MinMsgSize
	if !udp {
		max = dns.MaxMsgSize
	} else if opt := m.IsEdns0(); opt != nil {
		max = int(opt.UDPSize())
	}

	m.Truncated = m.Len() > max
	if !m.Truncated {
		return m
	}

	m.Extra = nil // Drop all extra records first
	if m.Len() < max {
		return m
	}
	answers := m.Answer[:]
	left, right := 0, len(m.Answer)
	for {
		if left == right {
			break
		}
		mid := (left + right) / 2
		m.Answer = answers[:mid]
		if m.Len() < max {
			left = mid + 1
			continue
		}
		right = mid
	}
	return m
}
Ejemplo n.º 18
0
// sign signs a message m, it takes care of negative or nodata responses as
// well by synthesising NSEC3 records. It will also cache the signatures, using
// a hash of the signed data as a key.
// We also fake the origin TTL in the signature, because we don't want to
// throw away signatures when services decide to have longer TTL. So we just
// set the origTTL to 60.
// TODO(miek): revisit origTTL
func (s *server) sign(m *dns.Msg, bufsize uint16) {
	now := time.Now().UTC()
	incep := uint32(now.Add(-3 * time.Hour).Unix())     // 2+1 hours, be sure to catch daylight saving time and such
	expir := uint32(now.Add(7 * 24 * time.Hour).Unix()) // sign for a week

	for _, r := range rrSets(m.Answer) {
		if r[0].Header().Rrtype == dns.TypeRRSIG {
			continue
		}
		if sig, err := s.signSet(r, now, incep, expir); err == nil {
			m.Answer = append(m.Answer, sig)
		}
	}
	for _, r := range rrSets(m.Ns) {
		if r[0].Header().Rrtype == dns.TypeRRSIG {
			continue
		}
		if sig, err := s.signSet(r, now, incep, expir); err == nil {
			m.Ns = append(m.Ns, sig)
		}
	}
	for _, r := range rrSets(m.Extra) {
		if r[0].Header().Rrtype == dns.TypeRRSIG {
			continue
		}
		if sig, err := s.signSet(r, now, incep, expir); err == nil {
			m.Extra = append(m.Extra, sig)
		}
	}
	if bufsize >= 512 || bufsize <= 4096 {
		m.Truncated = m.Len() > int(bufsize)
	}
	o := new(dns.OPT)
	o.Hdr.Name = "."
	o.Hdr.Rrtype = dns.TypeOPT
	o.SetDo()
	o.SetUDPSize(4096) // TODO(miek): echo client
	m.Extra = append(m.Extra, o)
	return
}
Ejemplo n.º 19
0
func truncateResponse(response *dns.Msg, maxSize int) {
	if len(response.Answer) <= 1 || maxSize <= 0 {
		return
	}

	// take a copy of answers, as we're going to mutate response
	answers := response.Answer

	// search for smallest i that is too big
	i := sort.Search(len(response.Answer), func(i int) bool {
		// return true if too big
		response.Answer = answers[:i+1]
		return response.Len() > maxSize
	})
	if i == len(answers) {
		response.Answer = answers
		return
	}

	response.Answer = answers[:i]
	response.Truncated = true
}
Ejemplo n.º 20
0
// trimUDPResponse makes sure a UDP response is not longer than allowed by RFC
// 1035. Enforce an arbitrary limit that can be further ratcheted down by
// config, and then make sure the response doesn't exceed 512 bytes. Any extra
// records will be trimmed along with answers.
func trimUDPResponse(config *DNSConfig, resp *dns.Msg) (trimmed bool) {
	numAnswers := len(resp.Answer)
	hasExtra := len(resp.Extra) > 0

	// We avoid some function calls and allocations by only handling the
	// extra data when necessary.
	var index map[string]dns.RR
	if hasExtra {
		index = make(map[string]dns.RR, len(resp.Extra))
		indexRRs(resp.Extra, index)
	}

	// This cuts UDP responses to a useful but limited number of responses.
	maxAnswers := lib.MinInt(maxUDPAnswerLimit, config.UDPAnswerLimit)
	if numAnswers > maxAnswers {
		resp.Answer = resp.Answer[:maxAnswers]
		if hasExtra {
			syncExtra(index, resp)
		}
	}

	// This enforces the hard limit of 512 bytes per the RFC. Note that we
	// temporarily switch to uncompressed so that we limit to a response
	// that will not exceed 512 bytes uncompressed, which is more
	// conservative and will allow our responses to be compliant even if
	// some downstream server uncompresses them.
	compress := resp.Compress
	resp.Compress = false
	for len(resp.Answer) > 0 && resp.Len() > 512 {
		resp.Answer = resp.Answer[:len(resp.Answer)-1]
		if hasExtra {
			syncExtra(index, resp)
		}
	}
	resp.Compress = compress

	return len(resp.Answer) < numAnswers
}
Ejemplo n.º 21
0
// Fit will make m fit the size. If a message is larger than size then entire
// additional section is dropped. If it is still to large and the transport
// is udp we return a truncated message.
// If the transport is tcp we are going to drop RR from the answer section
// until it fits. When this is case the returned bool is true.
func Fit(m *dns.Msg, size int, tcp bool) (*dns.Msg, bool) {
	if m.Len() > size {
		// Check for OPT Records at the end and keep those. TODO(miek)
		m.Extra = nil
	}
	if m.Len() < size {
		return m, false
	}

	// With TCP setting TC does not mean anything.
	if !tcp {
		m.Truncated = true
		// fall through here, so we at least return a message that can
		// fit the udp buffer.
	}

	// Additional section is gone, binary search until we have length that fits.
	min, max := 0, len(m.Answer)
	original := make([]dns.RR, len(m.Answer))
	copy(original, m.Answer)
	for {
		if min == max {
			break
		}

		mid := (min + max) / 2
		m.Answer = original[:mid]

		if m.Len() < size {
			min++
			continue
		}
		max = mid

	}
	if max > 1 {
		max--
	}
	m.Answer = m.Answer[:max]
	return m, true
}
Ejemplo n.º 22
0
func Respond(w dns.ResponseWriter, req *dns.Msg, m *dns.Msg) {
	// Figure out the max response size
	bufsize := uint16(512)
	tcp := isTcp(w)

	if o := req.IsEdns0(); o != nil {
		bufsize = o.UDPSize()
	}

	if tcp {
		bufsize = dns.MaxMsgSize - 1
	} else if bufsize < 512 {
		bufsize = 512
	}

	// Make sure the payload fits the buffer size. If the message is too large we strip the Extra section.
	// If it's still too large we return a truncated message for UDP queries and ServerFailure for TCP queries.
	if m.Len() > int(bufsize) {
		fqdn := dns.Fqdn(req.Question[0].Name)
		log.WithFields(log.Fields{"fqdn": fqdn}).Debug("Response too big, dropping Authority and Extra")
		m.Extra = nil
		if m.Len() > int(bufsize) {
			if tcp {
				log.WithFields(log.Fields{"fqdn": fqdn}).Debug("Response still too big, return ServerFailure")
				m = new(dns.Msg)
				m.SetRcode(req, dns.RcodeServerFailure)
			} else {
				log.WithFields(log.Fields{"fqdn": fqdn}).Debug("Response still too big, return truncated message")
				m.Answer = nil
				m.Truncated = true
			}
		}
	}

	err := w.WriteMsg(m)
	if err != nil {
		log.Warn("Failed to return reply: ", err, m.Len())
	}

}
Ejemplo n.º 23
0
Archivo: dns.go Proyecto: n054/weave
func (h *handler) responseTooBig(req, response *dns.Msg) bool {
	return len(response.Answer) > 1 && h.maxResponseSize > 0 && response.Len() > h.getMaxResponseSize(req)
}
Ejemplo n.º 24
0
func TestRecursiveCompress(t *testing.T) {
	const (
		hostname = "foo.example."
		maxSize  = 512
	)

	// Construct a response that is >512 when uncompressed, <512 when compressed
	response := dns.Msg{}
	response.Authoritative = true
	response.Answer = []dns.RR{}
	header := dns.RR_Header{
		Name:   hostname,
		Rrtype: dns.TypeA,
		Class:  dns.ClassINET,
		Ttl:    10,
	}
	for response.Len() <= maxSize {
		ip := address.Address(rand.Uint32()).IP4()
		response.Answer = append(response.Answer, &dns.A{Hdr: header, A: ip})
	}
	response.Compress = true
	require.True(t, response.Len() <= maxSize)

	// A dns server that returns the above response
	var gotRequest = make(chan struct{}, 1)
	handleRecursive := func(w dns.ResponseWriter, req *dns.Msg) {
		gotRequest <- struct{}{}
		require.Equal(t, req.Question[0].Name, hostname)
		response.SetReply(req)
		err := w.WriteMsg(&response)
		require.Nil(t, err)
	}
	mux := dns.NewServeMux()
	mux.HandleFunc(topDomain, handleRecursive)
	udpListener, err := net.ListenPacket("udp", "0.0.0.0:0")
	require.Nil(t, err)
	udpServer := &dns.Server{PacketConn: udpListener, Handler: mux}
	udpServerPort := udpListener.LocalAddr().(*net.UDPAddr).Port
	go udpServer.ActivateAndServe()
	defer udpServer.Shutdown()

	// The weavedns server, pointed at the above server
	dnsserver, _, udpPort, _ := startServer(t, &dns.ClientConfig{
		Servers:  []string{"127.0.0.1"},
		Port:     strconv.Itoa(udpServerPort),
		Ndots:    1,
		Timeout:  5,
		Attempts: 2,
	})
	defer dnsserver.Stop()

	// Now do lookup, check its what we expected.
	// NB this doesn't really test golang's resolver behaves correctly, as I can't see
	// a way to point golangs resolver at a specific hosts.
	req := new(dns.Msg)
	req.Id = dns.Id()
	req.RecursionDesired = true
	req.Question = make([]dns.Question, 1)
	req.Question[0] = dns.Question{
		Name:   hostname,
		Qtype:  dns.TypeA,
		Qclass: dns.ClassINET,
	}
	c := new(dns.Client)
	res, _, err := c.Exchange(req, fmt.Sprintf("127.0.0.1:%d", udpPort))
	require.Nil(t, err)
	require.True(t, len(gotRequest) > 0)
	require.True(t, res.Len() > maxSize)
}
Ejemplo n.º 25
0
// ServeDNS is the handler for DNS requests, responsible for parsing DNS request, possibly forwarding
// it to a real dns server and returning a response.
func (s *server) ServeDNS(w dns.ResponseWriter, req *dns.Msg) {
	m := new(dns.Msg)
	m.SetReply(req)
	m.Authoritative = true
	m.RecursionAvailable = true
	m.Compress = true
	bufsize := uint16(512)
	dnssec := false
	tcp := false
	start := time.Now()

	if req.Question[0].Qtype == dns.TypeANY {
		m.Authoritative = false
		m.Rcode = dns.RcodeRefused
		m.RecursionAvailable = false
		m.RecursionDesired = false
		m.Compress = false
		// if write fails don't care
		w.WriteMsg(m)

		promErrorCount.WithLabelValues("refused").Inc()
		return
	}

	if o := req.IsEdns0(); o != nil {
		bufsize = o.UDPSize()
		dnssec = o.Do()
	}
	if bufsize < 512 {
		bufsize = 512
	}
	// with TCP we can send 64K
	if tcp = isTCP(w); tcp {
		bufsize = dns.MaxMsgSize - 1
		promRequestCount.WithLabelValues("tcp").Inc()
	} else {
		promRequestCount.WithLabelValues("udp").Inc()
	}

	StatsRequestCount.Inc(1)

	if dnssec {
		StatsDnssecOkCount.Inc(1)
		promDnssecOkCount.Inc()
	}

	defer func() {
		promCacheSize.WithLabelValues("response").Set(float64(s.rcache.Size()))
	}()

	// Check cache first.
	key := cache.QuestionKey(req.Question[0], dnssec)
	m1, exp, hit := s.rcache.Search(key)
	if hit {
		// Cache hit! \o/
		if time.Since(exp) < 0 {
			m1.Id = m.Id
			m1.Compress = true
			m1.Truncated = false

			if dnssec {
				// The key for DNS/DNSSEC in cache is different, no
				// need to do Denial/Sign here.
				//if s.config.PubKey != nil {
				//s.Denial(m1) // not needed for cache hits
				//s.Sign(m1, bufsize)
				//}
			}
			if m1.Len() > int(bufsize) && !tcp {
				promErrorCount.WithLabelValues("truncated").Inc()
				m1.Truncated = true
			}
			// Still round-robin even with hits from the cache.
			// Only shuffle A and AAAA records with each other.
			if req.Question[0].Qtype == dns.TypeA || req.Question[0].Qtype == dns.TypeAAAA {
				s.RoundRobin(m1.Answer)
			}

			if err := w.WriteMsg(m1); err != nil {
				log.Printf("skydns: failure to return reply %q", err)
			}
			metricSizeAndDuration(m1, start, tcp)
			return
		}
		// Expired! /o\
		s.rcache.Remove(key)
	}

	q := req.Question[0]
	name := strings.ToLower(q.Name)

	if s.config.Verbose {
		log.Printf("skydns: received DNS Request for %q from %q with type %d", q.Name, w.RemoteAddr(), q.Qtype)
	}

	for zone, ns := range *s.config.stub {
		if strings.HasSuffix(name, zone) {
			resp := s.ServeDNSStubForward(w, req, ns)
			metricSizeAndDuration(resp, start, tcp)
			return
		}
	}

	// If the qname is local.dns.skydns.local. and s.config.Local != "", substitute that name.
	if s.config.Local != "" && name == s.config.localDomain {
		name = s.config.Local
	}

	if q.Qtype == dns.TypePTR && strings.HasSuffix(name, ".in-addr.arpa.") || strings.HasSuffix(name, ".ip6.arpa.") {
		resp := s.ServeDNSReverse(w, req)
		metricSizeAndDuration(resp, start, tcp)
		return
	}

	if q.Qclass != dns.ClassCHAOS && !strings.HasSuffix(name, s.config.Domain) {
		if s.config.Verbose {
			log.Printf("skydns: %q is not sub of %q, forwarding...", name, s.config.Domain)
		}

		resp := s.ServeDNSForward(w, req)
		metricSizeAndDuration(resp, start, tcp)
		return
	}

	promCacheMiss.WithLabelValues("response").Inc()

	defer func() {
		if m.Rcode == dns.RcodeServerFailure {
			if err := w.WriteMsg(m); err != nil {
				log.Printf("skydns: failure to return reply %q", err)
			}
			return
		}
		// Set TTL to the minimum of the RRset and dedup the message, i.e.
		// remove identical RRs.
		m = s.dedup(m)

		minttl := s.config.Ttl
		if len(m.Answer) > 1 {
			for _, r := range m.Answer {
				if r.Header().Ttl < minttl {
					minttl = r.Header().Ttl
				}
			}
			for _, r := range m.Answer {
				r.Header().Ttl = minttl
			}
		}

		if !m.Truncated {
			s.rcache.InsertMessage(cache.QuestionKey(req.Question[0], dnssec), m)
		}

		if dnssec {
			if s.config.PubKey != nil {
				m.AuthenticatedData = true
				s.Denial(m)
				s.Sign(m, bufsize)
			}
		}

		if m.Len() > dns.MaxMsgSize {
			log.Printf("skydns: overflowing maximum message size: %d, dropping additional section", m.Len())
			m.Extra = nil // Drop entire additional section to see if this helps.

			if m.Len() > dns.MaxMsgSize {
				// *Still* too large.
				log.Printf("skydns: still overflowing maximum message size: %d", m.Len())
				promErrorCount.WithLabelValues("overflow").Inc()
				m1 := new(dns.Msg) // Use smaller msg to signal failure.
				m1.SetRcode(m, dns.RcodeServerFailure)
				if err := w.WriteMsg(m1); err != nil {
					log.Printf("skydns: failure to return reply %q", err)
				}
				metricSizeAndDuration(m1, start, tcp)
				return
			}
		}

		if m.Len() > int(bufsize) && !tcp {
			m.Extra = nil // As above, drop entire additional section.
			if m.Len() > int(bufsize) {
				promErrorCount.WithLabelValues("truncated").Inc()
				m.Truncated = true
			}
		}

		if err := w.WriteMsg(m); err != nil {
			log.Printf("skydns: failure to return reply %q %d", err, m.Len())
		}
		metricSizeAndDuration(m, start, tcp)
	}()

	if name == s.config.Domain {
		if q.Qtype == dns.TypeSOA {
			m.Answer = []dns.RR{s.NewSOA()}
			return
		}
		if q.Qtype == dns.TypeDNSKEY {
			if s.config.PubKey != nil {
				m.Answer = []dns.RR{s.config.PubKey}
				return
			}
		}
	}
	if q.Qclass == dns.ClassCHAOS {
		if q.Qtype == dns.TypeTXT {
			switch name {
			case "authors.bind.":
				fallthrough
			case s.config.Domain:
				hdr := dns.RR_Header{Name: q.Name, Rrtype: dns.TypeTXT, Class: dns.ClassCHAOS, Ttl: 0}
				authors := []string{"Erik St. Martin", "Brian Ketelsen", "Miek Gieben", "Michael Crosby"}
				for _, a := range authors {
					m.Answer = append(m.Answer, &dns.TXT{Hdr: hdr, Txt: []string{a}})
				}
				for j := 0; j < len(authors)*(int(dns.Id())%4+1); j++ {
					q := int(dns.Id()) % len(authors)
					p := int(dns.Id()) % len(authors)
					if q == p {
						p = (p + 1) % len(authors)
					}
					m.Answer[q], m.Answer[p] = m.Answer[p], m.Answer[q]
				}
				return
			case "version.bind.":
				fallthrough
			case "version.server.":
				hdr := dns.RR_Header{Name: q.Name, Rrtype: dns.TypeTXT, Class: dns.ClassCHAOS, Ttl: 0}
				m.Answer = []dns.RR{&dns.TXT{Hdr: hdr, Txt: []string{Version}}}
				return
			case "hostname.bind.":
				fallthrough
			case "id.server.":
				// TODO(miek): machine name to return
				hdr := dns.RR_Header{Name: q.Name, Rrtype: dns.TypeTXT, Class: dns.ClassCHAOS, Ttl: 0}
				m.Answer = []dns.RR{&dns.TXT{Hdr: hdr, Txt: []string{"localhost"}}}
				return
			}
		}
		// still here, fail
		m.SetReply(req)
		m.SetRcode(req, dns.RcodeServerFailure)
		return
	}

	switch q.Qtype {
	case dns.TypeNS:
		if name != s.config.Domain {
			log.Printf("skydns: %q unmatch default domain", name)
			break
		}
		// Lookup s.config.DnsDomain
		records, extra, err := s.NSRecords(q, s.config.dnsDomain)
		if err != nil {
			if e, ok := err.(*etcd.EtcdError); ok {
				if e.ErrorCode == 100 {
					s.NameError(m, req)
					return
				}
			}
		}
		m.Answer = append(m.Answer, records...)
		m.Extra = append(m.Extra, extra...)
	case dns.TypeA, dns.TypeAAAA:
		records, err := s.AddressRecords(q, name, nil, bufsize, dnssec, false)
		if err != nil {
			if e, ok := err.(*etcd.EtcdError); ok {
				if e.ErrorCode == 100 {
					s.NameError(m, req)
					return
				}
			}
		}
		m.Answer = append(m.Answer, records...)
	case dns.TypeTXT:
		records, err := s.TXTRecords(q, name)
		if err != nil {
			if e, ok := err.(*etcd.EtcdError); ok {
				if e.ErrorCode == 100 {
					s.NameError(m, req)
					return
				}
			}
		}
		m.Answer = append(m.Answer, records...)
	case dns.TypeCNAME:
		records, err := s.CNAMERecords(q, name)
		if err != nil {
			if e, ok := err.(*etcd.EtcdError); ok {
				if e.ErrorCode == 100 {
					s.NameError(m, req)
					return
				}
			}
		}
		m.Answer = append(m.Answer, records...)
	case dns.TypeMX:
		records, extra, err := s.MXRecords(q, name, bufsize, dnssec)
		if err != nil {
			if e, ok := err.(*etcd.EtcdError); ok {
				if e.ErrorCode == 100 {
					s.NameError(m, req)
					return
				}
			}
		}
		m.Answer = append(m.Answer, records...)
		m.Extra = append(m.Extra, extra...)
	default:
		fallthrough // also catch other types, so that they return NODATA
	case dns.TypeSRV:
		records, extra, err := s.SRVRecords(q, name, bufsize, dnssec)
		if err != nil {
			if e, ok := err.(*etcd.EtcdError); ok {
				if e.ErrorCode == 100 {
					s.NameError(m, req)
					return
				}
			}
			if q.Qtype == dns.TypeSRV { // Otherwise NODATA
				s.ServerFailure(m, req)
				return
			}
		}
		// if we are here again, check the types, because an answer may only
		// be given for SRV. All other types should return NODATA, the
		// NXDOMAIN part is handled in the above code. TODO(miek): yes this
		// can be done in a more elegant manor.
		if q.Qtype == dns.TypeSRV {
			m.Answer = append(m.Answer, records...)
			m.Extra = append(m.Extra, extra...)
		}
	}

	if len(m.Answer) == 0 { // NODATA response
		StatsNoDataCount.Inc(1)
		m.Ns = []dns.RR{s.NewSOA()}
		m.Ns[0].Header().Ttl = s.config.MinTtl
	}
}
Ejemplo n.º 26
0
// truncate sets the TC bit in the given dns.Msg if its length exceeds the
// permitted length of the given transmission channel.
// See https://tools.ietf.org/html/rfc1035#section-4.2.1
func truncate(m *dns.Msg, udp bool) *dns.Msg {
	m.Truncated = udp && m.Len() > dns.MinMsgSize
	return m
}
Ejemplo n.º 27
0
// ServeDNS is the handler for DNS requests, responsible for parsing DNS request, possibly forwarding
// it to a real dns server and returning a response.
func (s *server) ServeDNS(w dns.ResponseWriter, req *dns.Msg) {
	m := new(dns.Msg)
	m.SetReply(req)
	m.Authoritative = true
	m.RecursionAvailable = true
	m.Compress = true
	bufsize := uint16(512)
	dnssec := false
	tcp := false

	if req.Question[0].Qtype == dns.TypeANY {
		m.Authoritative = false
		m.Rcode = dns.RcodeRefused
		m.RecursionAvailable = false
		m.RecursionDesired = false
		m.Compress = false
		// if write fails don't care
		w.WriteMsg(m)
		return
	}

	if o := req.IsEdns0(); o != nil {
		bufsize = o.UDPSize()
		dnssec = o.Do()
	}
	if bufsize < 512 {
		bufsize = 512
	}
	// with TCP we can send 64K
	if _, ok := w.RemoteAddr().(*net.TCPAddr); ok {
		bufsize = dns.MaxMsgSize - 1
		tcp = true
	}
	// Check cache first.
	key := cache.QuestionKey(req.Question[0], dnssec)
	m1, exp, hit := s.rcache.Search(key)
	if hit {
		// Cache hit! \o/
		if time.Since(exp) < 0 {
			m1.Id = m.Id
			if dnssec {
				StatsDnssecOkCount.Inc(1)
				// The key for DNS/DNSSEC in cache is different, no
				// need to do Denial/Sign here.
				//if s.config.PubKey != nil {
				//s.Denial(m1) // not needed for cache hits
				//s.Sign(m1, bufsize)
				//}
			}
			if m1.Len() > int(bufsize) && !tcp {
				m1.Truncated = true
			}
			// Still round-robin even with hits from the cache.
			if s.config.RoundRobin && (req.Question[0].Qtype == dns.TypeA ||
				req.Question[0].Qtype == dns.TypeAAAA) {
				switch l := len(m1.Answer); l {
				case 2:
					if dns.Id()%2 == 0 {
						m1.Answer[0], m1.Answer[1] = m1.Answer[1], m1.Answer[0]
					}
				default:
					// Do a minimum of l swap, maximum of 4l swaps
					for j := 0; j < l*(int(dns.Id())%4+1); j++ {
						q := int(dns.Id()) % l
						p := int(dns.Id()) % l
						if q == p {
							p = (p + 1) % l
						}
						m1.Answer[q], m1.Answer[p] = m1.Answer[p], m1.Answer[q]
					}
				}

			}
			if err := w.WriteMsg(m1); err != nil {
				s.config.log.Errorf("failure to return reply %q", err)
			}
			return
		}
		// Expired! /o\
		s.rcache.Remove(key)
	}

	q := req.Question[0]
	name := strings.ToLower(q.Name)
	StatsRequestCount.Inc(1)
	if verbose {
		s.config.log.Infof("received DNS Request for %q from %q with type %d", q.Name, w.RemoteAddr(), q.Qtype)
	}
	// If the qname is local.dns.skydns.local. and s.config.Local != "", substitute that name.
	if s.config.Local != "" && name == s.config.localDomain {
		name = s.config.Local
	}

	if q.Qtype == dns.TypePTR && strings.HasSuffix(name, ".in-addr.arpa.") || strings.HasSuffix(name, ".ip6.arpa.") {
		s.ServeDNSReverse(w, req)
		return
	}

	if q.Qclass != dns.ClassCHAOS && !strings.HasSuffix(name, s.config.Domain) {
		s.ServeDNSForward(w, req)
		return
	}

	defer func() {
		if m.Rcode == dns.RcodeServerFailure {
			if err := w.WriteMsg(m); err != nil {
				s.config.log.Errorf("failure to return reply %q", err)
			}
			return
		}
		// Set TTL to the minimum of the RRset.
		minttl := s.config.Ttl
		if len(m.Answer) > 1 {
			for _, r := range m.Answer {
				if r.Header().Ttl < minttl {
					minttl = r.Header().Ttl
				}
			}
			for _, r := range m.Answer {
				r.Header().Ttl = minttl
			}
		}

		s.rcache.InsertMessage(cache.QuestionKey(req.Question[0], dnssec), m)

		if dnssec {
			StatsDnssecOkCount.Inc(1)
			if s.config.PubKey != nil {
				m.AuthenticatedData = true
				s.Denial(m)
				s.Sign(m, bufsize)
			}
		}
		if m.Len() > int(bufsize) && !tcp {
			// TODO(miek): this is a little brain dead, better is to not add
			// RRs in the message in the first place.
			m.Truncated = true
		}
		if err := w.WriteMsg(m); err != nil {
			s.config.log.Errorf("failure to return reply %q", err)
		}
	}()

	if name == s.config.Domain {
		if q.Qtype == dns.TypeSOA {
			m.Answer = []dns.RR{s.NewSOA()}
			return
		}
		if q.Qtype == dns.TypeDNSKEY {
			if s.config.PubKey != nil {
				m.Answer = []dns.RR{s.config.PubKey}
				return
			}
		}
	}
	if q.Qclass == dns.ClassCHAOS {
		if q.Qtype == dns.TypeTXT {
			switch name {
			case "authors.bind.":
				fallthrough
			case s.config.Domain:
				hdr := dns.RR_Header{Name: q.Name, Rrtype: dns.TypeTXT, Class: dns.ClassCHAOS, Ttl: 0}
				authors := []string{"Erik St. Martin", "Brian Ketelsen", "Miek Gieben", "Michael Crosby"}
				for _, a := range authors {
					m.Answer = append(m.Answer, &dns.TXT{Hdr: hdr, Txt: []string{a}})
				}
				for j := 0; j < len(authors)*(int(dns.Id())%4+1); j++ {
					q := int(dns.Id()) % len(authors)
					p := int(dns.Id()) % len(authors)
					if q == p {
						p = (p + 1) % len(authors)
					}
					m.Answer[q], m.Answer[p] = m.Answer[p], m.Answer[q]
				}
				return
			case "version.bind.":
				fallthrough
			case "version.server.":
				hdr := dns.RR_Header{Name: q.Name, Rrtype: dns.TypeTXT, Class: dns.ClassCHAOS, Ttl: 0}
				m.Answer = []dns.RR{&dns.TXT{Hdr: hdr, Txt: []string{Version}}}
				return
			case "hostname.bind.":
				fallthrough
			case "id.server.":
				// TODO(miek): machine name to return
				hdr := dns.RR_Header{Name: q.Name, Rrtype: dns.TypeTXT, Class: dns.ClassCHAOS, Ttl: 0}
				m.Answer = []dns.RR{&dns.TXT{Hdr: hdr, Txt: []string{"localhost"}}}
				return
			}
		}
		// still here, fail
		m.SetReply(req)
		m.SetRcode(req, dns.RcodeServerFailure)
		return
	}

	switch q.Qtype {
	case dns.TypeNS:
		if name != s.config.Domain {
			break
		}
		// Lookup s.config.DnsDomain
		records, extra, err := s.NSRecords(q, s.config.dnsDomain)
		if err != nil {
			if e, ok := err.(*etcd.EtcdError); ok {
				if e.ErrorCode == 100 {
					s.NameError(m, req)
					return
				}
			}
		}
		m.Answer = append(m.Answer, records...)
		m.Extra = append(m.Extra, extra...)
	case dns.TypeA, dns.TypeAAAA:
		records, err := s.AddressRecords(q, name, nil)
		if err != nil {
			if e, ok := err.(*etcd.EtcdError); ok {
				if e.ErrorCode == 100 {
					s.NameError(m, req)
					return
				}
			}
			if err.Error() == "incomplete CNAME chain" {
				// We can not complete the CNAME internally, *iff* there is a
				// external name in the set, take it, and try to resolve it externally.
				if len(records) == 0 {
					s.NameError(m, req)
					return
				}
				target := ""
				for _, r := range records {
					if v, ok := r.(*dns.CNAME); ok {
						if !dns.IsSubDomain(s.config.Domain, v.Target) {
							target = v.Target
							break
						}
					}
				}
				if target == "" {
					s.config.log.Warningf("incomplete CNAME chain for %s", name)
					s.NoDataError(m, req)
					return
				}
				m1, e1 := s.Lookup(target, req.Question[0].Qtype, bufsize, dnssec)
				if e1 != nil {
					s.config.log.Errorf("%q", err)
					s.NoDataError(m, req)
					return
				}
				records = append(records, m1.Answer...)
			}
		}
		m.Answer = append(m.Answer, records...)
	case dns.TypeCNAME:
		records, err := s.CNAMERecords(q, name)
		if err != nil {
			if e, ok := err.(*etcd.EtcdError); ok {
				if e.ErrorCode == 100 {
					s.NameError(m, req)
					return
				}
			}
		}
		m.Answer = append(m.Answer, records...)
	default:
		fallthrough // also catch other types, so that they return NODATA
	case dns.TypeSRV, dns.TypeANY:
		records, extra, err := s.SRVRecords(q, name, bufsize, dnssec)
		if err != nil {
			if e, ok := err.(*etcd.EtcdError); ok {
				if e.ErrorCode == 100 {
					s.NameError(m, req)
					return
				}
			}
		}
		// if we are here again, check the types, because an answer may only
		// be given for SRV or ANY. All other types should return NODATA, the
		// NXDOMAIN part is handled in the above code. TODO(miek): yes this
		// can be done in a more elegant manor.
		if q.Qtype == dns.TypeSRV || q.Qtype == dns.TypeANY {
			m.Answer = append(m.Answer, records...)
			m.Extra = append(m.Extra, extra...)
		}
	}

	if len(m.Answer) == 0 { // NODATA response
		StatsNoDataCount.Inc(1)
		m.Ns = []dns.RR{s.NewSOA()}
		m.Ns[0].Header().Ttl = s.config.MinTtl
	}
}
Ejemplo n.º 28
0
// sign signs a message m, it takes care of negative or nodata responses as
// well by synthesising NSEC3 records. It will also cache the signatures, using
// a hash of the signed data as a key.
// We also fake the origin TTL in the signature, because we don't want to
// throw away signatures when services decide to have longer TTL. So we just
// set the origTTL to 60.
func (s *server) sign(m *dns.Msg, bufsize uint16) {
	now := time.Now().UTC()
	incep := uint32(now.Add(-2 * time.Hour).Unix())     // 2 hours, be sure to catch daylight saving time and such
	expir := uint32(now.Add(7 * 24 * time.Hour).Unix()) // sign for a week

	// TODO(miek): repeating this two times?
	for _, r := range rrSets(m.Answer) {
		if r[0].Header().Rrtype == dns.TypeRRSIG {
			continue
		}
		key := cache.key(r)
		if s := cache.search(key); s != nil {
			if s.ValidityPeriod(now.Add(-24 * time.Hour)) {
				m.Answer = append(m.Answer, s)
				continue
			}
			cache.remove(key)
		}
		sig, err, shared := inflight.Do(key, func() (*dns.RRSIG, error) {
			sig1 := s.NewRRSIG(incep, expir)
			if r[0].Header().Rrtype == dns.TypeNSEC3 {
				sig1.OrigTtl = s.config.MinTtl
				sig1.Header().Ttl = s.config.MinTtl
			}
			e := sig1.Sign(s.config.PrivKey, r)
			if e != nil {
				log.Printf("failed to sign: %s\n", e.Error())
			}
			return sig1, e
		})
		if err != nil {
			continue
		}
		if !shared {
			// is it possible to miss this, due the the c.dups > 0 in Do()? TODO(miek)
			cache.insert(key, sig)
		}
		m.Answer = append(m.Answer, dns.Copy(sig).(*dns.RRSIG))
	}
	for _, r := range rrSets(m.Ns) {
		if r[0].Header().Rrtype == dns.TypeRRSIG {
			continue
		}
		key := cache.key(r)
		if s := cache.search(key); s != nil {
			if s.ValidityPeriod(now.Add(-24 * time.Hour)) {
				m.Ns = append(m.Ns, s)
				continue
			}
			cache.remove(key)
		}
		sig, err, shared := inflight.Do(key, func() (*dns.RRSIG, error) {
			sig1 := s.NewRRSIG(incep, expir)
			if r[0].Header().Rrtype == dns.TypeNSEC3 {
				sig1.OrigTtl = s.config.MinTtl
				sig1.Header().Ttl = s.config.MinTtl
			}
			e := sig1.Sign(s.config.PrivKey, r)
			if e != nil {
				log.Printf("failed to sign: %s\n", e.Error())
			}
			return sig1, e
		})
		if err != nil {
			continue
		}
		if !shared {
			// is it possible to miss this, due the the c.dups > 0 in Do()? TODO(miek)
			cache.insert(key, sig)
		}
		m.Ns = append(m.Ns, dns.Copy(sig).(*dns.RRSIG))
	}
	// TODO(miek): Forget the additional section for now
	if bufsize >= 512 || bufsize <= 4096 {
		m.Truncated = m.Len() > int(bufsize)
	}
	o := new(dns.OPT)
	o.Hdr.Name = "."
	o.Hdr.Rrtype = dns.TypeOPT
	o.SetDo()
	o.SetUDPSize(4096)
	m.Extra = append(m.Extra, o)
	return
}
Ejemplo n.º 29
0
func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) {
	var (
		extConn net.Conn
		resp    *dns.Msg
		err     error
	)

	if query == nil || len(query.Question) == 0 {
		return
	}
	name := query.Question[0].Name

	switch query.Question[0].Qtype {
	case dns.TypeA:
		resp, err = r.handleIPQuery(name, query, types.IPv4)
	case dns.TypeAAAA:
		resp, err = r.handleIPQuery(name, query, types.IPv6)
	case dns.TypePTR:
		resp, err = r.handlePTRQuery(name, query)
	case dns.TypeSRV:
		resp, err = r.handleSRVQuery(name, query)
	}

	if err != nil {
		logrus.Error(err)
		return
	}

	if resp == nil {
		// If the backend doesn't support proxying dns request
		// fail the response
		if !r.proxyDNS {
			resp = new(dns.Msg)
			resp.SetRcode(query, dns.RcodeServerFailure)
			w.WriteMsg(resp)
			return
		}

		// If the user sets ndots > 0 explicitly and the query is
		// in the root domain don't forward it out. We will return
		// failure and let the client retry with the search domain
		// attached
		switch query.Question[0].Qtype {
		case dns.TypeA:
			fallthrough
		case dns.TypeAAAA:
			if r.backend.NdotsSet() && !strings.Contains(strings.TrimSuffix(name, "."), ".") {
				resp = createRespMsg(query)
			}
		}
	}

	proto := w.LocalAddr().Network()
	maxSize := 0
	if proto == "tcp" {
		maxSize = dns.MaxMsgSize - 1
	} else if proto == "udp" {
		optRR := query.IsEdns0()
		if optRR != nil {
			maxSize = int(optRR.UDPSize())
		}
		if maxSize < defaultRespSize {
			maxSize = defaultRespSize
		}
	}

	if resp != nil {
		if resp.Len() > maxSize {
			truncateResp(resp, maxSize, proto == "tcp")
		}
	} else {
		for i := 0; i < maxExtDNS; i++ {
			extDNS := &r.extDNSList[i]
			if extDNS.ipStr == "" {
				break
			}
			extConnect := func() {
				addr := fmt.Sprintf("%s:%d", extDNS.ipStr, 53)
				extConn, err = net.DialTimeout(proto, addr, extIOTimeout)
			}

			if extDNS.hostLoopback {
				extConnect()
			} else {
				execErr := r.backend.ExecFunc(extConnect)
				if execErr != nil {
					logrus.Warn(execErr)
					continue
				}
			}
			if err != nil {
				logrus.Warnf("Connect failed: %s", err)
				continue
			}
			logrus.Debugf("Query %s[%d] from %s, forwarding to %s:%s", name, query.Question[0].Qtype,
				extConn.LocalAddr().String(), proto, extDNS.ipStr)

			// Timeout has to be set for every IO operation.
			extConn.SetDeadline(time.Now().Add(extIOTimeout))
			co := &dns.Conn{
				Conn:    extConn,
				UDPSize: uint16(maxSize),
			}
			defer co.Close()

			// limits the number of outstanding concurrent queries.
			if r.forwardQueryStart() == false {
				old := r.tStamp
				r.tStamp = time.Now()
				if r.tStamp.Sub(old) > logInterval {
					logrus.Errorf("More than %v concurrent queries from %s", maxConcurrent, extConn.LocalAddr().String())
				}
				continue
			}

			err = co.WriteMsg(query)
			if err != nil {
				r.forwardQueryEnd()
				logrus.Debugf("Send to DNS server failed, %s", err)
				continue
			}

			resp, err = co.ReadMsg()
			// Truncated DNS replies should be sent to the client so that the
			// client can retry over TCP
			if err != nil && err != dns.ErrTruncated {
				r.forwardQueryEnd()
				logrus.Debugf("Read from DNS server failed, %s", err)
				continue
			}
			r.forwardQueryEnd()
			if resp != nil {
				for _, rr := range resp.Answer {
					h := rr.Header()
					switch h.Rrtype {
					case dns.TypeA:
						ip := rr.(*dns.A).A
						r.backend.HandleQueryResp(h.Name, ip)
					case dns.TypeAAAA:
						ip := rr.(*dns.AAAA).AAAA
						r.backend.HandleQueryResp(h.Name, ip)
					}
				}
			}
			resp.Compress = true
			break
		}
		if resp == nil {
			return
		}
	}

	if err = w.WriteMsg(resp); err != nil {
		logrus.Errorf("error writing resolver resp, %s", err)
	}
}
Ejemplo n.º 30
0
Archivo: q.go Proyecto: raybejjani/dns
func main() {
	short = flag.Bool("short", false, "abbreviate long DNSSEC records")
	dnssec := flag.Bool("dnssec", false, "request DNSSEC records")
	query := flag.Bool("question", false, "show question")
	check := flag.Bool("check", false, "check internal DNSSEC consistency")
	raw := flag.Bool("raw", false, "do not strip 'http://' from the qname")
	six := flag.Bool("6", false, "use IPv6 only")
	four := flag.Bool("4", false, "use IPv4 only")
	anchor := flag.String("anchor", "", "use the DNSKEY in this file for interal DNSSEC consistency")
	tsig := flag.String("tsig", "", "request tsig with key: [hmac:]name:key")
	port := flag.Int("port", 53, "port number to use")
	aa := flag.Bool("aa", false, "set AA flag in query")
	ad := flag.Bool("ad", false, "set AD flag in query")
	cd := flag.Bool("cd", false, "set CD flag in query")
	rd := flag.Bool("rd", true, "set RD flag in query")
	fallback := flag.Bool("fallback", false, "fallback to 4096 bytes bufsize and after that TCP")
	tcp := flag.Bool("tcp", false, "TCP mode")
	nsid := flag.Bool("nsid", false, "set edns nsid option")
	client := flag.String("client", "", "set edns client-subnet option")
	//serial := flag.Int("serial", 0, "perform an IXFR with this serial")
	flag.Usage = func() {
		fmt.Fprintf(os.Stderr, "Usage: %s [options] [@server] [qtype] [qclass] [name ...]\n", os.Args[0])
		flag.PrintDefaults()
	}

	qtype := uint16(0)
	qclass := uint16(dns.ClassINET)
	var qname []string

	flag.Parse()
	if *anchor != "" {
		f, err := os.Open(*anchor)
		if err != nil {
			fmt.Fprintf(os.Stderr, "Failure to open %s: %s\n", *anchor, err.Error())
		}
		r, err := dns.ReadRR(f, *anchor)
		if err != nil {
			fmt.Fprintf(os.Stderr, "Failure to read an RR from %s: %s\n", *anchor, err.Error())
		}
		if k, ok := r.(*dns.DNSKEY); !ok {
			fmt.Fprintf(os.Stderr, "No DNSKEY read from %s\n", *anchor)
		} else {
			dnskey = k
		}
	}

	var nameserver string

Flags:
	for i := 0; i < flag.NArg(); i++ {
		// If it starts with @ it is a nameserver
		if flag.Arg(i)[0] == '@' {
			nameserver = flag.Arg(i)
			continue Flags
		}
		// First class, then type, to make ANY queries possible
		// And if it looks like type, it is a type
		if k, ok := dns.StringToType[strings.ToUpper(flag.Arg(i))]; ok {
			qtype = k
			continue Flags
		}
		// If it looks like a class, it is a class
		if k, ok := dns.StringToClass[strings.ToUpper(flag.Arg(i))]; ok {
			qclass = k
			continue Flags
		}
		// If it starts with TYPExxx it is unknown rr
		if strings.HasPrefix(flag.Arg(i), "TYPE") {
			i, e := strconv.Atoi(string([]byte(flag.Arg(i))[4:]))
			if e == nil {
				qtype = uint16(i)
				continue Flags
			}
		}

		// Anything else is a qname
		qname = append(qname, flag.Arg(i))
	}
	if len(qname) == 0 {
		qname = make([]string, 1)
		qname[0] = "."
		qtype = dns.TypeNS
	}
	if qtype == 0 {
		qtype = dns.TypeA
	}

	if len(nameserver) == 0 {
		conf, err := dns.ClientConfigFromFile("/etc/resolv.conf")
		if err != nil {
			fmt.Fprintln(os.Stderr, err)
			os.Exit(2)
		}
		nameserver = "@" + conf.Servers[0]
	}

	nameserver = string([]byte(nameserver)[1:]) // chop off @
	// if the nameserver is from /etc/resolv.conf the [ and ] are already
	// added, thereby breaking net.ParseIP. Check for this and don't
	// fully qualify such a name
	if nameserver[0] == '[' && nameserver[len(nameserver)-1] == ']' {
		nameserver = nameserver[1 : len(nameserver)-1]
	}
	if i := net.ParseIP(nameserver); i != nil {
		nameserver = net.JoinHostPort(nameserver, strconv.Itoa(*port))
	} else {
		nameserver = dns.Fqdn(nameserver) + ":" + strconv.Itoa(*port)
	}
	c := new(dns.Client)
	if *tcp {
		c.Net = "tcp"
		if *four {
			c.Net = "tcp4"
		}
		if *six {
			c.Net = "tcp6"
		}
	} else {
		c.Net = "udp"
		if *four {
			c.Net = "udp4"
		}
		if *six {
			c.Net = "udp6"
		}
	}

	m := new(dns.Msg)
	m.MsgHdr.Authoritative = *aa
	m.MsgHdr.AuthenticatedData = *ad
	m.MsgHdr.CheckingDisabled = *cd
	m.MsgHdr.RecursionDesired = *rd
	m.Question = make([]dns.Question, 1)

	if *dnssec || *nsid || *client != "" {
		o := new(dns.OPT)
		o.Hdr.Name = "."
		o.Hdr.Rrtype = dns.TypeOPT
		if *dnssec {
			o.SetDo()
			o.SetUDPSize(dns.DefaultMsgSize)
		}
		if *nsid {
			e := new(dns.EDNS0_NSID)
			e.Code = dns.EDNS0NSID
			o.Option = append(o.Option, e)
			// NSD will not return nsid when the udp message size is too small
			o.SetUDPSize(dns.DefaultMsgSize)
		}
		if *client != "" {
			e := new(dns.EDNS0_SUBNET)
			e.Code = dns.EDNS0SUBNET
			e.SourceScope = 0
			e.Address = net.ParseIP(*client)
			if e.Address == nil {
				fmt.Fprintf(os.Stderr, "Failure to parse IP address: %s\n", *client)
				return
			}
			e.Family = 1 // IP4
			e.SourceNetmask = net.IPv4len * 8
			if e.Address.To4() == nil {
				e.Family = 2 // IP6
				e.SourceNetmask = net.IPv6len * 8
			}
			o.Option = append(o.Option, e)
		}
		m.Extra = append(m.Extra, o)
	}

	for _, v := range qname {
		if !*raw && strings.HasPrefix(v, "http://") {
			v = v[7:]
			if v[len(v)-1] == '/' {
				v = v[:len(v)-1]
			}
		}

		m.Question[0] = dns.Question{dns.Fqdn(v), qtype, qclass}
		m.Id = dns.Id()
		// Add tsig
		if *tsig != "" {
			if algo, name, secret, ok := tsigKeyParse(*tsig); ok {
				m.SetTsig(name, algo, 300, time.Now().Unix())
				c.TsigSecret = map[string]string{name: secret}
			} else {
				fmt.Fprintf(os.Stderr, "TSIG key data error\n")
				return
			}
		}
		if *query {
			fmt.Printf("%s", m.String())
			fmt.Printf("\n;; size: %d bytes\n\n", m.Len())
		}
		if qtype == dns.TypeAXFR {
			c.Net = "tcp"
			doXfr(c, m, nameserver)
			continue
		}
		if qtype == dns.TypeIXFR {
			doXfr(c, m, nameserver)
			continue
		}
		r, rtt, e := c.Exchange(m, nameserver)
	Redo:
		if e != nil {
			fmt.Printf(";; %s\n", e.Error())
			continue
		}
		if r.Id != m.Id {
			fmt.Fprintf(os.Stderr, "Id mismatch\n")
			return
		}
		if r.MsgHdr.Truncated && *fallback {
			if c.Net != "tcp" {
				if !*dnssec {
					fmt.Printf(";; Truncated, trying %d bytes bufsize\n", dns.DefaultMsgSize)
					o := new(dns.OPT)
					o.Hdr.Name = "."
					o.Hdr.Rrtype = dns.TypeOPT
					o.SetUDPSize(dns.DefaultMsgSize)
					m.Extra = append(m.Extra, o)
					r, rtt, e = c.Exchange(m, nameserver)
					*dnssec = true
					goto Redo
				} else {
					// First EDNS, then TCP
					fmt.Printf(";; Truncated, trying TCP\n")
					c.Net = "tcp"
					r, rtt, e = c.Exchange(m, nameserver)
					goto Redo
				}
			}
		}
		if r.MsgHdr.Truncated && !*fallback {
			fmt.Printf(";; Truncated\n")
		}
		if *check {
			sigCheck(r, nameserver, *tcp)
		}
		if *short {
			r = shortMsg(r)
		}

		fmt.Printf("%v", r)
		fmt.Printf("\n;; query time: %.3d µs, server: %s(%s), size: %d bytes\n", rtt/1e3, nameserver, c.Net, r.Len())
	}
}