diff --git a/AGHTechDoc.md b/AGHTechDoc.md index 17e94307..3e3574e2 100644 --- a/AGHTechDoc.md +++ b/AGHTechDoc.md @@ -986,6 +986,110 @@ Response: This section allows the administrator to easily configure custom DNS response for a specific domain name. A, AAAA and CNAME records are supported. +Syntax: + + key -> value + +where `key` is a host name or a wild card that matches Question in DNS request +and `value` is either: +* IPv4 address: use this IP in A response +* IPv6 address: use this IP in AAAA response +* canonical name: add CNAME record +* "": CNAME exception - pass request to upstream +* "A": A exception - pass A request to upstream +* "AAAA": AAAA exception - pass AAAA request to upstream + + +#### Example: A record + + host.com -> 1.2.3.4 + +Response: + + A: + A = 1.2.3.4 + AAAA: + + +#### Example: AAAA record + + host.com -> ::1 + +Response: + + A: + + AAAA: + AAAA = ::1 + +#### Example: CNAME record + + sub.host.com -> host.com + +Response: + + A: + CNAME = host.com + A = + AAAA: + CNAME = host.com + AAAA = + +#### Example: CNAME+A records + + sub.host.com -> host.com + host.com -> 1.2.3.4 + +Response: + + A: + CNAME = host.com + A = 1.2.3.4 + AAAA: + CNAME = host.com + +#### Example: Wildcard CNAME+A record with CNAME exception + + *.host.com -> 1.2.3.4 + pass.host.com -> pass.host.com + +Response to `my.host.com`: + + A: + A = 1.2.3.4 + AAAA: + + +Response to `pass.host.com`: + + A: + A = + AAAA: + AAAA = + +#### Example: A record with AAAA exception + + host.com -> 1.2.3.4 + host.com -> AAAA + +Response: + + A: + A = 1.2.3.4 + AAAA: + AAAA = + +#### Example: pass A only + + host.com -> A + +Response: + + A: + A = + AAAA: + + ### API: List rewrite entries diff --git a/dnsfilter/dnsfilter.go b/dnsfilter/dnsfilter.go index 1014c462..71c73b1f 100644 --- a/dnsfilter/dnsfilter.go +++ b/dnsfilter/dnsfilter.go @@ -316,7 +316,7 @@ func (d *Dnsfilter) CheckHost(host string, qtype uint16, setts *RequestFiltering var result Result var err error - result = d.processRewrites(host) + result = d.processRewrites(host, qtype) if result.Reason == ReasonRewrite { return result, nil } @@ -398,8 +398,8 @@ func (d *Dnsfilter) CheckHost(host string, qtype uint16, setts *RequestFiltering // . if found, set domain name to canonical name // . repeat for the new domain name (Note: we return only the last CNAME) // . Find A or AAAA record for a domain name (exact match or by wildcard) -// . if found, return IP addresses (both IPv4 and IPv6) -func (d *Dnsfilter) processRewrites(host string) Result { +// . if found, set IP addresses (IPv4 or IPv6 depending on qtype) in Result.IPList array +func (d *Dnsfilter) processRewrites(host string, qtype uint16) Result { var res Result d.confLock.RLock() @@ -432,7 +432,14 @@ func (d *Dnsfilter) processRewrites(host string) Result { } for _, r := range rr { - if r.Type != dns.TypeCNAME { + if (r.Type == dns.TypeA && qtype == dns.TypeA) || + (r.Type == dns.TypeAAAA && qtype == dns.TypeAAAA) { + + if r.IP == nil { // IP exception + res.Reason = 0 + return res + } + res.IPList = append(res.IPList, r.IP) log.Debug("Rewrite: A/AAAA for %s is %s", host, r.IP) } diff --git a/dnsfilter/rewrites.go b/dnsfilter/rewrites.go index 166ff0bb..9c042228 100644 --- a/dnsfilter/rewrites.go +++ b/dnsfilter/rewrites.go @@ -69,6 +69,16 @@ func (a rewritesArray) Less(i, j int) bool { // Prepare entry for use func (r *RewriteEntry) prepare() { + if r.Answer == "AAAA" { + r.IP = nil + r.Type = dns.TypeAAAA + return + } else if r.Answer == "A" { + r.IP = nil + r.Type = dns.TypeA + return + } + ip := net.ParseIP(r.Answer) if ip == nil { r.Type = dns.TypeCNAME diff --git a/dnsfilter/rewrites_test.go b/dnsfilter/rewrites_test.go index aac9cedd..2ed8210e 100644 --- a/dnsfilter/rewrites_test.go +++ b/dnsfilter/rewrites_test.go @@ -4,6 +4,7 @@ import ( "net" "testing" + "github.com/miekg/dns" "github.com/stretchr/testify/assert" ) @@ -20,16 +21,21 @@ func TestRewrites(t *testing.T) { RewriteEntry{"www.host.com", "host.com", 0, nil}, } d.prepareRewrites() - r := d.processRewrites("host2.com") + r := d.processRewrites("host2.com", dns.TypeA) assert.Equal(t, NotFilteredNotFound, r.Reason) - r = d.processRewrites("www.host.com") + r = d.processRewrites("www.host.com", dns.TypeA) assert.Equal(t, ReasonRewrite, r.Reason) assert.Equal(t, "host.com", r.CanonName) - assert.True(t, len(r.IPList) == 3) + assert.Equal(t, 2, len(r.IPList)) assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4"))) assert.True(t, r.IPList[1].Equal(net.ParseIP("1.2.3.5"))) - assert.True(t, r.IPList[2].Equal(net.ParseIP("1:2:3::4"))) + + r = d.processRewrites("www.host.com", dns.TypeAAAA) + assert.Equal(t, ReasonRewrite, r.Reason) + assert.Equal(t, "host.com", r.CanonName) + assert.Equal(t, 1, len(r.IPList)) + assert.True(t, r.IPList[0].Equal(net.ParseIP("1:2:3::4"))) // wildcard d.Rewrites = []RewriteEntry{ @@ -37,15 +43,15 @@ func TestRewrites(t *testing.T) { RewriteEntry{"*.host.com", "1.2.3.5", 0, nil}, } d.prepareRewrites() - r = d.processRewrites("host.com") + r = d.processRewrites("host.com", dns.TypeA) assert.Equal(t, ReasonRewrite, r.Reason) assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4"))) - r = d.processRewrites("www.host.com") + r = d.processRewrites("www.host.com", dns.TypeA) assert.Equal(t, ReasonRewrite, r.Reason) assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.5"))) - r = d.processRewrites("www.host2.com") + r = d.processRewrites("www.host2.com", dns.TypeA) assert.Equal(t, NotFilteredNotFound, r.Reason) // override a wildcard @@ -54,7 +60,7 @@ func TestRewrites(t *testing.T) { RewriteEntry{"*.host.com", "1.2.3.5", 0, nil}, } d.prepareRewrites() - r = d.processRewrites("a.host.com") + r = d.processRewrites("a.host.com", dns.TypeA) assert.Equal(t, ReasonRewrite, r.Reason) assert.True(t, len(r.IPList) == 1) assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4"))) @@ -65,7 +71,7 @@ func TestRewrites(t *testing.T) { RewriteEntry{"*.host.com", "host.com", 0, nil}, } d.prepareRewrites() - r = d.processRewrites("www.host.com") + r = d.processRewrites("www.host.com", dns.TypeA) assert.Equal(t, ReasonRewrite, r.Reason) assert.Equal(t, "host.com", r.CanonName) assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4"))) @@ -77,7 +83,7 @@ func TestRewrites(t *testing.T) { RewriteEntry{"host.com", "1.2.3.4", 0, nil}, } d.prepareRewrites() - r = d.processRewrites("b.host.com") + r = d.processRewrites("b.host.com", dns.TypeA) assert.Equal(t, ReasonRewrite, r.Reason) assert.Equal(t, "host.com", r.CanonName) assert.True(t, len(r.IPList) == 1) @@ -90,7 +96,7 @@ func TestRewrites(t *testing.T) { RewriteEntry{"*.somehost.com", "1.2.3.4", 0, nil}, } d.prepareRewrites() - r = d.processRewrites("b.host.com") + r = d.processRewrites("b.host.com", dns.TypeA) assert.Equal(t, ReasonRewrite, r.Reason) assert.Equal(t, "x.somehost.com", r.CanonName) assert.True(t, len(r.IPList) == 1) @@ -108,25 +114,25 @@ func TestRewritesLevels(t *testing.T) { d.prepareRewrites() // match exact - r := d.processRewrites("host.com") + r := d.processRewrites("host.com", dns.TypeA) assert.Equal(t, ReasonRewrite, r.Reason) assert.Equal(t, 1, len(r.IPList)) assert.Equal(t, "1.1.1.1", r.IPList[0].String()) // match L2 - r = d.processRewrites("sub.host.com") + r = d.processRewrites("sub.host.com", dns.TypeA) assert.Equal(t, ReasonRewrite, r.Reason) assert.Equal(t, 1, len(r.IPList)) assert.Equal(t, "2.2.2.2", r.IPList[0].String()) // match L3 - r = d.processRewrites("my.sub.host.com") + r = d.processRewrites("my.sub.host.com", dns.TypeA) assert.Equal(t, ReasonRewrite, r.Reason) assert.Equal(t, 1, len(r.IPList)) assert.Equal(t, "3.3.3.3", r.IPList[0].String()) } -func TestRewritesException(t *testing.T) { +func TestRewritesExceptionCNAME(t *testing.T) { d := Dnsfilter{} // wildcard; exception for a sub-domain d.Rewrites = []RewriteEntry{ @@ -136,13 +142,13 @@ func TestRewritesException(t *testing.T) { d.prepareRewrites() // match sub-domain - r := d.processRewrites("my.host.com") + r := d.processRewrites("my.host.com", dns.TypeA) assert.Equal(t, ReasonRewrite, r.Reason) assert.Equal(t, 1, len(r.IPList)) assert.Equal(t, "2.2.2.2", r.IPList[0].String()) // match sub-domain, but handle exception - r = d.processRewrites("sub.host.com") + r = d.processRewrites("sub.host.com", dns.TypeA) assert.Equal(t, NotFilteredNotFound, r.Reason) } @@ -156,12 +162,54 @@ func TestRewritesExceptionWC(t *testing.T) { d.prepareRewrites() // match sub-domain - r := d.processRewrites("my.host.com") + r := d.processRewrites("my.host.com", dns.TypeA) assert.Equal(t, ReasonRewrite, r.Reason) assert.Equal(t, 1, len(r.IPList)) assert.Equal(t, "2.2.2.2", r.IPList[0].String()) // match sub-domain, but handle exception - r = d.processRewrites("my.sub.host.com") + r = d.processRewrites("my.sub.host.com", dns.TypeA) assert.Equal(t, NotFilteredNotFound, r.Reason) } + +func TestRewritesExceptionIP(t *testing.T) { + d := Dnsfilter{} + // exception for AAAA record + d.Rewrites = []RewriteEntry{ + RewriteEntry{"host.com", "1.2.3.4", 0, nil}, + RewriteEntry{"host.com", "AAAA", 0, nil}, + RewriteEntry{"host2.com", "::1", 0, nil}, + RewriteEntry{"host2.com", "A", 0, nil}, + RewriteEntry{"host3.com", "A", 0, nil}, + } + d.prepareRewrites() + + // match domain + r := d.processRewrites("host.com", dns.TypeA) + assert.Equal(t, ReasonRewrite, r.Reason) + assert.Equal(t, 1, len(r.IPList)) + assert.Equal(t, "1.2.3.4", r.IPList[0].String()) + + // match exception + r = d.processRewrites("host.com", dns.TypeAAAA) + assert.Equal(t, NotFilteredNotFound, r.Reason) + + // match exception + r = d.processRewrites("host2.com", dns.TypeA) + assert.Equal(t, NotFilteredNotFound, r.Reason) + + // match domain + r = d.processRewrites("host2.com", dns.TypeAAAA) + assert.Equal(t, ReasonRewrite, r.Reason) + assert.Equal(t, 1, len(r.IPList)) + assert.Equal(t, "::1", r.IPList[0].String()) + + // match exception + r = d.processRewrites("host3.com", dns.TypeA) + assert.Equal(t, NotFilteredNotFound, r.Reason) + + // match domain + r = d.processRewrites("host3.com", dns.TypeAAAA) + assert.Equal(t, ReasonRewrite, r.Reason) + assert.Equal(t, 0, len(r.IPList)) +} diff --git a/dnsforward/filter.go b/dnsforward/filter.go index 08c81c40..1591d74f 100644 --- a/dnsforward/filter.go +++ b/dnsforward/filter.go @@ -65,12 +65,11 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*dnsfilter.Result, error) { } for _, ip := range res.IPList { - ip4 := ip.To4() - if req.Question[0].Qtype == dns.TypeA && ip4 != nil { - a := s.genAAnswer(req, ip4) + if req.Question[0].Qtype == dns.TypeA { + a := s.genAAnswer(req, ip.To4()) a.Hdr.Name = dns.Fqdn(name) resp.Answer = append(resp.Answer, a) - } else if req.Question[0].Qtype == dns.TypeAAAA && ip4 == nil { + } else if req.Question[0].Qtype == dns.TypeAAAA { a := s.genAAAAAnswer(req, ip) a.Hdr.Name = dns.Fqdn(name) resp.Answer = append(resp.Answer, a)