Pull request: 2499 merge rewrites vol.1

Merge in DNS/adguard-home from 2499-merge-rewrites-vol.1 to master

Updates #2499.

Squashed commit of the following:

commit 6b308bc2b360cee8c22e506f31d62bacb4bf8fb3
Merge: f49e9186 2b635bf6
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Thu Oct 14 19:23:07 2021 +0300

    Merge branch 'master' into 2499-merge-rewrites-vol.1

commit f49e9186ffc8b7074d03c6721ee56cdb09243684
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Thu Oct 14 18:50:49 2021 +0300

    aghos: fix fs events filtering

commit 567dd646556606212af5dab60e3ecbb8fff22c25
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Thu Oct 14 16:50:37 2021 +0300

    all: imp code, docs, fix windows

commit 140c8bf519345eb54d0e7500a996fcf465353d71
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Wed Oct 13 19:41:53 2021 +0300

    aghnet: use const

commit bebf3f76bd394a498ccad812c57d4507c69529ba
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Wed Oct 13 19:32:37 2021 +0300

    all: imp tests, docs

commit 9bfdbb6eb454833135d616e208e82699f98e2562
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Wed Oct 13 18:42:20 2021 +0300

    all: imp path more, imp docs

commit ee9ea4c132a6b17787d150bf2bee703abaa57be3
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Wed Oct 13 16:09:46 2021 +0300

    all: fix windows, imp paths

commit 6fac8338a81e9ecfebfc23a1adcb964e89f6aee6
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Mon Oct 11 19:53:35 2021 +0300

    all: imp code, docs

commit da1ce1a2a3dd2be3fdff2412a6dbd596859dc249
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Mon Oct 11 18:22:50 2021 +0300

    aghnet: fix windows tests

commit d29de359ed68118d71efb226a8433fac15ff5c66
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Fri Oct 8 21:02:14 2021 +0300

    all: repl & imp

commit 1356c08944cdbb85ce5532d90fe5b077219ce5ff
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Fri Oct 8 01:41:19 2021 +0300

    all: add tests, mv logic, added tmpfs

commit f4b11adf8998bc8d9d955c5ac9f386f671bd5213
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Thu Oct 7 14:26:30 2021 +0300

    all: imp filewalker, refactor hosts container
This commit is contained in:
Eugene Burkov 2021-10-14 19:39:21 +03:00
parent 2b635bf689
commit 2796e65468
24 changed files with 1377 additions and 775 deletions

View File

@ -1,387 +0,0 @@
package aghnet
import (
"bufio"
"net"
"os"
"path/filepath"
"runtime"
"strings"
"sync"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/fsnotify/fsnotify"
"github.com/miekg/dns"
)
type onChangedT func()
// EtcHostsContainer - automatic DNS records
//
// TODO(e.burkov): Move the logic under interface. Refactor. Probably remove
// the resolving logic.
type EtcHostsContainer struct {
// lock protects table and tableReverse.
lock sync.RWMutex
// table is the host-to-IPs map.
table map[string][]net.IP
// tableReverse is the IP-to-hosts map. The type of the values in the
// map is []string.
tableReverse *netutil.IPMap
hostsFn string // path to the main hosts-file
hostsDirs []string // paths to OS-specific directories with hosts-files
watcher *fsnotify.Watcher // file and directory watcher object
// onlyWritesChan used to contain only writing events from watcher.
onlyWritesChan chan fsnotify.Event
onChanged onChangedT // notification to other modules
}
// SetOnChanged - set callback function that will be called when the data is changed
func (ehc *EtcHostsContainer) SetOnChanged(onChanged onChangedT) {
ehc.onChanged = onChanged
}
// Notify other modules
func (ehc *EtcHostsContainer) notify() {
if ehc.onChanged == nil {
return
}
ehc.onChanged()
}
// Init - initialize
// hostsFn: Override default name for the hosts-file (optional)
func (ehc *EtcHostsContainer) Init(hostsFn string) {
ehc.table = make(map[string][]net.IP)
ehc.onlyWritesChan = make(chan fsnotify.Event, 2)
ehc.hostsFn = "/etc/hosts"
if runtime.GOOS == "windows" {
ehc.hostsFn = os.ExpandEnv("$SystemRoot\\system32\\drivers\\etc\\hosts")
}
if len(hostsFn) != 0 {
ehc.hostsFn = hostsFn
}
if aghos.IsOpenWrt() {
// OpenWrt: "/tmp/hosts/dhcp.cfg01411c".
ehc.hostsDirs = append(ehc.hostsDirs, "/tmp/hosts")
}
// Load hosts initially
ehc.updateHosts()
var err error
ehc.watcher, err = fsnotify.NewWatcher()
if err != nil {
log.Error("etchosts: %s", err)
}
}
// Start - start module
func (ehc *EtcHostsContainer) Start() {
if ehc == nil {
return
}
log.Debug("Start etchostscontainer module")
ehc.updateHosts()
if ehc.watcher != nil {
go ehc.watcherLoop()
err := ehc.watcher.Add(ehc.hostsFn)
if err != nil {
log.Error("Error while initializing watcher for a file %s: %s", ehc.hostsFn, err)
}
for _, dir := range ehc.hostsDirs {
err = ehc.watcher.Add(dir)
if err != nil {
log.Error("Error while initializing watcher for a directory %s: %s", dir, err)
}
}
}
}
// Close - close module
func (ehc *EtcHostsContainer) Close() {
if ehc == nil {
return
}
if ehc.watcher != nil {
_ = ehc.watcher.Close()
}
// Don't close onlyWritesChan here and let onlyWrites close it after
// watcher.Events is closed to prevent close races.
}
// Process returns the list of IP addresses for the hostname or nil if nothing
// found.
func (ehc *EtcHostsContainer) Process(host string, qtype uint16) []net.IP {
if qtype == dns.TypePTR {
return nil
}
var ipsCopy []net.IP
ehc.lock.RLock()
defer ehc.lock.RUnlock()
if ips, ok := ehc.table[host]; ok {
ipsCopy = make([]net.IP, len(ips))
copy(ipsCopy, ips)
}
log.Debug("etchosts: answer: %s -> %v", host, ipsCopy)
return ipsCopy
}
// ProcessReverse processes a PTR request. It returns nil if nothing is found.
func (ehc *EtcHostsContainer) ProcessReverse(addr string, qtype uint16) (hosts []string) {
if qtype != dns.TypePTR {
return nil
}
ip, err := netutil.IPFromReversedAddr(addr)
if err != nil {
log.Error("etchosts: reversed addr: %s", err)
return nil
}
ehc.lock.RLock()
defer ehc.lock.RUnlock()
v, ok := ehc.tableReverse.Get(ip)
if !ok {
return nil
}
hosts, ok = v.([]string)
if !ok {
log.Error("etchosts: bad type %T in tableReverse for %s", v, ip)
return nil
} else if len(hosts) == 0 {
return nil
}
log.Debug("etchosts: reverse-lookup: %s -> %s", addr, hosts)
return hosts
}
// List returns an IP-to-hostnames table. The type of the values in the map is
// []string. It is safe for concurrent use.
func (ehc *EtcHostsContainer) List() (ipToHosts *netutil.IPMap) {
ehc.lock.RLock()
defer ehc.lock.RUnlock()
return ehc.tableReverse.ShallowClone()
}
// update table
func (ehc *EtcHostsContainer) updateTable(table map[string][]net.IP, host string, ipAddr net.IP) {
ips, ok := table[host]
if ok {
for _, ip := range ips {
if ip.Equal(ipAddr) {
// IP already exists: don't add duplicates
ok = false
break
}
}
if !ok {
ips = append(ips, ipAddr)
table[host] = ips
}
} else {
table[host] = []net.IP{ipAddr}
ok = true
}
if ok {
log.Debug("etchosts: added %s -> %s", ipAddr, host)
}
}
// updateTableRev updates the reverse address table.
func (ehc *EtcHostsContainer) updateTableRev(tableRev *netutil.IPMap, newHost string, ip net.IP) {
v, ok := tableRev.Get(ip)
if !ok {
tableRev.Set(ip, []string{newHost})
log.Debug("etchosts: added reverse-address %s -> %s", ip, newHost)
return
}
hosts, _ := v.([]string)
for _, host := range hosts {
if host == newHost {
return
}
}
hosts = append(hosts, newHost)
tableRev.Set(ip, hosts)
log.Debug("etchosts: added reverse-address %s -> %s", ip, newHost)
}
// parseHostsLine parses hosts from the fields.
func parseHostsLine(fields []string) (hosts []string) {
for _, f := range fields {
hashIdx := strings.IndexByte(f, '#')
if hashIdx == 0 {
// The rest of the fields are a part of the comment.
// Skip immediately.
return
} else if hashIdx > 0 {
// Only a part of the field is a comment.
hosts = append(hosts, f[:hashIdx])
return hosts
}
hosts = append(hosts, f)
}
return hosts
}
// load reads IP-hostname pairs from the hosts file. Multiple hostnames per
// line for one IP are supported.
func (ehc *EtcHostsContainer) load(
table map[string][]net.IP,
tableRev *netutil.IPMap,
fn string,
) {
f, err := os.Open(fn)
if err != nil {
log.Error("etchosts: %s", err)
return
}
defer func() {
derr := f.Close()
if derr != nil {
log.Error("etchosts: closing file: %s", err)
}
}()
log.Debug("etchosts: loading hosts from file %s", fn)
s := bufio.NewScanner(f)
for s.Scan() {
line := strings.TrimSpace(s.Text())
fields := strings.Fields(line)
if len(fields) < 2 {
continue
}
ip := net.ParseIP(fields[0])
if ip == nil {
continue
}
hosts := parseHostsLine(fields[1:])
for _, host := range hosts {
ehc.updateTable(table, host, ip)
ehc.updateTableRev(tableRev, host, ip)
}
}
err = s.Err()
if err != nil {
log.Error("etchosts: %s", err)
}
}
// onlyWrites is a filter for (*fsnotify.Watcher).Events.
func (ehc *EtcHostsContainer) onlyWrites() {
for event := range ehc.watcher.Events {
if event.Op&fsnotify.Write == fsnotify.Write {
ehc.onlyWritesChan <- event
}
}
close(ehc.onlyWritesChan)
}
// Receive notifications from fsnotify package
func (ehc *EtcHostsContainer) watcherLoop() {
go ehc.onlyWrites()
for {
select {
case event, ok := <-ehc.onlyWritesChan:
if !ok {
return
}
// Assume that we sometimes have the same event occurred
// several times.
repeat := true
for repeat {
select {
case _, ok = <-ehc.onlyWritesChan:
repeat = ok
default:
repeat = false
}
}
if event.Op&fsnotify.Write == fsnotify.Write {
log.Debug("etchosts: modified: %s", event.Name)
ehc.updateHosts()
}
case err, ok := <-ehc.watcher.Errors:
if !ok {
return
}
log.Error("etchosts: %s", err)
}
}
}
// updateHosts - loads system hosts
func (ehc *EtcHostsContainer) updateHosts() {
table := make(map[string][]net.IP)
tableRev := netutil.NewIPMap(0)
ehc.load(table, tableRev, ehc.hostsFn)
for _, dir := range ehc.hostsDirs {
des, err := os.ReadDir(dir)
if err != nil {
if !errors.Is(err, os.ErrNotExist) {
log.Error("etchosts: Opening directory: %q: %s", dir, err)
}
continue
}
for _, de := range des {
ehc.load(table, tableRev, filepath.Join(dir, de.Name()))
}
}
func() {
ehc.lock.Lock()
defer ehc.lock.Unlock()
ehc.table = table
ehc.tableReverse = tableRev
}()
ehc.notify()
}

