Exemplo n.º 1
0
func (d dnsAPI) Recurse(w dns.ResponseWriter, req *dns.Msg) {
	var client dns.Client

	if isTCP(w.RemoteAddr()) {
		client.Net = "tcp"
	}

	for _, recursor := range d.Recursors {
		res, _, err := client.Exchange(req, recursor)
		if err != nil {
			continue
		}
		w.WriteMsg(res)
		return
	}

	// Return SERVFAIL
	res := &dns.Msg{}
	res.RecursionAvailable = true
	res.SetRcode(req, dns.RcodeServerFailure)
	w.WriteMsg(res)
}
Exemplo n.º 2
0
func (d dnsAPI) ServiceLookup(w dns.ResponseWriter, req *dns.Msg) {
	qName := req.Question[0].Name
	qType := req.Question[0].Qtype
	name := strings.TrimSuffix(strings.ToLower(dns.Fqdn(qName)), d.Domain)
	labels := dns.SplitDomainName(name)
	tcp := isTCP(w.RemoteAddr())

	res := &dns.Msg{}
	res.Authoritative = true
	res.RecursionAvailable = len(d.Recursors) > 0
	res.SetReply(req)
	defer func() {
		if res.Rcode == dns.RcodeSuccess && qType == dns.TypeSOA {
			// SOA answer if requested. at the end of the request to ensure we didn't hit NXDOMAIN
			res.Answer = []dns.RR{d.soaRecord()}
		}
		if len(res.Answer) == 0 {
			// Add authority section with SOA if the answer has no items
			res.Ns = []dns.RR{d.soaRecord()}
		}
		w.WriteMsg(res)
	}()

	nxdomain := func() { res.SetRcode(req, dns.RcodeNameError) }

	var service string
	var proto string
	var instanceID string
	var leader bool
	switch {
	case len(labels) == 1:
		// normal lookup
		service = labels[0]
	case len(labels) == 2 && strings.HasPrefix(labels[0], "_") && strings.HasPrefix(labels[1], "_"):
		// RFC 2782 request looks like _postgres._tcp
		service = labels[0][1:]
		proto = labels[1][1:]
	case len(labels) == 3 && labels[2] == "_i":
		// address lookup for instance in RFC 2782 SRV record
		service = labels[1]
		instanceID = labels[0]
	case len(labels) == 2 && labels[0] == "leader":
		// leader lookup
		leader = true
		service = labels[1]
	default:
		nxdomain()
		return
	}

	var instances []*discoverd.Instance
	if !leader {
		instances = d.Store.Get(service)
		if instances == nil {
			nxdomain()
			return
		}
	}

	if leader || instanceID != "" {
		// we're doing a lookup for a single instance
		var resInst *discoverd.Instance
		if leader {
			resInst = d.Store.GetLeader(service)
		} else {
			for _, inst := range instances {
				if inst.ID == instanceID {
					resInst = inst
					break
				}
			}
		}
		if resInst == nil {
			nxdomain()
			return
		}

		addr := parseAddr(resInst)
		if qType != dns.TypeA && qType != dns.TypeAAAA && qType != dns.TypeANY && qType != dns.TypeSRV ||
			addr.IPv4 == nil && qType == dns.TypeA ||
			addr.IPv6 == nil && qType == dns.TypeAAAA {
			// no results if we're looking up an record that doesn't match the
			// request type or the type is incorrect
			return
		}
		res.Answer = make([]dns.RR, 0, 2)
		if qType != dns.TypeSRV {
			res.Answer = append(res.Answer, addrRecord(qName, addr))
		}
		if qType == dns.TypeSRV || qType == dns.TypeANY {
			res.Answer = append(res.Answer, d.srvRecord(qName, service, addr, false))
		}
		if tcp && qType == dns.TypeSRV {
			res.Extra = []dns.RR{addrRecord(qName, addr)}
		}
		return
	}

	if qType == dns.TypeSOA {
		// We don't need to do any more processing, as NXDOMAIN can't be reached
		// beyond this point, the SOA answer is added in the deferred function
		// above
		return
	}

	addrs := make([]*addrData, 0, len(instances))
	added := make(map[string]struct{}, len(instances))
	for _, inst := range instances {
		if proto != "" && inst.Proto != proto {
			continue
		}
		addr := parseAddr(inst)
		if _, ok := added[addr.String]; ok {
			continue
		}
		if addr.IPv4 == nil && qType == dns.TypeA || addr.IPv6 == nil && qType == dns.TypeAAAA {
			// Skip instance if we have an IPv6 address but want IPv4 or vice versa
			continue
		}
		if qType != dns.TypeSRV {
			// skip duplicate IPs if we're not doing an SRV lookup
			added[addr.String] = struct{}{}
		}
		addrs = append(addrs, addr)
	}
	if len(addrs) == 0 {
		// return empty response
		return
	}
	shuffle(addrs)

	// Truncate the response if we're using UDP
	if !tcp && len(addrs) > maxUDPRecords {
		addrs = addrs[:maxUDPRecords]
	}

	res.Answer = make([]dns.RR, 0, len(addrs)*2)
	for _, addr := range addrs {
		if qType == dns.TypeANY || qType == dns.TypeA || qType == dns.TypeAAAA {
			res.Answer = append(res.Answer, addrRecord(qName, addr))
		}
	}
	for _, addr := range addrs {
		if qType == dns.TypeANY || qType == dns.TypeSRV {
			res.Answer = append(res.Answer, d.srvRecord(qName, service, addr, true))
		}
	}

	if qType == dns.TypeSRV && tcp {
		// Add extra records mapping instance IDs to addresses
		res.Extra = make([]dns.RR, len(addrs))
		for i, addr := range addrs {
			res.Extra[i] = addrRecord(d.instanceDomain(service, addr.ID), addr)
		}
	}
}