Merge: resolve minor issues

* commit '3b9d758510cd65e483524de1e698076326950ab0':
  * refactor
  * dnsforward: get per-client settings only once
  * clients: manual clients don't exclude auto-clients anymore
  * clients: Upstreams cache: refactor
  * clients: remove WHOIS info for manually-added clients
This commit is contained in:
Simon Zolin 2020-01-10 19:18:10 +03:00
commit 318ed0dafb
5 changed files with 81 additions and 103 deletions

View File

@ -445,7 +445,15 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
// A better approach is for proxy.Stop() to wait until all its workers exit,
// but this would require the Upstream interface to have Close() function
// (to prevent from hanging while waiting for unresponsive DNS server to respond).
res, err := s.filterDNSRequest(d)
var setts *dnsfilter.RequestFilteringSettings
var err error
res := &dnsfilter.Result{}
protectionEnabled := s.conf.ProtectionEnabled && s.dnsFilter != nil
if protectionEnabled {
setts = s.getClientRequestFilteringSettings(d)
res, err = s.filterDNSRequest(d, setts)
}
s.RUnlock()
if err != nil {
return err
@ -486,9 +494,9 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
d.Res.Answer = answer
}
} else if res.Reason != dnsfilter.NotFilteredWhiteList {
} else if res.Reason != dnsfilter.NotFilteredWhiteList && protectionEnabled {
origResp2 := d.Res
res, err = s.filterDNSResponse(d)
res, err = s.filterDNSResponse(d, setts)
if err != nil {
return err
}
@ -602,12 +610,7 @@ func (s *Server) getClientRequestFilteringSettings(d *proxy.DNSContext) *dnsfilt
}
// filterDNSRequest applies the dnsFilter and sets d.Res if the request was filtered
func (s *Server) filterDNSRequest(d *proxy.DNSContext) (*dnsfilter.Result, error) {
if !s.conf.ProtectionEnabled || s.dnsFilter == nil {
return &dnsfilter.Result{}, nil
}
setts := s.getClientRequestFilteringSettings(d)
func (s *Server) filterDNSRequest(d *proxy.DNSContext, setts *dnsfilter.RequestFilteringSettings) (*dnsfilter.Result, error) {
req := d.Req
host := strings.TrimSuffix(req.Question[0].Name, ".")
res, err := s.dnsFilter.CheckHost(host, d.Req.Question[0].Qtype, setts)
@ -648,7 +651,7 @@ func (s *Server) filterDNSRequest(d *proxy.DNSContext) (*dnsfilter.Result, error
// If response contains CNAME, A or AAAA records, we apply filtering to each canonical host name or IP address.
// If this is a match, we set a new response in d.Res and return.
func (s *Server) filterDNSResponse(d *proxy.DNSContext) (*dnsfilter.Result, error) {
func (s *Server) filterDNSResponse(d *proxy.DNSContext, setts *dnsfilter.RequestFilteringSettings) (*dnsfilter.Result, error) {
for _, a := range d.Res.Answer {
host := ""
@ -676,7 +679,6 @@ func (s *Server) filterDNSResponse(d *proxy.DNSContext) (*dnsfilter.Result, erro
s.RUnlock()
continue
}
setts := s.getClientRequestFilteringSettings(d)
res, err := s.dnsFilter.CheckHostRules(host, d.Req.Question[0].Qtype, setts)
s.RUnlock()

View File

@ -340,6 +340,22 @@ var testIPv4 = map[string][]net.IP{
"example.org.": {{127, 0, 0, 255}},
}
func TestBlockCNAMEProtectionEnabled(t *testing.T) {
s := createTestServer(t)
testUpstm := &testUpstream{testCNAMEs, testIPv4, nil}
s.conf.ProtectionEnabled = false
err := s.startWithUpstream(testUpstm)
assert.True(t, err == nil)
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
// 'badhost' has a canonical name 'null.example.org' which is blocked by filters:
// but protection is disabled - response is NOT blocked
req := createTestMessage("badhost.")
reply, err := dns.Exchange(req, addr.String())
assert.True(t, err == nil)
assert.True(t, reply.Rcode == dns.RcodeSuccess)
}
func TestBlockCNAME(t *testing.T) {
s := createTestServer(t)
testUpstm := &testUpstream{testCNAMEs, testIPv4, nil}
@ -349,35 +365,23 @@ func TestBlockCNAME(t *testing.T) {
// 'badhost' has a canonical name 'null.example.org' which is blocked by filters:
// response is blocked
req := dns.Msg{}
req.Id = dns.Id()
req.Question = []dns.Question{
{Name: "badhost.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
}
reply, err := dns.Exchange(&req, addr.String())
req := createTestMessage("badhost.")
reply, err := dns.Exchange(req, addr.String())
assert.True(t, err == nil)
assert.True(t, reply.Rcode == dns.RcodeNameError)
// 'whitelist.example.org' has a canonical name 'null.example.org' which is blocked by filters
// but 'whitelist.example.org' is in a whitelist:
// response isn't blocked
req = dns.Msg{}
req.Id = dns.Id()
req.Question = []dns.Question{
{Name: "whitelist.example.org.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
}
reply, err = dns.Exchange(&req, addr.String())
req = createTestMessage("whitelist.example.org.")
reply, err = dns.Exchange(req, addr.String())
assert.True(t, err == nil)
assert.True(t, reply.Rcode == dns.RcodeSuccess)
// 'example.org' has a canonical name 'cname1' with IP 127.0.0.255 which is blocked by filters:
// response is blocked
req = dns.Msg{}
req.Id = dns.Id()
req.Question = []dns.Question{
{Name: "example.org.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
}
reply, err = dns.Exchange(&req, addr.String())
req = createTestMessage("example.org.")
reply, err = dns.Exchange(req, addr.String())
assert.True(t, err == nil)
assert.True(t, reply.Rcode == dns.RcodeNameError)

View File

@ -32,12 +32,16 @@ type Client struct {
SafeSearchEnabled bool
SafeBrowsingEnabled bool
ParentalEnabled bool
WhoisInfo [][]string // [[key,value], ...]
UseOwnBlockedServices bool // false: use global settings
BlockedServices []string
Upstreams []string // list of upstream servers to be used for the client's requests
// Upstream objects:
// nil: not yet initialized
// not nil, but empty: initialized, no good upstreams
// not nil, not empty: Upstreams ready to be used
upstreamObjects []upstream.Upstream
}
type clientSource uint
@ -63,12 +67,7 @@ type clientsContainer struct {
list map[string]*Client // name -> client
idIndex map[string]*Client // IP -> client
ipHost map[string]*ClientHost // IP -> Hostname
// cache for Upstream instances that are used in the case
// when custom DNS servers are configured for a client
upstreamsCache map[string][]upstream.Upstream // name -> []Upstream
lock sync.Mutex
lock sync.Mutex
// dhcpServer is used for looking up clients IP addresses by MAC addresses
dhcpServer *dhcpd.Server
@ -85,7 +84,6 @@ func (clients *clientsContainer) Init(objects []clientObject, dhcpServer *dhcpd.
clients.list = make(map[string]*Client)
clients.idIndex = make(map[string]*Client)
clients.ipHost = make(map[string]*ClientHost)
clients.upstreamsCache = make(map[string][]upstream.Upstream)
clients.dhcpServer = dhcpServer
clients.addFromConfig(objects)
@ -199,6 +197,12 @@ func (clients *clientsContainer) Find(ip string) (Client, bool) {
return clients.findByIP(ip)
}
func upstreamArrayCopy(a []upstream.Upstream) []upstream.Upstream {
a2 := make([]upstream.Upstream, len(a))
copy(a2, a)
return a2
}
// FindUpstreams looks for upstreams configured for the client
// If no client found for this IP, or if no custom upstreams are configured,
// this method returns nil
@ -211,31 +215,22 @@ func (clients *clientsContainer) FindUpstreams(ip string) []upstream.Upstream {
return nil
}
if len(c.Upstreams) == 0 {
if c.upstreamObjects == nil {
c.upstreamObjects = make([]upstream.Upstream, 0)
for _, us := range c.Upstreams {
u, err := upstream.AddressToUpstream(us, upstream.Options{Timeout: dnsforward.DefaultTimeout})
if err != nil {
log.Error("upstream.AddressToUpstream: %s: %s", us, err)
continue
}
c.upstreamObjects = append(c.upstreamObjects, u)
}
}
if len(c.upstreamObjects) == 0 {
return nil
}
upstreams, ok := clients.upstreamsCache[c.Name]
if ok {
return upstreams
}
for _, us := range c.Upstreams {
u, err := upstream.AddressToUpstream(us, upstream.Options{Timeout: dnsforward.DefaultTimeout})
if err != nil {
log.Error("upstream.AddressToUpstream: %s: %s", us, err)
continue
}
upstreams = append(upstreams, u)
}
if len(upstreams) == 0 {
clients.upstreamsCache[c.Name] = nil
} else {
clients.upstreamsCache[c.Name] = upstreams
}
return upstreams
return upstreamArrayCopy(c.upstreamObjects)
}
// Find searches for a client by IP (and does not lock anything)
@ -366,17 +361,6 @@ func (clients *clientsContainer) Add(c Client) (bool, error) {
}
}
// remove auto-clients with the same IP address, keeping WHOIS info if possible
for _, id := range c.IDs {
ch, ok := clients.ipHost[id]
if ok {
if len(c.WhoisInfo) == 0 {
c.WhoisInfo = ch.WhoisInfo
}
delete(clients.ipHost, id)
}
}
// update Name index
clients.list[c.Name] = &c
@ -402,9 +386,6 @@ func (clients *clientsContainer) Del(name string) bool {
// update Name index
delete(clients.list, name)
// update upstreams cache
delete(clients.upstreamsCache, name)
// update ID index
for _, id := range c.IDs {
delete(clients.idIndex, id)
@ -473,11 +454,7 @@ func (clients *clientsContainer) Update(name string, c Client) error {
}
// update upstreams cache
if old.Name != c.Name {
delete(clients.upstreamsCache, old.Name)
} else {
delete(clients.upstreamsCache, c.Name)
}
c.upstreamObjects = nil
*old = c
return nil
@ -517,12 +494,6 @@ func (clients *clientsContainer) AddHost(ip, host string, source clientSource) (
clients.lock.Lock()
defer clients.lock.Unlock()
// check existing clients first
_, ok := clients.findByIP(ip)
if ok {
return false, nil
}
// check auto-clients index
ch, ok := clients.ipHost[ip]
if ok && ch.Source > source {

View File

@ -16,8 +16,6 @@ type clientJSON struct {
SafeSearchEnabled bool `json:"safesearch_enabled"`
SafeBrowsingEnabled bool `json:"safebrowsing_enabled"`
WhoisInfo map[string]interface{} `json:"whois_info"`
UseGlobalBlockedServices bool `json:"use_global_blocked_services"`
BlockedServices []string `json:"blocked_services"`
@ -116,11 +114,6 @@ func clientToJSON(c *Client) clientJSON {
Upstreams: c.Upstreams,
}
cj.WhoisInfo = make(map[string]interface{})
for _, wi := range c.WhoisInfo {
cj.WhoisInfo[wi[0]] = wi[1]
}
return cj
}

View File

@ -168,18 +168,18 @@ func TestClientsWhois(t *testing.T) {
clients.SetWhoisInfo("1.1.1.1", whois)
assert.True(t, clients.ipHost["1.1.1.1"].WhoisInfo[0][1] == "orgname-val")
// Check that we cannot set whois info on existing client
// Check that we cannot set whois info on a manually-added client
c = Client{
IDs: []string{"1.1.1.2"},
Name: "client1",
}
_, _ = clients.Add(c)
clients.SetWhoisInfo("1.1.1.2", whois)
assert.Nil(t, clients.idIndex["1.1.1.2"].WhoisInfo)
assert.True(t, clients.ipHost["1.1.1.2"] == nil)
_ = clients.Del("client1")
}
func TestClientsAddExistingHost(t *testing.T) {
func TestClientsAddExisting(t *testing.T) {
var c Client
clients := clientsContainer{}
clients.testing = true
@ -198,9 +198,9 @@ func TestClientsAddExistingHost(t *testing.T) {
assert.True(t, ok)
assert.Nil(t, err)
// try adding a duplicate by IP
// add an auto-client with the same IP - it's allowed
ok, err = clients.AddHost("1.1.1.1", "test", ClientSourceRDNS)
assert.False(t, ok)
assert.True(t, ok)
assert.Nil(t, err)
// now some more complicated stuff
@ -218,13 +218,21 @@ func TestClientsAddExistingHost(t *testing.T) {
})
assert.Nil(t, err)
// try adding a duplicate IP which for a Mac-based client
ok, err = clients.AddHost(testIP, "test", ClientSourceRDNS)
assert.False(t, ok)
// add a new client with the same IP as for a client with MAC
c = Client{
IDs: []string{testIP},
Name: "client2",
}
ok, err = clients.Add(c)
assert.True(t, ok)
assert.Nil(t, err)
// don't allow duplicates by CIDR
ok, err = clients.AddHost("2.2.2.2", "test", ClientSourceRDNS)
assert.False(t, ok)
// add a new client with the IP from the client1's IP range
c = Client{
IDs: []string{"2.2.2.2"},
Name: "client3",
}
ok, err = clients.Add(c)
assert.True(t, ok)
assert.Nil(t, err)
}