View File

@ -1,130 +0,0 @@
package aghnet
import (
"net"
"os"
"strings"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMain(m *testing.M) {
aghtest.DiscardLogOutput(m)
}
func prepareTestFile(t *testing.T) (f *os.File) {
t.Helper()
dir := t.TempDir()
f, err := os.CreateTemp(dir, "")
require.NoError(t, err)
require.NotNil(t, f)
t.Cleanup(func() {
assert.NoError(t, f.Close())
})
return f
}
func assertWriting(t *testing.T, f *os.File, strs ...string) {
t.Helper()
for _, str := range strs {
n, err := f.WriteString(str)
require.NoError(t, err)
assert.Equal(t, n, len(str))
}
}
func TestEtcHostsContainerResolution(t *testing.T) {
ehc := &EtcHostsContainer{}
f := prepareTestFile(t)
assertWriting(t, f,
" 127.0.0.1 host localhost # comment \n",
" ::1 localhost#comment \n",
)
ehc.Init(f.Name())
t.Run("existing_host", func(t *testing.T) {
ips := ehc.Process("localhost", dns.TypeA)
require.Len(t, ips, 1)
assert.Equal(t, net.IPv4(127, 0, 0, 1), ips[0])
})
t.Run("unknown_host", func(t *testing.T) {
ips := ehc.Process("newhost", dns.TypeA)
assert.Nil(t, ips)
// Comment.
ips = ehc.Process("comment", dns.TypeA)
assert.Nil(t, ips)
})
t.Run("hosts_file", func(t *testing.T) {
names, ok := ehc.List().Get(net.IP{127, 0, 0, 1})
require.True(t, ok)
assert.Equal(t, []string{"host", "localhost"}, names)
})
t.Run("ptr", func(t *testing.T) {
testCases := []struct {
wantIP string
wantHost string
wantLen int
}{
{wantIP: "127.0.0.1", wantHost: "host", wantLen: 2},
{wantIP: "::1", wantHost: "localhost", wantLen: 1},
}
for _, tc := range testCases {
a, err := dns.ReverseAddr(tc.wantIP)
require.NoError(t, err)
a = strings.TrimSuffix(a, ".")
hosts := ehc.ProcessReverse(a, dns.TypePTR)
require.Len(t, hosts, tc.wantLen)
assert.Equal(t, tc.wantHost, hosts[0])
}
})
}
func TestEtcHostsContainerFSNotify(t *testing.T) {
ehc := &EtcHostsContainer{}
f := prepareTestFile(t)
assertWriting(t, f, " 127.0.0.1 host localhost \n")
ehc.Init(f.Name())
t.Run("unknown_host", func(t *testing.T) {
ips := ehc.Process("newhost", dns.TypeA)
assert.Nil(t, ips)
})
// Start monitoring for changes.
ehc.Start()
t.Cleanup(ehc.Close)
assertWriting(t, f, "127.0.0.2 newhost\n")
require.NoError(t, f.Sync())
// Wait until fsnotify has triggered and processed the file-modification
// event.
time.Sleep(50 * time.Millisecond)
t.Run("notified", func(t *testing.T) {
ips := ehc.Process("newhost", dns.TypeA)
assert.NotNil(t, ips)
require.Len(t, ips, 1)
assert.True(t, net.IP{127, 0, 0, 2}.Equal(ips[0]))
})
}

View File

@ -0,0 +1,343 @@
package aghnet
import (
"bufio"
"fmt"
"io"
"io/fs"
"net"
"path"
"strings"
"sync"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/AdguardTeam/urlfilter"
"github.com/AdguardTeam/urlfilter/filterlist"
"github.com/miekg/dns"
)
// DefaultHostsPaths returns the slice of paths default for the operating system
// to files and directories which are containing the hosts database. The result
// is intended to use within fs.FS so the initial slash is omitted.
func DefaultHostsPaths() (paths []string) {
return defaultHostsPaths()
}
// hostsContainerPref is a prefix for logging and wrapping errors in
// HostsContainer's methods.
const hostsContainerPref = "hosts container"
// HostsContainer stores the relevant hosts database provided by the OS and
// processes both A/AAAA and PTR DNS requests for those.
type HostsContainer struct {
// engLock protects rulesStrg and engine.
engLock *sync.RWMutex
// rulesStrg stores the rules obtained from the hosts' file.
rulesStrg *filterlist.RuleStorage
// engine serves rulesStrg.
engine *urlfilter.DNSEngine
// Updates is the channel for receiving updated hosts. The receivable map's
// values has a type of slice of strings.
updates chan *netutil.IPMap
// fsys is the working file system to read hosts files from.
fsys fs.FS
// w tracks the changes in specified files and directories.
w aghos.FSWatcher
// patterns stores specified paths in the fs.Glob-compatible form.
patterns []string
}
// errNoPaths is returned when there are no paths to watch passed to the
// HostsContainer.
const errNoPaths errors.Error = "hosts paths are empty"
// NewHostsContainer creates a container of hosts, that watches the paths with
// w. paths shouldn't be empty and each of them should locate either a file or
// a directory in fsys. fsys and w must be non-nil.
func NewHostsContainer(
fsys fs.FS,
w aghos.FSWatcher,
paths ...string,
) (hc *HostsContainer, err error) {
defer func() { err = errors.Annotate(err, "%s: %w", hostsContainerPref) }()
if len(paths) == 0 {
return nil, errNoPaths
}
patterns, err := pathsToPatterns(fsys, paths)
if err != nil {
return nil, err
}
hc = &HostsContainer{
engLock: &sync.RWMutex{},
updates: make(chan *netutil.IPMap, 1),
fsys: fsys,
w: w,
patterns: patterns,
}
log.Debug("%s: starting", hostsContainerPref)
// Load initially.
if err = hc.refresh(); err != nil {
return nil, err
}
for _, p := range paths {
err = w.Add(p)
if err == nil {
continue
} else if errors.Is(err, fs.ErrNotExist) {
log.Debug("%s: file %q expected to exist but doesn't", hostsContainerPref, p)
continue
}
return nil, fmt.Errorf("adding path: %w", err)
}
go hc.handleEvents()
return hc, nil
}
// MatchRequest is the request processing method to resolve hostnames and
// addresses from the operating system's hosts files. Any request not of A/AAAA
// or PTR type will return with an empty result. It's safe for concurrent use.
func (hc *HostsContainer) MatchRequest(
req urlfilter.DNSRequest,
) (res urlfilter.DNSResult, ok bool) {
switch req.DNSType {
case dns.TypeA, dns.TypeAAAA, dns.TypePTR:
log.Debug("%s: handling the request", hostsContainerPref)
default:
return urlfilter.DNSResult{}, false
}
hc.engLock.RLock()
defer hc.engLock.RUnlock()
return hc.engine.MatchRequest(req)
}
// Close implements the io.Closer interface for *HostsContainer.
func (hc *HostsContainer) Close() (err error) {
log.Debug("%s: closing hosts container", hostsContainerPref)
return errors.Annotate(hc.w.Close(), "%s: closing: %w", hostsContainerPref)
}
// Upd returns the channel into which the updates are sent.
func (hc *HostsContainer) Upd() (updates <-chan *netutil.IPMap) {
return hc.updates
}
// pathsToPatterns converts paths into patterns compatible with fs.Glob.
func pathsToPatterns(fsys fs.FS, paths []string) (patterns []string, err error) {
for i, p := range paths {
var fi fs.FileInfo
if fi, err = fs.Stat(fsys, p); err != nil {
return nil, fmt.Errorf("%q at index %d: %w", p, i, err)
}
if fi.IsDir() {
p = path.Join(p, "*")
}
patterns = append(patterns, p)
}
return patterns, nil
}
// handleEvents concurrently handles the events. It closes the update channel
// of HostsContainer when finishes. Used to be called within a goroutine.
func (hc *HostsContainer) handleEvents() {
defer log.OnPanic(fmt.Sprintf("%s: handling events", hostsContainerPref))
defer close(hc.updates)
for range hc.w.Events() {
if err := hc.refresh(); err != nil {
log.Error("%s: %s", hostsContainerPref, err)
}
}
}
// hostsParser is a helper type to parse rules from the operating system's hosts
// file.
type hostsParser struct {
// rules builds the resulting rules list content.
rules *strings.Builder
// table stores only the unique IP-hostname pairs. It's also sent to the
// updates channel afterwards.
table *netutil.IPMap
}
// parseHostsFile is a aghtest.FileWalker for parsing the files with hosts
// syntax. It never signs to stop the walking.
//
// See man hosts(5).
func (hp hostsParser) parseHostsFile(
r io.Reader,
) (patterns []string, cont bool, err error) {
s := bufio.NewScanner(r)
for s.Scan() {
ip, hosts := hp.parseLine(s.Text())
if ip == nil {
continue
}
for _, host := range hosts {
hp.addPair(ip, host)
}
}
return nil, true, s.Err()
}
// parseLine parses the line having the hosts syntax ignoring invalid ones.
func (hp hostsParser) parseLine(line string) (ip net.IP, hosts []string) {
line = strings.TrimSpace(line)
fields := strings.Fields(line)
if len(fields) < 2 {
return nil, nil
}
if ip = net.ParseIP(fields[0]); ip == nil {
return nil, nil
}
loop:
for _, f := range fields[1:] {
switch hashIdx := strings.IndexByte(f, '#'); hashIdx {
case 0:
// The rest of the fields are a part of the comment so skip
// immediately.
break loop
case -1:
hosts = append(hosts, f)
default:
// Only a part of the field is a comment.
hosts = append(hosts, f[:hashIdx])
break loop
}
}
return ip, hosts
}
// add returns true if the pair of ip and host wasn't added to the hp before.
func (hp hostsParser) add(ip net.IP, host string) (added bool) {
v, ok := hp.table.Get(ip)
hosts, _ := v.([]string)
if ok && stringutil.InSlice(hosts, host) {
return false
}
hp.table.Set(ip, append(hosts, host))
return true
}
// addPair puts the pair of ip and host to the rules builder if needed.
func (hp hostsParser) addPair(ip net.IP, host string) {
arpa, err := netutil.IPToReversedAddr(ip)
if err != nil {
return
}
if !hp.add(ip, host) {
return
}
qtype := "AAAA"
if ip.To4() != nil {
// Assume the validation of the IP address is performed already.
qtype = "A"
}
stringutil.WriteToBuilder(
hp.rules,
"||",
host,
"^$dnsrewrite=NOERROR;",
qtype,
";",
ip.String(),
"\n",
"||",
arpa,
"^$dnsrewrite=NOERROR;PTR;",
dns.Fqdn(host),
"\n",
)
log.Debug("%s: added ip-host pair %q/%q", hostsContainerPref, ip, host)
}
// sendUpd tries to send the parsed data to the ch.
func (hp hostsParser) sendUpd(ch chan *netutil.IPMap) {
log.Debug("%s: sending upd", hostsContainerPref)
select {
case ch <- hp.table:
// Updates are delivered. Go on.
default:
log.Debug("%s: the buffer is full", hostsContainerPref)
}
}
// newStrg creates a new rules storage from parsed data.
func (hp hostsParser) newStrg() (s *filterlist.RuleStorage, err error) {
return filterlist.NewRuleStorage([]filterlist.RuleList{&filterlist.StringRuleList{
ID: 1,
RulesText: hp.rules.String(),
IgnoreCosmetic: true,
}})
}
// refresh gets the data from specified files and propagates the updates.
func (hc *HostsContainer) refresh() (err error) {
log.Debug("%s: refreshing", hostsContainerPref)
hp := hostsParser{
rules: &strings.Builder{},
table: netutil.NewIPMap(0),
}
_, err = aghos.FileWalker(hp.parseHostsFile).Walk(hc.fsys, hc.patterns...)
if err != nil {
return fmt.Errorf("updating: %w", err)
}
defer hp.sendUpd(hc.updates)
var rulesStrg *filterlist.RuleStorage
if rulesStrg, err = hp.newStrg(); err != nil {
return fmt.Errorf("initializing rules storage: %w", err)
}
hc.resetEng(rulesStrg)
return nil
}
func (hc *HostsContainer) resetEng(rulesStrg *filterlist.RuleStorage) {
hc.engLock.Lock()
defer hc.engLock.Unlock()
hc.rulesStrg = rulesStrg
hc.engine = urlfilter.NewDNSEngine(hc.rulesStrg)
}

View File

@ -0,0 +1,18 @@
//go:build linux
// +build linux
package aghnet
import (
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
)
func defaultHostsPaths() (paths []string) {
paths = []string{"etc/hosts"}
if aghos.IsOpenWrt() {
paths = append(paths, "tmp/hosts")
}
return paths
}

View File

@ -0,0 +1,8 @@
//go:build !(windows || linux)
// +build !windows,!linux
package aghnet
func defaultHostsPaths() (paths []string) {
return []string{"etc/hosts"}
}

View File

@ -0,0 +1,504 @@
package aghnet
import (
"io/fs"
"net"
"path"
"strings"
"sync/atomic"
"testing"
"testing/fstest"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/urlfilter"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const (
nl = "\n"
sp = " "
)
const closeCalled errors.Error = "close method called"
// fsWatcherOnCloseStub is a stub implementation of the Close method of
// aghos.FSWatcher.
func fsWatcherOnCloseStub() (err error) {
return closeCalled
}
func TestNewHostsContainer(t *testing.T) {
const dirname = "dir"
const filename = "file1"
p := path.Join(dirname, filename)
testFS := fstest.MapFS{
p: &fstest.MapFile{Data: []byte("127.0.0.1 localhost")},
}
testCases := []struct {
name string
paths []string
wantErr error
wantPatterns []string
}{{
name: "one_file",
paths: []string{p},
wantErr: nil,
wantPatterns: []string{p},
}, {
name: "no_files",
paths: []string{},
wantErr: errNoPaths,
wantPatterns: nil,
}, {
name: "non-existent_file",
paths: []string{path.Join(dirname, filename+"2")},
wantErr: fs.ErrNotExist,
wantPatterns: nil,
}, {
name: "whole_dir",
paths: []string{dirname},
wantErr: nil,
wantPatterns: []string{path.Join(dirname, "*")},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
onAdd := func(name string) (err error) {
assert.Contains(t, tc.paths, name)
return nil
}
var eventsCalledCounter uint32
eventsCh := make(chan struct{})
onEvents := func() (e <-chan struct{}) {
assert.Equal(t, uint32(1), atomic.AddUint32(&eventsCalledCounter, 1))
return eventsCh
}
hc, err := NewHostsContainer(testFS, &aghtest.FSWatcher{
OnEvents: onEvents,
OnAdd: onAdd,
OnClose: fsWatcherOnCloseStub,
}, tc.paths...)
if tc.wantErr != nil {
require.ErrorIs(t, err, tc.wantErr)
assert.Nil(t, hc)
return
}
require.NoError(t, err)
t.Cleanup(func() {
require.ErrorIs(t, hc.Close(), closeCalled)
})
require.NotNil(t, hc)
assert.Equal(t, tc.wantPatterns, hc.patterns)
assert.NotNil(t, <-hc.Upd())
eventsCh <- struct{}{}
assert.Equal(t, uint32(1), atomic.LoadUint32(&eventsCalledCounter))
})
}
t.Run("nil_fs", func(t *testing.T) {
require.Panics(t, func() {
_, _ = NewHostsContainer(nil, &aghtest.FSWatcher{
// Those shouldn't panic.
OnEvents: func() (e <-chan struct{}) { return nil },
OnAdd: func(name string) (err error) { return nil },
OnClose: func() (err error) { return nil },
}, p)
})
})
t.Run("nil_watcher", func(t *testing.T) {
require.Panics(t, func() {
_, _ = NewHostsContainer(testFS, nil, p)
})
})
t.Run("err_watcher", func(t *testing.T) {
const errOnAdd errors.Error = "error"
errWatcher := &aghtest.FSWatcher{
OnEvents: func() (e <-chan struct{}) { panic("not implemented") },
OnAdd: func(name string) (err error) { return errOnAdd },
OnClose: func() (err error) { panic("not implemented") },
}
hc, err := NewHostsContainer(testFS, errWatcher, p)
require.ErrorIs(t, err, errOnAdd)
assert.Nil(t, hc)
})
}
func TestHostsContainer_Refresh(t *testing.T) {
knownIP := net.IP{127, 0, 0, 1}
const knownHost = "localhost"
const knownAlias = "hocallost"
const dirname = "dir"
const filename1 = "file1"
const filename2 = "file2"
p1 := path.Join(dirname, filename1)
p2 := path.Join(dirname, filename2)
testFS := fstest.MapFS{
p1: &fstest.MapFile{
Data: []byte(strings.Join([]string{knownIP.String(), knownHost}, sp) + nl),
},
}
eventsCh := make(chan struct{}, 1)
t.Cleanup(func() { close(eventsCh) })
w := &aghtest.FSWatcher{
OnEvents: func() (e <-chan struct{}) { return eventsCh },
OnAdd: func(name string) (err error) {
assert.Equal(t, dirname, name)
return nil
},
OnClose: fsWatcherOnCloseStub,
}
hc, err := NewHostsContainer(testFS, w, dirname)
require.NoError(t, err)
t.Cleanup(func() { require.ErrorIs(t, hc.Close(), closeCalled) })
checkRefresh := func(t *testing.T, wantHosts []string) {
upd, ok := <-hc.Upd()
require.True(t, ok)
require.NotNil(t, upd)
assert.Equal(t, 1, upd.Len())
v, ok := upd.Get(knownIP)
require.True(t, ok)
var hosts []string
hosts, ok = v.([]string)
require.True(t, ok)
require.Len(t, hosts, len(wantHosts))
assert.Equal(t, wantHosts, hosts)
}
t.Run("initial_refresh", func(t *testing.T) {
checkRefresh(t, []string{knownHost})
})
testFS[p2] = &fstest.MapFile{
Data: []byte(strings.Join([]string{knownIP.String(), knownAlias}, sp) + nl),
}
eventsCh <- struct{}{}
t.Run("second_refresh", func(t *testing.T) {
checkRefresh(t, []string{knownHost, knownAlias})
})
}
func TestHostsContainer_MatchRequest(t *testing.T) {
var (
ip4 = net.IP{127, 0, 0, 1}
ip6 = net.IP{
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 1,
}
hostname4 = "localhost"
hostname6 = "localhostv6"
hostname4a = "abcd"
reversed4, _ = netutil.IPToReversedAddr(ip4)
reversed6, _ = netutil.IPToReversedAddr(ip6)
)
const filename = "file1"
gsfs := fstest.MapFS{
filename: &fstest.MapFile{Data: []byte(
strings.Join([]string{ip4.String(), hostname4, hostname4a}, sp) + nl +
strings.Join([]string{ip6.String(), hostname6}, sp) + nl +
strings.Join([]string{"256.256.256.256", "fakebroadcast"}, sp) + nl,
)},
}
hc, err := NewHostsContainer(gsfs, &aghtest.FSWatcher{
OnEvents: func() (e <-chan struct{}) { panic("not implemented") },
OnAdd: func(name string) (err error) {
assert.Equal(t, filename, name)
return nil
},
OnClose: fsWatcherOnCloseStub,
}, filename)
require.NoError(t, err)
t.Cleanup(func() { require.ErrorIs(t, hc.Close(), closeCalled) })
testCase := []struct {
name string
want interface{}
req urlfilter.DNSRequest
}{{
name: "a",
want: ip4.To16(),
req: urlfilter.DNSRequest{
Hostname: hostname4,
DNSType: dns.TypeA,
},
}, {
name: "aaaa",
want: ip6,
req: urlfilter.DNSRequest{
Hostname: hostname6,
DNSType: dns.TypeA,
},
}, {
name: "ptr",
want: dns.Fqdn(hostname4),
req: urlfilter.DNSRequest{
Hostname: reversed4,
DNSType: dns.TypePTR,
},
}, {
name: "ptr_v6",
want: dns.Fqdn(hostname6),
req: urlfilter.DNSRequest{
Hostname: reversed6,
DNSType: dns.TypePTR,
},
}, {
name: "a_alias",
want: ip4.To16(),
req: urlfilter.DNSRequest{
Hostname: hostname4a,
DNSType: dns.TypeA,
},
}}
for _, tc := range testCase {
t.Run(tc.name, func(t *testing.T) {
res, ok := hc.MatchRequest(tc.req)
require.False(t, ok)
assert.Equal(t, tc.want, res.DNSRewrites()[0].DNSRewrite.Value)
})
}
t.Run("cname", func(t *testing.T) {
res, ok := hc.MatchRequest(urlfilter.DNSRequest{
Hostname: hostname4,
DNSType: dns.TypeCNAME,
})
require.False(t, ok)
assert.Empty(t, res)
})
}
func TestHostsContainer_PathsToPatterns(t *testing.T) {
const (
dir0 = "dir"
dir1 = "dir_1"
fn1 = "file_1"
fn2 = "file_2"
fn3 = "file_3"
fn4 = "file_4"
)
fp1 := path.Join(dir0, fn1)
fp2 := path.Join(dir0, fn2)
fp3 := path.Join(dir0, dir1, fn3)
gsfs := fstest.MapFS{
fp1: &fstest.MapFile{Data: []byte{1}},
fp2: &fstest.MapFile{Data: []byte{2}},
fp3: &fstest.MapFile{Data: []byte{3}},
}
testCases := []struct {
name string
wantErr error
want []string
paths []string
}{{
name: "no_paths",
wantErr: nil,
want: nil,
paths: nil,
}, {
name: "single_file",
wantErr: nil,
want: []string{fp1},
paths: []string{fp1},
}, {
name: "several_files",
wantErr: nil,
want: []string{fp1, fp2},
paths: []string{fp1, fp2},
}, {
name: "whole_dir",
wantErr: nil,
want: []string{path.Join(dir0, "*")},
paths: []string{dir0},
}, {
name: "file_and_dir",
wantErr: nil,
want: []string{fp1, path.Join(dir0, dir1, "*")},
paths: []string{fp1, path.Join(dir0, dir1)},
}, {
name: "non-existing",
wantErr: fs.ErrNotExist,
want: nil,
paths: []string{path.Join(dir0, "file_3")},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
patterns, err := pathsToPatterns(gsfs, tc.paths)
if tc.wantErr != nil {
assert.ErrorIs(t, err, tc.wantErr)
return
}
require.NoError(t, err)
assert.Equal(t, tc.want, patterns)
})
}
}
func TestUniqueRules_AddPair(t *testing.T) {
knownIP := net.IP{1, 2, 3, 4}
const knownHost = "host1"
ipToHost := netutil.NewIPMap(0)
ipToHost.Set(knownIP, []string{knownHost})
testCases := []struct {
name string
host string
wantRules string
ip net.IP
}{{
name: "new_one",
host: "host2",
wantRules: "||host2^$dnsrewrite=NOERROR;A;1.2.3.4\n" +
"||4.3.2.1.in-addr.arpa^$dnsrewrite=NOERROR;PTR;host2.\n",
ip: knownIP,
}, {
name: "existing_one",
host: knownHost,
wantRules: "",
ip: knownIP,
}, {
name: "new_ip",
host: knownHost,
wantRules: "||" + knownHost + "^$dnsrewrite=NOERROR;A;1.2.3.5\n" +
"||5.3.2.1.in-addr.arpa^$dnsrewrite=NOERROR;PTR;" + knownHost + ".\n",
ip: net.IP{1, 2, 3, 5},
}, {
name: "bad_ip",
host: knownHost,
wantRules: "",
ip: net.IP{1, 2, 3, 4, 5},
}}
for _, tc := range testCases {
hp := hostsParser{
rules: &strings.Builder{},
table: ipToHost.ShallowClone(),
}
t.Run(tc.name, func(t *testing.T) {
hp.addPair(tc.ip, tc.host)
assert.Equal(t, tc.wantRules, hp.rules.String())
})
}
}
func TestUniqueRules_ParseLine(t *testing.T) {
const (
hostname = "localhost"
alias = "hocallost"
)
knownIP := net.IP{127, 0, 0, 1}
testCases := []struct {
name string
line string
wantIP net.IP
wantHosts []string
}{{
name: "simple",
line: strings.Join([]string{knownIP.String(), hostname}, sp),
wantIP: knownIP,
wantHosts: []string{"localhost"},
}, {
name: "aliases",
line: strings.Join([]string{knownIP.String(), hostname, alias}, sp),
wantIP: knownIP,
wantHosts: []string{"localhost", "hocallost"},
}, {
name: "invalid_line",
line: knownIP.String(),
wantIP: nil,
wantHosts: nil,
}, {
name: "invalid_line_hostname",
line: strings.Join([]string{knownIP.String(), "#" + hostname}, sp),
wantIP: knownIP,
wantHosts: nil,
}, {
name: "commented_aliases",
line: strings.Join([]string{knownIP.String(), hostname, "#" + alias}, sp),
wantIP: knownIP,
wantHosts: []string{"localhost"},
}, {
name: "whole_comment",
line: strings.Join([]string{"#", knownIP.String(), hostname}, sp),
wantIP: nil,
wantHosts: nil,
}, {
name: "partial_comment",
line: strings.Join([]string{knownIP.String(), hostname[:4] + "#" + hostname[4:]}, sp),
wantIP: knownIP,
wantHosts: []string{hostname[:4]},
}, {
name: "empty",
line: ``,
wantIP: nil,
wantHosts: nil,
}}
for _, tc := range testCases {
hp := hostsParser{}
t.Run(tc.name, func(t *testing.T) {
ip, hosts := hp.parseLine(tc.line)
assert.True(t, tc.wantIP.Equal(ip))
assert.Equal(t, tc.wantHosts, hosts)
})
}
}

