badguardhome/dnsforward/upstream.go
2018-12-06 17:19:57 +03:00

240 lines
5.4 KiB
Go

package dnsforward
import (
"bytes"
"fmt"
"io/ioutil"
"log"
"math/rand"
"net"
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/joomcode/errorx"
"github.com/miekg/dns"
)
const defaultTimeout = time.Second * 10
type Upstream interface {
Exchange(m *dns.Msg) (*dns.Msg, error)
Address() string
}
//
// plain DNS
//
type plainDNS struct {
boot bootstrapper
preferTCP bool
}
var defaultUDPClient = dns.Client{
Timeout: defaultTimeout,
UDPSize: dns.MaxMsgSize,
}
var defaultTCPClient = dns.Client{
Net: "tcp",
UDPSize: dns.MaxMsgSize,
Timeout: defaultTimeout,
}
// Address returns the original address that we've put in initially, not resolved one
func (p *plainDNS) Address() string { return p.boot.address }
func (p *plainDNS) Exchange(m *dns.Msg) (*dns.Msg, error) {
addr, _, err := p.boot.get()
if err != nil {
return nil, err
}
if p.preferTCP {
reply, _, err := defaultTCPClient.Exchange(m, addr)
return reply, err
}
reply, _, err := defaultUDPClient.Exchange(m, addr)
if err != nil && reply != nil && reply.Truncated {
log.Printf("Truncated message was received, retrying over TCP, question: %s", m.Question[0].String())
reply, _, err = defaultTCPClient.Exchange(m, addr)
}
return reply, err
}
//
// DNS-over-TLS
//
type dnsOverTLS struct {
boot bootstrapper
pool *TLSPool
sync.RWMutex // protects pool
}
func (p *dnsOverTLS) Address() string { return p.boot.address }
func (p *dnsOverTLS) Exchange(m *dns.Msg) (*dns.Msg, error) {
var pool *TLSPool
p.RLock()
pool = p.pool
p.RUnlock()
if pool == nil {
p.Lock()
// lazy initialize it
p.pool = &TLSPool{boot: &p.boot}
p.Unlock()
}
p.RLock()
poolConn, err := p.pool.Get()
p.RUnlock()
if err != nil {
return nil, errorx.Decorate(err, "Failed to get a connection from TLSPool to %s", p.Address())
}
c := dns.Conn{Conn: poolConn}
err = c.WriteMsg(m)
if err != nil {
poolConn.Close()
return nil, errorx.Decorate(err, "Failed to send a request to %s", p.Address())
}
reply, err := c.ReadMsg()
if err != nil {
poolConn.Close()
return nil, errorx.Decorate(err, "Failed to read a request from %s", p.Address())
}
p.RLock()
p.pool.Put(poolConn)
p.RUnlock()
return reply, nil
}
//
// DNS-over-https
//
type dnsOverHTTPS struct {
boot bootstrapper
}
func (p *dnsOverHTTPS) Address() string { return p.boot.address }
func (p *dnsOverHTTPS) Exchange(m *dns.Msg) (*dns.Msg, error) {
addr, tlsConfig, err := p.boot.get()
if err != nil {
return nil, errorx.Decorate(err, "Couldn't bootstrap %s", p.boot.address)
}
buf, err := m.Pack()
if err != nil {
return nil, errorx.Decorate(err, "Couldn't pack request msg")
}
bb := bytes.NewBuffer(buf)
// set up a custom request with custom URL
url, err := url.Parse(p.boot.address)
if err != nil {
return nil, errorx.Decorate(err, "Couldn't parse URL %s", p.boot.address)
}
req := http.Request{
Method: "POST",
URL: url,
Body: ioutil.NopCloser(bb),
Header: make(http.Header),
Host: url.Host,
}
url.Host = addr
req.Header.Set("Content-Type", "application/dns-message")
client := http.Client{
Transport: &http.Transport{TLSClientConfig: tlsConfig},
}
resp, err := client.Do(&req)
if resp != nil && resp.Body != nil {
defer resp.Body.Close()
}
if err != nil {
return nil, errorx.Decorate(err, "Couldn't do a POST request to '%s'", addr)
}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, errorx.Decorate(err, "Couldn't read body contents for '%s'", addr)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("Got an unexpected HTTP status code %d from '%s'", resp.StatusCode, addr)
}
if len(body) == 0 {
return nil, fmt.Errorf("Got an unexpected empty body from '%s'", addr)
}
response := dns.Msg{}
err = response.Unpack(body)
if err != nil {
return nil, errorx.Decorate(err, "Couldn't unpack DNS response from '%s': body is %s", addr, string(body))
}
return &response, nil
}
func (s *Server) chooseUpstream() Upstream {
upstreams := s.Upstreams
if upstreams == nil {
upstreams = defaultValues.Upstreams
}
if len(upstreams) == 0 {
panic("SHOULD NOT HAPPEN: no default upstreams specified")
}
if len(upstreams) == 1 {
return upstreams[0]
}
n := rand.Intn(len(upstreams))
upstream := upstreams[n]
return upstream
}
func AddressToUpstream(address string, bootstrap string) (Upstream, error) {
if strings.Contains(address, "://") {
url, err := url.Parse(address)
if err != nil {
return nil, errorx.Decorate(err, "Failed to parse %s", address)
}
switch url.Scheme {
case "dns":
if url.Port() == "" {
url.Host += ":53"
}
return &plainDNS{boot: toBoot(url.Host, bootstrap)}, nil
case "tcp":
if url.Port() == "" {
url.Host += ":53"
}
return &plainDNS{boot: toBoot(url.Host, bootstrap), preferTCP: true}, nil
case "tls":
if url.Port() == "" {
url.Host += ":853"
}
return &dnsOverTLS{boot: toBoot(url.String(), bootstrap)}, nil
case "https":
if url.Port() == "" {
url.Host += ":443"
}
return &dnsOverHTTPS{boot: toBoot(url.String(), bootstrap)}, nil
default:
// assume it's plain DNS
if url.Port() == "" {
url.Host += ":53"
}
return &plainDNS{boot: toBoot(url.String(), bootstrap)}, nil
}
}
// we don't have scheme in the url, so it's just a plain DNS host:port
_, _, err := net.SplitHostPort(address)
if err != nil {
// doesn't have port, default to 53
address = net.JoinHostPort(address, "53")
}
return &plainDNS{boot: toBoot(address, bootstrap)}, nil
}