// interceptRequest gets called upon each DNS request, and we determine // whether we want to deal with it or not func interceptRequest(w dns.ResponseWriter, r *dns.Msg) { m := r.Copy() err := error(nil) defer w.Close() if len(r.Question) < 1 { return } // Hijack the response to point to us instead if matchesCriteria(r.Question[0].Name) { fmt.Println("Matches!", r.Question[0].Name) if !(inAddressCache(r.Question[0].Name)) { updateAddressCache(r) } m = hijackResponse(r) // Pass it upstream, return the answer } else { //fmt.Println("Passing on ", r.Question[0].Name) m, err = upstreamLookup(r) if err != nil { fmt.Println("Error when passing request through upstream - network problem?") // in this instance, our response (m) has no answer } } w.WriteMsg(m) }
func (h *handler) handleRecursive(w dns.ResponseWriter, req *dns.Msg) { h.ns.debugf("recursive request: %+v", *req) // Resolve unqualified names locally if len(req.Question) == 1 { hostname := dns.Fqdn(req.Question[0].Name) if strings.Count(hostname, ".") == 1 { h.handleLocal(w, req) return } } upstreamConfig, err := h.upstream.Config() if err != nil { h.ns.errorf("unable to read upstream config: %s", err) } for _, server := range upstreamConfig.Servers { reqCopy := req.Copy() reqCopy.Id = dns.Id() response, _, err := h.client.Exchange(reqCopy, fmt.Sprintf("%s:%s", server, upstreamConfig.Port)) if (err != nil && err != dns.ErrTruncated) || response == nil { h.ns.debugf("error trying %s: %v", server, err) continue } response.Id = req.Id if h.responseTooBig(req, response) { response.Compress = true } h.respond(w, response) return } h.respond(w, h.makeErrorResponse(req, dns.RcodeServerFailure)) }
// InsertMessage inserts a message in the Cache. We will cache it for ttl seconds, which // should be a small (60...300) integer. func (c *Cache) InsertMessage(s string, msg *dns.Msg) { if c.capacity <= 0 { return } c.Lock() if _, ok := c.m[s]; !ok { c.m[s] = &elem{time.Now().UTC().Add(c.ttl), msg.Copy()} } c.EvictRandom() c.Unlock() }
// set the reply for the entry // returns True if the entry has changed the validUntil time func (e *cacheEntry) setReply(reply *dns.Msg, ttl int, flags uint8, now time.Time) bool { var prevValidUntil time.Time if e.Status == stResolved { if reply != nil { Debug.Printf("[cache msgid %d] replacing response in cache", reply.MsgHdr.Id) } prevValidUntil = e.validUntil } // make sure we do not overwrite noLocalReplies entries if flags&CacheNoLocalReplies != 0 { if e.Flags&CacheNoLocalReplies != 0 { return false } } if ttl != nullTTL { e.validUntil = now.Add(time.Second * time.Duration(ttl)) } else if reply != nil { // calculate the validUntil from the reply TTL var minTTL uint32 = math.MaxUint32 for _, rr := range reply.Answer { ttl := rr.Header().Ttl if ttl < minTTL { minTTL = ttl // TODO: improve the minTTL calculation (maybe we should skip some RRs) } } e.validUntil = now.Add(time.Second * time.Duration(minTTL)) } else { Warning.Printf("[cache] no valid TTL could be calculated") } e.Status = stResolved e.Flags = flags e.putTime = now if reply != nil { e.reply = *reply.Copy() e.ReplyLen = reply.Len() } return (prevValidUntil != e.validUntil) }
// Add adds dns.Msg to the cache func (c *Cache) Add(msg *mdns.Msg) { c.Lock() defer c.Unlock() if len(c.m) >= c.capacity { // pick a random key and remove it for k := range c.m { delete(c.m, k) break } } key := generateKey(msg.Question[0]) if _, ok := c.m[key]; !ok { c.m[key] = &Item{ Expiration: time.Now().UTC().Add(c.ttl), Msg: msg.Copy(), } } }
// updateAddressCache updates the dns map of real names to ip addresses // used to relay traffic to the intended destination. func updateAddressCache(r *dns.Msg) { u := r.Copy() u, err := upstreamLookup(u) if err != nil { fmt.Println("Network error: Cannnot lookup upstream!", err.Error()) return } // find all A records, save the name and address, replace the address and ttl names := []string{} ip := "" // grab all names (and IP from A record), for updating (sans dot at end) for _, answer := range u.Answer { names = append(names, answer.Header().Name) if dns.TypeToString[answer.Header().Rrtype] == "A" { found := addressRegex.FindString(answer.String()) if found != "" { ip = found } else { panic(r.String()) } } } if ip == "" { fmt.Println("Cannot updated map, no IP found") fmt.Println(u.String()) return } updateMutex.Lock() // Update both with . at the end and without for _, name := range names { addressMap[strings.ToLower(name)] = ip addressMap[strings.ToLower(name[:len(name)-1])] = ip fmt.Println("Updated address map with ", name, ip) } updateMutex.Unlock() }
// forwardSearch resolves a query by suffixing with search paths func (s *server) forwardSearch(req *dns.Msg, tcp bool) (*dns.Msg, error) { var r *dns.Msg var nodata *dns.Msg // stores the copy of a NODATA reply var searchName string // stores the current name suffixed with search domain var err error var didSearch bool name := req.Question[0].Name // original qname reqCopy := req.Copy() for _, domain := range s.config.SearchDomains { if strings.HasSuffix(name, domain) { continue } searchName = strings.ToLower(appendDomain(name, domain)) reqCopy.Question[0] = dns.Question{searchName, reqCopy.Question[0].Qtype, reqCopy.Question[0].Qclass} didSearch = true r, err = s.forwardQuery(reqCopy, tcp) if err != nil { // No server currently available, give up break } switch r.Rcode { case dns.RcodeSuccess: // In case of NO_DATA keep searching, otherwise a wildcard entry // could keep us from finding the answer higher in the search list if len(r.Answer) == 0 && !r.MsgHdr.Truncated { nodata = r.Copy() continue } case dns.RcodeNameError: fallthrough case dns.RcodeServerFailure: // try next search element if any continue } // anything else implies that we are done searching break } if !didSearch { m := new(dns.Msg) m.SetRcode(req, dns.RcodeNameError) return m, nil } if err == nil { if r.Rcode == dns.RcodeSuccess { if len(r.Answer) > 0 { cname := new(dns.CNAME) cname.Hdr = dns.RR_Header{Name: name, Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: 360} cname.Target = searchName answers := []dns.RR{cname} for _, rr := range r.Answer { answers = append(answers, rr) } r.Answer = answers } // If we ever got a NODATA, return this instead of a negative result } else if nodata != nil { r = nodata } // Restore original question r.Question[0] = req.Question[0] } if err != nil && nodata != nil { r = nodata r.Question[0] = req.Question[0] err = nil } return r, err }