diff --git a/dnsfilter/dnsfilter.go b/dnsfilter/dnsfilter.go index c46c3e97..b780c9f2 100644 --- a/dnsfilter/dnsfilter.go +++ b/dnsfilter/dnsfilter.go @@ -33,6 +33,7 @@ const defaultSafebrowsingServer = "sb.adtidy.org" const defaultSafebrowsingURL = "http://%s/safebrowsing-lookup-hash.html?prefixes=%s" const defaultParentalServer = "pctrl.adguard.com" const defaultParentalURL = "http://%s/check-parental-control-hash?prefixes=%s&sensitivity=%d" +const maxDialCacheSize = 2 // the number of host names for safebrowsing and parental control // ErrInvalidSyntax is returned by AddRule when the rule is invalid var ErrInvalidSyntax = errors.New("dnsfilter: invalid rule syntax") @@ -157,6 +158,7 @@ const ( // these variables need to survive coredns reload var ( stats Stats + dialCache gcache.Cache // "host" -> "IP" cache for safebrowsing and parental control servers safebrowsingCache gcache.Cache parentalCache gcache.Cache safeSearchCache gcache.Cache @@ -972,10 +974,34 @@ func (d *Dnsfilter) matchHost(host string) (Result, error) { // lifecycle helper functions // +// Return TRUE if this host's IP should be cached +func (d *Dnsfilter) shouldBeInDialCache(host string) bool { + return host == d.safeBrowsingServer || + host == d.parentalServer +} + +// Search for an IP address by host name +func searchInDialCache(host string) string { + rawValue, err := dialCache.Get(host) + if err != nil { + return "" + } + + ip, _ := rawValue.(string) + log.Debug("Found in cache: %s -> %s", host, ip) + return ip +} + +// Add "hostname" -> "IP address" entry to cache +func addToDialCache(host, ip string) { + dialCache.Set(host, ip) + log.Debug("Added to cache: %s -> %s", host, ip) +} + type dialFunctionType func(ctx context.Context, network, addr string) (net.Conn, error) // Connect to a remote server resolving hostname using our own DNS server -func createCustomDialContext(resolverAddr string) dialFunctionType { +func (d *Dnsfilter) createCustomDialContext(resolverAddr string) dialFunctionType { return func(ctx context.Context, network, addr string) (net.Conn, error) { log.Tracef("network:%v addr:%v", network, addr) @@ -993,6 +1019,15 @@ func createCustomDialContext(resolverAddr string) dialFunctionType { return con, err } + cache := d.shouldBeInDialCache(host) + if cache { + ip := searchInDialCache(host) + if len(ip) != 0 { + addr = fmt.Sprintf("%s:%s", ip, port) + return dialer.DialContext(ctx, network, addr) + } + } + r := upstream.NewResolver(resolverAddr, 30*time.Second) addrs, e := r.LookupIPAddr(ctx, host) log.Tracef("LookupIPAddr: %s: %v", host, addrs) @@ -1011,6 +1046,11 @@ func createCustomDialContext(resolverAddr string) dialFunctionType { } continue } + + if cache { + addToDialCache(host, a.String()) + } + return con, err } return nil, firstErr @@ -1037,7 +1077,8 @@ func New(c *Config) *Dnsfilter { ExpectContinueTimeout: 1 * time.Second, } if c != nil && len(c.ResolverAddress) != 0 { - d.transport.DialContext = createCustomDialContext(c.ResolverAddress) + dialCache = gcache.New(maxDialCacheSize).LRU().Expiration(defaultCacheTime).Build() + d.transport.DialContext = d.createCustomDialContext(c.ResolverAddress) } d.client = http.Client{ Transport: d.transport,