Merge: + dns rewrites: support wildcard override

Close #922

Squashed commit of the following:

commit 8ab742d84916a02043989fcfa5fc258e84046205
Merge: 2abde92f e398117d
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Thu Jan 16 12:47:44 2020 +0300

    Merge remote-tracking branch 'origin/master' into 922-rewrites

commit 2abde92fe2d1bdb153787b4ecac60f9744c7fd1a
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Thu Jan 16 12:34:41 2020 +0300

    minor

commit 3c20579cde85967786830d2d1b87cd987fc2ae12
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Fri Jan 10 19:06:53 2020 +0300

    don't change the order of rewrite entry list (apply priority at runtime)

commit 4f658f98011a7e31852c0ce379f2a02738f5614f
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Tue Dec 31 15:21:50 2019 +0300

    minor

commit 9e56b4f3abefe3ff1d19bc61375f31dc55242e52
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Tue Dec 31 15:14:14 2019 +0300

    * remove unused code

commit 4178f025a388eb768914306efa91ffead87e5a0c
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Tue Dec 31 15:13:57 2019 +0300

    + dns rewrites: support wildcard override
This commit is contained in:
Simon Zolin 2020-01-16 12:51:35 +03:00
parent e398117d25
commit a6d7511806
4 changed files with 195 additions and 136 deletions

View File

