Beispiel #1
0
// Async
func (c *MDNSClient) SendQuery(name string, querytype uint16, insistent bool, responseCh chan<- *Response) {
	c.actionChan <- func() {
		query, found := c.inflight[name]
		if !found {
			m := new(dns.Msg)
			m.SetQuestion(name, querytype)
			m.RecursionDesired = false
			buf, err := m.Pack()
			if err != nil {
				responseCh <- &Response{err: err}
				close(responseCh)
				return
			}
			query = &inflightQuery{
				name: name,
				id:   m.Id,
			}
			if _, err = c.conn.WriteTo(buf, c.addr); err != nil {
				responseCh <- &Response{err: err}
				close(responseCh)
				return
			}
			c.inflight[name] = query
		}
		info := &responseInfo{
			ch:        responseCh,
			timeout:   time.Now().Add(mDNSTimeout),
			insistent: insistent,
		}
		// Invariant on responseInfos: they are in ascending order of
		// timeout.  Since we use a fixed interval from Now(), this
		// must be after all existing timeouts.
		query.responseInfos = append(query.responseInfos, info)
	}
}
Beispiel #2
0
func (c *Cache) Insert(d *dns.Msg) {
	if *flaglog {
		log.Printf("fsk-shield: inserting " + toRadixKey(d))
	}
	buf, _ := d.Pack() // Should always work
	c.Radix.Insert(toRadixKey(d), &Packet{d: buf[2:], ttl: time.Now().UTC()})
}
Beispiel #3
0
func (this *cache) Add(message *dns.Msg) error {
	if message.Question[0].Qtype == dns.TypeA && message.Question[0].Qclass == dns.ClassINET {
		byteMessage, err := message.Pack()
		if err != nil {
			return err
		}

		err = this.collection.SetObject(message.Question[0].Name, CacheRecord{Expiry: time.Now().Add(24 * time.Hour), Record: byteMessage})
		for _, part := range message.Answer {
			switch part.(type) {
			case *dns.A:
				cacheErr := this.ipToHostname.SetObject(part.(*dns.A).A.String(), strings.TrimSuffix(message.Question[0].Name, "."))
				if cacheErr != nil {
					log.Println("Warning: Error Adding/Updating Cache IP:" + part.(*dns.A).A.String() + " Hostname:" + strings.TrimSuffix(message.Question[0].Name, ".") + " Error:" + cacheErr.Error())
				}
			case *dns.CNAME:
				//CNAME Don't contain the IP for reverse lookups.
				//log.Printf("CNAME: Type:%s Value:%v", reflect.TypeOf(part).Name(), part)
				//TODO: We should probably Add the CNAME as a hostname??
			default:
				log.Printf("Type:%s Value:%v", reflect.TypeOf(part).Name(), part)
			}
		}
		return err
	}
	return nil
}
Beispiel #4
0
func (c *Client) Discover(domain string, cb func(*dns.Msg)) {
	m := new(dns.Msg)
	m.SetQuestion(dns.Fqdn(domain), dns.TypePTR)
	m.RecursionDesired = true

	addr := &net.UDPAddr{
		IP:   net.ParseIP("224.0.0.251"),
		Port: 5353,
	}

	conn, err := net.ListenMulticastUDP("udp4", nil, addr)

	if err != nil {
		panic(err)
	}

	defer conn.Close()
	c.conn = conn

	out, err := m.Pack()

	if err != nil {
		panic(err)
	}

	_, err = conn.WriteToUDP(out, addr)
	if err != nil {
		panic(err)
	}

	c.handleReceiveMsg(domain, cb)
}
Beispiel #5
0
func (s *MDNSServer) sendResponse(m *dns.Msg) error {
	buf, err := m.Pack()
	if err != nil {
		return err
	}
	_, err = s.sendconn.WriteTo(buf, ipv4Addr)
	return err
}
Beispiel #6
0
func (conn *DNSUDPConn) WriteDNSToUDP(m *dns.Msg, addr *net.UDPAddr) error {
	out, err := m.Pack()
	if err != nil {
		return err
	}
	_, err = conn.WriteToUDP(out, addr)
	return err
}
Beispiel #7
0
// encode an mdns msg and broadcast it on the wire
func (c *connector) writeMessage(msg *dns.Msg, addr *net.UDPAddr) error {
	buf, err := msg.Pack()
	if err != nil {
		return err
	}
	_, err = c.WriteToUDP(buf, addr)
	return err
}
Beispiel #8
0
func CompressIfLarge(m *dns.Msg) {
	bytes, err := m.Pack()
	if err != nil {
		return
	}
	if len(bytes) > 512 { // may not fit into UDP packet
		m.Compress = true // will be compressed in WriteMsg
	}
}
func TestRFC2136ValidUpdatePacket(t *testing.T) {
	dns.HandleFunc(rfc2136TestZone, serverHandlerPassBackRequest)
	defer dns.HandleRemove(rfc2136TestZone)

	server, addrstr, err := runLocalDNSTestServer("127.0.0.1:0", false)
	if err != nil {
		t.Fatalf("Failed to start test server: %v", err)
	}
	defer server.Shutdown()

	rr := new(dns.TXT)
	rr.Hdr = dns.RR_Header{
		Name:   rfc2136TestFqdn,
		Rrtype: dns.TypeTXT,
		Class:  dns.ClassINET,
		Ttl:    uint32(rfc2136TestTTL),
	}
	rr.Txt = []string{rfc2136TestValue}
	rrs := make([]dns.RR, 1)
	rrs[0] = rr
	m := new(dns.Msg)
	m.SetUpdate(dns.Fqdn(rfc2136TestZone))
	m.Insert(rrs)
	expectstr := m.String()
	expect, err := m.Pack()
	if err != nil {
		t.Fatalf("Error packing expect msg: %v", err)
	}

	provider, err := NewDNSProviderRFC2136(addrstr, rfc2136TestZone, "", "")
	if err != nil {
		t.Fatalf("Expected NewDNSProviderRFC2136() to return no error but the error was -> %v", err)
	}

	if err := provider.Present(rfc2136TestDomain, "", "1234d=="); err != nil {
		t.Errorf("Expected Present() to return no error but the error was -> %v", err)
	}

	rcvMsg := <-reqChan
	rcvMsg.Id = m.Id
	actual, err := rcvMsg.Pack()
	if err != nil {
		t.Fatalf("Error packing actual msg: %v", err)
	}

	if !bytes.Equal(actual, expect) {
		tmp := new(dns.Msg)
		if err := tmp.Unpack(actual); err != nil {
			t.Fatalf("Error unpacking actual msg: %v", err)
		}
		t.Errorf("Expected msg:\n%s", expectstr)
		t.Errorf("Actual msg:\n%v", tmp)
	}
}
Beispiel #10
0
// sendQuery is used to multicast a query out
func (c *client) sendQuery(q *dns.Msg) error {
	buf, err := q.Pack()
	if err != nil {
		return err
	}
	if c.ipv4List != nil {
		c.ipv4List.WriteTo(buf, ipv4Addr)
	}
	if c.ipv6List != nil {
		c.ipv6List.WriteTo(buf, ipv6Addr)
	}
	return nil
}
Beispiel #11
0
// Pack the dns.Msg and write to available connections (multicast)
func (c *client) sendQuery(msg *dns.Msg) error {
	buf, err := msg.Pack()
	if err != nil {
		return err
	}
	if c.ipv4conn != nil {
		c.ipv4conn.WriteTo(buf, ipv4Addr)
	}
	if c.ipv6conn != nil {
		c.ipv6conn.WriteTo(buf, ipv6Addr)
	}
	return nil
}
Beispiel #12
0
// sendQuery is used to multicast a query out
func (c *client) sendQuery(q *dns.Msg, params *QueryParam) error {
	buf, err := q.Pack()
	if err != nil {
		return err
	}
	if c.ipv4List != nil {
		c.ipv4List.WriteTo(buf, params.IPv4mdns)
	}
	if c.ipv6List != nil {
		c.ipv6List.WriteTo(buf, params.IPv6mdns)
	}
	return nil
}
Beispiel #13
0
// sendQuery is used to multicast a query out
func (c *client) sendQuery(q *dns.Msg) error {
	buf, err := q.Pack()
	if err != nil {
		return err
	}
	if c.ipv4UnicastConn != nil {
		c.ipv4UnicastConn.WriteToUDP(buf, ipv4Addr)
	}
	if c.ipv6UnicastConn != nil {
		c.ipv6UnicastConn.WriteToUDP(buf, ipv6Addr)
	}
	return nil
}
Beispiel #14
0
// multicastResponse us used to send a multicast response packet
func (c *Server) multicastResponse(msg *dns.Msg) error {
	buf, err := msg.Pack()
	if err != nil {
		log.Println("Failed to pack message!")
		return err
	}
	if c.ipv4conn != nil {
		c.ipv4conn.WriteTo(buf, ipv4Addr)
	}
	if c.ipv6conn != nil {
		c.ipv6conn.WriteTo(buf, ipv6Addr)
	}
	return nil
}
Beispiel #15
0
// unicastResponse is used to send a unicast response packet
func (s *Server) unicastResponse(resp *dns.Msg, from net.Addr) error {
	buf, err := resp.Pack()
	if err != nil {
		return err
	}
	addr := from.(*net.UDPAddr)
	if addr.IP.To4() != nil {
		_, err = s.ipv4conn.WriteToUDP(buf, addr)
		return err
	} else {
		_, err = s.ipv6conn.WriteToUDP(buf, addr)
		return err
	}
}
Beispiel #16
0
func TestRFC2136ValidUpdatePacket(t *testing.T) {
	acme.ClearFqdnCache()
	dns.HandleFunc(rfc2136TestZone, serverHandlerPassBackRequest)
	defer dns.HandleRemove(rfc2136TestZone)

	server, addrstr, err := runLocalDNSTestServer("127.0.0.1:0", false)
	if err != nil {
		t.Fatalf("Failed to start test server: %v", err)
	}
	defer server.Shutdown()

	txtRR, _ := dns.NewRR(fmt.Sprintf("%s %d IN TXT %s", rfc2136TestFqdn, rfc2136TestTTL, rfc2136TestValue))
	rrs := []dns.RR{txtRR}
	m := new(dns.Msg)
	m.SetUpdate(rfc2136TestZone)
	m.RemoveRRset(rrs)
	m.Insert(rrs)
	expectstr := m.String()
	expect, err := m.Pack()
	if err != nil {
		t.Fatalf("Error packing expect msg: %v", err)
	}

	provider, err := NewDNSProvider(addrstr, "", "", "")
	if err != nil {
		t.Fatalf("Expected NewDNSProvider() to return no error but the error was -> %v", err)
	}

	if err := provider.Present(rfc2136TestDomain, "", "1234d=="); err != nil {
		t.Errorf("Expected Present() to return no error but the error was -> %v", err)
	}

	rcvMsg := <-reqChan
	rcvMsg.Id = m.Id
	actual, err := rcvMsg.Pack()
	if err != nil {
		t.Fatalf("Error packing actual msg: %v", err)
	}

	if !bytes.Equal(actual, expect) {
		tmp := new(dns.Msg)
		if err := tmp.Unpack(actual); err != nil {
			t.Fatalf("Error unpacking actual msg: %v", err)
		}
		t.Errorf("Expected msg:\n%s", expectstr)
		t.Errorf("Actual msg:\n%v", tmp)
	}
}
Beispiel #17
0
// sendResponse is used to send a response packet
func (s *Server) sendResponse(resp *dns.Msg, from net.Addr, unicast bool) error {
	// TODO(reddaly): Respect the unicast argument, and allow sending responses
	// over multicast.
	buf, err := resp.Pack()
	if err != nil {
		return err
	}

	// Determine the socket to send from
	addr := from.(*net.UDPAddr)
	if addr.IP.To4() != nil {
		_, err = s.ipv4List.WriteToUDP(buf, addr)
	} else {
		_, err = s.ipv6List.WriteToUDP(buf, addr)
	}
	return err
}
Beispiel #18
0
func (g *grpcResolver) Query(r *dns.Msg, tr trace.Trace) (*dns.Msg, error) {
	buf, err := r.Pack()
	if err != nil {
		return nil, err
	}

	// Give our RPCs 2 second timeouts: DNS usually doesn't wait that long
	// anyway.
	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
	defer cancel()

	reply, err := g.client.Query(ctx, &pb.RawMsg{Data: buf})
	if err != nil {
		return nil, err
	}

	m := &dns.Msg{}
	err = m.Unpack(reply.Data)
	return m, err
}
Beispiel #19
0
// sendQuery is used to multicast a query out
func (c *client) sendQuery(q *dns.Msg, iface *net.Interface) error {
	buf, err := q.Pack()
	if err != nil {
		return err
	}
	if c.ipv4UnicastConn != nil {
		p := ipv4.NewPacketConn(c.ipv4UnicastConn)
		if iface != nil {
			p.SetMulticastInterface(iface)
		}
		p.WriteTo(buf, nil, ipv4Addr)
	}
	if c.ipv6UnicastConn != nil {
		p := ipv6.NewPacketConn(c.ipv6UnicastConn)
		if iface != nil {
			p.SetMulticastInterface(iface)
		}
		p.WriteTo(buf, nil, ipv6Addr)
	}
	return nil
}
/*
func (this ClientProxy) CreateHTTPClient() {
	if this.start_TLS == false {
		this.client = &http.Client{}
	} else {
		if this.TLS_Path == "" {
			tr := &http.Transport{
				TLSClientConfig:    &tls.Config{InsecureSkipVerify: true},
				DisableCompression: true}
			this.client = &http.Client{Transport: tr}
		} else {
			pool := x509.NewCertPool()
			caCrt, err := ioutil.ReadFile(this.TLS_Path)
			if err != nil {
				_D("invalid cetificate path: %s", this.TLS_Path)
				return
			}
			pool.AppendCertsFromPEM(caCrt)
			tr := &http.Transport{
				TLSClientConfig: &tls.Config{RootCAs: pool},
			}
			this.client = &http.Client{Transport: tr}
		}
	}
}
*/
func (this ClientProxy) ServeDNS(w dns.ResponseWriter, request *dns.Msg) {
	_LOG("get %s query from %s", request.Question[0].Name, w.RemoteAddr())
	request_bytes, err := request.Pack() //I am not sure it is better to pack directly or using a pointer
	if err != nil {
		SRVFAIL(w, request)
		_D("error in packing request from %s for '%s', error message: %s",
			dns.ResponseWriter.RemoteAddr(w), request.Question[0].Name, err)
		return
	}
	ServerInput := this.SERVERS[rand.Intn(len(this.SERVERS))]
	ipaddress := net.ParseIP(ServerInput)
	var ServerInputurl string
	if this.start_TLS { //if it is TLS, use HTTPS
		if ipaddress.To4() != nil {
			ServerInputurl = "https://" + ServerInput
		} else {
			ServerInputurl = "https://[" + ServerInput + "]"
		}

	} else {
		if ipaddress.To4() != nil {
			ServerInputurl = "http://" + ServerInput
		} else {
			ServerInputurl = "http://[" + ServerInput + "]"
		}
	}

	postBytesReader := bytes.NewReader(request_bytes)

	ServerInputurl = ServerInputurl + "/proxy_dns"

	req, err := http.NewRequest("POST", ServerInputurl, postBytesReader) //need add random here in future
	if err != nil {
		SRVFAIL(w, request)
		_D("error in creating HTTP request from %s for '%s', error message: %s",
			dns.ResponseWriter.RemoteAddr(w), request.Question[0].Name, err)
		return
	}
	req.Header.Add("Host", ServerInput)
	req.Header.Add("Accept", "application/octet-stream")
	req.Header.Add("Content-Type", "application/octet-stream")
	if this.TransPro == UDPcode {
		req.Header.Add("Proxy-DNS-Transport", "UDP")
	} else if this.TransPro == TCPcode {
		req.Header.Add("Proxy-DNS-Transport", "TCP")
	}
	if this.start_TLS == false {
		//HTTP version
		tr := &http.Transport{
			DisableKeepAlives: true,
			TLSNextProto:      nil}
		this.client = &http.Client{Transport: tr}
	} else {
		//HTTPS version disabled certificate verification
		if this.TLS_Path == "" {
			tr := &http.Transport{
				TLSClientConfig:    &tls.Config{InsecureSkipVerify: true},
				DisableCompression: true}
			this.client = &http.Client{Transport: tr}
		} else {
			//HTTPS version allow the certificate manually
			pool := x509.NewCertPool()
			caCrt, err := ioutil.ReadFile(this.TLS_Path)
			if err != nil {
				_D("invalid cetificate path: %s", this.TLS_Path)
				return
			}
			pool.AppendCertsFromPEM(caCrt)
			tr := &http.Transport{
				TLSClientConfig: &tls.Config{RootCAs: pool},
			}
			this.client = &http.Client{Transport: tr}
		}
	}
	resp, err := this.client.Do(req)
	if err != nil {
		SRVFAIL(w, request)
		_D("error in HTTP post request for query from %s for '%s', error message: %s",
			dns.ResponseWriter.RemoteAddr(w), request.Question[0].Name, err)
		return
	}
	if resp.StatusCode >= 500 {
		SRVFAIL(w, request)
		_D("HTTP ERROR: %s", http.StatusText(resp.StatusCode))
	}
	var requestBody []byte
	requestBody, err = ioutil.ReadAll(resp.Body)
	//	nRead, err := resp.Body.Read(requestBody)
	if err != nil {
		// these need to be separate checks, otherwise you will get a nil-reference
		// when you print the error message below!
		SRVFAIL(w, request)
		_D("error in reading HTTP response for query from %s for '%s', error message: %s",
			dns.ResponseWriter.RemoteAddr(w), request.Question[0].Name, err)
		return
	}
	//I not sure whether I should return server fail directly
	//I just found there is a bug here. Body.Read can not read all the contents out, I don't know how to solve it.
	if len(requestBody) < (int)(resp.ContentLength) {
		SRVFAIL(w, request)
		_D("failure reading all HTTP content for query from %s for '%s' (%d of %d bytes read)",
			dns.ResponseWriter.RemoteAddr(w), request.Question[0].Name,
			len(requestBody), (int)(resp.ContentLength))
		return
	}
	var DNSreponse dns.Msg
	err = DNSreponse.Unpack(requestBody)
	if err != nil {
		SRVFAIL(w, request)
		_D("error in packing HTTP response for query from %s for '%s', error message: %s",
			dns.ResponseWriter.RemoteAddr(w), request.Question[0].Name, err)
		return
	}
	err = w.WriteMsg(&DNSreponse)
	if err != nil {
		_D("error in sending DNS response back for query from %s for '%s', error message: %s",
			dns.ResponseWriter.RemoteAddr(w), request.Question[0].Name, err)
		return
	}
}
Beispiel #21
0
func handleReflect(w dns.ResponseWriter, r *dns.Msg) {
	reflectHandled += 1
	if reflectHandled%1000 == 0 {
		fmt.Printf("Served %d reflections\n", reflectHandled)
	}
	var (
		v4  bool
		rr  dns.RR
		str string
		a   net.IP
	)
	m := new(dns.Msg)
	m.SetReply(r)
	m.Compress = *compress
	if ip, ok := w.RemoteAddr().(*net.UDPAddr); ok {
		str = "Port: " + strconv.Itoa(ip.Port) + " (udp)"
		a = ip.IP
		v4 = a.To4() != nil
	}
	if ip, ok := w.RemoteAddr().(*net.TCPAddr); ok {
		str = "Port: " + strconv.Itoa(ip.Port) + " (tcp)"
		a = ip.IP
		v4 = a.To4() != nil
	}

	if v4 {
		rr = new(dns.A)
		rr.(*dns.A).Hdr = dns.RR_Header{Name: dom, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 0}
		rr.(*dns.A).A = a.To4()
	} else {
		rr = new(dns.AAAA)
		rr.(*dns.AAAA).Hdr = dns.RR_Header{Name: dom, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 0}
		rr.(*dns.AAAA).AAAA = a
	}

	t := new(dns.TXT)
	t.Hdr = dns.RR_Header{Name: dom, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 0}
	t.Txt = []string{str}

	switch r.Question[0].Qtype {
	case dns.TypeTXT:
		m.Answer = append(m.Answer, t)
		m.Extra = append(m.Extra, rr)
	default:
		fallthrough
	case dns.TypeAAAA, dns.TypeA:
		m.Answer = append(m.Answer, rr)
		m.Extra = append(m.Extra, t)

	case dns.TypeAXFR, dns.TypeIXFR:
		c := make(chan *dns.Envelope)
		tr := new(dns.Transfer)
		defer close(c)
		err := tr.Out(w, r, c)
		if err != nil {
			return
		}
		soa, _ := dns.NewRR(`whoami.miek.nl. 0 IN SOA linode.atoom.net. miek.miek.nl. 2009032802 21600 7200 604800 3600`)
		c <- &dns.Envelope{RR: []dns.RR{soa, t, rr, soa}}
		w.Hijack()
		// w.Close() // Client closes connection
		return

	}

	if r.IsTsig() != nil {
		if w.TsigStatus() == nil {
			m.SetTsig(r.Extra[len(r.Extra)-1].(*dns.TSIG).Hdr.Name, dns.HmacMD5, 300, time.Now().Unix())
		} else {
			println("Status", w.TsigStatus().Error())
		}
	}
	if *printf {
		fmt.Printf("%v\n", m.String())
	}
	// set TC when question is tc.miek.nl.
	if m.Question[0].Name == "tc.miek.nl." {
		m.Truncated = true
		// send half a message
		buf, _ := m.Pack()
		w.Write(buf[:len(buf)/2])
		return
	}
	w.WriteMsg(m)
}
Beispiel #22
0
func route(w dns.ResponseWriter, req *dns.Msg) {
	keyP, err := getKey(req)
	if err != nil {
		failWithRcode(w, req, dns.RcodeRefused)
		return
	}
	if handleSpecialNames(w, req) {
		return
	}
	maxPayloadSize := getMaxPayloadSize(req)
	var resp *dns.Msg
	cacheValP, _ := cache.Get(*keyP)
	if cacheValP != nil {
		cacheVal := cacheValP.(CacheVal)
		remaining := -time.Since(cacheVal.ValidUntil)
		if remaining > 0 {
			resp = cacheVal.Response.Copy()
			resp.Id = req.Id
			resp.Question = req.Question
		}
	}
	if *debug {
		question := req.Question[0]
		cachedStr := ""
		if resp != nil {
			cachedStr = " (cached)"
		}
		log.Printf("%v\t%v %v%v\n", w.RemoteAddr(), question.Name, dns.TypeToString[question.Qtype], cachedStr)
	}
	if resp == nil {
		slipValue := atomic.LoadUint32(&slip)
		if slipValue > 0 && slipValue%2 == 0 {
			atomic.CompareAndSwapUint32(&slip, slipValue, slipValue+1)
			if slipValue%4 == 0 {
				sendTruncated(w, req.MsgHdr)
			} else {
				w.Close()
			}
			return
		}
	}
	if resp == nil {
		resp, err = resolve(req, keyP.DNSSEC)
		if err == nil {
			validUntil := time.Now().Add(getMinTTL(resp))
			cache.Add(*keyP, CacheVal{ValidUntil: validUntil, Response: resp})
		} else {
			if cacheValP == nil {
				w.Close()
				return
			}
			cacheVal := cacheValP.(CacheVal)
			resp = cacheVal.Response.Copy()
			resp.Id = req.Id
			resp.Question = req.Question
		}
	}
	packed, _ := resp.Pack()
	packedLen := len(packed)
	if uint16(packedLen) > maxPayloadSize {
		sendTruncated(w, resp.MsgHdr)
	} else {
		w.WriteMsg(resp)
	}
}
Beispiel #23
0
func proxyServe(w dns.ResponseWriter, req *dns.Msg) {
	var (
		key       string
		m         *dns.Msg
		err       error
		tried     bool
		data      []byte
		id        uint16
		query     []string
		questions []dns.Question
		used      string
	)

	defer func() {
		if err := recover(); err != nil {
			fmt.Println(err)
		}
	}()

	if req.MsgHdr.Response == true { // supposed responses sent to us are bogus
		return
	}

	query = make([]string, len(req.Question))

	for i, q := range req.Question {
		if q.Qtype != dns.TypeAAAA || *ipv6 {
			questions = append(questions, q)
		}
		query[i] = fmt.Sprintf("(%s %s %s)", q.Name, dns.ClassToString[q.Qclass], dns.TypeToString[q.Qtype])
	}

	if len(questions) == 0 {
		return
	}

	req.Question = questions

	id = req.Id

	req.Id = 0
	key = toMd5(req.String())
	req.Id = id

	if ENCACHE {
		if reply, ok := conn.Get(key); ok {
			data, _ = reply.([]byte)
		}
		if data != nil && len(data) > 0 {
			m = &dns.Msg{}
			m.Unpack(data)
			m.Id = id
			err = w.WriteMsg(m)

			if DEBUG > 0 {
				log.Printf("id: %5d cache: HIT %v\n", id, query)
			}

			goto end
		} else {
			if DEBUG > 0 {
				log.Printf("id: %5d cache: MISS %v\n", id, query)
			}
		}
	}

	for i, parts := range DNS {
		dns := parts[0]
		proto := parts[1]
		tried = i > 0
		if DEBUG > 0 {
			if tried {
				log.Printf("id: %5d try: %v %s %s\n", id, query, dns, proto)
			} else {
				log.Printf("id: %5d resolve: %v %s %s\n", id, query, dns, proto)
			}
		}
		client := clientUDP
		if proto == "tcp" {
			client = clientTCP
		}
		m, _, err = client.Exchange(req, dns)
		if err == nil && len(m.Answer) > 0 {
			used = dns
			break
		}
	}

	if err == nil {
		if DEBUG > 0 {
			if tried {
				if len(m.Answer) == 0 {
					log.Printf("id: %5d failed: %v\n", id, query)
				} else {
					log.Printf("id: %5d bingo: %v %s\n", id, query, used)
				}
			}
		}
		data, err = m.Pack()
		if err == nil {
			_, err = w.Write(data)

			if err == nil {
				if ENCACHE {
					m.Id = 0
					data, _ = m.Pack()
					ttl := 0
					if len(m.Answer) > 0 {
						ttl = int(m.Answer[0].Header().Ttl)
						if ttl < 0 {
							ttl = 0
						}
					}
					conn.Set(key, data, time.Second*time.Duration(ttl))
					m.Id = id
					if DEBUG > 0 {
						log.Printf("id: %5d cache: CACHED %v TTL %v\n", id, query, ttl)
					}
				}
			}
		}
	}

end:
	if DEBUG > 1 {
		fmt.Println(req)
		if m != nil {
			fmt.Println(m)
		}
	}
	if err != nil {
		log.Printf("id: %5d error: %v %s\n", id, query, err)
	}

	if DEBUG > 1 {
		fmt.Println("====================================================")
	}
}
Beispiel #24
0
func dnschanDnsQuery(servers map[net.UDPAddr]*dnsServer, domain string, RecordChan chan *DnsRecord, ExitChan chan int) {
	select {
	case <-ExitChan:
		return
	default:
	}
	myexitChan := make(chan int)
	defer func() { close(myexitChan) }()

	// 打开端口
	conn, err := net.ListenUDP("udp", nil)
	if err != nil {
		log.Printf("为dns请求打开udp失败,%v", err)
		return
	}
	defer conn.Close()

	conn.SetDeadline(time.Now().Add(5 * time.Second))

	// dns 请求
	m := new(dns.Msg)
	m.SetQuestion(dns.Fqdn(domain), dns.TypeA)
	m.RecursionDesired = true
	mData, err := m.Pack()
	if err != nil {
		log.Printf("生成dns请求包失败,%v", err)
		return
	}

	// 另开一个线程发出查询
	go func() {
		for k, _ := range servers {
			if _, err := conn.WriteToUDP(mData, k); err != nil {
				log.Printf("向%v发送dns请求失败,%v", k, err)
			}
		}
	}()

	// 等待关闭
	go func() {
		select {
		case <-ExitChan:
		case <-myexitChan:
		}
		conn.Close()
	}()

	// 接收查询结果并输入到信道
	buf := make([]byte, 1500)
	for {
		n, addr, err := conn.ReadFromUDP(buf)
		if err != nil {
			return
		}
		nbuf := buf[:n]

		r := new(dns.Msg)
		if err := r.Unpack(nbuf); err != nil {
			log.Printf("解 DNS 包失败,%v", err)
			continue
		}

		if r.Id != m.Id {
			log.Printf("错误的dns id。")
			continue
		}

		v := servers[*addr]
		if v == nil {
			log.Printf("未知的服务器 %v 回应,忽略。", *addr)
			continue
		}

		for _, a := range r.Answer {
			dnsA, err := a.(*dns.A)
			if err != nil || dnsA == nil {
				log.Printf("内部错误,a=%v,err=%v", dnsA, err)
			}
			select {
			case RecordChan <- &DnsRecord{
				Ip:     dnsA.A.String(),
				Credit: v.Credit,
			}:
			case <-ExitChan:
				return
			}
		}
	}
}
Beispiel #25
0
func TestServerSimpleQuery(t *testing.T) {
	var (
		testRecord1 = Record{"test.weave.local.", net.ParseIP("10.20.20.10"), 0, 0, 0}
		testRecord2 = Record{"test.weave.local.", net.ParseIP("10.20.20.20"), 0, 0, 0}
		testInAddr1 = "10.20.20.10.in-addr.arpa."
	)

	InitDefaultLogging(testing.Verbose())
	Info.Println("TestServerSimpleQuery starting")

	mzone := newMockedZoneWithRecords([]ZoneRecord{testRecord1, testRecord2})
	mdnsServer, err := NewMDNSServer(mzone, true, DefaultLocalTTL)
	require.NoError(t, err)
	err = mdnsServer.Start(nil)
	require.NoError(t, err)
	defer mdnsServer.Stop()

	var receivedAddrs []net.IP
	receivedName := ""
	recvChan := make(chan interface{})
	receivedCount := 0

	// Implement a minimal listener for responses
	multicast, err := LinkLocalMulticastListener(nil)
	require.NoError(t, err)

	handleMDNS := func(w dns.ResponseWriter, r *dns.Msg) {
		// Only handle responses here
		if len(r.Answer) > 0 {
			t.Logf("Received %d answer(s)", len(r.Answer))
			for _, answer := range r.Answer {
				recvChan <- answer
			}
			recvChan <- "ok"
		}
	}

	sendQuery := func(name string, querytype uint16) {
		receivedAddrs = make([]net.IP, 0)
		receivedName = ""
		receivedCount = 0

		m := new(dns.Msg)
		m.SetQuestion(name, querytype)
		m.RecursionDesired = false
		buf, err := m.Pack()
		require.NoError(t, err)
		conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
		require.NoError(t, err)
		Debug.Printf("Sending UDP packet to %s", ipv4Addr)
		_, err = conn.WriteTo(buf, ipv4Addr)
		require.NoError(t, err)

		Debug.Printf("Waiting for response")
		for {
			select {
			case x := <-recvChan:
				switch rr := x.(type) {
				case *dns.A:
					t.Logf("... A:\n%+v", rr)
					receivedAddrs = append(receivedAddrs, rr.A)
					receivedCount++
				case *dns.PTR:
					t.Logf("... PTR:\n%+v", rr)
					receivedName = rr.Ptr
					receivedCount++
				case string:
					return
				}
			case <-time.After(100 * time.Millisecond):
				Debug.Printf("Timeout while waiting for response")
				return
			}
		}
	}

	listener := &dns.Server{
		Unsafe:      true,
		PacketConn:  multicast,
		Handler:     dns.HandlerFunc(handleMDNS),
		ReadTimeout: 100 * time.Millisecond}
	go listener.ActivateAndServe()
	defer listener.Shutdown()

	time.Sleep(100 * time.Millisecond) // Allow for server to get going

	Debug.Printf("Checking that we get 2 IPs fo name '%s' [A]", testRecord1.Name())
	sendQuery(testRecord1.Name(), dns.TypeA)
	if receivedCount != 2 {
		t.Fatalf("Unexpected result count %d for %s", receivedCount, testRecord1.Name())
	}
	if !(receivedAddrs[0].Equal(testRecord1.IP()) || receivedAddrs[0].Equal(testRecord2.IP())) {
		t.Fatalf("Unexpected result %s for %s", receivedAddrs, testRecord1.Name())
	}
	if !(receivedAddrs[1].Equal(testRecord1.IP()) || receivedAddrs[1].Equal(testRecord2.IP())) {
		t.Fatalf("Unexpected result %s for %s", receivedAddrs, testRecord1.Name())
	}

	Debug.Printf("Checking that 'testfail.weave.' [A] gets no answers")
	sendQuery("testfail.weave.", dns.TypeA)
	if receivedCount != 0 {
		t.Fatalf("Unexpected result count %d for testfail.weave", receivedCount)
	}

	Debug.Printf("Checking that '%s' [PTR] gets one name", testInAddr1)
	sendQuery(testInAddr1, dns.TypePTR)
	if receivedCount != 1 {
		t.Fatalf("Expected an answer to %s, got %d answers", testInAddr1, receivedCount)
	} else if !(testRecord1.Name() == receivedName) {
		t.Fatalf("Expected answer %s to query for %s, got %s", testRecord1.Name(), testInAddr1, receivedName)
	}
}