View File

@ -0,0 +1,33 @@
//go:build windows
// +build windows
package aghnet
import (
"os"
"path"
"path/filepath"
"strings"
"github.com/AdguardTeam/golibs/log"
"golang.org/x/sys/windows"
)
func defaultHostsPaths() (paths []string) {
sysDir, err := windows.GetSystemDirectory()
if err != nil {
log.Error("getting system directory: %s", err)
return []string{}
}
// Split all the elements of the path to join them afterwards. This is
// needed to make the Windows-specific path string returned by
// windows.GetSystemDirectory to be compatible with fs.FS.
pathElems := strings.Split(sysDir, string(os.PathSeparator))
if len(pathElems) > 0 && pathElems[0] == filepath.VolumeName(sysDir) {
pathElems = pathElems[1:]
}
return []string{path.Join(append(pathElems, "drivers/etc/hosts")...)}
}

View File

@ -18,9 +18,11 @@ func canBindPrivilegedPorts() (can bool, err error) {
} }
func ifaceHasStaticIP(ifaceName string) (ok bool, err error) { func ifaceHasStaticIP(ifaceName string) (ok bool, err error) {
const filename = "/etc/rc.conf" const rcConfFilename = "etc/rc.conf"
return aghos.FileWalker(interfaceName(ifaceName).rcConfStaticConfig).Walk(filename) walker := aghos.FileWalker(interfaceName(ifaceName).rcConfStaticConfig)
return walker.Walk(aghos.RootDirFS(), rcConfFilename)
} }
// rcConfStaticConfig checks if the interface is configured by /etc/rc.conf to // rcConfStaticConfig checks if the interface is configured by /etc/rc.conf to

