diff --git a/dnsfilter/dnsfilter.go b/dnsfilter/dnsfilter.go index c46c3e97..3396fb7c 100644 --- a/dnsfilter/dnsfilter.go +++ b/dnsfilter/dnsfilter.go @@ -157,6 +157,7 @@ const ( // these variables need to survive coredns reload var ( stats Stats + securityCache gcache.Cache // "host" -> "IP" cache for safebrowsing and parental control servers safebrowsingCache gcache.Cache parentalCache gcache.Cache safeSearchCache gcache.Cache @@ -972,10 +973,34 @@ func (d *Dnsfilter) matchHost(host string) (Result, error) { // lifecycle helper functions // +// Return TRUE if this host's IP should be cached +func (d *Dnsfilter) shouldCache(host string) bool { + return host == d.safeBrowsingServer || + host == d.parentalServer +} + +// Search for an IP address by host name +func searchInCache(host string) string { + rawValue, err := securityCache.Get(host) + if err != nil { + return "" + } + + ip, _ := rawValue.(string) + log.Debug("Found in cache: %s -> %s", host, ip) + return ip +} + +// Add "hostname" -> "IP address" entry to cache +func addToCache(host, ip string) { + securityCache.Set(host, ip) + log.Debug("Added to cache: %s -> %s", host, ip) +} + type dialFunctionType func(ctx context.Context, network, addr string) (net.Conn, error) // Connect to a remote server resolving hostname using our own DNS server -func createCustomDialContext(resolverAddr string) dialFunctionType { +func (d *Dnsfilter) createCustomDialContext(resolverAddr string) dialFunctionType { return func(ctx context.Context, network, addr string) (net.Conn, error) { log.Tracef("network:%v addr:%v", network, addr) @@ -993,6 +1018,15 @@ func createCustomDialContext(resolverAddr string) dialFunctionType { return con, err } + cache := d.shouldCache(host) + if cache { + ip := searchInCache(host) + if len(ip) != 0 { + addr = fmt.Sprintf("%s:%s", ip, port) + return dialer.DialContext(ctx, network, addr) + } + } + r := upstream.NewResolver(resolverAddr, 30*time.Second) addrs, e := r.LookupIPAddr(ctx, host) log.Tracef("LookupIPAddr: %s: %v", host, addrs) @@ -1011,6 +1045,11 @@ func createCustomDialContext(resolverAddr string) dialFunctionType { } continue } + + if cache { + addToCache(host, a.String()) + } + return con, err } return nil, firstErr @@ -1037,7 +1076,8 @@ func New(c *Config) *Dnsfilter { ExpectContinueTimeout: 1 * time.Second, } if c != nil && len(c.ResolverAddress) != 0 { - d.transport.DialContext = createCustomDialContext(c.ResolverAddress) + securityCache = gcache.New(2).LRU().Expiration(defaultCacheTime).Build() + d.transport.DialContext = d.createCustomDialContext(c.ResolverAddress) } d.client = http.Client{ Transport: d.transport,