Pull request: 3435 openwrt detect

Merge in DNS/adguard-home from 3435-openwrt-detect to master

Updates #3435.

Squashed commit of the following:

commit 04b10f407ced1c85ac8089f980d79e9bbfe14e95
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Fri Aug 13 19:02:55 2021 +0300

    aghos: fix windows build

commit d387cec5f9cae9256dccef8c666c02f2fb7449a2
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Fri Aug 13 18:22:12 2021 +0300

    aghos: imp code, tests

commit 2450b98522eb032ec8658f3ef2384fc77b627cc6
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Fri Aug 13 13:43:46 2021 +0300

    all: imp code, docs

commit 7fabba3a8dc70fe61dbaa8fd5445453816fe9ac7
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Fri Aug 13 04:04:09 2021 +0300

    all: log changes

commit 7cc1235308caf09eb4c80c05a4f328b8d6909ec7
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Fri Aug 13 03:33:13 2021 +0300

    querylog: repl with golibs

commit 84592087d3b2aca23613950bb203ff3c862624dc
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Fri Aug 13 03:16:37 2021 +0300

    aghos: use filewalker

commit e4f2964b0e031c7a9a053e85c0ff7c792c772929
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Fri Aug 13 00:34:20 2021 +0300

    aghos: mv recurrentchecker from aghnet
This commit is contained in:
Eugene Burkov 2021-08-13 19:20:17 +03:00
parent e3ad46876f
commit 394c2f65e0
16 changed files with 478 additions and 485 deletions

View File

