diff --git a/home/dns.go b/home/dns.go index 3b90ddfa..1712df9c 100644 --- a/home/dns.go +++ b/home/dns.go @@ -5,27 +5,19 @@ import ( "net" "os" "path/filepath" - "sync" "github.com/AdguardTeam/AdGuardHome/dnsfilter" "github.com/AdguardTeam/AdGuardHome/dnsforward" "github.com/AdguardTeam/AdGuardHome/querylog" "github.com/AdguardTeam/AdGuardHome/stats" "github.com/AdguardTeam/dnsproxy/proxy" - "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/log" "github.com/joomcode/errorx" "github.com/miekg/dns" ) type dnsContext struct { - rdnsChannel chan string // pass data from DNS request handling thread to rDNS thread - // contains IP addresses of clients to be resolved by rDNS - // if IP address couldn't be resolved, it stays here forever to prevent further attempts to resolve the same IP - rdnsIP map[string]bool - rdnsLock sync.Mutex // synchronize access to rdnsIP - upstream upstream.Upstream // Upstream object for our own DNS server - + rdns *RDNS whois *Whois } @@ -57,7 +49,7 @@ func initDNSServer(baseDir string) { config.auth = InitAuth(sessFilename, config.Users) config.Users = nil - initRDNS() + config.dnsctx.rdns = InitRDNS(&config.clients) config.dnsctx.whois = initWhois(&config.clients) initFiltering() } @@ -133,7 +125,7 @@ func onDNSRequest(d *proxy.DNSContext) { ipAddr := net.ParseIP(ip) if !ipAddr.IsLoopback() { - beginAsyncRDNS(ip) + config.dnsctx.rdns.Begin(ip) } if isPublicIP(ipAddr) { config.dnsctx.whois.Begin(ip) diff --git a/home/rdns.go b/home/rdns.go index 048dcde1..c8a39974 100644 --- a/home/rdns.go +++ b/home/rdns.go @@ -3,6 +3,7 @@ package home import ( "fmt" "strings" + "sync" "time" "github.com/AdguardTeam/dnsproxy/upstream" @@ -14,7 +15,21 @@ const ( rdnsTimeout = 3 * time.Second // max time to wait for rDNS response ) -func initRDNS() { +// RDNS - module context +type RDNS struct { + clients *clientsContainer + ipChannel chan string // pass data from DNS request handling thread to rDNS thread + // contains IP addresses of clients to be resolved by rDNS + // if IP address couldn't be resolved, it stays here forever to prevent further attempts to resolve the same IP + ips map[string]bool + lock sync.Mutex // synchronize access to 'ips' + upstream upstream.Upstream // Upstream object for our own DNS server +} + +// InitRDNS - create module context +func InitRDNS(clients *clientsContainer) *RDNS { + r := RDNS{} + r.clients = clients var err error bindhost := config.DNS.BindHost @@ -26,35 +41,36 @@ func initRDNS() { opts := upstream.Options{ Timeout: rdnsTimeout, } - config.dnsctx.upstream, err = upstream.AddressToUpstream(resolverAddress, opts) + r.upstream, err = upstream.AddressToUpstream(resolverAddress, opts) if err != nil { log.Error("upstream.AddressToUpstream: %s", err) - return + return nil } - config.dnsctx.rdnsIP = make(map[string]bool) - config.dnsctx.rdnsChannel = make(chan string, 256) - go asyncRDNSLoop() + r.ips = make(map[string]bool) + r.ipChannel = make(chan string, 256) + go r.workerLoop() + return &r } -// Add IP address to the rDNS queue -func beginAsyncRDNS(ip string) { - if config.clients.Exists(ip, ClientSourceRDNS) { +// Begin - add IP address to rDNS queue +func (r *RDNS) Begin(ip string) { + if r.clients.Exists(ip, ClientSourceRDNS) { return } - // add IP to rdnsIP, if not exists - config.dnsctx.rdnsLock.Lock() - defer config.dnsctx.rdnsLock.Unlock() - _, ok := config.dnsctx.rdnsIP[ip] + // add IP to ips, if not exists + r.lock.Lock() + defer r.lock.Unlock() + _, ok := r.ips[ip] if ok { return } - config.dnsctx.rdnsIP[ip] = true + r.ips[ip] = true log.Tracef("Adding %s for rDNS resolve", ip) select { - case config.dnsctx.rdnsChannel <- ip: + case r.ipChannel <- ip: // default: log.Tracef("rDNS queue is full") @@ -62,7 +78,7 @@ func beginAsyncRDNS(ip string) { } // Use rDNS to get hostname by IP address -func resolveRDNS(ip string) string { +func (r *RDNS) resolve(ip string) string { log.Tracef("Resolving host for %s", ip) req := dns.Msg{} @@ -81,7 +97,7 @@ func resolveRDNS(ip string) string { return "" } - resp, err := config.dnsctx.upstream.Exchange(&req) + resp, err := r.upstream.Exchange(&req) if err != nil { log.Debug("Error while making an rDNS lookup for %s: %s", ip, err) return "" @@ -106,19 +122,19 @@ func resolveRDNS(ip string) string { // Wait for a signal and then synchronously resolve hostname by IP address // Add the hostname:IP pair to "Clients" array -func asyncRDNSLoop() { +func (r *RDNS) workerLoop() { for { var ip string - ip = <-config.dnsctx.rdnsChannel + ip = <-r.ipChannel - host := resolveRDNS(ip) + host := r.resolve(ip) if len(host) == 0 { continue } - config.dnsctx.rdnsLock.Lock() - delete(config.dnsctx.rdnsIP, ip) - config.dnsctx.rdnsLock.Unlock() + r.lock.Lock() + delete(r.ips, ip) + r.lock.Unlock() _, _ = config.clients.AddHost(ip, host, ClientSourceRDNS) }