diff --git a/dnsforward/upstream.go b/dnsforward/upstream.go index 3e0a61c0..b245e7f1 100644 --- a/dnsforward/upstream.go +++ b/dnsforward/upstream.go @@ -7,6 +7,7 @@ import ( "io/ioutil" "log" "math/rand" + "net" "net/http" "net/url" "strings" @@ -172,16 +173,31 @@ func GetUpstream(address string) (Upstream, error) { } switch url.Scheme { case "dns": - return &plainDNS{Address: address}, nil + if url.Port() == "" { + url.Host += ":53" + } + return &plainDNS{Address: url.String()}, nil case "tls": - return &dnsOverTLS{Address: address}, nil + if url.Port() == "" { + url.Host += ":853" + } + return &dnsOverTLS{Address: url.String()}, nil case "https": - return &dnsOverHTTPS{Address: address}, nil + return &dnsOverHTTPS{Address: url.String()}, nil default: - return &plainDNS{Address: address}, nil + // assume it's plain DNS + if url.Port() == "" { + url.Host += ":53" + } + return &plainDNS{Address: url.String()}, nil } } // we don't have scheme in the url, so it's just a plain DNS host:port + _, _, err := net.SplitHostPort(address) + if err != nil { + // doesn't have port, default to 53 + address = net.JoinHostPort(address, "53") + } return &plainDNS{Address: address}, nil }