diff --git a/dnsfilter/dnsfilter.go b/dnsfilter/dnsfilter.go index 90bd61ea..dd5fe2e4 100644 --- a/dnsfilter/dnsfilter.go +++ b/dnsfilter/dnsfilter.go @@ -91,10 +91,11 @@ type LookupStats struct { PendingMax int64 // maximum number of pending HTTP requests } -// Stats store LookupStats for both safebrowsing and parental +// Stats store LookupStats for safebrowsing, parental and safesearch type Stats struct { Safebrowsing LookupStats Parental LookupStats + Safesearch LookupStats } // Dnsfilter holds added rules and performs hostname matches against the rules @@ -155,6 +156,7 @@ var ( stats Stats safebrowsingCache gcache.Cache parentalCache gcache.Cache + safeSearchCache gcache.Cache ) // Result holds state of hostname check @@ -188,6 +190,19 @@ func (d *Dnsfilter) CheckHost(host string) (Result, error) { return result, nil } + // check safeSearch if no match + if d.SafeSearchEnabled { + result, err = d.checkSafeSearch(host) + if err != nil { + log.Printf("Failed to safesearch HTTP lookup, ignoring check: %v", err) + return Result{}, nil + } + + if result.Reason.Matched() { + return result, nil + } + } + // check safebrowsing if no match if d.SafeBrowsingEnabled { result, err = d.checkSafeBrowsing(host) @@ -584,6 +599,62 @@ func hostnameToHashParam(host string, addslash bool) (string, map[string]bool) { return hashparam.String(), hashes } +func (d *Dnsfilter) checkSafeSearch(host string) (Result, error) { + if safeSearchCache == nil { + safeSearchCache = gcache.New(defaultCacheSize).LRU().Expiration(defaultCacheTime).Build() + } + + // Check cache. Return cached result if it was found + cachedValue, isFound, err := getCachedReason(safeSearchCache, host) + if isFound { + atomic.AddUint64(&stats.Safesearch.CacheHits, 1) + return cachedValue, nil + } + + if err != nil { + return Result{}, err + } + + safeHost, ok := d.SafeSearchDomain(host) + if !ok { + return Result{}, nil + } + + res := Result{IsFiltered: true, Reason: FilteredSafeSearch} + if ip := net.ParseIP(safeHost); ip != nil { + res.IP = ip + err = safeSearchCache.Set(host, res) + if err != nil { + return Result{}, nil + } + + return res, nil + } + + // TODO this address should be resolved with upstream that was configured in dnsforward + addrs, err := net.LookupIP(safeHost) + if err != nil { + log.Tracef("SafeSearchDomain for %s was found but failed to lookup for %s cause %s", host, safeHost, err) + return Result{}, err + } + + res.IP = addrs[0] + // The next bug may occurs: LookupIP returns DNS64 mapped ipv4 address with zero-prefix + for _, i := range addrs { + if ipv4 := i.To4(); ipv4 != nil && len(i) == net.IPv6len { + res.IP = ipv4 + break + } + } + + // Cache result + err = safeSearchCache.Set(host, res) + if err != nil { + return Result{}, nil + } + return res, nil +} + func (d *Dnsfilter) checkSafeBrowsing(host string) (Result, error) { // prevent recursion -- checking the host of safebrowsing server makes no sense if host == d.safeBrowsingServer { diff --git a/dnsfilter/dnsfilter_test.go b/dnsfilter/dnsfilter_test.go index 9692a60d..bbb01ac0 100644 --- a/dnsfilter/dnsfilter_test.go +++ b/dnsfilter/dnsfilter_test.go @@ -4,6 +4,7 @@ import ( "archive/zip" "bytes" "io/ioutil" + "net" "net/http" "net/http/httptest" "path" @@ -607,6 +608,146 @@ func TestSafeBrowsingCustomServerFail(t *testing.T) { d.checkMatchEmpty(t, "wmconvirus.narod.ru") } +func TestCheckHostSafeSearchYandex(t *testing.T) { + d := NewForTest() + defer d.Destroy() + + // Enable safesearch + d.SafeSearchEnabled = true + + // Slice of yandex domains + yandex := []string{"yAndeX.ru", "YANdex.COM", "yandex.ua", "yandex.by", "yandex.kz", "www.yandex.com"} + + // Check host for each domain + for _, host := range yandex { + result, err := d.CheckHost(host) + if err != nil { + t.Errorf("SafeSearch doesn't work for yandex domain `%s` cause %s", host, err) + } + + if result.IP.String() != "213.180.193.56" { + t.Errorf("SafeSearch doesn't work for yandex domain `%s`", host) + } + } +} + +func TestCheckHostSafeSearchGoogle(t *testing.T) { + d := NewForTest() + defer d.Destroy() + + // Enable safesearch + d.SafeSearchEnabled = true + + // Slice of google domains + googleDomains := []string{"www.google.com", "www.google.im", "www.google.co.in", "www.google.iq", "www.google.is", "www.google.it", "www.google.je"} + + // Check host for each domain + for _, host := range googleDomains { + result, err := d.CheckHost(host) + if err != nil { + t.Errorf("SafeSearch doesn't work for %s cause %s", host, err) + } + + if result.IP == nil { + t.Errorf("SafeSearch doesn't work for %s", host) + } + } +} + +func TestSafeSearchCacheYandex (t *testing.T) { + d := NewForTest() + defer d.Destroy() + domain := "yandex.ru" + + // Check host with disabled safesearch + result, err := d.CheckHost(domain) + if result.IP != nil { + t.Fatalf("SafeSearch is not enabled but there is an answer for `%s` !", domain) + } + + // Enable safesearch + d.SafeSearchEnabled = true + result, err = d.CheckHost(domain) + if err != nil { + t.Fatalf("CheckHost for safesearh domain %s failed cause %s", domain, err) + } + + // Fir yandex we already know valid ip + if result.IP.String() != "213.180.193.56" { + t.Fatalf("Wrong IP for %s safesearch: %s", domain, result.IP.String()) + } + + // Check cache + cachedValue, isFound, err := getCachedReason(safeSearchCache, domain) + + if err != nil { + t.Fatalf("An error occured during cache search for %s: %s", domain, err) + } + + if !isFound { + t.Fatalf("Safesearch cache doesn't work for %s!", domain) + } + + if cachedValue.IP.String() != "213.180.193.56" { + t.Fatalf("Wrong IP in cache for %s safesearch: %s", domain, cachedValue.IP.String()) + } +} + +func TestSafeSearchCacheGoogle (t *testing.T) { + d := NewForTest() + defer d.Destroy() + domain := "www.google.ru" + result, err := d.CheckHost(domain) + if result.IP != nil { + t.Fatalf("SafeSearch is not enabled but there is an answer!") + } + + // Enable safesearch and check host + d.SafeSearchEnabled = true + + // Let's lookup for safesearch domain + safeDomain, ok := d.SafeSearchDomain(domain) + if !ok { + t.Fatalf("Failed to get safesearch domain for %s", domain) + } + + ips, err := net.LookupIP(safeDomain) + if err != nil { + t.Fatalf("Failed to lookup for %s", safeDomain) + } + + ip := ips[0] + for _, i := range ips { + if len(i) == net.IPv6len && i.To4() != nil { + ip = i + } + } + + result, err = d.CheckHost(domain) + if err != nil { + t.Fatalf("CheckHost for safesearh domain %s failed cause %s", domain, err) + } + + if result.IP.String() != ip.String() { + t.Fatalf("Wrong IP for %s safesearch: %s", domain, result.IP.String()) + } + + // Check cache + cachedValue, isFound, err := getCachedReason(safeSearchCache, domain) + + if err != nil { + t.Fatalf("An error occured during cache search for %s: %s", domain, err) + } + + if !isFound { + t.Fatalf("Safesearch cache doesn't work for %s!", domain) + } + + if cachedValue.IP.String() != ip.String() { + t.Fatalf("Wrong IP in cache for %s safesearch: %s", domain, cachedValue.IP.String()) + } +} + func TestParentalControl(t *testing.T) { d := NewForTest() defer d.Destroy() diff --git a/dnsfilter/safesearch.go b/dnsfilter/safesearch.go index 902c0f09..c4834b2f 100644 --- a/dnsfilter/safesearch.go +++ b/dnsfilter/safesearch.go @@ -6,6 +6,11 @@ var safeSearchDomains = map[string]string{ "yandex.ua": "213.180.193.56", "yandex.by": "213.180.193.56", "yandex.kz": "213.180.193.56", + "www.yandex.com": "213.180.193.56", + "www.yandex.ru": "213.180.193.56", + "www.yandex.ua": "213.180.193.56", + "www.yandex.by": "213.180.193.56", + "www.yandex.kz": "213.180.193.56", "www.bing.com": "strict.bing.com", diff --git a/dnsforward/dnsforward_test.go b/dnsforward/dnsforward_test.go index 28c115d1..df14b8ae 100644 --- a/dnsforward/dnsforward_test.go +++ b/dnsforward/dnsforward_test.go @@ -38,14 +38,14 @@ func TestServer(t *testing.T) { } // message over UDP - req := createTestMessage() + req := createGoogleATestMessage() addr := s.dnsProxy.Addr(proxy.ProtoUDP) client := dns.Client{Net: "udp"} reply, _, err := client.Exchange(req, addr.String()) if err != nil { t.Fatalf("Couldn't talk to server %s: %s", addr, err) } - assertResponse(t, reply) + assertGoogleAResponse(t, reply) // check query log and stats log := s.GetQueryLog() @@ -56,14 +56,14 @@ func TestServer(t *testing.T) { assert.Equal(t, 1, len(stats.Clients), "Top clients length") // message over TCP - req = createTestMessage() + req = createGoogleATestMessage() addr = s.dnsProxy.Addr("tcp") client = dns.Client{Net: "tcp"} reply, _, err = client.Exchange(req, addr.String()) if err != nil { t.Fatalf("Couldn't talk to server %s: %s", addr, err) } - assertResponse(t, reply) + assertGoogleAResponse(t, reply) // check query log and stats again log = s.GetQueryLog() @@ -143,6 +143,75 @@ func TestServerRace(t *testing.T) { } } +func TestSafeSearch(t *testing.T) { + s := createTestServer(t) + s.SafeSearchEnabled = true + defer removeDataDir(t) + err := s.Start(nil) + if err != nil { + t.Fatalf("Failed to start server: %s", err) + } + + // Test safe search for yandex. We already know safe search ip + addr := s.dnsProxy.Addr(proxy.ProtoUDP) + client := dns.Client{Net: "udp"} + yandexDomains := []string{"yandex.com.", "yandex.by.", "yandex.kz.", "yandex.ru.", "yandex.com."} + for _, host := range yandexDomains { + exchangeAndAssertResponse(t, client, addr, host, "213.180.193.56") + } + + // Check aggregated stats + assert.Equal(t, s.GetAggregatedStats()["replaced_safesearch"], float64(len(yandexDomains))) + assert.Equal(t, s.GetAggregatedStats()["blocked_filtering"], float64(len(yandexDomains))) + assert.Equal(t, s.GetAggregatedStats()["dns_queries"], float64(len(yandexDomains))) + + // Let's lookup for google safesearch ip + ips, err := net.LookupIP("forcesafesearch.google.com") + if err != nil { + t.Fatalf("Failed to lookup for forcesafesearch.google.com: %s", err) + } + + ip := ips[0] + for _, i := range ips { + if len(i) == net.IPv6len && i.To4() != nil { + ip = i + break + } + } + + // Test safe search for google. + googleDomains := []string{"www.google.com.", "www.google.com.af.", "www.google.be.", "www.google.by."} + for _, host := range googleDomains { + exchangeAndAssertResponse(t, client, addr, host, ip.String()) + } + + // Check aggregated stats + assert.Equal(t, s.GetAggregatedStats()["replaced_safesearch"], float64(len(yandexDomains)+len(googleDomains))) + assert.Equal(t, s.GetAggregatedStats()["blocked_filtering"], float64(len(yandexDomains)+len(googleDomains))) + assert.Equal(t, s.GetAggregatedStats()["dns_queries"], float64(len(yandexDomains)+len(googleDomains))) + + // Do one more exchange + exchangeAndAssertResponse(t, client, addr, "google-public-dns-a.google.com.", "8.8.8.8") + + // Check aggregated stats + assert.Equal(t, s.GetAggregatedStats()["replaced_safesearch"], float64(len(yandexDomains)+len(googleDomains))) + assert.Equal(t, s.GetAggregatedStats()["blocked_filtering"], float64(len(yandexDomains)+len(googleDomains))) + assert.Equal(t, s.GetAggregatedStats()["dns_queries"], float64(len(yandexDomains)+len(googleDomains)+1)) + + // Count of blocked domains (there is `yandex.com` duplicate in yandexDomains array) + blockedCount := len(yandexDomains) - 1 + len(googleDomains) + assert.Equal(t, len(s.GetStatsTop().Blocked), blockedCount) + + // Count of domains (blocked domains + `google-public-dns-a.google.com`) + domainsCount := blockedCount + 1 + assert.Equal(t, len(s.GetStatsTop().Domains), domainsCount) + + err = s.Stop() + if err != nil { + t.Fatalf("Can not stopd server cause: %s", err) + } +} + func TestInvalidRequest(t *testing.T) { s := createTestServer(t) defer removeDataDir(t) @@ -413,7 +482,7 @@ func sendTestMessageAsync(t *testing.T, conn *dns.Conn, g *sync.WaitGroup) { g.Done() }() - req := createTestMessage() + req := createGoogleATestMessage() err := conn.WriteMsg(req) if err != nil { t.Fatalf("cannot write message: %s", err) @@ -423,7 +492,7 @@ func sendTestMessageAsync(t *testing.T, conn *dns.Conn, g *sync.WaitGroup) { if err != nil { t.Fatalf("cannot read response to message: %s", err) } - assertResponse(t, res) + assertGoogleAResponse(t, res) } // sendTestMessagesAsync sends messages in parallel @@ -441,7 +510,7 @@ func sendTestMessagesAsync(t *testing.T, conn *dns.Conn) { func sendTestMessages(t *testing.T, conn *dns.Conn) { for i := 0; i < 10; i++ { - req := createTestMessage() + req := createGoogleATestMessage() err := conn.WriteMsg(req) if err != nil { t.Fatalf("cannot write message #%d: %s", i, err) @@ -451,27 +520,44 @@ func sendTestMessages(t *testing.T, conn *dns.Conn) { if err != nil { t.Fatalf("cannot read response to message #%d: %s", i, err) } - assertResponse(t, res) + assertGoogleAResponse(t, res) } } -func createTestMessage() *dns.Msg { +func exchangeAndAssertResponse(t *testing.T, client dns.Client, addr net.Addr, host, ip string) { + req := createTestMessage(host) + reply, _, err := client.Exchange(req, addr.String()) + if err != nil { + t.Fatalf("Couldn't talk to server %s: %s", addr, err) + } + assertResponse(t, reply, ip) +} + +func createGoogleATestMessage() *dns.Msg { + return createTestMessage("google-public-dns-a.google.com.") +} + +func createTestMessage(host string) *dns.Msg { req := dns.Msg{} req.Id = dns.Id() req.RecursionDesired = true req.Question = []dns.Question{ - {Name: "google-public-dns-a.google.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, + {Name: host, Qtype: dns.TypeA, Qclass: dns.ClassINET}, } return &req } -func assertResponse(t *testing.T, reply *dns.Msg) { +func assertGoogleAResponse(t *testing.T, reply *dns.Msg) { + assertResponse(t, reply, "8.8.8.8") +} + +func assertResponse(t *testing.T, reply *dns.Msg, ip string) { if len(reply.Answer) != 1 { t.Fatalf("DNS server returned reply with wrong number of answers - %d", len(reply.Answer)) } if a, ok := reply.Answer[0].(*dns.A); ok { - if !net.IPv4(8, 8, 8, 8).Equal(a.A) { - t.Fatalf("DNS server returned wrong answer instead of 8.8.8.8: %v", a.A) + if !net.ParseIP(ip).Equal(a.A) { + t.Fatalf("DNS server returned wrong answer instead of %s: %v", ip, a.A) } } else { t.Fatalf("DNS server returned wrong answer type instead of A: %v", reply.Answer[0])