diff --git a/CHANGELOG.md b/CHANGELOG.md index 12e6edf4..20aafe62 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -56,6 +56,11 @@ In this release, the schema version has changed from 12 to 13. --> - Go 1.17 support. v0.109.0 will require at least Go 1.18 to build. +## Fixed + +- Unnecessarily complex hosts-related logic leading to infinite recursion in + some cases ([#4216]). + ### Removed - Go 1.16 support. diff --git a/internal/aghnet/hostscontainer.go b/internal/aghnet/hostscontainer.go index 85379fed..198d4c78 100644 --- a/internal/aghnet/hostscontainer.go +++ b/internal/aghnet/hostscontainer.go @@ -46,13 +46,8 @@ type requestMatcher struct { } // MatchRequest processes the request rewriting hostnames and addresses read -// from the operating system's hosts files. -// -// res is nil for any request having not an A/AAAA or PTR type. Results -// containing CNAME information may be queried again with the same question type -// and the returned CNAME for Host field of request. Results are guaranteed to -// be direct, i.e. any returned CNAME resolves into actual address like an alias -// in hosts does, see man hosts (5). +// from the operating system's hosts files. res is nil for any request having +// not an A/AAAA or PTR type, see man 5 hosts. // // It's safe for concurrent use. func (rm *requestMatcher) MatchRequest( @@ -203,7 +198,7 @@ func (hc *HostsContainer) Close() (err error) { } // Upd returns the channel into which the updates are sent. The receivable -// map's values are guaranteed to be of type of *aghnet.Hosts. +// map's values are guaranteed to be of type of *stringutil.Set. func (hc *HostsContainer) Upd() (updates <-chan *netutil.IPMap) { return hc.updates } @@ -259,28 +254,14 @@ func (hc *HostsContainer) handleEvents() { } } -// ipRules is the pair of generated A/AAAA and PTR rules with related IP. -type ipRules struct { - // rule is the A/AAAA $dnsrewrite rule. - rule string - // rulePtr is the PTR $dnsrewrite rule. - rulePtr string - // ip is the IP address related to the rules. - ip net.IP -} - // hostsParser is a helper type to parse rules from the operating system's hosts // file. It exists for only a single refreshing session. type hostsParser struct { // rulesBuilder builds the resulting rules list content. rulesBuilder *strings.Builder - // rules stores the rules for main hosts to generate translations. - rules []ipRules - - // cnameSet prevents duplicating cname rules, e.g. same hostname for - // different IP versions. - cnameSet *stringutil.Set + // translations maps generated rules into actual hosts file lines. + translations map[string]string // table stores only the unique IP-hostname pairs. It's also sent to the // updates channel afterwards. @@ -290,13 +271,10 @@ type hostsParser struct { // newHostsParser creates a new *hostsParser with buffers of size taken from the // previous parse. func (hc *HostsContainer) newHostsParser() (hp *hostsParser) { - lastLen := hc.last.Len() - return &hostsParser{ rulesBuilder: &strings.Builder{}, - rules: make([]ipRules, 0, lastLen), - cnameSet: stringutil.NewSet(), - table: netutil.NewIPMap(lastLen), + translations: map[string]string{}, + table: netutil.NewIPMap(hc.last.Len()), } } @@ -342,6 +320,8 @@ func (hp *hostsParser) parseLine(line string) (ip net.IP, hosts []string) { // Make sure that invalid hosts aren't turned into rules. // // See https://github.com/AdguardTeam/AdGuardHome/issues/3946. + // + // TODO(e.burkov): Investigate if hosts may contain DNS-SD domains. err := netutil.ValidateDomainName(f) if err != nil { log.Error("%s: host %q is invalid, ignoring", hostsContainerPref, f) @@ -355,107 +335,45 @@ func (hp *hostsParser) parseLine(line string) (ip net.IP, hosts []string) { return ip, hosts } -// Hosts is used to contain the main host and all it's aliases. -type Hosts struct { - // Aliases contains all the aliases for Main. - Aliases *stringutil.Set - // Main is the host itself. - Main string -} - -// Equal returns true if h equals hh. -func (h *Hosts) Equal(hh *Hosts) (ok bool) { - if h == nil || hh == nil { - return h == hh - } - - return h.Main == hh.Main && h.Aliases.Equal(hh.Aliases) -} - -// add tries to add the ip-host pair. It returns: -// -// main host if the host is not the first one added for the ip. -// host itself if the host is the first one added for the ip. -// "" if the ip-host pair has already been added. -// -func (hp *hostsParser) add(ip net.IP, host string) (mainHost string) { - v, ok := hp.table.Get(ip) - switch h, _ := v.(*Hosts); { - case !ok: - // This is the first host for the ip. - hp.table.Set(ip, &Hosts{Main: host}) - - return host - case h.Main == host: - // This is a duplicate. Go on. - case h.Aliases == nil: - // This is the first alias. - h.Aliases = stringutil.NewSet(host) - - return h.Main - case !h.Aliases.Has(host): - // This is a new alias. - h.Aliases.Add(host) - - return h.Main - default: - // This is a duplicate. Go on. - } - - return "" -} - // addPair puts the pair of ip and host to the rules builder if needed. For // each ip the first member of hosts will become the main one. func (hp *hostsParser) addPairs(ip net.IP, hosts []string) { - for _, host := range hosts { - switch mainHost := hp.add(ip, host); mainHost { - case "": - // This host is a duplicate. + v, ok := hp.table.Get(ip) + if !ok { + // This ip is added at the first time. + v = stringutil.NewSet() + hp.table.Set(ip, v) + } + + var set *stringutil.Set + set, ok = v.(*stringutil.Set) + if !ok { + log.Debug("%s: adding pairs: unexpected value type %T", hostsContainerPref, v) + + return + } + + processed := strings.Join(append([]string{ip.String()}, hosts...), " ") + for _, h := range hosts { + if set.Has(h) { continue - case host: - // This host is main. - added, addedPtr := hp.writeMainRule(host, ip) - hp.rules = append(hp.rules, ipRules{ - rule: added, - rulePtr: addedPtr, - ip: ip, - }) - default: - // This host is an alias. - pair := fmt.Sprint(host, " ", mainHost) - if hp.cnameSet.Has(pair) { - continue - } - hp.writeAliasRule(host, mainHost) - hp.cnameSet.Add(pair) } - log.Debug("%s: added ip-host pair %q-%q", hostsContainerPref, ip, host) + set.Add(h) + + rule, rulePtr := hp.writeRules(h, ip) + hp.translations[rule], hp.translations[rulePtr] = processed, processed + + log.Debug("%s: added ip-host pair %q-%q", hostsContainerPref, ip, h) } } -// writeAliasRule writes the CNAME rule for the alias-host pair into internal -// builders. -func (hp *hostsParser) writeAliasRule(alias, host string) { - const ( - nl = "\n" - sc = ";" - - rwSuccess = rules.MaskSeparator + "$dnsrewrite=NOERROR" + sc + "CNAME" + sc - constLen = len(rules.MaskPipe) + len(rwSuccess) + len(nl) - ) - - hp.rulesBuilder.Grow(constLen + len(host) + len(alias)) - stringutil.WriteToBuilder(hp.rulesBuilder, rules.MaskPipe, alias, rwSuccess, host, nl) -} - -// writeMainRule writes the actual rule for the qtype and the PTR for the +// writeRules writes the actual rule for the qtype and the PTR for the // host-ip pair into internal builders. -func (hp *hostsParser) writeMainRule(host string, ip net.IP) (added, addedPtr string) { +func (hp *hostsParser) writeRules(host string, ip net.IP) (rule, rulePtr string) { arpa, err := netutil.IPToReversedAddr(ip) if err != nil { - return + return "", "" } const ( @@ -482,28 +400,20 @@ func (hp *hostsParser) writeMainRule(host string, ip net.IP) (added, addedPtr st ruleBuilder := &strings.Builder{} ruleBuilder.Grow(modLen + len(host) + len(qtype) + len(ipStr)) - stringutil.WriteToBuilder( - ruleBuilder, - rules.MaskPipe, - host, - rwSuccess, - qtype, - ";", - ipStr, - ) - added = ruleBuilder.String() + stringutil.WriteToBuilder(ruleBuilder, rules.MaskPipe, host, rwSuccess, qtype, ";", ipStr) + rule = ruleBuilder.String() ruleBuilder.Reset() ruleBuilder.Grow(modLenPTR + len(arpa) + len(fqdn)) stringutil.WriteToBuilder(ruleBuilder, rules.MaskPipe, arpa, rwSuccessPTR, fqdn) - addedPtr = ruleBuilder.String() + rulePtr = ruleBuilder.String() - hp.rulesBuilder.Grow(len(added) + len(addedPtr) + 2*len(nl)) - stringutil.WriteToBuilder(hp.rulesBuilder, added, nl, addedPtr, nl) + hp.rulesBuilder.Grow(len(rule) + len(rulePtr) + 2*len(nl)) + stringutil.WriteToBuilder(hp.rulesBuilder, rule, nl, rulePtr, nl) - return added, addedPtr + return rule, rulePtr } // equalSet returns true if the internal hosts table just parsed equals target. @@ -519,12 +429,11 @@ func (hp *hostsParser) equalSet(target *netutil.IPMap) (ok bool) { hp.table.Range(func(ip net.IP, b interface{}) (cont bool) { // ok is set to true if the target doesn't contain ip or if the - // appropriate hosts set isn't equal to the checked one, i.e. the main - // hosts differ or the maps have at least one discrepancy. + // appropriate hosts set isn't equal to the checked one. if a, hasIP := target.Get(ip); !hasIP { ok = true - } else if hosts, aok := a.(*Hosts); aok { - ok = !hosts.Equal(b.(*Hosts)) + } else if hosts, aok := a.(*stringutil.Set); aok { + ok = !hosts.Equal(b.(*stringutil.Set)) } // Continue only if maps has no discrepancies. @@ -563,35 +472,6 @@ func (hp *hostsParser) newStrg(id int) (s *filterlist.RuleStorage, err error) { }}) } -// translations generates the map to translate $dnsrewrite rules to -// hosts-syntax ones. -func (hp *hostsParser) translations() (trans map[string]string) { - l := len(hp.rules) - if l == 0 { - return nil - } - - trans = make(map[string]string, l*2) - for _, r := range hp.rules { - v, ok := hp.table.Get(r.ip) - if !ok { - continue - } - - var hosts *Hosts - hosts, ok = v.(*Hosts) - if !ok { - continue - } - - strs := append([]string{r.ip.String(), hosts.Main}, hosts.Aliases.Values()...) - hostsLine := strings.Join(strs, " ") - trans[r.rule], trans[r.rulePtr] = hostsLine, hostsLine - } - - return trans -} - // refresh gets the data from specified files and propagates the updates if // needed. // @@ -618,7 +498,7 @@ func (hc *HostsContainer) refresh() (err error) { return fmt.Errorf("initializing rules storage: %w", err) } - hc.resetEng(rulesStrg, hp.translations()) + hc.resetEng(rulesStrg, hp.translations) return nil } diff --git a/internal/aghnet/hostscontainer_test.go b/internal/aghnet/hostscontainer_test.go index 150e8c19..a141ce00 100644 --- a/internal/aghnet/hostscontainer_test.go +++ b/internal/aghnet/hostscontainer_test.go @@ -14,6 +14,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/stringutil" + "github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/urlfilter" "github.com/AdguardTeam/urlfilter/rules" "github.com/miekg/dns" @@ -86,6 +87,7 @@ func TestNewHostsContainer(t *testing.T) { return } + testutil.CleanupAndRequireSuccess(t, hc.Close) require.NoError(t, err) require.NotNil(t, hc) @@ -156,8 +158,9 @@ func TestHostsContainer_refresh(t *testing.T) { hc, err := NewHostsContainer(0, testFS, w, "dir") require.NoError(t, err) + testutil.CleanupAndRequireSuccess(t, hc.Close) - checkRefresh := func(t *testing.T, wantHosts Hosts) { + checkRefresh := func(t *testing.T, wantHosts *stringutil.Set) { upd, ok := <-hc.Upd() require.True(t, ok) require.NotNil(t, upd) @@ -167,22 +170,21 @@ func TestHostsContainer_refresh(t *testing.T) { v, ok := upd.Get(ip) require.True(t, ok) - var hosts *Hosts - hosts, ok = v.(*Hosts) + var set *stringutil.Set + set, ok = v.(*stringutil.Set) require.True(t, ok) - assert.Equal(t, wantHosts.Main, hosts.Main) - assert.True(t, hosts.Aliases.Equal(wantHosts.Aliases)) + assert.True(t, set.Equal(wantHosts)) } t.Run("initial_refresh", func(t *testing.T) { - checkRefresh(t, Hosts{Main: "hostname"}) + checkRefresh(t, stringutil.NewSet("hostname")) }) t.Run("second_refresh", func(t *testing.T) { testFS["dir/file2"] = &fstest.MapFile{Data: []byte(ipStr + ` alias` + nl)} eventsCh <- event{} - checkRefresh(t, Hosts{Main: "hostname", Aliases: stringutil.NewSet("alias")}) + checkRefresh(t, stringutil.NewSet("hostname", "alias")) }) t.Run("double_refresh", func(t *testing.T) { @@ -288,6 +290,7 @@ func TestHostsContainer_Translate(t *testing.T) { hc, err := NewHostsContainer(0, testdata, &stubWatcher, "etc_hosts") require.NoError(t, err) + testutil.CleanupAndRequireSuccess(t, hc.Close) testCases := []struct { name string @@ -300,7 +303,11 @@ func TestHostsContainer_Translate(t *testing.T) { }, { name: "hello", rule: "|hello^$dnsrewrite=NOERROR;A;1.0.0.0", - wantTrans: []string{"1.0.0.0", "hello", "hello.world", "hello.world.again"}, + wantTrans: []string{"1.0.0.0", "hello", "hello.world"}, + }, { + name: "hello-alias", + rule: "|hello.world.again^$dnsrewrite=NOERROR;A;1.0.0.0", + wantTrans: []string{"1.0.0.0", "hello.world.again"}, }, { name: "simplehost_v6", rule: "|simplehost^$dnsrewrite=NOERROR;AAAA;::1", @@ -308,7 +315,11 @@ func TestHostsContainer_Translate(t *testing.T) { }, { name: "hello_v6", rule: "|hello^$dnsrewrite=NOERROR;AAAA;::", - wantTrans: []string{"::", "hello", "hello.world", "hello.world.again"}, + wantTrans: []string{"::", "hello", "hello.world"}, + }, { + name: "hello_v6-alias", + rule: "|hello.world.again^$dnsrewrite=NOERROR;AAAA;::", + wantTrans: []string{"::", "hello.world.again"}, }, { name: "simplehost_ptr", rule: "|1.0.0.1.in-addr.arpa^$dnsrewrite=NOERROR;PTR;simplehost.", @@ -316,7 +327,11 @@ func TestHostsContainer_Translate(t *testing.T) { }, { name: "hello_ptr", rule: "|0.0.0.1.in-addr.arpa^$dnsrewrite=NOERROR;PTR;hello.", - wantTrans: []string{"1.0.0.0", "hello", "hello.world", "hello.world.again"}, + wantTrans: []string{"1.0.0.0", "hello", "hello.world"}, + }, { + name: "hello_ptr-alias", + rule: "|0.0.0.1.in-addr.arpa^$dnsrewrite=NOERROR;PTR;hello.world.again.", + wantTrans: []string{"1.0.0.0", "hello.world.again"}, }, { name: "simplehost_ptr_v6", rule: "|1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa" + @@ -326,14 +341,18 @@ func TestHostsContainer_Translate(t *testing.T) { name: "hello_ptr_v6", rule: "|0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa" + "^$dnsrewrite=NOERROR;PTR;hello.", - wantTrans: []string{"::", "hello", "hello.world", "hello.world.again"}, + wantTrans: []string{"::", "hello", "hello.world"}, + }, { + name: "hello_ptr_v6-alias", + rule: "|0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa" + + "^$dnsrewrite=NOERROR;PTR;hello.world.again.", + wantTrans: []string{"::", "hello.world.again"}, }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - want := stringutil.NewSet(tc.wantTrans...) got := stringutil.NewSet(strings.Fields(hc.Translate(tc.rule))...) - assert.True(t, want.Equal(got)) + assert.True(t, stringutil.NewSet(tc.wantTrans...).Equal(got)) }) } } @@ -354,7 +373,7 @@ func TestHostsContainer(t *testing.T) { RRType: dns.TypeA, }, { RCode: dns.RcodeSuccess, - Value: net.IP(append((&[15]byte{})[:], byte(1))), + Value: net.ParseIP("::1"), RRType: dns.TypeAAAA, }}, name: "simple", @@ -364,8 +383,13 @@ func TestHostsContainer(t *testing.T) { }, }, { want: []*rules.DNSRewrite{{ - RCode: dns.RcodeSuccess, - NewCNAME: "hello", + RCode: dns.RcodeSuccess, + Value: net.IPv4(1, 0, 0, 0), + RRType: dns.TypeA, + }, { + RCode: dns.RcodeSuccess, + Value: net.ParseIP("::"), + RRType: dns.TypeAAAA, }}, name: "hello_alias", req: urlfilter.DNSRequest{ @@ -374,8 +398,13 @@ func TestHostsContainer(t *testing.T) { }, }, { want: []*rules.DNSRewrite{{ - RCode: dns.RcodeSuccess, - NewCNAME: "hello", + RCode: dns.RcodeSuccess, + Value: net.IPv4(1, 0, 0, 0), + RRType: dns.TypeA, + }, { + RCode: dns.RcodeSuccess, + Value: net.ParseIP("::"), + RRType: dns.TypeAAAA, }}, name: "other_line_alias", req: urlfilter.DNSRequest{ @@ -398,8 +427,13 @@ func TestHostsContainer(t *testing.T) { }, }, { want: []*rules.DNSRewrite{{ - RCode: dns.RcodeSuccess, - NewCNAME: "a.whole", + RCode: dns.RcodeSuccess, + RRType: dns.TypeA, + Value: net.IPv4(1, 0, 0, 2), + }, { + RCode: dns.RcodeSuccess, + RRType: dns.TypeAAAA, + Value: net.ParseIP("::2"), }}, name: "lots_of_aliases", req: urlfilter.DNSRequest{ @@ -431,6 +465,51 @@ func TestHostsContainer(t *testing.T) { Hostname: "1.0.0.1.in-addr.arpa", DNSType: dns.TypeSRV, }, + }, { + want: []*rules.DNSRewrite{{ + RCode: dns.RcodeSuccess, + RRType: dns.TypeA, + Value: net.IPv4(4, 2, 1, 6), + }, { + RCode: dns.RcodeSuccess, + RRType: dns.TypeAAAA, + Value: net.ParseIP("::42"), + }}, + name: "issue_4216_4_6", + req: urlfilter.DNSRequest{ + Hostname: "domain", + DNSType: dns.TypeA, + }, + }, { + want: []*rules.DNSRewrite{{ + RCode: dns.RcodeSuccess, + RRType: dns.TypeA, + Value: net.IPv4(7, 5, 3, 1), + }, { + RCode: dns.RcodeSuccess, + RRType: dns.TypeA, + Value: net.IPv4(1, 3, 5, 7), + }}, + name: "issue_4216_4", + req: urlfilter.DNSRequest{ + Hostname: "domain4", + DNSType: dns.TypeA, + }, + }, { + want: []*rules.DNSRewrite{{ + RCode: dns.RcodeSuccess, + RRType: dns.TypeAAAA, + Value: net.ParseIP("::13"), + }, { + RCode: dns.RcodeSuccess, + RRType: dns.TypeAAAA, + Value: net.ParseIP("::31"), + }}, + name: "issue_4216_6", + req: urlfilter.DNSRequest{ + Hostname: "domain6", + DNSType: dns.TypeAAAA, + }, }} stubWatcher := aghtest.FSWatcher{ @@ -441,6 +520,7 @@ func TestHostsContainer(t *testing.T) { hc, err := NewHostsContainer(listID, testdata, &stubWatcher, "etc_hosts") require.NoError(t, err) + testutil.CleanupAndRequireSuccess(t, hc.Close) for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { diff --git a/internal/aghnet/testdata/etc_hosts b/internal/aghnet/testdata/etc_hosts index afe7011c..81041031 100644 --- a/internal/aghnet/testdata/etc_hosts +++ b/internal/aghnet/testdata/etc_hosts @@ -27,4 +27,12 @@ ::4 *.com :: hello.world.again ::1 simplehost -:: hello.world \ No newline at end of file +:: hello.world + +# See https://github.com/AdguardTeam/AdGuardHome/issues/4216. +4.2.1.6 domain domain.alias +::42 domain.alias domain +1.3.5.7 domain4 domain4.alias +7.5.3.1 domain4.alias domain4 +::13 domain6 domain6.alias +::31 domain6.alias domain6 \ No newline at end of file diff --git a/internal/filtering/filtering.go b/internal/filtering/filtering.go index a2a8ff21..0a8bcd6e 100644 --- a/internal/filtering/filtering.go +++ b/internal/filtering/filtering.go @@ -477,7 +477,7 @@ func (d *DNSFilter) matchSysHosts( return res, nil } - return d.matchSysHostsIntl(&urlfilter.DNSRequest{ + dnsres, _ := d.EtcHosts.MatchRequest(urlfilter.DNSRequest{ Hostname: host, SortedClientTags: setts.ClientTags, // TODO(e.burkov): Wait for urlfilter update to pass net.IP. @@ -485,12 +485,6 @@ func (d *DNSFilter) matchSysHosts( ClientName: setts.ClientName, DNSType: qtype, }) -} - -// matchSysHostsIntl actually matches the request. It's separated to avoid -// performing checks twice. -func (d *DNSFilter) matchSysHostsIntl(req *urlfilter.DNSRequest) (res Result, err error) { - dnsres, _ := d.EtcHosts.MatchRequest(*req) if dnsres == nil { return res, nil } @@ -501,13 +495,6 @@ func (d *DNSFilter) matchSysHostsIntl(req *urlfilter.DNSRequest) (res Result, er } res = d.processDNSRewrites(dnsr) - if cn := res.CanonName; cn != "" { - // Probably an alias. - req.Hostname = cn - - return d.matchSysHostsIntl(req) - } - res.Reason = RewrittenAutoHosts for _, r := range res.Rules { r.Text = stringutil.Coalesce(d.EtcHosts.Translate(r.Text), r.Text) diff --git a/internal/home/clients.go b/internal/home/clients.go index 94dd41f7..098fba91 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -783,17 +783,14 @@ func (clients *clientsContainer) addFromHostsFile(hosts *netutil.IPMap) { n := 0 hosts.Range(func(ip net.IP, v interface{}) (cont bool) { - hosts, ok := v.(*aghnet.Hosts) + hosts, ok := v.(*stringutil.Set) if !ok { log.Error("dns: bad type %T in ipToRC for %s", v, ip) return true } - if clients.addHostLocked(ip, hosts.Main, ClientSourceHostsFile) { - n++ - } - hosts.Aliases.Range(func(name string) (cont bool) { + hosts.Range(func(name string) (cont bool) { if clients.addHostLocked(ip, name, ClientSourceHostsFile) { n++ }