Added bootstrap DNS to the config file

DNS healthcheck now uses the upstream package methods
This commit is contained in:
Andrey Meshkov 2018-11-06 00:47:59 +03:00
parent 7f018234f6
commit 451922b858
5 changed files with 19 additions and 94 deletions

View File

@ -70,6 +70,7 @@ type coreDNSConfig struct {
Pprof string `yaml:"-"` Pprof string `yaml:"-"`
Cache string `yaml:"-"` Cache string `yaml:"-"`
Prometheus string `yaml:"-"` Prometheus string `yaml:"-"`
BootstrapDNS string `yaml:"bootstrap_dns"`
UpstreamDNS []string `yaml:"upstream_dns"` UpstreamDNS []string `yaml:"upstream_dns"`
} }
@ -100,6 +101,7 @@ var config = configuration{
SafeBrowsingEnabled: false, SafeBrowsingEnabled: false,
BlockedResponseTTL: 10, // in seconds BlockedResponseTTL: 10, // in seconds
QueryLogEnabled: true, QueryLogEnabled: true,
BootstrapDNS: "8.8.8.8:53",
UpstreamDNS: defaultDNS, UpstreamDNS: defaultDNS,
Cache: "cache", Cache: "cache",
Prometheus: "prometheus :9153", Prometheus: "prometheus :9153",
@ -253,7 +255,7 @@ const coreDNSConfigTemplate = `.:{{.Port}} {
hosts { hosts {
fallthrough fallthrough
} }
{{if .UpstreamDNS}}upstream {{range .UpstreamDNS}}{{.}} {{end}} { bootstrap 8.8.8.8:53 }{{end}} {{if .UpstreamDNS}}upstream {{range .UpstreamDNS}}{{.}} {{end}} { bootstrap {{.BootstrapDNS}} }{{end}}
{{.Cache}} {{.Cache}}
{{.Prometheus}} {{.Prometheus}}
} }

View File

@ -6,7 +6,6 @@ import (
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"log" "log"
"net"
"net/http" "net/http"
"os" "os"
"path/filepath" "path/filepath"
@ -15,8 +14,9 @@ import (
"strings" "strings"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/upstream"
corednsplugin "github.com/AdguardTeam/AdGuardHome/coredns_plugin" corednsplugin "github.com/AdguardTeam/AdGuardHome/coredns_plugin"
"github.com/miekg/dns"
"gopkg.in/asaskevich/govalidator.v4" "gopkg.in/asaskevich/govalidator.v4"
) )
@ -81,6 +81,7 @@ func handleStatus(w http.ResponseWriter, r *http.Request) {
"protection_enabled": config.CoreDNS.ProtectionEnabled, "protection_enabled": config.CoreDNS.ProtectionEnabled,
"querylog_enabled": config.CoreDNS.QueryLogEnabled, "querylog_enabled": config.CoreDNS.QueryLogEnabled,
"running": isRunning(), "running": isRunning(),
"bootstrap_dns": config.CoreDNS.BootstrapDNS,
"upstream_dns": config.CoreDNS.UpstreamDNS, "upstream_dns": config.CoreDNS.UpstreamDNS,
"version": VersionString, "version": VersionString,
} }
@ -140,11 +141,8 @@ func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) {
return return
} }
// if empty body -- user is asking for default servers // if empty body -- user is asking for default servers
hosts, err := sanitiseDNSServers(string(body)) hosts := strings.Fields(string(body))
if err != nil {
httpError(w, http.StatusBadRequest, "Invalid DNS servers were given: %s", err)
return
}
if len(hosts) == 0 { if len(hosts) == 0 {
config.CoreDNS.UpstreamDNS = defaultDNS config.CoreDNS.UpstreamDNS = defaultDNS
} else { } else {
@ -214,104 +212,26 @@ func handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
} }
func checkDNS(input string) error { func checkDNS(input string) error {
input, err := sanitizeDNSServer(input)
u, err := upstream.NewUpstream(input, config.CoreDNS.BootstrapDNS)
if err != nil { if err != nil {
return err return err
} }
req := dns.Msg{} alive, err := upstream.IsAlive(u)
req.Id = dns.Id()
req.RecursionDesired = true
req.Question = []dns.Question{
{Name: "google-public-dns-a.google.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
}
prefix, host := splitDNSServerPrefixServer(input)
c := dns.Client{
Timeout: time.Minute,
}
switch prefix {
case "tls://":
c.Net = "tcp-tls"
}
resp, rtt, err := c.Exchange(&req, host)
if err != nil { if err != nil {
return fmt.Errorf("couldn't communicate with DNS server %s: %s", input, err) return fmt.Errorf("couldn't communicate with DNS server %s: %s", input, err)
} }
trace("exchange with %s took %v", input, rtt)
if len(resp.Answer) != 1 { if !alive {
return fmt.Errorf("DNS server %s returned wrong answer", input) return fmt.Errorf("DNS server has not passed the healthcheck: %s", input)
}
if t, ok := resp.Answer[0].(*dns.A); ok {
if !net.IPv4(8, 8, 8, 8).Equal(t.A) {
return fmt.Errorf("DNS server %s returned wrong answer: %v", input, t.A)
}
} }
return nil return nil
} }
func sanitiseDNSServers(input string) ([]string, error) {
fields := strings.Fields(input)
hosts := make([]string, 0)
for _, field := range fields {
sanitized, err := sanitizeDNSServer(field)
if err != nil {
return hosts, err
}
hosts = append(hosts, sanitized)
}
return hosts, nil
}
func getDNSServerPrefix(input string) string {
prefix := ""
switch {
case strings.HasPrefix(input, "dns://"):
prefix = "dns://"
case strings.HasPrefix(input, "tls://"):
prefix = "tls://"
}
return prefix
}
func splitDNSServerPrefixServer(input string) (string, string) {
prefix := getDNSServerPrefix(input)
host := strings.TrimPrefix(input, prefix)
return prefix, host
}
func sanitizeDNSServer(input string) (string, error) {
prefix, host := splitDNSServerPrefixServer(input)
host = appendPortIfMissing(prefix, host)
{
h, _, err := net.SplitHostPort(host)
if err != nil {
return "", err
}
ip := net.ParseIP(h)
if ip == nil {
return "", fmt.Errorf("invalid DNS server field: %s", h)
}
}
return prefix + host, nil
}
func appendPortIfMissing(prefix, input string) string {
port := "53"
switch prefix {
case "tls://":
port = "853"
}
_, _, err := net.SplitHostPort(input)
if err == nil {
return input
}
return net.JoinHostPort(input, port)
}
//noinspection GoUnusedParameter //noinspection GoUnusedParameter
func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) { func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
now := time.Now() now := time.Now()

View File

@ -41,6 +41,7 @@ paths:
protection_enabled: true protection_enabled: true
querylog_enabled: true querylog_enabled: true
running: true running: true
bootstrap_dns: 8.8.8.8:53
upstream_dns: upstream_dns:
- 1.1.1.1 - 1.1.1.1
- 1.0.0.1 - 1.0.0.1

View File

@ -93,7 +93,7 @@ func IsAlive(u Upstream) (bool, error) {
// If we got a header, we're alright, basically only care about I/O errors 'n stuff. // If we got a header, we're alright, basically only care about I/O errors 'n stuff.
if err != nil && resp != nil { if err != nil && resp != nil {
// Silly check, something sane came back. // Silly check, something sane came back.
if resp.Response || resp.Opcode == dns.OpcodeQuery { if resp.Rcode != dns.RcodeServerFailure {
err = nil err = nil
} }
} }

View File

@ -10,6 +10,8 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
) )
// Persistent connections cache -- almost similar to the same used in the CoreDNS forward plugin
const ( const (
defaultExpire = 10 * time.Second defaultExpire = 10 * time.Second
minDialTimeout = 100 * time.Millisecond minDialTimeout = 100 * time.Millisecond