* dnsfilter: use a single global context object
This commit is contained in:
parent
f1e6a30931
commit
2307f55715
|
@ -128,14 +128,15 @@ const (
|
||||||
FilteredSafeSearch
|
FilteredSafeSearch
|
||||||
)
|
)
|
||||||
|
|
||||||
// these variables need to survive coredns reload
|
type dnsfContext struct {
|
||||||
var (
|
|
||||||
stats Stats
|
stats Stats
|
||||||
dialCache gcache.Cache // "host" -> "IP" cache for safebrowsing and parental control servers
|
dialCache gcache.Cache // "host" -> "IP" cache for safebrowsing and parental control servers
|
||||||
safebrowsingCache gcache.Cache
|
safebrowsingCache gcache.Cache
|
||||||
parentalCache gcache.Cache
|
parentalCache gcache.Cache
|
||||||
safeSearchCache gcache.Cache
|
safeSearchCache gcache.Cache
|
||||||
)
|
}
|
||||||
|
|
||||||
|
var gctx dnsfContext // global dnsfilter context
|
||||||
|
|
||||||
// Result holds state of hostname check
|
// Result holds state of hostname check
|
||||||
type Result struct {
|
type Result struct {
|
||||||
|
@ -298,14 +299,10 @@ func (d *Dnsfilter) checkSafeSearch(host string) (Result, error) {
|
||||||
defer timer.LogElapsed("SafeSearch HTTP lookup for %s", host)
|
defer timer.LogElapsed("SafeSearch HTTP lookup for %s", host)
|
||||||
}
|
}
|
||||||
|
|
||||||
if safeSearchCache == nil {
|
|
||||||
safeSearchCache = gcache.New(defaultCacheSize).LRU().Expiration(defaultCacheTime).Build()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check cache. Return cached result if it was found
|
// Check cache. Return cached result if it was found
|
||||||
cachedValue, isFound, err := getCachedReason(safeSearchCache, host)
|
cachedValue, isFound, err := getCachedReason(gctx.safeSearchCache, host)
|
||||||
if isFound {
|
if isFound {
|
||||||
atomic.AddUint64(&stats.Safesearch.CacheHits, 1)
|
atomic.AddUint64(&gctx.stats.Safesearch.CacheHits, 1)
|
||||||
log.Tracef("%s: found in SafeSearch cache", host)
|
log.Tracef("%s: found in SafeSearch cache", host)
|
||||||
return cachedValue, nil
|
return cachedValue, nil
|
||||||
}
|
}
|
||||||
|
@ -322,7 +319,7 @@ func (d *Dnsfilter) checkSafeSearch(host string) (Result, error) {
|
||||||
res := Result{IsFiltered: true, Reason: FilteredSafeSearch}
|
res := Result{IsFiltered: true, Reason: FilteredSafeSearch}
|
||||||
if ip := net.ParseIP(safeHost); ip != nil {
|
if ip := net.ParseIP(safeHost); ip != nil {
|
||||||
res.IP = ip
|
res.IP = ip
|
||||||
err = safeSearchCache.Set(host, res)
|
err = gctx.safeSearchCache.Set(host, res)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Result{}, nil
|
return Result{}, nil
|
||||||
}
|
}
|
||||||
|
@ -349,7 +346,7 @@ func (d *Dnsfilter) checkSafeSearch(host string) (Result, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cache result
|
// Cache result
|
||||||
err = safeSearchCache.Set(host, res)
|
err = gctx.safeSearchCache.Set(host, res)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Result{}, nil
|
return Result{}, nil
|
||||||
}
|
}
|
||||||
|
@ -395,10 +392,7 @@ func (d *Dnsfilter) checkSafeBrowsing(host string) (Result, error) {
|
||||||
}
|
}
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
if safebrowsingCache == nil {
|
result, err := d.lookupCommon(host, &gctx.stats.Safebrowsing, gctx.safebrowsingCache, true, format, handleBody)
|
||||||
safebrowsingCache = gcache.New(defaultCacheSize).LRU().Expiration(defaultCacheTime).Build()
|
|
||||||
}
|
|
||||||
result, err := d.lookupCommon(host, &stats.Safebrowsing, safebrowsingCache, true, format, handleBody)
|
|
||||||
return result, err
|
return result, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -450,10 +444,7 @@ func (d *Dnsfilter) checkParental(host string) (Result, error) {
|
||||||
}
|
}
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
if parentalCache == nil {
|
result, err := d.lookupCommon(host, &gctx.stats.Parental, gctx.parentalCache, false, format, handleBody)
|
||||||
parentalCache = gcache.New(defaultCacheSize).LRU().Expiration(defaultCacheTime).Build()
|
|
||||||
}
|
|
||||||
result, err := d.lookupCommon(host, &stats.Parental, parentalCache, false, format, handleBody)
|
|
||||||
return result, err
|
return result, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -620,7 +611,7 @@ func (d *Dnsfilter) shouldBeInDialCache(host string) bool {
|
||||||
|
|
||||||
// Search for an IP address by host name
|
// Search for an IP address by host name
|
||||||
func searchInDialCache(host string) string {
|
func searchInDialCache(host string) string {
|
||||||
rawValue, err := dialCache.Get(host)
|
rawValue, err := gctx.dialCache.Get(host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
@ -632,7 +623,7 @@ func searchInDialCache(host string) string {
|
||||||
|
|
||||||
// Add "hostname" -> "IP address" entry to cache
|
// Add "hostname" -> "IP address" entry to cache
|
||||||
func addToDialCache(host, ip string) {
|
func addToDialCache(host, ip string) {
|
||||||
err := dialCache.Set(host, ip)
|
err := gctx.dialCache.Set(host, ip)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debug("dialCache.Set: %s", err)
|
log.Debug("dialCache.Set: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -701,6 +692,23 @@ func (d *Dnsfilter) createCustomDialContext(resolverAddr string) dialFunctionTyp
|
||||||
|
|
||||||
// New creates properly initialized DNS Filter that is ready to be used
|
// New creates properly initialized DNS Filter that is ready to be used
|
||||||
func New(c *Config, filters map[int]string) *Dnsfilter {
|
func New(c *Config, filters map[int]string) *Dnsfilter {
|
||||||
|
|
||||||
|
if c != nil {
|
||||||
|
// initialize objects only once
|
||||||
|
if c.SafeBrowsingEnabled && gctx.safebrowsingCache == nil {
|
||||||
|
gctx.safebrowsingCache = gcache.New(defaultCacheSize).LRU().Expiration(defaultCacheTime).Build()
|
||||||
|
}
|
||||||
|
if c.SafeSearchEnabled && gctx.safeSearchCache == nil {
|
||||||
|
gctx.safeSearchCache = gcache.New(defaultCacheSize).LRU().Expiration(defaultCacheTime).Build()
|
||||||
|
}
|
||||||
|
if c.ParentalEnabled && gctx.parentalCache == nil {
|
||||||
|
gctx.parentalCache = gcache.New(defaultCacheSize).LRU().Expiration(defaultCacheTime).Build()
|
||||||
|
}
|
||||||
|
if len(c.ResolverAddress) != 0 && gctx.dialCache == nil {
|
||||||
|
gctx.dialCache = gcache.New(maxDialCacheSize).LRU().Expiration(defaultCacheTime).Build()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
d := new(Dnsfilter)
|
d := new(Dnsfilter)
|
||||||
|
|
||||||
// Customize the Transport to have larger connection pool,
|
// Customize the Transport to have larger connection pool,
|
||||||
|
@ -714,7 +722,6 @@ func New(c *Config, filters map[int]string) *Dnsfilter {
|
||||||
ExpectContinueTimeout: 1 * time.Second,
|
ExpectContinueTimeout: 1 * time.Second,
|
||||||
}
|
}
|
||||||
if c != nil && len(c.ResolverAddress) != 0 {
|
if c != nil && len(c.ResolverAddress) != 0 {
|
||||||
dialCache = gcache.New(maxDialCacheSize).LRU().Expiration(defaultCacheTime).Build()
|
|
||||||
d.transport.DialContext = d.createCustomDialContext(c.ResolverAddress)
|
d.transport.DialContext = d.createCustomDialContext(c.ResolverAddress)
|
||||||
}
|
}
|
||||||
d.client = http.Client{
|
d.client = http.Client{
|
||||||
|
@ -790,5 +797,5 @@ func (d *Dnsfilter) SafeSearchDomain(host string) (string, bool) {
|
||||||
|
|
||||||
// GetStats return dns filtering stats since startup
|
// GetStats return dns filtering stats since startup
|
||||||
func (d *Dnsfilter) GetStats() Stats {
|
func (d *Dnsfilter) GetStats() Stats {
|
||||||
return stats
|
return gctx.stats
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue