diff --git a/dnsforward/access.go b/dnsforward/access.go new file mode 100644 index 00000000..e27ddb1f --- /dev/null +++ b/dnsforward/access.go @@ -0,0 +1,188 @@ +package dnsforward + +import ( + "encoding/json" + "net" + "net/http" + "sync" + + "github.com/AdguardTeam/golibs/log" +) + +type accessCtx struct { + lock sync.Mutex + + allowedClients map[string]bool // IP addresses of whitelist clients + disallowedClients map[string]bool // IP addresses of clients that should be blocked + + allowedClientsIPNet []net.IPNet // CIDRs of whitelist clients + disallowedClientsIPNet []net.IPNet // CIDRs of clients that should be blocked + + blockedHosts map[string]bool // hosts that should be blocked +} + +func (a *accessCtx) Init(allowedClients, disallowedClients, blockedHosts []string) error { + err := processIPCIDRArray(&a.allowedClients, &a.allowedClientsIPNet, allowedClients) + if err != nil { + return err + } + + err = processIPCIDRArray(&a.disallowedClients, &a.disallowedClientsIPNet, disallowedClients) + if err != nil { + return err + } + + convertArrayToMap(&a.blockedHosts, blockedHosts) + return nil +} + +func convertArrayToMap(dst *map[string]bool, src []string) { + *dst = make(map[string]bool) + for _, s := range src { + (*dst)[s] = true + } +} + +// Split array of IP or CIDR into 2 containers for fast search +func processIPCIDRArray(dst *map[string]bool, dstIPNet *[]net.IPNet, src []string) error { + *dst = make(map[string]bool) + + for _, s := range src { + ip := net.ParseIP(s) + if ip != nil { + (*dst)[s] = true + continue + } + + _, ipnet, err := net.ParseCIDR(s) + if err != nil { + return err + } + *dstIPNet = append(*dstIPNet, *ipnet) + } + + return nil +} + +// IsBlockedIP - return TRUE if this client should be blocked +func (a *accessCtx) IsBlockedIP(ip string) bool { + a.lock.Lock() + defer a.lock.Unlock() + + if len(a.allowedClients) != 0 || len(a.allowedClientsIPNet) != 0 { + _, ok := a.allowedClients[ip] + if ok { + return false + } + + if len(a.allowedClientsIPNet) != 0 { + ipAddr := net.ParseIP(ip) + for _, ipnet := range a.allowedClientsIPNet { + if ipnet.Contains(ipAddr) { + return false + } + } + } + + return true + } + + _, ok := a.disallowedClients[ip] + if ok { + return true + } + + if len(a.disallowedClientsIPNet) != 0 { + ipAddr := net.ParseIP(ip) + for _, ipnet := range a.disallowedClientsIPNet { + if ipnet.Contains(ipAddr) { + return true + } + } + } + + return false +} + +// IsBlockedDomain - return TRUE if this domain should be blocked +func (a *accessCtx) IsBlockedDomain(host string) bool { + a.lock.Lock() + _, ok := a.blockedHosts[host] + a.lock.Unlock() + return ok +} + +type accessListJSON struct { + AllowedClients []string `json:"allowed_clients"` + DisallowedClients []string `json:"disallowed_clients"` + BlockedHosts []string `json:"blocked_hosts"` +} + +func (s *Server) handleAccessList(w http.ResponseWriter, r *http.Request) { + s.Lock() + j := accessListJSON{ + AllowedClients: s.conf.AllowedClients, + DisallowedClients: s.conf.DisallowedClients, + BlockedHosts: s.conf.BlockedHosts, + } + s.Unlock() + + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(j) + if err != nil { + httpError(r, w, http.StatusInternalServerError, "json.Encode: %s", err) + return + } +} + +func checkIPCIDRArray(src []string) error { + for _, s := range src { + ip := net.ParseIP(s) + if ip != nil { + continue + } + + _, _, err := net.ParseCIDR(s) + if err != nil { + return err + } + } + + return nil +} + +func (s *Server) handleAccessSet(w http.ResponseWriter, r *http.Request) { + j := accessListJSON{} + err := json.NewDecoder(r.Body).Decode(&j) + if err != nil { + httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err) + return + } + + err = checkIPCIDRArray(j.AllowedClients) + if err == nil { + err = checkIPCIDRArray(j.DisallowedClients) + } + if err != nil { + httpError(r, w, http.StatusBadRequest, "%s", err) + return + } + + a := &accessCtx{} + err = a.Init(j.AllowedClients, j.DisallowedClients, j.BlockedHosts) + if err != nil { + httpError(r, w, http.StatusBadRequest, "access.Init: %s", err) + return + } + + s.Lock() + s.conf.AllowedClients = j.AllowedClients + s.conf.DisallowedClients = j.DisallowedClients + s.conf.BlockedHosts = j.BlockedHosts + s.access = a + s.Unlock() + s.conf.ConfigModified() + + log.Debug("Access: updated lists: %d, %d, %d", + len(j.AllowedClients), len(j.DisallowedClients), len(j.BlockedHosts)) +} diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index cb86c235..94de90cf 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -3,6 +3,7 @@ package dnsforward import ( "crypto/tls" "errors" + "fmt" "net" "net/http" "runtime" @@ -28,6 +29,12 @@ const ( parentalBlockHost = "family-block.dns.adguard.com" ) +var defaultDNS = []string{ + "https://1.1.1.1/dns-query", + "https://1.0.0.1/dns-query", +} +var defaultBootstrap = []string{"1.1.1.1", "1.0.0.1"} + // Server is the main way to start a DNS server. // // Example: @@ -43,12 +50,9 @@ type Server struct { dnsFilter *dnsfilter.Dnsfilter // DNS filter instance queryLog querylog.QueryLog // Query log instance stats stats.Stats + access *accessCtx - AllowedClients map[string]bool // IP addresses of whitelist clients - DisallowedClients map[string]bool // IP addresses of clients that should be blocked - AllowedClientsIPNet []net.IPNet // CIDRs of whitelist clients - DisallowedClientsIPNet []net.IPNet // CIDRs of clients that should be blocked - BlockedHosts map[string]bool // hosts that should be blocked + webRegistered bool sync.RWMutex conf ServerConfig @@ -61,9 +65,15 @@ func NewServer(dnsFilter *dnsfilter.Dnsfilter, stats stats.Stats, queryLog query s.dnsFilter = dnsFilter s.stats = stats s.queryLog = queryLog + + if runtime.GOARCH == "mips" || runtime.GOARCH == "mipsle" { + // Use plain DNS on MIPS, encryption is too slow + defaultDNS = []string{"1.1.1.1", "1.0.0.1"} + } return s } +// Close - close object func (s *Server) Close() { s.Lock() s.dnsFilter = nil @@ -72,20 +82,23 @@ func (s *Server) Close() { s.Unlock() } +// WriteDiskConfig - write configuration +func (s *Server) WriteDiskConfig(c *FilteringConfig) { + s.Lock() + *c = s.conf.FilteringConfig + s.Unlock() +} + // FilteringConfig represents the DNS filtering configuration of AdGuard Home // The zero FilteringConfig is empty and ready for use. type FilteringConfig struct { // Filtering callback function FilterHandler func(clientAddr string, settings *dnsfilter.RequestFilteringSettings) `yaml:"-"` - ProtectionEnabled bool `yaml:"protection_enabled"` // whether or not use any of dnsfilter features - FilteringEnabled bool `yaml:"filtering_enabled"` // whether or not use filter lists - FiltersUpdateIntervalHours uint32 `yaml:"filters_update_interval"` // time period to update filters (in hours) + ProtectionEnabled bool `yaml:"protection_enabled"` // whether or not use any of dnsfilter features BlockingMode string `yaml:"blocking_mode"` // mode how to answer filtered requests BlockedResponseTTL uint32 `yaml:"blocked_response_ttl"` // if 0, then default is used (3600) - QueryLogEnabled bool `yaml:"querylog_enabled"` // if true, query log is enabled - QueryLogInterval uint32 `yaml:"querylog_interval"` // time interval for query log (in days) Ratelimit int `yaml:"ratelimit"` // max number of requests per second from a given IP (0 to disable) RatelimitWhitelist []string `yaml:"ratelimit_whitelist"` // a list of whitelisted client IP addresses RefuseAny bool `yaml:"refuse_any"` // if true, refuse ANY requests @@ -100,13 +113,8 @@ type FilteringConfig struct { ParentalBlockHost string `yaml:"parental_block_host"` SafeBrowsingBlockHost string `yaml:"safebrowsing_block_host"` - // Names of services to block (globally). - // Per-client settings can override this configuration. - BlockedServices []string `yaml:"blocked_services"` - - CacheSize uint `yaml:"cache_size"` // DNS cache size (in bytes) - - DnsfilterConf dnsfilter.Config `yaml:",inline"` + CacheSize uint `yaml:"cache_size"` // DNS cache size (in bytes) + UpstreamDNS []string `yaml:"upstream_dns"` } // TLSConfig is the TLS configuration for HTTPS, DNS-over-HTTPS, and DNS-over-TLS @@ -133,6 +141,12 @@ type ServerConfig struct { FilteringConfig TLSConfig + + // Called when the configuration is changed by HTTP request + ConfigModified func() + + // Register an HTTP handler + HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request)) } // if any of ServerConfig values are zero, then default values from below are used @@ -142,19 +156,6 @@ var defaultValues = ServerConfig{ FilteringConfig: FilteringConfig{BlockedResponseTTL: 3600}, } -func init() { - defaultDNS := []string{"8.8.8.8:53", "8.8.4.4:53"} - - defaultUpstreams := make([]upstream.Upstream, 0) - for _, addr := range defaultDNS { - u, err := upstream.AddressToUpstream(addr, upstream.Options{Timeout: DefaultTimeout}) - if err == nil { - defaultUpstreams = append(defaultUpstreams, u) - } - } - defaultValues.Upstreams = defaultUpstreams -} - // Start starts the DNS server func (s *Server) Start(config *ServerConfig) error { s.Lock() @@ -162,34 +163,6 @@ func (s *Server) Start(config *ServerConfig) error { return s.startInternal(config) } -func convertArrayToMap(dst *map[string]bool, src []string) { - *dst = make(map[string]bool) - for _, s := range src { - (*dst)[s] = true - } -} - -// Split array of IP or CIDR into 2 containers for fast search -func processIPCIDRArray(dst *map[string]bool, dstIPNet *[]net.IPNet, src []string) error { - *dst = make(map[string]bool) - - for _, s := range src { - ip := net.ParseIP(s) - if ip != nil { - (*dst)[s] = true - continue - } - - _, ipnet, err := net.ParseCIDR(s) - if err != nil { - return err - } - *dstIPNet = append(*dstIPNet, *ipnet) - } - - return nil -} - // startInternal starts without locking func (s *Server) startInternal(config *ServerConfig) error { if s.dnsProxy != nil { @@ -199,12 +172,33 @@ func (s *Server) startInternal(config *ServerConfig) error { if config != nil { s.conf = *config } + + if len(s.conf.UpstreamDNS) == 0 { + s.conf.UpstreamDNS = defaultDNS + } + if len(s.conf.BootstrapDNS) == 0 { + s.conf.BootstrapDNS = defaultBootstrap + } + + upstreamConfig, err := proxy.ParseUpstreamsConfig(s.conf.UpstreamDNS, s.conf.BootstrapDNS, DefaultTimeout) + if err != nil { + return fmt.Errorf("DNS: proxy.ParseUpstreamsConfig: %s", err) + } + s.conf.Upstreams = upstreamConfig.Upstreams + s.conf.DomainsReservedUpstreams = upstreamConfig.DomainReservedUpstreams + if len(s.conf.ParentalBlockHost) == 0 { s.conf.ParentalBlockHost = parentalBlockHost } if len(s.conf.SafeBrowsingBlockHost) == 0 { s.conf.SafeBrowsingBlockHost = safeBrowsingBlockHost } + if s.conf.UDPListenAddr == nil { + s.conf.UDPListenAddr = defaultValues.UDPListenAddr + } + if s.conf.TCPListenAddr == nil { + s.conf.TCPListenAddr = defaultValues.TCPListenAddr + } proxyConfig := proxy.Config{ UDPListenAddr: s.conf.UDPListenAddr, @@ -221,18 +215,12 @@ func (s *Server) startInternal(config *ServerConfig) error { AllServers: s.conf.AllServers, } - err := processIPCIDRArray(&s.AllowedClients, &s.AllowedClientsIPNet, s.conf.AllowedClients) + s.access = &accessCtx{} + err = s.access.Init(s.conf.AllowedClients, s.conf.DisallowedClients, s.conf.BlockedHosts) if err != nil { return err } - err = processIPCIDRArray(&s.DisallowedClients, &s.DisallowedClientsIPNet, s.conf.DisallowedClients) - if err != nil { - return err - } - - convertArrayToMap(&s.BlockedHosts, s.conf.BlockedHosts) - if s.conf.TLSListenAddr != nil && len(s.conf.CertificateChainData) != 0 && len(s.conf.PrivateKeyData) != 0 { proxyConfig.TLSListenAddr = s.conf.TLSListenAddr keypair, err := tls.X509KeyPair(s.conf.CertificateChainData, s.conf.PrivateKeyData) @@ -245,16 +233,13 @@ func (s *Server) startInternal(config *ServerConfig) error { } } - if proxyConfig.UDPListenAddr == nil { - proxyConfig.UDPListenAddr = defaultValues.UDPListenAddr - } - - if proxyConfig.TCPListenAddr == nil { - proxyConfig.TCPListenAddr = defaultValues.TCPListenAddr - } - if len(proxyConfig.Upstreams) == 0 { - proxyConfig.Upstreams = defaultValues.Upstreams + log.Fatal("len(proxyConfig.Upstreams) == 0") + } + + if !s.webRegistered && s.conf.HTTPRegister != nil { + s.webRegistered = true + s.registerHandlers() } // Initialize and start the DNS proxy @@ -293,6 +278,23 @@ func (s *Server) IsRunning() bool { return isRunning } +// Restart - restart server +func (s *Server) Restart() error { + s.Lock() + defer s.Unlock() + log.Print("Start reconfiguring the server") + err := s.stopInternal() + if err != nil { + return errorx.Decorate(err, "could not reconfigure the server") + } + err = s.startInternal(nil) + if err != nil { + return errorx.Decorate(err, "could not reconfigure the server") + } + + return nil +} + // Reconfigure applies the new configuration to the DNS server func (s *Server) Reconfigure(config *ServerConfig) error { s.Lock() @@ -324,59 +326,16 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { s.RUnlock() } -// Return TRUE if this client should be blocked -func (s *Server) isBlockedIP(ip string) bool { - if len(s.AllowedClients) != 0 || len(s.AllowedClientsIPNet) != 0 { - _, ok := s.AllowedClients[ip] - if ok { - return false - } - - if len(s.AllowedClientsIPNet) != 0 { - ipAddr := net.ParseIP(ip) - for _, ipnet := range s.AllowedClientsIPNet { - if ipnet.Contains(ipAddr) { - return false - } - } - } - - return true - } - - _, ok := s.DisallowedClients[ip] - if ok { - return true - } - - if len(s.DisallowedClientsIPNet) != 0 { - ipAddr := net.ParseIP(ip) - for _, ipnet := range s.DisallowedClientsIPNet { - if ipnet.Contains(ipAddr) { - return true - } - } - } - - return false -} - -// Return TRUE if this domain should be blocked -func (s *Server) isBlockedDomain(host string) bool { - _, ok := s.BlockedHosts[host] - return ok -} - func (s *Server) beforeRequestHandler(p *proxy.Proxy, d *proxy.DNSContext) (bool, error) { ip, _, _ := net.SplitHostPort(d.Addr.String()) - if s.isBlockedIP(ip) { + if s.access.IsBlockedIP(ip) { log.Tracef("Client IP %s is blocked by settings", ip) return false, nil } if len(d.Req.Question) == 1 { host := strings.TrimSuffix(d.Req.Question[0].Name, ".") - if s.isBlockedDomain(host) { + if s.access.IsBlockedDomain(host) { log.Tracef("Domain %s is blocked by settings", host) return false, nil } diff --git a/dnsforward/dnsforward_http.go b/dnsforward/dnsforward_http.go new file mode 100644 index 00000000..b8b51d76 --- /dev/null +++ b/dnsforward/dnsforward_http.go @@ -0,0 +1,281 @@ +package dnsforward + +import ( + "encoding/json" + "fmt" + "net" + "net/http" + "strconv" + "strings" + + "github.com/AdguardTeam/dnsproxy/upstream" + "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/utils" + "github.com/miekg/dns" +) + +func httpError(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) { + text := fmt.Sprintf(format, args...) + log.Info("DNS: %s %s: %s", r.Method, r.URL, text) + http.Error(w, text, code) +} + +func (s *Server) handleProtectionEnable(w http.ResponseWriter, r *http.Request) { + s.conf.ProtectionEnabled = true + s.conf.ConfigModified() +} + +func (s *Server) handleProtectionDisable(w http.ResponseWriter, r *http.Request) { + s.conf.ProtectionEnabled = false + s.conf.ConfigModified() +} + +type upstreamJSON struct { + Upstreams []string `json:"upstream_dns"` // Upstreams + BootstrapDNS []string `json:"bootstrap_dns"` // Bootstrap DNS + AllServers bool `json:"all_servers"` // --all-servers param for dnsproxy +} + +func (s *Server) handleSetUpstreamConfig(w http.ResponseWriter, r *http.Request) { + req := upstreamJSON{} + err := json.NewDecoder(r.Body).Decode(&req) + if err != nil { + httpError(r, w, http.StatusBadRequest, "Failed to parse new upstreams config json: %s", err) + return + } + + err = validateUpstreams(req.Upstreams) + if err != nil { + httpError(r, w, http.StatusBadRequest, "wrong upstreams specification: %s", err) + return + } + + newconf := FilteringConfig{} + newconf.UpstreamDNS = req.Upstreams + + // bootstrap servers are plain DNS only + for _, host := range req.BootstrapDNS { + if err := checkPlainDNS(host); err != nil { + httpError(r, w, http.StatusBadRequest, "%s can not be used as bootstrap dns cause: %s", host, err) + return + } + } + newconf.BootstrapDNS = req.BootstrapDNS + + newconf.AllServers = req.AllServers + + s.Lock() + s.conf.UpstreamDNS = newconf.UpstreamDNS + s.conf.BootstrapDNS = newconf.BootstrapDNS + s.conf.AllServers = newconf.AllServers + s.Unlock() + s.conf.ConfigModified() + + err = s.Restart() + if err != nil { + httpError(r, w, http.StatusInternalServerError, "%s", err) + return + } +} + +// validateUpstreams validates each upstream and returns an error if any upstream is invalid or if there are no default upstreams specified +func validateUpstreams(upstreams []string) error { + var defaultUpstreamFound bool + for _, u := range upstreams { + d, err := validateUpstream(u) + if err != nil { + return err + } + + // Check this flag until default upstream will not be found + if !defaultUpstreamFound { + defaultUpstreamFound = d + } + } + + // Return error if there are no default upstreams + if !defaultUpstreamFound { + return fmt.Errorf("no default upstreams specified") + } + + return nil +} + +var protocols = []string{"tls://", "https://", "tcp://", "sdns://"} + +func validateUpstream(u string) (bool, error) { + // Check if user tries to specify upstream for domain + u, defaultUpstream, err := separateUpstream(u) + if err != nil { + return defaultUpstream, err + } + + // The special server address '#' means "use the default servers" + if u == "#" && !defaultUpstream { + return defaultUpstream, nil + } + + // Check if the upstream has a valid protocol prefix + for _, proto := range protocols { + if strings.HasPrefix(u, proto) { + return defaultUpstream, nil + } + } + + // Return error if the upstream contains '://' without any valid protocol + if strings.Contains(u, "://") { + return defaultUpstream, fmt.Errorf("wrong protocol") + } + + // Check if upstream is valid plain DNS + return defaultUpstream, checkPlainDNS(u) +} + +// separateUpstream returns upstream without specified domains and a bool flag that indicates if no domains were specified +// error will be returned if upstream per domain specification is invalid +func separateUpstream(upstream string) (string, bool, error) { + defaultUpstream := true + if strings.HasPrefix(upstream, "[/") { + defaultUpstream = false + // split domains and upstream string + domainsAndUpstream := strings.Split(strings.TrimPrefix(upstream, "[/"), "/]") + if len(domainsAndUpstream) != 2 { + return "", defaultUpstream, fmt.Errorf("wrong DNS upstream per domain specification: %s", upstream) + } + + // split domains list and validate each one + for _, host := range strings.Split(domainsAndUpstream[0], "/") { + if host != "" { + if err := utils.IsValidHostname(host); err != nil { + return "", defaultUpstream, err + } + } + } + upstream = domainsAndUpstream[1] + } + return upstream, defaultUpstream, nil +} + +// checkPlainDNS checks if host is plain DNS +func checkPlainDNS(upstream string) error { + // Check if host is ip without port + if net.ParseIP(upstream) != nil { + return nil + } + + // Check if host is ip with port + ip, port, err := net.SplitHostPort(upstream) + if err != nil { + return err + } + + if net.ParseIP(ip) == nil { + return fmt.Errorf("%s is not a valid IP", ip) + } + + _, err = strconv.ParseInt(port, 0, 64) + if err != nil { + return fmt.Errorf("%s is not a valid port: %s", port, err) + } + + return nil +} + +func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) { + req := upstreamJSON{} + err := json.NewDecoder(r.Body).Decode(&req) + if err != nil { + httpError(r, w, http.StatusBadRequest, "Failed to read request body: %s", err) + return + } + + if len(req.Upstreams) == 0 { + httpError(r, w, http.StatusBadRequest, "No servers specified") + return + } + + result := map[string]string{} + + for _, host := range req.Upstreams { + err = checkDNS(host, req.BootstrapDNS) + if err != nil { + log.Info("%v", err) + result[host] = err.Error() + } else { + result[host] = "OK" + } + } + + jsonVal, err := json.Marshal(result) + if err != nil { + httpError(r, w, http.StatusInternalServerError, "Unable to marshal status json: %s", err) + return + } + + w.Header().Set("Content-Type", "application/json") + _, err = w.Write(jsonVal) + if err != nil { + httpError(r, w, http.StatusInternalServerError, "Couldn't write body: %s", err) + return + } +} + +func checkDNS(input string, bootstrap []string) error { + // separate upstream from domains list + input, defaultUpstream, err := separateUpstream(input) + if err != nil { + return fmt.Errorf("wrong upstream format: %s", err) + } + + // No need to check this entrance + if input == "#" && !defaultUpstream { + return nil + } + + if _, err := validateUpstream(input); err != nil { + return fmt.Errorf("wrong upstream format: %s", err) + } + + if len(bootstrap) == 0 { + bootstrap = defaultBootstrap + } + + log.Debug("Checking if DNS %s works...", input) + u, err := upstream.AddressToUpstream(input, upstream.Options{Bootstrap: bootstrap, Timeout: DefaultTimeout}) + if err != nil { + return fmt.Errorf("failed to choose upstream for %s: %s", input, err) + } + + req := dns.Msg{} + req.Id = dns.Id() + req.RecursionDesired = true + req.Question = []dns.Question{ + {Name: "google-public-dns-a.google.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, + } + reply, err := u.Exchange(&req) + if err != nil { + return fmt.Errorf("couldn't communicate with DNS server %s: %s", input, err) + } + if len(reply.Answer) != 1 { + return fmt.Errorf("DNS server %s returned wrong answer", input) + } + if t, ok := reply.Answer[0].(*dns.A); ok { + if !net.IPv4(8, 8, 8, 8).Equal(t.A) { + return fmt.Errorf("DNS server %s returned wrong answer: %v", input, t.A) + } + } + + log.Debug("DNS %s works OK", input) + return nil +} + +func (s *Server) registerHandlers() { + s.conf.HTTPRegister("POST", "/control/enable_protection", s.handleProtectionEnable) + s.conf.HTTPRegister("POST", "/control/disable_protection", s.handleProtectionDisable) + s.conf.HTTPRegister("POST", "/control/set_upstreams_config", s.handleSetUpstreamConfig) + s.conf.HTTPRegister("POST", "/control/test_upstream_dns", s.handleTestUpstreamDNS) + + s.conf.HTTPRegister("GET", "/control/access/list", s.handleAccessList) + s.conf.HTTPRegister("POST", "/control/access/set", s.handleAccessSet) + +} diff --git a/dnsforward/dnsforward_test.go b/dnsforward/dnsforward_test.go index 2568ef7b..f05934f6 100644 --- a/dnsforward/dnsforward_test.go +++ b/dnsforward/dnsforward_test.go @@ -17,6 +17,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/dnsfilter" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/miekg/dns" + "github.com/stretchr/testify/assert" ) const ( @@ -390,8 +391,7 @@ func createTestServer(t *testing.T) *Server { s := NewServer(f, nil, nil) s.conf.UDPListenAddr = &net.UDPAddr{Port: 0} s.conf.TCPListenAddr = &net.TCPAddr{Port: 0} - - s.conf.FilteringConfig.FilteringEnabled = true + s.conf.UpstreamDNS = []string{"8.8.8.8:53", "8.8.4.4:53"} s.conf.FilteringConfig.ProtectionEnabled = true return s } @@ -541,67 +541,106 @@ func publicKey(priv interface{}) interface{} { } func TestIsBlockedIPAllowed(t *testing.T) { - s := createTestServer(t) - s.conf.AllowedClients = []string{"1.1.1.1", "2.2.0.0/16"} + a := &accessCtx{} + assert.True(t, a.Init([]string{"1.1.1.1", "2.2.0.0/16"}, nil, nil) == nil) - err := s.Start(nil) - if err != nil { - t.Fatalf("Failed to start server: %s", err) - } - - if s.isBlockedIP("1.1.1.1") { - t.Fatalf("isBlockedIP") - } - if !s.isBlockedIP("1.1.1.2") { - t.Fatalf("isBlockedIP") - } - if s.isBlockedIP("2.2.1.1") { - t.Fatalf("isBlockedIP") - } - if !s.isBlockedIP("2.3.1.1") { - t.Fatalf("isBlockedIP") - } + assert.True(t, !a.IsBlockedIP("1.1.1.1")) + assert.True(t, a.IsBlockedIP("1.1.1.2")) + assert.True(t, !a.IsBlockedIP("2.2.1.1")) + assert.True(t, a.IsBlockedIP("2.3.1.1")) } func TestIsBlockedIPDisallowed(t *testing.T) { - s := createTestServer(t) - s.conf.DisallowedClients = []string{"1.1.1.1", "2.2.0.0/16"} + a := &accessCtx{} + assert.True(t, a.Init(nil, []string{"1.1.1.1", "2.2.0.0/16"}, nil) == nil) - err := s.Start(nil) - if err != nil { - t.Fatalf("Failed to start server: %s", err) - } - - if !s.isBlockedIP("1.1.1.1") { - t.Fatalf("isBlockedIP") - } - if s.isBlockedIP("1.1.1.2") { - t.Fatalf("isBlockedIP") - } - if !s.isBlockedIP("2.2.1.1") { - t.Fatalf("isBlockedIP") - } - if s.isBlockedIP("2.3.1.1") { - t.Fatalf("isBlockedIP") - } + assert.True(t, a.IsBlockedIP("1.1.1.1")) + assert.True(t, !a.IsBlockedIP("1.1.1.2")) + assert.True(t, a.IsBlockedIP("2.2.1.1")) + assert.True(t, !a.IsBlockedIP("2.3.1.1")) } func TestIsBlockedIPBlockedDomain(t *testing.T) { - s := createTestServer(t) - s.conf.BlockedHosts = []string{"host1", "host2"} + a := &accessCtx{} + assert.True(t, a.Init(nil, nil, []string{"host1", "host2"}) == nil) - err := s.Start(nil) - if err != nil { - t.Fatalf("Failed to start server: %s", err) + assert.True(t, a.IsBlockedDomain("host1")) + assert.True(t, a.IsBlockedDomain("host2")) + assert.True(t, !a.IsBlockedDomain("host3")) +} + +func TestValidateUpstream(t *testing.T) { + invalidUpstreams := []string{"1.2.3.4.5", + "123.3.7m", + "htttps://google.com/dns-query", + "[/host.com]tls://dns.adguard.com", + "[host.ru]#", } - if !s.isBlockedDomain("host1") { - t.Fatalf("isBlockedDomain") + validDefaultUpstreams := []string{"1.1.1.1", + "tls://1.1.1.1", + "https://dns.adguard.com/dns-query", + "sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", } - if !s.isBlockedDomain("host2") { - t.Fatalf("isBlockedDomain") + + validUpstreams := []string{"[/host.com/]1.1.1.1", + "[//]tls://1.1.1.1", + "[/www.host.com/]#", + "[/host.com/google.com/]8.8.8.8", + "[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", } - if s.isBlockedDomain("host3") { - t.Fatalf("isBlockedDomain") + for _, u := range invalidUpstreams { + _, err := validateUpstream(u) + if err == nil { + t.Fatalf("upstream %s is invalid but it pass through validation", u) + } + } + + for _, u := range validDefaultUpstreams { + defaultUpstream, err := validateUpstream(u) + if err != nil { + t.Fatalf("upstream %s is valid but it doen't pass through validation cause: %s", u, err) + } + if !defaultUpstream { + t.Fatalf("upstream %s is default one!", u) + } + } + + for _, u := range validUpstreams { + defaultUpstream, err := validateUpstream(u) + if err != nil { + t.Fatalf("upstream %s is valid but it doen't pass through validation cause: %s", u, err) + } + if defaultUpstream { + t.Fatalf("upstream %s is default one!", u) + } + } +} + +func TestValidateUpstreamsSet(t *testing.T) { + // Set of valid upstreams. There is no default upstream specified + upstreamsSet := []string{"[/host.com/]1.1.1.1", + "[//]tls://1.1.1.1", + "[/www.host.com/]#", + "[/host.com/google.com/]8.8.8.8", + "[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", + } + err := validateUpstreams(upstreamsSet) + if err == nil { + t.Fatalf("there is no default upstream") + } + + // Let's add default upstream + upstreamsSet = append(upstreamsSet, "8.8.8.8") + err = validateUpstreams(upstreamsSet) + if err != nil { + t.Fatalf("upstreams set is valid, but doesn't pass through validation cause: %s", err) + } + + // Let's add invalid upstream + upstreamsSet = append(upstreamsSet, "dhcp://fake.dns") + err = validateUpstreams(upstreamsSet) + if err == nil { + t.Fatalf("there is an invalid upstream in set, but it pass through validation") } } diff --git a/home/blocked_services.go b/home/blocked_services.go index 7795867a..335668bc 100644 --- a/home/blocked_services.go +++ b/home/blocked_services.go @@ -187,7 +187,7 @@ func handleBlockedServicesSet(w http.ResponseWriter, r *http.Request) { return } - returnOK(w) + httpOK(r, w) } // RegisterBlockedServicesHandlers - register HTTP handlers diff --git a/home/clients_http.go b/home/clients_http.go index 7c39b68d..5a6bd332 100644 --- a/home/clients_http.go +++ b/home/clients_http.go @@ -168,8 +168,7 @@ func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http. return } - _ = writeAllConfigsAndReloadDNS() - returnOK(w) + onConfigModified() } // Remove client @@ -192,8 +191,7 @@ func (clients *clientsContainer) handleDelClient(w http.ResponseWriter, r *http. return } - _ = writeAllConfigsAndReloadDNS() - returnOK(w) + onConfigModified() } type updateJSON struct { @@ -232,8 +230,7 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht return } - _ = writeAllConfigsAndReloadDNS() - returnOK(w) + onConfigModified() } // Get the list of clients by IP address list diff --git a/home/config.go b/home/config.go index 61053ebe..8622155f 100644 --- a/home/config.go +++ b/home/config.go @@ -5,7 +5,6 @@ import ( "net/http" "os" "path/filepath" - "runtime" "sync" "time" @@ -105,16 +104,19 @@ type dnsConfig struct { // time interval for statistics (in days) StatsInterval uint32 `yaml:"statistics_interval"` + QueryLogEnabled bool `yaml:"querylog_enabled"` // if true, query log is enabled + QueryLogInterval uint32 `yaml:"querylog_interval"` // time interval for query log (in days) + dnsforward.FilteringConfig `yaml:",inline"` - UpstreamDNS []string `yaml:"upstream_dns"` -} + FilteringEnabled bool `yaml:"filtering_enabled"` // whether or not use filter lists + FiltersUpdateIntervalHours uint32 `yaml:"filters_update_interval"` // time period to update filters (in hours) + DnsfilterConf dnsfilter.Config `yaml:",inline"` -var defaultDNS = []string{ - "https://1.1.1.1/dns-query", - "https://1.0.0.1/dns-query", + // Names of services to block (globally). + // Per-client settings can override this configuration. + BlockedServices []string `yaml:"blocked_services"` } -var defaultBootstrap = []string{"1.1.1.1", "1.0.0.1"} type tlsConfigSettings struct { Enabled bool `yaml:"enabled" json:"enabled"` // Enabled is the encryption (DOT/DOH/HTTPS) status @@ -159,23 +161,21 @@ var config = configuration{ BindPort: 3000, BindHost: "0.0.0.0", DNS: dnsConfig{ - BindHost: "0.0.0.0", - Port: 53, - StatsInterval: 1, + BindHost: "0.0.0.0", + Port: 53, + StatsInterval: 1, + QueryLogEnabled: true, + QueryLogInterval: 1, FilteringConfig: dnsforward.FilteringConfig{ - ProtectionEnabled: true, // whether or not use any of dnsfilter features - FilteringEnabled: true, // whether or not use filter lists - FiltersUpdateIntervalHours: 24, - BlockingMode: "nxdomain", // mode how to answer filtered requests - BlockedResponseTTL: 10, // in seconds - QueryLogEnabled: true, - QueryLogInterval: 1, - Ratelimit: 20, - RefuseAny: true, - BootstrapDNS: defaultBootstrap, - AllServers: false, + ProtectionEnabled: true, // whether or not use any of dnsfilter features + BlockingMode: "nxdomain", // mode how to answer filtered requests + BlockedResponseTTL: 10, // in seconds + Ratelimit: 20, + RefuseAny: true, + AllServers: false, }, - UpstreamDNS: defaultDNS, + FilteringEnabled: true, // whether or not use filter lists + FiltersUpdateIntervalHours: 24, }, TLS: tlsConfig{ tlsConfigSettings: tlsConfigSettings{ @@ -202,13 +202,6 @@ func initConfig() { config.WebSessionTTLHours = 30 * 24 - if runtime.GOARCH == "mips" || runtime.GOARCH == "mipsle" { - // Use plain DNS on MIPS, encryption is too slow - defaultDNS = []string{"1.1.1.1", "1.0.0.1"} - // also change the default config - config.DNS.UpstreamDNS = defaultDNS - } - config.DNS.CacheSize = 4 * 1024 * 1024 config.DNS.DnsfilterConf.SafeBrowsingCacheSize = 1 * 1024 * 1024 config.DNS.DnsfilterConf.SafeSearchCacheSize = 1 * 1024 * 1024 @@ -325,6 +318,12 @@ func (c *configuration) write() error { config.DNS.DnsfilterConf = c } + if config.dnsServer != nil { + c := dnsforward.FilteringConfig{} + config.dnsServer.WriteDiskConfig(&c) + config.DNS.FilteringConfig = c + } + if config.dhcpServer != nil { c := dhcpd.ServerConfig{} config.dhcpServer.WriteDiskConfig(&c) diff --git a/home/control.go b/home/control.go index 73508f79..e27164c9 100644 --- a/home/control.go +++ b/home/control.go @@ -3,21 +3,13 @@ package home import ( "encoding/json" "fmt" - "net" "net/http" - "strconv" - "strings" "github.com/AdguardTeam/AdGuardHome/dnsforward" - "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/log" - "github.com/AdguardTeam/golibs/utils" "github.com/NYTimes/gziphandler" - "github.com/miekg/dns" ) -var protocols = []string{"tls://", "https://", "tcp://", "sdns://"} - // ---------------- // helper functions // ---------------- @@ -29,6 +21,9 @@ func returnOK(w http.ResponseWriter) { } } +func httpOK(r *http.Request, w http.ResponseWriter) { +} + func httpError(w http.ResponseWriter, code int, format string, args ...interface{}) { text := fmt.Sprintf(format, args...) log.Info(text) @@ -47,15 +42,6 @@ func writeAllConfigsAndReloadDNS() error { return reconfigureDNSServer() } -func httpUpdateConfigReloadDNSReturnOK(w http.ResponseWriter, r *http.Request) { - err := writeAllConfigsAndReloadDNS() - if err != nil { - httpError(w, http.StatusInternalServerError, "Couldn't write config file: %s", err) - return - } - returnOK(w) -} - func addDNSAddress(dnsAddresses *[]string, addr string) { if config.DNS.Port != 53 { addr = fmt.Sprintf("%s:%d", addr, config.DNS.Port) @@ -106,17 +92,22 @@ func getDNSAddresses() []string { } func handleStatus(w http.ResponseWriter, r *http.Request) { + c := dnsforward.FilteringConfig{} + if config.dnsServer != nil { + config.dnsServer.WriteDiskConfig(&c) + } data := map[string]interface{}{ - "dns_addresses": getDNSAddresses(), - "http_port": config.BindPort, - "dns_port": config.DNS.Port, - "protection_enabled": config.DNS.ProtectionEnabled, - "running": isRunning(), - "bootstrap_dns": config.DNS.BootstrapDNS, - "upstream_dns": config.DNS.UpstreamDNS, - "all_servers": config.DNS.AllServers, - "version": versionString, - "language": config.Language, + "dns_addresses": getDNSAddresses(), + "http_port": config.BindPort, + "dns_port": config.DNS.Port, + "running": isRunning(), + "version": versionString, + "language": config.Language, + + "protection_enabled": c.ProtectionEnabled, + "bootstrap_dns": c.BootstrapDNS, + "upstream_dns": c.UpstreamDNS, + "all_servers": c.AllServers, } jsonVal, err := json.Marshal(data) @@ -132,251 +123,6 @@ func handleStatus(w http.ResponseWriter, r *http.Request) { } } -func handleProtectionEnable(w http.ResponseWriter, r *http.Request) { - config.DNS.ProtectionEnabled = true - httpUpdateConfigReloadDNSReturnOK(w, r) -} - -func handleProtectionDisable(w http.ResponseWriter, r *http.Request) { - config.DNS.ProtectionEnabled = false - httpUpdateConfigReloadDNSReturnOK(w, r) -} - -// ----------------------- -// upstreams configuration -// ----------------------- - -// TODO this struct will become unnecessary after config file rework -type upstreamConfig struct { - Upstreams []string `json:"upstream_dns"` // Upstreams - BootstrapDNS []string `json:"bootstrap_dns"` // Bootstrap DNS - AllServers bool `json:"all_servers"` // --all-servers param for dnsproxy -} - -func handleSetUpstreamConfig(w http.ResponseWriter, r *http.Request) { - newconfig := upstreamConfig{} - err := json.NewDecoder(r.Body).Decode(&newconfig) - if err != nil { - httpError(w, http.StatusBadRequest, "Failed to parse new upstreams config json: %s", err) - return - } - - err = validateUpstreams(newconfig.Upstreams) - if err != nil { - httpError(w, http.StatusBadRequest, "wrong upstreams specification: %s", err) - return - } - - config.DNS.UpstreamDNS = defaultDNS - if len(newconfig.Upstreams) > 0 { - config.DNS.UpstreamDNS = newconfig.Upstreams - } - - // bootstrap servers are plain DNS only. - for _, host := range newconfig.BootstrapDNS { - if err := checkPlainDNS(host); err != nil { - httpError(w, http.StatusBadRequest, "%s can not be used as bootstrap dns cause: %s", host, err) - return - } - } - - config.DNS.BootstrapDNS = defaultBootstrap - if len(newconfig.BootstrapDNS) > 0 { - config.DNS.BootstrapDNS = newconfig.BootstrapDNS - } - - config.DNS.AllServers = newconfig.AllServers - httpUpdateConfigReloadDNSReturnOK(w, r) -} - -// validateUpstreams validates each upstream and returns an error if any upstream is invalid or if there are no default upstreams specified -func validateUpstreams(upstreams []string) error { - var defaultUpstreamFound bool - for _, u := range upstreams { - d, err := validateUpstream(u) - if err != nil { - return err - } - - // Check this flag until default upstream will not be found - if !defaultUpstreamFound { - defaultUpstreamFound = d - } - } - - // Return error if there are no default upstreams - if !defaultUpstreamFound { - return fmt.Errorf("no default upstreams specified") - } - - return nil -} - -func validateUpstream(u string) (bool, error) { - // Check if user tries to specify upstream for domain - u, defaultUpstream, err := separateUpstream(u) - if err != nil { - return defaultUpstream, err - } - - // The special server address '#' means "use the default servers" - if u == "#" && !defaultUpstream { - return defaultUpstream, nil - } - - // Check if the upstream has a valid protocol prefix - for _, proto := range protocols { - if strings.HasPrefix(u, proto) { - return defaultUpstream, nil - } - } - - // Return error if the upstream contains '://' without any valid protocol - if strings.Contains(u, "://") { - return defaultUpstream, fmt.Errorf("wrong protocol") - } - - // Check if upstream is valid plain DNS - return defaultUpstream, checkPlainDNS(u) -} - -// separateUpstream returns upstream without specified domains and a bool flag that indicates if no domains were specified -// error will be returned if upstream per domain specification is invalid -func separateUpstream(upstream string) (string, bool, error) { - defaultUpstream := true - if strings.HasPrefix(upstream, "[/") { - defaultUpstream = false - // split domains and upstream string - domainsAndUpstream := strings.Split(strings.TrimPrefix(upstream, "[/"), "/]") - if len(domainsAndUpstream) != 2 { - return "", defaultUpstream, fmt.Errorf("wrong DNS upstream per domain specification: %s", upstream) - } - - // split domains list and validate each one - for _, host := range strings.Split(domainsAndUpstream[0], "/") { - if host != "" { - if err := utils.IsValidHostname(host); err != nil { - return "", defaultUpstream, err - } - } - } - upstream = domainsAndUpstream[1] - } - return upstream, defaultUpstream, nil -} - -// checkPlainDNS checks if host is plain DNS -func checkPlainDNS(upstream string) error { - // Check if host is ip without port - if net.ParseIP(upstream) != nil { - return nil - } - - // Check if host is ip with port - ip, port, err := net.SplitHostPort(upstream) - if err != nil { - return err - } - - if net.ParseIP(ip) == nil { - return fmt.Errorf("%s is not a valid IP", ip) - } - - _, err = strconv.ParseInt(port, 0, 64) - if err != nil { - return fmt.Errorf("%s is not a valid port: %s", port, err) - } - - return nil -} - -func handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) { - upstreamConfig := upstreamConfig{} - err := json.NewDecoder(r.Body).Decode(&upstreamConfig) - if err != nil { - httpError(w, http.StatusBadRequest, "Failed to read request body: %s", err) - return - } - - if len(upstreamConfig.Upstreams) == 0 { - httpError(w, http.StatusBadRequest, "No servers specified") - return - } - - result := map[string]string{} - - for _, host := range upstreamConfig.Upstreams { - err = checkDNS(host, upstreamConfig.BootstrapDNS) - if err != nil { - log.Info("%v", err) - result[host] = err.Error() - } else { - result[host] = "OK" - } - } - - jsonVal, err := json.Marshal(result) - if err != nil { - httpError(w, http.StatusInternalServerError, "Unable to marshal status json: %s", err) - return - } - - w.Header().Set("Content-Type", "application/json") - _, err = w.Write(jsonVal) - if err != nil { - httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err) - } -} - -func checkDNS(input string, bootstrap []string) error { - // separate upstream from domains list - input, defaultUpstream, err := separateUpstream(input) - if err != nil { - return fmt.Errorf("wrong upstream format: %s", err) - } - - // No need to check this entrance - if input == "#" && !defaultUpstream { - return nil - } - - if _, err := validateUpstream(input); err != nil { - return fmt.Errorf("wrong upstream format: %s", err) - } - - if len(bootstrap) == 0 { - bootstrap = defaultBootstrap - } - - log.Debug("Checking if DNS %s works...", input) - u, err := upstream.AddressToUpstream(input, upstream.Options{Bootstrap: bootstrap, Timeout: dnsforward.DefaultTimeout}) - if err != nil { - return fmt.Errorf("failed to choose upstream for %s: %s", input, err) - } - - req := dns.Msg{} - req.Id = dns.Id() - req.RecursionDesired = true - req.Question = []dns.Question{ - {Name: "google-public-dns-a.google.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, - } - reply, err := u.Exchange(&req) - if err != nil { - return fmt.Errorf("couldn't communicate with DNS server %s: %s", input, err) - } - if len(reply.Answer) != 1 { - return fmt.Errorf("DNS server %s returned wrong answer", input) - } - if t, ok := reply.Answer[0].(*dns.A); ok { - if !net.IPv4(8, 8, 8, 8).Equal(t.A) { - return fmt.Errorf("DNS server %s returned wrong answer: %v", input, t.A) - } - } - - log.Debug("DNS %s works OK", input) - return nil -} - type profileJSON struct { Name string `json:"name"` } @@ -416,17 +162,11 @@ func handleDOH(w http.ResponseWriter, r *http.Request) { // ------------------------ func registerControlHandlers() { httpRegister(http.MethodGet, "/control/status", handleStatus) - httpRegister(http.MethodPost, "/control/enable_protection", handleProtectionEnable) - httpRegister(http.MethodPost, "/control/disable_protection", handleProtectionDisable) - httpRegister(http.MethodPost, "/control/set_upstreams_config", handleSetUpstreamConfig) - httpRegister(http.MethodPost, "/control/test_upstream_dns", handleTestUpstreamDNS) httpRegister(http.MethodPost, "/control/i18n/change_language", handleI18nChangeLanguage) httpRegister(http.MethodGet, "/control/i18n/current_language", handleI18nCurrentLanguage) http.HandleFunc("/control/version.json", postInstall(optionalAuth(handleGetVersionJSON))) httpRegister(http.MethodPost, "/control/update", handleUpdate) - httpRegister(http.MethodGet, "/control/access/list", handleAccessList) - httpRegister(http.MethodPost, "/control/access/set", handleAccessSet) httpRegister("GET", "/control/profile", handleGetProfile) RegisterFilteringHandlers() diff --git a/home/control_access.go b/home/control_access.go deleted file mode 100644 index abc51101..00000000 --- a/home/control_access.go +++ /dev/null @@ -1,84 +0,0 @@ -package home - -import ( - "encoding/json" - "net" - "net/http" - - "github.com/AdguardTeam/golibs/log" -) - -type accessListJSON struct { - AllowedClients []string `json:"allowed_clients"` - DisallowedClients []string `json:"disallowed_clients"` - BlockedHosts []string `json:"blocked_hosts"` -} - -func handleAccessList(w http.ResponseWriter, r *http.Request) { - config.controlLock.Lock() - j := accessListJSON{ - AllowedClients: config.DNS.AllowedClients, - DisallowedClients: config.DNS.DisallowedClients, - BlockedHosts: config.DNS.BlockedHosts, - } - config.controlLock.Unlock() - - w.Header().Set("Content-Type", "application/json") - err := json.NewEncoder(w).Encode(j) - if err != nil { - httpError(w, http.StatusInternalServerError, "json.Encode: %s", err) - return - } -} - -func checkIPCIDRArray(src []string) error { - for _, s := range src { - ip := net.ParseIP(s) - if ip != nil { - continue - } - - _, _, err := net.ParseCIDR(s) - if err != nil { - return err - } - } - - return nil -} - -func handleAccessSet(w http.ResponseWriter, r *http.Request) { - - j := accessListJSON{} - err := json.NewDecoder(r.Body).Decode(&j) - if err != nil { - httpError(w, http.StatusBadRequest, "json.Decode: %s", err) - return - } - - err = checkIPCIDRArray(j.AllowedClients) - if err == nil { - err = checkIPCIDRArray(j.DisallowedClients) - } - if err != nil { - httpError(w, http.StatusBadRequest, "%s", err) - return - } - - config.Lock() - config.DNS.AllowedClients = j.AllowedClients - config.DNS.DisallowedClients = j.DisallowedClients - config.DNS.BlockedHosts = j.BlockedHosts - config.Unlock() - - log.Tracef("Update access lists: %d, %d, %d", - len(j.AllowedClients), len(j.DisallowedClients), len(j.BlockedHosts)) - - err = writeAllConfigsAndReloadDNS() - if err != nil { - httpError(w, http.StatusBadRequest, "%s", err) - return - } - - returnOK(w) -} diff --git a/home/control_filtering.go b/home/control_filtering.go index 92a0a01b..5b4b963e 100644 --- a/home/control_filtering.go +++ b/home/control_filtering.go @@ -172,7 +172,12 @@ func handleFilteringSetRules(w http.ResponseWriter, r *http.Request) { } config.UserRules = strings.Split(string(body), "\n") - _ = writeAllConfigs() + onConfigModified() + userFilter := userFilter() + err = userFilter.save() + if err != nil { + log.Error("Couldn't save the user filter: %s", err) + } enableFilters(true) } @@ -218,7 +223,7 @@ func handleFilteringStatus(w http.ResponseWriter, r *http.Request) { RulesCount: uint32(f.RulesCount), } - if f.LastUpdated.Second() != 0 { + if !f.LastUpdated.IsZero() { fj.LastUpdated = f.LastUpdated.Format(time.RFC3339) } diff --git a/home/control_test.go b/home/control_test.go index 642b5bba..b047b65a 100644 --- a/home/control_test.go +++ b/home/control_test.go @@ -75,79 +75,3 @@ kXS9jgARhhiWXJrk t.Fatalf("valid cert & priv key: validateCertificates(): %v", data) } } - -func TestValidateUpstream(t *testing.T) { - invalidUpstreams := []string{"1.2.3.4.5", - "123.3.7m", - "htttps://google.com/dns-query", - "[/host.com]tls://dns.adguard.com", - "[host.ru]#", - } - - validDefaultUpstreams := []string{"1.1.1.1", - "tls://1.1.1.1", - "https://dns.adguard.com/dns-query", - "sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", - } - - validUpstreams := []string{"[/host.com/]1.1.1.1", - "[//]tls://1.1.1.1", - "[/www.host.com/]#", - "[/host.com/google.com/]8.8.8.8", - "[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", - } - for _, u := range invalidUpstreams { - _, err := validateUpstream(u) - if err == nil { - t.Fatalf("upstream %s is invalid but it pass through validation", u) - } - } - - for _, u := range validDefaultUpstreams { - defaultUpstream, err := validateUpstream(u) - if err != nil { - t.Fatalf("upstream %s is valid but it doen't pass through validation cause: %s", u, err) - } - if !defaultUpstream { - t.Fatalf("upstream %s is default one!", u) - } - } - - for _, u := range validUpstreams { - defaultUpstream, err := validateUpstream(u) - if err != nil { - t.Fatalf("upstream %s is valid but it doen't pass through validation cause: %s", u, err) - } - if defaultUpstream { - t.Fatalf("upstream %s is default one!", u) - } - } -} - -func TestValidateUpstreamsSet(t *testing.T) { - // Set of valid upstreams. There is no default upstream specified - upstreamsSet := []string{"[/host.com/]1.1.1.1", - "[//]tls://1.1.1.1", - "[/www.host.com/]#", - "[/host.com/google.com/]8.8.8.8", - "[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", - } - err := validateUpstreams(upstreamsSet) - if err == nil { - t.Fatalf("there is no default upstream") - } - - // Let's add default upstream - upstreamsSet = append(upstreamsSet, "8.8.8.8") - err = validateUpstreams(upstreamsSet) - if err != nil { - t.Fatalf("upstreams set is valid, but doesn't pass through validation cause: %s", err) - } - - // Let's add invalid upstream - upstreamsSet = append(upstreamsSet, "dhcp://fake.dns") - err = validateUpstreams(upstreamsSet) - if err == nil { - t.Fatalf("there is an invalid upstream in set, but it pass through validation") - } -} diff --git a/home/dns.go b/home/dns.go index 7b81a89e..af39ee12 100644 --- a/home/dns.go +++ b/home/dns.go @@ -156,6 +156,9 @@ func generateServerConfig() (dnsforward.ServerConfig, error) { UDPListenAddr: &net.UDPAddr{IP: net.ParseIP(config.DNS.BindHost), Port: config.DNS.Port}, TCPListenAddr: &net.TCPAddr{IP: net.ParseIP(config.DNS.BindHost), Port: config.DNS.Port}, FilteringConfig: config.DNS.FilteringConfig, + ConfigModified: onConfigModified, + HTTPRegister: httpRegister, + OnDNSRequest: onDNSRequest, } if config.TLS.Enabled { @@ -165,15 +168,7 @@ func generateServerConfig() (dnsforward.ServerConfig, error) { } } - upstreamConfig, err := proxy.ParseUpstreamsConfig(config.DNS.UpstreamDNS, config.DNS.BootstrapDNS, dnsforward.DefaultTimeout) - if err != nil { - return newconfig, fmt.Errorf("Couldn't get upstreams configuration cause: %s", err) - } - newconfig.Upstreams = upstreamConfig.Upstreams - newconfig.DomainsReservedUpstreams = upstreamConfig.DomainReservedUpstreams - newconfig.AllServers = config.DNS.AllServers newconfig.FilterHandler = applyAdditionalFiltering - newconfig.OnDNSRequest = onDNSRequest return newconfig, nil } diff --git a/home/filter.go b/home/filter.go index d5ce7ffc..04a32d87 100644 --- a/home/filter.go +++ b/home/filter.go @@ -499,7 +499,7 @@ func (filter *filter) LastTimeUpdated() time.Time { func enableFilters(async bool) { var filters map[int]string - if config.DNS.FilteringConfig.FilteringEnabled { + if config.DNS.FilteringEnabled { // convert array of filters filters = make(map[int]string) diff --git a/home/i18n.go b/home/i18n.go index 610a44c2..4b20f606 100644 --- a/home/i18n.go +++ b/home/i18n.go @@ -79,6 +79,6 @@ func handleI18nChangeLanguage(w http.ResponseWriter, r *http.Request) { } config.Language = language - - httpUpdateConfigReloadDNSReturnOK(w, r) + onConfigModified() + returnOK(w) }