From 484c0ceaff2837db49f0ab679954511a9c916650 Mon Sep 17 00:00:00 2001 From: Andrey Meshkov Date: Thu, 1 Nov 2018 14:45:32 +0300 Subject: [PATCH] Upstream plugin prototype --- control.go | 42 +++++++------- upstream/dns_upstream.go | 36 ++++++++++++ upstream/https_upstream.go | 109 +++++++++++++++++++++++++++++++++++++ upstream/tls_upstream.go | 47 ++++++++++++++++ upstream/upstream.go | 43 +++++++++++++++ upstream/upstream_test.go | 86 +++++++++++++++++++++++++++++ 6 files changed, 342 insertions(+), 21 deletions(-) create mode 100644 upstream/dns_upstream.go create mode 100644 upstream/https_upstream.go create mode 100644 upstream/tls_upstream.go create mode 100644 upstream/upstream.go create mode 100644 upstream/upstream_test.go diff --git a/control.go b/control.go index 1e1084e8..378a2aec 100644 --- a/control.go +++ b/control.go @@ -134,9 +134,9 @@ func httpError(w http.ResponseWriter, code int, format string, args ...interface func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) { body, err := ioutil.ReadAll(r.Body) if err != nil { - errortext := fmt.Sprintf("Failed to read request body: %s", err) - log.Println(errortext) - http.Error(w, errortext, http.StatusBadRequest) + errorText := fmt.Sprintf("Failed to read request body: %s", err) + log.Println(errorText) + http.Error(w, errorText, http.StatusBadRequest) return } // if empty body -- user is asking for default servers @@ -153,34 +153,34 @@ func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) { err = writeAllConfigs() if err != nil { - errortext := fmt.Sprintf("Couldn't write config file: %s", err) - log.Println(errortext) - http.Error(w, errortext, http.StatusInternalServerError) + errorText := fmt.Sprintf("Couldn't write config file: %s", err) + log.Println(errorText) + http.Error(w, errorText, http.StatusInternalServerError) return } tellCoreDNSToReload() _, err = fmt.Fprintf(w, "OK %d servers\n", len(hosts)) if err != nil { - errortext := fmt.Sprintf("Couldn't write body: %s", err) - log.Println(errortext) - http.Error(w, errortext, http.StatusInternalServerError) + errorText := fmt.Sprintf("Couldn't write body: %s", err) + log.Println(errorText) + http.Error(w, errorText, http.StatusInternalServerError) } } func handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) { body, err := ioutil.ReadAll(r.Body) if err != nil { - errortext := fmt.Sprintf("Failed to read request body: %s", err) - log.Println(errortext) - http.Error(w, errortext, 400) + errorText := fmt.Sprintf("Failed to read request body: %s", err) + log.Println(errorText) + http.Error(w, errorText, 400) return } hosts := strings.Fields(string(body)) if len(hosts) == 0 { - errortext := fmt.Sprintf("No servers specified") - log.Println(errortext) - http.Error(w, errortext, http.StatusBadRequest) + errorText := fmt.Sprintf("No servers specified") + log.Println(errorText) + http.Error(w, errorText, http.StatusBadRequest) return } @@ -198,18 +198,18 @@ func handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) { jsonVal, err := json.Marshal(result) if err != nil { - errortext := fmt.Sprintf("Unable to marshal status json: %s", err) - log.Println(errortext) - http.Error(w, errortext, http.StatusInternalServerError) + errorText := fmt.Sprintf("Unable to marshal status json: %s", err) + log.Println(errorText) + http.Error(w, errorText, http.StatusInternalServerError) return } w.Header().Set("Content-Type", "application/json") _, err = w.Write(jsonVal) if err != nil { - errortext := fmt.Sprintf("Couldn't write body: %s", err) - log.Println(errortext) - http.Error(w, errortext, http.StatusInternalServerError) + errorText := fmt.Sprintf("Couldn't write body: %s", err) + log.Println(errorText) + http.Error(w, errorText, http.StatusInternalServerError) } } diff --git a/upstream/dns_upstream.go b/upstream/dns_upstream.go new file mode 100644 index 00000000..779c059e --- /dev/null +++ b/upstream/dns_upstream.go @@ -0,0 +1,36 @@ +package upstream + +import ( + "github.com/miekg/dns" + "golang.org/x/net/context" + "time" +) + +// DnsUpstream is a very simple upstream implementation for plain DNS +type DnsUpstream struct { + nameServer string // IP:port + timeout time.Duration // Max read and write timeout +} + +// NewDnsUpstream creates a new plain-DNS upstream +func NewDnsUpstream(nameServer string) (Upstream, error) { + return &DnsUpstream{nameServer: nameServer, timeout: defaultTimeout}, nil +} + +// Exchange provides an implementation for the Upstream interface +func (u *DnsUpstream) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) { + + dnsClient := &dns.Client{ + ReadTimeout: u.timeout, + WriteTimeout: u.timeout, + } + + resp, _, err := dnsClient.Exchange(query, u.nameServer) + + if err != nil { + resp = &dns.Msg{} + resp.SetRcode(resp, dns.RcodeServerFailure) + } + + return resp, err +} diff --git a/upstream/https_upstream.go b/upstream/https_upstream.go new file mode 100644 index 00000000..61c7a397 --- /dev/null +++ b/upstream/https_upstream.go @@ -0,0 +1,109 @@ +package upstream + +import ( + "bytes" + "crypto/tls" + "fmt" + "github.com/miekg/dns" + "github.com/pkg/errors" + "golang.org/x/net/context" + "golang.org/x/net/http2" + "io/ioutil" + "log" + "net/http" + "net/url" +) + +const ( + dnsMessageContentType = "application/dns-message" +) + +// HttpsUpstream is the upstream implementation for DNS-over-HTTPS +type HttpsUpstream struct { + client *http.Client + endpoint *url.URL +} + +// NewHttpsUpstream creates a new DNS-over-HTTPS upstream from hostname +func NewHttpsUpstream(endpoint string) (Upstream, error) { + u, err := url.Parse(endpoint) + if err != nil { + return nil, err + } + + // Update TLS and HTTP client configuration + tlsConfig := &tls.Config{ServerName: u.Hostname()} + transport := &http.Transport{ + TLSClientConfig: tlsConfig, + DisableCompression: true, + MaxIdleConns: 1, + } + http2.ConfigureTransport(transport) + + client := &http.Client{ + Timeout: defaultTimeout, + Transport: transport, + } + + return &HttpsUpstream{client: client, endpoint: u}, nil +} + +// Exchange provides an implementation for the Upstream interface +func (u *HttpsUpstream) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) { + queryBuf, err := query.Pack() + if err != nil { + return nil, errors.Wrap(err, "failed to pack DNS query") + } + + // No content negotiation for now, use DNS wire format + buf, backendErr := u.exchangeWireformat(queryBuf) + if backendErr == nil { + response := &dns.Msg{} + if err := response.Unpack(buf); err != nil { + return nil, errors.Wrap(err, "failed to unpack DNS response from body") + } + + response.Id = query.Id + return response, nil + } + + log.Printf("failed to connect to an HTTPS backend %q due to %s", u.endpoint, backendErr) + return nil, backendErr +} + +// Perform message exchange with the default UDP wireformat defined in current draft +// https://tools.ietf.org/html/draft-ietf-doh-dns-over-https-10 +func (u *HttpsUpstream) exchangeWireformat(msg []byte) ([]byte, error) { + req, err := http.NewRequest("POST", u.endpoint.String(), bytes.NewBuffer(msg)) + if err != nil { + return nil, errors.Wrap(err, "failed to create an HTTPS request") + } + + req.Header.Add("Content-Type", dnsMessageContentType) + req.Header.Add("Accept", dnsMessageContentType) + req.Host = u.endpoint.Hostname() + + resp, err := u.client.Do(req) + if err != nil { + return nil, errors.Wrap(err, "failed to perform an HTTPS request") + } + + // Check response status code + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("returned status code %d", resp.StatusCode) + } + + contentType := resp.Header.Get("Content-Type") + if contentType != dnsMessageContentType { + return nil, fmt.Errorf("return wrong content type %s", contentType) + } + + // Read application/dns-message response from the body + buf, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, errors.Wrap(err, "failed to read the response body") + } + + return buf, nil +} diff --git a/upstream/tls_upstream.go b/upstream/tls_upstream.go new file mode 100644 index 00000000..aed55829 --- /dev/null +++ b/upstream/tls_upstream.go @@ -0,0 +1,47 @@ +package upstream + +import ( + "crypto/tls" + "github.com/miekg/dns" + "golang.org/x/net/context" + "time" +) + +// TODO: Use persistent connection here + +// DnsOverTlsUpstream is the upstream implementation for plain DNS-over-TLS +type DnsOverTlsUpstream struct { + endpoint string + tlsServerName string + timeout time.Duration +} + +// NewHttpsUpstream creates a new DNS-over-TLS upstream from the endpoint address and TLS server name +func NewDnsOverTlsUpstream(endpoint string, tlsServerName string) (Upstream, error) { + return &DnsOverTlsUpstream{ + endpoint: endpoint, + tlsServerName: tlsServerName, + timeout: defaultTimeout, + }, nil +} + +// Exchange provides an implementation for the Upstream interface +func (u *DnsOverTlsUpstream) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) { + + dnsClient := &dns.Client{ + Net: "tcp-tls", + ReadTimeout: u.timeout, + WriteTimeout: u.timeout, + TLSConfig: new(tls.Config), + } + dnsClient.TLSConfig.ServerName = u.tlsServerName + + resp, _, err := dnsClient.Exchange(query, u.endpoint) + + if err != nil { + resp = &dns.Msg{} + resp.SetRcode(resp, dns.RcodeServerFailure) + } + + return resp, err +} diff --git a/upstream/upstream.go b/upstream/upstream.go new file mode 100644 index 00000000..6d2570c5 --- /dev/null +++ b/upstream/upstream.go @@ -0,0 +1,43 @@ +package upstream + +import ( + "github.com/coredns/coredns/plugin" + "github.com/miekg/dns" + "github.com/pkg/errors" + "golang.org/x/net/context" + "time" +) + +const ( + defaultTimeout = 5 * time.Second +) + +// Upstream is a simplified interface for proxy destination +type Upstream interface { + Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) +} + +// UpstreamPlugin is a simplified DNS proxy using a generic upstream interface +type UpstreamPlugin struct { + Upstreams []Upstream + Next plugin.Handler +} + +// ServeDNS implements interface for CoreDNS plugin +func (p UpstreamPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + var reply *dns.Msg + var backendErr error + + for _, upstream := range p.Upstreams { + reply, backendErr = upstream.Exchange(ctx, r) + if backendErr == nil { + w.WriteMsg(reply) + return 0, nil + } + } + + return dns.RcodeServerFailure, errors.Wrap(backendErr, "failed to contact any of the upstreams") +} + +// Name implements interface for CoreDNS plugin +func (p UpstreamPlugin) Name() string { return "upstream" } diff --git a/upstream/upstream_test.go b/upstream/upstream_test.go new file mode 100644 index 00000000..ca0df859 --- /dev/null +++ b/upstream/upstream_test.go @@ -0,0 +1,86 @@ +package upstream + +import ( + "github.com/miekg/dns" + "log" + "net" + "testing" +) + +func TestDnsUpstream(t *testing.T) { + + u, err := NewDnsUpstream("8.8.8.8:53") + + if err != nil { + t.Errorf("cannot create a DNS upstream") + } + + testUpstream(t, u) +} + +func TestHttpsUpstream(t *testing.T) { + + testCases := []string{ + "https://cloudflare-dns.com/dns-query", + "https://dns.google.com/experimental", + "https://doh.cleanbrowsing.org/doh/security-filter/", + } + + for _, url := range testCases { + u, err := NewHttpsUpstream(url) + + if err != nil { + t.Errorf("cannot create a DNS-over-HTTPS upstream") + } + + testUpstream(t, u) + } +} + +func TestDnsOverTlsUpstream(t *testing.T) { + + var tests = []struct { + endpoint string + tlsServerName string + }{ + {"1.1.1.1:853", ""}, + {"8.8.8.8:853", ""}, + {"185.228.168.10:853", "security-filter-dns.cleanbrowsing.org"}, + } + + for _, test := range tests { + u, err := NewDnsOverTlsUpstream(test.endpoint, test.tlsServerName) + + if err != nil { + t.Errorf("cannot create a DNS-over-TLS upstream") + } + + testUpstream(t, u) + } +} + +func testUpstream(t *testing.T, u Upstream) { + req := dns.Msg{} + req.Id = dns.Id() + req.RecursionDesired = true + req.Question = []dns.Question{ + {Name: "google-public-dns-a.google.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, + } + + resp, err := u.Exchange(nil, &req) + + if err != nil { + t.Errorf("error while making an upstream request: %s", err) + } + + if len(resp.Answer) != 1 { + t.Errorf("no answer section in the response") + } + if answer, ok := resp.Answer[0].(*dns.A); ok { + if !net.IPv4(8, 8, 8, 8).Equal(answer.A) { + t.Errorf("wrong IP in the response: %v", answer.A) + } + } + + log.Printf("response: %v", resp) +}