Pull request: 2846 cover aghnet vol.2

Merge in DNS/adguard-home from 2846-cover-aghnet-vol.2 to master

Updates #2846.
Closes #4408.

Squashed commit of the following:

commit 8d62b29d5b5be875cb71e518e479e321d853eb1a
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Wed Mar 23 20:42:04 2022 +0300

    home: recover panic

commit 1d98109e910830bec712c7aecbbbcb8f659d823d
Merge: ac11d751 9ce2a0fb
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Wed Mar 23 20:32:05 2022 +0300

    Merge branch 'master' into 2846-cover-aghnet-vol.2

commit ac11d751fb7951e3dd0940bf425a893223c32789
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Wed Mar 23 20:29:41 2022 +0300

    aghnet: use iotest

commit 7c923df7bafd5d4b91c4b4a01e75ab161944f949
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Wed Mar 23 20:17:19 2022 +0300

    aghnet: cover more

commit 3bfd4d587e4b887b5527d60c0eb6027da15c7e37
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Wed Mar 23 14:13:59 2022 +0300

    aghnet: cover arpdb more

commit cd5cf7bbdecceeab6d3abee10a5572e1e907cc67
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Wed Mar 23 13:05:35 2022 +0300

    all: rm arpdb initial refresh

commit 0fb8d9e44a4d130ca4e8fc2ea5d595ec08555302
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Tue Mar 22 21:13:16 2022 +0300

    aghnet: cover arpdb
This commit is contained in:
Eugene Burkov 2022-03-23 20:47:45 +03:00
parent 9ce2a0fb34
commit 82505566f8
17 changed files with 143 additions and 88 deletions

View File

@ -27,15 +27,8 @@ type ARPDB interface {
} }
// NewARPDB returns the ARPDB properly initialized for the OS. // NewARPDB returns the ARPDB properly initialized for the OS.
func NewARPDB() (arp ARPDB, err error) { func NewARPDB() (arp ARPDB) {
arp = newARPDB() return newARPDB()
err = arp.Refresh()
if err != nil {
return nil, fmt.Errorf("arpdb initial refresh: %w", err)
}
return arp, nil
} }
// Empty ARPDB implementation // Empty ARPDB implementation

View File

