func TestInFlightEDns0(t *T) { m1 := new(dns.Msg) m1.SetQuestion(testAnyDomain, dns.TypeA) m1.SetEdns0(4096, false) w1 := getWriter() m2 := new(dns.Msg) m2.SetQuestion(testAnyDomain, dns.TypeA) w2 := getWriter() go func() { handleRequest(w1, m1) }() go func() { handleRequest(w2, m2) }() var r1 *dns.Msg var r2 *dns.Msg for r1 == nil || r2 == nil { select { case r1 = <-w1.ReplyCh: case r2 = <-w2.ReplyCh: } } //note: this test could be flaky since we're relying on google to return //edns0 response when we send one vs when we don't send one assert.NotNil(t, r1.IsEdns0()) assert.Nil(t, r2.IsEdns0()) }
// truncate removes answers until the given dns.Msg fits the permitted // length of the given transmission channel and sets the TC bit. // See https://tools.ietf.org/html/rfc1035#section-4.2.1 func truncate(m *dns.Msg, udp bool) *dns.Msg { max := dns.MinMsgSize if !udp { max = dns.MaxMsgSize } else if opt := m.IsEdns0(); opt != nil { max = int(opt.UDPSize()) } m.Truncated = m.Len() > max if !m.Truncated { return m } m.Extra = nil // Drop all extra records first if m.Len() < max { return m } answers := m.Answer[:] left, right := 0, len(m.Answer) for { if left == right { break } mid := (left + right) / 2 m.Answer = answers[:mid] if m.Len() < max { left = mid + 1 continue } right = mid } return m }
// Adds the DNS message data to the supplied MapStr. func addDnsToMapStr(m common.MapStr, dns *mkdns.Msg, authority bool, additional bool) { m["id"] = dns.Id m["op_code"] = dnsOpCodeToString(dns.Opcode) m["flags"] = common.MapStr{ "authoritative": dns.Authoritative, "truncated_response": dns.Truncated, "recursion_desired": dns.RecursionDesired, "recursion_available": dns.RecursionAvailable, "authentic_data": dns.AuthenticatedData, // [RFC4035] "checking_disabled": dns.CheckingDisabled, // [RFC4035] } m["response_code"] = dnsResponseCodeToString(dns.Rcode) if len(dns.Question) > 0 { q := dns.Question[0] qMapStr := common.MapStr{ "name": q.Name, "type": dnsTypeToString(q.Qtype), "class": dnsClassToString(q.Qclass), } m["question"] = qMapStr eTLDPlusOne, err := publicsuffix.EffectiveTLDPlusOne(strings.TrimRight(q.Name, ".")) if err == nil { qMapStr["etld_plus_one"] = eTLDPlusOne + "." } } rrOPT := dns.IsEdns0() if rrOPT != nil { m["opt"] = optToMapStr(rrOPT) } m["answers_count"] = len(dns.Answer) if len(dns.Answer) > 0 { m["answers"] = rrsToMapStrs(dns.Answer) } m["authorities_count"] = len(dns.Ns) if authority && len(dns.Ns) > 0 { m["authorities"] = rrsToMapStrs(dns.Ns) } if rrOPT != nil { m["additionals_count"] = len(dns.Extra) - 1 } else { m["additionals_count"] = len(dns.Extra) } if additional && len(dns.Extra) > 0 { rrsMapStrs := rrsToMapStrs(dns.Extra) // We do not want OPT RR to appear in the 'additional' section, // that's why rrsMapStrs could be empty even though len(dns.Extra) > 0 if len(rrsMapStrs) > 0 { m["additionals"] = rrsMapStrs } } }
// ServeDNSStubForward forwards a request to a nameservers and returns the response. func (s *server) ServeDNSStubForward(w dns.ResponseWriter, req *dns.Msg, ns []string) *dns.Msg { // Check EDNS0 Stub option, if set drop the packet. option := req.IsEdns0() if option != nil { for _, o := range option.Option { if o.Option() == ednsStubCode && len(o.(*dns.EDNS0_LOCAL).Data) == 1 && o.(*dns.EDNS0_LOCAL).Data[0] == 1 { // Maybe log source IP here? logf("not fowarding stub request to another stub") return nil } } } tcp := isTCP(w) // Add a custom EDNS0 option to the packet, so we can detect loops // when 2 stubs are forwarding to each other. if option != nil { option.Option = append(option.Option, &dns.EDNS0_LOCAL{ednsStubCode, []byte{1}}) } else { req.Extra = append(req.Extra, ednsStub) } var ( r *dns.Msg err error try int ) // Use request Id for "random" nameserver selection. nsid := int(req.Id) % len(ns) Redo: switch tcp { case false: r, _, err = s.dnsUDPclient.Exchange(req, ns[nsid]) case true: r, _, err = s.dnsTCPclient.Exchange(req, ns[nsid]) } if err == nil { r.Compress = true r.Id = req.Id w.WriteMsg(r) return r } // Seen an error, this can only mean, "server not reached", try again // but only if we have not exausted our nameservers. if try < len(ns) { try++ nsid = (nsid + 1) % len(ns) goto Redo } logf("failure to forward stub request %q", err) m := s.ServerFailure(req) w.WriteMsg(m) return m }
// get the maximum UDP-reply length func getMaxReplyLen(r *dns.Msg, proto dnsProtocol) int { maxLen := minUDPSize if proto == protTCP { maxLen = maxUDPSize } else if opt := r.IsEdns0(); opt != nil { maxLen = int(opt.UDPSize()) } return maxLen }
func getMaxPayloadSize(req *dns.Msg) uint16 { opt := req.IsEdns0() if opt == nil { return dns.MinMsgSize } maxPayloadSize := opt.UDPSize() if maxPayloadSize < dns.MinMsgSize { maxPayloadSize = dns.MinMsgSize } return maxPayloadSize }
// addStubEdns0 adds our special option to the message's OPT record. func addStubEdns0(m *dns.Msg) *dns.Msg { option := m.IsEdns0() // Add a custom EDNS0 option to the packet, so we can detect loops when 2 stubs are forwarding to each other. if option != nil { option.Option = append(option.Option, &dns.EDNS0_LOCAL{Code: ednsStubCode, Data: []byte{1}}) return m } m.Extra = append(m.Extra, ednsStub) return m }
// hasStubEdns0 checks if the message is carrying our special edns0 zero option. func hasStubEdns0(m *dns.Msg) bool { option := m.IsEdns0() if option == nil { return false } for _, o := range option.Option { if o.Option() == ednsStubCode && len(o.(*dns.EDNS0_LOCAL).Data) == 1 && o.(*dns.EDNS0_LOCAL).Data[0] == 1 { return true } } return false }
func Respond(w dns.ResponseWriter, req *dns.Msg, records []dns.RR) { m := new(dns.Msg) m.SetReply(req) m.Authoritative = true m.RecursionAvailable = true m.Compress = true m.Answer = records // Figure out the max response size bufsize := uint16(512) tcp := isTcp(w) if o := req.IsEdns0(); o != nil { bufsize = o.UDPSize() } if tcp { bufsize = dns.MaxMsgSize - 1 } else if bufsize < 512 { bufsize = 512 } if m.Len() > dns.MaxMsgSize { fqdn := dns.Fqdn(req.Question[0].Name) log.WithFields(log.Fields{"fqdn": fqdn}).Debug("Response too big, dropping Extra") m.Extra = nil if m.Len() > dns.MaxMsgSize { log.WithFields(log.Fields{"fqdn": fqdn}).Debug("Response still too big") m := new(dns.Msg) m.SetRcode(m, dns.RcodeServerFailure) } } if m.Len() > int(bufsize) && !tcp { log.Debug("Too big 1") m.Extra = nil if m.Len() > int(bufsize) { log.Debug("Too big 2") m.Answer = nil m.Truncated = true } } err := w.WriteMsg(m) if err != nil { log.Warn("Failed to return reply: ", err, m.Len()) } }
// our ServeDNS interface, which gets invoked on every DNS message func (this ServerProxy) ServeDNS(w dns.ResponseWriter, request *dns.Msg) { // see if we have our groovy custom EDNS0 option client_supports_appfrag := false opt := request.IsEdns0() if opt != nil { for ofs, e := range opt.Option { if e.Option() == dns.EDNS0LOCALSTART { _D("%s QID:%d found EDNS0LOCALSTART", w.RemoteAddr(), request.Id) client_supports_appfrag = true // go ahead and use the maximum UDP size for the local communication // with our server opt.SetUDPSize(65535) // remove the fragmentation option opt.Option = append(opt.Option[0:ofs], opt.Option[ofs+1:]...) // in principle we should only have one of these options break } } } // proxy the query c := new(dns.Client) c.ReadTimeout = this.timeout c.WriteTimeout = this.timeout response, rtt, err := c.Exchange(request, this.SERVERS[rand.Intn(this.s_len)]) if err != nil { _D("%s QID:%d error proxying query: %s", w.RemoteAddr(), request.Id, err) this.SRVFAIL(w, request) return } _D("%s QID:%d request took %s", w.RemoteAddr(), request.Id, rtt) // if the client does not support fragmentation, we just send the response back and finish if !client_supports_appfrag { _D("%s QID:%d sending raw response to client", w.RemoteAddr(), request.Id) w.WriteMsg(response) return } // otherwise lets get our fragments all_frags := frag(response) // send our fragments for n, frag := range all_frags { _D("%s QID:%d sending fragment %d", w.RemoteAddr(), request.Id, n) w.WriteMsg(&frag) } }
// extract out the total fragments and sequence number from the EDNS0 informaton in a packet func get_fragment_info(msg *dns.Msg) (num_frags int, sequence_num int) { num_frags = -1 sequence_num = -1 resp_edns0 := msg.IsEdns0() if resp_edns0 != nil { for _, opt := range resp_edns0.Option { if opt.Option() == dns.EDNS0LOCALSTART+1 { num_frags = int(opt.(*dns.EDNS0_LOCAL).Data[0]) sequence_num = int(opt.(*dns.EDNS0_LOCAL).Data[1]) // we only expect this option to be here once break } } } return num_frags, sequence_num }
func getMsgKey(r *dns.Msg) string { k := "" for _, q := range r.Question { k += getQuestionKey(q) } //RFC 1035 defines this as the max //since some clients (read: go) don't support edns0, this is the default //if they do send edns0 we'll raise it to what they sent limit := dns.MinMsgSize opt := r.IsEdns0() if opt != nil { limit = int(opt.UDPSize()) llog.Debug("received edns0 limit", llog.KV{"limit": limit}) } k += strconv.Itoa(limit) return k }
// SizeAndDo adds an OPT record that the reflects the intent from state. // The returned bool indicated if an record was found and normalised. func (s *State) SizeAndDo(m *dns.Msg) bool { o := s.Req.IsEdns0() // TODO(miek): speed this up if o == nil { return false } o.Hdr.Name = "." o.Hdr.Rrtype = dns.TypeOPT o.SetVersion(0) if mo := m.IsEdns0(); mo != nil { mo.Hdr.Name = "." mo.Hdr.Rrtype = dns.TypeOPT mo.SetVersion(0) mo.SetUDPSize(o.UDPSize()) if o.Do() { mo.SetDo() } return true } m.Extra = append(m.Extra, o) return true }
func Respond(w dns.ResponseWriter, req *dns.Msg, m *dns.Msg) { // Figure out the max response size bufsize := uint16(512) tcp := isTcp(w) if o := req.IsEdns0(); o != nil { bufsize = o.UDPSize() } if tcp { bufsize = dns.MaxMsgSize - 1 } else if bufsize < 512 { bufsize = 512 } // Make sure the payload fits the buffer size. If the message is too large we strip the Extra section. // If it's still too large we return a truncated message for UDP queries and ServerFailure for TCP queries. if m.Len() > int(bufsize) { fqdn := dns.Fqdn(req.Question[0].Name) log.WithFields(log.Fields{"fqdn": fqdn}).Debug("Response too big, dropping Authority and Extra") m.Extra = nil if m.Len() > int(bufsize) { if tcp { log.WithFields(log.Fields{"fqdn": fqdn}).Debug("Response still too big, return ServerFailure") m = new(dns.Msg) m.SetRcode(req, dns.RcodeServerFailure) } else { log.WithFields(log.Fields{"fqdn": fqdn}).Debug("Response still too big, return truncated message") m.Answer = nil m.Truncated = true } } } err := w.WriteMsg(m) if err != nil { log.Warn("Failed to return reply: ", err, m.Len()) } }
// Edns0Version checks the EDNS version in the request. If error // is nil everything is OK and we can invoke the middleware. If non-nil, the // returned Msg is valid to be returned to the client (and should). For some // reason this response should not contain a question RR in the question section. func Edns0Version(req *dns.Msg) (*dns.Msg, error) { opt := req.IsEdns0() if opt == nil { return nil, nil } if opt.Version() == 0 { return nil, nil } m := new(dns.Msg) m.SetReply(req) // zero out question section, wtf. m.Question = nil o := new(dns.OPT) o.Hdr.Name = "." o.Hdr.Rrtype = dns.TypeOPT o.SetVersion(0) o.SetExtendedRcode(dns.RcodeBadVers) m.Extra = []dns.RR{o} return m, errors.New("EDNS0 BADVERS") }
// Classify classifies a message, it returns the MessageType. func Classify(m *dns.Msg) (MsgType, *dns.OPT) { opt := m.IsEdns0() if len(m.Answer) > 0 && m.Rcode == dns.RcodeSuccess { return Success, opt } soa := false ns := 0 for _, r := range m.Ns { if r.Header().Rrtype == dns.TypeSOA { soa = true continue } if r.Header().Rrtype == dns.TypeNS { ns++ } } // Check length of different sections, and drop stuff that is just to large? TODO(miek). if soa && m.Rcode == dns.RcodeSuccess { return NoData, opt } if soa && m.Rcode == dns.RcodeNameError { return NameError, opt } if ns > 0 && ns == len(m.Ns) && m.Rcode == dns.RcodeSuccess { return Delegation, opt } if m.Rcode == dns.RcodeSuccess { return Success, opt } return OtherError, opt }
// ServeDNS is the handler for DNS requests, responsible for parsing DNS request, possibly forwarding // it to a real dns server and returning a response. func (s *server) ServeDNS(w dns.ResponseWriter, req *dns.Msg) { q := req.Question[0] name := strings.ToLower(q.Name) if !strings.HasSuffix(name, s.config.Domain) { s.ServeDNSForward(w, req) return } m := new(dns.Msg) m.SetReply(req) m.Authoritative = true m.RecursionAvailable = true m.Answer = make([]dns.RR, 0, 10) defer func() { // Set TTL to the minimum of the RRset. minttl := s.config.Ttl if len(m.Answer) > 1 { for _, r := range m.Answer { if r.Header().Ttl < minttl { minttl = r.Header().Ttl } } for _, r := range m.Answer { r.Header().Ttl = minttl } } // Check if we need to do DNSSEC and sign the reply. if s.config.PubKey != nil { if opt := req.IsEdns0(); opt != nil && opt.Do() { s.Denial(m) s.sign(m, opt.UDPSize()) } } w.WriteMsg(m) }() if name == s.config.Domain { switch q.Qtype { case dns.TypeDNSKEY: if s.config.PubKey != nil { m.Answer = append(m.Answer, s.config.PubKey) return } case dns.TypeSOA: m.Answer = []dns.RR{s.NewSOA()} return } } if q.Qtype == dns.TypeA || q.Qtype == dns.TypeAAAA { records, err := s.AddressRecords(q) if err != nil { if e, ok := err.(*etcd.EtcdError); ok { if e.ErrorCode == 100 { m.SetRcode(req, dns.RcodeNameError) m.Ns = []dns.RR{s.NewSOA()} return } } } m.Answer = append(m.Answer, records...) } if q.Qtype == dns.TypeSRV || q.Qtype == dns.TypeANY { records, extra, err := s.SRVRecords(q) if err != nil { if e, ok := err.(*etcd.EtcdError); ok { if e.ErrorCode == 100 { m.SetRcode(req, dns.RcodeNameError) m.Ns = []dns.RR{s.NewSOA()} return } } } m.Answer = append(m.Answer, records...) m.Extra = append(m.Extra, extra...) } if len(m.Answer) == 0 { // NODATA response m.Ns = []dns.RR{s.NewSOA()} } }
func (srv *Server) serve(w dns.ResponseWriter, req *dns.Msg, z *Zone) { qname := req.Question[0].Name qtype := req.Question[0].Qtype var qle *querylog.Entry if srv.queryLogger != nil { qle = &querylog.Entry{ Time: time.Now().UnixNano(), Origin: z.Origin, Name: qname, Qtype: qtype, } defer srv.queryLogger.Write(qle) } logPrintf("[zone %s] incoming %s %s (id %d) from %s\n", z.Origin, qname, dns.TypeToString[qtype], req.Id, w.RemoteAddr()) // Global meter metrics.Get("queries").(metrics.Meter).Mark(1) // Zone meter z.Metrics.Queries.Mark(1) logPrintln("Got request", req) label := getQuestionName(z, req) z.Metrics.LabelStats.Add(label) // IP that's talking to us (not EDNS CLIENT SUBNET) var realIP net.IP if addr, ok := w.RemoteAddr().(*net.UDPAddr); ok { realIP = make(net.IP, len(addr.IP)) copy(realIP, addr.IP) } else if addr, ok := w.RemoteAddr().(*net.TCPAddr); ok { realIP = make(net.IP, len(addr.IP)) copy(realIP, addr.IP) } if qle != nil { qle.RemoteAddr = realIP.String() } z.Metrics.ClientStats.Add(realIP.String()) var ip net.IP // EDNS or real IP var edns *dns.EDNS0_SUBNET var opt_rr *dns.OPT for _, extra := range req.Extra { switch extra.(type) { case *dns.OPT: for _, o := range extra.(*dns.OPT).Option { opt_rr = extra.(*dns.OPT) switch e := o.(type) { case *dns.EDNS0_NSID: // do stuff with e.Nsid case *dns.EDNS0_SUBNET: z.Metrics.EdnsQueries.Mark(1) logPrintln("Got edns", e.Address, e.Family, e.SourceNetmask, e.SourceScope) if e.Address != nil { edns = e ip = e.Address if qle != nil { qle.HasECS = true qle.ClientAddr = fmt.Sprintf("%s/%d", ip, e.SourceNetmask) } } } } } } if len(ip) == 0 { // no edns subnet ip = realIP if qle != nil { qle.ClientAddr = fmt.Sprintf("%s/%d", ip, len(ip)*8) } } targets, netmask := z.Options.Targeting.GetTargets(ip) if qle != nil { qle.Targets = targets } m := new(dns.Msg) if qle != nil { defer func() { qle.Rcode = m.Rcode qle.Answers = len(m.Answer) }() } m.SetReply(req) if e := m.IsEdns0(); e != nil { m.SetEdns0(4096, e.Do()) } m.Authoritative = true // TODO: set scope to 0 if there are no alternate responses if edns != nil { if edns.Family != 0 { if netmask < 16 { netmask = 16 } edns.SourceScope = uint8(netmask) m.Extra = append(m.Extra, opt_rr) } } labels, labelQtype := z.findLabels(label, targets, qTypes{dns.TypeMF, dns.TypeCNAME, qtype}) if labelQtype == 0 { labelQtype = qtype } if labels == nil { permitDebug := !*flagPrivateDebug || (realIP != nil && realIP.IsLoopback()) firstLabel := (strings.Split(label, "."))[0] if qle != nil { qle.LabelName = firstLabel } if permitDebug && firstLabel == "_status" { if qtype == dns.TypeANY || qtype == dns.TypeTXT { m.Answer = statusRR(label + "." + z.Origin + ".") } else { m.Ns = append(m.Ns, z.SoaRR()) } m.Authoritative = true w.WriteMsg(m) return } if firstLabel == "_country" { if qtype == dns.TypeANY || qtype == dns.TypeTXT { h := dns.RR_Header{Ttl: 1, Class: dns.ClassINET, Rrtype: dns.TypeTXT} h.Name = label + "." + z.Origin + "." txt := []string{ w.RemoteAddr().String(), ip.String(), } targets, netmask := z.Options.Targeting.GetTargets(ip) txt = append(txt, strings.Join(targets, " ")) txt = append(txt, fmt.Sprintf("/%d", netmask), serverID, serverIP) m.Answer = []dns.RR{&dns.TXT{Hdr: h, Txt: txt, }} } else { m.Ns = append(m.Ns, z.SoaRR()) } m.Authoritative = true w.WriteMsg(m) return } // return NXDOMAIN m.SetRcode(req, dns.RcodeNameError) m.Authoritative = true m.Ns = []dns.RR{z.SoaRR()} w.WriteMsg(m) return } if servers := labels.Picker(labelQtype, labels.MaxHosts); servers != nil { var rrs []dns.RR for _, record := range servers { rr := dns.Copy(record.RR) rr.Header().Name = qname rrs = append(rrs, rr) } m.Answer = rrs } if len(m.Answer) == 0 { // Return a SOA so the NOERROR answer gets cached m.Ns = append(m.Ns, z.SoaRR()) } logPrintln(m) if qle != nil { qle.LabelName = labels.Label qle.Answers = len(m.Answer) qle.Rcode = m.Rcode } err := w.WriteMsg(m) if err != nil { // if Pack'ing fails the Write fails. Return SERVFAIL. log.Println("Error writing packet", m) dns.HandleFailed(w, req) } return }
func getMaxResponseSize(req *dns.Msg, defaultMaxResponseSize int) int { if opt := req.IsEdns0(); opt != nil { return int(opt.UDPSize()) } return defaultMaxResponseSize }
// ServeDNS is the handler for DNS requests, responsible for parsing DNS request, possibly forwarding // it to a real dns server and returning a response. func (s *server) ServeDNS(w dns.ResponseWriter, req *dns.Msg) { m := new(dns.Msg) m.SetReply(req) m.Authoritative = true m.RecursionAvailable = true m.Compress = true bufsize := uint16(512) dnssec := false tcp := false start := time.Now() if req.Question[0].Qtype == dns.TypeANY { m.Authoritative = false m.Rcode = dns.RcodeRefused m.RecursionAvailable = false m.RecursionDesired = false m.Compress = false // if write fails don't care w.WriteMsg(m) promErrorCount.WithLabelValues("refused").Inc() return } if o := req.IsEdns0(); o != nil { bufsize = o.UDPSize() dnssec = o.Do() } if bufsize < 512 { bufsize = 512 } // with TCP we can send 64K if tcp = isTCP(w); tcp { bufsize = dns.MaxMsgSize - 1 promRequestCount.WithLabelValues("tcp").Inc() } else { promRequestCount.WithLabelValues("udp").Inc() } StatsRequestCount.Inc(1) if dnssec { StatsDnssecOkCount.Inc(1) promDnssecOkCount.Inc() } defer func() { promCacheSize.WithLabelValues("response").Set(float64(s.rcache.Size())) }() // Check cache first. key := cache.QuestionKey(req.Question[0], dnssec) m1, exp, hit := s.rcache.Search(key) if hit { // Cache hit! \o/ if time.Since(exp) < 0 { m1.Id = m.Id m1.Compress = true m1.Truncated = false if dnssec { // The key for DNS/DNSSEC in cache is different, no // need to do Denial/Sign here. //if s.config.PubKey != nil { //s.Denial(m1) // not needed for cache hits //s.Sign(m1, bufsize) //} } if m1.Len() > int(bufsize) && !tcp { promErrorCount.WithLabelValues("truncated").Inc() m1.Truncated = true } // Still round-robin even with hits from the cache. // Only shuffle A and AAAA records with each other. if req.Question[0].Qtype == dns.TypeA || req.Question[0].Qtype == dns.TypeAAAA { s.RoundRobin(m1.Answer) } if err := w.WriteMsg(m1); err != nil { log.Printf("skydns: failure to return reply %q", err) } metricSizeAndDuration(m1, start, tcp) return } // Expired! /o\ s.rcache.Remove(key) } q := req.Question[0] name := strings.ToLower(q.Name) if s.config.Verbose { log.Printf("skydns: received DNS Request for %q from %q with type %d", q.Name, w.RemoteAddr(), q.Qtype) } for zone, ns := range *s.config.stub { if strings.HasSuffix(name, zone) { resp := s.ServeDNSStubForward(w, req, ns) metricSizeAndDuration(resp, start, tcp) return } } // If the qname is local.dns.skydns.local. and s.config.Local != "", substitute that name. if s.config.Local != "" && name == s.config.localDomain { name = s.config.Local } if q.Qtype == dns.TypePTR && strings.HasSuffix(name, ".in-addr.arpa.") || strings.HasSuffix(name, ".ip6.arpa.") { resp := s.ServeDNSReverse(w, req) metricSizeAndDuration(resp, start, tcp) return } if q.Qclass != dns.ClassCHAOS && !strings.HasSuffix(name, s.config.Domain) { if s.config.Verbose { log.Printf("skydns: %q is not sub of %q, forwarding...", name, s.config.Domain) } resp := s.ServeDNSForward(w, req) metricSizeAndDuration(resp, start, tcp) return } promCacheMiss.WithLabelValues("response").Inc() defer func() { if m.Rcode == dns.RcodeServerFailure { if err := w.WriteMsg(m); err != nil { log.Printf("skydns: failure to return reply %q", err) } return } // Set TTL to the minimum of the RRset and dedup the message, i.e. // remove identical RRs. m = s.dedup(m) minttl := s.config.Ttl if len(m.Answer) > 1 { for _, r := range m.Answer { if r.Header().Ttl < minttl { minttl = r.Header().Ttl } } for _, r := range m.Answer { r.Header().Ttl = minttl } } if !m.Truncated { s.rcache.InsertMessage(cache.QuestionKey(req.Question[0], dnssec), m) } if dnssec { if s.config.PubKey != nil { m.AuthenticatedData = true s.Denial(m) s.Sign(m, bufsize) } } if m.Len() > dns.MaxMsgSize { log.Printf("skydns: overflowing maximum message size: %d, dropping additional section", m.Len()) m.Extra = nil // Drop entire additional section to see if this helps. if m.Len() > dns.MaxMsgSize { // *Still* too large. log.Printf("skydns: still overflowing maximum message size: %d", m.Len()) promErrorCount.WithLabelValues("overflow").Inc() m1 := new(dns.Msg) // Use smaller msg to signal failure. m1.SetRcode(m, dns.RcodeServerFailure) if err := w.WriteMsg(m1); err != nil { log.Printf("skydns: failure to return reply %q", err) } metricSizeAndDuration(m1, start, tcp) return } } if m.Len() > int(bufsize) && !tcp { m.Extra = nil // As above, drop entire additional section. if m.Len() > int(bufsize) { promErrorCount.WithLabelValues("truncated").Inc() m.Truncated = true } } if err := w.WriteMsg(m); err != nil { log.Printf("skydns: failure to return reply %q %d", err, m.Len()) } metricSizeAndDuration(m, start, tcp) }() if name == s.config.Domain { if q.Qtype == dns.TypeSOA { m.Answer = []dns.RR{s.NewSOA()} return } if q.Qtype == dns.TypeDNSKEY { if s.config.PubKey != nil { m.Answer = []dns.RR{s.config.PubKey} return } } } if q.Qclass == dns.ClassCHAOS { if q.Qtype == dns.TypeTXT { switch name { case "authors.bind.": fallthrough case s.config.Domain: hdr := dns.RR_Header{Name: q.Name, Rrtype: dns.TypeTXT, Class: dns.ClassCHAOS, Ttl: 0} authors := []string{"Erik St. Martin", "Brian Ketelsen", "Miek Gieben", "Michael Crosby"} for _, a := range authors { m.Answer = append(m.Answer, &dns.TXT{Hdr: hdr, Txt: []string{a}}) } for j := 0; j < len(authors)*(int(dns.Id())%4+1); j++ { q := int(dns.Id()) % len(authors) p := int(dns.Id()) % len(authors) if q == p { p = (p + 1) % len(authors) } m.Answer[q], m.Answer[p] = m.Answer[p], m.Answer[q] } return case "version.bind.": fallthrough case "version.server.": hdr := dns.RR_Header{Name: q.Name, Rrtype: dns.TypeTXT, Class: dns.ClassCHAOS, Ttl: 0} m.Answer = []dns.RR{&dns.TXT{Hdr: hdr, Txt: []string{Version}}} return case "hostname.bind.": fallthrough case "id.server.": // TODO(miek): machine name to return hdr := dns.RR_Header{Name: q.Name, Rrtype: dns.TypeTXT, Class: dns.ClassCHAOS, Ttl: 0} m.Answer = []dns.RR{&dns.TXT{Hdr: hdr, Txt: []string{"localhost"}}} return } } // still here, fail m.SetReply(req) m.SetRcode(req, dns.RcodeServerFailure) return } switch q.Qtype { case dns.TypeNS: if name != s.config.Domain { log.Printf("skydns: %q unmatch default domain", name) break } // Lookup s.config.DnsDomain records, extra, err := s.NSRecords(q, s.config.dnsDomain) if err != nil { if e, ok := err.(*etcd.EtcdError); ok { if e.ErrorCode == 100 { s.NameError(m, req) return } } } m.Answer = append(m.Answer, records...) m.Extra = append(m.Extra, extra...) case dns.TypeA, dns.TypeAAAA: records, err := s.AddressRecords(q, name, nil, bufsize, dnssec, false) if err != nil { if e, ok := err.(*etcd.EtcdError); ok { if e.ErrorCode == 100 { s.NameError(m, req) return } } } m.Answer = append(m.Answer, records...) case dns.TypeTXT: records, err := s.TXTRecords(q, name) if err != nil { if e, ok := err.(*etcd.EtcdError); ok { if e.ErrorCode == 100 { s.NameError(m, req) return } } } m.Answer = append(m.Answer, records...) case dns.TypeCNAME: records, err := s.CNAMERecords(q, name) if err != nil { if e, ok := err.(*etcd.EtcdError); ok { if e.ErrorCode == 100 { s.NameError(m, req) return } } } m.Answer = append(m.Answer, records...) case dns.TypeMX: records, extra, err := s.MXRecords(q, name, bufsize, dnssec) if err != nil { if e, ok := err.(*etcd.EtcdError); ok { if e.ErrorCode == 100 { s.NameError(m, req) return } } } m.Answer = append(m.Answer, records...) m.Extra = append(m.Extra, extra...) default: fallthrough // also catch other types, so that they return NODATA case dns.TypeSRV: records, extra, err := s.SRVRecords(q, name, bufsize, dnssec) if err != nil { if e, ok := err.(*etcd.EtcdError); ok { if e.ErrorCode == 100 { s.NameError(m, req) return } } if q.Qtype == dns.TypeSRV { // Otherwise NODATA s.ServerFailure(m, req) return } } // if we are here again, check the types, because an answer may only // be given for SRV. All other types should return NODATA, the // NXDOMAIN part is handled in the above code. TODO(miek): yes this // can be done in a more elegant manor. if q.Qtype == dns.TypeSRV { m.Answer = append(m.Answer, records...) m.Extra = append(m.Extra, extra...) } } if len(m.Answer) == 0 { // NODATA response StatsNoDataCount.Inc(1) m.Ns = []dns.RR{s.NewSOA()} m.Ns[0].Header().Ttl = s.config.MinTtl } }
/* For fragmentation, we use a naive algorithm. We use the same header for every fragment, and include the same EDNS0 section in every additional section. We add one RR at a time, until our fragment is larger than 512 bytes, then we remove the last RR so that it fits in the 512 byte size limit. If we discover that one of the fragments ends up with 0 RR in it (for example because a single RR is too big), then we return a single truncated response instead of the set of fragments. We could perhaps make the process of building fragments faster by bisecting the set of RR that we include in an answer. So, if we have 8 RR we could try all, then if that is too big, 4 RR, and if that fits then 6 RR, until an optimal set of RR is found. We could also possibly produce a smaller set of responses by optimizing how we combine RR. Just taking account the various sizes is the same as the bin packing problem, which is NP-hard: https://en.wikipedia.org/wiki/Bin_packing_problem While some non-optimal but reasonable heuristics exist, in the case of DNS we would have to use some sophisticated algorithm to also consider name compression. */ func frag(reply *dns.Msg) []dns.Msg { // create a return value all_frags := []dns.Msg{} HasEdns0 := true // get each RR section and save a copy out remaining_answer := make([]dns.RR, len(reply.Answer)) copy(remaining_answer, reply.Answer) remaining_ns := make([]dns.RR, len(reply.Ns)) copy(remaining_ns, reply.Ns) remaining_extra := make([]dns.RR, len(reply.Extra)) copy(remaining_extra, reply.Extra) // if we don't have EDNS0 in the packet, add it now if reply.IsEdns0() == nil { reply.SetEdns0(512, false) } // the EDNS option for later use var edns0_rr dns.RR = nil // remove the EDNS0 option from our additional ("extra") section // (we will include it separately on every fragment) for ofs, r := range remaining_extra { // found the EDNS option if r.Header().Rrtype == dns.TypeOPT { // save the EDNS option edns0_rr = r // remove from the set of extra RR remaining_extra = append(remaining_extra[0:ofs], remaining_extra[ofs+1:]...) // in principle we should only have one EDNS0 section break } } if edns0_rr == nil { log.Printf("Server reply missing EDNS0 option") return []dns.Msg{} //HasEdns0 = false } // now build fragments for { // make a shallow copy of our reply packet, and prepare space for our RR frag := *reply frag.Answer = []dns.RR{} frag.Ns = []dns.RR{} frag.Extra = []dns.RR{} // add our custom EDNS0 option (needed in every fragment) local_opt := new(dns.EDNS0_LOCAL) local_opt.Code = dns.EDNS0LOCALSTART + 1 local_opt.Data = []byte{0, 0} if HasEdns0 == true { edns0_rr_copy := dns.Copy(edns0_rr) edns0_rr_copy.(*dns.OPT).Option = append(edns0_rr_copy.(*dns.OPT).Option, local_opt) frag.Extra = append(frag.Extra, edns0_rr_copy) } //if HasEdns0 == false { // frag.Extra = append(frag.Extra, local_opt) //} // add as many RR to the answer as we can for len(remaining_answer) > 0 { frag.Answer = append(frag.Answer, remaining_answer[0]) if frag.Len() <= 512 { // if the new answer fits, then remove it from our remaining list remaining_answer = remaining_answer[1:] } else { // otherwise we are full, remove it from our fragment and stop frag.Answer = frag.Answer[0 : len(frag.Answer)-1] break } } for len(remaining_ns) > 0 { frag.Ns = append(frag.Ns, remaining_ns[0]) if frag.Len() <= 512 { // if the new answer fits, then remove it from our remaining list remaining_ns = remaining_ns[1:] } else { // otherwise we are full, remove it from our fragment and stop frag.Ns = frag.Ns[0 : len(frag.Ns)-1] break } } for len(remaining_extra) > 0 { frag.Extra = append(frag.Extra, remaining_extra[0]) if frag.Len() <= 512 { // if the new answer fits, then remove it from our remaining list remaining_extra = remaining_extra[1:] } else { // otherwise we are full, remove it from our fragment and stop frag.Extra = frag.Extra[0 : len(frag.Extra)-1] break } } // check to see if we didn't manage to add any RR if (len(frag.Answer) == 0) && (len(frag.Ns) == 0) && (len(frag.Extra) == 1) { // TODO: test this :) // return a single truncated fragment without any RR frag.MsgHdr.Truncated = true frag.Extra = []dns.RR{} return []dns.Msg{frag} } // add to our list of fragments all_frags = append(all_frags, frag) // if we have finished all remaining sections, we are done if (len(remaining_answer) == 0) && (len(remaining_ns) == 0) && (len(remaining_extra) == 0) { break } } // fix up our fragments so they have the correct sequence and length values for n, frag := range all_frags { frag_edns0 := frag.IsEdns0() for _, opt := range frag_edns0.Option { if opt.Option() == dns.EDNS0LOCALSTART+1 { opt.(*dns.EDNS0_LOCAL).Data = []byte{byte(len(all_frags)), byte(n)} } } } // return our fragments return all_frags }
func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) { var ( extConn net.Conn resp *dns.Msg err error ) if query == nil || len(query.Question) == 0 { return } name := query.Question[0].Name if query.Question[0].Qtype == dns.TypeA { resp, err = r.handleIPv4Query(name, query) } else if query.Question[0].Qtype == dns.TypePTR { resp, err = r.handlePTRQuery(name, query) } if err != nil { log.Error(err) return } proto := w.LocalAddr().Network() maxSize := 0 if proto == "tcp" { maxSize = dns.MaxMsgSize - 1 } else if proto == "udp" { optRR := query.IsEdns0() if optRR != nil { maxSize = int(optRR.UDPSize()) } if maxSize < defaultRespSize { maxSize = defaultRespSize } } if resp != nil { if resp.Len() > maxSize { truncateResp(resp, maxSize, proto == "tcp") } } else { for i := 0; i < maxExtDNS; i++ { extDNS := &r.extDNSList[i] if extDNS.ipStr == "" { break } log.Debugf("Querying ext dns %s:%s for %s[%d]", proto, extDNS.ipStr, name, query.Question[0].Qtype) extConnect := func() { addr := fmt.Sprintf("%s:%d", extDNS.ipStr, 53) extConn, err = net.DialTimeout(proto, addr, extIOTimeout) } // For udp clients connection is persisted to reuse for further queries. // Accessing extDNS.extConn be a race here between go rouines. Hence the // connection setup is done in a Once block and fetch the extConn again extConn = extDNS.extConn if extConn == nil || proto == "tcp" { if proto == "udp" { extDNS.extOnce.Do(func() { r.sb.execFunc(extConnect) extDNS.extConn = extConn }) extConn = extDNS.extConn } else { r.sb.execFunc(extConnect) } if err != nil { log.Debugf("Connect failed, %s", err) continue } } // If two go routines are executing in parralel one will // block on the Once.Do and in case of error connecting // to the external server it will end up with a nil err // but extConn also being nil. if extConn == nil { continue } // Timeout has to be set for every IO operation. extConn.SetDeadline(time.Now().Add(extIOTimeout)) co := &dns.Conn{Conn: extConn} defer func() { if proto == "tcp" { co.Close() } }() err = co.WriteMsg(query) if err != nil { log.Debugf("Send to DNS server failed, %s", err) continue } resp, err = co.ReadMsg() if err != nil { log.Debugf("Read from DNS server failed, %s", err) continue } resp.Compress = true break } if resp == nil { return } } err = w.WriteMsg(resp) if err != nil { log.Errorf("error writing resolver resp, %s", err) } }
// ServeDNS is the handler for DNS requests, responsible for parsing DNS request, possibly forwarding // it to a real dns server and returning a response. func (s *server) ServeDNS(w dns.ResponseWriter, req *dns.Msg) { m := new(dns.Msg) m.SetReply(req) m.Authoritative = true m.RecursionAvailable = true m.Compress = true bufsize := uint16(512) dnssec := false tcp := false if req.Question[0].Qtype == dns.TypeANY { m.Authoritative = false m.Rcode = dns.RcodeRefused m.RecursionAvailable = false m.RecursionDesired = false m.Compress = false // if write fails don't care w.WriteMsg(m) return } if o := req.IsEdns0(); o != nil { bufsize = o.UDPSize() dnssec = o.Do() } if bufsize < 512 { bufsize = 512 } // with TCP we can send 64K if _, ok := w.RemoteAddr().(*net.TCPAddr); ok { bufsize = dns.MaxMsgSize - 1 tcp = true } // Check cache first. key := cache.QuestionKey(req.Question[0], dnssec) m1, exp, hit := s.rcache.Search(key) if hit { // Cache hit! \o/ if time.Since(exp) < 0 { m1.Id = m.Id if dnssec { StatsDnssecOkCount.Inc(1) // The key for DNS/DNSSEC in cache is different, no // need to do Denial/Sign here. //if s.config.PubKey != nil { //s.Denial(m1) // not needed for cache hits //s.Sign(m1, bufsize) //} } if m1.Len() > int(bufsize) && !tcp { m1.Truncated = true } // Still round-robin even with hits from the cache. if s.config.RoundRobin && (req.Question[0].Qtype == dns.TypeA || req.Question[0].Qtype == dns.TypeAAAA) { switch l := len(m1.Answer); l { case 2: if dns.Id()%2 == 0 { m1.Answer[0], m1.Answer[1] = m1.Answer[1], m1.Answer[0] } default: // Do a minimum of l swap, maximum of 4l swaps for j := 0; j < l*(int(dns.Id())%4+1); j++ { q := int(dns.Id()) % l p := int(dns.Id()) % l if q == p { p = (p + 1) % l } m1.Answer[q], m1.Answer[p] = m1.Answer[p], m1.Answer[q] } } } if err := w.WriteMsg(m1); err != nil { s.config.log.Errorf("failure to return reply %q", err) } return } // Expired! /o\ s.rcache.Remove(key) } q := req.Question[0] name := strings.ToLower(q.Name) StatsRequestCount.Inc(1) if verbose { s.config.log.Infof("received DNS Request for %q from %q with type %d", q.Name, w.RemoteAddr(), q.Qtype) } // If the qname is local.dns.skydns.local. and s.config.Local != "", substitute that name. if s.config.Local != "" && name == s.config.localDomain { name = s.config.Local } if q.Qtype == dns.TypePTR && strings.HasSuffix(name, ".in-addr.arpa.") || strings.HasSuffix(name, ".ip6.arpa.") { s.ServeDNSReverse(w, req) return } if q.Qclass != dns.ClassCHAOS && !strings.HasSuffix(name, s.config.Domain) { s.ServeDNSForward(w, req) return } defer func() { if m.Rcode == dns.RcodeServerFailure { if err := w.WriteMsg(m); err != nil { s.config.log.Errorf("failure to return reply %q", err) } return } // Set TTL to the minimum of the RRset. minttl := s.config.Ttl if len(m.Answer) > 1 { for _, r := range m.Answer { if r.Header().Ttl < minttl { minttl = r.Header().Ttl } } for _, r := range m.Answer { r.Header().Ttl = minttl } } s.rcache.InsertMessage(cache.QuestionKey(req.Question[0], dnssec), m) if dnssec { StatsDnssecOkCount.Inc(1) if s.config.PubKey != nil { m.AuthenticatedData = true s.Denial(m) s.Sign(m, bufsize) } } if m.Len() > int(bufsize) && !tcp { // TODO(miek): this is a little brain dead, better is to not add // RRs in the message in the first place. m.Truncated = true } if err := w.WriteMsg(m); err != nil { s.config.log.Errorf("failure to return reply %q", err) } }() if name == s.config.Domain { if q.Qtype == dns.TypeSOA { m.Answer = []dns.RR{s.NewSOA()} return } if q.Qtype == dns.TypeDNSKEY { if s.config.PubKey != nil { m.Answer = []dns.RR{s.config.PubKey} return } } } if q.Qclass == dns.ClassCHAOS { if q.Qtype == dns.TypeTXT { switch name { case "authors.bind.": fallthrough case s.config.Domain: hdr := dns.RR_Header{Name: q.Name, Rrtype: dns.TypeTXT, Class: dns.ClassCHAOS, Ttl: 0} authors := []string{"Erik St. Martin", "Brian Ketelsen", "Miek Gieben", "Michael Crosby"} for _, a := range authors { m.Answer = append(m.Answer, &dns.TXT{Hdr: hdr, Txt: []string{a}}) } for j := 0; j < len(authors)*(int(dns.Id())%4+1); j++ { q := int(dns.Id()) % len(authors) p := int(dns.Id()) % len(authors) if q == p { p = (p + 1) % len(authors) } m.Answer[q], m.Answer[p] = m.Answer[p], m.Answer[q] } return case "version.bind.": fallthrough case "version.server.": hdr := dns.RR_Header{Name: q.Name, Rrtype: dns.TypeTXT, Class: dns.ClassCHAOS, Ttl: 0} m.Answer = []dns.RR{&dns.TXT{Hdr: hdr, Txt: []string{Version}}} return case "hostname.bind.": fallthrough case "id.server.": // TODO(miek): machine name to return hdr := dns.RR_Header{Name: q.Name, Rrtype: dns.TypeTXT, Class: dns.ClassCHAOS, Ttl: 0} m.Answer = []dns.RR{&dns.TXT{Hdr: hdr, Txt: []string{"localhost"}}} return } } // still here, fail m.SetReply(req) m.SetRcode(req, dns.RcodeServerFailure) return } switch q.Qtype { case dns.TypeNS: if name != s.config.Domain { break } // Lookup s.config.DnsDomain records, extra, err := s.NSRecords(q, s.config.dnsDomain) if err != nil { if e, ok := err.(*etcd.EtcdError); ok { if e.ErrorCode == 100 { s.NameError(m, req) return } } } m.Answer = append(m.Answer, records...) m.Extra = append(m.Extra, extra...) case dns.TypeA, dns.TypeAAAA: records, err := s.AddressRecords(q, name, nil) if err != nil { if e, ok := err.(*etcd.EtcdError); ok { if e.ErrorCode == 100 { s.NameError(m, req) return } } if err.Error() == "incomplete CNAME chain" { // We can not complete the CNAME internally, *iff* there is a // external name in the set, take it, and try to resolve it externally. if len(records) == 0 { s.NameError(m, req) return } target := "" for _, r := range records { if v, ok := r.(*dns.CNAME); ok { if !dns.IsSubDomain(s.config.Domain, v.Target) { target = v.Target break } } } if target == "" { s.config.log.Warningf("incomplete CNAME chain for %s", name) s.NoDataError(m, req) return } m1, e1 := s.Lookup(target, req.Question[0].Qtype, bufsize, dnssec) if e1 != nil { s.config.log.Errorf("%q", err) s.NoDataError(m, req) return } records = append(records, m1.Answer...) } } m.Answer = append(m.Answer, records...) case dns.TypeCNAME: records, err := s.CNAMERecords(q, name) if err != nil { if e, ok := err.(*etcd.EtcdError); ok { if e.ErrorCode == 100 { s.NameError(m, req) return } } } m.Answer = append(m.Answer, records...) default: fallthrough // also catch other types, so that they return NODATA case dns.TypeSRV, dns.TypeANY: records, extra, err := s.SRVRecords(q, name, bufsize, dnssec) if err != nil { if e, ok := err.(*etcd.EtcdError); ok { if e.ErrorCode == 100 { s.NameError(m, req) return } } } // if we are here again, check the types, because an answer may only // be given for SRV or ANY. All other types should return NODATA, the // NXDOMAIN part is handled in the above code. TODO(miek): yes this // can be done in a more elegant manor. if q.Qtype == dns.TypeSRV || q.Qtype == dns.TypeANY { m.Answer = append(m.Answer, records...) m.Extra = append(m.Extra, extra...) } } if len(m.Answer) == 0 { // NODATA response StatsNoDataCount.Inc(1) m.Ns = []dns.RR{s.NewSOA()} m.Ns[0].Header().Ttl = s.config.MinTtl } }
func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) { var ( extConn net.Conn resp *dns.Msg err error ) if query == nil || len(query.Question) == 0 { return } name := query.Question[0].Name switch query.Question[0].Qtype { case dns.TypeA: resp, err = r.handleIPQuery(name, query, types.IPv4) case dns.TypeAAAA: resp, err = r.handleIPQuery(name, query, types.IPv6) case dns.TypePTR: resp, err = r.handlePTRQuery(name, query) case dns.TypeSRV: resp, err = r.handleSRVQuery(name, query) } if err != nil { log.Error(err) return } proto := w.LocalAddr().Network() maxSize := 0 if proto == "tcp" { maxSize = dns.MaxMsgSize - 1 } else if proto == "udp" { optRR := query.IsEdns0() if optRR != nil { maxSize = int(optRR.UDPSize()) } if maxSize < defaultRespSize { maxSize = defaultRespSize } } if resp != nil { if resp.Len() > maxSize { truncateResp(resp, maxSize, proto == "tcp") } } else { for i := 0; i < maxExtDNS; i++ { extDNS := &r.extDNSList[i] if extDNS.ipStr == "" { break } extConnect := func() { addr := fmt.Sprintf("%s:%d", extDNS.ipStr, 53) extConn, err = net.DialTimeout(proto, addr, extIOTimeout) } r.sb.execFunc(extConnect) if err != nil { log.Debugf("Connect failed, %s", err) continue } log.Debugf("Query %s[%d] from %s, forwarding to %s:%s", name, query.Question[0].Qtype, extConn.LocalAddr().String(), proto, extDNS.ipStr) // Timeout has to be set for every IO operation. extConn.SetDeadline(time.Now().Add(extIOTimeout)) co := &dns.Conn{Conn: extConn} defer co.Close() // limits the number of outstanding concurrent queries. if r.forwardQueryStart() == false { old := r.tStamp r.tStamp = time.Now() if r.tStamp.Sub(old) > logInterval { log.Errorf("More than %v concurrent queries from %s", maxConcurrent, extConn.LocalAddr().String()) } continue } err = co.WriteMsg(query) if err != nil { r.forwardQueryEnd() log.Debugf("Send to DNS server failed, %s", err) continue } resp, err = co.ReadMsg() // Truncated DNS replies should be sent to the client so that the // client can retry over TCP if err != nil && err != dns.ErrTruncated { r.forwardQueryEnd() log.Debugf("Read from DNS server failed, %s", err) continue } r.forwardQueryEnd() resp.Compress = true break } if resp == nil { return } } if err = w.WriteMsg(resp); err != nil { log.Errorf("error writing resolver resp, %s", err) } }
// ServeDNS is the handler for DNS requests, responsible for parsing DNS request, possibly forwarding // it to a real dns server and returning a response. func (s *server) ServeDNS(w dns.ResponseWriter, req *dns.Msg) { startTime := time.Now() defer func() { elapsed := time.Since(startTime) log.Debugf("[%d] Response time: %s", req.Id, elapsed) }() m := new(dns.Msg) m.SetReply(req) m.Authoritative = false m.RecursionAvailable = true m.Compress = true bufsize := uint16(512) dnssec := false tcp := false local := true q := req.Question[0] name := strings.ToLower(q.Name) /* if q.Qtype == dns.TypeANY { m.Authoritative = false m.Rcode = dns.RcodeRefused m.RecursionAvailable = false m.RecursionDesired = false m.Compress = false // if write fails don't care w.WriteMsg(m) return }*/ if o := req.IsEdns0(); o != nil { bufsize = o.UDPSize() dnssec = o.Do() } if bufsize < 512 { bufsize = 512 } // with TCP we can send 64K if tcp = isTCP(w); tcp { bufsize = dns.MaxMsgSize - 1 } StatsRequestCount.Inc(1) if dnssec { StatsDnssecOkCount.Inc(1) } log.Debugf("[%d] Got query for '%s %s' from %s", req.Id, dns.TypeToString[q.Qtype], q.Name, w.RemoteAddr().String()) // Check cache first. m1 := s.rcache.Hit(q, dnssec, tcp, m.Id) if m1 != nil { log.Debugf("[%d] Found cached response for this query", req.Id) if tcp { if _, overflow := Fit(m1, dns.MaxMsgSize, tcp); overflow { msgFail := new(dns.Msg) s.ServerFailure(msgFail, req) w.WriteMsg(msgFail) return } } else { // Overflow with udp always results in TC. Fit(m1, int(bufsize), tcp) } if q.Qtype == dns.TypeA || q.Qtype == dns.TypeAAAA { s.RoundRobin(m1.Answer) } if err := w.WriteMsg(m1); err != nil { log.Errorf("Failed to return reply %q", err) } StatsCacheHit.Inc(1) return } StatsCacheMiss.Inc(1) defer func() { if local { if m.Rcode == dns.RcodeServerFailure { if err := w.WriteMsg(m); err != nil { log.Errorf("Failed to return reply %q", err) } return } if tcp { if _, overflow := Fit(m, dns.MaxMsgSize, tcp); overflow { msgFail := new(dns.Msg) s.ServerFailure(msgFail, req) w.WriteMsg(msgFail) return } } else { Fit(m, int(bufsize), tcp) } s.rcache.InsertMessage(cache.Key(q, dnssec, tcp), m) if err := w.WriteMsg(m); err != nil { log.Errorf("Failed to return reply %q", err) } } }() // Check hosts records before forwarding the query if q.Qtype == dns.TypeA || q.Qtype == dns.TypeAAAA || q.Qtype == dns.TypeANY { records, err := s.AddressRecords(q, name) if err != nil { log.Errorf("Error looking up hostsfile records: %s", err) } if len(records) > 0 { log.Debugf("[%d] Found name in hostsfile records", req.Id) m.Answer = append(m.Answer, records...) return } } if q.Qtype == dns.TypePTR && strings.HasSuffix(name, ".in-addr.arpa.") || strings.HasSuffix(name, ".ip6.arpa.") { local = false resp := s.ServeDNSReverse(w, req) if resp != nil { s.rcache.InsertMessage(cache.Key(q, dnssec, tcp), resp) } return } if q.Qclass == dns.ClassCHAOS { m.Authoritative = true if q.Qtype == dns.TypeTXT { switch name { case "version.bind.": fallthrough case "version.server.": hdr := dns.RR_Header{Name: q.Name, Rrtype: dns.TypeTXT, Class: dns.ClassCHAOS, Ttl: 0} m.Answer = []dns.RR{&dns.TXT{Hdr: hdr, Txt: []string{s.version}}} return case "hostname.bind.": fallthrough case "id.server.": // TODO(miek): machine name to return hdr := dns.RR_Header{Name: q.Name, Rrtype: dns.TypeTXT, Class: dns.ClassCHAOS, Ttl: 0} m.Answer = []dns.RR{&dns.TXT{Hdr: hdr, Txt: []string{"localhost"}}} return } } // still here, fail m.SetReply(req) m.SetRcode(req, dns.RcodeServerFailure) return } // Forward all other queries local = false resp := s.ServeDNSForward(w, req) if resp != nil { s.rcache.InsertMessage(cache.Key(q, dnssec, tcp), resp) } }
func (h *handler) getMaxResponseSize(req *dns.Msg) int { if opt := req.IsEdns0(); opt != nil { return int(opt.UDPSize()) } return h.maxResponseSize }
// ServeDNS is the handler for DNS requests, responsible for parsing DNS request, possibly forwarding // it to a real dns server and returning a response. func (s *server) ServeDNS(w dns.ResponseWriter, req *dns.Msg) { //stats.RequestCount.Inc(1) q := req.Question[0] name := strings.ToLower(q.Name) log.Printf("Received DNS Request for %q from %q with type %d", q.Name, w.RemoteAddr(), q.Qtype) if !strings.HasSuffix(name, s.config.Domain) { s.ServeDNSForward(w, req) return } m := new(dns.Msg) m.SetReply(req) m.Authoritative = true m.RecursionAvailable = true m.Answer = make([]dns.RR, 0, 10) defer func() { // Check if we need to do DNSSEC and sign the reply. if s.config.PubKey != nil { if opt := req.IsEdns0(); opt != nil && opt.Do() { s.nsec(m) s.sign(m, opt.UDPSize()) } } w.WriteMsg(m) }() if name == s.config.Domain { switch q.Qtype { case dns.TypeDNSKEY: if s.config.PubKey != nil { m.Answer = append(m.Answer, s.config.PubKey) return } case dns.TypeSOA: m.Answer = []dns.RR{s.SOA()} return } } if q.Qtype == dns.TypeA || q.Qtype == dns.TypeAAAA { records, err := s.AddressRecords(q) if err != nil { m.SetRcode(req, dns.RcodeNameError) m.Ns = []dns.RR{s.SOA()} return } m.Answer = append(m.Answer, records...) } if q.Qtype == dns.TypeSRV || q.Qtype == dns.TypeANY { records, extra, err := s.SRVRecords(q) if err != nil { // NODATA } m.Answer = append(m.Answer, records...) m.Extra = append(m.Extra, extra...) } // FIXME(miek): uh, NXDOMAIN or NODATA? if len(m.Answer) == 0 { // We are authoritative for this name, but it does not exist: NXDOMAIN m.SetRcode(req, dns.RcodeNameError) m.Ns = []dns.RR{s.SOA()} return } if len(m.Answer) == 0 { // Send back a NODATA response m.Ns = []dns.RR{s.SOA()} } }
// ServeDNS is the handler for DNS requests, responsible for parsing DNS request, possibly forwarding // it to a real dns server and returning a response. func (s *Server) ServeDNS(w dns.ResponseWriter, req *dns.Msg) { stats.RequestCount.Inc(1) q := req.Question[0] // Ensure we lowercase question so that proper matching against anchor domain takes place q.Name = strings.ToLower(q.Name) log.Printf("Received DNS Request for %q from %q with type %d", q.Name, w.RemoteAddr(), q.Qtype) // If the query does not fall in our s.domain, forward it if !strings.HasSuffix(q.Name, dns.Fqdn(s.domain)) { s.ServeDNSForward(w, req) return } m := new(dns.Msg) m.SetReply(req) m.Authoritative = true m.RecursionAvailable = true m.Answer = make([]dns.RR, 0, 10) defer func() { // Check if we need to do DNSSEC and sign the reply if s.PublicKey() != nil { if opt := req.IsEdns0(); opt != nil && opt.Do() { s.nsec(m) s.sign(m, opt.UDPSize()) } } w.WriteMsg(m) }() if q.Name == dns.Fqdn(s.domain) { switch q.Qtype { case dns.TypeDNSKEY: if s.PublicKey() != nil { m.Answer = append(m.Answer, s.PublicKey()) return } case dns.TypeSOA: m.Answer = s.createSOA() return } } if q.Qtype == dns.TypeA || q.Qtype == dns.TypeAAAA { records, err := s.getARecords(q) if err != nil { m.SetRcode(req, dns.RcodeNameError) m.Ns = s.createSOA() return } if s.roundrobin { switch l := uint16(len(records)); l { case 1: case 2: if dns.Id()%2 == 0 { records[0], records[1] = records[1], records[0] } default: // Do a minimum of l swap, maximum of 4l swaps for j := 0; j < int(l*(dns.Id()%4+1)); j++ { q := dns.Id() % l p := dns.Id() % l if q == p { p = (p + 1) % l } records[q], records[p] = records[p], records[q] } } } m.Answer = append(m.Answer, records...) } records, extra, err := s.getSRVRecords(q) if err != nil && len(m.Answer) == 0 { // We are authoritative for this name, but it does not exist: NXDOMAIN m.SetRcode(req, dns.RcodeNameError) m.Ns = s.createSOA() return } if q.Qtype == dns.TypeANY || q.Qtype == dns.TypeSRV { m.Answer = append(m.Answer, records...) m.Extra = append(m.Extra, extra...) } if len(m.Answer) == 0 { // Send back a NODATA response m.Ns = s.createSOA() } }
func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) { var ( extConn net.Conn resp *dns.Msg err error ) if query == nil || len(query.Question) == 0 { return } name := query.Question[0].Name switch query.Question[0].Qtype { case dns.TypeA: resp, err = r.handleIPQuery(name, query, types.IPv4) case dns.TypeAAAA: resp, err = r.handleIPQuery(name, query, types.IPv6) case dns.TypePTR: resp, err = r.handlePTRQuery(name, query) case dns.TypeSRV: resp, err = r.handleSRVQuery(name, query) } if err != nil { logrus.Error(err) return } if resp == nil { // If the backend doesn't support proxying dns request // fail the response if !r.proxyDNS { resp = new(dns.Msg) resp.SetRcode(query, dns.RcodeServerFailure) w.WriteMsg(resp) return } // If the user sets ndots > 0 explicitly and the query is // in the root domain don't forward it out. We will return // failure and let the client retry with the search domain // attached switch query.Question[0].Qtype { case dns.TypeA: fallthrough case dns.TypeAAAA: if r.backend.NdotsSet() && !strings.Contains(strings.TrimSuffix(name, "."), ".") { resp = createRespMsg(query) } } } proto := w.LocalAddr().Network() maxSize := 0 if proto == "tcp" { maxSize = dns.MaxMsgSize - 1 } else if proto == "udp" { optRR := query.IsEdns0() if optRR != nil { maxSize = int(optRR.UDPSize()) } if maxSize < defaultRespSize { maxSize = defaultRespSize } } if resp != nil { if resp.Len() > maxSize { truncateResp(resp, maxSize, proto == "tcp") } } else { for i := 0; i < maxExtDNS; i++ { extDNS := &r.extDNSList[i] if extDNS.ipStr == "" { break } extConnect := func() { addr := fmt.Sprintf("%s:%d", extDNS.ipStr, 53) extConn, err = net.DialTimeout(proto, addr, extIOTimeout) } if extDNS.hostLoopback { extConnect() } else { execErr := r.backend.ExecFunc(extConnect) if execErr != nil { logrus.Warn(execErr) continue } } if err != nil { logrus.Warnf("Connect failed: %s", err) continue } logrus.Debugf("Query %s[%d] from %s, forwarding to %s:%s", name, query.Question[0].Qtype, extConn.LocalAddr().String(), proto, extDNS.ipStr) // Timeout has to be set for every IO operation. extConn.SetDeadline(time.Now().Add(extIOTimeout)) co := &dns.Conn{ Conn: extConn, UDPSize: uint16(maxSize), } defer co.Close() // limits the number of outstanding concurrent queries. if r.forwardQueryStart() == false { old := r.tStamp r.tStamp = time.Now() if r.tStamp.Sub(old) > logInterval { logrus.Errorf("More than %v concurrent queries from %s", maxConcurrent, extConn.LocalAddr().String()) } continue } err = co.WriteMsg(query) if err != nil { r.forwardQueryEnd() logrus.Debugf("Send to DNS server failed, %s", err) continue } resp, err = co.ReadMsg() // Truncated DNS replies should be sent to the client so that the // client can retry over TCP if err != nil && err != dns.ErrTruncated { r.forwardQueryEnd() logrus.Debugf("Read from DNS server failed, %s", err) continue } r.forwardQueryEnd() if resp != nil { for _, rr := range resp.Answer { h := rr.Header() switch h.Rrtype { case dns.TypeA: ip := rr.(*dns.A).A r.backend.HandleQueryResp(h.Name, ip) case dns.TypeAAAA: ip := rr.(*dns.AAAA).AAAA r.backend.HandleQueryResp(h.Name, ip) } } } resp.Compress = true break } if resp == nil { return } } if err = w.WriteMsg(resp); err != nil { logrus.Errorf("error writing resolver resp, %s", err) } }
// ServeDNS is the handler for DNS requests, responsible for parsing DNS request, possibly forwarding // it to a real dns server and returning a response. func (s *server) ServeDNS(w dns.ResponseWriter, req *dns.Msg) { m := new(dns.Msg) m.SetReply(req) m.Authoritative = true m.RecursionAvailable = true m.Compress = true bufsize := uint16(512) dnssec := false tcp := false start := time.Now() q := req.Question[0] name := strings.ToLower(q.Name) if q.Qtype == dns.TypeANY { m.Authoritative = false m.Rcode = dns.RcodeRefused m.RecursionAvailable = false m.RecursionDesired = false m.Compress = false // if write fails don't care w.WriteMsg(m) promErrorCount.WithLabelValues("refused").Inc() return } if o := req.IsEdns0(); o != nil { bufsize = o.UDPSize() dnssec = o.Do() } if bufsize < 512 { bufsize = 512 } // with TCP we can send 64K if tcp = isTCP(w); tcp { bufsize = dns.MaxMsgSize - 1 promRequestCount.WithLabelValues("tcp").Inc() } else { promRequestCount.WithLabelValues("udp").Inc() } StatsRequestCount.Inc(1) if dnssec { StatsDnssecOkCount.Inc(1) promDnssecOkCount.Inc() } if s.config.Verbose { logf("received DNS Request for %q from %q with type %d", q.Name, w.RemoteAddr(), q.Qtype) } // Check cache first. m1 := s.rcache.Hit(q, dnssec, tcp, m.Id) if m1 != nil { if tcp { if _, overflow := Fit(m1, dns.MaxMsgSize, tcp); overflow { promErrorCount.WithLabelValues("overflow").Inc() msgFail := new(dns.Msg) s.ServerFailure(msgFail, req) w.WriteMsg(msgFail) return } } else { // Overflow with udp always results in TC. Fit(m1, int(bufsize), tcp) if m1.Truncated { promErrorCount.WithLabelValues("truncated").Inc() } } // Still round-robin even with hits from the cache. // Only shuffle A and AAAA records with each other. if q.Qtype == dns.TypeA || q.Qtype == dns.TypeAAAA { s.RoundRobin(m1.Answer) } if err := w.WriteMsg(m1); err != nil { logf("failure to return reply %q", err) } metricSizeAndDuration(m1, start, tcp) return } for zone, ns := range *s.config.stub { if strings.HasSuffix(name, zone) { resp := s.ServeDNSStubForward(w, req, ns) if resp != nil { s.rcache.InsertMessage(cache.Key(q, dnssec, tcp), resp) metricSizeAndDuration(resp, start, tcp) } return } } // If the qname is local.ds.skydns.local. and s.config.Local != "", substitute that name. if s.config.Local != "" && name == s.config.localDomain { name = s.config.Local } if q.Qtype == dns.TypePTR && strings.HasSuffix(name, ".in-addr.arpa.") || strings.HasSuffix(name, ".ip6.arpa.") { resp := s.ServeDNSReverse(w, req) if resp != nil { s.rcache.InsertMessage(cache.Key(q, dnssec, tcp), resp) metricSizeAndDuration(resp, start, tcp) } return } if q.Qclass != dns.ClassCHAOS && !strings.HasSuffix(name, s.config.Domain) { resp := s.ServeDNSForward(w, req) if resp != nil { s.rcache.InsertMessage(cache.Key(q, dnssec, tcp), resp) metricSizeAndDuration(resp, start, tcp) } return } promCacheMiss.WithLabelValues("response").Inc() defer func() { if m.Rcode == dns.RcodeServerFailure { if err := w.WriteMsg(m); err != nil { logf("failure to return reply %q", err) } return } // Set TTL to the minimum of the RRset and dedup the message, i.e. remove identical RRs. m = s.dedup(m) minttl := s.config.Ttl if len(m.Answer) > 1 { for _, r := range m.Answer { if r.Header().Ttl < minttl { minttl = r.Header().Ttl } } for _, r := range m.Answer { r.Header().Ttl = minttl } } if dnssec { if s.config.PubKey != nil { m.AuthenticatedData = true s.Denial(m) s.Sign(m, bufsize) } } if tcp { if _, overflow := Fit(m, dns.MaxMsgSize, tcp); overflow { msgFail := new(dns.Msg) s.ServerFailure(msgFail, req) w.WriteMsg(msgFail) return } } else { Fit(m, int(bufsize), tcp) if m.Truncated { promErrorCount.WithLabelValues("truncated").Inc() } } s.rcache.InsertMessage(cache.Key(q, dnssec, tcp), m) if err := w.WriteMsg(m); err != nil { logf("failure to return reply %q", err) } metricSizeAndDuration(m, start, tcp) }() if name == s.config.Domain { if q.Qtype == dns.TypeSOA { m.Answer = []dns.RR{s.NewSOA()} return } if q.Qtype == dns.TypeDNSKEY { if s.config.PubKey != nil { m.Answer = []dns.RR{s.config.PubKey} return } } } if q.Qclass == dns.ClassCHAOS { if q.Qtype == dns.TypeTXT { switch name { case "authors.bind.": fallthrough case s.config.Domain: hdr := dns.RR_Header{Name: q.Name, Rrtype: dns.TypeTXT, Class: dns.ClassCHAOS, Ttl: 0} authors := []string{"Erik St. Martin", "Brian Ketelsen", "Miek Gieben", "Michael Crosby"} for _, a := range authors { m.Answer = append(m.Answer, &dns.TXT{Hdr: hdr, Txt: []string{a}}) } for j := 0; j < len(authors)*(int(dns.Id())%4+1); j++ { q := int(dns.Id()) % len(authors) p := int(dns.Id()) % len(authors) if q == p { p = (p + 1) % len(authors) } m.Answer[q], m.Answer[p] = m.Answer[p], m.Answer[q] } return case "version.bind.": fallthrough case "version.server.": hdr := dns.RR_Header{Name: q.Name, Rrtype: dns.TypeTXT, Class: dns.ClassCHAOS, Ttl: 0} m.Answer = []dns.RR{&dns.TXT{Hdr: hdr, Txt: []string{Version}}} return case "hostname.bind.": fallthrough case "id.server.": // TODO(miek): machine name to return hdr := dns.RR_Header{Name: q.Name, Rrtype: dns.TypeTXT, Class: dns.ClassCHAOS, Ttl: 0} m.Answer = []dns.RR{&dns.TXT{Hdr: hdr, Txt: []string{"localhost"}}} return } } // still here, fail m.SetReply(req) m.SetRcode(req, dns.RcodeServerFailure) return } switch q.Qtype { case dns.TypeNS: if name != s.config.Domain { break } // Lookup s.config.DnsDomain records, extra, err := s.NSRecords(q, s.config.dnsDomain) if isEtcdNameError(err, s) { s.NameError(m, req) return } m.Answer = append(m.Answer, records...) m.Extra = append(m.Extra, extra...) case dns.TypeA, dns.TypeAAAA: records, err := s.AddressRecords(q, name, nil, bufsize, dnssec, false) if isEtcdNameError(err, s) { s.NameError(m, req) return } m.Answer = append(m.Answer, records...) case dns.TypeTXT: records, err := s.TXTRecords(q, name) if isEtcdNameError(err, s) { s.NameError(m, req) return } m.Answer = append(m.Answer, records...) case dns.TypeCNAME: records, err := s.CNAMERecords(q, name) if isEtcdNameError(err, s) { s.NameError(m, req) return } m.Answer = append(m.Answer, records...) case dns.TypeMX: records, extra, err := s.MXRecords(q, name, bufsize, dnssec) if isEtcdNameError(err, s) { s.NameError(m, req) return } m.Answer = append(m.Answer, records...) m.Extra = append(m.Extra, extra...) default: fallthrough // also catch other types, so that they return NODATA case dns.TypeSRV: records, extra, err := s.SRVRecords(q, name, bufsize, dnssec) if err != nil { if isEtcdNameError(err, s) { s.NameError(m, req) return } logf("got error from backend: %s", err) if q.Qtype == dns.TypeSRV { // Otherwise NODATA s.ServerFailure(m, req) return } } // if we are here again, check the types, because an answer may only // be given for SRV. All other types should return NODATA, the // NXDOMAIN part is handled in the above code. TODO(miek): yes this // can be done in a more elegant manor. if q.Qtype == dns.TypeSRV { m.Answer = append(m.Answer, records...) m.Extra = append(m.Extra, extra...) } } if len(m.Answer) == 0 { // NODATA response StatsNoDataCount.Inc(1) m.Ns = []dns.RR{s.NewSOA()} m.Ns[0].Header().Ttl = s.config.MinTtl } }