diff --git a/dnsfilter/dnsfilter.go b/dnsfilter/dnsfilter.go index bf1e7a40..08604f34 100644 --- a/dnsfilter/dnsfilter.go +++ b/dnsfilter/dnsfilter.go @@ -46,6 +46,14 @@ const shortcutLength = 6 // used for rule search optimization, 6 hits the sweet const enableFastLookup = true // flag for debugging, must be true in production for faster performance const enableDelayedCompilation = true // flag for debugging, must be true in production for faster performance +// Custom filtering settings +type RequestFilteringSettings struct { + FilteringEnabled bool + SafeSearchEnabled bool + SafeBrowsingEnabled bool + ParentalEnabled bool +} + // Config allows you to configure DNS filtering with New() or just change variables directly. type Config struct { FilteringTempFilename string `yaml:"filtering_temp_filename"` // temporary file for storing unused filtering rules @@ -55,6 +63,9 @@ type Config struct { SafeSearchEnabled bool `yaml:"safesearch_enabled"` SafeBrowsingEnabled bool `yaml:"safebrowsing_enabled"` ResolverAddress string // DNS server address + + // Filtering callback function + FilterHandler func(clientAddr string, settings *RequestFilteringSettings) `yaml:"-"` } type privateConfig struct { @@ -149,7 +160,7 @@ func (r Reason) Matched() bool { } // CheckHost tries to match host against rules, then safebrowsing and parental if they are enabled -func (d *Dnsfilter) CheckHost(host string, qtype uint16) (Result, error) { +func (d *Dnsfilter) CheckHost(host string, qtype uint16, clientAddr string) (Result, error) { // sometimes DNS clients will try to resolve ".", which is a request to get root servers if host == "" { return Result{Reason: NotFilteredNotFound}, nil @@ -160,17 +171,30 @@ func (d *Dnsfilter) CheckHost(host string, qtype uint16) (Result, error) { return Result{}, nil } - // try filter lists first - result, err := d.matchHost(host, qtype) - if err != nil { - return result, err + var setts RequestFilteringSettings + setts.FilteringEnabled = true + setts.SafeSearchEnabled = d.SafeSearchEnabled + setts.SafeBrowsingEnabled = d.SafeBrowsingEnabled + setts.ParentalEnabled = d.ParentalEnabled + if len(clientAddr) != 0 && d.FilterHandler != nil { + d.FilterHandler(clientAddr, &setts) } - if result.Reason.Matched() { - return result, nil + + var result Result + var err error + // try filter lists first + if setts.FilteringEnabled { + result, err = d.matchHost(host, qtype) + if err != nil { + return result, err + } + if result.Reason.Matched() { + return result, nil + } } // check safeSearch if no match - if d.SafeSearchEnabled { + if setts.SafeSearchEnabled { result, err = d.checkSafeSearch(host) if err != nil { log.Printf("Failed to safesearch HTTP lookup, ignoring check: %v", err) @@ -183,7 +207,7 @@ func (d *Dnsfilter) CheckHost(host string, qtype uint16) (Result, error) { } // check safebrowsing if no match - if d.SafeBrowsingEnabled { + if setts.SafeBrowsingEnabled { result, err = d.checkSafeBrowsing(host) if err != nil { // failed to do HTTP lookup -- treat it as if we got empty response, but don't save cache @@ -196,7 +220,7 @@ func (d *Dnsfilter) CheckHost(host string, qtype uint16) (Result, error) { } // check parental if no match - if d.ParentalEnabled { + if setts.ParentalEnabled { result, err = d.checkParental(host) if err != nil { // failed to do HTTP lookup -- treat it as if we got empty response, but don't save cache diff --git a/dnsfilter/dnsfilter_test.go b/dnsfilter/dnsfilter_test.go index 4619310f..7488985f 100644 --- a/dnsfilter/dnsfilter_test.go +++ b/dnsfilter/dnsfilter_test.go @@ -18,6 +18,7 @@ import ( // SAFE SEARCH // PARENTAL // FILTERING +// CLIENTS SETTINGS // BENCHMARKS // HELPERS @@ -52,7 +53,7 @@ func NewForTestFilters(filters map[int]string) *Dnsfilter { func (d *Dnsfilter) checkMatch(t *testing.T, hostname string) { t.Helper() - ret, err := d.CheckHost(hostname, dns.TypeA) + ret, err := d.CheckHost(hostname, dns.TypeA, "") if err != nil { t.Errorf("Error while matching host %s: %s", hostname, err) } @@ -63,7 +64,7 @@ func (d *Dnsfilter) checkMatch(t *testing.T, hostname string) { func (d *Dnsfilter) checkMatchIP(t *testing.T, hostname string, ip string, qtype uint16) { t.Helper() - ret, err := d.CheckHost(hostname, qtype) + ret, err := d.CheckHost(hostname, qtype, "") if err != nil { t.Errorf("Error while matching host %s: %s", hostname, err) } @@ -77,7 +78,7 @@ func (d *Dnsfilter) checkMatchIP(t *testing.T, hostname string, ip string, qtype func (d *Dnsfilter) checkMatchEmpty(t *testing.T, hostname string) { t.Helper() - ret, err := d.CheckHost(hostname, dns.TypeA) + ret, err := d.CheckHost(hostname, dns.TypeA, "") if err != nil { t.Errorf("Error while matching host %s: %s", hostname, err) } @@ -212,7 +213,7 @@ func TestCheckHostSafeSearchYandex(t *testing.T) { // Check host for each domain for _, host := range yandex { - result, err := d.CheckHost(host, dns.TypeA) + result, err := d.CheckHost(host, dns.TypeA, "") if err != nil { t.Errorf("SafeSearch doesn't work for yandex domain `%s` cause %s", host, err) } @@ -235,7 +236,7 @@ func TestCheckHostSafeSearchGoogle(t *testing.T) { // Check host for each domain for _, host := range googleDomains { - result, err := d.CheckHost(host, dns.TypeA) + result, err := d.CheckHost(host, dns.TypeA, "") if err != nil { t.Errorf("SafeSearch doesn't work for %s cause %s", host, err) } @@ -255,7 +256,7 @@ func TestSafeSearchCacheYandex(t *testing.T) { var err error // Check host with disabled safesearch - result, err = d.CheckHost(domain, dns.TypeA) + result, err = d.CheckHost(domain, dns.TypeA, "") if err != nil { t.Fatalf("Cannot check host due to %s", err) } @@ -265,7 +266,7 @@ func TestSafeSearchCacheYandex(t *testing.T) { // Enable safesearch d.SafeSearchEnabled = true - result, err = d.CheckHost(domain, dns.TypeA) + result, err = d.CheckHost(domain, dns.TypeA, "") if err != nil { t.Fatalf("CheckHost for safesearh domain %s failed cause %s", domain, err) } @@ -295,7 +296,7 @@ func TestSafeSearchCacheGoogle(t *testing.T) { d := NewForTest() defer d.Destroy() domain := "www.google.ru" - result, err := d.CheckHost(domain, dns.TypeA) + result, err := d.CheckHost(domain, dns.TypeA, "") if err != nil { t.Fatalf("Cannot check host due to %s", err) } @@ -324,7 +325,7 @@ func TestSafeSearchCacheGoogle(t *testing.T) { } } - result, err = d.CheckHost(domain, dns.TypeA) + result, err = d.CheckHost(domain, dns.TypeA, "") if err != nil { t.Fatalf("CheckHost for safesearh domain %s failed cause %s", domain, err) } @@ -441,7 +442,7 @@ func TestMatching(t *testing.T) { d := NewForTestFilters(filters) defer d.Destroy() - ret, err := d.CheckHost(test.hostname, dns.TypeA) + ret, err := d.CheckHost(test.hostname, dns.TypeA, "") if err != nil { t.Errorf("Error while matching host %s: %s", test.hostname, err) } @@ -455,6 +456,52 @@ func TestMatching(t *testing.T) { } } +// CLIENT SETTINGS + +func applyClientSettings(clientAddr string, setts *RequestFilteringSettings) { + setts.FilteringEnabled = false + setts.ParentalEnabled = false +} + +func TestClientSettings(t *testing.T) { + var r Result + filters := make(map[int]string) + filters[0] = "||example.org^\n" + d := NewForTestFilters(filters) + defer d.Destroy() + d.ParentalEnabled = true + d.ParentalSensitivity = 3 + + // no client settings: + + // blocked by filters + r, _ = d.CheckHost("example.org", dns.TypeA, "1.1.1.1") + if !r.IsFiltered || r.Reason != FilteredBlackList { + t.Fatalf("CheckHost FilteredBlackList") + } + + // blocked by parental + r, _ = d.CheckHost("pornhub.com", dns.TypeA, "1.1.1.1") + if !r.IsFiltered || r.Reason != FilteredParental { + t.Fatalf("CheckHost FilteredParental") + } + + // override client settings: + d.FilterHandler = applyClientSettings + + // override filtering settings + r, _ = d.CheckHost("example.org", dns.TypeA, "1.1.1.1") + if r.IsFiltered { + t.Fatalf("CheckHost") + } + + // override parental settings + r, _ = d.CheckHost("pornhub.com", dns.TypeA, "1.1.1.1") + if r.IsFiltered { + t.Fatalf("CheckHost") + } +} + // BENCHMARKS func BenchmarkSafeBrowsing(b *testing.B) { @@ -463,7 +510,7 @@ func BenchmarkSafeBrowsing(b *testing.B) { d.SafeBrowsingEnabled = true for n := 0; n < b.N; n++ { hostname := "wmconvirus.narod.ru" - ret, err := d.CheckHost(hostname, dns.TypeA) + ret, err := d.CheckHost(hostname, dns.TypeA, "") if err != nil { b.Errorf("Error while matching host %s: %s", hostname, err) } @@ -480,7 +527,7 @@ func BenchmarkSafeBrowsingParallel(b *testing.B) { b.RunParallel(func(pb *testing.PB) { for pb.Next() { hostname := "wmconvirus.narod.ru" - ret, err := d.CheckHost(hostname, dns.TypeA) + ret, err := d.CheckHost(hostname, dns.TypeA, "") if err != nil { b.Errorf("Error while matching host %s: %s", hostname, err) } diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index b728c90f..34e6285c 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -378,7 +378,11 @@ func (s *Server) filterDNSRequest(d *proxy.DNSContext) (*dnsfilter.Result, error var res dnsfilter.Result var err error - res, err = dnsFilter.CheckHost(host, d.Req.Question[0].Qtype) + clientAddr := "" + if d.Addr != nil { + clientAddr, _, _ = net.SplitHostPort(d.Addr.String()) + } + res, err = dnsFilter.CheckHost(host, d.Req.Question[0].Qtype, clientAddr) if err != nil { // Return immediately if there's an error return nil, errorx.Decorate(err, "dnsfilter failed to check host '%s'", host)