@ -43,6 +43,7 @@ and this project adheres to
### Changed ### Changed
- Better OpenWrt detection ([#3435]).
- DNS-over-HTTPS queries that come from HTTP proxies in the `trusted_proxies` - DNS-over-HTTPS queries that come from HTTP proxies in the `trusted_proxies`
list now use the real IP address of the client instead of the address of the list now use the real IP address of the client instead of the address of the
proxy ([#2799]). proxy ([#2799]).
@ -123,6 +124,7 @@ and this project adheres to
[#3351]: https://github.com/AdguardTeam/AdGuardHome/issues/3351 [#3351]: https://github.com/AdguardTeam/AdGuardHome/issues/3351
[#3372]: https://github.com/AdguardTeam/AdGuardHome/issues/3372 [#3372]: https://github.com/AdguardTeam/AdGuardHome/issues/3372
[#3417]: https://github.com/AdguardTeam/AdGuardHome/issues/3417 [#3417]: https://github.com/AdguardTeam/AdGuardHome/issues/3417
[#3435]: https://github.com/AdguardTeam/AdGuardHome/issues/3435
[#3437]: https://github.com/AdguardTeam/AdGuardHome/issues/3437 [#3437]: https://github.com/AdguardTeam/AdGuardHome/issues/3437

2
go.mod
View File

@ -4,7 +4,7 @@ go 1.16
require ( require (
github.com/AdguardTeam/dnsproxy v0.39.2 github.com/AdguardTeam/dnsproxy v0.39.2
github.com/AdguardTeam/golibs v0.9.0 github.com/AdguardTeam/golibs v0.9.1
github.com/AdguardTeam/urlfilter v0.14.6 github.com/AdguardTeam/urlfilter v0.14.6
github.com/NYTimes/gziphandler v1.1.1 github.com/NYTimes/gziphandler v1.1.1
github.com/ameshkov/dnscrypt/v2 v2.2.1 github.com/ameshkov/dnscrypt/v2 v2.2.1

4
go.sum
View File

@ -14,8 +14,8 @@ github.com/AdguardTeam/dnsproxy v0.39.2/go.mod h1:aNXKNdTyKfgAG2OS712SYSaGIM9Aas
github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
github.com/AdguardTeam/golibs v0.4.2/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= github.com/AdguardTeam/golibs v0.4.2/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
github.com/AdguardTeam/golibs v0.8.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= github.com/AdguardTeam/golibs v0.8.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
github.com/AdguardTeam/golibs v0.9.0 h1:QwmHqeZOVs9XpkmPb2iYpZ35OBArjgTesE8gLtEFRFg= github.com/AdguardTeam/golibs v0.9.1 h1:mHSN4LfaY1uGmHPsl97paAND/VeSnM5r9XQ7pSYx93o=
github.com/AdguardTeam/golibs v0.9.0/go.mod h1:fCAMwPBJ8S7YMYbTWvYS+eeTLblP5E04IDtNAo7y7IY= github.com/AdguardTeam/golibs v0.9.1/go.mod h1:fCAMwPBJ8S7YMYbTWvYS+eeTLblP5E04IDtNAo7y7IY=
github.com/AdguardTeam/gomitmproxy v0.2.0/go.mod h1:Qdv0Mktnzer5zpdpi5rAwixNJzW2FN91LjKJCkVbYGU= github.com/AdguardTeam/gomitmproxy v0.2.0/go.mod h1:Qdv0Mktnzer5zpdpi5rAwixNJzW2FN91LjKJCkVbYGU=
github.com/AdguardTeam/urlfilter v0.14.6 h1:emqoKZElooHACYehRBYENeKVN1a/rspxiqTIMYLuoIo= github.com/AdguardTeam/urlfilter v0.14.6 h1:emqoKZElooHACYehRBYENeKVN1a/rspxiqTIMYLuoIo=
github.com/AdguardTeam/urlfilter v0.14.6/go.mod h1:klx4JbOfc4EaNb5lWLqOwfg+pVcyRukmoJRvO55lL5U= github.com/AdguardTeam/urlfilter v0.14.6/go.mod h1:klx4JbOfc4EaNb5lWLqOwfg+pVcyRukmoJRvO55lL5U=

View File

@ -8,69 +8,47 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"os"
"strings" "strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
"github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors"
) )
func canBindPrivilegedPorts() (can bool, err error) { func canBindPrivilegedPorts() (can bool, err error) {
return aghos.HaveAdminRights() return aghos.HaveAdminRights()
} }
// maxCheckedFileSize is the maximum acceptable length of the /etc/rc.conf file.
const maxCheckedFileSize = 1024 * 1024
func ifaceHasStaticIP(ifaceName string) (ok bool, err error) { func ifaceHasStaticIP(ifaceName string) (ok bool, err error) {
const filename = "/etc/rc.conf" const filename = "/etc/rc.conf"
var f *os.File return aghos.FileWalker(interfaceName(ifaceName).rcConfStaticConfig).Walk(filename)
f, err = os.Open(filename)
if err != nil {
return false, err
}
defer func() { err = errors.WithDeferred(err, f.Close()) }()
var r io.Reader
r, err = aghio.LimitReader(f, maxCheckedFileSize)
if err != nil {
return false, err
}
return rcConfStaticConfig(r, ifaceName)
} }
// rcConfStaticConfig checks if the interface is configured by /etc/rc.conf to // rcConfStaticConfig checks if the interface is configured by /etc/rc.conf to
// have a static IP. // have a static IP.
func rcConfStaticConfig(r io.Reader, ifaceName string) (has bool, err error) { func (n interfaceName) rcConfStaticConfig(r io.Reader) (_ []string, cont bool, err error) {
s := bufio.NewScanner(r) s := bufio.NewScanner(r)
for ifaceLinePref := fmt.Sprintf("ifconfig_%s", ifaceName); s.Scan(); { for pref := fmt.Sprintf("ifconfig_%s=", n); s.Scan(); {
line := strings.TrimSpace(s.Text()) line := strings.TrimSpace(s.Text())
if !strings.HasPrefix(line, ifaceLinePref) { if !strings.HasPrefix(line, pref) {
continue continue
} }
eqIdx := len(ifaceLinePref) cfgLeft, cfgRight := len(pref)+1, len(line)-1
if line[eqIdx] != '=' { if cfgLeft >= cfgRight {
continue continue
} }
fieldsStart, fieldsEnd := eqIdx+2, len(line)-1 // TODO(e.burkov): Expand the check to cover possible
if fieldsStart >= fieldsEnd { // configurations from man rc.conf(5).
continue fields := strings.Fields(line[cfgLeft:cfgRight])
}
fields := strings.Fields(line[fieldsStart:fieldsEnd])
if len(fields) >= 2 && if len(fields) >= 2 &&
strings.ToLower(fields[0]) == "inet" && strings.EqualFold(fields[0], "inet") &&
net.ParseIP(fields[1]) != nil { net.ParseIP(fields[1]) != nil {
return true, s.Err() return nil, false, s.Err()
} }
} }
return false, s.Err() return nil, true, s.Err()
} }
func ifaceSetStaticIP(string) (err error) { func ifaceSetStaticIP(string) (err error) {

View File

@ -12,49 +12,48 @@ import (
) )
func TestRcConfStaticConfig(t *testing.T) { func TestRcConfStaticConfig(t *testing.T) {
const ifaceName = `em0` const iface interfaceName = `em0`
const nl = "\n" const nl = "\n"
testCases := []struct { testCases := []struct {
name string name string
rcconfData string rcconfData string
wantHas bool wantCont bool
}{{ }{{
name: "simple", name: "simple",
rcconfData: `ifconfig_em0="inet 127.0.0.253 netmask 0xffffffff"` + nl, rcconfData: `ifconfig_em0="inet 127.0.0.253 netmask 0xffffffff"` + nl,
wantHas: true, wantCont: false,
}, { }, {
name: "case_insensitiveness", name: "case_insensitiveness",
rcconfData: `ifconfig_em0="InEt 127.0.0.253 NeTmAsK 0xffffffff"` + nl, rcconfData: `ifconfig_em0="InEt 127.0.0.253 NeTmAsK 0xffffffff"` + nl,
wantHas: true, wantCont: false,
}, { }, {
name: "comments_and_trash", name: "comments_and_trash",
rcconfData: `# comment 1` + nl + rcconfData: `# comment 1` + nl +
`` + nl + `` + nl +
`# comment 2` + nl + `# comment 2` + nl +
`ifconfig_em0="inet 127.0.0.253 netmask 0xffffffff"` + nl, `ifconfig_em0="inet 127.0.0.253 netmask 0xffffffff"` + nl,
wantHas: true, wantCont: false,
}, { }, {
name: "aliases", name: "aliases",
rcconfData: `ifconfig_em0_alias="inet 127.0.0.1/24"` + nl + rcconfData: `ifconfig_em0_alias="inet 127.0.0.1/24"` + nl +
`ifconfig_em0="inet 127.0.0.253 netmask 0xffffffff"` + nl, `ifconfig_em0="inet 127.0.0.253 netmask 0xffffffff"` + nl,
wantHas: true, wantCont: false,
}, { }, {
name: "incorrect_config", name: "incorrect_config",
rcconfData: `ifconfig_em0="inet6 127.0.0.253 netmask 0xffffffff"` + nl + rcconfData: `ifconfig_em0="inet6 127.0.0.253 netmask 0xffffffff"` + nl +
`ifconfig_em0="inet 127.0.0.253 net-mask 0xffffffff"` + nl +
`ifconfig_em0="inet 256.256.256.256 netmask 0xffffffff"` + nl + `ifconfig_em0="inet 256.256.256.256 netmask 0xffffffff"` + nl +
`ifconfig_em0=""` + nl, `ifconfig_em0=""` + nl,
wantHas: false, wantCont: true,
}} }}
for _, tc := range testCases { for _, tc := range testCases {
r := strings.NewReader(tc.rcconfData) r := strings.NewReader(tc.rcconfData)
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
has, err := rcConfStaticConfig(r, ifaceName) _, cont, err := iface.rcConfStaticConfig(r)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, tc.wantHas, has) assert.Equal(t, tc.wantCont, cont)
}) })
} }
} }

View File

@ -9,130 +9,72 @@ import (
"io" "io"
"net" "net"
"os" "os"
"path/filepath"
"strings" "strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
"github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/google/renameio/maybe" "github.com/google/renameio/maybe"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
// recurrentChecker is used to check all the files which may include references // dhcpcdStaticConfig checks if interface is configured by /etc/dhcpcd.conf to
// for other ones. // have a static IP.
type recurrentChecker struct { func (n interfaceName) dhcpcdStaticConfig(r io.Reader) (subsources []string, cont bool, err error) {
// checker is the function to check if r's stream contains the desired s := bufio.NewScanner(r)
// attribute. It must return all the patterns for files which should ifaceFound := findIfaceLine(s, string(n))
// also be checked and each of them should be valid for filepath.Glob if !ifaceFound {
// function. return nil, true, s.Err()
checker func(r io.Reader, desired string) (patterns []string, has bool, err error) }
// initPath is the path of the first member in the sequence of checked
// files. for s.Scan() {
initPath string line := strings.TrimSpace(s.Text())
fields := strings.Fields(line)
if len(fields) >= 2 &&
fields[0] == "static" &&
strings.HasPrefix(fields[1], "ip_address=") {
return nil, false, s.Err()
}
if len(fields) > 0 && fields[0] == "interface" {
// Another interface found.
break
}
}
return nil, true, s.Err()
} }
// maxCheckedFileSize is the maximum length of the file that recurrentChecker // ifacesStaticConfig checks if the interface is configured by any file of
// may check. // /etc/network/interfaces format to have a static IP.
const maxCheckedFileSize = 1024 * 1024 func (n interfaceName) ifacesStaticConfig(r io.Reader) (sub []string, cont bool, err error) {
s := bufio.NewScanner(r)
// checkFile tries to open and to check single file located on the sourcePath. for s.Scan() {
func (rc *recurrentChecker) checkFile(sourcePath, desired string) ( line := strings.TrimSpace(s.Text())
subsources []string, if len(line) == 0 || line[0] == '#' {
has bool, continue
err error,
) {
var f *os.File
f, err = os.Open(sourcePath)
if err != nil {
return nil, false, err
}
defer func() { err = errors.WithDeferred(err, f.Close()) }()
var r io.Reader
r, err = aghio.LimitReader(f, maxCheckedFileSize)
if err != nil {
return nil, false, err
}
subsources, has, err = rc.checker(r, desired)
if err != nil {
return nil, false, err
}
if has {
return nil, true, nil
}
return subsources, has, nil
}
// handlePatterns parses the patterns and takes care of duplicates.
func (rc *recurrentChecker) handlePatterns(sourcesSet *stringutil.Set, patterns []string) (
subsources []string,
err error,
) {
subsources = make([]string, 0, len(patterns))
for _, p := range patterns {
var matches []string
matches, err = filepath.Glob(p)
if err != nil {
return nil, fmt.Errorf("invalid pattern %q: %w", p, err)
} }
for _, m := range matches { // TODO(e.burkov): As man page interfaces(5) says, a line may be
if sourcesSet.Has(m) { // extended across multiple lines by making the last character a
continue // backslash. Provide extended lines support.
}
sourcesSet.Add(m) fields := strings.Fields(line)
subsources = append(subsources, m) fieldsNum := len(fields)
// Man page interfaces(5) declares that interface definition
// should consist of the key word "iface" followed by interface
// name, and method at fourth field.
if fieldsNum >= 4 &&
fields[0] == "iface" && fields[1] == string(n) && fields[3] == "static" {
return nil, false, nil
}
if fieldsNum >= 2 && fields[0] == "source" {
sub = append(sub, fields[1])
} }
} }
return subsources, nil return sub, true, s.Err()
}
// check walks through all the files searching for the desired attribute.
func (rc *recurrentChecker) check(desired string) (has bool, err error) {
var i int
sources := []string{rc.initPath}
defer func() {
if i >= len(sources) {
return
}
err = errors.Annotate(err, "checking %q: %w", sources[i])
}()
var patterns, subsources []string
// The slice of sources is separate from the set of sources to keep the
// order in which the files are walked.
for sourcesSet := stringutil.NewSet(rc.initPath); i < len(sources); i++ {
patterns, has, err = rc.checkFile(sources[i], desired)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
continue
}
return false, err
}
if has {
return true, nil
}
subsources, err = rc.handlePatterns(sourcesSet, patterns)
if err != nil {
return false, err
}
sources = append(sources, subsources...)
}
return false, nil
} }
func ifaceHasStaticIP(ifaceName string) (has bool, err error) { func ifaceHasStaticIP(ifaceName string) (has bool, err error) {
@ -141,14 +83,19 @@ func ifaceHasStaticIP(ifaceName string) (has bool, err error) {
// /etc/network/interfaces doesn't, it will return true. Perhaps this // /etc/network/interfaces doesn't, it will return true. Perhaps this
// is not the most desirable behavior. // is not the most desirable behavior.
for _, rc := range []*recurrentChecker{{ iface := interfaceName(ifaceName)
checker: dhcpcdStaticConfig,
initPath: "/etc/dhcpcd.conf", for _, pair := range []struct {
aghos.FileWalker
filename string
}{{
FileWalker: iface.dhcpcdStaticConfig,
filename: "/etc/dhcpcd.conf",
}, { }, {
checker: ifacesStaticConfig, FileWalker: iface.ifacesStaticConfig,
initPath: "/etc/network/interfaces", filename: "/etc/network/interfaces",
}} { }} {
has, err = rc.check(ifaceName) has, err = pair.Walk(pair.filename)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -183,67 +130,6 @@ func findIfaceLine(s *bufio.Scanner, name string) (ok bool) {
return false return false
} }
// dhcpcdStaticConfig checks if interface is configured by /etc/dhcpcd.conf to
// have a static IP.
func dhcpcdStaticConfig(r io.Reader, ifaceName string) (subsources []string, has bool, err error) {
s := bufio.NewScanner(r)
ifaceFound := findIfaceLine(s, ifaceName)
if !ifaceFound {
return nil, false, s.Err()
}
for s.Scan() {
line := strings.TrimSpace(s.Text())
fields := strings.Fields(line)
if len(fields) >= 2 &&
fields[0] == "static" &&
strings.HasPrefix(fields[1], "ip_address=") {
return nil, true, s.Err()
}
if len(fields) > 0 && fields[0] == "interface" {
// Another interface found.
break
}
}
return nil, false, s.Err()
}
// ifacesStaticConfig checks if the interface is configured by any file of
// /etc/network/interfaces format to have a static IP.
func ifacesStaticConfig(r io.Reader, ifaceName string) (subsources []string, has bool, err error) {
s := bufio.NewScanner(r)
for s.Scan() {
line := strings.TrimSpace(s.Text())
if len(line) == 0 || line[0] == '#' {
continue
}
// TODO(e.burkov): As man page interfaces(5) says, a line may be
// extended across multiple lines by making the last character a
// backslash. Provide extended lines and "source-directory"
// stanzas support.
fields := strings.Fields(line)
fieldsNum := len(fields)
// Man page interfaces(5) declares that interface definition
// should consist of the key word "iface" followed by interface
// name, and method at fourth field.
if fieldsNum >= 4 &&
fields[0] == "iface" && fields[1] == ifaceName && fields[3] == "static" {
return nil, true, nil
}
if fieldsNum >= 2 && fields[0] == "source" {
subsources = append(subsources, fields[1])
}
}
return subsources, false, s.Err()
}
// ifaceSetStaticIP configures the system to retain its current IP on the // ifaceSetStaticIP configures the system to retain its current IP on the
// interface through dhcpdc.conf. // interface through dhcpdc.conf.
func ifaceSetStaticIP(ifaceName string) (err error) { func ifaceSetStaticIP(ifaceName string) (err error) {

View File

@ -12,101 +12,90 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestRecurrentChecker(t *testing.T) {
c := &recurrentChecker{
checker: ifacesStaticConfig,
initPath: "./testdata/include-subsources",
}
has, err := c.check("sample_name")
require.NoError(t, err)
assert.True(t, has)
has, err = c.check("another_name")
require.NoError(t, err)
assert.False(t, has)
}
const nl = "\n" const nl = "\n"
func TestDHCPCDStaticConfig(t *testing.T) { func TestDHCPCDStaticConfig(t *testing.T) {
const iface interfaceName = `wlan0`
testCases := []struct { testCases := []struct {
name string name string
data []byte data []byte
want bool wantCont bool
}{{ }{{
name: "has_not", name: "has_not",
data: []byte(`#comment` + nl + data: []byte(`#comment` + nl +
`# comment` + nl + `# comment` + nl +
`interface eth0` + nl + `interface eth0` + nl +
`static ip_address=192.168.0.1/24` + nl + `static ip_address=192.168.0.1/24` + nl +
`# interface wlan0` + nl + `# interface ` + iface + nl +
`static ip_address=192.168.1.1/24` + nl + `static ip_address=192.168.1.1/24` + nl +
`# comment` + nl, `# comment` + nl,
), ),
want: false, wantCont: true,
}, { }, {
name: "has", name: "has",
data: []byte(`#comment` + nl + data: []byte(`#comment` + nl +
`# comment` + nl + `# comment` + nl +
`interface eth0` + nl + `interface eth0` + nl +
`static ip_address=192.168.0.1/24` + nl + `static ip_address=192.168.0.1/24` + nl +
`# interface wlan0` + nl + `# interface ` + iface + nl +
`static ip_address=192.168.1.1/24` + nl + `static ip_address=192.168.1.1/24` + nl +
`# comment` + nl + `# comment` + nl +
`interface wlan0` + nl + `interface ` + iface + nl +
`# comment` + nl + `# comment` + nl +
`static ip_address=192.168.2.1/24` + nl, `static ip_address=192.168.2.1/24` + nl,
), ),
want: true, wantCont: false,
}} }}
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
r := bytes.NewReader(tc.data) r := bytes.NewReader(tc.data)
_, has, err := dhcpcdStaticConfig(r, "wlan0") _, cont, err := iface.dhcpcdStaticConfig(r)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, tc.want, has) assert.Equal(t, tc.wantCont, cont)
}) })
} }
} }
func TestIfacesStaticConfig(t *testing.T) { func TestIfacesStaticConfig(t *testing.T) {
const iface interfaceName = `enp0s3`
testCases := []struct { testCases := []struct {
name string name string
data []byte data []byte
want bool wantCont bool
wantPatterns []string wantPatterns []string
}{{ }{{
name: "has_not", name: "has_not",
data: []byte(`allow-hotplug enp0s3` + nl + data: []byte(`allow-hotplug ` + iface + nl +
`#iface enp0s3 inet static` + nl + `#iface enp0s3 inet static` + nl +
`# address 192.168.0.200` + nl + `# address 192.168.0.200` + nl +
`# netmask 255.255.255.0` + nl + `# netmask 255.255.255.0` + nl +
`# gateway 192.168.0.1` + nl + `# gateway 192.168.0.1` + nl +
`iface enp0s3 inet dhcp` + nl, `iface ` + iface + ` inet dhcp` + nl,
), ),
want: false, wantCont: true,
wantPatterns: []string{}, wantPatterns: []string{},
}, { }, {
name: "has", name: "has",
data: []byte(`allow-hotplug enp0s3` + nl + data: []byte(`allow-hotplug ` + iface + nl +
`iface enp0s3 inet static` + nl + `iface ` + iface + ` inet static` + nl +
` address 192.168.0.200` + nl + ` address 192.168.0.200` + nl +
` netmask 255.255.255.0` + nl + ` netmask 255.255.255.0` + nl +
` gateway 192.168.0.1` + nl + ` gateway 192.168.0.1` + nl +
`#iface enp0s3 inet dhcp` + nl, `#iface ` + iface + ` inet dhcp` + nl,
), ),
want: true, wantCont: false,
wantPatterns: []string{}, wantPatterns: []string{},
}, { }, {
name: "return_patterns", name: "return_patterns",
data: []byte(`source hello` + nl + data: []byte(`source hello` + nl +
`source world` + nl + `source world` + nl +
`#iface enp0s3 inet static` + nl, `#iface ` + iface + ` inet static` + nl,
), ),
want: false, wantCont: true,
wantPatterns: []string{"hello", "world"}, wantPatterns: []string{"hello", "world"},
}, { }, {
// This one tests if the first found valid interface prevents // This one tests if the first found valid interface prevents
@ -114,19 +103,19 @@ func TestIfacesStaticConfig(t *testing.T) {
name: "ignore_patterns", name: "ignore_patterns",
data: []byte(`source hello` + nl + data: []byte(`source hello` + nl +
`source world` + nl + `source world` + nl +
`iface enp0s3 inet static` + nl, `iface ` + iface + ` inet static` + nl,
), ),
want: true, wantCont: false,
wantPatterns: []string{}, wantPatterns: []string{},
}} }}
for _, tc := range testCases { for _, tc := range testCases {
r := bytes.NewReader(tc.data)
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
r := bytes.NewReader(tc.data) patterns, has, err := iface.ifacesStaticConfig(r)
patterns, has, err := ifacesStaticConfig(r, "enp0s3")
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, tc.want, has) assert.Equal(t, tc.wantCont, has)
assert.ElementsMatch(t, tc.wantPatterns, patterns) assert.ElementsMatch(t, tc.wantPatterns, patterns)
}) })
} }