View File

@ -85,17 +85,17 @@ func ifaceHasStaticIP(ifaceName string) (has bool, err error) {
iface := interfaceName(ifaceName) iface := interfaceName(ifaceName)
for _, pair := range []struct { for _, pair := range [...]struct {
aghos.FileWalker aghos.FileWalker
filename string filename string
}{{ }{{
FileWalker: iface.dhcpcdStaticConfig, FileWalker: iface.dhcpcdStaticConfig,
filename: "/etc/dhcpcd.conf", filename: "etc/dhcpcd.conf",
}, { }, {
FileWalker: iface.ifacesStaticConfig, FileWalker: iface.ifacesStaticConfig,
filename: "/etc/network/interfaces", filename: "etc/network/interfaces",
}} { }} {
has, err = pair.Walk(pair.filename) has, err = pair.Walk(aghos.RootDirFS(), pair.filename)
if err != nil { if err != nil {
return false, err return false, err
} }

View File

@ -12,8 +12,6 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
const nl = "\n"
func TestDHCPCDStaticConfig(t *testing.T) { func TestDHCPCDStaticConfig(t *testing.T) {
const iface interfaceName = `wlan0` const iface interfaceName = `wlan0`

View File

@ -18,9 +18,9 @@ func canBindPrivilegedPorts() (can bool, err error) {
} }
func ifaceHasStaticIP(ifaceName string) (ok bool, err error) { func ifaceHasStaticIP(ifaceName string) (ok bool, err error) {
filename := fmt.Sprintf("/etc/hostname.%s", ifaceName) filename := fmt.Sprintf("etc/hostname.%s", ifaceName)
return aghos.FileWalker(hostnameIfStaticConfig).Walk(filename) return aghos.FileWalker(hostnameIfStaticConfig).Walk(aghos.RootDirFS(), filename)
} }
// hostnameIfStaticConfig checks if the interface is configured by // hostnameIfStaticConfig checks if the interface is configured by

View File

@ -4,10 +4,15 @@ import (
"net" "net"
"testing" "testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestMain(m *testing.M) {
aghtest.DiscardLogOutput(m)
}
func TestGetValidNetInterfacesForWeb(t *testing.T) { func TestGetValidNetInterfacesForWeb(t *testing.T) {
ifaces, err := GetValidNetInterfacesForWeb() ifaces, err := GetValidNetInterfacesForWeb()
require.NoErrorf(t, err, "cannot get net interfaces: %s", err) require.NoErrorf(t, err, "cannot get net interfaces: %s", err)

View File

@ -3,8 +3,7 @@ package aghos
import ( import (
"fmt" "fmt"
"io" "io"
"os" "io/fs"
"path/filepath"
"github.com/AdguardTeam/AdGuardHome/internal/aghio" "github.com/AdguardTeam/AdGuardHome/internal/aghio"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
@ -14,10 +13,10 @@ import (
// FileWalker is the signature of a function called for files in the file tree. // FileWalker is the signature of a function called for files in the file tree.
// As opposed to filepath.Walk it only walk the files (not directories) matching // As opposed to filepath.Walk it only walk the files (not directories) matching
// the provided pattern and those returned by function itself. All patterns // the provided pattern and those returned by function itself. All patterns
// should be valid for filepath.Glob. If cont is false, the walking terminates. // should be valid for fs.Glob. If cont is false, the walking terminates. Each
// Each opened file is also limited for reading to MaxWalkedFileSize. // opened file is also limited for reading to MaxWalkedFileSize.
// //
// TODO(e.burkov): Consider moving to the separate package like pathutil. // TODO(e.burkov, a.garipov): Move into another package like aghfs.
// //
// TODO(e.burkov): Think about passing filename or any additional data. // TODO(e.burkov): Think about passing filename or any additional data.
type FileWalker func(r io.Reader) (patterns []string, cont bool, err error) type FileWalker func(r io.Reader) (patterns []string, cont bool, err error)
@ -26,15 +25,19 @@ type FileWalker func(r io.Reader) (patterns []string, cont bool, err error)
// check. // check.
const MaxWalkedFileSize = 1024 * 1024 const MaxWalkedFileSize = 1024 * 1024
// checkFile tries to open and process a single file located on sourcePath. // checkFile tries to open and process a single file located on sourcePath in
func checkFile(c FileWalker, sourcePath string) (patterns []string, cont bool, err error) { // the specified fsys. The path is skipped if it's a directory.
var f *os.File func checkFile(
f, err = os.Open(sourcePath) fsys fs.FS,
c FileWalker,
sourcePath string,
) (patterns []string, cont bool, err error) {
var f fs.File
f, err = fsys.Open(sourcePath)
if err != nil { if err != nil {
if errors.Is(err, os.ErrNotExist) { if errors.Is(err, fs.ErrNotExist) {
// Ignore non-existing files since this may only happen // Ignore non-existing files since this may only happen when the
// when the file was removed after filepath.Glob matched // file was removed after filepath.Glob matched it.
// it.
return nil, true, nil return nil, true, nil
} }
@ -42,9 +45,18 @@ func checkFile(c FileWalker, sourcePath string) (patterns []string, cont bool, e
} }
defer func() { err = errors.WithDeferred(err, f.Close()) }() defer func() { err = errors.WithDeferred(err, f.Close()) }()
var fi fs.FileInfo
if fi, err = f.Stat(); err != nil {
return nil, true, err
}
if fi.IsDir() {
// Skip the directories.
return nil, true, nil
}
var r io.Reader var r io.Reader
// Ignore the error since LimitReader function returns error only if // Ignore the error since LimitReader function returns error only if passed
// passed limit value is less than zero, but the constant used. // limit value is less than zero, but the constant used.
// //
// TODO(e.burkov): Make variable. // TODO(e.burkov): Make variable.
r, _ = aghio.LimitReader(f, MaxWalkedFileSize) r, _ = aghio.LimitReader(f, MaxWalkedFileSize)
@ -52,13 +64,17 @@ func checkFile(c FileWalker, sourcePath string) (patterns []string, cont bool, e
return c(r) return c(r)
} }
// handlePatterns parses the patterns and ignores duplicates using srcSet. // handlePatterns parses the patterns in fsys and ignores duplicates using
// srcSet must be non-nil. // srcSet. srcSet must be non-nil.
func handlePatterns(srcSet *stringutil.Set, patterns ...string) (sub []string, err error) { func handlePatterns(
fsys fs.FS,
srcSet *stringutil.Set,
patterns ...string,
) (sub []string, err error) {
sub = make([]string, 0, len(patterns)) sub = make([]string, 0, len(patterns))
for _, p := range patterns { for _, p := range patterns {
var matches []string var matches []string
matches, err = filepath.Glob(p) matches, err = fs.Glob(fsys, p)
if err != nil { if err != nil {
// Enrich error with the pattern because filepath.Glob // Enrich error with the pattern because filepath.Glob
// doesn't do it. // doesn't do it.
@ -78,14 +94,14 @@ func handlePatterns(srcSet *stringutil.Set, patterns ...string) (sub []string, e
return sub, nil return sub, nil
} }
// Walk starts walking the files defined by initPattern. It only returns true // Walk starts walking the files in fsys defined by patterns from initial.
// if c signed to stop walking. // It only returns true if fw signed to stop walking.
func (c FileWalker) Walk(initPattern string) (ok bool, err error) { func (fw FileWalker) Walk(fsys fs.FS, initial ...string) (ok bool, err error) {
// The slice of sources keeps the order in which the files are walked // The slice of sources keeps the order in which the files are walked since
// since srcSet.Values() returns strings in undefined order. // srcSet.Values() returns strings in undefined order.
srcSet := stringutil.NewSet() srcSet := stringutil.NewSet()
var src []string var src []string
src, err = handlePatterns(srcSet, initPattern) src, err = handlePatterns(fsys, srcSet, initial...)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -97,7 +113,7 @@ func (c FileWalker) Walk(initPattern string) (ok bool, err error) {
var patterns []string var patterns []string
var cont bool var cont bool
filename = src[i] filename = src[i]
patterns, cont, err = checkFile(c, src[i]) patterns, cont, err = checkFile(fsys, fw, src[i])
if err != nil { if err != nil {
return false, err return false, err
} }
@ -107,7 +123,7 @@ func (c FileWalker) Walk(initPattern string) (ok bool, err error) {
} }
var subsrc []string var subsrc []string
subsrc, err = handlePatterns(srcSet, patterns...) subsrc, err = handlePatterns(fsys, srcSet, patterns...)
if err != nil { if err != nil {
return false, err return false, err
} }

View File

@ -4,56 +4,19 @@ import (
"bufio" "bufio"
"io" "io"
"io/fs" "io/fs"
"os" "path"
"path/filepath"
"testing" "testing"
"testing/fstest"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
// testFSDir maps entries' names to entries which should either be a testFSDir
// or byte slice.
type testFSDir map[string]interface{}
// testFSGen is used to generate a temporary filesystem consisting of
// directories and plain text files from itself.
type testFSGen testFSDir
// gen returns the name of top directory of the generated filesystem.
func (g testFSGen) gen(t *testing.T) (dirName string) {
t.Helper()
dirName = t.TempDir()
g.rangeThrough(t, dirName)
return dirName
}
func (g testFSGen) rangeThrough(t *testing.T, dirName string) {
const perm fs.FileMode = 0o777
for k, e := range g {
switch e := e.(type) {
case []byte:
require.NoError(t, os.WriteFile(filepath.Join(dirName, k), e, perm))
case testFSDir:
newDir := filepath.Join(dirName, k)
require.NoError(t, os.Mkdir(newDir, perm))
testFSGen(e).rangeThrough(t, newDir)
default:
t.Fatalf("unexpected entry type %T", e)
}
}
}
func TestFileWalker_Walk(t *testing.T) { func TestFileWalker_Walk(t *testing.T) {
const attribute = `000` const attribute = `000`
makeFileWalker := func(dirName string) (fw FileWalker) { makeFileWalker := func(_ string) (fw FileWalker) {
return func(r io.Reader) (patterns []string, cont bool, err error) { return func(r io.Reader) (patterns []string, cont bool, err error) {
s := bufio.NewScanner(r) s := bufio.NewScanner(r)
for s.Scan() { for s.Scan() {
@ -63,7 +26,7 @@ func TestFileWalker_Walk(t *testing.T) {
} }
if len(line) != 0 { if len(line) != 0 {
patterns = append(patterns, filepath.Join(dirName, line)) patterns = append(patterns, path.Join(".", line))
} }
} }
@ -74,136 +37,150 @@ func TestFileWalker_Walk(t *testing.T) {
const nl = "\n" const nl = "\n"
testCases := []struct { testCases := []struct {
name string testFS fstest.MapFS
testFS testFSGen want assert.BoolAssertionFunc
initPattern string initPattern string
want bool name string
}{{ }{{
name: "simple", name: "simple",
testFS: testFSGen{ testFS: fstest.MapFS{
"simple_0001.txt": []byte(attribute + nl), "simple_0001.txt": &fstest.MapFile{Data: []byte(attribute + nl)},
}, },
initPattern: "simple_0001.txt", initPattern: "simple_0001.txt",
want: true, want: assert.True,
}, { }, {
name: "chain", name: "chain",
testFS: testFSGen{ testFS: fstest.MapFS{
"chain_0001.txt": []byte(`chain_0002.txt` + nl), "chain_0001.txt": &fstest.MapFile{Data: []byte(`chain_0002.txt` + nl)},
"chain_0002.txt": []byte(`chain_0003.txt` + nl), "chain_0002.txt": &fstest.MapFile{Data: []byte(`chain_0003.txt` + nl)},
"chain_0003.txt": []byte(attribute + nl), "chain_0003.txt": &fstest.MapFile{Data: []byte(attribute + nl)},
}, },
initPattern: "chain_0001.txt", initPattern: "chain_0001.txt",
want: true, want: assert.True,
}, { }, {
name: "several", name: "several",
testFS: testFSGen{ testFS: fstest.MapFS{
"several_0001.txt": []byte(`several_*` + nl), "several_0001.txt": &fstest.MapFile{Data: []byte(`several_*` + nl)},
"several_0002.txt": []byte(`several_0001.txt` + nl), "several_0002.txt": &fstest.MapFile{Data: []byte(`several_0001.txt` + nl)},
"several_0003.txt": []byte(attribute + nl), "several_0003.txt": &fstest.MapFile{Data: []byte(attribute + nl)},
}, },
initPattern: "several_0001.txt", initPattern: "several_0001.txt",
want: true, want: assert.True,
}, { }, {
name: "no", name: "no",
testFS: testFSGen{ testFS: fstest.MapFS{
"no_0001.txt": []byte(nl), "no_0001.txt": &fstest.MapFile{Data: []byte(nl)},
"no_0002.txt": []byte(nl), "no_0002.txt": &fstest.MapFile{Data: []byte(nl)},
"no_0003.txt": []byte(nl), "no_0003.txt": &fstest.MapFile{Data: []byte(nl)},
}, },
initPattern: "no_*", initPattern: "no_*",
want: false, want: assert.False,
}, { }, {
name: "subdirectory", name: "subdirectory",
testFS: testFSGen{ testFS: fstest.MapFS{
"dir": testFSDir{ path.Join("dir", "subdir_0002.txt"): &fstest.MapFile{
"subdir_0002.txt": []byte(attribute + nl), Data: []byte(attribute + nl),
}, },
"subdir_0001.txt": []byte(`dir/*`), "subdir_0001.txt": &fstest.MapFile{Data: []byte(`dir/*`)},
}, },
initPattern: "subdir_0001.txt", initPattern: "subdir_0001.txt",
want: true, want: assert.True,
}} }}
for _, tc := range testCases { for _, tc := range testCases {
testDir := tc.testFS.gen(t) fw := makeFileWalker("")
fw := makeFileWalker(testDir)
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
ok, err := fw.Walk(filepath.Join(testDir, tc.initPattern)) ok, err := fw.Walk(tc.testFS, tc.initPattern)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, tc.want, ok) tc.want(t, ok)
}) })
} }
t.Run("pattern_malformed", func(t *testing.T) { t.Run("pattern_malformed", func(t *testing.T) {
ok, err := makeFileWalker("").Walk("[]") f := fstest.MapFS{}
ok, err := makeFileWalker("").Walk(f, "[]")
require.Error(t, err) require.Error(t, err)
assert.False(t, ok) assert.False(t, ok)
assert.ErrorIs(t, err, filepath.ErrBadPattern) assert.ErrorIs(t, err, path.ErrBadPattern)
}) })
t.Run("bad_filename", func(t *testing.T) { t.Run("bad_filename", func(t *testing.T) {
dir := testFSGen{ const filename = "bad_filename.txt"
"bad_filename.txt": []byte("[]"),
}.gen(t) f := fstest.MapFS{
fw := FileWalker(func(r io.Reader) (patterns []string, cont bool, err error) { filename: &fstest.MapFile{Data: []byte("[]")},
}
ok, err := FileWalker(func(r io.Reader) (patterns []string, cont bool, err error) {
s := bufio.NewScanner(r) s := bufio.NewScanner(r)
for s.Scan() { for s.Scan() {
patterns = append(patterns, s.Text()) patterns = append(patterns, s.Text())
} }
return patterns, true, s.Err() return patterns, true, s.Err()
}) }).Walk(f, filename)
ok, err := fw.Walk(filepath.Join(dir, "bad_filename.txt"))
require.Error(t, err) require.Error(t, err)
assert.False(t, ok) assert.False(t, ok)
assert.ErrorIs(t, err, filepath.ErrBadPattern) assert.ErrorIs(t, err, path.ErrBadPattern)
}) })
t.Run("itself_error", func(t *testing.T) { t.Run("itself_error", func(t *testing.T) {
const rerr errors.Error = "returned error" const rerr errors.Error = "returned error"
dir := testFSGen{ f := fstest.MapFS{
"mockfile.txt": []byte(`mockdata`), "mockfile.txt": &fstest.MapFile{Data: []byte(`mockdata`)},
}.gen(t) }
ok, err := FileWalker(func(r io.Reader) (patterns []string, ok bool, err error) { ok, err := FileWalker(func(r io.Reader) (patterns []string, ok bool, err error) {
return nil, true, rerr return nil, true, rerr
}).Walk(filepath.Join(dir, "*")) }).Walk(f, "*")
require.Error(t, err) require.ErrorIs(t, err, rerr)
require.False(t, ok)
assert.ErrorIs(t, err, rerr) assert.False(t, ok)
}) })
} }
type errFS struct {
fs.GlobFS
}
const errErrFSOpen errors.Error = "this error is always returned"
func (efs *errFS) Open(name string) (fs.File, error) {
return nil, errErrFSOpen
}
func TestWalkerFunc_CheckFile(t *testing.T) { func TestWalkerFunc_CheckFile(t *testing.T) {
emptyFS := fstest.MapFS{}
t.Run("non-existing", func(t *testing.T) { t.Run("non-existing", func(t *testing.T) {
_, ok, err := checkFile(nil, "lol") _, ok, err := checkFile(emptyFS, nil, "lol")
require.NoError(t, err) require.NoError(t, err)
assert.True(t, ok) assert.True(t, ok)
}) })
t.Run("invalid_argument", func(t *testing.T) { t.Run("invalid_argument", func(t *testing.T) {
const badPath = "\x00" _, ok, err := checkFile(&errFS{}, nil, "")
require.ErrorIs(t, err, errErrFSOpen)
_, ok, err := checkFile(nil, badPath)
require.Error(t, err)
assert.False(t, ok) assert.False(t, ok)
// TODO(e.burkov): Use assert.ErrorsIs within the error from })
// less platform-dependent package instead of syscall.EINVAL.
// t.Run("ignore_dirs", func(t *testing.T) {
// See https://github.com/golang/go/issues/46849 and const dirName = "dir"
// https://github.com/golang/go/issues/30322.
pathErr := &os.PathError{} testFS := fstest.MapFS{
require.ErrorAs(t, err, &pathErr) path.Join(dirName, "file"): &fstest.MapFile{Data: []byte{}},
assert.Equal(t, "open", pathErr.Op) }
assert.Equal(t, badPath, pathErr.Path)
patterns, ok, err := checkFile(testFS, nil, dirName)
require.NoError(t, err)
assert.Empty(t, patterns)
assert.True(t, ok)
}) })
} }

131
internal/aghos/fswatcher.go Normal file
View File

@ -0,0 +1,131 @@
package aghos
import (
"fmt"
"io"
"io/fs"
"path/filepath"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/fsnotify/fsnotify"
)
// event is a convenient alias for an empty struct to signal that watching
// event happened.
type event = struct{}
// FSWatcher tracks all the fyle system events and notifies about those.
//
// TODO(e.burkov, a.garipov): Move into another package like aghfs.
type FSWatcher interface {
io.Closer
// Events should return a read-only channel which notifies about events.
Events() (e <-chan event)
// Add should check if the file named name is accessible and starts tracking
// it.
Add(name string) (err error)
}
// osWatcher tracks the file system provided by the OS.
type osWatcher struct {
// w is the actual notifier that is handled by osWatcher.
w *fsnotify.Watcher
// events is the channel to notify.
events chan event
}
const (
// osWatcherPref is a prefix for logging and wrapping errors in osWathcer's
// methods.
osWatcherPref = "os watcher"
)
// NewOSWritesWatcher creates FSWatcher that tracks the real file system of the
// OS and notifies only about writing events.
func NewOSWritesWatcher() (w FSWatcher, err error) {
defer func() { err = errors.Annotate(err, "%s: %w", osWatcherPref) }()
var watcher *fsnotify.Watcher
watcher, err = fsnotify.NewWatcher()
if err != nil {
return nil, fmt.Errorf("creating watcher: %w", err)
}
fsw := &osWatcher{
w: watcher,
events: make(chan event, 1),
}
go fsw.handleErrors()
go fsw.handleEvents()
return fsw, nil
}
// handleErrors handles accompanying errors. It used to be called in a separate
// goroutine.
func (w *osWatcher) handleErrors() {
defer log.OnPanic(fmt.Sprintf("%s: handling errors", osWatcherPref))
for err := range w.w.Errors {
log.Error("%s: %s", osWatcherPref, err)
}
}
// Events implements the FSWatcher interface for *osWatcher.
func (w *osWatcher) Events() (e <-chan event) {
return w.events
}
// Add implements the FSWatcher interface for *osWatcher.
func (w *osWatcher) Add(name string) (err error) {
defer func() { err = errors.Annotate(err, "%s: %w", osWatcherPref) }()
if _, err = fs.Stat(RootDirFS(), name); err != nil {
return fmt.Errorf("checking file %q: %w", name, err)
}
return w.w.Add(filepath.Join("/", name))
}
// Close implements the FSWatcher interface for *osWatcher.
func (w *osWatcher) Close() (err error) {
return w.w.Close()
}
// handleEvents notifies about the received file system's event if needed. It
// used to be called in a separate goroutine.
func (w *osWatcher) handleEvents() {
defer log.OnPanic(fmt.Sprintf("%s: handling events", osWatcherPref))
defer close(w.events)
ch := w.w.Events
for e := range ch {
if e.Op&fsnotify.Write == 0 {
continue
}
// Skip the following events assuming that sometimes the same event
// occurrs several times.
for ok := true; ok; {
select {
case _, ok = <-ch:
// Go on.
default:
ok = false
}
}
select {
case w.events <- event{}:
// Go on.
default:
log.Debug("%s: events buffer is full", osWatcherPref)
}
}
}

View File

@ -7,6 +7,8 @@ import (
"bufio" "bufio"
"fmt" "fmt"
"io" "io"
"io/fs"
"os"
"os/exec" "os/exec"
"path" "path"
"runtime" "runtime"
@ -159,3 +161,10 @@ ScanLoop:
func IsOpenWrt() (ok bool) { func IsOpenWrt() (ok bool) {
return isOpenWrt() return isOpenWrt()
} }
// RootDirFS returns the fs.FS rooted at the operating system's root.
func RootDirFS() (fsys fs.FS) {
// Use empty string since os.DirFS implicitly prepends a slash to it. This
// behavior is undocumented but it currently works.
return os.DirFS("")
}

View File

@ -26,6 +26,8 @@ func haveAdminRights() (bool, error) {
} }
func isOpenWrt() (ok bool) { func isOpenWrt() (ok bool) {
const etcReleasePattern = "etc/*release*"
var err error var err error
ok, err = FileWalker(func(r io.Reader) (_ []string, cont bool, err error) { ok, err = FileWalker(func(r io.Reader) (_ []string, cont bool, err error) {
const osNameData = "openwrt" const osNameData = "openwrt"
@ -39,7 +41,7 @@ func isOpenWrt() (ok bool) {
} }
return nil, !stringutil.ContainsFold(string(data), osNameData), nil return nil, !stringutil.ContainsFold(string(data), osNameData), nil
}).Walk("/etc/*release*") }).Walk(RootDirFS(), etcReleasePattern)
return err == nil && ok return err == nil && ok
} }

View File

@ -0,0 +1,23 @@
package aghtest
// FSWatcher is a mock aghos.FSWatcher implementation to use in tests.
type FSWatcher struct {
OnEvents func() (e <-chan struct{})
OnAdd func(name string) (err error)
OnClose func() (err error)
}
// Events implements the aghos.FSWatcher interface for *FSWatcher.
func (w *FSWatcher) Events() (e <-chan struct{}) {
return w.OnEvents()
}
// Add implements the aghos.FSWatcher interface for *FSWatcher.
func (w *FSWatcher) Add(name string) (err error) {
return w.OnAdd(name)
}
// Close implements the aghos.FSWatcher interface for *FSWatcher.
func (w *FSWatcher) Close() (err error) {
return w.OnClose()
}

View File

@ -11,9 +11,10 @@ import (
"fmt" "fmt"
"math/big" "math/big"
"net" "net"
"os"
"sync" "sync"
"sync/atomic"
"testing" "testing"
"testing/fstest"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghnet"
@ -1057,23 +1058,43 @@ func TestPTRResponseFromDHCPLeases(t *testing.T) {
} }
func TestPTRResponseFromHosts(t *testing.T) { func TestPTRResponseFromHosts(t *testing.T) {
c := filtering.Config{ // Prepare test hosts file.
EtcHosts: &aghnet.EtcHostsContainer{},
const hostsFilename = "hosts"
testFS := fstest.MapFS{
hostsFilename: &fstest.MapFile{Data: []byte(`
127.0.0.1 host # comment
::1 localhost#comment
`)},
} }
// Prepare test hosts file. const closeCalled errors.Error = "close method called"
hf, err := os.CreateTemp("", "")
var eventsCalledCounter uint32
hc, err := aghnet.NewHostsContainer(testFS, &aghtest.FSWatcher{
OnEvents: func() (e <-chan struct{}) {
assert.Equal(t, uint32(1), atomic.AddUint32(&eventsCalledCounter, 1))
return nil
},
OnAdd: func(name string) (err error) {
assert.Equal(t, hostsFilename, name)
return nil
},
OnClose: func() (err error) { return closeCalled },
}, hostsFilename)
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(func() { t.Cleanup(func() {
assert.NoError(t, hf.Close()) assert.Equal(t, uint32(1), atomic.LoadUint32(&eventsCalledCounter))
assert.NoError(t, os.Remove(hf.Name()))
require.ErrorIs(t, hc.Close(), closeCalled)
}) })
_, _ = hf.WriteString(" 127.0.0.1 host # comment \n") c := filtering.Config{
_, _ = hf.WriteString(" ::1 localhost#comment \n") EtcHosts: hc,
}
c.EtcHosts.Init(hf.Name())
t.Cleanup(c.EtcHosts.Close)
var snd *aghnet.SubnetDetector var snd *aghnet.SubnetDetector
snd, err = aghnet.NewSubnetDetector() snd, err = aghnet.NewSubnetDetector()
@ -1109,7 +1130,7 @@ func TestPTRResponseFromHosts(t *testing.T) {
resp, err := dns.Exchange(req, addr.String()) resp, err := dns.Exchange(req, addr.String())
require.NoError(t, err) require.NoError(t, err)
require.Len(t, resp.Answer, 1) require.Lenf(t, resp.Answer, 1, "%#v", resp)
assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype) assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype)
assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name) assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name)

View File

@ -15,9 +15,9 @@ type DNSRewriteResult struct {
// the server returns. // the server returns.
type DNSRewriteResultResponse map[rules.RRType][]rules.RRValue type DNSRewriteResultResponse map[rules.RRType][]rules.RRValue
// processDNSRewrites processes DNS rewrite rules in dnsr. It returns // processDNSRewrites processes DNS rewrite rules in dnsr. It returns an empty
// an empty result if dnsr is empty. Otherwise, the result will have // result if dnsr is empty. Otherwise, the result will have either CanonName or
// either CanonName or DNSRewriteResult set. // DNSRewriteResult set.
func (d *DNSFilter) processDNSRewrites(dnsr []*rules.NetworkRule) (res Result) { func (d *DNSFilter) processDNSRewrites(dnsr []*rules.NetworkRule) (res Result) {
if len(dnsr) == 0 { if len(dnsr) == 0 {
return Result{} return Result{}

View File

@ -73,7 +73,7 @@ type Config struct {
// EtcHosts is a container of IP-hostname pairs taken from the operating // EtcHosts is a container of IP-hostname pairs taken from the operating
// system configuration files (e.g. /etc/hosts). // system configuration files (e.g. /etc/hosts).
EtcHosts *aghnet.EtcHostsContainer `yaml:"-"` EtcHosts *aghnet.HostsContainer `yaml:"-"`
// Called when the configuration is changed by HTTP request // Called when the configuration is changed by HTTP request
ConfigModified func() `yaml:"-"` ConfigModified func() `yaml:"-"`
@ -176,8 +176,8 @@ const (
// FilteredBlockedService - the host is blocked by "blocked services" settings // FilteredBlockedService - the host is blocked by "blocked services" settings
FilteredBlockedService FilteredBlockedService
// Rewritten is returned when there was a rewrite by a legacy DNS // Rewritten is returned when there was a rewrite by a legacy DNS rewrite
// rewrite rule. // rule.
Rewritten Rewritten
// RewrittenAutoHosts is returned when there was a rewrite by autohosts // RewrittenAutoHosts is returned when there was a rewrite by autohosts
@ -186,8 +186,8 @@ const (
// RewrittenRule is returned when a $dnsrewrite filter rule was applied. // RewrittenRule is returned when a $dnsrewrite filter rule was applied.
// //
// TODO(a.garipov): Remove Rewritten and RewrittenAutoHosts by merging // TODO(a.garipov): Remove Rewritten and RewrittenAutoHosts by merging their
// their functionality into RewrittenRule. // functionality into RewrittenRule.
// //
// See https://github.com/AdguardTeam/AdGuardHome/issues/2499. // See https://github.com/AdguardTeam/AdGuardHome/issues/2499.
RewrittenRule RewrittenRule
@ -371,24 +371,23 @@ type Result struct {
// Reason is the reason for blocking or unblocking the request. // Reason is the reason for blocking or unblocking the request.
Reason Reason `json:",omitempty"` Reason Reason `json:",omitempty"`
// Rules are applied rules. If Rules are not empty, each rule // Rules are applied rules. If Rules are not empty, each rule is not nil.
// is not nil.
Rules []*ResultRule `json:",omitempty"` Rules []*ResultRule `json:",omitempty"`
// ReverseHosts is the reverse lookup rewrite result. It is // ReverseHosts is the reverse lookup rewrite result. It is empty unless
// empty unless Reason is set to RewrittenAutoHosts. // Reason is set to RewrittenAutoHosts.
ReverseHosts []string `json:",omitempty"` ReverseHosts []string `json:",omitempty"`
// IPList is the lookup rewrite result. It is empty unless // IPList is the lookup rewrite result. It is empty unless Reason is set to
// Reason is set to RewrittenAutoHosts or Rewritten. // RewrittenAutoHosts or Rewritten.
IPList []net.IP `json:",omitempty"` IPList []net.IP `json:",omitempty"`
// CanonName is the CNAME value from the lookup rewrite result. // CanonName is the CNAME value from the lookup rewrite result. It is empty
// It is empty unless Reason is set to Rewritten or RewrittenRule. // unless Reason is set to Rewritten or RewrittenRule.
CanonName string `json:",omitempty"` CanonName string `json:",omitempty"`
// ServiceName is the name of the blocked service. It is empty // ServiceName is the name of the blocked service. It is empty unless
// unless Reason is set to FilteredBlockedService. // Reason is set to FilteredBlockedService.
ServiceName string `json:",omitempty"` ServiceName string `json:",omitempty"`
// DNSRewriteResult is the $dnsrewrite filter rule result. // DNSRewriteResult is the $dnsrewrite filter rule result.
@ -446,43 +445,49 @@ func (d *DNSFilter) CheckHost(
return Result{}, nil return Result{}, nil
} }
// checkEtcHosts compares the host against our /etc/hosts table. The err is // matchSysHosts tries to match the host against the operating system's hosts
// always nil, it is only there to make this a valid hostChecker function. // database.
func (d *DNSFilter) checkEtcHosts( func (d *DNSFilter) matchSysHosts(host string, qtype uint16, setts *Settings) (res Result, err error) {
host string, if d.EtcHosts == nil {
qtype uint16,
_ *Settings,
) (res Result, err error) {
if d.Config.EtcHosts == nil {
return Result{}, nil return Result{}, nil
} }
ips := d.Config.EtcHosts.Process(host, qtype) dnsres, _ := d.EtcHosts.MatchRequest(urlfilter.DNSRequest{
if ips != nil { Hostname: host,
res = Result{ SortedClientTags: setts.ClientTags,
// TODO(e.burkov): Wait for urlfilter update to pass net.IP.
ClientIP: setts.ClientIP.String(),
ClientName: setts.ClientName,
DNSType: qtype,
})
dnsr := dnsres.DNSRewrites()
if len(dnsr) == 0 {
return Result{}, nil
}
var ips []net.IP
var revHosts []string
for _, nr := range dnsr {
dr := nr.DNSRewrite
if dr == nil {
continue
}
switch val := nr.DNSRewrite.Value.(type) {
case net.IP:
ips = append(ips, val)
case string:
revHosts = append(revHosts, val)
}
}
return Result{
Reason: RewrittenAutoHosts, Reason: RewrittenAutoHosts,
IPList: ips, IPList: ips,
} ReverseHosts: revHosts,
}, nil
return res, nil
}
revHosts := d.Config.EtcHosts.ProcessReverse(host, qtype)
if len(revHosts) != 0 {
res = Result{
Reason: RewrittenAutoHosts,
}
// TODO(a.garipov): Optimize this with a buffer.
res.ReverseHosts = make([]string, len(revHosts))
for i := range revHosts {
res.ReverseHosts[i] = revHosts[i] + "."
}
return res, nil
}
return Result{}, nil
} }
// Process rewrites table // Process rewrites table
@ -647,15 +652,18 @@ func (d *DNSFilter) initFiltering(allowFilters, blockFilters []Filter) error {
return err return err
} }
func() {
d.engineLock.Lock() d.engineLock.Lock()
defer d.engineLock.Unlock()
d.reset() d.reset()
d.rulesStorage = rulesStorage d.rulesStorage = rulesStorage
d.filteringEngine = filteringEngine d.filteringEngine = filteringEngine
d.rulesStorageAllow = rulesStorageAllow d.rulesStorageAllow = rulesStorageAllow
d.filteringEngineAllow = filteringEngineAllow d.filteringEngineAllow = filteringEngineAllow
d.engineLock.Unlock() }()
// Make sure that the OS reclaims memory as soon as possible // Make sure that the OS reclaims memory as soon as possible.
debug.FreeOSMemory() debug.FreeOSMemory()
log.Debug("initialized filtering engine") log.Debug("initialized filtering engine")
@ -734,8 +742,8 @@ func (d *DNSFilter) matchHostProcessDNSResult(
} }
if dnsres.HostRulesV4 != nil || dnsres.HostRulesV6 != nil { if dnsres.HostRulesV4 != nil || dnsres.HostRulesV6 != nil {
// Question type doesn't match the host rules. Return the first // Question type doesn't match the host rules. Return the first matched
// matched host rule, but without an IP address. // host rule, but without an IP address.
var matchedRules []rules.Rule var matchedRules []rules.Rule
if dnsres.HostRulesV4 != nil { if dnsres.HostRulesV4 != nil {
matchedRules = []rules.Rule{dnsres.HostRulesV4[0]} matchedRules = []rules.Rule{dnsres.HostRulesV4[0]}
@ -760,11 +768,6 @@ func (d *DNSFilter) matchHost(
return Result{}, nil return Result{}, nil
} }
d.engineLock.RLock()
// Keep in mind that this lock must be held no just when calling Match()
// but also while using the rules returned by it.
defer d.engineLock.RUnlock()
ureq := urlfilter.DNSRequest{ ureq := urlfilter.DNSRequest{
Hostname: host, Hostname: host,
SortedClientTags: setts.ClientTags, SortedClientTags: setts.ClientTags,
@ -774,6 +777,13 @@ func (d *DNSFilter) matchHost(
DNSType: qtype, DNSType: qtype,
} }
d.engineLock.RLock()
// Keep in mind that this lock must be held no just when calling Match() but
// also while using the rules returned by it.
//
// TODO(e.burkov): Inspect if the above is true.
defer d.engineLock.RUnlock()
if d.filteringEngineAllow != nil { if d.filteringEngineAllow != nil {
dnsres, ok := d.filteringEngineAllow.MatchRequest(ureq) dnsres, ok := d.filteringEngineAllow.MatchRequest(ureq)
if ok { if ok {
@ -791,8 +801,8 @@ func (d *DNSFilter) matchHost(
if dnsr := dnsres.DNSRewrites(); len(dnsr) > 0 { if dnsr := dnsres.DNSRewrites(); len(dnsr) > 0 {
res = d.processDNSRewrites(dnsr) res = d.processDNSRewrites(dnsr)
if res.Reason == RewrittenRule && res.CanonName == host { if res.Reason == RewrittenRule && res.CanonName == host {
// A rewrite of a host to itself. Go on and try // A rewrite of a host to itself. Go on and try matching other
// matching other things. // things.
} else { } else {
return res, nil return res, nil
} }
@ -868,8 +878,8 @@ func New(c *Config, blockFilters []Filter) *DNSFilter {
} }
d.hostCheckers = []hostChecker{{ d.hostCheckers = []hostChecker{{
check: d.checkEtcHosts, check: d.matchSysHosts,
name: "etchosts", name: "hosts container",
}, { }, {
check: d.matchHost, check: d.matchHost,
name: "filtering", name: "filtering",

View File

@ -95,7 +95,9 @@ type clientsContainer struct {
// dnsServer is used for checking clients IP status access list status // dnsServer is used for checking clients IP status access list status
dnsServer *dnsforward.Server dnsServer *dnsforward.Server
etcHosts *aghnet.EtcHostsContainer // get entries from system hosts-files // etcHosts contains list of rewrite rules taken from the operating system's
// hosts databse.
etcHosts *aghnet.HostsContainer
testing bool // if TRUE, this object is used for internal tests testing bool // if TRUE, this object is used for internal tests
} }
@ -106,7 +108,7 @@ type clientsContainer struct {
func (clients *clientsContainer) Init( func (clients *clientsContainer) Init(
objects []clientObject, objects []clientObject,
dhcpServer *dhcpd.Server, dhcpServer *dhcpd.Server,
etcHosts *aghnet.EtcHostsContainer, etcHosts *aghnet.HostsContainer,
) { ) {
if clients.list != nil { if clients.list != nil {
log.Fatal("clients.list != nil") log.Fatal("clients.list != nil")
@ -121,13 +123,22 @@ func (clients *clientsContainer) Init(
clients.etcHosts = etcHosts clients.etcHosts = etcHosts
clients.addFromConfig(objects) clients.addFromConfig(objects)
if !clients.testing { if clients.testing {
return
}
clients.updateFromDHCP(true) clients.updateFromDHCP(true)
if clients.dhcpServer != nil { if clients.dhcpServer != nil {
clients.dhcpServer.SetOnLeaseChanged(clients.onDHCPLeaseChanged) clients.dhcpServer.SetOnLeaseChanged(clients.onDHCPLeaseChanged)
} }
go clients.handleHostsUpdates()
}
func (clients *clientsContainer) handleHostsUpdates() {
if clients.etcHosts != nil { if clients.etcHosts != nil {
clients.etcHosts.SetOnChanged(clients.onHostsChanged) for upd := range clients.etcHosts.Upd() {
clients.addFromHostsFile(upd)
} }
} }
} }
@ -250,10 +261,6 @@ func (clients *clientsContainer) onDHCPLeaseChanged(flags int) {
} }
} }
func (clients *clientsContainer) onHostsChanged() {
clients.addFromHostsFile()
}
// Exists checks if client with this IP address already exists. // Exists checks if client with this IP address already exists.
func (clients *clientsContainer) Exists(ip net.IP, source clientSource) (ok bool) { func (clients *clientsContainer) Exists(ip net.IP, source clientSource) (ok bool) {
clients.lock.Lock() clients.lock.Lock()
@ -697,7 +704,7 @@ func (clients *clientsContainer) SetWHOISInfo(ip net.IP, wi *RuntimeClientWHOISI
log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi) log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi)
} }
// AddHost adds a new IP-hostname pairing. The priorities of the sources is // AddHost adds a new IP-hostname pairing. The priorities of the sources are
// taken into account. ok is true if the pairing was added. // taken into account. ok is true if the pairing was added.
func (clients *clientsContainer) AddHost(ip net.IP, host string, src clientSource) (ok bool, err error) { func (clients *clientsContainer) AddHost(ip net.IP, host string, src clientSource) (ok bool, err error) {
clients.lock.Lock() clients.lock.Lock()
@ -757,13 +764,7 @@ func (clients *clientsContainer) rmHostsBySrc(src clientSource) {
// addFromHostsFile fills the client-hostname pairing index from the system's // addFromHostsFile fills the client-hostname pairing index from the system's
// hosts files. // hosts files.
func (clients *clientsContainer) addFromHostsFile() { func (clients *clientsContainer) addFromHostsFile(hosts *netutil.IPMap) {
if clients.etcHosts == nil {
return
}
hosts := clients.etcHosts.List()
clients.lock.Lock() clients.lock.Lock()
defer clients.lock.Unlock() defer clients.lock.Unlock()

View File

@ -710,7 +710,6 @@ func enableFilters(async bool) {
} }
func enableFiltersLocked(async bool) { func enableFiltersLocked(async bool) {
var whiteFilters []filtering.Filter
filters := []filtering.Filter{{ filters := []filtering.Filter{{
Data: []byte(strings.Join(config.UserRules, "\n")), Data: []byte(strings.Join(config.UserRules, "\n")),
}} }}
@ -725,18 +724,20 @@ func enableFiltersLocked(async bool) {
FilePath: filter.Path(), FilePath: filter.Path(),
}) })
} }
var allowFilters []filtering.Filter
for _, filter := range config.WhitelistFilters { for _, filter := range config.WhitelistFilters {
if !filter.Enabled { if !filter.Enabled {
continue continue
} }
whiteFilters = append(whiteFilters, filtering.Filter{ allowFilters = append(allowFilters, filtering.Filter{
ID: filter.ID, ID: filter.ID,
FilePath: filter.Path(), FilePath: filter.Path(),
}) })
} }
if err := Context.dnsFilter.SetFilters(filters, whiteFilters, async); err != nil { if err := Context.dnsFilter.SetFilters(filters, allowFilters, async); err != nil {
log.Debug("enabling filters: %s", err) log.Debug("enabling filters: %s", err)
} }

View File

@ -56,7 +56,9 @@ type homeContext struct {
filters Filtering // DNS filtering module filters Filtering // DNS filtering module
web *Web // Web (HTTP, HTTPS) module web *Web // Web (HTTP, HTTPS) module
tls *TLSMod // TLS module tls *TLSMod // TLS module
etcHosts *aghnet.EtcHostsContainer // IP-hostname pairs taken from system configuration (e.g. /etc/hosts) files // etcHosts is an IP-hostname pairs set taken from system configuration
// (e.g. /etc/hosts) files.
etcHosts *aghnet.HostsContainer
updater *updater.Updater updater *updater.Updater
subnetDetector *aghnet.SubnetDetector subnetDetector *aghnet.SubnetDetector
@ -257,8 +259,20 @@ func setupConfig(args options) (err error) {
}) })
if !args.noEtcHosts { if !args.noEtcHosts {
Context.etcHosts = &aghnet.EtcHostsContainer{} var osWritesWatcher aghos.FSWatcher
Context.etcHosts.Init("") osWritesWatcher, err = aghos.NewOSWritesWatcher()
if err != nil {
return fmt.Errorf("initing os watcher: %w", err)
}
Context.etcHosts, err = aghnet.NewHostsContainer(
aghos.RootDirFS(),
osWritesWatcher,
aghnet.DefaultHostsPaths()...,
)
if err != nil {
return fmt.Errorf("initing hosts container: %w", err)
}
} }
Context.clients.Init(config.Clients, Context.dhcpServer, Context.etcHosts) Context.clients.Init(config.Clients, Context.dhcpServer, Context.etcHosts)
config.Clients = nil config.Clients = nil
@ -424,7 +438,6 @@ func run(args options, clientBuildFS fs.FS) {
fatalOnError(err) fatalOnError(err)
Context.tls.Start() Context.tls.Start()
Context.etcHosts.Start()
go func() { go func() {
serr := startDNSServer() serr := startDNSServer()
@ -647,7 +660,11 @@ func cleanup(ctx context.Context) {
} }
} }
Context.etcHosts.Close() if Context.etcHosts != nil {
if err = Context.etcHosts.Close(); err != nil {
log.Error("stopping hosts container: %s", err)
}
}
if Context.tls != nil { if Context.tls != nil {
Context.tls.Close() Context.tls.Close()