diff --git a/home/rdns.go b/home/rdns.go index c8a39974..abdea47f 100644 --- a/home/rdns.go +++ b/home/rdns.go @@ -1,12 +1,13 @@ package home import ( + "encoding/binary" "fmt" "strings" - "sync" "time" "github.com/AdguardTeam/dnsproxy/upstream" + "github.com/AdguardTeam/golibs/cache" "github.com/AdguardTeam/golibs/log" "github.com/miekg/dns" ) @@ -18,12 +19,14 @@ const ( // RDNS - module context type RDNS struct { clients *clientsContainer - ipChannel chan string // pass data from DNS request handling thread to rDNS thread - // contains IP addresses of clients to be resolved by rDNS - // if IP address couldn't be resolved, it stays here forever to prevent further attempts to resolve the same IP - ips map[string]bool - lock sync.Mutex // synchronize access to 'ips' - upstream upstream.Upstream // Upstream object for our own DNS server + ipChannel chan string // pass data from DNS request handling thread to rDNS thread + upstream upstream.Upstream // Upstream object for our own DNS server + + // Contains IP addresses of clients to be resolved by rDNS + // If IP address is resolved, it stays here while it's inside Clients. + // If it's removed from Clients, this IP address will be resolved once again. + // If IP address couldn't be resolved, it stays here for some time to prevent further attempts to resolve the same IP. + ipAddrs cache.Cache } // InitRDNS - create module context @@ -47,7 +50,11 @@ func InitRDNS(clients *clientsContainer) *RDNS { return nil } - r.ips = make(map[string]bool) + cconf := cache.Config{} + cconf.EnableLRU = true + cconf.MaxCount = 10000 + r.ipAddrs = cache.New(cconf) + r.ipChannel = make(chan string, 256) go r.workerLoop() return &r @@ -55,25 +62,30 @@ func InitRDNS(clients *clientsContainer) *RDNS { // Begin - add IP address to rDNS queue func (r *RDNS) Begin(ip string) { + now := uint64(time.Now().Unix()) + expire := r.ipAddrs.Get([]byte(ip)) + if len(expire) != 0 { + exp := binary.BigEndian.Uint64(expire) + if exp > now { + return + } + // TTL expired + } + expire = make([]byte, 8) + const ttl = 12 * 60 * 60 + binary.BigEndian.PutUint64(expire, now+ttl) + _ = r.ipAddrs.Set([]byte(ip), expire) + if r.clients.Exists(ip, ClientSourceRDNS) { return } - // add IP to ips, if not exists - r.lock.Lock() - defer r.lock.Unlock() - _, ok := r.ips[ip] - if ok { - return - } - r.ips[ip] = true - - log.Tracef("Adding %s for rDNS resolve", ip) + log.Tracef("rDNS: adding %s", ip) select { case r.ipChannel <- ip: // default: - log.Tracef("rDNS queue is full") + log.Tracef("rDNS: queue is full") } } @@ -132,10 +144,6 @@ func (r *RDNS) workerLoop() { continue } - r.lock.Lock() - delete(r.ips, ip) - r.lock.Unlock() - _, _ = config.clients.AddHost(ip, host, ClientSourceRDNS) } } diff --git a/home/whois.go b/home/whois.go index 1f7eafba..f7c90a2f 100644 --- a/home/whois.go +++ b/home/whois.go @@ -1,13 +1,14 @@ package home import ( + "encoding/binary" "fmt" "io/ioutil" "net" "strings" - "sync" "time" + "github.com/AdguardTeam/golibs/cache" "github.com/AdguardTeam/golibs/log" ) @@ -20,10 +21,13 @@ const ( // Whois - module context type Whois struct { clients *clientsContainer - ips map[string]bool - lock sync.Mutex ipChan chan string timeoutMsec uint + + // Contains IP addresses of clients + // An active IP address is resolved once again after it expires. + // If IP address couldn't be resolved, it stays here for some time to prevent further attempts to resolve the same IP. + ipAddrs cache.Cache } // Create module context @@ -31,7 +35,12 @@ func initWhois(clients *clientsContainer) *Whois { w := Whois{} w.timeoutMsec = 5000 w.clients = clients - w.ips = make(map[string]bool) + + cconf := cache.Config{} + cconf.EnableLRU = true + cconf.MaxCount = 10000 + w.ipAddrs = cache.New(cconf) + w.ipChan = make(chan string, 255) go w.workerLoop() return &w @@ -186,14 +195,19 @@ func (w *Whois) process(ip string) [][]string { // Begin - begin requesting WHOIS info func (w *Whois) Begin(ip string) { - w.lock.Lock() - _, found := w.ips[ip] - if found { - w.lock.Unlock() - return + now := uint64(time.Now().Unix()) + expire := w.ipAddrs.Get([]byte(ip)) + if len(expire) != 0 { + exp := binary.BigEndian.Uint64(expire) + if exp > now { + return + } + // TTL expired } - w.ips[ip] = true - w.lock.Unlock() + expire = make([]byte, 8) + const ttl = 12 * 60 * 60 + binary.BigEndian.PutUint64(expire, now+ttl) + _ = w.ipAddrs.Set([]byte(ip), expire) log.Debug("Whois: adding %s", ip) select {