From e783564084c8396fa0f216103e339d9c6991981a Mon Sep 17 00:00:00 2001 From: Eugene Burkov Date: Thu, 3 Feb 2022 21:19:32 +0300 Subject: [PATCH] Pull request: 4238 response filtering Merge in DNS/adguard-home from 4238-response-filtering to master Closes #4238. Squashed commit of the following: commit 2113f83c617a396a39f910bb8df939364fedf391 Author: Eugene Burkov Date: Thu Feb 3 21:04:25 2022 +0300 dnsforward: restore a bit commit f78607ed97892557c6bd6f3c3332f0bae01c1987 Author: Eugene Burkov Date: Thu Feb 3 20:52:45 2022 +0300 all: imp code, docs commit 646074ce141e8ac12a972f46d071389a2ce124e4 Author: Eugene Burkov Date: Thu Feb 3 20:37:05 2022 +0300 all: log changes commit 94556d810549370fc455bcf14537fa1d2783eed1 Author: Eugene Burkov Date: Thu Feb 3 20:30:57 2022 +0300 all: imp test upstream, cover resp filtering commit 63e7721822a049734a390c7d7ea6d8416a43c8b5 Author: Eugene Burkov Date: Tue Feb 1 21:58:08 2022 +0300 all: filter response by rrtype --- CHANGELOG.md | 3 + internal/aghtest/aghtest.go | 20 ++-- internal/aghtest/upstream.go | 83 ++++--------- internal/dnsforward/dns.go | 6 +- internal/dnsforward/dns_test.go | 4 +- internal/dnsforward/dnsforward_test.go | 46 +++---- internal/dnsforward/filter.go | 41 ++++--- internal/dnsforward/filter_test.go | 159 +++++++++++++++++++++++++ internal/filtering/filtering.go | 21 ++-- internal/home/rdns_test.go | 2 +- 10 files changed, 257 insertions(+), 128 deletions(-) create mode 100644 internal/dnsforward/filter_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 5ad22454..1f00e1d2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,8 @@ and this project adheres to ### Changed +- Response filtering is now performed using the record types of the answer + section of messages as opposed to the type of the question ([#4238]). - Instead of adding the build time information, the build scripts now use the standardized environment variable [`SOURCE_DATE_EPOCH`][repr] to add the date of the commit from which the binary was built ([#4221]). This should simplify @@ -80,6 +82,7 @@ In this release, the schema version has changed from 12 to 13. [#3057]: https://github.com/AdguardTeam/AdGuardHome/issues/3057 [#3367]: https://github.com/AdguardTeam/AdGuardHome/issues/3367 [#4221]: https://github.com/AdguardTeam/AdGuardHome/issues/4221 +[#4238]: https://github.com/AdguardTeam/AdGuardHome/issues/4238 [repr]: https://reproducible-builds.org/docs/source-date-epoch/ diff --git a/internal/aghtest/aghtest.go b/internal/aghtest/aghtest.go index 0e7f600c..878ef178 100644 --- a/internal/aghtest/aghtest.go +++ b/internal/aghtest/aghtest.go @@ -20,17 +20,19 @@ func DiscardLogOutput(m *testing.M) { // ReplaceLogWriter moves logger output to w and uses Cleanup method of t to // revert changes. -func ReplaceLogWriter(t *testing.T, w io.Writer) { - stdWriter := log.Writer() - t.Cleanup(func() { - log.SetOutput(stdWriter) - }) +func ReplaceLogWriter(t testing.TB, w io.Writer) { + t.Helper() + + prev := log.Writer() + t.Cleanup(func() { log.SetOutput(prev) }) log.SetOutput(w) } // ReplaceLogLevel sets logging level to l and uses Cleanup method of t to // revert changes. -func ReplaceLogLevel(t *testing.T, l log.Level) { +func ReplaceLogLevel(t testing.TB, l log.Level) { + t.Helper() + switch l { case log.INFO, log.DEBUG, log.ERROR: // Go on. @@ -38,9 +40,7 @@ func ReplaceLogLevel(t *testing.T, l log.Level) { t.Fatalf("wrong l value (must be one of %v, %v, %v)", log.INFO, log.DEBUG, log.ERROR) } - stdLevel := log.GetLevel() - t.Cleanup(func() { - log.SetLevel(stdLevel) - }) + prev := log.GetLevel() + t.Cleanup(func() { log.SetLevel(prev) }) log.SetLevel(l) } diff --git a/internal/aghtest/upstream.go b/internal/aghtest/upstream.go index aa364310..95d8f5ad 100644 --- a/internal/aghtest/upstream.go +++ b/internal/aghtest/upstream.go @@ -11,10 +11,10 @@ import ( "github.com/miekg/dns" ) -// TestUpstream is a mock of real upstream. -type TestUpstream struct { +// Upstream is a mock implementation of upstream.Upstream. +type Upstream struct { // CName is a map of hostname to canonical name. - CName map[string]string + CName map[string][]string // IPv4 is a map of hostname to IPv4. IPv4 map[string][]net.IP // IPv6 is a map of hostname to IPv6. @@ -25,78 +25,45 @@ type TestUpstream struct { Addr string } -// Exchange implements upstream.Upstream interface for *TestUpstream. +// Exchange implements the upstream.Upstream interface for *Upstream. // // TODO(a.garipov): Split further into handlers. -func (u *TestUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) { - resp = &dns.Msg{} - resp.SetReply(m) +func (u *Upstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) { + resp = new(dns.Msg).SetReply(m) if len(m.Question) == 0 { return nil, fmt.Errorf("question should not be empty") } - name := m.Question[0].Name - - if cname, ok := u.CName[name]; ok { - ans := &dns.CNAME{ - Hdr: dns.RR_Header{ - Name: name, - Rrtype: dns.TypeCNAME, - }, + q := m.Question[0] + name := q.Name + for _, cname := range u.CName[name] { + resp.Answer = append(resp.Answer, &dns.CNAME{ + Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeCNAME}, Target: cname, - } - - resp.Answer = append(resp.Answer, ans) + }) } - rrType := m.Question[0].Qtype + qtype := q.Qtype hdr := dns.RR_Header{ Name: name, - Rrtype: rrType, + Rrtype: qtype, } - var names []string - var ips []net.IP - switch m.Question[0].Qtype { + switch qtype { case dns.TypeA: - ips = u.IPv4[name] + for _, ip := range u.IPv4[name] { + resp.Answer = append(resp.Answer, &dns.A{Hdr: hdr, A: ip}) + } case dns.TypeAAAA: - ips = u.IPv6[name] + for _, ip := range u.IPv6[name] { + resp.Answer = append(resp.Answer, &dns.AAAA{Hdr: hdr, AAAA: ip}) + } case dns.TypePTR: - names = u.Reverse[name] - } - - for _, ip := range ips { - var ans dns.RR - if rrType == dns.TypeA { - ans = &dns.A{ - Hdr: hdr, - A: ip, - } - - resp.Answer = append(resp.Answer, ans) - - continue + for _, name := range u.Reverse[name] { + resp.Answer = append(resp.Answer, &dns.PTR{Hdr: hdr, Ptr: name}) } - - ans = &dns.AAAA{ - Hdr: hdr, - AAAA: ip, - } - - resp.Answer = append(resp.Answer, ans) } - - for _, n := range names { - ans := &dns.PTR{ - Hdr: hdr, - Ptr: n, - } - - resp.Answer = append(resp.Answer, ans) - } - if len(resp.Answer) == 0 { resp.SetRcode(m, dns.RcodeNameError) } @@ -104,8 +71,8 @@ func (u *TestUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) { return resp, nil } -// Address implements upstream.Upstream interface for *TestUpstream. -func (u *TestUpstream) Address() string { +// Address implements upstream.Upstream interface for *Upstream. +func (u *Upstream) Address() string { return u.Addr } diff --git a/internal/dnsforward/dns.go b/internal/dnsforward/dns.go index c11171d9..87cb5194 100644 --- a/internal/dnsforward/dns.go +++ b/internal/dnsforward/dns.go @@ -613,9 +613,9 @@ func (s *Server) processFilteringAfterResponse(ctx *dnsContext) (rc resultCode) d.Res.Answer = answer } default: - // Check the response only if the it's from an upstream. Don't check - // the response if the protection is disabled since dnsrewrite rules - // aren't applied to it anyway. + // Check the response only if it's from an upstream. Don't check the + // response if the protection is disabled since dnsrewrite rules aren't + // applied to it anyway. if !ctx.protectionEnabled || !ctx.responseFromUpstream || s.dnsFilter == nil { break } diff --git a/internal/dnsforward/dns_test.go b/internal/dnsforward/dns_test.go index edf54f51..4fc87ccf 100644 --- a/internal/dnsforward/dns_test.go +++ b/internal/dnsforward/dns_test.go @@ -261,7 +261,7 @@ func TestServer_ProcessInternalHosts(t *testing.T) { } func TestServer_ProcessRestrictLocal(t *testing.T) { - ups := &aghtest.TestUpstream{ + ups := &aghtest.Upstream{ Reverse: map[string][]string{ "251.252.253.254.in-addr.arpa.": {"host1.example.net."}, "1.1.168.192.in-addr.arpa.": {"some.local-client."}, @@ -339,7 +339,7 @@ func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) { s := createTestServer(t, &filtering.Config{}, ServerConfig{ UDPListenAddrs: []*net.UDPAddr{{}}, TCPListenAddrs: []*net.TCPAddr{{}}, - }, &aghtest.TestUpstream{ + }, &aghtest.Upstream{ Reverse: map[string][]string{ reqAddr: {locDomain}, }, diff --git a/internal/dnsforward/dnsforward_test.go b/internal/dnsforward/dnsforward_test.go index d0191c85..bc90d760 100644 --- a/internal/dnsforward/dnsforward_test.go +++ b/internal/dnsforward/dnsforward_test.go @@ -89,7 +89,7 @@ func createTestServer( defer s.serverLock.Unlock() if localUps != nil { - s.localResolvers.Config.UpstreamConfig.Upstreams = []upstream.Upstream{localUps} + s.localResolvers.UpstreamConfig.Upstreams = []upstream.Upstream{localUps} s.conf.UsePrivateRDNS = true } @@ -247,7 +247,7 @@ func TestServer(t *testing.T) { TCPListenAddrs: []*net.TCPAddr{{}}, }, nil) s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ - &aghtest.TestUpstream{ + &aghtest.Upstream{ IPv4: map[string][]net.IP{ "google-public-dns-a.google.com.": {{8, 8, 8, 8}}, }, @@ -316,7 +316,7 @@ func TestServerWithProtectionDisabled(t *testing.T) { TCPListenAddrs: []*net.TCPAddr{{}}, }, nil) s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ - &aghtest.TestUpstream{ + &aghtest.Upstream{ IPv4: map[string][]net.IP{ "google-public-dns-a.google.com.": {{8, 8, 8, 8}}, }, @@ -339,7 +339,7 @@ func TestDoTServer(t *testing.T) { TLSListenAddrs: []*net.TCPAddr{{}}, }) s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ - &aghtest.TestUpstream{ + &aghtest.Upstream{ IPv4: map[string][]net.IP{ "google-public-dns-a.google.com.": {{8, 8, 8, 8}}, }, @@ -369,7 +369,7 @@ func TestDoQServer(t *testing.T) { QUICListenAddrs: []*net.UDPAddr{{IP: net.IP{127, 0, 0, 1}}}, }) s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ - &aghtest.TestUpstream{ + &aghtest.Upstream{ IPv4: map[string][]net.IP{ "google-public-dns-a.google.com.": {{8, 8, 8, 8}}, }, @@ -413,7 +413,7 @@ func TestServerRace(t *testing.T) { } s := createTestServer(t, filterConf, forwardConf, nil) s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ - &aghtest.TestUpstream{ + &aghtest.Upstream{ IPv4: map[string][]net.IP{ "google-public-dns-a.google.com.": {{8, 8, 8, 8}}, }, @@ -552,7 +552,7 @@ func TestServerCustomClientUpstream(t *testing.T) { } s := createTestServer(t, &filtering.Config{}, forwardConf, nil) s.conf.GetCustomUpstreamByClient = func(_ string) (conf *proxy.UpstreamConfig, err error) { - ups := &aghtest.TestUpstream{ + ups := &aghtest.Upstream{ IPv4: map[string][]net.IP{ "host.": {{192, 168, 0, 1}}, }, @@ -580,9 +580,9 @@ func TestServerCustomClientUpstream(t *testing.T) { } // testCNAMEs is a map of names and CNAMEs necessary for the TestUpstream work. -var testCNAMEs = map[string]string{ - "badhost.": "NULL.example.org.", - "whitelist.example.org.": "NULL.example.org.", +var testCNAMEs = map[string][]string{ + "badhost.": {"NULL.example.org."}, + "whitelist.example.org.": {"NULL.example.org."}, } // testIPv4 is a map of names and IPv4s necessary for the TestUpstream work. @@ -596,7 +596,7 @@ func TestBlockCNAMEProtectionEnabled(t *testing.T) { UDPListenAddrs: []*net.UDPAddr{{}}, TCPListenAddrs: []*net.TCPAddr{{}}, }, nil) - testUpstm := &aghtest.TestUpstream{ + testUpstm := &aghtest.Upstream{ CName: testCNAMEs, IPv4: testIPv4, IPv6: nil, @@ -630,7 +630,7 @@ func TestBlockCNAME(t *testing.T) { } s := createTestServer(t, &filtering.Config{}, forwardConf, nil) s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ - &aghtest.TestUpstream{ + &aghtest.Upstream{ CName: testCNAMEs, IPv4: testIPv4, }, @@ -640,14 +640,17 @@ func TestBlockCNAME(t *testing.T) { addr := s.dnsProxy.Addr(proxy.ProtoUDP).String() testCases := []struct { + name string host string want bool }{{ + name: "block_request", host: "badhost.", // 'badhost' has a canonical name 'NULL.example.org' which is // blocked by filters: response is blocked. want: true, }, { + name: "allowed", host: "whitelist.example.org.", // 'whitelist.example.org' has a canonical name // 'NULL.example.org' which is blocked by filters @@ -655,6 +658,7 @@ func TestBlockCNAME(t *testing.T) { // response isn't blocked. want: false, }, { + name: "block_response", host: "example.org.", // 'example.org' has a canonical name 'cname1' with IP // 127.0.0.255 which is blocked by filters: response is blocked. @@ -662,9 +666,9 @@ func TestBlockCNAME(t *testing.T) { }} for _, tc := range testCases { - t.Run("block_cname_"+tc.host, func(t *testing.T) { - req := createTestMessage(tc.host) + req := createTestMessage(tc.host) + t.Run(tc.name, func(t *testing.T) { reply, err := dns.Exchange(req, addr) require.NoError(t, err) @@ -674,7 +678,7 @@ func TestBlockCNAME(t *testing.T) { ans := reply.Answer[0] a, ok := ans.(*dns.A) - require.Truef(t, ok, "got %T", ans) + require.True(t, ok) assert.True(t, a.A.IsUnspecified()) } @@ -695,7 +699,7 @@ func TestClientRulesForCNAMEMatching(t *testing.T) { } s := createTestServer(t, &filtering.Config{}, forwardConf, nil) s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ - &aghtest.TestUpstream{ + &aghtest.Upstream{ CName: testCNAMEs, IPv4: testIPv4, }, @@ -931,9 +935,9 @@ func TestRewrite(t *testing.T) { })) s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ - &aghtest.TestUpstream{ - CName: map[string]string{ - "example.org": "somename", + &aghtest.Upstream{ + CName: map[string][]string{ + "example.org": {"somename"}, }, IPv4: map[string][]net.IP{ "example.org.": {{4, 3, 2, 1}}, @@ -1193,12 +1197,12 @@ func TestNewServer(t *testing.T) { } func TestServer_Exchange(t *testing.T) { - extUpstream := &aghtest.TestUpstream{ + extUpstream := &aghtest.Upstream{ Reverse: map[string][]string{ "1.1.1.1.in-addr.arpa.": {"one.one.one.one"}, }, } - locUpstream := &aghtest.TestUpstream{ + locUpstream := &aghtest.Upstream{ Reverse: map[string][]string{ "1.1.168.192.in-addr.arpa.": {"local.domain"}, "2.1.168.192.in-addr.arpa.": {}, diff --git a/internal/dnsforward/filter.go b/internal/dnsforward/filter.go index 471b463e..18f12797 100644 --- a/internal/dnsforward/filter.go +++ b/internal/dnsforward/filter.go @@ -116,7 +116,7 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*filtering.Result, error) { // checkHostRules checks the host against filters. It is safe for concurrent // use. -func (s *Server) checkHostRules(host string, qtype uint16, setts *filtering.Settings) ( +func (s *Server) checkHostRules(host string, rrtype uint16, setts *filtering.Settings) ( r *filtering.Result, err error, ) { @@ -128,7 +128,7 @@ func (s *Server) checkHostRules(host string, qtype uint16, setts *filtering.Sett } var res filtering.Result - res, err = s.dnsFilter.CheckHostRules(host, qtype, setts) + res, err = s.dnsFilter.CheckHostRules(host, rrtype, setts) if err != nil { return nil, err } @@ -136,33 +136,36 @@ func (s *Server) checkHostRules(host string, qtype uint16, setts *filtering.Sett return &res, err } -// If response contains CNAME, A or AAAA records, we apply filtering to each -// canonical host name or IP address. If this is a match, we set a new response -// in d.Res and return. -func (s *Server) filterDNSResponse(ctx *dnsContext) (*filtering.Result, error) { +// filterDNSResponse checks each resource record of the response's answer +// section from ctx and returns a non-nil res if at least one of canonnical +// names or IP addresses in it matches the filtering rules. +func (s *Server) filterDNSResponse(ctx *dnsContext) (res *filtering.Result, err error) { d := ctx.proxyCtx + setts := ctx.setts + if !setts.FilteringEnabled { + return nil, nil + } + for _, a := range d.Res.Answer { host := "" - - switch v := a.(type) { + var rrtype uint16 + switch a := a.(type) { case *dns.CNAME: - log.Debug("DNSFwd: Checking CNAME %s for %s", v.Target, v.Hdr.Name) - host = strings.TrimSuffix(v.Target, ".") - + host = strings.TrimSuffix(a.Target, ".") + rrtype = dns.TypeCNAME case *dns.A: - host = v.A.String() - log.Debug("DNSFwd: Checking record A (%s) for %s", host, v.Hdr.Name) - + host = a.A.String() + rrtype = dns.TypeA case *dns.AAAA: - host = v.AAAA.String() - log.Debug("DNSFwd: Checking record AAAA (%s) for %s", host, v.Hdr.Name) - + host = a.AAAA.String() + rrtype = dns.TypeAAAA default: continue } - host = strings.TrimSuffix(host, ".") - res, err := s.checkHostRules(host, d.Req.Question[0].Qtype, ctx.setts) + log.Debug("dnsforward: checking %s %s for %s", dns.Type(rrtype), host, a.Header().Name) + + res, err = s.checkHostRules(host, rrtype, setts) if err != nil { return nil, err } else if res == nil { diff --git a/internal/dnsforward/filter_test.go b/internal/dnsforward/filter_test.go new file mode 100644 index 00000000..84570bce --- /dev/null +++ b/internal/dnsforward/filter_test.go @@ -0,0 +1,159 @@ +package dnsforward + +import ( + "net" + "testing" + + "github.com/AdguardTeam/AdGuardHome/internal/aghnet" + "github.com/AdguardTeam/AdGuardHome/internal/aghtest" + "github.com/AdguardTeam/AdGuardHome/internal/filtering" + "github.com/AdguardTeam/dnsproxy/proxy" + "github.com/AdguardTeam/dnsproxy/upstream" + "github.com/AdguardTeam/golibs/netutil" + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestHandleDNSRequest_filterDNSResponse(t *testing.T) { + rules := ` +||blocked.domain^ +@@||allowed.domain^ +||cname.specific^$dnstype=~CNAME +||0.0.0.1^$dnstype=~A +||::1^$dnstype=~AAAA +` + + forwardConf := ServerConfig{ + UDPListenAddrs: []*net.UDPAddr{{}}, + TCPListenAddrs: []*net.TCPAddr{{}}, + FilteringConfig: FilteringConfig{ + ProtectionEnabled: true, + BlockingMode: BlockingModeDefault, + }, + } + filters := []filtering.Filter{{ + ID: 0, Data: []byte(rules), + }} + + f := filtering.New(&filtering.Config{}, filters) + f.SetEnabled(true) + + snd, err := aghnet.NewSubnetDetector() + require.NoError(t, err) + require.NotNil(t, snd) + + s, err := NewServer(DNSCreateParams{ + DHCPServer: &testDHCP{}, + DNSFilter: f, + SubnetDetector: snd, + }) + require.NoError(t, err) + + s.conf = forwardConf + err = s.Prepare(nil) + require.NoError(t, err) + + s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ + &aghtest.Upstream{ + CName: map[string][]string{ + "cname.exception.": {"cname.specific."}, + "should.block.": {"blocked.domain."}, + "allowed.first.": {"allowed.domain.", "blocked.domain."}, + "blocked.first.": {"blocked.domain.", "allowed.domain."}, + }, + IPv4: map[string][]net.IP{ + "a.exception.": {{0, 0, 0, 1}}, + }, + IPv6: map[string][]net.IP{ + "aaaa.exception.": {net.ParseIP("::1")}, + }, + }, + } + startDeferStop(t, s) + + testCases := []struct { + req *dns.Msg + name string + wantAns []dns.RR + }{{ + req: createTestMessage("cname.exception."), + name: "cname_exception", + wantAns: []dns.RR{&dns.CNAME{ + Hdr: dns.RR_Header{ + Name: "cname.exception.", + Rrtype: dns.TypeCNAME, + }, + Target: "cname.specific.", + }}, + }, { + req: createTestMessage("should.block."), + name: "blocked_by_cname", + wantAns: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{ + Name: "should.block.", + Rrtype: dns.TypeA, + Class: dns.ClassINET, + }, + A: netutil.IPv4Zero(), + }}, + }, { + req: createTestMessage("a.exception."), + name: "a_exception", + wantAns: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{ + Name: "a.exception.", + Rrtype: dns.TypeA, + }, + A: net.IP{0, 0, 0, 1}, + }}, + }, { + req: createTestMessageWithType("aaaa.exception.", dns.TypeAAAA), + name: "aaaa_exception", + wantAns: []dns.RR{&dns.AAAA{ + Hdr: dns.RR_Header{ + Name: "aaaa.exception.", + Rrtype: dns.TypeAAAA, + }, + AAAA: net.ParseIP("::1"), + }}, + }, { + req: createTestMessage("allowed.first."), + name: "allowed_first", + wantAns: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{ + Name: "allowed.first.", + Rrtype: dns.TypeA, + Class: dns.ClassINET, + }, + A: netutil.IPv4Zero(), + }}, + }, { + req: createTestMessage("blocked.first."), + name: "blocked_first", + wantAns: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{ + Name: "blocked.first.", + Rrtype: dns.TypeA, + Class: dns.ClassINET, + }, + A: netutil.IPv4Zero(), + }}, + }} + + for _, tc := range testCases { + dctx := &proxy.DNSContext{ + Proto: proxy.ProtoUDP, + Req: tc.req, + Addr: &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: 1}, + } + + t.Run(tc.name, func(t *testing.T) { + err = s.handleDNSRequest(nil, dctx) + require.NoError(t, err) + require.NotNil(t, dctx.Res) + + assert.Equal(t, tc.wantAns, dctx.Res.Answer) + }) + } +} diff --git a/internal/filtering/filtering.go b/internal/filtering/filtering.go index 0a8bcd6e..9bedeef7 100644 --- a/internal/filtering/filtering.go +++ b/internal/filtering/filtering.go @@ -420,14 +420,8 @@ func (r Reason) Matched() bool { } // CheckHostRules tries to match the host against filtering rules only. -func (d *DNSFilter) CheckHostRules(host string, qtype uint16, setts *Settings) (Result, error) { - if !setts.FilteringEnabled { - return Result{}, nil - } - - host = strings.ToLower(host) - - return d.matchHost(host, qtype, setts) +func (d *DNSFilter) CheckHostRules(host string, rrtype uint16, setts *Settings) (Result, error) { + return d.matchHost(strings.ToLower(host), rrtype, setts) } // CheckHost tries to match the host against filtering rules, then safebrowsing @@ -726,8 +720,7 @@ func hostRulesToRules(netRules []*rules.HostRule) (res []rules.Rule) { return res } -// matchHostProcessAllowList processes the allowlist logic of host -// matching. +// matchHostProcessAllowList processes the allowlist logic of host matching. func (d *DNSFilter) matchHostProcessAllowList( host string, dnsres *urlfilter.DNSResult, @@ -798,11 +791,11 @@ func (d *DNSFilter) matchHostProcessDNSResult( return Result{} } -// matchHost is a low-level way to check only if hostname is filtered by rules, +// matchHost is a low-level way to check only if host is filtered by rules, // skipping expensive safebrowsing and parental lookups. func (d *DNSFilter) matchHost( host string, - qtype uint16, + rrtype uint16, setts *Settings, ) (res Result, err error) { if !setts.FilteringEnabled { @@ -815,7 +808,7 @@ func (d *DNSFilter) matchHost( // TODO(e.burkov): Wait for urlfilter update to pass net.IP. ClientIP: setts.ClientIP.String(), ClientName: setts.ClientName, - DNSType: qtype, + DNSType: rrtype, } d.engineLock.RLock() @@ -855,7 +848,7 @@ func (d *DNSFilter) matchHost( return Result{}, nil } - res = d.matchHostProcessDNSResult(qtype, dnsres) + res = d.matchHostProcessDNSResult(rrtype, dnsres) for _, r := range res.Rules { log.Debug( "filtering: found rule %q for host %q, filter list id: %d", diff --git a/internal/home/rdns_test.go b/internal/home/rdns_test.go index 202f9f5f..08f4f013 100644 --- a/internal/home/rdns_test.go +++ b/internal/home/rdns_test.go @@ -167,7 +167,7 @@ func TestRDNS_WorkerLoop(t *testing.T) { w := &bytes.Buffer{} aghtest.ReplaceLogWriter(t, w) - locUpstream := &aghtest.TestUpstream{ + locUpstream := &aghtest.Upstream{ Reverse: map[string][]string{ "192.168.1.1": {"local.domain"}, "2a00:1450:400c:c06::93": {"ipv6.domain"},