Pull request: use testutil

Squashed commit of the following:

commit 5345a14b3565f358c56a37500cafb35b7e397951
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Thu Oct 21 21:13:06 2021 +0300

    all: fix windows tests

commit 8b9cdbe3e78f43339d21277f04e686bb154f6968
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Thu Oct 21 20:23:55 2021 +0300

    all: imp code

commit 271fdbe74c29d8ea4b53d7f56d2a36612dfed7b3
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Thu Oct 21 19:43:32 2021 +0300

    all: imp testing

commit e340f9d48679c57fc8eb579b8b78d4957be111c4
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Thu Oct 21 18:53:51 2021 +0300

    all: use testutil
This commit is contained in:
Eugene Burkov 2021-10-22 11:58:18 +03:00
parent 7804d97743
commit ea8950a80d
29 changed files with 211 additions and 249 deletions

View File

@ -11,7 +11,7 @@ type LimitReachedError struct {
Limit int64 Limit int64
} }
// Error implements error interface for LimitReachedError. // Error implements the error interface for LimitReachedError.
// //
// TODO(a.garipov): Think about error string format. // TODO(a.garipov): Think about error string format.
func (lre *LimitReachedError) Error() string { func (lre *LimitReachedError) Error() string {

View File

@ -1,30 +1,30 @@
package aghio package aghio
import ( import (
"fmt"
"io" "io"
"strings" "strings"
"testing" "testing"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestLimitReader(t *testing.T) { func TestLimitReader(t *testing.T) {
testCases := []struct { testCases := []struct {
want error wantErrMsg string
name string name string
n int64 n int64
}{{ }{{
want: nil, wantErrMsg: "",
name: "positive", name: "positive",
n: 1, n: 1,
}, { }, {
want: nil, wantErrMsg: "",
name: "zero", name: "zero",
n: 0, n: 0,
}, { }, {
want: fmt.Errorf("aghio: invalid n in LimitReader: -1"), wantErrMsg: "aghio: invalid n in LimitReader: -1",
name: "negative", name: "negative",
n: -1, n: -1,
}} }}
@ -32,7 +32,7 @@ func TestLimitReader(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
_, err := LimitReader(nil, tc.n) _, err := LimitReader(nil, tc.n)
assert.Equal(t, tc.want, err) testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
}) })
} }
} }
@ -73,36 +73,23 @@ func TestLimitedReader_Read(t *testing.T) {
}} }}
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
readCloser := io.NopCloser(strings.NewReader(tc.rStr)) readCloser := io.NopCloser(strings.NewReader(tc.rStr))
buf := make([]byte, tc.limit+1)
lreader, err := LimitReader(readCloser, tc.limit) lreader, err := LimitReader(readCloser, tc.limit)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, lreader)
t.Run(tc.name, func(t *testing.T) {
buf := make([]byte, tc.limit+1)
n, rerr := lreader.Read(buf)
require.Equal(t, rerr, tc.err)
n, err := lreader.Read(buf)
require.Equal(t, tc.err, err)
assert.Equal(t, tc.want, n) assert.Equal(t, tc.want, n)
}) })
} }
} }
func TestLimitedReader_LimitReachedError(t *testing.T) { func TestLimitedReader_LimitReachedError(t *testing.T) {
testCases := []struct { testutil.AssertErrorMsg(t, "attempted to read more than 0 bytes", &LimitReachedError{
err error
name string
want string
}{{
err: &LimitReachedError{
Limit: 0, Limit: 0,
},
name: "simplest",
want: "attempted to read more than 0 bytes",
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.want, tc.err.Error())
}) })
}
} }

View File

@ -79,8 +79,8 @@ func TestSystemResolvers_DialFunc(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
conn, err := imp.dialFunc(context.Background(), "", tc.address) conn, err := imp.dialFunc(context.Background(), "", tc.address)
require.Nil(t, conn) require.Nil(t, conn)
assert.ErrorIs(t, err, tc.want) assert.ErrorIs(t, err, tc.want)
}) })
} }

View File

@ -5,9 +5,9 @@ package dhcpd
import ( import (
"net" "net"
"strings"
"testing" "testing"
"github.com/AdguardTeam/golibs/testutil"
"github.com/google/gopacket" "github.com/google/gopacket"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
"github.com/insomniacslk/dhcp/dhcpv4" "github.com/insomniacslk/dhcp/dhcpv4"
@ -45,7 +45,7 @@ func TestDHCPConn_WriteTo_common(t *testing.T) {
n, err := conn.WriteTo(nil, &unexpectedAddrType{}) n, err := conn.WriteTo(nil, &unexpectedAddrType{})
require.Error(t, err) require.Error(t, err)
assert.True(t, strings.Contains(err.Error(), "peer is of unexpected type")) testutil.AssertErrorMsg(t, "peer is of unexpected type *dhcpd.unexpectedAddrType", err)
assert.Zero(t, n) assert.Zero(t, n)
}) })
} }

View File

@ -68,9 +68,7 @@ func TestDB(t *testing.T) {
err = s.dbStore() err = s.dbStore()
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(func() { testutil.CleanupAndRequireSuccess(t, func() (err error) { return os.Remove(dbFilename) })
assert.NoError(t, os.Remove(dbFilename))
})
err = s.srv4.ResetLeases(nil) err = s.srv4.ResetLeases(nil)
require.NoError(t, err) require.NoError(t, err)

View File

@ -15,7 +15,7 @@ func TestServer_notImplemented(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
r, err := http.NewRequest(http.MethodGet, "/unsupported", nil) r, err := http.NewRequest(http.MethodGet, "/unsupported", nil)
require.Nil(t, err) require.NoError(t, err)
h(w, r) h(w, r)
assert.Equal(t, http.StatusNotImplemented, w.Code) assert.Equal(t, http.StatusNotImplemented, w.Code)

View File

@ -4,6 +4,7 @@ import (
"net" "net"
"testing" "testing"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -64,14 +65,8 @@ func TestNewIPRange(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
r, err := newIPRange(tc.start, tc.end) _, err := newIPRange(tc.start, tc.end)
if tc.wantErrMsg == "" { testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
assert.NoError(t, err)
assert.NotNil(t, r)
} else {
require.Error(t, err)
assert.Equal(t, tc.wantErrMsg, err.Error())
}
}) })
} }
} }

View File

@ -4,6 +4,7 @@ import (
"encoding/json" "encoding/json"
"testing" "testing"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -45,13 +46,7 @@ func TestNullBool_UnmarshalJSON(t *testing.T) {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
var got nullBool var got nullBool
err := got.UnmarshalJSON(tc.data) err := got.UnmarshalJSON(tc.data)
if tc.wantErrMsg == "" { testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
assert.NoError(t, err)
} else {
require.Error(t, err)
assert.Equal(t, tc.wantErrMsg, err.Error())
}
assert.Equal(t, tc.want, got) assert.Equal(t, tc.want, got)
}) })

