Merge pull request #40 in DNS/adguard-dns from feature/347 to master
* commit 'ff86d6b7dc31e463651c11f02330630e35676e05': Set default servers to tls://1.1.1.1 and tls://1.0.0.1
This commit is contained in:
commit
de7b2d5e6b
|
@ -56,7 +56,7 @@ type filter struct {
|
||||||
LastUpdated time.Time `json:"last_updated" yaml:"-"`
|
LastUpdated time.Time `json:"last_updated" yaml:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var defaultDNS = []string{"1.1.1.1", "1.0.0.1"}
|
var defaultDNS = []string{"tls://1.1.1.1", "tls://1.0.0.1"}
|
||||||
|
|
||||||
// initialize to default values, will be changed later when reading config or parsing command line
|
// initialize to default values, will be changed later when reading config or parsing command line
|
||||||
var config = configuration{
|
var config = configuration{
|
||||||
|
|
118
control.go
118
control.go
|
@ -506,16 +506,25 @@ func handleStatsTop(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func httpError(w http.ResponseWriter, code int, format string, args ...interface{}) {
|
||||||
|
text := fmt.Sprintf(format, args...)
|
||||||
|
http.Error(w, text, code)
|
||||||
|
}
|
||||||
|
|
||||||
func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||||
body, err := ioutil.ReadAll(r.Body)
|
body, err := ioutil.ReadAll(r.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errortext := fmt.Sprintf("Failed to read request body: %s", err)
|
errortext := fmt.Sprintf("Failed to read request body: %s", err)
|
||||||
log.Println(errortext)
|
log.Println(errortext)
|
||||||
http.Error(w, errortext, 400)
|
http.Error(w, errortext, http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// if empty body -- user is asking for default servers
|
// if empty body -- user is asking for default servers
|
||||||
hosts := parseIPsOptionalPort(string(body))
|
hosts, err := sanitiseDNSServers(string(body))
|
||||||
|
if err != nil {
|
||||||
|
httpError(w, http.StatusBadRequest, "Invalid DNS servers were given: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
if len(hosts) == 0 {
|
if len(hosts) == 0 {
|
||||||
config.CoreDNS.UpstreamDNS = defaultDNS
|
config.CoreDNS.UpstreamDNS = defaultDNS
|
||||||
} else {
|
} else {
|
||||||
|
@ -584,18 +593,11 @@ func handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkDNS(host string) error {
|
func checkDNS(input string) error {
|
||||||
host = appendPortIfMissing(host)
|
input, err := sanitizeDNSServer(input)
|
||||||
{
|
|
||||||
h, _, err := net.SplitHostPort(host)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
ip := net.ParseIP(h)
|
|
||||||
if ip == nil {
|
|
||||||
return fmt.Errorf("Invalid DNS server field: %s", h)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
req := dns.Msg{}
|
req := dns.Msg{}
|
||||||
req.Id = dns.Id()
|
req.Id = dns.Id()
|
||||||
|
@ -603,45 +605,91 @@ func checkDNS(host string) error {
|
||||||
req.Question = []dns.Question{
|
req.Question = []dns.Question{
|
||||||
{"google-public-dns-a.google.com.", dns.TypeA, dns.ClassINET},
|
{"google-public-dns-a.google.com.", dns.TypeA, dns.ClassINET},
|
||||||
}
|
}
|
||||||
resp, err := dns.Exchange(&req, host)
|
|
||||||
if err != nil {
|
prefix, host := splitDNSServerPrefixServer(input)
|
||||||
return fmt.Errorf("Couldn't communicate with DNS server %s: %s", host, err)
|
|
||||||
|
c := dns.Client{
|
||||||
|
Timeout: time.Minute,
|
||||||
}
|
}
|
||||||
|
switch prefix {
|
||||||
|
case "tls://":
|
||||||
|
c.Net = "tcp-tls"
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, rtt, err := c.Exchange(&req, host)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Couldn't communicate with DNS server %s: %s", input, err)
|
||||||
|
}
|
||||||
|
trace("exchange with %s took %v", input, rtt)
|
||||||
if len(resp.Answer) != 1 {
|
if len(resp.Answer) != 1 {
|
||||||
return fmt.Errorf("DNS server %s returned wrong answer", host)
|
return fmt.Errorf("DNS server %s returned wrong answer", input)
|
||||||
}
|
}
|
||||||
if t, ok := resp.Answer[0].(*dns.A); ok {
|
if t, ok := resp.Answer[0].(*dns.A); ok {
|
||||||
if !net.IPv4(8, 8, 8, 8).Equal(t.A) {
|
if !net.IPv4(8, 8, 8, 8).Equal(t.A) {
|
||||||
return fmt.Errorf("DNS server %s returned wrong answer: %v", host, t.A)
|
return fmt.Errorf("DNS server %s returned wrong answer: %v", input, t.A)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func appendPortIfMissing(input string) string {
|
func sanitiseDNSServers(input string) ([]string, error) {
|
||||||
|
fields := strings.Fields(input)
|
||||||
|
hosts := []string{}
|
||||||
|
for _, field := range fields {
|
||||||
|
sanitized, err := sanitizeDNSServer(field)
|
||||||
|
if err != nil {
|
||||||
|
return hosts, err
|
||||||
|
}
|
||||||
|
hosts = append(hosts, sanitized)
|
||||||
|
}
|
||||||
|
return hosts, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getDNSServerPrefix(input string) string {
|
||||||
|
prefix := ""
|
||||||
|
switch {
|
||||||
|
case strings.HasPrefix(input, "dns://"):
|
||||||
|
prefix = "dns://"
|
||||||
|
case strings.HasPrefix(input, "tls://"):
|
||||||
|
prefix = "tls://"
|
||||||
|
}
|
||||||
|
return prefix
|
||||||
|
}
|
||||||
|
|
||||||
|
func splitDNSServerPrefixServer(input string) (string, string) {
|
||||||
|
prefix := getDNSServerPrefix(input)
|
||||||
|
host := strings.TrimPrefix(input, prefix)
|
||||||
|
return prefix, host
|
||||||
|
}
|
||||||
|
|
||||||
|
func sanitizeDNSServer(input string) (string, error) {
|
||||||
|
prefix, host := splitDNSServerPrefixServer(input)
|
||||||
|
host = appendPortIfMissing(prefix, host)
|
||||||
|
{
|
||||||
|
h, _, err := net.SplitHostPort(host)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
ip := net.ParseIP(h)
|
||||||
|
if ip == nil {
|
||||||
|
return "", fmt.Errorf("Invalid DNS server field: %s", h)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return prefix + host, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func appendPortIfMissing(prefix, input string) string {
|
||||||
|
port := "53"
|
||||||
|
switch prefix {
|
||||||
|
case "tls://":
|
||||||
|
port = "853"
|
||||||
|
}
|
||||||
_, _, err := net.SplitHostPort(input)
|
_, _, err := net.SplitHostPort(input)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return input
|
return input
|
||||||
}
|
}
|
||||||
return net.JoinHostPort(input, "53")
|
return net.JoinHostPort(input, port)
|
||||||
}
|
|
||||||
|
|
||||||
func parseIPsOptionalPort(input string) []string {
|
|
||||||
fields := strings.Fields(input)
|
|
||||||
hosts := []string{}
|
|
||||||
for _, field := range fields {
|
|
||||||
_, _, err := net.SplitHostPort(field)
|
|
||||||
if err != nil {
|
|
||||||
ip := net.ParseIP(field)
|
|
||||||
if ip == nil {
|
|
||||||
log.Printf("Invalid DNS server field: %s\n", field)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
hosts = append(hosts, field)
|
|
||||||
}
|
|
||||||
return hosts
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
|
func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"os"
|
||||||
"path"
|
"path"
|
||||||
"runtime"
|
"runtime"
|
||||||
"sort"
|
"sort"
|
||||||
|
@ -259,5 +260,5 @@ func trace(format string, args ...interface{}) {
|
||||||
if len(text) == 0 || text[len(text)-1] != '\n' {
|
if len(text) == 0 || text[len(text)-1] != '\n' {
|
||||||
buf.WriteRune('\n')
|
buf.WriteRune('\n')
|
||||||
}
|
}
|
||||||
fmt.Print(buf.String())
|
fmt.Fprint(os.Stderr, buf.String())
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue