// Forwards a DNS request to the specified nameservers. On success, the
// original reply packet will be returned to the caller. On failure, a
// new packet will be returned with `RCODE` set to `SERVFAIL`.
// Even though the original `ResponseWriter` object is taken as an argument,
// this function does not send a reply to the client. Instead, the
// packet is returned for further processing by the caller.
func getServerReply(w dns.ResponseWriter, req *dns.Msg) *dns.Msg {
	if *verbose {
		log.Print("Forwarding ", req.Question[0].Name, "/", dns.Type(req.Question[0].Qtype).String())
	}

	// create a new DNS client

	client := &dns.Client{Net: "udp", ReadTimeout: 4 * time.Second, WriteTimeout: 4 * time.Second, SingleInflight: true}

	if _, tcp := w.RemoteAddr().(*net.TCPAddr); tcp {
		client.Net = "tcp"
	}

	var r *dns.Msg
	var err error

	// loop through the specified nameservers and forward them the request

	// the request ID is used as a starting point in order to introduce at least
	// some element of randomness, instead of always hitting the first nameserver

	for i := 0; i < len(nameservers); i++ {
		r, _, err = client.Exchange(req, nameservers[(int(req.Id)+i)%len(nameservers)])

		if err == nil {
			r.Compress = true

			return r
		}
	}

	log.Print("Failed to forward request.", err)
	return getEmptyMsg(w, req, dns.RcodeServerFailure)
}
Exemple #2
0
func (p ReverseProxy) ServeDNS(w dns.ResponseWriter, r *dns.Msg, extra []dns.RR) error {
	var (
		reply *dns.Msg
		err   error
	)

	switch {
	case middleware.Proto(w) == "tcp":
		reply, err = middleware.Exchange(p.Client.TCP, r, p.Host)
	default:
		reply, err = middleware.Exchange(p.Client.UDP, r, p.Host)
	}

	if reply != nil && reply.Truncated {
		// Suppress proxy error for truncated responses
		err = nil
	}

	if err != nil {
		return err
	}

	reply.Compress = true
	reply.Id = r.Id
	w.WriteMsg(reply)
	return nil
}
Exemple #3
0
// ServeDNSForward forwards a request to a nameservers and returns the response.
func (s *server) ServeDNSForward(w dns.ResponseWriter, req *dns.Msg) *dns.Msg {

	if s.config.NoRec {
		m := s.ServerFailure(req)
		w.WriteMsg(m)
		return m
	}

	if len(s.config.Nameservers) == 0 || dns.CountLabel(req.Question[0].Name) < s.config.Ndots {
		if s.config.Verbose {
			if len(s.config.Nameservers) == 0 {
				logf("can not forward, no nameservers defined")
			} else {
				logf("can not forward, name too short (less than %d labels): `%s'", s.config.Ndots, req.Question[0].Name)
			}
		}
		m := s.ServerFailure(req)
		m.RecursionAvailable = true // this is still true
		w.WriteMsg(m)
		return m
	}

	tcp := isTCP(w)

	var (
		r   *dns.Msg
		err error
		try int
	)

	nsid := 0
	if s.config.NSRotate {
		// Use request Id for "random" nameserver selection.
		nsid = int(req.Id) % len(s.config.Nameservers)
	}
Redo:
	switch tcp {
	case false:
		r, _, err = s.dnsUDPclient.Exchange(req, s.config.Nameservers[nsid])
	case true:
		r, _, err = s.dnsTCPclient.Exchange(req, s.config.Nameservers[nsid])
	}
	if err == nil {
		r.Compress = true
		r.Id = req.Id
		w.WriteMsg(r)
		return r
	}
	// Seen an error, this can only mean, "server not reached", try again
	// but only if we have not exausted our nameservers.
	if try < len(s.config.Nameservers) {
		try++
		nsid = (nsid + 1) % len(s.config.Nameservers)
		goto Redo
	}

	logf("failure to forward request %q", err)
	m := s.ServerFailure(req)
	return m
}
Exemple #4
0
// handlerFunc receives requests, looks up the result and returns what is found.
func handlerFunc(res dns.ResponseWriter, req *dns.Msg) {
	message := new(dns.Msg)
	switch req.Opcode {
	case dns.OpcodeQuery:
		message.SetReply(req)
		message.Compress = false
		message.Answer = make([]dns.RR, 0)

		for _, question := range message.Question {
			answers := answerQuestion(strings.ToLower(question.Name), question.Qtype)
			if len(answers) > 0 {
				for i := range answers {
					message.Answer = append(message.Answer, answers[i])
				}
			} else {
				// If there are no records, go back through and search for SOA records
				for _, question := range message.Question {
					answers := answerQuestion(strings.ToLower(question.Name), dns.TypeSOA)
					for i := range answers {
						message.Ns = append(message.Ns, answers[i])
					}
				}
			}
		}
		if len(message.Answer) == 0 && len(message.Ns) == 0 {
			message.Rcode = dns.RcodeNameError
		}
	default:
		message = message.SetRcode(req, dns.RcodeNotImplemented)
	}
	res.WriteMsg(message)
}
// handleDNS is a handler function to actualy perform the dns querey response
func (c *CatchAll) handleDNS(w dns.ResponseWriter, r *dns.Msg) {
	defer w.Close()
	var rr dns.RR

	domainSpoof := r.Question[0].Name

	msgResp := new(dns.Msg)
	msgResp.SetReply(r)
	msgResp.Compress = false

	rr = new(dns.A)

	if c.SpoofDomain {
		rr.(*dns.A).Hdr = dns.RR_Header{Name: domainSpoof, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 0}
	} else {
		rr.(*dns.A).Hdr = dns.RR_Header{Name: c.Domain, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 0}
	}

	rr.(*dns.A).A = c.IP

	switch r.Question[0].Qtype {
	case dns.TypeA:
		msgResp.Answer = append(msgResp.Answer, rr)
	default:
		log.Warnf("Unknown dns type %T", r.Question[0].Qtype)
		return
	}

	w.WriteMsg(msgResp)
}
Exemple #6
0
// reply writes the given dns.Msg out to the given dns.ResponseWriter,
// compressing the message first and truncating it accordingly.
func reply(w dns.ResponseWriter, m *dns.Msg) {
	m.Compress = true // https://github.com/mesosphere/mesos-dns/issues/{170,173,174}

	if err := w.WriteMsg(truncate(m, isUDP(w))); err != nil {
		logging.Error.Println(err)
	}
}
Exemple #7
0
func handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
	if *debug {
		Log.Printf("handleRequest: message: %+v\n", r)
	}

	m := new(dns.Msg)
	m.SetReply(r)
	m.Compress = false

	switch r.Opcode {
	case dns.OpcodeQuery:
		parseQuery(m)

	case dns.OpcodeUpdate:
		for _, question := range r.Question {
			for _, rr := range r.Ns {
				updateRecord(rr, &question)
			}
		}
	}

	if r.IsTsig() != nil {
		if w.TsigStatus() == nil {
			m.SetTsig(r.Extra[len(r.Extra)-1].(*dns.TSIG).Hdr.Name,
				dns.HmacMD5, 300, time.Now().Unix())
		} else {
			Log.Println("Status", w.TsigStatus().Error())
		}
	}

	w.WriteMsg(m)
}
Exemple #8
0
// handleQuery is used to handle DNS queries in the configured domain
func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) {
	q := req.Question[0]
	defer func(s time.Time) {
		metrics.MeasureSince([]string{"consul", "dns", "domain_query", d.agent.config.NodeName}, s)
		d.logger.Printf("[DEBUG] dns: request for %v (%v) from client %s (%s)",
			q, time.Now().Sub(s), resp.RemoteAddr().String(),
			resp.RemoteAddr().Network())
	}(time.Now())

	// Switch to TCP if the client is
	network := "udp"
	if _, ok := resp.RemoteAddr().(*net.TCPAddr); ok {
		network = "tcp"
	}

	// Setup the message response
	m := new(dns.Msg)
	m.SetReply(req)
	m.Compress = !d.config.DisableCompression
	m.Authoritative = true
	m.RecursionAvailable = (len(d.recursors) > 0)

	// Only add the SOA if requested
	if req.Question[0].Qtype == dns.TypeSOA {
		d.addSOA(d.domain, m)
	}

	// Dispatch the correct handler
	d.dispatch(network, req, m)

	// Write out the complete response
	if err := resp.WriteMsg(m); err != nil {
		d.logger.Printf("[WARN] dns: failed to respond: %v", err)
	}
}
Exemple #9
0
// ServeDNSStubForward forwards a request to a nameservers and returns the response.
func (s *server) ServeDNSStubForward(w dns.ResponseWriter, req *dns.Msg, ns []string) *dns.Msg {
	// Check EDNS0 Stub option, if set drop the packet.
	option := req.IsEdns0()
	if option != nil {
		for _, o := range option.Option {
			if o.Option() == ednsStubCode && len(o.(*dns.EDNS0_LOCAL).Data) == 1 &&
				o.(*dns.EDNS0_LOCAL).Data[0] == 1 {
				// Maybe log source IP here?
				logf("not fowarding stub request to another stub")
				return nil
			}
		}
	}

	tcp := isTCP(w)

	// Add a custom EDNS0 option to the packet, so we can detect loops
	// when 2 stubs are forwarding to each other.
	if option != nil {
		option.Option = append(option.Option, &dns.EDNS0_LOCAL{ednsStubCode, []byte{1}})
	} else {
		req.Extra = append(req.Extra, ednsStub)
	}

	var (
		r   *dns.Msg
		err error
		try int
	)

	// Use request Id for "random" nameserver selection.
	nsid := int(req.Id) % len(ns)
Redo:
	switch tcp {
	case false:
		r, _, err = s.dnsUDPclient.Exchange(req, ns[nsid])
	case true:
		r, _, err = s.dnsTCPclient.Exchange(req, ns[nsid])
	}
	if err == nil {
		r.Compress = true
		r.Id = req.Id
		w.WriteMsg(r)
		return r
	}
	// Seen an error, this can only mean, "server not reached", try again
	// but only if we have not exausted our nameservers.
	if try < len(ns) {
		try++
		nsid = (nsid + 1) % len(ns)
		goto Redo
	}

	logf("failure to forward stub request %q", err)
	m := s.ServerFailure(req)
	w.WriteMsg(m)
	return m
}
Exemple #10
0
func CompressIfLarge(m *dns.Msg) {
	bytes, err := m.Pack()
	if err != nil {
		return
	}
	if len(bytes) > 512 { // may not fit into UDP packet
		m.Compress = true // will be compressed in WriteMsg
	}
}
Exemple #11
0
func (s *Server) handle(w dns.ResponseWriter, request *dns.Msg) {
	// Always close the writer
	defer w.Close()

	// Capture starting time for measuring message response time
	var start time.Time
	start = time.Now()

	// Setup the default response
	var response *dns.Msg
	response = &dns.Msg{}
	response.SetReply(request)
	response.Compress = true

	// Lookup answers to any of the questions
	for _, question := range request.Question {
		// Capture starting time for measuring lookup
		var lookupStart time.Time
		lookupStart = time.Now()

		// Perform lookup for this question
		var records []dns.RR
		records = s.registry.Lookup(question.Name, question.Qtype, question.Qclass)

		// Capture ending and elapsed time
		var lookupElapsed time.Duration
		lookupElapsed = time.Since(lookupStart)

		// Append results to the response
		response.Answer = append(response.Answer, records...)

		// If StatsD is enabled, record some metrics
		if s.statsd != nil {
			var tags []string
			tags = []string{
				fmt.Sprintf("name:%s", question.Name),
				fmt.Sprintf("qtype:%s", dns.TypeToString[question.Qtype]),
				fmt.Sprintf("qclass:%s", dns.ClassToString[question.Qclass]),
			}

			s.statsd.TimeInMilliseconds("lookup.time", lookupElapsed.Seconds()*1000.0, tags, 1)
			s.statsd.Histogram("lookup.answer", float64(len(records)), tags, 1)
			s.statsd.Count("request.question", 1, tags, 1)
		}
	}

	// Respond to the request
	w.WriteMsg(response)

	// Record any ending metrics
	if s.statsd != nil {
		var elapsed time.Duration
		elapsed = time.Since(start)
		s.statsd.TimeInMilliseconds("request.time", elapsed.Seconds()*1000.0, nil, 1)
	}
}
Exemple #12
0
// HandleForwarding forwards a request to the nameservers and returns the response
func (s *Server) HandleForwarding(w mdns.ResponseWriter, r *mdns.Msg) (bool, error) {
	defer trace.End(trace.Begin(r.String()))

	var m *mdns.Msg
	var err error
	var try int

	if len(s.Nameservers) == 0 {
		log.Errorf("No nameservers defined, can not forward")
		return false, respServerFailure(w, r)
	}

	// which protocol are they talking
	tcp := false
	if _, ok := w.RemoteAddr().(*net.TCPAddr); ok {
		tcp = true
	}

	// Use request ID for "random" nameserver selection.
	nsid := int(r.Id) % len(s.Nameservers)

Redo:
	nameserver := s.Nameservers[nsid]
	if i := strings.Index(nameserver, ":"); i < 0 {
		nameserver += ":53"
	}

	if tcp {
		m, _, err = s.tcpclient.Exchange(r, nameserver)
	} else {
		m, _, err = s.udpclient.Exchange(r, nameserver)
	}
	if err != nil {
		// Seen an error, this can only mean, "server not reached", try again but only if we have not exausted our nameservers.
		if try < len(s.Nameservers) {
			try++
			nsid = (nsid + 1) % len(s.Nameservers)
			goto Redo
		}

		log.Errorf("Failure to forward request: %q", err)
		return false, respServerFailure(w, r)
	}

	// We have a response so cache it
	s.cache.Add(m)

	m.Compress = true
	if err := w.WriteMsg(m); err != nil {
		log.Errorf("Error writing response: %q", err)
		return true, err
	}
	return true, nil
}
Exemple #13
0
func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) {
	var (
		resp *dns.Msg
		err  error
	)

	if query == nil || len(query.Question) == 0 {
		return
	}
	name := query.Question[0].Name
	if query.Question[0].Qtype == dns.TypeA {
		resp, err = r.handleIPv4Query(name, query)
	} else if query.Question[0].Qtype == dns.TypePTR {
		resp, err = r.handlePTRQuery(name, query)
	}

	if err != nil {
		log.Error(err)
		return
	}

	if resp == nil {
		if len(r.extDNS) == 0 {
			return
		}

		num := maxExtDNS
		if len(r.extDNS) < maxExtDNS {
			num = len(r.extDNS)
		}
		for i := 0; i < num; i++ {
			log.Debugf("Querying ext dns %s:%s for %s[%d]", w.LocalAddr().Network(), r.extDNS[i], name, query.Question[0].Qtype)

			c := &dns.Client{Net: w.LocalAddr().Network()}
			addr := fmt.Sprintf("%s:%d", r.extDNS[i], 53)

			resp, _, err = c.Exchange(query, addr)
			if err == nil {
				resp.Compress = true
				break
			}
			log.Errorf("external resolution failed, %s", err)
		}
		if resp == nil {
			return
		}
	}

	err = w.WriteMsg(resp)
	if err != nil {
		log.Errorf("error writing resolver resp, %s", err)
	}
}
Exemple #14
0
// ServeDNSForward forwards a request to a nameservers and returns the response.
func (s *server) ServeDNSForward(w dns.ResponseWriter, req *dns.Msg) *dns.Msg {
	if s.config.NoRec {
		m := s.ServerFailure(req)
		w.WriteMsg(m)
		return m
	}

	if len(s.config.Nameservers) == 0 || dns.CountLabel(req.Question[0].Name) < s.config.Ndots {
		if s.config.Verbose {
			if len(s.config.Nameservers) == 0 {
				logf("can not forward, no nameservers defined")
			} else {
				logf("can not forward, name too short (less than %d labels): `%s'", s.config.Ndots, req.Question[0].Name)
			}
		}
		m := s.ServerFailure(req)
		m.RecursionAvailable = true // this is still true
		w.WriteMsg(m)
		return m
	}

	var (
		r   *dns.Msg
		err error
	)

	nsid := s.randomNameserverID(req.Id)
	try := 0
Redo:
	if isTCP(w) {
		r, err = exchangeWithRetry(s.dnsTCPclient, req, s.config.Nameservers[nsid])
	} else {
		r, err = exchangeWithRetry(s.dnsUDPclient, req, s.config.Nameservers[nsid])
	}
	if err == nil {
		r.Compress = true
		r.Id = req.Id
		w.WriteMsg(r)
		return r
	}
	// Seen an error, this can only mean, "server not reached", try again
	// but only if we have not exausted our nameservers.
	if try < len(s.config.Nameservers) {
		try++
		nsid = (nsid + 1) % len(s.config.Nameservers)
		goto Redo
	}

	logf("failure to forward request %q", err)
	m := s.ServerFailure(req)
	return m
}
Exemple #15
0
func handleForwardingRaw(nameservers []string, req *dns.Msg, remote net.Addr) *dns.Msg {
	if len(nameservers) == 0 {
		log.Printf("no nameservers defined, can not forward\n")
		m := new(dns.Msg)
		m.SetReply(req)
		m.SetRcode(req, dns.RcodeServerFailure)
		m.Authoritative = false     // no matter what set to false
		m.RecursionAvailable = true // and this is still true
		return m
	}

	tcp := false
	if _, ok := remote.(*net.TCPAddr); ok {
		tcp = true
	}

	var (
		r   *dns.Msg
		err error
		try int
	)
	// Use request Id for "random" nameserver selection.
	nsid := int(req.Id) % len(nameservers)
	dnsClient := &dns.Client{Net: "udp", ReadTimeout: 4 * time.Second, WriteTimeout: 4 * time.Second, SingleInflight: true}
	if tcp {
		dnsClient.Net = "tcp"
	}
Redo:
	nameserver := nameservers[nsid]
	if i := strings.Index(nameserver, ":"); i < 0 {
		nameserver += ":53"
	}
	r, _, err = dnsClient.Exchange(req, nameserver)
	if err == nil {
		r.Compress = true
		return r
	}
	// Seen an error, this can only mean, "server not reached", try again
	// but only if we have not exausted our nameservers.
	if try < len(nameservers) {
		try++
		nsid = (nsid + 1) % len(nameservers)
		goto Redo
	}

	log.Printf("failure to forward request %q\n", err)
	m := new(dns.Msg)
	m.SetReply(req)
	m.SetRcode(req, dns.RcodeServerFailure)
	return m
}
Exemple #16
0
// trimUDPResponse makes sure a UDP response is not longer than allowed by RFC
// 1035. Enforce an arbitrary limit that can be further ratcheted down by
// config, and then make sure the response doesn't exceed 512 bytes. Any extra
// records will be trimmed along with answers.
func trimUDPResponse(config *DNSConfig, resp *dns.Msg) (trimmed bool) {
	numAnswers := len(resp.Answer)
	hasExtra := len(resp.Extra) > 0

	// We avoid some function calls and allocations by only handling the
	// extra data when necessary.
	var index map[string]dns.RR
	if hasExtra {
		index = make(map[string]dns.RR, len(resp.Extra))
		indexRRs(resp.Extra, index)
	}

	// This cuts UDP responses to a useful but limited number of responses.
	maxAnswers := lib.MinInt(maxUDPAnswerLimit, config.UDPAnswerLimit)
	if numAnswers > maxAnswers {
		resp.Answer = resp.Answer[:maxAnswers]
		if hasExtra {
			syncExtra(index, resp)
		}
	}

	// This enforces the hard limit of 512 bytes per the RFC. Note that we
	// temporarily switch to uncompressed so that we limit to a response
	// that will not exceed 512 bytes uncompressed, which is more
	// conservative and will allow our responses to be compliant even if
	// some downstream server uncompresses them.
	compress := resp.Compress
	resp.Compress = false
	for len(resp.Answer) > 0 && resp.Len() > 512 {
		resp.Answer = resp.Answer[:len(resp.Answer)-1]
		if hasExtra {
			syncExtra(index, resp)
		}
	}
	resp.Compress = compress

	return len(resp.Answer) < numAnswers
}
Exemple #17
0
// ServeDNSReverse is the handler for DNS requests for the reverse zone. If nothing is found
// locally the request is forwarded to the forwarder for resolution.
func (s *server) ServeDNSReverse(w dns.ResponseWriter, req *dns.Msg) *dns.Msg {
	m := new(dns.Msg)
	m.SetReply(req)
	m.Compress = true
	m.Authoritative = false
	m.RecursionAvailable = true
	if records, err := s.PTRRecords(req.Question[0]); err == nil && len(records) > 0 {
		m.Answer = records
		writeMsg(w, m)
		return m
	}
	// Always forward if not found locally.
	return s.ServeDNSForward(w, req)
}
// ServeDNSForward forwards a request to a nameservers and returns the response.
func (s *server) ServeDNSForward(w dns.ResponseWriter, req *dns.Msg) {
	StatsForwardCount.Inc(1)
	if len(s.config.Nameservers) == 0 || dns.CountLabel(req.Question[0].Name) < s.config.Ndots {
		s.config.log.Infof("no nameservers defined or name too short, can not forward")
		m := new(dns.Msg)
		m.SetReply(req)
		m.SetRcode(req, dns.RcodeServerFailure)
		m.Authoritative = false     // no matter what set to false
		m.RecursionAvailable = true // and this is still true
		w.WriteMsg(m)
		return
	}
	tcp := false
	if _, ok := w.RemoteAddr().(*net.TCPAddr); ok {
		tcp = true
	}

	var (
		r   *dns.Msg
		err error
		try int
	)
	// Use request Id for "random" nameserver selection.
	nsid := int(req.Id) % len(s.config.Nameservers)
Redo:
	switch tcp {
	case false:
		r, _, err = s.dnsUDPclient.Exchange(req, s.config.Nameservers[nsid])
	case true:
		r, _, err = s.dnsTCPclient.Exchange(req, s.config.Nameservers[nsid])
	}
	if err == nil {
		r.Compress = true
		w.WriteMsg(r)
		return
	}
	// Seen an error, this can only mean, "server not reached", try again
	// but only if we have not exausted our nameservers.
	if try < len(s.config.Nameservers) {
		try++
		nsid = (nsid + 1) % len(s.config.Nameservers)
		goto Redo
	}

	s.config.log.Errorf("failure to forward request %q", err)
	m := new(dns.Msg)
	m.SetReply(req)
	m.SetRcode(req, dns.RcodeServerFailure)
	w.WriteMsg(m)
}
Exemple #19
0
func Respond(w dns.ResponseWriter, req *dns.Msg, records []dns.RR) {
	m := new(dns.Msg)
	m.SetReply(req)
	m.Authoritative = true
	m.RecursionAvailable = true
	m.Compress = true
	m.Answer = records

	// Figure out the max response size
	bufsize := uint16(512)
	tcp := isTcp(w)

	if o := req.IsEdns0(); o != nil {
		bufsize = o.UDPSize()
	}

	if tcp {
		bufsize = dns.MaxMsgSize - 1
	} else if bufsize < 512 {
		bufsize = 512
	}

	if m.Len() > dns.MaxMsgSize {
		fqdn := dns.Fqdn(req.Question[0].Name)
		log.WithFields(log.Fields{"fqdn": fqdn}).Debug("Response too big, dropping Extra")
		m.Extra = nil
		if m.Len() > dns.MaxMsgSize {
			log.WithFields(log.Fields{"fqdn": fqdn}).Debug("Response still too big")
			m := new(dns.Msg)
			m.SetRcode(m, dns.RcodeServerFailure)
		}
	}

	if m.Len() > int(bufsize) && !tcp {
		log.Debug("Too big 1")
		m.Extra = nil
		if m.Len() > int(bufsize) {
			log.Debug("Too big 2")
			m.Answer = nil
			m.Truncated = true
		}
	}

	err := w.WriteMsg(m)
	if err != nil {
		log.Warn("Failed to return reply: ", err, m.Len())
	}

}
Exemple #20
0
// handleRecurse is used to handle recursive DNS queries
func (d *DNSServer) handleRecurse(resp dns.ResponseWriter, req *dns.Msg) {
	q := req.Question[0]
	network := "udp"
	defer func(s time.Time) {
		d.logger.Printf("[DEBUG] dns: request for %v (%s) (%v) from client %s (%s)",
			q, network, time.Now().Sub(s), resp.RemoteAddr().String(),
			resp.RemoteAddr().Network())
	}(time.Now())

	// Switch to TCP if the client is
	if _, ok := resp.RemoteAddr().(*net.TCPAddr); ok {
		network = "tcp"
	}

	// Recursively resolve
	c := &dns.Client{Net: network, Timeout: d.config.RecursorTimeout}
	var r *dns.Msg
	var rtt time.Duration
	var err error
	for _, recursor := range d.recursors {
		r, rtt, err = c.Exchange(req, recursor)
		if err == nil {
			// Compress the response; we don't know if the incoming
			// response was compressed or not, so by not compressing
			// we might generate an invalid packet on the way out.
			r.Compress = !d.config.DisableCompression

			// Forward the response
			d.logger.Printf("[DEBUG] dns: recurse RTT for %v (%v)", q, rtt)
			if err := resp.WriteMsg(r); err != nil {
				d.logger.Printf("[WARN] dns: failed to respond: %v", err)
			}
			return
		}
		d.logger.Printf("[ERR] dns: recurse failed: %v", err)
	}

	// If all resolvers fail, return a SERVFAIL message
	d.logger.Printf("[ERR] dns: all resolvers failed for %v from client %s (%s)",
		q, resp.RemoteAddr().String(), resp.RemoteAddr().Network())
	m := &dns.Msg{}
	m.SetReply(req)
	m.Compress = !d.config.DisableCompression
	m.RecursionAvailable = true
	m.SetRcode(req, dns.RcodeServerFailure)
	resp.WriteMsg(m)
}
// ServeDNSReverse is the handler for DNS requests for the reverse zone. If nothing is found
// locally the request is forwarded to the forwarder for resolution.
func (s *server) ServeDNSReverse(w dns.ResponseWriter, req *dns.Msg) {
	m := new(dns.Msg)
	m.SetReply(req)
	m.Compress = true
	m.Authoritative = false // Set to false, because I don't know what to do wrt DNSSEC.
	m.RecursionAvailable = true
	var err error
	if m.Answer, err = s.PTRRecords(req.Question[0]); err == nil {
		// TODO(miek): Reverse DNSSEC. We should sign this, but requires a key....and more
		// Probably not worth the hassle?
		if err := w.WriteMsg(m); err != nil {
			s.config.log.Errorf("failure to return reply %q", err)
		}
	}
	// Always forward if not found locally.
	s.ServeDNSForward(w, req)
}
Exemple #22
0
// toMsg turns i into a message, it tailers to reply to m.
func (i *item) toMsg(m *dns.Msg) *dns.Msg {
	m1 := new(dns.Msg)
	m1.SetReply(m)
	m1.Authoritative = i.Authoritative
	m1.AuthenticatedData = i.AuthenticatedData
	m1.RecursionAvailable = i.RecursionAvailable
	m1.Compress = true

	m1.Answer = i.Answer
	m1.Ns = i.Ns
	m1.Extra = i.Extra

	ttl := int(i.origTtl) - int(time.Now().UTC().Sub(i.stored).Seconds())
	if ttl < baseTtl {
		ttl = baseTtl
	}
	setCap(m1, uint32(ttl))
	return m1
}
Exemple #23
0
func handleRequest(w dns.ResponseWriter, r *dns.Msg) {
	var rr dns.RR
	fmt.Println(r.Question[0].Name)
	m := new(dns.Msg)
	m.SetReply(r)
	m.Compress = *compress

	rrstr, err := getRRStr(r.Question[0])
	if err {
		m.SetRcode(r, dns.RcodeNameError)
	} else {
		rr, _ = dns.NewRR(rrstr)

		m.Answer = append(m.Answer, rr)
		if *printf {
			fmt.Printf("%v\n", m.String())
		}
	}
	w.WriteMsg(m)
}
Exemple #24
0
func (dom Domain) addHandler(tld string) {
	fqdn := dom.Name
	if tld != "" {
		fqdn = dom.Name + "." + tld
	}

	FqdnParts, _ := dns.IsDomainName(fqdn)

	fmt.Printf("Adding: %v - %v nums\n", fqdn, FqdnParts)

	// Handle dns requests if it is really a fqdn
	if dns.IsFqdn(fqdn) {
		dns.HandleFunc(fqdn, func(w dns.ResponseWriter, req *dns.Msg) {
			m := new(dns.Msg)
			m.SetReply(req)
			m.Compress = true

			for i, q := range req.Question {
				fmt.Printf("Requested: %s, Type: %v\n", req.Question[i].Name, req.Question[i].Qtype)

				switch q.Qtype {
				case 1:
					fmt.Printf("Adding a record %v with ip %v", q.Name, dom.A.Ip)
					m.Answer = append(m.Answer, NewA(q.Name, dom.A.Ip, uint32(dom.A.Ttl)))
				case 15:
					fmt.Printf("Adding a record %v with ip %v", q.Name, dom.A.Ip)
					m.Answer = append(m.Answer, NewMX(q.Name, dom.Mx.Content, dom.Mx.Priority, uint32(dom.Mx.Ttl)))
				}
			}

			w.WriteMsg(m)
		})
	}

	// add subdomains..
	for _, d := range dom.Domains {
		d.addHandler(fqdn)
	}
}
Exemple #25
0
func dnsHandler(w dns.ResponseWriter, r *dns.Msg) {
	defer w.Close()
	m := new(dns.Msg)
	m.SetReply(r)
	m.Compress = false

	for _, q := range r.Question {
		fmt.Printf("dns-srv: Query -- [%s] %s\n", q.Name, dns.TypeToString[q.Qtype])
		switch q.Qtype {
		case dns.TypeA:
			record := new(dns.A)
			record.Hdr = dns.RR_Header{
				Name:   q.Name,
				Rrtype: dns.TypeA,
				Class:  dns.ClassINET,
				Ttl:    0,
			}
			record.A = net.ParseIP("127.0.0.1")

			m.Answer = append(m.Answer, record)
		case dns.TypeMX:
			record := new(dns.MX)
			record.Hdr = dns.RR_Header{
				Name:   q.Name,
				Rrtype: dns.TypeMX,
				Class:  dns.ClassINET,
				Ttl:    0,
			}
			record.Mx = "mail." + q.Name
			record.Preference = 10

			m.Answer = append(m.Answer, record)
		}
	}

	w.WriteMsg(m)
	return
}
Exemple #26
0
// base handler for dns server
func dnsHandler(w dns.ResponseWriter, request *dns.Msg) {
	response := new(dns.Msg)
	response.SetReply(request)
	response.Compress = false

	switch request.Opcode {
	case dns.OpcodeQuery:
		for _, q := range response.Question {
			if readRR, e := getRecord(q.Name, q.Qtype); e == nil {
				rr := readRR.(dns.RR)
				if rr.Header().Name == q.Name {
					response.Answer = append(response.Answer, rr)
				}
			}
		}
	case dns.OpcodeUpdate:
		if request.IsTsig() != nil && w.TsigStatus() == nil {
			for _, question := range request.Question {
				for _, rr := range request.Ns {
					updateRecord(rr, &question)
				}
			}
		} else {
			log.Println("droping update without tsig or with bad sig")
		}
	}

	if request.IsTsig() != nil {
		if w.TsigStatus() == nil {
			response.SetTsig(request.Extra[len(request.Extra)-1].(*dns.TSIG).Hdr.Name, dns.HmacMD5, 300, time.Now().Unix())
		} else {
			log.Println("Status: ", w.TsigStatus().Error())
		}
	}
	w.WriteMsg(response)
}
Exemple #27
0
func (d dnsAPI) Recurse(w dns.ResponseWriter, req *dns.Msg) {
	var client dns.Client

	if isTCP(w.RemoteAddr()) {
		client.Net = "tcp"
	}

	for _, recursor := range d.Recursors {
		req.Compress = true
		res, _, err := client.Exchange(req, recursor)
		if err != nil {
			continue
		}
		res.Compress = true
		w.WriteMsg(res)
		return
	}

	// Return SERVFAIL
	res := &dns.Msg{}
	res.RecursionAvailable = true
	res.SetRcode(req, dns.RcodeServerFailure)
	w.WriteMsg(res)
}
Exemple #28
0
func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) {
	var (
		extConn net.Conn
		resp    *dns.Msg
		err     error
	)

	if query == nil || len(query.Question) == 0 {
		return
	}
	name := query.Question[0].Name

	switch query.Question[0].Qtype {
	case dns.TypeA:
		resp, err = r.handleIPQuery(name, query, types.IPv4)
	case dns.TypeAAAA:
		resp, err = r.handleIPQuery(name, query, types.IPv6)
	case dns.TypePTR:
		resp, err = r.handlePTRQuery(name, query)
	case dns.TypeSRV:
		resp, err = r.handleSRVQuery(name, query)
	}

	if err != nil {
		logrus.Error(err)
		return
	}

	if resp == nil {
		// If the backend doesn't support proxying dns request
		// fail the response
		if !r.proxyDNS {
			resp = new(dns.Msg)
			resp.SetRcode(query, dns.RcodeServerFailure)
			w.WriteMsg(resp)
			return
		}

		// If the user sets ndots > 0 explicitly and the query is
		// in the root domain don't forward it out. We will return
		// failure and let the client retry with the search domain
		// attached
		switch query.Question[0].Qtype {
		case dns.TypeA:
			fallthrough
		case dns.TypeAAAA:
			if r.backend.NdotsSet() && !strings.Contains(strings.TrimSuffix(name, "."), ".") {
				resp = createRespMsg(query)
			}
		}
	}

	proto := w.LocalAddr().Network()
	maxSize := 0
	if proto == "tcp" {
		maxSize = dns.MaxMsgSize - 1
	} else if proto == "udp" {
		optRR := query.IsEdns0()
		if optRR != nil {
			maxSize = int(optRR.UDPSize())
		}
		if maxSize < defaultRespSize {
			maxSize = defaultRespSize
		}
	}

	if resp != nil {
		if resp.Len() > maxSize {
			truncateResp(resp, maxSize, proto == "tcp")
		}
	} else {
		for i := 0; i < maxExtDNS; i++ {
			extDNS := &r.extDNSList[i]
			if extDNS.ipStr == "" {
				break
			}
			extConnect := func() {
				addr := fmt.Sprintf("%s:%d", extDNS.ipStr, 53)
				extConn, err = net.DialTimeout(proto, addr, extIOTimeout)
			}

			if extDNS.hostLoopback {
				extConnect()
			} else {
				execErr := r.backend.ExecFunc(extConnect)
				if execErr != nil {
					logrus.Warn(execErr)
					continue
				}
			}
			if err != nil {
				logrus.Warnf("Connect failed: %s", err)
				continue
			}
			logrus.Debugf("Query %s[%d] from %s, forwarding to %s:%s", name, query.Question[0].Qtype,
				extConn.LocalAddr().String(), proto, extDNS.ipStr)

			// Timeout has to be set for every IO operation.
			extConn.SetDeadline(time.Now().Add(extIOTimeout))
			co := &dns.Conn{
				Conn:    extConn,
				UDPSize: uint16(maxSize),
			}
			defer co.Close()

			// limits the number of outstanding concurrent queries.
			if r.forwardQueryStart() == false {
				old := r.tStamp
				r.tStamp = time.Now()
				if r.tStamp.Sub(old) > logInterval {
					logrus.Errorf("More than %v concurrent queries from %s", maxConcurrent, extConn.LocalAddr().String())
				}
				continue
			}

			err = co.WriteMsg(query)
			if err != nil {
				r.forwardQueryEnd()
				logrus.Debugf("Send to DNS server failed, %s", err)
				continue
			}

			resp, err = co.ReadMsg()
			// Truncated DNS replies should be sent to the client so that the
			// client can retry over TCP
			if err != nil && err != dns.ErrTruncated {
				r.forwardQueryEnd()
				logrus.Debugf("Read from DNS server failed, %s", err)
				continue
			}
			r.forwardQueryEnd()
			if resp != nil {
				for _, rr := range resp.Answer {
					h := rr.Header()
					switch h.Rrtype {
					case dns.TypeA:
						ip := rr.(*dns.A).A
						r.backend.HandleQueryResp(h.Name, ip)
					case dns.TypeAAAA:
						ip := rr.(*dns.AAAA).AAAA
						r.backend.HandleQueryResp(h.Name, ip)
					}
				}
			}
			resp.Compress = true
			break
		}
		if resp == nil {
			return
		}
	}

	if err = w.WriteMsg(resp); err != nil {
		logrus.Errorf("error writing resolver resp, %s", err)
	}
}
Exemple #29
0
// ServeDNSForward forwards a request to a nameservers and returns the response.
func (s *server) ServeDNSForward(w dns.ResponseWriter, req *dns.Msg) *dns.Msg {
	StatsForwardCount.Inc(1)
	promExternalRequestCount.WithLabelValues("recursive").Inc()

	if s.config.NoRec {
		m := new(dns.Msg)
		m.SetReply(req)
		m.SetRcode(req, dns.RcodeServerFailure)
		m.Authoritative = false
		m.RecursionAvailable = false
		w.WriteMsg(m)
		return m
	}

	if len(s.config.Nameservers) == 0 || dns.CountLabel(req.Question[0].Name) < s.config.Ndots {
		if s.config.Verbose {
			if len(s.config.Nameservers) == 0 {
				logf("can not forward, no nameservers defined")
			} else {
				logf("can not forward, name too short (less than %d labels): `%s'", s.config.Ndots, req.Question[0].Name)
			}
		}
		m := new(dns.Msg)
		m.SetReply(req)
		m.SetRcode(req, dns.RcodeServerFailure)
		m.Authoritative = false     // no matter what set to false
		m.RecursionAvailable = true // and this is still true
		w.WriteMsg(m)
		return m
	}

	tcp := isTCP(w)

	var (
		r   *dns.Msg
		err error
		try int
	)
	// Use request Id for "random" nameserver selection.
	nsid := int(req.Id) % len(s.config.Nameservers)
Redo:
	switch tcp {
	case false:
		r, _, err = s.dnsUDPclient.Exchange(req, s.config.Nameservers[nsid])
	case true:
		r, _, err = s.dnsTCPclient.Exchange(req, s.config.Nameservers[nsid])
	}
	if err == nil {
		r.Compress = true
		r.Id = req.Id
		w.WriteMsg(r)
		return r
	}
	// Seen an error, this can only mean, "server not reached", try again
	// but only if we have not exausted our nameservers.
	if try < len(s.config.Nameservers) {
		try++
		nsid = (nsid + 1) % len(s.config.Nameservers)
		goto Redo
	}

	logf("failure to forward request %q", err)
	m := new(dns.Msg)
	m.SetReply(req)
	m.SetRcode(req, dns.RcodeServerFailure)
	w.WriteMsg(m)
	return m
}
Exemple #30
0
func TestRecursiveCompress(t *testing.T) {
	const (
		hostname = "foo.example."
		maxSize  = 512
	)

	// Construct a response that is >512 when uncompressed, <512 when compressed
	response := dns.Msg{}
	response.Authoritative = true
	response.Answer = []dns.RR{}
	header := dns.RR_Header{
		Name:   hostname,
		Rrtype: dns.TypeA,
		Class:  dns.ClassINET,
		Ttl:    10,
	}
	for response.Len() <= maxSize {
		ip := address.Address(rand.Uint32()).IP4()
		response.Answer = append(response.Answer, &dns.A{Hdr: header, A: ip})
	}
	response.Compress = true
	require.True(t, response.Len() <= maxSize)

	// A dns server that returns the above response
	var gotRequest = make(chan struct{}, 1)
	handleRecursive := func(w dns.ResponseWriter, req *dns.Msg) {
		gotRequest <- struct{}{}
		require.Equal(t, req.Question[0].Name, hostname)
		response.SetReply(req)
		err := w.WriteMsg(&response)
		require.Nil(t, err)
	}
	mux := dns.NewServeMux()
	mux.HandleFunc(topDomain, handleRecursive)
	udpListener, err := net.ListenPacket("udp", "0.0.0.0:0")
	require.Nil(t, err)
	udpServer := &dns.Server{PacketConn: udpListener, Handler: mux}
	udpServerPort := udpListener.LocalAddr().(*net.UDPAddr).Port
	go udpServer.ActivateAndServe()
	defer udpServer.Shutdown()

	// The weavedns server, pointed at the above server
	dnsserver, _, udpPort, _ := startServer(t, &dns.ClientConfig{
		Servers:  []string{"127.0.0.1"},
		Port:     strconv.Itoa(udpServerPort),
		Ndots:    1,
		Timeout:  5,
		Attempts: 2,
	})
	defer dnsserver.Stop()

	// Now do lookup, check its what we expected.
	// NB this doesn't really test golang's resolver behaves correctly, as I can't see
	// a way to point golangs resolver at a specific hosts.
	req := new(dns.Msg)
	req.Id = dns.Id()
	req.RecursionDesired = true
	req.Question = make([]dns.Question, 1)
	req.Question[0] = dns.Question{
		Name:   hostname,
		Qtype:  dns.TypeA,
		Qclass: dns.ClassINET,
	}
	c := new(dns.Client)
	res, _, err := c.Exchange(req, fmt.Sprintf("127.0.0.1:%d", udpPort))
	require.Nil(t, err)
	require.True(t, len(gotRequest) > 0)
	require.True(t, res.Len() > maxSize)
}