diff --git a/internal/aghnet/etchostscontainer.go b/internal/aghnet/etchostscontainer.go deleted file mode 100644 index af597e61..00000000 --- a/internal/aghnet/etchostscontainer.go +++ /dev/null @@ -1,387 +0,0 @@ -package aghnet - -import ( - "bufio" - "net" - "os" - "path/filepath" - "runtime" - "strings" - "sync" - - "github.com/AdguardTeam/AdGuardHome/internal/aghos" - "github.com/AdguardTeam/golibs/errors" - "github.com/AdguardTeam/golibs/log" - "github.com/AdguardTeam/golibs/netutil" - "github.com/fsnotify/fsnotify" - "github.com/miekg/dns" -) - -type onChangedT func() - -// EtcHostsContainer - automatic DNS records -// -// TODO(e.burkov): Move the logic under interface. Refactor. Probably remove -// the resolving logic. -type EtcHostsContainer struct { - // lock protects table and tableReverse. - lock sync.RWMutex - // table is the host-to-IPs map. - table map[string][]net.IP - // tableReverse is the IP-to-hosts map. The type of the values in the - // map is []string. - tableReverse *netutil.IPMap - - hostsFn string // path to the main hosts-file - hostsDirs []string // paths to OS-specific directories with hosts-files - watcher *fsnotify.Watcher // file and directory watcher object - - // onlyWritesChan used to contain only writing events from watcher. - onlyWritesChan chan fsnotify.Event - - onChanged onChangedT // notification to other modules -} - -// SetOnChanged - set callback function that will be called when the data is changed -func (ehc *EtcHostsContainer) SetOnChanged(onChanged onChangedT) { - ehc.onChanged = onChanged -} - -// Notify other modules -func (ehc *EtcHostsContainer) notify() { - if ehc.onChanged == nil { - return - } - ehc.onChanged() -} - -// Init - initialize -// hostsFn: Override default name for the hosts-file (optional) -func (ehc *EtcHostsContainer) Init(hostsFn string) { - ehc.table = make(map[string][]net.IP) - ehc.onlyWritesChan = make(chan fsnotify.Event, 2) - - ehc.hostsFn = "/etc/hosts" - if runtime.GOOS == "windows" { - ehc.hostsFn = os.ExpandEnv("$SystemRoot\\system32\\drivers\\etc\\hosts") - } - if len(hostsFn) != 0 { - ehc.hostsFn = hostsFn - } - - if aghos.IsOpenWrt() { - // OpenWrt: "/tmp/hosts/dhcp.cfg01411c". - ehc.hostsDirs = append(ehc.hostsDirs, "/tmp/hosts") - } - - // Load hosts initially - ehc.updateHosts() - - var err error - ehc.watcher, err = fsnotify.NewWatcher() - if err != nil { - log.Error("etchosts: %s", err) - } -} - -// Start - start module -func (ehc *EtcHostsContainer) Start() { - if ehc == nil { - return - } - - log.Debug("Start etchostscontainer module") - - ehc.updateHosts() - - if ehc.watcher != nil { - go ehc.watcherLoop() - - err := ehc.watcher.Add(ehc.hostsFn) - if err != nil { - log.Error("Error while initializing watcher for a file %s: %s", ehc.hostsFn, err) - } - - for _, dir := range ehc.hostsDirs { - err = ehc.watcher.Add(dir) - if err != nil { - log.Error("Error while initializing watcher for a directory %s: %s", dir, err) - } - } - } -} - -// Close - close module -func (ehc *EtcHostsContainer) Close() { - if ehc == nil { - return - } - - if ehc.watcher != nil { - _ = ehc.watcher.Close() - } - - // Don't close onlyWritesChan here and let onlyWrites close it after - // watcher.Events is closed to prevent close races. -} - -// Process returns the list of IP addresses for the hostname or nil if nothing -// found. -func (ehc *EtcHostsContainer) Process(host string, qtype uint16) []net.IP { - if qtype == dns.TypePTR { - return nil - } - - var ipsCopy []net.IP - ehc.lock.RLock() - defer ehc.lock.RUnlock() - - if ips, ok := ehc.table[host]; ok { - ipsCopy = make([]net.IP, len(ips)) - copy(ipsCopy, ips) - } - - log.Debug("etchosts: answer: %s -> %v", host, ipsCopy) - return ipsCopy -} - -// ProcessReverse processes a PTR request. It returns nil if nothing is found. -func (ehc *EtcHostsContainer) ProcessReverse(addr string, qtype uint16) (hosts []string) { - if qtype != dns.TypePTR { - return nil - } - - ip, err := netutil.IPFromReversedAddr(addr) - if err != nil { - log.Error("etchosts: reversed addr: %s", err) - - return nil - } - - ehc.lock.RLock() - defer ehc.lock.RUnlock() - - v, ok := ehc.tableReverse.Get(ip) - if !ok { - return nil - } - - hosts, ok = v.([]string) - if !ok { - log.Error("etchosts: bad type %T in tableReverse for %s", v, ip) - - return nil - } else if len(hosts) == 0 { - return nil - } - - log.Debug("etchosts: reverse-lookup: %s -> %s", addr, hosts) - - return hosts -} - -// List returns an IP-to-hostnames table. The type of the values in the map is -// []string. It is safe for concurrent use. -func (ehc *EtcHostsContainer) List() (ipToHosts *netutil.IPMap) { - ehc.lock.RLock() - defer ehc.lock.RUnlock() - - return ehc.tableReverse.ShallowClone() -} - -// update table -func (ehc *EtcHostsContainer) updateTable(table map[string][]net.IP, host string, ipAddr net.IP) { - ips, ok := table[host] - if ok { - for _, ip := range ips { - if ip.Equal(ipAddr) { - // IP already exists: don't add duplicates - ok = false - break - } - } - if !ok { - ips = append(ips, ipAddr) - table[host] = ips - } - } else { - table[host] = []net.IP{ipAddr} - ok = true - } - if ok { - log.Debug("etchosts: added %s -> %s", ipAddr, host) - } -} - -// updateTableRev updates the reverse address table. -func (ehc *EtcHostsContainer) updateTableRev(tableRev *netutil.IPMap, newHost string, ip net.IP) { - v, ok := tableRev.Get(ip) - if !ok { - tableRev.Set(ip, []string{newHost}) - log.Debug("etchosts: added reverse-address %s -> %s", ip, newHost) - - return - } - - hosts, _ := v.([]string) - for _, host := range hosts { - if host == newHost { - return - } - } - - hosts = append(hosts, newHost) - tableRev.Set(ip, hosts) - - log.Debug("etchosts: added reverse-address %s -> %s", ip, newHost) -} - -// parseHostsLine parses hosts from the fields. -func parseHostsLine(fields []string) (hosts []string) { - for _, f := range fields { - hashIdx := strings.IndexByte(f, '#') - if hashIdx == 0 { - // The rest of the fields are a part of the comment. - // Skip immediately. - return - } else if hashIdx > 0 { - // Only a part of the field is a comment. - hosts = append(hosts, f[:hashIdx]) - - return hosts - } - - hosts = append(hosts, f) - } - - return hosts -} - -// load reads IP-hostname pairs from the hosts file. Multiple hostnames per -// line for one IP are supported. -func (ehc *EtcHostsContainer) load( - table map[string][]net.IP, - tableRev *netutil.IPMap, - fn string, -) { - f, err := os.Open(fn) - if err != nil { - log.Error("etchosts: %s", err) - - return - } - - defer func() { - derr := f.Close() - if derr != nil { - log.Error("etchosts: closing file: %s", err) - } - }() - - log.Debug("etchosts: loading hosts from file %s", fn) - - s := bufio.NewScanner(f) - for s.Scan() { - line := strings.TrimSpace(s.Text()) - fields := strings.Fields(line) - if len(fields) < 2 { - continue - } - - ip := net.ParseIP(fields[0]) - if ip == nil { - continue - } - - hosts := parseHostsLine(fields[1:]) - for _, host := range hosts { - ehc.updateTable(table, host, ip) - ehc.updateTableRev(tableRev, host, ip) - } - } - - err = s.Err() - if err != nil { - log.Error("etchosts: %s", err) - } -} - -// onlyWrites is a filter for (*fsnotify.Watcher).Events. -func (ehc *EtcHostsContainer) onlyWrites() { - for event := range ehc.watcher.Events { - if event.Op&fsnotify.Write == fsnotify.Write { - ehc.onlyWritesChan <- event - } - } - - close(ehc.onlyWritesChan) -} - -// Receive notifications from fsnotify package -func (ehc *EtcHostsContainer) watcherLoop() { - go ehc.onlyWrites() - for { - select { - case event, ok := <-ehc.onlyWritesChan: - if !ok { - return - } - - // Assume that we sometimes have the same event occurred - // several times. - repeat := true - for repeat { - select { - case _, ok = <-ehc.onlyWritesChan: - repeat = ok - default: - repeat = false - } - } - - if event.Op&fsnotify.Write == fsnotify.Write { - log.Debug("etchosts: modified: %s", event.Name) - ehc.updateHosts() - } - - case err, ok := <-ehc.watcher.Errors: - if !ok { - return - } - log.Error("etchosts: %s", err) - } - } -} - -// updateHosts - loads system hosts -func (ehc *EtcHostsContainer) updateHosts() { - table := make(map[string][]net.IP) - tableRev := netutil.NewIPMap(0) - - ehc.load(table, tableRev, ehc.hostsFn) - - for _, dir := range ehc.hostsDirs { - des, err := os.ReadDir(dir) - if err != nil { - if !errors.Is(err, os.ErrNotExist) { - log.Error("etchosts: Opening directory: %q: %s", dir, err) - } - - continue - } - - for _, de := range des { - ehc.load(table, tableRev, filepath.Join(dir, de.Name())) - } - } - - func() { - ehc.lock.Lock() - defer ehc.lock.Unlock() - - ehc.table = table - ehc.tableReverse = tableRev - }() - - ehc.notify() -} diff --git a/internal/aghnet/etchostscontainer_test.go b/internal/aghnet/etchostscontainer_test.go deleted file mode 100644 index b83f2dd2..00000000 --- a/internal/aghnet/etchostscontainer_test.go +++ /dev/null @@ -1,130 +0,0 @@ -package aghnet - -import ( - "net" - "os" - "strings" - "testing" - "time" - - "github.com/AdguardTeam/AdGuardHome/internal/aghtest" - "github.com/miekg/dns" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestMain(m *testing.M) { - aghtest.DiscardLogOutput(m) -} - -func prepareTestFile(t *testing.T) (f *os.File) { - t.Helper() - - dir := t.TempDir() - - f, err := os.CreateTemp(dir, "") - require.NoError(t, err) - require.NotNil(t, f) - - t.Cleanup(func() { - assert.NoError(t, f.Close()) - }) - - return f -} - -func assertWriting(t *testing.T, f *os.File, strs ...string) { - t.Helper() - - for _, str := range strs { - n, err := f.WriteString(str) - require.NoError(t, err) - assert.Equal(t, n, len(str)) - } -} - -func TestEtcHostsContainerResolution(t *testing.T) { - ehc := &EtcHostsContainer{} - - f := prepareTestFile(t) - - assertWriting(t, f, - " 127.0.0.1 host localhost # comment \n", - " ::1 localhost#comment \n", - ) - ehc.Init(f.Name()) - - t.Run("existing_host", func(t *testing.T) { - ips := ehc.Process("localhost", dns.TypeA) - require.Len(t, ips, 1) - assert.Equal(t, net.IPv4(127, 0, 0, 1), ips[0]) - }) - - t.Run("unknown_host", func(t *testing.T) { - ips := ehc.Process("newhost", dns.TypeA) - assert.Nil(t, ips) - - // Comment. - ips = ehc.Process("comment", dns.TypeA) - assert.Nil(t, ips) - }) - - t.Run("hosts_file", func(t *testing.T) { - names, ok := ehc.List().Get(net.IP{127, 0, 0, 1}) - require.True(t, ok) - assert.Equal(t, []string{"host", "localhost"}, names) - }) - - t.Run("ptr", func(t *testing.T) { - testCases := []struct { - wantIP string - wantHost string - wantLen int - }{ - {wantIP: "127.0.0.1", wantHost: "host", wantLen: 2}, - {wantIP: "::1", wantHost: "localhost", wantLen: 1}, - } - - for _, tc := range testCases { - a, err := dns.ReverseAddr(tc.wantIP) - require.NoError(t, err) - - a = strings.TrimSuffix(a, ".") - hosts := ehc.ProcessReverse(a, dns.TypePTR) - require.Len(t, hosts, tc.wantLen) - assert.Equal(t, tc.wantHost, hosts[0]) - } - }) -} - -func TestEtcHostsContainerFSNotify(t *testing.T) { - ehc := &EtcHostsContainer{} - - f := prepareTestFile(t) - - assertWriting(t, f, " 127.0.0.1 host localhost \n") - ehc.Init(f.Name()) - - t.Run("unknown_host", func(t *testing.T) { - ips := ehc.Process("newhost", dns.TypeA) - assert.Nil(t, ips) - }) - - // Start monitoring for changes. - ehc.Start() - t.Cleanup(ehc.Close) - - assertWriting(t, f, "127.0.0.2 newhost\n") - require.NoError(t, f.Sync()) - - // Wait until fsnotify has triggered and processed the file-modification - // event. - time.Sleep(50 * time.Millisecond) - - t.Run("notified", func(t *testing.T) { - ips := ehc.Process("newhost", dns.TypeA) - assert.NotNil(t, ips) - require.Len(t, ips, 1) - assert.True(t, net.IP{127, 0, 0, 2}.Equal(ips[0])) - }) -} diff --git a/internal/aghnet/hostscontainer.go b/internal/aghnet/hostscontainer.go new file mode 100644 index 00000000..c0e12a7b --- /dev/null +++ b/internal/aghnet/hostscontainer.go @@ -0,0 +1,343 @@ +package aghnet + +import ( + "bufio" + "fmt" + "io" + "io/fs" + "net" + "path" + "strings" + "sync" + + "github.com/AdguardTeam/AdGuardHome/internal/aghos" + "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/netutil" + "github.com/AdguardTeam/golibs/stringutil" + "github.com/AdguardTeam/urlfilter" + "github.com/AdguardTeam/urlfilter/filterlist" + "github.com/miekg/dns" +) + +// DefaultHostsPaths returns the slice of paths default for the operating system +// to files and directories which are containing the hosts database. The result +// is intended to use within fs.FS so the initial slash is omitted. +func DefaultHostsPaths() (paths []string) { + return defaultHostsPaths() +} + +// hostsContainerPref is a prefix for logging and wrapping errors in +// HostsContainer's methods. +const hostsContainerPref = "hosts container" + +// HostsContainer stores the relevant hosts database provided by the OS and +// processes both A/AAAA and PTR DNS requests for those. +type HostsContainer struct { + // engLock protects rulesStrg and engine. + engLock *sync.RWMutex + + // rulesStrg stores the rules obtained from the hosts' file. + rulesStrg *filterlist.RuleStorage + // engine serves rulesStrg. + engine *urlfilter.DNSEngine + + // Updates is the channel for receiving updated hosts. The receivable map's + // values has a type of slice of strings. + updates chan *netutil.IPMap + + // fsys is the working file system to read hosts files from. + fsys fs.FS + + // w tracks the changes in specified files and directories. + w aghos.FSWatcher + // patterns stores specified paths in the fs.Glob-compatible form. + patterns []string +} + +// errNoPaths is returned when there are no paths to watch passed to the +// HostsContainer. +const errNoPaths errors.Error = "hosts paths are empty" + +// NewHostsContainer creates a container of hosts, that watches the paths with +// w. paths shouldn't be empty and each of them should locate either a file or +// a directory in fsys. fsys and w must be non-nil. +func NewHostsContainer( + fsys fs.FS, + w aghos.FSWatcher, + paths ...string, +) (hc *HostsContainer, err error) { + defer func() { err = errors.Annotate(err, "%s: %w", hostsContainerPref) }() + + if len(paths) == 0 { + return nil, errNoPaths + } + + patterns, err := pathsToPatterns(fsys, paths) + if err != nil { + return nil, err + } + + hc = &HostsContainer{ + engLock: &sync.RWMutex{}, + updates: make(chan *netutil.IPMap, 1), + fsys: fsys, + w: w, + patterns: patterns, + } + + log.Debug("%s: starting", hostsContainerPref) + + // Load initially. + if err = hc.refresh(); err != nil { + return nil, err + } + + for _, p := range paths { + err = w.Add(p) + if err == nil { + continue + } else if errors.Is(err, fs.ErrNotExist) { + log.Debug("%s: file %q expected to exist but doesn't", hostsContainerPref, p) + + continue + } + + return nil, fmt.Errorf("adding path: %w", err) + } + + go hc.handleEvents() + + return hc, nil +} + +// MatchRequest is the request processing method to resolve hostnames and +// addresses from the operating system's hosts files. Any request not of A/AAAA +// or PTR type will return with an empty result. It's safe for concurrent use. +func (hc *HostsContainer) MatchRequest( + req urlfilter.DNSRequest, +) (res urlfilter.DNSResult, ok bool) { + switch req.DNSType { + case dns.TypeA, dns.TypeAAAA, dns.TypePTR: + log.Debug("%s: handling the request", hostsContainerPref) + default: + return urlfilter.DNSResult{}, false + } + + hc.engLock.RLock() + defer hc.engLock.RUnlock() + + return hc.engine.MatchRequest(req) +} + +// Close implements the io.Closer interface for *HostsContainer. +func (hc *HostsContainer) Close() (err error) { + log.Debug("%s: closing hosts container", hostsContainerPref) + + return errors.Annotate(hc.w.Close(), "%s: closing: %w", hostsContainerPref) +} + +// Upd returns the channel into which the updates are sent. +func (hc *HostsContainer) Upd() (updates <-chan *netutil.IPMap) { + return hc.updates +} + +// pathsToPatterns converts paths into patterns compatible with fs.Glob. +func pathsToPatterns(fsys fs.FS, paths []string) (patterns []string, err error) { + for i, p := range paths { + var fi fs.FileInfo + if fi, err = fs.Stat(fsys, p); err != nil { + return nil, fmt.Errorf("%q at index %d: %w", p, i, err) + } + + if fi.IsDir() { + p = path.Join(p, "*") + } + + patterns = append(patterns, p) + } + + return patterns, nil +} + +// handleEvents concurrently handles the events. It closes the update channel +// of HostsContainer when finishes. Used to be called within a goroutine. +func (hc *HostsContainer) handleEvents() { + defer log.OnPanic(fmt.Sprintf("%s: handling events", hostsContainerPref)) + + defer close(hc.updates) + + for range hc.w.Events() { + if err := hc.refresh(); err != nil { + log.Error("%s: %s", hostsContainerPref, err) + } + } +} + +// hostsParser is a helper type to parse rules from the operating system's hosts +// file. +type hostsParser struct { + // rules builds the resulting rules list content. + rules *strings.Builder + + // table stores only the unique IP-hostname pairs. It's also sent to the + // updates channel afterwards. + table *netutil.IPMap +} + +// parseHostsFile is a aghtest.FileWalker for parsing the files with hosts +// syntax. It never signs to stop the walking. +// +// See man hosts(5). +func (hp hostsParser) parseHostsFile( + r io.Reader, +) (patterns []string, cont bool, err error) { + s := bufio.NewScanner(r) + for s.Scan() { + ip, hosts := hp.parseLine(s.Text()) + if ip == nil { + continue + } + + for _, host := range hosts { + hp.addPair(ip, host) + } + } + + return nil, true, s.Err() +} + +// parseLine parses the line having the hosts syntax ignoring invalid ones. +func (hp hostsParser) parseLine(line string) (ip net.IP, hosts []string) { + line = strings.TrimSpace(line) + fields := strings.Fields(line) + if len(fields) < 2 { + return nil, nil + } + + if ip = net.ParseIP(fields[0]); ip == nil { + return nil, nil + } + +loop: + for _, f := range fields[1:] { + switch hashIdx := strings.IndexByte(f, '#'); hashIdx { + case 0: + // The rest of the fields are a part of the comment so skip + // immediately. + break loop + case -1: + hosts = append(hosts, f) + default: + // Only a part of the field is a comment. + hosts = append(hosts, f[:hashIdx]) + + break loop + } + } + + return ip, hosts +} + +// add returns true if the pair of ip and host wasn't added to the hp before. +func (hp hostsParser) add(ip net.IP, host string) (added bool) { + v, ok := hp.table.Get(ip) + hosts, _ := v.([]string) + if ok && stringutil.InSlice(hosts, host) { + return false + } + + hp.table.Set(ip, append(hosts, host)) + + return true +} + +// addPair puts the pair of ip and host to the rules builder if needed. +func (hp hostsParser) addPair(ip net.IP, host string) { + arpa, err := netutil.IPToReversedAddr(ip) + if err != nil { + return + } + + if !hp.add(ip, host) { + return + } + + qtype := "AAAA" + if ip.To4() != nil { + // Assume the validation of the IP address is performed already. + qtype = "A" + } + + stringutil.WriteToBuilder( + hp.rules, + "||", + host, + "^$dnsrewrite=NOERROR;", + qtype, + ";", + ip.String(), + "\n", + "||", + arpa, + "^$dnsrewrite=NOERROR;PTR;", + dns.Fqdn(host), + "\n", + ) + + log.Debug("%s: added ip-host pair %q/%q", hostsContainerPref, ip, host) +} + +// sendUpd tries to send the parsed data to the ch. +func (hp hostsParser) sendUpd(ch chan *netutil.IPMap) { + log.Debug("%s: sending upd", hostsContainerPref) + select { + case ch <- hp.table: + // Updates are delivered. Go on. + default: + log.Debug("%s: the buffer is full", hostsContainerPref) + } +} + +// newStrg creates a new rules storage from parsed data. +func (hp hostsParser) newStrg() (s *filterlist.RuleStorage, err error) { + return filterlist.NewRuleStorage([]filterlist.RuleList{&filterlist.StringRuleList{ + ID: 1, + RulesText: hp.rules.String(), + IgnoreCosmetic: true, + }}) +} + +// refresh gets the data from specified files and propagates the updates. +func (hc *HostsContainer) refresh() (err error) { + log.Debug("%s: refreshing", hostsContainerPref) + + hp := hostsParser{ + rules: &strings.Builder{}, + table: netutil.NewIPMap(0), + } + + _, err = aghos.FileWalker(hp.parseHostsFile).Walk(hc.fsys, hc.patterns...) + if err != nil { + return fmt.Errorf("updating: %w", err) + } + + defer hp.sendUpd(hc.updates) + + var rulesStrg *filterlist.RuleStorage + if rulesStrg, err = hp.newStrg(); err != nil { + return fmt.Errorf("initializing rules storage: %w", err) + } + + hc.resetEng(rulesStrg) + + return nil +} + +func (hc *HostsContainer) resetEng(rulesStrg *filterlist.RuleStorage) { + hc.engLock.Lock() + defer hc.engLock.Unlock() + + hc.rulesStrg = rulesStrg + hc.engine = urlfilter.NewDNSEngine(hc.rulesStrg) +} diff --git a/internal/aghnet/hostscontainer_linux.go b/internal/aghnet/hostscontainer_linux.go new file mode 100644 index 00000000..b456efdd --- /dev/null +++ b/internal/aghnet/hostscontainer_linux.go @@ -0,0 +1,18 @@ +//go:build linux +// +build linux + +package aghnet + +import ( + "github.com/AdguardTeam/AdGuardHome/internal/aghos" +) + +func defaultHostsPaths() (paths []string) { + paths = []string{"etc/hosts"} + + if aghos.IsOpenWrt() { + paths = append(paths, "tmp/hosts") + } + + return paths +} diff --git a/internal/aghnet/hostscontainer_others.go b/internal/aghnet/hostscontainer_others.go new file mode 100644 index 00000000..cd9f0fb7 --- /dev/null +++ b/internal/aghnet/hostscontainer_others.go @@ -0,0 +1,8 @@ +//go:build !(windows || linux) +// +build !windows,!linux + +package aghnet + +func defaultHostsPaths() (paths []string) { + return []string{"etc/hosts"} +} diff --git a/internal/aghnet/hostscontainer_test.go b/internal/aghnet/hostscontainer_test.go new file mode 100644 index 00000000..ce137c87 --- /dev/null +++ b/internal/aghnet/hostscontainer_test.go @@ -0,0 +1,504 @@ +package aghnet + +import ( + "io/fs" + "net" + "path" + "strings" + "sync/atomic" + "testing" + "testing/fstest" + + "github.com/AdguardTeam/AdGuardHome/internal/aghtest" + "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/netutil" + "github.com/AdguardTeam/urlfilter" + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const ( + nl = "\n" + sp = " " +) + +const closeCalled errors.Error = "close method called" + +// fsWatcherOnCloseStub is a stub implementation of the Close method of +// aghos.FSWatcher. +func fsWatcherOnCloseStub() (err error) { + return closeCalled +} + +func TestNewHostsContainer(t *testing.T) { + const dirname = "dir" + const filename = "file1" + + p := path.Join(dirname, filename) + + testFS := fstest.MapFS{ + p: &fstest.MapFile{Data: []byte("127.0.0.1 localhost")}, + } + + testCases := []struct { + name string + paths []string + wantErr error + wantPatterns []string + }{{ + name: "one_file", + paths: []string{p}, + wantErr: nil, + wantPatterns: []string{p}, + }, { + name: "no_files", + paths: []string{}, + wantErr: errNoPaths, + wantPatterns: nil, + }, { + name: "non-existent_file", + paths: []string{path.Join(dirname, filename+"2")}, + wantErr: fs.ErrNotExist, + wantPatterns: nil, + }, { + name: "whole_dir", + paths: []string{dirname}, + wantErr: nil, + wantPatterns: []string{path.Join(dirname, "*")}, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + onAdd := func(name string) (err error) { + assert.Contains(t, tc.paths, name) + + return nil + } + + var eventsCalledCounter uint32 + eventsCh := make(chan struct{}) + onEvents := func() (e <-chan struct{}) { + assert.Equal(t, uint32(1), atomic.AddUint32(&eventsCalledCounter, 1)) + + return eventsCh + } + + hc, err := NewHostsContainer(testFS, &aghtest.FSWatcher{ + OnEvents: onEvents, + OnAdd: onAdd, + OnClose: fsWatcherOnCloseStub, + }, tc.paths...) + if tc.wantErr != nil { + require.ErrorIs(t, err, tc.wantErr) + + assert.Nil(t, hc) + + return + } + + require.NoError(t, err) + t.Cleanup(func() { + require.ErrorIs(t, hc.Close(), closeCalled) + }) + + require.NotNil(t, hc) + + assert.Equal(t, tc.wantPatterns, hc.patterns) + assert.NotNil(t, <-hc.Upd()) + + eventsCh <- struct{}{} + assert.Equal(t, uint32(1), atomic.LoadUint32(&eventsCalledCounter)) + }) + } + + t.Run("nil_fs", func(t *testing.T) { + require.Panics(t, func() { + _, _ = NewHostsContainer(nil, &aghtest.FSWatcher{ + // Those shouldn't panic. + OnEvents: func() (e <-chan struct{}) { return nil }, + OnAdd: func(name string) (err error) { return nil }, + OnClose: func() (err error) { return nil }, + }, p) + }) + }) + + t.Run("nil_watcher", func(t *testing.T) { + require.Panics(t, func() { + _, _ = NewHostsContainer(testFS, nil, p) + }) + }) + + t.Run("err_watcher", func(t *testing.T) { + const errOnAdd errors.Error = "error" + + errWatcher := &aghtest.FSWatcher{ + OnEvents: func() (e <-chan struct{}) { panic("not implemented") }, + OnAdd: func(name string) (err error) { return errOnAdd }, + OnClose: func() (err error) { panic("not implemented") }, + } + + hc, err := NewHostsContainer(testFS, errWatcher, p) + require.ErrorIs(t, err, errOnAdd) + + assert.Nil(t, hc) + }) +} + +func TestHostsContainer_Refresh(t *testing.T) { + knownIP := net.IP{127, 0, 0, 1} + + const knownHost = "localhost" + const knownAlias = "hocallost" + + const dirname = "dir" + const filename1 = "file1" + const filename2 = "file2" + + p1 := path.Join(dirname, filename1) + p2 := path.Join(dirname, filename2) + + testFS := fstest.MapFS{ + p1: &fstest.MapFile{ + Data: []byte(strings.Join([]string{knownIP.String(), knownHost}, sp) + nl), + }, + } + + eventsCh := make(chan struct{}, 1) + t.Cleanup(func() { close(eventsCh) }) + + w := &aghtest.FSWatcher{ + OnEvents: func() (e <-chan struct{}) { return eventsCh }, + OnAdd: func(name string) (err error) { + assert.Equal(t, dirname, name) + + return nil + }, + OnClose: fsWatcherOnCloseStub, + } + + hc, err := NewHostsContainer(testFS, w, dirname) + require.NoError(t, err) + t.Cleanup(func() { require.ErrorIs(t, hc.Close(), closeCalled) }) + + checkRefresh := func(t *testing.T, wantHosts []string) { + upd, ok := <-hc.Upd() + require.True(t, ok) + require.NotNil(t, upd) + + assert.Equal(t, 1, upd.Len()) + + v, ok := upd.Get(knownIP) + require.True(t, ok) + + var hosts []string + hosts, ok = v.([]string) + require.True(t, ok) + require.Len(t, hosts, len(wantHosts)) + + assert.Equal(t, wantHosts, hosts) + } + + t.Run("initial_refresh", func(t *testing.T) { + checkRefresh(t, []string{knownHost}) + }) + + testFS[p2] = &fstest.MapFile{ + Data: []byte(strings.Join([]string{knownIP.String(), knownAlias}, sp) + nl), + } + + eventsCh <- struct{}{} + + t.Run("second_refresh", func(t *testing.T) { + checkRefresh(t, []string{knownHost, knownAlias}) + }) +} + +func TestHostsContainer_MatchRequest(t *testing.T) { + var ( + ip4 = net.IP{127, 0, 0, 1} + ip6 = net.IP{ + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 1, + } + + hostname4 = "localhost" + hostname6 = "localhostv6" + hostname4a = "abcd" + + reversed4, _ = netutil.IPToReversedAddr(ip4) + reversed6, _ = netutil.IPToReversedAddr(ip6) + ) + + const filename = "file1" + + gsfs := fstest.MapFS{ + filename: &fstest.MapFile{Data: []byte( + strings.Join([]string{ip4.String(), hostname4, hostname4a}, sp) + nl + + strings.Join([]string{ip6.String(), hostname6}, sp) + nl + + strings.Join([]string{"256.256.256.256", "fakebroadcast"}, sp) + nl, + )}, + } + + hc, err := NewHostsContainer(gsfs, &aghtest.FSWatcher{ + OnEvents: func() (e <-chan struct{}) { panic("not implemented") }, + OnAdd: func(name string) (err error) { + assert.Equal(t, filename, name) + + return nil + }, + OnClose: fsWatcherOnCloseStub, + }, filename) + require.NoError(t, err) + t.Cleanup(func() { require.ErrorIs(t, hc.Close(), closeCalled) }) + + testCase := []struct { + name string + want interface{} + req urlfilter.DNSRequest + }{{ + name: "a", + want: ip4.To16(), + req: urlfilter.DNSRequest{ + Hostname: hostname4, + DNSType: dns.TypeA, + }, + }, { + name: "aaaa", + want: ip6, + req: urlfilter.DNSRequest{ + Hostname: hostname6, + DNSType: dns.TypeA, + }, + }, { + name: "ptr", + want: dns.Fqdn(hostname4), + req: urlfilter.DNSRequest{ + Hostname: reversed4, + DNSType: dns.TypePTR, + }, + }, { + name: "ptr_v6", + want: dns.Fqdn(hostname6), + req: urlfilter.DNSRequest{ + Hostname: reversed6, + DNSType: dns.TypePTR, + }, + }, { + name: "a_alias", + want: ip4.To16(), + req: urlfilter.DNSRequest{ + Hostname: hostname4a, + DNSType: dns.TypeA, + }, + }} + + for _, tc := range testCase { + t.Run(tc.name, func(t *testing.T) { + res, ok := hc.MatchRequest(tc.req) + require.False(t, ok) + + assert.Equal(t, tc.want, res.DNSRewrites()[0].DNSRewrite.Value) + }) + } + + t.Run("cname", func(t *testing.T) { + res, ok := hc.MatchRequest(urlfilter.DNSRequest{ + Hostname: hostname4, + DNSType: dns.TypeCNAME, + }) + require.False(t, ok) + + assert.Empty(t, res) + }) +} + +func TestHostsContainer_PathsToPatterns(t *testing.T) { + const ( + dir0 = "dir" + dir1 = "dir_1" + fn1 = "file_1" + fn2 = "file_2" + fn3 = "file_3" + fn4 = "file_4" + ) + + fp1 := path.Join(dir0, fn1) + fp2 := path.Join(dir0, fn2) + fp3 := path.Join(dir0, dir1, fn3) + + gsfs := fstest.MapFS{ + fp1: &fstest.MapFile{Data: []byte{1}}, + fp2: &fstest.MapFile{Data: []byte{2}}, + fp3: &fstest.MapFile{Data: []byte{3}}, + } + + testCases := []struct { + name string + wantErr error + want []string + paths []string + }{{ + name: "no_paths", + wantErr: nil, + want: nil, + paths: nil, + }, { + name: "single_file", + wantErr: nil, + want: []string{fp1}, + paths: []string{fp1}, + }, { + name: "several_files", + wantErr: nil, + want: []string{fp1, fp2}, + paths: []string{fp1, fp2}, + }, { + name: "whole_dir", + wantErr: nil, + want: []string{path.Join(dir0, "*")}, + paths: []string{dir0}, + }, { + name: "file_and_dir", + wantErr: nil, + want: []string{fp1, path.Join(dir0, dir1, "*")}, + paths: []string{fp1, path.Join(dir0, dir1)}, + }, { + name: "non-existing", + wantErr: fs.ErrNotExist, + want: nil, + paths: []string{path.Join(dir0, "file_3")}, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + patterns, err := pathsToPatterns(gsfs, tc.paths) + if tc.wantErr != nil { + assert.ErrorIs(t, err, tc.wantErr) + + return + } + + require.NoError(t, err) + + assert.Equal(t, tc.want, patterns) + }) + } +} + +func TestUniqueRules_AddPair(t *testing.T) { + knownIP := net.IP{1, 2, 3, 4} + + const knownHost = "host1" + + ipToHost := netutil.NewIPMap(0) + ipToHost.Set(knownIP, []string{knownHost}) + + testCases := []struct { + name string + host string + wantRules string + ip net.IP + }{{ + name: "new_one", + host: "host2", + wantRules: "||host2^$dnsrewrite=NOERROR;A;1.2.3.4\n" + + "||4.3.2.1.in-addr.arpa^$dnsrewrite=NOERROR;PTR;host2.\n", + ip: knownIP, + }, { + name: "existing_one", + host: knownHost, + wantRules: "", + ip: knownIP, + }, { + name: "new_ip", + host: knownHost, + wantRules: "||" + knownHost + "^$dnsrewrite=NOERROR;A;1.2.3.5\n" + + "||5.3.2.1.in-addr.arpa^$dnsrewrite=NOERROR;PTR;" + knownHost + ".\n", + ip: net.IP{1, 2, 3, 5}, + }, { + name: "bad_ip", + host: knownHost, + wantRules: "", + ip: net.IP{1, 2, 3, 4, 5}, + }} + + for _, tc := range testCases { + hp := hostsParser{ + rules: &strings.Builder{}, + table: ipToHost.ShallowClone(), + } + + t.Run(tc.name, func(t *testing.T) { + hp.addPair(tc.ip, tc.host) + assert.Equal(t, tc.wantRules, hp.rules.String()) + }) + } +} + +func TestUniqueRules_ParseLine(t *testing.T) { + const ( + hostname = "localhost" + alias = "hocallost" + ) + + knownIP := net.IP{127, 0, 0, 1} + + testCases := []struct { + name string + line string + wantIP net.IP + wantHosts []string + }{{ + name: "simple", + line: strings.Join([]string{knownIP.String(), hostname}, sp), + wantIP: knownIP, + wantHosts: []string{"localhost"}, + }, { + name: "aliases", + line: strings.Join([]string{knownIP.String(), hostname, alias}, sp), + wantIP: knownIP, + wantHosts: []string{"localhost", "hocallost"}, + }, { + name: "invalid_line", + line: knownIP.String(), + wantIP: nil, + wantHosts: nil, + }, { + name: "invalid_line_hostname", + line: strings.Join([]string{knownIP.String(), "#" + hostname}, sp), + wantIP: knownIP, + wantHosts: nil, + }, { + name: "commented_aliases", + line: strings.Join([]string{knownIP.String(), hostname, "#" + alias}, sp), + wantIP: knownIP, + wantHosts: []string{"localhost"}, + }, { + name: "whole_comment", + line: strings.Join([]string{"#", knownIP.String(), hostname}, sp), + wantIP: nil, + wantHosts: nil, + }, { + name: "partial_comment", + line: strings.Join([]string{knownIP.String(), hostname[:4] + "#" + hostname[4:]}, sp), + wantIP: knownIP, + wantHosts: []string{hostname[:4]}, + }, { + name: "empty", + line: ``, + wantIP: nil, + wantHosts: nil, + }} + + for _, tc := range testCases { + hp := hostsParser{} + t.Run(tc.name, func(t *testing.T) { + ip, hosts := hp.parseLine(tc.line) + assert.True(t, tc.wantIP.Equal(ip)) + assert.Equal(t, tc.wantHosts, hosts) + }) + } +} diff --git a/internal/aghnet/hostscontainer_windows.go b/internal/aghnet/hostscontainer_windows.go new file mode 100644 index 00000000..9eb2fe6f --- /dev/null +++ b/internal/aghnet/hostscontainer_windows.go @@ -0,0 +1,33 @@ +//go:build windows +// +build windows + +package aghnet + +import ( + "os" + "path" + "path/filepath" + "strings" + + "github.com/AdguardTeam/golibs/log" + "golang.org/x/sys/windows" +) + +func defaultHostsPaths() (paths []string) { + sysDir, err := windows.GetSystemDirectory() + if err != nil { + log.Error("getting system directory: %s", err) + + return []string{} + } + + // Split all the elements of the path to join them afterwards. This is + // needed to make the Windows-specific path string returned by + // windows.GetSystemDirectory to be compatible with fs.FS. + pathElems := strings.Split(sysDir, string(os.PathSeparator)) + if len(pathElems) > 0 && pathElems[0] == filepath.VolumeName(sysDir) { + pathElems = pathElems[1:] + } + + return []string{path.Join(append(pathElems, "drivers/etc/hosts")...)} +} diff --git a/internal/aghnet/net_freebsd.go b/internal/aghnet/net_freebsd.go index f4c106f1..a5200fb8 100644 --- a/internal/aghnet/net_freebsd.go +++ b/internal/aghnet/net_freebsd.go @@ -18,9 +18,11 @@ func canBindPrivilegedPorts() (can bool, err error) { } func ifaceHasStaticIP(ifaceName string) (ok bool, err error) { - const filename = "/etc/rc.conf" + const rcConfFilename = "etc/rc.conf" - return aghos.FileWalker(interfaceName(ifaceName).rcConfStaticConfig).Walk(filename) + walker := aghos.FileWalker(interfaceName(ifaceName).rcConfStaticConfig) + + return walker.Walk(aghos.RootDirFS(), rcConfFilename) } // rcConfStaticConfig checks if the interface is configured by /etc/rc.conf to diff --git a/internal/aghnet/net_linux.go b/internal/aghnet/net_linux.go index 3c6a6659..93414165 100644 --- a/internal/aghnet/net_linux.go +++ b/internal/aghnet/net_linux.go @@ -85,17 +85,17 @@ func ifaceHasStaticIP(ifaceName string) (has bool, err error) { iface := interfaceName(ifaceName) - for _, pair := range []struct { + for _, pair := range [...]struct { aghos.FileWalker filename string }{{ FileWalker: iface.dhcpcdStaticConfig, - filename: "/etc/dhcpcd.conf", + filename: "etc/dhcpcd.conf", }, { FileWalker: iface.ifacesStaticConfig, - filename: "/etc/network/interfaces", + filename: "etc/network/interfaces", }} { - has, err = pair.Walk(pair.filename) + has, err = pair.Walk(aghos.RootDirFS(), pair.filename) if err != nil { return false, err } diff --git a/internal/aghnet/net_linux_test.go b/internal/aghnet/net_linux_test.go index a907819c..bf2cecfe 100644 --- a/internal/aghnet/net_linux_test.go +++ b/internal/aghnet/net_linux_test.go @@ -12,8 +12,6 @@ import ( "github.com/stretchr/testify/require" ) -const nl = "\n" - func TestDHCPCDStaticConfig(t *testing.T) { const iface interfaceName = `wlan0` diff --git a/internal/aghnet/net_openbsd.go b/internal/aghnet/net_openbsd.go index a2b5e6e6..627db0ab 100644 --- a/internal/aghnet/net_openbsd.go +++ b/internal/aghnet/net_openbsd.go @@ -18,9 +18,9 @@ func canBindPrivilegedPorts() (can bool, err error) { } func ifaceHasStaticIP(ifaceName string) (ok bool, err error) { - filename := fmt.Sprintf("/etc/hostname.%s", ifaceName) + filename := fmt.Sprintf("etc/hostname.%s", ifaceName) - return aghos.FileWalker(hostnameIfStaticConfig).Walk(filename) + return aghos.FileWalker(hostnameIfStaticConfig).Walk(aghos.RootDirFS(), filename) } // hostnameIfStaticConfig checks if the interface is configured by diff --git a/internal/aghnet/net_test.go b/internal/aghnet/net_test.go index 0773f67f..38fdf9cc 100644 --- a/internal/aghnet/net_test.go +++ b/internal/aghnet/net_test.go @@ -4,10 +4,15 @@ import ( "net" "testing" + "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +func TestMain(m *testing.M) { + aghtest.DiscardLogOutput(m) +} + func TestGetValidNetInterfacesForWeb(t *testing.T) { ifaces, err := GetValidNetInterfacesForWeb() require.NoErrorf(t, err, "cannot get net interfaces: %s", err) diff --git a/internal/aghos/filewalker.go b/internal/aghos/filewalker.go index b6473d1f..a72e2f5c 100644 --- a/internal/aghos/filewalker.go +++ b/internal/aghos/filewalker.go @@ -3,8 +3,7 @@ package aghos import ( "fmt" "io" - "os" - "path/filepath" + "io/fs" "github.com/AdguardTeam/AdGuardHome/internal/aghio" "github.com/AdguardTeam/golibs/errors" @@ -14,10 +13,10 @@ import ( // FileWalker is the signature of a function called for files in the file tree. // As opposed to filepath.Walk it only walk the files (not directories) matching // the provided pattern and those returned by function itself. All patterns -// should be valid for filepath.Glob. If cont is false, the walking terminates. -// Each opened file is also limited for reading to MaxWalkedFileSize. +// should be valid for fs.Glob. If cont is false, the walking terminates. Each +// opened file is also limited for reading to MaxWalkedFileSize. // -// TODO(e.burkov): Consider moving to the separate package like pathutil. +// TODO(e.burkov, a.garipov): Move into another package like aghfs. // // TODO(e.burkov): Think about passing filename or any additional data. type FileWalker func(r io.Reader) (patterns []string, cont bool, err error) @@ -26,15 +25,19 @@ type FileWalker func(r io.Reader) (patterns []string, cont bool, err error) // check. const MaxWalkedFileSize = 1024 * 1024 -// checkFile tries to open and process a single file located on sourcePath. -func checkFile(c FileWalker, sourcePath string) (patterns []string, cont bool, err error) { - var f *os.File - f, err = os.Open(sourcePath) +// checkFile tries to open and process a single file located on sourcePath in +// the specified fsys. The path is skipped if it's a directory. +func checkFile( + fsys fs.FS, + c FileWalker, + sourcePath string, +) (patterns []string, cont bool, err error) { + var f fs.File + f, err = fsys.Open(sourcePath) if err != nil { - if errors.Is(err, os.ErrNotExist) { - // Ignore non-existing files since this may only happen - // when the file was removed after filepath.Glob matched - // it. + if errors.Is(err, fs.ErrNotExist) { + // Ignore non-existing files since this may only happen when the + // file was removed after filepath.Glob matched it. return nil, true, nil } @@ -42,9 +45,18 @@ func checkFile(c FileWalker, sourcePath string) (patterns []string, cont bool, e } defer func() { err = errors.WithDeferred(err, f.Close()) }() + var fi fs.FileInfo + if fi, err = f.Stat(); err != nil { + return nil, true, err + } + if fi.IsDir() { + // Skip the directories. + return nil, true, nil + } + var r io.Reader - // Ignore the error since LimitReader function returns error only if - // passed limit value is less than zero, but the constant used. + // Ignore the error since LimitReader function returns error only if passed + // limit value is less than zero, but the constant used. // // TODO(e.burkov): Make variable. r, _ = aghio.LimitReader(f, MaxWalkedFileSize) @@ -52,13 +64,17 @@ func checkFile(c FileWalker, sourcePath string) (patterns []string, cont bool, e return c(r) } -// handlePatterns parses the patterns and ignores duplicates using srcSet. -// srcSet must be non-nil. -func handlePatterns(srcSet *stringutil.Set, patterns ...string) (sub []string, err error) { +// handlePatterns parses the patterns in fsys and ignores duplicates using +// srcSet. srcSet must be non-nil. +func handlePatterns( + fsys fs.FS, + srcSet *stringutil.Set, + patterns ...string, +) (sub []string, err error) { sub = make([]string, 0, len(patterns)) for _, p := range patterns { var matches []string - matches, err = filepath.Glob(p) + matches, err = fs.Glob(fsys, p) if err != nil { // Enrich error with the pattern because filepath.Glob // doesn't do it. @@ -78,14 +94,14 @@ func handlePatterns(srcSet *stringutil.Set, patterns ...string) (sub []string, e return sub, nil } -// Walk starts walking the files defined by initPattern. It only returns true -// if c signed to stop walking. -func (c FileWalker) Walk(initPattern string) (ok bool, err error) { - // The slice of sources keeps the order in which the files are walked - // since srcSet.Values() returns strings in undefined order. +// Walk starts walking the files in fsys defined by patterns from initial. +// It only returns true if fw signed to stop walking. +func (fw FileWalker) Walk(fsys fs.FS, initial ...string) (ok bool, err error) { + // The slice of sources keeps the order in which the files are walked since + // srcSet.Values() returns strings in undefined order. srcSet := stringutil.NewSet() var src []string - src, err = handlePatterns(srcSet, initPattern) + src, err = handlePatterns(fsys, srcSet, initial...) if err != nil { return false, err } @@ -97,7 +113,7 @@ func (c FileWalker) Walk(initPattern string) (ok bool, err error) { var patterns []string var cont bool filename = src[i] - patterns, cont, err = checkFile(c, src[i]) + patterns, cont, err = checkFile(fsys, fw, src[i]) if err != nil { return false, err } @@ -107,7 +123,7 @@ func (c FileWalker) Walk(initPattern string) (ok bool, err error) { } var subsrc []string - subsrc, err = handlePatterns(srcSet, patterns...) + subsrc, err = handlePatterns(fsys, srcSet, patterns...) if err != nil { return false, err } diff --git a/internal/aghos/filewalker_test.go b/internal/aghos/filewalker_test.go index 4ba1db20..97d1a845 100644 --- a/internal/aghos/filewalker_test.go +++ b/internal/aghos/filewalker_test.go @@ -4,56 +4,19 @@ import ( "bufio" "io" "io/fs" - "os" - "path/filepath" + "path" "testing" + "testing/fstest" "github.com/AdguardTeam/golibs/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -// testFSDir maps entries' names to entries which should either be a testFSDir -// or byte slice. -type testFSDir map[string]interface{} - -// testFSGen is used to generate a temporary filesystem consisting of -// directories and plain text files from itself. -type testFSGen testFSDir - -// gen returns the name of top directory of the generated filesystem. -func (g testFSGen) gen(t *testing.T) (dirName string) { - t.Helper() - - dirName = t.TempDir() - g.rangeThrough(t, dirName) - - return dirName -} - -func (g testFSGen) rangeThrough(t *testing.T, dirName string) { - const perm fs.FileMode = 0o777 - - for k, e := range g { - switch e := e.(type) { - case []byte: - require.NoError(t, os.WriteFile(filepath.Join(dirName, k), e, perm)) - - case testFSDir: - newDir := filepath.Join(dirName, k) - require.NoError(t, os.Mkdir(newDir, perm)) - - testFSGen(e).rangeThrough(t, newDir) - default: - t.Fatalf("unexpected entry type %T", e) - } - } -} - func TestFileWalker_Walk(t *testing.T) { const attribute = `000` - makeFileWalker := func(dirName string) (fw FileWalker) { + makeFileWalker := func(_ string) (fw FileWalker) { return func(r io.Reader) (patterns []string, cont bool, err error) { s := bufio.NewScanner(r) for s.Scan() { @@ -63,7 +26,7 @@ func TestFileWalker_Walk(t *testing.T) { } if len(line) != 0 { - patterns = append(patterns, filepath.Join(dirName, line)) + patterns = append(patterns, path.Join(".", line)) } } @@ -74,136 +37,150 @@ func TestFileWalker_Walk(t *testing.T) { const nl = "\n" testCases := []struct { - name string - testFS testFSGen + testFS fstest.MapFS + want assert.BoolAssertionFunc initPattern string - want bool + name string }{{ name: "simple", - testFS: testFSGen{ - "simple_0001.txt": []byte(attribute + nl), + testFS: fstest.MapFS{ + "simple_0001.txt": &fstest.MapFile{Data: []byte(attribute + nl)}, }, initPattern: "simple_0001.txt", - want: true, + want: assert.True, }, { name: "chain", - testFS: testFSGen{ - "chain_0001.txt": []byte(`chain_0002.txt` + nl), - "chain_0002.txt": []byte(`chain_0003.txt` + nl), - "chain_0003.txt": []byte(attribute + nl), + testFS: fstest.MapFS{ + "chain_0001.txt": &fstest.MapFile{Data: []byte(`chain_0002.txt` + nl)}, + "chain_0002.txt": &fstest.MapFile{Data: []byte(`chain_0003.txt` + nl)}, + "chain_0003.txt": &fstest.MapFile{Data: []byte(attribute + nl)}, }, initPattern: "chain_0001.txt", - want: true, + want: assert.True, }, { name: "several", - testFS: testFSGen{ - "several_0001.txt": []byte(`several_*` + nl), - "several_0002.txt": []byte(`several_0001.txt` + nl), - "several_0003.txt": []byte(attribute + nl), + testFS: fstest.MapFS{ + "several_0001.txt": &fstest.MapFile{Data: []byte(`several_*` + nl)}, + "several_0002.txt": &fstest.MapFile{Data: []byte(`several_0001.txt` + nl)}, + "several_0003.txt": &fstest.MapFile{Data: []byte(attribute + nl)}, }, initPattern: "several_0001.txt", - want: true, + want: assert.True, }, { name: "no", - testFS: testFSGen{ - "no_0001.txt": []byte(nl), - "no_0002.txt": []byte(nl), - "no_0003.txt": []byte(nl), + testFS: fstest.MapFS{ + "no_0001.txt": &fstest.MapFile{Data: []byte(nl)}, + "no_0002.txt": &fstest.MapFile{Data: []byte(nl)}, + "no_0003.txt": &fstest.MapFile{Data: []byte(nl)}, }, initPattern: "no_*", - want: false, + want: assert.False, }, { name: "subdirectory", - testFS: testFSGen{ - "dir": testFSDir{ - "subdir_0002.txt": []byte(attribute + nl), + testFS: fstest.MapFS{ + path.Join("dir", "subdir_0002.txt"): &fstest.MapFile{ + Data: []byte(attribute + nl), }, - "subdir_0001.txt": []byte(`dir/*`), + "subdir_0001.txt": &fstest.MapFile{Data: []byte(`dir/*`)}, }, initPattern: "subdir_0001.txt", - want: true, + want: assert.True, }} for _, tc := range testCases { - testDir := tc.testFS.gen(t) - fw := makeFileWalker(testDir) + fw := makeFileWalker("") t.Run(tc.name, func(t *testing.T) { - ok, err := fw.Walk(filepath.Join(testDir, tc.initPattern)) + ok, err := fw.Walk(tc.testFS, tc.initPattern) require.NoError(t, err) - assert.Equal(t, tc.want, ok) + tc.want(t, ok) }) } t.Run("pattern_malformed", func(t *testing.T) { - ok, err := makeFileWalker("").Walk("[]") + f := fstest.MapFS{} + ok, err := makeFileWalker("").Walk(f, "[]") require.Error(t, err) assert.False(t, ok) - assert.ErrorIs(t, err, filepath.ErrBadPattern) + assert.ErrorIs(t, err, path.ErrBadPattern) }) t.Run("bad_filename", func(t *testing.T) { - dir := testFSGen{ - "bad_filename.txt": []byte("[]"), - }.gen(t) - fw := FileWalker(func(r io.Reader) (patterns []string, cont bool, err error) { + const filename = "bad_filename.txt" + + f := fstest.MapFS{ + filename: &fstest.MapFile{Data: []byte("[]")}, + } + ok, err := FileWalker(func(r io.Reader) (patterns []string, cont bool, err error) { s := bufio.NewScanner(r) for s.Scan() { patterns = append(patterns, s.Text()) } return patterns, true, s.Err() - }) - - ok, err := fw.Walk(filepath.Join(dir, "bad_filename.txt")) + }).Walk(f, filename) require.Error(t, err) assert.False(t, ok) - assert.ErrorIs(t, err, filepath.ErrBadPattern) + assert.ErrorIs(t, err, path.ErrBadPattern) }) t.Run("itself_error", func(t *testing.T) { const rerr errors.Error = "returned error" - dir := testFSGen{ - "mockfile.txt": []byte(`mockdata`), - }.gen(t) + f := fstest.MapFS{ + "mockfile.txt": &fstest.MapFile{Data: []byte(`mockdata`)}, + } ok, err := FileWalker(func(r io.Reader) (patterns []string, ok bool, err error) { return nil, true, rerr - }).Walk(filepath.Join(dir, "*")) - require.Error(t, err) - require.False(t, ok) + }).Walk(f, "*") + require.ErrorIs(t, err, rerr) - assert.ErrorIs(t, err, rerr) + assert.False(t, ok) }) } +type errFS struct { + fs.GlobFS +} + +const errErrFSOpen errors.Error = "this error is always returned" + +func (efs *errFS) Open(name string) (fs.File, error) { + return nil, errErrFSOpen +} + func TestWalkerFunc_CheckFile(t *testing.T) { + emptyFS := fstest.MapFS{} + t.Run("non-existing", func(t *testing.T) { - _, ok, err := checkFile(nil, "lol") + _, ok, err := checkFile(emptyFS, nil, "lol") require.NoError(t, err) assert.True(t, ok) }) t.Run("invalid_argument", func(t *testing.T) { - const badPath = "\x00" - - _, ok, err := checkFile(nil, badPath) - require.Error(t, err) + _, ok, err := checkFile(&errFS{}, nil, "") + require.ErrorIs(t, err, errErrFSOpen) assert.False(t, ok) - // TODO(e.burkov): Use assert.ErrorsIs within the error from - // less platform-dependent package instead of syscall.EINVAL. - // - // See https://github.com/golang/go/issues/46849 and - // https://github.com/golang/go/issues/30322. - pathErr := &os.PathError{} - require.ErrorAs(t, err, &pathErr) - assert.Equal(t, "open", pathErr.Op) - assert.Equal(t, badPath, pathErr.Path) + }) + + t.Run("ignore_dirs", func(t *testing.T) { + const dirName = "dir" + + testFS := fstest.MapFS{ + path.Join(dirName, "file"): &fstest.MapFile{Data: []byte{}}, + } + + patterns, ok, err := checkFile(testFS, nil, dirName) + require.NoError(t, err) + + assert.Empty(t, patterns) + assert.True(t, ok) }) } diff --git a/internal/aghos/fswatcher.go b/internal/aghos/fswatcher.go new file mode 100644 index 00000000..a113610f --- /dev/null +++ b/internal/aghos/fswatcher.go @@ -0,0 +1,131 @@ +package aghos + +import ( + "fmt" + "io" + "io/fs" + "path/filepath" + + "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/log" + "github.com/fsnotify/fsnotify" +) + +// event is a convenient alias for an empty struct to signal that watching +// event happened. +type event = struct{} + +// FSWatcher tracks all the fyle system events and notifies about those. +// +// TODO(e.burkov, a.garipov): Move into another package like aghfs. +type FSWatcher interface { + io.Closer + + // Events should return a read-only channel which notifies about events. + Events() (e <-chan event) + + // Add should check if the file named name is accessible and starts tracking + // it. + Add(name string) (err error) +} + +// osWatcher tracks the file system provided by the OS. +type osWatcher struct { + // w is the actual notifier that is handled by osWatcher. + w *fsnotify.Watcher + + // events is the channel to notify. + events chan event +} + +const ( + // osWatcherPref is a prefix for logging and wrapping errors in osWathcer's + // methods. + osWatcherPref = "os watcher" +) + +// NewOSWritesWatcher creates FSWatcher that tracks the real file system of the +// OS and notifies only about writing events. +func NewOSWritesWatcher() (w FSWatcher, err error) { + defer func() { err = errors.Annotate(err, "%s: %w", osWatcherPref) }() + + var watcher *fsnotify.Watcher + watcher, err = fsnotify.NewWatcher() + if err != nil { + return nil, fmt.Errorf("creating watcher: %w", err) + } + + fsw := &osWatcher{ + w: watcher, + events: make(chan event, 1), + } + + go fsw.handleErrors() + go fsw.handleEvents() + + return fsw, nil +} + +// handleErrors handles accompanying errors. It used to be called in a separate +// goroutine. +func (w *osWatcher) handleErrors() { + defer log.OnPanic(fmt.Sprintf("%s: handling errors", osWatcherPref)) + + for err := range w.w.Errors { + log.Error("%s: %s", osWatcherPref, err) + } +} + +// Events implements the FSWatcher interface for *osWatcher. +func (w *osWatcher) Events() (e <-chan event) { + return w.events +} + +// Add implements the FSWatcher interface for *osWatcher. +func (w *osWatcher) Add(name string) (err error) { + defer func() { err = errors.Annotate(err, "%s: %w", osWatcherPref) }() + + if _, err = fs.Stat(RootDirFS(), name); err != nil { + return fmt.Errorf("checking file %q: %w", name, err) + } + + return w.w.Add(filepath.Join("/", name)) +} + +// Close implements the FSWatcher interface for *osWatcher. +func (w *osWatcher) Close() (err error) { + return w.w.Close() +} + +// handleEvents notifies about the received file system's event if needed. It +// used to be called in a separate goroutine. +func (w *osWatcher) handleEvents() { + defer log.OnPanic(fmt.Sprintf("%s: handling events", osWatcherPref)) + + defer close(w.events) + + ch := w.w.Events + for e := range ch { + if e.Op&fsnotify.Write == 0 { + continue + } + + // Skip the following events assuming that sometimes the same event + // occurrs several times. + for ok := true; ok; { + select { + case _, ok = <-ch: + // Go on. + default: + ok = false + } + } + + select { + case w.events <- event{}: + // Go on. + default: + log.Debug("%s: events buffer is full", osWatcherPref) + } + } +} diff --git a/internal/aghos/os.go b/internal/aghos/os.go index 501f7d66..1723e001 100644 --- a/internal/aghos/os.go +++ b/internal/aghos/os.go @@ -7,6 +7,8 @@ import ( "bufio" "fmt" "io" + "io/fs" + "os" "os/exec" "path" "runtime" @@ -159,3 +161,10 @@ ScanLoop: func IsOpenWrt() (ok bool) { return isOpenWrt() } + +// RootDirFS returns the fs.FS rooted at the operating system's root. +func RootDirFS() (fsys fs.FS) { + // Use empty string since os.DirFS implicitly prepends a slash to it. This + // behavior is undocumented but it currently works. + return os.DirFS("") +} diff --git a/internal/aghos/os_linux.go b/internal/aghos/os_linux.go index 349fc18a..ed513b00 100644 --- a/internal/aghos/os_linux.go +++ b/internal/aghos/os_linux.go @@ -26,6 +26,8 @@ func haveAdminRights() (bool, error) { } func isOpenWrt() (ok bool) { + const etcReleasePattern = "etc/*release*" + var err error ok, err = FileWalker(func(r io.Reader) (_ []string, cont bool, err error) { const osNameData = "openwrt" @@ -39,7 +41,7 @@ func isOpenWrt() (ok bool) { } return nil, !stringutil.ContainsFold(string(data), osNameData), nil - }).Walk("/etc/*release*") + }).Walk(RootDirFS(), etcReleasePattern) return err == nil && ok } diff --git a/internal/aghtest/fswatcher.go b/internal/aghtest/fswatcher.go new file mode 100644 index 00000000..0df4470d --- /dev/null +++ b/internal/aghtest/fswatcher.go @@ -0,0 +1,23 @@ +package aghtest + +// FSWatcher is a mock aghos.FSWatcher implementation to use in tests. +type FSWatcher struct { + OnEvents func() (e <-chan struct{}) + OnAdd func(name string) (err error) + OnClose func() (err error) +} + +// Events implements the aghos.FSWatcher interface for *FSWatcher. +func (w *FSWatcher) Events() (e <-chan struct{}) { + return w.OnEvents() +} + +// Add implements the aghos.FSWatcher interface for *FSWatcher. +func (w *FSWatcher) Add(name string) (err error) { + return w.OnAdd(name) +} + +// Close implements the aghos.FSWatcher interface for *FSWatcher. +func (w *FSWatcher) Close() (err error) { + return w.OnClose() +} diff --git a/internal/dnsforward/dnsforward_test.go b/internal/dnsforward/dnsforward_test.go index 397f496e..3ab76a40 100644 --- a/internal/dnsforward/dnsforward_test.go +++ b/internal/dnsforward/dnsforward_test.go @@ -11,9 +11,10 @@ import ( "fmt" "math/big" "net" - "os" "sync" + "sync/atomic" "testing" + "testing/fstest" "time" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" @@ -1057,23 +1058,43 @@ func TestPTRResponseFromDHCPLeases(t *testing.T) { } func TestPTRResponseFromHosts(t *testing.T) { - c := filtering.Config{ - EtcHosts: &aghnet.EtcHostsContainer{}, + // Prepare test hosts file. + + const hostsFilename = "hosts" + + testFS := fstest.MapFS{ + hostsFilename: &fstest.MapFile{Data: []byte(` + 127.0.0.1 host # comment + ::1 localhost#comment + `)}, } - // Prepare test hosts file. - hf, err := os.CreateTemp("", "") + const closeCalled errors.Error = "close method called" + + var eventsCalledCounter uint32 + hc, err := aghnet.NewHostsContainer(testFS, &aghtest.FSWatcher{ + OnEvents: func() (e <-chan struct{}) { + assert.Equal(t, uint32(1), atomic.AddUint32(&eventsCalledCounter, 1)) + + return nil + }, + OnAdd: func(name string) (err error) { + assert.Equal(t, hostsFilename, name) + + return nil + }, + OnClose: func() (err error) { return closeCalled }, + }, hostsFilename) require.NoError(t, err) t.Cleanup(func() { - assert.NoError(t, hf.Close()) - assert.NoError(t, os.Remove(hf.Name())) + assert.Equal(t, uint32(1), atomic.LoadUint32(&eventsCalledCounter)) + + require.ErrorIs(t, hc.Close(), closeCalled) }) - _, _ = hf.WriteString(" 127.0.0.1 host # comment \n") - _, _ = hf.WriteString(" ::1 localhost#comment \n") - - c.EtcHosts.Init(hf.Name()) - t.Cleanup(c.EtcHosts.Close) + c := filtering.Config{ + EtcHosts: hc, + } var snd *aghnet.SubnetDetector snd, err = aghnet.NewSubnetDetector() @@ -1109,7 +1130,7 @@ func TestPTRResponseFromHosts(t *testing.T) { resp, err := dns.Exchange(req, addr.String()) require.NoError(t, err) - require.Len(t, resp.Answer, 1) + require.Lenf(t, resp.Answer, 1, "%#v", resp) assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype) assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name) diff --git a/internal/filtering/dnsrewrite.go b/internal/filtering/dnsrewrite.go index 5cefe01d..e98dfa3d 100644 --- a/internal/filtering/dnsrewrite.go +++ b/internal/filtering/dnsrewrite.go @@ -15,9 +15,9 @@ type DNSRewriteResult struct { // the server returns. type DNSRewriteResultResponse map[rules.RRType][]rules.RRValue -// processDNSRewrites processes DNS rewrite rules in dnsr. It returns -// an empty result if dnsr is empty. Otherwise, the result will have -// either CanonName or DNSRewriteResult set. +// processDNSRewrites processes DNS rewrite rules in dnsr. It returns an empty +// result if dnsr is empty. Otherwise, the result will have either CanonName or +// DNSRewriteResult set. func (d *DNSFilter) processDNSRewrites(dnsr []*rules.NetworkRule) (res Result) { if len(dnsr) == 0 { return Result{} diff --git a/internal/filtering/filtering.go b/internal/filtering/filtering.go index 4db69cd2..f7a8ebe4 100644 --- a/internal/filtering/filtering.go +++ b/internal/filtering/filtering.go @@ -73,7 +73,7 @@ type Config struct { // EtcHosts is a container of IP-hostname pairs taken from the operating // system configuration files (e.g. /etc/hosts). - EtcHosts *aghnet.EtcHostsContainer `yaml:"-"` + EtcHosts *aghnet.HostsContainer `yaml:"-"` // Called when the configuration is changed by HTTP request ConfigModified func() `yaml:"-"` @@ -176,8 +176,8 @@ const ( // FilteredBlockedService - the host is blocked by "blocked services" settings FilteredBlockedService - // Rewritten is returned when there was a rewrite by a legacy DNS - // rewrite rule. + // Rewritten is returned when there was a rewrite by a legacy DNS rewrite + // rule. Rewritten // RewrittenAutoHosts is returned when there was a rewrite by autohosts @@ -186,8 +186,8 @@ const ( // RewrittenRule is returned when a $dnsrewrite filter rule was applied. // - // TODO(a.garipov): Remove Rewritten and RewrittenAutoHosts by merging - // their functionality into RewrittenRule. + // TODO(a.garipov): Remove Rewritten and RewrittenAutoHosts by merging their + // functionality into RewrittenRule. // // See https://github.com/AdguardTeam/AdGuardHome/issues/2499. RewrittenRule @@ -371,24 +371,23 @@ type Result struct { // Reason is the reason for blocking or unblocking the request. Reason Reason `json:",omitempty"` - // Rules are applied rules. If Rules are not empty, each rule - // is not nil. + // Rules are applied rules. If Rules are not empty, each rule is not nil. Rules []*ResultRule `json:",omitempty"` - // ReverseHosts is the reverse lookup rewrite result. It is - // empty unless Reason is set to RewrittenAutoHosts. + // ReverseHosts is the reverse lookup rewrite result. It is empty unless + // Reason is set to RewrittenAutoHosts. ReverseHosts []string `json:",omitempty"` - // IPList is the lookup rewrite result. It is empty unless - // Reason is set to RewrittenAutoHosts or Rewritten. + // IPList is the lookup rewrite result. It is empty unless Reason is set to + // RewrittenAutoHosts or Rewritten. IPList []net.IP `json:",omitempty"` - // CanonName is the CNAME value from the lookup rewrite result. - // It is empty unless Reason is set to Rewritten or RewrittenRule. + // CanonName is the CNAME value from the lookup rewrite result. It is empty + // unless Reason is set to Rewritten or RewrittenRule. CanonName string `json:",omitempty"` - // ServiceName is the name of the blocked service. It is empty - // unless Reason is set to FilteredBlockedService. + // ServiceName is the name of the blocked service. It is empty unless + // Reason is set to FilteredBlockedService. ServiceName string `json:",omitempty"` // DNSRewriteResult is the $dnsrewrite filter rule result. @@ -446,43 +445,49 @@ func (d *DNSFilter) CheckHost( return Result{}, nil } -// checkEtcHosts compares the host against our /etc/hosts table. The err is -// always nil, it is only there to make this a valid hostChecker function. -func (d *DNSFilter) checkEtcHosts( - host string, - qtype uint16, - _ *Settings, -) (res Result, err error) { - if d.Config.EtcHosts == nil { +// matchSysHosts tries to match the host against the operating system's hosts +// database. +func (d *DNSFilter) matchSysHosts(host string, qtype uint16, setts *Settings) (res Result, err error) { + if d.EtcHosts == nil { return Result{}, nil } - ips := d.Config.EtcHosts.Process(host, qtype) - if ips != nil { - res = Result{ - Reason: RewrittenAutoHosts, - IPList: ips, - } + dnsres, _ := d.EtcHosts.MatchRequest(urlfilter.DNSRequest{ + Hostname: host, + SortedClientTags: setts.ClientTags, + // TODO(e.burkov): Wait for urlfilter update to pass net.IP. + ClientIP: setts.ClientIP.String(), + ClientName: setts.ClientName, + DNSType: qtype, + }) - return res, nil + dnsr := dnsres.DNSRewrites() + if len(dnsr) == 0 { + return Result{}, nil } - revHosts := d.Config.EtcHosts.ProcessReverse(host, qtype) - if len(revHosts) != 0 { - res = Result{ - Reason: RewrittenAutoHosts, + var ips []net.IP + var revHosts []string + + for _, nr := range dnsr { + dr := nr.DNSRewrite + if dr == nil { + continue } - // TODO(a.garipov): Optimize this with a buffer. - res.ReverseHosts = make([]string, len(revHosts)) - for i := range revHosts { - res.ReverseHosts[i] = revHosts[i] + "." + switch val := nr.DNSRewrite.Value.(type) { + case net.IP: + ips = append(ips, val) + case string: + revHosts = append(revHosts, val) } - - return res, nil } - return Result{}, nil + return Result{ + Reason: RewrittenAutoHosts, + IPList: ips, + ReverseHosts: revHosts, + }, nil } // Process rewrites table @@ -647,15 +652,18 @@ func (d *DNSFilter) initFiltering(allowFilters, blockFilters []Filter) error { return err } - d.engineLock.Lock() - d.reset() - d.rulesStorage = rulesStorage - d.filteringEngine = filteringEngine - d.rulesStorageAllow = rulesStorageAllow - d.filteringEngineAllow = filteringEngineAllow - d.engineLock.Unlock() + func() { + d.engineLock.Lock() + defer d.engineLock.Unlock() - // Make sure that the OS reclaims memory as soon as possible + d.reset() + d.rulesStorage = rulesStorage + d.filteringEngine = filteringEngine + d.rulesStorageAllow = rulesStorageAllow + d.filteringEngineAllow = filteringEngineAllow + }() + + // Make sure that the OS reclaims memory as soon as possible. debug.FreeOSMemory() log.Debug("initialized filtering engine") @@ -734,8 +742,8 @@ func (d *DNSFilter) matchHostProcessDNSResult( } if dnsres.HostRulesV4 != nil || dnsres.HostRulesV6 != nil { - // Question type doesn't match the host rules. Return the first - // matched host rule, but without an IP address. + // Question type doesn't match the host rules. Return the first matched + // host rule, but without an IP address. var matchedRules []rules.Rule if dnsres.HostRulesV4 != nil { matchedRules = []rules.Rule{dnsres.HostRulesV4[0]} @@ -760,11 +768,6 @@ func (d *DNSFilter) matchHost( return Result{}, nil } - d.engineLock.RLock() - // Keep in mind that this lock must be held no just when calling Match() - // but also while using the rules returned by it. - defer d.engineLock.RUnlock() - ureq := urlfilter.DNSRequest{ Hostname: host, SortedClientTags: setts.ClientTags, @@ -774,6 +777,13 @@ func (d *DNSFilter) matchHost( DNSType: qtype, } + d.engineLock.RLock() + // Keep in mind that this lock must be held no just when calling Match() but + // also while using the rules returned by it. + // + // TODO(e.burkov): Inspect if the above is true. + defer d.engineLock.RUnlock() + if d.filteringEngineAllow != nil { dnsres, ok := d.filteringEngineAllow.MatchRequest(ureq) if ok { @@ -791,8 +801,8 @@ func (d *DNSFilter) matchHost( if dnsr := dnsres.DNSRewrites(); len(dnsr) > 0 { res = d.processDNSRewrites(dnsr) if res.Reason == RewrittenRule && res.CanonName == host { - // A rewrite of a host to itself. Go on and try - // matching other things. + // A rewrite of a host to itself. Go on and try matching other + // things. } else { return res, nil } @@ -868,8 +878,8 @@ func New(c *Config, blockFilters []Filter) *DNSFilter { } d.hostCheckers = []hostChecker{{ - check: d.checkEtcHosts, - name: "etchosts", + check: d.matchSysHosts, + name: "hosts container", }, { check: d.matchHost, name: "filtering", diff --git a/internal/home/clients.go b/internal/home/clients.go index 9b660757..85eec9fd 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -95,7 +95,9 @@ type clientsContainer struct { // dnsServer is used for checking clients IP status access list status dnsServer *dnsforward.Server - etcHosts *aghnet.EtcHostsContainer // get entries from system hosts-files + // etcHosts contains list of rewrite rules taken from the operating system's + // hosts databse. + etcHosts *aghnet.HostsContainer testing bool // if TRUE, this object is used for internal tests } @@ -106,7 +108,7 @@ type clientsContainer struct { func (clients *clientsContainer) Init( objects []clientObject, dhcpServer *dhcpd.Server, - etcHosts *aghnet.EtcHostsContainer, + etcHosts *aghnet.HostsContainer, ) { if clients.list != nil { log.Fatal("clients.list != nil") @@ -121,13 +123,22 @@ func (clients *clientsContainer) Init( clients.etcHosts = etcHosts clients.addFromConfig(objects) - if !clients.testing { - clients.updateFromDHCP(true) - if clients.dhcpServer != nil { - clients.dhcpServer.SetOnLeaseChanged(clients.onDHCPLeaseChanged) - } - if clients.etcHosts != nil { - clients.etcHosts.SetOnChanged(clients.onHostsChanged) + if clients.testing { + return + } + + clients.updateFromDHCP(true) + if clients.dhcpServer != nil { + clients.dhcpServer.SetOnLeaseChanged(clients.onDHCPLeaseChanged) + } + + go clients.handleHostsUpdates() +} + +func (clients *clientsContainer) handleHostsUpdates() { + if clients.etcHosts != nil { + for upd := range clients.etcHosts.Upd() { + clients.addFromHostsFile(upd) } } } @@ -250,10 +261,6 @@ func (clients *clientsContainer) onDHCPLeaseChanged(flags int) { } } -func (clients *clientsContainer) onHostsChanged() { - clients.addFromHostsFile() -} - // Exists checks if client with this IP address already exists. func (clients *clientsContainer) Exists(ip net.IP, source clientSource) (ok bool) { clients.lock.Lock() @@ -697,7 +704,7 @@ func (clients *clientsContainer) SetWHOISInfo(ip net.IP, wi *RuntimeClientWHOISI log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi) } -// AddHost adds a new IP-hostname pairing. The priorities of the sources is +// AddHost adds a new IP-hostname pairing. The priorities of the sources are // taken into account. ok is true if the pairing was added. func (clients *clientsContainer) AddHost(ip net.IP, host string, src clientSource) (ok bool, err error) { clients.lock.Lock() @@ -757,13 +764,7 @@ func (clients *clientsContainer) rmHostsBySrc(src clientSource) { // addFromHostsFile fills the client-hostname pairing index from the system's // hosts files. -func (clients *clientsContainer) addFromHostsFile() { - if clients.etcHosts == nil { - return - } - - hosts := clients.etcHosts.List() - +func (clients *clientsContainer) addFromHostsFile(hosts *netutil.IPMap) { clients.lock.Lock() defer clients.lock.Unlock() diff --git a/internal/home/filter.go b/internal/home/filter.go index 22bc7435..74f85356 100644 --- a/internal/home/filter.go +++ b/internal/home/filter.go @@ -710,7 +710,6 @@ func enableFilters(async bool) { } func enableFiltersLocked(async bool) { - var whiteFilters []filtering.Filter filters := []filtering.Filter{{ Data: []byte(strings.Join(config.UserRules, "\n")), }} @@ -725,18 +724,20 @@ func enableFiltersLocked(async bool) { FilePath: filter.Path(), }) } + + var allowFilters []filtering.Filter for _, filter := range config.WhitelistFilters { if !filter.Enabled { continue } - whiteFilters = append(whiteFilters, filtering.Filter{ + allowFilters = append(allowFilters, filtering.Filter{ ID: filter.ID, FilePath: filter.Path(), }) } - if err := Context.dnsFilter.SetFilters(filters, whiteFilters, async); err != nil { + if err := Context.dnsFilter.SetFilters(filters, allowFilters, async); err != nil { log.Debug("enabling filters: %s", err) } diff --git a/internal/home/home.go b/internal/home/home.go index 6dceb1fd..c5a6a5e6 100644 --- a/internal/home/home.go +++ b/internal/home/home.go @@ -44,20 +44,22 @@ type homeContext struct { // Modules // -- - clients clientsContainer // per-client-settings module - stats stats.Stats // statistics module - queryLog querylog.QueryLog // query log module - dnsServer *dnsforward.Server // DNS module - rdns *RDNS // rDNS module - whois *WHOIS // WHOIS module - dnsFilter *filtering.DNSFilter // DNS filtering module - dhcpServer *dhcpd.Server // DHCP module - auth *Auth // HTTP authentication module - filters Filtering // DNS filtering module - web *Web // Web (HTTP, HTTPS) module - tls *TLSMod // TLS module - etcHosts *aghnet.EtcHostsContainer // IP-hostname pairs taken from system configuration (e.g. /etc/hosts) files - updater *updater.Updater + clients clientsContainer // per-client-settings module + stats stats.Stats // statistics module + queryLog querylog.QueryLog // query log module + dnsServer *dnsforward.Server // DNS module + rdns *RDNS // rDNS module + whois *WHOIS // WHOIS module + dnsFilter *filtering.DNSFilter // DNS filtering module + dhcpServer *dhcpd.Server // DHCP module + auth *Auth // HTTP authentication module + filters Filtering // DNS filtering module + web *Web // Web (HTTP, HTTPS) module + tls *TLSMod // TLS module + // etcHosts is an IP-hostname pairs set taken from system configuration + // (e.g. /etc/hosts) files. + etcHosts *aghnet.HostsContainer + updater *updater.Updater subnetDetector *aghnet.SubnetDetector @@ -257,8 +259,20 @@ func setupConfig(args options) (err error) { }) if !args.noEtcHosts { - Context.etcHosts = &aghnet.EtcHostsContainer{} - Context.etcHosts.Init("") + var osWritesWatcher aghos.FSWatcher + osWritesWatcher, err = aghos.NewOSWritesWatcher() + if err != nil { + return fmt.Errorf("initing os watcher: %w", err) + } + + Context.etcHosts, err = aghnet.NewHostsContainer( + aghos.RootDirFS(), + osWritesWatcher, + aghnet.DefaultHostsPaths()..., + ) + if err != nil { + return fmt.Errorf("initing hosts container: %w", err) + } } Context.clients.Init(config.Clients, Context.dhcpServer, Context.etcHosts) config.Clients = nil @@ -424,7 +438,6 @@ func run(args options, clientBuildFS fs.FS) { fatalOnError(err) Context.tls.Start() - Context.etcHosts.Start() go func() { serr := startDNSServer() @@ -647,7 +660,11 @@ func cleanup(ctx context.Context) { } } - Context.etcHosts.Close() + if Context.etcHosts != nil { + if err = Context.etcHosts.Close(); err != nil { + log.Error("stopping hosts container: %s", err) + } + } if Context.tls != nil { Context.tls.Close()