diff --git a/dnsfilter/dnsfilter.go b/dnsfilter/dnsfilter.go index dcecbccc..ea514d00 100644 --- a/dnsfilter/dnsfilter.go +++ b/dnsfilter/dnsfilter.go @@ -128,14 +128,15 @@ const ( FilteredSafeSearch ) -// these variables need to survive coredns reload -var ( +type dnsfContext struct { stats Stats dialCache gcache.Cache // "host" -> "IP" cache for safebrowsing and parental control servers safebrowsingCache gcache.Cache parentalCache gcache.Cache safeSearchCache gcache.Cache -) +} + +var gctx dnsfContext // global dnsfilter context // Result holds state of hostname check type Result struct { @@ -298,14 +299,10 @@ func (d *Dnsfilter) checkSafeSearch(host string) (Result, error) { defer timer.LogElapsed("SafeSearch HTTP lookup for %s", host) } - if safeSearchCache == nil { - safeSearchCache = gcache.New(defaultCacheSize).LRU().Expiration(defaultCacheTime).Build() - } - // Check cache. Return cached result if it was found - cachedValue, isFound, err := getCachedReason(safeSearchCache, host) + cachedValue, isFound, err := getCachedReason(gctx.safeSearchCache, host) if isFound { - atomic.AddUint64(&stats.Safesearch.CacheHits, 1) + atomic.AddUint64(&gctx.stats.Safesearch.CacheHits, 1) log.Tracef("%s: found in SafeSearch cache", host) return cachedValue, nil } @@ -322,7 +319,7 @@ func (d *Dnsfilter) checkSafeSearch(host string) (Result, error) { res := Result{IsFiltered: true, Reason: FilteredSafeSearch} if ip := net.ParseIP(safeHost); ip != nil { res.IP = ip - err = safeSearchCache.Set(host, res) + err = gctx.safeSearchCache.Set(host, res) if err != nil { return Result{}, nil } @@ -349,7 +346,7 @@ func (d *Dnsfilter) checkSafeSearch(host string) (Result, error) { } // Cache result - err = safeSearchCache.Set(host, res) + err = gctx.safeSearchCache.Set(host, res) if err != nil { return Result{}, nil } @@ -395,10 +392,7 @@ func (d *Dnsfilter) checkSafeBrowsing(host string) (Result, error) { } return result, nil } - if safebrowsingCache == nil { - safebrowsingCache = gcache.New(defaultCacheSize).LRU().Expiration(defaultCacheTime).Build() - } - result, err := d.lookupCommon(host, &stats.Safebrowsing, safebrowsingCache, true, format, handleBody) + result, err := d.lookupCommon(host, &gctx.stats.Safebrowsing, gctx.safebrowsingCache, true, format, handleBody) return result, err } @@ -450,10 +444,7 @@ func (d *Dnsfilter) checkParental(host string) (Result, error) { } return result, nil } - if parentalCache == nil { - parentalCache = gcache.New(defaultCacheSize).LRU().Expiration(defaultCacheTime).Build() - } - result, err := d.lookupCommon(host, &stats.Parental, parentalCache, false, format, handleBody) + result, err := d.lookupCommon(host, &gctx.stats.Parental, gctx.parentalCache, false, format, handleBody) return result, err } @@ -620,7 +611,7 @@ func (d *Dnsfilter) shouldBeInDialCache(host string) bool { // Search for an IP address by host name func searchInDialCache(host string) string { - rawValue, err := dialCache.Get(host) + rawValue, err := gctx.dialCache.Get(host) if err != nil { return "" } @@ -632,7 +623,7 @@ func searchInDialCache(host string) string { // Add "hostname" -> "IP address" entry to cache func addToDialCache(host, ip string) { - err := dialCache.Set(host, ip) + err := gctx.dialCache.Set(host, ip) if err != nil { log.Debug("dialCache.Set: %s", err) } @@ -701,6 +692,23 @@ func (d *Dnsfilter) createCustomDialContext(resolverAddr string) dialFunctionTyp // New creates properly initialized DNS Filter that is ready to be used func New(c *Config, filters map[int]string) *Dnsfilter { + + if c != nil { + // initialize objects only once + if c.SafeBrowsingEnabled && gctx.safebrowsingCache == nil { + gctx.safebrowsingCache = gcache.New(defaultCacheSize).LRU().Expiration(defaultCacheTime).Build() + } + if c.SafeSearchEnabled && gctx.safeSearchCache == nil { + gctx.safeSearchCache = gcache.New(defaultCacheSize).LRU().Expiration(defaultCacheTime).Build() + } + if c.ParentalEnabled && gctx.parentalCache == nil { + gctx.parentalCache = gcache.New(defaultCacheSize).LRU().Expiration(defaultCacheTime).Build() + } + if len(c.ResolverAddress) != 0 && gctx.dialCache == nil { + gctx.dialCache = gcache.New(maxDialCacheSize).LRU().Expiration(defaultCacheTime).Build() + } + } + d := new(Dnsfilter) // Customize the Transport to have larger connection pool, @@ -714,7 +722,6 @@ func New(c *Config, filters map[int]string) *Dnsfilter { ExpectContinueTimeout: 1 * time.Second, } if c != nil && len(c.ResolverAddress) != 0 { - dialCache = gcache.New(maxDialCacheSize).LRU().Expiration(defaultCacheTime).Build() d.transport.DialContext = d.createCustomDialContext(c.ResolverAddress) } d.client = http.Client{ @@ -790,5 +797,5 @@ func (d *Dnsfilter) SafeSearchDomain(host string) (string, bool) { // GetStats return dns filtering stats since startup func (d *Dnsfilter) GetStats() Stats { - return stats + return gctx.stats }