diff --git a/dnsfilter/dnsfilter.go b/dnsfilter/dnsfilter.go index b9292f5f..61e9bee0 100644 --- a/dnsfilter/dnsfilter.go +++ b/dnsfilter/dnsfilter.go @@ -19,6 +19,7 @@ import ( "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/urlfilter" "github.com/bluele/gcache" + "github.com/miekg/dns" "golang.org/x/net/publicsuffix" ) @@ -147,7 +148,7 @@ func (r Reason) Matched() bool { } // CheckHost tries to match host against rules, then safebrowsing and parental if they are enabled -func (d *Dnsfilter) CheckHost(host string) (Result, error) { +func (d *Dnsfilter) CheckHost(host string, qtype uint16) (Result, error) { // sometimes DNS clients will try to resolve ".", which is a request to get root servers if host == "" { return Result{Reason: NotFilteredNotFound}, nil @@ -159,7 +160,7 @@ func (d *Dnsfilter) CheckHost(host string) (Result, error) { } // try filter lists first - result, err := d.matchHost(host) + result, err := d.matchHost(host, qtype) if err != nil { return result, err } @@ -517,7 +518,7 @@ func (d *Dnsfilter) initFiltering(filters map[int]string) error { } // matchHost is a low-level way to check only if hostname is filtered by rules, skipping expensive safebrowsing and parental lookups -func (d *Dnsfilter) matchHost(host string) (Result, error) { +func (d *Dnsfilter) matchHost(host string, qtype uint16) (Result, error) { if d.filteringEngine == nil { return Result{}, nil } @@ -527,6 +528,8 @@ func (d *Dnsfilter) matchHost(host string) (Result, error) { return Result{}, nil } + log.Tracef("%d rules matched for host '%s'", len(rules), host) + for _, rule := range rules { log.Tracef("Found rule for host '%s': '%s' list_id: %d", @@ -548,8 +551,24 @@ func (d *Dnsfilter) matchHost(host string) (Result, error) { } else if hostRule, ok := rule.(*urlfilter.HostRule); ok { - res.IP = hostRule.IP - return res, nil + if qtype == dns.TypeA && hostRule.IP.To4() != nil { + // either IPv4 or IPv4-mapped IPv6 address + res.IP = hostRule.IP.To4() + return res, nil + + } else if qtype == dns.TypeAAAA { + ip4 := hostRule.IP.To4() + if ip4 == nil { + res.IP = hostRule.IP + return res, nil + } + if bytes.Equal(ip4, []byte{0, 0, 0, 0}) { + // send IP="::" response for a rule "0.0.0.0 blockdomain" + res.IP = net.IPv6zero + return res, nil + } + } + continue } else { log.Tracef("Rule type is unsupported: '%s' list_id: %d", diff --git a/dnsfilter/dnsfilter_test.go b/dnsfilter/dnsfilter_test.go index 04861e8c..4619310f 100644 --- a/dnsfilter/dnsfilter_test.go +++ b/dnsfilter/dnsfilter_test.go @@ -9,6 +9,8 @@ import ( "runtime" "testing" "time" + + "github.com/miekg/dns" ) // HELPERS @@ -50,7 +52,7 @@ func NewForTestFilters(filters map[int]string) *Dnsfilter { func (d *Dnsfilter) checkMatch(t *testing.T, hostname string) { t.Helper() - ret, err := d.CheckHost(hostname) + ret, err := d.CheckHost(hostname, dns.TypeA) if err != nil { t.Errorf("Error while matching host %s: %s", hostname, err) } @@ -59,9 +61,9 @@ func (d *Dnsfilter) checkMatch(t *testing.T, hostname string) { } } -func (d *Dnsfilter) checkMatchIP(t *testing.T, hostname string, ip string) { +func (d *Dnsfilter) checkMatchIP(t *testing.T, hostname string, ip string, qtype uint16) { t.Helper() - ret, err := d.CheckHost(hostname) + ret, err := d.CheckHost(hostname, qtype) if err != nil { t.Errorf("Error while matching host %s: %s", hostname, err) } @@ -75,7 +77,7 @@ func (d *Dnsfilter) checkMatchIP(t *testing.T, hostname string, ip string) { func (d *Dnsfilter) checkMatchEmpty(t *testing.T, hostname string) { t.Helper() - ret, err := d.CheckHost(hostname) + ret, err := d.CheckHost(hostname, dns.TypeA) if err != nil { t.Errorf("Error while matching host %s: %s", hostname, err) } @@ -86,16 +88,25 @@ func (d *Dnsfilter) checkMatchEmpty(t *testing.T, hostname string) { func TestEtcHostsMatching(t *testing.T) { addr := "216.239.38.120" - text := fmt.Sprintf(" %s google.com www.google.com # enforce google's safesearch ", addr) + addr6 := "::1" + text := fmt.Sprintf(" %s google.com www.google.com # enforce google's safesearch \n%s google.com\n0.0.0.0 block.com\n", + addr, addr6) filters := make(map[int]string) filters[0] = text d := NewForTestFilters(filters) defer d.Destroy() - d.checkMatchIP(t, "google.com", addr) - d.checkMatchIP(t, "www.google.com", addr) + d.checkMatchIP(t, "google.com", addr, dns.TypeA) + d.checkMatchIP(t, "www.google.com", addr, dns.TypeA) d.checkMatchEmpty(t, "subdomain.google.com") d.checkMatchEmpty(t, "example.org") + + // IPv6 address + d.checkMatchIP(t, "google.com", addr6, dns.TypeAAAA) + + // block both IPv4 and IPv6 + d.checkMatchIP(t, "block.com", "0.0.0.0", dns.TypeA) + d.checkMatchIP(t, "block.com", "::", dns.TypeAAAA) } // SAFE BROWSING @@ -201,7 +212,7 @@ func TestCheckHostSafeSearchYandex(t *testing.T) { // Check host for each domain for _, host := range yandex { - result, err := d.CheckHost(host) + result, err := d.CheckHost(host, dns.TypeA) if err != nil { t.Errorf("SafeSearch doesn't work for yandex domain `%s` cause %s", host, err) } @@ -224,7 +235,7 @@ func TestCheckHostSafeSearchGoogle(t *testing.T) { // Check host for each domain for _, host := range googleDomains { - result, err := d.CheckHost(host) + result, err := d.CheckHost(host, dns.TypeA) if err != nil { t.Errorf("SafeSearch doesn't work for %s cause %s", host, err) } @@ -244,7 +255,7 @@ func TestSafeSearchCacheYandex(t *testing.T) { var err error // Check host with disabled safesearch - result, err = d.CheckHost(domain) + result, err = d.CheckHost(domain, dns.TypeA) if err != nil { t.Fatalf("Cannot check host due to %s", err) } @@ -254,7 +265,7 @@ func TestSafeSearchCacheYandex(t *testing.T) { // Enable safesearch d.SafeSearchEnabled = true - result, err = d.CheckHost(domain) + result, err = d.CheckHost(domain, dns.TypeA) if err != nil { t.Fatalf("CheckHost for safesearh domain %s failed cause %s", domain, err) } @@ -284,7 +295,7 @@ func TestSafeSearchCacheGoogle(t *testing.T) { d := NewForTest() defer d.Destroy() domain := "www.google.ru" - result, err := d.CheckHost(domain) + result, err := d.CheckHost(domain, dns.TypeA) if err != nil { t.Fatalf("Cannot check host due to %s", err) } @@ -313,7 +324,7 @@ func TestSafeSearchCacheGoogle(t *testing.T) { } } - result, err = d.CheckHost(domain) + result, err = d.CheckHost(domain, dns.TypeA) if err != nil { t.Fatalf("CheckHost for safesearh domain %s failed cause %s", domain, err) } @@ -430,7 +441,7 @@ func TestMatching(t *testing.T) { d := NewForTestFilters(filters) defer d.Destroy() - ret, err := d.CheckHost(test.hostname) + ret, err := d.CheckHost(test.hostname, dns.TypeA) if err != nil { t.Errorf("Error while matching host %s: %s", test.hostname, err) } @@ -452,7 +463,7 @@ func BenchmarkSafeBrowsing(b *testing.B) { d.SafeBrowsingEnabled = true for n := 0; n < b.N; n++ { hostname := "wmconvirus.narod.ru" - ret, err := d.CheckHost(hostname) + ret, err := d.CheckHost(hostname, dns.TypeA) if err != nil { b.Errorf("Error while matching host %s: %s", hostname, err) } @@ -469,7 +480,7 @@ func BenchmarkSafeBrowsingParallel(b *testing.B) { b.RunParallel(func(pb *testing.PB) { for pb.Next() { hostname := "wmconvirus.narod.ru" - ret, err := d.CheckHost(hostname) + ret, err := d.CheckHost(hostname, dns.TypeA) if err != nil { b.Errorf("Error while matching host %s: %s", hostname, err) } diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index 51c6d3f6..b728c90f 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -378,7 +378,7 @@ func (s *Server) filterDNSRequest(d *proxy.DNSContext) (*dnsfilter.Result, error var res dnsfilter.Result var err error - res, err = dnsFilter.CheckHost(host) + res, err = dnsFilter.CheckHost(host, d.Req.Question[0].Qtype) if err != nil { // Return immediately if there's an error return nil, errorx.Decorate(err, "dnsfilter failed to check host '%s'", host) @@ -394,7 +394,7 @@ func (s *Server) filterDNSRequest(d *proxy.DNSContext) (*dnsfilter.Result, error func (s *Server) genDNSFilterMessage(d *proxy.DNSContext, result *dnsfilter.Result) *dns.Msg { m := d.Req - if m.Question[0].Qtype != dns.TypeA { + if m.Question[0].Qtype != dns.TypeA && m.Question[0].Qtype != dns.TypeAAAA { return s.genNXDomain(m) } @@ -405,11 +405,25 @@ func (s *Server) genDNSFilterMessage(d *proxy.DNSContext, result *dnsfilter.Resu return s.genBlockedHost(m, parentalBlockHost, d) default: if result.IP != nil { - return s.genARecord(m, result.IP) + if m.Question[0].Qtype == dns.TypeA { + return s.genARecord(m, result.IP) + } else if m.Question[0].Qtype == dns.TypeAAAA { + return s.genAAAARecord(m, result.IP) + } + + // empty response + resp := dns.Msg{} + resp.SetReply(m) + return &resp } if s.conf.BlockingMode == "null_ip" { - return s.genARecord(m, net.IPv4zero) + switch m.Question[0].Qtype { + case dns.TypeA: + return s.genARecord(m, []byte{0, 0, 0, 0}) + case dns.TypeAAAA: + return s.genAAAARecord(m, net.IPv6zero) + } } return s.genNXDomain(m) @@ -426,15 +440,41 @@ func (s *Server) genServerFailure(request *dns.Msg) *dns.Msg { func (s *Server) genARecord(request *dns.Msg, ip net.IP) *dns.Msg { resp := dns.Msg{} resp.SetReply(request) - answer, err := dns.NewRR(fmt.Sprintf("%s %d A %s", request.Question[0].Name, s.conf.BlockedResponseTTL, ip.String())) - if err != nil { - log.Printf("Couldn't generate A record for replacement host '%s': %s", ip.String(), err) - return s.genServerFailure(request) - } - resp.Answer = append(resp.Answer, answer) + resp.Answer = append(resp.Answer, s.genAAnswer(request, ip)) return &resp } +func (s *Server) genAAAARecord(request *dns.Msg, ip net.IP) *dns.Msg { + resp := dns.Msg{} + resp.SetReply(request) + resp.Answer = append(resp.Answer, s.genAAAAAnswer(request, ip)) + return &resp +} + +func (s *Server) genAAnswer(req *dns.Msg, ip net.IP) *dns.A { + answer := new(dns.A) + answer.Hdr = dns.RR_Header{ + Name: req.Question[0].Name, + Rrtype: dns.TypeA, + Ttl: s.conf.BlockedResponseTTL, + Class: dns.ClassINET, + } + answer.A = ip + return answer +} + +func (s *Server) genAAAAAnswer(req *dns.Msg, ip net.IP) *dns.AAAA { + answer := new(dns.AAAA) + answer.Hdr = dns.RR_Header{ + Name: req.Question[0].Name, + Rrtype: dns.TypeAAAA, + Ttl: s.conf.BlockedResponseTTL, + Class: dns.ClassINET, + } + answer.AAAA = ip + return answer +} + func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSContext) *dns.Msg { // look up the hostname, TODO: cache replReq := dns.Msg{}