diff --git a/internal/aghnet/arpdb.go b/internal/aghnet/arpdb.go index 759a688a..afb880c3 100644 --- a/internal/aghnet/arpdb.go +++ b/internal/aghnet/arpdb.go @@ -27,15 +27,8 @@ type ARPDB interface { } // NewARPDB returns the ARPDB properly initialized for the OS. -func NewARPDB() (arp ARPDB, err error) { - arp = newARPDB() - - err = arp.Refresh() - if err != nil { - return nil, fmt.Errorf("arpdb initial refresh: %w", err) - } - - return arp, nil +func NewARPDB() (arp ARPDB) { + return newARPDB() } // Empty ARPDB implementation diff --git a/internal/aghnet/arpdb_bsd_test.go b/internal/aghnet/arpdb_bsd_test.go index bbadc600..3404af69 100644 --- a/internal/aghnet/arpdb_bsd_test.go +++ b/internal/aghnet/arpdb_bsd_test.go @@ -8,8 +8,12 @@ import ( ) const arpAOutput = ` +invalid.mac (1.2.3.4) at 12:34:56:78:910 on el0 ifscope [ethernet] +invalid.ip (1.2.3.4.5) at ab:cd:ef:ab:cd:12 on ek0 ifscope [ethernet] +invalid.fmt 1 at 12:cd:ef:ab:cd:ef on er0 ifscope [ethernet] hostname.one (192.168.1.2) at ab:cd:ef:ab:cd:ef on en0 ifscope [ethernet] hostname.two (::ffff:ffff) at ef:cd:ab:ef:cd:ab on em0 expires in 1198 seconds [ethernet] +? (::1234) at aa:bb:cc:dd:ee:ff on ej0 expires in 1918 seconds [ethernet] ` var wantNeighs = []Neighbor{{ @@ -20,4 +24,8 @@ var wantNeighs = []Neighbor{{ Name: "hostname.two", IP: net.ParseIP("::ffff:ffff"), MAC: net.HardwareAddr{0xEF, 0xCD, 0xAB, 0xEF, 0xCD, 0xAB}, +}, { + Name: "", + IP: net.ParseIP("::1234"), + MAC: net.HardwareAddr{0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF}, }} diff --git a/internal/aghnet/arpdb_linux_test.go b/internal/aghnet/arpdb_linux_test.go index 0439c9b5..b9ed859f 100644 --- a/internal/aghnet/arpdb_linux_test.go +++ b/internal/aghnet/arpdb_linux_test.go @@ -16,14 +16,21 @@ import ( const arpAOutputWrt = ` IP address HW type Flags HW address Mask Device +1.2.3.4.5 0x1 0x2 aa:bb:cc:dd:ee:ff * wan +1.2.3.4 0x1 0x2 12:34:56:78:910 * wan 192.168.1.2 0x1 0x2 ab:cd:ef:ab:cd:ef * wan ::ffff:ffff 0x1 0x2 ef:cd:ab:ef:cd:ab * wan` const arpAOutput = ` +invalid.mac (1.2.3.4) at 12:34:56:78:910 on el0 ifscope [ethernet] +invalid.ip (1.2.3.4.5) at ab:cd:ef:ab:cd:12 on ek0 ifscope [ethernet] +invalid.fmt 1 at 12:cd:ef:ab:cd:ef on er0 ifscope [ethernet] ? (192.168.1.2) at ab:cd:ef:ab:cd:ef on en0 ifscope [ethernet] ? (::ffff:ffff) at ef:cd:ab:ef:cd:ab on em0 expires in 100 seconds [ethernet]` const ipNeighOutput = ` +1.2.3.4.5 dev enp0s3 lladdr aa:bb:cc:dd:ee:ff DELAY +1.2.3.4 dev enp0s3 lladdr 12:34:56:78:910 DELAY 192.168.1.2 dev enp0s3 lladdr ab:cd:ef:ab:cd:ef DELAY ::ffff:ffff dev enp0s3 lladdr ef:cd:ab:ef:cd:ab router STALE` diff --git a/internal/aghnet/arpdb_openbsd_test.go b/internal/aghnet/arpdb_openbsd_test.go index b1021513..915c17ff 100644 --- a/internal/aghnet/arpdb_openbsd_test.go +++ b/internal/aghnet/arpdb_openbsd_test.go @@ -9,6 +9,8 @@ import ( const arpAOutput = ` Host Ethernet Address Netif Expire Flags +1.2.3.4.5 aa:bb:cc:dd:ee:ff em0 permanent +1.2.3.4 12:34:56:78:910 em0 permanent 192.168.1.2 ab:cd:ef:ab:cd:ef em0 19m56s ::ffff:ffff ef:cd:ab:ef:cd:ab em0 permanent l ` diff --git a/internal/aghnet/arpdb_test.go b/internal/aghnet/arpdb_test.go index ce3da4fb..aa5c821a 100644 --- a/internal/aghnet/arpdb_test.go +++ b/internal/aghnet/arpdb_test.go @@ -6,6 +6,7 @@ import ( "strings" "sync" "testing" + "testing/iotest" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/testutil" @@ -13,6 +14,15 @@ import ( "github.com/stretchr/testify/require" ) +func TestNewARPDB(t *testing.T) { + var a ARPDB + require.NotPanics(t, func() { + a = NewARPDB() + }) + + assert.NotNil(t, a) +} + // TestARPDB is the mock implementation of ARPDB to use in tests. type TestARPDB struct { OnRefresh func() (err error) @@ -166,3 +176,46 @@ func TestCmdARPDB_arpa(t *testing.T) { testutil.AssertErrorMsg(t, "cmd arpdb: running command: can't run", err) }) } + +func TestCmdARPDB_errors(t *testing.T) { + const errRead errors.Error = "can't read" + + badReaderRunCmd := runCmdFunc(func() (r io.Reader, err error) { + return iotest.ErrReader(errRead), nil + }) + + a := &cmdARPDB{ + runcmd: badReaderRunCmd, + parse: parseArpA, + ns: &neighs{ + mu: &sync.RWMutex{}, + ns: make([]Neighbor, 0), + }, + } + + const wantErrMsg string = "cmd arpdb: scanning the output: " + string(errRead) + + testutil.AssertErrorMsg(t, wantErrMsg, a.Refresh()) +} + +func TestEmptyARPDB(t *testing.T) { + a := EmptyARPDB{} + + t.Run("refresh", func(t *testing.T) { + var err error + require.NotPanics(t, func() { + err = a.Refresh() + }) + + assert.NoError(t, err) + }) + + t.Run("neighbors", func(t *testing.T) { + var ns []Neighbor + require.NotPanics(t, func() { + ns = a.Neighbors() + }) + + assert.Empty(t, ns) + }) +} diff --git a/internal/aghnet/hostscontainer.go b/internal/aghnet/hostscontainer.go index 198d4c78..290dc1c0 100644 --- a/internal/aghnet/hostscontainer.go +++ b/internal/aghnet/hostscontainer.go @@ -368,8 +368,8 @@ func (hp *hostsParser) addPairs(ip net.IP, hosts []string) { } } -// writeRules writes the actual rule for the qtype and the PTR for the -// host-ip pair into internal builders. +// writeRules writes the actual rule for the qtype and the PTR for the host-ip +// pair into internal builders. func (hp *hostsParser) writeRules(host string, ip net.IP) (rule, rulePtr string) { arpa, err := netutil.IPToReversedAddr(ip) if err != nil { diff --git a/internal/aghnet/net.go b/internal/aghnet/net.go index ecb70fa8..2d791c03 100644 --- a/internal/aghnet/net.go +++ b/internal/aghnet/net.go @@ -105,8 +105,7 @@ func GetValidNetInterfacesForWeb() ([]*NetInterface, error) { ifaces, err := net.Interfaces() if err != nil { return nil, fmt.Errorf("couldn't get interfaces: %w", err) - } - if len(ifaces) == 0 { + } else if len(ifaces) == 0 { return nil, errors.Error("couldn't find any legible interface") } diff --git a/internal/aghnet/net_test.go b/internal/aghnet/net_test.go index 34e99faa..b461cb44 100644 --- a/internal/aghnet/net_test.go +++ b/internal/aghnet/net_test.go @@ -7,6 +7,7 @@ import ( "testing" "github.com/AdguardTeam/AdGuardHome/internal/aghtest" + "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" @@ -130,3 +131,27 @@ func TestCheckPort(t *testing.T) { assert.NoError(t, err) }) } + +func TestCollectAllIfacesAddrs(t *testing.T) { + addrs, err := CollectAllIfacesAddrs() + require.NoError(t, err) + + assert.NotEmpty(t, addrs) +} + +func TestIsAddrInUse(t *testing.T) { + t.Run("addr_in_use", func(t *testing.T) { + l, err := net.Listen("tcp", "0.0.0.0:0") + require.NoError(t, err) + testutil.CleanupAndRequireSuccess(t, l.Close) + + _, err = net.Listen(l.Addr().Network(), l.Addr().String()) + assert.True(t, IsAddrInUse(err)) + }) + + t.Run("another", func(t *testing.T) { + const anotherErr errors.Error = "not addr in use" + + assert.False(t, IsAddrInUse(anotherErr)) + }) +} diff --git a/internal/aghnet/systemresolvers.go b/internal/aghnet/systemresolvers.go index 777127a3..13fbeb32 100644 --- a/internal/aghnet/systemresolvers.go +++ b/internal/aghnet/systemresolvers.go @@ -1,11 +1,5 @@ package aghnet -import ( - "time" - - "github.com/AdguardTeam/golibs/log" -) - // DefaultRefreshIvl is the default period of time between refreshing cached // addresses. // const DefaultRefreshIvl = 5 * time.Minute @@ -16,39 +10,21 @@ type HostGenFunc func() (host string) // SystemResolvers helps to work with local resolvers' addresses provided by OS. type SystemResolvers interface { - // Get returns the slice of local resolvers' addresses. It should be - // safe for concurrent use. + // Get returns the slice of local resolvers' addresses. It must be safe for + // concurrent use. Get() (rs []string) - // refresh refreshes the local resolvers' addresses cache. It should be - // safe for concurrent use. + // refresh refreshes the local resolvers' addresses cache. It must be safe + // for concurrent use. refresh() (err error) } -// refreshWithTicker refreshes the cache of sr after each tick form tickCh. -func refreshWithTicker(sr SystemResolvers, tickCh <-chan time.Time) { - defer log.OnPanic("systemResolvers") - - // TODO(e.burkov): Implement a functionality to stop ticker. - for range tickCh { - err := sr.refresh() - if err != nil { - log.Error("systemResolvers: error in refreshing goroutine: %s", err) - - continue - } - - log.Debug("systemResolvers: local addresses cache is refreshed") - } -} - // NewSystemResolvers returns a SystemResolvers with the cache refresh rate // defined by refreshIvl. It disables auto-resfreshing if refreshIvl is 0. If // nil is passed for hostGenFunc, the default generator will be used. func NewSystemResolvers( - refreshIvl time.Duration, hostGenFunc HostGenFunc, ) (sr SystemResolvers, err error) { - sr = newSystemResolvers(refreshIvl, hostGenFunc) + sr = newSystemResolvers(hostGenFunc) // Fill cache. err = sr.refresh() @@ -56,11 +32,5 @@ func NewSystemResolvers( return nil, err } - if refreshIvl > 0 { - ticker := time.NewTicker(refreshIvl) - - go refreshWithTicker(sr, ticker.C) - } - return sr, nil } diff --git a/internal/aghnet/systemresolvers_others.go b/internal/aghnet/systemresolvers_others.go index 8acdb6c7..f8afa286 100644 --- a/internal/aghnet/systemresolvers_others.go +++ b/internal/aghnet/systemresolvers_others.go @@ -24,12 +24,15 @@ func defaultHostGen() (host string) { // systemResolvers is a default implementation of SystemResolvers interface. type systemResolvers struct { - resolver *net.Resolver - hostGenFunc HostGenFunc - - // addrs is the set that contains cached local resolvers' addresses. - addrs *stringutil.Set + // addrsLock protects addrs. addrsLock sync.RWMutex + // addrs is the set that contains cached local resolvers' addresses. + addrs *stringutil.Set + + // resolver is used to fetch the resolvers' addresses. + resolver *net.Resolver + // hostGenFunc generates hosts to resolve. + hostGenFunc HostGenFunc } const ( @@ -44,6 +47,7 @@ const ( errUnexpectedHostFormat errors.Error = "unexpected host format" ) +// refresh implements the SystemResolvers interface for *systemResolvers. func (sr *systemResolvers) refresh() (err error) { defer func() { err = errors.Annotate(err, "systemResolvers: %w") }() @@ -56,7 +60,7 @@ func (sr *systemResolvers) refresh() (err error) { return err } -func newSystemResolvers(refreshIvl time.Duration, hostGenFunc HostGenFunc) (sr SystemResolvers) { +func newSystemResolvers(hostGenFunc HostGenFunc) (sr SystemResolvers) { if hostGenFunc == nil { hostGenFunc = defaultHostGen } @@ -76,19 +80,18 @@ func newSystemResolvers(refreshIvl time.Duration, hostGenFunc HostGenFunc) (sr S func validateDialedHost(host string) (err error) { defer func() { err = errors.Annotate(err, "parsing %q: %w", host) }() - var ipStr string parts := strings.Split(host, "%") switch len(parts) { case 1: - ipStr = host + // host case 2: // Remove the zone and check the IP address part. - ipStr = parts[0] + host = parts[0] default: return errUnexpectedHostFormat } - if net.ParseIP(ipStr) == nil { + if _, err = netutil.ParseIP(host); err != nil { return errBadAddrPassed } diff --git a/internal/aghnet/systemresolvers_others_test.go b/internal/aghnet/systemresolvers_others_test.go index 79abeca2..f7cf9ef0 100644 --- a/internal/aghnet/systemresolvers_others_test.go +++ b/internal/aghnet/systemresolvers_others_test.go @@ -6,37 +6,32 @@ package aghnet import ( "context" "testing" - "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func createTestSystemResolversImp( +func createTestSystemResolversImpl( t *testing.T, - refreshDur time.Duration, hostGenFunc HostGenFunc, ) (imp *systemResolvers) { t.Helper() - sr := createTestSystemResolvers(t, refreshDur, hostGenFunc) + sr := createTestSystemResolvers(t, hostGenFunc) + require.IsType(t, (*systemResolvers)(nil), sr) - var ok bool - imp, ok = sr.(*systemResolvers) - require.True(t, ok) - - return imp + return sr.(*systemResolvers) } func TestSystemResolvers_Refresh(t *testing.T) { t.Run("expected_error", func(t *testing.T) { - sr := createTestSystemResolvers(t, 0, nil) + sr := createTestSystemResolvers(t, nil) assert.NoError(t, sr.refresh()) }) t.Run("unexpected_error", func(t *testing.T) { - _, err := NewSystemResolvers(0, func() string { + _, err := NewSystemResolvers(func() string { return "127.0.0.1::123" }) assert.Error(t, err) @@ -44,7 +39,7 @@ func TestSystemResolvers_Refresh(t *testing.T) { } func TestSystemResolvers_DialFunc(t *testing.T) { - imp := createTestSystemResolversImp(t, 0, nil) + imp := createTestSystemResolversImpl(t, nil) testCases := []struct { want error @@ -52,7 +47,7 @@ func TestSystemResolvers_DialFunc(t *testing.T) { address string }{{ want: errFakeDial, - name: "valid", + name: "valid_ipv4", address: "127.0.0.1", }, { want: errFakeDial, diff --git a/internal/aghnet/systemresolvers_test.go b/internal/aghnet/systemresolvers_test.go index 13145817..0a19490d 100644 --- a/internal/aghnet/systemresolvers_test.go +++ b/internal/aghnet/systemresolvers_test.go @@ -2,7 +2,6 @@ package aghnet import ( "testing" - "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -10,13 +9,12 @@ import ( func createTestSystemResolvers( t *testing.T, - refreshDur time.Duration, hostGenFunc HostGenFunc, ) (sr SystemResolvers) { t.Helper() var err error - sr, err = NewSystemResolvers(refreshDur, hostGenFunc) + sr, err = NewSystemResolvers(hostGenFunc) require.NoError(t, err) require.NotNil(t, sr) @@ -24,8 +22,14 @@ func createTestSystemResolvers( } func TestSystemResolvers_Get(t *testing.T) { - sr := createTestSystemResolvers(t, 0, nil) - assert.NotEmpty(t, sr.Get()) + sr := createTestSystemResolvers(t, nil) + + var rs []string + require.NotPanics(t, func() { + rs = sr.Get() + }) + + assert.NotEmpty(t, rs) } // TODO(e.burkov): Write tests for refreshWithTicker. diff --git a/internal/aghnet/systemresolvers_windows.go b/internal/aghnet/systemresolvers_windows.go index 5acdfa85..f82d6e7e 100644 --- a/internal/aghnet/systemresolvers_windows.go +++ b/internal/aghnet/systemresolvers_windows.go @@ -11,7 +11,6 @@ import ( "os/exec" "strings" "sync" - "time" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" @@ -27,7 +26,7 @@ type systemResolvers struct { addrsLock sync.RWMutex } -func newSystemResolvers(refreshIvl time.Duration, _ HostGenFunc) (sr SystemResolvers) { +func newSystemResolvers(_ HostGenFunc) (sr SystemResolvers) { return &systemResolvers{} } diff --git a/internal/aghnet/testdata/proc_net_arp b/internal/aghnet/testdata/proc_net_arp index 07d214e1..8460c8bb 100644 --- a/internal/aghnet/testdata/proc_net_arp +++ b/internal/aghnet/testdata/proc_net_arp @@ -1,4 +1,6 @@ IP address HW type Flags HW address Mask Device 192.168.1.2 0x1 0x2 ab:cd:ef:ab:cd:ef * wan ::ffff:ffff 0x1 0x0 ef:cd:ab:ef:cd:ab * br-lan -0.0.0.0 0x0 0x0 00:00:00:00:00:00 * unspec \ No newline at end of file +0.0.0.0 0x0 0x0 00:00:00:00:00:00 * unspec +1.2.3.4.5 0x1 0x2 aa:bb:cc:dd:ee:ff * wan +1.2.3.4 0x1 0x2 12:34:56:78:910 * wan \ No newline at end of file diff --git a/internal/dnsforward/dnsforward.go b/internal/dnsforward/dnsforward.go index a5b0098a..c0cd0e55 100644 --- a/internal/dnsforward/dnsforward.go +++ b/internal/dnsforward/dnsforward.go @@ -173,7 +173,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) { // TODO(e.burkov): Enable the refresher after the actual implementation // passes the public testing. - s.sysResolvers, err = aghnet.NewSystemResolvers(0, nil) + s.sysResolvers, err = aghnet.NewSystemResolvers(nil) if err != nil { return nil, fmt.Errorf("initializing system resolvers: %w", err) } diff --git a/internal/home/clients.go b/internal/home/clients.go index 9230e565..fe15e514 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -257,6 +257,8 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) { } func (clients *clientsContainer) periodicUpdate() { + defer log.OnPanic("clients container") + for { clients.Reload() time.Sleep(clientsUpdatePeriod) diff --git a/internal/home/home.go b/internal/home/home.go index 4e4d3aee..5fae42c8 100644 --- a/internal/home/home.go +++ b/internal/home/home.go @@ -291,14 +291,7 @@ func setupConfig(args options) (err error) { } } - var arpdb aghnet.ARPDB - arpdb, err = aghnet.NewARPDB() - if err != nil { - log.Info("warning: creating arpdb: %s; using stub", err) - - arpdb = aghnet.EmptyARPDB{} - } - + arpdb := aghnet.NewARPDB() Context.clients.Init(config.Clients, Context.dhcpServer, Context.etcHosts, arpdb) if args.bindPort != 0 {