From 3b2f5d7842a290a124a04489b3cca20f02a180d4 Mon Sep 17 00:00:00 2001 From: Eugene Burkov Date: Mon, 22 Mar 2021 16:46:36 +0300 Subject: [PATCH] Pull request: 2704 local resolvers vol.1 Merge in DNS/adguard-home from 2704-local-addresses-vol.1 to master Updates #2704. Updates #2829. Updates #2846. Squashed commit of the following: commit 9a49b3d27edcb30da7f16a065226907833b1dc81 Author: Eugene Burkov Date: Mon Mar 22 15:39:17 2021 +0300 aghnet: imp docs and logging commit 74f95a29c55b9e732276601b0ecc63fb7c3a9f9e Author: Eugene Burkov Date: Fri Mar 19 20:56:51 2021 +0300 all: fix friday evening mistakes commit 0e2066bc5c16ed807fa601780b99e154502361a9 Author: Eugene Burkov Date: Fri Mar 19 20:51:15 2021 +0300 all: upd testify, imp code quality commit 8237c50b670c58361ccf7adec3ff2452b1196677 Author: Eugene Burkov Date: Fri Mar 19 20:19:29 2021 +0300 aghnet: imp test naming commit 14eb1e189339554c0a6d38e2ba7a93917774ebab Author: Eugene Burkov Date: Fri Mar 19 19:41:43 2021 +0300 aghnet: isolate windows-specific functionality commit d461ac8b18c187999da3e3aba116571b7ebe6785 Author: Eugene Burkov Date: Fri Mar 19 14:50:05 2021 +0300 aghnet: imp code quality commit d0ee01cb1f8613de2085c0f2f2f396e46beb52a5 Author: Eugene Burkov Date: Fri Mar 19 11:59:10 2021 +0300 all: mv funcs to agherr, mk system resolvers getter --- go.mod | 2 +- go.sum | 2 + internal/agherr/agherr.go | 18 ++ internal/agherr/agherr_test.go | 38 +++++ internal/aghnet/net.go | 22 +++ internal/aghnet/net_darwin.go | 10 +- internal/aghnet/systemresolvers.go | 78 +++++++++ internal/aghnet/systemresolvers_others.go | 96 +++++++++++ .../aghnet/systemresolvers_others_test.go | 74 ++++++++ internal/aghnet/systemresolvers_test.go | 33 ++++ internal/aghnet/systemresolvers_windows.go | 158 ++++++++++++++++++ .../aghnet/systemresolvers_windows_test.go | 7 + internal/aghos/os.go | 23 ++- internal/aghtest/upstream.go | 7 +- internal/home/control.go | 19 +-- internal/home/controlinstall.go | 9 +- internal/home/service.go | 4 +- internal/querylog/qlogreader_test.go | 3 +- internal/util/helpers.go | 13 -- 19 files changed, 568 insertions(+), 48 deletions(-) create mode 100644 internal/aghnet/systemresolvers.go create mode 100644 internal/aghnet/systemresolvers_others.go create mode 100644 internal/aghnet/systemresolvers_others_test.go create mode 100644 internal/aghnet/systemresolvers_test.go create mode 100644 internal/aghnet/systemresolvers_windows.go create mode 100644 internal/aghnet/systemresolvers_windows_test.go diff --git a/go.mod b/go.mod index c6eb56a2..c339c991 100644 --- a/go.mod +++ b/go.mod @@ -29,7 +29,7 @@ require ( github.com/satori/go.uuid v1.2.0 github.com/sirupsen/logrus v1.8.1 // indirect github.com/spf13/cobra v1.1.3 // indirect - github.com/stretchr/testify v1.6.1 + github.com/stretchr/testify v1.7.0 github.com/ti-mo/netfilter v0.4.0 github.com/u-root/u-root v7.0.0+incompatible go.etcd.io/bbolt v1.3.5 diff --git a/go.sum b/go.sum index c26a26cd..a01f82f9 100644 --- a/go.sum +++ b/go.sum @@ -413,6 +413,8 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw= github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= github.com/ti-mo/netfilter v0.2.0/go.mod h1:8GbBGsY/8fxtyIdfwy29JiluNcPK4K7wIT+x42ipqUU= diff --git a/internal/agherr/agherr.go b/internal/agherr/agherr.go index eb61206f..fd3cd830 100644 --- a/internal/agherr/agherr.go +++ b/internal/agherr/agherr.go @@ -4,6 +4,8 @@ package agherr import ( "fmt" "strings" + + "github.com/AdguardTeam/golibs/log" ) // Error is the constant error type. @@ -95,6 +97,8 @@ type wrapper interface { // } // // msg must contain the final ": %w" verb. +// +// TODO(a.garipov): Clearify the function usage. func Annotate(msg string, errPtr *error, args ...interface{}) { if errPtr == nil { return @@ -107,3 +111,17 @@ func Annotate(msg string, errPtr *error, args ...interface{}) { *errPtr = fmt.Errorf(msg, args...) } } + +// LogPanic is a convinient helper function to log a panic in a goroutine. It +// should not be used where proper error handling is required. +func LogPanic(prefix string) { + if v := recover(); v != nil { + if prefix != "" { + log.Error("%s: recovered from panic: %v", prefix, v) + + return + } + + log.Error("recovered from panic: %v", v) + } +} diff --git a/internal/agherr/agherr_test.go b/internal/agherr/agherr_test.go index b9f3183c..a951010f 100644 --- a/internal/agherr/agherr_test.go +++ b/internal/agherr/agherr_test.go @@ -1,10 +1,12 @@ package agherr import ( + "bytes" "errors" "fmt" "testing" + "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -120,3 +122,39 @@ func TestAnnotate(t *testing.T) { assert.Equal(t, wantMsg, err.Error()) }) } + +func TestLogPanic(t *testing.T) { + buf := &bytes.Buffer{} + aghtest.ReplaceLogWriter(t, buf) + + t.Run("prefix", func(t *testing.T) { + const ( + panicMsg = "spooky!" + prefix = "packagename" + errWithNoPrefix = "[error] recovered from panic: spooky!" + errWithPrefix = "[error] packagename: recovered from panic: spooky!" + ) + + panicFunc := func(prefix string) { + defer LogPanic(prefix) + + panic(panicMsg) + } + + panicFunc("") + assert.Contains(t, buf.String(), errWithNoPrefix) + buf.Reset() + + panicFunc(prefix) + assert.Contains(t, buf.String(), errWithPrefix) + buf.Reset() + }) + + t.Run("don't_panic", func(t *testing.T) { + require.NotPanics(t, func() { + defer LogPanic("") + }) + + assert.Empty(t, buf.String()) + }) +} diff --git a/internal/aghnet/net.go b/internal/aghnet/net.go index 9bce75fd..d49e5573 100644 --- a/internal/aghnet/net.go +++ b/internal/aghnet/net.go @@ -251,3 +251,25 @@ func ErrorIsAddrInUse(err error) bool { return errErrno == syscall.EADDRINUSE } + +// SplitHost is a wrapper for net.SplitHostPort for the cases when the hostport +// does not necessarily contain a port. +func SplitHost(hostport string) (host string, err error) { + host, _, err = net.SplitHostPort(hostport) + if err != nil { + // Check for the missing port error. If it is that error, just + // use the host as is. + // + // See the source code for net.SplitHostPort. + const missingPort = "missing port in address" + + addrErr := &net.AddrError{} + if !errors.As(err, &addrErr) || addrErr.Err != missingPort { + return "", err + } + + host = hostport + } + + return host, nil +} diff --git a/internal/aghnet/net_darwin.go b/internal/aghnet/net_darwin.go index 56fa51e1..926e87e5 100644 --- a/internal/aghnet/net_darwin.go +++ b/internal/aghnet/net_darwin.go @@ -9,7 +9,7 @@ import ( "regexp" "strings" - "github.com/AdguardTeam/AdGuardHome/internal/util" + "github.com/AdguardTeam/AdGuardHome/internal/aghos" ) // hardwarePortInfo - information obtained using MacOS networksetup @@ -47,7 +47,7 @@ func getCurrentHardwarePortInfo(ifaceName string) (hardwarePortInfo, error) { // it returns a map where the key is the interface name, and the value is the "hardware port" // returns nil if it fails to parse the output func getNetworkSetupHardwareReports() map[string]string { - _, out, err := util.RunCommand("networksetup", "-listallhardwareports") + _, out, err := aghos.RunCommand("networksetup", "-listallhardwareports") if err != nil { return nil } @@ -72,7 +72,7 @@ func getNetworkSetupHardwareReports() map[string]string { func getHardwarePortInfo(hardwarePort string) (hardwarePortInfo, error) { h := hardwarePortInfo{} - _, out, err := util.RunCommand("networksetup", "-getinfo", hardwarePort) + _, out, err := aghos.RunCommand("networksetup", "-getinfo", hardwarePort) if err != nil { return h, err } @@ -116,7 +116,7 @@ func ifaceSetStaticIP(ifaceName string) (err error) { args = append(args, dnsAddrs...) // Setting DNS servers is necessary when configuring a static IP - code, _, err := util.RunCommand("networksetup", args...) + code, _, err := aghos.RunCommand("networksetup", args...) if err != nil { return err } @@ -125,7 +125,7 @@ func ifaceSetStaticIP(ifaceName string) (err error) { } // Actually configures hardware port to have static IP - code, _, err = util.RunCommand("networksetup", "-setmanual", + code, _, err = aghos.RunCommand("networksetup", "-setmanual", portInfo.name, portInfo.ip, portInfo.subnet, portInfo.gatewayIP) if err != nil { return err diff --git a/internal/aghnet/systemresolvers.go b/internal/aghnet/systemresolvers.go new file mode 100644 index 00000000..4a0ae6ca --- /dev/null +++ b/internal/aghnet/systemresolvers.go @@ -0,0 +1,78 @@ +package aghnet + +import ( + "time" + + "github.com/AdguardTeam/AdGuardHome/internal/agherr" + "github.com/AdguardTeam/golibs/log" +) + +// DefaultRefreshIvl is the default period of time between refreshing cached +// addresses. +// const DefaultRefreshIvl = 5 * time.Minute + +// HostGenFunc is the signature for functions generating fake hostnames. The +// implementation must be safe for concurrent use. +type HostGenFunc func() (host string) + +// unit is an alias for an existing map value. +type unit = struct{} + +// SystemResolvers helps to work with local resolvers' addresses provided by OS. +type SystemResolvers interface { + // Get returns the slice of local resolvers' addresses. + // It should be safe for concurrent use. + Get() (rs []string) + // Refresh refreshes the local resolvers' addresses cache. It should be + // safe for concurrent use. + Refresh() (err error) +} + +const ( + // fakeDialErr is an error which dialFunc is expected to return. + fakeDialErr agherr.Error = "this error signals the successful dialFunc work" + + // badAddrPassedErr is returned when dialFunc can't parse an IP address. + badAddrPassedErr agherr.Error = "the passed string is not a valid IP address" +) + +// refreshWithTicker refreshes the cache of sr after each tick form tickCh. +func refreshWithTicker(sr SystemResolvers, tickCh <-chan time.Time) { + defer agherr.LogPanic("systemResolvers") + + // TODO(e.burkov): Implement a functionality to stop ticker. + for range tickCh { + err := sr.Refresh() + if err != nil { + log.Error("systemResolvers: error in refreshing goroutine: %s", err) + + continue + } + + log.Debug("systemResolvers: local addresses cache is refreshed") + } +} + +// NewSystemResolvers returns a SystemResolvers with the cache refresh rate +// defined by refreshIvl. It disables auto-resfreshing if refreshIvl is 0. If +// nil is passed for hostGenFunc, the default generator will be used. +func NewSystemResolvers( + refreshIvl time.Duration, + hostGenFunc HostGenFunc, +) (sr SystemResolvers, err error) { + sr = newSystemResolvers(refreshIvl, hostGenFunc) + + // Fill cache. + err = sr.Refresh() + if err != nil { + return nil, err + } + + if refreshIvl > 0 { + ticker := time.NewTicker(refreshIvl) + + go refreshWithTicker(sr, ticker.C) + } + + return sr, nil +} diff --git a/internal/aghnet/systemresolvers_others.go b/internal/aghnet/systemresolvers_others.go new file mode 100644 index 00000000..ad67cfdb --- /dev/null +++ b/internal/aghnet/systemresolvers_others.go @@ -0,0 +1,96 @@ +// +build !windows + +package aghnet + +import ( + "context" + "errors" + "fmt" + "net" + "sync" + "time" + + "github.com/AdguardTeam/AdGuardHome/internal/agherr" +) + +// defaultHostGen is the default method of generating host for Refresh. +func defaultHostGen() (host string) { + // TODO(e.burkov): Use strings.Builder. + return fmt.Sprintf("test%d.org", time.Now().UnixNano()) +} + +// systemResolvers is a default implementation of SystemResolvers interface. +type systemResolvers struct { + resolver *net.Resolver + hostGenFunc HostGenFunc + + // addrs is the map that contains cached local resolvers' addresses. + addrs map[string]unit + addrsLock sync.RWMutex +} + +func (sr *systemResolvers) Refresh() (err error) { + defer agherr.Annotate("systemResolvers: %w", &err) + + _, err = sr.resolver.LookupHost(context.Background(), sr.hostGenFunc()) + dnserr := &net.DNSError{} + if errors.As(err, &dnserr) && dnserr.Err == fakeDialErr.Error() { + return nil + } + + return err +} + +func newSystemResolvers(refreshIvl time.Duration, hostGenFunc HostGenFunc) (sr SystemResolvers) { + if hostGenFunc == nil { + hostGenFunc = defaultHostGen + } + s := &systemResolvers{ + resolver: &net.Resolver{ + PreferGo: true, + }, + hostGenFunc: hostGenFunc, + addrs: make(map[string]unit), + } + s.resolver.Dial = s.dialFunc + + return s +} + +// dialFunc gets the resolver's address and puts it into internal cache. +func (sr *systemResolvers) dialFunc(_ context.Context, _, address string) (_ net.Conn, err error) { + // Just validate the passed address is a valid IP. + var host string + host, err = SplitHost(address) + if err != nil { + // TODO(e.burkov): Maybe use a structured badAddrPassedErr to + // allow unwrapping of the real error. + return nil, fmt.Errorf("%s: %w", err, badAddrPassedErr) + } + + if net.ParseIP(host) == nil { + return nil, fmt.Errorf("parsing %q: %w", host, badAddrPassedErr) + } + + sr.addrsLock.Lock() + defer sr.addrsLock.Unlock() + + sr.addrs[address] = unit{} + + return nil, fakeDialErr +} + +func (sr *systemResolvers) Get() (rs []string) { + sr.addrsLock.RLock() + defer sr.addrsLock.RUnlock() + + addrs := sr.addrs + rs = make([]string, len(addrs)) + var i int + for addr := range addrs { + rs[i] = addr + i++ + } + + return rs +} diff --git a/internal/aghnet/systemresolvers_others_test.go b/internal/aghnet/systemresolvers_others_test.go new file mode 100644 index 00000000..972247b4 --- /dev/null +++ b/internal/aghnet/systemresolvers_others_test.go @@ -0,0 +1,74 @@ +// +build !windows + +package aghnet + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func createTestSystemResolversImp( + t *testing.T, + refreshDur time.Duration, + hostGenFunc HostGenFunc, +) (imp *systemResolvers) { + t.Helper() + + sr := createTestSystemResolvers(t, refreshDur, hostGenFunc) + + var ok bool + imp, ok = sr.(*systemResolvers) + require.True(t, ok) + + return imp +} + +func TestSystemResolvers_Refresh(t *testing.T) { + t.Run("expected_error", func(t *testing.T) { + sr := createTestSystemResolvers(t, 0, nil) + + assert.NoError(t, sr.Refresh()) + }) + + t.Run("unexpected_error", func(t *testing.T) { + _, err := NewSystemResolvers(0, func() string { + return "127.0.0.1::123" + }) + assert.Error(t, err) + }) +} + +func TestSystemResolvers_DialFunc(t *testing.T) { + imp := createTestSystemResolversImp(t, 0, nil) + + testCases := []struct { + name string + address string + want error + }{{ + name: "valid", + address: "127.0.0.1", + want: fakeDialErr, + }, { + name: "invalid_split_host", + address: "127.0.0.1::123", + want: badAddrPassedErr, + }, { + name: "invalid_parse_ip", + address: "not-ip", + want: badAddrPassedErr, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + conn, err := imp.dialFunc(context.Background(), "", tc.address) + + require.Nil(t, conn) + assert.ErrorIs(t, err, tc.want) + }) + } +} diff --git a/internal/aghnet/systemresolvers_test.go b/internal/aghnet/systemresolvers_test.go new file mode 100644 index 00000000..13145817 --- /dev/null +++ b/internal/aghnet/systemresolvers_test.go @@ -0,0 +1,33 @@ +package aghnet + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func createTestSystemResolvers( + t *testing.T, + refreshDur time.Duration, + hostGenFunc HostGenFunc, +) (sr SystemResolvers) { + t.Helper() + + var err error + sr, err = NewSystemResolvers(refreshDur, hostGenFunc) + require.NoError(t, err) + require.NotNil(t, sr) + + return sr +} + +func TestSystemResolvers_Get(t *testing.T) { + sr := createTestSystemResolvers(t, 0, nil) + assert.NotEmpty(t, sr.Get()) +} + +// TODO(e.burkov): Write tests for refreshWithTicker. +// +// See https://github.com/AdguardTeam/AdGuardHome/issues/2846. diff --git a/internal/aghnet/systemresolvers_windows.go b/internal/aghnet/systemresolvers_windows.go new file mode 100644 index 00000000..c918b44a --- /dev/null +++ b/internal/aghnet/systemresolvers_windows.go @@ -0,0 +1,158 @@ +// +build windows + +package aghnet + +import ( + "bufio" + "fmt" + "io" + "net" + "os/exec" + "strings" + "sync" + "time" + + "github.com/AdguardTeam/AdGuardHome/internal/agherr" + "github.com/AdguardTeam/AdGuardHome/internal/aghio" + "github.com/AdguardTeam/AdGuardHome/internal/aghos" + "github.com/AdguardTeam/golibs/log" +) + +// systemResolvers implementation differs for Windows since Go's resolver +// doesn't work there. +// +// See https://github.com/golang/go/issues/33097. +type systemResolvers struct { + // addrs is the slice of cached local resolvers' addresses. + addrs []string + addrsLock sync.RWMutex +} + +func newSystemResolvers(refreshIvl time.Duration, _ HostGenFunc) (sr SystemResolvers) { + return &systemResolvers{} +} + +func (sr *systemResolvers) Get() (rs []string) { + sr.addrsLock.RLock() + defer sr.addrsLock.RUnlock() + + addrs := sr.addrs + rs = make([]string, len(addrs)) + copy(rs, addrs) + + return rs +} + +// getAddrs gets local resolvers' addresses from OS in a special Windows way. +// +// TODO(e.burkov): This whole function needs more detailed research on getting +// local resolvers addresses on Windows. We execute the external command for +// now that is not the most accurate way. +func (sr *systemResolvers) getAddrs() (addrs []string, err error) { + cmd := exec.Command("nslookup") + + var stdin io.WriteCloser + stdin, err = cmd.StdinPipe() + if err != nil { + return nil, fmt.Errorf("getting the command's stdin pipe: %w", err) + } + + var stdout io.ReadCloser + stdout, err = cmd.StdoutPipe() + if err != nil { + return nil, fmt.Errorf("getting the command's stdout pipe: %w", err) + } + + var stdoutLimited io.ReadCloser + stdoutLimited, err = aghio.LimitReadCloser(stdout, aghos.MaxCmdOutputSize) + if err != nil { + return nil, fmt.Errorf("limiting stdout reader: %w", err) + } + + go func() { + defer agherr.LogPanic("systemResolvers") + defer func() { + derr := stdin.Close() + if derr != nil { + log.Error("systemResolvers: closing stdin pipe: %s", derr) + } + }() + + _, werr := io.WriteString(stdin, "exit") + if werr != nil { + log.Error("systemResolvers: writing to command pipe: %s", werr) + } + }() + + err = cmd.Start() + if err != nil { + return nil, fmt.Errorf("start command executing: %w", err) + } + + // The output of nslookup looks like this: + // + // Default Server: 192-168-1-1.qualified.domain.ru + // Address: 192.168.1.1 + + var possibleIPs []string + s := bufio.NewScanner(stdoutLimited) + for s.Scan() { + line := s.Text() + if len(line) == 0 { + continue + } + + fields := strings.Fields(line) + if len(fields) != 2 || fields[0] != "Address:" { + continue + } + + // If the address contains port then it is separated with '#'. + ipStrs := strings.Split(fields[1], "#") + if len(ipStrs) == 0 { + continue + } + + possibleIPs = append(possibleIPs, ipStrs[0]) + } + + err = cmd.Wait() + if err != nil { + return nil, fmt.Errorf("executing the command: %w", err) + } + + // Don't close StdoutPipe since Wait do it for us in ¿most? cases. + // + // See go doc os/exec.Cmd.StdoutPipe. + + for _, addr := range possibleIPs { + if net.ParseIP(addr) == nil { + log.Debug("systemResolvers: %q is not a valid ip", addr) + + continue + } + + addrs = append(addrs, addr) + } + + return addrs, nil +} + +func (sr *systemResolvers) Refresh() (err error) { + defer agherr.Annotate("systemResolvers: %w", &err) + + got, err := sr.getAddrs() + if err != nil { + return fmt.Errorf("can't get addresses: %w", err) + } + if len(got) == 0 { + return nil + } + + sr.addrsLock.Lock() + defer sr.addrsLock.Unlock() + + sr.addrs = got + + return nil +} diff --git a/internal/aghnet/systemresolvers_windows_test.go b/internal/aghnet/systemresolvers_windows_test.go new file mode 100644 index 00000000..35cc2fd3 --- /dev/null +++ b/internal/aghnet/systemresolvers_windows_test.go @@ -0,0 +1,7 @@ +// +build windows + +package aghnet + +// TODO(e.burkov): Write tests for Windows implementation. +// +// See https://github.com/AdguardTeam/AdGuardHome/issues/2846. diff --git a/internal/aghos/os.go b/internal/aghos/os.go index f573d7a7..4d535100 100644 --- a/internal/aghos/os.go +++ b/internal/aghos/os.go @@ -1,7 +1,11 @@ // Package aghos contains utilities for functions requiring system calls. package aghos -import "syscall" +import ( + "fmt" + "os/exec" + "syscall" +) // CanBindPrivilegedPorts checks if current process can bind to privileged // ports. @@ -24,3 +28,20 @@ func HaveAdminRights() (bool, error) { func SendProcessSignal(pid int, sig syscall.Signal) error { return sendProcessSignal(pid, sig) } + +// MaxCmdOutputSize is the maximum length of performed shell command output. +const MaxCmdOutputSize = 2 * 1024 + +// RunCommand runs shell command. +func RunCommand(command string, arguments ...string) (int, string, error) { + cmd := exec.Command(command, arguments...) + out, err := cmd.Output() + if len(out) > MaxCmdOutputSize { + out = out[:MaxCmdOutputSize] + } + if err != nil { + return 1, "", fmt.Errorf("exec.Command(%s) failed: %v: %s", command, err, string(out)) + } + + return cmd.ProcessState.ExitCode(), string(out), nil +} diff --git a/internal/aghtest/upstream.go b/internal/aghtest/upstream.go index 78622771..5cf4925d 100644 --- a/internal/aghtest/upstream.go +++ b/internal/aghtest/upstream.go @@ -3,12 +3,12 @@ package aghtest import ( "crypto/sha256" "encoding/hex" + "errors" "fmt" "net" "strings" "sync" - "github.com/AdguardTeam/AdGuardHome/internal/agherr" "github.com/miekg/dns" ) @@ -166,7 +166,10 @@ type TestErrUpstream struct{} // Exchange always returns nil Msg and non-nil error. func (u *TestErrUpstream) Exchange(*dns.Msg) (*dns.Msg, error) { - return nil, agherr.Error("bad") + // We don't use an agherr.Error to avoid the import cycle since aghtests + // used to provide the utilities for testing which agherr (and any other + // testable package) should be able to use. + return nil, errors.New("bad") } // Address always returns an empty string. diff --git a/internal/home/control.go b/internal/home/control.go index e7724794..57e8e731 100644 --- a/internal/home/control.go +++ b/internal/home/control.go @@ -2,7 +2,6 @@ package home import ( "encoding/json" - "errors" "fmt" "net" "net/http" @@ -11,6 +10,7 @@ import ( "strconv" "strings" + "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/AdGuardHome/internal/version" "github.com/AdguardTeam/golibs/log" @@ -213,22 +213,11 @@ func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (ok bool) { return true } - host, _, err := net.SplitHostPort(r.Host) + host, err := aghnet.SplitHost(r.Host) if err != nil { - // Check for the missing port error. If it is that error, just - // use the host as is. - // - // See the source code for net.SplitHostPort. - const missingPort = "missing port in address" + httpError(w, http.StatusBadRequest, "bad host: %s", err) - addrErr := &net.AddrError{} - if !errors.As(err, &addrErr) || addrErr.Err != missingPort { - httpError(w, http.StatusBadRequest, "bad host: %s", err) - - return false - } - - host = r.Host + return false } if r.TLS == nil && web.forceHTTPS { diff --git a/internal/home/controlinstall.go b/internal/home/controlinstall.go index 6c80e1d9..be1dc8ab 100644 --- a/internal/home/controlinstall.go +++ b/internal/home/controlinstall.go @@ -15,6 +15,7 @@ import ( "strings" "time" + "github.com/AdguardTeam/AdGuardHome/internal/agherr" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/golibs/log" @@ -270,14 +271,8 @@ func copyInstallSettings(dst, src *configuration) { // shutdownTimeout is the timeout for shutting HTTP server down operation. const shutdownTimeout = 5 * time.Second -func logPanic() { - if v := recover(); v != nil { - log.Error("recovered from panic: %v", v) - } -} - func shutdownSrv(ctx context.Context, cancel context.CancelFunc, srv *http.Server) { - defer logPanic() + defer agherr.LogPanic("") if srv == nil { return diff --git a/internal/home/service.go b/internal/home/service.go index 2ebd4f2f..4935212a 100644 --- a/internal/home/service.go +++ b/internal/home/service.go @@ -98,7 +98,7 @@ func sendSigReload() { if os.IsNotExist(err) { var code int var psdata string - code, psdata, err = util.RunCommand("ps", "-C", serviceName, "-o", "pid=") + code, psdata, err = aghos.RunCommand("ps", "-C", serviceName, "-o", "pid=") if err != nil || code != 0 { log.Error("Can't find AdGuardHome process: %s code:%d", err, code) return @@ -301,7 +301,7 @@ func configureService(c *service.Config) { // returns command code or error if any func runInitdCommand(action string) (int, error) { confPath := "/etc/init.d/" + serviceName - code, _, err := util.RunCommand("sh", "-c", confPath+" "+action) + code, _, err := aghos.RunCommand("sh", "-c", confPath+" "+action) return code, err } diff --git a/internal/querylog/qlogreader_test.go b/internal/querylog/qlogreader_test.go index 0dce5004..060adbca 100644 --- a/internal/querylog/qlogreader_test.go +++ b/internal/querylog/qlogreader_test.go @@ -1,7 +1,6 @@ package querylog import ( - "errors" "io" "testing" "time" @@ -116,7 +115,7 @@ func TestQLogReader_Seek(t *testing.T) { require.Nil(t, err) err = r.SeekTS(ts.UnixNano()) - assert.True(t, errors.Is(err, tc.want), err) + assert.ErrorIs(t, err, tc.want) }) } } diff --git a/internal/util/helpers.go b/internal/util/helpers.go index 2ccc9bbe..7add9617 100644 --- a/internal/util/helpers.go +++ b/internal/util/helpers.go @@ -6,25 +6,12 @@ package util import ( "bytes" - "fmt" "io/ioutil" - "os/exec" "path/filepath" "runtime" "strings" ) -// RunCommand runs shell command. -func RunCommand(command string, arguments ...string) (int, string, error) { - cmd := exec.Command(command, arguments...) - out, err := cmd.Output() - if err != nil { - return 1, "", fmt.Errorf("exec.Command(%s) failed: %v: %s", command, err, string(out)) - } - - return cmd.ProcessState.ExitCode(), string(out), nil -} - // SplitNext - split string by a byte and return the first chunk // Skip empty chunks // Whitespace is trimmed