Pull request: 2639 use testify require vol.4

Merge in DNS/adguard-home from 2639-testify-require-4 to master

Closes #2639.

Squashed commit of the following:

commit 0bb9125f42ab6d2511c1b8e481112aa5edd581d9
Merge: 0e9e9ed1 2c9992e0
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Thu Mar 11 15:47:21 2021 +0300

    Merge branch 'master' into 2639-testify-require-4

commit 0e9e9ed16ae13ce648b5e1da6ffd123df911c2d7
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Wed Mar 10 12:43:15 2021 +0300

    home: rm deletion error check

commit 6bfbbcd2b7f9197a06856f9e6b959c2e1c4b8353
Merge: c8ebe541 8811c881
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Wed Mar 10 12:30:07 2021 +0300

    Merge branch 'master' into 2639-testify-require-4

commit c8ebe54142bba780226f76ddb72e33664ed28f30
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Wed Mar 10 12:28:43 2021 +0300

    home: imp tests

commit f0e1db456f02df5f5f56ca93e7bd40a48475b38c
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Fri Mar 5 14:06:41 2021 +0300

    dnsforward: imp tests

commit 4528246105ed06471a8778abbe8e5c30fc5483d5
Merge: 54b08d9c 90ebc4d8
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Thu Mar 4 18:17:52 2021 +0300

    Merge branch 'master' into 2639-testify-require-4

commit 54b08d9c980b8d69d019a1a1b3931aa048275691
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Thu Feb 11 13:17:05 2021 +0300

    dnsfilter: imp tests
This commit is contained in:
Eugene Burkov 2021-03-11 17:32:58 +03:00
parent 2c9992e0cc
commit dfdbfee4fd
19 changed files with 1375 additions and 1267 deletions

View File

@ -5,6 +5,7 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"strings"
"testing" "testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/aghtest"
@ -13,6 +14,7 @@ import (
"github.com/AdguardTeam/urlfilter/rules" "github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
@ -58,7 +60,7 @@ func (d *DNSFilter) checkMatch(t *testing.T, hostname string) {
t.Helper() t.Helper()
res, err := d.CheckHost(hostname, dns.TypeA, &setts) res, err := d.CheckHost(hostname, dns.TypeA, &setts)
assert.Nilf(t, err, "Error while matching host %s: %s", hostname, err) require.Nilf(t, err, "Error while matching host %s: %s", hostname, err)
assert.Truef(t, res.IsFiltered, "Expected hostname %s to match", hostname) assert.Truef(t, res.IsFiltered, "Expected hostname %s to match", hostname)
} }
@ -66,20 +68,20 @@ func (d *DNSFilter) checkMatchIP(t *testing.T, hostname, ip string, qtype uint16
t.Helper() t.Helper()
res, err := d.CheckHost(hostname, qtype, &setts) res, err := d.CheckHost(hostname, qtype, &setts)
assert.Nilf(t, err, "Error while matching host %s: %s", hostname, err) require.Nilf(t, err, "Error while matching host %s: %s", hostname, err)
assert.Truef(t, res.IsFiltered, "Expected hostname %s to match", hostname) assert.Truef(t, res.IsFiltered, "Expected hostname %s to match", hostname)
if assert.NotEmpty(t, res.Rules, "Expected result to have rules") {
r := res.Rules[0] require.NotEmpty(t, res.Rules, "Expected result to have rules")
assert.NotNilf(t, r.IP, "Expected ip %s to match, actual: %v", ip, r.IP) r := res.Rules[0]
assert.Equalf(t, ip, r.IP.String(), "Expected ip %s to match, actual: %v", ip, r.IP) require.NotNilf(t, r.IP, "Expected ip %s to match, actual: %v", ip, r.IP)
} assert.Equalf(t, ip, r.IP.String(), "Expected ip %s to match, actual: %v", ip, r.IP)
} }
func (d *DNSFilter) checkMatchEmpty(t *testing.T, hostname string) { func (d *DNSFilter) checkMatchEmpty(t *testing.T, hostname string) {
t.Helper() t.Helper()
res, err := d.CheckHost(hostname, dns.TypeA, &setts) res, err := d.CheckHost(hostname, dns.TypeA, &setts)
assert.Nilf(t, err, "Error while matching host %s: %s", hostname, err) require.Nilf(t, err, "Error while matching host %s: %s", hostname, err)
assert.Falsef(t, res.IsFiltered, "Expected hostname %s to not match", hostname) assert.Falsef(t, res.IsFiltered, "Expected hostname %s to not match", hostname)
} }
@ -110,40 +112,40 @@ func TestEtcHostsMatching(t *testing.T) {
// Empty IPv6. // Empty IPv6.
res, err := d.CheckHost("block.com", dns.TypeAAAA, &setts) res, err := d.CheckHost("block.com", dns.TypeAAAA, &setts)
assert.Nil(t, err) require.Nil(t, err)
assert.True(t, res.IsFiltered) assert.True(t, res.IsFiltered)
if assert.Len(t, res.Rules, 1) {
assert.Equal(t, "0.0.0.0 block.com", res.Rules[0].Text) require.Len(t, res.Rules, 1)
assert.Empty(t, res.Rules[0].IP) assert.Equal(t, "0.0.0.0 block.com", res.Rules[0].Text)
} assert.Empty(t, res.Rules[0].IP)
// IPv6 match. // IPv6 match.
d.checkMatchIP(t, "ipv6.com", addr6, dns.TypeAAAA) d.checkMatchIP(t, "ipv6.com", addr6, dns.TypeAAAA)
// Empty IPv4. // Empty IPv4.
res, err = d.CheckHost("ipv6.com", dns.TypeA, &setts) res, err = d.CheckHost("ipv6.com", dns.TypeA, &setts)
assert.Nil(t, err) require.Nil(t, err)
assert.True(t, res.IsFiltered) assert.True(t, res.IsFiltered)
if assert.Len(t, res.Rules, 1) {
assert.Equal(t, "::1 ipv6.com", res.Rules[0].Text) require.Len(t, res.Rules, 1)
assert.Empty(t, res.Rules[0].IP) assert.Equal(t, "::1 ipv6.com", res.Rules[0].Text)
} assert.Empty(t, res.Rules[0].IP)
// Two IPv4, the first one returned. // Two IPv4, the first one returned.
res, err = d.CheckHost("host2", dns.TypeA, &setts) res, err = d.CheckHost("host2", dns.TypeA, &setts)
assert.Nil(t, err) require.Nil(t, err)
assert.True(t, res.IsFiltered) assert.True(t, res.IsFiltered)
if assert.Len(t, res.Rules, 1) {
assert.Equal(t, res.Rules[0].IP, net.IP{0, 0, 0, 1}) require.Len(t, res.Rules, 1)
} assert.Equal(t, res.Rules[0].IP, net.IP{0, 0, 0, 1})
// One IPv6 address. // One IPv6 address.
res, err = d.CheckHost("host2", dns.TypeAAAA, &setts) res, err = d.CheckHost("host2", dns.TypeAAAA, &setts)
assert.Nil(t, err) require.Nil(t, err)
assert.True(t, res.IsFiltered) assert.True(t, res.IsFiltered)
if assert.Len(t, res.Rules, 1) {
assert.Equal(t, res.Rules[0].IP, net.IPv6loopback) require.Len(t, res.Rules, 1)
} assert.Equal(t, res.Rules[0].IP, net.IPv6loopback)
} }
// Safe Browsing. // Safe Browsing.
@ -155,14 +157,14 @@ func TestSafeBrowsing(t *testing.T) {
d := newForTest(&Config{SafeBrowsingEnabled: true}, nil) d := newForTest(&Config{SafeBrowsingEnabled: true}, nil)
t.Cleanup(d.Close) t.Cleanup(d.Close)
matching := "wmconvirus.narod.ru" const matching = "wmconvirus.narod.ru"
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{ d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
Hostname: matching, Hostname: matching,
Block: true, Block: true,
}) })
d.checkMatch(t, matching) d.checkMatch(t, matching)
assert.Contains(t, logOutput.String(), "SafeBrowsing lookup for "+matching) require.Contains(t, logOutput.String(), "SafeBrowsing lookup for "+matching)
d.checkMatch(t, "test."+matching) d.checkMatch(t, "test."+matching)
d.checkMatchEmpty(t, "yandex.ru") d.checkMatchEmpty(t, "yandex.ru")
@ -178,7 +180,7 @@ func TestSafeBrowsing(t *testing.T) {
func TestParallelSB(t *testing.T) { func TestParallelSB(t *testing.T) {
d := newForTest(&Config{SafeBrowsingEnabled: true}, nil) d := newForTest(&Config{SafeBrowsingEnabled: true}, nil)
t.Cleanup(d.Close) t.Cleanup(d.Close)
matching := "wmconvirus.narod.ru" const matching = "wmconvirus.narod.ru"
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{ d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
Hostname: matching, Hostname: matching,
Block: true, Block: true,
@ -203,7 +205,7 @@ func TestSafeSearch(t *testing.T) {
d := newForTest(&Config{SafeSearchEnabled: true}, nil) d := newForTest(&Config{SafeSearchEnabled: true}, nil)
t.Cleanup(d.Close) t.Cleanup(d.Close)
val, ok := d.SafeSearchDomain("www.google.com") val, ok := d.SafeSearchDomain("www.google.com")
assert.True(t, ok, "Expected safesearch to find result for www.google.com") require.True(t, ok, "Expected safesearch to find result for www.google.com")
assert.Equal(t, "forcesafesearch.google.com", val, "Expected safesearch for google.com to be forcesafesearch.google.com") assert.Equal(t, "forcesafesearch.google.com", val, "Expected safesearch for google.com to be forcesafesearch.google.com")
} }
@ -211,6 +213,8 @@ func TestCheckHostSafeSearchYandex(t *testing.T) {
d := newForTest(&Config{SafeSearchEnabled: true}, nil) d := newForTest(&Config{SafeSearchEnabled: true}, nil)
t.Cleanup(d.Close) t.Cleanup(d.Close)
yandexIP := net.IPv4(213, 180, 193, 56)
// Check host for each domain. // Check host for each domain.
for _, host := range []string{ for _, host := range []string{
"yAndeX.ru", "yAndeX.ru",
@ -220,22 +224,27 @@ func TestCheckHostSafeSearchYandex(t *testing.T) {
"yandex.kz", "yandex.kz",
"www.yandex.com", "www.yandex.com",
} { } {
res, err := d.CheckHost(host, dns.TypeA, &setts) t.Run(strings.ToLower(host), func(t *testing.T) {
assert.Nil(t, err) res, err := d.CheckHost(host, dns.TypeA, &setts)
assert.True(t, res.IsFiltered) require.Nil(t, err)
if assert.Len(t, res.Rules, 1) { assert.True(t, res.IsFiltered)
assert.Equal(t, res.Rules[0].IP, net.IPv4(213, 180, 193, 56))
} require.Len(t, res.Rules, 1)
assert.Equal(t, yandexIP, res.Rules[0].IP)
})
} }
} }
func TestCheckHostSafeSearchGoogle(t *testing.T) { func TestCheckHostSafeSearchGoogle(t *testing.T) {
resolver := &aghtest.TestResolver{}
d := newForTest(&Config{ d := newForTest(&Config{
SafeSearchEnabled: true, SafeSearchEnabled: true,
CustomResolver: &aghtest.TestResolver{}, CustomResolver: resolver,
}, nil) }, nil)
t.Cleanup(d.Close) t.Cleanup(d.Close)
ip, _ := resolver.HostToIPs("forcesafesearch.google.com")
// Check host for each domain. // Check host for each domain.
for _, host := range []string{ for _, host := range []string{
"www.google.com", "www.google.com",
@ -248,11 +257,10 @@ func TestCheckHostSafeSearchGoogle(t *testing.T) {
} { } {
t.Run(host, func(t *testing.T) { t.Run(host, func(t *testing.T) {
res, err := d.CheckHost(host, dns.TypeA, &setts) res, err := d.CheckHost(host, dns.TypeA, &setts)
assert.Nil(t, err) require.Nil(t, err)
assert.True(t, res.IsFiltered) assert.True(t, res.IsFiltered)
if assert.Len(t, res.Rules, 1) { require.Len(t, res.Rules, 1)
assert.NotEqual(t, res.Rules[0].IP.String(), "0.0.0.0") assert.Equal(t, ip, res.Rules[0].IP)
}
}) })
} }
} }
@ -260,31 +268,31 @@ func TestCheckHostSafeSearchGoogle(t *testing.T) {
func TestSafeSearchCacheYandex(t *testing.T) { func TestSafeSearchCacheYandex(t *testing.T) {
d := newForTest(nil, nil) d := newForTest(nil, nil)
t.Cleanup(d.Close) t.Cleanup(d.Close)
domain := "yandex.ru" const domain = "yandex.ru"
// Check host with disabled safesearch. // Check host with disabled safesearch.
res, err := d.CheckHost(domain, dns.TypeA, &setts) res, err := d.CheckHost(domain, dns.TypeA, &setts)
assert.Nil(t, err) require.Nil(t, err)
assert.False(t, res.IsFiltered) assert.False(t, res.IsFiltered)
assert.Empty(t, res.Rules) require.Empty(t, res.Rules)
yandexIP := net.IPv4(213, 180, 193, 56)
d = newForTest(&Config{SafeSearchEnabled: true}, nil) d = newForTest(&Config{SafeSearchEnabled: true}, nil)
t.Cleanup(d.Close) t.Cleanup(d.Close)
res, err = d.CheckHost(domain, dns.TypeA, &setts) res, err = d.CheckHost(domain, dns.TypeA, &setts)
assert.Nilf(t, err, "CheckHost for safesearh domain %s failed cause %s", domain, err) require.Nilf(t, err, "CheckHost for safesearh domain %s failed cause %s", domain, err)
// For yandex we already know valid IP. // For yandex we already know valid IP.
if assert.Len(t, res.Rules, 1) { require.Len(t, res.Rules, 1)
assert.Equal(t, res.Rules[0].IP, net.IPv4(213, 180, 193, 56)) assert.Equal(t, res.Rules[0].IP, yandexIP)
}
// Check cache. // Check cache.
cachedValue, isFound := getCachedResult(gctx.safeSearchCache, domain) cachedValue, isFound := getCachedResult(gctx.safeSearchCache, domain)
assert.True(t, isFound) require.True(t, isFound)
if assert.Len(t, cachedValue.Rules, 1) { require.Len(t, cachedValue.Rules, 1)
assert.Equal(t, cachedValue.Rules[0].IP, net.IPv4(213, 180, 193, 56)) assert.Equal(t, cachedValue.Rules[0].IP, yandexIP)
}
} }
func TestSafeSearchCacheGoogle(t *testing.T) { func TestSafeSearchCacheGoogle(t *testing.T) {
@ -294,11 +302,11 @@ func TestSafeSearchCacheGoogle(t *testing.T) {
}, nil) }, nil)
t.Cleanup(d.Close) t.Cleanup(d.Close)
domain := "www.google.ru" const domain = "www.google.ru"
res, err := d.CheckHost(domain, dns.TypeA, &setts) res, err := d.CheckHost(domain, dns.TypeA, &setts)
assert.Nil(t, err) require.Nil(t, err)
assert.False(t, res.IsFiltered) assert.False(t, res.IsFiltered)
assert.Empty(t, res.Rules) require.Empty(t, res.Rules)
d = newForTest(&Config{SafeSearchEnabled: true}, nil) d = newForTest(&Config{SafeSearchEnabled: true}, nil)
t.Cleanup(d.Close) t.Cleanup(d.Close)
@ -306,12 +314,10 @@ func TestSafeSearchCacheGoogle(t *testing.T) {
// Lookup for safesearch domain. // Lookup for safesearch domain.
safeDomain, ok := d.SafeSearchDomain(domain) safeDomain, ok := d.SafeSearchDomain(domain)
assert.Truef(t, ok, "Failed to get safesearch domain for %s", domain) require.Truef(t, ok, "Failed to get safesearch domain for %s", domain)
ips, err := resolver.LookupIP(context.Background(), "ip", safeDomain) ips, err := resolver.LookupIP(context.Background(), "ip", safeDomain)
if err != nil { require.Nilf(t, err, "Failed to lookup for %s", safeDomain)
t.Fatalf("Failed to lookup for %s", safeDomain)
}
var ip net.IP var ip net.IP
for _, foundIP := range ips { for _, foundIP := range ips {
@ -323,17 +329,15 @@ func TestSafeSearchCacheGoogle(t *testing.T) {
} }
res, err = d.CheckHost(domain, dns.TypeA, &setts) res, err = d.CheckHost(domain, dns.TypeA, &setts)
assert.Nil(t, err) require.Nil(t, err)
if assert.Len(t, res.Rules, 1) { require.Len(t, res.Rules, 1)
assert.True(t, res.Rules[0].IP.Equal(ip)) assert.True(t, res.Rules[0].IP.Equal(ip))
}
// Check cache. // Check cache.
cachedValue, isFound := getCachedResult(gctx.safeSearchCache, domain) cachedValue, isFound := getCachedResult(gctx.safeSearchCache, domain)
assert.True(t, isFound) require.True(t, isFound)
if assert.Len(t, cachedValue.Rules, 1) { require.Len(t, cachedValue.Rules, 1)
assert.True(t, cachedValue.Rules[0].IP.Equal(ip)) assert.True(t, cachedValue.Rules[0].IP.Equal(ip))
}
} }
// Parental. // Parental.
@ -345,24 +349,23 @@ func TestParentalControl(t *testing.T) {
d := newForTest(&Config{ParentalEnabled: true}, nil) d := newForTest(&Config{ParentalEnabled: true}, nil)
t.Cleanup(d.Close) t.Cleanup(d.Close)
matching := "pornhub.com" const matching = "pornhub.com"
d.SetParentalUpstream(&aghtest.TestBlockUpstream{ d.SetParentalUpstream(&aghtest.TestBlockUpstream{
Hostname: matching, Hostname: matching,
Block: true, Block: true,
}) })
d.checkMatch(t, matching) d.checkMatch(t, matching)
assert.Contains(t, logOutput.String(), "Parental lookup for "+matching) require.Contains(t, logOutput.String(), "Parental lookup for "+matching)
d.checkMatch(t, "www."+matching) d.checkMatch(t, "www."+matching)
d.checkMatchEmpty(t, "www.yandex.ru") d.checkMatchEmpty(t, "www.yandex.ru")
d.checkMatchEmpty(t, "yandex.ru") d.checkMatchEmpty(t, "yandex.ru")
d.checkMatchEmpty(t, "api.jquery.com") d.checkMatchEmpty(t, "api.jquery.com")
// test cached result // Test cached result.
d.parentalServer = "127.0.0.1" d.parentalServer = "127.0.0.1"
d.checkMatch(t, matching) d.checkMatch(t, matching)
d.checkMatchEmpty(t, "yandex.ru") d.checkMatchEmpty(t, "yandex.ru")
d.parentalServer = defaultParentalServer
} }
// Filtering. // Filtering.
@ -651,7 +654,7 @@ func TestMatching(t *testing.T) {
t.Cleanup(d.Close) t.Cleanup(d.Close)
res, err := d.CheckHost(tc.host, tc.wantDNSType, &setts) res, err := d.CheckHost(tc.host, tc.wantDNSType, &setts)
assert.Nilf(t, err, "Error while matching host %s: %s", tc.host, err) require.Nilf(t, err, "Error while matching host %s: %s", tc.host, err)
assert.Equalf(t, tc.wantIsFiltered, res.IsFiltered, "Hostname %s has wrong result (%v must be %v)", tc.host, res.IsFiltered, tc.wantIsFiltered) assert.Equalf(t, tc.wantIsFiltered, res.IsFiltered, "Hostname %s has wrong result (%v must be %v)", tc.host, res.IsFiltered, tc.wantIsFiltered)
assert.Equalf(t, tc.wantReason, res.Reason, "Hostname %s has wrong reason (%v must be %v)", tc.host, res.Reason, tc.wantReason) assert.Equalf(t, tc.wantReason, res.Reason, "Hostname %s has wrong reason (%v must be %v)", tc.host, res.Reason, tc.wantReason)
}) })
@ -674,28 +677,24 @@ func TestWhitelist(t *testing.T) {
}} }}
d := newForTest(nil, filters) d := newForTest(nil, filters)
err := d.SetFilters(filters, whiteFilters, false) require.Nil(t, d.SetFilters(filters, whiteFilters, false))
assert.Nil(t, err)
t.Cleanup(d.Close) t.Cleanup(d.Close)
// Matched by white filter. // Matched by white filter.
res, err := d.CheckHost("host1", dns.TypeA, &setts) res, err := d.CheckHost("host1", dns.TypeA, &setts)
assert.Nil(t, err) require.Nil(t, err)
assert.False(t, res.IsFiltered) assert.False(t, res.IsFiltered)
assert.Equal(t, res.Reason, NotFilteredAllowList) assert.Equal(t, res.Reason, NotFilteredAllowList)
if assert.Len(t, res.Rules, 1) { require.Len(t, res.Rules, 1)
assert.Equal(t, "||host1^", res.Rules[0].Text) assert.Equal(t, "||host1^", res.Rules[0].Text)
}
// Not matched by white filter, but matched by block filter. // Not matched by white filter, but matched by block filter.
res, err = d.CheckHost("host2", dns.TypeA, &setts) res, err = d.CheckHost("host2", dns.TypeA, &setts)
assert.Nil(t, err) require.Nil(t, err)
assert.True(t, res.IsFiltered) assert.True(t, res.IsFiltered)
assert.Equal(t, res.Reason, FilteredBlockList) assert.Equal(t, res.Reason, FilteredBlockList)
if assert.Len(t, res.Rules, 1) { require.Len(t, res.Rules, 1)
assert.Equal(t, "||host2^", res.Rules[0].Text) assert.Equal(t, "||host2^", res.Rules[0].Text)
}
} }
// Client Settings. // Client Settings.
@ -797,7 +796,7 @@ func BenchmarkSafeBrowsing(b *testing.B) {
}) })
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
res, err := d.CheckHost(blocked, dns.TypeA, &setts) res, err := d.CheckHost(blocked, dns.TypeA, &setts)
assert.Nilf(b, err, "Error while matching host %s: %s", blocked, err) require.Nilf(b, err, "Error while matching host %s: %s", blocked, err)
assert.True(b, res.IsFiltered, "Expected hostname %s to match", blocked) assert.True(b, res.IsFiltered, "Expected hostname %s to match", blocked)
} }
} }
@ -813,7 +812,7 @@ func BenchmarkSafeBrowsingParallel(b *testing.B) {
b.RunParallel(func(pb *testing.PB) { b.RunParallel(func(pb *testing.PB) {
for pb.Next() { for pb.Next() {
res, err := d.CheckHost(blocked, dns.TypeA, &setts) res, err := d.CheckHost(blocked, dns.TypeA, &setts)
assert.Nilf(b, err, "Error while matching host %s: %s", blocked, err) require.Nilf(b, err, "Error while matching host %s: %s", blocked, err)
assert.True(b, res.IsFiltered, "Expected hostname %s to match", blocked) assert.True(b, res.IsFiltered, "Expected hostname %s to match", blocked)
} }
}) })
@ -824,7 +823,7 @@ func BenchmarkSafeSearch(b *testing.B) {
b.Cleanup(d.Close) b.Cleanup(d.Close)
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
val, ok := d.SafeSearchDomain("www.google.com") val, ok := d.SafeSearchDomain("www.google.com")
assert.True(b, ok, "Expected safesearch to find result for www.google.com") require.True(b, ok, "Expected safesearch to find result for www.google.com")
assert.Equal(b, "forcesafesearch.google.com", val, "Expected safesearch for google.com to be forcesafesearch.google.com") assert.Equal(b, "forcesafesearch.google.com", val, "Expected safesearch for google.com to be forcesafesearch.google.com")
} }
} }
@ -835,7 +834,7 @@ func BenchmarkSafeSearchParallel(b *testing.B) {
b.RunParallel(func(pb *testing.PB) { b.RunParallel(func(pb *testing.PB) {
for pb.Next() { for pb.Next() {
val, ok := d.SafeSearchDomain("www.google.com") val, ok := d.SafeSearchDomain("www.google.com")
assert.True(b, ok, "Expected safesearch to find result for www.google.com") require.True(b, ok, "Expected safesearch to find result for www.google.com")
assert.Equal(b, "forcesafesearch.google.com", val, "Expected safesearch for google.com to be forcesafesearch.google.com") assert.Equal(b, "forcesafesearch.google.com", val, "Expected safesearch for google.com to be forcesafesearch.google.com")
} }
}) })

View File

@ -7,6 +7,7 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) { func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
@ -55,138 +56,89 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
ipv6p1 := net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} ipv6p1 := net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
ipv6p2 := net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2} ipv6p2 := net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2}
testCasesA := []struct {
name string
dtyp uint16
rcode int
want []interface{}
}{{
name: "a-record",
dtyp: dns.TypeA,
rcode: dns.RcodeSuccess,
want: []interface{}{ipv4p1},
}, {
name: "aaaa-record",
dtyp: dns.TypeAAAA,
rcode: dns.RcodeSuccess,
want: []interface{}{ipv6p1},
}, {
name: "txt-record",
dtyp: dns.TypeTXT,
rcode: dns.RcodeSuccess,
want: []interface{}{"hello-world"},
}, {
name: "refused",
rcode: dns.RcodeRefused,
}, {
name: "a-records",
dtyp: dns.TypeA,
rcode: dns.RcodeSuccess,
want: []interface{}{ipv4p1, ipv4p2},
}, {
name: "aaaa-records",
dtyp: dns.TypeAAAA,
rcode: dns.RcodeSuccess,
want: []interface{}{ipv6p1, ipv6p2},
}, {
name: "disable-one",
dtyp: dns.TypeA,
rcode: dns.RcodeSuccess,
want: []interface{}{ipv4p2},
}, {
name: "disable-cname",
dtyp: dns.TypeA,
rcode: dns.RcodeSuccess,
want: []interface{}{ipv4p1},
}}
for _, tc := range testCasesA {
t.Run(tc.name, func(t *testing.T) {
host := path.Base(tc.name)
res, err := f.CheckHostRules(host, tc.dtyp, setts)
require.Nil(t, err)
dnsrr := res.DNSRewriteResult
require.NotNil(t, dnsrr)
assert.Equal(t, tc.rcode, dnsrr.RCode)
if tc.rcode == dns.RcodeRefused {
return
}
ipVals := dnsrr.Response[tc.dtyp]
require.Len(t, ipVals, len(tc.want))
for i, val := range tc.want {
require.Equal(t, val, ipVals[i])
}
})
}
t.Run("cname", func(t *testing.T) { t.Run("cname", func(t *testing.T) {
dtyp := dns.TypeA dtyp := dns.TypeA
host := path.Base(t.Name()) host := path.Base(t.Name())
res, err := f.CheckHostRules(host, dtyp, setts) res, err := f.CheckHostRules(host, dtyp, setts)
assert.Nil(t, err) require.Nil(t, err)
assert.Equal(t, "new-cname", res.CanonName) assert.Equal(t, "new-cname", res.CanonName)
}) })
t.Run("a-record", func(t *testing.T) {
dtyp := dns.TypeA
host := path.Base(t.Name())
res, err := f.CheckHostRules(host, dtyp, setts)
assert.Nil(t, err)
if dnsrr := res.DNSRewriteResult; assert.NotNil(t, dnsrr) {
assert.Equal(t, dns.RcodeSuccess, dnsrr.RCode)
if ipVals := dnsrr.Response[dtyp]; assert.Len(t, ipVals, 1) {
assert.Equal(t, ipv4p1, ipVals[0])
}
}
})
t.Run("aaaa-record", func(t *testing.T) {
dtyp := dns.TypeAAAA
host := path.Base(t.Name())
res, err := f.CheckHostRules(host, dtyp, setts)
assert.Nil(t, err)
if dnsrr := res.DNSRewriteResult; assert.NotNil(t, dnsrr) {
assert.Equal(t, dns.RcodeSuccess, dnsrr.RCode)
if ipVals := dnsrr.Response[dtyp]; assert.Len(t, ipVals, 1) {
assert.Equal(t, ipv6p1, ipVals[0])
}
}
})
t.Run("txt-record", func(t *testing.T) {
dtyp := dns.TypeTXT
host := path.Base(t.Name())
res, err := f.CheckHostRules(host, dtyp, setts)
assert.Nil(t, err)
if dnsrr := res.DNSRewriteResult; assert.NotNil(t, dnsrr) {
assert.Equal(t, dns.RcodeSuccess, dnsrr.RCode)
if strVals := dnsrr.Response[dtyp]; assert.Len(t, strVals, 1) {
assert.Equal(t, "hello-world", strVals[0])
}
}
})
t.Run("refused", func(t *testing.T) {
host := path.Base(t.Name())
res, err := f.CheckHostRules(host, dns.TypeA, setts)
assert.Nil(t, err)
if dnsrr := res.DNSRewriteResult; assert.NotNil(t, dnsrr) {
assert.Equal(t, dns.RcodeRefused, dnsrr.RCode)
}
})
t.Run("a-records", func(t *testing.T) {
dtyp := dns.TypeA
host := path.Base(t.Name())
res, err := f.CheckHostRules(host, dtyp, setts)
assert.Nil(t, err)
if dnsrr := res.DNSRewriteResult; assert.NotNil(t, dnsrr) {
assert.Equal(t, dns.RcodeSuccess, dnsrr.RCode)
if ipVals := dnsrr.Response[dtyp]; assert.Len(t, ipVals, 2) {
assert.Equal(t, ipv4p1, ipVals[0])
assert.Equal(t, ipv4p2, ipVals[1])
}
}
})
t.Run("aaaa-records", func(t *testing.T) {
dtyp := dns.TypeAAAA
host := path.Base(t.Name())
res, err := f.CheckHostRules(host, dtyp, setts)
assert.Nil(t, err)
if dnsrr := res.DNSRewriteResult; assert.NotNil(t, dnsrr) {
assert.Equal(t, dns.RcodeSuccess, dnsrr.RCode)
if ipVals := dnsrr.Response[dtyp]; assert.Len(t, ipVals, 2) {
assert.Equal(t, ipv6p1, ipVals[0])
assert.Equal(t, ipv6p2, ipVals[1])
}
}
})
t.Run("disable-one", func(t *testing.T) {
dtyp := dns.TypeA
host := path.Base(t.Name())
res, err := f.CheckHostRules(host, dtyp, setts)
assert.Nil(t, err)
if dnsrr := res.DNSRewriteResult; assert.NotNil(t, dnsrr) {
assert.Equal(t, dns.RcodeSuccess, dnsrr.RCode)
if ipVals := dnsrr.Response[dtyp]; assert.Len(t, ipVals, 1) {
assert.Equal(t, ipv4p2, ipVals[0])
}
}
})
t.Run("disable-cname", func(t *testing.T) {
dtyp := dns.TypeA
host := path.Base(t.Name())
res, err := f.CheckHostRules(host, dtyp, setts)
assert.Nil(t, err)
assert.Empty(t, res.CanonName)
if dnsrr := res.DNSRewriteResult; assert.NotNil(t, dnsrr) {
assert.Equal(t, dns.RcodeSuccess, dnsrr.RCode)
if ipVals := dnsrr.Response[dtyp]; assert.Len(t, ipVals, 1) {
assert.Equal(t, ipv4p1, ipVals[0])
}
}
})
t.Run("disable-cname-many", func(t *testing.T) { t.Run("disable-cname-many", func(t *testing.T) {
dtyp := dns.TypeA dtyp := dns.TypeA
host := path.Base(t.Name()) host := path.Base(t.Name())
res, err := f.CheckHostRules(host, dtyp, setts) res, err := f.CheckHostRules(host, dtyp, setts)
assert.Nil(t, err) require.Nil(t, err)
assert.Equal(t, "new-cname-2", res.CanonName) assert.Equal(t, "new-cname-2", res.CanonName)
assert.Nil(t, res.DNSRewriteResult) assert.Nil(t, res.DNSRewriteResult)
}) })
@ -196,7 +148,7 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
host := path.Base(t.Name()) host := path.Base(t.Name())
res, err := f.CheckHostRules(host, dtyp, setts) res, err := f.CheckHostRules(host, dtyp, setts)
assert.Nil(t, err) require.Nil(t, err)
assert.Empty(t, res.CanonName) assert.Empty(t, res.CanonName)
assert.Empty(t, res.Rules) assert.Empty(t, res.Rules)
}) })

