Merge: dnsfilter: prevent recursion when both parental control and safebrowsing are enabled
Close #732 * commit 'c4e67690f4fcceb055cbea73610b5974855db96f': * dnsfilter: don't use global variable for custom resolver function - dnsfilter: prevent recursion when both parental control and safebrowsing are enabled
This commit is contained in:
commit
e1bb89c393
|
@ -162,8 +162,6 @@ var (
|
|||
safeSearchCache gcache.Cache
|
||||
)
|
||||
|
||||
var resolverAddr string // DNS server address
|
||||
|
||||
// Result holds state of hostname check
|
||||
type Result struct {
|
||||
IsFiltered bool `json:",omitempty"` // True if the host name is filtered
|
||||
|
@ -185,6 +183,10 @@ func (d *Dnsfilter) CheckHost(host string) (Result, error) {
|
|||
return Result{Reason: NotFilteredNotFound}, nil
|
||||
}
|
||||
host = strings.ToLower(host)
|
||||
// prevent recursion
|
||||
if host == d.parentalServer || host == d.safeBrowsingServer {
|
||||
return Result{}, nil
|
||||
}
|
||||
|
||||
// try filter lists first
|
||||
result, err := d.matchHost(host)
|
||||
|
@ -674,10 +676,6 @@ func (d *Dnsfilter) checkSafeBrowsing(host string) (Result, error) {
|
|||
defer timer.LogElapsed("SafeBrowsing HTTP lookup for %s", host)
|
||||
}
|
||||
|
||||
// prevent recursion -- checking the host of safebrowsing server makes no sense
|
||||
if host == d.safeBrowsingServer {
|
||||
return Result{}, nil
|
||||
}
|
||||
format := func(hashparam string) string {
|
||||
url := fmt.Sprintf(defaultSafebrowsingURL, d.safeBrowsingServer, hashparam)
|
||||
return url
|
||||
|
@ -720,10 +718,6 @@ func (d *Dnsfilter) checkParental(host string) (Result, error) {
|
|||
defer timer.LogElapsed("Parental HTTP lookup for %s", host)
|
||||
}
|
||||
|
||||
// prevent recursion -- checking the host of parental safety server makes no sense
|
||||
if host == d.parentalServer {
|
||||
return Result{}, nil
|
||||
}
|
||||
format := func(hashparam string) string {
|
||||
url := fmt.Sprintf(defaultParentalURL, d.parentalServer, hashparam, d.ParentalSensitivity)
|
||||
return url
|
||||
|
@ -978,45 +972,49 @@ func (d *Dnsfilter) matchHost(host string) (Result, error) {
|
|||
// lifecycle helper functions
|
||||
//
|
||||
|
||||
type dialFunctionType func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||
|
||||
// Connect to a remote server resolving hostname using our own DNS server
|
||||
func customDialContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
log.Tracef("network:%v addr:%v", network, addr)
|
||||
func createCustomDialContext(resolverAddr string) dialFunctionType {
|
||||
return func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
log.Tracef("network:%v addr:%v", network, addr)
|
||||
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dialer := &net.Dialer{
|
||||
Timeout: time.Minute * 5,
|
||||
}
|
||||
|
||||
if net.ParseIP(host) != nil {
|
||||
con, err := dialer.DialContext(ctx, network, addr)
|
||||
return con, err
|
||||
}
|
||||
|
||||
r := upstream.NewResolver(resolverAddr, 30*time.Second)
|
||||
addrs, e := r.LookupIPAddr(ctx, host)
|
||||
log.Tracef("LookupIPAddr: %s: %v", host, addrs)
|
||||
if e != nil {
|
||||
return nil, e
|
||||
}
|
||||
|
||||
var firstErr error
|
||||
firstErr = nil
|
||||
for _, a := range addrs {
|
||||
addr = fmt.Sprintf("%s:%s", a.String(), port)
|
||||
con, err := dialer.DialContext(ctx, network, addr)
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
if firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
continue
|
||||
return nil, err
|
||||
}
|
||||
return con, err
|
||||
|
||||
dialer := &net.Dialer{
|
||||
Timeout: time.Minute * 5,
|
||||
}
|
||||
|
||||
if net.ParseIP(host) != nil {
|
||||
con, err := dialer.DialContext(ctx, network, addr)
|
||||
return con, err
|
||||
}
|
||||
|
||||
r := upstream.NewResolver(resolverAddr, 30*time.Second)
|
||||
addrs, e := r.LookupIPAddr(ctx, host)
|
||||
log.Tracef("LookupIPAddr: %s: %v", host, addrs)
|
||||
if e != nil {
|
||||
return nil, e
|
||||
}
|
||||
|
||||
var firstErr error
|
||||
firstErr = nil
|
||||
for _, a := range addrs {
|
||||
addr = fmt.Sprintf("%s:%s", a.String(), port)
|
||||
con, err := dialer.DialContext(ctx, network, addr)
|
||||
if err != nil {
|
||||
if firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
continue
|
||||
}
|
||||
return con, err
|
||||
}
|
||||
return nil, firstErr
|
||||
}
|
||||
return nil, firstErr
|
||||
}
|
||||
|
||||
// New creates properly initialized DNS Filter that is ready to be used
|
||||
|
@ -1039,8 +1037,7 @@ func New(c *Config) *Dnsfilter {
|
|||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
if c != nil && len(c.ResolverAddress) != 0 {
|
||||
resolverAddr = c.ResolverAddress
|
||||
d.transport.DialContext = customDialContext
|
||||
d.transport.DialContext = createCustomDialContext(c.ResolverAddress)
|
||||
}
|
||||
d.client = http.Client{
|
||||
Transport: d.transport,
|
||||
|
|
Loading…
Reference in New Issue