Added DNS-over-TLS unit-test and a test looking for race-conditions
This commit is contained in:
parent
158f2f6100
commit
71df659dc9
|
@ -70,6 +70,7 @@ type FilteringConfig struct {
|
||||||
dnsfilter.Config `yaml:",inline"`
|
dnsfilter.Config `yaml:",inline"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TLSConfig is the TLS configuration for HTTPS, DNS-over-HTTPS, and DNS-over-TLS
|
||||||
type TLSConfig struct {
|
type TLSConfig struct {
|
||||||
TLSListenAddr *net.TCPAddr `yaml:"-" json:"-"`
|
TLSListenAddr *net.TCPAddr `yaml:"-" json:"-"`
|
||||||
CertificateChain string `yaml:"certificate_chain" json:"certificate_chain"` // PEM-encoded certificates chain
|
CertificateChain string `yaml:"certificate_chain" json:"certificate_chain"` // PEM-encoded certificates chain
|
||||||
|
|
|
@ -1,17 +1,34 @@
|
||||||
package dnsforward
|
package dnsforward
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"crypto/x509/pkix"
|
||||||
|
"encoding/pem"
|
||||||
|
"math/big"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
|
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
tlsServerName = "testdns.adguard.com"
|
||||||
|
dataDir = "testData"
|
||||||
|
testMessagesCount = 10
|
||||||
|
)
|
||||||
|
|
||||||
func TestServer(t *testing.T) {
|
func TestServer(t *testing.T) {
|
||||||
s := createTestServer(t)
|
s := createTestServer(t)
|
||||||
defer removeDataDir(t)
|
defer removeDataDir(t)
|
||||||
|
@ -22,7 +39,7 @@ func TestServer(t *testing.T) {
|
||||||
|
|
||||||
// message over UDP
|
// message over UDP
|
||||||
req := createTestMessage()
|
req := createTestMessage()
|
||||||
addr := s.dnsProxy.Addr("udp")
|
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||||
client := dns.Client{Net: "udp"}
|
client := dns.Client{Net: "udp"}
|
||||||
reply, _, err := client.Exchange(req, addr.String())
|
reply, _, err := client.Exchange(req, addr.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -63,6 +80,69 @@ func TestServer(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDotServer(t *testing.T) {
|
||||||
|
// Prepare the proxy server
|
||||||
|
_, certPem, keyPem := createServerTLSConfig(t)
|
||||||
|
s := createTestServer(t)
|
||||||
|
defer removeDataDir(t)
|
||||||
|
|
||||||
|
s.TLSConfig = TLSConfig{
|
||||||
|
TLSListenAddr: &net.TCPAddr{Port: 0},
|
||||||
|
CertificateChain: string(certPem),
|
||||||
|
PrivateKey: string(keyPem),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Starting the server
|
||||||
|
err := s.Start(nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to start server: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add our self-signed generated config to roots
|
||||||
|
roots := x509.NewCertPool()
|
||||||
|
roots.AppendCertsFromPEM(certPem)
|
||||||
|
tlsConfig := &tls.Config{ServerName: tlsServerName, RootCAs: roots}
|
||||||
|
|
||||||
|
// Create a DNS-over-TLS client connection
|
||||||
|
addr := s.dnsProxy.Addr(proxy.ProtoTLS)
|
||||||
|
conn, err := dns.DialWithTLS("tcp-tls", addr.String(), tlsConfig)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("cannot connect to the proxy: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sendTestMessages(t, conn)
|
||||||
|
|
||||||
|
// Stop the proxy
|
||||||
|
err = s.Stop()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("DNS server failed to stop: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerRace(t *testing.T) {
|
||||||
|
s := createTestServer(t)
|
||||||
|
defer removeDataDir(t)
|
||||||
|
err := s.Start(nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to start server: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// message over UDP
|
||||||
|
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||||
|
conn, err := dns.Dial("udp", addr.String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("cannot connect to the proxy: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sendTestMessagesAsync(t, conn)
|
||||||
|
|
||||||
|
// Stop the proxy
|
||||||
|
err = s.Stop()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("DNS server failed to stop: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestInvalidRequest(t *testing.T) {
|
func TestInvalidRequest(t *testing.T) {
|
||||||
s := createTestServer(t)
|
s := createTestServer(t)
|
||||||
defer removeDataDir(t)
|
defer removeDataDir(t)
|
||||||
|
@ -72,7 +152,7 @@ func TestInvalidRequest(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// server is running, send a message
|
// server is running, send a message
|
||||||
addr := s.dnsProxy.Addr("udp")
|
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||||
req := dns.Msg{}
|
req := dns.Msg{}
|
||||||
req.Id = dns.Id()
|
req.Id = dns.Id()
|
||||||
req.RecursionDesired = true
|
req.RecursionDesired = true
|
||||||
|
@ -106,7 +186,7 @@ func TestBlockedRequest(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to start server: %s", err)
|
t.Fatalf("Failed to start server: %s", err)
|
||||||
}
|
}
|
||||||
addr := s.dnsProxy.Addr("udp")
|
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||||
|
|
||||||
//
|
//
|
||||||
// NXDomain blocking
|
// NXDomain blocking
|
||||||
|
@ -147,7 +227,7 @@ func TestBlockedByHosts(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to start server: %s", err)
|
t.Fatalf("Failed to start server: %s", err)
|
||||||
}
|
}
|
||||||
addr := s.dnsProxy.Addr("udp")
|
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||||
|
|
||||||
//
|
//
|
||||||
// Hosts blocking
|
// Hosts blocking
|
||||||
|
@ -195,7 +275,7 @@ func TestBlockedBySafeBrowsing(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to start server: %s", err)
|
t.Fatalf("Failed to start server: %s", err)
|
||||||
}
|
}
|
||||||
addr := s.dnsProxy.Addr("udp")
|
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||||
|
|
||||||
//
|
//
|
||||||
// Safebrowsing blocking
|
// Safebrowsing blocking
|
||||||
|
@ -251,6 +331,7 @@ func createTestServer(t *testing.T) *Server {
|
||||||
s := NewServer(createDataDir(t))
|
s := NewServer(createDataDir(t))
|
||||||
s.UDPListenAddr = &net.UDPAddr{Port: 0}
|
s.UDPListenAddr = &net.UDPAddr{Port: 0}
|
||||||
s.TCPListenAddr = &net.TCPAddr{Port: 0}
|
s.TCPListenAddr = &net.TCPAddr{Port: 0}
|
||||||
|
|
||||||
s.QueryLogEnabled = true
|
s.QueryLogEnabled = true
|
||||||
s.FilteringConfig.FilteringEnabled = true
|
s.FilteringConfig.FilteringEnabled = true
|
||||||
s.FilteringConfig.ProtectionEnabled = true
|
s.FilteringConfig.ProtectionEnabled = true
|
||||||
|
@ -266,20 +347,111 @@ func createTestServer(t *testing.T) *Server {
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func createDataDir(t *testing.T) string {
|
func createServerTLSConfig(t *testing.T) (*tls.Config, []byte, []byte) {
|
||||||
dir := "testData"
|
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
err := os.MkdirAll(dir, 0755)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Cannot create %s: %s", dir, err)
|
t.Fatalf("cannot generate RSA key: %s", err)
|
||||||
}
|
}
|
||||||
return dir
|
|
||||||
|
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
|
||||||
|
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to generate serial number: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
notBefore := time.Now()
|
||||||
|
notAfter := notBefore.Add(5 * 365 * time.Hour * 24)
|
||||||
|
|
||||||
|
template := x509.Certificate{
|
||||||
|
SerialNumber: serialNumber,
|
||||||
|
Subject: pkix.Name{
|
||||||
|
Organization: []string{"AdGuard Tests"},
|
||||||
|
},
|
||||||
|
NotBefore: notBefore,
|
||||||
|
NotAfter: notAfter,
|
||||||
|
|
||||||
|
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
|
||||||
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||||
|
BasicConstraintsValid: true,
|
||||||
|
IsCA: true,
|
||||||
|
}
|
||||||
|
template.DNSNames = append(template.DNSNames, tlsServerName)
|
||||||
|
|
||||||
|
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey(privateKey), privateKey)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create certificate: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
certPem := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
|
||||||
|
keyPem := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)})
|
||||||
|
|
||||||
|
cert, err := tls.X509KeyPair(certPem, keyPem)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create certificate: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &tls.Config{Certificates: []tls.Certificate{cert}, ServerName: tlsServerName}, certPem, keyPem
|
||||||
|
}
|
||||||
|
|
||||||
|
func createDataDir(t *testing.T) string {
|
||||||
|
err := os.MkdirAll(dataDir, 0755)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Cannot create %s: %s", dataDir, err)
|
||||||
|
}
|
||||||
|
return dataDir
|
||||||
}
|
}
|
||||||
|
|
||||||
func removeDataDir(t *testing.T) {
|
func removeDataDir(t *testing.T) {
|
||||||
dir := "testData"
|
err := os.RemoveAll(dataDir)
|
||||||
err := os.RemoveAll(dir)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Cannot remove %s: %s", dir, err)
|
t.Fatalf("Cannot remove %s: %s", dataDir, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendTestMessageAsync(t *testing.T, conn *dns.Conn, g *sync.WaitGroup) {
|
||||||
|
defer func() {
|
||||||
|
g.Done()
|
||||||
|
}()
|
||||||
|
|
||||||
|
req := createTestMessage()
|
||||||
|
err := conn.WriteMsg(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("cannot write message: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := conn.ReadMsg()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("cannot read response to message: %s", err)
|
||||||
|
}
|
||||||
|
assertResponse(t, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendTestMessagesAsync sends messages in parallel
|
||||||
|
// so that we could find race issues
|
||||||
|
func sendTestMessagesAsync(t *testing.T, conn *dns.Conn) {
|
||||||
|
g := &sync.WaitGroup{}
|
||||||
|
g.Add(testMessagesCount)
|
||||||
|
|
||||||
|
for i := 0; i < testMessagesCount; i++ {
|
||||||
|
go sendTestMessageAsync(t, conn, g)
|
||||||
|
}
|
||||||
|
|
||||||
|
g.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendTestMessages(t *testing.T, conn *dns.Conn) {
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
req := createTestMessage()
|
||||||
|
err := conn.WriteMsg(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("cannot write message #%d: %s", i, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := conn.ReadMsg()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("cannot read response to message #%d: %s", i, err)
|
||||||
|
}
|
||||||
|
assertResponse(t, res)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -305,3 +477,14 @@ func assertResponse(t *testing.T, reply *dns.Msg) {
|
||||||
t.Fatalf("DNS server returned wrong answer type instead of A: %v", reply.Answer[0])
|
t.Fatalf("DNS server returned wrong answer type instead of A: %v", reply.Answer[0])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func publicKey(priv interface{}) interface{} {
|
||||||
|
switch k := priv.(type) {
|
||||||
|
case *rsa.PrivateKey:
|
||||||
|
return &k.PublicKey
|
||||||
|
case *ecdsa.PrivateKey:
|
||||||
|
return &k.PublicKey
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue