diff --git a/internal/dhcpd/v6.go b/internal/dhcpd/v6.go index f9dab113..51788551 100644 --- a/internal/dhcpd/v6.go +++ b/internal/dhcpd/v6.go @@ -42,7 +42,6 @@ func (s *v6Server) WriteDiskConfig6(c *V6ServerConf) { } // Return TRUE if IP address is within range [start..0xff] -// nolint(staticcheck) func ip6InRange(start, ip net.IP) bool { if len(start) != 16 { return false diff --git a/internal/dnsforward/dnsforward.go b/internal/dnsforward/dnsforward.go index 58ea3b72..c24ba62b 100644 --- a/internal/dnsforward/dnsforward.go +++ b/internal/dnsforward/dnsforward.go @@ -163,15 +163,15 @@ func (s *Server) Exchange(req *dns.Msg) (*dns.Msg, error) { return ctx.Res, nil } -// Start starts the DNS server +// Start starts the DNS server. func (s *Server) Start() error { s.Lock() defer s.Unlock() - return s.startInternal() + return s.startLocked() } -// startInternal starts without locking -func (s *Server) startInternal() error { +// startLocked starts the DNS server without locking. For internal use only. +func (s *Server) startLocked() error { err := s.dnsProxy.Start() if err == nil { s.isRunning = true @@ -256,15 +256,15 @@ func (s *Server) Prepare(config *ServerConfig) error { return nil } -// Stop stops the DNS server +// Stop stops the DNS server. func (s *Server) Stop() error { s.Lock() defer s.Unlock() - return s.stopInternal() + return s.stopLocked() } -// stopInternal stops without locking -func (s *Server) stopInternal() error { +// stopLocked stops the DNS server without locking. For internal use only. +func (s *Server) stopLocked() error { if s.dnsProxy != nil { err := s.dnsProxy.Stop() if err != nil { @@ -289,7 +289,7 @@ func (s *Server) Reconfigure(config *ServerConfig) error { defer s.Unlock() log.Print("Start reconfiguring the server") - err := s.stopInternal() + err := s.stopLocked() if err != nil { return fmt.Errorf("could not reconfigure the server: %w", err) } @@ -303,7 +303,7 @@ func (s *Server) Reconfigure(config *ServerConfig) error { return fmt.Errorf("could not reconfigure the server: %w", err) } - err = s.startInternal() + err = s.startLocked() if err != nil { return fmt.Errorf("could not reconfigure the server: %w", err) } diff --git a/internal/dnsforward/dnsforward_test.go b/internal/dnsforward/dnsforward_test.go index 17f4b872..86356c6b 100644 --- a/internal/dnsforward/dnsforward_test.go +++ b/internal/dnsforward/dnsforward_test.go @@ -1,9 +1,11 @@ package dnsforward import ( + "context" "crypto/ecdsa" "crypto/rand" "crypto/rsa" + "crypto/sha256" "crypto/tls" "crypto/x509" "crypto/x509/pkix" @@ -38,82 +40,104 @@ const ( testMessagesCount = 10 ) +func startDeferStop(t *testing.T, s *Server) { + t.Helper() + + err := s.Start() + assert.Nilf(t, err, "failed to start server: %s", err) + + t.Cleanup(func() { + err := s.Stop() + assert.Nilf(t, err, "dns server failed to stop: %s", err) + }) +} + func TestServer(t *testing.T) { s := createTestServer(t) - err := s.Start() - if err != nil { - t.Fatalf("Failed to start server: %s", err) + s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ + &testUpstream{ + ipv4: map[string][]net.IP{ + "google-public-dns-a.google.com.": {{8, 8, 8, 8}}, + }, + }, } + startDeferStop(t, s) - // message over UDP - req := createGoogleATestMessage() - addr := s.dnsProxy.Addr(proxy.ProtoUDP) - client := dns.Client{Net: "udp"} - reply, _, err := client.Exchange(req, addr.String()) - if err != nil { - t.Fatalf("Couldn't talk to server %s: %s", addr, err) - } - assertGoogleAResponse(t, reply) + testCases := []struct { + name string + proto string + }{{ + name: "message_over_udp", + proto: proxy.ProtoUDP, + }, { + name: "message_over_tcp", + proto: proxy.ProtoTCP, + }} - // message over TCP - req = createGoogleATestMessage() - addr = s.dnsProxy.Addr("tcp") - client = dns.Client{Net: "tcp"} - reply, _, err = client.Exchange(req, addr.String()) - if err != nil { - t.Fatalf("Couldn't talk to server %s: %s", addr, err) - } - assertGoogleAResponse(t, reply) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + addr := s.dnsProxy.Addr(tc.proto) + client := dns.Client{Net: tc.proto} - err = s.Stop() - if err != nil { - t.Fatalf("DNS server failed to stop: %s", err) + reply, _, err := client.Exchange(createGoogleATestMessage(), addr.String()) + assert.Nilf(t, err, "сouldn't talk to server %s: %s", addr, err) + + assertGoogleAResponse(t, reply) + }) } } func TestServerWithProtectionDisabled(t *testing.T) { s := createTestServer(t) s.conf.ProtectionEnabled = false - err := s.Start() - if err != nil { - t.Fatalf("Failed to start server: %s", err) + s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ + &testUpstream{ + ipv4: map[string][]net.IP{ + "google-public-dns-a.google.com.": {{8, 8, 8, 8}}, + }, + }, } + startDeferStop(t, s) - // message over UDP + // Message over UDP. req := createGoogleATestMessage() addr := s.dnsProxy.Addr(proxy.ProtoUDP) - client := dns.Client{Net: "udp"} + client := dns.Client{Net: proxy.ProtoUDP} reply, _, err := client.Exchange(req, addr.String()) - if err != nil { - t.Fatalf("Couldn't talk to server %s: %s", addr, err) - } + assert.Nilf(t, err, "сouldn't talk to server %s: %s", addr, err) assertGoogleAResponse(t, reply) - - err = s.Stop() - if err != nil { - t.Fatalf("DNS server failed to stop: %s", err) - } } -func TestDotServer(t *testing.T) { - // Prepare the proxy server - _, certPem, keyPem := createServerTLSConfig(t) - s := createTestServer(t) +func createTestTLS(t *testing.T, tlsConf TLSConfig) (s *Server, certPem []byte) { + t.Helper() - s.conf.TLSConfig = TLSConfig{ - TLSListenAddr: &net.TCPAddr{Port: 0}, - CertificateChainData: certPem, - PrivateKeyData: keyPem, + var keyPem []byte + _, certPem, keyPem = createServerTLSConfig(t) + s = createTestServer(t) + + tlsConf.CertificateChainData, tlsConf.PrivateKeyData = certPem, keyPem + s.conf.TLSConfig = tlsConf + + err := s.Prepare(nil) + assert.Nilf(t, err, "failed to prepare server: %s", err) + + return s, certPem +} + +func TestDoTServer(t *testing.T) { + s, certPem := createTestTLS(t, TLSConfig{ + TLSListenAddr: &net.TCPAddr{Port: 0}, + }) + s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ + &testUpstream{ + ipv4: map[string][]net.IP{ + "google-public-dns-a.google.com.": {{8, 8, 8, 8}}, + }, + }, } + startDeferStop(t, s) - _ = s.Prepare(nil) - // Starting the server - err := s.Start() - if err != nil { - t.Fatalf("Failed to start server: %s", err) - } - - // Add our self-signed generated config to roots + // Add our self-signed generated config to roots. roots := x509.NewCertPool() roots.AppendCertsFromPEM(certPem) tlsConfig := &tls.Config{ @@ -122,275 +146,259 @@ func TestDotServer(t *testing.T) { MinVersion: tls.VersionTLS12, } - // Create a DNS-over-TLS client connection + // Create a DNS-over-TLS client connection. addr := s.dnsProxy.Addr(proxy.ProtoTLS) conn, err := dns.DialWithTLS("tcp-tls", addr.String(), tlsConfig) - if err != nil { - t.Fatalf("cannot connect to the proxy: %s", err) - } + assert.Nilf(t, err, "cannot connect to the proxy: %s", err) sendTestMessages(t, conn) - - // Stop the proxy - err = s.Stop() - if err != nil { - t.Fatalf("DNS server failed to stop: %s", err) - } } -func TestDoqServer(t *testing.T) { - // Prepare the proxy server - _, certPem, keyPem := createServerTLSConfig(t) - s := createTestServer(t) - - s.conf.TLSConfig = TLSConfig{ - QUICListenAddr: &net.UDPAddr{Port: 0}, - CertificateChainData: certPem, - PrivateKeyData: keyPem, +func TestDoQServer(t *testing.T) { + s, _ := createTestTLS(t, TLSConfig{ + QUICListenAddr: &net.UDPAddr{Port: 0}, + }) + s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ + &testUpstream{ + ipv4: map[string][]net.IP{ + "google-public-dns-a.google.com.": {{8, 8, 8, 8}}, + }, + }, } + startDeferStop(t, s) - _ = s.Prepare(nil) - // Starting the server - err := s.Start() - assert.Nil(t, err) - - // Create a DNS-over-QUIC upstream + // Create a DNS-over-QUIC upstream. addr := s.dnsProxy.Addr(proxy.ProtoQUIC) opts := upstream.Options{InsecureSkipVerify: true} - u, err := upstream.AddressToUpstream(fmt.Sprintf("quic://%s", addr), opts) + u, err := upstream.AddressToUpstream(fmt.Sprintf("%s://%s", proxy.ProtoQUIC, addr), opts) assert.Nil(t, err) - // Send the test message + // Send the test message. req := createGoogleATestMessage() res, err := u.Exchange(req) assert.Nil(t, err) - assertGoogleAResponse(t, res) - // Stop the proxy - err = s.Stop() - if err != nil { - t.Fatalf("DNS server failed to stop: %s", err) - } + assertGoogleAResponse(t, res) } func TestServerRace(t *testing.T) { - s := createTestServer(t) - err := s.Start() - if err != nil { - t.Fatalf("Failed to start server: %s", err) - } + t.Skip("TODO(e.burkov): inspect the golibs/cache package for locks") - // message over UDP - addr := s.dnsProxy.Addr(proxy.ProtoUDP) - conn, err := dns.Dial("udp", addr.String()) - if err != nil { - t.Fatalf("cannot connect to the proxy: %s", err) + s := createTestServer(t) + s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ + &testUpstream{ + ipv4: map[string][]net.IP{ + "google-public-dns-a.google.com.": {{8, 8, 8, 8}}, + }, + }, } + startDeferStop(t, s) + + // Message over UDP. + addr := s.dnsProxy.Addr(proxy.ProtoUDP) + conn, err := dns.Dial(proxy.ProtoUDP, addr.String()) + assert.Nilf(t, err, "cannot connect to the proxy: %s", err) sendTestMessagesAsync(t, conn) +} - // Stop the proxy - err = s.Stop() - if err != nil { - t.Fatalf("DNS server failed to stop: %s", err) +// testResolver is a Resolver for tests. +// +//lint:ignore U1000 TODO(e.burkov): move into aghtest package. +type testResolver struct{} + +// LookupIPAddr implements Resolver interface for *testResolver. +// +//lint:ignore U1000 TODO(e.burkov): move into aghtest package. +func (r *testResolver) LookupIPAddr(_ context.Context, host string) (ips []net.IPAddr, err error) { + hash := sha256.Sum256([]byte(host)) + addrs := []net.IPAddr{{ + IP: net.IP(hash[:4]), + Zone: "somezone", + }, { + IP: net.IP(hash[4:20]), + Zone: "somezone", + }} + return addrs, nil +} + +// LookupHost implements Resolver interface for *testResolver. +// +//lint:ignore U1000 TODO(e.burkov): move into aghtest package. +func (r *testResolver) LookupHost(host string) (addrs []string, err error) { + hash := sha256.Sum256([]byte(host)) + addrs = []string{ + net.IP(hash[:4]).String(), + net.IP(hash[4:20]).String(), } + return addrs, nil } func TestSafeSearch(t *testing.T) { - s := createTestServer(t) - err := s.Start() - if err != nil { - t.Fatalf("Failed to start server: %s", err) + t.Skip("TODO(e.burkov): substitute the dnsfilter by one with custom resolver from aghtest") + + testUpstreamIP := net.IP{213, 180, 193, 56} + testCases := []string{ + "yandex.com.", + "yandex.by.", + "yandex.kz.", + "yandex.ru.", + "www.google.com.", + "www.google.com.af.", + "www.google.be.", + "www.google.by.", } - // Test safe search for yandex. We already know safe search ip - addr := s.dnsProxy.Addr(proxy.ProtoUDP) - client := dns.Client{Net: "udp"} - yandexDomains := []string{"yandex.com.", "yandex.by.", "yandex.kz.", "yandex.ru.", "yandex.com."} - for _, host := range yandexDomains { - exchangeAndAssertResponse(t, &client, addr, host, "213.180.193.56") - } + for _, tc := range testCases { + t.Run("safe_search_"+tc, func(t *testing.T) { + s := createTestServer(t) + startDeferStop(t, s) - // Let's lookup for google safesearch ip - ips, err := net.LookupIP("forcesafesearch.google.com") - if err != nil { - t.Fatalf("Failed to lookup for forcesafesearch.google.com: %s", err) - } + addr := s.dnsProxy.Addr(proxy.ProtoUDP) + client := dns.Client{Net: proxy.ProtoUDP} - ip := ips[0] - for _, i := range ips { - if i.To4() != nil { - ip = i - break - } - } - - // Test safe search for google. - googleDomains := []string{"www.google.com.", "www.google.com.af.", "www.google.be.", "www.google.by."} - for _, host := range googleDomains { - exchangeAndAssertResponse(t, &client, addr, host, ip.String()) - } - - err = s.Stop() - if err != nil { - t.Fatalf("Can not stopd server cause: %s", err) + exchangeAndAssertResponse(t, &client, addr, tc, testUpstreamIP) + }) } } func TestInvalidRequest(t *testing.T) { s := createTestServer(t) - err := s.Start() - if err != nil { - t.Fatalf("Failed to start server: %s", err) + startDeferStop(t, s) + + addr := s.dnsProxy.Addr(proxy.ProtoUDP).String() + req := dns.Msg{ + MsgHdr: dns.MsgHdr{ + Id: dns.Id(), + RecursionDesired: true, + }, } - // server is running, send a message - addr := s.dnsProxy.Addr(proxy.ProtoUDP) - req := dns.Msg{} - req.Id = dns.Id() - req.RecursionDesired = true + // Send a DNS request without question. + _, _, err := (&dns.Client{ + Net: proxy.ProtoUDP, + Timeout: 500 * time.Millisecond, + }).Exchange(&req, addr) - // send a DNS request without question - client := dns.Client{Net: "udp", Timeout: 500 * time.Millisecond} - _, _, err = client.Exchange(&req, addr.String()) - if err != nil { - t.Fatalf("got a response to an invalid query") - } - - err = s.Stop() - if err != nil { - t.Fatalf("DNS server failed to stop: %s", err) - } + assert.Nil(t, err, "got a response to an invalid query") } func TestBlockedRequest(t *testing.T) { s := createTestServer(t) - err := s.Start() - if err != nil { - t.Fatalf("Failed to start server: %s", err) - } + startDeferStop(t, s) + addr := s.dnsProxy.Addr(proxy.ProtoUDP) - // - // Default blocking - NULL IP - // - req := dns.Msg{} - req.Id = dns.Id() - req.RecursionDesired = true - req.Question = []dns.Question{ - {Name: "nxdomain.example.org.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, - } + // Default blocking. + req := createTestMessage("nxdomain.example.org.") - reply, err := dns.Exchange(&req, addr.String()) - if err != nil { - t.Fatalf("Couldn't talk to server %s: %s", addr, err) - } + reply, err := dns.Exchange(req, addr.String()) + assert.Nilf(t, err, "couldn't talk to server %s: %s", addr, err) assert.Equal(t, dns.RcodeSuccess, reply.Rcode) - assert.True(t, reply.Answer[0].(*dns.A).A.Equal(net.IP{0, 0, 0, 0})) - - err = s.Stop() - if err != nil { - t.Fatalf("DNS server failed to stop: %s", err) - } + assert.True(t, reply.Answer[0].(*dns.A).A.IsUnspecified()) } func TestServerCustomClientUpstream(t *testing.T) { s := createTestServer(t) s.conf.GetCustomUpstreamByClient = func(_ string) *proxy.UpstreamConfig { - uc := &proxy.UpstreamConfig{} - u := &testUpstream{} - u.ipv4 = map[string][]net.IP{} - u.ipv4["host."] = []net.IP{{192, 168, 0, 1}} - uc.Upstreams = append(uc.Upstreams, u) - return uc + return &proxy.UpstreamConfig{ + Upstreams: []upstream.Upstream{ + &testUpstream{ + ipv4: map[string][]net.IP{ + "host.": {{192, 168, 0, 1}}, + }, + }, + }, + } } - - assert.Nil(t, s.Start()) + startDeferStop(t, s) addr := s.dnsProxy.Addr(proxy.ProtoUDP) - // Send test request - req := dns.Msg{} - req.Id = dns.Id() - req.RecursionDesired = true - req.Question = []dns.Question{ - {Name: "host.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, - } + // Send test request. + req := createTestMessage("host.") - reply, err := dns.Exchange(&req, addr.String()) + reply, err := dns.Exchange(req, addr.String()) assert.Nil(t, err) assert.Equal(t, dns.RcodeSuccess, reply.Rcode) - assert.NotNil(t, reply.Answer) - assert.True(t, net.IP{192, 168, 0, 1}.Equal(reply.Answer[0].(*dns.A).A)) - assert.Nil(t, s.Stop()) + assert.NotEmpty(t, reply.Answer) + + assert.Equal(t, net.IP{192, 168, 0, 1}, reply.Answer[0].(*dns.A).A) } -// testUpstream is a mock of real upstream. -// specify fields with necessary values to simulate real upstream behaviour +// testUpstream is a mock of real upstream. specify fields with necessary values +// to simulate real upstream behaviour. +// +// TODO(e.burkov): move into aghtest package. type testUpstream struct { - cn map[string]string // Map of [name]canonical_name - ipv4 map[string][]net.IP // Map of [name]IPv4 - ipv6 map[string][]net.IP // Map of [name]IPv6 + // cn is a map of hostname to canonical name. + cn map[string]string + // ipv4 is a map of hostname to IPv4. + ipv4 map[string][]net.IP + // ipv6 is a map of hostname to IPv6. + ipv6 map[string][]net.IP } +// Exchange implements upstream.Upstream interface for *testUpstream. func (u *testUpstream) Exchange(m *dns.Msg) (*dns.Msg, error) { - resp := dns.Msg{} + resp := &dns.Msg{} resp.SetReply(m) - hasARecord := false - hasAAAARecord := false + hasRec := false - reqType := m.Question[0].Qtype name := m.Question[0].Name - // Let's check if we have any CNAME for given name if cname, ok := u.cn[name]; ok { - cn := dns.CNAME{} - cn.Hdr.Name = name - cn.Hdr.Rrtype = dns.TypeCNAME - cn.Target = cname - resp.Answer = append(resp.Answer, &cn) + resp.Answer = append(resp.Answer, &dns.CNAME{ + Hdr: dns.RR_Header{ + Name: name, + Rrtype: dns.TypeCNAME, + }, + Target: cname, + }) } - // Let's check if we can add some A records to the answer - if ipv4addr, ok := u.ipv4[name]; ok && reqType == dns.TypeA { - hasARecord = true - for _, ipv4 := range ipv4addr { - respA := dns.A{} - respA.Hdr.Rrtype = dns.TypeA - respA.Hdr.Name = name - respA.A = ipv4 - resp.Answer = append(resp.Answer, &respA) + var rrtype uint16 + var a []net.IP + switch m.Question[0].Qtype { + case dns.TypeA: + rrtype = dns.TypeA + if ipv4addr, ok := u.ipv4[name]; ok { + hasRec = true + a = ipv4addr + } + case dns.TypeAAAA: + rrtype = dns.TypeAAAA + if ipv6addr, ok := u.ipv6[name]; ok { + hasRec = true + a = ipv6addr } } - - // Let's check if we can add some AAAA records to the answer - if u.ipv6 != nil { - if ipv6addr, ok := u.ipv6[name]; ok && reqType == dns.TypeAAAA { - hasAAAARecord = true - for _, ipv6 := range ipv6addr { - respAAAA := dns.A{} - respAAAA.Hdr.Rrtype = dns.TypeAAAA - respAAAA.Hdr.Name = name - respAAAA.A = ipv6 - resp.Answer = append(resp.Answer, &respAAAA) - } - } + for _, ip := range a { + resp.Answer = append(resp.Answer, &dns.A{ + Hdr: dns.RR_Header{ + Name: name, + Rrtype: rrtype, + }, + A: ip, + }) } if len(resp.Answer) == 0 { - if hasARecord || hasAAAARecord { - // Set No Error RCode if there are some records for given Qname but we didn't apply them + if hasRec { + // Set no error RCode if there are some records for + // given Qname but we didn't apply them. resp.SetRcode(m, dns.RcodeSuccess) - } else { - // Set NXDomain RCode otherwise - resp.SetRcode(m, dns.RcodeNameError) + return resp, nil } + // Set NXDomain RCode otherwise. + resp.SetRcode(m, dns.RcodeNameError) } - return &resp, nil + return resp, nil } +// Address implements upstream.Upstream interface for *testUpstream. func (u *testUpstream) Address() string { return "test" } @@ -422,14 +430,18 @@ var testIPv4 = map[string][]net.IP{ func TestBlockCNAMEProtectionEnabled(t *testing.T) { s := createTestServer(t) - testUpstm := &testUpstream{testCNAMEs, testIPv4, nil} + testUpstm := &testUpstream{ + cn: testCNAMEs, + ipv4: testIPv4, + ipv6: nil, + } s.conf.ProtectionEnabled = false err := s.startWithUpstream(testUpstm) assert.Nil(t, err) addr := s.dnsProxy.Addr(proxy.ProtoUDP) - // 'badhost' has a canonical name 'null.example.org' which is blocked by filters: - // but protection is disabled - response is NOT blocked + // 'badhost' has a canonical name 'null.example.org' which is blocked by + // filters: but protection is disabled so response is _not_ blocked. req := createTestMessage("badhost.") reply, err := dns.Exchange(req, addr.String()) assert.Nil(t, err) @@ -438,57 +450,81 @@ func TestBlockCNAMEProtectionEnabled(t *testing.T) { func TestBlockCNAME(t *testing.T) { s := createTestServer(t) - testUpstm := &testUpstream{testCNAMEs, testIPv4, nil} - err := s.startWithUpstream(testUpstm) - assert.Nil(t, err) - addr := s.dnsProxy.Addr(proxy.ProtoUDP) + s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ + &testUpstream{ + cn: testCNAMEs, + ipv4: testIPv4, + }, + } + startDeferStop(t, s) - // 'badhost' has a canonical name 'null.example.org' which is blocked by filters: - // response is blocked - req := createTestMessage("badhost.") - reply, err := dns.Exchange(req, addr.String()) - assert.Nil(t, err) - assert.Equal(t, dns.RcodeSuccess, reply.Rcode) - assert.True(t, reply.Answer[0].(*dns.A).A.Equal(net.IP{0, 0, 0, 0})) + addr := s.dnsProxy.Addr(proxy.ProtoUDP).String() - // 'whitelist.example.org' has a canonical name 'null.example.org' which is blocked by filters - // but 'whitelist.example.org' is in a whitelist: - // response isn't blocked - req = createTestMessage("whitelist.example.org.") - reply, err = dns.Exchange(req, addr.String()) - assert.Nil(t, err) - assert.Equal(t, dns.RcodeSuccess, reply.Rcode) + testCases := []struct { + host string + want bool + }{{ + host: "badhost.", + // 'badhost' has a canonical name 'null.example.org' which is + // blocked by filters: response is blocked. + want: true, + }, { + host: "whitelist.example.org.", + // 'whitelist.example.org' has a canonical name + // 'null.example.org' which is blocked by filters + // but 'whitelist.example.org' is in a whitelist: + // response isn't blocked. + want: false, + }, { + host: "example.org.", + // 'example.org' has a canonical name 'cname1' with IP + // 127.0.0.255 which is blocked by filters: response is blocked. + want: true, + }} - // 'example.org' has a canonical name 'cname1' with IP 127.0.0.255 which is blocked by filters: - // response is blocked - req = createTestMessage("example.org.") - reply, err = dns.Exchange(req, addr.String()) - assert.Nil(t, err) - assert.Equal(t, dns.RcodeSuccess, reply.Rcode) - assert.True(t, reply.Answer[0].(*dns.A).A.Equal(net.IP{0, 0, 0, 0})) - - _ = s.Stop() + for _, tc := range testCases { + t.Run("block_cname_"+tc.host, func(t *testing.T) { + req := createTestMessage(tc.host) + reply, err := dns.Exchange(req, addr) + assert.Nil(t, err) + assert.Equal(t, dns.RcodeSuccess, reply.Rcode) + if tc.want { + assert.True(t, reply.Answer[0].(*dns.A).A.IsUnspecified()) + } + }) + } } func TestClientRulesForCNAMEMatching(t *testing.T) { s := createTestServer(t) - testUpstm := &testUpstream{testCNAMEs, testIPv4, nil} s.conf.FilterHandler = func(_ net.IP, _ string, settings *dnsfilter.RequestFilteringSettings) { settings.FilteringEnabled = false } - err := s.startWithUpstream(testUpstm) - assert.Nil(t, err) + s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ + &testUpstream{ + cn: testCNAMEs, + ipv4: testIPv4, + }, + } + startDeferStop(t, s) + addr := s.dnsProxy.Addr(proxy.ProtoUDP) - // 'badhost' has a canonical name 'null.example.org' which is blocked by filters: - // response is blocked - req := dns.Msg{} - req.Id = dns.Id() - req.Question = []dns.Question{ - {Name: "badhost.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, + // 'badhost' has a canonical name 'null.example.org' which is blocked by + // filters: response is blocked. + req := dns.Msg{ + MsgHdr: dns.MsgHdr{ + Id: dns.Id(), + }, + Question: []dns.Question{{ + Name: "badhost.", + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + }}, } - // However, in our case it should not be blocked - // as filtering is disabled on the client level + + // However, in our case it should not be blocked as filtering is + // disabled on the client level. reply, err := dns.Exchange(&req, addr.String()) assert.Nil(t, err) assert.Equal(t, dns.RcodeSuccess, reply.Rcode) @@ -497,68 +533,58 @@ func TestClientRulesForCNAMEMatching(t *testing.T) { func TestNullBlockedRequest(t *testing.T) { s := createTestServer(t) s.conf.FilteringConfig.BlockingMode = "null_ip" - err := s.Start() - if err != nil { - t.Fatalf("Failed to start server: %s", err) - } + startDeferStop(t, s) addr := s.dnsProxy.Addr(proxy.ProtoUDP) - // - // Null filter blocking - // - req := dns.Msg{} - req.Id = dns.Id() - req.RecursionDesired = true - req.Question = []dns.Question{ - {Name: "null.example.org.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, + // Nil filter blocking. + req := dns.Msg{ + MsgHdr: dns.MsgHdr{ + Id: dns.Id(), + RecursionDesired: true, + }, + Question: []dns.Question{{ + Name: "null.example.org.", + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + }}, } reply, err := dns.Exchange(&req, addr.String()) - if err != nil { - t.Fatalf("Couldn't talk to server %s: %s", addr, err) - } - if len(reply.Answer) != 1 { - t.Fatalf("DNS server %s returned reply with wrong number of answers - %d", addr, len(reply.Answer)) - } - if a, ok := reply.Answer[0].(*dns.A); ok { - if !net.IPv4zero.Equal(a.A) { - t.Fatalf("DNS server %s returned wrong answer instead of 0.0.0.0: %v", addr, a.A) - } - } else { - t.Fatalf("DNS server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0]) - } - - err = s.Stop() - if err != nil { - t.Fatalf("DNS server failed to stop: %s", err) - } + assert.Nilf(t, err, "couldn't talk to server %s: %s", addr, err) + assert.Lenf(t, reply.Answer, 1, "dns server %s returned reply with wrong number of answers - %d", addr, len(reply.Answer)) + a, ok := reply.Answer[0].(*dns.A) + assert.Truef(t, ok, "dns server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0]) + assert.Truef(t, a.A.IsUnspecified(), "dns server %s returned wrong answer instead of 0.0.0.0: %v", addr, a.A) } func TestBlockedCustomIP(t *testing.T) { rules := "||nxdomain.example.org^\n||null.example.org^\n127.0.0.1 host.example.org\n@@||whitelist.example.org^\n||127.0.0.255\n" filters := []dnsfilter.Filter{{ - ID: 0, Data: []byte(rules), + ID: 0, + Data: []byte(rules), }} - c := dnsfilter.Config{} - f := dnsfilter.New(&c, filters) - s := NewServer(DNSCreateParams{DNSFilter: f}) - conf := ServerConfig{} - conf.UDPListenAddr = &net.UDPAddr{Port: 0} - conf.TCPListenAddr = &net.TCPAddr{Port: 0} - conf.ProtectionEnabled = true - conf.BlockingMode = "custom_ip" - conf.BlockingIPv4 = nil - conf.UpstreamDNS = []string{"8.8.8.8:53", "8.8.4.4:53"} - err := s.Prepare(&conf) - assert.NotNil(t, err) // invalid BlockingIPv4 + s := NewServer(DNSCreateParams{ + DNSFilter: dnsfilter.New(&dnsfilter.Config{}, filters), + }) + conf := ServerConfig{ + UDPListenAddr: &net.UDPAddr{Port: 0}, + TCPListenAddr: &net.TCPAddr{Port: 0}, + FilteringConfig: FilteringConfig{ + ProtectionEnabled: true, + BlockingMode: "custom_ip", + BlockingIPv4: nil, + UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"}, + }, + } + // Invalid BlockingIPv4. + assert.NotNil(t, s.Prepare(&conf)) conf.BlockingIPv4 = net.IP{0, 0, 0, 1} conf.BlockingIPv6 = net.ParseIP("::1") - err = s.Prepare(&conf) - assert.Nil(t, err) - err = s.Start() - assert.Nil(t, err) + assert.Nil(t, s.Prepare(&conf)) + + startDeferStop(t, s) addr := s.dnsProxy.Addr(proxy.ProtoUDP) @@ -577,134 +603,102 @@ func TestBlockedCustomIP(t *testing.T) { a6, ok := reply.Answer[0].(*dns.AAAA) assert.True(t, ok) assert.Equal(t, "::1", a6.AAAA.String()) - - err = s.Stop() - if err != nil { - t.Fatalf("DNS server failed to stop: %s", err) - } } func TestBlockedByHosts(t *testing.T) { s := createTestServer(t) - err := s.Start() - if err != nil { - t.Fatalf("Failed to start server: %s", err) - } + startDeferStop(t, s) addr := s.dnsProxy.Addr(proxy.ProtoUDP) - // - // Hosts blocking - // - req := dns.Msg{} - req.Id = dns.Id() - req.RecursionDesired = true - req.Question = []dns.Question{ - {Name: "host.example.org.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, - } + // Hosts blocking. + req := createTestMessage("host.example.org.") - reply, err := dns.Exchange(&req, addr.String()) - if err != nil { - t.Fatalf("Couldn't talk to server %s: %s", addr, err) - } - if len(reply.Answer) != 1 { - t.Fatalf("DNS server %s returned reply with wrong number of answers - %d", addr, len(reply.Answer)) - } - if a, ok := reply.Answer[0].(*dns.A); ok { - if !net.IPv4(127, 0, 0, 1).Equal(a.A) { - t.Fatalf("DNS server %s returned wrong answer instead of 8.8.8.8: %v", addr, a.A) - } - } else { - t.Fatalf("DNS server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0]) - } + reply, err := dns.Exchange(req, addr.String()) + assert.Nilf(t, err, "couldn't talk to server %s: %s", addr, err) + assert.Lenf(t, reply.Answer, 1, "dns server %s returned reply with wrong number of answers - %d", addr, len(reply.Answer)) - err = s.Stop() - if err != nil { - t.Fatalf("DNS server failed to stop: %s", err) - } + a, ok := reply.Answer[0].(*dns.A) + assert.Truef(t, ok, "dns server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0]) + assert.Equalf(t, net.IP{127, 0, 0, 1}, a.A, "dns server %s returned wrong answer instead of 8.8.8.8: %v", addr, a.A) } func TestBlockedBySafeBrowsing(t *testing.T) { + t.Skip("TODO(e.burkov): substitute the dnsfilter by one with custom safeBrowsingUpstream") + resolver := &testResolver{} + ips, _ := resolver.LookupIPAddr(context.Background(), safeBrowsingBlockHost) + addrs, _ := resolver.LookupHost(safeBrowsingBlockHost) + s := createTestServer(t) - err := s.Start() - if err != nil { - t.Fatalf("Failed to start server: %s", err) + s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ + &testUpstream{ + ipv4: map[string][]net.IP{ + "wmconvirus.narod.ru.": {ips[0].IP}, + }, + }, } + startDeferStop(t, s) addr := s.dnsProxy.Addr(proxy.ProtoUDP) - // - // Safebrowsing blocking - // - req := dns.Msg{} - req.Id = dns.Id() - req.RecursionDesired = true - req.Question = []dns.Question{ - {Name: "wmconvirus.narod.ru.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, - } - reply, err := dns.Exchange(&req, addr.String()) - if err != nil { - t.Fatalf("Couldn't talk to server %s: %s", addr, err) - } - if len(reply.Answer) != 1 { - t.Fatalf("DNS server %s returned reply with wrong number of answers - %d", addr, len(reply.Answer)) - } - if a, ok := reply.Answer[0].(*dns.A); ok { - addrs, lookupErr := net.LookupHost(safeBrowsingBlockHost) - if lookupErr != nil { - t.Fatalf("cannot resolve %s due to %s", safeBrowsingBlockHost, lookupErr) - } + // SafeBrowsing blocking. + req := createTestMessage("wmconvirus.narod.ru.") + reply, err := dns.Exchange(req, addr.String()) + assert.Nilf(t, err, "couldn't talk to server %s: %s", addr, err) + assert.Lenf(t, reply.Answer, 1, "dns server %s returned reply with wrong number of answers - %d", addr, len(reply.Answer)) + + a, ok := reply.Answer[0].(*dns.A) + if assert.Truef(t, ok, "dns server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0]) { found := false for _, blockAddr := range addrs { if blockAddr == a.A.String() { found = true + break } } - - if !found { - t.Fatalf("DNS server %s returned wrong answer: %v", addr, a.A) - } - } else { - t.Fatalf("DNS server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0]) - } - - err = s.Stop() - if err != nil { - t.Fatalf("DNS server failed to stop: %s", err) + assert.Truef(t, found, "dns server %s returned wrong answer: %v", addr, a.A) } } func TestRewrite(t *testing.T) { - c := dnsfilter.Config{} - c.Rewrites = []dnsfilter.RewriteEntry{ - { + c := &dnsfilter.Config{ + Rewrites: []dnsfilter.RewriteEntry{{ Domain: "test.com", Answer: "1.2.3.4", Type: dns.TypeA, - }, - { + }, { Domain: "alias.test.com", Answer: "test.com", Type: dns.TypeCNAME, - }, - { + }, { Domain: "my.alias.example.org", Answer: "example.org", Type: dns.TypeCNAME, + }}, + } + f := dnsfilter.New(c, nil) + + s := NewServer(DNSCreateParams{DNSFilter: f}) + err := s.Prepare(&ServerConfig{ + UDPListenAddr: &net.UDPAddr{Port: 0}, + TCPListenAddr: &net.TCPAddr{Port: 0}, + FilteringConfig: FilteringConfig{ + ProtectionEnabled: true, + UpstreamDNS: []string{"8.8.8.8:53"}, + }, + }) + s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ + &testUpstream{ + cn: map[string]string{ + "example.org": "somename", + }, + ipv4: map[string][]net.IP{ + "example.org.": {{4, 3, 2, 1}}, + }, }, } - - f := dnsfilter.New(&c, nil) - s := NewServer(DNSCreateParams{DNSFilter: f}) - conf := ServerConfig{} - conf.UDPListenAddr = &net.UDPAddr{Port: 0} - conf.TCPListenAddr = &net.TCPAddr{Port: 0} - conf.ProtectionEnabled = true - conf.UpstreamDNS = []string{"8.8.8.8:53"} - - err := s.Prepare(&conf) - assert.Nil(t, err) - err = s.Start() assert.Nil(t, err) + startDeferStop(t, s) + addr := s.dnsProxy.Addr(proxy.ProtoUDP) req := createTestMessageWithType("test.com.", dns.TypeA) @@ -734,8 +728,6 @@ func TestRewrite(t *testing.T) { assert.Len(t, reply.Answer, 2) assert.Equal(t, "example.org.", reply.Answer[0].(*dns.CNAME).Target) assert.Equal(t, dns.TypeA, reply.Answer[1].Header().Rrtype) - - _ = s.Stop() } func createTestServer(t *testing.T) *Server { @@ -747,22 +739,27 @@ func createTestServer(t *testing.T) *Server { filters := []dnsfilter.Filter{{ ID: 0, Data: []byte(rules), }} - c := dnsfilter.Config{} - c.SafeBrowsingEnabled = true - c.SafeBrowsingCacheSize = 1000 - c.SafeSearchEnabled = true - c.SafeSearchCacheSize = 1000 - c.ParentalCacheSize = 1000 - c.CacheTime = 30 + c := dnsfilter.Config{ + SafeBrowsingEnabled: true, + SafeBrowsingCacheSize: 1000, + SafeSearchEnabled: true, + SafeSearchCacheSize: 1000, + ParentalCacheSize: 1000, + CacheTime: 30, + } f := dnsfilter.New(&c, filters) s := NewServer(DNSCreateParams{DNSFilter: f}) - s.conf.UDPListenAddr = &net.UDPAddr{Port: 0} - s.conf.TCPListenAddr = &net.TCPAddr{Port: 0} - s.conf.UpstreamDNS = []string{"8.8.8.8:53", "8.8.4.4:53"} - s.conf.FilteringConfig.ProtectionEnabled = true - s.conf.ConfigModified = func() {} + s.conf = ServerConfig{ + UDPListenAddr: &net.UDPAddr{Port: 0}, + TCPListenAddr: &net.TCPAddr{Port: 0}, + FilteringConfig: FilteringConfig{ + ProtectionEnabled: true, + UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"}, + }, + ConfigModified: func() {}, + } err := s.Prepare(nil) assert.Nil(t, err) @@ -770,16 +767,14 @@ func createTestServer(t *testing.T) *Server { } func createServerTLSConfig(t *testing.T) (*tls.Config, []byte, []byte) { + t.Helper() + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - t.Fatalf("cannot generate RSA key: %s", err) - } + assert.Nilf(t, err, "cannot generate RSA key: %s", err) serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) - if err != nil { - t.Fatalf("failed to generate serial number: %s", err) - } + assert.Nilf(t, err, "failed to generate serial number: %s", err) notBefore := time.Now() notAfter := notBefore.Add(5 * 365 * time.Hour * 24) @@ -800,76 +795,66 @@ func createServerTLSConfig(t *testing.T) (*tls.Config, []byte, []byte) { template.DNSNames = append(template.DNSNames, tlsServerName) derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey(privateKey), privateKey) - if err != nil { - t.Fatalf("failed to create certificate: %s", err) - } + assert.Nilf(t, err, "failed to create certificate: %s", err) certPem := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) keyPem := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)}) cert, err := tls.X509KeyPair(certPem, keyPem) - if err != nil { - t.Fatalf("failed to create certificate: %s", err) - } + assert.Nilf(t, err, "failed to create certificate: %s", err) - return &tls.Config{Certificates: []tls.Certificate{cert}, ServerName: tlsServerName, MinVersion: tls.VersionTLS12}, certPem, keyPem + return &tls.Config{ + Certificates: []tls.Certificate{cert}, + ServerName: tlsServerName, + MinVersion: tls.VersionTLS12, + }, certPem, keyPem } -func sendTestMessageAsync(t *testing.T, conn *dns.Conn, g *sync.WaitGroup) { - defer func() { - g.Done() - }() - - req := createGoogleATestMessage() - err := conn.WriteMsg(req) - if err != nil { - panic(fmt.Sprintf("cannot write message: %s", err)) - } - - res, err := conn.ReadMsg() - if err != nil { - panic(fmt.Sprintf("cannot read response to message: %s", err)) - } - assertGoogleAResponse(t, res) -} - -// sendTestMessagesAsync sends messages in parallel -// so that we could find race issues +// sendTestMessagesAsync sends messages in parallel to check for race issues. +//lint:ignore U1000 it's called from the function which is skipped for now. func sendTestMessagesAsync(t *testing.T, conn *dns.Conn) { - g := &sync.WaitGroup{} - g.Add(testMessagesCount) + wg := &sync.WaitGroup{} for i := 0; i < testMessagesCount; i++ { - go sendTestMessageAsync(t, conn, g) + msg := createGoogleATestMessage() + wg.Add(1) + + go func() { + defer wg.Done() + + err := conn.WriteMsg(msg) + assert.Nilf(t, err, "cannot write message: %s", err) + + res, err := conn.ReadMsg() + assert.Nilf(t, err, "cannot read response to message: %s", err) + + assertGoogleAResponse(t, res) + }() } - g.Wait() + wg.Wait() } func sendTestMessages(t *testing.T, conn *dns.Conn) { - for i := 0; i < 10; i++ { + t.Helper() + + for i := 0; i < testMessagesCount; i++ { req := createGoogleATestMessage() err := conn.WriteMsg(req) - if err != nil { - t.Fatalf("cannot write message #%d: %s", i, err) - } + assert.Nilf(t, err, "cannot write message #%d: %s", i, err) res, err := conn.ReadMsg() - if err != nil { - t.Fatalf("cannot read response to message #%d: %s", i, err) - } + assert.Nilf(t, err, "cannot read response to message #%d: %s", i, err) assertGoogleAResponse(t, res) } } -func exchangeAndAssertResponse(t *testing.T, client *dns.Client, addr net.Addr, host, ip string) { +func exchangeAndAssertResponse(t *testing.T, client *dns.Client, addr net.Addr, host string, ip net.IP) { t.Helper() req := createTestMessage(host) reply, _, err := client.Exchange(req, addr.String()) - if err != nil { - t.Fatalf("Couldn't talk to server %s: %s", addr, err) - } + assert.Nilf(t, err, "couldn't talk to server %s: %s", addr, err) assertResponse(t, reply, ip) } @@ -878,41 +863,36 @@ func createGoogleATestMessage() *dns.Msg { } func createTestMessage(host string) *dns.Msg { - req := dns.Msg{} - req.Id = dns.Id() - req.RecursionDesired = true - req.Question = []dns.Question{ - {Name: host, Qtype: dns.TypeA, Qclass: dns.ClassINET}, + return &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Id: dns.Id(), + RecursionDesired: true, + }, + Question: []dns.Question{{ + Name: host, + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + }}, } - return &req } func createTestMessageWithType(host string, qtype uint16) *dns.Msg { - req := dns.Msg{} - req.Id = dns.Id() - req.RecursionDesired = true - req.Question = []dns.Question{ - {Name: host, Qtype: qtype, Qclass: dns.ClassINET}, - } - return &req + req := createTestMessage(host) + req.Question[0].Qtype = qtype + return req } func assertGoogleAResponse(t *testing.T, reply *dns.Msg) { - assertResponse(t, reply, "8.8.8.8") + assertResponse(t, reply, net.IP{8, 8, 8, 8}) } -func assertResponse(t *testing.T, reply *dns.Msg, ip string) { +func assertResponse(t *testing.T, reply *dns.Msg, ip net.IP) { t.Helper() - if len(reply.Answer) != 1 { - t.Fatalf("DNS server returned reply with wrong number of answers - %d", len(reply.Answer)) - } - if a, ok := reply.Answer[0].(*dns.A); ok { - if !net.ParseIP(ip).Equal(a.A) { - t.Fatalf("DNS server returned wrong answer instead of %s: %v", ip, a.A) - } - } else { - t.Fatalf("DNS server returned wrong answer type instead of A: %v", reply.Answer[0]) + assert.Lenf(t, reply.Answer, 1, "dns server returned reply with wrong number of answers - %d", len(reply.Answer)) + a, ok := reply.Answer[0].(*dns.A) + if assert.Truef(t, ok, "dns server returned wrong answer type instead of A: %v", reply.Answer[0]) { + assert.Truef(t, a.A.Equal(ip), "dns server returned wrong answer instead of %s: %s", ip, a.A) } } @@ -928,68 +908,99 @@ func publicKey(priv interface{}) interface{} { } func TestValidateUpstream(t *testing.T) { - invalidUpstreams := []string{ - "1.2.3.4.5", - "123.3.7m", - "htttps://google.com/dns-query", - "[/host.com]tls://dns.adguard.com", - "[host.ru]#", - } + testCases := []struct { + name string + upstream string + valid bool + wantDef bool + }{{ + name: "invalid", + upstream: "1.2.3.4.5", + valid: false, + }, { + name: "invalid", + upstream: "123.3.7m", + valid: false, + }, { + name: "invalid", + upstream: "htttps://google.com/dns-query", + valid: false, + }, { + name: "invalid", + upstream: "[/host.com]tls://dns.adguard.com", + valid: false, + }, { + name: "invalid", + upstream: "[host.ru]#", + valid: false, + }, { + name: "valid_default", + upstream: "1.1.1.1", + valid: true, + wantDef: true, + }, { + name: "valid_default", + upstream: "tls://1.1.1.1", + valid: true, + wantDef: true, + }, { + name: "valid_default", + upstream: "https://dns.adguard.com/dns-query", + valid: true, + wantDef: true, + }, { + name: "valid_default", + upstream: "sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", + valid: true, + wantDef: true, + }, { + name: "valid", + upstream: "[/host.com/]1.1.1.1", + valid: true, + wantDef: false, + }, { + name: "valid", + upstream: "[//]tls://1.1.1.1", + valid: true, + wantDef: false, + }, { + name: "valid", + upstream: "[/www.host.com/]#", + valid: true, + wantDef: false, + }, { + name: "valid", + upstream: "[/host.com/google.com/]8.8.8.8", + valid: true, + wantDef: false, + }, { + name: "valid", + upstream: "[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", + valid: true, + wantDef: false, + }} - validDefaultUpstreams := []string{ - "1.1.1.1", - "tls://1.1.1.1", - "https://dns.adguard.com/dns-query", - "sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", - } - - validUpstreams := []string{ - "[/host.com/]1.1.1.1", - "[//]tls://1.1.1.1", - "[/www.host.com/]#", - "[/host.com/google.com/]8.8.8.8", - "[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", - } - for _, u := range invalidUpstreams { - _, err := validateUpstream(u) - if err == nil { - t.Fatalf("upstream %s is invalid but it pass through validation", u) - } - } - - for _, u := range validDefaultUpstreams { - defaultUpstream, err := validateUpstream(u) - if err != nil { - t.Fatalf("upstream %s is valid but it doen't pass through validation cause: %s", u, err) - } - if !defaultUpstream { - t.Fatalf("upstream %s is default one!", u) - } - } - - for _, u := range validUpstreams { - defaultUpstream, err := validateUpstream(u) - if err != nil { - t.Fatalf("upstream %s is valid but it doen't pass through validation cause: %s", u, err) - } - if defaultUpstream { - t.Fatalf("upstream %s is default one!", u) - } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + defaultUpstream, err := validateUpstream(tc.upstream) + assert.Equal(t, tc.valid, err == nil) + if err == nil { + assert.Equal(t, tc.wantDef, defaultUpstream) + } + }) } } func TestValidateUpstreamsSet(t *testing.T) { - // Empty upstreams array + // Empty upstreams array. var upstreamsSet []string - err := ValidateUpstreams(upstreamsSet) - assert.Nil(t, err, "empty upstreams array should be valid") + assert.Nil(t, ValidateUpstreams(upstreamsSet), "empty upstreams array should be valid") - // Comment in upstreams array + // Comment in upstreams array. upstreamsSet = []string{"# comment"} - err = ValidateUpstreams(upstreamsSet) - assert.Nil(t, err, "comments should not be validated") + assert.Nil(t, ValidateUpstreams(upstreamsSet), "comments should not be validated") - // Set of valid upstreams. There is no default upstream specified + // Set of valid upstreams. There is no default upstream specified. upstreamsSet = []string{ "[/host.com/]1.1.1.1", "[//]tls://1.1.1.1", @@ -997,49 +1008,80 @@ func TestValidateUpstreamsSet(t *testing.T) { "[/host.com/google.com/]8.8.8.8", "[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", } - err = ValidateUpstreams(upstreamsSet) - assert.NotNil(t, err, "there is no default upstream") + assert.NotNil(t, ValidateUpstreams(upstreamsSet), "there is no default upstream") - // Let's add default upstream + // Let's add default upstream. upstreamsSet = append(upstreamsSet, "8.8.8.8") - err = ValidateUpstreams(upstreamsSet) + err := ValidateUpstreams(upstreamsSet) assert.Nilf(t, err, "upstreams set is valid, but doesn't pass through validation cause: %s", err) - // Let's add invalid upstream + // Let's add invalid upstream. upstreamsSet = append(upstreamsSet, "dhcp://fake.dns") - err = ValidateUpstreams(upstreamsSet) - assert.NotNil(t, err, "there is an invalid upstream in set, but it pass through validation") + assert.NotNil(t, ValidateUpstreams(upstreamsSet), "there is an invalid upstream in set, but it pass through validation") } func TestIPStringFromAddr(t *testing.T) { - addr := net.UDPAddr{} - addr.IP = net.ParseIP("1:2:3::4") - addr.Port = 12345 - addr.Zone = "eth0" - assert.Equal(t, IPStringFromAddr(&addr), net.ParseIP("1:2:3::4").String()) - + addr := net.UDPAddr{ + IP: net.ParseIP("1:2:3::4"), + Port: 12345, + Zone: "eth0", + } + assert.Equal(t, IPStringFromAddr(&addr), addr.IP.String()) assert.Empty(t, IPStringFromAddr(nil)) } func TestMatchDNSName(t *testing.T) { dnsNames := []string{"host1", "*.host2", "1.2.3.4"} sort.Strings(dnsNames) - assert.True(t, matchDNSName(dnsNames, "host1")) - assert.True(t, matchDNSName(dnsNames, "a.host2")) - assert.True(t, matchDNSName(dnsNames, "b.a.host2")) - assert.True(t, matchDNSName(dnsNames, "1.2.3.4")) - assert.False(t, matchDNSName(dnsNames, "host2")) - assert.False(t, matchDNSName(dnsNames, "")) - assert.False(t, matchDNSName(dnsNames, "*.host2")) + + testCases := []struct { + name string + dnsName string + want bool + }{{ + name: "match", + dnsName: "host1", + want: true, + }, { + name: "match", + dnsName: "a.host2", + want: true, + }, { + name: "match", + dnsName: "b.a.host2", + want: true, + }, { + name: "match", + dnsName: "1.2.3.4", + want: true, + }, { + name: "mismatch", + dnsName: "host2", + want: false, + }, { + name: "mismatch", + dnsName: "", + want: false, + }, { + name: "mismatch", + dnsName: "*.host2", + want: false, + }} + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.want, matchDNSName(dnsNames, tc.dnsName)) + }) + } } type testDHCP struct{} func (d *testDHCP) Leases(flags int) []dhcpd.Lease { - l := dhcpd.Lease{} - l.IP = net.IP{127, 0, 0, 1} - l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa") - l.Hostname = "localhost" + l := dhcpd.Lease{ + IP: net.IP{127, 0, 0, 1}, + HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}, + Hostname: "localhost", + } return []dhcpd.Lease{l} } func (d *testDHCP) SetOnLeaseChanged(onLeaseChanged dhcpd.OnLeaseChangedT) {} @@ -1047,26 +1089,30 @@ func (d *testDHCP) SetOnLeaseChanged(onLeaseChanged dhcpd.OnLeaseChangedT) {} func TestPTRResponseFromDHCPLeases(t *testing.T) { dhcp := &testDHCP{} - c := dnsfilter.Config{} - f := dnsfilter.New(&c, nil) - s := NewServer(DNSCreateParams{DNSFilter: f, DHCPServer: dhcp}) + s := NewServer(DNSCreateParams{ + DNSFilter: dnsfilter.New(&dnsfilter.Config{}, nil), + DHCPServer: dhcp, + }) + s.conf.UDPListenAddr = &net.UDPAddr{Port: 0} s.conf.TCPListenAddr = &net.TCPAddr{Port: 0} s.conf.UpstreamDNS = []string{"127.0.0.1:53"} s.conf.FilteringConfig.ProtectionEnabled = true err := s.Prepare(nil) assert.Nil(t, err) - assert.Nil(t, s.Start()) + assert.Nil(t, s.Start()) addr := s.dnsProxy.Addr(proxy.ProtoUDP) - req := createTestMessage("1.0.0.127.in-addr.arpa.") - req.Question[0].Qtype = dns.TypePTR + + req := createTestMessageWithType("1.0.0.127.in-addr.arpa.", dns.TypePTR) resp, err := dns.Exchange(req, addr.String()) + assert.Nil(t, err) assert.Len(t, resp.Answer, 1) assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype) assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name) + ptr := resp.Answer[0].(*dns.PTR) assert.Equal(t, "localhost.", ptr.Ptr) @@ -1078,37 +1124,40 @@ func TestPTRResponseFromHosts(t *testing.T) { AutoHosts: &util.AutoHosts{}, } - // Prepare test hosts file - hf, _ := ioutil.TempFile("", "") - defer func() { _ = os.Remove(hf.Name()) }() - defer hf.Close() + // Prepare test hosts file. + hf, err := ioutil.TempFile("", "") + if assert.Nil(t, err) { + t.Cleanup(func() { + assert.Nil(t, hf.Close()) + assert.Nil(t, os.Remove(hf.Name())) + }) + } _, _ = hf.WriteString(" 127.0.0.1 host # comment \n") _, _ = hf.WriteString(" ::1 localhost#comment \n") - // Init auto hosts + // Init auto hosts. c.AutoHosts.Init(hf.Name()) - defer c.AutoHosts.Close() + t.Cleanup(c.AutoHosts.Close) - f := dnsfilter.New(&c, nil) - s := NewServer(DNSCreateParams{DNSFilter: f}) + s := NewServer(DNSCreateParams{DNSFilter: dnsfilter.New(&c, nil)}) s.conf.UDPListenAddr = &net.UDPAddr{Port: 0} s.conf.TCPListenAddr = &net.TCPAddr{Port: 0} s.conf.UpstreamDNS = []string{"127.0.0.1:53"} s.conf.FilteringConfig.ProtectionEnabled = true - err := s.Prepare(nil) - assert.Nil(t, err) + assert.Nil(t, s.Prepare(nil)) + assert.Nil(t, s.Start()) addr := s.dnsProxy.Addr(proxy.ProtoUDP) - req := createTestMessage("1.0.0.127.in-addr.arpa.") - req.Question[0].Qtype = dns.TypePTR + req := createTestMessageWithType("1.0.0.127.in-addr.arpa.", dns.TypePTR) resp, err := dns.Exchange(req, addr.String()) assert.Nil(t, err) assert.Len(t, resp.Answer, 1) assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype) assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name) + ptr := resp.Answer[0].(*dns.PTR) assert.Equal(t, "host.", ptr.Ptr)