Merge branch 'i18n' into i18n_wip
This commit is contained in:
commit
8d209773b3
|
@ -106,8 +106,10 @@ Settings are stored in [YAML format](https://en.wikipedia.org/wiki/YAML), possib
|
||||||
* `parental_enabled` — Parental control-based DNS requests filtering
|
* `parental_enabled` — Parental control-based DNS requests filtering
|
||||||
* `parental_sensitivity` — Age group for parental control-based filtering, must be either 3, 10, 13 or 17
|
* `parental_sensitivity` — Age group for parental control-based filtering, must be either 3, 10, 13 or 17
|
||||||
* `querylog_enabled` — Query logging (also used to calculate top 50 clients, blocked domains and requested domains for statistic purposes)
|
* `querylog_enabled` — Query logging (also used to calculate top 50 clients, blocked domains and requested domains for statistic purposes)
|
||||||
|
* `bootstrap_dns` — DNS server used for initial hostnames resolution in case if upstream is DoH or DoT with a hostname
|
||||||
* `upstream_dns` — List of upstream DNS servers
|
* `upstream_dns` — List of upstream DNS servers
|
||||||
* `filters` — List of filters, each filter has the following values:
|
* `filters` — List of filters, each filter has the following values:
|
||||||
|
* `ID` - filter ID (must be unique)
|
||||||
* `url` — URL pointing to the filter contents (filtering rules)
|
* `url` — URL pointing to the filter contents (filtering rules)
|
||||||
* `enabled` — Current filter's status (enabled/disabled)
|
* `enabled` — Current filter's status (enabled/disabled)
|
||||||
* `user_rules` — User-specified filtering rules
|
* `user_rules` — User-specified filtering rules
|
||||||
|
|
|
@ -70,6 +70,7 @@ type coreDNSConfig struct {
|
||||||
Pprof string `yaml:"-"`
|
Pprof string `yaml:"-"`
|
||||||
Cache string `yaml:"-"`
|
Cache string `yaml:"-"`
|
||||||
Prometheus string `yaml:"-"`
|
Prometheus string `yaml:"-"`
|
||||||
|
BootstrapDNS string `yaml:"bootstrap_dns"`
|
||||||
UpstreamDNS []string `yaml:"upstream_dns"`
|
UpstreamDNS []string `yaml:"upstream_dns"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -100,6 +101,7 @@ var config = configuration{
|
||||||
SafeBrowsingEnabled: false,
|
SafeBrowsingEnabled: false,
|
||||||
BlockedResponseTTL: 10, // in seconds
|
BlockedResponseTTL: 10, // in seconds
|
||||||
QueryLogEnabled: true,
|
QueryLogEnabled: true,
|
||||||
|
BootstrapDNS: "8.8.8.8:53",
|
||||||
UpstreamDNS: defaultDNS,
|
UpstreamDNS: defaultDNS,
|
||||||
Cache: "cache",
|
Cache: "cache",
|
||||||
Prometheus: "prometheus :9153",
|
Prometheus: "prometheus :9153",
|
||||||
|
@ -253,7 +255,7 @@ const coreDNSConfigTemplate = `.:{{.Port}} {
|
||||||
hosts {
|
hosts {
|
||||||
fallthrough
|
fallthrough
|
||||||
}
|
}
|
||||||
{{if .UpstreamDNS}}forward . {{range .UpstreamDNS}}{{.}} {{end}}{{end}}
|
{{if .UpstreamDNS}}upstream {{range .UpstreamDNS}}{{.}} {{end}} { bootstrap {{.BootstrapDNS}} }{{end}}
|
||||||
{{.Cache}}
|
{{.Cache}}
|
||||||
{{.Prometheus}}
|
{{.Prometheus}}
|
||||||
}
|
}
|
||||||
|
|
147
control.go
147
control.go
|
@ -6,7 +6,6 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
@ -15,8 +14,9 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/upstream"
|
||||||
|
|
||||||
corednsplugin "github.com/AdguardTeam/AdGuardHome/coredns_plugin"
|
corednsplugin "github.com/AdguardTeam/AdGuardHome/coredns_plugin"
|
||||||
"github.com/miekg/dns"
|
|
||||||
"gopkg.in/asaskevich/govalidator.v4"
|
"gopkg.in/asaskevich/govalidator.v4"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -81,6 +81,7 @@ func handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||||
"protection_enabled": config.CoreDNS.ProtectionEnabled,
|
"protection_enabled": config.CoreDNS.ProtectionEnabled,
|
||||||
"querylog_enabled": config.CoreDNS.QueryLogEnabled,
|
"querylog_enabled": config.CoreDNS.QueryLogEnabled,
|
||||||
"running": isRunning(),
|
"running": isRunning(),
|
||||||
|
"bootstrap_dns": config.CoreDNS.BootstrapDNS,
|
||||||
"upstream_dns": config.CoreDNS.UpstreamDNS,
|
"upstream_dns": config.CoreDNS.UpstreamDNS,
|
||||||
"version": VersionString,
|
"version": VersionString,
|
||||||
}
|
}
|
||||||
|
@ -134,17 +135,14 @@ func httpError(w http.ResponseWriter, code int, format string, args ...interface
|
||||||
func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||||
body, err := ioutil.ReadAll(r.Body)
|
body, err := ioutil.ReadAll(r.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errortext := fmt.Sprintf("Failed to read request body: %s", err)
|
errorText := fmt.Sprintf("Failed to read request body: %s", err)
|
||||||
log.Println(errortext)
|
log.Println(errorText)
|
||||||
http.Error(w, errortext, http.StatusBadRequest)
|
http.Error(w, errorText, http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// if empty body -- user is asking for default servers
|
// if empty body -- user is asking for default servers
|
||||||
hosts, err := sanitiseDNSServers(string(body))
|
hosts := strings.Fields(string(body))
|
||||||
if err != nil {
|
|
||||||
httpError(w, http.StatusBadRequest, "Invalid DNS servers were given: %s", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if len(hosts) == 0 {
|
if len(hosts) == 0 {
|
||||||
config.CoreDNS.UpstreamDNS = defaultDNS
|
config.CoreDNS.UpstreamDNS = defaultDNS
|
||||||
} else {
|
} else {
|
||||||
|
@ -153,34 +151,34 @@ func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
err = writeAllConfigs()
|
err = writeAllConfigs()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errortext := fmt.Sprintf("Couldn't write config file: %s", err)
|
errorText := fmt.Sprintf("Couldn't write config file: %s", err)
|
||||||
log.Println(errortext)
|
log.Println(errorText)
|
||||||
http.Error(w, errortext, http.StatusInternalServerError)
|
http.Error(w, errorText, http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
tellCoreDNSToReload()
|
tellCoreDNSToReload()
|
||||||
_, err = fmt.Fprintf(w, "OK %d servers\n", len(hosts))
|
_, err = fmt.Fprintf(w, "OK %d servers\n", len(hosts))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errortext := fmt.Sprintf("Couldn't write body: %s", err)
|
errorText := fmt.Sprintf("Couldn't write body: %s", err)
|
||||||
log.Println(errortext)
|
log.Println(errorText)
|
||||||
http.Error(w, errortext, http.StatusInternalServerError)
|
http.Error(w, errorText, http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
func handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||||
body, err := ioutil.ReadAll(r.Body)
|
body, err := ioutil.ReadAll(r.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errortext := fmt.Sprintf("Failed to read request body: %s", err)
|
errorText := fmt.Sprintf("Failed to read request body: %s", err)
|
||||||
log.Println(errortext)
|
log.Println(errorText)
|
||||||
http.Error(w, errortext, 400)
|
http.Error(w, errorText, 400)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
hosts := strings.Fields(string(body))
|
hosts := strings.Fields(string(body))
|
||||||
|
|
||||||
if len(hosts) == 0 {
|
if len(hosts) == 0 {
|
||||||
errortext := fmt.Sprintf("No servers specified")
|
errorText := fmt.Sprintf("No servers specified")
|
||||||
log.Println(errortext)
|
log.Println(errorText)
|
||||||
http.Error(w, errortext, http.StatusBadRequest)
|
http.Error(w, errorText, http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -198,120 +196,43 @@ func handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
jsonVal, err := json.Marshal(result)
|
jsonVal, err := json.Marshal(result)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errortext := fmt.Sprintf("Unable to marshal status json: %s", err)
|
errorText := fmt.Sprintf("Unable to marshal status json: %s", err)
|
||||||
log.Println(errortext)
|
log.Println(errorText)
|
||||||
http.Error(w, errortext, http.StatusInternalServerError)
|
http.Error(w, errorText, http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
_, err = w.Write(jsonVal)
|
_, err = w.Write(jsonVal)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errortext := fmt.Sprintf("Couldn't write body: %s", err)
|
errorText := fmt.Sprintf("Couldn't write body: %s", err)
|
||||||
log.Println(errortext)
|
log.Println(errorText)
|
||||||
http.Error(w, errortext, http.StatusInternalServerError)
|
http.Error(w, errorText, http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkDNS(input string) error {
|
func checkDNS(input string) error {
|
||||||
input, err := sanitizeDNSServer(input)
|
|
||||||
|
u, err := upstream.NewUpstream(input, config.CoreDNS.BootstrapDNS)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
defer u.Close()
|
||||||
|
|
||||||
req := dns.Msg{}
|
alive, err := upstream.IsAlive(u)
|
||||||
req.Id = dns.Id()
|
|
||||||
req.RecursionDesired = true
|
|
||||||
req.Question = []dns.Question{
|
|
||||||
{Name: "google-public-dns-a.google.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
|
|
||||||
}
|
|
||||||
|
|
||||||
prefix, host := splitDNSServerPrefixServer(input)
|
|
||||||
|
|
||||||
c := dns.Client{
|
|
||||||
Timeout: time.Minute,
|
|
||||||
}
|
|
||||||
switch prefix {
|
|
||||||
case "tls://":
|
|
||||||
c.Net = "tcp-tls"
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, rtt, err := c.Exchange(&req, host)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("couldn't communicate with DNS server %s: %s", input, err)
|
return fmt.Errorf("couldn't communicate with DNS server %s: %s", input, err)
|
||||||
}
|
}
|
||||||
trace("exchange with %s took %v", input, rtt)
|
|
||||||
if len(resp.Answer) != 1 {
|
if !alive {
|
||||||
return fmt.Errorf("DNS server %s returned wrong answer", input)
|
return fmt.Errorf("DNS server has not passed the healthcheck: %s", input)
|
||||||
}
|
|
||||||
if t, ok := resp.Answer[0].(*dns.A); ok {
|
|
||||||
if !net.IPv4(8, 8, 8, 8).Equal(t.A) {
|
|
||||||
return fmt.Errorf("DNS server %s returned wrong answer: %v", input, t.A)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func sanitiseDNSServers(input string) ([]string, error) {
|
|
||||||
fields := strings.Fields(input)
|
|
||||||
hosts := make([]string, 0)
|
|
||||||
for _, field := range fields {
|
|
||||||
sanitized, err := sanitizeDNSServer(field)
|
|
||||||
if err != nil {
|
|
||||||
return hosts, err
|
|
||||||
}
|
|
||||||
hosts = append(hosts, sanitized)
|
|
||||||
}
|
|
||||||
return hosts, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getDNSServerPrefix(input string) string {
|
|
||||||
prefix := ""
|
|
||||||
switch {
|
|
||||||
case strings.HasPrefix(input, "dns://"):
|
|
||||||
prefix = "dns://"
|
|
||||||
case strings.HasPrefix(input, "tls://"):
|
|
||||||
prefix = "tls://"
|
|
||||||
}
|
|
||||||
return prefix
|
|
||||||
}
|
|
||||||
|
|
||||||
func splitDNSServerPrefixServer(input string) (string, string) {
|
|
||||||
prefix := getDNSServerPrefix(input)
|
|
||||||
host := strings.TrimPrefix(input, prefix)
|
|
||||||
return prefix, host
|
|
||||||
}
|
|
||||||
|
|
||||||
func sanitizeDNSServer(input string) (string, error) {
|
|
||||||
prefix, host := splitDNSServerPrefixServer(input)
|
|
||||||
host = appendPortIfMissing(prefix, host)
|
|
||||||
{
|
|
||||||
h, _, err := net.SplitHostPort(host)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
ip := net.ParseIP(h)
|
|
||||||
if ip == nil {
|
|
||||||
return "", fmt.Errorf("invalid DNS server field: %s", h)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return prefix + host, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func appendPortIfMissing(prefix, input string) string {
|
|
||||||
port := "53"
|
|
||||||
switch prefix {
|
|
||||||
case "tls://":
|
|
||||||
port = "853"
|
|
||||||
}
|
|
||||||
_, _, err := net.SplitHostPort(input)
|
|
||||||
if err == nil {
|
|
||||||
return input
|
|
||||||
}
|
|
||||||
return net.JoinHostPort(input, port)
|
|
||||||
}
|
|
||||||
|
|
||||||
//noinspection GoUnusedParameter
|
//noinspection GoUnusedParameter
|
||||||
func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
|
func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"sync" // Include all plugins.
|
"sync" // Include all plugins.
|
||||||
|
|
||||||
_ "github.com/AdguardTeam/AdGuardHome/coredns_plugin"
|
_ "github.com/AdguardTeam/AdGuardHome/coredns_plugin"
|
||||||
|
_ "github.com/AdguardTeam/AdGuardHome/upstream"
|
||||||
"github.com/coredns/coredns/core/dnsserver"
|
"github.com/coredns/coredns/core/dnsserver"
|
||||||
"github.com/coredns/coredns/coremain"
|
"github.com/coredns/coredns/coremain"
|
||||||
_ "github.com/coredns/coredns/plugin/auto"
|
_ "github.com/coredns/coredns/plugin/auto"
|
||||||
|
@ -79,6 +80,7 @@ var directives = []string{
|
||||||
"loop",
|
"loop",
|
||||||
"forward",
|
"forward",
|
||||||
"proxy",
|
"proxy",
|
||||||
|
"upstream",
|
||||||
"erratic",
|
"erratic",
|
||||||
"whoami",
|
"whoami",
|
||||||
"on",
|
"on",
|
||||||
|
|
|
@ -41,6 +41,7 @@ paths:
|
||||||
protection_enabled: true
|
protection_enabled: true
|
||||||
querylog_enabled: true
|
querylog_enabled: true
|
||||||
running: true
|
running: true
|
||||||
|
bootstrap_dns: 8.8.8.8:53
|
||||||
upstream_dns:
|
upstream_dns:
|
||||||
- 1.1.1.1
|
- 1.1.1.1
|
||||||
- 1.0.0.1
|
- 1.0.0.1
|
||||||
|
|
|
@ -0,0 +1,109 @@
|
||||||
|
package upstream
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DnsUpstream is a very simple upstream implementation for plain DNS
|
||||||
|
type DnsUpstream struct {
|
||||||
|
endpoint string // IP:port
|
||||||
|
timeout time.Duration // Max read and write timeout
|
||||||
|
proto string // Protocol (tcp, tcp-tls, or udp)
|
||||||
|
transport *Transport // Persistent connections cache
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDnsUpstream creates a new DNS upstream
|
||||||
|
func NewDnsUpstream(endpoint string, proto string, tlsServerName string) (Upstream, error) {
|
||||||
|
|
||||||
|
u := &DnsUpstream{
|
||||||
|
endpoint: endpoint,
|
||||||
|
timeout: defaultTimeout,
|
||||||
|
proto: proto,
|
||||||
|
}
|
||||||
|
|
||||||
|
var tlsConfig *tls.Config
|
||||||
|
|
||||||
|
if proto == "tcp-tls" {
|
||||||
|
tlsConfig = new(tls.Config)
|
||||||
|
tlsConfig.ServerName = tlsServerName
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize the connections cache
|
||||||
|
u.transport = NewTransport(endpoint)
|
||||||
|
u.transport.tlsConfig = tlsConfig
|
||||||
|
u.transport.Start()
|
||||||
|
|
||||||
|
return u, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exchange provides an implementation for the Upstream interface
|
||||||
|
func (u *DnsUpstream) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) {
|
||||||
|
|
||||||
|
resp, err := u.exchange(u.proto, query)
|
||||||
|
|
||||||
|
// Retry over TCP if response is truncated
|
||||||
|
if err == dns.ErrTruncated && u.proto == "udp" {
|
||||||
|
resp, err = u.exchange("tcp", query)
|
||||||
|
} else if err == dns.ErrTruncated && resp != nil {
|
||||||
|
// Reassemble something to be sent to client
|
||||||
|
m := new(dns.Msg)
|
||||||
|
m.SetReply(query)
|
||||||
|
m.Truncated = true
|
||||||
|
m.Authoritative = true
|
||||||
|
m.Rcode = dns.RcodeSuccess
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
resp = &dns.Msg{}
|
||||||
|
resp.SetRcode(resp, dns.RcodeServerFailure)
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear resources
|
||||||
|
func (u *DnsUpstream) Close() error {
|
||||||
|
|
||||||
|
// Close active connections
|
||||||
|
u.transport.Stop()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Performs a synchronous query. It sends the message m via the conn
|
||||||
|
// c and waits for a reply. The conn c is not closed.
|
||||||
|
func (u *DnsUpstream) exchange(proto string, query *dns.Msg) (r *dns.Msg, err error) {
|
||||||
|
|
||||||
|
// Establish a connection if needed (or reuse cached)
|
||||||
|
conn, err := u.transport.Dial(proto)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write the request with a timeout
|
||||||
|
conn.SetWriteDeadline(time.Now().Add(u.timeout))
|
||||||
|
if err = conn.WriteMsg(query); err != nil {
|
||||||
|
conn.Close() // Not giving it back
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write response with a timeout
|
||||||
|
conn.SetReadDeadline(time.Now().Add(u.timeout))
|
||||||
|
r, err = conn.ReadMsg()
|
||||||
|
if err != nil {
|
||||||
|
conn.Close() // Not giving it back
|
||||||
|
} else if err == nil && r.Id != query.Id {
|
||||||
|
err = dns.ErrId
|
||||||
|
conn.Close() // Not giving it back
|
||||||
|
}
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
// Return it back to the connections cache if there were no errors
|
||||||
|
u.transport.Yield(conn)
|
||||||
|
}
|
||||||
|
return r, err
|
||||||
|
}
|
|
@ -0,0 +1,101 @@
|
||||||
|
package upstream
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Detects the upstream type from the specified url and creates a proper Upstream object
|
||||||
|
func NewUpstream(url string, bootstrap string) (Upstream, error) {
|
||||||
|
|
||||||
|
proto := "udp"
|
||||||
|
prefix := ""
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case strings.HasPrefix(url, "tcp://"):
|
||||||
|
proto = "tcp"
|
||||||
|
prefix = "tcp://"
|
||||||
|
case strings.HasPrefix(url, "tls://"):
|
||||||
|
proto = "tcp-tls"
|
||||||
|
prefix = "tls://"
|
||||||
|
case strings.HasPrefix(url, "https://"):
|
||||||
|
return NewHttpsUpstream(url, bootstrap)
|
||||||
|
}
|
||||||
|
|
||||||
|
hostname := strings.TrimPrefix(url, prefix)
|
||||||
|
|
||||||
|
host, port, err := net.SplitHostPort(hostname)
|
||||||
|
if err != nil {
|
||||||
|
// Set port depending on the protocol
|
||||||
|
switch proto {
|
||||||
|
case "udp":
|
||||||
|
port = "53"
|
||||||
|
case "tcp":
|
||||||
|
port = "53"
|
||||||
|
case "tcp-tls":
|
||||||
|
port = "853"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set host = hostname
|
||||||
|
host = hostname
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to resolve the host address (or check if it's an IP address)
|
||||||
|
bootstrapResolver := CreateResolver(bootstrap)
|
||||||
|
ips, err := bootstrapResolver.LookupIPAddr(context.Background(), host)
|
||||||
|
|
||||||
|
if err != nil || len(ips) == 0 {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
addr := ips[0].String()
|
||||||
|
endpoint := net.JoinHostPort(addr, port)
|
||||||
|
tlsServerName := ""
|
||||||
|
|
||||||
|
if proto == "tcp-tls" && host != addr {
|
||||||
|
// Check if we need to specify TLS server name
|
||||||
|
tlsServerName = host
|
||||||
|
}
|
||||||
|
|
||||||
|
return NewDnsUpstream(endpoint, proto, tlsServerName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateResolver(bootstrap string) *net.Resolver {
|
||||||
|
|
||||||
|
bootstrapResolver := net.DefaultResolver
|
||||||
|
|
||||||
|
if bootstrap != "" {
|
||||||
|
bootstrapResolver = &net.Resolver{
|
||||||
|
PreferGo: true,
|
||||||
|
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
var d net.Dialer
|
||||||
|
return d.DialContext(ctx, network, bootstrap)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return bootstrapResolver
|
||||||
|
}
|
||||||
|
|
||||||
|
// Performs a simple health-check of the specified upstream
|
||||||
|
func IsAlive(u Upstream) (bool, error) {
|
||||||
|
|
||||||
|
// Using ipv4only.arpa. domain as it is a part of DNS64 RFC and it should exist everywhere
|
||||||
|
ping := new(dns.Msg)
|
||||||
|
ping.SetQuestion("ipv4only.arpa.", dns.TypeA)
|
||||||
|
|
||||||
|
resp, err := u.Exchange(context.Background(), ping)
|
||||||
|
|
||||||
|
// If we got a header, we're alright, basically only care about I/O errors 'n stuff.
|
||||||
|
if err != nil && resp != nil {
|
||||||
|
// Silly check, something sane came back.
|
||||||
|
if resp.Rcode != dns.RcodeServerFailure {
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return err == nil, err
|
||||||
|
}
|
|
@ -0,0 +1,128 @@
|
||||||
|
package upstream
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
"golang.org/x/net/http2"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
dnsMessageContentType = "application/dns-message"
|
||||||
|
defaultKeepAlive = 30 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// HttpsUpstream is the upstream implementation for DNS-over-HTTPS
|
||||||
|
type HttpsUpstream struct {
|
||||||
|
client *http.Client
|
||||||
|
endpoint *url.URL
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHttpsUpstream creates a new DNS-over-HTTPS upstream from the specified url
|
||||||
|
func NewHttpsUpstream(endpoint string, bootstrap string) (Upstream, error) {
|
||||||
|
u, err := url.Parse(endpoint)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize bootstrap resolver
|
||||||
|
bootstrapResolver := CreateResolver(bootstrap)
|
||||||
|
dialer := &net.Dialer{
|
||||||
|
Timeout: defaultTimeout,
|
||||||
|
KeepAlive: defaultKeepAlive,
|
||||||
|
DualStack: true,
|
||||||
|
Resolver: bootstrapResolver,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update TLS and HTTP client configuration
|
||||||
|
tlsConfig := &tls.Config{ServerName: u.Hostname()}
|
||||||
|
transport := &http.Transport{
|
||||||
|
TLSClientConfig: tlsConfig,
|
||||||
|
DisableCompression: true,
|
||||||
|
MaxIdleConns: 1,
|
||||||
|
DialContext: dialer.DialContext,
|
||||||
|
}
|
||||||
|
http2.ConfigureTransport(transport)
|
||||||
|
|
||||||
|
client := &http.Client{
|
||||||
|
Timeout: defaultTimeout,
|
||||||
|
Transport: transport,
|
||||||
|
}
|
||||||
|
|
||||||
|
return &HttpsUpstream{client: client, endpoint: u}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exchange provides an implementation for the Upstream interface
|
||||||
|
func (u *HttpsUpstream) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) {
|
||||||
|
queryBuf, err := query.Pack()
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "failed to pack DNS query")
|
||||||
|
}
|
||||||
|
|
||||||
|
// No content negotiation for now, use DNS wire format
|
||||||
|
buf, backendErr := u.exchangeWireformat(queryBuf)
|
||||||
|
if backendErr == nil {
|
||||||
|
response := &dns.Msg{}
|
||||||
|
if err := response.Unpack(buf); err != nil {
|
||||||
|
return nil, errors.Wrap(err, "failed to unpack DNS response from body")
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Id = query.Id
|
||||||
|
return response, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("failed to connect to an HTTPS backend %q due to %s", u.endpoint, backendErr)
|
||||||
|
return nil, backendErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform message exchange with the default UDP wireformat defined in current draft
|
||||||
|
// https://tools.ietf.org/html/draft-ietf-doh-dns-over-https-10
|
||||||
|
func (u *HttpsUpstream) exchangeWireformat(msg []byte) ([]byte, error) {
|
||||||
|
req, err := http.NewRequest("POST", u.endpoint.String(), bytes.NewBuffer(msg))
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "failed to create an HTTPS request")
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Add("Content-Type", dnsMessageContentType)
|
||||||
|
req.Header.Add("Accept", dnsMessageContentType)
|
||||||
|
req.Host = u.endpoint.Hostname()
|
||||||
|
|
||||||
|
resp, err := u.client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "failed to perform an HTTPS request")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check response status code
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("returned status code %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
contentType := resp.Header.Get("Content-Type")
|
||||||
|
if contentType != dnsMessageContentType {
|
||||||
|
return nil, fmt.Errorf("return wrong content type %s", contentType)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read application/dns-message response from the body
|
||||||
|
buf, err := ioutil.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "failed to read the response body")
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear resources
|
||||||
|
func (u *HttpsUpstream) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -0,0 +1,210 @@
|
||||||
|
package upstream
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"net"
|
||||||
|
"sort"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Persistent connections cache -- almost similar to the same used in the CoreDNS forward plugin
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultExpire = 10 * time.Second
|
||||||
|
minDialTimeout = 100 * time.Millisecond
|
||||||
|
maxDialTimeout = 30 * time.Second
|
||||||
|
defaultDialTimeout = 30 * time.Second
|
||||||
|
cumulativeAvgWeight = 4
|
||||||
|
)
|
||||||
|
|
||||||
|
// a persistConn hold the dns.Conn and the last used time.
|
||||||
|
type persistConn struct {
|
||||||
|
c *dns.Conn
|
||||||
|
used time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// Transport hold the persistent cache.
|
||||||
|
type Transport struct {
|
||||||
|
avgDialTime int64 // kind of average time of dial time
|
||||||
|
conns map[string][]*persistConn // Buckets for udp, tcp and tcp-tls.
|
||||||
|
expire time.Duration // After this duration a connection is expired.
|
||||||
|
addr string
|
||||||
|
tlsConfig *tls.Config
|
||||||
|
|
||||||
|
dial chan string
|
||||||
|
yield chan *dns.Conn
|
||||||
|
ret chan *dns.Conn
|
||||||
|
stop chan bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dial dials the address configured in transport, potentially reusing a connection or creating a new one.
|
||||||
|
func (t *Transport) Dial(proto string) (*dns.Conn, error) {
|
||||||
|
// If tls has been configured; use it.
|
||||||
|
if t.tlsConfig != nil {
|
||||||
|
proto = "tcp-tls"
|
||||||
|
}
|
||||||
|
|
||||||
|
t.dial <- proto
|
||||||
|
c := <-t.ret
|
||||||
|
|
||||||
|
if c != nil {
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
reqTime := time.Now()
|
||||||
|
timeout := t.dialTimeout()
|
||||||
|
if proto == "tcp-tls" {
|
||||||
|
conn, err := dns.DialTimeoutWithTLS(proto, t.addr, t.tlsConfig, timeout)
|
||||||
|
t.updateDialTimeout(time.Since(reqTime))
|
||||||
|
return conn, err
|
||||||
|
}
|
||||||
|
conn, err := dns.DialTimeout(proto, t.addr, timeout)
|
||||||
|
t.updateDialTimeout(time.Since(reqTime))
|
||||||
|
return conn, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Yield return the connection to transport for reuse.
|
||||||
|
func (t *Transport) Yield(c *dns.Conn) { t.yield <- c }
|
||||||
|
|
||||||
|
// Start starts the transport's connection manager.
|
||||||
|
func (t *Transport) Start() { go t.connManager() }
|
||||||
|
|
||||||
|
// Stop stops the transport's connection manager.
|
||||||
|
func (t *Transport) Stop() { close(t.stop) }
|
||||||
|
|
||||||
|
// SetExpire sets the connection expire time in transport.
|
||||||
|
func (t *Transport) SetExpire(expire time.Duration) { t.expire = expire }
|
||||||
|
|
||||||
|
// SetTLSConfig sets the TLS config in transport.
|
||||||
|
func (t *Transport) SetTLSConfig(cfg *tls.Config) { t.tlsConfig = cfg }
|
||||||
|
|
||||||
|
func NewTransport(addr string) *Transport {
|
||||||
|
t := &Transport{
|
||||||
|
avgDialTime: int64(defaultDialTimeout / 2),
|
||||||
|
conns: make(map[string][]*persistConn),
|
||||||
|
expire: defaultExpire,
|
||||||
|
addr: addr,
|
||||||
|
dial: make(chan string),
|
||||||
|
yield: make(chan *dns.Conn),
|
||||||
|
ret: make(chan *dns.Conn),
|
||||||
|
stop: make(chan bool),
|
||||||
|
}
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
|
||||||
|
func averageTimeout(currentAvg *int64, observedDuration time.Duration, weight int64) {
|
||||||
|
dt := time.Duration(atomic.LoadInt64(currentAvg))
|
||||||
|
atomic.AddInt64(currentAvg, int64(observedDuration-dt)/weight)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Transport) dialTimeout() time.Duration {
|
||||||
|
return limitTimeout(&t.avgDialTime, minDialTimeout, maxDialTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Transport) updateDialTimeout(newDialTime time.Duration) {
|
||||||
|
averageTimeout(&t.avgDialTime, newDialTime, cumulativeAvgWeight)
|
||||||
|
}
|
||||||
|
|
||||||
|
// limitTimeout is a utility function to auto-tune timeout values
|
||||||
|
// average observed time is moved towards the last observed delay moderated by a weight
|
||||||
|
// next timeout to use will be the double of the computed average, limited by min and max frame.
|
||||||
|
func limitTimeout(currentAvg *int64, minValue time.Duration, maxValue time.Duration) time.Duration {
|
||||||
|
rt := time.Duration(atomic.LoadInt64(currentAvg))
|
||||||
|
if rt < minValue {
|
||||||
|
return minValue
|
||||||
|
}
|
||||||
|
if rt < maxValue/2 {
|
||||||
|
return 2 * rt
|
||||||
|
}
|
||||||
|
return maxValue
|
||||||
|
}
|
||||||
|
|
||||||
|
// connManagers manages the persistent connection cache for UDP and TCP.
|
||||||
|
func (t *Transport) connManager() {
|
||||||
|
ticker := time.NewTicker(t.expire)
|
||||||
|
Wait:
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case proto := <-t.dial:
|
||||||
|
// take the last used conn - complexity O(1)
|
||||||
|
if stack := t.conns[proto]; len(stack) > 0 {
|
||||||
|
pc := stack[len(stack)-1]
|
||||||
|
if time.Since(pc.used) < t.expire {
|
||||||
|
// Found one, remove from pool and return this conn.
|
||||||
|
t.conns[proto] = stack[:len(stack)-1]
|
||||||
|
t.ret <- pc.c
|
||||||
|
continue Wait
|
||||||
|
}
|
||||||
|
// clear entire cache if the last conn is expired
|
||||||
|
t.conns[proto] = nil
|
||||||
|
// now, the connections being passed to closeConns() are not reachable from
|
||||||
|
// transport methods anymore. So, it's safe to close them in a separate goroutine
|
||||||
|
go closeConns(stack)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.ret <- nil
|
||||||
|
|
||||||
|
case conn := <-t.yield:
|
||||||
|
|
||||||
|
// no proto here, infer from config and conn
|
||||||
|
if _, ok := conn.Conn.(*net.UDPConn); ok {
|
||||||
|
t.conns["udp"] = append(t.conns["udp"], &persistConn{conn, time.Now()})
|
||||||
|
continue Wait
|
||||||
|
}
|
||||||
|
|
||||||
|
if t.tlsConfig == nil {
|
||||||
|
t.conns["tcp"] = append(t.conns["tcp"], &persistConn{conn, time.Now()})
|
||||||
|
continue Wait
|
||||||
|
}
|
||||||
|
|
||||||
|
t.conns["tcp-tls"] = append(t.conns["tcp-tls"], &persistConn{conn, time.Now()})
|
||||||
|
|
||||||
|
case <-ticker.C:
|
||||||
|
t.cleanup(false)
|
||||||
|
|
||||||
|
case <-t.stop:
|
||||||
|
t.cleanup(true)
|
||||||
|
close(t.ret)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// closeConns closes connections.
|
||||||
|
func closeConns(conns []*persistConn) {
|
||||||
|
for _, pc := range conns {
|
||||||
|
pc.c.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanup removes connections from cache.
|
||||||
|
func (t *Transport) cleanup(all bool) {
|
||||||
|
staleTime := time.Now().Add(-t.expire)
|
||||||
|
for proto, stack := range t.conns {
|
||||||
|
if len(stack) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if all {
|
||||||
|
t.conns[proto] = nil
|
||||||
|
// now, the connections being passed to closeConns() are not reachable from
|
||||||
|
// transport methods anymore. So, it's safe to close them in a separate goroutine
|
||||||
|
go closeConns(stack)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if stack[0].used.After(staleTime) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// connections in stack are sorted by "used"
|
||||||
|
good := sort.Search(len(stack), func(i int) bool {
|
||||||
|
return stack[i].used.After(staleTime)
|
||||||
|
})
|
||||||
|
t.conns[proto] = stack[good:]
|
||||||
|
// now, the connections being passed to closeConns() are not reachable from
|
||||||
|
// transport methods anymore. So, it's safe to close them in a separate goroutine
|
||||||
|
go closeConns(stack[:good])
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,84 @@
|
||||||
|
package upstream
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
|
||||||
|
"github.com/coredns/coredns/core/dnsserver"
|
||||||
|
"github.com/coredns/coredns/plugin"
|
||||||
|
"github.com/mholt/caddy"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
caddy.RegisterPlugin("upstream", caddy.Plugin{
|
||||||
|
ServerType: "dns",
|
||||||
|
Action: setup,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the configuration and initialize upstreams
|
||||||
|
func setup(c *caddy.Controller) error {
|
||||||
|
|
||||||
|
p, err := setupPlugin(c)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
config := dnsserver.GetConfig(c)
|
||||||
|
config.AddPlugin(func(next plugin.Handler) plugin.Handler {
|
||||||
|
p.Next = next
|
||||||
|
return p
|
||||||
|
})
|
||||||
|
|
||||||
|
c.OnShutdown(p.onShutdown)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the configuration
|
||||||
|
func setupPlugin(c *caddy.Controller) (*UpstreamPlugin, error) {
|
||||||
|
|
||||||
|
p := New()
|
||||||
|
|
||||||
|
log.Println("Initializing the Upstream plugin")
|
||||||
|
|
||||||
|
bootstrap := ""
|
||||||
|
upstreamUrls := []string{}
|
||||||
|
for c.Next() {
|
||||||
|
args := c.RemainingArgs()
|
||||||
|
if len(args) > 0 {
|
||||||
|
upstreamUrls = append(upstreamUrls, args...)
|
||||||
|
}
|
||||||
|
for c.NextBlock() {
|
||||||
|
switch c.Val() {
|
||||||
|
case "bootstrap":
|
||||||
|
if !c.NextArg() {
|
||||||
|
return nil, c.ArgErr()
|
||||||
|
}
|
||||||
|
bootstrap = c.Val()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, url := range upstreamUrls {
|
||||||
|
u, err := NewUpstream(url, bootstrap)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Cannot initialize upstream %s", url)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
p.Upstreams = append(p.Upstreams, u)
|
||||||
|
}
|
||||||
|
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *UpstreamPlugin) onShutdown() error {
|
||||||
|
for i := range p.Upstreams {
|
||||||
|
|
||||||
|
u := p.Upstreams[i]
|
||||||
|
err := u.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Error while closing the upstream: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -0,0 +1,30 @@
|
||||||
|
package upstream
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/mholt/caddy"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSetup(t *testing.T) {
|
||||||
|
|
||||||
|
var tests = []struct {
|
||||||
|
config string
|
||||||
|
}{
|
||||||
|
{`upstream 8.8.8.8`},
|
||||||
|
{`upstream 8.8.8.8 {
|
||||||
|
bootstrap 8.8.8.8:53
|
||||||
|
}`},
|
||||||
|
{`upstream tls://1.1.1.1 8.8.8.8 {
|
||||||
|
bootstrap 1.1.1.1
|
||||||
|
}`},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
c := caddy.NewTestController("dns", test.config)
|
||||||
|
err := setup(c)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Test failed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,57 @@
|
||||||
|
package upstream
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/coredns/coredns/plugin"
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultTimeout = 5 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// Upstream is a simplified interface for proxy destination
|
||||||
|
type Upstream interface {
|
||||||
|
Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error)
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamPlugin is a simplified DNS proxy using a generic upstream interface
|
||||||
|
type UpstreamPlugin struct {
|
||||||
|
Upstreams []Upstream
|
||||||
|
Next plugin.Handler
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize the upstream plugin
|
||||||
|
func New() *UpstreamPlugin {
|
||||||
|
p := &UpstreamPlugin{
|
||||||
|
Upstreams: []Upstream{},
|
||||||
|
}
|
||||||
|
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServeDNS implements interface for CoreDNS plugin
|
||||||
|
func (p *UpstreamPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||||
|
var reply *dns.Msg
|
||||||
|
var backendErr error
|
||||||
|
|
||||||
|
for i := range p.Upstreams {
|
||||||
|
upstream := p.Upstreams[i]
|
||||||
|
reply, backendErr = upstream.Exchange(ctx, r)
|
||||||
|
if backendErr == nil {
|
||||||
|
w.WriteMsg(reply)
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return dns.RcodeServerFailure, errors.Wrap(backendErr, "failed to contact any of the upstreams")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name implements interface for CoreDNS plugin
|
||||||
|
func (p *UpstreamPlugin) Name() string {
|
||||||
|
return "upstream"
|
||||||
|
}
|
|
@ -0,0 +1,194 @@
|
||||||
|
package upstream
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDnsUpstreamIsAlive(t *testing.T) {
|
||||||
|
|
||||||
|
var tests = []struct {
|
||||||
|
url string
|
||||||
|
bootstrap string
|
||||||
|
}{
|
||||||
|
{"8.8.8.8:53", "8.8.8.8:53"},
|
||||||
|
{"1.1.1.1", ""},
|
||||||
|
{"tcp://1.1.1.1:53", ""},
|
||||||
|
{"176.103.130.130:5353", ""},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
u, err := NewUpstream(test.url, test.bootstrap)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("cannot create a DNS upstream")
|
||||||
|
}
|
||||||
|
|
||||||
|
testUpstreamIsAlive(t, u)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHttpsUpstreamIsAlive(t *testing.T) {
|
||||||
|
|
||||||
|
var tests = []struct {
|
||||||
|
url string
|
||||||
|
bootstrap string
|
||||||
|
}{
|
||||||
|
{"https://cloudflare-dns.com/dns-query", "8.8.8.8:53"},
|
||||||
|
{"https://dns.google.com/experimental", "8.8.8.8:53"},
|
||||||
|
{"https://doh.cleanbrowsing.org/doh/security-filter/", ""},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
u, err := NewUpstream(test.url, test.bootstrap)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("cannot create a DNS-over-HTTPS upstream")
|
||||||
|
}
|
||||||
|
|
||||||
|
testUpstreamIsAlive(t, u)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDnsOverTlsIsAlive(t *testing.T) {
|
||||||
|
|
||||||
|
var tests = []struct {
|
||||||
|
url string
|
||||||
|
bootstrap string
|
||||||
|
}{
|
||||||
|
{"tls://1.1.1.1", ""},
|
||||||
|
{"tls://9.9.9.9:853", ""},
|
||||||
|
{"tls://security-filter-dns.cleanbrowsing.org", "8.8.8.8:53"},
|
||||||
|
{"tls://adult-filter-dns.cleanbrowsing.org:853", "8.8.8.8:53"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
u, err := NewUpstream(test.url, test.bootstrap)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("cannot create a DNS-over-TLS upstream")
|
||||||
|
}
|
||||||
|
|
||||||
|
testUpstreamIsAlive(t, u)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDnsUpstream(t *testing.T) {
|
||||||
|
|
||||||
|
var tests = []struct {
|
||||||
|
url string
|
||||||
|
bootstrap string
|
||||||
|
}{
|
||||||
|
{"8.8.8.8:53", "8.8.8.8:53"},
|
||||||
|
{"1.1.1.1", ""},
|
||||||
|
{"tcp://1.1.1.1:53", ""},
|
||||||
|
{"176.103.130.130:5353", ""},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
u, err := NewUpstream(test.url, test.bootstrap)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("cannot create a DNS upstream")
|
||||||
|
}
|
||||||
|
|
||||||
|
testUpstream(t, u)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHttpsUpstream(t *testing.T) {
|
||||||
|
|
||||||
|
var tests = []struct {
|
||||||
|
url string
|
||||||
|
bootstrap string
|
||||||
|
}{
|
||||||
|
{"https://cloudflare-dns.com/dns-query", "8.8.8.8:53"},
|
||||||
|
{"https://dns.google.com/experimental", "8.8.8.8:53"},
|
||||||
|
{"https://doh.cleanbrowsing.org/doh/security-filter/", ""},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
u, err := NewUpstream(test.url, test.bootstrap)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("cannot create a DNS-over-HTTPS upstream")
|
||||||
|
}
|
||||||
|
|
||||||
|
testUpstream(t, u)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDnsOverTlsUpstream(t *testing.T) {
|
||||||
|
|
||||||
|
var tests = []struct {
|
||||||
|
url string
|
||||||
|
bootstrap string
|
||||||
|
}{
|
||||||
|
{"tls://1.1.1.1", ""},
|
||||||
|
{"tls://9.9.9.9:853", ""},
|
||||||
|
{"tls://security-filter-dns.cleanbrowsing.org", "8.8.8.8:53"},
|
||||||
|
{"tls://adult-filter-dns.cleanbrowsing.org:853", "8.8.8.8:53"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
u, err := NewUpstream(test.url, test.bootstrap)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("cannot create a DNS-over-TLS upstream")
|
||||||
|
}
|
||||||
|
|
||||||
|
testUpstream(t, u)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testUpstreamIsAlive(t *testing.T, u Upstream) {
|
||||||
|
alive, err := IsAlive(u)
|
||||||
|
if !alive || err != nil {
|
||||||
|
t.Errorf("Upstream is not alive")
|
||||||
|
}
|
||||||
|
|
||||||
|
u.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func testUpstream(t *testing.T, u Upstream) {
|
||||||
|
|
||||||
|
var tests = []struct {
|
||||||
|
name string
|
||||||
|
expected net.IP
|
||||||
|
}{
|
||||||
|
{"google-public-dns-a.google.com.", net.IPv4(8, 8, 8, 8)},
|
||||||
|
{"google-public-dns-b.google.com.", net.IPv4(8, 8, 4, 4)},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
req := dns.Msg{}
|
||||||
|
req.Id = dns.Id()
|
||||||
|
req.RecursionDesired = true
|
||||||
|
req.Question = []dns.Question{
|
||||||
|
{Name: test.name, Qtype: dns.TypeA, Qclass: dns.ClassINET},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := u.Exchange(context.Background(), &req)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("error while making an upstream request: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(resp.Answer) != 1 {
|
||||||
|
t.Errorf("no answer section in the response")
|
||||||
|
}
|
||||||
|
if answer, ok := resp.Answer[0].(*dns.A); ok {
|
||||||
|
if !test.expected.Equal(answer.A) {
|
||||||
|
t.Errorf("wrong IP in the response: %v", answer.A)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err := u.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Error while closing the upstream: %s", err)
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue