Add dns forwarding server library
This commit is contained in:
parent
39eccc62b1
commit
5a548be16c
|
@ -0,0 +1,225 @@
|
||||||
|
package dnsforward
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"log"
|
||||||
|
"math"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
type item struct {
|
||||||
|
m *dns.Msg
|
||||||
|
when time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
type cache struct {
|
||||||
|
items map[string]item
|
||||||
|
|
||||||
|
sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *cache) Get(request *dns.Msg) (*dns.Msg, bool) {
|
||||||
|
if request == nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
ok, key := key(request)
|
||||||
|
if !ok {
|
||||||
|
log.Printf("Get(): key returned !ok")
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
c.RLock()
|
||||||
|
item, ok := c.items[key]
|
||||||
|
c.RUnlock()
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
// get item's TTL
|
||||||
|
ttl := findLowestTTL(item.m)
|
||||||
|
// zero TTL? delete and don't serve it
|
||||||
|
if ttl == 0 {
|
||||||
|
c.Lock()
|
||||||
|
delete(c.items, key)
|
||||||
|
c.Unlock()
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
// too much time has passed? delete and don't serve it
|
||||||
|
if time.Since(item.when) >= time.Duration(ttl)*time.Second {
|
||||||
|
c.Lock()
|
||||||
|
delete(c.items, key)
|
||||||
|
c.Unlock()
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
response := item.fromItem(request)
|
||||||
|
return response, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *cache) Set(m *dns.Msg) {
|
||||||
|
if m == nil {
|
||||||
|
return // no-op
|
||||||
|
}
|
||||||
|
if !isRequestCacheable(m) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !isResponseCacheable(m) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ok, key := key(m)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
i := toItem(m)
|
||||||
|
|
||||||
|
c.Lock()
|
||||||
|
if c.items == nil {
|
||||||
|
c.items = map[string]item{}
|
||||||
|
}
|
||||||
|
c.items[key] = i
|
||||||
|
c.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// check only request fields
|
||||||
|
func isRequestCacheable(m *dns.Msg) bool {
|
||||||
|
// truncated messages aren't valid
|
||||||
|
if m.Truncated {
|
||||||
|
log.Printf("Refusing to cache truncated message")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// if has wrong number of questions, also don't cache
|
||||||
|
if len(m.Question) != 1 {
|
||||||
|
log.Printf("Refusing to cache message with wrong number of questions")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// only OK or NXdomain replies are cached
|
||||||
|
switch m.Rcode {
|
||||||
|
case dns.RcodeSuccess:
|
||||||
|
case dns.RcodeNameError: // that's an NXDomain
|
||||||
|
case dns.RcodeServerFailure:
|
||||||
|
return false // quietly refuse, don't log
|
||||||
|
default:
|
||||||
|
log.Printf("%s: Refusing to cache message with rcode: %s", m.Question[0].Name, dns.RcodeToString[m.Rcode])
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func isResponseCacheable(m *dns.Msg) bool {
|
||||||
|
ttl := findLowestTTL(m)
|
||||||
|
if ttl == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func findLowestTTL(m *dns.Msg) uint32 {
|
||||||
|
var ttl uint32 = math.MaxUint32
|
||||||
|
found := false
|
||||||
|
|
||||||
|
if m.Answer != nil {
|
||||||
|
for _, r := range m.Answer {
|
||||||
|
if r.Header().Ttl < ttl {
|
||||||
|
ttl = r.Header().Ttl
|
||||||
|
found = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.Ns != nil {
|
||||||
|
for _, r := range m.Ns {
|
||||||
|
if r.Header().Ttl < ttl {
|
||||||
|
ttl = r.Header().Ttl
|
||||||
|
found = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.Extra != nil {
|
||||||
|
for _, r := range m.Extra {
|
||||||
|
if r.Header().Rrtype == dns.TypeOPT {
|
||||||
|
continue // OPT records use TTL for other purposes
|
||||||
|
}
|
||||||
|
if r.Header().Ttl < ttl {
|
||||||
|
ttl = r.Header().Ttl
|
||||||
|
found = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if found == false {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return ttl
|
||||||
|
}
|
||||||
|
|
||||||
|
// key is binary little endian in sequence:
|
||||||
|
// uint16(qtype) then uint16(qclass) then name
|
||||||
|
func key(m *dns.Msg) (bool, string) {
|
||||||
|
if len(m.Question) != 1 {
|
||||||
|
log.Printf("got msg with len(m.Question) != 1: %d", len(m.Question))
|
||||||
|
return false, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
bb := strings.Builder{}
|
||||||
|
b := make([]byte, 2)
|
||||||
|
binary.LittleEndian.PutUint16(b, m.Question[0].Qtype)
|
||||||
|
bb.Write(b)
|
||||||
|
binary.LittleEndian.PutUint16(b, m.Question[0].Qclass)
|
||||||
|
bb.Write(b)
|
||||||
|
name := strings.ToLower(m.Question[0].Name)
|
||||||
|
bb.WriteString(name)
|
||||||
|
return true, bb.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func toItem(m *dns.Msg) item {
|
||||||
|
return item{
|
||||||
|
m: m,
|
||||||
|
when: time.Now(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *item) fromItem(request *dns.Msg) *dns.Msg {
|
||||||
|
response := &dns.Msg{}
|
||||||
|
response.SetReply(request)
|
||||||
|
|
||||||
|
response.Authoritative = false
|
||||||
|
response.AuthenticatedData = i.m.AuthenticatedData
|
||||||
|
response.RecursionAvailable = i.m.RecursionAvailable
|
||||||
|
response.Rcode = i.m.Rcode
|
||||||
|
|
||||||
|
ttl := findLowestTTL(i.m)
|
||||||
|
timeleft := math.Round(float64(ttl) - time.Since(i.when).Seconds())
|
||||||
|
var newttl uint32
|
||||||
|
if timeleft > 0 {
|
||||||
|
newttl = uint32(timeleft)
|
||||||
|
}
|
||||||
|
for _, r := range i.m.Answer {
|
||||||
|
answer := dns.Copy(r)
|
||||||
|
answer.Header().Ttl = newttl
|
||||||
|
response.Answer = append(response.Answer, answer)
|
||||||
|
}
|
||||||
|
for _, r := range i.m.Ns {
|
||||||
|
ns := dns.Copy(r)
|
||||||
|
ns.Header().Ttl = newttl
|
||||||
|
response.Ns = append(response.Ns, ns)
|
||||||
|
}
|
||||||
|
for _, r := range i.m.Extra {
|
||||||
|
// don't return OPT records as these are hop-by-hop
|
||||||
|
if r.Header().Rrtype == dns.TypeOPT {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
extra := dns.Copy(r)
|
||||||
|
extra.Header().Ttl = newttl
|
||||||
|
response.Extra = append(response.Extra, extra)
|
||||||
|
}
|
||||||
|
return response
|
||||||
|
}
|
|
@ -0,0 +1,144 @@
|
||||||
|
package dnsforward
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/go-test/deep"
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
func RR(rr string) dns.RR {
|
||||||
|
r, err := dns.NewRR(rr)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// deepEqual is same as deep.Equal, except:
|
||||||
|
// * ignores Id when comparing
|
||||||
|
// * question names are not case sensetive
|
||||||
|
func deepEqualMsg(left *dns.Msg, right *dns.Msg) []string {
|
||||||
|
temp := *left
|
||||||
|
temp.Id = right.Id
|
||||||
|
for i := range left.Question {
|
||||||
|
left.Question[i].Name = strings.ToLower(left.Question[i].Name)
|
||||||
|
}
|
||||||
|
for i := range right.Question {
|
||||||
|
right.Question[i].Name = strings.ToLower(right.Question[i].Name)
|
||||||
|
}
|
||||||
|
return deep.Equal(&temp, right)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCacheSanity(t *testing.T) {
|
||||||
|
cache := cache{}
|
||||||
|
request := dns.Msg{}
|
||||||
|
request.SetQuestion("google.com.", dns.TypeA)
|
||||||
|
_, ok := cache.Get(&request)
|
||||||
|
if ok {
|
||||||
|
t.Fatal("empty cache replied with positive response")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type tests struct {
|
||||||
|
cache []testEntry
|
||||||
|
cases []testCase
|
||||||
|
}
|
||||||
|
|
||||||
|
type testEntry struct {
|
||||||
|
q string
|
||||||
|
t uint16
|
||||||
|
a []dns.RR
|
||||||
|
}
|
||||||
|
|
||||||
|
type testCase struct {
|
||||||
|
q string
|
||||||
|
t uint16
|
||||||
|
a []dns.RR
|
||||||
|
ok bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCache(t *testing.T) {
|
||||||
|
tests := tests{
|
||||||
|
cache: []testEntry{
|
||||||
|
{q: "google.com.", t: dns.TypeA, a: []dns.RR{RR("google.com. 3600 IN A 8.8.8.8")}},
|
||||||
|
},
|
||||||
|
cases: []testCase{
|
||||||
|
{q: "google.com.", t: dns.TypeA, a: []dns.RR{RR("google.com. 3600 IN A 8.8.8.8")}, ok: true},
|
||||||
|
{q: "google.com.", t: dns.TypeMX, ok: false},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
runTests(t, tests)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCacheMixedCase(t *testing.T) {
|
||||||
|
tests := tests{
|
||||||
|
cache: []testEntry{
|
||||||
|
{q: "gOOgle.com.", t: dns.TypeA, a: []dns.RR{RR("google.com. 3600 IN A 8.8.8.8")}},
|
||||||
|
},
|
||||||
|
cases: []testCase{
|
||||||
|
{q: "gOOgle.com.", t: dns.TypeA, a: []dns.RR{RR("google.com. 3600 IN A 8.8.8.8")}, ok: true},
|
||||||
|
{q: "google.com.", t: dns.TypeA, a: []dns.RR{RR("google.com. 3600 IN A 8.8.8.8")}, ok: true},
|
||||||
|
{q: "GOOGLE.COM.", t: dns.TypeA, a: []dns.RR{RR("google.com. 3600 IN A 8.8.8.8")}, ok: true},
|
||||||
|
{q: "gOOgle.com.", t: dns.TypeMX, ok: false},
|
||||||
|
{q: "google.com.", t: dns.TypeMX, ok: false},
|
||||||
|
{q: "GOOGLE.COM.", t: dns.TypeMX, ok: false},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
runTests(t, tests)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestZeroTTL(t *testing.T) {
|
||||||
|
tests := tests{
|
||||||
|
cache: []testEntry{
|
||||||
|
{q: "gOOgle.com.", t: dns.TypeA, a: []dns.RR{RR("google.com. 0 IN A 8.8.8.8")}},
|
||||||
|
},
|
||||||
|
cases: []testCase{
|
||||||
|
{q: "google.com.", t: dns.TypeA, ok: false},
|
||||||
|
{q: "google.com.", t: dns.TypeA, ok: false},
|
||||||
|
{q: "google.com.", t: dns.TypeA, ok: false},
|
||||||
|
{q: "google.com.", t: dns.TypeMX, ok: false},
|
||||||
|
{q: "google.com.", t: dns.TypeMX, ok: false},
|
||||||
|
{q: "google.com.", t: dns.TypeMX, ok: false},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
runTests(t, tests)
|
||||||
|
}
|
||||||
|
|
||||||
|
func runTests(t *testing.T, tests tests) {
|
||||||
|
t.Helper()
|
||||||
|
cache := cache{}
|
||||||
|
for _, tc := range tests.cache {
|
||||||
|
reply := dns.Msg{}
|
||||||
|
reply.SetQuestion(tc.q, tc.t)
|
||||||
|
reply.Response = true
|
||||||
|
reply.Answer = tc.a
|
||||||
|
cache.Set(&reply)
|
||||||
|
}
|
||||||
|
for _, tc := range tests.cases {
|
||||||
|
request := dns.Msg{}
|
||||||
|
request.SetQuestion(tc.q, tc.t)
|
||||||
|
val, ok := cache.Get(&request)
|
||||||
|
if diff := deep.Equal(ok, tc.ok); diff != nil {
|
||||||
|
t.Error(diff)
|
||||||
|
}
|
||||||
|
if tc.a != nil {
|
||||||
|
if ok == false {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
reply := dns.Msg{}
|
||||||
|
reply.SetQuestion(tc.q, tc.t)
|
||||||
|
reply.Response = true
|
||||||
|
reply.Answer = tc.a
|
||||||
|
cache.Set(&reply)
|
||||||
|
if diff := deepEqualMsg(val, &reply); diff != nil {
|
||||||
|
t.Error(diff)
|
||||||
|
} else {
|
||||||
|
if diff := deep.Equal(val, reply); diff == nil {
|
||||||
|
t.Error("different message ID were not caught")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,467 @@
|
||||||
|
package dnsforward
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"reflect"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
|
||||||
|
"github.com/joomcode/errorx"
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Server is the main way to start a DNS server
|
||||||
|
// Example:
|
||||||
|
// s := dnsforward.Server{}
|
||||||
|
// err := s.Start(nil) // will start a DNS server listening on default port 53, in a goroutine
|
||||||
|
// err := s.Reconfigure(ServerConfig{UDPListenAddr: &net.UDPAddr{Port: 53535}}) // will reconfigure running DNS server to listen on UDP port 53535
|
||||||
|
// err := s.Stop() // will stop listening on port 53535 and cancel all goroutines
|
||||||
|
// err := s.Start(nil) // will start listening again, on port 53535, in a goroutine
|
||||||
|
//
|
||||||
|
// The zero Server is empty and ready for use.
|
||||||
|
type Server struct {
|
||||||
|
udpListen *net.UDPConn
|
||||||
|
|
||||||
|
dnsFilter *dnsfilter.Dnsfilter
|
||||||
|
|
||||||
|
cache cache
|
||||||
|
|
||||||
|
sync.RWMutex
|
||||||
|
ServerConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// The zero ServerConfig is empty and ready for use.
|
||||||
|
type ServerConfig struct {
|
||||||
|
UDPListenAddr *net.UDPAddr // if nil, then default is is used (port 53 on *)
|
||||||
|
BlockedTTL uint32 // if 0, then default is used (3600)
|
||||||
|
Upstreams []Upstream
|
||||||
|
Filters []Filter
|
||||||
|
}
|
||||||
|
|
||||||
|
var defaultValues = ServerConfig{
|
||||||
|
UDPListenAddr: &net.UDPAddr{Port: 53},
|
||||||
|
BlockedTTL: 3600,
|
||||||
|
Upstreams: []Upstream{
|
||||||
|
//// dns over HTTPS
|
||||||
|
// &dnsOverHTTPS{Address: "https://1.1.1.1/dns-query"},
|
||||||
|
// &dnsOverHTTPS{Address: "https://dns.google.com/experimental"},
|
||||||
|
// &dnsOverHTTPS{Address: "https://doh.cleanbrowsing.org/doh/security-filter/"},
|
||||||
|
// &dnsOverHTTPS{Address: "https://dns10.quad9.net/dns-query"},
|
||||||
|
// &dnsOverHTTPS{Address: "https://doh.powerdns.org"},
|
||||||
|
// &dnsOverHTTPS{Address: "https://doh.securedns.eu/dns-query"},
|
||||||
|
|
||||||
|
//// dns over TLS
|
||||||
|
// &dnsOverTLS{Address: "tls://8.8.8.8:853"},
|
||||||
|
// &dnsOverTLS{Address: "tls://8.8.4.4:853"},
|
||||||
|
&dnsOverTLS{Address: "tls://1.1.1.1:853"},
|
||||||
|
&dnsOverTLS{Address: "tls://1.0.0.1:853"},
|
||||||
|
|
||||||
|
//// plainDNS
|
||||||
|
// &plainDNS{Address: "8.8.8.8:53"},
|
||||||
|
// &plainDNS{Address: "8.8.4.4:53"},
|
||||||
|
// &plainDNS{Address: "1.1.1.1:53"},
|
||||||
|
// &plainDNS{Address: "1.0.0.1:53"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
type Filter struct {
|
||||||
|
ID int64
|
||||||
|
Rules []string
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// packet loop
|
||||||
|
//
|
||||||
|
func (s *Server) packetLoop() {
|
||||||
|
log.Printf("Entering packet handle loop")
|
||||||
|
b := make([]byte, dns.MaxMsgSize)
|
||||||
|
for {
|
||||||
|
s.RLock()
|
||||||
|
conn := s.udpListen
|
||||||
|
s.RUnlock()
|
||||||
|
if conn == nil {
|
||||||
|
log.Printf("udp socket has disappeared, exiting loop")
|
||||||
|
break
|
||||||
|
}
|
||||||
|
n, addr, err := conn.ReadFrom(b)
|
||||||
|
// documentation says to handle the packet even if err occurs, so do that first
|
||||||
|
if n > 0 {
|
||||||
|
// make a copy of all bytes because ReadFrom() will overwrite contents of b on next call
|
||||||
|
// we need the contents to survive the call because we're handling them in goroutine
|
||||||
|
p := make([]byte, n)
|
||||||
|
copy(p, b)
|
||||||
|
go s.handlePacket(p, addr, conn) // ignore errors
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
if isConnClosed(err) {
|
||||||
|
log.Printf("ReadFrom() returned because we're reading from a closed connection, exiting loop")
|
||||||
|
break
|
||||||
|
}
|
||||||
|
log.Printf("Got error when reading from udp listen: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Control functions
|
||||||
|
//
|
||||||
|
|
||||||
|
func (s *Server) Start(config *ServerConfig) error {
|
||||||
|
s.Lock()
|
||||||
|
defer s.Unlock()
|
||||||
|
if config != nil {
|
||||||
|
s.ServerConfig = *config
|
||||||
|
}
|
||||||
|
// TODO: handle being called Start() second time after Stop()
|
||||||
|
if s.udpListen == nil {
|
||||||
|
log.Printf("Creating UDP socket")
|
||||||
|
var err error
|
||||||
|
addr := s.UDPListenAddr
|
||||||
|
if addr == nil {
|
||||||
|
addr = defaultValues.UDPListenAddr
|
||||||
|
}
|
||||||
|
s.udpListen, err = net.ListenUDP("udp", addr)
|
||||||
|
if err != nil {
|
||||||
|
return errorx.Decorate(err, "Couldn't listen to UDP socket")
|
||||||
|
}
|
||||||
|
log.Println(s.udpListen.LocalAddr(), s.UDPListenAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.dnsFilter == nil {
|
||||||
|
log.Printf("Creating dnsfilter")
|
||||||
|
s.dnsFilter = dnsfilter.New()
|
||||||
|
}
|
||||||
|
|
||||||
|
go s.packetLoop()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) Stop() error {
|
||||||
|
s.Lock()
|
||||||
|
defer s.Unlock()
|
||||||
|
if s.udpListen != nil {
|
||||||
|
err := s.udpListen.Close()
|
||||||
|
if err != nil {
|
||||||
|
return errorx.Decorate(err, "Couldn't close UDP listening socket")
|
||||||
|
}
|
||||||
|
s.udpListen = nil
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Server reconfigure
|
||||||
|
//
|
||||||
|
|
||||||
|
func (s *Server) reconfigureListenAddr(new ServerConfig) error {
|
||||||
|
oldAddr := s.UDPListenAddr
|
||||||
|
if oldAddr == nil {
|
||||||
|
oldAddr = defaultValues.UDPListenAddr
|
||||||
|
}
|
||||||
|
newAddr := new.UDPListenAddr
|
||||||
|
if newAddr == nil {
|
||||||
|
newAddr = defaultValues.UDPListenAddr
|
||||||
|
}
|
||||||
|
if newAddr.Port == 0 {
|
||||||
|
return errorx.IllegalArgument.New("new port cannot be 0")
|
||||||
|
}
|
||||||
|
if reflect.DeepEqual(oldAddr, newAddr) {
|
||||||
|
// do nothing, the addresses are exactly the same
|
||||||
|
log.Printf("Not going to rebind because addresses are same: %v -> %v", oldAddr, newAddr)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// rebind, using a strategy:
|
||||||
|
// * if ports are different, bind new first, then close old
|
||||||
|
// * if ports are same, close old first, then bind new
|
||||||
|
var newListen *net.UDPConn
|
||||||
|
var err error
|
||||||
|
if oldAddr.Port != newAddr.Port {
|
||||||
|
log.Printf("Rebinding -- ports are different so bind first then close")
|
||||||
|
newListen, err = net.ListenUDP("udp", newAddr)
|
||||||
|
if err != nil {
|
||||||
|
return errorx.Decorate(err, "Couldn't bind to %v", newAddr)
|
||||||
|
}
|
||||||
|
if s.udpListen != nil {
|
||||||
|
err := s.udpListen.Close()
|
||||||
|
if err != nil {
|
||||||
|
return errorx.Decorate(err, "Couldn't close UDP listening socket")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.Printf("Rebinding -- ports are same so close first then bind")
|
||||||
|
if s.udpListen != nil {
|
||||||
|
err := s.udpListen.Close()
|
||||||
|
if err != nil {
|
||||||
|
return errorx.Decorate(err, "Couldn't close UDP listening socket")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
newListen, err = net.ListenUDP("udp", newAddr)
|
||||||
|
if err != nil {
|
||||||
|
return errorx.Decorate(err, "Couldn't bind to %v", newAddr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.Lock()
|
||||||
|
s.udpListen = newListen
|
||||||
|
s.UDPListenAddr = new.UDPListenAddr
|
||||||
|
s.Unlock()
|
||||||
|
log.Println(s.udpListen.LocalAddr(), s.UDPListenAddr)
|
||||||
|
|
||||||
|
go s.packetLoop() // the old one has quit, use new one
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) reconfigureBlockedTTL(new ServerConfig) {
|
||||||
|
newVal := new.BlockedTTL
|
||||||
|
if newVal == 0 {
|
||||||
|
newVal = defaultValues.BlockedTTL
|
||||||
|
}
|
||||||
|
oldVal := s.BlockedTTL
|
||||||
|
if oldVal == 0 {
|
||||||
|
oldVal = defaultValues.BlockedTTL
|
||||||
|
}
|
||||||
|
if newVal != oldVal {
|
||||||
|
s.BlockedTTL = new.BlockedTTL
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) reconfigureUpstreams(new ServerConfig) {
|
||||||
|
newVal := new.Upstreams
|
||||||
|
if len(newVal) == 0 {
|
||||||
|
newVal = defaultValues.Upstreams
|
||||||
|
}
|
||||||
|
oldVal := s.Upstreams
|
||||||
|
if len(oldVal) == 0 {
|
||||||
|
oldVal = defaultValues.Upstreams
|
||||||
|
}
|
||||||
|
if reflect.DeepEqual(newVal, oldVal) {
|
||||||
|
// they're exactly the same, do nothing
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.Upstreams = new.Upstreams
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) reconfigureFilters(new ServerConfig) {
|
||||||
|
newFilters := new.Filters
|
||||||
|
if len(newFilters) == 0 {
|
||||||
|
newFilters = defaultValues.Filters
|
||||||
|
}
|
||||||
|
oldFilters := s.Filters
|
||||||
|
if len(oldFilters) == 0 {
|
||||||
|
oldFilters = defaultValues.Filters
|
||||||
|
}
|
||||||
|
if reflect.DeepEqual(newFilters, oldFilters) {
|
||||||
|
// they're exactly the same, do nothing
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
dnsFilter := dnsfilter.New()
|
||||||
|
for _, f := range newFilters {
|
||||||
|
for _, rule := range f.Rules {
|
||||||
|
err := dnsFilter.AddRule(rule, f.ID)
|
||||||
|
if err == dnsfilter.ErrAlreadyExists || err == dnsfilter.ErrInvalidSyntax {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Cannot add rule %s: %s", rule, err)
|
||||||
|
// Just ignore invalid rules
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.Lock()
|
||||||
|
oldDnsFilter := s.dnsFilter
|
||||||
|
s.dnsFilter = dnsFilter
|
||||||
|
s.Unlock()
|
||||||
|
|
||||||
|
oldDnsFilter.Destroy()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) Reconfigure(new ServerConfig) error {
|
||||||
|
s.reconfigureBlockedTTL(new)
|
||||||
|
s.reconfigureUpstreams(new)
|
||||||
|
s.reconfigureFilters(new)
|
||||||
|
|
||||||
|
err := s.reconfigureListenAddr(new)
|
||||||
|
if err != nil {
|
||||||
|
return errorx.Decorate(err, "Couldn't reconfigure to new listening address %+v", new.UDPListenAddr)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// packet handling functions
|
||||||
|
//
|
||||||
|
|
||||||
|
func (s *Server) handlePacket(p []byte, addr net.Addr, conn *net.UDPConn) {
|
||||||
|
// log.Printf("Got packet %d bytes from %s: %v", len(p), addr, p)
|
||||||
|
msg := dns.Msg{}
|
||||||
|
err := msg.Unpack(p)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("failed to unpack DNS packet: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// DNS packet byte format is valid
|
||||||
|
//
|
||||||
|
// any errors below here require a response to client
|
||||||
|
// log.Printf("Unpacked: %v", msg.String())
|
||||||
|
if len(msg.Question) != 1 {
|
||||||
|
log.Printf("Got invalid number of questions: %v", len(msg.Question))
|
||||||
|
err := s.respondWithServerFailure(&msg, addr, conn)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Couldn't respond to UDP packet: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
val, ok := s.cache.Get(&msg)
|
||||||
|
if ok && val != nil {
|
||||||
|
err = s.respond(val, addr, conn)
|
||||||
|
if err != nil {
|
||||||
|
if isConnClosed(err) {
|
||||||
|
// ignore this error, the connection was closed and that's ok
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Printf("Couldn't respond to UDP packet: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
host := msg.Question[0].Name
|
||||||
|
res, err := s.dnsFilter.CheckHost(host)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("dnsfilter failed to check host '%s': %s", host, err)
|
||||||
|
err := s.respondWithServerFailure(&msg, addr, conn)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Couldn't respond to UDP packet: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else if res.IsFiltered {
|
||||||
|
log.Printf("Host %s is filtered, reason - '%s', matched rule: '%s'", host, res.Reason, res.Rule)
|
||||||
|
err := s.respondWithNXDomain(&msg, addr, conn)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Couldn't respond to UDP packet: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: replace with single-socket implementation
|
||||||
|
// TODO: replace 8.8.8.8:53 with configurable upstreams
|
||||||
|
// TODO: support DoH, DoT and TCP
|
||||||
|
upstream := s.chooseUpstream()
|
||||||
|
reply, err := upstream.Exchange(&msg)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("talking to upstream failed for host '%s': %s", host, err)
|
||||||
|
err := s.respondWithServerFailure(&msg, addr, conn)
|
||||||
|
if err != nil {
|
||||||
|
if isConnClosed(err) {
|
||||||
|
// ignore this error, the connection was closed and that's ok
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Printf("Couldn't respond to UDP packet with server failure: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if reply == nil {
|
||||||
|
log.Printf("SHOULD NOT HAPPEN upstream returned empty message for host '%s'. Request is %v", host, msg.String())
|
||||||
|
err := s.respondWithServerFailure(&msg, addr, conn)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Couldn't respond to UDP packet with should not happen: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.respond(reply, addr, conn)
|
||||||
|
if err != nil {
|
||||||
|
if isConnClosed(err) {
|
||||||
|
// ignore this error, the connection was closed and that's ok
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Printf("Couldn't respond to UDP packet: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.cache.Set(reply)
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// packet sending functions
|
||||||
|
//
|
||||||
|
|
||||||
|
func (s *Server) respond(resp *dns.Msg, addr net.Addr, conn *net.UDPConn) error {
|
||||||
|
// log.Printf("Replying to %s with %s", addr, resp)
|
||||||
|
resp.Compress = true
|
||||||
|
bytes, err := resp.Pack()
|
||||||
|
if err != nil {
|
||||||
|
return errorx.Decorate(err, "Couldn't convert message into wire format")
|
||||||
|
}
|
||||||
|
n, err := conn.WriteTo(bytes, addr)
|
||||||
|
if n == 0 && isConnClosed(err) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if n != len(bytes) {
|
||||||
|
return fmt.Errorf("WriteTo() returned with %d != %d", n, len(bytes))
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return errorx.Decorate(err, "WriteTo() returned error")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) respondWithServerFailure(request *dns.Msg, addr net.Addr, conn *net.UDPConn) error {
|
||||||
|
resp := dns.Msg{}
|
||||||
|
resp.SetRcode(request, dns.RcodeServerFailure)
|
||||||
|
return s.respond(&resp, addr, conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) respondWithNXDomain(request *dns.Msg, addr net.Addr, conn *net.UDPConn) error {
|
||||||
|
resp := dns.Msg{}
|
||||||
|
resp.SetRcode(request, dns.RcodeNameError)
|
||||||
|
resp.Ns = s.genSOA(request)
|
||||||
|
return s.respond(&resp, addr, conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) genSOA(request *dns.Msg) []dns.RR {
|
||||||
|
zone := ""
|
||||||
|
if len(request.Question) > 0 {
|
||||||
|
zone = request.Question[0].Name
|
||||||
|
}
|
||||||
|
|
||||||
|
soa := dns.SOA{
|
||||||
|
// values copied from verisign's nonexistent .com domain
|
||||||
|
// their exact values are not important in our use case because they are used for domain transfers between primary/secondary DNS servers
|
||||||
|
Refresh: 1800,
|
||||||
|
Retry: 900,
|
||||||
|
Expire: 604800,
|
||||||
|
Minttl: 86400,
|
||||||
|
// copied from AdGuard DNS
|
||||||
|
Ns: "fake-for-negative-caching.adguard.com.",
|
||||||
|
Serial: 100500,
|
||||||
|
// rest is request-specific
|
||||||
|
Hdr: dns.RR_Header{
|
||||||
|
Name: zone,
|
||||||
|
Rrtype: dns.TypeSOA,
|
||||||
|
Ttl: s.BlockedTTL,
|
||||||
|
Class: dns.ClassINET,
|
||||||
|
},
|
||||||
|
Mbox: "hostmaster.", // zone will be appended later if it's not empty or "."
|
||||||
|
}
|
||||||
|
if soa.Hdr.Ttl == 0 {
|
||||||
|
soa.Hdr.Ttl = defaultValues.BlockedTTL
|
||||||
|
}
|
||||||
|
if len(zone) > 0 && zone[0] != '.' {
|
||||||
|
soa.Mbox += zone
|
||||||
|
}
|
||||||
|
return []dns.RR{&soa}
|
||||||
|
}
|
|
@ -0,0 +1,43 @@
|
||||||
|
package dnsforward
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"path"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func isConnClosed(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
nerr, ok := err.(*net.OpError)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(nerr.Err.Error(), "use of closed network connection") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------
|
||||||
|
// debug logging helpers
|
||||||
|
// ---------------------
|
||||||
|
func trace(format string, args ...interface{}) {
|
||||||
|
pc := make([]uintptr, 10) // at least 1 entry needed
|
||||||
|
runtime.Callers(2, pc)
|
||||||
|
f := runtime.FuncForPC(pc[0])
|
||||||
|
var buf strings.Builder
|
||||||
|
buf.WriteString(fmt.Sprintf("%s(): ", path.Base(f.Name())))
|
||||||
|
text := fmt.Sprintf(format, args...)
|
||||||
|
buf.WriteString(text)
|
||||||
|
if len(text) == 0 || text[len(text)-1] != '\n' {
|
||||||
|
buf.WriteRune('\n')
|
||||||
|
}
|
||||||
|
fmt.Fprint(os.Stderr, buf.String())
|
||||||
|
}
|
|
@ -0,0 +1 @@
|
||||||
|
/standalone
|
|
@ -0,0 +1,51 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
_ "net/http/pprof"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"runtime"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/dnsforward"
|
||||||
|
)
|
||||||
|
|
||||||
|
//
|
||||||
|
// main function
|
||||||
|
//
|
||||||
|
func main() {
|
||||||
|
go func() {
|
||||||
|
log.Println(http.ListenAndServe("localhost:6060", nil))
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
for range time.Tick(time.Second) {
|
||||||
|
log.Printf("goroutines = %d", runtime.NumGoroutine())
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
s := dnsforward.Server{}
|
||||||
|
err := s.Start(nil)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
err = s.Stop()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
err = s.Start(&dnsforward.ServerConfig{UDPListenAddr: &net.UDPAddr{Port: 53535}})
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
err = s.Reconfigure(dnsforward.ServerConfig{UDPListenAddr: &net.UDPAddr{Port: 53, IP: net.ParseIP("0.0.0.0")}})
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
log.Printf("Now serving DNS")
|
||||||
|
signal_channel := make(chan os.Signal)
|
||||||
|
signal.Notify(signal_channel, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
<-signal_channel
|
||||||
|
}
|
|
@ -0,0 +1,187 @@
|
||||||
|
package dnsforward
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"log"
|
||||||
|
"math/rand"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/joomcode/errorx"
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
const defaultTimeout = time.Second * 10
|
||||||
|
|
||||||
|
type Upstream interface {
|
||||||
|
Exchange(m *dns.Msg) (*dns.Msg, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// plain DNS
|
||||||
|
//
|
||||||
|
type plainDNS struct {
|
||||||
|
Address string
|
||||||
|
}
|
||||||
|
|
||||||
|
var defaultUDPClient = dns.Client{
|
||||||
|
Timeout: defaultTimeout,
|
||||||
|
UDPSize: dns.MaxMsgSize,
|
||||||
|
}
|
||||||
|
|
||||||
|
var defaultTCPClient = dns.Client{
|
||||||
|
Net: "tcp",
|
||||||
|
UDPSize: dns.MaxMsgSize,
|
||||||
|
Timeout: defaultTimeout,
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *plainDNS) Exchange(m *dns.Msg) (*dns.Msg, error) {
|
||||||
|
reply, _, err := defaultUDPClient.Exchange(m, p.Address)
|
||||||
|
if err != nil && reply != nil && reply.Truncated {
|
||||||
|
log.Printf("Truncated message was received, retrying over TCP, question: %s", m.Question[0].String())
|
||||||
|
reply, _, err = defaultTCPClient.Exchange(m, p.Address)
|
||||||
|
}
|
||||||
|
return reply, err
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// DNS-over-TLS
|
||||||
|
//
|
||||||
|
type dnsOverTLS struct {
|
||||||
|
Address string
|
||||||
|
pool *TLSPool
|
||||||
|
|
||||||
|
sync.RWMutex // protects pool
|
||||||
|
}
|
||||||
|
|
||||||
|
var defaultTLSClient = dns.Client{
|
||||||
|
Net: "tcp-tls",
|
||||||
|
Timeout: defaultTimeout,
|
||||||
|
UDPSize: dns.MaxMsgSize,
|
||||||
|
TLSConfig: &tls.Config{},
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *dnsOverTLS) Exchange(m *dns.Msg) (*dns.Msg, error) {
|
||||||
|
var pool *TLSPool
|
||||||
|
p.RLock()
|
||||||
|
pool = p.pool
|
||||||
|
p.RUnlock()
|
||||||
|
if pool == nil {
|
||||||
|
p.Lock()
|
||||||
|
// lazy initialize it
|
||||||
|
p.pool = &TLSPool{Address: p.Address}
|
||||||
|
p.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
p.RLock()
|
||||||
|
poolConn, err := p.pool.Get()
|
||||||
|
p.RUnlock()
|
||||||
|
if err != nil {
|
||||||
|
return nil, errorx.Decorate(err, "Failed to get a connection from TLSPool to %s", p.Address)
|
||||||
|
}
|
||||||
|
c := dns.Conn{Conn: poolConn}
|
||||||
|
err = c.WriteMsg(m)
|
||||||
|
if err != nil {
|
||||||
|
poolConn.Close()
|
||||||
|
return nil, errorx.Decorate(err, "Failed to send a request to %s", p.Address)
|
||||||
|
}
|
||||||
|
|
||||||
|
reply, err := c.ReadMsg()
|
||||||
|
if err != nil {
|
||||||
|
poolConn.Close()
|
||||||
|
return nil, errorx.Decorate(err, "Failed to read a request from %s", p.Address)
|
||||||
|
}
|
||||||
|
p.RLock()
|
||||||
|
p.pool.Put(poolConn)
|
||||||
|
p.RUnlock()
|
||||||
|
return reply, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// DNS-over-https
|
||||||
|
//
|
||||||
|
type dnsOverHTTPS struct {
|
||||||
|
Address string
|
||||||
|
}
|
||||||
|
|
||||||
|
var defaultHTTPSTransport = http.Transport{}
|
||||||
|
|
||||||
|
var defaultHTTPSClient = http.Client{
|
||||||
|
Transport: &defaultHTTPSTransport,
|
||||||
|
Timeout: defaultTimeout,
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *dnsOverHTTPS) Exchange(m *dns.Msg) (*dns.Msg, error) {
|
||||||
|
buf, err := m.Pack()
|
||||||
|
if err != nil {
|
||||||
|
return nil, errorx.Decorate(err, "Couldn't pack request msg")
|
||||||
|
}
|
||||||
|
bb := bytes.NewBuffer(buf)
|
||||||
|
resp, err := http.Post(p.Address, "application/dns-message", bb)
|
||||||
|
if resp != nil && resp.Body != nil {
|
||||||
|
defer resp.Body.Close()
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, errorx.Decorate(err, "Couldn't do a POST request to '%s'", p.Address)
|
||||||
|
}
|
||||||
|
body, err := ioutil.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errorx.Decorate(err, "Couldn't read body contents for '%s'", p.Address)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("Got an unexpected HTTP status code %d from '%s'", resp.StatusCode, p.Address)
|
||||||
|
}
|
||||||
|
if len(body) == 0 {
|
||||||
|
return nil, fmt.Errorf("Got an unexpected empty body from '%s'", p.Address)
|
||||||
|
}
|
||||||
|
response := dns.Msg{}
|
||||||
|
err = response.Unpack(body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errorx.Decorate(err, "Couldn't unpack DNS response from '%s': body is %s", p.Address, string(body))
|
||||||
|
}
|
||||||
|
return &response, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) chooseUpstream() Upstream {
|
||||||
|
upstreams := s.Upstreams
|
||||||
|
if upstreams == nil {
|
||||||
|
upstreams = defaultValues.Upstreams
|
||||||
|
}
|
||||||
|
if len(upstreams) == 0 {
|
||||||
|
panic("SHOULD NOT HAPPEN: no default upstreams specified")
|
||||||
|
}
|
||||||
|
if len(upstreams) == 1 {
|
||||||
|
return upstreams[0]
|
||||||
|
}
|
||||||
|
n := rand.Intn(len(upstreams))
|
||||||
|
upstream := upstreams[n]
|
||||||
|
return upstream
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetUpstream(address string) (Upstream, error) {
|
||||||
|
if strings.Contains(address, "://") {
|
||||||
|
url, err := url.Parse(address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errorx.Decorate(err, "Failed to parse %s", address)
|
||||||
|
}
|
||||||
|
switch url.Scheme {
|
||||||
|
case "dns":
|
||||||
|
return &plainDNS{Address: address}, nil
|
||||||
|
case "tls":
|
||||||
|
return &dnsOverTLS{Address: address}, nil
|
||||||
|
case "https":
|
||||||
|
return &dnsOverHTTPS{Address: address}, nil
|
||||||
|
default:
|
||||||
|
return &plainDNS{Address: address}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// we don't have scheme in the url, so it's just a plain DNS host:port
|
||||||
|
return &plainDNS{Address: address}, nil
|
||||||
|
}
|
|
@ -0,0 +1,98 @@
|
||||||
|
package dnsforward
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/url"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/joomcode/errorx"
|
||||||
|
)
|
||||||
|
|
||||||
|
// upstream TLS pool.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
// pool := TLSPool{Address: "tls://1.1.1.1:853"}
|
||||||
|
// netConn, err := pool.Get()
|
||||||
|
// if err != nil {panic(err)}
|
||||||
|
// c := dns.Conn{Conn: netConn}
|
||||||
|
// q := dns.Msg{}
|
||||||
|
// q.SetQuestion("google.com.", dns.TypeA)
|
||||||
|
// log.Println(q)
|
||||||
|
// err = c.WriteMsg(&q)
|
||||||
|
// if err != nil {panic(err)}
|
||||||
|
// r, err := c.ReadMsg()
|
||||||
|
// if err != nil {panic(err)}
|
||||||
|
// log.Println(r)
|
||||||
|
// pool.Put(c.Conn)
|
||||||
|
type TLSPool struct {
|
||||||
|
Address string
|
||||||
|
parsedAddress *url.URL
|
||||||
|
parsedAddressMutex sync.RWMutex
|
||||||
|
|
||||||
|
conns []net.Conn
|
||||||
|
sync.Mutex // protects conns
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *TLSPool) getHost() (string, error) {
|
||||||
|
n.parsedAddressMutex.RLock()
|
||||||
|
if n.parsedAddress != nil {
|
||||||
|
n.parsedAddressMutex.RUnlock()
|
||||||
|
return n.parsedAddress.Host, nil
|
||||||
|
}
|
||||||
|
n.parsedAddressMutex.RUnlock()
|
||||||
|
|
||||||
|
n.parsedAddressMutex.Lock()
|
||||||
|
defer n.parsedAddressMutex.Unlock()
|
||||||
|
url, err := url.Parse(n.Address)
|
||||||
|
if err != nil {
|
||||||
|
return "", errorx.Decorate(err, "Failed to parse %s", n.Address)
|
||||||
|
}
|
||||||
|
if url.Scheme != "tls" {
|
||||||
|
return "", fmt.Errorf("TLSPool only supports TLS")
|
||||||
|
}
|
||||||
|
n.parsedAddress = url
|
||||||
|
return n.parsedAddress.Host, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *TLSPool) Get() (net.Conn, error) {
|
||||||
|
host, err := n.getHost()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// get the connection from the slice inside the lock
|
||||||
|
var c net.Conn
|
||||||
|
n.Lock()
|
||||||
|
num := len(n.conns)
|
||||||
|
if num > 0 {
|
||||||
|
last := num - 1
|
||||||
|
c = n.conns[last]
|
||||||
|
n.conns = n.conns[:last]
|
||||||
|
}
|
||||||
|
n.Unlock()
|
||||||
|
|
||||||
|
// if we got connection from the slice, return it
|
||||||
|
if c != nil {
|
||||||
|
// log.Printf("Returning existing connection to %s", host)
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// we'll need a new connection, dial now
|
||||||
|
// log.Printf("Dialing to %s", host)
|
||||||
|
conn, err := tls.Dial("tcp", host, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errorx.Decorate(err, "Failed to connect to %s", host)
|
||||||
|
}
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *TLSPool) Put(c net.Conn) {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
n.Lock()
|
||||||
|
n.conns = append(n.conns, c)
|
||||||
|
n.Unlock()
|
||||||
|
}
|
Loading…
Reference in New Issue