View File

@ -6,215 +6,297 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
// TODO(e.burkov): All the tests in this file may and should me merged together.
func TestRewrites(t *testing.T) { func TestRewrites(t *testing.T) {
d := newForTest(nil, nil) d := newForTest(nil, nil)
t.Cleanup(d.Close) t.Cleanup(d.Close)
// CNAME, A, AAAA
d.Rewrites = []RewriteEntry{
{"somecname", "somehost.com", 0, nil},
{"somehost.com", "0.0.0.0", 0, nil},
{"host.com", "1.2.3.4", 0, nil}, d.Rewrites = []RewriteEntry{{
{"host.com", "1.2.3.5", 0, nil}, // This one and below are about CNAME, A and AAAA.
{"host.com", "1:2:3::4", 0, nil}, Domain: "somecname",
{"www.host.com", "host.com", 0, nil}, Answer: "somehost.com",
} }, {
Domain: "somehost.com",
Answer: "0.0.0.0",
}, {
Domain: "host.com",
Answer: "1.2.3.4",
}, {
Domain: "host.com",
Answer: "1.2.3.5",
}, {
Domain: "host.com",
Answer: "1:2:3::4",
}, {
Domain: "www.host.com",
Answer: "host.com",
}, {
// This one is a wildcard.
Domain: "*.host.com",
Answer: "1.2.3.5",
}, {
// This one and below are about wildcard overriding.
Domain: "a.host.com",
Answer: "1.2.3.4",
}, {
// This one is about CNAME and wildcard interacting.
Domain: "*.host2.com",
Answer: "host.com",
}, {
// This one and below are about 2 level CNAME.
Domain: "b.host.com",
Answer: "somecname",
}, {
// This one and below are about 2 level CNAME and wildcard.
Domain: "b.host3.com",
Answer: "a.host3.com",
}, {
Domain: "a.host3.com",
Answer: "x.host.com",
}}
d.prepareRewrites() d.prepareRewrites()
r := d.processRewrites("host2.com", dns.TypeA)
assert.Equal(t, NotFilteredNotFound, r.Reason)
r = d.processRewrites("www.host.com", dns.TypeA) testCases := []struct {
assert.Equal(t, Rewritten, r.Reason) name string
assert.Equal(t, "host.com", r.CanonName) host string
assert.Len(t, r.IPList, 2) dtyp uint16
assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4})) wantCName string
assert.True(t, r.IPList[1].Equal(net.IP{1, 2, 3, 5})) wantVals []net.IP
}{{
name: "not_filtered_not_found",
host: "hoost.com",
dtyp: dns.TypeA,
}, {
name: "rewritten_a",
host: "www.host.com",
dtyp: dns.TypeA,
wantCName: "host.com",
wantVals: []net.IP{{1, 2, 3, 4}, {1, 2, 3, 5}},
}, {
name: "rewritten_aaaa",
host: "www.host.com",
dtyp: dns.TypeAAAA,
wantCName: "host.com",
wantVals: []net.IP{net.ParseIP("1:2:3::4")},
}, {
name: "wildcard_match",
host: "abc.host.com",
dtyp: dns.TypeA,
wantVals: []net.IP{{1, 2, 3, 5}},
}, {
name: "wildcard_override",
host: "a.host.com",
dtyp: dns.TypeA,
wantVals: []net.IP{{1, 2, 3, 4}},
}, {
name: "wildcard_cname_interaction",
host: "www.host2.com",
dtyp: dns.TypeA,
wantCName: "host.com",
wantVals: []net.IP{{1, 2, 3, 4}, {1, 2, 3, 5}},
}, {
name: "two_cnames",
host: "b.host.com",
dtyp: dns.TypeA,
wantCName: "somehost.com",
wantVals: []net.IP{{0, 0, 0, 0}},
}, {
name: "two_cnames_and_wildcard",
host: "b.host3.com",
dtyp: dns.TypeA,
wantCName: "x.host.com",
wantVals: []net.IP{{1, 2, 3, 5}},
}}
r = d.processRewrites("www.host.com", dns.TypeAAAA) for _, tc := range testCases {
assert.Equal(t, Rewritten, r.Reason) t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, "host.com", r.CanonName) valsNum := len(tc.wantVals)
assert.Len(t, r.IPList, 1)
assert.True(t, r.IPList[0].Equal(net.ParseIP("1:2:3::4")))
// wildcard r := d.processRewrites(tc.host, tc.dtyp)
d.Rewrites = []RewriteEntry{ if valsNum == 0 {
{"host.com", "1.2.3.4", 0, nil}, assert.Equal(t, NotFilteredNotFound, r.Reason)
{"*.host.com", "1.2.3.5", 0, nil},
return
}
require.Equal(t, Rewritten, r.Reason)
if tc.wantCName != "" {
assert.Equal(t, tc.wantCName, r.CanonName)
}
require.Len(t, r.IPList, valsNum)
for i, ip := range tc.wantVals {
assert.Equal(t, ip, r.IPList[i])
}
})
} }
d.prepareRewrites()
r = d.processRewrites("host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason)
assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4}))
r = d.processRewrites("www.host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason)
assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 5}))
r = d.processRewrites("www.host2.com", dns.TypeA)
assert.Equal(t, NotFilteredNotFound, r.Reason)
// override a wildcard
d.Rewrites = []RewriteEntry{
{"a.host.com", "1.2.3.4", 0, nil},
{"*.host.com", "1.2.3.5", 0, nil},
}
d.prepareRewrites()
r = d.processRewrites("a.host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason)
assert.Len(t, r.IPList, 1)
assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4}))
// wildcard + CNAME
d.Rewrites = []RewriteEntry{
{"host.com", "1.2.3.4", 0, nil},
{"*.host.com", "host.com", 0, nil},
}
d.prepareRewrites()
r = d.processRewrites("www.host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason)
assert.Equal(t, "host.com", r.CanonName)
assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4}))
// 2 CNAMEs
d.Rewrites = []RewriteEntry{
{"b.host.com", "a.host.com", 0, nil},
{"a.host.com", "host.com", 0, nil},
{"host.com", "1.2.3.4", 0, nil},
}
d.prepareRewrites()
r = d.processRewrites("b.host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason)
assert.Equal(t, "host.com", r.CanonName)
assert.Len(t, r.IPList, 1)
assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4}))
// 2 CNAMEs + wildcard
d.Rewrites = []RewriteEntry{
{"b.host.com", "a.host.com", 0, nil},
{"a.host.com", "x.somehost.com", 0, nil},
{"*.somehost.com", "1.2.3.4", 0, nil},
}
d.prepareRewrites()
r = d.processRewrites("b.host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason)
assert.Equal(t, "x.somehost.com", r.CanonName)
assert.Len(t, r.IPList, 1)
assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4}))
} }
func TestRewritesLevels(t *testing.T) { func TestRewritesLevels(t *testing.T) {
d := newForTest(nil, nil) d := newForTest(nil, nil)
t.Cleanup(d.Close) t.Cleanup(d.Close)
// exact host, wildcard L2, wildcard L3 // Exact host, wildcard L2, wildcard L3.
d.Rewrites = []RewriteEntry{ d.Rewrites = []RewriteEntry{{
{"host.com", "1.1.1.1", 0, nil}, Domain: "host.com",
{"*.host.com", "2.2.2.2", 0, nil}, Answer: "1.1.1.1",
{"*.sub.host.com", "3.3.3.3", 0, nil}, }, {
} Domain: "*.host.com",
Answer: "2.2.2.2",
}, {
Domain: "*.sub.host.com",
Answer: "3.3.3.3",
}}
d.prepareRewrites() d.prepareRewrites()
// match exact testCases := []struct {
r := d.processRewrites("host.com", dns.TypeA) name string
assert.Equal(t, Rewritten, r.Reason) host string
assert.Len(t, r.IPList, 1) want net.IP
assert.True(t, net.IP{1, 1, 1, 1}.Equal(r.IPList[0])) }{{
name: "exact_match",
host: "host.com",
want: net.IP{1, 1, 1, 1},
}, {
name: "l2_match",
host: "sub.host.com",
want: net.IP{2, 2, 2, 2},
}, {
name: "l3_match",
host: "my.sub.host.com",
want: net.IP{3, 3, 3, 3},
}}
// match L2 for _, tc := range testCases {
r = d.processRewrites("sub.host.com", dns.TypeA) t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, Rewritten, r.Reason) r := d.processRewrites(tc.host, dns.TypeA)
assert.Len(t, r.IPList, 1) assert.Equal(t, Rewritten, r.Reason)
assert.True(t, net.IP{2, 2, 2, 2}.Equal(r.IPList[0])) require.Len(t, r.IPList, 1)
})
// match L3 }
r = d.processRewrites("my.sub.host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason)
assert.Len(t, r.IPList, 1)
assert.True(t, net.IP{3, 3, 3, 3}.Equal(r.IPList[0]))
} }
func TestRewritesExceptionCNAME(t *testing.T) { func TestRewritesExceptionCNAME(t *testing.T) {
d := newForTest(nil, nil) d := newForTest(nil, nil)
t.Cleanup(d.Close) t.Cleanup(d.Close)
// wildcard; exception for a sub-domain // Wildcard and exception for a sub-domain.
d.Rewrites = []RewriteEntry{ d.Rewrites = []RewriteEntry{{
{"*.host.com", "2.2.2.2", 0, nil}, Domain: "*.host.com",
{"sub.host.com", "sub.host.com", 0, nil}, Answer: "2.2.2.2",
} }, {
Domain: "sub.host.com",
Answer: "sub.host.com",
}, {
Domain: "*.sub.host.com",
Answer: "*.sub.host.com",
}}
d.prepareRewrites() d.prepareRewrites()
// match sub-domain testCases := []struct {
r := d.processRewrites("my.host.com", dns.TypeA) name string
assert.Equal(t, Rewritten, r.Reason) host string
assert.Len(t, r.IPList, 1) want net.IP
assert.True(t, net.IP{2, 2, 2, 2}.Equal(r.IPList[0])) }{{
name: "match_sub-domain",
host: "my.host.com",
want: net.IP{2, 2, 2, 2},
}, {
name: "exception_cname",
host: "sub.host.com",
}, {
name: "exception_wildcard",
host: "my.sub.host.com",
}}
// match sub-domain, but handle exception for _, tc := range testCases {
r = d.processRewrites("sub.host.com", dns.TypeA) t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, NotFilteredNotFound, r.Reason) r := d.processRewrites(tc.host, dns.TypeA)
} if tc.want == nil {
assert.Equal(t, NotFilteredNotFound, r.Reason)
func TestRewritesExceptionWC(t *testing.T) { return
d := newForTest(nil, nil) }
t.Cleanup(d.Close)
// wildcard; exception for a sub-wildcard assert.Equal(t, Rewritten, r.Reason)
d.Rewrites = []RewriteEntry{ require.Len(t, r.IPList, 1)
{"*.host.com", "2.2.2.2", 0, nil}, assert.True(t, tc.want.Equal(r.IPList[0]))
{"*.sub.host.com", "*.sub.host.com", 0, nil}, })
} }
d.prepareRewrites()
// match sub-domain
r := d.processRewrites("my.host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason)
assert.Len(t, r.IPList, 1)
assert.True(t, net.IP{2, 2, 2, 2}.Equal(r.IPList[0]))
// match sub-domain, but handle exception
r = d.processRewrites("my.sub.host.com", dns.TypeA)
assert.Equal(t, NotFilteredNotFound, r.Reason)
} }
func TestRewritesExceptionIP(t *testing.T) { func TestRewritesExceptionIP(t *testing.T) {
d := newForTest(nil, nil) d := newForTest(nil, nil)
t.Cleanup(d.Close) t.Cleanup(d.Close)
// exception for AAAA record // Exception for AAAA record.
d.Rewrites = []RewriteEntry{ d.Rewrites = []RewriteEntry{{
{"host.com", "1.2.3.4", 0, nil}, Domain: "host.com",
{"host.com", "AAAA", 0, nil}, Answer: "1.2.3.4",
{"host2.com", "::1", 0, nil}, }, {
{"host2.com", "A", 0, nil}, Domain: "host.com",
{"host3.com", "A", 0, nil}, Answer: "AAAA",
} }, {
Domain: "host2.com",
Answer: "::1",
}, {
Domain: "host2.com",
Answer: "A",
}, {
Domain: "host3.com",
Answer: "A",
}}
d.prepareRewrites() d.prepareRewrites()
// match domain testCases := []struct {
r := d.processRewrites("host.com", dns.TypeA) name string
assert.Equal(t, Rewritten, r.Reason) host string
assert.Len(t, r.IPList, 1) dtyp uint16
assert.True(t, net.IP{1, 2, 3, 4}.Equal(r.IPList[0])) want []net.IP
}{{
name: "match_A",
host: "host.com",
dtyp: dns.TypeA,
want: []net.IP{{1, 2, 3, 4}},
}, {
name: "exception_AAAA_host.com",
host: "host.com",
dtyp: dns.TypeAAAA,
}, {
name: "exception_A_host2.com",
host: "host2.com",
dtyp: dns.TypeA,
}, {
name: "match_AAAA_host2.com",
host: "host2.com",
dtyp: dns.TypeAAAA,
want: []net.IP{net.ParseIP("::1")},
}, {
name: "exception_A_host3.com",
host: "host3.com",
dtyp: dns.TypeA,
}, {
name: "match_AAAA_host3.com",
host: "host3.com",
dtyp: dns.TypeAAAA,
want: []net.IP{},
}}
// match exception for _, tc := range testCases {
r = d.processRewrites("host.com", dns.TypeAAAA) t.Run(tc.name+"_"+tc.host, func(t *testing.T) {
assert.Equal(t, NotFilteredNotFound, r.Reason) r := d.processRewrites(tc.host, tc.dtyp)
if tc.want == nil {
assert.Equal(t, NotFilteredNotFound, r.Reason)
// match exception return
r = d.processRewrites("host2.com", dns.TypeA) }
assert.Equal(t, NotFilteredNotFound, r.Reason)
// match domain assert.Equal(t, Rewritten, r.Reason)
r = d.processRewrites("host2.com", dns.TypeAAAA) require.Len(t, r.IPList, len(tc.want))
assert.Equal(t, Rewritten, r.Reason) for _, ip := range tc.want {
assert.Len(t, r.IPList, 1) assert.True(t, ip.Equal(r.IPList[0]))
assert.Equal(t, "::1", r.IPList[0].String()) }
})
// match exception }
r = d.processRewrites("host3.com", dns.TypeA)
assert.Equal(t, NotFilteredNotFound, r.Reason)
// match domain
r = d.processRewrites("host3.com", dns.TypeAAAA)
assert.Equal(t, Rewritten, r.Reason)
assert.Empty(t, r.IPList)
} }

