diff --git a/README.md b/README.md index 8464a7dd..b7b43446 100644 --- a/README.md +++ b/README.md @@ -106,8 +106,10 @@ Settings are stored in [YAML format](https://en.wikipedia.org/wiki/YAML), possib * `parental_enabled` — Parental control-based DNS requests filtering * `parental_sensitivity` — Age group for parental control-based filtering, must be either 3, 10, 13 or 17 * `querylog_enabled` — Query logging (also used to calculate top 50 clients, blocked domains and requested domains for statistic purposes) + * `bootstrap_dns` — DNS server used for initial hostnames resolution in case if upstream is DoH or DoT with a hostname * `upstream_dns` — List of upstream DNS servers * `filters` — List of filters, each filter has the following values: + * `ID` - filter ID (must be unique) * `url` — URL pointing to the filter contents (filtering rules) * `enabled` — Current filter's status (enabled/disabled) * `user_rules` — User-specified filtering rules diff --git a/config.go b/config.go index bf63c0ee..a8534bc5 100644 --- a/config.go +++ b/config.go @@ -70,6 +70,7 @@ type coreDNSConfig struct { Pprof string `yaml:"-"` Cache string `yaml:"-"` Prometheus string `yaml:"-"` + BootstrapDNS string `yaml:"bootstrap_dns"` UpstreamDNS []string `yaml:"upstream_dns"` } @@ -100,6 +101,7 @@ var config = configuration{ SafeBrowsingEnabled: false, BlockedResponseTTL: 10, // in seconds QueryLogEnabled: true, + BootstrapDNS: "8.8.8.8:53", UpstreamDNS: defaultDNS, Cache: "cache", Prometheus: "prometheus :9153", @@ -253,7 +255,7 @@ const coreDNSConfigTemplate = `.:{{.Port}} { hosts { fallthrough } - {{if .UpstreamDNS}}forward . {{range .UpstreamDNS}}{{.}} {{end}}{{end}} + {{if .UpstreamDNS}}upstream {{range .UpstreamDNS}}{{.}} {{end}} { bootstrap {{.BootstrapDNS}} }{{end}} {{.Cache}} {{.Prometheus}} } diff --git a/control.go b/control.go index 1e1084e8..63a387b0 100644 --- a/control.go +++ b/control.go @@ -6,7 +6,6 @@ import ( "fmt" "io/ioutil" "log" - "net" "net/http" "os" "path/filepath" @@ -15,8 +14,9 @@ import ( "strings" "time" + "github.com/AdguardTeam/AdGuardHome/upstream" + corednsplugin "github.com/AdguardTeam/AdGuardHome/coredns_plugin" - "github.com/miekg/dns" "gopkg.in/asaskevich/govalidator.v4" ) @@ -81,6 +81,7 @@ func handleStatus(w http.ResponseWriter, r *http.Request) { "protection_enabled": config.CoreDNS.ProtectionEnabled, "querylog_enabled": config.CoreDNS.QueryLogEnabled, "running": isRunning(), + "bootstrap_dns": config.CoreDNS.BootstrapDNS, "upstream_dns": config.CoreDNS.UpstreamDNS, "version": VersionString, } @@ -134,17 +135,14 @@ func httpError(w http.ResponseWriter, code int, format string, args ...interface func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) { body, err := ioutil.ReadAll(r.Body) if err != nil { - errortext := fmt.Sprintf("Failed to read request body: %s", err) - log.Println(errortext) - http.Error(w, errortext, http.StatusBadRequest) + errorText := fmt.Sprintf("Failed to read request body: %s", err) + log.Println(errorText) + http.Error(w, errorText, http.StatusBadRequest) return } // if empty body -- user is asking for default servers - hosts, err := sanitiseDNSServers(string(body)) - if err != nil { - httpError(w, http.StatusBadRequest, "Invalid DNS servers were given: %s", err) - return - } + hosts := strings.Fields(string(body)) + if len(hosts) == 0 { config.CoreDNS.UpstreamDNS = defaultDNS } else { @@ -153,34 +151,34 @@ func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) { err = writeAllConfigs() if err != nil { - errortext := fmt.Sprintf("Couldn't write config file: %s", err) - log.Println(errortext) - http.Error(w, errortext, http.StatusInternalServerError) + errorText := fmt.Sprintf("Couldn't write config file: %s", err) + log.Println(errorText) + http.Error(w, errorText, http.StatusInternalServerError) return } tellCoreDNSToReload() _, err = fmt.Fprintf(w, "OK %d servers\n", len(hosts)) if err != nil { - errortext := fmt.Sprintf("Couldn't write body: %s", err) - log.Println(errortext) - http.Error(w, errortext, http.StatusInternalServerError) + errorText := fmt.Sprintf("Couldn't write body: %s", err) + log.Println(errorText) + http.Error(w, errorText, http.StatusInternalServerError) } } func handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) { body, err := ioutil.ReadAll(r.Body) if err != nil { - errortext := fmt.Sprintf("Failed to read request body: %s", err) - log.Println(errortext) - http.Error(w, errortext, 400) + errorText := fmt.Sprintf("Failed to read request body: %s", err) + log.Println(errorText) + http.Error(w, errorText, 400) return } hosts := strings.Fields(string(body)) if len(hosts) == 0 { - errortext := fmt.Sprintf("No servers specified") - log.Println(errortext) - http.Error(w, errortext, http.StatusBadRequest) + errorText := fmt.Sprintf("No servers specified") + log.Println(errorText) + http.Error(w, errorText, http.StatusBadRequest) return } @@ -198,120 +196,43 @@ func handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) { jsonVal, err := json.Marshal(result) if err != nil { - errortext := fmt.Sprintf("Unable to marshal status json: %s", err) - log.Println(errortext) - http.Error(w, errortext, http.StatusInternalServerError) + errorText := fmt.Sprintf("Unable to marshal status json: %s", err) + log.Println(errorText) + http.Error(w, errorText, http.StatusInternalServerError) return } w.Header().Set("Content-Type", "application/json") _, err = w.Write(jsonVal) if err != nil { - errortext := fmt.Sprintf("Couldn't write body: %s", err) - log.Println(errortext) - http.Error(w, errortext, http.StatusInternalServerError) + errorText := fmt.Sprintf("Couldn't write body: %s", err) + log.Println(errorText) + http.Error(w, errorText, http.StatusInternalServerError) } } func checkDNS(input string) error { - input, err := sanitizeDNSServer(input) + + u, err := upstream.NewUpstream(input, config.CoreDNS.BootstrapDNS) + if err != nil { return err } + defer u.Close() - req := dns.Msg{} - req.Id = dns.Id() - req.RecursionDesired = true - req.Question = []dns.Question{ - {Name: "google-public-dns-a.google.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, - } + alive, err := upstream.IsAlive(u) - prefix, host := splitDNSServerPrefixServer(input) - - c := dns.Client{ - Timeout: time.Minute, - } - switch prefix { - case "tls://": - c.Net = "tcp-tls" - } - - resp, rtt, err := c.Exchange(&req, host) if err != nil { return fmt.Errorf("couldn't communicate with DNS server %s: %s", input, err) } - trace("exchange with %s took %v", input, rtt) - if len(resp.Answer) != 1 { - return fmt.Errorf("DNS server %s returned wrong answer", input) - } - if t, ok := resp.Answer[0].(*dns.A); ok { - if !net.IPv4(8, 8, 8, 8).Equal(t.A) { - return fmt.Errorf("DNS server %s returned wrong answer: %v", input, t.A) - } + + if !alive { + return fmt.Errorf("DNS server has not passed the healthcheck: %s", input) } return nil } -func sanitiseDNSServers(input string) ([]string, error) { - fields := strings.Fields(input) - hosts := make([]string, 0) - for _, field := range fields { - sanitized, err := sanitizeDNSServer(field) - if err != nil { - return hosts, err - } - hosts = append(hosts, sanitized) - } - return hosts, nil -} - -func getDNSServerPrefix(input string) string { - prefix := "" - switch { - case strings.HasPrefix(input, "dns://"): - prefix = "dns://" - case strings.HasPrefix(input, "tls://"): - prefix = "tls://" - } - return prefix -} - -func splitDNSServerPrefixServer(input string) (string, string) { - prefix := getDNSServerPrefix(input) - host := strings.TrimPrefix(input, prefix) - return prefix, host -} - -func sanitizeDNSServer(input string) (string, error) { - prefix, host := splitDNSServerPrefixServer(input) - host = appendPortIfMissing(prefix, host) - { - h, _, err := net.SplitHostPort(host) - if err != nil { - return "", err - } - ip := net.ParseIP(h) - if ip == nil { - return "", fmt.Errorf("invalid DNS server field: %s", h) - } - } - return prefix + host, nil -} - -func appendPortIfMissing(prefix, input string) string { - port := "53" - switch prefix { - case "tls://": - port = "853" - } - _, _, err := net.SplitHostPort(input) - if err == nil { - return input - } - return net.JoinHostPort(input, port) -} - //noinspection GoUnusedParameter func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) { now := time.Now() diff --git a/coredns.go b/coredns.go index 5dbe01b4..a21fb986 100644 --- a/coredns.go +++ b/coredns.go @@ -8,6 +8,7 @@ import ( "sync" // Include all plugins. _ "github.com/AdguardTeam/AdGuardHome/coredns_plugin" + _ "github.com/AdguardTeam/AdGuardHome/upstream" "github.com/coredns/coredns/core/dnsserver" "github.com/coredns/coredns/coremain" _ "github.com/coredns/coredns/plugin/auto" @@ -79,6 +80,7 @@ var directives = []string{ "loop", "forward", "proxy", + "upstream", "erratic", "whoami", "on", diff --git a/openapi.yaml b/openapi.yaml index 35e32a90..9fc585f2 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -41,6 +41,7 @@ paths: protection_enabled: true querylog_enabled: true running: true + bootstrap_dns: 8.8.8.8:53 upstream_dns: - 1.1.1.1 - 1.0.0.1 diff --git a/upstream/dns_upstream.go b/upstream/dns_upstream.go new file mode 100644 index 00000000..2902ca2e --- /dev/null +++ b/upstream/dns_upstream.go @@ -0,0 +1,109 @@ +package upstream + +import ( + "crypto/tls" + "time" + + "github.com/miekg/dns" + "golang.org/x/net/context" +) + +// DnsUpstream is a very simple upstream implementation for plain DNS +type DnsUpstream struct { + endpoint string // IP:port + timeout time.Duration // Max read and write timeout + proto string // Protocol (tcp, tcp-tls, or udp) + transport *Transport // Persistent connections cache +} + +// NewDnsUpstream creates a new DNS upstream +func NewDnsUpstream(endpoint string, proto string, tlsServerName string) (Upstream, error) { + + u := &DnsUpstream{ + endpoint: endpoint, + timeout: defaultTimeout, + proto: proto, + } + + var tlsConfig *tls.Config + + if proto == "tcp-tls" { + tlsConfig = new(tls.Config) + tlsConfig.ServerName = tlsServerName + } + + // Initialize the connections cache + u.transport = NewTransport(endpoint) + u.transport.tlsConfig = tlsConfig + u.transport.Start() + + return u, nil +} + +// Exchange provides an implementation for the Upstream interface +func (u *DnsUpstream) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) { + + resp, err := u.exchange(u.proto, query) + + // Retry over TCP if response is truncated + if err == dns.ErrTruncated && u.proto == "udp" { + resp, err = u.exchange("tcp", query) + } else if err == dns.ErrTruncated && resp != nil { + // Reassemble something to be sent to client + m := new(dns.Msg) + m.SetReply(query) + m.Truncated = true + m.Authoritative = true + m.Rcode = dns.RcodeSuccess + return m, nil + } + + if err != nil { + resp = &dns.Msg{} + resp.SetRcode(resp, dns.RcodeServerFailure) + } + + return resp, err +} + +// Clear resources +func (u *DnsUpstream) Close() error { + + // Close active connections + u.transport.Stop() + return nil +} + +// Performs a synchronous query. It sends the message m via the conn +// c and waits for a reply. The conn c is not closed. +func (u *DnsUpstream) exchange(proto string, query *dns.Msg) (r *dns.Msg, err error) { + + // Establish a connection if needed (or reuse cached) + conn, err := u.transport.Dial(proto) + if err != nil { + return nil, err + } + + // Write the request with a timeout + conn.SetWriteDeadline(time.Now().Add(u.timeout)) + if err = conn.WriteMsg(query); err != nil { + conn.Close() // Not giving it back + return nil, err + } + + // Write response with a timeout + conn.SetReadDeadline(time.Now().Add(u.timeout)) + r, err = conn.ReadMsg() + if err != nil { + conn.Close() // Not giving it back + } else if err == nil && r.Id != query.Id { + err = dns.ErrId + conn.Close() // Not giving it back + } + + if err == nil { + // Return it back to the connections cache if there were no errors + u.transport.Yield(conn) + } + return r, err +} diff --git a/upstream/helpers.go b/upstream/helpers.go new file mode 100644 index 00000000..209da533 --- /dev/null +++ b/upstream/helpers.go @@ -0,0 +1,101 @@ +package upstream + +import ( + "net" + "strings" + + "github.com/miekg/dns" + "golang.org/x/net/context" +) + +// Detects the upstream type from the specified url and creates a proper Upstream object +func NewUpstream(url string, bootstrap string) (Upstream, error) { + + proto := "udp" + prefix := "" + + switch { + case strings.HasPrefix(url, "tcp://"): + proto = "tcp" + prefix = "tcp://" + case strings.HasPrefix(url, "tls://"): + proto = "tcp-tls" + prefix = "tls://" + case strings.HasPrefix(url, "https://"): + return NewHttpsUpstream(url, bootstrap) + } + + hostname := strings.TrimPrefix(url, prefix) + + host, port, err := net.SplitHostPort(hostname) + if err != nil { + // Set port depending on the protocol + switch proto { + case "udp": + port = "53" + case "tcp": + port = "53" + case "tcp-tls": + port = "853" + } + + // Set host = hostname + host = hostname + } + + // Try to resolve the host address (or check if it's an IP address) + bootstrapResolver := CreateResolver(bootstrap) + ips, err := bootstrapResolver.LookupIPAddr(context.Background(), host) + + if err != nil || len(ips) == 0 { + return nil, err + } + + addr := ips[0].String() + endpoint := net.JoinHostPort(addr, port) + tlsServerName := "" + + if proto == "tcp-tls" && host != addr { + // Check if we need to specify TLS server name + tlsServerName = host + } + + return NewDnsUpstream(endpoint, proto, tlsServerName) +} + +func CreateResolver(bootstrap string) *net.Resolver { + + bootstrapResolver := net.DefaultResolver + + if bootstrap != "" { + bootstrapResolver = &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, network, bootstrap) + }, + } + } + + return bootstrapResolver +} + +// Performs a simple health-check of the specified upstream +func IsAlive(u Upstream) (bool, error) { + + // Using ipv4only.arpa. domain as it is a part of DNS64 RFC and it should exist everywhere + ping := new(dns.Msg) + ping.SetQuestion("ipv4only.arpa.", dns.TypeA) + + resp, err := u.Exchange(context.Background(), ping) + + // If we got a header, we're alright, basically only care about I/O errors 'n stuff. + if err != nil && resp != nil { + // Silly check, something sane came back. + if resp.Rcode != dns.RcodeServerFailure { + err = nil + } + } + + return err == nil, err +} diff --git a/upstream/https_upstream.go b/upstream/https_upstream.go new file mode 100644 index 00000000..d7d7bdde --- /dev/null +++ b/upstream/https_upstream.go @@ -0,0 +1,128 @@ +package upstream + +import ( + "bytes" + "crypto/tls" + "fmt" + "io/ioutil" + "log" + "net" + "net/http" + "net/url" + "time" + + "github.com/miekg/dns" + "github.com/pkg/errors" + "golang.org/x/net/context" + "golang.org/x/net/http2" +) + +const ( + dnsMessageContentType = "application/dns-message" + defaultKeepAlive = 30 * time.Second +) + +// HttpsUpstream is the upstream implementation for DNS-over-HTTPS +type HttpsUpstream struct { + client *http.Client + endpoint *url.URL +} + +// NewHttpsUpstream creates a new DNS-over-HTTPS upstream from the specified url +func NewHttpsUpstream(endpoint string, bootstrap string) (Upstream, error) { + u, err := url.Parse(endpoint) + if err != nil { + return nil, err + } + + // Initialize bootstrap resolver + bootstrapResolver := CreateResolver(bootstrap) + dialer := &net.Dialer{ + Timeout: defaultTimeout, + KeepAlive: defaultKeepAlive, + DualStack: true, + Resolver: bootstrapResolver, + } + + // Update TLS and HTTP client configuration + tlsConfig := &tls.Config{ServerName: u.Hostname()} + transport := &http.Transport{ + TLSClientConfig: tlsConfig, + DisableCompression: true, + MaxIdleConns: 1, + DialContext: dialer.DialContext, + } + http2.ConfigureTransport(transport) + + client := &http.Client{ + Timeout: defaultTimeout, + Transport: transport, + } + + return &HttpsUpstream{client: client, endpoint: u}, nil +} + +// Exchange provides an implementation for the Upstream interface +func (u *HttpsUpstream) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) { + queryBuf, err := query.Pack() + if err != nil { + return nil, errors.Wrap(err, "failed to pack DNS query") + } + + // No content negotiation for now, use DNS wire format + buf, backendErr := u.exchangeWireformat(queryBuf) + if backendErr == nil { + response := &dns.Msg{} + if err := response.Unpack(buf); err != nil { + return nil, errors.Wrap(err, "failed to unpack DNS response from body") + } + + response.Id = query.Id + return response, nil + } + + log.Printf("failed to connect to an HTTPS backend %q due to %s", u.endpoint, backendErr) + return nil, backendErr +} + +// Perform message exchange with the default UDP wireformat defined in current draft +// https://tools.ietf.org/html/draft-ietf-doh-dns-over-https-10 +func (u *HttpsUpstream) exchangeWireformat(msg []byte) ([]byte, error) { + req, err := http.NewRequest("POST", u.endpoint.String(), bytes.NewBuffer(msg)) + if err != nil { + return nil, errors.Wrap(err, "failed to create an HTTPS request") + } + + req.Header.Add("Content-Type", dnsMessageContentType) + req.Header.Add("Accept", dnsMessageContentType) + req.Host = u.endpoint.Hostname() + + resp, err := u.client.Do(req) + if err != nil { + return nil, errors.Wrap(err, "failed to perform an HTTPS request") + } + + // Check response status code + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("returned status code %d", resp.StatusCode) + } + + contentType := resp.Header.Get("Content-Type") + if contentType != dnsMessageContentType { + return nil, fmt.Errorf("return wrong content type %s", contentType) + } + + // Read application/dns-message response from the body + buf, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, errors.Wrap(err, "failed to read the response body") + } + + return buf, nil +} + +// Clear resources +func (u *HttpsUpstream) Close() error { + return nil +} diff --git a/upstream/persistent.go b/upstream/persistent.go new file mode 100644 index 00000000..91cc9094 --- /dev/null +++ b/upstream/persistent.go @@ -0,0 +1,210 @@ +package upstream + +import ( + "crypto/tls" + "net" + "sort" + "sync/atomic" + "time" + + "github.com/miekg/dns" +) + +// Persistent connections cache -- almost similar to the same used in the CoreDNS forward plugin + +const ( + defaultExpire = 10 * time.Second + minDialTimeout = 100 * time.Millisecond + maxDialTimeout = 30 * time.Second + defaultDialTimeout = 30 * time.Second + cumulativeAvgWeight = 4 +) + +// a persistConn hold the dns.Conn and the last used time. +type persistConn struct { + c *dns.Conn + used time.Time +} + +// Transport hold the persistent cache. +type Transport struct { + avgDialTime int64 // kind of average time of dial time + conns map[string][]*persistConn // Buckets for udp, tcp and tcp-tls. + expire time.Duration // After this duration a connection is expired. + addr string + tlsConfig *tls.Config + + dial chan string + yield chan *dns.Conn + ret chan *dns.Conn + stop chan bool +} + +// Dial dials the address configured in transport, potentially reusing a connection or creating a new one. +func (t *Transport) Dial(proto string) (*dns.Conn, error) { + // If tls has been configured; use it. + if t.tlsConfig != nil { + proto = "tcp-tls" + } + + t.dial <- proto + c := <-t.ret + + if c != nil { + return c, nil + } + + reqTime := time.Now() + timeout := t.dialTimeout() + if proto == "tcp-tls" { + conn, err := dns.DialTimeoutWithTLS(proto, t.addr, t.tlsConfig, timeout) + t.updateDialTimeout(time.Since(reqTime)) + return conn, err + } + conn, err := dns.DialTimeout(proto, t.addr, timeout) + t.updateDialTimeout(time.Since(reqTime)) + return conn, err +} + +// Yield return the connection to transport for reuse. +func (t *Transport) Yield(c *dns.Conn) { t.yield <- c } + +// Start starts the transport's connection manager. +func (t *Transport) Start() { go t.connManager() } + +// Stop stops the transport's connection manager. +func (t *Transport) Stop() { close(t.stop) } + +// SetExpire sets the connection expire time in transport. +func (t *Transport) SetExpire(expire time.Duration) { t.expire = expire } + +// SetTLSConfig sets the TLS config in transport. +func (t *Transport) SetTLSConfig(cfg *tls.Config) { t.tlsConfig = cfg } + +func NewTransport(addr string) *Transport { + t := &Transport{ + avgDialTime: int64(defaultDialTimeout / 2), + conns: make(map[string][]*persistConn), + expire: defaultExpire, + addr: addr, + dial: make(chan string), + yield: make(chan *dns.Conn), + ret: make(chan *dns.Conn), + stop: make(chan bool), + } + return t +} + +func averageTimeout(currentAvg *int64, observedDuration time.Duration, weight int64) { + dt := time.Duration(atomic.LoadInt64(currentAvg)) + atomic.AddInt64(currentAvg, int64(observedDuration-dt)/weight) +} + +func (t *Transport) dialTimeout() time.Duration { + return limitTimeout(&t.avgDialTime, minDialTimeout, maxDialTimeout) +} + +func (t *Transport) updateDialTimeout(newDialTime time.Duration) { + averageTimeout(&t.avgDialTime, newDialTime, cumulativeAvgWeight) +} + +// limitTimeout is a utility function to auto-tune timeout values +// average observed time is moved towards the last observed delay moderated by a weight +// next timeout to use will be the double of the computed average, limited by min and max frame. +func limitTimeout(currentAvg *int64, minValue time.Duration, maxValue time.Duration) time.Duration { + rt := time.Duration(atomic.LoadInt64(currentAvg)) + if rt < minValue { + return minValue + } + if rt < maxValue/2 { + return 2 * rt + } + return maxValue +} + +// connManagers manages the persistent connection cache for UDP and TCP. +func (t *Transport) connManager() { + ticker := time.NewTicker(t.expire) +Wait: + for { + select { + case proto := <-t.dial: + // take the last used conn - complexity O(1) + if stack := t.conns[proto]; len(stack) > 0 { + pc := stack[len(stack)-1] + if time.Since(pc.used) < t.expire { + // Found one, remove from pool and return this conn. + t.conns[proto] = stack[:len(stack)-1] + t.ret <- pc.c + continue Wait + } + // clear entire cache if the last conn is expired + t.conns[proto] = nil + // now, the connections being passed to closeConns() are not reachable from + // transport methods anymore. So, it's safe to close them in a separate goroutine + go closeConns(stack) + } + + t.ret <- nil + + case conn := <-t.yield: + + // no proto here, infer from config and conn + if _, ok := conn.Conn.(*net.UDPConn); ok { + t.conns["udp"] = append(t.conns["udp"], &persistConn{conn, time.Now()}) + continue Wait + } + + if t.tlsConfig == nil { + t.conns["tcp"] = append(t.conns["tcp"], &persistConn{conn, time.Now()}) + continue Wait + } + + t.conns["tcp-tls"] = append(t.conns["tcp-tls"], &persistConn{conn, time.Now()}) + + case <-ticker.C: + t.cleanup(false) + + case <-t.stop: + t.cleanup(true) + close(t.ret) + return + } + } +} + +// closeConns closes connections. +func closeConns(conns []*persistConn) { + for _, pc := range conns { + pc.c.Close() + } +} + +// cleanup removes connections from cache. +func (t *Transport) cleanup(all bool) { + staleTime := time.Now().Add(-t.expire) + for proto, stack := range t.conns { + if len(stack) == 0 { + continue + } + if all { + t.conns[proto] = nil + // now, the connections being passed to closeConns() are not reachable from + // transport methods anymore. So, it's safe to close them in a separate goroutine + go closeConns(stack) + continue + } + if stack[0].used.After(staleTime) { + continue + } + + // connections in stack are sorted by "used" + good := sort.Search(len(stack), func(i int) bool { + return stack[i].used.After(staleTime) + }) + t.conns[proto] = stack[good:] + // now, the connections being passed to closeConns() are not reachable from + // transport methods anymore. So, it's safe to close them in a separate goroutine + go closeConns(stack[:good]) + } +} diff --git a/upstream/setup.go b/upstream/setup.go new file mode 100644 index 00000000..56f5da27 --- /dev/null +++ b/upstream/setup.go @@ -0,0 +1,84 @@ +package upstream + +import ( + "log" + + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" + "github.com/mholt/caddy" +) + +func init() { + caddy.RegisterPlugin("upstream", caddy.Plugin{ + ServerType: "dns", + Action: setup, + }) +} + +// Read the configuration and initialize upstreams +func setup(c *caddy.Controller) error { + + p, err := setupPlugin(c) + if err != nil { + return err + } + config := dnsserver.GetConfig(c) + config.AddPlugin(func(next plugin.Handler) plugin.Handler { + p.Next = next + return p + }) + + c.OnShutdown(p.onShutdown) + return nil +} + +// Read the configuration +func setupPlugin(c *caddy.Controller) (*UpstreamPlugin, error) { + + p := New() + + log.Println("Initializing the Upstream plugin") + + bootstrap := "" + upstreamUrls := []string{} + for c.Next() { + args := c.RemainingArgs() + if len(args) > 0 { + upstreamUrls = append(upstreamUrls, args...) + } + for c.NextBlock() { + switch c.Val() { + case "bootstrap": + if !c.NextArg() { + return nil, c.ArgErr() + } + bootstrap = c.Val() + } + } + } + + for _, url := range upstreamUrls { + u, err := NewUpstream(url, bootstrap) + if err != nil { + log.Printf("Cannot initialize upstream %s", url) + return nil, err + } + + p.Upstreams = append(p.Upstreams, u) + } + + return p, nil +} + +func (p *UpstreamPlugin) onShutdown() error { + for i := range p.Upstreams { + + u := p.Upstreams[i] + err := u.Close() + if err != nil { + log.Printf("Error while closing the upstream: %s", err) + } + } + + return nil +} diff --git a/upstream/setup_test.go b/upstream/setup_test.go new file mode 100644 index 00000000..b3918932 --- /dev/null +++ b/upstream/setup_test.go @@ -0,0 +1,30 @@ +package upstream + +import ( + "testing" + + "github.com/mholt/caddy" +) + +func TestSetup(t *testing.T) { + + var tests = []struct { + config string + }{ + {`upstream 8.8.8.8`}, + {`upstream 8.8.8.8 { + bootstrap 8.8.8.8:53 +}`}, + {`upstream tls://1.1.1.1 8.8.8.8 { + bootstrap 1.1.1.1 +}`}, + } + + for _, test := range tests { + c := caddy.NewTestController("dns", test.config) + err := setup(c) + if err != nil { + t.Fatalf("Test failed") + } + } +} diff --git a/upstream/upstream.go b/upstream/upstream.go new file mode 100644 index 00000000..faef224e --- /dev/null +++ b/upstream/upstream.go @@ -0,0 +1,57 @@ +package upstream + +import ( + "time" + + "github.com/coredns/coredns/plugin" + "github.com/miekg/dns" + "github.com/pkg/errors" + "golang.org/x/net/context" +) + +const ( + defaultTimeout = 5 * time.Second +) + +// Upstream is a simplified interface for proxy destination +type Upstream interface { + Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) + Close() error +} + +// UpstreamPlugin is a simplified DNS proxy using a generic upstream interface +type UpstreamPlugin struct { + Upstreams []Upstream + Next plugin.Handler +} + +// Initialize the upstream plugin +func New() *UpstreamPlugin { + p := &UpstreamPlugin{ + Upstreams: []Upstream{}, + } + + return p +} + +// ServeDNS implements interface for CoreDNS plugin +func (p *UpstreamPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + var reply *dns.Msg + var backendErr error + + for i := range p.Upstreams { + upstream := p.Upstreams[i] + reply, backendErr = upstream.Exchange(ctx, r) + if backendErr == nil { + w.WriteMsg(reply) + return 0, nil + } + } + + return dns.RcodeServerFailure, errors.Wrap(backendErr, "failed to contact any of the upstreams") +} + +// Name implements interface for CoreDNS plugin +func (p *UpstreamPlugin) Name() string { + return "upstream" +} diff --git a/upstream/upstream_test.go b/upstream/upstream_test.go new file mode 100644 index 00000000..7ce5690f --- /dev/null +++ b/upstream/upstream_test.go @@ -0,0 +1,194 @@ +package upstream + +import ( + "net" + "testing" + + "github.com/miekg/dns" + "golang.org/x/net/context" +) + +func TestDnsUpstreamIsAlive(t *testing.T) { + + var tests = []struct { + url string + bootstrap string + }{ + {"8.8.8.8:53", "8.8.8.8:53"}, + {"1.1.1.1", ""}, + {"tcp://1.1.1.1:53", ""}, + {"176.103.130.130:5353", ""}, + } + + for _, test := range tests { + u, err := NewUpstream(test.url, test.bootstrap) + + if err != nil { + t.Errorf("cannot create a DNS upstream") + } + + testUpstreamIsAlive(t, u) + } +} + +func TestHttpsUpstreamIsAlive(t *testing.T) { + + var tests = []struct { + url string + bootstrap string + }{ + {"https://cloudflare-dns.com/dns-query", "8.8.8.8:53"}, + {"https://dns.google.com/experimental", "8.8.8.8:53"}, + {"https://doh.cleanbrowsing.org/doh/security-filter/", ""}, + } + + for _, test := range tests { + u, err := NewUpstream(test.url, test.bootstrap) + + if err != nil { + t.Errorf("cannot create a DNS-over-HTTPS upstream") + } + + testUpstreamIsAlive(t, u) + } +} + +func TestDnsOverTlsIsAlive(t *testing.T) { + + var tests = []struct { + url string + bootstrap string + }{ + {"tls://1.1.1.1", ""}, + {"tls://9.9.9.9:853", ""}, + {"tls://security-filter-dns.cleanbrowsing.org", "8.8.8.8:53"}, + {"tls://adult-filter-dns.cleanbrowsing.org:853", "8.8.8.8:53"}, + } + + for _, test := range tests { + u, err := NewUpstream(test.url, test.bootstrap) + + if err != nil { + t.Errorf("cannot create a DNS-over-TLS upstream") + } + + testUpstreamIsAlive(t, u) + } +} + +func TestDnsUpstream(t *testing.T) { + + var tests = []struct { + url string + bootstrap string + }{ + {"8.8.8.8:53", "8.8.8.8:53"}, + {"1.1.1.1", ""}, + {"tcp://1.1.1.1:53", ""}, + {"176.103.130.130:5353", ""}, + } + + for _, test := range tests { + u, err := NewUpstream(test.url, test.bootstrap) + + if err != nil { + t.Errorf("cannot create a DNS upstream") + } + + testUpstream(t, u) + } +} + +func TestHttpsUpstream(t *testing.T) { + + var tests = []struct { + url string + bootstrap string + }{ + {"https://cloudflare-dns.com/dns-query", "8.8.8.8:53"}, + {"https://dns.google.com/experimental", "8.8.8.8:53"}, + {"https://doh.cleanbrowsing.org/doh/security-filter/", ""}, + } + + for _, test := range tests { + u, err := NewUpstream(test.url, test.bootstrap) + + if err != nil { + t.Errorf("cannot create a DNS-over-HTTPS upstream") + } + + testUpstream(t, u) + } +} + +func TestDnsOverTlsUpstream(t *testing.T) { + + var tests = []struct { + url string + bootstrap string + }{ + {"tls://1.1.1.1", ""}, + {"tls://9.9.9.9:853", ""}, + {"tls://security-filter-dns.cleanbrowsing.org", "8.8.8.8:53"}, + {"tls://adult-filter-dns.cleanbrowsing.org:853", "8.8.8.8:53"}, + } + + for _, test := range tests { + u, err := NewUpstream(test.url, test.bootstrap) + + if err != nil { + t.Errorf("cannot create a DNS-over-TLS upstream") + } + + testUpstream(t, u) + } +} + +func testUpstreamIsAlive(t *testing.T, u Upstream) { + alive, err := IsAlive(u) + if !alive || err != nil { + t.Errorf("Upstream is not alive") + } + + u.Close() +} + +func testUpstream(t *testing.T, u Upstream) { + + var tests = []struct { + name string + expected net.IP + }{ + {"google-public-dns-a.google.com.", net.IPv4(8, 8, 8, 8)}, + {"google-public-dns-b.google.com.", net.IPv4(8, 8, 4, 4)}, + } + + for _, test := range tests { + req := dns.Msg{} + req.Id = dns.Id() + req.RecursionDesired = true + req.Question = []dns.Question{ + {Name: test.name, Qtype: dns.TypeA, Qclass: dns.ClassINET}, + } + + resp, err := u.Exchange(context.Background(), &req) + + if err != nil { + t.Errorf("error while making an upstream request: %s", err) + } + + if len(resp.Answer) != 1 { + t.Errorf("no answer section in the response") + } + if answer, ok := resp.Answer[0].(*dns.A); ok { + if !test.expected.Equal(answer.A) { + t.Errorf("wrong IP in the response: %v", answer.A) + } + } + } + + err := u.Close() + if err != nil { + t.Errorf("Error while closing the upstream: %s", err) + } +}