From cb8afde629f272be79782b8e5bee4908ae97eabc Mon Sep 17 00:00:00 2001 From: Andrey Meshkov Date: Fri, 11 Sep 2020 11:53:36 +0300 Subject: [PATCH] * (dnsforward): fix reverse lookups from /etc/hosts There was a bug with empty PTR responses for IPs that are in the hosts file Closes: #2085 --- dnsforward/dnsforward_test.go | 54 ++++++++++++++-- dnsforward/filter.go | 29 ++++----- util/auto_hosts.go | 118 ++++++++++++++++++---------------- util/auto_hosts_test.go | 4 -- 4 files changed, 125 insertions(+), 80 deletions(-) diff --git a/dnsforward/dnsforward_test.go b/dnsforward/dnsforward_test.go index 664a5c1b..2fb89075 100644 --- a/dnsforward/dnsforward_test.go +++ b/dnsforward/dnsforward_test.go @@ -9,13 +9,17 @@ import ( "crypto/x509/pkix" "encoding/pem" "fmt" + "io/ioutil" "math/big" "net" + "os" "sort" "sync" "testing" "time" + "github.com/AdguardTeam/AdGuardHome/util" + "github.com/AdguardTeam/AdGuardHome/dhcpd" "github.com/AdguardTeam/AdGuardHome/dnsfilter" "github.com/AdguardTeam/dnsproxy/proxy" @@ -664,17 +668,17 @@ func TestBlockedBySafeBrowsing(t *testing.T) { func TestRewrite(t *testing.T) { c := dnsfilter.Config{} c.Rewrites = []dnsfilter.RewriteEntry{ - dnsfilter.RewriteEntry{ + { Domain: "test.com", Answer: "1.2.3.4", Type: dns.TypeA, }, - dnsfilter.RewriteEntry{ + { Domain: "alias.test.com", Answer: "test.com", Type: dns.TypeCNAME, }, - dnsfilter.RewriteEntry{ + { Domain: "my.alias.example.org", Answer: "example.org", Type: dns.TypeCNAME, @@ -1066,7 +1070,7 @@ func (d *testDHCP) SetOnLeaseChanged(onLeaseChanged dhcpd.OnLeaseChangedT) { return } -func TestPTRResponse(t *testing.T) { +func TestPTRResponseFromDHCPLeases(t *testing.T) { dhcp := &testDHCP{} c := dnsfilter.Config{} @@ -1094,3 +1098,45 @@ func TestPTRResponse(t *testing.T) { s.Close() } + +func TestPTRResponseFromHosts(t *testing.T) { + c := dnsfilter.Config{ + AutoHosts: &util.AutoHosts{}, + } + + // Prepare test hosts file + hf, _ := ioutil.TempFile("", "") + defer func() { _ = os.Remove(hf.Name()) }() + defer hf.Close() + + _, _ = hf.WriteString(" 127.0.0.1 host # comment \n") + _, _ = hf.WriteString(" ::1 localhost#comment \n") + + // Init auto hosts + c.AutoHosts.Init(hf.Name()) + defer c.AutoHosts.Close() + + f := dnsfilter.New(&c, nil) + s := NewServer(DNSCreateParams{DNSFilter: f}) + s.conf.UDPListenAddr = &net.UDPAddr{Port: 0} + s.conf.TCPListenAddr = &net.TCPAddr{Port: 0} + s.conf.UpstreamDNS = []string{"127.0.0.1:53"} + s.conf.FilteringConfig.ProtectionEnabled = true + err := s.Prepare(nil) + assert.True(t, err == nil) + assert.Nil(t, s.Start()) + + addr := s.dnsProxy.Addr(proxy.ProtoUDP) + req := createTestMessage("1.0.0.127.in-addr.arpa.") + req.Question[0].Qtype = dns.TypePTR + + resp, err := dns.Exchange(req, addr.String()) + assert.Nil(t, err) + assert.Equal(t, 1, len(resp.Answer)) + assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype) + assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name) + ptr := resp.Answer[0].(*dns.PTR) + assert.Equal(t, "host.", ptr.Ptr) + + s.Close() +} diff --git a/dnsforward/filter.go b/dnsforward/filter.go index 068c5112..4ef3979d 100644 --- a/dnsforward/filter.go +++ b/dnsforward/filter.go @@ -51,7 +51,7 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*dnsfilter.Result, error) { return nil, errorx.Decorate(err, "dnsfilter failed to check host '%s'", host) } else if res.IsFiltered { - // log.Tracef("Host %s is filtered, reason - '%s', matched rule: '%s'", host, res.Reason, res.Rule) + log.Tracef("Host %s is filtered, reason - '%s', matched rule: '%s'", host, res.Reason, res.Rule) d.Res = s.genDNSFilterMessage(d, &res) } else if res.Reason == dnsfilter.ReasonRewrite && len(res.CanonName) != 0 && len(res.IPList) == 0 { @@ -59,6 +59,19 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*dnsfilter.Result, error) { // resolve canonical name, not the original host name d.Req.Question[0].Name = dns.Fqdn(res.CanonName) + } else if res.Reason == dnsfilter.RewriteEtcHosts && len(res.ReverseHost) != 0 { + + resp := s.makeResponse(req) + ptr := &dns.PTR{} + ptr.Hdr = dns.RR_Header{ + Name: req.Question[0].Name, + Rrtype: dns.TypePTR, + Ttl: s.conf.BlockedResponseTTL, + Class: dns.ClassINET, + } + ptr.Ptr = res.ReverseHost + resp.Answer = append(resp.Answer, ptr) + d.Res = resp } else if res.Reason == dnsfilter.ReasonRewrite || res.Reason == dnsfilter.RewriteEtcHosts { resp := s.makeResponse(req) @@ -81,20 +94,6 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*dnsfilter.Result, error) { } d.Res = resp - - } else if res.Reason == dnsfilter.RewriteEtcHosts && len(res.ReverseHost) != 0 { - - resp := s.makeResponse(req) - ptr := &dns.PTR{} - ptr.Hdr = dns.RR_Header{ - Name: req.Question[0].Name, - Rrtype: dns.TypePTR, - Ttl: s.conf.BlockedResponseTTL, - Class: dns.ClassINET, - } - ptr.Ptr = res.ReverseHost - resp.Answer = append(resp.Answer, ptr) - d.Res = resp } return &res, err diff --git a/util/auto_hosts.go b/util/auto_hosts.go index b12acd81..1cec9c1a 100644 --- a/util/auto_hosts.go +++ b/util/auto_hosts.go @@ -10,9 +10,10 @@ import ( "strings" "sync" + "github.com/miekg/dns" + "github.com/AdguardTeam/golibs/log" "github.com/fsnotify/fsnotify" - "github.com/miekg/dns" ) type onChangedT func() @@ -62,6 +63,9 @@ func (a *AutoHosts) Init(hostsFn string) { a.hostsDirs = append(a.hostsDirs, "/tmp/hosts") // OpenWRT: "/tmp/hosts/dhcp.cfg01411c" } + // Load hosts initially + a.updateHosts() + var err error a.watcher, err = fsnotify.NewWatcher() if err != nil { @@ -102,6 +106,62 @@ func (a *AutoHosts) Close() { } } +// Process - get the list of IP addresses for the hostname +// Return nil if not found +func (a *AutoHosts) Process(host string, qtype uint16) []net.IP { + if qtype == dns.TypePTR { + return nil + } + + var ipsCopy []net.IP + a.lock.Lock() + ips, _ := a.table[host] + if len(ips) != 0 { + ipsCopy = make([]net.IP, len(ips)) + copy(ipsCopy, ips) + } + a.lock.Unlock() + + log.Debug("AutoHosts: answer: %s -> %v", host, ipsCopy) + return ipsCopy +} + +// ProcessReverse - process PTR request +// Return "" if not found or an error occurred +func (a *AutoHosts) ProcessReverse(addr string, qtype uint16) string { + if qtype != dns.TypePTR { + return "" + } + + ipReal := DNSUnreverseAddr(addr) + if ipReal == nil { + return "" // invalid IP in question + } + ipStr := ipReal.String() + + a.lock.Lock() + host := a.tableReverse[ipStr] + a.lock.Unlock() + + if len(host) == 0 { + return "" // not found + } + + log.Debug("AutoHosts: reverse-lookup: %s -> %s", addr, host) + return host +} + +// List - get "IP -> hostname" table. Thread-safe. +func (a *AutoHosts) List() map[string]string { + table := make(map[string]string) + a.lock.Lock() + for k, v := range a.tableReverse { + table[k] = v + } + a.lock.Unlock() + return table +} + // update table func (a *AutoHosts) updateTable(table map[string][]net.IP, host string, ipAddr net.IP) { ips, ok := table[host] @@ -275,59 +335,3 @@ func (a *AutoHosts) updateHosts() { a.notify() } - -// Process - get the list of IP addresses for the hostname -// Return nil if not found -func (a *AutoHosts) Process(host string, qtype uint16) []net.IP { - if qtype == dns.TypePTR { - return nil - } - - var ipsCopy []net.IP - a.lock.Lock() - ips, _ := a.table[host] - if len(ips) != 0 { - ipsCopy = make([]net.IP, len(ips)) - copy(ipsCopy, ips) - } - a.lock.Unlock() - - log.Debug("AutoHosts: answer: %s -> %v", host, ipsCopy) - return ipsCopy -} - -// ProcessReverse - process PTR request -// Return "" if not found or an error occurred -func (a *AutoHosts) ProcessReverse(addr string, qtype uint16) string { - if qtype != dns.TypePTR { - return "" - } - - ipReal := DNSUnreverseAddr(addr) - if ipReal == nil { - return "" // invalid IP in question - } - ipStr := ipReal.String() - - a.lock.Lock() - host := a.tableReverse[ipStr] - a.lock.Unlock() - - if len(host) == 0 { - return "" // not found - } - - log.Debug("AutoHosts: reverse-lookup: %s -> %s", addr, host) - return host -} - -// List - get "IP -> hostname" table. Thread-safe. -func (a *AutoHosts) List() map[string]string { - table := make(map[string]string) - a.lock.Lock() - for k, v := range a.tableReverse { - table[k] = v - } - a.lock.Unlock() - return table -} diff --git a/util/auto_hosts_test.go b/util/auto_hosts_test.go index ea2e43ad..efd94b99 100644 --- a/util/auto_hosts_test.go +++ b/util/auto_hosts_test.go @@ -34,9 +34,6 @@ func TestAutoHostsResolution(t *testing.T) { ah.Init(f.Name()) - // Update from the hosts file - ah.updateHosts() - // Existing host ips := ah.Process("localhost", dns.TypeA) assert.NotNil(t, ips) @@ -79,7 +76,6 @@ func TestAutoHostsFSNotify(t *testing.T) { // Init _, _ = f.WriteString(" 127.0.0.1 host localhost \n") ah.Init(f.Name()) - ah.updateHosts() // Unknown host ips := ah.Process("newhost", dns.TypeA)