View File

@ -8,6 +8,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/golibs/cache" "github.com/AdguardTeam/golibs/cache"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestSafeBrowsingHash(t *testing.T) { func TestSafeBrowsingHash(t *testing.T) {
@ -155,25 +156,25 @@ func TestSBPC(t *testing.T) {
}} }}
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { // Prepare the upstream.
// Prepare the upstream. ups := &aghtest.TestBlockUpstream{
ups := &aghtest.TestBlockUpstream{ Hostname: hostname,
Hostname: hostname, Block: tc.block,
Block: tc.block, }
} d.SetSafeBrowsingUpstream(ups)
d.SetSafeBrowsingUpstream(ups) d.SetParentalUpstream(ups)
d.SetParentalUpstream(ups)
t.Run(tc.name, func(t *testing.T) {
// Firstly, check the request blocking. // Firstly, check the request blocking.
hits := 0 hits := 0
res, err := tc.testFunc(hostname) res, err := tc.testFunc(hostname)
assert.Nil(t, err) require.Nil(t, err)
if tc.block { if tc.block {
assert.True(t, res.IsFiltered) assert.True(t, res.IsFiltered)
assert.Len(t, res.Rules, 1) require.Len(t, res.Rules, 1)
hits++ hits++
} else { } else {
assert.False(t, res.IsFiltered) require.False(t, res.IsFiltered)
} }
// Check the cache state, check the response is now cached. // Check the cache state, check the response is now cached.
@ -185,12 +186,12 @@ func TestSBPC(t *testing.T) {
// Now make the same request to check the cache was used. // Now make the same request to check the cache was used.
res, err = tc.testFunc(hostname) res, err = tc.testFunc(hostname)
assert.Nil(t, err) require.Nil(t, err)
if tc.block { if tc.block {
assert.True(t, res.IsFiltered) assert.True(t, res.IsFiltered)
assert.Len(t, res.Rules, 1) require.Len(t, res.Rules, 1)
} else { } else {
assert.False(t, res.IsFiltered) require.False(t, res.IsFiltered)
} }
// Check the cache state, it should've been used. // Check the cache state, it should've been used.
@ -199,8 +200,8 @@ func TestSBPC(t *testing.T) {
// Check that there were no additional requests. // Check that there were no additional requests.
assert.Equal(t, 1, ups.RequestsCount()) assert.Equal(t, 1, ups.RequestsCount())
purgeCaches()
}) })
purgeCaches()
} }
} }

View File

@ -5,71 +5,153 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestIsBlockedIPAllowed(t *testing.T) { func TestIsBlockedIP(t *testing.T) {
a := &accessCtx{} const (
assert.Nil(t, a.Init([]string{"1.1.1.1", "2.2.0.0/16"}, nil, nil)) ip int = iota
cidr
)
disallowed, disallowedRule := a.IsBlockedIP(net.IPv4(1, 1, 1, 1)) rules := []string{
assert.False(t, disallowed) ip: "1.1.1.1",
assert.Empty(t, disallowedRule) cidr: "2.2.0.0/16",
}
disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(1, 1, 1, 2)) testCases := []struct {
assert.True(t, disallowed) name string
assert.Empty(t, disallowedRule) allowed bool
ip net.IP
wantDis bool
wantRule string
}{{
name: "allow_ip",
allowed: true,
ip: net.IPv4(1, 1, 1, 1),
wantDis: false,
wantRule: "",
}, {
name: "disallow_ip",
allowed: true,
ip: net.IPv4(1, 1, 1, 2),
wantDis: true,
wantRule: "",
}, {
name: "allow_cidr",
allowed: true,
ip: net.IPv4(2, 2, 1, 1),
wantDis: false,
wantRule: "",
}, {
name: "disallow_cidr",
allowed: true,
ip: net.IPv4(2, 3, 1, 1),
wantDis: true,
wantRule: "",
}, {
name: "allow_ip",
allowed: false,
ip: net.IPv4(1, 1, 1, 1),
wantDis: true,
wantRule: rules[ip],
}, {
name: "disallow_ip",
allowed: false,
ip: net.IPv4(1, 1, 1, 2),
wantDis: false,
wantRule: "",
}, {
name: "allow_cidr",
allowed: false,
ip: net.IPv4(2, 2, 1, 1),
wantDis: true,
wantRule: rules[cidr],
}, {
name: "disallow_cidr",
allowed: false,
ip: net.IPv4(2, 3, 1, 1),
wantDis: false,
wantRule: "",
}}
disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(2, 2, 1, 1)) for _, tc := range testCases {
assert.False(t, disallowed) prefix := "allowed_"
assert.Empty(t, disallowedRule) if !tc.allowed {
prefix = "disallowed_"
}
disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(2, 3, 1, 1)) t.Run(prefix+tc.name, func(t *testing.T) {
assert.True(t, disallowed) aCtx := &accessCtx{}
assert.Empty(t, disallowedRule) allowedRules := rules
var disallowedRules []string
if !tc.allowed {
allowedRules, disallowedRules = disallowedRules, allowedRules
}
require.Nil(t, aCtx.Init(allowedRules, disallowedRules, nil))
disallowed, rule := aCtx.IsBlockedIP(tc.ip)
assert.Equal(t, tc.wantDis, disallowed)
assert.Equal(t, tc.wantRule, rule)
})
}
} }
func TestIsBlockedIPDisallowed(t *testing.T) { func TestIsBlockedDomain(t *testing.T) {
a := &accessCtx{} aCtx := &accessCtx{}
assert.Nil(t, a.Init(nil, []string{"1.1.1.1", "2.2.0.0/16"}, nil)) require.Nil(t, aCtx.Init(nil, nil, []string{
disallowed, disallowedRule := a.IsBlockedIP(net.IPv4(1, 1, 1, 1))
assert.True(t, disallowed)
assert.Equal(t, "1.1.1.1", disallowedRule)
disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(1, 1, 1, 2))
assert.False(t, disallowed)
assert.Empty(t, disallowedRule)
disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(2, 2, 1, 1))
assert.True(t, disallowed)
assert.Equal(t, "2.2.0.0/16", disallowedRule)
disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(2, 3, 1, 1))
assert.False(t, disallowed)
assert.Empty(t, disallowedRule)
}
func TestIsBlockedIPBlockedDomain(t *testing.T) {
a := &accessCtx{}
assert.True(t, a.Init(nil, nil, []string{
"host1", "host1",
"host2",
"*.host.com", "*.host.com",
"||host3.com^", "||host3.com^",
}) == nil) }))
// match by "host2.com" testCases := []struct {
assert.True(t, a.IsBlockedDomain("host1")) name string
assert.True(t, a.IsBlockedDomain("host2")) domain string
assert.False(t, a.IsBlockedDomain("host3")) want bool
}{{
name: "plain_match",
domain: "host1",
want: true,
}, {
name: "plain_mismatch",
domain: "host2",
want: false,
}, {
name: "wildcard_type-1_match_short",
domain: "asdf.host.com",
want: true,
}, {
name: "wildcard_type-1_match_long",
domain: "qwer.asdf.host.com",
want: true,
}, {
name: "wildcard_type-1_mismatch_no-lead",
domain: "host.com",
want: false,
}, {
name: "wildcard_type-1_mismatch_bad-asterisk",
domain: "asdf.zhost.com",
want: false,
}, {
name: "wildcard_type-2_match_simple",
domain: "host3.com",
want: true,
}, {
name: "wildcard_type-2_match_complex",
domain: "asdf.host3.com",
want: true,
}, {
name: "wildcard_type-2_mismatch",
domain: ".host3.com",
want: false,
}}
// match by wildcard "*.host.com" for _, tc := range testCases {
assert.False(t, a.IsBlockedDomain("host.com")) t.Run(tc.name, func(t *testing.T) {
assert.True(t, a.IsBlockedDomain("asdf.host.com")) assert.Equal(t, tc.want, aCtx.IsBlockedDomain(tc.domain))
assert.True(t, a.IsBlockedDomain("qwer.asdf.host.com")) })
assert.False(t, a.IsBlockedDomain("asdf.zhost.com")) }
// match by wildcard "||host3.com^"
assert.True(t, a.IsBlockedDomain("host3.com"))
assert.True(t, a.IsBlockedDomain("asdf.host3.com"))
} }

View File

