From 5a548be16c1069e52b3343f04be8bce8ee607fa9 Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Wed, 28 Nov 2018 15:40:56 +0300 Subject: [PATCH 01/58] Add dns forwarding server library --- dnsforward/cache.go | 225 ++++++++++++++ dnsforward/cache_test.go | 144 +++++++++ dnsforward/dnsforward.go | 467 ++++++++++++++++++++++++++++ dnsforward/helpers.go | 43 +++ dnsforward/standalone/.gitignore | 1 + dnsforward/standalone/standalone.go | 51 +++ dnsforward/upstream.go | 187 +++++++++++ dnsforward/upstream_pool.go | 98 ++++++ 8 files changed, 1216 insertions(+) create mode 100644 dnsforward/cache.go create mode 100644 dnsforward/cache_test.go create mode 100644 dnsforward/dnsforward.go create mode 100644 dnsforward/helpers.go create mode 100644 dnsforward/standalone/.gitignore create mode 100644 dnsforward/standalone/standalone.go create mode 100644 dnsforward/upstream.go create mode 100644 dnsforward/upstream_pool.go diff --git a/dnsforward/cache.go b/dnsforward/cache.go new file mode 100644 index 00000000..568f284c --- /dev/null +++ b/dnsforward/cache.go @@ -0,0 +1,225 @@ +package dnsforward + +import ( + "encoding/binary" + "log" + "math" + "strings" + "sync" + "time" + + "github.com/miekg/dns" +) + +type item struct { + m *dns.Msg + when time.Time +} + +type cache struct { + items map[string]item + + sync.RWMutex +} + +func (c *cache) Get(request *dns.Msg) (*dns.Msg, bool) { + if request == nil { + return nil, false + } + ok, key := key(request) + if !ok { + log.Printf("Get(): key returned !ok") + return nil, false + } + + c.RLock() + item, ok := c.items[key] + c.RUnlock() + if !ok { + return nil, false + } + // get item's TTL + ttl := findLowestTTL(item.m) + // zero TTL? delete and don't serve it + if ttl == 0 { + c.Lock() + delete(c.items, key) + c.Unlock() + return nil, false + } + // too much time has passed? delete and don't serve it + if time.Since(item.when) >= time.Duration(ttl)*time.Second { + c.Lock() + delete(c.items, key) + c.Unlock() + return nil, false + } + response := item.fromItem(request) + return response, true +} + +func (c *cache) Set(m *dns.Msg) { + if m == nil { + return // no-op + } + if !isRequestCacheable(m) { + return + } + if !isResponseCacheable(m) { + return + } + ok, key := key(m) + if !ok { + return + } + + i := toItem(m) + + c.Lock() + if c.items == nil { + c.items = map[string]item{} + } + c.items[key] = i + c.Unlock() +} + +// check only request fields +func isRequestCacheable(m *dns.Msg) bool { + // truncated messages aren't valid + if m.Truncated { + log.Printf("Refusing to cache truncated message") + return false + } + + // if has wrong number of questions, also don't cache + if len(m.Question) != 1 { + log.Printf("Refusing to cache message with wrong number of questions") + return false + } + + // only OK or NXdomain replies are cached + switch m.Rcode { + case dns.RcodeSuccess: + case dns.RcodeNameError: // that's an NXDomain + case dns.RcodeServerFailure: + return false // quietly refuse, don't log + default: + log.Printf("%s: Refusing to cache message with rcode: %s", m.Question[0].Name, dns.RcodeToString[m.Rcode]) + return false + } + + return true +} + +func isResponseCacheable(m *dns.Msg) bool { + ttl := findLowestTTL(m) + if ttl == 0 { + return false + } + + return true +} + +func findLowestTTL(m *dns.Msg) uint32 { + var ttl uint32 = math.MaxUint32 + found := false + + if m.Answer != nil { + for _, r := range m.Answer { + if r.Header().Ttl < ttl { + ttl = r.Header().Ttl + found = true + } + } + } + + if m.Ns != nil { + for _, r := range m.Ns { + if r.Header().Ttl < ttl { + ttl = r.Header().Ttl + found = true + } + } + } + + if m.Extra != nil { + for _, r := range m.Extra { + if r.Header().Rrtype == dns.TypeOPT { + continue // OPT records use TTL for other purposes + } + if r.Header().Ttl < ttl { + ttl = r.Header().Ttl + found = true + } + } + } + + if found == false { + return 0 + } + + return ttl +} + +// key is binary little endian in sequence: +// uint16(qtype) then uint16(qclass) then name +func key(m *dns.Msg) (bool, string) { + if len(m.Question) != 1 { + log.Printf("got msg with len(m.Question) != 1: %d", len(m.Question)) + return false, "" + } + + bb := strings.Builder{} + b := make([]byte, 2) + binary.LittleEndian.PutUint16(b, m.Question[0].Qtype) + bb.Write(b) + binary.LittleEndian.PutUint16(b, m.Question[0].Qclass) + bb.Write(b) + name := strings.ToLower(m.Question[0].Name) + bb.WriteString(name) + return true, bb.String() +} + +func toItem(m *dns.Msg) item { + return item{ + m: m, + when: time.Now(), + } +} + +func (i *item) fromItem(request *dns.Msg) *dns.Msg { + response := &dns.Msg{} + response.SetReply(request) + + response.Authoritative = false + response.AuthenticatedData = i.m.AuthenticatedData + response.RecursionAvailable = i.m.RecursionAvailable + response.Rcode = i.m.Rcode + + ttl := findLowestTTL(i.m) + timeleft := math.Round(float64(ttl) - time.Since(i.when).Seconds()) + var newttl uint32 + if timeleft > 0 { + newttl = uint32(timeleft) + } + for _, r := range i.m.Answer { + answer := dns.Copy(r) + answer.Header().Ttl = newttl + response.Answer = append(response.Answer, answer) + } + for _, r := range i.m.Ns { + ns := dns.Copy(r) + ns.Header().Ttl = newttl + response.Ns = append(response.Ns, ns) + } + for _, r := range i.m.Extra { + // don't return OPT records as these are hop-by-hop + if r.Header().Rrtype == dns.TypeOPT { + continue + } + extra := dns.Copy(r) + extra.Header().Ttl = newttl + response.Extra = append(response.Extra, extra) + } + return response +} diff --git a/dnsforward/cache_test.go b/dnsforward/cache_test.go new file mode 100644 index 00000000..c9f4577e --- /dev/null +++ b/dnsforward/cache_test.go @@ -0,0 +1,144 @@ +package dnsforward + +import ( + "strings" + "testing" + + "github.com/go-test/deep" + "github.com/miekg/dns" +) + +func RR(rr string) dns.RR { + r, err := dns.NewRR(rr) + if err != nil { + panic(err) + } + return r +} + +// deepEqual is same as deep.Equal, except: +// * ignores Id when comparing +// * question names are not case sensetive +func deepEqualMsg(left *dns.Msg, right *dns.Msg) []string { + temp := *left + temp.Id = right.Id + for i := range left.Question { + left.Question[i].Name = strings.ToLower(left.Question[i].Name) + } + for i := range right.Question { + right.Question[i].Name = strings.ToLower(right.Question[i].Name) + } + return deep.Equal(&temp, right) +} + +func TestCacheSanity(t *testing.T) { + cache := cache{} + request := dns.Msg{} + request.SetQuestion("google.com.", dns.TypeA) + _, ok := cache.Get(&request) + if ok { + t.Fatal("empty cache replied with positive response") + } +} + +type tests struct { + cache []testEntry + cases []testCase +} + +type testEntry struct { + q string + t uint16 + a []dns.RR +} + +type testCase struct { + q string + t uint16 + a []dns.RR + ok bool +} + +func TestCache(t *testing.T) { + tests := tests{ + cache: []testEntry{ + {q: "google.com.", t: dns.TypeA, a: []dns.RR{RR("google.com. 3600 IN A 8.8.8.8")}}, + }, + cases: []testCase{ + {q: "google.com.", t: dns.TypeA, a: []dns.RR{RR("google.com. 3600 IN A 8.8.8.8")}, ok: true}, + {q: "google.com.", t: dns.TypeMX, ok: false}, + }, + } + runTests(t, tests) +} + +func TestCacheMixedCase(t *testing.T) { + tests := tests{ + cache: []testEntry{ + {q: "gOOgle.com.", t: dns.TypeA, a: []dns.RR{RR("google.com. 3600 IN A 8.8.8.8")}}, + }, + cases: []testCase{ + {q: "gOOgle.com.", t: dns.TypeA, a: []dns.RR{RR("google.com. 3600 IN A 8.8.8.8")}, ok: true}, + {q: "google.com.", t: dns.TypeA, a: []dns.RR{RR("google.com. 3600 IN A 8.8.8.8")}, ok: true}, + {q: "GOOGLE.COM.", t: dns.TypeA, a: []dns.RR{RR("google.com. 3600 IN A 8.8.8.8")}, ok: true}, + {q: "gOOgle.com.", t: dns.TypeMX, ok: false}, + {q: "google.com.", t: dns.TypeMX, ok: false}, + {q: "GOOGLE.COM.", t: dns.TypeMX, ok: false}, + }, + } + runTests(t, tests) +} + +func TestZeroTTL(t *testing.T) { + tests := tests{ + cache: []testEntry{ + {q: "gOOgle.com.", t: dns.TypeA, a: []dns.RR{RR("google.com. 0 IN A 8.8.8.8")}}, + }, + cases: []testCase{ + {q: "google.com.", t: dns.TypeA, ok: false}, + {q: "google.com.", t: dns.TypeA, ok: false}, + {q: "google.com.", t: dns.TypeA, ok: false}, + {q: "google.com.", t: dns.TypeMX, ok: false}, + {q: "google.com.", t: dns.TypeMX, ok: false}, + {q: "google.com.", t: dns.TypeMX, ok: false}, + }, + } + runTests(t, tests) +} + +func runTests(t *testing.T, tests tests) { + t.Helper() + cache := cache{} + for _, tc := range tests.cache { + reply := dns.Msg{} + reply.SetQuestion(tc.q, tc.t) + reply.Response = true + reply.Answer = tc.a + cache.Set(&reply) + } + for _, tc := range tests.cases { + request := dns.Msg{} + request.SetQuestion(tc.q, tc.t) + val, ok := cache.Get(&request) + if diff := deep.Equal(ok, tc.ok); diff != nil { + t.Error(diff) + } + if tc.a != nil { + if ok == false { + continue + } + reply := dns.Msg{} + reply.SetQuestion(tc.q, tc.t) + reply.Response = true + reply.Answer = tc.a + cache.Set(&reply) + if diff := deepEqualMsg(val, &reply); diff != nil { + t.Error(diff) + } else { + if diff := deep.Equal(val, reply); diff == nil { + t.Error("different message ID were not caught") + } + } + } + } +} diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go new file mode 100644 index 00000000..546c4eae --- /dev/null +++ b/dnsforward/dnsforward.go @@ -0,0 +1,467 @@ +package dnsforward + +import ( + "fmt" + "log" + "net" + "reflect" + "sync" + + "github.com/AdguardTeam/AdGuardHome/dnsfilter" + "github.com/joomcode/errorx" + "github.com/miekg/dns" +) + +// Server is the main way to start a DNS server +// Example: +// s := dnsforward.Server{} +// err := s.Start(nil) // will start a DNS server listening on default port 53, in a goroutine +// err := s.Reconfigure(ServerConfig{UDPListenAddr: &net.UDPAddr{Port: 53535}}) // will reconfigure running DNS server to listen on UDP port 53535 +// err := s.Stop() // will stop listening on port 53535 and cancel all goroutines +// err := s.Start(nil) // will start listening again, on port 53535, in a goroutine +// +// The zero Server is empty and ready for use. +type Server struct { + udpListen *net.UDPConn + + dnsFilter *dnsfilter.Dnsfilter + + cache cache + + sync.RWMutex + ServerConfig +} + +// The zero ServerConfig is empty and ready for use. +type ServerConfig struct { + UDPListenAddr *net.UDPAddr // if nil, then default is is used (port 53 on *) + BlockedTTL uint32 // if 0, then default is used (3600) + Upstreams []Upstream + Filters []Filter +} + +var defaultValues = ServerConfig{ + UDPListenAddr: &net.UDPAddr{Port: 53}, + BlockedTTL: 3600, + Upstreams: []Upstream{ + //// dns over HTTPS + // &dnsOverHTTPS{Address: "https://1.1.1.1/dns-query"}, + // &dnsOverHTTPS{Address: "https://dns.google.com/experimental"}, + // &dnsOverHTTPS{Address: "https://doh.cleanbrowsing.org/doh/security-filter/"}, + // &dnsOverHTTPS{Address: "https://dns10.quad9.net/dns-query"}, + // &dnsOverHTTPS{Address: "https://doh.powerdns.org"}, + // &dnsOverHTTPS{Address: "https://doh.securedns.eu/dns-query"}, + + //// dns over TLS + // &dnsOverTLS{Address: "tls://8.8.8.8:853"}, + // &dnsOverTLS{Address: "tls://8.8.4.4:853"}, + &dnsOverTLS{Address: "tls://1.1.1.1:853"}, + &dnsOverTLS{Address: "tls://1.0.0.1:853"}, + + //// plainDNS + // &plainDNS{Address: "8.8.8.8:53"}, + // &plainDNS{Address: "8.8.4.4:53"}, + // &plainDNS{Address: "1.1.1.1:53"}, + // &plainDNS{Address: "1.0.0.1:53"}, + }, +} + +type Filter struct { + ID int64 + Rules []string +} + +// +// packet loop +// +func (s *Server) packetLoop() { + log.Printf("Entering packet handle loop") + b := make([]byte, dns.MaxMsgSize) + for { + s.RLock() + conn := s.udpListen + s.RUnlock() + if conn == nil { + log.Printf("udp socket has disappeared, exiting loop") + break + } + n, addr, err := conn.ReadFrom(b) + // documentation says to handle the packet even if err occurs, so do that first + if n > 0 { + // make a copy of all bytes because ReadFrom() will overwrite contents of b on next call + // we need the contents to survive the call because we're handling them in goroutine + p := make([]byte, n) + copy(p, b) + go s.handlePacket(p, addr, conn) // ignore errors + } + if err != nil { + if isConnClosed(err) { + log.Printf("ReadFrom() returned because we're reading from a closed connection, exiting loop") + break + } + log.Printf("Got error when reading from udp listen: %s", err) + } + } +} + +// +// Control functions +// + +func (s *Server) Start(config *ServerConfig) error { + s.Lock() + defer s.Unlock() + if config != nil { + s.ServerConfig = *config + } + // TODO: handle being called Start() second time after Stop() + if s.udpListen == nil { + log.Printf("Creating UDP socket") + var err error + addr := s.UDPListenAddr + if addr == nil { + addr = defaultValues.UDPListenAddr + } + s.udpListen, err = net.ListenUDP("udp", addr) + if err != nil { + return errorx.Decorate(err, "Couldn't listen to UDP socket") + } + log.Println(s.udpListen.LocalAddr(), s.UDPListenAddr) + } + + if s.dnsFilter == nil { + log.Printf("Creating dnsfilter") + s.dnsFilter = dnsfilter.New() + } + + go s.packetLoop() + + return nil +} + +func (s *Server) Stop() error { + s.Lock() + defer s.Unlock() + if s.udpListen != nil { + err := s.udpListen.Close() + if err != nil { + return errorx.Decorate(err, "Couldn't close UDP listening socket") + } + s.udpListen = nil + } + return nil +} + +// +// Server reconfigure +// + +func (s *Server) reconfigureListenAddr(new ServerConfig) error { + oldAddr := s.UDPListenAddr + if oldAddr == nil { + oldAddr = defaultValues.UDPListenAddr + } + newAddr := new.UDPListenAddr + if newAddr == nil { + newAddr = defaultValues.UDPListenAddr + } + if newAddr.Port == 0 { + return errorx.IllegalArgument.New("new port cannot be 0") + } + if reflect.DeepEqual(oldAddr, newAddr) { + // do nothing, the addresses are exactly the same + log.Printf("Not going to rebind because addresses are same: %v -> %v", oldAddr, newAddr) + return nil + } + + // rebind, using a strategy: + // * if ports are different, bind new first, then close old + // * if ports are same, close old first, then bind new + var newListen *net.UDPConn + var err error + if oldAddr.Port != newAddr.Port { + log.Printf("Rebinding -- ports are different so bind first then close") + newListen, err = net.ListenUDP("udp", newAddr) + if err != nil { + return errorx.Decorate(err, "Couldn't bind to %v", newAddr) + } + if s.udpListen != nil { + err := s.udpListen.Close() + if err != nil { + return errorx.Decorate(err, "Couldn't close UDP listening socket") + } + } + } else { + log.Printf("Rebinding -- ports are same so close first then bind") + if s.udpListen != nil { + err := s.udpListen.Close() + if err != nil { + return errorx.Decorate(err, "Couldn't close UDP listening socket") + } + } + newListen, err = net.ListenUDP("udp", newAddr) + if err != nil { + return errorx.Decorate(err, "Couldn't bind to %v", newAddr) + } + } + s.Lock() + s.udpListen = newListen + s.UDPListenAddr = new.UDPListenAddr + s.Unlock() + log.Println(s.udpListen.LocalAddr(), s.UDPListenAddr) + + go s.packetLoop() // the old one has quit, use new one + + return nil +} + +func (s *Server) reconfigureBlockedTTL(new ServerConfig) { + newVal := new.BlockedTTL + if newVal == 0 { + newVal = defaultValues.BlockedTTL + } + oldVal := s.BlockedTTL + if oldVal == 0 { + oldVal = defaultValues.BlockedTTL + } + if newVal != oldVal { + s.BlockedTTL = new.BlockedTTL + } +} + +func (s *Server) reconfigureUpstreams(new ServerConfig) { + newVal := new.Upstreams + if len(newVal) == 0 { + newVal = defaultValues.Upstreams + } + oldVal := s.Upstreams + if len(oldVal) == 0 { + oldVal = defaultValues.Upstreams + } + if reflect.DeepEqual(newVal, oldVal) { + // they're exactly the same, do nothing + return + } + s.Upstreams = new.Upstreams +} + +func (s *Server) reconfigureFilters(new ServerConfig) { + newFilters := new.Filters + if len(newFilters) == 0 { + newFilters = defaultValues.Filters + } + oldFilters := s.Filters + if len(oldFilters) == 0 { + oldFilters = defaultValues.Filters + } + if reflect.DeepEqual(newFilters, oldFilters) { + // they're exactly the same, do nothing + return + } + + dnsFilter := dnsfilter.New() + for _, f := range newFilters { + for _, rule := range f.Rules { + err := dnsFilter.AddRule(rule, f.ID) + if err == dnsfilter.ErrAlreadyExists || err == dnsfilter.ErrInvalidSyntax { + continue + } + if err != nil { + log.Printf("Cannot add rule %s: %s", rule, err) + // Just ignore invalid rules + continue + } + } + } + + s.Lock() + oldDnsFilter := s.dnsFilter + s.dnsFilter = dnsFilter + s.Unlock() + + oldDnsFilter.Destroy() +} + +func (s *Server) Reconfigure(new ServerConfig) error { + s.reconfigureBlockedTTL(new) + s.reconfigureUpstreams(new) + s.reconfigureFilters(new) + + err := s.reconfigureListenAddr(new) + if err != nil { + return errorx.Decorate(err, "Couldn't reconfigure to new listening address %+v", new.UDPListenAddr) + } + return nil +} + +// +// packet handling functions +// + +func (s *Server) handlePacket(p []byte, addr net.Addr, conn *net.UDPConn) { + // log.Printf("Got packet %d bytes from %s: %v", len(p), addr, p) + msg := dns.Msg{} + err := msg.Unpack(p) + if err != nil { + log.Printf("failed to unpack DNS packet: %s", err) + return + } + + // + // DNS packet byte format is valid + // + // any errors below here require a response to client + // log.Printf("Unpacked: %v", msg.String()) + if len(msg.Question) != 1 { + log.Printf("Got invalid number of questions: %v", len(msg.Question)) + err := s.respondWithServerFailure(&msg, addr, conn) + if err != nil { + log.Printf("Couldn't respond to UDP packet: %s", err) + return + } + } + + { + val, ok := s.cache.Get(&msg) + if ok && val != nil { + err = s.respond(val, addr, conn) + if err != nil { + if isConnClosed(err) { + // ignore this error, the connection was closed and that's ok + return + } + log.Printf("Couldn't respond to UDP packet: %s", err) + return + } + return + } + } + host := msg.Question[0].Name + res, err := s.dnsFilter.CheckHost(host) + if err != nil { + log.Printf("dnsfilter failed to check host '%s': %s", host, err) + err := s.respondWithServerFailure(&msg, addr, conn) + if err != nil { + log.Printf("Couldn't respond to UDP packet: %s", err) + return + } + } else if res.IsFiltered { + log.Printf("Host %s is filtered, reason - '%s', matched rule: '%s'", host, res.Reason, res.Rule) + err := s.respondWithNXDomain(&msg, addr, conn) + if err != nil { + log.Printf("Couldn't respond to UDP packet: %s", err) + return + } + } + + // TODO: replace with single-socket implementation + // TODO: replace 8.8.8.8:53 with configurable upstreams + // TODO: support DoH, DoT and TCP + upstream := s.chooseUpstream() + reply, err := upstream.Exchange(&msg) + if err != nil { + log.Printf("talking to upstream failed for host '%s': %s", host, err) + err := s.respondWithServerFailure(&msg, addr, conn) + if err != nil { + if isConnClosed(err) { + // ignore this error, the connection was closed and that's ok + return + } + log.Printf("Couldn't respond to UDP packet with server failure: %s", err) + return + } + return + } + if reply == nil { + log.Printf("SHOULD NOT HAPPEN upstream returned empty message for host '%s'. Request is %v", host, msg.String()) + err := s.respondWithServerFailure(&msg, addr, conn) + if err != nil { + log.Printf("Couldn't respond to UDP packet with should not happen: %s", err) + return + } + return + } + + err = s.respond(reply, addr, conn) + if err != nil { + if isConnClosed(err) { + // ignore this error, the connection was closed and that's ok + return + } + log.Printf("Couldn't respond to UDP packet: %s", err) + return + } + + s.cache.Set(reply) +} + +// +// packet sending functions +// + +func (s *Server) respond(resp *dns.Msg, addr net.Addr, conn *net.UDPConn) error { + // log.Printf("Replying to %s with %s", addr, resp) + resp.Compress = true + bytes, err := resp.Pack() + if err != nil { + return errorx.Decorate(err, "Couldn't convert message into wire format") + } + n, err := conn.WriteTo(bytes, addr) + if n == 0 && isConnClosed(err) { + return err + } + if n != len(bytes) { + return fmt.Errorf("WriteTo() returned with %d != %d", n, len(bytes)) + } + if err != nil { + return errorx.Decorate(err, "WriteTo() returned error") + } + return nil +} + +func (s *Server) respondWithServerFailure(request *dns.Msg, addr net.Addr, conn *net.UDPConn) error { + resp := dns.Msg{} + resp.SetRcode(request, dns.RcodeServerFailure) + return s.respond(&resp, addr, conn) +} + +func (s *Server) respondWithNXDomain(request *dns.Msg, addr net.Addr, conn *net.UDPConn) error { + resp := dns.Msg{} + resp.SetRcode(request, dns.RcodeNameError) + resp.Ns = s.genSOA(request) + return s.respond(&resp, addr, conn) +} + +func (s *Server) genSOA(request *dns.Msg) []dns.RR { + zone := "" + if len(request.Question) > 0 { + zone = request.Question[0].Name + } + + soa := dns.SOA{ + // values copied from verisign's nonexistent .com domain + // their exact values are not important in our use case because they are used for domain transfers between primary/secondary DNS servers + Refresh: 1800, + Retry: 900, + Expire: 604800, + Minttl: 86400, + // copied from AdGuard DNS + Ns: "fake-for-negative-caching.adguard.com.", + Serial: 100500, + // rest is request-specific + Hdr: dns.RR_Header{ + Name: zone, + Rrtype: dns.TypeSOA, + Ttl: s.BlockedTTL, + Class: dns.ClassINET, + }, + Mbox: "hostmaster.", // zone will be appended later if it's not empty or "." + } + if soa.Hdr.Ttl == 0 { + soa.Hdr.Ttl = defaultValues.BlockedTTL + } + if len(zone) > 0 && zone[0] != '.' { + soa.Mbox += zone + } + return []dns.RR{&soa} +} diff --git a/dnsforward/helpers.go b/dnsforward/helpers.go new file mode 100644 index 00000000..339023a0 --- /dev/null +++ b/dnsforward/helpers.go @@ -0,0 +1,43 @@ +package dnsforward + +import ( + "fmt" + "net" + "os" + "path" + "runtime" + "strings" +) + +func isConnClosed(err error) bool { + if err == nil { + return false + } + nerr, ok := err.(*net.OpError) + if !ok { + return false + } + + if strings.Contains(nerr.Err.Error(), "use of closed network connection") { + return true + } + + return false +} + +// --------------------- +// debug logging helpers +// --------------------- +func trace(format string, args ...interface{}) { + pc := make([]uintptr, 10) // at least 1 entry needed + runtime.Callers(2, pc) + f := runtime.FuncForPC(pc[0]) + var buf strings.Builder + buf.WriteString(fmt.Sprintf("%s(): ", path.Base(f.Name()))) + text := fmt.Sprintf(format, args...) + buf.WriteString(text) + if len(text) == 0 || text[len(text)-1] != '\n' { + buf.WriteRune('\n') + } + fmt.Fprint(os.Stderr, buf.String()) +} diff --git a/dnsforward/standalone/.gitignore b/dnsforward/standalone/.gitignore new file mode 100644 index 00000000..5f81988c --- /dev/null +++ b/dnsforward/standalone/.gitignore @@ -0,0 +1 @@ +/standalone \ No newline at end of file diff --git a/dnsforward/standalone/standalone.go b/dnsforward/standalone/standalone.go new file mode 100644 index 00000000..ae3e6d13 --- /dev/null +++ b/dnsforward/standalone/standalone.go @@ -0,0 +1,51 @@ +package main + +import ( + "log" + "net" + "net/http" + _ "net/http/pprof" + "os" + "os/signal" + "runtime" + "syscall" + "time" + + "github.com/AdguardTeam/AdGuardHome/dnsforward" +) + +// +// main function +// +func main() { + go func() { + log.Println(http.ListenAndServe("localhost:6060", nil)) + }() + go func() { + for range time.Tick(time.Second) { + log.Printf("goroutines = %d", runtime.NumGoroutine()) + } + }() + s := dnsforward.Server{} + err := s.Start(nil) + if err != nil { + panic(err) + } + time.Sleep(time.Second) + err = s.Stop() + if err != nil { + panic(err) + } + err = s.Start(&dnsforward.ServerConfig{UDPListenAddr: &net.UDPAddr{Port: 53535}}) + if err != nil { + panic(err) + } + err = s.Reconfigure(dnsforward.ServerConfig{UDPListenAddr: &net.UDPAddr{Port: 53, IP: net.ParseIP("0.0.0.0")}}) + if err != nil { + panic(err) + } + log.Printf("Now serving DNS") + signal_channel := make(chan os.Signal) + signal.Notify(signal_channel, syscall.SIGINT, syscall.SIGTERM) + <-signal_channel +} diff --git a/dnsforward/upstream.go b/dnsforward/upstream.go new file mode 100644 index 00000000..3e0a61c0 --- /dev/null +++ b/dnsforward/upstream.go @@ -0,0 +1,187 @@ +package dnsforward + +import ( + "bytes" + "crypto/tls" + "fmt" + "io/ioutil" + "log" + "math/rand" + "net/http" + "net/url" + "strings" + "sync" + "time" + + "github.com/joomcode/errorx" + "github.com/miekg/dns" +) + +const defaultTimeout = time.Second * 10 + +type Upstream interface { + Exchange(m *dns.Msg) (*dns.Msg, error) +} + +// +// plain DNS +// +type plainDNS struct { + Address string +} + +var defaultUDPClient = dns.Client{ + Timeout: defaultTimeout, + UDPSize: dns.MaxMsgSize, +} + +var defaultTCPClient = dns.Client{ + Net: "tcp", + UDPSize: dns.MaxMsgSize, + Timeout: defaultTimeout, +} + +func (p *plainDNS) Exchange(m *dns.Msg) (*dns.Msg, error) { + reply, _, err := defaultUDPClient.Exchange(m, p.Address) + if err != nil && reply != nil && reply.Truncated { + log.Printf("Truncated message was received, retrying over TCP, question: %s", m.Question[0].String()) + reply, _, err = defaultTCPClient.Exchange(m, p.Address) + } + return reply, err +} + +// +// DNS-over-TLS +// +type dnsOverTLS struct { + Address string + pool *TLSPool + + sync.RWMutex // protects pool +} + +var defaultTLSClient = dns.Client{ + Net: "tcp-tls", + Timeout: defaultTimeout, + UDPSize: dns.MaxMsgSize, + TLSConfig: &tls.Config{}, +} + +func (p *dnsOverTLS) Exchange(m *dns.Msg) (*dns.Msg, error) { + var pool *TLSPool + p.RLock() + pool = p.pool + p.RUnlock() + if pool == nil { + p.Lock() + // lazy initialize it + p.pool = &TLSPool{Address: p.Address} + p.Unlock() + } + + p.RLock() + poolConn, err := p.pool.Get() + p.RUnlock() + if err != nil { + return nil, errorx.Decorate(err, "Failed to get a connection from TLSPool to %s", p.Address) + } + c := dns.Conn{Conn: poolConn} + err = c.WriteMsg(m) + if err != nil { + poolConn.Close() + return nil, errorx.Decorate(err, "Failed to send a request to %s", p.Address) + } + + reply, err := c.ReadMsg() + if err != nil { + poolConn.Close() + return nil, errorx.Decorate(err, "Failed to read a request from %s", p.Address) + } + p.RLock() + p.pool.Put(poolConn) + p.RUnlock() + return reply, nil +} + +// +// DNS-over-https +// +type dnsOverHTTPS struct { + Address string +} + +var defaultHTTPSTransport = http.Transport{} + +var defaultHTTPSClient = http.Client{ + Transport: &defaultHTTPSTransport, + Timeout: defaultTimeout, +} + +func (p *dnsOverHTTPS) Exchange(m *dns.Msg) (*dns.Msg, error) { + buf, err := m.Pack() + if err != nil { + return nil, errorx.Decorate(err, "Couldn't pack request msg") + } + bb := bytes.NewBuffer(buf) + resp, err := http.Post(p.Address, "application/dns-message", bb) + if resp != nil && resp.Body != nil { + defer resp.Body.Close() + } + if err != nil { + return nil, errorx.Decorate(err, "Couldn't do a POST request to '%s'", p.Address) + } + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, errorx.Decorate(err, "Couldn't read body contents for '%s'", p.Address) + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Got an unexpected HTTP status code %d from '%s'", resp.StatusCode, p.Address) + } + if len(body) == 0 { + return nil, fmt.Errorf("Got an unexpected empty body from '%s'", p.Address) + } + response := dns.Msg{} + err = response.Unpack(body) + if err != nil { + return nil, errorx.Decorate(err, "Couldn't unpack DNS response from '%s': body is %s", p.Address, string(body)) + } + return &response, nil +} + +func (s *Server) chooseUpstream() Upstream { + upstreams := s.Upstreams + if upstreams == nil { + upstreams = defaultValues.Upstreams + } + if len(upstreams) == 0 { + panic("SHOULD NOT HAPPEN: no default upstreams specified") + } + if len(upstreams) == 1 { + return upstreams[0] + } + n := rand.Intn(len(upstreams)) + upstream := upstreams[n] + return upstream +} + +func GetUpstream(address string) (Upstream, error) { + if strings.Contains(address, "://") { + url, err := url.Parse(address) + if err != nil { + return nil, errorx.Decorate(err, "Failed to parse %s", address) + } + switch url.Scheme { + case "dns": + return &plainDNS{Address: address}, nil + case "tls": + return &dnsOverTLS{Address: address}, nil + case "https": + return &dnsOverHTTPS{Address: address}, nil + default: + return &plainDNS{Address: address}, nil + } + } + + // we don't have scheme in the url, so it's just a plain DNS host:port + return &plainDNS{Address: address}, nil +} diff --git a/dnsforward/upstream_pool.go b/dnsforward/upstream_pool.go new file mode 100644 index 00000000..9756d54f --- /dev/null +++ b/dnsforward/upstream_pool.go @@ -0,0 +1,98 @@ +package dnsforward + +import ( + "crypto/tls" + "fmt" + "net" + "net/url" + "sync" + + "github.com/joomcode/errorx" +) + +// upstream TLS pool. +// +// Example: +// pool := TLSPool{Address: "tls://1.1.1.1:853"} +// netConn, err := pool.Get() +// if err != nil {panic(err)} +// c := dns.Conn{Conn: netConn} +// q := dns.Msg{} +// q.SetQuestion("google.com.", dns.TypeA) +// log.Println(q) +// err = c.WriteMsg(&q) +// if err != nil {panic(err)} +// r, err := c.ReadMsg() +// if err != nil {panic(err)} +// log.Println(r) +// pool.Put(c.Conn) +type TLSPool struct { + Address string + parsedAddress *url.URL + parsedAddressMutex sync.RWMutex + + conns []net.Conn + sync.Mutex // protects conns +} + +func (n *TLSPool) getHost() (string, error) { + n.parsedAddressMutex.RLock() + if n.parsedAddress != nil { + n.parsedAddressMutex.RUnlock() + return n.parsedAddress.Host, nil + } + n.parsedAddressMutex.RUnlock() + + n.parsedAddressMutex.Lock() + defer n.parsedAddressMutex.Unlock() + url, err := url.Parse(n.Address) + if err != nil { + return "", errorx.Decorate(err, "Failed to parse %s", n.Address) + } + if url.Scheme != "tls" { + return "", fmt.Errorf("TLSPool only supports TLS") + } + n.parsedAddress = url + return n.parsedAddress.Host, nil +} + +func (n *TLSPool) Get() (net.Conn, error) { + host, err := n.getHost() + if err != nil { + return nil, err + } + + // get the connection from the slice inside the lock + var c net.Conn + n.Lock() + num := len(n.conns) + if num > 0 { + last := num - 1 + c = n.conns[last] + n.conns = n.conns[:last] + } + n.Unlock() + + // if we got connection from the slice, return it + if c != nil { + // log.Printf("Returning existing connection to %s", host) + return c, nil + } + + // we'll need a new connection, dial now + // log.Printf("Dialing to %s", host) + conn, err := tls.Dial("tcp", host, nil) + if err != nil { + return nil, errorx.Decorate(err, "Failed to connect to %s", host) + } + return conn, nil +} + +func (n *TLSPool) Put(c net.Conn) { + if c == nil { + return + } + n.Lock() + n.conns = append(n.conns, c) + n.Unlock() +} From ea1353422f557eaa371f7a430c588a4e7297ed14 Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Wed, 28 Nov 2018 16:05:24 +0300 Subject: [PATCH 02/58] User rules -- hold them as a slice of strings, which is how dns forwarding server will expect them. --- app.go | 2 +- config.go | 19 +++++-------------- control.go | 22 +++++++++++----------- 3 files changed, 17 insertions(+), 26 deletions(-) diff --git a/app.go b/app.go index 85cb96ae..f958c7b4 100644 --- a/app.go +++ b/app.go @@ -149,7 +149,7 @@ func main() { log.Printf("Couldn't load filter %d contents due to %s", filter.ID, err) // clear LastUpdated so it gets fetched right away } - if len(filter.Contents) == 0 { + if len(filter.Rules) == 0 { filter.LastUpdated = time.Time{} } } diff --git a/config.go b/config.go index d141706b..59cff6bc 100644 --- a/config.go +++ b/config.go @@ -76,7 +76,7 @@ type filter struct { LastUpdated time.Time `json:"lastUpdated,omitempty" yaml:"last_updated,omitempty"` ID int64 `json:"id"` // auto-assigned when filter is added (see nextFilterID), json by default keeps ID uppercase but we need lowercase - Contents []byte `json:"-" yaml:"-"` // not in yaml or json + Rules []string `json:"-" yaml:"-"` // not in yaml or json } var defaultDNS = []string{"tls://1.1.1.1", "tls://1.0.0.1"} @@ -112,21 +112,12 @@ var config = configuration{ // Creates a helper object for working with the user rules func userFilter() filter { - // TODO: This should be calculated when UserRules are set - var contents []byte - for _, rule := range config.UserRules { - contents = append(contents, []byte(rule)...) - contents = append(contents, '\n') - } - - userFilter := filter{ + return filter{ // User filter always has constant ID=0 - ID: userFilterID, - Contents: contents, - Enabled: true, + ID: userFilterID, + Rules: config.UserRules, + Enabled: true, } - - return userFilter } // Loads configuration from the YAML file diff --git a/control.go b/control.go index 3e4dcabf..d2011cb0 100644 --- a/control.go +++ b/control.go @@ -1,7 +1,6 @@ package main import ( - "bytes" "encoding/json" "fmt" "io/ioutil" @@ -9,6 +8,7 @@ import ( "net/http" "os" "path/filepath" + "reflect" "regexp" "strconv" "strings" @@ -577,7 +577,7 @@ func refreshFiltersIfNeccessary(force bool) int { } // A helper function that parses filter contents and returns a number of rules and a filter name (if there's any) -func parseFilterContents(contents []byte) (int, string) { +func parseFilterContents(contents []byte) (int, string, []string) { lines := strings.Split(string(contents), "\n") rulesCount := 0 name := "" @@ -596,7 +596,7 @@ func parseFilterContents(contents []byte) (int, string) { } } - return rulesCount, name + return rulesCount, name, lines } // Checks for filters updates @@ -645,21 +645,21 @@ func (filter *filter) update(force bool) (bool, error) { } // Extract filter name and count number of rules - rulesCount, filterName := parseFilterContents(body) + rulesCount, filterName, rules := parseFilterContents(body) if filterName != "" { filter.Name = filterName } // Check if the filter has been really changed - if bytes.Equal(filter.Contents, body) { + if reflect.DeepEqual(filter.Rules, rules) { log.Printf("The filter %d text has not changed", filter.ID) return false, nil } log.Printf("Filter %d has been updated: %d bytes, %d rules", filter.ID, len(body), rulesCount) filter.RulesCount = rulesCount - filter.Contents = body + filter.Rules = rules return true, nil } @@ -668,8 +668,9 @@ func (filter *filter) update(force bool) (bool, error) { func (filter *filter) save() error { filterFilePath := filter.Path() log.Printf("Saving filter %d contents to: %s", filter.ID, filterFilePath) + body := []byte(strings.Join(filter.Rules, "\n")) - return safeWriteFile(filterFilePath, filter.Contents) + return safeWriteFile(filterFilePath, body) } // loads filter contents from the file in dataDir @@ -692,12 +693,11 @@ func (filter *filter) load() error { return err } - log.Printf("Filter %d length is %d", filter.ID, len(filterFileContents)) - filter.Contents = filterFileContents + log.Printf("File %s, id %d, length %d", filterFilePath, filter.ID, len(filterFileContents)) + rulesCount, _, rules := parseFilterContents(filterFileContents) - // Now extract the rules count - rulesCount, _ := parseFilterContents(filter.Contents) filter.RulesCount = rulesCount + filter.Rules = rules return nil } From 639b34c7d1d92ed52bbc9f6a51a2a1778ece1a16 Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Wed, 28 Nov 2018 16:26:33 +0300 Subject: [PATCH 03/58] dnsforward library -- Fix race conditions found by -race --- dnsforward/dnsforward.go | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index 546c4eae..c7559b42 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -185,19 +185,23 @@ func (s *Server) reconfigureListenAddr(new ServerConfig) error { if err != nil { return errorx.Decorate(err, "Couldn't bind to %v", newAddr) } + s.Lock() if s.udpListen != nil { - err := s.udpListen.Close() - if err != nil { - return errorx.Decorate(err, "Couldn't close UDP listening socket") - } + err = s.udpListen.Close() + } + s.Unlock() + if err != nil { + return errorx.Decorate(err, "Couldn't close UDP listening socket") } } else { log.Printf("Rebinding -- ports are same so close first then bind") + s.Lock() if s.udpListen != nil { - err := s.udpListen.Close() - if err != nil { - return errorx.Decorate(err, "Couldn't close UDP listening socket") - } + err = s.udpListen.Close() + } + s.Unlock() + if err != nil { + return errorx.Decorate(err, "Couldn't close UDP listening socket") } newListen, err = net.ListenUDP("udp", newAddr) if err != nil { From 7944f23d9521963a954df2767321411169ae170b Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Wed, 28 Nov 2018 16:28:28 +0300 Subject: [PATCH 04/58] dnsforward library -- consistently nullify and close listening socket when we're done with it. --- dnsforward/dnsforward.go | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index c7559b42..0e3686af 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -97,6 +97,16 @@ func (s *Server) packetLoop() { if err != nil { if isConnClosed(err) { log.Printf("ReadFrom() returned because we're reading from a closed connection, exiting loop") + var err error + s.Lock() + if s.udpListen != nil { + err = s.udpListen.Close() + s.udpListen = nil + } + s.Unlock() + if err != nil { + log.Printf("Failed to close udp connection while exiting loop: %s", err) + } break } log.Printf("Got error when reading from udp listen: %s", err) @@ -124,6 +134,7 @@ func (s *Server) Start(config *ServerConfig) error { } s.udpListen, err = net.ListenUDP("udp", addr) if err != nil { + s.udpListen = nil return errorx.Decorate(err, "Couldn't listen to UDP socket") } log.Println(s.udpListen.LocalAddr(), s.UDPListenAddr) @@ -144,10 +155,10 @@ func (s *Server) Stop() error { defer s.Unlock() if s.udpListen != nil { err := s.udpListen.Close() + s.udpListen = nil if err != nil { return errorx.Decorate(err, "Couldn't close UDP listening socket") } - s.udpListen = nil } return nil } @@ -188,6 +199,7 @@ func (s *Server) reconfigureListenAddr(new ServerConfig) error { s.Lock() if s.udpListen != nil { err = s.udpListen.Close() + s.udpListen = nil } s.Unlock() if err != nil { @@ -198,6 +210,7 @@ func (s *Server) reconfigureListenAddr(new ServerConfig) error { s.Lock() if s.udpListen != nil { err = s.udpListen.Close() + s.udpListen = nil } s.Unlock() if err != nil { From 584f441141923873cda7833aba153c793cfbf24b Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Wed, 28 Nov 2018 16:43:50 +0300 Subject: [PATCH 05/58] dnsforward library -- introduce IsRunning() --- dnsforward/dnsforward.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index 0e3686af..e72869f2 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -163,6 +163,16 @@ func (s *Server) Stop() error { return nil } +func (s *Server) IsRunning() bool { + s.RLock() + isRunning := true + if s.udpListen == nil { + isRunning = false + } + s.RUnlock() + return isRunning +} + // // Server reconfigure // From a904f85e61478fa82138609dcfd6409e40255177 Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Wed, 28 Nov 2018 16:44:19 +0300 Subject: [PATCH 06/58] dnsforward library -- default to plain DNS for high-performance testing. --- dnsforward/dnsforward.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index e72869f2..e4ef070b 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -55,14 +55,14 @@ var defaultValues = ServerConfig{ //// dns over TLS // &dnsOverTLS{Address: "tls://8.8.8.8:853"}, // &dnsOverTLS{Address: "tls://8.8.4.4:853"}, - &dnsOverTLS{Address: "tls://1.1.1.1:853"}, - &dnsOverTLS{Address: "tls://1.0.0.1:853"}, + // &dnsOverTLS{Address: "tls://1.1.1.1:853"}, + // &dnsOverTLS{Address: "tls://1.0.0.1:853"}, //// plainDNS - // &plainDNS{Address: "8.8.8.8:53"}, - // &plainDNS{Address: "8.8.4.4:53"}, - // &plainDNS{Address: "1.1.1.1:53"}, - // &plainDNS{Address: "1.0.0.1:53"}, + &plainDNS{Address: "8.8.8.8:53"}, + &plainDNS{Address: "8.8.4.4:53"}, + &plainDNS{Address: "1.1.1.1:53"}, + &plainDNS{Address: "1.0.0.1:53"}, }, } From feabc21864dd8d08b2ceed359169f542ac866e92 Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Wed, 28 Nov 2018 16:45:30 +0300 Subject: [PATCH 07/58] Unplug coreDNS and plug dnsforward library. --- config.go | 97 +---------------------------------- coredns.go | 145 ++++++++++++----------------------------------------- go.mod | 1 + go.sum | 2 + 4 files changed, 38 insertions(+), 207 deletions(-) diff --git a/config.go b/config.go index 59cff6bc..5d114538 100644 --- a/config.go +++ b/config.go @@ -1,14 +1,11 @@ package main import ( - "bytes" "io/ioutil" "log" "os" "path/filepath" - "regexp" "sync" - "text/template" "time" "gopkg.in/yaml.v2" @@ -56,7 +53,7 @@ type coreDNSConfig struct { SafeSearchEnabled bool `yaml:"safesearch_enabled"` ParentalEnabled bool `yaml:"parental_enabled"` ParentalSensitivity int `yaml:"parental_sensitivity"` - BlockedResponseTTL int `yaml:"blocked_response_ttl"` + BlockedResponseTTL uint32 `yaml:"blocked_response_ttl"` QueryLogEnabled bool `yaml:"querylog_enabled"` Ratelimit int `yaml:"ratelimit"` RefuseAny bool `yaml:"refuse_any"` @@ -188,98 +185,8 @@ func (c *configuration) write() error { return nil } -// -------------- -// coredns config -// -------------- -func writeCoreDNSConfig() error { - coreFile := filepath.Join(config.ourBinaryDir, config.CoreDNS.coreFile) - log.Printf("Writing DNS config: %s", coreFile) - configText, err := generateCoreDNSConfigText() - if err != nil { - log.Printf("Couldn't generate DNS config: %s", err) - return err - } - err = safeWriteFile(coreFile, []byte(configText)) - if err != nil { - log.Printf("Couldn't save DNS config: %s", err) - return err - } - return nil -} - func writeAllConfigs() error { - err := config.write() - if err != nil { - log.Printf("Couldn't write our config: %s", err) - return err - } - err = writeCoreDNSConfig() - if err != nil { - log.Printf("Couldn't write DNS config: %s", err) - return err - } - return nil -} - -const coreDNSConfigTemplate = `.:{{.Port}} { - {{if .ProtectionEnabled}}dnsfilter { - {{if .SafeBrowsingEnabled}}safebrowsing{{end}} - {{if .ParentalEnabled}}parental {{.ParentalSensitivity}}{{end}} - {{if .SafeSearchEnabled}}safesearch{{end}} - {{if .QueryLogEnabled}}querylog{{end}} - blocked_ttl {{.BlockedResponseTTL}} - {{if .FilteringEnabled}}{{range .Filters}}{{if and .Enabled .Contents}} - filter {{.ID}} "{{.Path}}" - {{end}}{{end}}{{end}} - }{{end}} - {{.Pprof}} - {{if .RefuseAny}}refuseany{{end}} - {{if gt .Ratelimit 0}}ratelimit {{.Ratelimit}}{{end}} - hosts { - fallthrough - } - {{if .UpstreamDNS}}upstream {{range .UpstreamDNS}}{{.}} {{end}} { bootstrap {{.BootstrapDNS}} }{{end}} - {{.Cache}} - {{.Prometheus}} -} -` - -var removeEmptyLines = regexp.MustCompile("([\t ]*\n)+") - -// generate CoreDNS config text -func generateCoreDNSConfigText() (string, error) { - t, err := template.New("config").Parse(coreDNSConfigTemplate) - if err != nil { - log.Printf("Couldn't generate DNS config: %s", err) - return "", err - } - - var configBytes bytes.Buffer - temporaryConfig := config.CoreDNS - - // generate temporary filter list, needed to put userfilter in coredns config - filters := []filter{} - - // first of all, append the user filter - userFilter := userFilter() - - filters = append(filters, userFilter) - - // then go through other filters - filters = append(filters, config.Filters...) - temporaryConfig.Filters = filters - - // run the template - err = t.Execute(&configBytes, &temporaryConfig) - if err != nil { - log.Printf("Couldn't generate DNS config: %s", err) - return "", err - } - configText := configBytes.String() - - // remove empty lines from generated config - configText = removeEmptyLines.ReplaceAllString(configText, "\n") - return configText, nil + return config.write() } // Set the next filter ID to max(filter.ID) + 1 diff --git a/coredns.go b/coredns.go index 376e6210..df91b1a6 100644 --- a/coredns.go +++ b/coredns.go @@ -3,130 +3,51 @@ package main import ( "fmt" "log" - "os" - "path/filepath" - "sync" // Include all plugins. + "net" - _ "github.com/AdguardTeam/AdGuardHome/coredns_plugin" - _ "github.com/AdguardTeam/AdGuardHome/coredns_plugin/ratelimit" - _ "github.com/AdguardTeam/AdGuardHome/coredns_plugin/refuseany" - _ "github.com/AdguardTeam/AdGuardHome/upstream" - "github.com/coredns/coredns/core/dnsserver" - "github.com/coredns/coredns/coremain" - _ "github.com/coredns/coredns/plugin/auto" - _ "github.com/coredns/coredns/plugin/autopath" - _ "github.com/coredns/coredns/plugin/bind" - _ "github.com/coredns/coredns/plugin/cache" - _ "github.com/coredns/coredns/plugin/chaos" - _ "github.com/coredns/coredns/plugin/debug" - _ "github.com/coredns/coredns/plugin/dnssec" - _ "github.com/coredns/coredns/plugin/dnstap" - _ "github.com/coredns/coredns/plugin/erratic" - _ "github.com/coredns/coredns/plugin/errors" - _ "github.com/coredns/coredns/plugin/file" - _ "github.com/coredns/coredns/plugin/forward" - _ "github.com/coredns/coredns/plugin/health" - _ "github.com/coredns/coredns/plugin/hosts" - _ "github.com/coredns/coredns/plugin/loadbalance" - _ "github.com/coredns/coredns/plugin/log" - _ "github.com/coredns/coredns/plugin/loop" - _ "github.com/coredns/coredns/plugin/metadata" - _ "github.com/coredns/coredns/plugin/metrics" - _ "github.com/coredns/coredns/plugin/nsid" - _ "github.com/coredns/coredns/plugin/pprof" - _ "github.com/coredns/coredns/plugin/proxy" - _ "github.com/coredns/coredns/plugin/reload" - _ "github.com/coredns/coredns/plugin/rewrite" - _ "github.com/coredns/coredns/plugin/root" - _ "github.com/coredns/coredns/plugin/secondary" - _ "github.com/coredns/coredns/plugin/template" - _ "github.com/coredns/coredns/plugin/tls" - _ "github.com/coredns/coredns/plugin/whoami" - _ "github.com/mholt/caddy/onevent" + "github.com/AdguardTeam/AdGuardHome/dnsforward" + "github.com/joomcode/errorx" ) -// Directives are registered in the order they should be -// executed. -// -// Ordering is VERY important. Every plugin will -// feel the effects of all other plugin below -// (after) them during a request, but they must not -// care what plugin above them are doing. - -var directives = []string{ - "metadata", - "tls", - "reload", - "nsid", - "root", - "bind", - "debug", - "health", - "pprof", - "prometheus", - "errors", - "log", - "refuseany", - "ratelimit", - "dnsfilter", - "dnstap", - "chaos", - "loadbalance", - "cache", - "rewrite", - "dnssec", - "autopath", - "template", - "hosts", - "file", - "auto", - "secondary", - "loop", - "forward", - "proxy", - "upstream", - "erratic", - "whoami", - "on", -} - -func init() { - dnsserver.Directives = directives -} - -var ( - isCoreDNSRunningLock sync.Mutex - isCoreDNSRunning = false -) +var dnsServer = dnsforward.Server{} func isRunning() bool { - isCoreDNSRunningLock.Lock() - value := isCoreDNSRunning - isCoreDNSRunningLock.Unlock() - return value + return dnsServer.IsRunning() } func startDNSServer() error { - isCoreDNSRunningLock.Lock() - if isCoreDNSRunning { - isCoreDNSRunningLock.Unlock() + if isRunning() { return fmt.Errorf("Unable to start coreDNS: Already running") } - isCoreDNSRunning = true - isCoreDNSRunningLock.Unlock() - configpath := filepath.Join(config.ourBinaryDir, config.CoreDNS.coreFile) - os.Args = os.Args[:1] - os.Args = append(os.Args, "-conf") - os.Args = append(os.Args, configpath) - - err := writeCoreDNSConfig() - if err != nil { - errortext := fmt.Errorf("Unable to write coredns config: %s", err) - log.Println(errortext) - return errortext + filters := []dnsforward.Filter{} + for _, filter := range config.Filters { + filters = append(filters, dnsforward.Filter{ + ID: filter.ID, + Rules: filter.Rules, + }) + } + + newconfig := dnsforward.ServerConfig{ + UDPListenAddr: &net.UDPAddr{Port: config.CoreDNS.Port}, + BlockedTTL: config.CoreDNS.BlockedResponseTTL, + Filters: filters, + } + + for _, u := range config.CoreDNS.UpstreamDNS { + upstream, err := dnsforward.GetUpstream(u) + if err != nil { + log.Printf("Couldn't get upstream: %s", err) + // continue, just ignore the upstream + continue + } + newconfig.Upstreams = append(newconfig.Upstreams, upstream) + } + + err := dnsServer.Start(&newconfig) + if err != nil { + return errorx.Decorate(err, "Couldn't start forwarding DNS server") } - go coremain.Run() return nil } diff --git a/go.mod b/go.mod index dae96b71..1b8d78e6 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,7 @@ require ( github.com/gobuffalo/packr v1.19.0 github.com/google/uuid v1.0.0 // indirect github.com/grpc-ecosystem/grpc-opentracing v0.0.0-20180507213350-8e809c8a8645 // indirect + github.com/joomcode/errorx v0.1.0 github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect github.com/mholt/caddy v0.11.0 github.com/miekg/dns v1.0.15 diff --git a/go.sum b/go.sum index 06efaa9e..4ecb93be 100644 --- a/go.sum +++ b/go.sum @@ -41,6 +41,8 @@ github.com/grpc-ecosystem/grpc-opentracing v0.0.0-20180507213350-8e809c8a8645/go github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/joho/godotenv v1.3.0 h1:Zjp+RcGpHhGlrMbJzXTrZZPrWj+1vfm90La1wgB6Bhc= github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg= +github.com/joomcode/errorx v0.1.0 h1:QmJMiI1DE1UFje2aI1ZWO/VMT5a32qBoXUclGOt8vsc= +github.com/joomcode/errorx v0.1.0/go.mod h1:kgco15ekB6cs+4Xjzo7SPeXzx38PbJzBwbnu9qfVNHQ= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/markbates/oncer v0.0.0-20181014194634-05fccaae8fc4 h1:Mlji5gkcpzkqTROyE4ZxZ8hN7osunMb2RuGVrbvMvCc= github.com/markbates/oncer v0.0.0-20181014194634-05fccaae8fc4/go.mod h1:Ld9puTsIW75CHf65OeIOkyKbteujpZVXDpWK6YGZbxE= From 4eb122e973b38cc42facaa4deb12736f1fa64864 Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Wed, 28 Nov 2018 17:29:48 +0300 Subject: [PATCH 08/58] Avoid duplication of fields in filter struct. --- config.go | 18 ++++++++++-------- dnsforward/dnsforward.go | 4 ++-- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/config.go b/config.go index 5d114538..11692215 100644 --- a/config.go +++ b/config.go @@ -8,6 +8,7 @@ import ( "sync" "time" + "github.com/AdguardTeam/AdGuardHome/dnsforward" "gopkg.in/yaml.v2" ) @@ -71,9 +72,8 @@ type filter struct { Name string `json:"name" yaml:"name"` RulesCount int `json:"rulesCount" yaml:"-"` LastUpdated time.Time `json:"lastUpdated,omitempty" yaml:"last_updated,omitempty"` - ID int64 `json:"id"` // auto-assigned when filter is added (see nextFilterID), json by default keeps ID uppercase but we need lowercase - Rules []string `json:"-" yaml:"-"` // not in yaml or json + dnsforward.Filter `yaml:",inline"` } var defaultDNS = []string{"tls://1.1.1.1", "tls://1.0.0.1"} @@ -100,10 +100,10 @@ var config = configuration{ Prometheus: "prometheus :9153", }, Filters: []filter{ - {ID: 1, Enabled: true, URL: "https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt", Name: "AdGuard Simplified Domain Names filter"}, - {ID: 2, Enabled: false, URL: "https://adaway.org/hosts.txt", Name: "AdAway"}, - {ID: 3, Enabled: false, URL: "https://hosts-file.net/ad_servers.txt", Name: "hpHosts - Ad and Tracking servers only"}, - {ID: 4, Enabled: false, URL: "http://www.malwaredomainlist.com/hostslist/hosts.txt", Name: "MalwareDomainList.com Hosts List"}, + {Filter: dnsforward.Filter{ID: 1}, Enabled: true, URL: "https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt", Name: "AdGuard Simplified Domain Names filter"}, + {Filter: dnsforward.Filter{ID: 2}, Enabled: false, URL: "https://adaway.org/hosts.txt", Name: "AdAway"}, + {Filter: dnsforward.Filter{ID: 3}, Enabled: false, URL: "https://hosts-file.net/ad_servers.txt", Name: "hpHosts - Ad and Tracking servers only"}, + {Filter: dnsforward.Filter{ID: 4}, Enabled: false, URL: "http://www.malwaredomainlist.com/hostslist/hosts.txt", Name: "MalwareDomainList.com Hosts List"}, }, } @@ -111,9 +111,11 @@ var config = configuration{ func userFilter() filter { return filter{ // User filter always has constant ID=0 - ID: userFilterID, - Rules: config.UserRules, Enabled: true, + Filter: dnsforward.Filter{ + ID: userFilterID, + Rules: config.UserRules, + }, } } diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index e4ef070b..7e3083fd 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -67,8 +67,8 @@ var defaultValues = ServerConfig{ } type Filter struct { - ID int64 - Rules []string + ID int64 `json:"id"` // auto-assigned when filter is added (see nextFilterID), json by default keeps ID uppercase but we need lowercase + Rules []string `json:"-" yaml:"-"` // not in yaml or json } // From e4a35647060d29f41519b396b9f963d58477fe1d Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Wed, 28 Nov 2018 17:55:01 +0300 Subject: [PATCH 09/58] Fix a logical race that wasn't detectable by -race -- we were closing a connection that was already reestablished. --- dnsforward/dnsforward.go | 51 ++++++++++++++++++++++++++++++++-------- dnsforward/helpers.go | 7 ++++++ 2 files changed, 48 insertions(+), 10 deletions(-) diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index 7e3083fd..1bc91453 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -32,6 +32,46 @@ type Server struct { ServerConfig } +// uncomment this block to have tracing of locks +/* +func (s *Server) Lock() { + pc := make([]uintptr, 10) // at least 1 entry needed + runtime.Callers(2, pc) + f := runtime.FuncForPC(pc[0]) + file, line := f.FileLine(pc[0]) + fmt.Fprintf(os.Stderr, "%s:%d %s() -> Lock() -> in progress\n", path.Base(file), line, path.Base(f.Name())) + s.RWMutex.Lock() + fmt.Fprintf(os.Stderr, "%s:%d %s() -> Lock() -> done\n", path.Base(file), line, path.Base(f.Name())) +} +func (s *Server) RLock() { + pc := make([]uintptr, 10) // at least 1 entry needed + runtime.Callers(2, pc) + f := runtime.FuncForPC(pc[0]) + file, line := f.FileLine(pc[0]) + fmt.Fprintf(os.Stderr, "%s:%d %s() -> RLock() -> in progress\n", path.Base(file), line, path.Base(f.Name())) + s.RWMutex.RLock() + fmt.Fprintf(os.Stderr, "%s:%d %s() -> RLock() -> done\n", path.Base(file), line, path.Base(f.Name())) +} +func (s *Server) Unlock() { + pc := make([]uintptr, 10) // at least 1 entry needed + runtime.Callers(2, pc) + f := runtime.FuncForPC(pc[0]) + file, line := f.FileLine(pc[0]) + fmt.Fprintf(os.Stderr, "%s:%d %s() -> Unlock() -> in progress\n", path.Base(file), line, path.Base(f.Name())) + s.RWMutex.Unlock() + fmt.Fprintf(os.Stderr, "%s:%d %s() -> Unlock() -> done\n", path.Base(file), line, path.Base(f.Name())) +} +func (s *Server) RUnlock() { + pc := make([]uintptr, 10) // at least 1 entry needed + runtime.Callers(2, pc) + f := runtime.FuncForPC(pc[0]) + file, line := f.FileLine(pc[0]) + fmt.Fprintf(os.Stderr, "%s:%d %s() -> RUnlock() -> in progress\n", path.Base(file), line, path.Base(f.Name())) + s.RWMutex.RUnlock() + fmt.Fprintf(os.Stderr, "%s:%d %s() -> RUnlock() -> done\n", path.Base(file), line, path.Base(f.Name())) +} +*/ + // The zero ServerConfig is empty and ready for use. type ServerConfig struct { UDPListenAddr *net.UDPAddr // if nil, then default is is used (port 53 on *) @@ -97,16 +137,7 @@ func (s *Server) packetLoop() { if err != nil { if isConnClosed(err) { log.Printf("ReadFrom() returned because we're reading from a closed connection, exiting loop") - var err error - s.Lock() - if s.udpListen != nil { - err = s.udpListen.Close() - s.udpListen = nil - } - s.Unlock() - if err != nil { - log.Printf("Failed to close udp connection while exiting loop: %s", err) - } + // don't try to nullify s.udpListen here, because s.udpListen could be already re-bound to listen break } log.Printf("Got error when reading from udp listen: %s", err) diff --git a/dnsforward/helpers.go b/dnsforward/helpers.go index 339023a0..52b65c87 100644 --- a/dnsforward/helpers.go +++ b/dnsforward/helpers.go @@ -28,6 +28,13 @@ func isConnClosed(err error) bool { // --------------------- // debug logging helpers // --------------------- +func _Func() string { + pc := make([]uintptr, 10) // at least 1 entry needed + runtime.Callers(2, pc) + f := runtime.FuncForPC(pc[0]) + return path.Base(f.Name()) +} + func trace(format string, args ...interface{}) { pc := make([]uintptr, 10) // at least 1 entry needed runtime.Callers(2, pc) From 7120f551c8ab2a5c75d654930246cc9811338bfb Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Wed, 28 Nov 2018 18:14:12 +0300 Subject: [PATCH 10/58] dnsforward -- rename BlockedTTL to BlockedResponseTTL to be in line with app's config variable. --- coredns.go | 6 +++--- dnsforward/dnsforward.go | 30 +++++++++++++++--------------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/coredns.go b/coredns.go index df91b1a6..b612c8d0 100644 --- a/coredns.go +++ b/coredns.go @@ -29,9 +29,9 @@ func startDNSServer() error { } newconfig := dnsforward.ServerConfig{ - UDPListenAddr: &net.UDPAddr{Port: config.CoreDNS.Port}, - BlockedTTL: config.CoreDNS.BlockedResponseTTL, - Filters: filters, + UDPListenAddr: &net.UDPAddr{Port: config.CoreDNS.Port}, + BlockedResponseTTL: config.CoreDNS.BlockedResponseTTL, + Filters: filters, } for _, u := range config.CoreDNS.UpstreamDNS { diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index 1bc91453..45a66c53 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -74,15 +74,15 @@ func (s *Server) RUnlock() { // The zero ServerConfig is empty and ready for use. type ServerConfig struct { - UDPListenAddr *net.UDPAddr // if nil, then default is is used (port 53 on *) - BlockedTTL uint32 // if 0, then default is used (3600) - Upstreams []Upstream - Filters []Filter + UDPListenAddr *net.UDPAddr // if nil, then default is is used (port 53 on *) + BlockedResponseTTL uint32 // if 0, then default is used (3600) + Upstreams []Upstream + Filters []Filter } var defaultValues = ServerConfig{ - UDPListenAddr: &net.UDPAddr{Port: 53}, - BlockedTTL: 3600, + UDPListenAddr: &net.UDPAddr{Port: 53}, + BlockedResponseTTL: 3600, Upstreams: []Upstream{ //// dns over HTTPS // &dnsOverHTTPS{Address: "https://1.1.1.1/dns-query"}, @@ -273,17 +273,17 @@ func (s *Server) reconfigureListenAddr(new ServerConfig) error { return nil } -func (s *Server) reconfigureBlockedTTL(new ServerConfig) { - newVal := new.BlockedTTL +func (s *Server) reconfigureBlockedResponseTTL(new ServerConfig) { + newVal := new.BlockedResponseTTL if newVal == 0 { - newVal = defaultValues.BlockedTTL + newVal = defaultValues.BlockedResponseTTL } - oldVal := s.BlockedTTL + oldVal := s.BlockedResponseTTL if oldVal == 0 { - oldVal = defaultValues.BlockedTTL + oldVal = defaultValues.BlockedResponseTTL } if newVal != oldVal { - s.BlockedTTL = new.BlockedTTL + s.BlockedResponseTTL = new.BlockedResponseTTL } } @@ -341,7 +341,7 @@ func (s *Server) reconfigureFilters(new ServerConfig) { } func (s *Server) Reconfigure(new ServerConfig) error { - s.reconfigureBlockedTTL(new) + s.reconfigureBlockedResponseTTL(new) s.reconfigureUpstreams(new) s.reconfigureFilters(new) @@ -510,13 +510,13 @@ func (s *Server) genSOA(request *dns.Msg) []dns.RR { Hdr: dns.RR_Header{ Name: zone, Rrtype: dns.TypeSOA, - Ttl: s.BlockedTTL, + Ttl: s.BlockedResponseTTL, Class: dns.ClassINET, }, Mbox: "hostmaster.", // zone will be appended later if it's not empty or "." } if soa.Hdr.Ttl == 0 { - soa.Hdr.Ttl = defaultValues.BlockedTTL + soa.Hdr.Ttl = defaultValues.BlockedResponseTTL } if len(zone) > 0 && zone[0] != '.' { soa.Mbox += zone From 8316d39b4282599681a827a50d6541c5cb4987ed Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Wed, 28 Nov 2018 18:24:04 +0300 Subject: [PATCH 11/58] Move filtering setting fields from main app to dnsforward. --- config.go | 64 +++++++++++++++++++--------------------- dnsforward/dnsforward.go | 12 ++++++++ 2 files changed, 43 insertions(+), 33 deletions(-) diff --git a/config.go b/config.go index 11692215..46de5f3b 100644 --- a/config.go +++ b/config.go @@ -44,25 +44,21 @@ type configuration struct { // field ordering is important -- yaml fields will mirror ordering from here type coreDNSConfig struct { - binaryFile string - coreFile string - Filters []filter `yaml:"-"` - Port int `yaml:"port"` - ProtectionEnabled bool `yaml:"protection_enabled"` - FilteringEnabled bool `yaml:"filtering_enabled"` - SafeBrowsingEnabled bool `yaml:"safebrowsing_enabled"` - SafeSearchEnabled bool `yaml:"safesearch_enabled"` - ParentalEnabled bool `yaml:"parental_enabled"` - ParentalSensitivity int `yaml:"parental_sensitivity"` - BlockedResponseTTL uint32 `yaml:"blocked_response_ttl"` - QueryLogEnabled bool `yaml:"querylog_enabled"` - Ratelimit int `yaml:"ratelimit"` - RefuseAny bool `yaml:"refuse_any"` - Pprof string `yaml:"-"` - Cache string `yaml:"-"` - Prometheus string `yaml:"-"` - BootstrapDNS string `yaml:"bootstrap_dns"` - UpstreamDNS []string `yaml:"upstream_dns"` + binaryFile string + coreFile string + Filters []filter `yaml:"-"` + Port int `yaml:"port"` + + dnsforward.FilteringConfig `yaml:",inline"` + + QueryLogEnabled bool `yaml:"querylog_enabled"` + Ratelimit int `yaml:"ratelimit"` + RefuseAny bool `yaml:"refuse_any"` + Pprof string `yaml:"-"` + Cache string `yaml:"-"` + Prometheus string `yaml:"-"` + BootstrapDNS string `yaml:"bootstrap_dns"` + UpstreamDNS []string `yaml:"upstream_dns"` } // field ordering is important -- yaml fields will mirror ordering from here @@ -84,20 +80,22 @@ var config = configuration{ BindPort: 3000, BindHost: "127.0.0.1", CoreDNS: coreDNSConfig{ - Port: 53, - binaryFile: "coredns", // only filename, no path - coreFile: "Corefile", // only filename, no path - ProtectionEnabled: true, - FilteringEnabled: true, - SafeBrowsingEnabled: false, - BlockedResponseTTL: 10, // in seconds - QueryLogEnabled: true, - Ratelimit: 20, - RefuseAny: true, - BootstrapDNS: "8.8.8.8:53", - UpstreamDNS: defaultDNS, - Cache: "cache", - Prometheus: "prometheus :9153", + Port: 53, + binaryFile: "coredns", // only filename, no path + coreFile: "Corefile", // only filename, no path + FilteringConfig: dnsforward.FilteringConfig{ + ProtectionEnabled: true, + FilteringEnabled: true, + SafeBrowsingEnabled: false, + BlockedResponseTTL: 10, // in seconds + }, + QueryLogEnabled: true, + Ratelimit: 20, + RefuseAny: true, + BootstrapDNS: "8.8.8.8:53", + UpstreamDNS: defaultDNS, + Cache: "cache", + Prometheus: "prometheus :9153", }, Filters: []filter{ {Filter: dnsforward.Filter{ID: 1}, Enabled: true, URL: "https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt", Name: "AdGuard Simplified Domain Names filter"}, diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index 45a66c53..c087d17e 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -72,12 +72,24 @@ func (s *Server) RUnlock() { } */ +type FilteringConfig struct { + ProtectionEnabled bool `yaml:"protection_enabled"` + FilteringEnabled bool `yaml:"filtering_enabled"` + SafeBrowsingEnabled bool `yaml:"safebrowsing_enabled"` + SafeSearchEnabled bool `yaml:"safesearch_enabled"` + ParentalEnabled bool `yaml:"parental_enabled"` + ParentalSensitivity int `yaml:"parental_sensitivity"` + BlockedResponseTTL uint32 `yaml:"blocked_response_ttl"` +} + // The zero ServerConfig is empty and ready for use. type ServerConfig struct { UDPListenAddr *net.UDPAddr // if nil, then default is is used (port 53 on *) BlockedResponseTTL uint32 // if 0, then default is used (3600) Upstreams []Upstream Filters []Filter + + FilteringConfig } var defaultValues = ServerConfig{ From 9d87ae95e60756588b66172052125a5b63c8c292 Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Wed, 28 Nov 2018 18:40:20 +0300 Subject: [PATCH 12/58] dnsforward -- if given addresses without ports, assign default ports --- dnsforward/upstream.go | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/dnsforward/upstream.go b/dnsforward/upstream.go index 3e0a61c0..b245e7f1 100644 --- a/dnsforward/upstream.go +++ b/dnsforward/upstream.go @@ -7,6 +7,7 @@ import ( "io/ioutil" "log" "math/rand" + "net" "net/http" "net/url" "strings" @@ -172,16 +173,31 @@ func GetUpstream(address string) (Upstream, error) { } switch url.Scheme { case "dns": - return &plainDNS{Address: address}, nil + if url.Port() == "" { + url.Host += ":53" + } + return &plainDNS{Address: url.String()}, nil case "tls": - return &dnsOverTLS{Address: address}, nil + if url.Port() == "" { + url.Host += ":853" + } + return &dnsOverTLS{Address: url.String()}, nil case "https": - return &dnsOverHTTPS{Address: address}, nil + return &dnsOverHTTPS{Address: url.String()}, nil default: - return &plainDNS{Address: address}, nil + // assume it's plain DNS + if url.Port() == "" { + url.Host += ":53" + } + return &plainDNS{Address: url.String()}, nil } } // we don't have scheme in the url, so it's just a plain DNS host:port + _, _, err := net.SplitHostPort(address) + if err != nil { + // doesn't have port, default to 53 + address = net.JoinHostPort(address, "53") + } return &plainDNS{Address: address}, nil } From 9c4b791621344edb7ed36e4cc6aa88bed9876681 Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Wed, 28 Nov 2018 18:57:20 +0300 Subject: [PATCH 13/58] coredns reload -> dnsServer.Reconfigure() --- control.go | 12 ++++-------- coredns.go | 27 ++++++++++++++++++++++----- 2 files changed, 26 insertions(+), 13 deletions(-) diff --git a/control.go b/control.go index d2011cb0..25c059b0 100644 --- a/control.go +++ b/control.go @@ -38,17 +38,13 @@ var client = &http.Client{ // ------------------- // coredns run control // ------------------- -func tellCoreDNSToReload() { - corednsplugin.Reload <- true -} - func writeAllConfigsAndReloadCoreDNS() error { err := writeAllConfigs() if err != nil { log.Printf("Couldn't write all configs: %s", err) return err } - tellCoreDNSToReload() + reconfigureDNSServer() return nil } @@ -155,7 +151,7 @@ func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) { http.Error(w, errorText, http.StatusInternalServerError) return } - tellCoreDNSToReload() + reconfigureDNSServer() _, err = fmt.Fprintf(w, "OK %d servers\n", len(hosts)) if err != nil { errorText := fmt.Sprintf("Couldn't write body: %s", err) @@ -386,7 +382,7 @@ func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) { return } - tellCoreDNSToReload() + reconfigureDNSServer() _, err = fmt.Fprintf(w, "OK %d rules\n", filter.RulesCount) if err != nil { @@ -571,7 +567,7 @@ func refreshFiltersIfNeccessary(force bool) int { config.Unlock() if updateCount > 0 { - tellCoreDNSToReload() + reconfigureDNSServer() } return updateCount } diff --git a/coredns.go b/coredns.go index b612c8d0..348b4241 100644 --- a/coredns.go +++ b/coredns.go @@ -15,11 +15,7 @@ func isRunning() bool { return dnsServer.IsRunning() } -func startDNSServer() error { - if isRunning() { - return fmt.Errorf("Unable to start coreDNS: Already running") - } - +func generateServerConfig() dnsforward.ServerConfig { filters := []dnsforward.Filter{} for _, filter := range config.Filters { filters = append(filters, dnsforward.Filter{ @@ -43,7 +39,15 @@ func startDNSServer() error { } newconfig.Upstreams = append(newconfig.Upstreams, upstream) } + return newconfig +} +func startDNSServer() error { + if isRunning() { + return fmt.Errorf("Unable to start forwarding DNS server: Already running") + } + + newconfig := generateServerConfig() err := dnsServer.Start(&newconfig) if err != nil { return errorx.Decorate(err, "Couldn't start forwarding DNS server") @@ -51,3 +55,16 @@ func startDNSServer() error { return nil } + +func reconfigureDNSServer() error { + if !isRunning() { + return fmt.Errorf("Refusing to reconfigure forwarding DNS server: not running") + } + + err := dnsServer.Reconfigure(generateServerConfig()) + if err != nil { + return errorx.Decorate(err, "Couldn't start forwarding DNS server") + } + + return nil +} From d27fd0488d4d415271e5233bf9d7a371790df1ee Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Wed, 28 Nov 2018 20:14:54 +0300 Subject: [PATCH 14/58] Move filter-related variables, types and methods to filter.go --- config.go | 42 --------- control.go | 181 -------------------------------------- filter.go | 251 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 251 insertions(+), 223 deletions(-) create mode 100644 filter.go diff --git a/config.go b/config.go index 46de5f3b..21a75b42 100644 --- a/config.go +++ b/config.go @@ -6,7 +6,6 @@ import ( "os" "path/filepath" "sync" - "time" "github.com/AdguardTeam/AdGuardHome/dnsforward" "gopkg.in/yaml.v2" @@ -16,12 +15,8 @@ const ( currentSchemaVersion = 1 // used for upgrading from old configs to new config dataDir = "data" // data storage filterDir = "filters" // cache location for downloaded filters, it's under DataDir - userFilterID = 0 // special filter ID, always 0 ) -// Just a counter that we use for incrementing the filter ID -var nextFilterID int64 = time.Now().Unix() - // configuration is loaded from YAML // field ordering is important -- yaml fields will mirror ordering from here type configuration struct { @@ -61,17 +56,6 @@ type coreDNSConfig struct { UpstreamDNS []string `yaml:"upstream_dns"` } -// field ordering is important -- yaml fields will mirror ordering from here -type filter struct { - Enabled bool `json:"enabled"` - URL string `json:"url"` - Name string `json:"name" yaml:"name"` - RulesCount int `json:"rulesCount" yaml:"-"` - LastUpdated time.Time `json:"lastUpdated,omitempty" yaml:"last_updated,omitempty"` - - dnsforward.Filter `yaml:",inline"` -} - var defaultDNS = []string{"tls://1.1.1.1", "tls://1.0.0.1"} // initialize to default values, will be changed later when reading config or parsing command line @@ -105,18 +89,6 @@ var config = configuration{ }, } -// Creates a helper object for working with the user rules -func userFilter() filter { - return filter{ - // User filter always has constant ID=0 - Enabled: true, - Filter: dnsforward.Filter{ - ID: userFilterID, - Rules: config.UserRules, - }, - } -} - // Loads configuration from the YAML file func parseConfig() error { configFile := filepath.Join(config.ourBinaryDir, config.ourConfigFilename) @@ -189,17 +161,3 @@ func writeAllConfigs() error { return config.write() } -// Set the next filter ID to max(filter.ID) + 1 -func updateUniqueFilterID(filters []filter) { - for _, filter := range filters { - if nextFilterID < filter.ID { - nextFilterID = filter.ID + 1 - } - } -} - -func assignUniqueFilterID() int64 { - value := nextFilterID - nextFilterID += 1 - return value -} diff --git a/control.go b/control.go index 25c059b0..f41655c3 100644 --- a/control.go +++ b/control.go @@ -7,9 +7,6 @@ import ( "log" "net/http" "os" - "path/filepath" - "reflect" - "regexp" "strconv" "strings" "time" @@ -22,8 +19,6 @@ import ( const updatePeriod = time.Minute * 30 -var filterTitleRegexp = regexp.MustCompile(`^! Title: +(.*)$`) - // cached version.json to avoid hammering github.io for each page reload var versionCheckJSON []byte var versionCheckLastTime time.Time @@ -527,182 +522,6 @@ func handleFilteringRefresh(w http.ResponseWriter, r *http.Request) { fmt.Fprintf(w, "OK %d filters updated\n", updated) } -// Sets up a timer that will be checking for filters updates periodically -func periodicallyRefreshFilters() { - for range time.Tick(time.Minute) { - refreshFiltersIfNeccessary(false) - } -} - -// Checks filters updates if necessary -// If force is true, it ignores the filter.LastUpdated field value -func refreshFiltersIfNeccessary(force bool) int { - config.Lock() - - // fetch URLs - updateCount := 0 - for i := range config.Filters { - filter := &config.Filters[i] // otherwise we will be operating on a copy - - if filter.ID == 0 { // protect against users modifying the yaml and removing the ID - filter.ID = assignUniqueFilterID() - } - - updated, err := filter.update(force) - if err != nil { - log.Printf("Failed to update filter %s: %s\n", filter.URL, err) - continue - } - if updated { - // Saving it to the filters dir now - err = filter.save() - if err != nil { - log.Printf("Failed to save the updated filter %d: %s", filter.ID, err) - continue - } - - updateCount++ - } - } - config.Unlock() - - if updateCount > 0 { - reconfigureDNSServer() - } - return updateCount -} - -// A helper function that parses filter contents and returns a number of rules and a filter name (if there's any) -func parseFilterContents(contents []byte) (int, string, []string) { - lines := strings.Split(string(contents), "\n") - rulesCount := 0 - name := "" - seenTitle := false - - // Count lines in the filter - for _, line := range lines { - line = strings.TrimSpace(line) - if len(line) > 0 && line[0] == '!' { - if m := filterTitleRegexp.FindAllStringSubmatch(line, -1); len(m) > 0 && len(m[0]) >= 2 && !seenTitle { - name = m[0][1] - seenTitle = true - } - } else if len(line) != 0 { - rulesCount++ - } - } - - return rulesCount, name, lines -} - -// Checks for filters updates -// If "force" is true -- does not check the filter's LastUpdated field -// Call "save" to persist the filter contents -func (filter *filter) update(force bool) (bool, error) { - if filter.ID == 0 { // protect against users deleting the ID - filter.ID = assignUniqueFilterID() - } - if !filter.Enabled { - return false, nil - } - if !force && time.Since(filter.LastUpdated) <= updatePeriod { - return false, nil - } - - log.Printf("Downloading update for filter %d from %s", filter.ID, filter.URL) - - // use the same update period for failed filter downloads to avoid flooding with requests - filter.LastUpdated = time.Now() - - resp, err := client.Get(filter.URL) - if resp != nil && resp.Body != nil { - defer resp.Body.Close() - } - if err != nil { - log.Printf("Couldn't request filter from URL %s, skipping: %s", filter.URL, err) - return false, err - } - - if resp.StatusCode != 200 { - log.Printf("Got status code %d from URL %s, skipping", resp.StatusCode, filter.URL) - return false, fmt.Errorf("got status code != 200: %d", resp.StatusCode) - } - - contentType := strings.ToLower(resp.Header.Get("content-type")) - if !strings.HasPrefix(contentType, "text/plain") { - log.Printf("Non-text response %s from %s, skipping", contentType, filter.URL) - return false, fmt.Errorf("non-text response %s", contentType) - } - - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - log.Printf("Couldn't fetch filter contents from URL %s, skipping: %s", filter.URL, err) - return false, err - } - - // Extract filter name and count number of rules - rulesCount, filterName, rules := parseFilterContents(body) - - if filterName != "" { - filter.Name = filterName - } - - // Check if the filter has been really changed - if reflect.DeepEqual(filter.Rules, rules) { - log.Printf("The filter %d text has not changed", filter.ID) - return false, nil - } - - log.Printf("Filter %d has been updated: %d bytes, %d rules", filter.ID, len(body), rulesCount) - filter.RulesCount = rulesCount - filter.Rules = rules - - return true, nil -} - -// saves filter contents to the file in dataDir -func (filter *filter) save() error { - filterFilePath := filter.Path() - log.Printf("Saving filter %d contents to: %s", filter.ID, filterFilePath) - body := []byte(strings.Join(filter.Rules, "\n")) - - return safeWriteFile(filterFilePath, body) -} - -// loads filter contents from the file in dataDir -func (filter *filter) load() error { - if !filter.Enabled { - // No need to load a filter that is not enabled - return nil - } - - filterFilePath := filter.Path() - log.Printf("Loading filter %d contents to: %s", filter.ID, filterFilePath) - - if _, err := os.Stat(filterFilePath); os.IsNotExist(err) { - // do nothing, file doesn't exist - return err - } - - filterFileContents, err := ioutil.ReadFile(filterFilePath) - if err != nil { - return err - } - - log.Printf("File %s, id %d, length %d", filterFilePath, filter.ID, len(filterFileContents)) - rulesCount, _, rules := parseFilterContents(filterFileContents) - - filter.RulesCount = rulesCount - filter.Rules = rules - - return nil -} - -// Path to the filter contents -func (filter *filter) Path() string { - return filepath.Join(config.ourBinaryDir, dataDir, filterDir, strconv.FormatInt(filter.ID, 10)+".txt") -} - // ------------ // safebrowsing // ------------ diff --git a/filter.go b/filter.go new file mode 100644 index 00000000..46e3a376 --- /dev/null +++ b/filter.go @@ -0,0 +1,251 @@ +package main + +import ( + "fmt" + "io/ioutil" + "log" + "os" + "path/filepath" + "reflect" + "regexp" + "strconv" + "strings" + "time" + + "github.com/AdguardTeam/AdGuardHome/dnsforward" +) + +var ( + nextFilterID = time.Now().Unix() // semi-stable way to generate an unique ID + filterTitleRegexp = regexp.MustCompile(`^! Title: +(.*)$`) +) + +// field ordering is important -- yaml fields will mirror ordering from here +type filter struct { + Enabled bool `json:"enabled"` + URL string `json:"url"` + Name string `json:"name" yaml:"name"` + RulesCount int `json:"rulesCount" yaml:"-"` + LastUpdated time.Time `json:"lastUpdated,omitempty" yaml:"last_updated,omitempty"` + + dnsforward.Filter `yaml:",inline"` +} + +// Creates a helper object for working with the user rules +func userFilter() filter { + return filter{ + // User filter always has constant ID=0 + Enabled: true, + Filter: dnsforward.Filter{ + Rules: config.UserRules, + }, + } +} + +func deduplicateFilters() { + // Deduplicate filters + i := 0 // output index, used for deletion later + urls := map[string]bool{} + for _, filter := range config.Filters { + if _, ok := urls[filter.URL]; !ok { + // we didn't see it before, keep it + urls[filter.URL] = true // remember the URL + config.Filters[i] = filter + i++ + } + } + + // all entries we want to keep are at front, delete the rest + config.Filters = config.Filters[:i] +} + +// Set the next filter ID to max(filter.ID) + 1 +func updateUniqueFilterID(filters []filter) { + for _, filter := range filters { + if nextFilterID < filter.ID { + nextFilterID = filter.ID + 1 + } + } +} + +func assignUniqueFilterID() int64 { + value := nextFilterID + nextFilterID += 1 + return value +} + +// Sets up a timer that will be checking for filters updates periodically +func periodicallyRefreshFilters() { + for range time.Tick(time.Minute) { + refreshFiltersIfNeccessary(false) + } +} + +// Checks filters updates if necessary +// If force is true, it ignores the filter.LastUpdated field value +func refreshFiltersIfNeccessary(force bool) int { + config.Lock() + + // fetch URLs + updateCount := 0 + for i := range config.Filters { + filter := &config.Filters[i] // otherwise we will be operating on a copy + + if filter.ID == 0 { // protect against users modifying the yaml and removing the ID + filter.ID = assignUniqueFilterID() + } + + updated, err := filter.update(force) + if err != nil { + log.Printf("Failed to update filter %s: %s\n", filter.URL, err) + continue + } + if updated { + // Saving it to the filters dir now + err = filter.save() + if err != nil { + log.Printf("Failed to save the updated filter %d: %s", filter.ID, err) + continue + } + + updateCount++ + } + } + config.Unlock() + + if updateCount > 0 { + reconfigureDNSServer() + } + return updateCount +} + +// A helper function that parses filter contents and returns a number of rules and a filter name (if there's any) +func parseFilterContents(contents []byte) (int, string, []string) { + lines := strings.Split(string(contents), "\n") + rulesCount := 0 + name := "" + seenTitle := false + + // Count lines in the filter + for _, line := range lines { + line = strings.TrimSpace(line) + if len(line) > 0 && line[0] == '!' { + if m := filterTitleRegexp.FindAllStringSubmatch(line, -1); len(m) > 0 && len(m[0]) >= 2 && !seenTitle { + name = m[0][1] + seenTitle = true + } + } else if len(line) != 0 { + rulesCount++ + } + } + + return rulesCount, name, lines +} + +// Checks for filters updates +// If "force" is true -- does not check the filter's LastUpdated field +// Call "save" to persist the filter contents +func (filter *filter) update(force bool) (bool, error) { + if filter.ID == 0 { // protect against users deleting the ID + filter.ID = assignUniqueFilterID() + } + if !filter.Enabled { + return false, nil + } + if !force && time.Since(filter.LastUpdated) <= updatePeriod { + return false, nil + } + + log.Printf("Downloading update for filter %d from %s", filter.ID, filter.URL) + + // use the same update period for failed filter downloads to avoid flooding with requests + filter.LastUpdated = time.Now() + + resp, err := client.Get(filter.URL) + if resp != nil && resp.Body != nil { + defer resp.Body.Close() + } + if err != nil { + log.Printf("Couldn't request filter from URL %s, skipping: %s", filter.URL, err) + return false, err + } + + if resp.StatusCode != 200 { + log.Printf("Got status code %d from URL %s, skipping", resp.StatusCode, filter.URL) + return false, fmt.Errorf("got status code != 200: %d", resp.StatusCode) + } + + contentType := strings.ToLower(resp.Header.Get("content-type")) + if !strings.HasPrefix(contentType, "text/plain") { + log.Printf("Non-text response %s from %s, skipping", contentType, filter.URL) + return false, fmt.Errorf("non-text response %s", contentType) + } + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + log.Printf("Couldn't fetch filter contents from URL %s, skipping: %s", filter.URL, err) + return false, err + } + + // Extract filter name and count number of rules + rulesCount, filterName, rules := parseFilterContents(body) + + if filterName != "" { + filter.Name = filterName + } + + // Check if the filter has been really changed + if reflect.DeepEqual(filter.Rules, rules) { + log.Printf("The filter %d text has not changed", filter.ID) + return false, nil + } + + log.Printf("Filter %d has been updated: %d bytes, %d rules", filter.ID, len(body), rulesCount) + filter.RulesCount = rulesCount + filter.Rules = rules + + return true, nil +} + +// saves filter contents to the file in dataDir +func (filter *filter) save() error { + filterFilePath := filter.Path() + log.Printf("Saving filter %d contents to: %s", filter.ID, filterFilePath) + body := []byte(strings.Join(filter.Rules, "\n")) + + return safeWriteFile(filterFilePath, body) +} + +// loads filter contents from the file in dataDir +func (filter *filter) load() error { + if !filter.Enabled { + // No need to load a filter that is not enabled + return nil + } + + filterFilePath := filter.Path() + log.Printf("Loading filter %d contents to: %s", filter.ID, filterFilePath) + + if _, err := os.Stat(filterFilePath); os.IsNotExist(err) { + // do nothing, file doesn't exist + return err + } + + filterFileContents, err := ioutil.ReadFile(filterFilePath) + if err != nil { + return err + } + + log.Printf("File %s, id %d, length %d", filterFilePath, filter.ID, len(filterFileContents)) + rulesCount, _, rules := parseFilterContents(filterFileContents) + + filter.RulesCount = rulesCount + filter.Rules = rules + + return nil +} + +// Path to the filter contents +func (filter *filter) Path() string { + return filepath.Join(config.ourBinaryDir, dataDir, filterDir, strconv.FormatInt(filter.ID, 10)+".txt") +} From 0d1478b635b2a01d77290d3672586ee382fe4b90 Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Wed, 28 Nov 2018 20:15:18 +0300 Subject: [PATCH 15/58] Remove unused struct field --- config.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/config.go b/config.go index 21a75b42..bec31db2 100644 --- a/config.go +++ b/config.go @@ -41,8 +41,7 @@ type configuration struct { type coreDNSConfig struct { binaryFile string coreFile string - Filters []filter `yaml:"-"` - Port int `yaml:"port"` + Port int `yaml:"port"` dnsforward.FilteringConfig `yaml:",inline"` From 31f77af534a0e584480d31dcdc0d8ae763ded29a Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Wed, 28 Nov 2018 20:15:32 +0300 Subject: [PATCH 16/58] Move user filter saving into writeAllConfigs() --- config.go | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/config.go b/config.go index bec31db2..38bee5ed 100644 --- a/config.go +++ b/config.go @@ -146,6 +146,16 @@ func (c *configuration) write() error { return err } + return nil +} + +func writeAllConfigs() error { + err := config.write() + if err != nil { + log.Printf("Couldn't write config: %s", err) + return err + } + userFilter := userFilter() err = userFilter.save() if err != nil { @@ -155,8 +165,3 @@ func (c *configuration) write() error { return nil } - -func writeAllConfigs() error { - return config.write() -} - From e0286ee85d14222919f1fc130b78c086aca3b593 Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Wed, 28 Nov 2018 20:19:09 +0300 Subject: [PATCH 17/58] Don't forget to give user filter to dns forwarding server --- coredns.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/coredns.go b/coredns.go index 348b4241..ce46a965 100644 --- a/coredns.go +++ b/coredns.go @@ -17,6 +17,11 @@ func isRunning() bool { func generateServerConfig() dnsforward.ServerConfig { filters := []dnsforward.Filter{} + userFilter := userFilter() + filters = append(filters, dnsforward.Filter{ + ID: userFilter.ID, + Rules: userFilter.Rules, + }) for _, filter := range config.Filters { filters = append(filters, dnsforward.Filter{ ID: filter.ID, From 9089122b56a7ce5ea788370492293e75ff88d1a1 Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Thu, 29 Nov 2018 19:40:48 +0300 Subject: [PATCH 18/58] Compress the packed static js and css to save some space. --- Makefile | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index 038602fc..d9753d3f 100644 --- a/Makefile +++ b/Makefile @@ -20,8 +20,11 @@ $(STATIC): $(JSFILES) client/node_modules npm --prefix client run build-prod $(TARGET): $(STATIC) *.go coredns_plugin/*.go dnsfilter/*.go - GOPATH=$(GOPATH) GOOS=$(NATIVE_GOOS) GOARCH=$(NATIVE_GOARCH) GO111MODULE=off go get -v github.com/gobuffalo/packr/... - GOPATH=$(GOPATH) PATH=$(GOPATH)/bin:$(PATH) packr build -ldflags="-X main.VersionString=$(GIT_VERSION)" -asmflags="-trimpath=$(PWD)" -gcflags="-trimpath=$(PWD)" -o $(TARGET) + go get -d . + GOOS=$(NATIVE_GOOS) GOARCH=$(NATIVE_GOARCH) GO111MODULE=off go get -v github.com/gobuffalo/packr/... + PATH=$(GOPATH)/bin:$(PATH) packr -z + CGO_ENABLED=0 go build -ldflags="-s -w -X main.VersionString=$(GIT_VERSION)" -asmflags="-trimpath=$(PWD)" -gcflags="-trimpath=$(PWD)" -o $(TARGET) + PATH=$(GOPATH)/bin:$(PATH) packr clean clean: $(MAKE) cleanfast From a6e0a1745417eae47850cb544405d3b439f79899 Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Fri, 30 Nov 2018 13:17:34 +0300 Subject: [PATCH 19/58] dnsforward -- trim dot in the end of hostname, dnsfilter does not expect it there --- dnsforward/dnsforward.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index c087d17e..d9dfae5c 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -5,6 +5,7 @@ import ( "log" "net" "reflect" + "strings" "sync" "github.com/AdguardTeam/AdGuardHome/dnsfilter" @@ -406,7 +407,7 @@ func (s *Server) handlePacket(p []byte, addr net.Addr, conn *net.UDPConn) { return } } - host := msg.Question[0].Name + host := strings.TrimSuffix(msg.Question[0].Name, ".") res, err := s.dnsFilter.CheckHost(host) if err != nil { log.Printf("dnsfilter failed to check host '%s': %s", host, err) From 87c54ebd4c51dae349e8630448ac0a6ad6210efd Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Fri, 30 Nov 2018 13:24:42 +0300 Subject: [PATCH 20/58] Move Filter definition from dnsforward to dnsfilter, it belongs there. --- config.go | 9 +++++---- coredns.go | 7 ++++--- dnsfilter/dnsfilter.go | 5 +++++ dnsforward/dnsforward.go | 7 +------ filter.go | 6 +++--- 5 files changed, 18 insertions(+), 16 deletions(-) diff --git a/config.go b/config.go index 38bee5ed..f79c4684 100644 --- a/config.go +++ b/config.go @@ -7,6 +7,7 @@ import ( "path/filepath" "sync" + "github.com/AdguardTeam/AdGuardHome/dnsfilter" "github.com/AdguardTeam/AdGuardHome/dnsforward" "gopkg.in/yaml.v2" ) @@ -81,10 +82,10 @@ var config = configuration{ Prometheus: "prometheus :9153", }, Filters: []filter{ - {Filter: dnsforward.Filter{ID: 1}, Enabled: true, URL: "https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt", Name: "AdGuard Simplified Domain Names filter"}, - {Filter: dnsforward.Filter{ID: 2}, Enabled: false, URL: "https://adaway.org/hosts.txt", Name: "AdAway"}, - {Filter: dnsforward.Filter{ID: 3}, Enabled: false, URL: "https://hosts-file.net/ad_servers.txt", Name: "hpHosts - Ad and Tracking servers only"}, - {Filter: dnsforward.Filter{ID: 4}, Enabled: false, URL: "http://www.malwaredomainlist.com/hostslist/hosts.txt", Name: "MalwareDomainList.com Hosts List"}, + {Filter: dnsfilter.Filter{ID: 1}, Enabled: true, URL: "https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt", Name: "AdGuard Simplified Domain Names filter"}, + {Filter: dnsfilter.Filter{ID: 2}, Enabled: false, URL: "https://adaway.org/hosts.txt", Name: "AdAway"}, + {Filter: dnsfilter.Filter{ID: 3}, Enabled: false, URL: "https://hosts-file.net/ad_servers.txt", Name: "hpHosts - Ad and Tracking servers only"}, + {Filter: dnsfilter.Filter{ID: 4}, Enabled: false, URL: "http://www.malwaredomainlist.com/hostslist/hosts.txt", Name: "MalwareDomainList.com Hosts List"}, }, } diff --git a/coredns.go b/coredns.go index ce46a965..05199868 100644 --- a/coredns.go +++ b/coredns.go @@ -5,6 +5,7 @@ import ( "log" "net" + "github.com/AdguardTeam/AdGuardHome/dnsfilter" "github.com/AdguardTeam/AdGuardHome/dnsforward" "github.com/joomcode/errorx" ) @@ -16,14 +17,14 @@ func isRunning() bool { } func generateServerConfig() dnsforward.ServerConfig { - filters := []dnsforward.Filter{} + filters := []dnsfilter.Filter{} userFilter := userFilter() - filters = append(filters, dnsforward.Filter{ + filters = append(filters, dnsfilter.Filter{ ID: userFilter.ID, Rules: userFilter.Rules, }) for _, filter := range config.Filters { - filters = append(filters, dnsforward.Filter{ + filters = append(filters, dnsfilter.Filter{ ID: filter.ID, Rules: filter.Rules, }) diff --git a/dnsfilter/dnsfilter.go b/dnsfilter/dnsfilter.go index 3153f69f..f96139e4 100644 --- a/dnsfilter/dnsfilter.go +++ b/dnsfilter/dnsfilter.go @@ -113,6 +113,11 @@ type Dnsfilter struct { config config } +type Filter struct { + ID int64 `json:"id"` // auto-assigned when filter is added (see nextFilterID), json by default keeps ID uppercase but we need lowercase + Rules []string `json:"-" yaml:"-"` // not in yaml or json +} + //go:generate stringer -type=Reason // Reason holds an enum detailing why it was filtered or not filtered diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index d9dfae5c..7a2716e9 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -88,7 +88,7 @@ type ServerConfig struct { UDPListenAddr *net.UDPAddr // if nil, then default is is used (port 53 on *) BlockedResponseTTL uint32 // if 0, then default is used (3600) Upstreams []Upstream - Filters []Filter + Filters []dnsfilter.Filter FilteringConfig } @@ -119,11 +119,6 @@ var defaultValues = ServerConfig{ }, } -type Filter struct { - ID int64 `json:"id"` // auto-assigned when filter is added (see nextFilterID), json by default keeps ID uppercase but we need lowercase - Rules []string `json:"-" yaml:"-"` // not in yaml or json -} - // // packet loop // diff --git a/filter.go b/filter.go index 46e3a376..6a8a15c8 100644 --- a/filter.go +++ b/filter.go @@ -12,7 +12,7 @@ import ( "strings" "time" - "github.com/AdguardTeam/AdGuardHome/dnsforward" + "github.com/AdguardTeam/AdGuardHome/dnsfilter" ) var ( @@ -28,7 +28,7 @@ type filter struct { RulesCount int `json:"rulesCount" yaml:"-"` LastUpdated time.Time `json:"lastUpdated,omitempty" yaml:"last_updated,omitempty"` - dnsforward.Filter `yaml:",inline"` + dnsfilter.Filter `yaml:",inline"` } // Creates a helper object for working with the user rules @@ -36,7 +36,7 @@ func userFilter() filter { return filter{ // User filter always has constant ID=0 Enabled: true, - Filter: dnsforward.Filter{ + Filter: dnsfilter.Filter{ Rules: config.UserRules, }, } From ce615e185552f51dd58d815aa6770d8def90ea7d Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Fri, 30 Nov 2018 13:32:51 +0300 Subject: [PATCH 21/58] dnsfilter -- Get rid of accessors. --- coredns_plugin/coredns_plugin.go | 11 +++--- dnsfilter/dnsfilter.go | 65 ++++++++++++++------------------ 2 files changed, 34 insertions(+), 42 deletions(-) diff --git a/coredns_plugin/coredns_plugin.go b/coredns_plugin/coredns_plugin.go index f3a946dd..356902c6 100644 --- a/coredns_plugin/coredns_plugin.go +++ b/coredns_plugin/coredns_plugin.go @@ -88,7 +88,7 @@ func setupPlugin(c *caddy.Controller) (*plug, error) { switch blockValue { case "safebrowsing": log.Println("Browsing security service is enabled") - p.d.EnableSafeBrowsing() + p.d.SafeBrowsingEnabled = true if c.NextArg() { if len(c.Val()) == 0 { return nil, c.ArgErr() @@ -97,7 +97,7 @@ func setupPlugin(c *caddy.Controller) (*plug, error) { } case "safesearch": log.Println("Safe search is enabled") - p.d.EnableSafeSearch() + p.d.SafeSearchEnabled = true case "parental": if !c.NextArg() { return nil, c.ArgErr() @@ -108,10 +108,11 @@ func setupPlugin(c *caddy.Controller) (*plug, error) { } log.Println("Parental control is enabled") - err = p.d.EnableParental(sensitivity) - if err != nil { - return nil, c.ArgErr() + if !dnsfilter.IsParentalSensitivityValid(sensitivity) { + return nil, dnsfilter.ErrInvalidParental } + p.d.ParentalEnabled = true + p.d.ParentalSensitivity = sensitivity if c.NextArg() { if len(c.Val()) == 0 { return nil, c.ArgErr() diff --git a/dnsfilter/dnsfilter.go b/dnsfilter/dnsfilter.go index f96139e4..e0958cda 100644 --- a/dnsfilter/dnsfilter.go +++ b/dnsfilter/dnsfilter.go @@ -46,13 +46,17 @@ const shortcutLength = 6 // used for rule search optimization, 6 hits the sweet const enableFastLookup = true // flag for debugging, must be true in production for faster performance const enableDelayedCompilation = true // flag for debugging, must be true in production for faster performance -type config struct { - parentalServer string - parentalSensitivity int // must be either 3, 10, 13 or 17 - parentalEnabled bool - safeSearchEnabled bool - safeBrowsingEnabled bool - safeBrowsingServer string +// Config allows you to configure DNS filtering with New() or just change variables directly. +type Config struct { + ParentalSensitivity int `yaml:"parental_sensitivity"` // must be either 3, 10, 13 or 17 + ParentalEnabled bool `yaml:"parental_enabled"` + SafeSearchEnabled bool `yaml:"safesearch_enabled"` + SafeBrowsingEnabled bool `yaml:"safebrowsing_enabled"` +} + +type privateConfig struct { + parentalServer string // access via methods + safeBrowsingServer string // access via methods } type rule struct { @@ -110,7 +114,8 @@ type Dnsfilter struct { client http.Client // handle for http client -- single instance as recommended by docs transport *http.Transport // handle for http transport used by http client - config config + Config // for direct access by library users, even a = assignment + privateConfig } type Filter struct { @@ -176,7 +181,7 @@ func (d *Dnsfilter) CheckHost(host string) (Result, error) { } // check safebrowsing if no match - if d.config.safeBrowsingEnabled { + if d.SafeBrowsingEnabled { result, err = d.checkSafeBrowsing(host) if err != nil { // failed to do HTTP lookup -- treat it as if we got empty response, but don't save cache @@ -189,7 +194,7 @@ func (d *Dnsfilter) CheckHost(host string) (Result, error) { } // check parental if no match - if d.config.parentalEnabled { + if d.ParentalEnabled { result, err = d.checkParental(host) if err != nil { // failed to do HTTP lookup -- treat it as if we got empty response, but don't save cache @@ -574,11 +579,11 @@ func hostnameToHashParam(host string, addslash bool) (string, map[string]bool) { func (d *Dnsfilter) checkSafeBrowsing(host string) (Result, error) { // prevent recursion -- checking the host of safebrowsing server makes no sense - if host == d.config.safeBrowsingServer { + if host == d.safeBrowsingServer { return Result{}, nil } format := func(hashparam string) string { - url := fmt.Sprintf(defaultSafebrowsingURL, d.config.safeBrowsingServer, hashparam) + url := fmt.Sprintf(defaultSafebrowsingURL, d.safeBrowsingServer, hashparam) return url } handleBody := func(body []byte, hashes map[string]bool) (Result, error) { @@ -615,11 +620,11 @@ func (d *Dnsfilter) checkSafeBrowsing(host string) (Result, error) { func (d *Dnsfilter) checkParental(host string) (Result, error) { // prevent recursion -- checking the host of parental safety server makes no sense - if host == d.config.parentalServer { + if host == d.parentalServer { return Result{}, nil } format := func(hashparam string) string { - url := fmt.Sprintf(defaultParentalURL, d.config.parentalServer, hashparam, d.config.parentalSensitivity) + url := fmt.Sprintf(defaultParentalURL, d.parentalServer, hashparam, d.ParentalSensitivity) return url } handleBody := func(body []byte, hashes map[string]bool) (Result, error) { @@ -872,8 +877,8 @@ func New() *Dnsfilter { Transport: d.transport, Timeout: defaultHTTPTimeout, } - d.config.safeBrowsingServer = defaultSafebrowsingServer - d.config.parentalServer = defaultParentalServer + d.safeBrowsingServer = defaultSafebrowsingServer + d.parentalServer = defaultParentalServer return d } @@ -890,35 +895,21 @@ func (d *Dnsfilter) Destroy() { // config manipulation helpers // -// EnableSafeBrowsing turns on checking hostnames in malware/phishing database -func (d *Dnsfilter) EnableSafeBrowsing() { - d.config.safeBrowsingEnabled = true -} - -// EnableParental turns on checking hostnames for containing adult content -func (d *Dnsfilter) EnableParental(sensitivity int) error { +// IsParentalSensitivityValid checks if sensitivity is valid value +func IsParentalSensitivityValid(sensitivity int) bool { switch sensitivity { case 3, 10, 13, 17: - d.config.parentalSensitivity = sensitivity - d.config.parentalEnabled = true - return nil - default: - return ErrInvalidParental + return true } -} - -// EnableSafeSearch turns on enforcing safesearch in search engines -// only used in coredns plugin and requires caller to use SafeSearchDomain() -func (d *Dnsfilter) EnableSafeSearch() { - d.config.safeSearchEnabled = true + return false } // SetSafeBrowsingServer lets you optionally change hostname of safesearch lookup func (d *Dnsfilter) SetSafeBrowsingServer(host string) { if len(host) == 0 { - d.config.safeBrowsingServer = defaultSafebrowsingServer + d.safeBrowsingServer = defaultSafebrowsingServer } else { - d.config.safeBrowsingServer = host + d.safeBrowsingServer = host } } @@ -934,7 +925,7 @@ func (d *Dnsfilter) ResetHTTPTimeout() { // SafeSearchDomain returns replacement address for search engine func (d *Dnsfilter) SafeSearchDomain(host string) (string, bool) { - if d.config.safeSearchEnabled { + if d.SafeSearchEnabled { val, ok := safeSearchDomains[host] return val, ok } From 057db71f3b4ee10fc7674479c2f1f940f6a98e5c Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Fri, 30 Nov 2018 13:35:22 +0300 Subject: [PATCH 22/58] Get rid of duplicate variable definitions --- config.go | 7 +++---- coredns.go | 6 +++--- dnsforward/dnsforward.go | 23 ++++++++++------------- 3 files changed, 16 insertions(+), 20 deletions(-) diff --git a/config.go b/config.go index f79c4684..f9907feb 100644 --- a/config.go +++ b/config.go @@ -68,10 +68,9 @@ var config = configuration{ binaryFile: "coredns", // only filename, no path coreFile: "Corefile", // only filename, no path FilteringConfig: dnsforward.FilteringConfig{ - ProtectionEnabled: true, - FilteringEnabled: true, - SafeBrowsingEnabled: false, - BlockedResponseTTL: 10, // in seconds + ProtectionEnabled: true, // whether or not use any of dnsfilter features + FilteringEnabled: true, // whether or not use filter lists + BlockedResponseTTL: 10, // in seconds }, QueryLogEnabled: true, Ratelimit: 20, diff --git a/coredns.go b/coredns.go index 05199868..f6940abe 100644 --- a/coredns.go +++ b/coredns.go @@ -31,9 +31,9 @@ func generateServerConfig() dnsforward.ServerConfig { } newconfig := dnsforward.ServerConfig{ - UDPListenAddr: &net.UDPAddr{Port: config.CoreDNS.Port}, - BlockedResponseTTL: config.CoreDNS.BlockedResponseTTL, - Filters: filters, + UDPListenAddr: &net.UDPAddr{Port: config.CoreDNS.Port}, + FilteringConfig: config.CoreDNS.FilteringConfig, + Filters: filters, } for _, u := range config.CoreDNS.UpstreamDNS { diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index 7a2716e9..bf7566a9 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -74,28 +74,25 @@ func (s *Server) RUnlock() { */ type FilteringConfig struct { - ProtectionEnabled bool `yaml:"protection_enabled"` - FilteringEnabled bool `yaml:"filtering_enabled"` - SafeBrowsingEnabled bool `yaml:"safebrowsing_enabled"` - SafeSearchEnabled bool `yaml:"safesearch_enabled"` - ParentalEnabled bool `yaml:"parental_enabled"` - ParentalSensitivity int `yaml:"parental_sensitivity"` - BlockedResponseTTL uint32 `yaml:"blocked_response_ttl"` + ProtectionEnabled bool `yaml:"protection_enabled"` + FilteringEnabled bool `yaml:"filtering_enabled"` + BlockedResponseTTL uint32 `yaml:"blocked_response_ttl"` // if 0, then default is used (3600) + + dnsfilter.Config `yaml:",inline"` } // The zero ServerConfig is empty and ready for use. type ServerConfig struct { - UDPListenAddr *net.UDPAddr // if nil, then default is is used (port 53 on *) - BlockedResponseTTL uint32 // if 0, then default is used (3600) - Upstreams []Upstream - Filters []dnsfilter.Filter + UDPListenAddr *net.UDPAddr // if nil, then default is is used (port 53 on *) + Upstreams []Upstream + Filters []dnsfilter.Filter FilteringConfig } var defaultValues = ServerConfig{ - UDPListenAddr: &net.UDPAddr{Port: 53}, - BlockedResponseTTL: 3600, + UDPListenAddr: &net.UDPAddr{Port: 53}, + FilteringConfig: FilteringConfig{BlockedResponseTTL: 3600}, Upstreams: []Upstream{ //// dns over HTTPS // &dnsOverHTTPS{Address: "https://1.1.1.1/dns-query"}, From 3ecc0ee24bdfffde74fec6a3c2f6601b8efc481d Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Fri, 30 Nov 2018 13:43:01 +0300 Subject: [PATCH 23/58] Makefile -- don't dictate to go what binary to build, so it can append .exe if building for windows. --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index d9753d3f..4ff25301 100644 --- a/Makefile +++ b/Makefile @@ -23,7 +23,7 @@ $(TARGET): $(STATIC) *.go coredns_plugin/*.go dnsfilter/*.go go get -d . GOOS=$(NATIVE_GOOS) GOARCH=$(NATIVE_GOARCH) GO111MODULE=off go get -v github.com/gobuffalo/packr/... PATH=$(GOPATH)/bin:$(PATH) packr -z - CGO_ENABLED=0 go build -ldflags="-s -w -X main.VersionString=$(GIT_VERSION)" -asmflags="-trimpath=$(PWD)" -gcflags="-trimpath=$(PWD)" -o $(TARGET) + CGO_ENABLED=0 go build -ldflags="-s -w -X main.VersionString=$(GIT_VERSION)" -asmflags="-trimpath=$(PWD)" -gcflags="-trimpath=$(PWD)" PATH=$(GOPATH)/bin:$(PATH) packr clean clean: From e26837d9e87223a477fd8beb3151d98eeb52c5cd Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Fri, 30 Nov 2018 13:47:26 +0300 Subject: [PATCH 24/58] dnsfilter -- Add parameter to New() to supply optional initial config. --- coredns_plugin/coredns_plugin.go | 2 +- dnsfilter/dnsfilter.go | 5 ++++- dnsforward/dnsforward.go | 4 ++-- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/coredns_plugin/coredns_plugin.go b/coredns_plugin/coredns_plugin.go index 356902c6..8d302fd3 100644 --- a/coredns_plugin/coredns_plugin.go +++ b/coredns_plugin/coredns_plugin.go @@ -77,7 +77,7 @@ func setupPlugin(c *caddy.Controller) (*plug, error) { // create new Plugin and copy default values p := &plug{ settings: defaultPluginSettings, - d: dnsfilter.New(), + d: dnsfilter.New(nil), } log.Println("Initializing the CoreDNS plugin") diff --git a/dnsfilter/dnsfilter.go b/dnsfilter/dnsfilter.go index e0958cda..7fd01be3 100644 --- a/dnsfilter/dnsfilter.go +++ b/dnsfilter/dnsfilter.go @@ -856,7 +856,7 @@ func (d *Dnsfilter) matchHost(host string) (Result, error) { // // New creates properly initialized DNS Filter that is ready to be used -func New() *Dnsfilter { +func New(c *Config) *Dnsfilter { d := new(Dnsfilter) d.storage = make(map[string]bool) @@ -879,6 +879,9 @@ func New() *Dnsfilter { } d.safeBrowsingServer = defaultSafebrowsingServer d.parentalServer = defaultParentalServer + if c != nil { + d.Config = *c + } return d } diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index bf7566a9..3a2de123 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -178,7 +178,7 @@ func (s *Server) Start(config *ServerConfig) error { if s.dnsFilter == nil { log.Printf("Creating dnsfilter") - s.dnsFilter = dnsfilter.New() + s.dnsFilter = dnsfilter.New(nil) } go s.packetLoop() @@ -322,7 +322,7 @@ func (s *Server) reconfigureFilters(new ServerConfig) { return } - dnsFilter := dnsfilter.New() + dnsFilter := dnsfilter.New(&new.Config) // sets safebrowsing, safesearch and parental for _, f := range newFilters { for _, rule := range f.Rules { err := dnsFilter.AddRule(rule, f.ID) From 7d2e39ed52c4813b811fbc9117a238d813f96694 Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Fri, 30 Nov 2018 13:48:53 +0300 Subject: [PATCH 25/58] dnsfilter -- Add a convinience function to add all rules from all filters. --- dnsfilter/dnsfilter.go | 18 ++++++++++++++++++ dnsforward/dnsforward.go | 16 ++++------------ 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/dnsfilter/dnsfilter.go b/dnsfilter/dnsfilter.go index 7fd01be3..2c78a92b 100644 --- a/dnsfilter/dnsfilter.go +++ b/dnsfilter/dnsfilter.go @@ -737,6 +737,24 @@ func (d *Dnsfilter) lookupCommon(host string, lookupstats *LookupStats, cache gc // Adding rule and matching against the rules // +// AddRules is a convinience function to add an array of filters in one call +func (d *Dnsfilter) AddRules(filters []Filter) error { + for _, f := range filters { + for _, rule := range f.Rules { + err := d.AddRule(rule, f.ID) + if err == ErrAlreadyExists || err == ErrInvalidSyntax { + continue + } + if err != nil { + log.Printf("Cannot add rule %s: %s", rule, err) + // Just ignore invalid rules + continue + } + } + } + return nil +} + // AddRule adds a rule, checking if it is a valid rule first and if it wasn't added already func (d *Dnsfilter) AddRule(input string, filterListID int64) error { input = strings.TrimSpace(input) diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index 3a2de123..7d15d44a 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -323,18 +323,10 @@ func (s *Server) reconfigureFilters(new ServerConfig) { } dnsFilter := dnsfilter.New(&new.Config) // sets safebrowsing, safesearch and parental - for _, f := range newFilters { - for _, rule := range f.Rules { - err := dnsFilter.AddRule(rule, f.ID) - if err == dnsfilter.ErrAlreadyExists || err == dnsfilter.ErrInvalidSyntax { - continue - } - if err != nil { - log.Printf("Cannot add rule %s: %s", rule, err) - // Just ignore invalid rules - continue - } - } + + // add rules only if they are enabled + if new.FilteringEnabled { + dnsFilter.AddRules(newFilters) } s.Lock() From 2281b60ebbce1e68a804e7a01ccb7fca54617b22 Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Fri, 30 Nov 2018 13:49:54 +0300 Subject: [PATCH 26/58] dnsfilter -- add trace() --- dnsfilter/helpers.go | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/dnsfilter/helpers.go b/dnsfilter/helpers.go index 68d4ba26..8152f402 100644 --- a/dnsfilter/helpers.go +++ b/dnsfilter/helpers.go @@ -1,6 +1,10 @@ package dnsfilter import ( + "fmt" + "os" + "path" + "runtime" "strings" "sync/atomic" ) @@ -58,3 +62,17 @@ func updateMax(valuePtr *int64, maxPtr *int64) { // swapping failed because value has changed after reading, try again } } + +func trace(format string, args ...interface{}) { + pc := make([]uintptr, 10) // at least 1 entry needed + runtime.Callers(2, pc) + f := runtime.FuncForPC(pc[0]) + var buf strings.Builder + buf.WriteString(fmt.Sprintf("%s(): ", path.Base(f.Name()))) + text := fmt.Sprintf(format, args...) + buf.WriteString(text) + if len(text) == 0 || text[len(text)-1] != '\n' { + buf.WriteRune('\n') + } + fmt.Fprint(os.Stderr, buf.String()) +} From 0502ef6cc7e52b9cdeb42e9462ddbb5294183a0f Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Fri, 30 Nov 2018 13:51:15 +0300 Subject: [PATCH 27/58] dnsforward -- initialize all dnsfilter settings at start and reconfigure --- dnsforward/dnsforward.go | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index 7d15d44a..cde93571 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -178,7 +178,12 @@ func (s *Server) Start(config *ServerConfig) error { if s.dnsFilter == nil { log.Printf("Creating dnsfilter") - s.dnsFilter = dnsfilter.New(nil) + dnsFilterConfig := &config.Config + s.dnsFilter = dnsfilter.New(dnsFilterConfig) + // add rules only if they are enabled + if s.FilteringEnabled { + s.dnsFilter.AddRules(s.Filters) + } } go s.packetLoop() @@ -308,7 +313,7 @@ func (s *Server) reconfigureUpstreams(new ServerConfig) { s.Upstreams = new.Upstreams } -func (s *Server) reconfigureFilters(new ServerConfig) { +func (s *Server) reconfigureFiltering(new ServerConfig) { newFilters := new.Filters if len(newFilters) == 0 { newFilters = defaultValues.Filters @@ -317,11 +322,22 @@ func (s *Server) reconfigureFilters(new ServerConfig) { if len(oldFilters) == 0 { oldFilters = defaultValues.Filters } - if reflect.DeepEqual(newFilters, oldFilters) { - // they're exactly the same, do nothing + + needUpdate := false + if !reflect.DeepEqual(newFilters, oldFilters) { + needUpdate = true + } + + if !reflect.DeepEqual(new.FilteringConfig, s.FilteringConfig) { + needUpdate = true + } + + if !needUpdate { + // nothing to do, everything is same return } + // TODO: instead of creating new dnsfilter, change existing one's settings and filters dnsFilter := dnsfilter.New(&new.Config) // sets safebrowsing, safesearch and parental // add rules only if they are enabled @@ -332,6 +348,7 @@ func (s *Server) reconfigureFilters(new ServerConfig) { s.Lock() oldDnsFilter := s.dnsFilter s.dnsFilter = dnsFilter + s.FilteringConfig = new.FilteringConfig s.Unlock() oldDnsFilter.Destroy() @@ -340,7 +357,7 @@ func (s *Server) reconfigureFilters(new ServerConfig) { func (s *Server) Reconfigure(new ServerConfig) error { s.reconfigureBlockedResponseTTL(new) s.reconfigureUpstreams(new) - s.reconfigureFilters(new) + s.reconfigureFiltering(new) err := s.reconfigureListenAddr(new) if err != nil { From 0a4627f4f0b9bd63168f4a708b050486fbb6b1db Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Fri, 30 Nov 2018 13:51:54 +0300 Subject: [PATCH 28/58] Fix engrish --- filter.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/filter.go b/filter.go index 6a8a15c8..1150d292 100644 --- a/filter.go +++ b/filter.go @@ -196,7 +196,7 @@ func (filter *filter) update(force bool) (bool, error) { // Check if the filter has been really changed if reflect.DeepEqual(filter.Rules, rules) { - log.Printf("The filter %d text has not changed", filter.ID) + log.Printf("Filter #%d at URL %s hasn't changed, not updating it", filter.ID, filter.URL) return false, nil } From be4b65fdcae719741b0df92cd87fe398b706114e Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Fri, 30 Nov 2018 14:04:17 +0300 Subject: [PATCH 29/58] dnsforward -- use dnsfilter before cache -- changed settings or filters would require cache invalidation otherwise --- dnsforward/dnsforward.go | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index cde93571..f60939c0 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -393,21 +393,7 @@ func (s *Server) handlePacket(p []byte, addr net.Addr, conn *net.UDPConn) { } } - { - val, ok := s.cache.Get(&msg) - if ok && val != nil { - err = s.respond(val, addr, conn) - if err != nil { - if isConnClosed(err) { - // ignore this error, the connection was closed and that's ok - return - } - log.Printf("Couldn't respond to UDP packet: %s", err) - return - } - return - } - } + // use dnsfilter before cache -- changed settings or filters would require cache invalidation otherwise host := strings.TrimSuffix(msg.Question[0].Name, ".") res, err := s.dnsFilter.CheckHost(host) if err != nil { @@ -426,9 +412,23 @@ func (s *Server) handlePacket(p []byte, addr net.Addr, conn *net.UDPConn) { } } + { + val, ok := s.cache.Get(&msg) + if ok && val != nil { + err = s.respond(val, addr, conn) + if err != nil { + if isConnClosed(err) { + // ignore this error, the connection was closed and that's ok + return + } + log.Printf("Couldn't respond to UDP packet: %s", err) + return + } + return + } + } + // TODO: replace with single-socket implementation - // TODO: replace 8.8.8.8:53 with configurable upstreams - // TODO: support DoH, DoT and TCP upstream := s.chooseUpstream() reply, err := upstream.Exchange(&msg) if err != nil { From 09fb539875da90951d6a663aa3e758145a642745 Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Mon, 3 Dec 2018 12:51:37 +0300 Subject: [PATCH 30/58] Simplify two lines into one line. --- dnsforward/dnsforward.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index f60939c0..e5d2eb55 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -178,8 +178,7 @@ func (s *Server) Start(config *ServerConfig) error { if s.dnsFilter == nil { log.Printf("Creating dnsfilter") - dnsFilterConfig := &config.Config - s.dnsFilter = dnsfilter.New(dnsFilterConfig) + s.dnsFilter = dnsfilter.New(&s.Config) // add rules only if they are enabled if s.FilteringEnabled { s.dnsFilter.AddRules(s.Filters) From 8396dc2fdb90711ba3394b975347a262afe22ae7 Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Wed, 5 Dec 2018 12:52:23 +0300 Subject: [PATCH 31/58] Update docs for formatting in godoc. --- dnsforward/dnsforward.go | 13 +++++++------ dnsforward/upstream_pool.go | 28 ++++++++++++++-------------- 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index e5d2eb55..155255a8 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -13,13 +13,14 @@ import ( "github.com/miekg/dns" ) -// Server is the main way to start a DNS server +// Server is the main way to start a DNS server. +// // Example: -// s := dnsforward.Server{} -// err := s.Start(nil) // will start a DNS server listening on default port 53, in a goroutine -// err := s.Reconfigure(ServerConfig{UDPListenAddr: &net.UDPAddr{Port: 53535}}) // will reconfigure running DNS server to listen on UDP port 53535 -// err := s.Stop() // will stop listening on port 53535 and cancel all goroutines -// err := s.Start(nil) // will start listening again, on port 53535, in a goroutine +// s := dnsforward.Server{} +// err := s.Start(nil) // will start a DNS server listening on default port 53, in a goroutine +// err := s.Reconfigure(ServerConfig{UDPListenAddr: &net.UDPAddr{Port: 53535}}) // will reconfigure running DNS server to listen on UDP port 53535 +// err := s.Stop() // will stop listening on port 53535 and cancel all goroutines +// err := s.Start(nil) // will start listening again, on port 53535, in a goroutine // // The zero Server is empty and ready for use. type Server struct { diff --git a/dnsforward/upstream_pool.go b/dnsforward/upstream_pool.go index 9756d54f..f944e695 100644 --- a/dnsforward/upstream_pool.go +++ b/dnsforward/upstream_pool.go @@ -10,22 +10,22 @@ import ( "github.com/joomcode/errorx" ) -// upstream TLS pool. +// Upstream TLS pool. // // Example: -// pool := TLSPool{Address: "tls://1.1.1.1:853"} -// netConn, err := pool.Get() -// if err != nil {panic(err)} -// c := dns.Conn{Conn: netConn} -// q := dns.Msg{} -// q.SetQuestion("google.com.", dns.TypeA) -// log.Println(q) -// err = c.WriteMsg(&q) -// if err != nil {panic(err)} -// r, err := c.ReadMsg() -// if err != nil {panic(err)} -// log.Println(r) -// pool.Put(c.Conn) +// pool := TLSPool{Address: "tls://1.1.1.1:853"} +// netConn, err := pool.Get() +// if err != nil {panic(err)} +// c := dns.Conn{Conn: netConn} +// q := dns.Msg{} +// q.SetQuestion("google.com.", dns.TypeA) +// log.Println(q) +// err = c.WriteMsg(&q) +// if err != nil {panic(err)} +// r, err := c.ReadMsg() +// if err != nil {panic(err)} +// log.Println(r) +// pool.Put(c.Conn) type TLSPool struct { Address string parsedAddress *url.URL From e5d2f883ac13b195ece8e00aa386f61a64c5c669 Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Wed, 5 Dec 2018 12:57:14 +0300 Subject: [PATCH 32/58] dnsforward -- Make Upstream interface give access to Address field. --- dnsforward/dnsforward.go | 28 ++++++++++++------------ dnsforward/upstream.go | 47 +++++++++++++++++++++++----------------- 2 files changed, 41 insertions(+), 34 deletions(-) diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index 155255a8..e362ee75 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -96,24 +96,24 @@ var defaultValues = ServerConfig{ FilteringConfig: FilteringConfig{BlockedResponseTTL: 3600}, Upstreams: []Upstream{ //// dns over HTTPS - // &dnsOverHTTPS{Address: "https://1.1.1.1/dns-query"}, - // &dnsOverHTTPS{Address: "https://dns.google.com/experimental"}, - // &dnsOverHTTPS{Address: "https://doh.cleanbrowsing.org/doh/security-filter/"}, - // &dnsOverHTTPS{Address: "https://dns10.quad9.net/dns-query"}, - // &dnsOverHTTPS{Address: "https://doh.powerdns.org"}, - // &dnsOverHTTPS{Address: "https://doh.securedns.eu/dns-query"}, + // &dnsOverHTTPS{address: "https://1.1.1.1/dns-query"}, + // &dnsOverHTTPS{address: "https://dns.google.com/experimental"}, + // &dnsOverHTTPS{address: "https://doh.cleanbrowsing.org/doh/security-filter/"}, + // &dnsOverHTTPS{address: "https://dns10.quad9.net/dns-query"}, + // &dnsOverHTTPS{address: "https://doh.powerdns.org"}, + // &dnsOverHTTPS{address: "https://doh.securedns.eu/dns-query"}, //// dns over TLS - // &dnsOverTLS{Address: "tls://8.8.8.8:853"}, - // &dnsOverTLS{Address: "tls://8.8.4.4:853"}, - // &dnsOverTLS{Address: "tls://1.1.1.1:853"}, - // &dnsOverTLS{Address: "tls://1.0.0.1:853"}, + // &dnsOverTLS{address: "tls://8.8.8.8:853"}, + // &dnsOverTLS{address: "tls://8.8.4.4:853"}, + // &dnsOverTLS{address: "tls://1.1.1.1:853"}, + // &dnsOverTLS{address: "tls://1.0.0.1:853"}, //// plainDNS - &plainDNS{Address: "8.8.8.8:53"}, - &plainDNS{Address: "8.8.4.4:53"}, - &plainDNS{Address: "1.1.1.1:53"}, - &plainDNS{Address: "1.0.0.1:53"}, + &plainDNS{address: "8.8.8.8:53"}, + &plainDNS{address: "8.8.4.4:53"}, + &plainDNS{address: "1.1.1.1:53"}, + &plainDNS{address: "1.0.0.1:53"}, }, } diff --git a/dnsforward/upstream.go b/dnsforward/upstream.go index b245e7f1..1ca27242 100644 --- a/dnsforward/upstream.go +++ b/dnsforward/upstream.go @@ -22,13 +22,14 @@ const defaultTimeout = time.Second * 10 type Upstream interface { Exchange(m *dns.Msg) (*dns.Msg, error) + Address() string } // // plain DNS // type plainDNS struct { - Address string + address string } var defaultUDPClient = dns.Client{ @@ -42,11 +43,13 @@ var defaultTCPClient = dns.Client{ Timeout: defaultTimeout, } +func (p *plainDNS) Address() string { return p.address } + func (p *plainDNS) Exchange(m *dns.Msg) (*dns.Msg, error) { - reply, _, err := defaultUDPClient.Exchange(m, p.Address) + reply, _, err := defaultUDPClient.Exchange(m, p.address) if err != nil && reply != nil && reply.Truncated { log.Printf("Truncated message was received, retrying over TCP, question: %s", m.Question[0].String()) - reply, _, err = defaultTCPClient.Exchange(m, p.Address) + reply, _, err = defaultTCPClient.Exchange(m, p.address) } return reply, err } @@ -55,7 +58,7 @@ func (p *plainDNS) Exchange(m *dns.Msg) (*dns.Msg, error) { // DNS-over-TLS // type dnsOverTLS struct { - Address string + address string pool *TLSPool sync.RWMutex // protects pool @@ -68,6 +71,8 @@ var defaultTLSClient = dns.Client{ TLSConfig: &tls.Config{}, } +func (p *dnsOverTLS) Address() string { return p.address } + func (p *dnsOverTLS) Exchange(m *dns.Msg) (*dns.Msg, error) { var pool *TLSPool p.RLock() @@ -76,7 +81,7 @@ func (p *dnsOverTLS) Exchange(m *dns.Msg) (*dns.Msg, error) { if pool == nil { p.Lock() // lazy initialize it - p.pool = &TLSPool{Address: p.Address} + p.pool = &TLSPool{Address: p.address} p.Unlock() } @@ -84,19 +89,19 @@ func (p *dnsOverTLS) Exchange(m *dns.Msg) (*dns.Msg, error) { poolConn, err := p.pool.Get() p.RUnlock() if err != nil { - return nil, errorx.Decorate(err, "Failed to get a connection from TLSPool to %s", p.Address) + return nil, errorx.Decorate(err, "Failed to get a connection from TLSPool to %s", p.address) } c := dns.Conn{Conn: poolConn} err = c.WriteMsg(m) if err != nil { poolConn.Close() - return nil, errorx.Decorate(err, "Failed to send a request to %s", p.Address) + return nil, errorx.Decorate(err, "Failed to send a request to %s", p.address) } reply, err := c.ReadMsg() if err != nil { poolConn.Close() - return nil, errorx.Decorate(err, "Failed to read a request from %s", p.Address) + return nil, errorx.Decorate(err, "Failed to read a request from %s", p.address) } p.RLock() p.pool.Put(poolConn) @@ -108,7 +113,7 @@ func (p *dnsOverTLS) Exchange(m *dns.Msg) (*dns.Msg, error) { // DNS-over-https // type dnsOverHTTPS struct { - Address string + address string } var defaultHTTPSTransport = http.Transport{} @@ -118,33 +123,35 @@ var defaultHTTPSClient = http.Client{ Timeout: defaultTimeout, } +func (p *dnsOverHTTPS) Address() string { return p.address } + func (p *dnsOverHTTPS) Exchange(m *dns.Msg) (*dns.Msg, error) { buf, err := m.Pack() if err != nil { return nil, errorx.Decorate(err, "Couldn't pack request msg") } bb := bytes.NewBuffer(buf) - resp, err := http.Post(p.Address, "application/dns-message", bb) + resp, err := http.Post(p.address, "application/dns-message", bb) if resp != nil && resp.Body != nil { defer resp.Body.Close() } if err != nil { - return nil, errorx.Decorate(err, "Couldn't do a POST request to '%s'", p.Address) + return nil, errorx.Decorate(err, "Couldn't do a POST request to '%s'", p.address) } body, err := ioutil.ReadAll(resp.Body) if err != nil { - return nil, errorx.Decorate(err, "Couldn't read body contents for '%s'", p.Address) + return nil, errorx.Decorate(err, "Couldn't read body contents for '%s'", p.address) } if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("Got an unexpected HTTP status code %d from '%s'", resp.StatusCode, p.Address) + return nil, fmt.Errorf("Got an unexpected HTTP status code %d from '%s'", resp.StatusCode, p.address) } if len(body) == 0 { - return nil, fmt.Errorf("Got an unexpected empty body from '%s'", p.Address) + return nil, fmt.Errorf("Got an unexpected empty body from '%s'", p.address) } response := dns.Msg{} err = response.Unpack(body) if err != nil { - return nil, errorx.Decorate(err, "Couldn't unpack DNS response from '%s': body is %s", p.Address, string(body)) + return nil, errorx.Decorate(err, "Couldn't unpack DNS response from '%s': body is %s", p.address, string(body)) } return &response, nil } @@ -176,20 +183,20 @@ func GetUpstream(address string) (Upstream, error) { if url.Port() == "" { url.Host += ":53" } - return &plainDNS{Address: url.String()}, nil + return &plainDNS{address: url.String()}, nil case "tls": if url.Port() == "" { url.Host += ":853" } - return &dnsOverTLS{Address: url.String()}, nil + return &dnsOverTLS{address: url.String()}, nil case "https": - return &dnsOverHTTPS{Address: url.String()}, nil + return &dnsOverHTTPS{address: url.String()}, nil default: // assume it's plain DNS if url.Port() == "" { url.Host += ":53" } - return &plainDNS{Address: url.String()}, nil + return &plainDNS{address: url.String()}, nil } } @@ -199,5 +206,5 @@ func GetUpstream(address string) (Upstream, error) { // doesn't have port, default to 53 address = net.JoinHostPort(address, "53") } - return &plainDNS{Address: address}, nil + return &plainDNS{address: address}, nil } From 052f9757623e2f100d25dd5b94fae8d236f2b688 Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Wed, 5 Dec 2018 14:03:41 +0300 Subject: [PATCH 33/58] dnsforward -- Move querylog from coredns plugin, a more complex migration with proper API took too long so a simple move was used instead to save time. --- dnsforward/dnsforward.go | 132 ++++++------ dnsforward/querylog.go | 224 ++++++++++++++++++++ dnsforward/querylog_file.go | 253 +++++++++++++++++++++++ dnsforward/querylog_top.go | 365 +++++++++++++++++++++++++++++++++ dnsforward/stats.go | 398 ++++++++++++++++++++++++++++++++++++ 5 files changed, 1305 insertions(+), 67 deletions(-) create mode 100644 dnsforward/querylog.go create mode 100644 dnsforward/querylog_file.go create mode 100644 dnsforward/querylog_top.go create mode 100644 dnsforward/stats.go diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index e362ee75..a5a7193a 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -7,6 +7,7 @@ import ( "reflect" "strings" "sync" + "time" "github.com/AdguardTeam/AdGuardHome/dnsfilter" "github.com/joomcode/errorx" @@ -186,6 +187,19 @@ func (s *Server) Start(config *ServerConfig) error { } } + log.Printf("Loading stats from querylog") + err := fillStatsFromQueryLog() + if err != nil { + log.Printf("Failed to load stats from querylog: %s", err) + return err + } + + once.Do(func() { + go periodicQueryLogRotate() + go periodicHourlyTopRotate() + go statsRotator() + }) + go s.packetLoop() return nil @@ -346,12 +360,12 @@ func (s *Server) reconfigureFiltering(new ServerConfig) { } s.Lock() - oldDnsFilter := s.dnsFilter + oldDNSFilter := s.dnsFilter s.dnsFilter = dnsFilter s.FilteringConfig = new.FilteringConfig s.Unlock() - oldDnsFilter.Destroy() + oldDNSFilter.Destroy() } func (s *Server) Reconfigure(new ServerConfig) error { @@ -370,15 +384,13 @@ func (s *Server) Reconfigure(new ServerConfig) error { // packet handling functions // -func (s *Server) handlePacket(p []byte, addr net.Addr, conn *net.UDPConn) { +// handlePacketInternal processes the incoming packet bytes and returns with an optional response packet. +// +// If an empty dns.Msg is returned, do not try to send anything back to client, otherwise send contents of dns.Msg. +// +// If an error is returned, log it, don't try to generate data based on that error. +func (s *Server) handlePacketInternal(msg *dns.Msg, addr net.Addr, conn *net.UDPConn) (*dns.Msg, *dnsfilter.Result, Upstream, error) { // log.Printf("Got packet %d bytes from %s: %v", len(p), addr, p) - msg := dns.Msg{} - err := msg.Unpack(p) - if err != nil { - log.Printf("failed to unpack DNS packet: %s", err) - return - } - // // DNS packet byte format is valid // @@ -386,11 +398,7 @@ func (s *Server) handlePacket(p []byte, addr net.Addr, conn *net.UDPConn) { // log.Printf("Unpacked: %v", msg.String()) if len(msg.Question) != 1 { log.Printf("Got invalid number of questions: %v", len(msg.Question)) - err := s.respondWithServerFailure(&msg, addr, conn) - if err != nil { - log.Printf("Couldn't respond to UDP packet: %s", err) - return - } + return s.genServerFailure(msg), nil, nil, nil } // use dnsfilter before cache -- changed settings or filters would require cache invalidation otherwise @@ -398,73 +406,61 @@ func (s *Server) handlePacket(p []byte, addr net.Addr, conn *net.UDPConn) { res, err := s.dnsFilter.CheckHost(host) if err != nil { log.Printf("dnsfilter failed to check host '%s': %s", host, err) - err := s.respondWithServerFailure(&msg, addr, conn) - if err != nil { - log.Printf("Couldn't respond to UDP packet: %s", err) - return - } + return s.genServerFailure(msg), &res, nil, err } else if res.IsFiltered { log.Printf("Host %s is filtered, reason - '%s', matched rule: '%s'", host, res.Reason, res.Rule) - err := s.respondWithNXDomain(&msg, addr, conn) - if err != nil { - log.Printf("Couldn't respond to UDP packet: %s", err) - return - } + return s.genNXDomain(msg), &res, nil, nil } { - val, ok := s.cache.Get(&msg) + val, ok := s.cache.Get(msg) if ok && val != nil { - err = s.respond(val, addr, conn) - if err != nil { - if isConnClosed(err) { - // ignore this error, the connection was closed and that's ok - return - } - log.Printf("Couldn't respond to UDP packet: %s", err) - return - } - return + return val, &res, nil, nil } } // TODO: replace with single-socket implementation upstream := s.chooseUpstream() - reply, err := upstream.Exchange(&msg) + reply, err := upstream.Exchange(msg) if err != nil { log.Printf("talking to upstream failed for host '%s': %s", host, err) - err := s.respondWithServerFailure(&msg, addr, conn) - if err != nil { - if isConnClosed(err) { - // ignore this error, the connection was closed and that's ok - return - } - log.Printf("Couldn't respond to UDP packet with server failure: %s", err) - return - } - return + return s.genServerFailure(msg), &res, upstream, err } if reply == nil { log.Printf("SHOULD NOT HAPPEN upstream returned empty message for host '%s'. Request is %v", host, msg.String()) - err := s.respondWithServerFailure(&msg, addr, conn) - if err != nil { - log.Printf("Couldn't respond to UDP packet with should not happen: %s", err) - return - } - return - } - - err = s.respond(reply, addr, conn) - if err != nil { - if isConnClosed(err) { - // ignore this error, the connection was closed and that's ok - return - } - log.Printf("Couldn't respond to UDP packet: %s", err) - return + return s.genServerFailure(msg), &res, upstream, nil } s.cache.Set(reply) + + return reply, &res, upstream, nil +} + +func (s *Server) handlePacket(p []byte, addr net.Addr, conn *net.UDPConn) { + start := time.Now() + + msg := &dns.Msg{} + err := msg.Unpack(p) + if err != nil { + log.Printf("got invalid DNS packet: %s", err) + return // do nothing + } + + reply, result, upstream, err := s.handlePacketInternal(msg, addr, conn) + if reply != nil { + rerr := s.respond(reply, addr, conn) + if rerr != nil { + log.Printf("Couldn't respond to UDP packet: %s", err) + } + } + + // query logging and stats counters + elapsed := time.Since(start) + upstreamAddr := "" + if upstream != nil { + upstreamAddr = upstream.Address() + } + logRequest(msg, reply, result, elapsed, addr.String(), upstreamAddr) } // @@ -491,17 +487,17 @@ func (s *Server) respond(resp *dns.Msg, addr net.Addr, conn *net.UDPConn) error return nil } -func (s *Server) respondWithServerFailure(request *dns.Msg, addr net.Addr, conn *net.UDPConn) error { +func (s *Server) genServerFailure(request *dns.Msg) *dns.Msg { resp := dns.Msg{} resp.SetRcode(request, dns.RcodeServerFailure) - return s.respond(&resp, addr, conn) + return &resp } -func (s *Server) respondWithNXDomain(request *dns.Msg, addr net.Addr, conn *net.UDPConn) error { +func (s *Server) genNXDomain(request *dns.Msg) *dns.Msg { resp := dns.Msg{} resp.SetRcode(request, dns.RcodeNameError) resp.Ns = s.genSOA(request) - return s.respond(&resp, addr, conn) + return &resp } func (s *Server) genSOA(request *dns.Msg) []dns.RR { @@ -537,3 +533,5 @@ func (s *Server) genSOA(request *dns.Msg) []dns.RR { } return []dns.RR{&soa} } + +var once sync.Once diff --git a/dnsforward/querylog.go b/dnsforward/querylog.go new file mode 100644 index 00000000..65af8b74 --- /dev/null +++ b/dnsforward/querylog.go @@ -0,0 +1,224 @@ +package dnsforward + +import ( + "encoding/json" + "fmt" + "log" + "net/http" + "strconv" + "strings" + "sync" + "time" + + "github.com/AdguardTeam/AdGuardHome/dnsfilter" + "github.com/miekg/dns" +) + +const ( + logBufferCap = 5000 // maximum capacity of logBuffer before it's flushed to disk + queryLogTimeLimit = time.Hour * 24 // how far in the past we care about querylogs + queryLogRotationPeriod = time.Hour * 24 // rotate the log every 24 hours + queryLogFileName = "querylog.json" // .gz added during compression + queryLogSize = 5000 // maximum API response for /querylog + queryLogTopSize = 500 // Keep in memory only top N values +) + +var ( + logBufferLock sync.RWMutex + logBuffer []*logEntry + + queryLogCache []*logEntry + queryLogLock sync.RWMutex +) + +type logEntry struct { + Question []byte + Answer []byte `json:",omitempty"` // sometimes empty answers happen like binerdunt.top or rev2.globalrootservers.net + Result *dnsfilter.Result + Time time.Time + Elapsed time.Duration + IP string + Upstream string `json:",omitempty"` // if empty, means it was cached +} + +func logRequest(question *dns.Msg, answer *dns.Msg, result *dnsfilter.Result, elapsed time.Duration, ip string, upstream string) { + var q []byte + var a []byte + var err error + + if question != nil { + q, err = question.Pack() + if err != nil { + log.Printf("failed to pack question for querylog: %s", err) + return + } + } + if answer != nil { + a, err = answer.Pack() + if err != nil { + log.Printf("failed to pack answer for querylog: %s", err) + return + } + } + + now := time.Now() + entry := logEntry{ + Question: q, + Answer: a, + Result: result, + Time: now, + Elapsed: elapsed, + IP: ip, + Upstream: upstream, + } + var flushBuffer []*logEntry + + logBufferLock.Lock() + logBuffer = append(logBuffer, &entry) + if len(logBuffer) >= logBufferCap { + flushBuffer = logBuffer + logBuffer = nil + } + logBufferLock.Unlock() + queryLogLock.Lock() + queryLogCache = append(queryLogCache, &entry) + if len(queryLogCache) > queryLogSize { + toremove := len(queryLogCache) - queryLogSize + queryLogCache = queryLogCache[toremove:] + } + queryLogLock.Unlock() + + // add it to running top + err = runningTop.addEntry(&entry, question, now) + if err != nil { + log.Printf("Failed to add entry to running top: %s", err) + // don't do failure, just log + } + + incrementCounters(&entry) + + // if buffer needs to be flushed to disk, do it now + if len(flushBuffer) > 0 { + // write to file + // do it in separate goroutine -- we are stalling DNS response this whole time + go flushToFile(flushBuffer) + } +} + +func HandleQueryLog(w http.ResponseWriter, r *http.Request) { + queryLogLock.RLock() + values := make([]*logEntry, len(queryLogCache)) + copy(values, queryLogCache) + queryLogLock.RUnlock() + + // reverse it so that newest is first + for left, right := 0, len(values)-1; left < right; left, right = left+1, right-1 { + values[left], values[right] = values[right], values[left] + } + + var data = []map[string]interface{}{} + for _, entry := range values { + var q *dns.Msg + var a *dns.Msg + + if len(entry.Question) > 0 { + q = new(dns.Msg) + if err := q.Unpack(entry.Question); err != nil { + // ignore, log and move on + log.Printf("Failed to unpack dns message question: %s", err) + q = nil + } + } + if len(entry.Answer) > 0 { + a = new(dns.Msg) + if err := a.Unpack(entry.Answer); err != nil { + // ignore, log and move on + log.Printf("Failed to unpack dns message question: %s", err) + a = nil + } + } + + jsonEntry := map[string]interface{}{ + "reason": entry.Result.Reason.String(), + "elapsedMs": strconv.FormatFloat(entry.Elapsed.Seconds()*1000, 'f', -1, 64), + "time": entry.Time.Format(time.RFC3339), + "client": entry.IP, + } + if q != nil { + jsonEntry["question"] = map[string]interface{}{ + "host": strings.ToLower(strings.TrimSuffix(q.Question[0].Name, ".")), + "type": dns.Type(q.Question[0].Qtype).String(), + "class": dns.Class(q.Question[0].Qclass).String(), + } + } + + if a != nil { + jsonEntry["status"] = dns.RcodeToString[a.Rcode] + } + if len(entry.Result.Rule) > 0 { + jsonEntry["rule"] = entry.Result.Rule + jsonEntry["filterId"] = entry.Result.FilterID + } + + if a != nil && len(a.Answer) > 0 { + var answers = []map[string]interface{}{} + for _, k := range a.Answer { + header := k.Header() + answer := map[string]interface{}{ + "type": dns.TypeToString[header.Rrtype], + "ttl": header.Ttl, + } + // try most common record types + switch v := k.(type) { + case *dns.A: + answer["value"] = v.A + case *dns.AAAA: + answer["value"] = v.AAAA + case *dns.MX: + answer["value"] = fmt.Sprintf("%v %v", v.Preference, v.Mx) + case *dns.CNAME: + answer["value"] = v.Target + case *dns.NS: + answer["value"] = v.Ns + case *dns.SPF: + answer["value"] = v.Txt + case *dns.TXT: + answer["value"] = v.Txt + case *dns.PTR: + answer["value"] = v.Ptr + case *dns.SOA: + answer["value"] = fmt.Sprintf("%v %v %v %v %v %v %v", v.Ns, v.Mbox, v.Serial, v.Refresh, v.Retry, v.Expire, v.Minttl) + case *dns.CAA: + answer["value"] = fmt.Sprintf("%v %v \"%v\"", v.Flag, v.Tag, v.Value) + case *dns.HINFO: + answer["value"] = fmt.Sprintf("\"%v\" \"%v\"", v.Cpu, v.Os) + case *dns.RRSIG: + answer["value"] = fmt.Sprintf("%v %v %v %v %v %v %v %v %v", dns.TypeToString[v.TypeCovered], v.Algorithm, v.Labels, v.OrigTtl, v.Expiration, v.Inception, v.KeyTag, v.SignerName, v.Signature) + default: + // type unknown, marshall it as-is + answer["value"] = v + } + answers = append(answers, answer) + } + jsonEntry["answer"] = answers + } + + data = append(data, jsonEntry) + } + + jsonVal, err := json.Marshal(data) + if err != nil { + errorText := fmt.Sprintf("Couldn't marshal data into json: %s", err) + log.Println(errorText) + http.Error(w, errorText, http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + _, err = w.Write(jsonVal) + if err != nil { + errorText := fmt.Sprintf("Unable to write response json: %s", err) + log.Println(errorText) + http.Error(w, errorText, http.StatusInternalServerError) + } +} diff --git a/dnsforward/querylog_file.go b/dnsforward/querylog_file.go new file mode 100644 index 00000000..9ea8ef95 --- /dev/null +++ b/dnsforward/querylog_file.go @@ -0,0 +1,253 @@ +package dnsforward + +import ( + "bytes" + "compress/gzip" + "encoding/json" + "fmt" + "log" + "os" + "sync" + "time" + + "github.com/go-test/deep" +) + +var ( + fileWriteLock sync.Mutex +) + +const enableGzip = false + +func flushToFile(buffer []*logEntry) error { + if len(buffer) == 0 { + return nil + } + start := time.Now() + + var b bytes.Buffer + e := json.NewEncoder(&b) + for _, entry := range buffer { + err := e.Encode(entry) + if err != nil { + log.Printf("Failed to marshal entry: %s", err) + return err + } + } + + elapsed := time.Since(start) + log.Printf("%d elements serialized via json in %v: %d kB, %v/entry, %v/entry", len(buffer), elapsed, b.Len()/1024, float64(b.Len())/float64(len(buffer)), elapsed/time.Duration(len(buffer))) + + err := checkBuffer(buffer, b) + if err != nil { + log.Printf("failed to check buffer: %s", err) + return err + } + + var zb bytes.Buffer + filename := queryLogFileName + + // gzip enabled? + if enableGzip { + filename += ".gz" + + zw := gzip.NewWriter(&zb) + zw.Name = queryLogFileName + zw.ModTime = time.Now() + + _, err = zw.Write(b.Bytes()) + if err != nil { + log.Printf("Couldn't compress to gzip: %s", err) + zw.Close() + return err + } + + if err = zw.Close(); err != nil { + log.Printf("Couldn't close gzip writer: %s", err) + return err + } + } else { + zb = b + } + + fileWriteLock.Lock() + defer fileWriteLock.Unlock() + f, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644) + if err != nil { + log.Printf("failed to create file \"%s\": %s", filename, err) + return err + } + defer f.Close() + + n, err := f.Write(zb.Bytes()) + if err != nil { + log.Printf("Couldn't write to file: %s", err) + return err + } + + log.Printf("ok \"%s\": %v bytes written", filename, n) + + return nil +} + +func checkBuffer(buffer []*logEntry, b bytes.Buffer) error { + l := len(buffer) + d := json.NewDecoder(&b) + + i := 0 + for d.More() { + entry := &logEntry{} + err := d.Decode(entry) + if err != nil { + log.Printf("Failed to decode: %s", err) + return err + } + if diff := deep.Equal(entry, buffer[i]); diff != nil { + log.Printf("decoded buffer differs: %s", diff) + return fmt.Errorf("decoded buffer differs: %s", diff) + } + i++ + } + if i != l { + err := fmt.Errorf("check fail: %d vs %d entries", l, i) + log.Print(err) + return err + } + log.Printf("check ok: %d entries", i) + + return nil +} + +func rotateQueryLog() error { + from := queryLogFileName + to := queryLogFileName + ".1" + + if enableGzip { + from = queryLogFileName + ".gz" + to = queryLogFileName + ".gz.1" + } + + if _, err := os.Stat(from); os.IsNotExist(err) { + // do nothing, file doesn't exist + return nil + } + + err := os.Rename(from, to) + if err != nil { + log.Printf("Failed to rename querylog: %s", err) + return err + } + + log.Printf("Rotated from %s to %s successfully", from, to) + + return nil +} + +func periodicQueryLogRotate() { + for range time.Tick(queryLogRotationPeriod) { + err := rotateQueryLog() + if err != nil { + log.Printf("Failed to rotate querylog: %s", err) + // do nothing, continue rotating + } + } +} + +func genericLoader(onEntry func(entry *logEntry) error, needMore func() bool, timeWindow time.Duration) error { + now := time.Now() + // read from querylog files, try newest file first + files := []string{} + + if enableGzip { + files = []string{ + queryLogFileName + ".gz", + queryLogFileName + ".gz.1", + } + } else { + files = []string{ + queryLogFileName, + queryLogFileName + ".1", + } + } + + // read from all files + for _, file := range files { + if !needMore() { + break + } + if _, err := os.Stat(file); os.IsNotExist(err) { + // do nothing, file doesn't exist + continue + } + + f, err := os.Open(file) + if err != nil { + log.Printf("Failed to open file \"%s\": %s", file, err) + // try next file + continue + } + defer f.Close() + + var d *json.Decoder + + if enableGzip { + trace("Creating gzip reader") + zr, err := gzip.NewReader(f) + if err != nil { + log.Printf("Failed to create gzip reader: %s", err) + continue + } + defer zr.Close() + + trace("Creating json decoder") + d = json.NewDecoder(zr) + } else { + d = json.NewDecoder(f) + } + + i := 0 + over := 0 + max := 10000 * time.Second + var sum time.Duration + // entries on file are in oldest->newest order + // we want maxLen newest + for d.More() { + if !needMore() { + break + } + var entry logEntry + err := d.Decode(&entry) + if err != nil { + log.Printf("Failed to decode: %s", err) + // next entry can be fine, try more + continue + } + + if now.Sub(entry.Time) > timeWindow { + // trace("skipping entry") // debug logging + continue + } + + if entry.Elapsed > max { + over++ + } else { + sum += entry.Elapsed + } + + i++ + err = onEntry(&entry) + if err != nil { + return err + } + } + elapsed := time.Since(now) + var perunit time.Duration + var avg time.Duration + if i > 0 { + perunit = elapsed / time.Duration(i) + avg = sum / time.Duration(i) + } + log.Printf("file \"%s\": read %d entries in %v, %v/entry, %v over %v, %v avg", file, i, elapsed, perunit, over, max, avg) + } + return nil +} diff --git a/dnsforward/querylog_top.go b/dnsforward/querylog_top.go new file mode 100644 index 00000000..b78dea79 --- /dev/null +++ b/dnsforward/querylog_top.go @@ -0,0 +1,365 @@ +package dnsforward + +import ( + "bytes" + "fmt" + "log" + "net/http" + "os" + "path" + "runtime" + "sort" + "strconv" + "strings" + "sync" + "time" + + "github.com/bluele/gcache" + "github.com/miekg/dns" +) + +type hourTop struct { + domains gcache.Cache + blocked gcache.Cache + clients gcache.Cache + + mutex sync.RWMutex +} + +func (top *hourTop) init() { + top.domains = gcache.New(queryLogTopSize).LRU().Build() + top.blocked = gcache.New(queryLogTopSize).LRU().Build() + top.clients = gcache.New(queryLogTopSize).LRU().Build() +} + +type dayTop struct { + hours []*hourTop + hoursLock sync.RWMutex // writelock this lock ONLY WHEN rotating or intializing hours! + + loaded bool + loadedLock sync.Mutex +} + +var runningTop dayTop + +func init() { + runningTop.hoursWriteLock() + for i := 0; i < 24; i++ { + hour := hourTop{} + hour.init() + runningTop.hours = append(runningTop.hours, &hour) + } + runningTop.hoursWriteUnlock() +} + +func rotateHourlyTop() { + log.Printf("Rotating hourly top") + hour := &hourTop{} + hour.init() + runningTop.hoursWriteLock() + runningTop.hours = append([]*hourTop{hour}, runningTop.hours...) + runningTop.hours = runningTop.hours[:24] + runningTop.hoursWriteUnlock() +} + +func periodicHourlyTopRotate() { + t := time.Hour + for range time.Tick(t) { + rotateHourlyTop() + } +} + +func (top *hourTop) incrementValue(key string, cache gcache.Cache) error { + top.Lock() + defer top.Unlock() + ivalue, err := cache.Get(key) + if err == gcache.KeyNotFoundError { + // we just set it and we're done + err = cache.Set(key, 1) + if err != nil { + log.Printf("Failed to set hourly top value: %s", err) + return err + } + return nil + } + + if err != nil { + log.Printf("gcache encountered an error during get: %s", err) + return err + } + + cachedValue, ok := ivalue.(int) + if !ok { + err = fmt.Errorf("SHOULD NOT HAPPEN: gcache has non-int as value: %v", ivalue) + log.Println(err) + return err + } + + err = cache.Set(key, cachedValue+1) + if err != nil { + log.Printf("Failed to set hourly top value: %s", err) + return err + } + return nil +} + +func (top *hourTop) incrementDomains(key string) error { + return top.incrementValue(key, top.domains) +} + +func (top *hourTop) incrementBlocked(key string) error { + return top.incrementValue(key, top.blocked) +} + +func (top *hourTop) incrementClients(key string) error { + return top.incrementValue(key, top.clients) +} + +// if does not exist -- return 0 +func (top *hourTop) lockedGetValue(key string, cache gcache.Cache) (int, error) { + ivalue, err := cache.Get(key) + if err == gcache.KeyNotFoundError { + return 0, nil + } + + if err != nil { + log.Printf("gcache encountered an error during get: %s", err) + return 0, err + } + + value, ok := ivalue.(int) + if !ok { + err := fmt.Errorf("SHOULD NOT HAPPEN: gcache has non-int as value: %v", ivalue) + log.Println(err) + return 0, err + } + + return value, nil +} + +func (top *hourTop) lockedGetDomains(key string) (int, error) { + return top.lockedGetValue(key, top.domains) +} + +func (top *hourTop) lockedGetBlocked(key string) (int, error) { + return top.lockedGetValue(key, top.blocked) +} + +func (top *hourTop) lockedGetClients(key string) (int, error) { + return top.lockedGetValue(key, top.clients) +} + +func (r *dayTop) addEntry(entry *logEntry, q *dns.Msg, now time.Time) error { + // figure out which hour bucket it belongs to + hour := int(now.Sub(entry.Time).Hours()) + if hour >= 24 { + log.Printf("t %v is >24 hours ago, ignoring", entry.Time) + return nil + } + + hostname := strings.ToLower(strings.TrimSuffix(q.Question[0].Name, ".")) + + // get value, if not set, crate one + runningTop.hoursReadLock() + defer runningTop.hoursReadUnlock() + err := runningTop.hours[hour].incrementDomains(hostname) + if err != nil { + log.Printf("Failed to increment value: %s", err) + return err + } + + if entry.Result.IsFiltered { + err := runningTop.hours[hour].incrementBlocked(hostname) + if err != nil { + log.Printf("Failed to increment value: %s", err) + return err + } + } + + if len(entry.IP) > 0 { + err := runningTop.hours[hour].incrementClients(entry.IP) + if err != nil { + log.Printf("Failed to increment value: %s", err) + return err + } + } + + return nil +} + +func fillStatsFromQueryLog() error { + now := time.Now() + runningTop.loadedWriteLock() + defer runningTop.loadedWriteUnlock() + if runningTop.loaded { + return nil + } + onEntry := func(entry *logEntry) error { + if len(entry.Question) == 0 { + log.Printf("entry question is absent, skipping") + return nil + } + + if entry.Time.After(now) { + log.Printf("t %v vs %v is in the future, ignoring", entry.Time, now) + return nil + } + + q := new(dns.Msg) + if err := q.Unpack(entry.Question); err != nil { + log.Printf("failed to unpack dns message question: %s", err) + return err + } + + if len(q.Question) != 1 { + log.Printf("malformed dns message, has no questions, skipping") + return nil + } + + err := runningTop.addEntry(entry, q, now) + if err != nil { + log.Printf("Failed to add entry to running top: %s", err) + return err + } + + queryLogLock.Lock() + queryLogCache = append(queryLogCache, entry) + if len(queryLogCache) > queryLogSize { + toremove := len(queryLogCache) - queryLogSize + queryLogCache = queryLogCache[toremove:] + } + queryLogLock.Unlock() + + incrementCounters(entry) + + return nil + } + + needMore := func() bool { return true } + err := genericLoader(onEntry, needMore, queryLogTimeLimit) + if err != nil { + log.Printf("Failed to load entries from querylog: %s", err) + return err + } + + runningTop.loaded = true + + return nil +} + +func HandleStatsTop(w http.ResponseWriter, r *http.Request) { + domains := map[string]int{} + blocked := map[string]int{} + clients := map[string]int{} + + do := func(keys []interface{}, getter func(key string) (int, error), result map[string]int) { + for _, ikey := range keys { + key, ok := ikey.(string) + if !ok { + continue + } + value, err := getter(key) + if err != nil { + log.Printf("Failed to get top domains value for %v: %s", key, err) + return + } + result[key] += value + } + } + + runningTop.hoursReadLock() + for hour := 0; hour < 24; hour++ { + runningTop.hours[hour].RLock() + do(runningTop.hours[hour].domains.Keys(), runningTop.hours[hour].lockedGetDomains, domains) + do(runningTop.hours[hour].blocked.Keys(), runningTop.hours[hour].lockedGetBlocked, blocked) + do(runningTop.hours[hour].clients.Keys(), runningTop.hours[hour].lockedGetClients, clients) + runningTop.hours[hour].RUnlock() + } + runningTop.hoursReadUnlock() + + // use manual json marshalling because we want maps to be sorted by value + json := bytes.Buffer{} + json.WriteString("{\n") + + gen := func(json *bytes.Buffer, name string, top map[string]int, addComma bool) { + json.WriteString(" ") + json.WriteString(fmt.Sprintf("%q", name)) + json.WriteString(": {\n") + sorted := sortByValue(top) + // no more than 50 entries + if len(sorted) > 50 { + sorted = sorted[:50] + } + for i, key := range sorted { + json.WriteString(" ") + json.WriteString(fmt.Sprintf("%q", key)) + json.WriteString(": ") + json.WriteString(strconv.Itoa(top[key])) + if i+1 != len(sorted) { + json.WriteByte(',') + } + json.WriteByte('\n') + } + json.WriteString(" }") + if addComma { + json.WriteByte(',') + } + json.WriteByte('\n') + } + gen(&json, "top_queried_domains", domains, true) + gen(&json, "top_blocked_domains", blocked, true) + gen(&json, "top_clients", clients, true) + json.WriteString(" \"stats_period\": \"24 hours\"\n") + json.WriteString("}\n") + + w.Header().Set("Content-Type", "application/json") + _, err := w.Write(json.Bytes()) + if err != nil { + errortext := fmt.Sprintf("Couldn't write body: %s", err) + log.Println(errortext) + http.Error(w, errortext, http.StatusInternalServerError) + } +} + +// helper function for querylog API +func sortByValue(m map[string]int) []string { + type kv struct { + k string + v int + } + var ss []kv + for k, v := range m { + ss = append(ss, kv{k, v}) + } + sort.Slice(ss, func(l, r int) bool { + return ss[l].v > ss[r].v + }) + + sorted := []string{} + for _, v := range ss { + sorted = append(sorted, v.k) + } + return sorted +} + +func (d *dayTop) hoursWriteLock() { tracelock(); d.hoursLock.Lock() } +func (d *dayTop) hoursWriteUnlock() { tracelock(); d.hoursLock.Unlock() } +func (d *dayTop) hoursReadLock() { tracelock(); d.hoursLock.RLock() } +func (d *dayTop) hoursReadUnlock() { tracelock(); d.hoursLock.RUnlock() } +func (d *dayTop) loadedWriteLock() { tracelock(); d.loadedLock.Lock() } +func (d *dayTop) loadedWriteUnlock() { tracelock(); d.loadedLock.Unlock() } + +func (h *hourTop) Lock() { tracelock(); h.mutex.Lock() } +func (h *hourTop) RLock() { tracelock(); h.mutex.RLock() } +func (h *hourTop) RUnlock() { tracelock(); h.mutex.RUnlock() } +func (h *hourTop) Unlock() { tracelock(); h.mutex.Unlock() } + +func tracelock() { + if false { // not commented out to make code checked during compilation + pc := make([]uintptr, 10) // at least 1 entry needed + runtime.Callers(2, pc) + f := path.Base(runtime.FuncForPC(pc[1]).Name()) + lockf := path.Base(runtime.FuncForPC(pc[0]).Name()) + fmt.Fprintf(os.Stderr, "%s(): %s\n", f, lockf) + } +} diff --git a/dnsforward/stats.go b/dnsforward/stats.go new file mode 100644 index 00000000..9809e6e7 --- /dev/null +++ b/dnsforward/stats.go @@ -0,0 +1,398 @@ +package dnsforward + +import ( + "encoding/json" + "fmt" + "log" + "net/http" + "sync" + "time" + + "github.com/AdguardTeam/AdGuardHome/dnsfilter" +) + +var ( + requests = newDNSCounter("requests_total") + filtered = newDNSCounter("filtered_total") + filteredLists = newDNSCounter("filtered_lists_total") + filteredSafebrowsing = newDNSCounter("filtered_safebrowsing_total") + filteredParental = newDNSCounter("filtered_parental_total") + filteredInvalid = newDNSCounter("filtered_invalid_total") + whitelisted = newDNSCounter("whitelisted_total") + safesearch = newDNSCounter("safesearch_total") + errorsTotal = newDNSCounter("errors_total") + elapsedTime = newDNSHistogram("request_duration") +) + +// entries for single time period (for example all per-second entries) +type statsEntries map[string][statsHistoryElements]float64 + +// how far back to keep the stats +const statsHistoryElements = 60 + 1 // +1 for calculating delta + +// each periodic stat is a map of arrays +type periodicStats struct { + Entries statsEntries + period time.Duration // how long one entry lasts + LastRotate time.Time // last time this data was rotated + + sync.RWMutex +} + +type stats struct { + PerSecond periodicStats + PerMinute periodicStats + PerHour periodicStats + PerDay periodicStats +} + +// per-second/per-minute/per-hour/per-day stats +var statistics stats + +func initPeriodicStats(periodic *periodicStats, period time.Duration) { + periodic.Entries = statsEntries{} + periodic.LastRotate = time.Now() + periodic.period = period +} + +func init() { + purgeStats() +} + +func purgeStats() { + initPeriodicStats(&statistics.PerSecond, time.Second) + initPeriodicStats(&statistics.PerMinute, time.Minute) + initPeriodicStats(&statistics.PerHour, time.Hour) + initPeriodicStats(&statistics.PerDay, time.Hour*24) +} + +func (p *periodicStats) Inc(name string, when time.Time) { + // calculate how many periods ago this happened + elapsed := int64(time.Since(when) / p.period) + // trace("%s: %v as %v -> [%v]", name, time.Since(when), p.period, elapsed) + if elapsed >= statsHistoryElements { + return // outside of our timeframe + } + p.Lock() + currentValues := p.Entries[name] + currentValues[elapsed]++ + p.Entries[name] = currentValues + p.Unlock() +} + +func (p *periodicStats) Observe(name string, when time.Time, value float64) { + // calculate how many periods ago this happened + elapsed := int64(time.Since(when) / p.period) + // trace("%s: %v as %v -> [%v]", name, time.Since(when), p.period, elapsed) + if elapsed >= statsHistoryElements { + return // outside of our timeframe + } + p.Lock() + { + countname := name + "_count" + currentValues := p.Entries[countname] + value := currentValues[elapsed] + // trace("Will change p.Entries[%s][%d] from %v to %v", countname, elapsed, value, value+1) + value += 1 + currentValues[elapsed] = value + p.Entries[countname] = currentValues + } + { + totalname := name + "_sum" + currentValues := p.Entries[totalname] + currentValues[elapsed] += value + p.Entries[totalname] = currentValues + } + p.Unlock() +} + +func (p *periodicStats) statsRotate(now time.Time) { + p.Lock() + rotations := int64(now.Sub(p.LastRotate) / p.period) + if rotations > statsHistoryElements { + rotations = statsHistoryElements + } + // calculate how many times we should rotate + for r := int64(0); r < rotations; r++ { + for key, values := range p.Entries { + newValues := [statsHistoryElements]float64{} + for i := 1; i < len(values); i++ { + newValues[i] = values[i-1] + } + p.Entries[key] = newValues + } + } + if rotations > 0 { + p.LastRotate = now + } + p.Unlock() +} + +func statsRotator() { + for range time.Tick(time.Second) { + now := time.Now() + statistics.PerSecond.statsRotate(now) + statistics.PerMinute.statsRotate(now) + statistics.PerHour.statsRotate(now) + statistics.PerDay.statsRotate(now) + } +} + +// counter that wraps around prometheus Counter but also adds to periodic stats +type counter struct { + name string // used as key in periodic stats + value int64 +} + +func newDNSCounter(name string) *counter { + // trace("called") + return &counter{ + name: name, + } +} + +func (c *counter) IncWithTime(when time.Time) { + statistics.PerSecond.Inc(c.name, when) + statistics.PerMinute.Inc(c.name, when) + statistics.PerHour.Inc(c.name, when) + statistics.PerDay.Inc(c.name, when) + c.value++ +} + +func (c *counter) Inc() { + c.IncWithTime(time.Now()) +} + +type histogram struct { + name string // used as key in periodic stats + count int64 + total float64 +} + +func newDNSHistogram(name string) *histogram { + return &histogram{ + name: name, + } +} + +func (h *histogram) ObserveWithTime(value float64, when time.Time) { + statistics.PerSecond.Observe(h.name, when, value) + statistics.PerMinute.Observe(h.name, when, value) + statistics.PerHour.Observe(h.name, when, value) + statistics.PerDay.Observe(h.name, when, value) + h.count++ + h.total += value +} + +func (h *histogram) Observe(value float64) { + h.ObserveWithTime(value, time.Now()) +} + +// ----- +// stats +// ----- +func incrementCounters(entry *logEntry) { + requests.IncWithTime(entry.Time) + if entry.Result.IsFiltered { + filtered.IncWithTime(entry.Time) + } + switch entry.Result.Reason { + case dnsfilter.NotFilteredWhiteList: + whitelisted.IncWithTime(entry.Time) + case dnsfilter.NotFilteredError: + errorsTotal.IncWithTime(entry.Time) + case dnsfilter.FilteredBlackList: + filteredLists.IncWithTime(entry.Time) + case dnsfilter.FilteredSafeBrowsing: + filteredSafebrowsing.IncWithTime(entry.Time) + case dnsfilter.FilteredParental: + filteredParental.IncWithTime(entry.Time) + case dnsfilter.FilteredInvalid: + // do nothing + case dnsfilter.FilteredSafeSearch: + safesearch.IncWithTime(entry.Time) + } + elapsedTime.ObserveWithTime(entry.Elapsed.Seconds(), entry.Time) +} + +func HandleStats(w http.ResponseWriter, r *http.Request) { + const numHours = 24 + histrical := generateMapFromStats(&statistics.PerHour, 0, numHours) + // sum them up + summed := map[string]interface{}{} + for key, values := range histrical { + summedValue := 0.0 + floats, ok := values.([]float64) + if !ok { + continue + } + for _, v := range floats { + summedValue += v + } + summed[key] = summedValue + } + // don't forget to divide by number of elements in returned slice + if val, ok := summed["avg_processing_time"]; ok { + if flval, flok := val.(float64); flok { + flval /= numHours + summed["avg_processing_time"] = flval + } + } + + summed["stats_period"] = "24 hours" + + json, err := json.Marshal(summed) + if err != nil { + errortext := fmt.Sprintf("Unable to marshal status json: %s", err) + log.Println(errortext) + http.Error(w, errortext, 500) + return + } + w.Header().Set("Content-Type", "application/json") + _, err = w.Write(json) + if err != nil { + errortext := fmt.Sprintf("Unable to write response json: %s", err) + log.Println(errortext) + http.Error(w, errortext, 500) + return + } +} + +func generateMapFromStats(stats *periodicStats, start int, end int) map[string]interface{} { + // clamp + start = clamp(start, 0, statsHistoryElements) + end = clamp(end, 0, statsHistoryElements) + + avgProcessingTime := make([]float64, 0) + + count := getReversedSlice(stats.Entries[elapsedTime.name+"_count"], start, end) + sum := getReversedSlice(stats.Entries[elapsedTime.name+"_sum"], start, end) + for i := 0; i < len(count); i++ { + var avg float64 + if count[i] != 0 { + avg = sum[i] / count[i] + avg *= 1000 + } + avgProcessingTime = append(avgProcessingTime, avg) + } + + result := map[string]interface{}{ + "dns_queries": getReversedSlice(stats.Entries[requests.name], start, end), + "blocked_filtering": getReversedSlice(stats.Entries[filtered.name], start, end), + "replaced_safebrowsing": getReversedSlice(stats.Entries[filteredSafebrowsing.name], start, end), + "replaced_safesearch": getReversedSlice(stats.Entries[safesearch.name], start, end), + "replaced_parental": getReversedSlice(stats.Entries[filteredParental.name], start, end), + "avg_processing_time": avgProcessingTime, + } + return result +} + +func HandleStatsHistory(w http.ResponseWriter, r *http.Request) { + // handle time unit and prepare our time window size + now := time.Now() + timeUnitString := r.URL.Query().Get("time_unit") + var stats *periodicStats + var timeUnit time.Duration + switch timeUnitString { + case "seconds": + timeUnit = time.Second + stats = &statistics.PerSecond + case "minutes": + timeUnit = time.Minute + stats = &statistics.PerMinute + case "hours": + timeUnit = time.Hour + stats = &statistics.PerHour + case "days": + timeUnit = time.Hour * 24 + stats = &statistics.PerDay + default: + http.Error(w, "Must specify valid time_unit parameter", 400) + return + } + + // parse start and end time + startTime, err := time.Parse(time.RFC3339, r.URL.Query().Get("start_time")) + if err != nil { + errortext := fmt.Sprintf("Must specify valid start_time parameter: %s", err) + log.Println(errortext) + http.Error(w, errortext, 400) + return + } + endTime, err := time.Parse(time.RFC3339, r.URL.Query().Get("end_time")) + if err != nil { + errortext := fmt.Sprintf("Must specify valid end_time parameter: %s", err) + log.Println(errortext) + http.Error(w, errortext, 400) + return + } + + // check if start and time times are within supported time range + timeRange := timeUnit * statsHistoryElements + if startTime.Add(timeRange).Before(now) { + http.Error(w, "start_time parameter is outside of supported range", 501) + return + } + if endTime.Add(timeRange).Before(now) { + http.Error(w, "end_time parameter is outside of supported range", 501) + return + } + + // calculate start and end of our array + // basically it's how many hours/minutes/etc have passed since now + start := int(now.Sub(endTime) / timeUnit) + end := int(now.Sub(startTime) / timeUnit) + + // swap them around if they're inverted + if start > end { + start, end = end, start + } + + data := generateMapFromStats(stats, start, end) + json, err := json.Marshal(data) + if err != nil { + errortext := fmt.Sprintf("Unable to marshal status json: %s", err) + log.Println(errortext) + http.Error(w, errortext, 500) + return + } + w.Header().Set("Content-Type", "application/json") + _, err = w.Write(json) + if err != nil { + errortext := fmt.Sprintf("Unable to write response json: %s", err) + log.Println(errortext) + http.Error(w, errortext, 500) + return + } +} + +func HandleStatsReset(w http.ResponseWriter, r *http.Request) { + purgeStats() + _, err := fmt.Fprintf(w, "OK\n") + if err != nil { + errortext := fmt.Sprintf("Couldn't write body: %s", err) + log.Println(errortext) + http.Error(w, errortext, http.StatusInternalServerError) + } +} + +func clamp(value, low, high int) int { + if value < low { + return low + } + if value > high { + return high + } + return value +} + +// -------------------------- +// helper functions for stats +// -------------------------- +func getReversedSlice(input [statsHistoryElements]float64, start int, end int) []float64 { + output := make([]float64, 0) + for i := start; i <= end; i++ { + output = append([]float64{input[i]}, output...) + } + return output +} From e35762074016c270d6df2613017ed5a08bfa0749 Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Wed, 5 Dec 2018 14:21:25 +0300 Subject: [PATCH 34/58] Plug correct stats handler functions. --- control.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/control.go b/control.go index f41655c3..9b57ed06 100644 --- a/control.go +++ b/control.go @@ -11,9 +11,9 @@ import ( "strings" "time" + "github.com/AdguardTeam/AdGuardHome/dnsforward" "github.com/AdguardTeam/AdGuardHome/upstream" - corednsplugin "github.com/AdguardTeam/AdGuardHome/coredns_plugin" "gopkg.in/asaskevich/govalidator.v4" ) @@ -676,17 +676,17 @@ func registerControlHandlers() { http.HandleFunc("/control/status", optionalAuth(ensureGET(handleStatus))) http.HandleFunc("/control/enable_protection", optionalAuth(ensurePOST(handleProtectionEnable))) http.HandleFunc("/control/disable_protection", optionalAuth(ensurePOST(handleProtectionDisable))) - http.HandleFunc("/control/querylog", optionalAuth(ensureGET(corednsplugin.HandleQueryLog))) + http.HandleFunc("/control/querylog", optionalAuth(ensureGET(dnsforward.HandleQueryLog))) http.HandleFunc("/control/querylog_enable", optionalAuth(ensurePOST(handleQueryLogEnable))) http.HandleFunc("/control/querylog_disable", optionalAuth(ensurePOST(handleQueryLogDisable))) http.HandleFunc("/control/set_upstream_dns", optionalAuth(ensurePOST(handleSetUpstreamDNS))) http.HandleFunc("/control/test_upstream_dns", optionalAuth(ensurePOST(handleTestUpstreamDNS))) http.HandleFunc("/control/i18n/change_language", optionalAuth(ensurePOST(handleI18nChangeLanguage))) http.HandleFunc("/control/i18n/current_language", optionalAuth(ensureGET(handleI18nCurrentLanguage))) - http.HandleFunc("/control/stats_top", optionalAuth(ensureGET(corednsplugin.HandleStatsTop))) - http.HandleFunc("/control/stats", optionalAuth(ensureGET(corednsplugin.HandleStats))) - http.HandleFunc("/control/stats_history", optionalAuth(ensureGET(corednsplugin.HandleStatsHistory))) - http.HandleFunc("/control/stats_reset", optionalAuth(ensurePOST(corednsplugin.HandleStatsReset))) + http.HandleFunc("/control/stats_top", optionalAuth(ensureGET(dnsforward.HandleStatsTop))) + http.HandleFunc("/control/stats", optionalAuth(ensureGET(dnsforward.HandleStats))) + http.HandleFunc("/control/stats_history", optionalAuth(ensureGET(dnsforward.HandleStatsHistory))) + http.HandleFunc("/control/stats_reset", optionalAuth(ensurePOST(dnsforward.HandleStatsReset))) http.HandleFunc("/control/version.json", optionalAuth(handleGetVersionJSON)) http.HandleFunc("/control/filtering/enable", optionalAuth(ensurePOST(handleFilteringEnable))) http.HandleFunc("/control/filtering/disable", optionalAuth(ensurePOST(handleFilteringDisable))) From 9b43e07d7f31edff63e76bbad879e015ab3df21b Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Wed, 5 Dec 2018 14:21:48 +0300 Subject: [PATCH 35/58] dnsforward -- flush querylog to file on server stop --- dnsforward/dnsforward.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index a5a7193a..508011fc 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -215,6 +215,18 @@ func (s *Server) Stop() error { return errorx.Decorate(err, "Couldn't close UDP listening socket") } } + + // flush remainder to file + logBufferLock.Lock() + flushBuffer := logBuffer + logBuffer = nil + logBufferLock.Unlock() + err := flushToFile(flushBuffer) + if err != nil { + log.Printf("Saving querylog to file failed: %s", err) + return err + } + return nil } From b0149972cca4a40d24cef3b1571bfa0537ccb9fe Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Wed, 5 Dec 2018 14:22:09 +0300 Subject: [PATCH 36/58] dnsforward -- give only ip address to querylog, without port --- dnsforward/dnsforward.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index 508011fc..4b731d45 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -472,7 +472,11 @@ func (s *Server) handlePacket(p []byte, addr net.Addr, conn *net.UDPConn) { if upstream != nil { upstreamAddr = upstream.Address() } - logRequest(msg, reply, result, elapsed, addr.String(), upstreamAddr) + host, _, err := net.SplitHostPort(addr.String()) + if err != nil { + log.Printf("Failed to split %v into host/port: %s", addr, err) + } + logRequest(msg, reply, result, elapsed, host, upstreamAddr) } // From 7ddc71006ba78dcbe7c3a6e7cdd133f2d5602943 Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Wed, 5 Dec 2018 15:36:18 +0300 Subject: [PATCH 37/58] stop DNS server properly when interrupted with ctrl+c, SIGTERM, SIGHUP or SIGQUIT --- app.go | 20 ++++++++++++++++---- coredns.go | 13 +++++++++++++ 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/app.go b/app.go index f958c7b4..10469245 100644 --- a/app.go +++ b/app.go @@ -7,8 +7,10 @@ import ( "net" "net/http" "os" + "os/signal" "path/filepath" "strconv" + "syscall" "time" "github.com/gobuffalo/packr" @@ -164,10 +166,13 @@ func main() { } }() - // Eat all args so that coredns can start happily - if len(os.Args) > 1 { - os.Args = os.Args[:1] - } + signal_channel := make(chan os.Signal) + signal.Notify(signal_channel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT) + go func() { + <-signal_channel + cleanup() + os.Exit(0) + }() // Save the updated config err := config.write() @@ -192,6 +197,13 @@ func main() { log.Fatal(http.ListenAndServe(address, nil)) } +func cleanup() { + err := stopDNSServer() + if err != nil { + log.Printf("Couldn't stop DNS server: %s", err) + } +} + func getInput() (string, error) { scanner := bufio.NewScanner(os.Stdin) scanner.Scan() diff --git a/coredns.go b/coredns.go index f6940abe..119b2371 100644 --- a/coredns.go +++ b/coredns.go @@ -74,3 +74,16 @@ func reconfigureDNSServer() error { return nil } + +func stopDNSServer() error { + if !isRunning() { + return fmt.Errorf("Refusing to stop forwarding DNS server: not running") + } + + err := dnsServer.Stop() + if err != nil { + return errorx.Decorate(err, "Couldn't stop forwarding DNS server") + } + + return nil +} From 15f0dee71913f6174bfa923fc3a99062f615d859 Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Wed, 5 Dec 2018 16:25:33 +0300 Subject: [PATCH 38/58] readme -- Cleanup, mention that coredns was removed. --- README.md | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 17cec80c..37ee39d3 100644 --- a/README.md +++ b/README.md @@ -151,7 +151,15 @@ cd AdGuardHome make ``` -## How to update translations +## Contributing + +You are welcome to fork this repository, make your changes and submit a pull request — https://github.com/AdguardTeam/AdGuardHome/pulls + +### How to update translations + +If you want to help with AdGuard Home translations, please learn more about translating AdGuard products here: https://kb.adguard.com/en/general/adguard-translations + +Here is a direct link to AdGuard Home project: http://translate.adguard.com/collaboration/project?id=153384 Before updating translations you need to install dependencies: ``` @@ -181,14 +189,6 @@ node upload.js node download.js ``` -## Contributing - -You are welcome to fork this repository, make your changes and submit a pull request — https://github.com/AdguardTeam/AdGuardHome/pulls - -If you want to help with AdGuard Home translations, please learn more about translating AdGuard products here: https://kb.adguard.com/en/general/adguard-translations - -Here is a direct link to AdGuard Home project: http://translate.adguard.com/collaboration/project?id=153384 - ## Reporting issues If you run into any problem or have a suggestion, head to [this page](https://github.com/AdguardTeam/AdGuardHome/issues) and click on the `New issue` button. @@ -198,7 +198,6 @@ If you run into any problem or have a suggestion, head to [this page](https://gi This software wouldn't have been possible without: * [Go](https://golang.org/dl/) and it's libraries: - * [CoreDNS](https://coredns.io) * [packr](https://github.com/gobuffalo/packr) * [gcache](https://github.com/bluele/gcache) * [miekg's dns](https://github.com/miekg/dns) @@ -209,4 +208,6 @@ This software wouldn't have been possible without: * And many more node.js packages. * [whotracks.me data](https://github.com/cliqz-oss/whotracks.me) +You might have seen coredns mentioned here before, but we've stopped using it in AdGuardHome. While we still use it on our servers, it seemed like an overkill and impeded with Home features that we wanted to implement. + For a full list of all node.js packages in use, please take a look at [client/package.json](https://github.com/AdguardTeam/AdGuardHome/blob/master/client/package.json) file. From 478ce03386b92c265f63742b44c3b456e534d593 Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Wed, 5 Dec 2018 18:47:03 +0300 Subject: [PATCH 39/58] dnsforward -- implement ratelimit and refuseany --- config.go | 27 ++++++------ dnsforward/dnsforward.go | 66 +++++++++++++++++++++++------ dnsforward/ratelimit.go | 80 ++++++++++++++++++++++++++++++++++++ dnsforward/ratelimit_test.go | 42 +++++++++++++++++++ 4 files changed, 187 insertions(+), 28 deletions(-) create mode 100644 dnsforward/ratelimit.go create mode 100644 dnsforward/ratelimit_test.go diff --git a/config.go b/config.go index f9907feb..52efdfd9 100644 --- a/config.go +++ b/config.go @@ -46,14 +46,11 @@ type coreDNSConfig struct { dnsforward.FilteringConfig `yaml:",inline"` - QueryLogEnabled bool `yaml:"querylog_enabled"` - Ratelimit int `yaml:"ratelimit"` - RefuseAny bool `yaml:"refuse_any"` - Pprof string `yaml:"-"` - Cache string `yaml:"-"` - Prometheus string `yaml:"-"` - BootstrapDNS string `yaml:"bootstrap_dns"` - UpstreamDNS []string `yaml:"upstream_dns"` + Pprof string `yaml:"-"` + Cache string `yaml:"-"` + Prometheus string `yaml:"-"` + BootstrapDNS string `yaml:"bootstrap_dns"` + UpstreamDNS []string `yaml:"upstream_dns"` } var defaultDNS = []string{"tls://1.1.1.1", "tls://1.0.0.1"} @@ -71,14 +68,14 @@ var config = configuration{ ProtectionEnabled: true, // whether or not use any of dnsfilter features FilteringEnabled: true, // whether or not use filter lists BlockedResponseTTL: 10, // in seconds + QueryLogEnabled: true, + Ratelimit: 20, + RefuseAny: true, }, - QueryLogEnabled: true, - Ratelimit: 20, - RefuseAny: true, - BootstrapDNS: "8.8.8.8:53", - UpstreamDNS: defaultDNS, - Cache: "cache", - Prometheus: "prometheus :9153", + BootstrapDNS: "8.8.8.8:53", + UpstreamDNS: defaultDNS, + Cache: "cache", + Prometheus: "prometheus :9153", }, Filters: []filter{ {Filter: dnsfilter.Filter{ID: 1}, Enabled: true, URL: "https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt", Name: "AdGuard Simplified Domain Names filter"}, diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index 4b731d45..bee85d3a 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -12,6 +12,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/dnsfilter" "github.com/joomcode/errorx" "github.com/miekg/dns" + gocache "github.com/patrickmn/go-cache" ) // Server is the main way to start a DNS server. @@ -31,6 +32,8 @@ type Server struct { cache cache + ratelimitBuckets *gocache.Cache // where the ratelimiters are stored, per IP + sync.RWMutex ServerConfig } @@ -76,9 +79,13 @@ func (s *Server) RUnlock() { */ type FilteringConfig struct { - ProtectionEnabled bool `yaml:"protection_enabled"` - FilteringEnabled bool `yaml:"filtering_enabled"` - BlockedResponseTTL uint32 `yaml:"blocked_response_ttl"` // if 0, then default is used (3600) + ProtectionEnabled bool `yaml:"protection_enabled"` + FilteringEnabled bool `yaml:"filtering_enabled"` + BlockedResponseTTL uint32 `yaml:"blocked_response_ttl"` // if 0, then default is used (3600) + QueryLogEnabled bool `yaml:"querylog_enabled"` + Ratelimit int `yaml:"ratelimit"` + RatelimitWhitelist []string `yaml:"ratelimit_whitelist"` + RefuseAny bool `yaml:"refuse_any"` dnsfilter.Config `yaml:",inline"` } @@ -92,6 +99,7 @@ type ServerConfig struct { FilteringConfig } +// if any of ServerConfig values are zero, then default values from below are used var defaultValues = ServerConfig{ UDPListenAddr: &net.UDPAddr{Port: 53}, FilteringConfig: FilteringConfig{BlockedResponseTTL: 3600}, @@ -413,6 +421,10 @@ func (s *Server) handlePacketInternal(msg *dns.Msg, addr net.Addr, conn *net.UDP return s.genServerFailure(msg), nil, nil, nil } + if msg.Question[0].Qtype == dns.TypeANY && s.RefuseAny { + return s.genNotImpl(msg), nil, nil, nil + } + // use dnsfilter before cache -- changed settings or filters would require cache invalidation otherwise host := strings.TrimSuffix(msg.Question[0].Name, ".") res, err := s.dnsFilter.CheckHost(host) @@ -450,16 +462,36 @@ func (s *Server) handlePacketInternal(msg *dns.Msg, addr net.Addr, conn *net.UDP func (s *Server) handlePacket(p []byte, addr net.Addr, conn *net.UDPConn) { start := time.Now() + ip, _, err := net.SplitHostPort(addr.String()) + if err != nil { + log.Printf("Failed to split %v into host/port: %s", addr, err) + // not a fatal error, move on + } + + // ratelimit based on IP only, protects CPU cycles and outbound connections + if s.isRatelimited(ip) { + // log.Printf("Ratelimiting %s based on IP only", ip) + return // do nothing, don't reply, we got ratelimited + } msg := &dns.Msg{} - err := msg.Unpack(p) + err = msg.Unpack(p) if err != nil { log.Printf("got invalid DNS packet: %s", err) return // do nothing } reply, result, upstream, err := s.handlePacketInternal(msg, addr, conn) + if reply != nil { + // ratelimit based on reply size now + replysize := reply.Len() + if s.isRatelimitedForReply(ip, replysize) { + log.Printf("Ratelimiting %s based on IP and size %d", ip, replysize) + return // do nothing, don't reply, we got ratelimited + } + + // we're good to respond rerr := s.respond(reply, addr, conn) if rerr != nil { log.Printf("Couldn't respond to UDP packet: %s", err) @@ -467,16 +499,14 @@ func (s *Server) handlePacket(p []byte, addr net.Addr, conn *net.UDPConn) { } // query logging and stats counters - elapsed := time.Since(start) - upstreamAddr := "" - if upstream != nil { - upstreamAddr = upstream.Address() + if s.QueryLogEnabled { + elapsed := time.Since(start) + upstreamAddr := "" + if upstream != nil { + upstreamAddr = upstream.Address() + } + logRequest(msg, reply, result, elapsed, ip, upstreamAddr) } - host, _, err := net.SplitHostPort(addr.String()) - if err != nil { - log.Printf("Failed to split %v into host/port: %s", addr, err) - } - logRequest(msg, reply, result, elapsed, host, upstreamAddr) } // @@ -506,12 +536,22 @@ func (s *Server) respond(resp *dns.Msg, addr net.Addr, conn *net.UDPConn) error func (s *Server) genServerFailure(request *dns.Msg) *dns.Msg { resp := dns.Msg{} resp.SetRcode(request, dns.RcodeServerFailure) + resp.RecursionAvailable = true + return &resp +} + +func (s *Server) genNotImpl(request *dns.Msg) *dns.Msg { + resp := dns.Msg{} + resp.SetRcode(request, dns.RcodeNotImplemented) + resp.RecursionAvailable = true + resp.SetEdns0(1452, false) // NOTIMPL without EDNS is treated as 'we don't support EDNS', so explicitly set it return &resp } func (s *Server) genNXDomain(request *dns.Msg) *dns.Msg { resp := dns.Msg{} resp.SetRcode(request, dns.RcodeNameError) + resp.RecursionAvailable = true resp.Ns = s.genSOA(request) return &resp } diff --git a/dnsforward/ratelimit.go b/dnsforward/ratelimit.go new file mode 100644 index 00000000..9ea8d216 --- /dev/null +++ b/dnsforward/ratelimit.go @@ -0,0 +1,80 @@ +package dnsforward + +import ( + "log" + "sort" + "time" + + "github.com/beefsack/go-rate" + gocache "github.com/patrickmn/go-cache" +) + +func (s *Server) limiterForIP(ip string) interface{} { + if s.ratelimitBuckets == nil { + s.ratelimitBuckets = gocache.New(time.Hour, time.Hour) + } + + // check if ratelimiter for that IP already exists, if not, create + value, found := s.ratelimitBuckets.Get(ip) + if !found { + value = rate.New(s.Ratelimit, time.Second) + s.ratelimitBuckets.Set(ip, value, time.Hour) + } + + return value +} + +func (s *Server) isRatelimited(ip string) bool { + if s.Ratelimit == 0 { // 0 -- disabled + return false + } + if len(s.RatelimitWhitelist) > 0 { + i := sort.SearchStrings(s.RatelimitWhitelist, ip) + + if i < len(s.RatelimitWhitelist) && s.RatelimitWhitelist[i] == ip { + // found, don't ratelimit + return false + } + } + + value := s.limiterForIP(ip) + rl, ok := value.(*rate.RateLimiter) + if !ok { + log.Println("SHOULD NOT HAPPEN: non-bool entry found in safebrowsing lookup cache") + return false + } + + allow, _ := rl.Try() + return !allow +} + +func (s *Server) isRatelimitedForReply(ip string, size int) bool { + if s.Ratelimit == 0 { // 0 -- disabled + return false + } + if len(s.RatelimitWhitelist) > 0 { + i := sort.SearchStrings(s.RatelimitWhitelist, ip) + + if i < len(s.RatelimitWhitelist) && s.RatelimitWhitelist[i] == ip { + // found, don't ratelimit + return false + } + } + + value := s.limiterForIP(ip) + rl, ok := value.(*rate.RateLimiter) + if !ok { + log.Println("SHOULD NOT HAPPEN: non-bool entry found in safebrowsing lookup cache") + return false + } + + // For large UDP responses we try more times, effectively limiting per bandwidth + // The exact number of times depends on the response size + for i := 0; i < size/1000; i++ { + allow, _ := rl.Try() + if !allow { // not allowed -> ratelimited + return true + } + } + return false +} diff --git a/dnsforward/ratelimit_test.go b/dnsforward/ratelimit_test.go new file mode 100644 index 00000000..ed6f5ce9 --- /dev/null +++ b/dnsforward/ratelimit_test.go @@ -0,0 +1,42 @@ +package dnsforward + +import ( + "testing" +) + +func TestRatelimiting(t *testing.T) { + // rate limit is 1 per sec + p := Server{} + p.Ratelimit = 1 + + limited := p.isRatelimited("127.0.0.1") + + if limited { + t.Fatal("First request must have been allowed") + } + + limited = p.isRatelimited("127.0.0.1") + + if !limited { + t.Fatal("Second request must have been ratelimited") + } +} + +func TestWhitelist(t *testing.T) { + // rate limit is 1 per sec with whitelist + p := Server{} + p.Ratelimit = 1 + p.RatelimitWhitelist = []string{"127.0.0.1", "127.0.0.2", "127.0.0.125"} + + limited := p.isRatelimited("127.0.0.1") + + if limited { + t.Fatal("First request must have been allowed") + } + + limited = p.isRatelimited("127.0.0.1") + + if limited { + t.Fatal("Second request must have been allowed due to whitelist") + } +} From f6942213c8c57f5721512fb220cd553415237df6 Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Wed, 5 Dec 2018 19:17:17 +0300 Subject: [PATCH 40/58] Use dnsforward for checking if upstream DNS server is working. --- control.go | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/control.go b/control.go index 9b57ed06..0d57816d 100644 --- a/control.go +++ b/control.go @@ -5,6 +5,7 @@ import ( "fmt" "io/ioutil" "log" + "net" "net/http" "os" "strconv" @@ -12,7 +13,7 @@ import ( "time" "github.com/AdguardTeam/AdGuardHome/dnsforward" - "github.com/AdguardTeam/AdGuardHome/upstream" + "github.com/miekg/dns" "gopkg.in/asaskevich/govalidator.v4" ) @@ -202,23 +203,32 @@ func handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) { } func checkDNS(input string) error { - u, err := upstream.NewUpstream(input, config.CoreDNS.BootstrapDNS) - + log.Printf("Checking if DNS %s works...", input) + u, err := dnsforward.GetUpstream(input) if err != nil { - return err + return fmt.Errorf("Failed to choose upstream for %s: %s", input, err) } - defer u.Close() - - alive, err := upstream.IsAlive(u) + req := dns.Msg{} + req.Id = dns.Id() + req.RecursionDesired = true + req.Question = []dns.Question{ + {Name: "google-public-dns-a.google.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, + } + reply, err := u.Exchange(&req) if err != nil { return fmt.Errorf("couldn't communicate with DNS server %s: %s", input, err) } - - if !alive { - return fmt.Errorf("DNS server has not passed the healthcheck: %s", input) + if len(reply.Answer) != 1 { + return fmt.Errorf("DNS server %s returned wrong answer", input) + } + if t, ok := reply.Answer[0].(*dns.A); ok { + if !net.IPv4(8, 8, 8, 8).Equal(t.A) { + return fmt.Errorf("DNS server %s returned wrong answer: %v", input, t.A) + } } + log.Printf("DNS %s works OK", input) return nil } From b84786631026b78e3d04ed3ab42555aabbc9b88c Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Wed, 5 Dec 2018 19:18:58 +0300 Subject: [PATCH 41/58] Remove unused code. Goodbye CoreDNS. --- coredns_plugin/coredns_plugin.go | 558 --------------------- coredns_plugin/coredns_plugin_test.go | 131 ----- coredns_plugin/coredns_stats.go | 410 --------------- coredns_plugin/querylog.go | 239 --------- coredns_plugin/querylog_file.go | 291 ----------- coredns_plugin/querylog_top.go | 386 -------------- coredns_plugin/ratelimit/ratelimit.go | 182 ------- coredns_plugin/ratelimit/ratelimit_test.go | 80 --- coredns_plugin/refuseany/refuseany.go | 91 ---- coredns_plugin/reload.go | 36 -- go.mod | 16 - go.sum | 50 -- upstream/dns_upstream.go | 105 ---- upstream/helpers.go | 98 ---- upstream/https_upstream.go | 128 ----- upstream/persistent.go | 210 -------- upstream/setup.go | 81 --- upstream/setup_test.go | 29 -- upstream/upstream.go | 57 --- upstream/upstream_test.go | 187 ------- 20 files changed, 3365 deletions(-) delete mode 100644 coredns_plugin/coredns_plugin.go delete mode 100644 coredns_plugin/coredns_plugin_test.go delete mode 100644 coredns_plugin/coredns_stats.go delete mode 100644 coredns_plugin/querylog.go delete mode 100644 coredns_plugin/querylog_file.go delete mode 100644 coredns_plugin/querylog_top.go delete mode 100644 coredns_plugin/ratelimit/ratelimit.go delete mode 100644 coredns_plugin/ratelimit/ratelimit_test.go delete mode 100644 coredns_plugin/refuseany/refuseany.go delete mode 100644 coredns_plugin/reload.go delete mode 100644 upstream/dns_upstream.go delete mode 100644 upstream/helpers.go delete mode 100644 upstream/https_upstream.go delete mode 100644 upstream/persistent.go delete mode 100644 upstream/setup.go delete mode 100644 upstream/setup_test.go delete mode 100644 upstream/upstream.go delete mode 100644 upstream/upstream_test.go diff --git a/coredns_plugin/coredns_plugin.go b/coredns_plugin/coredns_plugin.go deleted file mode 100644 index 8d302fd3..00000000 --- a/coredns_plugin/coredns_plugin.go +++ /dev/null @@ -1,558 +0,0 @@ -package dnsfilter - -import ( - "bufio" - "errors" - "fmt" - "log" - "net" - "os" - "strconv" - "strings" - "sync" - "time" - - "github.com/AdguardTeam/AdGuardHome/dnsfilter" - "github.com/coredns/coredns/core/dnsserver" - "github.com/coredns/coredns/plugin" - "github.com/coredns/coredns/plugin/metrics" - "github.com/coredns/coredns/plugin/pkg/dnstest" - "github.com/coredns/coredns/plugin/pkg/upstream" - "github.com/coredns/coredns/request" - "github.com/mholt/caddy" - "github.com/miekg/dns" - "github.com/prometheus/client_golang/prometheus" - "golang.org/x/net/context" -) - -var defaultSOA = &dns.SOA{ - // values copied from verisign's nonexistent .com domain - // their exact values are not important in our use case because they are used for domain transfers between primary/secondary DNS servers - Refresh: 1800, - Retry: 900, - Expire: 604800, - Minttl: 86400, -} - -func init() { - caddy.RegisterPlugin("dnsfilter", caddy.Plugin{ - ServerType: "dns", - Action: setup, - }) -} - -type plugFilter struct { - ID int64 - Path string -} - -type plugSettings struct { - SafeBrowsingBlockHost string - ParentalBlockHost string - QueryLogEnabled bool - BlockedTTL uint32 // in seconds, default 3600 - Filters []plugFilter -} - -type plug struct { - d *dnsfilter.Dnsfilter - Next plugin.Handler - upstream upstream.Upstream - settings plugSettings - - sync.RWMutex -} - -var defaultPluginSettings = plugSettings{ - SafeBrowsingBlockHost: "safebrowsing.block.dns.adguard.com", - ParentalBlockHost: "family.block.dns.adguard.com", - BlockedTTL: 3600, // in seconds - Filters: make([]plugFilter, 0), -} - -// -// coredns handling functions -// -func setupPlugin(c *caddy.Controller) (*plug, error) { - // create new Plugin and copy default values - p := &plug{ - settings: defaultPluginSettings, - d: dnsfilter.New(nil), - } - - log.Println("Initializing the CoreDNS plugin") - - for c.Next() { - for c.NextBlock() { - blockValue := c.Val() - switch blockValue { - case "safebrowsing": - log.Println("Browsing security service is enabled") - p.d.SafeBrowsingEnabled = true - if c.NextArg() { - if len(c.Val()) == 0 { - return nil, c.ArgErr() - } - p.d.SetSafeBrowsingServer(c.Val()) - } - case "safesearch": - log.Println("Safe search is enabled") - p.d.SafeSearchEnabled = true - case "parental": - if !c.NextArg() { - return nil, c.ArgErr() - } - sensitivity, err := strconv.Atoi(c.Val()) - if err != nil { - return nil, c.ArgErr() - } - - log.Println("Parental control is enabled") - if !dnsfilter.IsParentalSensitivityValid(sensitivity) { - return nil, dnsfilter.ErrInvalidParental - } - p.d.ParentalEnabled = true - p.d.ParentalSensitivity = sensitivity - if c.NextArg() { - if len(c.Val()) == 0 { - return nil, c.ArgErr() - } - p.settings.ParentalBlockHost = c.Val() - } - case "blocked_ttl": - if !c.NextArg() { - return nil, c.ArgErr() - } - blockedTtl, err := strconv.ParseUint(c.Val(), 10, 32) - if err != nil { - return nil, c.ArgErr() - } - log.Printf("Blocked request TTL is %d", blockedTtl) - p.settings.BlockedTTL = uint32(blockedTtl) - case "querylog": - log.Println("Query log is enabled") - p.settings.QueryLogEnabled = true - case "filter": - if !c.NextArg() { - return nil, c.ArgErr() - } - - filterId, err := strconv.ParseInt(c.Val(), 10, 64) - if err != nil { - return nil, c.ArgErr() - } - if !c.NextArg() { - return nil, c.ArgErr() - } - filterPath := c.Val() - - // Initialize filter and add it to the list - p.settings.Filters = append(p.settings.Filters, plugFilter{ - ID: filterId, - Path: filterPath, - }) - } - } - } - - for _, filter := range p.settings.Filters { - log.Printf("Loading rules from %s", filter.Path) - - file, err := os.Open(filter.Path) - if err != nil { - return nil, err - } - defer file.Close() - - count := 0 - scanner := bufio.NewScanner(file) - for scanner.Scan() { - text := scanner.Text() - - err = p.d.AddRule(text, filter.ID) - if err == dnsfilter.ErrAlreadyExists || err == dnsfilter.ErrInvalidSyntax { - continue - } - if err != nil { - log.Printf("Cannot add rule %s: %s", text, err) - // Just ignore invalid rules - continue - } - count++ - } - log.Printf("Added %d rules from filter ID=%d", count, filter.ID) - - if err = scanner.Err(); err != nil { - return nil, err - } - } - - log.Printf("Loading stats from querylog") - err := fillStatsFromQueryLog() - if err != nil { - log.Printf("Failed to load stats from querylog: %s", err) - return nil, err - } - - if p.settings.QueryLogEnabled { - onceQueryLog.Do(func() { - go periodicQueryLogRotate() - go periodicHourlyTopRotate() - go statsRotator() - }) - } - - onceHook.Do(func() { - caddy.RegisterEventHook("dnsfilter-reload", hook) - }) - - p.upstream, err = upstream.New(nil) - if err != nil { - return nil, err - } - - return p, nil -} - -func setup(c *caddy.Controller) error { - p, err := setupPlugin(c) - if err != nil { - return err - } - config := dnsserver.GetConfig(c) - config.AddPlugin(func(next plugin.Handler) plugin.Handler { - p.Next = next - return p - }) - - c.OnStartup(func() error { - m := dnsserver.GetConfig(c).Handler("prometheus") - if m == nil { - return nil - } - if x, ok := m.(*metrics.Metrics); ok { - x.MustRegister(requests) - x.MustRegister(filtered) - x.MustRegister(filteredLists) - x.MustRegister(filteredSafebrowsing) - x.MustRegister(filteredParental) - x.MustRegister(whitelisted) - x.MustRegister(safesearch) - x.MustRegister(errorsTotal) - x.MustRegister(elapsedTime) - x.MustRegister(p) - } - return nil - }) - c.OnShutdown(p.onShutdown) - c.OnFinalShutdown(p.onFinalShutdown) - - return nil -} - -func (p *plug) onShutdown() error { - p.Lock() - p.d.Destroy() - p.d = nil - p.Unlock() - return nil -} - -func (p *plug) onFinalShutdown() error { - logBufferLock.Lock() - err := flushToFile(logBuffer) - if err != nil { - log.Printf("failed to flush to file: %s", err) - return err - } - logBufferLock.Unlock() - return nil -} - -type statsFunc func(ch interface{}, name string, text string, value float64, valueType prometheus.ValueType) - -func doDesc(ch interface{}, name string, text string, value float64, valueType prometheus.ValueType) { - realch, ok := ch.(chan<- *prometheus.Desc) - if !ok { - log.Printf("Couldn't convert ch to chan<- *prometheus.Desc\n") - return - } - realch <- prometheus.NewDesc(name, text, nil, nil) -} - -func doMetric(ch interface{}, name string, text string, value float64, valueType prometheus.ValueType) { - realch, ok := ch.(chan<- prometheus.Metric) - if !ok { - log.Printf("Couldn't convert ch to chan<- prometheus.Metric\n") - return - } - desc := prometheus.NewDesc(name, text, nil, nil) - realch <- prometheus.MustNewConstMetric(desc, valueType, value) -} - -func gen(ch interface{}, doFunc statsFunc, name string, text string, value float64, valueType prometheus.ValueType) { - doFunc(ch, name, text, value, valueType) -} - -func doStatsLookup(ch interface{}, doFunc statsFunc, name string, lookupstats *dnsfilter.LookupStats) { - gen(ch, doFunc, fmt.Sprintf("coredns_dnsfilter_%s_requests", name), fmt.Sprintf("Number of %s HTTP requests that were sent", name), float64(lookupstats.Requests), prometheus.CounterValue) - gen(ch, doFunc, fmt.Sprintf("coredns_dnsfilter_%s_cachehits", name), fmt.Sprintf("Number of %s lookups that didn't need HTTP requests", name), float64(lookupstats.CacheHits), prometheus.CounterValue) - gen(ch, doFunc, fmt.Sprintf("coredns_dnsfilter_%s_pending", name), fmt.Sprintf("Number of currently pending %s HTTP requests", name), float64(lookupstats.Pending), prometheus.GaugeValue) - gen(ch, doFunc, fmt.Sprintf("coredns_dnsfilter_%s_pending_max", name), fmt.Sprintf("Maximum number of pending %s HTTP requests", name), float64(lookupstats.PendingMax), prometheus.GaugeValue) -} - -func (p *plug) doStats(ch interface{}, doFunc statsFunc) { - p.RLock() - stats := p.d.GetStats() - doStatsLookup(ch, doFunc, "safebrowsing", &stats.Safebrowsing) - doStatsLookup(ch, doFunc, "parental", &stats.Parental) - p.RUnlock() -} - -// Describe is called by prometheus handler to know stat types -func (p *plug) Describe(ch chan<- *prometheus.Desc) { - p.doStats(ch, doDesc) -} - -// Collect is called by prometheus handler to collect stats -func (p *plug) Collect(ch chan<- prometheus.Metric) { - p.doStats(ch, doMetric) -} - -func (p *plug) replaceHostWithValAndReply(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, host string, val string, question dns.Question) (int, error) { - // check if it's a domain name or IP address - addr := net.ParseIP(val) - var records []dns.RR - // log.Println("Will give", val, "instead of", host) // debug logging - if addr != nil { - // this is an IP address, return it - result, err := dns.NewRR(fmt.Sprintf("%s %d A %s", host, p.settings.BlockedTTL, val)) - if err != nil { - log.Printf("Got error %s\n", err) - return dns.RcodeServerFailure, fmt.Errorf("plugin/dnsfilter: %s", err) - } - records = append(records, result) - } else { - // this is a domain name, need to look it up - req := new(dns.Msg) - req.SetQuestion(dns.Fqdn(val), question.Qtype) - req.RecursionDesired = true - reqstate := request.Request{W: w, Req: req, Context: ctx} - result, err := p.upstream.Lookup(reqstate, dns.Fqdn(val), reqstate.QType()) - if err != nil { - log.Printf("Got error %s\n", err) - return dns.RcodeServerFailure, fmt.Errorf("plugin/dnsfilter: %s", err) - } - if result != nil { - for _, answer := range result.Answer { - answer.Header().Name = question.Name - } - records = result.Answer - } - } - m := new(dns.Msg) - m.SetReply(r) - m.Authoritative, m.RecursionAvailable, m.Compress = true, true, true - m.Answer = append(m.Answer, records...) - state := request.Request{W: w, Req: r, Context: ctx} - state.SizeAndDo(m) - err := state.W.WriteMsg(m) - if err != nil { - log.Printf("Got error %s\n", err) - return dns.RcodeServerFailure, fmt.Errorf("plugin/dnsfilter: %s", err) - } - return dns.RcodeSuccess, nil -} - -// generate SOA record that makes DNS clients cache NXdomain results -// the only value that is important is TTL in header, other values like refresh, retry, expire and minttl are irrelevant -func (p *plug) genSOA(r *dns.Msg) []dns.RR { - zone := r.Question[0].Name - header := dns.RR_Header{Name: zone, Rrtype: dns.TypeSOA, Ttl: p.settings.BlockedTTL, Class: dns.ClassINET} - - Mbox := "hostmaster." - if zone[0] != '.' { - Mbox += zone - } - Ns := "fake-for-negative-caching.adguard.com." - - soa := *defaultSOA - soa.Hdr = header - soa.Mbox = Mbox - soa.Ns = Ns - soa.Serial = 100500 // faster than uint32(time.Now().Unix()) - return []dns.RR{&soa} -} - -func (p *plug) writeNXdomain(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { - state := request.Request{W: w, Req: r, Context: ctx} - m := new(dns.Msg) - m.SetRcode(state.Req, dns.RcodeNameError) - m.Authoritative, m.RecursionAvailable, m.Compress = true, true, true - m.Ns = p.genSOA(r) - - state.SizeAndDo(m) - err := state.W.WriteMsg(m) - if err != nil { - log.Printf("Got error %s\n", err) - return dns.RcodeServerFailure, err - } - return dns.RcodeNameError, nil -} - -func (p *plug) serveDNSInternal(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, dnsfilter.Result, error) { - if len(r.Question) != 1 { - // google DNS, bind and others do the same - return dns.RcodeFormatError, dnsfilter.Result{}, fmt.Errorf("got a DNS request with more than one Question") - } - for _, question := range r.Question { - host := strings.ToLower(strings.TrimSuffix(question.Name, ".")) - // is it a safesearch domain? - p.RLock() - if val, ok := p.d.SafeSearchDomain(host); ok { - rcode, err := p.replaceHostWithValAndReply(ctx, w, r, host, val, question) - if err != nil { - p.RUnlock() - return rcode, dnsfilter.Result{}, err - } - p.RUnlock() - return rcode, dnsfilter.Result{Reason: dnsfilter.FilteredSafeSearch}, err - } - p.RUnlock() - - // needs to be filtered instead - p.RLock() - result, err := p.d.CheckHost(host) - if err != nil { - log.Printf("plugin/dnsfilter: %s\n", err) - p.RUnlock() - return dns.RcodeServerFailure, dnsfilter.Result{}, fmt.Errorf("plugin/dnsfilter: %s", err) - } - p.RUnlock() - - if result.IsFiltered { - switch result.Reason { - case dnsfilter.FilteredSafeBrowsing: - // return cname safebrowsing.block.dns.adguard.com - val := p.settings.SafeBrowsingBlockHost - rcode, err := p.replaceHostWithValAndReply(ctx, w, r, host, val, question) - if err != nil { - return rcode, dnsfilter.Result{}, err - } - return rcode, result, err - case dnsfilter.FilteredParental: - // return cname family.block.dns.adguard.com - val := p.settings.ParentalBlockHost - rcode, err := p.replaceHostWithValAndReply(ctx, w, r, host, val, question) - if err != nil { - return rcode, dnsfilter.Result{}, err - } - return rcode, result, err - case dnsfilter.FilteredBlackList: - - if result.Ip == nil { - // return NXDomain - rcode, err := p.writeNXdomain(ctx, w, r) - if err != nil { - return rcode, dnsfilter.Result{}, err - } - return rcode, result, err - } else { - // This is a hosts-syntax rule - rcode, err := p.replaceHostWithValAndReply(ctx, w, r, host, result.Ip.String(), question) - if err != nil { - return rcode, dnsfilter.Result{}, err - } - return rcode, result, err - } - case dnsfilter.FilteredInvalid: - // return NXdomain - rcode, err := p.writeNXdomain(ctx, w, r) - if err != nil { - return rcode, dnsfilter.Result{}, err - } - return rcode, result, err - default: - log.Printf("SHOULD NOT HAPPEN -- got unknown reason for filtering host \"%s\": %v, %+v", host, result.Reason, result) - } - } else { - switch result.Reason { - case dnsfilter.NotFilteredWhiteList: - rcode, err := plugin.NextOrFailure(p.Name(), p.Next, ctx, w, r) - return rcode, result, err - case dnsfilter.NotFilteredNotFound: - // do nothing, pass through to lower code - default: - log.Printf("SHOULD NOT HAPPEN -- got unknown reason for not filtering host \"%s\": %v, %+v", host, result.Reason, result) - } - } - } - rcode, err := plugin.NextOrFailure(p.Name(), p.Next, ctx, w, r) - return rcode, dnsfilter.Result{}, err -} - -// ServeDNS handles the DNS request and refuses if it's in filterlists -func (p *plug) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { - start := time.Now() - requests.Inc() - state := request.Request{W: w, Req: r} - ip := state.IP() - - // capture the written answer - rrw := dnstest.NewRecorder(w) - rcode, result, err := p.serveDNSInternal(ctx, rrw, r) - if rcode > 0 { - // actually send the answer if we have one - answer := new(dns.Msg) - answer.SetRcode(r, rcode) - state.SizeAndDo(answer) - err = w.WriteMsg(answer) - if err != nil { - return dns.RcodeServerFailure, err - } - } - - // increment counters - switch { - case err != nil: - errorsTotal.Inc() - case result.Reason == dnsfilter.FilteredBlackList: - filtered.Inc() - filteredLists.Inc() - case result.Reason == dnsfilter.FilteredSafeBrowsing: - filtered.Inc() - filteredSafebrowsing.Inc() - case result.Reason == dnsfilter.FilteredParental: - filtered.Inc() - filteredParental.Inc() - case result.Reason == dnsfilter.FilteredInvalid: - filtered.Inc() - filteredInvalid.Inc() - case result.Reason == dnsfilter.FilteredSafeSearch: - // the request was passsed through but not filtered, don't increment filtered - safesearch.Inc() - case result.Reason == dnsfilter.NotFilteredWhiteList: - whitelisted.Inc() - case result.Reason == dnsfilter.NotFilteredNotFound: - // do nothing - case result.Reason == dnsfilter.NotFilteredError: - text := "SHOULD NOT HAPPEN: got DNSFILTER_NOTFILTERED_ERROR without err != nil!" - log.Println(text) - err = errors.New(text) - rcode = dns.RcodeServerFailure - } - - // log - elapsed := time.Since(start) - elapsedTime.Observe(elapsed.Seconds()) - if p.settings.QueryLogEnabled { - logRequest(r, rrw.Msg, result, time.Since(start), ip) - } - return rcode, err -} - -// Name returns name of the plugin as seen in Corefile and plugin.cfg -func (p *plug) Name() string { return "dnsfilter" } - -var onceHook sync.Once -var onceQueryLog sync.Once diff --git a/coredns_plugin/coredns_plugin_test.go b/coredns_plugin/coredns_plugin_test.go deleted file mode 100644 index 1733fd6f..00000000 --- a/coredns_plugin/coredns_plugin_test.go +++ /dev/null @@ -1,131 +0,0 @@ -package dnsfilter - -import ( - "context" - "fmt" - "io/ioutil" - "net" - "os" - "testing" - - "github.com/coredns/coredns/plugin" - "github.com/coredns/coredns/plugin/pkg/dnstest" - "github.com/coredns/coredns/plugin/test" - "github.com/mholt/caddy" - "github.com/miekg/dns" -) - -func TestSetup(t *testing.T) { - for i, testcase := range []struct { - config string - failing bool - }{ - {`dnsfilter`, false}, - {`dnsfilter { - filter 0 /dev/nonexistent/abcdef - }`, true}, - {`dnsfilter { - filter 0 ../tests/dns.txt - }`, false}, - {`dnsfilter { - safebrowsing - filter 0 ../tests/dns.txt - }`, false}, - {`dnsfilter { - parental - filter 0 ../tests/dns.txt - }`, true}, - } { - c := caddy.NewTestController("dns", testcase.config) - err := setup(c) - if err != nil { - if !testcase.failing { - t.Fatalf("Test #%d expected no errors, but got: %v", i, err) - } - continue - } - if testcase.failing { - t.Fatalf("Test #%d expected to fail but it didn't", i) - } - } -} - -func TestEtcHostsFilter(t *testing.T) { - text := []byte("127.0.0.1 doubleclick.net\n" + "127.0.0.1 example.org example.net www.example.org www.example.net") - tmpfile, err := ioutil.TempFile("", "") - if err != nil { - t.Fatal(err) - } - if _, err = tmpfile.Write(text); err != nil { - t.Fatal(err) - } - if err = tmpfile.Close(); err != nil { - t.Fatal(err) - } - - defer os.Remove(tmpfile.Name()) - - configText := fmt.Sprintf("dnsfilter {\nfilter 0 %s\n}", tmpfile.Name()) - c := caddy.NewTestController("dns", configText) - p, err := setupPlugin(c) - if err != nil { - t.Fatal(err) - } - - p.Next = zeroTTLBackend() - - ctx := context.TODO() - - for _, testcase := range []struct { - host string - filtered bool - }{ - {"www.doubleclick.net", false}, - {"doubleclick.net", true}, - {"www2.example.org", false}, - {"www2.example.net", false}, - {"test.www.example.org", false}, - {"test.www.example.net", false}, - {"example.org", true}, - {"example.net", true}, - {"www.example.org", true}, - {"www.example.net", true}, - } { - req := new(dns.Msg) - req.SetQuestion(testcase.host+".", dns.TypeA) - - resp := test.ResponseWriter{} - rrw := dnstest.NewRecorder(&resp) - rcode, err := p.ServeDNS(ctx, rrw, req) - if err != nil { - t.Fatalf("ServeDNS returned error: %s", err) - } - if rcode != rrw.Rcode { - t.Fatalf("ServeDNS return value for host %s has rcode %d that does not match captured rcode %d", testcase.host, rcode, rrw.Rcode) - } - A, ok := rrw.Msg.Answer[0].(*dns.A) - if !ok { - t.Fatalf("Host %s expected to have result A", testcase.host) - } - ip := net.IPv4(127, 0, 0, 1) - filtered := ip.Equal(A.A) - if testcase.filtered && testcase.filtered != filtered { - t.Fatalf("Host %s expected to be filtered, instead it is not filtered", testcase.host) - } - if !testcase.filtered && testcase.filtered != filtered { - t.Fatalf("Host %s expected to be not filtered, instead it is filtered", testcase.host) - } - } -} - -func zeroTTLBackend() plugin.Handler { - return plugin.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { - m := new(dns.Msg) - m.SetReply(r) - m.Response, m.RecursionAvailable = true, true - - m.Answer = []dns.RR{test.A("example.org. 0 IN A 127.0.0.53")} - w.WriteMsg(m) - return dns.RcodeSuccess, nil - }) -} diff --git a/coredns_plugin/coredns_stats.go b/coredns_plugin/coredns_stats.go deleted file mode 100644 index b138911e..00000000 --- a/coredns_plugin/coredns_stats.go +++ /dev/null @@ -1,410 +0,0 @@ -package dnsfilter - -import ( - "encoding/json" - "fmt" - "log" - "net/http" - "sync" - "time" - - "github.com/coredns/coredns/plugin" - "github.com/prometheus/client_golang/prometheus" -) - -var ( - requests = newDNSCounter("requests_total", "Count of requests seen by dnsfilter.") - filtered = newDNSCounter("filtered_total", "Count of requests filtered by dnsfilter.") - filteredLists = newDNSCounter("filtered_lists_total", "Count of requests filtered by dnsfilter using lists.") - filteredSafebrowsing = newDNSCounter("filtered_safebrowsing_total", "Count of requests filtered by dnsfilter using safebrowsing.") - filteredParental = newDNSCounter("filtered_parental_total", "Count of requests filtered by dnsfilter using parental.") - filteredInvalid = newDNSCounter("filtered_invalid_total", "Count of requests filtered by dnsfilter because they were invalid.") - whitelisted = newDNSCounter("whitelisted_total", "Count of requests not filtered by dnsfilter because they are whitelisted.") - safesearch = newDNSCounter("safesearch_total", "Count of requests replaced by dnsfilter safesearch.") - errorsTotal = newDNSCounter("errors_total", "Count of requests that dnsfilter couldn't process because of transitive errors.") - elapsedTime = newDNSHistogram("request_duration", "Histogram of the time (in seconds) each request took.") -) - -// entries for single time period (for example all per-second entries) -type statsEntries map[string][statsHistoryElements]float64 - -// how far back to keep the stats -const statsHistoryElements = 60 + 1 // +1 for calculating delta - -// each periodic stat is a map of arrays -type periodicStats struct { - Entries statsEntries - period time.Duration // how long one entry lasts - LastRotate time.Time // last time this data was rotated - - sync.RWMutex -} - -type stats struct { - PerSecond periodicStats - PerMinute periodicStats - PerHour periodicStats - PerDay periodicStats -} - -// per-second/per-minute/per-hour/per-day stats -var statistics stats - -func initPeriodicStats(periodic *periodicStats, period time.Duration) { - periodic.Entries = statsEntries{} - periodic.LastRotate = time.Now() - periodic.period = period -} - -func init() { - purgeStats() -} - -func purgeStats() { - initPeriodicStats(&statistics.PerSecond, time.Second) - initPeriodicStats(&statistics.PerMinute, time.Minute) - initPeriodicStats(&statistics.PerHour, time.Hour) - initPeriodicStats(&statistics.PerDay, time.Hour*24) -} - -func (p *periodicStats) Inc(name string, when time.Time) { - // calculate how many periods ago this happened - elapsed := int64(time.Since(when) / p.period) - // trace("%s: %v as %v -> [%v]", name, time.Since(when), p.period, elapsed) - if elapsed >= statsHistoryElements { - return // outside of our timeframe - } - p.Lock() - currentValues := p.Entries[name] - currentValues[elapsed]++ - p.Entries[name] = currentValues - p.Unlock() -} - -func (p *periodicStats) Observe(name string, when time.Time, value float64) { - // calculate how many periods ago this happened - elapsed := int64(time.Since(when) / p.period) - // trace("%s: %v as %v -> [%v]", name, time.Since(when), p.period, elapsed) - if elapsed >= statsHistoryElements { - return // outside of our timeframe - } - p.Lock() - { - countname := name + "_count" - currentValues := p.Entries[countname] - value := currentValues[elapsed] - // trace("Will change p.Entries[%s][%d] from %v to %v", countname, elapsed, value, value+1) - value += 1 - currentValues[elapsed] = value - p.Entries[countname] = currentValues - } - { - totalname := name + "_sum" - currentValues := p.Entries[totalname] - currentValues[elapsed] += value - p.Entries[totalname] = currentValues - } - p.Unlock() -} - -func (p *periodicStats) statsRotate(now time.Time) { - p.Lock() - rotations := int64(now.Sub(p.LastRotate) / p.period) - if rotations > statsHistoryElements { - rotations = statsHistoryElements - } - // calculate how many times we should rotate - for r := int64(0); r < rotations; r++ { - for key, values := range p.Entries { - newValues := [statsHistoryElements]float64{} - for i := 1; i < len(values); i++ { - newValues[i] = values[i-1] - } - p.Entries[key] = newValues - } - } - if rotations > 0 { - p.LastRotate = now - } - p.Unlock() -} - -func statsRotator() { - for range time.Tick(time.Second) { - now := time.Now() - statistics.PerSecond.statsRotate(now) - statistics.PerMinute.statsRotate(now) - statistics.PerHour.statsRotate(now) - statistics.PerDay.statsRotate(now) - } -} - -// counter that wraps around prometheus Counter but also adds to periodic stats -type counter struct { - name string // used as key in periodic stats - value int64 - prom prometheus.Counter -} - -func newDNSCounter(name string, help string) *counter { - // trace("called") - c := &counter{} - c.prom = prometheus.NewCounter(prometheus.CounterOpts{ - Namespace: plugin.Namespace, - Subsystem: "dnsfilter", - Name: name, - Help: help, - }) - c.name = name - - return c -} - -func (c *counter) IncWithTime(when time.Time) { - statistics.PerSecond.Inc(c.name, when) - statistics.PerMinute.Inc(c.name, when) - statistics.PerHour.Inc(c.name, when) - statistics.PerDay.Inc(c.name, when) - c.value++ - c.prom.Inc() -} - -func (c *counter) Inc() { - c.IncWithTime(time.Now()) -} - -func (c *counter) Describe(ch chan<- *prometheus.Desc) { - c.prom.Describe(ch) -} - -func (c *counter) Collect(ch chan<- prometheus.Metric) { - c.prom.Collect(ch) -} - -type histogram struct { - name string // used as key in periodic stats - count int64 - total float64 - prom prometheus.Histogram -} - -func newDNSHistogram(name string, help string) *histogram { - // trace("called") - h := &histogram{} - h.prom = prometheus.NewHistogram(prometheus.HistogramOpts{ - Namespace: plugin.Namespace, - Subsystem: "dnsfilter", - Name: name, - Help: help, - }) - h.name = name - - return h -} - -func (h *histogram) ObserveWithTime(value float64, when time.Time) { - statistics.PerSecond.Observe(h.name, when, value) - statistics.PerMinute.Observe(h.name, when, value) - statistics.PerHour.Observe(h.name, when, value) - statistics.PerDay.Observe(h.name, when, value) - h.count++ - h.total += value - h.prom.Observe(value) -} - -func (h *histogram) Observe(value float64) { - h.ObserveWithTime(value, time.Now()) -} - -func (h *histogram) Describe(ch chan<- *prometheus.Desc) { - h.prom.Describe(ch) -} - -func (h *histogram) Collect(ch chan<- prometheus.Metric) { - h.prom.Collect(ch) -} - -// ----- -// stats -// ----- -func HandleStats(w http.ResponseWriter, r *http.Request) { - const numHours = 24 - histrical := generateMapFromStats(&statistics.PerHour, 0, numHours) - // sum them up - summed := map[string]interface{}{} - for key, values := range histrical { - summedValue := 0.0 - floats, ok := values.([]float64) - if !ok { - continue - } - for _, v := range floats { - summedValue += v - } - summed[key] = summedValue - } - // don't forget to divide by number of elements in returned slice - if val, ok := summed["avg_processing_time"]; ok { - if flval, flok := val.(float64); flok { - flval /= numHours - summed["avg_processing_time"] = flval - } - } - - summed["stats_period"] = "24 hours" - - json, err := json.Marshal(summed) - if err != nil { - errortext := fmt.Sprintf("Unable to marshal status json: %s", err) - log.Println(errortext) - http.Error(w, errortext, 500) - return - } - w.Header().Set("Content-Type", "application/json") - _, err = w.Write(json) - if err != nil { - errortext := fmt.Sprintf("Unable to write response json: %s", err) - log.Println(errortext) - http.Error(w, errortext, 500) - return - } -} - -func generateMapFromStats(stats *periodicStats, start int, end int) map[string]interface{} { - // clamp - start = clamp(start, 0, statsHistoryElements) - end = clamp(end, 0, statsHistoryElements) - - avgProcessingTime := make([]float64, 0) - - count := getReversedSlice(stats.Entries[elapsedTime.name+"_count"], start, end) - sum := getReversedSlice(stats.Entries[elapsedTime.name+"_sum"], start, end) - for i := 0; i < len(count); i++ { - var avg float64 - if count[i] != 0 { - avg = sum[i] / count[i] - avg *= 1000 - } - avgProcessingTime = append(avgProcessingTime, avg) - } - - result := map[string]interface{}{ - "dns_queries": getReversedSlice(stats.Entries[requests.name], start, end), - "blocked_filtering": getReversedSlice(stats.Entries[filtered.name], start, end), - "replaced_safebrowsing": getReversedSlice(stats.Entries[filteredSafebrowsing.name], start, end), - "replaced_safesearch": getReversedSlice(stats.Entries[safesearch.name], start, end), - "replaced_parental": getReversedSlice(stats.Entries[filteredParental.name], start, end), - "avg_processing_time": avgProcessingTime, - } - return result -} - -func HandleStatsHistory(w http.ResponseWriter, r *http.Request) { - // handle time unit and prepare our time window size - now := time.Now() - timeUnitString := r.URL.Query().Get("time_unit") - var stats *periodicStats - var timeUnit time.Duration - switch timeUnitString { - case "seconds": - timeUnit = time.Second - stats = &statistics.PerSecond - case "minutes": - timeUnit = time.Minute - stats = &statistics.PerMinute - case "hours": - timeUnit = time.Hour - stats = &statistics.PerHour - case "days": - timeUnit = time.Hour * 24 - stats = &statistics.PerDay - default: - http.Error(w, "Must specify valid time_unit parameter", 400) - return - } - - // parse start and end time - startTime, err := time.Parse(time.RFC3339, r.URL.Query().Get("start_time")) - if err != nil { - errortext := fmt.Sprintf("Must specify valid start_time parameter: %s", err) - log.Println(errortext) - http.Error(w, errortext, 400) - return - } - endTime, err := time.Parse(time.RFC3339, r.URL.Query().Get("end_time")) - if err != nil { - errortext := fmt.Sprintf("Must specify valid end_time parameter: %s", err) - log.Println(errortext) - http.Error(w, errortext, 400) - return - } - - // check if start and time times are within supported time range - timeRange := timeUnit * statsHistoryElements - if startTime.Add(timeRange).Before(now) { - http.Error(w, "start_time parameter is outside of supported range", 501) - return - } - if endTime.Add(timeRange).Before(now) { - http.Error(w, "end_time parameter is outside of supported range", 501) - return - } - - // calculate start and end of our array - // basically it's how many hours/minutes/etc have passed since now - start := int(now.Sub(endTime) / timeUnit) - end := int(now.Sub(startTime) / timeUnit) - - // swap them around if they're inverted - if start > end { - start, end = end, start - } - - data := generateMapFromStats(stats, start, end) - json, err := json.Marshal(data) - if err != nil { - errortext := fmt.Sprintf("Unable to marshal status json: %s", err) - log.Println(errortext) - http.Error(w, errortext, 500) - return - } - w.Header().Set("Content-Type", "application/json") - _, err = w.Write(json) - if err != nil { - errortext := fmt.Sprintf("Unable to write response json: %s", err) - log.Println(errortext) - http.Error(w, errortext, 500) - return - } -} - -func HandleStatsReset(w http.ResponseWriter, r *http.Request) { - purgeStats() - _, err := fmt.Fprintf(w, "OK\n") - if err != nil { - errortext := fmt.Sprintf("Couldn't write body: %s", err) - log.Println(errortext) - http.Error(w, errortext, http.StatusInternalServerError) - } -} - -func clamp(value, low, high int) int { - if value < low { - return low - } - if value > high { - return high - } - return value -} - -// -------------------------- -// helper functions for stats -// -------------------------- -func getReversedSlice(input [statsHistoryElements]float64, start int, end int) []float64 { - output := make([]float64, 0) - for i := start; i <= end; i++ { - output = append([]float64{input[i]}, output...) - } - return output -} diff --git a/coredns_plugin/querylog.go b/coredns_plugin/querylog.go deleted file mode 100644 index 92ba2d1d..00000000 --- a/coredns_plugin/querylog.go +++ /dev/null @@ -1,239 +0,0 @@ -package dnsfilter - -import ( - "encoding/json" - "fmt" - "log" - "net/http" - "os" - "path" - "runtime" - "strconv" - "strings" - "sync" - "time" - - "github.com/AdguardTeam/AdGuardHome/dnsfilter" - "github.com/coredns/coredns/plugin/pkg/response" - "github.com/miekg/dns" -) - -const ( - logBufferCap = 5000 // maximum capacity of logBuffer before it's flushed to disk - queryLogTimeLimit = time.Hour * 24 // how far in the past we care about querylogs - queryLogRotationPeriod = time.Hour * 24 // rotate the log every 24 hours - queryLogFileName = "querylog.json" // .gz added during compression - queryLogSize = 5000 // maximum API response for /querylog - queryLogTopSize = 500 // Keep in memory only top N values -) - -var ( - logBufferLock sync.RWMutex - logBuffer []*logEntry - - queryLogCache []*logEntry - queryLogLock sync.RWMutex -) - -type logEntry struct { - Question []byte - Answer []byte `json:",omitempty"` // sometimes empty answers happen like binerdunt.top or rev2.globalrootservers.net - Result dnsfilter.Result - Time time.Time - Elapsed time.Duration - IP string -} - -func logRequest(question *dns.Msg, answer *dns.Msg, result dnsfilter.Result, elapsed time.Duration, ip string) { - var q []byte - var a []byte - var err error - - if question != nil { - q, err = question.Pack() - if err != nil { - log.Printf("failed to pack question for querylog: %s", err) - return - } - } - if answer != nil { - a, err = answer.Pack() - if err != nil { - log.Printf("failed to pack answer for querylog: %s", err) - return - } - } - - now := time.Now() - entry := logEntry{ - Question: q, - Answer: a, - Result: result, - Time: now, - Elapsed: elapsed, - IP: ip, - } - var flushBuffer []*logEntry - - logBufferLock.Lock() - logBuffer = append(logBuffer, &entry) - if len(logBuffer) >= logBufferCap { - flushBuffer = logBuffer - logBuffer = nil - } - logBufferLock.Unlock() - queryLogLock.Lock() - queryLogCache = append(queryLogCache, &entry) - if len(queryLogCache) > queryLogSize { - toremove := len(queryLogCache) - queryLogSize - queryLogCache = queryLogCache[toremove:] - } - queryLogLock.Unlock() - - // add it to running top - err = runningTop.addEntry(&entry, question, now) - if err != nil { - log.Printf("Failed to add entry to running top: %s", err) - // don't do failure, just log - } - - // if buffer needs to be flushed to disk, do it now - if len(flushBuffer) > 0 { - // write to file - // do it in separate goroutine -- we are stalling DNS response this whole time - go flushToFile(flushBuffer) - } -} - -func HandleQueryLog(w http.ResponseWriter, r *http.Request) { - queryLogLock.RLock() - values := make([]*logEntry, len(queryLogCache)) - copy(values, queryLogCache) - queryLogLock.RUnlock() - - // reverse it so that newest is first - for left, right := 0, len(values)-1; left < right; left, right = left+1, right-1 { - values[left], values[right] = values[right], values[left] - } - - var data = []map[string]interface{}{} - for _, entry := range values { - var q *dns.Msg - var a *dns.Msg - - if len(entry.Question) > 0 { - q = new(dns.Msg) - if err := q.Unpack(entry.Question); err != nil { - // ignore, log and move on - log.Printf("Failed to unpack dns message question: %s", err) - q = nil - } - } - if len(entry.Answer) > 0 { - a = new(dns.Msg) - if err := a.Unpack(entry.Answer); err != nil { - // ignore, log and move on - log.Printf("Failed to unpack dns message question: %s", err) - a = nil - } - } - - jsonEntry := map[string]interface{}{ - "reason": entry.Result.Reason.String(), - "elapsedMs": strconv.FormatFloat(entry.Elapsed.Seconds()*1000, 'f', -1, 64), - "time": entry.Time.Format(time.RFC3339), - "client": entry.IP, - } - if q != nil { - jsonEntry["question"] = map[string]interface{}{ - "host": strings.ToLower(strings.TrimSuffix(q.Question[0].Name, ".")), - "type": dns.Type(q.Question[0].Qtype).String(), - "class": dns.Class(q.Question[0].Qclass).String(), - } - } - - if a != nil { - status, _ := response.Typify(a, time.Now().UTC()) - jsonEntry["status"] = status.String() - } - if len(entry.Result.Rule) > 0 { - jsonEntry["rule"] = entry.Result.Rule - jsonEntry["filterId"] = entry.Result.FilterID - } - - if a != nil && len(a.Answer) > 0 { - var answers = []map[string]interface{}{} - for _, k := range a.Answer { - header := k.Header() - answer := map[string]interface{}{ - "type": dns.TypeToString[header.Rrtype], - "ttl": header.Ttl, - } - // try most common record types - switch v := k.(type) { - case *dns.A: - answer["value"] = v.A - case *dns.AAAA: - answer["value"] = v.AAAA - case *dns.MX: - answer["value"] = fmt.Sprintf("%v %v", v.Preference, v.Mx) - case *dns.CNAME: - answer["value"] = v.Target - case *dns.NS: - answer["value"] = v.Ns - case *dns.SPF: - answer["value"] = v.Txt - case *dns.TXT: - answer["value"] = v.Txt - case *dns.PTR: - answer["value"] = v.Ptr - case *dns.SOA: - answer["value"] = fmt.Sprintf("%v %v %v %v %v %v %v", v.Ns, v.Mbox, v.Serial, v.Refresh, v.Retry, v.Expire, v.Minttl) - case *dns.CAA: - answer["value"] = fmt.Sprintf("%v %v \"%v\"", v.Flag, v.Tag, v.Value) - case *dns.HINFO: - answer["value"] = fmt.Sprintf("\"%v\" \"%v\"", v.Cpu, v.Os) - case *dns.RRSIG: - answer["value"] = fmt.Sprintf("%v %v %v %v %v %v %v %v %v", dns.TypeToString[v.TypeCovered], v.Algorithm, v.Labels, v.OrigTtl, v.Expiration, v.Inception, v.KeyTag, v.SignerName, v.Signature) - default: - // type unknown, marshall it as-is - answer["value"] = v - } - answers = append(answers, answer) - } - jsonEntry["answer"] = answers - } - - data = append(data, jsonEntry) - } - - jsonVal, err := json.Marshal(data) - if err != nil { - errorText := fmt.Sprintf("Couldn't marshal data into json: %s", err) - log.Println(errorText) - http.Error(w, errorText, http.StatusInternalServerError) - return - } - - w.Header().Set("Content-Type", "application/json") - _, err = w.Write(jsonVal) - if err != nil { - errorText := fmt.Sprintf("Unable to write response json: %s", err) - log.Println(errorText) - http.Error(w, errorText, http.StatusInternalServerError) - } -} - -func trace(format string, args ...interface{}) { - pc := make([]uintptr, 10) // at least 1 entry needed - runtime.Callers(2, pc) - f := runtime.FuncForPC(pc[0]) - var buf strings.Builder - buf.WriteString(fmt.Sprintf("%s(): ", path.Base(f.Name()))) - text := fmt.Sprintf(format, args...) - buf.WriteString(text) - if len(text) == 0 || text[len(text)-1] != '\n' { - buf.WriteRune('\n') - } - fmt.Fprint(os.Stderr, buf.String()) -} diff --git a/coredns_plugin/querylog_file.go b/coredns_plugin/querylog_file.go deleted file mode 100644 index a36812c2..00000000 --- a/coredns_plugin/querylog_file.go +++ /dev/null @@ -1,291 +0,0 @@ -package dnsfilter - -import ( - "bytes" - "compress/gzip" - "encoding/json" - "fmt" - "log" - "os" - "sync" - "time" - - "github.com/go-test/deep" -) - -var ( - fileWriteLock sync.Mutex -) - -const enableGzip = false - -func flushToFile(buffer []*logEntry) error { - if len(buffer) == 0 { - return nil - } - start := time.Now() - - var b bytes.Buffer - e := json.NewEncoder(&b) - for _, entry := range buffer { - err := e.Encode(entry) - if err != nil { - log.Printf("Failed to marshal entry: %s", err) - return err - } - } - - elapsed := time.Since(start) - log.Printf("%d elements serialized via json in %v: %d kB, %v/entry, %v/entry", len(buffer), elapsed, b.Len()/1024, float64(b.Len())/float64(len(buffer)), elapsed/time.Duration(len(buffer))) - - err := checkBuffer(buffer, b) - if err != nil { - log.Printf("failed to check buffer: %s", err) - return err - } - - var zb bytes.Buffer - filename := queryLogFileName - - // gzip enabled? - if enableGzip { - filename += ".gz" - - zw := gzip.NewWriter(&zb) - zw.Name = queryLogFileName - zw.ModTime = time.Now() - - _, err = zw.Write(b.Bytes()) - if err != nil { - log.Printf("Couldn't compress to gzip: %s", err) - zw.Close() - return err - } - - if err = zw.Close(); err != nil { - log.Printf("Couldn't close gzip writer: %s", err) - return err - } - } else { - zb = b - } - - fileWriteLock.Lock() - defer fileWriteLock.Unlock() - f, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644) - if err != nil { - log.Printf("failed to create file \"%s\": %s", filename, err) - return err - } - defer f.Close() - - n, err := f.Write(zb.Bytes()) - if err != nil { - log.Printf("Couldn't write to file: %s", err) - return err - } - - log.Printf("ok \"%s\": %v bytes written", filename, n) - - return nil -} - -func checkBuffer(buffer []*logEntry, b bytes.Buffer) error { - l := len(buffer) - d := json.NewDecoder(&b) - - i := 0 - for d.More() { - entry := &logEntry{} - err := d.Decode(entry) - if err != nil { - log.Printf("Failed to decode: %s", err) - return err - } - if diff := deep.Equal(entry, buffer[i]); diff != nil { - log.Printf("decoded buffer differs: %s", diff) - return fmt.Errorf("decoded buffer differs: %s", diff) - } - i++ - } - if i != l { - err := fmt.Errorf("check fail: %d vs %d entries", l, i) - log.Print(err) - return err - } - log.Printf("check ok: %d entries", i) - - return nil -} - -func rotateQueryLog() error { - from := queryLogFileName - to := queryLogFileName + ".1" - - if enableGzip { - from = queryLogFileName + ".gz" - to = queryLogFileName + ".gz.1" - } - - if _, err := os.Stat(from); os.IsNotExist(err) { - // do nothing, file doesn't exist - return nil - } - - err := os.Rename(from, to) - if err != nil { - log.Printf("Failed to rename querylog: %s", err) - return err - } - - log.Printf("Rotated from %s to %s successfully", from, to) - - return nil -} - -func periodicQueryLogRotate() { - for range time.Tick(queryLogRotationPeriod) { - err := rotateQueryLog() - if err != nil { - log.Printf("Failed to rotate querylog: %s", err) - // do nothing, continue rotating - } - } -} - -func genericLoader(onEntry func(entry *logEntry) error, needMore func() bool, timeWindow time.Duration) error { - now := time.Now() - // read from querylog files, try newest file first - files := []string{} - - if enableGzip { - files = []string{ - queryLogFileName + ".gz", - queryLogFileName + ".gz.1", - } - } else { - files = []string{ - queryLogFileName, - queryLogFileName + ".1", - } - } - - // read from all files - for _, file := range files { - if !needMore() { - break - } - if _, err := os.Stat(file); os.IsNotExist(err) { - // do nothing, file doesn't exist - continue - } - - f, err := os.Open(file) - if err != nil { - log.Printf("Failed to open file \"%s\": %s", file, err) - // try next file - continue - } - defer f.Close() - - var d *json.Decoder - - if enableGzip { - trace("Creating gzip reader") - zr, err := gzip.NewReader(f) - if err != nil { - log.Printf("Failed to create gzip reader: %s", err) - continue - } - defer zr.Close() - - trace("Creating json decoder") - d = json.NewDecoder(zr) - } else { - d = json.NewDecoder(f) - } - - i := 0 - over := 0 - max := 10000 * time.Second - var sum time.Duration - // entries on file are in oldest->newest order - // we want maxLen newest - for d.More() { - if !needMore() { - break - } - var entry logEntry - err := d.Decode(&entry) - if err != nil { - log.Printf("Failed to decode: %s", err) - // next entry can be fine, try more - continue - } - - if now.Sub(entry.Time) > timeWindow { - // trace("skipping entry") // debug logging - continue - } - - if entry.Elapsed > max { - over++ - } else { - sum += entry.Elapsed - } - - i++ - err = onEntry(&entry) - if err != nil { - return err - } - } - elapsed := time.Since(now) - var perunit time.Duration - var avg time.Duration - if i > 0 { - perunit = elapsed / time.Duration(i) - avg = sum / time.Duration(i) - } - log.Printf("file \"%s\": read %d entries in %v, %v/entry, %v over %v, %v avg", file, i, elapsed, perunit, over, max, avg) - } - return nil -} - -func appendFromLogFile(values []*logEntry, maxLen int, timeWindow time.Duration) []*logEntry { - a := []*logEntry{} - - onEntry := func(entry *logEntry) error { - a = append(a, entry) - if len(a) > maxLen { - toskip := len(a) - maxLen - a = a[toskip:] - } - return nil - } - - needMore := func() bool { - return true - } - - err := genericLoader(onEntry, needMore, timeWindow) - if err != nil { - log.Printf("Failed to load entries from querylog: %s", err) - return values - } - - // now that we've read all eligible entries, reverse the slice to make it go from newest->oldest - for left, right := 0, len(a)-1; left < right; left, right = left+1, right-1 { - a[left], a[right] = a[right], a[left] - } - - // append it to values - values = append(values, a...) - - // then cut off of it is bigger than maxLen - if len(values) > maxLen { - values = values[:maxLen] - } - - return values -} diff --git a/coredns_plugin/querylog_top.go b/coredns_plugin/querylog_top.go deleted file mode 100644 index d4cc6e0d..00000000 --- a/coredns_plugin/querylog_top.go +++ /dev/null @@ -1,386 +0,0 @@ -package dnsfilter - -import ( - "bytes" - "fmt" - "log" - "net/http" - "os" - "path" - "runtime" - "sort" - "strconv" - "strings" - "sync" - "time" - - "github.com/AdguardTeam/AdGuardHome/dnsfilter" - "github.com/bluele/gcache" - "github.com/miekg/dns" -) - -type hourTop struct { - domains gcache.Cache - blocked gcache.Cache - clients gcache.Cache - - mutex sync.RWMutex -} - -func (top *hourTop) init() { - top.domains = gcache.New(queryLogTopSize).LRU().Build() - top.blocked = gcache.New(queryLogTopSize).LRU().Build() - top.clients = gcache.New(queryLogTopSize).LRU().Build() -} - -type dayTop struct { - hours []*hourTop - hoursLock sync.RWMutex // writelock this lock ONLY WHEN rotating or intializing hours! - - loaded bool - loadedLock sync.Mutex -} - -var runningTop dayTop - -func init() { - runningTop.hoursWriteLock() - for i := 0; i < 24; i++ { - hour := hourTop{} - hour.init() - runningTop.hours = append(runningTop.hours, &hour) - } - runningTop.hoursWriteUnlock() -} - -func rotateHourlyTop() { - log.Printf("Rotating hourly top") - hour := &hourTop{} - hour.init() - runningTop.hoursWriteLock() - runningTop.hours = append([]*hourTop{hour}, runningTop.hours...) - runningTop.hours = runningTop.hours[:24] - runningTop.hoursWriteUnlock() -} - -func periodicHourlyTopRotate() { - t := time.Hour - for range time.Tick(t) { - rotateHourlyTop() - } -} - -func (top *hourTop) incrementValue(key string, cache gcache.Cache) error { - top.Lock() - defer top.Unlock() - ivalue, err := cache.Get(key) - if err == gcache.KeyNotFoundError { - // we just set it and we're done - err = cache.Set(key, 1) - if err != nil { - log.Printf("Failed to set hourly top value: %s", err) - return err - } - return nil - } - - if err != nil { - log.Printf("gcache encountered an error during get: %s", err) - return err - } - - cachedValue, ok := ivalue.(int) - if !ok { - err = fmt.Errorf("SHOULD NOT HAPPEN: gcache has non-int as value: %v", ivalue) - log.Println(err) - return err - } - - err = cache.Set(key, cachedValue+1) - if err != nil { - log.Printf("Failed to set hourly top value: %s", err) - return err - } - return nil -} - -func (top *hourTop) incrementDomains(key string) error { - return top.incrementValue(key, top.domains) -} - -func (top *hourTop) incrementBlocked(key string) error { - return top.incrementValue(key, top.blocked) -} - -func (top *hourTop) incrementClients(key string) error { - return top.incrementValue(key, top.clients) -} - -// if does not exist -- return 0 -func (top *hourTop) lockedGetValue(key string, cache gcache.Cache) (int, error) { - ivalue, err := cache.Get(key) - if err == gcache.KeyNotFoundError { - return 0, nil - } - - if err != nil { - log.Printf("gcache encountered an error during get: %s", err) - return 0, err - } - - value, ok := ivalue.(int) - if !ok { - err := fmt.Errorf("SHOULD NOT HAPPEN: gcache has non-int as value: %v", ivalue) - log.Println(err) - return 0, err - } - - return value, nil -} - -func (top *hourTop) lockedGetDomains(key string) (int, error) { - return top.lockedGetValue(key, top.domains) -} - -func (top *hourTop) lockedGetBlocked(key string) (int, error) { - return top.lockedGetValue(key, top.blocked) -} - -func (top *hourTop) lockedGetClients(key string) (int, error) { - return top.lockedGetValue(key, top.clients) -} - -func (r *dayTop) addEntry(entry *logEntry, q *dns.Msg, now time.Time) error { - // figure out which hour bucket it belongs to - hour := int(now.Sub(entry.Time).Hours()) - if hour >= 24 { - log.Printf("t %v is >24 hours ago, ignoring", entry.Time) - return nil - } - - hostname := strings.ToLower(strings.TrimSuffix(q.Question[0].Name, ".")) - - // get value, if not set, crate one - runningTop.hoursReadLock() - defer runningTop.hoursReadUnlock() - err := runningTop.hours[hour].incrementDomains(hostname) - if err != nil { - log.Printf("Failed to increment value: %s", err) - return err - } - - if entry.Result.IsFiltered { - err := runningTop.hours[hour].incrementBlocked(hostname) - if err != nil { - log.Printf("Failed to increment value: %s", err) - return err - } - } - - if len(entry.IP) > 0 { - err := runningTop.hours[hour].incrementClients(entry.IP) - if err != nil { - log.Printf("Failed to increment value: %s", err) - return err - } - } - - return nil -} - -func fillStatsFromQueryLog() error { - now := time.Now() - runningTop.loadedWriteLock() - defer runningTop.loadedWriteUnlock() - if runningTop.loaded { - return nil - } - onEntry := func(entry *logEntry) error { - if len(entry.Question) == 0 { - log.Printf("entry question is absent, skipping") - return nil - } - - if entry.Time.After(now) { - log.Printf("t %v vs %v is in the future, ignoring", entry.Time, now) - return nil - } - - q := new(dns.Msg) - if err := q.Unpack(entry.Question); err != nil { - log.Printf("failed to unpack dns message question: %s", err) - return err - } - - if len(q.Question) != 1 { - log.Printf("malformed dns message, has no questions, skipping") - return nil - } - - err := runningTop.addEntry(entry, q, now) - if err != nil { - log.Printf("Failed to add entry to running top: %s", err) - return err - } - - queryLogLock.Lock() - queryLogCache = append(queryLogCache, entry) - if len(queryLogCache) > queryLogSize { - toremove := len(queryLogCache) - queryLogSize - queryLogCache = queryLogCache[toremove:] - } - queryLogLock.Unlock() - - requests.IncWithTime(entry.Time) - if entry.Result.IsFiltered { - filtered.IncWithTime(entry.Time) - } - switch entry.Result.Reason { - case dnsfilter.NotFilteredWhiteList: - whitelisted.IncWithTime(entry.Time) - case dnsfilter.NotFilteredError: - errorsTotal.IncWithTime(entry.Time) - case dnsfilter.FilteredBlackList: - filteredLists.IncWithTime(entry.Time) - case dnsfilter.FilteredSafeBrowsing: - filteredSafebrowsing.IncWithTime(entry.Time) - case dnsfilter.FilteredParental: - filteredParental.IncWithTime(entry.Time) - case dnsfilter.FilteredInvalid: - // do nothing - case dnsfilter.FilteredSafeSearch: - safesearch.IncWithTime(entry.Time) - } - elapsedTime.ObserveWithTime(entry.Elapsed.Seconds(), entry.Time) - - return nil - } - - needMore := func() bool { return true } - err := genericLoader(onEntry, needMore, queryLogTimeLimit) - if err != nil { - log.Printf("Failed to load entries from querylog: %s", err) - return err - } - - runningTop.loaded = true - - return nil -} - -func HandleStatsTop(w http.ResponseWriter, r *http.Request) { - domains := map[string]int{} - blocked := map[string]int{} - clients := map[string]int{} - - do := func(keys []interface{}, getter func(key string) (int, error), result map[string]int) { - for _, ikey := range keys { - key, ok := ikey.(string) - if !ok { - continue - } - value, err := getter(key) - if err != nil { - log.Printf("Failed to get top domains value for %v: %s", key, err) - return - } - result[key] += value - } - } - - runningTop.hoursReadLock() - for hour := 0; hour < 24; hour++ { - runningTop.hours[hour].RLock() - do(runningTop.hours[hour].domains.Keys(), runningTop.hours[hour].lockedGetDomains, domains) - do(runningTop.hours[hour].blocked.Keys(), runningTop.hours[hour].lockedGetBlocked, blocked) - do(runningTop.hours[hour].clients.Keys(), runningTop.hours[hour].lockedGetClients, clients) - runningTop.hours[hour].RUnlock() - } - runningTop.hoursReadUnlock() - - // use manual json marshalling because we want maps to be sorted by value - json := bytes.Buffer{} - json.WriteString("{\n") - - gen := func(json *bytes.Buffer, name string, top map[string]int, addComma bool) { - json.WriteString(" ") - json.WriteString(fmt.Sprintf("%q", name)) - json.WriteString(": {\n") - sorted := sortByValue(top) - // no more than 50 entries - if len(sorted) > 50 { - sorted = sorted[:50] - } - for i, key := range sorted { - json.WriteString(" ") - json.WriteString(fmt.Sprintf("%q", key)) - json.WriteString(": ") - json.WriteString(strconv.Itoa(top[key])) - if i+1 != len(sorted) { - json.WriteByte(',') - } - json.WriteByte('\n') - } - json.WriteString(" }") - if addComma { - json.WriteByte(',') - } - json.WriteByte('\n') - } - gen(&json, "top_queried_domains", domains, true) - gen(&json, "top_blocked_domains", blocked, true) - gen(&json, "top_clients", clients, true) - json.WriteString(" \"stats_period\": \"24 hours\"\n") - json.WriteString("}\n") - - w.Header().Set("Content-Type", "application/json") - _, err := w.Write(json.Bytes()) - if err != nil { - errortext := fmt.Sprintf("Couldn't write body: %s", err) - log.Println(errortext) - http.Error(w, errortext, http.StatusInternalServerError) - } -} - -// helper function for querylog API -func sortByValue(m map[string]int) []string { - type kv struct { - k string - v int - } - var ss []kv - for k, v := range m { - ss = append(ss, kv{k, v}) - } - sort.Slice(ss, func(l, r int) bool { - return ss[l].v > ss[r].v - }) - - sorted := []string{} - for _, v := range ss { - sorted = append(sorted, v.k) - } - return sorted -} - -func (d *dayTop) hoursWriteLock() { tracelock(); d.hoursLock.Lock() } -func (d *dayTop) hoursWriteUnlock() { tracelock(); d.hoursLock.Unlock() } -func (d *dayTop) hoursReadLock() { tracelock(); d.hoursLock.RLock() } -func (d *dayTop) hoursReadUnlock() { tracelock(); d.hoursLock.RUnlock() } -func (d *dayTop) loadedWriteLock() { tracelock(); d.loadedLock.Lock() } -func (d *dayTop) loadedWriteUnlock() { tracelock(); d.loadedLock.Unlock() } - -func (h *hourTop) Lock() { tracelock(); h.mutex.Lock() } -func (h *hourTop) RLock() { tracelock(); h.mutex.RLock() } -func (h *hourTop) RUnlock() { tracelock(); h.mutex.RUnlock() } -func (h *hourTop) Unlock() { tracelock(); h.mutex.Unlock() } - -func tracelock() { - if false { // not commented out to make code checked during compilation - pc := make([]uintptr, 10) // at least 1 entry needed - runtime.Callers(2, pc) - f := path.Base(runtime.FuncForPC(pc[1]).Name()) - lockf := path.Base(runtime.FuncForPC(pc[0]).Name()) - fmt.Fprintf(os.Stderr, "%s(): %s\n", f, lockf) - } -} diff --git a/coredns_plugin/ratelimit/ratelimit.go b/coredns_plugin/ratelimit/ratelimit.go deleted file mode 100644 index 8d3eeecc..00000000 --- a/coredns_plugin/ratelimit/ratelimit.go +++ /dev/null @@ -1,182 +0,0 @@ -package ratelimit - -import ( - "errors" - "log" - "sort" - "strconv" - "time" - - // ratelimiting and per-ip buckets - "github.com/beefsack/go-rate" - "github.com/patrickmn/go-cache" - - // coredns plugin - "github.com/coredns/coredns/core/dnsserver" - "github.com/coredns/coredns/plugin" - "github.com/coredns/coredns/plugin/metrics" - "github.com/coredns/coredns/plugin/pkg/dnstest" - "github.com/coredns/coredns/request" - "github.com/mholt/caddy" - "github.com/miekg/dns" - "github.com/prometheus/client_golang/prometheus" - "golang.org/x/net/context" -) - -const defaultRatelimit = 30 -const defaultResponseSize = 1000 - -var ( - tokenBuckets = cache.New(time.Hour, time.Hour) -) - -// ServeDNS handles the DNS request and refuses if it's an beyind specified ratelimit -func (p *plug) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { - state := request.Request{W: w, Req: r} - ip := state.IP() - allow, err := p.allowRequest(ip) - if err != nil { - return 0, err - } - if !allow { - ratelimited.Inc() - return 0, nil - } - - // Record response to get status code and size of the reply. - rw := dnstest.NewRecorder(w) - status, err := plugin.NextOrFailure(p.Name(), p.Next, ctx, rw, r) - - size := rw.Len - - if size > defaultResponseSize && state.Proto() == "udp" { - // For large UDP responses we call allowRequest more times - // The exact number of times depends on the response size - for i := 0; i < size/defaultResponseSize; i++ { - p.allowRequest(ip) - } - } - - return status, err -} - -func (p *plug) allowRequest(ip string) (bool, error) { - if len(p.whitelist) > 0 { - i := sort.SearchStrings(p.whitelist, ip) - - if i < len(p.whitelist) && p.whitelist[i] == ip { - return true, nil - } - } - - if _, found := tokenBuckets.Get(ip); !found { - tokenBuckets.Set(ip, rate.New(p.ratelimit, time.Second), time.Hour) - } - - value, found := tokenBuckets.Get(ip) - if !found { - // should not happen since we've just inserted it - text := "SHOULD NOT HAPPEN: just-inserted ratelimiter disappeared" - log.Println(text) - err := errors.New(text) - return true, err - } - - rl, ok := value.(*rate.RateLimiter) - if !ok { - text := "SHOULD NOT HAPPEN: non-bool entry found in safebrowsing lookup cache" - log.Println(text) - err := errors.New(text) - return true, err - } - - allow, _ := rl.Try() - return allow, nil -} - -// -// helper functions -// -func init() { - caddy.RegisterPlugin("ratelimit", caddy.Plugin{ - ServerType: "dns", - Action: setup, - }) -} - -type plug struct { - Next plugin.Handler - - // configuration for creating above - ratelimit int // in requests per second per IP - whitelist []string // a list of whitelisted IP addresses -} - -func setupPlugin(c *caddy.Controller) (*plug, error) { - p := &plug{ratelimit: defaultRatelimit} - - for c.Next() { - args := c.RemainingArgs() - if len(args) > 0 { - ratelimit, err := strconv.Atoi(args[0]) - if err != nil { - return nil, c.ArgErr() - } - p.ratelimit = ratelimit - } - for c.NextBlock() { - switch c.Val() { - case "whitelist": - p.whitelist = c.RemainingArgs() - - if len(p.whitelist) > 0 { - sort.Strings(p.whitelist) - } - } - } - } - - return p, nil -} - -func setup(c *caddy.Controller) error { - p, err := setupPlugin(c) - if err != nil { - return err - } - - config := dnsserver.GetConfig(c) - config.AddPlugin(func(next plugin.Handler) plugin.Handler { - p.Next = next - return p - }) - - c.OnStartup(func() error { - m := dnsserver.GetConfig(c).Handler("prometheus") - if m == nil { - return nil - } - if x, ok := m.(*metrics.Metrics); ok { - x.MustRegister(ratelimited) - } - return nil - }) - - return nil -} - -func newDNSCounter(name string, help string) prometheus.Counter { - return prometheus.NewCounter(prometheus.CounterOpts{ - Namespace: plugin.Namespace, - Subsystem: "ratelimit", - Name: name, - Help: help, - }) -} - -var ( - ratelimited = newDNSCounter("dropped_total", "Count of requests that have been dropped because of rate limit") -) - -// Name returns name of the plugin as seen in Corefile and plugin.cfg -func (p *plug) Name() string { return "ratelimit" } diff --git a/coredns_plugin/ratelimit/ratelimit_test.go b/coredns_plugin/ratelimit/ratelimit_test.go deleted file mode 100644 index b426f2eb..00000000 --- a/coredns_plugin/ratelimit/ratelimit_test.go +++ /dev/null @@ -1,80 +0,0 @@ -package ratelimit - -import ( - "testing" - - "github.com/mholt/caddy" -) - -func TestSetup(t *testing.T) { - for i, testcase := range []struct { - config string - failing bool - }{ - {`ratelimit`, false}, - {`ratelimit 100`, false}, - {`ratelimit { - whitelist 127.0.0.1 - }`, false}, - {`ratelimit 50 { - whitelist 127.0.0.1 176.103.130.130 - }`, false}, - {`ratelimit test`, true}, - } { - c := caddy.NewTestController("dns", testcase.config) - err := setup(c) - if err != nil { - if !testcase.failing { - t.Fatalf("Test #%d expected no errors, but got: %v", i, err) - } - continue - } - if testcase.failing { - t.Fatalf("Test #%d expected to fail but it didn't", i) - } - } -} - -func TestRatelimiting(t *testing.T) { - // rate limit is 1 per sec - c := caddy.NewTestController("dns", `ratelimit 1`) - p, err := setupPlugin(c) - - if err != nil { - t.Fatal("Failed to initialize the plugin") - } - - allowed, err := p.allowRequest("127.0.0.1") - - if err != nil || !allowed { - t.Fatal("First request must have been allowed") - } - - allowed, err = p.allowRequest("127.0.0.1") - - if err != nil || allowed { - t.Fatal("Second request must have been ratelimited") - } -} - -func TestWhitelist(t *testing.T) { - // rate limit is 1 per sec - c := caddy.NewTestController("dns", `ratelimit 1 { whitelist 127.0.0.2 127.0.0.1 127.0.0.125 }`) - p, err := setupPlugin(c) - - if err != nil { - t.Fatal("Failed to initialize the plugin") - } - - allowed, err := p.allowRequest("127.0.0.1") - - if err != nil || !allowed { - t.Fatal("First request must have been allowed") - } - - allowed, err = p.allowRequest("127.0.0.1") - - if err != nil || !allowed { - t.Fatal("Second request must have been allowed due to whitelist") - } -} diff --git a/coredns_plugin/refuseany/refuseany.go b/coredns_plugin/refuseany/refuseany.go deleted file mode 100644 index 92d5d508..00000000 --- a/coredns_plugin/refuseany/refuseany.go +++ /dev/null @@ -1,91 +0,0 @@ -package refuseany - -import ( - "fmt" - "log" - - "github.com/coredns/coredns/core/dnsserver" - "github.com/coredns/coredns/plugin" - "github.com/coredns/coredns/plugin/metrics" - "github.com/coredns/coredns/request" - "github.com/mholt/caddy" - "github.com/miekg/dns" - "github.com/prometheus/client_golang/prometheus" - "golang.org/x/net/context" -) - -type plug struct { - Next plugin.Handler -} - -// ServeDNS handles the DNS request and refuses if it's an ANY request -func (p *plug) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { - if len(r.Question) != 1 { - // google DNS, bind and others do the same - return dns.RcodeFormatError, fmt.Errorf("Got DNS request with != 1 questions") - } - - q := r.Question[0] - if q.Qtype == dns.TypeANY { - state := request.Request{W: w, Req: r, Context: ctx} - rcode := dns.RcodeNotImplemented - - m := new(dns.Msg) - m.SetRcode(r, rcode) - state.SizeAndDo(m) - err := state.W.WriteMsg(m) - if err != nil { - log.Printf("Got error %s\n", err) - return dns.RcodeServerFailure, err - } - return rcode, nil - } - - return plugin.NextOrFailure(p.Name(), p.Next, ctx, w, r) -} - -func init() { - caddy.RegisterPlugin("refuseany", caddy.Plugin{ - ServerType: "dns", - Action: setup, - }) -} - -func setup(c *caddy.Controller) error { - p := &plug{} - config := dnsserver.GetConfig(c) - - config.AddPlugin(func(next plugin.Handler) plugin.Handler { - p.Next = next - return p - }) - - c.OnStartup(func() error { - m := dnsserver.GetConfig(c).Handler("prometheus") - if m == nil { - return nil - } - if x, ok := m.(*metrics.Metrics); ok { - x.MustRegister(ratelimited) - } - return nil - }) - - return nil -} - -func newDNSCounter(name string, help string) prometheus.Counter { - return prometheus.NewCounter(prometheus.CounterOpts{ - Namespace: plugin.Namespace, - Subsystem: "refuseany", - Name: name, - Help: help, - }) -} - -var ( - ratelimited = newDNSCounter("refusedany_total", "Count of ANY requests that have been dropped") -) - -// Name returns name of the plugin as seen in Corefile and plugin.cfg -func (p *plug) Name() string { return "refuseany" } diff --git a/coredns_plugin/reload.go b/coredns_plugin/reload.go deleted file mode 100644 index 880a3acc..00000000 --- a/coredns_plugin/reload.go +++ /dev/null @@ -1,36 +0,0 @@ -package dnsfilter - -import ( - "log" - - "github.com/mholt/caddy" -) - -var Reload = make(chan bool) - -func hook(event caddy.EventName, info interface{}) error { - if event != caddy.InstanceStartupEvent { - return nil - } - - // this should be an instance. ok to panic if not - instance := info.(*caddy.Instance) - - go func() { - for range Reload { - corefile, err := caddy.LoadCaddyfile(instance.Caddyfile().ServerType()) - if err != nil { - continue - } - _, err = instance.Restart(corefile) - if err != nil { - log.Printf("Corefile changed but reload failed: %s", err) - continue - } - // hook will be called again from new instance - return - } - }() - - return nil -} diff --git a/go.mod b/go.mod index 1b8d78e6..166e3cce 100644 --- a/go.mod +++ b/go.mod @@ -3,35 +3,19 @@ module github.com/AdguardTeam/AdGuardHome require ( github.com/StackExchange/wmi v0.0.0-20180725035823-b12b22c5341f // indirect github.com/beefsack/go-rate v0.0.0-20180408011153-efa7637bb9b6 - github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973 // indirect github.com/bluele/gcache v0.0.0-20171010155617-472614239ac7 - github.com/coredns/coredns v1.2.6 - github.com/dnstap/golang-dnstap v0.0.0-20170829151710-2cf77a2b5e11 // indirect - github.com/farsightsec/golang-framestream v0.0.0-20181102145529-8a0cb8ba8710 // indirect - github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568 // indirect github.com/go-ole/go-ole v1.2.1 // indirect github.com/go-test/deep v1.0.1 github.com/gobuffalo/packr v1.19.0 - github.com/google/uuid v1.0.0 // indirect - github.com/grpc-ecosystem/grpc-opentracing v0.0.0-20180507213350-8e809c8a8645 // indirect github.com/joomcode/errorx v0.1.0 - github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect - github.com/mholt/caddy v0.11.0 github.com/miekg/dns v1.0.15 - github.com/opentracing/opentracing-go v1.0.2 // indirect github.com/patrickmn/go-cache v2.1.0+incompatible - github.com/pkg/errors v0.8.0 - github.com/prometheus/client_golang v0.9.0-pre1 - github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910 // indirect - github.com/prometheus/common v0.0.0-20181109100915-0b1957f9d949 // indirect - github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d // indirect github.com/shirou/gopsutil v2.18.10+incompatible github.com/shirou/w32 v0.0.0-20160930032740-bb4de0191aa4 // indirect go.uber.org/goleak v0.10.0 golang.org/x/crypto v0.0.0-20181106171534-e4dc69e5b2fd golang.org/x/net v0.0.0-20181108082009-03003ca0c849 golang.org/x/sys v0.0.0-20181107165924-66b7b1311ac8 // indirect - google.golang.org/grpc v1.16.0 // indirect gopkg.in/asaskevich/govalidator.v4 v4.0.0-20160518190739-766470278477 gopkg.in/yaml.v2 v2.2.1 ) diff --git a/go.sum b/go.sum index 4ecb93be..af10df24 100644 --- a/go.sum +++ b/go.sum @@ -1,23 +1,11 @@ -cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= github.com/StackExchange/wmi v0.0.0-20180725035823-b12b22c5341f h1:5ZfJxyXo8KyX8DgGXC5B7ILL8y51fci/qYz2B4j8iLY= github.com/StackExchange/wmi v0.0.0-20180725035823-b12b22c5341f/go.mod h1:3eOhrUMpNV+6aFIbp5/iudMxNCF27Vw2OZgy4xEx0Fg= github.com/beefsack/go-rate v0.0.0-20180408011153-efa7637bb9b6 h1:KXlsf+qt/X5ttPGEjR0tPH1xaWWoKBEg9Q1THAj2h3I= github.com/beefsack/go-rate v0.0.0-20180408011153-efa7637bb9b6/go.mod h1:6YNgTHLutezwnBvyneBbwvB8C82y3dcoOj5EQJIdGXA= -github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973 h1:xJ4a3vCFaGF/jqvzLMYoU8P317H5OQ+Via4RmuPwCS0= -github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/bluele/gcache v0.0.0-20171010155617-472614239ac7 h1:NpQ+gkFOH27AyDypSCJ/LdsIi/b4rdnEb1N5+IpFfYs= github.com/bluele/gcache v0.0.0-20171010155617-472614239ac7/go.mod h1:8c4/i2VlovMO2gBnHGQPN5EJw+H0lx1u/5p+cgsXtCk= -github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= -github.com/coredns/coredns v1.2.6 h1:QIAOkBqVE44Zx0ttrFqgE5YhCEn64XPIngU60JyuTGM= -github.com/coredns/coredns v1.2.6/go.mod h1:zASH/MVDgR6XZTbxvOnsZfffS+31vg6Ackf/wo1+AM0= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dnstap/golang-dnstap v0.0.0-20170829151710-2cf77a2b5e11 h1:m8nX8hsUghn853BJ5qB0lX+VvS6LTJPksWyILFZRYN4= -github.com/dnstap/golang-dnstap v0.0.0-20170829151710-2cf77a2b5e11/go.mod h1:s1PfVYYVmTMgCSPtho4LKBDecEHJWtiVDPNv78Z985U= -github.com/farsightsec/golang-framestream v0.0.0-20181102145529-8a0cb8ba8710 h1:QdyRyGZWLEvJG5Kw3VcVJvhXJ5tZ1MkRgqpJOEZSySM= -github.com/farsightsec/golang-framestream v0.0.0-20181102145529-8a0cb8ba8710/go.mod h1:eNde4IQyEiA5br02AouhEHCu3p3UzrCdFR4LuQHklMI= -github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568 h1:BHsljHzVlRcyQhjrss6TZTdY2VfCqZPbv5k3iBFa2ZQ= -github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc= github.com/go-ole/go-ole v1.2.1 h1:2lOsA72HgjxAuMlKpFiCbHTvu44PIVkZ5hqm3RSdI/E= github.com/go-ole/go-ole v1.2.1/go.mod h1:7FAglXiTm7HKlQRDeOQ6ZNUHidzCWXuZWq/1dTyBNF8= github.com/go-test/deep v1.0.1 h1:UQhStjbkDClarlmv0am7OXXO4/GaPdCGiUiMTvi28sg= @@ -28,46 +16,21 @@ github.com/gobuffalo/packd v0.0.0-20181031195726-c82734870264 h1:roWyi0eEdiFreSq github.com/gobuffalo/packd v0.0.0-20181031195726-c82734870264/go.mod h1:Yf2toFaISlyQrr5TfO3h6DB9pl9mZRmyvBGQb/aQ/pI= github.com/gobuffalo/packr v1.19.0 h1:3UDmBDxesCOPF8iZdMDBBWKfkBoYujIMIZePnobqIUI= github.com/gobuffalo/packr v1.19.0/go.mod h1:MstrNkfCQhd5o+Ct4IJ0skWlxN8emOq8DsoT1G98VIU= -github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58= -github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= -github.com/golang/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E= -github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= -github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM= -github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/google/uuid v1.0.0 h1:b4Gk+7WdP/d3HZH8EJsZpvV7EtDOgaZLtnaNGIu1adA= -github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/grpc-ecosystem/grpc-opentracing v0.0.0-20180507213350-8e809c8a8645 h1:MJG/KsmcqMwFAkh8mTnAwhyKoB+sTAnY4CACC110tbU= -github.com/grpc-ecosystem/grpc-opentracing v0.0.0-20180507213350-8e809c8a8645/go.mod h1:6iZfnjpejD4L/4DwD7NryNaJyCQdzwWwH2MWhCA90Kw= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/joho/godotenv v1.3.0 h1:Zjp+RcGpHhGlrMbJzXTrZZPrWj+1vfm90La1wgB6Bhc= github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg= github.com/joomcode/errorx v0.1.0 h1:QmJMiI1DE1UFje2aI1ZWO/VMT5a32qBoXUclGOt8vsc= github.com/joomcode/errorx v0.1.0/go.mod h1:kgco15ekB6cs+4Xjzo7SPeXzx38PbJzBwbnu9qfVNHQ= -github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/markbates/oncer v0.0.0-20181014194634-05fccaae8fc4 h1:Mlji5gkcpzkqTROyE4ZxZ8hN7osunMb2RuGVrbvMvCc= github.com/markbates/oncer v0.0.0-20181014194634-05fccaae8fc4/go.mod h1:Ld9puTsIW75CHf65OeIOkyKbteujpZVXDpWK6YGZbxE= -github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= -github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= -github.com/mholt/caddy v0.11.0 h1:cuhEyR7So/SBBRiAaiRBe9BoccDu6uveIPuM9FMMavg= -github.com/mholt/caddy v0.11.0/go.mod h1:Wb1PlT4DAYSqOEd03MsqkdkXnTxA8v9pKjdpxbqM1kY= github.com/miekg/dns v1.0.15 h1:9+UupePBQCG6zf1q/bGmTO1vumoG13jsrbWOSX1W6Tw= github.com/miekg/dns v1.0.15/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= -github.com/opentracing/opentracing-go v1.0.2 h1:3jA2P6O1F9UOrWVpwrIo17pu01KWvNWg4X946/Y5Zwg= -github.com/opentracing/opentracing-go v1.0.2/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= github.com/pkg/errors v0.8.0 h1:WdK/asTD0HN+q6hsWO3/vpuAkAr+tw6aNJNDFFf0+qw= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/prometheus/client_golang v0.9.0-pre1 h1:AWTOhsOI9qxeirTuA0A4By/1Es1+y9EcCGY6bBZ2fhM= -github.com/prometheus/client_golang v0.9.0-pre1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= -github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910 h1:idejC8f05m9MGOsuEi1ATq9shN03HrxNkD/luQvxCv8= -github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= -github.com/prometheus/common v0.0.0-20181109100915-0b1957f9d949 h1:MVbUQq1a49hMEISI29UcAUjywT3FyvDwx5up90OvVa4= -github.com/prometheus/common v0.0.0-20181109100915-0b1957f9d949/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro= -github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d h1:GoAlyOgbOEIFdaDqxJVlbOQ1DtGmZWs/Qau0hIlk+WQ= -github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/shirou/gopsutil v2.18.10+incompatible h1:cy84jW6EVRPa5g9HAHrlbxMSIjBhDSX0OFYyMYminYs= github.com/shirou/gopsutil v2.18.10+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= github.com/shirou/w32 v0.0.0-20160930032740-bb4de0191aa4 h1:udFKJ0aHUL60LboW/A+DfgoHVedieIzIXE8uylPue0U= @@ -82,29 +45,16 @@ go.uber.org/goleak v0.10.0 h1:G3eWbSNIskeRqtsN/1uI5B+eP73y3JUuBsv9AZjehb4= go.uber.org/goleak v0.10.0/go.mod h1:VCZuO8V8mFPlL0F5J5GK1rtHV3DrFcQ1R8ryq7FK0aI= golang.org/x/crypto v0.0.0-20181106171534-e4dc69e5b2fd h1:VtIkGDhk0ph3t+THbvXHfMZ8QHgsBO39Nh52+74pq7w= golang.org/x/crypto v0.0.0-20181106171534-e4dc69e5b2fd/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= -golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= -golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181102091132-c10e9556a7bc/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181108082009-03003ca0c849 h1:FSqE2GGG7wzsYUsWiQ8MZrvEd1EOyU3NCF0AW3Wtltg= golang.org/x/net v0.0.0-20181108082009-03003ca0c849/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f h1:wMNYb4v58l5UBM7MYRLPG6ZhfOqbKu7X5eyFl8ZhKvA= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181107165924-66b7b1311ac8 h1:YoY1wS6JYVRpIfFngRf2HHo9R9dAne3xbkGOQ5rJXjU= golang.org/x/sys v0.0.0-20181107165924-66b7b1311ac8/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= -google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8 h1:Nw54tB0rB7hY/N0NQvRW8DG4Yk3Q6T9cu9RcFQDu1tc= -google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/grpc v1.16.0 h1:dz5IJGuC2BB7qXR5AyHNwAUBhZscK2xVez7mznh72sY= -google.golang.org/grpc v1.16.0/go.mod h1:0JHn/cJsOMiMfNA9+DeHDlAU7KAAB5GDlYFpa9MZMio= gopkg.in/asaskevich/govalidator.v4 v4.0.0-20160518190739-766470278477 h1:5xUJw+lg4zao9W4HIDzlFbMYgSgtvNVHh00MEHvbGpQ= gopkg.in/asaskevich/govalidator.v4 v4.0.0-20160518190739-766470278477/go.mod h1:QDV1vrFSrowdoOba0UM8VJPUZONT7dnfdLsM+GG53Z8= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.1 h1:mUhvW9EsL+naU5Q3cakzfE91YhliOondGd6ZrsDBHQE= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/upstream/dns_upstream.go b/upstream/dns_upstream.go deleted file mode 100644 index 171f6362..00000000 --- a/upstream/dns_upstream.go +++ /dev/null @@ -1,105 +0,0 @@ -package upstream - -import ( - "crypto/tls" - "time" - - "github.com/miekg/dns" - "golang.org/x/net/context" -) - -// DnsUpstream is a very simple upstream implementation for plain DNS -type DnsUpstream struct { - endpoint string // IP:port - timeout time.Duration // Max read and write timeout - proto string // Protocol (tcp, tcp-tls, or udp) - transport *Transport // Persistent connections cache -} - -// NewDnsUpstream creates a new DNS upstream -func NewDnsUpstream(endpoint string, proto string, tlsServerName string) (Upstream, error) { - u := &DnsUpstream{ - endpoint: endpoint, - timeout: defaultTimeout, - proto: proto, - } - - var tlsConfig *tls.Config - - if proto == "tcp-tls" { - tlsConfig = new(tls.Config) - tlsConfig.ServerName = tlsServerName - } - - // Initialize the connections cache - u.transport = NewTransport(endpoint) - u.transport.tlsConfig = tlsConfig - u.transport.Start() - - return u, nil -} - -// Exchange provides an implementation for the Upstream interface -func (u *DnsUpstream) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) { - resp, err := u.exchange(u.proto, query) - - // Retry over TCP if response is truncated - if err == dns.ErrTruncated && u.proto == "udp" { - resp, err = u.exchange("tcp", query) - } else if err == dns.ErrTruncated && resp != nil { - // Reassemble something to be sent to client - m := new(dns.Msg) - m.SetReply(query) - m.Truncated = true - m.Authoritative = true - m.Rcode = dns.RcodeSuccess - return m, nil - } - - if err != nil { - resp = &dns.Msg{} - resp.SetRcode(resp, dns.RcodeServerFailure) - } - - return resp, err -} - -// Clear resources -func (u *DnsUpstream) Close() error { - // Close active connections - u.transport.Stop() - return nil -} - -// Performs a synchronous query. It sends the message m via the conn -// c and waits for a reply. The conn c is not closed. -func (u *DnsUpstream) exchange(proto string, query *dns.Msg) (r *dns.Msg, err error) { - // Establish a connection if needed (or reuse cached) - conn, err := u.transport.Dial(proto) - if err != nil { - return nil, err - } - - // Write the request with a timeout - conn.SetWriteDeadline(time.Now().Add(u.timeout)) - if err = conn.WriteMsg(query); err != nil { - conn.Close() // Not giving it back - return nil, err - } - - // Write response with a timeout - conn.SetReadDeadline(time.Now().Add(u.timeout)) - r, err = conn.ReadMsg() - if err != nil { - conn.Close() // Not giving it back - } else if err == nil && r.Id != query.Id { - err = dns.ErrId - conn.Close() // Not giving it back - } - - if err == nil { - // Return it back to the connections cache if there were no errors - u.transport.Yield(conn) - } - return r, err -} diff --git a/upstream/helpers.go b/upstream/helpers.go deleted file mode 100644 index 520a7a8b..00000000 --- a/upstream/helpers.go +++ /dev/null @@ -1,98 +0,0 @@ -package upstream - -import ( - "net" - "strings" - - "github.com/miekg/dns" - "golang.org/x/net/context" -) - -// Detects the upstream type from the specified url and creates a proper Upstream object -func NewUpstream(url string, bootstrap string) (Upstream, error) { - proto := "udp" - prefix := "" - - switch { - case strings.HasPrefix(url, "tcp://"): - proto = "tcp" - prefix = "tcp://" - case strings.HasPrefix(url, "tls://"): - proto = "tcp-tls" - prefix = "tls://" - case strings.HasPrefix(url, "https://"): - return NewHttpsUpstream(url, bootstrap) - } - - hostname := strings.TrimPrefix(url, prefix) - - host, port, err := net.SplitHostPort(hostname) - if err != nil { - // Set port depending on the protocol - switch proto { - case "udp": - port = "53" - case "tcp": - port = "53" - case "tcp-tls": - port = "853" - } - - // Set host = hostname - host = hostname - } - - // Try to resolve the host address (or check if it's an IP address) - bootstrapResolver := CreateResolver(bootstrap) - ips, err := bootstrapResolver.LookupIPAddr(context.Background(), host) - - if err != nil || len(ips) == 0 { - return nil, err - } - - addr := ips[0].String() - endpoint := net.JoinHostPort(addr, port) - tlsServerName := "" - - if proto == "tcp-tls" && host != addr { - // Check if we need to specify TLS server name - tlsServerName = host - } - - return NewDnsUpstream(endpoint, proto, tlsServerName) -} - -func CreateResolver(bootstrap string) *net.Resolver { - bootstrapResolver := net.DefaultResolver - - if bootstrap != "" { - bootstrapResolver = &net.Resolver{ - PreferGo: true, - Dial: func(ctx context.Context, network, address string) (net.Conn, error) { - var d net.Dialer - return d.DialContext(ctx, network, bootstrap) - }, - } - } - - return bootstrapResolver -} - -// Performs a simple health-check of the specified upstream -func IsAlive(u Upstream) (bool, error) { - // Using ipv4only.arpa. domain as it is a part of DNS64 RFC and it should exist everywhere - ping := new(dns.Msg) - ping.SetQuestion("ipv4only.arpa.", dns.TypeA) - - resp, err := u.Exchange(context.Background(), ping) - - // If we got a header, we're alright, basically only care about I/O errors 'n stuff. - if err != nil && resp != nil { - // Silly check, something sane came back. - if resp.Rcode != dns.RcodeServerFailure { - err = nil - } - } - - return err == nil, err -} diff --git a/upstream/https_upstream.go b/upstream/https_upstream.go deleted file mode 100644 index d7d7bdde..00000000 --- a/upstream/https_upstream.go +++ /dev/null @@ -1,128 +0,0 @@ -package upstream - -import ( - "bytes" - "crypto/tls" - "fmt" - "io/ioutil" - "log" - "net" - "net/http" - "net/url" - "time" - - "github.com/miekg/dns" - "github.com/pkg/errors" - "golang.org/x/net/context" - "golang.org/x/net/http2" -) - -const ( - dnsMessageContentType = "application/dns-message" - defaultKeepAlive = 30 * time.Second -) - -// HttpsUpstream is the upstream implementation for DNS-over-HTTPS -type HttpsUpstream struct { - client *http.Client - endpoint *url.URL -} - -// NewHttpsUpstream creates a new DNS-over-HTTPS upstream from the specified url -func NewHttpsUpstream(endpoint string, bootstrap string) (Upstream, error) { - u, err := url.Parse(endpoint) - if err != nil { - return nil, err - } - - // Initialize bootstrap resolver - bootstrapResolver := CreateResolver(bootstrap) - dialer := &net.Dialer{ - Timeout: defaultTimeout, - KeepAlive: defaultKeepAlive, - DualStack: true, - Resolver: bootstrapResolver, - } - - // Update TLS and HTTP client configuration - tlsConfig := &tls.Config{ServerName: u.Hostname()} - transport := &http.Transport{ - TLSClientConfig: tlsConfig, - DisableCompression: true, - MaxIdleConns: 1, - DialContext: dialer.DialContext, - } - http2.ConfigureTransport(transport) - - client := &http.Client{ - Timeout: defaultTimeout, - Transport: transport, - } - - return &HttpsUpstream{client: client, endpoint: u}, nil -} - -// Exchange provides an implementation for the Upstream interface -func (u *HttpsUpstream) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) { - queryBuf, err := query.Pack() - if err != nil { - return nil, errors.Wrap(err, "failed to pack DNS query") - } - - // No content negotiation for now, use DNS wire format - buf, backendErr := u.exchangeWireformat(queryBuf) - if backendErr == nil { - response := &dns.Msg{} - if err := response.Unpack(buf); err != nil { - return nil, errors.Wrap(err, "failed to unpack DNS response from body") - } - - response.Id = query.Id - return response, nil - } - - log.Printf("failed to connect to an HTTPS backend %q due to %s", u.endpoint, backendErr) - return nil, backendErr -} - -// Perform message exchange with the default UDP wireformat defined in current draft -// https://tools.ietf.org/html/draft-ietf-doh-dns-over-https-10 -func (u *HttpsUpstream) exchangeWireformat(msg []byte) ([]byte, error) { - req, err := http.NewRequest("POST", u.endpoint.String(), bytes.NewBuffer(msg)) - if err != nil { - return nil, errors.Wrap(err, "failed to create an HTTPS request") - } - - req.Header.Add("Content-Type", dnsMessageContentType) - req.Header.Add("Accept", dnsMessageContentType) - req.Host = u.endpoint.Hostname() - - resp, err := u.client.Do(req) - if err != nil { - return nil, errors.Wrap(err, "failed to perform an HTTPS request") - } - - // Check response status code - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("returned status code %d", resp.StatusCode) - } - - contentType := resp.Header.Get("Content-Type") - if contentType != dnsMessageContentType { - return nil, fmt.Errorf("return wrong content type %s", contentType) - } - - // Read application/dns-message response from the body - buf, err := ioutil.ReadAll(resp.Body) - if err != nil { - return nil, errors.Wrap(err, "failed to read the response body") - } - - return buf, nil -} - -// Clear resources -func (u *HttpsUpstream) Close() error { - return nil -} diff --git a/upstream/persistent.go b/upstream/persistent.go deleted file mode 100644 index 91cc9094..00000000 --- a/upstream/persistent.go +++ /dev/null @@ -1,210 +0,0 @@ -package upstream - -import ( - "crypto/tls" - "net" - "sort" - "sync/atomic" - "time" - - "github.com/miekg/dns" -) - -// Persistent connections cache -- almost similar to the same used in the CoreDNS forward plugin - -const ( - defaultExpire = 10 * time.Second - minDialTimeout = 100 * time.Millisecond - maxDialTimeout = 30 * time.Second - defaultDialTimeout = 30 * time.Second - cumulativeAvgWeight = 4 -) - -// a persistConn hold the dns.Conn and the last used time. -type persistConn struct { - c *dns.Conn - used time.Time -} - -// Transport hold the persistent cache. -type Transport struct { - avgDialTime int64 // kind of average time of dial time - conns map[string][]*persistConn // Buckets for udp, tcp and tcp-tls. - expire time.Duration // After this duration a connection is expired. - addr string - tlsConfig *tls.Config - - dial chan string - yield chan *dns.Conn - ret chan *dns.Conn - stop chan bool -} - -// Dial dials the address configured in transport, potentially reusing a connection or creating a new one. -func (t *Transport) Dial(proto string) (*dns.Conn, error) { - // If tls has been configured; use it. - if t.tlsConfig != nil { - proto = "tcp-tls" - } - - t.dial <- proto - c := <-t.ret - - if c != nil { - return c, nil - } - - reqTime := time.Now() - timeout := t.dialTimeout() - if proto == "tcp-tls" { - conn, err := dns.DialTimeoutWithTLS(proto, t.addr, t.tlsConfig, timeout) - t.updateDialTimeout(time.Since(reqTime)) - return conn, err - } - conn, err := dns.DialTimeout(proto, t.addr, timeout) - t.updateDialTimeout(time.Since(reqTime)) - return conn, err -} - -// Yield return the connection to transport for reuse. -func (t *Transport) Yield(c *dns.Conn) { t.yield <- c } - -// Start starts the transport's connection manager. -func (t *Transport) Start() { go t.connManager() } - -// Stop stops the transport's connection manager. -func (t *Transport) Stop() { close(t.stop) } - -// SetExpire sets the connection expire time in transport. -func (t *Transport) SetExpire(expire time.Duration) { t.expire = expire } - -// SetTLSConfig sets the TLS config in transport. -func (t *Transport) SetTLSConfig(cfg *tls.Config) { t.tlsConfig = cfg } - -func NewTransport(addr string) *Transport { - t := &Transport{ - avgDialTime: int64(defaultDialTimeout / 2), - conns: make(map[string][]*persistConn), - expire: defaultExpire, - addr: addr, - dial: make(chan string), - yield: make(chan *dns.Conn), - ret: make(chan *dns.Conn), - stop: make(chan bool), - } - return t -} - -func averageTimeout(currentAvg *int64, observedDuration time.Duration, weight int64) { - dt := time.Duration(atomic.LoadInt64(currentAvg)) - atomic.AddInt64(currentAvg, int64(observedDuration-dt)/weight) -} - -func (t *Transport) dialTimeout() time.Duration { - return limitTimeout(&t.avgDialTime, minDialTimeout, maxDialTimeout) -} - -func (t *Transport) updateDialTimeout(newDialTime time.Duration) { - averageTimeout(&t.avgDialTime, newDialTime, cumulativeAvgWeight) -} - -// limitTimeout is a utility function to auto-tune timeout values -// average observed time is moved towards the last observed delay moderated by a weight -// next timeout to use will be the double of the computed average, limited by min and max frame. -func limitTimeout(currentAvg *int64, minValue time.Duration, maxValue time.Duration) time.Duration { - rt := time.Duration(atomic.LoadInt64(currentAvg)) - if rt < minValue { - return minValue - } - if rt < maxValue/2 { - return 2 * rt - } - return maxValue -} - -// connManagers manages the persistent connection cache for UDP and TCP. -func (t *Transport) connManager() { - ticker := time.NewTicker(t.expire) -Wait: - for { - select { - case proto := <-t.dial: - // take the last used conn - complexity O(1) - if stack := t.conns[proto]; len(stack) > 0 { - pc := stack[len(stack)-1] - if time.Since(pc.used) < t.expire { - // Found one, remove from pool and return this conn. - t.conns[proto] = stack[:len(stack)-1] - t.ret <- pc.c - continue Wait - } - // clear entire cache if the last conn is expired - t.conns[proto] = nil - // now, the connections being passed to closeConns() are not reachable from - // transport methods anymore. So, it's safe to close them in a separate goroutine - go closeConns(stack) - } - - t.ret <- nil - - case conn := <-t.yield: - - // no proto here, infer from config and conn - if _, ok := conn.Conn.(*net.UDPConn); ok { - t.conns["udp"] = append(t.conns["udp"], &persistConn{conn, time.Now()}) - continue Wait - } - - if t.tlsConfig == nil { - t.conns["tcp"] = append(t.conns["tcp"], &persistConn{conn, time.Now()}) - continue Wait - } - - t.conns["tcp-tls"] = append(t.conns["tcp-tls"], &persistConn{conn, time.Now()}) - - case <-ticker.C: - t.cleanup(false) - - case <-t.stop: - t.cleanup(true) - close(t.ret) - return - } - } -} - -// closeConns closes connections. -func closeConns(conns []*persistConn) { - for _, pc := range conns { - pc.c.Close() - } -} - -// cleanup removes connections from cache. -func (t *Transport) cleanup(all bool) { - staleTime := time.Now().Add(-t.expire) - for proto, stack := range t.conns { - if len(stack) == 0 { - continue - } - if all { - t.conns[proto] = nil - // now, the connections being passed to closeConns() are not reachable from - // transport methods anymore. So, it's safe to close them in a separate goroutine - go closeConns(stack) - continue - } - if stack[0].used.After(staleTime) { - continue - } - - // connections in stack are sorted by "used" - good := sort.Search(len(stack), func(i int) bool { - return stack[i].used.After(staleTime) - }) - t.conns[proto] = stack[good:] - // now, the connections being passed to closeConns() are not reachable from - // transport methods anymore. So, it's safe to close them in a separate goroutine - go closeConns(stack[:good]) - } -} diff --git a/upstream/setup.go b/upstream/setup.go deleted file mode 100644 index 4aed6bcf..00000000 --- a/upstream/setup.go +++ /dev/null @@ -1,81 +0,0 @@ -package upstream - -import ( - "log" - - "github.com/coredns/coredns/core/dnsserver" - "github.com/coredns/coredns/plugin" - "github.com/mholt/caddy" -) - -func init() { - caddy.RegisterPlugin("upstream", caddy.Plugin{ - ServerType: "dns", - Action: setup, - }) -} - -// Read the configuration and initialize upstreams -func setup(c *caddy.Controller) error { - p, err := setupPlugin(c) - if err != nil { - return err - } - config := dnsserver.GetConfig(c) - config.AddPlugin(func(next plugin.Handler) plugin.Handler { - p.Next = next - return p - }) - - c.OnShutdown(p.onShutdown) - return nil -} - -// Read the configuration -func setupPlugin(c *caddy.Controller) (*UpstreamPlugin, error) { - p := New() - - log.Println("Initializing the Upstream plugin") - - bootstrap := "" - upstreamUrls := []string{} - for c.Next() { - args := c.RemainingArgs() - if len(args) > 0 { - upstreamUrls = append(upstreamUrls, args...) - } - for c.NextBlock() { - switch c.Val() { - case "bootstrap": - if !c.NextArg() { - return nil, c.ArgErr() - } - bootstrap = c.Val() - } - } - } - - for _, url := range upstreamUrls { - u, err := NewUpstream(url, bootstrap) - if err != nil { - log.Printf("Cannot initialize upstream %s", url) - return nil, err - } - - p.Upstreams = append(p.Upstreams, u) - } - - return p, nil -} - -func (p *UpstreamPlugin) onShutdown() error { - for i := range p.Upstreams { - u := p.Upstreams[i] - err := u.Close() - if err != nil { - log.Printf("Error while closing the upstream: %s", err) - } - } - - return nil -} diff --git a/upstream/setup_test.go b/upstream/setup_test.go deleted file mode 100644 index 82b8ab5c..00000000 --- a/upstream/setup_test.go +++ /dev/null @@ -1,29 +0,0 @@ -package upstream - -import ( - "testing" - - "github.com/mholt/caddy" -) - -func TestSetup(t *testing.T) { - var tests = []struct { - config string - }{ - {`upstream 8.8.8.8`}, - {`upstream 8.8.8.8 { - bootstrap 8.8.8.8:53 -}`}, - {`upstream tls://1.1.1.1 8.8.8.8 { - bootstrap 1.1.1.1 -}`}, - } - - for _, test := range tests { - c := caddy.NewTestController("dns", test.config) - err := setup(c) - if err != nil { - t.Fatalf("Test failed") - } - } -} diff --git a/upstream/upstream.go b/upstream/upstream.go deleted file mode 100644 index faef224e..00000000 --- a/upstream/upstream.go +++ /dev/null @@ -1,57 +0,0 @@ -package upstream - -import ( - "time" - - "github.com/coredns/coredns/plugin" - "github.com/miekg/dns" - "github.com/pkg/errors" - "golang.org/x/net/context" -) - -const ( - defaultTimeout = 5 * time.Second -) - -// Upstream is a simplified interface for proxy destination -type Upstream interface { - Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) - Close() error -} - -// UpstreamPlugin is a simplified DNS proxy using a generic upstream interface -type UpstreamPlugin struct { - Upstreams []Upstream - Next plugin.Handler -} - -// Initialize the upstream plugin -func New() *UpstreamPlugin { - p := &UpstreamPlugin{ - Upstreams: []Upstream{}, - } - - return p -} - -// ServeDNS implements interface for CoreDNS plugin -func (p *UpstreamPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { - var reply *dns.Msg - var backendErr error - - for i := range p.Upstreams { - upstream := p.Upstreams[i] - reply, backendErr = upstream.Exchange(ctx, r) - if backendErr == nil { - w.WriteMsg(reply) - return 0, nil - } - } - - return dns.RcodeServerFailure, errors.Wrap(backendErr, "failed to contact any of the upstreams") -} - -// Name implements interface for CoreDNS plugin -func (p *UpstreamPlugin) Name() string { - return "upstream" -} diff --git a/upstream/upstream_test.go b/upstream/upstream_test.go deleted file mode 100644 index 9221e6f5..00000000 --- a/upstream/upstream_test.go +++ /dev/null @@ -1,187 +0,0 @@ -package upstream - -import ( - "net" - "testing" - - "github.com/miekg/dns" - "golang.org/x/net/context" -) - -func TestDnsUpstreamIsAlive(t *testing.T) { - var tests = []struct { - url string - bootstrap string - }{ - {"8.8.8.8:53", "8.8.8.8:53"}, - {"1.1.1.1", ""}, - {"tcp://1.1.1.1:53", ""}, - {"176.103.130.130:5353", ""}, - } - - for _, test := range tests { - u, err := NewUpstream(test.url, test.bootstrap) - - if err != nil { - t.Errorf("cannot create a DNS upstream") - } - - testUpstreamIsAlive(t, u) - } -} - -func TestHttpsUpstreamIsAlive(t *testing.T) { - var tests = []struct { - url string - bootstrap string - }{ - {"https://cloudflare-dns.com/dns-query", "8.8.8.8:53"}, - {"https://dns.google.com/experimental", "8.8.8.8:53"}, - {"https://doh.cleanbrowsing.org/doh/security-filter/", ""}, - } - - for _, test := range tests { - u, err := NewUpstream(test.url, test.bootstrap) - - if err != nil { - t.Errorf("cannot create a DNS-over-HTTPS upstream") - } - - testUpstreamIsAlive(t, u) - } -} - -func TestDnsOverTlsIsAlive(t *testing.T) { - var tests = []struct { - url string - bootstrap string - }{ - {"tls://1.1.1.1", ""}, - {"tls://9.9.9.9:853", ""}, - {"tls://security-filter-dns.cleanbrowsing.org", "8.8.8.8:53"}, - {"tls://adult-filter-dns.cleanbrowsing.org:853", "8.8.8.8:53"}, - } - - for _, test := range tests { - u, err := NewUpstream(test.url, test.bootstrap) - - if err != nil { - t.Errorf("cannot create a DNS-over-TLS upstream") - } - - testUpstreamIsAlive(t, u) - } -} - -func TestDnsUpstream(t *testing.T) { - var tests = []struct { - url string - bootstrap string - }{ - {"8.8.8.8:53", "8.8.8.8:53"}, - {"1.1.1.1", ""}, - {"tcp://1.1.1.1:53", ""}, - {"176.103.130.130:5353", ""}, - } - - for _, test := range tests { - u, err := NewUpstream(test.url, test.bootstrap) - - if err != nil { - t.Errorf("cannot create a DNS upstream") - } - - testUpstream(t, u) - } -} - -func TestHttpsUpstream(t *testing.T) { - var tests = []struct { - url string - bootstrap string - }{ - {"https://cloudflare-dns.com/dns-query", "8.8.8.8:53"}, - {"https://dns.google.com/experimental", "8.8.8.8:53"}, - {"https://doh.cleanbrowsing.org/doh/security-filter/", ""}, - } - - for _, test := range tests { - u, err := NewUpstream(test.url, test.bootstrap) - - if err != nil { - t.Errorf("cannot create a DNS-over-HTTPS upstream") - } - - testUpstream(t, u) - } -} - -func TestDnsOverTlsUpstream(t *testing.T) { - var tests = []struct { - url string - bootstrap string - }{ - {"tls://1.1.1.1", ""}, - {"tls://9.9.9.9:853", ""}, - {"tls://security-filter-dns.cleanbrowsing.org", "8.8.8.8:53"}, - {"tls://adult-filter-dns.cleanbrowsing.org:853", "8.8.8.8:53"}, - } - - for _, test := range tests { - u, err := NewUpstream(test.url, test.bootstrap) - - if err != nil { - t.Errorf("cannot create a DNS-over-TLS upstream") - } - - testUpstream(t, u) - } -} - -func testUpstreamIsAlive(t *testing.T, u Upstream) { - alive, err := IsAlive(u) - if !alive || err != nil { - t.Errorf("Upstream is not alive") - } - - u.Close() -} - -func testUpstream(t *testing.T, u Upstream) { - var tests = []struct { - name string - expected net.IP - }{ - {"google-public-dns-a.google.com.", net.IPv4(8, 8, 8, 8)}, - {"google-public-dns-b.google.com.", net.IPv4(8, 8, 4, 4)}, - } - - for _, test := range tests { - req := dns.Msg{} - req.Id = dns.Id() - req.RecursionDesired = true - req.Question = []dns.Question{ - {Name: test.name, Qtype: dns.TypeA, Qclass: dns.ClassINET}, - } - - resp, err := u.Exchange(context.Background(), &req) - - if err != nil { - t.Fatalf("error while making an upstream request: %s", err) - } - - if len(resp.Answer) != 1 { - t.Fatalf("no answer section in the response") - } - if answer, ok := resp.Answer[0].(*dns.A); ok { - if !test.expected.Equal(answer.A) { - t.Errorf("wrong IP in the response: %v", answer.A) - } - } - } - - err := u.Close() - if err != nil { - t.Errorf("Error while closing the upstream: %s", err) - } -} From 2745223dbfa4a0263ae224a76032d56b2dd5aaf6 Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Wed, 5 Dec 2018 19:26:54 +0300 Subject: [PATCH 42/58] Remove old entries from .gitignore --- .gitignore | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index db6d4a1d..20902175 100644 --- a/.gitignore +++ b/.gitignore @@ -1,15 +1,11 @@ .DS_Store -.vscode -.idea -debug +/.vscode +/.idea /AdGuardHome /AdGuardHome.yaml /data/ /build/ /client/node_modules/ -/coredns -/Corefile -/dnsfilter.txt /querylog.json /querylog.json.1 /scripts/translations/node_modules From bb6c596b22cda0a163320516ef38fb2e49ff3c0b Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Wed, 5 Dec 2018 20:13:35 +0300 Subject: [PATCH 43/58] dnsforward -- add a simple test that launches a server and queries well-known value through it --- dnsforward/dnsforward_test.go | 49 +++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 dnsforward/dnsforward_test.go diff --git a/dnsforward/dnsforward_test.go b/dnsforward/dnsforward_test.go new file mode 100644 index 00000000..26dabb4b --- /dev/null +++ b/dnsforward/dnsforward_test.go @@ -0,0 +1,49 @@ +package dnsforward + +import ( + "net" + "testing" + + "github.com/miekg/dns" +) + +func TestServer(t *testing.T) { + s := Server{} + s.UDPListenAddr = &net.UDPAddr{Port: 0} + err := s.Start(nil) + if err != nil { + t.Fatalf("Failed to start server: %s", err) + } + if s.udpListen == nil { + t.Fatal("Started server has nil udpListen") + } + + // server is running, send a message + addr := s.udpListen.LocalAddr() + req := dns.Msg{} + req.Id = dns.Id() + req.RecursionDesired = true + req.Question = []dns.Question{ + {Name: "google-public-dns-a.google.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, + } + + reply, err := dns.Exchange(&req, addr.String()) + if err != nil { + t.Fatalf("Couldn't talk to server %s: %s", addr, err) + } + if len(reply.Answer) != 1 { + t.Fatalf("DNS server %s returned reply with wrong number of answers - %d", addr, len(reply.Answer)) + } + if a, ok := reply.Answer[0].(*dns.A); ok { + if !net.IPv4(8, 8, 8, 8).Equal(a.A) { + t.Fatalf("DNS server %s returned wrong answer instead of 8.8.8.8: %v", addr, a.A) + } + } else { + t.Fatalf("DNS server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0]) + } + + err = s.Stop() + if err != nil { + t.Fatalf("DNS server %s failed to stop: %s", addr, err) + } +} From e31905864b23a25ae1797dbc081f36f66d904ac9 Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Wed, 5 Dec 2018 20:29:00 +0300 Subject: [PATCH 44/58] Get rid of mentions of CoreDNS in code except for upgrading and in readme. Add config upgrade. --- Makefile | 2 +- README.md | 6 ++--- config.go | 38 ++++++++++++------------------- control.go | 65 +++++++++++++++++++++++++++--------------------------- coredns.go | 6 ++--- upgrade.go | 49 +++++++++++++++++++++++++++++++++++----- 6 files changed, 98 insertions(+), 68 deletions(-) diff --git a/Makefile b/Makefile index 4ff25301..9e9ae505 100644 --- a/Makefile +++ b/Makefile @@ -19,7 +19,7 @@ client/node_modules: client/package.json client/package-lock.json $(STATIC): $(JSFILES) client/node_modules npm --prefix client run build-prod -$(TARGET): $(STATIC) *.go coredns_plugin/*.go dnsfilter/*.go +$(TARGET): $(STATIC) *.go dnsfilter/*.go dnsforward/*.go go get -d . GOOS=$(NATIVE_GOOS) GOARCH=$(NATIVE_GOARCH) GO111MODULE=off go get -v github.com/gobuffalo/packr/... PATH=$(GOPATH)/bin:$(PATH) packr -z diff --git a/README.md b/README.md index 37ee39d3..4096ad93 100644 --- a/README.md +++ b/README.md @@ -90,7 +90,7 @@ Now open the browser and navigate to http://localhost:3000/ to control your AdGu You can run AdGuard Home without superuser privileges, but you need to instruct it to use a different port rather than 53. You can do that by editing `AdGuardHome.yaml` and finding these two lines: ```yaml -coredns: +dns: port: 53 ``` @@ -108,7 +108,7 @@ Settings are stored in [YAML format](https://en.wikipedia.org/wiki/YAML), possib * `bind_port` — Web interface IP port to listen on * `auth_name` — Web interface optional authorization username * `auth_pass` — Web interface optional authorization password - * `coredns` — CoreDNS configuration section + * `dns` — DNS configuration section * `port` — DNS server port to listen on * `filtering_enabled` — Filtering of DNS requests based on filter lists * `safebrowsing_enabled` — Filtering of DNS requests based on safebrowsing @@ -208,6 +208,6 @@ This software wouldn't have been possible without: * And many more node.js packages. * [whotracks.me data](https://github.com/cliqz-oss/whotracks.me) -You might have seen coredns mentioned here before, but we've stopped using it in AdGuardHome. While we still use it on our servers, it seemed like an overkill and impeded with Home features that we wanted to implement. +You might have seen that coredns was mentioned here before — we've stopped using it in AdGuardHome. While we still use it on our servers, it seemed like an overkill and it impeded with Home features that we wanted to implement. For a full list of all node.js packages in use, please take a look at [client/package.json](https://github.com/AdguardTeam/AdGuardHome/blob/master/client/package.json) file. diff --git a/config.go b/config.go index 52efdfd9..3bf6fbe4 100644 --- a/config.go +++ b/config.go @@ -13,9 +13,8 @@ import ( ) const ( - currentSchemaVersion = 1 // used for upgrading from old configs to new config - dataDir = "data" // data storage - filterDir = "filters" // cache location for downloaded filters, it's under DataDir + dataDir = "data" // data storage + filterDir = "filters" // cache location for downloaded filters, it's under DataDir ) // configuration is loaded from YAML @@ -24,14 +23,14 @@ type configuration struct { ourConfigFilename string // Config filename (can be overriden via the command line arguments) ourBinaryDir string // Location of our directory, used to protect against CWD being somewhere else - BindHost string `yaml:"bind_host"` - BindPort int `yaml:"bind_port"` - AuthName string `yaml:"auth_name"` - AuthPass string `yaml:"auth_pass"` - Language string `yaml:"language"` // two-letter ISO 639-1 language code - CoreDNS coreDNSConfig `yaml:"coredns"` - Filters []filter `yaml:"filters"` - UserRules []string `yaml:"user_rules,omitempty"` + BindHost string `yaml:"bind_host"` + BindPort int `yaml:"bind_port"` + AuthName string `yaml:"auth_name"` + AuthPass string `yaml:"auth_pass"` + Language string `yaml:"language"` // two-letter ISO 639-1 language code + DNS dnsConfig `yaml:"dns"` + Filters []filter `yaml:"filters"` + UserRules []string `yaml:"user_rules,omitempty"` sync.RWMutex `yaml:"-"` @@ -39,16 +38,11 @@ type configuration struct { } // field ordering is important -- yaml fields will mirror ordering from here -type coreDNSConfig struct { - binaryFile string - coreFile string - Port int `yaml:"port"` +type dnsConfig struct { + Port int `yaml:"port"` dnsforward.FilteringConfig `yaml:",inline"` - Pprof string `yaml:"-"` - Cache string `yaml:"-"` - Prometheus string `yaml:"-"` BootstrapDNS string `yaml:"bootstrap_dns"` UpstreamDNS []string `yaml:"upstream_dns"` } @@ -60,10 +54,8 @@ var config = configuration{ ourConfigFilename: "AdGuardHome.yaml", BindPort: 3000, BindHost: "127.0.0.1", - CoreDNS: coreDNSConfig{ - Port: 53, - binaryFile: "coredns", // only filename, no path - coreFile: "Corefile", // only filename, no path + DNS: dnsConfig{ + Port: 53, FilteringConfig: dnsforward.FilteringConfig{ ProtectionEnabled: true, // whether or not use any of dnsfilter features FilteringEnabled: true, // whether or not use filter lists @@ -74,8 +66,6 @@ var config = configuration{ }, BootstrapDNS: "8.8.8.8:53", UpstreamDNS: defaultDNS, - Cache: "cache", - Prometheus: "prometheus :9153", }, Filters: []filter{ {Filter: dnsfilter.Filter{ID: 1}, Enabled: true, URL: "https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt", Name: "AdGuard Simplified Domain Names filter"}, diff --git a/control.go b/control.go index 0d57816d..49869ddc 100644 --- a/control.go +++ b/control.go @@ -32,9 +32,9 @@ var client = &http.Client{ } // ------------------- -// coredns run control +// dns run control // ------------------- -func writeAllConfigsAndReloadCoreDNS() error { +func writeAllConfigsAndReloadDNS() error { err := writeAllConfigs() if err != nil { log.Printf("Couldn't write all configs: %s", err) @@ -45,7 +45,7 @@ func writeAllConfigsAndReloadCoreDNS() error { } func httpUpdateConfigReloadDNSReturnOK(w http.ResponseWriter, r *http.Request) { - err := writeAllConfigsAndReloadCoreDNS() + err := writeAllConfigsAndReloadDNS() if err != nil { errortext := fmt.Sprintf("Couldn't write config file: %s", err) log.Println(errortext) @@ -67,12 +67,12 @@ func returnOK(w http.ResponseWriter, r *http.Request) { func handleStatus(w http.ResponseWriter, r *http.Request) { data := map[string]interface{}{ "dns_address": config.BindHost, - "dns_port": config.CoreDNS.Port, - "protection_enabled": config.CoreDNS.ProtectionEnabled, - "querylog_enabled": config.CoreDNS.QueryLogEnabled, + "dns_port": config.DNS.Port, + "protection_enabled": config.DNS.ProtectionEnabled, + "querylog_enabled": config.DNS.QueryLogEnabled, "running": isRunning(), - "bootstrap_dns": config.CoreDNS.BootstrapDNS, - "upstream_dns": config.CoreDNS.UpstreamDNS, + "bootstrap_dns": config.DNS.BootstrapDNS, + "upstream_dns": config.DNS.UpstreamDNS, "version": VersionString, "language": config.Language, } @@ -95,12 +95,12 @@ func handleStatus(w http.ResponseWriter, r *http.Request) { } func handleProtectionEnable(w http.ResponseWriter, r *http.Request) { - config.CoreDNS.ProtectionEnabled = true + config.DNS.ProtectionEnabled = true httpUpdateConfigReloadDNSReturnOK(w, r) } func handleProtectionDisable(w http.ResponseWriter, r *http.Request) { - config.CoreDNS.ProtectionEnabled = false + config.DNS.ProtectionEnabled = false httpUpdateConfigReloadDNSReturnOK(w, r) } @@ -108,12 +108,12 @@ func handleProtectionDisable(w http.ResponseWriter, r *http.Request) { // stats // ----- func handleQueryLogEnable(w http.ResponseWriter, r *http.Request) { - config.CoreDNS.QueryLogEnabled = true + config.DNS.QueryLogEnabled = true httpUpdateConfigReloadDNSReturnOK(w, r) } func handleQueryLogDisable(w http.ResponseWriter, r *http.Request) { - config.CoreDNS.QueryLogEnabled = false + config.DNS.QueryLogEnabled = false httpUpdateConfigReloadDNSReturnOK(w, r) } @@ -135,9 +135,9 @@ func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) { hosts := strings.Fields(string(body)) if len(hosts) == 0 { - config.CoreDNS.UpstreamDNS = defaultDNS + config.DNS.UpstreamDNS = defaultDNS } else { - config.CoreDNS.UpstreamDNS = hosts + config.DNS.UpstreamDNS = hosts } err = writeAllConfigs() @@ -243,7 +243,7 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) { resp, err := client.Get(versionCheckURL) if err != nil { - errortext := fmt.Sprintf("Couldn't get querylog from coredns: %T %s\n", err, err) + errortext := fmt.Sprintf("Couldn't get version check json from %s: %T %s\n", versionCheckURL, err, err) log.Println(errortext) http.Error(w, errortext, http.StatusBadGateway) return @@ -255,7 +255,7 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) { // read the body entirely body, err := ioutil.ReadAll(resp.Body) if err != nil { - errortext := fmt.Sprintf("Couldn't read response body: %s", err) + errortext := fmt.Sprintf("Couldn't read response body from %s: %s", versionCheckURL, err) log.Println(errortext) http.Error(w, errortext, http.StatusBadGateway) return @@ -278,18 +278,18 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) { // --------- func handleFilteringEnable(w http.ResponseWriter, r *http.Request) { - config.CoreDNS.FilteringEnabled = true + config.DNS.FilteringEnabled = true httpUpdateConfigReloadDNSReturnOK(w, r) } func handleFilteringDisable(w http.ResponseWriter, r *http.Request) { - config.CoreDNS.FilteringEnabled = false + config.DNS.FilteringEnabled = false httpUpdateConfigReloadDNSReturnOK(w, r) } func handleFilteringStatus(w http.ResponseWriter, r *http.Request) { data := map[string]interface{}{ - "enabled": config.CoreDNS.FilteringEnabled, + "enabled": config.DNS.FilteringEnabled, } config.RLock() @@ -377,7 +377,8 @@ func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) { return } - // URL is deemed valid, append it to filters, update config, write new filter file and tell coredns to reload it + // URL is deemed valid, append it to filters, update config, write new filter file and tell dns to reload it + // TODO: since we directly feed filters in-memory, revisit if writing configs is always neccessary config.Filters = append(config.Filters, filter) err = writeAllConfigs() if err != nil { @@ -537,18 +538,18 @@ func handleFilteringRefresh(w http.ResponseWriter, r *http.Request) { // ------------ func handleSafeBrowsingEnable(w http.ResponseWriter, r *http.Request) { - config.CoreDNS.SafeBrowsingEnabled = true + config.DNS.SafeBrowsingEnabled = true httpUpdateConfigReloadDNSReturnOK(w, r) } func handleSafeBrowsingDisable(w http.ResponseWriter, r *http.Request) { - config.CoreDNS.SafeBrowsingEnabled = false + config.DNS.SafeBrowsingEnabled = false httpUpdateConfigReloadDNSReturnOK(w, r) } func handleSafeBrowsingStatus(w http.ResponseWriter, r *http.Request) { data := map[string]interface{}{ - "enabled": config.CoreDNS.SafeBrowsingEnabled, + "enabled": config.DNS.SafeBrowsingEnabled, } jsonVal, err := json.Marshal(data) if err != nil { @@ -611,22 +612,22 @@ func handleParentalEnable(w http.ResponseWriter, r *http.Request) { http.Error(w, "Sensitivity must be set to valid value", 400) return } - config.CoreDNS.ParentalSensitivity = i - config.CoreDNS.ParentalEnabled = true + config.DNS.ParentalSensitivity = i + config.DNS.ParentalEnabled = true httpUpdateConfigReloadDNSReturnOK(w, r) } func handleParentalDisable(w http.ResponseWriter, r *http.Request) { - config.CoreDNS.ParentalEnabled = false + config.DNS.ParentalEnabled = false httpUpdateConfigReloadDNSReturnOK(w, r) } func handleParentalStatus(w http.ResponseWriter, r *http.Request) { data := map[string]interface{}{ - "enabled": config.CoreDNS.ParentalEnabled, + "enabled": config.DNS.ParentalEnabled, } - if config.CoreDNS.ParentalEnabled { - data["sensitivity"] = config.CoreDNS.ParentalSensitivity + if config.DNS.ParentalEnabled { + data["sensitivity"] = config.DNS.ParentalSensitivity } jsonVal, err := json.Marshal(data) if err != nil { @@ -651,18 +652,18 @@ func handleParentalStatus(w http.ResponseWriter, r *http.Request) { // ------------ func handleSafeSearchEnable(w http.ResponseWriter, r *http.Request) { - config.CoreDNS.SafeSearchEnabled = true + config.DNS.SafeSearchEnabled = true httpUpdateConfigReloadDNSReturnOK(w, r) } func handleSafeSearchDisable(w http.ResponseWriter, r *http.Request) { - config.CoreDNS.SafeSearchEnabled = false + config.DNS.SafeSearchEnabled = false httpUpdateConfigReloadDNSReturnOK(w, r) } func handleSafeSearchStatus(w http.ResponseWriter, r *http.Request) { data := map[string]interface{}{ - "enabled": config.CoreDNS.SafeSearchEnabled, + "enabled": config.DNS.SafeSearchEnabled, } jsonVal, err := json.Marshal(data) if err != nil { diff --git a/coredns.go b/coredns.go index 119b2371..250c9e37 100644 --- a/coredns.go +++ b/coredns.go @@ -31,12 +31,12 @@ func generateServerConfig() dnsforward.ServerConfig { } newconfig := dnsforward.ServerConfig{ - UDPListenAddr: &net.UDPAddr{Port: config.CoreDNS.Port}, - FilteringConfig: config.CoreDNS.FilteringConfig, + UDPListenAddr: &net.UDPAddr{Port: config.DNS.Port}, + FilteringConfig: config.DNS.FilteringConfig, Filters: filters, } - for _, u := range config.CoreDNS.UpstreamDNS { + for _, u := range config.DNS.UpstreamDNS { upstream, err := dnsforward.GetUpstream(u) if err != nil { log.Printf("Couldn't get upstream: %s", err) diff --git a/upgrade.go b/upgrade.go index 4154ee03..1b8c6e34 100644 --- a/upgrade.go +++ b/upgrade.go @@ -10,6 +10,8 @@ import ( "gopkg.in/yaml.v2" ) +const currentSchemaVersion = 2 // used for upgrading from old configs to new config + // Performs necessary upgrade operations if needed func upgradeConfig() error { // read a config file into an interface map, so we can manipulate values without losing any @@ -57,7 +59,12 @@ func upgradeConfig() error { func upgradeConfigSchema(oldVersion int, diskConfig *map[string]interface{}) error { switch oldVersion { case 0: - err := upgradeSchema0to1(diskConfig) + err := upgradeSchema0to2(diskConfig) + if err != nil { + return err + } + case 1: + err := upgradeSchema1to2(diskConfig) if err != nil { return err } @@ -83,14 +90,13 @@ func upgradeConfigSchema(oldVersion int, diskConfig *map[string]interface{}) err return nil } +// The first schema upgrade: +// No more "dnsfilter.txt", filters are now kept in data/filters/ func upgradeSchema0to1(diskConfig *map[string]interface{}) error { log.Printf("%s(): called", _Func()) - // The first schema upgrade: - // No more "dnsfilter.txt", filters are now kept in data/filters/ dnsFilterPath := filepath.Join(config.ourBinaryDir, "dnsfilter.txt") - _, err := os.Stat(dnsFilterPath) - if !os.IsNotExist(err) { + if _, err := os.Stat(dnsFilterPath); !os.IsNotExist(err) { log.Printf("Deleting %s as we don't need it anymore", dnsFilterPath) err = os.Remove(dnsFilterPath) if err != nil { @@ -103,3 +109,36 @@ func upgradeSchema0to1(diskConfig *map[string]interface{}) error { return nil } + +// Second schema upgrade: +// coredns is now dns in config +// delete 'Corefile', since we don't use that anymore +func upgradeSchema1to2(diskConfig *map[string]interface{}) error { + log.Printf("%s(): called", _Func()) + + coreFilePath := filepath.Join(config.ourBinaryDir, "Corefile") + if _, err := os.Stat(coreFilePath); !os.IsNotExist(err) { + log.Printf("Deleting %s as we don't need it anymore", coreFilePath) + err = os.Remove(coreFilePath) + if err != nil { + log.Printf("Cannot remove %s due to %s", coreFilePath, err) + // not fatal, move on + } + } + + (*diskConfig)["dns"] = (*diskConfig)["coredns"] + delete((*diskConfig), "coredns") + (*diskConfig)["schema_version"] = 2 + + return nil +} + +// jump two schemas at once -- this time we just do it sequentially +func upgradeSchema0to2(diskConfig *map[string]interface{}) error { + err := upgradeSchema0to1(diskConfig) + if err != nil { + return err + } + + return upgradeSchema1to2(diskConfig) +} From 609523a59c5ca635665e7b4445acbcdfbfd8cf08 Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Wed, 5 Dec 2018 21:09:37 +0300 Subject: [PATCH 45/58] Don't omit empty user rules in configfile -- otherwise users might not be able to find that it's customizable in configfile. --- config.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config.go b/config.go index 3bf6fbe4..89822d3f 100644 --- a/config.go +++ b/config.go @@ -30,7 +30,7 @@ type configuration struct { Language string `yaml:"language"` // two-letter ISO 639-1 language code DNS dnsConfig `yaml:"dns"` Filters []filter `yaml:"filters"` - UserRules []string `yaml:"user_rules,omitempty"` + UserRules []string `yaml:"user_rules"` sync.RWMutex `yaml:"-"` From 2a1059107a0a9eb7f9ffcaec3b96a60b98bc1591 Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Wed, 5 Dec 2018 21:33:07 +0300 Subject: [PATCH 46/58] dnsforward -- add upstream tests. --- dnsforward/upstream_test.go | 84 +++++++++++++++++++++++++++++++++++++ 1 file changed, 84 insertions(+) create mode 100644 dnsforward/upstream_test.go diff --git a/dnsforward/upstream_test.go b/dnsforward/upstream_test.go new file mode 100644 index 00000000..975c5035 --- /dev/null +++ b/dnsforward/upstream_test.go @@ -0,0 +1,84 @@ +package dnsforward + +import ( + "net" + "testing" + + "github.com/miekg/dns" +) + +func TestUpstreamDNS(t *testing.T) { + upstreams := []string{ + "8.8.8.8:53", + "1.1.1.1", + "tcp://1.1.1.1:53", + "176.103.130.130:5353", + } + for _, input := range upstreams { + u, err := GetUpstream(input) + if err != nil { + t.Fatalf("Failed to choose upstream for %s: %s", input, err) + } + + checkUpstream(t, u, input) + } +} + +func TestUpstreamTLS(t *testing.T) { + upstreams := []string{ + "tls://1.1.1.1", + "tls://9.9.9.9:853", + "tls://security-filter-dns.cleanbrowsing.org", + "tls://adult-filter-dns.cleanbrowsing.org:853", + } + for _, input := range upstreams { + u, err := GetUpstream(input) + if err != nil { + t.Fatalf("Failed to choose upstream for %s: %s", input, err) + } + + checkUpstream(t, u, input) + } +} + +func TestUpstreamHTTPS(t *testing.T) { + upstreams := []string{ + "https://cloudflare-dns.com/dns-query", + "https://dns.google.com/experimental", + "https://doh.cleanbrowsing.org/doh/security-filter/", + } + for _, input := range upstreams { + u, err := GetUpstream(input) + if err != nil { + t.Fatalf("Failed to choose upstream for %s: %s", input, err) + } + + checkUpstream(t, u, input) + } +} + +func checkUpstream(t *testing.T, u Upstream, addr string) { + t.Helper() + + req := dns.Msg{} + req.Id = dns.Id() + req.RecursionDesired = true + req.Question = []dns.Question{ + {Name: "google-public-dns-a.google.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, + } + + reply, err := u.Exchange(&req) + if err != nil { + t.Fatalf("Couldn't talk to upstream %s: %s", addr, err) + } + if len(reply.Answer) != 1 { + t.Fatalf("DNS upstream %s returned reply with wrong number of answers - %d", addr, len(reply.Answer)) + } + if a, ok := reply.Answer[0].(*dns.A); ok { + if !net.IPv4(8, 8, 8, 8).Equal(a.A) { + t.Fatalf("DNS upstream %s returned wrong answer instead of 8.8.8.8: %v", addr, a.A) + } + } else { + t.Fatalf("DNS upstream %s returned wrong answer type instead of A: %v", addr, reply.Answer[0]) + } +} From ff1c19cac5bdc15d1d773a50ad86aa2d3ab7bca2 Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Wed, 5 Dec 2018 21:33:32 +0300 Subject: [PATCH 47/58] dnsforward -- support tcp:// schema --- dnsforward/upstream.go | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/dnsforward/upstream.go b/dnsforward/upstream.go index 1ca27242..5d672f5a 100644 --- a/dnsforward/upstream.go +++ b/dnsforward/upstream.go @@ -29,7 +29,8 @@ type Upstream interface { // plain DNS // type plainDNS struct { - address string + address string + preferTCP bool } var defaultUDPClient = dns.Client{ @@ -46,6 +47,11 @@ var defaultTCPClient = dns.Client{ func (p *plainDNS) Address() string { return p.address } func (p *plainDNS) Exchange(m *dns.Msg) (*dns.Msg, error) { + if p.preferTCP { + reply, _, err := defaultTCPClient.Exchange(m, p.address) + return reply, err + } + reply, _, err := defaultUDPClient.Exchange(m, p.address) if err != nil && reply != nil && reply.Truncated { log.Printf("Truncated message was received, retrying over TCP, question: %s", m.Question[0].String()) @@ -183,7 +189,12 @@ func GetUpstream(address string) (Upstream, error) { if url.Port() == "" { url.Host += ":53" } - return &plainDNS{address: url.String()}, nil + return &plainDNS{address: url.Host}, nil + case "tcp": + if url.Port() == "" { + url.Host += ":53" + } + return &plainDNS{address: url.Host, preferTCP: true}, nil case "tls": if url.Port() == "" { url.Host += ":853" From 0f5dd661f5a3dd0da4b60a0916a81466c4816d99 Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Thu, 6 Dec 2018 00:22:20 +0300 Subject: [PATCH 48/58] Add support for bootstrapping upstream DNS servers by hostname. --- README.md | 2 +- config.go | 7 +-- control.go | 2 +- coredns.go | 2 +- dnsforward/bootstrap.go | 107 ++++++++++++++++++++++++++++++++++++ dnsforward/dnsforward.go | 29 +++++----- dnsforward/upstream.go | 87 ++++++++++++++++++++--------- dnsforward/upstream_pool.go | 46 ++++------------ dnsforward/upstream_test.go | 102 +++++++++++++++++++--------------- 9 files changed, 256 insertions(+), 128 deletions(-) create mode 100644 dnsforward/bootstrap.go diff --git a/README.md b/README.md index 4096ad93..b9916574 100644 --- a/README.md +++ b/README.md @@ -116,7 +116,7 @@ Settings are stored in [YAML format](https://en.wikipedia.org/wiki/YAML), possib * `parental_enabled` — Parental control-based DNS requests filtering * `parental_sensitivity` — Age group for parental control-based filtering, must be either 3, 10, 13 or 17 * `querylog_enabled` — Query logging (also used to calculate top 50 clients, blocked domains and requested domains for statistic purposes) - * `bootstrap_dns` — DNS server used for initial hostnames resolution in case if upstream is DoH or DoT with a hostname + * `bootstrap_dns` — DNS server used for initial hostname resolution in case if upstream server name is a hostname * `upstream_dns` — List of upstream DNS servers * `filters` — List of filters, each filter has the following values: * `ID` - filter ID (must be unique) diff --git a/config.go b/config.go index 89822d3f..8facc405 100644 --- a/config.go +++ b/config.go @@ -43,8 +43,7 @@ type dnsConfig struct { dnsforward.FilteringConfig `yaml:",inline"` - BootstrapDNS string `yaml:"bootstrap_dns"` - UpstreamDNS []string `yaml:"upstream_dns"` + UpstreamDNS []string `yaml:"upstream_dns"` } var defaultDNS = []string{"tls://1.1.1.1", "tls://1.0.0.1"} @@ -63,9 +62,9 @@ var config = configuration{ QueryLogEnabled: true, Ratelimit: 20, RefuseAny: true, + BootstrapDNS: "8.8.8.8:53", }, - BootstrapDNS: "8.8.8.8:53", - UpstreamDNS: defaultDNS, + UpstreamDNS: defaultDNS, }, Filters: []filter{ {Filter: dnsfilter.Filter{ID: 1}, Enabled: true, URL: "https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt", Name: "AdGuard Simplified Domain Names filter"}, diff --git a/control.go b/control.go index 49869ddc..2674585c 100644 --- a/control.go +++ b/control.go @@ -204,7 +204,7 @@ func handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) { func checkDNS(input string) error { log.Printf("Checking if DNS %s works...", input) - u, err := dnsforward.GetUpstream(input) + u, err := dnsforward.AddressToUpstream(input, "") if err != nil { return fmt.Errorf("Failed to choose upstream for %s: %s", input, err) } diff --git a/coredns.go b/coredns.go index 250c9e37..42894336 100644 --- a/coredns.go +++ b/coredns.go @@ -37,7 +37,7 @@ func generateServerConfig() dnsforward.ServerConfig { } for _, u := range config.DNS.UpstreamDNS { - upstream, err := dnsforward.GetUpstream(u) + upstream, err := dnsforward.AddressToUpstream(u, config.DNS.BootstrapDNS) if err != nil { log.Printf("Couldn't get upstream: %s", err) // continue, just ignore the upstream diff --git a/dnsforward/bootstrap.go b/dnsforward/bootstrap.go new file mode 100644 index 00000000..2d263871 --- /dev/null +++ b/dnsforward/bootstrap.go @@ -0,0 +1,107 @@ +package dnsforward + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "net/url" + "strings" + "sync" + + "github.com/joomcode/errorx" +) + +type bootstrapper struct { + address string // in form of "tls://one.one.one.one:853" + resolver *net.Resolver // resolver to use to resolve hostname, if neccessary + resolved string // in form "IP:port" + resolvedConfig *tls.Config + sync.Mutex +} + +func toBoot(address, bootstrapAddr string) bootstrapper { + var resolver *net.Resolver + if bootstrapAddr != "" { + resolver = &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + d := net.Dialer{} + return d.DialContext(ctx, network, bootstrapAddr) + }, + } + } + return bootstrapper{ + address: address, + resolver: resolver, + } +} + +// will get usable IP address from Address field, and caches the result +func (n *bootstrapper) get() (string, *tls.Config, error) { + // TODO: RLock() here but atomically upgrade to Lock() if fast path doesn't work + n.Lock() + if n.resolved != "" { // fast path + retval, tlsconfig := n.resolved, n.resolvedConfig + n.Unlock() + return retval, tlsconfig, nil + } + + // + // slow path + // + + defer n.Unlock() + + justHostPort := n.address + if strings.Contains(n.address, "://") { + url, err := url.Parse(n.address) + if err != nil { + return "", nil, errorx.Decorate(err, "Failed to parse %s", n.address) + } + + justHostPort = url.Host + } + + // convert host to IP if neccessary, we know that it's scheme://hostname:port/ + + // get a host without port + host, port, err := net.SplitHostPort(justHostPort) + if err != nil { + return "", nil, fmt.Errorf("bootstrapper requires port in address %s", n.address) + } + + // if it's an IP + ip := net.ParseIP(host) + if ip != nil { + n.resolved = justHostPort + return n.resolved, nil, nil + } + + // + // if it's a hostname + // + + resolver := n.resolver // no need to check for nil resolver -- documented that nil is default resolver + addrs, err := resolver.LookupIPAddr(context.TODO(), host) + if err != nil { + return "", nil, errorx.Decorate(err, "Failed to lookup %s", host) + } + for _, addr := range addrs { + // TODO: support ipv6, support multiple ipv4 + if addr.IP.To4() == nil { + continue + } + ip = addr.IP + break + } + + if ip == nil { + // couldn't find any suitable IP address + return "", nil, fmt.Errorf("Couldn't find any suitable IP address for host %s", host) + } + + n.resolved = net.JoinHostPort(ip.String(), port) + n.resolvedConfig = &tls.Config{ServerName: host} + return n.resolved, n.resolvedConfig, nil +} diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index bee85d3a..404bbfb3 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -86,6 +86,7 @@ type FilteringConfig struct { Ratelimit int `yaml:"ratelimit"` RatelimitWhitelist []string `yaml:"ratelimit_whitelist"` RefuseAny bool `yaml:"refuse_any"` + BootstrapDNS string `yaml:"bootstrap_dns"` dnsfilter.Config `yaml:",inline"` } @@ -105,24 +106,24 @@ var defaultValues = ServerConfig{ FilteringConfig: FilteringConfig{BlockedResponseTTL: 3600}, Upstreams: []Upstream{ //// dns over HTTPS - // &dnsOverHTTPS{address: "https://1.1.1.1/dns-query"}, - // &dnsOverHTTPS{address: "https://dns.google.com/experimental"}, - // &dnsOverHTTPS{address: "https://doh.cleanbrowsing.org/doh/security-filter/"}, - // &dnsOverHTTPS{address: "https://dns10.quad9.net/dns-query"}, - // &dnsOverHTTPS{address: "https://doh.powerdns.org"}, - // &dnsOverHTTPS{address: "https://doh.securedns.eu/dns-query"}, + // &dnsOverHTTPS{boot: toBoot("https://1.1.1.1/dns-query", "")}, + // &dnsOverHTTPS{boot: toBoot("https://dns.google.com/experimental", "")}, + // &dnsOverHTTPS{boot: toBoot("https://doh.cleanbrowsing.org/doh/security-filter/", "")}, + // &dnsOverHTTPS{boot: toBoot("https://dns10.quad9.net/dns-query", "")}, + // &dnsOverHTTPS{boot: toBoot("https://doh.powerdns.org", "")}, + // &dnsOverHTTPS{boot: toBoot("https://doh.securedns.eu/dns-query", "")}, //// dns over TLS - // &dnsOverTLS{address: "tls://8.8.8.8:853"}, - // &dnsOverTLS{address: "tls://8.8.4.4:853"}, - // &dnsOverTLS{address: "tls://1.1.1.1:853"}, - // &dnsOverTLS{address: "tls://1.0.0.1:853"}, + // &dnsOverTLS{boot: toBoot("tls://8.8.8.8:853", "")}, + // &dnsOverTLS{boot: toBoot("tls://8.8.4.4:853", "")}, + // &dnsOverTLS{boot: toBoot("tls://1.1.1.1:853", "")}, + // &dnsOverTLS{boot: toBoot("tls://1.0.0.1:853", "")}, //// plainDNS - &plainDNS{address: "8.8.8.8:53"}, - &plainDNS{address: "8.8.4.4:53"}, - &plainDNS{address: "1.1.1.1:53"}, - &plainDNS{address: "1.0.0.1:53"}, + &plainDNS{boot: toBoot("8.8.8.8:53", "")}, + &plainDNS{boot: toBoot("8.8.4.4:53", "")}, + &plainDNS{boot: toBoot("1.1.1.1:53", "")}, + &plainDNS{boot: toBoot("1.0.0.1:53", "")}, }, } diff --git a/dnsforward/upstream.go b/dnsforward/upstream.go index 5d672f5a..99142929 100644 --- a/dnsforward/upstream.go +++ b/dnsforward/upstream.go @@ -29,7 +29,7 @@ type Upstream interface { // plain DNS // type plainDNS struct { - address string + boot bootstrapper preferTCP bool } @@ -44,19 +44,25 @@ var defaultTCPClient = dns.Client{ Timeout: defaultTimeout, } -func (p *plainDNS) Address() string { return p.address } +// Address returns the original address that we've put in initially, not resolved one +func (p *plainDNS) Address() string { return p.boot.address } func (p *plainDNS) Exchange(m *dns.Msg) (*dns.Msg, error) { + addr, _, err := p.boot.get() + if err != nil { + return nil, err + } if p.preferTCP { - reply, _, err := defaultTCPClient.Exchange(m, p.address) + reply, _, err := defaultTCPClient.Exchange(m, addr) return reply, err } - reply, _, err := defaultUDPClient.Exchange(m, p.address) + reply, _, err := defaultUDPClient.Exchange(m, addr) if err != nil && reply != nil && reply.Truncated { log.Printf("Truncated message was received, retrying over TCP, question: %s", m.Question[0].String()) - reply, _, err = defaultTCPClient.Exchange(m, p.address) + reply, _, err = defaultTCPClient.Exchange(m, addr) } + return reply, err } @@ -64,8 +70,8 @@ func (p *plainDNS) Exchange(m *dns.Msg) (*dns.Msg, error) { // DNS-over-TLS // type dnsOverTLS struct { - address string - pool *TLSPool + boot bootstrapper + pool *TLSPool sync.RWMutex // protects pool } @@ -77,7 +83,7 @@ var defaultTLSClient = dns.Client{ TLSConfig: &tls.Config{}, } -func (p *dnsOverTLS) Address() string { return p.address } +func (p *dnsOverTLS) Address() string { return p.boot.address } func (p *dnsOverTLS) Exchange(m *dns.Msg) (*dns.Msg, error) { var pool *TLSPool @@ -87,7 +93,7 @@ func (p *dnsOverTLS) Exchange(m *dns.Msg) (*dns.Msg, error) { if pool == nil { p.Lock() // lazy initialize it - p.pool = &TLSPool{Address: p.address} + p.pool = &TLSPool{boot: &p.boot} p.Unlock() } @@ -95,19 +101,19 @@ func (p *dnsOverTLS) Exchange(m *dns.Msg) (*dns.Msg, error) { poolConn, err := p.pool.Get() p.RUnlock() if err != nil { - return nil, errorx.Decorate(err, "Failed to get a connection from TLSPool to %s", p.address) + return nil, errorx.Decorate(err, "Failed to get a connection from TLSPool to %s", p.Address()) } c := dns.Conn{Conn: poolConn} err = c.WriteMsg(m) if err != nil { poolConn.Close() - return nil, errorx.Decorate(err, "Failed to send a request to %s", p.address) + return nil, errorx.Decorate(err, "Failed to send a request to %s", p.Address()) } reply, err := c.ReadMsg() if err != nil { poolConn.Close() - return nil, errorx.Decorate(err, "Failed to read a request from %s", p.address) + return nil, errorx.Decorate(err, "Failed to read a request from %s", p.Address()) } p.RLock() p.pool.Put(poolConn) @@ -119,7 +125,7 @@ func (p *dnsOverTLS) Exchange(m *dns.Msg) (*dns.Msg, error) { // DNS-over-https // type dnsOverHTTPS struct { - address string + boot bootstrapper } var defaultHTTPSTransport = http.Transport{} @@ -129,35 +135,59 @@ var defaultHTTPSClient = http.Client{ Timeout: defaultTimeout, } -func (p *dnsOverHTTPS) Address() string { return p.address } +func (p *dnsOverHTTPS) Address() string { return p.boot.address } func (p *dnsOverHTTPS) Exchange(m *dns.Msg) (*dns.Msg, error) { + addr, tlsConfig, err := p.boot.get() + if err != nil { + return nil, errorx.Decorate(err, "Couldn't bootstrap %s", p.boot.address) + } + buf, err := m.Pack() if err != nil { return nil, errorx.Decorate(err, "Couldn't pack request msg") } bb := bytes.NewBuffer(buf) - resp, err := http.Post(p.address, "application/dns-message", bb) + + // set up a custom request with custom URL + url, err := url.Parse(p.boot.address) + if err != nil { + return nil, errorx.Decorate(err, "Couldn't parse URL %s", p.boot.address) + } + req := http.Request{ + Method: "POST", + URL: url, + Body: ioutil.NopCloser(bb), + Header: make(http.Header), + Host: url.Host, + } + url.Host = addr + req.Header.Set("Content-Type", "application/dns-message") + client := http.Client{ + Transport: &http.Transport{TLSClientConfig: tlsConfig}, + } + resp, err := client.Do(&req) if resp != nil && resp.Body != nil { defer resp.Body.Close() } if err != nil { - return nil, errorx.Decorate(err, "Couldn't do a POST request to '%s'", p.address) + return nil, errorx.Decorate(err, "Couldn't do a POST request to '%s'", addr) } + body, err := ioutil.ReadAll(resp.Body) if err != nil { - return nil, errorx.Decorate(err, "Couldn't read body contents for '%s'", p.address) + return nil, errorx.Decorate(err, "Couldn't read body contents for '%s'", addr) } if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("Got an unexpected HTTP status code %d from '%s'", resp.StatusCode, p.address) + return nil, fmt.Errorf("Got an unexpected HTTP status code %d from '%s'", resp.StatusCode, addr) } if len(body) == 0 { - return nil, fmt.Errorf("Got an unexpected empty body from '%s'", p.address) + return nil, fmt.Errorf("Got an unexpected empty body from '%s'", addr) } response := dns.Msg{} err = response.Unpack(body) if err != nil { - return nil, errorx.Decorate(err, "Couldn't unpack DNS response from '%s': body is %s", p.address, string(body)) + return nil, errorx.Decorate(err, "Couldn't unpack DNS response from '%s': body is %s", addr, string(body)) } return &response, nil } @@ -178,7 +208,7 @@ func (s *Server) chooseUpstream() Upstream { return upstream } -func GetUpstream(address string) (Upstream, error) { +func AddressToUpstream(address string, bootstrap string) (Upstream, error) { if strings.Contains(address, "://") { url, err := url.Parse(address) if err != nil { @@ -189,25 +219,28 @@ func GetUpstream(address string) (Upstream, error) { if url.Port() == "" { url.Host += ":53" } - return &plainDNS{address: url.Host}, nil + return &plainDNS{boot: toBoot(url.Host, bootstrap)}, nil case "tcp": if url.Port() == "" { url.Host += ":53" } - return &plainDNS{address: url.Host, preferTCP: true}, nil + return &plainDNS{boot: toBoot(url.Host, bootstrap), preferTCP: true}, nil case "tls": if url.Port() == "" { url.Host += ":853" } - return &dnsOverTLS{address: url.String()}, nil + return &dnsOverTLS{boot: toBoot(url.String(), bootstrap)}, nil case "https": - return &dnsOverHTTPS{address: url.String()}, nil + if url.Port() == "" { + url.Host += ":443" + } + return &dnsOverHTTPS{boot: toBoot(url.String(), bootstrap)}, nil default: // assume it's plain DNS if url.Port() == "" { url.Host += ":53" } - return &plainDNS{address: url.String()}, nil + return &plainDNS{boot: toBoot(url.String(), bootstrap)}, nil } } @@ -217,5 +250,5 @@ func GetUpstream(address string) (Upstream, error) { // doesn't have port, default to 53 address = net.JoinHostPort(address, "53") } - return &plainDNS{address: address}, nil + return &plainDNS{boot: toBoot(address, bootstrap)}, nil } diff --git a/dnsforward/upstream_pool.go b/dnsforward/upstream_pool.go index f944e695..ca597808 100644 --- a/dnsforward/upstream_pool.go +++ b/dnsforward/upstream_pool.go @@ -2,9 +2,7 @@ package dnsforward import ( "crypto/tls" - "fmt" "net" - "net/url" "sync" "github.com/joomcode/errorx" @@ -27,51 +25,29 @@ import ( // log.Println(r) // pool.Put(c.Conn) type TLSPool struct { - Address string - parsedAddress *url.URL - parsedAddressMutex sync.RWMutex + boot *bootstrapper + // connections conns []net.Conn - sync.Mutex // protects conns -} - -func (n *TLSPool) getHost() (string, error) { - n.parsedAddressMutex.RLock() - if n.parsedAddress != nil { - n.parsedAddressMutex.RUnlock() - return n.parsedAddress.Host, nil - } - n.parsedAddressMutex.RUnlock() - - n.parsedAddressMutex.Lock() - defer n.parsedAddressMutex.Unlock() - url, err := url.Parse(n.Address) - if err != nil { - return "", errorx.Decorate(err, "Failed to parse %s", n.Address) - } - if url.Scheme != "tls" { - return "", fmt.Errorf("TLSPool only supports TLS") - } - n.parsedAddress = url - return n.parsedAddress.Host, nil + connsMutex sync.Mutex // protects conns } func (n *TLSPool) Get() (net.Conn, error) { - host, err := n.getHost() + address, tlsConfig, err := n.boot.get() if err != nil { return nil, err } // get the connection from the slice inside the lock var c net.Conn - n.Lock() + n.connsMutex.Lock() num := len(n.conns) if num > 0 { last := num - 1 c = n.conns[last] n.conns = n.conns[:last] } - n.Unlock() + n.connsMutex.Unlock() // if we got connection from the slice, return it if c != nil { @@ -80,10 +56,10 @@ func (n *TLSPool) Get() (net.Conn, error) { } // we'll need a new connection, dial now - // log.Printf("Dialing to %s", host) - conn, err := tls.Dial("tcp", host, nil) + // log.Printf("Dialing to %s", address) + conn, err := tls.Dial("tcp", address, tlsConfig) if err != nil { - return nil, errorx.Decorate(err, "Failed to connect to %s", host) + return nil, errorx.Decorate(err, "Failed to connect to %s", address) } return conn, nil } @@ -92,7 +68,7 @@ func (n *TLSPool) Put(c net.Conn) { if c == nil { return } - n.Lock() + n.connsMutex.Lock() n.conns = append(n.conns, c) - n.Unlock() + n.connsMutex.Unlock() } diff --git a/dnsforward/upstream_test.go b/dnsforward/upstream_test.go index 975c5035..0b83670f 100644 --- a/dnsforward/upstream_test.go +++ b/dnsforward/upstream_test.go @@ -7,53 +7,65 @@ import ( "github.com/miekg/dns" ) -func TestUpstreamDNS(t *testing.T) { - upstreams := []string{ - "8.8.8.8:53", - "1.1.1.1", - "tcp://1.1.1.1:53", - "176.103.130.130:5353", +func TestUpstreams(t *testing.T) { + upstreams := []struct { + address string + bootstrap string + }{ + { + address: "8.8.8.8:53", + bootstrap: "8.8.8.8:53", + }, + { + address: "1.1.1.1", + bootstrap: "", + }, + { + address: "tcp://1.1.1.1:53", + bootstrap: "", + }, + { + address: "176.103.130.130:5353", + bootstrap: "", + }, + { + address: "tls://1.1.1.1", + bootstrap: "", + }, + { + address: "tls://9.9.9.9:853", + bootstrap: "", + }, + { + address: "tls://security-filter-dns.cleanbrowsing.org", + bootstrap: "8.8.8.8:53", + }, + { + address: "tls://adult-filter-dns.cleanbrowsing.org:853", + bootstrap: "8.8.8.8:53", + }, + { + address: "https://cloudflare-dns.com/dns-query", + bootstrap: "8.8.8.8:53", + }, + { + address: "https://dns.google.com/experimental", + bootstrap: "8.8.8.8:53", + }, + { + address: "https://doh.cleanbrowsing.org/doh/security-filter/", + bootstrap: "", + }, } - for _, input := range upstreams { - u, err := GetUpstream(input) - if err != nil { - t.Fatalf("Failed to choose upstream for %s: %s", input, err) - } + for _, test := range upstreams { + t.Run(test.address, func(t *testing.T) { + u, err := AddressToUpstream(test.address, test.bootstrap) + if err != nil { + t.Fatalf("Failed to generate upstream from address %s: %s", test.address, err) + } - checkUpstream(t, u, input) - } -} - -func TestUpstreamTLS(t *testing.T) { - upstreams := []string{ - "tls://1.1.1.1", - "tls://9.9.9.9:853", - "tls://security-filter-dns.cleanbrowsing.org", - "tls://adult-filter-dns.cleanbrowsing.org:853", - } - for _, input := range upstreams { - u, err := GetUpstream(input) - if err != nil { - t.Fatalf("Failed to choose upstream for %s: %s", input, err) - } - - checkUpstream(t, u, input) - } -} - -func TestUpstreamHTTPS(t *testing.T) { - upstreams := []string{ - "https://cloudflare-dns.com/dns-query", - "https://dns.google.com/experimental", - "https://doh.cleanbrowsing.org/doh/security-filter/", - } - for _, input := range upstreams { - u, err := GetUpstream(input) - if err != nil { - t.Fatalf("Failed to choose upstream for %s: %s", input, err) - } - - checkUpstream(t, u, input) + checkUpstream(t, u, test.address) + }) } } From 3ad4b2864db74e8c4bdc6910112720e66d984ae7 Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Thu, 6 Dec 2018 00:23:03 +0300 Subject: [PATCH 49/58] Rename coredns.go to dns.go --- coredns.go => dns.go | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename coredns.go => dns.go (100%) diff --git a/coredns.go b/dns.go similarity index 100% rename from coredns.go rename to dns.go From 250cc0ec0f453e43d8a8f66a47c236fb800cec4f Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Thu, 6 Dec 2018 00:29:38 +0300 Subject: [PATCH 50/58] config -- Avoid deleting existing dns section if someone removes schema_version from yaml file. --- config.go | 1 + upgrade.go | 6 ++++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/config.go b/config.go index 8facc405..407fcc7a 100644 --- a/config.go +++ b/config.go @@ -72,6 +72,7 @@ var config = configuration{ {Filter: dnsfilter.Filter{ID: 3}, Enabled: false, URL: "https://hosts-file.net/ad_servers.txt", Name: "hpHosts - Ad and Tracking servers only"}, {Filter: dnsfilter.Filter{ID: 4}, Enabled: false, URL: "http://www.malwaredomainlist.com/hostslist/hosts.txt", Name: "MalwareDomainList.com Hosts List"}, }, + SchemaVersion: currentSchemaVersion, } // Loads configuration from the YAML file diff --git a/upgrade.go b/upgrade.go index 1b8c6e34..21d7686d 100644 --- a/upgrade.go +++ b/upgrade.go @@ -126,8 +126,10 @@ func upgradeSchema1to2(diskConfig *map[string]interface{}) error { } } - (*diskConfig)["dns"] = (*diskConfig)["coredns"] - delete((*diskConfig), "coredns") + if _, ok := (*diskConfig)["dns"]; !ok { + (*diskConfig)["dns"] = (*diskConfig)["coredns"] + delete((*diskConfig), "coredns") + } (*diskConfig)["schema_version"] = 2 return nil From c82a5ac0cb5a2138e8a78386aa4e8387dfd9dece Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Thu, 6 Dec 2018 16:54:48 +0300 Subject: [PATCH 51/58] dnsfilter -- fix broken tests --- dnsfilter/dnsfilter_test.go | 35 +++++++++++------------------------ 1 file changed, 11 insertions(+), 24 deletions(-) diff --git a/dnsfilter/dnsfilter_test.go b/dnsfilter/dnsfilter_test.go index 39b33a44..a93fadfc 100644 --- a/dnsfilter/dnsfilter_test.go +++ b/dnsfilter/dnsfilter_test.go @@ -338,7 +338,7 @@ func mustLoadTestRules(d *Dnsfilter) { } func NewForTest() *Dnsfilter { - d := New() + d := New(nil) purgeCaches() return d } @@ -542,7 +542,7 @@ func TestSafeBrowsing(t *testing.T) { t.Run(fmt.Sprintf("%s in %s", tc, _Func()), func(t *testing.T) { d := NewForTest() defer d.Destroy() - d.EnableSafeBrowsing() + d.SafeBrowsingEnabled = true stats.Safebrowsing.Requests = 0 d.checkMatch(t, "wmconvirus.narod.ru") d.checkMatch(t, "wmconvirus.narod.ru") @@ -570,7 +570,7 @@ func TestSafeBrowsing(t *testing.T) { func TestParallelSB(t *testing.T) { d := NewForTest() defer d.Destroy() - d.EnableSafeBrowsing() + d.SafeBrowsingEnabled = true t.Run("group", func(t *testing.T) { for i := 0; i < 100; i++ { t.Run(fmt.Sprintf("aaa%d", i), func(t *testing.T) { @@ -597,7 +597,7 @@ func TestSafeBrowsingCustomServerFail(t *testing.T) { defer ts.Close() address := ts.Listener.Addr().String() - d.EnableSafeBrowsing() + d.SafeBrowsingEnabled = true d.SetHTTPTimeout(time.Second * 5) d.SetSafeBrowsingServer(address) // this will ensure that test fails d.checkMatchEmpty(t, "wmconvirus.narod.ru") @@ -606,7 +606,8 @@ func TestSafeBrowsingCustomServerFail(t *testing.T) { func TestParentalControl(t *testing.T) { d := NewForTest() defer d.Destroy() - d.EnableParental(3) + d.ParentalEnabled = true + d.ParentalSensitivity = 3 d.checkMatch(t, "pornhub.com") d.checkMatch(t, "pornhub.com") if stats.Parental.Requests != 1 { @@ -637,7 +638,7 @@ func TestSafeSearch(t *testing.T) { if ok { t.Errorf("Expected safesearch to error when disabled") } - d.EnableSafeSearch() + d.SafeSearchEnabled = true val, ok := d.SafeSearchDomain("www.google.com") if !ok { t.Errorf("Expected safesearch to find result for www.google.com") @@ -924,7 +925,7 @@ func BenchmarkLotsOfRulesLotsOfHostsParallel(b *testing.B) { func BenchmarkSafeBrowsing(b *testing.B) { d := NewForTest() defer d.Destroy() - d.EnableSafeBrowsing() + d.SafeBrowsingEnabled = true for n := 0; n < b.N; n++ { hostname := "wmconvirus.narod.ru" ret, err := d.CheckHost(hostname) @@ -940,7 +941,7 @@ func BenchmarkSafeBrowsing(b *testing.B) { func BenchmarkSafeBrowsingParallel(b *testing.B) { d := NewForTest() defer d.Destroy() - d.EnableSafeBrowsing() + d.SafeBrowsingEnabled = true b.RunParallel(func(pb *testing.PB) { for pb.Next() { hostname := "wmconvirus.narod.ru" @@ -958,7 +959,7 @@ func BenchmarkSafeBrowsingParallel(b *testing.B) { func BenchmarkSafeSearch(b *testing.B) { d := NewForTest() defer d.Destroy() - d.EnableSafeSearch() + d.SafeSearchEnabled = true for n := 0; n < b.N; n++ { val, ok := d.SafeSearchDomain("www.google.com") if !ok { @@ -973,7 +974,7 @@ func BenchmarkSafeSearch(b *testing.B) { func BenchmarkSafeSearchParallel(b *testing.B) { d := NewForTest() defer d.Destroy() - d.EnableSafeSearch() + d.SafeSearchEnabled = true b.RunParallel(func(pb *testing.PB) { for pb.Next() { val, ok := d.SafeSearchDomain("www.google.com") @@ -1009,17 +1010,3 @@ func _Func() string { f := runtime.FuncForPC(pc[0]) return path.Base(f.Name()) } - -func trace(format string, args ...interface{}) { - pc := make([]uintptr, 10) // at least 1 entry needed - runtime.Callers(2, pc) - f := runtime.FuncForPC(pc[0]) - var buf strings.Builder - buf.WriteString(fmt.Sprintf("%s(): ", path.Base(f.Name()))) - text := fmt.Sprintf(format, args...) - buf.WriteString(text) - if len(text) == 0 || text[len(text)-1] != '\n' { - buf.WriteRune('\n') - } - fmt.Print(buf.String()) -} From 15f3c8223896e0f3a7502f92ae38dfd9407b687f Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Thu, 6 Dec 2018 16:55:05 +0300 Subject: [PATCH 52/58] dnsforward -- fix panic on ANY request --- dnsforward/querylog_top.go | 2 +- dnsforward/stats.go | 39 ++++++++++++++++++++------------------ 2 files changed, 22 insertions(+), 19 deletions(-) diff --git a/dnsforward/querylog_top.go b/dnsforward/querylog_top.go index b78dea79..9191171f 100644 --- a/dnsforward/querylog_top.go +++ b/dnsforward/querylog_top.go @@ -168,7 +168,7 @@ func (r *dayTop) addEntry(entry *logEntry, q *dns.Msg, now time.Time) error { return err } - if entry.Result.IsFiltered { + if entry.Result != nil && entry.Result.IsFiltered { err := runningTop.hours[hour].incrementBlocked(hostname) if err != nil { log.Printf("Failed to increment value: %s", err) diff --git a/dnsforward/stats.go b/dnsforward/stats.go index 9809e6e7..fead555b 100644 --- a/dnsforward/stats.go +++ b/dnsforward/stats.go @@ -193,24 +193,27 @@ func (h *histogram) Observe(value float64) { // ----- func incrementCounters(entry *logEntry) { requests.IncWithTime(entry.Time) - if entry.Result.IsFiltered { - filtered.IncWithTime(entry.Time) - } - switch entry.Result.Reason { - case dnsfilter.NotFilteredWhiteList: - whitelisted.IncWithTime(entry.Time) - case dnsfilter.NotFilteredError: - errorsTotal.IncWithTime(entry.Time) - case dnsfilter.FilteredBlackList: - filteredLists.IncWithTime(entry.Time) - case dnsfilter.FilteredSafeBrowsing: - filteredSafebrowsing.IncWithTime(entry.Time) - case dnsfilter.FilteredParental: - filteredParental.IncWithTime(entry.Time) - case dnsfilter.FilteredInvalid: - // do nothing - case dnsfilter.FilteredSafeSearch: - safesearch.IncWithTime(entry.Time) + if entry.Result != nil { + if entry.Result.IsFiltered { + filtered.IncWithTime(entry.Time) + } + + switch entry.Result.Reason { + case dnsfilter.NotFilteredWhiteList: + whitelisted.IncWithTime(entry.Time) + case dnsfilter.NotFilteredError: + errorsTotal.IncWithTime(entry.Time) + case dnsfilter.FilteredBlackList: + filteredLists.IncWithTime(entry.Time) + case dnsfilter.FilteredSafeBrowsing: + filteredSafebrowsing.IncWithTime(entry.Time) + case dnsfilter.FilteredParental: + filteredParental.IncWithTime(entry.Time) + case dnsfilter.FilteredInvalid: + // do nothing + case dnsfilter.FilteredSafeSearch: + safesearch.IncWithTime(entry.Time) + } } elapsedTime.ObserveWithTime(entry.Elapsed.Seconds(), entry.Time) } From 217124cb3b634787ee933abd7af98833478e5bec Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Thu, 6 Dec 2018 17:17:42 +0300 Subject: [PATCH 53/58] readme -- Update config field descriptions and clarify about coredns. --- README.md | 45 ++++++++++++++++++++++++++------------------- 1 file changed, 26 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index b9916574..438c3eb1 100644 --- a/README.md +++ b/README.md @@ -104,25 +104,32 @@ Upon the first execution, a file named `AdGuardHome.yaml` will be created, with Settings are stored in [YAML format](https://en.wikipedia.org/wiki/YAML), possible parameters that you can configure are listed below: - * `bind_host` — Web interface IP address to listen on - * `bind_port` — Web interface IP port to listen on - * `auth_name` — Web interface optional authorization username - * `auth_pass` — Web interface optional authorization password - * `dns` — DNS configuration section - * `port` — DNS server port to listen on - * `filtering_enabled` — Filtering of DNS requests based on filter lists - * `safebrowsing_enabled` — Filtering of DNS requests based on safebrowsing - * `safesearch_enabled` — Enforcing "Safe search" option for search engines, when possible - * `parental_enabled` — Parental control-based DNS requests filtering - * `parental_sensitivity` — Age group for parental control-based filtering, must be either 3, 10, 13 or 17 - * `querylog_enabled` — Query logging (also used to calculate top 50 clients, blocked domains and requested domains for statistic purposes) - * `bootstrap_dns` — DNS server used for initial hostname resolution in case if upstream server name is a hostname - * `upstream_dns` — List of upstream DNS servers + * `bind_host` — Web interface IP address to listen on. + * `bind_port` — Web interface IP port to listen on. + * `auth_name` — Web interface optional authorization username. + * `auth_pass` — Web interface optional authorization password. + * `dns` — DNS configuration section. + * `port` — DNS server port to listen on. + * `protection_enabled` — Whether any kind of filtering and protection should be done, when off it works as a plain dns forwarder. + * `filtering_enabled` — Filtering of DNS requests based on filter lists. + * `blocked_response_ttl` — For how many seconds the clients should cache a filtered response. Low values are useful on LAN if you change filters very often, high values are useful to increase performance and save traffic. + * `querylog_enabled` — Query logging (also used to calculate top 50 clients, blocked domains and requested domains for statistical purposes). + * `ratelimit` — DDoS protection, specifies in how many packets per second a client should receive. Anything above that is silently dropped. To disable set 0, default is 20. Safe to disable if DNS server is not available from internet. + * `ratelimit_whitelist` — If you want exclude some IP addresses from ratelimiting but keep ratelimiting on for others, put them here. + * `refuse_any` — Another DDoS protection mechanism. Requests of type ANY are rarely needed, so refusing to serve them mitigates against attackers trying to use your DNS as a reflection. Safe to disable if DNS server is not available from internet. + * `bootstrap_dns` — DNS server used for initial hostname resolution in case if upstream server name is a hostname. + * `parental_sensitivity` — Age group for parental control-based filtering, must be either 3, 10, 13 or 17 if enabled. + * `parental_enabled` — Parental control-based DNS requests filtering. + * `safesearch_enabled` — Enforcing "Safe search" option for search engines, when possible. + * `safebrowsing_enabled` — Filtering of DNS requests based on safebrowsing. + * `upstream_dns` — List of upstream DNS servers. * `filters` — List of filters, each filter has the following values: - * `ID` - filter ID (must be unique) - * `url` — URL pointing to the filter contents (filtering rules) - * `enabled` — Current filter's status (enabled/disabled) - * `user_rules` — User-specified filtering rules + * `enabled` — Current filter's status (enabled/disabled). + * `url` — URL pointing to the filter contents (filtering rules). + * `name` — Name of the filter. If it's an adguard syntax filter it will get updated automatically, otherwise it stays unchanged. + * `last_updated` — Time when the filter was last updated from server. + * `ID` - filter ID (must be unique). + * `user_rules` — User-specified filtering rules. Removing an entry from settings file will reset it to the default value. Deleting the file will reset all settings to the default values. @@ -208,6 +215,6 @@ This software wouldn't have been possible without: * And many more node.js packages. * [whotracks.me data](https://github.com/cliqz-oss/whotracks.me) -You might have seen that coredns was mentioned here before — we've stopped using it in AdGuardHome. While we still use it on our servers, it seemed like an overkill and it impeded with Home features that we wanted to implement. +You might have seen that [CoreDNS](https://coredns.io) was mentioned here before — we've stopped using it in AdGuardHome. While we still use it on our servers for [AdGuard DNS](https://adguard.com/adguard-dns/overview.html) service, it seemed like an overkill for Home as it impeded with Home features that we plan to implement. For a full list of all node.js packages in use, please take a look at [client/package.json](https://github.com/AdguardTeam/AdGuardHome/blob/master/client/package.json) file. From e5c1211e17ca32fdcbd148243906d3c11013bede Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Thu, 6 Dec 2018 17:18:16 +0300 Subject: [PATCH 54/58] Small code review update -- use CamelCase --- app.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/app.go b/app.go index 10469245..bbe36359 100644 --- a/app.go +++ b/app.go @@ -166,10 +166,10 @@ func main() { } }() - signal_channel := make(chan os.Signal) - signal.Notify(signal_channel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT) + signalChannel := make(chan os.Signal) + signal.Notify(signalChannel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT) go func() { - <-signal_channel + <-signalChannel cleanup() os.Exit(0) }() From 8fb6f92753e8bc9c45890a04667944684ce9bb0c Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Thu, 6 Dec 2018 17:18:34 +0300 Subject: [PATCH 55/58] Use filter deduplication function. --- config.go | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/config.go b/config.go index 407fcc7a..1f7464b2 100644 --- a/config.go +++ b/config.go @@ -96,20 +96,7 @@ func parseConfig() error { } // Deduplicate filters - { - i := 0 // output index, used for deletion later - urls := map[string]bool{} - for _, filter := range config.Filters { - if _, ok := urls[filter.URL]; !ok { - // we didn't see it before, keep it - urls[filter.URL] = true // remember the URL - config.Filters[i] = filter - i++ - } - } - // all entries we want to keep are at front, delete the rest - config.Filters = config.Filters[:i] - } + deduplicateFilters() updateUniqueFilterID(config.Filters) From e056fb2eb98fdedc5e0caaeb04e089c4e67a8253 Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Thu, 6 Dec 2018 17:18:44 +0300 Subject: [PATCH 56/58] Remove unused code. --- dnsfilter/dnsfilter.go | 3 --- dnsforward/upstream.go | 14 -------------- 2 files changed, 17 deletions(-) diff --git a/dnsfilter/dnsfilter.go b/dnsfilter/dnsfilter.go index 2c78a92b..cd408a4d 100644 --- a/dnsfilter/dnsfilter.go +++ b/dnsfilter/dnsfilter.go @@ -38,9 +38,6 @@ var ErrInvalidSyntax = errors.New("dnsfilter: invalid rule syntax") // ErrInvalidSyntax is returned by AddRule when the rule was already added to the filter var ErrAlreadyExists = errors.New("dnsfilter: rule was already added") -// ErrInvalidParental is returned by EnableParental when sensitivity is not a valid value -var ErrInvalidParental = errors.New("dnsfilter: invalid parental sensitivity, must be either 3, 10, 13 or 17") - const shortcutLength = 6 // used for rule search optimization, 6 hits the sweet spot const enableFastLookup = true // flag for debugging, must be true in production for faster performance diff --git a/dnsforward/upstream.go b/dnsforward/upstream.go index 99142929..4f6b18da 100644 --- a/dnsforward/upstream.go +++ b/dnsforward/upstream.go @@ -76,13 +76,6 @@ type dnsOverTLS struct { sync.RWMutex // protects pool } -var defaultTLSClient = dns.Client{ - Net: "tcp-tls", - Timeout: defaultTimeout, - UDPSize: dns.MaxMsgSize, - TLSConfig: &tls.Config{}, -} - func (p *dnsOverTLS) Address() string { return p.boot.address } func (p *dnsOverTLS) Exchange(m *dns.Msg) (*dns.Msg, error) { @@ -128,13 +121,6 @@ type dnsOverHTTPS struct { boot bootstrapper } -var defaultHTTPSTransport = http.Transport{} - -var defaultHTTPSClient = http.Client{ - Transport: &defaultHTTPSTransport, - Timeout: defaultTimeout, -} - func (p *dnsOverHTTPS) Address() string { return p.boot.address } func (p *dnsOverHTTPS) Exchange(m *dns.Msg) (*dns.Msg, error) { From 2ba5cb48b2e8a2bade7553920f3f5d8fe9cd6832 Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Thu, 6 Dec 2018 17:19:57 +0300 Subject: [PATCH 57/58] Fixup of previous commit -- remove unused import. --- dnsforward/upstream.go | 1 - 1 file changed, 1 deletion(-) diff --git a/dnsforward/upstream.go b/dnsforward/upstream.go index 4f6b18da..89016951 100644 --- a/dnsforward/upstream.go +++ b/dnsforward/upstream.go @@ -2,7 +2,6 @@ package dnsforward import ( "bytes" - "crypto/tls" "fmt" "io/ioutil" "log" From 253d8a4016d66863ecee426b8f7d74841c4ed4de Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Thu, 6 Dec 2018 17:27:38 +0300 Subject: [PATCH 58/58] Pointer for dnsfilter.Result in querylog didn't make things simpler, revert that change and all related changes. --- dnsforward/querylog.go | 8 ++++++-- dnsforward/querylog_top.go | 2 +- dnsforward/stats.go | 38 ++++++++++++++++++-------------------- 3 files changed, 25 insertions(+), 23 deletions(-) diff --git a/dnsforward/querylog.go b/dnsforward/querylog.go index 65af8b74..d449990d 100644 --- a/dnsforward/querylog.go +++ b/dnsforward/querylog.go @@ -34,7 +34,7 @@ var ( type logEntry struct { Question []byte Answer []byte `json:",omitempty"` // sometimes empty answers happen like binerdunt.top or rev2.globalrootservers.net - Result *dnsfilter.Result + Result dnsfilter.Result Time time.Time Elapsed time.Duration IP string @@ -61,11 +61,15 @@ func logRequest(question *dns.Msg, answer *dns.Msg, result *dnsfilter.Result, el } } + if result == nil { + result = &dnsfilter.Result{} + } + now := time.Now() entry := logEntry{ Question: q, Answer: a, - Result: result, + Result: *result, Time: now, Elapsed: elapsed, IP: ip, diff --git a/dnsforward/querylog_top.go b/dnsforward/querylog_top.go index 9191171f..b78dea79 100644 --- a/dnsforward/querylog_top.go +++ b/dnsforward/querylog_top.go @@ -168,7 +168,7 @@ func (r *dayTop) addEntry(entry *logEntry, q *dns.Msg, now time.Time) error { return err } - if entry.Result != nil && entry.Result.IsFiltered { + if entry.Result.IsFiltered { err := runningTop.hours[hour].incrementBlocked(hostname) if err != nil { log.Printf("Failed to increment value: %s", err) diff --git a/dnsforward/stats.go b/dnsforward/stats.go index fead555b..9cfe5f58 100644 --- a/dnsforward/stats.go +++ b/dnsforward/stats.go @@ -193,27 +193,25 @@ func (h *histogram) Observe(value float64) { // ----- func incrementCounters(entry *logEntry) { requests.IncWithTime(entry.Time) - if entry.Result != nil { - if entry.Result.IsFiltered { - filtered.IncWithTime(entry.Time) - } + if entry.Result.IsFiltered { + filtered.IncWithTime(entry.Time) + } - switch entry.Result.Reason { - case dnsfilter.NotFilteredWhiteList: - whitelisted.IncWithTime(entry.Time) - case dnsfilter.NotFilteredError: - errorsTotal.IncWithTime(entry.Time) - case dnsfilter.FilteredBlackList: - filteredLists.IncWithTime(entry.Time) - case dnsfilter.FilteredSafeBrowsing: - filteredSafebrowsing.IncWithTime(entry.Time) - case dnsfilter.FilteredParental: - filteredParental.IncWithTime(entry.Time) - case dnsfilter.FilteredInvalid: - // do nothing - case dnsfilter.FilteredSafeSearch: - safesearch.IncWithTime(entry.Time) - } + switch entry.Result.Reason { + case dnsfilter.NotFilteredWhiteList: + whitelisted.IncWithTime(entry.Time) + case dnsfilter.NotFilteredError: + errorsTotal.IncWithTime(entry.Time) + case dnsfilter.FilteredBlackList: + filteredLists.IncWithTime(entry.Time) + case dnsfilter.FilteredSafeBrowsing: + filteredSafebrowsing.IncWithTime(entry.Time) + case dnsfilter.FilteredParental: + filteredParental.IncWithTime(entry.Time) + case dnsfilter.FilteredInvalid: + // do nothing + case dnsfilter.FilteredSafeSearch: + safesearch.IncWithTime(entry.Time) } elapsedTime.ObserveWithTime(entry.Elapsed.Seconds(), entry.Time) }