From bebdc1b5bc8bcefe4b31743033d45d098a81b0ff Mon Sep 17 00:00:00 2001 From: Simon Zolin Date: Wed, 17 Apr 2019 14:53:52 +0300 Subject: [PATCH] + control: use our own DNS resolver instead of system default --- control.go | 7 ++++++- helpers.go | 46 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 1 deletion(-) diff --git a/control.go b/control.go index fe81eaa2..8ded7372 100644 --- a/control.go +++ b/control.go @@ -35,8 +35,13 @@ var protocols = []string{"tls://", "https://", "tcp://", "sdns://"} const versionCheckURL = "https://adguardteam.github.io/AdGuardHome/version.json" const versionCheckPeriod = time.Hour * 8 +var transport = &http.Transport{ + DialContext: customDialContext, +} + var client = &http.Client{ - Timeout: time.Minute * 5, + Timeout: time.Minute * 5, + Transport: transport, } var controlLock sync.Mutex diff --git a/helpers.go b/helpers.go index a304634e..184789d4 100644 --- a/helpers.go +++ b/helpers.go @@ -2,6 +2,7 @@ package main import ( "bufio" + "context" "errors" "fmt" "io" @@ -14,7 +15,10 @@ import ( "runtime" "strconv" "strings" + "time" + "github.com/AdguardTeam/dnsproxy/upstream" + "github.com/AdguardTeam/golibs/log" "github.com/joomcode/errorx" ) @@ -300,6 +304,48 @@ func checkPacketPortAvailable(host string, port int) error { return err } +// Connect to a remote server resolving hostname using our own DNS server +func customDialContext(ctx context.Context, network, addr string) (net.Conn, error) { + log.Tracef("network:%v addr:%v", network, addr) + + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + + dialer := &net.Dialer{ + Timeout: time.Minute * 5, + } + + if net.ParseIP(host) != nil { + con, err := dialer.DialContext(ctx, network, addr) + return con, err + } + + resolverAddr := fmt.Sprintf("%s:%d", config.DNS.BindHost, config.DNS.Port) + r := upstream.NewResolver(resolverAddr, 30*time.Second) + addrs, e := r.LookupIPAddr(ctx, host) + log.Tracef("LookupIPAddr: %s: %v", host, addrs) + if e != nil { + return nil, e + } + + var firstErr error + firstErr = nil + for _, a := range addrs { + addr = fmt.Sprintf("%s:%s", a.String(), port) + con, err := dialer.DialContext(ctx, network, addr) + if err != nil { + if firstErr == nil { + firstErr = err + } + continue + } + return con, err + } + return nil, firstErr +} + // --------------------- // debug logging helpers // ---------------------