// Forwards a DNS request to the specified nameservers. On success, the // original reply packet will be returned to the caller. On failure, a // new packet will be returned with `RCODE` set to `SERVFAIL`. // Even though the original `ResponseWriter` object is taken as an argument, // this function does not send a reply to the client. Instead, the // packet is returned for further processing by the caller. func getServerReply(w dns.ResponseWriter, req *dns.Msg) *dns.Msg { if *verbose { log.Print("Forwarding ", req.Question[0].Name, "/", dns.Type(req.Question[0].Qtype).String()) } // create a new DNS client client := &dns.Client{Net: "udp", ReadTimeout: 4 * time.Second, WriteTimeout: 4 * time.Second, SingleInflight: true} if _, tcp := w.RemoteAddr().(*net.TCPAddr); tcp { client.Net = "tcp" } var r *dns.Msg var err error // loop through the specified nameservers and forward them the request // the request ID is used as a starting point in order to introduce at least // some element of randomness, instead of always hitting the first nameserver for i := 0; i < len(nameservers); i++ { r, _, err = client.Exchange(req, nameservers[(int(req.Id)+i)%len(nameservers)]) if err == nil { r.Compress = true return r } } log.Print("Failed to forward request.", err) return getEmptyMsg(w, req, dns.RcodeServerFailure) }
func (p ReverseProxy) ServeDNS(w dns.ResponseWriter, r *dns.Msg, extra []dns.RR) error { var ( reply *dns.Msg err error ) switch { case middleware.Proto(w) == "tcp": reply, err = middleware.Exchange(p.Client.TCP, r, p.Host) default: reply, err = middleware.Exchange(p.Client.UDP, r, p.Host) } if reply != nil && reply.Truncated { // Suppress proxy error for truncated responses err = nil } if err != nil { return err } reply.Compress = true reply.Id = r.Id w.WriteMsg(reply) return nil }
// ServeDNSForward forwards a request to a nameservers and returns the response. func (s *server) ServeDNSForward(w dns.ResponseWriter, req *dns.Msg) *dns.Msg { if s.config.NoRec { m := s.ServerFailure(req) w.WriteMsg(m) return m } if len(s.config.Nameservers) == 0 || dns.CountLabel(req.Question[0].Name) < s.config.Ndots { if s.config.Verbose { if len(s.config.Nameservers) == 0 { logf("can not forward, no nameservers defined") } else { logf("can not forward, name too short (less than %d labels): `%s'", s.config.Ndots, req.Question[0].Name) } } m := s.ServerFailure(req) m.RecursionAvailable = true // this is still true w.WriteMsg(m) return m } tcp := isTCP(w) var ( r *dns.Msg err error try int ) nsid := 0 if s.config.NSRotate { // Use request Id for "random" nameserver selection. nsid = int(req.Id) % len(s.config.Nameservers) } Redo: switch tcp { case false: r, _, err = s.dnsUDPclient.Exchange(req, s.config.Nameservers[nsid]) case true: r, _, err = s.dnsTCPclient.Exchange(req, s.config.Nameservers[nsid]) } if err == nil { r.Compress = true r.Id = req.Id w.WriteMsg(r) return r } // Seen an error, this can only mean, "server not reached", try again // but only if we have not exausted our nameservers. if try < len(s.config.Nameservers) { try++ nsid = (nsid + 1) % len(s.config.Nameservers) goto Redo } logf("failure to forward request %q", err) m := s.ServerFailure(req) return m }
// handlerFunc receives requests, looks up the result and returns what is found. func handlerFunc(res dns.ResponseWriter, req *dns.Msg) { message := new(dns.Msg) switch req.Opcode { case dns.OpcodeQuery: message.SetReply(req) message.Compress = false message.Answer = make([]dns.RR, 0) for _, question := range message.Question { answers := answerQuestion(strings.ToLower(question.Name), question.Qtype) if len(answers) > 0 { for i := range answers { message.Answer = append(message.Answer, answers[i]) } } else { // If there are no records, go back through and search for SOA records for _, question := range message.Question { answers := answerQuestion(strings.ToLower(question.Name), dns.TypeSOA) for i := range answers { message.Ns = append(message.Ns, answers[i]) } } } } if len(message.Answer) == 0 && len(message.Ns) == 0 { message.Rcode = dns.RcodeNameError } default: message = message.SetRcode(req, dns.RcodeNotImplemented) } res.WriteMsg(message) }
// handleDNS is a handler function to actualy perform the dns querey response func (c *CatchAll) handleDNS(w dns.ResponseWriter, r *dns.Msg) { defer w.Close() var rr dns.RR domainSpoof := r.Question[0].Name msgResp := new(dns.Msg) msgResp.SetReply(r) msgResp.Compress = false rr = new(dns.A) if c.SpoofDomain { rr.(*dns.A).Hdr = dns.RR_Header{Name: domainSpoof, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 0} } else { rr.(*dns.A).Hdr = dns.RR_Header{Name: c.Domain, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 0} } rr.(*dns.A).A = c.IP switch r.Question[0].Qtype { case dns.TypeA: msgResp.Answer = append(msgResp.Answer, rr) default: log.Warnf("Unknown dns type %T", r.Question[0].Qtype) return } w.WriteMsg(msgResp) }
// reply writes the given dns.Msg out to the given dns.ResponseWriter, // compressing the message first and truncating it accordingly. func reply(w dns.ResponseWriter, m *dns.Msg) { m.Compress = true // https://github.com/mesosphere/mesos-dns/issues/{170,173,174} if err := w.WriteMsg(truncate(m, isUDP(w))); err != nil { logging.Error.Println(err) } }
func handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) { if *debug { Log.Printf("handleRequest: message: %+v\n", r) } m := new(dns.Msg) m.SetReply(r) m.Compress = false switch r.Opcode { case dns.OpcodeQuery: parseQuery(m) case dns.OpcodeUpdate: for _, question := range r.Question { for _, rr := range r.Ns { updateRecord(rr, &question) } } } if r.IsTsig() != nil { if w.TsigStatus() == nil { m.SetTsig(r.Extra[len(r.Extra)-1].(*dns.TSIG).Hdr.Name, dns.HmacMD5, 300, time.Now().Unix()) } else { Log.Println("Status", w.TsigStatus().Error()) } } w.WriteMsg(m) }
// handleQuery is used to handle DNS queries in the configured domain func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) { q := req.Question[0] defer func(s time.Time) { metrics.MeasureSince([]string{"consul", "dns", "domain_query", d.agent.config.NodeName}, s) d.logger.Printf("[DEBUG] dns: request for %v (%v) from client %s (%s)", q, time.Now().Sub(s), resp.RemoteAddr().String(), resp.RemoteAddr().Network()) }(time.Now()) // Switch to TCP if the client is network := "udp" if _, ok := resp.RemoteAddr().(*net.TCPAddr); ok { network = "tcp" } // Setup the message response m := new(dns.Msg) m.SetReply(req) m.Compress = !d.config.DisableCompression m.Authoritative = true m.RecursionAvailable = (len(d.recursors) > 0) // Only add the SOA if requested if req.Question[0].Qtype == dns.TypeSOA { d.addSOA(d.domain, m) } // Dispatch the correct handler d.dispatch(network, req, m) // Write out the complete response if err := resp.WriteMsg(m); err != nil { d.logger.Printf("[WARN] dns: failed to respond: %v", err) } }
// ServeDNSStubForward forwards a request to a nameservers and returns the response. func (s *server) ServeDNSStubForward(w dns.ResponseWriter, req *dns.Msg, ns []string) *dns.Msg { // Check EDNS0 Stub option, if set drop the packet. option := req.IsEdns0() if option != nil { for _, o := range option.Option { if o.Option() == ednsStubCode && len(o.(*dns.EDNS0_LOCAL).Data) == 1 && o.(*dns.EDNS0_LOCAL).Data[0] == 1 { // Maybe log source IP here? logf("not fowarding stub request to another stub") return nil } } } tcp := isTCP(w) // Add a custom EDNS0 option to the packet, so we can detect loops // when 2 stubs are forwarding to each other. if option != nil { option.Option = append(option.Option, &dns.EDNS0_LOCAL{ednsStubCode, []byte{1}}) } else { req.Extra = append(req.Extra, ednsStub) } var ( r *dns.Msg err error try int ) // Use request Id for "random" nameserver selection. nsid := int(req.Id) % len(ns) Redo: switch tcp { case false: r, _, err = s.dnsUDPclient.Exchange(req, ns[nsid]) case true: r, _, err = s.dnsTCPclient.Exchange(req, ns[nsid]) } if err == nil { r.Compress = true r.Id = req.Id w.WriteMsg(r) return r } // Seen an error, this can only mean, "server not reached", try again // but only if we have not exausted our nameservers. if try < len(ns) { try++ nsid = (nsid + 1) % len(ns) goto Redo } logf("failure to forward stub request %q", err) m := s.ServerFailure(req) w.WriteMsg(m) return m }
func CompressIfLarge(m *dns.Msg) { bytes, err := m.Pack() if err != nil { return } if len(bytes) > 512 { // may not fit into UDP packet m.Compress = true // will be compressed in WriteMsg } }
func (s *Server) handle(w dns.ResponseWriter, request *dns.Msg) { // Always close the writer defer w.Close() // Capture starting time for measuring message response time var start time.Time start = time.Now() // Setup the default response var response *dns.Msg response = &dns.Msg{} response.SetReply(request) response.Compress = true // Lookup answers to any of the questions for _, question := range request.Question { // Capture starting time for measuring lookup var lookupStart time.Time lookupStart = time.Now() // Perform lookup for this question var records []dns.RR records = s.registry.Lookup(question.Name, question.Qtype, question.Qclass) // Capture ending and elapsed time var lookupElapsed time.Duration lookupElapsed = time.Since(lookupStart) // Append results to the response response.Answer = append(response.Answer, records...) // If StatsD is enabled, record some metrics if s.statsd != nil { var tags []string tags = []string{ fmt.Sprintf("name:%s", question.Name), fmt.Sprintf("qtype:%s", dns.TypeToString[question.Qtype]), fmt.Sprintf("qclass:%s", dns.ClassToString[question.Qclass]), } s.statsd.TimeInMilliseconds("lookup.time", lookupElapsed.Seconds()*1000.0, tags, 1) s.statsd.Histogram("lookup.answer", float64(len(records)), tags, 1) s.statsd.Count("request.question", 1, tags, 1) } } // Respond to the request w.WriteMsg(response) // Record any ending metrics if s.statsd != nil { var elapsed time.Duration elapsed = time.Since(start) s.statsd.TimeInMilliseconds("request.time", elapsed.Seconds()*1000.0, nil, 1) } }
// HandleForwarding forwards a request to the nameservers and returns the response func (s *Server) HandleForwarding(w mdns.ResponseWriter, r *mdns.Msg) (bool, error) { defer trace.End(trace.Begin(r.String())) var m *mdns.Msg var err error var try int if len(s.Nameservers) == 0 { log.Errorf("No nameservers defined, can not forward") return false, respServerFailure(w, r) } // which protocol are they talking tcp := false if _, ok := w.RemoteAddr().(*net.TCPAddr); ok { tcp = true } // Use request ID for "random" nameserver selection. nsid := int(r.Id) % len(s.Nameservers) Redo: nameserver := s.Nameservers[nsid] if i := strings.Index(nameserver, ":"); i < 0 { nameserver += ":53" } if tcp { m, _, err = s.tcpclient.Exchange(r, nameserver) } else { m, _, err = s.udpclient.Exchange(r, nameserver) } if err != nil { // Seen an error, this can only mean, "server not reached", try again but only if we have not exausted our nameservers. if try < len(s.Nameservers) { try++ nsid = (nsid + 1) % len(s.Nameservers) goto Redo } log.Errorf("Failure to forward request: %q", err) return false, respServerFailure(w, r) } // We have a response so cache it s.cache.Add(m) m.Compress = true if err := w.WriteMsg(m); err != nil { log.Errorf("Error writing response: %q", err) return true, err } return true, nil }
func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) { var ( resp *dns.Msg err error ) if query == nil || len(query.Question) == 0 { return } name := query.Question[0].Name if query.Question[0].Qtype == dns.TypeA { resp, err = r.handleIPv4Query(name, query) } else if query.Question[0].Qtype == dns.TypePTR { resp, err = r.handlePTRQuery(name, query) } if err != nil { log.Error(err) return } if resp == nil { if len(r.extDNS) == 0 { return } num := maxExtDNS if len(r.extDNS) < maxExtDNS { num = len(r.extDNS) } for i := 0; i < num; i++ { log.Debugf("Querying ext dns %s:%s for %s[%d]", w.LocalAddr().Network(), r.extDNS[i], name, query.Question[0].Qtype) c := &dns.Client{Net: w.LocalAddr().Network()} addr := fmt.Sprintf("%s:%d", r.extDNS[i], 53) resp, _, err = c.Exchange(query, addr) if err == nil { resp.Compress = true break } log.Errorf("external resolution failed, %s", err) } if resp == nil { return } } err = w.WriteMsg(resp) if err != nil { log.Errorf("error writing resolver resp, %s", err) } }
// ServeDNSForward forwards a request to a nameservers and returns the response. func (s *server) ServeDNSForward(w dns.ResponseWriter, req *dns.Msg) *dns.Msg { if s.config.NoRec { m := s.ServerFailure(req) w.WriteMsg(m) return m } if len(s.config.Nameservers) == 0 || dns.CountLabel(req.Question[0].Name) < s.config.Ndots { if s.config.Verbose { if len(s.config.Nameservers) == 0 { logf("can not forward, no nameservers defined") } else { logf("can not forward, name too short (less than %d labels): `%s'", s.config.Ndots, req.Question[0].Name) } } m := s.ServerFailure(req) m.RecursionAvailable = true // this is still true w.WriteMsg(m) return m } var ( r *dns.Msg err error ) nsid := s.randomNameserverID(req.Id) try := 0 Redo: if isTCP(w) { r, err = exchangeWithRetry(s.dnsTCPclient, req, s.config.Nameservers[nsid]) } else { r, err = exchangeWithRetry(s.dnsUDPclient, req, s.config.Nameservers[nsid]) } if err == nil { r.Compress = true r.Id = req.Id w.WriteMsg(r) return r } // Seen an error, this can only mean, "server not reached", try again // but only if we have not exausted our nameservers. if try < len(s.config.Nameservers) { try++ nsid = (nsid + 1) % len(s.config.Nameservers) goto Redo } logf("failure to forward request %q", err) m := s.ServerFailure(req) return m }
func handleForwardingRaw(nameservers []string, req *dns.Msg, remote net.Addr) *dns.Msg { if len(nameservers) == 0 { log.Printf("no nameservers defined, can not forward\n") m := new(dns.Msg) m.SetReply(req) m.SetRcode(req, dns.RcodeServerFailure) m.Authoritative = false // no matter what set to false m.RecursionAvailable = true // and this is still true return m } tcp := false if _, ok := remote.(*net.TCPAddr); ok { tcp = true } var ( r *dns.Msg err error try int ) // Use request Id for "random" nameserver selection. nsid := int(req.Id) % len(nameservers) dnsClient := &dns.Client{Net: "udp", ReadTimeout: 4 * time.Second, WriteTimeout: 4 * time.Second, SingleInflight: true} if tcp { dnsClient.Net = "tcp" } Redo: nameserver := nameservers[nsid] if i := strings.Index(nameserver, ":"); i < 0 { nameserver += ":53" } r, _, err = dnsClient.Exchange(req, nameserver) if err == nil { r.Compress = true return r } // Seen an error, this can only mean, "server not reached", try again // but only if we have not exausted our nameservers. if try < len(nameservers) { try++ nsid = (nsid + 1) % len(nameservers) goto Redo } log.Printf("failure to forward request %q\n", err) m := new(dns.Msg) m.SetReply(req) m.SetRcode(req, dns.RcodeServerFailure) return m }
// trimUDPResponse makes sure a UDP response is not longer than allowed by RFC // 1035. Enforce an arbitrary limit that can be further ratcheted down by // config, and then make sure the response doesn't exceed 512 bytes. Any extra // records will be trimmed along with answers. func trimUDPResponse(config *DNSConfig, resp *dns.Msg) (trimmed bool) { numAnswers := len(resp.Answer) hasExtra := len(resp.Extra) > 0 // We avoid some function calls and allocations by only handling the // extra data when necessary. var index map[string]dns.RR if hasExtra { index = make(map[string]dns.RR, len(resp.Extra)) indexRRs(resp.Extra, index) } // This cuts UDP responses to a useful but limited number of responses. maxAnswers := lib.MinInt(maxUDPAnswerLimit, config.UDPAnswerLimit) if numAnswers > maxAnswers { resp.Answer = resp.Answer[:maxAnswers] if hasExtra { syncExtra(index, resp) } } // This enforces the hard limit of 512 bytes per the RFC. Note that we // temporarily switch to uncompressed so that we limit to a response // that will not exceed 512 bytes uncompressed, which is more // conservative and will allow our responses to be compliant even if // some downstream server uncompresses them. compress := resp.Compress resp.Compress = false for len(resp.Answer) > 0 && resp.Len() > 512 { resp.Answer = resp.Answer[:len(resp.Answer)-1] if hasExtra { syncExtra(index, resp) } } resp.Compress = compress return len(resp.Answer) < numAnswers }
// ServeDNSReverse is the handler for DNS requests for the reverse zone. If nothing is found // locally the request is forwarded to the forwarder for resolution. func (s *server) ServeDNSReverse(w dns.ResponseWriter, req *dns.Msg) *dns.Msg { m := new(dns.Msg) m.SetReply(req) m.Compress = true m.Authoritative = false m.RecursionAvailable = true if records, err := s.PTRRecords(req.Question[0]); err == nil && len(records) > 0 { m.Answer = records writeMsg(w, m) return m } // Always forward if not found locally. return s.ServeDNSForward(w, req) }
// ServeDNSForward forwards a request to a nameservers and returns the response. func (s *server) ServeDNSForward(w dns.ResponseWriter, req *dns.Msg) { StatsForwardCount.Inc(1) if len(s.config.Nameservers) == 0 || dns.CountLabel(req.Question[0].Name) < s.config.Ndots { s.config.log.Infof("no nameservers defined or name too short, can not forward") m := new(dns.Msg) m.SetReply(req) m.SetRcode(req, dns.RcodeServerFailure) m.Authoritative = false // no matter what set to false m.RecursionAvailable = true // and this is still true w.WriteMsg(m) return } tcp := false if _, ok := w.RemoteAddr().(*net.TCPAddr); ok { tcp = true } var ( r *dns.Msg err error try int ) // Use request Id for "random" nameserver selection. nsid := int(req.Id) % len(s.config.Nameservers) Redo: switch tcp { case false: r, _, err = s.dnsUDPclient.Exchange(req, s.config.Nameservers[nsid]) case true: r, _, err = s.dnsTCPclient.Exchange(req, s.config.Nameservers[nsid]) } if err == nil { r.Compress = true w.WriteMsg(r) return } // Seen an error, this can only mean, "server not reached", try again // but only if we have not exausted our nameservers. if try < len(s.config.Nameservers) { try++ nsid = (nsid + 1) % len(s.config.Nameservers) goto Redo } s.config.log.Errorf("failure to forward request %q", err) m := new(dns.Msg) m.SetReply(req) m.SetRcode(req, dns.RcodeServerFailure) w.WriteMsg(m) }
func Respond(w dns.ResponseWriter, req *dns.Msg, records []dns.RR) { m := new(dns.Msg) m.SetReply(req) m.Authoritative = true m.RecursionAvailable = true m.Compress = true m.Answer = records // Figure out the max response size bufsize := uint16(512) tcp := isTcp(w) if o := req.IsEdns0(); o != nil { bufsize = o.UDPSize() } if tcp { bufsize = dns.MaxMsgSize - 1 } else if bufsize < 512 { bufsize = 512 } if m.Len() > dns.MaxMsgSize { fqdn := dns.Fqdn(req.Question[0].Name) log.WithFields(log.Fields{"fqdn": fqdn}).Debug("Response too big, dropping Extra") m.Extra = nil if m.Len() > dns.MaxMsgSize { log.WithFields(log.Fields{"fqdn": fqdn}).Debug("Response still too big") m := new(dns.Msg) m.SetRcode(m, dns.RcodeServerFailure) } } if m.Len() > int(bufsize) && !tcp { log.Debug("Too big 1") m.Extra = nil if m.Len() > int(bufsize) { log.Debug("Too big 2") m.Answer = nil m.Truncated = true } } err := w.WriteMsg(m) if err != nil { log.Warn("Failed to return reply: ", err, m.Len()) } }
// handleRecurse is used to handle recursive DNS queries func (d *DNSServer) handleRecurse(resp dns.ResponseWriter, req *dns.Msg) { q := req.Question[0] network := "udp" defer func(s time.Time) { d.logger.Printf("[DEBUG] dns: request for %v (%s) (%v) from client %s (%s)", q, network, time.Now().Sub(s), resp.RemoteAddr().String(), resp.RemoteAddr().Network()) }(time.Now()) // Switch to TCP if the client is if _, ok := resp.RemoteAddr().(*net.TCPAddr); ok { network = "tcp" } // Recursively resolve c := &dns.Client{Net: network, Timeout: d.config.RecursorTimeout} var r *dns.Msg var rtt time.Duration var err error for _, recursor := range d.recursors { r, rtt, err = c.Exchange(req, recursor) if err == nil { // Compress the response; we don't know if the incoming // response was compressed or not, so by not compressing // we might generate an invalid packet on the way out. r.Compress = !d.config.DisableCompression // Forward the response d.logger.Printf("[DEBUG] dns: recurse RTT for %v (%v)", q, rtt) if err := resp.WriteMsg(r); err != nil { d.logger.Printf("[WARN] dns: failed to respond: %v", err) } return } d.logger.Printf("[ERR] dns: recurse failed: %v", err) } // If all resolvers fail, return a SERVFAIL message d.logger.Printf("[ERR] dns: all resolvers failed for %v from client %s (%s)", q, resp.RemoteAddr().String(), resp.RemoteAddr().Network()) m := &dns.Msg{} m.SetReply(req) m.Compress = !d.config.DisableCompression m.RecursionAvailable = true m.SetRcode(req, dns.RcodeServerFailure) resp.WriteMsg(m) }
// ServeDNSReverse is the handler for DNS requests for the reverse zone. If nothing is found // locally the request is forwarded to the forwarder for resolution. func (s *server) ServeDNSReverse(w dns.ResponseWriter, req *dns.Msg) { m := new(dns.Msg) m.SetReply(req) m.Compress = true m.Authoritative = false // Set to false, because I don't know what to do wrt DNSSEC. m.RecursionAvailable = true var err error if m.Answer, err = s.PTRRecords(req.Question[0]); err == nil { // TODO(miek): Reverse DNSSEC. We should sign this, but requires a key....and more // Probably not worth the hassle? if err := w.WriteMsg(m); err != nil { s.config.log.Errorf("failure to return reply %q", err) } } // Always forward if not found locally. s.ServeDNSForward(w, req) }
// toMsg turns i into a message, it tailers to reply to m. func (i *item) toMsg(m *dns.Msg) *dns.Msg { m1 := new(dns.Msg) m1.SetReply(m) m1.Authoritative = i.Authoritative m1.AuthenticatedData = i.AuthenticatedData m1.RecursionAvailable = i.RecursionAvailable m1.Compress = true m1.Answer = i.Answer m1.Ns = i.Ns m1.Extra = i.Extra ttl := int(i.origTtl) - int(time.Now().UTC().Sub(i.stored).Seconds()) if ttl < baseTtl { ttl = baseTtl } setCap(m1, uint32(ttl)) return m1 }
func handleRequest(w dns.ResponseWriter, r *dns.Msg) { var rr dns.RR fmt.Println(r.Question[0].Name) m := new(dns.Msg) m.SetReply(r) m.Compress = *compress rrstr, err := getRRStr(r.Question[0]) if err { m.SetRcode(r, dns.RcodeNameError) } else { rr, _ = dns.NewRR(rrstr) m.Answer = append(m.Answer, rr) if *printf { fmt.Printf("%v\n", m.String()) } } w.WriteMsg(m) }
func (dom Domain) addHandler(tld string) { fqdn := dom.Name if tld != "" { fqdn = dom.Name + "." + tld } FqdnParts, _ := dns.IsDomainName(fqdn) fmt.Printf("Adding: %v - %v nums\n", fqdn, FqdnParts) // Handle dns requests if it is really a fqdn if dns.IsFqdn(fqdn) { dns.HandleFunc(fqdn, func(w dns.ResponseWriter, req *dns.Msg) { m := new(dns.Msg) m.SetReply(req) m.Compress = true for i, q := range req.Question { fmt.Printf("Requested: %s, Type: %v\n", req.Question[i].Name, req.Question[i].Qtype) switch q.Qtype { case 1: fmt.Printf("Adding a record %v with ip %v", q.Name, dom.A.Ip) m.Answer = append(m.Answer, NewA(q.Name, dom.A.Ip, uint32(dom.A.Ttl))) case 15: fmt.Printf("Adding a record %v with ip %v", q.Name, dom.A.Ip) m.Answer = append(m.Answer, NewMX(q.Name, dom.Mx.Content, dom.Mx.Priority, uint32(dom.Mx.Ttl))) } } w.WriteMsg(m) }) } // add subdomains.. for _, d := range dom.Domains { d.addHandler(fqdn) } }
func dnsHandler(w dns.ResponseWriter, r *dns.Msg) { defer w.Close() m := new(dns.Msg) m.SetReply(r) m.Compress = false for _, q := range r.Question { fmt.Printf("dns-srv: Query -- [%s] %s\n", q.Name, dns.TypeToString[q.Qtype]) switch q.Qtype { case dns.TypeA: record := new(dns.A) record.Hdr = dns.RR_Header{ Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 0, } record.A = net.ParseIP("127.0.0.1") m.Answer = append(m.Answer, record) case dns.TypeMX: record := new(dns.MX) record.Hdr = dns.RR_Header{ Name: q.Name, Rrtype: dns.TypeMX, Class: dns.ClassINET, Ttl: 0, } record.Mx = "mail." + q.Name record.Preference = 10 m.Answer = append(m.Answer, record) } } w.WriteMsg(m) return }
// base handler for dns server func dnsHandler(w dns.ResponseWriter, request *dns.Msg) { response := new(dns.Msg) response.SetReply(request) response.Compress = false switch request.Opcode { case dns.OpcodeQuery: for _, q := range response.Question { if readRR, e := getRecord(q.Name, q.Qtype); e == nil { rr := readRR.(dns.RR) if rr.Header().Name == q.Name { response.Answer = append(response.Answer, rr) } } } case dns.OpcodeUpdate: if request.IsTsig() != nil && w.TsigStatus() == nil { for _, question := range request.Question { for _, rr := range request.Ns { updateRecord(rr, &question) } } } else { log.Println("droping update without tsig or with bad sig") } } if request.IsTsig() != nil { if w.TsigStatus() == nil { response.SetTsig(request.Extra[len(request.Extra)-1].(*dns.TSIG).Hdr.Name, dns.HmacMD5, 300, time.Now().Unix()) } else { log.Println("Status: ", w.TsigStatus().Error()) } } w.WriteMsg(response) }
func (d dnsAPI) Recurse(w dns.ResponseWriter, req *dns.Msg) { var client dns.Client if isTCP(w.RemoteAddr()) { client.Net = "tcp" } for _, recursor := range d.Recursors { req.Compress = true res, _, err := client.Exchange(req, recursor) if err != nil { continue } res.Compress = true w.WriteMsg(res) return } // Return SERVFAIL res := &dns.Msg{} res.RecursionAvailable = true res.SetRcode(req, dns.RcodeServerFailure) w.WriteMsg(res) }
func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) { var ( extConn net.Conn resp *dns.Msg err error ) if query == nil || len(query.Question) == 0 { return } name := query.Question[0].Name switch query.Question[0].Qtype { case dns.TypeA: resp, err = r.handleIPQuery(name, query, types.IPv4) case dns.TypeAAAA: resp, err = r.handleIPQuery(name, query, types.IPv6) case dns.TypePTR: resp, err = r.handlePTRQuery(name, query) case dns.TypeSRV: resp, err = r.handleSRVQuery(name, query) } if err != nil { logrus.Error(err) return } if resp == nil { // If the backend doesn't support proxying dns request // fail the response if !r.proxyDNS { resp = new(dns.Msg) resp.SetRcode(query, dns.RcodeServerFailure) w.WriteMsg(resp) return } // If the user sets ndots > 0 explicitly and the query is // in the root domain don't forward it out. We will return // failure and let the client retry with the search domain // attached switch query.Question[0].Qtype { case dns.TypeA: fallthrough case dns.TypeAAAA: if r.backend.NdotsSet() && !strings.Contains(strings.TrimSuffix(name, "."), ".") { resp = createRespMsg(query) } } } proto := w.LocalAddr().Network() maxSize := 0 if proto == "tcp" { maxSize = dns.MaxMsgSize - 1 } else if proto == "udp" { optRR := query.IsEdns0() if optRR != nil { maxSize = int(optRR.UDPSize()) } if maxSize < defaultRespSize { maxSize = defaultRespSize } } if resp != nil { if resp.Len() > maxSize { truncateResp(resp, maxSize, proto == "tcp") } } else { for i := 0; i < maxExtDNS; i++ { extDNS := &r.extDNSList[i] if extDNS.ipStr == "" { break } extConnect := func() { addr := fmt.Sprintf("%s:%d", extDNS.ipStr, 53) extConn, err = net.DialTimeout(proto, addr, extIOTimeout) } if extDNS.hostLoopback { extConnect() } else { execErr := r.backend.ExecFunc(extConnect) if execErr != nil { logrus.Warn(execErr) continue } } if err != nil { logrus.Warnf("Connect failed: %s", err) continue } logrus.Debugf("Query %s[%d] from %s, forwarding to %s:%s", name, query.Question[0].Qtype, extConn.LocalAddr().String(), proto, extDNS.ipStr) // Timeout has to be set for every IO operation. extConn.SetDeadline(time.Now().Add(extIOTimeout)) co := &dns.Conn{ Conn: extConn, UDPSize: uint16(maxSize), } defer co.Close() // limits the number of outstanding concurrent queries. if r.forwardQueryStart() == false { old := r.tStamp r.tStamp = time.Now() if r.tStamp.Sub(old) > logInterval { logrus.Errorf("More than %v concurrent queries from %s", maxConcurrent, extConn.LocalAddr().String()) } continue } err = co.WriteMsg(query) if err != nil { r.forwardQueryEnd() logrus.Debugf("Send to DNS server failed, %s", err) continue } resp, err = co.ReadMsg() // Truncated DNS replies should be sent to the client so that the // client can retry over TCP if err != nil && err != dns.ErrTruncated { r.forwardQueryEnd() logrus.Debugf("Read from DNS server failed, %s", err) continue } r.forwardQueryEnd() if resp != nil { for _, rr := range resp.Answer { h := rr.Header() switch h.Rrtype { case dns.TypeA: ip := rr.(*dns.A).A r.backend.HandleQueryResp(h.Name, ip) case dns.TypeAAAA: ip := rr.(*dns.AAAA).AAAA r.backend.HandleQueryResp(h.Name, ip) } } } resp.Compress = true break } if resp == nil { return } } if err = w.WriteMsg(resp); err != nil { logrus.Errorf("error writing resolver resp, %s", err) } }
// ServeDNSForward forwards a request to a nameservers and returns the response. func (s *server) ServeDNSForward(w dns.ResponseWriter, req *dns.Msg) *dns.Msg { StatsForwardCount.Inc(1) promExternalRequestCount.WithLabelValues("recursive").Inc() if s.config.NoRec { m := new(dns.Msg) m.SetReply(req) m.SetRcode(req, dns.RcodeServerFailure) m.Authoritative = false m.RecursionAvailable = false w.WriteMsg(m) return m } if len(s.config.Nameservers) == 0 || dns.CountLabel(req.Question[0].Name) < s.config.Ndots { if s.config.Verbose { if len(s.config.Nameservers) == 0 { logf("can not forward, no nameservers defined") } else { logf("can not forward, name too short (less than %d labels): `%s'", s.config.Ndots, req.Question[0].Name) } } m := new(dns.Msg) m.SetReply(req) m.SetRcode(req, dns.RcodeServerFailure) m.Authoritative = false // no matter what set to false m.RecursionAvailable = true // and this is still true w.WriteMsg(m) return m } tcp := isTCP(w) var ( r *dns.Msg err error try int ) // Use request Id for "random" nameserver selection. nsid := int(req.Id) % len(s.config.Nameservers) Redo: switch tcp { case false: r, _, err = s.dnsUDPclient.Exchange(req, s.config.Nameservers[nsid]) case true: r, _, err = s.dnsTCPclient.Exchange(req, s.config.Nameservers[nsid]) } if err == nil { r.Compress = true r.Id = req.Id w.WriteMsg(r) return r } // Seen an error, this can only mean, "server not reached", try again // but only if we have not exausted our nameservers. if try < len(s.config.Nameservers) { try++ nsid = (nsid + 1) % len(s.config.Nameservers) goto Redo } logf("failure to forward request %q", err) m := new(dns.Msg) m.SetReply(req) m.SetRcode(req, dns.RcodeServerFailure) w.WriteMsg(m) return m }
func TestRecursiveCompress(t *testing.T) { const ( hostname = "foo.example." maxSize = 512 ) // Construct a response that is >512 when uncompressed, <512 when compressed response := dns.Msg{} response.Authoritative = true response.Answer = []dns.RR{} header := dns.RR_Header{ Name: hostname, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 10, } for response.Len() <= maxSize { ip := address.Address(rand.Uint32()).IP4() response.Answer = append(response.Answer, &dns.A{Hdr: header, A: ip}) } response.Compress = true require.True(t, response.Len() <= maxSize) // A dns server that returns the above response var gotRequest = make(chan struct{}, 1) handleRecursive := func(w dns.ResponseWriter, req *dns.Msg) { gotRequest <- struct{}{} require.Equal(t, req.Question[0].Name, hostname) response.SetReply(req) err := w.WriteMsg(&response) require.Nil(t, err) } mux := dns.NewServeMux() mux.HandleFunc(topDomain, handleRecursive) udpListener, err := net.ListenPacket("udp", "0.0.0.0:0") require.Nil(t, err) udpServer := &dns.Server{PacketConn: udpListener, Handler: mux} udpServerPort := udpListener.LocalAddr().(*net.UDPAddr).Port go udpServer.ActivateAndServe() defer udpServer.Shutdown() // The weavedns server, pointed at the above server dnsserver, _, udpPort, _ := startServer(t, &dns.ClientConfig{ Servers: []string{"127.0.0.1"}, Port: strconv.Itoa(udpServerPort), Ndots: 1, Timeout: 5, Attempts: 2, }) defer dnsserver.Stop() // Now do lookup, check its what we expected. // NB this doesn't really test golang's resolver behaves correctly, as I can't see // a way to point golangs resolver at a specific hosts. req := new(dns.Msg) req.Id = dns.Id() req.RecursionDesired = true req.Question = make([]dns.Question, 1) req.Question[0] = dns.Question{ Name: hostname, Qtype: dns.TypeA, Qclass: dns.ClassINET, } c := new(dns.Client) res, _, err := c.Exchange(req, fmt.Sprintf("127.0.0.1:%d", udpPort)) require.Nil(t, err) require.True(t, len(gotRequest) > 0) require.True(t, res.Len() > maxSize) }