Pull request: 2639 use testify require vol.4
Merge in DNS/adguard-home from 2639-testify-require-4 to master Closes #2639. Squashed commit of the following: commit 0bb9125f42ab6d2511c1b8e481112aa5edd581d9 Merge: 0e9e9ed12c9992e0
Author: Eugene Burkov <e.burkov@adguard.com> Date: Thu Mar 11 15:47:21 2021 +0300 Merge branch 'master' into 2639-testify-require-4 commit 0e9e9ed16ae13ce648b5e1da6ffd123df911c2d7 Author: Eugene Burkov <e.burkov@adguard.com> Date: Wed Mar 10 12:43:15 2021 +0300 home: rm deletion error check commit 6bfbbcd2b7f9197a06856f9e6b959c2e1c4b8353 Merge: c8ebe5418811c881
Author: Eugene Burkov <e.burkov@adguard.com> Date: Wed Mar 10 12:30:07 2021 +0300 Merge branch 'master' into 2639-testify-require-4 commit c8ebe54142bba780226f76ddb72e33664ed28f30 Author: Eugene Burkov <e.burkov@adguard.com> Date: Wed Mar 10 12:28:43 2021 +0300 home: imp tests commit f0e1db456f02df5f5f56ca93e7bd40a48475b38c Author: Eugene Burkov <e.burkov@adguard.com> Date: Fri Mar 5 14:06:41 2021 +0300 dnsforward: imp tests commit 4528246105ed06471a8778abbe8e5c30fc5483d5 Merge: 54b08d9c90ebc4d8
Author: Eugene Burkov <e.burkov@adguard.com> Date: Thu Mar 4 18:17:52 2021 +0300 Merge branch 'master' into 2639-testify-require-4 commit 54b08d9c980b8d69d019a1a1b3931aa048275691 Author: Eugene Burkov <e.burkov@adguard.com> Date: Thu Feb 11 13:17:05 2021 +0300 dnsfilter: imp tests
This commit is contained in:
parent
2c9992e0cc
commit
dfdbfee4fd
|
@ -5,6 +5,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||||
|
@ -13,6 +14,7 @@ import (
|
||||||
"github.com/AdguardTeam/urlfilter/rules"
|
"github.com/AdguardTeam/urlfilter/rules"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
||||||
|
@ -58,7 +60,7 @@ func (d *DNSFilter) checkMatch(t *testing.T, hostname string) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
res, err := d.CheckHost(hostname, dns.TypeA, &setts)
|
res, err := d.CheckHost(hostname, dns.TypeA, &setts)
|
||||||
assert.Nilf(t, err, "Error while matching host %s: %s", hostname, err)
|
require.Nilf(t, err, "Error while matching host %s: %s", hostname, err)
|
||||||
assert.Truef(t, res.IsFiltered, "Expected hostname %s to match", hostname)
|
assert.Truef(t, res.IsFiltered, "Expected hostname %s to match", hostname)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -66,20 +68,20 @@ func (d *DNSFilter) checkMatchIP(t *testing.T, hostname, ip string, qtype uint16
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
res, err := d.CheckHost(hostname, qtype, &setts)
|
res, err := d.CheckHost(hostname, qtype, &setts)
|
||||||
assert.Nilf(t, err, "Error while matching host %s: %s", hostname, err)
|
require.Nilf(t, err, "Error while matching host %s: %s", hostname, err)
|
||||||
assert.Truef(t, res.IsFiltered, "Expected hostname %s to match", hostname)
|
assert.Truef(t, res.IsFiltered, "Expected hostname %s to match", hostname)
|
||||||
if assert.NotEmpty(t, res.Rules, "Expected result to have rules") {
|
|
||||||
r := res.Rules[0]
|
require.NotEmpty(t, res.Rules, "Expected result to have rules")
|
||||||
assert.NotNilf(t, r.IP, "Expected ip %s to match, actual: %v", ip, r.IP)
|
r := res.Rules[0]
|
||||||
assert.Equalf(t, ip, r.IP.String(), "Expected ip %s to match, actual: %v", ip, r.IP)
|
require.NotNilf(t, r.IP, "Expected ip %s to match, actual: %v", ip, r.IP)
|
||||||
}
|
assert.Equalf(t, ip, r.IP.String(), "Expected ip %s to match, actual: %v", ip, r.IP)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DNSFilter) checkMatchEmpty(t *testing.T, hostname string) {
|
func (d *DNSFilter) checkMatchEmpty(t *testing.T, hostname string) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
res, err := d.CheckHost(hostname, dns.TypeA, &setts)
|
res, err := d.CheckHost(hostname, dns.TypeA, &setts)
|
||||||
assert.Nilf(t, err, "Error while matching host %s: %s", hostname, err)
|
require.Nilf(t, err, "Error while matching host %s: %s", hostname, err)
|
||||||
assert.Falsef(t, res.IsFiltered, "Expected hostname %s to not match", hostname)
|
assert.Falsef(t, res.IsFiltered, "Expected hostname %s to not match", hostname)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -110,40 +112,40 @@ func TestEtcHostsMatching(t *testing.T) {
|
||||||
|
|
||||||
// Empty IPv6.
|
// Empty IPv6.
|
||||||
res, err := d.CheckHost("block.com", dns.TypeAAAA, &setts)
|
res, err := d.CheckHost("block.com", dns.TypeAAAA, &setts)
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
assert.True(t, res.IsFiltered)
|
assert.True(t, res.IsFiltered)
|
||||||
if assert.Len(t, res.Rules, 1) {
|
|
||||||
assert.Equal(t, "0.0.0.0 block.com", res.Rules[0].Text)
|
require.Len(t, res.Rules, 1)
|
||||||
assert.Empty(t, res.Rules[0].IP)
|
assert.Equal(t, "0.0.0.0 block.com", res.Rules[0].Text)
|
||||||
}
|
assert.Empty(t, res.Rules[0].IP)
|
||||||
|
|
||||||
// IPv6 match.
|
// IPv6 match.
|
||||||
d.checkMatchIP(t, "ipv6.com", addr6, dns.TypeAAAA)
|
d.checkMatchIP(t, "ipv6.com", addr6, dns.TypeAAAA)
|
||||||
|
|
||||||
// Empty IPv4.
|
// Empty IPv4.
|
||||||
res, err = d.CheckHost("ipv6.com", dns.TypeA, &setts)
|
res, err = d.CheckHost("ipv6.com", dns.TypeA, &setts)
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
assert.True(t, res.IsFiltered)
|
assert.True(t, res.IsFiltered)
|
||||||
if assert.Len(t, res.Rules, 1) {
|
|
||||||
assert.Equal(t, "::1 ipv6.com", res.Rules[0].Text)
|
require.Len(t, res.Rules, 1)
|
||||||
assert.Empty(t, res.Rules[0].IP)
|
assert.Equal(t, "::1 ipv6.com", res.Rules[0].Text)
|
||||||
}
|
assert.Empty(t, res.Rules[0].IP)
|
||||||
|
|
||||||
// Two IPv4, the first one returned.
|
// Two IPv4, the first one returned.
|
||||||
res, err = d.CheckHost("host2", dns.TypeA, &setts)
|
res, err = d.CheckHost("host2", dns.TypeA, &setts)
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
assert.True(t, res.IsFiltered)
|
assert.True(t, res.IsFiltered)
|
||||||
if assert.Len(t, res.Rules, 1) {
|
|
||||||
assert.Equal(t, res.Rules[0].IP, net.IP{0, 0, 0, 1})
|
require.Len(t, res.Rules, 1)
|
||||||
}
|
assert.Equal(t, res.Rules[0].IP, net.IP{0, 0, 0, 1})
|
||||||
|
|
||||||
// One IPv6 address.
|
// One IPv6 address.
|
||||||
res, err = d.CheckHost("host2", dns.TypeAAAA, &setts)
|
res, err = d.CheckHost("host2", dns.TypeAAAA, &setts)
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
assert.True(t, res.IsFiltered)
|
assert.True(t, res.IsFiltered)
|
||||||
if assert.Len(t, res.Rules, 1) {
|
|
||||||
assert.Equal(t, res.Rules[0].IP, net.IPv6loopback)
|
require.Len(t, res.Rules, 1)
|
||||||
}
|
assert.Equal(t, res.Rules[0].IP, net.IPv6loopback)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Safe Browsing.
|
// Safe Browsing.
|
||||||
|
@ -155,14 +157,14 @@ func TestSafeBrowsing(t *testing.T) {
|
||||||
|
|
||||||
d := newForTest(&Config{SafeBrowsingEnabled: true}, nil)
|
d := newForTest(&Config{SafeBrowsingEnabled: true}, nil)
|
||||||
t.Cleanup(d.Close)
|
t.Cleanup(d.Close)
|
||||||
matching := "wmconvirus.narod.ru"
|
const matching = "wmconvirus.narod.ru"
|
||||||
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
|
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
|
||||||
Hostname: matching,
|
Hostname: matching,
|
||||||
Block: true,
|
Block: true,
|
||||||
})
|
})
|
||||||
d.checkMatch(t, matching)
|
d.checkMatch(t, matching)
|
||||||
|
|
||||||
assert.Contains(t, logOutput.String(), "SafeBrowsing lookup for "+matching)
|
require.Contains(t, logOutput.String(), "SafeBrowsing lookup for "+matching)
|
||||||
|
|
||||||
d.checkMatch(t, "test."+matching)
|
d.checkMatch(t, "test."+matching)
|
||||||
d.checkMatchEmpty(t, "yandex.ru")
|
d.checkMatchEmpty(t, "yandex.ru")
|
||||||
|
@ -178,7 +180,7 @@ func TestSafeBrowsing(t *testing.T) {
|
||||||
func TestParallelSB(t *testing.T) {
|
func TestParallelSB(t *testing.T) {
|
||||||
d := newForTest(&Config{SafeBrowsingEnabled: true}, nil)
|
d := newForTest(&Config{SafeBrowsingEnabled: true}, nil)
|
||||||
t.Cleanup(d.Close)
|
t.Cleanup(d.Close)
|
||||||
matching := "wmconvirus.narod.ru"
|
const matching = "wmconvirus.narod.ru"
|
||||||
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
|
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
|
||||||
Hostname: matching,
|
Hostname: matching,
|
||||||
Block: true,
|
Block: true,
|
||||||
|
@ -203,7 +205,7 @@ func TestSafeSearch(t *testing.T) {
|
||||||
d := newForTest(&Config{SafeSearchEnabled: true}, nil)
|
d := newForTest(&Config{SafeSearchEnabled: true}, nil)
|
||||||
t.Cleanup(d.Close)
|
t.Cleanup(d.Close)
|
||||||
val, ok := d.SafeSearchDomain("www.google.com")
|
val, ok := d.SafeSearchDomain("www.google.com")
|
||||||
assert.True(t, ok, "Expected safesearch to find result for www.google.com")
|
require.True(t, ok, "Expected safesearch to find result for www.google.com")
|
||||||
assert.Equal(t, "forcesafesearch.google.com", val, "Expected safesearch for google.com to be forcesafesearch.google.com")
|
assert.Equal(t, "forcesafesearch.google.com", val, "Expected safesearch for google.com to be forcesafesearch.google.com")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -211,6 +213,8 @@ func TestCheckHostSafeSearchYandex(t *testing.T) {
|
||||||
d := newForTest(&Config{SafeSearchEnabled: true}, nil)
|
d := newForTest(&Config{SafeSearchEnabled: true}, nil)
|
||||||
t.Cleanup(d.Close)
|
t.Cleanup(d.Close)
|
||||||
|
|
||||||
|
yandexIP := net.IPv4(213, 180, 193, 56)
|
||||||
|
|
||||||
// Check host for each domain.
|
// Check host for each domain.
|
||||||
for _, host := range []string{
|
for _, host := range []string{
|
||||||
"yAndeX.ru",
|
"yAndeX.ru",
|
||||||
|
@ -220,22 +224,27 @@ func TestCheckHostSafeSearchYandex(t *testing.T) {
|
||||||
"yandex.kz",
|
"yandex.kz",
|
||||||
"www.yandex.com",
|
"www.yandex.com",
|
||||||
} {
|
} {
|
||||||
res, err := d.CheckHost(host, dns.TypeA, &setts)
|
t.Run(strings.ToLower(host), func(t *testing.T) {
|
||||||
assert.Nil(t, err)
|
res, err := d.CheckHost(host, dns.TypeA, &setts)
|
||||||
assert.True(t, res.IsFiltered)
|
require.Nil(t, err)
|
||||||
if assert.Len(t, res.Rules, 1) {
|
assert.True(t, res.IsFiltered)
|
||||||
assert.Equal(t, res.Rules[0].IP, net.IPv4(213, 180, 193, 56))
|
|
||||||
}
|
require.Len(t, res.Rules, 1)
|
||||||
|
assert.Equal(t, yandexIP, res.Rules[0].IP)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCheckHostSafeSearchGoogle(t *testing.T) {
|
func TestCheckHostSafeSearchGoogle(t *testing.T) {
|
||||||
|
resolver := &aghtest.TestResolver{}
|
||||||
d := newForTest(&Config{
|
d := newForTest(&Config{
|
||||||
SafeSearchEnabled: true,
|
SafeSearchEnabled: true,
|
||||||
CustomResolver: &aghtest.TestResolver{},
|
CustomResolver: resolver,
|
||||||
}, nil)
|
}, nil)
|
||||||
t.Cleanup(d.Close)
|
t.Cleanup(d.Close)
|
||||||
|
|
||||||
|
ip, _ := resolver.HostToIPs("forcesafesearch.google.com")
|
||||||
|
|
||||||
// Check host for each domain.
|
// Check host for each domain.
|
||||||
for _, host := range []string{
|
for _, host := range []string{
|
||||||
"www.google.com",
|
"www.google.com",
|
||||||
|
@ -248,11 +257,10 @@ func TestCheckHostSafeSearchGoogle(t *testing.T) {
|
||||||
} {
|
} {
|
||||||
t.Run(host, func(t *testing.T) {
|
t.Run(host, func(t *testing.T) {
|
||||||
res, err := d.CheckHost(host, dns.TypeA, &setts)
|
res, err := d.CheckHost(host, dns.TypeA, &setts)
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
assert.True(t, res.IsFiltered)
|
assert.True(t, res.IsFiltered)
|
||||||
if assert.Len(t, res.Rules, 1) {
|
require.Len(t, res.Rules, 1)
|
||||||
assert.NotEqual(t, res.Rules[0].IP.String(), "0.0.0.0")
|
assert.Equal(t, ip, res.Rules[0].IP)
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -260,31 +268,31 @@ func TestCheckHostSafeSearchGoogle(t *testing.T) {
|
||||||
func TestSafeSearchCacheYandex(t *testing.T) {
|
func TestSafeSearchCacheYandex(t *testing.T) {
|
||||||
d := newForTest(nil, nil)
|
d := newForTest(nil, nil)
|
||||||
t.Cleanup(d.Close)
|
t.Cleanup(d.Close)
|
||||||
domain := "yandex.ru"
|
const domain = "yandex.ru"
|
||||||
|
|
||||||
// Check host with disabled safesearch.
|
// Check host with disabled safesearch.
|
||||||
res, err := d.CheckHost(domain, dns.TypeA, &setts)
|
res, err := d.CheckHost(domain, dns.TypeA, &setts)
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
assert.False(t, res.IsFiltered)
|
assert.False(t, res.IsFiltered)
|
||||||
assert.Empty(t, res.Rules)
|
require.Empty(t, res.Rules)
|
||||||
|
|
||||||
|
yandexIP := net.IPv4(213, 180, 193, 56)
|
||||||
|
|
||||||
d = newForTest(&Config{SafeSearchEnabled: true}, nil)
|
d = newForTest(&Config{SafeSearchEnabled: true}, nil)
|
||||||
t.Cleanup(d.Close)
|
t.Cleanup(d.Close)
|
||||||
|
|
||||||
res, err = d.CheckHost(domain, dns.TypeA, &setts)
|
res, err = d.CheckHost(domain, dns.TypeA, &setts)
|
||||||
assert.Nilf(t, err, "CheckHost for safesearh domain %s failed cause %s", domain, err)
|
require.Nilf(t, err, "CheckHost for safesearh domain %s failed cause %s", domain, err)
|
||||||
|
|
||||||
// For yandex we already know valid IP.
|
// For yandex we already know valid IP.
|
||||||
if assert.Len(t, res.Rules, 1) {
|
require.Len(t, res.Rules, 1)
|
||||||
assert.Equal(t, res.Rules[0].IP, net.IPv4(213, 180, 193, 56))
|
assert.Equal(t, res.Rules[0].IP, yandexIP)
|
||||||
}
|
|
||||||
|
|
||||||
// Check cache.
|
// Check cache.
|
||||||
cachedValue, isFound := getCachedResult(gctx.safeSearchCache, domain)
|
cachedValue, isFound := getCachedResult(gctx.safeSearchCache, domain)
|
||||||
assert.True(t, isFound)
|
require.True(t, isFound)
|
||||||
if assert.Len(t, cachedValue.Rules, 1) {
|
require.Len(t, cachedValue.Rules, 1)
|
||||||
assert.Equal(t, cachedValue.Rules[0].IP, net.IPv4(213, 180, 193, 56))
|
assert.Equal(t, cachedValue.Rules[0].IP, yandexIP)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSafeSearchCacheGoogle(t *testing.T) {
|
func TestSafeSearchCacheGoogle(t *testing.T) {
|
||||||
|
@ -294,11 +302,11 @@ func TestSafeSearchCacheGoogle(t *testing.T) {
|
||||||
}, nil)
|
}, nil)
|
||||||
t.Cleanup(d.Close)
|
t.Cleanup(d.Close)
|
||||||
|
|
||||||
domain := "www.google.ru"
|
const domain = "www.google.ru"
|
||||||
res, err := d.CheckHost(domain, dns.TypeA, &setts)
|
res, err := d.CheckHost(domain, dns.TypeA, &setts)
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
assert.False(t, res.IsFiltered)
|
assert.False(t, res.IsFiltered)
|
||||||
assert.Empty(t, res.Rules)
|
require.Empty(t, res.Rules)
|
||||||
|
|
||||||
d = newForTest(&Config{SafeSearchEnabled: true}, nil)
|
d = newForTest(&Config{SafeSearchEnabled: true}, nil)
|
||||||
t.Cleanup(d.Close)
|
t.Cleanup(d.Close)
|
||||||
|
@ -306,12 +314,10 @@ func TestSafeSearchCacheGoogle(t *testing.T) {
|
||||||
|
|
||||||
// Lookup for safesearch domain.
|
// Lookup for safesearch domain.
|
||||||
safeDomain, ok := d.SafeSearchDomain(domain)
|
safeDomain, ok := d.SafeSearchDomain(domain)
|
||||||
assert.Truef(t, ok, "Failed to get safesearch domain for %s", domain)
|
require.Truef(t, ok, "Failed to get safesearch domain for %s", domain)
|
||||||
|
|
||||||
ips, err := resolver.LookupIP(context.Background(), "ip", safeDomain)
|
ips, err := resolver.LookupIP(context.Background(), "ip", safeDomain)
|
||||||
if err != nil {
|
require.Nilf(t, err, "Failed to lookup for %s", safeDomain)
|
||||||
t.Fatalf("Failed to lookup for %s", safeDomain)
|
|
||||||
}
|
|
||||||
|
|
||||||
var ip net.IP
|
var ip net.IP
|
||||||
for _, foundIP := range ips {
|
for _, foundIP := range ips {
|
||||||
|
@ -323,17 +329,15 @@ func TestSafeSearchCacheGoogle(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
res, err = d.CheckHost(domain, dns.TypeA, &setts)
|
res, err = d.CheckHost(domain, dns.TypeA, &setts)
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
if assert.Len(t, res.Rules, 1) {
|
require.Len(t, res.Rules, 1)
|
||||||
assert.True(t, res.Rules[0].IP.Equal(ip))
|
assert.True(t, res.Rules[0].IP.Equal(ip))
|
||||||
}
|
|
||||||
|
|
||||||
// Check cache.
|
// Check cache.
|
||||||
cachedValue, isFound := getCachedResult(gctx.safeSearchCache, domain)
|
cachedValue, isFound := getCachedResult(gctx.safeSearchCache, domain)
|
||||||
assert.True(t, isFound)
|
require.True(t, isFound)
|
||||||
if assert.Len(t, cachedValue.Rules, 1) {
|
require.Len(t, cachedValue.Rules, 1)
|
||||||
assert.True(t, cachedValue.Rules[0].IP.Equal(ip))
|
assert.True(t, cachedValue.Rules[0].IP.Equal(ip))
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parental.
|
// Parental.
|
||||||
|
@ -345,24 +349,23 @@ func TestParentalControl(t *testing.T) {
|
||||||
|
|
||||||
d := newForTest(&Config{ParentalEnabled: true}, nil)
|
d := newForTest(&Config{ParentalEnabled: true}, nil)
|
||||||
t.Cleanup(d.Close)
|
t.Cleanup(d.Close)
|
||||||
matching := "pornhub.com"
|
const matching = "pornhub.com"
|
||||||
d.SetParentalUpstream(&aghtest.TestBlockUpstream{
|
d.SetParentalUpstream(&aghtest.TestBlockUpstream{
|
||||||
Hostname: matching,
|
Hostname: matching,
|
||||||
Block: true,
|
Block: true,
|
||||||
})
|
})
|
||||||
|
|
||||||
d.checkMatch(t, matching)
|
d.checkMatch(t, matching)
|
||||||
assert.Contains(t, logOutput.String(), "Parental lookup for "+matching)
|
require.Contains(t, logOutput.String(), "Parental lookup for "+matching)
|
||||||
d.checkMatch(t, "www."+matching)
|
d.checkMatch(t, "www."+matching)
|
||||||
d.checkMatchEmpty(t, "www.yandex.ru")
|
d.checkMatchEmpty(t, "www.yandex.ru")
|
||||||
d.checkMatchEmpty(t, "yandex.ru")
|
d.checkMatchEmpty(t, "yandex.ru")
|
||||||
d.checkMatchEmpty(t, "api.jquery.com")
|
d.checkMatchEmpty(t, "api.jquery.com")
|
||||||
|
|
||||||
// test cached result
|
// Test cached result.
|
||||||
d.parentalServer = "127.0.0.1"
|
d.parentalServer = "127.0.0.1"
|
||||||
d.checkMatch(t, matching)
|
d.checkMatch(t, matching)
|
||||||
d.checkMatchEmpty(t, "yandex.ru")
|
d.checkMatchEmpty(t, "yandex.ru")
|
||||||
d.parentalServer = defaultParentalServer
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Filtering.
|
// Filtering.
|
||||||
|
@ -651,7 +654,7 @@ func TestMatching(t *testing.T) {
|
||||||
t.Cleanup(d.Close)
|
t.Cleanup(d.Close)
|
||||||
|
|
||||||
res, err := d.CheckHost(tc.host, tc.wantDNSType, &setts)
|
res, err := d.CheckHost(tc.host, tc.wantDNSType, &setts)
|
||||||
assert.Nilf(t, err, "Error while matching host %s: %s", tc.host, err)
|
require.Nilf(t, err, "Error while matching host %s: %s", tc.host, err)
|
||||||
assert.Equalf(t, tc.wantIsFiltered, res.IsFiltered, "Hostname %s has wrong result (%v must be %v)", tc.host, res.IsFiltered, tc.wantIsFiltered)
|
assert.Equalf(t, tc.wantIsFiltered, res.IsFiltered, "Hostname %s has wrong result (%v must be %v)", tc.host, res.IsFiltered, tc.wantIsFiltered)
|
||||||
assert.Equalf(t, tc.wantReason, res.Reason, "Hostname %s has wrong reason (%v must be %v)", tc.host, res.Reason, tc.wantReason)
|
assert.Equalf(t, tc.wantReason, res.Reason, "Hostname %s has wrong reason (%v must be %v)", tc.host, res.Reason, tc.wantReason)
|
||||||
})
|
})
|
||||||
|
@ -674,28 +677,24 @@ func TestWhitelist(t *testing.T) {
|
||||||
}}
|
}}
|
||||||
d := newForTest(nil, filters)
|
d := newForTest(nil, filters)
|
||||||
|
|
||||||
err := d.SetFilters(filters, whiteFilters, false)
|
require.Nil(t, d.SetFilters(filters, whiteFilters, false))
|
||||||
assert.Nil(t, err)
|
|
||||||
|
|
||||||
t.Cleanup(d.Close)
|
t.Cleanup(d.Close)
|
||||||
|
|
||||||
// Matched by white filter.
|
// Matched by white filter.
|
||||||
res, err := d.CheckHost("host1", dns.TypeA, &setts)
|
res, err := d.CheckHost("host1", dns.TypeA, &setts)
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
assert.False(t, res.IsFiltered)
|
assert.False(t, res.IsFiltered)
|
||||||
assert.Equal(t, res.Reason, NotFilteredAllowList)
|
assert.Equal(t, res.Reason, NotFilteredAllowList)
|
||||||
if assert.Len(t, res.Rules, 1) {
|
require.Len(t, res.Rules, 1)
|
||||||
assert.Equal(t, "||host1^", res.Rules[0].Text)
|
assert.Equal(t, "||host1^", res.Rules[0].Text)
|
||||||
}
|
|
||||||
|
|
||||||
// Not matched by white filter, but matched by block filter.
|
// Not matched by white filter, but matched by block filter.
|
||||||
res, err = d.CheckHost("host2", dns.TypeA, &setts)
|
res, err = d.CheckHost("host2", dns.TypeA, &setts)
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
assert.True(t, res.IsFiltered)
|
assert.True(t, res.IsFiltered)
|
||||||
assert.Equal(t, res.Reason, FilteredBlockList)
|
assert.Equal(t, res.Reason, FilteredBlockList)
|
||||||
if assert.Len(t, res.Rules, 1) {
|
require.Len(t, res.Rules, 1)
|
||||||
assert.Equal(t, "||host2^", res.Rules[0].Text)
|
assert.Equal(t, "||host2^", res.Rules[0].Text)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Client Settings.
|
// Client Settings.
|
||||||
|
@ -797,7 +796,7 @@ func BenchmarkSafeBrowsing(b *testing.B) {
|
||||||
})
|
})
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
res, err := d.CheckHost(blocked, dns.TypeA, &setts)
|
res, err := d.CheckHost(blocked, dns.TypeA, &setts)
|
||||||
assert.Nilf(b, err, "Error while matching host %s: %s", blocked, err)
|
require.Nilf(b, err, "Error while matching host %s: %s", blocked, err)
|
||||||
assert.True(b, res.IsFiltered, "Expected hostname %s to match", blocked)
|
assert.True(b, res.IsFiltered, "Expected hostname %s to match", blocked)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -813,7 +812,7 @@ func BenchmarkSafeBrowsingParallel(b *testing.B) {
|
||||||
b.RunParallel(func(pb *testing.PB) {
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
for pb.Next() {
|
for pb.Next() {
|
||||||
res, err := d.CheckHost(blocked, dns.TypeA, &setts)
|
res, err := d.CheckHost(blocked, dns.TypeA, &setts)
|
||||||
assert.Nilf(b, err, "Error while matching host %s: %s", blocked, err)
|
require.Nilf(b, err, "Error while matching host %s: %s", blocked, err)
|
||||||
assert.True(b, res.IsFiltered, "Expected hostname %s to match", blocked)
|
assert.True(b, res.IsFiltered, "Expected hostname %s to match", blocked)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -824,7 +823,7 @@ func BenchmarkSafeSearch(b *testing.B) {
|
||||||
b.Cleanup(d.Close)
|
b.Cleanup(d.Close)
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
val, ok := d.SafeSearchDomain("www.google.com")
|
val, ok := d.SafeSearchDomain("www.google.com")
|
||||||
assert.True(b, ok, "Expected safesearch to find result for www.google.com")
|
require.True(b, ok, "Expected safesearch to find result for www.google.com")
|
||||||
assert.Equal(b, "forcesafesearch.google.com", val, "Expected safesearch for google.com to be forcesafesearch.google.com")
|
assert.Equal(b, "forcesafesearch.google.com", val, "Expected safesearch for google.com to be forcesafesearch.google.com")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -835,7 +834,7 @@ func BenchmarkSafeSearchParallel(b *testing.B) {
|
||||||
b.RunParallel(func(pb *testing.PB) {
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
for pb.Next() {
|
for pb.Next() {
|
||||||
val, ok := d.SafeSearchDomain("www.google.com")
|
val, ok := d.SafeSearchDomain("www.google.com")
|
||||||
assert.True(b, ok, "Expected safesearch to find result for www.google.com")
|
require.True(b, ok, "Expected safesearch to find result for www.google.com")
|
||||||
assert.Equal(b, "forcesafesearch.google.com", val, "Expected safesearch for google.com to be forcesafesearch.google.com")
|
assert.Equal(b, "forcesafesearch.google.com", val, "Expected safesearch for google.com to be forcesafesearch.google.com")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
|
func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
|
||||||
|
@ -55,138 +56,89 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
|
||||||
ipv6p1 := net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
|
ipv6p1 := net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
|
||||||
ipv6p2 := net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2}
|
ipv6p2 := net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2}
|
||||||
|
|
||||||
|
testCasesA := []struct {
|
||||||
|
name string
|
||||||
|
dtyp uint16
|
||||||
|
rcode int
|
||||||
|
want []interface{}
|
||||||
|
}{{
|
||||||
|
name: "a-record",
|
||||||
|
dtyp: dns.TypeA,
|
||||||
|
rcode: dns.RcodeSuccess,
|
||||||
|
want: []interface{}{ipv4p1},
|
||||||
|
}, {
|
||||||
|
name: "aaaa-record",
|
||||||
|
dtyp: dns.TypeAAAA,
|
||||||
|
rcode: dns.RcodeSuccess,
|
||||||
|
want: []interface{}{ipv6p1},
|
||||||
|
}, {
|
||||||
|
name: "txt-record",
|
||||||
|
dtyp: dns.TypeTXT,
|
||||||
|
rcode: dns.RcodeSuccess,
|
||||||
|
want: []interface{}{"hello-world"},
|
||||||
|
}, {
|
||||||
|
name: "refused",
|
||||||
|
rcode: dns.RcodeRefused,
|
||||||
|
}, {
|
||||||
|
name: "a-records",
|
||||||
|
dtyp: dns.TypeA,
|
||||||
|
rcode: dns.RcodeSuccess,
|
||||||
|
want: []interface{}{ipv4p1, ipv4p2},
|
||||||
|
}, {
|
||||||
|
name: "aaaa-records",
|
||||||
|
dtyp: dns.TypeAAAA,
|
||||||
|
rcode: dns.RcodeSuccess,
|
||||||
|
want: []interface{}{ipv6p1, ipv6p2},
|
||||||
|
}, {
|
||||||
|
name: "disable-one",
|
||||||
|
dtyp: dns.TypeA,
|
||||||
|
rcode: dns.RcodeSuccess,
|
||||||
|
want: []interface{}{ipv4p2},
|
||||||
|
}, {
|
||||||
|
name: "disable-cname",
|
||||||
|
dtyp: dns.TypeA,
|
||||||
|
rcode: dns.RcodeSuccess,
|
||||||
|
want: []interface{}{ipv4p1},
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, tc := range testCasesA {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
host := path.Base(tc.name)
|
||||||
|
|
||||||
|
res, err := f.CheckHostRules(host, tc.dtyp, setts)
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
dnsrr := res.DNSRewriteResult
|
||||||
|
require.NotNil(t, dnsrr)
|
||||||
|
assert.Equal(t, tc.rcode, dnsrr.RCode)
|
||||||
|
|
||||||
|
if tc.rcode == dns.RcodeRefused {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ipVals := dnsrr.Response[tc.dtyp]
|
||||||
|
require.Len(t, ipVals, len(tc.want))
|
||||||
|
for i, val := range tc.want {
|
||||||
|
require.Equal(t, val, ipVals[i])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
t.Run("cname", func(t *testing.T) {
|
t.Run("cname", func(t *testing.T) {
|
||||||
dtyp := dns.TypeA
|
dtyp := dns.TypeA
|
||||||
host := path.Base(t.Name())
|
host := path.Base(t.Name())
|
||||||
|
|
||||||
res, err := f.CheckHostRules(host, dtyp, setts)
|
res, err := f.CheckHostRules(host, dtyp, setts)
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
assert.Equal(t, "new-cname", res.CanonName)
|
assert.Equal(t, "new-cname", res.CanonName)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("a-record", func(t *testing.T) {
|
|
||||||
dtyp := dns.TypeA
|
|
||||||
host := path.Base(t.Name())
|
|
||||||
|
|
||||||
res, err := f.CheckHostRules(host, dtyp, setts)
|
|
||||||
assert.Nil(t, err)
|
|
||||||
|
|
||||||
if dnsrr := res.DNSRewriteResult; assert.NotNil(t, dnsrr) {
|
|
||||||
assert.Equal(t, dns.RcodeSuccess, dnsrr.RCode)
|
|
||||||
if ipVals := dnsrr.Response[dtyp]; assert.Len(t, ipVals, 1) {
|
|
||||||
assert.Equal(t, ipv4p1, ipVals[0])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("aaaa-record", func(t *testing.T) {
|
|
||||||
dtyp := dns.TypeAAAA
|
|
||||||
host := path.Base(t.Name())
|
|
||||||
|
|
||||||
res, err := f.CheckHostRules(host, dtyp, setts)
|
|
||||||
assert.Nil(t, err)
|
|
||||||
|
|
||||||
if dnsrr := res.DNSRewriteResult; assert.NotNil(t, dnsrr) {
|
|
||||||
assert.Equal(t, dns.RcodeSuccess, dnsrr.RCode)
|
|
||||||
if ipVals := dnsrr.Response[dtyp]; assert.Len(t, ipVals, 1) {
|
|
||||||
assert.Equal(t, ipv6p1, ipVals[0])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("txt-record", func(t *testing.T) {
|
|
||||||
dtyp := dns.TypeTXT
|
|
||||||
host := path.Base(t.Name())
|
|
||||||
res, err := f.CheckHostRules(host, dtyp, setts)
|
|
||||||
assert.Nil(t, err)
|
|
||||||
|
|
||||||
if dnsrr := res.DNSRewriteResult; assert.NotNil(t, dnsrr) {
|
|
||||||
assert.Equal(t, dns.RcodeSuccess, dnsrr.RCode)
|
|
||||||
if strVals := dnsrr.Response[dtyp]; assert.Len(t, strVals, 1) {
|
|
||||||
assert.Equal(t, "hello-world", strVals[0])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("refused", func(t *testing.T) {
|
|
||||||
host := path.Base(t.Name())
|
|
||||||
res, err := f.CheckHostRules(host, dns.TypeA, setts)
|
|
||||||
assert.Nil(t, err)
|
|
||||||
|
|
||||||
if dnsrr := res.DNSRewriteResult; assert.NotNil(t, dnsrr) {
|
|
||||||
assert.Equal(t, dns.RcodeRefused, dnsrr.RCode)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("a-records", func(t *testing.T) {
|
|
||||||
dtyp := dns.TypeA
|
|
||||||
host := path.Base(t.Name())
|
|
||||||
|
|
||||||
res, err := f.CheckHostRules(host, dtyp, setts)
|
|
||||||
assert.Nil(t, err)
|
|
||||||
|
|
||||||
if dnsrr := res.DNSRewriteResult; assert.NotNil(t, dnsrr) {
|
|
||||||
assert.Equal(t, dns.RcodeSuccess, dnsrr.RCode)
|
|
||||||
if ipVals := dnsrr.Response[dtyp]; assert.Len(t, ipVals, 2) {
|
|
||||||
assert.Equal(t, ipv4p1, ipVals[0])
|
|
||||||
assert.Equal(t, ipv4p2, ipVals[1])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("aaaa-records", func(t *testing.T) {
|
|
||||||
dtyp := dns.TypeAAAA
|
|
||||||
host := path.Base(t.Name())
|
|
||||||
|
|
||||||
res, err := f.CheckHostRules(host, dtyp, setts)
|
|
||||||
assert.Nil(t, err)
|
|
||||||
|
|
||||||
if dnsrr := res.DNSRewriteResult; assert.NotNil(t, dnsrr) {
|
|
||||||
assert.Equal(t, dns.RcodeSuccess, dnsrr.RCode)
|
|
||||||
if ipVals := dnsrr.Response[dtyp]; assert.Len(t, ipVals, 2) {
|
|
||||||
assert.Equal(t, ipv6p1, ipVals[0])
|
|
||||||
assert.Equal(t, ipv6p2, ipVals[1])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("disable-one", func(t *testing.T) {
|
|
||||||
dtyp := dns.TypeA
|
|
||||||
host := path.Base(t.Name())
|
|
||||||
|
|
||||||
res, err := f.CheckHostRules(host, dtyp, setts)
|
|
||||||
assert.Nil(t, err)
|
|
||||||
|
|
||||||
if dnsrr := res.DNSRewriteResult; assert.NotNil(t, dnsrr) {
|
|
||||||
assert.Equal(t, dns.RcodeSuccess, dnsrr.RCode)
|
|
||||||
if ipVals := dnsrr.Response[dtyp]; assert.Len(t, ipVals, 1) {
|
|
||||||
assert.Equal(t, ipv4p2, ipVals[0])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("disable-cname", func(t *testing.T) {
|
|
||||||
dtyp := dns.TypeA
|
|
||||||
host := path.Base(t.Name())
|
|
||||||
|
|
||||||
res, err := f.CheckHostRules(host, dtyp, setts)
|
|
||||||
assert.Nil(t, err)
|
|
||||||
assert.Empty(t, res.CanonName)
|
|
||||||
|
|
||||||
if dnsrr := res.DNSRewriteResult; assert.NotNil(t, dnsrr) {
|
|
||||||
assert.Equal(t, dns.RcodeSuccess, dnsrr.RCode)
|
|
||||||
if ipVals := dnsrr.Response[dtyp]; assert.Len(t, ipVals, 1) {
|
|
||||||
assert.Equal(t, ipv4p1, ipVals[0])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("disable-cname-many", func(t *testing.T) {
|
t.Run("disable-cname-many", func(t *testing.T) {
|
||||||
dtyp := dns.TypeA
|
dtyp := dns.TypeA
|
||||||
host := path.Base(t.Name())
|
host := path.Base(t.Name())
|
||||||
|
|
||||||
res, err := f.CheckHostRules(host, dtyp, setts)
|
res, err := f.CheckHostRules(host, dtyp, setts)
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
assert.Equal(t, "new-cname-2", res.CanonName)
|
assert.Equal(t, "new-cname-2", res.CanonName)
|
||||||
assert.Nil(t, res.DNSRewriteResult)
|
assert.Nil(t, res.DNSRewriteResult)
|
||||||
})
|
})
|
||||||
|
@ -196,7 +148,7 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
|
||||||
host := path.Base(t.Name())
|
host := path.Base(t.Name())
|
||||||
|
|
||||||
res, err := f.CheckHostRules(host, dtyp, setts)
|
res, err := f.CheckHostRules(host, dtyp, setts)
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
assert.Empty(t, res.CanonName)
|
assert.Empty(t, res.CanonName)
|
||||||
assert.Empty(t, res.Rules)
|
assert.Empty(t, res.Rules)
|
||||||
})
|
})
|
||||||
|
|
|
@ -6,215 +6,297 @@ import (
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// TODO(e.burkov): All the tests in this file may and should me merged together.
|
||||||
|
|
||||||
func TestRewrites(t *testing.T) {
|
func TestRewrites(t *testing.T) {
|
||||||
d := newForTest(nil, nil)
|
d := newForTest(nil, nil)
|
||||||
t.Cleanup(d.Close)
|
t.Cleanup(d.Close)
|
||||||
// CNAME, A, AAAA
|
|
||||||
d.Rewrites = []RewriteEntry{
|
|
||||||
{"somecname", "somehost.com", 0, nil},
|
|
||||||
{"somehost.com", "0.0.0.0", 0, nil},
|
|
||||||
|
|
||||||
{"host.com", "1.2.3.4", 0, nil},
|
d.Rewrites = []RewriteEntry{{
|
||||||
{"host.com", "1.2.3.5", 0, nil},
|
// This one and below are about CNAME, A and AAAA.
|
||||||
{"host.com", "1:2:3::4", 0, nil},
|
Domain: "somecname",
|
||||||
{"www.host.com", "host.com", 0, nil},
|
Answer: "somehost.com",
|
||||||
}
|
}, {
|
||||||
|
Domain: "somehost.com",
|
||||||
|
Answer: "0.0.0.0",
|
||||||
|
}, {
|
||||||
|
Domain: "host.com",
|
||||||
|
Answer: "1.2.3.4",
|
||||||
|
}, {
|
||||||
|
Domain: "host.com",
|
||||||
|
Answer: "1.2.3.5",
|
||||||
|
}, {
|
||||||
|
Domain: "host.com",
|
||||||
|
Answer: "1:2:3::4",
|
||||||
|
}, {
|
||||||
|
Domain: "www.host.com",
|
||||||
|
Answer: "host.com",
|
||||||
|
}, {
|
||||||
|
// This one is a wildcard.
|
||||||
|
Domain: "*.host.com",
|
||||||
|
Answer: "1.2.3.5",
|
||||||
|
}, {
|
||||||
|
// This one and below are about wildcard overriding.
|
||||||
|
Domain: "a.host.com",
|
||||||
|
Answer: "1.2.3.4",
|
||||||
|
}, {
|
||||||
|
// This one is about CNAME and wildcard interacting.
|
||||||
|
Domain: "*.host2.com",
|
||||||
|
Answer: "host.com",
|
||||||
|
}, {
|
||||||
|
// This one and below are about 2 level CNAME.
|
||||||
|
Domain: "b.host.com",
|
||||||
|
Answer: "somecname",
|
||||||
|
}, {
|
||||||
|
// This one and below are about 2 level CNAME and wildcard.
|
||||||
|
Domain: "b.host3.com",
|
||||||
|
Answer: "a.host3.com",
|
||||||
|
}, {
|
||||||
|
Domain: "a.host3.com",
|
||||||
|
Answer: "x.host.com",
|
||||||
|
}}
|
||||||
d.prepareRewrites()
|
d.prepareRewrites()
|
||||||
r := d.processRewrites("host2.com", dns.TypeA)
|
|
||||||
assert.Equal(t, NotFilteredNotFound, r.Reason)
|
|
||||||
|
|
||||||
r = d.processRewrites("www.host.com", dns.TypeA)
|
testCases := []struct {
|
||||||
assert.Equal(t, Rewritten, r.Reason)
|
name string
|
||||||
assert.Equal(t, "host.com", r.CanonName)
|
host string
|
||||||
assert.Len(t, r.IPList, 2)
|
dtyp uint16
|
||||||
assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4}))
|
wantCName string
|
||||||
assert.True(t, r.IPList[1].Equal(net.IP{1, 2, 3, 5}))
|
wantVals []net.IP
|
||||||
|
}{{
|
||||||
|
name: "not_filtered_not_found",
|
||||||
|
host: "hoost.com",
|
||||||
|
dtyp: dns.TypeA,
|
||||||
|
}, {
|
||||||
|
name: "rewritten_a",
|
||||||
|
host: "www.host.com",
|
||||||
|
dtyp: dns.TypeA,
|
||||||
|
wantCName: "host.com",
|
||||||
|
wantVals: []net.IP{{1, 2, 3, 4}, {1, 2, 3, 5}},
|
||||||
|
}, {
|
||||||
|
name: "rewritten_aaaa",
|
||||||
|
host: "www.host.com",
|
||||||
|
dtyp: dns.TypeAAAA,
|
||||||
|
wantCName: "host.com",
|
||||||
|
wantVals: []net.IP{net.ParseIP("1:2:3::4")},
|
||||||
|
}, {
|
||||||
|
name: "wildcard_match",
|
||||||
|
host: "abc.host.com",
|
||||||
|
dtyp: dns.TypeA,
|
||||||
|
wantVals: []net.IP{{1, 2, 3, 5}},
|
||||||
|
}, {
|
||||||
|
name: "wildcard_override",
|
||||||
|
host: "a.host.com",
|
||||||
|
dtyp: dns.TypeA,
|
||||||
|
wantVals: []net.IP{{1, 2, 3, 4}},
|
||||||
|
}, {
|
||||||
|
name: "wildcard_cname_interaction",
|
||||||
|
host: "www.host2.com",
|
||||||
|
dtyp: dns.TypeA,
|
||||||
|
wantCName: "host.com",
|
||||||
|
wantVals: []net.IP{{1, 2, 3, 4}, {1, 2, 3, 5}},
|
||||||
|
}, {
|
||||||
|
name: "two_cnames",
|
||||||
|
host: "b.host.com",
|
||||||
|
dtyp: dns.TypeA,
|
||||||
|
wantCName: "somehost.com",
|
||||||
|
wantVals: []net.IP{{0, 0, 0, 0}},
|
||||||
|
}, {
|
||||||
|
name: "two_cnames_and_wildcard",
|
||||||
|
host: "b.host3.com",
|
||||||
|
dtyp: dns.TypeA,
|
||||||
|
wantCName: "x.host.com",
|
||||||
|
wantVals: []net.IP{{1, 2, 3, 5}},
|
||||||
|
}}
|
||||||
|
|
||||||
r = d.processRewrites("www.host.com", dns.TypeAAAA)
|
for _, tc := range testCases {
|
||||||
assert.Equal(t, Rewritten, r.Reason)
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
assert.Equal(t, "host.com", r.CanonName)
|
valsNum := len(tc.wantVals)
|
||||||
assert.Len(t, r.IPList, 1)
|
|
||||||
assert.True(t, r.IPList[0].Equal(net.ParseIP("1:2:3::4")))
|
|
||||||
|
|
||||||
// wildcard
|
r := d.processRewrites(tc.host, tc.dtyp)
|
||||||
d.Rewrites = []RewriteEntry{
|
if valsNum == 0 {
|
||||||
{"host.com", "1.2.3.4", 0, nil},
|
assert.Equal(t, NotFilteredNotFound, r.Reason)
|
||||||
{"*.host.com", "1.2.3.5", 0, nil},
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t, Rewritten, r.Reason)
|
||||||
|
if tc.wantCName != "" {
|
||||||
|
assert.Equal(t, tc.wantCName, r.CanonName)
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Len(t, r.IPList, valsNum)
|
||||||
|
for i, ip := range tc.wantVals {
|
||||||
|
assert.Equal(t, ip, r.IPList[i])
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
d.prepareRewrites()
|
|
||||||
r = d.processRewrites("host.com", dns.TypeA)
|
|
||||||
assert.Equal(t, Rewritten, r.Reason)
|
|
||||||
assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4}))
|
|
||||||
|
|
||||||
r = d.processRewrites("www.host.com", dns.TypeA)
|
|
||||||
assert.Equal(t, Rewritten, r.Reason)
|
|
||||||
assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 5}))
|
|
||||||
|
|
||||||
r = d.processRewrites("www.host2.com", dns.TypeA)
|
|
||||||
assert.Equal(t, NotFilteredNotFound, r.Reason)
|
|
||||||
|
|
||||||
// override a wildcard
|
|
||||||
d.Rewrites = []RewriteEntry{
|
|
||||||
{"a.host.com", "1.2.3.4", 0, nil},
|
|
||||||
{"*.host.com", "1.2.3.5", 0, nil},
|
|
||||||
}
|
|
||||||
d.prepareRewrites()
|
|
||||||
r = d.processRewrites("a.host.com", dns.TypeA)
|
|
||||||
assert.Equal(t, Rewritten, r.Reason)
|
|
||||||
assert.Len(t, r.IPList, 1)
|
|
||||||
assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4}))
|
|
||||||
|
|
||||||
// wildcard + CNAME
|
|
||||||
d.Rewrites = []RewriteEntry{
|
|
||||||
{"host.com", "1.2.3.4", 0, nil},
|
|
||||||
{"*.host.com", "host.com", 0, nil},
|
|
||||||
}
|
|
||||||
d.prepareRewrites()
|
|
||||||
r = d.processRewrites("www.host.com", dns.TypeA)
|
|
||||||
assert.Equal(t, Rewritten, r.Reason)
|
|
||||||
assert.Equal(t, "host.com", r.CanonName)
|
|
||||||
assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4}))
|
|
||||||
|
|
||||||
// 2 CNAMEs
|
|
||||||
d.Rewrites = []RewriteEntry{
|
|
||||||
{"b.host.com", "a.host.com", 0, nil},
|
|
||||||
{"a.host.com", "host.com", 0, nil},
|
|
||||||
{"host.com", "1.2.3.4", 0, nil},
|
|
||||||
}
|
|
||||||
d.prepareRewrites()
|
|
||||||
r = d.processRewrites("b.host.com", dns.TypeA)
|
|
||||||
assert.Equal(t, Rewritten, r.Reason)
|
|
||||||
assert.Equal(t, "host.com", r.CanonName)
|
|
||||||
assert.Len(t, r.IPList, 1)
|
|
||||||
assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4}))
|
|
||||||
|
|
||||||
// 2 CNAMEs + wildcard
|
|
||||||
d.Rewrites = []RewriteEntry{
|
|
||||||
{"b.host.com", "a.host.com", 0, nil},
|
|
||||||
{"a.host.com", "x.somehost.com", 0, nil},
|
|
||||||
{"*.somehost.com", "1.2.3.4", 0, nil},
|
|
||||||
}
|
|
||||||
d.prepareRewrites()
|
|
||||||
r = d.processRewrites("b.host.com", dns.TypeA)
|
|
||||||
assert.Equal(t, Rewritten, r.Reason)
|
|
||||||
assert.Equal(t, "x.somehost.com", r.CanonName)
|
|
||||||
assert.Len(t, r.IPList, 1)
|
|
||||||
assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4}))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRewritesLevels(t *testing.T) {
|
func TestRewritesLevels(t *testing.T) {
|
||||||
d := newForTest(nil, nil)
|
d := newForTest(nil, nil)
|
||||||
t.Cleanup(d.Close)
|
t.Cleanup(d.Close)
|
||||||
// exact host, wildcard L2, wildcard L3
|
// Exact host, wildcard L2, wildcard L3.
|
||||||
d.Rewrites = []RewriteEntry{
|
d.Rewrites = []RewriteEntry{{
|
||||||
{"host.com", "1.1.1.1", 0, nil},
|
Domain: "host.com",
|
||||||
{"*.host.com", "2.2.2.2", 0, nil},
|
Answer: "1.1.1.1",
|
||||||
{"*.sub.host.com", "3.3.3.3", 0, nil},
|
}, {
|
||||||
}
|
Domain: "*.host.com",
|
||||||
|
Answer: "2.2.2.2",
|
||||||
|
}, {
|
||||||
|
Domain: "*.sub.host.com",
|
||||||
|
Answer: "3.3.3.3",
|
||||||
|
}}
|
||||||
d.prepareRewrites()
|
d.prepareRewrites()
|
||||||
|
|
||||||
// match exact
|
testCases := []struct {
|
||||||
r := d.processRewrites("host.com", dns.TypeA)
|
name string
|
||||||
assert.Equal(t, Rewritten, r.Reason)
|
host string
|
||||||
assert.Len(t, r.IPList, 1)
|
want net.IP
|
||||||
assert.True(t, net.IP{1, 1, 1, 1}.Equal(r.IPList[0]))
|
}{{
|
||||||
|
name: "exact_match",
|
||||||
|
host: "host.com",
|
||||||
|
want: net.IP{1, 1, 1, 1},
|
||||||
|
}, {
|
||||||
|
name: "l2_match",
|
||||||
|
host: "sub.host.com",
|
||||||
|
want: net.IP{2, 2, 2, 2},
|
||||||
|
}, {
|
||||||
|
name: "l3_match",
|
||||||
|
host: "my.sub.host.com",
|
||||||
|
want: net.IP{3, 3, 3, 3},
|
||||||
|
}}
|
||||||
|
|
||||||
// match L2
|
for _, tc := range testCases {
|
||||||
r = d.processRewrites("sub.host.com", dns.TypeA)
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
assert.Equal(t, Rewritten, r.Reason)
|
r := d.processRewrites(tc.host, dns.TypeA)
|
||||||
assert.Len(t, r.IPList, 1)
|
assert.Equal(t, Rewritten, r.Reason)
|
||||||
assert.True(t, net.IP{2, 2, 2, 2}.Equal(r.IPList[0]))
|
require.Len(t, r.IPList, 1)
|
||||||
|
})
|
||||||
// match L3
|
}
|
||||||
r = d.processRewrites("my.sub.host.com", dns.TypeA)
|
|
||||||
assert.Equal(t, Rewritten, r.Reason)
|
|
||||||
assert.Len(t, r.IPList, 1)
|
|
||||||
assert.True(t, net.IP{3, 3, 3, 3}.Equal(r.IPList[0]))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRewritesExceptionCNAME(t *testing.T) {
|
func TestRewritesExceptionCNAME(t *testing.T) {
|
||||||
d := newForTest(nil, nil)
|
d := newForTest(nil, nil)
|
||||||
t.Cleanup(d.Close)
|
t.Cleanup(d.Close)
|
||||||
// wildcard; exception for a sub-domain
|
// Wildcard and exception for a sub-domain.
|
||||||
d.Rewrites = []RewriteEntry{
|
d.Rewrites = []RewriteEntry{{
|
||||||
{"*.host.com", "2.2.2.2", 0, nil},
|
Domain: "*.host.com",
|
||||||
{"sub.host.com", "sub.host.com", 0, nil},
|
Answer: "2.2.2.2",
|
||||||
}
|
}, {
|
||||||
|
Domain: "sub.host.com",
|
||||||
|
Answer: "sub.host.com",
|
||||||
|
}, {
|
||||||
|
Domain: "*.sub.host.com",
|
||||||
|
Answer: "*.sub.host.com",
|
||||||
|
}}
|
||||||
d.prepareRewrites()
|
d.prepareRewrites()
|
||||||
|
|
||||||
// match sub-domain
|
testCases := []struct {
|
||||||
r := d.processRewrites("my.host.com", dns.TypeA)
|
name string
|
||||||
assert.Equal(t, Rewritten, r.Reason)
|
host string
|
||||||
assert.Len(t, r.IPList, 1)
|
want net.IP
|
||||||
assert.True(t, net.IP{2, 2, 2, 2}.Equal(r.IPList[0]))
|
}{{
|
||||||
|
name: "match_sub-domain",
|
||||||
|
host: "my.host.com",
|
||||||
|
want: net.IP{2, 2, 2, 2},
|
||||||
|
}, {
|
||||||
|
name: "exception_cname",
|
||||||
|
host: "sub.host.com",
|
||||||
|
}, {
|
||||||
|
name: "exception_wildcard",
|
||||||
|
host: "my.sub.host.com",
|
||||||
|
}}
|
||||||
|
|
||||||
// match sub-domain, but handle exception
|
for _, tc := range testCases {
|
||||||
r = d.processRewrites("sub.host.com", dns.TypeA)
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
assert.Equal(t, NotFilteredNotFound, r.Reason)
|
r := d.processRewrites(tc.host, dns.TypeA)
|
||||||
}
|
if tc.want == nil {
|
||||||
|
assert.Equal(t, NotFilteredNotFound, r.Reason)
|
||||||
|
|
||||||
func TestRewritesExceptionWC(t *testing.T) {
|
return
|
||||||
d := newForTest(nil, nil)
|
}
|
||||||
t.Cleanup(d.Close)
|
|
||||||
// wildcard; exception for a sub-wildcard
|
assert.Equal(t, Rewritten, r.Reason)
|
||||||
d.Rewrites = []RewriteEntry{
|
require.Len(t, r.IPList, 1)
|
||||||
{"*.host.com", "2.2.2.2", 0, nil},
|
assert.True(t, tc.want.Equal(r.IPList[0]))
|
||||||
{"*.sub.host.com", "*.sub.host.com", 0, nil},
|
})
|
||||||
}
|
}
|
||||||
d.prepareRewrites()
|
|
||||||
|
|
||||||
// match sub-domain
|
|
||||||
r := d.processRewrites("my.host.com", dns.TypeA)
|
|
||||||
assert.Equal(t, Rewritten, r.Reason)
|
|
||||||
assert.Len(t, r.IPList, 1)
|
|
||||||
assert.True(t, net.IP{2, 2, 2, 2}.Equal(r.IPList[0]))
|
|
||||||
|
|
||||||
// match sub-domain, but handle exception
|
|
||||||
r = d.processRewrites("my.sub.host.com", dns.TypeA)
|
|
||||||
assert.Equal(t, NotFilteredNotFound, r.Reason)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRewritesExceptionIP(t *testing.T) {
|
func TestRewritesExceptionIP(t *testing.T) {
|
||||||
d := newForTest(nil, nil)
|
d := newForTest(nil, nil)
|
||||||
t.Cleanup(d.Close)
|
t.Cleanup(d.Close)
|
||||||
// exception for AAAA record
|
// Exception for AAAA record.
|
||||||
d.Rewrites = []RewriteEntry{
|
d.Rewrites = []RewriteEntry{{
|
||||||
{"host.com", "1.2.3.4", 0, nil},
|
Domain: "host.com",
|
||||||
{"host.com", "AAAA", 0, nil},
|
Answer: "1.2.3.4",
|
||||||
{"host2.com", "::1", 0, nil},
|
}, {
|
||||||
{"host2.com", "A", 0, nil},
|
Domain: "host.com",
|
||||||
{"host3.com", "A", 0, nil},
|
Answer: "AAAA",
|
||||||
}
|
}, {
|
||||||
|
Domain: "host2.com",
|
||||||
|
Answer: "::1",
|
||||||
|
}, {
|
||||||
|
Domain: "host2.com",
|
||||||
|
Answer: "A",
|
||||||
|
}, {
|
||||||
|
Domain: "host3.com",
|
||||||
|
Answer: "A",
|
||||||
|
}}
|
||||||
d.prepareRewrites()
|
d.prepareRewrites()
|
||||||
|
|
||||||
// match domain
|
testCases := []struct {
|
||||||
r := d.processRewrites("host.com", dns.TypeA)
|
name string
|
||||||
assert.Equal(t, Rewritten, r.Reason)
|
host string
|
||||||
assert.Len(t, r.IPList, 1)
|
dtyp uint16
|
||||||
assert.True(t, net.IP{1, 2, 3, 4}.Equal(r.IPList[0]))
|
want []net.IP
|
||||||
|
}{{
|
||||||
|
name: "match_A",
|
||||||
|
host: "host.com",
|
||||||
|
dtyp: dns.TypeA,
|
||||||
|
want: []net.IP{{1, 2, 3, 4}},
|
||||||
|
}, {
|
||||||
|
name: "exception_AAAA_host.com",
|
||||||
|
host: "host.com",
|
||||||
|
dtyp: dns.TypeAAAA,
|
||||||
|
}, {
|
||||||
|
name: "exception_A_host2.com",
|
||||||
|
host: "host2.com",
|
||||||
|
dtyp: dns.TypeA,
|
||||||
|
}, {
|
||||||
|
name: "match_AAAA_host2.com",
|
||||||
|
host: "host2.com",
|
||||||
|
dtyp: dns.TypeAAAA,
|
||||||
|
want: []net.IP{net.ParseIP("::1")},
|
||||||
|
}, {
|
||||||
|
name: "exception_A_host3.com",
|
||||||
|
host: "host3.com",
|
||||||
|
dtyp: dns.TypeA,
|
||||||
|
}, {
|
||||||
|
name: "match_AAAA_host3.com",
|
||||||
|
host: "host3.com",
|
||||||
|
dtyp: dns.TypeAAAA,
|
||||||
|
want: []net.IP{},
|
||||||
|
}}
|
||||||
|
|
||||||
// match exception
|
for _, tc := range testCases {
|
||||||
r = d.processRewrites("host.com", dns.TypeAAAA)
|
t.Run(tc.name+"_"+tc.host, func(t *testing.T) {
|
||||||
assert.Equal(t, NotFilteredNotFound, r.Reason)
|
r := d.processRewrites(tc.host, tc.dtyp)
|
||||||
|
if tc.want == nil {
|
||||||
|
assert.Equal(t, NotFilteredNotFound, r.Reason)
|
||||||
|
|
||||||
// match exception
|
return
|
||||||
r = d.processRewrites("host2.com", dns.TypeA)
|
}
|
||||||
assert.Equal(t, NotFilteredNotFound, r.Reason)
|
|
||||||
|
|
||||||
// match domain
|
assert.Equal(t, Rewritten, r.Reason)
|
||||||
r = d.processRewrites("host2.com", dns.TypeAAAA)
|
require.Len(t, r.IPList, len(tc.want))
|
||||||
assert.Equal(t, Rewritten, r.Reason)
|
for _, ip := range tc.want {
|
||||||
assert.Len(t, r.IPList, 1)
|
assert.True(t, ip.Equal(r.IPList[0]))
|
||||||
assert.Equal(t, "::1", r.IPList[0].String())
|
}
|
||||||
|
})
|
||||||
// match exception
|
}
|
||||||
r = d.processRewrites("host3.com", dns.TypeA)
|
|
||||||
assert.Equal(t, NotFilteredNotFound, r.Reason)
|
|
||||||
|
|
||||||
// match domain
|
|
||||||
r = d.processRewrites("host3.com", dns.TypeAAAA)
|
|
||||||
assert.Equal(t, Rewritten, r.Reason)
|
|
||||||
assert.Empty(t, r.IPList)
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||||
"github.com/AdguardTeam/golibs/cache"
|
"github.com/AdguardTeam/golibs/cache"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestSafeBrowsingHash(t *testing.T) {
|
func TestSafeBrowsingHash(t *testing.T) {
|
||||||
|
@ -155,25 +156,25 @@ func TestSBPC(t *testing.T) {
|
||||||
}}
|
}}
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
// Prepare the upstream.
|
||||||
// Prepare the upstream.
|
ups := &aghtest.TestBlockUpstream{
|
||||||
ups := &aghtest.TestBlockUpstream{
|
Hostname: hostname,
|
||||||
Hostname: hostname,
|
Block: tc.block,
|
||||||
Block: tc.block,
|
}
|
||||||
}
|
d.SetSafeBrowsingUpstream(ups)
|
||||||
d.SetSafeBrowsingUpstream(ups)
|
d.SetParentalUpstream(ups)
|
||||||
d.SetParentalUpstream(ups)
|
|
||||||
|
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
// Firstly, check the request blocking.
|
// Firstly, check the request blocking.
|
||||||
hits := 0
|
hits := 0
|
||||||
res, err := tc.testFunc(hostname)
|
res, err := tc.testFunc(hostname)
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
if tc.block {
|
if tc.block {
|
||||||
assert.True(t, res.IsFiltered)
|
assert.True(t, res.IsFiltered)
|
||||||
assert.Len(t, res.Rules, 1)
|
require.Len(t, res.Rules, 1)
|
||||||
hits++
|
hits++
|
||||||
} else {
|
} else {
|
||||||
assert.False(t, res.IsFiltered)
|
require.False(t, res.IsFiltered)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check the cache state, check the response is now cached.
|
// Check the cache state, check the response is now cached.
|
||||||
|
@ -185,12 +186,12 @@ func TestSBPC(t *testing.T) {
|
||||||
|
|
||||||
// Now make the same request to check the cache was used.
|
// Now make the same request to check the cache was used.
|
||||||
res, err = tc.testFunc(hostname)
|
res, err = tc.testFunc(hostname)
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
if tc.block {
|
if tc.block {
|
||||||
assert.True(t, res.IsFiltered)
|
assert.True(t, res.IsFiltered)
|
||||||
assert.Len(t, res.Rules, 1)
|
require.Len(t, res.Rules, 1)
|
||||||
} else {
|
} else {
|
||||||
assert.False(t, res.IsFiltered)
|
require.False(t, res.IsFiltered)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check the cache state, it should've been used.
|
// Check the cache state, it should've been used.
|
||||||
|
@ -199,8 +200,8 @@ func TestSBPC(t *testing.T) {
|
||||||
|
|
||||||
// Check that there were no additional requests.
|
// Check that there were no additional requests.
|
||||||
assert.Equal(t, 1, ups.RequestsCount())
|
assert.Equal(t, 1, ups.RequestsCount())
|
||||||
|
|
||||||
purgeCaches()
|
|
||||||
})
|
})
|
||||||
|
|
||||||
|
purgeCaches()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,71 +5,153 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestIsBlockedIPAllowed(t *testing.T) {
|
func TestIsBlockedIP(t *testing.T) {
|
||||||
a := &accessCtx{}
|
const (
|
||||||
assert.Nil(t, a.Init([]string{"1.1.1.1", "2.2.0.0/16"}, nil, nil))
|
ip int = iota
|
||||||
|
cidr
|
||||||
|
)
|
||||||
|
|
||||||
disallowed, disallowedRule := a.IsBlockedIP(net.IPv4(1, 1, 1, 1))
|
rules := []string{
|
||||||
assert.False(t, disallowed)
|
ip: "1.1.1.1",
|
||||||
assert.Empty(t, disallowedRule)
|
cidr: "2.2.0.0/16",
|
||||||
|
}
|
||||||
|
|
||||||
disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(1, 1, 1, 2))
|
testCases := []struct {
|
||||||
assert.True(t, disallowed)
|
name string
|
||||||
assert.Empty(t, disallowedRule)
|
allowed bool
|
||||||
|
ip net.IP
|
||||||
|
wantDis bool
|
||||||
|
wantRule string
|
||||||
|
}{{
|
||||||
|
name: "allow_ip",
|
||||||
|
allowed: true,
|
||||||
|
ip: net.IPv4(1, 1, 1, 1),
|
||||||
|
wantDis: false,
|
||||||
|
wantRule: "",
|
||||||
|
}, {
|
||||||
|
name: "disallow_ip",
|
||||||
|
allowed: true,
|
||||||
|
ip: net.IPv4(1, 1, 1, 2),
|
||||||
|
wantDis: true,
|
||||||
|
wantRule: "",
|
||||||
|
}, {
|
||||||
|
name: "allow_cidr",
|
||||||
|
allowed: true,
|
||||||
|
ip: net.IPv4(2, 2, 1, 1),
|
||||||
|
wantDis: false,
|
||||||
|
wantRule: "",
|
||||||
|
}, {
|
||||||
|
name: "disallow_cidr",
|
||||||
|
allowed: true,
|
||||||
|
ip: net.IPv4(2, 3, 1, 1),
|
||||||
|
wantDis: true,
|
||||||
|
wantRule: "",
|
||||||
|
}, {
|
||||||
|
name: "allow_ip",
|
||||||
|
allowed: false,
|
||||||
|
ip: net.IPv4(1, 1, 1, 1),
|
||||||
|
wantDis: true,
|
||||||
|
wantRule: rules[ip],
|
||||||
|
}, {
|
||||||
|
name: "disallow_ip",
|
||||||
|
allowed: false,
|
||||||
|
ip: net.IPv4(1, 1, 1, 2),
|
||||||
|
wantDis: false,
|
||||||
|
wantRule: "",
|
||||||
|
}, {
|
||||||
|
name: "allow_cidr",
|
||||||
|
allowed: false,
|
||||||
|
ip: net.IPv4(2, 2, 1, 1),
|
||||||
|
wantDis: true,
|
||||||
|
wantRule: rules[cidr],
|
||||||
|
}, {
|
||||||
|
name: "disallow_cidr",
|
||||||
|
allowed: false,
|
||||||
|
ip: net.IPv4(2, 3, 1, 1),
|
||||||
|
wantDis: false,
|
||||||
|
wantRule: "",
|
||||||
|
}}
|
||||||
|
|
||||||
disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(2, 2, 1, 1))
|
for _, tc := range testCases {
|
||||||
assert.False(t, disallowed)
|
prefix := "allowed_"
|
||||||
assert.Empty(t, disallowedRule)
|
if !tc.allowed {
|
||||||
|
prefix = "disallowed_"
|
||||||
|
}
|
||||||
|
|
||||||
disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(2, 3, 1, 1))
|
t.Run(prefix+tc.name, func(t *testing.T) {
|
||||||
assert.True(t, disallowed)
|
aCtx := &accessCtx{}
|
||||||
assert.Empty(t, disallowedRule)
|
allowedRules := rules
|
||||||
|
var disallowedRules []string
|
||||||
|
|
||||||
|
if !tc.allowed {
|
||||||
|
allowedRules, disallowedRules = disallowedRules, allowedRules
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Nil(t, aCtx.Init(allowedRules, disallowedRules, nil))
|
||||||
|
|
||||||
|
disallowed, rule := aCtx.IsBlockedIP(tc.ip)
|
||||||
|
assert.Equal(t, tc.wantDis, disallowed)
|
||||||
|
assert.Equal(t, tc.wantRule, rule)
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIsBlockedIPDisallowed(t *testing.T) {
|
func TestIsBlockedDomain(t *testing.T) {
|
||||||
a := &accessCtx{}
|
aCtx := &accessCtx{}
|
||||||
assert.Nil(t, a.Init(nil, []string{"1.1.1.1", "2.2.0.0/16"}, nil))
|
require.Nil(t, aCtx.Init(nil, nil, []string{
|
||||||
|
|
||||||
disallowed, disallowedRule := a.IsBlockedIP(net.IPv4(1, 1, 1, 1))
|
|
||||||
assert.True(t, disallowed)
|
|
||||||
assert.Equal(t, "1.1.1.1", disallowedRule)
|
|
||||||
|
|
||||||
disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(1, 1, 1, 2))
|
|
||||||
assert.False(t, disallowed)
|
|
||||||
assert.Empty(t, disallowedRule)
|
|
||||||
|
|
||||||
disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(2, 2, 1, 1))
|
|
||||||
assert.True(t, disallowed)
|
|
||||||
assert.Equal(t, "2.2.0.0/16", disallowedRule)
|
|
||||||
|
|
||||||
disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(2, 3, 1, 1))
|
|
||||||
assert.False(t, disallowed)
|
|
||||||
assert.Empty(t, disallowedRule)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIsBlockedIPBlockedDomain(t *testing.T) {
|
|
||||||
a := &accessCtx{}
|
|
||||||
assert.True(t, a.Init(nil, nil, []string{
|
|
||||||
"host1",
|
"host1",
|
||||||
"host2",
|
|
||||||
"*.host.com",
|
"*.host.com",
|
||||||
"||host3.com^",
|
"||host3.com^",
|
||||||
}) == nil)
|
}))
|
||||||
|
|
||||||
// match by "host2.com"
|
testCases := []struct {
|
||||||
assert.True(t, a.IsBlockedDomain("host1"))
|
name string
|
||||||
assert.True(t, a.IsBlockedDomain("host2"))
|
domain string
|
||||||
assert.False(t, a.IsBlockedDomain("host3"))
|
want bool
|
||||||
|
}{{
|
||||||
|
name: "plain_match",
|
||||||
|
domain: "host1",
|
||||||
|
want: true,
|
||||||
|
}, {
|
||||||
|
name: "plain_mismatch",
|
||||||
|
domain: "host2",
|
||||||
|
want: false,
|
||||||
|
}, {
|
||||||
|
name: "wildcard_type-1_match_short",
|
||||||
|
domain: "asdf.host.com",
|
||||||
|
want: true,
|
||||||
|
}, {
|
||||||
|
name: "wildcard_type-1_match_long",
|
||||||
|
domain: "qwer.asdf.host.com",
|
||||||
|
want: true,
|
||||||
|
}, {
|
||||||
|
name: "wildcard_type-1_mismatch_no-lead",
|
||||||
|
domain: "host.com",
|
||||||
|
want: false,
|
||||||
|
}, {
|
||||||
|
name: "wildcard_type-1_mismatch_bad-asterisk",
|
||||||
|
domain: "asdf.zhost.com",
|
||||||
|
want: false,
|
||||||
|
}, {
|
||||||
|
name: "wildcard_type-2_match_simple",
|
||||||
|
domain: "host3.com",
|
||||||
|
want: true,
|
||||||
|
}, {
|
||||||
|
name: "wildcard_type-2_match_complex",
|
||||||
|
domain: "asdf.host3.com",
|
||||||
|
want: true,
|
||||||
|
}, {
|
||||||
|
name: "wildcard_type-2_mismatch",
|
||||||
|
domain: ".host3.com",
|
||||||
|
want: false,
|
||||||
|
}}
|
||||||
|
|
||||||
// match by wildcard "*.host.com"
|
for _, tc := range testCases {
|
||||||
assert.False(t, a.IsBlockedDomain("host.com"))
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
assert.True(t, a.IsBlockedDomain("asdf.host.com"))
|
assert.Equal(t, tc.want, aCtx.IsBlockedDomain(tc.domain))
|
||||||
assert.True(t, a.IsBlockedDomain("qwer.asdf.host.com"))
|
})
|
||||||
assert.False(t, a.IsBlockedDomain("asdf.zhost.com"))
|
}
|
||||||
|
|
||||||
// match by wildcard "||host3.com^"
|
|
||||||
assert.True(t, a.IsBlockedDomain("host3.com"))
|
|
||||||
assert.True(t, a.IsBlockedDomain("asdf.host3.com"))
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,6 +27,7 @@ import (
|
||||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
||||||
|
@ -42,14 +43,180 @@ func startDeferStop(t *testing.T, s *Server) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
err := s.Start()
|
err := s.Start()
|
||||||
assert.Nilf(t, err, "failed to start server: %s", err)
|
require.Nilf(t, err, "failed to start server: %s", err)
|
||||||
|
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
err := s.Stop()
|
err := s.Stop()
|
||||||
assert.Nilf(t, err, "dns server failed to stop: %s", err)
|
require.Nilf(t, err, "dns server failed to stop: %s", err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func createTestServer(t *testing.T, filterConf *dnsfilter.Config, forwardConf ServerConfig) *Server {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
rules := `||nxdomain.example.org
|
||||||
|
||null.example.org^
|
||||||
|
127.0.0.1 host.example.org
|
||||||
|
@@||whitelist.example.org^
|
||||||
|
||127.0.0.255`
|
||||||
|
filters := []dnsfilter.Filter{{
|
||||||
|
ID: 0, Data: []byte(rules),
|
||||||
|
}}
|
||||||
|
|
||||||
|
f := dnsfilter.New(filterConf, filters)
|
||||||
|
|
||||||
|
s := NewServer(DNSCreateParams{DNSFilter: f})
|
||||||
|
s.conf = forwardConf
|
||||||
|
require.Nil(t, s.Prepare(nil))
|
||||||
|
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func createServerTLSConfig(t *testing.T) (*tls.Config, []byte, []byte) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
require.Nilf(t, err, "cannot generate RSA key: %s", err)
|
||||||
|
|
||||||
|
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
|
||||||
|
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
|
||||||
|
require.Nilf(t, err, "failed to generate serial number: %s", err)
|
||||||
|
|
||||||
|
notBefore := time.Now()
|
||||||
|
notAfter := notBefore.Add(5 * 365 * time.Hour * 24)
|
||||||
|
|
||||||
|
template := x509.Certificate{
|
||||||
|
SerialNumber: serialNumber,
|
||||||
|
Subject: pkix.Name{
|
||||||
|
Organization: []string{"AdGuard Tests"},
|
||||||
|
},
|
||||||
|
NotBefore: notBefore,
|
||||||
|
NotAfter: notAfter,
|
||||||
|
|
||||||
|
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
|
||||||
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||||
|
BasicConstraintsValid: true,
|
||||||
|
IsCA: true,
|
||||||
|
}
|
||||||
|
template.DNSNames = append(template.DNSNames, tlsServerName)
|
||||||
|
|
||||||
|
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey(privateKey), privateKey)
|
||||||
|
require.Nilf(t, err, "failed to create certificate: %s", err)
|
||||||
|
|
||||||
|
certPem := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
|
||||||
|
keyPem := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)})
|
||||||
|
|
||||||
|
cert, err := tls.X509KeyPair(certPem, keyPem)
|
||||||
|
require.Nilf(t, err, "failed to create certificate: %s", err)
|
||||||
|
|
||||||
|
return &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{cert},
|
||||||
|
ServerName: tlsServerName,
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
}, certPem, keyPem
|
||||||
|
}
|
||||||
|
|
||||||
|
func createTestTLS(t *testing.T, tlsConf TLSConfig) (s *Server, certPem []byte) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
var keyPem []byte
|
||||||
|
_, certPem, keyPem = createServerTLSConfig(t)
|
||||||
|
|
||||||
|
s = createTestServer(t, &dnsfilter.Config{}, ServerConfig{
|
||||||
|
UDPListenAddr: &net.UDPAddr{},
|
||||||
|
TCPListenAddr: &net.TCPAddr{},
|
||||||
|
})
|
||||||
|
|
||||||
|
tlsConf.CertificateChainData, tlsConf.PrivateKeyData = certPem, keyPem
|
||||||
|
s.conf.TLSConfig = tlsConf
|
||||||
|
|
||||||
|
err := s.Prepare(nil)
|
||||||
|
require.Nilf(t, err, "failed to prepare server: %s", err)
|
||||||
|
|
||||||
|
return s, certPem
|
||||||
|
}
|
||||||
|
|
||||||
|
func createGoogleATestMessage() *dns.Msg {
|
||||||
|
return createTestMessage("google-public-dns-a.google.com.")
|
||||||
|
}
|
||||||
|
|
||||||
|
func createTestMessage(host string) *dns.Msg {
|
||||||
|
return &dns.Msg{
|
||||||
|
MsgHdr: dns.MsgHdr{
|
||||||
|
Id: dns.Id(),
|
||||||
|
RecursionDesired: true,
|
||||||
|
},
|
||||||
|
Question: []dns.Question{{
|
||||||
|
Name: host,
|
||||||
|
Qtype: dns.TypeA,
|
||||||
|
Qclass: dns.ClassINET,
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createTestMessageWithType(host string, qtype uint16) *dns.Msg {
|
||||||
|
req := createTestMessage(host)
|
||||||
|
req.Question[0].Qtype = qtype
|
||||||
|
|
||||||
|
return req
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertGoogleAResponse(t *testing.T, reply *dns.Msg) {
|
||||||
|
assertResponse(t, reply, net.IP{8, 8, 8, 8})
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertResponse(t *testing.T, reply *dns.Msg, ip net.IP) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
require.Lenf(t, reply.Answer, 1, "dns server returned reply with wrong number of answers - %d", len(reply.Answer))
|
||||||
|
|
||||||
|
a, ok := reply.Answer[0].(*dns.A)
|
||||||
|
require.Truef(t, ok, "dns server returned wrong answer type instead of A: %v", reply.Answer[0])
|
||||||
|
assert.Truef(t, a.A.Equal(ip), "dns server returned wrong answer instead of %s: %s", ip, a.A)
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendTestMessagesAsync sends messages in parallel to check for race issues.
|
||||||
|
//
|
||||||
|
//lint:ignore U1000 it's called from the function which is skipped for now.
|
||||||
|
func sendTestMessagesAsync(t *testing.T, conn *dns.Conn) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
wg := &sync.WaitGroup{}
|
||||||
|
|
||||||
|
for i := 0; i < testMessagesCount; i++ {
|
||||||
|
msg := createGoogleATestMessage()
|
||||||
|
wg.Add(1)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
err := conn.WriteMsg(msg)
|
||||||
|
require.Nilf(t, err, "cannot write message: %s", err)
|
||||||
|
|
||||||
|
res, err := conn.ReadMsg()
|
||||||
|
require.Nilf(t, err, "cannot read response to message: %s", err)
|
||||||
|
|
||||||
|
assertGoogleAResponse(t, res)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendTestMessages(t *testing.T, conn *dns.Conn) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
for i := 0; i < testMessagesCount; i++ {
|
||||||
|
req := createGoogleATestMessage()
|
||||||
|
err := conn.WriteMsg(req)
|
||||||
|
assert.Nilf(t, err, "cannot write message #%d: %s", i, err)
|
||||||
|
|
||||||
|
res, err := conn.ReadMsg()
|
||||||
|
assert.Nilf(t, err, "cannot read response to message #%d: %s", i, err)
|
||||||
|
assertGoogleAResponse(t, res)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestServer(t *testing.T) {
|
func TestServer(t *testing.T) {
|
||||||
s := createTestServer(t, &dnsfilter.Config{}, ServerConfig{
|
s := createTestServer(t, &dnsfilter.Config{}, ServerConfig{
|
||||||
UDPListenAddr: &net.UDPAddr{},
|
UDPListenAddr: &net.UDPAddr{},
|
||||||
|
@ -81,7 +248,7 @@ func TestServer(t *testing.T) {
|
||||||
client := dns.Client{Net: tc.proto}
|
client := dns.Client{Net: tc.proto}
|
||||||
|
|
||||||
reply, _, err := client.Exchange(createGoogleATestMessage(), addr.String())
|
reply, _, err := client.Exchange(createGoogleATestMessage(), addr.String())
|
||||||
assert.Nilf(t, err, "сouldn't talk to server %s: %s", addr, err)
|
require.Nilf(t, err, "сouldn't talk to server %s: %s", addr, err)
|
||||||
|
|
||||||
assertGoogleAResponse(t, reply)
|
assertGoogleAResponse(t, reply)
|
||||||
})
|
})
|
||||||
|
@ -106,31 +273,12 @@ func TestServerWithProtectionDisabled(t *testing.T) {
|
||||||
req := createGoogleATestMessage()
|
req := createGoogleATestMessage()
|
||||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||||
client := dns.Client{Net: proxy.ProtoUDP}
|
client := dns.Client{Net: proxy.ProtoUDP}
|
||||||
|
|
||||||
reply, _, err := client.Exchange(req, addr.String())
|
reply, _, err := client.Exchange(req, addr.String())
|
||||||
assert.Nilf(t, err, "сouldn't talk to server %s: %s", addr, err)
|
require.Nilf(t, err, "сouldn't talk to server %s: %s", addr, err)
|
||||||
assertGoogleAResponse(t, reply)
|
assertGoogleAResponse(t, reply)
|
||||||
}
|
}
|
||||||
|
|
||||||
func createTestTLS(t *testing.T, tlsConf TLSConfig) (s *Server, certPem []byte) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
var keyPem []byte
|
|
||||||
_, certPem, keyPem = createServerTLSConfig(t)
|
|
||||||
|
|
||||||
s = createTestServer(t, &dnsfilter.Config{}, ServerConfig{
|
|
||||||
UDPListenAddr: &net.UDPAddr{},
|
|
||||||
TCPListenAddr: &net.TCPAddr{},
|
|
||||||
})
|
|
||||||
|
|
||||||
tlsConf.CertificateChainData, tlsConf.PrivateKeyData = certPem, keyPem
|
|
||||||
s.conf.TLSConfig = tlsConf
|
|
||||||
|
|
||||||
err := s.Prepare(nil)
|
|
||||||
assert.Nilf(t, err, "failed to prepare server: %s", err)
|
|
||||||
|
|
||||||
return s, certPem
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDoTServer(t *testing.T) {
|
func TestDoTServer(t *testing.T) {
|
||||||
s, certPem := createTestTLS(t, TLSConfig{
|
s, certPem := createTestTLS(t, TLSConfig{
|
||||||
TLSListenAddr: &net.TCPAddr{},
|
TLSListenAddr: &net.TCPAddr{},
|
||||||
|
@ -156,7 +304,7 @@ func TestDoTServer(t *testing.T) {
|
||||||
// Create a DNS-over-TLS client connection.
|
// Create a DNS-over-TLS client connection.
|
||||||
addr := s.dnsProxy.Addr(proxy.ProtoTLS)
|
addr := s.dnsProxy.Addr(proxy.ProtoTLS)
|
||||||
conn, err := dns.DialWithTLS("tcp-tls", addr.String(), tlsConfig)
|
conn, err := dns.DialWithTLS("tcp-tls", addr.String(), tlsConfig)
|
||||||
assert.Nilf(t, err, "cannot connect to the proxy: %s", err)
|
require.Nilf(t, err, "cannot connect to the proxy: %s", err)
|
||||||
|
|
||||||
sendTestMessages(t, conn)
|
sendTestMessages(t, conn)
|
||||||
}
|
}
|
||||||
|
@ -178,12 +326,12 @@ func TestDoQServer(t *testing.T) {
|
||||||
addr := s.dnsProxy.Addr(proxy.ProtoQUIC)
|
addr := s.dnsProxy.Addr(proxy.ProtoQUIC)
|
||||||
opts := upstream.Options{InsecureSkipVerify: true}
|
opts := upstream.Options{InsecureSkipVerify: true}
|
||||||
u, err := upstream.AddressToUpstream(fmt.Sprintf("%s://%s", proxy.ProtoQUIC, addr), opts)
|
u, err := upstream.AddressToUpstream(fmt.Sprintf("%s://%s", proxy.ProtoQUIC, addr), opts)
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
// Send the test message.
|
// Send the test message.
|
||||||
req := createGoogleATestMessage()
|
req := createGoogleATestMessage()
|
||||||
res, err := u.Exchange(req)
|
res, err := u.Exchange(req)
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
assertGoogleAResponse(t, res)
|
assertGoogleAResponse(t, res)
|
||||||
}
|
}
|
||||||
|
@ -221,7 +369,7 @@ func TestServerRace(t *testing.T) {
|
||||||
// Message over UDP.
|
// Message over UDP.
|
||||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||||
conn, err := dns.Dial(proxy.ProtoUDP, addr.String())
|
conn, err := dns.Dial(proxy.ProtoUDP, addr.String())
|
||||||
assert.Nilf(t, err, "cannot connect to the proxy: %s", err)
|
require.Nilf(t, err, "cannot connect to the proxy: %s", err)
|
||||||
|
|
||||||
sendTestMessagesAsync(t, conn)
|
sendTestMessagesAsync(t, conn)
|
||||||
}
|
}
|
||||||
|
@ -282,8 +430,9 @@ func TestSafeSearch(t *testing.T) {
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.host, func(t *testing.T) {
|
t.Run(tc.host, func(t *testing.T) {
|
||||||
req := createTestMessage(tc.host)
|
req := createTestMessage(tc.host)
|
||||||
|
|
||||||
reply, _, err := client.Exchange(req, addr)
|
reply, _, err := client.Exchange(req, addr)
|
||||||
assert.Nilf(t, err, "couldn't talk to server %s: %s", addr, err)
|
require.Nilf(t, err, "couldn't talk to server %s: %s", addr, err)
|
||||||
assertResponse(t, reply, tc.want)
|
assertResponse(t, reply, tc.want)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -330,8 +479,10 @@ func TestBlockedRequest(t *testing.T) {
|
||||||
req := createTestMessage("nxdomain.example.org.")
|
req := createTestMessage("nxdomain.example.org.")
|
||||||
|
|
||||||
reply, err := dns.Exchange(req, addr.String())
|
reply, err := dns.Exchange(req, addr.String())
|
||||||
assert.Nilf(t, err, "couldn't talk to server %s: %s", addr, err)
|
require.Nilf(t, err, "couldn't talk to server %s: %s", addr, err)
|
||||||
assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
|
assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
|
||||||
|
|
||||||
|
require.Len(t, reply.Answer, 1)
|
||||||
assert.True(t, reply.Answer[0].(*dns.A).A.IsUnspecified())
|
assert.True(t, reply.Answer[0].(*dns.A).A.IsUnspecified())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -364,28 +515,14 @@ func TestServerCustomClientUpstream(t *testing.T) {
|
||||||
|
|
||||||
reply, err := dns.Exchange(req, addr.String())
|
reply, err := dns.Exchange(req, addr.String())
|
||||||
|
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
|
assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
|
||||||
assert.NotEmpty(t, reply.Answer)
|
require.NotEmpty(t, reply.Answer)
|
||||||
|
|
||||||
|
require.Len(t, reply.Answer, 1)
|
||||||
assert.Equal(t, net.IP{192, 168, 0, 1}, reply.Answer[0].(*dns.A).A)
|
assert.Equal(t, net.IP{192, 168, 0, 1}, reply.Answer[0].(*dns.A).A)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) startWithUpstream(u upstream.Upstream) error {
|
|
||||||
s.Lock()
|
|
||||||
defer s.Unlock()
|
|
||||||
err := s.Prepare(nil)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
s.dnsProxy.UpstreamConfig = &proxy.UpstreamConfig{
|
|
||||||
Upstreams: []upstream.Upstream{u},
|
|
||||||
}
|
|
||||||
|
|
||||||
return s.dnsProxy.Start()
|
|
||||||
}
|
|
||||||
|
|
||||||
// testCNAMEs is a map of names and CNAMEs necessary for the TestUpstream work.
|
// testCNAMEs is a map of names and CNAMEs necessary for the TestUpstream work.
|
||||||
var testCNAMEs = map[string]string{
|
var testCNAMEs = map[string]string{
|
||||||
"badhost.": "null.example.org.",
|
"badhost.": "null.example.org.",
|
||||||
|
@ -409,15 +546,19 @@ func TestBlockCNAMEProtectionEnabled(t *testing.T) {
|
||||||
IPv6: nil,
|
IPv6: nil,
|
||||||
}
|
}
|
||||||
s.conf.ProtectionEnabled = false
|
s.conf.ProtectionEnabled = false
|
||||||
err := s.startWithUpstream(testUpstm)
|
s.dnsProxy.UpstreamConfig = &proxy.UpstreamConfig{
|
||||||
assert.Nil(t, err)
|
Upstreams: []upstream.Upstream{testUpstm},
|
||||||
|
}
|
||||||
|
startDeferStop(t, s)
|
||||||
|
|
||||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||||
|
|
||||||
// 'badhost' has a canonical name 'null.example.org' which is blocked by
|
// 'badhost' has a canonical name 'null.example.org' which should be
|
||||||
// filters: but protection is disabled so response is _not_ blocked.
|
// blocked by filters, but protection is disabled so it is not.
|
||||||
req := createTestMessage("badhost.")
|
req := createTestMessage("badhost.")
|
||||||
|
|
||||||
reply, err := dns.Exchange(req, addr.String())
|
reply, err := dns.Exchange(req, addr.String())
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
|
assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -465,11 +606,15 @@ func TestBlockCNAME(t *testing.T) {
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run("block_cname_"+tc.host, func(t *testing.T) {
|
t.Run("block_cname_"+tc.host, func(t *testing.T) {
|
||||||
req := createTestMessage(tc.host)
|
req := createTestMessage(tc.host)
|
||||||
|
|
||||||
reply, err := dns.Exchange(req, addr)
|
reply, err := dns.Exchange(req, addr)
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
|
assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
|
||||||
if tc.want {
|
if tc.want {
|
||||||
assert.True(t, reply.Answer[0].(*dns.A).A.IsUnspecified())
|
require.Len(t, reply.Answer, 1)
|
||||||
|
a, ok := reply.Answer[0].(*dns.A)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.True(t, a.A.IsUnspecified())
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -513,7 +658,7 @@ func TestClientRulesForCNAMEMatching(t *testing.T) {
|
||||||
// However, in our case it should not be blocked as filtering is
|
// However, in our case it should not be blocked as filtering is
|
||||||
// disabled on the client level.
|
// disabled on the client level.
|
||||||
reply, err := dns.Exchange(&req, addr.String())
|
reply, err := dns.Exchange(&req, addr.String())
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
|
assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -544,10 +689,10 @@ func TestNullBlockedRequest(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
reply, err := dns.Exchange(&req, addr.String())
|
reply, err := dns.Exchange(&req, addr.String())
|
||||||
assert.Nilf(t, err, "couldn't talk to server %s: %s", addr, err)
|
require.Nilf(t, err, "couldn't talk to server %s: %s", addr, err)
|
||||||
assert.Lenf(t, reply.Answer, 1, "dns server %s returned reply with wrong number of answers - %d", addr, len(reply.Answer))
|
require.Lenf(t, reply.Answer, 1, "dns server %s returned reply with wrong number of answers - %d", addr, len(reply.Answer))
|
||||||
a, ok := reply.Answer[0].(*dns.A)
|
a, ok := reply.Answer[0].(*dns.A)
|
||||||
assert.Truef(t, ok, "dns server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0])
|
require.Truef(t, ok, "dns server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0])
|
||||||
assert.Truef(t, a.A.IsUnspecified(), "dns server %s returned wrong answer instead of 0.0.0.0: %v", addr, a.A)
|
assert.Truef(t, a.A.IsUnspecified(), "dns server %s returned wrong answer instead of 0.0.0.0: %v", addr, a.A)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -561,7 +706,7 @@ func TestBlockedCustomIP(t *testing.T) {
|
||||||
s := NewServer(DNSCreateParams{
|
s := NewServer(DNSCreateParams{
|
||||||
DNSFilter: dnsfilter.New(&dnsfilter.Config{}, filters),
|
DNSFilter: dnsfilter.New(&dnsfilter.Config{}, filters),
|
||||||
})
|
})
|
||||||
conf := ServerConfig{
|
conf := &ServerConfig{
|
||||||
UDPListenAddr: &net.UDPAddr{},
|
UDPListenAddr: &net.UDPAddr{},
|
||||||
TCPListenAddr: &net.TCPAddr{},
|
TCPListenAddr: &net.TCPAddr{},
|
||||||
FilteringConfig: FilteringConfig{
|
FilteringConfig: FilteringConfig{
|
||||||
|
@ -572,11 +717,11 @@ func TestBlockedCustomIP(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
// Invalid BlockingIPv4.
|
// Invalid BlockingIPv4.
|
||||||
assert.NotNil(t, s.Prepare(&conf))
|
assert.NotNil(t, s.Prepare(conf))
|
||||||
|
|
||||||
conf.BlockingIPv4 = net.IP{0, 0, 0, 1}
|
conf.BlockingIPv4 = net.IP{0, 0, 0, 1}
|
||||||
conf.BlockingIPv6 = net.ParseIP("::1")
|
conf.BlockingIPv6 = net.ParseIP("::1")
|
||||||
assert.Nil(t, s.Prepare(&conf))
|
require.Nil(t, s.Prepare(conf))
|
||||||
|
|
||||||
startDeferStop(t, s)
|
startDeferStop(t, s)
|
||||||
|
|
||||||
|
@ -584,18 +729,18 @@ func TestBlockedCustomIP(t *testing.T) {
|
||||||
|
|
||||||
req := createTestMessageWithType("null.example.org.", dns.TypeA)
|
req := createTestMessageWithType("null.example.org.", dns.TypeA)
|
||||||
reply, err := dns.Exchange(req, addr.String())
|
reply, err := dns.Exchange(req, addr.String())
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
assert.Len(t, reply.Answer, 1)
|
require.Len(t, reply.Answer, 1)
|
||||||
a, ok := reply.Answer[0].(*dns.A)
|
a, ok := reply.Answer[0].(*dns.A)
|
||||||
assert.True(t, ok)
|
require.True(t, ok)
|
||||||
assert.True(t, net.IP{0, 0, 0, 1}.Equal(a.A))
|
assert.True(t, net.IP{0, 0, 0, 1}.Equal(a.A))
|
||||||
|
|
||||||
req = createTestMessageWithType("null.example.org.", dns.TypeAAAA)
|
req = createTestMessageWithType("null.example.org.", dns.TypeAAAA)
|
||||||
reply, err = dns.Exchange(req, addr.String())
|
reply, err = dns.Exchange(req, addr.String())
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
assert.Len(t, reply.Answer, 1)
|
require.Len(t, reply.Answer, 1)
|
||||||
a6, ok := reply.Answer[0].(*dns.AAAA)
|
a6, ok := reply.Answer[0].(*dns.AAAA)
|
||||||
assert.True(t, ok)
|
require.True(t, ok)
|
||||||
assert.Equal(t, "::1", a6.AAAA.String())
|
assert.Equal(t, "::1", a6.AAAA.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -615,11 +760,10 @@ func TestBlockedByHosts(t *testing.T) {
|
||||||
req := createTestMessage("host.example.org.")
|
req := createTestMessage("host.example.org.")
|
||||||
|
|
||||||
reply, err := dns.Exchange(req, addr.String())
|
reply, err := dns.Exchange(req, addr.String())
|
||||||
assert.Nilf(t, err, "couldn't talk to server %s: %s", addr, err)
|
require.Nilf(t, err, "couldn't talk to server %s: %s", addr, err)
|
||||||
assert.Lenf(t, reply.Answer, 1, "dns server %s returned reply with wrong number of answers - %d", addr, len(reply.Answer))
|
require.Lenf(t, reply.Answer, 1, "dns server %s returned reply with wrong number of answers - %d", addr, len(reply.Answer))
|
||||||
|
|
||||||
a, ok := reply.Answer[0].(*dns.A)
|
a, ok := reply.Answer[0].(*dns.A)
|
||||||
assert.Truef(t, ok, "dns server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0])
|
require.Truef(t, ok, "dns server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0])
|
||||||
assert.Equalf(t, net.IP{127, 0, 0, 1}, a.A, "dns server %s returned wrong answer instead of 8.8.8.8: %v", addr, a.A)
|
assert.Equalf(t, net.IP{127, 0, 0, 1}, a.A, "dns server %s returned wrong answer instead of 8.8.8.8: %v", addr, a.A)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -630,7 +774,7 @@ func TestBlockedBySafeBrowsing(t *testing.T) {
|
||||||
Hostname: hostname,
|
Hostname: hostname,
|
||||||
Block: true,
|
Block: true,
|
||||||
}
|
}
|
||||||
ans, _ := (&aghtest.TestResolver{}).HostToIPs(hostname)
|
ans4, _ := (&aghtest.TestResolver{}).HostToIPs(hostname)
|
||||||
|
|
||||||
filterConf := &dnsfilter.Config{
|
filterConf := &dnsfilter.Config{
|
||||||
SafeBrowsingEnabled: true,
|
SafeBrowsingEnabled: true,
|
||||||
|
@ -639,7 +783,7 @@ func TestBlockedBySafeBrowsing(t *testing.T) {
|
||||||
UDPListenAddr: &net.UDPAddr{},
|
UDPListenAddr: &net.UDPAddr{},
|
||||||
TCPListenAddr: &net.TCPAddr{},
|
TCPListenAddr: &net.TCPAddr{},
|
||||||
FilteringConfig: FilteringConfig{
|
FilteringConfig: FilteringConfig{
|
||||||
SafeBrowsingBlockHost: ans.String(),
|
SafeBrowsingBlockHost: ans4.String(),
|
||||||
ProtectionEnabled: true,
|
ProtectionEnabled: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -652,13 +796,12 @@ func TestBlockedBySafeBrowsing(t *testing.T) {
|
||||||
req := createTestMessage(hostname + ".")
|
req := createTestMessage(hostname + ".")
|
||||||
|
|
||||||
reply, err := dns.Exchange(req, addr.String())
|
reply, err := dns.Exchange(req, addr.String())
|
||||||
assert.Nilf(t, err, "couldn't talk to server %s: %s", addr, err)
|
require.Nilf(t, err, "couldn't talk to server %s: %s", addr, err)
|
||||||
assert.Lenf(t, reply.Answer, 1, "dns server %s returned reply with wrong number of answers - %d", addr, len(reply.Answer))
|
require.Lenf(t, reply.Answer, 1, "dns server %s returned reply with wrong number of answers - %d", addr, len(reply.Answer))
|
||||||
|
|
||||||
a, ok := reply.Answer[0].(*dns.A)
|
a, ok := reply.Answer[0].(*dns.A)
|
||||||
if assert.Truef(t, ok, "dns server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0]) {
|
require.Truef(t, ok, "dns server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0])
|
||||||
assert.Equal(t, ans, a.A, "dns server %s returned wrong answer: %v", addr, a.A)
|
assert.Equal(t, ans4, a.A, "dns server %s returned wrong answer: %v", addr, a.A)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRewrite(t *testing.T) {
|
func TestRewrite(t *testing.T) {
|
||||||
|
@ -680,14 +823,14 @@ func TestRewrite(t *testing.T) {
|
||||||
f := dnsfilter.New(c, nil)
|
f := dnsfilter.New(c, nil)
|
||||||
|
|
||||||
s := NewServer(DNSCreateParams{DNSFilter: f})
|
s := NewServer(DNSCreateParams{DNSFilter: f})
|
||||||
err := s.Prepare(&ServerConfig{
|
assert.Nil(t, s.Prepare(&ServerConfig{
|
||||||
UDPListenAddr: &net.UDPAddr{},
|
UDPListenAddr: &net.UDPAddr{},
|
||||||
TCPListenAddr: &net.TCPAddr{},
|
TCPListenAddr: &net.TCPAddr{},
|
||||||
FilteringConfig: FilteringConfig{
|
FilteringConfig: FilteringConfig{
|
||||||
ProtectionEnabled: true,
|
ProtectionEnabled: true,
|
||||||
UpstreamDNS: []string{"8.8.8.8:53"},
|
UpstreamDNS: []string{"8.8.8.8:53"},
|
||||||
},
|
},
|
||||||
})
|
}))
|
||||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
||||||
&aghtest.TestUpstream{
|
&aghtest.TestUpstream{
|
||||||
CName: map[string]string{
|
CName: map[string]string{
|
||||||
|
@ -698,185 +841,44 @@ func TestRewrite(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
assert.Nil(t, err)
|
|
||||||
startDeferStop(t, s)
|
startDeferStop(t, s)
|
||||||
|
|
||||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||||
|
|
||||||
req := createTestMessageWithType("test.com.", dns.TypeA)
|
req := createTestMessageWithType("test.com.", dns.TypeA)
|
||||||
reply, err := dns.Exchange(req, addr.String())
|
reply, err := dns.Exchange(req, addr.String())
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
assert.Len(t, reply.Answer, 1)
|
require.Len(t, reply.Answer, 1)
|
||||||
a, ok := reply.Answer[0].(*dns.A)
|
a, ok := reply.Answer[0].(*dns.A)
|
||||||
assert.True(t, ok)
|
require.True(t, ok)
|
||||||
assert.True(t, net.IP{1, 2, 3, 4}.Equal(a.A))
|
assert.True(t, net.IP{1, 2, 3, 4}.Equal(a.A))
|
||||||
|
|
||||||
req = createTestMessageWithType("test.com.", dns.TypeAAAA)
|
req = createTestMessageWithType("test.com.", dns.TypeAAAA)
|
||||||
reply, err = dns.Exchange(req, addr.String())
|
reply, err = dns.Exchange(req, addr.String())
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
assert.Empty(t, reply.Answer)
|
assert.Empty(t, reply.Answer)
|
||||||
|
|
||||||
req = createTestMessageWithType("alias.test.com.", dns.TypeA)
|
req = createTestMessageWithType("alias.test.com.", dns.TypeA)
|
||||||
reply, err = dns.Exchange(req, addr.String())
|
reply, err = dns.Exchange(req, addr.String())
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
assert.Len(t, reply.Answer, 2)
|
|
||||||
|
require.Len(t, reply.Answer, 2)
|
||||||
assert.Equal(t, "test.com.", reply.Answer[0].(*dns.CNAME).Target)
|
assert.Equal(t, "test.com.", reply.Answer[0].(*dns.CNAME).Target)
|
||||||
assert.True(t, net.IP{1, 2, 3, 4}.Equal(reply.Answer[1].(*dns.A).A))
|
assert.True(t, net.IP{1, 2, 3, 4}.Equal(reply.Answer[1].(*dns.A).A))
|
||||||
|
|
||||||
req = createTestMessageWithType("my.alias.example.org.", dns.TypeA)
|
req = createTestMessageWithType("my.alias.example.org.", dns.TypeA)
|
||||||
reply, err = dns.Exchange(req, addr.String())
|
reply, err = dns.Exchange(req, addr.String())
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
// The original question is restored.
|
// The original question is restored.
|
||||||
|
require.Len(t, reply.Question, 1)
|
||||||
assert.Equal(t, "my.alias.example.org.", reply.Question[0].Name)
|
assert.Equal(t, "my.alias.example.org.", reply.Question[0].Name)
|
||||||
assert.Len(t, reply.Answer, 2)
|
|
||||||
|
require.Len(t, reply.Answer, 2)
|
||||||
assert.Equal(t, "example.org.", reply.Answer[0].(*dns.CNAME).Target)
|
assert.Equal(t, "example.org.", reply.Answer[0].(*dns.CNAME).Target)
|
||||||
assert.Equal(t, dns.TypeA, reply.Answer[1].Header().Rrtype)
|
assert.Equal(t, dns.TypeA, reply.Answer[1].Header().Rrtype)
|
||||||
}
|
}
|
||||||
|
|
||||||
func createTestServer(t *testing.T, filterConf *dnsfilter.Config, forwardConf ServerConfig) *Server {
|
|
||||||
rules := `||nxdomain.example.org
|
|
||||||
||null.example.org^
|
|
||||||
127.0.0.1 host.example.org
|
|
||||||
@@||whitelist.example.org^
|
|
||||||
||127.0.0.255`
|
|
||||||
filters := []dnsfilter.Filter{{
|
|
||||||
ID: 0, Data: []byte(rules),
|
|
||||||
}}
|
|
||||||
|
|
||||||
f := dnsfilter.New(filterConf, filters)
|
|
||||||
|
|
||||||
s := NewServer(DNSCreateParams{DNSFilter: f})
|
|
||||||
s.conf = forwardConf
|
|
||||||
assert.Nil(t, s.Prepare(nil))
|
|
||||||
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
func createServerTLSConfig(t *testing.T) (*tls.Config, []byte, []byte) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
||||||
assert.Nilf(t, err, "cannot generate RSA key: %s", err)
|
|
||||||
|
|
||||||
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
|
|
||||||
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
|
|
||||||
assert.Nilf(t, err, "failed to generate serial number: %s", err)
|
|
||||||
|
|
||||||
notBefore := time.Now()
|
|
||||||
notAfter := notBefore.Add(5 * 365 * time.Hour * 24)
|
|
||||||
|
|
||||||
template := x509.Certificate{
|
|
||||||
SerialNumber: serialNumber,
|
|
||||||
Subject: pkix.Name{
|
|
||||||
Organization: []string{"AdGuard Tests"},
|
|
||||||
},
|
|
||||||
NotBefore: notBefore,
|
|
||||||
NotAfter: notAfter,
|
|
||||||
|
|
||||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
|
|
||||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
|
||||||
BasicConstraintsValid: true,
|
|
||||||
IsCA: true,
|
|
||||||
}
|
|
||||||
template.DNSNames = append(template.DNSNames, tlsServerName)
|
|
||||||
|
|
||||||
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey(privateKey), privateKey)
|
|
||||||
assert.Nilf(t, err, "failed to create certificate: %s", err)
|
|
||||||
|
|
||||||
certPem := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
|
|
||||||
keyPem := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)})
|
|
||||||
|
|
||||||
cert, err := tls.X509KeyPair(certPem, keyPem)
|
|
||||||
assert.Nilf(t, err, "failed to create certificate: %s", err)
|
|
||||||
|
|
||||||
return &tls.Config{
|
|
||||||
Certificates: []tls.Certificate{cert},
|
|
||||||
ServerName: tlsServerName,
|
|
||||||
MinVersion: tls.VersionTLS12,
|
|
||||||
}, certPem, keyPem
|
|
||||||
}
|
|
||||||
|
|
||||||
// sendTestMessagesAsync sends messages in parallel to check for race issues.
|
|
||||||
//lint:ignore U1000 it's called from the function which is skipped for now.
|
|
||||||
func sendTestMessagesAsync(t *testing.T, conn *dns.Conn) {
|
|
||||||
wg := &sync.WaitGroup{}
|
|
||||||
|
|
||||||
for i := 0; i < testMessagesCount; i++ {
|
|
||||||
msg := createGoogleATestMessage()
|
|
||||||
wg.Add(1)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
|
|
||||||
err := conn.WriteMsg(msg)
|
|
||||||
assert.Nilf(t, err, "cannot write message: %s", err)
|
|
||||||
|
|
||||||
res, err := conn.ReadMsg()
|
|
||||||
assert.Nilf(t, err, "cannot read response to message: %s", err)
|
|
||||||
|
|
||||||
assertGoogleAResponse(t, res)
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
wg.Wait()
|
|
||||||
}
|
|
||||||
|
|
||||||
func sendTestMessages(t *testing.T, conn *dns.Conn) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
for i := 0; i < testMessagesCount; i++ {
|
|
||||||
req := createGoogleATestMessage()
|
|
||||||
err := conn.WriteMsg(req)
|
|
||||||
assert.Nilf(t, err, "cannot write message #%d: %s", i, err)
|
|
||||||
|
|
||||||
res, err := conn.ReadMsg()
|
|
||||||
assert.Nilf(t, err, "cannot read response to message #%d: %s", i, err)
|
|
||||||
assertGoogleAResponse(t, res)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func createGoogleATestMessage() *dns.Msg {
|
|
||||||
return createTestMessage("google-public-dns-a.google.com.")
|
|
||||||
}
|
|
||||||
|
|
||||||
func createTestMessage(host string) *dns.Msg {
|
|
||||||
return &dns.Msg{
|
|
||||||
MsgHdr: dns.MsgHdr{
|
|
||||||
Id: dns.Id(),
|
|
||||||
RecursionDesired: true,
|
|
||||||
},
|
|
||||||
Question: []dns.Question{{
|
|
||||||
Name: host,
|
|
||||||
Qtype: dns.TypeA,
|
|
||||||
Qclass: dns.ClassINET,
|
|
||||||
}},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func createTestMessageWithType(host string, qtype uint16) *dns.Msg {
|
|
||||||
req := createTestMessage(host)
|
|
||||||
req.Question[0].Qtype = qtype
|
|
||||||
|
|
||||||
return req
|
|
||||||
}
|
|
||||||
|
|
||||||
func assertGoogleAResponse(t *testing.T, reply *dns.Msg) {
|
|
||||||
assertResponse(t, reply, net.IP{8, 8, 8, 8})
|
|
||||||
}
|
|
||||||
|
|
||||||
func assertResponse(t *testing.T, reply *dns.Msg, ip net.IP) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
if !assert.Lenf(t, reply.Answer, 1, "dns server returned reply with wrong number of answers - %d", len(reply.Answer)) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
a, ok := reply.Answer[0].(*dns.A)
|
|
||||||
if assert.Truef(t, ok, "dns server returned wrong answer type instead of A: %v", reply.Answer[0]) {
|
|
||||||
assert.Truef(t, a.A.Equal(ip), "dns server returned wrong answer instead of %s: %s", ip, a.A)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func publicKey(priv interface{}) interface{} {
|
func publicKey(priv interface{}) interface{} {
|
||||||
switch k := priv.(type) {
|
switch k := priv.(type) {
|
||||||
case *rsa.PrivateKey:
|
case *rsa.PrivateKey:
|
||||||
|
@ -966,8 +968,8 @@ func TestValidateUpstream(t *testing.T) {
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
defaultUpstream, err := validateUpstream(tc.upstream)
|
defaultUpstream, err := validateUpstream(tc.upstream)
|
||||||
assert.Equal(t, tc.valid, err == nil)
|
require.Equal(t, tc.valid, err == nil)
|
||||||
if err == nil {
|
if tc.valid {
|
||||||
assert.Equal(t, tc.wantDef, defaultUpstream)
|
assert.Equal(t, tc.wantDef, defaultUpstream)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -975,42 +977,73 @@ func TestValidateUpstream(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestValidateUpstreamsSet(t *testing.T) {
|
func TestValidateUpstreamsSet(t *testing.T) {
|
||||||
// Empty upstreams array.
|
testCases := []struct {
|
||||||
var upstreamsSet []string
|
name string
|
||||||
assert.Nil(t, ValidateUpstreams(upstreamsSet), "empty upstreams array should be valid")
|
msg string
|
||||||
|
set []string
|
||||||
|
wantNil bool
|
||||||
|
}{{
|
||||||
|
name: "empty",
|
||||||
|
msg: "empty upstreams array should be valid",
|
||||||
|
set: nil,
|
||||||
|
wantNil: true,
|
||||||
|
}, {
|
||||||
|
name: "comment",
|
||||||
|
msg: "comments should not be validated",
|
||||||
|
set: []string{"# comment"},
|
||||||
|
wantNil: true,
|
||||||
|
}, {
|
||||||
|
name: "valid_no_default",
|
||||||
|
msg: "there is no default upstream",
|
||||||
|
set: []string{
|
||||||
|
"[/host.com/]1.1.1.1",
|
||||||
|
"[//]tls://1.1.1.1",
|
||||||
|
"[/www.host.com/]#",
|
||||||
|
"[/host.com/google.com/]8.8.8.8",
|
||||||
|
"[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
|
||||||
|
},
|
||||||
|
wantNil: false,
|
||||||
|
}, {
|
||||||
|
name: "valid_with_default",
|
||||||
|
msg: "upstreams set is valid, but doesn't pass through validation cause: %s",
|
||||||
|
set: []string{
|
||||||
|
"[/host.com/]1.1.1.1",
|
||||||
|
"[//]tls://1.1.1.1",
|
||||||
|
"[/www.host.com/]#",
|
||||||
|
"[/host.com/google.com/]8.8.8.8",
|
||||||
|
"[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
|
||||||
|
"8.8.8.8",
|
||||||
|
},
|
||||||
|
wantNil: true,
|
||||||
|
}, {
|
||||||
|
name: "invalid",
|
||||||
|
msg: "there is an invalid upstream in set, but it pass through validation",
|
||||||
|
set: []string{"dhcp://fake.dns"},
|
||||||
|
wantNil: false,
|
||||||
|
}}
|
||||||
|
|
||||||
// Comment in upstreams array.
|
for _, tc := range testCases {
|
||||||
upstreamsSet = []string{"# comment"}
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
assert.Nil(t, ValidateUpstreams(upstreamsSet), "comments should not be validated")
|
err := ValidateUpstreams(tc.set)
|
||||||
|
|
||||||
// Set of valid upstreams. There is no default upstream specified.
|
assert.Equalf(t, tc.wantNil, err == nil, tc.msg, err)
|
||||||
upstreamsSet = []string{
|
})
|
||||||
"[/host.com/]1.1.1.1",
|
|
||||||
"[//]tls://1.1.1.1",
|
|
||||||
"[/www.host.com/]#",
|
|
||||||
"[/host.com/google.com/]8.8.8.8",
|
|
||||||
"[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
|
|
||||||
}
|
}
|
||||||
assert.NotNil(t, ValidateUpstreams(upstreamsSet), "there is no default upstream")
|
|
||||||
|
|
||||||
// Let's add default upstream.
|
|
||||||
upstreamsSet = append(upstreamsSet, "8.8.8.8")
|
|
||||||
err := ValidateUpstreams(upstreamsSet)
|
|
||||||
assert.Nilf(t, err, "upstreams set is valid, but doesn't pass through validation cause: %s", err)
|
|
||||||
|
|
||||||
// Let's add invalid upstream.
|
|
||||||
upstreamsSet = append(upstreamsSet, "dhcp://fake.dns")
|
|
||||||
assert.NotNil(t, ValidateUpstreams(upstreamsSet), "there is an invalid upstream in set, but it pass through validation")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIPStringFromAddr(t *testing.T) {
|
func TestIPStringFromAddr(t *testing.T) {
|
||||||
addr := net.UDPAddr{
|
t.Run("not_nil", func(t *testing.T) {
|
||||||
IP: net.ParseIP("1:2:3::4"),
|
addr := net.UDPAddr{
|
||||||
Port: 12345,
|
IP: net.ParseIP("1:2:3::4"),
|
||||||
Zone: "eth0",
|
Port: 12345,
|
||||||
}
|
Zone: "eth0",
|
||||||
assert.Equal(t, IPStringFromAddr(&addr), addr.IP.String())
|
}
|
||||||
assert.Empty(t, IPStringFromAddr(nil))
|
assert.Equal(t, IPStringFromAddr(&addr), addr.IP.String())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("nil", func(t *testing.T) {
|
||||||
|
assert.Empty(t, IPStringFromAddr(nil))
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMatchDNSName(t *testing.T) {
|
func TestMatchDNSName(t *testing.T) {
|
||||||
|
@ -1071,38 +1104,33 @@ func (d *testDHCP) Leases(flags int) []dhcpd.Lease {
|
||||||
func (d *testDHCP) SetOnLeaseChanged(onLeaseChanged dhcpd.OnLeaseChangedT) {}
|
func (d *testDHCP) SetOnLeaseChanged(onLeaseChanged dhcpd.OnLeaseChangedT) {}
|
||||||
|
|
||||||
func TestPTRResponseFromDHCPLeases(t *testing.T) {
|
func TestPTRResponseFromDHCPLeases(t *testing.T) {
|
||||||
dhcp := &testDHCP{}
|
|
||||||
|
|
||||||
s := NewServer(DNSCreateParams{
|
s := NewServer(DNSCreateParams{
|
||||||
DNSFilter: dnsfilter.New(&dnsfilter.Config{}, nil),
|
DNSFilter: dnsfilter.New(&dnsfilter.Config{}, nil),
|
||||||
DHCPServer: dhcp,
|
DHCPServer: &testDHCP{},
|
||||||
})
|
})
|
||||||
|
|
||||||
s.conf.UDPListenAddr = &net.UDPAddr{}
|
s.conf.UDPListenAddr = &net.UDPAddr{}
|
||||||
s.conf.TCPListenAddr = &net.TCPAddr{}
|
s.conf.TCPListenAddr = &net.TCPAddr{}
|
||||||
s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
|
s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
|
||||||
s.conf.FilteringConfig.ProtectionEnabled = true
|
s.conf.FilteringConfig.ProtectionEnabled = true
|
||||||
err := s.Prepare(nil)
|
require.Nil(t, s.Prepare(nil))
|
||||||
assert.Nil(t, err)
|
require.Nil(t, s.Start())
|
||||||
|
t.Cleanup(func() {
|
||||||
|
s.Close()
|
||||||
|
})
|
||||||
|
|
||||||
assert.Nil(t, s.Start())
|
|
||||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||||
|
|
||||||
req := createTestMessageWithType("1.0.0.127.in-addr.arpa.", dns.TypePTR)
|
req := createTestMessageWithType("1.0.0.127.in-addr.arpa.", dns.TypePTR)
|
||||||
|
|
||||||
resp, err := dns.Exchange(req, addr.String())
|
resp, err := dns.Exchange(req, addr.String())
|
||||||
|
require.Nil(t, err)
|
||||||
assert.Nil(t, err)
|
require.Len(t, resp.Answer, 1)
|
||||||
assert.Len(t, resp.Answer, 1)
|
|
||||||
assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype)
|
assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype)
|
||||||
assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name)
|
assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name)
|
||||||
|
|
||||||
ptr, ok := resp.Answer[0].(*dns.PTR)
|
ptr, ok := resp.Answer[0].(*dns.PTR)
|
||||||
if assert.True(t, ok) {
|
require.True(t, ok)
|
||||||
assert.Equal(t, "localhost.", ptr.Ptr)
|
assert.Equal(t, "localhost.", ptr.Ptr)
|
||||||
}
|
|
||||||
|
|
||||||
s.Close()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPTRResponseFromHosts(t *testing.T) {
|
func TestPTRResponseFromHosts(t *testing.T) {
|
||||||
|
@ -1112,12 +1140,11 @@ func TestPTRResponseFromHosts(t *testing.T) {
|
||||||
|
|
||||||
// Prepare test hosts file.
|
// Prepare test hosts file.
|
||||||
hf, err := ioutil.TempFile("", "")
|
hf, err := ioutil.TempFile("", "")
|
||||||
if assert.Nil(t, err) {
|
require.Nil(t, err)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
assert.Nil(t, hf.Close())
|
assert.Nil(t, hf.Close())
|
||||||
assert.Nil(t, os.Remove(hf.Name()))
|
assert.Nil(t, os.Remove(hf.Name()))
|
||||||
})
|
})
|
||||||
}
|
|
||||||
|
|
||||||
_, _ = hf.WriteString(" 127.0.0.1 host # comment \n")
|
_, _ = hf.WriteString(" 127.0.0.1 host # comment \n")
|
||||||
_, _ = hf.WriteString(" ::1 localhost#comment \n")
|
_, _ = hf.WriteString(" ::1 localhost#comment \n")
|
||||||
|
@ -1131,23 +1158,23 @@ func TestPTRResponseFromHosts(t *testing.T) {
|
||||||
s.conf.TCPListenAddr = &net.TCPAddr{}
|
s.conf.TCPListenAddr = &net.TCPAddr{}
|
||||||
s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
|
s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
|
||||||
s.conf.FilteringConfig.ProtectionEnabled = true
|
s.conf.FilteringConfig.ProtectionEnabled = true
|
||||||
assert.Nil(t, s.Prepare(nil))
|
require.Nil(t, s.Prepare(nil))
|
||||||
|
|
||||||
assert.Nil(t, s.Start())
|
require.Nil(t, s.Start())
|
||||||
|
t.Cleanup(func() {
|
||||||
|
s.Close()
|
||||||
|
})
|
||||||
|
|
||||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||||
req := createTestMessageWithType("1.0.0.127.in-addr.arpa.", dns.TypePTR)
|
req := createTestMessageWithType("1.0.0.127.in-addr.arpa.", dns.TypePTR)
|
||||||
|
|
||||||
resp, err := dns.Exchange(req, addr.String())
|
resp, err := dns.Exchange(req, addr.String())
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
assert.Len(t, resp.Answer, 1)
|
require.Len(t, resp.Answer, 1)
|
||||||
assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype)
|
assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype)
|
||||||
assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name)
|
assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name)
|
||||||
|
|
||||||
ptr, ok := resp.Answer[0].(*dns.PTR)
|
ptr, ok := resp.Answer[0].(*dns.PTR)
|
||||||
if assert.True(t, ok) {
|
require.True(t, ok)
|
||||||
assert.Equal(t, "host.", ptr.Ptr)
|
assert.Equal(t, "host.", ptr.Ptr)
|
||||||
}
|
|
||||||
|
|
||||||
s.Close()
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"github.com/AdguardTeam/urlfilter/rules"
|
"github.com/AdguardTeam/urlfilter/rules"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestServer_FilterDNSRewrite(t *testing.T) {
|
func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||||
|
@ -54,7 +55,8 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||||
d := &proxy.DNSContext{}
|
d := &proxy.DNSContext{}
|
||||||
|
|
||||||
err := srv.filterDNSRewrite(req, res, d)
|
err := srv.filterDNSRewrite(req, res, d)
|
||||||
assert.Nil(t, err)
|
|
||||||
|
require.Nil(t, err)
|
||||||
assert.Equal(t, dns.RcodeNameError, d.Res.Rcode)
|
assert.Equal(t, dns.RcodeNameError, d.Res.Rcode)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -64,7 +66,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||||
d := &proxy.DNSContext{}
|
d := &proxy.DNSContext{}
|
||||||
|
|
||||||
err := srv.filterDNSRewrite(req, res, d)
|
err := srv.filterDNSRewrite(req, res, d)
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
|
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
|
||||||
assert.Empty(t, d.Res.Answer)
|
assert.Empty(t, d.Res.Answer)
|
||||||
})
|
})
|
||||||
|
@ -75,11 +77,11 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||||
d := &proxy.DNSContext{}
|
d := &proxy.DNSContext{}
|
||||||
|
|
||||||
err := srv.filterDNSRewrite(req, res, d)
|
err := srv.filterDNSRewrite(req, res, d)
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
|
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
|
||||||
if assert.Len(t, d.Res.Answer, 1) {
|
|
||||||
assert.Equal(t, ip4, d.Res.Answer[0].(*dns.A).A)
|
require.Len(t, d.Res.Answer, 1)
|
||||||
}
|
assert.Equal(t, ip4, d.Res.Answer[0].(*dns.A).A)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("noerror_aaaa", func(t *testing.T) {
|
t.Run("noerror_aaaa", func(t *testing.T) {
|
||||||
|
@ -88,11 +90,11 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||||
d := &proxy.DNSContext{}
|
d := &proxy.DNSContext{}
|
||||||
|
|
||||||
err := srv.filterDNSRewrite(req, res, d)
|
err := srv.filterDNSRewrite(req, res, d)
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
|
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
|
||||||
if assert.Len(t, d.Res.Answer, 1) {
|
|
||||||
assert.Equal(t, ip6, d.Res.Answer[0].(*dns.AAAA).AAAA)
|
require.Len(t, d.Res.Answer, 1)
|
||||||
}
|
assert.Equal(t, ip6, d.Res.Answer[0].(*dns.AAAA).AAAA)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("noerror_ptr", func(t *testing.T) {
|
t.Run("noerror_ptr", func(t *testing.T) {
|
||||||
|
@ -101,11 +103,11 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||||
d := &proxy.DNSContext{}
|
d := &proxy.DNSContext{}
|
||||||
|
|
||||||
err := srv.filterDNSRewrite(req, res, d)
|
err := srv.filterDNSRewrite(req, res, d)
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
|
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
|
||||||
if assert.Len(t, d.Res.Answer, 1) {
|
|
||||||
assert.Equal(t, domain, d.Res.Answer[0].(*dns.PTR).Ptr)
|
require.Len(t, d.Res.Answer, 1)
|
||||||
}
|
assert.Equal(t, domain, d.Res.Answer[0].(*dns.PTR).Ptr)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("noerror_txt", func(t *testing.T) {
|
t.Run("noerror_txt", func(t *testing.T) {
|
||||||
|
@ -114,11 +116,11 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||||
d := &proxy.DNSContext{}
|
d := &proxy.DNSContext{}
|
||||||
|
|
||||||
err := srv.filterDNSRewrite(req, res, d)
|
err := srv.filterDNSRewrite(req, res, d)
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
|
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
|
||||||
if assert.Len(t, d.Res.Answer, 1) {
|
|
||||||
assert.Equal(t, []string{domain}, d.Res.Answer[0].(*dns.TXT).Txt)
|
require.Len(t, d.Res.Answer, 1)
|
||||||
}
|
assert.Equal(t, []string{domain}, d.Res.Answer[0].(*dns.TXT).Txt)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("noerror_mx", func(t *testing.T) {
|
t.Run("noerror_mx", func(t *testing.T) {
|
||||||
|
@ -127,15 +129,15 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||||
d := &proxy.DNSContext{}
|
d := &proxy.DNSContext{}
|
||||||
|
|
||||||
err := srv.filterDNSRewrite(req, res, d)
|
err := srv.filterDNSRewrite(req, res, d)
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
|
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
|
||||||
if assert.Len(t, d.Res.Answer, 1) {
|
|
||||||
ans, ok := d.Res.Answer[0].(*dns.MX)
|
require.Len(t, d.Res.Answer, 1)
|
||||||
if assert.True(t, ok) {
|
ans, ok := d.Res.Answer[0].(*dns.MX)
|
||||||
assert.Equal(t, mx.Exchange, ans.Mx)
|
|
||||||
assert.Equal(t, mx.Preference, ans.Preference)
|
require.True(t, ok)
|
||||||
}
|
assert.Equal(t, mx.Exchange, ans.Mx)
|
||||||
}
|
assert.Equal(t, mx.Preference, ans.Preference)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("noerror_svcb", func(t *testing.T) {
|
t.Run("noerror_svcb", func(t *testing.T) {
|
||||||
|
@ -144,17 +146,17 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||||
d := &proxy.DNSContext{}
|
d := &proxy.DNSContext{}
|
||||||
|
|
||||||
err := srv.filterDNSRewrite(req, res, d)
|
err := srv.filterDNSRewrite(req, res, d)
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
|
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
|
||||||
if assert.Len(t, d.Res.Answer, 1) {
|
|
||||||
ans, ok := d.Res.Answer[0].(*dns.SVCB)
|
require.Len(t, d.Res.Answer, 1)
|
||||||
if assert.True(t, ok) {
|
ans, ok := d.Res.Answer[0].(*dns.SVCB)
|
||||||
assert.Equal(t, dns.SVCB_ALPN, ans.Value[0].Key())
|
require.True(t, ok)
|
||||||
assert.Equal(t, svcb.Params["alpn"], ans.Value[0].String())
|
|
||||||
assert.Equal(t, svcb.Target, ans.Target)
|
assert.Equal(t, dns.SVCB_ALPN, ans.Value[0].Key())
|
||||||
assert.Equal(t, svcb.Priority, ans.Priority)
|
assert.Equal(t, svcb.Params["alpn"], ans.Value[0].String())
|
||||||
}
|
assert.Equal(t, svcb.Target, ans.Target)
|
||||||
}
|
assert.Equal(t, svcb.Priority, ans.Priority)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("noerror_https", func(t *testing.T) {
|
t.Run("noerror_https", func(t *testing.T) {
|
||||||
|
@ -163,16 +165,16 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||||
d := &proxy.DNSContext{}
|
d := &proxy.DNSContext{}
|
||||||
|
|
||||||
err := srv.filterDNSRewrite(req, res, d)
|
err := srv.filterDNSRewrite(req, res, d)
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
|
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
|
||||||
if assert.Len(t, d.Res.Answer, 1) {
|
|
||||||
ans, ok := d.Res.Answer[0].(*dns.HTTPS)
|
require.Len(t, d.Res.Answer, 1)
|
||||||
if assert.True(t, ok) {
|
ans, ok := d.Res.Answer[0].(*dns.HTTPS)
|
||||||
assert.Equal(t, dns.SVCB_ALPN, ans.Value[0].Key())
|
|
||||||
assert.Equal(t, svcb.Params["alpn"], ans.Value[0].String())
|
require.True(t, ok)
|
||||||
assert.Equal(t, svcb.Target, ans.Target)
|
assert.Equal(t, dns.SVCB_ALPN, ans.Value[0].Key())
|
||||||
assert.Equal(t, svcb.Priority, ans.Priority)
|
assert.Equal(t, svcb.Params["alpn"], ans.Value[0].String())
|
||||||
}
|
assert.Equal(t, svcb.Target, ans.Target)
|
||||||
}
|
assert.Equal(t, svcb.Priority, ans.Priority)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsfilter"
|
"github.com/AdguardTeam/AdGuardHome/internal/dnsfilter"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDNSForwardHTTTP_handleGetConfig(t *testing.T) {
|
func TestDNSForwardHTTTP_handleGetConfig(t *testing.T) {
|
||||||
|
@ -31,9 +32,10 @@ func TestDNSForwardHTTTP_handleGetConfig(t *testing.T) {
|
||||||
ConfigModified: func() {},
|
ConfigModified: func() {},
|
||||||
}
|
}
|
||||||
s := createTestServer(t, filterConf, forwardConf)
|
s := createTestServer(t, filterConf, forwardConf)
|
||||||
err := s.Start()
|
require.Nil(t, s.Start())
|
||||||
assert.Nil(t, err)
|
t.Cleanup(func() {
|
||||||
defer assert.Nil(t, s.Stop())
|
require.Nil(t, s.Stop())
|
||||||
|
})
|
||||||
|
|
||||||
defaultConf := s.conf
|
defaultConf := s.conf
|
||||||
|
|
||||||
|
@ -71,13 +73,14 @@ func TestDNSForwardHTTTP_handleGetConfig(t *testing.T) {
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
t.Cleanup(w.Body.Reset)
|
||||||
|
|
||||||
s.conf = tc.conf()
|
s.conf = tc.conf()
|
||||||
s.handleGetConfig(w, nil)
|
s.handleGetConfig(w, nil)
|
||||||
assert.Equal(t, tc.want, w.Body.String())
|
|
||||||
|
|
||||||
assert.Equal(t, "application/json", w.Header().Get("Content-Type"))
|
assert.Equal(t, "application/json", w.Header().Get("Content-Type"))
|
||||||
|
assert.Equal(t, tc.want, w.Body.String())
|
||||||
})
|
})
|
||||||
w.Body.Reset()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -191,9 +194,13 @@ func TestDNSForwardHTTTP_handleSetConfig(t *testing.T) {
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
t.Cleanup(func() {
|
||||||
|
s.conf = defaultConf
|
||||||
|
})
|
||||||
|
|
||||||
rBody := ioutil.NopCloser(strings.NewReader(tc.req))
|
rBody := ioutil.NopCloser(strings.NewReader(tc.req))
|
||||||
r, err := http.NewRequest(http.MethodPost, "http://example.com", rBody)
|
r, err := http.NewRequest(http.MethodPost, "http://example.com", rBody)
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
s.handleSetConfig(w, r)
|
s.handleSetConfig(w, r)
|
||||||
assert.Equal(t, tc.wantSet, w.Body.String())
|
assert.Equal(t, tc.wantSet, w.Body.String())
|
||||||
|
@ -203,6 +210,5 @@ func TestDNSForwardHTTTP_handleSetConfig(t *testing.T) {
|
||||||
assert.Equal(t, tc.wantGet, w.Body.String())
|
assert.Equal(t, tc.wantGet, w.Body.String())
|
||||||
w.Body.Reset()
|
w.Body.Reset()
|
||||||
})
|
})
|
||||||
s.conf = defaultConf
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,6 +12,7 @@ import (
|
||||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
// testQueryLog is a simple querylog.QueryLog implementation for tests.
|
// testQueryLog is a simple querylog.QueryLog implementation for tests.
|
||||||
|
@ -156,7 +157,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
|
||||||
}}
|
}}
|
||||||
|
|
||||||
ups, err := upstream.AddressToUpstream("1.1.1.1", upstream.Options{})
|
ups, err := upstream.AddressToUpstream("1.1.1.1", upstream.Options{})
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
|
|
@ -20,10 +20,18 @@ func TestMain(m *testing.M) {
|
||||||
aghtest.DiscardLogOutput(m)
|
aghtest.DiscardLogOutput(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
func prepareTestDir() string {
|
func prepareTestDir(t *testing.T) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
const dir = "./agh-test"
|
const dir = "./agh-test"
|
||||||
_ = os.RemoveAll(dir)
|
|
||||||
_ = os.MkdirAll(dir, 0o755)
|
require.Nil(t, os.RemoveAll(dir))
|
||||||
|
// TODO(e.burkov): Replace with testing.TempDir after updating Go
|
||||||
|
// version to 1.16.
|
||||||
|
require.Nil(t, os.MkdirAll(dir, 0o755))
|
||||||
|
|
||||||
|
t.Cleanup(func() { require.Nil(t, os.RemoveAll(dir)) })
|
||||||
|
|
||||||
return dir
|
return dir
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -47,8 +55,7 @@ func TestNewSessionToken(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAuth(t *testing.T) {
|
func TestAuth(t *testing.T) {
|
||||||
dir := prepareTestDir()
|
dir := prepareTestDir(t)
|
||||||
t.Cleanup(func() { _ = os.RemoveAll(dir) })
|
|
||||||
fn := filepath.Join(dir, "sessions.db")
|
fn := filepath.Join(dir, "sessions.db")
|
||||||
|
|
||||||
users := []User{{
|
users := []User{{
|
||||||
|
@ -123,8 +130,7 @@ func (w *testResponseWriter) WriteHeader(statusCode int) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAuthHTTP(t *testing.T) {
|
func TestAuthHTTP(t *testing.T) {
|
||||||
dir := prepareTestDir()
|
dir := prepareTestDir(t)
|
||||||
defer func() { _ = os.RemoveAll(dir) }()
|
|
||||||
fn := filepath.Join(dir, "sessions.db")
|
fn := filepath.Join(dir, "sessions.db")
|
||||||
|
|
||||||
users := []User{
|
users := []User{
|
||||||
|
|
|
@ -4,40 +4,38 @@ import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestAuthGL(t *testing.T) {
|
func TestAuthGL(t *testing.T) {
|
||||||
dir := prepareTestDir()
|
dir := prepareTestDir(t)
|
||||||
defer func() { _ = os.RemoveAll(dir) }()
|
|
||||||
|
|
||||||
GLMode = true
|
GLMode = true
|
||||||
|
t.Cleanup(func() {
|
||||||
|
GLMode = false
|
||||||
|
})
|
||||||
glFilePrefix = dir + "/gl_token_"
|
glFilePrefix = dir + "/gl_token_"
|
||||||
|
|
||||||
tval := uint32(1)
|
putFunc := binary.BigEndian.PutUint32
|
||||||
data := make([]byte, 4)
|
|
||||||
if archIsLittleEndian() {
|
if archIsLittleEndian() {
|
||||||
binary.LittleEndian.PutUint32(data, tval)
|
putFunc = binary.LittleEndian.PutUint32
|
||||||
} else {
|
|
||||||
binary.BigEndian.PutUint32(data, tval)
|
|
||||||
}
|
}
|
||||||
assert.Nil(t, ioutil.WriteFile(glFilePrefix+"test", data, 0o644))
|
|
||||||
|
data := make([]byte, 4)
|
||||||
|
putFunc(data, 1)
|
||||||
|
|
||||||
|
require.Nil(t, ioutil.WriteFile(glFilePrefix+"test", data, 0o644))
|
||||||
assert.False(t, glCheckToken("test"))
|
assert.False(t, glCheckToken("test"))
|
||||||
|
|
||||||
tval = uint32(time.Now().UTC().Unix() + 60)
|
|
||||||
data = make([]byte, 4)
|
data = make([]byte, 4)
|
||||||
if archIsLittleEndian() {
|
putFunc(data, uint32(time.Now().UTC().Unix()+60))
|
||||||
binary.LittleEndian.PutUint32(data, tval)
|
|
||||||
} else {
|
require.Nil(t, ioutil.WriteFile(glFilePrefix+"test", data, 0o644))
|
||||||
binary.BigEndian.PutUint32(data, tval)
|
|
||||||
}
|
|
||||||
assert.Nil(t, ioutil.WriteFile(glFilePrefix+"test", data, 0o644))
|
|
||||||
r, _ := http.NewRequest(http.MethodGet, "http://localhost/", nil)
|
r, _ := http.NewRequest(http.MethodGet, "http://localhost/", nil)
|
||||||
r.AddCookie(&http.Cookie{Name: glCookieName, Value: "test"})
|
r.AddCookie(&http.Cookie{Name: glCookieName, Value: "test"})
|
||||||
assert.True(t, glProcessCookie(r))
|
assert.True(t, glProcessCookie(r))
|
||||||
GLMode = false
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestClients(t *testing.T) {
|
func TestClients(t *testing.T) {
|
||||||
|
@ -24,8 +25,8 @@ func TestClients(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
ok, err := clients.Add(c)
|
ok, err := clients.Add(c)
|
||||||
|
require.Nil(t, err)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
assert.Nil(t, err)
|
|
||||||
|
|
||||||
c = &Client{
|
c = &Client{
|
||||||
IDs: []string{"2.2.2.2"},
|
IDs: []string{"2.2.2.2"},
|
||||||
|
@ -33,110 +34,99 @@ func TestClients(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
ok, err = clients.Add(c)
|
ok, err = clients.Add(c)
|
||||||
|
require.Nil(t, err)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
assert.Nil(t, err)
|
|
||||||
|
|
||||||
c, ok = clients.Find("1.1.1.1")
|
c, ok = clients.Find("1.1.1.1")
|
||||||
assert.True(t, ok)
|
require.True(t, ok)
|
||||||
assert.Equal(t, "client1", c.Name)
|
assert.Equal(t, "client1", c.Name)
|
||||||
|
|
||||||
c, ok = clients.Find("1:2:3::4")
|
c, ok = clients.Find("1:2:3::4")
|
||||||
assert.True(t, ok)
|
require.True(t, ok)
|
||||||
assert.Equal(t, "client1", c.Name)
|
assert.Equal(t, "client1", c.Name)
|
||||||
|
|
||||||
c, ok = clients.Find("2.2.2.2")
|
c, ok = clients.Find("2.2.2.2")
|
||||||
assert.True(t, ok)
|
require.True(t, ok)
|
||||||
assert.Equal(t, "client2", c.Name)
|
assert.Equal(t, "client2", c.Name)
|
||||||
|
|
||||||
assert.True(t, !clients.Exists("1.2.3.4", ClientSourceHostsFile))
|
assert.False(t, clients.Exists("1.2.3.4", ClientSourceHostsFile))
|
||||||
assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
|
assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
|
||||||
assert.True(t, clients.Exists("2.2.2.2", ClientSourceHostsFile))
|
assert.True(t, clients.Exists("2.2.2.2", ClientSourceHostsFile))
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("add_fail_name", func(t *testing.T) {
|
t.Run("add_fail_name", func(t *testing.T) {
|
||||||
c := &Client{
|
ok, err := clients.Add(&Client{
|
||||||
IDs: []string{"1.2.3.5"},
|
IDs: []string{"1.2.3.5"},
|
||||||
Name: "client1",
|
Name: "client1",
|
||||||
}
|
})
|
||||||
|
require.Nil(t, err)
|
||||||
ok, err := clients.Add(c)
|
|
||||||
assert.False(t, ok)
|
assert.False(t, ok)
|
||||||
assert.Nil(t, err)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("add_fail_ip", func(t *testing.T) {
|
t.Run("add_fail_ip", func(t *testing.T) {
|
||||||
c := &Client{
|
ok, err := clients.Add(&Client{
|
||||||
IDs: []string{"2.2.2.2"},
|
IDs: []string{"2.2.2.2"},
|
||||||
Name: "client3",
|
Name: "client3",
|
||||||
}
|
})
|
||||||
|
require.NotNil(t, err)
|
||||||
ok, err := clients.Add(c)
|
|
||||||
assert.False(t, ok)
|
assert.False(t, ok)
|
||||||
assert.NotNil(t, err)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("update_fail_name", func(t *testing.T) {
|
t.Run("update_fail_name", func(t *testing.T) {
|
||||||
c := &Client{
|
err := clients.Update("client3", &Client{
|
||||||
IDs: []string{"1.2.3.0"},
|
IDs: []string{"1.2.3.0"},
|
||||||
Name: "client3",
|
Name: "client3",
|
||||||
}
|
})
|
||||||
|
require.NotNil(t, err)
|
||||||
|
|
||||||
err := clients.Update("client3", c)
|
err = clients.Update("client3", &Client{
|
||||||
assert.NotNil(t, err)
|
|
||||||
|
|
||||||
c = &Client{
|
|
||||||
IDs: []string{"1.2.3.0"},
|
IDs: []string{"1.2.3.0"},
|
||||||
Name: "client2",
|
Name: "client2",
|
||||||
}
|
})
|
||||||
|
|
||||||
err = clients.Update("client3", c)
|
|
||||||
assert.NotNil(t, err)
|
assert.NotNil(t, err)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("update_fail_ip", func(t *testing.T) {
|
t.Run("update_fail_ip", func(t *testing.T) {
|
||||||
c := &Client{
|
err := clients.Update("client1", &Client{
|
||||||
IDs: []string{"2.2.2.2"},
|
IDs: []string{"2.2.2.2"},
|
||||||
Name: "client1",
|
Name: "client1",
|
||||||
}
|
})
|
||||||
|
|
||||||
err := clients.Update("client1", c)
|
|
||||||
assert.NotNil(t, err)
|
assert.NotNil(t, err)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("update_success", func(t *testing.T) {
|
t.Run("update_success", func(t *testing.T) {
|
||||||
c := &Client{
|
err := clients.Update("client1", &Client{
|
||||||
IDs: []string{"1.1.1.2"},
|
IDs: []string{"1.1.1.2"},
|
||||||
Name: "client1",
|
Name: "client1",
|
||||||
}
|
})
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
err := clients.Update("client1", c)
|
assert.False(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
|
||||||
assert.Nil(t, err)
|
|
||||||
|
|
||||||
assert.True(t, !clients.Exists("1.1.1.1", ClientSourceHostsFile))
|
|
||||||
assert.True(t, clients.Exists("1.1.1.2", ClientSourceHostsFile))
|
assert.True(t, clients.Exists("1.1.1.2", ClientSourceHostsFile))
|
||||||
|
|
||||||
c = &Client{
|
err = clients.Update("client1", &Client{
|
||||||
IDs: []string{"1.1.1.2"},
|
IDs: []string{"1.1.1.2"},
|
||||||
Name: "client1-renamed",
|
Name: "client1-renamed",
|
||||||
UseOwnSettings: true,
|
UseOwnSettings: true,
|
||||||
}
|
})
|
||||||
|
require.Nil(t, err)
|
||||||
err = clients.Update("client1", c)
|
|
||||||
assert.Nil(t, err)
|
|
||||||
|
|
||||||
c, ok := clients.Find("1.1.1.2")
|
c, ok := clients.Find("1.1.1.2")
|
||||||
assert.True(t, ok)
|
require.True(t, ok)
|
||||||
assert.Equal(t, "client1-renamed", c.Name)
|
assert.Equal(t, "client1-renamed", c.Name)
|
||||||
assert.True(t, c.UseOwnSettings)
|
assert.True(t, c.UseOwnSettings)
|
||||||
assert.Nil(t, clients.list["client1"])
|
|
||||||
if assert.Len(t, c.IDs, 1) {
|
nilCli, ok := clients.list["client1"]
|
||||||
assert.Equal(t, "1.1.1.2", c.IDs[0])
|
require.False(t, ok)
|
||||||
}
|
assert.Nil(t, nilCli)
|
||||||
|
|
||||||
|
require.Len(t, c.IDs, 1)
|
||||||
|
assert.Equal(t, "1.1.1.2", c.IDs[0])
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("del_success", func(t *testing.T) {
|
t.Run("del_success", func(t *testing.T) {
|
||||||
ok := clients.Del("client1-renamed")
|
ok := clients.Del("client1-renamed")
|
||||||
assert.True(t, ok)
|
require.True(t, ok)
|
||||||
assert.False(t, clients.Exists("1.1.1.2", ClientSourceHostsFile))
|
assert.False(t, clients.Exists("1.1.1.2", ClientSourceHostsFile))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -147,146 +137,155 @@ func TestClients(t *testing.T) {
|
||||||
|
|
||||||
t.Run("addhost_success", func(t *testing.T) {
|
t.Run("addhost_success", func(t *testing.T) {
|
||||||
ok, err := clients.AddHost("1.1.1.1", "host", ClientSourceARP)
|
ok, err := clients.AddHost("1.1.1.1", "host", ClientSourceARP)
|
||||||
|
require.Nil(t, err)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
assert.Nil(t, err)
|
|
||||||
|
|
||||||
ok, err = clients.AddHost("1.1.1.1", "host2", ClientSourceARP)
|
ok, err = clients.AddHost("1.1.1.1", "host2", ClientSourceARP)
|
||||||
|
require.Nil(t, err)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
assert.Nil(t, err)
|
|
||||||
|
|
||||||
ok, err = clients.AddHost("1.1.1.1", "host3", ClientSourceHostsFile)
|
ok, err = clients.AddHost("1.1.1.1", "host3", ClientSourceHostsFile)
|
||||||
|
require.Nil(t, err)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
assert.Nil(t, err)
|
|
||||||
|
|
||||||
assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
|
assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("addhost_fail", func(t *testing.T) {
|
t.Run("addhost_fail", func(t *testing.T) {
|
||||||
ok, err := clients.AddHost("1.1.1.1", "host1", ClientSourceRDNS)
|
ok, err := clients.AddHost("1.1.1.1", "host1", ClientSourceRDNS)
|
||||||
|
require.Nil(t, err)
|
||||||
assert.False(t, ok)
|
assert.False(t, ok)
|
||||||
assert.Nil(t, err)
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClientsWhois(t *testing.T) {
|
func TestClientsWhois(t *testing.T) {
|
||||||
var c *Client
|
clients := clientsContainer{
|
||||||
clients := clientsContainer{}
|
testing: true,
|
||||||
clients.testing = true
|
}
|
||||||
clients.Init(nil, nil, nil)
|
clients.Init(nil, nil, nil)
|
||||||
|
|
||||||
whois := [][]string{{"orgname", "orgname-val"}, {"country", "country-val"}}
|
whois := [][]string{{"orgname", "orgname-val"}, {"country", "country-val"}}
|
||||||
// set whois info on new client
|
|
||||||
clients.SetWhoisInfo("1.1.1.255", whois)
|
t.Run("new_client", func(t *testing.T) {
|
||||||
if assert.NotNil(t, clients.ipHost["1.1.1.255"]) {
|
clients.SetWhoisInfo("1.1.1.255", whois)
|
||||||
|
|
||||||
|
require.NotNil(t, clients.ipHost["1.1.1.255"])
|
||||||
h := clients.ipHost["1.1.1.255"]
|
h := clients.ipHost["1.1.1.255"]
|
||||||
if assert.Len(t, h.WhoisInfo, 2) && assert.Len(t, h.WhoisInfo[0], 2) {
|
|
||||||
assert.Equal(t, "orgname-val", h.WhoisInfo[0][1])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// set whois info on existing auto-client
|
require.Len(t, h.WhoisInfo, 2)
|
||||||
_, _ = clients.AddHost("1.1.1.1", "host", ClientSourceRDNS)
|
require.Len(t, h.WhoisInfo[0], 2)
|
||||||
clients.SetWhoisInfo("1.1.1.1", whois)
|
assert.Equal(t, "orgname-val", h.WhoisInfo[0][1])
|
||||||
if assert.NotNil(t, clients.ipHost["1.1.1.1"]) {
|
})
|
||||||
|
|
||||||
|
t.Run("existing_auto-client", func(t *testing.T) {
|
||||||
|
ok, err := clients.AddHost("1.1.1.1", "host", ClientSourceRDNS)
|
||||||
|
require.Nil(t, err)
|
||||||
|
assert.True(t, ok)
|
||||||
|
|
||||||
|
clients.SetWhoisInfo("1.1.1.1", whois)
|
||||||
|
|
||||||
|
require.NotNil(t, clients.ipHost["1.1.1.1"])
|
||||||
h := clients.ipHost["1.1.1.1"]
|
h := clients.ipHost["1.1.1.1"]
|
||||||
if assert.Len(t, h.WhoisInfo, 2) && assert.Len(t, h.WhoisInfo[0], 2) {
|
|
||||||
assert.Equal(t, "orgname-val", h.WhoisInfo[0][1])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check that we cannot set whois info on a manually-added client
|
require.Len(t, h.WhoisInfo, 2)
|
||||||
c = &Client{
|
require.Len(t, h.WhoisInfo[0], 2)
|
||||||
IDs: []string{"1.1.1.2"},
|
assert.Equal(t, "orgname-val", h.WhoisInfo[0][1])
|
||||||
Name: "client1",
|
})
|
||||||
}
|
|
||||||
_, _ = clients.Add(c)
|
t.Run("can't_set_manually-added", func(t *testing.T) {
|
||||||
clients.SetWhoisInfo("1.1.1.2", whois)
|
ok, err := clients.Add(&Client{
|
||||||
assert.Nil(t, clients.ipHost["1.1.1.2"])
|
IDs: []string{"1.1.1.2"},
|
||||||
_ = clients.Del("client1")
|
Name: "client1",
|
||||||
|
})
|
||||||
|
require.Nil(t, err)
|
||||||
|
assert.True(t, ok)
|
||||||
|
|
||||||
|
clients.SetWhoisInfo("1.1.1.2", whois)
|
||||||
|
require.Nil(t, clients.ipHost["1.1.1.2"])
|
||||||
|
assert.True(t, clients.Del("client1"))
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClientsAddExisting(t *testing.T) {
|
func TestClientsAddExisting(t *testing.T) {
|
||||||
var c *Client
|
clients := clientsContainer{
|
||||||
clients := clientsContainer{}
|
testing: true,
|
||||||
clients.testing = true
|
}
|
||||||
clients.Init(nil, nil, nil)
|
clients.Init(nil, nil, nil)
|
||||||
|
|
||||||
// some test variables
|
t.Run("simple", func(t *testing.T) {
|
||||||
mac, _ := net.ParseMAC("aa:aa:aa:aa:aa:aa")
|
// Add a client.
|
||||||
testIP := "1.2.3.4"
|
ok, err := clients.Add(&Client{
|
||||||
|
IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa", "2.2.2.0/24"},
|
||||||
|
Name: "client1",
|
||||||
|
})
|
||||||
|
require.Nil(t, err)
|
||||||
|
assert.True(t, ok)
|
||||||
|
|
||||||
// add a client
|
// Now add an auto-client with the same IP.
|
||||||
c = &Client{
|
ok, err = clients.AddHost("1.1.1.1", "test", ClientSourceRDNS)
|
||||||
IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa", "2.2.2.0/24"},
|
require.Nil(t, err)
|
||||||
Name: "client1",
|
assert.True(t, ok)
|
||||||
}
|
|
||||||
ok, err := clients.Add(c)
|
|
||||||
assert.True(t, ok)
|
|
||||||
assert.Nil(t, err)
|
|
||||||
|
|
||||||
// add an auto-client with the same IP - it's allowed
|
|
||||||
ok, err = clients.AddHost("1.1.1.1", "test", ClientSourceRDNS)
|
|
||||||
assert.True(t, ok)
|
|
||||||
assert.Nil(t, err)
|
|
||||||
|
|
||||||
// now some more complicated stuff
|
|
||||||
// first, init a DHCP server with a single static lease
|
|
||||||
config := dhcpd.ServerConfig{
|
|
||||||
DBFilePath: "leases.db",
|
|
||||||
}
|
|
||||||
defer func() { _ = os.Remove("leases.db") }()
|
|
||||||
clients.dhcpServer = dhcpd.Create(config)
|
|
||||||
err = clients.dhcpServer.AddStaticLease(dhcpd.Lease{
|
|
||||||
HWAddr: mac,
|
|
||||||
IP: net.ParseIP(testIP).To4(),
|
|
||||||
Hostname: "testhost",
|
|
||||||
Expiry: time.Now().Add(time.Hour),
|
|
||||||
})
|
})
|
||||||
assert.Nil(t, err)
|
|
||||||
|
|
||||||
// add a new client with the same IP as for a client with MAC
|
t.Run("complicated", func(t *testing.T) {
|
||||||
c = &Client{
|
testIP := net.IP{1, 2, 3, 4}
|
||||||
IDs: []string{testIP},
|
|
||||||
Name: "client2",
|
|
||||||
}
|
|
||||||
ok, err = clients.Add(c)
|
|
||||||
assert.True(t, ok)
|
|
||||||
assert.Nil(t, err)
|
|
||||||
|
|
||||||
// add a new client with the IP from the client1's IP range
|
// First, init a DHCP server with a single static lease.
|
||||||
c = &Client{
|
config := dhcpd.ServerConfig{
|
||||||
IDs: []string{"2.2.2.2"},
|
DBFilePath: "leases.db",
|
||||||
Name: "client3",
|
}
|
||||||
}
|
clients.dhcpServer = dhcpd.Create(config)
|
||||||
ok, err = clients.Add(c)
|
t.Cleanup(func() { _ = os.Remove("leases.db") })
|
||||||
assert.True(t, ok)
|
|
||||||
assert.Nil(t, err)
|
err := clients.dhcpServer.AddStaticLease(dhcpd.Lease{
|
||||||
|
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||||
|
IP: testIP,
|
||||||
|
Hostname: "testhost",
|
||||||
|
Expiry: time.Now().Add(time.Hour),
|
||||||
|
})
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
// Add a new client with the same IP as for a client with MAC.
|
||||||
|
ok, err := clients.Add(&Client{
|
||||||
|
IDs: []string{testIP.String()},
|
||||||
|
Name: "client2",
|
||||||
|
})
|
||||||
|
require.Nil(t, err)
|
||||||
|
assert.True(t, ok)
|
||||||
|
|
||||||
|
// Add a new client with the IP from the first client's IP
|
||||||
|
// range.
|
||||||
|
ok, err = clients.Add(&Client{
|
||||||
|
IDs: []string{"2.2.2.2"},
|
||||||
|
Name: "client3",
|
||||||
|
})
|
||||||
|
require.Nil(t, err)
|
||||||
|
assert.True(t, ok)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClientsCustomUpstream(t *testing.T) {
|
func TestClientsCustomUpstream(t *testing.T) {
|
||||||
clients := clientsContainer{}
|
clients := clientsContainer{
|
||||||
clients.testing = true
|
testing: true,
|
||||||
|
}
|
||||||
clients.Init(nil, nil, nil)
|
clients.Init(nil, nil, nil)
|
||||||
|
|
||||||
// add client with upstreams
|
// Add client with upstreams.
|
||||||
c := &Client{
|
ok, err := clients.Add(&Client{
|
||||||
IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa"},
|
IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa"},
|
||||||
Name: "client1",
|
Name: "client1",
|
||||||
Upstreams: []string{
|
Upstreams: []string{
|
||||||
"1.1.1.1",
|
"1.1.1.1",
|
||||||
"[/example.org/]8.8.8.8",
|
"[/example.org/]8.8.8.8",
|
||||||
},
|
},
|
||||||
}
|
})
|
||||||
ok, err := clients.Add(c)
|
require.Nil(t, err)
|
||||||
assert.Nil(t, err)
|
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
config := clients.FindUpstreams("1.2.3.4")
|
config := clients.FindUpstreams("1.2.3.4")
|
||||||
assert.Nil(t, config)
|
assert.Nil(t, config)
|
||||||
|
|
||||||
config = clients.FindUpstreams("1.1.1.1")
|
config = clients.FindUpstreams("1.1.1.1")
|
||||||
assert.NotNil(t, config)
|
require.NotNil(t, config)
|
||||||
assert.Equal(t, 1, len(config.Upstreams))
|
assert.Len(t, config.Upstreams, 1)
|
||||||
assert.Equal(t, 1, len(config.DomainReservedUpstreams))
|
assert.Len(t, config.DomainReservedUpstreams, 1)
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,32 +3,12 @@ package home
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
/* Tests performed:
|
const (
|
||||||
. Bad certificate
|
CertificateChain = `-----BEGIN CERTIFICATE-----
|
||||||
. Bad private key
|
|
||||||
. Valid certificate & private key */
|
|
||||||
func TestValidateCertificates(t *testing.T) {
|
|
||||||
var data tlsConfigStatus
|
|
||||||
|
|
||||||
// bad cert
|
|
||||||
data = validateCertificates("bad cert", "", "")
|
|
||||||
if !(data.WarningValidation != "" &&
|
|
||||||
!data.ValidCert &&
|
|
||||||
!data.ValidChain) {
|
|
||||||
t.Fatalf("bad cert: validateCertificates(): %v", data)
|
|
||||||
}
|
|
||||||
|
|
||||||
// bad priv key
|
|
||||||
data = validateCertificates("", "bad priv key", "")
|
|
||||||
if !(data.WarningValidation != "" &&
|
|
||||||
!data.ValidKey) {
|
|
||||||
t.Fatalf("bad priv key: validateCertificates(): %v", data)
|
|
||||||
}
|
|
||||||
|
|
||||||
// valid cert & priv key
|
|
||||||
CertificateChain := `-----BEGIN CERTIFICATE-----
|
|
||||||
MIICKzCCAZSgAwIBAgIJAMT9kPVJdM7LMA0GCSqGSIb3DQEBCwUAMC0xFDASBgNV
|
MIICKzCCAZSgAwIBAgIJAMT9kPVJdM7LMA0GCSqGSIb3DQEBCwUAMC0xFDASBgNV
|
||||||
BAoMC0FkR3VhcmQgTHRkMRUwEwYDVQQDDAxBZEd1YXJkIEhvbWUwHhcNMTkwMjI3
|
BAoMC0FkR3VhcmQgTHRkMRUwEwYDVQQDDAxBZEd1YXJkIEhvbWUwHhcNMTkwMjI3
|
||||||
MDkyNDIzWhcNNDYwNzE0MDkyNDIzWjAtMRQwEgYDVQQKDAtBZEd1YXJkIEx0ZDEV
|
MDkyNDIzWhcNNDYwNzE0MDkyNDIzWjAtMRQwEgYDVQQKDAtBZEd1YXJkIEx0ZDEV
|
||||||
|
@ -42,7 +22,7 @@ LwlXfbakf7qkVTlCNXgoY7RaJ8rJdPgOZPoCTVToEhT6u/cb1c2qp8QB0dNExDna
|
||||||
b0Z+dnODTZqQOJo6z/wIXlcUrnR4cQVvytXt8lFn+26l6Y6EMI26twC/xWr+1swq
|
b0Z+dnODTZqQOJo6z/wIXlcUrnR4cQVvytXt8lFn+26l6Y6EMI26twC/xWr+1swq
|
||||||
Muj4FeWHVDerquH4yMr1jsYLD3ci+kc5sbIX6TfVxQ==
|
Muj4FeWHVDerquH4yMr1jsYLD3ci+kc5sbIX6TfVxQ==
|
||||||
-----END CERTIFICATE-----`
|
-----END CERTIFICATE-----`
|
||||||
PrivateKey := `-----BEGIN PRIVATE KEY-----
|
PrivateKey = `-----BEGIN PRIVATE KEY-----
|
||||||
MIICeAIBADANBgkqhkiG9w0BAQEFAASCAmIwggJeAgEAAoGBALC/BSc8mI68tw5p
|
MIICeAIBADANBgkqhkiG9w0BAQEFAASCAmIwggJeAgEAAoGBALC/BSc8mI68tw5p
|
||||||
aYa7pjrySwWvXeetcFywOWHGVfLw9qiFWLdfESa3Y6tWMpZAXD9t1Xh9n211YUBV
|
aYa7pjrySwWvXeetcFywOWHGVfLw9qiFWLdfESa3Y6tWMpZAXD9t1Xh9n211YUBV
|
||||||
FGSB4ZshnM/tgEPU6t787lJD4NsIIRp++MkJxdAitN4oUTqL0bdpIwezQ/CrYuBX
|
FGSB4ZshnM/tgEPU6t787lJD4NsIIRp++MkJxdAitN4oUTqL0bdpIwezQ/CrYuBX
|
||||||
|
@ -58,20 +38,35 @@ O5EX70gpeGQMPDK0QSWpaazg956njJSDbNCFM4BccrdQbJu1cW4qOsfBAkAMgZuG
|
||||||
O88slmgTRHX4JGFmy3rrLiHNI2BbJSuJ++Yllz8beVzh6NfvuY+HKRCmPqoBPATU
|
O88slmgTRHX4JGFmy3rrLiHNI2BbJSuJ++Yllz8beVzh6NfvuY+HKRCmPqoBPATU
|
||||||
kXS9jgARhhiWXJrk
|
kXS9jgARhhiWXJrk
|
||||||
-----END PRIVATE KEY-----`
|
-----END PRIVATE KEY-----`
|
||||||
data = validateCertificates(CertificateChain, PrivateKey, "")
|
)
|
||||||
notBefore, _ := time.Parse(time.RFC3339, "2019-02-27T09:24:23Z")
|
|
||||||
notAfter, _ := time.Parse(time.RFC3339, "2046-07-14T09:24:23Z")
|
func TestValidateCertificates(t *testing.T) {
|
||||||
if !(data.WarningValidation != "" /* self signed */ &&
|
t.Run("bad_certificate", func(t *testing.T) {
|
||||||
data.ValidCert &&
|
data := validateCertificates("bad cert", "", "")
|
||||||
!data.ValidChain &&
|
assert.NotEmpty(t, data.WarningValidation)
|
||||||
data.ValidKey &&
|
assert.False(t, data.ValidCert)
|
||||||
data.KeyType == "RSA" &&
|
assert.False(t, data.ValidChain)
|
||||||
data.Subject == "CN=AdGuard Home,O=AdGuard Ltd" &&
|
})
|
||||||
data.Issuer == "CN=AdGuard Home,O=AdGuard Ltd" &&
|
|
||||||
data.NotBefore.Equal(notBefore) &&
|
t.Run("bad_private_key", func(t *testing.T) {
|
||||||
data.NotAfter.Equal(notAfter) &&
|
data := validateCertificates("", "bad priv key", "")
|
||||||
// data.DNSNames[0] == &&
|
assert.NotEmpty(t, data.WarningValidation)
|
||||||
data.ValidPair) {
|
assert.False(t, data.ValidKey)
|
||||||
t.Fatalf("valid cert & priv key: validateCertificates(): %v", data)
|
})
|
||||||
}
|
|
||||||
|
t.Run("valid", func(t *testing.T) {
|
||||||
|
data := validateCertificates(CertificateChain, PrivateKey, "")
|
||||||
|
notBefore, _ := time.Parse(time.RFC3339, "2019-02-27T09:24:23Z")
|
||||||
|
notAfter, _ := time.Parse(time.RFC3339, "2046-07-14T09:24:23Z")
|
||||||
|
assert.NotEmpty(t, data.WarningValidation)
|
||||||
|
assert.True(t, data.ValidCert)
|
||||||
|
assert.False(t, data.ValidChain)
|
||||||
|
assert.True(t, data.ValidKey)
|
||||||
|
assert.Equal(t, "RSA", data.KeyType)
|
||||||
|
assert.Equal(t, "CN=AdGuard Home,O=AdGuard Ltd", data.Subject)
|
||||||
|
assert.Equal(t, "CN=AdGuard Home,O=AdGuard Ltd", data.Issuer)
|
||||||
|
assert.Equal(t, notBefore, data.NotBefore)
|
||||||
|
assert.Equal(t, notAfter, data.NotAfter)
|
||||||
|
assert.True(t, data.ValidPair)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,38 +9,47 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func testStartFilterListener() net.Listener {
|
func testStartFilterListener(t *testing.T) net.Listener {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
const content = `||example.org^$third-party
|
||||||
|
# Inline comment example
|
||||||
|
||example.com^$third-party
|
||||||
|
0.0.0.0 example.com
|
||||||
|
`
|
||||||
|
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
mux.HandleFunc("/filters/1.txt", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/filters/1.txt", func(w http.ResponseWriter, r *http.Request) {
|
||||||
content := `||example.org^$third-party
|
_, werr := w.Write([]byte(content))
|
||||||
# Inline comment example
|
assert.Nil(t, werr)
|
||||||
||example.com^$third-party
|
|
||||||
0.0.0.0 example.com
|
|
||||||
`
|
|
||||||
_, _ = w.Write([]byte(content))
|
|
||||||
})
|
})
|
||||||
|
|
||||||
listener, err := net.Listen("tcp", ":0")
|
listener, err := net.Listen("tcp", ":0")
|
||||||
if err != nil {
|
require.Nil(t, err)
|
||||||
panic(err)
|
|
||||||
}
|
go func() {
|
||||||
|
_ = http.Serve(listener, mux)
|
||||||
|
}()
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
assert.Nil(t, listener.Close())
|
||||||
|
})
|
||||||
|
|
||||||
go func() { _ = http.Serve(listener, mux) }()
|
|
||||||
return listener
|
return listener
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFilters(t *testing.T) {
|
func TestFilters(t *testing.T) {
|
||||||
l := testStartFilterListener()
|
l := testStartFilterListener(t)
|
||||||
defer func() { _ = l.Close() }()
|
dir := prepareTestDir(t)
|
||||||
|
|
||||||
dir := prepareTestDir()
|
Context = homeContext{
|
||||||
defer func() { _ = os.RemoveAll(dir) }()
|
workDir: dir,
|
||||||
Context = homeContext{}
|
client: &http.Client{
|
||||||
Context.workDir = dir
|
Timeout: 5 * time.Second,
|
||||||
Context.client = &http.Client{
|
},
|
||||||
Timeout: 5 * time.Second,
|
|
||||||
}
|
}
|
||||||
Context.filters.Init()
|
Context.filters.Init()
|
||||||
|
|
||||||
|
@ -48,20 +57,20 @@ func TestFilters(t *testing.T) {
|
||||||
URL: fmt.Sprintf("http://127.0.0.1:%d/filters/1.txt", l.Addr().(*net.TCPAddr).Port),
|
URL: fmt.Sprintf("http://127.0.0.1:%d/filters/1.txt", l.Addr().(*net.TCPAddr).Port),
|
||||||
}
|
}
|
||||||
|
|
||||||
// download
|
// Download.
|
||||||
ok, err := Context.filters.update(&f)
|
ok, err := Context.filters.update(&f)
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
assert.True(t, ok)
|
require.True(t, ok)
|
||||||
assert.Equal(t, 3, f.RulesCount)
|
assert.Equal(t, 3, f.RulesCount)
|
||||||
|
|
||||||
// refresh
|
// Refresh.
|
||||||
ok, err = Context.filters.update(&f)
|
ok, err = Context.filters.update(&f)
|
||||||
assert.False(t, ok)
|
require.Nil(t, err)
|
||||||
assert.Nil(t, err)
|
require.False(t, ok)
|
||||||
|
|
||||||
err = Context.filters.load(&f)
|
err = Context.filters.load(&f)
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
f.unload()
|
f.unload()
|
||||||
_ = os.Remove(f.Path())
|
require.Nil(t, os.Remove(f.Path()))
|
||||||
}
|
}
|
||||||
|
|
|
@ -114,8 +114,7 @@ func TestHome(t *testing.T) {
|
||||||
// Init new context
|
// Init new context
|
||||||
Context = homeContext{}
|
Context = homeContext{}
|
||||||
|
|
||||||
dir := prepareTestDir()
|
dir := prepareTestDir(t)
|
||||||
defer func() { _ = os.RemoveAll(dir) }()
|
|
||||||
fn := filepath.Join(dir, "AdGuardHome.yaml")
|
fn := filepath.Join(dir, "AdGuardHome.yaml")
|
||||||
|
|
||||||
// Prepare the test config
|
// Prepare the test config
|
||||||
|
|
|
@ -39,21 +39,21 @@ func TestLimitRequestBody(t *testing.T) {
|
||||||
wantErr: nil,
|
wantErr: nil,
|
||||||
}}
|
}}
|
||||||
|
|
||||||
makeHandler := func(err *error) http.HandlerFunc {
|
makeHandler := func(t *testing.T, err *error) http.HandlerFunc {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
var b []byte
|
var b []byte
|
||||||
b, *err = ioutil.ReadAll(r.Body)
|
b, *err = ioutil.ReadAll(r.Body)
|
||||||
_, werr := w.Write(b)
|
_, werr := w.Write(b)
|
||||||
if werr != nil {
|
require.Nil(t, werr)
|
||||||
panic(werr)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
var err error
|
var err error
|
||||||
handler := makeHandler(&err)
|
handler := makeHandler(t, &err)
|
||||||
lim := limitRequestBody(handler)
|
lim := limitRequestBody(handler)
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodPost, "https://www.example.com", strings.NewReader(tc.body))
|
req := httptest.NewRequest(http.MethodPost, "https://www.example.com", strings.NewReader(tc.body))
|
||||||
|
@ -61,7 +61,7 @@ func TestLimitRequestBody(t *testing.T) {
|
||||||
|
|
||||||
lim.ServeHTTP(res, req)
|
lim.ServeHTTP(res, req)
|
||||||
|
|
||||||
require.Equal(t, tc.wantErr, err)
|
assert.Equal(t, tc.wantErr, err)
|
||||||
assert.Equal(t, tc.want, res.Body.Bytes())
|
assert.Equal(t, tc.want, res.Body.Bytes())
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,29 +6,29 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
"howett.net/plist"
|
"howett.net/plist"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestHandleMobileConfigDOH(t *testing.T) {
|
func TestHandleMobileConfigDOH(t *testing.T) {
|
||||||
t.Run("success", func(t *testing.T) {
|
t.Run("success", func(t *testing.T) {
|
||||||
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/doh.mobileconfig?host=example.org", nil)
|
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/doh.mobileconfig?host=example.org", nil)
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
handleMobileConfigDOH(w, r)
|
handleMobileConfigDOH(w, r)
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
|
||||||
var mc mobileConfig
|
var mc mobileConfig
|
||||||
_, err = plist.Unmarshal(w.Body.Bytes(), &mc)
|
_, err = plist.Unmarshal(w.Body.Bytes(), &mc)
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
if assert.Len(t, mc.PayloadContent, 1) {
|
require.Len(t, mc.PayloadContent, 1)
|
||||||
assert.Equal(t, "example.org DoH", mc.PayloadContent[0].Name)
|
assert.Equal(t, "example.org DoH", mc.PayloadContent[0].Name)
|
||||||
assert.Equal(t, "example.org DoH", mc.PayloadContent[0].PayloadDisplayName)
|
assert.Equal(t, "example.org DoH", mc.PayloadContent[0].PayloadDisplayName)
|
||||||
assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName)
|
assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName)
|
||||||
assert.Equal(t, "https://example.org/dns-query", mc.PayloadContent[0].DNSSettings.ServerURL)
|
assert.Equal(t, "https://example.org/dns-query", mc.PayloadContent[0].DNSSettings.ServerURL)
|
||||||
}
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("success_no_host", func(t *testing.T) {
|
t.Run("success_no_host", func(t *testing.T) {
|
||||||
|
@ -40,23 +40,22 @@ func TestHandleMobileConfigDOH(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/doh.mobileconfig", nil)
|
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/doh.mobileconfig", nil)
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
handleMobileConfigDOH(w, r)
|
handleMobileConfigDOH(w, r)
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
|
||||||
var mc mobileConfig
|
var mc mobileConfig
|
||||||
_, err = plist.Unmarshal(w.Body.Bytes(), &mc)
|
_, err = plist.Unmarshal(w.Body.Bytes(), &mc)
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
if assert.Len(t, mc.PayloadContent, 1) {
|
require.Len(t, mc.PayloadContent, 1)
|
||||||
assert.Equal(t, "example.org DoH", mc.PayloadContent[0].Name)
|
assert.Equal(t, "example.org DoH", mc.PayloadContent[0].Name)
|
||||||
assert.Equal(t, "example.org DoH", mc.PayloadContent[0].PayloadDisplayName)
|
assert.Equal(t, "example.org DoH", mc.PayloadContent[0].PayloadDisplayName)
|
||||||
assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName)
|
assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName)
|
||||||
assert.Equal(t, "https://example.org/dns-query", mc.PayloadContent[0].DNSSettings.ServerURL)
|
assert.Equal(t, "https://example.org/dns-query", mc.PayloadContent[0].DNSSettings.ServerURL)
|
||||||
}
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("error_no_host", func(t *testing.T) {
|
t.Run("error_no_host", func(t *testing.T) {
|
||||||
|
@ -66,7 +65,7 @@ func TestHandleMobileConfigDOH(t *testing.T) {
|
||||||
Context.tls = &TLSMod{conf: tlsConfigSettings{}}
|
Context.tls = &TLSMod{conf: tlsConfigSettings{}}
|
||||||
|
|
||||||
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/doh.mobileconfig", nil)
|
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/doh.mobileconfig", nil)
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
@ -76,45 +75,43 @@ func TestHandleMobileConfigDOH(t *testing.T) {
|
||||||
|
|
||||||
t.Run("client_id", func(t *testing.T) {
|
t.Run("client_id", func(t *testing.T) {
|
||||||
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/doh.mobileconfig?host=example.org&client_id=cli42", nil)
|
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/doh.mobileconfig?host=example.org&client_id=cli42", nil)
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
handleMobileConfigDOH(w, r)
|
handleMobileConfigDOH(w, r)
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
|
||||||
var mc mobileConfig
|
var mc mobileConfig
|
||||||
_, err = plist.Unmarshal(w.Body.Bytes(), &mc)
|
_, err = plist.Unmarshal(w.Body.Bytes(), &mc)
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
if assert.Len(t, mc.PayloadContent, 1) {
|
require.Len(t, mc.PayloadContent, 1)
|
||||||
assert.Equal(t, "example.org DoH", mc.PayloadContent[0].Name)
|
assert.Equal(t, "example.org DoH", mc.PayloadContent[0].Name)
|
||||||
assert.Equal(t, "example.org DoH", mc.PayloadContent[0].PayloadDisplayName)
|
assert.Equal(t, "example.org DoH", mc.PayloadContent[0].PayloadDisplayName)
|
||||||
assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName)
|
assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName)
|
||||||
assert.Equal(t, "https://example.org/dns-query/cli42", mc.PayloadContent[0].DNSSettings.ServerURL)
|
assert.Equal(t, "https://example.org/dns-query/cli42", mc.PayloadContent[0].DNSSettings.ServerURL)
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHandleMobileConfigDOT(t *testing.T) {
|
func TestHandleMobileConfigDOT(t *testing.T) {
|
||||||
t.Run("success", func(t *testing.T) {
|
t.Run("success", func(t *testing.T) {
|
||||||
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/dot.mobileconfig?host=example.org", nil)
|
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/dot.mobileconfig?host=example.org", nil)
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
handleMobileConfigDOT(w, r)
|
handleMobileConfigDOT(w, r)
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
|
||||||
var mc mobileConfig
|
var mc mobileConfig
|
||||||
_, err = plist.Unmarshal(w.Body.Bytes(), &mc)
|
_, err = plist.Unmarshal(w.Body.Bytes(), &mc)
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
if assert.Len(t, mc.PayloadContent, 1) {
|
require.Len(t, mc.PayloadContent, 1)
|
||||||
assert.Equal(t, "example.org DoT", mc.PayloadContent[0].Name)
|
assert.Equal(t, "example.org DoT", mc.PayloadContent[0].Name)
|
||||||
assert.Equal(t, "example.org DoT", mc.PayloadContent[0].PayloadDisplayName)
|
assert.Equal(t, "example.org DoT", mc.PayloadContent[0].PayloadDisplayName)
|
||||||
assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName)
|
assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName)
|
||||||
}
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("success_no_host", func(t *testing.T) {
|
t.Run("success_no_host", func(t *testing.T) {
|
||||||
|
@ -126,22 +123,21 @@ func TestHandleMobileConfigDOT(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/dot.mobileconfig", nil)
|
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/dot.mobileconfig", nil)
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
handleMobileConfigDOT(w, r)
|
handleMobileConfigDOT(w, r)
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
|
||||||
var mc mobileConfig
|
var mc mobileConfig
|
||||||
_, err = plist.Unmarshal(w.Body.Bytes(), &mc)
|
_, err = plist.Unmarshal(w.Body.Bytes(), &mc)
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
if assert.Len(t, mc.PayloadContent, 1) {
|
require.Len(t, mc.PayloadContent, 1)
|
||||||
assert.Equal(t, "example.org DoT", mc.PayloadContent[0].Name)
|
assert.Equal(t, "example.org DoT", mc.PayloadContent[0].Name)
|
||||||
assert.Equal(t, "example.org DoT", mc.PayloadContent[0].PayloadDisplayName)
|
assert.Equal(t, "example.org DoT", mc.PayloadContent[0].PayloadDisplayName)
|
||||||
assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName)
|
assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName)
|
||||||
}
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("error_no_host", func(t *testing.T) {
|
t.Run("error_no_host", func(t *testing.T) {
|
||||||
|
@ -151,7 +147,7 @@ func TestHandleMobileConfigDOT(t *testing.T) {
|
||||||
Context.tls = &TLSMod{conf: tlsConfigSettings{}}
|
Context.tls = &TLSMod{conf: tlsConfigSettings{}}
|
||||||
|
|
||||||
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/dot.mobileconfig", nil)
|
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/dot.mobileconfig", nil)
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
@ -161,21 +157,20 @@ func TestHandleMobileConfigDOT(t *testing.T) {
|
||||||
|
|
||||||
t.Run("client_id", func(t *testing.T) {
|
t.Run("client_id", func(t *testing.T) {
|
||||||
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/dot.mobileconfig?host=example.org&client_id=cli42", nil)
|
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/dot.mobileconfig?host=example.org&client_id=cli42", nil)
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
handleMobileConfigDOT(w, r)
|
handleMobileConfigDOT(w, r)
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
|
||||||
var mc mobileConfig
|
var mc mobileConfig
|
||||||
_, err = plist.Unmarshal(w.Body.Bytes(), &mc)
|
_, err = plist.Unmarshal(w.Body.Bytes(), &mc)
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
if assert.Len(t, mc.PayloadContent, 1) {
|
require.Len(t, mc.PayloadContent, 1)
|
||||||
assert.Equal(t, "example.org DoT", mc.PayloadContent[0].Name)
|
assert.Equal(t, "example.org DoT", mc.PayloadContent[0].Name)
|
||||||
assert.Equal(t, "example.org DoT", mc.PayloadContent[0].PayloadDisplayName)
|
assert.Equal(t, "example.org DoT", mc.PayloadContent[0].PayloadDisplayName)
|
||||||
assert.Equal(t, "cli42.example.org", mc.PayloadContent[0].DNSSettings.ServerName)
|
assert.Equal(t, "cli42.example.org", mc.PayloadContent[0].DNSSettings.ServerName)
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,96 +4,74 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func testParseOk(t *testing.T, ss ...string) options {
|
func testParseOK(t *testing.T, ss ...string) options {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
o, _, err := parse("", ss)
|
o, _, err := parse("", ss)
|
||||||
if err != nil {
|
require.Nil(t, err)
|
||||||
t.Fatal(err.Error())
|
|
||||||
}
|
|
||||||
return o
|
return o
|
||||||
}
|
}
|
||||||
|
|
||||||
func testParseErr(t *testing.T, descr string, ss ...string) {
|
func testParseErr(t *testing.T, descr string, ss ...string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
_, _, err := parse("", ss)
|
_, _, err := parse("", ss)
|
||||||
if err == nil {
|
require.NotNilf(t, err, "expected an error because %s but no error returned", descr)
|
||||||
t.Fatalf("expected an error because %s but no error returned", descr)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func testParseParamMissing(t *testing.T, param string) {
|
func testParseParamMissing(t *testing.T, param string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
testParseErr(t, fmt.Sprintf("%s parameter missing", param), param)
|
testParseErr(t, fmt.Sprintf("%s parameter missing", param), param)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseVerbose(t *testing.T) {
|
func TestParseVerbose(t *testing.T) {
|
||||||
if testParseOk(t).verbose {
|
assert.False(t, testParseOK(t).verbose, "empty is not verbose")
|
||||||
t.Fatal("empty is not verbose")
|
assert.True(t, testParseOK(t, "-v").verbose, "-v is verbose")
|
||||||
}
|
assert.True(t, testParseOK(t, "--verbose").verbose, "--verbose is verbose")
|
||||||
if !testParseOk(t, "-v").verbose {
|
|
||||||
t.Fatal("-v is verbose")
|
|
||||||
}
|
|
||||||
if !testParseOk(t, "--verbose").verbose {
|
|
||||||
t.Fatal("--verbose is verbose")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseConfigFilename(t *testing.T) {
|
func TestParseConfigFilename(t *testing.T) {
|
||||||
if testParseOk(t).configFilename != "" {
|
assert.Equal(t, "", testParseOK(t).configFilename, "empty is no config filename")
|
||||||
t.Fatal("empty is no config filename")
|
assert.Equal(t, "path", testParseOK(t, "-c", "path").configFilename, "-c is config filename")
|
||||||
}
|
|
||||||
if testParseOk(t, "-c", "path").configFilename != "path" {
|
|
||||||
t.Fatal("-c is config filename")
|
|
||||||
}
|
|
||||||
testParseParamMissing(t, "-c")
|
testParseParamMissing(t, "-c")
|
||||||
if testParseOk(t, "--config", "path").configFilename != "path" {
|
|
||||||
t.Fatal("--configFilename is config filename")
|
assert.Equal(t, "path", testParseOK(t, "--config", "path").configFilename, "--config is config filename")
|
||||||
}
|
|
||||||
testParseParamMissing(t, "--config")
|
testParseParamMissing(t, "--config")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseWorkDir(t *testing.T) {
|
func TestParseWorkDir(t *testing.T) {
|
||||||
if testParseOk(t).workDir != "" {
|
assert.Equal(t, "", testParseOK(t).workDir, "empty is no work dir")
|
||||||
t.Fatal("empty is no work dir")
|
assert.Equal(t, "path", testParseOK(t, "-w", "path").workDir, "-w is work dir")
|
||||||
}
|
|
||||||
if testParseOk(t, "-w", "path").workDir != "path" {
|
|
||||||
t.Fatal("-w is work dir")
|
|
||||||
}
|
|
||||||
testParseParamMissing(t, "-w")
|
testParseParamMissing(t, "-w")
|
||||||
if testParseOk(t, "--work-dir", "path").workDir != "path" {
|
|
||||||
t.Fatal("--work-dir is work dir")
|
assert.Equal(t, "path", testParseOK(t, "--work-dir", "path").workDir, "--work-dir is work dir")
|
||||||
}
|
|
||||||
testParseParamMissing(t, "--work-dir")
|
testParseParamMissing(t, "--work-dir")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseBindHost(t *testing.T) {
|
func TestParseBindHost(t *testing.T) {
|
||||||
if testParseOk(t).bindHost != nil {
|
assert.Nil(t, testParseOK(t).bindHost, "empty is not host")
|
||||||
t.Fatal("empty is no host")
|
assert.Equal(t, net.IPv4(1, 2, 3, 4), testParseOK(t, "-h", "1.2.3.4").bindHost, "-h is host")
|
||||||
}
|
|
||||||
if !testParseOk(t, "-h", "1.2.3.4").bindHost.Equal(net.IP{1, 2, 3, 4}) {
|
|
||||||
t.Fatal("-h is host")
|
|
||||||
}
|
|
||||||
testParseParamMissing(t, "-h")
|
testParseParamMissing(t, "-h")
|
||||||
if !testParseOk(t, "--host", "1.2.3.4").bindHost.Equal(net.IP{1, 2, 3, 4}) {
|
|
||||||
t.Fatal("--host is host")
|
assert.Equal(t, net.IPv4(1, 2, 3, 4), testParseOK(t, "--host", "1.2.3.4").bindHost, "--host is host")
|
||||||
}
|
|
||||||
testParseParamMissing(t, "--host")
|
testParseParamMissing(t, "--host")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseBindPort(t *testing.T) {
|
func TestParseBindPort(t *testing.T) {
|
||||||
if testParseOk(t).bindPort != 0 {
|
assert.Equal(t, 0, testParseOK(t).bindPort, "empty is port 0")
|
||||||
t.Fatal("empty is port 0")
|
assert.Equal(t, 65535, testParseOK(t, "-p", "65535").bindPort, "-p is port")
|
||||||
}
|
|
||||||
if testParseOk(t, "-p", "65535").bindPort != 65535 {
|
|
||||||
t.Fatal("-p is port")
|
|
||||||
}
|
|
||||||
testParseParamMissing(t, "-p")
|
testParseParamMissing(t, "-p")
|
||||||
if testParseOk(t, "--port", "65535").bindPort != 65535 {
|
|
||||||
t.Fatal("--port is port")
|
|
||||||
}
|
|
||||||
testParseParamMissing(t, "--port")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseBindPortBad(t *testing.T) {
|
assert.Equal(t, 65535, testParseOK(t, "--port", "65535").bindPort, "--port is port")
|
||||||
|
testParseParamMissing(t, "--port")
|
||||||
|
|
||||||
testParseErr(t, "not an int", "-p", "x")
|
testParseErr(t, "not an int", "-p", "x")
|
||||||
testParseErr(t, "hex not supported", "-p", "0x100")
|
testParseErr(t, "hex not supported", "-p", "0x100")
|
||||||
testParseErr(t, "port negative", "-p", "-1")
|
testParseErr(t, "port negative", "-p", "-1")
|
||||||
|
@ -103,72 +81,40 @@ func TestParseBindPortBad(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseLogfile(t *testing.T) {
|
func TestParseLogfile(t *testing.T) {
|
||||||
if testParseOk(t).logFile != "" {
|
assert.Equal(t, "", testParseOK(t).logFile, "empty is no log file")
|
||||||
t.Fatal("empty is no log file")
|
assert.Equal(t, "path", testParseOK(t, "-l", "path").logFile, "-l is log file")
|
||||||
}
|
assert.Equal(t, "path", testParseOK(t, "--logfile", "path").logFile, "--logfile is log file")
|
||||||
if testParseOk(t, "-l", "path").logFile != "path" {
|
|
||||||
t.Fatal("-l is log file")
|
|
||||||
}
|
|
||||||
if testParseOk(t, "--logfile", "path").logFile != "path" {
|
|
||||||
t.Fatal("--logfile is log file")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParsePidfile(t *testing.T) {
|
func TestParsePidfile(t *testing.T) {
|
||||||
if testParseOk(t).pidFile != "" {
|
assert.Equal(t, "", testParseOK(t).pidFile, "empty is no pid file")
|
||||||
t.Fatal("empty is no pid file")
|
assert.Equal(t, "path", testParseOK(t, "--pidfile", "path").pidFile, "--pidfile is pid file")
|
||||||
}
|
|
||||||
if testParseOk(t, "--pidfile", "path").pidFile != "path" {
|
|
||||||
t.Fatal("--pidfile is pid file")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseCheckConfig(t *testing.T) {
|
func TestParseCheckConfig(t *testing.T) {
|
||||||
if testParseOk(t).checkConfig {
|
assert.False(t, testParseOK(t).checkConfig, "empty is not check config")
|
||||||
t.Fatal("empty is not check config")
|
assert.True(t, testParseOK(t, "--check-config").checkConfig, "--check-config is check config")
|
||||||
}
|
|
||||||
if !testParseOk(t, "--check-config").checkConfig {
|
|
||||||
t.Fatal("--check-config is check config")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseDisableUpdate(t *testing.T) {
|
func TestParseDisableUpdate(t *testing.T) {
|
||||||
if testParseOk(t).disableUpdate {
|
assert.False(t, testParseOK(t).disableUpdate, "empty is not disable update")
|
||||||
t.Fatal("empty is not disable update")
|
assert.True(t, testParseOK(t, "--no-check-update").disableUpdate, "--no-check-update is disable update")
|
||||||
}
|
|
||||||
if !testParseOk(t, "--no-check-update").disableUpdate {
|
|
||||||
t.Fatal("--no-check-update is disable update")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseDisableMemoryOptimization(t *testing.T) {
|
func TestParseDisableMemoryOptimization(t *testing.T) {
|
||||||
if testParseOk(t).disableMemoryOptimization {
|
assert.False(t, testParseOK(t).disableMemoryOptimization, "empty is not disable update")
|
||||||
t.Fatal("empty is not disable update")
|
assert.True(t, testParseOK(t, "--no-mem-optimization").disableMemoryOptimization, "--no-mem-optimization is disable update")
|
||||||
}
|
|
||||||
if !testParseOk(t, "--no-mem-optimization").disableMemoryOptimization {
|
|
||||||
t.Fatal("--no-mem-optimization is disable update")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseService(t *testing.T) {
|
func TestParseService(t *testing.T) {
|
||||||
if testParseOk(t).serviceControlAction != "" {
|
assert.Equal(t, "", testParseOK(t).serviceControlAction, "empty is not service cmd")
|
||||||
t.Fatal("empty is no service command")
|
assert.Equal(t, "cmd", testParseOK(t, "-s", "cmd").serviceControlAction, "-s is service cmd")
|
||||||
}
|
assert.Equal(t, "cmd", testParseOK(t, "--service", "cmd").serviceControlAction, "--service is service cmd")
|
||||||
if testParseOk(t, "-s", "command").serviceControlAction != "command" {
|
|
||||||
t.Fatal("-s is service command")
|
|
||||||
}
|
|
||||||
if testParseOk(t, "--service", "command").serviceControlAction != "command" {
|
|
||||||
t.Fatal("--service is service command")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseGLInet(t *testing.T) {
|
func TestParseGLInet(t *testing.T) {
|
||||||
if testParseOk(t).glinetMode {
|
assert.False(t, testParseOK(t).glinetMode, "empty is not GL-Inet mode")
|
||||||
t.Fatal("empty is not GL-Inet mode")
|
assert.True(t, testParseOK(t, "--glinet").glinetMode, "--glinet is GL-Inet mode")
|
||||||
}
|
|
||||||
if !testParseOk(t, "--glinet").glinetMode {
|
|
||||||
t.Fatal("--glinet is GL-Inet mode")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseUnknown(t *testing.T) {
|
func TestParseUnknown(t *testing.T) {
|
||||||
|
@ -180,73 +126,85 @@ func TestParseUnknown(t *testing.T) {
|
||||||
testParseErr(t, "unknown dash", "-")
|
testParseErr(t, "unknown dash", "-")
|
||||||
}
|
}
|
||||||
|
|
||||||
func testSerialize(t *testing.T, o options, ss ...string) {
|
func TestSerialize(t *testing.T) {
|
||||||
result := serialize(o)
|
const reportFmt = "expected %s but got %s"
|
||||||
if len(result) != len(ss) {
|
|
||||||
t.Fatalf("expected %s but got %s", ss, result)
|
testCases := []struct {
|
||||||
}
|
name string
|
||||||
for i, r := range result {
|
opts options
|
||||||
if r != ss[i] {
|
ss []string
|
||||||
t.Fatalf("expected %s but got %s", ss, result)
|
}{{
|
||||||
}
|
name: "empty",
|
||||||
|
opts: options{},
|
||||||
|
ss: []string{},
|
||||||
|
}, {
|
||||||
|
name: "config_filename",
|
||||||
|
opts: options{configFilename: "path"},
|
||||||
|
ss: []string{"-c", "path"},
|
||||||
|
}, {
|
||||||
|
name: "work_dir",
|
||||||
|
opts: options{workDir: "path"},
|
||||||
|
ss: []string{"-w", "path"},
|
||||||
|
}, {
|
||||||
|
name: "bind_host",
|
||||||
|
opts: options{bindHost: net.IP{1, 2, 3, 4}},
|
||||||
|
ss: []string{"-h", "1.2.3.4"},
|
||||||
|
}, {
|
||||||
|
name: "bind_port",
|
||||||
|
opts: options{bindPort: 666},
|
||||||
|
ss: []string{"-p", "666"},
|
||||||
|
}, {
|
||||||
|
name: "log_file",
|
||||||
|
opts: options{logFile: "path"},
|
||||||
|
ss: []string{"-l", "path"},
|
||||||
|
}, {
|
||||||
|
name: "pid_file",
|
||||||
|
opts: options{pidFile: "path"},
|
||||||
|
ss: []string{"--pidfile", "path"},
|
||||||
|
}, {
|
||||||
|
name: "disable_update",
|
||||||
|
opts: options{disableUpdate: true},
|
||||||
|
ss: []string{"--no-check-update"},
|
||||||
|
}, {
|
||||||
|
name: "control_action",
|
||||||
|
opts: options{serviceControlAction: "run"},
|
||||||
|
ss: []string{"-s", "run"},
|
||||||
|
}, {
|
||||||
|
name: "glinet_mode",
|
||||||
|
opts: options{glinetMode: true},
|
||||||
|
ss: []string{"--glinet"},
|
||||||
|
}, {
|
||||||
|
name: "disable_mem_opt",
|
||||||
|
opts: options{disableMemoryOptimization: true},
|
||||||
|
ss: []string{"--no-mem-optimization"},
|
||||||
|
}, {
|
||||||
|
name: "multiple",
|
||||||
|
opts: options{
|
||||||
|
serviceControlAction: "run",
|
||||||
|
configFilename: "config",
|
||||||
|
workDir: "work",
|
||||||
|
pidFile: "pid",
|
||||||
|
disableUpdate: true,
|
||||||
|
disableMemoryOptimization: true,
|
||||||
|
},
|
||||||
|
ss: []string{
|
||||||
|
"-c", "config",
|
||||||
|
"-w", "work",
|
||||||
|
"-s", "run",
|
||||||
|
"--pidfile", "pid",
|
||||||
|
"--no-check-update",
|
||||||
|
"--no-mem-optimization",
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
result := serialize(tc.opts)
|
||||||
|
require.Lenf(t, result, len(tc.ss), reportFmt, tc.ss, result)
|
||||||
|
|
||||||
|
for i, r := range result {
|
||||||
|
assert.Equalf(t, tc.ss[i], r, reportFmt, tc.ss, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSerializeEmpty(t *testing.T) {
|
|
||||||
testSerialize(t, options{})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSerializeConfigFilename(t *testing.T) {
|
|
||||||
testSerialize(t, options{configFilename: "path"}, "-c", "path")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSerializeWorkDir(t *testing.T) {
|
|
||||||
testSerialize(t, options{workDir: "path"}, "-w", "path")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSerializeBindHost(t *testing.T) {
|
|
||||||
testSerialize(t, options{bindHost: net.IP{1, 2, 3, 4}}, "-h", "1.2.3.4")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSerializeBindPort(t *testing.T) {
|
|
||||||
testSerialize(t, options{bindPort: 666}, "-p", "666")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSerializeLogfile(t *testing.T) {
|
|
||||||
testSerialize(t, options{logFile: "path"}, "-l", "path")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSerializePidfile(t *testing.T) {
|
|
||||||
testSerialize(t, options{pidFile: "path"}, "--pidfile", "path")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSerializeCheckConfig(t *testing.T) {
|
|
||||||
testSerialize(t, options{checkConfig: true}, "--check-config")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSerializeDisableUpdate(t *testing.T) {
|
|
||||||
testSerialize(t, options{disableUpdate: true}, "--no-check-update")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSerializeService(t *testing.T) {
|
|
||||||
testSerialize(t, options{serviceControlAction: "run"}, "-s", "run")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSerializeGLInet(t *testing.T) {
|
|
||||||
testSerialize(t, options{glinetMode: true}, "--glinet")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSerializeDisableMemoryOptimization(t *testing.T) {
|
|
||||||
testSerialize(t, options{disableMemoryOptimization: true}, "--no-mem-optimization")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSerializeMultiple(t *testing.T) {
|
|
||||||
testSerialize(t, options{
|
|
||||||
serviceControlAction: "run",
|
|
||||||
configFilename: "config",
|
|
||||||
workDir: "work",
|
|
||||||
pidFile: "pid",
|
|
||||||
disableUpdate: true,
|
|
||||||
disableMemoryOptimization: true,
|
|
||||||
}, "-c", "config", "-w", "work", "-s", "run", "--pidfile", "pid", "--no-check-update", "--no-mem-optimization")
|
|
||||||
}
|
|
||||||
|
|
|
@ -20,6 +20,7 @@ func newTestQLogReader(t *testing.T, filesNum, linesNum int) (reader *QLogReader
|
||||||
// Create the new QLogReader instance.
|
// Create the new QLogReader instance.
|
||||||
reader, err := NewQLogReader(testFiles)
|
reader, err := NewQLogReader(testFiles)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
assert.NotNil(t, reader)
|
assert.NotNil(t, reader)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
assert.Nil(t, reader.Close())
|
assert.Nil(t, reader.Close())
|
||||||
|
@ -112,11 +113,7 @@ func TestQLogReader_Seek(t *testing.T) {
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
timestamp, err := time.Parse(time.RFC3339Nano, tc.time)
|
timestamp, err := time.Parse(time.RFC3339Nano, tc.time)
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
if tc.name == "first" {
|
|
||||||
assert.True(t, true)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = r.SeekTS(timestamp.UnixNano())
|
err = r.SeekTS(timestamp.UnixNano())
|
||||||
assert.True(t, errors.Is(err, tc.want))
|
assert.True(t, errors.Is(err, tc.want))
|
||||||
|
@ -146,11 +143,11 @@ func TestQLogReader_ReadNext(t *testing.T) {
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
err := r.SeekStart()
|
err := r.SeekStart()
|
||||||
assert.Nil(t, err, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
for i := 1; i < tc.start; i++ {
|
for i := 1; i < tc.start; i++ {
|
||||||
_, err := r.ReadNext()
|
_, err := r.ReadNext()
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = r.ReadNext()
|
_, err = r.ReadNext()
|
||||||
|
|
Loading…
Reference in New Issue