View File

@ -95,6 +95,7 @@ func TestParseOpt(t *testing.T) {
opt, err := parseDHCPOption(tc.in) opt, err := parseDHCPOption(tc.in)
if tc.wantErrMsg != "" { if tc.wantErrMsg != "" {
require.Error(t, err) require.Error(t, err)
assert.Equal(t, tc.wantErrMsg, err.Error()) assert.Equal(t, tc.wantErrMsg, err.Error())
return return

View File

@ -9,6 +9,7 @@ import (
"testing" "testing"
"github.com/AdguardTeam/golibs/stringutil" "github.com/AdguardTeam/golibs/stringutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/insomniacslk/dhcp/dhcpv4" "github.com/insomniacslk/dhcp/dhcpv4"
"github.com/mdlayher/raw" "github.com/mdlayher/raw"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -430,14 +431,7 @@ func TestNormalizeHostname(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
got, err := normalizeHostname(tc.hostname) got, err := normalizeHostname(tc.hostname)
if tc.wantErrMsg == "" { testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
assert.NoError(t, err)
} else {
require.Error(t, err)
assert.Equal(t, tc.wantErrMsg, err.Error())
}
assert.Equal(t, tc.want, got) assert.Equal(t, tc.want, got)
}) })
} }

View File

@ -8,9 +8,9 @@ import (
"testing" "testing"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/testutil"
"github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
// testTLSConn is a tlsConn for tests. // testTLSConn is a tlsConn for tests.
@ -179,13 +179,7 @@ func TestServer_clientIDFromDNSContext(t *testing.T) {
clientID, err := srv.clientIDFromDNSContext(pctx) clientID, err := srv.clientIDFromDNSContext(pctx)
assert.Equal(t, tc.wantClientID, clientID) assert.Equal(t, tc.wantClientID, clientID)
if tc.wantErrMsg == "" { testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
assert.NoError(t, err)
} else {
require.Error(t, err)
assert.Equal(t, tc.wantErrMsg, err.Error())
}
}) })
} }
} }
@ -250,13 +244,7 @@ func TestClientIDFromDNSContextHTTPS(t *testing.T) {
clientID, err := clientIDFromDNSContextHTTPS(pctx) clientID, err := clientIDFromDNSContextHTTPS(pctx)
assert.Equal(t, tc.wantClientID, clientID) assert.Equal(t, tc.wantClientID, clientID)
if tc.wantErrMsg == "" { testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
assert.NoError(t, err)
} else {
require.Error(t, err)
assert.Equal(t, tc.wantErrMsg, err.Error())
}
}) })
} }
} }

View File

@ -24,6 +24,7 @@ import (
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/golibs/timeutil" "github.com/AdguardTeam/golibs/timeutil"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -45,10 +46,7 @@ func startDeferStop(t *testing.T, s *Server) {
err := s.Start() err := s.Start()
require.NoErrorf(t, err, "failed to start server: %s", err) require.NoErrorf(t, err, "failed to start server: %s", err)
t.Cleanup(func() { testutil.CleanupAndRequireSuccess(t, s.Stop)
serr := s.Stop()
require.NoErrorf(t, serr, "dns server failed to stop: %s", serr)
})
} }
func createTestServer( func createTestServer(
@ -1049,9 +1047,7 @@ func TestPTRResponseFromDHCPLeases(t *testing.T) {
err = s.Start() err = s.Start()
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(func() { t.Cleanup(s.Close)
s.Close()
})
addr := s.dnsProxy.Addr(proxy.ProtoUDP) addr := s.dnsProxy.Addr(proxy.ProtoUDP)
req := createTestMessageWithType("34.12.168.192.in-addr.arpa.", dns.TypePTR) req := createTestMessageWithType("34.12.168.192.in-addr.arpa.", dns.TypePTR)
@ -1131,9 +1127,7 @@ func TestPTRResponseFromHosts(t *testing.T) {
err = s.Start() err = s.Start()
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(func() { t.Cleanup(s.Close)
s.Close()
})
subTestFunc := func(t *testing.T) { subTestFunc := func(t *testing.T) {
addr := s.dnsProxy.Addr(proxy.ProtoUDP) addr := s.dnsProxy.Addr(proxy.ProtoUDP)
@ -1197,12 +1191,7 @@ func TestNewServer(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
_, err := NewServer(tc.in) _, err := NewServer(tc.in)
if tc.wantErrMsg == "" { testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
assert.NoError(t, err)
} else {
require.Error(t, err)
assert.Equal(t, tc.wantErrMsg, err.Error())
}
}) })
} }
} }

View File

@ -61,8 +61,8 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
d := &proxy.DNSContext{} d := &proxy.DNSContext{}
err := srv.filterDNSRewrite(req, res, d) err := srv.filterDNSRewrite(req, res, d)
require.NoError(t, err)
require.Nil(t, err)
assert.Equal(t, dns.RcodeNameError, d.Res.Rcode) assert.Equal(t, dns.RcodeNameError, d.Res.Rcode)
}) })
@ -72,7 +72,8 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
d := &proxy.DNSContext{} d := &proxy.DNSContext{}
err := srv.filterDNSRewrite(req, res, d) err := srv.filterDNSRewrite(req, res, d)
require.Nil(t, err) require.NoError(t, err)
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode) assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
assert.Empty(t, d.Res.Answer) assert.Empty(t, d.Res.Answer)
}) })
@ -83,7 +84,8 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
d := &proxy.DNSContext{} d := &proxy.DNSContext{}
err := srv.filterDNSRewrite(req, res, d) err := srv.filterDNSRewrite(req, res, d)
require.Nil(t, err) require.NoError(t, err)
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode) assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
require.Len(t, d.Res.Answer, 1) require.Len(t, d.Res.Answer, 1)
@ -96,7 +98,8 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
d := &proxy.DNSContext{} d := &proxy.DNSContext{}
err := srv.filterDNSRewrite(req, res, d) err := srv.filterDNSRewrite(req, res, d)
require.Nil(t, err) require.NoError(t, err)
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode) assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
require.Len(t, d.Res.Answer, 1) require.Len(t, d.Res.Answer, 1)
@ -109,7 +112,8 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
d := &proxy.DNSContext{} d := &proxy.DNSContext{}
err := srv.filterDNSRewrite(req, res, d) err := srv.filterDNSRewrite(req, res, d)
require.Nil(t, err) require.NoError(t, err)
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode) assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
require.Len(t, d.Res.Answer, 1) require.Len(t, d.Res.Answer, 1)
@ -122,7 +126,8 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
d := &proxy.DNSContext{} d := &proxy.DNSContext{}
err := srv.filterDNSRewrite(req, res, d) err := srv.filterDNSRewrite(req, res, d)
require.Nil(t, err) require.NoError(t, err)
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode) assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
require.Len(t, d.Res.Answer, 1) require.Len(t, d.Res.Answer, 1)
@ -135,7 +140,8 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
d := &proxy.DNSContext{} d := &proxy.DNSContext{}
err := srv.filterDNSRewrite(req, res, d) err := srv.filterDNSRewrite(req, res, d)
require.Nil(t, err) require.NoError(t, err)
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode) assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
require.Len(t, d.Res.Answer, 1) require.Len(t, d.Res.Answer, 1)
@ -152,7 +158,8 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
d := &proxy.DNSContext{} d := &proxy.DNSContext{}
err := srv.filterDNSRewrite(req, res, d) err := srv.filterDNSRewrite(req, res, d)
require.Nil(t, err) require.NoError(t, err)
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode) assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
require.Len(t, d.Res.Answer, 1) require.Len(t, d.Res.Answer, 1)
@ -171,7 +178,8 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
d := &proxy.DNSContext{} d := &proxy.DNSContext{}
err := srv.filterDNSRewrite(req, res, d) err := srv.filterDNSRewrite(req, res, d)
require.Nil(t, err) require.NoError(t, err)
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode) assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
require.Len(t, d.Res.Answer, 1) require.Len(t, d.Res.Answer, 1)
@ -190,7 +198,8 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
d := &proxy.DNSContext{} d := &proxy.DNSContext{}
err := srv.filterDNSRewrite(req, res, d) err := srv.filterDNSRewrite(req, res, d)
require.Nil(t, err) require.NoError(t, err)
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode) assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
require.Len(t, d.Res.Answer, 1) require.Len(t, d.Res.Answer, 1)

