badguardhome/dnsforward/dnsforward.go

472 lines
12 KiB
Go
Raw Normal View History

2018-11-28 12:40:56 +00:00
package dnsforward
import (
"fmt"
"log"
"net"
"reflect"
"sync"
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
"github.com/joomcode/errorx"
"github.com/miekg/dns"
)
// Server is the main way to start a DNS server
// Example:
// s := dnsforward.Server{}
// err := s.Start(nil) // will start a DNS server listening on default port 53, in a goroutine
// err := s.Reconfigure(ServerConfig{UDPListenAddr: &net.UDPAddr{Port: 53535}}) // will reconfigure running DNS server to listen on UDP port 53535
// err := s.Stop() // will stop listening on port 53535 and cancel all goroutines
// err := s.Start(nil) // will start listening again, on port 53535, in a goroutine
//
// The zero Server is empty and ready for use.
type Server struct {
udpListen *net.UDPConn
dnsFilter *dnsfilter.Dnsfilter
cache cache
sync.RWMutex
ServerConfig
}
// The zero ServerConfig is empty and ready for use.
type ServerConfig struct {
UDPListenAddr *net.UDPAddr // if nil, then default is is used (port 53 on *)
BlockedTTL uint32 // if 0, then default is used (3600)
Upstreams []Upstream
Filters []Filter
}
var defaultValues = ServerConfig{
UDPListenAddr: &net.UDPAddr{Port: 53},
BlockedTTL: 3600,
Upstreams: []Upstream{
//// dns over HTTPS
// &dnsOverHTTPS{Address: "https://1.1.1.1/dns-query"},
// &dnsOverHTTPS{Address: "https://dns.google.com/experimental"},
// &dnsOverHTTPS{Address: "https://doh.cleanbrowsing.org/doh/security-filter/"},
// &dnsOverHTTPS{Address: "https://dns10.quad9.net/dns-query"},
// &dnsOverHTTPS{Address: "https://doh.powerdns.org"},
// &dnsOverHTTPS{Address: "https://doh.securedns.eu/dns-query"},
//// dns over TLS
// &dnsOverTLS{Address: "tls://8.8.8.8:853"},
// &dnsOverTLS{Address: "tls://8.8.4.4:853"},
&dnsOverTLS{Address: "tls://1.1.1.1:853"},
&dnsOverTLS{Address: "tls://1.0.0.1:853"},
//// plainDNS
// &plainDNS{Address: "8.8.8.8:53"},
// &plainDNS{Address: "8.8.4.4:53"},
// &plainDNS{Address: "1.1.1.1:53"},
// &plainDNS{Address: "1.0.0.1:53"},
},
}
type Filter struct {
ID int64
Rules []string
}
//
// packet loop
//
func (s *Server) packetLoop() {
log.Printf("Entering packet handle loop")
b := make([]byte, dns.MaxMsgSize)
for {
s.RLock()
conn := s.udpListen
s.RUnlock()
if conn == nil {
log.Printf("udp socket has disappeared, exiting loop")
break
}
n, addr, err := conn.ReadFrom(b)
// documentation says to handle the packet even if err occurs, so do that first
if n > 0 {
// make a copy of all bytes because ReadFrom() will overwrite contents of b on next call
// we need the contents to survive the call because we're handling them in goroutine
p := make([]byte, n)
copy(p, b)
go s.handlePacket(p, addr, conn) // ignore errors
}
if err != nil {
if isConnClosed(err) {
log.Printf("ReadFrom() returned because we're reading from a closed connection, exiting loop")
break
}
log.Printf("Got error when reading from udp listen: %s", err)
}
}
}
//
// Control functions
//
func (s *Server) Start(config *ServerConfig) error {
s.Lock()
defer s.Unlock()
if config != nil {
s.ServerConfig = *config
}
// TODO: handle being called Start() second time after Stop()
if s.udpListen == nil {
log.Printf("Creating UDP socket")
var err error
addr := s.UDPListenAddr
if addr == nil {
addr = defaultValues.UDPListenAddr
}
s.udpListen, err = net.ListenUDP("udp", addr)
if err != nil {
return errorx.Decorate(err, "Couldn't listen to UDP socket")
}
log.Println(s.udpListen.LocalAddr(), s.UDPListenAddr)
}
if s.dnsFilter == nil {
log.Printf("Creating dnsfilter")
s.dnsFilter = dnsfilter.New()
}
go s.packetLoop()
return nil
}
func (s *Server) Stop() error {
s.Lock()
defer s.Unlock()
if s.udpListen != nil {
err := s.udpListen.Close()
if err != nil {
return errorx.Decorate(err, "Couldn't close UDP listening socket")
}
s.udpListen = nil
}
return nil
}
//
// Server reconfigure
//
func (s *Server) reconfigureListenAddr(new ServerConfig) error {
oldAddr := s.UDPListenAddr
if oldAddr == nil {
oldAddr = defaultValues.UDPListenAddr
}
newAddr := new.UDPListenAddr
if newAddr == nil {
newAddr = defaultValues.UDPListenAddr
}
if newAddr.Port == 0 {
return errorx.IllegalArgument.New("new port cannot be 0")
}
if reflect.DeepEqual(oldAddr, newAddr) {
// do nothing, the addresses are exactly the same
log.Printf("Not going to rebind because addresses are same: %v -> %v", oldAddr, newAddr)
return nil
}
// rebind, using a strategy:
// * if ports are different, bind new first, then close old
// * if ports are same, close old first, then bind new
var newListen *net.UDPConn
var err error
if oldAddr.Port != newAddr.Port {
log.Printf("Rebinding -- ports are different so bind first then close")
newListen, err = net.ListenUDP("udp", newAddr)
if err != nil {
return errorx.Decorate(err, "Couldn't bind to %v", newAddr)
}
s.Lock()
2018-11-28 12:40:56 +00:00
if s.udpListen != nil {
err = s.udpListen.Close()
}
s.Unlock()
if err != nil {
return errorx.Decorate(err, "Couldn't close UDP listening socket")
2018-11-28 12:40:56 +00:00
}
} else {
log.Printf("Rebinding -- ports are same so close first then bind")
s.Lock()
2018-11-28 12:40:56 +00:00
if s.udpListen != nil {
err = s.udpListen.Close()
}
s.Unlock()
if err != nil {
return errorx.Decorate(err, "Couldn't close UDP listening socket")
2018-11-28 12:40:56 +00:00
}
newListen, err = net.ListenUDP("udp", newAddr)
if err != nil {
return errorx.Decorate(err, "Couldn't bind to %v", newAddr)
}
}
s.Lock()
s.udpListen = newListen
s.UDPListenAddr = new.UDPListenAddr
s.Unlock()
log.Println(s.udpListen.LocalAddr(), s.UDPListenAddr)
go s.packetLoop() // the old one has quit, use new one
return nil
}
func (s *Server) reconfigureBlockedTTL(new ServerConfig) {
newVal := new.BlockedTTL
if newVal == 0 {
newVal = defaultValues.BlockedTTL
}
oldVal := s.BlockedTTL
if oldVal == 0 {
oldVal = defaultValues.BlockedTTL
}
if newVal != oldVal {
s.BlockedTTL = new.BlockedTTL
}
}
func (s *Server) reconfigureUpstreams(new ServerConfig) {
newVal := new.Upstreams
if len(newVal) == 0 {
newVal = defaultValues.Upstreams
}
oldVal := s.Upstreams
if len(oldVal) == 0 {
oldVal = defaultValues.Upstreams
}
if reflect.DeepEqual(newVal, oldVal) {
// they're exactly the same, do nothing
return
}
s.Upstreams = new.Upstreams
}
func (s *Server) reconfigureFilters(new ServerConfig) {
newFilters := new.Filters
if len(newFilters) == 0 {
newFilters = defaultValues.Filters
}
oldFilters := s.Filters
if len(oldFilters) == 0 {
oldFilters = defaultValues.Filters
}
if reflect.DeepEqual(newFilters, oldFilters) {
// they're exactly the same, do nothing
return
}
dnsFilter := dnsfilter.New()
for _, f := range newFilters {
for _, rule := range f.Rules {
err := dnsFilter.AddRule(rule, f.ID)
if err == dnsfilter.ErrAlreadyExists || err == dnsfilter.ErrInvalidSyntax {
continue
}
if err != nil {
log.Printf("Cannot add rule %s: %s", rule, err)
// Just ignore invalid rules
continue
}
}
}
s.Lock()
oldDnsFilter := s.dnsFilter
s.dnsFilter = dnsFilter
s.Unlock()
oldDnsFilter.Destroy()
}
func (s *Server) Reconfigure(new ServerConfig) error {
s.reconfigureBlockedTTL(new)
s.reconfigureUpstreams(new)
s.reconfigureFilters(new)
err := s.reconfigureListenAddr(new)
if err != nil {
return errorx.Decorate(err, "Couldn't reconfigure to new listening address %+v", new.UDPListenAddr)
}
return nil
}
//
// packet handling functions
//
func (s *Server) handlePacket(p []byte, addr net.Addr, conn *net.UDPConn) {
// log.Printf("Got packet %d bytes from %s: %v", len(p), addr, p)
msg := dns.Msg{}
err := msg.Unpack(p)
if err != nil {
log.Printf("failed to unpack DNS packet: %s", err)
return
}
//
// DNS packet byte format is valid
//
// any errors below here require a response to client
// log.Printf("Unpacked: %v", msg.String())
if len(msg.Question) != 1 {
log.Printf("Got invalid number of questions: %v", len(msg.Question))
err := s.respondWithServerFailure(&msg, addr, conn)
if err != nil {
log.Printf("Couldn't respond to UDP packet: %s", err)
return
}
}
{
val, ok := s.cache.Get(&msg)
if ok && val != nil {
err = s.respond(val, addr, conn)
if err != nil {
if isConnClosed(err) {
// ignore this error, the connection was closed and that's ok
return
}
log.Printf("Couldn't respond to UDP packet: %s", err)
return
}
return
}
}
host := msg.Question[0].Name
res, err := s.dnsFilter.CheckHost(host)
if err != nil {
log.Printf("dnsfilter failed to check host '%s': %s", host, err)
err := s.respondWithServerFailure(&msg, addr, conn)
if err != nil {
log.Printf("Couldn't respond to UDP packet: %s", err)
return
}
} else if res.IsFiltered {
log.Printf("Host %s is filtered, reason - '%s', matched rule: '%s'", host, res.Reason, res.Rule)
err := s.respondWithNXDomain(&msg, addr, conn)
if err != nil {
log.Printf("Couldn't respond to UDP packet: %s", err)
return
}
}
// TODO: replace with single-socket implementation
// TODO: replace 8.8.8.8:53 with configurable upstreams
// TODO: support DoH, DoT and TCP
upstream := s.chooseUpstream()
reply, err := upstream.Exchange(&msg)
if err != nil {
log.Printf("talking to upstream failed for host '%s': %s", host, err)
err := s.respondWithServerFailure(&msg, addr, conn)
if err != nil {
if isConnClosed(err) {
// ignore this error, the connection was closed and that's ok
return
}
log.Printf("Couldn't respond to UDP packet with server failure: %s", err)
return
}
return
}
if reply == nil {
log.Printf("SHOULD NOT HAPPEN upstream returned empty message for host '%s'. Request is %v", host, msg.String())
err := s.respondWithServerFailure(&msg, addr, conn)
if err != nil {
log.Printf("Couldn't respond to UDP packet with should not happen: %s", err)
return
}
return
}
err = s.respond(reply, addr, conn)
if err != nil {
if isConnClosed(err) {
// ignore this error, the connection was closed and that's ok
return
}
log.Printf("Couldn't respond to UDP packet: %s", err)
return
}
s.cache.Set(reply)
}
//
// packet sending functions
//
func (s *Server) respond(resp *dns.Msg, addr net.Addr, conn *net.UDPConn) error {
// log.Printf("Replying to %s with %s", addr, resp)
resp.Compress = true
bytes, err := resp.Pack()
if err != nil {
return errorx.Decorate(err, "Couldn't convert message into wire format")
}
n, err := conn.WriteTo(bytes, addr)
if n == 0 && isConnClosed(err) {
return err
}
if n != len(bytes) {
return fmt.Errorf("WriteTo() returned with %d != %d", n, len(bytes))
}
if err != nil {
return errorx.Decorate(err, "WriteTo() returned error")
}
return nil
}
func (s *Server) respondWithServerFailure(request *dns.Msg, addr net.Addr, conn *net.UDPConn) error {
resp := dns.Msg{}
resp.SetRcode(request, dns.RcodeServerFailure)
return s.respond(&resp, addr, conn)
}
func (s *Server) respondWithNXDomain(request *dns.Msg, addr net.Addr, conn *net.UDPConn) error {
resp := dns.Msg{}
resp.SetRcode(request, dns.RcodeNameError)
resp.Ns = s.genSOA(request)
return s.respond(&resp, addr, conn)
}
func (s *Server) genSOA(request *dns.Msg) []dns.RR {
zone := ""
if len(request.Question) > 0 {
zone = request.Question[0].Name
}
soa := dns.SOA{
// values copied from verisign's nonexistent .com domain
// their exact values are not important in our use case because they are used for domain transfers between primary/secondary DNS servers
Refresh: 1800,
Retry: 900,
Expire: 604800,
Minttl: 86400,
// copied from AdGuard DNS
Ns: "fake-for-negative-caching.adguard.com.",
Serial: 100500,
// rest is request-specific
Hdr: dns.RR_Header{
Name: zone,
Rrtype: dns.TypeSOA,
Ttl: s.BlockedTTL,
Class: dns.ClassINET,
},
Mbox: "hostmaster.", // zone will be appended later if it's not empty or "."
}
if soa.Hdr.Ttl == 0 {
soa.Hdr.Ttl = defaultValues.BlockedTTL
}
if len(zone) > 0 && zone[0] != '.' {
soa.Mbox += zone
}
return []dns.RR{&soa}
}