@ -27,6 +27,7 @@ import (
"github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/dnsproxy/upstream"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
@ -42,14 +43,180 @@ func startDeferStop(t *testing.T, s *Server) {
t.Helper() t.Helper()
err := s.Start() err := s.Start()
assert.Nilf(t, err, "failed to start server: %s", err) require.Nilf(t, err, "failed to start server: %s", err)
t.Cleanup(func() { t.Cleanup(func() {
err := s.Stop() err := s.Stop()
assert.Nilf(t, err, "dns server failed to stop: %s", err) require.Nilf(t, err, "dns server failed to stop: %s", err)
}) })
} }
func createTestServer(t *testing.T, filterConf *dnsfilter.Config, forwardConf ServerConfig) *Server {
t.Helper()
rules := `||nxdomain.example.org
||null.example.org^
127.0.0.1 host.example.org
@@||whitelist.example.org^
||127.0.0.255`
filters := []dnsfilter.Filter{{
ID: 0, Data: []byte(rules),
}}
f := dnsfilter.New(filterConf, filters)
s := NewServer(DNSCreateParams{DNSFilter: f})
s.conf = forwardConf
require.Nil(t, s.Prepare(nil))
return s
}
func createServerTLSConfig(t *testing.T) (*tls.Config, []byte, []byte) {
t.Helper()
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.Nilf(t, err, "cannot generate RSA key: %s", err)
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
require.Nilf(t, err, "failed to generate serial number: %s", err)
notBefore := time.Now()
notAfter := notBefore.Add(5 * 365 * time.Hour * 24)
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"AdGuard Tests"},
},
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
IsCA: true,
}
template.DNSNames = append(template.DNSNames, tlsServerName)
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey(privateKey), privateKey)
require.Nilf(t, err, "failed to create certificate: %s", err)
certPem := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
keyPem := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)})
cert, err := tls.X509KeyPair(certPem, keyPem)
require.Nilf(t, err, "failed to create certificate: %s", err)
return &tls.Config{
Certificates: []tls.Certificate{cert},
ServerName: tlsServerName,
MinVersion: tls.VersionTLS12,
}, certPem, keyPem
}
func createTestTLS(t *testing.T, tlsConf TLSConfig) (s *Server, certPem []byte) {
t.Helper()
var keyPem []byte
_, certPem, keyPem = createServerTLSConfig(t)
s = createTestServer(t, &dnsfilter.Config{}, ServerConfig{
UDPListenAddr: &net.UDPAddr{},
TCPListenAddr: &net.TCPAddr{},
})
tlsConf.CertificateChainData, tlsConf.PrivateKeyData = certPem, keyPem
s.conf.TLSConfig = tlsConf
err := s.Prepare(nil)
require.Nilf(t, err, "failed to prepare server: %s", err)
return s, certPem
}
func createGoogleATestMessage() *dns.Msg {
return createTestMessage("google-public-dns-a.google.com.")
}
func createTestMessage(host string) *dns.Msg {
return &dns.Msg{
MsgHdr: dns.MsgHdr{
Id: dns.Id(),
RecursionDesired: true,
},
Question: []dns.Question{{
Name: host,
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}},
}
}
func createTestMessageWithType(host string, qtype uint16) *dns.Msg {
req := createTestMessage(host)
req.Question[0].Qtype = qtype
return req
}
func assertGoogleAResponse(t *testing.T, reply *dns.Msg) {
assertResponse(t, reply, net.IP{8, 8, 8, 8})
}
func assertResponse(t *testing.T, reply *dns.Msg, ip net.IP) {
t.Helper()
require.Lenf(t, reply.Answer, 1, "dns server returned reply with wrong number of answers - %d", len(reply.Answer))
a, ok := reply.Answer[0].(*dns.A)
require.Truef(t, ok, "dns server returned wrong answer type instead of A: %v", reply.Answer[0])
assert.Truef(t, a.A.Equal(ip), "dns server returned wrong answer instead of %s: %s", ip, a.A)
}
// sendTestMessagesAsync sends messages in parallel to check for race issues.
//
//lint:ignore U1000 it's called from the function which is skipped for now.
func sendTestMessagesAsync(t *testing.T, conn *dns.Conn) {
t.Helper()
wg := &sync.WaitGroup{}
for i := 0; i < testMessagesCount; i++ {
msg := createGoogleATestMessage()
wg.Add(1)
go func() {
defer wg.Done()
err := conn.WriteMsg(msg)
require.Nilf(t, err, "cannot write message: %s", err)
res, err := conn.ReadMsg()
require.Nilf(t, err, "cannot read response to message: %s", err)
assertGoogleAResponse(t, res)
}()
}
wg.Wait()
}
func sendTestMessages(t *testing.T, conn *dns.Conn) {
t.Helper()
for i := 0; i < testMessagesCount; i++ {
req := createGoogleATestMessage()
err := conn.WriteMsg(req)
assert.Nilf(t, err, "cannot write message #%d: %s", i, err)
res, err := conn.ReadMsg()
assert.Nilf(t, err, "cannot read response to message #%d: %s", i, err)
assertGoogleAResponse(t, res)
}
}
func TestServer(t *testing.T) { func TestServer(t *testing.T) {
s := createTestServer(t, &dnsfilter.Config{}, ServerConfig{ s := createTestServer(t, &dnsfilter.Config{}, ServerConfig{
UDPListenAddr: &net.UDPAddr{}, UDPListenAddr: &net.UDPAddr{},
@ -81,7 +248,7 @@ func TestServer(t *testing.T) {
client := dns.Client{Net: tc.proto} client := dns.Client{Net: tc.proto}
reply, _, err := client.Exchange(createGoogleATestMessage(), addr.String()) reply, _, err := client.Exchange(createGoogleATestMessage(), addr.String())
assert.Nilf(t, err, "сouldn't talk to server %s: %s", addr, err) require.Nilf(t, err, "сouldn't talk to server %s: %s", addr, err)
assertGoogleAResponse(t, reply) assertGoogleAResponse(t, reply)
}) })
@ -106,31 +273,12 @@ func TestServerWithProtectionDisabled(t *testing.T) {
req := createGoogleATestMessage() req := createGoogleATestMessage()
addr := s.dnsProxy.Addr(proxy.ProtoUDP) addr := s.dnsProxy.Addr(proxy.ProtoUDP)
client := dns.Client{Net: proxy.ProtoUDP} client := dns.Client{Net: proxy.ProtoUDP}
reply, _, err := client.Exchange(req, addr.String()) reply, _, err := client.Exchange(req, addr.String())
assert.Nilf(t, err, "сouldn't talk to server %s: %s", addr, err) require.Nilf(t, err, "сouldn't talk to server %s: %s", addr, err)
assertGoogleAResponse(t, reply) assertGoogleAResponse(t, reply)
} }
func createTestTLS(t *testing.T, tlsConf TLSConfig) (s *Server, certPem []byte) {
t.Helper()
var keyPem []byte
_, certPem, keyPem = createServerTLSConfig(t)
s = createTestServer(t, &dnsfilter.Config{}, ServerConfig{
UDPListenAddr: &net.UDPAddr{},
TCPListenAddr: &net.TCPAddr{},
})
tlsConf.CertificateChainData, tlsConf.PrivateKeyData = certPem, keyPem
s.conf.TLSConfig = tlsConf
err := s.Prepare(nil)
assert.Nilf(t, err, "failed to prepare server: %s", err)
return s, certPem
}
func TestDoTServer(t *testing.T) { func TestDoTServer(t *testing.T) {
s, certPem := createTestTLS(t, TLSConfig{ s, certPem := createTestTLS(t, TLSConfig{
TLSListenAddr: &net.TCPAddr{}, TLSListenAddr: &net.TCPAddr{},
@ -156,7 +304,7 @@ func TestDoTServer(t *testing.T) {
// Create a DNS-over-TLS client connection. // Create a DNS-over-TLS client connection.
addr := s.dnsProxy.Addr(proxy.ProtoTLS) addr := s.dnsProxy.Addr(proxy.ProtoTLS)
conn, err := dns.DialWithTLS("tcp-tls", addr.String(), tlsConfig) conn, err := dns.DialWithTLS("tcp-tls", addr.String(), tlsConfig)
assert.Nilf(t, err, "cannot connect to the proxy: %s", err) require.Nilf(t, err, "cannot connect to the proxy: %s", err)
sendTestMessages(t, conn) sendTestMessages(t, conn)
} }
@ -178,12 +326,12 @@ func TestDoQServer(t *testing.T) {
addr := s.dnsProxy.Addr(proxy.ProtoQUIC) addr := s.dnsProxy.Addr(proxy.ProtoQUIC)
opts := upstream.Options{InsecureSkipVerify: true} opts := upstream.Options{InsecureSkipVerify: true}
u, err := upstream.AddressToUpstream(fmt.Sprintf("%s://%s", proxy.ProtoQUIC, addr), opts) u, err := upstream.AddressToUpstream(fmt.Sprintf("%s://%s", proxy.ProtoQUIC, addr), opts)
assert.Nil(t, err) require.Nil(t, err)
// Send the test message. // Send the test message.
req := createGoogleATestMessage() req := createGoogleATestMessage()
res, err := u.Exchange(req) res, err := u.Exchange(req)
assert.Nil(t, err) require.Nil(t, err)
assertGoogleAResponse(t, res) assertGoogleAResponse(t, res)
} }
@ -221,7 +369,7 @@ func TestServerRace(t *testing.T) {
// Message over UDP. // Message over UDP.
addr := s.dnsProxy.Addr(proxy.ProtoUDP) addr := s.dnsProxy.Addr(proxy.ProtoUDP)
conn, err := dns.Dial(proxy.ProtoUDP, addr.String()) conn, err := dns.Dial(proxy.ProtoUDP, addr.String())
assert.Nilf(t, err, "cannot connect to the proxy: %s", err) require.Nilf(t, err, "cannot connect to the proxy: %s", err)
sendTestMessagesAsync(t, conn) sendTestMessagesAsync(t, conn)
} }
@ -282,8 +430,9 @@ func TestSafeSearch(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.host, func(t *testing.T) { t.Run(tc.host, func(t *testing.T) {
req := createTestMessage(tc.host) req := createTestMessage(tc.host)
reply, _, err := client.Exchange(req, addr) reply, _, err := client.Exchange(req, addr)
assert.Nilf(t, err, "couldn't talk to server %s: %s", addr, err) require.Nilf(t, err, "couldn't talk to server %s: %s", addr, err)
assertResponse(t, reply, tc.want) assertResponse(t, reply, tc.want)
}) })
} }
@ -330,8 +479,10 @@ func TestBlockedRequest(t *testing.T) {
req := createTestMessage("nxdomain.example.org.") req := createTestMessage("nxdomain.example.org.")
reply, err := dns.Exchange(req, addr.String()) reply, err := dns.Exchange(req, addr.String())
assert.Nilf(t, err, "couldn't talk to server %s: %s", addr, err) require.Nilf(t, err, "couldn't talk to server %s: %s", addr, err)
assert.Equal(t, dns.RcodeSuccess, reply.Rcode) assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
require.Len(t, reply.Answer, 1)
assert.True(t, reply.Answer[0].(*dns.A).A.IsUnspecified()) assert.True(t, reply.Answer[0].(*dns.A).A.IsUnspecified())
} }
@ -364,28 +515,14 @@ func TestServerCustomClientUpstream(t *testing.T) {
reply, err := dns.Exchange(req, addr.String()) reply, err := dns.Exchange(req, addr.String())
assert.Nil(t, err) require.Nil(t, err)
assert.Equal(t, dns.RcodeSuccess, reply.Rcode) assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
assert.NotEmpty(t, reply.Answer) require.NotEmpty(t, reply.Answer)
require.Len(t, reply.Answer, 1)
assert.Equal(t, net.IP{192, 168, 0, 1}, reply.Answer[0].(*dns.A).A) assert.Equal(t, net.IP{192, 168, 0, 1}, reply.Answer[0].(*dns.A).A)
} }
func (s *Server) startWithUpstream(u upstream.Upstream) error {
s.Lock()
defer s.Unlock()
err := s.Prepare(nil)
if err != nil {
return err
}
s.dnsProxy.UpstreamConfig = &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{u},
}
return s.dnsProxy.Start()
}
// testCNAMEs is a map of names and CNAMEs necessary for the TestUpstream work. // testCNAMEs is a map of names and CNAMEs necessary for the TestUpstream work.
var testCNAMEs = map[string]string{ var testCNAMEs = map[string]string{
"badhost.": "null.example.org.", "badhost.": "null.example.org.",
@ -409,15 +546,19 @@ func TestBlockCNAMEProtectionEnabled(t *testing.T) {
IPv6: nil, IPv6: nil,
} }
s.conf.ProtectionEnabled = false s.conf.ProtectionEnabled = false
err := s.startWithUpstream(testUpstm) s.dnsProxy.UpstreamConfig = &proxy.UpstreamConfig{
assert.Nil(t, err) Upstreams: []upstream.Upstream{testUpstm},
}
startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP) addr := s.dnsProxy.Addr(proxy.ProtoUDP)
// 'badhost' has a canonical name 'null.example.org' which is blocked by // 'badhost' has a canonical name 'null.example.org' which should be
// filters: but protection is disabled so response is _not_ blocked. // blocked by filters, but protection is disabled so it is not.
req := createTestMessage("badhost.") req := createTestMessage("badhost.")
reply, err := dns.Exchange(req, addr.String()) reply, err := dns.Exchange(req, addr.String())
assert.Nil(t, err) require.Nil(t, err)
assert.Equal(t, dns.RcodeSuccess, reply.Rcode) assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
} }
@ -465,11 +606,15 @@ func TestBlockCNAME(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run("block_cname_"+tc.host, func(t *testing.T) { t.Run("block_cname_"+tc.host, func(t *testing.T) {
req := createTestMessage(tc.host) req := createTestMessage(tc.host)
reply, err := dns.Exchange(req, addr) reply, err := dns.Exchange(req, addr)
assert.Nil(t, err) require.Nil(t, err)
assert.Equal(t, dns.RcodeSuccess, reply.Rcode) assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
if tc.want { if tc.want {
assert.True(t, reply.Answer[0].(*dns.A).A.IsUnspecified()) require.Len(t, reply.Answer, 1)
a, ok := reply.Answer[0].(*dns.A)
require.True(t, ok)
assert.True(t, a.A.IsUnspecified())
} }
}) })
} }
@ -513,7 +658,7 @@ func TestClientRulesForCNAMEMatching(t *testing.T) {
// However, in our case it should not be blocked as filtering is // However, in our case it should not be blocked as filtering is
// disabled on the client level. // disabled on the client level.
reply, err := dns.Exchange(&req, addr.String()) reply, err := dns.Exchange(&req, addr.String())
assert.Nil(t, err) require.Nil(t, err)
assert.Equal(t, dns.RcodeSuccess, reply.Rcode) assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
} }
@ -544,10 +689,10 @@ func TestNullBlockedRequest(t *testing.T) {
} }
reply, err := dns.Exchange(&req, addr.String()) reply, err := dns.Exchange(&req, addr.String())
assert.Nilf(t, err, "couldn't talk to server %s: %s", addr, err) require.Nilf(t, err, "couldn't talk to server %s: %s", addr, err)
assert.Lenf(t, reply.Answer, 1, "dns server %s returned reply with wrong number of answers - %d", addr, len(reply.Answer)) require.Lenf(t, reply.Answer, 1, "dns server %s returned reply with wrong number of answers - %d", addr, len(reply.Answer))
a, ok := reply.Answer[0].(*dns.A) a, ok := reply.Answer[0].(*dns.A)
assert.Truef(t, ok, "dns server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0]) require.Truef(t, ok, "dns server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0])
assert.Truef(t, a.A.IsUnspecified(), "dns server %s returned wrong answer instead of 0.0.0.0: %v", addr, a.A) assert.Truef(t, a.A.IsUnspecified(), "dns server %s returned wrong answer instead of 0.0.0.0: %v", addr, a.A)
} }
@ -561,7 +706,7 @@ func TestBlockedCustomIP(t *testing.T) {
s := NewServer(DNSCreateParams{ s := NewServer(DNSCreateParams{
DNSFilter: dnsfilter.New(&dnsfilter.Config{}, filters), DNSFilter: dnsfilter.New(&dnsfilter.Config{}, filters),
}) })
conf := ServerConfig{ conf := &ServerConfig{
UDPListenAddr: &net.UDPAddr{}, UDPListenAddr: &net.UDPAddr{},
TCPListenAddr: &net.TCPAddr{}, TCPListenAddr: &net.TCPAddr{},
FilteringConfig: FilteringConfig{ FilteringConfig: FilteringConfig{
@ -572,11 +717,11 @@ func TestBlockedCustomIP(t *testing.T) {
}, },
} }
// Invalid BlockingIPv4. // Invalid BlockingIPv4.
assert.NotNil(t, s.Prepare(&conf)) assert.NotNil(t, s.Prepare(conf))
conf.BlockingIPv4 = net.IP{0, 0, 0, 1} conf.BlockingIPv4 = net.IP{0, 0, 0, 1}
conf.BlockingIPv6 = net.ParseIP("::1") conf.BlockingIPv6 = net.ParseIP("::1")
assert.Nil(t, s.Prepare(&conf)) require.Nil(t, s.Prepare(conf))
startDeferStop(t, s) startDeferStop(t, s)
@ -584,18 +729,18 @@ func TestBlockedCustomIP(t *testing.T) {
req := createTestMessageWithType("null.example.org.", dns.TypeA) req := createTestMessageWithType("null.example.org.", dns.TypeA)
reply, err := dns.Exchange(req, addr.String()) reply, err := dns.Exchange(req, addr.String())
assert.Nil(t, err) require.Nil(t, err)
assert.Len(t, reply.Answer, 1) require.Len(t, reply.Answer, 1)
a, ok := reply.Answer[0].(*dns.A) a, ok := reply.Answer[0].(*dns.A)
assert.True(t, ok) require.True(t, ok)
assert.True(t, net.IP{0, 0, 0, 1}.Equal(a.A)) assert.True(t, net.IP{0, 0, 0, 1}.Equal(a.A))
req = createTestMessageWithType("null.example.org.", dns.TypeAAAA) req = createTestMessageWithType("null.example.org.", dns.TypeAAAA)
reply, err = dns.Exchange(req, addr.String()) reply, err = dns.Exchange(req, addr.String())
assert.Nil(t, err) require.Nil(t, err)
assert.Len(t, reply.Answer, 1) require.Len(t, reply.Answer, 1)
a6, ok := reply.Answer[0].(*dns.AAAA) a6, ok := reply.Answer[0].(*dns.AAAA)
assert.True(t, ok) require.True(t, ok)
assert.Equal(t, "::1", a6.AAAA.String()) assert.Equal(t, "::1", a6.AAAA.String())
} }
@ -615,11 +760,10 @@ func TestBlockedByHosts(t *testing.T) {
req := createTestMessage("host.example.org.") req := createTestMessage("host.example.org.")
reply, err := dns.Exchange(req, addr.String()) reply, err := dns.Exchange(req, addr.String())
assert.Nilf(t, err, "couldn't talk to server %s: %s", addr, err) require.Nilf(t, err, "couldn't talk to server %s: %s", addr, err)
assert.Lenf(t, reply.Answer, 1, "dns server %s returned reply with wrong number of answers - %d", addr, len(reply.Answer)) require.Lenf(t, reply.Answer, 1, "dns server %s returned reply with wrong number of answers - %d", addr, len(reply.Answer))
a, ok := reply.Answer[0].(*dns.A) a, ok := reply.Answer[0].(*dns.A)
assert.Truef(t, ok, "dns server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0]) require.Truef(t, ok, "dns server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0])
assert.Equalf(t, net.IP{127, 0, 0, 1}, a.A, "dns server %s returned wrong answer instead of 8.8.8.8: %v", addr, a.A) assert.Equalf(t, net.IP{127, 0, 0, 1}, a.A, "dns server %s returned wrong answer instead of 8.8.8.8: %v", addr, a.A)
} }
@ -630,7 +774,7 @@ func TestBlockedBySafeBrowsing(t *testing.T) {
Hostname: hostname, Hostname: hostname,
Block: true, Block: true,
} }
ans, _ := (&aghtest.TestResolver{}).HostToIPs(hostname) ans4, _ := (&aghtest.TestResolver{}).HostToIPs(hostname)
filterConf := &dnsfilter.Config{ filterConf := &dnsfilter.Config{
SafeBrowsingEnabled: true, SafeBrowsingEnabled: true,
@ -639,7 +783,7 @@ func TestBlockedBySafeBrowsing(t *testing.T) {
UDPListenAddr: &net.UDPAddr{}, UDPListenAddr: &net.UDPAddr{},
TCPListenAddr: &net.TCPAddr{}, TCPListenAddr: &net.TCPAddr{},
FilteringConfig: FilteringConfig{ FilteringConfig: FilteringConfig{
SafeBrowsingBlockHost: ans.String(), SafeBrowsingBlockHost: ans4.String(),
ProtectionEnabled: true, ProtectionEnabled: true,
}, },
} }
@ -652,13 +796,12 @@ func TestBlockedBySafeBrowsing(t *testing.T) {
req := createTestMessage(hostname + ".") req := createTestMessage(hostname + ".")
reply, err := dns.Exchange(req, addr.String()) reply, err := dns.Exchange(req, addr.String())
assert.Nilf(t, err, "couldn't talk to server %s: %s", addr, err) require.Nilf(t, err, "couldn't talk to server %s: %s", addr, err)
assert.Lenf(t, reply.Answer, 1, "dns server %s returned reply with wrong number of answers - %d", addr, len(reply.Answer)) require.Lenf(t, reply.Answer, 1, "dns server %s returned reply with wrong number of answers - %d", addr, len(reply.Answer))
a, ok := reply.Answer[0].(*dns.A) a, ok := reply.Answer[0].(*dns.A)
if assert.Truef(t, ok, "dns server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0]) { require.Truef(t, ok, "dns server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0])
assert.Equal(t, ans, a.A, "dns server %s returned wrong answer: %v", addr, a.A) assert.Equal(t, ans4, a.A, "dns server %s returned wrong answer: %v", addr, a.A)
}
} }
func TestRewrite(t *testing.T) { func TestRewrite(t *testing.T) {
@ -680,14 +823,14 @@ func TestRewrite(t *testing.T) {
f := dnsfilter.New(c, nil) f := dnsfilter.New(c, nil)
s := NewServer(DNSCreateParams{DNSFilter: f}) s := NewServer(DNSCreateParams{DNSFilter: f})
err := s.Prepare(&ServerConfig{ assert.Nil(t, s.Prepare(&ServerConfig{
UDPListenAddr: &net.UDPAddr{}, UDPListenAddr: &net.UDPAddr{},
TCPListenAddr: &net.TCPAddr{}, TCPListenAddr: &net.TCPAddr{},
FilteringConfig: FilteringConfig{ FilteringConfig: FilteringConfig{
ProtectionEnabled: true, ProtectionEnabled: true,
UpstreamDNS: []string{"8.8.8.8:53"}, UpstreamDNS: []string{"8.8.8.8:53"},
}, },
}) }))
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&aghtest.TestUpstream{ &aghtest.TestUpstream{
CName: map[string]string{ CName: map[string]string{
@ -698,185 +841,44 @@ func TestRewrite(t *testing.T) {
}, },
}, },
} }
assert.Nil(t, err)
startDeferStop(t, s) startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP) addr := s.dnsProxy.Addr(proxy.ProtoUDP)
req := createTestMessageWithType("test.com.", dns.TypeA) req := createTestMessageWithType("test.com.", dns.TypeA)
reply, err := dns.Exchange(req, addr.String()) reply, err := dns.Exchange(req, addr.String())
assert.Nil(t, err) require.Nil(t, err)
assert.Len(t, reply.Answer, 1) require.Len(t, reply.Answer, 1)
a, ok := reply.Answer[0].(*dns.A) a, ok := reply.Answer[0].(*dns.A)
assert.True(t, ok) require.True(t, ok)
assert.True(t, net.IP{1, 2, 3, 4}.Equal(a.A)) assert.True(t, net.IP{1, 2, 3, 4}.Equal(a.A))
req = createTestMessageWithType("test.com.", dns.TypeAAAA) req = createTestMessageWithType("test.com.", dns.TypeAAAA)
reply, err = dns.Exchange(req, addr.String()) reply, err = dns.Exchange(req, addr.String())
assert.Nil(t, err) require.Nil(t, err)
assert.Empty(t, reply.Answer) assert.Empty(t, reply.Answer)
req = createTestMessageWithType("alias.test.com.", dns.TypeA) req = createTestMessageWithType("alias.test.com.", dns.TypeA)
reply, err = dns.Exchange(req, addr.String()) reply, err = dns.Exchange(req, addr.String())
assert.Nil(t, err) require.Nil(t, err)
assert.Len(t, reply.Answer, 2)
require.Len(t, reply.Answer, 2)
assert.Equal(t, "test.com.", reply.Answer[0].(*dns.CNAME).Target) assert.Equal(t, "test.com.", reply.Answer[0].(*dns.CNAME).Target)
assert.True(t, net.IP{1, 2, 3, 4}.Equal(reply.Answer[1].(*dns.A).A)) assert.True(t, net.IP{1, 2, 3, 4}.Equal(reply.Answer[1].(*dns.A).A))
req = createTestMessageWithType("my.alias.example.org.", dns.TypeA) req = createTestMessageWithType("my.alias.example.org.", dns.TypeA)
reply, err = dns.Exchange(req, addr.String()) reply, err = dns.Exchange(req, addr.String())
assert.Nil(t, err) require.Nil(t, err)
// The original question is restored. // The original question is restored.
require.Len(t, reply.Question, 1)
assert.Equal(t, "my.alias.example.org.", reply.Question[0].Name) assert.Equal(t, "my.alias.example.org.", reply.Question[0].Name)
assert.Len(t, reply.Answer, 2)
require.Len(t, reply.Answer, 2)
assert.Equal(t, "example.org.", reply.Answer[0].(*dns.CNAME).Target) assert.Equal(t, "example.org.", reply.Answer[0].(*dns.CNAME).Target)
assert.Equal(t, dns.TypeA, reply.Answer[1].Header().Rrtype) assert.Equal(t, dns.TypeA, reply.Answer[1].Header().Rrtype)
} }
func createTestServer(t *testing.T, filterConf *dnsfilter.Config, forwardConf ServerConfig) *Server {
rules := `||nxdomain.example.org
||null.example.org^
127.0.0.1 host.example.org
@@||whitelist.example.org^
||127.0.0.255`
filters := []dnsfilter.Filter{{
ID: 0, Data: []byte(rules),
}}
f := dnsfilter.New(filterConf, filters)
s := NewServer(DNSCreateParams{DNSFilter: f})
s.conf = forwardConf
assert.Nil(t, s.Prepare(nil))
return s
}
func createServerTLSConfig(t *testing.T) (*tls.Config, []byte, []byte) {
t.Helper()
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
assert.Nilf(t, err, "cannot generate RSA key: %s", err)
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
assert.Nilf(t, err, "failed to generate serial number: %s", err)
notBefore := time.Now()
notAfter := notBefore.Add(5 * 365 * time.Hour * 24)
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"AdGuard Tests"},
},
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
IsCA: true,
}
template.DNSNames = append(template.DNSNames, tlsServerName)
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey(privateKey), privateKey)
assert.Nilf(t, err, "failed to create certificate: %s", err)
certPem := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
keyPem := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)})
cert, err := tls.X509KeyPair(certPem, keyPem)
assert.Nilf(t, err, "failed to create certificate: %s", err)
return &tls.Config{
Certificates: []tls.Certificate{cert},
ServerName: tlsServerName,
MinVersion: tls.VersionTLS12,
}, certPem, keyPem
}
// sendTestMessagesAsync sends messages in parallel to check for race issues.
//lint:ignore U1000 it's called from the function which is skipped for now.
func sendTestMessagesAsync(t *testing.T, conn *dns.Conn) {
wg := &sync.WaitGroup{}
for i := 0; i < testMessagesCount; i++ {
msg := createGoogleATestMessage()
wg.Add(1)
go func() {
defer wg.Done()
err := conn.WriteMsg(msg)
assert.Nilf(t, err, "cannot write message: %s", err)
res, err := conn.ReadMsg()
assert.Nilf(t, err, "cannot read response to message: %s", err)
assertGoogleAResponse(t, res)
}()
}
wg.Wait()
}
func sendTestMessages(t *testing.T, conn *dns.Conn) {
t.Helper()
for i := 0; i < testMessagesCount; i++ {
req := createGoogleATestMessage()
err := conn.WriteMsg(req)
assert.Nilf(t, err, "cannot write message #%d: %s", i, err)
res, err := conn.ReadMsg()
assert.Nilf(t, err, "cannot read response to message #%d: %s", i, err)
assertGoogleAResponse(t, res)
}
}
func createGoogleATestMessage() *dns.Msg {
return createTestMessage("google-public-dns-a.google.com.")
}
func createTestMessage(host string) *dns.Msg {
return &dns.Msg{
MsgHdr: dns.MsgHdr{
Id: dns.Id(),
RecursionDesired: true,
},
Question: []dns.Question{{
Name: host,
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}},
}
}
func createTestMessageWithType(host string, qtype uint16) *dns.Msg {
req := createTestMessage(host)
req.Question[0].Qtype = qtype
return req
}
func assertGoogleAResponse(t *testing.T, reply *dns.Msg) {
assertResponse(t, reply, net.IP{8, 8, 8, 8})
}
func assertResponse(t *testing.T, reply *dns.Msg, ip net.IP) {
t.Helper()
if !assert.Lenf(t, reply.Answer, 1, "dns server returned reply with wrong number of answers - %d", len(reply.Answer)) {
return
}
a, ok := reply.Answer[0].(*dns.A)
if assert.Truef(t, ok, "dns server returned wrong answer type instead of A: %v", reply.Answer[0]) {
assert.Truef(t, a.A.Equal(ip), "dns server returned wrong answer instead of %s: %s", ip, a.A)
}
}
func publicKey(priv interface{}) interface{} { func publicKey(priv interface{}) interface{} {
switch k := priv.(type) { switch k := priv.(type) {
case *rsa.PrivateKey: case *rsa.PrivateKey:
@ -966,8 +968,8 @@ func TestValidateUpstream(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
defaultUpstream, err := validateUpstream(tc.upstream) defaultUpstream, err := validateUpstream(tc.upstream)
assert.Equal(t, tc.valid, err == nil) require.Equal(t, tc.valid, err == nil)
if err == nil { if tc.valid {
assert.Equal(t, tc.wantDef, defaultUpstream) assert.Equal(t, tc.wantDef, defaultUpstream)
} }
}) })
@ -975,42 +977,73 @@ func TestValidateUpstream(t *testing.T) {
} }
func TestValidateUpstreamsSet(t *testing.T) { func TestValidateUpstreamsSet(t *testing.T) {
// Empty upstreams array. testCases := []struct {
var upstreamsSet []string name string
assert.Nil(t, ValidateUpstreams(upstreamsSet), "empty upstreams array should be valid") msg string
set []string
wantNil bool
}{{
name: "empty",
msg: "empty upstreams array should be valid",
set: nil,
wantNil: true,
}, {
name: "comment",
msg: "comments should not be validated",
set: []string{"# comment"},
wantNil: true,
}, {
name: "valid_no_default",
msg: "there is no default upstream",
set: []string{
"[/host.com/]1.1.1.1",
"[//]tls://1.1.1.1",
"[/www.host.com/]#",
"[/host.com/google.com/]8.8.8.8",
"[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
},
wantNil: false,
}, {
name: "valid_with_default",
msg: "upstreams set is valid, but doesn't pass through validation cause: %s",
set: []string{
"[/host.com/]1.1.1.1",
"[//]tls://1.1.1.1",
"[/www.host.com/]#",
"[/host.com/google.com/]8.8.8.8",
"[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
"8.8.8.8",
},
wantNil: true,
}, {
name: "invalid",
msg: "there is an invalid upstream in set, but it pass through validation",
set: []string{"dhcp://fake.dns"},
wantNil: false,
}}
// Comment in upstreams array. for _, tc := range testCases {
upstreamsSet = []string{"# comment"} t.Run(tc.name, func(t *testing.T) {
assert.Nil(t, ValidateUpstreams(upstreamsSet), "comments should not be validated") err := ValidateUpstreams(tc.set)
// Set of valid upstreams. There is no default upstream specified. assert.Equalf(t, tc.wantNil, err == nil, tc.msg, err)
upstreamsSet = []string{ })
"[/host.com/]1.1.1.1",
"[//]tls://1.1.1.1",
"[/www.host.com/]#",
"[/host.com/google.com/]8.8.8.8",
"[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
} }
assert.NotNil(t, ValidateUpstreams(upstreamsSet), "there is no default upstream")
// Let's add default upstream.
upstreamsSet = append(upstreamsSet, "8.8.8.8")
err := ValidateUpstreams(upstreamsSet)
assert.Nilf(t, err, "upstreams set is valid, but doesn't pass through validation cause: %s", err)
// Let's add invalid upstream.
upstreamsSet = append(upstreamsSet, "dhcp://fake.dns")
assert.NotNil(t, ValidateUpstreams(upstreamsSet), "there is an invalid upstream in set, but it pass through validation")
} }
func TestIPStringFromAddr(t *testing.T) { func TestIPStringFromAddr(t *testing.T) {
addr := net.UDPAddr{ t.Run("not_nil", func(t *testing.T) {
IP: net.ParseIP("1:2:3::4"), addr := net.UDPAddr{
Port: 12345, IP: net.ParseIP("1:2:3::4"),
Zone: "eth0", Port: 12345,
} Zone: "eth0",
assert.Equal(t, IPStringFromAddr(&addr), addr.IP.String()) }
assert.Empty(t, IPStringFromAddr(nil)) assert.Equal(t, IPStringFromAddr(&addr), addr.IP.String())
})
t.Run("nil", func(t *testing.T) {
assert.Empty(t, IPStringFromAddr(nil))
})
} }
func TestMatchDNSName(t *testing.T) { func TestMatchDNSName(t *testing.T) {
@ -1071,38 +1104,33 @@ func (d *testDHCP) Leases(flags int) []dhcpd.Lease {
func (d *testDHCP) SetOnLeaseChanged(onLeaseChanged dhcpd.OnLeaseChangedT) {} func (d *testDHCP) SetOnLeaseChanged(onLeaseChanged dhcpd.OnLeaseChangedT) {}
func TestPTRResponseFromDHCPLeases(t *testing.T) { func TestPTRResponseFromDHCPLeases(t *testing.T) {
dhcp := &testDHCP{}
s := NewServer(DNSCreateParams{ s := NewServer(DNSCreateParams{
DNSFilter: dnsfilter.New(&dnsfilter.Config{}, nil), DNSFilter: dnsfilter.New(&dnsfilter.Config{}, nil),
DHCPServer: dhcp, DHCPServer: &testDHCP{},
}) })
s.conf.UDPListenAddr = &net.UDPAddr{} s.conf.UDPListenAddr = &net.UDPAddr{}
s.conf.TCPListenAddr = &net.TCPAddr{} s.conf.TCPListenAddr = &net.TCPAddr{}
s.conf.UpstreamDNS = []string{"127.0.0.1:53"} s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
s.conf.FilteringConfig.ProtectionEnabled = true s.conf.FilteringConfig.ProtectionEnabled = true
err := s.Prepare(nil) require.Nil(t, s.Prepare(nil))
assert.Nil(t, err) require.Nil(t, s.Start())
t.Cleanup(func() {
s.Close()
})
assert.Nil(t, s.Start())
addr := s.dnsProxy.Addr(proxy.ProtoUDP) addr := s.dnsProxy.Addr(proxy.ProtoUDP)
req := createTestMessageWithType("1.0.0.127.in-addr.arpa.", dns.TypePTR) req := createTestMessageWithType("1.0.0.127.in-addr.arpa.", dns.TypePTR)
resp, err := dns.Exchange(req, addr.String()) resp, err := dns.Exchange(req, addr.String())
require.Nil(t, err)
assert.Nil(t, err) require.Len(t, resp.Answer, 1)
assert.Len(t, resp.Answer, 1)
assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype) assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype)
assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name) assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name)
ptr, ok := resp.Answer[0].(*dns.PTR) ptr, ok := resp.Answer[0].(*dns.PTR)
if assert.True(t, ok) { require.True(t, ok)
assert.Equal(t, "localhost.", ptr.Ptr) assert.Equal(t, "localhost.", ptr.Ptr)
}
s.Close()
} }
func TestPTRResponseFromHosts(t *testing.T) { func TestPTRResponseFromHosts(t *testing.T) {
@ -1112,12 +1140,11 @@ func TestPTRResponseFromHosts(t *testing.T) {
// Prepare test hosts file. // Prepare test hosts file.
hf, err := ioutil.TempFile("", "") hf, err := ioutil.TempFile("", "")
if assert.Nil(t, err) { require.Nil(t, err)
t.Cleanup(func() { t.Cleanup(func() {
assert.Nil(t, hf.Close()) assert.Nil(t, hf.Close())
assert.Nil(t, os.Remove(hf.Name())) assert.Nil(t, os.Remove(hf.Name()))
}) })
}
_, _ = hf.WriteString(" 127.0.0.1 host # comment \n") _, _ = hf.WriteString(" 127.0.0.1 host # comment \n")
_, _ = hf.WriteString(" ::1 localhost#comment \n") _, _ = hf.WriteString(" ::1 localhost#comment \n")
@ -1131,23 +1158,23 @@ func TestPTRResponseFromHosts(t *testing.T) {
s.conf.TCPListenAddr = &net.TCPAddr{} s.conf.TCPListenAddr = &net.TCPAddr{}
s.conf.UpstreamDNS = []string{"127.0.0.1:53"} s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
s.conf.FilteringConfig.ProtectionEnabled = true s.conf.FilteringConfig.ProtectionEnabled = true
assert.Nil(t, s.Prepare(nil)) require.Nil(t, s.Prepare(nil))
assert.Nil(t, s.Start()) require.Nil(t, s.Start())
t.Cleanup(func() {
s.Close()
})
addr := s.dnsProxy.Addr(proxy.ProtoUDP) addr := s.dnsProxy.Addr(proxy.ProtoUDP)
req := createTestMessageWithType("1.0.0.127.in-addr.arpa.", dns.TypePTR) req := createTestMessageWithType("1.0.0.127.in-addr.arpa.", dns.TypePTR)
resp, err := dns.Exchange(req, addr.String()) resp, err := dns.Exchange(req, addr.String())
assert.Nil(t, err) require.Nil(t, err)
assert.Len(t, resp.Answer, 1) require.Len(t, resp.Answer, 1)
assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype) assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype)
assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name) assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name)
ptr, ok := resp.Answer[0].(*dns.PTR) ptr, ok := resp.Answer[0].(*dns.PTR)
if assert.True(t, ok) { require.True(t, ok)
assert.Equal(t, "host.", ptr.Ptr) assert.Equal(t, "host.", ptr.Ptr)
}
s.Close()
} }

