* dnsforward: get per-client settings only once
+ dnsforward: add 'ProtectionEnabled = false' test
This commit is contained in:
parent
b3ddae7f85
commit
0ef8e5cdae
|
@ -445,7 +445,15 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
|
|||
// A better approach is for proxy.Stop() to wait until all its workers exit,
|
||||
// but this would require the Upstream interface to have Close() function
|
||||
// (to prevent from hanging while waiting for unresponsive DNS server to respond).
|
||||
res, err := s.filterDNSRequest(d)
|
||||
|
||||
var setts *dnsfilter.RequestFilteringSettings
|
||||
var err error
|
||||
res := &dnsfilter.Result{}
|
||||
protectionEnabled := s.conf.ProtectionEnabled && s.dnsFilter != nil
|
||||
if protectionEnabled {
|
||||
setts = s.getClientRequestFilteringSettings(d)
|
||||
res, err = s.filterDNSRequest(d, setts)
|
||||
}
|
||||
s.RUnlock()
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -486,9 +494,9 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
|
|||
d.Res.Answer = answer
|
||||
}
|
||||
|
||||
} else if res.Reason != dnsfilter.NotFilteredWhiteList {
|
||||
} else if res.Reason != dnsfilter.NotFilteredWhiteList && protectionEnabled {
|
||||
origResp2 := d.Res
|
||||
res, err = s.filterDNSResponse(d)
|
||||
res, err = s.filterDNSResponse(d, setts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -602,12 +610,7 @@ func (s *Server) getClientRequestFilteringSettings(d *proxy.DNSContext) *dnsfilt
|
|||
}
|
||||
|
||||
// filterDNSRequest applies the dnsFilter and sets d.Res if the request was filtered
|
||||
func (s *Server) filterDNSRequest(d *proxy.DNSContext) (*dnsfilter.Result, error) {
|
||||
if !s.conf.ProtectionEnabled || s.dnsFilter == nil {
|
||||
return &dnsfilter.Result{}, nil
|
||||
}
|
||||
|
||||
setts := s.getClientRequestFilteringSettings(d)
|
||||
func (s *Server) filterDNSRequest(d *proxy.DNSContext, setts *dnsfilter.RequestFilteringSettings) (*dnsfilter.Result, error) {
|
||||
req := d.Req
|
||||
host := strings.TrimSuffix(req.Question[0].Name, ".")
|
||||
res, err := s.dnsFilter.CheckHost(host, d.Req.Question[0].Qtype, setts)
|
||||
|
@ -648,7 +651,7 @@ func (s *Server) filterDNSRequest(d *proxy.DNSContext) (*dnsfilter.Result, error
|
|||
|
||||
// If response contains CNAME, A or AAAA records, we apply filtering to each canonical host name or IP address.
|
||||
// If this is a match, we set a new response in d.Res and return.
|
||||
func (s *Server) filterDNSResponse(d *proxy.DNSContext) (*dnsfilter.Result, error) {
|
||||
func (s *Server) filterDNSResponse(d *proxy.DNSContext, setts *dnsfilter.RequestFilteringSettings) (*dnsfilter.Result, error) {
|
||||
for _, a := range d.Res.Answer {
|
||||
host := ""
|
||||
|
||||
|
@ -676,7 +679,6 @@ func (s *Server) filterDNSResponse(d *proxy.DNSContext) (*dnsfilter.Result, erro
|
|||
s.RUnlock()
|
||||
continue
|
||||
}
|
||||
setts := s.getClientRequestFilteringSettings(d)
|
||||
res, err := s.dnsFilter.CheckHostRules(host, d.Req.Question[0].Qtype, setts)
|
||||
s.RUnlock()
|
||||
|
||||
|
|
|
@ -340,6 +340,22 @@ var testIPv4 = map[string][]net.IP{
|
|||
"example.org.": {{127, 0, 0, 255}},
|
||||
}
|
||||
|
||||
func TestBlockCNAMEProtectionEnabled(t *testing.T) {
|
||||
s := createTestServer(t)
|
||||
testUpstm := &testUpstream{testCNAMEs, testIPv4, nil}
|
||||
s.conf.ProtectionEnabled = false
|
||||
err := s.startWithUpstream(testUpstm)
|
||||
assert.True(t, err == nil)
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||
|
||||
// 'badhost' has a canonical name 'null.example.org' which is blocked by filters:
|
||||
// but protection is disabled - response is NOT blocked
|
||||
req := createTestMessage("badhost.")
|
||||
reply, err := dns.Exchange(req, addr.String())
|
||||
assert.True(t, err == nil)
|
||||
assert.True(t, reply.Rcode == dns.RcodeSuccess)
|
||||
}
|
||||
|
||||
func TestBlockCNAME(t *testing.T) {
|
||||
s := createTestServer(t)
|
||||
testUpstm := &testUpstream{testCNAMEs, testIPv4, nil}
|
||||
|
|
Loading…
Reference in New Issue