* dnsfilter: use a single global context object

This commit is contained in:
Simon Zolin 2019-06-24 19:00:03 +03:00
parent f1e6a30931
commit 2307f55715
1 changed files with 30 additions and 23 deletions

View File

@ -128,14 +128,15 @@ const (
FilteredSafeSearch FilteredSafeSearch
) )
// these variables need to survive coredns reload type dnsfContext struct {
var (
stats Stats stats Stats
dialCache gcache.Cache // "host" -> "IP" cache for safebrowsing and parental control servers dialCache gcache.Cache // "host" -> "IP" cache for safebrowsing and parental control servers
safebrowsingCache gcache.Cache safebrowsingCache gcache.Cache
parentalCache gcache.Cache parentalCache gcache.Cache
safeSearchCache gcache.Cache safeSearchCache gcache.Cache
) }
var gctx dnsfContext // global dnsfilter context
// Result holds state of hostname check // Result holds state of hostname check
type Result struct { type Result struct {
@ -298,14 +299,10 @@ func (d *Dnsfilter) checkSafeSearch(host string) (Result, error) {
defer timer.LogElapsed("SafeSearch HTTP lookup for %s", host) defer timer.LogElapsed("SafeSearch HTTP lookup for %s", host)
} }
if safeSearchCache == nil {
safeSearchCache = gcache.New(defaultCacheSize).LRU().Expiration(defaultCacheTime).Build()
}
// Check cache. Return cached result if it was found // Check cache. Return cached result if it was found
cachedValue, isFound, err := getCachedReason(safeSearchCache, host) cachedValue, isFound, err := getCachedReason(gctx.safeSearchCache, host)
if isFound { if isFound {
atomic.AddUint64(&stats.Safesearch.CacheHits, 1) atomic.AddUint64(&gctx.stats.Safesearch.CacheHits, 1)
log.Tracef("%s: found in SafeSearch cache", host) log.Tracef("%s: found in SafeSearch cache", host)
return cachedValue, nil return cachedValue, nil
} }
@ -322,7 +319,7 @@ func (d *Dnsfilter) checkSafeSearch(host string) (Result, error) {
res := Result{IsFiltered: true, Reason: FilteredSafeSearch} res := Result{IsFiltered: true, Reason: FilteredSafeSearch}
if ip := net.ParseIP(safeHost); ip != nil { if ip := net.ParseIP(safeHost); ip != nil {
res.IP = ip res.IP = ip
err = safeSearchCache.Set(host, res) err = gctx.safeSearchCache.Set(host, res)
if err != nil { if err != nil {
return Result{}, nil return Result{}, nil
} }
@ -349,7 +346,7 @@ func (d *Dnsfilter) checkSafeSearch(host string) (Result, error) {
} }
// Cache result // Cache result
err = safeSearchCache.Set(host, res) err = gctx.safeSearchCache.Set(host, res)
if err != nil { if err != nil {
return Result{}, nil return Result{}, nil
} }
@ -395,10 +392,7 @@ func (d *Dnsfilter) checkSafeBrowsing(host string) (Result, error) {
} }
return result, nil return result, nil
} }
if safebrowsingCache == nil { result, err := d.lookupCommon(host, &gctx.stats.Safebrowsing, gctx.safebrowsingCache, true, format, handleBody)
safebrowsingCache = gcache.New(defaultCacheSize).LRU().Expiration(defaultCacheTime).Build()
}
result, err := d.lookupCommon(host, &stats.Safebrowsing, safebrowsingCache, true, format, handleBody)
return result, err return result, err
} }
@ -450,10 +444,7 @@ func (d *Dnsfilter) checkParental(host string) (Result, error) {
} }
return result, nil return result, nil
} }
if parentalCache == nil { result, err := d.lookupCommon(host, &gctx.stats.Parental, gctx.parentalCache, false, format, handleBody)
parentalCache = gcache.New(defaultCacheSize).LRU().Expiration(defaultCacheTime).Build()
}
result, err := d.lookupCommon(host, &stats.Parental, parentalCache, false, format, handleBody)
return result, err return result, err
} }
@ -620,7 +611,7 @@ func (d *Dnsfilter) shouldBeInDialCache(host string) bool {
// Search for an IP address by host name // Search for an IP address by host name
func searchInDialCache(host string) string { func searchInDialCache(host string) string {
rawValue, err := dialCache.Get(host) rawValue, err := gctx.dialCache.Get(host)
if err != nil { if err != nil {
return "" return ""
} }
@ -632,7 +623,7 @@ func searchInDialCache(host string) string {
// Add "hostname" -> "IP address" entry to cache // Add "hostname" -> "IP address" entry to cache
func addToDialCache(host, ip string) { func addToDialCache(host, ip string) {
err := dialCache.Set(host, ip) err := gctx.dialCache.Set(host, ip)
if err != nil { if err != nil {
log.Debug("dialCache.Set: %s", err) log.Debug("dialCache.Set: %s", err)
} }
@ -701,6 +692,23 @@ func (d *Dnsfilter) createCustomDialContext(resolverAddr string) dialFunctionTyp
// New creates properly initialized DNS Filter that is ready to be used // New creates properly initialized DNS Filter that is ready to be used
func New(c *Config, filters map[int]string) *Dnsfilter { func New(c *Config, filters map[int]string) *Dnsfilter {
if c != nil {
// initialize objects only once
if c.SafeBrowsingEnabled && gctx.safebrowsingCache == nil {
gctx.safebrowsingCache = gcache.New(defaultCacheSize).LRU().Expiration(defaultCacheTime).Build()
}
if c.SafeSearchEnabled && gctx.safeSearchCache == nil {
gctx.safeSearchCache = gcache.New(defaultCacheSize).LRU().Expiration(defaultCacheTime).Build()
}
if c.ParentalEnabled && gctx.parentalCache == nil {
gctx.parentalCache = gcache.New(defaultCacheSize).LRU().Expiration(defaultCacheTime).Build()
}
if len(c.ResolverAddress) != 0 && gctx.dialCache == nil {
gctx.dialCache = gcache.New(maxDialCacheSize).LRU().Expiration(defaultCacheTime).Build()
}
}
d := new(Dnsfilter) d := new(Dnsfilter)
// Customize the Transport to have larger connection pool, // Customize the Transport to have larger connection pool,
@ -714,7 +722,6 @@ func New(c *Config, filters map[int]string) *Dnsfilter {
ExpectContinueTimeout: 1 * time.Second, ExpectContinueTimeout: 1 * time.Second,
} }
if c != nil && len(c.ResolverAddress) != 0 { if c != nil && len(c.ResolverAddress) != 0 {
dialCache = gcache.New(maxDialCacheSize).LRU().Expiration(defaultCacheTime).Build()
d.transport.DialContext = d.createCustomDialContext(c.ResolverAddress) d.transport.DialContext = d.createCustomDialContext(c.ResolverAddress)
} }
d.client = http.Client{ d.client = http.Client{
@ -790,5 +797,5 @@ func (d *Dnsfilter) SafeSearchDomain(host string) (string, bool) {
// GetStats return dns filtering stats since startup // GetStats return dns filtering stats since startup
func (d *Dnsfilter) GetStats() Stats { func (d *Dnsfilter) GetStats() Stats {
return stats return gctx.stats
} }