Merge: Don't use the system default resolver implicitly

Close #647

* commit 'a1b18776678ee894cd0b558cf8683845056a2dfa':
  + parental, safesearch: use our own DNS resolver instead of system default
  + control: use our own DNS resolver instead of system default
This commit is contained in:
Simon Zolin 2019-04-22 11:28:43 +03:00
commit 00ba63341b
4 changed files with 107 additions and 5 deletions

View File

@ -35,8 +35,13 @@ var protocols = []string{"tls://", "https://", "tcp://", "sdns://"}
const versionCheckURL = "https://adguardteam.github.io/AdGuardHome/version.json" const versionCheckURL = "https://adguardteam.github.io/AdGuardHome/version.json"
const versionCheckPeriod = time.Hour * 8 const versionCheckPeriod = time.Hour * 8
var transport = &http.Transport{
DialContext: customDialContext,
}
var client = &http.Client{ var client = &http.Client{
Timeout: time.Minute * 5, Timeout: time.Minute * 5,
Transport: transport,
} }
var controlLock sync.Mutex var controlLock sync.Mutex

1
dns.go
View File

@ -50,6 +50,7 @@ func generateServerConfig() dnsforward.ServerConfig {
FilteringConfig: config.DNS.FilteringConfig, FilteringConfig: config.DNS.FilteringConfig,
Filters: filters, Filters: filters,
} }
newconfig.ResolverAddress = fmt.Sprintf("%s:%d", config.DNS.BindHost, config.DNS.Port)
if config.TLS.Enabled { if config.TLS.Enabled {
newconfig.TLSConfig = config.TLS.TLSConfig newconfig.TLSConfig = config.TLS.TLSConfig

View File

@ -3,6 +3,7 @@ package dnsfilter
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"context"
"crypto/sha256" "crypto/sha256"
"encoding/json" "encoding/json"
"errors" "errors"
@ -16,6 +17,7 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/bluele/gcache" "github.com/bluele/gcache"
"golang.org/x/net/publicsuffix" "golang.org/x/net/publicsuffix"
@ -49,6 +51,7 @@ type Config struct {
ParentalEnabled bool `yaml:"parental_enabled"` ParentalEnabled bool `yaml:"parental_enabled"`
SafeSearchEnabled bool `yaml:"safesearch_enabled"` SafeSearchEnabled bool `yaml:"safesearch_enabled"`
SafeBrowsingEnabled bool `yaml:"safebrowsing_enabled"` SafeBrowsingEnabled bool `yaml:"safebrowsing_enabled"`
ResolverAddress string // DNS server address
} }
type privateConfig struct { type privateConfig struct {
@ -159,6 +162,8 @@ var (
safeSearchCache gcache.Cache safeSearchCache gcache.Cache
) )
var resolverAddr string // DNS server address
// Result holds state of hostname check // Result holds state of hostname check
type Result struct { type Result struct {
IsFiltered bool `json:",omitempty"` // True if the host name is filtered IsFiltered bool `json:",omitempty"` // True if the host name is filtered
@ -971,6 +976,47 @@ func (d *Dnsfilter) matchHost(host string) (Result, error) {
// lifecycle helper functions // lifecycle helper functions
// //
// Connect to a remote server resolving hostname using our own DNS server
func customDialContext(ctx context.Context, network, addr string) (net.Conn, error) {
log.Tracef("network:%v addr:%v", network, addr)
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
dialer := &net.Dialer{
Timeout: time.Minute * 5,
}
if net.ParseIP(host) != nil {
con, err := dialer.DialContext(ctx, network, addr)
return con, err
}
r := upstream.NewResolver(resolverAddr, 30*time.Second)
addrs, e := r.LookupIPAddr(ctx, host)
log.Tracef("LookupIPAddr: %s: %v", host, addrs)
if e != nil {
return nil, e
}
var firstErr error
firstErr = nil
for _, a := range addrs {
addr = fmt.Sprintf("%s:%s", a.String(), port)
con, err := dialer.DialContext(ctx, network, addr)
if err != nil {
if firstErr == nil {
firstErr = err
}
continue
}
return con, err
}
return nil, firstErr
}
// New creates properly initialized DNS Filter that is ready to be used // New creates properly initialized DNS Filter that is ready to be used
func New(c *Config) *Dnsfilter { func New(c *Config) *Dnsfilter {
d := new(Dnsfilter) d := new(Dnsfilter)
@ -990,6 +1036,10 @@ func New(c *Config) *Dnsfilter {
TLSHandshakeTimeout: 10 * time.Second, TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second, ExpectContinueTimeout: 1 * time.Second,
} }
if len(c.ResolverAddress) != 0 {
resolverAddr = c.ResolverAddress
d.transport.DialContext = customDialContext
}
d.client = http.Client{ d.client = http.Client{
Transport: d.transport, Transport: d.transport,
Timeout: defaultHTTPTimeout, Timeout: defaultHTTPTimeout,

View File

@ -2,6 +2,7 @@ package main
import ( import (
"bufio" "bufio"
"context"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@ -14,7 +15,10 @@ import (
"runtime" "runtime"
"strconv" "strconv"
"strings" "strings"
"time"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/log"
"github.com/joomcode/errorx" "github.com/joomcode/errorx"
) )
@ -300,6 +304,48 @@ func checkPacketPortAvailable(host string, port int) error {
return err return err
} }
// Connect to a remote server resolving hostname using our own DNS server
func customDialContext(ctx context.Context, network, addr string) (net.Conn, error) {
log.Tracef("network:%v addr:%v", network, addr)
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
dialer := &net.Dialer{
Timeout: time.Minute * 5,
}
if net.ParseIP(host) != nil {
con, err := dialer.DialContext(ctx, network, addr)
return con, err
}
resolverAddr := fmt.Sprintf("%s:%d", config.DNS.BindHost, config.DNS.Port)
r := upstream.NewResolver(resolverAddr, 30*time.Second)
addrs, e := r.LookupIPAddr(ctx, host)
log.Tracef("LookupIPAddr: %s: %v", host, addrs)
if e != nil {
return nil, e
}
var firstErr error
firstErr = nil
for _, a := range addrs {
addr = fmt.Sprintf("%s:%s", a.String(), port)
con, err := dialer.DialContext(ctx, network, addr)
if err != nil {
if firstErr == nil {
firstErr = err
}
continue
}
return con, err
}
return nil, firstErr
}
// --------------------- // ---------------------
// debug logging helpers // debug logging helpers
// --------------------- // ---------------------