+ dnsfilter: use callback function for applying per-client settings
This commit is contained in:
parent
4bb7b654ab
commit
8bf76c331d
|
@ -46,6 +46,14 @@ const shortcutLength = 6 // used for rule search optimization, 6 hits the sweet
|
||||||
const enableFastLookup = true // flag for debugging, must be true in production for faster performance
|
const enableFastLookup = true // flag for debugging, must be true in production for faster performance
|
||||||
const enableDelayedCompilation = true // flag for debugging, must be true in production for faster performance
|
const enableDelayedCompilation = true // flag for debugging, must be true in production for faster performance
|
||||||
|
|
||||||
|
// Custom filtering settings
|
||||||
|
type RequestFilteringSettings struct {
|
||||||
|
FilteringEnabled bool
|
||||||
|
SafeSearchEnabled bool
|
||||||
|
SafeBrowsingEnabled bool
|
||||||
|
ParentalEnabled bool
|
||||||
|
}
|
||||||
|
|
||||||
// Config allows you to configure DNS filtering with New() or just change variables directly.
|
// Config allows you to configure DNS filtering with New() or just change variables directly.
|
||||||
type Config struct {
|
type Config struct {
|
||||||
FilteringTempFilename string `yaml:"filtering_temp_filename"` // temporary file for storing unused filtering rules
|
FilteringTempFilename string `yaml:"filtering_temp_filename"` // temporary file for storing unused filtering rules
|
||||||
|
@ -55,6 +63,9 @@ type Config struct {
|
||||||
SafeSearchEnabled bool `yaml:"safesearch_enabled"`
|
SafeSearchEnabled bool `yaml:"safesearch_enabled"`
|
||||||
SafeBrowsingEnabled bool `yaml:"safebrowsing_enabled"`
|
SafeBrowsingEnabled bool `yaml:"safebrowsing_enabled"`
|
||||||
ResolverAddress string // DNS server address
|
ResolverAddress string // DNS server address
|
||||||
|
|
||||||
|
// Filtering callback function
|
||||||
|
FilterHandler func(clientAddr string, settings *RequestFilteringSettings) `yaml:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type privateConfig struct {
|
type privateConfig struct {
|
||||||
|
@ -149,7 +160,7 @@ func (r Reason) Matched() bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
// CheckHost tries to match host against rules, then safebrowsing and parental if they are enabled
|
// CheckHost tries to match host against rules, then safebrowsing and parental if they are enabled
|
||||||
func (d *Dnsfilter) CheckHost(host string, qtype uint16) (Result, error) {
|
func (d *Dnsfilter) CheckHost(host string, qtype uint16, clientAddr string) (Result, error) {
|
||||||
// sometimes DNS clients will try to resolve ".", which is a request to get root servers
|
// sometimes DNS clients will try to resolve ".", which is a request to get root servers
|
||||||
if host == "" {
|
if host == "" {
|
||||||
return Result{Reason: NotFilteredNotFound}, nil
|
return Result{Reason: NotFilteredNotFound}, nil
|
||||||
|
@ -160,17 +171,30 @@ func (d *Dnsfilter) CheckHost(host string, qtype uint16) (Result, error) {
|
||||||
return Result{}, nil
|
return Result{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// try filter lists first
|
var setts RequestFilteringSettings
|
||||||
result, err := d.matchHost(host, qtype)
|
setts.FilteringEnabled = true
|
||||||
if err != nil {
|
setts.SafeSearchEnabled = d.SafeSearchEnabled
|
||||||
return result, err
|
setts.SafeBrowsingEnabled = d.SafeBrowsingEnabled
|
||||||
|
setts.ParentalEnabled = d.ParentalEnabled
|
||||||
|
if len(clientAddr) != 0 && d.FilterHandler != nil {
|
||||||
|
d.FilterHandler(clientAddr, &setts)
|
||||||
}
|
}
|
||||||
if result.Reason.Matched() {
|
|
||||||
return result, nil
|
var result Result
|
||||||
|
var err error
|
||||||
|
// try filter lists first
|
||||||
|
if setts.FilteringEnabled {
|
||||||
|
result, err = d.matchHost(host, qtype)
|
||||||
|
if err != nil {
|
||||||
|
return result, err
|
||||||
|
}
|
||||||
|
if result.Reason.Matched() {
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// check safeSearch if no match
|
// check safeSearch if no match
|
||||||
if d.SafeSearchEnabled {
|
if setts.SafeSearchEnabled {
|
||||||
result, err = d.checkSafeSearch(host)
|
result, err = d.checkSafeSearch(host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Failed to safesearch HTTP lookup, ignoring check: %v", err)
|
log.Printf("Failed to safesearch HTTP lookup, ignoring check: %v", err)
|
||||||
|
@ -183,7 +207,7 @@ func (d *Dnsfilter) CheckHost(host string, qtype uint16) (Result, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// check safebrowsing if no match
|
// check safebrowsing if no match
|
||||||
if d.SafeBrowsingEnabled {
|
if setts.SafeBrowsingEnabled {
|
||||||
result, err = d.checkSafeBrowsing(host)
|
result, err = d.checkSafeBrowsing(host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// failed to do HTTP lookup -- treat it as if we got empty response, but don't save cache
|
// failed to do HTTP lookup -- treat it as if we got empty response, but don't save cache
|
||||||
|
@ -196,7 +220,7 @@ func (d *Dnsfilter) CheckHost(host string, qtype uint16) (Result, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// check parental if no match
|
// check parental if no match
|
||||||
if d.ParentalEnabled {
|
if setts.ParentalEnabled {
|
||||||
result, err = d.checkParental(host)
|
result, err = d.checkParental(host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// failed to do HTTP lookup -- treat it as if we got empty response, but don't save cache
|
// failed to do HTTP lookup -- treat it as if we got empty response, but don't save cache
|
||||||
|
|
|
@ -18,6 +18,7 @@ import (
|
||||||
// SAFE SEARCH
|
// SAFE SEARCH
|
||||||
// PARENTAL
|
// PARENTAL
|
||||||
// FILTERING
|
// FILTERING
|
||||||
|
// CLIENTS SETTINGS
|
||||||
// BENCHMARKS
|
// BENCHMARKS
|
||||||
|
|
||||||
// HELPERS
|
// HELPERS
|
||||||
|
@ -52,7 +53,7 @@ func NewForTestFilters(filters map[int]string) *Dnsfilter {
|
||||||
|
|
||||||
func (d *Dnsfilter) checkMatch(t *testing.T, hostname string) {
|
func (d *Dnsfilter) checkMatch(t *testing.T, hostname string) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
ret, err := d.CheckHost(hostname, dns.TypeA)
|
ret, err := d.CheckHost(hostname, dns.TypeA, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Error while matching host %s: %s", hostname, err)
|
t.Errorf("Error while matching host %s: %s", hostname, err)
|
||||||
}
|
}
|
||||||
|
@ -63,7 +64,7 @@ func (d *Dnsfilter) checkMatch(t *testing.T, hostname string) {
|
||||||
|
|
||||||
func (d *Dnsfilter) checkMatchIP(t *testing.T, hostname string, ip string, qtype uint16) {
|
func (d *Dnsfilter) checkMatchIP(t *testing.T, hostname string, ip string, qtype uint16) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
ret, err := d.CheckHost(hostname, qtype)
|
ret, err := d.CheckHost(hostname, qtype, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Error while matching host %s: %s", hostname, err)
|
t.Errorf("Error while matching host %s: %s", hostname, err)
|
||||||
}
|
}
|
||||||
|
@ -77,7 +78,7 @@ func (d *Dnsfilter) checkMatchIP(t *testing.T, hostname string, ip string, qtype
|
||||||
|
|
||||||
func (d *Dnsfilter) checkMatchEmpty(t *testing.T, hostname string) {
|
func (d *Dnsfilter) checkMatchEmpty(t *testing.T, hostname string) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
ret, err := d.CheckHost(hostname, dns.TypeA)
|
ret, err := d.CheckHost(hostname, dns.TypeA, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Error while matching host %s: %s", hostname, err)
|
t.Errorf("Error while matching host %s: %s", hostname, err)
|
||||||
}
|
}
|
||||||
|
@ -212,7 +213,7 @@ func TestCheckHostSafeSearchYandex(t *testing.T) {
|
||||||
|
|
||||||
// Check host for each domain
|
// Check host for each domain
|
||||||
for _, host := range yandex {
|
for _, host := range yandex {
|
||||||
result, err := d.CheckHost(host, dns.TypeA)
|
result, err := d.CheckHost(host, dns.TypeA, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("SafeSearch doesn't work for yandex domain `%s` cause %s", host, err)
|
t.Errorf("SafeSearch doesn't work for yandex domain `%s` cause %s", host, err)
|
||||||
}
|
}
|
||||||
|
@ -235,7 +236,7 @@ func TestCheckHostSafeSearchGoogle(t *testing.T) {
|
||||||
|
|
||||||
// Check host for each domain
|
// Check host for each domain
|
||||||
for _, host := range googleDomains {
|
for _, host := range googleDomains {
|
||||||
result, err := d.CheckHost(host, dns.TypeA)
|
result, err := d.CheckHost(host, dns.TypeA, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("SafeSearch doesn't work for %s cause %s", host, err)
|
t.Errorf("SafeSearch doesn't work for %s cause %s", host, err)
|
||||||
}
|
}
|
||||||
|
@ -255,7 +256,7 @@ func TestSafeSearchCacheYandex(t *testing.T) {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
// Check host with disabled safesearch
|
// Check host with disabled safesearch
|
||||||
result, err = d.CheckHost(domain, dns.TypeA)
|
result, err = d.CheckHost(domain, dns.TypeA, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Cannot check host due to %s", err)
|
t.Fatalf("Cannot check host due to %s", err)
|
||||||
}
|
}
|
||||||
|
@ -265,7 +266,7 @@ func TestSafeSearchCacheYandex(t *testing.T) {
|
||||||
|
|
||||||
// Enable safesearch
|
// Enable safesearch
|
||||||
d.SafeSearchEnabled = true
|
d.SafeSearchEnabled = true
|
||||||
result, err = d.CheckHost(domain, dns.TypeA)
|
result, err = d.CheckHost(domain, dns.TypeA, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("CheckHost for safesearh domain %s failed cause %s", domain, err)
|
t.Fatalf("CheckHost for safesearh domain %s failed cause %s", domain, err)
|
||||||
}
|
}
|
||||||
|
@ -295,7 +296,7 @@ func TestSafeSearchCacheGoogle(t *testing.T) {
|
||||||
d := NewForTest()
|
d := NewForTest()
|
||||||
defer d.Destroy()
|
defer d.Destroy()
|
||||||
domain := "www.google.ru"
|
domain := "www.google.ru"
|
||||||
result, err := d.CheckHost(domain, dns.TypeA)
|
result, err := d.CheckHost(domain, dns.TypeA, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Cannot check host due to %s", err)
|
t.Fatalf("Cannot check host due to %s", err)
|
||||||
}
|
}
|
||||||
|
@ -324,7 +325,7 @@ func TestSafeSearchCacheGoogle(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err = d.CheckHost(domain, dns.TypeA)
|
result, err = d.CheckHost(domain, dns.TypeA, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("CheckHost for safesearh domain %s failed cause %s", domain, err)
|
t.Fatalf("CheckHost for safesearh domain %s failed cause %s", domain, err)
|
||||||
}
|
}
|
||||||
|
@ -441,7 +442,7 @@ func TestMatching(t *testing.T) {
|
||||||
d := NewForTestFilters(filters)
|
d := NewForTestFilters(filters)
|
||||||
defer d.Destroy()
|
defer d.Destroy()
|
||||||
|
|
||||||
ret, err := d.CheckHost(test.hostname, dns.TypeA)
|
ret, err := d.CheckHost(test.hostname, dns.TypeA, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Error while matching host %s: %s", test.hostname, err)
|
t.Errorf("Error while matching host %s: %s", test.hostname, err)
|
||||||
}
|
}
|
||||||
|
@ -455,6 +456,52 @@ func TestMatching(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CLIENT SETTINGS
|
||||||
|
|
||||||
|
func applyClientSettings(clientAddr string, setts *RequestFilteringSettings) {
|
||||||
|
setts.FilteringEnabled = false
|
||||||
|
setts.ParentalEnabled = false
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientSettings(t *testing.T) {
|
||||||
|
var r Result
|
||||||
|
filters := make(map[int]string)
|
||||||
|
filters[0] = "||example.org^\n"
|
||||||
|
d := NewForTestFilters(filters)
|
||||||
|
defer d.Destroy()
|
||||||
|
d.ParentalEnabled = true
|
||||||
|
d.ParentalSensitivity = 3
|
||||||
|
|
||||||
|
// no client settings:
|
||||||
|
|
||||||
|
// blocked by filters
|
||||||
|
r, _ = d.CheckHost("example.org", dns.TypeA, "1.1.1.1")
|
||||||
|
if !r.IsFiltered || r.Reason != FilteredBlackList {
|
||||||
|
t.Fatalf("CheckHost FilteredBlackList")
|
||||||
|
}
|
||||||
|
|
||||||
|
// blocked by parental
|
||||||
|
r, _ = d.CheckHost("pornhub.com", dns.TypeA, "1.1.1.1")
|
||||||
|
if !r.IsFiltered || r.Reason != FilteredParental {
|
||||||
|
t.Fatalf("CheckHost FilteredParental")
|
||||||
|
}
|
||||||
|
|
||||||
|
// override client settings:
|
||||||
|
d.FilterHandler = applyClientSettings
|
||||||
|
|
||||||
|
// override filtering settings
|
||||||
|
r, _ = d.CheckHost("example.org", dns.TypeA, "1.1.1.1")
|
||||||
|
if r.IsFiltered {
|
||||||
|
t.Fatalf("CheckHost")
|
||||||
|
}
|
||||||
|
|
||||||
|
// override parental settings
|
||||||
|
r, _ = d.CheckHost("pornhub.com", dns.TypeA, "1.1.1.1")
|
||||||
|
if r.IsFiltered {
|
||||||
|
t.Fatalf("CheckHost")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// BENCHMARKS
|
// BENCHMARKS
|
||||||
|
|
||||||
func BenchmarkSafeBrowsing(b *testing.B) {
|
func BenchmarkSafeBrowsing(b *testing.B) {
|
||||||
|
@ -463,7 +510,7 @@ func BenchmarkSafeBrowsing(b *testing.B) {
|
||||||
d.SafeBrowsingEnabled = true
|
d.SafeBrowsingEnabled = true
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
hostname := "wmconvirus.narod.ru"
|
hostname := "wmconvirus.narod.ru"
|
||||||
ret, err := d.CheckHost(hostname, dns.TypeA)
|
ret, err := d.CheckHost(hostname, dns.TypeA, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Errorf("Error while matching host %s: %s", hostname, err)
|
b.Errorf("Error while matching host %s: %s", hostname, err)
|
||||||
}
|
}
|
||||||
|
@ -480,7 +527,7 @@ func BenchmarkSafeBrowsingParallel(b *testing.B) {
|
||||||
b.RunParallel(func(pb *testing.PB) {
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
for pb.Next() {
|
for pb.Next() {
|
||||||
hostname := "wmconvirus.narod.ru"
|
hostname := "wmconvirus.narod.ru"
|
||||||
ret, err := d.CheckHost(hostname, dns.TypeA)
|
ret, err := d.CheckHost(hostname, dns.TypeA, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Errorf("Error while matching host %s: %s", hostname, err)
|
b.Errorf("Error while matching host %s: %s", hostname, err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -378,7 +378,11 @@ func (s *Server) filterDNSRequest(d *proxy.DNSContext) (*dnsfilter.Result, error
|
||||||
var res dnsfilter.Result
|
var res dnsfilter.Result
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
res, err = dnsFilter.CheckHost(host, d.Req.Question[0].Qtype)
|
clientAddr := ""
|
||||||
|
if d.Addr != nil {
|
||||||
|
clientAddr, _, _ = net.SplitHostPort(d.Addr.String())
|
||||||
|
}
|
||||||
|
res, err = dnsFilter.CheckHost(host, d.Req.Question[0].Qtype, clientAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Return immediately if there's an error
|
// Return immediately if there's an error
|
||||||
return nil, errorx.Decorate(err, "dnsfilter failed to check host '%s'", host)
|
return nil, errorx.Decorate(err, "dnsfilter failed to check host '%s'", host)
|
||||||
|
|
Loading…
Reference in New Issue