// Async func (c *MDNSClient) SendQuery(name string, querytype uint16, insistent bool, responseCh chan<- *Response) { c.actionChan <- func() { query, found := c.inflight[name] if !found { m := new(dns.Msg) m.SetQuestion(name, querytype) m.RecursionDesired = false buf, err := m.Pack() if err != nil { responseCh <- &Response{err: err} close(responseCh) return } query = &inflightQuery{ name: name, id: m.Id, } if _, err = c.conn.WriteTo(buf, c.addr); err != nil { responseCh <- &Response{err: err} close(responseCh) return } c.inflight[name] = query } info := &responseInfo{ ch: responseCh, timeout: time.Now().Add(mDNSTimeout), insistent: insistent, } // Invariant on responseInfos: they are in ascending order of // timeout. Since we use a fixed interval from Now(), this // must be after all existing timeouts. query.responseInfos = append(query.responseInfos, info) } }
func (c *Cache) Insert(d *dns.Msg) { if *flaglog { log.Printf("fsk-shield: inserting " + toRadixKey(d)) } buf, _ := d.Pack() // Should always work c.Radix.Insert(toRadixKey(d), &Packet{d: buf[2:], ttl: time.Now().UTC()}) }
func (this *cache) Add(message *dns.Msg) error { if message.Question[0].Qtype == dns.TypeA && message.Question[0].Qclass == dns.ClassINET { byteMessage, err := message.Pack() if err != nil { return err } err = this.collection.SetObject(message.Question[0].Name, CacheRecord{Expiry: time.Now().Add(24 * time.Hour), Record: byteMessage}) for _, part := range message.Answer { switch part.(type) { case *dns.A: cacheErr := this.ipToHostname.SetObject(part.(*dns.A).A.String(), strings.TrimSuffix(message.Question[0].Name, ".")) if cacheErr != nil { log.Println("Warning: Error Adding/Updating Cache IP:" + part.(*dns.A).A.String() + " Hostname:" + strings.TrimSuffix(message.Question[0].Name, ".") + " Error:" + cacheErr.Error()) } case *dns.CNAME: //CNAME Don't contain the IP for reverse lookups. //log.Printf("CNAME: Type:%s Value:%v", reflect.TypeOf(part).Name(), part) //TODO: We should probably Add the CNAME as a hostname?? default: log.Printf("Type:%s Value:%v", reflect.TypeOf(part).Name(), part) } } return err } return nil }
func (c *Client) Discover(domain string, cb func(*dns.Msg)) { m := new(dns.Msg) m.SetQuestion(dns.Fqdn(domain), dns.TypePTR) m.RecursionDesired = true addr := &net.UDPAddr{ IP: net.ParseIP("224.0.0.251"), Port: 5353, } conn, err := net.ListenMulticastUDP("udp4", nil, addr) if err != nil { panic(err) } defer conn.Close() c.conn = conn out, err := m.Pack() if err != nil { panic(err) } _, err = conn.WriteToUDP(out, addr) if err != nil { panic(err) } c.handleReceiveMsg(domain, cb) }
func (s *MDNSServer) sendResponse(m *dns.Msg) error { buf, err := m.Pack() if err != nil { return err } _, err = s.sendconn.WriteTo(buf, ipv4Addr) return err }
func (conn *DNSUDPConn) WriteDNSToUDP(m *dns.Msg, addr *net.UDPAddr) error { out, err := m.Pack() if err != nil { return err } _, err = conn.WriteToUDP(out, addr) return err }
// encode an mdns msg and broadcast it on the wire func (c *connector) writeMessage(msg *dns.Msg, addr *net.UDPAddr) error { buf, err := msg.Pack() if err != nil { return err } _, err = c.WriteToUDP(buf, addr) return err }
func CompressIfLarge(m *dns.Msg) { bytes, err := m.Pack() if err != nil { return } if len(bytes) > 512 { // may not fit into UDP packet m.Compress = true // will be compressed in WriteMsg } }
func TestRFC2136ValidUpdatePacket(t *testing.T) { dns.HandleFunc(rfc2136TestZone, serverHandlerPassBackRequest) defer dns.HandleRemove(rfc2136TestZone) server, addrstr, err := runLocalDNSTestServer("127.0.0.1:0", false) if err != nil { t.Fatalf("Failed to start test server: %v", err) } defer server.Shutdown() rr := new(dns.TXT) rr.Hdr = dns.RR_Header{ Name: rfc2136TestFqdn, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: uint32(rfc2136TestTTL), } rr.Txt = []string{rfc2136TestValue} rrs := make([]dns.RR, 1) rrs[0] = rr m := new(dns.Msg) m.SetUpdate(dns.Fqdn(rfc2136TestZone)) m.Insert(rrs) expectstr := m.String() expect, err := m.Pack() if err != nil { t.Fatalf("Error packing expect msg: %v", err) } provider, err := NewDNSProviderRFC2136(addrstr, rfc2136TestZone, "", "") if err != nil { t.Fatalf("Expected NewDNSProviderRFC2136() to return no error but the error was -> %v", err) } if err := provider.Present(rfc2136TestDomain, "", "1234d=="); err != nil { t.Errorf("Expected Present() to return no error but the error was -> %v", err) } rcvMsg := <-reqChan rcvMsg.Id = m.Id actual, err := rcvMsg.Pack() if err != nil { t.Fatalf("Error packing actual msg: %v", err) } if !bytes.Equal(actual, expect) { tmp := new(dns.Msg) if err := tmp.Unpack(actual); err != nil { t.Fatalf("Error unpacking actual msg: %v", err) } t.Errorf("Expected msg:\n%s", expectstr) t.Errorf("Actual msg:\n%v", tmp) } }
// sendQuery is used to multicast a query out func (c *client) sendQuery(q *dns.Msg) error { buf, err := q.Pack() if err != nil { return err } if c.ipv4List != nil { c.ipv4List.WriteTo(buf, ipv4Addr) } if c.ipv6List != nil { c.ipv6List.WriteTo(buf, ipv6Addr) } return nil }
// Pack the dns.Msg and write to available connections (multicast) func (c *client) sendQuery(msg *dns.Msg) error { buf, err := msg.Pack() if err != nil { return err } if c.ipv4conn != nil { c.ipv4conn.WriteTo(buf, ipv4Addr) } if c.ipv6conn != nil { c.ipv6conn.WriteTo(buf, ipv6Addr) } return nil }
// sendQuery is used to multicast a query out func (c *client) sendQuery(q *dns.Msg, params *QueryParam) error { buf, err := q.Pack() if err != nil { return err } if c.ipv4List != nil { c.ipv4List.WriteTo(buf, params.IPv4mdns) } if c.ipv6List != nil { c.ipv6List.WriteTo(buf, params.IPv6mdns) } return nil }
// sendQuery is used to multicast a query out func (c *client) sendQuery(q *dns.Msg) error { buf, err := q.Pack() if err != nil { return err } if c.ipv4UnicastConn != nil { c.ipv4UnicastConn.WriteToUDP(buf, ipv4Addr) } if c.ipv6UnicastConn != nil { c.ipv6UnicastConn.WriteToUDP(buf, ipv6Addr) } return nil }
// multicastResponse us used to send a multicast response packet func (c *Server) multicastResponse(msg *dns.Msg) error { buf, err := msg.Pack() if err != nil { log.Println("Failed to pack message!") return err } if c.ipv4conn != nil { c.ipv4conn.WriteTo(buf, ipv4Addr) } if c.ipv6conn != nil { c.ipv6conn.WriteTo(buf, ipv6Addr) } return nil }
// unicastResponse is used to send a unicast response packet func (s *Server) unicastResponse(resp *dns.Msg, from net.Addr) error { buf, err := resp.Pack() if err != nil { return err } addr := from.(*net.UDPAddr) if addr.IP.To4() != nil { _, err = s.ipv4conn.WriteToUDP(buf, addr) return err } else { _, err = s.ipv6conn.WriteToUDP(buf, addr) return err } }
func TestRFC2136ValidUpdatePacket(t *testing.T) { acme.ClearFqdnCache() dns.HandleFunc(rfc2136TestZone, serverHandlerPassBackRequest) defer dns.HandleRemove(rfc2136TestZone) server, addrstr, err := runLocalDNSTestServer("127.0.0.1:0", false) if err != nil { t.Fatalf("Failed to start test server: %v", err) } defer server.Shutdown() txtRR, _ := dns.NewRR(fmt.Sprintf("%s %d IN TXT %s", rfc2136TestFqdn, rfc2136TestTTL, rfc2136TestValue)) rrs := []dns.RR{txtRR} m := new(dns.Msg) m.SetUpdate(rfc2136TestZone) m.RemoveRRset(rrs) m.Insert(rrs) expectstr := m.String() expect, err := m.Pack() if err != nil { t.Fatalf("Error packing expect msg: %v", err) } provider, err := NewDNSProvider(addrstr, "", "", "") if err != nil { t.Fatalf("Expected NewDNSProvider() to return no error but the error was -> %v", err) } if err := provider.Present(rfc2136TestDomain, "", "1234d=="); err != nil { t.Errorf("Expected Present() to return no error but the error was -> %v", err) } rcvMsg := <-reqChan rcvMsg.Id = m.Id actual, err := rcvMsg.Pack() if err != nil { t.Fatalf("Error packing actual msg: %v", err) } if !bytes.Equal(actual, expect) { tmp := new(dns.Msg) if err := tmp.Unpack(actual); err != nil { t.Fatalf("Error unpacking actual msg: %v", err) } t.Errorf("Expected msg:\n%s", expectstr) t.Errorf("Actual msg:\n%v", tmp) } }
// sendResponse is used to send a response packet func (s *Server) sendResponse(resp *dns.Msg, from net.Addr, unicast bool) error { // TODO(reddaly): Respect the unicast argument, and allow sending responses // over multicast. buf, err := resp.Pack() if err != nil { return err } // Determine the socket to send from addr := from.(*net.UDPAddr) if addr.IP.To4() != nil { _, err = s.ipv4List.WriteToUDP(buf, addr) } else { _, err = s.ipv6List.WriteToUDP(buf, addr) } return err }
func (g *grpcResolver) Query(r *dns.Msg, tr trace.Trace) (*dns.Msg, error) { buf, err := r.Pack() if err != nil { return nil, err } // Give our RPCs 2 second timeouts: DNS usually doesn't wait that long // anyway. ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() reply, err := g.client.Query(ctx, &pb.RawMsg{Data: buf}) if err != nil { return nil, err } m := &dns.Msg{} err = m.Unpack(reply.Data) return m, err }
// sendQuery is used to multicast a query out func (c *client) sendQuery(q *dns.Msg, iface *net.Interface) error { buf, err := q.Pack() if err != nil { return err } if c.ipv4UnicastConn != nil { p := ipv4.NewPacketConn(c.ipv4UnicastConn) if iface != nil { p.SetMulticastInterface(iface) } p.WriteTo(buf, nil, ipv4Addr) } if c.ipv6UnicastConn != nil { p := ipv6.NewPacketConn(c.ipv6UnicastConn) if iface != nil { p.SetMulticastInterface(iface) } p.WriteTo(buf, nil, ipv6Addr) } return nil }
/* func (this ClientProxy) CreateHTTPClient() { if this.start_TLS == false { this.client = &http.Client{} } else { if this.TLS_Path == "" { tr := &http.Transport{ TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, DisableCompression: true} this.client = &http.Client{Transport: tr} } else { pool := x509.NewCertPool() caCrt, err := ioutil.ReadFile(this.TLS_Path) if err != nil { _D("invalid cetificate path: %s", this.TLS_Path) return } pool.AppendCertsFromPEM(caCrt) tr := &http.Transport{ TLSClientConfig: &tls.Config{RootCAs: pool}, } this.client = &http.Client{Transport: tr} } } } */ func (this ClientProxy) ServeDNS(w dns.ResponseWriter, request *dns.Msg) { _LOG("get %s query from %s", request.Question[0].Name, w.RemoteAddr()) request_bytes, err := request.Pack() //I am not sure it is better to pack directly or using a pointer if err != nil { SRVFAIL(w, request) _D("error in packing request from %s for '%s', error message: %s", dns.ResponseWriter.RemoteAddr(w), request.Question[0].Name, err) return } ServerInput := this.SERVERS[rand.Intn(len(this.SERVERS))] ipaddress := net.ParseIP(ServerInput) var ServerInputurl string if this.start_TLS { //if it is TLS, use HTTPS if ipaddress.To4() != nil { ServerInputurl = "https://" + ServerInput } else { ServerInputurl = "https://[" + ServerInput + "]" } } else { if ipaddress.To4() != nil { ServerInputurl = "http://" + ServerInput } else { ServerInputurl = "http://[" + ServerInput + "]" } } postBytesReader := bytes.NewReader(request_bytes) ServerInputurl = ServerInputurl + "/proxy_dns" req, err := http.NewRequest("POST", ServerInputurl, postBytesReader) //need add random here in future if err != nil { SRVFAIL(w, request) _D("error in creating HTTP request from %s for '%s', error message: %s", dns.ResponseWriter.RemoteAddr(w), request.Question[0].Name, err) return } req.Header.Add("Host", ServerInput) req.Header.Add("Accept", "application/octet-stream") req.Header.Add("Content-Type", "application/octet-stream") if this.TransPro == UDPcode { req.Header.Add("Proxy-DNS-Transport", "UDP") } else if this.TransPro == TCPcode { req.Header.Add("Proxy-DNS-Transport", "TCP") } if this.start_TLS == false { //HTTP version tr := &http.Transport{ DisableKeepAlives: true, TLSNextProto: nil} this.client = &http.Client{Transport: tr} } else { //HTTPS version disabled certificate verification if this.TLS_Path == "" { tr := &http.Transport{ TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, DisableCompression: true} this.client = &http.Client{Transport: tr} } else { //HTTPS version allow the certificate manually pool := x509.NewCertPool() caCrt, err := ioutil.ReadFile(this.TLS_Path) if err != nil { _D("invalid cetificate path: %s", this.TLS_Path) return } pool.AppendCertsFromPEM(caCrt) tr := &http.Transport{ TLSClientConfig: &tls.Config{RootCAs: pool}, } this.client = &http.Client{Transport: tr} } } resp, err := this.client.Do(req) if err != nil { SRVFAIL(w, request) _D("error in HTTP post request for query from %s for '%s', error message: %s", dns.ResponseWriter.RemoteAddr(w), request.Question[0].Name, err) return } if resp.StatusCode >= 500 { SRVFAIL(w, request) _D("HTTP ERROR: %s", http.StatusText(resp.StatusCode)) } var requestBody []byte requestBody, err = ioutil.ReadAll(resp.Body) // nRead, err := resp.Body.Read(requestBody) if err != nil { // these need to be separate checks, otherwise you will get a nil-reference // when you print the error message below! SRVFAIL(w, request) _D("error in reading HTTP response for query from %s for '%s', error message: %s", dns.ResponseWriter.RemoteAddr(w), request.Question[0].Name, err) return } //I not sure whether I should return server fail directly //I just found there is a bug here. Body.Read can not read all the contents out, I don't know how to solve it. if len(requestBody) < (int)(resp.ContentLength) { SRVFAIL(w, request) _D("failure reading all HTTP content for query from %s for '%s' (%d of %d bytes read)", dns.ResponseWriter.RemoteAddr(w), request.Question[0].Name, len(requestBody), (int)(resp.ContentLength)) return } var DNSreponse dns.Msg err = DNSreponse.Unpack(requestBody) if err != nil { SRVFAIL(w, request) _D("error in packing HTTP response for query from %s for '%s', error message: %s", dns.ResponseWriter.RemoteAddr(w), request.Question[0].Name, err) return } err = w.WriteMsg(&DNSreponse) if err != nil { _D("error in sending DNS response back for query from %s for '%s', error message: %s", dns.ResponseWriter.RemoteAddr(w), request.Question[0].Name, err) return } }
func handleReflect(w dns.ResponseWriter, r *dns.Msg) { reflectHandled += 1 if reflectHandled%1000 == 0 { fmt.Printf("Served %d reflections\n", reflectHandled) } var ( v4 bool rr dns.RR str string a net.IP ) m := new(dns.Msg) m.SetReply(r) m.Compress = *compress if ip, ok := w.RemoteAddr().(*net.UDPAddr); ok { str = "Port: " + strconv.Itoa(ip.Port) + " (udp)" a = ip.IP v4 = a.To4() != nil } if ip, ok := w.RemoteAddr().(*net.TCPAddr); ok { str = "Port: " + strconv.Itoa(ip.Port) + " (tcp)" a = ip.IP v4 = a.To4() != nil } if v4 { rr = new(dns.A) rr.(*dns.A).Hdr = dns.RR_Header{Name: dom, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 0} rr.(*dns.A).A = a.To4() } else { rr = new(dns.AAAA) rr.(*dns.AAAA).Hdr = dns.RR_Header{Name: dom, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 0} rr.(*dns.AAAA).AAAA = a } t := new(dns.TXT) t.Hdr = dns.RR_Header{Name: dom, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 0} t.Txt = []string{str} switch r.Question[0].Qtype { case dns.TypeTXT: m.Answer = append(m.Answer, t) m.Extra = append(m.Extra, rr) default: fallthrough case dns.TypeAAAA, dns.TypeA: m.Answer = append(m.Answer, rr) m.Extra = append(m.Extra, t) case dns.TypeAXFR, dns.TypeIXFR: c := make(chan *dns.Envelope) tr := new(dns.Transfer) defer close(c) err := tr.Out(w, r, c) if err != nil { return } soa, _ := dns.NewRR(`whoami.miek.nl. 0 IN SOA linode.atoom.net. miek.miek.nl. 2009032802 21600 7200 604800 3600`) c <- &dns.Envelope{RR: []dns.RR{soa, t, rr, soa}} w.Hijack() // w.Close() // Client closes connection return } if r.IsTsig() != nil { if w.TsigStatus() == nil { m.SetTsig(r.Extra[len(r.Extra)-1].(*dns.TSIG).Hdr.Name, dns.HmacMD5, 300, time.Now().Unix()) } else { println("Status", w.TsigStatus().Error()) } } if *printf { fmt.Printf("%v\n", m.String()) } // set TC when question is tc.miek.nl. if m.Question[0].Name == "tc.miek.nl." { m.Truncated = true // send half a message buf, _ := m.Pack() w.Write(buf[:len(buf)/2]) return } w.WriteMsg(m) }
func route(w dns.ResponseWriter, req *dns.Msg) { keyP, err := getKey(req) if err != nil { failWithRcode(w, req, dns.RcodeRefused) return } if handleSpecialNames(w, req) { return } maxPayloadSize := getMaxPayloadSize(req) var resp *dns.Msg cacheValP, _ := cache.Get(*keyP) if cacheValP != nil { cacheVal := cacheValP.(CacheVal) remaining := -time.Since(cacheVal.ValidUntil) if remaining > 0 { resp = cacheVal.Response.Copy() resp.Id = req.Id resp.Question = req.Question } } if *debug { question := req.Question[0] cachedStr := "" if resp != nil { cachedStr = " (cached)" } log.Printf("%v\t%v %v%v\n", w.RemoteAddr(), question.Name, dns.TypeToString[question.Qtype], cachedStr) } if resp == nil { slipValue := atomic.LoadUint32(&slip) if slipValue > 0 && slipValue%2 == 0 { atomic.CompareAndSwapUint32(&slip, slipValue, slipValue+1) if slipValue%4 == 0 { sendTruncated(w, req.MsgHdr) } else { w.Close() } return } } if resp == nil { resp, err = resolve(req, keyP.DNSSEC) if err == nil { validUntil := time.Now().Add(getMinTTL(resp)) cache.Add(*keyP, CacheVal{ValidUntil: validUntil, Response: resp}) } else { if cacheValP == nil { w.Close() return } cacheVal := cacheValP.(CacheVal) resp = cacheVal.Response.Copy() resp.Id = req.Id resp.Question = req.Question } } packed, _ := resp.Pack() packedLen := len(packed) if uint16(packedLen) > maxPayloadSize { sendTruncated(w, resp.MsgHdr) } else { w.WriteMsg(resp) } }
func proxyServe(w dns.ResponseWriter, req *dns.Msg) { var ( key string m *dns.Msg err error tried bool data []byte id uint16 query []string questions []dns.Question used string ) defer func() { if err := recover(); err != nil { fmt.Println(err) } }() if req.MsgHdr.Response == true { // supposed responses sent to us are bogus return } query = make([]string, len(req.Question)) for i, q := range req.Question { if q.Qtype != dns.TypeAAAA || *ipv6 { questions = append(questions, q) } query[i] = fmt.Sprintf("(%s %s %s)", q.Name, dns.ClassToString[q.Qclass], dns.TypeToString[q.Qtype]) } if len(questions) == 0 { return } req.Question = questions id = req.Id req.Id = 0 key = toMd5(req.String()) req.Id = id if ENCACHE { if reply, ok := conn.Get(key); ok { data, _ = reply.([]byte) } if data != nil && len(data) > 0 { m = &dns.Msg{} m.Unpack(data) m.Id = id err = w.WriteMsg(m) if DEBUG > 0 { log.Printf("id: %5d cache: HIT %v\n", id, query) } goto end } else { if DEBUG > 0 { log.Printf("id: %5d cache: MISS %v\n", id, query) } } } for i, parts := range DNS { dns := parts[0] proto := parts[1] tried = i > 0 if DEBUG > 0 { if tried { log.Printf("id: %5d try: %v %s %s\n", id, query, dns, proto) } else { log.Printf("id: %5d resolve: %v %s %s\n", id, query, dns, proto) } } client := clientUDP if proto == "tcp" { client = clientTCP } m, _, err = client.Exchange(req, dns) if err == nil && len(m.Answer) > 0 { used = dns break } } if err == nil { if DEBUG > 0 { if tried { if len(m.Answer) == 0 { log.Printf("id: %5d failed: %v\n", id, query) } else { log.Printf("id: %5d bingo: %v %s\n", id, query, used) } } } data, err = m.Pack() if err == nil { _, err = w.Write(data) if err == nil { if ENCACHE { m.Id = 0 data, _ = m.Pack() ttl := 0 if len(m.Answer) > 0 { ttl = int(m.Answer[0].Header().Ttl) if ttl < 0 { ttl = 0 } } conn.Set(key, data, time.Second*time.Duration(ttl)) m.Id = id if DEBUG > 0 { log.Printf("id: %5d cache: CACHED %v TTL %v\n", id, query, ttl) } } } } } end: if DEBUG > 1 { fmt.Println(req) if m != nil { fmt.Println(m) } } if err != nil { log.Printf("id: %5d error: %v %s\n", id, query, err) } if DEBUG > 1 { fmt.Println("====================================================") } }
func dnschanDnsQuery(servers map[net.UDPAddr]*dnsServer, domain string, RecordChan chan *DnsRecord, ExitChan chan int) { select { case <-ExitChan: return default: } myexitChan := make(chan int) defer func() { close(myexitChan) }() // 打开端口 conn, err := net.ListenUDP("udp", nil) if err != nil { log.Printf("为dns请求打开udp失败,%v", err) return } defer conn.Close() conn.SetDeadline(time.Now().Add(5 * time.Second)) // dns 请求 m := new(dns.Msg) m.SetQuestion(dns.Fqdn(domain), dns.TypeA) m.RecursionDesired = true mData, err := m.Pack() if err != nil { log.Printf("生成dns请求包失败,%v", err) return } // 另开一个线程发出查询 go func() { for k, _ := range servers { if _, err := conn.WriteToUDP(mData, k); err != nil { log.Printf("向%v发送dns请求失败,%v", k, err) } } }() // 等待关闭 go func() { select { case <-ExitChan: case <-myexitChan: } conn.Close() }() // 接收查询结果并输入到信道 buf := make([]byte, 1500) for { n, addr, err := conn.ReadFromUDP(buf) if err != nil { return } nbuf := buf[:n] r := new(dns.Msg) if err := r.Unpack(nbuf); err != nil { log.Printf("解 DNS 包失败,%v", err) continue } if r.Id != m.Id { log.Printf("错误的dns id。") continue } v := servers[*addr] if v == nil { log.Printf("未知的服务器 %v 回应,忽略。", *addr) continue } for _, a := range r.Answer { dnsA, err := a.(*dns.A) if err != nil || dnsA == nil { log.Printf("内部错误,a=%v,err=%v", dnsA, err) } select { case RecordChan <- &DnsRecord{ Ip: dnsA.A.String(), Credit: v.Credit, }: case <-ExitChan: return } } } }
func TestServerSimpleQuery(t *testing.T) { var ( testRecord1 = Record{"test.weave.local.", net.ParseIP("10.20.20.10"), 0, 0, 0} testRecord2 = Record{"test.weave.local.", net.ParseIP("10.20.20.20"), 0, 0, 0} testInAddr1 = "10.20.20.10.in-addr.arpa." ) InitDefaultLogging(testing.Verbose()) Info.Println("TestServerSimpleQuery starting") mzone := newMockedZoneWithRecords([]ZoneRecord{testRecord1, testRecord2}) mdnsServer, err := NewMDNSServer(mzone, true, DefaultLocalTTL) require.NoError(t, err) err = mdnsServer.Start(nil) require.NoError(t, err) defer mdnsServer.Stop() var receivedAddrs []net.IP receivedName := "" recvChan := make(chan interface{}) receivedCount := 0 // Implement a minimal listener for responses multicast, err := LinkLocalMulticastListener(nil) require.NoError(t, err) handleMDNS := func(w dns.ResponseWriter, r *dns.Msg) { // Only handle responses here if len(r.Answer) > 0 { t.Logf("Received %d answer(s)", len(r.Answer)) for _, answer := range r.Answer { recvChan <- answer } recvChan <- "ok" } } sendQuery := func(name string, querytype uint16) { receivedAddrs = make([]net.IP, 0) receivedName = "" receivedCount = 0 m := new(dns.Msg) m.SetQuestion(name, querytype) m.RecursionDesired = false buf, err := m.Pack() require.NoError(t, err) conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) require.NoError(t, err) Debug.Printf("Sending UDP packet to %s", ipv4Addr) _, err = conn.WriteTo(buf, ipv4Addr) require.NoError(t, err) Debug.Printf("Waiting for response") for { select { case x := <-recvChan: switch rr := x.(type) { case *dns.A: t.Logf("... A:\n%+v", rr) receivedAddrs = append(receivedAddrs, rr.A) receivedCount++ case *dns.PTR: t.Logf("... PTR:\n%+v", rr) receivedName = rr.Ptr receivedCount++ case string: return } case <-time.After(100 * time.Millisecond): Debug.Printf("Timeout while waiting for response") return } } } listener := &dns.Server{ Unsafe: true, PacketConn: multicast, Handler: dns.HandlerFunc(handleMDNS), ReadTimeout: 100 * time.Millisecond} go listener.ActivateAndServe() defer listener.Shutdown() time.Sleep(100 * time.Millisecond) // Allow for server to get going Debug.Printf("Checking that we get 2 IPs fo name '%s' [A]", testRecord1.Name()) sendQuery(testRecord1.Name(), dns.TypeA) if receivedCount != 2 { t.Fatalf("Unexpected result count %d for %s", receivedCount, testRecord1.Name()) } if !(receivedAddrs[0].Equal(testRecord1.IP()) || receivedAddrs[0].Equal(testRecord2.IP())) { t.Fatalf("Unexpected result %s for %s", receivedAddrs, testRecord1.Name()) } if !(receivedAddrs[1].Equal(testRecord1.IP()) || receivedAddrs[1].Equal(testRecord2.IP())) { t.Fatalf("Unexpected result %s for %s", receivedAddrs, testRecord1.Name()) } Debug.Printf("Checking that 'testfail.weave.' [A] gets no answers") sendQuery("testfail.weave.", dns.TypeA) if receivedCount != 0 { t.Fatalf("Unexpected result count %d for testfail.weave", receivedCount) } Debug.Printf("Checking that '%s' [PTR] gets one name", testInAddr1) sendQuery(testInAddr1, dns.TypePTR) if receivedCount != 1 { t.Fatalf("Expected an answer to %s, got %d answers", testInAddr1, receivedCount) } else if !(testRecord1.Name() == receivedName) { t.Fatalf("Expected answer %s to query for %s, got %s", testRecord1.Name(), testInAddr1, receivedName) } }