@ -8,8 +8,12 @@ import (
) )
const arpAOutput = ` const arpAOutput = `
invalid.mac (1.2.3.4) at 12:34:56:78:910 on el0 ifscope [ethernet]
invalid.ip (1.2.3.4.5) at ab:cd:ef:ab:cd:12 on ek0 ifscope [ethernet]
invalid.fmt 1 at 12:cd:ef:ab:cd:ef on er0 ifscope [ethernet]
hostname.one (192.168.1.2) at ab:cd:ef:ab:cd:ef on en0 ifscope [ethernet] hostname.one (192.168.1.2) at ab:cd:ef:ab:cd:ef on en0 ifscope [ethernet]
hostname.two (::ffff:ffff) at ef:cd:ab:ef:cd:ab on em0 expires in 1198 seconds [ethernet] hostname.two (::ffff:ffff) at ef:cd:ab:ef:cd:ab on em0 expires in 1198 seconds [ethernet]
? (::1234) at aa:bb:cc:dd:ee:ff on ej0 expires in 1918 seconds [ethernet]
` `
var wantNeighs = []Neighbor{{ var wantNeighs = []Neighbor{{
@ -20,4 +24,8 @@ var wantNeighs = []Neighbor{{
Name: "hostname.two", Name: "hostname.two",
IP: net.ParseIP("::ffff:ffff"), IP: net.ParseIP("::ffff:ffff"),
MAC: net.HardwareAddr{0xEF, 0xCD, 0xAB, 0xEF, 0xCD, 0xAB}, MAC: net.HardwareAddr{0xEF, 0xCD, 0xAB, 0xEF, 0xCD, 0xAB},
}, {
Name: "",
IP: net.ParseIP("::1234"),
MAC: net.HardwareAddr{0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF},
}} }}

View File

@ -16,14 +16,21 @@ import (
const arpAOutputWrt = ` const arpAOutputWrt = `
IP address HW type Flags HW address Mask Device IP address HW type Flags HW address Mask Device
1.2.3.4.5 0x1 0x2 aa:bb:cc:dd:ee:ff * wan
1.2.3.4 0x1 0x2 12:34:56:78:910 * wan
192.168.1.2 0x1 0x2 ab:cd:ef:ab:cd:ef * wan 192.168.1.2 0x1 0x2 ab:cd:ef:ab:cd:ef * wan
::ffff:ffff 0x1 0x2 ef:cd:ab:ef:cd:ab * wan` ::ffff:ffff 0x1 0x2 ef:cd:ab:ef:cd:ab * wan`
const arpAOutput = ` const arpAOutput = `
invalid.mac (1.2.3.4) at 12:34:56:78:910 on el0 ifscope [ethernet]
invalid.ip (1.2.3.4.5) at ab:cd:ef:ab:cd:12 on ek0 ifscope [ethernet]
invalid.fmt 1 at 12:cd:ef:ab:cd:ef on er0 ifscope [ethernet]
? (192.168.1.2) at ab:cd:ef:ab:cd:ef on en0 ifscope [ethernet] ? (192.168.1.2) at ab:cd:ef:ab:cd:ef on en0 ifscope [ethernet]
? (::ffff:ffff) at ef:cd:ab:ef:cd:ab on em0 expires in 100 seconds [ethernet]` ? (::ffff:ffff) at ef:cd:ab:ef:cd:ab on em0 expires in 100 seconds [ethernet]`
const ipNeighOutput = ` const ipNeighOutput = `
1.2.3.4.5 dev enp0s3 lladdr aa:bb:cc:dd:ee:ff DELAY
1.2.3.4 dev enp0s3 lladdr 12:34:56:78:910 DELAY
192.168.1.2 dev enp0s3 lladdr ab:cd:ef:ab:cd:ef DELAY 192.168.1.2 dev enp0s3 lladdr ab:cd:ef:ab:cd:ef DELAY
::ffff:ffff dev enp0s3 lladdr ef:cd:ab:ef:cd:ab router STALE` ::ffff:ffff dev enp0s3 lladdr ef:cd:ab:ef:cd:ab router STALE`

View File

@ -9,6 +9,8 @@ import (
const arpAOutput = ` const arpAOutput = `
Host Ethernet Address Netif Expire Flags Host Ethernet Address Netif Expire Flags
1.2.3.4.5 aa:bb:cc:dd:ee:ff em0 permanent
1.2.3.4 12:34:56:78:910 em0 permanent
192.168.1.2 ab:cd:ef:ab:cd:ef em0 19m56s 192.168.1.2 ab:cd:ef:ab:cd:ef em0 19m56s
::ffff:ffff ef:cd:ab:ef:cd:ab em0 permanent l ::ffff:ffff ef:cd:ab:ef:cd:ab em0 permanent l
` `

View File

@ -6,6 +6,7 @@ import (
"strings" "strings"
"sync" "sync"
"testing" "testing"
"testing/iotest"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil"
@ -13,6 +14,15 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestNewARPDB(t *testing.T) {
var a ARPDB
require.NotPanics(t, func() {
a = NewARPDB()
})
assert.NotNil(t, a)
}
// TestARPDB is the mock implementation of ARPDB to use in tests. // TestARPDB is the mock implementation of ARPDB to use in tests.
type TestARPDB struct { type TestARPDB struct {
OnRefresh func() (err error) OnRefresh func() (err error)
@ -166,3 +176,46 @@ func TestCmdARPDB_arpa(t *testing.T) {
testutil.AssertErrorMsg(t, "cmd arpdb: running command: can't run", err) testutil.AssertErrorMsg(t, "cmd arpdb: running command: can't run", err)
}) })
} }
func TestCmdARPDB_errors(t *testing.T) {
const errRead errors.Error = "can't read"
badReaderRunCmd := runCmdFunc(func() (r io.Reader, err error) {
return iotest.ErrReader(errRead), nil
})
a := &cmdARPDB{
runcmd: badReaderRunCmd,
parse: parseArpA,
ns: &neighs{
mu: &sync.RWMutex{},
ns: make([]Neighbor, 0),
},
}
const wantErrMsg string = "cmd arpdb: scanning the output: " + string(errRead)
testutil.AssertErrorMsg(t, wantErrMsg, a.Refresh())
}
func TestEmptyARPDB(t *testing.T) {
a := EmptyARPDB{}
t.Run("refresh", func(t *testing.T) {
var err error
require.NotPanics(t, func() {
err = a.Refresh()
})
assert.NoError(t, err)
})
t.Run("neighbors", func(t *testing.T) {
var ns []Neighbor
require.NotPanics(t, func() {
ns = a.Neighbors()
})
assert.Empty(t, ns)
})
}

View File

@ -368,8 +368,8 @@ func (hp *hostsParser) addPairs(ip net.IP, hosts []string) {
} }
} }
// writeRules writes the actual rule for the qtype and the PTR for the // writeRules writes the actual rule for the qtype and the PTR for the host-ip
// host-ip pair into internal builders. // pair into internal builders.
func (hp *hostsParser) writeRules(host string, ip net.IP) (rule, rulePtr string) { func (hp *hostsParser) writeRules(host string, ip net.IP) (rule, rulePtr string) {
arpa, err := netutil.IPToReversedAddr(ip) arpa, err := netutil.IPToReversedAddr(ip)
if err != nil { if err != nil {

View File

@ -105,8 +105,7 @@ func GetValidNetInterfacesForWeb() ([]*NetInterface, error) {
ifaces, err := net.Interfaces() ifaces, err := net.Interfaces()
if err != nil { if err != nil {
return nil, fmt.Errorf("couldn't get interfaces: %w", err) return nil, fmt.Errorf("couldn't get interfaces: %w", err)
} } else if len(ifaces) == 0 {
if len(ifaces) == 0 {
return nil, errors.Error("couldn't find any legible interface") return nil, errors.Error("couldn't find any legible interface")
} }

View File

@ -7,6 +7,7 @@ import (
"testing" "testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -130,3 +131,27 @@ func TestCheckPort(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
}) })
} }
func TestCollectAllIfacesAddrs(t *testing.T) {
addrs, err := CollectAllIfacesAddrs()
require.NoError(t, err)
assert.NotEmpty(t, addrs)
}
func TestIsAddrInUse(t *testing.T) {
t.Run("addr_in_use", func(t *testing.T) {
l, err := net.Listen("tcp", "0.0.0.0:0")
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, l.Close)
_, err = net.Listen(l.Addr().Network(), l.Addr().String())
assert.True(t, IsAddrInUse(err))
})
t.Run("another", func(t *testing.T) {
const anotherErr errors.Error = "not addr in use"
assert.False(t, IsAddrInUse(anotherErr))
})
}

