diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index 785f32b7..e035f8ed 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -70,6 +70,7 @@ type FilteringConfig struct { dnsfilter.Config `yaml:",inline"` } +// TLSConfig is the TLS configuration for HTTPS, DNS-over-HTTPS, and DNS-over-TLS type TLSConfig struct { TLSListenAddr *net.TCPAddr `yaml:"-" json:"-"` CertificateChain string `yaml:"certificate_chain" json:"certificate_chain"` // PEM-encoded certificates chain diff --git a/dnsforward/dnsforward_test.go b/dnsforward/dnsforward_test.go index 9553b9ed..28c115d1 100644 --- a/dnsforward/dnsforward_test.go +++ b/dnsforward/dnsforward_test.go @@ -1,17 +1,34 @@ package dnsforward import ( + "crypto/ecdsa" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" "net" "os" + "sync" "testing" "time" + "github.com/AdguardTeam/dnsproxy/proxy" + "github.com/stretchr/testify/assert" "github.com/AdguardTeam/AdGuardHome/dnsfilter" "github.com/miekg/dns" ) +const ( + tlsServerName = "testdns.adguard.com" + dataDir = "testData" + testMessagesCount = 10 +) + func TestServer(t *testing.T) { s := createTestServer(t) defer removeDataDir(t) @@ -22,7 +39,7 @@ func TestServer(t *testing.T) { // message over UDP req := createTestMessage() - addr := s.dnsProxy.Addr("udp") + addr := s.dnsProxy.Addr(proxy.ProtoUDP) client := dns.Client{Net: "udp"} reply, _, err := client.Exchange(req, addr.String()) if err != nil { @@ -63,6 +80,69 @@ func TestServer(t *testing.T) { } } +func TestDotServer(t *testing.T) { + // Prepare the proxy server + _, certPem, keyPem := createServerTLSConfig(t) + s := createTestServer(t) + defer removeDataDir(t) + + s.TLSConfig = TLSConfig{ + TLSListenAddr: &net.TCPAddr{Port: 0}, + CertificateChain: string(certPem), + PrivateKey: string(keyPem), + } + + // Starting the server + err := s.Start(nil) + if err != nil { + t.Fatalf("Failed to start server: %s", err) + } + + // Add our self-signed generated config to roots + roots := x509.NewCertPool() + roots.AppendCertsFromPEM(certPem) + tlsConfig := &tls.Config{ServerName: tlsServerName, RootCAs: roots} + + // Create a DNS-over-TLS client connection + addr := s.dnsProxy.Addr(proxy.ProtoTLS) + conn, err := dns.DialWithTLS("tcp-tls", addr.String(), tlsConfig) + if err != nil { + t.Fatalf("cannot connect to the proxy: %s", err) + } + + sendTestMessages(t, conn) + + // Stop the proxy + err = s.Stop() + if err != nil { + t.Fatalf("DNS server failed to stop: %s", err) + } +} + +func TestServerRace(t *testing.T) { + s := createTestServer(t) + defer removeDataDir(t) + err := s.Start(nil) + if err != nil { + t.Fatalf("Failed to start server: %s", err) + } + + // message over UDP + addr := s.dnsProxy.Addr(proxy.ProtoUDP) + conn, err := dns.Dial("udp", addr.String()) + if err != nil { + t.Fatalf("cannot connect to the proxy: %s", err) + } + + sendTestMessagesAsync(t, conn) + + // Stop the proxy + err = s.Stop() + if err != nil { + t.Fatalf("DNS server failed to stop: %s", err) + } +} + func TestInvalidRequest(t *testing.T) { s := createTestServer(t) defer removeDataDir(t) @@ -72,7 +152,7 @@ func TestInvalidRequest(t *testing.T) { } // server is running, send a message - addr := s.dnsProxy.Addr("udp") + addr := s.dnsProxy.Addr(proxy.ProtoUDP) req := dns.Msg{} req.Id = dns.Id() req.RecursionDesired = true @@ -106,7 +186,7 @@ func TestBlockedRequest(t *testing.T) { if err != nil { t.Fatalf("Failed to start server: %s", err) } - addr := s.dnsProxy.Addr("udp") + addr := s.dnsProxy.Addr(proxy.ProtoUDP) // // NXDomain blocking @@ -147,7 +227,7 @@ func TestBlockedByHosts(t *testing.T) { if err != nil { t.Fatalf("Failed to start server: %s", err) } - addr := s.dnsProxy.Addr("udp") + addr := s.dnsProxy.Addr(proxy.ProtoUDP) // // Hosts blocking @@ -195,7 +275,7 @@ func TestBlockedBySafeBrowsing(t *testing.T) { if err != nil { t.Fatalf("Failed to start server: %s", err) } - addr := s.dnsProxy.Addr("udp") + addr := s.dnsProxy.Addr(proxy.ProtoUDP) // // Safebrowsing blocking @@ -251,6 +331,7 @@ func createTestServer(t *testing.T) *Server { s := NewServer(createDataDir(t)) s.UDPListenAddr = &net.UDPAddr{Port: 0} s.TCPListenAddr = &net.TCPAddr{Port: 0} + s.QueryLogEnabled = true s.FilteringConfig.FilteringEnabled = true s.FilteringConfig.ProtectionEnabled = true @@ -266,20 +347,111 @@ func createTestServer(t *testing.T) *Server { return s } -func createDataDir(t *testing.T) string { - dir := "testData" - err := os.MkdirAll(dir, 0755) +func createServerTLSConfig(t *testing.T) (*tls.Config, []byte, []byte) { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { - t.Fatalf("Cannot create %s: %s", dir, err) + t.Fatalf("cannot generate RSA key: %s", err) } - return dir + + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + t.Fatalf("failed to generate serial number: %s", err) + } + + notBefore := time.Now() + notAfter := notBefore.Add(5 * 365 * time.Hour * 24) + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"AdGuard Tests"}, + }, + NotBefore: notBefore, + NotAfter: notAfter, + + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + IsCA: true, + } + template.DNSNames = append(template.DNSNames, tlsServerName) + + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey(privateKey), privateKey) + if err != nil { + t.Fatalf("failed to create certificate: %s", err) + } + + certPem := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + keyPem := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)}) + + cert, err := tls.X509KeyPair(certPem, keyPem) + if err != nil { + t.Fatalf("failed to create certificate: %s", err) + } + + return &tls.Config{Certificates: []tls.Certificate{cert}, ServerName: tlsServerName}, certPem, keyPem +} + +func createDataDir(t *testing.T) string { + err := os.MkdirAll(dataDir, 0755) + if err != nil { + t.Fatalf("Cannot create %s: %s", dataDir, err) + } + return dataDir } func removeDataDir(t *testing.T) { - dir := "testData" - err := os.RemoveAll(dir) + err := os.RemoveAll(dataDir) if err != nil { - t.Fatalf("Cannot remove %s: %s", dir, err) + t.Fatalf("Cannot remove %s: %s", dataDir, err) + } +} + +func sendTestMessageAsync(t *testing.T, conn *dns.Conn, g *sync.WaitGroup) { + defer func() { + g.Done() + }() + + req := createTestMessage() + err := conn.WriteMsg(req) + if err != nil { + t.Fatalf("cannot write message: %s", err) + } + + res, err := conn.ReadMsg() + if err != nil { + t.Fatalf("cannot read response to message: %s", err) + } + assertResponse(t, res) +} + +// sendTestMessagesAsync sends messages in parallel +// so that we could find race issues +func sendTestMessagesAsync(t *testing.T, conn *dns.Conn) { + g := &sync.WaitGroup{} + g.Add(testMessagesCount) + + for i := 0; i < testMessagesCount; i++ { + go sendTestMessageAsync(t, conn, g) + } + + g.Wait() +} + +func sendTestMessages(t *testing.T, conn *dns.Conn) { + for i := 0; i < 10; i++ { + req := createTestMessage() + err := conn.WriteMsg(req) + if err != nil { + t.Fatalf("cannot write message #%d: %s", i, err) + } + + res, err := conn.ReadMsg() + if err != nil { + t.Fatalf("cannot read response to message #%d: %s", i, err) + } + assertResponse(t, res) } } @@ -305,3 +477,14 @@ func assertResponse(t *testing.T, reply *dns.Msg) { t.Fatalf("DNS server returned wrong answer type instead of A: %v", reply.Answer[0]) } } + +func publicKey(priv interface{}) interface{} { + switch k := priv.(type) { + case *rsa.PrivateKey: + return &k.PublicKey + case *ecdsa.PrivateKey: + return &k.PublicKey + default: + return nil + } +}