Merge: * dns: refactor
Squashed commit of the following:
commit e9469266cafa3df537b5a4d5e28ca51db8289a34
Merge: 17cf6d60 e7e946fa
Author: Simon Zolin <s.zolin@adguard.com>
Date: Tue Jan 21 13:04:30 2020 +0300
Merge remote-tracking branch 'origin/master' into refactor
commit 17cf6d60d11602df3837316119ba8828f41a95df
Author: Simon Zolin <s.zolin@adguard.com>
Date: Mon Jan 20 15:25:43 2020 +0300
minor
commit 7b79462ebbeb743a10417bd28ceb70262ff9fa5c
Author: Simon Zolin <s.zolin@adguard.com>
Date: Fri Jan 17 17:50:09 2020 +0300
minor
commit d8b175c7eda36005c0277e7876f0f0a55a661b05
Author: Simon Zolin <s.zolin@adguard.com>
Date: Fri Jan 17 15:30:37 2020 +0300
minor
commit 93370aa32aa560d42fc67c95fd13f027ddc01b94
Author: Simon Zolin <s.zolin@adguard.com>
Date: Fri Jan 17 14:28:14 2020 +0300
* dns: refactor
. introduce a local context object
. move filtering, upstream logic, stats, querylog code to separate functions
This commit is contained in:
parent
e7e946faa6
commit
3f7e2f7241
|
@ -425,14 +425,33 @@ func (s *Server) beforeRequestHandler(p *proxy.Proxy, d *proxy.DNSContext) (bool
|
|||
return true, nil
|
||||
}
|
||||
|
||||
// handleDNSRequest filters the incoming DNS requests and writes them to the query log
|
||||
// nolint (gocyclo)
|
||||
func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
|
||||
start := time.Now()
|
||||
// To transfer information between modules
|
||||
type dnsContext struct {
|
||||
srv *Server
|
||||
proxyCtx *proxy.DNSContext
|
||||
setts *dnsfilter.RequestFilteringSettings // filtering settings for this client
|
||||
startTime time.Time
|
||||
result *dnsfilter.Result
|
||||
origResp *dns.Msg // response received from upstream servers. Set when response is modified by filtering
|
||||
origQuestion dns.Question // question received from client. Set when Rewrites are used.
|
||||
err error // error returned from the module
|
||||
protectionEnabled bool // filtering is enabled, dnsfilter object is ready
|
||||
responseFromUpstream bool // response is received from upstream servers
|
||||
}
|
||||
|
||||
const (
|
||||
resultDone = iota // module has completed its job, continue
|
||||
resultFinish // module has completed its job, exit normally
|
||||
resultError // an error occurred, exit with an error
|
||||
)
|
||||
|
||||
// Perform initial checks; process WHOIS & rDNS
|
||||
func processInitial(ctx *dnsContext) int {
|
||||
s := ctx.srv
|
||||
d := ctx.proxyCtx
|
||||
if s.conf.AAAADisabled && d.Req.Question[0].Qtype == dns.TypeAAAA {
|
||||
_ = proxy.CheckDisabledAAAARequest(d, true)
|
||||
return nil
|
||||
return resultFinish
|
||||
}
|
||||
|
||||
if s.conf.OnDNSRequest != nil {
|
||||
|
@ -443,10 +462,17 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
|
|||
if (d.Req.Question[0].Qtype == dns.TypeA || d.Req.Question[0].Qtype == dns.TypeAAAA) &&
|
||||
d.Req.Question[0].Name == "use-application-dns.net." {
|
||||
d.Res = s.genNXDomain(d.Req)
|
||||
return nil
|
||||
return resultFinish
|
||||
}
|
||||
|
||||
// use dnsfilter before cache -- changed settings or filters would require cache invalidation otherwise
|
||||
return resultDone
|
||||
}
|
||||
|
||||
// Apply filtering logic
|
||||
func processFilteringBeforeRequest(ctx *dnsContext) int {
|
||||
s := ctx.srv
|
||||
d := ctx.proxyCtx
|
||||
|
||||
s.RLock()
|
||||
// Synchronize access to s.dnsFilter so it won't be suddenly uninitialized while in use.
|
||||
// This could happen after proxy server has been stopped, but its workers are not yet exited.
|
||||
|
@ -455,72 +481,94 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
|
|||
// but this would require the Upstream interface to have Close() function
|
||||
// (to prevent from hanging while waiting for unresponsive DNS server to respond).
|
||||
|
||||
var setts *dnsfilter.RequestFilteringSettings
|
||||
var err error
|
||||
res := &dnsfilter.Result{}
|
||||
protectionEnabled := s.conf.ProtectionEnabled && s.dnsFilter != nil
|
||||
if protectionEnabled {
|
||||
setts = s.getClientRequestFilteringSettings(d)
|
||||
res, err = s.filterDNSRequest(d, setts)
|
||||
ctx.protectionEnabled = s.conf.ProtectionEnabled && s.dnsFilter != nil
|
||||
if ctx.protectionEnabled {
|
||||
ctx.setts = s.getClientRequestFilteringSettings(d)
|
||||
ctx.result, err = s.filterDNSRequest(ctx)
|
||||
}
|
||||
s.RUnlock()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var origResp *dns.Msg
|
||||
if d.Res == nil {
|
||||
answer := []dns.RR{}
|
||||
originalQuestion := d.Req.Question[0]
|
||||
|
||||
if res.Reason == dnsfilter.ReasonRewrite && len(res.CanonName) != 0 {
|
||||
answer = append(answer, s.genCNAMEAnswer(d.Req, res.CanonName))
|
||||
// resolve canonical name, not the original host name
|
||||
d.Req.Question[0].Name = dns.Fqdn(res.CanonName)
|
||||
}
|
||||
|
||||
if d.Addr != nil && s.conf.GetUpstreamsByClient != nil {
|
||||
clientIP := ipFromAddr(d.Addr)
|
||||
upstreams := s.conf.GetUpstreamsByClient(clientIP)
|
||||
if len(upstreams) > 0 {
|
||||
log.Debug("Using custom upstreams for %s", clientIP)
|
||||
d.Upstreams = upstreams
|
||||
}
|
||||
}
|
||||
|
||||
// request was not filtered so let it be processed further
|
||||
err = p.Resolve(d)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if res.Reason == dnsfilter.ReasonRewrite && len(res.CanonName) != 0 {
|
||||
d.Req.Question[0] = originalQuestion
|
||||
d.Res.Question[0] = originalQuestion
|
||||
|
||||
if len(d.Res.Answer) != 0 {
|
||||
answer = append(answer, d.Res.Answer...) // host -> IP
|
||||
d.Res.Answer = answer
|
||||
}
|
||||
|
||||
} else if res.Reason != dnsfilter.NotFilteredWhiteList && protectionEnabled {
|
||||
origResp2 := d.Res
|
||||
res, err = s.filterDNSResponse(d, setts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if res != nil {
|
||||
origResp = origResp2 // matched by response
|
||||
} else {
|
||||
res = &dnsfilter.Result{}
|
||||
}
|
||||
}
|
||||
ctx.err = err
|
||||
return resultError
|
||||
}
|
||||
return resultDone
|
||||
}
|
||||
|
||||
// Pass request to upstream servers; process the response
|
||||
func processUpstream(ctx *dnsContext) int {
|
||||
s := ctx.srv
|
||||
d := ctx.proxyCtx
|
||||
if d.Res != nil {
|
||||
d.Res.Compress = true // some devices require DNS message compression
|
||||
return resultDone // response is already set - nothing to do
|
||||
}
|
||||
|
||||
if d.Addr != nil && s.conf.GetUpstreamsByClient != nil {
|
||||
clientIP := ipFromAddr(d.Addr)
|
||||
upstreams := s.conf.GetUpstreamsByClient(clientIP)
|
||||
if len(upstreams) > 0 {
|
||||
log.Debug("Using custom upstreams for %s", clientIP)
|
||||
d.Upstreams = upstreams
|
||||
}
|
||||
}
|
||||
|
||||
// request was not filtered so let it be processed further
|
||||
err := s.dnsProxy.Resolve(d)
|
||||
if err != nil {
|
||||
ctx.err = err
|
||||
return resultError
|
||||
}
|
||||
|
||||
ctx.responseFromUpstream = true
|
||||
return resultDone
|
||||
}
|
||||
|
||||
// Apply filtering logic after we have received response from upstream servers
|
||||
func processFilteringAfterResponse(ctx *dnsContext) int {
|
||||
s := ctx.srv
|
||||
d := ctx.proxyCtx
|
||||
res := ctx.result
|
||||
var err error
|
||||
|
||||
if !ctx.responseFromUpstream {
|
||||
return resultDone // don't process response if it's not from upstream servers
|
||||
}
|
||||
|
||||
if res.Reason == dnsfilter.ReasonRewrite && len(res.CanonName) != 0 {
|
||||
d.Req.Question[0] = ctx.origQuestion
|
||||
d.Res.Question[0] = ctx.origQuestion
|
||||
|
||||
if len(d.Res.Answer) != 0 {
|
||||
answer := []dns.RR{}
|
||||
answer = append(answer, s.genCNAMEAnswer(d.Req, res.CanonName))
|
||||
answer = append(answer, d.Res.Answer...) // host -> IP
|
||||
d.Res.Answer = answer
|
||||
}
|
||||
|
||||
} else if res.Reason != dnsfilter.NotFilteredWhiteList && ctx.protectionEnabled {
|
||||
origResp2 := d.Res
|
||||
ctx.result, err = s.filterDNSResponse(ctx)
|
||||
if err != nil {
|
||||
ctx.err = err
|
||||
return resultError
|
||||
}
|
||||
if ctx.result != nil {
|
||||
ctx.origResp = origResp2 // matched by response
|
||||
} else {
|
||||
ctx.result = &dnsfilter.Result{}
|
||||
}
|
||||
}
|
||||
|
||||
return resultDone
|
||||
}
|
||||
|
||||
// Write Stats data and logs
|
||||
func processQueryLogsAndStats(ctx *dnsContext) int {
|
||||
elapsed := time.Since(ctx.startTime)
|
||||
s := ctx.srv
|
||||
d := ctx.proxyCtx
|
||||
|
||||
shouldLog := true
|
||||
msg := d.Req
|
||||
|
||||
|
@ -529,7 +577,6 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
|
|||
shouldLog = false
|
||||
}
|
||||
|
||||
elapsed := time.Since(start)
|
||||
s.RLock()
|
||||
// Synchronize access to s.queryLog and s.stats so they won't be suddenly uninitialized while in use.
|
||||
// This can happen after proxy server has been stopped, but its workers haven't yet exited.
|
||||
|
@ -537,8 +584,8 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
|
|||
p := querylog.AddParams{
|
||||
Question: msg,
|
||||
Answer: d.Res,
|
||||
OrigAnswer: origResp,
|
||||
Result: res,
|
||||
OrigAnswer: ctx.origResp,
|
||||
Result: ctx.result,
|
||||
Elapsed: elapsed,
|
||||
ClientIP: getIP(d.Addr),
|
||||
}
|
||||
|
@ -548,9 +595,41 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
|
|||
s.queryLog.Add(p)
|
||||
}
|
||||
|
||||
s.updateStats(d, elapsed, *res)
|
||||
s.updateStats(d, elapsed, *ctx.result)
|
||||
s.RUnlock()
|
||||
|
||||
return resultDone
|
||||
}
|
||||
|
||||
// handleDNSRequest filters the incoming DNS requests and writes them to the query log
|
||||
// nolint (gocyclo)
|
||||
func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
|
||||
ctx := &dnsContext{srv: s, proxyCtx: d}
|
||||
ctx.result = &dnsfilter.Result{}
|
||||
ctx.startTime = time.Now()
|
||||
|
||||
type modProcessFunc func(ctx *dnsContext) int
|
||||
mods := []modProcessFunc{
|
||||
processInitial,
|
||||
processFilteringBeforeRequest,
|
||||
processUpstream,
|
||||
processFilteringAfterResponse,
|
||||
}
|
||||
for _, process := range mods {
|
||||
r := process(ctx)
|
||||
switch r {
|
||||
case resultFinish:
|
||||
return nil
|
||||
case resultError:
|
||||
return ctx.err
|
||||
}
|
||||
}
|
||||
|
||||
if d.Res != nil {
|
||||
d.Res.Compress = true // some devices require DNS message compression
|
||||
}
|
||||
|
||||
_ = processQueryLogsAndStats(ctx)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -619,10 +698,11 @@ func (s *Server) getClientRequestFilteringSettings(d *proxy.DNSContext) *dnsfilt
|
|||
}
|
||||
|
||||
// filterDNSRequest applies the dnsFilter and sets d.Res if the request was filtered
|
||||
func (s *Server) filterDNSRequest(d *proxy.DNSContext, setts *dnsfilter.RequestFilteringSettings) (*dnsfilter.Result, error) {
|
||||
func (s *Server) filterDNSRequest(ctx *dnsContext) (*dnsfilter.Result, error) {
|
||||
d := ctx.proxyCtx
|
||||
req := d.Req
|
||||
host := strings.TrimSuffix(req.Question[0].Name, ".")
|
||||
res, err := s.dnsFilter.CheckHost(host, d.Req.Question[0].Qtype, setts)
|
||||
res, err := s.dnsFilter.CheckHost(host, d.Req.Question[0].Qtype, ctx.setts)
|
||||
if err != nil {
|
||||
// Return immediately if there's an error
|
||||
return nil, errorx.Decorate(err, "dnsfilter failed to check host '%s'", host)
|
||||
|
@ -653,6 +733,11 @@ func (s *Server) filterDNSRequest(d *proxy.DNSContext, setts *dnsfilter.RequestF
|
|||
}
|
||||
|
||||
d.Res = resp
|
||||
|
||||
} else if res.Reason == dnsfilter.ReasonRewrite && len(res.CanonName) != 0 {
|
||||
ctx.origQuestion = d.Req.Question[0]
|
||||
// resolve canonical name, not the original host name
|
||||
d.Req.Question[0].Name = dns.Fqdn(res.CanonName)
|
||||
}
|
||||
|
||||
return &res, err
|
||||
|
@ -660,7 +745,8 @@ func (s *Server) filterDNSRequest(d *proxy.DNSContext, setts *dnsfilter.RequestF
|
|||
|
||||
// If response contains CNAME, A or AAAA records, we apply filtering to each canonical host name or IP address.
|
||||
// If this is a match, we set a new response in d.Res and return.
|
||||
func (s *Server) filterDNSResponse(d *proxy.DNSContext, setts *dnsfilter.RequestFilteringSettings) (*dnsfilter.Result, error) {
|
||||
func (s *Server) filterDNSResponse(ctx *dnsContext) (*dnsfilter.Result, error) {
|
||||
d := ctx.proxyCtx
|
||||
for _, a := range d.Res.Answer {
|
||||
host := ""
|
||||
|
||||
|
@ -688,7 +774,7 @@ func (s *Server) filterDNSResponse(d *proxy.DNSContext, setts *dnsfilter.Request
|
|||
s.RUnlock()
|
||||
continue
|
||||
}
|
||||
res, err := s.dnsFilter.CheckHostRules(host, d.Req.Question[0].Qtype, setts)
|
||||
res, err := s.dnsFilter.CheckHostRules(host, d.Req.Question[0].Qtype, ctx.setts)
|
||||
s.RUnlock()
|
||||
|
||||
if err != nil {
|
||||
|
|
Loading…
Reference in New Issue