diff --git a/AGHTechDoc.md b/AGHTechDoc.md index b1303621..a91e25c1 100644 --- a/AGHTechDoc.md +++ b/AGHTechDoc.md @@ -946,7 +946,7 @@ Response: { allowed_clients: ["127.0.0.1", ...] disallowed_clients: ["127.0.0.1", ...] - blocked_hosts: ["host.com", ...] + blocked_hosts: ["host.com", ...] // host name or a wildcard } diff --git a/client/src/__locales/en.json b/client/src/__locales/en.json index a4f2293e..f0bfb04f 100644 --- a/client/src/__locales/en.json +++ b/client/src/__locales/en.json @@ -357,7 +357,7 @@ "access_disallowed_title": "Disallowed clients", "access_disallowed_desc": "A list of CIDR or IP addresses. If configured, AdGuard Home will drop requests from these IP addresses.", "access_blocked_title": "Blocked domains", - "access_blocked_desc": "Don't confuse this with filters. AdGuard Home will drop DNS queries with these domains in query's question.", + "access_blocked_desc": "Don't confuse this with filters. AdGuard Home will drop DNS queries with these domains in query's question. Here you can specify the exact domain names, wildcards and urlfilter-rules, e.g. 'example.org', '*.example.org' or '||example.org^'.", "access_settings_saved": "Access settings successfully saved", "updates_checked": "Updates successfully checked", "updates_version_equal": "AdGuard Home is up-to-date", diff --git a/dnsforward/access.go b/dnsforward/access.go index 6868c6d7..6eed4337 100644 --- a/dnsforward/access.go +++ b/dnsforward/access.go @@ -2,11 +2,15 @@ package dnsforward import ( "encoding/json" + "fmt" "net" "net/http" + "strings" "sync" "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/urlfilter" + "github.com/AdguardTeam/urlfilter/filterlist" ) type accessCtx struct { @@ -18,7 +22,7 @@ type accessCtx struct { allowedClientsIPNet []net.IPNet // CIDRs of whitelist clients disallowedClientsIPNet []net.IPNet // CIDRs of clients that should be blocked - blockedHosts map[string]bool // hosts that should be blocked + blockedHostsEngine *urlfilter.DNSEngine // finds hosts that should be blocked } func (a *accessCtx) Init(allowedClients, disallowedClients, blockedHosts []string) error { @@ -32,15 +36,26 @@ func (a *accessCtx) Init(allowedClients, disallowedClients, blockedHosts []strin return err } - convertArrayToMap(&a.blockedHosts, blockedHosts) - return nil -} - -func convertArrayToMap(dst *map[string]bool, src []string) { - *dst = make(map[string]bool) - for _, s := range src { - (*dst)[s] = true + buf := strings.Builder{} + for _, s := range blockedHosts { + buf.WriteString(s) + buf.WriteString("\n") } + + listArray := []filterlist.RuleList{} + list := &filterlist.StringRuleList{ + ID: int(0), + RulesText: buf.String(), + IgnoreCosmetic: true, + } + listArray = append(listArray, list) + rulesStorage, err := filterlist.NewRuleStorage(listArray) + if err != nil { + return fmt.Errorf("filterlist.NewRuleStorage(): %s", err) + } + a.blockedHostsEngine = urlfilter.NewDNSEngine(rulesStorage) + + return nil } // Split array of IP or CIDR into 2 containers for fast search @@ -107,7 +122,7 @@ func (a *accessCtx) IsBlockedIP(ip string) bool { // IsBlockedDomain - return TRUE if this domain should be blocked func (a *accessCtx) IsBlockedDomain(host string) bool { a.lock.Lock() - _, ok := a.blockedHosts[host] + _, ok := a.blockedHostsEngine.Match(host, nil) a.lock.Unlock() return ok } diff --git a/dnsforward/dnsforward_test.go b/dnsforward/dnsforward_test.go index 49dcc6ae..60b606bf 100644 --- a/dnsforward/dnsforward_test.go +++ b/dnsforward/dnsforward_test.go @@ -796,11 +796,26 @@ func TestIsBlockedIPDisallowed(t *testing.T) { func TestIsBlockedIPBlockedDomain(t *testing.T) { a := &accessCtx{} - assert.True(t, a.Init(nil, nil, []string{"host1", "host2"}) == nil) + assert.True(t, a.Init(nil, nil, []string{"host1", + "host2", + "*.host.com", + "||host3.com^", + }) == nil) + // match by "host2.com" assert.True(t, a.IsBlockedDomain("host1")) assert.True(t, a.IsBlockedDomain("host2")) assert.True(t, !a.IsBlockedDomain("host3")) + + // match by wildcard "*.host.com" + assert.True(t, !a.IsBlockedDomain("host.com")) + assert.True(t, a.IsBlockedDomain("asdf.host.com")) + assert.True(t, a.IsBlockedDomain("qwer.asdf.host.com")) + assert.True(t, !a.IsBlockedDomain("asdf.zhost.com")) + + // match by wildcard "||host3.com^" + assert.True(t, a.IsBlockedDomain("host3.com")) + assert.True(t, a.IsBlockedDomain("asdf.host3.com")) } func TestValidateUpstream(t *testing.T) {