View File

@ -9,6 +9,7 @@ import (
"github.com/AdguardTeam/urlfilter/rules" "github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestServer_FilterDNSRewrite(t *testing.T) { func TestServer_FilterDNSRewrite(t *testing.T) {
@ -54,7 +55,8 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
d := &proxy.DNSContext{} d := &proxy.DNSContext{}
err := srv.filterDNSRewrite(req, res, d) err := srv.filterDNSRewrite(req, res, d)
assert.Nil(t, err)
require.Nil(t, err)
assert.Equal(t, dns.RcodeNameError, d.Res.Rcode) assert.Equal(t, dns.RcodeNameError, d.Res.Rcode)
}) })
@ -64,7 +66,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
d := &proxy.DNSContext{} d := &proxy.DNSContext{}
err := srv.filterDNSRewrite(req, res, d) err := srv.filterDNSRewrite(req, res, d)
assert.Nil(t, err) require.Nil(t, err)
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode) assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
assert.Empty(t, d.Res.Answer) assert.Empty(t, d.Res.Answer)
}) })
@ -75,11 +77,11 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
d := &proxy.DNSContext{} d := &proxy.DNSContext{}
err := srv.filterDNSRewrite(req, res, d) err := srv.filterDNSRewrite(req, res, d)
assert.Nil(t, err) require.Nil(t, err)
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode) assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
if assert.Len(t, d.Res.Answer, 1) {
assert.Equal(t, ip4, d.Res.Answer[0].(*dns.A).A) require.Len(t, d.Res.Answer, 1)
} assert.Equal(t, ip4, d.Res.Answer[0].(*dns.A).A)
}) })
t.Run("noerror_aaaa", func(t *testing.T) { t.Run("noerror_aaaa", func(t *testing.T) {
@ -88,11 +90,11 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
d := &proxy.DNSContext{} d := &proxy.DNSContext{}
err := srv.filterDNSRewrite(req, res, d) err := srv.filterDNSRewrite(req, res, d)
assert.Nil(t, err) require.Nil(t, err)
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode) assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
if assert.Len(t, d.Res.Answer, 1) {
assert.Equal(t, ip6, d.Res.Answer[0].(*dns.AAAA).AAAA) require.Len(t, d.Res.Answer, 1)
} assert.Equal(t, ip6, d.Res.Answer[0].(*dns.AAAA).AAAA)
}) })
t.Run("noerror_ptr", func(t *testing.T) { t.Run("noerror_ptr", func(t *testing.T) {
@ -101,11 +103,11 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
d := &proxy.DNSContext{} d := &proxy.DNSContext{}
err := srv.filterDNSRewrite(req, res, d) err := srv.filterDNSRewrite(req, res, d)
assert.Nil(t, err) require.Nil(t, err)
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode) assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
if assert.Len(t, d.Res.Answer, 1) {
assert.Equal(t, domain, d.Res.Answer[0].(*dns.PTR).Ptr) require.Len(t, d.Res.Answer, 1)
} assert.Equal(t, domain, d.Res.Answer[0].(*dns.PTR).Ptr)
}) })
t.Run("noerror_txt", func(t *testing.T) { t.Run("noerror_txt", func(t *testing.T) {
@ -114,11 +116,11 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
d := &proxy.DNSContext{} d := &proxy.DNSContext{}
err := srv.filterDNSRewrite(req, res, d) err := srv.filterDNSRewrite(req, res, d)
assert.Nil(t, err) require.Nil(t, err)
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode) assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
if assert.Len(t, d.Res.Answer, 1) {
assert.Equal(t, []string{domain}, d.Res.Answer[0].(*dns.TXT).Txt) require.Len(t, d.Res.Answer, 1)
} assert.Equal(t, []string{domain}, d.Res.Answer[0].(*dns.TXT).Txt)
}) })
t.Run("noerror_mx", func(t *testing.T) { t.Run("noerror_mx", func(t *testing.T) {
@ -127,15 +129,15 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
d := &proxy.DNSContext{} d := &proxy.DNSContext{}
err := srv.filterDNSRewrite(req, res, d) err := srv.filterDNSRewrite(req, res, d)
assert.Nil(t, err) require.Nil(t, err)
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode) assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
if assert.Len(t, d.Res.Answer, 1) {
ans, ok := d.Res.Answer[0].(*dns.MX) require.Len(t, d.Res.Answer, 1)
if assert.True(t, ok) { ans, ok := d.Res.Answer[0].(*dns.MX)
assert.Equal(t, mx.Exchange, ans.Mx)
assert.Equal(t, mx.Preference, ans.Preference) require.True(t, ok)
} assert.Equal(t, mx.Exchange, ans.Mx)
} assert.Equal(t, mx.Preference, ans.Preference)
}) })
t.Run("noerror_svcb", func(t *testing.T) { t.Run("noerror_svcb", func(t *testing.T) {
@ -144,17 +146,17 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
d := &proxy.DNSContext{} d := &proxy.DNSContext{}
err := srv.filterDNSRewrite(req, res, d) err := srv.filterDNSRewrite(req, res, d)
assert.Nil(t, err) require.Nil(t, err)
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode) assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
if assert.Len(t, d.Res.Answer, 1) {
ans, ok := d.Res.Answer[0].(*dns.SVCB) require.Len(t, d.Res.Answer, 1)
if assert.True(t, ok) { ans, ok := d.Res.Answer[0].(*dns.SVCB)
assert.Equal(t, dns.SVCB_ALPN, ans.Value[0].Key()) require.True(t, ok)
assert.Equal(t, svcb.Params["alpn"], ans.Value[0].String())
assert.Equal(t, svcb.Target, ans.Target) assert.Equal(t, dns.SVCB_ALPN, ans.Value[0].Key())
assert.Equal(t, svcb.Priority, ans.Priority) assert.Equal(t, svcb.Params["alpn"], ans.Value[0].String())
} assert.Equal(t, svcb.Target, ans.Target)
} assert.Equal(t, svcb.Priority, ans.Priority)
}) })
t.Run("noerror_https", func(t *testing.T) { t.Run("noerror_https", func(t *testing.T) {
@ -163,16 +165,16 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
d := &proxy.DNSContext{} d := &proxy.DNSContext{}
err := srv.filterDNSRewrite(req, res, d) err := srv.filterDNSRewrite(req, res, d)
assert.Nil(t, err) require.Nil(t, err)
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode) assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
if assert.Len(t, d.Res.Answer, 1) {
ans, ok := d.Res.Answer[0].(*dns.HTTPS) require.Len(t, d.Res.Answer, 1)
if assert.True(t, ok) { ans, ok := d.Res.Answer[0].(*dns.HTTPS)
assert.Equal(t, dns.SVCB_ALPN, ans.Value[0].Key())
assert.Equal(t, svcb.Params["alpn"], ans.Value[0].String()) require.True(t, ok)
assert.Equal(t, svcb.Target, ans.Target) assert.Equal(t, dns.SVCB_ALPN, ans.Value[0].Key())
assert.Equal(t, svcb.Priority, ans.Priority) assert.Equal(t, svcb.Params["alpn"], ans.Value[0].String())
} assert.Equal(t, svcb.Target, ans.Target)
} assert.Equal(t, svcb.Priority, ans.Priority)
}) })
} }

