diff --git a/HACKING.md b/HACKING.md index ec9528e9..1b19d217 100644 --- a/HACKING.md +++ b/HACKING.md @@ -66,8 +66,7 @@ on GitHub and most other Markdown renderers. --> ### Code * Always `recover` from panics in new goroutines. Preferably in the very - first statement. If all you want there is a log message, use - `agherr.LogPanic`. + first statement. If all you want there is a log message, use `log.OnPanic`. * Avoid `fallthrough`. It makes it harder to rearrange `case`s, to reason about the code, and also to switch the code to a handler approach, if that diff --git a/go.mod b/go.mod index 65257868..33934b9b 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.16 require ( github.com/AdguardTeam/dnsproxy v0.37.4 - github.com/AdguardTeam/golibs v0.5.0 + github.com/AdguardTeam/golibs v0.8.0 github.com/AdguardTeam/urlfilter v0.14.5 github.com/NYTimes/gziphandler v1.1.1 github.com/ameshkov/dnscrypt/v2 v2.1.3 diff --git a/go.sum b/go.sum index 78e87d4c..454ecee9 100644 --- a/go.sum +++ b/go.sum @@ -14,8 +14,8 @@ github.com/AdguardTeam/dnsproxy v0.37.4/go.mod h1:xkJWEuTr550gPDmB9azsciKZzSXjf9 github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= github.com/AdguardTeam/golibs v0.4.2/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= github.com/AdguardTeam/golibs v0.4.4/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= -github.com/AdguardTeam/golibs v0.5.0 h1:qwhEKjDrT0UcwDnHrNU2Yg/DLR9b/GsUncnXYW6VzAU= -github.com/AdguardTeam/golibs v0.5.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= +github.com/AdguardTeam/golibs v0.8.0 h1:rHo+yIgT2fivFG0yW2Cwk/DPc2+t/Aw6QvzPpiIFre0= +github.com/AdguardTeam/golibs v0.8.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= github.com/AdguardTeam/gomitmproxy v0.2.0/go.mod h1:Qdv0Mktnzer5zpdpi5rAwixNJzW2FN91LjKJCkVbYGU= github.com/AdguardTeam/urlfilter v0.14.5 h1:WyF0hg0MwKevsqNPkoaZFH8f5WRi/yuy/7qePtYt5Ts= github.com/AdguardTeam/urlfilter v0.14.5/go.mod h1:klx4JbOfc4EaNb5lWLqOwfg+pVcyRukmoJRvO55lL5U= diff --git a/internal/agherr/agherr.go b/internal/agherr/agherr.go deleted file mode 100644 index 0a9f1b6d..00000000 --- a/internal/agherr/agherr.go +++ /dev/null @@ -1,128 +0,0 @@ -// Package agherr contains AdGuard Home's error handling helpers. -package agherr - -import ( - "fmt" - "strings" - - "github.com/AdguardTeam/golibs/log" -) - -// Error is the constant error type. -type Error string - -// Error implements the error interface for Error. -func (err Error) Error() (msg string) { - return string(err) -} - -// manyError is an error containing several wrapped errors. It is created to be -// a simpler version of the API provided by github.com/joomcode/errorx. -type manyError struct { - message string - underlying []error -} - -// Many wraps several errors and returns a single error. -// -// TODO(a.garipov): Add formatting to message. -func Many(message string, underlying ...error) (err error) { - err = &manyError{ - message: message, - underlying: underlying, - } - - return err -} - -// Error implements the error interface for *manyError. -func (e *manyError) Error() (msg string) { - switch len(e.underlying) { - case 0: - return e.message - case 1: - return fmt.Sprintf("%s: %s", e.message, e.underlying[0]) - default: - b := &strings.Builder{} - - // Ignore errors, since strings.(*Buffer).Write never returns - // errors. We don't use aghstrings.WriteToBuilder here since - // this package should be importable for any other. - _, _ = fmt.Fprintf(b, "%s: %s (hidden: %s", e.message, e.underlying[0], e.underlying[1]) - for _, u := range e.underlying[2:] { - // See comment above. - _, _ = fmt.Fprintf(b, ", %s", u) - } - - // See comment above. - _, _ = b.WriteString(")") - - return b.String() - } -} - -// Unwrap implements the hidden errors.wrapper interface for *manyError. -func (e *manyError) Unwrap() (err error) { - if len(e.underlying) == 0 { - return nil - } - - return e.underlying[0] -} - -// wrapper is a copy of the hidden errors.wrapper interface for tests, linting, -// etc. -type wrapper interface { - Unwrap() error -} - -// Annotate annotates the error with the message, unless the error is nil. This -// is a helper function to simplify code like this: -// -// func (f *foo) doStuff(s string) (err error) { -// defer func() { -// if err != nil { -// err = fmt.Errorf("bad foo string %q: %w", s, err) -// } -// }() -// -// // … -// } -// -// Instead, write: -// -// func (f *foo) doStuff(s string) (err error) { -// defer agherr.Annotate("bad foo string %q: %w", &err, s) -// -// // … -// } -// -// msg must contain the final ": %w" verb. -// -// TODO(a.garipov): Clearify the function usage. -func Annotate(msg string, errPtr *error, args ...interface{}) { - if errPtr == nil { - return - } - - err := *errPtr - if err != nil { - args = append(args, err) - - *errPtr = fmt.Errorf(msg, args...) - } -} - -// LogPanic is a convinient helper function to log a panic in a goroutine. It -// should not be used where proper error handling is required. -func LogPanic(prefix string) { - if v := recover(); v != nil { - if prefix != "" { - log.Error("%s: recovered from panic: %v", prefix, v) - - return - } - - log.Error("recovered from panic: %v", v) - } -} diff --git a/internal/agherr/agherr_test.go b/internal/agherr/agherr_test.go deleted file mode 100644 index e7f53083..00000000 --- a/internal/agherr/agherr_test.go +++ /dev/null @@ -1,160 +0,0 @@ -package agherr - -import ( - "bytes" - "errors" - "fmt" - "testing" - - "github.com/AdguardTeam/AdGuardHome/internal/aghtest" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestError_Error(t *testing.T) { - testCases := []struct { - err error - name string - want string - }{{ - err: Many("a"), - name: "simple", - want: "a", - }, { - err: Many("a", errors.New("b")), - name: "wrapping", - want: "a: b", - }, { - err: Many("a", errors.New("b"), errors.New("c"), errors.New("d")), - name: "wrapping several", - want: "a: b (hidden: c, d)", - }, { - err: Many("a", Many("b", errors.New("c"), errors.New("d"))), - name: "wrapping wrapper", - want: "a: b: c (hidden: d)", - }} - - for _, tc := range testCases { - assert.Equal(t, tc.want, tc.err.Error(), tc.name) - } -} - -func TestError_Unwrap(t *testing.T) { - var _ wrapper = &manyError{} - - const ( - errSimple = iota - errWrapped - errNil - ) - - errs := []error{ - errSimple: errors.New("a"), - errWrapped: fmt.Errorf("err: %w", errors.New("nested")), - errNil: nil, - } - - testCases := []struct { - want error - wrapped error - name string - }{{ - want: errs[errSimple], - wrapped: Many("a", errs[errSimple]), - name: "simple", - }, { - want: errs[errWrapped], - wrapped: Many("b", errs[errWrapped]), - name: "nested", - }, { - want: errs[errNil], - wrapped: Many("c", errs[errNil]), - name: "nil passed", - }, { - want: nil, - wrapped: Many("d"), - name: "nil not passed", - }} - - for _, tc := range testCases { - assert.Equal(t, tc.want, errors.Unwrap(tc.wrapped), tc.name) - } -} - -func TestAnnotate(t *testing.T) { - const s = "1234" - const wantMsg = `bad string "1234": test` - - // Don't use const, because we can't take a pointer of a constant. - var errTest error = Error("test") - - t.Run("nil", func(t *testing.T) { - var errPtr *error - assert.NotPanics(t, func() { - Annotate("bad string %q: %w", errPtr, s) - }) - }) - - t.Run("non_nil", func(t *testing.T) { - errPtr := &errTest - assert.NotPanics(t, func() { - Annotate("bad string %q: %w", errPtr, s) - }) - - require.NotNil(t, errPtr) - - err := *errPtr - require.Error(t, err) - - assert.Equal(t, wantMsg, err.Error()) - }) - - t.Run("defer", func(t *testing.T) { - f := func() (err error) { - defer Annotate("bad string %q: %w", &errTest, s) - - return errTest - } - - err := f() - require.Error(t, err) - - assert.Equal(t, wantMsg, err.Error()) - }) -} - -func TestLogPanic(t *testing.T) { - buf := &bytes.Buffer{} - aghtest.ReplaceLogWriter(t, buf) - - t.Run("prefix", func(t *testing.T) { - const ( - panicMsg = "spooky!" - prefix = "packagename" - errWithNoPrefix = "[error] recovered from panic: spooky!" - errWithPrefix = "[error] packagename: recovered from panic: spooky!" - ) - - panicFunc := func(prefix string) { - defer LogPanic(prefix) - - panic(panicMsg) - } - - panicFunc("") - assert.Contains(t, buf.String(), errWithNoPrefix) - buf.Reset() - - panicFunc(prefix) - assert.Contains(t, buf.String(), errWithPrefix) - buf.Reset() - }) - - t.Run("don't_panic", func(t *testing.T) { - require.NotPanics(t, func() { - defer LogPanic("") - }) - - assert.Empty(t, buf.String()) - }) -} diff --git a/internal/aghio/limitedreadcloser.go b/internal/aghio/limitedreadcloser.go deleted file mode 100644 index 7690705a..00000000 --- a/internal/aghio/limitedreadcloser.go +++ /dev/null @@ -1,59 +0,0 @@ -// Package aghio contains extensions for io package's types and methods -package aghio - -import ( - "fmt" - "io" -) - -// LimitReachedError records the limit and the operation that caused it. -type LimitReachedError struct { - Limit int64 -} - -// Error implements error interface for LimitReachedError. -// TODO(a.garipov): Think about error string format. -func (lre *LimitReachedError) Error() string { - return fmt.Sprintf("attempted to read more than %d bytes", lre.Limit) -} - -// limitedReadCloser is a wrapper for io.ReadCloser with limited reader and -// dealing with agherr package. -type limitedReadCloser struct { - limit int64 - n int64 - rc io.ReadCloser -} - -// Read implements Reader interface. -func (lrc *limitedReadCloser) Read(p []byte) (n int, err error) { - if lrc.n == 0 { - return 0, &LimitReachedError{ - Limit: lrc.limit, - } - } - if int64(len(p)) > lrc.n { - p = p[0:lrc.n] - } - n, err = lrc.rc.Read(p) - lrc.n -= int64(n) - return n, err -} - -// Close implements Closer interface. -func (lrc *limitedReadCloser) Close() error { - return lrc.rc.Close() -} - -// LimitReadCloser wraps ReadCloser to make it's Reader stop with -// ErrLimitReached after n bytes read. -func LimitReadCloser(rc io.ReadCloser, n int64) (limited io.ReadCloser, err error) { - if n < 0 { - return nil, fmt.Errorf("aghio: invalid n in LimitReadCloser: %d", n) - } - return &limitedReadCloser{ - limit: n, - n: n, - rc: rc, - }, nil -} diff --git a/internal/aghio/limitedreader.go b/internal/aghio/limitedreader.go new file mode 100644 index 00000000..0f185da0 --- /dev/null +++ b/internal/aghio/limitedreader.go @@ -0,0 +1,59 @@ +// Package aghio contains extensions for io package's types and methods +package aghio + +import ( + "fmt" + "io" +) + +// LimitReachedError records the limit and the operation that caused it. +type LimitReachedError struct { + Limit int64 +} + +// Error implements error interface for LimitReachedError. +// +// TODO(a.garipov): Think about error string format. +func (lre *LimitReachedError) Error() string { + return fmt.Sprintf("attempted to read more than %d bytes", lre.Limit) +} + +// limitedReader is a wrapper for io.Reader with limited reader and dealing with +// errors package. +type limitedReader struct { + r io.Reader + limit int64 + n int64 +} + +// Read implements Reader interface. +func (lr *limitedReader) Read(p []byte) (n int, err error) { + if lr.n == 0 { + return 0, &LimitReachedError{ + Limit: lr.limit, + } + } + + if int64(len(p)) > lr.n { + p = p[0:lr.n] + } + + n, err = lr.r.Read(p) + lr.n -= int64(n) + + return n, err +} + +// LimitReader wraps Reader to make it's Reader stop with ErrLimitReached after +// n bytes read. +func LimitReader(r io.Reader, n int64) (limited io.Reader, err error) { + if n < 0 { + return nil, fmt.Errorf("aghio: invalid n in LimitReader: %d", n) + } + + return &limitedReader{ + r: r, + limit: n, + n: n, + }, nil +} diff --git a/internal/aghio/limitedreadcloser_test.go b/internal/aghio/limitedreader_test.go similarity index 83% rename from internal/aghio/limitedreadcloser_test.go rename to internal/aghio/limitedreader_test.go index 66ef2ec3..b9514819 100644 --- a/internal/aghio/limitedreadcloser_test.go +++ b/internal/aghio/limitedreader_test.go @@ -10,7 +10,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestLimitReadCloser(t *testing.T) { +func TestLimitReader(t *testing.T) { testCases := []struct { want error name string @@ -24,20 +24,20 @@ func TestLimitReadCloser(t *testing.T) { name: "zero", n: 0, }, { - want: fmt.Errorf("aghio: invalid n in LimitReadCloser: -1"), + want: fmt.Errorf("aghio: invalid n in LimitReader: -1"), name: "negative", n: -1, }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - _, err := LimitReadCloser(nil, tc.n) + _, err := LimitReader(nil, tc.n) assert.Equal(t, tc.want, err) }) } } -func TestLimitedReadCloser_Read(t *testing.T) { +func TestLimitedReader_Read(t *testing.T) { testCases := []struct { err error name string @@ -77,7 +77,7 @@ func TestLimitedReadCloser_Read(t *testing.T) { readCloser := io.NopCloser(strings.NewReader(tc.rStr)) buf := make([]byte, tc.limit+1) - lreader, err := LimitReadCloser(readCloser, tc.limit) + lreader, err := LimitReader(readCloser, tc.limit) require.NoError(t, err) n, err := lreader.Read(buf) @@ -87,7 +87,7 @@ func TestLimitedReadCloser_Read(t *testing.T) { } } -func TestLimitedReadCloser_LimitReachedError(t *testing.T) { +func TestLimitedReader_LimitReachedError(t *testing.T) { testCases := []struct { err error name string diff --git a/internal/aghnet/addr.go b/internal/aghnet/addr.go index ba81733f..ebe7d48d 100644 --- a/internal/aghnet/addr.go +++ b/internal/aghnet/addr.go @@ -6,7 +6,7 @@ import ( "strconv" "strings" - "github.com/AdguardTeam/AdGuardHome/internal/agherr" + "github.com/AdguardTeam/golibs/errors" "golang.org/x/net/idna" ) @@ -26,11 +26,11 @@ func isValidHostRune(r rune) (ok bool) { // ValidateHardwareAddress returns an error if hwa is not a valid EUI-48, // EUI-64, or 20-octet InfiniBand link-layer address. func ValidateHardwareAddress(hwa net.HardwareAddr) (err error) { - defer agherr.Annotate("validating hardware address %q: %w", &err, hwa) + defer func() { err = errors.Annotate(err, "validating hardware address %q: %w", hwa) }() switch l := len(hwa); l { case 0: - return agherr.Error("address is empty") + return errors.Error("address is empty") case 6, 8, 20: return nil default: @@ -51,13 +51,13 @@ const maxDomainNameLen = 253 // ValidateDomainNameLabel returns an error if label is not a valid label of // a domain name. func ValidateDomainNameLabel(label string) (err error) { - defer agherr.Annotate("validating label %q: %w", &err, label) + defer func() { err = errors.Annotate(err, "validating label %q: %w", label) }() l := len(label) if l > maxDomainLabelLen { return fmt.Errorf("label is too long, max: %d", maxDomainLabelLen) } else if l == 0 { - return agherr.Error("label is empty") + return errors.Error("label is empty") } if r := label[0]; !IsValidHostOuterRune(rune(r)) { @@ -87,7 +87,7 @@ func ValidateDomainNameLabel(label string) (err error) { // TODO(a.garipov): After making sure that this works correctly, port this into // module golibs. func ValidateDomainName(name string) (err error) { - defer agherr.Annotate("validating domain name %q: %w", &err, name) + defer func() { err = errors.Annotate(err, "validating domain name %q: %w", name) }() name, err = idna.ToASCII(name) if err != nil { @@ -96,7 +96,7 @@ func ValidateDomainName(name string) (err error) { l := len(name) if l == 0 { - return agherr.Error("domain name is empty") + return errors.Error("domain name is empty") } else if l > maxDomainNameLen { return fmt.Errorf("too long, max: %d", maxDomainNameLen) } diff --git a/internal/aghnet/etchostscontainer.go b/internal/aghnet/etchostscontainer.go index 9cb17c83..3a1487be 100644 --- a/internal/aghnet/etchostscontainer.go +++ b/internal/aghnet/etchostscontainer.go @@ -2,7 +2,6 @@ package aghnet import ( "bufio" - "errors" "io" "net" "os" @@ -12,6 +11,7 @@ import ( "sync" "github.com/AdguardTeam/AdGuardHome/internal/aghos" + "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/fsnotify/fsnotify" "github.com/miekg/dns" @@ -239,7 +239,13 @@ func (ehc *EtcHostsContainer) load(table map[string][]net.IP, tableRev map[strin log.Error("etchostscontainer: %s", err) return } - defer f.Close() + defer func() { + derr := f.Close() + if derr != nil { + log.Error("etchostscontainer: closing file: %s", err) + } + }() + r := bufio.NewReader(f) log.Debug("etchostscontainer: loading hosts from file %s", fn) diff --git a/internal/aghnet/net.go b/internal/aghnet/net.go index 643b61f8..ddb26d56 100644 --- a/internal/aghnet/net.go +++ b/internal/aghnet/net.go @@ -3,7 +3,6 @@ package aghnet import ( "encoding/json" - "errors" "fmt" "net" "os" @@ -14,14 +13,14 @@ import ( "syscall" "time" - "github.com/AdguardTeam/AdGuardHome/internal/agherr" "github.com/AdguardTeam/AdGuardHome/internal/aghstrings" + "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" ) // ErrNoStaticIPInfo is returned by IfaceHasStaticIP when no information about // the IP being static is available. -const ErrNoStaticIPInfo agherr.Error = "no information about static ip" +const ErrNoStaticIPInfo errors.Error = "no information about static ip" // IfaceHasStaticIP checks if interface is configured to have static IP address. // If it can't give a definitive answer, it returns false and an error for which @@ -106,7 +105,7 @@ func GetValidNetInterfacesForWeb() ([]*NetInterface, error) { return nil, fmt.Errorf("couldn't get interfaces: %w", err) } if len(ifaces) == 0 { - return nil, errors.New("couldn't find any legible interface") + return nil, errors.Error("couldn't find any legible interface") } var netInterfaces []*NetInterface diff --git a/internal/aghnet/net_darwin.go b/internal/aghnet/net_darwin.go index 0289fcd5..8e68e719 100644 --- a/internal/aghnet/net_darwin.go +++ b/internal/aghnet/net_darwin.go @@ -5,13 +5,13 @@ package aghnet import ( - "errors" "fmt" "os" "regexp" "strings" "github.com/AdguardTeam/AdGuardHome/internal/aghos" + "github.com/AdguardTeam/golibs/errors" ) // hardwarePortInfo - information obtained using MacOS networksetup @@ -83,7 +83,7 @@ func getHardwarePortInfo(hardwarePort string) (hardwarePortInfo, error) { match := re.FindStringSubmatch(out) if len(match) == 0 { - return h, errors.New("could not find hardware port info") + return h, errors.Error("could not find hardware port info") } h.name = hardwarePort @@ -105,7 +105,7 @@ func ifaceSetStaticIP(ifaceName string) (err error) { } if portInfo.static { - return errors.New("IP address is already static") + return errors.Error("IP address is already static") } dnsAddrs, err := getEtcResolvConfServers() @@ -151,7 +151,7 @@ func getEtcResolvConfServers() ([]string, error) { matches := re.FindAllStringSubmatch(string(body), -1) if len(matches) == 0 { - return nil, errors.New("found no DNS servers in /etc/resolv.conf") + return nil, errors.Error("found no DNS servers in /etc/resolv.conf") } addrs := make([]string, 0) diff --git a/internal/aghnet/net_linux.go b/internal/aghnet/net_linux.go index ce95db8e..b945dea6 100644 --- a/internal/aghnet/net_linux.go +++ b/internal/aghnet/net_linux.go @@ -6,7 +6,6 @@ package aghnet import ( "bufio" - "errors" "fmt" "io" "net" @@ -14,6 +13,7 @@ import ( "strings" "github.com/AdguardTeam/AdGuardHome/internal/aghio" + "github.com/AdguardTeam/golibs/errors" "github.com/google/renameio/maybe" ) @@ -49,16 +49,15 @@ func ifaceHasStaticIP(ifaceName string) (has bool, err error) { return false, err } - defer f.Close() + defer func() { err = errors.WithDeferred(err, f.Close()) }() - var fileReadCloser io.ReadCloser - fileReadCloser, err = aghio.LimitReadCloser(f, maxConfigFileSize) + var fileReader io.Reader + fileReader, err = aghio.LimitReader(f, maxConfigFileSize) if err != nil { return false, err } - defer fileReadCloser.Close() - has, err = check.checker(fileReadCloser, ifaceName) + has, err = check.checker(fileReader, ifaceName) if err != nil { return false, err } @@ -134,7 +133,7 @@ func ifacesStaticConfig(r io.Reader, ifaceName string) (has bool, err error) { func ifaceSetStaticIP(ifaceName string) (err error) { ipNet := GetSubnet(ifaceName) if ipNet.IP == nil { - return errors.New("can't get IP address") + return errors.Error("can't get IP address") } gatewayIP := GatewayIP(ifaceName) diff --git a/internal/aghnet/systemresolvers.go b/internal/aghnet/systemresolvers.go index 4f1ed6e6..4133fc24 100644 --- a/internal/aghnet/systemresolvers.go +++ b/internal/aghnet/systemresolvers.go @@ -3,7 +3,7 @@ package aghnet import ( "time" - "github.com/AdguardTeam/AdGuardHome/internal/agherr" + "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" ) @@ -27,19 +27,19 @@ type SystemResolvers interface { const ( // errBadAddrPassed is returned when dialFunc can't parse an IP address. - errBadAddrPassed agherr.Error = "the passed string is not a valid IP address" + errBadAddrPassed errors.Error = "the passed string is not a valid IP address" // errFakeDial is an error which dialFunc is expected to return. - errFakeDial agherr.Error = "this error signals the successful dialFunc work" + errFakeDial errors.Error = "this error signals the successful dialFunc work" // errUnexpectedHostFormat is returned by validateDialedHost when the host has // more than one percent sign. - errUnexpectedHostFormat agherr.Error = "unexpected host format" + errUnexpectedHostFormat errors.Error = "unexpected host format" ) // refreshWithTicker refreshes the cache of sr after each tick form tickCh. func refreshWithTicker(sr SystemResolvers, tickCh <-chan time.Time) { - defer agherr.LogPanic("systemResolvers") + defer log.OnPanic("systemResolvers") // TODO(e.burkov): Implement a functionality to stop ticker. for range tickCh { diff --git a/internal/aghnet/systemresolvers_others.go b/internal/aghnet/systemresolvers_others.go index d8da42da..c0ba457a 100644 --- a/internal/aghnet/systemresolvers_others.go +++ b/internal/aghnet/systemresolvers_others.go @@ -6,15 +6,14 @@ package aghnet import ( "context" - "errors" "fmt" "net" "strings" "sync" "time" - "github.com/AdguardTeam/AdGuardHome/internal/agherr" "github.com/AdguardTeam/AdGuardHome/internal/aghstrings" + "github.com/AdguardTeam/golibs/errors" ) // defaultHostGen is the default method of generating host for Refresh. @@ -34,7 +33,7 @@ type systemResolvers struct { } func (sr *systemResolvers) refresh() (err error) { - defer agherr.Annotate("systemResolvers: %w", &err) + defer func() { err = errors.Annotate(err, "systemResolvers: %w") }() _, err = sr.resolver.LookupHost(context.Background(), sr.hostGenFunc()) dnserr := &net.DNSError{} @@ -63,7 +62,7 @@ func newSystemResolvers(refreshIvl time.Duration, hostGenFunc HostGenFunc) (sr S // validateDialedHost validated the host used by resolvers in dialFunc. func validateDialedHost(host string) (err error) { - defer agherr.Annotate("parsing %q: %w", &err, host) + defer func() { err = errors.Annotate(err, "parsing %q: %w", host) }() var ipStr string parts := strings.Split(host, "%") diff --git a/internal/aghnet/systemresolvers_windows.go b/internal/aghnet/systemresolvers_windows.go index b93a3ec8..f3095860 100644 --- a/internal/aghnet/systemresolvers_windows.go +++ b/internal/aghnet/systemresolvers_windows.go @@ -14,9 +14,9 @@ import ( "sync" "time" - "github.com/AdguardTeam/AdGuardHome/internal/agherr" "github.com/AdguardTeam/AdGuardHome/internal/aghio" "github.com/AdguardTeam/AdGuardHome/internal/aghos" + "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" ) @@ -65,14 +65,15 @@ func (sr *systemResolvers) getAddrs() (addrs []string, err error) { return nil, fmt.Errorf("getting the command's stdout pipe: %w", err) } - var stdoutLimited io.ReadCloser - stdoutLimited, err = aghio.LimitReadCloser(stdout, aghos.MaxCmdOutputSize) + var stdoutLimited io.Reader + stdoutLimited, err = aghio.LimitReader(stdout, aghos.MaxCmdOutputSize) if err != nil { return nil, fmt.Errorf("limiting stdout reader: %w", err) } go func() { - defer agherr.LogPanic("systemResolvers") + defer log.OnPanic("systemResolvers") + defer func() { derr := stdin.Close() if derr != nil { @@ -141,7 +142,7 @@ func (sr *systemResolvers) getAddrs() (addrs []string, err error) { } func (sr *systemResolvers) refresh() (err error) { - defer agherr.Annotate("systemResolvers: %w", &err) + defer func() { err = errors.Annotate(err, "systemResolvers: %w") }() got, err := sr.getAddrs() if err != nil { diff --git a/internal/aghtest/aghtest.go b/internal/aghtest/aghtest.go index a713d20d..0e7f600c 100644 --- a/internal/aghtest/aghtest.go +++ b/internal/aghtest/aghtest.go @@ -30,7 +30,7 @@ func ReplaceLogWriter(t *testing.T, w io.Writer) { // ReplaceLogLevel sets logging level to l and uses Cleanup method of t to // revert changes. -func ReplaceLogLevel(t *testing.T, l int) { +func ReplaceLogLevel(t *testing.T, l log.Level) { switch l { case log.INFO, log.DEBUG, log.ERROR: // Go on. diff --git a/internal/aghtest/upstream.go b/internal/aghtest/upstream.go index 44c6a6ce..c69982ed 100644 --- a/internal/aghtest/upstream.go +++ b/internal/aghtest/upstream.go @@ -168,9 +168,6 @@ type TestErrUpstream struct { // Exchange always returns nil Msg and non-nil error. func (u *TestErrUpstream) Exchange(*dns.Msg) (*dns.Msg, error) { - // We don't use an agherr.Error to avoid the import cycle since aghtests - // used to provide the utilities for testing which agherr (and any other - // testable package) should be able to use. return nil, fmt.Errorf("errupstream: %w", u.Err) } diff --git a/internal/dhcpd/checkother.go b/internal/dhcpd/checkother.go index 7d84ce79..1e5de1c0 100644 --- a/internal/dhcpd/checkother.go +++ b/internal/dhcpd/checkother.go @@ -12,6 +12,7 @@ import ( "runtime" "time" + "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/insomniacslk/dhcp/dhcpv4" "github.com/insomniacslk/dhcp/dhcpv4/nclient4" @@ -78,7 +79,7 @@ func CheckIfOtherDHCPServersPresentV4(ifaceName string) (ok bool, err error) { return false, fmt.Errorf("couldn't listen on :68: %w", err) } if c != nil { - defer c.Close() + defer func() { err = errors.WithDeferred(err, c.Close()) }() } // send to 255.255.255.255:67 @@ -202,7 +203,7 @@ func CheckIfOtherDHCPServersPresentV6(ifaceName string) (ok bool, err error) { return false, fmt.Errorf("dhcpv6: Couldn't listen on :546: %w", err) } if c != nil { - defer c.Close() + defer func() { err = errors.WithDeferred(err, c.Close()) }() } _, err = c.WriteTo(req.ToBytes(), dstAddr) diff --git a/internal/dhcpd/db.go b/internal/dhcpd/db.go index c3ce8486..e9fd97c0 100644 --- a/internal/dhcpd/db.go +++ b/internal/dhcpd/db.go @@ -4,11 +4,11 @@ package dhcpd import ( "encoding/json" - "errors" "net" "os" "time" + "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/google/renameio/maybe" ) diff --git a/internal/dhcpd/http.go b/internal/dhcpd/http.go index 757e4153..c7de4e82 100644 --- a/internal/dhcpd/http.go +++ b/internal/dhcpd/http.go @@ -2,7 +2,6 @@ package dhcpd import ( "encoding/json" - "errors" "fmt" "io" "net" @@ -11,6 +10,7 @@ import ( "strings" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" + "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" ) diff --git a/internal/dhcpd/iprange.go b/internal/dhcpd/iprange.go index 3591869a..55670743 100644 --- a/internal/dhcpd/iprange.go +++ b/internal/dhcpd/iprange.go @@ -6,7 +6,7 @@ import ( "math/big" "net" - "github.com/AdguardTeam/AdGuardHome/internal/agherr" + "github.com/AdguardTeam/golibs/errors" ) // ipRange is an inclusive range of IP addresses. A nil range is a range that @@ -28,7 +28,7 @@ const maxRangeLen = math.MaxUint32 // newIPRange creates a new IP address range. start must be less than end. The // resulting range must not be greater than maxRangeLen. func newIPRange(start, end net.IP) (r *ipRange, err error) { - defer agherr.Annotate("invalid ip range: %w", &err) + defer func() { err = errors.Annotate(err, "invalid ip range: %w") }() // Make sure that both are 16 bytes long to simplify handling in // methods. diff --git a/internal/dhcpd/options.go b/internal/dhcpd/options.go index 780eeeab..51c3658e 100644 --- a/internal/dhcpd/options.go +++ b/internal/dhcpd/options.go @@ -7,7 +7,7 @@ import ( "strconv" "strings" - "github.com/AdguardTeam/AdGuardHome/internal/agherr" + "github.com/AdguardTeam/golibs/errors" ) // hexDHCPOptionParserHandler parses a DHCP option as a hex-encoded string. @@ -32,7 +32,7 @@ func hexDHCPOptionParserHandler(s string) (data []byte, err error) { func ipDHCPOptionParserHandler(s string) (data []byte, err error) { ip := net.ParseIP(s) if ip == nil { - return nil, agherr.Error("invalid ip") + return nil, errors.Error("invalid ip") } // Most DHCP options require IPv4, so do not put the 16-byte @@ -100,12 +100,12 @@ func newDHCPOptionParser() (p *dhcpOptionParser) { // parse parses an option. See the handlers' documentation for more info. func (p *dhcpOptionParser) parse(s string) (code uint8, data []byte, err error) { - defer agherr.Annotate("invalid option string %q: %w", &err, s) + defer func() { err = errors.Annotate(err, "invalid option string %q: %w", s) }() s = strings.TrimSpace(s) parts := strings.SplitN(s, " ", 3) if len(parts) < 3 { - return 0, nil, agherr.Error("need at least three fields") + return 0, nil, errors.Error("need at least three fields") } codeStr := parts[0] diff --git a/internal/dhcpd/os_windows.go b/internal/dhcpd/os_windows.go index 607e3d1e..6c0417a0 100644 --- a/internal/dhcpd/os_windows.go +++ b/internal/dhcpd/os_windows.go @@ -1,13 +1,13 @@ package dhcpd import ( - "errors" "net" + "github.com/AdguardTeam/golibs/errors" "golang.org/x/net/ipv4" ) // Create a socket for receiving broadcast packets func newBroadcastPacketConn(bindAddr net.IP, port int, ifname string) (*ipv4.PacketConn, error) { - return nil, errors.New("newBroadcastPacketConn(): not supported on Windows") + return nil, errors.Error("newBroadcastPacketConn(): not supported on Windows") } diff --git a/internal/dhcpd/v4.go b/internal/dhcpd/v4.go index a822f7a4..ee37d097 100644 --- a/internal/dhcpd/v4.go +++ b/internal/dhcpd/v4.go @@ -6,16 +6,15 @@ package dhcpd import ( "bytes" - "errors" "fmt" "net" "strings" "sync" "time" - "github.com/AdguardTeam/AdGuardHome/internal/agherr" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghstrings" + "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/go-ping/ping" "github.com/insomniacslk/dhcp/dhcpv4" @@ -55,7 +54,7 @@ func (s *v4Server) WriteDiskConfig6(c *V6ServerConf) { // normalizeHostname normalizes a hostname sent by the client. If err is not // nil, norm is an empty string. func normalizeHostname(hostname string) (norm string, err error) { - defer agherr.Annotate("normalizing %q: %w", &err, hostname) + defer func() { err = errors.Annotate(err, "normalizing %q: %w", hostname) }() if hostname == "" { return "", nil @@ -249,7 +248,7 @@ func (s *v4Server) rmDynamicLease(lease *Lease) (err error) { if bytes.Equal(l.HWAddr, lease.HWAddr) { if l.IsStatic() { - return agherr.Error("static lease already exists") + return errors.Error("static lease already exists") } s.rmLeaseByIndex(i) @@ -262,7 +261,7 @@ func (s *v4Server) rmDynamicLease(lease *Lease) (err error) { if l.IP.Equal(lease.IP) { if l.IsStatic() { - return agherr.Error("static lease already exists") + return errors.Error("static lease already exists") } s.rmLeaseByIndex(i) @@ -322,12 +321,12 @@ func (s *v4Server) rmLease(lease Lease) (err error) { } } - return agherr.Error("lease not found") + return errors.Error("lease not found") } // AddStaticLease adds a static lease. It is safe for concurrent use. func (s *v4Server) AddStaticLease(l Lease) (err error) { - defer agherr.Annotate("dhcpv4: adding static lease: %w", &err) + defer func() { err = errors.Annotate(err, "dhcpv4: adding static lease: %w") }() if ip4 := l.IP.To4(); ip4 == nil { return fmt.Errorf("invalid ip %q, only ipv4 is supported", l.IP) @@ -397,7 +396,7 @@ func (s *v4Server) AddStaticLease(l Lease) (err error) { // RemoveStaticLease removes a static lease. It is safe for concurrent use. func (s *v4Server) RemoveStaticLease(l Lease) (err error) { - defer agherr.Annotate("dhcpv4: %w", &err) + defer func() { err = errors.Annotate(err, "dhcpv4: %w") }() if len(l.IP) != 4 { return fmt.Errorf("invalid IP") @@ -937,7 +936,7 @@ func (s *v4Server) packetHandler(conn net.PacketConn, peer net.Addr, req *dhcpv4 // Start starts the IPv4 DHCP server. func (s *v4Server) Start() (err error) { - defer agherr.Annotate("dhcpv4: %w", &err) + defer func() { err = errors.Annotate(err, "dhcpv4: %w") }() if !s.conf.Enabled { return nil diff --git a/internal/dhcpd/v46_test.go b/internal/dhcpd/v46_test.go index 6495eeee..a07ccff3 100644 --- a/internal/dhcpd/v46_test.go +++ b/internal/dhcpd/v46_test.go @@ -1,11 +1,10 @@ package dhcpd import ( - "errors" "net" "testing" - "github.com/AdguardTeam/AdGuardHome/internal/agherr" + "github.com/AdguardTeam/golibs/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -25,7 +24,7 @@ func (iface *fakeIface) Addrs() (addrs []net.Addr, err error) { } func TestIfaceIPAddrs(t *testing.T) { - const errTest agherr.Error = "test error" + const errTest errors.Error = "test error" ip4 := net.IP{1, 2, 3, 4} addr4 := &net.IPNet{IP: ip4} @@ -108,7 +107,7 @@ func (iface *waitingFakeIface) Addrs() (addrs []net.Addr, err error) { } func TestIfaceDNSIPAddrs(t *testing.T) { - const errTest agherr.Error = "test error" + const errTest errors.Error = "test error" ip4 := net.IP{1, 2, 3, 4} addr4 := &net.IPNet{IP: ip4} diff --git a/internal/dhcpd/v6.go b/internal/dhcpd/v6.go index 994a7e11..9ccfbda7 100644 --- a/internal/dhcpd/v6.go +++ b/internal/dhcpd/v6.go @@ -11,8 +11,8 @@ import ( "sync" "time" - "github.com/AdguardTeam/AdGuardHome/internal/agherr" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" + "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/insomniacslk/dhcp/dhcpv6" "github.com/insomniacslk/dhcp/dhcpv6/server6" @@ -165,7 +165,7 @@ func (s *v6Server) rmDynamicLease(lease Lease) error { // AddStaticLease adds a static lease. It is safe for concurrent use. func (s *v6Server) AddStaticLease(l Lease) (err error) { - defer agherr.Annotate("dhcpv6: %w", &err) + defer func() { err = errors.Annotate(err, "dhcpv6: %w") }() if len(l.IP) != 16 { return fmt.Errorf("invalid IP") @@ -194,7 +194,7 @@ func (s *v6Server) AddStaticLease(l Lease) (err error) { // RemoveStaticLease removes a static lease. It is safe for concurrent use. func (s *v6Server) RemoveStaticLease(l Lease) (err error) { - defer agherr.Annotate("dhcpv6: %w", &err) + defer func() { err = errors.Annotate(err, "dhcpv6: %w") }() if len(l.IP) != 16 { return fmt.Errorf("invalid IP") @@ -585,7 +585,7 @@ func (s *v6Server) initRA(iface *net.Interface) error { // Start starts the IPv6 DHCP server. func (s *v6Server) Start() (err error) { - defer agherr.Annotate("dhcpv6: %w", &err) + defer func() { err = errors.Annotate(err, "dhcpv6: %w") }() if !s.conf.Enabled { return nil diff --git a/internal/dnsforward/clientid.go b/internal/dnsforward/clientid.go index 6f6debe7..06e2f563 100644 --- a/internal/dnsforward/clientid.go +++ b/internal/dnsforward/clientid.go @@ -2,13 +2,13 @@ package dnsforward import ( "crypto/tls" - "errors" "fmt" "path" "strings" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/dnsproxy/proxy" + "github.com/AdguardTeam/golibs/errors" "github.com/lucas-clemente/quic-go" ) diff --git a/internal/dnsforward/config.go b/internal/dnsforward/config.go index 8598c987..03710238 100644 --- a/internal/dnsforward/config.go +++ b/internal/dnsforward/config.go @@ -3,7 +3,6 @@ package dnsforward import ( "crypto/tls" "crypto/x509" - "errors" "fmt" "net" "net/http" @@ -14,6 +13,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" + "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/ameshkov/dnscrypt/v2" ) @@ -220,7 +220,7 @@ func (s *Server) createProxyConfig() (proxy.Config, error) { // Validate proxy config if proxyConfig.UpstreamConfig == nil || len(proxyConfig.UpstreamConfig.Upstreams) == 0 { - return proxyConfig, errors.New("no default upstream servers configured") + return proxyConfig, errors.Error("no default upstream servers configured") } return proxyConfig, nil diff --git a/internal/dnsforward/dnsforward.go b/internal/dnsforward/dnsforward.go index 471025d6..cf31500c 100644 --- a/internal/dnsforward/dnsforward.go +++ b/internal/dnsforward/dnsforward.go @@ -2,7 +2,6 @@ package dnsforward import ( - "errors" "fmt" "net" "net/http" @@ -12,7 +11,6 @@ import ( "sync" "time" - "github.com/AdguardTeam/AdGuardHome/internal/agherr" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghstrings" "github.com/AdguardTeam/AdGuardHome/internal/dhcpd" @@ -21,6 +19,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/stats" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" + "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/miekg/dns" ) @@ -226,11 +225,11 @@ type RDNSExchanger interface { const ( // rDNSEmptyAnswerErr is returned by Exchange method when the answer // section of respond is empty. - rDNSEmptyAnswerErr agherr.Error = "the answer section is empty" + rDNSEmptyAnswerErr errors.Error = "the answer section is empty" // rDNSNotPTRErr is returned by Exchange method when the response is not // of PTR type. - rDNSNotPTRErr agherr.Error = "the response is not a ptr" + rDNSNotPTRErr errors.Error = "the response is not a ptr" ) // Exchange implements the RDNSExchanger interface for *Server. diff --git a/internal/dnsforward/dnsforward_test.go b/internal/dnsforward/dnsforward_test.go index 3037188a..8cf48ff4 100644 --- a/internal/dnsforward/dnsforward_test.go +++ b/internal/dnsforward/dnsforward_test.go @@ -17,13 +17,13 @@ import ( "testing" "time" - "github.com/AdguardTeam/AdGuardHome/internal/agherr" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/dhcpd" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" + "github.com/AdguardTeam/golibs/errors" "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -1198,7 +1198,7 @@ func TestServer_Exchange(t *testing.T) { "2.1.168.192.in-addr.arpa.": {}, }, } - upstreamErr := agherr.Error("upstream error") + upstreamErr := errors.Error("upstream error") errUpstream := &aghtest.TestErrUpstream{ Err: upstreamErr, } diff --git a/internal/dnsforward/dnsrewrite.go b/internal/dnsforward/dnsrewrite.go index ad2d5cea..2d4cc5a7 100644 --- a/internal/dnsforward/dnsrewrite.go +++ b/internal/dnsforward/dnsrewrite.go @@ -4,9 +4,9 @@ import ( "fmt" "net" - "github.com/AdguardTeam/AdGuardHome/internal/agherr" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/dnsproxy/proxy" + "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/urlfilter/rules" "github.com/miekg/dns" @@ -83,7 +83,7 @@ func (s *Server) filterDNSRewrite(req *dns.Msg, res filtering.Result, d *proxy.D resp := s.makeResponse(req) dnsrr := res.DNSRewriteResult if dnsrr == nil { - return agherr.Error("no dns rewrite rule content") + return errors.Error("no dns rewrite rule content") } resp.Rcode = dnsrr.RCode @@ -94,7 +94,7 @@ func (s *Server) filterDNSRewrite(req *dns.Msg, res filtering.Result, d *proxy.D } if dnsrr.Response == nil { - return agherr.Error("no dns rewrite rule responses") + return errors.Error("no dns rewrite rule responses") } rr := req.Question[0].Qtype diff --git a/internal/dnsforward/http.go b/internal/dnsforward/http.go index f83241fb..35c1d9d4 100644 --- a/internal/dnsforward/http.go +++ b/internal/dnsforward/http.go @@ -8,11 +8,11 @@ import ( "strconv" "strings" - "github.com/AdguardTeam/AdGuardHome/internal/agherr" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghstrings" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" + "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/miekg/dns" ) @@ -399,7 +399,7 @@ func validateUpstream(u string) (bool, error) { // separateUpstream returns the upstream without the specified domains. // useDefault is true when a default upstream must be used. func separateUpstream(upstreamStr string) (upstream string, useDefault bool, err error) { - defer agherr.Annotate("bad upstream for domain spec %q: %w", &err, upstreamStr) + defer func() { err = errors.Annotate(err, "bad upstream for domain spec %q: %w", upstreamStr) }() if !strings.HasPrefix(upstreamStr, "[/") { return upstreamStr, true, nil @@ -407,7 +407,7 @@ func separateUpstream(upstreamStr string) (upstream string, useDefault bool, err parts := strings.Split(upstreamStr[2:], "/]") if len(parts) != 2 { - return "", false, agherr.Error("duplicated separator") + return "", false, errors.Error("duplicated separator") } domains := parts[0] diff --git a/internal/dnsforward/ipset_linux.go b/internal/dnsforward/ipset_linux.go index 32d27c9b..4f02b634 100644 --- a/internal/dnsforward/ipset_linux.go +++ b/internal/dnsforward/ipset_linux.go @@ -10,7 +10,7 @@ import ( "strings" "sync" - "github.com/AdguardTeam/AdGuardHome/internal/agherr" + "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/digineo/go-ipset/v2" "github.com/mdlayher/netlink" @@ -83,7 +83,7 @@ func (c *ipsetCtx) ipsetProps(name string) (set ipsetProps, err error) { } if res == nil || res.Family == nil { - return set, agherr.Error("empty response or no family data") + return set, errors.Error("empty response or no family data") } family := netfilter.ProtoFamily(res.Family.Value) @@ -193,23 +193,23 @@ func (c *ipsetCtx) init(ipsetConfig []string) (err error) { // Close closes the Linux Netfilter connections. func (c *ipsetCtx) Close() (err error) { - var errors []error + var errs []error if c.ipv4Conn != nil { err = c.ipv4Conn.Close() if err != nil { - errors = append(errors, err) + errs = append(errs, err) } } if c.ipv6Conn != nil { err = c.ipv6Conn.Close() if err != nil { - errors = append(errors, err) + errs = append(errs, err) } } - if len(errors) != 0 { - return agherr.Many("closing ipsets", errors...) + if len(errs) != 0 { + return errors.List("closing ipsets", errs...) } return nil diff --git a/internal/home/authglinet.go b/internal/home/authglinet.go index 317e468c..213d52f9 100644 --- a/internal/home/authglinet.go +++ b/internal/home/authglinet.go @@ -73,22 +73,27 @@ func glGetTokenDate(file string) uint32 { f, err := os.Open(file) if err != nil { log.Error("os.Open: %s", err) + return 0 } - defer f.Close() + defer func() { + derr := f.Close() + if derr != nil { + log.Error("glinet: closing file: %s", err) + } + }() - fileReadCloser, err := aghio.LimitReadCloser(f, MaxFileSize) + fileReader, err := aghio.LimitReader(f, MaxFileSize) if err != nil { log.Error("creating limited reader: %s", err) return 0 } - defer fileReadCloser.Close() var dateToken uint32 // This use of ReadAll is now safe, because we limited reader. - bs, err := io.ReadAll(fileReadCloser) + bs, err := io.ReadAll(fileReader) if err != nil { log.Error("reading token: %s", err) diff --git a/internal/home/clients.go b/internal/home/clients.go index 8f05c1c4..00caa34e 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -11,7 +11,6 @@ import ( "sync" "time" - "github.com/AdguardTeam/AdGuardHome/internal/agherr" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghstrings" "github.com/AdguardTeam/AdGuardHome/internal/dhcpd" @@ -20,6 +19,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/querylog" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" + "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" ) @@ -438,11 +438,11 @@ func (clients *clientsContainer) FindRuntimeClient(ip string) (RuntimeClient, bo func (clients *clientsContainer) check(c *Client) (err error) { switch { case c == nil: - return agherr.Error("client is nil") + return errors.Error("client is nil") case c.Name == "": - return agherr.Error("invalid name") + return errors.Error("invalid name") case len(c.IDs) == 0: - return agherr.Error("id required") + return errors.Error("id required") default: // Go on. } @@ -570,14 +570,14 @@ func (clients *clientsContainer) Update(name string, c *Client) (err error) { prev, ok := clients.list[name] if !ok { - return agherr.Error("client not found") + return errors.Error("client not found") } // First, check the name index. if prev.Name != c.Name { _, ok = clients.list[c.Name] if ok { - return agherr.Error("client already exists") + return errors.Error("client already exists") } } diff --git a/internal/home/config.go b/internal/home/config.go index 46e8c2b7..7071e364 100644 --- a/internal/home/config.go +++ b/internal/home/config.go @@ -1,7 +1,6 @@ package home import ( - "errors" "fmt" "net" "os" @@ -14,6 +13,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/querylog" "github.com/AdguardTeam/AdGuardHome/internal/stats" "github.com/AdguardTeam/AdGuardHome/internal/version" + "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/google/renameio/maybe" yaml "gopkg.in/yaml.v2" diff --git a/internal/home/controlinstall.go b/internal/home/controlinstall.go index c441cde0..97a0c084 100644 --- a/internal/home/controlinstall.go +++ b/internal/home/controlinstall.go @@ -15,7 +15,6 @@ import ( "strings" "time" - "github.com/AdguardTeam/AdGuardHome/internal/agherr" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/golibs/log" ) @@ -271,7 +270,7 @@ func copyInstallSettings(dst, src *configuration) { const shutdownTimeout = 5 * time.Second func shutdownSrv(ctx context.Context, cancel context.CancelFunc, srv *http.Server) { - defer agherr.LogPanic("") + defer log.OnPanic("") if srv == nil { return diff --git a/internal/home/controlupdate.go b/internal/home/controlupdate.go index 1aafa2f7..74c5bfbc 100644 --- a/internal/home/controlupdate.go +++ b/internal/home/controlupdate.go @@ -3,7 +3,6 @@ package home import ( "context" "encoding/json" - "errors" "net/http" "os" "os/exec" @@ -14,6 +13,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/AdGuardHome/internal/updater" + "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" ) diff --git a/internal/home/dns.go b/internal/home/dns.go index 68008fa2..809bb011 100644 --- a/internal/home/dns.go +++ b/internal/home/dns.go @@ -8,12 +8,12 @@ import ( "path/filepath" "strconv" - "github.com/AdguardTeam/AdGuardHome/internal/agherr" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/querylog" "github.com/AdguardTeam/AdGuardHome/internal/stats" "github.com/AdguardTeam/dnsproxy/proxy" + "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/ameshkov/dnscrypt/v2" yaml "gopkg.in/yaml.v2" @@ -207,14 +207,14 @@ func generateServerConfig() (newConf dnsforward.ServerConfig, err error) { func newDNSCrypt(hosts []net.IP, tlsConf tlsConfigSettings) (dnscc dnsforward.DNSCryptConfig, err error) { if tlsConf.DNSCryptConfigFile == "" { - return dnscc, agherr.Error("no dnscrypt_config_file") + return dnscc, errors.Error("no dnscrypt_config_file") } f, err := os.Open(tlsConf.DNSCryptConfigFile) if err != nil { return dnscc, fmt.Errorf("opening dnscrypt config: %w", err) } - defer f.Close() + defer func() { err = errors.WithDeferred(err, f.Close()) }() rc := &dnscrypt.ResolverConfig{} err = yaml.NewDecoder(f).Decode(rc) diff --git a/internal/home/filter.go b/internal/home/filter.go index 7b46a83d..de2ad2ef 100644 --- a/internal/home/filter.go +++ b/internal/home/filter.go @@ -2,7 +2,6 @@ package home import ( "bufio" - "errors" "fmt" "hash/crc32" "io" @@ -17,6 +16,7 @@ import ( "time" "github.com/AdguardTeam/AdGuardHome/internal/filtering" + "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" ) @@ -566,17 +566,18 @@ func (f *Filtering) updateIntl(filter *filter) (updated bool, err error) { if err != nil { return updated, fmt.Errorf("open file: %w", err) } + defer func() { err = errors.WithDeferred(err, f.Close()) }() - defer f.Close() reader = f } else { var resp *http.Response resp, err = Context.client.Get(filter.URL) if err != nil { log.Printf("Couldn't request filter from URL %s, skipping: %s", filter.URL, err) + return updated, err } - defer resp.Body.Close() + defer func() { err = errors.WithDeferred(err, resp.Body.Close()) }() if resp.StatusCode != http.StatusOK { log.Printf("Got status code %d from URL %s, skipping", resp.StatusCode, filter.URL) @@ -634,7 +635,7 @@ func (f *Filtering) load(filter *filter) (err error) { } else if err != nil { return fmt.Errorf("opening filter file: %w", err) } - defer file.Close() + defer func() { err = errors.WithDeferred(err, file.Close()) }() st, err := file.Stat() if err != nil { diff --git a/internal/home/home.go b/internal/home/home.go index cd34927a..f40c9569 100644 --- a/internal/home/home.go +++ b/internal/home/home.go @@ -5,7 +5,6 @@ import ( "context" "crypto/tls" "crypto/x509" - "errors" "fmt" "io/fs" "net" @@ -21,7 +20,6 @@ import ( "syscall" "time" - "github.com/AdguardTeam/AdGuardHome/internal/agherr" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/AdGuardHome/internal/dhcpd" @@ -31,6 +29,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/stats" "github.com/AdguardTeam/AdGuardHome/internal/updater" "github.com/AdguardTeam/AdGuardHome/internal/version" + "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "gopkg.in/natefinch/lumberjack.v2" ) @@ -736,7 +735,7 @@ func customDialContext(ctx context.Context, network, addr string) (conn net.Conn return conn, err } - return nil, agherr.Many(fmt.Sprintf("couldn't dial to %s", addr), dialErrs...) + return nil, errors.List(fmt.Sprintf("couldn't dial to %s", addr), dialErrs...) } func getHTTPProxy(_ *http.Request) (*url.URL, error) { diff --git a/internal/home/middlewares.go b/internal/home/middlewares.go index de38c3fa..73e1566e 100644 --- a/internal/home/middlewares.go +++ b/internal/home/middlewares.go @@ -1,6 +1,7 @@ package home import ( + "io" "net/http" "github.com/AdguardTeam/AdGuardHome/internal/aghio" @@ -58,14 +59,20 @@ func limitRequestBody(h http.Handler) (limited http.Handler) { szLim = largerReqBodySzLim } - r.Body, err = aghio.LimitReadCloser(r.Body, szLim) + var reader io.Reader + reader, err = aghio.LimitReader(r.Body, szLim) if err != nil { log.Error("limitRequestBody: %s", err) return } - h.ServeHTTP(w, r) + // HTTP handlers aren't supposed to call r.Body.Close(), so just + // replace the body in a clone. + rr := r.Clone(r.Context()) + rr.Body = io.NopCloser(reader) + + h.ServeHTTP(w, rr) }) } diff --git a/internal/home/rdns.go b/internal/home/rdns.go index a2f59c8c..81d7ca5a 100644 --- a/internal/home/rdns.go +++ b/internal/home/rdns.go @@ -5,7 +5,6 @@ import ( "net" "time" - "github.com/AdguardTeam/AdGuardHome/internal/agherr" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/golibs/cache" "github.com/AdguardTeam/golibs/log" @@ -85,7 +84,7 @@ func (r *RDNS) Begin(ip net.IP) { // workerLoop handles incoming IP addresses from ipChan and adds it into // clients. func (r *RDNS) workerLoop() { - defer agherr.LogPanic("rdns") + defer log.OnPanic("rdns") for ip := range r.ipCh { host, err := r.exchanger.Exchange(ip) diff --git a/internal/home/rdns_test.go b/internal/home/rdns_test.go index fd111565..329e0bf4 100644 --- a/internal/home/rdns_test.go +++ b/internal/home/rdns_test.go @@ -3,7 +3,6 @@ package home import ( "bytes" "encoding/binary" - "errors" "net" "sync" "testing" @@ -13,6 +12,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/cache" + "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/miekg/dns" "github.com/stretchr/testify/assert" @@ -141,7 +141,7 @@ func TestRDNS_WorkerLoop(t *testing.T) { }, } errUpstream := &aghtest.TestErrUpstream{ - Err: errors.New("1234"), + Err: errors.Error("1234"), } testCases := []struct { diff --git a/internal/home/service.go b/internal/home/service.go index 4079e239..c940da07 100644 --- a/internal/home/service.go +++ b/internal/home/service.go @@ -1,7 +1,6 @@ package home import ( - "errors" "fmt" "io/fs" "os" @@ -11,6 +10,7 @@ import ( "syscall" "github.com/AdguardTeam/AdGuardHome/internal/aghos" + "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/kardianos/service" ) diff --git a/internal/home/tls.go b/internal/home/tls.go index 34bc0bce..eaf5884c 100644 --- a/internal/home/tls.go +++ b/internal/home/tls.go @@ -10,7 +10,6 @@ import ( "encoding/base64" "encoding/json" "encoding/pem" - "errors" "fmt" "net/http" "os" @@ -21,6 +20,7 @@ import ( "sync" "time" + "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "golang.org/x/sys/cpu" ) @@ -341,14 +341,14 @@ func verifyCertChain(data *tlsConfigStatus, certChain, serverName string) error parsed, err := x509.ParseCertificate(cert.Bytes) if err != nil { data.WarningValidation = fmt.Sprintf("Failed to parse certificate: %s", err) - return errors.New(data.WarningValidation) + return errors.Error(data.WarningValidation) } parsedCerts = append(parsedCerts, parsed) } if len(parsedCerts) == 0 { data.WarningValidation = "You have specified an empty certificate" - return errors.New(data.WarningValidation) + return errors.Error(data.WarningValidation) } data.ValidCert = true @@ -415,14 +415,14 @@ func validatePkey(data *tlsConfigStatus, pkey string) error { if key == nil { data.WarningValidation = "No valid keys were found" - return errors.New(data.WarningValidation) + return errors.Error(data.WarningValidation) } // parse the decoded key _, keytype, err := parsePrivateKey(key.Bytes) if err != nil { data.WarningValidation = fmt.Sprintf("Failed to parse private key: %s", err) - return errors.New(data.WarningValidation) + return errors.Error(data.WarningValidation) } data.ValidKey = true @@ -479,7 +479,7 @@ func parsePrivateKey(der []byte) (crypto.PrivateKey, string, error) { case *ecdsa.PrivateKey: return key, "ECDSA", nil default: - return nil, "", errors.New("tls: found unknown private key type in PKCS#8 wrapping") + return nil, "", errors.Error("tls: found unknown private key type in PKCS#8 wrapping") } } @@ -487,7 +487,7 @@ func parsePrivateKey(der []byte) (crypto.PrivateKey, string, error) { return key, "ECDSA", nil } - return nil, "", errors.New("tls: failed to parse private key") + return nil, "", errors.Error("tls: failed to parse private key") } // unmarshalTLS handles base64-encoded certificates transparently diff --git a/internal/home/upgrade.go b/internal/home/upgrade.go index 278f3774..c8e54319 100644 --- a/internal/home/upgrade.go +++ b/internal/home/upgrade.go @@ -1,7 +1,6 @@ package home import ( - "errors" "fmt" "net/url" "os" @@ -12,6 +11,7 @@ import ( "strings" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" + "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/google/renameio/maybe" "golang.org/x/crypto/bcrypt" diff --git a/internal/home/whois.go b/internal/home/whois.go index 07023415..e5b0ca08 100644 --- a/internal/home/whois.go +++ b/internal/home/whois.go @@ -11,6 +11,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghio" "github.com/AdguardTeam/golibs/cache" + "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" ) @@ -139,22 +140,22 @@ func whoisParse(data string) (m strmap) { const MaxConnReadSize = 64 * 1024 // Send request to a server and receive the response -func (w *Whois) query(ctx context.Context, target, serverAddr string) (string, error) { +func (w *Whois) query(ctx context.Context, target, serverAddr string) (data string, err error) { addr, _, _ := net.SplitHostPort(serverAddr) if addr == "whois.arin.net" { target = "n + " + target } + conn, err := w.dialContext(ctx, "tcp", serverAddr) if err != nil { return "", err } - defer conn.Close() + defer func() { err = errors.WithDeferred(err, conn.Close()) }() - connReadCloser, err := aghio.LimitReadCloser(conn, MaxConnReadSize) + r, err := aghio.LimitReader(conn, MaxConnReadSize) if err != nil { return "", err } - defer connReadCloser.Close() _ = conn.SetReadDeadline(time.Now().Add(time.Duration(w.timeoutMsec) * time.Millisecond)) _, err = conn.Write([]byte(target + "\r\n")) @@ -163,12 +164,13 @@ func (w *Whois) query(ctx context.Context, target, serverAddr string) (string, e } // This use of ReadAll is now safe, because we limited the conn Reader. - data, err := io.ReadAll(connReadCloser) + var whoisData []byte + whoisData, err = io.ReadAll(r) if err != nil { return "", err } - return string(data), nil + return string(whoisData), nil } // Query WHOIS servers (handle redirects) diff --git a/internal/querylog/qlog.go b/internal/querylog/qlog.go index 5a556a9b..bcdaa8b4 100644 --- a/internal/querylog/qlog.go +++ b/internal/querylog/qlog.go @@ -2,7 +2,6 @@ package querylog import ( - "errors" "fmt" "net" "os" @@ -11,6 +10,7 @@ import ( "time" "github.com/AdguardTeam/AdGuardHome/internal/filtering" + "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/miekg/dns" ) diff --git a/internal/querylog/qlogfile.go b/internal/querylog/qlogfile.go index cf8d5f8c..f0c7b7c4 100644 --- a/internal/querylog/qlogfile.go +++ b/internal/querylog/qlogfile.go @@ -8,15 +8,15 @@ import ( "sync" "time" - "github.com/AdguardTeam/AdGuardHome/internal/agherr" + "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" ) // Timestamp not found errors. const ( - ErrTSNotFound agherr.Error = "ts not found" - ErrTSTooLate agherr.Error = "ts too late" - ErrTSTooEarly agherr.Error = "ts too early" + ErrTSNotFound errors.Error = "ts not found" + ErrTSTooLate errors.Error = "ts too late" + ErrTSTooEarly errors.Error = "ts too early" ) // TODO: Find a way to grow buffer instead of relying on this value when reading strings diff --git a/internal/querylog/qlogfile_test.go b/internal/querylog/qlogfile_test.go index 13711558..5e82ce72 100644 --- a/internal/querylog/qlogfile_test.go +++ b/internal/querylog/qlogfile_test.go @@ -2,7 +2,6 @@ package querylog import ( "encoding/binary" - "errors" "fmt" "io" "math" @@ -12,6 +11,7 @@ import ( "testing" "time" + "github.com/AdguardTeam/golibs/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/internal/querylog/qlogreader.go b/internal/querylog/qlogreader.go index 03a93a0f..7735eb67 100644 --- a/internal/querylog/qlogreader.go +++ b/internal/querylog/qlogreader.go @@ -1,12 +1,11 @@ package querylog import ( - "errors" "fmt" "io" "os" - "github.com/AdguardTeam/AdGuardHome/internal/agherr" + "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" ) @@ -154,7 +153,7 @@ func closeQFiles(qFiles []*QLogFile) error { } if len(errs) > 0 { - return agherr.Many("Error while closing QLogReader", errs...) + return errors.List("error while closing QLogReader", errs...) } return nil diff --git a/internal/querylog/querylog.go b/internal/querylog/querylog.go index eb22b74a..bd158fc2 100644 --- a/internal/querylog/querylog.go +++ b/internal/querylog/querylog.go @@ -6,8 +6,8 @@ import ( "path/filepath" "time" - "github.com/AdguardTeam/AdGuardHome/internal/agherr" "github.com/AdguardTeam/AdGuardHome/internal/filtering" + "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/miekg/dns" ) @@ -78,13 +78,13 @@ type AddParams struct { func (p *AddParams) validate() (err error) { switch { case p.Question == nil: - return agherr.Error("question is nil") + return errors.Error("question is nil") case len(p.Question.Question) != 1: - return agherr.Error("more than one question") + return errors.Error("more than one question") case len(p.Question.Question[0].Name) == 0: - return agherr.Error("no host in question") + return errors.Error("no host in question") case p.ClientIP == nil: - return agherr.Error("no client ip") + return errors.Error("no client ip") default: return nil } diff --git a/internal/querylog/querylogfile.go b/internal/querylog/querylogfile.go index 7a170706..aac9a5f6 100644 --- a/internal/querylog/querylogfile.go +++ b/internal/querylog/querylogfile.go @@ -3,10 +3,10 @@ package querylog import ( "bytes" "encoding/json" - "errors" "os" "time" + "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" ) @@ -39,7 +39,7 @@ func (l *queryLog) flushLogBuffer(fullFlush bool) error { } // flushToFile saves the specified log entries to the query log file -func (l *queryLog) flushToFile(buffer []*logEntry) error { +func (l *queryLog) flushToFile(buffer []*logEntry) (err error) { if len(buffer) == 0 { log.Debug("querylog: there's nothing to write to a file") return nil @@ -49,9 +49,10 @@ func (l *queryLog) flushToFile(buffer []*logEntry) error { var b bytes.Buffer e := json.NewEncoder(&b) for _, entry := range buffer { - err := e.Encode(entry) + err = e.Encode(entry) if err != nil { log.Error("Failed to marshal entry: %s", err) + return err } } @@ -59,7 +60,6 @@ func (l *queryLog) flushToFile(buffer []*logEntry) error { elapsed := time.Since(start) log.Debug("%d elements serialized via json in %v: %d kB, %v/entry, %v/entry", len(buffer), elapsed, b.Len()/1024, float64(b.Len())/float64(len(buffer)), elapsed/time.Duration(len(buffer))) - var err error var zb bytes.Buffer filename := l.logFile zb = b @@ -71,7 +71,7 @@ func (l *queryLog) flushToFile(buffer []*logEntry) error { log.Error("failed to create file \"%s\": %s", filename, err) return err } - defer f.Close() + defer func() { err = errors.WithDeferred(err, f.Close()) }() n, err := f.Write(zb.Bytes()) if err != nil { @@ -109,7 +109,12 @@ func (l *queryLog) readFileFirstTimeValue() int64 { if err != nil { return -1 } - defer f.Close() + defer func() { + derr := f.Close() + if derr != nil { + log.Error("querylog: closing file: %s", derr) + } + }() buf := make([]byte, 500) r, err := f.Read(buf) diff --git a/internal/querylog/search.go b/internal/querylog/search.go index 820e4376..2181ab04 100644 --- a/internal/querylog/search.go +++ b/internal/querylog/search.go @@ -142,7 +142,12 @@ func (l *queryLog) searchFiles( return entries, oldest, 0 } - defer r.Close() + defer func() { + derr := r.Close() + if derr != nil { + log.Error("querylog: closing file: %s", err) + } + }() if params.olderThan.IsZero() { err = r.SeekStart() diff --git a/internal/stats/unit.go b/internal/stats/unit.go index b3a42767..968a2f66 100644 --- a/internal/stats/unit.go +++ b/internal/stats/unit.go @@ -4,7 +4,6 @@ import ( "bytes" "encoding/binary" "encoding/gob" - "errors" "fmt" "net" "os" @@ -12,7 +11,7 @@ import ( "sync" "time" - "github.com/AdguardTeam/AdGuardHome/internal/agherr" + "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" bolt "go.etcd.io/bbolt" ) @@ -93,7 +92,7 @@ func createObject(conf Config) (s *statsCtx, err error) { // TODO(a.garipov): See if this is actually necessary. Looks // like a rather bizarre solution. - errStop := agherr.Error("stop iteration") + errStop := errors.Error("stop iteration") forEachBkt := func(name []byte, _ *bolt.Bucket) (cberr error) { nameID := uint32(btoi(name)) if nameID < firstID { diff --git a/internal/updater/check.go b/internal/updater/check.go index 71c4ac47..edf046af 100644 --- a/internal/updater/check.go +++ b/internal/updater/check.go @@ -4,10 +4,12 @@ import ( "encoding/json" "fmt" "io" + "net/http" "strings" "time" "github.com/AdguardTeam/AdGuardHome/internal/aghio" + "github.com/AdguardTeam/golibs/errors" ) // TODO(a.garipov): Make configurable. @@ -27,7 +29,7 @@ const MaxResponseSize = 64 * 1024 // VersionInfo downloads the latest version information. If forceRecheck is // false and there are cached results, those results are returned. -func (u *Updater) VersionInfo(forceRecheck bool) (VersionInfo, error) { +func (u *Updater) VersionInfo(forceRecheck bool) (vi VersionInfo, err error) { u.mu.Lock() defer u.mu.Unlock() @@ -37,22 +39,23 @@ func (u *Updater) VersionInfo(forceRecheck bool) (VersionInfo, error) { return u.prevCheckResult, u.prevCheckError } + var resp *http.Response vcu := u.versionCheckURL - resp, err := u.client.Get(vcu) + resp, err = u.client.Get(vcu) if err != nil { return VersionInfo{}, fmt.Errorf("updater: HTTP GET %s: %w", vcu, err) } - defer resp.Body.Close() + defer func() { err = errors.WithDeferred(err, resp.Body.Close()) }() - resp.Body, err = aghio.LimitReadCloser(resp.Body, MaxResponseSize) + var r io.Reader + r, err = aghio.LimitReader(resp.Body, MaxResponseSize) if err != nil { return VersionInfo{}, fmt.Errorf("updater: LimitReadCloser: %w", err) } - defer resp.Body.Close() // This use of ReadAll is safe, because we just limited the appropriate // ReadCloser. - body, err := io.ReadAll(resp.Body) + body, err := io.ReadAll(r) if err != nil { return VersionInfo{}, fmt.Errorf("updater: HTTP GET %s: %w", vcu, err) } diff --git a/internal/updater/updater.go b/internal/updater/updater.go index 1d62800a..47ae1397 100644 --- a/internal/updater/updater.go +++ b/internal/updater/updater.go @@ -5,7 +5,6 @@ import ( "archive/tar" "archive/zip" "compress/gzip" - "errors" "fmt" "io" "net/http" @@ -20,6 +19,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghio" "github.com/AdguardTeam/AdGuardHome/internal/version" + "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" ) @@ -283,22 +283,23 @@ func (u *Updater) clean() { const MaxPackageFileSize = 32 * 1024 * 1024 // Download package file and save it to disk -func (u *Updater) downloadPackageFile(url, filename string) error { - resp, err := u.client.Get(url) +func (u *Updater) downloadPackageFile(url, filename string) (err error) { + var resp *http.Response + resp, err = u.client.Get(url) if err != nil { return fmt.Errorf("http request failed: %w", err) } - defer resp.Body.Close() + defer func() { err = errors.WithDeferred(err, resp.Body.Close()) }() - resp.Body, err = aghio.LimitReadCloser(resp.Body, MaxPackageFileSize) + var r io.Reader + r, err = aghio.LimitReader(resp.Body, MaxPackageFileSize) if err != nil { return fmt.Errorf("http request failed: %w", err) } - defer resp.Body.Close() log.Debug("updater: reading HTTP body") // This use of ReadAll is now safe, because we limited body's Reader. - body, err := io.ReadAll(resp.Body) + body, err := io.ReadAll(r) if err != nil { return fmt.Errorf("io.ReadAll() failed: %w", err) } @@ -313,172 +314,178 @@ func (u *Updater) downloadPackageFile(url, filename string) error { return nil } +func tarGzFileUnpackOne(outDir string, tr *tar.Reader, hdr *tar.Header) (name string, err error) { + name = filepath.Base(hdr.Name) + if name == "" { + return "", nil + } + + outputName := filepath.Join(outDir, name) + + if hdr.Typeflag == tar.TypeDir { + if name == "AdGuardHome" { + // Top-level AdGuardHome/. Skip it. + // + // TODO(a.garipov): This whole package needs to be + // rewritten and covered in more integration tests. It + // has weird assumptions and file mode issues. + return "", nil + } + + err = os.Mkdir(outputName, os.FileMode(hdr.Mode&0o777)) + if err != nil && !errors.Is(err, os.ErrExist) { + return "", fmt.Errorf("os.Mkdir(%q): %w", outputName, err) + } + + log.Debug("updater: created directory %q", outputName) + + return "", nil + } + + if hdr.Typeflag != tar.TypeReg { + log.Debug("updater: %s: unknown file type %d, skipping", name, hdr.Typeflag) + + return "", nil + } + + var wc io.WriteCloser + wc, err = os.OpenFile( + outputName, + os.O_WRONLY|os.O_CREATE|os.O_TRUNC, + os.FileMode(hdr.Mode&0o777), + ) + if err != nil { + return "", fmt.Errorf("os.OpenFile(%s): %w", outputName, err) + } + defer func() { err = errors.WithDeferred(err, wc.Close()) }() + + _, err = io.Copy(wc, tr) + if err != nil { + return "", fmt.Errorf("io.Copy(): %w", err) + } + + log.Tracef("updater: created file %s", outputName) + + return name, nil +} + // Unpack all files from .tar.gz file to the specified directory // Existing files are overwritten -// All files are created inside 'outdir', subdirectories are not created +// All files are created inside outDir, subdirectories are not created // Return the list of files (not directories) written -func tarGzFileUnpack(tarfile, outdir string) ([]string, error) { +func tarGzFileUnpack(tarfile, outDir string) (files []string, err error) { f, err := os.Open(tarfile) if err != nil { return nil, fmt.Errorf("os.Open(): %w", err) } - defer func() { - _ = f.Close() - }() + defer func() { err = errors.WithDeferred(err, f.Close()) }() gzReader, err := gzip.NewReader(f) if err != nil { return nil, fmt.Errorf("gzip.NewReader(): %w", err) } + defer func() { err = errors.WithDeferred(err, gzReader.Close()) }() - var files []string - var err2 error tarReader := tar.NewReader(gzReader) for { - var header *tar.Header - header, err = tarReader.Next() - if err == io.EOF { - err2 = nil + var hdr *tar.Header + hdr, err = tarReader.Next() + if errors.Is(err, io.EOF) { + err = nil + break - } - if err != nil { - err2 = fmt.Errorf("tarReader.Next(): %w", err) + } else if err != nil { + err = fmt.Errorf("tarReader.Next(): %w", err) + break } - _, inputNameOnly := filepath.Split(header.Name) - if inputNameOnly == "" { - continue + var name string + name, err = tarGzFileUnpackOne(outDir, tarReader, hdr) + + if name != "" { + files = append(files, name) } - - outputName := filepath.Join(outdir, inputNameOnly) - - if header.Typeflag == tar.TypeDir { - if inputNameOnly == "AdGuardHome" { - // Top-level AdGuardHome/. Skip it. - // - // TODO(a.garipov): This whole package needs to - // be rewritten and covered in more integration - // tests. It has weird assumptions and file - // mode issues. - continue - } - - err = os.Mkdir(outputName, os.FileMode(header.Mode&0o777)) - if err != nil && !errors.Is(err, os.ErrExist) { - err2 = fmt.Errorf("os.Mkdir(%q): %w", outputName, err) - - break - } - - log.Debug("updater: created directory %q", outputName) - - continue - } else if header.Typeflag != tar.TypeReg { - log.Debug("updater: %s: unknown file type %d, skipping", inputNameOnly, header.Typeflag) - continue - } - - var f io.WriteCloser - f, err = os.OpenFile(outputName, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.FileMode(header.Mode&0o777)) - if err != nil { - err2 = fmt.Errorf("os.OpenFile(%s): %w", outputName, err) - break - } - _, err = io.Copy(f, tarReader) - if err != nil { - _ = f.Close() - err2 = fmt.Errorf("io.Copy(): %w", err) - break - } - err = f.Close() - if err != nil { - err2 = fmt.Errorf("f.Close(): %w", err) - break - } - - log.Debug("updater: created file %s", outputName) - files = append(files, header.Name) } - _ = gzReader.Close() - return files, err2 + return files, err +} + +func zipFileUnpackOne(outDir string, zf *zip.File) (name string, err error) { + var rc io.ReadCloser + rc, err = zf.Open() + if err != nil { + return "", fmt.Errorf("zip file Open(): %w", err) + } + defer func() { err = errors.WithDeferred(err, rc.Close()) }() + + fi := zf.FileInfo() + name = fi.Name() + if name == "" { + return "", nil + } + + outputName := filepath.Join(outDir, name) + if fi.IsDir() { + if name == "AdGuardHome" { + // Top-level AdGuardHome/. Skip it. + // + // TODO(a.garipov): See the similar todo in + // tarGzFileUnpack. + return "", nil + } + + err = os.Mkdir(outputName, fi.Mode()) + if err != nil && !errors.Is(err, os.ErrExist) { + return "", fmt.Errorf("os.Mkdir(%q): %w", outputName, err) + } + + log.Tracef("created directory %q", outputName) + + return "", nil + } + + var wc io.WriteCloser + wc, err = os.OpenFile(outputName, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, fi.Mode()) + if err != nil { + return "", fmt.Errorf("os.OpenFile(): %w", err) + } + defer func() { err = errors.WithDeferred(err, wc.Close()) }() + + _, err = io.Copy(wc, rc) + if err != nil { + return "", fmt.Errorf("io.Copy(): %w", err) + } + + log.Tracef("created file %s", outputName) + + return name, nil } // Unpack all files from .zip file to the specified directory // Existing files are overwritten -// All files are created inside 'outdir', subdirectories are not created +// All files are created inside 'outDir', subdirectories are not created // Return the list of files (not directories) written -func zipFileUnpack(zipfile, outdir string) ([]string, error) { - r, err := zip.OpenReader(zipfile) +func zipFileUnpack(zipfile, outDir string) (files []string, err error) { + zrc, err := zip.OpenReader(zipfile) if err != nil { return nil, fmt.Errorf("zip.OpenReader(): %w", err) } - defer r.Close() + defer func() { err = errors.WithDeferred(err, zrc.Close()) }() - var files []string - var err2 error - var zr io.ReadCloser - for _, zf := range r.File { - zr, err = zf.Open() + for _, zf := range zrc.File { + var name string + name, err = zipFileUnpackOne(outDir, zf) if err != nil { - err2 = fmt.Errorf("zip file Open(): %w", err) break } - fi := zf.FileInfo() - inputNameOnly := fi.Name() - if inputNameOnly == "" { - continue + if name != "" { + files = append(files, name) } - - outputName := filepath.Join(outdir, inputNameOnly) - - if fi.IsDir() { - if inputNameOnly == "AdGuardHome" { - // Top-level AdGuardHome/. Skip it. - // - // TODO(a.garipov): See the similar todo in - // tarGzFileUnpack. - continue - } - - err = os.Mkdir(outputName, fi.Mode()) - if err != nil && !errors.Is(err, os.ErrExist) { - err2 = fmt.Errorf("os.Mkdir(%q): %w", outputName, err) - - break - } - - log.Tracef("created directory %q", outputName) - - continue - } - - var f io.WriteCloser - f, err = os.OpenFile(outputName, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, fi.Mode()) - if err != nil { - err2 = fmt.Errorf("os.OpenFile(): %w", err) - break - } - _, err = io.Copy(f, zr) - if err != nil { - _ = f.Close() - err2 = fmt.Errorf("io.Copy(): %w", err) - break - } - err = f.Close() - if err != nil { - err2 = fmt.Errorf("f.Close(): %w", err) - break - } - - log.Tracef("created file %s", outputName) - files = append(files, inputNameOnly) } - _ = zr.Close() - return files, err2 + return files, err } // Copy file on disk diff --git a/scripts/make/go-lint.sh b/scripts/make/go-lint.sh index e0164612..40de04bc 100644 --- a/scripts/make/go-lint.sh +++ b/scripts/make/go-lint.sh @@ -75,10 +75,10 @@ esac # Simple Analyzers # blocklist_imports is a simple check against unwanted packages. Package -# io/ioutil is soft-deprecated. Package log is replaced by our own package -# github.com/AdguardTeam/golibs/log. +# io/ioutil is soft-deprecated. Packages errors and log are replaced by our own +# packages in the github.com/AdguardTeam/golibs module. blocklist_imports() { - git grep -F -e '"io/ioutil"' -e '"log"' -- '*.go' || exit 0; + git grep -F -e '"errors"' -e '"io/ioutil"' -e '"log"' -- '*.go' || exit 0; } # method_const is a simple check against the usage of some raw strings and @@ -192,13 +192,7 @@ nilness ./... exit_on_output shadow --strict ./... -# TODO(a.garipov): Enable errcheck fully after handling all errors, including -# the deferred and generated ones, properly. Also, perhaps, enable --blank. -# -# errcheck ./... -exit_on_output sh -c ' - errcheck --asserts --ignoregenerated ./... |\ - { grep -e "defer" -v || exit 0; } -' +# TODO(a.garipov): Enable --blank? +errcheck --asserts ./... staticcheck ./...