Pull request: replace agherr with golibs' errors

Merge in DNS/adguard-home from golibs-errors to master

Squashed commit of the following:

commit 5aba278a31c5a213bd9e08273ce7277c57713b22
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Mon May 24 17:05:18 2021 +0300

    all: imp code

commit f447eb875b81779fa9e391d98c31c1eeba7ef323
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Mon May 24 15:33:45 2021 +0300

    replace agherr with golibs' errors
This commit is contained in:
Ainar Garipov 2021-05-24 17:28:11 +03:00
parent 14250821ab
commit 03a828ef51
60 changed files with 406 additions and 672 deletions

View File

@ -66,8 +66,7 @@ on GitHub and most other Markdown renderers. -->
### <a id="code" href="#code">Code</a> ### <a id="code" href="#code">Code</a>
* Always `recover` from panics in new goroutines. Preferably in the very * Always `recover` from panics in new goroutines. Preferably in the very
first statement. If all you want there is a log message, use first statement. If all you want there is a log message, use `log.OnPanic`.
`agherr.LogPanic`.
* Avoid `fallthrough`. It makes it harder to rearrange `case`s, to reason * Avoid `fallthrough`. It makes it harder to rearrange `case`s, to reason
about the code, and also to switch the code to a handler approach, if that about the code, and also to switch the code to a handler approach, if that

2
go.mod
View File

