diff --git a/config.go b/config.go index 1a0d18a4..a8534bc5 100644 --- a/config.go +++ b/config.go @@ -70,6 +70,7 @@ type coreDNSConfig struct { Pprof string `yaml:"-"` Cache string `yaml:"-"` Prometheus string `yaml:"-"` + BootstrapDNS string `yaml:"bootstrap_dns"` UpstreamDNS []string `yaml:"upstream_dns"` } @@ -100,6 +101,7 @@ var config = configuration{ SafeBrowsingEnabled: false, BlockedResponseTTL: 10, // in seconds QueryLogEnabled: true, + BootstrapDNS: "8.8.8.8:53", UpstreamDNS: defaultDNS, Cache: "cache", Prometheus: "prometheus :9153", @@ -253,7 +255,7 @@ const coreDNSConfigTemplate = `.:{{.Port}} { hosts { fallthrough } - {{if .UpstreamDNS}}upstream {{range .UpstreamDNS}}{{.}} {{end}} { bootstrap 8.8.8.8:53 }{{end}} + {{if .UpstreamDNS}}upstream {{range .UpstreamDNS}}{{.}} {{end}} { bootstrap {{.BootstrapDNS}} }{{end}} {{.Cache}} {{.Prometheus}} } diff --git a/control.go b/control.go index 378a2aec..238bc131 100644 --- a/control.go +++ b/control.go @@ -6,7 +6,6 @@ import ( "fmt" "io/ioutil" "log" - "net" "net/http" "os" "path/filepath" @@ -15,8 +14,9 @@ import ( "strings" "time" + "github.com/AdguardTeam/AdGuardHome/upstream" + corednsplugin "github.com/AdguardTeam/AdGuardHome/coredns_plugin" - "github.com/miekg/dns" "gopkg.in/asaskevich/govalidator.v4" ) @@ -81,6 +81,7 @@ func handleStatus(w http.ResponseWriter, r *http.Request) { "protection_enabled": config.CoreDNS.ProtectionEnabled, "querylog_enabled": config.CoreDNS.QueryLogEnabled, "running": isRunning(), + "bootstrap_dns": config.CoreDNS.BootstrapDNS, "upstream_dns": config.CoreDNS.UpstreamDNS, "version": VersionString, } @@ -140,11 +141,8 @@ func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) { return } // if empty body -- user is asking for default servers - hosts, err := sanitiseDNSServers(string(body)) - if err != nil { - httpError(w, http.StatusBadRequest, "Invalid DNS servers were given: %s", err) - return - } + hosts := strings.Fields(string(body)) + if len(hosts) == 0 { config.CoreDNS.UpstreamDNS = defaultDNS } else { @@ -214,104 +212,26 @@ func handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) { } func checkDNS(input string) error { - input, err := sanitizeDNSServer(input) + + u, err := upstream.NewUpstream(input, config.CoreDNS.BootstrapDNS) + if err != nil { return err } - req := dns.Msg{} - req.Id = dns.Id() - req.RecursionDesired = true - req.Question = []dns.Question{ - {Name: "google-public-dns-a.google.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, - } + alive, err := upstream.IsAlive(u) - prefix, host := splitDNSServerPrefixServer(input) - - c := dns.Client{ - Timeout: time.Minute, - } - switch prefix { - case "tls://": - c.Net = "tcp-tls" - } - - resp, rtt, err := c.Exchange(&req, host) if err != nil { return fmt.Errorf("couldn't communicate with DNS server %s: %s", input, err) } - trace("exchange with %s took %v", input, rtt) - if len(resp.Answer) != 1 { - return fmt.Errorf("DNS server %s returned wrong answer", input) - } - if t, ok := resp.Answer[0].(*dns.A); ok { - if !net.IPv4(8, 8, 8, 8).Equal(t.A) { - return fmt.Errorf("DNS server %s returned wrong answer: %v", input, t.A) - } + + if !alive { + return fmt.Errorf("DNS server has not passed the healthcheck: %s", input) } return nil } -func sanitiseDNSServers(input string) ([]string, error) { - fields := strings.Fields(input) - hosts := make([]string, 0) - for _, field := range fields { - sanitized, err := sanitizeDNSServer(field) - if err != nil { - return hosts, err - } - hosts = append(hosts, sanitized) - } - return hosts, nil -} - -func getDNSServerPrefix(input string) string { - prefix := "" - switch { - case strings.HasPrefix(input, "dns://"): - prefix = "dns://" - case strings.HasPrefix(input, "tls://"): - prefix = "tls://" - } - return prefix -} - -func splitDNSServerPrefixServer(input string) (string, string) { - prefix := getDNSServerPrefix(input) - host := strings.TrimPrefix(input, prefix) - return prefix, host -} - -func sanitizeDNSServer(input string) (string, error) { - prefix, host := splitDNSServerPrefixServer(input) - host = appendPortIfMissing(prefix, host) - { - h, _, err := net.SplitHostPort(host) - if err != nil { - return "", err - } - ip := net.ParseIP(h) - if ip == nil { - return "", fmt.Errorf("invalid DNS server field: %s", h) - } - } - return prefix + host, nil -} - -func appendPortIfMissing(prefix, input string) string { - port := "53" - switch prefix { - case "tls://": - port = "853" - } - _, _, err := net.SplitHostPort(input) - if err == nil { - return input - } - return net.JoinHostPort(input, port) -} - //noinspection GoUnusedParameter func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) { now := time.Now() diff --git a/openapi.yaml b/openapi.yaml index 35e32a90..9fc585f2 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -41,6 +41,7 @@ paths: protection_enabled: true querylog_enabled: true running: true + bootstrap_dns: 8.8.8.8:53 upstream_dns: - 1.1.1.1 - 1.0.0.1 diff --git a/upstream/helpers.go b/upstream/helpers.go index 832d58b4..1313b8e0 100644 --- a/upstream/helpers.go +++ b/upstream/helpers.go @@ -93,7 +93,7 @@ func IsAlive(u Upstream) (bool, error) { // If we got a header, we're alright, basically only care about I/O errors 'n stuff. if err != nil && resp != nil { // Silly check, something sane came back. - if resp.Response || resp.Opcode == dns.OpcodeQuery { + if resp.Rcode != dns.RcodeServerFailure { err = nil } } diff --git a/upstream/persistent.go b/upstream/persistent.go index 5c28a10e..91cc9094 100644 --- a/upstream/persistent.go +++ b/upstream/persistent.go @@ -10,6 +10,8 @@ import ( "github.com/miekg/dns" ) +// Persistent connections cache -- almost similar to the same used in the CoreDNS forward plugin + const ( defaultExpire = 10 * time.Second minDialTimeout = 100 * time.Millisecond