Merge: dnsforward: support IPv6
Close #735 * commit 'e2579c72bdcafed41d5be1250fb38aeda0a8184e': * dnsfilter: fix tests + dnsforward: support IPv6 (AAAA response) * dnsfilter: return the correct IP address (host rules)
This commit is contained in:
commit
b9df476c5d
|
@ -19,6 +19,7 @@ import (
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/AdguardTeam/urlfilter"
|
"github.com/AdguardTeam/urlfilter"
|
||||||
"github.com/bluele/gcache"
|
"github.com/bluele/gcache"
|
||||||
|
"github.com/miekg/dns"
|
||||||
"golang.org/x/net/publicsuffix"
|
"golang.org/x/net/publicsuffix"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -147,7 +148,7 @@ func (r Reason) Matched() bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
// CheckHost tries to match host against rules, then safebrowsing and parental if they are enabled
|
// CheckHost tries to match host against rules, then safebrowsing and parental if they are enabled
|
||||||
func (d *Dnsfilter) CheckHost(host string) (Result, error) {
|
func (d *Dnsfilter) CheckHost(host string, qtype uint16) (Result, error) {
|
||||||
// sometimes DNS clients will try to resolve ".", which is a request to get root servers
|
// sometimes DNS clients will try to resolve ".", which is a request to get root servers
|
||||||
if host == "" {
|
if host == "" {
|
||||||
return Result{Reason: NotFilteredNotFound}, nil
|
return Result{Reason: NotFilteredNotFound}, nil
|
||||||
|
@ -159,7 +160,7 @@ func (d *Dnsfilter) CheckHost(host string) (Result, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// try filter lists first
|
// try filter lists first
|
||||||
result, err := d.matchHost(host)
|
result, err := d.matchHost(host, qtype)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return result, err
|
return result, err
|
||||||
}
|
}
|
||||||
|
@ -517,7 +518,7 @@ func (d *Dnsfilter) initFiltering(filters map[int]string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// matchHost is a low-level way to check only if hostname is filtered by rules, skipping expensive safebrowsing and parental lookups
|
// matchHost is a low-level way to check only if hostname is filtered by rules, skipping expensive safebrowsing and parental lookups
|
||||||
func (d *Dnsfilter) matchHost(host string) (Result, error) {
|
func (d *Dnsfilter) matchHost(host string, qtype uint16) (Result, error) {
|
||||||
if d.filteringEngine == nil {
|
if d.filteringEngine == nil {
|
||||||
return Result{}, nil
|
return Result{}, nil
|
||||||
}
|
}
|
||||||
|
@ -527,6 +528,8 @@ func (d *Dnsfilter) matchHost(host string) (Result, error) {
|
||||||
return Result{}, nil
|
return Result{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Tracef("%d rules matched for host '%s'", len(rules), host)
|
||||||
|
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
|
|
||||||
log.Tracef("Found rule for host '%s': '%s' list_id: %d",
|
log.Tracef("Found rule for host '%s': '%s' list_id: %d",
|
||||||
|
@ -548,8 +551,24 @@ func (d *Dnsfilter) matchHost(host string) (Result, error) {
|
||||||
|
|
||||||
} else if hostRule, ok := rule.(*urlfilter.HostRule); ok {
|
} else if hostRule, ok := rule.(*urlfilter.HostRule); ok {
|
||||||
|
|
||||||
res.IP = hostRule.IP
|
if qtype == dns.TypeA && hostRule.IP.To4() != nil {
|
||||||
return res, nil
|
// either IPv4 or IPv4-mapped IPv6 address
|
||||||
|
res.IP = hostRule.IP.To4()
|
||||||
|
return res, nil
|
||||||
|
|
||||||
|
} else if qtype == dns.TypeAAAA {
|
||||||
|
ip4 := hostRule.IP.To4()
|
||||||
|
if ip4 == nil {
|
||||||
|
res.IP = hostRule.IP
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
if bytes.Equal(ip4, []byte{0, 0, 0, 0}) {
|
||||||
|
// send IP="::" response for a rule "0.0.0.0 blockdomain"
|
||||||
|
res.IP = net.IPv6zero
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
log.Tracef("Rule type is unsupported: '%s' list_id: %d",
|
log.Tracef("Rule type is unsupported: '%s' list_id: %d",
|
||||||
|
|
|
@ -9,6 +9,8 @@ import (
|
||||||
"runtime"
|
"runtime"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
// HELPERS
|
// HELPERS
|
||||||
|
@ -50,7 +52,7 @@ func NewForTestFilters(filters map[int]string) *Dnsfilter {
|
||||||
|
|
||||||
func (d *Dnsfilter) checkMatch(t *testing.T, hostname string) {
|
func (d *Dnsfilter) checkMatch(t *testing.T, hostname string) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
ret, err := d.CheckHost(hostname)
|
ret, err := d.CheckHost(hostname, dns.TypeA)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Error while matching host %s: %s", hostname, err)
|
t.Errorf("Error while matching host %s: %s", hostname, err)
|
||||||
}
|
}
|
||||||
|
@ -59,9 +61,9 @@ func (d *Dnsfilter) checkMatch(t *testing.T, hostname string) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Dnsfilter) checkMatchIP(t *testing.T, hostname string, ip string) {
|
func (d *Dnsfilter) checkMatchIP(t *testing.T, hostname string, ip string, qtype uint16) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
ret, err := d.CheckHost(hostname)
|
ret, err := d.CheckHost(hostname, qtype)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Error while matching host %s: %s", hostname, err)
|
t.Errorf("Error while matching host %s: %s", hostname, err)
|
||||||
}
|
}
|
||||||
|
@ -75,7 +77,7 @@ func (d *Dnsfilter) checkMatchIP(t *testing.T, hostname string, ip string) {
|
||||||
|
|
||||||
func (d *Dnsfilter) checkMatchEmpty(t *testing.T, hostname string) {
|
func (d *Dnsfilter) checkMatchEmpty(t *testing.T, hostname string) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
ret, err := d.CheckHost(hostname)
|
ret, err := d.CheckHost(hostname, dns.TypeA)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Error while matching host %s: %s", hostname, err)
|
t.Errorf("Error while matching host %s: %s", hostname, err)
|
||||||
}
|
}
|
||||||
|
@ -86,16 +88,25 @@ func (d *Dnsfilter) checkMatchEmpty(t *testing.T, hostname string) {
|
||||||
|
|
||||||
func TestEtcHostsMatching(t *testing.T) {
|
func TestEtcHostsMatching(t *testing.T) {
|
||||||
addr := "216.239.38.120"
|
addr := "216.239.38.120"
|
||||||
text := fmt.Sprintf(" %s google.com www.google.com # enforce google's safesearch ", addr)
|
addr6 := "::1"
|
||||||
|
text := fmt.Sprintf(" %s google.com www.google.com # enforce google's safesearch \n%s google.com\n0.0.0.0 block.com\n",
|
||||||
|
addr, addr6)
|
||||||
filters := make(map[int]string)
|
filters := make(map[int]string)
|
||||||
filters[0] = text
|
filters[0] = text
|
||||||
d := NewForTestFilters(filters)
|
d := NewForTestFilters(filters)
|
||||||
defer d.Destroy()
|
defer d.Destroy()
|
||||||
|
|
||||||
d.checkMatchIP(t, "google.com", addr)
|
d.checkMatchIP(t, "google.com", addr, dns.TypeA)
|
||||||
d.checkMatchIP(t, "www.google.com", addr)
|
d.checkMatchIP(t, "www.google.com", addr, dns.TypeA)
|
||||||
d.checkMatchEmpty(t, "subdomain.google.com")
|
d.checkMatchEmpty(t, "subdomain.google.com")
|
||||||
d.checkMatchEmpty(t, "example.org")
|
d.checkMatchEmpty(t, "example.org")
|
||||||
|
|
||||||
|
// IPv6 address
|
||||||
|
d.checkMatchIP(t, "google.com", addr6, dns.TypeAAAA)
|
||||||
|
|
||||||
|
// block both IPv4 and IPv6
|
||||||
|
d.checkMatchIP(t, "block.com", "0.0.0.0", dns.TypeA)
|
||||||
|
d.checkMatchIP(t, "block.com", "::", dns.TypeAAAA)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SAFE BROWSING
|
// SAFE BROWSING
|
||||||
|
@ -201,7 +212,7 @@ func TestCheckHostSafeSearchYandex(t *testing.T) {
|
||||||
|
|
||||||
// Check host for each domain
|
// Check host for each domain
|
||||||
for _, host := range yandex {
|
for _, host := range yandex {
|
||||||
result, err := d.CheckHost(host)
|
result, err := d.CheckHost(host, dns.TypeA)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("SafeSearch doesn't work for yandex domain `%s` cause %s", host, err)
|
t.Errorf("SafeSearch doesn't work for yandex domain `%s` cause %s", host, err)
|
||||||
}
|
}
|
||||||
|
@ -224,7 +235,7 @@ func TestCheckHostSafeSearchGoogle(t *testing.T) {
|
||||||
|
|
||||||
// Check host for each domain
|
// Check host for each domain
|
||||||
for _, host := range googleDomains {
|
for _, host := range googleDomains {
|
||||||
result, err := d.CheckHost(host)
|
result, err := d.CheckHost(host, dns.TypeA)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("SafeSearch doesn't work for %s cause %s", host, err)
|
t.Errorf("SafeSearch doesn't work for %s cause %s", host, err)
|
||||||
}
|
}
|
||||||
|
@ -244,7 +255,7 @@ func TestSafeSearchCacheYandex(t *testing.T) {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
// Check host with disabled safesearch
|
// Check host with disabled safesearch
|
||||||
result, err = d.CheckHost(domain)
|
result, err = d.CheckHost(domain, dns.TypeA)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Cannot check host due to %s", err)
|
t.Fatalf("Cannot check host due to %s", err)
|
||||||
}
|
}
|
||||||
|
@ -254,7 +265,7 @@ func TestSafeSearchCacheYandex(t *testing.T) {
|
||||||
|
|
||||||
// Enable safesearch
|
// Enable safesearch
|
||||||
d.SafeSearchEnabled = true
|
d.SafeSearchEnabled = true
|
||||||
result, err = d.CheckHost(domain)
|
result, err = d.CheckHost(domain, dns.TypeA)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("CheckHost for safesearh domain %s failed cause %s", domain, err)
|
t.Fatalf("CheckHost for safesearh domain %s failed cause %s", domain, err)
|
||||||
}
|
}
|
||||||
|
@ -284,7 +295,7 @@ func TestSafeSearchCacheGoogle(t *testing.T) {
|
||||||
d := NewForTest()
|
d := NewForTest()
|
||||||
defer d.Destroy()
|
defer d.Destroy()
|
||||||
domain := "www.google.ru"
|
domain := "www.google.ru"
|
||||||
result, err := d.CheckHost(domain)
|
result, err := d.CheckHost(domain, dns.TypeA)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Cannot check host due to %s", err)
|
t.Fatalf("Cannot check host due to %s", err)
|
||||||
}
|
}
|
||||||
|
@ -313,7 +324,7 @@ func TestSafeSearchCacheGoogle(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err = d.CheckHost(domain)
|
result, err = d.CheckHost(domain, dns.TypeA)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("CheckHost for safesearh domain %s failed cause %s", domain, err)
|
t.Fatalf("CheckHost for safesearh domain %s failed cause %s", domain, err)
|
||||||
}
|
}
|
||||||
|
@ -430,7 +441,7 @@ func TestMatching(t *testing.T) {
|
||||||
d := NewForTestFilters(filters)
|
d := NewForTestFilters(filters)
|
||||||
defer d.Destroy()
|
defer d.Destroy()
|
||||||
|
|
||||||
ret, err := d.CheckHost(test.hostname)
|
ret, err := d.CheckHost(test.hostname, dns.TypeA)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Error while matching host %s: %s", test.hostname, err)
|
t.Errorf("Error while matching host %s: %s", test.hostname, err)
|
||||||
}
|
}
|
||||||
|
@ -452,7 +463,7 @@ func BenchmarkSafeBrowsing(b *testing.B) {
|
||||||
d.SafeBrowsingEnabled = true
|
d.SafeBrowsingEnabled = true
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
hostname := "wmconvirus.narod.ru"
|
hostname := "wmconvirus.narod.ru"
|
||||||
ret, err := d.CheckHost(hostname)
|
ret, err := d.CheckHost(hostname, dns.TypeA)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Errorf("Error while matching host %s: %s", hostname, err)
|
b.Errorf("Error while matching host %s: %s", hostname, err)
|
||||||
}
|
}
|
||||||
|
@ -469,7 +480,7 @@ func BenchmarkSafeBrowsingParallel(b *testing.B) {
|
||||||
b.RunParallel(func(pb *testing.PB) {
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
for pb.Next() {
|
for pb.Next() {
|
||||||
hostname := "wmconvirus.narod.ru"
|
hostname := "wmconvirus.narod.ru"
|
||||||
ret, err := d.CheckHost(hostname)
|
ret, err := d.CheckHost(hostname, dns.TypeA)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Errorf("Error while matching host %s: %s", hostname, err)
|
b.Errorf("Error while matching host %s: %s", hostname, err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -378,7 +378,7 @@ func (s *Server) filterDNSRequest(d *proxy.DNSContext) (*dnsfilter.Result, error
|
||||||
var res dnsfilter.Result
|
var res dnsfilter.Result
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
res, err = dnsFilter.CheckHost(host)
|
res, err = dnsFilter.CheckHost(host, d.Req.Question[0].Qtype)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Return immediately if there's an error
|
// Return immediately if there's an error
|
||||||
return nil, errorx.Decorate(err, "dnsfilter failed to check host '%s'", host)
|
return nil, errorx.Decorate(err, "dnsfilter failed to check host '%s'", host)
|
||||||
|
@ -394,7 +394,7 @@ func (s *Server) filterDNSRequest(d *proxy.DNSContext) (*dnsfilter.Result, error
|
||||||
func (s *Server) genDNSFilterMessage(d *proxy.DNSContext, result *dnsfilter.Result) *dns.Msg {
|
func (s *Server) genDNSFilterMessage(d *proxy.DNSContext, result *dnsfilter.Result) *dns.Msg {
|
||||||
m := d.Req
|
m := d.Req
|
||||||
|
|
||||||
if m.Question[0].Qtype != dns.TypeA {
|
if m.Question[0].Qtype != dns.TypeA && m.Question[0].Qtype != dns.TypeAAAA {
|
||||||
return s.genNXDomain(m)
|
return s.genNXDomain(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -405,11 +405,25 @@ func (s *Server) genDNSFilterMessage(d *proxy.DNSContext, result *dnsfilter.Resu
|
||||||
return s.genBlockedHost(m, parentalBlockHost, d)
|
return s.genBlockedHost(m, parentalBlockHost, d)
|
||||||
default:
|
default:
|
||||||
if result.IP != nil {
|
if result.IP != nil {
|
||||||
return s.genARecord(m, result.IP)
|
if m.Question[0].Qtype == dns.TypeA {
|
||||||
|
return s.genARecord(m, result.IP)
|
||||||
|
} else if m.Question[0].Qtype == dns.TypeAAAA {
|
||||||
|
return s.genAAAARecord(m, result.IP)
|
||||||
|
}
|
||||||
|
|
||||||
|
// empty response
|
||||||
|
resp := dns.Msg{}
|
||||||
|
resp.SetReply(m)
|
||||||
|
return &resp
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.conf.BlockingMode == "null_ip" {
|
if s.conf.BlockingMode == "null_ip" {
|
||||||
return s.genARecord(m, net.IPv4zero)
|
switch m.Question[0].Qtype {
|
||||||
|
case dns.TypeA:
|
||||||
|
return s.genARecord(m, []byte{0, 0, 0, 0})
|
||||||
|
case dns.TypeAAAA:
|
||||||
|
return s.genAAAARecord(m, net.IPv6zero)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.genNXDomain(m)
|
return s.genNXDomain(m)
|
||||||
|
@ -426,15 +440,41 @@ func (s *Server) genServerFailure(request *dns.Msg) *dns.Msg {
|
||||||
func (s *Server) genARecord(request *dns.Msg, ip net.IP) *dns.Msg {
|
func (s *Server) genARecord(request *dns.Msg, ip net.IP) *dns.Msg {
|
||||||
resp := dns.Msg{}
|
resp := dns.Msg{}
|
||||||
resp.SetReply(request)
|
resp.SetReply(request)
|
||||||
answer, err := dns.NewRR(fmt.Sprintf("%s %d A %s", request.Question[0].Name, s.conf.BlockedResponseTTL, ip.String()))
|
resp.Answer = append(resp.Answer, s.genAAnswer(request, ip))
|
||||||
if err != nil {
|
|
||||||
log.Printf("Couldn't generate A record for replacement host '%s': %s", ip.String(), err)
|
|
||||||
return s.genServerFailure(request)
|
|
||||||
}
|
|
||||||
resp.Answer = append(resp.Answer, answer)
|
|
||||||
return &resp
|
return &resp
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) genAAAARecord(request *dns.Msg, ip net.IP) *dns.Msg {
|
||||||
|
resp := dns.Msg{}
|
||||||
|
resp.SetReply(request)
|
||||||
|
resp.Answer = append(resp.Answer, s.genAAAAAnswer(request, ip))
|
||||||
|
return &resp
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) genAAnswer(req *dns.Msg, ip net.IP) *dns.A {
|
||||||
|
answer := new(dns.A)
|
||||||
|
answer.Hdr = dns.RR_Header{
|
||||||
|
Name: req.Question[0].Name,
|
||||||
|
Rrtype: dns.TypeA,
|
||||||
|
Ttl: s.conf.BlockedResponseTTL,
|
||||||
|
Class: dns.ClassINET,
|
||||||
|
}
|
||||||
|
answer.A = ip
|
||||||
|
return answer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) genAAAAAnswer(req *dns.Msg, ip net.IP) *dns.AAAA {
|
||||||
|
answer := new(dns.AAAA)
|
||||||
|
answer.Hdr = dns.RR_Header{
|
||||||
|
Name: req.Question[0].Name,
|
||||||
|
Rrtype: dns.TypeAAAA,
|
||||||
|
Ttl: s.conf.BlockedResponseTTL,
|
||||||
|
Class: dns.ClassINET,
|
||||||
|
}
|
||||||
|
answer.AAAA = ip
|
||||||
|
return answer
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSContext) *dns.Msg {
|
func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSContext) *dns.Msg {
|
||||||
// look up the hostname, TODO: cache
|
// look up the hostname, TODO: cache
|
||||||
replReq := dns.Msg{}
|
replReq := dns.Msg{}
|
||||||
|
|
Loading…
Reference in New Issue