@ -4,7 +4,7 @@ go 1.16
require ( require (
github.com/AdguardTeam/dnsproxy v0.37.4 github.com/AdguardTeam/dnsproxy v0.37.4
github.com/AdguardTeam/golibs v0.5.0 github.com/AdguardTeam/golibs v0.8.0
github.com/AdguardTeam/urlfilter v0.14.5 github.com/AdguardTeam/urlfilter v0.14.5
github.com/NYTimes/gziphandler v1.1.1 github.com/NYTimes/gziphandler v1.1.1
github.com/ameshkov/dnscrypt/v2 v2.1.3 github.com/ameshkov/dnscrypt/v2 v2.1.3

4
go.sum
View File

@ -14,8 +14,8 @@ github.com/AdguardTeam/dnsproxy v0.37.4/go.mod h1:xkJWEuTr550gPDmB9azsciKZzSXjf9
github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
github.com/AdguardTeam/golibs v0.4.2/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= github.com/AdguardTeam/golibs v0.4.2/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
github.com/AdguardTeam/golibs v0.4.4/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= github.com/AdguardTeam/golibs v0.4.4/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
github.com/AdguardTeam/golibs v0.5.0 h1:qwhEKjDrT0UcwDnHrNU2Yg/DLR9b/GsUncnXYW6VzAU= github.com/AdguardTeam/golibs v0.8.0 h1:rHo+yIgT2fivFG0yW2Cwk/DPc2+t/Aw6QvzPpiIFre0=
github.com/AdguardTeam/golibs v0.5.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= github.com/AdguardTeam/golibs v0.8.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
github.com/AdguardTeam/gomitmproxy v0.2.0/go.mod h1:Qdv0Mktnzer5zpdpi5rAwixNJzW2FN91LjKJCkVbYGU= github.com/AdguardTeam/gomitmproxy v0.2.0/go.mod h1:Qdv0Mktnzer5zpdpi5rAwixNJzW2FN91LjKJCkVbYGU=
github.com/AdguardTeam/urlfilter v0.14.5 h1:WyF0hg0MwKevsqNPkoaZFH8f5WRi/yuy/7qePtYt5Ts= github.com/AdguardTeam/urlfilter v0.14.5 h1:WyF0hg0MwKevsqNPkoaZFH8f5WRi/yuy/7qePtYt5Ts=
github.com/AdguardTeam/urlfilter v0.14.5/go.mod h1:klx4JbOfc4EaNb5lWLqOwfg+pVcyRukmoJRvO55lL5U= github.com/AdguardTeam/urlfilter v0.14.5/go.mod h1:klx4JbOfc4EaNb5lWLqOwfg+pVcyRukmoJRvO55lL5U=

View File

@ -1,128 +0,0 @@
// Package agherr contains AdGuard Home's error handling helpers.
package agherr
import (
"fmt"
"strings"
"github.com/AdguardTeam/golibs/log"
)
// Error is the constant error type.
type Error string
// Error implements the error interface for Error.
func (err Error) Error() (msg string) {
return string(err)
}
// manyError is an error containing several wrapped errors. It is created to be
// a simpler version of the API provided by github.com/joomcode/errorx.
type manyError struct {
message string
underlying []error
}
// Many wraps several errors and returns a single error.
//
// TODO(a.garipov): Add formatting to message.
func Many(message string, underlying ...error) (err error) {
err = &manyError{
message: message,
underlying: underlying,
}
return err
}
// Error implements the error interface for *manyError.
func (e *manyError) Error() (msg string) {
switch len(e.underlying) {
case 0:
return e.message
case 1:
return fmt.Sprintf("%s: %s", e.message, e.underlying[0])
default:
b := &strings.Builder{}
// Ignore errors, since strings.(*Buffer).Write never returns
// errors. We don't use aghstrings.WriteToBuilder here since
// this package should be importable for any other.
_, _ = fmt.Fprintf(b, "%s: %s (hidden: %s", e.message, e.underlying[0], e.underlying[1])
for _, u := range e.underlying[2:] {
// See comment above.
_, _ = fmt.Fprintf(b, ", %s", u)
}
// See comment above.
_, _ = b.WriteString(")")
return b.String()
}
}
// Unwrap implements the hidden errors.wrapper interface for *manyError.
func (e *manyError) Unwrap() (err error) {
if len(e.underlying) == 0 {
return nil
}
return e.underlying[0]
}
// wrapper is a copy of the hidden errors.wrapper interface for tests, linting,
// etc.
type wrapper interface {
Unwrap() error
}
// Annotate annotates the error with the message, unless the error is nil. This
// is a helper function to simplify code like this:
//
// func (f *foo) doStuff(s string) (err error) {
// defer func() {
// if err != nil {
// err = fmt.Errorf("bad foo string %q: %w", s, err)
// }
// }()
//
// // …
// }
//
// Instead, write:
//
// func (f *foo) doStuff(s string) (err error) {
// defer agherr.Annotate("bad foo string %q: %w", &err, s)
//
// // …
// }
//
// msg must contain the final ": %w" verb.
//
// TODO(a.garipov): Clearify the function usage.
func Annotate(msg string, errPtr *error, args ...interface{}) {
if errPtr == nil {
return
}
err := *errPtr
if err != nil {
args = append(args, err)
*errPtr = fmt.Errorf(msg, args...)
}
}
// LogPanic is a convinient helper function to log a panic in a goroutine. It
// should not be used where proper error handling is required.
func LogPanic(prefix string) {
if v := recover(); v != nil {
if prefix != "" {
log.Error("%s: recovered from panic: %v", prefix, v)
return
}
log.Error("recovered from panic: %v", v)
}
}

View File

@ -1,160 +0,0 @@
package agherr
import (
"bytes"
"errors"
"fmt"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestError_Error(t *testing.T) {
testCases := []struct {
err error
name string
want string
}{{
err: Many("a"),
name: "simple",
want: "a",
}, {
err: Many("a", errors.New("b")),
name: "wrapping",
want: "a: b",
}, {
err: Many("a", errors.New("b"), errors.New("c"), errors.New("d")),
name: "wrapping several",
want: "a: b (hidden: c, d)",
}, {
err: Many("a", Many("b", errors.New("c"), errors.New("d"))),
name: "wrapping wrapper",
want: "a: b: c (hidden: d)",
}}
for _, tc := range testCases {
assert.Equal(t, tc.want, tc.err.Error(), tc.name)
}
}
func TestError_Unwrap(t *testing.T) {
var _ wrapper = &manyError{}
const (
errSimple = iota
errWrapped
errNil
)
errs := []error{
errSimple: errors.New("a"),
errWrapped: fmt.Errorf("err: %w", errors.New("nested")),
errNil: nil,
}
testCases := []struct {
want error
wrapped error
name string
}{{
want: errs[errSimple],
wrapped: Many("a", errs[errSimple]),
name: "simple",
}, {
want: errs[errWrapped],
wrapped: Many("b", errs[errWrapped]),
name: "nested",
}, {
want: errs[errNil],
wrapped: Many("c", errs[errNil]),
name: "nil passed",
}, {
want: nil,
wrapped: Many("d"),
name: "nil not passed",
}}
for _, tc := range testCases {
assert.Equal(t, tc.want, errors.Unwrap(tc.wrapped), tc.name)
}
}
func TestAnnotate(t *testing.T) {
const s = "1234"
const wantMsg = `bad string "1234": test`
// Don't use const, because we can't take a pointer of a constant.
var errTest error = Error("test")
t.Run("nil", func(t *testing.T) {
var errPtr *error
assert.NotPanics(t, func() {
Annotate("bad string %q: %w", errPtr, s)
})
})
t.Run("non_nil", func(t *testing.T) {
errPtr := &errTest
assert.NotPanics(t, func() {
Annotate("bad string %q: %w", errPtr, s)
})
require.NotNil(t, errPtr)
err := *errPtr
require.Error(t, err)
assert.Equal(t, wantMsg, err.Error())
})
t.Run("defer", func(t *testing.T) {
f := func() (err error) {
defer Annotate("bad string %q: %w", &errTest, s)
return errTest
}
err := f()
require.Error(t, err)
assert.Equal(t, wantMsg, err.Error())
})
}
func TestLogPanic(t *testing.T) {
buf := &bytes.Buffer{}
aghtest.ReplaceLogWriter(t, buf)
t.Run("prefix", func(t *testing.T) {
const (
panicMsg = "spooky!"
prefix = "packagename"
errWithNoPrefix = "[error] recovered from panic: spooky!"
errWithPrefix = "[error] packagename: recovered from panic: spooky!"
)
panicFunc := func(prefix string) {
defer LogPanic(prefix)
panic(panicMsg)
}
panicFunc("")
assert.Contains(t, buf.String(), errWithNoPrefix)
buf.Reset()
panicFunc(prefix)
assert.Contains(t, buf.String(), errWithPrefix)
buf.Reset()
})
t.Run("don't_panic", func(t *testing.T) {
require.NotPanics(t, func() {
defer LogPanic("")
})
assert.Empty(t, buf.String())
})
}

View File

@ -1,59 +0,0 @@
// Package aghio contains extensions for io package's types and methods
package aghio
import (
"fmt"
"io"
)
// LimitReachedError records the limit and the operation that caused it.
type LimitReachedError struct {
Limit int64
}
// Error implements error interface for LimitReachedError.
// TODO(a.garipov): Think about error string format.
func (lre *LimitReachedError) Error() string {
return fmt.Sprintf("attempted to read more than %d bytes", lre.Limit)
}
// limitedReadCloser is a wrapper for io.ReadCloser with limited reader and
// dealing with agherr package.
type limitedReadCloser struct {
limit int64
n int64
rc io.ReadCloser
}
// Read implements Reader interface.
func (lrc *limitedReadCloser) Read(p []byte) (n int, err error) {
if lrc.n == 0 {
return 0, &LimitReachedError{
Limit: lrc.limit,
}
}
if int64(len(p)) > lrc.n {
p = p[0:lrc.n]
}
n, err = lrc.rc.Read(p)
lrc.n -= int64(n)
return n, err
}
// Close implements Closer interface.
func (lrc *limitedReadCloser) Close() error {
return lrc.rc.Close()
}
// LimitReadCloser wraps ReadCloser to make it's Reader stop with
// ErrLimitReached after n bytes read.
func LimitReadCloser(rc io.ReadCloser, n int64) (limited io.ReadCloser, err error) {
if n < 0 {
return nil, fmt.Errorf("aghio: invalid n in LimitReadCloser: %d", n)
}
return &limitedReadCloser{
limit: n,
n: n,
rc: rc,
}, nil
}

View File

@ -0,0 +1,59 @@
// Package aghio contains extensions for io package's types and methods
package aghio
import (
"fmt"
"io"
)
// LimitReachedError records the limit and the operation that caused it.
type LimitReachedError struct {
Limit int64
}
// Error implements error interface for LimitReachedError.
//
// TODO(a.garipov): Think about error string format.
func (lre *LimitReachedError) Error() string {
return fmt.Sprintf("attempted to read more than %d bytes", lre.Limit)
}
// limitedReader is a wrapper for io.Reader with limited reader and dealing with
// errors package.
type limitedReader struct {
r io.Reader
limit int64
n int64
}
// Read implements Reader interface.
func (lr *limitedReader) Read(p []byte) (n int, err error) {
if lr.n == 0 {
return 0, &LimitReachedError{
Limit: lr.limit,
}
}
if int64(len(p)) > lr.n {
p = p[0:lr.n]
}
n, err = lr.r.Read(p)
lr.n -= int64(n)
return n, err
}
// LimitReader wraps Reader to make it's Reader stop with ErrLimitReached after
// n bytes read.
func LimitReader(r io.Reader, n int64) (limited io.Reader, err error) {
if n < 0 {
return nil, fmt.Errorf("aghio: invalid n in LimitReader: %d", n)
}
return &limitedReader{
r: r,
limit: n,
n: n,
}, nil
}

View File

@ -10,7 +10,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestLimitReadCloser(t *testing.T) { func TestLimitReader(t *testing.T) {
testCases := []struct { testCases := []struct {
want error want error
name string name string
@ -24,20 +24,20 @@ func TestLimitReadCloser(t *testing.T) {
name: "zero", name: "zero",
n: 0, n: 0,
}, { }, {
want: fmt.Errorf("aghio: invalid n in LimitReadCloser: -1"), want: fmt.Errorf("aghio: invalid n in LimitReader: -1"),
name: "negative", name: "negative",
n: -1, n: -1,
}} }}
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
_, err := LimitReadCloser(nil, tc.n) _, err := LimitReader(nil, tc.n)
assert.Equal(t, tc.want, err) assert.Equal(t, tc.want, err)
}) })
} }
} }
func TestLimitedReadCloser_Read(t *testing.T) { func TestLimitedReader_Read(t *testing.T) {
testCases := []struct { testCases := []struct {
err error err error
name string name string
@ -77,7 +77,7 @@ func TestLimitedReadCloser_Read(t *testing.T) {
readCloser := io.NopCloser(strings.NewReader(tc.rStr)) readCloser := io.NopCloser(strings.NewReader(tc.rStr))
buf := make([]byte, tc.limit+1) buf := make([]byte, tc.limit+1)
lreader, err := LimitReadCloser(readCloser, tc.limit) lreader, err := LimitReader(readCloser, tc.limit)
require.NoError(t, err) require.NoError(t, err)
n, err := lreader.Read(buf) n, err := lreader.Read(buf)
@ -87,7 +87,7 @@ func TestLimitedReadCloser_Read(t *testing.T) {
} }
} }
func TestLimitedReadCloser_LimitReachedError(t *testing.T) { func TestLimitedReader_LimitReachedError(t *testing.T) {
testCases := []struct { testCases := []struct {
err error err error
name string name string

View File

@ -6,7 +6,7 @@ import (
"strconv" "strconv"
"strings" "strings"
"github.com/AdguardTeam/AdGuardHome/internal/agherr" "github.com/AdguardTeam/golibs/errors"
"golang.org/x/net/idna" "golang.org/x/net/idna"
) )
@ -26,11 +26,11 @@ func isValidHostRune(r rune) (ok bool) {
// ValidateHardwareAddress returns an error if hwa is not a valid EUI-48, // ValidateHardwareAddress returns an error if hwa is not a valid EUI-48,
// EUI-64, or 20-octet InfiniBand link-layer address. // EUI-64, or 20-octet InfiniBand link-layer address.
func ValidateHardwareAddress(hwa net.HardwareAddr) (err error) { func ValidateHardwareAddress(hwa net.HardwareAddr) (err error) {
defer agherr.Annotate("validating hardware address %q: %w", &err, hwa) defer func() { err = errors.Annotate(err, "validating hardware address %q: %w", hwa) }()
switch l := len(hwa); l { switch l := len(hwa); l {
case 0: case 0:
return agherr.Error("address is empty") return errors.Error("address is empty")
case 6, 8, 20: case 6, 8, 20:
return nil return nil
default: default:
@ -51,13 +51,13 @@ const maxDomainNameLen = 253
// ValidateDomainNameLabel returns an error if label is not a valid label of // ValidateDomainNameLabel returns an error if label is not a valid label of
// a domain name. // a domain name.
func ValidateDomainNameLabel(label string) (err error) { func ValidateDomainNameLabel(label string) (err error) {
defer agherr.Annotate("validating label %q: %w", &err, label) defer func() { err = errors.Annotate(err, "validating label %q: %w", label) }()
l := len(label) l := len(label)
if l > maxDomainLabelLen { if l > maxDomainLabelLen {
return fmt.Errorf("label is too long, max: %d", maxDomainLabelLen) return fmt.Errorf("label is too long, max: %d", maxDomainLabelLen)
} else if l == 0 { } else if l == 0 {
return agherr.Error("label is empty") return errors.Error("label is empty")
} }
if r := label[0]; !IsValidHostOuterRune(rune(r)) { if r := label[0]; !IsValidHostOuterRune(rune(r)) {
@ -87,7 +87,7 @@ func ValidateDomainNameLabel(label string) (err error) {
// TODO(a.garipov): After making sure that this works correctly, port this into // TODO(a.garipov): After making sure that this works correctly, port this into
// module golibs. // module golibs.
func ValidateDomainName(name string) (err error) { func ValidateDomainName(name string) (err error) {
defer agherr.Annotate("validating domain name %q: %w", &err, name) defer func() { err = errors.Annotate(err, "validating domain name %q: %w", name) }()
name, err = idna.ToASCII(name) name, err = idna.ToASCII(name)
if err != nil { if err != nil {
@ -96,7 +96,7 @@ func ValidateDomainName(name string) (err error) {
l := len(name) l := len(name)
if l == 0 { if l == 0 {
return agherr.Error("domain name is empty") return errors.Error("domain name is empty")
} else if l > maxDomainNameLen { } else if l > maxDomainNameLen {
return fmt.Errorf("too long, max: %d", maxDomainNameLen) return fmt.Errorf("too long, max: %d", maxDomainNameLen)
} }

View File

@ -2,7 +2,6 @@ package aghnet
import ( import (
"bufio" "bufio"
"errors"
"io" "io"
"net" "net"
"os" "os"
@ -12,6 +11,7 @@ import (
"sync" "sync"
"github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/fsnotify/fsnotify" "github.com/fsnotify/fsnotify"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -239,7 +239,13 @@ func (ehc *EtcHostsContainer) load(table map[string][]net.IP, tableRev map[strin
log.Error("etchostscontainer: %s", err) log.Error("etchostscontainer: %s", err)
return return
} }
defer f.Close() defer func() {
derr := f.Close()
if derr != nil {
log.Error("etchostscontainer: closing file: %s", err)
}
}()
r := bufio.NewReader(f) r := bufio.NewReader(f)
log.Debug("etchostscontainer: loading hosts from file %s", fn) log.Debug("etchostscontainer: loading hosts from file %s", fn)

View File

@ -3,7 +3,6 @@ package aghnet
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"net" "net"
"os" "os"
@ -14,14 +13,14 @@ import (
"syscall" "syscall"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/AdGuardHome/internal/aghstrings" "github.com/AdguardTeam/AdGuardHome/internal/aghstrings"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
) )
// ErrNoStaticIPInfo is returned by IfaceHasStaticIP when no information about // ErrNoStaticIPInfo is returned by IfaceHasStaticIP when no information about
// the IP being static is available. // the IP being static is available.
const ErrNoStaticIPInfo agherr.Error = "no information about static ip" const ErrNoStaticIPInfo errors.Error = "no information about static ip"
// IfaceHasStaticIP checks if interface is configured to have static IP address. // IfaceHasStaticIP checks if interface is configured to have static IP address.
// If it can't give a definitive answer, it returns false and an error for which // If it can't give a definitive answer, it returns false and an error for which
@ -106,7 +105,7 @@ func GetValidNetInterfacesForWeb() ([]*NetInterface, error) {
return nil, fmt.Errorf("couldn't get interfaces: %w", err) return nil, fmt.Errorf("couldn't get interfaces: %w", err)
} }
if len(ifaces) == 0 { if len(ifaces) == 0 {
return nil, errors.New("couldn't find any legible interface") return nil, errors.Error("couldn't find any legible interface")
} }
var netInterfaces []*NetInterface var netInterfaces []*NetInterface

View File

@ -5,13 +5,13 @@
package aghnet package aghnet
import ( import (
"errors"
"fmt" "fmt"
"os" "os"
"regexp" "regexp"
"strings" "strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors"
) )
// hardwarePortInfo - information obtained using MacOS networksetup // hardwarePortInfo - information obtained using MacOS networksetup
@ -83,7 +83,7 @@ func getHardwarePortInfo(hardwarePort string) (hardwarePortInfo, error) {
match := re.FindStringSubmatch(out) match := re.FindStringSubmatch(out)
if len(match) == 0 { if len(match) == 0 {
return h, errors.New("could not find hardware port info") return h, errors.Error("could not find hardware port info")
} }
h.name = hardwarePort h.name = hardwarePort
@ -105,7 +105,7 @@ func ifaceSetStaticIP(ifaceName string) (err error) {
} }
if portInfo.static { if portInfo.static {
return errors.New("IP address is already static") return errors.Error("IP address is already static")
} }
dnsAddrs, err := getEtcResolvConfServers() dnsAddrs, err := getEtcResolvConfServers()
@ -151,7 +151,7 @@ func getEtcResolvConfServers() ([]string, error) {
matches := re.FindAllStringSubmatch(string(body), -1) matches := re.FindAllStringSubmatch(string(body), -1)
if len(matches) == 0 { if len(matches) == 0 {
return nil, errors.New("found no DNS servers in /etc/resolv.conf") return nil, errors.Error("found no DNS servers in /etc/resolv.conf")
} }
addrs := make([]string, 0) addrs := make([]string, 0)

View File

@ -6,7 +6,6 @@ package aghnet
import ( import (
"bufio" "bufio"
"errors"
"fmt" "fmt"
"io" "io"
"net" "net"
@ -14,6 +13,7 @@ import (
"strings" "strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghio" "github.com/AdguardTeam/AdGuardHome/internal/aghio"
"github.com/AdguardTeam/golibs/errors"
"github.com/google/renameio/maybe" "github.com/google/renameio/maybe"
) )
@ -49,16 +49,15 @@ func ifaceHasStaticIP(ifaceName string) (has bool, err error) {
return false, err return false, err
} }
defer f.Close() defer func() { err = errors.WithDeferred(err, f.Close()) }()
var fileReadCloser io.ReadCloser var fileReader io.Reader
fileReadCloser, err = aghio.LimitReadCloser(f, maxConfigFileSize) fileReader, err = aghio.LimitReader(f, maxConfigFileSize)
if err != nil { if err != nil {
return false, err return false, err
} }
defer fileReadCloser.Close()
has, err = check.checker(fileReadCloser, ifaceName) has, err = check.checker(fileReader, ifaceName)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -134,7 +133,7 @@ func ifacesStaticConfig(r io.Reader, ifaceName string) (has bool, err error) {
func ifaceSetStaticIP(ifaceName string) (err error) { func ifaceSetStaticIP(ifaceName string) (err error) {
ipNet := GetSubnet(ifaceName) ipNet := GetSubnet(ifaceName)
if ipNet.IP == nil { if ipNet.IP == nil {
return errors.New("can't get IP address") return errors.Error("can't get IP address")
} }
gatewayIP := GatewayIP(ifaceName) gatewayIP := GatewayIP(ifaceName)

View File

@ -3,7 +3,7 @@ package aghnet
import ( import (
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/agherr" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
) )
@ -27,19 +27,19 @@ type SystemResolvers interface {
const ( const (
// errBadAddrPassed is returned when dialFunc can't parse an IP address. // errBadAddrPassed is returned when dialFunc can't parse an IP address.
errBadAddrPassed agherr.Error = "the passed string is not a valid IP address" errBadAddrPassed errors.Error = "the passed string is not a valid IP address"
// errFakeDial is an error which dialFunc is expected to return. // errFakeDial is an error which dialFunc is expected to return.
errFakeDial agherr.Error = "this error signals the successful dialFunc work" errFakeDial errors.Error = "this error signals the successful dialFunc work"
// errUnexpectedHostFormat is returned by validateDialedHost when the host has // errUnexpectedHostFormat is returned by validateDialedHost when the host has
// more than one percent sign. // more than one percent sign.
errUnexpectedHostFormat agherr.Error = "unexpected host format" errUnexpectedHostFormat errors.Error = "unexpected host format"
) )
// refreshWithTicker refreshes the cache of sr after each tick form tickCh. // refreshWithTicker refreshes the cache of sr after each tick form tickCh.
func refreshWithTicker(sr SystemResolvers, tickCh <-chan time.Time) { func refreshWithTicker(sr SystemResolvers, tickCh <-chan time.Time) {
defer agherr.LogPanic("systemResolvers") defer log.OnPanic("systemResolvers")
// TODO(e.burkov): Implement a functionality to stop ticker. // TODO(e.burkov): Implement a functionality to stop ticker.
for range tickCh { for range tickCh {

View File

@ -6,15 +6,14 @@ package aghnet
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"net" "net"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/AdGuardHome/internal/aghstrings" "github.com/AdguardTeam/AdGuardHome/internal/aghstrings"
"github.com/AdguardTeam/golibs/errors"
) )
// defaultHostGen is the default method of generating host for Refresh. // defaultHostGen is the default method of generating host for Refresh.
@ -34,7 +33,7 @@ type systemResolvers struct {
} }
func (sr *systemResolvers) refresh() (err error) { func (sr *systemResolvers) refresh() (err error) {
defer agherr.Annotate("systemResolvers: %w", &err) defer func() { err = errors.Annotate(err, "systemResolvers: %w") }()
_, err = sr.resolver.LookupHost(context.Background(), sr.hostGenFunc()) _, err = sr.resolver.LookupHost(context.Background(), sr.hostGenFunc())
dnserr := &net.DNSError{} dnserr := &net.DNSError{}
@ -63,7 +62,7 @@ func newSystemResolvers(refreshIvl time.Duration, hostGenFunc HostGenFunc) (sr S
// validateDialedHost validated the host used by resolvers in dialFunc. // validateDialedHost validated the host used by resolvers in dialFunc.
func validateDialedHost(host string) (err error) { func validateDialedHost(host string) (err error) {
defer agherr.Annotate("parsing %q: %w", &err, host) defer func() { err = errors.Annotate(err, "parsing %q: %w", host) }()
var ipStr string var ipStr string
parts := strings.Split(host, "%") parts := strings.Split(host, "%")

View File

@ -14,9 +14,9 @@ import (
"sync" "sync"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/AdGuardHome/internal/aghio" "github.com/AdguardTeam/AdGuardHome/internal/aghio"
"github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
) )
@ -65,14 +65,15 @@ func (sr *systemResolvers) getAddrs() (addrs []string, err error) {
return nil, fmt.Errorf("getting the command's stdout pipe: %w", err) return nil, fmt.Errorf("getting the command's stdout pipe: %w", err)
} }
var stdoutLimited io.ReadCloser var stdoutLimited io.Reader
stdoutLimited, err = aghio.LimitReadCloser(stdout, aghos.MaxCmdOutputSize) stdoutLimited, err = aghio.LimitReader(stdout, aghos.MaxCmdOutputSize)
if err != nil { if err != nil {
return nil, fmt.Errorf("limiting stdout reader: %w", err) return nil, fmt.Errorf("limiting stdout reader: %w", err)
} }
go func() { go func() {
defer agherr.LogPanic("systemResolvers") defer log.OnPanic("systemResolvers")
defer func() { defer func() {
derr := stdin.Close() derr := stdin.Close()
if derr != nil { if derr != nil {
@ -141,7 +142,7 @@ func (sr *systemResolvers) getAddrs() (addrs []string, err error) {
} }
func (sr *systemResolvers) refresh() (err error) { func (sr *systemResolvers) refresh() (err error) {
defer agherr.Annotate("systemResolvers: %w", &err) defer func() { err = errors.Annotate(err, "systemResolvers: %w") }()
got, err := sr.getAddrs() got, err := sr.getAddrs()
if err != nil { if err != nil {

View File

@ -30,7 +30,7 @@ func ReplaceLogWriter(t *testing.T, w io.Writer) {
// ReplaceLogLevel sets logging level to l and uses Cleanup method of t to // ReplaceLogLevel sets logging level to l and uses Cleanup method of t to
// revert changes. // revert changes.
func ReplaceLogLevel(t *testing.T, l int) { func ReplaceLogLevel(t *testing.T, l log.Level) {
switch l { switch l {
case log.INFO, log.DEBUG, log.ERROR: case log.INFO, log.DEBUG, log.ERROR:
// Go on. // Go on.

View File

@ -168,9 +168,6 @@ type TestErrUpstream struct {
// Exchange always returns nil Msg and non-nil error. // Exchange always returns nil Msg and non-nil error.
func (u *TestErrUpstream) Exchange(*dns.Msg) (*dns.Msg, error) { func (u *TestErrUpstream) Exchange(*dns.Msg) (*dns.Msg, error) {
// We don't use an agherr.Error to avoid the import cycle since aghtests
// used to provide the utilities for testing which agherr (and any other
// testable package) should be able to use.
return nil, fmt.Errorf("errupstream: %w", u.Err) return nil, fmt.Errorf("errupstream: %w", u.Err)
} }

View File

@ -12,6 +12,7 @@ import (
"runtime" "runtime"
"time" "time"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/insomniacslk/dhcp/dhcpv4" "github.com/insomniacslk/dhcp/dhcpv4"
"github.com/insomniacslk/dhcp/dhcpv4/nclient4" "github.com/insomniacslk/dhcp/dhcpv4/nclient4"
@ -78,7 +79,7 @@ func CheckIfOtherDHCPServersPresentV4(ifaceName string) (ok bool, err error) {
return false, fmt.Errorf("couldn't listen on :68: %w", err) return false, fmt.Errorf("couldn't listen on :68: %w", err)
} }
if c != nil { if c != nil {
defer c.Close() defer func() { err = errors.WithDeferred(err, c.Close()) }()
} }
// send to 255.255.255.255:67 // send to 255.255.255.255:67
@ -202,7 +203,7 @@ func CheckIfOtherDHCPServersPresentV6(ifaceName string) (ok bool, err error) {
return false, fmt.Errorf("dhcpv6: Couldn't listen on :546: %w", err) return false, fmt.Errorf("dhcpv6: Couldn't listen on :546: %w", err)
} }
if c != nil { if c != nil {
defer c.Close() defer func() { err = errors.WithDeferred(err, c.Close()) }()
} }
_, err = c.WriteTo(req.ToBytes(), dstAddr) _, err = c.WriteTo(req.ToBytes(), dstAddr)

View File

@ -4,11 +4,11 @@ package dhcpd
import ( import (
"encoding/json" "encoding/json"
"errors"
"net" "net"
"os" "os"
"time" "time"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/google/renameio/maybe" "github.com/google/renameio/maybe"
) )

View File

@ -2,7 +2,6 @@ package dhcpd
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"net" "net"
@ -11,6 +10,7 @@ import (
"strings" "strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
) )

View File

@ -6,7 +6,7 @@ import (
"math/big" "math/big"
"net" "net"
"github.com/AdguardTeam/AdGuardHome/internal/agherr" "github.com/AdguardTeam/golibs/errors"
) )
// ipRange is an inclusive range of IP addresses. A nil range is a range that // ipRange is an inclusive range of IP addresses. A nil range is a range that
@ -28,7 +28,7 @@ const maxRangeLen = math.MaxUint32
// newIPRange creates a new IP address range. start must be less than end. The // newIPRange creates a new IP address range. start must be less than end. The
// resulting range must not be greater than maxRangeLen. // resulting range must not be greater than maxRangeLen.
func newIPRange(start, end net.IP) (r *ipRange, err error) { func newIPRange(start, end net.IP) (r *ipRange, err error) {
defer agherr.Annotate("invalid ip range: %w", &err) defer func() { err = errors.Annotate(err, "invalid ip range: %w") }()
// Make sure that both are 16 bytes long to simplify handling in // Make sure that both are 16 bytes long to simplify handling in
// methods. // methods.

View File

@ -7,7 +7,7 @@ import (
"strconv" "strconv"
"strings" "strings"
"github.com/AdguardTeam/AdGuardHome/internal/agherr" "github.com/AdguardTeam/golibs/errors"
) )
// hexDHCPOptionParserHandler parses a DHCP option as a hex-encoded string. // hexDHCPOptionParserHandler parses a DHCP option as a hex-encoded string.
@ -32,7 +32,7 @@ func hexDHCPOptionParserHandler(s string) (data []byte, err error) {
func ipDHCPOptionParserHandler(s string) (data []byte, err error) { func ipDHCPOptionParserHandler(s string) (data []byte, err error) {
ip := net.ParseIP(s) ip := net.ParseIP(s)
if ip == nil { if ip == nil {
return nil, agherr.Error("invalid ip") return nil, errors.Error("invalid ip")
} }
// Most DHCP options require IPv4, so do not put the 16-byte // Most DHCP options require IPv4, so do not put the 16-byte
@ -100,12 +100,12 @@ func newDHCPOptionParser() (p *dhcpOptionParser) {
// parse parses an option. See the handlers' documentation for more info. // parse parses an option. See the handlers' documentation for more info.
func (p *dhcpOptionParser) parse(s string) (code uint8, data []byte, err error) { func (p *dhcpOptionParser) parse(s string) (code uint8, data []byte, err error) {
defer agherr.Annotate("invalid option string %q: %w", &err, s) defer func() { err = errors.Annotate(err, "invalid option string %q: %w", s) }()
s = strings.TrimSpace(s) s = strings.TrimSpace(s)
parts := strings.SplitN(s, " ", 3) parts := strings.SplitN(s, " ", 3)
if len(parts) < 3 { if len(parts) < 3 {
return 0, nil, agherr.Error("need at least three fields") return 0, nil, errors.Error("need at least three fields")
} }
codeStr := parts[0] codeStr := parts[0]

View File

@ -1,13 +1,13 @@
package dhcpd package dhcpd
import ( import (
"errors"
"net" "net"
"github.com/AdguardTeam/golibs/errors"
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
) )
// Create a socket for receiving broadcast packets // Create a socket for receiving broadcast packets
func newBroadcastPacketConn(bindAddr net.IP, port int, ifname string) (*ipv4.PacketConn, error) { func newBroadcastPacketConn(bindAddr net.IP, port int, ifname string) (*ipv4.PacketConn, error) {
return nil, errors.New("newBroadcastPacketConn(): not supported on Windows") return nil, errors.Error("newBroadcastPacketConn(): not supported on Windows")
} }

View File

@ -6,16 +6,15 @@ package dhcpd
import ( import (
"bytes" "bytes"
"errors"
"fmt" "fmt"
"net" "net"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghstrings" "github.com/AdguardTeam/AdGuardHome/internal/aghstrings"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/go-ping/ping" "github.com/go-ping/ping"
"github.com/insomniacslk/dhcp/dhcpv4" "github.com/insomniacslk/dhcp/dhcpv4"
@ -55,7 +54,7 @@ func (s *v4Server) WriteDiskConfig6(c *V6ServerConf) {
// normalizeHostname normalizes a hostname sent by the client. If err is not // normalizeHostname normalizes a hostname sent by the client. If err is not
// nil, norm is an empty string. // nil, norm is an empty string.
func normalizeHostname(hostname string) (norm string, err error) { func normalizeHostname(hostname string) (norm string, err error) {
defer agherr.Annotate("normalizing %q: %w", &err, hostname) defer func() { err = errors.Annotate(err, "normalizing %q: %w", hostname) }()
if hostname == "" { if hostname == "" {
return "", nil return "", nil
@ -249,7 +248,7 @@ func (s *v4Server) rmDynamicLease(lease *Lease) (err error) {
if bytes.Equal(l.HWAddr, lease.HWAddr) { if bytes.Equal(l.HWAddr, lease.HWAddr) {
if l.IsStatic() { if l.IsStatic() {
return agherr.Error("static lease already exists") return errors.Error("static lease already exists")
} }
s.rmLeaseByIndex(i) s.rmLeaseByIndex(i)
@ -262,7 +261,7 @@ func (s *v4Server) rmDynamicLease(lease *Lease) (err error) {
if l.IP.Equal(lease.IP) { if l.IP.Equal(lease.IP) {
if l.IsStatic() { if l.IsStatic() {
return agherr.Error("static lease already exists") return errors.Error("static lease already exists")
} }
s.rmLeaseByIndex(i) s.rmLeaseByIndex(i)
@ -322,12 +321,12 @@ func (s *v4Server) rmLease(lease Lease) (err error) {
} }
} }
return agherr.Error("lease not found") return errors.Error("lease not found")
} }
// AddStaticLease adds a static lease. It is safe for concurrent use. // AddStaticLease adds a static lease. It is safe for concurrent use.
func (s *v4Server) AddStaticLease(l Lease) (err error) { func (s *v4Server) AddStaticLease(l Lease) (err error) {
defer agherr.Annotate("dhcpv4: adding static lease: %w", &err) defer func() { err = errors.Annotate(err, "dhcpv4: adding static lease: %w") }()
if ip4 := l.IP.To4(); ip4 == nil { if ip4 := l.IP.To4(); ip4 == nil {
return fmt.Errorf("invalid ip %q, only ipv4 is supported", l.IP) return fmt.Errorf("invalid ip %q, only ipv4 is supported", l.IP)
@ -397,7 +396,7 @@ func (s *v4Server) AddStaticLease(l Lease) (err error) {
// RemoveStaticLease removes a static lease. It is safe for concurrent use. // RemoveStaticLease removes a static lease. It is safe for concurrent use.
func (s *v4Server) RemoveStaticLease(l Lease) (err error) { func (s *v4Server) RemoveStaticLease(l Lease) (err error) {
defer agherr.Annotate("dhcpv4: %w", &err) defer func() { err = errors.Annotate(err, "dhcpv4: %w") }()
if len(l.IP) != 4 { if len(l.IP) != 4 {
return fmt.Errorf("invalid IP") return fmt.Errorf("invalid IP")
@ -937,7 +936,7 @@ func (s *v4Server) packetHandler(conn net.PacketConn, peer net.Addr, req *dhcpv4
// Start starts the IPv4 DHCP server. // Start starts the IPv4 DHCP server.
func (s *v4Server) Start() (err error) { func (s *v4Server) Start() (err error) {
defer agherr.Annotate("dhcpv4: %w", &err) defer func() { err = errors.Annotate(err, "dhcpv4: %w") }()
if !s.conf.Enabled { if !s.conf.Enabled {
return nil return nil

View File

@ -1,11 +1,10 @@
package dhcpd package dhcpd
import ( import (
"errors"
"net" "net"
"testing" "testing"
"github.com/AdguardTeam/AdGuardHome/internal/agherr" "github.com/AdguardTeam/golibs/errors"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -25,7 +24,7 @@ func (iface *fakeIface) Addrs() (addrs []net.Addr, err error) {
} }
func TestIfaceIPAddrs(t *testing.T) { func TestIfaceIPAddrs(t *testing.T) {
const errTest agherr.Error = "test error" const errTest errors.Error = "test error"
ip4 := net.IP{1, 2, 3, 4} ip4 := net.IP{1, 2, 3, 4}
addr4 := &net.IPNet{IP: ip4} addr4 := &net.IPNet{IP: ip4}
@ -108,7 +107,7 @@ func (iface *waitingFakeIface) Addrs() (addrs []net.Addr, err error) {
} }
func TestIfaceDNSIPAddrs(t *testing.T) { func TestIfaceDNSIPAddrs(t *testing.T) {
const errTest agherr.Error = "test error" const errTest errors.Error = "test error"
ip4 := net.IP{1, 2, 3, 4} ip4 := net.IP{1, 2, 3, 4}
addr4 := &net.IPNet{IP: ip4} addr4 := &net.IPNet{IP: ip4}

View File

@ -11,8 +11,8 @@ import (
"sync" "sync"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/insomniacslk/dhcp/dhcpv6" "github.com/insomniacslk/dhcp/dhcpv6"
"github.com/insomniacslk/dhcp/dhcpv6/server6" "github.com/insomniacslk/dhcp/dhcpv6/server6"
@ -165,7 +165,7 @@ func (s *v6Server) rmDynamicLease(lease Lease) error {
// AddStaticLease adds a static lease. It is safe for concurrent use. // AddStaticLease adds a static lease. It is safe for concurrent use.
func (s *v6Server) AddStaticLease(l Lease) (err error) { func (s *v6Server) AddStaticLease(l Lease) (err error) {
defer agherr.Annotate("dhcpv6: %w", &err) defer func() { err = errors.Annotate(err, "dhcpv6: %w") }()
if len(l.IP) != 16 { if len(l.IP) != 16 {
return fmt.Errorf("invalid IP") return fmt.Errorf("invalid IP")
@ -194,7 +194,7 @@ func (s *v6Server) AddStaticLease(l Lease) (err error) {
// RemoveStaticLease removes a static lease. It is safe for concurrent use. // RemoveStaticLease removes a static lease. It is safe for concurrent use.
func (s *v6Server) RemoveStaticLease(l Lease) (err error) { func (s *v6Server) RemoveStaticLease(l Lease) (err error) {
defer agherr.Annotate("dhcpv6: %w", &err) defer func() { err = errors.Annotate(err, "dhcpv6: %w") }()
if len(l.IP) != 16 { if len(l.IP) != 16 {
return fmt.Errorf("invalid IP") return fmt.Errorf("invalid IP")
@ -585,7 +585,7 @@ func (s *v6Server) initRA(iface *net.Interface) error {
// Start starts the IPv6 DHCP server. // Start starts the IPv6 DHCP server.
func (s *v6Server) Start() (err error) { func (s *v6Server) Start() (err error) {
defer agherr.Annotate("dhcpv6: %w", &err) defer func() { err = errors.Annotate(err, "dhcpv6: %w") }()
if !s.conf.Enabled { if !s.conf.Enabled {
return nil return nil

View File

@ -2,13 +2,13 @@ package dnsforward
import ( import (
"crypto/tls" "crypto/tls"
"errors"
"fmt" "fmt"
"path" "path"
"strings" "strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/errors"
"github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go"
) )

View File

@ -3,7 +3,6 @@ package dnsforward
import ( import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"errors"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
@ -14,6 +13,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/ameshkov/dnscrypt/v2" "github.com/ameshkov/dnscrypt/v2"
) )
@ -220,7 +220,7 @@ func (s *Server) createProxyConfig() (proxy.Config, error) {
// Validate proxy config // Validate proxy config
if proxyConfig.UpstreamConfig == nil || len(proxyConfig.UpstreamConfig.Upstreams) == 0 { if proxyConfig.UpstreamConfig == nil || len(proxyConfig.UpstreamConfig.Upstreams) == 0 {
return proxyConfig, errors.New("no default upstream servers configured") return proxyConfig, errors.Error("no default upstream servers configured")
} }
return proxyConfig, nil return proxyConfig, nil

View File

@ -2,7 +2,6 @@
package dnsforward package dnsforward
import ( import (
"errors"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
@ -12,7 +11,6 @@ import (
"sync" "sync"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghstrings" "github.com/AdguardTeam/AdGuardHome/internal/aghstrings"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd" "github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
@ -21,6 +19,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/stats" "github.com/AdguardTeam/AdGuardHome/internal/stats"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@ -226,11 +225,11 @@ type RDNSExchanger interface {
const ( const (
// rDNSEmptyAnswerErr is returned by Exchange method when the answer // rDNSEmptyAnswerErr is returned by Exchange method when the answer
// section of respond is empty. // section of respond is empty.
rDNSEmptyAnswerErr agherr.Error = "the answer section is empty" rDNSEmptyAnswerErr errors.Error = "the answer section is empty"
// rDNSNotPTRErr is returned by Exchange method when the response is not // rDNSNotPTRErr is returned by Exchange method when the response is not
// of PTR type. // of PTR type.
rDNSNotPTRErr agherr.Error = "the response is not a ptr" rDNSNotPTRErr errors.Error = "the response is not a ptr"
) )
// Exchange implements the RDNSExchanger interface for *Server. // Exchange implements the RDNSExchanger interface for *Server.

View File

@ -17,13 +17,13 @@ import (
"testing" "testing"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd" "github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -1198,7 +1198,7 @@ func TestServer_Exchange(t *testing.T) {
"2.1.168.192.in-addr.arpa.": {}, "2.1.168.192.in-addr.arpa.": {},
}, },
} }
upstreamErr := agherr.Error("upstream error") upstreamErr := errors.Error("upstream error")
errUpstream := &aghtest.TestErrUpstream{ errUpstream := &aghtest.TestErrUpstream{
Err: upstreamErr, Err: upstreamErr,
} }

View File

@ -4,9 +4,9 @@ import (
"fmt" "fmt"
"net" "net"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/urlfilter/rules" "github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -83,7 +83,7 @@ func (s *Server) filterDNSRewrite(req *dns.Msg, res filtering.Result, d *proxy.D
resp := s.makeResponse(req) resp := s.makeResponse(req)
dnsrr := res.DNSRewriteResult dnsrr := res.DNSRewriteResult
if dnsrr == nil { if dnsrr == nil {
return agherr.Error("no dns rewrite rule content") return errors.Error("no dns rewrite rule content")
} }
resp.Rcode = dnsrr.RCode resp.Rcode = dnsrr.RCode
@ -94,7 +94,7 @@ func (s *Server) filterDNSRewrite(req *dns.Msg, res filtering.Result, d *proxy.D
} }
if dnsrr.Response == nil { if dnsrr.Response == nil {
return agherr.Error("no dns rewrite rule responses") return errors.Error("no dns rewrite rule responses")
} }
rr := req.Question[0].Qtype rr := req.Question[0].Qtype

View File

@ -8,11 +8,11 @@ import (
"strconv" "strconv"
"strings" "strings"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghstrings" "github.com/AdguardTeam/AdGuardHome/internal/aghstrings"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@ -399,7 +399,7 @@ func validateUpstream(u string) (bool, error) {
// separateUpstream returns the upstream without the specified domains. // separateUpstream returns the upstream without the specified domains.
// useDefault is true when a default upstream must be used. // useDefault is true when a default upstream must be used.
func separateUpstream(upstreamStr string) (upstream string, useDefault bool, err error) { func separateUpstream(upstreamStr string) (upstream string, useDefault bool, err error) {
defer agherr.Annotate("bad upstream for domain spec %q: %w", &err, upstreamStr) defer func() { err = errors.Annotate(err, "bad upstream for domain spec %q: %w", upstreamStr) }()
if !strings.HasPrefix(upstreamStr, "[/") { if !strings.HasPrefix(upstreamStr, "[/") {
return upstreamStr, true, nil return upstreamStr, true, nil
@ -407,7 +407,7 @@ func separateUpstream(upstreamStr string) (upstream string, useDefault bool, err
parts := strings.Split(upstreamStr[2:], "/]") parts := strings.Split(upstreamStr[2:], "/]")
if len(parts) != 2 { if len(parts) != 2 {
return "", false, agherr.Error("duplicated separator") return "", false, errors.Error("duplicated separator")
} }
domains := parts[0] domains := parts[0]

View File

@ -10,7 +10,7 @@ import (
"strings" "strings"
"sync" "sync"
"github.com/AdguardTeam/AdGuardHome/internal/agherr" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/digineo/go-ipset/v2" "github.com/digineo/go-ipset/v2"
"github.com/mdlayher/netlink" "github.com/mdlayher/netlink"
@ -83,7 +83,7 @@ func (c *ipsetCtx) ipsetProps(name string) (set ipsetProps, err error) {
} }
if res == nil || res.Family == nil { if res == nil || res.Family == nil {
return set, agherr.Error("empty response or no family data") return set, errors.Error("empty response or no family data")
} }
family := netfilter.ProtoFamily(res.Family.Value) family := netfilter.ProtoFamily(res.Family.Value)
@ -193,23 +193,23 @@ func (c *ipsetCtx) init(ipsetConfig []string) (err error) {
// Close closes the Linux Netfilter connections. // Close closes the Linux Netfilter connections.
func (c *ipsetCtx) Close() (err error) { func (c *ipsetCtx) Close() (err error) {
var errors []error var errs []error
if c.ipv4Conn != nil { if c.ipv4Conn != nil {
err = c.ipv4Conn.Close() err = c.ipv4Conn.Close()
if err != nil { if err != nil {
errors = append(errors, err) errs = append(errs, err)
} }
} }
if c.ipv6Conn != nil { if c.ipv6Conn != nil {
err = c.ipv6Conn.Close() err = c.ipv6Conn.Close()
if err != nil { if err != nil {
errors = append(errors, err) errs = append(errs, err)
} }
} }
if len(errors) != 0 { if len(errs) != 0 {
return agherr.Many("closing ipsets", errors...) return errors.List("closing ipsets", errs...)
} }
return nil return nil

View File

@ -73,22 +73,27 @@ func glGetTokenDate(file string) uint32 {
f, err := os.Open(file) f, err := os.Open(file)
if err != nil { if err != nil {
log.Error("os.Open: %s", err) log.Error("os.Open: %s", err)
return 0 return 0
} }
defer f.Close() defer func() {
derr := f.Close()
if derr != nil {
log.Error("glinet: closing file: %s", err)
}
}()
fileReadCloser, err := aghio.LimitReadCloser(f, MaxFileSize) fileReader, err := aghio.LimitReader(f, MaxFileSize)
if err != nil { if err != nil {
log.Error("creating limited reader: %s", err) log.Error("creating limited reader: %s", err)
return 0 return 0
} }
defer fileReadCloser.Close()
var dateToken uint32 var dateToken uint32
// This use of ReadAll is now safe, because we limited reader. // This use of ReadAll is now safe, because we limited reader.
bs, err := io.ReadAll(fileReadCloser) bs, err := io.ReadAll(fileReader)
if err != nil { if err != nil {
log.Error("reading token: %s", err) log.Error("reading token: %s", err)

View File

@ -11,7 +11,6 @@ import (
"sync" "sync"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghstrings" "github.com/AdguardTeam/AdGuardHome/internal/aghstrings"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd" "github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
@ -20,6 +19,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/querylog" "github.com/AdguardTeam/AdGuardHome/internal/querylog"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
) )
@ -438,11 +438,11 @@ func (clients *clientsContainer) FindRuntimeClient(ip string) (RuntimeClient, bo
func (clients *clientsContainer) check(c *Client) (err error) { func (clients *clientsContainer) check(c *Client) (err error) {
switch { switch {
case c == nil: case c == nil:
return agherr.Error("client is nil") return errors.Error("client is nil")
case c.Name == "": case c.Name == "":
return agherr.Error("invalid name") return errors.Error("invalid name")
case len(c.IDs) == 0: case len(c.IDs) == 0:
return agherr.Error("id required") return errors.Error("id required")
default: default:
// Go on. // Go on.
} }
@ -570,14 +570,14 @@ func (clients *clientsContainer) Update(name string, c *Client) (err error) {
prev, ok := clients.list[name] prev, ok := clients.list[name]
if !ok { if !ok {
return agherr.Error("client not found") return errors.Error("client not found")
} }
// First, check the name index. // First, check the name index.
if prev.Name != c.Name { if prev.Name != c.Name {
_, ok = clients.list[c.Name] _, ok = clients.list[c.Name]
if ok { if ok {
return agherr.Error("client already exists") return errors.Error("client already exists")
} }
} }

View File

@ -1,7 +1,6 @@
package home package home
import ( import (
"errors"
"fmt" "fmt"
"net" "net"
"os" "os"
@ -14,6 +13,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/querylog" "github.com/AdguardTeam/AdGuardHome/internal/querylog"
"github.com/AdguardTeam/AdGuardHome/internal/stats" "github.com/AdguardTeam/AdGuardHome/internal/stats"
"github.com/AdguardTeam/AdGuardHome/internal/version" "github.com/AdguardTeam/AdGuardHome/internal/version"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/google/renameio/maybe" "github.com/google/renameio/maybe"
yaml "gopkg.in/yaml.v2" yaml "gopkg.in/yaml.v2"

View File

@ -15,7 +15,6 @@ import (
"strings" "strings"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
) )
@ -271,7 +270,7 @@ func copyInstallSettings(dst, src *configuration) {
const shutdownTimeout = 5 * time.Second const shutdownTimeout = 5 * time.Second
func shutdownSrv(ctx context.Context, cancel context.CancelFunc, srv *http.Server) { func shutdownSrv(ctx context.Context, cancel context.CancelFunc, srv *http.Server) {
defer agherr.LogPanic("") defer log.OnPanic("")
if srv == nil { if srv == nil {
return return

View File

@ -3,7 +3,6 @@ package home
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"net/http" "net/http"
"os" "os"
"os/exec" "os/exec"
@ -14,6 +13,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/AdGuardHome/internal/updater" "github.com/AdguardTeam/AdGuardHome/internal/updater"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
) )

View File

@ -8,12 +8,12 @@ import (
"path/filepath" "path/filepath"
"strconv" "strconv"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/querylog" "github.com/AdguardTeam/AdGuardHome/internal/querylog"
"github.com/AdguardTeam/AdGuardHome/internal/stats" "github.com/AdguardTeam/AdGuardHome/internal/stats"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/ameshkov/dnscrypt/v2" "github.com/ameshkov/dnscrypt/v2"
yaml "gopkg.in/yaml.v2" yaml "gopkg.in/yaml.v2"
@ -207,14 +207,14 @@ func generateServerConfig() (newConf dnsforward.ServerConfig, err error) {
func newDNSCrypt(hosts []net.IP, tlsConf tlsConfigSettings) (dnscc dnsforward.DNSCryptConfig, err error) { func newDNSCrypt(hosts []net.IP, tlsConf tlsConfigSettings) (dnscc dnsforward.DNSCryptConfig, err error) {
if tlsConf.DNSCryptConfigFile == "" { if tlsConf.DNSCryptConfigFile == "" {
return dnscc, agherr.Error("no dnscrypt_config_file") return dnscc, errors.Error("no dnscrypt_config_file")
} }
f, err := os.Open(tlsConf.DNSCryptConfigFile) f, err := os.Open(tlsConf.DNSCryptConfigFile)
if err != nil { if err != nil {
return dnscc, fmt.Errorf("opening dnscrypt config: %w", err) return dnscc, fmt.Errorf("opening dnscrypt config: %w", err)
} }
defer f.Close() defer func() { err = errors.WithDeferred(err, f.Close()) }()
rc := &dnscrypt.ResolverConfig{} rc := &dnscrypt.ResolverConfig{}
err = yaml.NewDecoder(f).Decode(rc) err = yaml.NewDecoder(f).Decode(rc)

View File

@ -2,7 +2,6 @@ package home
import ( import (
"bufio" "bufio"
"errors"
"fmt" "fmt"
"hash/crc32" "hash/crc32"
"io" "io"
@ -17,6 +16,7 @@ import (
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
) )
@ -566,17 +566,18 @@ func (f *Filtering) updateIntl(filter *filter) (updated bool, err error) {
if err != nil { if err != nil {
return updated, fmt.Errorf("open file: %w", err) return updated, fmt.Errorf("open file: %w", err)
} }
defer func() { err = errors.WithDeferred(err, f.Close()) }()
defer f.Close()
reader = f reader = f
} else { } else {
var resp *http.Response var resp *http.Response
resp, err = Context.client.Get(filter.URL) resp, err = Context.client.Get(filter.URL)
if err != nil { if err != nil {
log.Printf("Couldn't request filter from URL %s, skipping: %s", filter.URL, err) log.Printf("Couldn't request filter from URL %s, skipping: %s", filter.URL, err)
return updated, err return updated, err
} }
defer resp.Body.Close() defer func() { err = errors.WithDeferred(err, resp.Body.Close()) }()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
log.Printf("Got status code %d from URL %s, skipping", resp.StatusCode, filter.URL) log.Printf("Got status code %d from URL %s, skipping", resp.StatusCode, filter.URL)
@ -634,7 +635,7 @@ func (f *Filtering) load(filter *filter) (err error) {
} else if err != nil { } else if err != nil {
return fmt.Errorf("opening filter file: %w", err) return fmt.Errorf("opening filter file: %w", err)
} }
defer file.Close() defer func() { err = errors.WithDeferred(err, file.Close()) }()
st, err := file.Stat() st, err := file.Stat()
if err != nil { if err != nil {

View File

@ -5,7 +5,6 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"errors"
"fmt" "fmt"
"io/fs" "io/fs"
"net" "net"
@ -21,7 +20,6 @@ import (
"syscall" "syscall"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd" "github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
@ -31,6 +29,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/stats" "github.com/AdguardTeam/AdGuardHome/internal/stats"
"github.com/AdguardTeam/AdGuardHome/internal/updater" "github.com/AdguardTeam/AdGuardHome/internal/updater"
"github.com/AdguardTeam/AdGuardHome/internal/version" "github.com/AdguardTeam/AdGuardHome/internal/version"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"gopkg.in/natefinch/lumberjack.v2" "gopkg.in/natefinch/lumberjack.v2"
) )
@ -736,7 +735,7 @@ func customDialContext(ctx context.Context, network, addr string) (conn net.Conn
return conn, err return conn, err
} }
return nil, agherr.Many(fmt.Sprintf("couldn't dial to %s", addr), dialErrs...) return nil, errors.List(fmt.Sprintf("couldn't dial to %s", addr), dialErrs...)
} }
func getHTTPProxy(_ *http.Request) (*url.URL, error) { func getHTTPProxy(_ *http.Request) (*url.URL, error) {

View File

@ -1,6 +1,7 @@
package home package home
import ( import (
"io"
"net/http" "net/http"
"github.com/AdguardTeam/AdGuardHome/internal/aghio" "github.com/AdguardTeam/AdGuardHome/internal/aghio"
@ -58,14 +59,20 @@ func limitRequestBody(h http.Handler) (limited http.Handler) {
szLim = largerReqBodySzLim szLim = largerReqBodySzLim
} }
r.Body, err = aghio.LimitReadCloser(r.Body, szLim) var reader io.Reader
reader, err = aghio.LimitReader(r.Body, szLim)
if err != nil { if err != nil {
log.Error("limitRequestBody: %s", err) log.Error("limitRequestBody: %s", err)
return return
} }
h.ServeHTTP(w, r) // HTTP handlers aren't supposed to call r.Body.Close(), so just
// replace the body in a clone.
rr := r.Clone(r.Context())
rr.Body = io.NopCloser(reader)
h.ServeHTTP(w, rr)
}) })
} }

View File

@ -5,7 +5,6 @@ import (
"net" "net"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/golibs/cache" "github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
@ -85,7 +84,7 @@ func (r *RDNS) Begin(ip net.IP) {
// workerLoop handles incoming IP addresses from ipChan and adds it into // workerLoop handles incoming IP addresses from ipChan and adds it into
// clients. // clients.
func (r *RDNS) workerLoop() { func (r *RDNS) workerLoop() {
defer agherr.LogPanic("rdns") defer log.OnPanic("rdns")
for ip := range r.ipCh { for ip := range r.ipCh {
host, err := r.exchanger.Exchange(ip) host, err := r.exchanger.Exchange(ip)

View File

@ -3,7 +3,6 @@ package home
import ( import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"errors"
"net" "net"
"sync" "sync"
"testing" "testing"
@ -13,6 +12,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/cache" "github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -141,7 +141,7 @@ func TestRDNS_WorkerLoop(t *testing.T) {
}, },
} }
errUpstream := &aghtest.TestErrUpstream{ errUpstream := &aghtest.TestErrUpstream{
Err: errors.New("1234"), Err: errors.Error("1234"),
} }
testCases := []struct { testCases := []struct {

View File

@ -1,7 +1,6 @@
package home package home
import ( import (
"errors"
"fmt" "fmt"
"io/fs" "io/fs"
"os" "os"
@ -11,6 +10,7 @@ import (
"syscall" "syscall"
"github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/kardianos/service" "github.com/kardianos/service"
) )

View File

@ -10,7 +10,6 @@ import (
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"os" "os"
@ -21,6 +20,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"golang.org/x/sys/cpu" "golang.org/x/sys/cpu"
) )
@ -341,14 +341,14 @@ func verifyCertChain(data *tlsConfigStatus, certChain, serverName string) error
parsed, err := x509.ParseCertificate(cert.Bytes) parsed, err := x509.ParseCertificate(cert.Bytes)
if err != nil { if err != nil {
data.WarningValidation = fmt.Sprintf("Failed to parse certificate: %s", err) data.WarningValidation = fmt.Sprintf("Failed to parse certificate: %s", err)
return errors.New(data.WarningValidation) return errors.Error(data.WarningValidation)
} }
parsedCerts = append(parsedCerts, parsed) parsedCerts = append(parsedCerts, parsed)
} }
if len(parsedCerts) == 0 { if len(parsedCerts) == 0 {
data.WarningValidation = "You have specified an empty certificate" data.WarningValidation = "You have specified an empty certificate"
return errors.New(data.WarningValidation) return errors.Error(data.WarningValidation)
} }
data.ValidCert = true data.ValidCert = true
@ -415,14 +415,14 @@ func validatePkey(data *tlsConfigStatus, pkey string) error {
if key == nil { if key == nil {
data.WarningValidation = "No valid keys were found" data.WarningValidation = "No valid keys were found"
return errors.New(data.WarningValidation) return errors.Error(data.WarningValidation)
} }
// parse the decoded key // parse the decoded key
_, keytype, err := parsePrivateKey(key.Bytes) _, keytype, err := parsePrivateKey(key.Bytes)
if err != nil { if err != nil {
data.WarningValidation = fmt.Sprintf("Failed to parse private key: %s", err) data.WarningValidation = fmt.Sprintf("Failed to parse private key: %s", err)
return errors.New(data.WarningValidation) return errors.Error(data.WarningValidation)
} }
data.ValidKey = true data.ValidKey = true
@ -479,7 +479,7 @@ func parsePrivateKey(der []byte) (crypto.PrivateKey, string, error) {
case *ecdsa.PrivateKey: case *ecdsa.PrivateKey:
return key, "ECDSA", nil return key, "ECDSA", nil
default: default:
return nil, "", errors.New("tls: found unknown private key type in PKCS#8 wrapping") return nil, "", errors.Error("tls: found unknown private key type in PKCS#8 wrapping")
} }
} }
@ -487,7 +487,7 @@ func parsePrivateKey(der []byte) (crypto.PrivateKey, string, error) {
return key, "ECDSA", nil return key, "ECDSA", nil
} }
return nil, "", errors.New("tls: failed to parse private key") return nil, "", errors.Error("tls: failed to parse private key")
} }
// unmarshalTLS handles base64-encoded certificates transparently // unmarshalTLS handles base64-encoded certificates transparently

View File

@ -1,7 +1,6 @@
package home package home
import ( import (
"errors"
"fmt" "fmt"
"net/url" "net/url"
"os" "os"
@ -12,6 +11,7 @@ import (
"strings" "strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/google/renameio/maybe" "github.com/google/renameio/maybe"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"

View File

@ -11,6 +11,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghio" "github.com/AdguardTeam/AdGuardHome/internal/aghio"
"github.com/AdguardTeam/golibs/cache" "github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
) )
@ -139,22 +140,22 @@ func whoisParse(data string) (m strmap) {
const MaxConnReadSize = 64 * 1024 const MaxConnReadSize = 64 * 1024
// Send request to a server and receive the response // Send request to a server and receive the response
func (w *Whois) query(ctx context.Context, target, serverAddr string) (string, error) { func (w *Whois) query(ctx context.Context, target, serverAddr string) (data string, err error) {
addr, _, _ := net.SplitHostPort(serverAddr) addr, _, _ := net.SplitHostPort(serverAddr)
if addr == "whois.arin.net" { if addr == "whois.arin.net" {
target = "n + " + target target = "n + " + target
} }
conn, err := w.dialContext(ctx, "tcp", serverAddr) conn, err := w.dialContext(ctx, "tcp", serverAddr)
if err != nil { if err != nil {
return "", err return "", err
} }
defer conn.Close() defer func() { err = errors.WithDeferred(err, conn.Close()) }()
connReadCloser, err := aghio.LimitReadCloser(conn, MaxConnReadSize) r, err := aghio.LimitReader(conn, MaxConnReadSize)
if err != nil { if err != nil {
return "", err return "", err
} }
defer connReadCloser.Close()
_ = conn.SetReadDeadline(time.Now().Add(time.Duration(w.timeoutMsec) * time.Millisecond)) _ = conn.SetReadDeadline(time.Now().Add(time.Duration(w.timeoutMsec) * time.Millisecond))
_, err = conn.Write([]byte(target + "\r\n")) _, err = conn.Write([]byte(target + "\r\n"))
@ -163,12 +164,13 @@ func (w *Whois) query(ctx context.Context, target, serverAddr string) (string, e
} }
// This use of ReadAll is now safe, because we limited the conn Reader. // This use of ReadAll is now safe, because we limited the conn Reader.
data, err := io.ReadAll(connReadCloser) var whoisData []byte
whoisData, err = io.ReadAll(r)
if err != nil { if err != nil {
return "", err return "", err
} }
return string(data), nil return string(whoisData), nil
} }
// Query WHOIS servers (handle redirects) // Query WHOIS servers (handle redirects)

View File

@ -2,7 +2,6 @@
package querylog package querylog
import ( import (
"errors"
"fmt" "fmt"
"net" "net"
"os" "os"
@ -11,6 +10,7 @@ import (
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns" "github.com/miekg/dns"
) )

View File

@ -8,15 +8,15 @@ import (
"sync" "sync"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/agherr" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
) )
// Timestamp not found errors. // Timestamp not found errors.
const ( const (
ErrTSNotFound agherr.Error = "ts not found" ErrTSNotFound errors.Error = "ts not found"
ErrTSTooLate agherr.Error = "ts too late" ErrTSTooLate errors.Error = "ts too late"
ErrTSTooEarly agherr.Error = "ts too early" ErrTSTooEarly errors.Error = "ts too early"
) )
// TODO: Find a way to grow buffer instead of relying on this value when reading strings // TODO: Find a way to grow buffer instead of relying on this value when reading strings

View File

@ -2,7 +2,6 @@ package querylog
import ( import (
"encoding/binary" "encoding/binary"
"errors"
"fmt" "fmt"
"io" "io"
"math" "math"
@ -12,6 +11,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/AdguardTeam/golibs/errors"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )

View File

@ -1,12 +1,11 @@
package querylog package querylog
import ( import (
"errors"
"fmt" "fmt"
"io" "io"
"os" "os"
"github.com/AdguardTeam/AdGuardHome/internal/agherr" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
) )
@ -154,7 +153,7 @@ func closeQFiles(qFiles []*QLogFile) error {
} }
if len(errs) > 0 { if len(errs) > 0 {
return agherr.Many("Error while closing QLogReader", errs...) return errors.List("error while closing QLogReader", errs...)
} }
return nil return nil

View File

@ -6,8 +6,8 @@ import (
"path/filepath" "path/filepath"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@ -78,13 +78,13 @@ type AddParams struct {
func (p *AddParams) validate() (err error) { func (p *AddParams) validate() (err error) {
switch { switch {
case p.Question == nil: case p.Question == nil:
return agherr.Error("question is nil") return errors.Error("question is nil")
case len(p.Question.Question) != 1: case len(p.Question.Question) != 1:
return agherr.Error("more than one question") return errors.Error("more than one question")
case len(p.Question.Question[0].Name) == 0: case len(p.Question.Question[0].Name) == 0:
return agherr.Error("no host in question") return errors.Error("no host in question")
case p.ClientIP == nil: case p.ClientIP == nil:
return agherr.Error("no client ip") return errors.Error("no client ip")
default: default:
return nil return nil
} }

View File

@ -3,10 +3,10 @@ package querylog
import ( import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"errors"
"os" "os"
"time" "time"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
) )
@ -39,7 +39,7 @@ func (l *queryLog) flushLogBuffer(fullFlush bool) error {
} }
// flushToFile saves the specified log entries to the query log file // flushToFile saves the specified log entries to the query log file
func (l *queryLog) flushToFile(buffer []*logEntry) error { func (l *queryLog) flushToFile(buffer []*logEntry) (err error) {
if len(buffer) == 0 { if len(buffer) == 0 {
log.Debug("querylog: there's nothing to write to a file") log.Debug("querylog: there's nothing to write to a file")
return nil return nil
@ -49,9 +49,10 @@ func (l *queryLog) flushToFile(buffer []*logEntry) error {
var b bytes.Buffer var b bytes.Buffer
e := json.NewEncoder(&b) e := json.NewEncoder(&b)
for _, entry := range buffer { for _, entry := range buffer {
err := e.Encode(entry) err = e.Encode(entry)
if err != nil { if err != nil {
log.Error("Failed to marshal entry: %s", err) log.Error("Failed to marshal entry: %s", err)
return err return err
} }
} }
@ -59,7 +60,6 @@ func (l *queryLog) flushToFile(buffer []*logEntry) error {
elapsed := time.Since(start) elapsed := time.Since(start)
log.Debug("%d elements serialized via json in %v: %d kB, %v/entry, %v/entry", len(buffer), elapsed, b.Len()/1024, float64(b.Len())/float64(len(buffer)), elapsed/time.Duration(len(buffer))) log.Debug("%d elements serialized via json in %v: %d kB, %v/entry, %v/entry", len(buffer), elapsed, b.Len()/1024, float64(b.Len())/float64(len(buffer)), elapsed/time.Duration(len(buffer)))
var err error
var zb bytes.Buffer var zb bytes.Buffer
filename := l.logFile filename := l.logFile
zb = b zb = b
@ -71,7 +71,7 @@ func (l *queryLog) flushToFile(buffer []*logEntry) error {
log.Error("failed to create file \"%s\": %s", filename, err) log.Error("failed to create file \"%s\": %s", filename, err)
return err return err
} }
defer f.Close() defer func() { err = errors.WithDeferred(err, f.Close()) }()
n, err := f.Write(zb.Bytes()) n, err := f.Write(zb.Bytes())
if err != nil { if err != nil {
@ -109,7 +109,12 @@ func (l *queryLog) readFileFirstTimeValue() int64 {
if err != nil { if err != nil {
return -1 return -1
} }
defer f.Close() defer func() {
derr := f.Close()
if derr != nil {
log.Error("querylog: closing file: %s", derr)
}
}()
buf := make([]byte, 500) buf := make([]byte, 500)
r, err := f.Read(buf) r, err := f.Read(buf)

View File

@ -142,7 +142,12 @@ func (l *queryLog) searchFiles(
return entries, oldest, 0 return entries, oldest, 0
} }
defer r.Close() defer func() {
derr := r.Close()
if derr != nil {
log.Error("querylog: closing file: %s", err)
}
}()
if params.olderThan.IsZero() { if params.olderThan.IsZero() {
err = r.SeekStart() err = r.SeekStart()

View File

@ -4,7 +4,6 @@ import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"encoding/gob" "encoding/gob"
"errors"
"fmt" "fmt"
"net" "net"
"os" "os"
@ -12,7 +11,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/agherr" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
bolt "go.etcd.io/bbolt" bolt "go.etcd.io/bbolt"
) )
@ -93,7 +92,7 @@ func createObject(conf Config) (s *statsCtx, err error) {
// TODO(a.garipov): See if this is actually necessary. Looks // TODO(a.garipov): See if this is actually necessary. Looks
// like a rather bizarre solution. // like a rather bizarre solution.
errStop := agherr.Error("stop iteration") errStop := errors.Error("stop iteration")
forEachBkt := func(name []byte, _ *bolt.Bucket) (cberr error) { forEachBkt := func(name []byte, _ *bolt.Bucket) (cberr error) {
nameID := uint32(btoi(name)) nameID := uint32(btoi(name))
if nameID < firstID { if nameID < firstID {

View File

@ -4,10 +4,12 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"net/http"
"strings" "strings"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghio" "github.com/AdguardTeam/AdGuardHome/internal/aghio"
"github.com/AdguardTeam/golibs/errors"
) )
// TODO(a.garipov): Make configurable. // TODO(a.garipov): Make configurable.
@ -27,7 +29,7 @@ const MaxResponseSize = 64 * 1024
// VersionInfo downloads the latest version information. If forceRecheck is // VersionInfo downloads the latest version information. If forceRecheck is
// false and there are cached results, those results are returned. // false and there are cached results, those results are returned.
func (u *Updater) VersionInfo(forceRecheck bool) (VersionInfo, error) { func (u *Updater) VersionInfo(forceRecheck bool) (vi VersionInfo, err error) {
u.mu.Lock() u.mu.Lock()
defer u.mu.Unlock() defer u.mu.Unlock()
@ -37,22 +39,23 @@ func (u *Updater) VersionInfo(forceRecheck bool) (VersionInfo, error) {
return u.prevCheckResult, u.prevCheckError return u.prevCheckResult, u.prevCheckError
} }
var resp *http.Response
vcu := u.versionCheckURL vcu := u.versionCheckURL
resp, err := u.client.Get(vcu) resp, err = u.client.Get(vcu)
if err != nil { if err != nil {
return VersionInfo{}, fmt.Errorf("updater: HTTP GET %s: %w", vcu, err) return VersionInfo{}, fmt.Errorf("updater: HTTP GET %s: %w", vcu, err)
} }
defer resp.Body.Close() defer func() { err = errors.WithDeferred(err, resp.Body.Close()) }()
resp.Body, err = aghio.LimitReadCloser(resp.Body, MaxResponseSize) var r io.Reader
r, err = aghio.LimitReader(resp.Body, MaxResponseSize)
if err != nil { if err != nil {
return VersionInfo{}, fmt.Errorf("updater: LimitReadCloser: %w", err) return VersionInfo{}, fmt.Errorf("updater: LimitReadCloser: %w", err)
} }
defer resp.Body.Close()
// This use of ReadAll is safe, because we just limited the appropriate // This use of ReadAll is safe, because we just limited the appropriate
// ReadCloser. // ReadCloser.
body, err := io.ReadAll(resp.Body) body, err := io.ReadAll(r)
if err != nil { if err != nil {
return VersionInfo{}, fmt.Errorf("updater: HTTP GET %s: %w", vcu, err) return VersionInfo{}, fmt.Errorf("updater: HTTP GET %s: %w", vcu, err)
} }

View File

@ -5,7 +5,6 @@ import (
"archive/tar" "archive/tar"
"archive/zip" "archive/zip"
"compress/gzip" "compress/gzip"
"errors"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@ -20,6 +19,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghio" "github.com/AdguardTeam/AdGuardHome/internal/aghio"
"github.com/AdguardTeam/AdGuardHome/internal/version" "github.com/AdguardTeam/AdGuardHome/internal/version"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
) )
@ -283,22 +283,23 @@ func (u *Updater) clean() {
const MaxPackageFileSize = 32 * 1024 * 1024 const MaxPackageFileSize = 32 * 1024 * 1024
// Download package file and save it to disk // Download package file and save it to disk
func (u *Updater) downloadPackageFile(url, filename string) error { func (u *Updater) downloadPackageFile(url, filename string) (err error) {
resp, err := u.client.Get(url) var resp *http.Response
resp, err = u.client.Get(url)
if err != nil { if err != nil {
return fmt.Errorf("http request failed: %w", err) return fmt.Errorf("http request failed: %w", err)
} }
defer resp.Body.Close() defer func() { err = errors.WithDeferred(err, resp.Body.Close()) }()
resp.Body, err = aghio.LimitReadCloser(resp.Body, MaxPackageFileSize) var r io.Reader
r, err = aghio.LimitReader(resp.Body, MaxPackageFileSize)
if err != nil { if err != nil {
return fmt.Errorf("http request failed: %w", err) return fmt.Errorf("http request failed: %w", err)
} }
defer resp.Body.Close()
log.Debug("updater: reading HTTP body") log.Debug("updater: reading HTTP body")
// This use of ReadAll is now safe, because we limited body's Reader. // This use of ReadAll is now safe, because we limited body's Reader.
body, err := io.ReadAll(resp.Body) body, err := io.ReadAll(r)
if err != nil { if err != nil {
return fmt.Errorf("io.ReadAll() failed: %w", err) return fmt.Errorf("io.ReadAll() failed: %w", err)
} }
@ -313,172 +314,178 @@ func (u *Updater) downloadPackageFile(url, filename string) error {
return nil return nil
} }
func tarGzFileUnpackOne(outDir string, tr *tar.Reader, hdr *tar.Header) (name string, err error) {
name = filepath.Base(hdr.Name)
if name == "" {
return "", nil
}
outputName := filepath.Join(outDir, name)
if hdr.Typeflag == tar.TypeDir {
if name == "AdGuardHome" {
// Top-level AdGuardHome/. Skip it.
//
// TODO(a.garipov): This whole package needs to be
// rewritten and covered in more integration tests. It
// has weird assumptions and file mode issues.
return "", nil
}
err = os.Mkdir(outputName, os.FileMode(hdr.Mode&0o777))
if err != nil && !errors.Is(err, os.ErrExist) {
return "", fmt.Errorf("os.Mkdir(%q): %w", outputName, err)
}
log.Debug("updater: created directory %q", outputName)
return "", nil
}
if hdr.Typeflag != tar.TypeReg {
log.Debug("updater: %s: unknown file type %d, skipping", name, hdr.Typeflag)
return "", nil
}
var wc io.WriteCloser
wc, err = os.OpenFile(
outputName,
os.O_WRONLY|os.O_CREATE|os.O_TRUNC,
os.FileMode(hdr.Mode&0o777),
)
if err != nil {
return "", fmt.Errorf("os.OpenFile(%s): %w", outputName, err)
}
defer func() { err = errors.WithDeferred(err, wc.Close()) }()
_, err = io.Copy(wc, tr)
if err != nil {
return "", fmt.Errorf("io.Copy(): %w", err)
}
log.Tracef("updater: created file %s", outputName)
return name, nil
}
// Unpack all files from .tar.gz file to the specified directory // Unpack all files from .tar.gz file to the specified directory
// Existing files are overwritten // Existing files are overwritten
// All files are created inside 'outdir', subdirectories are not created // All files are created inside outDir, subdirectories are not created
// Return the list of files (not directories) written // Return the list of files (not directories) written
func tarGzFileUnpack(tarfile, outdir string) ([]string, error) { func tarGzFileUnpack(tarfile, outDir string) (files []string, err error) {
f, err := os.Open(tarfile) f, err := os.Open(tarfile)
if err != nil { if err != nil {
return nil, fmt.Errorf("os.Open(): %w", err) return nil, fmt.Errorf("os.Open(): %w", err)
} }
defer func() { defer func() { err = errors.WithDeferred(err, f.Close()) }()
_ = f.Close()
}()
gzReader, err := gzip.NewReader(f) gzReader, err := gzip.NewReader(f)
if err != nil { if err != nil {
return nil, fmt.Errorf("gzip.NewReader(): %w", err) return nil, fmt.Errorf("gzip.NewReader(): %w", err)
} }
defer func() { err = errors.WithDeferred(err, gzReader.Close()) }()
var files []string
var err2 error
tarReader := tar.NewReader(gzReader) tarReader := tar.NewReader(gzReader)
for { for {
var header *tar.Header var hdr *tar.Header
header, err = tarReader.Next() hdr, err = tarReader.Next()
if err == io.EOF { if errors.Is(err, io.EOF) {
err2 = nil err = nil
break break
} } else if err != nil {
if err != nil { err = fmt.Errorf("tarReader.Next(): %w", err)
err2 = fmt.Errorf("tarReader.Next(): %w", err)
break
}
_, inputNameOnly := filepath.Split(header.Name)
if inputNameOnly == "" {
continue
}
outputName := filepath.Join(outdir, inputNameOnly)
if header.Typeflag == tar.TypeDir {
if inputNameOnly == "AdGuardHome" {
// Top-level AdGuardHome/. Skip it.
//
// TODO(a.garipov): This whole package needs to
// be rewritten and covered in more integration
// tests. It has weird assumptions and file
// mode issues.
continue
}
err = os.Mkdir(outputName, os.FileMode(header.Mode&0o777))
if err != nil && !errors.Is(err, os.ErrExist) {
err2 = fmt.Errorf("os.Mkdir(%q): %w", outputName, err)
break break
} }
log.Debug("updater: created directory %q", outputName) var name string
name, err = tarGzFileUnpackOne(outDir, tarReader, hdr)
continue if name != "" {
} else if header.Typeflag != tar.TypeReg { files = append(files, name)
log.Debug("updater: %s: unknown file type %d, skipping", inputNameOnly, header.Typeflag) }
continue
} }
var f io.WriteCloser return files, err
f, err = os.OpenFile(outputName, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.FileMode(header.Mode&0o777))
if err != nil {
err2 = fmt.Errorf("os.OpenFile(%s): %w", outputName, err)
break
}
_, err = io.Copy(f, tarReader)
if err != nil {
_ = f.Close()
err2 = fmt.Errorf("io.Copy(): %w", err)
break
}
err = f.Close()
if err != nil {
err2 = fmt.Errorf("f.Close(): %w", err)
break
}
log.Debug("updater: created file %s", outputName)
files = append(files, header.Name)
}
_ = gzReader.Close()
return files, err2
} }
// Unpack all files from .zip file to the specified directory func zipFileUnpackOne(outDir string, zf *zip.File) (name string, err error) {
// Existing files are overwritten var rc io.ReadCloser
// All files are created inside 'outdir', subdirectories are not created rc, err = zf.Open()
// Return the list of files (not directories) written
func zipFileUnpack(zipfile, outdir string) ([]string, error) {
r, err := zip.OpenReader(zipfile)
if err != nil { if err != nil {
return nil, fmt.Errorf("zip.OpenReader(): %w", err) return "", fmt.Errorf("zip file Open(): %w", err)
}
defer r.Close()
var files []string
var err2 error
var zr io.ReadCloser
for _, zf := range r.File {
zr, err = zf.Open()
if err != nil {
err2 = fmt.Errorf("zip file Open(): %w", err)
break
} }
defer func() { err = errors.WithDeferred(err, rc.Close()) }()
fi := zf.FileInfo() fi := zf.FileInfo()
inputNameOnly := fi.Name() name = fi.Name()
if inputNameOnly == "" { if name == "" {
continue return "", nil
} }
outputName := filepath.Join(outdir, inputNameOnly) outputName := filepath.Join(outDir, name)
if fi.IsDir() { if fi.IsDir() {
if inputNameOnly == "AdGuardHome" { if name == "AdGuardHome" {
// Top-level AdGuardHome/. Skip it. // Top-level AdGuardHome/. Skip it.
// //
// TODO(a.garipov): See the similar todo in // TODO(a.garipov): See the similar todo in
// tarGzFileUnpack. // tarGzFileUnpack.
continue return "", nil
} }
err = os.Mkdir(outputName, fi.Mode()) err = os.Mkdir(outputName, fi.Mode())
if err != nil && !errors.Is(err, os.ErrExist) { if err != nil && !errors.Is(err, os.ErrExist) {
err2 = fmt.Errorf("os.Mkdir(%q): %w", outputName, err) return "", fmt.Errorf("os.Mkdir(%q): %w", outputName, err)
break
} }
log.Tracef("created directory %q", outputName) log.Tracef("created directory %q", outputName)
continue return "", nil
} }
var f io.WriteCloser var wc io.WriteCloser
f, err = os.OpenFile(outputName, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, fi.Mode()) wc, err = os.OpenFile(outputName, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, fi.Mode())
if err != nil { if err != nil {
err2 = fmt.Errorf("os.OpenFile(): %w", err) return "", fmt.Errorf("os.OpenFile(): %w", err)
break
} }
_, err = io.Copy(f, zr) defer func() { err = errors.WithDeferred(err, wc.Close()) }()
_, err = io.Copy(wc, rc)
if err != nil { if err != nil {
_ = f.Close() return "", fmt.Errorf("io.Copy(): %w", err)
err2 = fmt.Errorf("io.Copy(): %w", err)
break
}
err = f.Close()
if err != nil {
err2 = fmt.Errorf("f.Close(): %w", err)
break
} }
log.Tracef("created file %s", outputName) log.Tracef("created file %s", outputName)
files = append(files, inputNameOnly)
return name, nil
}
// Unpack all files from .zip file to the specified directory
// Existing files are overwritten
// All files are created inside 'outDir', subdirectories are not created
// Return the list of files (not directories) written
func zipFileUnpack(zipfile, outDir string) (files []string, err error) {
zrc, err := zip.OpenReader(zipfile)
if err != nil {
return nil, fmt.Errorf("zip.OpenReader(): %w", err)
}
defer func() { err = errors.WithDeferred(err, zrc.Close()) }()
for _, zf := range zrc.File {
var name string
name, err = zipFileUnpackOne(outDir, zf)
if err != nil {
break
} }
_ = zr.Close() if name != "" {
return files, err2 files = append(files, name)
}
}
return files, err
} }
// Copy file on disk // Copy file on disk

View File

@ -75,10 +75,10 @@ esac
# Simple Analyzers # Simple Analyzers
# blocklist_imports is a simple check against unwanted packages. Package # blocklist_imports is a simple check against unwanted packages. Package
# io/ioutil is soft-deprecated. Package log is replaced by our own package # io/ioutil is soft-deprecated. Packages errors and log are replaced by our own
# github.com/AdguardTeam/golibs/log. # packages in the github.com/AdguardTeam/golibs module.
blocklist_imports() { blocklist_imports() {
git grep -F -e '"io/ioutil"' -e '"log"' -- '*.go' || exit 0; git grep -F -e '"errors"' -e '"io/ioutil"' -e '"log"' -- '*.go' || exit 0;
} }
# method_const is a simple check against the usage of some raw strings and # method_const is a simple check against the usage of some raw strings and
@ -192,13 +192,7 @@ nilness ./...
exit_on_output shadow --strict ./... exit_on_output shadow --strict ./...
# TODO(a.garipov): Enable errcheck fully after handling all errors, including # TODO(a.garipov): Enable --blank?
# the deferred and generated ones, properly. Also, perhaps, enable --blank. errcheck --asserts ./...
#
# errcheck ./...
exit_on_output sh -c '
errcheck --asserts --ignoregenerated ./... |\
{ grep -e "defer" -v || exit 0; }
'
staticcheck ./... staticcheck ./...