View File

@ -14,6 +14,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -38,9 +39,7 @@ func loadTestData(t *testing.T, casesFileName string, cases interface{}) {
var f *os.File var f *os.File
f, err := os.Open(filepath.Join("testdata", casesFileName)) f, err := os.Open(filepath.Join("testdata", casesFileName))
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(func() { testutil.CleanupAndRequireSuccess(t, f.Close)
require.NoError(t, f.Close())
})
err = json.NewDecoder(f).Decode(cases) err = json.NewDecoder(f).Decode(cases)
require.NoError(t, err) require.NoError(t, err)
@ -69,10 +68,8 @@ func TestDNSForwardHTTP_handleGetConfig(t *testing.T) {
s := createTestServer(t, filterConf, forwardConf, nil) s := createTestServer(t, filterConf, forwardConf, nil)
s.sysResolvers = &fakeSystemResolvers{} s.sysResolvers = &fakeSystemResolvers{}
require.Nil(t, s.Start()) require.NoError(t, s.Start())
t.Cleanup(func() { testutil.CleanupAndRequireSuccess(t, s.Stop)
require.Nil(t, s.Stop())
})
defaultConf := s.conf defaultConf := s.conf
@ -147,10 +144,8 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
defaultConf := s.conf defaultConf := s.conf
err := s.Start() err := s.Start()
assert.Nil(t, err) assert.NoError(t, err)
t.Cleanup(func() { testutil.CleanupAndRequireSuccess(t, s.Stop)
assert.Nil(t, s.Stop())
})
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -221,14 +216,12 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
require.True(t, ok) require.True(t, ok)
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
t.Cleanup(func() { t.Cleanup(func() { s.conf = defaultConf })
s.conf = defaultConf
})
rBody := io.NopCloser(bytes.NewReader(caseData.Req)) rBody := io.NopCloser(bytes.NewReader(caseData.Req))
var r *http.Request var r *http.Request
r, err = http.NewRequest(http.MethodPost, "http://example.com", rBody) r, err = http.NewRequest(http.MethodPost, "http://example.com", rBody)
require.Nil(t, err) require.NoError(t, err)
s.handleSetConfig(w, r) s.handleSetConfig(w, r)
assert.Equal(t, tc.wantSet, strings.TrimSuffix(w.Body.String(), "\n")) assert.Equal(t, tc.wantSet, strings.TrimSuffix(w.Body.String(), "\n"))

View File

@ -157,7 +157,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
}} }}
ups, err := upstream.AddressToUpstream("1.1.1.1", nil) ups, err := upstream.AddressToUpstream("1.1.1.1", nil)
require.Nil(t, err) require.NoError(t, err)
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {

View File

@ -13,6 +13,7 @@ import (
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -24,19 +25,17 @@ func TestMain(m *testing.M) {
func TestNewSessionToken(t *testing.T) { func TestNewSessionToken(t *testing.T) {
// Successful case. // Successful case.
token, err := newSessionToken() token, err := newSessionToken()
require.Nil(t, err) require.NoError(t, err)
assert.Len(t, token, sessionTokenSize) assert.Len(t, token, sessionTokenSize)
// Break the rand.Reader. // Break the rand.Reader.
prevReader := rand.Reader prevReader := rand.Reader
t.Cleanup(func() { t.Cleanup(func() { rand.Reader = prevReader })
rand.Reader = prevReader
})
rand.Reader = &bytes.Buffer{} rand.Reader = &bytes.Buffer{}
// Unsuccessful case. // Unsuccessful case.
token, err = newSessionToken() token, err = newSessionToken()
require.NotNil(t, err) require.Error(t, err)
assert.Empty(t, token) assert.Empty(t, token)
} }
@ -58,7 +57,7 @@ func TestAuth(t *testing.T) {
a.RemoveSession("notfound") a.RemoveSession("notfound")
sess, err := newSessionToken() sess, err := newSessionToken()
assert.Nil(t, err) require.NoError(t, err)
sessStr := hex.EncodeToString(sess) sessStr := hex.EncodeToString(sess)
now := time.Now().UTC().Unix() now := time.Now().UTC().Unix()
@ -152,7 +151,7 @@ func TestAuthHTTP(t *testing.T) {
// perform login // perform login
cookie, err := Context.auth.httpCookie(loginJSON{Name: "name", Password: "password"}, "") cookie, err := Context.auth.httpCookie(loginJSON{Name: "name", Password: "password"}, "")
assert.Nil(t, err) require.NoError(t, err)
assert.NotEmpty(t, cookie) assert.NotEmpty(t, cookie)
// get / // get /
@ -251,12 +250,7 @@ func TestRealIP(t *testing.T) {
ip, err := realIP(r) ip, err := realIP(r)
assert.Equal(t, tc.wantIP, ip) assert.Equal(t, tc.wantIP, ip)
if tc.wantErrMsg == "" { testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
assert.NoError(t, err)
} else {
require.Error(t, err)
assert.Equal(t, tc.wantErrMsg, err.Error())
}
}) })
} }
} }

