diff --git a/dnsfilter/dnsfilter.go b/dnsfilter/dnsfilter.go index e58b0350..c46c3e97 100644 --- a/dnsfilter/dnsfilter.go +++ b/dnsfilter/dnsfilter.go @@ -162,8 +162,6 @@ var ( safeSearchCache gcache.Cache ) -var resolverAddr string // DNS server address - // Result holds state of hostname check type Result struct { IsFiltered bool `json:",omitempty"` // True if the host name is filtered @@ -185,6 +183,10 @@ func (d *Dnsfilter) CheckHost(host string) (Result, error) { return Result{Reason: NotFilteredNotFound}, nil } host = strings.ToLower(host) + // prevent recursion + if host == d.parentalServer || host == d.safeBrowsingServer { + return Result{}, nil + } // try filter lists first result, err := d.matchHost(host) @@ -674,10 +676,6 @@ func (d *Dnsfilter) checkSafeBrowsing(host string) (Result, error) { defer timer.LogElapsed("SafeBrowsing HTTP lookup for %s", host) } - // prevent recursion -- checking the host of safebrowsing server makes no sense - if host == d.safeBrowsingServer { - return Result{}, nil - } format := func(hashparam string) string { url := fmt.Sprintf(defaultSafebrowsingURL, d.safeBrowsingServer, hashparam) return url @@ -720,10 +718,6 @@ func (d *Dnsfilter) checkParental(host string) (Result, error) { defer timer.LogElapsed("Parental HTTP lookup for %s", host) } - // prevent recursion -- checking the host of parental safety server makes no sense - if host == d.parentalServer { - return Result{}, nil - } format := func(hashparam string) string { url := fmt.Sprintf(defaultParentalURL, d.parentalServer, hashparam, d.ParentalSensitivity) return url @@ -978,45 +972,49 @@ func (d *Dnsfilter) matchHost(host string) (Result, error) { // lifecycle helper functions // +type dialFunctionType func(ctx context.Context, network, addr string) (net.Conn, error) + // Connect to a remote server resolving hostname using our own DNS server -func customDialContext(ctx context.Context, network, addr string) (net.Conn, error) { - log.Tracef("network:%v addr:%v", network, addr) +func createCustomDialContext(resolverAddr string) dialFunctionType { + return func(ctx context.Context, network, addr string) (net.Conn, error) { + log.Tracef("network:%v addr:%v", network, addr) - host, port, err := net.SplitHostPort(addr) - if err != nil { - return nil, err - } - - dialer := &net.Dialer{ - Timeout: time.Minute * 5, - } - - if net.ParseIP(host) != nil { - con, err := dialer.DialContext(ctx, network, addr) - return con, err - } - - r := upstream.NewResolver(resolverAddr, 30*time.Second) - addrs, e := r.LookupIPAddr(ctx, host) - log.Tracef("LookupIPAddr: %s: %v", host, addrs) - if e != nil { - return nil, e - } - - var firstErr error - firstErr = nil - for _, a := range addrs { - addr = fmt.Sprintf("%s:%s", a.String(), port) - con, err := dialer.DialContext(ctx, network, addr) + host, port, err := net.SplitHostPort(addr) if err != nil { - if firstErr == nil { - firstErr = err - } - continue + return nil, err } - return con, err + + dialer := &net.Dialer{ + Timeout: time.Minute * 5, + } + + if net.ParseIP(host) != nil { + con, err := dialer.DialContext(ctx, network, addr) + return con, err + } + + r := upstream.NewResolver(resolverAddr, 30*time.Second) + addrs, e := r.LookupIPAddr(ctx, host) + log.Tracef("LookupIPAddr: %s: %v", host, addrs) + if e != nil { + return nil, e + } + + var firstErr error + firstErr = nil + for _, a := range addrs { + addr = fmt.Sprintf("%s:%s", a.String(), port) + con, err := dialer.DialContext(ctx, network, addr) + if err != nil { + if firstErr == nil { + firstErr = err + } + continue + } + return con, err + } + return nil, firstErr } - return nil, firstErr } // New creates properly initialized DNS Filter that is ready to be used @@ -1039,8 +1037,7 @@ func New(c *Config) *Dnsfilter { ExpectContinueTimeout: 1 * time.Second, } if c != nil && len(c.ResolverAddress) != 0 { - resolverAddr = c.ResolverAddress - d.transport.DialContext = customDialContext + d.transport.DialContext = createCustomDialContext(c.ResolverAddress) } d.client = http.Client{ Transport: d.transport,