diff --git a/CHANGELOG.md b/CHANGELOG.md index 3b1ad28e..66f3b1f7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,6 +34,8 @@ and this project adheres to ### Changed +- When /etc/hosts-type rules have several IPs for one host, all IPs are now + returned instead of only the first one ([#1381]). - The setting `rlimit_nofile` is now in the `os` block of the configuration file, together with the new `group` and `user` settings ([#2763]). - Permissions on filter files are now `0o644` instead of `0o600` ([#3198]). @@ -63,6 +65,7 @@ released by then. - Go 1.15 support. +[#1381]: https://github.com/AdguardTeam/AdGuardHome/issues/1381 [#1691]: https://github.com/AdguardTeam/AdGuardHome/issues/1691 [#2280]: https://github.com/AdguardTeam/AdGuardHome/issues/2280 [#2439]: https://github.com/AdguardTeam/AdGuardHome/issues/2439 diff --git a/internal/dnsforward/config.go b/internal/dnsforward/config.go index 905e5cec..3b28c236 100644 --- a/internal/dnsforward/config.go +++ b/internal/dnsforward/config.go @@ -21,6 +21,30 @@ import ( "github.com/ameshkov/dnscrypt/v2" ) +// BlockingMode is an enum of all allowed blocking modes. +type BlockingMode string + +// Allowed blocking modes. +const ( + // BlockingModeCustomIP means respond with a custom IP address. + BlockingModeCustomIP BlockingMode = "custom_ip" + + // BlockingModeDefault is the same as BlockingModeNullIP for + // Adblock-style rules, but responds with the IP address specified in + // the rule when blocked by an `/etc/hosts`-style rule. + BlockingModeDefault BlockingMode = "default" + + // BlockingModeNullIP means respond with a zero IP address: "0.0.0.0" + // for A requests and "::" for AAAA ones. + BlockingModeNullIP BlockingMode = "null_ip" + + // BlockingModeNXDOMAIN means respond with the NXDOMAIN code. + BlockingModeNXDOMAIN BlockingMode = "nxdomain" + + // BlockingModeREFUSED means respond with the REFUSED code. + BlockingModeREFUSED BlockingMode = "refused" +) + // FilteringConfig represents the DNS filtering configuration of AdGuard Home // The zero FilteringConfig is empty and ready for use. type FilteringConfig struct { @@ -38,11 +62,11 @@ type FilteringConfig struct { // Protection configuration // -- - ProtectionEnabled bool `yaml:"protection_enabled"` // whether or not use any of filtering features - BlockingMode string `yaml:"blocking_mode"` // mode how to answer filtered requests - BlockingIPv4 net.IP `yaml:"blocking_ipv4"` // IP address to be returned for a blocked A request - BlockingIPv6 net.IP `yaml:"blocking_ipv6"` // IP address to be returned for a blocked AAAA request - BlockedResponseTTL uint32 `yaml:"blocked_response_ttl"` // if 0, then default is used (3600) + ProtectionEnabled bool `yaml:"protection_enabled"` // whether or not use any of filtering features + BlockingMode BlockingMode `yaml:"blocking_mode"` // mode how to answer filtered requests + BlockingIPv4 net.IP `yaml:"blocking_ipv4"` // IP address to be returned for a blocked A request + BlockingIPv6 net.IP `yaml:"blocking_ipv6"` // IP address to be returned for a blocked AAAA request + BlockedResponseTTL uint32 `yaml:"blocked_response_ttl"` // if 0, then default is used (3600) // IP (or domain name) which is used to respond to DNS requests blocked by parental control or safe-browsing ParentalBlockHost string `yaml:"parental_block_host"` diff --git a/internal/dnsforward/dnsforward_test.go b/internal/dnsforward/dnsforward_test.go index 9bc4f53e..a7ed4fc4 100644 --- a/internal/dnsforward/dnsforward_test.go +++ b/internal/dnsforward/dnsforward_test.go @@ -520,6 +520,7 @@ func TestBlockedRequest(t *testing.T) { TCPListenAddrs: []*net.TCPAddr{{}}, FilteringConfig: FilteringConfig{ ProtectionEnabled: true, + BlockingMode: BlockingModeDefault, }, } s := createTestServer(t, &filtering.Config{}, forwardConf, nil) @@ -622,6 +623,7 @@ func TestBlockCNAME(t *testing.T) { TCPListenAddrs: []*net.TCPAddr{{}}, FilteringConfig: FilteringConfig{ ProtectionEnabled: true, + BlockingMode: BlockingModeDefault, }, } s := createTestServer(t, &filtering.Config{}, forwardConf, nil) @@ -724,7 +726,7 @@ func TestNullBlockedRequest(t *testing.T) { TCPListenAddrs: []*net.TCPAddr{{}}, FilteringConfig: FilteringConfig{ ProtectionEnabled: true, - BlockingMode: "null_ip", + BlockingMode: BlockingModeNullIP, }, } s := createTestServer(t, &filtering.Config{}, forwardConf, nil) @@ -777,7 +779,7 @@ func TestBlockedCustomIP(t *testing.T) { TCPListenAddrs: []*net.TCPAddr{{}}, FilteringConfig: FilteringConfig{ ProtectionEnabled: true, - BlockingMode: "custom_ip", + BlockingMode: BlockingModeCustomIP, BlockingIPv4: nil, UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"}, }, @@ -827,6 +829,7 @@ func TestBlockedByHosts(t *testing.T) { TCPListenAddrs: []*net.TCPAddr{{}}, FilteringConfig: FilteringConfig{ ProtectionEnabled: true, + BlockingMode: BlockingModeDefault, }, } diff --git a/internal/dnsforward/http.go b/internal/dnsforward/http.go index 06baa9a9..3e18f1ad 100644 --- a/internal/dnsforward/http.go +++ b/internal/dnsforward/http.go @@ -29,21 +29,21 @@ type dnsConfig struct { UpstreamsFile *string `json:"upstream_dns_file"` Bootstraps *[]string `json:"bootstrap_dns"` - ProtectionEnabled *bool `json:"protection_enabled"` - RateLimit *uint32 `json:"ratelimit"` - BlockingMode *string `json:"blocking_mode"` - BlockingIPv4 net.IP `json:"blocking_ipv4"` - BlockingIPv6 net.IP `json:"blocking_ipv6"` - EDNSCSEnabled *bool `json:"edns_cs_enabled"` - DNSSECEnabled *bool `json:"dnssec_enabled"` - DisableIPv6 *bool `json:"disable_ipv6"` - UpstreamMode *string `json:"upstream_mode"` - CacheSize *uint32 `json:"cache_size"` - CacheMinTTL *uint32 `json:"cache_ttl_min"` - CacheMaxTTL *uint32 `json:"cache_ttl_max"` - ResolveClients *bool `json:"resolve_clients"` - UsePrivateRDNS *bool `json:"use_private_ptr_resolvers"` - LocalPTRUpstreams *[]string `json:"local_ptr_upstreams"` + ProtectionEnabled *bool `json:"protection_enabled"` + RateLimit *uint32 `json:"ratelimit"` + BlockingMode *BlockingMode `json:"blocking_mode"` + BlockingIPv4 net.IP `json:"blocking_ipv4"` + BlockingIPv6 net.IP `json:"blocking_ipv6"` + EDNSCSEnabled *bool `json:"edns_cs_enabled"` + DNSSECEnabled *bool `json:"dnssec_enabled"` + DisableIPv6 *bool `json:"disable_ipv6"` + UpstreamMode *string `json:"upstream_mode"` + CacheSize *uint32 `json:"cache_size"` + CacheMinTTL *uint32 `json:"cache_ttl_min"` + CacheMaxTTL *uint32 `json:"cache_ttl_max"` + ResolveClients *bool `json:"resolve_clients"` + UsePrivateRDNS *bool `json:"use_private_ptr_resolvers"` + LocalPTRUpstreams *[]string `json:"local_ptr_upstreams"` } func (s *Server) getDNSConfig() dnsConfig { @@ -126,27 +126,17 @@ func (req *dnsConfig) checkBlockingMode() bool { return true } - bm := *req.BlockingMode - if bm == "custom_ip" { - if req.BlockingIPv4.To4() == nil { - return false - } - - return req.BlockingIPv6 != nil + switch bm := *req.BlockingMode; bm { + case BlockingModeDefault, + BlockingModeREFUSED, + BlockingModeNXDOMAIN, + BlockingModeNullIP: + return true + case BlockingModeCustomIP: + return req.BlockingIPv4.To4() != nil && req.BlockingIPv6 != nil + default: + return false } - - for _, valid := range []string{ - "default", - "refused", - "nxdomain", - "null_ip", - } { - if bm == valid { - return true - } - } - - return false } func (req *dnsConfig) checkUpstreamsMode() bool { diff --git a/internal/dnsforward/msg.go b/internal/dnsforward/msg.go index 617604f3..42307619 100644 --- a/internal/dnsforward/msg.go +++ b/internal/dnsforward/msg.go @@ -26,17 +26,29 @@ func (s *Server) makeResponse(req *dns.Msg) (resp *dns.Msg) { return resp } +// ipsFromRules extracts non-IP addresses from the filtering result rules. +func ipsFromRules(resRules []*filtering.ResultRule) (ips []net.IP) { + for _, r := range resRules { + if r.IP != nil { + ips = append(ips, r.IP) + } + } + + return ips +} + // genDNSFilterMessage generates a DNS message corresponding to the filtering result func (s *Server) genDNSFilterMessage(d *proxy.DNSContext, result *filtering.Result) *dns.Msg { m := d.Req if m.Question[0].Qtype != dns.TypeA && m.Question[0].Qtype != dns.TypeAAAA { - if s.conf.BlockingMode == "null_ip" { + if s.conf.BlockingMode == BlockingModeNullIP { return s.makeResponse(m) } return s.genNXDomain(m) } + ips := ipsFromRules(result.Rules) switch result.Reason { case filtering.FilteredSafeBrowsing: return s.genBlockedHost(m, s.conf.SafeBrowsingBlockHost, d) @@ -46,42 +58,45 @@ func (s *Server) genDNSFilterMessage(d *proxy.DNSContext, result *filtering.Resu // If the query was filtered by "Safe search", filtering also must return // the IP address that must be used in response. // In this case regardless of the filtering method, we should return it - if result.Reason == filtering.FilteredSafeSearch && - len(result.Rules) > 0 && - result.Rules[0].IP != nil { - return s.genResponseWithIP(m, result.Rules[0].IP) + if result.Reason == filtering.FilteredSafeSearch && len(ips) > 0 { + return s.genResponseWithIPs(m, ips) } - if s.conf.BlockingMode == "null_ip" { - // it means that we should return 0.0.0.0 or :: for any blocked request - return s.makeResponseNullIP(m) - } else if s.conf.BlockingMode == "custom_ip" { - // means that we should return custom IP for any blocked request - + switch s.conf.BlockingMode { + case BlockingModeCustomIP: switch m.Question[0].Qtype { case dns.TypeA: return s.genARecord(m, s.conf.BlockingIPv4) case dns.TypeAAAA: return s.genAAAARecord(m, s.conf.BlockingIPv6) + default: + // Generally shouldn't happen, since the types + // are checked above. + log.Error( + "dns: invalid msg type %d for blocking mode %s", + m.Question[0].Qtype, + s.conf.BlockingMode, + ) + + return s.makeResponse(m) + } + case BlockingModeDefault: + if len(ips) > 0 { + return s.genResponseWithIPs(m, ips) } - } else if s.conf.BlockingMode == "nxdomain" { - // means that we should return NXDOMAIN for any blocked request + return s.makeResponseNullIP(m) + case BlockingModeNullIP: + return s.makeResponseNullIP(m) + case BlockingModeNXDOMAIN: return s.genNXDomain(m) - } else if s.conf.BlockingMode == "refused" { - // means that we should return NXDOMAIN for any blocked request - + case BlockingModeREFUSED: return s.makeResponseREFUSED(m) - } + default: + log.Error("dns: invalid blocking mode %q", s.conf.BlockingMode) - // Default blocking mode - // If there's an IP specified in the rule, return it - // For host-type rules, return null IP - if len(result.Rules) > 0 && result.Rules[0].IP != nil { - return s.genResponseWithIP(m, result.Rules[0].IP) + return s.makeResponse(m) } - - return s.makeResponseNullIP(m) } } @@ -166,35 +181,60 @@ func (s *Server) genAnswerTXT(req *dns.Msg, strs []string) (ans *dns.TXT) { } } -// generate DNS response message with an IP address -func (s *Server) genResponseWithIP(req *dns.Msg, ip net.IP) *dns.Msg { - if req.Question[0].Qtype == dns.TypeA && ip.To4() != nil { - return s.genARecord(req, ip.To4()) - } else if req.Question[0].Qtype == dns.TypeAAAA && - len(ip) == net.IPv6len && ip.To4() == nil { - return s.genAAAARecord(req, ip) +// genResponseWithIPs generates a DNS response message with the provided IP +// addresses and an appropriate resource record type. If any of the IPs cannot +// be converted to the correct protocol, genResponseWithIPs returns an empty +// response. +func (s *Server) genResponseWithIPs(req *dns.Msg, ips []net.IP) (resp *dns.Msg) { + var ans []dns.RR + switch req.Question[0].Qtype { + case dns.TypeA: + for _, ip := range ips { + if ip4 := ip.To4(); ip4 == nil { + ans = nil + + break + } + + ans = append(ans, s.genAnswerA(req, ip)) + } + case dns.TypeAAAA: + for _, ip := range ips { + ans = append(ans, s.genAnswerAAAA(req, ip.To16())) + } + default: + // Go on and return an empty response. } - // empty response - resp := s.makeResponse(req) + resp = s.makeResponse(req) + resp.Answer = ans + return resp } -// Respond with 0.0.0.0 for A, :: for AAAA, empty response for other types -func (s *Server) makeResponseNullIP(req *dns.Msg) *dns.Msg { - if req.Question[0].Qtype == dns.TypeA { - return s.genARecord(req, []byte{0, 0, 0, 0}) - } else if req.Question[0].Qtype == dns.TypeAAAA { - return s.genAAAARecord(req, net.IPv6zero) +// makeResponseNullIP creates a response with 0.0.0.0 for A requests, :: for +// AAAA requests, and an empty response for other types. +func (s *Server) makeResponseNullIP(req *dns.Msg) (resp *dns.Msg) { + // Respond with the corresponding zero IP type as opposed to simply + // using one or the other in both cases, because the IPv4 zero IP is + // converted to a IPV6-mapped IPv4 address, while the IPv6 zero IP is + // converted into an empty slice instead of the zero IPv4. + switch req.Question[0].Qtype { + case dns.TypeA: + resp = s.genResponseWithIPs(req, []net.IP{{0, 0, 0, 0}}) + case dns.TypeAAAA: + resp = s.genResponseWithIPs(req, []net.IP{net.IPv6zero}) + default: + resp = s.makeResponse(req) } - return s.makeResponse(req) + return resp } func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSContext) *dns.Msg { ip := net.ParseIP(newAddr) if ip != nil { - return s.genResponseWithIP(request, ip) + return s.genResponseWithIPs(request, []net.IP{ip}) } // look up the hostname, TODO: cache diff --git a/internal/filtering/filtering.go b/internal/filtering/filtering.go index 14df5ef5..d1ce6f94 100644 --- a/internal/filtering/filtering.go +++ b/internal/filtering/filtering.go @@ -344,13 +344,13 @@ var gctx dnsFilterContext // ResultRule contains information about applied rules. type ResultRule struct { - // FilterListID is the ID of the rule's filter list. - FilterListID int64 `json:",omitempty"` // Text is the text of the rule. Text string `json:",omitempty"` // IP is the host IP. It is nil unless the rule uses the // /etc/hosts syntax or the reason is FilteredSafeSearch. IP net.IP `json:",omitempty"` + // FilterListID is the ID of the rule's filter list. + FilterListID int64 `json:",omitempty"` } // Result contains the result of a request check. @@ -657,26 +657,43 @@ func (d *DNSFilter) initFiltering(allowFilters, blockFilters []Filter) error { return nil } -// matchHostProcessAllowList processes the allowlist logic of host -// matching. -func (d *DNSFilter) matchHostProcessAllowList(host string, dnsres urlfilter.DNSResult) (res Result, err error) { - var rule rules.Rule - if dnsres.NetworkRule != nil { - rule = dnsres.NetworkRule - } else if len(dnsres.HostRulesV4) > 0 { - rule = dnsres.HostRulesV4[0] - } else if len(dnsres.HostRulesV6) > 0 { - rule = dnsres.HostRulesV6[0] +// hostRules is a helper that converts a slice of host rules into a slice of the +// rules.Rule interface values. +func hostRulesToRules(netRules []*rules.HostRule) (res []rules.Rule) { + if netRules == nil { + return nil } - if rule == nil { + res = make([]rules.Rule, len(netRules)) + for i, nr := range netRules { + res[i] = nr + } + + return res +} + +// matchHostProcessAllowList processes the allowlist logic of host +// matching. +func (d *DNSFilter) matchHostProcessAllowList( + host string, + dnsres urlfilter.DNSResult, +) (res Result, err error) { + var matchedRules []rules.Rule + if dnsres.NetworkRule != nil { + matchedRules = []rules.Rule{dnsres.NetworkRule} + } else if len(dnsres.HostRulesV4) > 0 { + matchedRules = hostRulesToRules(dnsres.HostRulesV4) + } else if len(dnsres.HostRulesV6) > 0 { + matchedRules = hostRulesToRules(dnsres.HostRulesV6) + } + + if len(matchedRules) == 0 { return Result{}, fmt.Errorf("invalid dns result: rules are empty") } - log.Debug("Filtering: found allowlist rule for host %q: %q list_id: %d", - host, rule.Text(), rule.GetFilterListID()) + log.Debug("filtering: allowlist rules for host %q: %+v", host, matchedRules) - return makeResult(rule, NotFilteredAllowList), nil + return makeResult(matchedRules, NotFilteredAllowList), nil } // matchHostProcessDNSResult processes the matched DNS filtering result. @@ -690,21 +707,23 @@ func (d *DNSFilter) matchHostProcessDNSResult( reason = NotFilteredAllowList } - return makeResult(dnsres.NetworkRule, reason) + return makeResult([]rules.Rule{dnsres.NetworkRule}, reason) } if qtype == dns.TypeA && dnsres.HostRulesV4 != nil { - rule := dnsres.HostRulesV4[0] - res = makeResult(rule, FilteredBlockList) - res.Rules[0].IP = rule.IP.To4() + res = makeResult(hostRulesToRules(dnsres.HostRulesV4), FilteredBlockList) + for i, hr := range dnsres.HostRulesV4 { + res.Rules[i].IP = hr.IP.To4() + } return res } if qtype == dns.TypeAAAA && dnsres.HostRulesV6 != nil { - rule := dnsres.HostRulesV6[0] - res = makeResult(rule, FilteredBlockList) - res.Rules[0].IP = rule.IP.To16() + res = makeResult(hostRulesToRules(dnsres.HostRulesV6), FilteredBlockList) + for i, hr := range dnsres.HostRulesV6 { + res.Rules[i].IP = hr.IP.To16() + } return res } @@ -712,17 +731,14 @@ func (d *DNSFilter) matchHostProcessDNSResult( if dnsres.HostRulesV4 != nil || dnsres.HostRulesV6 != nil { // Question type doesn't match the host rules. Return the first // matched host rule, but without an IP address. - var rule rules.Rule + var matchedRules []rules.Rule if dnsres.HostRulesV4 != nil { - rule = dnsres.HostRulesV4[0] + matchedRules = []rules.Rule{dnsres.HostRulesV4[0]} } else if dnsres.HostRulesV6 != nil { - rule = dnsres.HostRulesV6[0] + matchedRules = []rules.Rule{dnsres.HostRulesV6[0]} } - res = makeResult(rule, FilteredBlockList) - res.Rules[0].IP = net.IP{} - - return res + return makeResult(matchedRules, FilteredBlockList) } return Result{} @@ -780,8 +796,7 @@ func (d *DNSFilter) matchHost( } res = d.matchHostProcessDNSResult(qtype, dnsres) - if len(res.Rules) > 0 { - r := res.Rules[0] + for _, r := range res.Rules { log.Debug( "filtering: found rule %q for host %q, filter list id: %d", r.Text, @@ -794,20 +809,20 @@ func (d *DNSFilter) matchHost( } // makeResult returns a properly constructed Result. -func makeResult(rule rules.Rule, reason Reason) Result { - res := Result{ - Reason: reason, - Rules: []*ResultRule{{ - FilterListID: int64(rule.GetFilterListID()), - Text: rule.Text(), - }}, +func makeResult(matchedRules []rules.Rule, reason Reason) (res Result) { + resRules := make([]*ResultRule, len(matchedRules)) + for i, mr := range matchedRules { + resRules[i] = &ResultRule{ + FilterListID: int64(mr.GetFilterListID()), + Text: mr.Text(), + } } - if reason == FilteredBlockList { - res.IsFiltered = true + return Result{ + IsFiltered: reason == FilteredBlockList, + Reason: reason, + Rules: resRules, } - - return res } // InitModule manually initializes blocked services map. diff --git a/internal/filtering/filtering_test.go b/internal/filtering/filtering_test.go index b63e1055..746e9ed0 100644 --- a/internal/filtering/filtering_test.go +++ b/internal/filtering/filtering_test.go @@ -60,29 +60,33 @@ func (d *DNSFilter) checkMatch(t *testing.T, hostname string) { t.Helper() res, err := d.CheckHost(hostname, dns.TypeA, &setts) - require.Nilf(t, err, "Error while matching host %s: %s", hostname, err) - assert.Truef(t, res.IsFiltered, "Expected hostname %s to match", hostname) + require.NoErrorf(t, err, "host %q", hostname) + + assert.Truef(t, res.IsFiltered, "host %q", hostname) } func (d *DNSFilter) checkMatchIP(t *testing.T, hostname, ip string, qtype uint16) { t.Helper() res, err := d.CheckHost(hostname, qtype, &setts) - require.Nilf(t, err, "Error while matching host %s: %s", hostname, err) - assert.Truef(t, res.IsFiltered, "Expected hostname %s to match", hostname) + require.NoErrorf(t, err, "host %q", hostname, err) + require.NotEmpty(t, res.Rules, "host %q", hostname) + + assert.Truef(t, res.IsFiltered, "host %q", hostname) - require.NotEmpty(t, res.Rules, "Expected result to have rules") r := res.Rules[0] require.NotNilf(t, r.IP, "Expected ip %s to match, actual: %v", ip, r.IP) - assert.Equalf(t, ip, r.IP.String(), "Expected ip %s to match, actual: %v", ip, r.IP) + + assert.Equalf(t, ip, r.IP.String(), "host %q", hostname) } func (d *DNSFilter) checkMatchEmpty(t *testing.T, hostname string) { t.Helper() res, err := d.CheckHost(hostname, dns.TypeA, &setts) - require.Nilf(t, err, "Error while matching host %s: %s", hostname, err) - assert.Falsef(t, res.IsFiltered, "Expected hostname %s to not match", hostname) + require.NoErrorf(t, err, "host %q", hostname) + + assert.Falsef(t, res.IsFiltered, "host %q", hostname) } func TestEtcHostsMatching(t *testing.T) { @@ -112,10 +116,12 @@ func TestEtcHostsMatching(t *testing.T) { // Empty IPv6. res, err := d.CheckHost("block.com", dns.TypeAAAA, &setts) - require.Nil(t, err) + require.NoError(t, err) + assert.True(t, res.IsFiltered) require.Len(t, res.Rules, 1) + assert.Equal(t, "0.0.0.0 block.com", res.Rules[0].Text) assert.Empty(t, res.Rules[0].IP) @@ -124,27 +130,34 @@ func TestEtcHostsMatching(t *testing.T) { // Empty IPv4. res, err = d.CheckHost("ipv6.com", dns.TypeA, &setts) - require.Nil(t, err) + require.NoError(t, err) + assert.True(t, res.IsFiltered) require.Len(t, res.Rules, 1) + assert.Equal(t, "::1 ipv6.com", res.Rules[0].Text) assert.Empty(t, res.Rules[0].IP) - // Two IPv4, the first one returned. + // Two IPv4, both must be returned. res, err = d.CheckHost("host2", dns.TypeA, &setts) - require.Nil(t, err) + require.NoError(t, err) + assert.True(t, res.IsFiltered) - require.Len(t, res.Rules, 1) + require.Len(t, res.Rules, 2) + assert.Equal(t, res.Rules[0].IP, net.IP{0, 0, 0, 1}) + assert.Equal(t, res.Rules[1].IP, net.IP{0, 0, 0, 2}) // One IPv6 address. res, err = d.CheckHost("host2", dns.TypeAAAA, &setts) - require.Nil(t, err) + require.NoError(t, err) + assert.True(t, res.IsFiltered) require.Len(t, res.Rules, 1) + assert.Equal(t, res.Rules[0].IP, net.IPv6loopback) } @@ -205,8 +218,9 @@ func TestSafeSearch(t *testing.T) { d := newForTest(&Config{SafeSearchEnabled: true}, nil) t.Cleanup(d.Close) val, ok := d.SafeSearchDomain("www.google.com") - require.True(t, ok, "Expected safesearch to find result for www.google.com") - assert.Equal(t, "forcesafesearch.google.com", val, "Expected safesearch for google.com to be forcesafesearch.google.com") + require.True(t, ok) + + assert.Equal(t, "forcesafesearch.google.com", val) } func TestCheckHostSafeSearchYandex(t *testing.T) { @@ -226,10 +240,12 @@ func TestCheckHostSafeSearchYandex(t *testing.T) { } { t.Run(strings.ToLower(host), func(t *testing.T) { res, err := d.CheckHost(host, dns.TypeA, &setts) - require.Nil(t, err) + require.NoError(t, err) + assert.True(t, res.IsFiltered) require.Len(t, res.Rules, 1) + assert.Equal(t, yandexIP, res.Rules[0].IP) }) } @@ -257,9 +273,12 @@ func TestCheckHostSafeSearchGoogle(t *testing.T) { } { t.Run(host, func(t *testing.T) { res, err := d.CheckHost(host, dns.TypeA, &setts) - require.Nil(t, err) + require.NoError(t, err) + assert.True(t, res.IsFiltered) + require.Len(t, res.Rules, 1) + assert.Equal(t, ip, res.Rules[0].IP) }) } @@ -272,8 +291,10 @@ func TestSafeSearchCacheYandex(t *testing.T) { // Check host with disabled safesearch. res, err := d.CheckHost(domain, dns.TypeA, &setts) - require.Nil(t, err) + require.NoError(t, err) + assert.False(t, res.IsFiltered) + require.Empty(t, res.Rules) yandexIP := net.IPv4(213, 180, 193, 56) @@ -282,7 +303,7 @@ func TestSafeSearchCacheYandex(t *testing.T) { t.Cleanup(d.Close) res, err = d.CheckHost(domain, dns.TypeA, &setts) - require.Nilf(t, err, "CheckHost for safesearh domain %s failed cause %s", domain, err) + require.NoError(t, err) // For yandex we already know valid IP. require.Len(t, res.Rules, 1) @@ -292,6 +313,7 @@ func TestSafeSearchCacheYandex(t *testing.T) { cachedValue, isFound := getCachedResult(gctx.safeSearchCache, domain) require.True(t, isFound) require.Len(t, cachedValue.Rules, 1) + assert.Equal(t, cachedValue.Rules[0].IP, yandexIP) } @@ -304,8 +326,10 @@ func TestSafeSearchCacheGoogle(t *testing.T) { const domain = "www.google.ru" res, err := d.CheckHost(domain, dns.TypeA, &setts) - require.Nil(t, err) + require.NoError(t, err) + assert.False(t, res.IsFiltered) + require.Empty(t, res.Rules) d = newForTest(&Config{SafeSearchEnabled: true}, nil) @@ -314,10 +338,10 @@ func TestSafeSearchCacheGoogle(t *testing.T) { // Lookup for safesearch domain. safeDomain, ok := d.SafeSearchDomain(domain) - require.Truef(t, ok, "Failed to get safesearch domain for %s", domain) + require.True(t, ok) ips, err := resolver.LookupIP(context.Background(), "ip", safeDomain) - require.Nilf(t, err, "Failed to lookup for %s", safeDomain) + require.NoError(t, err) var ip net.IP for _, foundIP := range ips { @@ -329,14 +353,16 @@ func TestSafeSearchCacheGoogle(t *testing.T) { } res, err = d.CheckHost(domain, dns.TypeA, &setts) - require.Nil(t, err) + require.NoError(t, err) require.Len(t, res.Rules, 1) + assert.True(t, res.Rules[0].IP.Equal(ip)) // Check cache. cachedValue, isFound := getCachedResult(gctx.safeSearchCache, domain) require.True(t, isFound) require.Len(t, cachedValue.Rules, 1) + assert.True(t, cachedValue.Rules[0].IP.Equal(ip)) } @@ -357,6 +383,7 @@ func TestParentalControl(t *testing.T) { d.checkMatch(t, matching) require.Contains(t, logOutput.String(), "Parental lookup for "+matching) + d.checkMatch(t, "www."+matching) d.checkMatchEmpty(t, "www.yandex.ru") d.checkMatchEmpty(t, "yandex.ru") @@ -654,7 +681,8 @@ func TestMatching(t *testing.T) { t.Cleanup(d.Close) res, err := d.CheckHost(tc.host, tc.wantDNSType, &setts) - require.Nilf(t, err, "Error while matching host %s: %s", tc.host, err) + require.NoError(t, err) + assert.Equalf(t, tc.wantIsFiltered, res.IsFiltered, "Hostname %s has wrong result (%v must be %v)", tc.host, res.IsFiltered, tc.wantIsFiltered) assert.Equalf(t, tc.wantReason, res.Reason, "Hostname %s has wrong reason (%v must be %v)", tc.host, res.Reason, tc.wantReason) }) @@ -677,23 +705,31 @@ func TestWhitelist(t *testing.T) { }} d := newForTest(nil, filters) - require.Nil(t, d.SetFilters(filters, whiteFilters, false)) + err := d.SetFilters(filters, whiteFilters, false) + require.NoError(t, err) + t.Cleanup(d.Close) // Matched by white filter. res, err := d.CheckHost("host1", dns.TypeA, &setts) - require.Nil(t, err) + require.NoError(t, err) + assert.False(t, res.IsFiltered) assert.Equal(t, res.Reason, NotFilteredAllowList) + require.Len(t, res.Rules, 1) + assert.Equal(t, "||host1^", res.Rules[0].Text) // Not matched by white filter, but matched by block filter. res, err = d.CheckHost("host2", dns.TypeA, &setts) - require.Nil(t, err) + require.NoError(t, err) + assert.True(t, res.IsFiltered) assert.Equal(t, res.Reason, FilteredBlockList) + require.Len(t, res.Rules, 1) + assert.Equal(t, "||host2^", res.Rules[0].Text) } @@ -796,7 +832,8 @@ func BenchmarkSafeBrowsing(b *testing.B) { }) for n := 0; n < b.N; n++ { res, err := d.CheckHost(blocked, dns.TypeA, &setts) - require.Nilf(b, err, "Error while matching host %s: %s", blocked, err) + require.NoError(b, err) + assert.True(b, res.IsFiltered, "Expected hostname %s to match", blocked) } } @@ -812,7 +849,8 @@ func BenchmarkSafeBrowsingParallel(b *testing.B) { b.RunParallel(func(pb *testing.PB) { for pb.Next() { res, err := d.CheckHost(blocked, dns.TypeA, &setts) - require.Nilf(b, err, "Error while matching host %s: %s", blocked, err) + require.NoError(b, err) + assert.True(b, res.IsFiltered, "Expected hostname %s to match", blocked) } }) @@ -823,7 +861,8 @@ func BenchmarkSafeSearch(b *testing.B) { b.Cleanup(d.Close) for n := 0; n < b.N; n++ { val, ok := d.SafeSearchDomain("www.google.com") - require.True(b, ok, "Expected safesearch to find result for www.google.com") + require.True(b, ok) + assert.Equal(b, "forcesafesearch.google.com", val, "Expected safesearch for google.com to be forcesafesearch.google.com") } } @@ -834,7 +873,8 @@ func BenchmarkSafeSearchParallel(b *testing.B) { b.RunParallel(func(pb *testing.PB) { for pb.Next() { val, ok := d.SafeSearchDomain("www.google.com") - require.True(b, ok, "Expected safesearch to find result for www.google.com") + require.True(b, ok) + assert.Equal(b, "forcesafesearch.google.com", val, "Expected safesearch for google.com to be forcesafesearch.google.com") } })