diff --git a/config.go b/config.go index e44345b4..5265a218 100644 --- a/config.go +++ b/config.go @@ -56,7 +56,7 @@ type filter struct { LastUpdated time.Time `json:"last_updated" yaml:"-"` } -var defaultDNS = []string{"1.1.1.1", "1.0.0.1"} +var defaultDNS = []string{"tls://1.1.1.1", "tls://1.0.0.1"} // initialize to default values, will be changed later when reading config or parsing command line var config = configuration{ diff --git a/control.go b/control.go index a8bb391b..9508f0aa 100644 --- a/control.go +++ b/control.go @@ -506,16 +506,25 @@ func handleStatsTop(w http.ResponseWriter, r *http.Request) { } } +func httpError(w http.ResponseWriter, code int, format string, args ...interface{}) { + text := fmt.Sprintf(format, args...) + http.Error(w, text, code) +} + func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) { body, err := ioutil.ReadAll(r.Body) if err != nil { errortext := fmt.Sprintf("Failed to read request body: %s", err) log.Println(errortext) - http.Error(w, errortext, 400) + http.Error(w, errortext, http.StatusBadRequest) return } // if empty body -- user is asking for default servers - hosts := parseIPsOptionalPort(string(body)) + hosts, err := sanitiseDNSServers(string(body)) + if err != nil { + httpError(w, http.StatusBadRequest, "Invalid DNS servers were given: %s", err) + return + } if len(hosts) == 0 { config.CoreDNS.UpstreamDNS = defaultDNS } else { @@ -584,17 +593,10 @@ func handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) { } } -func checkDNS(host string) error { - host = appendPortIfMissing(host) - { - h, _, err := net.SplitHostPort(host) - if err != nil { - return err - } - ip := net.ParseIP(h) - if ip == nil { - return fmt.Errorf("Invalid DNS server field: %s", h) - } +func checkDNS(input string) error { + input, err := sanitizeDNSServer(input) + if err != nil { + return err } req := dns.Msg{} @@ -603,45 +605,91 @@ func checkDNS(host string) error { req.Question = []dns.Question{ {"google-public-dns-a.google.com.", dns.TypeA, dns.ClassINET}, } - resp, err := dns.Exchange(&req, host) - if err != nil { - return fmt.Errorf("Couldn't communicate with DNS server %s: %s", host, err) + + prefix, host := splitDNSServerPrefixServer(input) + + c := dns.Client{ + Timeout: time.Minute, } + switch prefix { + case "tls://": + c.Net = "tcp-tls" + } + + resp, rtt, err := c.Exchange(&req, host) + if err != nil { + return fmt.Errorf("Couldn't communicate with DNS server %s: %s", input, err) + } + trace("exchange with %s took %v", input, rtt) if len(resp.Answer) != 1 { - return fmt.Errorf("DNS server %s returned wrong answer", host) + return fmt.Errorf("DNS server %s returned wrong answer", input) } if t, ok := resp.Answer[0].(*dns.A); ok { if !net.IPv4(8, 8, 8, 8).Equal(t.A) { - return fmt.Errorf("DNS server %s returned wrong answer: %v", host, t.A) + return fmt.Errorf("DNS server %s returned wrong answer: %v", input, t.A) } } return nil } -func appendPortIfMissing(input string) string { +func sanitiseDNSServers(input string) ([]string, error) { + fields := strings.Fields(input) + hosts := []string{} + for _, field := range fields { + sanitized, err := sanitizeDNSServer(field) + if err != nil { + return hosts, err + } + hosts = append(hosts, sanitized) + } + return hosts, nil +} + +func getDNSServerPrefix(input string) string { + prefix := "" + switch { + case strings.HasPrefix(input, "dns://"): + prefix = "dns://" + case strings.HasPrefix(input, "tls://"): + prefix = "tls://" + } + return prefix +} + +func splitDNSServerPrefixServer(input string) (string, string) { + prefix := getDNSServerPrefix(input) + host := strings.TrimPrefix(input, prefix) + return prefix, host +} + +func sanitizeDNSServer(input string) (string, error) { + prefix, host := splitDNSServerPrefixServer(input) + host = appendPortIfMissing(prefix, host) + { + h, _, err := net.SplitHostPort(host) + if err != nil { + return "", err + } + ip := net.ParseIP(h) + if ip == nil { + return "", fmt.Errorf("Invalid DNS server field: %s", h) + } + } + return prefix + host, nil +} + +func appendPortIfMissing(prefix, input string) string { + port := "53" + switch prefix { + case "tls://": + port = "853" + } _, _, err := net.SplitHostPort(input) if err == nil { return input } - return net.JoinHostPort(input, "53") -} - -func parseIPsOptionalPort(input string) []string { - fields := strings.Fields(input) - hosts := []string{} - for _, field := range fields { - _, _, err := net.SplitHostPort(field) - if err != nil { - ip := net.ParseIP(field) - if ip == nil { - log.Printf("Invalid DNS server field: %s\n", field) - continue - } - } - hosts = append(hosts, field) - } - return hosts + return net.JoinHostPort(input, port) } func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) { diff --git a/helpers.go b/helpers.go index f28b9ea7..d6997141 100644 --- a/helpers.go +++ b/helpers.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net/http" + "os" "path" "runtime" "sort" @@ -259,5 +260,5 @@ func trace(format string, args ...interface{}) { if len(text) == 0 || text[len(text)-1] != '\n' { buf.WriteRune('\n') } - fmt.Print(buf.String()) + fmt.Fprint(os.Stderr, buf.String()) }