Pull request: 2499 merge rewrites vol.1
Merge in DNS/adguard-home from 2499-merge-rewrites-vol.1 to master
Updates #2499.
Squashed commit of the following:
commit 6b308bc2b360cee8c22e506f31d62bacb4bf8fb3
Merge: f49e9186 2b635bf6
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Thu Oct 14 19:23:07 2021 +0300
Merge branch 'master' into 2499-merge-rewrites-vol.1
commit f49e9186ffc8b7074d03c6721ee56cdb09243684
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Thu Oct 14 18:50:49 2021 +0300
aghos: fix fs events filtering
commit 567dd646556606212af5dab60e3ecbb8fff22c25
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Thu Oct 14 16:50:37 2021 +0300
all: imp code, docs, fix windows
commit 140c8bf519345eb54d0e7500a996fcf465353d71
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Wed Oct 13 19:41:53 2021 +0300
aghnet: use const
commit bebf3f76bd394a498ccad812c57d4507c69529ba
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Wed Oct 13 19:32:37 2021 +0300
all: imp tests, docs
commit 9bfdbb6eb454833135d616e208e82699f98e2562
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Wed Oct 13 18:42:20 2021 +0300
all: imp path more, imp docs
commit ee9ea4c132a6b17787d150bf2bee703abaa57be3
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Wed Oct 13 16:09:46 2021 +0300
all: fix windows, imp paths
commit 6fac8338a81e9ecfebfc23a1adcb964e89f6aee6
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Mon Oct 11 19:53:35 2021 +0300
all: imp code, docs
commit da1ce1a2a3dd2be3fdff2412a6dbd596859dc249
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Mon Oct 11 18:22:50 2021 +0300
aghnet: fix windows tests
commit d29de359ed68118d71efb226a8433fac15ff5c66
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Fri Oct 8 21:02:14 2021 +0300
all: repl & imp
commit 1356c08944cdbb85ce5532d90fe5b077219ce5ff
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Fri Oct 8 01:41:19 2021 +0300
all: add tests, mv logic, added tmpfs
commit f4b11adf8998bc8d9d955c5ac9f386f671bd5213
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Thu Oct 7 14:26:30 2021 +0300
all: imp filewalker, refactor hosts container
This commit is contained in:
parent
2b635bf689
commit
2796e65468
@ -1,387 +0,0 @@
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type onChangedT func()
|
||||
|
||||
// EtcHostsContainer - automatic DNS records
|
||||
//
|
||||
// TODO(e.burkov): Move the logic under interface. Refactor. Probably remove
|
||||
// the resolving logic.
|
||||
type EtcHostsContainer struct {
|
||||
// lock protects table and tableReverse.
|
||||
lock sync.RWMutex
|
||||
// table is the host-to-IPs map.
|
||||
table map[string][]net.IP
|
||||
// tableReverse is the IP-to-hosts map. The type of the values in the
|
||||
// map is []string.
|
||||
tableReverse *netutil.IPMap
|
||||
|
||||
hostsFn string // path to the main hosts-file
|
||||
hostsDirs []string // paths to OS-specific directories with hosts-files
|
||||
watcher *fsnotify.Watcher // file and directory watcher object
|
||||
|
||||
// onlyWritesChan used to contain only writing events from watcher.
|
||||
onlyWritesChan chan fsnotify.Event
|
||||
|
||||
onChanged onChangedT // notification to other modules
|
||||
}
|
||||
|
||||
// SetOnChanged - set callback function that will be called when the data is changed
|
||||
func (ehc *EtcHostsContainer) SetOnChanged(onChanged onChangedT) {
|
||||
ehc.onChanged = onChanged
|
||||
}
|
||||
|
||||
// Notify other modules
|
||||
func (ehc *EtcHostsContainer) notify() {
|
||||
if ehc.onChanged == nil {
|
||||
return
|
||||
}
|
||||
ehc.onChanged()
|
||||
}
|
||||
|
||||
// Init - initialize
|
||||
// hostsFn: Override default name for the hosts-file (optional)
|
||||
func (ehc *EtcHostsContainer) Init(hostsFn string) {
|
||||
ehc.table = make(map[string][]net.IP)
|
||||
ehc.onlyWritesChan = make(chan fsnotify.Event, 2)
|
||||
|
||||
ehc.hostsFn = "/etc/hosts"
|
||||
if runtime.GOOS == "windows" {
|
||||
ehc.hostsFn = os.ExpandEnv("$SystemRoot\\system32\\drivers\\etc\\hosts")
|
||||
}
|
||||
if len(hostsFn) != 0 {
|
||||
ehc.hostsFn = hostsFn
|
||||
}
|
||||
|
||||
if aghos.IsOpenWrt() {
|
||||
// OpenWrt: "/tmp/hosts/dhcp.cfg01411c".
|
||||
ehc.hostsDirs = append(ehc.hostsDirs, "/tmp/hosts")
|
||||
}
|
||||
|
||||
// Load hosts initially
|
||||
ehc.updateHosts()
|
||||
|
||||
var err error
|
||||
ehc.watcher, err = fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
log.Error("etchosts: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Start - start module
|
||||
func (ehc *EtcHostsContainer) Start() {
|
||||
if ehc == nil {
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug("Start etchostscontainer module")
|
||||
|
||||
ehc.updateHosts()
|
||||
|
||||
if ehc.watcher != nil {
|
||||
go ehc.watcherLoop()
|
||||
|
||||
err := ehc.watcher.Add(ehc.hostsFn)
|
||||
if err != nil {
|
||||
log.Error("Error while initializing watcher for a file %s: %s", ehc.hostsFn, err)
|
||||
}
|
||||
|
||||
for _, dir := range ehc.hostsDirs {
|
||||
err = ehc.watcher.Add(dir)
|
||||
if err != nil {
|
||||
log.Error("Error while initializing watcher for a directory %s: %s", dir, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close - close module
|
||||
func (ehc *EtcHostsContainer) Close() {
|
||||
if ehc == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if ehc.watcher != nil {
|
||||
_ = ehc.watcher.Close()
|
||||
}
|
||||
|
||||
// Don't close onlyWritesChan here and let onlyWrites close it after
|
||||
// watcher.Events is closed to prevent close races.
|
||||
}
|
||||
|
||||
// Process returns the list of IP addresses for the hostname or nil if nothing
|
||||
// found.
|
||||
func (ehc *EtcHostsContainer) Process(host string, qtype uint16) []net.IP {
|
||||
if qtype == dns.TypePTR {
|
||||
return nil
|
||||
}
|
||||
|
||||
var ipsCopy []net.IP
|
||||
ehc.lock.RLock()
|
||||
defer ehc.lock.RUnlock()
|
||||
|
||||
if ips, ok := ehc.table[host]; ok {
|
||||
ipsCopy = make([]net.IP, len(ips))
|
||||
copy(ipsCopy, ips)
|
||||
}
|
||||
|
||||
log.Debug("etchosts: answer: %s -> %v", host, ipsCopy)
|
||||
return ipsCopy
|
||||
}
|
||||
|
||||
// ProcessReverse processes a PTR request. It returns nil if nothing is found.
|
||||
func (ehc *EtcHostsContainer) ProcessReverse(addr string, qtype uint16) (hosts []string) {
|
||||
if qtype != dns.TypePTR {
|
||||
return nil
|
||||
}
|
||||
|
||||
ip, err := netutil.IPFromReversedAddr(addr)
|
||||
if err != nil {
|
||||
log.Error("etchosts: reversed addr: %s", err)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
ehc.lock.RLock()
|
||||
defer ehc.lock.RUnlock()
|
||||
|
||||
v, ok := ehc.tableReverse.Get(ip)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
hosts, ok = v.([]string)
|
||||
if !ok {
|
||||
log.Error("etchosts: bad type %T in tableReverse for %s", v, ip)
|
||||
|
||||
return nil
|
||||
} else if len(hosts) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Debug("etchosts: reverse-lookup: %s -> %s", addr, hosts)
|
||||
|
||||
return hosts
|
||||
}
|
||||
|
||||
// List returns an IP-to-hostnames table. The type of the values in the map is
|
||||
// []string. It is safe for concurrent use.
|
||||
func (ehc *EtcHostsContainer) List() (ipToHosts *netutil.IPMap) {
|
||||
ehc.lock.RLock()
|
||||
defer ehc.lock.RUnlock()
|
||||
|
||||
return ehc.tableReverse.ShallowClone()
|
||||
}
|
||||
|
||||
// update table
|
||||
func (ehc *EtcHostsContainer) updateTable(table map[string][]net.IP, host string, ipAddr net.IP) {
|
||||
ips, ok := table[host]
|
||||
if ok {
|
||||
for _, ip := range ips {
|
||||
if ip.Equal(ipAddr) {
|
||||
// IP already exists: don't add duplicates
|
||||
ok = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
ips = append(ips, ipAddr)
|
||||
table[host] = ips
|
||||
}
|
||||
} else {
|
||||
table[host] = []net.IP{ipAddr}
|
||||
ok = true
|
||||
}
|
||||
if ok {
|
||||
log.Debug("etchosts: added %s -> %s", ipAddr, host)
|
||||
}
|
||||
}
|
||||
|
||||
// updateTableRev updates the reverse address table.
|
||||
func (ehc *EtcHostsContainer) updateTableRev(tableRev *netutil.IPMap, newHost string, ip net.IP) {
|
||||
v, ok := tableRev.Get(ip)
|
||||
if !ok {
|
||||
tableRev.Set(ip, []string{newHost})
|
||||
log.Debug("etchosts: added reverse-address %s -> %s", ip, newHost)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
hosts, _ := v.([]string)
|
||||
for _, host := range hosts {
|
||||
if host == newHost {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
hosts = append(hosts, newHost)
|
||||
tableRev.Set(ip, hosts)
|
||||
|
||||
log.Debug("etchosts: added reverse-address %s -> %s", ip, newHost)
|
||||
}
|
||||
|
||||
// parseHostsLine parses hosts from the fields.
|
||||
func parseHostsLine(fields []string) (hosts []string) {
|
||||
for _, f := range fields {
|
||||
hashIdx := strings.IndexByte(f, '#')
|
||||
if hashIdx == 0 {
|
||||
// The rest of the fields are a part of the comment.
|
||||
// Skip immediately.
|
||||
return
|
||||
} else if hashIdx > 0 {
|
||||
// Only a part of the field is a comment.
|
||||
hosts = append(hosts, f[:hashIdx])
|
||||
|
||||
return hosts
|
||||
}
|
||||
|
||||
hosts = append(hosts, f)
|
||||
}
|
||||
|
||||
return hosts
|
||||
}
|
||||
|
||||
// load reads IP-hostname pairs from the hosts file. Multiple hostnames per
|
||||
// line for one IP are supported.
|
||||
func (ehc *EtcHostsContainer) load(
|
||||
table map[string][]net.IP,
|
||||
tableRev *netutil.IPMap,
|
||||
fn string,
|
||||
) {
|
||||
f, err := os.Open(fn)
|
||||
if err != nil {
|
||||
log.Error("etchosts: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
derr := f.Close()
|
||||
if derr != nil {
|
||||
log.Error("etchosts: closing file: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
log.Debug("etchosts: loading hosts from file %s", fn)
|
||||
|
||||
s := bufio.NewScanner(f)
|
||||
for s.Scan() {
|
||||
line := strings.TrimSpace(s.Text())
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) < 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
ip := net.ParseIP(fields[0])
|
||||
if ip == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
hosts := parseHostsLine(fields[1:])
|
||||
for _, host := range hosts {
|
||||
ehc.updateTable(table, host, ip)
|
||||
ehc.updateTableRev(tableRev, host, ip)
|
||||
}
|
||||
}
|
||||
|
||||
err = s.Err()
|
||||
if err != nil {
|
||||
log.Error("etchosts: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// onlyWrites is a filter for (*fsnotify.Watcher).Events.
|
||||
func (ehc *EtcHostsContainer) onlyWrites() {
|
||||
for event := range ehc.watcher.Events {
|
||||
if event.Op&fsnotify.Write == fsnotify.Write {
|
||||
ehc.onlyWritesChan <- event
|
||||
}
|
||||
}
|
||||
|
||||
close(ehc.onlyWritesChan)
|
||||
}
|
||||
|
||||
// Receive notifications from fsnotify package
|
||||
func (ehc *EtcHostsContainer) watcherLoop() {
|
||||
go ehc.onlyWrites()
|
||||
for {
|
||||
select {
|
||||
case event, ok := <-ehc.onlyWritesChan:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// Assume that we sometimes have the same event occurred
|
||||
// several times.
|
||||
repeat := true
|
||||
for repeat {
|
||||
select {
|
||||
case _, ok = <-ehc.onlyWritesChan:
|
||||
repeat = ok
|
||||
default:
|
||||
repeat = false
|
||||
}
|
||||
}
|
||||
|
||||
if event.Op&fsnotify.Write == fsnotify.Write {
|
||||
log.Debug("etchosts: modified: %s", event.Name)
|
||||
ehc.updateHosts()
|
||||
}
|
||||
|
||||
case err, ok := <-ehc.watcher.Errors:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
log.Error("etchosts: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// updateHosts - loads system hosts
|
||||
func (ehc *EtcHostsContainer) updateHosts() {
|
||||
table := make(map[string][]net.IP)
|
||||
tableRev := netutil.NewIPMap(0)
|
||||
|
||||
ehc.load(table, tableRev, ehc.hostsFn)
|
||||
|
||||
for _, dir := range ehc.hostsDirs {
|
||||
des, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
log.Error("etchosts: Opening directory: %q: %s", dir, err)
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
for _, de := range des {
|
||||
ehc.load(table, tableRev, filepath.Join(dir, de.Name()))
|
||||
}
|
||||
}
|
||||
|
||||
func() {
|
||||
ehc.lock.Lock()
|
||||
defer ehc.lock.Unlock()
|
||||
|
||||
ehc.table = table
|
||||
ehc.tableReverse = tableRev
|
||||
}()
|
||||
|
||||
ehc.notify()
|
||||
}
|
@ -1,130 +0,0 @@
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
aghtest.DiscardLogOutput(m)
|
||||
}
|
||||
|
||||
func prepareTestFile(t *testing.T) (f *os.File) {
|
||||
t.Helper()
|
||||
|
||||
dir := t.TempDir()
|
||||
|
||||
f, err := os.CreateTemp(dir, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, f)
|
||||
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, f.Close())
|
||||
})
|
||||
|
||||
return f
|
||||
}
|
||||
|
||||
func assertWriting(t *testing.T, f *os.File, strs ...string) {
|
||||
t.Helper()
|
||||
|
||||
for _, str := range strs {
|
||||
n, err := f.WriteString(str)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, n, len(str))
|
||||
}
|
||||
}
|
||||
|
||||
func TestEtcHostsContainerResolution(t *testing.T) {
|
||||
ehc := &EtcHostsContainer{}
|
||||
|
||||
f := prepareTestFile(t)
|
||||
|
||||
assertWriting(t, f,
|
||||
" 127.0.0.1 host localhost # comment \n",
|
||||
" ::1 localhost#comment \n",
|
||||
)
|
||||
ehc.Init(f.Name())
|
||||
|
||||
t.Run("existing_host", func(t *testing.T) {
|
||||
ips := ehc.Process("localhost", dns.TypeA)
|
||||
require.Len(t, ips, 1)
|
||||
assert.Equal(t, net.IPv4(127, 0, 0, 1), ips[0])
|
||||
})
|
||||
|
||||
t.Run("unknown_host", func(t *testing.T) {
|
||||
ips := ehc.Process("newhost", dns.TypeA)
|
||||
assert.Nil(t, ips)
|
||||
|
||||
// Comment.
|
||||
ips = ehc.Process("comment", dns.TypeA)
|
||||
assert.Nil(t, ips)
|
||||
})
|
||||
|
||||
t.Run("hosts_file", func(t *testing.T) {
|
||||
names, ok := ehc.List().Get(net.IP{127, 0, 0, 1})
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, []string{"host", "localhost"}, names)
|
||||
})
|
||||
|
||||
t.Run("ptr", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
wantIP string
|
||||
wantHost string
|
||||
wantLen int
|
||||
}{
|
||||
{wantIP: "127.0.0.1", wantHost: "host", wantLen: 2},
|
||||
{wantIP: "::1", wantHost: "localhost", wantLen: 1},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
a, err := dns.ReverseAddr(tc.wantIP)
|
||||
require.NoError(t, err)
|
||||
|
||||
a = strings.TrimSuffix(a, ".")
|
||||
hosts := ehc.ProcessReverse(a, dns.TypePTR)
|
||||
require.Len(t, hosts, tc.wantLen)
|
||||
assert.Equal(t, tc.wantHost, hosts[0])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestEtcHostsContainerFSNotify(t *testing.T) {
|
||||
ehc := &EtcHostsContainer{}
|
||||
|
||||
f := prepareTestFile(t)
|
||||
|
||||
assertWriting(t, f, " 127.0.0.1 host localhost \n")
|
||||
ehc.Init(f.Name())
|
||||
|
||||
t.Run("unknown_host", func(t *testing.T) {
|
||||
ips := ehc.Process("newhost", dns.TypeA)
|
||||
assert.Nil(t, ips)
|
||||
})
|
||||
|
||||
// Start monitoring for changes.
|
||||
ehc.Start()
|
||||
t.Cleanup(ehc.Close)
|
||||
|
||||
assertWriting(t, f, "127.0.0.2 newhost\n")
|
||||
require.NoError(t, f.Sync())
|
||||
|
||||
// Wait until fsnotify has triggered and processed the file-modification
|
||||
// event.
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
t.Run("notified", func(t *testing.T) {
|
||||
ips := ehc.Process("newhost", dns.TypeA)
|
||||
assert.NotNil(t, ips)
|
||||
require.Len(t, ips, 1)
|
||||
assert.True(t, net.IP{127, 0, 0, 2}.Equal(ips[0]))
|
||||
})
|
||||
}
|
343
internal/aghnet/hostscontainer.go
Normal file
343
internal/aghnet/hostscontainer.go
Normal file
@ -0,0 +1,343 @@
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net"
|
||||
"path"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/AdguardTeam/urlfilter"
|
||||
"github.com/AdguardTeam/urlfilter/filterlist"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// DefaultHostsPaths returns the slice of paths default for the operating system
|
||||
// to files and directories which are containing the hosts database. The result
|
||||
// is intended to use within fs.FS so the initial slash is omitted.
|
||||
func DefaultHostsPaths() (paths []string) {
|
||||
return defaultHostsPaths()
|
||||
}
|
||||
|
||||
// hostsContainerPref is a prefix for logging and wrapping errors in
|
||||
// HostsContainer's methods.
|
||||
const hostsContainerPref = "hosts container"
|
||||
|
||||
// HostsContainer stores the relevant hosts database provided by the OS and
|
||||
// processes both A/AAAA and PTR DNS requests for those.
|
||||
type HostsContainer struct {
|
||||
// engLock protects rulesStrg and engine.
|
||||
engLock *sync.RWMutex
|
||||
|
||||
// rulesStrg stores the rules obtained from the hosts' file.
|
||||
rulesStrg *filterlist.RuleStorage
|
||||
// engine serves rulesStrg.
|
||||
engine *urlfilter.DNSEngine
|
||||
|
||||
// Updates is the channel for receiving updated hosts. The receivable map's
|
||||
// values has a type of slice of strings.
|
||||
updates chan *netutil.IPMap
|
||||
|
||||
// fsys is the working file system to read hosts files from.
|
||||
fsys fs.FS
|
||||
|
||||
// w tracks the changes in specified files and directories.
|
||||
w aghos.FSWatcher
|
||||
// patterns stores specified paths in the fs.Glob-compatible form.
|
||||
patterns []string
|
||||
}
|
||||
|
||||
// errNoPaths is returned when there are no paths to watch passed to the
|
||||
// HostsContainer.
|
||||
const errNoPaths errors.Error = "hosts paths are empty"
|
||||
|
||||
// NewHostsContainer creates a container of hosts, that watches the paths with
|
||||
// w. paths shouldn't be empty and each of them should locate either a file or
|
||||
// a directory in fsys. fsys and w must be non-nil.
|
||||
func NewHostsContainer(
|
||||
fsys fs.FS,
|
||||
w aghos.FSWatcher,
|
||||
paths ...string,
|
||||
) (hc *HostsContainer, err error) {
|
||||
defer func() { err = errors.Annotate(err, "%s: %w", hostsContainerPref) }()
|
||||
|
||||
if len(paths) == 0 {
|
||||
return nil, errNoPaths
|
||||
}
|
||||
|
||||
patterns, err := pathsToPatterns(fsys, paths)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hc = &HostsContainer{
|
||||
engLock: &sync.RWMutex{},
|
||||
updates: make(chan *netutil.IPMap, 1),
|
||||
fsys: fsys,
|
||||
w: w,
|
||||
patterns: patterns,
|
||||
}
|
||||
|
||||
log.Debug("%s: starting", hostsContainerPref)
|
||||
|
||||
// Load initially.
|
||||
if err = hc.refresh(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, p := range paths {
|
||||
err = w.Add(p)
|
||||
if err == nil {
|
||||
continue
|
||||
} else if errors.Is(err, fs.ErrNotExist) {
|
||||
log.Debug("%s: file %q expected to exist but doesn't", hostsContainerPref, p)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("adding path: %w", err)
|
||||
}
|
||||
|
||||
go hc.handleEvents()
|
||||
|
||||
return hc, nil
|
||||
}
|
||||
|
||||
// MatchRequest is the request processing method to resolve hostnames and
|
||||
// addresses from the operating system's hosts files. Any request not of A/AAAA
|
||||
// or PTR type will return with an empty result. It's safe for concurrent use.
|
||||
func (hc *HostsContainer) MatchRequest(
|
||||
req urlfilter.DNSRequest,
|
||||
) (res urlfilter.DNSResult, ok bool) {
|
||||
switch req.DNSType {
|
||||
case dns.TypeA, dns.TypeAAAA, dns.TypePTR:
|
||||
log.Debug("%s: handling the request", hostsContainerPref)
|
||||
default:
|
||||
return urlfilter.DNSResult{}, false
|
||||
}
|
||||
|
||||
hc.engLock.RLock()
|
||||
defer hc.engLock.RUnlock()
|
||||
|
||||
return hc.engine.MatchRequest(req)
|
||||
}
|
||||
|
||||
// Close implements the io.Closer interface for *HostsContainer.
|
||||
func (hc *HostsContainer) Close() (err error) {
|
||||
log.Debug("%s: closing hosts container", hostsContainerPref)
|
||||
|
||||
return errors.Annotate(hc.w.Close(), "%s: closing: %w", hostsContainerPref)
|
||||
}
|
||||
|
||||
// Upd returns the channel into which the updates are sent.
|
||||
func (hc *HostsContainer) Upd() (updates <-chan *netutil.IPMap) {
|
||||
return hc.updates
|
||||
}
|
||||
|
||||
// pathsToPatterns converts paths into patterns compatible with fs.Glob.
|
||||
func pathsToPatterns(fsys fs.FS, paths []string) (patterns []string, err error) {
|
||||
for i, p := range paths {
|
||||
var fi fs.FileInfo
|
||||
if fi, err = fs.Stat(fsys, p); err != nil {
|
||||
return nil, fmt.Errorf("%q at index %d: %w", p, i, err)
|
||||
}
|
||||
|
||||
if fi.IsDir() {
|
||||
p = path.Join(p, "*")
|
||||
}
|
||||
|
||||
patterns = append(patterns, p)
|
||||
}
|
||||
|
||||
return patterns, nil
|
||||
}
|
||||
|
||||
// handleEvents concurrently handles the events. It closes the update channel
|
||||
// of HostsContainer when finishes. Used to be called within a goroutine.
|
||||
func (hc *HostsContainer) handleEvents() {
|
||||
defer log.OnPanic(fmt.Sprintf("%s: handling events", hostsContainerPref))
|
||||
|
||||
defer close(hc.updates)
|
||||
|
||||
for range hc.w.Events() {
|
||||
if err := hc.refresh(); err != nil {
|
||||
log.Error("%s: %s", hostsContainerPref, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// hostsParser is a helper type to parse rules from the operating system's hosts
|
||||
// file.
|
||||
type hostsParser struct {
|
||||
// rules builds the resulting rules list content.
|
||||
rules *strings.Builder
|
||||
|
||||
// table stores only the unique IP-hostname pairs. It's also sent to the
|
||||
// updates channel afterwards.
|
||||
table *netutil.IPMap
|
||||
}
|
||||
|
||||
// parseHostsFile is a aghtest.FileWalker for parsing the files with hosts
|
||||
// syntax. It never signs to stop the walking.
|
||||
//
|
||||
// See man hosts(5).
|
||||
func (hp hostsParser) parseHostsFile(
|
||||
r io.Reader,
|
||||
) (patterns []string, cont bool, err error) {
|
||||
s := bufio.NewScanner(r)
|
||||
for s.Scan() {
|
||||
ip, hosts := hp.parseLine(s.Text())
|
||||
if ip == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, host := range hosts {
|
||||
hp.addPair(ip, host)
|
||||
}
|
||||
}
|
||||
|
||||
return nil, true, s.Err()
|
||||
}
|
||||
|
||||
// parseLine parses the line having the hosts syntax ignoring invalid ones.
|
||||
func (hp hostsParser) parseLine(line string) (ip net.IP, hosts []string) {
|
||||
line = strings.TrimSpace(line)
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) < 2 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if ip = net.ParseIP(fields[0]); ip == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
loop:
|
||||
for _, f := range fields[1:] {
|
||||
switch hashIdx := strings.IndexByte(f, '#'); hashIdx {
|
||||
case 0:
|
||||
// The rest of the fields are a part of the comment so skip
|
||||
// immediately.
|
||||
break loop
|
||||
case -1:
|
||||
hosts = append(hosts, f)
|
||||
default:
|
||||
// Only a part of the field is a comment.
|
||||
hosts = append(hosts, f[:hashIdx])
|
||||
|
||||
break loop
|
||||
}
|
||||
}
|
||||
|
||||
return ip, hosts
|
||||
}
|
||||
|
||||
// add returns true if the pair of ip and host wasn't added to the hp before.
|
||||
func (hp hostsParser) add(ip net.IP, host string) (added bool) {
|
||||
v, ok := hp.table.Get(ip)
|
||||
hosts, _ := v.([]string)
|
||||
if ok && stringutil.InSlice(hosts, host) {
|
||||
return false
|
||||
}
|
||||
|
||||
hp.table.Set(ip, append(hosts, host))
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// addPair puts the pair of ip and host to the rules builder if needed.
|
||||
func (hp hostsParser) addPair(ip net.IP, host string) {
|
||||
arpa, err := netutil.IPToReversedAddr(ip)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if !hp.add(ip, host) {
|
||||
return
|
||||
}
|
||||
|
||||
qtype := "AAAA"
|
||||
if ip.To4() != nil {
|
||||
// Assume the validation of the IP address is performed already.
|
||||
qtype = "A"
|
||||
}
|
||||
|
||||
stringutil.WriteToBuilder(
|
||||
hp.rules,
|
||||
"||",
|
||||
host,
|
||||
"^$dnsrewrite=NOERROR;",
|
||||
qtype,
|
||||
";",
|
||||
ip.String(),
|
||||
"\n",
|
||||
"||",
|
||||
arpa,
|
||||
"^$dnsrewrite=NOERROR;PTR;",
|
||||
dns.Fqdn(host),
|
||||
"\n",
|
||||
)
|
||||
|
||||
log.Debug("%s: added ip-host pair %q/%q", hostsContainerPref, ip, host)
|
||||
}
|
||||
|
||||
// sendUpd tries to send the parsed data to the ch.
|
||||
func (hp hostsParser) sendUpd(ch chan *netutil.IPMap) {
|
||||
log.Debug("%s: sending upd", hostsContainerPref)
|
||||
select {
|
||||
case ch <- hp.table:
|
||||
// Updates are delivered. Go on.
|
||||
default:
|
||||
log.Debug("%s: the buffer is full", hostsContainerPref)
|
||||
}
|
||||
}
|
||||
|
||||
// newStrg creates a new rules storage from parsed data.
|
||||
func (hp hostsParser) newStrg() (s *filterlist.RuleStorage, err error) {
|
||||
return filterlist.NewRuleStorage([]filterlist.RuleList{&filterlist.StringRuleList{
|
||||
ID: 1,
|
||||
RulesText: hp.rules.String(),
|
||||
IgnoreCosmetic: true,
|
||||
}})
|
||||
}
|
||||
|
||||
// refresh gets the data from specified files and propagates the updates.
|
||||
func (hc *HostsContainer) refresh() (err error) {
|
||||
log.Debug("%s: refreshing", hostsContainerPref)
|
||||
|
||||
hp := hostsParser{
|
||||
rules: &strings.Builder{},
|
||||
table: netutil.NewIPMap(0),
|
||||
}
|
||||
|
||||
_, err = aghos.FileWalker(hp.parseHostsFile).Walk(hc.fsys, hc.patterns...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("updating: %w", err)
|
||||
}
|
||||
|
||||
defer hp.sendUpd(hc.updates)
|
||||
|
||||
var rulesStrg *filterlist.RuleStorage
|
||||
if rulesStrg, err = hp.newStrg(); err != nil {
|
||||
return fmt.Errorf("initializing rules storage: %w", err)
|
||||
}
|
||||
|
||||
hc.resetEng(rulesStrg)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hc *HostsContainer) resetEng(rulesStrg *filterlist.RuleStorage) {
|
||||
hc.engLock.Lock()
|
||||
defer hc.engLock.Unlock()
|
||||
|
||||
hc.rulesStrg = rulesStrg
|
||||
hc.engine = urlfilter.NewDNSEngine(hc.rulesStrg)
|
||||
}
|
18
internal/aghnet/hostscontainer_linux.go
Normal file
18
internal/aghnet/hostscontainer_linux.go
Normal file
@ -0,0 +1,18 @@
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
)
|
||||
|
||||
func defaultHostsPaths() (paths []string) {
|
||||
paths = []string{"etc/hosts"}
|
||||
|
||||
if aghos.IsOpenWrt() {
|
||||
paths = append(paths, "tmp/hosts")
|
||||
}
|
||||
|
||||
return paths
|
||||
}
|
8
internal/aghnet/hostscontainer_others.go
Normal file
8
internal/aghnet/hostscontainer_others.go
Normal file
@ -0,0 +1,8 @@
|
||||
//go:build !(windows || linux)
|
||||
// +build !windows,!linux
|
||||
|
||||
package aghnet
|
||||
|
||||
func defaultHostsPaths() (paths []string) {
|
||||
return []string{"etc/hosts"}
|
||||
}
|
504
internal/aghnet/hostscontainer_test.go
Normal file
504
internal/aghnet/hostscontainer_test.go
Normal file
@ -0,0 +1,504 @@
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"net"
|
||||
"path"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"testing/fstest"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/urlfilter"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
nl = "\n"
|
||||
sp = " "
|
||||
)
|
||||
|
||||
const closeCalled errors.Error = "close method called"
|
||||
|
||||
// fsWatcherOnCloseStub is a stub implementation of the Close method of
|
||||
// aghos.FSWatcher.
|
||||
func fsWatcherOnCloseStub() (err error) {
|
||||
return closeCalled
|
||||
}
|
||||
|
||||
func TestNewHostsContainer(t *testing.T) {
|
||||
const dirname = "dir"
|
||||
const filename = "file1"
|
||||
|
||||
p := path.Join(dirname, filename)
|
||||
|
||||
testFS := fstest.MapFS{
|
||||
p: &fstest.MapFile{Data: []byte("127.0.0.1 localhost")},
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
paths []string
|
||||
wantErr error
|
||||
wantPatterns []string
|
||||
}{{
|
||||
name: "one_file",
|
||||
paths: []string{p},
|
||||
wantErr: nil,
|
||||
wantPatterns: []string{p},
|
||||
}, {
|
||||
name: "no_files",
|
||||
paths: []string{},
|
||||
wantErr: errNoPaths,
|
||||
wantPatterns: nil,
|
||||
}, {
|
||||
name: "non-existent_file",
|
||||
paths: []string{path.Join(dirname, filename+"2")},
|
||||
wantErr: fs.ErrNotExist,
|
||||
wantPatterns: nil,
|
||||
}, {
|
||||
name: "whole_dir",
|
||||
paths: []string{dirname},
|
||||
wantErr: nil,
|
||||
wantPatterns: []string{path.Join(dirname, "*")},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
onAdd := func(name string) (err error) {
|
||||
assert.Contains(t, tc.paths, name)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var eventsCalledCounter uint32
|
||||
eventsCh := make(chan struct{})
|
||||
onEvents := func() (e <-chan struct{}) {
|
||||
assert.Equal(t, uint32(1), atomic.AddUint32(&eventsCalledCounter, 1))
|
||||
|
||||
return eventsCh
|
||||
}
|
||||
|
||||
hc, err := NewHostsContainer(testFS, &aghtest.FSWatcher{
|
||||
OnEvents: onEvents,
|
||||
OnAdd: onAdd,
|
||||
OnClose: fsWatcherOnCloseStub,
|
||||
}, tc.paths...)
|
||||
if tc.wantErr != nil {
|
||||
require.ErrorIs(t, err, tc.wantErr)
|
||||
|
||||
assert.Nil(t, hc)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.ErrorIs(t, hc.Close(), closeCalled)
|
||||
})
|
||||
|
||||
require.NotNil(t, hc)
|
||||
|
||||
assert.Equal(t, tc.wantPatterns, hc.patterns)
|
||||
assert.NotNil(t, <-hc.Upd())
|
||||
|
||||
eventsCh <- struct{}{}
|
||||
assert.Equal(t, uint32(1), atomic.LoadUint32(&eventsCalledCounter))
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("nil_fs", func(t *testing.T) {
|
||||
require.Panics(t, func() {
|
||||
_, _ = NewHostsContainer(nil, &aghtest.FSWatcher{
|
||||
// Those shouldn't panic.
|
||||
OnEvents: func() (e <-chan struct{}) { return nil },
|
||||
OnAdd: func(name string) (err error) { return nil },
|
||||
OnClose: func() (err error) { return nil },
|
||||
}, p)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("nil_watcher", func(t *testing.T) {
|
||||
require.Panics(t, func() {
|
||||
_, _ = NewHostsContainer(testFS, nil, p)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("err_watcher", func(t *testing.T) {
|
||||
const errOnAdd errors.Error = "error"
|
||||
|
||||
errWatcher := &aghtest.FSWatcher{
|
||||
OnEvents: func() (e <-chan struct{}) { panic("not implemented") },
|
||||
OnAdd: func(name string) (err error) { return errOnAdd },
|
||||
OnClose: func() (err error) { panic("not implemented") },
|
||||
}
|
||||
|
||||
hc, err := NewHostsContainer(testFS, errWatcher, p)
|
||||
require.ErrorIs(t, err, errOnAdd)
|
||||
|
||||
assert.Nil(t, hc)
|
||||
})
|
||||
}
|
||||
|
||||
func TestHostsContainer_Refresh(t *testing.T) {
|
||||
knownIP := net.IP{127, 0, 0, 1}
|
||||
|
||||
const knownHost = "localhost"
|
||||
const knownAlias = "hocallost"
|
||||
|
||||
const dirname = "dir"
|
||||
const filename1 = "file1"
|
||||
const filename2 = "file2"
|
||||
|
||||
p1 := path.Join(dirname, filename1)
|
||||
p2 := path.Join(dirname, filename2)
|
||||
|
||||
testFS := fstest.MapFS{
|
||||
p1: &fstest.MapFile{
|
||||
Data: []byte(strings.Join([]string{knownIP.String(), knownHost}, sp) + nl),
|
||||
},
|
||||
}
|
||||
|
||||
eventsCh := make(chan struct{}, 1)
|
||||
t.Cleanup(func() { close(eventsCh) })
|
||||
|
||||
w := &aghtest.FSWatcher{
|
||||
OnEvents: func() (e <-chan struct{}) { return eventsCh },
|
||||
OnAdd: func(name string) (err error) {
|
||||
assert.Equal(t, dirname, name)
|
||||
|
||||
return nil
|
||||
},
|
||||
OnClose: fsWatcherOnCloseStub,
|
||||
}
|
||||
|
||||
hc, err := NewHostsContainer(testFS, w, dirname)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { require.ErrorIs(t, hc.Close(), closeCalled) })
|
||||
|
||||
checkRefresh := func(t *testing.T, wantHosts []string) {
|
||||
upd, ok := <-hc.Upd()
|
||||
require.True(t, ok)
|
||||
require.NotNil(t, upd)
|
||||
|
||||
assert.Equal(t, 1, upd.Len())
|
||||
|
||||
v, ok := upd.Get(knownIP)
|
||||
require.True(t, ok)
|
||||
|
||||
var hosts []string
|
||||
hosts, ok = v.([]string)
|
||||
require.True(t, ok)
|
||||
require.Len(t, hosts, len(wantHosts))
|
||||
|
||||
assert.Equal(t, wantHosts, hosts)
|
||||
}
|
||||
|
||||
t.Run("initial_refresh", func(t *testing.T) {
|
||||
checkRefresh(t, []string{knownHost})
|
||||
})
|
||||
|
||||
testFS[p2] = &fstest.MapFile{
|
||||
Data: []byte(strings.Join([]string{knownIP.String(), knownAlias}, sp) + nl),
|
||||
}
|
||||
|
||||
eventsCh <- struct{}{}
|
||||
|
||||
t.Run("second_refresh", func(t *testing.T) {
|
||||
checkRefresh(t, []string{knownHost, knownAlias})
|
||||
})
|
||||
}
|
||||
|
||||
func TestHostsContainer_MatchRequest(t *testing.T) {
|
||||
var (
|
||||
ip4 = net.IP{127, 0, 0, 1}
|
||||
ip6 = net.IP{
|
||||
0, 0, 0, 0,
|
||||
0, 0, 0, 0,
|
||||
0, 0, 0, 0,
|
||||
0, 0, 0, 1,
|
||||
}
|
||||
|
||||
hostname4 = "localhost"
|
||||
hostname6 = "localhostv6"
|
||||
hostname4a = "abcd"
|
||||
|
||||
reversed4, _ = netutil.IPToReversedAddr(ip4)
|
||||
reversed6, _ = netutil.IPToReversedAddr(ip6)
|
||||
)
|
||||
|
||||
const filename = "file1"
|
||||
|
||||
gsfs := fstest.MapFS{
|
||||
filename: &fstest.MapFile{Data: []byte(
|
||||
strings.Join([]string{ip4.String(), hostname4, hostname4a}, sp) + nl +
|
||||
strings.Join([]string{ip6.String(), hostname6}, sp) + nl +
|
||||
strings.Join([]string{"256.256.256.256", "fakebroadcast"}, sp) + nl,
|
||||
)},
|
||||
}
|
||||
|
||||
hc, err := NewHostsContainer(gsfs, &aghtest.FSWatcher{
|
||||
OnEvents: func() (e <-chan struct{}) { panic("not implemented") },
|
||||
OnAdd: func(name string) (err error) {
|
||||
assert.Equal(t, filename, name)
|
||||
|
||||
return nil
|
||||
},
|
||||
OnClose: fsWatcherOnCloseStub,
|
||||
}, filename)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { require.ErrorIs(t, hc.Close(), closeCalled) })
|
||||
|
||||
testCase := []struct {
|
||||
name string
|
||||
want interface{}
|
||||
req urlfilter.DNSRequest
|
||||
}{{
|
||||
name: "a",
|
||||
want: ip4.To16(),
|
||||
req: urlfilter.DNSRequest{
|
||||
Hostname: hostname4,
|
||||
DNSType: dns.TypeA,
|
||||
},
|
||||
}, {
|
||||
name: "aaaa",
|
||||
want: ip6,
|
||||
req: urlfilter.DNSRequest{
|
||||
Hostname: hostname6,
|
||||
DNSType: dns.TypeA,
|
||||
},
|
||||
}, {
|
||||
name: "ptr",
|
||||
want: dns.Fqdn(hostname4),
|
||||
req: urlfilter.DNSRequest{
|
||||
Hostname: reversed4,
|
||||
DNSType: dns.TypePTR,
|
||||
},
|
||||
}, {
|
||||
name: "ptr_v6",
|
||||
want: dns.Fqdn(hostname6),
|
||||
req: urlfilter.DNSRequest{
|
||||
Hostname: reversed6,
|
||||
DNSType: dns.TypePTR,
|
||||
},
|
||||
}, {
|
||||
name: "a_alias",
|
||||
want: ip4.To16(),
|
||||
req: urlfilter.DNSRequest{
|
||||
Hostname: hostname4a,
|
||||
DNSType: dns.TypeA,
|
||||
},
|
||||
}}
|
||||
|
||||
for _, tc := range testCase {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
res, ok := hc.MatchRequest(tc.req)
|
||||
require.False(t, ok)
|
||||
|
||||
assert.Equal(t, tc.want, res.DNSRewrites()[0].DNSRewrite.Value)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("cname", func(t *testing.T) {
|
||||
res, ok := hc.MatchRequest(urlfilter.DNSRequest{
|
||||
Hostname: hostname4,
|
||||
DNSType: dns.TypeCNAME,
|
||||
})
|
||||
require.False(t, ok)
|
||||
|
||||
assert.Empty(t, res)
|
||||
})
|
||||
}
|
||||
|
||||
func TestHostsContainer_PathsToPatterns(t *testing.T) {
|
||||
const (
|
||||
dir0 = "dir"
|
||||
dir1 = "dir_1"
|
||||
fn1 = "file_1"
|
||||
fn2 = "file_2"
|
||||
fn3 = "file_3"
|
||||
fn4 = "file_4"
|
||||
)
|
||||
|
||||
fp1 := path.Join(dir0, fn1)
|
||||
fp2 := path.Join(dir0, fn2)
|
||||
fp3 := path.Join(dir0, dir1, fn3)
|
||||
|
||||
gsfs := fstest.MapFS{
|
||||
fp1: &fstest.MapFile{Data: []byte{1}},
|
||||
fp2: &fstest.MapFile{Data: []byte{2}},
|
||||
fp3: &fstest.MapFile{Data: []byte{3}},
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
wantErr error
|
||||
want []string
|
||||
paths []string
|
||||
}{{
|
||||
name: "no_paths",
|
||||
wantErr: nil,
|
||||
want: nil,
|
||||
paths: nil,
|
||||
}, {
|
||||
name: "single_file",
|
||||
wantErr: nil,
|
||||
want: []string{fp1},
|
||||
paths: []string{fp1},
|
||||
}, {
|
||||
name: "several_files",
|
||||
wantErr: nil,
|
||||
want: []string{fp1, fp2},
|
||||
paths: []string{fp1, fp2},
|
||||
}, {
|
||||
name: "whole_dir",
|
||||
wantErr: nil,
|
||||
want: []string{path.Join(dir0, "*")},
|
||||
paths: []string{dir0},
|
||||
}, {
|
||||
name: "file_and_dir",
|
||||
wantErr: nil,
|
||||
want: []string{fp1, path.Join(dir0, dir1, "*")},
|
||||
paths: []string{fp1, path.Join(dir0, dir1)},
|
||||
}, {
|
||||
name: "non-existing",
|
||||
wantErr: fs.ErrNotExist,
|
||||
want: nil,
|
||||
paths: []string{path.Join(dir0, "file_3")},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
patterns, err := pathsToPatterns(gsfs, tc.paths)
|
||||
if tc.wantErr != nil {
|
||||
assert.ErrorIs(t, err, tc.wantErr)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.want, patterns)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUniqueRules_AddPair(t *testing.T) {
|
||||
knownIP := net.IP{1, 2, 3, 4}
|
||||
|
||||
const knownHost = "host1"
|
||||
|
||||
ipToHost := netutil.NewIPMap(0)
|
||||
ipToHost.Set(knownIP, []string{knownHost})
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
host string
|
||||
wantRules string
|
||||
ip net.IP
|
||||
}{{
|
||||
name: "new_one",
|
||||
host: "host2",
|
||||
wantRules: "||host2^$dnsrewrite=NOERROR;A;1.2.3.4\n" +
|
||||
"||4.3.2.1.in-addr.arpa^$dnsrewrite=NOERROR;PTR;host2.\n",
|
||||
ip: knownIP,
|
||||
}, {
|
||||
name: "existing_one",
|
||||
host: knownHost,
|
||||
wantRules: "",
|
||||
ip: knownIP,
|
||||
}, {
|
||||
name: "new_ip",
|
||||
host: knownHost,
|
||||
wantRules: "||" + knownHost + "^$dnsrewrite=NOERROR;A;1.2.3.5\n" +
|
||||
"||5.3.2.1.in-addr.arpa^$dnsrewrite=NOERROR;PTR;" + knownHost + ".\n",
|
||||
ip: net.IP{1, 2, 3, 5},
|
||||
}, {
|
||||
name: "bad_ip",
|
||||
host: knownHost,
|
||||
wantRules: "",
|
||||
ip: net.IP{1, 2, 3, 4, 5},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
hp := hostsParser{
|
||||
rules: &strings.Builder{},
|
||||
table: ipToHost.ShallowClone(),
|
||||
}
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
hp.addPair(tc.ip, tc.host)
|
||||
assert.Equal(t, tc.wantRules, hp.rules.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUniqueRules_ParseLine(t *testing.T) {
|
||||
const (
|
||||
hostname = "localhost"
|
||||
alias = "hocallost"
|
||||
)
|
||||
|
||||
knownIP := net.IP{127, 0, 0, 1}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
line string
|
||||
wantIP net.IP
|
||||
wantHosts []string
|
||||
}{{
|
||||
name: "simple",
|
||||
line: strings.Join([]string{knownIP.String(), hostname}, sp),
|
||||
wantIP: knownIP,
|
||||
wantHosts: []string{"localhost"},
|
||||
}, {
|
||||
name: "aliases",
|
||||
line: strings.Join([]string{knownIP.String(), hostname, alias}, sp),
|
||||
wantIP: knownIP,
|
||||
wantHosts: []string{"localhost", "hocallost"},
|
||||
}, {
|
||||
name: "invalid_line",
|
||||
line: knownIP.String(),
|
||||
wantIP: nil,
|
||||
wantHosts: nil,
|
||||
}, {
|
||||
name: "invalid_line_hostname",
|
||||
line: strings.Join([]string{knownIP.String(), "#" + hostname}, sp),
|
||||
wantIP: knownIP,
|
||||
wantHosts: nil,
|
||||
}, {
|
||||
name: "commented_aliases",
|
||||
line: strings.Join([]string{knownIP.String(), hostname, "#" + alias}, sp),
|
||||
wantIP: knownIP,
|
||||
wantHosts: []string{"localhost"},
|
||||
}, {
|
||||
name: "whole_comment",
|
||||
line: strings.Join([]string{"#", knownIP.String(), hostname}, sp),
|
||||
wantIP: nil,
|
||||
wantHosts: nil,
|
||||
}, {
|
||||
name: "partial_comment",
|
||||
line: strings.Join([]string{knownIP.String(), hostname[:4] + "#" + hostname[4:]}, sp),
|
||||
wantIP: knownIP,
|
||||
wantHosts: []string{hostname[:4]},
|
||||
}, {
|
||||
name: "empty",
|
||||
line: ``,
|
||||
wantIP: nil,
|
||||
wantHosts: nil,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
hp := hostsParser{}
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ip, hosts := hp.parseLine(tc.line)
|
||||
assert.True(t, tc.wantIP.Equal(ip))
|
||||
assert.Equal(t, tc.wantHosts, hosts)
|
||||
})
|
||||
}
|
||||
}
|
33
internal/aghnet/hostscontainer_windows.go
Normal file
33
internal/aghnet/hostscontainer_windows.go
Normal file
@ -0,0 +1,33 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func defaultHostsPaths() (paths []string) {
|
||||
sysDir, err := windows.GetSystemDirectory()
|
||||
if err != nil {
|
||||
log.Error("getting system directory: %s", err)
|
||||
|
||||
return []string{}
|
||||
}
|
||||
|
||||
// Split all the elements of the path to join them afterwards. This is
|
||||
// needed to make the Windows-specific path string returned by
|
||||
// windows.GetSystemDirectory to be compatible with fs.FS.
|
||||
pathElems := strings.Split(sysDir, string(os.PathSeparator))
|
||||
if len(pathElems) > 0 && pathElems[0] == filepath.VolumeName(sysDir) {
|
||||
pathElems = pathElems[1:]
|
||||
}
|
||||
|
||||
return []string{path.Join(append(pathElems, "drivers/etc/hosts")...)}
|
||||
}
|
@ -18,9 +18,11 @@ func canBindPrivilegedPorts() (can bool, err error) {
|
||||
}
|
||||
|
||||
func ifaceHasStaticIP(ifaceName string) (ok bool, err error) {
|
||||
const filename = "/etc/rc.conf"
|
||||
const rcConfFilename = "etc/rc.conf"
|
||||
|
||||
return aghos.FileWalker(interfaceName(ifaceName).rcConfStaticConfig).Walk(filename)
|
||||
walker := aghos.FileWalker(interfaceName(ifaceName).rcConfStaticConfig)
|
||||
|
||||
return walker.Walk(aghos.RootDirFS(), rcConfFilename)
|
||||
}
|
||||
|
||||
// rcConfStaticConfig checks if the interface is configured by /etc/rc.conf to
|
||||
|
@ -85,17 +85,17 @@ func ifaceHasStaticIP(ifaceName string) (has bool, err error) {
|
||||
|
||||
iface := interfaceName(ifaceName)
|
||||
|
||||
for _, pair := range []struct {
|
||||
for _, pair := range [...]struct {
|
||||
aghos.FileWalker
|
||||
filename string
|
||||
}{{
|
||||
FileWalker: iface.dhcpcdStaticConfig,
|
||||
filename: "/etc/dhcpcd.conf",
|
||||
filename: "etc/dhcpcd.conf",
|
||||
}, {
|
||||
FileWalker: iface.ifacesStaticConfig,
|
||||
filename: "/etc/network/interfaces",
|
||||
filename: "etc/network/interfaces",
|
||||
}} {
|
||||
has, err = pair.Walk(pair.filename)
|
||||
has, err = pair.Walk(aghos.RootDirFS(), pair.filename)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
@ -12,8 +12,6 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const nl = "\n"
|
||||
|
||||
func TestDHCPCDStaticConfig(t *testing.T) {
|
||||
const iface interfaceName = `wlan0`
|
||||
|
||||
|
@ -18,9 +18,9 @@ func canBindPrivilegedPorts() (can bool, err error) {
|
||||
}
|
||||
|
||||
func ifaceHasStaticIP(ifaceName string) (ok bool, err error) {
|
||||
filename := fmt.Sprintf("/etc/hostname.%s", ifaceName)
|
||||
filename := fmt.Sprintf("etc/hostname.%s", ifaceName)
|
||||
|
||||
return aghos.FileWalker(hostnameIfStaticConfig).Walk(filename)
|
||||
return aghos.FileWalker(hostnameIfStaticConfig).Walk(aghos.RootDirFS(), filename)
|
||||
}
|
||||
|
||||
// hostnameIfStaticConfig checks if the interface is configured by
|
||||
|
@ -4,10 +4,15 @@ import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
aghtest.DiscardLogOutput(m)
|
||||
}
|
||||
|
||||
func TestGetValidNetInterfacesForWeb(t *testing.T) {
|
||||
ifaces, err := GetValidNetInterfacesForWeb()
|
||||
require.NoErrorf(t, err, "cannot get net interfaces: %s", err)
|
||||
|
@ -3,8 +3,7 @@ package aghos
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"io/fs"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
@ -14,10 +13,10 @@ import (
|
||||
// FileWalker is the signature of a function called for files in the file tree.
|
||||
// As opposed to filepath.Walk it only walk the files (not directories) matching
|
||||
// the provided pattern and those returned by function itself. All patterns
|
||||
// should be valid for filepath.Glob. If cont is false, the walking terminates.
|
||||
// Each opened file is also limited for reading to MaxWalkedFileSize.
|
||||
// should be valid for fs.Glob. If cont is false, the walking terminates. Each
|
||||
// opened file is also limited for reading to MaxWalkedFileSize.
|
||||
//
|
||||
// TODO(e.burkov): Consider moving to the separate package like pathutil.
|
||||
// TODO(e.burkov, a.garipov): Move into another package like aghfs.
|
||||
//
|
||||
// TODO(e.burkov): Think about passing filename or any additional data.
|
||||
type FileWalker func(r io.Reader) (patterns []string, cont bool, err error)
|
||||
@ -26,15 +25,19 @@ type FileWalker func(r io.Reader) (patterns []string, cont bool, err error)
|
||||
// check.
|
||||
const MaxWalkedFileSize = 1024 * 1024
|
||||
|
||||
// checkFile tries to open and process a single file located on sourcePath.
|
||||
func checkFile(c FileWalker, sourcePath string) (patterns []string, cont bool, err error) {
|
||||
var f *os.File
|
||||
f, err = os.Open(sourcePath)
|
||||
// checkFile tries to open and process a single file located on sourcePath in
|
||||
// the specified fsys. The path is skipped if it's a directory.
|
||||
func checkFile(
|
||||
fsys fs.FS,
|
||||
c FileWalker,
|
||||
sourcePath string,
|
||||
) (patterns []string, cont bool, err error) {
|
||||
var f fs.File
|
||||
f, err = fsys.Open(sourcePath)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
// Ignore non-existing files since this may only happen
|
||||
// when the file was removed after filepath.Glob matched
|
||||
// it.
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
// Ignore non-existing files since this may only happen when the
|
||||
// file was removed after filepath.Glob matched it.
|
||||
return nil, true, nil
|
||||
}
|
||||
|
||||
@ -42,9 +45,18 @@ func checkFile(c FileWalker, sourcePath string) (patterns []string, cont bool, e
|
||||
}
|
||||
defer func() { err = errors.WithDeferred(err, f.Close()) }()
|
||||
|
||||
var fi fs.FileInfo
|
||||
if fi, err = f.Stat(); err != nil {
|
||||
return nil, true, err
|
||||
}
|
||||
if fi.IsDir() {
|
||||
// Skip the directories.
|
||||
return nil, true, nil
|
||||
}
|
||||
|
||||
var r io.Reader
|
||||
// Ignore the error since LimitReader function returns error only if
|
||||
// passed limit value is less than zero, but the constant used.
|
||||
// Ignore the error since LimitReader function returns error only if passed
|
||||
// limit value is less than zero, but the constant used.
|
||||
//
|
||||
// TODO(e.burkov): Make variable.
|
||||
r, _ = aghio.LimitReader(f, MaxWalkedFileSize)
|
||||
@ -52,13 +64,17 @@ func checkFile(c FileWalker, sourcePath string) (patterns []string, cont bool, e
|
||||
return c(r)
|
||||
}
|
||||
|
||||
// handlePatterns parses the patterns and ignores duplicates using srcSet.
|
||||
// srcSet must be non-nil.
|
||||
func handlePatterns(srcSet *stringutil.Set, patterns ...string) (sub []string, err error) {
|
||||
// handlePatterns parses the patterns in fsys and ignores duplicates using
|
||||
// srcSet. srcSet must be non-nil.
|
||||
func handlePatterns(
|
||||
fsys fs.FS,
|
||||
srcSet *stringutil.Set,
|
||||
patterns ...string,
|
||||
) (sub []string, err error) {
|
||||
sub = make([]string, 0, len(patterns))
|
||||
for _, p := range patterns {
|
||||
var matches []string
|
||||
matches, err = filepath.Glob(p)
|
||||
matches, err = fs.Glob(fsys, p)
|
||||
if err != nil {
|
||||
// Enrich error with the pattern because filepath.Glob
|
||||
// doesn't do it.
|
||||
@ -78,14 +94,14 @@ func handlePatterns(srcSet *stringutil.Set, patterns ...string) (sub []string, e
|
||||
return sub, nil
|
||||
}
|
||||
|
||||
// Walk starts walking the files defined by initPattern. It only returns true
|
||||
// if c signed to stop walking.
|
||||
func (c FileWalker) Walk(initPattern string) (ok bool, err error) {
|
||||
// The slice of sources keeps the order in which the files are walked
|
||||
// since srcSet.Values() returns strings in undefined order.
|
||||
// Walk starts walking the files in fsys defined by patterns from initial.
|
||||
// It only returns true if fw signed to stop walking.
|
||||
func (fw FileWalker) Walk(fsys fs.FS, initial ...string) (ok bool, err error) {
|
||||
// The slice of sources keeps the order in which the files are walked since
|
||||
// srcSet.Values() returns strings in undefined order.
|
||||
srcSet := stringutil.NewSet()
|
||||
var src []string
|
||||
src, err = handlePatterns(srcSet, initPattern)
|
||||
src, err = handlePatterns(fsys, srcSet, initial...)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@ -97,7 +113,7 @@ func (c FileWalker) Walk(initPattern string) (ok bool, err error) {
|
||||
var patterns []string
|
||||
var cont bool
|
||||
filename = src[i]
|
||||
patterns, cont, err = checkFile(c, src[i])
|
||||
patterns, cont, err = checkFile(fsys, fw, src[i])
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@ -107,7 +123,7 @@ func (c FileWalker) Walk(initPattern string) (ok bool, err error) {
|
||||
}
|
||||
|
||||
var subsrc []string
|
||||
subsrc, err = handlePatterns(srcSet, patterns...)
|
||||
subsrc, err = handlePatterns(fsys, srcSet, patterns...)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
@ -4,56 +4,19 @@ import (
|
||||
"bufio"
|
||||
"io"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"path"
|
||||
"testing"
|
||||
"testing/fstest"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// testFSDir maps entries' names to entries which should either be a testFSDir
|
||||
// or byte slice.
|
||||
type testFSDir map[string]interface{}
|
||||
|
||||
// testFSGen is used to generate a temporary filesystem consisting of
|
||||
// directories and plain text files from itself.
|
||||
type testFSGen testFSDir
|
||||
|
||||
// gen returns the name of top directory of the generated filesystem.
|
||||
func (g testFSGen) gen(t *testing.T) (dirName string) {
|
||||
t.Helper()
|
||||
|
||||
dirName = t.TempDir()
|
||||
g.rangeThrough(t, dirName)
|
||||
|
||||
return dirName
|
||||
}
|
||||
|
||||
func (g testFSGen) rangeThrough(t *testing.T, dirName string) {
|
||||
const perm fs.FileMode = 0o777
|
||||
|
||||
for k, e := range g {
|
||||
switch e := e.(type) {
|
||||
case []byte:
|
||||
require.NoError(t, os.WriteFile(filepath.Join(dirName, k), e, perm))
|
||||
|
||||
case testFSDir:
|
||||
newDir := filepath.Join(dirName, k)
|
||||
require.NoError(t, os.Mkdir(newDir, perm))
|
||||
|
||||
testFSGen(e).rangeThrough(t, newDir)
|
||||
default:
|
||||
t.Fatalf("unexpected entry type %T", e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileWalker_Walk(t *testing.T) {
|
||||
const attribute = `000`
|
||||
|
||||
makeFileWalker := func(dirName string) (fw FileWalker) {
|
||||
makeFileWalker := func(_ string) (fw FileWalker) {
|
||||
return func(r io.Reader) (patterns []string, cont bool, err error) {
|
||||
s := bufio.NewScanner(r)
|
||||
for s.Scan() {
|
||||
@ -63,7 +26,7 @@ func TestFileWalker_Walk(t *testing.T) {
|
||||
}
|
||||
|
||||
if len(line) != 0 {
|
||||
patterns = append(patterns, filepath.Join(dirName, line))
|
||||
patterns = append(patterns, path.Join(".", line))
|
||||
}
|
||||
}
|
||||
|
||||
@ -74,136 +37,150 @@ func TestFileWalker_Walk(t *testing.T) {
|
||||
const nl = "\n"
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
testFS testFSGen
|
||||
testFS fstest.MapFS
|
||||
want assert.BoolAssertionFunc
|
||||
initPattern string
|
||||
want bool
|
||||
name string
|
||||
}{{
|
||||
name: "simple",
|
||||
testFS: testFSGen{
|
||||
"simple_0001.txt": []byte(attribute + nl),
|
||||
testFS: fstest.MapFS{
|
||||
"simple_0001.txt": &fstest.MapFile{Data: []byte(attribute + nl)},
|
||||
},
|
||||
initPattern: "simple_0001.txt",
|
||||
want: true,
|
||||
want: assert.True,
|
||||
}, {
|
||||
name: "chain",
|
||||
testFS: testFSGen{
|
||||
"chain_0001.txt": []byte(`chain_0002.txt` + nl),
|
||||
"chain_0002.txt": []byte(`chain_0003.txt` + nl),
|
||||
"chain_0003.txt": []byte(attribute + nl),
|
||||
testFS: fstest.MapFS{
|
||||
"chain_0001.txt": &fstest.MapFile{Data: []byte(`chain_0002.txt` + nl)},
|
||||
"chain_0002.txt": &fstest.MapFile{Data: []byte(`chain_0003.txt` + nl)},
|
||||
"chain_0003.txt": &fstest.MapFile{Data: []byte(attribute + nl)},
|
||||
},
|
||||
initPattern: "chain_0001.txt",
|
||||
want: true,
|
||||
want: assert.True,
|
||||
}, {
|
||||
name: "several",
|
||||
testFS: testFSGen{
|
||||
"several_0001.txt": []byte(`several_*` + nl),
|
||||
"several_0002.txt": []byte(`several_0001.txt` + nl),
|
||||
"several_0003.txt": []byte(attribute + nl),
|
||||
testFS: fstest.MapFS{
|
||||
"several_0001.txt": &fstest.MapFile{Data: []byte(`several_*` + nl)},
|
||||
"several_0002.txt": &fstest.MapFile{Data: []byte(`several_0001.txt` + nl)},
|
||||
"several_0003.txt": &fstest.MapFile{Data: []byte(attribute + nl)},
|
||||
},
|
||||
initPattern: "several_0001.txt",
|
||||
want: true,
|
||||
want: assert.True,
|
||||
}, {
|
||||
name: "no",
|
||||
testFS: testFSGen{
|
||||
"no_0001.txt": []byte(nl),
|
||||
"no_0002.txt": []byte(nl),
|
||||
"no_0003.txt": []byte(nl),
|
||||
testFS: fstest.MapFS{
|
||||
"no_0001.txt": &fstest.MapFile{Data: []byte(nl)},
|
||||
"no_0002.txt": &fstest.MapFile{Data: []byte(nl)},
|
||||
"no_0003.txt": &fstest.MapFile{Data: []byte(nl)},
|
||||
},
|
||||
initPattern: "no_*",
|
||||
want: false,
|
||||
want: assert.False,
|
||||
}, {
|
||||
name: "subdirectory",
|
||||
testFS: testFSGen{
|
||||
"dir": testFSDir{
|
||||
"subdir_0002.txt": []byte(attribute + nl),
|
||||
testFS: fstest.MapFS{
|
||||
path.Join("dir", "subdir_0002.txt"): &fstest.MapFile{
|
||||
Data: []byte(attribute + nl),
|
||||
},
|
||||
"subdir_0001.txt": []byte(`dir/*`),
|
||||
"subdir_0001.txt": &fstest.MapFile{Data: []byte(`dir/*`)},
|
||||
},
|
||||
initPattern: "subdir_0001.txt",
|
||||
want: true,
|
||||
want: assert.True,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
testDir := tc.testFS.gen(t)
|
||||
fw := makeFileWalker(testDir)
|
||||
fw := makeFileWalker("")
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ok, err := fw.Walk(filepath.Join(testDir, tc.initPattern))
|
||||
ok, err := fw.Walk(tc.testFS, tc.initPattern)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.want, ok)
|
||||
tc.want(t, ok)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("pattern_malformed", func(t *testing.T) {
|
||||
ok, err := makeFileWalker("").Walk("[]")
|
||||
f := fstest.MapFS{}
|
||||
ok, err := makeFileWalker("").Walk(f, "[]")
|
||||
require.Error(t, err)
|
||||
|
||||
assert.False(t, ok)
|
||||
assert.ErrorIs(t, err, filepath.ErrBadPattern)
|
||||
assert.ErrorIs(t, err, path.ErrBadPattern)
|
||||
})
|
||||
|
||||
t.Run("bad_filename", func(t *testing.T) {
|
||||
dir := testFSGen{
|
||||
"bad_filename.txt": []byte("[]"),
|
||||
}.gen(t)
|
||||
fw := FileWalker(func(r io.Reader) (patterns []string, cont bool, err error) {
|
||||
const filename = "bad_filename.txt"
|
||||
|
||||
f := fstest.MapFS{
|
||||
filename: &fstest.MapFile{Data: []byte("[]")},
|
||||
}
|
||||
ok, err := FileWalker(func(r io.Reader) (patterns []string, cont bool, err error) {
|
||||
s := bufio.NewScanner(r)
|
||||
for s.Scan() {
|
||||
patterns = append(patterns, s.Text())
|
||||
}
|
||||
|
||||
return patterns, true, s.Err()
|
||||
})
|
||||
|
||||
ok, err := fw.Walk(filepath.Join(dir, "bad_filename.txt"))
|
||||
}).Walk(f, filename)
|
||||
require.Error(t, err)
|
||||
|
||||
assert.False(t, ok)
|
||||
assert.ErrorIs(t, err, filepath.ErrBadPattern)
|
||||
assert.ErrorIs(t, err, path.ErrBadPattern)
|
||||
})
|
||||
|
||||
t.Run("itself_error", func(t *testing.T) {
|
||||
const rerr errors.Error = "returned error"
|
||||
|
||||
dir := testFSGen{
|
||||
"mockfile.txt": []byte(`mockdata`),
|
||||
}.gen(t)
|
||||
f := fstest.MapFS{
|
||||
"mockfile.txt": &fstest.MapFile{Data: []byte(`mockdata`)},
|
||||
}
|
||||
|
||||
ok, err := FileWalker(func(r io.Reader) (patterns []string, ok bool, err error) {
|
||||
return nil, true, rerr
|
||||
}).Walk(filepath.Join(dir, "*"))
|
||||
require.Error(t, err)
|
||||
require.False(t, ok)
|
||||
}).Walk(f, "*")
|
||||
require.ErrorIs(t, err, rerr)
|
||||
|
||||
assert.ErrorIs(t, err, rerr)
|
||||
assert.False(t, ok)
|
||||
})
|
||||
}
|
||||
|
||||
type errFS struct {
|
||||
fs.GlobFS
|
||||
}
|
||||
|
||||
const errErrFSOpen errors.Error = "this error is always returned"
|
||||
|
||||
func (efs *errFS) Open(name string) (fs.File, error) {
|
||||
return nil, errErrFSOpen
|
||||
}
|
||||
|
||||
func TestWalkerFunc_CheckFile(t *testing.T) {
|
||||
emptyFS := fstest.MapFS{}
|
||||
|
||||
t.Run("non-existing", func(t *testing.T) {
|
||||
_, ok, err := checkFile(nil, "lol")
|
||||
_, ok, err := checkFile(emptyFS, nil, "lol")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, ok)
|
||||
})
|
||||
|
||||
t.Run("invalid_argument", func(t *testing.T) {
|
||||
const badPath = "\x00"
|
||||
|
||||
_, ok, err := checkFile(nil, badPath)
|
||||
require.Error(t, err)
|
||||
_, ok, err := checkFile(&errFS{}, nil, "")
|
||||
require.ErrorIs(t, err, errErrFSOpen)
|
||||
|
||||
assert.False(t, ok)
|
||||
// TODO(e.burkov): Use assert.ErrorsIs within the error from
|
||||
// less platform-dependent package instead of syscall.EINVAL.
|
||||
//
|
||||
// See https://github.com/golang/go/issues/46849 and
|
||||
// https://github.com/golang/go/issues/30322.
|
||||
pathErr := &os.PathError{}
|
||||
require.ErrorAs(t, err, &pathErr)
|
||||
assert.Equal(t, "open", pathErr.Op)
|
||||
assert.Equal(t, badPath, pathErr.Path)
|
||||
})
|
||||
|
||||
t.Run("ignore_dirs", func(t *testing.T) {
|
||||
const dirName = "dir"
|
||||
|
||||
testFS := fstest.MapFS{
|
||||
path.Join(dirName, "file"): &fstest.MapFile{Data: []byte{}},
|
||||
}
|
||||
|
||||
patterns, ok, err := checkFile(testFS, nil, dirName)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Empty(t, patterns)
|
||||
assert.True(t, ok)
|
||||
})
|
||||
}
|
||||
|
131
internal/aghos/fswatcher.go
Normal file
131
internal/aghos/fswatcher.go
Normal file
@ -0,0 +1,131 @@
|
||||
package aghos
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/fsnotify/fsnotify"
|
||||
)
|
||||
|
||||
// event is a convenient alias for an empty struct to signal that watching
|
||||
// event happened.
|
||||
type event = struct{}
|
||||
|
||||
// FSWatcher tracks all the fyle system events and notifies about those.
|
||||
//
|
||||
// TODO(e.burkov, a.garipov): Move into another package like aghfs.
|
||||
type FSWatcher interface {
|
||||
io.Closer
|
||||
|
||||
// Events should return a read-only channel which notifies about events.
|
||||
Events() (e <-chan event)
|
||||
|
||||
// Add should check if the file named name is accessible and starts tracking
|
||||
// it.
|
||||
Add(name string) (err error)
|
||||
}
|
||||
|
||||
// osWatcher tracks the file system provided by the OS.
|
||||
type osWatcher struct {
|
||||
// w is the actual notifier that is handled by osWatcher.
|
||||
w *fsnotify.Watcher
|
||||
|
||||
// events is the channel to notify.
|
||||
events chan event
|
||||
}
|
||||
|
||||
const (
|
||||
// osWatcherPref is a prefix for logging and wrapping errors in osWathcer's
|
||||
// methods.
|
||||
osWatcherPref = "os watcher"
|
||||
)
|
||||
|
||||
// NewOSWritesWatcher creates FSWatcher that tracks the real file system of the
|
||||
// OS and notifies only about writing events.
|
||||
func NewOSWritesWatcher() (w FSWatcher, err error) {
|
||||
defer func() { err = errors.Annotate(err, "%s: %w", osWatcherPref) }()
|
||||
|
||||
var watcher *fsnotify.Watcher
|
||||
watcher, err = fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating watcher: %w", err)
|
||||
}
|
||||
|
||||
fsw := &osWatcher{
|
||||
w: watcher,
|
||||
events: make(chan event, 1),
|
||||
}
|
||||
|
||||
go fsw.handleErrors()
|
||||
go fsw.handleEvents()
|
||||
|
||||
return fsw, nil
|
||||
}
|
||||
|
||||
// handleErrors handles accompanying errors. It used to be called in a separate
|
||||
// goroutine.
|
||||
func (w *osWatcher) handleErrors() {
|
||||
defer log.OnPanic(fmt.Sprintf("%s: handling errors", osWatcherPref))
|
||||
|
||||
for err := range w.w.Errors {
|
||||
log.Error("%s: %s", osWatcherPref, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Events implements the FSWatcher interface for *osWatcher.
|
||||
func (w *osWatcher) Events() (e <-chan event) {
|
||||
return w.events
|
||||
}
|
||||
|
||||
// Add implements the FSWatcher interface for *osWatcher.
|
||||
func (w *osWatcher) Add(name string) (err error) {
|
||||
defer func() { err = errors.Annotate(err, "%s: %w", osWatcherPref) }()
|
||||
|
||||
if _, err = fs.Stat(RootDirFS(), name); err != nil {
|
||||
return fmt.Errorf("checking file %q: %w", name, err)
|
||||
}
|
||||
|
||||
return w.w.Add(filepath.Join("/", name))
|
||||
}
|
||||
|
||||
// Close implements the FSWatcher interface for *osWatcher.
|
||||
func (w *osWatcher) Close() (err error) {
|
||||
return w.w.Close()
|
||||
}
|
||||
|
||||
// handleEvents notifies about the received file system's event if needed. It
|
||||
// used to be called in a separate goroutine.
|
||||
func (w *osWatcher) handleEvents() {
|
||||
defer log.OnPanic(fmt.Sprintf("%s: handling events", osWatcherPref))
|
||||
|
||||
defer close(w.events)
|
||||
|
||||
ch := w.w.Events
|
||||
for e := range ch {
|
||||
if e.Op&fsnotify.Write == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip the following events assuming that sometimes the same event
|
||||
// occurrs several times.
|
||||
for ok := true; ok; {
|
||||
select {
|
||||
case _, ok = <-ch:
|
||||
// Go on.
|
||||
default:
|
||||
ok = false
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case w.events <- event{}:
|
||||
// Go on.
|
||||
default:
|
||||
log.Debug("%s: events buffer is full", osWatcherPref)
|
||||
}
|
||||
}
|
||||
}
|
@ -7,6 +7,8 @@ import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path"
|
||||
"runtime"
|
||||
@ -159,3 +161,10 @@ ScanLoop:
|
||||
func IsOpenWrt() (ok bool) {
|
||||
return isOpenWrt()
|
||||
}
|
||||
|
||||
// RootDirFS returns the fs.FS rooted at the operating system's root.
|
||||
func RootDirFS() (fsys fs.FS) {
|
||||
// Use empty string since os.DirFS implicitly prepends a slash to it. This
|
||||
// behavior is undocumented but it currently works.
|
||||
return os.DirFS("")
|
||||
}
|
||||
|
@ -26,6 +26,8 @@ func haveAdminRights() (bool, error) {
|
||||
}
|
||||
|
||||
func isOpenWrt() (ok bool) {
|
||||
const etcReleasePattern = "etc/*release*"
|
||||
|
||||
var err error
|
||||
ok, err = FileWalker(func(r io.Reader) (_ []string, cont bool, err error) {
|
||||
const osNameData = "openwrt"
|
||||
@ -39,7 +41,7 @@ func isOpenWrt() (ok bool) {
|
||||
}
|
||||
|
||||
return nil, !stringutil.ContainsFold(string(data), osNameData), nil
|
||||
}).Walk("/etc/*release*")
|
||||
}).Walk(RootDirFS(), etcReleasePattern)
|
||||
|
||||
return err == nil && ok
|
||||
}
|
||||
|
23
internal/aghtest/fswatcher.go
Normal file
23
internal/aghtest/fswatcher.go
Normal file
@ -0,0 +1,23 @@
|
||||
package aghtest
|
||||
|
||||
// FSWatcher is a mock aghos.FSWatcher implementation to use in tests.
|
||||
type FSWatcher struct {
|
||||
OnEvents func() (e <-chan struct{})
|
||||
OnAdd func(name string) (err error)
|
||||
OnClose func() (err error)
|
||||
}
|
||||
|
||||
// Events implements the aghos.FSWatcher interface for *FSWatcher.
|
||||
func (w *FSWatcher) Events() (e <-chan struct{}) {
|
||||
return w.OnEvents()
|
||||
}
|
||||
|
||||
// Add implements the aghos.FSWatcher interface for *FSWatcher.
|
||||
func (w *FSWatcher) Add(name string) (err error) {
|
||||
return w.OnAdd(name)
|
||||
}
|
||||
|
||||
// Close implements the aghos.FSWatcher interface for *FSWatcher.
|
||||
func (w *FSWatcher) Close() (err error) {
|
||||
return w.OnClose()
|
||||
}
|
@ -11,9 +11,10 @@ import (
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"testing/fstest"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
@ -1057,23 +1058,43 @@ func TestPTRResponseFromDHCPLeases(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestPTRResponseFromHosts(t *testing.T) {
|
||||
c := filtering.Config{
|
||||
EtcHosts: &aghnet.EtcHostsContainer{},
|
||||
// Prepare test hosts file.
|
||||
|
||||
const hostsFilename = "hosts"
|
||||
|
||||
testFS := fstest.MapFS{
|
||||
hostsFilename: &fstest.MapFile{Data: []byte(`
|
||||
127.0.0.1 host # comment
|
||||
::1 localhost#comment
|
||||
`)},
|
||||
}
|
||||
|
||||
// Prepare test hosts file.
|
||||
hf, err := os.CreateTemp("", "")
|
||||
const closeCalled errors.Error = "close method called"
|
||||
|
||||
var eventsCalledCounter uint32
|
||||
hc, err := aghnet.NewHostsContainer(testFS, &aghtest.FSWatcher{
|
||||
OnEvents: func() (e <-chan struct{}) {
|
||||
assert.Equal(t, uint32(1), atomic.AddUint32(&eventsCalledCounter, 1))
|
||||
|
||||
return nil
|
||||
},
|
||||
OnAdd: func(name string) (err error) {
|
||||
assert.Equal(t, hostsFilename, name)
|
||||
|
||||
return nil
|
||||
},
|
||||
OnClose: func() (err error) { return closeCalled },
|
||||
}, hostsFilename)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, hf.Close())
|
||||
assert.NoError(t, os.Remove(hf.Name()))
|
||||
assert.Equal(t, uint32(1), atomic.LoadUint32(&eventsCalledCounter))
|
||||
|
||||
require.ErrorIs(t, hc.Close(), closeCalled)
|
||||
})
|
||||
|
||||
_, _ = hf.WriteString(" 127.0.0.1 host # comment \n")
|
||||
_, _ = hf.WriteString(" ::1 localhost#comment \n")
|
||||
|
||||
c.EtcHosts.Init(hf.Name())
|
||||
t.Cleanup(c.EtcHosts.Close)
|
||||
c := filtering.Config{
|
||||
EtcHosts: hc,
|
||||
}
|
||||
|
||||
var snd *aghnet.SubnetDetector
|
||||
snd, err = aghnet.NewSubnetDetector()
|
||||
@ -1109,7 +1130,7 @@ func TestPTRResponseFromHosts(t *testing.T) {
|
||||
resp, err := dns.Exchange(req, addr.String())
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, resp.Answer, 1)
|
||||
require.Lenf(t, resp.Answer, 1, "%#v", resp)
|
||||
|
||||
assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype)
|
||||
assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name)
|
||||
|
@ -15,9 +15,9 @@ type DNSRewriteResult struct {
|
||||
// the server returns.
|
||||
type DNSRewriteResultResponse map[rules.RRType][]rules.RRValue
|
||||
|
||||
// processDNSRewrites processes DNS rewrite rules in dnsr. It returns
|
||||
// an empty result if dnsr is empty. Otherwise, the result will have
|
||||
// either CanonName or DNSRewriteResult set.
|
||||
// processDNSRewrites processes DNS rewrite rules in dnsr. It returns an empty
|
||||
// result if dnsr is empty. Otherwise, the result will have either CanonName or
|
||||
// DNSRewriteResult set.
|
||||
func (d *DNSFilter) processDNSRewrites(dnsr []*rules.NetworkRule) (res Result) {
|
||||
if len(dnsr) == 0 {
|
||||
return Result{}
|
||||
|
@ -73,7 +73,7 @@ type Config struct {
|
||||
|
||||
// EtcHosts is a container of IP-hostname pairs taken from the operating
|
||||
// system configuration files (e.g. /etc/hosts).
|
||||
EtcHosts *aghnet.EtcHostsContainer `yaml:"-"`
|
||||
EtcHosts *aghnet.HostsContainer `yaml:"-"`
|
||||
|
||||
// Called when the configuration is changed by HTTP request
|
||||
ConfigModified func() `yaml:"-"`
|
||||
@ -176,8 +176,8 @@ const (
|
||||
// FilteredBlockedService - the host is blocked by "blocked services" settings
|
||||
FilteredBlockedService
|
||||
|
||||
// Rewritten is returned when there was a rewrite by a legacy DNS
|
||||
// rewrite rule.
|
||||
// Rewritten is returned when there was a rewrite by a legacy DNS rewrite
|
||||
// rule.
|
||||
Rewritten
|
||||
|
||||
// RewrittenAutoHosts is returned when there was a rewrite by autohosts
|
||||
@ -186,8 +186,8 @@ const (
|
||||
|
||||
// RewrittenRule is returned when a $dnsrewrite filter rule was applied.
|
||||
//
|
||||
// TODO(a.garipov): Remove Rewritten and RewrittenAutoHosts by merging
|
||||
// their functionality into RewrittenRule.
|
||||
// TODO(a.garipov): Remove Rewritten and RewrittenAutoHosts by merging their
|
||||
// functionality into RewrittenRule.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/2499.
|
||||
RewrittenRule
|
||||
@ -371,24 +371,23 @@ type Result struct {
|
||||
// Reason is the reason for blocking or unblocking the request.
|
||||
Reason Reason `json:",omitempty"`
|
||||
|
||||
// Rules are applied rules. If Rules are not empty, each rule
|
||||
// is not nil.
|
||||
// Rules are applied rules. If Rules are not empty, each rule is not nil.
|
||||
Rules []*ResultRule `json:",omitempty"`
|
||||
|
||||
// ReverseHosts is the reverse lookup rewrite result. It is
|
||||
// empty unless Reason is set to RewrittenAutoHosts.
|
||||
// ReverseHosts is the reverse lookup rewrite result. It is empty unless
|
||||
// Reason is set to RewrittenAutoHosts.
|
||||
ReverseHosts []string `json:",omitempty"`
|
||||
|
||||
// IPList is the lookup rewrite result. It is empty unless
|
||||
// Reason is set to RewrittenAutoHosts or Rewritten.
|
||||
// IPList is the lookup rewrite result. It is empty unless Reason is set to
|
||||
// RewrittenAutoHosts or Rewritten.
|
||||
IPList []net.IP `json:",omitempty"`
|
||||
|
||||
// CanonName is the CNAME value from the lookup rewrite result.
|
||||
// It is empty unless Reason is set to Rewritten or RewrittenRule.
|
||||
// CanonName is the CNAME value from the lookup rewrite result. It is empty
|
||||
// unless Reason is set to Rewritten or RewrittenRule.
|
||||
CanonName string `json:",omitempty"`
|
||||
|
||||
// ServiceName is the name of the blocked service. It is empty
|
||||
// unless Reason is set to FilteredBlockedService.
|
||||
// ServiceName is the name of the blocked service. It is empty unless
|
||||
// Reason is set to FilteredBlockedService.
|
||||
ServiceName string `json:",omitempty"`
|
||||
|
||||
// DNSRewriteResult is the $dnsrewrite filter rule result.
|
||||
@ -446,43 +445,49 @@ func (d *DNSFilter) CheckHost(
|
||||
return Result{}, nil
|
||||
}
|
||||
|
||||
// checkEtcHosts compares the host against our /etc/hosts table. The err is
|
||||
// always nil, it is only there to make this a valid hostChecker function.
|
||||
func (d *DNSFilter) checkEtcHosts(
|
||||
host string,
|
||||
qtype uint16,
|
||||
_ *Settings,
|
||||
) (res Result, err error) {
|
||||
if d.Config.EtcHosts == nil {
|
||||
// matchSysHosts tries to match the host against the operating system's hosts
|
||||
// database.
|
||||
func (d *DNSFilter) matchSysHosts(host string, qtype uint16, setts *Settings) (res Result, err error) {
|
||||
if d.EtcHosts == nil {
|
||||
return Result{}, nil
|
||||
}
|
||||
|
||||
ips := d.Config.EtcHosts.Process(host, qtype)
|
||||
if ips != nil {
|
||||
res = Result{
|
||||
Reason: RewrittenAutoHosts,
|
||||
IPList: ips,
|
||||
}
|
||||
dnsres, _ := d.EtcHosts.MatchRequest(urlfilter.DNSRequest{
|
||||
Hostname: host,
|
||||
SortedClientTags: setts.ClientTags,
|
||||
// TODO(e.burkov): Wait for urlfilter update to pass net.IP.
|
||||
ClientIP: setts.ClientIP.String(),
|
||||
ClientName: setts.ClientName,
|
||||
DNSType: qtype,
|
||||
})
|
||||
|
||||
return res, nil
|
||||
dnsr := dnsres.DNSRewrites()
|
||||
if len(dnsr) == 0 {
|
||||
return Result{}, nil
|
||||
}
|
||||
|
||||
revHosts := d.Config.EtcHosts.ProcessReverse(host, qtype)
|
||||
if len(revHosts) != 0 {
|
||||
res = Result{
|
||||
Reason: RewrittenAutoHosts,
|
||||
var ips []net.IP
|
||||
var revHosts []string
|
||||
|
||||
for _, nr := range dnsr {
|
||||
dr := nr.DNSRewrite
|
||||
if dr == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// TODO(a.garipov): Optimize this with a buffer.
|
||||
res.ReverseHosts = make([]string, len(revHosts))
|
||||
for i := range revHosts {
|
||||
res.ReverseHosts[i] = revHosts[i] + "."
|
||||
switch val := nr.DNSRewrite.Value.(type) {
|
||||
case net.IP:
|
||||
ips = append(ips, val)
|
||||
case string:
|
||||
revHosts = append(revHosts, val)
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
return Result{}, nil
|
||||
return Result{
|
||||
Reason: RewrittenAutoHosts,
|
||||
IPList: ips,
|
||||
ReverseHosts: revHosts,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Process rewrites table
|
||||
@ -647,15 +652,18 @@ func (d *DNSFilter) initFiltering(allowFilters, blockFilters []Filter) error {
|
||||
return err
|
||||
}
|
||||
|
||||
d.engineLock.Lock()
|
||||
d.reset()
|
||||
d.rulesStorage = rulesStorage
|
||||
d.filteringEngine = filteringEngine
|
||||
d.rulesStorageAllow = rulesStorageAllow
|
||||
d.filteringEngineAllow = filteringEngineAllow
|
||||
d.engineLock.Unlock()
|
||||
func() {
|
||||
d.engineLock.Lock()
|
||||
defer d.engineLock.Unlock()
|
||||
|
||||
// Make sure that the OS reclaims memory as soon as possible
|
||||
d.reset()
|
||||
d.rulesStorage = rulesStorage
|
||||
d.filteringEngine = filteringEngine
|
||||
d.rulesStorageAllow = rulesStorageAllow
|
||||
d.filteringEngineAllow = filteringEngineAllow
|
||||
}()
|
||||
|
||||
// Make sure that the OS reclaims memory as soon as possible.
|
||||
debug.FreeOSMemory()
|
||||
log.Debug("initialized filtering engine")
|
||||
|
||||
@ -734,8 +742,8 @@ func (d *DNSFilter) matchHostProcessDNSResult(
|
||||
}
|
||||
|
||||
if dnsres.HostRulesV4 != nil || dnsres.HostRulesV6 != nil {
|
||||
// Question type doesn't match the host rules. Return the first
|
||||
// matched host rule, but without an IP address.
|
||||
// Question type doesn't match the host rules. Return the first matched
|
||||
// host rule, but without an IP address.
|
||||
var matchedRules []rules.Rule
|
||||
if dnsres.HostRulesV4 != nil {
|
||||
matchedRules = []rules.Rule{dnsres.HostRulesV4[0]}
|
||||
@ -760,11 +768,6 @@ func (d *DNSFilter) matchHost(
|
||||
return Result{}, nil
|
||||
}
|
||||
|
||||
d.engineLock.RLock()
|
||||
// Keep in mind that this lock must be held no just when calling Match()
|
||||
// but also while using the rules returned by it.
|
||||
defer d.engineLock.RUnlock()
|
||||
|
||||
ureq := urlfilter.DNSRequest{
|
||||
Hostname: host,
|
||||
SortedClientTags: setts.ClientTags,
|
||||
@ -774,6 +777,13 @@ func (d *DNSFilter) matchHost(
|
||||
DNSType: qtype,
|
||||
}
|
||||
|
||||
d.engineLock.RLock()
|
||||
// Keep in mind that this lock must be held no just when calling Match() but
|
||||
// also while using the rules returned by it.
|
||||
//
|
||||
// TODO(e.burkov): Inspect if the above is true.
|
||||
defer d.engineLock.RUnlock()
|
||||
|
||||
if d.filteringEngineAllow != nil {
|
||||
dnsres, ok := d.filteringEngineAllow.MatchRequest(ureq)
|
||||
if ok {
|
||||
@ -791,8 +801,8 @@ func (d *DNSFilter) matchHost(
|
||||
if dnsr := dnsres.DNSRewrites(); len(dnsr) > 0 {
|
||||
res = d.processDNSRewrites(dnsr)
|
||||
if res.Reason == RewrittenRule && res.CanonName == host {
|
||||
// A rewrite of a host to itself. Go on and try
|
||||
// matching other things.
|
||||
// A rewrite of a host to itself. Go on and try matching other
|
||||
// things.
|
||||
} else {
|
||||
return res, nil
|
||||
}
|
||||
@ -868,8 +878,8 @@ func New(c *Config, blockFilters []Filter) *DNSFilter {
|
||||
}
|
||||
|
||||
d.hostCheckers = []hostChecker{{
|
||||
check: d.checkEtcHosts,
|
||||
name: "etchosts",
|
||||
check: d.matchSysHosts,
|
||||
name: "hosts container",
|
||||
}, {
|
||||
check: d.matchHost,
|
||||
name: "filtering",
|
||||
|
@ -95,7 +95,9 @@ type clientsContainer struct {
|
||||
// dnsServer is used for checking clients IP status access list status
|
||||
dnsServer *dnsforward.Server
|
||||
|
||||
etcHosts *aghnet.EtcHostsContainer // get entries from system hosts-files
|
||||
// etcHosts contains list of rewrite rules taken from the operating system's
|
||||
// hosts databse.
|
||||
etcHosts *aghnet.HostsContainer
|
||||
|
||||
testing bool // if TRUE, this object is used for internal tests
|
||||
}
|
||||
@ -106,7 +108,7 @@ type clientsContainer struct {
|
||||
func (clients *clientsContainer) Init(
|
||||
objects []clientObject,
|
||||
dhcpServer *dhcpd.Server,
|
||||
etcHosts *aghnet.EtcHostsContainer,
|
||||
etcHosts *aghnet.HostsContainer,
|
||||
) {
|
||||
if clients.list != nil {
|
||||
log.Fatal("clients.list != nil")
|
||||
@ -121,13 +123,22 @@ func (clients *clientsContainer) Init(
|
||||
clients.etcHosts = etcHosts
|
||||
clients.addFromConfig(objects)
|
||||
|
||||
if !clients.testing {
|
||||
clients.updateFromDHCP(true)
|
||||
if clients.dhcpServer != nil {
|
||||
clients.dhcpServer.SetOnLeaseChanged(clients.onDHCPLeaseChanged)
|
||||
}
|
||||
if clients.etcHosts != nil {
|
||||
clients.etcHosts.SetOnChanged(clients.onHostsChanged)
|
||||
if clients.testing {
|
||||
return
|
||||
}
|
||||
|
||||
clients.updateFromDHCP(true)
|
||||
if clients.dhcpServer != nil {
|
||||
clients.dhcpServer.SetOnLeaseChanged(clients.onDHCPLeaseChanged)
|
||||
}
|
||||
|
||||
go clients.handleHostsUpdates()
|
||||
}
|
||||
|
||||
func (clients *clientsContainer) handleHostsUpdates() {
|
||||
if clients.etcHosts != nil {
|
||||
for upd := range clients.etcHosts.Upd() {
|
||||
clients.addFromHostsFile(upd)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -250,10 +261,6 @@ func (clients *clientsContainer) onDHCPLeaseChanged(flags int) {
|
||||
}
|
||||
}
|
||||
|
||||
func (clients *clientsContainer) onHostsChanged() {
|
||||
clients.addFromHostsFile()
|
||||
}
|
||||
|
||||
// Exists checks if client with this IP address already exists.
|
||||
func (clients *clientsContainer) Exists(ip net.IP, source clientSource) (ok bool) {
|
||||
clients.lock.Lock()
|
||||
@ -697,7 +704,7 @@ func (clients *clientsContainer) SetWHOISInfo(ip net.IP, wi *RuntimeClientWHOISI
|
||||
log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi)
|
||||
}
|
||||
|
||||
// AddHost adds a new IP-hostname pairing. The priorities of the sources is
|
||||
// AddHost adds a new IP-hostname pairing. The priorities of the sources are
|
||||
// taken into account. ok is true if the pairing was added.
|
||||
func (clients *clientsContainer) AddHost(ip net.IP, host string, src clientSource) (ok bool, err error) {
|
||||
clients.lock.Lock()
|
||||
@ -757,13 +764,7 @@ func (clients *clientsContainer) rmHostsBySrc(src clientSource) {
|
||||
|
||||
// addFromHostsFile fills the client-hostname pairing index from the system's
|
||||
// hosts files.
|
||||
func (clients *clientsContainer) addFromHostsFile() {
|
||||
if clients.etcHosts == nil {
|
||||
return
|
||||
}
|
||||
|
||||
hosts := clients.etcHosts.List()
|
||||
|
||||
func (clients *clientsContainer) addFromHostsFile(hosts *netutil.IPMap) {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
|
@ -710,7 +710,6 @@ func enableFilters(async bool) {
|
||||
}
|
||||
|
||||
func enableFiltersLocked(async bool) {
|
||||
var whiteFilters []filtering.Filter
|
||||
filters := []filtering.Filter{{
|
||||
Data: []byte(strings.Join(config.UserRules, "\n")),
|
||||
}}
|
||||
@ -725,18 +724,20 @@ func enableFiltersLocked(async bool) {
|
||||
FilePath: filter.Path(),
|
||||
})
|
||||
}
|
||||
|
||||
var allowFilters []filtering.Filter
|
||||
for _, filter := range config.WhitelistFilters {
|
||||
if !filter.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
whiteFilters = append(whiteFilters, filtering.Filter{
|
||||
allowFilters = append(allowFilters, filtering.Filter{
|
||||
ID: filter.ID,
|
||||
FilePath: filter.Path(),
|
||||
})
|
||||
}
|
||||
|
||||
if err := Context.dnsFilter.SetFilters(filters, whiteFilters, async); err != nil {
|
||||
if err := Context.dnsFilter.SetFilters(filters, allowFilters, async); err != nil {
|
||||
log.Debug("enabling filters: %s", err)
|
||||
}
|
||||
|
||||
|
@ -44,20 +44,22 @@ type homeContext struct {
|
||||
// Modules
|
||||
// --
|
||||
|
||||
clients clientsContainer // per-client-settings module
|
||||
stats stats.Stats // statistics module
|
||||
queryLog querylog.QueryLog // query log module
|
||||
dnsServer *dnsforward.Server // DNS module
|
||||
rdns *RDNS // rDNS module
|
||||
whois *WHOIS // WHOIS module
|
||||
dnsFilter *filtering.DNSFilter // DNS filtering module
|
||||
dhcpServer *dhcpd.Server // DHCP module
|
||||
auth *Auth // HTTP authentication module
|
||||
filters Filtering // DNS filtering module
|
||||
web *Web // Web (HTTP, HTTPS) module
|
||||
tls *TLSMod // TLS module
|
||||
etcHosts *aghnet.EtcHostsContainer // IP-hostname pairs taken from system configuration (e.g. /etc/hosts) files
|
||||
updater *updater.Updater
|
||||
clients clientsContainer // per-client-settings module
|
||||
stats stats.Stats // statistics module
|
||||
queryLog querylog.QueryLog // query log module
|
||||
dnsServer *dnsforward.Server // DNS module
|
||||
rdns *RDNS // rDNS module
|
||||
whois *WHOIS // WHOIS module
|
||||
dnsFilter *filtering.DNSFilter // DNS filtering module
|
||||
dhcpServer *dhcpd.Server // DHCP module
|
||||
auth *Auth // HTTP authentication module
|
||||
filters Filtering // DNS filtering module
|
||||
web *Web // Web (HTTP, HTTPS) module
|
||||
tls *TLSMod // TLS module
|
||||
// etcHosts is an IP-hostname pairs set taken from system configuration
|
||||
// (e.g. /etc/hosts) files.
|
||||
etcHosts *aghnet.HostsContainer
|
||||
updater *updater.Updater
|
||||
|
||||
subnetDetector *aghnet.SubnetDetector
|
||||
|
||||
@ -257,8 +259,20 @@ func setupConfig(args options) (err error) {
|
||||
})
|
||||
|
||||
if !args.noEtcHosts {
|
||||
Context.etcHosts = &aghnet.EtcHostsContainer{}
|
||||
Context.etcHosts.Init("")
|
||||
var osWritesWatcher aghos.FSWatcher
|
||||
osWritesWatcher, err = aghos.NewOSWritesWatcher()
|
||||
if err != nil {
|
||||
return fmt.Errorf("initing os watcher: %w", err)
|
||||
}
|
||||
|
||||
Context.etcHosts, err = aghnet.NewHostsContainer(
|
||||
aghos.RootDirFS(),
|
||||
osWritesWatcher,
|
||||
aghnet.DefaultHostsPaths()...,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("initing hosts container: %w", err)
|
||||
}
|
||||
}
|
||||
Context.clients.Init(config.Clients, Context.dhcpServer, Context.etcHosts)
|
||||
config.Clients = nil
|
||||
@ -424,7 +438,6 @@ func run(args options, clientBuildFS fs.FS) {
|
||||
fatalOnError(err)
|
||||
|
||||
Context.tls.Start()
|
||||
Context.etcHosts.Start()
|
||||
|
||||
go func() {
|
||||
serr := startDNSServer()
|
||||
@ -647,7 +660,11 @@ func cleanup(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
Context.etcHosts.Close()
|
||||
if Context.etcHosts != nil {
|
||||
if err = Context.etcHosts.Close(); err != nil {
|
||||
log.Error("stopping hosts container: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
if Context.tls != nil {
|
||||
Context.tls.Close()
|
||||
|
Loading…
Reference in New Issue
Block a user