From b8c4651decc0193f9c2089f5ab1ce0921bcddc2d Mon Sep 17 00:00:00 2001 From: Eugene Burkov Date: Wed, 20 Oct 2021 19:52:13 +0300 Subject: [PATCH] Pull request: 1558 enable dnsrewrites on disabled protection Merge in DNS/adguard-home from 1558-always-rewrite to master Squashed commit of the following: commit b8508b3b5fb688cad273a9259c09ccfc07948b2f Author: Eugene Burkov Date: Wed Oct 20 19:17:22 2021 +0300 all: imp log of changes commit 97e3649b670786a2936e368a9505faf52f8e8804 Author: Eugene Burkov Date: Mon Oct 18 13:18:15 2021 +0300 all: enable dnsrewrites on disabled protection --- CHANGELOG.md | 3 + internal/dnsforward/dns.go | 70 ++++++++-------- internal/dnsforward/dnsforward_test.go | 104 ++++++++++++++---------- internal/dnsforward/filter.go | 55 +++++++------ internal/filtering/filtering.go | 33 +++++--- internal/filtering/filtering_test.go | 16 ++-- internal/filtering/safebrowsing.go | 4 +- internal/filtering/safebrowsing_test.go | 22 ++--- internal/filtering/safesearch.go | 2 +- internal/home/controlfiltering.go | 1 + 10 files changed, 178 insertions(+), 132 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 99821349..aeb1ed2d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -46,6 +46,8 @@ and this project adheres to ### Changed +- `$dnsrewrite` rules and other DNS rewrites will now be applied even when the + protection is disabled ([#1558]). - DHCP gateway address, subnet mask, IP address range, and leases validations ([#3529]). - The `systemd` service script will now create the `/var/log` directory when it @@ -155,6 +157,7 @@ In this release, the schema version has changed from 10 to 12. - Go 1.15 support. [#1381]: https://github.com/AdguardTeam/AdGuardHome/issues/1381 +[#1558]: https://github.com/AdguardTeam/AdGuardHome/issues/1558 [#1691]: https://github.com/AdguardTeam/AdGuardHome/issues/1691 [#1898]: https://github.com/AdguardTeam/AdGuardHome/issues/1898 [#1992]: https://github.com/AdguardTeam/AdGuardHome/issues/1992 diff --git a/internal/dnsforward/dns.go b/internal/dnsforward/dns.go index 4f1b8d3a..a0aa4cb1 100644 --- a/internal/dnsforward/dns.go +++ b/internal/dnsforward/dns.go @@ -90,7 +90,7 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error { s.processRestrictLocal, s.processInternalIPAddrs, s.processClientID, - processFilteringBeforeRequest, + s.processFilteringBeforeRequest, s.processLocalPTR, s.processUpstream, processDNSSECAfterResponse, @@ -468,19 +468,18 @@ func (s *Server) processLocalPTR(ctx *dnsContext) (rc resultCode) { } // Apply filtering logic -func processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode) { - s := ctx.srv - d := ctx.proxyCtx - - if d.Res != nil { - return resultCodeSuccess // response is already set - nothing to do +func (s *Server) processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode) { + if ctx.proxyCtx.Res != nil { + // Go on since the response is already set. + return resultCodeSuccess } s.serverLock.RLock() defer s.serverLock.RUnlock() - ctx.protectionEnabled = s.conf.ProtectionEnabled && s.dnsFilter != nil - if !ctx.protectionEnabled { + ctx.protectionEnabled = s.conf.ProtectionEnabled + + if s.dnsFilter == nil { return resultCodeSuccess } @@ -489,8 +488,7 @@ func processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode) { } var err error - ctx.result, err = s.filterDNSRequest(ctx) - if err != nil { + if ctx.result, err = s.filterDNSRequest(ctx); err != nil { ctx.err = err return resultCodeError @@ -608,48 +606,50 @@ func processDNSSECAfterResponse(ctx *dnsContext) (rc resultCode) { func processFilteringAfterResponse(ctx *dnsContext) (rc resultCode) { s := ctx.srv d := ctx.proxyCtx - res := ctx.result - var err error - switch res.Reason { - case filtering.Rewritten, + switch res := ctx.result; res.Reason { + case filtering.NotFilteredAllowList: + // Go on. + case + filtering.Rewritten, filtering.RewrittenRule: if len(ctx.origQuestion.Name) == 0 { - // origQuestion is set in case we get only CNAME without IP from rewrites table + // origQuestion is set in case we get only CNAME without IP from + // rewrites table. break } - d.Req.Question[0] = ctx.origQuestion - d.Res.Question[0] = ctx.origQuestion - - if len(d.Res.Answer) != 0 { - answer := []dns.RR{} - answer = append(answer, s.genAnswerCNAME(d.Req, res.CanonName)) - answer = append(answer, d.Res.Answer...) + d.Req.Question[0], d.Res.Question[0] = ctx.origQuestion, ctx.origQuestion + if len(d.Res.Answer) > 0 { + answer := append([]dns.RR{s.genAnswerCNAME(d.Req, res.CanonName)}, d.Res.Answer...) d.Res.Answer = answer } - - case filtering.NotFilteredAllowList: - // nothing - default: - if !ctx.protectionEnabled || // filters are disabled: there's nothing to check for - !ctx.responseFromUpstream { // only check response if it's from an upstream server + // Check the response only if the it's from an upstream. Don't check + // the response if the protection is disabled since dnsrewrite rules + // aren't applied to it anyway. + if !ctx.protectionEnabled || !ctx.responseFromUpstream || s.dnsFilter == nil { break } - origResp2 := d.Res - ctx.result, err = s.filterDNSResponse(ctx) + + origResp := d.Res + result, err := s.filterDNSResponse(ctx) if err != nil { ctx.err = err + return resultCodeError } - if ctx.result != nil { - ctx.origResp = origResp2 // matched by response - } else { - ctx.result = &filtering.Result{} + + if result != nil { + ctx.result = result + ctx.origResp = origResp } } + if ctx.result == nil { + ctx.result = &filtering.Result{} + } + return resultCodeSuccess } diff --git a/internal/dnsforward/dnsforward_test.go b/internal/dnsforward/dnsforward_test.go index 3ab76a40..a3aa68c6 100644 --- a/internal/dnsforward/dnsforward_test.go +++ b/internal/dnsforward/dnsforward_test.go @@ -909,6 +909,7 @@ func TestRewrite(t *testing.T) { }}, } f := filtering.New(c, nil) + f.SetEnabled(true) snd, err := aghnet.NewSubnetDetector() require.NoError(t, err) @@ -945,45 +946,56 @@ func TestRewrite(t *testing.T) { addr := s.dnsProxy.Addr(proxy.ProtoUDP) - req := createTestMessageWithType("test.com.", dns.TypeA) - reply, err := dns.Exchange(req, addr.String()) - require.NoError(t, err) + subTestFunc := func(t *testing.T) { + req := createTestMessageWithType("test.com.", dns.TypeA) + reply, eerr := dns.Exchange(req, addr.String()) + require.NoError(t, eerr) - require.Len(t, reply.Answer, 1) + require.Len(t, reply.Answer, 1) - a, ok := reply.Answer[0].(*dns.A) - require.True(t, ok) + a, ok := reply.Answer[0].(*dns.A) + require.True(t, ok) - assert.True(t, net.IP{1, 2, 3, 4}.Equal(a.A)) + assert.True(t, net.IP{1, 2, 3, 4}.Equal(a.A)) - req = createTestMessageWithType("test.com.", dns.TypeAAAA) - reply, err = dns.Exchange(req, addr.String()) - require.NoError(t, err) + req = createTestMessageWithType("test.com.", dns.TypeAAAA) + reply, eerr = dns.Exchange(req, addr.String()) + require.NoError(t, eerr) - assert.Empty(t, reply.Answer) + assert.Empty(t, reply.Answer) - req = createTestMessageWithType("alias.test.com.", dns.TypeA) - reply, err = dns.Exchange(req, addr.String()) - require.NoError(t, err) + req = createTestMessageWithType("alias.test.com.", dns.TypeA) + reply, eerr = dns.Exchange(req, addr.String()) + require.NoError(t, eerr) - require.Len(t, reply.Answer, 2) + require.Len(t, reply.Answer, 2) - assert.Equal(t, "test.com.", reply.Answer[0].(*dns.CNAME).Target) - assert.True(t, net.IP{1, 2, 3, 4}.Equal(reply.Answer[1].(*dns.A).A)) + assert.Equal(t, "test.com.", reply.Answer[0].(*dns.CNAME).Target) + assert.True(t, net.IP{1, 2, 3, 4}.Equal(reply.Answer[1].(*dns.A).A)) - req = createTestMessageWithType("my.alias.example.org.", dns.TypeA) - reply, err = dns.Exchange(req, addr.String()) - require.NoError(t, err) + req = createTestMessageWithType("my.alias.example.org.", dns.TypeA) + reply, eerr = dns.Exchange(req, addr.String()) + require.NoError(t, eerr) - // The original question is restored. - require.Len(t, reply.Question, 1) + // The original question is restored. + require.Len(t, reply.Question, 1) - assert.Equal(t, "my.alias.example.org.", reply.Question[0].Name) + assert.Equal(t, "my.alias.example.org.", reply.Question[0].Name) - require.Len(t, reply.Answer, 2) + require.Len(t, reply.Answer, 2) - assert.Equal(t, "example.org.", reply.Answer[0].(*dns.CNAME).Target) - assert.Equal(t, dns.TypeA, reply.Answer[1].Header().Rrtype) + assert.Equal(t, "example.org.", reply.Answer[0].(*dns.CNAME).Target) + assert.Equal(t, dns.TypeA, reply.Answer[1].Header().Rrtype) + } + + for _, protect := range []bool{true, false} { + val := protect + conf := s.getDNSConfig() + conf.ProtectionEnabled = &val + s.setConfig(conf) + + t.Run(fmt.Sprintf("protection_is_%t", val), subTestFunc) + } } func publicKey(priv interface{}) interface{} { @@ -1092,9 +1104,10 @@ func TestPTRResponseFromHosts(t *testing.T) { require.ErrorIs(t, hc.Close(), closeCalled) }) - c := filtering.Config{ + flt := filtering.New(&filtering.Config{ EtcHosts: hc, - } + }, nil) + flt.SetEnabled(true) var snd *aghnet.SubnetDetector snd, err = aghnet.NewSubnetDetector() @@ -1104,7 +1117,7 @@ func TestPTRResponseFromHosts(t *testing.T) { var s *Server s, err = NewServer(DNSCreateParams{ DHCPServer: &testDHCP{}, - DNSFilter: filtering.New(&c, nil), + DNSFilter: flt, SubnetDetector: snd, }) require.NoError(t, err) @@ -1112,32 +1125,41 @@ func TestPTRResponseFromHosts(t *testing.T) { s.conf.UDPListenAddrs = []*net.UDPAddr{{}} s.conf.TCPListenAddrs = []*net.TCPAddr{{}} s.conf.UpstreamDNS = []string{"127.0.0.1:53"} - s.conf.FilteringConfig.ProtectionEnabled = true err = s.Prepare(nil) require.NoError(t, err) err = s.Start() require.NoError(t, err) - t.Cleanup(func() { s.Close() }) - addr := s.dnsProxy.Addr(proxy.ProtoUDP) - req := createTestMessageWithType("1.0.0.127.in-addr.arpa.", dns.TypePTR) + subTestFunc := func(t *testing.T) { + addr := s.dnsProxy.Addr(proxy.ProtoUDP) + req := createTestMessageWithType("1.0.0.127.in-addr.arpa.", dns.TypePTR) - resp, err := dns.Exchange(req, addr.String()) - require.NoError(t, err) + resp, eerr := dns.Exchange(req, addr.String()) + require.NoError(t, eerr) - require.Lenf(t, resp.Answer, 1, "%#v", resp) + require.Len(t, resp.Answer, 1) - assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype) - assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name) + assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype) + assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name) - ptr, ok := resp.Answer[0].(*dns.PTR) - require.True(t, ok) - assert.Equal(t, "host.", ptr.Ptr) + ptr, ok := resp.Answer[0].(*dns.PTR) + require.True(t, ok) + assert.Equal(t, "host.", ptr.Ptr) + } + + for _, protect := range []bool{true, false} { + val := protect + conf := s.getDNSConfig() + conf.ProtectionEnabled = &val + s.setConfig(conf) + + t.Run(fmt.Sprintf("protection_is_%t", val), subTestFunc) + } } func TestNewServer(t *testing.T) { diff --git a/internal/dnsforward/filter.go b/internal/dnsforward/filter.go index 5edca948..7300b43c 100644 --- a/internal/dnsforward/filter.go +++ b/internal/dnsforward/filter.go @@ -52,6 +52,7 @@ func (s *Server) beforeRequestHandler( // the client's IP address and ID, if any, from ctx. func (s *Server) getClientRequestFilteringSettings(ctx *dnsContext) *filtering.Settings { setts := s.dnsFilter.GetConfig() + setts.ProtectionEnabled = ctx.protectionEnabled if s.conf.FilterHandler != nil { ip, _ := netutil.IPAndPortFromAddr(ctx.proxyCtx.Addr) s.conf.FilterHandler(ip, ctx.clientID, &setts) @@ -65,32 +66,31 @@ func (s *Server) getClientRequestFilteringSettings(ctx *dnsContext) *filtering.S func (s *Server) filterDNSRequest(ctx *dnsContext) (*filtering.Result, error) { d := ctx.proxyCtx req := d.Req - host := strings.TrimSuffix(req.Question[0].Name, ".") - res, err := s.dnsFilter.CheckHost(host, req.Question[0].Qtype, ctx.setts) - if err != nil { - // Return immediately if there's an error - return nil, fmt.Errorf("filtering failed to check host %q: %w", host, err) - } else if res.IsFiltered { - log.Tracef("Host %s is filtered, reason - %q, matched rule: %q", host, res.Reason, res.Rules[0].Text) + q := req.Question[0] + host := strings.TrimSuffix(q.Name, ".") + res, err := s.dnsFilter.CheckHost(host, q.Qtype, ctx.setts) + switch { + case err != nil: + return nil, fmt.Errorf("failed to check host %q: %w", host, err) + case res.IsFiltered: + log.Tracef("host %q is filtered, reason %q, rule: %q", host, res.Reason, res.Rules[0].Text) d.Res = s.genDNSFilterMessage(d, &res) - } else if res.Reason.In(filtering.Rewritten, filtering.RewrittenRule) && + case res.Reason.In(filtering.Rewritten, filtering.RewrittenRule) && res.CanonName != "" && - len(res.IPList) == 0 { - // Resolve the new canonical name, not the original host - // name. The original question is readded in - // processFilteringAfterResponse. - ctx.origQuestion = req.Question[0] + len(res.IPList) == 0: + // Resolve the new canonical name, not the original host name. The + // original question is readded in processFilteringAfterResponse. + ctx.origQuestion = q req.Question[0].Name = dns.Fqdn(res.CanonName) - } else if res.Reason == filtering.RewrittenAutoHosts && len(res.ReverseHosts) != 0 { + case res.Reason == filtering.RewrittenAutoHosts && len(res.ReverseHosts) != 0: resp := s.makeResponse(req) + hdr := dns.RR_Header{ + Name: q.Name, + Rrtype: dns.TypePTR, + Ttl: s.conf.BlockedResponseTTL, + Class: dns.ClassINET, + } for _, h := range res.ReverseHosts { - hdr := dns.RR_Header{ - Name: req.Question[0].Name, - Rrtype: dns.TypePTR, - Ttl: s.conf.BlockedResponseTTL, - Class: dns.ClassINET, - } - ptr := &dns.PTR{ Hdr: hdr, Ptr: h, @@ -100,7 +100,7 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*filtering.Result, error) { } d.Res = resp - } else if res.Reason.In(filtering.Rewritten, filtering.RewrittenAutoHosts) { + case res.Reason.In(filtering.Rewritten, filtering.RewrittenAutoHosts): resp := s.makeResponse(req) name := host @@ -110,11 +110,12 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*filtering.Result, error) { } for _, ip := range res.IPList { - if req.Question[0].Qtype == dns.TypeA { + switch q.Qtype { + case dns.TypeA: a := s.genAnswerA(req, ip.To4()) a.Hdr.Name = dns.Fqdn(name) resp.Answer = append(resp.Answer, a) - } else if req.Question[0].Qtype == dns.TypeAAAA { + case dns.TypeAAAA: a := s.genAnswerAAAA(req, ip) a.Hdr.Name = dns.Fqdn(name) resp.Answer = append(resp.Answer, a) @@ -122,9 +123,8 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*filtering.Result, error) { } d.Res = resp - } else if res.Reason == filtering.RewrittenRule { - err = s.filterDNSRewrite(req, res, d) - if err != nil { + case res.Reason == filtering.RewrittenRule: + if err = s.filterDNSRewrite(req, res, d); err != nil { return nil, err } } @@ -179,6 +179,7 @@ func (s *Server) filterDNSResponse(ctx *dnsContext) (*filtering.Result, error) { continue } + host = strings.TrimSuffix(host, ".") res, err := s.checkHostRules(host, d.Req.Question[0].Qtype, ctx.setts) if err != nil { return nil, err diff --git a/internal/filtering/filtering.go b/internal/filtering/filtering.go index f7a8ebe4..87d0da15 100644 --- a/internal/filtering/filtering.go +++ b/internal/filtering/filtering.go @@ -38,6 +38,7 @@ type Settings struct { ServicesRules []ServiceEntry + ProtectionEnabled bool FilteringEnabled bool SafeSearchEnabled bool SafeBrowsingEnabled bool @@ -221,12 +222,13 @@ func (r Reason) String() string { } // In returns true if reasons include r. -func (r Reason) In(reasons ...Reason) bool { +func (r Reason) In(reasons ...Reason) (ok bool) { for _, reason := range reasons { if r == reason { return true } } + return false } @@ -245,7 +247,7 @@ func (d *DNSFilter) GetConfig() (s Settings) { defer d.confLock.RUnlock() return Settings{ - FilteringEnabled: atomic.LoadUint32(&d.Config.enabled) == 1, + FilteringEnabled: atomic.LoadUint32(&d.Config.enabled) != 0, SafeSearchEnabled: d.Config.SafeSearchEnabled, SafeBrowsingEnabled: d.Config.SafeBrowsingEnabled, ParentalEnabled: d.Config.ParentalEnabled, @@ -421,14 +423,16 @@ func (d *DNSFilter) CheckHost( // Sometimes clients try to resolve ".", which is a request to get root // servers. if host == "" { - return Result{Reason: NotFilteredNotFound}, nil + return Result{}, nil } host = strings.ToLower(host) - res = d.processRewrites(host, qtype) - if res.Reason == Rewritten { - return res, nil + if setts.FilteringEnabled { + res = d.processRewrites(host, qtype) + if res.Reason == Rewritten { + return res, nil + } } for _, hc := range d.hostCheckers { @@ -448,7 +452,7 @@ func (d *DNSFilter) CheckHost( // matchSysHosts tries to match the host against the operating system's hosts // database. func (d *DNSFilter) matchSysHosts(host string, qtype uint16, setts *Settings) (res Result, err error) { - if d.EtcHosts == nil { + if !setts.FilteringEnabled || d.EtcHosts == nil { return Result{}, nil } @@ -468,10 +472,8 @@ func (d *DNSFilter) matchSysHosts(host string, qtype uint16, setts *Settings) (r var ips []net.IP var revHosts []string - for _, nr := range dnsr { - dr := nr.DNSRewrite - if dr == nil { + if nr.DNSRewrite == nil { continue } @@ -553,6 +555,10 @@ func matchBlockedServicesRules( _ uint16, setts *Settings, ) (res Result, err error) { + if !setts.ProtectionEnabled { + return Result{}, nil + } + svcs := setts.ServicesRules if len(svcs) == 0 { return Result{}, nil @@ -784,7 +790,7 @@ func (d *DNSFilter) matchHost( // TODO(e.burkov): Inspect if the above is true. defer d.engineLock.RUnlock() - if d.filteringEngineAllow != nil { + if setts.ProtectionEnabled && d.filteringEngineAllow != nil { dnsres, ok := d.filteringEngineAllow.MatchRequest(ureq) if ok { return d.matchHostProcessAllowList(host, dnsres) @@ -810,6 +816,11 @@ func (d *DNSFilter) matchHost( return Result{}, nil } + if !setts.ProtectionEnabled { + // Don't check non-dnsrewrite filtering results. + return Result{}, nil + } + res = d.matchHostProcessDNSResult(qtype, dnsres) for _, r := range res.Rules { log.Debug( diff --git a/internal/filtering/filtering_test.go b/internal/filtering/filtering_test.go index 746e9ed0..b389b386 100644 --- a/internal/filtering/filtering_test.go +++ b/internal/filtering/filtering_test.go @@ -21,7 +21,9 @@ func TestMain(m *testing.M) { aghtest.DiscardLogOutput(m) } -var setts Settings +var setts = Settings{ + ProtectionEnabled: true, +} // Helpers. @@ -39,9 +41,9 @@ func purgeCaches() { func newForTest(c *Config, filters []Filter) *DNSFilter { setts = Settings{ - FilteringEnabled: true, + ProtectionEnabled: true, + FilteringEnabled: true, } - setts.FilteringEnabled = true if c != nil { c.SafeBrowsingCacheSize = 10000 c.ParentalCacheSize = 10000 @@ -797,7 +799,11 @@ func TestClientSettings(t *testing.T) { makeTester := func(tc testCase, before bool) func(t *testing.T) { return func(t *testing.T) { - r, _ := d.CheckHost(tc.host, dns.TypeA, &setts) + t.Helper() + + r, err := d.CheckHost(tc.host, dns.TypeA, &setts) + require.NoError(t, err) + if before { assert.True(t, r.IsFiltered) assert.Equal(t, tc.wantReason, r.Reason) @@ -808,7 +814,7 @@ func TestClientSettings(t *testing.T) { } // Check behaviour without any per-client settings, then apply per-client - // settings and check behaviour once again. + // settings and check behavior once again. for _, tc := range testCases { t.Run(tc.name, makeTester(tc, tc.before)) } diff --git a/internal/filtering/safebrowsing.go b/internal/filtering/safebrowsing.go index d535a39d..ec626315 100644 --- a/internal/filtering/safebrowsing.go +++ b/internal/filtering/safebrowsing.go @@ -306,7 +306,7 @@ func (d *DNSFilter) checkSafeBrowsing( _ uint16, setts *Settings, ) (res Result, err error) { - if !setts.SafeBrowsingEnabled { + if !setts.ProtectionEnabled || !setts.SafeBrowsingEnabled { return Result{}, nil } @@ -339,7 +339,7 @@ func (d *DNSFilter) checkParental( _ uint16, setts *Settings, ) (res Result, err error) { - if !setts.ParentalEnabled { + if !setts.ProtectionEnabled || !setts.ParentalEnabled { return Result{}, nil } diff --git a/internal/filtering/safebrowsing_test.go b/internal/filtering/safebrowsing_test.go index d513c0b2..c88576f1 100644 --- a/internal/filtering/safebrowsing_test.go +++ b/internal/filtering/safebrowsing_test.go @@ -117,6 +117,7 @@ func TestSBPC_checkErrorUpstream(t *testing.T) { d.SetParentalUpstream(ups) setts := &Settings{ + ProtectionEnabled: true, SafeBrowsingEnabled: true, ParentalEnabled: true, } @@ -135,35 +136,36 @@ func TestSBPC(t *testing.T) { const hostname = "example.org" setts := &Settings{ + ProtectionEnabled: true, SafeBrowsingEnabled: true, ParentalEnabled: true, } testCases := []struct { + testCache cache.Cache + testFunc func(host string, _ uint16, _ *Settings) (res Result, err error) name string block bool - testFunc func(host string, _ uint16, _ *Settings) (res Result, err error) - testCache cache.Cache }{{ + testCache: gctx.safebrowsingCache, + testFunc: d.checkSafeBrowsing, name: "sb_no_block", block: false, - testFunc: d.checkSafeBrowsing, - testCache: gctx.safebrowsingCache, }, { + testCache: gctx.safebrowsingCache, + testFunc: d.checkSafeBrowsing, name: "sb_block", block: true, - testFunc: d.checkSafeBrowsing, - testCache: gctx.safebrowsingCache, }, { + testCache: gctx.parentalCache, + testFunc: d.checkParental, name: "pc_no_block", block: false, - testFunc: d.checkParental, - testCache: gctx.parentalCache, }, { + testCache: gctx.parentalCache, + testFunc: d.checkParental, name: "pc_block", block: true, - testFunc: d.checkParental, - testCache: gctx.parentalCache, }} for _, tc := range testCases { diff --git a/internal/filtering/safesearch.go b/internal/filtering/safesearch.go index db9e9ee9..ff89b950 100644 --- a/internal/filtering/safesearch.go +++ b/internal/filtering/safesearch.go @@ -74,7 +74,7 @@ func (d *DNSFilter) checkSafeSearch( _ uint16, setts *Settings, ) (res Result, err error) { - if !setts.SafeSearchEnabled { + if !setts.ProtectionEnabled || !setts.SafeSearchEnabled { return Result{}, nil } diff --git a/internal/home/controlfiltering.go b/internal/home/controlfiltering.go index 550b4b87..ee19fd16 100644 --- a/internal/home/controlfiltering.go +++ b/internal/home/controlfiltering.go @@ -404,6 +404,7 @@ func (f *Filtering) handleCheckHost(w http.ResponseWriter, r *http.Request) { setts := Context.dnsFilter.GetConfig() setts.FilteringEnabled = true + setts.ProtectionEnabled = true Context.dnsFilter.ApplyBlockedServices(&setts, nil, true) result, err := Context.dnsFilter.CheckHost(host, dns.TypeA, &setts) if err != nil {