* refactor: move blocked-services functions to dnsfilter

This commit is contained in:
Simon Zolin 2020-02-18 20:17:35 +03:00
parent e8129f15c7
commit 0789e4b20d
8 changed files with 51 additions and 51 deletions

View File

@ -1,10 +1,9 @@
package home package dnsfilter
import ( import (
"encoding/json" "encoding/json"
"net/http" "net/http"
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/urlfilter/rules" "github.com/AdguardTeam/urlfilter/rules"
) )
@ -119,7 +118,7 @@ var serviceRulesArray = []svc{
} }
// convert array to map // convert array to map
func initServices() { func initBlockedServices() {
serviceRules = make(map[string][]*rules.NetworkRule) serviceRules = make(map[string][]*rules.NetworkRule)
for _, s := range serviceRulesArray { for _, s := range serviceRulesArray {
netRules := []*rules.NetworkRule{} netRules := []*rules.NetworkRule{}
@ -135,15 +134,20 @@ func initServices() {
} }
} }
// Return TRUE if a blocked service name is known // BlockedSvcKnown - return TRUE if a blocked service name is known
func blockedSvcKnown(s string) bool { func BlockedSvcKnown(s string) bool {
_, ok := serviceRules[s] _, ok := serviceRules[s]
return ok return ok
} }
// ApplyBlockedServices - set blocked services settings for this DNS request // ApplyBlockedServices - set blocked services settings for this DNS request
func ApplyBlockedServices(setts *dnsfilter.RequestFilteringSettings, list []string) { func (d *Dnsfilter) ApplyBlockedServices(setts *RequestFilteringSettings, list []string, global bool) {
setts.ServicesRules = []dnsfilter.ServiceEntry{} setts.ServicesRules = []ServiceEntry{}
if global {
d.confLock.RLock()
defer d.confLock.RUnlock()
list = d.Config.BlockedServices
}
for _, name := range list { for _, name := range list {
rules, ok := serviceRules[name] rules, ok := serviceRules[name]
@ -152,51 +156,45 @@ func ApplyBlockedServices(setts *dnsfilter.RequestFilteringSettings, list []stri
continue continue
} }
s := dnsfilter.ServiceEntry{} s := ServiceEntry{}
s.Name = name s.Name = name
s.Rules = rules s.Rules = rules
setts.ServicesRules = append(setts.ServicesRules, s) setts.ServicesRules = append(setts.ServicesRules, s)
} }
} }
func handleBlockedServicesList(w http.ResponseWriter, r *http.Request) { func (d *Dnsfilter) handleBlockedServicesList(w http.ResponseWriter, r *http.Request) {
config.RLock() d.confLock.RLock()
list := config.DNS.BlockedServices list := d.Config.BlockedServices
config.RUnlock() d.confLock.RUnlock()
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(list) err := json.NewEncoder(w).Encode(list)
if err != nil { if err != nil {
httpError(w, http.StatusInternalServerError, "json.Encode: %s", err) httpError(r, w, http.StatusInternalServerError, "json.Encode: %s", err)
return return
} }
} }
func handleBlockedServicesSet(w http.ResponseWriter, r *http.Request) { func (d *Dnsfilter) handleBlockedServicesSet(w http.ResponseWriter, r *http.Request) {
list := []string{} list := []string{}
err := json.NewDecoder(r.Body).Decode(&list) err := json.NewDecoder(r.Body).Decode(&list)
if err != nil { if err != nil {
httpError(w, http.StatusBadRequest, "json.Decode: %s", err) httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
return return
} }
config.Lock() d.confLock.Lock()
config.DNS.BlockedServices = list d.Config.BlockedServices = list
config.Unlock() d.confLock.Unlock()
log.Debug("Updated blocked services list: %d", len(list)) log.Debug("Updated blocked services list: %d", len(list))
err = writeAllConfigsAndReloadDNS() d.ConfigModified()
if err != nil {
httpError(w, http.StatusBadRequest, "%s", err)
return
}
httpOK(r, w)
} }
// RegisterBlockedServicesHandlers - register HTTP handlers // registerBlockedServicesHandlers - register HTTP handlers
func RegisterBlockedServicesHandlers() { func (d *Dnsfilter) registerBlockedServicesHandlers() {
httpRegister(http.MethodGet, "/control/blocked_services/list", handleBlockedServicesList) d.Config.HTTPRegister("GET", "/control/blocked_services/list", d.handleBlockedServicesList)
httpRegister(http.MethodPost, "/control/blocked_services/set", handleBlockedServicesSet) d.Config.HTTPRegister("POST", "/control/blocked_services/set", d.handleBlockedServicesSet)
} }

