Pull request: all: sup many ips in host rules
Closes #1381. Squashed commit of the following: commit 44965f74d8becb27173cafc533e56e1cde484b59 Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Thu Jun 17 13:54:44 2021 +0300 dnsforward: imp code, docs commit 23736cf9b407b668faade19b61739536caafff03 Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Wed Jun 16 20:33:59 2021 +0300 dnsforward: rm todo commit ff086756160d72f3cdbe662862dd1fb447ac5ff3 Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Wed Jun 16 20:32:01 2021 +0300 all: sup many ips in host rules
This commit is contained in:
parent
84e71e912e
commit
7547d3a422
|
@ -34,6 +34,8 @@ and this project adheres to
|
||||||
|
|
||||||
### Changed
|
### Changed
|
||||||
|
|
||||||
|
- When /etc/hosts-type rules have several IPs for one host, all IPs are now
|
||||||
|
returned instead of only the first one ([#1381]).
|
||||||
- The setting `rlimit_nofile` is now in the `os` block of the configuration
|
- The setting `rlimit_nofile` is now in the `os` block of the configuration
|
||||||
file, together with the new `group` and `user` settings ([#2763]).
|
file, together with the new `group` and `user` settings ([#2763]).
|
||||||
- Permissions on filter files are now `0o644` instead of `0o600` ([#3198]).
|
- Permissions on filter files are now `0o644` instead of `0o600` ([#3198]).
|
||||||
|
@ -63,6 +65,7 @@ released by then.
|
||||||
|
|
||||||
- Go 1.15 support.
|
- Go 1.15 support.
|
||||||
|
|
||||||
|
[#1381]: https://github.com/AdguardTeam/AdGuardHome/issues/1381
|
||||||
[#1691]: https://github.com/AdguardTeam/AdGuardHome/issues/1691
|
[#1691]: https://github.com/AdguardTeam/AdGuardHome/issues/1691
|
||||||
[#2280]: https://github.com/AdguardTeam/AdGuardHome/issues/2280
|
[#2280]: https://github.com/AdguardTeam/AdGuardHome/issues/2280
|
||||||
[#2439]: https://github.com/AdguardTeam/AdGuardHome/issues/2439
|
[#2439]: https://github.com/AdguardTeam/AdGuardHome/issues/2439
|
||||||
|
|
|
@ -21,6 +21,30 @@ import (
|
||||||
"github.com/ameshkov/dnscrypt/v2"
|
"github.com/ameshkov/dnscrypt/v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// BlockingMode is an enum of all allowed blocking modes.
|
||||||
|
type BlockingMode string
|
||||||
|
|
||||||
|
// Allowed blocking modes.
|
||||||
|
const (
|
||||||
|
// BlockingModeCustomIP means respond with a custom IP address.
|
||||||
|
BlockingModeCustomIP BlockingMode = "custom_ip"
|
||||||
|
|
||||||
|
// BlockingModeDefault is the same as BlockingModeNullIP for
|
||||||
|
// Adblock-style rules, but responds with the IP address specified in
|
||||||
|
// the rule when blocked by an `/etc/hosts`-style rule.
|
||||||
|
BlockingModeDefault BlockingMode = "default"
|
||||||
|
|
||||||
|
// BlockingModeNullIP means respond with a zero IP address: "0.0.0.0"
|
||||||
|
// for A requests and "::" for AAAA ones.
|
||||||
|
BlockingModeNullIP BlockingMode = "null_ip"
|
||||||
|
|
||||||
|
// BlockingModeNXDOMAIN means respond with the NXDOMAIN code.
|
||||||
|
BlockingModeNXDOMAIN BlockingMode = "nxdomain"
|
||||||
|
|
||||||
|
// BlockingModeREFUSED means respond with the REFUSED code.
|
||||||
|
BlockingModeREFUSED BlockingMode = "refused"
|
||||||
|
)
|
||||||
|
|
||||||
// FilteringConfig represents the DNS filtering configuration of AdGuard Home
|
// FilteringConfig represents the DNS filtering configuration of AdGuard Home
|
||||||
// The zero FilteringConfig is empty and ready for use.
|
// The zero FilteringConfig is empty and ready for use.
|
||||||
type FilteringConfig struct {
|
type FilteringConfig struct {
|
||||||
|
@ -39,7 +63,7 @@ type FilteringConfig struct {
|
||||||
// --
|
// --
|
||||||
|
|
||||||
ProtectionEnabled bool `yaml:"protection_enabled"` // whether or not use any of filtering features
|
ProtectionEnabled bool `yaml:"protection_enabled"` // whether or not use any of filtering features
|
||||||
BlockingMode string `yaml:"blocking_mode"` // mode how to answer filtered requests
|
BlockingMode BlockingMode `yaml:"blocking_mode"` // mode how to answer filtered requests
|
||||||
BlockingIPv4 net.IP `yaml:"blocking_ipv4"` // IP address to be returned for a blocked A request
|
BlockingIPv4 net.IP `yaml:"blocking_ipv4"` // IP address to be returned for a blocked A request
|
||||||
BlockingIPv6 net.IP `yaml:"blocking_ipv6"` // IP address to be returned for a blocked AAAA request
|
BlockingIPv6 net.IP `yaml:"blocking_ipv6"` // IP address to be returned for a blocked AAAA request
|
||||||
BlockedResponseTTL uint32 `yaml:"blocked_response_ttl"` // if 0, then default is used (3600)
|
BlockedResponseTTL uint32 `yaml:"blocked_response_ttl"` // if 0, then default is used (3600)
|
||||||
|
|
|
@ -520,6 +520,7 @@ func TestBlockedRequest(t *testing.T) {
|
||||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||||
FilteringConfig: FilteringConfig{
|
FilteringConfig: FilteringConfig{
|
||||||
ProtectionEnabled: true,
|
ProtectionEnabled: true,
|
||||||
|
BlockingMode: BlockingModeDefault,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
s := createTestServer(t, &filtering.Config{}, forwardConf, nil)
|
s := createTestServer(t, &filtering.Config{}, forwardConf, nil)
|
||||||
|
@ -622,6 +623,7 @@ func TestBlockCNAME(t *testing.T) {
|
||||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||||
FilteringConfig: FilteringConfig{
|
FilteringConfig: FilteringConfig{
|
||||||
ProtectionEnabled: true,
|
ProtectionEnabled: true,
|
||||||
|
BlockingMode: BlockingModeDefault,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
s := createTestServer(t, &filtering.Config{}, forwardConf, nil)
|
s := createTestServer(t, &filtering.Config{}, forwardConf, nil)
|
||||||
|
@ -724,7 +726,7 @@ func TestNullBlockedRequest(t *testing.T) {
|
||||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||||
FilteringConfig: FilteringConfig{
|
FilteringConfig: FilteringConfig{
|
||||||
ProtectionEnabled: true,
|
ProtectionEnabled: true,
|
||||||
BlockingMode: "null_ip",
|
BlockingMode: BlockingModeNullIP,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
s := createTestServer(t, &filtering.Config{}, forwardConf, nil)
|
s := createTestServer(t, &filtering.Config{}, forwardConf, nil)
|
||||||
|
@ -777,7 +779,7 @@ func TestBlockedCustomIP(t *testing.T) {
|
||||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||||
FilteringConfig: FilteringConfig{
|
FilteringConfig: FilteringConfig{
|
||||||
ProtectionEnabled: true,
|
ProtectionEnabled: true,
|
||||||
BlockingMode: "custom_ip",
|
BlockingMode: BlockingModeCustomIP,
|
||||||
BlockingIPv4: nil,
|
BlockingIPv4: nil,
|
||||||
UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"},
|
UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"},
|
||||||
},
|
},
|
||||||
|
@ -827,6 +829,7 @@ func TestBlockedByHosts(t *testing.T) {
|
||||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||||
FilteringConfig: FilteringConfig{
|
FilteringConfig: FilteringConfig{
|
||||||
ProtectionEnabled: true,
|
ProtectionEnabled: true,
|
||||||
|
BlockingMode: BlockingModeDefault,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -31,7 +31,7 @@ type dnsConfig struct {
|
||||||
|
|
||||||
ProtectionEnabled *bool `json:"protection_enabled"`
|
ProtectionEnabled *bool `json:"protection_enabled"`
|
||||||
RateLimit *uint32 `json:"ratelimit"`
|
RateLimit *uint32 `json:"ratelimit"`
|
||||||
BlockingMode *string `json:"blocking_mode"`
|
BlockingMode *BlockingMode `json:"blocking_mode"`
|
||||||
BlockingIPv4 net.IP `json:"blocking_ipv4"`
|
BlockingIPv4 net.IP `json:"blocking_ipv4"`
|
||||||
BlockingIPv6 net.IP `json:"blocking_ipv6"`
|
BlockingIPv6 net.IP `json:"blocking_ipv6"`
|
||||||
EDNSCSEnabled *bool `json:"edns_cs_enabled"`
|
EDNSCSEnabled *bool `json:"edns_cs_enabled"`
|
||||||
|
@ -126,27 +126,17 @@ func (req *dnsConfig) checkBlockingMode() bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
bm := *req.BlockingMode
|
switch bm := *req.BlockingMode; bm {
|
||||||
if bm == "custom_ip" {
|
case BlockingModeDefault,
|
||||||
if req.BlockingIPv4.To4() == nil {
|
BlockingModeREFUSED,
|
||||||
return false
|
BlockingModeNXDOMAIN,
|
||||||
}
|
BlockingModeNullIP:
|
||||||
|
|
||||||
return req.BlockingIPv6 != nil
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, valid := range []string{
|
|
||||||
"default",
|
|
||||||
"refused",
|
|
||||||
"nxdomain",
|
|
||||||
"null_ip",
|
|
||||||
} {
|
|
||||||
if bm == valid {
|
|
||||||
return true
|
return true
|
||||||
}
|
case BlockingModeCustomIP:
|
||||||
}
|
return req.BlockingIPv4.To4() != nil && req.BlockingIPv6 != nil
|
||||||
|
default:
|
||||||
return false
|
return false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (req *dnsConfig) checkUpstreamsMode() bool {
|
func (req *dnsConfig) checkUpstreamsMode() bool {
|
||||||
|
|
|
@ -26,17 +26,29 @@ func (s *Server) makeResponse(req *dns.Msg) (resp *dns.Msg) {
|
||||||
return resp
|
return resp
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ipsFromRules extracts non-IP addresses from the filtering result rules.
|
||||||
|
func ipsFromRules(resRules []*filtering.ResultRule) (ips []net.IP) {
|
||||||
|
for _, r := range resRules {
|
||||||
|
if r.IP != nil {
|
||||||
|
ips = append(ips, r.IP)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ips
|
||||||
|
}
|
||||||
|
|
||||||
// genDNSFilterMessage generates a DNS message corresponding to the filtering result
|
// genDNSFilterMessage generates a DNS message corresponding to the filtering result
|
||||||
func (s *Server) genDNSFilterMessage(d *proxy.DNSContext, result *filtering.Result) *dns.Msg {
|
func (s *Server) genDNSFilterMessage(d *proxy.DNSContext, result *filtering.Result) *dns.Msg {
|
||||||
m := d.Req
|
m := d.Req
|
||||||
|
|
||||||
if m.Question[0].Qtype != dns.TypeA && m.Question[0].Qtype != dns.TypeAAAA {
|
if m.Question[0].Qtype != dns.TypeA && m.Question[0].Qtype != dns.TypeAAAA {
|
||||||
if s.conf.BlockingMode == "null_ip" {
|
if s.conf.BlockingMode == BlockingModeNullIP {
|
||||||
return s.makeResponse(m)
|
return s.makeResponse(m)
|
||||||
}
|
}
|
||||||
return s.genNXDomain(m)
|
return s.genNXDomain(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ips := ipsFromRules(result.Rules)
|
||||||
switch result.Reason {
|
switch result.Reason {
|
||||||
case filtering.FilteredSafeBrowsing:
|
case filtering.FilteredSafeBrowsing:
|
||||||
return s.genBlockedHost(m, s.conf.SafeBrowsingBlockHost, d)
|
return s.genBlockedHost(m, s.conf.SafeBrowsingBlockHost, d)
|
||||||
|
@ -46,42 +58,45 @@ func (s *Server) genDNSFilterMessage(d *proxy.DNSContext, result *filtering.Resu
|
||||||
// If the query was filtered by "Safe search", filtering also must return
|
// If the query was filtered by "Safe search", filtering also must return
|
||||||
// the IP address that must be used in response.
|
// the IP address that must be used in response.
|
||||||
// In this case regardless of the filtering method, we should return it
|
// In this case regardless of the filtering method, we should return it
|
||||||
if result.Reason == filtering.FilteredSafeSearch &&
|
if result.Reason == filtering.FilteredSafeSearch && len(ips) > 0 {
|
||||||
len(result.Rules) > 0 &&
|
return s.genResponseWithIPs(m, ips)
|
||||||
result.Rules[0].IP != nil {
|
|
||||||
return s.genResponseWithIP(m, result.Rules[0].IP)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.conf.BlockingMode == "null_ip" {
|
switch s.conf.BlockingMode {
|
||||||
// it means that we should return 0.0.0.0 or :: for any blocked request
|
case BlockingModeCustomIP:
|
||||||
return s.makeResponseNullIP(m)
|
|
||||||
} else if s.conf.BlockingMode == "custom_ip" {
|
|
||||||
// means that we should return custom IP for any blocked request
|
|
||||||
|
|
||||||
switch m.Question[0].Qtype {
|
switch m.Question[0].Qtype {
|
||||||
case dns.TypeA:
|
case dns.TypeA:
|
||||||
return s.genARecord(m, s.conf.BlockingIPv4)
|
return s.genARecord(m, s.conf.BlockingIPv4)
|
||||||
case dns.TypeAAAA:
|
case dns.TypeAAAA:
|
||||||
return s.genAAAARecord(m, s.conf.BlockingIPv6)
|
return s.genAAAARecord(m, s.conf.BlockingIPv6)
|
||||||
|
default:
|
||||||
|
// Generally shouldn't happen, since the types
|
||||||
|
// are checked above.
|
||||||
|
log.Error(
|
||||||
|
"dns: invalid msg type %d for blocking mode %s",
|
||||||
|
m.Question[0].Qtype,
|
||||||
|
s.conf.BlockingMode,
|
||||||
|
)
|
||||||
|
|
||||||
|
return s.makeResponse(m)
|
||||||
}
|
}
|
||||||
} else if s.conf.BlockingMode == "nxdomain" {
|
case BlockingModeDefault:
|
||||||
// means that we should return NXDOMAIN for any blocked request
|
if len(ips) > 0 {
|
||||||
|
return s.genResponseWithIPs(m, ips)
|
||||||
return s.genNXDomain(m)
|
|
||||||
} else if s.conf.BlockingMode == "refused" {
|
|
||||||
// means that we should return NXDOMAIN for any blocked request
|
|
||||||
|
|
||||||
return s.makeResponseREFUSED(m)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Default blocking mode
|
|
||||||
// If there's an IP specified in the rule, return it
|
|
||||||
// For host-type rules, return null IP
|
|
||||||
if len(result.Rules) > 0 && result.Rules[0].IP != nil {
|
|
||||||
return s.genResponseWithIP(m, result.Rules[0].IP)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.makeResponseNullIP(m)
|
return s.makeResponseNullIP(m)
|
||||||
|
case BlockingModeNullIP:
|
||||||
|
return s.makeResponseNullIP(m)
|
||||||
|
case BlockingModeNXDOMAIN:
|
||||||
|
return s.genNXDomain(m)
|
||||||
|
case BlockingModeREFUSED:
|
||||||
|
return s.makeResponseREFUSED(m)
|
||||||
|
default:
|
||||||
|
log.Error("dns: invalid blocking mode %q", s.conf.BlockingMode)
|
||||||
|
|
||||||
|
return s.makeResponse(m)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -166,35 +181,60 @@ func (s *Server) genAnswerTXT(req *dns.Msg, strs []string) (ans *dns.TXT) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// generate DNS response message with an IP address
|
// genResponseWithIPs generates a DNS response message with the provided IP
|
||||||
func (s *Server) genResponseWithIP(req *dns.Msg, ip net.IP) *dns.Msg {
|
// addresses and an appropriate resource record type. If any of the IPs cannot
|
||||||
if req.Question[0].Qtype == dns.TypeA && ip.To4() != nil {
|
// be converted to the correct protocol, genResponseWithIPs returns an empty
|
||||||
return s.genARecord(req, ip.To4())
|
// response.
|
||||||
} else if req.Question[0].Qtype == dns.TypeAAAA &&
|
func (s *Server) genResponseWithIPs(req *dns.Msg, ips []net.IP) (resp *dns.Msg) {
|
||||||
len(ip) == net.IPv6len && ip.To4() == nil {
|
var ans []dns.RR
|
||||||
return s.genAAAARecord(req, ip)
|
switch req.Question[0].Qtype {
|
||||||
|
case dns.TypeA:
|
||||||
|
for _, ip := range ips {
|
||||||
|
if ip4 := ip.To4(); ip4 == nil {
|
||||||
|
ans = nil
|
||||||
|
|
||||||
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
// empty response
|
ans = append(ans, s.genAnswerA(req, ip))
|
||||||
resp := s.makeResponse(req)
|
}
|
||||||
|
case dns.TypeAAAA:
|
||||||
|
for _, ip := range ips {
|
||||||
|
ans = append(ans, s.genAnswerAAAA(req, ip.To16()))
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// Go on and return an empty response.
|
||||||
|
}
|
||||||
|
|
||||||
|
resp = s.makeResponse(req)
|
||||||
|
resp.Answer = ans
|
||||||
|
|
||||||
return resp
|
return resp
|
||||||
}
|
}
|
||||||
|
|
||||||
// Respond with 0.0.0.0 for A, :: for AAAA, empty response for other types
|
// makeResponseNullIP creates a response with 0.0.0.0 for A requests, :: for
|
||||||
func (s *Server) makeResponseNullIP(req *dns.Msg) *dns.Msg {
|
// AAAA requests, and an empty response for other types.
|
||||||
if req.Question[0].Qtype == dns.TypeA {
|
func (s *Server) makeResponseNullIP(req *dns.Msg) (resp *dns.Msg) {
|
||||||
return s.genARecord(req, []byte{0, 0, 0, 0})
|
// Respond with the corresponding zero IP type as opposed to simply
|
||||||
} else if req.Question[0].Qtype == dns.TypeAAAA {
|
// using one or the other in both cases, because the IPv4 zero IP is
|
||||||
return s.genAAAARecord(req, net.IPv6zero)
|
// converted to a IPV6-mapped IPv4 address, while the IPv6 zero IP is
|
||||||
|
// converted into an empty slice instead of the zero IPv4.
|
||||||
|
switch req.Question[0].Qtype {
|
||||||
|
case dns.TypeA:
|
||||||
|
resp = s.genResponseWithIPs(req, []net.IP{{0, 0, 0, 0}})
|
||||||
|
case dns.TypeAAAA:
|
||||||
|
resp = s.genResponseWithIPs(req, []net.IP{net.IPv6zero})
|
||||||
|
default:
|
||||||
|
resp = s.makeResponse(req)
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.makeResponse(req)
|
return resp
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSContext) *dns.Msg {
|
func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSContext) *dns.Msg {
|
||||||
ip := net.ParseIP(newAddr)
|
ip := net.ParseIP(newAddr)
|
||||||
if ip != nil {
|
if ip != nil {
|
||||||
return s.genResponseWithIP(request, ip)
|
return s.genResponseWithIPs(request, []net.IP{ip})
|
||||||
}
|
}
|
||||||
|
|
||||||
// look up the hostname, TODO: cache
|
// look up the hostname, TODO: cache
|
||||||
|
|
|
@ -344,13 +344,13 @@ var gctx dnsFilterContext
|
||||||
|
|
||||||
// ResultRule contains information about applied rules.
|
// ResultRule contains information about applied rules.
|
||||||
type ResultRule struct {
|
type ResultRule struct {
|
||||||
// FilterListID is the ID of the rule's filter list.
|
|
||||||
FilterListID int64 `json:",omitempty"`
|
|
||||||
// Text is the text of the rule.
|
// Text is the text of the rule.
|
||||||
Text string `json:",omitempty"`
|
Text string `json:",omitempty"`
|
||||||
// IP is the host IP. It is nil unless the rule uses the
|
// IP is the host IP. It is nil unless the rule uses the
|
||||||
// /etc/hosts syntax or the reason is FilteredSafeSearch.
|
// /etc/hosts syntax or the reason is FilteredSafeSearch.
|
||||||
IP net.IP `json:",omitempty"`
|
IP net.IP `json:",omitempty"`
|
||||||
|
// FilterListID is the ID of the rule's filter list.
|
||||||
|
FilterListID int64 `json:",omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Result contains the result of a request check.
|
// Result contains the result of a request check.
|
||||||
|
@ -657,26 +657,43 @@ func (d *DNSFilter) initFiltering(allowFilters, blockFilters []Filter) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// matchHostProcessAllowList processes the allowlist logic of host
|
// hostRules is a helper that converts a slice of host rules into a slice of the
|
||||||
// matching.
|
// rules.Rule interface values.
|
||||||
func (d *DNSFilter) matchHostProcessAllowList(host string, dnsres urlfilter.DNSResult) (res Result, err error) {
|
func hostRulesToRules(netRules []*rules.HostRule) (res []rules.Rule) {
|
||||||
var rule rules.Rule
|
if netRules == nil {
|
||||||
if dnsres.NetworkRule != nil {
|
return nil
|
||||||
rule = dnsres.NetworkRule
|
|
||||||
} else if len(dnsres.HostRulesV4) > 0 {
|
|
||||||
rule = dnsres.HostRulesV4[0]
|
|
||||||
} else if len(dnsres.HostRulesV6) > 0 {
|
|
||||||
rule = dnsres.HostRulesV6[0]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if rule == nil {
|
res = make([]rules.Rule, len(netRules))
|
||||||
|
for i, nr := range netRules {
|
||||||
|
res[i] = nr
|
||||||
|
}
|
||||||
|
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
// matchHostProcessAllowList processes the allowlist logic of host
|
||||||
|
// matching.
|
||||||
|
func (d *DNSFilter) matchHostProcessAllowList(
|
||||||
|
host string,
|
||||||
|
dnsres urlfilter.DNSResult,
|
||||||
|
) (res Result, err error) {
|
||||||
|
var matchedRules []rules.Rule
|
||||||
|
if dnsres.NetworkRule != nil {
|
||||||
|
matchedRules = []rules.Rule{dnsres.NetworkRule}
|
||||||
|
} else if len(dnsres.HostRulesV4) > 0 {
|
||||||
|
matchedRules = hostRulesToRules(dnsres.HostRulesV4)
|
||||||
|
} else if len(dnsres.HostRulesV6) > 0 {
|
||||||
|
matchedRules = hostRulesToRules(dnsres.HostRulesV6)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(matchedRules) == 0 {
|
||||||
return Result{}, fmt.Errorf("invalid dns result: rules are empty")
|
return Result{}, fmt.Errorf("invalid dns result: rules are empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug("Filtering: found allowlist rule for host %q: %q list_id: %d",
|
log.Debug("filtering: allowlist rules for host %q: %+v", host, matchedRules)
|
||||||
host, rule.Text(), rule.GetFilterListID())
|
|
||||||
|
|
||||||
return makeResult(rule, NotFilteredAllowList), nil
|
return makeResult(matchedRules, NotFilteredAllowList), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// matchHostProcessDNSResult processes the matched DNS filtering result.
|
// matchHostProcessDNSResult processes the matched DNS filtering result.
|
||||||
|
@ -690,21 +707,23 @@ func (d *DNSFilter) matchHostProcessDNSResult(
|
||||||
reason = NotFilteredAllowList
|
reason = NotFilteredAllowList
|
||||||
}
|
}
|
||||||
|
|
||||||
return makeResult(dnsres.NetworkRule, reason)
|
return makeResult([]rules.Rule{dnsres.NetworkRule}, reason)
|
||||||
}
|
}
|
||||||
|
|
||||||
if qtype == dns.TypeA && dnsres.HostRulesV4 != nil {
|
if qtype == dns.TypeA && dnsres.HostRulesV4 != nil {
|
||||||
rule := dnsres.HostRulesV4[0]
|
res = makeResult(hostRulesToRules(dnsres.HostRulesV4), FilteredBlockList)
|
||||||
res = makeResult(rule, FilteredBlockList)
|
for i, hr := range dnsres.HostRulesV4 {
|
||||||
res.Rules[0].IP = rule.IP.To4()
|
res.Rules[i].IP = hr.IP.To4()
|
||||||
|
}
|
||||||
|
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
if qtype == dns.TypeAAAA && dnsres.HostRulesV6 != nil {
|
if qtype == dns.TypeAAAA && dnsres.HostRulesV6 != nil {
|
||||||
rule := dnsres.HostRulesV6[0]
|
res = makeResult(hostRulesToRules(dnsres.HostRulesV6), FilteredBlockList)
|
||||||
res = makeResult(rule, FilteredBlockList)
|
for i, hr := range dnsres.HostRulesV6 {
|
||||||
res.Rules[0].IP = rule.IP.To16()
|
res.Rules[i].IP = hr.IP.To16()
|
||||||
|
}
|
||||||
|
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
@ -712,17 +731,14 @@ func (d *DNSFilter) matchHostProcessDNSResult(
|
||||||
if dnsres.HostRulesV4 != nil || dnsres.HostRulesV6 != nil {
|
if dnsres.HostRulesV4 != nil || dnsres.HostRulesV6 != nil {
|
||||||
// Question type doesn't match the host rules. Return the first
|
// Question type doesn't match the host rules. Return the first
|
||||||
// matched host rule, but without an IP address.
|
// matched host rule, but without an IP address.
|
||||||
var rule rules.Rule
|
var matchedRules []rules.Rule
|
||||||
if dnsres.HostRulesV4 != nil {
|
if dnsres.HostRulesV4 != nil {
|
||||||
rule = dnsres.HostRulesV4[0]
|
matchedRules = []rules.Rule{dnsres.HostRulesV4[0]}
|
||||||
} else if dnsres.HostRulesV6 != nil {
|
} else if dnsres.HostRulesV6 != nil {
|
||||||
rule = dnsres.HostRulesV6[0]
|
matchedRules = []rules.Rule{dnsres.HostRulesV6[0]}
|
||||||
}
|
}
|
||||||
|
|
||||||
res = makeResult(rule, FilteredBlockList)
|
return makeResult(matchedRules, FilteredBlockList)
|
||||||
res.Rules[0].IP = net.IP{}
|
|
||||||
|
|
||||||
return res
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return Result{}
|
return Result{}
|
||||||
|
@ -780,8 +796,7 @@ func (d *DNSFilter) matchHost(
|
||||||
}
|
}
|
||||||
|
|
||||||
res = d.matchHostProcessDNSResult(qtype, dnsres)
|
res = d.matchHostProcessDNSResult(qtype, dnsres)
|
||||||
if len(res.Rules) > 0 {
|
for _, r := range res.Rules {
|
||||||
r := res.Rules[0]
|
|
||||||
log.Debug(
|
log.Debug(
|
||||||
"filtering: found rule %q for host %q, filter list id: %d",
|
"filtering: found rule %q for host %q, filter list id: %d",
|
||||||
r.Text,
|
r.Text,
|
||||||
|
@ -794,20 +809,20 @@ func (d *DNSFilter) matchHost(
|
||||||
}
|
}
|
||||||
|
|
||||||
// makeResult returns a properly constructed Result.
|
// makeResult returns a properly constructed Result.
|
||||||
func makeResult(rule rules.Rule, reason Reason) Result {
|
func makeResult(matchedRules []rules.Rule, reason Reason) (res Result) {
|
||||||
res := Result{
|
resRules := make([]*ResultRule, len(matchedRules))
|
||||||
|
for i, mr := range matchedRules {
|
||||||
|
resRules[i] = &ResultRule{
|
||||||
|
FilterListID: int64(mr.GetFilterListID()),
|
||||||
|
Text: mr.Text(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Result{
|
||||||
|
IsFiltered: reason == FilteredBlockList,
|
||||||
Reason: reason,
|
Reason: reason,
|
||||||
Rules: []*ResultRule{{
|
Rules: resRules,
|
||||||
FilterListID: int64(rule.GetFilterListID()),
|
|
||||||
Text: rule.Text(),
|
|
||||||
}},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if reason == FilteredBlockList {
|
|
||||||
res.IsFiltered = true
|
|
||||||
}
|
|
||||||
|
|
||||||
return res
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// InitModule manually initializes blocked services map.
|
// InitModule manually initializes blocked services map.
|
||||||
|
|
|
@ -60,29 +60,33 @@ func (d *DNSFilter) checkMatch(t *testing.T, hostname string) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
res, err := d.CheckHost(hostname, dns.TypeA, &setts)
|
res, err := d.CheckHost(hostname, dns.TypeA, &setts)
|
||||||
require.Nilf(t, err, "Error while matching host %s: %s", hostname, err)
|
require.NoErrorf(t, err, "host %q", hostname)
|
||||||
assert.Truef(t, res.IsFiltered, "Expected hostname %s to match", hostname)
|
|
||||||
|
assert.Truef(t, res.IsFiltered, "host %q", hostname)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DNSFilter) checkMatchIP(t *testing.T, hostname, ip string, qtype uint16) {
|
func (d *DNSFilter) checkMatchIP(t *testing.T, hostname, ip string, qtype uint16) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
res, err := d.CheckHost(hostname, qtype, &setts)
|
res, err := d.CheckHost(hostname, qtype, &setts)
|
||||||
require.Nilf(t, err, "Error while matching host %s: %s", hostname, err)
|
require.NoErrorf(t, err, "host %q", hostname, err)
|
||||||
assert.Truef(t, res.IsFiltered, "Expected hostname %s to match", hostname)
|
require.NotEmpty(t, res.Rules, "host %q", hostname)
|
||||||
|
|
||||||
|
assert.Truef(t, res.IsFiltered, "host %q", hostname)
|
||||||
|
|
||||||
require.NotEmpty(t, res.Rules, "Expected result to have rules")
|
|
||||||
r := res.Rules[0]
|
r := res.Rules[0]
|
||||||
require.NotNilf(t, r.IP, "Expected ip %s to match, actual: %v", ip, r.IP)
|
require.NotNilf(t, r.IP, "Expected ip %s to match, actual: %v", ip, r.IP)
|
||||||
assert.Equalf(t, ip, r.IP.String(), "Expected ip %s to match, actual: %v", ip, r.IP)
|
|
||||||
|
assert.Equalf(t, ip, r.IP.String(), "host %q", hostname)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DNSFilter) checkMatchEmpty(t *testing.T, hostname string) {
|
func (d *DNSFilter) checkMatchEmpty(t *testing.T, hostname string) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
res, err := d.CheckHost(hostname, dns.TypeA, &setts)
|
res, err := d.CheckHost(hostname, dns.TypeA, &setts)
|
||||||
require.Nilf(t, err, "Error while matching host %s: %s", hostname, err)
|
require.NoErrorf(t, err, "host %q", hostname)
|
||||||
assert.Falsef(t, res.IsFiltered, "Expected hostname %s to not match", hostname)
|
|
||||||
|
assert.Falsef(t, res.IsFiltered, "host %q", hostname)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEtcHostsMatching(t *testing.T) {
|
func TestEtcHostsMatching(t *testing.T) {
|
||||||
|
@ -112,10 +116,12 @@ func TestEtcHostsMatching(t *testing.T) {
|
||||||
|
|
||||||
// Empty IPv6.
|
// Empty IPv6.
|
||||||
res, err := d.CheckHost("block.com", dns.TypeAAAA, &setts)
|
res, err := d.CheckHost("block.com", dns.TypeAAAA, &setts)
|
||||||
require.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.True(t, res.IsFiltered)
|
assert.True(t, res.IsFiltered)
|
||||||
|
|
||||||
require.Len(t, res.Rules, 1)
|
require.Len(t, res.Rules, 1)
|
||||||
|
|
||||||
assert.Equal(t, "0.0.0.0 block.com", res.Rules[0].Text)
|
assert.Equal(t, "0.0.0.0 block.com", res.Rules[0].Text)
|
||||||
assert.Empty(t, res.Rules[0].IP)
|
assert.Empty(t, res.Rules[0].IP)
|
||||||
|
|
||||||
|
@ -124,27 +130,34 @@ func TestEtcHostsMatching(t *testing.T) {
|
||||||
|
|
||||||
// Empty IPv4.
|
// Empty IPv4.
|
||||||
res, err = d.CheckHost("ipv6.com", dns.TypeA, &setts)
|
res, err = d.CheckHost("ipv6.com", dns.TypeA, &setts)
|
||||||
require.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.True(t, res.IsFiltered)
|
assert.True(t, res.IsFiltered)
|
||||||
|
|
||||||
require.Len(t, res.Rules, 1)
|
require.Len(t, res.Rules, 1)
|
||||||
|
|
||||||
assert.Equal(t, "::1 ipv6.com", res.Rules[0].Text)
|
assert.Equal(t, "::1 ipv6.com", res.Rules[0].Text)
|
||||||
assert.Empty(t, res.Rules[0].IP)
|
assert.Empty(t, res.Rules[0].IP)
|
||||||
|
|
||||||
// Two IPv4, the first one returned.
|
// Two IPv4, both must be returned.
|
||||||
res, err = d.CheckHost("host2", dns.TypeA, &setts)
|
res, err = d.CheckHost("host2", dns.TypeA, &setts)
|
||||||
require.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.True(t, res.IsFiltered)
|
assert.True(t, res.IsFiltered)
|
||||||
|
|
||||||
require.Len(t, res.Rules, 1)
|
require.Len(t, res.Rules, 2)
|
||||||
|
|
||||||
assert.Equal(t, res.Rules[0].IP, net.IP{0, 0, 0, 1})
|
assert.Equal(t, res.Rules[0].IP, net.IP{0, 0, 0, 1})
|
||||||
|
assert.Equal(t, res.Rules[1].IP, net.IP{0, 0, 0, 2})
|
||||||
|
|
||||||
// One IPv6 address.
|
// One IPv6 address.
|
||||||
res, err = d.CheckHost("host2", dns.TypeAAAA, &setts)
|
res, err = d.CheckHost("host2", dns.TypeAAAA, &setts)
|
||||||
require.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.True(t, res.IsFiltered)
|
assert.True(t, res.IsFiltered)
|
||||||
|
|
||||||
require.Len(t, res.Rules, 1)
|
require.Len(t, res.Rules, 1)
|
||||||
|
|
||||||
assert.Equal(t, res.Rules[0].IP, net.IPv6loopback)
|
assert.Equal(t, res.Rules[0].IP, net.IPv6loopback)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -205,8 +218,9 @@ func TestSafeSearch(t *testing.T) {
|
||||||
d := newForTest(&Config{SafeSearchEnabled: true}, nil)
|
d := newForTest(&Config{SafeSearchEnabled: true}, nil)
|
||||||
t.Cleanup(d.Close)
|
t.Cleanup(d.Close)
|
||||||
val, ok := d.SafeSearchDomain("www.google.com")
|
val, ok := d.SafeSearchDomain("www.google.com")
|
||||||
require.True(t, ok, "Expected safesearch to find result for www.google.com")
|
require.True(t, ok)
|
||||||
assert.Equal(t, "forcesafesearch.google.com", val, "Expected safesearch for google.com to be forcesafesearch.google.com")
|
|
||||||
|
assert.Equal(t, "forcesafesearch.google.com", val)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCheckHostSafeSearchYandex(t *testing.T) {
|
func TestCheckHostSafeSearchYandex(t *testing.T) {
|
||||||
|
@ -226,10 +240,12 @@ func TestCheckHostSafeSearchYandex(t *testing.T) {
|
||||||
} {
|
} {
|
||||||
t.Run(strings.ToLower(host), func(t *testing.T) {
|
t.Run(strings.ToLower(host), func(t *testing.T) {
|
||||||
res, err := d.CheckHost(host, dns.TypeA, &setts)
|
res, err := d.CheckHost(host, dns.TypeA, &setts)
|
||||||
require.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.True(t, res.IsFiltered)
|
assert.True(t, res.IsFiltered)
|
||||||
|
|
||||||
require.Len(t, res.Rules, 1)
|
require.Len(t, res.Rules, 1)
|
||||||
|
|
||||||
assert.Equal(t, yandexIP, res.Rules[0].IP)
|
assert.Equal(t, yandexIP, res.Rules[0].IP)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -257,9 +273,12 @@ func TestCheckHostSafeSearchGoogle(t *testing.T) {
|
||||||
} {
|
} {
|
||||||
t.Run(host, func(t *testing.T) {
|
t.Run(host, func(t *testing.T) {
|
||||||
res, err := d.CheckHost(host, dns.TypeA, &setts)
|
res, err := d.CheckHost(host, dns.TypeA, &setts)
|
||||||
require.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.True(t, res.IsFiltered)
|
assert.True(t, res.IsFiltered)
|
||||||
|
|
||||||
require.Len(t, res.Rules, 1)
|
require.Len(t, res.Rules, 1)
|
||||||
|
|
||||||
assert.Equal(t, ip, res.Rules[0].IP)
|
assert.Equal(t, ip, res.Rules[0].IP)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -272,8 +291,10 @@ func TestSafeSearchCacheYandex(t *testing.T) {
|
||||||
|
|
||||||
// Check host with disabled safesearch.
|
// Check host with disabled safesearch.
|
||||||
res, err := d.CheckHost(domain, dns.TypeA, &setts)
|
res, err := d.CheckHost(domain, dns.TypeA, &setts)
|
||||||
require.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.False(t, res.IsFiltered)
|
assert.False(t, res.IsFiltered)
|
||||||
|
|
||||||
require.Empty(t, res.Rules)
|
require.Empty(t, res.Rules)
|
||||||
|
|
||||||
yandexIP := net.IPv4(213, 180, 193, 56)
|
yandexIP := net.IPv4(213, 180, 193, 56)
|
||||||
|
@ -282,7 +303,7 @@ func TestSafeSearchCacheYandex(t *testing.T) {
|
||||||
t.Cleanup(d.Close)
|
t.Cleanup(d.Close)
|
||||||
|
|
||||||
res, err = d.CheckHost(domain, dns.TypeA, &setts)
|
res, err = d.CheckHost(domain, dns.TypeA, &setts)
|
||||||
require.Nilf(t, err, "CheckHost for safesearh domain %s failed cause %s", domain, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// For yandex we already know valid IP.
|
// For yandex we already know valid IP.
|
||||||
require.Len(t, res.Rules, 1)
|
require.Len(t, res.Rules, 1)
|
||||||
|
@ -292,6 +313,7 @@ func TestSafeSearchCacheYandex(t *testing.T) {
|
||||||
cachedValue, isFound := getCachedResult(gctx.safeSearchCache, domain)
|
cachedValue, isFound := getCachedResult(gctx.safeSearchCache, domain)
|
||||||
require.True(t, isFound)
|
require.True(t, isFound)
|
||||||
require.Len(t, cachedValue.Rules, 1)
|
require.Len(t, cachedValue.Rules, 1)
|
||||||
|
|
||||||
assert.Equal(t, cachedValue.Rules[0].IP, yandexIP)
|
assert.Equal(t, cachedValue.Rules[0].IP, yandexIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -304,8 +326,10 @@ func TestSafeSearchCacheGoogle(t *testing.T) {
|
||||||
|
|
||||||
const domain = "www.google.ru"
|
const domain = "www.google.ru"
|
||||||
res, err := d.CheckHost(domain, dns.TypeA, &setts)
|
res, err := d.CheckHost(domain, dns.TypeA, &setts)
|
||||||
require.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.False(t, res.IsFiltered)
|
assert.False(t, res.IsFiltered)
|
||||||
|
|
||||||
require.Empty(t, res.Rules)
|
require.Empty(t, res.Rules)
|
||||||
|
|
||||||
d = newForTest(&Config{SafeSearchEnabled: true}, nil)
|
d = newForTest(&Config{SafeSearchEnabled: true}, nil)
|
||||||
|
@ -314,10 +338,10 @@ func TestSafeSearchCacheGoogle(t *testing.T) {
|
||||||
|
|
||||||
// Lookup for safesearch domain.
|
// Lookup for safesearch domain.
|
||||||
safeDomain, ok := d.SafeSearchDomain(domain)
|
safeDomain, ok := d.SafeSearchDomain(domain)
|
||||||
require.Truef(t, ok, "Failed to get safesearch domain for %s", domain)
|
require.True(t, ok)
|
||||||
|
|
||||||
ips, err := resolver.LookupIP(context.Background(), "ip", safeDomain)
|
ips, err := resolver.LookupIP(context.Background(), "ip", safeDomain)
|
||||||
require.Nilf(t, err, "Failed to lookup for %s", safeDomain)
|
require.NoError(t, err)
|
||||||
|
|
||||||
var ip net.IP
|
var ip net.IP
|
||||||
for _, foundIP := range ips {
|
for _, foundIP := range ips {
|
||||||
|
@ -329,14 +353,16 @@ func TestSafeSearchCacheGoogle(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
res, err = d.CheckHost(domain, dns.TypeA, &setts)
|
res, err = d.CheckHost(domain, dns.TypeA, &setts)
|
||||||
require.Nil(t, err)
|
require.NoError(t, err)
|
||||||
require.Len(t, res.Rules, 1)
|
require.Len(t, res.Rules, 1)
|
||||||
|
|
||||||
assert.True(t, res.Rules[0].IP.Equal(ip))
|
assert.True(t, res.Rules[0].IP.Equal(ip))
|
||||||
|
|
||||||
// Check cache.
|
// Check cache.
|
||||||
cachedValue, isFound := getCachedResult(gctx.safeSearchCache, domain)
|
cachedValue, isFound := getCachedResult(gctx.safeSearchCache, domain)
|
||||||
require.True(t, isFound)
|
require.True(t, isFound)
|
||||||
require.Len(t, cachedValue.Rules, 1)
|
require.Len(t, cachedValue.Rules, 1)
|
||||||
|
|
||||||
assert.True(t, cachedValue.Rules[0].IP.Equal(ip))
|
assert.True(t, cachedValue.Rules[0].IP.Equal(ip))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -357,6 +383,7 @@ func TestParentalControl(t *testing.T) {
|
||||||
|
|
||||||
d.checkMatch(t, matching)
|
d.checkMatch(t, matching)
|
||||||
require.Contains(t, logOutput.String(), "Parental lookup for "+matching)
|
require.Contains(t, logOutput.String(), "Parental lookup for "+matching)
|
||||||
|
|
||||||
d.checkMatch(t, "www."+matching)
|
d.checkMatch(t, "www."+matching)
|
||||||
d.checkMatchEmpty(t, "www.yandex.ru")
|
d.checkMatchEmpty(t, "www.yandex.ru")
|
||||||
d.checkMatchEmpty(t, "yandex.ru")
|
d.checkMatchEmpty(t, "yandex.ru")
|
||||||
|
@ -654,7 +681,8 @@ func TestMatching(t *testing.T) {
|
||||||
t.Cleanup(d.Close)
|
t.Cleanup(d.Close)
|
||||||
|
|
||||||
res, err := d.CheckHost(tc.host, tc.wantDNSType, &setts)
|
res, err := d.CheckHost(tc.host, tc.wantDNSType, &setts)
|
||||||
require.Nilf(t, err, "Error while matching host %s: %s", tc.host, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Equalf(t, tc.wantIsFiltered, res.IsFiltered, "Hostname %s has wrong result (%v must be %v)", tc.host, res.IsFiltered, tc.wantIsFiltered)
|
assert.Equalf(t, tc.wantIsFiltered, res.IsFiltered, "Hostname %s has wrong result (%v must be %v)", tc.host, res.IsFiltered, tc.wantIsFiltered)
|
||||||
assert.Equalf(t, tc.wantReason, res.Reason, "Hostname %s has wrong reason (%v must be %v)", tc.host, res.Reason, tc.wantReason)
|
assert.Equalf(t, tc.wantReason, res.Reason, "Hostname %s has wrong reason (%v must be %v)", tc.host, res.Reason, tc.wantReason)
|
||||||
})
|
})
|
||||||
|
@ -677,23 +705,31 @@ func TestWhitelist(t *testing.T) {
|
||||||
}}
|
}}
|
||||||
d := newForTest(nil, filters)
|
d := newForTest(nil, filters)
|
||||||
|
|
||||||
require.Nil(t, d.SetFilters(filters, whiteFilters, false))
|
err := d.SetFilters(filters, whiteFilters, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
t.Cleanup(d.Close)
|
t.Cleanup(d.Close)
|
||||||
|
|
||||||
// Matched by white filter.
|
// Matched by white filter.
|
||||||
res, err := d.CheckHost("host1", dns.TypeA, &setts)
|
res, err := d.CheckHost("host1", dns.TypeA, &setts)
|
||||||
require.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.False(t, res.IsFiltered)
|
assert.False(t, res.IsFiltered)
|
||||||
assert.Equal(t, res.Reason, NotFilteredAllowList)
|
assert.Equal(t, res.Reason, NotFilteredAllowList)
|
||||||
|
|
||||||
require.Len(t, res.Rules, 1)
|
require.Len(t, res.Rules, 1)
|
||||||
|
|
||||||
assert.Equal(t, "||host1^", res.Rules[0].Text)
|
assert.Equal(t, "||host1^", res.Rules[0].Text)
|
||||||
|
|
||||||
// Not matched by white filter, but matched by block filter.
|
// Not matched by white filter, but matched by block filter.
|
||||||
res, err = d.CheckHost("host2", dns.TypeA, &setts)
|
res, err = d.CheckHost("host2", dns.TypeA, &setts)
|
||||||
require.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.True(t, res.IsFiltered)
|
assert.True(t, res.IsFiltered)
|
||||||
assert.Equal(t, res.Reason, FilteredBlockList)
|
assert.Equal(t, res.Reason, FilteredBlockList)
|
||||||
|
|
||||||
require.Len(t, res.Rules, 1)
|
require.Len(t, res.Rules, 1)
|
||||||
|
|
||||||
assert.Equal(t, "||host2^", res.Rules[0].Text)
|
assert.Equal(t, "||host2^", res.Rules[0].Text)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -796,7 +832,8 @@ func BenchmarkSafeBrowsing(b *testing.B) {
|
||||||
})
|
})
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
res, err := d.CheckHost(blocked, dns.TypeA, &setts)
|
res, err := d.CheckHost(blocked, dns.TypeA, &setts)
|
||||||
require.Nilf(b, err, "Error while matching host %s: %s", blocked, err)
|
require.NoError(b, err)
|
||||||
|
|
||||||
assert.True(b, res.IsFiltered, "Expected hostname %s to match", blocked)
|
assert.True(b, res.IsFiltered, "Expected hostname %s to match", blocked)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -812,7 +849,8 @@ func BenchmarkSafeBrowsingParallel(b *testing.B) {
|
||||||
b.RunParallel(func(pb *testing.PB) {
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
for pb.Next() {
|
for pb.Next() {
|
||||||
res, err := d.CheckHost(blocked, dns.TypeA, &setts)
|
res, err := d.CheckHost(blocked, dns.TypeA, &setts)
|
||||||
require.Nilf(b, err, "Error while matching host %s: %s", blocked, err)
|
require.NoError(b, err)
|
||||||
|
|
||||||
assert.True(b, res.IsFiltered, "Expected hostname %s to match", blocked)
|
assert.True(b, res.IsFiltered, "Expected hostname %s to match", blocked)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -823,7 +861,8 @@ func BenchmarkSafeSearch(b *testing.B) {
|
||||||
b.Cleanup(d.Close)
|
b.Cleanup(d.Close)
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
val, ok := d.SafeSearchDomain("www.google.com")
|
val, ok := d.SafeSearchDomain("www.google.com")
|
||||||
require.True(b, ok, "Expected safesearch to find result for www.google.com")
|
require.True(b, ok)
|
||||||
|
|
||||||
assert.Equal(b, "forcesafesearch.google.com", val, "Expected safesearch for google.com to be forcesafesearch.google.com")
|
assert.Equal(b, "forcesafesearch.google.com", val, "Expected safesearch for google.com to be forcesafesearch.google.com")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -834,7 +873,8 @@ func BenchmarkSafeSearchParallel(b *testing.B) {
|
||||||
b.RunParallel(func(pb *testing.PB) {
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
for pb.Next() {
|
for pb.Next() {
|
||||||
val, ok := d.SafeSearchDomain("www.google.com")
|
val, ok := d.SafeSearchDomain("www.google.com")
|
||||||
require.True(b, ok, "Expected safesearch to find result for www.google.com")
|
require.True(b, ok)
|
||||||
|
|
||||||
assert.Equal(b, "forcesafesearch.google.com", val, "Expected safesearch for google.com to be forcesafesearch.google.com")
|
assert.Equal(b, "forcesafesearch.google.com", val, "Expected safesearch for google.com to be forcesafesearch.google.com")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
Loading…
Reference in New Issue