View File

@ -1,11 +1,5 @@
package aghnet package aghnet
import (
"time"
"github.com/AdguardTeam/golibs/log"
)
// DefaultRefreshIvl is the default period of time between refreshing cached // DefaultRefreshIvl is the default period of time between refreshing cached
// addresses. // addresses.
// const DefaultRefreshIvl = 5 * time.Minute // const DefaultRefreshIvl = 5 * time.Minute
@ -16,39 +10,21 @@ type HostGenFunc func() (host string)
// SystemResolvers helps to work with local resolvers' addresses provided by OS. // SystemResolvers helps to work with local resolvers' addresses provided by OS.
type SystemResolvers interface { type SystemResolvers interface {
// Get returns the slice of local resolvers' addresses. It should be // Get returns the slice of local resolvers' addresses. It must be safe for
// safe for concurrent use. // concurrent use.
Get() (rs []string) Get() (rs []string)
// refresh refreshes the local resolvers' addresses cache. It should be // refresh refreshes the local resolvers' addresses cache. It must be safe
// safe for concurrent use. // for concurrent use.
refresh() (err error) refresh() (err error)
} }
// refreshWithTicker refreshes the cache of sr after each tick form tickCh.
func refreshWithTicker(sr SystemResolvers, tickCh <-chan time.Time) {
defer log.OnPanic("systemResolvers")
// TODO(e.burkov): Implement a functionality to stop ticker.
for range tickCh {
err := sr.refresh()
if err != nil {
log.Error("systemResolvers: error in refreshing goroutine: %s", err)
continue
}
log.Debug("systemResolvers: local addresses cache is refreshed")
}
}
// NewSystemResolvers returns a SystemResolvers with the cache refresh rate // NewSystemResolvers returns a SystemResolvers with the cache refresh rate
// defined by refreshIvl. It disables auto-resfreshing if refreshIvl is 0. If // defined by refreshIvl. It disables auto-resfreshing if refreshIvl is 0. If
// nil is passed for hostGenFunc, the default generator will be used. // nil is passed for hostGenFunc, the default generator will be used.
func NewSystemResolvers( func NewSystemResolvers(
refreshIvl time.Duration,
hostGenFunc HostGenFunc, hostGenFunc HostGenFunc,
) (sr SystemResolvers, err error) { ) (sr SystemResolvers, err error) {
sr = newSystemResolvers(refreshIvl, hostGenFunc) sr = newSystemResolvers(hostGenFunc)
// Fill cache. // Fill cache.
err = sr.refresh() err = sr.refresh()
@ -56,11 +32,5 @@ func NewSystemResolvers(
return nil, err return nil, err
} }
if refreshIvl > 0 {
ticker := time.NewTicker(refreshIvl)
go refreshWithTicker(sr, ticker.C)
}
return sr, nil return sr, nil
} }