View File

@ -10,6 +10,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/dnsfilter" "github.com/AdguardTeam/AdGuardHome/internal/dnsfilter"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestDNSForwardHTTTP_handleGetConfig(t *testing.T) { func TestDNSForwardHTTTP_handleGetConfig(t *testing.T) {
@ -31,9 +32,10 @@ func TestDNSForwardHTTTP_handleGetConfig(t *testing.T) {
ConfigModified: func() {}, ConfigModified: func() {},
} }
s := createTestServer(t, filterConf, forwardConf) s := createTestServer(t, filterConf, forwardConf)
err := s.Start() require.Nil(t, s.Start())
assert.Nil(t, err) t.Cleanup(func() {
defer assert.Nil(t, s.Stop()) require.Nil(t, s.Stop())
})
defaultConf := s.conf defaultConf := s.conf
@ -71,13 +73,14 @@ func TestDNSForwardHTTTP_handleGetConfig(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
t.Cleanup(w.Body.Reset)
s.conf = tc.conf() s.conf = tc.conf()
s.handleGetConfig(w, nil) s.handleGetConfig(w, nil)
assert.Equal(t, tc.want, w.Body.String())
assert.Equal(t, "application/json", w.Header().Get("Content-Type")) assert.Equal(t, "application/json", w.Header().Get("Content-Type"))
assert.Equal(t, tc.want, w.Body.String())
}) })
w.Body.Reset()
} }
} }
@ -191,9 +194,13 @@ func TestDNSForwardHTTTP_handleSetConfig(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
t.Cleanup(func() {
s.conf = defaultConf
})
rBody := ioutil.NopCloser(strings.NewReader(tc.req)) rBody := ioutil.NopCloser(strings.NewReader(tc.req))
r, err := http.NewRequest(http.MethodPost, "http://example.com", rBody) r, err := http.NewRequest(http.MethodPost, "http://example.com", rBody)
assert.Nil(t, err) require.Nil(t, err)
s.handleSetConfig(w, r) s.handleSetConfig(w, r)
assert.Equal(t, tc.wantSet, w.Body.String()) assert.Equal(t, tc.wantSet, w.Body.String())
@ -203,6 +210,5 @@ func TestDNSForwardHTTTP_handleSetConfig(t *testing.T) {
assert.Equal(t, tc.wantGet, w.Body.String()) assert.Equal(t, tc.wantGet, w.Body.String())
w.Body.Reset() w.Body.Reset()
}) })
s.conf = defaultConf
} }
} }

View File

@ -12,6 +12,7 @@ import (
"github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/dnsproxy/upstream"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
// testQueryLog is a simple querylog.QueryLog implementation for tests. // testQueryLog is a simple querylog.QueryLog implementation for tests.
@ -156,7 +157,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
}} }}
ups, err := upstream.AddressToUpstream("1.1.1.1", upstream.Options{}) ups, err := upstream.AddressToUpstream("1.1.1.1", upstream.Options{})
assert.Nil(t, err) require.Nil(t, err)
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {

View File

@ -20,10 +20,18 @@ func TestMain(m *testing.M) {
aghtest.DiscardLogOutput(m) aghtest.DiscardLogOutput(m)
} }
func prepareTestDir() string { func prepareTestDir(t *testing.T) string {
t.Helper()
const dir = "./agh-test" const dir = "./agh-test"
_ = os.RemoveAll(dir)
_ = os.MkdirAll(dir, 0o755) require.Nil(t, os.RemoveAll(dir))
// TODO(e.burkov): Replace with testing.TempDir after updating Go
// version to 1.16.
require.Nil(t, os.MkdirAll(dir, 0o755))
t.Cleanup(func() { require.Nil(t, os.RemoveAll(dir)) })
return dir return dir
} }
@ -47,8 +55,7 @@ func TestNewSessionToken(t *testing.T) {
} }
func TestAuth(t *testing.T) { func TestAuth(t *testing.T) {
dir := prepareTestDir() dir := prepareTestDir(t)
t.Cleanup(func() { _ = os.RemoveAll(dir) })
fn := filepath.Join(dir, "sessions.db") fn := filepath.Join(dir, "sessions.db")
users := []User{{ users := []User{{
@ -123,8 +130,7 @@ func (w *testResponseWriter) WriteHeader(statusCode int) {
} }
func TestAuthHTTP(t *testing.T) { func TestAuthHTTP(t *testing.T) {
dir := prepareTestDir() dir := prepareTestDir(t)
defer func() { _ = os.RemoveAll(dir) }()
fn := filepath.Join(dir, "sessions.db") fn := filepath.Join(dir, "sessions.db")
users := []User{ users := []User{

View File

@ -4,40 +4,38 @@ import (
"encoding/binary" "encoding/binary"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"os"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestAuthGL(t *testing.T) { func TestAuthGL(t *testing.T) {
dir := prepareTestDir() dir := prepareTestDir(t)
defer func() { _ = os.RemoveAll(dir) }()
GLMode = true GLMode = true
t.Cleanup(func() {
GLMode = false
})
glFilePrefix = dir + "/gl_token_" glFilePrefix = dir + "/gl_token_"
tval := uint32(1) putFunc := binary.BigEndian.PutUint32
data := make([]byte, 4)
if archIsLittleEndian() { if archIsLittleEndian() {
binary.LittleEndian.PutUint32(data, tval) putFunc = binary.LittleEndian.PutUint32
} else {
binary.BigEndian.PutUint32(data, tval)
} }
assert.Nil(t, ioutil.WriteFile(glFilePrefix+"test", data, 0o644))
data := make([]byte, 4)
putFunc(data, 1)
require.Nil(t, ioutil.WriteFile(glFilePrefix+"test", data, 0o644))
assert.False(t, glCheckToken("test")) assert.False(t, glCheckToken("test"))
tval = uint32(time.Now().UTC().Unix() + 60)
data = make([]byte, 4) data = make([]byte, 4)
if archIsLittleEndian() { putFunc(data, uint32(time.Now().UTC().Unix()+60))
binary.LittleEndian.PutUint32(data, tval)
} else { require.Nil(t, ioutil.WriteFile(glFilePrefix+"test", data, 0o644))
binary.BigEndian.PutUint32(data, tval)
}
assert.Nil(t, ioutil.WriteFile(glFilePrefix+"test", data, 0o644))
r, _ := http.NewRequest(http.MethodGet, "http://localhost/", nil) r, _ := http.NewRequest(http.MethodGet, "http://localhost/", nil)
r.AddCookie(&http.Cookie{Name: glCookieName, Value: "test"}) r.AddCookie(&http.Cookie{Name: glCookieName, Value: "test"})
assert.True(t, glProcessCookie(r)) assert.True(t, glProcessCookie(r))
GLMode = false
} }

View File

@ -9,6 +9,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd" "github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestClients(t *testing.T) { func TestClients(t *testing.T) {
@ -24,8 +25,8 @@ func TestClients(t *testing.T) {
} }
ok, err := clients.Add(c) ok, err := clients.Add(c)
require.Nil(t, err)
assert.True(t, ok) assert.True(t, ok)
assert.Nil(t, err)
c = &Client{ c = &Client{
IDs: []string{"2.2.2.2"}, IDs: []string{"2.2.2.2"},
@ -33,110 +34,99 @@ func TestClients(t *testing.T) {
} }
ok, err = clients.Add(c) ok, err = clients.Add(c)
require.Nil(t, err)
assert.True(t, ok) assert.True(t, ok)
assert.Nil(t, err)
c, ok = clients.Find("1.1.1.1") c, ok = clients.Find("1.1.1.1")
assert.True(t, ok) require.True(t, ok)
assert.Equal(t, "client1", c.Name) assert.Equal(t, "client1", c.Name)
c, ok = clients.Find("1:2:3::4") c, ok = clients.Find("1:2:3::4")
assert.True(t, ok) require.True(t, ok)
assert.Equal(t, "client1", c.Name) assert.Equal(t, "client1", c.Name)
c, ok = clients.Find("2.2.2.2") c, ok = clients.Find("2.2.2.2")
assert.True(t, ok) require.True(t, ok)
assert.Equal(t, "client2", c.Name) assert.Equal(t, "client2", c.Name)
assert.True(t, !clients.Exists("1.2.3.4", ClientSourceHostsFile)) assert.False(t, clients.Exists("1.2.3.4", ClientSourceHostsFile))
assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile)) assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
assert.True(t, clients.Exists("2.2.2.2", ClientSourceHostsFile)) assert.True(t, clients.Exists("2.2.2.2", ClientSourceHostsFile))
}) })
t.Run("add_fail_name", func(t *testing.T) { t.Run("add_fail_name", func(t *testing.T) {
c := &Client{ ok, err := clients.Add(&Client{
IDs: []string{"1.2.3.5"}, IDs: []string{"1.2.3.5"},
Name: "client1", Name: "client1",
} })
require.Nil(t, err)
ok, err := clients.Add(c)
assert.False(t, ok) assert.False(t, ok)
assert.Nil(t, err)
}) })
t.Run("add_fail_ip", func(t *testing.T) { t.Run("add_fail_ip", func(t *testing.T) {
c := &Client{ ok, err := clients.Add(&Client{
IDs: []string{"2.2.2.2"}, IDs: []string{"2.2.2.2"},
Name: "client3", Name: "client3",
} })
require.NotNil(t, err)
ok, err := clients.Add(c)
assert.False(t, ok) assert.False(t, ok)
assert.NotNil(t, err)
}) })
t.Run("update_fail_name", func(t *testing.T) { t.Run("update_fail_name", func(t *testing.T) {
c := &Client{ err := clients.Update("client3", &Client{
IDs: []string{"1.2.3.0"}, IDs: []string{"1.2.3.0"},
Name: "client3", Name: "client3",
} })
require.NotNil(t, err)
err := clients.Update("client3", c) err = clients.Update("client3", &Client{
assert.NotNil(t, err)
c = &Client{
IDs: []string{"1.2.3.0"}, IDs: []string{"1.2.3.0"},
Name: "client2", Name: "client2",
} })
err = clients.Update("client3", c)
assert.NotNil(t, err) assert.NotNil(t, err)
}) })
t.Run("update_fail_ip", func(t *testing.T) { t.Run("update_fail_ip", func(t *testing.T) {
c := &Client{ err := clients.Update("client1", &Client{
IDs: []string{"2.2.2.2"}, IDs: []string{"2.2.2.2"},
Name: "client1", Name: "client1",
} })
err := clients.Update("client1", c)
assert.NotNil(t, err) assert.NotNil(t, err)
}) })
t.Run("update_success", func(t *testing.T) { t.Run("update_success", func(t *testing.T) {
c := &Client{ err := clients.Update("client1", &Client{
IDs: []string{"1.1.1.2"}, IDs: []string{"1.1.1.2"},
Name: "client1", Name: "client1",
} })
require.Nil(t, err)
err := clients.Update("client1", c) assert.False(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
assert.Nil(t, err)
assert.True(t, !clients.Exists("1.1.1.1", ClientSourceHostsFile))
assert.True(t, clients.Exists("1.1.1.2", ClientSourceHostsFile)) assert.True(t, clients.Exists("1.1.1.2", ClientSourceHostsFile))
c = &Client{ err = clients.Update("client1", &Client{
IDs: []string{"1.1.1.2"}, IDs: []string{"1.1.1.2"},
Name: "client1-renamed", Name: "client1-renamed",
UseOwnSettings: true, UseOwnSettings: true,
} })
require.Nil(t, err)
err = clients.Update("client1", c)
assert.Nil(t, err)
c, ok := clients.Find("1.1.1.2") c, ok := clients.Find("1.1.1.2")
assert.True(t, ok) require.True(t, ok)
assert.Equal(t, "client1-renamed", c.Name) assert.Equal(t, "client1-renamed", c.Name)
assert.True(t, c.UseOwnSettings) assert.True(t, c.UseOwnSettings)
assert.Nil(t, clients.list["client1"])
if assert.Len(t, c.IDs, 1) { nilCli, ok := clients.list["client1"]
assert.Equal(t, "1.1.1.2", c.IDs[0]) require.False(t, ok)
} assert.Nil(t, nilCli)
require.Len(t, c.IDs, 1)
assert.Equal(t, "1.1.1.2", c.IDs[0])
}) })
t.Run("del_success", func(t *testing.T) { t.Run("del_success", func(t *testing.T) {
ok := clients.Del("client1-renamed") ok := clients.Del("client1-renamed")
assert.True(t, ok) require.True(t, ok)
assert.False(t, clients.Exists("1.1.1.2", ClientSourceHostsFile)) assert.False(t, clients.Exists("1.1.1.2", ClientSourceHostsFile))
}) })
@ -147,146 +137,155 @@ func TestClients(t *testing.T) {
t.Run("addhost_success", func(t *testing.T) { t.Run("addhost_success", func(t *testing.T) {
ok, err := clients.AddHost("1.1.1.1", "host", ClientSourceARP) ok, err := clients.AddHost("1.1.1.1", "host", ClientSourceARP)
require.Nil(t, err)
assert.True(t, ok) assert.True(t, ok)
assert.Nil(t, err)
ok, err = clients.AddHost("1.1.1.1", "host2", ClientSourceARP) ok, err = clients.AddHost("1.1.1.1", "host2", ClientSourceARP)
require.Nil(t, err)
assert.True(t, ok) assert.True(t, ok)
assert.Nil(t, err)
ok, err = clients.AddHost("1.1.1.1", "host3", ClientSourceHostsFile) ok, err = clients.AddHost("1.1.1.1", "host3", ClientSourceHostsFile)
require.Nil(t, err)
assert.True(t, ok) assert.True(t, ok)
assert.Nil(t, err)
assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile)) assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
}) })
t.Run("addhost_fail", func(t *testing.T) { t.Run("addhost_fail", func(t *testing.T) {
ok, err := clients.AddHost("1.1.1.1", "host1", ClientSourceRDNS) ok, err := clients.AddHost("1.1.1.1", "host1", ClientSourceRDNS)
require.Nil(t, err)
assert.False(t, ok) assert.False(t, ok)
assert.Nil(t, err)
}) })
} }
func TestClientsWhois(t *testing.T) { func TestClientsWhois(t *testing.T) {
var c *Client clients := clientsContainer{
clients := clientsContainer{} testing: true,
clients.testing = true }
clients.Init(nil, nil, nil) clients.Init(nil, nil, nil)
whois := [][]string{{"orgname", "orgname-val"}, {"country", "country-val"}} whois := [][]string{{"orgname", "orgname-val"}, {"country", "country-val"}}
// set whois info on new client
clients.SetWhoisInfo("1.1.1.255", whois) t.Run("new_client", func(t *testing.T) {
if assert.NotNil(t, clients.ipHost["1.1.1.255"]) { clients.SetWhoisInfo("1.1.1.255", whois)
require.NotNil(t, clients.ipHost["1.1.1.255"])
h := clients.ipHost["1.1.1.255"] h := clients.ipHost["1.1.1.255"]
if assert.Len(t, h.WhoisInfo, 2) && assert.Len(t, h.WhoisInfo[0], 2) {
assert.Equal(t, "orgname-val", h.WhoisInfo[0][1])
}
}
// set whois info on existing auto-client require.Len(t, h.WhoisInfo, 2)
_, _ = clients.AddHost("1.1.1.1", "host", ClientSourceRDNS) require.Len(t, h.WhoisInfo[0], 2)
clients.SetWhoisInfo("1.1.1.1", whois) assert.Equal(t, "orgname-val", h.WhoisInfo[0][1])
if assert.NotNil(t, clients.ipHost["1.1.1.1"]) { })
t.Run("existing_auto-client", func(t *testing.T) {
ok, err := clients.AddHost("1.1.1.1", "host", ClientSourceRDNS)
require.Nil(t, err)
assert.True(t, ok)
clients.SetWhoisInfo("1.1.1.1", whois)
require.NotNil(t, clients.ipHost["1.1.1.1"])
h := clients.ipHost["1.1.1.1"] h := clients.ipHost["1.1.1.1"]
if assert.Len(t, h.WhoisInfo, 2) && assert.Len(t, h.WhoisInfo[0], 2) {
assert.Equal(t, "orgname-val", h.WhoisInfo[0][1])
}
}
// Check that we cannot set whois info on a manually-added client require.Len(t, h.WhoisInfo, 2)
c = &Client{ require.Len(t, h.WhoisInfo[0], 2)
IDs: []string{"1.1.1.2"}, assert.Equal(t, "orgname-val", h.WhoisInfo[0][1])
Name: "client1", })
}
_, _ = clients.Add(c) t.Run("can't_set_manually-added", func(t *testing.T) {
clients.SetWhoisInfo("1.1.1.2", whois) ok, err := clients.Add(&Client{
assert.Nil(t, clients.ipHost["1.1.1.2"]) IDs: []string{"1.1.1.2"},
_ = clients.Del("client1") Name: "client1",
})
require.Nil(t, err)
assert.True(t, ok)
clients.SetWhoisInfo("1.1.1.2", whois)
require.Nil(t, clients.ipHost["1.1.1.2"])
assert.True(t, clients.Del("client1"))
})
} }
func TestClientsAddExisting(t *testing.T) { func TestClientsAddExisting(t *testing.T) {
var c *Client clients := clientsContainer{
clients := clientsContainer{} testing: true,
clients.testing = true }
clients.Init(nil, nil, nil) clients.Init(nil, nil, nil)
// some test variables t.Run("simple", func(t *testing.T) {
mac, _ := net.ParseMAC("aa:aa:aa:aa:aa:aa") // Add a client.
testIP := "1.2.3.4" ok, err := clients.Add(&Client{
IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa", "2.2.2.0/24"},
Name: "client1",
})
require.Nil(t, err)
assert.True(t, ok)
// add a client // Now add an auto-client with the same IP.
c = &Client{ ok, err = clients.AddHost("1.1.1.1", "test", ClientSourceRDNS)
IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa", "2.2.2.0/24"}, require.Nil(t, err)
Name: "client1", assert.True(t, ok)
}
ok, err := clients.Add(c)
assert.True(t, ok)
assert.Nil(t, err)
// add an auto-client with the same IP - it's allowed
ok, err = clients.AddHost("1.1.1.1", "test", ClientSourceRDNS)
assert.True(t, ok)
assert.Nil(t, err)
// now some more complicated stuff
// first, init a DHCP server with a single static lease
config := dhcpd.ServerConfig{
DBFilePath: "leases.db",
}
defer func() { _ = os.Remove("leases.db") }()
clients.dhcpServer = dhcpd.Create(config)
err = clients.dhcpServer.AddStaticLease(dhcpd.Lease{
HWAddr: mac,
IP: net.ParseIP(testIP).To4(),
Hostname: "testhost",
Expiry: time.Now().Add(time.Hour),
}) })
assert.Nil(t, err)
// add a new client with the same IP as for a client with MAC t.Run("complicated", func(t *testing.T) {
c = &Client{ testIP := net.IP{1, 2, 3, 4}
IDs: []string{testIP},
Name: "client2",
}
ok, err = clients.Add(c)
assert.True(t, ok)
assert.Nil(t, err)
// add a new client with the IP from the client1's IP range // First, init a DHCP server with a single static lease.
c = &Client{ config := dhcpd.ServerConfig{
IDs: []string{"2.2.2.2"}, DBFilePath: "leases.db",
Name: "client3", }
} clients.dhcpServer = dhcpd.Create(config)
ok, err = clients.Add(c) t.Cleanup(func() { _ = os.Remove("leases.db") })
assert.True(t, ok)
assert.Nil(t, err) err := clients.dhcpServer.AddStaticLease(dhcpd.Lease{
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
IP: testIP,
Hostname: "testhost",
Expiry: time.Now().Add(time.Hour),
})
require.Nil(t, err)
// Add a new client with the same IP as for a client with MAC.
ok, err := clients.Add(&Client{
IDs: []string{testIP.String()},
Name: "client2",
})
require.Nil(t, err)
assert.True(t, ok)
// Add a new client with the IP from the first client's IP
// range.
ok, err = clients.Add(&Client{
IDs: []string{"2.2.2.2"},
Name: "client3",
})
require.Nil(t, err)
assert.True(t, ok)
})
} }
func TestClientsCustomUpstream(t *testing.T) { func TestClientsCustomUpstream(t *testing.T) {
clients := clientsContainer{} clients := clientsContainer{
clients.testing = true testing: true,
}
clients.Init(nil, nil, nil) clients.Init(nil, nil, nil)
// add client with upstreams // Add client with upstreams.
c := &Client{ ok, err := clients.Add(&Client{
IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa"}, IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa"},
Name: "client1", Name: "client1",
Upstreams: []string{ Upstreams: []string{
"1.1.1.1", "1.1.1.1",
"[/example.org/]8.8.8.8", "[/example.org/]8.8.8.8",
}, },
} })
ok, err := clients.Add(c) require.Nil(t, err)
assert.Nil(t, err)
assert.True(t, ok) assert.True(t, ok)
config := clients.FindUpstreams("1.2.3.4") config := clients.FindUpstreams("1.2.3.4")
assert.Nil(t, config) assert.Nil(t, config)
config = clients.FindUpstreams("1.1.1.1") config = clients.FindUpstreams("1.1.1.1")
assert.NotNil(t, config) require.NotNil(t, config)
assert.Equal(t, 1, len(config.Upstreams)) assert.Len(t, config.Upstreams, 1)
assert.Equal(t, 1, len(config.DomainReservedUpstreams)) assert.Len(t, config.DomainReservedUpstreams, 1)
} }

