Pull request: 2704 local resolvers vol.1

Merge in DNS/adguard-home from 2704-local-addresses-vol.1 to master

Updates #2704.
Updates #2829.
Updates #2846.

Squashed commit of the following:

commit 9a49b3d27edcb30da7f16a065226907833b1dc81
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Mon Mar 22 15:39:17 2021 +0300

    aghnet: imp docs and logging

commit 74f95a29c55b9e732276601b0ecc63fb7c3a9f9e
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Fri Mar 19 20:56:51 2021 +0300

    all: fix friday evening mistakes

commit 0e2066bc5c16ed807fa601780b99e154502361a9
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Fri Mar 19 20:51:15 2021 +0300

    all: upd testify, imp code quality

commit 8237c50b670c58361ccf7adec3ff2452b1196677
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Fri Mar 19 20:19:29 2021 +0300

    aghnet: imp test naming

commit 14eb1e189339554c0a6d38e2ba7a93917774ebab
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Fri Mar 19 19:41:43 2021 +0300

    aghnet: isolate windows-specific functionality

commit d461ac8b18c187999da3e3aba116571b7ebe6785
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Fri Mar 19 14:50:05 2021 +0300

    aghnet: imp code quality

commit d0ee01cb1f8613de2085c0f2f2f396e46beb52a5
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Fri Mar 19 11:59:10 2021 +0300

    all: mv funcs to agherr, mk system resolvers getter
This commit is contained in:
Eugene Burkov 2021-03-22 16:46:36 +03:00
parent eb9526cc92
commit 3b2f5d7842
19 changed files with 568 additions and 48 deletions

2
go.mod
View File

@ -29,7 +29,7 @@ require (
github.com/satori/go.uuid v1.2.0 github.com/satori/go.uuid v1.2.0
github.com/sirupsen/logrus v1.8.1 // indirect github.com/sirupsen/logrus v1.8.1 // indirect
github.com/spf13/cobra v1.1.3 // indirect github.com/spf13/cobra v1.1.3 // indirect
github.com/stretchr/testify v1.6.1 github.com/stretchr/testify v1.7.0
github.com/ti-mo/netfilter v0.4.0 github.com/ti-mo/netfilter v0.4.0
github.com/u-root/u-root v7.0.0+incompatible github.com/u-root/u-root v7.0.0+incompatible
go.etcd.io/bbolt v1.3.5 go.etcd.io/bbolt v1.3.5

2
go.sum
View File

@ -413,6 +413,8 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw= github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw=
github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA=
github.com/ti-mo/netfilter v0.2.0/go.mod h1:8GbBGsY/8fxtyIdfwy29JiluNcPK4K7wIT+x42ipqUU= github.com/ti-mo/netfilter v0.2.0/go.mod h1:8GbBGsY/8fxtyIdfwy29JiluNcPK4K7wIT+x42ipqUU=

View File

@ -4,6 +4,8 @@ package agherr
import ( import (
"fmt" "fmt"
"strings" "strings"
"github.com/AdguardTeam/golibs/log"
) )
// Error is the constant error type. // Error is the constant error type.
@ -95,6 +97,8 @@ type wrapper interface {
// } // }
// //
// msg must contain the final ": %w" verb. // msg must contain the final ": %w" verb.
//
// TODO(a.garipov): Clearify the function usage.
func Annotate(msg string, errPtr *error, args ...interface{}) { func Annotate(msg string, errPtr *error, args ...interface{}) {
if errPtr == nil { if errPtr == nil {
return return
@ -107,3 +111,17 @@ func Annotate(msg string, errPtr *error, args ...interface{}) {
*errPtr = fmt.Errorf(msg, args...) *errPtr = fmt.Errorf(msg, args...)
} }
} }
// LogPanic is a convinient helper function to log a panic in a goroutine. It
// should not be used where proper error handling is required.
func LogPanic(prefix string) {
if v := recover(); v != nil {
if prefix != "" {
log.Error("%s: recovered from panic: %v", prefix, v)
return
}
log.Error("recovered from panic: %v", v)
}
}

View File

@ -1,10 +1,12 @@
package agherr package agherr
import ( import (
"bytes"
"errors" "errors"
"fmt" "fmt"
"testing" "testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -120,3 +122,39 @@ func TestAnnotate(t *testing.T) {
assert.Equal(t, wantMsg, err.Error()) assert.Equal(t, wantMsg, err.Error())
}) })
} }
func TestLogPanic(t *testing.T) {
buf := &bytes.Buffer{}
aghtest.ReplaceLogWriter(t, buf)
t.Run("prefix", func(t *testing.T) {
const (
panicMsg = "spooky!"
prefix = "packagename"
errWithNoPrefix = "[error] recovered from panic: spooky!"
errWithPrefix = "[error] packagename: recovered from panic: spooky!"
)
panicFunc := func(prefix string) {
defer LogPanic(prefix)
panic(panicMsg)
}
panicFunc("")
assert.Contains(t, buf.String(), errWithNoPrefix)
buf.Reset()
panicFunc(prefix)
assert.Contains(t, buf.String(), errWithPrefix)
buf.Reset()
})
t.Run("don't_panic", func(t *testing.T) {
require.NotPanics(t, func() {
defer LogPanic("")
})
assert.Empty(t, buf.String())
})
}