@ -34,15 +34,8 @@ type RequestFilteringSettings struct {
ServicesRules []ServiceEntry ServicesRules []ServiceEntry
} }
// RewriteEntry is a rewrite array element
type RewriteEntry struct {
Domain string `yaml:"domain"`
Answer string `yaml:"answer"` // IP address or canonical name
}
// Config allows you to configure DNS filtering with New() or just change variables directly. // Config allows you to configure DNS filtering with New() or just change variables directly.
type Config struct { type Config struct {
ParentalSensitivity int `yaml:"parental_sensitivity"` // must be either 3, 10, 13 or 17
ParentalEnabled bool `yaml:"parental_enabled"` ParentalEnabled bool `yaml:"parental_enabled"`
SafeSearchEnabled bool `yaml:"safesearch_enabled"` SafeSearchEnabled bool `yaml:"safesearch_enabled"`
SafeBrowsingEnabled bool `yaml:"safebrowsing_enabled"` SafeBrowsingEnabled bool `yaml:"safebrowsing_enabled"`
@ -175,7 +168,10 @@ func (d *Dnsfilter) GetConfig() RequestFilteringSettings {
// WriteDiskConfig - write configuration // WriteDiskConfig - write configuration
func (d *Dnsfilter) WriteDiskConfig(c *Config) { func (d *Dnsfilter) WriteDiskConfig(c *Config) {
d.confLock.Lock()
*c = d.Config *c = d.Config
c.Rewrites = rewriteArrayDup(d.Config.Rewrites)
d.confLock.Unlock()
} }
// SetFilters - set new filters (synchronously or asynchronously) // SetFilters - set new filters (synchronously or asynchronously)
@ -343,17 +339,11 @@ func (d *Dnsfilter) CheckHost(host string, qtype uint16, setts *RequestFiltering
return Result{}, nil return Result{}, nil
} }
// Return TRUE of host name matches a wildcard pattern
func matchDomainWildcard(host, wildcard string) bool {
return len(wildcard) >= 2 &&
wildcard[0] == '*' && wildcard[1] == '.' &&
strings.HasSuffix(host, wildcard[1:])
}
// Process rewrites table // Process rewrites table
// . Find CNAME for a domain name // . Find CNAME for a domain name (exact match or by wildcard)
// . if found, set domain name to canonical name // . if found, set domain name to canonical name
// . Find A or AAAA record for a domain name // . repeat for the new domain name (Note: we return only the last CNAME)
// . Find A or AAAA record for a domain name (exact match or by wildcard)
// . if found, return IP addresses // . if found, return IP addresses
func (d *Dnsfilter) processRewrites(host string, qtype uint16) Result { func (d *Dnsfilter) processRewrites(host string, qtype uint16) Result {
var res Result var res Result
@ -361,48 +351,31 @@ func (d *Dnsfilter) processRewrites(host string, qtype uint16) Result {
d.confLock.RLock() d.confLock.RLock()
defer d.confLock.RUnlock() defer d.confLock.RUnlock()
for _, r := range d.Rewrites { rr := findRewrites(d.Rewrites, host)
if r.Domain != host { if len(rr) != 0 {
if !matchDomainWildcard(host, r.Domain) {
continue
}
}
ip := net.ParseIP(r.Answer)
if ip == nil {
log.Debug("Rewrite: CNAME for %s is %s", host, r.Answer)
host = r.Answer
res.CanonName = r.Answer
res.Reason = ReasonRewrite res.Reason = ReasonRewrite
break
}
} }
for _, r := range d.Rewrites { cnames := map[string]bool{}
if r.Domain != host { origHost := host
if !matchDomainWildcard(host, r.Domain) { for len(rr) != 0 && rr[0].Type == dns.TypeCNAME {
continue log.Debug("Rewrite: CNAME for %s is %s", host, rr[0].Answer)
host = rr[0].Answer
_, ok := cnames[host]
if ok {
log.Info("Rewrite: breaking CNAME redirection loop: %s. Question: %s", host, origHost)
return res
} }
cnames[host] = false
res.CanonName = rr[0].Answer
rr = findRewrites(d.Rewrites, host)
} }
ip := net.ParseIP(r.Answer) for _, r := range rr {
if ip == nil { if r.Type != dns.TypeCNAME && r.Type == qtype {
continue res.IPList = append(res.IPList, r.IP)
log.Debug("Rewrite: A/AAAA for %s is %s", host, r.IP)
} }
ip4 := ip.To4()
if qtype == dns.TypeA && ip4 != nil {
res.IPList = append(res.IPList, ip4)
log.Debug("Rewrite: A for %s is %s", host, ip4)
} else if qtype == dns.TypeAAAA && ip4 == nil {
res.IPList = append(res.IPList, ip)
log.Debug("Rewrite: AAAA for %s is %s", host, ip)
}
}
if len(res.IPList) != 0 {
res.Reason = ReasonRewrite
} }
return res return res
@ -592,6 +565,7 @@ func New(c *Config, filters map[int]string) *Dnsfilter {
if c != nil { if c != nil {
d.Config = *c d.Config = *c
d.prepareRewrites()
} }
if filters != nil { if filters != nil {

View File

@ -317,7 +317,6 @@ func TestSafeSearchCacheGoogle(t *testing.T) {
func TestParentalControl(t *testing.T) { func TestParentalControl(t *testing.T) {
d := NewForTest(&Config{ParentalEnabled: true}, nil) d := NewForTest(&Config{ParentalEnabled: true}, nil)
defer d.Close() defer d.Close()
d.ParentalSensitivity = 3
d.checkMatch(t, "pornhub.com") d.checkMatch(t, "pornhub.com")
d.checkMatch(t, "www.pornhub.com") d.checkMatch(t, "www.pornhub.com")
d.checkMatchEmpty(t, "www.yandex.ru") d.checkMatchEmpty(t, "www.yandex.ru")
@ -429,7 +428,6 @@ func TestClientSettings(t *testing.T) {
filters[0] = "||example.org^\n" filters[0] = "||example.org^\n"
d := NewForTest(&Config{ParentalEnabled: true, SafeBrowsingEnabled: false}, filters) d := NewForTest(&Config{ParentalEnabled: true, SafeBrowsingEnabled: false}, filters)
defer d.Close() defer d.Close()
d.ParentalSensitivity = 3
// no client settings: // no client settings:
@ -485,14 +483,15 @@ func TestRewrites(t *testing.T) {
d := Dnsfilter{} d := Dnsfilter{}
// CNAME, A, AAAA // CNAME, A, AAAA
d.Rewrites = []RewriteEntry{ d.Rewrites = []RewriteEntry{
RewriteEntry{"somecname", "somehost.com"}, RewriteEntry{"somecname", "somehost.com", 0, nil},
RewriteEntry{"somehost.com", "0.0.0.0"}, RewriteEntry{"somehost.com", "0.0.0.0", 0, nil},
RewriteEntry{"host.com", "1.2.3.4"}, RewriteEntry{"host.com", "1.2.3.4", 0, nil},
RewriteEntry{"host.com", "1.2.3.5"}, RewriteEntry{"host.com", "1.2.3.5", 0, nil},
RewriteEntry{"host.com", "1:2:3::4"}, RewriteEntry{"host.com", "1:2:3::4", 0, nil},
RewriteEntry{"www.host.com", "host.com"}, RewriteEntry{"www.host.com", "host.com", 0, nil},
} }
d.prepareRewrites()
r := d.processRewrites("host2.com", dns.TypeA) r := d.processRewrites("host2.com", dns.TypeA)
assert.Equal(t, NotFilteredNotFound, r.Reason) assert.Equal(t, NotFilteredNotFound, r.Reason)
@ -510,9 +509,10 @@ func TestRewrites(t *testing.T) {
// wildcard // wildcard
d.Rewrites = []RewriteEntry{ d.Rewrites = []RewriteEntry{
RewriteEntry{"*.host.com", "1.2.3.5"}, RewriteEntry{"host.com", "1.2.3.4", 0, nil},
RewriteEntry{"host.com", "1.2.3.4"}, RewriteEntry{"*.host.com", "1.2.3.5", 0, nil},
} }
d.prepareRewrites()
r = d.processRewrites("host.com", dns.TypeA) r = d.processRewrites("host.com", dns.TypeA)
assert.Equal(t, ReasonRewrite, r.Reason) assert.Equal(t, ReasonRewrite, r.Reason)
assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4"))) assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4")))
@ -524,15 +524,53 @@ func TestRewrites(t *testing.T) {
r = d.processRewrites("www.host2.com", dns.TypeA) r = d.processRewrites("www.host2.com", dns.TypeA)
assert.Equal(t, NotFilteredNotFound, r.Reason) assert.Equal(t, NotFilteredNotFound, r.Reason)
// override a wildcard
d.Rewrites = []RewriteEntry{
RewriteEntry{"a.host.com", "1.2.3.4", 0, nil},
RewriteEntry{"*.host.com", "1.2.3.5", 0, nil},
}
d.prepareRewrites()
r = d.processRewrites("a.host.com", dns.TypeA)
assert.Equal(t, ReasonRewrite, r.Reason)
assert.True(t, len(r.IPList) == 1)
assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4")))
// wildcard + CNAME // wildcard + CNAME
d.Rewrites = []RewriteEntry{ d.Rewrites = []RewriteEntry{
RewriteEntry{"*.host.com", "host.com"}, RewriteEntry{"host.com", "1.2.3.4", 0, nil},
RewriteEntry{"host.com", "1.2.3.4"}, RewriteEntry{"*.host.com", "host.com", 0, nil},
} }
d.prepareRewrites()
r = d.processRewrites("www.host.com", dns.TypeA) r = d.processRewrites("www.host.com", dns.TypeA)
assert.Equal(t, ReasonRewrite, r.Reason) assert.Equal(t, ReasonRewrite, r.Reason)
assert.Equal(t, "host.com", r.CanonName) assert.Equal(t, "host.com", r.CanonName)
assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4"))) assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4")))
// 2 CNAMEs
d.Rewrites = []RewriteEntry{
RewriteEntry{"b.host.com", "a.host.com", 0, nil},
RewriteEntry{"a.host.com", "host.com", 0, nil},
RewriteEntry{"host.com", "1.2.3.4", 0, nil},
}
d.prepareRewrites()
r = d.processRewrites("b.host.com", dns.TypeA)
assert.Equal(t, ReasonRewrite, r.Reason)
assert.Equal(t, "host.com", r.CanonName)
assert.True(t, len(r.IPList) == 1)
assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4")))
// 2 CNAMEs + wildcard
d.Rewrites = []RewriteEntry{
RewriteEntry{"b.host.com", "a.host.com", 0, nil},
RewriteEntry{"a.host.com", "x.somehost.com", 0, nil},
RewriteEntry{"*.somehost.com", "1.2.3.4", 0, nil},
}
d.prepareRewrites()
r = d.processRewrites("b.host.com", dns.TypeA)
assert.Equal(t, ReasonRewrite, r.Reason)
assert.Equal(t, "x.somehost.com", r.CanonName)
assert.True(t, len(r.IPList) == 1)
assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4")))
} }
// BENCHMARKS // BENCHMARKS

View File

@ -4,11 +4,123 @@ package dnsfilter
import ( import (
"encoding/json" "encoding/json"
"net"
"net/http" "net/http"
"sort"
"strings"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
) )
// RewriteEntry is a rewrite array element
type RewriteEntry struct {
Domain string `yaml:"domain"`
Answer string `yaml:"answer"` // IP address or canonical name
Type uint16 `yaml:"-"` // DNS record type: CNAME, A or AAAA
IP net.IP `yaml:"-"` // Parsed IP address (if Type is A or AAAA)
}
func (r *RewriteEntry) equals(b RewriteEntry) bool {
return r.Domain == b.Domain && r.Answer == b.Answer
}
func isWildcard(host string) bool {
return len(host) >= 2 &&
host[0] == '*' && host[1] == '.'
}
// Return TRUE of host name matches a wildcard pattern
func matchDomainWildcard(host, wildcard string) bool {
return isWildcard(wildcard) &&
strings.HasSuffix(host, wildcard[1:])
}
type rewritesArray []RewriteEntry
func (a rewritesArray) Len() int { return len(a) }
func (a rewritesArray) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
// Priority: CNAME, A/AAAA; exact, wildcard.
func (a rewritesArray) Less(i, j int) bool {
if a[i].Type == dns.TypeCNAME && a[j].Type != dns.TypeCNAME {
return false
} else if a[i].Type != dns.TypeCNAME && a[j].Type == dns.TypeCNAME {
return true
}
if isWildcard(a[i].Domain) && !isWildcard(a[j].Domain) {
return false
} else if !isWildcard(a[i].Domain) && isWildcard(a[j].Domain) {
return true
}
return i < j
}
// Prepare entry for use
func (r *RewriteEntry) prepare() {
ip := net.ParseIP(r.Answer)
if ip == nil {
r.Type = dns.TypeCNAME
return
}
r.IP = ip
r.Type = dns.TypeAAAA
ip4 := ip.To4()
if ip4 != nil {
r.IP = ip4
r.Type = dns.TypeA
}
}
func (d *Dnsfilter) prepareRewrites() {
for i := range d.Rewrites {
d.Rewrites[i].prepare()
}
}
// Get the list of matched rewrite entries.
// Priority: CNAME, A/AAAA; exact, wildcard.
// If matched exactly, don't return wildcard entries.
func findRewrites(a []RewriteEntry, host string) []RewriteEntry {
rr := rewritesArray{}
for _, r := range a {
if r.Domain != host {
if !matchDomainWildcard(host, r.Domain) {
continue
}
}
rr = append(rr, r)
}
if len(rr) == 0 {
return nil
}
sort.Sort(rr)
isWC := isWildcard(rr[0].Domain)
if !isWC {
for i, r := range rr {
if isWildcard(r.Domain) {
rr = rr[:i]
break
}
}
}
return rr
}
func rewriteArrayDup(a []RewriteEntry) []RewriteEntry {
a2 := make([]RewriteEntry, len(a))
copy(a2, a)
return a2
}
type rewriteEntryJSON struct { type rewriteEntryJSON struct {
Domain string `json:"domain"` Domain string `json:"domain"`
Answer string `json:"answer"` Answer string `json:"answer"`
@ -49,6 +161,7 @@ func (d *Dnsfilter) handleRewriteAdd(w http.ResponseWriter, r *http.Request) {
Domain: jsent.Domain, Domain: jsent.Domain,
Answer: jsent.Answer, Answer: jsent.Answer,
} }
ent.prepare()
d.confLock.Lock() d.confLock.Lock()
d.Config.Rewrites = append(d.Config.Rewrites, ent) d.Config.Rewrites = append(d.Config.Rewrites, ent)
d.confLock.Unlock() d.confLock.Unlock()
@ -74,7 +187,7 @@ func (d *Dnsfilter) handleRewriteDelete(w http.ResponseWriter, r *http.Request)
arr := []RewriteEntry{} arr := []RewriteEntry{}
d.confLock.Lock() d.confLock.Lock()
for _, ent := range d.Config.Rewrites { for _, ent := range d.Config.Rewrites {
if ent == entDel { if ent.equals(entDel) {
log.Debug("Rewrites: removed element: %s -> %s", ent.Domain, ent.Answer) log.Debug("Rewrites: removed element: %s -> %s", ent.Domain, ent.Answer)
continue continue
} }

View File

@ -3,19 +3,15 @@
package dnsfilter package dnsfilter
import ( import (
"bufio"
"bytes" "bytes"
"crypto/sha256" "crypto/sha256"
"encoding/binary" "encoding/binary"
"encoding/gob" "encoding/gob"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io"
"net" "net"
"net/http" "net/http"
"strconv"
"strings" "strings"
"time" "time"
@ -325,66 +321,7 @@ func (d *Dnsfilter) handleSafeBrowsingStatus(w http.ResponseWriter, r *http.Requ
} }
} }
func parseParametersFromBody(r io.Reader) (map[string]string, error) {
parameters := map[string]string{}
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := scanner.Text()
if len(line) == 0 {
// skip empty lines
continue
}
parts := strings.SplitN(line, "=", 2)
if len(parts) != 2 {
return parameters, errors.New("Got invalid request body")
}
parameters[strings.TrimSpace(parts[0])] = strings.TrimSpace(parts[1])
}
return parameters, nil
}
func (d *Dnsfilter) handleParentalEnable(w http.ResponseWriter, r *http.Request) { func (d *Dnsfilter) handleParentalEnable(w http.ResponseWriter, r *http.Request) {
parameters, err := parseParametersFromBody(r.Body)
if err != nil {
httpError(r, w, http.StatusBadRequest, "failed to parse parameters from body: %s", err)
return
}
sensitivity, ok := parameters["sensitivity"]
if !ok {
http.Error(w, "Sensitivity parameter was not specified", 400)
return
}
switch sensitivity {
case "3":
break
case "EARLY_CHILDHOOD":
sensitivity = "3"
case "10":
break
case "YOUNG":
sensitivity = "10"
case "13":
break
case "TEEN":
sensitivity = "13"
case "17":
break
case "MATURE":
sensitivity = "17"
default:
http.Error(w, "Sensitivity must be set to valid value", 400)
return
}
i, err := strconv.Atoi(sensitivity)
if err != nil {
http.Error(w, "Sensitivity must be set to valid value", 400)
return
}
d.Config.ParentalSensitivity = i
d.Config.ParentalEnabled = true d.Config.ParentalEnabled = true
d.Config.ConfigModified() d.Config.ConfigModified()
} }
@ -398,9 +335,6 @@ func (d *Dnsfilter) handleParentalStatus(w http.ResponseWriter, r *http.Request)
data := map[string]interface{}{ data := map[string]interface{}{
"enabled": d.Config.ParentalEnabled, "enabled": d.Config.ParentalEnabled,
} }
if d.Config.ParentalEnabled {
data["sensitivity"] = d.Config.ParentalSensitivity
}
jsonVal, err := json.Marshal(data) jsonVal, err := json.Marshal(data)
if err != nil { if err != nil {
httpError(r, w, http.StatusInternalServerError, "Unable to marshal status json: %s", err) httpError(r, w, http.StatusInternalServerError, "Unable to marshal status json: %s", err)