Upstream plugin prototype
This commit is contained in:
parent
19e30dbccc
commit
484c0ceaff
42
control.go
42
control.go
@ -134,9 +134,9 @@ func httpError(w http.ResponseWriter, code int, format string, args ...interface
|
||||
func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
errortext := fmt.Sprintf("Failed to read request body: %s", err)
|
||||
log.Println(errortext)
|
||||
http.Error(w, errortext, http.StatusBadRequest)
|
||||
errorText := fmt.Sprintf("Failed to read request body: %s", err)
|
||||
log.Println(errorText)
|
||||
http.Error(w, errorText, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
// if empty body -- user is asking for default servers
|
||||
@ -153,34 +153,34 @@ func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
err = writeAllConfigs()
|
||||
if err != nil {
|
||||
errortext := fmt.Sprintf("Couldn't write config file: %s", err)
|
||||
log.Println(errortext)
|
||||
http.Error(w, errortext, http.StatusInternalServerError)
|
||||
errorText := fmt.Sprintf("Couldn't write config file: %s", err)
|
||||
log.Println(errorText)
|
||||
http.Error(w, errorText, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
tellCoreDNSToReload()
|
||||
_, err = fmt.Fprintf(w, "OK %d servers\n", len(hosts))
|
||||
if err != nil {
|
||||
errortext := fmt.Sprintf("Couldn't write body: %s", err)
|
||||
log.Println(errortext)
|
||||
http.Error(w, errortext, http.StatusInternalServerError)
|
||||
errorText := fmt.Sprintf("Couldn't write body: %s", err)
|
||||
log.Println(errorText)
|
||||
http.Error(w, errorText, http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
func handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
errortext := fmt.Sprintf("Failed to read request body: %s", err)
|
||||
log.Println(errortext)
|
||||
http.Error(w, errortext, 400)
|
||||
errorText := fmt.Sprintf("Failed to read request body: %s", err)
|
||||
log.Println(errorText)
|
||||
http.Error(w, errorText, 400)
|
||||
return
|
||||
}
|
||||
hosts := strings.Fields(string(body))
|
||||
|
||||
if len(hosts) == 0 {
|
||||
errortext := fmt.Sprintf("No servers specified")
|
||||
log.Println(errortext)
|
||||
http.Error(w, errortext, http.StatusBadRequest)
|
||||
errorText := fmt.Sprintf("No servers specified")
|
||||
log.Println(errorText)
|
||||
http.Error(w, errorText, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
@ -198,18 +198,18 @@ func handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
jsonVal, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
errortext := fmt.Sprintf("Unable to marshal status json: %s", err)
|
||||
log.Println(errortext)
|
||||
http.Error(w, errortext, http.StatusInternalServerError)
|
||||
errorText := fmt.Sprintf("Unable to marshal status json: %s", err)
|
||||
log.Println(errorText)
|
||||
http.Error(w, errorText, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, err = w.Write(jsonVal)
|
||||
if err != nil {
|
||||
errortext := fmt.Sprintf("Couldn't write body: %s", err)
|
||||
log.Println(errortext)
|
||||
http.Error(w, errortext, http.StatusInternalServerError)
|
||||
errorText := fmt.Sprintf("Couldn't write body: %s", err)
|
||||
log.Println(errorText)
|
||||
http.Error(w, errorText, http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
|
36
upstream/dns_upstream.go
Normal file
36
upstream/dns_upstream.go
Normal file
@ -0,0 +1,36 @@
|
||||
package upstream
|
||||
|
||||
import (
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/net/context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DnsUpstream is a very simple upstream implementation for plain DNS
|
||||
type DnsUpstream struct {
|
||||
nameServer string // IP:port
|
||||
timeout time.Duration // Max read and write timeout
|
||||
}
|
||||
|
||||
// NewDnsUpstream creates a new plain-DNS upstream
|
||||
func NewDnsUpstream(nameServer string) (Upstream, error) {
|
||||
return &DnsUpstream{nameServer: nameServer, timeout: defaultTimeout}, nil
|
||||
}
|
||||
|
||||
// Exchange provides an implementation for the Upstream interface
|
||||
func (u *DnsUpstream) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) {
|
||||
|
||||
dnsClient := &dns.Client{
|
||||
ReadTimeout: u.timeout,
|
||||
WriteTimeout: u.timeout,
|
||||
}
|
||||
|
||||
resp, _, err := dnsClient.Exchange(query, u.nameServer)
|
||||
|
||||
if err != nil {
|
||||
resp = &dns.Msg{}
|
||||
resp.SetRcode(resp, dns.RcodeServerFailure)
|
||||
}
|
||||
|
||||
return resp, err
|
||||
}
|
109
upstream/https_upstream.go
Normal file
109
upstream/https_upstream.go
Normal file
@ -0,0 +1,109 @@
|
||||
package upstream
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/net/context"
|
||||
"golang.org/x/net/http2"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
const (
|
||||
dnsMessageContentType = "application/dns-message"
|
||||
)
|
||||
|
||||
// HttpsUpstream is the upstream implementation for DNS-over-HTTPS
|
||||
type HttpsUpstream struct {
|
||||
client *http.Client
|
||||
endpoint *url.URL
|
||||
}
|
||||
|
||||
// NewHttpsUpstream creates a new DNS-over-HTTPS upstream from hostname
|
||||
func NewHttpsUpstream(endpoint string) (Upstream, error) {
|
||||
u, err := url.Parse(endpoint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Update TLS and HTTP client configuration
|
||||
tlsConfig := &tls.Config{ServerName: u.Hostname()}
|
||||
transport := &http.Transport{
|
||||
TLSClientConfig: tlsConfig,
|
||||
DisableCompression: true,
|
||||
MaxIdleConns: 1,
|
||||
}
|
||||
http2.ConfigureTransport(transport)
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: defaultTimeout,
|
||||
Transport: transport,
|
||||
}
|
||||
|
||||
return &HttpsUpstream{client: client, endpoint: u}, nil
|
||||
}
|
||||
|
||||
// Exchange provides an implementation for the Upstream interface
|
||||
func (u *HttpsUpstream) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) {
|
||||
queryBuf, err := query.Pack()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to pack DNS query")
|
||||
}
|
||||
|
||||
// No content negotiation for now, use DNS wire format
|
||||
buf, backendErr := u.exchangeWireformat(queryBuf)
|
||||
if backendErr == nil {
|
||||
response := &dns.Msg{}
|
||||
if err := response.Unpack(buf); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to unpack DNS response from body")
|
||||
}
|
||||
|
||||
response.Id = query.Id
|
||||
return response, nil
|
||||
}
|
||||
|
||||
log.Printf("failed to connect to an HTTPS backend %q due to %s", u.endpoint, backendErr)
|
||||
return nil, backendErr
|
||||
}
|
||||
|
||||
// Perform message exchange with the default UDP wireformat defined in current draft
|
||||
// https://tools.ietf.org/html/draft-ietf-doh-dns-over-https-10
|
||||
func (u *HttpsUpstream) exchangeWireformat(msg []byte) ([]byte, error) {
|
||||
req, err := http.NewRequest("POST", u.endpoint.String(), bytes.NewBuffer(msg))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to create an HTTPS request")
|
||||
}
|
||||
|
||||
req.Header.Add("Content-Type", dnsMessageContentType)
|
||||
req.Header.Add("Accept", dnsMessageContentType)
|
||||
req.Host = u.endpoint.Hostname()
|
||||
|
||||
resp, err := u.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to perform an HTTPS request")
|
||||
}
|
||||
|
||||
// Check response status code
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("returned status code %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if contentType != dnsMessageContentType {
|
||||
return nil, fmt.Errorf("return wrong content type %s", contentType)
|
||||
}
|
||||
|
||||
// Read application/dns-message response from the body
|
||||
buf, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to read the response body")
|
||||
}
|
||||
|
||||
return buf, nil
|
||||
}
|
47
upstream/tls_upstream.go
Normal file
47
upstream/tls_upstream.go
Normal file
@ -0,0 +1,47 @@
|
||||
package upstream
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/net/context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TODO: Use persistent connection here
|
||||
|
||||
// DnsOverTlsUpstream is the upstream implementation for plain DNS-over-TLS
|
||||
type DnsOverTlsUpstream struct {
|
||||
endpoint string
|
||||
tlsServerName string
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
// NewHttpsUpstream creates a new DNS-over-TLS upstream from the endpoint address and TLS server name
|
||||
func NewDnsOverTlsUpstream(endpoint string, tlsServerName string) (Upstream, error) {
|
||||
return &DnsOverTlsUpstream{
|
||||
endpoint: endpoint,
|
||||
tlsServerName: tlsServerName,
|
||||
timeout: defaultTimeout,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Exchange provides an implementation for the Upstream interface
|
||||
func (u *DnsOverTlsUpstream) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) {
|
||||
|
||||
dnsClient := &dns.Client{
|
||||
Net: "tcp-tls",
|
||||
ReadTimeout: u.timeout,
|
||||
WriteTimeout: u.timeout,
|
||||
TLSConfig: new(tls.Config),
|
||||
}
|
||||
dnsClient.TLSConfig.ServerName = u.tlsServerName
|
||||
|
||||
resp, _, err := dnsClient.Exchange(query, u.endpoint)
|
||||
|
||||
if err != nil {
|
||||
resp = &dns.Msg{}
|
||||
resp.SetRcode(resp, dns.RcodeServerFailure)
|
||||
}
|
||||
|
||||
return resp, err
|
||||
}
|
43
upstream/upstream.go
Normal file
43
upstream/upstream.go
Normal file
@ -0,0 +1,43 @@
|
||||
package upstream
|
||||
|
||||
import (
|
||||
"github.com/coredns/coredns/plugin"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/net/context"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
// Upstream is a simplified interface for proxy destination
|
||||
type Upstream interface {
|
||||
Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error)
|
||||
}
|
||||
|
||||
// UpstreamPlugin is a simplified DNS proxy using a generic upstream interface
|
||||
type UpstreamPlugin struct {
|
||||
Upstreams []Upstream
|
||||
Next plugin.Handler
|
||||
}
|
||||
|
||||
// ServeDNS implements interface for CoreDNS plugin
|
||||
func (p UpstreamPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||
var reply *dns.Msg
|
||||
var backendErr error
|
||||
|
||||
for _, upstream := range p.Upstreams {
|
||||
reply, backendErr = upstream.Exchange(ctx, r)
|
||||
if backendErr == nil {
|
||||
w.WriteMsg(reply)
|
||||
return 0, nil
|
||||
}
|
||||
}
|
||||
|
||||
return dns.RcodeServerFailure, errors.Wrap(backendErr, "failed to contact any of the upstreams")
|
||||
}
|
||||
|
||||
// Name implements interface for CoreDNS plugin
|
||||
func (p UpstreamPlugin) Name() string { return "upstream" }
|
86
upstream/upstream_test.go
Normal file
86
upstream/upstream_test.go
Normal file
@ -0,0 +1,86 @@
|
||||
package upstream
|
||||
|
||||
import (
|
||||
"github.com/miekg/dns"
|
||||
"log"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDnsUpstream(t *testing.T) {
|
||||
|
||||
u, err := NewDnsUpstream("8.8.8.8:53")
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("cannot create a DNS upstream")
|
||||
}
|
||||
|
||||
testUpstream(t, u)
|
||||
}
|
||||
|
||||
func TestHttpsUpstream(t *testing.T) {
|
||||
|
||||
testCases := []string{
|
||||
"https://cloudflare-dns.com/dns-query",
|
||||
"https://dns.google.com/experimental",
|
||||
"https://doh.cleanbrowsing.org/doh/security-filter/",
|
||||
}
|
||||
|
||||
for _, url := range testCases {
|
||||
u, err := NewHttpsUpstream(url)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("cannot create a DNS-over-HTTPS upstream")
|
||||
}
|
||||
|
||||
testUpstream(t, u)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDnsOverTlsUpstream(t *testing.T) {
|
||||
|
||||
var tests = []struct {
|
||||
endpoint string
|
||||
tlsServerName string
|
||||
}{
|
||||
{"1.1.1.1:853", ""},
|
||||
{"8.8.8.8:853", ""},
|
||||
{"185.228.168.10:853", "security-filter-dns.cleanbrowsing.org"},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
u, err := NewDnsOverTlsUpstream(test.endpoint, test.tlsServerName)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("cannot create a DNS-over-TLS upstream")
|
||||
}
|
||||
|
||||
testUpstream(t, u)
|
||||
}
|
||||
}
|
||||
|
||||
func testUpstream(t *testing.T, u Upstream) {
|
||||
req := dns.Msg{}
|
||||
req.Id = dns.Id()
|
||||
req.RecursionDesired = true
|
||||
req.Question = []dns.Question{
|
||||
{Name: "google-public-dns-a.google.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
|
||||
}
|
||||
|
||||
resp, err := u.Exchange(nil, &req)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("error while making an upstream request: %s", err)
|
||||
}
|
||||
|
||||
if len(resp.Answer) != 1 {
|
||||
t.Errorf("no answer section in the response")
|
||||
}
|
||||
if answer, ok := resp.Answer[0].(*dns.A); ok {
|
||||
if !net.IPv4(8, 8, 8, 8).Equal(answer.A) {
|
||||
t.Errorf("wrong IP in the response: %v", answer.A)
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("response: %v", resp)
|
||||
}
|
Loading…
Reference in New Issue
Block a user