Pull request: 4238 response filtering

Merge in DNS/adguard-home from 4238-response-filtering to master

Closes #4238.

Squashed commit of the following:

commit 2113f83c617a396a39f910bb8df939364fedf391
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Thu Feb 3 21:04:25 2022 +0300

    dnsforward: restore a bit

commit f78607ed97892557c6bd6f3c3332f0bae01c1987
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Thu Feb 3 20:52:45 2022 +0300

    all: imp code, docs

commit 646074ce141e8ac12a972f46d071389a2ce124e4
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Thu Feb 3 20:37:05 2022 +0300

    all: log changes

commit 94556d810549370fc455bcf14537fa1d2783eed1
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Thu Feb 3 20:30:57 2022 +0300

    all: imp test upstream, cover resp filtering

commit 63e7721822a049734a390c7d7ea6d8416a43c8b5
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Tue Feb 1 21:58:08 2022 +0300

    all: filter response by rrtype
This commit is contained in:
Eugene Burkov 2022-02-03 21:19:32 +03:00
parent 0ee34534c6
commit e783564084
10 changed files with 257 additions and 128 deletions

View File

@ -21,6 +21,8 @@ and this project adheres to
### Changed ### Changed
- Response filtering is now performed using the record types of the answer
section of messages as opposed to the type of the question ([#4238]).
- Instead of adding the build time information, the build scripts now use the - Instead of adding the build time information, the build scripts now use the
standardized environment variable [`SOURCE_DATE_EPOCH`][repr] to add the date standardized environment variable [`SOURCE_DATE_EPOCH`][repr] to add the date
of the commit from which the binary was built ([#4221]). This should simplify of the commit from which the binary was built ([#4221]). This should simplify
@ -80,6 +82,7 @@ In this release, the schema version has changed from 12 to 13.
[#3057]: https://github.com/AdguardTeam/AdGuardHome/issues/3057 [#3057]: https://github.com/AdguardTeam/AdGuardHome/issues/3057
[#3367]: https://github.com/AdguardTeam/AdGuardHome/issues/3367 [#3367]: https://github.com/AdguardTeam/AdGuardHome/issues/3367
[#4221]: https://github.com/AdguardTeam/AdGuardHome/issues/4221 [#4221]: https://github.com/AdguardTeam/AdGuardHome/issues/4221
[#4238]: https://github.com/AdguardTeam/AdGuardHome/issues/4238
[repr]: https://reproducible-builds.org/docs/source-date-epoch/ [repr]: https://reproducible-builds.org/docs/source-date-epoch/

View File

@ -20,17 +20,19 @@ func DiscardLogOutput(m *testing.M) {
// ReplaceLogWriter moves logger output to w and uses Cleanup method of t to // ReplaceLogWriter moves logger output to w and uses Cleanup method of t to
// revert changes. // revert changes.
func ReplaceLogWriter(t *testing.T, w io.Writer) { func ReplaceLogWriter(t testing.TB, w io.Writer) {
stdWriter := log.Writer() t.Helper()
t.Cleanup(func() {
log.SetOutput(stdWriter) prev := log.Writer()
}) t.Cleanup(func() { log.SetOutput(prev) })
log.SetOutput(w) log.SetOutput(w)
} }
// ReplaceLogLevel sets logging level to l and uses Cleanup method of t to // ReplaceLogLevel sets logging level to l and uses Cleanup method of t to
// revert changes. // revert changes.
func ReplaceLogLevel(t *testing.T, l log.Level) { func ReplaceLogLevel(t testing.TB, l log.Level) {
t.Helper()
switch l { switch l {
case log.INFO, log.DEBUG, log.ERROR: case log.INFO, log.DEBUG, log.ERROR:
// Go on. // Go on.
@ -38,9 +40,7 @@ func ReplaceLogLevel(t *testing.T, l log.Level) {
t.Fatalf("wrong l value (must be one of %v, %v, %v)", log.INFO, log.DEBUG, log.ERROR) t.Fatalf("wrong l value (must be one of %v, %v, %v)", log.INFO, log.DEBUG, log.ERROR)
} }
stdLevel := log.GetLevel() prev := log.GetLevel()
t.Cleanup(func() { t.Cleanup(func() { log.SetLevel(prev) })
log.SetLevel(stdLevel)
})
log.SetLevel(l) log.SetLevel(l)
} }

View File

@ -11,10 +11,10 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
) )
// TestUpstream is a mock of real upstream. // Upstream is a mock implementation of upstream.Upstream.
type TestUpstream struct { type Upstream struct {
// CName is a map of hostname to canonical name. // CName is a map of hostname to canonical name.
CName map[string]string CName map[string][]string
// IPv4 is a map of hostname to IPv4. // IPv4 is a map of hostname to IPv4.
IPv4 map[string][]net.IP IPv4 map[string][]net.IP
// IPv6 is a map of hostname to IPv6. // IPv6 is a map of hostname to IPv6.
@ -25,78 +25,45 @@ type TestUpstream struct {
Addr string Addr string
} }
// Exchange implements upstream.Upstream interface for *TestUpstream. // Exchange implements the upstream.Upstream interface for *Upstream.
// //
// TODO(a.garipov): Split further into handlers. // TODO(a.garipov): Split further into handlers.
func (u *TestUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) { func (u *Upstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
resp = &dns.Msg{} resp = new(dns.Msg).SetReply(m)
resp.SetReply(m)
if len(m.Question) == 0 { if len(m.Question) == 0 {
return nil, fmt.Errorf("question should not be empty") return nil, fmt.Errorf("question should not be empty")
} }
name := m.Question[0].Name q := m.Question[0]
name := q.Name
if cname, ok := u.CName[name]; ok { for _, cname := range u.CName[name] {
ans := &dns.CNAME{ resp.Answer = append(resp.Answer, &dns.CNAME{
Hdr: dns.RR_Header{ Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeCNAME},
Name: name,
Rrtype: dns.TypeCNAME,
},
Target: cname, Target: cname,
})
} }
resp.Answer = append(resp.Answer, ans) qtype := q.Qtype
}
rrType := m.Question[0].Qtype
hdr := dns.RR_Header{ hdr := dns.RR_Header{
Name: name, Name: name,
Rrtype: rrType, Rrtype: qtype,
} }
var names []string switch qtype {
var ips []net.IP
switch m.Question[0].Qtype {
case dns.TypeA: case dns.TypeA:
ips = u.IPv4[name] for _, ip := range u.IPv4[name] {
resp.Answer = append(resp.Answer, &dns.A{Hdr: hdr, A: ip})
}
case dns.TypeAAAA: case dns.TypeAAAA:
ips = u.IPv6[name] for _, ip := range u.IPv6[name] {
resp.Answer = append(resp.Answer, &dns.AAAA{Hdr: hdr, AAAA: ip})
}
case dns.TypePTR: case dns.TypePTR:
names = u.Reverse[name] for _, name := range u.Reverse[name] {
resp.Answer = append(resp.Answer, &dns.PTR{Hdr: hdr, Ptr: name})
} }
for _, ip := range ips {
var ans dns.RR
if rrType == dns.TypeA {
ans = &dns.A{
Hdr: hdr,
A: ip,
} }
resp.Answer = append(resp.Answer, ans)
continue
}
ans = &dns.AAAA{
Hdr: hdr,
AAAA: ip,
}
resp.Answer = append(resp.Answer, ans)
}
for _, n := range names {
ans := &dns.PTR{
Hdr: hdr,
Ptr: n,
}
resp.Answer = append(resp.Answer, ans)
}
if len(resp.Answer) == 0 { if len(resp.Answer) == 0 {
resp.SetRcode(m, dns.RcodeNameError) resp.SetRcode(m, dns.RcodeNameError)
} }
@ -104,8 +71,8 @@ func (u *TestUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
return resp, nil return resp, nil
} }
// Address implements upstream.Upstream interface for *TestUpstream. // Address implements upstream.Upstream interface for *Upstream.
func (u *TestUpstream) Address() string { func (u *Upstream) Address() string {
return u.Addr return u.Addr
} }

View File

@ -613,9 +613,9 @@ func (s *Server) processFilteringAfterResponse(ctx *dnsContext) (rc resultCode)
d.Res.Answer = answer d.Res.Answer = answer
} }
default: default:
// Check the response only if the it's from an upstream. Don't check // Check the response only if it's from an upstream. Don't check the
// the response if the protection is disabled since dnsrewrite rules // response if the protection is disabled since dnsrewrite rules aren't
// aren't applied to it anyway. // applied to it anyway.
if !ctx.protectionEnabled || !ctx.responseFromUpstream || s.dnsFilter == nil { if !ctx.protectionEnabled || !ctx.responseFromUpstream || s.dnsFilter == nil {
break break
} }

View File

@ -261,7 +261,7 @@ func TestServer_ProcessInternalHosts(t *testing.T) {
} }
func TestServer_ProcessRestrictLocal(t *testing.T) { func TestServer_ProcessRestrictLocal(t *testing.T) {
ups := &aghtest.TestUpstream{ ups := &aghtest.Upstream{
Reverse: map[string][]string{ Reverse: map[string][]string{
"251.252.253.254.in-addr.arpa.": {"host1.example.net."}, "251.252.253.254.in-addr.arpa.": {"host1.example.net."},
"1.1.168.192.in-addr.arpa.": {"some.local-client."}, "1.1.168.192.in-addr.arpa.": {"some.local-client."},
@ -339,7 +339,7 @@ func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) {
s := createTestServer(t, &filtering.Config{}, ServerConfig{ s := createTestServer(t, &filtering.Config{}, ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}}, UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}}, TCPListenAddrs: []*net.TCPAddr{{}},
}, &aghtest.TestUpstream{ }, &aghtest.Upstream{
Reverse: map[string][]string{ Reverse: map[string][]string{
reqAddr: {locDomain}, reqAddr: {locDomain},
}, },

View File

@ -89,7 +89,7 @@ func createTestServer(
defer s.serverLock.Unlock() defer s.serverLock.Unlock()
if localUps != nil { if localUps != nil {
s.localResolvers.Config.UpstreamConfig.Upstreams = []upstream.Upstream{localUps} s.localResolvers.UpstreamConfig.Upstreams = []upstream.Upstream{localUps}
s.conf.UsePrivateRDNS = true s.conf.UsePrivateRDNS = true
} }
@ -247,7 +247,7 @@ func TestServer(t *testing.T) {
TCPListenAddrs: []*net.TCPAddr{{}}, TCPListenAddrs: []*net.TCPAddr{{}},
}, nil) }, nil)
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&aghtest.TestUpstream{ &aghtest.Upstream{
IPv4: map[string][]net.IP{ IPv4: map[string][]net.IP{
"google-public-dns-a.google.com.": {{8, 8, 8, 8}}, "google-public-dns-a.google.com.": {{8, 8, 8, 8}},
}, },
@ -316,7 +316,7 @@ func TestServerWithProtectionDisabled(t *testing.T) {
TCPListenAddrs: []*net.TCPAddr{{}}, TCPListenAddrs: []*net.TCPAddr{{}},
}, nil) }, nil)
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&aghtest.TestUpstream{ &aghtest.Upstream{
IPv4: map[string][]net.IP{ IPv4: map[string][]net.IP{
"google-public-dns-a.google.com.": {{8, 8, 8, 8}}, "google-public-dns-a.google.com.": {{8, 8, 8, 8}},
}, },
@ -339,7 +339,7 @@ func TestDoTServer(t *testing.T) {
TLSListenAddrs: []*net.TCPAddr{{}}, TLSListenAddrs: []*net.TCPAddr{{}},
}) })
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&aghtest.TestUpstream{ &aghtest.Upstream{
IPv4: map[string][]net.IP{ IPv4: map[string][]net.IP{
"google-public-dns-a.google.com.": {{8, 8, 8, 8}}, "google-public-dns-a.google.com.": {{8, 8, 8, 8}},
}, },
@ -369,7 +369,7 @@ func TestDoQServer(t *testing.T) {
QUICListenAddrs: []*net.UDPAddr{{IP: net.IP{127, 0, 0, 1}}}, QUICListenAddrs: []*net.UDPAddr{{IP: net.IP{127, 0, 0, 1}}},
}) })
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&aghtest.TestUpstream{ &aghtest.Upstream{
IPv4: map[string][]net.IP{ IPv4: map[string][]net.IP{
"google-public-dns-a.google.com.": {{8, 8, 8, 8}}, "google-public-dns-a.google.com.": {{8, 8, 8, 8}},
}, },
@ -413,7 +413,7 @@ func TestServerRace(t *testing.T) {
} }
s := createTestServer(t, filterConf, forwardConf, nil) s := createTestServer(t, filterConf, forwardConf, nil)
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&aghtest.TestUpstream{ &aghtest.Upstream{
IPv4: map[string][]net.IP{ IPv4: map[string][]net.IP{
"google-public-dns-a.google.com.": {{8, 8, 8, 8}}, "google-public-dns-a.google.com.": {{8, 8, 8, 8}},
}, },
@ -552,7 +552,7 @@ func TestServerCustomClientUpstream(t *testing.T) {
} }
s := createTestServer(t, &filtering.Config{}, forwardConf, nil) s := createTestServer(t, &filtering.Config{}, forwardConf, nil)
s.conf.GetCustomUpstreamByClient = func(_ string) (conf *proxy.UpstreamConfig, err error) { s.conf.GetCustomUpstreamByClient = func(_ string) (conf *proxy.UpstreamConfig, err error) {
ups := &aghtest.TestUpstream{ ups := &aghtest.Upstream{
IPv4: map[string][]net.IP{ IPv4: map[string][]net.IP{
"host.": {{192, 168, 0, 1}}, "host.": {{192, 168, 0, 1}},
}, },
@ -580,9 +580,9 @@ func TestServerCustomClientUpstream(t *testing.T) {
} }
// testCNAMEs is a map of names and CNAMEs necessary for the TestUpstream work. // testCNAMEs is a map of names and CNAMEs necessary for the TestUpstream work.
var testCNAMEs = map[string]string{ var testCNAMEs = map[string][]string{
"badhost.": "NULL.example.org.", "badhost.": {"NULL.example.org."},
"whitelist.example.org.": "NULL.example.org.", "whitelist.example.org.": {"NULL.example.org."},
} }
// testIPv4 is a map of names and IPv4s necessary for the TestUpstream work. // testIPv4 is a map of names and IPv4s necessary for the TestUpstream work.
@ -596,7 +596,7 @@ func TestBlockCNAMEProtectionEnabled(t *testing.T) {
UDPListenAddrs: []*net.UDPAddr{{}}, UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}}, TCPListenAddrs: []*net.TCPAddr{{}},
}, nil) }, nil)
testUpstm := &aghtest.TestUpstream{ testUpstm := &aghtest.Upstream{
CName: testCNAMEs, CName: testCNAMEs,
IPv4: testIPv4, IPv4: testIPv4,
IPv6: nil, IPv6: nil,
@ -630,7 +630,7 @@ func TestBlockCNAME(t *testing.T) {
} }
s := createTestServer(t, &filtering.Config{}, forwardConf, nil) s := createTestServer(t, &filtering.Config{}, forwardConf, nil)
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&aghtest.TestUpstream{ &aghtest.Upstream{
CName: testCNAMEs, CName: testCNAMEs,
IPv4: testIPv4, IPv4: testIPv4,
}, },
@ -640,14 +640,17 @@ func TestBlockCNAME(t *testing.T) {
addr := s.dnsProxy.Addr(proxy.ProtoUDP).String() addr := s.dnsProxy.Addr(proxy.ProtoUDP).String()
testCases := []struct { testCases := []struct {
name string
host string host string
want bool want bool
}{{ }{{
name: "block_request",
host: "badhost.", host: "badhost.",
// 'badhost' has a canonical name 'NULL.example.org' which is // 'badhost' has a canonical name 'NULL.example.org' which is
// blocked by filters: response is blocked. // blocked by filters: response is blocked.
want: true, want: true,
}, { }, {
name: "allowed",
host: "whitelist.example.org.", host: "whitelist.example.org.",
// 'whitelist.example.org' has a canonical name // 'whitelist.example.org' has a canonical name
// 'NULL.example.org' which is blocked by filters // 'NULL.example.org' which is blocked by filters
@ -655,6 +658,7 @@ func TestBlockCNAME(t *testing.T) {
// response isn't blocked. // response isn't blocked.
want: false, want: false,
}, { }, {
name: "block_response",
host: "example.org.", host: "example.org.",
// 'example.org' has a canonical name 'cname1' with IP // 'example.org' has a canonical name 'cname1' with IP
// 127.0.0.255 which is blocked by filters: response is blocked. // 127.0.0.255 which is blocked by filters: response is blocked.
@ -662,9 +666,9 @@ func TestBlockCNAME(t *testing.T) {
}} }}
for _, tc := range testCases { for _, tc := range testCases {
t.Run("block_cname_"+tc.host, func(t *testing.T) {
req := createTestMessage(tc.host) req := createTestMessage(tc.host)
t.Run(tc.name, func(t *testing.T) {
reply, err := dns.Exchange(req, addr) reply, err := dns.Exchange(req, addr)
require.NoError(t, err) require.NoError(t, err)
@ -674,7 +678,7 @@ func TestBlockCNAME(t *testing.T) {
ans := reply.Answer[0] ans := reply.Answer[0]
a, ok := ans.(*dns.A) a, ok := ans.(*dns.A)
require.Truef(t, ok, "got %T", ans) require.True(t, ok)
assert.True(t, a.A.IsUnspecified()) assert.True(t, a.A.IsUnspecified())
} }
@ -695,7 +699,7 @@ func TestClientRulesForCNAMEMatching(t *testing.T) {
} }
s := createTestServer(t, &filtering.Config{}, forwardConf, nil) s := createTestServer(t, &filtering.Config{}, forwardConf, nil)
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&aghtest.TestUpstream{ &aghtest.Upstream{
CName: testCNAMEs, CName: testCNAMEs,
IPv4: testIPv4, IPv4: testIPv4,
}, },
@ -931,9 +935,9 @@ func TestRewrite(t *testing.T) {
})) }))
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&aghtest.TestUpstream{ &aghtest.Upstream{
CName: map[string]string{ CName: map[string][]string{
"example.org": "somename", "example.org": {"somename"},
}, },
IPv4: map[string][]net.IP{ IPv4: map[string][]net.IP{
"example.org.": {{4, 3, 2, 1}}, "example.org.": {{4, 3, 2, 1}},
@ -1193,12 +1197,12 @@ func TestNewServer(t *testing.T) {
} }
func TestServer_Exchange(t *testing.T) { func TestServer_Exchange(t *testing.T) {
extUpstream := &aghtest.TestUpstream{ extUpstream := &aghtest.Upstream{
Reverse: map[string][]string{ Reverse: map[string][]string{
"1.1.1.1.in-addr.arpa.": {"one.one.one.one"}, "1.1.1.1.in-addr.arpa.": {"one.one.one.one"},
}, },
} }
locUpstream := &aghtest.TestUpstream{ locUpstream := &aghtest.Upstream{
Reverse: map[string][]string{ Reverse: map[string][]string{
"1.1.168.192.in-addr.arpa.": {"local.domain"}, "1.1.168.192.in-addr.arpa.": {"local.domain"},
"2.1.168.192.in-addr.arpa.": {}, "2.1.168.192.in-addr.arpa.": {},

View File

@ -116,7 +116,7 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*filtering.Result, error) {
// checkHostRules checks the host against filters. It is safe for concurrent // checkHostRules checks the host against filters. It is safe for concurrent
// use. // use.
func (s *Server) checkHostRules(host string, qtype uint16, setts *filtering.Settings) ( func (s *Server) checkHostRules(host string, rrtype uint16, setts *filtering.Settings) (
r *filtering.Result, r *filtering.Result,
err error, err error,
) { ) {
@ -128,7 +128,7 @@ func (s *Server) checkHostRules(host string, qtype uint16, setts *filtering.Sett
} }
var res filtering.Result var res filtering.Result
res, err = s.dnsFilter.CheckHostRules(host, qtype, setts) res, err = s.dnsFilter.CheckHostRules(host, rrtype, setts)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -136,33 +136,36 @@ func (s *Server) checkHostRules(host string, qtype uint16, setts *filtering.Sett
return &res, err return &res, err
} }
// If response contains CNAME, A or AAAA records, we apply filtering to each // filterDNSResponse checks each resource record of the response's answer
// canonical host name or IP address. If this is a match, we set a new response // section from ctx and returns a non-nil res if at least one of canonnical
// in d.Res and return. // names or IP addresses in it matches the filtering rules.
func (s *Server) filterDNSResponse(ctx *dnsContext) (*filtering.Result, error) { func (s *Server) filterDNSResponse(ctx *dnsContext) (res *filtering.Result, err error) {
d := ctx.proxyCtx d := ctx.proxyCtx
setts := ctx.setts
if !setts.FilteringEnabled {
return nil, nil
}
for _, a := range d.Res.Answer { for _, a := range d.Res.Answer {
host := "" host := ""
var rrtype uint16
switch v := a.(type) { switch a := a.(type) {
case *dns.CNAME: case *dns.CNAME:
log.Debug("DNSFwd: Checking CNAME %s for %s", v.Target, v.Hdr.Name) host = strings.TrimSuffix(a.Target, ".")
host = strings.TrimSuffix(v.Target, ".") rrtype = dns.TypeCNAME
case *dns.A: case *dns.A:
host = v.A.String() host = a.A.String()
log.Debug("DNSFwd: Checking record A (%s) for %s", host, v.Hdr.Name) rrtype = dns.TypeA
case *dns.AAAA: case *dns.AAAA:
host = v.AAAA.String() host = a.AAAA.String()
log.Debug("DNSFwd: Checking record AAAA (%s) for %s", host, v.Hdr.Name) rrtype = dns.TypeAAAA
default: default:
continue continue
} }
host = strings.TrimSuffix(host, ".") log.Debug("dnsforward: checking %s %s for %s", dns.Type(rrtype), host, a.Header().Name)
res, err := s.checkHostRules(host, d.Req.Question[0].Qtype, ctx.setts)
res, err = s.checkHostRules(host, rrtype, setts)
if err != nil { if err != nil {
return nil, err return nil, err
} else if res == nil { } else if res == nil {

View File

@ -0,0 +1,159 @@
package dnsforward
import (
"net"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
rules := `
||blocked.domain^
@@||allowed.domain^
||cname.specific^$dnstype=~CNAME
||0.0.0.1^$dnstype=~A
||::1^$dnstype=~AAAA
`
forwardConf := ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
FilteringConfig: FilteringConfig{
ProtectionEnabled: true,
BlockingMode: BlockingModeDefault,
},
}
filters := []filtering.Filter{{
ID: 0, Data: []byte(rules),
}}
f := filtering.New(&filtering.Config{}, filters)
f.SetEnabled(true)
snd, err := aghnet.NewSubnetDetector()
require.NoError(t, err)
require.NotNil(t, snd)
s, err := NewServer(DNSCreateParams{
DHCPServer: &testDHCP{},
DNSFilter: f,
SubnetDetector: snd,
})
require.NoError(t, err)
s.conf = forwardConf
err = s.Prepare(nil)
require.NoError(t, err)
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&aghtest.Upstream{
CName: map[string][]string{
"cname.exception.": {"cname.specific."},
"should.block.": {"blocked.domain."},
"allowed.first.": {"allowed.domain.", "blocked.domain."},
"blocked.first.": {"blocked.domain.", "allowed.domain."},
},
IPv4: map[string][]net.IP{
"a.exception.": {{0, 0, 0, 1}},
},
IPv6: map[string][]net.IP{
"aaaa.exception.": {net.ParseIP("::1")},
},
},
}
startDeferStop(t, s)
testCases := []struct {
req *dns.Msg
name string
wantAns []dns.RR
}{{
req: createTestMessage("cname.exception."),
name: "cname_exception",
wantAns: []dns.RR{&dns.CNAME{
Hdr: dns.RR_Header{
Name: "cname.exception.",
Rrtype: dns.TypeCNAME,
},
Target: "cname.specific.",
}},
}, {
req: createTestMessage("should.block."),
name: "blocked_by_cname",
wantAns: []dns.RR{&dns.A{
Hdr: dns.RR_Header{
Name: "should.block.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
},
A: netutil.IPv4Zero(),
}},
}, {
req: createTestMessage("a.exception."),
name: "a_exception",
wantAns: []dns.RR{&dns.A{
Hdr: dns.RR_Header{
Name: "a.exception.",
Rrtype: dns.TypeA,
},
A: net.IP{0, 0, 0, 1},
}},
}, {
req: createTestMessageWithType("aaaa.exception.", dns.TypeAAAA),
name: "aaaa_exception",
wantAns: []dns.RR{&dns.AAAA{
Hdr: dns.RR_Header{
Name: "aaaa.exception.",
Rrtype: dns.TypeAAAA,
},
AAAA: net.ParseIP("::1"),
}},
}, {
req: createTestMessage("allowed.first."),
name: "allowed_first",
wantAns: []dns.RR{&dns.A{
Hdr: dns.RR_Header{
Name: "allowed.first.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
},
A: netutil.IPv4Zero(),
}},
}, {
req: createTestMessage("blocked.first."),
name: "blocked_first",
wantAns: []dns.RR{&dns.A{
Hdr: dns.RR_Header{
Name: "blocked.first.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
},
A: netutil.IPv4Zero(),
}},
}}
for _, tc := range testCases {
dctx := &proxy.DNSContext{
Proto: proxy.ProtoUDP,
Req: tc.req,
Addr: &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: 1},
}
t.Run(tc.name, func(t *testing.T) {
err = s.handleDNSRequest(nil, dctx)
require.NoError(t, err)
require.NotNil(t, dctx.Res)
assert.Equal(t, tc.wantAns, dctx.Res.Answer)
})
}
}

View File

@ -420,14 +420,8 @@ func (r Reason) Matched() bool {
} }
// CheckHostRules tries to match the host against filtering rules only. // CheckHostRules tries to match the host against filtering rules only.
func (d *DNSFilter) CheckHostRules(host string, qtype uint16, setts *Settings) (Result, error) { func (d *DNSFilter) CheckHostRules(host string, rrtype uint16, setts *Settings) (Result, error) {
if !setts.FilteringEnabled { return d.matchHost(strings.ToLower(host), rrtype, setts)
return Result{}, nil
}
host = strings.ToLower(host)
return d.matchHost(host, qtype, setts)
} }
// CheckHost tries to match the host against filtering rules, then safebrowsing // CheckHost tries to match the host against filtering rules, then safebrowsing
@ -726,8 +720,7 @@ func hostRulesToRules(netRules []*rules.HostRule) (res []rules.Rule) {
return res return res
} }
// matchHostProcessAllowList processes the allowlist logic of host // matchHostProcessAllowList processes the allowlist logic of host matching.
// matching.
func (d *DNSFilter) matchHostProcessAllowList( func (d *DNSFilter) matchHostProcessAllowList(
host string, host string,
dnsres *urlfilter.DNSResult, dnsres *urlfilter.DNSResult,
@ -798,11 +791,11 @@ func (d *DNSFilter) matchHostProcessDNSResult(
return Result{} return Result{}
} }
// matchHost is a low-level way to check only if hostname is filtered by rules, // matchHost is a low-level way to check only if host is filtered by rules,
// skipping expensive safebrowsing and parental lookups. // skipping expensive safebrowsing and parental lookups.
func (d *DNSFilter) matchHost( func (d *DNSFilter) matchHost(
host string, host string,
qtype uint16, rrtype uint16,
setts *Settings, setts *Settings,
) (res Result, err error) { ) (res Result, err error) {
if !setts.FilteringEnabled { if !setts.FilteringEnabled {
@ -815,7 +808,7 @@ func (d *DNSFilter) matchHost(
// TODO(e.burkov): Wait for urlfilter update to pass net.IP. // TODO(e.burkov): Wait for urlfilter update to pass net.IP.
ClientIP: setts.ClientIP.String(), ClientIP: setts.ClientIP.String(),
ClientName: setts.ClientName, ClientName: setts.ClientName,
DNSType: qtype, DNSType: rrtype,
} }
d.engineLock.RLock() d.engineLock.RLock()
@ -855,7 +848,7 @@ func (d *DNSFilter) matchHost(
return Result{}, nil return Result{}, nil
} }
res = d.matchHostProcessDNSResult(qtype, dnsres) res = d.matchHostProcessDNSResult(rrtype, dnsres)
for _, r := range res.Rules { for _, r := range res.Rules {
log.Debug( log.Debug(
"filtering: found rule %q for host %q, filter list id: %d", "filtering: found rule %q for host %q, filter list id: %d",

View File

@ -167,7 +167,7 @@ func TestRDNS_WorkerLoop(t *testing.T) {
w := &bytes.Buffer{} w := &bytes.Buffer{}
aghtest.ReplaceLogWriter(t, w) aghtest.ReplaceLogWriter(t, w)
locUpstream := &aghtest.TestUpstream{ locUpstream := &aghtest.Upstream{
Reverse: map[string][]string{ Reverse: map[string][]string{
"192.168.1.1": {"local.domain"}, "192.168.1.1": {"local.domain"},
"2a00:1450:400c:c06::93": {"ipv6.domain"}, "2a00:1450:400c:c06::93": {"ipv6.domain"},