View File

@ -3,32 +3,12 @@ package home
import ( import (
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert"
) )
/* Tests performed: const (
. Bad certificate CertificateChain = `-----BEGIN CERTIFICATE-----
. Bad private key
. Valid certificate & private key */
func TestValidateCertificates(t *testing.T) {
var data tlsConfigStatus
// bad cert
data = validateCertificates("bad cert", "", "")
if !(data.WarningValidation != "" &&
!data.ValidCert &&
!data.ValidChain) {
t.Fatalf("bad cert: validateCertificates(): %v", data)
}
// bad priv key
data = validateCertificates("", "bad priv key", "")
if !(data.WarningValidation != "" &&
!data.ValidKey) {
t.Fatalf("bad priv key: validateCertificates(): %v", data)
}
// valid cert & priv key
CertificateChain := `-----BEGIN CERTIFICATE-----
MIICKzCCAZSgAwIBAgIJAMT9kPVJdM7LMA0GCSqGSIb3DQEBCwUAMC0xFDASBgNV MIICKzCCAZSgAwIBAgIJAMT9kPVJdM7LMA0GCSqGSIb3DQEBCwUAMC0xFDASBgNV
BAoMC0FkR3VhcmQgTHRkMRUwEwYDVQQDDAxBZEd1YXJkIEhvbWUwHhcNMTkwMjI3 BAoMC0FkR3VhcmQgTHRkMRUwEwYDVQQDDAxBZEd1YXJkIEhvbWUwHhcNMTkwMjI3
MDkyNDIzWhcNNDYwNzE0MDkyNDIzWjAtMRQwEgYDVQQKDAtBZEd1YXJkIEx0ZDEV MDkyNDIzWhcNNDYwNzE0MDkyNDIzWjAtMRQwEgYDVQQKDAtBZEd1YXJkIEx0ZDEV
@ -42,7 +22,7 @@ LwlXfbakf7qkVTlCNXgoY7RaJ8rJdPgOZPoCTVToEhT6u/cb1c2qp8QB0dNExDna
b0Z+dnODTZqQOJo6z/wIXlcUrnR4cQVvytXt8lFn+26l6Y6EMI26twC/xWr+1swq b0Z+dnODTZqQOJo6z/wIXlcUrnR4cQVvytXt8lFn+26l6Y6EMI26twC/xWr+1swq
Muj4FeWHVDerquH4yMr1jsYLD3ci+kc5sbIX6TfVxQ== Muj4FeWHVDerquH4yMr1jsYLD3ci+kc5sbIX6TfVxQ==
-----END CERTIFICATE-----` -----END CERTIFICATE-----`
PrivateKey := `-----BEGIN PRIVATE KEY----- PrivateKey = `-----BEGIN PRIVATE KEY-----
MIICeAIBADANBgkqhkiG9w0BAQEFAASCAmIwggJeAgEAAoGBALC/BSc8mI68tw5p MIICeAIBADANBgkqhkiG9w0BAQEFAASCAmIwggJeAgEAAoGBALC/BSc8mI68tw5p
aYa7pjrySwWvXeetcFywOWHGVfLw9qiFWLdfESa3Y6tWMpZAXD9t1Xh9n211YUBV aYa7pjrySwWvXeetcFywOWHGVfLw9qiFWLdfESa3Y6tWMpZAXD9t1Xh9n211YUBV
FGSB4ZshnM/tgEPU6t787lJD4NsIIRp++MkJxdAitN4oUTqL0bdpIwezQ/CrYuBX FGSB4ZshnM/tgEPU6t787lJD4NsIIRp++MkJxdAitN4oUTqL0bdpIwezQ/CrYuBX
@ -58,20 +38,35 @@ O5EX70gpeGQMPDK0QSWpaazg956njJSDbNCFM4BccrdQbJu1cW4qOsfBAkAMgZuG
O88slmgTRHX4JGFmy3rrLiHNI2BbJSuJ++Yllz8beVzh6NfvuY+HKRCmPqoBPATU O88slmgTRHX4JGFmy3rrLiHNI2BbJSuJ++Yllz8beVzh6NfvuY+HKRCmPqoBPATU
kXS9jgARhhiWXJrk kXS9jgARhhiWXJrk
-----END PRIVATE KEY-----` -----END PRIVATE KEY-----`
data = validateCertificates(CertificateChain, PrivateKey, "") )
notBefore, _ := time.Parse(time.RFC3339, "2019-02-27T09:24:23Z")
notAfter, _ := time.Parse(time.RFC3339, "2046-07-14T09:24:23Z") func TestValidateCertificates(t *testing.T) {
if !(data.WarningValidation != "" /* self signed */ && t.Run("bad_certificate", func(t *testing.T) {
data.ValidCert && data := validateCertificates("bad cert", "", "")
!data.ValidChain && assert.NotEmpty(t, data.WarningValidation)
data.ValidKey && assert.False(t, data.ValidCert)
data.KeyType == "RSA" && assert.False(t, data.ValidChain)
data.Subject == "CN=AdGuard Home,O=AdGuard Ltd" && })
data.Issuer == "CN=AdGuard Home,O=AdGuard Ltd" &&
data.NotBefore.Equal(notBefore) && t.Run("bad_private_key", func(t *testing.T) {
data.NotAfter.Equal(notAfter) && data := validateCertificates("", "bad priv key", "")
// data.DNSNames[0] == && assert.NotEmpty(t, data.WarningValidation)
data.ValidPair) { assert.False(t, data.ValidKey)
t.Fatalf("valid cert & priv key: validateCertificates(): %v", data) })
}
t.Run("valid", func(t *testing.T) {
data := validateCertificates(CertificateChain, PrivateKey, "")
notBefore, _ := time.Parse(time.RFC3339, "2019-02-27T09:24:23Z")
notAfter, _ := time.Parse(time.RFC3339, "2046-07-14T09:24:23Z")
assert.NotEmpty(t, data.WarningValidation)
assert.True(t, data.ValidCert)
assert.False(t, data.ValidChain)
assert.True(t, data.ValidKey)
assert.Equal(t, "RSA", data.KeyType)
assert.Equal(t, "CN=AdGuard Home,O=AdGuard Ltd", data.Subject)
assert.Equal(t, "CN=AdGuard Home,O=AdGuard Ltd", data.Issuer)
assert.Equal(t, notBefore, data.NotBefore)
assert.Equal(t, notAfter, data.NotAfter)
assert.True(t, data.ValidPair)
})
} }

View File

@ -9,38 +9,47 @@ import (
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func testStartFilterListener() net.Listener { func testStartFilterListener(t *testing.T) net.Listener {
t.Helper()
const content = `||example.org^$third-party
# Inline comment example
||example.com^$third-party
0.0.0.0 example.com
`
mux := http.NewServeMux() mux := http.NewServeMux()
mux.HandleFunc("/filters/1.txt", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/filters/1.txt", func(w http.ResponseWriter, r *http.Request) {
content := `||example.org^$third-party _, werr := w.Write([]byte(content))
# Inline comment example assert.Nil(t, werr)
||example.com^$third-party
0.0.0.0 example.com
`
_, _ = w.Write([]byte(content))
}) })
listener, err := net.Listen("tcp", ":0") listener, err := net.Listen("tcp", ":0")
if err != nil { require.Nil(t, err)
panic(err)
} go func() {
_ = http.Serve(listener, mux)
}()
t.Cleanup(func() {
assert.Nil(t, listener.Close())
})
go func() { _ = http.Serve(listener, mux) }()
return listener return listener
} }
func TestFilters(t *testing.T) { func TestFilters(t *testing.T) {
l := testStartFilterListener() l := testStartFilterListener(t)
defer func() { _ = l.Close() }() dir := prepareTestDir(t)
dir := prepareTestDir() Context = homeContext{
defer func() { _ = os.RemoveAll(dir) }() workDir: dir,
Context = homeContext{} client: &http.Client{
Context.workDir = dir Timeout: 5 * time.Second,
Context.client = &http.Client{ },
Timeout: 5 * time.Second,
} }
Context.filters.Init() Context.filters.Init()
@ -48,20 +57,20 @@ func TestFilters(t *testing.T) {
URL: fmt.Sprintf("http://127.0.0.1:%d/filters/1.txt", l.Addr().(*net.TCPAddr).Port), URL: fmt.Sprintf("http://127.0.0.1:%d/filters/1.txt", l.Addr().(*net.TCPAddr).Port),
} }
// download // Download.
ok, err := Context.filters.update(&f) ok, err := Context.filters.update(&f)
assert.Nil(t, err) require.Nil(t, err)
assert.True(t, ok) require.True(t, ok)
assert.Equal(t, 3, f.RulesCount) assert.Equal(t, 3, f.RulesCount)
// refresh // Refresh.
ok, err = Context.filters.update(&f) ok, err = Context.filters.update(&f)
assert.False(t, ok) require.Nil(t, err)
assert.Nil(t, err) require.False(t, ok)
err = Context.filters.load(&f) err = Context.filters.load(&f)
assert.Nil(t, err) require.Nil(t, err)
f.unload() f.unload()
_ = os.Remove(f.Path()) require.Nil(t, os.Remove(f.Path()))
} }

View File

@ -114,8 +114,7 @@ func TestHome(t *testing.T) {
// Init new context // Init new context
Context = homeContext{} Context = homeContext{}
dir := prepareTestDir() dir := prepareTestDir(t)
defer func() { _ = os.RemoveAll(dir) }()
fn := filepath.Join(dir, "AdGuardHome.yaml") fn := filepath.Join(dir, "AdGuardHome.yaml")
// Prepare the test config // Prepare the test config

View File

@ -39,21 +39,21 @@ func TestLimitRequestBody(t *testing.T) {
wantErr: nil, wantErr: nil,
}} }}
makeHandler := func(err *error) http.HandlerFunc { makeHandler := func(t *testing.T, err *error) http.HandlerFunc {
t.Helper()
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var b []byte var b []byte
b, *err = ioutil.ReadAll(r.Body) b, *err = ioutil.ReadAll(r.Body)
_, werr := w.Write(b) _, werr := w.Write(b)
if werr != nil { require.Nil(t, werr)
panic(werr)
}
}) })
} }
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
var err error var err error
handler := makeHandler(&err) handler := makeHandler(t, &err)
lim := limitRequestBody(handler) lim := limitRequestBody(handler)
req := httptest.NewRequest(http.MethodPost, "https://www.example.com", strings.NewReader(tc.body)) req := httptest.NewRequest(http.MethodPost, "https://www.example.com", strings.NewReader(tc.body))
@ -61,7 +61,7 @@ func TestLimitRequestBody(t *testing.T) {
lim.ServeHTTP(res, req) lim.ServeHTTP(res, req)
require.Equal(t, tc.wantErr, err) assert.Equal(t, tc.wantErr, err)
assert.Equal(t, tc.want, res.Body.Bytes()) assert.Equal(t, tc.want, res.Body.Bytes())
}) })
} }

View File

@ -6,29 +6,29 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"howett.net/plist" "howett.net/plist"
) )
func TestHandleMobileConfigDOH(t *testing.T) { func TestHandleMobileConfigDOH(t *testing.T) {
t.Run("success", func(t *testing.T) { t.Run("success", func(t *testing.T) {
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/doh.mobileconfig?host=example.org", nil) r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/doh.mobileconfig?host=example.org", nil)
assert.Nil(t, err) require.Nil(t, err)
w := httptest.NewRecorder() w := httptest.NewRecorder()
handleMobileConfigDOH(w, r) handleMobileConfigDOH(w, r)
assert.Equal(t, http.StatusOK, w.Code) require.Equal(t, http.StatusOK, w.Code)
var mc mobileConfig var mc mobileConfig
_, err = plist.Unmarshal(w.Body.Bytes(), &mc) _, err = plist.Unmarshal(w.Body.Bytes(), &mc)
assert.Nil(t, err) require.Nil(t, err)
if assert.Len(t, mc.PayloadContent, 1) { require.Len(t, mc.PayloadContent, 1)
assert.Equal(t, "example.org DoH", mc.PayloadContent[0].Name) assert.Equal(t, "example.org DoH", mc.PayloadContent[0].Name)
assert.Equal(t, "example.org DoH", mc.PayloadContent[0].PayloadDisplayName) assert.Equal(t, "example.org DoH", mc.PayloadContent[0].PayloadDisplayName)
assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName) assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName)
assert.Equal(t, "https://example.org/dns-query", mc.PayloadContent[0].DNSSettings.ServerURL) assert.Equal(t, "https://example.org/dns-query", mc.PayloadContent[0].DNSSettings.ServerURL)
}
}) })
t.Run("success_no_host", func(t *testing.T) { t.Run("success_no_host", func(t *testing.T) {
@ -40,23 +40,22 @@ func TestHandleMobileConfigDOH(t *testing.T) {
} }
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/doh.mobileconfig", nil) r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/doh.mobileconfig", nil)
assert.Nil(t, err) require.Nil(t, err)
w := httptest.NewRecorder() w := httptest.NewRecorder()
handleMobileConfigDOH(w, r) handleMobileConfigDOH(w, r)
assert.Equal(t, http.StatusOK, w.Code) require.Equal(t, http.StatusOK, w.Code)
var mc mobileConfig var mc mobileConfig
_, err = plist.Unmarshal(w.Body.Bytes(), &mc) _, err = plist.Unmarshal(w.Body.Bytes(), &mc)
assert.Nil(t, err) require.Nil(t, err)
if assert.Len(t, mc.PayloadContent, 1) { require.Len(t, mc.PayloadContent, 1)
assert.Equal(t, "example.org DoH", mc.PayloadContent[0].Name) assert.Equal(t, "example.org DoH", mc.PayloadContent[0].Name)
assert.Equal(t, "example.org DoH", mc.PayloadContent[0].PayloadDisplayName) assert.Equal(t, "example.org DoH", mc.PayloadContent[0].PayloadDisplayName)
assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName) assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName)
assert.Equal(t, "https://example.org/dns-query", mc.PayloadContent[0].DNSSettings.ServerURL) assert.Equal(t, "https://example.org/dns-query", mc.PayloadContent[0].DNSSettings.ServerURL)
}
}) })
t.Run("error_no_host", func(t *testing.T) { t.Run("error_no_host", func(t *testing.T) {
@ -66,7 +65,7 @@ func TestHandleMobileConfigDOH(t *testing.T) {
Context.tls = &TLSMod{conf: tlsConfigSettings{}} Context.tls = &TLSMod{conf: tlsConfigSettings{}}
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/doh.mobileconfig", nil) r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/doh.mobileconfig", nil)
assert.Nil(t, err) require.Nil(t, err)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -76,45 +75,43 @@ func TestHandleMobileConfigDOH(t *testing.T) {
t.Run("client_id", func(t *testing.T) { t.Run("client_id", func(t *testing.T) {
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/doh.mobileconfig?host=example.org&client_id=cli42", nil) r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/doh.mobileconfig?host=example.org&client_id=cli42", nil)
assert.Nil(t, err) require.Nil(t, err)
w := httptest.NewRecorder() w := httptest.NewRecorder()
handleMobileConfigDOH(w, r) handleMobileConfigDOH(w, r)
assert.Equal(t, http.StatusOK, w.Code) require.Equal(t, http.StatusOK, w.Code)
var mc mobileConfig var mc mobileConfig
_, err = plist.Unmarshal(w.Body.Bytes(), &mc) _, err = plist.Unmarshal(w.Body.Bytes(), &mc)
assert.Nil(t, err) require.Nil(t, err)
if assert.Len(t, mc.PayloadContent, 1) { require.Len(t, mc.PayloadContent, 1)
assert.Equal(t, "example.org DoH", mc.PayloadContent[0].Name) assert.Equal(t, "example.org DoH", mc.PayloadContent[0].Name)
assert.Equal(t, "example.org DoH", mc.PayloadContent[0].PayloadDisplayName) assert.Equal(t, "example.org DoH", mc.PayloadContent[0].PayloadDisplayName)
assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName) assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName)
assert.Equal(t, "https://example.org/dns-query/cli42", mc.PayloadContent[0].DNSSettings.ServerURL) assert.Equal(t, "https://example.org/dns-query/cli42", mc.PayloadContent[0].DNSSettings.ServerURL)
}
}) })
} }
func TestHandleMobileConfigDOT(t *testing.T) { func TestHandleMobileConfigDOT(t *testing.T) {
t.Run("success", func(t *testing.T) { t.Run("success", func(t *testing.T) {
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/dot.mobileconfig?host=example.org", nil) r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/dot.mobileconfig?host=example.org", nil)
assert.Nil(t, err) require.Nil(t, err)
w := httptest.NewRecorder() w := httptest.NewRecorder()
handleMobileConfigDOT(w, r) handleMobileConfigDOT(w, r)
assert.Equal(t, http.StatusOK, w.Code) require.Equal(t, http.StatusOK, w.Code)
var mc mobileConfig var mc mobileConfig
_, err = plist.Unmarshal(w.Body.Bytes(), &mc) _, err = plist.Unmarshal(w.Body.Bytes(), &mc)
assert.Nil(t, err) require.Nil(t, err)
if assert.Len(t, mc.PayloadContent, 1) { require.Len(t, mc.PayloadContent, 1)
assert.Equal(t, "example.org DoT", mc.PayloadContent[0].Name) assert.Equal(t, "example.org DoT", mc.PayloadContent[0].Name)
assert.Equal(t, "example.org DoT", mc.PayloadContent[0].PayloadDisplayName) assert.Equal(t, "example.org DoT", mc.PayloadContent[0].PayloadDisplayName)
assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName) assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName)
}
}) })
t.Run("success_no_host", func(t *testing.T) { t.Run("success_no_host", func(t *testing.T) {
@ -126,22 +123,21 @@ func TestHandleMobileConfigDOT(t *testing.T) {
} }
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/dot.mobileconfig", nil) r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/dot.mobileconfig", nil)
assert.Nil(t, err) require.Nil(t, err)
w := httptest.NewRecorder() w := httptest.NewRecorder()
handleMobileConfigDOT(w, r) handleMobileConfigDOT(w, r)
assert.Equal(t, http.StatusOK, w.Code) require.Equal(t, http.StatusOK, w.Code)
var mc mobileConfig var mc mobileConfig
_, err = plist.Unmarshal(w.Body.Bytes(), &mc) _, err = plist.Unmarshal(w.Body.Bytes(), &mc)
assert.Nil(t, err) require.Nil(t, err)
if assert.Len(t, mc.PayloadContent, 1) { require.Len(t, mc.PayloadContent, 1)
assert.Equal(t, "example.org DoT", mc.PayloadContent[0].Name) assert.Equal(t, "example.org DoT", mc.PayloadContent[0].Name)
assert.Equal(t, "example.org DoT", mc.PayloadContent[0].PayloadDisplayName) assert.Equal(t, "example.org DoT", mc.PayloadContent[0].PayloadDisplayName)
assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName) assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName)
}
}) })
t.Run("error_no_host", func(t *testing.T) { t.Run("error_no_host", func(t *testing.T) {
@ -151,7 +147,7 @@ func TestHandleMobileConfigDOT(t *testing.T) {
Context.tls = &TLSMod{conf: tlsConfigSettings{}} Context.tls = &TLSMod{conf: tlsConfigSettings{}}
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/dot.mobileconfig", nil) r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/dot.mobileconfig", nil)
assert.Nil(t, err) require.Nil(t, err)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -161,21 +157,20 @@ func TestHandleMobileConfigDOT(t *testing.T) {
t.Run("client_id", func(t *testing.T) { t.Run("client_id", func(t *testing.T) {
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/dot.mobileconfig?host=example.org&client_id=cli42", nil) r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/dot.mobileconfig?host=example.org&client_id=cli42", nil)
assert.Nil(t, err) require.Nil(t, err)
w := httptest.NewRecorder() w := httptest.NewRecorder()
handleMobileConfigDOT(w, r) handleMobileConfigDOT(w, r)
assert.Equal(t, http.StatusOK, w.Code) require.Equal(t, http.StatusOK, w.Code)
var mc mobileConfig var mc mobileConfig
_, err = plist.Unmarshal(w.Body.Bytes(), &mc) _, err = plist.Unmarshal(w.Body.Bytes(), &mc)
assert.Nil(t, err) require.Nil(t, err)
if assert.Len(t, mc.PayloadContent, 1) { require.Len(t, mc.PayloadContent, 1)
assert.Equal(t, "example.org DoT", mc.PayloadContent[0].Name) assert.Equal(t, "example.org DoT", mc.PayloadContent[0].Name)
assert.Equal(t, "example.org DoT", mc.PayloadContent[0].PayloadDisplayName) assert.Equal(t, "example.org DoT", mc.PayloadContent[0].PayloadDisplayName)
assert.Equal(t, "cli42.example.org", mc.PayloadContent[0].DNSSettings.ServerName) assert.Equal(t, "cli42.example.org", mc.PayloadContent[0].DNSSettings.ServerName)
}
}) })
} }

View File

@ -4,96 +4,74 @@ import (
"fmt" "fmt"
"net" "net"
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func testParseOk(t *testing.T, ss ...string) options { func testParseOK(t *testing.T, ss ...string) options {
t.Helper()
o, _, err := parse("", ss) o, _, err := parse("", ss)
if err != nil { require.Nil(t, err)
t.Fatal(err.Error())
}
return o return o
} }
func testParseErr(t *testing.T, descr string, ss ...string) { func testParseErr(t *testing.T, descr string, ss ...string) {
t.Helper()
_, _, err := parse("", ss) _, _, err := parse("", ss)
if err == nil { require.NotNilf(t, err, "expected an error because %s but no error returned", descr)
t.Fatalf("expected an error because %s but no error returned", descr)
}
} }
func testParseParamMissing(t *testing.T, param string) { func testParseParamMissing(t *testing.T, param string) {
t.Helper()
testParseErr(t, fmt.Sprintf("%s parameter missing", param), param) testParseErr(t, fmt.Sprintf("%s parameter missing", param), param)
} }
func TestParseVerbose(t *testing.T) { func TestParseVerbose(t *testing.T) {
if testParseOk(t).verbose { assert.False(t, testParseOK(t).verbose, "empty is not verbose")
t.Fatal("empty is not verbose") assert.True(t, testParseOK(t, "-v").verbose, "-v is verbose")
} assert.True(t, testParseOK(t, "--verbose").verbose, "--verbose is verbose")
if !testParseOk(t, "-v").verbose {
t.Fatal("-v is verbose")
}
if !testParseOk(t, "--verbose").verbose {
t.Fatal("--verbose is verbose")
}
} }
func TestParseConfigFilename(t *testing.T) { func TestParseConfigFilename(t *testing.T) {
if testParseOk(t).configFilename != "" { assert.Equal(t, "", testParseOK(t).configFilename, "empty is no config filename")
t.Fatal("empty is no config filename") assert.Equal(t, "path", testParseOK(t, "-c", "path").configFilename, "-c is config filename")
}
if testParseOk(t, "-c", "path").configFilename != "path" {
t.Fatal("-c is config filename")
}
testParseParamMissing(t, "-c") testParseParamMissing(t, "-c")
if testParseOk(t, "--config", "path").configFilename != "path" {
t.Fatal("--configFilename is config filename") assert.Equal(t, "path", testParseOK(t, "--config", "path").configFilename, "--config is config filename")
}
testParseParamMissing(t, "--config") testParseParamMissing(t, "--config")
} }
func TestParseWorkDir(t *testing.T) { func TestParseWorkDir(t *testing.T) {
if testParseOk(t).workDir != "" { assert.Equal(t, "", testParseOK(t).workDir, "empty is no work dir")
t.Fatal("empty is no work dir") assert.Equal(t, "path", testParseOK(t, "-w", "path").workDir, "-w is work dir")
}
if testParseOk(t, "-w", "path").workDir != "path" {
t.Fatal("-w is work dir")
}
testParseParamMissing(t, "-w") testParseParamMissing(t, "-w")
if testParseOk(t, "--work-dir", "path").workDir != "path" {
t.Fatal("--work-dir is work dir") assert.Equal(t, "path", testParseOK(t, "--work-dir", "path").workDir, "--work-dir is work dir")
}
testParseParamMissing(t, "--work-dir") testParseParamMissing(t, "--work-dir")
} }
func TestParseBindHost(t *testing.T) { func TestParseBindHost(t *testing.T) {
if testParseOk(t).bindHost != nil { assert.Nil(t, testParseOK(t).bindHost, "empty is not host")
t.Fatal("empty is no host") assert.Equal(t, net.IPv4(1, 2, 3, 4), testParseOK(t, "-h", "1.2.3.4").bindHost, "-h is host")
}
if !testParseOk(t, "-h", "1.2.3.4").bindHost.Equal(net.IP{1, 2, 3, 4}) {
t.Fatal("-h is host")
}
testParseParamMissing(t, "-h") testParseParamMissing(t, "-h")
if !testParseOk(t, "--host", "1.2.3.4").bindHost.Equal(net.IP{1, 2, 3, 4}) {
t.Fatal("--host is host") assert.Equal(t, net.IPv4(1, 2, 3, 4), testParseOK(t, "--host", "1.2.3.4").bindHost, "--host is host")
}
testParseParamMissing(t, "--host") testParseParamMissing(t, "--host")
} }
func TestParseBindPort(t *testing.T) { func TestParseBindPort(t *testing.T) {
if testParseOk(t).bindPort != 0 { assert.Equal(t, 0, testParseOK(t).bindPort, "empty is port 0")
t.Fatal("empty is port 0") assert.Equal(t, 65535, testParseOK(t, "-p", "65535").bindPort, "-p is port")
}
if testParseOk(t, "-p", "65535").bindPort != 65535 {
t.Fatal("-p is port")
}
testParseParamMissing(t, "-p") testParseParamMissing(t, "-p")
if testParseOk(t, "--port", "65535").bindPort != 65535 {
t.Fatal("--port is port")
}
testParseParamMissing(t, "--port")
}
func TestParseBindPortBad(t *testing.T) { assert.Equal(t, 65535, testParseOK(t, "--port", "65535").bindPort, "--port is port")
testParseParamMissing(t, "--port")
testParseErr(t, "not an int", "-p", "x") testParseErr(t, "not an int", "-p", "x")
testParseErr(t, "hex not supported", "-p", "0x100") testParseErr(t, "hex not supported", "-p", "0x100")
testParseErr(t, "port negative", "-p", "-1") testParseErr(t, "port negative", "-p", "-1")
@ -103,72 +81,40 @@ func TestParseBindPortBad(t *testing.T) {
} }
func TestParseLogfile(t *testing.T) { func TestParseLogfile(t *testing.T) {
if testParseOk(t).logFile != "" { assert.Equal(t, "", testParseOK(t).logFile, "empty is no log file")
t.Fatal("empty is no log file") assert.Equal(t, "path", testParseOK(t, "-l", "path").logFile, "-l is log file")
} assert.Equal(t, "path", testParseOK(t, "--logfile", "path").logFile, "--logfile is log file")
if testParseOk(t, "-l", "path").logFile != "path" {
t.Fatal("-l is log file")
}
if testParseOk(t, "--logfile", "path").logFile != "path" {
t.Fatal("--logfile is log file")
}
} }
func TestParsePidfile(t *testing.T) { func TestParsePidfile(t *testing.T) {
if testParseOk(t).pidFile != "" { assert.Equal(t, "", testParseOK(t).pidFile, "empty is no pid file")
t.Fatal("empty is no pid file") assert.Equal(t, "path", testParseOK(t, "--pidfile", "path").pidFile, "--pidfile is pid file")
}
if testParseOk(t, "--pidfile", "path").pidFile != "path" {
t.Fatal("--pidfile is pid file")
}
} }
func TestParseCheckConfig(t *testing.T) { func TestParseCheckConfig(t *testing.T) {
if testParseOk(t).checkConfig { assert.False(t, testParseOK(t).checkConfig, "empty is not check config")
t.Fatal("empty is not check config") assert.True(t, testParseOK(t, "--check-config").checkConfig, "--check-config is check config")
}
if !testParseOk(t, "--check-config").checkConfig {
t.Fatal("--check-config is check config")
}
} }
func TestParseDisableUpdate(t *testing.T) { func TestParseDisableUpdate(t *testing.T) {
if testParseOk(t).disableUpdate { assert.False(t, testParseOK(t).disableUpdate, "empty is not disable update")
t.Fatal("empty is not disable update") assert.True(t, testParseOK(t, "--no-check-update").disableUpdate, "--no-check-update is disable update")
}
if !testParseOk(t, "--no-check-update").disableUpdate {
t.Fatal("--no-check-update is disable update")
}
} }
func TestParseDisableMemoryOptimization(t *testing.T) { func TestParseDisableMemoryOptimization(t *testing.T) {
if testParseOk(t).disableMemoryOptimization { assert.False(t, testParseOK(t).disableMemoryOptimization, "empty is not disable update")
t.Fatal("empty is not disable update") assert.True(t, testParseOK(t, "--no-mem-optimization").disableMemoryOptimization, "--no-mem-optimization is disable update")
}
if !testParseOk(t, "--no-mem-optimization").disableMemoryOptimization {
t.Fatal("--no-mem-optimization is disable update")
}
} }
func TestParseService(t *testing.T) { func TestParseService(t *testing.T) {
if testParseOk(t).serviceControlAction != "" { assert.Equal(t, "", testParseOK(t).serviceControlAction, "empty is not service cmd")
t.Fatal("empty is no service command") assert.Equal(t, "cmd", testParseOK(t, "-s", "cmd").serviceControlAction, "-s is service cmd")
} assert.Equal(t, "cmd", testParseOK(t, "--service", "cmd").serviceControlAction, "--service is service cmd")
if testParseOk(t, "-s", "command").serviceControlAction != "command" {
t.Fatal("-s is service command")
}
if testParseOk(t, "--service", "command").serviceControlAction != "command" {
t.Fatal("--service is service command")
}
} }
func TestParseGLInet(t *testing.T) { func TestParseGLInet(t *testing.T) {
if testParseOk(t).glinetMode { assert.False(t, testParseOK(t).glinetMode, "empty is not GL-Inet mode")
t.Fatal("empty is not GL-Inet mode") assert.True(t, testParseOK(t, "--glinet").glinetMode, "--glinet is GL-Inet mode")
}
if !testParseOk(t, "--glinet").glinetMode {
t.Fatal("--glinet is GL-Inet mode")
}
} }
func TestParseUnknown(t *testing.T) { func TestParseUnknown(t *testing.T) {
@ -180,73 +126,85 @@ func TestParseUnknown(t *testing.T) {
testParseErr(t, "unknown dash", "-") testParseErr(t, "unknown dash", "-")
} }
func testSerialize(t *testing.T, o options, ss ...string) { func TestSerialize(t *testing.T) {
result := serialize(o) const reportFmt = "expected %s but got %s"
if len(result) != len(ss) {
t.Fatalf("expected %s but got %s", ss, result) testCases := []struct {
} name string
for i, r := range result { opts options
if r != ss[i] { ss []string
t.Fatalf("expected %s but got %s", ss, result) }{{
} name: "empty",
opts: options{},
ss: []string{},
}, {
name: "config_filename",
opts: options{configFilename: "path"},
ss: []string{"-c", "path"},
}, {
name: "work_dir",
opts: options{workDir: "path"},
ss: []string{"-w", "path"},
}, {
name: "bind_host",
opts: options{bindHost: net.IP{1, 2, 3, 4}},
ss: []string{"-h", "1.2.3.4"},
}, {
name: "bind_port",
opts: options{bindPort: 666},
ss: []string{"-p", "666"},
}, {
name: "log_file",
opts: options{logFile: "path"},
ss: []string{"-l", "path"},
}, {
name: "pid_file",
opts: options{pidFile: "path"},
ss: []string{"--pidfile", "path"},
}, {
name: "disable_update",
opts: options{disableUpdate: true},
ss: []string{"--no-check-update"},
}, {
name: "control_action",
opts: options{serviceControlAction: "run"},
ss: []string{"-s", "run"},
}, {
name: "glinet_mode",
opts: options{glinetMode: true},
ss: []string{"--glinet"},
}, {
name: "disable_mem_opt",
opts: options{disableMemoryOptimization: true},
ss: []string{"--no-mem-optimization"},
}, {
name: "multiple",
opts: options{
serviceControlAction: "run",
configFilename: "config",
workDir: "work",
pidFile: "pid",
disableUpdate: true,
disableMemoryOptimization: true,
},
ss: []string{
"-c", "config",
"-w", "work",
"-s", "run",
"--pidfile", "pid",
"--no-check-update",
"--no-mem-optimization",
},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := serialize(tc.opts)
require.Lenf(t, result, len(tc.ss), reportFmt, tc.ss, result)
for i, r := range result {
assert.Equalf(t, tc.ss[i], r, reportFmt, tc.ss, result)
}
})
} }
} }
func TestSerializeEmpty(t *testing.T) {
testSerialize(t, options{})
}
func TestSerializeConfigFilename(t *testing.T) {
testSerialize(t, options{configFilename: "path"}, "-c", "path")
}
func TestSerializeWorkDir(t *testing.T) {
testSerialize(t, options{workDir: "path"}, "-w", "path")
}
func TestSerializeBindHost(t *testing.T) {
testSerialize(t, options{bindHost: net.IP{1, 2, 3, 4}}, "-h", "1.2.3.4")
}
func TestSerializeBindPort(t *testing.T) {
testSerialize(t, options{bindPort: 666}, "-p", "666")
}
func TestSerializeLogfile(t *testing.T) {
testSerialize(t, options{logFile: "path"}, "-l", "path")
}
func TestSerializePidfile(t *testing.T) {
testSerialize(t, options{pidFile: "path"}, "--pidfile", "path")
}
func TestSerializeCheckConfig(t *testing.T) {
testSerialize(t, options{checkConfig: true}, "--check-config")
}
func TestSerializeDisableUpdate(t *testing.T) {
testSerialize(t, options{disableUpdate: true}, "--no-check-update")
}
func TestSerializeService(t *testing.T) {
testSerialize(t, options{serviceControlAction: "run"}, "-s", "run")
}
func TestSerializeGLInet(t *testing.T) {
testSerialize(t, options{glinetMode: true}, "--glinet")
}
func TestSerializeDisableMemoryOptimization(t *testing.T) {
testSerialize(t, options{disableMemoryOptimization: true}, "--no-mem-optimization")
}
func TestSerializeMultiple(t *testing.T) {
testSerialize(t, options{
serviceControlAction: "run",
configFilename: "config",
workDir: "work",
pidFile: "pid",
disableUpdate: true,
disableMemoryOptimization: true,
}, "-c", "config", "-w", "work", "-s", "run", "--pidfile", "pid", "--no-check-update", "--no-mem-optimization")
}

View File

@ -20,6 +20,7 @@ func newTestQLogReader(t *testing.T, filesNum, linesNum int) (reader *QLogReader
// Create the new QLogReader instance. // Create the new QLogReader instance.
reader, err := NewQLogReader(testFiles) reader, err := NewQLogReader(testFiles)
require.Nil(t, err) require.Nil(t, err)
assert.NotNil(t, reader) assert.NotNil(t, reader)
t.Cleanup(func() { t.Cleanup(func() {
assert.Nil(t, reader.Close()) assert.Nil(t, reader.Close())
@ -112,11 +113,7 @@ func TestQLogReader_Seek(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
timestamp, err := time.Parse(time.RFC3339Nano, tc.time) timestamp, err := time.Parse(time.RFC3339Nano, tc.time)
assert.Nil(t, err) require.Nil(t, err)
if tc.name == "first" {
assert.True(t, true)
}
err = r.SeekTS(timestamp.UnixNano()) err = r.SeekTS(timestamp.UnixNano())
assert.True(t, errors.Is(err, tc.want)) assert.True(t, errors.Is(err, tc.want))
@ -146,11 +143,11 @@ func TestQLogReader_ReadNext(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
err := r.SeekStart() err := r.SeekStart()
assert.Nil(t, err, err) require.Nil(t, err)
for i := 1; i < tc.start; i++ { for i := 1; i < tc.start; i++ {
_, err := r.ReadNext() _, err := r.ReadNext()
assert.Nil(t, err) require.Nil(t, err)
} }
_, err = r.ReadNext() _, err = r.ReadNext()