badguardhome/dnsforward/dnsforward_test.go

491 lines
12 KiB
Go

package dnsforward
import (
"crypto/ecdsa"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"net"
"os"
"sync"
"testing"
"time"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/stretchr/testify/assert"
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
"github.com/miekg/dns"
)
const (
tlsServerName = "testdns.adguard.com"
dataDir = "testData"
testMessagesCount = 10
)
func TestServer(t *testing.T) {
s := createTestServer(t)
defer removeDataDir(t)
err := s.Start(nil)
if err != nil {
t.Fatalf("Failed to start server: %s", err)
}
// message over UDP
req := createTestMessage()
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
client := dns.Client{Net: "udp"}
reply, _, err := client.Exchange(req, addr.String())
if err != nil {
t.Fatalf("Couldn't talk to server %s: %s", addr, err)
}
assertResponse(t, reply)
// check query log and stats
log := s.GetQueryLog()
assert.Equal(t, 1, len(log), "Log size")
stats := s.GetStatsTop()
assert.Equal(t, 1, len(stats.Domains), "Top domains length")
assert.Equal(t, 0, len(stats.Blocked), "Top blocked length")
assert.Equal(t, 1, len(stats.Clients), "Top clients length")
// message over TCP
req = createTestMessage()
addr = s.dnsProxy.Addr("tcp")
client = dns.Client{Net: "tcp"}
reply, _, err = client.Exchange(req, addr.String())
if err != nil {
t.Fatalf("Couldn't talk to server %s: %s", addr, err)
}
assertResponse(t, reply)
// check query log and stats again
log = s.GetQueryLog()
assert.Equal(t, 2, len(log), "Log size")
stats = s.GetStatsTop()
// Length did not change as we queried the same domain
assert.Equal(t, 1, len(stats.Domains), "Top domains length")
assert.Equal(t, 0, len(stats.Blocked), "Top blocked length")
assert.Equal(t, 1, len(stats.Clients), "Top clients length")
err = s.Stop()
if err != nil {
t.Fatalf("DNS server failed to stop: %s", err)
}
}
func TestDotServer(t *testing.T) {
// Prepare the proxy server
_, certPem, keyPem := createServerTLSConfig(t)
s := createTestServer(t)
defer removeDataDir(t)
s.TLSConfig = TLSConfig{
TLSListenAddr: &net.TCPAddr{Port: 0},
CertificateChain: string(certPem),
PrivateKey: string(keyPem),
}
// Starting the server
err := s.Start(nil)
if err != nil {
t.Fatalf("Failed to start server: %s", err)
}
// Add our self-signed generated config to roots
roots := x509.NewCertPool()
roots.AppendCertsFromPEM(certPem)
tlsConfig := &tls.Config{ServerName: tlsServerName, RootCAs: roots}
// Create a DNS-over-TLS client connection
addr := s.dnsProxy.Addr(proxy.ProtoTLS)
conn, err := dns.DialWithTLS("tcp-tls", addr.String(), tlsConfig)
if err != nil {
t.Fatalf("cannot connect to the proxy: %s", err)
}
sendTestMessages(t, conn)
// Stop the proxy
err = s.Stop()
if err != nil {
t.Fatalf("DNS server failed to stop: %s", err)
}
}
func TestServerRace(t *testing.T) {
s := createTestServer(t)
defer removeDataDir(t)
err := s.Start(nil)
if err != nil {
t.Fatalf("Failed to start server: %s", err)
}
// message over UDP
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
conn, err := dns.Dial("udp", addr.String())
if err != nil {
t.Fatalf("cannot connect to the proxy: %s", err)
}
sendTestMessagesAsync(t, conn)
// Stop the proxy
err = s.Stop()
if err != nil {
t.Fatalf("DNS server failed to stop: %s", err)
}
}
func TestInvalidRequest(t *testing.T) {
s := createTestServer(t)
defer removeDataDir(t)
err := s.Start(nil)
if err != nil {
t.Fatalf("Failed to start server: %s", err)
}
// server is running, send a message
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
req := dns.Msg{}
req.Id = dns.Id()
req.RecursionDesired = true
// send a DNS request without question
client := dns.Client{Net: "udp", Timeout: 500 * time.Millisecond}
_, _, err = client.Exchange(&req, addr.String())
if err != nil {
t.Fatalf("got a response to an invalid query")
}
// check query log and stats
// invalid requests aren't written to the query log
log := s.GetQueryLog()
assert.Equal(t, 0, len(log), "Log size")
stats := s.GetStatsTop()
assert.Equal(t, 0, len(stats.Domains), "Top domains length")
assert.Equal(t, 0, len(stats.Blocked), "Top blocked length")
assert.Equal(t, 0, len(stats.Clients), "Top clients length")
err = s.Stop()
if err != nil {
t.Fatalf("DNS server failed to stop: %s", err)
}
}
func TestBlockedRequest(t *testing.T) {
s := createTestServer(t)
defer removeDataDir(t)
err := s.Start(nil)
if err != nil {
t.Fatalf("Failed to start server: %s", err)
}
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
//
// NXDomain blocking
//
req := dns.Msg{}
req.Id = dns.Id()
req.RecursionDesired = true
req.Question = []dns.Question{
{Name: "nxdomain.example.org.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
}
reply, err := dns.Exchange(&req, addr.String())
if err != nil {
t.Fatalf("Couldn't talk to server %s: %s", addr, err)
}
if reply.Rcode != dns.RcodeNameError {
t.Fatalf("Wrong response: %s", reply.String())
}
// check query log and stats
log := s.GetQueryLog()
assert.Equal(t, 1, len(log), "Log size")
stats := s.GetStatsTop()
assert.Equal(t, 1, len(stats.Domains), "Top domains length")
assert.Equal(t, 1, len(stats.Blocked), "Top blocked length")
assert.Equal(t, 1, len(stats.Clients), "Top clients length")
err = s.Stop()
if err != nil {
t.Fatalf("DNS server failed to stop: %s", err)
}
}
func TestBlockedByHosts(t *testing.T) {
s := createTestServer(t)
defer removeDataDir(t)
err := s.Start(nil)
if err != nil {
t.Fatalf("Failed to start server: %s", err)
}
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
//
// Hosts blocking
//
req := dns.Msg{}
req.Id = dns.Id()
req.RecursionDesired = true
req.Question = []dns.Question{
{Name: "host.example.org.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
}
reply, err := dns.Exchange(&req, addr.String())
if err != nil {
t.Fatalf("Couldn't talk to server %s: %s", addr, err)
}
if len(reply.Answer) != 1 {
t.Fatalf("DNS server %s returned reply with wrong number of answers - %d", addr, len(reply.Answer))
}
if a, ok := reply.Answer[0].(*dns.A); ok {
if !net.IPv4(127, 0, 0, 1).Equal(a.A) {
t.Fatalf("DNS server %s returned wrong answer instead of 8.8.8.8: %v", addr, a.A)
}
} else {
t.Fatalf("DNS server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0])
}
// check query log and stats
log := s.GetQueryLog()
assert.Equal(t, 1, len(log), "Log size")
stats := s.GetStatsTop()
assert.Equal(t, 1, len(stats.Domains), "Top domains length")
assert.Equal(t, 1, len(stats.Blocked), "Top blocked length")
assert.Equal(t, 1, len(stats.Clients), "Top clients length")
err = s.Stop()
if err != nil {
t.Fatalf("DNS server failed to stop: %s", err)
}
}
func TestBlockedBySafeBrowsing(t *testing.T) {
s := createTestServer(t)
defer removeDataDir(t)
err := s.Start(nil)
if err != nil {
t.Fatalf("Failed to start server: %s", err)
}
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
//
// Safebrowsing blocking
//
req := dns.Msg{}
req.Id = dns.Id()
req.RecursionDesired = true
req.Question = []dns.Question{
{Name: "wmconvirus.narod.ru.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
}
reply, err := dns.Exchange(&req, addr.String())
if err != nil {
t.Fatalf("Couldn't talk to server %s: %s", addr, err)
}
if len(reply.Answer) != 1 {
t.Fatalf("DNS server %s returned reply with wrong number of answers - %d", addr, len(reply.Answer))
}
if a, ok := reply.Answer[0].(*dns.A); ok {
addrs, lookupErr := net.LookupHost(safeBrowsingBlockHost)
if lookupErr != nil {
t.Fatalf("cannot resolve %s due to %s", safeBrowsingBlockHost, lookupErr)
}
found := false
for _, blockAddr := range addrs {
if blockAddr == a.A.String() {
found = true
}
}
if !found {
t.Fatalf("DNS server %s returned wrong answer: %v", addr, a.A)
}
} else {
t.Fatalf("DNS server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0])
}
// check query log and stats
log := s.GetQueryLog()
assert.Equal(t, 1, len(log), "Log size")
stats := s.GetStatsTop()
assert.Equal(t, 1, len(stats.Domains), "Top domains length")
assert.Equal(t, 1, len(stats.Blocked), "Top blocked length")
assert.Equal(t, 1, len(stats.Clients), "Top clients length")
err = s.Stop()
if err != nil {
t.Fatalf("DNS server failed to stop: %s", err)
}
}
func createTestServer(t *testing.T) *Server {
s := NewServer(createDataDir(t))
s.UDPListenAddr = &net.UDPAddr{Port: 0}
s.TCPListenAddr = &net.TCPAddr{Port: 0}
s.QueryLogEnabled = true
s.FilteringConfig.FilteringEnabled = true
s.FilteringConfig.ProtectionEnabled = true
s.FilteringConfig.SafeBrowsingEnabled = true
s.Filters = make([]dnsfilter.Filter, 0)
rules := []string{
"||nxdomain.example.org^",
"127.0.0.1 host.example.org",
}
filter := dnsfilter.Filter{ID: 1, Rules: rules}
s.Filters = append(s.Filters, filter)
return s
}
func createServerTLSConfig(t *testing.T) (*tls.Config, []byte, []byte) {
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("cannot generate RSA key: %s", err)
}
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
t.Fatalf("failed to generate serial number: %s", err)
}
notBefore := time.Now()
notAfter := notBefore.Add(5 * 365 * time.Hour * 24)
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"AdGuard Tests"},
},
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
IsCA: true,
}
template.DNSNames = append(template.DNSNames, tlsServerName)
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey(privateKey), privateKey)
if err != nil {
t.Fatalf("failed to create certificate: %s", err)
}
certPem := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
keyPem := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)})
cert, err := tls.X509KeyPair(certPem, keyPem)
if err != nil {
t.Fatalf("failed to create certificate: %s", err)
}
return &tls.Config{Certificates: []tls.Certificate{cert}, ServerName: tlsServerName}, certPem, keyPem
}
func createDataDir(t *testing.T) string {
err := os.MkdirAll(dataDir, 0755)
if err != nil {
t.Fatalf("Cannot create %s: %s", dataDir, err)
}
return dataDir
}
func removeDataDir(t *testing.T) {
err := os.RemoveAll(dataDir)
if err != nil {
t.Fatalf("Cannot remove %s: %s", dataDir, err)
}
}
func sendTestMessageAsync(t *testing.T, conn *dns.Conn, g *sync.WaitGroup) {
defer func() {
g.Done()
}()
req := createTestMessage()
err := conn.WriteMsg(req)
if err != nil {
t.Fatalf("cannot write message: %s", err)
}
res, err := conn.ReadMsg()
if err != nil {
t.Fatalf("cannot read response to message: %s", err)
}
assertResponse(t, res)
}
// sendTestMessagesAsync sends messages in parallel
// so that we could find race issues
func sendTestMessagesAsync(t *testing.T, conn *dns.Conn) {
g := &sync.WaitGroup{}
g.Add(testMessagesCount)
for i := 0; i < testMessagesCount; i++ {
go sendTestMessageAsync(t, conn, g)
}
g.Wait()
}
func sendTestMessages(t *testing.T, conn *dns.Conn) {
for i := 0; i < 10; i++ {
req := createTestMessage()
err := conn.WriteMsg(req)
if err != nil {
t.Fatalf("cannot write message #%d: %s", i, err)
}
res, err := conn.ReadMsg()
if err != nil {
t.Fatalf("cannot read response to message #%d: %s", i, err)
}
assertResponse(t, res)
}
}
func createTestMessage() *dns.Msg {
req := dns.Msg{}
req.Id = dns.Id()
req.RecursionDesired = true
req.Question = []dns.Question{
{Name: "google-public-dns-a.google.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
}
return &req
}
func assertResponse(t *testing.T, reply *dns.Msg) {
if len(reply.Answer) != 1 {
t.Fatalf("DNS server returned reply with wrong number of answers - %d", len(reply.Answer))
}
if a, ok := reply.Answer[0].(*dns.A); ok {
if !net.IPv4(8, 8, 8, 8).Equal(a.A) {
t.Fatalf("DNS server returned wrong answer instead of 8.8.8.8: %v", a.A)
}
} else {
t.Fatalf("DNS server returned wrong answer type instead of A: %v", reply.Answer[0])
}
}
func publicKey(priv interface{}) interface{} {
switch k := priv.(type) {
case *rsa.PrivateKey:
return &k.PublicKey
case *ecdsa.PrivateKey:
return &k.PublicKey
default:
return nil
}
}