diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index f14845a8..e94b47cc 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -445,7 +445,15 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error { // A better approach is for proxy.Stop() to wait until all its workers exit, // but this would require the Upstream interface to have Close() function // (to prevent from hanging while waiting for unresponsive DNS server to respond). - res, err := s.filterDNSRequest(d) + + var setts *dnsfilter.RequestFilteringSettings + var err error + res := &dnsfilter.Result{} + protectionEnabled := s.conf.ProtectionEnabled && s.dnsFilter != nil + if protectionEnabled { + setts = s.getClientRequestFilteringSettings(d) + res, err = s.filterDNSRequest(d, setts) + } s.RUnlock() if err != nil { return err @@ -486,9 +494,9 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error { d.Res.Answer = answer } - } else if res.Reason != dnsfilter.NotFilteredWhiteList { + } else if res.Reason != dnsfilter.NotFilteredWhiteList && protectionEnabled { origResp2 := d.Res - res, err = s.filterDNSResponse(d) + res, err = s.filterDNSResponse(d, setts) if err != nil { return err } @@ -602,12 +610,7 @@ func (s *Server) getClientRequestFilteringSettings(d *proxy.DNSContext) *dnsfilt } // filterDNSRequest applies the dnsFilter and sets d.Res if the request was filtered -func (s *Server) filterDNSRequest(d *proxy.DNSContext) (*dnsfilter.Result, error) { - if !s.conf.ProtectionEnabled || s.dnsFilter == nil { - return &dnsfilter.Result{}, nil - } - - setts := s.getClientRequestFilteringSettings(d) +func (s *Server) filterDNSRequest(d *proxy.DNSContext, setts *dnsfilter.RequestFilteringSettings) (*dnsfilter.Result, error) { req := d.Req host := strings.TrimSuffix(req.Question[0].Name, ".") res, err := s.dnsFilter.CheckHost(host, d.Req.Question[0].Qtype, setts) @@ -648,7 +651,7 @@ func (s *Server) filterDNSRequest(d *proxy.DNSContext) (*dnsfilter.Result, error // If response contains CNAME, A or AAAA records, we apply filtering to each canonical host name or IP address. // If this is a match, we set a new response in d.Res and return. -func (s *Server) filterDNSResponse(d *proxy.DNSContext) (*dnsfilter.Result, error) { +func (s *Server) filterDNSResponse(d *proxy.DNSContext, setts *dnsfilter.RequestFilteringSettings) (*dnsfilter.Result, error) { for _, a := range d.Res.Answer { host := "" @@ -676,7 +679,6 @@ func (s *Server) filterDNSResponse(d *proxy.DNSContext) (*dnsfilter.Result, erro s.RUnlock() continue } - setts := s.getClientRequestFilteringSettings(d) res, err := s.dnsFilter.CheckHostRules(host, d.Req.Question[0].Qtype, setts) s.RUnlock() diff --git a/dnsforward/dnsforward_test.go b/dnsforward/dnsforward_test.go index 76f8f028..53803c78 100644 --- a/dnsforward/dnsforward_test.go +++ b/dnsforward/dnsforward_test.go @@ -340,6 +340,22 @@ var testIPv4 = map[string][]net.IP{ "example.org.": {{127, 0, 0, 255}}, } +func TestBlockCNAMEProtectionEnabled(t *testing.T) { + s := createTestServer(t) + testUpstm := &testUpstream{testCNAMEs, testIPv4, nil} + s.conf.ProtectionEnabled = false + err := s.startWithUpstream(testUpstm) + assert.True(t, err == nil) + addr := s.dnsProxy.Addr(proxy.ProtoUDP) + + // 'badhost' has a canonical name 'null.example.org' which is blocked by filters: + // but protection is disabled - response is NOT blocked + req := createTestMessage("badhost.") + reply, err := dns.Exchange(req, addr.String()) + assert.True(t, err == nil) + assert.True(t, reply.Rcode == dns.RcodeSuccess) +} + func TestBlockCNAME(t *testing.T) { s := createTestServer(t) testUpstm := &testUpstream{testCNAMEs, testIPv4, nil} @@ -349,35 +365,23 @@ func TestBlockCNAME(t *testing.T) { // 'badhost' has a canonical name 'null.example.org' which is blocked by filters: // response is blocked - req := dns.Msg{} - req.Id = dns.Id() - req.Question = []dns.Question{ - {Name: "badhost.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, - } - reply, err := dns.Exchange(&req, addr.String()) + req := createTestMessage("badhost.") + reply, err := dns.Exchange(req, addr.String()) assert.True(t, err == nil) assert.True(t, reply.Rcode == dns.RcodeNameError) // 'whitelist.example.org' has a canonical name 'null.example.org' which is blocked by filters // but 'whitelist.example.org' is in a whitelist: // response isn't blocked - req = dns.Msg{} - req.Id = dns.Id() - req.Question = []dns.Question{ - {Name: "whitelist.example.org.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, - } - reply, err = dns.Exchange(&req, addr.String()) + req = createTestMessage("whitelist.example.org.") + reply, err = dns.Exchange(req, addr.String()) assert.True(t, err == nil) assert.True(t, reply.Rcode == dns.RcodeSuccess) // 'example.org' has a canonical name 'cname1' with IP 127.0.0.255 which is blocked by filters: // response is blocked - req = dns.Msg{} - req.Id = dns.Id() - req.Question = []dns.Question{ - {Name: "example.org.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, - } - reply, err = dns.Exchange(&req, addr.String()) + req = createTestMessage("example.org.") + reply, err = dns.Exchange(req, addr.String()) assert.True(t, err == nil) assert.True(t, reply.Rcode == dns.RcodeNameError) diff --git a/home/clients.go b/home/clients.go index 3b94edef..65ce78e2 100644 --- a/home/clients.go +++ b/home/clients.go @@ -32,12 +32,16 @@ type Client struct { SafeSearchEnabled bool SafeBrowsingEnabled bool ParentalEnabled bool - WhoisInfo [][]string // [[key,value], ...] UseOwnBlockedServices bool // false: use global settings BlockedServices []string Upstreams []string // list of upstream servers to be used for the client's requests + // Upstream objects: + // nil: not yet initialized + // not nil, but empty: initialized, no good upstreams + // not nil, not empty: Upstreams ready to be used + upstreamObjects []upstream.Upstream } type clientSource uint @@ -63,12 +67,7 @@ type clientsContainer struct { list map[string]*Client // name -> client idIndex map[string]*Client // IP -> client ipHost map[string]*ClientHost // IP -> Hostname - - // cache for Upstream instances that are used in the case - // when custom DNS servers are configured for a client - upstreamsCache map[string][]upstream.Upstream // name -> []Upstream - - lock sync.Mutex + lock sync.Mutex // dhcpServer is used for looking up clients IP addresses by MAC addresses dhcpServer *dhcpd.Server @@ -85,7 +84,6 @@ func (clients *clientsContainer) Init(objects []clientObject, dhcpServer *dhcpd. clients.list = make(map[string]*Client) clients.idIndex = make(map[string]*Client) clients.ipHost = make(map[string]*ClientHost) - clients.upstreamsCache = make(map[string][]upstream.Upstream) clients.dhcpServer = dhcpServer clients.addFromConfig(objects) @@ -199,6 +197,12 @@ func (clients *clientsContainer) Find(ip string) (Client, bool) { return clients.findByIP(ip) } +func upstreamArrayCopy(a []upstream.Upstream) []upstream.Upstream { + a2 := make([]upstream.Upstream, len(a)) + copy(a2, a) + return a2 +} + // FindUpstreams looks for upstreams configured for the client // If no client found for this IP, or if no custom upstreams are configured, // this method returns nil @@ -211,31 +215,22 @@ func (clients *clientsContainer) FindUpstreams(ip string) []upstream.Upstream { return nil } - if len(c.Upstreams) == 0 { + if c.upstreamObjects == nil { + c.upstreamObjects = make([]upstream.Upstream, 0) + for _, us := range c.Upstreams { + u, err := upstream.AddressToUpstream(us, upstream.Options{Timeout: dnsforward.DefaultTimeout}) + if err != nil { + log.Error("upstream.AddressToUpstream: %s: %s", us, err) + continue + } + c.upstreamObjects = append(c.upstreamObjects, u) + } + } + + if len(c.upstreamObjects) == 0 { return nil } - - upstreams, ok := clients.upstreamsCache[c.Name] - if ok { - return upstreams - } - - for _, us := range c.Upstreams { - u, err := upstream.AddressToUpstream(us, upstream.Options{Timeout: dnsforward.DefaultTimeout}) - if err != nil { - log.Error("upstream.AddressToUpstream: %s: %s", us, err) - continue - } - upstreams = append(upstreams, u) - } - - if len(upstreams) == 0 { - clients.upstreamsCache[c.Name] = nil - } else { - clients.upstreamsCache[c.Name] = upstreams - } - - return upstreams + return upstreamArrayCopy(c.upstreamObjects) } // Find searches for a client by IP (and does not lock anything) @@ -366,17 +361,6 @@ func (clients *clientsContainer) Add(c Client) (bool, error) { } } - // remove auto-clients with the same IP address, keeping WHOIS info if possible - for _, id := range c.IDs { - ch, ok := clients.ipHost[id] - if ok { - if len(c.WhoisInfo) == 0 { - c.WhoisInfo = ch.WhoisInfo - } - delete(clients.ipHost, id) - } - } - // update Name index clients.list[c.Name] = &c @@ -402,9 +386,6 @@ func (clients *clientsContainer) Del(name string) bool { // update Name index delete(clients.list, name) - // update upstreams cache - delete(clients.upstreamsCache, name) - // update ID index for _, id := range c.IDs { delete(clients.idIndex, id) @@ -473,11 +454,7 @@ func (clients *clientsContainer) Update(name string, c Client) error { } // update upstreams cache - if old.Name != c.Name { - delete(clients.upstreamsCache, old.Name) - } else { - delete(clients.upstreamsCache, c.Name) - } + c.upstreamObjects = nil *old = c return nil @@ -517,12 +494,6 @@ func (clients *clientsContainer) AddHost(ip, host string, source clientSource) ( clients.lock.Lock() defer clients.lock.Unlock() - // check existing clients first - _, ok := clients.findByIP(ip) - if ok { - return false, nil - } - // check auto-clients index ch, ok := clients.ipHost[ip] if ok && ch.Source > source { diff --git a/home/clients_http.go b/home/clients_http.go index 37336298..aaaef5e0 100644 --- a/home/clients_http.go +++ b/home/clients_http.go @@ -16,8 +16,6 @@ type clientJSON struct { SafeSearchEnabled bool `json:"safesearch_enabled"` SafeBrowsingEnabled bool `json:"safebrowsing_enabled"` - WhoisInfo map[string]interface{} `json:"whois_info"` - UseGlobalBlockedServices bool `json:"use_global_blocked_services"` BlockedServices []string `json:"blocked_services"` @@ -116,11 +114,6 @@ func clientToJSON(c *Client) clientJSON { Upstreams: c.Upstreams, } - - cj.WhoisInfo = make(map[string]interface{}) - for _, wi := range c.WhoisInfo { - cj.WhoisInfo[wi[0]] = wi[1] - } return cj } diff --git a/home/clients_test.go b/home/clients_test.go index e227d699..4468de35 100644 --- a/home/clients_test.go +++ b/home/clients_test.go @@ -168,18 +168,18 @@ func TestClientsWhois(t *testing.T) { clients.SetWhoisInfo("1.1.1.1", whois) assert.True(t, clients.ipHost["1.1.1.1"].WhoisInfo[0][1] == "orgname-val") - // Check that we cannot set whois info on existing client + // Check that we cannot set whois info on a manually-added client c = Client{ IDs: []string{"1.1.1.2"}, Name: "client1", } _, _ = clients.Add(c) clients.SetWhoisInfo("1.1.1.2", whois) - assert.Nil(t, clients.idIndex["1.1.1.2"].WhoisInfo) + assert.True(t, clients.ipHost["1.1.1.2"] == nil) _ = clients.Del("client1") } -func TestClientsAddExistingHost(t *testing.T) { +func TestClientsAddExisting(t *testing.T) { var c Client clients := clientsContainer{} clients.testing = true @@ -198,9 +198,9 @@ func TestClientsAddExistingHost(t *testing.T) { assert.True(t, ok) assert.Nil(t, err) - // try adding a duplicate by IP + // add an auto-client with the same IP - it's allowed ok, err = clients.AddHost("1.1.1.1", "test", ClientSourceRDNS) - assert.False(t, ok) + assert.True(t, ok) assert.Nil(t, err) // now some more complicated stuff @@ -218,13 +218,21 @@ func TestClientsAddExistingHost(t *testing.T) { }) assert.Nil(t, err) - // try adding a duplicate IP which for a Mac-based client - ok, err = clients.AddHost(testIP, "test", ClientSourceRDNS) - assert.False(t, ok) + // add a new client with the same IP as for a client with MAC + c = Client{ + IDs: []string{testIP}, + Name: "client2", + } + ok, err = clients.Add(c) + assert.True(t, ok) assert.Nil(t, err) - // don't allow duplicates by CIDR - ok, err = clients.AddHost("2.2.2.2", "test", ClientSourceRDNS) - assert.False(t, ok) + // add a new client with the IP from the client1's IP range + c = Client{ + IDs: []string{"2.2.2.2"}, + Name: "client3", + } + ok, err = clients.Add(c) + assert.True(t, ok) assert.Nil(t, err) }