Merge: + dnsfilter: cache IP addresses of safebrowsing and parental control servers

Close #745

* commit 'd918e5b418de232d95ba1e3d642dca00664f0304':
  use maxDialCacheSize constant
  rename functions and container
  + dnsfilter: cache IP addresses of safebrowsing and parental control servers
This commit is contained in:
Simon Zolin 2019-05-15 14:01:01 +03:00
commit f20cb65189
1 changed files with 43 additions and 2 deletions

View File

@ -33,6 +33,7 @@ const defaultSafebrowsingServer = "sb.adtidy.org"
const defaultSafebrowsingURL = "http://%s/safebrowsing-lookup-hash.html?prefixes=%s" const defaultSafebrowsingURL = "http://%s/safebrowsing-lookup-hash.html?prefixes=%s"
const defaultParentalServer = "pctrl.adguard.com" const defaultParentalServer = "pctrl.adguard.com"
const defaultParentalURL = "http://%s/check-parental-control-hash?prefixes=%s&sensitivity=%d" const defaultParentalURL = "http://%s/check-parental-control-hash?prefixes=%s&sensitivity=%d"
const maxDialCacheSize = 2 // the number of host names for safebrowsing and parental control
// ErrInvalidSyntax is returned by AddRule when the rule is invalid // ErrInvalidSyntax is returned by AddRule when the rule is invalid
var ErrInvalidSyntax = errors.New("dnsfilter: invalid rule syntax") var ErrInvalidSyntax = errors.New("dnsfilter: invalid rule syntax")
@ -157,6 +158,7 @@ const (
// these variables need to survive coredns reload // these variables need to survive coredns reload
var ( var (
stats Stats stats Stats
dialCache gcache.Cache // "host" -> "IP" cache for safebrowsing and parental control servers
safebrowsingCache gcache.Cache safebrowsingCache gcache.Cache
parentalCache gcache.Cache parentalCache gcache.Cache
safeSearchCache gcache.Cache safeSearchCache gcache.Cache
@ -972,10 +974,34 @@ func (d *Dnsfilter) matchHost(host string) (Result, error) {
// lifecycle helper functions // lifecycle helper functions
// //
// Return TRUE if this host's IP should be cached
func (d *Dnsfilter) shouldBeInDialCache(host string) bool {
return host == d.safeBrowsingServer ||
host == d.parentalServer
}
// Search for an IP address by host name
func searchInDialCache(host string) string {
rawValue, err := dialCache.Get(host)
if err != nil {
return ""
}
ip, _ := rawValue.(string)
log.Debug("Found in cache: %s -> %s", host, ip)
return ip
}
// Add "hostname" -> "IP address" entry to cache
func addToDialCache(host, ip string) {
dialCache.Set(host, ip)
log.Debug("Added to cache: %s -> %s", host, ip)
}
type dialFunctionType func(ctx context.Context, network, addr string) (net.Conn, error) type dialFunctionType func(ctx context.Context, network, addr string) (net.Conn, error)
// Connect to a remote server resolving hostname using our own DNS server // Connect to a remote server resolving hostname using our own DNS server
func createCustomDialContext(resolverAddr string) dialFunctionType { func (d *Dnsfilter) createCustomDialContext(resolverAddr string) dialFunctionType {
return func(ctx context.Context, network, addr string) (net.Conn, error) { return func(ctx context.Context, network, addr string) (net.Conn, error) {
log.Tracef("network:%v addr:%v", network, addr) log.Tracef("network:%v addr:%v", network, addr)
@ -993,6 +1019,15 @@ func createCustomDialContext(resolverAddr string) dialFunctionType {
return con, err return con, err
} }
cache := d.shouldBeInDialCache(host)
if cache {
ip := searchInDialCache(host)
if len(ip) != 0 {
addr = fmt.Sprintf("%s:%s", ip, port)
return dialer.DialContext(ctx, network, addr)
}
}
r := upstream.NewResolver(resolverAddr, 30*time.Second) r := upstream.NewResolver(resolverAddr, 30*time.Second)
addrs, e := r.LookupIPAddr(ctx, host) addrs, e := r.LookupIPAddr(ctx, host)
log.Tracef("LookupIPAddr: %s: %v", host, addrs) log.Tracef("LookupIPAddr: %s: %v", host, addrs)
@ -1011,6 +1046,11 @@ func createCustomDialContext(resolverAddr string) dialFunctionType {
} }
continue continue
} }
if cache {
addToDialCache(host, a.String())
}
return con, err return con, err
} }
return nil, firstErr return nil, firstErr
@ -1037,7 +1077,8 @@ func New(c *Config) *Dnsfilter {
ExpectContinueTimeout: 1 * time.Second, ExpectContinueTimeout: 1 * time.Second,
} }
if c != nil && len(c.ResolverAddress) != 0 { if c != nil && len(c.ResolverAddress) != 0 {
d.transport.DialContext = createCustomDialContext(c.ResolverAddress) dialCache = gcache.New(maxDialCacheSize).LRU().Expiration(defaultCacheTime).Build()
d.transport.DialContext = d.createCustomDialContext(c.ResolverAddress)
} }
d.client = http.Client{ d.client = http.Client{
Transport: d.transport, Transport: d.transport,