Pull request: 3842 ptr filtering

Merge in DNS/adguard-home from 3842-fix-ptr-restrict to master

Updates #3842.

Squashed commit of the following:

commit 77bbec41c5238f8fcb0d2bb8d11910d1ac521fcd
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Mon Nov 15 17:34:14 2021 +0300

    dnsforward: imp docs

commit c637276b5a53f5301387b7dc3035e265d0bc9418
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Mon Nov 15 15:41:39 2021 +0300

    dnsforward: fix local ptr blocking
This commit is contained in:
Eugene Burkov 2021-11-15 17:42:10 +03:00
parent 4f257a1cfc
commit 9c8e087544
3 changed files with 39 additions and 43 deletions

View File

@ -16,9 +16,6 @@ import (
// To transfer information between modules // To transfer information between modules
type dnsContext struct { type dnsContext struct {
// TODO(a.garipov): Remove this and rewrite processors to be methods of
// *Server instead.
srv *Server
proxyCtx *proxy.DNSContext proxyCtx *proxy.DNSContext
// setts are the filtering settings for the client. // setts are the filtering settings for the client.
setts *filtering.Settings setts *filtering.Settings
@ -28,7 +25,8 @@ type dnsContext struct {
// response is modified by filters. // response is modified by filters.
origResp *dns.Msg origResp *dns.Msg
// unreversedReqIP stores an IP address obtained from PTR request if it // unreversedReqIP stores an IP address obtained from PTR request if it
// was successfully parsed. // parsed successfully and belongs to one of locally-served IP ranges as per
// RFC 6303.
unreversedReqIP net.IP unreversedReqIP net.IP
// err is the error returned from a processing function. // err is the error returned from a processing function.
err error err error
@ -69,7 +67,6 @@ const (
// handleDNSRequest filters the incoming DNS requests and writes them to the query log // handleDNSRequest filters the incoming DNS requests and writes them to the query log
func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error { func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
ctx := &dnsContext{ ctx := &dnsContext{
srv: s,
proxyCtx: d, proxyCtx: d,
result: &filtering.Result{}, result: &filtering.Result{},
startTime: time.Now(), startTime: time.Now(),
@ -84,7 +81,7 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
// appropriate handler. // appropriate handler.
mods := []modProcessFunc{ mods := []modProcessFunc{
s.processRecursion, s.processRecursion,
processInitial, s.processInitial,
s.processDetermineLocal, s.processDetermineLocal,
s.processInternalHosts, s.processInternalHosts,
s.processRestrictLocal, s.processRestrictLocal,
@ -93,10 +90,10 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
s.processFilteringBeforeRequest, s.processFilteringBeforeRequest,
s.processLocalPTR, s.processLocalPTR,
s.processUpstream, s.processUpstream,
processDNSSECAfterResponse, s.processDNSSECAfterResponse,
processFilteringAfterResponse, s.processFilteringAfterResponse,
s.ipset.process, s.ipset.process,
processQueryLogsAndStats, s.processQueryLogsAndStats,
} }
for _, process := range mods { for _, process := range mods {
r := process(ctx) r := process(ctx)
@ -135,8 +132,7 @@ func (s *Server) processRecursion(dctx *dnsContext) (rc resultCode) {
} }
// Perform initial checks; process WHOIS & rDNS // Perform initial checks; process WHOIS & rDNS
func processInitial(ctx *dnsContext) (rc resultCode) { func (s *Server) processInitial(ctx *dnsContext) (rc resultCode) {
s := ctx.srv
d := ctx.proxyCtx d := ctx.proxyCtx
if s.conf.AAAADisabled && d.Req.Question[0].Qtype == dns.TypeAAAA { if s.conf.AAAADisabled && d.Req.Question[0].Qtype == dns.TypeAAAA {
_ = proxy.CheckDisabledAAAARequest(d, true) _ = proxy.CheckDisabledAAAARequest(d, true)
@ -155,6 +151,9 @@ func processInitial(ctx *dnsContext) (rc resultCode) {
return resultCodeFinish return resultCodeFinish
} }
ctx.protectionEnabled = s.conf.ProtectionEnabled
ctx.setts = s.getClientRequestFilteringSettings(ctx)
return resultCodeSuccess return resultCodeSuccess
} }
@ -339,10 +338,16 @@ func (s *Server) processRestrictLocal(ctx *dnsContext) (rc resultCode) {
} }
// Restrict an access to local addresses for external clients. We also // Restrict an access to local addresses for external clients. We also
// assume that all the DHCP leases we give are locally-served or at // assume that all the DHCP leases we give are locally-served or at least
// least don't need to be inaccessible externally. // don't need to be inaccessible externally.
if s.subnetDetector.IsLocallyServedNetwork(ip) && !ctx.isLocalClient { if !s.subnetDetector.IsLocallyServedNetwork(ip) {
log.Debug("dns: %q requests for internal ip", d.Addr) log.Debug("dns: addr %s is not from locally-served network", ip)
return resultCodeSuccess
}
if !ctx.isLocalClient {
log.Debug("dns: %q requests an internal ip", d.Addr)
d.Res = s.genNXDomain(req) d.Res = s.genNXDomain(req)
// Do not even put into query log. // Do not even put into query log.
@ -352,13 +357,13 @@ func (s *Server) processRestrictLocal(ctx *dnsContext) (rc resultCode) {
// Do not perform unreversing ever again. // Do not perform unreversing ever again.
ctx.unreversedReqIP = ip ctx.unreversedReqIP = ip
// Disable redundant filtering. // There is no need to filter request from external addresses since this
filterSetts := s.getClientRequestFilteringSettings(ctx) // code is only executed when the request is for locally-served ARPA
filterSetts.ParentalEnabled = false // hostname so disable redundant filters.
filterSetts.SafeBrowsingEnabled = false ctx.setts.ParentalEnabled = false
filterSetts.SafeSearchEnabled = false ctx.setts.SafeBrowsingEnabled = false
filterSetts.ServicesRules = nil ctx.setts.SafeSearchEnabled = false
ctx.setts = filterSetts ctx.setts.ServicesRules = nil
// Nothing to restrict. // Nothing to restrict.
return resultCodeSuccess return resultCodeSuccess
@ -475,16 +480,10 @@ func (s *Server) processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode)
s.serverLock.RLock() s.serverLock.RLock()
defer s.serverLock.RUnlock() defer s.serverLock.RUnlock()
ctx.protectionEnabled = s.conf.ProtectionEnabled
if s.dnsFilter == nil { if s.dnsFilter == nil {
return resultCodeSuccess return resultCodeSuccess
} }
if ctx.setts == nil {
ctx.setts = s.getClientRequestFilteringSettings(ctx)
}
var err error var err error
if ctx.result, err = s.filterDNSRequest(ctx); err != nil { if ctx.result, err = s.filterDNSRequest(ctx); err != nil {
ctx.err = err ctx.err = err
@ -555,11 +554,11 @@ func (s *Server) processUpstream(ctx *dnsContext) (rc resultCode) {
} }
// Process DNSSEC after response from upstream server // Process DNSSEC after response from upstream server
func processDNSSECAfterResponse(ctx *dnsContext) (rc resultCode) { func (s *Server) processDNSSECAfterResponse(ctx *dnsContext) (rc resultCode) {
d := ctx.proxyCtx d := ctx.proxyCtx
if !ctx.responseFromUpstream || // don't process response if it's not from upstream servers // Don't process response if it's not from upstream servers.
!ctx.srv.conf.EnableDNSSEC { if !ctx.responseFromUpstream || !s.conf.EnableDNSSEC {
return resultCodeSuccess return resultCodeSuccess
} }
@ -601,8 +600,7 @@ func processDNSSECAfterResponse(ctx *dnsContext) (rc resultCode) {
} }
// Apply filtering logic after we have received response from upstream servers // Apply filtering logic after we have received response from upstream servers
func processFilteringAfterResponse(ctx *dnsContext) (rc resultCode) { func (s *Server) processFilteringAfterResponse(ctx *dnsContext) (rc resultCode) {
s := ctx.srv
d := ctx.proxyCtx d := ctx.proxyCtx
switch res := ctx.result; res.Reason { switch res := ctx.result; res.Reason {

View File

@ -13,9 +13,8 @@ import (
) )
// Write Stats data and logs // Write Stats data and logs
func processQueryLogsAndStats(ctx *dnsContext) (rc resultCode) { func (s *Server) processQueryLogsAndStats(ctx *dnsContext) (rc resultCode) {
elapsed := time.Since(ctx.startTime) elapsed := time.Since(ctx.startTime)
s := ctx.srv
pctx := ctx.proxyCtx pctx := ctx.proxyCtx
shouldLog := true shouldLog := true

View File

@ -160,6 +160,12 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
for _, tc := range testCases { for _, tc := range testCases {
ql := &testQueryLog{}
st := &testStats{}
srv := &Server{
queryLog: ql,
stats: st,
}
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
req := &dns.Msg{ req := &dns.Msg{
Question: []dns.Question{{ Question: []dns.Question{{
@ -173,14 +179,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
Addr: tc.addr, Addr: tc.addr,
Upstream: ups, Upstream: ups,
} }
ql := &testQueryLog{}
st := &testStats{}
dctx := &dnsContext{ dctx := &dnsContext{
srv: &Server{
queryLog: ql,
stats: st,
},
proxyCtx: pctx, proxyCtx: pctx,
startTime: time.Now(), startTime: time.Now(),
result: &filtering.Result{ result: &filtering.Result{
@ -189,7 +188,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
clientID: tc.clientID, clientID: tc.clientID,
} }
code := processQueryLogsAndStats(dctx) code := srv.processQueryLogsAndStats(dctx)
assert.Equal(t, tc.wantCode, code) assert.Equal(t, tc.wantCode, code)
assert.Equal(t, tc.wantLogProto, ql.lastParams.ClientProto) assert.Equal(t, tc.wantLogProto, ql.lastParams.ClientProto)
assert.Equal(t, tc.wantStatClient, st.lastEntry.Client) assert.Equal(t, tc.wantStatClient, st.lastEntry.Client)