View File

@ -49,6 +49,10 @@ type Config struct {
Rewrites []RewriteEntry `yaml:"rewrites"` Rewrites []RewriteEntry `yaml:"rewrites"`
// Names of services to block (globally).
// Per-client settings can override this configuration.
BlockedServices []string `yaml:"blocked_services"`
// Called when the configuration is changed by HTTP request // Called when the configuration is changed by HTTP request
ConfigModified func() `yaml:"-"` ConfigModified func() `yaml:"-"`
@ -175,6 +179,7 @@ func (d *Dnsfilter) WriteDiskConfig(c *Config) {
d.confLock.Lock() d.confLock.Lock()
*c = d.Config *c = d.Config
c.Rewrites = rewriteArrayDup(d.Config.Rewrites) c.Rewrites = rewriteArrayDup(d.Config.Rewrites)
// BlockedServices
d.confLock.Unlock() d.confLock.Unlock()
} }
@ -633,6 +638,18 @@ func New(c *Config, blockFilters []Filter) *Dnsfilter {
d.prepareRewrites() d.prepareRewrites()
} }
bsvcs := []string{}
for _, s := range d.BlockedServices {
if !BlockedSvcKnown(s) {
log.Debug("skipping unknown blocked-service '%s'", s)
continue
}
bsvcs = append(bsvcs, s)
}
d.BlockedServices = bsvcs
initBlockedServices()
if blockFilters != nil { if blockFilters != nil {
err := d.initFiltering(nil, blockFilters) err := d.initFiltering(nil, blockFilters)
if err != nil { if err != nil {
@ -655,6 +672,7 @@ func (d *Dnsfilter) Start() {
if d.Config.HTTPRegister != nil { // for tests if d.Config.HTTPRegister != nil { // for tests
d.registerSecurityHandlers() d.registerSecurityHandlers()
d.registerRewritesHandlers() d.registerRewritesHandlers()
d.registerBlockedServicesHandlers()
} }
} }

View File

@ -14,6 +14,7 @@ import (
"time" "time"
"github.com/AdguardTeam/AdGuardHome/dhcpd" "github.com/AdguardTeam/AdGuardHome/dhcpd"
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
"github.com/AdguardTeam/AdGuardHome/dnsforward" "github.com/AdguardTeam/AdGuardHome/dnsforward"
"github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
@ -161,7 +162,7 @@ func (clients *clientsContainer) addFromConfig(objects []clientObject) {
} }
for _, s := range cy.BlockedServices { for _, s := range cy.BlockedServices {
if !blockedSvcKnown(s) { if !dnsfilter.BlockedSvcKnown(s) {
log.Debug("Clients: skipping unknown blocked-service '%s'", s) log.Debug("Clients: skipping unknown blocked-service '%s'", s)
continue continue
} }

View File

@ -95,10 +95,6 @@ type dnsConfig struct {
FilteringEnabled bool `yaml:"filtering_enabled"` // whether or not use filter lists FilteringEnabled bool `yaml:"filtering_enabled"` // whether or not use filter lists
FiltersUpdateIntervalHours uint32 `yaml:"filters_update_interval"` // time period to update filters (in hours) FiltersUpdateIntervalHours uint32 `yaml:"filters_update_interval"` // time period to update filters (in hours)
DnsfilterConf dnsfilter.Config `yaml:",inline"` DnsfilterConf dnsfilter.Config `yaml:",inline"`
// Names of services to block (globally).
// Per-client settings can override this configuration.
BlockedServices []string `yaml:"blocked_services"`
} }
type tlsConfigSettings struct { type tlsConfigSettings struct {
@ -234,16 +230,6 @@ func parseConfig() error {
return err return err
} }
bsvcs := []string{}
for _, s := range config.DNS.BlockedServices {
if !blockedSvcKnown(s) {
log.Debug("skipping unknown blocked-service '%s'", s)
continue
}
bsvcs = append(bsvcs, s)
}
config.DNS.BlockedServices = bsvcs
if !checkFiltersUpdateIntervalHours(config.DNS.FiltersUpdateIntervalHours) { if !checkFiltersUpdateIntervalHours(config.DNS.FiltersUpdateIntervalHours) {
config.DNS.FiltersUpdateIntervalHours = 24 config.DNS.FiltersUpdateIntervalHours = 24
} }

View File

@ -9,9 +9,8 @@ import (
"strconv" "strconv"
"strings" "strings"
"github.com/AdguardTeam/AdGuardHome/util"
"github.com/AdguardTeam/AdGuardHome/dnsforward" "github.com/AdguardTeam/AdGuardHome/dnsforward"
"github.com/AdguardTeam/AdGuardHome/util"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/NYTimes/gziphandler" "github.com/NYTimes/gziphandler"
) )
@ -174,7 +173,6 @@ func registerControlHandlers() {
httpRegister("GET", "/control/profile", handleGetProfile) httpRegister("GET", "/control/profile", handleGetProfile)
RegisterTLSHandlers() RegisterTLSHandlers()
RegisterBlockedServicesHandlers()
RegisterAuthHandlers() RegisterAuthHandlers()
http.HandleFunc("/dns-query", postInstall(handleDOH)) http.HandleFunc("/dns-query", postInstall(handleDOH))

View File

@ -349,7 +349,7 @@ func (f *Filtering) handleCheckHost(w http.ResponseWriter, r *http.Request) {
setts := Context.dnsFilter.GetConfig() setts := Context.dnsFilter.GetConfig()
setts.FilteringEnabled = true setts.FilteringEnabled = true
ApplyBlockedServices(&setts, config.DNS.BlockedServices) Context.dnsFilter.ApplyBlockedServices(&setts, nil, true)
result, err := Context.dnsFilter.CheckHost(host, dns.TypeA, &setts) result, err := Context.dnsFilter.CheckHost(host, dns.TypeA, &setts)
if err != nil { if err != nil {
httpError(w, http.StatusInternalServerError, "couldn't apply filtering: %s: %s", host, err) httpError(w, http.StatusInternalServerError, "couldn't apply filtering: %s: %s", host, err)

View File

@ -175,7 +175,7 @@ func getUpstreamsByClient(clientAddr string) []upstream.Upstream {
// If a client has his own settings, apply them // If a client has his own settings, apply them
func applyAdditionalFiltering(clientAddr string, setts *dnsfilter.RequestFilteringSettings) { func applyAdditionalFiltering(clientAddr string, setts *dnsfilter.RequestFilteringSettings) {
ApplyBlockedServices(setts, config.DNS.BlockedServices) Context.dnsFilter.ApplyBlockedServices(setts, nil, true)
if len(clientAddr) == 0 { if len(clientAddr) == 0 {
return return
@ -189,7 +189,7 @@ func applyAdditionalFiltering(clientAddr string, setts *dnsfilter.RequestFilteri
log.Debug("Using settings for client with IP %s", clientAddr) log.Debug("Using settings for client with IP %s", clientAddr)
if c.UseOwnBlockedServices { if c.UseOwnBlockedServices {
ApplyBlockedServices(setts, c.BlockedServices) Context.dnsFilter.ApplyBlockedServices(setts, c.BlockedServices, false)
} }
setts.ClientTags = c.Tags setts.ClientTags = c.Tags

View File

@ -172,7 +172,6 @@ func run(args options) {
} }
initConfig() initConfig()
initServices()
Context.tlsRoots = util.LoadSystemRootCAs() Context.tlsRoots = util.LoadSystemRootCAs()
Context.transport = &http.Transport{ Context.transport = &http.Transport{