Example #1
0
// search searches the tree for qname and type. If glue is true the search *does* not
// spot when hitting NS records, but descends in search of glue. The qtype for this
// kind of search can only be AAAA or A.
func (n *Node) search(qname string, qtype uint16, glue bool) (*Node, Result) {
	old := n
	for n != nil {

		switch c := Less(n.Elem, qname); {
		case c == 0:
			return n, Found
		case c < 0:
			old = n
			n = n.Left
		default:
			if !glue && n.Elem.Types(dns.TypeNS) != nil {
				return n, Delegation

			}
			old = n
			n = n.Right
		}
	}
	if dns.CountLabel(qname) < dns.CountLabel(old.Elem.Name()) {
		return n, EmptyNonTerminal
	}

	return n, NameError
}
Example #2
0
// lookupNameservers returns the authoritative nameservers for the given fqdn.
func lookupNameservers(fqdn string) ([]string, error) {
	var authoritativeNss []string

	r, err := dnsQuery(fqdn, dns.TypeNS, recursiveNameserver, true)
	if err != nil {
		return nil, err
	}

	for _, rr := range r.Answer {
		if ns, ok := rr.(*dns.NS); ok {
			authoritativeNss = append(authoritativeNss, strings.ToLower(ns.Ns))
		}
	}

	if len(authoritativeNss) > 0 {
		return authoritativeNss, nil
	}

	// Strip of the left most label to get the parent domain.
	offset, _ := dns.NextLabel(fqdn, 0)
	next := fqdn[offset:]
	if dns.CountLabel(next) < 2 {
		return nil, fmt.Errorf("Could not determine authoritative nameservers")
	}

	return lookupNameservers(next)
}
Example #3
0
// ServeDNSForward forwards a request to a nameservers and returns the response.
func (s *server) ServeDNSForward(w dns.ResponseWriter, req *dns.Msg) *dns.Msg {

	if s.config.NoRec {
		m := s.ServerFailure(req)
		w.WriteMsg(m)
		return m
	}

	if len(s.config.Nameservers) == 0 || dns.CountLabel(req.Question[0].Name) < s.config.Ndots {
		if s.config.Verbose {
			if len(s.config.Nameservers) == 0 {
				logf("can not forward, no nameservers defined")
			} else {
				logf("can not forward, name too short (less than %d labels): `%s'", s.config.Ndots, req.Question[0].Name)
			}
		}
		m := s.ServerFailure(req)
		m.RecursionAvailable = true // this is still true
		w.WriteMsg(m)
		return m
	}

	tcp := isTCP(w)

	var (
		r   *dns.Msg
		err error
		try int
	)

	nsid := 0
	if s.config.NSRotate {
		// Use request Id for "random" nameserver selection.
		nsid = int(req.Id) % len(s.config.Nameservers)
	}
Redo:
	switch tcp {
	case false:
		r, _, err = s.dnsUDPclient.Exchange(req, s.config.Nameservers[nsid])
	case true:
		r, _, err = s.dnsTCPclient.Exchange(req, s.config.Nameservers[nsid])
	}
	if err == nil {
		r.Compress = true
		r.Id = req.Id
		w.WriteMsg(r)
		return r
	}
	// Seen an error, this can only mean, "server not reached", try again
	// but only if we have not exausted our nameservers.
	if try < len(s.config.Nameservers) {
		try++
		nsid = (nsid + 1) % len(s.config.Nameservers)
		goto Redo
	}

	logf("failure to forward request %q", err)
	m := s.ServerFailure(req)
	return m
}
Example #4
0
File: stub.go Project: CMGS/skydns
// Look in .../dns/stub/<domain>/xx for msg.Services. Loop through them
// extract <domain> and add them as forwarders (ip:port-combos) for
// the stub zones. Only numeric (i.e. IP address) hosts are used.
func (s *server) UpdateStubZones() {
	stubmap := make(map[string][]string)

	services, err := s.backend.Records("stub.dns."+s.config.Domain, false, net.IP{})
	if err != nil {
		logf("stub zone update failed: %s", err)
		return
	}
	for _, serv := range services {
		if serv.Port == 0 {
			serv.Port = 53
		}
		ip := net.ParseIP(serv.Host)
		if ip == nil {
			logf("stub zone non-address %s seen for: %s", serv.Key, serv.Host)
			continue
		}

		domain := msg.Domain(serv.Key)
		// Chop of left most label, because that is used as the nameserver place holder
		// and drop the right most labels that belong to localDomain.
		labels := dns.SplitDomainName(domain)
		domain = dns.Fqdn(strings.Join(labels[1:len(labels)-dns.CountLabel(s.config.localDomain)], "."))

		// If the remaining name equals s.config.LocalDomain we ignore it.
		if domain == s.config.localDomain {
			logf("not adding stub zone for my own domain")
			continue
		}
		stubmap[domain] = append(stubmap[domain], net.JoinHostPort(serv.Host, strconv.Itoa(serv.Port)))
	}

	s.config.stub = &stubmap
}
Example #5
0
// saveDNSRR saves 1 or more DNS records to the resolver cache.
func (r *Resolver) saveDNSRR(host string, qname string, drrs []dns.RR) RRs {
	var rrs RRs
	cl := dns.CountLabel(qname)
	for _, drr := range drrs {
		rr, ok := convertRR(drr)
		if !ok {
			continue
		}
		if dns.CountLabel(rr.Name) < cl && dns.CompareDomainName(qname, rr.Name) < 2 {
			// fmt.Fprintf(os.Stderr, "Warning: potential poisoning from %s: %s -> %s\n", host, qname, drr.String())
			continue
		}
		r.cache.add(rr.Name, rr)
		rrs = append(rrs, rr)
	}
	return rrs
}
Example #6
0
func setDefaults(config *Config) error {
	if config.ReadTimeout == 0 {
		config.ReadTimeout = 2 * time.Second
	}
	if config.DnsAddr == "" {
		config.DnsAddr = "127.0.0.1:53"
	}
	if config.Domain == "" {
		config.Domain = "skydns.local."
	}
	if config.Hostmaster == "" {
		config.Hostmaster = "hostmaster." + config.Domain
	}
	// People probably don't know that SOA's email addresses cannot
	// contain @-signs, replace them with dots
	config.Hostmaster = dns.Fqdn(strings.Replace(config.Hostmaster, "@", ".", -1))
	if config.MinTtl == 0 {
		config.MinTtl = 60
	}
	if config.Ttl == 0 {
		config.Ttl = 3600
	}
	if config.Priority == 0 {
		config.Priority = 10
	}

	if len(config.Nameservers) == 0 {
		c, err := dns.ClientConfigFromFile("/etc/resolv.conf")
		if err != nil {
			return err
		}
		for _, s := range c.Servers {
			config.Nameservers = append(config.Nameservers, net.JoinHostPort(s, c.Port))
		}
	}
	config.Domain = dns.Fqdn(strings.ToLower(config.Domain))
	config.DomainLabels = dns.CountLabel(config.Domain)
	if config.DNSSEC != "" {
		// For some reason the + are replaces by spaces in etcd. Re-replace them
		keyfile := strings.Replace(config.DNSSEC, " ", "+", -1)
		k, p, err := ParseKeyFile(keyfile)
		if err != nil {
			return err
		}
		if k.Header().Name != dns.Fqdn(config.Domain) {
			return fmt.Errorf("ownername of DNSKEY must match SkyDNS domain")
		}
		k.Header().Ttl = config.Ttl
		config.PubKey = k
		config.KeyTag = k.KeyTag()
		config.PrivKey = p
		config.ClosestEncloser, config.DenyWildcard = newNSEC3CEandWildcard(config.Domain, config.Domain, config.MinTtl)
	}
	return nil
}
Example #7
0
// ServeDNSForward forwards a request to a nameservers and returns the response.
func (s *server) ServeDNSForward(w dns.ResponseWriter, req *dns.Msg) *dns.Msg {
	if s.config.NoRec {
		m := s.ServerFailure(req)
		w.WriteMsg(m)
		return m
	}

	if len(s.config.Nameservers) == 0 || dns.CountLabel(req.Question[0].Name) < s.config.Ndots {
		if s.config.Verbose {
			if len(s.config.Nameservers) == 0 {
				logf("can not forward, no nameservers defined")
			} else {
				logf("can not forward, name too short (less than %d labels): `%s'", s.config.Ndots, req.Question[0].Name)
			}
		}
		m := s.ServerFailure(req)
		m.RecursionAvailable = true // this is still true
		w.WriteMsg(m)
		return m
	}

	var (
		r   *dns.Msg
		err error
	)

	nsid := s.randomNameserverID(req.Id)
	try := 0
Redo:
	if isTCP(w) {
		r, err = exchangeWithRetry(s.dnsTCPclient, req, s.config.Nameservers[nsid])
	} else {
		r, err = exchangeWithRetry(s.dnsUDPclient, req, s.config.Nameservers[nsid])
	}
	if err == nil {
		r.Compress = true
		r.Id = req.Id
		w.WriteMsg(r)
		return r
	}
	// Seen an error, this can only mean, "server not reached", try again
	// but only if we have not exausted our nameservers.
	if try < len(s.config.Nameservers) {
		try++
		nsid = (nsid + 1) % len(s.config.Nameservers)
		goto Redo
	}

	logf("failure to forward request %q", err)
	m := s.ServerFailure(req)
	return m
}
Example #8
0
// Newserver returns a new Server.
func NewServer(members []string, domain string, dnsAddr string, httpAddr string, dataDir string, rt, wt time.Duration, secret string, nameservers []string, roundrobin bool, tlskey string, tlspem string) (s *Server) {
	s = &Server{
		members:      members,
		domain:       strings.ToLower(domain),
		domainLabels: dns.CountLabel(dns.Fqdn(domain)),
		dnsAddr:      dnsAddr,
		httpAddr:     httpAddr,
		readTimeout:  rt,
		writeTimeout: wt,
		router:       mux.NewRouter(),
		registry:     registry.New(),
		dataDir:      dataDir,
		dnsHandler:   dns.NewServeMux(),
		waiter:       new(sync.WaitGroup),
		secret:       secret,
		nameservers:  nameservers,
		roundrobin:   roundrobin,
		tlskey:       tlskey,
		tlspem:       tlspem,
	}

	if _, err := os.Stat(s.dataDir); os.IsNotExist(err) {
		log.Fatal("Data directory does not exist: ", dataDir)
		return
	}

	// DNS
	s.dnsHandler.Handle(".", s)

	authWrapper := s.authHTTPWrapper

	// API Routes
	s.router.HandleFunc("/skydns/services/{uuid}", authWrapper(s.addServiceHTTPHandler)).Methods("PUT")
	s.router.HandleFunc("/skydns/services/{uuid}", authWrapper(s.getServiceHTTPHandler)).Methods("GET")
	s.router.HandleFunc("/skydns/services/{uuid}", authWrapper(s.removeServiceHTTPHandler)).Methods("DELETE")
	s.router.HandleFunc("/skydns/services/{uuid}", authWrapper(s.updateServiceHTTPHandler)).Methods("PATCH")

	s.router.HandleFunc("/skydns/callbacks/{uuid}", authWrapper(s.addCallbackHTTPHandler)).Methods("PUT")

	// External API Routes
	// /skydns/services #list all services
	s.router.HandleFunc("/skydns/services/", authWrapper(s.getServicesHTTPHandler)).Methods("GET")
	// /skydns/regions #list all regions
	s.router.HandleFunc("/skydns/regions/", authWrapper(s.getRegionsHTTPHandler)).Methods("GET")
	// /skydns/environnments #list all environments
	s.router.HandleFunc("/skydns/environments/", authWrapper(s.getEnvironmentsHTTPHandler)).Methods("GET")

	// Raft Routes
	s.router.HandleFunc("/raft/join", s.joinHandler).Methods("POST")
	return
}
Example #9
0
File: zone.go Project: abh/geodns
func NewZone(name string) *Zone {
	zone := new(Zone)
	zone.Labels = make(labels)
	zone.Origin = name
	zone.LabelCount = dns.CountLabel(zone.Origin)

	// defaults
	zone.Options.Ttl = 120
	zone.Options.MaxHosts = 2
	zone.Options.Contact = "hostmaster." + name
	zone.Options.Targeting = TargetGlobal + TargetCountry + TargetContinent

	return zone
}
Example #10
0
// ServeDNSForward forwards a request to a nameservers and returns the response.
func (s *server) ServeDNSForward(w dns.ResponseWriter, req *dns.Msg) {
	StatsForwardCount.Inc(1)
	if len(s.config.Nameservers) == 0 || dns.CountLabel(req.Question[0].Name) < s.config.Ndots {
		s.config.log.Infof("no nameservers defined or name too short, can not forward")
		m := new(dns.Msg)
		m.SetReply(req)
		m.SetRcode(req, dns.RcodeServerFailure)
		m.Authoritative = false     // no matter what set to false
		m.RecursionAvailable = true // and this is still true
		w.WriteMsg(m)
		return
	}
	tcp := false
	if _, ok := w.RemoteAddr().(*net.TCPAddr); ok {
		tcp = true
	}

	var (
		r   *dns.Msg
		err error
		try int
	)
	// Use request Id for "random" nameserver selection.
	nsid := int(req.Id) % len(s.config.Nameservers)
Redo:
	switch tcp {
	case false:
		r, _, err = s.dnsUDPclient.Exchange(req, s.config.Nameservers[nsid])
	case true:
		r, _, err = s.dnsTCPclient.Exchange(req, s.config.Nameservers[nsid])
	}
	if err == nil {
		r.Compress = true
		w.WriteMsg(r)
		return
	}
	// Seen an error, this can only mean, "server not reached", try again
	// but only if we have not exausted our nameservers.
	if try < len(s.config.Nameservers) {
		try++
		nsid = (nsid + 1) % len(s.config.Nameservers)
		goto Redo
	}

	s.config.log.Errorf("failure to forward request %q", err)
	m := new(dns.Msg)
	m.SetReply(req)
	m.SetRcode(req, dns.RcodeServerFailure)
	w.WriteMsg(m)
}
Example #11
0
// Newserver returns a new server.
// TODO(miek): multiple ectdAddrs
func NewServer(domain, dnsAddr string, nameservers []string, etcdAddr string) *server {
	s := &server{
		domain:       dns.Fqdn(strings.ToLower(domain)),
		domainLabels: dns.CountLabel(dns.Fqdn(domain)),
		DnsAddr:      dnsAddr,
		client:       etcd.NewClient([]string{etcdAddr}),
		dnsHandler:   dns.NewServeMux(),
		waiter:       new(sync.WaitGroup),
		nameservers:  nameservers,
	}

	// DNS
	s.dnsHandler.Handle(".", s)
	return s
}
Example #12
0
func setDefaults(config *Config) error {
	if config.ReadTimeout == 0 {
		config.ReadTimeout = 2 * time.Second
	}
	if config.DnsAddr == "" {
		config.DnsAddr = "127.0.0.1:53"
	}
	if config.Domain == "" {
		config.Domain = "skydns.local"
	}
	if config.MinTtl == 0 {
		config.MinTtl = 60
	}
	if config.Ttl == 0 {
		config.Ttl = 3600
	}
	if config.Priority == 0 {
		config.Priority = 10
	}

	if len(config.Nameservers) == 0 {
		c, err := dns.ClientConfigFromFile("/etc/resolv.conf")
		if err != nil {
			return err
		}
		for _, s := range c.Servers {
			config.Nameservers = append(config.Nameservers, net.JoinHostPort(s, c.Port))
		}
	}
	config.Domain = dns.Fqdn(strings.ToLower(config.Domain))
	config.DomainLabels = dns.CountLabel(config.Domain)
	if config.DNSSEC != "" {
		// For some reason the + are replaces by spaces in etcd. Re-replace them
		keyfile := strings.Replace(config.DNSSEC, " ", "+", -1)
		k, p, err := ParseKeyFile(keyfile)
		if err != nil {
			return err
		}
		if k.Header().Name != dns.Fqdn(config.Domain) {
			return fmt.Errorf("ownername of DNSKEY must match SkyDNS domain")
		}
		k.Header().Ttl = config.Ttl
		config.PubKey = k
		config.KeyTag = k.KeyTag()
		config.PrivKey = p
	}
	return nil
}
Example #13
0
// getZoneForName returns the zone string that matches the name and a
// list of the DNS labels from name that are within the zone.
// For example, if "coredns.local" is a zone configured for the
// Kubernetes middleware, then getZoneForName("a.b.coredns.local")
// will return ("coredns.local", ["a", "b"]).
func (g Kubernetes) getZoneForName(name string) (string, []string) {
	var zone string
	var serviceSegments []string

	for _, z := range g.Zones {
		if dns.IsSubDomain(z, name) {
			zone = z

			serviceSegments = dns.SplitDomainName(name)
			serviceSegments = serviceSegments[:len(serviceSegments)-dns.CountLabel(zone)]
			break
		}
	}

	return zone, serviceSegments
}
Example #14
0
// Look in .../dns/stub/<zone>/xx for msg.Services. Loop through them
// extract <zone> and add them as forwarders (ip:port-combos) for
// the stub zones. Only numeric (i.e. IP address) hosts are used.
// Only the first zone configured on e is used for the lookup.
func (e *Etcd) updateStubZones() {
	zone := e.Zones[0]
	services, err := e.Records(stubDomain+"."+zone, false)
	if err != nil {
		return
	}

	stubmap := make(map[string]proxy.Proxy)
	// track the nameservers on a per domain basis, but allow a list on the domain.
	nameservers := map[string][]string{}

	for _, serv := range services {
		if serv.Port == 0 {
			serv.Port = 53
		}
		ip := net.ParseIP(serv.Host)
		if ip == nil {
			log.Printf("[WARNING] Non IP address stub nameserver: %s", serv.Host)
			continue
		}

		domain := msg.Domain(serv.Key)
		labels := dns.SplitDomainName(domain)

		// If the remaining name equals any of the zones we have, we ignore it.
		for _, z := range e.Zones {
			// Chop of left most label, because that is used as the nameserver place holder
			// and drop the right most labels that belong to zone.
			// We must *also* chop of dns.stub. which means cutting two more labels.
			domain = dns.Fqdn(strings.Join(labels[1:len(labels)-dns.CountLabel(z)-2], "."))
			if domain == z {
				log.Printf("[WARNING] Skipping nameserver for domain we are authoritative for: %s", domain)
				continue
			}
			nameservers[domain] = append(nameservers[domain], net.JoinHostPort(serv.Host, strconv.Itoa(serv.Port)))
		}
	}
	for domain, nss := range nameservers {
		stubmap[domain] = proxy.New(nss)
	}
	// atomic swap (at least that's what we hope it is)
	if len(stubmap) > 0 {
		e.Stubmap = &stubmap
	}
	return
}
Example #15
0
// Lookup looks up name,type using the recursive nameserver defines
// in the server's config. If none defined it returns an error.
func (s *server) Lookup(n string, t, bufsize uint16, dnssec bool) (*dns.Msg, error) {
	StatsLookupCount.Inc(1)
	promExternalRequestCount.WithLabelValues("lookup").Inc()

	if len(s.config.Nameservers) == 0 {
		return nil, fmt.Errorf("no nameservers configured can not lookup name")
	}
	if dns.CountLabel(n) < s.config.Ndots {
		return nil, fmt.Errorf("name has fewer than %d labels", s.config.Ndots)
	}
	m := new(dns.Msg)
	m.SetQuestion(n, t)
	m.SetEdns0(bufsize, dnssec)

	nsid := int(m.Id) % len(s.config.Nameservers)
	try := 0
Redo:
	r, _, err := s.dnsUDPclient.Exchange(m, s.config.Nameservers[nsid])
	if err == nil {
		if r.Rcode != dns.RcodeSuccess {
			return nil, fmt.Errorf("rcode is not equal to success")
		}
		// Reset TTLs to rcache TTL to make some of the other code
		// and the tests not care about TTLs
		for _, rr := range r.Answer {
			rr.Header().Ttl = uint32(s.config.RCacheTtl)
		}
		for _, rr := range r.Extra {
			rr.Header().Ttl = uint32(s.config.RCacheTtl)
		}
		return r, nil
	}
	// Seen an error, this can only mean, "server not reached", try again
	// but only if we have not exausted our nameservers.
	if try < len(s.config.Nameservers) {
		try++
		nsid = (nsid + 1) % len(s.config.Nameservers)
		goto Redo
	}
	return nil, fmt.Errorf("failure to lookup name")
}
Example #16
0
// Lookup looks up name,type using the recursive nameserver defines
// in the server's config. If none defined it returns an error
func (s *server) Lookup(n string, t, bufsize uint16, dnssec bool) (*dns.Msg, error) {
	StatsLookupCount.Inc(1)
	if len(s.config.Nameservers) == 0 {
		return nil, fmt.Errorf("no nameservers configured can not lookup name")
	}
	if dns.CountLabel(n) < s.config.Ndots {
		return nil, fmt.Errorf("name has fewer than %d labels", s.config.Ndots)
	}
	m := new(dns.Msg)
	m.SetQuestion(n, t)
	m.SetEdns0(bufsize, dnssec)

	c := &dns.Client{Net: "udp", ReadTimeout: 2 * s.config.ReadTimeout, WriteTimeout: 2 * s.config.ReadTimeout}
	nsid := int(m.Id) % len(s.config.Nameservers)
	try := 0
Redo:
	// Move this to use s.udpClient/s.tcpClient code instead of allocating a new client for every query.
	r, _, err := c.Exchange(m, s.config.Nameservers[nsid])
	if err == nil {
		if r.Rcode != dns.RcodeSuccess {
			return nil, fmt.Errorf("rcode is not equal to success")
		}
		// Reset TTLs to rcache TTL to make some of the other code
		// and the tests not care about TTLs
		for _, rr := range r.Answer {
			rr.Header().Ttl = uint32(s.config.RCacheTtl)
		}
		for _, rr := range r.Extra {
			rr.Header().Ttl = uint32(s.config.RCacheTtl)
		}
		return r, nil
	}
	// Seen an error, this can only mean, "server not reached", try again
	// but only if we have not exausted our nameservers.
	if try < len(s.config.Nameservers) {
		try++
		nsid = (nsid + 1) % len(s.config.Nameservers)
		goto Redo
	}
	return nil, fmt.Errorf("failure to lookup name")
}
Example #17
0
func setDefaults(config *Config) error {
	if config.ReadTimeout == 0 {
		config.ReadTimeout = 2 * time.Second
	}
	if config.WriteTimeout == 0 {
		config.WriteTimeout = 2 * time.Second
	}
	if config.DnsAddr == "" {
		config.DnsAddr = "127.0.0.1:53"
	}
	if config.Domain == "" {
		config.Domain = "skydns.local"
	}

	if len(config.Nameservers) == 0 {
		c, err := dns.ClientConfigFromFile("/etc/resolv.conf")
		if err != nil {
			return err
		}
		for _, s := range c.Servers {
			config.Nameservers = append(config.Nameservers, net.JoinHostPort(s, c.Port))
		}
	}
	if config.DNSSEC != "" {
		k, p, err := ParseKeyFile(config.DNSSEC)
		if err != nil {
			return err
		}
		if k.Header().Name != dns.Fqdn(config.Domain) {
			return fmt.Errorf("ownername of DNSKEY must match SkyDNS domain")
		}
		config.PubKey = k
		config.KeyTag = k.KeyTag()
		config.PrivKey = p
	}
	config.Domain = dns.Fqdn(strings.ToLower(config.Domain))
	config.DomainLabels = dns.CountLabel(config.Domain)
	return nil
}
Example #18
0
func getKey(req *dns.Msg) (*CacheKey, error) {
	questions := req.Question
	if len(questions) != 1 {
		return nil, errors.New("Invalid number of questions")
	}
	question := questions[0]
	if question.Qclass != dns.ClassINET {
		return nil, errors.New("Unsupported question class")
	}
	if dns.CountLabel(question.Name) < *minLabelsCount {
		return nil, errors.New("Not enough labels")
	}
	dnssec := false
	for _, extra := range req.Extra {
		if extra.Header().Rrtype == dns.TypeOPT {
			dnssec = extra.(*dns.OPT).Do()
		}
	}
	CacheKey := CacheKey{Name: strings.ToLower(question.Name),
		Qtype: question.Qtype, DNSSEC: dnssec}
	return &CacheKey, nil
}
Example #19
0
func (r *Resolver) iterateParents(qname string, qtype string, depth int) (RRs, error) {
	chanRRs := make(chan RRs, MaxNameservers)
	chanErrs := make(chan error, MaxNameservers)
	for pname, ok := qname, true; ok; pname, ok = parent(pname) {
		if pname == qname && qtype == "NS" { // If we’re looking for [foo.com,NS], then skip to [com,NS]
			continue
		}

		// Only query TLDs against the root nameservers
		if pname == "." && dns.CountLabel(qname) != 1 {
			// fmt.Fprintf(os.Stderr, "Warning: non-TLD query at root: dig +norecurse %s %s\n", qname, qtype)
			return nil, nil
		}

		// Get nameservers
		nrrs, err := r.resolve(pname, "NS", depth)
		if err == NXDOMAIN {
			return nil, err
		}
		if err != nil {
			continue
		}

		// Early out for specific queries
		if len(nrrs) > 0 && qtype != "" {
			rrs, err := r.cacheGet(qname, qtype)
			if err != nil {
				return nil, err
			}
			if len(rrs) > 0 {
				return rrs, nil
			}
		}

		// Query all nameservers in parallel
		count := 0
		for _, nrr := range nrrs {
			if nrr.Type != "NS" {
				continue
			}

			go func(host string) {
				rrs, err := r.exchange(host, qname, qtype, depth)
				if err != nil {
					chanErrs <- err
				} else {
					chanRRs <- rrs
				}
			}(nrr.Value)

			count++
			if count >= MaxNameservers {
				break
			}
		}

		// Wait for first response
		for ; count > 0; count-- {
			select {
			case rrs := <-chanRRs:
				for _, nrr := range nrrs {
					if nrr.Name == qname {
						rrs = append(rrs, nrr)
					}
				}
				return r.resolveCNAMEs(qname, qtype, rrs, depth)
			case err = <-chanErrs:
				if err == NXDOMAIN {
					return nil, err
				}
			}
		}

		// NS queries naturally recurse, so stop further iteration
		if qtype == "NS" {
			return nil, err
		}
	}

	return nil, ErrNoResponse
}
Example #20
0
// ServeDNSForward forwards a request to a nameservers and returns the response.
func (s *server) ServeDNSForward(w dns.ResponseWriter, req *dns.Msg) *dns.Msg {
	StatsForwardCount.Inc(1)
	promExternalRequestCount.WithLabelValues("recursive").Inc()

	if s.config.NoRec {
		m := new(dns.Msg)
		m.SetReply(req)
		m.SetRcode(req, dns.RcodeServerFailure)
		m.Authoritative = false
		m.RecursionAvailable = false
		w.WriteMsg(m)
		return m
	}

	if len(s.config.Nameservers) == 0 || dns.CountLabel(req.Question[0].Name) < s.config.Ndots {
		if s.config.Verbose {
			if len(s.config.Nameservers) == 0 {
				logf("can not forward, no nameservers defined")
			} else {
				logf("can not forward, name too short (less than %d labels): `%s'", s.config.Ndots, req.Question[0].Name)
			}
		}
		m := new(dns.Msg)
		m.SetReply(req)
		m.SetRcode(req, dns.RcodeServerFailure)
		m.Authoritative = false     // no matter what set to false
		m.RecursionAvailable = true // and this is still true
		w.WriteMsg(m)
		return m
	}

	tcp := isTCP(w)

	var (
		r   *dns.Msg
		err error
		try int
	)
	// Use request Id for "random" nameserver selection.
	nsid := int(req.Id) % len(s.config.Nameservers)
Redo:
	switch tcp {
	case false:
		r, _, err = s.dnsUDPclient.Exchange(req, s.config.Nameservers[nsid])
	case true:
		r, _, err = s.dnsTCPclient.Exchange(req, s.config.Nameservers[nsid])
	}
	if err == nil {
		r.Compress = true
		r.Id = req.Id
		w.WriteMsg(r)
		return r
	}
	// Seen an error, this can only mean, "server not reached", try again
	// but only if we have not exausted our nameservers.
	if try < len(s.config.Nameservers) {
		try++
		nsid = (nsid + 1) % len(s.config.Nameservers)
		goto Redo
	}

	logf("failure to forward request %q", err)
	m := new(dns.Msg)
	m.SetReply(req)
	m.SetRcode(req, dns.RcodeServerFailure)
	w.WriteMsg(m)
	return m
}
Example #21
0
// ServeDNSForward resolves a query by forwarding to a recursive nameserver
func (s *server) ServeDNSForward(w dns.ResponseWriter, req *dns.Msg) *dns.Msg {
	name := req.Question[0].Name
	nameDots := dns.CountLabel(name) - 1
	refuse := false

	switch {
	case s.config.NoRec:
		log.Debugf("[%d] Refusing query, recursion disabled", req.Id)
		refuse = true
	case len(s.config.Nameservers) == 0:
		log.Debugf("[%d] Refusing query, no nameservers configured", req.Id)
		refuse = true
	case nameDots < s.config.FwdNdots && !s.config.EnableSearch:
		log.Debugf("[%d] Refusing query, qname '%s' too short to forward", req.Id, name)
		refuse = true
	}

	if refuse {
		m := new(dns.Msg)
		m.SetRcode(req, dns.RcodeRefused)
		writeMsg(w, m)
		return m
	}

	StatsForwardCount.Inc(1)

	var searchEnabled, didAbsolute, didSearch bool
	var absoluteRes, searchRes *dns.Msg // responses from absolute/search lookups
	var absoluteErr, searchErr error    // errors from absolute/search lookups

	tcp := isTCP(w)

	if s.config.EnableSearch && len(s.config.SearchDomains) > 0 {
		searchEnabled = true
	}

	// If there are enough dots in the name, start with trying to
	// resolve the literal name
	if nameDots >= s.config.Ndots {
		if nameDots >= s.config.FwdNdots {
			log.Debugf("[%d] Doing initial absolute lookup", req.Id)
			absoluteRes, absoluteErr = s.forwardQuery(req, tcp)
			if absoluteErr != nil {
				log.Errorf("[%d] Error looking up literal qname '%s' with upstreams: %v", req.Id, name, absoluteErr)
			}

			if absoluteErr == nil && absoluteRes.Rcode == dns.RcodeSuccess {
				log.Debugf("[%d] Initial lookup yielded result. Response to client: %s",
					req.Id, dns.RcodeToString[absoluteRes.Rcode])
				absoluteRes.Compress = true
				absoluteRes.Id = req.Id
				writeMsg(w, absoluteRes)
				return absoluteRes
			}
			didAbsolute = true
		} else {
			log.Debugf("[%d] Not doing absolute lookup, name too short: '%s'", req.Id, name)
		}
	}

	// We do at least one level of search if search is enabled
	// and we didn't previously fail to query the upstreams
	if absoluteErr == nil && searchEnabled {
		log.Debugf("[%d] Doing search lookup", req.Id)
		searchRes, searchErr = s.forwardSearch(req, tcp)
		if searchErr != nil {
			log.Errorf("[%d] Error looking up qname '%s' with search: %v", req.Id, name, searchErr)
		}

		if searchErr == nil && searchRes.Rcode == dns.RcodeSuccess {
			log.Debugf("[%d] Search lookup yielded result. Response to client: %s",
				req.Id, dns.RcodeToString[searchRes.Rcode])
			searchRes.Compress = true
			searchRes.Id = req.Id
			writeMsg(w, searchRes)
			return searchRes
		}
		didSearch = true
	}

	// If we didn't yet do an absolute lookup then do that now
	// if there are enough dots in the name and searching did
	// not previously fail
	if searchErr == nil && !didAbsolute {
		if nameDots >= s.config.FwdNdots {
			log.Debugf("[%d] Doing absolute lookup", req.Id)
			absoluteRes, absoluteErr = s.forwardQuery(req, tcp)
			if absoluteErr != nil {
				log.Errorf("[%d] Error resolving literal qname '%s': %v", req.Id, name, absoluteErr)
			}

			if absoluteErr == nil && absoluteRes.Rcode == dns.RcodeSuccess {
				log.Debugf("[%d] Absolute lookup yielded result. Response to client: %s",
					req.Id, dns.RcodeToString[absoluteRes.Rcode])
				absoluteRes.Compress = true
				absoluteRes.Id = req.Id
				writeMsg(w, absoluteRes)
				return absoluteRes
			}
			didAbsolute = true
		} else {
			log.Debugf("[%d] Not doing absolute lookup, name too short: '%s'", req.Id, name)
		}
	}

	// If we got here, we failed to get a positive result for the query.
	// If we did an absolute query, return that result, otherwise return
	// a no-data response with the rcode from the last search we did.
	if didAbsolute && absoluteErr == nil {
		log.Debugf("[%d] Failed to resolve query. Returning response of absolute lookup: %s",
			req.Id, dns.RcodeToString[absoluteRes.Rcode])
		absoluteRes.Compress = true
		absoluteRes.Id = req.Id
		writeMsg(w, absoluteRes)
		return absoluteRes
	}

	if didSearch && searchErr == nil {
		log.Debugf("[%d] Failed to resolve query. Returning no-data response: %s",
			req.Id, dns.RcodeToString[searchRes.Rcode])
		m := new(dns.Msg)
		m.SetRcode(req, searchRes.Rcode)
		writeMsg(w, m)
		return m
	}

	// If we got here, we either failed to forward the query or the qname was too
	// short to forward.
	log.Debugf("[%d] Error forwarding query. Returning SRVFAIL.", req.Id)
	m := new(dns.Msg)
	m.SetRcode(req, dns.RcodeServerFailure)
	writeMsg(w, m)
	return m
}