* rDNS: refactor

This commit is contained in:
Simon Zolin 2019-09-19 18:27:13 +03:00
parent fd592b4dc8
commit 192b58b9d9
2 changed files with 42 additions and 34 deletions

View File

@ -5,27 +5,19 @@ import (
"net" "net"
"os" "os"
"path/filepath" "path/filepath"
"sync"
"github.com/AdguardTeam/AdGuardHome/dnsfilter" "github.com/AdguardTeam/AdGuardHome/dnsfilter"
"github.com/AdguardTeam/AdGuardHome/dnsforward" "github.com/AdguardTeam/AdGuardHome/dnsforward"
"github.com/AdguardTeam/AdGuardHome/querylog" "github.com/AdguardTeam/AdGuardHome/querylog"
"github.com/AdguardTeam/AdGuardHome/stats" "github.com/AdguardTeam/AdGuardHome/stats"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/joomcode/errorx" "github.com/joomcode/errorx"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
type dnsContext struct { type dnsContext struct {
rdnsChannel chan string // pass data from DNS request handling thread to rDNS thread rdns *RDNS
// contains IP addresses of clients to be resolved by rDNS
// if IP address couldn't be resolved, it stays here forever to prevent further attempts to resolve the same IP
rdnsIP map[string]bool
rdnsLock sync.Mutex // synchronize access to rdnsIP
upstream upstream.Upstream // Upstream object for our own DNS server
whois *Whois whois *Whois
} }
@ -57,7 +49,7 @@ func initDNSServer(baseDir string) {
config.auth = InitAuth(sessFilename, config.Users) config.auth = InitAuth(sessFilename, config.Users)
config.Users = nil config.Users = nil
initRDNS() config.dnsctx.rdns = InitRDNS(&config.clients)
config.dnsctx.whois = initWhois(&config.clients) config.dnsctx.whois = initWhois(&config.clients)
initFiltering() initFiltering()
} }
@ -133,7 +125,7 @@ func onDNSRequest(d *proxy.DNSContext) {
ipAddr := net.ParseIP(ip) ipAddr := net.ParseIP(ip)
if !ipAddr.IsLoopback() { if !ipAddr.IsLoopback() {
beginAsyncRDNS(ip) config.dnsctx.rdns.Begin(ip)
} }
if isPublicIP(ipAddr) { if isPublicIP(ipAddr) {
config.dnsctx.whois.Begin(ip) config.dnsctx.whois.Begin(ip)

View File

@ -3,6 +3,7 @@ package home
import ( import (
"fmt" "fmt"
"strings" "strings"
"sync"
"time" "time"
"github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/dnsproxy/upstream"
@ -14,7 +15,21 @@ const (
rdnsTimeout = 3 * time.Second // max time to wait for rDNS response rdnsTimeout = 3 * time.Second // max time to wait for rDNS response
) )
func initRDNS() { // RDNS - module context
type RDNS struct {
clients *clientsContainer
ipChannel chan string // pass data from DNS request handling thread to rDNS thread
// contains IP addresses of clients to be resolved by rDNS
// if IP address couldn't be resolved, it stays here forever to prevent further attempts to resolve the same IP
ips map[string]bool
lock sync.Mutex // synchronize access to 'ips'
upstream upstream.Upstream // Upstream object for our own DNS server
}
// InitRDNS - create module context
func InitRDNS(clients *clientsContainer) *RDNS {
r := RDNS{}
r.clients = clients
var err error var err error
bindhost := config.DNS.BindHost bindhost := config.DNS.BindHost
@ -26,35 +41,36 @@ func initRDNS() {
opts := upstream.Options{ opts := upstream.Options{
Timeout: rdnsTimeout, Timeout: rdnsTimeout,
} }
config.dnsctx.upstream, err = upstream.AddressToUpstream(resolverAddress, opts) r.upstream, err = upstream.AddressToUpstream(resolverAddress, opts)
if err != nil { if err != nil {
log.Error("upstream.AddressToUpstream: %s", err) log.Error("upstream.AddressToUpstream: %s", err)
return nil
}
r.ips = make(map[string]bool)
r.ipChannel = make(chan string, 256)
go r.workerLoop()
return &r
}
// Begin - add IP address to rDNS queue
func (r *RDNS) Begin(ip string) {
if r.clients.Exists(ip, ClientSourceRDNS) {
return return
} }
config.dnsctx.rdnsIP = make(map[string]bool) // add IP to ips, if not exists
config.dnsctx.rdnsChannel = make(chan string, 256) r.lock.Lock()
go asyncRDNSLoop() defer r.lock.Unlock()
} _, ok := r.ips[ip]
// Add IP address to the rDNS queue
func beginAsyncRDNS(ip string) {
if config.clients.Exists(ip, ClientSourceRDNS) {
return
}
// add IP to rdnsIP, if not exists
config.dnsctx.rdnsLock.Lock()
defer config.dnsctx.rdnsLock.Unlock()
_, ok := config.dnsctx.rdnsIP[ip]
if ok { if ok {
return return
} }
config.dnsctx.rdnsIP[ip] = true r.ips[ip] = true
log.Tracef("Adding %s for rDNS resolve", ip) log.Tracef("Adding %s for rDNS resolve", ip)
select { select {
case config.dnsctx.rdnsChannel <- ip: case r.ipChannel <- ip:
// //
default: default:
log.Tracef("rDNS queue is full") log.Tracef("rDNS queue is full")
@ -62,7 +78,7 @@ func beginAsyncRDNS(ip string) {
} }
// Use rDNS to get hostname by IP address // Use rDNS to get hostname by IP address
func resolveRDNS(ip string) string { func (r *RDNS) resolve(ip string) string {
log.Tracef("Resolving host for %s", ip) log.Tracef("Resolving host for %s", ip)
req := dns.Msg{} req := dns.Msg{}
@ -81,7 +97,7 @@ func resolveRDNS(ip string) string {
return "" return ""
} }
resp, err := config.dnsctx.upstream.Exchange(&req) resp, err := r.upstream.Exchange(&req)
if err != nil { if err != nil {
log.Debug("Error while making an rDNS lookup for %s: %s", ip, err) log.Debug("Error while making an rDNS lookup for %s: %s", ip, err)
return "" return ""
@ -106,19 +122,19 @@ func resolveRDNS(ip string) string {
// Wait for a signal and then synchronously resolve hostname by IP address // Wait for a signal and then synchronously resolve hostname by IP address
// Add the hostname:IP pair to "Clients" array // Add the hostname:IP pair to "Clients" array
func asyncRDNSLoop() { func (r *RDNS) workerLoop() {
for { for {
var ip string var ip string
ip = <-config.dnsctx.rdnsChannel ip = <-r.ipChannel
host := resolveRDNS(ip) host := r.resolve(ip)
if len(host) == 0 { if len(host) == 0 {
continue continue
} }
config.dnsctx.rdnsLock.Lock() r.lock.Lock()
delete(config.dnsctx.rdnsIP, ip) delete(r.ips, ip)
config.dnsctx.rdnsLock.Unlock() r.lock.Unlock()
_, _ = config.clients.AddHost(ip, host, ClientSourceRDNS) _, _ = config.clients.AddHost(ip, host, ClientSourceRDNS)
} }