View File

@ -15,21 +15,19 @@ func TestAuthGL(t *testing.T) {
dir := t.TempDir() dir := t.TempDir()
GLMode = true GLMode = true
t.Cleanup(func() { t.Cleanup(func() { GLMode = false })
GLMode = false
})
glFilePrefix = dir + "/gl_token_" glFilePrefix = dir + "/gl_token_"
data := make([]byte, 4) data := make([]byte, 4)
aghos.NativeEndian.PutUint32(data, 1) aghos.NativeEndian.PutUint32(data, 1)
require.Nil(t, os.WriteFile(glFilePrefix+"test", data, 0o644)) require.NoError(t, os.WriteFile(glFilePrefix+"test", data, 0o644))
assert.False(t, glCheckToken("test")) assert.False(t, glCheckToken("test"))
data = make([]byte, 4) data = make([]byte, 4)
aghos.NativeEndian.PutUint32(data, uint32(time.Now().UTC().Unix()+60)) aghos.NativeEndian.PutUint32(data, uint32(time.Now().UTC().Unix()+60))
require.Nil(t, os.WriteFile(glFilePrefix+"test", data, 0o644)) require.NoError(t, os.WriteFile(glFilePrefix+"test", data, 0o644))
r, _ := http.NewRequest(http.MethodGet, "http://localhost/", nil) r, _ := http.NewRequest(http.MethodGet, "http://localhost/", nil)
r.AddCookie(&http.Cookie{Name: glCookieName, Value: "test"}) r.AddCookie(&http.Cookie{Name: glCookieName, Value: "test"})
assert.True(t, glProcessCookie(r)) assert.True(t, glProcessCookie(r))

View File

@ -290,7 +290,9 @@ func TestClientsAddExisting(t *testing.T) {
clients.dhcpServer, err = dhcpd.Create(config) clients.dhcpServer, err = dhcpd.Create(config)
require.NoError(t, err) require.NoError(t, err)
// TODO(e.burkov): leases.db isn't created on Windows so removing it
// causes an error. Split the test to make it run properly on different
// operating systems.
t.Cleanup(func() { _ = os.Remove("leases.db") }) t.Cleanup(func() { _ = os.Remove("leases.db") })
err = clients.dhcpServer.AddStaticLease(&dhcpd.Lease{ err = clients.dhcpServer.AddStaticLease(&dhcpd.Lease{
@ -309,8 +311,7 @@ func TestClientsAddExisting(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.True(t, ok) assert.True(t, ok)
// Add a new client with the IP from the first client's IP // Add a new client with the IP from the first client's IP range.
// range.
ok, err = clients.Add(&Client{ ok, err = clients.Add(&Client{
IDs: []string{"2.2.2.2"}, IDs: []string{"2.2.2.2"},
Name: "client3", Name: "client3",

View File

@ -12,6 +12,7 @@ import (
"time" "time"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -21,7 +22,7 @@ const testFltsFileName = "1.txt"
func testStartFilterListener(t *testing.T, fltContent *[]byte) (l net.Listener) { func testStartFilterListener(t *testing.T, fltContent *[]byte) (l net.Listener) {
t.Helper() t.Helper()
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
n, werr := w.Write(*fltContent) n, werr := w.Write(*fltContent)
require.NoError(t, werr) require.NoError(t, werr)
require.Equal(t, len(*fltContent), n) require.Equal(t, len(*fltContent), n)
@ -34,9 +35,7 @@ func testStartFilterListener(t *testing.T, fltContent *[]byte) (l net.Listener)
go func() { go func() {
_ = http.Serve(l, h) _ = http.Serve(l, h)
}() }()
t.Cleanup(func() { testutil.CleanupAndRequireSuccess(t, l.Close)
require.NoError(t, l.Close())
})
return l return l
} }
@ -100,9 +99,7 @@ func TestFilters(t *testing.T) {
t.Run("refresh_actually", func(t *testing.T) { t.Run("refresh_actually", func(t *testing.T) {
fltContent = []byte(`||example.com^`) fltContent = []byte(`||example.com^`)
t.Cleanup(func() { t.Cleanup(func() { fltContent = []byte(content) })
fltContent = []byte(content)
})
updateAndAssert(t, require.True, 1) updateAndAssert(t, require.True, 1)
}) })

View File

@ -46,7 +46,7 @@ func TestLimitRequestBody(t *testing.T) {
var b []byte var b []byte
b, *err = io.ReadAll(r.Body) b, *err = io.ReadAll(r.Body)
_, werr := w.Write(b) _, werr := w.Write(b)
require.Nil(t, werr) require.NoError(t, werr)
}) })
} }

View File

@ -153,7 +153,6 @@ func TestHandleMobileConfigDoT(t *testing.T) {
handleMobileConfigDoT(w, r) handleMobileConfigDoT(w, r)
assert.Equal(t, http.StatusInternalServerError, w.Code) assert.Equal(t, http.StatusInternalServerError, w.Code)
assert.JSONEq(t, w.Body.String(), b.String()) assert.JSONEq(t, w.Body.String(), b.String())
}) })

View File

@ -13,7 +13,7 @@ func testParseOK(t *testing.T, ss ...string) options {
t.Helper() t.Helper()
o, _, err := parse("", ss) o, _, err := parse("", ss)
require.Nil(t, err) require.NoError(t, err)
return o return o
} }
@ -22,7 +22,7 @@ func testParseErr(t *testing.T, descr string, ss ...string) {
t.Helper() t.Helper()
_, _, err := parse("", ss) _, _, err := parse("", ss)
require.NotNilf(t, err, "expected an error because %s but no error returned", descr) require.Error(t, err)
} }
func testParseParamMissing(t *testing.T, param string) { func testParseParamMissing(t *testing.T, param string) {

View File

@ -4,6 +4,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/golibs/timeutil" "github.com/AdguardTeam/golibs/timeutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -333,8 +334,7 @@ func TestUpgradeSchema9to10(t *testing.T) {
err := upgradeSchema9to10(conf) err := upgradeSchema9to10(conf)
if tc.wantErr != "" { if tc.wantErr != "" {
require.Error(t, err) testutil.AssertErrorMsg(t, tc.wantErr, err)
assert.Equal(t, tc.wantErr, err.Error())
return return
} }
@ -366,8 +366,7 @@ func TestUpgradeSchema9to10(t *testing.T) {
"dns": ultimateAns, "dns": ultimateAns,
}) })
require.Error(t, err) testutil.AssertErrorMsg(t, "unexpected type of dns: int", err)
assert.Equal(t, "unexpected type of dns: int", err.Error())
}) })
} }
@ -448,6 +447,7 @@ func TestUpgradeSchema11to12(t *testing.T) {
if tc.wantErr != "" { if tc.wantErr != "" {
require.Error(t, err) require.Error(t, err)
assert.Equal(t, tc.wantErr, err.Error()) assert.Equal(t, tc.wantErr, err.Error())
return return
@ -482,8 +482,7 @@ func TestUpgradeSchema11to12(t *testing.T) {
"dns": 0, "dns": 0,
}) })
require.Error(t, err) testutil.AssertErrorMsg(t, "unexpected type of dns: int", err)
assert.Equal(t, "unexpected type of dns: int", err.Error())
}) })
t.Run("no_field", func(t *testing.T) { t.Run("no_field", func(t *testing.T) {

View File

@ -14,6 +14,7 @@ import (
"github.com/AdguardTeam/urlfilter/rules" "github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestDecodeLogEntry(t *testing.T) { func TestDecodeLogEntry(t *testing.T) {
@ -45,7 +46,7 @@ func TestDecodeLogEntry(t *testing.T) {
`"Elapsed":837429}` `"Elapsed":837429}`
ans, err := base64.StdEncoding.DecodeString(ansStr) ans, err := base64.StdEncoding.DecodeString(ansStr)
assert.Nil(t, err) require.NoError(t, err)
want := &logEntry{ want := &logEntry{
IP: net.IPv4(127, 0, 0, 1), IP: net.IPv4(127, 0, 0, 1),

View File

@ -35,13 +35,13 @@ func TestQueryLog(t *testing.T) {
// Add disk entries. // Add disk entries.
addEntry(l, "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) addEntry(l, "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
// Write to disk (first file). // Write to disk (first file).
require.Nil(t, l.flushLogBuffer(true)) require.NoError(t, l.flushLogBuffer(true))
// Start writing to the second file. // Start writing to the second file.
require.Nil(t, l.rotate()) require.NoError(t, l.rotate())
// Add disk entries. // Add disk entries.
addEntry(l, "example.org", net.IPv4(1, 1, 1, 2), net.IPv4(2, 2, 2, 2)) addEntry(l, "example.org", net.IPv4(1, 1, 1, 2), net.IPv4(2, 2, 2, 2))
// Write to disk. // Write to disk.
require.Nil(t, l.flushLogBuffer(true)) require.NoError(t, l.flushLogBuffer(true))
// Add memory entries. // Add memory entries.
addEntry(l, "test.example.org", net.IPv4(1, 1, 1, 3), net.IPv4(2, 2, 2, 3)) addEntry(l, "test.example.org", net.IPv4(1, 1, 1, 3), net.IPv4(2, 2, 2, 3))
addEntry(l, "example.com", net.IPv4(1, 1, 1, 4), net.IPv4(2, 2, 2, 4)) addEntry(l, "example.com", net.IPv4(1, 1, 1, 4), net.IPv4(2, 2, 2, 4))
@ -144,7 +144,7 @@ func TestQueryLogOffsetLimit(t *testing.T) {
addEntry(l, secondPageDomain, net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) addEntry(l, secondPageDomain, net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
} }
// Write them to the first file. // Write them to the first file.
require.Nil(t, l.flushLogBuffer(true)) require.NoError(t, l.flushLogBuffer(true))
// Add more to the in-memory part of log. // Add more to the in-memory part of log.
for i := 0; i < entNum; i++ { for i := 0; i < entNum; i++ {
addEntry(l, firstPageDomain, net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) addEntry(l, firstPageDomain, net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
@ -215,7 +215,7 @@ func TestQueryLogMaxFileScanEntries(t *testing.T) {
addEntry(l, "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) addEntry(l, "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
} }
// Write them to disk. // Write them to disk.
require.Nil(t, l.flushLogBuffer(true)) require.NoError(t, l.flushLogBuffer(true))
params := newSearchParams() params := newSearchParams()
@ -300,7 +300,7 @@ func assertLogEntry(t *testing.T, entry *logEntry, host string, answer, client n
assert.Equal(t, "IN", entry.QClass) assert.Equal(t, "IN", entry.QClass)
msg := &dns.Msg{} msg := &dns.Msg{}
require.Nil(t, msg.Unpack(entry.Answer)) require.NoError(t, msg.Unpack(entry.Answer))
require.Len(t, msg.Answer, 1) require.Len(t, msg.Answer, 1)
ip := proxyutil.GetIPFromDNSRecord(msg.Answer[0]).To16() ip := proxyutil.GetIPFromDNSRecord(msg.Answer[0]).To16()

View File

@ -12,6 +12,7 @@ import (
"time" "time"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -80,11 +81,10 @@ func newTestQLogFile(t *testing.T, linesNum int) (file *QLogFile) {
// Create the new QLogFile instance. // Create the new QLogFile instance.
file, err := NewQLogFile(testFile) file, err := NewQLogFile(testFile)
require.Nil(t, err) require.NoError(t, err)
assert.NotNil(t, file) assert.NotNil(t, file)
t.Cleanup(func() { testutil.CleanupAndRequireSuccess(t, file.Close)
assert.Nil(t, file.Close())
})
return file return file
} }
@ -107,7 +107,7 @@ func TestQLogFile_ReadNext(t *testing.T) {
// Calculate the expected position. // Calculate the expected position.
fileInfo, err := q.file.Stat() fileInfo, err := q.file.Stat()
require.Nil(t, err) require.NoError(t, err)
var expPos int64 var expPos int64
if expPos = fileInfo.Size(); expPos > 0 { if expPos = fileInfo.Size(); expPos > 0 {
expPos-- expPos--
@ -115,7 +115,7 @@ func TestQLogFile_ReadNext(t *testing.T) {
// Seek to the start. // Seek to the start.
pos, err := q.SeekStart() pos, err := q.SeekStart()
require.Nil(t, err) require.NoError(t, err)
require.EqualValues(t, expPos, pos) require.EqualValues(t, expPos, pos)
var read int var read int
@ -170,17 +170,17 @@ func TestQLogFile_SeekTS_good(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(l.name+"_"+tc.name, func(t *testing.T) { t.Run(l.name+"_"+tc.name, func(t *testing.T) {
line, err := getQLogFileLine(q, tc.line) line, err := getQLogFileLine(q, tc.line)
require.Nil(t, err) require.NoError(t, err)
ts := readQLogTimestamp(line) ts := readQLogTimestamp(line)
assert.NotEqualValues(t, 0, ts) assert.NotEqualValues(t, 0, ts)
// Try seeking to that line now. // Try seeking to that line now.
pos, _, err := q.SeekTS(ts) pos, _, err := q.SeekTS(ts)
require.Nil(t, err) require.NoError(t, err)
assert.NotEqualValues(t, 0, pos) assert.NotEqualValues(t, 0, pos)
testLine, err := q.ReadNext() testLine, err := q.ReadNext()
require.Nil(t, err) require.NoError(t, err)
assert.Equal(t, line, testLine) assert.Equal(t, line, testLine)
}) })
} }
@ -220,7 +220,7 @@ func TestQLogFile_SeekTS_bad(t *testing.T) {
testCases[1].ts = lateTS.UnixNano() testCases[1].ts = lateTS.UnixNano()
line, err := getQLogFileLine(q, l.num/2) line, err := getQLogFileLine(q, l.num/2)
require.Nil(t, err) require.NoError(t, err)
testCases[2].ts = readQLogTimestamp(line) - 1 testCases[2].ts = readQLogTimestamp(line) - 1
for _, tc := range testCases { for _, tc := range testCases {
@ -230,7 +230,8 @@ func TestQLogFile_SeekTS_bad(t *testing.T) {
var depth int var depth int
_, depth, err = q.SeekTS(tc.ts) _, depth, err = q.SeekTS(tc.ts)
assert.NotEmpty(t, l.num) assert.NotEmpty(t, l.num)
require.NotNil(t, err) require.Error(t, err)
if tc.leq { if tc.leq {
assert.LessOrEqual(t, depth, int(math.Log2(float64(l.num))+3)) assert.LessOrEqual(t, depth, int(math.Log2(float64(l.num))+3))
} }
@ -260,19 +261,19 @@ func TestQLogFile(t *testing.T) {
// Seek to the start. // Seek to the start.
pos, err := q.SeekStart() pos, err := q.SeekStart()
require.Nil(t, err) require.NoError(t, err)
assert.Greater(t, pos, int64(0)) assert.Greater(t, pos, int64(0))
// Read first line. // Read first line.
line, err := q.ReadNext() line, err := q.ReadNext()
require.Nil(t, err) require.NoError(t, err)
assert.Contains(t, line, "0.0.0.2") assert.Contains(t, line, "0.0.0.2")
assert.True(t, strings.HasPrefix(line, "{"), line) assert.True(t, strings.HasPrefix(line, "{"), line)
assert.True(t, strings.HasSuffix(line, "}"), line) assert.True(t, strings.HasSuffix(line, "}"), line)
// Read second line. // Read second line.
line, err = q.ReadNext() line, err = q.ReadNext()
require.Nil(t, err) require.NoError(t, err)
assert.EqualValues(t, 0, q.position) assert.EqualValues(t, 0, q.position)
assert.Contains(t, line, "0.0.0.1") assert.Contains(t, line, "0.0.0.1")
assert.True(t, strings.HasPrefix(line, "{"), line) assert.True(t, strings.HasPrefix(line, "{"), line)
@ -286,19 +287,15 @@ func TestQLogFile(t *testing.T) {
func NewTestQLogFileData(t *testing.T, data string) (file *QLogFile) { func NewTestQLogFileData(t *testing.T, data string) (file *QLogFile) {
f, err := os.CreateTemp(t.TempDir(), "*.txt") f, err := os.CreateTemp(t.TempDir(), "*.txt")
require.Nil(t, err) require.NoError(t, err)
t.Cleanup(func() { testutil.CleanupAndRequireSuccess(t, f.Close)
assert.Nil(t, f.Close())
})
_, err = f.WriteString(data) _, err = f.WriteString(data)
require.Nil(t, err) require.NoError(t, err)
file, err = NewQLogFile(f.Name()) file, err = NewQLogFile(f.Name())
require.Nil(t, err) require.NoError(t, err)
t.Cleanup(func() { testutil.CleanupAndRequireSuccess(t, file.Close)
assert.Nil(t, file.Close())
})
return file return file
} }

View File

@ -5,6 +5,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -18,12 +19,10 @@ func newTestQLogReader(t *testing.T, filesNum, linesNum int) (reader *QLogReader
// Create the new QLogReader instance. // Create the new QLogReader instance.
reader, err := NewQLogReader(testFiles) reader, err := NewQLogReader(testFiles)
require.Nil(t, err) require.NoError(t, err)
assert.NotNil(t, reader) assert.NotNil(t, reader)
t.Cleanup(func() { testutil.CleanupAndRequireSuccess(t, reader.Close)
assert.Nil(t, reader.Close())
})
return reader return reader
} }
@ -53,7 +52,7 @@ func TestQLogReader(t *testing.T) {
// Seek to the start. // Seek to the start.
err := r.SeekStart() err := r.SeekStart()
require.Nil(t, err) require.NoError(t, err)
// Read everything. // Read everything.
var read int var read int
@ -112,7 +111,7 @@ func TestQLogReader_Seek(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
ts, err := time.Parse(time.RFC3339Nano, tc.time) ts, err := time.Parse(time.RFC3339Nano, tc.time)
require.Nil(t, err) require.NoError(t, err)
err = r.SeekTS(ts.UnixNano()) err = r.SeekTS(ts.UnixNano())
assert.ErrorIs(t, err, tc.want) assert.ErrorIs(t, err, tc.want)
@ -142,11 +141,11 @@ func TestQLogReader_ReadNext(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
err := r.SeekStart() err := r.SeekStart()
require.Nil(t, err) require.NoError(t, err)
for i := 1; i < tc.start; i++ { for i := 1; i < tc.start; i++ {
_, err = r.ReadNext() _, err = r.ReadNext()
require.Nil(t, err) require.NoError(t, err)
} }
_, err = r.ReadNext() _, err = r.ReadNext()

View File

@ -8,6 +8,7 @@ import (
"testing" "testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -37,11 +38,12 @@ func TestStats(t *testing.T) {
} }
s, err := createObject(conf) s, err := createObject(conf)
require.Nil(t, err) require.NoError(t, err)
t.Cleanup(func() { testutil.CleanupAndRequireSuccess(t, func() (err error) {
s.clear() s.clear()
s.Close() s.Close()
assert.Nil(t, os.Remove(conf.Filename))
return os.Remove(conf.Filename)
}) })
s.Update(Entry{ s.Update(Entry{
@ -109,10 +111,11 @@ func TestLargeNumbers(t *testing.T) {
UnitID: newID, UnitID: newID,
} }
s, err := createObject(conf) s, err := createObject(conf)
require.Nil(t, err) require.NoError(t, err)
t.Cleanup(func() { testutil.CleanupAndRequireSuccess(t, func() (err error) {
s.Close() s.Close()
assert.Nil(t, os.Remove(conf.Filename))
return os.Remove(conf.Filename)
}) })
// Number of distinct clients and domains every hour. // Number of distinct clients and domains every hour.

View File

@ -12,7 +12,9 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/version" "github.com/AdguardTeam/AdGuardHome/internal/version"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
// TODO(a.garipov): Rewrite these tests. // TODO(a.garipov): Rewrite these tests.
@ -67,7 +69,7 @@ func TestUpdateGetVersion(t *testing.T) {
}` }`
l, lport := startHTTPServer(jsonData) l, lport := startHTTPServer(jsonData)
t.Cleanup(func() { assert.Nil(t, l.Close()) }) testutil.CleanupAndRequireSuccess(t, l.Close)
u := NewUpdater(&Config{ u := NewUpdater(&Config{
Client: &http.Client{}, Client: &http.Client{},
@ -85,7 +87,8 @@ func TestUpdateGetVersion(t *testing.T) {
u.versionCheckURL = fakeURL.String() u.versionCheckURL = fakeURL.String()
info, err := u.VersionInfo(false) info, err := u.VersionInfo(false)
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, "v0.103.0-beta.2", info.NewVersion) assert.Equal(t, "v0.103.0-beta.2", info.NewVersion)
assert.Equal(t, "AdGuard Home v0.103.0-beta.2 is now available!", info.Announcement) assert.Equal(t, "AdGuard Home v0.103.0-beta.2 is now available!", info.Announcement)
assert.Equal(t, "https://github.com/AdguardTeam/AdGuardHome/internal/releases", info.AnnouncementURL) assert.Equal(t, "https://github.com/AdguardTeam/AdGuardHome/internal/releases", info.AnnouncementURL)
@ -96,22 +99,23 @@ func TestUpdateGetVersion(t *testing.T) {
// check cached // check cached
_, err = u.VersionInfo(false) _, err = u.VersionInfo(false)
assert.Nil(t, err) require.NoError(t, err)
} }
func TestUpdate(t *testing.T) { func TestUpdate(t *testing.T) {
wd := t.TempDir() wd := t.TempDir()
assert.Nil(t, os.WriteFile(filepath.Join(wd, "AdGuardHome"), []byte("AdGuardHome"), 0o755)) require.NoError(t, os.WriteFile(filepath.Join(wd, "AdGuardHome"), []byte("AdGuardHome"), 0o755))
assert.Nil(t, os.WriteFile(filepath.Join(wd, "README.md"), []byte("README.md"), 0o644)) require.NoError(t, os.WriteFile(filepath.Join(wd, "README.md"), []byte("README.md"), 0o644))
assert.Nil(t, os.WriteFile(filepath.Join(wd, "LICENSE.txt"), []byte("LICENSE.txt"), 0o644)) require.NoError(t, os.WriteFile(filepath.Join(wd, "LICENSE.txt"), []byte("LICENSE.txt"), 0o644))
assert.Nil(t, os.WriteFile(filepath.Join(wd, "AdGuardHome.yaml"), []byte("AdGuardHome.yaml"), 0o644)) require.NoError(t, os.WriteFile(filepath.Join(wd, "AdGuardHome.yaml"), []byte("AdGuardHome.yaml"), 0o644))
// start server for returning package file // start server for returning package file
pkgData, err := os.ReadFile("testdata/AdGuardHome.tar.gz") pkgData, err := os.ReadFile("testdata/AdGuardHome.tar.gz")
assert.Nil(t, err) require.NoError(t, err)
l, lport := startHTTPServer(string(pkgData)) l, lport := startHTTPServer(string(pkgData))
t.Cleanup(func() { assert.Nil(t, l.Close()) }) testutil.CleanupAndRequireSuccess(t, l.Close)
u := NewUpdater(&Config{ u := NewUpdater(&Config{
Client: &http.Client{}, Client: &http.Client{},
@ -129,56 +133,66 @@ func TestUpdate(t *testing.T) {
u.newVersion = "v0.103.1" u.newVersion = "v0.103.1"
u.packageURL = fakeURL.String() u.packageURL = fakeURL.String()
assert.Nil(t, u.prepare()) require.NoError(t, u.prepare())
u.currentExeName = filepath.Join(wd, "AdGuardHome") u.currentExeName = filepath.Join(wd, "AdGuardHome")
assert.Nil(t, u.downloadPackageFile(u.packageURL, u.packageName))
assert.Nil(t, u.unpack()) require.NoError(t, u.downloadPackageFile(u.packageURL, u.packageName))
// assert.Nil(t, u.check()) require.NoError(t, u.unpack())
assert.Nil(t, u.backup())
assert.Nil(t, u.replace()) // require.NoError(t, u.check())
require.NoError(t, u.backup())
require.NoError(t, u.replace())
u.clean() u.clean()
// check backup files // check backup files
d, err := os.ReadFile(filepath.Join(wd, "agh-backup", "AdGuardHome.yaml")) d, err := os.ReadFile(filepath.Join(wd, "agh-backup", "AdGuardHome.yaml"))
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, "AdGuardHome.yaml", string(d)) assert.Equal(t, "AdGuardHome.yaml", string(d))
d, err = os.ReadFile(filepath.Join(wd, "agh-backup", "AdGuardHome")) d, err = os.ReadFile(filepath.Join(wd, "agh-backup", "AdGuardHome"))
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, "AdGuardHome", string(d)) assert.Equal(t, "AdGuardHome", string(d))
// check updated files // check updated files
d, err = os.ReadFile(filepath.Join(wd, "AdGuardHome")) d, err = os.ReadFile(filepath.Join(wd, "AdGuardHome"))
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, "1", string(d)) assert.Equal(t, "1", string(d))
d, err = os.ReadFile(filepath.Join(wd, "README.md")) d, err = os.ReadFile(filepath.Join(wd, "README.md"))
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, "2", string(d)) assert.Equal(t, "2", string(d))
d, err = os.ReadFile(filepath.Join(wd, "LICENSE.txt")) d, err = os.ReadFile(filepath.Join(wd, "LICENSE.txt"))
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, "3", string(d)) assert.Equal(t, "3", string(d))
d, err = os.ReadFile(filepath.Join(wd, "AdGuardHome.yaml")) d, err = os.ReadFile(filepath.Join(wd, "AdGuardHome.yaml"))
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, "AdGuardHome.yaml", string(d)) assert.Equal(t, "AdGuardHome.yaml", string(d))
} }
func TestUpdateWindows(t *testing.T) { func TestUpdateWindows(t *testing.T) {
wd := t.TempDir() wd := t.TempDir()
assert.Nil(t, os.WriteFile(filepath.Join(wd, "AdGuardHome.exe"), []byte("AdGuardHome.exe"), 0o755)) require.NoError(t, os.WriteFile(filepath.Join(wd, "AdGuardHome.exe"), []byte("AdGuardHome.exe"), 0o755))
assert.Nil(t, os.WriteFile(filepath.Join(wd, "README.md"), []byte("README.md"), 0o644)) require.NoError(t, os.WriteFile(filepath.Join(wd, "README.md"), []byte("README.md"), 0o644))
assert.Nil(t, os.WriteFile(filepath.Join(wd, "LICENSE.txt"), []byte("LICENSE.txt"), 0o644)) require.NoError(t, os.WriteFile(filepath.Join(wd, "LICENSE.txt"), []byte("LICENSE.txt"), 0o644))
assert.Nil(t, os.WriteFile(filepath.Join(wd, "AdGuardHome.yaml"), []byte("AdGuardHome.yaml"), 0o644)) require.NoError(t, os.WriteFile(filepath.Join(wd, "AdGuardHome.yaml"), []byte("AdGuardHome.yaml"), 0o644))
// start server for returning package file // start server for returning package file
pkgData, err := os.ReadFile("testdata/AdGuardHome.zip") pkgData, err := os.ReadFile("testdata/AdGuardHome.zip")
assert.Nil(t, err) require.NoError(t, err)
l, lport := startHTTPServer(string(pkgData)) l, lport := startHTTPServer(string(pkgData))
t.Cleanup(func() { assert.Nil(t, l.Close()) }) testutil.CleanupAndRequireSuccess(t, l.Close)
u := NewUpdater(&Config{ u := NewUpdater(&Config{
Client: &http.Client{}, Client: &http.Client{},
@ -197,39 +211,48 @@ func TestUpdateWindows(t *testing.T) {
u.newVersion = "v0.103.1" u.newVersion = "v0.103.1"
u.packageURL = fakeURL.String() u.packageURL = fakeURL.String()
assert.Nil(t, u.prepare()) require.NoError(t, u.prepare())
u.currentExeName = filepath.Join(wd, "AdGuardHome.exe") u.currentExeName = filepath.Join(wd, "AdGuardHome.exe")
assert.Nil(t, u.downloadPackageFile(u.packageURL, u.packageName))
assert.Nil(t, u.unpack()) require.NoError(t, u.downloadPackageFile(u.packageURL, u.packageName))
require.NoError(t, u.unpack())
// assert.Nil(t, u.check()) // assert.Nil(t, u.check())
assert.Nil(t, u.backup()) require.NoError(t, u.backup())
assert.Nil(t, u.replace()) require.NoError(t, u.replace())
u.clean() u.clean()
// check backup files // check backup files
d, err := os.ReadFile(filepath.Join(wd, "agh-backup", "AdGuardHome.yaml")) d, err := os.ReadFile(filepath.Join(wd, "agh-backup", "AdGuardHome.yaml"))
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, "AdGuardHome.yaml", string(d)) assert.Equal(t, "AdGuardHome.yaml", string(d))
d, err = os.ReadFile(filepath.Join(wd, "agh-backup", "AdGuardHome.exe")) d, err = os.ReadFile(filepath.Join(wd, "agh-backup", "AdGuardHome.exe"))
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, "AdGuardHome.exe", string(d)) assert.Equal(t, "AdGuardHome.exe", string(d))
// check updated files // check updated files
d, err = os.ReadFile(filepath.Join(wd, "AdGuardHome.exe")) d, err = os.ReadFile(filepath.Join(wd, "AdGuardHome.exe"))
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, "1", string(d)) assert.Equal(t, "1", string(d))
d, err = os.ReadFile(filepath.Join(wd, "README.md")) d, err = os.ReadFile(filepath.Join(wd, "README.md"))
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, "2", string(d)) assert.Equal(t, "2", string(d))
d, err = os.ReadFile(filepath.Join(wd, "LICENSE.txt")) d, err = os.ReadFile(filepath.Join(wd, "LICENSE.txt"))
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, "3", string(d)) assert.Equal(t, "3", string(d))
d, err = os.ReadFile(filepath.Join(wd, "AdGuardHome.yaml")) d, err = os.ReadFile(filepath.Join(wd, "AdGuardHome.yaml"))
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, "AdGuardHome.yaml", string(d)) assert.Equal(t, "AdGuardHome.yaml", string(d))
} }
@ -243,7 +266,7 @@ func TestUpdater_VersionInto_ARM(t *testing.T) {
}` }`
l, lport := startHTTPServer(jsonData) l, lport := startHTTPServer(jsonData)
t.Cleanup(func() { assert.Nil(t, l.Close()) }) testutil.CleanupAndRequireSuccess(t, l.Close)
u := NewUpdater(&Config{ u := NewUpdater(&Config{
Client: &http.Client{}, Client: &http.Client{},
@ -262,7 +285,8 @@ func TestUpdater_VersionInto_ARM(t *testing.T) {
u.versionCheckURL = fakeURL.String() u.versionCheckURL = fakeURL.String()
info, err := u.VersionInfo(false) info, err := u.VersionInfo(false)
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, "v0.103.0-beta.2", info.NewVersion) assert.Equal(t, "v0.103.0-beta.2", info.NewVersion)
assert.Equal(t, "AdGuard Home v0.103.0-beta.2 is now available!", info.Announcement) assert.Equal(t, "AdGuard Home v0.103.0-beta.2 is now available!", info.Announcement)
assert.Equal(t, "https://github.com/AdguardTeam/AdGuardHome/internal/releases", info.AnnouncementURL) assert.Equal(t, "https://github.com/AdguardTeam/AdGuardHome/internal/releases", info.AnnouncementURL)
@ -282,7 +306,7 @@ func TestUpdater_VersionInto_MIPS(t *testing.T) {
}` }`
l, lport := startHTTPServer(jsonData) l, lport := startHTTPServer(jsonData)
t.Cleanup(func() { assert.Nil(t, l.Close()) }) testutil.CleanupAndRequireSuccess(t, l.Close)
u := NewUpdater(&Config{ u := NewUpdater(&Config{
Client: &http.Client{}, Client: &http.Client{},
@ -301,7 +325,8 @@ func TestUpdater_VersionInto_MIPS(t *testing.T) {
u.versionCheckURL = fakeURL.String() u.versionCheckURL = fakeURL.String()
info, err := u.VersionInfo(false) info, err := u.VersionInfo(false)
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, "v0.103.0-beta.2", info.NewVersion) assert.Equal(t, "v0.103.0-beta.2", info.NewVersion)
assert.Equal(t, "AdGuard Home v0.103.0-beta.2 is now available!", info.Announcement) assert.Equal(t, "AdGuard Home v0.103.0-beta.2 is now available!", info.Announcement)
assert.Equal(t, "https://github.com/AdguardTeam/AdGuardHome/internal/releases", info.AnnouncementURL) assert.Equal(t, "https://github.com/AdguardTeam/AdGuardHome/internal/releases", info.AnnouncementURL)