314 lines
7.5 KiB
Go
314 lines
7.5 KiB
Go
package dnsforward
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"log"
|
|
"math/rand"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/jedisct1/go-dnsstamps"
|
|
|
|
"github.com/ameshkov/dnscrypt"
|
|
"github.com/joomcode/errorx"
|
|
"github.com/miekg/dns"
|
|
)
|
|
|
|
const defaultTimeout = time.Second * 10
|
|
|
|
type Upstream interface {
|
|
Exchange(m *dns.Msg) (*dns.Msg, error)
|
|
Address() string
|
|
}
|
|
|
|
//
|
|
// plain DNS
|
|
//
|
|
type plainDNS struct {
|
|
boot bootstrapper
|
|
preferTCP bool
|
|
}
|
|
|
|
var defaultUDPClient = dns.Client{
|
|
Timeout: defaultTimeout,
|
|
UDPSize: dns.MaxMsgSize,
|
|
}
|
|
|
|
var defaultTCPClient = dns.Client{
|
|
Net: "tcp",
|
|
UDPSize: dns.MaxMsgSize,
|
|
Timeout: defaultTimeout,
|
|
}
|
|
|
|
// Address returns the original address that we've put in initially, not resolved one
|
|
func (p *plainDNS) Address() string { return p.boot.address }
|
|
|
|
func (p *plainDNS) Exchange(m *dns.Msg) (*dns.Msg, error) {
|
|
addr, _, err := p.boot.get()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if p.preferTCP {
|
|
reply, _, err := defaultTCPClient.Exchange(m, addr)
|
|
return reply, err
|
|
}
|
|
|
|
reply, _, err := defaultUDPClient.Exchange(m, addr)
|
|
if err != nil && reply != nil && reply.Truncated {
|
|
log.Printf("Truncated message was received, retrying over TCP, question: %s", m.Question[0].String())
|
|
reply, _, err = defaultTCPClient.Exchange(m, addr)
|
|
}
|
|
|
|
return reply, err
|
|
}
|
|
|
|
//
|
|
// DNS-over-TLS
|
|
//
|
|
type dnsOverTLS struct {
|
|
boot bootstrapper
|
|
pool *TLSPool
|
|
|
|
sync.RWMutex // protects pool
|
|
}
|
|
|
|
func (p *dnsOverTLS) Address() string { return p.boot.address }
|
|
|
|
func (p *dnsOverTLS) Exchange(m *dns.Msg) (*dns.Msg, error) {
|
|
var pool *TLSPool
|
|
p.RLock()
|
|
pool = p.pool
|
|
p.RUnlock()
|
|
if pool == nil {
|
|
p.Lock()
|
|
// lazy initialize it
|
|
p.pool = &TLSPool{boot: &p.boot}
|
|
p.Unlock()
|
|
}
|
|
|
|
p.RLock()
|
|
poolConn, err := p.pool.Get()
|
|
p.RUnlock()
|
|
if err != nil {
|
|
return nil, errorx.Decorate(err, "Failed to get a connection from TLSPool to %s", p.Address())
|
|
}
|
|
c := dns.Conn{Conn: poolConn}
|
|
err = c.WriteMsg(m)
|
|
if err != nil {
|
|
poolConn.Close()
|
|
return nil, errorx.Decorate(err, "Failed to send a request to %s", p.Address())
|
|
}
|
|
|
|
reply, err := c.ReadMsg()
|
|
if err != nil {
|
|
poolConn.Close()
|
|
return nil, errorx.Decorate(err, "Failed to read a request from %s", p.Address())
|
|
}
|
|
p.RLock()
|
|
p.pool.Put(poolConn)
|
|
p.RUnlock()
|
|
return reply, nil
|
|
}
|
|
|
|
//
|
|
// DNS-over-https
|
|
//
|
|
type dnsOverHTTPS struct {
|
|
boot bootstrapper
|
|
}
|
|
|
|
func (p *dnsOverHTTPS) Address() string { return p.boot.address }
|
|
|
|
func (p *dnsOverHTTPS) Exchange(m *dns.Msg) (*dns.Msg, error) {
|
|
addr, tlsConfig, err := p.boot.get()
|
|
if err != nil {
|
|
return nil, errorx.Decorate(err, "Couldn't bootstrap %s", p.boot.address)
|
|
}
|
|
|
|
buf, err := m.Pack()
|
|
if err != nil {
|
|
return nil, errorx.Decorate(err, "Couldn't pack request msg")
|
|
}
|
|
bb := bytes.NewBuffer(buf)
|
|
|
|
// set up a custom request with custom URL
|
|
url, err := url.Parse(p.boot.address)
|
|
if err != nil {
|
|
return nil, errorx.Decorate(err, "Couldn't parse URL %s", p.boot.address)
|
|
}
|
|
req := http.Request{
|
|
Method: "POST",
|
|
URL: url,
|
|
Body: ioutil.NopCloser(bb),
|
|
Header: make(http.Header),
|
|
Host: url.Host,
|
|
}
|
|
url.Host = addr
|
|
req.Header.Set("Content-Type", "application/dns-message")
|
|
client := http.Client{
|
|
Transport: &http.Transport{TLSClientConfig: tlsConfig},
|
|
}
|
|
resp, err := client.Do(&req)
|
|
if resp != nil && resp.Body != nil {
|
|
defer resp.Body.Close()
|
|
}
|
|
if err != nil {
|
|
return nil, errorx.Decorate(err, "Couldn't do a POST request to '%s'", addr)
|
|
}
|
|
|
|
body, err := ioutil.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, errorx.Decorate(err, "Couldn't read body contents for '%s'", addr)
|
|
}
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("Got an unexpected HTTP status code %d from '%s'", resp.StatusCode, addr)
|
|
}
|
|
if len(body) == 0 {
|
|
return nil, fmt.Errorf("Got an unexpected empty body from '%s'", addr)
|
|
}
|
|
response := dns.Msg{}
|
|
err = response.Unpack(body)
|
|
if err != nil {
|
|
return nil, errorx.Decorate(err, "Couldn't unpack DNS response from '%s': body is %s", addr, string(body))
|
|
}
|
|
return &response, nil
|
|
}
|
|
|
|
//
|
|
// DNSCrypt
|
|
//
|
|
type dnsCrypt struct {
|
|
boot bootstrapper
|
|
client *dnscrypt.Client // DNSCrypt client properties
|
|
serverInfo *dnscrypt.ServerInfo // DNSCrypt server info
|
|
|
|
sync.RWMutex // protects DNSCrypt client
|
|
}
|
|
|
|
func (p *dnsCrypt) Address() string { return p.boot.address }
|
|
|
|
func (p *dnsCrypt) Exchange(m *dns.Msg) (*dns.Msg, error) {
|
|
|
|
var client *dnscrypt.Client
|
|
var serverInfo *dnscrypt.ServerInfo
|
|
|
|
p.RLock()
|
|
client = p.client
|
|
serverInfo = p.serverInfo
|
|
p.RUnlock()
|
|
|
|
now := uint32(time.Now().Unix())
|
|
if client == nil || serverInfo == nil || (serverInfo != nil && serverInfo.ServerCert.NotAfter < now) {
|
|
p.Lock()
|
|
|
|
// Using "udp" for DNSCrypt upstreams by default
|
|
client = &dnscrypt.Client{Timeout: defaultTimeout, AdjustPayloadSize: true}
|
|
si, _, err := client.Dial(p.boot.address)
|
|
|
|
if err != nil {
|
|
p.Unlock()
|
|
return nil, errorx.Decorate(err, "Failed to fetch certificate info from %s", p.Address())
|
|
}
|
|
|
|
p.client = client
|
|
p.serverInfo = si
|
|
serverInfo = si
|
|
p.Unlock()
|
|
}
|
|
|
|
reply, _, err := client.Exchange(m, serverInfo)
|
|
|
|
if err, ok := err.(net.Error); ok && err.Timeout() {
|
|
// If request times out, it is possible that the server configuration has been changed.
|
|
// It is safe to assume that the key was rotated (for instance, as it is described here: https://dnscrypt.pl/2017/02/26/how-key-rotation-is-automated/).
|
|
// We should re-fetch the server certificate info so that the new requests were not failing.
|
|
p.Lock()
|
|
p.client = nil
|
|
p.serverInfo = nil
|
|
p.Unlock()
|
|
}
|
|
|
|
return reply, err
|
|
}
|
|
|
|
func (s *Server) chooseUpstream() Upstream {
|
|
upstreams := s.Upstreams
|
|
if upstreams == nil {
|
|
upstreams = defaultValues.Upstreams
|
|
}
|
|
if len(upstreams) == 0 {
|
|
panic("SHOULD NOT HAPPEN: no default upstreams specified")
|
|
}
|
|
if len(upstreams) == 1 {
|
|
return upstreams[0]
|
|
}
|
|
n := rand.Intn(len(upstreams))
|
|
upstream := upstreams[n]
|
|
return upstream
|
|
}
|
|
|
|
func AddressToUpstream(address string, bootstrap string) (Upstream, error) {
|
|
if strings.Contains(address, "://") {
|
|
url, err := url.Parse(address)
|
|
if err != nil {
|
|
return nil, errorx.Decorate(err, "Failed to parse %s", address)
|
|
}
|
|
switch url.Scheme {
|
|
case "sdns":
|
|
stamp, err := dnsstamps.NewServerStampFromString(address)
|
|
if err != nil {
|
|
return nil, errorx.Decorate(err, "Failed to parse %s", address)
|
|
}
|
|
|
|
switch stamp.Proto {
|
|
case dnsstamps.StampProtoTypeDNSCrypt:
|
|
return &dnsCrypt{boot: toBoot(url.String(), bootstrap)}, nil
|
|
case dnsstamps.StampProtoTypeDoH:
|
|
return AddressToUpstream(fmt.Sprintf("https://%s%s", stamp.ProviderName, stamp.Path), bootstrap)
|
|
}
|
|
|
|
return nil, fmt.Errorf("Unsupported protocol %v in %s", stamp.Proto, address)
|
|
case "dns":
|
|
if url.Port() == "" {
|
|
url.Host += ":53"
|
|
}
|
|
return &plainDNS{boot: toBoot(url.Host, bootstrap)}, nil
|
|
case "tcp":
|
|
if url.Port() == "" {
|
|
url.Host += ":53"
|
|
}
|
|
return &plainDNS{boot: toBoot(url.Host, bootstrap), preferTCP: true}, nil
|
|
case "tls":
|
|
if url.Port() == "" {
|
|
url.Host += ":853"
|
|
}
|
|
return &dnsOverTLS{boot: toBoot(url.String(), bootstrap)}, nil
|
|
case "https":
|
|
if url.Port() == "" {
|
|
url.Host += ":443"
|
|
}
|
|
return &dnsOverHTTPS{boot: toBoot(url.String(), bootstrap)}, nil
|
|
default:
|
|
// assume it's plain DNS
|
|
if url.Port() == "" {
|
|
url.Host += ":53"
|
|
}
|
|
return &plainDNS{boot: toBoot(url.String(), bootstrap)}, nil
|
|
}
|
|
}
|
|
|
|
// we don't have scheme in the url, so it's just a plain DNS host:port
|
|
_, _, err := net.SplitHostPort(address)
|
|
if err != nil {
|
|
// doesn't have port, default to 53
|
|
address = net.JoinHostPort(address, "53")
|
|
}
|
|
return &plainDNS{boot: toBoot(address, bootstrap)}, nil
|
|
}
|