Pull request: 2499 merge rewrites vol.1
Merge in DNS/adguard-home from 2499-merge-rewrites-vol.1 to master
Updates #2499.
Squashed commit of the following:
commit 6b308bc2b360cee8c22e506f31d62bacb4bf8fb3
Merge: f49e9186 2b635bf6
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Thu Oct 14 19:23:07 2021 +0300
Merge branch 'master' into 2499-merge-rewrites-vol.1
commit f49e9186ffc8b7074d03c6721ee56cdb09243684
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Thu Oct 14 18:50:49 2021 +0300
aghos: fix fs events filtering
commit 567dd646556606212af5dab60e3ecbb8fff22c25
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Thu Oct 14 16:50:37 2021 +0300
all: imp code, docs, fix windows
commit 140c8bf519345eb54d0e7500a996fcf465353d71
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Wed Oct 13 19:41:53 2021 +0300
aghnet: use const
commit bebf3f76bd394a498ccad812c57d4507c69529ba
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Wed Oct 13 19:32:37 2021 +0300
all: imp tests, docs
commit 9bfdbb6eb454833135d616e208e82699f98e2562
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Wed Oct 13 18:42:20 2021 +0300
all: imp path more, imp docs
commit ee9ea4c132a6b17787d150bf2bee703abaa57be3
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Wed Oct 13 16:09:46 2021 +0300
all: fix windows, imp paths
commit 6fac8338a81e9ecfebfc23a1adcb964e89f6aee6
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Mon Oct 11 19:53:35 2021 +0300
all: imp code, docs
commit da1ce1a2a3dd2be3fdff2412a6dbd596859dc249
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Mon Oct 11 18:22:50 2021 +0300
aghnet: fix windows tests
commit d29de359ed68118d71efb226a8433fac15ff5c66
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Fri Oct 8 21:02:14 2021 +0300
all: repl & imp
commit 1356c08944cdbb85ce5532d90fe5b077219ce5ff
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Fri Oct 8 01:41:19 2021 +0300
all: add tests, mv logic, added tmpfs
commit f4b11adf8998bc8d9d955c5ac9f386f671bd5213
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Thu Oct 7 14:26:30 2021 +0300
all: imp filewalker, refactor hosts container
This commit is contained in:
parent
2b635bf689
commit
2796e65468
|
@ -1,387 +0,0 @@
|
||||||
package aghnet
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"net"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"runtime"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
|
||||||
"github.com/AdguardTeam/golibs/log"
|
|
||||||
"github.com/AdguardTeam/golibs/netutil"
|
|
||||||
"github.com/fsnotify/fsnotify"
|
|
||||||
"github.com/miekg/dns"
|
|
||||||
)
|
|
||||||
|
|
||||||
type onChangedT func()
|
|
||||||
|
|
||||||
// EtcHostsContainer - automatic DNS records
|
|
||||||
//
|
|
||||||
// TODO(e.burkov): Move the logic under interface. Refactor. Probably remove
|
|
||||||
// the resolving logic.
|
|
||||||
type EtcHostsContainer struct {
|
|
||||||
// lock protects table and tableReverse.
|
|
||||||
lock sync.RWMutex
|
|
||||||
// table is the host-to-IPs map.
|
|
||||||
table map[string][]net.IP
|
|
||||||
// tableReverse is the IP-to-hosts map. The type of the values in the
|
|
||||||
// map is []string.
|
|
||||||
tableReverse *netutil.IPMap
|
|
||||||
|
|
||||||
hostsFn string // path to the main hosts-file
|
|
||||||
hostsDirs []string // paths to OS-specific directories with hosts-files
|
|
||||||
watcher *fsnotify.Watcher // file and directory watcher object
|
|
||||||
|
|
||||||
// onlyWritesChan used to contain only writing events from watcher.
|
|
||||||
onlyWritesChan chan fsnotify.Event
|
|
||||||
|
|
||||||
onChanged onChangedT // notification to other modules
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetOnChanged - set callback function that will be called when the data is changed
|
|
||||||
func (ehc *EtcHostsContainer) SetOnChanged(onChanged onChangedT) {
|
|
||||||
ehc.onChanged = onChanged
|
|
||||||
}
|
|
||||||
|
|
||||||
// Notify other modules
|
|
||||||
func (ehc *EtcHostsContainer) notify() {
|
|
||||||
if ehc.onChanged == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ehc.onChanged()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Init - initialize
|
|
||||||
// hostsFn: Override default name for the hosts-file (optional)
|
|
||||||
func (ehc *EtcHostsContainer) Init(hostsFn string) {
|
|
||||||
ehc.table = make(map[string][]net.IP)
|
|
||||||
ehc.onlyWritesChan = make(chan fsnotify.Event, 2)
|
|
||||||
|
|
||||||
ehc.hostsFn = "/etc/hosts"
|
|
||||||
if runtime.GOOS == "windows" {
|
|
||||||
ehc.hostsFn = os.ExpandEnv("$SystemRoot\\system32\\drivers\\etc\\hosts")
|
|
||||||
}
|
|
||||||
if len(hostsFn) != 0 {
|
|
||||||
ehc.hostsFn = hostsFn
|
|
||||||
}
|
|
||||||
|
|
||||||
if aghos.IsOpenWrt() {
|
|
||||||
// OpenWrt: "/tmp/hosts/dhcp.cfg01411c".
|
|
||||||
ehc.hostsDirs = append(ehc.hostsDirs, "/tmp/hosts")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load hosts initially
|
|
||||||
ehc.updateHosts()
|
|
||||||
|
|
||||||
var err error
|
|
||||||
ehc.watcher, err = fsnotify.NewWatcher()
|
|
||||||
if err != nil {
|
|
||||||
log.Error("etchosts: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start - start module
|
|
||||||
func (ehc *EtcHostsContainer) Start() {
|
|
||||||
if ehc == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug("Start etchostscontainer module")
|
|
||||||
|
|
||||||
ehc.updateHosts()
|
|
||||||
|
|
||||||
if ehc.watcher != nil {
|
|
||||||
go ehc.watcherLoop()
|
|
||||||
|
|
||||||
err := ehc.watcher.Add(ehc.hostsFn)
|
|
||||||
if err != nil {
|
|
||||||
log.Error("Error while initializing watcher for a file %s: %s", ehc.hostsFn, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, dir := range ehc.hostsDirs {
|
|
||||||
err = ehc.watcher.Add(dir)
|
|
||||||
if err != nil {
|
|
||||||
log.Error("Error while initializing watcher for a directory %s: %s", dir, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close - close module
|
|
||||||
func (ehc *EtcHostsContainer) Close() {
|
|
||||||
if ehc == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if ehc.watcher != nil {
|
|
||||||
_ = ehc.watcher.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Don't close onlyWritesChan here and let onlyWrites close it after
|
|
||||||
// watcher.Events is closed to prevent close races.
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process returns the list of IP addresses for the hostname or nil if nothing
|
|
||||||
// found.
|
|
||||||
func (ehc *EtcHostsContainer) Process(host string, qtype uint16) []net.IP {
|
|
||||||
if qtype == dns.TypePTR {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var ipsCopy []net.IP
|
|
||||||
ehc.lock.RLock()
|
|
||||||
defer ehc.lock.RUnlock()
|
|
||||||
|
|
||||||
if ips, ok := ehc.table[host]; ok {
|
|
||||||
ipsCopy = make([]net.IP, len(ips))
|
|
||||||
copy(ipsCopy, ips)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug("etchosts: answer: %s -> %v", host, ipsCopy)
|
|
||||||
return ipsCopy
|
|
||||||
}
|
|
||||||
|
|
||||||
// ProcessReverse processes a PTR request. It returns nil if nothing is found.
|
|
||||||
func (ehc *EtcHostsContainer) ProcessReverse(addr string, qtype uint16) (hosts []string) {
|
|
||||||
if qtype != dns.TypePTR {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
ip, err := netutil.IPFromReversedAddr(addr)
|
|
||||||
if err != nil {
|
|
||||||
log.Error("etchosts: reversed addr: %s", err)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
ehc.lock.RLock()
|
|
||||||
defer ehc.lock.RUnlock()
|
|
||||||
|
|
||||||
v, ok := ehc.tableReverse.Get(ip)
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
hosts, ok = v.([]string)
|
|
||||||
if !ok {
|
|
||||||
log.Error("etchosts: bad type %T in tableReverse for %s", v, ip)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
} else if len(hosts) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug("etchosts: reverse-lookup: %s -> %s", addr, hosts)
|
|
||||||
|
|
||||||
return hosts
|
|
||||||
}
|
|
||||||
|
|
||||||
// List returns an IP-to-hostnames table. The type of the values in the map is
|
|
||||||
// []string. It is safe for concurrent use.
|
|
||||||
func (ehc *EtcHostsContainer) List() (ipToHosts *netutil.IPMap) {
|
|
||||||
ehc.lock.RLock()
|
|
||||||
defer ehc.lock.RUnlock()
|
|
||||||
|
|
||||||
return ehc.tableReverse.ShallowClone()
|
|
||||||
}
|
|
||||||
|
|
||||||
// update table
|
|
||||||
func (ehc *EtcHostsContainer) updateTable(table map[string][]net.IP, host string, ipAddr net.IP) {
|
|
||||||
ips, ok := table[host]
|
|
||||||
if ok {
|
|
||||||
for _, ip := range ips {
|
|
||||||
if ip.Equal(ipAddr) {
|
|
||||||
// IP already exists: don't add duplicates
|
|
||||||
ok = false
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !ok {
|
|
||||||
ips = append(ips, ipAddr)
|
|
||||||
table[host] = ips
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
table[host] = []net.IP{ipAddr}
|
|
||||||
ok = true
|
|
||||||
}
|
|
||||||
if ok {
|
|
||||||
log.Debug("etchosts: added %s -> %s", ipAddr, host)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// updateTableRev updates the reverse address table.
|
|
||||||
func (ehc *EtcHostsContainer) updateTableRev(tableRev *netutil.IPMap, newHost string, ip net.IP) {
|
|
||||||
v, ok := tableRev.Get(ip)
|
|
||||||
if !ok {
|
|
||||||
tableRev.Set(ip, []string{newHost})
|
|
||||||
log.Debug("etchosts: added reverse-address %s -> %s", ip, newHost)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
hosts, _ := v.([]string)
|
|
||||||
for _, host := range hosts {
|
|
||||||
if host == newHost {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
hosts = append(hosts, newHost)
|
|
||||||
tableRev.Set(ip, hosts)
|
|
||||||
|
|
||||||
log.Debug("etchosts: added reverse-address %s -> %s", ip, newHost)
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseHostsLine parses hosts from the fields.
|
|
||||||
func parseHostsLine(fields []string) (hosts []string) {
|
|
||||||
for _, f := range fields {
|
|
||||||
hashIdx := strings.IndexByte(f, '#')
|
|
||||||
if hashIdx == 0 {
|
|
||||||
// The rest of the fields are a part of the comment.
|
|
||||||
// Skip immediately.
|
|
||||||
return
|
|
||||||
} else if hashIdx > 0 {
|
|
||||||
// Only a part of the field is a comment.
|
|
||||||
hosts = append(hosts, f[:hashIdx])
|
|
||||||
|
|
||||||
return hosts
|
|
||||||
}
|
|
||||||
|
|
||||||
hosts = append(hosts, f)
|
|
||||||
}
|
|
||||||
|
|
||||||
return hosts
|
|
||||||
}
|
|
||||||
|
|
||||||
// load reads IP-hostname pairs from the hosts file. Multiple hostnames per
|
|
||||||
// line for one IP are supported.
|
|
||||||
func (ehc *EtcHostsContainer) load(
|
|
||||||
table map[string][]net.IP,
|
|
||||||
tableRev *netutil.IPMap,
|
|
||||||
fn string,
|
|
||||||
) {
|
|
||||||
f, err := os.Open(fn)
|
|
||||||
if err != nil {
|
|
||||||
log.Error("etchosts: %s", err)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
derr := f.Close()
|
|
||||||
if derr != nil {
|
|
||||||
log.Error("etchosts: closing file: %s", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
log.Debug("etchosts: loading hosts from file %s", fn)
|
|
||||||
|
|
||||||
s := bufio.NewScanner(f)
|
|
||||||
for s.Scan() {
|
|
||||||
line := strings.TrimSpace(s.Text())
|
|
||||||
fields := strings.Fields(line)
|
|
||||||
if len(fields) < 2 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
ip := net.ParseIP(fields[0])
|
|
||||||
if ip == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
hosts := parseHostsLine(fields[1:])
|
|
||||||
for _, host := range hosts {
|
|
||||||
ehc.updateTable(table, host, ip)
|
|
||||||
ehc.updateTableRev(tableRev, host, ip)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err = s.Err()
|
|
||||||
if err != nil {
|
|
||||||
log.Error("etchosts: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// onlyWrites is a filter for (*fsnotify.Watcher).Events.
|
|
||||||
func (ehc *EtcHostsContainer) onlyWrites() {
|
|
||||||
for event := range ehc.watcher.Events {
|
|
||||||
if event.Op&fsnotify.Write == fsnotify.Write {
|
|
||||||
ehc.onlyWritesChan <- event
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
close(ehc.onlyWritesChan)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Receive notifications from fsnotify package
|
|
||||||
func (ehc *EtcHostsContainer) watcherLoop() {
|
|
||||||
go ehc.onlyWrites()
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case event, ok := <-ehc.onlyWritesChan:
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Assume that we sometimes have the same event occurred
|
|
||||||
// several times.
|
|
||||||
repeat := true
|
|
||||||
for repeat {
|
|
||||||
select {
|
|
||||||
case _, ok = <-ehc.onlyWritesChan:
|
|
||||||
repeat = ok
|
|
||||||
default:
|
|
||||||
repeat = false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if event.Op&fsnotify.Write == fsnotify.Write {
|
|
||||||
log.Debug("etchosts: modified: %s", event.Name)
|
|
||||||
ehc.updateHosts()
|
|
||||||
}
|
|
||||||
|
|
||||||
case err, ok := <-ehc.watcher.Errors:
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.Error("etchosts: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// updateHosts - loads system hosts
|
|
||||||
func (ehc *EtcHostsContainer) updateHosts() {
|
|
||||||
table := make(map[string][]net.IP)
|
|
||||||
tableRev := netutil.NewIPMap(0)
|
|
||||||
|
|
||||||
ehc.load(table, tableRev, ehc.hostsFn)
|
|
||||||
|
|
||||||
for _, dir := range ehc.hostsDirs {
|
|
||||||
des, err := os.ReadDir(dir)
|
|
||||||
if err != nil {
|
|
||||||
if !errors.Is(err, os.ErrNotExist) {
|
|
||||||
log.Error("etchosts: Opening directory: %q: %s", dir, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, de := range des {
|
|
||||||
ehc.load(table, tableRev, filepath.Join(dir, de.Name()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func() {
|
|
||||||
ehc.lock.Lock()
|
|
||||||
defer ehc.lock.Unlock()
|
|
||||||
|
|
||||||
ehc.table = table
|
|
||||||
ehc.tableReverse = tableRev
|
|
||||||
}()
|
|
||||||
|
|
||||||
ehc.notify()
|
|
||||||
}
|
|
|
@ -1,130 +0,0 @@
|
||||||
package aghnet
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
"os"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
|
||||||
"github.com/miekg/dns"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
|
||||||
aghtest.DiscardLogOutput(m)
|
|
||||||
}
|
|
||||||
|
|
||||||
func prepareTestFile(t *testing.T) (f *os.File) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
dir := t.TempDir()
|
|
||||||
|
|
||||||
f, err := os.CreateTemp(dir, "")
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, f)
|
|
||||||
|
|
||||||
t.Cleanup(func() {
|
|
||||||
assert.NoError(t, f.Close())
|
|
||||||
})
|
|
||||||
|
|
||||||
return f
|
|
||||||
}
|
|
||||||
|
|
||||||
func assertWriting(t *testing.T, f *os.File, strs ...string) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
for _, str := range strs {
|
|
||||||
n, err := f.WriteString(str)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, n, len(str))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestEtcHostsContainerResolution(t *testing.T) {
|
|
||||||
ehc := &EtcHostsContainer{}
|
|
||||||
|
|
||||||
f := prepareTestFile(t)
|
|
||||||
|
|
||||||
assertWriting(t, f,
|
|
||||||
" 127.0.0.1 host localhost # comment \n",
|
|
||||||
" ::1 localhost#comment \n",
|
|
||||||
)
|
|
||||||
ehc.Init(f.Name())
|
|
||||||
|
|
||||||
t.Run("existing_host", func(t *testing.T) {
|
|
||||||
ips := ehc.Process("localhost", dns.TypeA)
|
|
||||||
require.Len(t, ips, 1)
|
|
||||||
assert.Equal(t, net.IPv4(127, 0, 0, 1), ips[0])
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("unknown_host", func(t *testing.T) {
|
|
||||||
ips := ehc.Process("newhost", dns.TypeA)
|
|
||||||
assert.Nil(t, ips)
|
|
||||||
|
|
||||||
// Comment.
|
|
||||||
ips = ehc.Process("comment", dns.TypeA)
|
|
||||||
assert.Nil(t, ips)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("hosts_file", func(t *testing.T) {
|
|
||||||
names, ok := ehc.List().Get(net.IP{127, 0, 0, 1})
|
|
||||||
require.True(t, ok)
|
|
||||||
assert.Equal(t, []string{"host", "localhost"}, names)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("ptr", func(t *testing.T) {
|
|
||||||
testCases := []struct {
|
|
||||||
wantIP string
|
|
||||||
wantHost string
|
|
||||||
wantLen int
|
|
||||||
}{
|
|
||||||
{wantIP: "127.0.0.1", wantHost: "host", wantLen: 2},
|
|
||||||
{wantIP: "::1", wantHost: "localhost", wantLen: 1},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
|
||||||
a, err := dns.ReverseAddr(tc.wantIP)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
a = strings.TrimSuffix(a, ".")
|
|
||||||
hosts := ehc.ProcessReverse(a, dns.TypePTR)
|
|
||||||
require.Len(t, hosts, tc.wantLen)
|
|
||||||
assert.Equal(t, tc.wantHost, hosts[0])
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestEtcHostsContainerFSNotify(t *testing.T) {
|
|
||||||
ehc := &EtcHostsContainer{}
|
|
||||||
|
|
||||||
f := prepareTestFile(t)
|
|
||||||
|
|
||||||
assertWriting(t, f, " 127.0.0.1 host localhost \n")
|
|
||||||
ehc.Init(f.Name())
|
|
||||||
|
|
||||||
t.Run("unknown_host", func(t *testing.T) {
|
|
||||||
ips := ehc.Process("newhost", dns.TypeA)
|
|
||||||
assert.Nil(t, ips)
|
|
||||||
})
|
|
||||||
|
|
||||||
// Start monitoring for changes.
|
|
||||||
ehc.Start()
|
|
||||||
t.Cleanup(ehc.Close)
|
|
||||||
|
|
||||||
assertWriting(t, f, "127.0.0.2 newhost\n")
|
|
||||||
require.NoError(t, f.Sync())
|
|
||||||
|
|
||||||
// Wait until fsnotify has triggered and processed the file-modification
|
|
||||||
// event.
|
|
||||||
time.Sleep(50 * time.Millisecond)
|
|
||||||
|
|
||||||
t.Run("notified", func(t *testing.T) {
|
|
||||||
ips := ehc.Process("newhost", dns.TypeA)
|
|
||||||
assert.NotNil(t, ips)
|
|
||||||
require.Len(t, ips, 1)
|
|
||||||
assert.True(t, net.IP{127, 0, 0, 2}.Equal(ips[0]))
|
|
||||||
})
|
|
||||||
}
|
|
|
@ -0,0 +1,343 @@
|
||||||
|
package aghnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"io/fs"
|
||||||
|
"net"
|
||||||
|
"path"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||||
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
|
"github.com/AdguardTeam/golibs/log"
|
||||||
|
"github.com/AdguardTeam/golibs/netutil"
|
||||||
|
"github.com/AdguardTeam/golibs/stringutil"
|
||||||
|
"github.com/AdguardTeam/urlfilter"
|
||||||
|
"github.com/AdguardTeam/urlfilter/filterlist"
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DefaultHostsPaths returns the slice of paths default for the operating system
|
||||||
|
// to files and directories which are containing the hosts database. The result
|
||||||
|
// is intended to use within fs.FS so the initial slash is omitted.
|
||||||
|
func DefaultHostsPaths() (paths []string) {
|
||||||
|
return defaultHostsPaths()
|
||||||
|
}
|
||||||
|
|
||||||
|
// hostsContainerPref is a prefix for logging and wrapping errors in
|
||||||
|
// HostsContainer's methods.
|
||||||
|
const hostsContainerPref = "hosts container"
|
||||||
|
|
||||||
|
// HostsContainer stores the relevant hosts database provided by the OS and
|
||||||
|
// processes both A/AAAA and PTR DNS requests for those.
|
||||||
|
type HostsContainer struct {
|
||||||
|
// engLock protects rulesStrg and engine.
|
||||||
|
engLock *sync.RWMutex
|
||||||
|
|
||||||
|
// rulesStrg stores the rules obtained from the hosts' file.
|
||||||
|
rulesStrg *filterlist.RuleStorage
|
||||||
|
// engine serves rulesStrg.
|
||||||
|
engine *urlfilter.DNSEngine
|
||||||
|
|
||||||
|
// Updates is the channel for receiving updated hosts. The receivable map's
|
||||||
|
// values has a type of slice of strings.
|
||||||
|
updates chan *netutil.IPMap
|
||||||
|
|
||||||
|
// fsys is the working file system to read hosts files from.
|
||||||
|
fsys fs.FS
|
||||||
|
|
||||||
|
// w tracks the changes in specified files and directories.
|
||||||
|
w aghos.FSWatcher
|
||||||
|
// patterns stores specified paths in the fs.Glob-compatible form.
|
||||||
|
patterns []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// errNoPaths is returned when there are no paths to watch passed to the
|
||||||
|
// HostsContainer.
|
||||||
|
const errNoPaths errors.Error = "hosts paths are empty"
|
||||||
|
|
||||||
|
// NewHostsContainer creates a container of hosts, that watches the paths with
|
||||||
|
// w. paths shouldn't be empty and each of them should locate either a file or
|
||||||
|
// a directory in fsys. fsys and w must be non-nil.
|
||||||
|
func NewHostsContainer(
|
||||||
|
fsys fs.FS,
|
||||||
|
w aghos.FSWatcher,
|
||||||
|
paths ...string,
|
||||||
|
) (hc *HostsContainer, err error) {
|
||||||
|
defer func() { err = errors.Annotate(err, "%s: %w", hostsContainerPref) }()
|
||||||
|
|
||||||
|
if len(paths) == 0 {
|
||||||
|
return nil, errNoPaths
|
||||||
|
}
|
||||||
|
|
||||||
|
patterns, err := pathsToPatterns(fsys, paths)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
hc = &HostsContainer{
|
||||||
|
engLock: &sync.RWMutex{},
|
||||||
|
updates: make(chan *netutil.IPMap, 1),
|
||||||
|
fsys: fsys,
|
||||||
|
w: w,
|
||||||
|
patterns: patterns,
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug("%s: starting", hostsContainerPref)
|
||||||
|
|
||||||
|
// Load initially.
|
||||||
|
if err = hc.refresh(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, p := range paths {
|
||||||
|
err = w.Add(p)
|
||||||
|
if err == nil {
|
||||||
|
continue
|
||||||
|
} else if errors.Is(err, fs.ErrNotExist) {
|
||||||
|
log.Debug("%s: file %q expected to exist but doesn't", hostsContainerPref, p)
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("adding path: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
go hc.handleEvents()
|
||||||
|
|
||||||
|
return hc, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MatchRequest is the request processing method to resolve hostnames and
|
||||||
|
// addresses from the operating system's hosts files. Any request not of A/AAAA
|
||||||
|
// or PTR type will return with an empty result. It's safe for concurrent use.
|
||||||
|
func (hc *HostsContainer) MatchRequest(
|
||||||
|
req urlfilter.DNSRequest,
|
||||||
|
) (res urlfilter.DNSResult, ok bool) {
|
||||||
|
switch req.DNSType {
|
||||||
|
case dns.TypeA, dns.TypeAAAA, dns.TypePTR:
|
||||||
|
log.Debug("%s: handling the request", hostsContainerPref)
|
||||||
|
default:
|
||||||
|
return urlfilter.DNSResult{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
hc.engLock.RLock()
|
||||||
|
defer hc.engLock.RUnlock()
|
||||||
|
|
||||||
|
return hc.engine.MatchRequest(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close implements the io.Closer interface for *HostsContainer.
|
||||||
|
func (hc *HostsContainer) Close() (err error) {
|
||||||
|
log.Debug("%s: closing hosts container", hostsContainerPref)
|
||||||
|
|
||||||
|
return errors.Annotate(hc.w.Close(), "%s: closing: %w", hostsContainerPref)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Upd returns the channel into which the updates are sent.
|
||||||
|
func (hc *HostsContainer) Upd() (updates <-chan *netutil.IPMap) {
|
||||||
|
return hc.updates
|
||||||
|
}
|
||||||
|
|
||||||
|
// pathsToPatterns converts paths into patterns compatible with fs.Glob.
|
||||||
|
func pathsToPatterns(fsys fs.FS, paths []string) (patterns []string, err error) {
|
||||||
|
for i, p := range paths {
|
||||||
|
var fi fs.FileInfo
|
||||||
|
if fi, err = fs.Stat(fsys, p); err != nil {
|
||||||
|
return nil, fmt.Errorf("%q at index %d: %w", p, i, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if fi.IsDir() {
|
||||||
|
p = path.Join(p, "*")
|
||||||
|
}
|
||||||
|
|
||||||
|
patterns = append(patterns, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
return patterns, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleEvents concurrently handles the events. It closes the update channel
|
||||||
|
// of HostsContainer when finishes. Used to be called within a goroutine.
|
||||||
|
func (hc *HostsContainer) handleEvents() {
|
||||||
|
defer log.OnPanic(fmt.Sprintf("%s: handling events", hostsContainerPref))
|
||||||
|
|
||||||
|
defer close(hc.updates)
|
||||||
|
|
||||||
|
for range hc.w.Events() {
|
||||||
|
if err := hc.refresh(); err != nil {
|
||||||
|
log.Error("%s: %s", hostsContainerPref, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// hostsParser is a helper type to parse rules from the operating system's hosts
|
||||||
|
// file.
|
||||||
|
type hostsParser struct {
|
||||||
|
// rules builds the resulting rules list content.
|
||||||
|
rules *strings.Builder
|
||||||
|
|
||||||
|
// table stores only the unique IP-hostname pairs. It's also sent to the
|
||||||
|
// updates channel afterwards.
|
||||||
|
table *netutil.IPMap
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseHostsFile is a aghtest.FileWalker for parsing the files with hosts
|
||||||
|
// syntax. It never signs to stop the walking.
|
||||||
|
//
|
||||||
|
// See man hosts(5).
|
||||||
|
func (hp hostsParser) parseHostsFile(
|
||||||
|
r io.Reader,
|
||||||
|
) (patterns []string, cont bool, err error) {
|
||||||
|
s := bufio.NewScanner(r)
|
||||||
|
for s.Scan() {
|
||||||
|
ip, hosts := hp.parseLine(s.Text())
|
||||||
|
if ip == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, host := range hosts {
|
||||||
|
hp.addPair(ip, host)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, true, s.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseLine parses the line having the hosts syntax ignoring invalid ones.
|
||||||
|
func (hp hostsParser) parseLine(line string) (ip net.IP, hosts []string) {
|
||||||
|
line = strings.TrimSpace(line)
|
||||||
|
fields := strings.Fields(line)
|
||||||
|
if len(fields) < 2 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if ip = net.ParseIP(fields[0]); ip == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
loop:
|
||||||
|
for _, f := range fields[1:] {
|
||||||
|
switch hashIdx := strings.IndexByte(f, '#'); hashIdx {
|
||||||
|
case 0:
|
||||||
|
// The rest of the fields are a part of the comment so skip
|
||||||
|
// immediately.
|
||||||
|
break loop
|
||||||
|
case -1:
|
||||||
|
hosts = append(hosts, f)
|
||||||
|
default:
|
||||||
|
// Only a part of the field is a comment.
|
||||||
|
hosts = append(hosts, f[:hashIdx])
|
||||||
|
|
||||||
|
break loop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ip, hosts
|
||||||
|
}
|
||||||
|
|
||||||
|
// add returns true if the pair of ip and host wasn't added to the hp before.
|
||||||
|
func (hp hostsParser) add(ip net.IP, host string) (added bool) {
|
||||||
|
v, ok := hp.table.Get(ip)
|
||||||
|
hosts, _ := v.([]string)
|
||||||
|
if ok && stringutil.InSlice(hosts, host) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
hp.table.Set(ip, append(hosts, host))
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// addPair puts the pair of ip and host to the rules builder if needed.
|
||||||
|
func (hp hostsParser) addPair(ip net.IP, host string) {
|
||||||
|
arpa, err := netutil.IPToReversedAddr(ip)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !hp.add(ip, host) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
qtype := "AAAA"
|
||||||
|
if ip.To4() != nil {
|
||||||
|
// Assume the validation of the IP address is performed already.
|
||||||
|
qtype = "A"
|
||||||
|
}
|
||||||
|
|
||||||
|
stringutil.WriteToBuilder(
|
||||||
|
hp.rules,
|
||||||
|
"||",
|
||||||
|
host,
|
||||||
|
"^$dnsrewrite=NOERROR;",
|
||||||
|
qtype,
|
||||||
|
";",
|
||||||
|
ip.String(),
|
||||||
|
"\n",
|
||||||
|
"||",
|
||||||
|
arpa,
|
||||||
|
"^$dnsrewrite=NOERROR;PTR;",
|
||||||
|
dns.Fqdn(host),
|
||||||
|
"\n",
|
||||||
|
)
|
||||||
|
|
||||||
|
log.Debug("%s: added ip-host pair %q/%q", hostsContainerPref, ip, host)
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendUpd tries to send the parsed data to the ch.
|
||||||
|
func (hp hostsParser) sendUpd(ch chan *netutil.IPMap) {
|
||||||
|
log.Debug("%s: sending upd", hostsContainerPref)
|
||||||
|
select {
|
||||||
|
case ch <- hp.table:
|
||||||
|
// Updates are delivered. Go on.
|
||||||
|
default:
|
||||||
|
log.Debug("%s: the buffer is full", hostsContainerPref)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// newStrg creates a new rules storage from parsed data.
|
||||||
|
func (hp hostsParser) newStrg() (s *filterlist.RuleStorage, err error) {
|
||||||
|
return filterlist.NewRuleStorage([]filterlist.RuleList{&filterlist.StringRuleList{
|
||||||
|
ID: 1,
|
||||||
|
RulesText: hp.rules.String(),
|
||||||
|
IgnoreCosmetic: true,
|
||||||
|
}})
|
||||||
|
}
|
||||||
|
|
||||||
|
// refresh gets the data from specified files and propagates the updates.
|
||||||
|
func (hc *HostsContainer) refresh() (err error) {
|
||||||
|
log.Debug("%s: refreshing", hostsContainerPref)
|
||||||
|
|
||||||
|
hp := hostsParser{
|
||||||
|
rules: &strings.Builder{},
|
||||||
|
table: netutil.NewIPMap(0),
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = aghos.FileWalker(hp.parseHostsFile).Walk(hc.fsys, hc.patterns...)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("updating: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer hp.sendUpd(hc.updates)
|
||||||
|
|
||||||
|
var rulesStrg *filterlist.RuleStorage
|
||||||
|
if rulesStrg, err = hp.newStrg(); err != nil {
|
||||||
|
return fmt.Errorf("initializing rules storage: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
hc.resetEng(rulesStrg)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hc *HostsContainer) resetEng(rulesStrg *filterlist.RuleStorage) {
|
||||||
|
hc.engLock.Lock()
|
||||||
|
defer hc.engLock.Unlock()
|
||||||
|
|
||||||
|
hc.rulesStrg = rulesStrg
|
||||||
|
hc.engine = urlfilter.NewDNSEngine(hc.rulesStrg)
|
||||||
|
}
|
|
@ -0,0 +1,18 @@
|
||||||
|
//go:build linux
|
||||||
|
// +build linux
|
||||||
|
|
||||||
|
package aghnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||||
|
)
|
||||||
|
|
||||||
|
func defaultHostsPaths() (paths []string) {
|
||||||
|
paths = []string{"etc/hosts"}
|
||||||
|
|
||||||
|
if aghos.IsOpenWrt() {
|
||||||
|
paths = append(paths, "tmp/hosts")
|
||||||
|
}
|
||||||
|
|
||||||
|
return paths
|
||||||
|
}
|
|
@ -0,0 +1,8 @@
|
||||||
|
//go:build !(windows || linux)
|
||||||
|
// +build !windows,!linux
|
||||||
|
|
||||||
|
package aghnet
|
||||||
|
|
||||||
|
func defaultHostsPaths() (paths []string) {
|
||||||
|
return []string{"etc/hosts"}
|
||||||
|
}
|
|
@ -0,0 +1,504 @@
|
||||||
|
package aghnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io/fs"
|
||||||
|
"net"
|
||||||
|
"path"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"testing/fstest"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||||
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
|
"github.com/AdguardTeam/golibs/netutil"
|
||||||
|
"github.com/AdguardTeam/urlfilter"
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
nl = "\n"
|
||||||
|
sp = " "
|
||||||
|
)
|
||||||
|
|
||||||
|
const closeCalled errors.Error = "close method called"
|
||||||
|
|
||||||
|
// fsWatcherOnCloseStub is a stub implementation of the Close method of
|
||||||
|
// aghos.FSWatcher.
|
||||||
|
func fsWatcherOnCloseStub() (err error) {
|
||||||
|
return closeCalled
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewHostsContainer(t *testing.T) {
|
||||||
|
const dirname = "dir"
|
||||||
|
const filename = "file1"
|
||||||
|
|
||||||
|
p := path.Join(dirname, filename)
|
||||||
|
|
||||||
|
testFS := fstest.MapFS{
|
||||||
|
p: &fstest.MapFile{Data: []byte("127.0.0.1 localhost")},
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
paths []string
|
||||||
|
wantErr error
|
||||||
|
wantPatterns []string
|
||||||
|
}{{
|
||||||
|
name: "one_file",
|
||||||
|
paths: []string{p},
|
||||||
|
wantErr: nil,
|
||||||
|
wantPatterns: []string{p},
|
||||||
|
}, {
|
||||||
|
name: "no_files",
|
||||||
|
paths: []string{},
|
||||||
|
wantErr: errNoPaths,
|
||||||
|
wantPatterns: nil,
|
||||||
|
}, {
|
||||||
|
name: "non-existent_file",
|
||||||
|
paths: []string{path.Join(dirname, filename+"2")},
|
||||||
|
wantErr: fs.ErrNotExist,
|
||||||
|
wantPatterns: nil,
|
||||||
|
}, {
|
||||||
|
name: "whole_dir",
|
||||||
|
paths: []string{dirname},
|
||||||
|
wantErr: nil,
|
||||||
|
wantPatterns: []string{path.Join(dirname, "*")},
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
onAdd := func(name string) (err error) {
|
||||||
|
assert.Contains(t, tc.paths, name)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var eventsCalledCounter uint32
|
||||||
|
eventsCh := make(chan struct{})
|
||||||
|
onEvents := func() (e <-chan struct{}) {
|
||||||
|
assert.Equal(t, uint32(1), atomic.AddUint32(&eventsCalledCounter, 1))
|
||||||
|
|
||||||
|
return eventsCh
|
||||||
|
}
|
||||||
|
|
||||||
|
hc, err := NewHostsContainer(testFS, &aghtest.FSWatcher{
|
||||||
|
OnEvents: onEvents,
|
||||||
|
OnAdd: onAdd,
|
||||||
|
OnClose: fsWatcherOnCloseStub,
|
||||||
|
}, tc.paths...)
|
||||||
|
if tc.wantErr != nil {
|
||||||
|
require.ErrorIs(t, err, tc.wantErr)
|
||||||
|
|
||||||
|
assert.Nil(t, hc)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.ErrorIs(t, hc.Close(), closeCalled)
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NotNil(t, hc)
|
||||||
|
|
||||||
|
assert.Equal(t, tc.wantPatterns, hc.patterns)
|
||||||
|
assert.NotNil(t, <-hc.Upd())
|
||||||
|
|
||||||
|
eventsCh <- struct{}{}
|
||||||
|
assert.Equal(t, uint32(1), atomic.LoadUint32(&eventsCalledCounter))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("nil_fs", func(t *testing.T) {
|
||||||
|
require.Panics(t, func() {
|
||||||
|
_, _ = NewHostsContainer(nil, &aghtest.FSWatcher{
|
||||||
|
// Those shouldn't panic.
|
||||||
|
OnEvents: func() (e <-chan struct{}) { return nil },
|
||||||
|
OnAdd: func(name string) (err error) { return nil },
|
||||||
|
OnClose: func() (err error) { return nil },
|
||||||
|
}, p)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("nil_watcher", func(t *testing.T) {
|
||||||
|
require.Panics(t, func() {
|
||||||
|
_, _ = NewHostsContainer(testFS, nil, p)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("err_watcher", func(t *testing.T) {
|
||||||
|
const errOnAdd errors.Error = "error"
|
||||||
|
|
||||||
|
errWatcher := &aghtest.FSWatcher{
|
||||||
|
OnEvents: func() (e <-chan struct{}) { panic("not implemented") },
|
||||||
|
OnAdd: func(name string) (err error) { return errOnAdd },
|
||||||
|
OnClose: func() (err error) { panic("not implemented") },
|
||||||
|
}
|
||||||
|
|
||||||
|
hc, err := NewHostsContainer(testFS, errWatcher, p)
|
||||||
|
require.ErrorIs(t, err, errOnAdd)
|
||||||
|
|
||||||
|
assert.Nil(t, hc)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHostsContainer_Refresh(t *testing.T) {
|
||||||
|
knownIP := net.IP{127, 0, 0, 1}
|
||||||
|
|
||||||
|
const knownHost = "localhost"
|
||||||
|
const knownAlias = "hocallost"
|
||||||
|
|
||||||
|
const dirname = "dir"
|
||||||
|
const filename1 = "file1"
|
||||||
|
const filename2 = "file2"
|
||||||
|
|
||||||
|
p1 := path.Join(dirname, filename1)
|
||||||
|
p2 := path.Join(dirname, filename2)
|
||||||
|
|
||||||
|
testFS := fstest.MapFS{
|
||||||
|
p1: &fstest.MapFile{
|
||||||
|
Data: []byte(strings.Join([]string{knownIP.String(), knownHost}, sp) + nl),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
eventsCh := make(chan struct{}, 1)
|
||||||
|
t.Cleanup(func() { close(eventsCh) })
|
||||||
|
|
||||||
|
w := &aghtest.FSWatcher{
|
||||||
|
OnEvents: func() (e <-chan struct{}) { return eventsCh },
|
||||||
|
OnAdd: func(name string) (err error) {
|
||||||
|
assert.Equal(t, dirname, name)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
OnClose: fsWatcherOnCloseStub,
|
||||||
|
}
|
||||||
|
|
||||||
|
hc, err := NewHostsContainer(testFS, w, dirname)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() { require.ErrorIs(t, hc.Close(), closeCalled) })
|
||||||
|
|
||||||
|
checkRefresh := func(t *testing.T, wantHosts []string) {
|
||||||
|
upd, ok := <-hc.Upd()
|
||||||
|
require.True(t, ok)
|
||||||
|
require.NotNil(t, upd)
|
||||||
|
|
||||||
|
assert.Equal(t, 1, upd.Len())
|
||||||
|
|
||||||
|
v, ok := upd.Get(knownIP)
|
||||||
|
require.True(t, ok)
|
||||||
|
|
||||||
|
var hosts []string
|
||||||
|
hosts, ok = v.([]string)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Len(t, hosts, len(wantHosts))
|
||||||
|
|
||||||
|
assert.Equal(t, wantHosts, hosts)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("initial_refresh", func(t *testing.T) {
|
||||||
|
checkRefresh(t, []string{knownHost})
|
||||||
|
})
|
||||||
|
|
||||||
|
testFS[p2] = &fstest.MapFile{
|
||||||
|
Data: []byte(strings.Join([]string{knownIP.String(), knownAlias}, sp) + nl),
|
||||||
|
}
|
||||||
|
|
||||||
|
eventsCh <- struct{}{}
|
||||||
|
|
||||||
|
t.Run("second_refresh", func(t *testing.T) {
|
||||||
|
checkRefresh(t, []string{knownHost, knownAlias})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHostsContainer_MatchRequest(t *testing.T) {
|
||||||
|
var (
|
||||||
|
ip4 = net.IP{127, 0, 0, 1}
|
||||||
|
ip6 = net.IP{
|
||||||
|
0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0,
|
||||||
|
0, 0, 0, 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
hostname4 = "localhost"
|
||||||
|
hostname6 = "localhostv6"
|
||||||
|
hostname4a = "abcd"
|
||||||
|
|
||||||
|
reversed4, _ = netutil.IPToReversedAddr(ip4)
|
||||||
|
reversed6, _ = netutil.IPToReversedAddr(ip6)
|
||||||
|
)
|
||||||
|
|
||||||
|
const filename = "file1"
|
||||||
|
|
||||||
|
gsfs := fstest.MapFS{
|
||||||
|
filename: &fstest.MapFile{Data: []byte(
|
||||||
|
strings.Join([]string{ip4.String(), hostname4, hostname4a}, sp) + nl +
|
||||||
|
strings.Join([]string{ip6.String(), hostname6}, sp) + nl +
|
||||||
|
strings.Join([]string{"256.256.256.256", "fakebroadcast"}, sp) + nl,
|
||||||
|
)},
|
||||||
|
}
|
||||||
|
|
||||||
|
hc, err := NewHostsContainer(gsfs, &aghtest.FSWatcher{
|
||||||
|
OnEvents: func() (e <-chan struct{}) { panic("not implemented") },
|
||||||
|
OnAdd: func(name string) (err error) {
|
||||||
|
assert.Equal(t, filename, name)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
OnClose: fsWatcherOnCloseStub,
|
||||||
|
}, filename)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() { require.ErrorIs(t, hc.Close(), closeCalled) })
|
||||||
|
|
||||||
|
testCase := []struct {
|
||||||
|
name string
|
||||||
|
want interface{}
|
||||||
|
req urlfilter.DNSRequest
|
||||||
|
}{{
|
||||||
|
name: "a",
|
||||||
|
want: ip4.To16(),
|
||||||
|
req: urlfilter.DNSRequest{
|
||||||
|
Hostname: hostname4,
|
||||||
|
DNSType: dns.TypeA,
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
name: "aaaa",
|
||||||
|
want: ip6,
|
||||||
|
req: urlfilter.DNSRequest{
|
||||||
|
Hostname: hostname6,
|
||||||
|
DNSType: dns.TypeA,
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
name: "ptr",
|
||||||
|
want: dns.Fqdn(hostname4),
|
||||||
|
req: urlfilter.DNSRequest{
|
||||||
|
Hostname: reversed4,
|
||||||
|
DNSType: dns.TypePTR,
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
name: "ptr_v6",
|
||||||
|
want: dns.Fqdn(hostname6),
|
||||||
|
req: urlfilter.DNSRequest{
|
||||||
|
Hostname: reversed6,
|
||||||
|
DNSType: dns.TypePTR,
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
name: "a_alias",
|
||||||
|
want: ip4.To16(),
|
||||||
|
req: urlfilter.DNSRequest{
|
||||||
|
Hostname: hostname4a,
|
||||||
|
DNSType: dns.TypeA,
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, tc := range testCase {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
res, ok := hc.MatchRequest(tc.req)
|
||||||
|
require.False(t, ok)
|
||||||
|
|
||||||
|
assert.Equal(t, tc.want, res.DNSRewrites()[0].DNSRewrite.Value)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("cname", func(t *testing.T) {
|
||||||
|
res, ok := hc.MatchRequest(urlfilter.DNSRequest{
|
||||||
|
Hostname: hostname4,
|
||||||
|
DNSType: dns.TypeCNAME,
|
||||||
|
})
|
||||||
|
require.False(t, ok)
|
||||||
|
|
||||||
|
assert.Empty(t, res)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHostsContainer_PathsToPatterns(t *testing.T) {
|
||||||
|
const (
|
||||||
|
dir0 = "dir"
|
||||||
|
dir1 = "dir_1"
|
||||||
|
fn1 = "file_1"
|
||||||
|
fn2 = "file_2"
|
||||||
|
fn3 = "file_3"
|
||||||
|
fn4 = "file_4"
|
||||||
|
)
|
||||||
|
|
||||||
|
fp1 := path.Join(dir0, fn1)
|
||||||
|
fp2 := path.Join(dir0, fn2)
|
||||||
|
fp3 := path.Join(dir0, dir1, fn3)
|
||||||
|
|
||||||
|
gsfs := fstest.MapFS{
|
||||||
|
fp1: &fstest.MapFile{Data: []byte{1}},
|
||||||
|
fp2: &fstest.MapFile{Data: []byte{2}},
|
||||||
|
fp3: &fstest.MapFile{Data: []byte{3}},
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
wantErr error
|
||||||
|
want []string
|
||||||
|
paths []string
|
||||||
|
}{{
|
||||||
|
name: "no_paths",
|
||||||
|
wantErr: nil,
|
||||||
|
want: nil,
|
||||||
|
paths: nil,
|
||||||
|
}, {
|
||||||
|
name: "single_file",
|
||||||
|
wantErr: nil,
|
||||||
|
want: []string{fp1},
|
||||||
|
paths: []string{fp1},
|
||||||
|
}, {
|
||||||
|
name: "several_files",
|
||||||
|
wantErr: nil,
|
||||||
|
want: []string{fp1, fp2},
|
||||||
|
paths: []string{fp1, fp2},
|
||||||
|
}, {
|
||||||
|
name: "whole_dir",
|
||||||
|
wantErr: nil,
|
||||||
|
want: []string{path.Join(dir0, "*")},
|
||||||
|
paths: []string{dir0},
|
||||||
|
}, {
|
||||||
|
name: "file_and_dir",
|
||||||
|
wantErr: nil,
|
||||||
|
want: []string{fp1, path.Join(dir0, dir1, "*")},
|
||||||
|
paths: []string{fp1, path.Join(dir0, dir1)},
|
||||||
|
}, {
|
||||||
|
name: "non-existing",
|
||||||
|
wantErr: fs.ErrNotExist,
|
||||||
|
want: nil,
|
||||||
|
paths: []string{path.Join(dir0, "file_3")},
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
patterns, err := pathsToPatterns(gsfs, tc.paths)
|
||||||
|
if tc.wantErr != nil {
|
||||||
|
assert.ErrorIs(t, err, tc.wantErr)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, tc.want, patterns)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUniqueRules_AddPair(t *testing.T) {
|
||||||
|
knownIP := net.IP{1, 2, 3, 4}
|
||||||
|
|
||||||
|
const knownHost = "host1"
|
||||||
|
|
||||||
|
ipToHost := netutil.NewIPMap(0)
|
||||||
|
ipToHost.Set(knownIP, []string{knownHost})
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
host string
|
||||||
|
wantRules string
|
||||||
|
ip net.IP
|
||||||
|
}{{
|
||||||
|
name: "new_one",
|
||||||
|
host: "host2",
|
||||||
|
wantRules: "||host2^$dnsrewrite=NOERROR;A;1.2.3.4\n" +
|
||||||
|
"||4.3.2.1.in-addr.arpa^$dnsrewrite=NOERROR;PTR;host2.\n",
|
||||||
|
ip: knownIP,
|
||||||
|
}, {
|
||||||
|
name: "existing_one",
|
||||||
|
host: knownHost,
|
||||||
|
wantRules: "",
|
||||||
|
ip: knownIP,
|
||||||
|
}, {
|
||||||
|
name: "new_ip",
|
||||||
|
host: knownHost,
|
||||||
|
wantRules: "||" + knownHost + "^$dnsrewrite=NOERROR;A;1.2.3.5\n" +
|
||||||
|
"||5.3.2.1.in-addr.arpa^$dnsrewrite=NOERROR;PTR;" + knownHost + ".\n",
|
||||||
|
ip: net.IP{1, 2, 3, 5},
|
||||||
|
}, {
|
||||||
|
name: "bad_ip",
|
||||||
|
host: knownHost,
|
||||||
|
wantRules: "",
|
||||||
|
ip: net.IP{1, 2, 3, 4, 5},
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
hp := hostsParser{
|
||||||
|
rules: &strings.Builder{},
|
||||||
|
table: ipToHost.ShallowClone(),
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
hp.addPair(tc.ip, tc.host)
|
||||||
|
assert.Equal(t, tc.wantRules, hp.rules.String())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUniqueRules_ParseLine(t *testing.T) {
|
||||||
|
const (
|
||||||
|
hostname = "localhost"
|
||||||
|
alias = "hocallost"
|
||||||
|
)
|
||||||
|
|
||||||
|
knownIP := net.IP{127, 0, 0, 1}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
line string
|
||||||
|
wantIP net.IP
|
||||||
|
wantHosts []string
|
||||||
|
}{{
|
||||||
|
name: "simple",
|
||||||
|
line: strings.Join([]string{knownIP.String(), hostname}, sp),
|
||||||
|
wantIP: knownIP,
|
||||||
|
wantHosts: []string{"localhost"},
|
||||||
|
}, {
|
||||||
|
name: "aliases",
|
||||||
|
line: strings.Join([]string{knownIP.String(), hostname, alias}, sp),
|
||||||
|
wantIP: knownIP,
|
||||||
|
wantHosts: []string{"localhost", "hocallost"},
|
||||||
|
}, {
|
||||||
|
name: "invalid_line",
|
||||||
|
line: knownIP.String(),
|
||||||
|
wantIP: nil,
|
||||||
|
wantHosts: nil,
|
||||||
|
}, {
|
||||||
|
name: "invalid_line_hostname",
|
||||||
|
line: strings.Join([]string{knownIP.String(), "#" + hostname}, sp),
|
||||||
|
wantIP: knownIP,
|
||||||
|
wantHosts: nil,
|
||||||
|
}, {
|
||||||
|
name: "commented_aliases",
|
||||||
|
line: strings.Join([]string{knownIP.String(), hostname, "#" + alias}, sp),
|
||||||
|
wantIP: knownIP,
|
||||||
|
wantHosts: []string{"localhost"},
|
||||||
|
}, {
|
||||||
|
name: "whole_comment",
|
||||||
|
line: strings.Join([]string{"#", knownIP.String(), hostname}, sp),
|
||||||
|
wantIP: nil,
|
||||||
|
wantHosts: nil,
|
||||||
|
}, {
|
||||||
|
name: "partial_comment",
|
||||||
|
line: strings.Join([]string{knownIP.String(), hostname[:4] + "#" + hostname[4:]}, sp),
|
||||||
|
wantIP: knownIP,
|
||||||
|
wantHosts: []string{hostname[:4]},
|
||||||
|
}, {
|
||||||
|
name: "empty",
|
||||||
|
line: ``,
|
||||||
|
wantIP: nil,
|
||||||
|
wantHosts: nil,
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
hp := hostsParser{}
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
ip, hosts := hp.parseLine(tc.line)
|
||||||
|
assert.True(t, tc.wantIP.Equal(ip))
|
||||||
|
assert.Equal(t, tc.wantHosts, hosts)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,33 @@
|
||||||
|
//go:build windows
|
||||||
|
// +build windows
|
||||||
|
|
||||||
|
package aghnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/golibs/log"
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
func defaultHostsPaths() (paths []string) {
|
||||||
|
sysDir, err := windows.GetSystemDirectory()
|
||||||
|
if err != nil {
|
||||||
|
log.Error("getting system directory: %s", err)
|
||||||
|
|
||||||
|
return []string{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Split all the elements of the path to join them afterwards. This is
|
||||||
|
// needed to make the Windows-specific path string returned by
|
||||||
|
// windows.GetSystemDirectory to be compatible with fs.FS.
|
||||||
|
pathElems := strings.Split(sysDir, string(os.PathSeparator))
|
||||||
|
if len(pathElems) > 0 && pathElems[0] == filepath.VolumeName(sysDir) {
|
||||||
|
pathElems = pathElems[1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
return []string{path.Join(append(pathElems, "drivers/etc/hosts")...)}
|
||||||
|
}
|
|
@ -18,9 +18,11 @@ func canBindPrivilegedPorts() (can bool, err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func ifaceHasStaticIP(ifaceName string) (ok bool, err error) {
|
func ifaceHasStaticIP(ifaceName string) (ok bool, err error) {
|
||||||
const filename = "/etc/rc.conf"
|
const rcConfFilename = "etc/rc.conf"
|
||||||
|
|
||||||
return aghos.FileWalker(interfaceName(ifaceName).rcConfStaticConfig).Walk(filename)
|
walker := aghos.FileWalker(interfaceName(ifaceName).rcConfStaticConfig)
|
||||||
|
|
||||||
|
return walker.Walk(aghos.RootDirFS(), rcConfFilename)
|
||||||
}
|
}
|
||||||
|
|
||||||
// rcConfStaticConfig checks if the interface is configured by /etc/rc.conf to
|
// rcConfStaticConfig checks if the interface is configured by /etc/rc.conf to
|
||||||
|
|
|
@ -85,17 +85,17 @@ func ifaceHasStaticIP(ifaceName string) (has bool, err error) {
|
||||||
|
|
||||||
iface := interfaceName(ifaceName)
|
iface := interfaceName(ifaceName)
|
||||||
|
|
||||||
for _, pair := range []struct {
|
for _, pair := range [...]struct {
|
||||||
aghos.FileWalker
|
aghos.FileWalker
|
||||||
filename string
|
filename string
|
||||||
}{{
|
}{{
|
||||||
FileWalker: iface.dhcpcdStaticConfig,
|
FileWalker: iface.dhcpcdStaticConfig,
|
||||||
filename: "/etc/dhcpcd.conf",
|
filename: "etc/dhcpcd.conf",
|
||||||
}, {
|
}, {
|
||||||
FileWalker: iface.ifacesStaticConfig,
|
FileWalker: iface.ifacesStaticConfig,
|
||||||
filename: "/etc/network/interfaces",
|
filename: "etc/network/interfaces",
|
||||||
}} {
|
}} {
|
||||||
has, err = pair.Walk(pair.filename)
|
has, err = pair.Walk(aghos.RootDirFS(), pair.filename)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,8 +12,6 @@ import (
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
const nl = "\n"
|
|
||||||
|
|
||||||
func TestDHCPCDStaticConfig(t *testing.T) {
|
func TestDHCPCDStaticConfig(t *testing.T) {
|
||||||
const iface interfaceName = `wlan0`
|
const iface interfaceName = `wlan0`
|
||||||
|
|
||||||
|
|
|
@ -18,9 +18,9 @@ func canBindPrivilegedPorts() (can bool, err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func ifaceHasStaticIP(ifaceName string) (ok bool, err error) {
|
func ifaceHasStaticIP(ifaceName string) (ok bool, err error) {
|
||||||
filename := fmt.Sprintf("/etc/hostname.%s", ifaceName)
|
filename := fmt.Sprintf("etc/hostname.%s", ifaceName)
|
||||||
|
|
||||||
return aghos.FileWalker(hostnameIfStaticConfig).Walk(filename)
|
return aghos.FileWalker(hostnameIfStaticConfig).Walk(aghos.RootDirFS(), filename)
|
||||||
}
|
}
|
||||||
|
|
||||||
// hostnameIfStaticConfig checks if the interface is configured by
|
// hostnameIfStaticConfig checks if the interface is configured by
|
||||||
|
|
|
@ -4,10 +4,15 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestMain(m *testing.M) {
|
||||||
|
aghtest.DiscardLogOutput(m)
|
||||||
|
}
|
||||||
|
|
||||||
func TestGetValidNetInterfacesForWeb(t *testing.T) {
|
func TestGetValidNetInterfacesForWeb(t *testing.T) {
|
||||||
ifaces, err := GetValidNetInterfacesForWeb()
|
ifaces, err := GetValidNetInterfacesForWeb()
|
||||||
require.NoErrorf(t, err, "cannot get net interfaces: %s", err)
|
require.NoErrorf(t, err, "cannot get net interfaces: %s", err)
|
||||||
|
|
|
@ -3,8 +3,7 @@ package aghos
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"io/fs"
|
||||||
"path/filepath"
|
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
|
@ -14,10 +13,10 @@ import (
|
||||||
// FileWalker is the signature of a function called for files in the file tree.
|
// FileWalker is the signature of a function called for files in the file tree.
|
||||||
// As opposed to filepath.Walk it only walk the files (not directories) matching
|
// As opposed to filepath.Walk it only walk the files (not directories) matching
|
||||||
// the provided pattern and those returned by function itself. All patterns
|
// the provided pattern and those returned by function itself. All patterns
|
||||||
// should be valid for filepath.Glob. If cont is false, the walking terminates.
|
// should be valid for fs.Glob. If cont is false, the walking terminates. Each
|
||||||
// Each opened file is also limited for reading to MaxWalkedFileSize.
|
// opened file is also limited for reading to MaxWalkedFileSize.
|
||||||
//
|
//
|
||||||
// TODO(e.burkov): Consider moving to the separate package like pathutil.
|
// TODO(e.burkov, a.garipov): Move into another package like aghfs.
|
||||||
//
|
//
|
||||||
// TODO(e.burkov): Think about passing filename or any additional data.
|
// TODO(e.burkov): Think about passing filename or any additional data.
|
||||||
type FileWalker func(r io.Reader) (patterns []string, cont bool, err error)
|
type FileWalker func(r io.Reader) (patterns []string, cont bool, err error)
|
||||||
|
@ -26,15 +25,19 @@ type FileWalker func(r io.Reader) (patterns []string, cont bool, err error)
|
||||||
// check.
|
// check.
|
||||||
const MaxWalkedFileSize = 1024 * 1024
|
const MaxWalkedFileSize = 1024 * 1024
|
||||||
|
|
||||||
// checkFile tries to open and process a single file located on sourcePath.
|
// checkFile tries to open and process a single file located on sourcePath in
|
||||||
func checkFile(c FileWalker, sourcePath string) (patterns []string, cont bool, err error) {
|
// the specified fsys. The path is skipped if it's a directory.
|
||||||
var f *os.File
|
func checkFile(
|
||||||
f, err = os.Open(sourcePath)
|
fsys fs.FS,
|
||||||
|
c FileWalker,
|
||||||
|
sourcePath string,
|
||||||
|
) (patterns []string, cont bool, err error) {
|
||||||
|
var f fs.File
|
||||||
|
f, err = fsys.Open(sourcePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, os.ErrNotExist) {
|
if errors.Is(err, fs.ErrNotExist) {
|
||||||
// Ignore non-existing files since this may only happen
|
// Ignore non-existing files since this may only happen when the
|
||||||
// when the file was removed after filepath.Glob matched
|
// file was removed after filepath.Glob matched it.
|
||||||
// it.
|
|
||||||
return nil, true, nil
|
return nil, true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -42,9 +45,18 @@ func checkFile(c FileWalker, sourcePath string) (patterns []string, cont bool, e
|
||||||
}
|
}
|
||||||
defer func() { err = errors.WithDeferred(err, f.Close()) }()
|
defer func() { err = errors.WithDeferred(err, f.Close()) }()
|
||||||
|
|
||||||
|
var fi fs.FileInfo
|
||||||
|
if fi, err = f.Stat(); err != nil {
|
||||||
|
return nil, true, err
|
||||||
|
}
|
||||||
|
if fi.IsDir() {
|
||||||
|
// Skip the directories.
|
||||||
|
return nil, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
var r io.Reader
|
var r io.Reader
|
||||||
// Ignore the error since LimitReader function returns error only if
|
// Ignore the error since LimitReader function returns error only if passed
|
||||||
// passed limit value is less than zero, but the constant used.
|
// limit value is less than zero, but the constant used.
|
||||||
//
|
//
|
||||||
// TODO(e.burkov): Make variable.
|
// TODO(e.burkov): Make variable.
|
||||||
r, _ = aghio.LimitReader(f, MaxWalkedFileSize)
|
r, _ = aghio.LimitReader(f, MaxWalkedFileSize)
|
||||||
|
@ -52,13 +64,17 @@ func checkFile(c FileWalker, sourcePath string) (patterns []string, cont bool, e
|
||||||
return c(r)
|
return c(r)
|
||||||
}
|
}
|
||||||
|
|
||||||
// handlePatterns parses the patterns and ignores duplicates using srcSet.
|
// handlePatterns parses the patterns in fsys and ignores duplicates using
|
||||||
// srcSet must be non-nil.
|
// srcSet. srcSet must be non-nil.
|
||||||
func handlePatterns(srcSet *stringutil.Set, patterns ...string) (sub []string, err error) {
|
func handlePatterns(
|
||||||
|
fsys fs.FS,
|
||||||
|
srcSet *stringutil.Set,
|
||||||
|
patterns ...string,
|
||||||
|
) (sub []string, err error) {
|
||||||
sub = make([]string, 0, len(patterns))
|
sub = make([]string, 0, len(patterns))
|
||||||
for _, p := range patterns {
|
for _, p := range patterns {
|
||||||
var matches []string
|
var matches []string
|
||||||
matches, err = filepath.Glob(p)
|
matches, err = fs.Glob(fsys, p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Enrich error with the pattern because filepath.Glob
|
// Enrich error with the pattern because filepath.Glob
|
||||||
// doesn't do it.
|
// doesn't do it.
|
||||||
|
@ -78,14 +94,14 @@ func handlePatterns(srcSet *stringutil.Set, patterns ...string) (sub []string, e
|
||||||
return sub, nil
|
return sub, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Walk starts walking the files defined by initPattern. It only returns true
|
// Walk starts walking the files in fsys defined by patterns from initial.
|
||||||
// if c signed to stop walking.
|
// It only returns true if fw signed to stop walking.
|
||||||
func (c FileWalker) Walk(initPattern string) (ok bool, err error) {
|
func (fw FileWalker) Walk(fsys fs.FS, initial ...string) (ok bool, err error) {
|
||||||
// The slice of sources keeps the order in which the files are walked
|
// The slice of sources keeps the order in which the files are walked since
|
||||||
// since srcSet.Values() returns strings in undefined order.
|
// srcSet.Values() returns strings in undefined order.
|
||||||
srcSet := stringutil.NewSet()
|
srcSet := stringutil.NewSet()
|
||||||
var src []string
|
var src []string
|
||||||
src, err = handlePatterns(srcSet, initPattern)
|
src, err = handlePatterns(fsys, srcSet, initial...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
@ -97,7 +113,7 @@ func (c FileWalker) Walk(initPattern string) (ok bool, err error) {
|
||||||
var patterns []string
|
var patterns []string
|
||||||
var cont bool
|
var cont bool
|
||||||
filename = src[i]
|
filename = src[i]
|
||||||
patterns, cont, err = checkFile(c, src[i])
|
patterns, cont, err = checkFile(fsys, fw, src[i])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
@ -107,7 +123,7 @@ func (c FileWalker) Walk(initPattern string) (ok bool, err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
var subsrc []string
|
var subsrc []string
|
||||||
subsrc, err = handlePatterns(srcSet, patterns...)
|
subsrc, err = handlePatterns(fsys, srcSet, patterns...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,56 +4,19 @@ import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"io"
|
"io"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"os"
|
"path"
|
||||||
"path/filepath"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
"testing/fstest"
|
||||||
|
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
// testFSDir maps entries' names to entries which should either be a testFSDir
|
|
||||||
// or byte slice.
|
|
||||||
type testFSDir map[string]interface{}
|
|
||||||
|
|
||||||
// testFSGen is used to generate a temporary filesystem consisting of
|
|
||||||
// directories and plain text files from itself.
|
|
||||||
type testFSGen testFSDir
|
|
||||||
|
|
||||||
// gen returns the name of top directory of the generated filesystem.
|
|
||||||
func (g testFSGen) gen(t *testing.T) (dirName string) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
dirName = t.TempDir()
|
|
||||||
g.rangeThrough(t, dirName)
|
|
||||||
|
|
||||||
return dirName
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g testFSGen) rangeThrough(t *testing.T, dirName string) {
|
|
||||||
const perm fs.FileMode = 0o777
|
|
||||||
|
|
||||||
for k, e := range g {
|
|
||||||
switch e := e.(type) {
|
|
||||||
case []byte:
|
|
||||||
require.NoError(t, os.WriteFile(filepath.Join(dirName, k), e, perm))
|
|
||||||
|
|
||||||
case testFSDir:
|
|
||||||
newDir := filepath.Join(dirName, k)
|
|
||||||
require.NoError(t, os.Mkdir(newDir, perm))
|
|
||||||
|
|
||||||
testFSGen(e).rangeThrough(t, newDir)
|
|
||||||
default:
|
|
||||||
t.Fatalf("unexpected entry type %T", e)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFileWalker_Walk(t *testing.T) {
|
func TestFileWalker_Walk(t *testing.T) {
|
||||||
const attribute = `000`
|
const attribute = `000`
|
||||||
|
|
||||||
makeFileWalker := func(dirName string) (fw FileWalker) {
|
makeFileWalker := func(_ string) (fw FileWalker) {
|
||||||
return func(r io.Reader) (patterns []string, cont bool, err error) {
|
return func(r io.Reader) (patterns []string, cont bool, err error) {
|
||||||
s := bufio.NewScanner(r)
|
s := bufio.NewScanner(r)
|
||||||
for s.Scan() {
|
for s.Scan() {
|
||||||
|
@ -63,7 +26,7 @@ func TestFileWalker_Walk(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(line) != 0 {
|
if len(line) != 0 {
|
||||||
patterns = append(patterns, filepath.Join(dirName, line))
|
patterns = append(patterns, path.Join(".", line))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -74,136 +37,150 @@ func TestFileWalker_Walk(t *testing.T) {
|
||||||
const nl = "\n"
|
const nl = "\n"
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
testFS fstest.MapFS
|
||||||
testFS testFSGen
|
want assert.BoolAssertionFunc
|
||||||
initPattern string
|
initPattern string
|
||||||
want bool
|
name string
|
||||||
}{{
|
}{{
|
||||||
name: "simple",
|
name: "simple",
|
||||||
testFS: testFSGen{
|
testFS: fstest.MapFS{
|
||||||
"simple_0001.txt": []byte(attribute + nl),
|
"simple_0001.txt": &fstest.MapFile{Data: []byte(attribute + nl)},
|
||||||
},
|
},
|
||||||
initPattern: "simple_0001.txt",
|
initPattern: "simple_0001.txt",
|
||||||
want: true,
|
want: assert.True,
|
||||||
}, {
|
}, {
|
||||||
name: "chain",
|
name: "chain",
|
||||||
testFS: testFSGen{
|
testFS: fstest.MapFS{
|
||||||
"chain_0001.txt": []byte(`chain_0002.txt` + nl),
|
"chain_0001.txt": &fstest.MapFile{Data: []byte(`chain_0002.txt` + nl)},
|
||||||
"chain_0002.txt": []byte(`chain_0003.txt` + nl),
|
"chain_0002.txt": &fstest.MapFile{Data: []byte(`chain_0003.txt` + nl)},
|
||||||
"chain_0003.txt": []byte(attribute + nl),
|
"chain_0003.txt": &fstest.MapFile{Data: []byte(attribute + nl)},
|
||||||
},
|
},
|
||||||
initPattern: "chain_0001.txt",
|
initPattern: "chain_0001.txt",
|
||||||
want: true,
|
want: assert.True,
|
||||||
}, {
|
}, {
|
||||||
name: "several",
|
name: "several",
|
||||||
testFS: testFSGen{
|
testFS: fstest.MapFS{
|
||||||
"several_0001.txt": []byte(`several_*` + nl),
|
"several_0001.txt": &fstest.MapFile{Data: []byte(`several_*` + nl)},
|
||||||
"several_0002.txt": []byte(`several_0001.txt` + nl),
|
"several_0002.txt": &fstest.MapFile{Data: []byte(`several_0001.txt` + nl)},
|
||||||
"several_0003.txt": []byte(attribute + nl),
|
"several_0003.txt": &fstest.MapFile{Data: []byte(attribute + nl)},
|
||||||
},
|
},
|
||||||
initPattern: "several_0001.txt",
|
initPattern: "several_0001.txt",
|
||||||
want: true,
|
want: assert.True,
|
||||||
}, {
|
}, {
|
||||||
name: "no",
|
name: "no",
|
||||||
testFS: testFSGen{
|
testFS: fstest.MapFS{
|
||||||
"no_0001.txt": []byte(nl),
|
"no_0001.txt": &fstest.MapFile{Data: []byte(nl)},
|
||||||
"no_0002.txt": []byte(nl),
|
"no_0002.txt": &fstest.MapFile{Data: []byte(nl)},
|
||||||
"no_0003.txt": []byte(nl),
|
"no_0003.txt": &fstest.MapFile{Data: []byte(nl)},
|
||||||
},
|
},
|
||||||
initPattern: "no_*",
|
initPattern: "no_*",
|
||||||
want: false,
|
want: assert.False,
|
||||||
}, {
|
}, {
|
||||||
name: "subdirectory",
|
name: "subdirectory",
|
||||||
testFS: testFSGen{
|
testFS: fstest.MapFS{
|
||||||
"dir": testFSDir{
|
path.Join("dir", "subdir_0002.txt"): &fstest.MapFile{
|
||||||
"subdir_0002.txt": []byte(attribute + nl),
|
Data: []byte(attribute + nl),
|
||||||
},
|
},
|
||||||
"subdir_0001.txt": []byte(`dir/*`),
|
"subdir_0001.txt": &fstest.MapFile{Data: []byte(`dir/*`)},
|
||||||
},
|
},
|
||||||
initPattern: "subdir_0001.txt",
|
initPattern: "subdir_0001.txt",
|
||||||
want: true,
|
want: assert.True,
|
||||||
}}
|
}}
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
testDir := tc.testFS.gen(t)
|
fw := makeFileWalker("")
|
||||||
fw := makeFileWalker(testDir)
|
|
||||||
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
ok, err := fw.Walk(filepath.Join(testDir, tc.initPattern))
|
ok, err := fw.Walk(tc.testFS, tc.initPattern)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, tc.want, ok)
|
tc.want(t, ok)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Run("pattern_malformed", func(t *testing.T) {
|
t.Run("pattern_malformed", func(t *testing.T) {
|
||||||
ok, err := makeFileWalker("").Walk("[]")
|
f := fstest.MapFS{}
|
||||||
|
ok, err := makeFileWalker("").Walk(f, "[]")
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
|
|
||||||
assert.False(t, ok)
|
assert.False(t, ok)
|
||||||
assert.ErrorIs(t, err, filepath.ErrBadPattern)
|
assert.ErrorIs(t, err, path.ErrBadPattern)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("bad_filename", func(t *testing.T) {
|
t.Run("bad_filename", func(t *testing.T) {
|
||||||
dir := testFSGen{
|
const filename = "bad_filename.txt"
|
||||||
"bad_filename.txt": []byte("[]"),
|
|
||||||
}.gen(t)
|
f := fstest.MapFS{
|
||||||
fw := FileWalker(func(r io.Reader) (patterns []string, cont bool, err error) {
|
filename: &fstest.MapFile{Data: []byte("[]")},
|
||||||
|
}
|
||||||
|
ok, err := FileWalker(func(r io.Reader) (patterns []string, cont bool, err error) {
|
||||||
s := bufio.NewScanner(r)
|
s := bufio.NewScanner(r)
|
||||||
for s.Scan() {
|
for s.Scan() {
|
||||||
patterns = append(patterns, s.Text())
|
patterns = append(patterns, s.Text())
|
||||||
}
|
}
|
||||||
|
|
||||||
return patterns, true, s.Err()
|
return patterns, true, s.Err()
|
||||||
})
|
}).Walk(f, filename)
|
||||||
|
|
||||||
ok, err := fw.Walk(filepath.Join(dir, "bad_filename.txt"))
|
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
|
|
||||||
assert.False(t, ok)
|
assert.False(t, ok)
|
||||||
assert.ErrorIs(t, err, filepath.ErrBadPattern)
|
assert.ErrorIs(t, err, path.ErrBadPattern)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("itself_error", func(t *testing.T) {
|
t.Run("itself_error", func(t *testing.T) {
|
||||||
const rerr errors.Error = "returned error"
|
const rerr errors.Error = "returned error"
|
||||||
|
|
||||||
dir := testFSGen{
|
f := fstest.MapFS{
|
||||||
"mockfile.txt": []byte(`mockdata`),
|
"mockfile.txt": &fstest.MapFile{Data: []byte(`mockdata`)},
|
||||||
}.gen(t)
|
}
|
||||||
|
|
||||||
ok, err := FileWalker(func(r io.Reader) (patterns []string, ok bool, err error) {
|
ok, err := FileWalker(func(r io.Reader) (patterns []string, ok bool, err error) {
|
||||||
return nil, true, rerr
|
return nil, true, rerr
|
||||||
}).Walk(filepath.Join(dir, "*"))
|
}).Walk(f, "*")
|
||||||
require.Error(t, err)
|
require.ErrorIs(t, err, rerr)
|
||||||
require.False(t, ok)
|
|
||||||
|
|
||||||
assert.ErrorIs(t, err, rerr)
|
assert.False(t, ok)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type errFS struct {
|
||||||
|
fs.GlobFS
|
||||||
|
}
|
||||||
|
|
||||||
|
const errErrFSOpen errors.Error = "this error is always returned"
|
||||||
|
|
||||||
|
func (efs *errFS) Open(name string) (fs.File, error) {
|
||||||
|
return nil, errErrFSOpen
|
||||||
|
}
|
||||||
|
|
||||||
func TestWalkerFunc_CheckFile(t *testing.T) {
|
func TestWalkerFunc_CheckFile(t *testing.T) {
|
||||||
|
emptyFS := fstest.MapFS{}
|
||||||
|
|
||||||
t.Run("non-existing", func(t *testing.T) {
|
t.Run("non-existing", func(t *testing.T) {
|
||||||
_, ok, err := checkFile(nil, "lol")
|
_, ok, err := checkFile(emptyFS, nil, "lol")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("invalid_argument", func(t *testing.T) {
|
t.Run("invalid_argument", func(t *testing.T) {
|
||||||
const badPath = "\x00"
|
_, ok, err := checkFile(&errFS{}, nil, "")
|
||||||
|
require.ErrorIs(t, err, errErrFSOpen)
|
||||||
_, ok, err := checkFile(nil, badPath)
|
|
||||||
require.Error(t, err)
|
|
||||||
|
|
||||||
assert.False(t, ok)
|
assert.False(t, ok)
|
||||||
// TODO(e.burkov): Use assert.ErrorsIs within the error from
|
})
|
||||||
// less platform-dependent package instead of syscall.EINVAL.
|
|
||||||
//
|
t.Run("ignore_dirs", func(t *testing.T) {
|
||||||
// See https://github.com/golang/go/issues/46849 and
|
const dirName = "dir"
|
||||||
// https://github.com/golang/go/issues/30322.
|
|
||||||
pathErr := &os.PathError{}
|
testFS := fstest.MapFS{
|
||||||
require.ErrorAs(t, err, &pathErr)
|
path.Join(dirName, "file"): &fstest.MapFile{Data: []byte{}},
|
||||||
assert.Equal(t, "open", pathErr.Op)
|
}
|
||||||
assert.Equal(t, badPath, pathErr.Path)
|
|
||||||
|
patterns, ok, err := checkFile(testFS, nil, dirName)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Empty(t, patterns)
|
||||||
|
assert.True(t, ok)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,131 @@
|
||||||
|
package aghos
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"io/fs"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
|
"github.com/AdguardTeam/golibs/log"
|
||||||
|
"github.com/fsnotify/fsnotify"
|
||||||
|
)
|
||||||
|
|
||||||
|
// event is a convenient alias for an empty struct to signal that watching
|
||||||
|
// event happened.
|
||||||
|
type event = struct{}
|
||||||
|
|
||||||
|
// FSWatcher tracks all the fyle system events and notifies about those.
|
||||||
|
//
|
||||||
|
// TODO(e.burkov, a.garipov): Move into another package like aghfs.
|
||||||
|
type FSWatcher interface {
|
||||||
|
io.Closer
|
||||||
|
|
||||||
|
// Events should return a read-only channel which notifies about events.
|
||||||
|
Events() (e <-chan event)
|
||||||
|
|
||||||
|
// Add should check if the file named name is accessible and starts tracking
|
||||||
|
// it.
|
||||||
|
Add(name string) (err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// osWatcher tracks the file system provided by the OS.
|
||||||
|
type osWatcher struct {
|
||||||
|
// w is the actual notifier that is handled by osWatcher.
|
||||||
|
w *fsnotify.Watcher
|
||||||
|
|
||||||
|
// events is the channel to notify.
|
||||||
|
events chan event
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
// osWatcherPref is a prefix for logging and wrapping errors in osWathcer's
|
||||||
|
// methods.
|
||||||
|
osWatcherPref = "os watcher"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewOSWritesWatcher creates FSWatcher that tracks the real file system of the
|
||||||
|
// OS and notifies only about writing events.
|
||||||
|
func NewOSWritesWatcher() (w FSWatcher, err error) {
|
||||||
|
defer func() { err = errors.Annotate(err, "%s: %w", osWatcherPref) }()
|
||||||
|
|
||||||
|
var watcher *fsnotify.Watcher
|
||||||
|
watcher, err = fsnotify.NewWatcher()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("creating watcher: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fsw := &osWatcher{
|
||||||
|
w: watcher,
|
||||||
|
events: make(chan event, 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
go fsw.handleErrors()
|
||||||
|
go fsw.handleEvents()
|
||||||
|
|
||||||
|
return fsw, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleErrors handles accompanying errors. It used to be called in a separate
|
||||||
|
// goroutine.
|
||||||
|
func (w *osWatcher) handleErrors() {
|
||||||
|
defer log.OnPanic(fmt.Sprintf("%s: handling errors", osWatcherPref))
|
||||||
|
|
||||||
|
for err := range w.w.Errors {
|
||||||
|
log.Error("%s: %s", osWatcherPref, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Events implements the FSWatcher interface for *osWatcher.
|
||||||
|
func (w *osWatcher) Events() (e <-chan event) {
|
||||||
|
return w.events
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add implements the FSWatcher interface for *osWatcher.
|
||||||
|
func (w *osWatcher) Add(name string) (err error) {
|
||||||
|
defer func() { err = errors.Annotate(err, "%s: %w", osWatcherPref) }()
|
||||||
|
|
||||||
|
if _, err = fs.Stat(RootDirFS(), name); err != nil {
|
||||||
|
return fmt.Errorf("checking file %q: %w", name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return w.w.Add(filepath.Join("/", name))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close implements the FSWatcher interface for *osWatcher.
|
||||||
|
func (w *osWatcher) Close() (err error) {
|
||||||
|
return w.w.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleEvents notifies about the received file system's event if needed. It
|
||||||
|
// used to be called in a separate goroutine.
|
||||||
|
func (w *osWatcher) handleEvents() {
|
||||||
|
defer log.OnPanic(fmt.Sprintf("%s: handling events", osWatcherPref))
|
||||||
|
|
||||||
|
defer close(w.events)
|
||||||
|
|
||||||
|
ch := w.w.Events
|
||||||
|
for e := range ch {
|
||||||
|
if e.Op&fsnotify.Write == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip the following events assuming that sometimes the same event
|
||||||
|
// occurrs several times.
|
||||||
|
for ok := true; ok; {
|
||||||
|
select {
|
||||||
|
case _, ok = <-ch:
|
||||||
|
// Go on.
|
||||||
|
default:
|
||||||
|
ok = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case w.events <- event{}:
|
||||||
|
// Go on.
|
||||||
|
default:
|
||||||
|
log.Debug("%s: events buffer is full", osWatcherPref)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -7,6 +7,8 @@ import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"io/fs"
|
||||||
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"path"
|
"path"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
@ -159,3 +161,10 @@ ScanLoop:
|
||||||
func IsOpenWrt() (ok bool) {
|
func IsOpenWrt() (ok bool) {
|
||||||
return isOpenWrt()
|
return isOpenWrt()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RootDirFS returns the fs.FS rooted at the operating system's root.
|
||||||
|
func RootDirFS() (fsys fs.FS) {
|
||||||
|
// Use empty string since os.DirFS implicitly prepends a slash to it. This
|
||||||
|
// behavior is undocumented but it currently works.
|
||||||
|
return os.DirFS("")
|
||||||
|
}
|
||||||
|
|
|
@ -26,6 +26,8 @@ func haveAdminRights() (bool, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func isOpenWrt() (ok bool) {
|
func isOpenWrt() (ok bool) {
|
||||||
|
const etcReleasePattern = "etc/*release*"
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
ok, err = FileWalker(func(r io.Reader) (_ []string, cont bool, err error) {
|
ok, err = FileWalker(func(r io.Reader) (_ []string, cont bool, err error) {
|
||||||
const osNameData = "openwrt"
|
const osNameData = "openwrt"
|
||||||
|
@ -39,7 +41,7 @@ func isOpenWrt() (ok bool) {
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, !stringutil.ContainsFold(string(data), osNameData), nil
|
return nil, !stringutil.ContainsFold(string(data), osNameData), nil
|
||||||
}).Walk("/etc/*release*")
|
}).Walk(RootDirFS(), etcReleasePattern)
|
||||||
|
|
||||||
return err == nil && ok
|
return err == nil && ok
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,23 @@
|
||||||
|
package aghtest
|
||||||
|
|
||||||
|
// FSWatcher is a mock aghos.FSWatcher implementation to use in tests.
|
||||||
|
type FSWatcher struct {
|
||||||
|
OnEvents func() (e <-chan struct{})
|
||||||
|
OnAdd func(name string) (err error)
|
||||||
|
OnClose func() (err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Events implements the aghos.FSWatcher interface for *FSWatcher.
|
||||||
|
func (w *FSWatcher) Events() (e <-chan struct{}) {
|
||||||
|
return w.OnEvents()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add implements the aghos.FSWatcher interface for *FSWatcher.
|
||||||
|
func (w *FSWatcher) Add(name string) (err error) {
|
||||||
|
return w.OnAdd(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close implements the aghos.FSWatcher interface for *FSWatcher.
|
||||||
|
func (w *FSWatcher) Close() (err error) {
|
||||||
|
return w.OnClose()
|
||||||
|
}
|
|
@ -11,9 +11,10 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/big"
|
"math/big"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
|
"testing/fstest"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||||
|
@ -1057,23 +1058,43 @@ func TestPTRResponseFromDHCPLeases(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPTRResponseFromHosts(t *testing.T) {
|
func TestPTRResponseFromHosts(t *testing.T) {
|
||||||
c := filtering.Config{
|
// Prepare test hosts file.
|
||||||
EtcHosts: &aghnet.EtcHostsContainer{},
|
|
||||||
|
const hostsFilename = "hosts"
|
||||||
|
|
||||||
|
testFS := fstest.MapFS{
|
||||||
|
hostsFilename: &fstest.MapFile{Data: []byte(`
|
||||||
|
127.0.0.1 host # comment
|
||||||
|
::1 localhost#comment
|
||||||
|
`)},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prepare test hosts file.
|
const closeCalled errors.Error = "close method called"
|
||||||
hf, err := os.CreateTemp("", "")
|
|
||||||
|
var eventsCalledCounter uint32
|
||||||
|
hc, err := aghnet.NewHostsContainer(testFS, &aghtest.FSWatcher{
|
||||||
|
OnEvents: func() (e <-chan struct{}) {
|
||||||
|
assert.Equal(t, uint32(1), atomic.AddUint32(&eventsCalledCounter, 1))
|
||||||
|
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
OnAdd: func(name string) (err error) {
|
||||||
|
assert.Equal(t, hostsFilename, name)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
OnClose: func() (err error) { return closeCalled },
|
||||||
|
}, hostsFilename)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
assert.NoError(t, hf.Close())
|
assert.Equal(t, uint32(1), atomic.LoadUint32(&eventsCalledCounter))
|
||||||
assert.NoError(t, os.Remove(hf.Name()))
|
|
||||||
|
require.ErrorIs(t, hc.Close(), closeCalled)
|
||||||
})
|
})
|
||||||
|
|
||||||
_, _ = hf.WriteString(" 127.0.0.1 host # comment \n")
|
c := filtering.Config{
|
||||||
_, _ = hf.WriteString(" ::1 localhost#comment \n")
|
EtcHosts: hc,
|
||||||
|
}
|
||||||
c.EtcHosts.Init(hf.Name())
|
|
||||||
t.Cleanup(c.EtcHosts.Close)
|
|
||||||
|
|
||||||
var snd *aghnet.SubnetDetector
|
var snd *aghnet.SubnetDetector
|
||||||
snd, err = aghnet.NewSubnetDetector()
|
snd, err = aghnet.NewSubnetDetector()
|
||||||
|
@ -1109,7 +1130,7 @@ func TestPTRResponseFromHosts(t *testing.T) {
|
||||||
resp, err := dns.Exchange(req, addr.String())
|
resp, err := dns.Exchange(req, addr.String())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
require.Len(t, resp.Answer, 1)
|
require.Lenf(t, resp.Answer, 1, "%#v", resp)
|
||||||
|
|
||||||
assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype)
|
assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype)
|
||||||
assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name)
|
assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name)
|
||||||
|
|
|
@ -15,9 +15,9 @@ type DNSRewriteResult struct {
|
||||||
// the server returns.
|
// the server returns.
|
||||||
type DNSRewriteResultResponse map[rules.RRType][]rules.RRValue
|
type DNSRewriteResultResponse map[rules.RRType][]rules.RRValue
|
||||||
|
|
||||||
// processDNSRewrites processes DNS rewrite rules in dnsr. It returns
|
// processDNSRewrites processes DNS rewrite rules in dnsr. It returns an empty
|
||||||
// an empty result if dnsr is empty. Otherwise, the result will have
|
// result if dnsr is empty. Otherwise, the result will have either CanonName or
|
||||||
// either CanonName or DNSRewriteResult set.
|
// DNSRewriteResult set.
|
||||||
func (d *DNSFilter) processDNSRewrites(dnsr []*rules.NetworkRule) (res Result) {
|
func (d *DNSFilter) processDNSRewrites(dnsr []*rules.NetworkRule) (res Result) {
|
||||||
if len(dnsr) == 0 {
|
if len(dnsr) == 0 {
|
||||||
return Result{}
|
return Result{}
|
||||||
|
|
|
@ -73,7 +73,7 @@ type Config struct {
|
||||||
|
|
||||||
// EtcHosts is a container of IP-hostname pairs taken from the operating
|
// EtcHosts is a container of IP-hostname pairs taken from the operating
|
||||||
// system configuration files (e.g. /etc/hosts).
|
// system configuration files (e.g. /etc/hosts).
|
||||||
EtcHosts *aghnet.EtcHostsContainer `yaml:"-"`
|
EtcHosts *aghnet.HostsContainer `yaml:"-"`
|
||||||
|
|
||||||
// Called when the configuration is changed by HTTP request
|
// Called when the configuration is changed by HTTP request
|
||||||
ConfigModified func() `yaml:"-"`
|
ConfigModified func() `yaml:"-"`
|
||||||
|
@ -176,8 +176,8 @@ const (
|
||||||
// FilteredBlockedService - the host is blocked by "blocked services" settings
|
// FilteredBlockedService - the host is blocked by "blocked services" settings
|
||||||
FilteredBlockedService
|
FilteredBlockedService
|
||||||
|
|
||||||
// Rewritten is returned when there was a rewrite by a legacy DNS
|
// Rewritten is returned when there was a rewrite by a legacy DNS rewrite
|
||||||
// rewrite rule.
|
// rule.
|
||||||
Rewritten
|
Rewritten
|
||||||
|
|
||||||
// RewrittenAutoHosts is returned when there was a rewrite by autohosts
|
// RewrittenAutoHosts is returned when there was a rewrite by autohosts
|
||||||
|
@ -186,8 +186,8 @@ const (
|
||||||
|
|
||||||
// RewrittenRule is returned when a $dnsrewrite filter rule was applied.
|
// RewrittenRule is returned when a $dnsrewrite filter rule was applied.
|
||||||
//
|
//
|
||||||
// TODO(a.garipov): Remove Rewritten and RewrittenAutoHosts by merging
|
// TODO(a.garipov): Remove Rewritten and RewrittenAutoHosts by merging their
|
||||||
// their functionality into RewrittenRule.
|
// functionality into RewrittenRule.
|
||||||
//
|
//
|
||||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/2499.
|
// See https://github.com/AdguardTeam/AdGuardHome/issues/2499.
|
||||||
RewrittenRule
|
RewrittenRule
|
||||||
|
@ -371,24 +371,23 @@ type Result struct {
|
||||||
// Reason is the reason for blocking or unblocking the request.
|
// Reason is the reason for blocking or unblocking the request.
|
||||||
Reason Reason `json:",omitempty"`
|
Reason Reason `json:",omitempty"`
|
||||||
|
|
||||||
// Rules are applied rules. If Rules are not empty, each rule
|
// Rules are applied rules. If Rules are not empty, each rule is not nil.
|
||||||
// is not nil.
|
|
||||||
Rules []*ResultRule `json:",omitempty"`
|
Rules []*ResultRule `json:",omitempty"`
|
||||||
|
|
||||||
// ReverseHosts is the reverse lookup rewrite result. It is
|
// ReverseHosts is the reverse lookup rewrite result. It is empty unless
|
||||||
// empty unless Reason is set to RewrittenAutoHosts.
|
// Reason is set to RewrittenAutoHosts.
|
||||||
ReverseHosts []string `json:",omitempty"`
|
ReverseHosts []string `json:",omitempty"`
|
||||||
|
|
||||||
// IPList is the lookup rewrite result. It is empty unless
|
// IPList is the lookup rewrite result. It is empty unless Reason is set to
|
||||||
// Reason is set to RewrittenAutoHosts or Rewritten.
|
// RewrittenAutoHosts or Rewritten.
|
||||||
IPList []net.IP `json:",omitempty"`
|
IPList []net.IP `json:",omitempty"`
|
||||||
|
|
||||||
// CanonName is the CNAME value from the lookup rewrite result.
|
// CanonName is the CNAME value from the lookup rewrite result. It is empty
|
||||||
// It is empty unless Reason is set to Rewritten or RewrittenRule.
|
// unless Reason is set to Rewritten or RewrittenRule.
|
||||||
CanonName string `json:",omitempty"`
|
CanonName string `json:",omitempty"`
|
||||||
|
|
||||||
// ServiceName is the name of the blocked service. It is empty
|
// ServiceName is the name of the blocked service. It is empty unless
|
||||||
// unless Reason is set to FilteredBlockedService.
|
// Reason is set to FilteredBlockedService.
|
||||||
ServiceName string `json:",omitempty"`
|
ServiceName string `json:",omitempty"`
|
||||||
|
|
||||||
// DNSRewriteResult is the $dnsrewrite filter rule result.
|
// DNSRewriteResult is the $dnsrewrite filter rule result.
|
||||||
|
@ -446,43 +445,49 @@ func (d *DNSFilter) CheckHost(
|
||||||
return Result{}, nil
|
return Result{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkEtcHosts compares the host against our /etc/hosts table. The err is
|
// matchSysHosts tries to match the host against the operating system's hosts
|
||||||
// always nil, it is only there to make this a valid hostChecker function.
|
// database.
|
||||||
func (d *DNSFilter) checkEtcHosts(
|
func (d *DNSFilter) matchSysHosts(host string, qtype uint16, setts *Settings) (res Result, err error) {
|
||||||
host string,
|
if d.EtcHosts == nil {
|
||||||
qtype uint16,
|
|
||||||
_ *Settings,
|
|
||||||
) (res Result, err error) {
|
|
||||||
if d.Config.EtcHosts == nil {
|
|
||||||
return Result{}, nil
|
return Result{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ips := d.Config.EtcHosts.Process(host, qtype)
|
dnsres, _ := d.EtcHosts.MatchRequest(urlfilter.DNSRequest{
|
||||||
if ips != nil {
|
Hostname: host,
|
||||||
res = Result{
|
SortedClientTags: setts.ClientTags,
|
||||||
Reason: RewrittenAutoHosts,
|
// TODO(e.burkov): Wait for urlfilter update to pass net.IP.
|
||||||
IPList: ips,
|
ClientIP: setts.ClientIP.String(),
|
||||||
}
|
ClientName: setts.ClientName,
|
||||||
|
DNSType: qtype,
|
||||||
|
})
|
||||||
|
|
||||||
return res, nil
|
dnsr := dnsres.DNSRewrites()
|
||||||
|
if len(dnsr) == 0 {
|
||||||
|
return Result{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
revHosts := d.Config.EtcHosts.ProcessReverse(host, qtype)
|
var ips []net.IP
|
||||||
if len(revHosts) != 0 {
|
var revHosts []string
|
||||||
res = Result{
|
|
||||||
Reason: RewrittenAutoHosts,
|
for _, nr := range dnsr {
|
||||||
|
dr := nr.DNSRewrite
|
||||||
|
if dr == nil {
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(a.garipov): Optimize this with a buffer.
|
switch val := nr.DNSRewrite.Value.(type) {
|
||||||
res.ReverseHosts = make([]string, len(revHosts))
|
case net.IP:
|
||||||
for i := range revHosts {
|
ips = append(ips, val)
|
||||||
res.ReverseHosts[i] = revHosts[i] + "."
|
case string:
|
||||||
|
revHosts = append(revHosts, val)
|
||||||
}
|
}
|
||||||
|
|
||||||
return res, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return Result{}, nil
|
return Result{
|
||||||
|
Reason: RewrittenAutoHosts,
|
||||||
|
IPList: ips,
|
||||||
|
ReverseHosts: revHosts,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process rewrites table
|
// Process rewrites table
|
||||||
|
@ -647,15 +652,18 @@ func (d *DNSFilter) initFiltering(allowFilters, blockFilters []Filter) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
d.engineLock.Lock()
|
func() {
|
||||||
d.reset()
|
d.engineLock.Lock()
|
||||||
d.rulesStorage = rulesStorage
|
defer d.engineLock.Unlock()
|
||||||
d.filteringEngine = filteringEngine
|
|
||||||
d.rulesStorageAllow = rulesStorageAllow
|
|
||||||
d.filteringEngineAllow = filteringEngineAllow
|
|
||||||
d.engineLock.Unlock()
|
|
||||||
|
|
||||||
// Make sure that the OS reclaims memory as soon as possible
|
d.reset()
|
||||||
|
d.rulesStorage = rulesStorage
|
||||||
|
d.filteringEngine = filteringEngine
|
||||||
|
d.rulesStorageAllow = rulesStorageAllow
|
||||||
|
d.filteringEngineAllow = filteringEngineAllow
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Make sure that the OS reclaims memory as soon as possible.
|
||||||
debug.FreeOSMemory()
|
debug.FreeOSMemory()
|
||||||
log.Debug("initialized filtering engine")
|
log.Debug("initialized filtering engine")
|
||||||
|
|
||||||
|
@ -734,8 +742,8 @@ func (d *DNSFilter) matchHostProcessDNSResult(
|
||||||
}
|
}
|
||||||
|
|
||||||
if dnsres.HostRulesV4 != nil || dnsres.HostRulesV6 != nil {
|
if dnsres.HostRulesV4 != nil || dnsres.HostRulesV6 != nil {
|
||||||
// Question type doesn't match the host rules. Return the first
|
// Question type doesn't match the host rules. Return the first matched
|
||||||
// matched host rule, but without an IP address.
|
// host rule, but without an IP address.
|
||||||
var matchedRules []rules.Rule
|
var matchedRules []rules.Rule
|
||||||
if dnsres.HostRulesV4 != nil {
|
if dnsres.HostRulesV4 != nil {
|
||||||
matchedRules = []rules.Rule{dnsres.HostRulesV4[0]}
|
matchedRules = []rules.Rule{dnsres.HostRulesV4[0]}
|
||||||
|
@ -760,11 +768,6 @@ func (d *DNSFilter) matchHost(
|
||||||
return Result{}, nil
|
return Result{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
d.engineLock.RLock()
|
|
||||||
// Keep in mind that this lock must be held no just when calling Match()
|
|
||||||
// but also while using the rules returned by it.
|
|
||||||
defer d.engineLock.RUnlock()
|
|
||||||
|
|
||||||
ureq := urlfilter.DNSRequest{
|
ureq := urlfilter.DNSRequest{
|
||||||
Hostname: host,
|
Hostname: host,
|
||||||
SortedClientTags: setts.ClientTags,
|
SortedClientTags: setts.ClientTags,
|
||||||
|
@ -774,6 +777,13 @@ func (d *DNSFilter) matchHost(
|
||||||
DNSType: qtype,
|
DNSType: qtype,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
d.engineLock.RLock()
|
||||||
|
// Keep in mind that this lock must be held no just when calling Match() but
|
||||||
|
// also while using the rules returned by it.
|
||||||
|
//
|
||||||
|
// TODO(e.burkov): Inspect if the above is true.
|
||||||
|
defer d.engineLock.RUnlock()
|
||||||
|
|
||||||
if d.filteringEngineAllow != nil {
|
if d.filteringEngineAllow != nil {
|
||||||
dnsres, ok := d.filteringEngineAllow.MatchRequest(ureq)
|
dnsres, ok := d.filteringEngineAllow.MatchRequest(ureq)
|
||||||
if ok {
|
if ok {
|
||||||
|
@ -791,8 +801,8 @@ func (d *DNSFilter) matchHost(
|
||||||
if dnsr := dnsres.DNSRewrites(); len(dnsr) > 0 {
|
if dnsr := dnsres.DNSRewrites(); len(dnsr) > 0 {
|
||||||
res = d.processDNSRewrites(dnsr)
|
res = d.processDNSRewrites(dnsr)
|
||||||
if res.Reason == RewrittenRule && res.CanonName == host {
|
if res.Reason == RewrittenRule && res.CanonName == host {
|
||||||
// A rewrite of a host to itself. Go on and try
|
// A rewrite of a host to itself. Go on and try matching other
|
||||||
// matching other things.
|
// things.
|
||||||
} else {
|
} else {
|
||||||
return res, nil
|
return res, nil
|
||||||
}
|
}
|
||||||
|
@ -868,8 +878,8 @@ func New(c *Config, blockFilters []Filter) *DNSFilter {
|
||||||
}
|
}
|
||||||
|
|
||||||
d.hostCheckers = []hostChecker{{
|
d.hostCheckers = []hostChecker{{
|
||||||
check: d.checkEtcHosts,
|
check: d.matchSysHosts,
|
||||||
name: "etchosts",
|
name: "hosts container",
|
||||||
}, {
|
}, {
|
||||||
check: d.matchHost,
|
check: d.matchHost,
|
||||||
name: "filtering",
|
name: "filtering",
|
||||||
|
|
|
@ -95,7 +95,9 @@ type clientsContainer struct {
|
||||||
// dnsServer is used for checking clients IP status access list status
|
// dnsServer is used for checking clients IP status access list status
|
||||||
dnsServer *dnsforward.Server
|
dnsServer *dnsforward.Server
|
||||||
|
|
||||||
etcHosts *aghnet.EtcHostsContainer // get entries from system hosts-files
|
// etcHosts contains list of rewrite rules taken from the operating system's
|
||||||
|
// hosts databse.
|
||||||
|
etcHosts *aghnet.HostsContainer
|
||||||
|
|
||||||
testing bool // if TRUE, this object is used for internal tests
|
testing bool // if TRUE, this object is used for internal tests
|
||||||
}
|
}
|
||||||
|
@ -106,7 +108,7 @@ type clientsContainer struct {
|
||||||
func (clients *clientsContainer) Init(
|
func (clients *clientsContainer) Init(
|
||||||
objects []clientObject,
|
objects []clientObject,
|
||||||
dhcpServer *dhcpd.Server,
|
dhcpServer *dhcpd.Server,
|
||||||
etcHosts *aghnet.EtcHostsContainer,
|
etcHosts *aghnet.HostsContainer,
|
||||||
) {
|
) {
|
||||||
if clients.list != nil {
|
if clients.list != nil {
|
||||||
log.Fatal("clients.list != nil")
|
log.Fatal("clients.list != nil")
|
||||||
|
@ -121,13 +123,22 @@ func (clients *clientsContainer) Init(
|
||||||
clients.etcHosts = etcHosts
|
clients.etcHosts = etcHosts
|
||||||
clients.addFromConfig(objects)
|
clients.addFromConfig(objects)
|
||||||
|
|
||||||
if !clients.testing {
|
if clients.testing {
|
||||||
clients.updateFromDHCP(true)
|
return
|
||||||
if clients.dhcpServer != nil {
|
}
|
||||||
clients.dhcpServer.SetOnLeaseChanged(clients.onDHCPLeaseChanged)
|
|
||||||
}
|
clients.updateFromDHCP(true)
|
||||||
if clients.etcHosts != nil {
|
if clients.dhcpServer != nil {
|
||||||
clients.etcHosts.SetOnChanged(clients.onHostsChanged)
|
clients.dhcpServer.SetOnLeaseChanged(clients.onDHCPLeaseChanged)
|
||||||
|
}
|
||||||
|
|
||||||
|
go clients.handleHostsUpdates()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (clients *clientsContainer) handleHostsUpdates() {
|
||||||
|
if clients.etcHosts != nil {
|
||||||
|
for upd := range clients.etcHosts.Upd() {
|
||||||
|
clients.addFromHostsFile(upd)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -250,10 +261,6 @@ func (clients *clientsContainer) onDHCPLeaseChanged(flags int) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (clients *clientsContainer) onHostsChanged() {
|
|
||||||
clients.addFromHostsFile()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Exists checks if client with this IP address already exists.
|
// Exists checks if client with this IP address already exists.
|
||||||
func (clients *clientsContainer) Exists(ip net.IP, source clientSource) (ok bool) {
|
func (clients *clientsContainer) Exists(ip net.IP, source clientSource) (ok bool) {
|
||||||
clients.lock.Lock()
|
clients.lock.Lock()
|
||||||
|
@ -697,7 +704,7 @@ func (clients *clientsContainer) SetWHOISInfo(ip net.IP, wi *RuntimeClientWHOISI
|
||||||
log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi)
|
log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddHost adds a new IP-hostname pairing. The priorities of the sources is
|
// AddHost adds a new IP-hostname pairing. The priorities of the sources are
|
||||||
// taken into account. ok is true if the pairing was added.
|
// taken into account. ok is true if the pairing was added.
|
||||||
func (clients *clientsContainer) AddHost(ip net.IP, host string, src clientSource) (ok bool, err error) {
|
func (clients *clientsContainer) AddHost(ip net.IP, host string, src clientSource) (ok bool, err error) {
|
||||||
clients.lock.Lock()
|
clients.lock.Lock()
|
||||||
|
@ -757,13 +764,7 @@ func (clients *clientsContainer) rmHostsBySrc(src clientSource) {
|
||||||
|
|
||||||
// addFromHostsFile fills the client-hostname pairing index from the system's
|
// addFromHostsFile fills the client-hostname pairing index from the system's
|
||||||
// hosts files.
|
// hosts files.
|
||||||
func (clients *clientsContainer) addFromHostsFile() {
|
func (clients *clientsContainer) addFromHostsFile(hosts *netutil.IPMap) {
|
||||||
if clients.etcHosts == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
hosts := clients.etcHosts.List()
|
|
||||||
|
|
||||||
clients.lock.Lock()
|
clients.lock.Lock()
|
||||||
defer clients.lock.Unlock()
|
defer clients.lock.Unlock()
|
||||||
|
|
||||||
|
|
|
@ -710,7 +710,6 @@ func enableFilters(async bool) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func enableFiltersLocked(async bool) {
|
func enableFiltersLocked(async bool) {
|
||||||
var whiteFilters []filtering.Filter
|
|
||||||
filters := []filtering.Filter{{
|
filters := []filtering.Filter{{
|
||||||
Data: []byte(strings.Join(config.UserRules, "\n")),
|
Data: []byte(strings.Join(config.UserRules, "\n")),
|
||||||
}}
|
}}
|
||||||
|
@ -725,18 +724,20 @@ func enableFiltersLocked(async bool) {
|
||||||
FilePath: filter.Path(),
|
FilePath: filter.Path(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var allowFilters []filtering.Filter
|
||||||
for _, filter := range config.WhitelistFilters {
|
for _, filter := range config.WhitelistFilters {
|
||||||
if !filter.Enabled {
|
if !filter.Enabled {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
whiteFilters = append(whiteFilters, filtering.Filter{
|
allowFilters = append(allowFilters, filtering.Filter{
|
||||||
ID: filter.ID,
|
ID: filter.ID,
|
||||||
FilePath: filter.Path(),
|
FilePath: filter.Path(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := Context.dnsFilter.SetFilters(filters, whiteFilters, async); err != nil {
|
if err := Context.dnsFilter.SetFilters(filters, allowFilters, async); err != nil {
|
||||||
log.Debug("enabling filters: %s", err)
|
log.Debug("enabling filters: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -44,20 +44,22 @@ type homeContext struct {
|
||||||
// Modules
|
// Modules
|
||||||
// --
|
// --
|
||||||
|
|
||||||
clients clientsContainer // per-client-settings module
|
clients clientsContainer // per-client-settings module
|
||||||
stats stats.Stats // statistics module
|
stats stats.Stats // statistics module
|
||||||
queryLog querylog.QueryLog // query log module
|
queryLog querylog.QueryLog // query log module
|
||||||
dnsServer *dnsforward.Server // DNS module
|
dnsServer *dnsforward.Server // DNS module
|
||||||
rdns *RDNS // rDNS module
|
rdns *RDNS // rDNS module
|
||||||
whois *WHOIS // WHOIS module
|
whois *WHOIS // WHOIS module
|
||||||
dnsFilter *filtering.DNSFilter // DNS filtering module
|
dnsFilter *filtering.DNSFilter // DNS filtering module
|
||||||
dhcpServer *dhcpd.Server // DHCP module
|
dhcpServer *dhcpd.Server // DHCP module
|
||||||
auth *Auth // HTTP authentication module
|
auth *Auth // HTTP authentication module
|
||||||
filters Filtering // DNS filtering module
|
filters Filtering // DNS filtering module
|
||||||
web *Web // Web (HTTP, HTTPS) module
|
web *Web // Web (HTTP, HTTPS) module
|
||||||
tls *TLSMod // TLS module
|
tls *TLSMod // TLS module
|
||||||
etcHosts *aghnet.EtcHostsContainer // IP-hostname pairs taken from system configuration (e.g. /etc/hosts) files
|
// etcHosts is an IP-hostname pairs set taken from system configuration
|
||||||
updater *updater.Updater
|
// (e.g. /etc/hosts) files.
|
||||||
|
etcHosts *aghnet.HostsContainer
|
||||||
|
updater *updater.Updater
|
||||||
|
|
||||||
subnetDetector *aghnet.SubnetDetector
|
subnetDetector *aghnet.SubnetDetector
|
||||||
|
|
||||||
|
@ -257,8 +259,20 @@ func setupConfig(args options) (err error) {
|
||||||
})
|
})
|
||||||
|
|
||||||
if !args.noEtcHosts {
|
if !args.noEtcHosts {
|
||||||
Context.etcHosts = &aghnet.EtcHostsContainer{}
|
var osWritesWatcher aghos.FSWatcher
|
||||||
Context.etcHosts.Init("")
|
osWritesWatcher, err = aghos.NewOSWritesWatcher()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("initing os watcher: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
Context.etcHosts, err = aghnet.NewHostsContainer(
|
||||||
|
aghos.RootDirFS(),
|
||||||
|
osWritesWatcher,
|
||||||
|
aghnet.DefaultHostsPaths()...,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("initing hosts container: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Context.clients.Init(config.Clients, Context.dhcpServer, Context.etcHosts)
|
Context.clients.Init(config.Clients, Context.dhcpServer, Context.etcHosts)
|
||||||
config.Clients = nil
|
config.Clients = nil
|
||||||
|
@ -424,7 +438,6 @@ func run(args options, clientBuildFS fs.FS) {
|
||||||
fatalOnError(err)
|
fatalOnError(err)
|
||||||
|
|
||||||
Context.tls.Start()
|
Context.tls.Start()
|
||||||
Context.etcHosts.Start()
|
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
serr := startDNSServer()
|
serr := startDNSServer()
|
||||||
|
@ -647,7 +660,11 @@ func cleanup(ctx context.Context) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Context.etcHosts.Close()
|
if Context.etcHosts != nil {
|
||||||
|
if err = Context.etcHosts.Close(); err != nil {
|
||||||
|
log.Error("stopping hosts container: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if Context.tls != nil {
|
if Context.tls != nil {
|
||||||
Context.tls.Close()
|
Context.tls.Close()
|
||||||
|
|
Loading…
Reference in New Issue