Added factory method for creating DNS upstreams
This commit is contained in:
parent
a6022fc198
commit
9bc4bf66ed
@ -1,6 +1,84 @@
|
|||||||
package upstream
|
package upstream
|
||||||
|
|
||||||
import "github.com/miekg/dns"
|
import (
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Detects the upstream type from the specified url and creates a proper Upstream object
|
||||||
|
func NewUpstream(url string, bootstrap string) (Upstream, error) {
|
||||||
|
|
||||||
|
proto := "udp"
|
||||||
|
prefix := ""
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case strings.HasPrefix(url, "tcp://"):
|
||||||
|
proto = "tcp"
|
||||||
|
prefix = "tcp://"
|
||||||
|
case strings.HasPrefix(url, "tls://"):
|
||||||
|
proto = "tcp-tls"
|
||||||
|
prefix = "tls://"
|
||||||
|
case strings.HasPrefix(url, "https://"):
|
||||||
|
return NewHttpsUpstream(url, bootstrap)
|
||||||
|
}
|
||||||
|
|
||||||
|
hostname := strings.TrimPrefix(url, prefix)
|
||||||
|
|
||||||
|
host, port, err := net.SplitHostPort(hostname)
|
||||||
|
if err != nil {
|
||||||
|
// Set port depending on the protocol
|
||||||
|
switch proto {
|
||||||
|
case "udp":
|
||||||
|
port = "53"
|
||||||
|
case "tcp":
|
||||||
|
port = "53"
|
||||||
|
case "tcp-tls":
|
||||||
|
port = "853"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set host = hostname
|
||||||
|
host = hostname
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to resolve the host address (or check if it's an IP address)
|
||||||
|
bootstrapResolver := CreateResolver(bootstrap)
|
||||||
|
ips, err := bootstrapResolver.LookupIPAddr(context.Background(), host)
|
||||||
|
|
||||||
|
if err != nil || len(ips) == 0 {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
addr := ips[0].String()
|
||||||
|
endpoint := net.JoinHostPort(addr, port)
|
||||||
|
tlsServerName := ""
|
||||||
|
|
||||||
|
if proto == "tcp-tls" && host != addr {
|
||||||
|
// Check if we need to specify TLS server name
|
||||||
|
tlsServerName = host
|
||||||
|
}
|
||||||
|
|
||||||
|
return NewDnsUpstream(endpoint, proto, tlsServerName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateResolver(bootstrap string) *net.Resolver {
|
||||||
|
|
||||||
|
bootstrapResolver := net.DefaultResolver
|
||||||
|
|
||||||
|
if bootstrap != "" {
|
||||||
|
bootstrapResolver = &net.Resolver{
|
||||||
|
PreferGo: true,
|
||||||
|
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
var d net.Dialer
|
||||||
|
conn, err := d.DialContext(ctx, network, bootstrap)
|
||||||
|
return conn, err
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return bootstrapResolver
|
||||||
|
}
|
||||||
|
|
||||||
// Performs a simple health-check of the specified upstream
|
// Performs a simple health-check of the specified upstream
|
||||||
func IsAlive(u Upstream) (bool, error) {
|
func IsAlive(u Upstream) (bool, error) {
|
||||||
|
@ -27,7 +27,7 @@ type HttpsUpstream struct {
|
|||||||
endpoint *url.URL
|
endpoint *url.URL
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewHttpsUpstream creates a new DNS-over-HTTPS upstream from hostname
|
// NewHttpsUpstream creates a new DNS-over-HTTPS upstream from the specified url
|
||||||
func NewHttpsUpstream(endpoint string, bootstrap string) (Upstream, error) {
|
func NewHttpsUpstream(endpoint string, bootstrap string) (Upstream, error) {
|
||||||
u, err := url.Parse(endpoint)
|
u, err := url.Parse(endpoint)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -35,18 +35,7 @@ func NewHttpsUpstream(endpoint string, bootstrap string) (Upstream, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Initialize bootstrap resolver
|
// Initialize bootstrap resolver
|
||||||
bootstrapResolver := net.DefaultResolver
|
bootstrapResolver := CreateResolver(bootstrap)
|
||||||
if bootstrap != "" {
|
|
||||||
bootstrapResolver = &net.Resolver{
|
|
||||||
PreferGo: true,
|
|
||||||
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
|
||||||
var d net.Dialer
|
|
||||||
conn, err := d.DialContext(ctx, network, bootstrap)
|
|
||||||
return conn, err
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
dialer := &net.Dialer{
|
dialer := &net.Dialer{
|
||||||
Timeout: defaultTimeout,
|
Timeout: defaultTimeout,
|
||||||
KeepAlive: defaultKeepAlive,
|
KeepAlive: defaultKeepAlive,
|
||||||
|
@ -9,16 +9,17 @@ import (
|
|||||||
func TestDnsUpstreamIsAlive(t *testing.T) {
|
func TestDnsUpstreamIsAlive(t *testing.T) {
|
||||||
|
|
||||||
var tests = []struct {
|
var tests = []struct {
|
||||||
endpoint string
|
url string
|
||||||
proto string
|
bootstrap string
|
||||||
}{
|
}{
|
||||||
{"8.8.8.8:53", "udp"},
|
{"8.8.8.8:53", "8.8.8.8:53"},
|
||||||
{"8.8.8.8:53", "tcp"},
|
{"1.1.1.1", ""},
|
||||||
{"1.1.1.1:53", "udp"},
|
{"tcp://1.1.1.1:53", ""},
|
||||||
|
{"176.103.130.130:5353", ""},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
u, err := NewDnsUpstream(test.endpoint, test.proto, "")
|
u, err := NewUpstream(test.url, test.bootstrap)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("cannot create a DNS upstream")
|
t.Errorf("cannot create a DNS upstream")
|
||||||
@ -36,11 +37,11 @@ func TestHttpsUpstreamIsAlive(t *testing.T) {
|
|||||||
}{
|
}{
|
||||||
{"https://cloudflare-dns.com/dns-query", "8.8.8.8:53"},
|
{"https://cloudflare-dns.com/dns-query", "8.8.8.8:53"},
|
||||||
{"https://dns.google.com/experimental", "8.8.8.8:53"},
|
{"https://dns.google.com/experimental", "8.8.8.8:53"},
|
||||||
{"https://doh.cleanbrowsing.org/doh/security-filter/", ""}, // TODO: status 201??
|
{"https://doh.cleanbrowsing.org/doh/security-filter/", ""},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
u, err := NewHttpsUpstream(test.url, test.bootstrap)
|
u, err := NewUpstream(test.url, test.bootstrap)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("cannot create a DNS-over-HTTPS upstream")
|
t.Errorf("cannot create a DNS-over-HTTPS upstream")
|
||||||
@ -53,16 +54,17 @@ func TestHttpsUpstreamIsAlive(t *testing.T) {
|
|||||||
func TestDnsOverTlsIsAlive(t *testing.T) {
|
func TestDnsOverTlsIsAlive(t *testing.T) {
|
||||||
|
|
||||||
var tests = []struct {
|
var tests = []struct {
|
||||||
endpoint string
|
url string
|
||||||
tlsServerName string
|
bootstrap string
|
||||||
}{
|
}{
|
||||||
{"1.1.1.1:853", ""},
|
{"tls://1.1.1.1", ""},
|
||||||
{"9.9.9.9:853", ""},
|
{"tls://9.9.9.9:853", ""},
|
||||||
{"185.228.168.10:853", "security-filter-dns.cleanbrowsing.org"},
|
{"tls://security-filter-dns.cleanbrowsing.org", "8.8.8.8:53"},
|
||||||
|
{"tls://adult-filter-dns.cleanbrowsing.org:853", "8.8.8.8:53"},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
u, err := NewDnsUpstream(test.endpoint, "tcp-tls", test.tlsServerName)
|
u, err := NewUpstream(test.url, test.bootstrap)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("cannot create a DNS-over-TLS upstream")
|
t.Errorf("cannot create a DNS-over-TLS upstream")
|
||||||
@ -75,16 +77,17 @@ func TestDnsOverTlsIsAlive(t *testing.T) {
|
|||||||
func TestDnsUpstream(t *testing.T) {
|
func TestDnsUpstream(t *testing.T) {
|
||||||
|
|
||||||
var tests = []struct {
|
var tests = []struct {
|
||||||
endpoint string
|
url string
|
||||||
proto string
|
bootstrap string
|
||||||
}{
|
}{
|
||||||
{"8.8.8.8:53", "udp"},
|
{"8.8.8.8:53", "8.8.8.8:53"},
|
||||||
{"8.8.8.8:53", "tcp"},
|
{"1.1.1.1", ""},
|
||||||
{"1.1.1.1:53", "udp"},
|
{"tcp://1.1.1.1:53", ""},
|
||||||
|
{"176.103.130.130:5353", ""},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
u, err := NewDnsUpstream(test.endpoint, test.proto, "")
|
u, err := NewUpstream(test.url, test.bootstrap)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("cannot create a DNS upstream")
|
t.Errorf("cannot create a DNS upstream")
|
||||||
@ -106,7 +109,7 @@ func TestHttpsUpstream(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
u, err := NewHttpsUpstream(test.url, test.bootstrap)
|
u, err := NewUpstream(test.url, test.bootstrap)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("cannot create a DNS-over-HTTPS upstream")
|
t.Errorf("cannot create a DNS-over-HTTPS upstream")
|
||||||
@ -119,16 +122,17 @@ func TestHttpsUpstream(t *testing.T) {
|
|||||||
func TestDnsOverTlsUpstream(t *testing.T) {
|
func TestDnsOverTlsUpstream(t *testing.T) {
|
||||||
|
|
||||||
var tests = []struct {
|
var tests = []struct {
|
||||||
endpoint string
|
url string
|
||||||
tlsServerName string
|
bootstrap string
|
||||||
}{
|
}{
|
||||||
{"1.1.1.1:853", ""},
|
{"tls://1.1.1.1", ""},
|
||||||
{"9.9.9.9:853", ""},
|
{"tls://9.9.9.9:853", ""},
|
||||||
{"185.228.168.10:853", "security-filter-dns.cleanbrowsing.org"},
|
{"tls://security-filter-dns.cleanbrowsing.org", "8.8.8.8:53"},
|
||||||
|
{"tls://adult-filter-dns.cleanbrowsing.org:853", "8.8.8.8:53"},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
u, err := NewDnsUpstream(test.endpoint, "tcp-tls", test.tlsServerName)
|
u, err := NewUpstream(test.url, test.bootstrap)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("cannot create a DNS-over-TLS upstream")
|
t.Errorf("cannot create a DNS-over-TLS upstream")
|
||||||
|
Loading…
Reference in New Issue
Block a user