View File

@ -251,3 +251,25 @@ func ErrorIsAddrInUse(err error) bool {
return errErrno == syscall.EADDRINUSE return errErrno == syscall.EADDRINUSE
} }
// SplitHost is a wrapper for net.SplitHostPort for the cases when the hostport
// does not necessarily contain a port.
func SplitHost(hostport string) (host string, err error) {
host, _, err = net.SplitHostPort(hostport)
if err != nil {
// Check for the missing port error. If it is that error, just
// use the host as is.
//
// See the source code for net.SplitHostPort.
const missingPort = "missing port in address"
addrErr := &net.AddrError{}
if !errors.As(err, &addrErr) || addrErr.Err != missingPort {
return "", err
}
host = hostport
}
return host, nil
}

View File

@ -9,7 +9,7 @@ import (
"regexp" "regexp"
"strings" "strings"
"github.com/AdguardTeam/AdGuardHome/internal/util" "github.com/AdguardTeam/AdGuardHome/internal/aghos"
) )
// hardwarePortInfo - information obtained using MacOS networksetup // hardwarePortInfo - information obtained using MacOS networksetup
@ -47,7 +47,7 @@ func getCurrentHardwarePortInfo(ifaceName string) (hardwarePortInfo, error) {
// it returns a map where the key is the interface name, and the value is the "hardware port" // it returns a map where the key is the interface name, and the value is the "hardware port"
// returns nil if it fails to parse the output // returns nil if it fails to parse the output
func getNetworkSetupHardwareReports() map[string]string { func getNetworkSetupHardwareReports() map[string]string {
_, out, err := util.RunCommand("networksetup", "-listallhardwareports") _, out, err := aghos.RunCommand("networksetup", "-listallhardwareports")
if err != nil { if err != nil {
return nil return nil
} }
@ -72,7 +72,7 @@ func getNetworkSetupHardwareReports() map[string]string {
func getHardwarePortInfo(hardwarePort string) (hardwarePortInfo, error) { func getHardwarePortInfo(hardwarePort string) (hardwarePortInfo, error) {
h := hardwarePortInfo{} h := hardwarePortInfo{}
_, out, err := util.RunCommand("networksetup", "-getinfo", hardwarePort) _, out, err := aghos.RunCommand("networksetup", "-getinfo", hardwarePort)
if err != nil { if err != nil {
return h, err return h, err
} }
@ -116,7 +116,7 @@ func ifaceSetStaticIP(ifaceName string) (err error) {
args = append(args, dnsAddrs...) args = append(args, dnsAddrs...)
// Setting DNS servers is necessary when configuring a static IP // Setting DNS servers is necessary when configuring a static IP
code, _, err := util.RunCommand("networksetup", args...) code, _, err := aghos.RunCommand("networksetup", args...)
if err != nil { if err != nil {
return err return err
} }
@ -125,7 +125,7 @@ func ifaceSetStaticIP(ifaceName string) (err error) {
} }
// Actually configures hardware port to have static IP // Actually configures hardware port to have static IP
code, _, err = util.RunCommand("networksetup", "-setmanual", code, _, err = aghos.RunCommand("networksetup", "-setmanual",
portInfo.name, portInfo.ip, portInfo.subnet, portInfo.gatewayIP) portInfo.name, portInfo.ip, portInfo.subnet, portInfo.gatewayIP)
if err != nil { if err != nil {
return err return err

View File

@ -0,0 +1,78 @@
package aghnet
import (
"time"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/golibs/log"
)
// DefaultRefreshIvl is the default period of time between refreshing cached
// addresses.
// const DefaultRefreshIvl = 5 * time.Minute
// HostGenFunc is the signature for functions generating fake hostnames. The
// implementation must be safe for concurrent use.
type HostGenFunc func() (host string)
// unit is an alias for an existing map value.
type unit = struct{}
// SystemResolvers helps to work with local resolvers' addresses provided by OS.
type SystemResolvers interface {
// Get returns the slice of local resolvers' addresses.
// It should be safe for concurrent use.
Get() (rs []string)
// Refresh refreshes the local resolvers' addresses cache. It should be
// safe for concurrent use.
Refresh() (err error)
}
const (
// fakeDialErr is an error which dialFunc is expected to return.
fakeDialErr agherr.Error = "this error signals the successful dialFunc work"
// badAddrPassedErr is returned when dialFunc can't parse an IP address.
badAddrPassedErr agherr.Error = "the passed string is not a valid IP address"
)
// refreshWithTicker refreshes the cache of sr after each tick form tickCh.
func refreshWithTicker(sr SystemResolvers, tickCh <-chan time.Time) {
defer agherr.LogPanic("systemResolvers")
// TODO(e.burkov): Implement a functionality to stop ticker.
for range tickCh {
err := sr.Refresh()
if err != nil {
log.Error("systemResolvers: error in refreshing goroutine: %s", err)
continue
}
log.Debug("systemResolvers: local addresses cache is refreshed")
}
}
// NewSystemResolvers returns a SystemResolvers with the cache refresh rate
// defined by refreshIvl. It disables auto-resfreshing if refreshIvl is 0. If
// nil is passed for hostGenFunc, the default generator will be used.
func NewSystemResolvers(
refreshIvl time.Duration,
hostGenFunc HostGenFunc,
) (sr SystemResolvers, err error) {
sr = newSystemResolvers(refreshIvl, hostGenFunc)
// Fill cache.
err = sr.Refresh()
if err != nil {
return nil, err
}
if refreshIvl > 0 {
ticker := time.NewTicker(refreshIvl)
go refreshWithTicker(sr, ticker.C)
}
return sr, nil
}

View File

@ -0,0 +1,96 @@
// +build !windows
package aghnet
import (
"context"
"errors"
"fmt"
"net"
"sync"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
)
// defaultHostGen is the default method of generating host for Refresh.
func defaultHostGen() (host string) {
// TODO(e.burkov): Use strings.Builder.
return fmt.Sprintf("test%d.org", time.Now().UnixNano())
}
// systemResolvers is a default implementation of SystemResolvers interface.
type systemResolvers struct {
resolver *net.Resolver
hostGenFunc HostGenFunc
// addrs is the map that contains cached local resolvers' addresses.
addrs map[string]unit
addrsLock sync.RWMutex
}
func (sr *systemResolvers) Refresh() (err error) {
defer agherr.Annotate("systemResolvers: %w", &err)
_, err = sr.resolver.LookupHost(context.Background(), sr.hostGenFunc())
dnserr := &net.DNSError{}
if errors.As(err, &dnserr) && dnserr.Err == fakeDialErr.Error() {
return nil
}
return err
}
func newSystemResolvers(refreshIvl time.Duration, hostGenFunc HostGenFunc) (sr SystemResolvers) {
if hostGenFunc == nil {
hostGenFunc = defaultHostGen
}
s := &systemResolvers{
resolver: &net.Resolver{
PreferGo: true,
},
hostGenFunc: hostGenFunc,
addrs: make(map[string]unit),
}
s.resolver.Dial = s.dialFunc
return s
}
// dialFunc gets the resolver's address and puts it into internal cache.
func (sr *systemResolvers) dialFunc(_ context.Context, _, address string) (_ net.Conn, err error) {
// Just validate the passed address is a valid IP.
var host string
host, err = SplitHost(address)
if err != nil {
// TODO(e.burkov): Maybe use a structured badAddrPassedErr to
// allow unwrapping of the real error.
return nil, fmt.Errorf("%s: %w", err, badAddrPassedErr)
}
if net.ParseIP(host) == nil {
return nil, fmt.Errorf("parsing %q: %w", host, badAddrPassedErr)
}
sr.addrsLock.Lock()
defer sr.addrsLock.Unlock()
sr.addrs[address] = unit{}
return nil, fakeDialErr
}
func (sr *systemResolvers) Get() (rs []string) {
sr.addrsLock.RLock()
defer sr.addrsLock.RUnlock()
addrs := sr.addrs
rs = make([]string, len(addrs))
var i int
for addr := range addrs {
rs[i] = addr
i++
}
return rs
}

View File

@ -0,0 +1,74 @@
// +build !windows
package aghnet
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func createTestSystemResolversImp(
t *testing.T,
refreshDur time.Duration,
hostGenFunc HostGenFunc,
) (imp *systemResolvers) {
t.Helper()
sr := createTestSystemResolvers(t, refreshDur, hostGenFunc)
var ok bool
imp, ok = sr.(*systemResolvers)
require.True(t, ok)
return imp
}
func TestSystemResolvers_Refresh(t *testing.T) {
t.Run("expected_error", func(t *testing.T) {
sr := createTestSystemResolvers(t, 0, nil)
assert.NoError(t, sr.Refresh())
})
t.Run("unexpected_error", func(t *testing.T) {
_, err := NewSystemResolvers(0, func() string {
return "127.0.0.1::123"
})
assert.Error(t, err)
})
}
func TestSystemResolvers_DialFunc(t *testing.T) {
imp := createTestSystemResolversImp(t, 0, nil)
testCases := []struct {
name string
address string
want error
}{{
name: "valid",
address: "127.0.0.1",
want: fakeDialErr,
}, {
name: "invalid_split_host",
address: "127.0.0.1::123",
want: badAddrPassedErr,
}, {
name: "invalid_parse_ip",
address: "not-ip",
want: badAddrPassedErr,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
conn, err := imp.dialFunc(context.Background(), "", tc.address)
require.Nil(t, conn)
assert.ErrorIs(t, err, tc.want)
})
}
}

View File

@ -0,0 +1,33 @@
package aghnet
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func createTestSystemResolvers(
t *testing.T,
refreshDur time.Duration,
hostGenFunc HostGenFunc,
) (sr SystemResolvers) {
t.Helper()
var err error
sr, err = NewSystemResolvers(refreshDur, hostGenFunc)
require.NoError(t, err)
require.NotNil(t, sr)
return sr
}
func TestSystemResolvers_Get(t *testing.T) {
sr := createTestSystemResolvers(t, 0, nil)
assert.NotEmpty(t, sr.Get())
}
// TODO(e.burkov): Write tests for refreshWithTicker.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/2846.

View File

@ -0,0 +1,158 @@
// +build windows
package aghnet
import (
"bufio"
"fmt"
"io"
"net"
"os/exec"
"strings"
"sync"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/log"
)
// systemResolvers implementation differs for Windows since Go's resolver
// doesn't work there.
//
// See https://github.com/golang/go/issues/33097.
type systemResolvers struct {
// addrs is the slice of cached local resolvers' addresses.
addrs []string
addrsLock sync.RWMutex
}
func newSystemResolvers(refreshIvl time.Duration, _ HostGenFunc) (sr SystemResolvers) {
return &systemResolvers{}
}
func (sr *systemResolvers) Get() (rs []string) {
sr.addrsLock.RLock()
defer sr.addrsLock.RUnlock()
addrs := sr.addrs
rs = make([]string, len(addrs))
copy(rs, addrs)
return rs
}
// getAddrs gets local resolvers' addresses from OS in a special Windows way.
//
// TODO(e.burkov): This whole function needs more detailed research on getting
// local resolvers addresses on Windows. We execute the external command for
// now that is not the most accurate way.
func (sr *systemResolvers) getAddrs() (addrs []string, err error) {
cmd := exec.Command("nslookup")
var stdin io.WriteCloser
stdin, err = cmd.StdinPipe()
if err != nil {
return nil, fmt.Errorf("getting the command's stdin pipe: %w", err)
}
var stdout io.ReadCloser
stdout, err = cmd.StdoutPipe()
if err != nil {
return nil, fmt.Errorf("getting the command's stdout pipe: %w", err)
}
var stdoutLimited io.ReadCloser
stdoutLimited, err = aghio.LimitReadCloser(stdout, aghos.MaxCmdOutputSize)
if err != nil {
return nil, fmt.Errorf("limiting stdout reader: %w", err)
}
go func() {
defer agherr.LogPanic("systemResolvers")
defer func() {
derr := stdin.Close()
if derr != nil {
log.Error("systemResolvers: closing stdin pipe: %s", derr)
}
}()
_, werr := io.WriteString(stdin, "exit")
if werr != nil {
log.Error("systemResolvers: writing to command pipe: %s", werr)
}
}()
err = cmd.Start()
if err != nil {
return nil, fmt.Errorf("start command executing: %w", err)
}
// The output of nslookup looks like this:
//
// Default Server: 192-168-1-1.qualified.domain.ru
// Address: 192.168.1.1
var possibleIPs []string
s := bufio.NewScanner(stdoutLimited)
for s.Scan() {
line := s.Text()
if len(line) == 0 {
continue
}
fields := strings.Fields(line)
if len(fields) != 2 || fields[0] != "Address:" {
continue
}
// If the address contains port then it is separated with '#'.
ipStrs := strings.Split(fields[1], "#")
if len(ipStrs) == 0 {
continue
}
possibleIPs = append(possibleIPs, ipStrs[0])
}
err = cmd.Wait()
if err != nil {
return nil, fmt.Errorf("executing the command: %w", err)
}
// Don't close StdoutPipe since Wait do it for us in ¿most? cases.
//
// See go doc os/exec.Cmd.StdoutPipe.
for _, addr := range possibleIPs {
if net.ParseIP(addr) == nil {
log.Debug("systemResolvers: %q is not a valid ip", addr)
continue
}
addrs = append(addrs, addr)
}
return addrs, nil
}
func (sr *systemResolvers) Refresh() (err error) {
defer agherr.Annotate("systemResolvers: %w", &err)
got, err := sr.getAddrs()
if err != nil {
return fmt.Errorf("can't get addresses: %w", err)
}
if len(got) == 0 {
return nil
}
sr.addrsLock.Lock()
defer sr.addrsLock.Unlock()
sr.addrs = got
return nil
}

View File

@ -0,0 +1,7 @@
// +build windows
package aghnet
// TODO(e.burkov): Write tests for Windows implementation.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/2846.

View File

@ -1,7 +1,11 @@
// Package aghos contains utilities for functions requiring system calls. // Package aghos contains utilities for functions requiring system calls.
package aghos package aghos
import "syscall" import (
"fmt"
"os/exec"
"syscall"
)
// CanBindPrivilegedPorts checks if current process can bind to privileged // CanBindPrivilegedPorts checks if current process can bind to privileged
// ports. // ports.
@ -24,3 +28,20 @@ func HaveAdminRights() (bool, error) {
func SendProcessSignal(pid int, sig syscall.Signal) error { func SendProcessSignal(pid int, sig syscall.Signal) error {
return sendProcessSignal(pid, sig) return sendProcessSignal(pid, sig)
} }
// MaxCmdOutputSize is the maximum length of performed shell command output.
const MaxCmdOutputSize = 2 * 1024
// RunCommand runs shell command.
func RunCommand(command string, arguments ...string) (int, string, error) {
cmd := exec.Command(command, arguments...)
out, err := cmd.Output()
if len(out) > MaxCmdOutputSize {
out = out[:MaxCmdOutputSize]
}
if err != nil {
return 1, "", fmt.Errorf("exec.Command(%s) failed: %v: %s", command, err, string(out))
}
return cmd.ProcessState.ExitCode(), string(out), nil
}

View File

@ -3,12 +3,12 @@ package aghtest
import ( import (
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"errors"
"fmt" "fmt"
"net" "net"
"strings" "strings"
"sync" "sync"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@ -166,7 +166,10 @@ type TestErrUpstream struct{}
// Exchange always returns nil Msg and non-nil error. // Exchange always returns nil Msg and non-nil error.
func (u *TestErrUpstream) Exchange(*dns.Msg) (*dns.Msg, error) { func (u *TestErrUpstream) Exchange(*dns.Msg) (*dns.Msg, error) {
return nil, agherr.Error("bad") // We don't use an agherr.Error to avoid the import cycle since aghtests
// used to provide the utilities for testing which agherr (and any other
// testable package) should be able to use.
return nil, errors.New("bad")
} }
// Address always returns an empty string. // Address always returns an empty string.

View File

@ -2,7 +2,6 @@ package home
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
@ -11,6 +10,7 @@ import (
"strconv" "strconv"
"strings" "strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/version" "github.com/AdguardTeam/AdGuardHome/internal/version"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
@ -213,22 +213,11 @@ func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (ok bool) {
return true return true
} }
host, _, err := net.SplitHostPort(r.Host) host, err := aghnet.SplitHost(r.Host)
if err != nil { if err != nil {
// Check for the missing port error. If it is that error, just httpError(w, http.StatusBadRequest, "bad host: %s", err)
// use the host as is.
//
// See the source code for net.SplitHostPort.
const missingPort = "missing port in address"
addrErr := &net.AddrError{} return false
if !errors.As(err, &addrErr) || addrErr.Err != missingPort {
httpError(w, http.StatusBadRequest, "bad host: %s", err)
return false
}
host = r.Host
} }
if r.TLS == nil && web.forceHTTPS { if r.TLS == nil && web.forceHTTPS {

View File

@ -15,6 +15,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
@ -270,14 +271,8 @@ func copyInstallSettings(dst, src *configuration) {
// shutdownTimeout is the timeout for shutting HTTP server down operation. // shutdownTimeout is the timeout for shutting HTTP server down operation.
const shutdownTimeout = 5 * time.Second const shutdownTimeout = 5 * time.Second
func logPanic() {
if v := recover(); v != nil {
log.Error("recovered from panic: %v", v)
}
}
func shutdownSrv(ctx context.Context, cancel context.CancelFunc, srv *http.Server) { func shutdownSrv(ctx context.Context, cancel context.CancelFunc, srv *http.Server) {
defer logPanic() defer agherr.LogPanic("")
if srv == nil { if srv == nil {
return return

View File

@ -98,7 +98,7 @@ func sendSigReload() {
if os.IsNotExist(err) { if os.IsNotExist(err) {
var code int var code int
var psdata string var psdata string
code, psdata, err = util.RunCommand("ps", "-C", serviceName, "-o", "pid=") code, psdata, err = aghos.RunCommand("ps", "-C", serviceName, "-o", "pid=")
if err != nil || code != 0 { if err != nil || code != 0 {
log.Error("Can't find AdGuardHome process: %s code:%d", err, code) log.Error("Can't find AdGuardHome process: %s code:%d", err, code)
return return
@ -301,7 +301,7 @@ func configureService(c *service.Config) {
// returns command code or error if any // returns command code or error if any
func runInitdCommand(action string) (int, error) { func runInitdCommand(action string) (int, error) {
confPath := "/etc/init.d/" + serviceName confPath := "/etc/init.d/" + serviceName
code, _, err := util.RunCommand("sh", "-c", confPath+" "+action) code, _, err := aghos.RunCommand("sh", "-c", confPath+" "+action)
return code, err return code, err
} }

View File

@ -1,7 +1,6 @@
package querylog package querylog
import ( import (
"errors"
"io" "io"
"testing" "testing"
"time" "time"
@ -116,7 +115,7 @@ func TestQLogReader_Seek(t *testing.T) {
require.Nil(t, err) require.Nil(t, err)
err = r.SeekTS(ts.UnixNano()) err = r.SeekTS(ts.UnixNano())
assert.True(t, errors.Is(err, tc.want), err) assert.ErrorIs(t, err, tc.want)
}) })
} }
} }

View File

@ -6,25 +6,12 @@ package util
import ( import (
"bytes" "bytes"
"fmt"
"io/ioutil" "io/ioutil"
"os/exec"
"path/filepath" "path/filepath"
"runtime" "runtime"
"strings" "strings"
) )
// RunCommand runs shell command.
func RunCommand(command string, arguments ...string) (int, string, error) {
cmd := exec.Command(command, arguments...)
out, err := cmd.Output()
if err != nil {
return 1, "", fmt.Errorf("exec.Command(%s) failed: %v: %s", command, err, string(out))
}
return cmd.ProcessState.ExitCode(), string(out), nil
}
// SplitNext - split string by a byte and return the first chunk // SplitNext - split string by a byte and return the first chunk
// Skip empty chunks // Skip empty chunks
// Whitespace is trimmed // Whitespace is trimmed