Pull request: 4238 response filtering
Merge in DNS/adguard-home from 4238-response-filtering to master Closes #4238. Squashed commit of the following: commit 2113f83c617a396a39f910bb8df939364fedf391 Author: Eugene Burkov <E.Burkov@AdGuard.COM> Date: Thu Feb 3 21:04:25 2022 +0300 dnsforward: restore a bit commit f78607ed97892557c6bd6f3c3332f0bae01c1987 Author: Eugene Burkov <E.Burkov@AdGuard.COM> Date: Thu Feb 3 20:52:45 2022 +0300 all: imp code, docs commit 646074ce141e8ac12a972f46d071389a2ce124e4 Author: Eugene Burkov <E.Burkov@AdGuard.COM> Date: Thu Feb 3 20:37:05 2022 +0300 all: log changes commit 94556d810549370fc455bcf14537fa1d2783eed1 Author: Eugene Burkov <E.Burkov@AdGuard.COM> Date: Thu Feb 3 20:30:57 2022 +0300 all: imp test upstream, cover resp filtering commit 63e7721822a049734a390c7d7ea6d8416a43c8b5 Author: Eugene Burkov <E.Burkov@AdGuard.COM> Date: Tue Feb 1 21:58:08 2022 +0300 all: filter response by rrtype
This commit is contained in:
parent
0ee34534c6
commit
e783564084
|
@ -21,6 +21,8 @@ and this project adheres to
|
||||||
|
|
||||||
### Changed
|
### Changed
|
||||||
|
|
||||||
|
- Response filtering is now performed using the record types of the answer
|
||||||
|
section of messages as opposed to the type of the question ([#4238]).
|
||||||
- Instead of adding the build time information, the build scripts now use the
|
- Instead of adding the build time information, the build scripts now use the
|
||||||
standardized environment variable [`SOURCE_DATE_EPOCH`][repr] to add the date
|
standardized environment variable [`SOURCE_DATE_EPOCH`][repr] to add the date
|
||||||
of the commit from which the binary was built ([#4221]). This should simplify
|
of the commit from which the binary was built ([#4221]). This should simplify
|
||||||
|
@ -80,6 +82,7 @@ In this release, the schema version has changed from 12 to 13.
|
||||||
[#3057]: https://github.com/AdguardTeam/AdGuardHome/issues/3057
|
[#3057]: https://github.com/AdguardTeam/AdGuardHome/issues/3057
|
||||||
[#3367]: https://github.com/AdguardTeam/AdGuardHome/issues/3367
|
[#3367]: https://github.com/AdguardTeam/AdGuardHome/issues/3367
|
||||||
[#4221]: https://github.com/AdguardTeam/AdGuardHome/issues/4221
|
[#4221]: https://github.com/AdguardTeam/AdGuardHome/issues/4221
|
||||||
|
[#4238]: https://github.com/AdguardTeam/AdGuardHome/issues/4238
|
||||||
|
|
||||||
[repr]: https://reproducible-builds.org/docs/source-date-epoch/
|
[repr]: https://reproducible-builds.org/docs/source-date-epoch/
|
||||||
|
|
||||||
|
|
|
@ -20,17 +20,19 @@ func DiscardLogOutput(m *testing.M) {
|
||||||
|
|
||||||
// ReplaceLogWriter moves logger output to w and uses Cleanup method of t to
|
// ReplaceLogWriter moves logger output to w and uses Cleanup method of t to
|
||||||
// revert changes.
|
// revert changes.
|
||||||
func ReplaceLogWriter(t *testing.T, w io.Writer) {
|
func ReplaceLogWriter(t testing.TB, w io.Writer) {
|
||||||
stdWriter := log.Writer()
|
t.Helper()
|
||||||
t.Cleanup(func() {
|
|
||||||
log.SetOutput(stdWriter)
|
prev := log.Writer()
|
||||||
})
|
t.Cleanup(func() { log.SetOutput(prev) })
|
||||||
log.SetOutput(w)
|
log.SetOutput(w)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReplaceLogLevel sets logging level to l and uses Cleanup method of t to
|
// ReplaceLogLevel sets logging level to l and uses Cleanup method of t to
|
||||||
// revert changes.
|
// revert changes.
|
||||||
func ReplaceLogLevel(t *testing.T, l log.Level) {
|
func ReplaceLogLevel(t testing.TB, l log.Level) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
switch l {
|
switch l {
|
||||||
case log.INFO, log.DEBUG, log.ERROR:
|
case log.INFO, log.DEBUG, log.ERROR:
|
||||||
// Go on.
|
// Go on.
|
||||||
|
@ -38,9 +40,7 @@ func ReplaceLogLevel(t *testing.T, l log.Level) {
|
||||||
t.Fatalf("wrong l value (must be one of %v, %v, %v)", log.INFO, log.DEBUG, log.ERROR)
|
t.Fatalf("wrong l value (must be one of %v, %v, %v)", log.INFO, log.DEBUG, log.ERROR)
|
||||||
}
|
}
|
||||||
|
|
||||||
stdLevel := log.GetLevel()
|
prev := log.GetLevel()
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() { log.SetLevel(prev) })
|
||||||
log.SetLevel(stdLevel)
|
|
||||||
})
|
|
||||||
log.SetLevel(l)
|
log.SetLevel(l)
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,10 +11,10 @@ import (
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestUpstream is a mock of real upstream.
|
// Upstream is a mock implementation of upstream.Upstream.
|
||||||
type TestUpstream struct {
|
type Upstream struct {
|
||||||
// CName is a map of hostname to canonical name.
|
// CName is a map of hostname to canonical name.
|
||||||
CName map[string]string
|
CName map[string][]string
|
||||||
// IPv4 is a map of hostname to IPv4.
|
// IPv4 is a map of hostname to IPv4.
|
||||||
IPv4 map[string][]net.IP
|
IPv4 map[string][]net.IP
|
||||||
// IPv6 is a map of hostname to IPv6.
|
// IPv6 is a map of hostname to IPv6.
|
||||||
|
@ -25,78 +25,45 @@ type TestUpstream struct {
|
||||||
Addr string
|
Addr string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Exchange implements upstream.Upstream interface for *TestUpstream.
|
// Exchange implements the upstream.Upstream interface for *Upstream.
|
||||||
//
|
//
|
||||||
// TODO(a.garipov): Split further into handlers.
|
// TODO(a.garipov): Split further into handlers.
|
||||||
func (u *TestUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
|
func (u *Upstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
|
||||||
resp = &dns.Msg{}
|
resp = new(dns.Msg).SetReply(m)
|
||||||
resp.SetReply(m)
|
|
||||||
|
|
||||||
if len(m.Question) == 0 {
|
if len(m.Question) == 0 {
|
||||||
return nil, fmt.Errorf("question should not be empty")
|
return nil, fmt.Errorf("question should not be empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
name := m.Question[0].Name
|
q := m.Question[0]
|
||||||
|
name := q.Name
|
||||||
if cname, ok := u.CName[name]; ok {
|
for _, cname := range u.CName[name] {
|
||||||
ans := &dns.CNAME{
|
resp.Answer = append(resp.Answer, &dns.CNAME{
|
||||||
Hdr: dns.RR_Header{
|
Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeCNAME},
|
||||||
Name: name,
|
|
||||||
Rrtype: dns.TypeCNAME,
|
|
||||||
},
|
|
||||||
Target: cname,
|
Target: cname,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
resp.Answer = append(resp.Answer, ans)
|
qtype := q.Qtype
|
||||||
}
|
|
||||||
|
|
||||||
rrType := m.Question[0].Qtype
|
|
||||||
hdr := dns.RR_Header{
|
hdr := dns.RR_Header{
|
||||||
Name: name,
|
Name: name,
|
||||||
Rrtype: rrType,
|
Rrtype: qtype,
|
||||||
}
|
}
|
||||||
|
|
||||||
var names []string
|
switch qtype {
|
||||||
var ips []net.IP
|
|
||||||
switch m.Question[0].Qtype {
|
|
||||||
case dns.TypeA:
|
case dns.TypeA:
|
||||||
ips = u.IPv4[name]
|
for _, ip := range u.IPv4[name] {
|
||||||
|
resp.Answer = append(resp.Answer, &dns.A{Hdr: hdr, A: ip})
|
||||||
|
}
|
||||||
case dns.TypeAAAA:
|
case dns.TypeAAAA:
|
||||||
ips = u.IPv6[name]
|
for _, ip := range u.IPv6[name] {
|
||||||
|
resp.Answer = append(resp.Answer, &dns.AAAA{Hdr: hdr, AAAA: ip})
|
||||||
|
}
|
||||||
case dns.TypePTR:
|
case dns.TypePTR:
|
||||||
names = u.Reverse[name]
|
for _, name := range u.Reverse[name] {
|
||||||
|
resp.Answer = append(resp.Answer, &dns.PTR{Hdr: hdr, Ptr: name})
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, ip := range ips {
|
|
||||||
var ans dns.RR
|
|
||||||
if rrType == dns.TypeA {
|
|
||||||
ans = &dns.A{
|
|
||||||
Hdr: hdr,
|
|
||||||
A: ip,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
resp.Answer = append(resp.Answer, ans)
|
|
||||||
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
ans = &dns.AAAA{
|
|
||||||
Hdr: hdr,
|
|
||||||
AAAA: ip,
|
|
||||||
}
|
|
||||||
|
|
||||||
resp.Answer = append(resp.Answer, ans)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, n := range names {
|
|
||||||
ans := &dns.PTR{
|
|
||||||
Hdr: hdr,
|
|
||||||
Ptr: n,
|
|
||||||
}
|
|
||||||
|
|
||||||
resp.Answer = append(resp.Answer, ans)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(resp.Answer) == 0 {
|
if len(resp.Answer) == 0 {
|
||||||
resp.SetRcode(m, dns.RcodeNameError)
|
resp.SetRcode(m, dns.RcodeNameError)
|
||||||
}
|
}
|
||||||
|
@ -104,8 +71,8 @@ func (u *TestUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Address implements upstream.Upstream interface for *TestUpstream.
|
// Address implements upstream.Upstream interface for *Upstream.
|
||||||
func (u *TestUpstream) Address() string {
|
func (u *Upstream) Address() string {
|
||||||
return u.Addr
|
return u.Addr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -613,9 +613,9 @@ func (s *Server) processFilteringAfterResponse(ctx *dnsContext) (rc resultCode)
|
||||||
d.Res.Answer = answer
|
d.Res.Answer = answer
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
// Check the response only if the it's from an upstream. Don't check
|
// Check the response only if it's from an upstream. Don't check the
|
||||||
// the response if the protection is disabled since dnsrewrite rules
|
// response if the protection is disabled since dnsrewrite rules aren't
|
||||||
// aren't applied to it anyway.
|
// applied to it anyway.
|
||||||
if !ctx.protectionEnabled || !ctx.responseFromUpstream || s.dnsFilter == nil {
|
if !ctx.protectionEnabled || !ctx.responseFromUpstream || s.dnsFilter == nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
|
@ -261,7 +261,7 @@ func TestServer_ProcessInternalHosts(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_ProcessRestrictLocal(t *testing.T) {
|
func TestServer_ProcessRestrictLocal(t *testing.T) {
|
||||||
ups := &aghtest.TestUpstream{
|
ups := &aghtest.Upstream{
|
||||||
Reverse: map[string][]string{
|
Reverse: map[string][]string{
|
||||||
"251.252.253.254.in-addr.arpa.": {"host1.example.net."},
|
"251.252.253.254.in-addr.arpa.": {"host1.example.net."},
|
||||||
"1.1.168.192.in-addr.arpa.": {"some.local-client."},
|
"1.1.168.192.in-addr.arpa.": {"some.local-client."},
|
||||||
|
@ -339,7 +339,7 @@ func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) {
|
||||||
s := createTestServer(t, &filtering.Config{}, ServerConfig{
|
s := createTestServer(t, &filtering.Config{}, ServerConfig{
|
||||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||||
}, &aghtest.TestUpstream{
|
}, &aghtest.Upstream{
|
||||||
Reverse: map[string][]string{
|
Reverse: map[string][]string{
|
||||||
reqAddr: {locDomain},
|
reqAddr: {locDomain},
|
||||||
},
|
},
|
||||||
|
|
|
@ -89,7 +89,7 @@ func createTestServer(
|
||||||
defer s.serverLock.Unlock()
|
defer s.serverLock.Unlock()
|
||||||
|
|
||||||
if localUps != nil {
|
if localUps != nil {
|
||||||
s.localResolvers.Config.UpstreamConfig.Upstreams = []upstream.Upstream{localUps}
|
s.localResolvers.UpstreamConfig.Upstreams = []upstream.Upstream{localUps}
|
||||||
s.conf.UsePrivateRDNS = true
|
s.conf.UsePrivateRDNS = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -247,7 +247,7 @@ func TestServer(t *testing.T) {
|
||||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||||
}, nil)
|
}, nil)
|
||||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
||||||
&aghtest.TestUpstream{
|
&aghtest.Upstream{
|
||||||
IPv4: map[string][]net.IP{
|
IPv4: map[string][]net.IP{
|
||||||
"google-public-dns-a.google.com.": {{8, 8, 8, 8}},
|
"google-public-dns-a.google.com.": {{8, 8, 8, 8}},
|
||||||
},
|
},
|
||||||
|
@ -316,7 +316,7 @@ func TestServerWithProtectionDisabled(t *testing.T) {
|
||||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||||
}, nil)
|
}, nil)
|
||||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
||||||
&aghtest.TestUpstream{
|
&aghtest.Upstream{
|
||||||
IPv4: map[string][]net.IP{
|
IPv4: map[string][]net.IP{
|
||||||
"google-public-dns-a.google.com.": {{8, 8, 8, 8}},
|
"google-public-dns-a.google.com.": {{8, 8, 8, 8}},
|
||||||
},
|
},
|
||||||
|
@ -339,7 +339,7 @@ func TestDoTServer(t *testing.T) {
|
||||||
TLSListenAddrs: []*net.TCPAddr{{}},
|
TLSListenAddrs: []*net.TCPAddr{{}},
|
||||||
})
|
})
|
||||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
||||||
&aghtest.TestUpstream{
|
&aghtest.Upstream{
|
||||||
IPv4: map[string][]net.IP{
|
IPv4: map[string][]net.IP{
|
||||||
"google-public-dns-a.google.com.": {{8, 8, 8, 8}},
|
"google-public-dns-a.google.com.": {{8, 8, 8, 8}},
|
||||||
},
|
},
|
||||||
|
@ -369,7 +369,7 @@ func TestDoQServer(t *testing.T) {
|
||||||
QUICListenAddrs: []*net.UDPAddr{{IP: net.IP{127, 0, 0, 1}}},
|
QUICListenAddrs: []*net.UDPAddr{{IP: net.IP{127, 0, 0, 1}}},
|
||||||
})
|
})
|
||||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
||||||
&aghtest.TestUpstream{
|
&aghtest.Upstream{
|
||||||
IPv4: map[string][]net.IP{
|
IPv4: map[string][]net.IP{
|
||||||
"google-public-dns-a.google.com.": {{8, 8, 8, 8}},
|
"google-public-dns-a.google.com.": {{8, 8, 8, 8}},
|
||||||
},
|
},
|
||||||
|
@ -413,7 +413,7 @@ func TestServerRace(t *testing.T) {
|
||||||
}
|
}
|
||||||
s := createTestServer(t, filterConf, forwardConf, nil)
|
s := createTestServer(t, filterConf, forwardConf, nil)
|
||||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
||||||
&aghtest.TestUpstream{
|
&aghtest.Upstream{
|
||||||
IPv4: map[string][]net.IP{
|
IPv4: map[string][]net.IP{
|
||||||
"google-public-dns-a.google.com.": {{8, 8, 8, 8}},
|
"google-public-dns-a.google.com.": {{8, 8, 8, 8}},
|
||||||
},
|
},
|
||||||
|
@ -552,7 +552,7 @@ func TestServerCustomClientUpstream(t *testing.T) {
|
||||||
}
|
}
|
||||||
s := createTestServer(t, &filtering.Config{}, forwardConf, nil)
|
s := createTestServer(t, &filtering.Config{}, forwardConf, nil)
|
||||||
s.conf.GetCustomUpstreamByClient = func(_ string) (conf *proxy.UpstreamConfig, err error) {
|
s.conf.GetCustomUpstreamByClient = func(_ string) (conf *proxy.UpstreamConfig, err error) {
|
||||||
ups := &aghtest.TestUpstream{
|
ups := &aghtest.Upstream{
|
||||||
IPv4: map[string][]net.IP{
|
IPv4: map[string][]net.IP{
|
||||||
"host.": {{192, 168, 0, 1}},
|
"host.": {{192, 168, 0, 1}},
|
||||||
},
|
},
|
||||||
|
@ -580,9 +580,9 @@ func TestServerCustomClientUpstream(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// testCNAMEs is a map of names and CNAMEs necessary for the TestUpstream work.
|
// testCNAMEs is a map of names and CNAMEs necessary for the TestUpstream work.
|
||||||
var testCNAMEs = map[string]string{
|
var testCNAMEs = map[string][]string{
|
||||||
"badhost.": "NULL.example.org.",
|
"badhost.": {"NULL.example.org."},
|
||||||
"whitelist.example.org.": "NULL.example.org.",
|
"whitelist.example.org.": {"NULL.example.org."},
|
||||||
}
|
}
|
||||||
|
|
||||||
// testIPv4 is a map of names and IPv4s necessary for the TestUpstream work.
|
// testIPv4 is a map of names and IPv4s necessary for the TestUpstream work.
|
||||||
|
@ -596,7 +596,7 @@ func TestBlockCNAMEProtectionEnabled(t *testing.T) {
|
||||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||||
}, nil)
|
}, nil)
|
||||||
testUpstm := &aghtest.TestUpstream{
|
testUpstm := &aghtest.Upstream{
|
||||||
CName: testCNAMEs,
|
CName: testCNAMEs,
|
||||||
IPv4: testIPv4,
|
IPv4: testIPv4,
|
||||||
IPv6: nil,
|
IPv6: nil,
|
||||||
|
@ -630,7 +630,7 @@ func TestBlockCNAME(t *testing.T) {
|
||||||
}
|
}
|
||||||
s := createTestServer(t, &filtering.Config{}, forwardConf, nil)
|
s := createTestServer(t, &filtering.Config{}, forwardConf, nil)
|
||||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
||||||
&aghtest.TestUpstream{
|
&aghtest.Upstream{
|
||||||
CName: testCNAMEs,
|
CName: testCNAMEs,
|
||||||
IPv4: testIPv4,
|
IPv4: testIPv4,
|
||||||
},
|
},
|
||||||
|
@ -640,14 +640,17 @@ func TestBlockCNAME(t *testing.T) {
|
||||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP).String()
|
addr := s.dnsProxy.Addr(proxy.ProtoUDP).String()
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
|
name string
|
||||||
host string
|
host string
|
||||||
want bool
|
want bool
|
||||||
}{{
|
}{{
|
||||||
|
name: "block_request",
|
||||||
host: "badhost.",
|
host: "badhost.",
|
||||||
// 'badhost' has a canonical name 'NULL.example.org' which is
|
// 'badhost' has a canonical name 'NULL.example.org' which is
|
||||||
// blocked by filters: response is blocked.
|
// blocked by filters: response is blocked.
|
||||||
want: true,
|
want: true,
|
||||||
}, {
|
}, {
|
||||||
|
name: "allowed",
|
||||||
host: "whitelist.example.org.",
|
host: "whitelist.example.org.",
|
||||||
// 'whitelist.example.org' has a canonical name
|
// 'whitelist.example.org' has a canonical name
|
||||||
// 'NULL.example.org' which is blocked by filters
|
// 'NULL.example.org' which is blocked by filters
|
||||||
|
@ -655,6 +658,7 @@ func TestBlockCNAME(t *testing.T) {
|
||||||
// response isn't blocked.
|
// response isn't blocked.
|
||||||
want: false,
|
want: false,
|
||||||
}, {
|
}, {
|
||||||
|
name: "block_response",
|
||||||
host: "example.org.",
|
host: "example.org.",
|
||||||
// 'example.org' has a canonical name 'cname1' with IP
|
// 'example.org' has a canonical name 'cname1' with IP
|
||||||
// 127.0.0.255 which is blocked by filters: response is blocked.
|
// 127.0.0.255 which is blocked by filters: response is blocked.
|
||||||
|
@ -662,9 +666,9 @@ func TestBlockCNAME(t *testing.T) {
|
||||||
}}
|
}}
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run("block_cname_"+tc.host, func(t *testing.T) {
|
|
||||||
req := createTestMessage(tc.host)
|
req := createTestMessage(tc.host)
|
||||||
|
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
reply, err := dns.Exchange(req, addr)
|
reply, err := dns.Exchange(req, addr)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
@ -674,7 +678,7 @@ func TestBlockCNAME(t *testing.T) {
|
||||||
|
|
||||||
ans := reply.Answer[0]
|
ans := reply.Answer[0]
|
||||||
a, ok := ans.(*dns.A)
|
a, ok := ans.(*dns.A)
|
||||||
require.Truef(t, ok, "got %T", ans)
|
require.True(t, ok)
|
||||||
|
|
||||||
assert.True(t, a.A.IsUnspecified())
|
assert.True(t, a.A.IsUnspecified())
|
||||||
}
|
}
|
||||||
|
@ -695,7 +699,7 @@ func TestClientRulesForCNAMEMatching(t *testing.T) {
|
||||||
}
|
}
|
||||||
s := createTestServer(t, &filtering.Config{}, forwardConf, nil)
|
s := createTestServer(t, &filtering.Config{}, forwardConf, nil)
|
||||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
||||||
&aghtest.TestUpstream{
|
&aghtest.Upstream{
|
||||||
CName: testCNAMEs,
|
CName: testCNAMEs,
|
||||||
IPv4: testIPv4,
|
IPv4: testIPv4,
|
||||||
},
|
},
|
||||||
|
@ -931,9 +935,9 @@ func TestRewrite(t *testing.T) {
|
||||||
}))
|
}))
|
||||||
|
|
||||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
||||||
&aghtest.TestUpstream{
|
&aghtest.Upstream{
|
||||||
CName: map[string]string{
|
CName: map[string][]string{
|
||||||
"example.org": "somename",
|
"example.org": {"somename"},
|
||||||
},
|
},
|
||||||
IPv4: map[string][]net.IP{
|
IPv4: map[string][]net.IP{
|
||||||
"example.org.": {{4, 3, 2, 1}},
|
"example.org.": {{4, 3, 2, 1}},
|
||||||
|
@ -1193,12 +1197,12 @@ func TestNewServer(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_Exchange(t *testing.T) {
|
func TestServer_Exchange(t *testing.T) {
|
||||||
extUpstream := &aghtest.TestUpstream{
|
extUpstream := &aghtest.Upstream{
|
||||||
Reverse: map[string][]string{
|
Reverse: map[string][]string{
|
||||||
"1.1.1.1.in-addr.arpa.": {"one.one.one.one"},
|
"1.1.1.1.in-addr.arpa.": {"one.one.one.one"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
locUpstream := &aghtest.TestUpstream{
|
locUpstream := &aghtest.Upstream{
|
||||||
Reverse: map[string][]string{
|
Reverse: map[string][]string{
|
||||||
"1.1.168.192.in-addr.arpa.": {"local.domain"},
|
"1.1.168.192.in-addr.arpa.": {"local.domain"},
|
||||||
"2.1.168.192.in-addr.arpa.": {},
|
"2.1.168.192.in-addr.arpa.": {},
|
||||||
|
|
|
@ -116,7 +116,7 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*filtering.Result, error) {
|
||||||
|
|
||||||
// checkHostRules checks the host against filters. It is safe for concurrent
|
// checkHostRules checks the host against filters. It is safe for concurrent
|
||||||
// use.
|
// use.
|
||||||
func (s *Server) checkHostRules(host string, qtype uint16, setts *filtering.Settings) (
|
func (s *Server) checkHostRules(host string, rrtype uint16, setts *filtering.Settings) (
|
||||||
r *filtering.Result,
|
r *filtering.Result,
|
||||||
err error,
|
err error,
|
||||||
) {
|
) {
|
||||||
|
@ -128,7 +128,7 @@ func (s *Server) checkHostRules(host string, qtype uint16, setts *filtering.Sett
|
||||||
}
|
}
|
||||||
|
|
||||||
var res filtering.Result
|
var res filtering.Result
|
||||||
res, err = s.dnsFilter.CheckHostRules(host, qtype, setts)
|
res, err = s.dnsFilter.CheckHostRules(host, rrtype, setts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -136,33 +136,36 @@ func (s *Server) checkHostRules(host string, qtype uint16, setts *filtering.Sett
|
||||||
return &res, err
|
return &res, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// If response contains CNAME, A or AAAA records, we apply filtering to each
|
// filterDNSResponse checks each resource record of the response's answer
|
||||||
// canonical host name or IP address. If this is a match, we set a new response
|
// section from ctx and returns a non-nil res if at least one of canonnical
|
||||||
// in d.Res and return.
|
// names or IP addresses in it matches the filtering rules.
|
||||||
func (s *Server) filterDNSResponse(ctx *dnsContext) (*filtering.Result, error) {
|
func (s *Server) filterDNSResponse(ctx *dnsContext) (res *filtering.Result, err error) {
|
||||||
d := ctx.proxyCtx
|
d := ctx.proxyCtx
|
||||||
|
setts := ctx.setts
|
||||||
|
if !setts.FilteringEnabled {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
for _, a := range d.Res.Answer {
|
for _, a := range d.Res.Answer {
|
||||||
host := ""
|
host := ""
|
||||||
|
var rrtype uint16
|
||||||
switch v := a.(type) {
|
switch a := a.(type) {
|
||||||
case *dns.CNAME:
|
case *dns.CNAME:
|
||||||
log.Debug("DNSFwd: Checking CNAME %s for %s", v.Target, v.Hdr.Name)
|
host = strings.TrimSuffix(a.Target, ".")
|
||||||
host = strings.TrimSuffix(v.Target, ".")
|
rrtype = dns.TypeCNAME
|
||||||
|
|
||||||
case *dns.A:
|
case *dns.A:
|
||||||
host = v.A.String()
|
host = a.A.String()
|
||||||
log.Debug("DNSFwd: Checking record A (%s) for %s", host, v.Hdr.Name)
|
rrtype = dns.TypeA
|
||||||
|
|
||||||
case *dns.AAAA:
|
case *dns.AAAA:
|
||||||
host = v.AAAA.String()
|
host = a.AAAA.String()
|
||||||
log.Debug("DNSFwd: Checking record AAAA (%s) for %s", host, v.Hdr.Name)
|
rrtype = dns.TypeAAAA
|
||||||
|
|
||||||
default:
|
default:
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
host = strings.TrimSuffix(host, ".")
|
log.Debug("dnsforward: checking %s %s for %s", dns.Type(rrtype), host, a.Header().Name)
|
||||||
res, err := s.checkHostRules(host, d.Req.Question[0].Qtype, ctx.setts)
|
|
||||||
|
res, err = s.checkHostRules(host, rrtype, setts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
} else if res == nil {
|
} else if res == nil {
|
||||||
|
|
|
@ -0,0 +1,159 @@
|
||||||
|
package dnsforward
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||||
|
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||||
|
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||||
|
"github.com/AdguardTeam/golibs/netutil"
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
|
||||||
|
rules := `
|
||||||
|
||blocked.domain^
|
||||||
|
@@||allowed.domain^
|
||||||
|
||cname.specific^$dnstype=~CNAME
|
||||||
|
||0.0.0.1^$dnstype=~A
|
||||||
|
||::1^$dnstype=~AAAA
|
||||||
|
`
|
||||||
|
|
||||||
|
forwardConf := ServerConfig{
|
||||||
|
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||||
|
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||||
|
FilteringConfig: FilteringConfig{
|
||||||
|
ProtectionEnabled: true,
|
||||||
|
BlockingMode: BlockingModeDefault,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
filters := []filtering.Filter{{
|
||||||
|
ID: 0, Data: []byte(rules),
|
||||||
|
}}
|
||||||
|
|
||||||
|
f := filtering.New(&filtering.Config{}, filters)
|
||||||
|
f.SetEnabled(true)
|
||||||
|
|
||||||
|
snd, err := aghnet.NewSubnetDetector()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, snd)
|
||||||
|
|
||||||
|
s, err := NewServer(DNSCreateParams{
|
||||||
|
DHCPServer: &testDHCP{},
|
||||||
|
DNSFilter: f,
|
||||||
|
SubnetDetector: snd,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
s.conf = forwardConf
|
||||||
|
err = s.Prepare(nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
||||||
|
&aghtest.Upstream{
|
||||||
|
CName: map[string][]string{
|
||||||
|
"cname.exception.": {"cname.specific."},
|
||||||
|
"should.block.": {"blocked.domain."},
|
||||||
|
"allowed.first.": {"allowed.domain.", "blocked.domain."},
|
||||||
|
"blocked.first.": {"blocked.domain.", "allowed.domain."},
|
||||||
|
},
|
||||||
|
IPv4: map[string][]net.IP{
|
||||||
|
"a.exception.": {{0, 0, 0, 1}},
|
||||||
|
},
|
||||||
|
IPv6: map[string][]net.IP{
|
||||||
|
"aaaa.exception.": {net.ParseIP("::1")},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
startDeferStop(t, s)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
req *dns.Msg
|
||||||
|
name string
|
||||||
|
wantAns []dns.RR
|
||||||
|
}{{
|
||||||
|
req: createTestMessage("cname.exception."),
|
||||||
|
name: "cname_exception",
|
||||||
|
wantAns: []dns.RR{&dns.CNAME{
|
||||||
|
Hdr: dns.RR_Header{
|
||||||
|
Name: "cname.exception.",
|
||||||
|
Rrtype: dns.TypeCNAME,
|
||||||
|
},
|
||||||
|
Target: "cname.specific.",
|
||||||
|
}},
|
||||||
|
}, {
|
||||||
|
req: createTestMessage("should.block."),
|
||||||
|
name: "blocked_by_cname",
|
||||||
|
wantAns: []dns.RR{&dns.A{
|
||||||
|
Hdr: dns.RR_Header{
|
||||||
|
Name: "should.block.",
|
||||||
|
Rrtype: dns.TypeA,
|
||||||
|
Class: dns.ClassINET,
|
||||||
|
},
|
||||||
|
A: netutil.IPv4Zero(),
|
||||||
|
}},
|
||||||
|
}, {
|
||||||
|
req: createTestMessage("a.exception."),
|
||||||
|
name: "a_exception",
|
||||||
|
wantAns: []dns.RR{&dns.A{
|
||||||
|
Hdr: dns.RR_Header{
|
||||||
|
Name: "a.exception.",
|
||||||
|
Rrtype: dns.TypeA,
|
||||||
|
},
|
||||||
|
A: net.IP{0, 0, 0, 1},
|
||||||
|
}},
|
||||||
|
}, {
|
||||||
|
req: createTestMessageWithType("aaaa.exception.", dns.TypeAAAA),
|
||||||
|
name: "aaaa_exception",
|
||||||
|
wantAns: []dns.RR{&dns.AAAA{
|
||||||
|
Hdr: dns.RR_Header{
|
||||||
|
Name: "aaaa.exception.",
|
||||||
|
Rrtype: dns.TypeAAAA,
|
||||||
|
},
|
||||||
|
AAAA: net.ParseIP("::1"),
|
||||||
|
}},
|
||||||
|
}, {
|
||||||
|
req: createTestMessage("allowed.first."),
|
||||||
|
name: "allowed_first",
|
||||||
|
wantAns: []dns.RR{&dns.A{
|
||||||
|
Hdr: dns.RR_Header{
|
||||||
|
Name: "allowed.first.",
|
||||||
|
Rrtype: dns.TypeA,
|
||||||
|
Class: dns.ClassINET,
|
||||||
|
},
|
||||||
|
A: netutil.IPv4Zero(),
|
||||||
|
}},
|
||||||
|
}, {
|
||||||
|
req: createTestMessage("blocked.first."),
|
||||||
|
name: "blocked_first",
|
||||||
|
wantAns: []dns.RR{&dns.A{
|
||||||
|
Hdr: dns.RR_Header{
|
||||||
|
Name: "blocked.first.",
|
||||||
|
Rrtype: dns.TypeA,
|
||||||
|
Class: dns.ClassINET,
|
||||||
|
},
|
||||||
|
A: netutil.IPv4Zero(),
|
||||||
|
}},
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
dctx := &proxy.DNSContext{
|
||||||
|
Proto: proxy.ProtoUDP,
|
||||||
|
Req: tc.req,
|
||||||
|
Addr: &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: 1},
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
err = s.handleDNSRequest(nil, dctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, dctx.Res)
|
||||||
|
|
||||||
|
assert.Equal(t, tc.wantAns, dctx.Res.Answer)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -420,14 +420,8 @@ func (r Reason) Matched() bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
// CheckHostRules tries to match the host against filtering rules only.
|
// CheckHostRules tries to match the host against filtering rules only.
|
||||||
func (d *DNSFilter) CheckHostRules(host string, qtype uint16, setts *Settings) (Result, error) {
|
func (d *DNSFilter) CheckHostRules(host string, rrtype uint16, setts *Settings) (Result, error) {
|
||||||
if !setts.FilteringEnabled {
|
return d.matchHost(strings.ToLower(host), rrtype, setts)
|
||||||
return Result{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
host = strings.ToLower(host)
|
|
||||||
|
|
||||||
return d.matchHost(host, qtype, setts)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// CheckHost tries to match the host against filtering rules, then safebrowsing
|
// CheckHost tries to match the host against filtering rules, then safebrowsing
|
||||||
|
@ -726,8 +720,7 @@ func hostRulesToRules(netRules []*rules.HostRule) (res []rules.Rule) {
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
// matchHostProcessAllowList processes the allowlist logic of host
|
// matchHostProcessAllowList processes the allowlist logic of host matching.
|
||||||
// matching.
|
|
||||||
func (d *DNSFilter) matchHostProcessAllowList(
|
func (d *DNSFilter) matchHostProcessAllowList(
|
||||||
host string,
|
host string,
|
||||||
dnsres *urlfilter.DNSResult,
|
dnsres *urlfilter.DNSResult,
|
||||||
|
@ -798,11 +791,11 @@ func (d *DNSFilter) matchHostProcessDNSResult(
|
||||||
return Result{}
|
return Result{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// matchHost is a low-level way to check only if hostname is filtered by rules,
|
// matchHost is a low-level way to check only if host is filtered by rules,
|
||||||
// skipping expensive safebrowsing and parental lookups.
|
// skipping expensive safebrowsing and parental lookups.
|
||||||
func (d *DNSFilter) matchHost(
|
func (d *DNSFilter) matchHost(
|
||||||
host string,
|
host string,
|
||||||
qtype uint16,
|
rrtype uint16,
|
||||||
setts *Settings,
|
setts *Settings,
|
||||||
) (res Result, err error) {
|
) (res Result, err error) {
|
||||||
if !setts.FilteringEnabled {
|
if !setts.FilteringEnabled {
|
||||||
|
@ -815,7 +808,7 @@ func (d *DNSFilter) matchHost(
|
||||||
// TODO(e.burkov): Wait for urlfilter update to pass net.IP.
|
// TODO(e.burkov): Wait for urlfilter update to pass net.IP.
|
||||||
ClientIP: setts.ClientIP.String(),
|
ClientIP: setts.ClientIP.String(),
|
||||||
ClientName: setts.ClientName,
|
ClientName: setts.ClientName,
|
||||||
DNSType: qtype,
|
DNSType: rrtype,
|
||||||
}
|
}
|
||||||
|
|
||||||
d.engineLock.RLock()
|
d.engineLock.RLock()
|
||||||
|
@ -855,7 +848,7 @@ func (d *DNSFilter) matchHost(
|
||||||
return Result{}, nil
|
return Result{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
res = d.matchHostProcessDNSResult(qtype, dnsres)
|
res = d.matchHostProcessDNSResult(rrtype, dnsres)
|
||||||
for _, r := range res.Rules {
|
for _, r := range res.Rules {
|
||||||
log.Debug(
|
log.Debug(
|
||||||
"filtering: found rule %q for host %q, filter list id: %d",
|
"filtering: found rule %q for host %q, filter list id: %d",
|
||||||
|
|
|
@ -167,7 +167,7 @@ func TestRDNS_WorkerLoop(t *testing.T) {
|
||||||
w := &bytes.Buffer{}
|
w := &bytes.Buffer{}
|
||||||
aghtest.ReplaceLogWriter(t, w)
|
aghtest.ReplaceLogWriter(t, w)
|
||||||
|
|
||||||
locUpstream := &aghtest.TestUpstream{
|
locUpstream := &aghtest.Upstream{
|
||||||
Reverse: map[string][]string{
|
Reverse: map[string][]string{
|
||||||
"192.168.1.1": {"local.domain"},
|
"192.168.1.1": {"local.domain"},
|
||||||
"2a00:1450:400c:c06::93": {"ipv6.domain"},
|
"2a00:1450:400c:c06::93": {"ipv6.domain"},
|
||||||
|
|
Loading…
Reference in New Issue