diff --git a/dnsfilter/dnsfilter_test.go b/dnsfilter/dnsfilter_test.go index 532e1c3c..6515b53d 100644 --- a/dnsfilter/dnsfilter_test.go +++ b/dnsfilter/dnsfilter_test.go @@ -532,96 +532,6 @@ func TestClientSettings(t *testing.T) { assert.True(t, r.IsFiltered && r.Reason == FilteredBlockedService) } -func TestRewrites(t *testing.T) { - d := Dnsfilter{} - // CNAME, A, AAAA - d.Rewrites = []RewriteEntry{ - RewriteEntry{"somecname", "somehost.com", 0, nil}, - RewriteEntry{"somehost.com", "0.0.0.0", 0, nil}, - - RewriteEntry{"host.com", "1.2.3.4", 0, nil}, - RewriteEntry{"host.com", "1.2.3.5", 0, nil}, - RewriteEntry{"host.com", "1:2:3::4", 0, nil}, - RewriteEntry{"www.host.com", "host.com", 0, nil}, - } - d.prepareRewrites() - r := d.processRewrites("host2.com") - assert.Equal(t, NotFilteredNotFound, r.Reason) - - r = d.processRewrites("www.host.com") - assert.Equal(t, ReasonRewrite, r.Reason) - assert.Equal(t, "host.com", r.CanonName) - assert.True(t, len(r.IPList) == 3) - assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4"))) - assert.True(t, r.IPList[1].Equal(net.ParseIP("1.2.3.5"))) - assert.True(t, r.IPList[2].Equal(net.ParseIP("1:2:3::4"))) - - // wildcard - d.Rewrites = []RewriteEntry{ - RewriteEntry{"host.com", "1.2.3.4", 0, nil}, - RewriteEntry{"*.host.com", "1.2.3.5", 0, nil}, - } - d.prepareRewrites() - r = d.processRewrites("host.com") - assert.Equal(t, ReasonRewrite, r.Reason) - assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4"))) - - r = d.processRewrites("www.host.com") - assert.Equal(t, ReasonRewrite, r.Reason) - assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.5"))) - - r = d.processRewrites("www.host2.com") - assert.Equal(t, NotFilteredNotFound, r.Reason) - - // override a wildcard - d.Rewrites = []RewriteEntry{ - RewriteEntry{"a.host.com", "1.2.3.4", 0, nil}, - RewriteEntry{"*.host.com", "1.2.3.5", 0, nil}, - } - d.prepareRewrites() - r = d.processRewrites("a.host.com") - assert.Equal(t, ReasonRewrite, r.Reason) - assert.True(t, len(r.IPList) == 1) - assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4"))) - - // wildcard + CNAME - d.Rewrites = []RewriteEntry{ - RewriteEntry{"host.com", "1.2.3.4", 0, nil}, - RewriteEntry{"*.host.com", "host.com", 0, nil}, - } - d.prepareRewrites() - r = d.processRewrites("www.host.com") - assert.Equal(t, ReasonRewrite, r.Reason) - assert.Equal(t, "host.com", r.CanonName) - assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4"))) - - // 2 CNAMEs - d.Rewrites = []RewriteEntry{ - RewriteEntry{"b.host.com", "a.host.com", 0, nil}, - RewriteEntry{"a.host.com", "host.com", 0, nil}, - RewriteEntry{"host.com", "1.2.3.4", 0, nil}, - } - d.prepareRewrites() - r = d.processRewrites("b.host.com") - assert.Equal(t, ReasonRewrite, r.Reason) - assert.Equal(t, "host.com", r.CanonName) - assert.True(t, len(r.IPList) == 1) - assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4"))) - - // 2 CNAMEs + wildcard - d.Rewrites = []RewriteEntry{ - RewriteEntry{"b.host.com", "a.host.com", 0, nil}, - RewriteEntry{"a.host.com", "x.somehost.com", 0, nil}, - RewriteEntry{"*.somehost.com", "1.2.3.4", 0, nil}, - } - d.prepareRewrites() - r = d.processRewrites("b.host.com") - assert.Equal(t, ReasonRewrite, r.Reason) - assert.Equal(t, "x.somehost.com", r.CanonName) - assert.True(t, len(r.IPList) == 1) - assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4"))) -} - func prepareTestDir() string { const dir = "./agh-test" _ = os.RemoveAll(dir) diff --git a/dnsfilter/rewrites.go b/dnsfilter/rewrites.go index 029d3332..cbc02a16 100644 --- a/dnsfilter/rewrites.go +++ b/dnsfilter/rewrites.go @@ -42,7 +42,10 @@ func (a rewritesArray) Len() int { return len(a) } func (a rewritesArray) Swap(i, j int) { a[i], a[j] = a[j], a[i] } -// Priority: CNAME, A/AAAA; exact, wildcard. +// Priority: +// . CNAME > A/AAAA; +// . exact > wildcard; +// . higher level wildcard > lower level wildcard func (a rewritesArray) Less(i, j int) bool { if a[i].Type == dns.TypeCNAME && a[j].Type != dns.TypeCNAME { return false @@ -50,13 +53,18 @@ func (a rewritesArray) Less(i, j int) bool { return true } - if isWildcard(a[i].Domain) && !isWildcard(a[j].Domain) { - return false - } else if !isWildcard(a[i].Domain) && isWildcard(a[j].Domain) { - return true + if isWildcard(a[i].Domain) { + if !isWildcard(a[j].Domain) { + return false + } + } else { + if isWildcard(a[j].Domain) { + return true + } } - return i < j + // both are wildcards + return len(a[i].Domain) > len(a[j].Domain) } // Prepare entry for use @@ -86,6 +94,7 @@ func (d *Dnsfilter) prepareRewrites() { // Get the list of matched rewrite entries. // Priority: CNAME, A/AAAA; exact, wildcard. // If matched exactly, don't return wildcard entries. +// If matched by several wildcards, select the more specific one func findRewrites(a []RewriteEntry, host string) []RewriteEntry { rr := rewritesArray{} for _, r := range a { @@ -111,7 +120,10 @@ func findRewrites(a []RewriteEntry, host string) []RewriteEntry { break } } + } else { + rr = rr[:1] } + return rr } diff --git a/dnsfilter/rewrites_test.go b/dnsfilter/rewrites_test.go new file mode 100644 index 00000000..6da3e0f9 --- /dev/null +++ b/dnsfilter/rewrites_test.go @@ -0,0 +1,127 @@ +package dnsfilter + +import ( + "net" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRewrites(t *testing.T) { + d := Dnsfilter{} + // CNAME, A, AAAA + d.Rewrites = []RewriteEntry{ + RewriteEntry{"somecname", "somehost.com", 0, nil}, + RewriteEntry{"somehost.com", "0.0.0.0", 0, nil}, + + RewriteEntry{"host.com", "1.2.3.4", 0, nil}, + RewriteEntry{"host.com", "1.2.3.5", 0, nil}, + RewriteEntry{"host.com", "1:2:3::4", 0, nil}, + RewriteEntry{"www.host.com", "host.com", 0, nil}, + } + d.prepareRewrites() + r := d.processRewrites("host2.com") + assert.Equal(t, NotFilteredNotFound, r.Reason) + + r = d.processRewrites("www.host.com") + assert.Equal(t, ReasonRewrite, r.Reason) + assert.Equal(t, "host.com", r.CanonName) + assert.True(t, len(r.IPList) == 3) + assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4"))) + assert.True(t, r.IPList[1].Equal(net.ParseIP("1.2.3.5"))) + assert.True(t, r.IPList[2].Equal(net.ParseIP("1:2:3::4"))) + + // wildcard + d.Rewrites = []RewriteEntry{ + RewriteEntry{"host.com", "1.2.3.4", 0, nil}, + RewriteEntry{"*.host.com", "1.2.3.5", 0, nil}, + } + d.prepareRewrites() + r = d.processRewrites("host.com") + assert.Equal(t, ReasonRewrite, r.Reason) + assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4"))) + + r = d.processRewrites("www.host.com") + assert.Equal(t, ReasonRewrite, r.Reason) + assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.5"))) + + r = d.processRewrites("www.host2.com") + assert.Equal(t, NotFilteredNotFound, r.Reason) + + // override a wildcard + d.Rewrites = []RewriteEntry{ + RewriteEntry{"a.host.com", "1.2.3.4", 0, nil}, + RewriteEntry{"*.host.com", "1.2.3.5", 0, nil}, + } + d.prepareRewrites() + r = d.processRewrites("a.host.com") + assert.Equal(t, ReasonRewrite, r.Reason) + assert.True(t, len(r.IPList) == 1) + assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4"))) + + // wildcard + CNAME + d.Rewrites = []RewriteEntry{ + RewriteEntry{"host.com", "1.2.3.4", 0, nil}, + RewriteEntry{"*.host.com", "host.com", 0, nil}, + } + d.prepareRewrites() + r = d.processRewrites("www.host.com") + assert.Equal(t, ReasonRewrite, r.Reason) + assert.Equal(t, "host.com", r.CanonName) + assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4"))) + + // 2 CNAMEs + d.Rewrites = []RewriteEntry{ + RewriteEntry{"b.host.com", "a.host.com", 0, nil}, + RewriteEntry{"a.host.com", "host.com", 0, nil}, + RewriteEntry{"host.com", "1.2.3.4", 0, nil}, + } + d.prepareRewrites() + r = d.processRewrites("b.host.com") + assert.Equal(t, ReasonRewrite, r.Reason) + assert.Equal(t, "host.com", r.CanonName) + assert.True(t, len(r.IPList) == 1) + assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4"))) + + // 2 CNAMEs + wildcard + d.Rewrites = []RewriteEntry{ + RewriteEntry{"b.host.com", "a.host.com", 0, nil}, + RewriteEntry{"a.host.com", "x.somehost.com", 0, nil}, + RewriteEntry{"*.somehost.com", "1.2.3.4", 0, nil}, + } + d.prepareRewrites() + r = d.processRewrites("b.host.com") + assert.Equal(t, ReasonRewrite, r.Reason) + assert.Equal(t, "x.somehost.com", r.CanonName) + assert.True(t, len(r.IPList) == 1) + assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4"))) +} + +func TestRewritesLevels(t *testing.T) { + d := Dnsfilter{} + // exact host, wildcard L2, wildcard L3 + d.Rewrites = []RewriteEntry{ + RewriteEntry{"host.com", "1.1.1.1", 0, nil}, + RewriteEntry{"*.host.com", "2.2.2.2", 0, nil}, + RewriteEntry{"*.sub.host.com", "3.3.3.3", 0, nil}, + } + d.prepareRewrites() + + // match exact + r := d.processRewrites("host.com") + assert.Equal(t, ReasonRewrite, r.Reason) + assert.Equal(t, 1, len(r.IPList)) + assert.Equal(t, "1.1.1.1", r.IPList[0].String()) + + // match L2 + r = d.processRewrites("sub.host.com") + assert.Equal(t, ReasonRewrite, r.Reason) + assert.Equal(t, 1, len(r.IPList)) + assert.Equal(t, "2.2.2.2", r.IPList[0].String()) + + // match L3 + r = d.processRewrites("my.sub.host.com") + assert.Equal(t, ReasonRewrite, r.Reason) + assert.Equal(t, 1, len(r.IPList)) + assert.Equal(t, "3.3.3.3", r.IPList[0].String()) +}