* refactor: move blocked-services functions to dnsfilter
This commit is contained in:
parent
e8129f15c7
commit
0789e4b20d
|
@ -1,10 +1,9 @@
|
||||||
package home
|
package dnsfilter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
|
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/AdguardTeam/urlfilter/rules"
|
"github.com/AdguardTeam/urlfilter/rules"
|
||||||
)
|
)
|
||||||
|
@ -119,7 +118,7 @@ var serviceRulesArray = []svc{
|
||||||
}
|
}
|
||||||
|
|
||||||
// convert array to map
|
// convert array to map
|
||||||
func initServices() {
|
func initBlockedServices() {
|
||||||
serviceRules = make(map[string][]*rules.NetworkRule)
|
serviceRules = make(map[string][]*rules.NetworkRule)
|
||||||
for _, s := range serviceRulesArray {
|
for _, s := range serviceRulesArray {
|
||||||
netRules := []*rules.NetworkRule{}
|
netRules := []*rules.NetworkRule{}
|
||||||
|
@ -135,15 +134,20 @@ func initServices() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return TRUE if a blocked service name is known
|
// BlockedSvcKnown - return TRUE if a blocked service name is known
|
||||||
func blockedSvcKnown(s string) bool {
|
func BlockedSvcKnown(s string) bool {
|
||||||
_, ok := serviceRules[s]
|
_, ok := serviceRules[s]
|
||||||
return ok
|
return ok
|
||||||
}
|
}
|
||||||
|
|
||||||
// ApplyBlockedServices - set blocked services settings for this DNS request
|
// ApplyBlockedServices - set blocked services settings for this DNS request
|
||||||
func ApplyBlockedServices(setts *dnsfilter.RequestFilteringSettings, list []string) {
|
func (d *Dnsfilter) ApplyBlockedServices(setts *RequestFilteringSettings, list []string, global bool) {
|
||||||
setts.ServicesRules = []dnsfilter.ServiceEntry{}
|
setts.ServicesRules = []ServiceEntry{}
|
||||||
|
if global {
|
||||||
|
d.confLock.RLock()
|
||||||
|
defer d.confLock.RUnlock()
|
||||||
|
list = d.Config.BlockedServices
|
||||||
|
}
|
||||||
for _, name := range list {
|
for _, name := range list {
|
||||||
rules, ok := serviceRules[name]
|
rules, ok := serviceRules[name]
|
||||||
|
|
||||||
|
@ -152,51 +156,45 @@ func ApplyBlockedServices(setts *dnsfilter.RequestFilteringSettings, list []stri
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
s := dnsfilter.ServiceEntry{}
|
s := ServiceEntry{}
|
||||||
s.Name = name
|
s.Name = name
|
||||||
s.Rules = rules
|
s.Rules = rules
|
||||||
setts.ServicesRules = append(setts.ServicesRules, s)
|
setts.ServicesRules = append(setts.ServicesRules, s)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleBlockedServicesList(w http.ResponseWriter, r *http.Request) {
|
func (d *Dnsfilter) handleBlockedServicesList(w http.ResponseWriter, r *http.Request) {
|
||||||
config.RLock()
|
d.confLock.RLock()
|
||||||
list := config.DNS.BlockedServices
|
list := d.Config.BlockedServices
|
||||||
config.RUnlock()
|
d.confLock.RUnlock()
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
err := json.NewEncoder(w).Encode(list)
|
err := json.NewEncoder(w).Encode(list)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(w, http.StatusInternalServerError, "json.Encode: %s", err)
|
httpError(r, w, http.StatusInternalServerError, "json.Encode: %s", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleBlockedServicesSet(w http.ResponseWriter, r *http.Request) {
|
func (d *Dnsfilter) handleBlockedServicesSet(w http.ResponseWriter, r *http.Request) {
|
||||||
list := []string{}
|
list := []string{}
|
||||||
err := json.NewDecoder(r.Body).Decode(&list)
|
err := json.NewDecoder(r.Body).Decode(&list)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(w, http.StatusBadRequest, "json.Decode: %s", err)
|
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
config.Lock()
|
d.confLock.Lock()
|
||||||
config.DNS.BlockedServices = list
|
d.Config.BlockedServices = list
|
||||||
config.Unlock()
|
d.confLock.Unlock()
|
||||||
|
|
||||||
log.Debug("Updated blocked services list: %d", len(list))
|
log.Debug("Updated blocked services list: %d", len(list))
|
||||||
|
|
||||||
err = writeAllConfigsAndReloadDNS()
|
d.ConfigModified()
|
||||||
if err != nil {
|
|
||||||
httpError(w, http.StatusBadRequest, "%s", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
httpOK(r, w)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegisterBlockedServicesHandlers - register HTTP handlers
|
// registerBlockedServicesHandlers - register HTTP handlers
|
||||||
func RegisterBlockedServicesHandlers() {
|
func (d *Dnsfilter) registerBlockedServicesHandlers() {
|
||||||
httpRegister(http.MethodGet, "/control/blocked_services/list", handleBlockedServicesList)
|
d.Config.HTTPRegister("GET", "/control/blocked_services/list", d.handleBlockedServicesList)
|
||||||
httpRegister(http.MethodPost, "/control/blocked_services/set", handleBlockedServicesSet)
|
d.Config.HTTPRegister("POST", "/control/blocked_services/set", d.handleBlockedServicesSet)
|
||||||
}
|
}
|
|
@ -49,6 +49,10 @@ type Config struct {
|
||||||
|
|
||||||
Rewrites []RewriteEntry `yaml:"rewrites"`
|
Rewrites []RewriteEntry `yaml:"rewrites"`
|
||||||
|
|
||||||
|
// Names of services to block (globally).
|
||||||
|
// Per-client settings can override this configuration.
|
||||||
|
BlockedServices []string `yaml:"blocked_services"`
|
||||||
|
|
||||||
// Called when the configuration is changed by HTTP request
|
// Called when the configuration is changed by HTTP request
|
||||||
ConfigModified func() `yaml:"-"`
|
ConfigModified func() `yaml:"-"`
|
||||||
|
|
||||||
|
@ -175,6 +179,7 @@ func (d *Dnsfilter) WriteDiskConfig(c *Config) {
|
||||||
d.confLock.Lock()
|
d.confLock.Lock()
|
||||||
*c = d.Config
|
*c = d.Config
|
||||||
c.Rewrites = rewriteArrayDup(d.Config.Rewrites)
|
c.Rewrites = rewriteArrayDup(d.Config.Rewrites)
|
||||||
|
// BlockedServices
|
||||||
d.confLock.Unlock()
|
d.confLock.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -633,6 +638,18 @@ func New(c *Config, blockFilters []Filter) *Dnsfilter {
|
||||||
d.prepareRewrites()
|
d.prepareRewrites()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bsvcs := []string{}
|
||||||
|
for _, s := range d.BlockedServices {
|
||||||
|
if !BlockedSvcKnown(s) {
|
||||||
|
log.Debug("skipping unknown blocked-service '%s'", s)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
bsvcs = append(bsvcs, s)
|
||||||
|
}
|
||||||
|
d.BlockedServices = bsvcs
|
||||||
|
|
||||||
|
initBlockedServices()
|
||||||
|
|
||||||
if blockFilters != nil {
|
if blockFilters != nil {
|
||||||
err := d.initFiltering(nil, blockFilters)
|
err := d.initFiltering(nil, blockFilters)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -655,6 +672,7 @@ func (d *Dnsfilter) Start() {
|
||||||
if d.Config.HTTPRegister != nil { // for tests
|
if d.Config.HTTPRegister != nil { // for tests
|
||||||
d.registerSecurityHandlers()
|
d.registerSecurityHandlers()
|
||||||
d.registerRewritesHandlers()
|
d.registerRewritesHandlers()
|
||||||
|
d.registerBlockedServicesHandlers()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -14,6 +14,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/dhcpd"
|
"github.com/AdguardTeam/AdGuardHome/dhcpd"
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
|
||||||
"github.com/AdguardTeam/AdGuardHome/dnsforward"
|
"github.com/AdguardTeam/AdGuardHome/dnsforward"
|
||||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
|
@ -161,7 +162,7 @@ func (clients *clientsContainer) addFromConfig(objects []clientObject) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, s := range cy.BlockedServices {
|
for _, s := range cy.BlockedServices {
|
||||||
if !blockedSvcKnown(s) {
|
if !dnsfilter.BlockedSvcKnown(s) {
|
||||||
log.Debug("Clients: skipping unknown blocked-service '%s'", s)
|
log.Debug("Clients: skipping unknown blocked-service '%s'", s)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
|
@ -95,10 +95,6 @@ type dnsConfig struct {
|
||||||
FilteringEnabled bool `yaml:"filtering_enabled"` // whether or not use filter lists
|
FilteringEnabled bool `yaml:"filtering_enabled"` // whether or not use filter lists
|
||||||
FiltersUpdateIntervalHours uint32 `yaml:"filters_update_interval"` // time period to update filters (in hours)
|
FiltersUpdateIntervalHours uint32 `yaml:"filters_update_interval"` // time period to update filters (in hours)
|
||||||
DnsfilterConf dnsfilter.Config `yaml:",inline"`
|
DnsfilterConf dnsfilter.Config `yaml:",inline"`
|
||||||
|
|
||||||
// Names of services to block (globally).
|
|
||||||
// Per-client settings can override this configuration.
|
|
||||||
BlockedServices []string `yaml:"blocked_services"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type tlsConfigSettings struct {
|
type tlsConfigSettings struct {
|
||||||
|
@ -234,16 +230,6 @@ func parseConfig() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
bsvcs := []string{}
|
|
||||||
for _, s := range config.DNS.BlockedServices {
|
|
||||||
if !blockedSvcKnown(s) {
|
|
||||||
log.Debug("skipping unknown blocked-service '%s'", s)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
bsvcs = append(bsvcs, s)
|
|
||||||
}
|
|
||||||
config.DNS.BlockedServices = bsvcs
|
|
||||||
|
|
||||||
if !checkFiltersUpdateIntervalHours(config.DNS.FiltersUpdateIntervalHours) {
|
if !checkFiltersUpdateIntervalHours(config.DNS.FiltersUpdateIntervalHours) {
|
||||||
config.DNS.FiltersUpdateIntervalHours = 24
|
config.DNS.FiltersUpdateIntervalHours = 24
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,9 +9,8 @@ import (
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/util"
|
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/dnsforward"
|
"github.com/AdguardTeam/AdGuardHome/dnsforward"
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/util"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/NYTimes/gziphandler"
|
"github.com/NYTimes/gziphandler"
|
||||||
)
|
)
|
||||||
|
@ -174,7 +173,6 @@ func registerControlHandlers() {
|
||||||
httpRegister("GET", "/control/profile", handleGetProfile)
|
httpRegister("GET", "/control/profile", handleGetProfile)
|
||||||
|
|
||||||
RegisterTLSHandlers()
|
RegisterTLSHandlers()
|
||||||
RegisterBlockedServicesHandlers()
|
|
||||||
RegisterAuthHandlers()
|
RegisterAuthHandlers()
|
||||||
|
|
||||||
http.HandleFunc("/dns-query", postInstall(handleDOH))
|
http.HandleFunc("/dns-query", postInstall(handleDOH))
|
||||||
|
|
|
@ -349,7 +349,7 @@ func (f *Filtering) handleCheckHost(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
setts := Context.dnsFilter.GetConfig()
|
setts := Context.dnsFilter.GetConfig()
|
||||||
setts.FilteringEnabled = true
|
setts.FilteringEnabled = true
|
||||||
ApplyBlockedServices(&setts, config.DNS.BlockedServices)
|
Context.dnsFilter.ApplyBlockedServices(&setts, nil, true)
|
||||||
result, err := Context.dnsFilter.CheckHost(host, dns.TypeA, &setts)
|
result, err := Context.dnsFilter.CheckHost(host, dns.TypeA, &setts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(w, http.StatusInternalServerError, "couldn't apply filtering: %s: %s", host, err)
|
httpError(w, http.StatusInternalServerError, "couldn't apply filtering: %s: %s", host, err)
|
||||||
|
|
|
@ -175,7 +175,7 @@ func getUpstreamsByClient(clientAddr string) []upstream.Upstream {
|
||||||
|
|
||||||
// If a client has his own settings, apply them
|
// If a client has his own settings, apply them
|
||||||
func applyAdditionalFiltering(clientAddr string, setts *dnsfilter.RequestFilteringSettings) {
|
func applyAdditionalFiltering(clientAddr string, setts *dnsfilter.RequestFilteringSettings) {
|
||||||
ApplyBlockedServices(setts, config.DNS.BlockedServices)
|
Context.dnsFilter.ApplyBlockedServices(setts, nil, true)
|
||||||
|
|
||||||
if len(clientAddr) == 0 {
|
if len(clientAddr) == 0 {
|
||||||
return
|
return
|
||||||
|
@ -189,7 +189,7 @@ func applyAdditionalFiltering(clientAddr string, setts *dnsfilter.RequestFilteri
|
||||||
log.Debug("Using settings for client with IP %s", clientAddr)
|
log.Debug("Using settings for client with IP %s", clientAddr)
|
||||||
|
|
||||||
if c.UseOwnBlockedServices {
|
if c.UseOwnBlockedServices {
|
||||||
ApplyBlockedServices(setts, c.BlockedServices)
|
Context.dnsFilter.ApplyBlockedServices(setts, c.BlockedServices, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
setts.ClientTags = c.Tags
|
setts.ClientTags = c.Tags
|
||||||
|
|
|
@ -172,7 +172,6 @@ func run(args options) {
|
||||||
}
|
}
|
||||||
|
|
||||||
initConfig()
|
initConfig()
|
||||||
initServices()
|
|
||||||
|
|
||||||
Context.tlsRoots = util.LoadSystemRootCAs()
|
Context.tlsRoots = util.LoadSystemRootCAs()
|
||||||
Context.transport = &http.Transport{
|
Context.transport = &http.Transport{
|
||||||
|
|
Loading…
Reference in New Issue