diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index f87ee219..0b4e5081 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -593,22 +593,28 @@ func (s *Server) updateStats(d *proxy.DNSContext, elapsed time.Duration, res dns s.stats.Update(e) } -// filterDNSRequest applies the dnsFilter and sets d.Res if the request was filtered -func (s *Server) filterDNSRequest(d *proxy.DNSContext) (*dnsfilter.Result, error) { - if !s.conf.ProtectionEnabled || s.dnsFilter == nil { - return &dnsfilter.Result{}, nil - } - +// getClientRequestFilteringSettings lookups client filtering settings +// using the client's IP address from the DNSContext +func (s *Server) getClientRequestFilteringSettings(d *proxy.DNSContext) *dnsfilter.RequestFilteringSettings { setts := s.dnsFilter.GetConfig() setts.FilteringEnabled = true if s.conf.FilterHandler != nil { clientAddr := ipFromAddr(d.Addr) s.conf.FilterHandler(clientAddr, &setts) } + return &setts +} +// filterDNSRequest applies the dnsFilter and sets d.Res if the request was filtered +func (s *Server) filterDNSRequest(d *proxy.DNSContext) (*dnsfilter.Result, error) { + if !s.conf.ProtectionEnabled || s.dnsFilter == nil { + return &dnsfilter.Result{}, nil + } + + setts := s.getClientRequestFilteringSettings(d) req := d.Req host := strings.TrimSuffix(req.Question[0].Name, ".") - res, err := s.dnsFilter.CheckHost(host, d.Req.Question[0].Qtype, &setts) + res, err := s.dnsFilter.CheckHost(host, d.Req.Question[0].Qtype, setts) if err != nil { // Return immediately if there's an error return nil, errorx.Decorate(err, "dnsfilter failed to check host '%s'", host) @@ -631,7 +637,6 @@ func (s *Server) filterDNSRequest(d *proxy.DNSContext) (*dnsfilter.Result, error a := s.genAAnswer(req, ip) a.Hdr.Name = dns.Fqdn(name) resp.Answer = append(resp.Answer, a) - } else if req.Question[0].Qtype == dns.TypeAAAA { a := s.genAAAAAnswer(req, ip) a.Hdr.Name = dns.Fqdn(name) @@ -675,9 +680,8 @@ func (s *Server) filterResponse(d *proxy.DNSContext) (*dnsfilter.Result, error) s.RUnlock() continue } - setts := dnsfilter.RequestFilteringSettings{} - setts.FilteringEnabled = true - res, err := s.dnsFilter.CheckHost(host, d.Req.Question[0].Qtype, &setts) + setts := s.getClientRequestFilteringSettings(d) + res, err := s.dnsFilter.CheckHost(host, d.Req.Question[0].Qtype, setts) s.RUnlock() if err != nil { diff --git a/dnsforward/dnsforward_test.go b/dnsforward/dnsforward_test.go index 88f7fb78..76f8f028 100644 --- a/dnsforward/dnsforward_test.go +++ b/dnsforward/dnsforward_test.go @@ -384,6 +384,30 @@ func TestBlockCNAME(t *testing.T) { _ = s.Stop() } +func TestClientRulesForCNAMEMatching(t *testing.T) { + s := createTestServer(t) + testUpstm := &testUpstream{testCNAMEs, testIPv4, nil} + s.conf.FilterHandler = func(clientAddr string, settings *dnsfilter.RequestFilteringSettings) { + settings.FilteringEnabled = false + } + err := s.startWithUpstream(testUpstm) + assert.Nil(t, err) + addr := s.dnsProxy.Addr(proxy.ProtoUDP) + + // 'badhost' has a canonical name 'null.example.org' which is blocked by filters: + // response is blocked + req := dns.Msg{} + req.Id = dns.Id() + req.Question = []dns.Question{ + {Name: "badhost.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, + } + // However, in our case it should not be blocked + // as filtering is disabled on the client level + reply, err := dns.Exchange(&req, addr.String()) + assert.Nil(t, err) + assert.Equal(t, dns.RcodeSuccess, reply.Rcode) +} + func TestNullBlockedRequest(t *testing.T) { s := createTestServer(t) s.conf.FilteringConfig.BlockingMode = "null_ip" @@ -563,7 +587,11 @@ func TestBlockedBySafeBrowsing(t *testing.T) { } func createTestServer(t *testing.T) *Server { - rules := "||nxdomain.example.org^\n||null.example.org^\n127.0.0.1 host.example.org\n@@||whitelist.example.org^\n||127.0.0.255\n" + rules := `||nxdomain.example.org +||null.example.org^ +127.0.0.1 host.example.org +@@||whitelist.example.org^ +||127.0.0.255` filters := map[int]string{} filters[0] = rules c := dnsfilter.Config{}