diff --git a/internal/dnsfilter/dnsfilter.go b/internal/dnsfilter/dnsfilter.go index 45ce5b8a..6ade9701 100644 --- a/internal/dnsfilter/dnsfilter.go +++ b/internal/dnsfilter/dnsfilter.go @@ -2,6 +2,7 @@ package dnsfilter import ( + "context" "fmt" "io/ioutil" "net" @@ -91,6 +92,12 @@ type filtersInitializerParams struct { blockFilters []Filter } +// Resolver is the interface for net.Resolver to simplify testing. +type Resolver interface { + // TODO(e.burkov): Replace with LookupIP after upgrading go to v1.15. + LookupIPAddr(ctx context.Context, host string) (ips []net.IPAddr, err error) +} + // DNSFilter matches hostnames and DNS requests against filtering rules. type DNSFilter struct { rulesStorage *filterlist.RuleStorage @@ -110,6 +117,11 @@ type DNSFilter struct { // Channel for passing data to filters-initializer goroutine filtersInitializerChan chan filtersInitializerParams filtersInitializerLock sync.Mutex + + // resolver only looks up the IP address of the host while safe search. + // + // TODO(e.burkov): Use upstream that configured in dnsforward instead. + resolver Resolver } // Filter represents a filter list @@ -805,7 +817,9 @@ func New(c *Config, blockFilters []Filter) *DNSFilter { } } - d := new(DNSFilter) + d := &DNSFilter{ + resolver: net.DefaultResolver, + } err := d.initSecurityServices() if err != nil { diff --git a/internal/dnsfilter/dnsfilter_test.go b/internal/dnsfilter/dnsfilter_test.go index abac1fb9..257953ad 100644 --- a/internal/dnsfilter/dnsfilter_test.go +++ b/internal/dnsfilter/dnsfilter_test.go @@ -2,11 +2,14 @@ package dnsfilter import ( "bytes" + "context" + "crypto/sha256" "fmt" "net" "testing" "github.com/AdguardTeam/AdGuardHome/internal/testutil" + "github.com/AdguardTeam/golibs/cache" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/urlfilter/rules" "github.com/miekg/dns" @@ -19,28 +22,21 @@ func TestMain(m *testing.M) { var setts RequestFilteringSettings -// HELPERS -// SAFE BROWSING -// SAFE SEARCH -// PARENTAL -// FILTERING -// BENCHMARKS - -// HELPERS +// Helpers. func purgeCaches() { - if gctx.safebrowsingCache != nil { - gctx.safebrowsingCache.Clear() - } - if gctx.parentalCache != nil { - gctx.parentalCache.Clear() - } - if gctx.safeSearchCache != nil { - gctx.safeSearchCache.Clear() + for _, c := range []cache.Cache{ + gctx.safebrowsingCache, + gctx.parentalCache, + gctx.safeSearchCache, + } { + if c != nil { + c.Clear() + } } } -func NewForTest(c *Config, filters []Filter) *DNSFilter { +func newForTest(c *Config, filters []Filter) *DNSFilter { setts = RequestFilteringSettings{} setts.FilteringEnabled = true if c != nil { @@ -59,48 +55,31 @@ func NewForTest(c *Config, filters []Filter) *DNSFilter { func (d *DNSFilter) checkMatch(t *testing.T, hostname string) { t.Helper() + res, err := d.CheckHost(hostname, dns.TypeA, &setts) - if err != nil { - t.Errorf("Error while matching host %s: %s", hostname, err) - } - if !res.IsFiltered { - t.Errorf("Expected hostname %s to match", hostname) - } + assert.Nilf(t, err, "Error while matching host %s: %s", hostname, err) + assert.Truef(t, res.IsFiltered, "Expected hostname %s to match", hostname) } func (d *DNSFilter) checkMatchIP(t *testing.T, hostname, ip string, qtype uint16) { t.Helper() res, err := d.CheckHost(hostname, qtype, &setts) - if err != nil { - t.Errorf("Error while matching host %s: %s", hostname, err) - } - - if !res.IsFiltered { - t.Errorf("Expected hostname %s to match", hostname) - } - - if len(res.Rules) == 0 { - t.Errorf("Expected result to have rules") - - return - } - - r := res.Rules[0] - if r.IP == nil || r.IP.String() != ip { - t.Errorf("Expected ip %s to match, actual: %v", ip, r.IP) + assert.Nilf(t, err, "Error while matching host %s: %s", hostname, err) + assert.Truef(t, res.IsFiltered, "Expected hostname %s to match", hostname) + if assert.NotEmpty(t, res.Rules, "Expected result to have rules") { + r := res.Rules[0] + assert.NotNilf(t, r.IP, "Expected ip %s to match, actual: %v", ip, r.IP) + assert.Equalf(t, ip, r.IP.String(), "Expected ip %s to match, actual: %v", ip, r.IP) } } func (d *DNSFilter) checkMatchEmpty(t *testing.T, hostname string) { t.Helper() + res, err := d.CheckHost(hostname, dns.TypeA, &setts) - if err != nil { - t.Errorf("Error while matching host %s: %s", hostname, err) - } - if res.IsFiltered { - t.Errorf("Expected hostname %s to not match", hostname) - } + assert.Nilf(t, err, "Error while matching host %s: %s", hostname, err) + assert.Falsef(t, res.IsFiltered, "Expected hostname %s to not match", hostname) } func TestEtcHostsMatching(t *testing.T) { @@ -117,18 +96,18 @@ func TestEtcHostsMatching(t *testing.T) { filters := []Filter{{ ID: 0, Data: []byte(text), }} - d := NewForTest(nil, filters) - defer d.Close() + d := newForTest(nil, filters) + t.Cleanup(d.Close) d.checkMatchIP(t, "google.com", addr, dns.TypeA) d.checkMatchIP(t, "www.google.com", addr, dns.TypeA) d.checkMatchEmpty(t, "subdomain.google.com") d.checkMatchEmpty(t, "example.org") - // IPv4 + // IPv4 match. d.checkMatchIP(t, "block.com", "0.0.0.0", dns.TypeA) - // ...but empty IPv6 + // Empty IPv6. res, err := d.CheckHost("block.com", dns.TypeAAAA, &setts) assert.Nil(t, err) assert.True(t, res.IsFiltered) @@ -137,10 +116,10 @@ func TestEtcHostsMatching(t *testing.T) { assert.Empty(t, res.Rules[0].IP) } - // IPv6 + // IPv6 match. d.checkMatchIP(t, "ipv6.com", addr6, dns.TypeAAAA) - // ...but empty IPv4 + // Empty IPv4. res, err = d.CheckHost("ipv6.com", dns.TypeA, &setts) assert.Nil(t, err) assert.True(t, res.IsFiltered) @@ -149,58 +128,67 @@ func TestEtcHostsMatching(t *testing.T) { assert.Empty(t, res.Rules[0].IP) } - // 2 IPv4 (return only the first one) + // Two IPv4, the first one returned. res, err = d.CheckHost("host2", dns.TypeA, &setts) assert.Nil(t, err) assert.True(t, res.IsFiltered) if assert.Len(t, res.Rules, 1) { - loopback4 := net.IP{0, 0, 0, 1} - assert.Equal(t, res.Rules[0].IP, loopback4) + assert.Equal(t, res.Rules[0].IP, net.IP{0, 0, 0, 1}) } - // ...and 1 IPv6 address + // One IPv6 address. res, err = d.CheckHost("host2", dns.TypeAAAA, &setts) assert.Nil(t, err) assert.True(t, res.IsFiltered) if assert.Len(t, res.Rules, 1) { - loopback6 := net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} - assert.Equal(t, res.Rules[0].IP, loopback6) + assert.Equal(t, res.Rules[0].IP, net.IPv6loopback) } } -// SAFE BROWSING +// Safe Browsing. func TestSafeBrowsing(t *testing.T) { logOutput := &bytes.Buffer{} testutil.ReplaceLogWriter(t, logOutput) testutil.ReplaceLogLevel(t, log.DEBUG) - d := NewForTest(&Config{SafeBrowsingEnabled: true}, nil) - defer d.Close() - d.checkMatch(t, "wmconvirus.narod.ru") + d := newForTest(&Config{SafeBrowsingEnabled: true}, nil) + t.Cleanup(d.Close) + matching := "wmconvirus.narod.ru" + d.safeBrowsingUpstream = &testSbUpstream{ + hostname: matching, + block: true, + } + d.checkMatch(t, matching) - assert.Contains(t, logOutput.String(), "SafeBrowsing lookup for wmconvirus.narod.ru") + assert.Contains(t, logOutput.String(), "SafeBrowsing lookup for "+matching) - d.checkMatch(t, "test.wmconvirus.narod.ru") + d.checkMatch(t, "test."+matching) d.checkMatchEmpty(t, "yandex.ru") d.checkMatchEmpty(t, "pornhub.com") - // test cached result + // Cached result. d.safeBrowsingServer = "127.0.0.1" - d.checkMatch(t, "wmconvirus.narod.ru") + d.checkMatch(t, matching) d.checkMatchEmpty(t, "pornhub.com") d.safeBrowsingServer = defaultSafebrowsingServer } func TestParallelSB(t *testing.T) { - d := NewForTest(&Config{SafeBrowsingEnabled: true}, nil) - defer d.Close() + d := newForTest(&Config{SafeBrowsingEnabled: true}, nil) + t.Cleanup(d.Close) + matching := "wmconvirus.narod.ru" + d.safeBrowsingUpstream = &testSbUpstream{ + hostname: matching, + block: true, + } + t.Run("group", func(t *testing.T) { for i := 0; i < 100; i++ { t.Run(fmt.Sprintf("aaa%d", i), func(t *testing.T) { t.Parallel() - d.checkMatch(t, "wmconvirus.narod.ru") - d.checkMatch(t, "test.wmconvirus.narod.ru") + d.checkMatch(t, matching) + d.checkMatch(t, "test."+matching) d.checkMatchEmpty(t, "yandex.ru") d.checkMatchEmpty(t, "pornhub.com") }) @@ -208,59 +196,81 @@ func TestParallelSB(t *testing.T) { }) } -// SAFE SEARCH +// Safe Search. func TestSafeSearch(t *testing.T) { - d := NewForTest(&Config{SafeSearchEnabled: true}, nil) - defer d.Close() + d := newForTest(&Config{SafeSearchEnabled: true}, nil) + t.Cleanup(d.Close) val, ok := d.SafeSearchDomain("www.google.com") - if !ok { - t.Errorf("Expected safesearch to find result for www.google.com") - } - if val != "forcesafesearch.google.com" { - t.Errorf("Expected safesearch for google.com to be forcesafesearch.google.com") - } + assert.True(t, ok, "Expected safesearch to find result for www.google.com") + assert.Equal(t, "forcesafesearch.google.com", val, "Expected safesearch for google.com to be forcesafesearch.google.com") } func TestCheckHostSafeSearchYandex(t *testing.T) { - d := NewForTest(&Config{SafeSearchEnabled: true}, nil) - defer d.Close() + d := newForTest(&Config{SafeSearchEnabled: true}, nil) + t.Cleanup(d.Close) - // Slice of yandex domains - yandex := []string{"yAndeX.ru", "YANdex.COM", "yandex.ua", "yandex.by", "yandex.kz", "www.yandex.com"} - - // Check host for each domain - for _, host := range yandex { + // Check host for each domain. + for _, host := range []string{ + "yAndeX.ru", + "YANdex.COM", + "yandex.ua", + "yandex.by", + "yandex.kz", + "www.yandex.com", + } { res, err := d.CheckHost(host, dns.TypeA, &setts) assert.Nil(t, err) assert.True(t, res.IsFiltered) if assert.Len(t, res.Rules, 1) { - assert.Equal(t, res.Rules[0].IP.String(), "213.180.193.56") + assert.Equal(t, res.Rules[0].IP, net.IPv4(213, 180, 193, 56)) } } } +// testResolver is a Resolver for tests. +type testResolver struct{} + +// LookupIP implements Resolver interface for *testResolver. +func (r *testResolver) LookupIPAddr(_ context.Context, host string) (ips []net.IPAddr, err error) { + hash := sha256.Sum256([]byte(host)) + addrs := []net.IPAddr{{ + IP: net.IP(hash[:4]), + Zone: "somezone", + }, { + IP: net.IP(hash[4:20]), + Zone: "somezone", + }} + return addrs, nil +} + func TestCheckHostSafeSearchGoogle(t *testing.T) { - d := NewForTest(&Config{SafeSearchEnabled: true}, nil) - defer d.Close() + d := newForTest(&Config{SafeSearchEnabled: true}, nil) + t.Cleanup(d.Close) + d.resolver = &testResolver{} - // Slice of google domains - googleDomains := []string{"www.google.com", "www.google.im", "www.google.co.in", "www.google.iq", "www.google.is", "www.google.it", "www.google.je"} - - // Check host for each domain - for _, host := range googleDomains { - res, err := d.CheckHost(host, dns.TypeA, &setts) - assert.Nil(t, err) - assert.True(t, res.IsFiltered) - if assert.Len(t, res.Rules, 1) { - assert.NotEqual(t, res.Rules[0].IP.String(), "0.0.0.0") - } + // Check host for each domain. + for _, host := range []string{ + "www.google.com", + "www.google.im", + "www.google.co.in", + "www.google.iq", + "www.google.is", + "www.google.it", + "www.google.je", + } { + t.Run(host, func(t *testing.T) { + res, err := d.CheckHost(host, dns.TypeA, &setts) + assert.Nil(t, err) + assert.True(t, res.IsFiltered) + assert.Len(t, res.Rules, 1) + }) } } func TestSafeSearchCacheYandex(t *testing.T) { - d := NewForTest(nil, nil) - defer d.Close() + d := newForTest(nil, nil) + t.Cleanup(d.Close) domain := "yandex.ru" // Check host with disabled safesearch. @@ -269,54 +279,55 @@ func TestSafeSearchCacheYandex(t *testing.T) { assert.False(t, res.IsFiltered) assert.Empty(t, res.Rules) - d = NewForTest(&Config{SafeSearchEnabled: true}, nil) - defer d.Close() + d = newForTest(&Config{SafeSearchEnabled: true}, nil) + t.Cleanup(d.Close) res, err = d.CheckHost(domain, dns.TypeA, &setts) - if err != nil { - t.Fatalf("CheckHost for safesearh domain %s failed cause %s", domain, err) - } + assert.Nilf(t, err, "CheckHost for safesearh domain %s failed cause %s", domain, err) - // For yandex we already know valid ip. + // For yandex we already know valid IP. if assert.Len(t, res.Rules, 1) { - assert.Equal(t, res.Rules[0].IP.String(), "213.180.193.56") + assert.Equal(t, res.Rules[0].IP, net.IPv4(213, 180, 193, 56)) } // Check cache. cachedValue, isFound := getCachedResult(gctx.safeSearchCache, domain) assert.True(t, isFound) if assert.Len(t, cachedValue.Rules, 1) { - assert.Equal(t, cachedValue.Rules[0].IP.String(), "213.180.193.56") + assert.Equal(t, cachedValue.Rules[0].IP, net.IPv4(213, 180, 193, 56)) } } func TestSafeSearchCacheGoogle(t *testing.T) { - d := NewForTest(nil, nil) - defer d.Close() + d := newForTest(nil, nil) + t.Cleanup(d.Close) + + resolver := &testResolver{} + d.resolver = resolver + domain := "www.google.ru" res, err := d.CheckHost(domain, dns.TypeA, &setts) assert.Nil(t, err) assert.False(t, res.IsFiltered) assert.Empty(t, res.Rules) - d = NewForTest(&Config{SafeSearchEnabled: true}, nil) - defer d.Close() + d = newForTest(&Config{SafeSearchEnabled: true}, nil) + t.Cleanup(d.Close) + d.resolver = resolver - // Let's lookup for safesearch domain + // Lookup for safesearch domain. safeDomain, ok := d.SafeSearchDomain(domain) - if !ok { - t.Fatalf("Failed to get safesearch domain for %s", domain) - } + assert.Truef(t, ok, "Failed to get safesearch domain for %s", domain) - ips, err := net.LookupIP(safeDomain) + ipAddrs, err := resolver.LookupIPAddr(context.Background(), safeDomain) if err != nil { t.Fatalf("Failed to lookup for %s", safeDomain) } - ip := ips[0] - for _, i := range ips { - if i.To4() != nil { - ip = i + ip := ipAddrs[0].IP + for _, ipAddr := range ipAddrs { + if ipAddr.IP.To4() != nil { + ip = ipAddr.IP break } } @@ -335,114 +346,324 @@ func TestSafeSearchCacheGoogle(t *testing.T) { } } -// PARENTAL +// Parental. func TestParentalControl(t *testing.T) { logOutput := &bytes.Buffer{} testutil.ReplaceLogWriter(t, logOutput) testutil.ReplaceLogLevel(t, log.DEBUG) - d := NewForTest(&Config{ParentalEnabled: true}, nil) - defer d.Close() - d.checkMatch(t, "pornhub.com") - assert.Contains(t, logOutput.String(), "Parental lookup for pornhub.com") - d.checkMatch(t, "www.pornhub.com") + d := newForTest(&Config{ParentalEnabled: true}, nil) + t.Cleanup(d.Close) + matching := "pornhub.com" + d.parentalUpstream = &testSbUpstream{ + hostname: matching, + block: true, + } + + d.checkMatch(t, matching) + assert.Contains(t, logOutput.String(), "Parental lookup for "+matching) + d.checkMatch(t, "www."+matching) d.checkMatchEmpty(t, "www.yandex.ru") d.checkMatchEmpty(t, "yandex.ru") d.checkMatchEmpty(t, "api.jquery.com") // test cached result d.parentalServer = "127.0.0.1" - d.checkMatch(t, "pornhub.com") + d.checkMatch(t, matching) d.checkMatchEmpty(t, "yandex.ru") d.parentalServer = defaultParentalServer } -// FILTERING - -const nl = "\n" - -const ( - blockingRules = `||example.org^` + nl - allowlistRules = `||example.org^` + nl + `@@||test.example.org` + nl - importantRules = `@@||example.org^` + nl + `||test.example.org^$important` + nl - regexRules = `/example\.org/` + nl + `@@||test.example.org^` + nl - maskRules = `test*.example.org^` + nl + `exam*.com` + nl - dnstypeRules = `||example.org^$dnstype=AAAA` + nl + `@@||test.example.org^` + nl -) - -var tests = []struct { - testname string - rules string - hostname string - isFiltered bool - reason Reason - dnsType uint16 -}{ - {"sanity", "||doubleclick.net^", "www.doubleclick.net", true, FilteredBlockList, dns.TypeA}, - {"sanity", "||doubleclick.net^", "nodoubleclick.net", false, NotFilteredNotFound, dns.TypeA}, - {"sanity", "||doubleclick.net^", "doubleclick.net.ru", false, NotFilteredNotFound, dns.TypeA}, - {"sanity", "||doubleclick.net^", "wmconvirus.narod.ru", false, NotFilteredNotFound, dns.TypeA}, - - {"blocking", blockingRules, "example.org", true, FilteredBlockList, dns.TypeA}, - {"blocking", blockingRules, "test.example.org", true, FilteredBlockList, dns.TypeA}, - {"blocking", blockingRules, "test.test.example.org", true, FilteredBlockList, dns.TypeA}, - {"blocking", blockingRules, "testexample.org", false, NotFilteredNotFound, dns.TypeA}, - {"blocking", blockingRules, "onemoreexample.org", false, NotFilteredNotFound, dns.TypeA}, - - {"allowlist", allowlistRules, "example.org", true, FilteredBlockList, dns.TypeA}, - {"allowlist", allowlistRules, "test.example.org", false, NotFilteredAllowList, dns.TypeA}, - {"allowlist", allowlistRules, "test.test.example.org", false, NotFilteredAllowList, dns.TypeA}, - {"allowlist", allowlistRules, "testexample.org", false, NotFilteredNotFound, dns.TypeA}, - {"allowlist", allowlistRules, "onemoreexample.org", false, NotFilteredNotFound, dns.TypeA}, - - {"important", importantRules, "example.org", false, NotFilteredAllowList, dns.TypeA}, - {"important", importantRules, "test.example.org", true, FilteredBlockList, dns.TypeA}, - {"important", importantRules, "test.test.example.org", true, FilteredBlockList, dns.TypeA}, - {"important", importantRules, "testexample.org", false, NotFilteredNotFound, dns.TypeA}, - {"important", importantRules, "onemoreexample.org", false, NotFilteredNotFound, dns.TypeA}, - - {"regex", regexRules, "example.org", true, FilteredBlockList, dns.TypeA}, - {"regex", regexRules, "test.example.org", false, NotFilteredAllowList, dns.TypeA}, - {"regex", regexRules, "test.test.example.org", false, NotFilteredAllowList, dns.TypeA}, - {"regex", regexRules, "testexample.org", true, FilteredBlockList, dns.TypeA}, - {"regex", regexRules, "onemoreexample.org", true, FilteredBlockList, dns.TypeA}, - - {"mask", maskRules, "test.example.org", true, FilteredBlockList, dns.TypeA}, - {"mask", maskRules, "test2.example.org", true, FilteredBlockList, dns.TypeA}, - {"mask", maskRules, "example.com", true, FilteredBlockList, dns.TypeA}, - {"mask", maskRules, "exampleeee.com", true, FilteredBlockList, dns.TypeA}, - {"mask", maskRules, "onemoreexamsite.com", true, FilteredBlockList, dns.TypeA}, - {"mask", maskRules, "example.org", false, NotFilteredNotFound, dns.TypeA}, - {"mask", maskRules, "testexample.org", false, NotFilteredNotFound, dns.TypeA}, - {"mask", maskRules, "example.co.uk", false, NotFilteredNotFound, dns.TypeA}, - - {"dnstype", dnstypeRules, "onemoreexample.org", false, NotFilteredNotFound, dns.TypeA}, - {"dnstype", dnstypeRules, "example.org", false, NotFilteredNotFound, dns.TypeA}, - {"dnstype", dnstypeRules, "example.org", true, FilteredBlockList, dns.TypeAAAA}, - {"dnstype", dnstypeRules, "test.example.org", false, NotFilteredAllowList, dns.TypeA}, - {"dnstype", dnstypeRules, "test.example.org", false, NotFilteredAllowList, dns.TypeAAAA}, -} +// Filtering. func TestMatching(t *testing.T) { - for _, test := range tests { - t.Run(fmt.Sprintf("%s-%s", test.testname, test.hostname), func(t *testing.T) { - filters := []Filter{{ - ID: 0, Data: []byte(test.rules), - }} - d := NewForTest(nil, filters) - defer d.Close() + const nl = "\n" + const ( + blockingRules = `||example.org^` + nl + allowlistRules = `||example.org^` + nl + `@@||test.example.org` + nl + importantRules = `@@||example.org^` + nl + `||test.example.org^$important` + nl + regexRules = `/example\.org/` + nl + `@@||test.example.org^` + nl + maskRules = `test*.example.org^` + nl + `exam*.com` + nl + dnstypeRules = `||example.org^$dnstype=AAAA` + nl + `@@||test.example.org^` + nl + ) + testCases := []struct { + name string + rules string + host string + wantIsFiltered bool + wantReason Reason + wantDNSType uint16 + }{{ + name: "sanity", + rules: "||doubleclick.net^", + host: "www.doubleclick.net", + wantIsFiltered: true, + wantReason: FilteredBlockList, + wantDNSType: dns.TypeA, + }, { + name: "sanity", + rules: "||doubleclick.net^", + host: "nodoubleclick.net", + wantIsFiltered: false, + wantReason: NotFilteredNotFound, + wantDNSType: dns.TypeA, + }, { + name: "sanity", + rules: "||doubleclick.net^", + host: "doubleclick.net.ru", + wantIsFiltered: false, + wantReason: NotFilteredNotFound, + wantDNSType: dns.TypeA, + }, { + name: "sanity", + rules: "||doubleclick.net^", + host: "wmconvirus.narod.ru", + wantIsFiltered: false, + wantReason: NotFilteredNotFound, + wantDNSType: dns.TypeA, + }, { + name: "blocking", + rules: blockingRules, + host: "example.org", + wantIsFiltered: true, + wantReason: FilteredBlockList, + wantDNSType: dns.TypeA, + }, { + name: "blocking", + rules: blockingRules, + host: "test.example.org", + wantIsFiltered: true, + wantReason: FilteredBlockList, + wantDNSType: dns.TypeA, + }, { + name: "blocking", + rules: blockingRules, + host: "test.test.example.org", + wantIsFiltered: true, + wantReason: FilteredBlockList, + wantDNSType: dns.TypeA, + }, { + name: "blocking", + rules: blockingRules, + host: "testexample.org", + wantIsFiltered: false, + wantReason: NotFilteredNotFound, + wantDNSType: dns.TypeA, + }, { + name: "blocking", + rules: blockingRules, + host: "onemoreexample.org", + wantIsFiltered: false, + wantReason: NotFilteredNotFound, + wantDNSType: dns.TypeA, + }, { + name: "allowlist", + rules: allowlistRules, + host: "example.org", + wantIsFiltered: true, + wantReason: FilteredBlockList, + wantDNSType: dns.TypeA, + }, { + name: "allowlist", + rules: allowlistRules, + host: "test.example.org", + wantIsFiltered: false, + wantReason: NotFilteredAllowList, + wantDNSType: dns.TypeA, + }, { + name: "allowlist", + rules: allowlistRules, + host: "test.test.example.org", + wantIsFiltered: false, + wantReason: NotFilteredAllowList, + wantDNSType: dns.TypeA, + }, { + name: "allowlist", + rules: allowlistRules, + host: "testexample.org", + wantIsFiltered: false, + wantReason: NotFilteredNotFound, + wantDNSType: dns.TypeA, + }, { + name: "allowlist", + rules: allowlistRules, + host: "onemoreexample.org", + wantIsFiltered: false, + wantReason: NotFilteredNotFound, + wantDNSType: dns.TypeA, + }, { + name: "important", + rules: importantRules, + host: "example.org", + wantIsFiltered: false, + wantReason: NotFilteredAllowList, + wantDNSType: dns.TypeA, + }, { + name: "important", + rules: importantRules, + host: "test.example.org", + wantIsFiltered: true, + wantReason: FilteredBlockList, + wantDNSType: dns.TypeA, + }, { + name: "important", + rules: importantRules, + host: "test.test.example.org", + wantIsFiltered: true, + wantReason: FilteredBlockList, + wantDNSType: dns.TypeA, + }, { + name: "important", + rules: importantRules, + host: "testexample.org", + wantIsFiltered: false, + wantReason: NotFilteredNotFound, + wantDNSType: dns.TypeA, + }, { + name: "important", + rules: importantRules, + host: "onemoreexample.org", + wantIsFiltered: false, + wantReason: NotFilteredNotFound, + wantDNSType: dns.TypeA, + }, { + name: "regex", + rules: regexRules, + host: "example.org", + wantIsFiltered: true, + wantReason: FilteredBlockList, + wantDNSType: dns.TypeA, + }, { + name: "regex", + rules: regexRules, + host: "test.example.org", + wantIsFiltered: false, + wantReason: NotFilteredAllowList, + wantDNSType: dns.TypeA, + }, { + name: "regex", + rules: regexRules, + host: "test.test.example.org", + wantIsFiltered: false, + wantReason: NotFilteredAllowList, + wantDNSType: dns.TypeA, + }, { + name: "regex", + rules: regexRules, + host: "testexample.org", + wantIsFiltered: true, + wantReason: FilteredBlockList, + wantDNSType: dns.TypeA, + }, { + name: "regex", + rules: regexRules, + host: "onemoreexample.org", + wantIsFiltered: true, + wantReason: FilteredBlockList, + wantDNSType: dns.TypeA, + }, { + name: "mask", + rules: maskRules, + host: "test.example.org", + wantIsFiltered: true, + wantReason: FilteredBlockList, + wantDNSType: dns.TypeA, + }, { + name: "mask", + rules: maskRules, + host: "test2.example.org", + wantIsFiltered: true, + wantReason: FilteredBlockList, + wantDNSType: dns.TypeA, + }, { + name: "mask", + rules: maskRules, + host: "example.com", + wantIsFiltered: true, + wantReason: FilteredBlockList, + wantDNSType: dns.TypeA, + }, { + name: "mask", + rules: maskRules, + host: "exampleeee.com", + wantIsFiltered: true, + wantReason: FilteredBlockList, + wantDNSType: dns.TypeA, + }, { + name: "mask", + rules: maskRules, + host: "onemoreexamsite.com", + wantIsFiltered: true, + wantReason: FilteredBlockList, + wantDNSType: dns.TypeA, + }, { + name: "mask", + rules: maskRules, + host: "example.org", + wantIsFiltered: false, + wantReason: NotFilteredNotFound, + wantDNSType: dns.TypeA, + }, { + name: "mask", + rules: maskRules, + host: "testexample.org", + wantIsFiltered: false, + wantReason: NotFilteredNotFound, + wantDNSType: dns.TypeA, + }, { + name: "mask", + rules: maskRules, + host: "example.co.uk", + wantIsFiltered: false, + wantReason: NotFilteredNotFound, + wantDNSType: dns.TypeA, + }, { + name: "dnstype", + rules: dnstypeRules, + host: "onemoreexample.org", + wantIsFiltered: false, + wantReason: NotFilteredNotFound, + wantDNSType: dns.TypeA, + }, { + name: "dnstype", + rules: dnstypeRules, + host: "example.org", + wantIsFiltered: false, + wantReason: NotFilteredNotFound, + wantDNSType: dns.TypeA, + }, { + name: "dnstype", + rules: dnstypeRules, + host: "example.org", + wantIsFiltered: true, + wantReason: FilteredBlockList, + wantDNSType: dns.TypeAAAA, + }, { + name: "dnstype", + rules: dnstypeRules, + host: "test.example.org", + wantIsFiltered: false, + wantReason: NotFilteredAllowList, + wantDNSType: dns.TypeA, + }, { + name: "dnstype", + rules: dnstypeRules, + host: "test.example.org", + wantIsFiltered: false, + wantReason: NotFilteredAllowList, + wantDNSType: dns.TypeAAAA, + }} + for _, tc := range testCases { + t.Run(fmt.Sprintf("%s-%s", tc.name, tc.host), func(t *testing.T) { + filters := []Filter{{ID: 0, Data: []byte(tc.rules)}} + d := newForTest(nil, filters) + t.Cleanup(d.Close) - res, err := d.CheckHost(test.hostname, test.dnsType, &setts) - if err != nil { - t.Errorf("Error while matching host %s: %s", test.hostname, err) - } - if res.IsFiltered != test.isFiltered { - t.Errorf("Hostname %s has wrong result (%v must be %v)", test.hostname, res.IsFiltered, test.isFiltered) - } - if res.Reason != test.reason { - t.Errorf("Hostname %s has wrong reason (%v must be %v)", test.hostname, res.Reason.String(), test.reason.String()) - } + res, err := d.CheckHost(tc.host, tc.wantDNSType, &setts) + assert.Nilf(t, err, "Error while matching host %s: %s", tc.host, err) + assert.Equalf(t, tc.wantIsFiltered, res.IsFiltered, "Hostname %s has wrong result (%v must be %v)", tc.host, res.IsFiltered, tc.wantIsFiltered) + assert.Equalf(t, tc.wantReason, res.Reason, "Hostname %s has wrong reason (%v must be %v)", tc.host, res.Reason, tc.wantReason) }) } } @@ -461,11 +682,11 @@ func TestWhitelist(t *testing.T) { whiteFilters := []Filter{{ ID: 0, Data: []byte(whiteRules), }} - d := NewForTest(nil, filters) + d := newForTest(nil, filters) d.SetFilters(filters, whiteFilters, false) - defer d.Close() + t.Cleanup(d.Close) - // matched by white filter + // Matched by white filter. res, err := d.CheckHost("host1", dns.TypeA, &setts) assert.Nil(t, err) assert.False(t, res.IsFiltered) @@ -474,7 +695,7 @@ func TestWhitelist(t *testing.T) { assert.Equal(t, "||host1^", res.Rules[0].Text) } - // not matched by white filter, but matched by block filter + // Not matched by white filter, but matched by block filter. res, err = d.CheckHost("host2", dns.TypeA, &setts) assert.Nil(t, err) assert.True(t, res.IsFiltered) @@ -484,7 +705,7 @@ func TestWhitelist(t *testing.T) { } } -// CLIENT SETTINGS +// Client Settings. func applyClientSettings(setts *RequestFilteringSettings) { setts.FilteringEnabled = false @@ -498,127 +719,131 @@ func applyClientSettings(setts *RequestFilteringSettings) { setts.ServicesRules = append(setts.ServicesRules, s) } -// Check behaviour without any per-client settings, -// then apply per-client settings and check behaviour once again func TestClientSettings(t *testing.T) { - var r Result - filters := []Filter{{ - ID: 0, Data: []byte("||example.org^\n"), + d := newForTest( + &Config{ + ParentalEnabled: true, + SafeBrowsingEnabled: false, + }, + []Filter{{ + ID: 0, Data: []byte("||example.org^\n"), + }}, + ) + t.Cleanup(d.Close) + d.parentalUpstream = &testSbUpstream{ + hostname: "pornhub.com", + block: true, + } + d.safeBrowsingUpstream = &testSbUpstream{ + hostname: "wmconvirus.narod.ru", + block: true, + } + + type testCase struct { + name string + host string + before bool + wantReason Reason + } + testCases := []testCase{{ + name: "filters", + host: "example.org", + before: true, + wantReason: FilteredBlockList, + }, { + name: "parental", + host: "pornhub.com", + before: true, + wantReason: FilteredParental, + }, { + name: "safebrowsing", + host: "wmconvirus.narod.ru", + before: false, + wantReason: FilteredSafeBrowsing, + }, { + name: "additional_rules", + host: "facebook.com", + before: false, + wantReason: FilteredBlockedService, }} - d := NewForTest(&Config{ParentalEnabled: true, SafeBrowsingEnabled: false}, filters) - defer d.Close() - // no client settings: - - // blocked by filters - r, _ = d.CheckHost("example.org", dns.TypeA, &setts) - if !r.IsFiltered || r.Reason != FilteredBlockList { - t.Fatalf("CheckHost FilteredBlockList") + makeTester := func(tc testCase, before bool) func(t *testing.T) { + return func(t *testing.T) { + r, _ := d.CheckHost(tc.host, dns.TypeA, &setts) + if before { + assert.True(t, r.IsFiltered) + assert.Equal(t, tc.wantReason, r.Reason) + } else { + assert.False(t, r.IsFiltered) + } + } } - // blocked by parental - r, _ = d.CheckHost("pornhub.com", dns.TypeA, &setts) - if !r.IsFiltered || r.Reason != FilteredParental { - t.Fatalf("CheckHost FilteredParental") + // Check behaviour without any per-client settings, then apply per-client + // settings and check behaviour once again. + for _, tc := range testCases { + t.Run(tc.name, makeTester(tc, tc.before)) } - // safesearch is disabled - r, _ = d.CheckHost("wmconvirus.narod.ru", dns.TypeA, &setts) - if r.IsFiltered { - t.Fatalf("CheckHost safesearch") - } - - // not blocked - r, _ = d.CheckHost("facebook.com", dns.TypeA, &setts) - assert.False(t, r.IsFiltered) - - // override client settings: applyClientSettings(&setts) - // override filtering settings - r, _ = d.CheckHost("example.org", dns.TypeA, &setts) - if r.IsFiltered { - t.Fatalf("CheckHost") + for _, tc := range testCases { + t.Run(tc.name, makeTester(tc, !tc.before)) } - - // override parental settings (force disable parental) - r, _ = d.CheckHost("pornhub.com", dns.TypeA, &setts) - if r.IsFiltered { - t.Fatalf("CheckHost") - } - - // override safesearch settings (force enable safesearch) - r, _ = d.CheckHost("wmconvirus.narod.ru", dns.TypeA, &setts) - if !r.IsFiltered || r.Reason != FilteredSafeBrowsing { - t.Fatalf("CheckHost FilteredSafeBrowsing") - } - - // blocked by additional rules - r, _ = d.CheckHost("facebook.com", dns.TypeA, &setts) - assert.True(t, r.IsFiltered) - assert.Equal(t, r.Reason, FilteredBlockedService) } -// BENCHMARKS +// Benchmarks. func BenchmarkSafeBrowsing(b *testing.B) { - d := NewForTest(&Config{SafeBrowsingEnabled: true}, nil) - defer d.Close() + d := newForTest(&Config{SafeBrowsingEnabled: true}, nil) + b.Cleanup(d.Close) + blocked := "wmconvirus.narod.ru" + d.safeBrowsingUpstream = &testSbUpstream{ + hostname: blocked, + block: true, + } for n := 0; n < b.N; n++ { - hostname := "wmconvirus.narod.ru" - res, err := d.CheckHost(hostname, dns.TypeA, &setts) - if err != nil { - b.Errorf("Error while matching host %s: %s", hostname, err) - } - if !res.IsFiltered { - b.Errorf("Expected hostname %s to match", hostname) - } + res, err := d.CheckHost(blocked, dns.TypeA, &setts) + assert.Nilf(b, err, "Error while matching host %s: %s", blocked, err) + assert.True(b, res.IsFiltered, "Expected hostname %s to match", blocked) } } func BenchmarkSafeBrowsingParallel(b *testing.B) { - d := NewForTest(&Config{SafeBrowsingEnabled: true}, nil) - defer d.Close() + d := newForTest(&Config{SafeBrowsingEnabled: true}, nil) + b.Cleanup(d.Close) + blocked := "wmconvirus.narod.ru" + d.safeBrowsingUpstream = &testSbUpstream{ + hostname: blocked, + block: true, + } b.RunParallel(func(pb *testing.PB) { for pb.Next() { - hostname := "wmconvirus.narod.ru" - res, err := d.CheckHost(hostname, dns.TypeA, &setts) - if err != nil { - b.Errorf("Error while matching host %s: %s", hostname, err) - } - if !res.IsFiltered { - b.Errorf("Expected hostname %s to match", hostname) - } + res, err := d.CheckHost(blocked, dns.TypeA, &setts) + assert.Nilf(b, err, "Error while matching host %s: %s", blocked, err) + assert.True(b, res.IsFiltered, "Expected hostname %s to match", blocked) } }) } func BenchmarkSafeSearch(b *testing.B) { - d := NewForTest(&Config{SafeSearchEnabled: true}, nil) - defer d.Close() + d := newForTest(&Config{SafeSearchEnabled: true}, nil) + b.Cleanup(d.Close) for n := 0; n < b.N; n++ { val, ok := d.SafeSearchDomain("www.google.com") - if !ok { - b.Errorf("Expected safesearch to find result for www.google.com") - } - if val != "forcesafesearch.google.com" { - b.Errorf("Expected safesearch for google.com to be forcesafesearch.google.com") - } + assert.True(b, ok, "Expected safesearch to find result for www.google.com") + assert.Equal(b, "forcesafesearch.google.com", val, "Expected safesearch for google.com to be forcesafesearch.google.com") } } func BenchmarkSafeSearchParallel(b *testing.B) { - d := NewForTest(&Config{SafeSearchEnabled: true}, nil) - defer d.Close() + d := newForTest(&Config{SafeSearchEnabled: true}, nil) + b.Cleanup(d.Close) b.RunParallel(func(pb *testing.PB) { for pb.Next() { val, ok := d.SafeSearchDomain("www.google.com") - if !ok { - b.Errorf("Expected safesearch to find result for www.google.com") - } - if val != "forcesafesearch.google.com" { - b.Errorf("Expected safesearch for google.com to be forcesafesearch.google.com") - } + assert.True(b, ok, "Expected safesearch to find result for www.google.com") + assert.Equal(b, "forcesafesearch.google.com", val, "Expected safesearch for google.com to be forcesafesearch.google.com") } }) } diff --git a/internal/dnsfilter/dnsrewrite_test.go b/internal/dnsfilter/dnsrewrite_test.go index 201de44f..c915d920 100644 --- a/internal/dnsfilter/dnsrewrite_test.go +++ b/internal/dnsfilter/dnsrewrite_test.go @@ -45,7 +45,7 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) { @@||disable-all^$dnsrewrite ` - f := NewForTest(nil, []Filter{{ID: 0, Data: []byte(text)}}) + f := newForTest(nil, []Filter{{ID: 0, Data: []byte(text)}}) setts := &RequestFilteringSettings{ FilteringEnabled: true, } diff --git a/internal/dnsfilter/rewrites_test.go b/internal/dnsfilter/rewrites_test.go index 12a59409..a56d2a48 100644 --- a/internal/dnsfilter/rewrites_test.go +++ b/internal/dnsfilter/rewrites_test.go @@ -9,7 +9,8 @@ import ( ) func TestRewrites(t *testing.T) { - d := DNSFilter{} + d := newForTest(nil, nil) + t.Cleanup(d.Close) // CNAME, A, AAAA d.Rewrites = []RewriteEntry{ {"somecname", "somehost.com", 0, nil}, @@ -104,7 +105,8 @@ func TestRewrites(t *testing.T) { } func TestRewritesLevels(t *testing.T) { - d := DNSFilter{} + d := newForTest(nil, nil) + t.Cleanup(d.Close) // exact host, wildcard L2, wildcard L3 d.Rewrites = []RewriteEntry{ {"host.com", "1.1.1.1", 0, nil}, @@ -133,7 +135,8 @@ func TestRewritesLevels(t *testing.T) { } func TestRewritesExceptionCNAME(t *testing.T) { - d := DNSFilter{} + d := newForTest(nil, nil) + t.Cleanup(d.Close) // wildcard; exception for a sub-domain d.Rewrites = []RewriteEntry{ {"*.host.com", "2.2.2.2", 0, nil}, @@ -153,7 +156,8 @@ func TestRewritesExceptionCNAME(t *testing.T) { } func TestRewritesExceptionWC(t *testing.T) { - d := DNSFilter{} + d := newForTest(nil, nil) + t.Cleanup(d.Close) // wildcard; exception for a sub-wildcard d.Rewrites = []RewriteEntry{ {"*.host.com", "2.2.2.2", 0, nil}, @@ -173,7 +177,8 @@ func TestRewritesExceptionWC(t *testing.T) { } func TestRewritesExceptionIP(t *testing.T) { - d := DNSFilter{} + d := newForTest(nil, nil) + t.Cleanup(d.Close) // exception for AAAA record d.Rewrites = []RewriteEntry{ {"host.com", "1.2.3.4", 0, nil}, diff --git a/internal/dnsfilter/safebrowsing_test.go b/internal/dnsfilter/safebrowsing_test.go index 21ea2277..a1f627c4 100644 --- a/internal/dnsfilter/safebrowsing_test.go +++ b/internal/dnsfilter/safebrowsing_test.go @@ -4,6 +4,7 @@ import ( "crypto/sha256" "encoding/hex" "strings" + "sync" "testing" "github.com/AdguardTeam/AdGuardHome/internal/agherr" @@ -121,8 +122,8 @@ func (teu *testErrUpstream) Address() string { } func TestSBPC_checkErrorUpstream(t *testing.T) { - d := NewForTest(&Config{SafeBrowsingEnabled: true}, nil) - defer d.Close() + d := newForTest(&Config{SafeBrowsingEnabled: true}, nil) + t.Cleanup(d.Close) ups := &testErrUpstream{} @@ -142,11 +143,14 @@ type testSbUpstream struct { hostname string block bool requestsCount int + counterLock sync.RWMutex } // Exchange returns a message depending on the upstream settings (hostname, block) func (u *testSbUpstream) Exchange(r *dns.Msg) (*dns.Msg, error) { + u.counterLock.Lock() u.requestsCount++ + u.counterLock.Unlock() hash := sha256.Sum256([]byte(u.hostname)) prefix := hash[0:2] @@ -175,8 +179,8 @@ func (u *testSbUpstream) Address() string { } func TestSBPC_sbValidResponse(t *testing.T) { - d := NewForTest(&Config{SafeBrowsingEnabled: true}, nil) - defer d.Close() + d := newForTest(&Config{SafeBrowsingEnabled: true}, nil) + t.Cleanup(d.Close) ups := &testSbUpstream{} d.safeBrowsingUpstream = ups @@ -213,8 +217,8 @@ func TestSBPC_sbValidResponse(t *testing.T) { } func TestSBPC_pcBlockedResponse(t *testing.T) { - d := NewForTest(&Config{SafeBrowsingEnabled: true}, nil) - defer d.Close() + d := newForTest(&Config{SafeBrowsingEnabled: true}, nil) + t.Cleanup(d.Close) ups := &testSbUpstream{} d.safeBrowsingUpstream = ups diff --git a/internal/dnsfilter/safesearch.go b/internal/dnsfilter/safesearch.go index 958962db..bc72cc9e 100644 --- a/internal/dnsfilter/safesearch.go +++ b/internal/dnsfilter/safesearch.go @@ -2,6 +2,7 @@ package dnsfilter import ( "bytes" + "context" "encoding/binary" "encoding/gob" "encoding/json" @@ -101,15 +102,14 @@ func (d *DNSFilter) checkSafeSearch(host string) (Result, error) { return res, nil } - // TODO this address should be resolved with upstream that was configured in dnsforward - ips, err := net.LookupIP(safeHost) + ipAddrs, err := d.resolver.LookupIPAddr(context.Background(), safeHost) if err != nil { log.Tracef("SafeSearchDomain for %s was found but failed to lookup for %s cause %s", host, safeHost, err) return Result{}, err } - for _, ip := range ips { - if ipv4 := ip.To4(); ipv4 != nil { + for _, ipAddr := range ipAddrs { + if ipv4 := ipAddr.IP.To4(); ipv4 != nil { res.Rules[0].IP = ipv4 l := d.setCacheResult(gctx.safeSearchCache, host, res)