Pull request: 2704 local resolvers vol.1
Merge in DNS/adguard-home from 2704-local-addresses-vol.1 to master Updates #2704. Updates #2829. Updates #2846. Squashed commit of the following: commit 9a49b3d27edcb30da7f16a065226907833b1dc81 Author: Eugene Burkov <e.burkov@adguard.com> Date: Mon Mar 22 15:39:17 2021 +0300 aghnet: imp docs and logging commit 74f95a29c55b9e732276601b0ecc63fb7c3a9f9e Author: Eugene Burkov <e.burkov@adguard.com> Date: Fri Mar 19 20:56:51 2021 +0300 all: fix friday evening mistakes commit 0e2066bc5c16ed807fa601780b99e154502361a9 Author: Eugene Burkov <e.burkov@adguard.com> Date: Fri Mar 19 20:51:15 2021 +0300 all: upd testify, imp code quality commit 8237c50b670c58361ccf7adec3ff2452b1196677 Author: Eugene Burkov <e.burkov@adguard.com> Date: Fri Mar 19 20:19:29 2021 +0300 aghnet: imp test naming commit 14eb1e189339554c0a6d38e2ba7a93917774ebab Author: Eugene Burkov <e.burkov@adguard.com> Date: Fri Mar 19 19:41:43 2021 +0300 aghnet: isolate windows-specific functionality commit d461ac8b18c187999da3e3aba116571b7ebe6785 Author: Eugene Burkov <e.burkov@adguard.com> Date: Fri Mar 19 14:50:05 2021 +0300 aghnet: imp code quality commit d0ee01cb1f8613de2085c0f2f2f396e46beb52a5 Author: Eugene Burkov <e.burkov@adguard.com> Date: Fri Mar 19 11:59:10 2021 +0300 all: mv funcs to agherr, mk system resolvers getter
This commit is contained in:
parent
eb9526cc92
commit
3b2f5d7842
2
go.mod
2
go.mod
|
@ -29,7 +29,7 @@ require (
|
||||||
github.com/satori/go.uuid v1.2.0
|
github.com/satori/go.uuid v1.2.0
|
||||||
github.com/sirupsen/logrus v1.8.1 // indirect
|
github.com/sirupsen/logrus v1.8.1 // indirect
|
||||||
github.com/spf13/cobra v1.1.3 // indirect
|
github.com/spf13/cobra v1.1.3 // indirect
|
||||||
github.com/stretchr/testify v1.6.1
|
github.com/stretchr/testify v1.7.0
|
||||||
github.com/ti-mo/netfilter v0.4.0
|
github.com/ti-mo/netfilter v0.4.0
|
||||||
github.com/u-root/u-root v7.0.0+incompatible
|
github.com/u-root/u-root v7.0.0+incompatible
|
||||||
go.etcd.io/bbolt v1.3.5
|
go.etcd.io/bbolt v1.3.5
|
||||||
|
|
2
go.sum
2
go.sum
|
@ -413,6 +413,8 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P
|
||||||
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
|
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
|
||||||
github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
|
github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
|
||||||
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
|
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
|
||||||
|
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw=
|
github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw=
|
||||||
github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA=
|
github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA=
|
||||||
github.com/ti-mo/netfilter v0.2.0/go.mod h1:8GbBGsY/8fxtyIdfwy29JiluNcPK4K7wIT+x42ipqUU=
|
github.com/ti-mo/netfilter v0.2.0/go.mod h1:8GbBGsY/8fxtyIdfwy29JiluNcPK4K7wIT+x42ipqUU=
|
||||||
|
|
|
@ -4,6 +4,8 @@ package agherr
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/golibs/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Error is the constant error type.
|
// Error is the constant error type.
|
||||||
|
@ -95,6 +97,8 @@ type wrapper interface {
|
||||||
// }
|
// }
|
||||||
//
|
//
|
||||||
// msg must contain the final ": %w" verb.
|
// msg must contain the final ": %w" verb.
|
||||||
|
//
|
||||||
|
// TODO(a.garipov): Clearify the function usage.
|
||||||
func Annotate(msg string, errPtr *error, args ...interface{}) {
|
func Annotate(msg string, errPtr *error, args ...interface{}) {
|
||||||
if errPtr == nil {
|
if errPtr == nil {
|
||||||
return
|
return
|
||||||
|
@ -107,3 +111,17 @@ func Annotate(msg string, errPtr *error, args ...interface{}) {
|
||||||
*errPtr = fmt.Errorf(msg, args...)
|
*errPtr = fmt.Errorf(msg, args...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LogPanic is a convinient helper function to log a panic in a goroutine. It
|
||||||
|
// should not be used where proper error handling is required.
|
||||||
|
func LogPanic(prefix string) {
|
||||||
|
if v := recover(); v != nil {
|
||||||
|
if prefix != "" {
|
||||||
|
log.Error("%s: recovered from panic: %v", prefix, v)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Error("recovered from panic: %v", v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -1,10 +1,12 @@
|
||||||
package agherr
|
package agherr
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
@ -120,3 +122,39 @@ func TestAnnotate(t *testing.T) {
|
||||||
assert.Equal(t, wantMsg, err.Error())
|
assert.Equal(t, wantMsg, err.Error())
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestLogPanic(t *testing.T) {
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
aghtest.ReplaceLogWriter(t, buf)
|
||||||
|
|
||||||
|
t.Run("prefix", func(t *testing.T) {
|
||||||
|
const (
|
||||||
|
panicMsg = "spooky!"
|
||||||
|
prefix = "packagename"
|
||||||
|
errWithNoPrefix = "[error] recovered from panic: spooky!"
|
||||||
|
errWithPrefix = "[error] packagename: recovered from panic: spooky!"
|
||||||
|
)
|
||||||
|
|
||||||
|
panicFunc := func(prefix string) {
|
||||||
|
defer LogPanic(prefix)
|
||||||
|
|
||||||
|
panic(panicMsg)
|
||||||
|
}
|
||||||
|
|
||||||
|
panicFunc("")
|
||||||
|
assert.Contains(t, buf.String(), errWithNoPrefix)
|
||||||
|
buf.Reset()
|
||||||
|
|
||||||
|
panicFunc(prefix)
|
||||||
|
assert.Contains(t, buf.String(), errWithPrefix)
|
||||||
|
buf.Reset()
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("don't_panic", func(t *testing.T) {
|
||||||
|
require.NotPanics(t, func() {
|
||||||
|
defer LogPanic("")
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.Empty(t, buf.String())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
@ -251,3 +251,25 @@ func ErrorIsAddrInUse(err error) bool {
|
||||||
|
|
||||||
return errErrno == syscall.EADDRINUSE
|
return errErrno == syscall.EADDRINUSE
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SplitHost is a wrapper for net.SplitHostPort for the cases when the hostport
|
||||||
|
// does not necessarily contain a port.
|
||||||
|
func SplitHost(hostport string) (host string, err error) {
|
||||||
|
host, _, err = net.SplitHostPort(hostport)
|
||||||
|
if err != nil {
|
||||||
|
// Check for the missing port error. If it is that error, just
|
||||||
|
// use the host as is.
|
||||||
|
//
|
||||||
|
// See the source code for net.SplitHostPort.
|
||||||
|
const missingPort = "missing port in address"
|
||||||
|
|
||||||
|
addrErr := &net.AddrError{}
|
||||||
|
if !errors.As(err, &addrErr) || addrErr.Err != missingPort {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
host = hostport
|
||||||
|
}
|
||||||
|
|
||||||
|
return host, nil
|
||||||
|
}
|
||||||
|
|
|
@ -9,7 +9,7 @@ import (
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/util"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||||
)
|
)
|
||||||
|
|
||||||
// hardwarePortInfo - information obtained using MacOS networksetup
|
// hardwarePortInfo - information obtained using MacOS networksetup
|
||||||
|
@ -47,7 +47,7 @@ func getCurrentHardwarePortInfo(ifaceName string) (hardwarePortInfo, error) {
|
||||||
// it returns a map where the key is the interface name, and the value is the "hardware port"
|
// it returns a map where the key is the interface name, and the value is the "hardware port"
|
||||||
// returns nil if it fails to parse the output
|
// returns nil if it fails to parse the output
|
||||||
func getNetworkSetupHardwareReports() map[string]string {
|
func getNetworkSetupHardwareReports() map[string]string {
|
||||||
_, out, err := util.RunCommand("networksetup", "-listallhardwareports")
|
_, out, err := aghos.RunCommand("networksetup", "-listallhardwareports")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -72,7 +72,7 @@ func getNetworkSetupHardwareReports() map[string]string {
|
||||||
func getHardwarePortInfo(hardwarePort string) (hardwarePortInfo, error) {
|
func getHardwarePortInfo(hardwarePort string) (hardwarePortInfo, error) {
|
||||||
h := hardwarePortInfo{}
|
h := hardwarePortInfo{}
|
||||||
|
|
||||||
_, out, err := util.RunCommand("networksetup", "-getinfo", hardwarePort)
|
_, out, err := aghos.RunCommand("networksetup", "-getinfo", hardwarePort)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return h, err
|
return h, err
|
||||||
}
|
}
|
||||||
|
@ -116,7 +116,7 @@ func ifaceSetStaticIP(ifaceName string) (err error) {
|
||||||
args = append(args, dnsAddrs...)
|
args = append(args, dnsAddrs...)
|
||||||
|
|
||||||
// Setting DNS servers is necessary when configuring a static IP
|
// Setting DNS servers is necessary when configuring a static IP
|
||||||
code, _, err := util.RunCommand("networksetup", args...)
|
code, _, err := aghos.RunCommand("networksetup", args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -125,7 +125,7 @@ func ifaceSetStaticIP(ifaceName string) (err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Actually configures hardware port to have static IP
|
// Actually configures hardware port to have static IP
|
||||||
code, _, err = util.RunCommand("networksetup", "-setmanual",
|
code, _, err = aghos.RunCommand("networksetup", "-setmanual",
|
||||||
portInfo.name, portInfo.ip, portInfo.subnet, portInfo.gatewayIP)
|
portInfo.name, portInfo.ip, portInfo.subnet, portInfo.gatewayIP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
|
@ -0,0 +1,78 @@
|
||||||
|
package aghnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
|
||||||
|
"github.com/AdguardTeam/golibs/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DefaultRefreshIvl is the default period of time between refreshing cached
|
||||||
|
// addresses.
|
||||||
|
// const DefaultRefreshIvl = 5 * time.Minute
|
||||||
|
|
||||||
|
// HostGenFunc is the signature for functions generating fake hostnames. The
|
||||||
|
// implementation must be safe for concurrent use.
|
||||||
|
type HostGenFunc func() (host string)
|
||||||
|
|
||||||
|
// unit is an alias for an existing map value.
|
||||||
|
type unit = struct{}
|
||||||
|
|
||||||
|
// SystemResolvers helps to work with local resolvers' addresses provided by OS.
|
||||||
|
type SystemResolvers interface {
|
||||||
|
// Get returns the slice of local resolvers' addresses.
|
||||||
|
// It should be safe for concurrent use.
|
||||||
|
Get() (rs []string)
|
||||||
|
// Refresh refreshes the local resolvers' addresses cache. It should be
|
||||||
|
// safe for concurrent use.
|
||||||
|
Refresh() (err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
// fakeDialErr is an error which dialFunc is expected to return.
|
||||||
|
fakeDialErr agherr.Error = "this error signals the successful dialFunc work"
|
||||||
|
|
||||||
|
// badAddrPassedErr is returned when dialFunc can't parse an IP address.
|
||||||
|
badAddrPassedErr agherr.Error = "the passed string is not a valid IP address"
|
||||||
|
)
|
||||||
|
|
||||||
|
// refreshWithTicker refreshes the cache of sr after each tick form tickCh.
|
||||||
|
func refreshWithTicker(sr SystemResolvers, tickCh <-chan time.Time) {
|
||||||
|
defer agherr.LogPanic("systemResolvers")
|
||||||
|
|
||||||
|
// TODO(e.burkov): Implement a functionality to stop ticker.
|
||||||
|
for range tickCh {
|
||||||
|
err := sr.Refresh()
|
||||||
|
if err != nil {
|
||||||
|
log.Error("systemResolvers: error in refreshing goroutine: %s", err)
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug("systemResolvers: local addresses cache is refreshed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSystemResolvers returns a SystemResolvers with the cache refresh rate
|
||||||
|
// defined by refreshIvl. It disables auto-resfreshing if refreshIvl is 0. If
|
||||||
|
// nil is passed for hostGenFunc, the default generator will be used.
|
||||||
|
func NewSystemResolvers(
|
||||||
|
refreshIvl time.Duration,
|
||||||
|
hostGenFunc HostGenFunc,
|
||||||
|
) (sr SystemResolvers, err error) {
|
||||||
|
sr = newSystemResolvers(refreshIvl, hostGenFunc)
|
||||||
|
|
||||||
|
// Fill cache.
|
||||||
|
err = sr.Refresh()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if refreshIvl > 0 {
|
||||||
|
ticker := time.NewTicker(refreshIvl)
|
||||||
|
|
||||||
|
go refreshWithTicker(sr, ticker.C)
|
||||||
|
}
|
||||||
|
|
||||||
|
return sr, nil
|
||||||
|
}
|
|
@ -0,0 +1,96 @@
|
||||||
|
// +build !windows
|
||||||
|
|
||||||
|
package aghnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
|
||||||
|
)
|
||||||
|
|
||||||
|
// defaultHostGen is the default method of generating host for Refresh.
|
||||||
|
func defaultHostGen() (host string) {
|
||||||
|
// TODO(e.burkov): Use strings.Builder.
|
||||||
|
return fmt.Sprintf("test%d.org", time.Now().UnixNano())
|
||||||
|
}
|
||||||
|
|
||||||
|
// systemResolvers is a default implementation of SystemResolvers interface.
|
||||||
|
type systemResolvers struct {
|
||||||
|
resolver *net.Resolver
|
||||||
|
hostGenFunc HostGenFunc
|
||||||
|
|
||||||
|
// addrs is the map that contains cached local resolvers' addresses.
|
||||||
|
addrs map[string]unit
|
||||||
|
addrsLock sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sr *systemResolvers) Refresh() (err error) {
|
||||||
|
defer agherr.Annotate("systemResolvers: %w", &err)
|
||||||
|
|
||||||
|
_, err = sr.resolver.LookupHost(context.Background(), sr.hostGenFunc())
|
||||||
|
dnserr := &net.DNSError{}
|
||||||
|
if errors.As(err, &dnserr) && dnserr.Err == fakeDialErr.Error() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSystemResolvers(refreshIvl time.Duration, hostGenFunc HostGenFunc) (sr SystemResolvers) {
|
||||||
|
if hostGenFunc == nil {
|
||||||
|
hostGenFunc = defaultHostGen
|
||||||
|
}
|
||||||
|
s := &systemResolvers{
|
||||||
|
resolver: &net.Resolver{
|
||||||
|
PreferGo: true,
|
||||||
|
},
|
||||||
|
hostGenFunc: hostGenFunc,
|
||||||
|
addrs: make(map[string]unit),
|
||||||
|
}
|
||||||
|
s.resolver.Dial = s.dialFunc
|
||||||
|
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// dialFunc gets the resolver's address and puts it into internal cache.
|
||||||
|
func (sr *systemResolvers) dialFunc(_ context.Context, _, address string) (_ net.Conn, err error) {
|
||||||
|
// Just validate the passed address is a valid IP.
|
||||||
|
var host string
|
||||||
|
host, err = SplitHost(address)
|
||||||
|
if err != nil {
|
||||||
|
// TODO(e.burkov): Maybe use a structured badAddrPassedErr to
|
||||||
|
// allow unwrapping of the real error.
|
||||||
|
return nil, fmt.Errorf("%s: %w", err, badAddrPassedErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
if net.ParseIP(host) == nil {
|
||||||
|
return nil, fmt.Errorf("parsing %q: %w", host, badAddrPassedErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
sr.addrsLock.Lock()
|
||||||
|
defer sr.addrsLock.Unlock()
|
||||||
|
|
||||||
|
sr.addrs[address] = unit{}
|
||||||
|
|
||||||
|
return nil, fakeDialErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sr *systemResolvers) Get() (rs []string) {
|
||||||
|
sr.addrsLock.RLock()
|
||||||
|
defer sr.addrsLock.RUnlock()
|
||||||
|
|
||||||
|
addrs := sr.addrs
|
||||||
|
rs = make([]string, len(addrs))
|
||||||
|
var i int
|
||||||
|
for addr := range addrs {
|
||||||
|
rs[i] = addr
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
|
||||||
|
return rs
|
||||||
|
}
|
|
@ -0,0 +1,74 @@
|
||||||
|
// +build !windows
|
||||||
|
|
||||||
|
package aghnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func createTestSystemResolversImp(
|
||||||
|
t *testing.T,
|
||||||
|
refreshDur time.Duration,
|
||||||
|
hostGenFunc HostGenFunc,
|
||||||
|
) (imp *systemResolvers) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
sr := createTestSystemResolvers(t, refreshDur, hostGenFunc)
|
||||||
|
|
||||||
|
var ok bool
|
||||||
|
imp, ok = sr.(*systemResolvers)
|
||||||
|
require.True(t, ok)
|
||||||
|
|
||||||
|
return imp
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSystemResolvers_Refresh(t *testing.T) {
|
||||||
|
t.Run("expected_error", func(t *testing.T) {
|
||||||
|
sr := createTestSystemResolvers(t, 0, nil)
|
||||||
|
|
||||||
|
assert.NoError(t, sr.Refresh())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unexpected_error", func(t *testing.T) {
|
||||||
|
_, err := NewSystemResolvers(0, func() string {
|
||||||
|
return "127.0.0.1::123"
|
||||||
|
})
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSystemResolvers_DialFunc(t *testing.T) {
|
||||||
|
imp := createTestSystemResolversImp(t, 0, nil)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
address string
|
||||||
|
want error
|
||||||
|
}{{
|
||||||
|
name: "valid",
|
||||||
|
address: "127.0.0.1",
|
||||||
|
want: fakeDialErr,
|
||||||
|
}, {
|
||||||
|
name: "invalid_split_host",
|
||||||
|
address: "127.0.0.1::123",
|
||||||
|
want: badAddrPassedErr,
|
||||||
|
}, {
|
||||||
|
name: "invalid_parse_ip",
|
||||||
|
address: "not-ip",
|
||||||
|
want: badAddrPassedErr,
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
conn, err := imp.dialFunc(context.Background(), "", tc.address)
|
||||||
|
|
||||||
|
require.Nil(t, conn)
|
||||||
|
assert.ErrorIs(t, err, tc.want)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,33 @@
|
||||||
|
package aghnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func createTestSystemResolvers(
|
||||||
|
t *testing.T,
|
||||||
|
refreshDur time.Duration,
|
||||||
|
hostGenFunc HostGenFunc,
|
||||||
|
) (sr SystemResolvers) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
var err error
|
||||||
|
sr, err = NewSystemResolvers(refreshDur, hostGenFunc)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, sr)
|
||||||
|
|
||||||
|
return sr
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSystemResolvers_Get(t *testing.T) {
|
||||||
|
sr := createTestSystemResolvers(t, 0, nil)
|
||||||
|
assert.NotEmpty(t, sr.Get())
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(e.burkov): Write tests for refreshWithTicker.
|
||||||
|
//
|
||||||
|
// See https://github.com/AdguardTeam/AdGuardHome/issues/2846.
|
|
@ -0,0 +1,158 @@
|
||||||
|
// +build windows
|
||||||
|
|
||||||
|
package aghnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||||
|
"github.com/AdguardTeam/golibs/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
// systemResolvers implementation differs for Windows since Go's resolver
|
||||||
|
// doesn't work there.
|
||||||
|
//
|
||||||
|
// See https://github.com/golang/go/issues/33097.
|
||||||
|
type systemResolvers struct {
|
||||||
|
// addrs is the slice of cached local resolvers' addresses.
|
||||||
|
addrs []string
|
||||||
|
addrsLock sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSystemResolvers(refreshIvl time.Duration, _ HostGenFunc) (sr SystemResolvers) {
|
||||||
|
return &systemResolvers{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sr *systemResolvers) Get() (rs []string) {
|
||||||
|
sr.addrsLock.RLock()
|
||||||
|
defer sr.addrsLock.RUnlock()
|
||||||
|
|
||||||
|
addrs := sr.addrs
|
||||||
|
rs = make([]string, len(addrs))
|
||||||
|
copy(rs, addrs)
|
||||||
|
|
||||||
|
return rs
|
||||||
|
}
|
||||||
|
|
||||||
|
// getAddrs gets local resolvers' addresses from OS in a special Windows way.
|
||||||
|
//
|
||||||
|
// TODO(e.burkov): This whole function needs more detailed research on getting
|
||||||
|
// local resolvers addresses on Windows. We execute the external command for
|
||||||
|
// now that is not the most accurate way.
|
||||||
|
func (sr *systemResolvers) getAddrs() (addrs []string, err error) {
|
||||||
|
cmd := exec.Command("nslookup")
|
||||||
|
|
||||||
|
var stdin io.WriteCloser
|
||||||
|
stdin, err = cmd.StdinPipe()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("getting the command's stdin pipe: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var stdout io.ReadCloser
|
||||||
|
stdout, err = cmd.StdoutPipe()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("getting the command's stdout pipe: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var stdoutLimited io.ReadCloser
|
||||||
|
stdoutLimited, err = aghio.LimitReadCloser(stdout, aghos.MaxCmdOutputSize)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("limiting stdout reader: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer agherr.LogPanic("systemResolvers")
|
||||||
|
defer func() {
|
||||||
|
derr := stdin.Close()
|
||||||
|
if derr != nil {
|
||||||
|
log.Error("systemResolvers: closing stdin pipe: %s", derr)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, werr := io.WriteString(stdin, "exit")
|
||||||
|
if werr != nil {
|
||||||
|
log.Error("systemResolvers: writing to command pipe: %s", werr)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
err = cmd.Start()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("start command executing: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The output of nslookup looks like this:
|
||||||
|
//
|
||||||
|
// Default Server: 192-168-1-1.qualified.domain.ru
|
||||||
|
// Address: 192.168.1.1
|
||||||
|
|
||||||
|
var possibleIPs []string
|
||||||
|
s := bufio.NewScanner(stdoutLimited)
|
||||||
|
for s.Scan() {
|
||||||
|
line := s.Text()
|
||||||
|
if len(line) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
fields := strings.Fields(line)
|
||||||
|
if len(fields) != 2 || fields[0] != "Address:" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the address contains port then it is separated with '#'.
|
||||||
|
ipStrs := strings.Split(fields[1], "#")
|
||||||
|
if len(ipStrs) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
possibleIPs = append(possibleIPs, ipStrs[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
err = cmd.Wait()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("executing the command: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Don't close StdoutPipe since Wait do it for us in ¿most? cases.
|
||||||
|
//
|
||||||
|
// See go doc os/exec.Cmd.StdoutPipe.
|
||||||
|
|
||||||
|
for _, addr := range possibleIPs {
|
||||||
|
if net.ParseIP(addr) == nil {
|
||||||
|
log.Debug("systemResolvers: %q is not a valid ip", addr)
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
addrs = append(addrs, addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
return addrs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sr *systemResolvers) Refresh() (err error) {
|
||||||
|
defer agherr.Annotate("systemResolvers: %w", &err)
|
||||||
|
|
||||||
|
got, err := sr.getAddrs()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("can't get addresses: %w", err)
|
||||||
|
}
|
||||||
|
if len(got) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
sr.addrsLock.Lock()
|
||||||
|
defer sr.addrsLock.Unlock()
|
||||||
|
|
||||||
|
sr.addrs = got
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -0,0 +1,7 @@
|
||||||
|
// +build windows
|
||||||
|
|
||||||
|
package aghnet
|
||||||
|
|
||||||
|
// TODO(e.burkov): Write tests for Windows implementation.
|
||||||
|
//
|
||||||
|
// See https://github.com/AdguardTeam/AdGuardHome/issues/2846.
|
|
@ -1,7 +1,11 @@
|
||||||
// Package aghos contains utilities for functions requiring system calls.
|
// Package aghos contains utilities for functions requiring system calls.
|
||||||
package aghos
|
package aghos
|
||||||
|
|
||||||
import "syscall"
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os/exec"
|
||||||
|
"syscall"
|
||||||
|
)
|
||||||
|
|
||||||
// CanBindPrivilegedPorts checks if current process can bind to privileged
|
// CanBindPrivilegedPorts checks if current process can bind to privileged
|
||||||
// ports.
|
// ports.
|
||||||
|
@ -24,3 +28,20 @@ func HaveAdminRights() (bool, error) {
|
||||||
func SendProcessSignal(pid int, sig syscall.Signal) error {
|
func SendProcessSignal(pid int, sig syscall.Signal) error {
|
||||||
return sendProcessSignal(pid, sig)
|
return sendProcessSignal(pid, sig)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MaxCmdOutputSize is the maximum length of performed shell command output.
|
||||||
|
const MaxCmdOutputSize = 2 * 1024
|
||||||
|
|
||||||
|
// RunCommand runs shell command.
|
||||||
|
func RunCommand(command string, arguments ...string) (int, string, error) {
|
||||||
|
cmd := exec.Command(command, arguments...)
|
||||||
|
out, err := cmd.Output()
|
||||||
|
if len(out) > MaxCmdOutputSize {
|
||||||
|
out = out[:MaxCmdOutputSize]
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return 1, "", fmt.Errorf("exec.Command(%s) failed: %v: %s", command, err, string(out))
|
||||||
|
}
|
||||||
|
|
||||||
|
return cmd.ProcessState.ExitCode(), string(out), nil
|
||||||
|
}
|
||||||
|
|
|
@ -3,12 +3,12 @@ package aghtest
|
||||||
import (
|
import (
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -166,7 +166,10 @@ type TestErrUpstream struct{}
|
||||||
|
|
||||||
// Exchange always returns nil Msg and non-nil error.
|
// Exchange always returns nil Msg and non-nil error.
|
||||||
func (u *TestErrUpstream) Exchange(*dns.Msg) (*dns.Msg, error) {
|
func (u *TestErrUpstream) Exchange(*dns.Msg) (*dns.Msg, error) {
|
||||||
return nil, agherr.Error("bad")
|
// We don't use an agherr.Error to avoid the import cycle since aghtests
|
||||||
|
// used to provide the utilities for testing which agherr (and any other
|
||||||
|
// testable package) should be able to use.
|
||||||
|
return nil, errors.New("bad")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Address always returns an empty string.
|
// Address always returns an empty string.
|
||||||
|
|
|
@ -2,7 +2,6 @@ package home
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -11,6 +10,7 @@ import (
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
|
@ -213,22 +213,11 @@ func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (ok bool) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
host, _, err := net.SplitHostPort(r.Host)
|
host, err := aghnet.SplitHost(r.Host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Check for the missing port error. If it is that error, just
|
httpError(w, http.StatusBadRequest, "bad host: %s", err)
|
||||||
// use the host as is.
|
|
||||||
//
|
|
||||||
// See the source code for net.SplitHostPort.
|
|
||||||
const missingPort = "missing port in address"
|
|
||||||
|
|
||||||
addrErr := &net.AddrError{}
|
return false
|
||||||
if !errors.As(err, &addrErr) || addrErr.Err != missingPort {
|
|
||||||
httpError(w, http.StatusBadRequest, "bad host: %s", err)
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
host = r.Host
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.TLS == nil && web.forceHTTPS {
|
if r.TLS == nil && web.forceHTTPS {
|
||||||
|
|
|
@ -15,6 +15,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||||
|
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
|
@ -270,14 +271,8 @@ func copyInstallSettings(dst, src *configuration) {
|
||||||
// shutdownTimeout is the timeout for shutting HTTP server down operation.
|
// shutdownTimeout is the timeout for shutting HTTP server down operation.
|
||||||
const shutdownTimeout = 5 * time.Second
|
const shutdownTimeout = 5 * time.Second
|
||||||
|
|
||||||
func logPanic() {
|
|
||||||
if v := recover(); v != nil {
|
|
||||||
log.Error("recovered from panic: %v", v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func shutdownSrv(ctx context.Context, cancel context.CancelFunc, srv *http.Server) {
|
func shutdownSrv(ctx context.Context, cancel context.CancelFunc, srv *http.Server) {
|
||||||
defer logPanic()
|
defer agherr.LogPanic("")
|
||||||
|
|
||||||
if srv == nil {
|
if srv == nil {
|
||||||
return
|
return
|
||||||
|
|
|
@ -98,7 +98,7 @@ func sendSigReload() {
|
||||||
if os.IsNotExist(err) {
|
if os.IsNotExist(err) {
|
||||||
var code int
|
var code int
|
||||||
var psdata string
|
var psdata string
|
||||||
code, psdata, err = util.RunCommand("ps", "-C", serviceName, "-o", "pid=")
|
code, psdata, err = aghos.RunCommand("ps", "-C", serviceName, "-o", "pid=")
|
||||||
if err != nil || code != 0 {
|
if err != nil || code != 0 {
|
||||||
log.Error("Can't find AdGuardHome process: %s code:%d", err, code)
|
log.Error("Can't find AdGuardHome process: %s code:%d", err, code)
|
||||||
return
|
return
|
||||||
|
@ -301,7 +301,7 @@ func configureService(c *service.Config) {
|
||||||
// returns command code or error if any
|
// returns command code or error if any
|
||||||
func runInitdCommand(action string) (int, error) {
|
func runInitdCommand(action string) (int, error) {
|
||||||
confPath := "/etc/init.d/" + serviceName
|
confPath := "/etc/init.d/" + serviceName
|
||||||
code, _, err := util.RunCommand("sh", "-c", confPath+" "+action)
|
code, _, err := aghos.RunCommand("sh", "-c", confPath+" "+action)
|
||||||
return code, err
|
return code, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package querylog
|
package querylog
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"io"
|
"io"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
@ -116,7 +115,7 @@ func TestQLogReader_Seek(t *testing.T) {
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
err = r.SeekTS(ts.UnixNano())
|
err = r.SeekTS(ts.UnixNano())
|
||||||
assert.True(t, errors.Is(err, tc.want), err)
|
assert.ErrorIs(t, err, tc.want)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,25 +6,12 @@ package util
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"os/exec"
|
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
// RunCommand runs shell command.
|
|
||||||
func RunCommand(command string, arguments ...string) (int, string, error) {
|
|
||||||
cmd := exec.Command(command, arguments...)
|
|
||||||
out, err := cmd.Output()
|
|
||||||
if err != nil {
|
|
||||||
return 1, "", fmt.Errorf("exec.Command(%s) failed: %v: %s", command, err, string(out))
|
|
||||||
}
|
|
||||||
|
|
||||||
return cmd.ProcessState.ExitCode(), string(out), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SplitNext - split string by a byte and return the first chunk
|
// SplitNext - split string by a byte and return the first chunk
|
||||||
// Skip empty chunks
|
// Skip empty chunks
|
||||||
// Whitespace is trimmed
|
// Whitespace is trimmed
|
||||||
|
|
Loading…
Reference in New Issue