View File

@ -8,61 +8,34 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"os"
"strings" "strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
"github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors"
) )
func canBindPrivilegedPorts() (can bool, err error) { func canBindPrivilegedPorts() (can bool, err error) {
return aghos.HaveAdminRights() return aghos.HaveAdminRights()
} }
// maxCheckedFileSize is the maximum acceptable length of the /etc/hostname.*
// files.
const maxCheckedFileSize = 1024 * 1024
func ifaceHasStaticIP(ifaceName string) (ok bool, err error) { func ifaceHasStaticIP(ifaceName string) (ok bool, err error) {
const filenameFmt = "/etc/hostname.%s" filename := fmt.Sprintf("/etc/hostname.%s", ifaceName)
filename := fmt.Sprintf(filenameFmt, ifaceName) return aghos.FileWalker(hostnameIfStaticConfig).Walk(filename)
var f *os.File
if f, err = os.Open(filename); err != nil {
if errors.Is(err, os.ErrNotExist) {
err = nil
}
return false, err
}
defer func() { err = errors.WithDeferred(err, f.Close()) }()
var r io.Reader
r, err = aghio.LimitReader(f, maxCheckedFileSize)
if err != nil {
return false, err
}
return hostnameIfStaticConfig(r)
} }
// hostnameIfStaticConfig checks if the interface is configured by // hostnameIfStaticConfig checks if the interface is configured by
// /etc/hostname.* to have a static IP. // /etc/hostname.* to have a static IP.
// func hostnameIfStaticConfig(r io.Reader) (_ []string, ok bool, err error) {
// TODO(e.burkov): The platform-dependent functions to check the static IP
// address configured are rather similar. Think about unifying common parts.
func hostnameIfStaticConfig(r io.Reader) (has bool, err error) {
s := bufio.NewScanner(r) s := bufio.NewScanner(r)
for s.Scan() { for s.Scan() {
line := strings.TrimSpace(s.Text()) line := strings.TrimSpace(s.Text())
fields := strings.Fields(line) fields := strings.Fields(line)
if len(fields) >= 2 && fields[0] == "inet" && net.ParseIP(fields[1]) != nil { if len(fields) >= 2 && fields[0] == "inet" && net.ParseIP(fields[1]) != nil {
return true, s.Err() return nil, true, s.Err()
} }
} }
return false, s.Err() return nil, false, s.Err()
} }
func ifaceSetStaticIP(string) (err error) { func ifaceSetStaticIP(string) (err error) {

View File

@ -43,7 +43,7 @@ func TestHostnameIfStaticConfig(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
r := strings.NewReader(tc.rcconfData) r := strings.NewReader(tc.rcconfData)
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
has, err := hostnameIfStaticConfig(r) _, has, err := hostnameIfStaticConfig(r)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, tc.wantHas, has) assert.Equal(t, tc.wantHas, has)

View File

@ -0,0 +1,8 @@
//go:build openbsd || freebsd || linux
// +build openbsd freebsd linux
package aghnet
// interfaceName is a string containing network interface's name. The name is
// used in file walking methods.
type interfaceName string

View File

@ -0,0 +1,119 @@
package aghos
import (
"fmt"
"io"
"os"
"path/filepath"
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/stringutil"
)
// FileWalker is the signature of a function called for files in the file tree.
// As opposed to filepath.Walk it only walk the files (not directories) matching
// the provided pattern and those returned by function itself. All patterns
// should be valid for filepath.Glob. If cont is false, the walking terminates.
// Each opened file is also limited for reading to MaxWalkedFileSize.
//
// TODO(e.burkov): Consider moving to the separate package like pathutil.
//
// TODO(e.burkov): Think about passing filename or any additional data.
type FileWalker func(r io.Reader) (patterns []string, cont bool, err error)
// MaxWalkedFileSize is the maximum length of the file that FileWalker can
// check.
const MaxWalkedFileSize = 1024 * 1024
// checkFile tries to open and process a single file located on sourcePath.
func checkFile(c FileWalker, sourcePath string) (patterns []string, cont bool, err error) {
var f *os.File
f, err = os.Open(sourcePath)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
// Ignore non-existing files since this may only happen
// when the file was removed after filepath.Glob matched
// it.
return nil, true, nil
}
return nil, false, err
}
defer func() { err = errors.WithDeferred(err, f.Close()) }()
var r io.Reader
// Ignore the error since LimitReader function returns error only if
// passed limit value is less than zero, but the constant used.
//
// TODO(e.burkov): Make variable.
r, _ = aghio.LimitReader(f, MaxWalkedFileSize)
return c(r)
}
// handlePatterns parses the patterns and ignores duplicates using srcSet.
// srcSet must be non-nil.
func handlePatterns(srcSet *stringutil.Set, patterns ...string) (sub []string, err error) {
sub = make([]string, 0, len(patterns))
for _, p := range patterns {
var matches []string
matches, err = filepath.Glob(p)
if err != nil {
// Enrich error with the pattern because filepath.Glob
// doesn't do it.
return nil, fmt.Errorf("invalid pattern %q: %w", p, err)
}
for _, m := range matches {
if srcSet.Has(m) {
continue
}
srcSet.Add(m)
sub = append(sub, m)
}
}
return sub, nil
}
// Walk starts walking the files defined by initPattern. It only returns true
// if c signed to stop walking.
func (c FileWalker) Walk(initPattern string) (ok bool, err error) {
// The slice of sources keeps the order in which the files are walked
// since srcSet.Values() returns strings in undefined order.
srcSet := stringutil.NewSet()
var src []string
src, err = handlePatterns(srcSet, initPattern)
if err != nil {
return false, err
}
var filename string
defer func() { err = errors.Annotate(err, "checking %q: %w", filename) }()
for i := 0; i < len(src); i++ {
var patterns []string
var cont bool
filename = src[i]
patterns, cont, err = checkFile(c, src[i])
if err != nil {
return false, err
}
if !cont {
return true, nil
}
var subsrc []string
subsrc, err = handlePatterns(srcSet, patterns...)
if err != nil {
return false, err
}
src = append(src, subsrc...)
}
return false, nil
}

