diff --git a/upstream/helpers.go b/upstream/helpers.go index 6e9cc30a..0e698698 100644 --- a/upstream/helpers.go +++ b/upstream/helpers.go @@ -1,6 +1,84 @@ package upstream -import "github.com/miekg/dns" +import ( + "github.com/miekg/dns" + "golang.org/x/net/context" + "net" + "strings" +) + +// Detects the upstream type from the specified url and creates a proper Upstream object +func NewUpstream(url string, bootstrap string) (Upstream, error) { + + proto := "udp" + prefix := "" + + switch { + case strings.HasPrefix(url, "tcp://"): + proto = "tcp" + prefix = "tcp://" + case strings.HasPrefix(url, "tls://"): + proto = "tcp-tls" + prefix = "tls://" + case strings.HasPrefix(url, "https://"): + return NewHttpsUpstream(url, bootstrap) + } + + hostname := strings.TrimPrefix(url, prefix) + + host, port, err := net.SplitHostPort(hostname) + if err != nil { + // Set port depending on the protocol + switch proto { + case "udp": + port = "53" + case "tcp": + port = "53" + case "tcp-tls": + port = "853" + } + + // Set host = hostname + host = hostname + } + + // Try to resolve the host address (or check if it's an IP address) + bootstrapResolver := CreateResolver(bootstrap) + ips, err := bootstrapResolver.LookupIPAddr(context.Background(), host) + + if err != nil || len(ips) == 0 { + return nil, err + } + + addr := ips[0].String() + endpoint := net.JoinHostPort(addr, port) + tlsServerName := "" + + if proto == "tcp-tls" && host != addr { + // Check if we need to specify TLS server name + tlsServerName = host + } + + return NewDnsUpstream(endpoint, proto, tlsServerName) +} + +func CreateResolver(bootstrap string) *net.Resolver { + + bootstrapResolver := net.DefaultResolver + + if bootstrap != "" { + bootstrapResolver = &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + var d net.Dialer + conn, err := d.DialContext(ctx, network, bootstrap) + return conn, err + }, + } + } + + return bootstrapResolver +} // Performs a simple health-check of the specified upstream func IsAlive(u Upstream) (bool, error) { diff --git a/upstream/https_upstream.go b/upstream/https_upstream.go index 906eadf2..ae705699 100644 --- a/upstream/https_upstream.go +++ b/upstream/https_upstream.go @@ -27,7 +27,7 @@ type HttpsUpstream struct { endpoint *url.URL } -// NewHttpsUpstream creates a new DNS-over-HTTPS upstream from hostname +// NewHttpsUpstream creates a new DNS-over-HTTPS upstream from the specified url func NewHttpsUpstream(endpoint string, bootstrap string) (Upstream, error) { u, err := url.Parse(endpoint) if err != nil { @@ -35,18 +35,7 @@ func NewHttpsUpstream(endpoint string, bootstrap string) (Upstream, error) { } // Initialize bootstrap resolver - bootstrapResolver := net.DefaultResolver - if bootstrap != "" { - bootstrapResolver = &net.Resolver{ - PreferGo: true, - Dial: func(ctx context.Context, network, address string) (net.Conn, error) { - var d net.Dialer - conn, err := d.DialContext(ctx, network, bootstrap) - return conn, err - }, - } - } - + bootstrapResolver := CreateResolver(bootstrap) dialer := &net.Dialer{ Timeout: defaultTimeout, KeepAlive: defaultKeepAlive, diff --git a/upstream/upstream_test.go b/upstream/upstream_test.go index f612fc6e..1b3235fe 100644 --- a/upstream/upstream_test.go +++ b/upstream/upstream_test.go @@ -9,16 +9,17 @@ import ( func TestDnsUpstreamIsAlive(t *testing.T) { var tests = []struct { - endpoint string - proto string + url string + bootstrap string }{ - {"8.8.8.8:53", "udp"}, - {"8.8.8.8:53", "tcp"}, - {"1.1.1.1:53", "udp"}, + {"8.8.8.8:53", "8.8.8.8:53"}, + {"1.1.1.1", ""}, + {"tcp://1.1.1.1:53", ""}, + {"176.103.130.130:5353", ""}, } for _, test := range tests { - u, err := NewDnsUpstream(test.endpoint, test.proto, "") + u, err := NewUpstream(test.url, test.bootstrap) if err != nil { t.Errorf("cannot create a DNS upstream") @@ -36,11 +37,11 @@ func TestHttpsUpstreamIsAlive(t *testing.T) { }{ {"https://cloudflare-dns.com/dns-query", "8.8.8.8:53"}, {"https://dns.google.com/experimental", "8.8.8.8:53"}, - {"https://doh.cleanbrowsing.org/doh/security-filter/", ""}, // TODO: status 201?? + {"https://doh.cleanbrowsing.org/doh/security-filter/", ""}, } for _, test := range tests { - u, err := NewHttpsUpstream(test.url, test.bootstrap) + u, err := NewUpstream(test.url, test.bootstrap) if err != nil { t.Errorf("cannot create a DNS-over-HTTPS upstream") @@ -53,16 +54,17 @@ func TestHttpsUpstreamIsAlive(t *testing.T) { func TestDnsOverTlsIsAlive(t *testing.T) { var tests = []struct { - endpoint string - tlsServerName string + url string + bootstrap string }{ - {"1.1.1.1:853", ""}, - {"9.9.9.9:853", ""}, - {"185.228.168.10:853", "security-filter-dns.cleanbrowsing.org"}, + {"tls://1.1.1.1", ""}, + {"tls://9.9.9.9:853", ""}, + {"tls://security-filter-dns.cleanbrowsing.org", "8.8.8.8:53"}, + {"tls://adult-filter-dns.cleanbrowsing.org:853", "8.8.8.8:53"}, } for _, test := range tests { - u, err := NewDnsUpstream(test.endpoint, "tcp-tls", test.tlsServerName) + u, err := NewUpstream(test.url, test.bootstrap) if err != nil { t.Errorf("cannot create a DNS-over-TLS upstream") @@ -75,16 +77,17 @@ func TestDnsOverTlsIsAlive(t *testing.T) { func TestDnsUpstream(t *testing.T) { var tests = []struct { - endpoint string - proto string + url string + bootstrap string }{ - {"8.8.8.8:53", "udp"}, - {"8.8.8.8:53", "tcp"}, - {"1.1.1.1:53", "udp"}, + {"8.8.8.8:53", "8.8.8.8:53"}, + {"1.1.1.1", ""}, + {"tcp://1.1.1.1:53", ""}, + {"176.103.130.130:5353", ""}, } for _, test := range tests { - u, err := NewDnsUpstream(test.endpoint, test.proto, "") + u, err := NewUpstream(test.url, test.bootstrap) if err != nil { t.Errorf("cannot create a DNS upstream") @@ -106,7 +109,7 @@ func TestHttpsUpstream(t *testing.T) { } for _, test := range tests { - u, err := NewHttpsUpstream(test.url, test.bootstrap) + u, err := NewUpstream(test.url, test.bootstrap) if err != nil { t.Errorf("cannot create a DNS-over-HTTPS upstream") @@ -119,16 +122,17 @@ func TestHttpsUpstream(t *testing.T) { func TestDnsOverTlsUpstream(t *testing.T) { var tests = []struct { - endpoint string - tlsServerName string + url string + bootstrap string }{ - {"1.1.1.1:853", ""}, - {"9.9.9.9:853", ""}, - {"185.228.168.10:853", "security-filter-dns.cleanbrowsing.org"}, + {"tls://1.1.1.1", ""}, + {"tls://9.9.9.9:853", ""}, + {"tls://security-filter-dns.cleanbrowsing.org", "8.8.8.8:53"}, + {"tls://adult-filter-dns.cleanbrowsing.org:853", "8.8.8.8:53"}, } for _, test := range tests { - u, err := NewDnsUpstream(test.endpoint, "tcp-tls", test.tlsServerName) + u, err := NewUpstream(test.url, test.bootstrap) if err != nil { t.Errorf("cannot create a DNS-over-TLS upstream")