diff --git a/dnsfilter/dnsfilter.go b/dnsfilter/dnsfilter.go index ec53121d..7e6c01f0 100644 --- a/dnsfilter/dnsfilter.go +++ b/dnsfilter/dnsfilter.go @@ -637,6 +637,7 @@ func (d *Dnsfilter) checkSafeSearch(host string) (Result, error) { return Result{}, err } + res.IP = addrs[0] // The next bug may occurs: LookupIP returns DNS64 mapped ipv4 address with zero-prefix for _, i := range addrs { if ipv4 := i.To4(); ipv4 != nil && len(i) == net.IPv6len { @@ -645,10 +646,6 @@ func (d *Dnsfilter) checkSafeSearch(host string) (Result, error) { } } - if res.IP == nil || len(res.IP) == 0 { - res.IP = addrs[0] - } - // Cache result err = safeSearchCache.Set(host, res) if err != nil { diff --git a/dnsfilter/dnsfilter_test.go b/dnsfilter/dnsfilter_test.go index 50fd0bab..bbb01ac0 100644 --- a/dnsfilter/dnsfilter_test.go +++ b/dnsfilter/dnsfilter_test.go @@ -702,6 +702,8 @@ func TestSafeSearchCacheGoogle (t *testing.T) { t.Fatalf("SafeSearch is not enabled but there is an answer!") } + // Enable safesearch and check host + d.SafeSearchEnabled = true // Let's lookup for safesearch domain safeDomain, ok := d.SafeSearchDomain(domain) @@ -714,20 +716,13 @@ func TestSafeSearchCacheGoogle (t *testing.T) { t.Fatalf("Failed to lookup for %s", safeDomain) } - var ip net.IP + ip := ips[0] for _, i := range ips { if len(i) == net.IPv6len && i.To4() != nil { ip = i } } - if ip == nil || len(ip) == 0 { - ip = ips[0] - } - - // Enable safesearch and check host - d.SafeSearchEnabled = true - result, err = d.CheckHost(domain) if err != nil { t.Fatalf("CheckHost for safesearh domain %s failed cause %s", domain, err) diff --git a/dnsforward/dnsforward_test.go b/dnsforward/dnsforward_test.go index 9553b9ed..476ad67a 100644 --- a/dnsforward/dnsforward_test.go +++ b/dnsforward/dnsforward_test.go @@ -21,14 +21,14 @@ func TestServer(t *testing.T) { } // message over UDP - req := createTestMessage() + req := createGoogleATestMessage() addr := s.dnsProxy.Addr("udp") client := dns.Client{Net: "udp"} reply, _, err := client.Exchange(req, addr.String()) if err != nil { t.Fatalf("Couldn't talk to server %s: %s", addr, err) } - assertResponse(t, reply) + assertGoogleAResponse(t, reply) // check query log and stats log := s.GetQueryLog() @@ -39,14 +39,14 @@ func TestServer(t *testing.T) { assert.Equal(t, 1, len(stats.Clients), "Top clients length") // message over TCP - req = createTestMessage() + req = createGoogleATestMessage() addr = s.dnsProxy.Addr("tcp") client = dns.Client{Net: "tcp"} reply, _, err = client.Exchange(req, addr.String()) if err != nil { t.Fatalf("Couldn't talk to server %s: %s", addr, err) } - assertResponse(t, reply) + assertGoogleAResponse(t, reply) // check query log and stats again log = s.GetQueryLog() @@ -63,6 +63,49 @@ func TestServer(t *testing.T) { } } +func TestSafeSearch(t *testing.T) { + s := createTestServer(t) + s.SafeSearchEnabled = true + defer removeDataDir(t) + err := s.Start(nil) + if err != nil { + t.Fatalf("Failed to start server: %s", err) + } + + // Test safesearch for yandex. We already know safesearch ip + addr := s.dnsProxy.Addr("udp") + client := dns.Client{Net: "udp"} + yandexDomains := []string{"yandex.com.", "yandex.by.", "yandex.kz.", "yandex.ru.", "yandex.com."} + for _, host := range yandexDomains { + exchangeAndAssertResponse(t, client, addr, host, "213.180.193.56") + } + + // Let's lookup for google safesearch ip + ips, err := net.LookupIP("forcesafesearch.google.com") + if err != nil { + t.Fatalf("Failed to lookup for forcesafesearch.google.com: %s", err) + } + + ip := ips[0] + for _, i := range ips { + if len(i) == net.IPv6len && i.To4() != nil { + ip = i + break + } + } + + // Test safeseacrh for google. + googleDomains := []string{"www.google.com.", "www.google.com.af.", "www.google.be.", "www.google.by."} + for _, host := range googleDomains { + exchangeAndAssertResponse(t, client, addr, host, ip.String()) + } + + err = s.Stop() + if err != nil { + t.Fatalf("Can not stopd server cause: %s", err) + } +} + func TestInvalidRequest(t *testing.T) { s := createTestServer(t) defer removeDataDir(t) @@ -283,23 +326,40 @@ func removeDataDir(t *testing.T) { } } -func createTestMessage() *dns.Msg { +func exchangeAndAssertResponse(t *testing.T, client dns.Client, addr net.Addr, host, ip string) { + req := createTestMessage(host) + reply, _, err := client.Exchange(req, addr.String()) + if err != nil { + t.Fatalf("Couldn't talk to server %s: %s", addr, err) + } + assertResponse(t, reply, ip) +} + +func createGoogleATestMessage() *dns.Msg { + return createTestMessage("google-public-dns-a.google.com.") +} + +func createTestMessage(host string) *dns.Msg { req := dns.Msg{} req.Id = dns.Id() req.RecursionDesired = true req.Question = []dns.Question{ - {Name: "google-public-dns-a.google.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, + {Name: host, Qtype: dns.TypeA, Qclass: dns.ClassINET}, } return &req } -func assertResponse(t *testing.T, reply *dns.Msg) { +func assertGoogleAResponse(t *testing.T, reply *dns.Msg) { + assertResponse(t, reply, "8.8.8.8") +} + +func assertResponse(t *testing.T, reply *dns.Msg, ip string) { if len(reply.Answer) != 1 { t.Fatalf("DNS server returned reply with wrong number of answers - %d", len(reply.Answer)) } if a, ok := reply.Answer[0].(*dns.A); ok { - if !net.IPv4(8, 8, 8, 8).Equal(a.A) { - t.Fatalf("DNS server returned wrong answer instead of 8.8.8.8: %v", a.A) + if !net.ParseIP(ip).Equal(a.A) { + t.Fatalf("DNS server returned wrong answer instead of %s: %v", ip, a.A) } } else { t.Fatalf("DNS server returned wrong answer type instead of A: %v", reply.Answer[0])