diff --git a/home/blocked_services.go b/dnsfilter/blocked_services.go similarity index 77% rename from home/blocked_services.go rename to dnsfilter/blocked_services.go index 5c14499d..4969b2d9 100644 --- a/home/blocked_services.go +++ b/dnsfilter/blocked_services.go @@ -1,10 +1,9 @@ -package home +package dnsfilter import ( "encoding/json" "net/http" - "github.com/AdguardTeam/AdGuardHome/dnsfilter" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/urlfilter/rules" ) @@ -119,7 +118,7 @@ var serviceRulesArray = []svc{ } // convert array to map -func initServices() { +func initBlockedServices() { serviceRules = make(map[string][]*rules.NetworkRule) for _, s := range serviceRulesArray { netRules := []*rules.NetworkRule{} @@ -135,15 +134,20 @@ func initServices() { } } -// Return TRUE if a blocked service name is known -func blockedSvcKnown(s string) bool { +// BlockedSvcKnown - return TRUE if a blocked service name is known +func BlockedSvcKnown(s string) bool { _, ok := serviceRules[s] return ok } // ApplyBlockedServices - set blocked services settings for this DNS request -func ApplyBlockedServices(setts *dnsfilter.RequestFilteringSettings, list []string) { - setts.ServicesRules = []dnsfilter.ServiceEntry{} +func (d *Dnsfilter) ApplyBlockedServices(setts *RequestFilteringSettings, list []string, global bool) { + setts.ServicesRules = []ServiceEntry{} + if global { + d.confLock.RLock() + defer d.confLock.RUnlock() + list = d.Config.BlockedServices + } for _, name := range list { rules, ok := serviceRules[name] @@ -152,51 +156,45 @@ func ApplyBlockedServices(setts *dnsfilter.RequestFilteringSettings, list []stri continue } - s := dnsfilter.ServiceEntry{} + s := ServiceEntry{} s.Name = name s.Rules = rules setts.ServicesRules = append(setts.ServicesRules, s) } } -func handleBlockedServicesList(w http.ResponseWriter, r *http.Request) { - config.RLock() - list := config.DNS.BlockedServices - config.RUnlock() +func (d *Dnsfilter) handleBlockedServicesList(w http.ResponseWriter, r *http.Request) { + d.confLock.RLock() + list := d.Config.BlockedServices + d.confLock.RUnlock() w.Header().Set("Content-Type", "application/json") err := json.NewEncoder(w).Encode(list) if err != nil { - httpError(w, http.StatusInternalServerError, "json.Encode: %s", err) + httpError(r, w, http.StatusInternalServerError, "json.Encode: %s", err) return } } -func handleBlockedServicesSet(w http.ResponseWriter, r *http.Request) { +func (d *Dnsfilter) handleBlockedServicesSet(w http.ResponseWriter, r *http.Request) { list := []string{} err := json.NewDecoder(r.Body).Decode(&list) if err != nil { - httpError(w, http.StatusBadRequest, "json.Decode: %s", err) + httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err) return } - config.Lock() - config.DNS.BlockedServices = list - config.Unlock() + d.confLock.Lock() + d.Config.BlockedServices = list + d.confLock.Unlock() log.Debug("Updated blocked services list: %d", len(list)) - err = writeAllConfigsAndReloadDNS() - if err != nil { - httpError(w, http.StatusBadRequest, "%s", err) - return - } - - httpOK(r, w) + d.ConfigModified() } -// RegisterBlockedServicesHandlers - register HTTP handlers -func RegisterBlockedServicesHandlers() { - httpRegister(http.MethodGet, "/control/blocked_services/list", handleBlockedServicesList) - httpRegister(http.MethodPost, "/control/blocked_services/set", handleBlockedServicesSet) +// registerBlockedServicesHandlers - register HTTP handlers +func (d *Dnsfilter) registerBlockedServicesHandlers() { + d.Config.HTTPRegister("GET", "/control/blocked_services/list", d.handleBlockedServicesList) + d.Config.HTTPRegister("POST", "/control/blocked_services/set", d.handleBlockedServicesSet) } diff --git a/dnsfilter/dnsfilter.go b/dnsfilter/dnsfilter.go index adfabd9c..3e2b8f0c 100644 --- a/dnsfilter/dnsfilter.go +++ b/dnsfilter/dnsfilter.go @@ -49,6 +49,10 @@ type Config struct { Rewrites []RewriteEntry `yaml:"rewrites"` + // Names of services to block (globally). + // Per-client settings can override this configuration. + BlockedServices []string `yaml:"blocked_services"` + // Called when the configuration is changed by HTTP request ConfigModified func() `yaml:"-"` @@ -175,6 +179,7 @@ func (d *Dnsfilter) WriteDiskConfig(c *Config) { d.confLock.Lock() *c = d.Config c.Rewrites = rewriteArrayDup(d.Config.Rewrites) + // BlockedServices d.confLock.Unlock() } @@ -633,6 +638,18 @@ func New(c *Config, blockFilters []Filter) *Dnsfilter { d.prepareRewrites() } + bsvcs := []string{} + for _, s := range d.BlockedServices { + if !BlockedSvcKnown(s) { + log.Debug("skipping unknown blocked-service '%s'", s) + continue + } + bsvcs = append(bsvcs, s) + } + d.BlockedServices = bsvcs + + initBlockedServices() + if blockFilters != nil { err := d.initFiltering(nil, blockFilters) if err != nil { @@ -655,6 +672,7 @@ func (d *Dnsfilter) Start() { if d.Config.HTTPRegister != nil { // for tests d.registerSecurityHandlers() d.registerRewritesHandlers() + d.registerBlockedServicesHandlers() } } diff --git a/home/clients.go b/home/clients.go index 271ad9b5..c7571a5f 100644 --- a/home/clients.go +++ b/home/clients.go @@ -14,6 +14,7 @@ import ( "time" "github.com/AdguardTeam/AdGuardHome/dhcpd" + "github.com/AdguardTeam/AdGuardHome/dnsfilter" "github.com/AdguardTeam/AdGuardHome/dnsforward" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/log" @@ -161,7 +162,7 @@ func (clients *clientsContainer) addFromConfig(objects []clientObject) { } for _, s := range cy.BlockedServices { - if !blockedSvcKnown(s) { + if !dnsfilter.BlockedSvcKnown(s) { log.Debug("Clients: skipping unknown blocked-service '%s'", s) continue } diff --git a/home/config.go b/home/config.go index 657325e2..41df6829 100644 --- a/home/config.go +++ b/home/config.go @@ -95,10 +95,6 @@ type dnsConfig struct { FilteringEnabled bool `yaml:"filtering_enabled"` // whether or not use filter lists FiltersUpdateIntervalHours uint32 `yaml:"filters_update_interval"` // time period to update filters (in hours) DnsfilterConf dnsfilter.Config `yaml:",inline"` - - // Names of services to block (globally). - // Per-client settings can override this configuration. - BlockedServices []string `yaml:"blocked_services"` } type tlsConfigSettings struct { @@ -234,16 +230,6 @@ func parseConfig() error { return err } - bsvcs := []string{} - for _, s := range config.DNS.BlockedServices { - if !blockedSvcKnown(s) { - log.Debug("skipping unknown blocked-service '%s'", s) - continue - } - bsvcs = append(bsvcs, s) - } - config.DNS.BlockedServices = bsvcs - if !checkFiltersUpdateIntervalHours(config.DNS.FiltersUpdateIntervalHours) { config.DNS.FiltersUpdateIntervalHours = 24 } diff --git a/home/control.go b/home/control.go index 82758074..571b0708 100644 --- a/home/control.go +++ b/home/control.go @@ -9,9 +9,8 @@ import ( "strconv" "strings" - "github.com/AdguardTeam/AdGuardHome/util" - "github.com/AdguardTeam/AdGuardHome/dnsforward" + "github.com/AdguardTeam/AdGuardHome/util" "github.com/AdguardTeam/golibs/log" "github.com/NYTimes/gziphandler" ) @@ -174,7 +173,6 @@ func registerControlHandlers() { httpRegister("GET", "/control/profile", handleGetProfile) RegisterTLSHandlers() - RegisterBlockedServicesHandlers() RegisterAuthHandlers() http.HandleFunc("/dns-query", postInstall(handleDOH)) diff --git a/home/control_filtering.go b/home/control_filtering.go index 242bbf3d..837d2a45 100644 --- a/home/control_filtering.go +++ b/home/control_filtering.go @@ -349,7 +349,7 @@ func (f *Filtering) handleCheckHost(w http.ResponseWriter, r *http.Request) { setts := Context.dnsFilter.GetConfig() setts.FilteringEnabled = true - ApplyBlockedServices(&setts, config.DNS.BlockedServices) + Context.dnsFilter.ApplyBlockedServices(&setts, nil, true) result, err := Context.dnsFilter.CheckHost(host, dns.TypeA, &setts) if err != nil { httpError(w, http.StatusInternalServerError, "couldn't apply filtering: %s: %s", host, err) diff --git a/home/dns.go b/home/dns.go index bc97bd66..81eb4fd1 100644 --- a/home/dns.go +++ b/home/dns.go @@ -175,7 +175,7 @@ func getUpstreamsByClient(clientAddr string) []upstream.Upstream { // If a client has his own settings, apply them func applyAdditionalFiltering(clientAddr string, setts *dnsfilter.RequestFilteringSettings) { - ApplyBlockedServices(setts, config.DNS.BlockedServices) + Context.dnsFilter.ApplyBlockedServices(setts, nil, true) if len(clientAddr) == 0 { return @@ -189,7 +189,7 @@ func applyAdditionalFiltering(clientAddr string, setts *dnsfilter.RequestFilteri log.Debug("Using settings for client with IP %s", clientAddr) if c.UseOwnBlockedServices { - ApplyBlockedServices(setts, c.BlockedServices) + Context.dnsFilter.ApplyBlockedServices(setts, c.BlockedServices, false) } setts.ClientTags = c.Tags diff --git a/home/home.go b/home/home.go index 63a01b0d..d4a4f215 100644 --- a/home/home.go +++ b/home/home.go @@ -172,7 +172,6 @@ func run(args options) { } initConfig() - initServices() Context.tlsRoots = util.LoadSystemRootCAs() Context.transport = &http.Transport{