Merge: resolve minor issues
* commit '3b9d758510cd65e483524de1e698076326950ab0': * refactor * dnsforward: get per-client settings only once * clients: manual clients don't exclude auto-clients anymore * clients: Upstreams cache: refactor * clients: remove WHOIS info for manually-added clients
This commit is contained in:
commit
318ed0dafb
|
@ -445,7 +445,15 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
|
|||
// A better approach is for proxy.Stop() to wait until all its workers exit,
|
||||
// but this would require the Upstream interface to have Close() function
|
||||
// (to prevent from hanging while waiting for unresponsive DNS server to respond).
|
||||
res, err := s.filterDNSRequest(d)
|
||||
|
||||
var setts *dnsfilter.RequestFilteringSettings
|
||||
var err error
|
||||
res := &dnsfilter.Result{}
|
||||
protectionEnabled := s.conf.ProtectionEnabled && s.dnsFilter != nil
|
||||
if protectionEnabled {
|
||||
setts = s.getClientRequestFilteringSettings(d)
|
||||
res, err = s.filterDNSRequest(d, setts)
|
||||
}
|
||||
s.RUnlock()
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -486,9 +494,9 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
|
|||
d.Res.Answer = answer
|
||||
}
|
||||
|
||||
} else if res.Reason != dnsfilter.NotFilteredWhiteList {
|
||||
} else if res.Reason != dnsfilter.NotFilteredWhiteList && protectionEnabled {
|
||||
origResp2 := d.Res
|
||||
res, err = s.filterDNSResponse(d)
|
||||
res, err = s.filterDNSResponse(d, setts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -602,12 +610,7 @@ func (s *Server) getClientRequestFilteringSettings(d *proxy.DNSContext) *dnsfilt
|
|||
}
|
||||
|
||||
// filterDNSRequest applies the dnsFilter and sets d.Res if the request was filtered
|
||||
func (s *Server) filterDNSRequest(d *proxy.DNSContext) (*dnsfilter.Result, error) {
|
||||
if !s.conf.ProtectionEnabled || s.dnsFilter == nil {
|
||||
return &dnsfilter.Result{}, nil
|
||||
}
|
||||
|
||||
setts := s.getClientRequestFilteringSettings(d)
|
||||
func (s *Server) filterDNSRequest(d *proxy.DNSContext, setts *dnsfilter.RequestFilteringSettings) (*dnsfilter.Result, error) {
|
||||
req := d.Req
|
||||
host := strings.TrimSuffix(req.Question[0].Name, ".")
|
||||
res, err := s.dnsFilter.CheckHost(host, d.Req.Question[0].Qtype, setts)
|
||||
|
@ -648,7 +651,7 @@ func (s *Server) filterDNSRequest(d *proxy.DNSContext) (*dnsfilter.Result, error
|
|||
|
||||
// If response contains CNAME, A or AAAA records, we apply filtering to each canonical host name or IP address.
|
||||
// If this is a match, we set a new response in d.Res and return.
|
||||
func (s *Server) filterDNSResponse(d *proxy.DNSContext) (*dnsfilter.Result, error) {
|
||||
func (s *Server) filterDNSResponse(d *proxy.DNSContext, setts *dnsfilter.RequestFilteringSettings) (*dnsfilter.Result, error) {
|
||||
for _, a := range d.Res.Answer {
|
||||
host := ""
|
||||
|
||||
|
@ -676,7 +679,6 @@ func (s *Server) filterDNSResponse(d *proxy.DNSContext) (*dnsfilter.Result, erro
|
|||
s.RUnlock()
|
||||
continue
|
||||
}
|
||||
setts := s.getClientRequestFilteringSettings(d)
|
||||
res, err := s.dnsFilter.CheckHostRules(host, d.Req.Question[0].Qtype, setts)
|
||||
s.RUnlock()
|
||||
|
||||
|
|
|
@ -340,6 +340,22 @@ var testIPv4 = map[string][]net.IP{
|
|||
"example.org.": {{127, 0, 0, 255}},
|
||||
}
|
||||
|
||||
func TestBlockCNAMEProtectionEnabled(t *testing.T) {
|
||||
s := createTestServer(t)
|
||||
testUpstm := &testUpstream{testCNAMEs, testIPv4, nil}
|
||||
s.conf.ProtectionEnabled = false
|
||||
err := s.startWithUpstream(testUpstm)
|
||||
assert.True(t, err == nil)
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||
|
||||
// 'badhost' has a canonical name 'null.example.org' which is blocked by filters:
|
||||
// but protection is disabled - response is NOT blocked
|
||||
req := createTestMessage("badhost.")
|
||||
reply, err := dns.Exchange(req, addr.String())
|
||||
assert.True(t, err == nil)
|
||||
assert.True(t, reply.Rcode == dns.RcodeSuccess)
|
||||
}
|
||||
|
||||
func TestBlockCNAME(t *testing.T) {
|
||||
s := createTestServer(t)
|
||||
testUpstm := &testUpstream{testCNAMEs, testIPv4, nil}
|
||||
|
@ -349,35 +365,23 @@ func TestBlockCNAME(t *testing.T) {
|
|||
|
||||
// 'badhost' has a canonical name 'null.example.org' which is blocked by filters:
|
||||
// response is blocked
|
||||
req := dns.Msg{}
|
||||
req.Id = dns.Id()
|
||||
req.Question = []dns.Question{
|
||||
{Name: "badhost.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
|
||||
}
|
||||
reply, err := dns.Exchange(&req, addr.String())
|
||||
req := createTestMessage("badhost.")
|
||||
reply, err := dns.Exchange(req, addr.String())
|
||||
assert.True(t, err == nil)
|
||||
assert.True(t, reply.Rcode == dns.RcodeNameError)
|
||||
|
||||
// 'whitelist.example.org' has a canonical name 'null.example.org' which is blocked by filters
|
||||
// but 'whitelist.example.org' is in a whitelist:
|
||||
// response isn't blocked
|
||||
req = dns.Msg{}
|
||||
req.Id = dns.Id()
|
||||
req.Question = []dns.Question{
|
||||
{Name: "whitelist.example.org.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
|
||||
}
|
||||
reply, err = dns.Exchange(&req, addr.String())
|
||||
req = createTestMessage("whitelist.example.org.")
|
||||
reply, err = dns.Exchange(req, addr.String())
|
||||
assert.True(t, err == nil)
|
||||
assert.True(t, reply.Rcode == dns.RcodeSuccess)
|
||||
|
||||
// 'example.org' has a canonical name 'cname1' with IP 127.0.0.255 which is blocked by filters:
|
||||
// response is blocked
|
||||
req = dns.Msg{}
|
||||
req.Id = dns.Id()
|
||||
req.Question = []dns.Question{
|
||||
{Name: "example.org.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
|
||||
}
|
||||
reply, err = dns.Exchange(&req, addr.String())
|
||||
req = createTestMessage("example.org.")
|
||||
reply, err = dns.Exchange(req, addr.String())
|
||||
assert.True(t, err == nil)
|
||||
assert.True(t, reply.Rcode == dns.RcodeNameError)
|
||||
|
||||
|
|
|
@ -32,12 +32,16 @@ type Client struct {
|
|||
SafeSearchEnabled bool
|
||||
SafeBrowsingEnabled bool
|
||||
ParentalEnabled bool
|
||||
WhoisInfo [][]string // [[key,value], ...]
|
||||
|
||||
UseOwnBlockedServices bool // false: use global settings
|
||||
BlockedServices []string
|
||||
|
||||
Upstreams []string // list of upstream servers to be used for the client's requests
|
||||
// Upstream objects:
|
||||
// nil: not yet initialized
|
||||
// not nil, but empty: initialized, no good upstreams
|
||||
// not nil, not empty: Upstreams ready to be used
|
||||
upstreamObjects []upstream.Upstream
|
||||
}
|
||||
|
||||
type clientSource uint
|
||||
|
@ -63,12 +67,7 @@ type clientsContainer struct {
|
|||
list map[string]*Client // name -> client
|
||||
idIndex map[string]*Client // IP -> client
|
||||
ipHost map[string]*ClientHost // IP -> Hostname
|
||||
|
||||
// cache for Upstream instances that are used in the case
|
||||
// when custom DNS servers are configured for a client
|
||||
upstreamsCache map[string][]upstream.Upstream // name -> []Upstream
|
||||
|
||||
lock sync.Mutex
|
||||
lock sync.Mutex
|
||||
|
||||
// dhcpServer is used for looking up clients IP addresses by MAC addresses
|
||||
dhcpServer *dhcpd.Server
|
||||
|
@ -85,7 +84,6 @@ func (clients *clientsContainer) Init(objects []clientObject, dhcpServer *dhcpd.
|
|||
clients.list = make(map[string]*Client)
|
||||
clients.idIndex = make(map[string]*Client)
|
||||
clients.ipHost = make(map[string]*ClientHost)
|
||||
clients.upstreamsCache = make(map[string][]upstream.Upstream)
|
||||
clients.dhcpServer = dhcpServer
|
||||
clients.addFromConfig(objects)
|
||||
|
||||
|
@ -199,6 +197,12 @@ func (clients *clientsContainer) Find(ip string) (Client, bool) {
|
|||
return clients.findByIP(ip)
|
||||
}
|
||||
|
||||
func upstreamArrayCopy(a []upstream.Upstream) []upstream.Upstream {
|
||||
a2 := make([]upstream.Upstream, len(a))
|
||||
copy(a2, a)
|
||||
return a2
|
||||
}
|
||||
|
||||
// FindUpstreams looks for upstreams configured for the client
|
||||
// If no client found for this IP, or if no custom upstreams are configured,
|
||||
// this method returns nil
|
||||
|
@ -211,31 +215,22 @@ func (clients *clientsContainer) FindUpstreams(ip string) []upstream.Upstream {
|
|||
return nil
|
||||
}
|
||||
|
||||
if len(c.Upstreams) == 0 {
|
||||
if c.upstreamObjects == nil {
|
||||
c.upstreamObjects = make([]upstream.Upstream, 0)
|
||||
for _, us := range c.Upstreams {
|
||||
u, err := upstream.AddressToUpstream(us, upstream.Options{Timeout: dnsforward.DefaultTimeout})
|
||||
if err != nil {
|
||||
log.Error("upstream.AddressToUpstream: %s: %s", us, err)
|
||||
continue
|
||||
}
|
||||
c.upstreamObjects = append(c.upstreamObjects, u)
|
||||
}
|
||||
}
|
||||
|
||||
if len(c.upstreamObjects) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
upstreams, ok := clients.upstreamsCache[c.Name]
|
||||
if ok {
|
||||
return upstreams
|
||||
}
|
||||
|
||||
for _, us := range c.Upstreams {
|
||||
u, err := upstream.AddressToUpstream(us, upstream.Options{Timeout: dnsforward.DefaultTimeout})
|
||||
if err != nil {
|
||||
log.Error("upstream.AddressToUpstream: %s: %s", us, err)
|
||||
continue
|
||||
}
|
||||
upstreams = append(upstreams, u)
|
||||
}
|
||||
|
||||
if len(upstreams) == 0 {
|
||||
clients.upstreamsCache[c.Name] = nil
|
||||
} else {
|
||||
clients.upstreamsCache[c.Name] = upstreams
|
||||
}
|
||||
|
||||
return upstreams
|
||||
return upstreamArrayCopy(c.upstreamObjects)
|
||||
}
|
||||
|
||||
// Find searches for a client by IP (and does not lock anything)
|
||||
|
@ -366,17 +361,6 @@ func (clients *clientsContainer) Add(c Client) (bool, error) {
|
|||
}
|
||||
}
|
||||
|
||||
// remove auto-clients with the same IP address, keeping WHOIS info if possible
|
||||
for _, id := range c.IDs {
|
||||
ch, ok := clients.ipHost[id]
|
||||
if ok {
|
||||
if len(c.WhoisInfo) == 0 {
|
||||
c.WhoisInfo = ch.WhoisInfo
|
||||
}
|
||||
delete(clients.ipHost, id)
|
||||
}
|
||||
}
|
||||
|
||||
// update Name index
|
||||
clients.list[c.Name] = &c
|
||||
|
||||
|
@ -402,9 +386,6 @@ func (clients *clientsContainer) Del(name string) bool {
|
|||
// update Name index
|
||||
delete(clients.list, name)
|
||||
|
||||
// update upstreams cache
|
||||
delete(clients.upstreamsCache, name)
|
||||
|
||||
// update ID index
|
||||
for _, id := range c.IDs {
|
||||
delete(clients.idIndex, id)
|
||||
|
@ -473,11 +454,7 @@ func (clients *clientsContainer) Update(name string, c Client) error {
|
|||
}
|
||||
|
||||
// update upstreams cache
|
||||
if old.Name != c.Name {
|
||||
delete(clients.upstreamsCache, old.Name)
|
||||
} else {
|
||||
delete(clients.upstreamsCache, c.Name)
|
||||
}
|
||||
c.upstreamObjects = nil
|
||||
|
||||
*old = c
|
||||
return nil
|
||||
|
@ -517,12 +494,6 @@ func (clients *clientsContainer) AddHost(ip, host string, source clientSource) (
|
|||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
// check existing clients first
|
||||
_, ok := clients.findByIP(ip)
|
||||
if ok {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// check auto-clients index
|
||||
ch, ok := clients.ipHost[ip]
|
||||
if ok && ch.Source > source {
|
||||
|
|
|
@ -16,8 +16,6 @@ type clientJSON struct {
|
|||
SafeSearchEnabled bool `json:"safesearch_enabled"`
|
||||
SafeBrowsingEnabled bool `json:"safebrowsing_enabled"`
|
||||
|
||||
WhoisInfo map[string]interface{} `json:"whois_info"`
|
||||
|
||||
UseGlobalBlockedServices bool `json:"use_global_blocked_services"`
|
||||
BlockedServices []string `json:"blocked_services"`
|
||||
|
||||
|
@ -116,11 +114,6 @@ func clientToJSON(c *Client) clientJSON {
|
|||
|
||||
Upstreams: c.Upstreams,
|
||||
}
|
||||
|
||||
cj.WhoisInfo = make(map[string]interface{})
|
||||
for _, wi := range c.WhoisInfo {
|
||||
cj.WhoisInfo[wi[0]] = wi[1]
|
||||
}
|
||||
return cj
|
||||
}
|
||||
|
||||
|
|
|
@ -168,18 +168,18 @@ func TestClientsWhois(t *testing.T) {
|
|||
clients.SetWhoisInfo("1.1.1.1", whois)
|
||||
assert.True(t, clients.ipHost["1.1.1.1"].WhoisInfo[0][1] == "orgname-val")
|
||||
|
||||
// Check that we cannot set whois info on existing client
|
||||
// Check that we cannot set whois info on a manually-added client
|
||||
c = Client{
|
||||
IDs: []string{"1.1.1.2"},
|
||||
Name: "client1",
|
||||
}
|
||||
_, _ = clients.Add(c)
|
||||
clients.SetWhoisInfo("1.1.1.2", whois)
|
||||
assert.Nil(t, clients.idIndex["1.1.1.2"].WhoisInfo)
|
||||
assert.True(t, clients.ipHost["1.1.1.2"] == nil)
|
||||
_ = clients.Del("client1")
|
||||
}
|
||||
|
||||
func TestClientsAddExistingHost(t *testing.T) {
|
||||
func TestClientsAddExisting(t *testing.T) {
|
||||
var c Client
|
||||
clients := clientsContainer{}
|
||||
clients.testing = true
|
||||
|
@ -198,9 +198,9 @@ func TestClientsAddExistingHost(t *testing.T) {
|
|||
assert.True(t, ok)
|
||||
assert.Nil(t, err)
|
||||
|
||||
// try adding a duplicate by IP
|
||||
// add an auto-client with the same IP - it's allowed
|
||||
ok, err = clients.AddHost("1.1.1.1", "test", ClientSourceRDNS)
|
||||
assert.False(t, ok)
|
||||
assert.True(t, ok)
|
||||
assert.Nil(t, err)
|
||||
|
||||
// now some more complicated stuff
|
||||
|
@ -218,13 +218,21 @@ func TestClientsAddExistingHost(t *testing.T) {
|
|||
})
|
||||
assert.Nil(t, err)
|
||||
|
||||
// try adding a duplicate IP which for a Mac-based client
|
||||
ok, err = clients.AddHost(testIP, "test", ClientSourceRDNS)
|
||||
assert.False(t, ok)
|
||||
// add a new client with the same IP as for a client with MAC
|
||||
c = Client{
|
||||
IDs: []string{testIP},
|
||||
Name: "client2",
|
||||
}
|
||||
ok, err = clients.Add(c)
|
||||
assert.True(t, ok)
|
||||
assert.Nil(t, err)
|
||||
|
||||
// don't allow duplicates by CIDR
|
||||
ok, err = clients.AddHost("2.2.2.2", "test", ClientSourceRDNS)
|
||||
assert.False(t, ok)
|
||||
// add a new client with the IP from the client1's IP range
|
||||
c = Client{
|
||||
IDs: []string{"2.2.2.2"},
|
||||
Name: "client3",
|
||||
}
|
||||
ok, err = clients.Add(c)
|
||||
assert.True(t, ok)
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue