Merge pull request #40 in DNS/adguard-dns from feature/347 to master

* commit 'ff86d6b7dc31e463651c11f02330630e35676e05':
  Set default servers to tls://1.1.1.1 and tls://1.0.0.1
This commit is contained in:
Eugene Bujak 2018-09-26 18:00:34 +03:00
commit de7b2d5e6b
3 changed files with 88 additions and 39 deletions

View File

@ -56,7 +56,7 @@ type filter struct {
LastUpdated time.Time `json:"last_updated" yaml:"-"` LastUpdated time.Time `json:"last_updated" yaml:"-"`
} }
var defaultDNS = []string{"1.1.1.1", "1.0.0.1"} var defaultDNS = []string{"tls://1.1.1.1", "tls://1.0.0.1"}
// initialize to default values, will be changed later when reading config or parsing command line // initialize to default values, will be changed later when reading config or parsing command line
var config = configuration{ var config = configuration{

View File

@ -506,16 +506,25 @@ func handleStatsTop(w http.ResponseWriter, r *http.Request) {
} }
} }
func httpError(w http.ResponseWriter, code int, format string, args ...interface{}) {
text := fmt.Sprintf(format, args...)
http.Error(w, text, code)
}
func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) { func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(r.Body) body, err := ioutil.ReadAll(r.Body)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Failed to read request body: %s", err) errortext := fmt.Sprintf("Failed to read request body: %s", err)
log.Println(errortext) log.Println(errortext)
http.Error(w, errortext, 400) http.Error(w, errortext, http.StatusBadRequest)
return return
} }
// if empty body -- user is asking for default servers // if empty body -- user is asking for default servers
hosts := parseIPsOptionalPort(string(body)) hosts, err := sanitiseDNSServers(string(body))
if err != nil {
httpError(w, http.StatusBadRequest, "Invalid DNS servers were given: %s", err)
return
}
if len(hosts) == 0 { if len(hosts) == 0 {
config.CoreDNS.UpstreamDNS = defaultDNS config.CoreDNS.UpstreamDNS = defaultDNS
} else { } else {
@ -584,18 +593,11 @@ func handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
} }
} }
func checkDNS(host string) error { func checkDNS(input string) error {
host = appendPortIfMissing(host) input, err := sanitizeDNSServer(input)
{
h, _, err := net.SplitHostPort(host)
if err != nil { if err != nil {
return err return err
} }
ip := net.ParseIP(h)
if ip == nil {
return fmt.Errorf("Invalid DNS server field: %s", h)
}
}
req := dns.Msg{} req := dns.Msg{}
req.Id = dns.Id() req.Id = dns.Id()
@ -603,45 +605,91 @@ func checkDNS(host string) error {
req.Question = []dns.Question{ req.Question = []dns.Question{
{"google-public-dns-a.google.com.", dns.TypeA, dns.ClassINET}, {"google-public-dns-a.google.com.", dns.TypeA, dns.ClassINET},
} }
resp, err := dns.Exchange(&req, host)
if err != nil { prefix, host := splitDNSServerPrefixServer(input)
return fmt.Errorf("Couldn't communicate with DNS server %s: %s", host, err)
c := dns.Client{
Timeout: time.Minute,
} }
switch prefix {
case "tls://":
c.Net = "tcp-tls"
}
resp, rtt, err := c.Exchange(&req, host)
if err != nil {
return fmt.Errorf("Couldn't communicate with DNS server %s: %s", input, err)
}
trace("exchange with %s took %v", input, rtt)
if len(resp.Answer) != 1 { if len(resp.Answer) != 1 {
return fmt.Errorf("DNS server %s returned wrong answer", host) return fmt.Errorf("DNS server %s returned wrong answer", input)
} }
if t, ok := resp.Answer[0].(*dns.A); ok { if t, ok := resp.Answer[0].(*dns.A); ok {
if !net.IPv4(8, 8, 8, 8).Equal(t.A) { if !net.IPv4(8, 8, 8, 8).Equal(t.A) {
return fmt.Errorf("DNS server %s returned wrong answer: %v", host, t.A) return fmt.Errorf("DNS server %s returned wrong answer: %v", input, t.A)
} }
} }
return nil return nil
} }
func appendPortIfMissing(input string) string { func sanitiseDNSServers(input string) ([]string, error) {
fields := strings.Fields(input)
hosts := []string{}
for _, field := range fields {
sanitized, err := sanitizeDNSServer(field)
if err != nil {
return hosts, err
}
hosts = append(hosts, sanitized)
}
return hosts, nil
}
func getDNSServerPrefix(input string) string {
prefix := ""
switch {
case strings.HasPrefix(input, "dns://"):
prefix = "dns://"
case strings.HasPrefix(input, "tls://"):
prefix = "tls://"
}
return prefix
}
func splitDNSServerPrefixServer(input string) (string, string) {
prefix := getDNSServerPrefix(input)
host := strings.TrimPrefix(input, prefix)
return prefix, host
}
func sanitizeDNSServer(input string) (string, error) {
prefix, host := splitDNSServerPrefixServer(input)
host = appendPortIfMissing(prefix, host)
{
h, _, err := net.SplitHostPort(host)
if err != nil {
return "", err
}
ip := net.ParseIP(h)
if ip == nil {
return "", fmt.Errorf("Invalid DNS server field: %s", h)
}
}
return prefix + host, nil
}
func appendPortIfMissing(prefix, input string) string {
port := "53"
switch prefix {
case "tls://":
port = "853"
}
_, _, err := net.SplitHostPort(input) _, _, err := net.SplitHostPort(input)
if err == nil { if err == nil {
return input return input
} }
return net.JoinHostPort(input, "53") return net.JoinHostPort(input, port)
}
func parseIPsOptionalPort(input string) []string {
fields := strings.Fields(input)
hosts := []string{}
for _, field := range fields {
_, _, err := net.SplitHostPort(field)
if err != nil {
ip := net.ParseIP(field)
if ip == nil {
log.Printf("Invalid DNS server field: %s\n", field)
continue
}
}
hosts = append(hosts, field)
}
return hosts
} }
func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) { func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {

View File

@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"os"
"path" "path"
"runtime" "runtime"
"sort" "sort"
@ -259,5 +260,5 @@ func trace(format string, args ...interface{}) {
if len(text) == 0 || text[len(text)-1] != '\n' { if len(text) == 0 || text[len(text)-1] != '\n' {
buf.WriteRune('\n') buf.WriteRune('\n')
} }
fmt.Print(buf.String()) fmt.Fprint(os.Stderr, buf.String())
} }