* (dnsforward): fix reverse lookups from /etc/hosts

There was a bug with empty PTR responses for IPs that are in the hosts
file

Closes: #2085
This commit is contained in:
Andrey Meshkov 2020-09-11 11:53:36 +03:00
parent caee4b86fa
commit cb8afde629
4 changed files with 125 additions and 80 deletions

View File

@ -9,13 +9,17 @@ import (
"crypto/x509/pkix" "crypto/x509/pkix"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"io/ioutil"
"math/big" "math/big"
"net" "net"
"os"
"sort" "sort"
"sync" "sync"
"testing" "testing"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/util"
"github.com/AdguardTeam/AdGuardHome/dhcpd" "github.com/AdguardTeam/AdGuardHome/dhcpd"
"github.com/AdguardTeam/AdGuardHome/dnsfilter" "github.com/AdguardTeam/AdGuardHome/dnsfilter"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
@ -664,17 +668,17 @@ func TestBlockedBySafeBrowsing(t *testing.T) {
func TestRewrite(t *testing.T) { func TestRewrite(t *testing.T) {
c := dnsfilter.Config{} c := dnsfilter.Config{}
c.Rewrites = []dnsfilter.RewriteEntry{ c.Rewrites = []dnsfilter.RewriteEntry{
dnsfilter.RewriteEntry{ {
Domain: "test.com", Domain: "test.com",
Answer: "1.2.3.4", Answer: "1.2.3.4",
Type: dns.TypeA, Type: dns.TypeA,
}, },
dnsfilter.RewriteEntry{ {
Domain: "alias.test.com", Domain: "alias.test.com",
Answer: "test.com", Answer: "test.com",
Type: dns.TypeCNAME, Type: dns.TypeCNAME,
}, },
dnsfilter.RewriteEntry{ {
Domain: "my.alias.example.org", Domain: "my.alias.example.org",
Answer: "example.org", Answer: "example.org",
Type: dns.TypeCNAME, Type: dns.TypeCNAME,
@ -1066,7 +1070,7 @@ func (d *testDHCP) SetOnLeaseChanged(onLeaseChanged dhcpd.OnLeaseChangedT) {
return return
} }
func TestPTRResponse(t *testing.T) { func TestPTRResponseFromDHCPLeases(t *testing.T) {
dhcp := &testDHCP{} dhcp := &testDHCP{}
c := dnsfilter.Config{} c := dnsfilter.Config{}
@ -1094,3 +1098,45 @@ func TestPTRResponse(t *testing.T) {
s.Close() s.Close()
} }
func TestPTRResponseFromHosts(t *testing.T) {
c := dnsfilter.Config{
AutoHosts: &util.AutoHosts{},
}
// Prepare test hosts file
hf, _ := ioutil.TempFile("", "")
defer func() { _ = os.Remove(hf.Name()) }()
defer hf.Close()
_, _ = hf.WriteString(" 127.0.0.1 host # comment \n")
_, _ = hf.WriteString(" ::1 localhost#comment \n")
// Init auto hosts
c.AutoHosts.Init(hf.Name())
defer c.AutoHosts.Close()
f := dnsfilter.New(&c, nil)
s := NewServer(DNSCreateParams{DNSFilter: f})
s.conf.UDPListenAddr = &net.UDPAddr{Port: 0}
s.conf.TCPListenAddr = &net.TCPAddr{Port: 0}
s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
s.conf.FilteringConfig.ProtectionEnabled = true
err := s.Prepare(nil)
assert.True(t, err == nil)
assert.Nil(t, s.Start())
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
req := createTestMessage("1.0.0.127.in-addr.arpa.")
req.Question[0].Qtype = dns.TypePTR
resp, err := dns.Exchange(req, addr.String())
assert.Nil(t, err)
assert.Equal(t, 1, len(resp.Answer))
assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype)
assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name)
ptr := resp.Answer[0].(*dns.PTR)
assert.Equal(t, "host.", ptr.Ptr)
s.Close()
}

View File

@ -51,7 +51,7 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*dnsfilter.Result, error) {
return nil, errorx.Decorate(err, "dnsfilter failed to check host '%s'", host) return nil, errorx.Decorate(err, "dnsfilter failed to check host '%s'", host)
} else if res.IsFiltered { } else if res.IsFiltered {
// log.Tracef("Host %s is filtered, reason - '%s', matched rule: '%s'", host, res.Reason, res.Rule) log.Tracef("Host %s is filtered, reason - '%s', matched rule: '%s'", host, res.Reason, res.Rule)
d.Res = s.genDNSFilterMessage(d, &res) d.Res = s.genDNSFilterMessage(d, &res)
} else if res.Reason == dnsfilter.ReasonRewrite && len(res.CanonName) != 0 && len(res.IPList) == 0 { } else if res.Reason == dnsfilter.ReasonRewrite && len(res.CanonName) != 0 && len(res.IPList) == 0 {
@ -59,6 +59,19 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*dnsfilter.Result, error) {
// resolve canonical name, not the original host name // resolve canonical name, not the original host name
d.Req.Question[0].Name = dns.Fqdn(res.CanonName) d.Req.Question[0].Name = dns.Fqdn(res.CanonName)
} else if res.Reason == dnsfilter.RewriteEtcHosts && len(res.ReverseHost) != 0 {
resp := s.makeResponse(req)
ptr := &dns.PTR{}
ptr.Hdr = dns.RR_Header{
Name: req.Question[0].Name,
Rrtype: dns.TypePTR,
Ttl: s.conf.BlockedResponseTTL,
Class: dns.ClassINET,
}
ptr.Ptr = res.ReverseHost
resp.Answer = append(resp.Answer, ptr)
d.Res = resp
} else if res.Reason == dnsfilter.ReasonRewrite || res.Reason == dnsfilter.RewriteEtcHosts { } else if res.Reason == dnsfilter.ReasonRewrite || res.Reason == dnsfilter.RewriteEtcHosts {
resp := s.makeResponse(req) resp := s.makeResponse(req)
@ -81,20 +94,6 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*dnsfilter.Result, error) {
} }
d.Res = resp d.Res = resp
} else if res.Reason == dnsfilter.RewriteEtcHosts && len(res.ReverseHost) != 0 {
resp := s.makeResponse(req)
ptr := &dns.PTR{}
ptr.Hdr = dns.RR_Header{
Name: req.Question[0].Name,
Rrtype: dns.TypePTR,
Ttl: s.conf.BlockedResponseTTL,
Class: dns.ClassINET,
}
ptr.Ptr = res.ReverseHost
resp.Answer = append(resp.Answer, ptr)
d.Res = resp
} }
return &res, err return &res, err

View File

@ -10,9 +10,10 @@ import (
"strings" "strings"
"sync" "sync"
"github.com/miekg/dns"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/fsnotify/fsnotify" "github.com/fsnotify/fsnotify"
"github.com/miekg/dns"
) )
type onChangedT func() type onChangedT func()
@ -62,6 +63,9 @@ func (a *AutoHosts) Init(hostsFn string) {
a.hostsDirs = append(a.hostsDirs, "/tmp/hosts") // OpenWRT: "/tmp/hosts/dhcp.cfg01411c" a.hostsDirs = append(a.hostsDirs, "/tmp/hosts") // OpenWRT: "/tmp/hosts/dhcp.cfg01411c"
} }
// Load hosts initially
a.updateHosts()
var err error var err error
a.watcher, err = fsnotify.NewWatcher() a.watcher, err = fsnotify.NewWatcher()
if err != nil { if err != nil {
@ -102,6 +106,62 @@ func (a *AutoHosts) Close() {
} }
} }
// Process - get the list of IP addresses for the hostname
// Return nil if not found
func (a *AutoHosts) Process(host string, qtype uint16) []net.IP {
if qtype == dns.TypePTR {
return nil
}
var ipsCopy []net.IP
a.lock.Lock()
ips, _ := a.table[host]
if len(ips) != 0 {
ipsCopy = make([]net.IP, len(ips))
copy(ipsCopy, ips)
}
a.lock.Unlock()
log.Debug("AutoHosts: answer: %s -> %v", host, ipsCopy)
return ipsCopy
}
// ProcessReverse - process PTR request
// Return "" if not found or an error occurred
func (a *AutoHosts) ProcessReverse(addr string, qtype uint16) string {
if qtype != dns.TypePTR {
return ""
}
ipReal := DNSUnreverseAddr(addr)
if ipReal == nil {
return "" // invalid IP in question
}
ipStr := ipReal.String()
a.lock.Lock()
host := a.tableReverse[ipStr]
a.lock.Unlock()
if len(host) == 0 {
return "" // not found
}
log.Debug("AutoHosts: reverse-lookup: %s -> %s", addr, host)
return host
}
// List - get "IP -> hostname" table. Thread-safe.
func (a *AutoHosts) List() map[string]string {
table := make(map[string]string)
a.lock.Lock()
for k, v := range a.tableReverse {
table[k] = v
}
a.lock.Unlock()
return table
}
// update table // update table
func (a *AutoHosts) updateTable(table map[string][]net.IP, host string, ipAddr net.IP) { func (a *AutoHosts) updateTable(table map[string][]net.IP, host string, ipAddr net.IP) {
ips, ok := table[host] ips, ok := table[host]
@ -275,59 +335,3 @@ func (a *AutoHosts) updateHosts() {
a.notify() a.notify()
} }
// Process - get the list of IP addresses for the hostname
// Return nil if not found
func (a *AutoHosts) Process(host string, qtype uint16) []net.IP {
if qtype == dns.TypePTR {
return nil
}
var ipsCopy []net.IP
a.lock.Lock()
ips, _ := a.table[host]
if len(ips) != 0 {
ipsCopy = make([]net.IP, len(ips))
copy(ipsCopy, ips)
}
a.lock.Unlock()
log.Debug("AutoHosts: answer: %s -> %v", host, ipsCopy)
return ipsCopy
}
// ProcessReverse - process PTR request
// Return "" if not found or an error occurred
func (a *AutoHosts) ProcessReverse(addr string, qtype uint16) string {
if qtype != dns.TypePTR {
return ""
}
ipReal := DNSUnreverseAddr(addr)
if ipReal == nil {
return "" // invalid IP in question
}
ipStr := ipReal.String()
a.lock.Lock()
host := a.tableReverse[ipStr]
a.lock.Unlock()
if len(host) == 0 {
return "" // not found
}
log.Debug("AutoHosts: reverse-lookup: %s -> %s", addr, host)
return host
}
// List - get "IP -> hostname" table. Thread-safe.
func (a *AutoHosts) List() map[string]string {
table := make(map[string]string)
a.lock.Lock()
for k, v := range a.tableReverse {
table[k] = v
}
a.lock.Unlock()
return table
}

View File

@ -34,9 +34,6 @@ func TestAutoHostsResolution(t *testing.T) {
ah.Init(f.Name()) ah.Init(f.Name())
// Update from the hosts file
ah.updateHosts()
// Existing host // Existing host
ips := ah.Process("localhost", dns.TypeA) ips := ah.Process("localhost", dns.TypeA)
assert.NotNil(t, ips) assert.NotNil(t, ips)
@ -79,7 +76,6 @@ func TestAutoHostsFSNotify(t *testing.T) {
// Init // Init
_, _ = f.WriteString(" 127.0.0.1 host localhost \n") _, _ = f.WriteString(" 127.0.0.1 host localhost \n")
ah.Init(f.Name()) ah.Init(f.Name())
ah.updateHosts()
// Unknown host // Unknown host
ips := ah.Process("newhost", dns.TypeA) ips := ah.Process("newhost", dns.TypeA)