View File

@ -24,12 +24,15 @@ func defaultHostGen() (host string) {
// systemResolvers is a default implementation of SystemResolvers interface. // systemResolvers is a default implementation of SystemResolvers interface.
type systemResolvers struct { type systemResolvers struct {
resolver *net.Resolver // addrsLock protects addrs.
hostGenFunc HostGenFunc
// addrs is the set that contains cached local resolvers' addresses.
addrs *stringutil.Set
addrsLock sync.RWMutex addrsLock sync.RWMutex
// addrs is the set that contains cached local resolvers' addresses.
addrs *stringutil.Set
// resolver is used to fetch the resolvers' addresses.
resolver *net.Resolver
// hostGenFunc generates hosts to resolve.
hostGenFunc HostGenFunc
} }
const ( const (
@ -44,6 +47,7 @@ const (
errUnexpectedHostFormat errors.Error = "unexpected host format" errUnexpectedHostFormat errors.Error = "unexpected host format"
) )
// refresh implements the SystemResolvers interface for *systemResolvers.
func (sr *systemResolvers) refresh() (err error) { func (sr *systemResolvers) refresh() (err error) {
defer func() { err = errors.Annotate(err, "systemResolvers: %w") }() defer func() { err = errors.Annotate(err, "systemResolvers: %w") }()
@ -56,7 +60,7 @@ func (sr *systemResolvers) refresh() (err error) {
return err return err
} }
func newSystemResolvers(refreshIvl time.Duration, hostGenFunc HostGenFunc) (sr SystemResolvers) { func newSystemResolvers(hostGenFunc HostGenFunc) (sr SystemResolvers) {
if hostGenFunc == nil { if hostGenFunc == nil {
hostGenFunc = defaultHostGen hostGenFunc = defaultHostGen
} }
@ -76,19 +80,18 @@ func newSystemResolvers(refreshIvl time.Duration, hostGenFunc HostGenFunc) (sr S
func validateDialedHost(host string) (err error) { func validateDialedHost(host string) (err error) {
defer func() { err = errors.Annotate(err, "parsing %q: %w", host) }() defer func() { err = errors.Annotate(err, "parsing %q: %w", host) }()
var ipStr string
parts := strings.Split(host, "%") parts := strings.Split(host, "%")
switch len(parts) { switch len(parts) {
case 1: case 1:
ipStr = host // host
case 2: case 2:
// Remove the zone and check the IP address part. // Remove the zone and check the IP address part.
ipStr = parts[0] host = parts[0]
default: default:
return errUnexpectedHostFormat return errUnexpectedHostFormat
} }
if net.ParseIP(ipStr) == nil { if _, err = netutil.ParseIP(host); err != nil {
return errBadAddrPassed return errBadAddrPassed
} }

View File

@ -6,37 +6,32 @@ package aghnet
import ( import (
"context" "context"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func createTestSystemResolversImp( func createTestSystemResolversImpl(
t *testing.T, t *testing.T,
refreshDur time.Duration,
hostGenFunc HostGenFunc, hostGenFunc HostGenFunc,
) (imp *systemResolvers) { ) (imp *systemResolvers) {
t.Helper() t.Helper()
sr := createTestSystemResolvers(t, refreshDur, hostGenFunc) sr := createTestSystemResolvers(t, hostGenFunc)
require.IsType(t, (*systemResolvers)(nil), sr)
var ok bool return sr.(*systemResolvers)
imp, ok = sr.(*systemResolvers)
require.True(t, ok)
return imp
} }
func TestSystemResolvers_Refresh(t *testing.T) { func TestSystemResolvers_Refresh(t *testing.T) {
t.Run("expected_error", func(t *testing.T) { t.Run("expected_error", func(t *testing.T) {
sr := createTestSystemResolvers(t, 0, nil) sr := createTestSystemResolvers(t, nil)
assert.NoError(t, sr.refresh()) assert.NoError(t, sr.refresh())
}) })
t.Run("unexpected_error", func(t *testing.T) { t.Run("unexpected_error", func(t *testing.T) {
_, err := NewSystemResolvers(0, func() string { _, err := NewSystemResolvers(func() string {
return "127.0.0.1::123" return "127.0.0.1::123"
}) })
assert.Error(t, err) assert.Error(t, err)
@ -44,7 +39,7 @@ func TestSystemResolvers_Refresh(t *testing.T) {
} }
func TestSystemResolvers_DialFunc(t *testing.T) { func TestSystemResolvers_DialFunc(t *testing.T) {
imp := createTestSystemResolversImp(t, 0, nil) imp := createTestSystemResolversImpl(t, nil)
testCases := []struct { testCases := []struct {
want error want error
@ -52,7 +47,7 @@ func TestSystemResolvers_DialFunc(t *testing.T) {
address string address string
}{{ }{{
want: errFakeDial, want: errFakeDial,
name: "valid", name: "valid_ipv4",
address: "127.0.0.1", address: "127.0.0.1",
}, { }, {
want: errFakeDial, want: errFakeDial,

View File

@ -2,7 +2,6 @@ package aghnet
import ( import (
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -10,13 +9,12 @@ import (
func createTestSystemResolvers( func createTestSystemResolvers(
t *testing.T, t *testing.T,
refreshDur time.Duration,
hostGenFunc HostGenFunc, hostGenFunc HostGenFunc,
) (sr SystemResolvers) { ) (sr SystemResolvers) {
t.Helper() t.Helper()
var err error var err error
sr, err = NewSystemResolvers(refreshDur, hostGenFunc) sr, err = NewSystemResolvers(hostGenFunc)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, sr) require.NotNil(t, sr)
@ -24,8 +22,14 @@ func createTestSystemResolvers(
} }
func TestSystemResolvers_Get(t *testing.T) { func TestSystemResolvers_Get(t *testing.T) {
sr := createTestSystemResolvers(t, 0, nil) sr := createTestSystemResolvers(t, nil)
assert.NotEmpty(t, sr.Get())
var rs []string
require.NotPanics(t, func() {
rs = sr.Get()
})
assert.NotEmpty(t, rs)
} }
// TODO(e.burkov): Write tests for refreshWithTicker. // TODO(e.burkov): Write tests for refreshWithTicker.

View File

@ -11,7 +11,6 @@ import (
"os/exec" "os/exec"
"strings" "strings"
"sync" "sync"
"time"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
@ -27,7 +26,7 @@ type systemResolvers struct {
addrsLock sync.RWMutex addrsLock sync.RWMutex
} }
func newSystemResolvers(refreshIvl time.Duration, _ HostGenFunc) (sr SystemResolvers) { func newSystemResolvers(_ HostGenFunc) (sr SystemResolvers) {
return &systemResolvers{} return &systemResolvers{}
} }

View File

@ -2,3 +2,5 @@ IP address HW type Flags HW address Mask Device
192.168.1.2 0x1 0x2 ab:cd:ef:ab:cd:ef * wan 192.168.1.2 0x1 0x2 ab:cd:ef:ab:cd:ef * wan
::ffff:ffff 0x1 0x0 ef:cd:ab:ef:cd:ab * br-lan ::ffff:ffff 0x1 0x0 ef:cd:ab:ef:cd:ab * br-lan
0.0.0.0 0x0 0x0 00:00:00:00:00:00 * unspec 0.0.0.0 0x0 0x0 00:00:00:00:00:00 * unspec
1.2.3.4.5 0x1 0x2 aa:bb:cc:dd:ee:ff * wan
1.2.3.4 0x1 0x2 12:34:56:78:910 * wan

View File

@ -173,7 +173,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
// TODO(e.burkov): Enable the refresher after the actual implementation // TODO(e.burkov): Enable the refresher after the actual implementation
// passes the public testing. // passes the public testing.
s.sysResolvers, err = aghnet.NewSystemResolvers(0, nil) s.sysResolvers, err = aghnet.NewSystemResolvers(nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("initializing system resolvers: %w", err) return nil, fmt.Errorf("initializing system resolvers: %w", err)
} }

View File

@ -257,6 +257,8 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) {
} }
func (clients *clientsContainer) periodicUpdate() { func (clients *clientsContainer) periodicUpdate() {
defer log.OnPanic("clients container")
for { for {
clients.Reload() clients.Reload()
time.Sleep(clientsUpdatePeriod) time.Sleep(clientsUpdatePeriod)

View File

@ -291,14 +291,7 @@ func setupConfig(args options) (err error) {
} }
} }
var arpdb aghnet.ARPDB arpdb := aghnet.NewARPDB()
arpdb, err = aghnet.NewARPDB()
if err != nil {
log.Info("warning: creating arpdb: %s; using stub", err)
arpdb = aghnet.EmptyARPDB{}
}
Context.clients.Init(config.Clients, Context.dhcpServer, Context.etcHosts, arpdb) Context.clients.Init(config.Clients, Context.dhcpServer, Context.etcHosts, arpdb)
if args.bindPort != 0 { if args.bindPort != 0 {