diff --git a/internal/aghnet/arpdb.go b/internal/aghnet/arpdb.go index afb880c3..4909af5f 100644 --- a/internal/aghnet/arpdb.go +++ b/internal/aghnet/arpdb.go @@ -2,13 +2,11 @@ package aghnet import ( "bufio" + "bytes" "fmt" - "io" "net" - "strings" "sync" - "github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/netutil" ) @@ -116,50 +114,33 @@ func (ns *neighs) reset(with []Neighbor) { // of Neighbors. type parseNeighsFunc func(sc *bufio.Scanner, lenHint int) (ns []Neighbor) -// runCmdFunc is the function that runs some command and returns its output -// wrapped to be a io.Reader. -type runCmdFunc func() (r io.Reader, err error) - // cmdARPDB is the implementation of the ARPDB that uses command line to // retrieve data. type cmdARPDB struct { - parse parseNeighsFunc - runcmd runCmdFunc - ns *neighs + parse parseNeighsFunc + ns *neighs + cmd string + args []string } // type check var _ ARPDB = (*cmdARPDB)(nil) -// runCmd runs the cmd with it's args and returns the result wrapped to be an -// io.Reader. The error is returned either if the exit code retured by command -// not equals 0 or the execution itself failed. -func runCmd(cmd string, args ...string) (r io.Reader, err error) { - var code int - var out string - code, out, err = aghos.RunCommand(cmd, args...) - if err != nil { - return nil, err - } else if code != 0 { - return nil, fmt.Errorf("unexpected exit code %d", code) - } - - return strings.NewReader(out), nil -} - // Refresh implements the ARPDB interface for *cmdARPDB. func (arp *cmdARPDB) Refresh() (err error) { defer func() { err = errors.Annotate(err, "cmd arpdb: %w") }() - var r io.Reader - r, err = arp.runcmd() + code, out, err := aghosRunCommand(arp.cmd, arp.args...) if err != nil { return fmt.Errorf("running command: %w", err) + } else if code != 0 { + return fmt.Errorf("running command: unexpected exit code %d", code) } - sc := bufio.NewScanner(r) + sc := bufio.NewScanner(bytes.NewReader(out)) ns := arp.parse(sc, arp.ns.len()) if err = sc.Err(); err != nil { + // TODO(e.burkov): This error seems unreachable. Investigate. return fmt.Errorf("scanning the output: %w", err) } @@ -180,8 +161,7 @@ func (arp *cmdARPDB) Neighbors() (ns []Neighbor) { type arpdbs struct { // arps is the set of ARPDB implementations to range through. arps []ARPDB - // last is the last succeeded ARPDB index. - last int + neighs } // newARPDBs returns a properly initialized *arpdbs. It begins refreshing from @@ -189,7 +169,10 @@ type arpdbs struct { func newARPDBs(arps ...ARPDB) (arp *arpdbs) { return &arpdbs{ arps: arps, - last: 0, + neighs: neighs{ + mu: &sync.RWMutex{}, + ns: make([]Neighbor, 0), + }, } } @@ -199,20 +182,18 @@ var _ ARPDB = (*arpdbs)(nil) // Refresh implements the ARPDB interface for *arpdbs. func (arp *arpdbs) Refresh() (err error) { var errs []error - l := len(arp.arps) - // Start from the last succeeded implementation. - for i := 0; i < l; i++ { - cur := (arp.last + i) % l - err = arp.arps[cur].Refresh() - if err == nil { - // The succeeded implementation found so update the last succeeded - // index. - arp.last = cur - return nil + for _, a := range arp.arps { + err = a.Refresh() + if err != nil { + errs = append(errs, err) + + continue } - errs = append(errs, err) + arp.reset(a.Neighbors()) + + return nil } if len(errs) > 0 { @@ -223,10 +204,8 @@ func (arp *arpdbs) Refresh() (err error) { } // Neighbors implements the ARPDB interface for *arpdbs. +// +// TODO(e.burkov): Think of a way to avoid cloning the slice twice. func (arp *arpdbs) Neighbors() (ns []Neighbor) { - if l := len(arp.arps); l > 0 && arp.last < l { - return arp.arps[arp.last].Neighbors() - } - - return nil + return arp.clone() } diff --git a/internal/aghnet/arpdb_bsd.go b/internal/aghnet/arpdb_bsd.go index fe00418a..a82da76c 100644 --- a/internal/aghnet/arpdb_bsd.go +++ b/internal/aghnet/arpdb_bsd.go @@ -15,12 +15,13 @@ import ( func newARPDB() *cmdARPDB { return &cmdARPDB{ - parse: parseArpA, - runcmd: rcArpA, + parse: parseArpA, ns: &neighs{ mu: &sync.RWMutex{}, ns: make([]Neighbor, 0), }, + cmd: "arp", + args: []string{"-a"}, } } diff --git a/internal/aghnet/arpdb_linux.go b/internal/aghnet/arpdb_linux.go index 976d8b7a..3d391f29 100644 --- a/internal/aghnet/arpdb_linux.go +++ b/internal/aghnet/arpdb_linux.go @@ -6,7 +6,6 @@ package aghnet import ( "bufio" "fmt" - "io" "io/fs" "net" "strings" @@ -34,11 +33,25 @@ func newARPDB() (arp *arpdbs) { return newARPDBs( // Try /proc/net/arp first. - &fsysARPDB{ns: ns, fsys: aghos.RootDirFS(), filename: "proc/net/arp"}, - // Try "arp -a" then. - &cmdARPDB{parse: parseF, runcmd: rcArpA, ns: ns}, - // Try "ip neigh" finally. - &cmdARPDB{parse: parseIPNeigh, runcmd: rcIPNeigh, ns: ns}, + &fsysARPDB{ + ns: ns, + fsys: rootDirFS, + filename: "proc/net/arp", + }, + // Then, try "arp -a". + &cmdARPDB{ + parse: parseF, + ns: ns, + cmd: "arp", + args: []string{"-a"}, + }, + // Finally, try "ip neigh". + &cmdARPDB{ + parse: parseIPNeigh, + ns: ns, + cmd: "ip", + args: []string{"neigh"}, + }, ) } @@ -187,11 +200,6 @@ func parseArpA(sc *bufio.Scanner, lenHint int) (ns []Neighbor) { return ns } -// rcIPNeigh runs "ip neigh". -func rcIPNeigh() (r io.Reader, err error) { - return runCmd("ip", "neigh") -} - // parseIPNeigh parses the output of the "ip neigh" command on Linux. The // expected input format: // diff --git a/internal/aghnet/arpdb_linux_test.go b/internal/aghnet/arpdb_linux_test.go index b9ed859f..46d87150 100644 --- a/internal/aghnet/arpdb_linux_test.go +++ b/internal/aghnet/arpdb_linux_test.go @@ -4,11 +4,10 @@ package aghnet import ( - "io" "net" - "strings" "sync" "testing" + "testing/fstest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -43,6 +42,8 @@ var wantNeighs = []Neighbor{{ }} func TestFSysARPDB(t *testing.T) { + require.NoError(t, fstest.TestFS(testdata, "proc_net_arp")) + a := &fsysARPDB{ ns: &neighs{ mu: &sync.RWMutex{}, @@ -59,33 +60,43 @@ func TestFSysARPDB(t *testing.T) { assert.Equal(t, wantNeighs, ns) } -func TestCmdARPDB_arpawrt(t *testing.T) { - a := &cmdARPDB{ - parse: parseArpAWrt, - runcmd: func() (r io.Reader, err error) { return strings.NewReader(arpAOutputWrt), nil }, - ns: &neighs{ - mu: &sync.RWMutex{}, - ns: make([]Neighbor, 0), - }, +func TestCmdARPDB_linux(t *testing.T) { + sh := mapShell{ + "arp -a": {err: nil, out: arpAOutputWrt, code: 0}, + "ip neigh": {err: nil, out: ipNeighOutput, code: 0}, } + substShell(t, sh.RunCmd) - err := a.Refresh() - require.NoError(t, err) + t.Run("wrt", func(t *testing.T) { + a := &cmdARPDB{ + parse: parseArpAWrt, + cmd: "arp", + args: []string{"-a"}, + ns: &neighs{ + mu: &sync.RWMutex{}, + ns: make([]Neighbor, 0), + }, + } - assert.Equal(t, wantNeighs, a.Neighbors()) -} - -func TestCmdARPDB_ipneigh(t *testing.T) { - a := &cmdARPDB{ - parse: parseIPNeigh, - runcmd: func() (r io.Reader, err error) { return strings.NewReader(ipNeighOutput), nil }, - ns: &neighs{ - mu: &sync.RWMutex{}, - ns: make([]Neighbor, 0), - }, - } - err := a.Refresh() - require.NoError(t, err) - - assert.Equal(t, wantNeighs, a.Neighbors()) + err := a.Refresh() + require.NoError(t, err) + + assert.Equal(t, wantNeighs, a.Neighbors()) + }) + + t.Run("ip_neigh", func(t *testing.T) { + a := &cmdARPDB{ + parse: parseIPNeigh, + cmd: "ip", + args: []string{"neigh"}, + ns: &neighs{ + mu: &sync.RWMutex{}, + ns: make([]Neighbor, 0), + }, + } + err := a.Refresh() + require.NoError(t, err) + + assert.Equal(t, wantNeighs, a.Neighbors()) + }) } diff --git a/internal/aghnet/arpdb_openbsd.go b/internal/aghnet/arpdb_openbsd.go index 805cb459..a00ffa85 100644 --- a/internal/aghnet/arpdb_openbsd.go +++ b/internal/aghnet/arpdb_openbsd.go @@ -14,12 +14,13 @@ import ( func newARPDB() *cmdARPDB { return &cmdARPDB{ - runcmd: rcArpA, - parse: parseArpA, + parse: parseArpA, ns: &neighs{ mu: &sync.RWMutex{}, ns: make([]Neighbor, 0), }, + cmd: "arp", + args: []string{"-a"}, } } diff --git a/internal/aghnet/arpdb_test.go b/internal/aghnet/arpdb_test.go index aa5c821a..75778b74 100644 --- a/internal/aghnet/arpdb_test.go +++ b/internal/aghnet/arpdb_test.go @@ -1,12 +1,9 @@ package aghnet import ( - "io" "net" - "strings" "sync" "testing" - "testing/iotest" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/testutil" @@ -16,9 +13,7 @@ import ( func TestNewARPDB(t *testing.T) { var a ARPDB - require.NotPanics(t, func() { - a = NewARPDB() - }) + require.NotPanics(t, func() { a = NewARPDB() }) assert.NotNil(t, a) } @@ -135,11 +130,11 @@ func TestARPDBS(t *testing.T) { assert.Equal(t, 1, succRefrCount) assert.NotEmpty(t, a.Neighbors()) - // Only the last succeeded ARPDB should be used. + // Unstable ARPDB should refresh successfully again. err = a.Refresh() require.NoError(t, err) - assert.Equal(t, 2, succRefrCount) + assert.Equal(t, 1, succRefrCount) assert.NotEmpty(t, a.Neighbors()) }) @@ -153,6 +148,7 @@ func TestARPDBS(t *testing.T) { func TestCmdARPDB_arpa(t *testing.T) { a := &cmdARPDB{ + cmd: "cmd", parse: parseArpA, ns: &neighs{ mu: &sync.RWMutex{}, @@ -161,7 +157,8 @@ func TestCmdARPDB_arpa(t *testing.T) { } t.Run("arp_a", func(t *testing.T) { - a.runcmd = func() (r io.Reader, err error) { return strings.NewReader(arpAOutput), nil } + sh := theOnlyCmd("cmd", 0, arpAOutput, nil) + substShell(t, sh.RunCmd) err := a.Refresh() require.NoError(t, err) @@ -170,32 +167,20 @@ func TestCmdARPDB_arpa(t *testing.T) { }) t.Run("runcmd_error", func(t *testing.T) { - a.runcmd = func() (r io.Reader, err error) { return nil, errors.Error("can't run") } + sh := theOnlyCmd("cmd", 0, "", errors.Error("can't run")) + substShell(t, sh.RunCmd) err := a.Refresh() testutil.AssertErrorMsg(t, "cmd arpdb: running command: can't run", err) }) -} -func TestCmdARPDB_errors(t *testing.T) { - const errRead errors.Error = "can't read" + t.Run("bad_code", func(t *testing.T) { + sh := theOnlyCmd("cmd", 1, "", nil) + substShell(t, sh.RunCmd) - badReaderRunCmd := runCmdFunc(func() (r io.Reader, err error) { - return iotest.ErrReader(errRead), nil + err := a.Refresh() + testutil.AssertErrorMsg(t, "cmd arpdb: running command: unexpected exit code 1", err) }) - - a := &cmdARPDB{ - runcmd: badReaderRunCmd, - parse: parseArpA, - ns: &neighs{ - mu: &sync.RWMutex{}, - ns: make([]Neighbor, 0), - }, - } - - const wantErrMsg string = "cmd arpdb: scanning the output: " + string(errRead) - - testutil.AssertErrorMsg(t, wantErrMsg, a.Refresh()) } func TestEmptyARPDB(t *testing.T) { diff --git a/internal/aghnet/arpdb_unix.go b/internal/aghnet/arpdb_unix.go deleted file mode 100644 index 50346f92..00000000 --- a/internal/aghnet/arpdb_unix.go +++ /dev/null @@ -1,13 +0,0 @@ -//go:build !windows -// +build !windows - -package aghnet - -import ( - "io" -) - -// rcArpA runs "arp -a". -func rcArpA() (r io.Reader, err error) { - return runCmd("arp", "-a") -} diff --git a/internal/aghnet/arpdb_windows.go b/internal/aghnet/arpdb_windows.go index 5156330b..2a70125f 100644 --- a/internal/aghnet/arpdb_windows.go +++ b/internal/aghnet/arpdb_windows.go @@ -5,7 +5,6 @@ package aghnet import ( "bufio" - "io" "net" "strings" "sync" @@ -13,20 +12,16 @@ import ( func newARPDB() *cmdARPDB { return &cmdARPDB{ - runcmd: rcArpA, + parse: parseArpA, ns: &neighs{ mu: &sync.RWMutex{}, ns: make([]Neighbor, 0), }, - parse: parseArpA, + cmd: "arp", + args: []string{"/a"}, } } -// rcArpA runs "arp /a". -func rcArpA() (r io.Reader, err error) { - return runCmd("arp", "/a") -} - // parseArpA parses the output of the "arp /a" command on Windows. The expected // input format (the first line is empty): // diff --git a/internal/aghnet/hostscontainer_test.go b/internal/aghnet/hostscontainer_test.go index 40bfb34c..807722a8 100644 --- a/internal/aghnet/hostscontainer_test.go +++ b/internal/aghnet/hostscontainer_test.go @@ -286,6 +286,8 @@ func TestHostsContainer_Translate(t *testing.T) { OnClose: func() (err error) { panic("not implemented") }, } + require.NoError(t, fstest.TestFS(testdata, "etc_hosts")) + hc, err := NewHostsContainer(0, testdata, &stubWatcher, "etc_hosts") require.NoError(t, err) testutil.CleanupAndRequireSuccess(t, hc.Close) @@ -358,6 +360,8 @@ func TestHostsContainer_Translate(t *testing.T) { func TestHostsContainer(t *testing.T) { const listID = 1234 + require.NoError(t, fstest.TestFS(testdata, "etc_hosts")) + testCases := []struct { want []*rules.DNSRewrite name string diff --git a/internal/aghnet/net.go b/internal/aghnet/net.go index 2d791c03..d17b9165 100644 --- a/internal/aghnet/net.go +++ b/internal/aghnet/net.go @@ -2,19 +2,27 @@ package aghnet import ( + "bytes" "encoding/json" "fmt" "io" "net" - "os/exec" - "strings" "syscall" + "github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/netutil" ) +// aghosRunCommand is the function to run shell commands. It's an unexported +// variable instead of a direct call to make it substitutable in tests. +var aghosRunCommand = aghos.RunCommand + +// rootDirFS is the filesystem pointing to the root directory. It's an +// unexported variable instead to make it substitutable in tests. +var rootDirFS = aghos.RootDirFS() + // ErrNoStaticIPInfo is returned by IfaceHasStaticIP when no information about // the IP being static is available. const ErrNoStaticIPInfo errors.Error = "no information about static ip" @@ -32,22 +40,29 @@ func IfaceSetStaticIP(ifaceName string) (err error) { } // GatewayIP returns IP address of interface's gateway. -func GatewayIP(ifaceName string) net.IP { - cmd := exec.Command("ip", "route", "show", "dev", ifaceName) - log.Tracef("executing %s %v", cmd.Path, cmd.Args) - d, err := cmd.Output() - if err != nil || cmd.ProcessState.ExitCode() != 0 { +// +// TODO(e.burkov): Investigate if the gateway address may be fetched in another +// way since not every machine has the software installed. +func GatewayIP(ifaceName string) (ip net.IP) { + code, out, err := aghosRunCommand("ip", "route", "show", "dev", ifaceName) + if err != nil { + log.Debug("%s", err) + + return nil + } else if code != 0 { + log.Debug("fetching gateway ip: unexpected exit code: %d", code) + return nil } - fields := strings.Fields(string(d)) + fields := bytes.Fields(out) // The meaningful "ip route" command output should contain the word // "default" at first field and default gateway IP address at third field. - if len(fields) < 3 || fields[0] != "default" { + if len(fields) < 3 || string(fields[0]) != "default" { return nil } - return net.ParseIP(fields[2]) + return net.ParseIP(string(fields[2])) } // CanBindPort checks if we can bind to the given port. @@ -101,7 +116,7 @@ func (iface NetInterface) MarshalJSON() ([]byte, error) { // GetValidNetInterfacesForWeb returns interfaces that are eligible for DNS and WEB only // we do not return link-local addresses here -func GetValidNetInterfacesForWeb() ([]*NetInterface, error) { +func GetValidNetInterfacesForWeb() (netInterfaces []*NetInterface, err error) { ifaces, err := net.Interfaces() if err != nil { return nil, fmt.Errorf("couldn't get interfaces: %w", err) @@ -109,8 +124,6 @@ func GetValidNetInterfacesForWeb() ([]*NetInterface, error) { return nil, errors.Error("couldn't find any legible interface") } - var netInterfaces []*NetInterface - for _, iface := range ifaces { var addrs []net.Addr addrs, err = iface.Addrs() @@ -130,12 +143,14 @@ func GetValidNetInterfacesForWeb() ([]*NetInterface, error) { ipNet, ok := addr.(*net.IPNet) if !ok { // Should be net.IPNet, this is weird. - return nil, fmt.Errorf("got iface.Addrs() element %s that is not net.IPNet, it is %T", addr, addr) + return nil, fmt.Errorf("got %s that is not net.IPNet, it is %T", addr, addr) } + // Ignore link-local. if ipNet.IP.IsLinkLocalUnicast() { continue } + netIface.Addresses = append(netIface.Addresses, ipNet.IP) netIface.Subnets = append(netIface.Subnets, ipNet) } diff --git a/internal/aghnet/net_darwin.go b/internal/aghnet/net_darwin.go index 3c504988..63e57dc4 100644 --- a/internal/aghnet/net_darwin.go +++ b/internal/aghnet/net_darwin.go @@ -4,10 +4,11 @@ package aghnet import ( + "bufio" + "bytes" "fmt" - "os" + "io" "regexp" - "strings" "github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/golibs/errors" @@ -27,7 +28,7 @@ func canBindPrivilegedPorts() (can bool, err error) { return aghos.HaveAdminRights() } -func ifaceHasStaticIP(ifaceName string) (bool, error) { +func ifaceHasStaticIP(ifaceName string) (ok bool, err error) { portInfo, err := getCurrentHardwarePortInfo(ifaceName) if err != nil { return false, err @@ -36,9 +37,10 @@ func ifaceHasStaticIP(ifaceName string) (bool, error) { return portInfo.static, nil } -// getCurrentHardwarePortInfo gets information for the specified network interface. +// getCurrentHardwarePortInfo gets information for the specified network +// interface. func getCurrentHardwarePortInfo(ifaceName string) (hardwarePortInfo, error) { - // First of all we should find hardware port name + // First of all we should find hardware port name. m := getNetworkSetupHardwareReports() hardwarePort, ok := m[ifaceName] if !ok { @@ -48,6 +50,10 @@ func getCurrentHardwarePortInfo(ifaceName string) (hardwarePortInfo, error) { return getHardwarePortInfo(hardwarePort) } +// hardwareReportsReg is the regular expression matching the lines of +// networksetup command output lines containing the interface information. +var hardwareReportsReg = regexp.MustCompile("Hardware Port: (.*?)\nDevice: (.*?)\n") + // getNetworkSetupHardwareReports parses the output of the `networksetup // -listallhardwareports` command it returns a map where the key is the // interface name, and the value is the "hardware port" returns nil if it fails @@ -56,54 +62,44 @@ func getCurrentHardwarePortInfo(ifaceName string) (hardwarePortInfo, error) { // TODO(e.burkov): There should be more proper approach than parsing the // command output. For example, see // https://developer.apple.com/documentation/systemconfiguration. -func getNetworkSetupHardwareReports() map[string]string { - _, out, err := aghos.RunCommand("networksetup", "-listallhardwareports") +func getNetworkSetupHardwareReports() (reports map[string]string) { + _, out, err := aghosRunCommand("networksetup", "-listallhardwareports") if err != nil { return nil } - re, err := regexp.Compile("Hardware Port: (.*?)\nDevice: (.*?)\n") - if err != nil { - return nil + reports = make(map[string]string) + + matches := hardwareReportsReg.FindAllSubmatch(out, -1) + for _, m := range matches { + reports[string(m[2])] = string(m[1]) } - m := make(map[string]string) - - matches := re.FindAllStringSubmatch(out, -1) - for i := range matches { - port := matches[i][1] - device := matches[i][2] - m[device] = port - } - - return m + return reports } -func getHardwarePortInfo(hardwarePort string) (hardwarePortInfo, error) { - h := hardwarePortInfo{} +// hardwarePortReg is the regular expression matching the lines of networksetup +// command output lines containing the port information. +var hardwarePortReg = regexp.MustCompile("IP address: (.*?)\nSubnet mask: (.*?)\nRouter: (.*?)\n") - _, out, err := aghos.RunCommand("networksetup", "-getinfo", hardwarePort) +func getHardwarePortInfo(hardwarePort string) (h hardwarePortInfo, err error) { + _, out, err := aghosRunCommand("networksetup", "-getinfo", hardwarePort) if err != nil { return h, err } - re := regexp.MustCompile("IP address: (.*?)\nSubnet mask: (.*?)\nRouter: (.*?)\n") - - match := re.FindStringSubmatch(out) - if len(match) == 0 { + match := hardwarePortReg.FindSubmatch(out) + if len(match) != 4 { return h, errors.Error("could not find hardware port info") } - h.name = hardwarePort - h.ip = match[1] - h.subnet = match[2] - h.gatewayIP = match[3] - - if strings.Index(out, "Manual Configuration") == 0 { - h.static = true - } - - return h, nil + return hardwarePortInfo{ + name: hardwarePort, + ip: string(match[1]), + subnet: string(match[2]), + gatewayIP: string(match[3]), + static: bytes.Index(out, []byte("Manual Configuration")) == 0, + }, nil } func ifaceSetStaticIP(ifaceName string) (err error) { @@ -113,7 +109,7 @@ func ifaceSetStaticIP(ifaceName string) (err error) { } if portInfo.static { - return errors.Error("IP address is already static") + return errors.Error("ip address is already static") } dnsAddrs, err := getEtcResolvConfServers() @@ -121,50 +117,62 @@ func ifaceSetStaticIP(ifaceName string) (err error) { return err } - args := make([]string, 0) - args = append(args, "-setdnsservers", portInfo.name) - args = append(args, dnsAddrs...) + args := append([]string{"-setdnsservers", portInfo.name}, dnsAddrs...) // Setting DNS servers is necessary when configuring a static IP - code, _, err := aghos.RunCommand("networksetup", args...) + code, _, err := aghosRunCommand("networksetup", args...) if err != nil { return err - } - if code != 0 { + } else if code != 0 { return fmt.Errorf("failed to set DNS servers, code=%d", code) } // Actually configures hardware port to have static IP - code, _, err = aghos.RunCommand("networksetup", "-setmanual", - portInfo.name, portInfo.ip, portInfo.subnet, portInfo.gatewayIP) + code, _, err = aghosRunCommand( + "networksetup", + "-setmanual", + portInfo.name, + portInfo.ip, + portInfo.subnet, + portInfo.gatewayIP, + ) if err != nil { return err - } - if code != 0 { + } else if code != 0 { return fmt.Errorf("failed to set DNS servers, code=%d", code) } return nil } +// etcResolvConfReg is the regular expression matching the lines of resolv.conf +// file containing a name server information. +var etcResolvConfReg = regexp.MustCompile("nameserver ([a-zA-Z0-9.:]+)") + // getEtcResolvConfServers returns a list of nameservers configured in // /etc/resolv.conf. -func getEtcResolvConfServers() ([]string, error) { - body, err := os.ReadFile("/etc/resolv.conf") +func getEtcResolvConfServers() (addrs []string, err error) { + const filename = "etc/resolv.conf" + + _, err = aghos.FileWalker(func(r io.Reader) (_ []string, _ bool, err error) { + sc := bufio.NewScanner(r) + for sc.Scan() { + matches := etcResolvConfReg.FindAllStringSubmatch(sc.Text(), -1) + if len(matches) == 0 { + continue + } + + for _, m := range matches { + addrs = append(addrs, m[1]) + } + } + + return nil, false, sc.Err() + }).Walk(rootDirFS, filename) if err != nil { - return nil, err - } - - re := regexp.MustCompile("nameserver ([a-zA-Z0-9.:]+)") - - matches := re.FindAllStringSubmatch(string(body), -1) - if len(matches) == 0 { - return nil, errors.Error("found no DNS servers in /etc/resolv.conf") - } - - addrs := make([]string, 0) - for i := range matches { - addrs = append(addrs, matches[i][1]) + return nil, fmt.Errorf("parsing etc/resolv.conf file: %w", err) + } else if len(addrs) == 0 { + return nil, fmt.Errorf("found no dns servers in %s", filename) } return addrs, nil diff --git a/internal/aghnet/net_darwin_test.go b/internal/aghnet/net_darwin_test.go new file mode 100644 index 00000000..905600d5 --- /dev/null +++ b/internal/aghnet/net_darwin_test.go @@ -0,0 +1,261 @@ +package aghnet + +import ( + "io/fs" + "testing" + "testing/fstest" + + "github.com/AdguardTeam/AdGuardHome/internal/aghtest" + "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/testutil" + "github.com/stretchr/testify/assert" +) + +func TestIfaceHasStaticIP(t *testing.T) { + testCases := []struct { + name string + shell mapShell + ifaceName string + wantHas assert.BoolAssertionFunc + wantErrMsg string + }{{ + name: "success", + shell: mapShell{ + "networksetup -listallhardwareports": { + err: nil, + out: "Hardware Port: hwport\nDevice: en0\n", + code: 0, + }, + "networksetup -getinfo hwport": { + err: nil, + out: "IP address: 1.2.3.4\nSubnet mask: 255.255.255.0\nRouter: 1.2.3.1\n", + code: 0, + }, + }, + ifaceName: "en0", + wantHas: assert.False, + wantErrMsg: ``, + }, { + name: "success_static", + shell: mapShell{ + "networksetup -listallhardwareports": { + err: nil, + out: "Hardware Port: hwport\nDevice: en0\n", + code: 0, + }, + "networksetup -getinfo hwport": { + err: nil, + out: "Manual Configuration\nIP address: 1.2.3.4\n" + + "Subnet mask: 255.255.255.0\nRouter: 1.2.3.1\n", + code: 0, + }, + }, + ifaceName: "en0", + wantHas: assert.True, + wantErrMsg: ``, + }, { + name: "reports_error", + shell: theOnlyCmd( + "networksetup -listallhardwareports", + 0, + "", + errors.Error("can't list"), + ), + ifaceName: "en0", + wantHas: assert.False, + wantErrMsg: `could not find hardware port for en0`, + }, { + name: "port_error", + shell: mapShell{ + "networksetup -listallhardwareports": { + err: nil, + out: "Hardware Port: hwport\nDevice: en0\n", + code: 0, + }, + "networksetup -getinfo hwport": { + err: errors.Error("can't get"), + out: ``, + code: 0, + }, + }, + ifaceName: "en0", + wantHas: assert.False, + wantErrMsg: `can't get`, + }, { + name: "port_bad_output", + shell: mapShell{ + "networksetup -listallhardwareports": { + err: nil, + out: "Hardware Port: hwport\nDevice: en0\n", + code: 0, + }, + "networksetup -getinfo hwport": { + err: nil, + out: "nothing meaningful", + code: 0, + }, + }, + ifaceName: "en0", + wantHas: assert.False, + wantErrMsg: `could not find hardware port info`, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + substShell(t, tc.shell.RunCmd) + + has, err := IfaceHasStaticIP(tc.ifaceName) + testutil.AssertErrorMsg(t, tc.wantErrMsg, err) + + tc.wantHas(t, has) + }) + } +} + +func TestIfaceSetStaticIP(t *testing.T) { + succFsys := fstest.MapFS{ + "etc/resolv.conf": &fstest.MapFile{ + Data: []byte(`nameserver 1.1.1.1`), + }, + } + panicFsys := &aghtest.FS{ + OnOpen: func(name string) (fs.File, error) { panic("not implemented") }, + } + + testCases := []struct { + name string + shell mapShell + fsys fs.FS + wantErrMsg string + }{{ + name: "success", + shell: mapShell{ + "networksetup -listallhardwareports": { + err: nil, + out: "Hardware Port: hwport\nDevice: en0\n", + code: 0, + }, + "networksetup -getinfo hwport": { + err: nil, + out: "IP address: 1.2.3.4\nSubnet mask: 255.255.255.0\nRouter: 1.2.3.1\n", + code: 0, + }, + "networksetup -setdnsservers hwport 1.1.1.1": { + err: nil, + out: "", + code: 0, + }, + "networksetup -setmanual hwport 1.2.3.4 255.255.255.0 1.2.3.1": { + err: nil, + out: "", + code: 0, + }, + }, + fsys: succFsys, + wantErrMsg: ``, + }, { + name: "static_already", + shell: mapShell{ + "networksetup -listallhardwareports": { + err: nil, + out: "Hardware Port: hwport\nDevice: en0\n", + code: 0, + }, + "networksetup -getinfo hwport": { + err: nil, + out: "Manual Configuration\nIP address: 1.2.3.4\n" + + "Subnet mask: 255.255.255.0\nRouter: 1.2.3.1\n", + code: 0, + }, + }, + fsys: panicFsys, + wantErrMsg: `ip address is already static`, + }, { + name: "reports_error", + shell: theOnlyCmd( + "networksetup -listallhardwareports", + 0, + "", + errors.Error("can't list"), + ), + fsys: panicFsys, + wantErrMsg: `could not find hardware port for en0`, + }, { + name: "resolv_conf_error", + shell: mapShell{ + "networksetup -listallhardwareports": { + err: nil, + out: "Hardware Port: hwport\nDevice: en0\n", + code: 0, + }, + "networksetup -getinfo hwport": { + err: nil, + out: "IP address: 1.2.3.4\nSubnet mask: 255.255.255.0\nRouter: 1.2.3.1\n", + code: 0, + }, + }, + fsys: fstest.MapFS{ + "etc/resolv.conf": &fstest.MapFile{ + Data: []byte("this resolv.conf is invalid"), + }, + }, + wantErrMsg: `found no dns servers in etc/resolv.conf`, + }, { + name: "set_dns_error", + shell: mapShell{ + "networksetup -listallhardwareports": { + err: nil, + out: "Hardware Port: hwport\nDevice: en0\n", + code: 0, + }, + "networksetup -getinfo hwport": { + err: nil, + out: "IP address: 1.2.3.4\nSubnet mask: 255.255.255.0\nRouter: 1.2.3.1\n", + code: 0, + }, + "networksetup -setdnsservers hwport 1.1.1.1": { + err: errors.Error("can't set"), + out: "", + code: 0, + }, + }, + fsys: succFsys, + wantErrMsg: `can't set`, + }, { + name: "set_manual_error", + shell: mapShell{ + "networksetup -listallhardwareports": { + err: nil, + out: "Hardware Port: hwport\nDevice: en0\n", + code: 0, + }, + "networksetup -getinfo hwport": { + err: nil, + out: "IP address: 1.2.3.4\nSubnet mask: 255.255.255.0\nRouter: 1.2.3.1\n", + code: 0, + }, + "networksetup -setdnsservers hwport 1.1.1.1": { + err: nil, + out: "", + code: 0, + }, + "networksetup -setmanual hwport 1.2.3.4 255.255.255.0 1.2.3.1": { + err: errors.Error("can't set"), + out: "", + code: 0, + }, + }, + fsys: succFsys, + wantErrMsg: `can't set`, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + substShell(t, tc.shell.RunCmd) + substRootDirFS(t, tc.fsys) + + err := IfaceSetStaticIP("en0") + testutil.AssertErrorMsg(t, tc.wantErrMsg, err) + }) + } +} diff --git a/internal/aghnet/net_freebsd.go b/internal/aghnet/net_freebsd.go index a5200fb8..34d93303 100644 --- a/internal/aghnet/net_freebsd.go +++ b/internal/aghnet/net_freebsd.go @@ -22,7 +22,7 @@ func ifaceHasStaticIP(ifaceName string) (ok bool, err error) { walker := aghos.FileWalker(interfaceName(ifaceName).rcConfStaticConfig) - return walker.Walk(aghos.RootDirFS(), rcConfFilename) + return walker.Walk(rootDirFS, rcConfFilename) } // rcConfStaticConfig checks if the interface is configured by /etc/rc.conf to diff --git a/internal/aghnet/net_freebsd_test.go b/internal/aghnet/net_freebsd_test.go index 3781b154..e00dafa7 100644 --- a/internal/aghnet/net_freebsd_test.go +++ b/internal/aghnet/net_freebsd_test.go @@ -4,56 +4,74 @@ package aghnet import ( - "strings" + "io/fs" "testing" + "testing/fstest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestRcConfStaticConfig(t *testing.T) { - const iface interfaceName = `em0` - const nl = "\n" + const ( + ifaceName = `em0` + rcConf = "etc/rc.conf" + ) testCases := []struct { - name string - rcconfData string - wantCont bool + name string + rootFsys fs.FS + wantHas assert.BoolAssertionFunc }{{ - name: "simple", - rcconfData: `ifconfig_em0="inet 127.0.0.253 netmask 0xffffffff"` + nl, - wantCont: false, + name: "simple", + rootFsys: fstest.MapFS{rcConf: &fstest.MapFile{ + Data: []byte(`ifconfig_` + ifaceName + `="inet 127.0.0.253 netmask 0xffffffff"` + nl), + }}, + wantHas: assert.True, }, { - name: "case_insensitiveness", - rcconfData: `ifconfig_em0="InEt 127.0.0.253 NeTmAsK 0xffffffff"` + nl, - wantCont: false, + name: "case_insensitiveness", + rootFsys: fstest.MapFS{rcConf: &fstest.MapFile{ + Data: []byte(`ifconfig_` + ifaceName + `="InEt 127.0.0.253 NeTmAsK 0xffffffff"` + nl), + }}, + wantHas: assert.True, }, { name: "comments_and_trash", - rcconfData: `# comment 1` + nl + - `` + nl + - `# comment 2` + nl + - `ifconfig_em0="inet 127.0.0.253 netmask 0xffffffff"` + nl, - wantCont: false, + rootFsys: fstest.MapFS{rcConf: &fstest.MapFile{ + Data: []byte(`# comment 1` + nl + + `` + nl + + `# comment 2` + nl + + `ifconfig_` + ifaceName + `="inet 127.0.0.253 netmask 0xffffffff"` + nl, + ), + }}, + wantHas: assert.True, }, { name: "aliases", - rcconfData: `ifconfig_em0_alias="inet 127.0.0.1/24"` + nl + - `ifconfig_em0="inet 127.0.0.253 netmask 0xffffffff"` + nl, - wantCont: false, + rootFsys: fstest.MapFS{rcConf: &fstest.MapFile{ + Data: []byte(`ifconfig_` + ifaceName + `_alias="inet 127.0.0.1/24"` + nl + + `ifconfig_` + ifaceName + `="inet 127.0.0.253 netmask 0xffffffff"` + nl, + ), + }}, + wantHas: assert.True, }, { name: "incorrect_config", - rcconfData: `ifconfig_em0="inet6 127.0.0.253 netmask 0xffffffff"` + nl + - `ifconfig_em0="inet 256.256.256.256 netmask 0xffffffff"` + nl + - `ifconfig_em0=""` + nl, - wantCont: true, + rootFsys: fstest.MapFS{rcConf: &fstest.MapFile{ + Data: []byte( + `ifconfig_` + ifaceName + `="inet6 127.0.0.253 netmask 0xffffffff"` + nl + + `ifconfig_` + ifaceName + `="inet 256.256.256.256 netmask 0xffffffff"` + nl + + `ifconfig_` + ifaceName + `=""` + nl, + ), + }}, + wantHas: assert.False, }} for _, tc := range testCases { - r := strings.NewReader(tc.rcconfData) t.Run(tc.name, func(t *testing.T) { - _, cont, err := iface.rcConfStaticConfig(r) + substRootDirFS(t, tc.rootFsys) + + has, err := IfaceHasStaticIP(ifaceName) require.NoError(t, err) - assert.Equal(t, tc.wantCont, cont) + tc.wantHas(t, has) }) } } diff --git a/internal/aghnet/net_linux.go b/internal/aghnet/net_linux.go index 93414165..c2526524 100644 --- a/internal/aghnet/net_linux.go +++ b/internal/aghnet/net_linux.go @@ -21,8 +21,7 @@ import ( // have a static IP. func (n interfaceName) dhcpcdStaticConfig(r io.Reader) (subsources []string, cont bool, err error) { s := bufio.NewScanner(r) - ifaceFound := findIfaceLine(s, string(n)) - if !ifaceFound { + if !findIfaceLine(s, string(n)) { return nil, true, s.Err() } @@ -61,9 +60,9 @@ func (n interfaceName) ifacesStaticConfig(r io.Reader) (sub []string, cont bool, fields := strings.Fields(line) fieldsNum := len(fields) - // Man page interfaces(5) declares that interface definition - // should consist of the key word "iface" followed by interface - // name, and method at fourth field. + // Man page interfaces(5) declares that interface definition should + // consist of the key word "iface" followed by interface name, and + // method at fourth field. if fieldsNum >= 4 && fields[0] == "iface" && fields[1] == string(n) && fields[3] == "static" { return nil, false, nil @@ -78,10 +77,10 @@ func (n interfaceName) ifacesStaticConfig(r io.Reader) (sub []string, cont bool, } func ifaceHasStaticIP(ifaceName string) (has bool, err error) { - // TODO(a.garipov): Currently, this function returns the first - // definitive result. So if /etc/dhcpcd.conf has a static IP while - // /etc/network/interfaces doesn't, it will return true. Perhaps this - // is not the most desirable behavior. + // TODO(a.garipov): Currently, this function returns the first definitive + // result. So if /etc/dhcpcd.conf has and /etc/network/interfaces has no + // static IP configuration, it will return true. Perhaps this is not the + // most desirable behavior. iface := interfaceName(ifaceName) @@ -95,12 +94,10 @@ func ifaceHasStaticIP(ifaceName string) (has bool, err error) { FileWalker: iface.ifacesStaticConfig, filename: "etc/network/interfaces", }} { - has, err = pair.Walk(aghos.RootDirFS(), pair.filename) + has, err = pair.Walk(rootDirFS, pair.filename) if err != nil { return false, err - } - - if has { + } else if has { return true, nil } } @@ -141,13 +138,15 @@ func ifaceSetStaticIP(ifaceName string) (err error) { gatewayIP := GatewayIP(ifaceName) add := dhcpcdConfIface(ifaceName, ipNet, gatewayIP, ipNet.IP) - body, err := os.ReadFile("/etc/dhcpcd.conf") + const filename = "/etc/dhcpcd.conf" + + body, err := os.ReadFile(filename) if err != nil && !errors.Is(err, os.ErrNotExist) { return err } body = append(body, []byte(add)...) - err = maybe.WriteFile("/etc/dhcpcd.conf", body, 0o644) + err = maybe.WriteFile(filename, body, 0o644) if err != nil { return fmt.Errorf("writing conf: %w", err) } diff --git a/internal/aghnet/net_linux_test.go b/internal/aghnet/net_linux_test.go index bf2cecfe..e46da4df 100644 --- a/internal/aghnet/net_linux_test.go +++ b/internal/aghnet/net_linux_test.go @@ -4,122 +4,130 @@ package aghnet import ( - "bytes" + "io/fs" "net" "testing" + "testing/fstest" + "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) -func TestDHCPCDStaticConfig(t *testing.T) { - const iface interfaceName = `wlan0` +func TestHasStaticIP(t *testing.T) { + const ifaceName = "wlan0" + + const ( + dhcpcd = "etc/dhcpcd.conf" + netifaces = "etc/network/interfaces" + ) testCases := []struct { - name string - data []byte - wantCont bool + rootFsys fs.FS + name string + wantHas assert.BoolAssertionFunc + wantErrMsg string }{{ - name: "has_not", - data: []byte(`#comment` + nl + - `# comment` + nl + - `interface eth0` + nl + - `static ip_address=192.168.0.1/24` + nl + - `# interface ` + iface + nl + - `static ip_address=192.168.1.1/24` + nl + - `# comment` + nl, - ), - wantCont: true, + rootFsys: fstest.MapFS{ + dhcpcd: &fstest.MapFile{ + Data: []byte(`#comment` + nl + + `# comment` + nl + + `interface eth0` + nl + + `static ip_address=192.168.0.1/24` + nl + + `# interface ` + ifaceName + nl + + `static ip_address=192.168.1.1/24` + nl + + `# comment` + nl, + ), + }, + }, + name: "dhcpcd_has_not", + wantHas: assert.False, + wantErrMsg: `no information about static ip`, }, { - name: "has", - data: []byte(`#comment` + nl + - `# comment` + nl + - `interface eth0` + nl + - `static ip_address=192.168.0.1/24` + nl + - `# interface ` + iface + nl + - `static ip_address=192.168.1.1/24` + nl + - `# comment` + nl + - `interface ` + iface + nl + - `# comment` + nl + - `static ip_address=192.168.2.1/24` + nl, - ), - wantCont: false, + rootFsys: fstest.MapFS{ + dhcpcd: &fstest.MapFile{ + Data: []byte(`#comment` + nl + + `# comment` + nl + + `interface ` + ifaceName + nl + + `static ip_address=192.168.0.1/24` + nl + + `# interface ` + ifaceName + nl + + `static ip_address=192.168.1.1/24` + nl + + `# comment` + nl, + ), + }, + }, + name: "dhcpcd_has", + wantHas: assert.True, + wantErrMsg: ``, + }, { + rootFsys: fstest.MapFS{ + netifaces: &fstest.MapFile{ + Data: []byte(`allow-hotplug ` + ifaceName + nl + + `#iface enp0s3 inet static` + nl + + `# address 192.168.0.200` + nl + + `# netmask 255.255.255.0` + nl + + `# gateway 192.168.0.1` + nl + + `iface ` + ifaceName + ` inet dhcp` + nl, + ), + }, + }, + name: "netifaces_has_not", + wantHas: assert.False, + wantErrMsg: `no information about static ip`, + }, { + rootFsys: fstest.MapFS{ + netifaces: &fstest.MapFile{ + Data: []byte(`allow-hotplug ` + ifaceName + nl + + `iface ` + ifaceName + ` inet static` + nl + + ` address 192.168.0.200` + nl + + ` netmask 255.255.255.0` + nl + + ` gateway 192.168.0.1` + nl + + `#iface ` + ifaceName + ` inet dhcp` + nl, + ), + }, + }, + name: "netifaces_has", + wantHas: assert.True, + wantErrMsg: ``, + }, { + rootFsys: fstest.MapFS{ + netifaces: &fstest.MapFile{ + Data: []byte(`source hello` + nl + + `#iface ` + ifaceName + ` inet static` + nl, + ), + }, + "hello": &fstest.MapFile{ + Data: []byte(`iface ` + ifaceName + ` inet static` + nl), + }, + }, + name: "netifaces_another_file", + wantHas: assert.True, + wantErrMsg: ``, + }, { + rootFsys: fstest.MapFS{ + netifaces: &fstest.MapFile{ + Data: []byte(`source hello` + nl + + `iface ` + ifaceName + ` inet static` + nl, + ), + }, + }, + name: "netifaces_ignore_another", + wantHas: assert.True, + wantErrMsg: ``, }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - r := bytes.NewReader(tc.data) - _, cont, err := iface.dhcpcdStaticConfig(r) - require.NoError(t, err) + substRootDirFS(t, tc.rootFsys) - assert.Equal(t, tc.wantCont, cont) + has, err := IfaceHasStaticIP(ifaceName) + testutil.AssertErrorMsg(t, tc.wantErrMsg, err) + + tc.wantHas(t, has) }) } } -func TestIfacesStaticConfig(t *testing.T) { - const iface interfaceName = `enp0s3` - - testCases := []struct { - name string - data []byte - wantCont bool - wantPatterns []string - }{{ - name: "has_not", - data: []byte(`allow-hotplug ` + iface + nl + - `#iface enp0s3 inet static` + nl + - `# address 192.168.0.200` + nl + - `# netmask 255.255.255.0` + nl + - `# gateway 192.168.0.1` + nl + - `iface ` + iface + ` inet dhcp` + nl, - ), - wantCont: true, - wantPatterns: []string{}, - }, { - name: "has", - data: []byte(`allow-hotplug ` + iface + nl + - `iface ` + iface + ` inet static` + nl + - ` address 192.168.0.200` + nl + - ` netmask 255.255.255.0` + nl + - ` gateway 192.168.0.1` + nl + - `#iface ` + iface + ` inet dhcp` + nl, - ), - wantCont: false, - wantPatterns: []string{}, - }, { - name: "return_patterns", - data: []byte(`source hello` + nl + - `source world` + nl + - `#iface ` + iface + ` inet static` + nl, - ), - wantCont: true, - wantPatterns: []string{"hello", "world"}, - }, { - // This one tests if the first found valid interface prevents - // checking files under the `source` directive. - name: "ignore_patterns", - data: []byte(`source hello` + nl + - `source world` + nl + - `iface ` + iface + ` inet static` + nl, - ), - wantCont: false, - wantPatterns: []string{}, - }} - - for _, tc := range testCases { - r := bytes.NewReader(tc.data) - t.Run(tc.name, func(t *testing.T) { - patterns, has, err := iface.ifacesStaticConfig(r) - require.NoError(t, err) - - assert.Equal(t, tc.wantCont, has) - assert.ElementsMatch(t, tc.wantPatterns, patterns) - }) - } -} - -func TestSetStaticIPdhcpcdConf(t *testing.T) { +func TestSetStaticIP_dhcpcdConfIface(t *testing.T) { testCases := []struct { name string dhcpcdConf string diff --git a/internal/aghnet/net_openbsd.go b/internal/aghnet/net_openbsd.go index 627db0ab..68ef90e0 100644 --- a/internal/aghnet/net_openbsd.go +++ b/internal/aghnet/net_openbsd.go @@ -20,7 +20,7 @@ func canBindPrivilegedPorts() (can bool, err error) { func ifaceHasStaticIP(ifaceName string) (ok bool, err error) { filename := fmt.Sprintf("etc/hostname.%s", ifaceName) - return aghos.FileWalker(hostnameIfStaticConfig).Walk(aghos.RootDirFS(), filename) + return aghos.FileWalker(hostnameIfStaticConfig).Walk(rootDirFS, filename) } // hostnameIfStaticConfig checks if the interface is configured by diff --git a/internal/aghnet/net_openbsd_test.go b/internal/aghnet/net_openbsd_test.go index e157d93a..356799b7 100644 --- a/internal/aghnet/net_openbsd_test.go +++ b/internal/aghnet/net_openbsd_test.go @@ -4,49 +4,69 @@ package aghnet import ( - "strings" + "fmt" + "io/fs" "testing" + "testing/fstest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestHostnameIfStaticConfig(t *testing.T) { - const nl = "\n" +func TestIfaceHasStaticIP(t *testing.T) { + const ifaceName = "em0" + + confFile := fmt.Sprintf("etc/hostname.%s", ifaceName) testCases := []struct { - name string - rcconfData string - wantHas bool + name string + rootFsys fs.FS + wantHas assert.BoolAssertionFunc }{{ - name: "simple", - rcconfData: `inet 127.0.0.253` + nl, - wantHas: true, + name: "simple", + rootFsys: fstest.MapFS{ + confFile: &fstest.MapFile{ + Data: []byte(`inet 127.0.0.253` + nl), + }, + }, + wantHas: assert.True, }, { - name: "case_sensitiveness", - rcconfData: `InEt 127.0.0.253` + nl, - wantHas: false, + name: "case_sensitiveness", + rootFsys: fstest.MapFS{ + confFile: &fstest.MapFile{ + Data: []byte(`InEt 127.0.0.253` + nl), + }, + }, + wantHas: assert.False, }, { name: "comments_and_trash", - rcconfData: `# comment 1` + nl + - `` + nl + - `# inet 127.0.0.253` + nl + - `inet` + nl, - wantHas: false, + rootFsys: fstest.MapFS{ + confFile: &fstest.MapFile{ + Data: []byte(`# comment 1` + nl + nl + + `# inet 127.0.0.253` + nl + + `inet` + nl, + ), + }, + }, + wantHas: assert.False, }, { name: "incorrect_config", - rcconfData: `inet6 127.0.0.253` + nl + - `inet 256.256.256.256` + nl, - wantHas: false, + rootFsys: fstest.MapFS{ + confFile: &fstest.MapFile{ + Data: []byte(`inet6 127.0.0.253` + nl + `inet 256.256.256.256` + nl), + }, + }, + wantHas: assert.False, }} for _, tc := range testCases { - r := strings.NewReader(tc.rcconfData) t.Run(tc.name, func(t *testing.T) { - _, has, err := hostnameIfStaticConfig(r) + substRootDirFS(t, tc.rootFsys) + + has, err := IfaceHasStaticIP(ifaceName) require.NoError(t, err) - assert.Equal(t, tc.wantHas, has) + tc.wantHas(t, has) }) } } diff --git a/internal/aghnet/net_test.go b/internal/aghnet/net_test.go index b461cb44..29de869b 100644 --- a/internal/aghnet/net_test.go +++ b/internal/aghnet/net_test.go @@ -1,9 +1,11 @@ package aghnet import ( + "fmt" "io/fs" "net" "os" + "strings" "testing" "github.com/AdguardTeam/AdGuardHome/internal/aghtest" @@ -14,11 +16,102 @@ import ( "github.com/stretchr/testify/require" ) +func TestMain(m *testing.M) { + aghtest.DiscardLogOutput(m) +} + // testdata is the filesystem containing data for testing the package. var testdata fs.FS = os.DirFS("./testdata") -func TestMain(m *testing.M) { - aghtest.DiscardLogOutput(m) +// substRootDirFS replaces the aghos.RootDirFS function used throughout the +// package with fsys for tests ran under t. +func substRootDirFS(t testing.TB, fsys fs.FS) { + t.Helper() + + prev := rootDirFS + t.Cleanup(func() { rootDirFS = prev }) + rootDirFS = fsys +} + +// RunCmdFunc is the signature of aghos.RunCommand function. +type RunCmdFunc func(cmd string, args ...string) (code int, out []byte, err error) + +// substShell replaces the the aghos.RunCommand function used throughout the +// package with rc for tests ran under t. +func substShell(t testing.TB, rc RunCmdFunc) { + t.Helper() + + prev := aghosRunCommand + t.Cleanup(func() { aghosRunCommand = prev }) + aghosRunCommand = rc +} + +// mapShell is a substitution of aghos.RunCommand that maps the command to it's +// execution result. It's only needed to simplify testing. +// +// TODO(e.burkov): Perhaps put all the shell interactions behind an interface. +type mapShell map[string]struct { + err error + out string + code int +} + +// theOnlyCmd returns s that only handles a single command and arguments +// combination from cmd. +func theOnlyCmd(cmd string, code int, out string, err error) (s mapShell) { + return mapShell{cmd: {code: code, out: out, err: err}} +} + +// RunCmd is a RunCmdFunc handled by s. +func (s mapShell) RunCmd(cmd string, args ...string) (code int, out []byte, err error) { + key := strings.Join(append([]string{cmd}, args...), " ") + ret, ok := s[key] + if !ok { + return 0, nil, fmt.Errorf("unexpected shell command %q", key) + } + + return ret.code, []byte(ret.out), ret.err +} + +func TestGatewayIP(t *testing.T) { + testCases := []struct { + name string + shell mapShell + want net.IP + }{{ + name: "success_v4", + shell: theOnlyCmd("ip route show dev ifaceName", 0, `default via 1.2.3.4 onlink`, nil), + want: net.IP{1, 2, 3, 4}.To16(), + }, { + name: "success_v6", + shell: theOnlyCmd("ip route show dev ifaceName", 0, `default via ::ffff onlink`, nil), + want: net.IP{ + 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0xFF, 0xFF, + }, + }, { + name: "bad_output", + shell: theOnlyCmd("ip route show dev ifaceName", 0, `non-default via 1.2.3.4 onlink`, nil), + want: nil, + }, { + name: "err_runcmd", + shell: theOnlyCmd("ip route show dev ifaceName", 0, "", errors.Error("can't run command")), + want: nil, + }, { + name: "bad_code", + shell: theOnlyCmd("ip route show dev ifaceName", 1, "", nil), + want: nil, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + substShell(t, tc.shell.RunCmd) + + assert.Equal(t, tc.want, GatewayIP("ifaceName")) + }) + } } func TestGetInterfaceByIP(t *testing.T) { @@ -133,6 +226,8 @@ func TestCheckPort(t *testing.T) { } func TestCollectAllIfacesAddrs(t *testing.T) { + t.Skip("TODO(e.burkov): Substitute the net.Interfaces.") + addrs, err := CollectAllIfacesAddrs() require.NoError(t, err) diff --git a/internal/aghos/os.go b/internal/aghos/os.go index 8ac189a1..018a3e89 100644 --- a/internal/aghos/os.go +++ b/internal/aghos/os.go @@ -57,20 +57,22 @@ func HaveAdminRights() (bool, error) { const MaxCmdOutputSize = 64 * 1024 // RunCommand runs shell command. -func RunCommand(command string, arguments ...string) (code int, output string, err error) { +func RunCommand(command string, arguments ...string) (code int, output []byte, err error) { cmd := exec.Command(command, arguments...) out, err := cmd.Output() if len(out) > MaxCmdOutputSize { out = out[:MaxCmdOutputSize] } - if errors.As(err, new(*exec.ExitError)) { - return cmd.ProcessState.ExitCode(), string(out), nil - } else if err != nil { - return 1, "", fmt.Errorf("command %q failed: %w: %s", command, err, out) + if err != nil { + if eerr := new(exec.ExitError); errors.As(err, &eerr) { + return eerr.ExitCode(), eerr.Stderr, nil + } + + return 1, nil, fmt.Errorf("command %q failed: %w: %s", command, err, out) } - return cmd.ProcessState.ExitCode(), string(out), nil + return cmd.ProcessState.ExitCode(), out, nil } // PIDByCommand searches for process named command and returns its PID ignoring diff --git a/internal/home/service_openbsd.go b/internal/home/service_openbsd.go index 679a7437..8ad0d212 100644 --- a/internal/home/service_openbsd.go +++ b/internal/home/service_openbsd.go @@ -314,12 +314,13 @@ func (s *openbsdRunComService) runCom(cmd string) (out string, err error) { // TODO(e.burkov): It's possible that os.ErrNotExist is caused by // something different than the service script's non-existence. Keep it // in mind, when replace the aghos.RunCommand. - _, out, err = aghos.RunCommand(scriptPath, cmd) + var outData []byte + _, outData, err = aghos.RunCommand(scriptPath, cmd) if errors.Is(err, os.ErrNotExist) { return "", service.ErrNotInstalled } - return out, err + return string(outData), err } // Status implements service.Service interface for *openbsdRunComService.