diff --git a/dnsfilter/dnsfilter.go b/dnsfilter/dnsfilter.go index 25ca9fa1..61493f05 100644 --- a/dnsfilter/dnsfilter.go +++ b/dnsfilter/dnsfilter.go @@ -34,15 +34,8 @@ type RequestFilteringSettings struct { ServicesRules []ServiceEntry } -// RewriteEntry is a rewrite array element -type RewriteEntry struct { - Domain string `yaml:"domain"` - Answer string `yaml:"answer"` // IP address or canonical name -} - // Config allows you to configure DNS filtering with New() or just change variables directly. type Config struct { - ParentalSensitivity int `yaml:"parental_sensitivity"` // must be either 3, 10, 13 or 17 ParentalEnabled bool `yaml:"parental_enabled"` SafeSearchEnabled bool `yaml:"safesearch_enabled"` SafeBrowsingEnabled bool `yaml:"safebrowsing_enabled"` @@ -175,7 +168,10 @@ func (d *Dnsfilter) GetConfig() RequestFilteringSettings { // WriteDiskConfig - write configuration func (d *Dnsfilter) WriteDiskConfig(c *Config) { + d.confLock.Lock() *c = d.Config + c.Rewrites = rewriteArrayDup(d.Config.Rewrites) + d.confLock.Unlock() } // SetFilters - set new filters (synchronously or asynchronously) @@ -343,17 +339,11 @@ func (d *Dnsfilter) CheckHost(host string, qtype uint16, setts *RequestFiltering return Result{}, nil } -// Return TRUE of host name matches a wildcard pattern -func matchDomainWildcard(host, wildcard string) bool { - return len(wildcard) >= 2 && - wildcard[0] == '*' && wildcard[1] == '.' && - strings.HasSuffix(host, wildcard[1:]) -} - // Process rewrites table -// . Find CNAME for a domain name +// . Find CNAME for a domain name (exact match or by wildcard) // . if found, set domain name to canonical name -// . Find A or AAAA record for a domain name +// . repeat for the new domain name (Note: we return only the last CNAME) +// . Find A or AAAA record for a domain name (exact match or by wildcard) // . if found, return IP addresses func (d *Dnsfilter) processRewrites(host string, qtype uint16) Result { var res Result @@ -361,50 +351,33 @@ func (d *Dnsfilter) processRewrites(host string, qtype uint16) Result { d.confLock.RLock() defer d.confLock.RUnlock() - for _, r := range d.Rewrites { - if r.Domain != host { - if !matchDomainWildcard(host, r.Domain) { - continue - } - } - - ip := net.ParseIP(r.Answer) - if ip == nil { - log.Debug("Rewrite: CNAME for %s is %s", host, r.Answer) - host = r.Answer - res.CanonName = r.Answer - res.Reason = ReasonRewrite - break - } - } - - for _, r := range d.Rewrites { - if r.Domain != host { - if !matchDomainWildcard(host, r.Domain) { - continue - } - } - - ip := net.ParseIP(r.Answer) - if ip == nil { - continue - } - ip4 := ip.To4() - - if qtype == dns.TypeA && ip4 != nil { - res.IPList = append(res.IPList, ip4) - log.Debug("Rewrite: A for %s is %s", host, ip4) - - } else if qtype == dns.TypeAAAA && ip4 == nil { - res.IPList = append(res.IPList, ip) - log.Debug("Rewrite: AAAA for %s is %s", host, ip) - } - } - - if len(res.IPList) != 0 { + rr := findRewrites(d.Rewrites, host) + if len(rr) != 0 { res.Reason = ReasonRewrite } + cnames := map[string]bool{} + origHost := host + for len(rr) != 0 && rr[0].Type == dns.TypeCNAME { + log.Debug("Rewrite: CNAME for %s is %s", host, rr[0].Answer) + host = rr[0].Answer + _, ok := cnames[host] + if ok { + log.Info("Rewrite: breaking CNAME redirection loop: %s. Question: %s", host, origHost) + return res + } + cnames[host] = false + res.CanonName = rr[0].Answer + rr = findRewrites(d.Rewrites, host) + } + + for _, r := range rr { + if r.Type != dns.TypeCNAME && r.Type == qtype { + res.IPList = append(res.IPList, r.IP) + log.Debug("Rewrite: A/AAAA for %s is %s", host, r.IP) + } + } + return res } @@ -592,6 +565,7 @@ func New(c *Config, filters map[int]string) *Dnsfilter { if c != nil { d.Config = *c + d.prepareRewrites() } if filters != nil { diff --git a/dnsfilter/dnsfilter_test.go b/dnsfilter/dnsfilter_test.go index 2881d0bf..3d1f07dd 100644 --- a/dnsfilter/dnsfilter_test.go +++ b/dnsfilter/dnsfilter_test.go @@ -317,7 +317,6 @@ func TestSafeSearchCacheGoogle(t *testing.T) { func TestParentalControl(t *testing.T) { d := NewForTest(&Config{ParentalEnabled: true}, nil) defer d.Close() - d.ParentalSensitivity = 3 d.checkMatch(t, "pornhub.com") d.checkMatch(t, "www.pornhub.com") d.checkMatchEmpty(t, "www.yandex.ru") @@ -429,7 +428,6 @@ func TestClientSettings(t *testing.T) { filters[0] = "||example.org^\n" d := NewForTest(&Config{ParentalEnabled: true, SafeBrowsingEnabled: false}, filters) defer d.Close() - d.ParentalSensitivity = 3 // no client settings: @@ -485,14 +483,15 @@ func TestRewrites(t *testing.T) { d := Dnsfilter{} // CNAME, A, AAAA d.Rewrites = []RewriteEntry{ - RewriteEntry{"somecname", "somehost.com"}, - RewriteEntry{"somehost.com", "0.0.0.0"}, + RewriteEntry{"somecname", "somehost.com", 0, nil}, + RewriteEntry{"somehost.com", "0.0.0.0", 0, nil}, - RewriteEntry{"host.com", "1.2.3.4"}, - RewriteEntry{"host.com", "1.2.3.5"}, - RewriteEntry{"host.com", "1:2:3::4"}, - RewriteEntry{"www.host.com", "host.com"}, + RewriteEntry{"host.com", "1.2.3.4", 0, nil}, + RewriteEntry{"host.com", "1.2.3.5", 0, nil}, + RewriteEntry{"host.com", "1:2:3::4", 0, nil}, + RewriteEntry{"www.host.com", "host.com", 0, nil}, } + d.prepareRewrites() r := d.processRewrites("host2.com", dns.TypeA) assert.Equal(t, NotFilteredNotFound, r.Reason) @@ -510,9 +509,10 @@ func TestRewrites(t *testing.T) { // wildcard d.Rewrites = []RewriteEntry{ - RewriteEntry{"*.host.com", "1.2.3.5"}, - RewriteEntry{"host.com", "1.2.3.4"}, + RewriteEntry{"host.com", "1.2.3.4", 0, nil}, + RewriteEntry{"*.host.com", "1.2.3.5", 0, nil}, } + d.prepareRewrites() r = d.processRewrites("host.com", dns.TypeA) assert.Equal(t, ReasonRewrite, r.Reason) assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4"))) @@ -524,15 +524,53 @@ func TestRewrites(t *testing.T) { r = d.processRewrites("www.host2.com", dns.TypeA) assert.Equal(t, NotFilteredNotFound, r.Reason) + // override a wildcard + d.Rewrites = []RewriteEntry{ + RewriteEntry{"a.host.com", "1.2.3.4", 0, nil}, + RewriteEntry{"*.host.com", "1.2.3.5", 0, nil}, + } + d.prepareRewrites() + r = d.processRewrites("a.host.com", dns.TypeA) + assert.Equal(t, ReasonRewrite, r.Reason) + assert.True(t, len(r.IPList) == 1) + assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4"))) + // wildcard + CNAME d.Rewrites = []RewriteEntry{ - RewriteEntry{"*.host.com", "host.com"}, - RewriteEntry{"host.com", "1.2.3.4"}, + RewriteEntry{"host.com", "1.2.3.4", 0, nil}, + RewriteEntry{"*.host.com", "host.com", 0, nil}, } + d.prepareRewrites() r = d.processRewrites("www.host.com", dns.TypeA) assert.Equal(t, ReasonRewrite, r.Reason) assert.Equal(t, "host.com", r.CanonName) assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4"))) + + // 2 CNAMEs + d.Rewrites = []RewriteEntry{ + RewriteEntry{"b.host.com", "a.host.com", 0, nil}, + RewriteEntry{"a.host.com", "host.com", 0, nil}, + RewriteEntry{"host.com", "1.2.3.4", 0, nil}, + } + d.prepareRewrites() + r = d.processRewrites("b.host.com", dns.TypeA) + assert.Equal(t, ReasonRewrite, r.Reason) + assert.Equal(t, "host.com", r.CanonName) + assert.True(t, len(r.IPList) == 1) + assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4"))) + + // 2 CNAMEs + wildcard + d.Rewrites = []RewriteEntry{ + RewriteEntry{"b.host.com", "a.host.com", 0, nil}, + RewriteEntry{"a.host.com", "x.somehost.com", 0, nil}, + RewriteEntry{"*.somehost.com", "1.2.3.4", 0, nil}, + } + d.prepareRewrites() + r = d.processRewrites("b.host.com", dns.TypeA) + assert.Equal(t, ReasonRewrite, r.Reason) + assert.Equal(t, "x.somehost.com", r.CanonName) + assert.True(t, len(r.IPList) == 1) + assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4"))) } // BENCHMARKS diff --git a/dnsfilter/rewrites.go b/dnsfilter/rewrites.go index 6cc18784..029d3332 100644 --- a/dnsfilter/rewrites.go +++ b/dnsfilter/rewrites.go @@ -4,11 +4,123 @@ package dnsfilter import ( "encoding/json" + "net" "net/http" + "sort" + "strings" "github.com/AdguardTeam/golibs/log" + "github.com/miekg/dns" ) +// RewriteEntry is a rewrite array element +type RewriteEntry struct { + Domain string `yaml:"domain"` + Answer string `yaml:"answer"` // IP address or canonical name + Type uint16 `yaml:"-"` // DNS record type: CNAME, A or AAAA + IP net.IP `yaml:"-"` // Parsed IP address (if Type is A or AAAA) +} + +func (r *RewriteEntry) equals(b RewriteEntry) bool { + return r.Domain == b.Domain && r.Answer == b.Answer +} + +func isWildcard(host string) bool { + return len(host) >= 2 && + host[0] == '*' && host[1] == '.' +} + +// Return TRUE of host name matches a wildcard pattern +func matchDomainWildcard(host, wildcard string) bool { + return isWildcard(wildcard) && + strings.HasSuffix(host, wildcard[1:]) +} + +type rewritesArray []RewriteEntry + +func (a rewritesArray) Len() int { return len(a) } + +func (a rewritesArray) Swap(i, j int) { a[i], a[j] = a[j], a[i] } + +// Priority: CNAME, A/AAAA; exact, wildcard. +func (a rewritesArray) Less(i, j int) bool { + if a[i].Type == dns.TypeCNAME && a[j].Type != dns.TypeCNAME { + return false + } else if a[i].Type != dns.TypeCNAME && a[j].Type == dns.TypeCNAME { + return true + } + + if isWildcard(a[i].Domain) && !isWildcard(a[j].Domain) { + return false + } else if !isWildcard(a[i].Domain) && isWildcard(a[j].Domain) { + return true + } + + return i < j +} + +// Prepare entry for use +func (r *RewriteEntry) prepare() { + ip := net.ParseIP(r.Answer) + if ip == nil { + r.Type = dns.TypeCNAME + return + } + + r.IP = ip + r.Type = dns.TypeAAAA + + ip4 := ip.To4() + if ip4 != nil { + r.IP = ip4 + r.Type = dns.TypeA + } +} + +func (d *Dnsfilter) prepareRewrites() { + for i := range d.Rewrites { + d.Rewrites[i].prepare() + } +} + +// Get the list of matched rewrite entries. +// Priority: CNAME, A/AAAA; exact, wildcard. +// If matched exactly, don't return wildcard entries. +func findRewrites(a []RewriteEntry, host string) []RewriteEntry { + rr := rewritesArray{} + for _, r := range a { + if r.Domain != host { + if !matchDomainWildcard(host, r.Domain) { + continue + } + } + rr = append(rr, r) + } + + if len(rr) == 0 { + return nil + } + + sort.Sort(rr) + + isWC := isWildcard(rr[0].Domain) + if !isWC { + for i, r := range rr { + if isWildcard(r.Domain) { + rr = rr[:i] + break + } + } + } + return rr +} + +func rewriteArrayDup(a []RewriteEntry) []RewriteEntry { + a2 := make([]RewriteEntry, len(a)) + copy(a2, a) + return a2 +} + type rewriteEntryJSON struct { Domain string `json:"domain"` Answer string `json:"answer"` @@ -49,6 +161,7 @@ func (d *Dnsfilter) handleRewriteAdd(w http.ResponseWriter, r *http.Request) { Domain: jsent.Domain, Answer: jsent.Answer, } + ent.prepare() d.confLock.Lock() d.Config.Rewrites = append(d.Config.Rewrites, ent) d.confLock.Unlock() @@ -74,7 +187,7 @@ func (d *Dnsfilter) handleRewriteDelete(w http.ResponseWriter, r *http.Request) arr := []RewriteEntry{} d.confLock.Lock() for _, ent := range d.Config.Rewrites { - if ent == entDel { + if ent.equals(entDel) { log.Debug("Rewrites: removed element: %s -> %s", ent.Domain, ent.Answer) continue } diff --git a/dnsfilter/security.go b/dnsfilter/security.go index 14751c6e..c0341c99 100644 --- a/dnsfilter/security.go +++ b/dnsfilter/security.go @@ -3,19 +3,15 @@ package dnsfilter import ( - "bufio" "bytes" "crypto/sha256" "encoding/binary" "encoding/gob" "encoding/hex" "encoding/json" - "errors" "fmt" - "io" "net" "net/http" - "strconv" "strings" "time" @@ -325,66 +321,7 @@ func (d *Dnsfilter) handleSafeBrowsingStatus(w http.ResponseWriter, r *http.Requ } } -func parseParametersFromBody(r io.Reader) (map[string]string, error) { - parameters := map[string]string{} - - scanner := bufio.NewScanner(r) - for scanner.Scan() { - line := scanner.Text() - if len(line) == 0 { - // skip empty lines - continue - } - parts := strings.SplitN(line, "=", 2) - if len(parts) != 2 { - return parameters, errors.New("Got invalid request body") - } - parameters[strings.TrimSpace(parts[0])] = strings.TrimSpace(parts[1]) - } - - return parameters, nil -} - func (d *Dnsfilter) handleParentalEnable(w http.ResponseWriter, r *http.Request) { - parameters, err := parseParametersFromBody(r.Body) - if err != nil { - httpError(r, w, http.StatusBadRequest, "failed to parse parameters from body: %s", err) - return - } - - sensitivity, ok := parameters["sensitivity"] - if !ok { - http.Error(w, "Sensitivity parameter was not specified", 400) - return - } - - switch sensitivity { - case "3": - break - case "EARLY_CHILDHOOD": - sensitivity = "3" - case "10": - break - case "YOUNG": - sensitivity = "10" - case "13": - break - case "TEEN": - sensitivity = "13" - case "17": - break - case "MATURE": - sensitivity = "17" - default: - http.Error(w, "Sensitivity must be set to valid value", 400) - return - } - i, err := strconv.Atoi(sensitivity) - if err != nil { - http.Error(w, "Sensitivity must be set to valid value", 400) - return - } - d.Config.ParentalSensitivity = i d.Config.ParentalEnabled = true d.Config.ConfigModified() } @@ -398,9 +335,6 @@ func (d *Dnsfilter) handleParentalStatus(w http.ResponseWriter, r *http.Request) data := map[string]interface{}{ "enabled": d.Config.ParentalEnabled, } - if d.Config.ParentalEnabled { - data["sensitivity"] = d.Config.ParentalSensitivity - } jsonVal, err := json.Marshal(data) if err != nil { httpError(r, w, http.StatusInternalServerError, "Unable to marshal status json: %s", err)