* dnsfilter: don't use global variable for custom resolver function

This commit is contained in:
Simon Zolin 2019-04-24 12:49:12 +03:00
parent f6023b395e
commit c4e67690f4
1 changed files with 39 additions and 38 deletions

View File

@ -162,8 +162,6 @@ var (
safeSearchCache gcache.Cache safeSearchCache gcache.Cache
) )
var resolverAddr string // DNS server address
// Result holds state of hostname check // Result holds state of hostname check
type Result struct { type Result struct {
IsFiltered bool `json:",omitempty"` // True if the host name is filtered IsFiltered bool `json:",omitempty"` // True if the host name is filtered
@ -974,45 +972,49 @@ func (d *Dnsfilter) matchHost(host string) (Result, error) {
// lifecycle helper functions // lifecycle helper functions
// //
type dialFunctionType func(ctx context.Context, network, addr string) (net.Conn, error)
// Connect to a remote server resolving hostname using our own DNS server // Connect to a remote server resolving hostname using our own DNS server
func customDialContext(ctx context.Context, network, addr string) (net.Conn, error) { func createCustomDialContext(resolverAddr string) dialFunctionType {
log.Tracef("network:%v addr:%v", network, addr) return func(ctx context.Context, network, addr string) (net.Conn, error) {
log.Tracef("network:%v addr:%v", network, addr)
host, port, err := net.SplitHostPort(addr) host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
dialer := &net.Dialer{
Timeout: time.Minute * 5,
}
if net.ParseIP(host) != nil {
con, err := dialer.DialContext(ctx, network, addr)
return con, err
}
r := upstream.NewResolver(resolverAddr, 30*time.Second)
addrs, e := r.LookupIPAddr(ctx, host)
log.Tracef("LookupIPAddr: %s: %v", host, addrs)
if e != nil {
return nil, e
}
var firstErr error
firstErr = nil
for _, a := range addrs {
addr = fmt.Sprintf("%s:%s", a.String(), port)
con, err := dialer.DialContext(ctx, network, addr)
if err != nil { if err != nil {
if firstErr == nil { return nil, err
firstErr = err
}
continue
} }
return con, err
dialer := &net.Dialer{
Timeout: time.Minute * 5,
}
if net.ParseIP(host) != nil {
con, err := dialer.DialContext(ctx, network, addr)
return con, err
}
r := upstream.NewResolver(resolverAddr, 30*time.Second)
addrs, e := r.LookupIPAddr(ctx, host)
log.Tracef("LookupIPAddr: %s: %v", host, addrs)
if e != nil {
return nil, e
}
var firstErr error
firstErr = nil
for _, a := range addrs {
addr = fmt.Sprintf("%s:%s", a.String(), port)
con, err := dialer.DialContext(ctx, network, addr)
if err != nil {
if firstErr == nil {
firstErr = err
}
continue
}
return con, err
}
return nil, firstErr
} }
return nil, firstErr
} }
// New creates properly initialized DNS Filter that is ready to be used // New creates properly initialized DNS Filter that is ready to be used
@ -1035,8 +1037,7 @@ func New(c *Config) *Dnsfilter {
ExpectContinueTimeout: 1 * time.Second, ExpectContinueTimeout: 1 * time.Second,
} }
if c != nil && len(c.ResolverAddress) != 0 { if c != nil && len(c.ResolverAddress) != 0 {
resolverAddr = c.ResolverAddress d.transport.DialContext = createCustomDialContext(c.ResolverAddress)
d.transport.DialContext = customDialContext
} }
d.client = http.Client{ d.client = http.Client{
Transport: d.transport, Transport: d.transport,