View File

@ -0,0 +1,209 @@
package aghos
import (
"bufio"
"io"
"io/fs"
"os"
"path/filepath"
"testing"
"github.com/AdguardTeam/golibs/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// testFSDir maps entries' names to entries which should either be a testFSDir
// or byte slice.
type testFSDir map[string]interface{}
// testFSGen is used to generate a temporary filesystem consisting of
// directories and plain text files from itself.
type testFSGen testFSDir
// gen returns the name of top directory of the generated filesystem.
func (g testFSGen) gen(t *testing.T) (dirName string) {
t.Helper()
dirName = t.TempDir()
g.rangeThrough(t, dirName)
return dirName
}
func (g testFSGen) rangeThrough(t *testing.T, dirName string) {
const perm fs.FileMode = 0o777
for k, e := range g {
switch e := e.(type) {
case []byte:
require.NoError(t, os.WriteFile(filepath.Join(dirName, k), e, perm))
case testFSDir:
newDir := filepath.Join(dirName, k)
require.NoError(t, os.Mkdir(newDir, perm))
testFSGen(e).rangeThrough(t, newDir)
default:
t.Fatalf("unexpected entry type %T", e)
}
}
}
func TestFileWalker_Walk(t *testing.T) {
const attribute = `000`
makeFileWalker := func(dirName string) (fw FileWalker) {
return func(r io.Reader) (patterns []string, cont bool, err error) {
s := bufio.NewScanner(r)
for s.Scan() {
line := s.Text()
if line == attribute {
return nil, false, nil
}
if len(line) != 0 {
patterns = append(patterns, filepath.Join(dirName, line))
}
}
return patterns, true, s.Err()
}
}
const nl = "\n"
testCases := []struct {
name string
testFS testFSGen
initPattern string
want bool
}{{
name: "simple",
testFS: testFSGen{
"simple_0001.txt": []byte(attribute + nl),
},
initPattern: "simple_0001.txt",
want: true,
}, {
name: "chain",
testFS: testFSGen{
"chain_0001.txt": []byte(`chain_0002.txt` + nl),
"chain_0002.txt": []byte(`chain_0003.txt` + nl),
"chain_0003.txt": []byte(attribute + nl),
},
initPattern: "chain_0001.txt",
want: true,
}, {
name: "several",
testFS: testFSGen{
"several_0001.txt": []byte(`several_*` + nl),
"several_0002.txt": []byte(`several_0001.txt` + nl),
"several_0003.txt": []byte(attribute + nl),
},
initPattern: "several_0001.txt",
want: true,
}, {
name: "no",
testFS: testFSGen{
"no_0001.txt": []byte(nl),
"no_0002.txt": []byte(nl),
"no_0003.txt": []byte(nl),
},
initPattern: "no_*",
want: false,
}, {
name: "subdirectory",
testFS: testFSGen{
"dir": testFSDir{
"subdir_0002.txt": []byte(attribute + nl),
},
"subdir_0001.txt": []byte(`dir/*`),
},
initPattern: "subdir_0001.txt",
want: true,
}}
for _, tc := range testCases {
testDir := tc.testFS.gen(t)
fw := makeFileWalker(testDir)
t.Run(tc.name, func(t *testing.T) {
ok, err := fw.Walk(filepath.Join(testDir, tc.initPattern))
require.NoError(t, err)
assert.Equal(t, tc.want, ok)
})
}
t.Run("pattern_malformed", func(t *testing.T) {
ok, err := makeFileWalker("").Walk("[]")
require.Error(t, err)
assert.False(t, ok)
assert.ErrorIs(t, err, filepath.ErrBadPattern)
})
t.Run("bad_filename", func(t *testing.T) {
dir := testFSGen{
"bad_filename.txt": []byte("[]"),
}.gen(t)
fw := FileWalker(func(r io.Reader) (patterns []string, cont bool, err error) {
s := bufio.NewScanner(r)
for s.Scan() {
patterns = append(patterns, s.Text())
}
return patterns, true, s.Err()
})
ok, err := fw.Walk(filepath.Join(dir, "bad_filename.txt"))
require.Error(t, err)
assert.False(t, ok)
assert.ErrorIs(t, err, filepath.ErrBadPattern)
})
t.Run("itself_error", func(t *testing.T) {
const rerr errors.Error = "returned error"
dir := testFSGen{
"mockfile.txt": []byte(`mockdata`),
}.gen(t)
ok, err := FileWalker(func(r io.Reader) (patterns []string, ok bool, err error) {
return nil, true, rerr
}).Walk(filepath.Join(dir, "*"))
require.Error(t, err)
require.False(t, ok)
assert.ErrorIs(t, err, rerr)
})
}
func TestWalkerFunc_CheckFile(t *testing.T) {
t.Run("non-existing", func(t *testing.T) {
_, ok, err := checkFile(nil, "lol")
require.NoError(t, err)
assert.True(t, ok)
})
t.Run("invalid_argument", func(t *testing.T) {
const badPath = "\x00"
_, ok, err := checkFile(nil, badPath)
require.Error(t, err)
assert.False(t, ok)
// TODO(e.burkov): Use assert.ErrorsIs within the error from
// less platform-dependent package instead of syscall.EINVAL.
//
// See https://github.com/golang/go/issues/46849 and
// https://github.com/golang/go/issues/30322.
pathErr := &os.PathError{}
require.ErrorAs(t, err, &pathErr)
assert.Equal(t, "open", pathErr.Op)
assert.Equal(t, badPath, pathErr.Path)
})
}

View File

@ -4,11 +4,11 @@
package aghos package aghos
import ( import (
"bytes" "io"
"os" "os"
"path/filepath"
"strings"
"syscall" "syscall"
"github.com/AdguardTeam/golibs/stringutil"
) )
func setRlimit(val uint64) (err error) { func setRlimit(val uint64) (err error) {
@ -30,37 +30,20 @@ func sendProcessSignal(pid int, sig syscall.Signal) error {
} }
func isOpenWrt() (ok bool) { func isOpenWrt() (ok bool) {
const etcDir = "/etc" var err error
ok, err = FileWalker(func(r io.Reader) (_ []string, cont bool, err error) {
const osNameData = "openwrt"
dirEnts, err := os.ReadDir(etcDir) // This use of ReadAll is now safe, because FileWalker's Walk()
if err != nil { // have limited r.
return false var data []byte
} data, err = io.ReadAll(r)
// fNameSubstr is a part of a name of the desired file.
const fNameSubstr = "release"
osNameData := []byte("OpenWrt")
for _, dirEnt := range dirEnts {
if dirEnt.IsDir() {
continue
}
fn := dirEnt.Name()
if !strings.Contains(fn, fNameSubstr) {
continue
}
var body []byte
body, err = os.ReadFile(filepath.Join(etcDir, fn))
if err != nil { if err != nil {
continue return nil, false, err
} }
if bytes.Contains(body, osNameData) { return nil, !stringutil.ContainsFold(string(data), osNameData), nil
return true }).Walk("/etc/*release*")
}
}
return false return err == nil && ok
} }

View File

@ -2,10 +2,9 @@ package querylog
import ( import (
"strings" "strings"
"unicode"
"unicode/utf8"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/stringutil"
) )
type criterionType int type criterionType int
@ -69,37 +68,6 @@ func ctDomainOrClientCaseStrict(
strings.EqualFold(name, term) strings.EqualFold(name, term)
} }
// containsFold reports whehter s contains, ignoring letter case, substr.
//
// TODO(a.garipov): Move to aghstrings if needed elsewhere.
func containsFold(s, substr string) (ok bool) {
sLen, substrLen := len(s), len(substr)
if sLen < substrLen {
return false
}
if sLen == substrLen {
return strings.EqualFold(s, substr)
}
first, _ := utf8.DecodeRuneInString(substr)
firstFolded := unicode.SimpleFold(first)
for i := 0; i != -1 && len(s) >= len(substr); {
if strings.EqualFold(s[:substrLen], substr) {
return true
}
i = strings.IndexFunc(s[1:], func(r rune) (eq bool) {
return r == first || r == firstFolded
})
s = s[1+i:]
}
return false
}
func ctDomainOrClientCaseNonStrict( func ctDomainOrClientCaseNonStrict(
term string, term string,
asciiTerm string, asciiTerm string,
@ -108,11 +76,11 @@ func ctDomainOrClientCaseNonStrict(
host string, host string,
ip string, ip string,
) (ok bool) { ) (ok bool) {
return containsFold(clientID, term) || return stringutil.ContainsFold(clientID, term) ||
containsFold(host, term) || stringutil.ContainsFold(host, term) ||
(asciiTerm != "" && containsFold(host, asciiTerm)) || (asciiTerm != "" && stringutil.ContainsFold(host, asciiTerm)) ||
containsFold(ip, term) || stringutil.ContainsFold(ip, term) ||
containsFold(name, term) stringutil.ContainsFold(name, term)
} }
// quickMatch quickly checks if the line matches the given search criterion. // quickMatch quickly checks if the line matches the given search criterion.

View File

@ -1,121 +0,0 @@
package querylog
import (
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestContainsFold(t *testing.T) {
testCases := []struct {
name string
inS string
inSubstr string
want bool
}{{
name: "empty",
inS: "",
inSubstr: "",
want: true,
}, {
name: "shorter",
inS: "a",
inSubstr: "abc",
want: false,
}, {
name: "same_len_true",
inS: "abc",
inSubstr: "abc",
want: true,
}, {
name: "same_len_true_fold",
inS: "abc",
inSubstr: "aBc",
want: true,
}, {
name: "same_len_false",
inS: "abc",
inSubstr: "def",
want: false,
}, {
name: "longer_true",
inS: "abcdedef",
inSubstr: "def",
want: true,
}, {
name: "longer_false",
inS: "abcded",
inSubstr: "ghi",
want: false,
}, {
name: "longer_true_fold",
inS: "abcdedef",
inSubstr: "dEf",
want: true,
}, {
name: "longer_false_fold",
inS: "abcded",
inSubstr: "gHi",
want: false,
}, {
name: "longer_true_cyr_fold",
inS: "абвгдедеё",
inSubstr: "дЕЁ",
want: true,
}, {
name: "longer_false_cyr_fold",
inS: "абвгдедеё",
inSubstr: "жЗИ",
want: false,
}, {
name: "no_letters_true",
inS: "1.2.3.4",
inSubstr: "2.3.4",
want: true,
}, {
name: "no_letters_false",
inS: "1.2.3.4",
inSubstr: "2.3.5",
want: false,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if tc.want {
assert.True(t, containsFold(tc.inS, tc.inSubstr))
} else {
assert.False(t, containsFold(tc.inS, tc.inSubstr))
}
})
}
}
var sink bool
func BenchmarkContainsFold(b *testing.B) {
const s = "aaahBbBhccchDDDeEehFfFhGGGhHhh"
const substr = "HHH"
// Compare our implementation of containsFold against a stupid solution
// of calling strings.ToLower and strings.Contains.
b.Run("containsfold", func(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
sink = containsFold(s, substr)
}
assert.True(b, sink)
})
b.Run("tolower_contains", func(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
sink = strings.Contains(strings.ToLower(s), strings.ToLower(substr))
}
assert.True(b, sink)
})
}