diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index cb86c235..cace0cbc 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -3,6 +3,7 @@ package dnsforward import ( "crypto/tls" "errors" + "fmt" "net" "net/http" "runtime" @@ -50,6 +51,8 @@ type Server struct { DisallowedClientsIPNet []net.IPNet // CIDRs of clients that should be blocked BlockedHosts map[string]bool // hosts that should be blocked + webRegistered bool + sync.RWMutex conf ServerConfig } @@ -64,6 +67,7 @@ func NewServer(dnsFilter *dnsfilter.Dnsfilter, stats stats.Stats, queryLog query return s } +// Close - close object func (s *Server) Close() { s.Lock() s.dnsFilter = nil @@ -72,20 +76,23 @@ func (s *Server) Close() { s.Unlock() } +// WriteDiskConfig - write configuration +func (s *Server) WriteDiskConfig(c *FilteringConfig) { + s.Lock() + *c = s.conf.FilteringConfig + s.Unlock() +} + // FilteringConfig represents the DNS filtering configuration of AdGuard Home // The zero FilteringConfig is empty and ready for use. type FilteringConfig struct { // Filtering callback function FilterHandler func(clientAddr string, settings *dnsfilter.RequestFilteringSettings) `yaml:"-"` - ProtectionEnabled bool `yaml:"protection_enabled"` // whether or not use any of dnsfilter features - FilteringEnabled bool `yaml:"filtering_enabled"` // whether or not use filter lists - FiltersUpdateIntervalHours uint32 `yaml:"filters_update_interval"` // time period to update filters (in hours) + ProtectionEnabled bool `yaml:"protection_enabled"` // whether or not use any of dnsfilter features BlockingMode string `yaml:"blocking_mode"` // mode how to answer filtered requests BlockedResponseTTL uint32 `yaml:"blocked_response_ttl"` // if 0, then default is used (3600) - QueryLogEnabled bool `yaml:"querylog_enabled"` // if true, query log is enabled - QueryLogInterval uint32 `yaml:"querylog_interval"` // time interval for query log (in days) Ratelimit int `yaml:"ratelimit"` // max number of requests per second from a given IP (0 to disable) RatelimitWhitelist []string `yaml:"ratelimit_whitelist"` // a list of whitelisted client IP addresses RefuseAny bool `yaml:"refuse_any"` // if true, refuse ANY requests @@ -104,9 +111,8 @@ type FilteringConfig struct { // Per-client settings can override this configuration. BlockedServices []string `yaml:"blocked_services"` - CacheSize uint `yaml:"cache_size"` // DNS cache size (in bytes) - - DnsfilterConf dnsfilter.Config `yaml:",inline"` + CacheSize uint `yaml:"cache_size"` // DNS cache size (in bytes) + UpstreamDNS []string `yaml:"upstream_dns"` } // TLSConfig is the TLS configuration for HTTPS, DNS-over-HTTPS, and DNS-over-TLS @@ -133,6 +139,12 @@ type ServerConfig struct { FilteringConfig TLSConfig + + // Called when the configuration is changed by HTTP request + ConfigModified func() + + // Register an HTTP handler + HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request)) } // if any of ServerConfig values are zero, then default values from below are used @@ -198,6 +210,13 @@ func (s *Server) startInternal(config *ServerConfig) error { if config != nil { s.conf = *config + upstreamConfig, err := proxy.ParseUpstreamsConfig(s.conf.UpstreamDNS, s.conf.BootstrapDNS, DefaultTimeout) + if err != nil { + return fmt.Errorf("DNS: proxy.ParseUpstreamsConfig: %s", err) + } + s.conf.Upstreams = upstreamConfig.Upstreams + s.conf.DomainsReservedUpstreams = upstreamConfig.DomainReservedUpstreams + } if len(s.conf.ParentalBlockHost) == 0 { s.conf.ParentalBlockHost = parentalBlockHost @@ -257,6 +276,11 @@ func (s *Server) startInternal(config *ServerConfig) error { proxyConfig.Upstreams = defaultValues.Upstreams } + if !s.webRegistered && s.conf.HTTPRegister != nil { + s.webRegistered = true + s.registerHandlers() + } + // Initialize and start the DNS proxy s.dnsProxy = &proxy.Proxy{Config: proxyConfig} return s.dnsProxy.Start() @@ -293,7 +317,28 @@ func (s *Server) IsRunning() bool { return isRunning } -// Reconfigure applies the new configuration to the DNS server +// Reconfigure2 - safely apply and write new configuration and restart +func (s *Server) Reconfigure2(newconf FilteringConfig) error { + s.Lock() + s.conf.FilteringConfig = newconf + s.Unlock() + s.conf.ConfigModified() + + s.Lock() + defer s.Unlock() + log.Print("Start reconfiguring the server") + err := s.stopInternal() + if err != nil { + return errorx.Decorate(err, "could not reconfigure the server") + } + err = s.startInternal(nil) + if err != nil { + return errorx.Decorate(err, "could not reconfigure the server") + } + + return nil +} + func (s *Server) Reconfigure(config *ServerConfig) error { s.Lock() defer s.Unlock() diff --git a/dnsforward/dnsforward_http.go b/dnsforward/dnsforward_http.go new file mode 100644 index 00000000..dec75713 --- /dev/null +++ b/dnsforward/dnsforward_http.go @@ -0,0 +1,286 @@ +package dnsforward + +import ( + "encoding/json" + "fmt" + "net" + "net/http" + "runtime" + "strconv" + "strings" + + "github.com/AdguardTeam/dnsproxy/upstream" + "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/utils" + "github.com/miekg/dns" +) + +var defaultDNS = []string{ + "https://1.1.1.1/dns-query", + "https://1.0.0.1/dns-query", +} +var defaultBootstrap = []string{"1.1.1.1", "1.0.0.1"} + +func httpError(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) { + text := fmt.Sprintf(format, args...) + log.Info("DNS: %s %s: %s", r.Method, r.URL, text) + http.Error(w, text, code) +} + +func (s *Server) handleProtectionEnable(w http.ResponseWriter, r *http.Request) { + s.conf.ProtectionEnabled = true + s.conf.ConfigModified() +} + +func (s *Server) handleProtectionDisable(w http.ResponseWriter, r *http.Request) { + s.conf.ProtectionEnabled = false + s.conf.ConfigModified() +} + +type upstreamJSON struct { + Upstreams []string `json:"upstream_dns"` // Upstreams + BootstrapDNS []string `json:"bootstrap_dns"` // Bootstrap DNS + AllServers bool `json:"all_servers"` // --all-servers param for dnsproxy +} + +func (s *Server) handleSetUpstreamConfig(w http.ResponseWriter, r *http.Request) { + req := upstreamJSON{} + err := json.NewDecoder(r.Body).Decode(&req) + if err != nil { + httpError(r, w, http.StatusBadRequest, "Failed to parse new upstreams config json: %s", err) + return + } + + err = validateUpstreams(req.Upstreams) + if err != nil { + httpError(r, w, http.StatusBadRequest, "wrong upstreams specification: %s", err) + return + } + + newconf := FilteringConfig{} + newconf.UpstreamDNS = defaultDNS + if runtime.GOARCH == "mips" || runtime.GOARCH == "mipsle" { + // Use plain DNS on MIPS, encryption is too slow + newconf.UpstreamDNS = []string{"1.1.1.1", "1.0.0.1"} + } + if len(req.Upstreams) != 0 { + newconf.UpstreamDNS = req.Upstreams + } + + // bootstrap servers are plain DNS only + for _, host := range req.BootstrapDNS { + if err := checkPlainDNS(host); err != nil { + httpError(r, w, http.StatusBadRequest, "%s can not be used as bootstrap dns cause: %s", host, err) + return + } + } + newconf.BootstrapDNS = defaultBootstrap + if len(req.BootstrapDNS) != 0 { + newconf.BootstrapDNS = req.BootstrapDNS + } + + newconf.AllServers = req.AllServers + err = s.Reconfigure2(newconf) + if err != nil { + httpError(r, w, http.StatusInternalServerError, "%s", err) + return + } +} + +// validateUpstreams validates each upstream and returns an error if any upstream is invalid or if there are no default upstreams specified +func validateUpstreams(upstreams []string) error { + var defaultUpstreamFound bool + for _, u := range upstreams { + d, err := validateUpstream(u) + if err != nil { + return err + } + + // Check this flag until default upstream will not be found + if !defaultUpstreamFound { + defaultUpstreamFound = d + } + } + + // Return error if there are no default upstreams + if !defaultUpstreamFound { + return fmt.Errorf("no default upstreams specified") + } + + return nil +} + +var protocols = []string{"tls://", "https://", "tcp://", "sdns://"} + +func validateUpstream(u string) (bool, error) { + // Check if user tries to specify upstream for domain + u, defaultUpstream, err := separateUpstream(u) + if err != nil { + return defaultUpstream, err + } + + // The special server address '#' means "use the default servers" + if u == "#" && !defaultUpstream { + return defaultUpstream, nil + } + + // Check if the upstream has a valid protocol prefix + for _, proto := range protocols { + if strings.HasPrefix(u, proto) { + return defaultUpstream, nil + } + } + + // Return error if the upstream contains '://' without any valid protocol + if strings.Contains(u, "://") { + return defaultUpstream, fmt.Errorf("wrong protocol") + } + + // Check if upstream is valid plain DNS + return defaultUpstream, checkPlainDNS(u) +} + +// separateUpstream returns upstream without specified domains and a bool flag that indicates if no domains were specified +// error will be returned if upstream per domain specification is invalid +func separateUpstream(upstream string) (string, bool, error) { + defaultUpstream := true + if strings.HasPrefix(upstream, "[/") { + defaultUpstream = false + // split domains and upstream string + domainsAndUpstream := strings.Split(strings.TrimPrefix(upstream, "[/"), "/]") + if len(domainsAndUpstream) != 2 { + return "", defaultUpstream, fmt.Errorf("wrong DNS upstream per domain specification: %s", upstream) + } + + // split domains list and validate each one + for _, host := range strings.Split(domainsAndUpstream[0], "/") { + if host != "" { + if err := utils.IsValidHostname(host); err != nil { + return "", defaultUpstream, err + } + } + } + upstream = domainsAndUpstream[1] + } + return upstream, defaultUpstream, nil +} + +// checkPlainDNS checks if host is plain DNS +func checkPlainDNS(upstream string) error { + // Check if host is ip without port + if net.ParseIP(upstream) != nil { + return nil + } + + // Check if host is ip with port + ip, port, err := net.SplitHostPort(upstream) + if err != nil { + return err + } + + if net.ParseIP(ip) == nil { + return fmt.Errorf("%s is not a valid IP", ip) + } + + _, err = strconv.ParseInt(port, 0, 64) + if err != nil { + return fmt.Errorf("%s is not a valid port: %s", port, err) + } + + return nil +} + +func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) { + req := upstreamJSON{} + err := json.NewDecoder(r.Body).Decode(&req) + if err != nil { + httpError(r, w, http.StatusBadRequest, "Failed to read request body: %s", err) + return + } + + if len(req.Upstreams) == 0 { + httpError(r, w, http.StatusBadRequest, "No servers specified") + return + } + + result := map[string]string{} + + for _, host := range req.Upstreams { + err = checkDNS(host, req.BootstrapDNS) + if err != nil { + log.Info("%v", err) + result[host] = err.Error() + } else { + result[host] = "OK" + } + } + + jsonVal, err := json.Marshal(result) + if err != nil { + httpError(r, w, http.StatusInternalServerError, "Unable to marshal status json: %s", err) + return + } + + w.Header().Set("Content-Type", "application/json") + _, err = w.Write(jsonVal) + if err != nil { + httpError(r, w, http.StatusInternalServerError, "Couldn't write body: %s", err) + return + } +} + +func checkDNS(input string, bootstrap []string) error { + // separate upstream from domains list + input, defaultUpstream, err := separateUpstream(input) + if err != nil { + return fmt.Errorf("wrong upstream format: %s", err) + } + + // No need to check this entrance + if input == "#" && !defaultUpstream { + return nil + } + + if _, err := validateUpstream(input); err != nil { + return fmt.Errorf("wrong upstream format: %s", err) + } + + if len(bootstrap) == 0 { + bootstrap = defaultBootstrap + } + + log.Debug("Checking if DNS %s works...", input) + u, err := upstream.AddressToUpstream(input, upstream.Options{Bootstrap: bootstrap, Timeout: DefaultTimeout}) + if err != nil { + return fmt.Errorf("failed to choose upstream for %s: %s", input, err) + } + + req := dns.Msg{} + req.Id = dns.Id() + req.RecursionDesired = true + req.Question = []dns.Question{ + {Name: "google-public-dns-a.google.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, + } + reply, err := u.Exchange(&req) + if err != nil { + return fmt.Errorf("couldn't communicate with DNS server %s: %s", input, err) + } + if len(reply.Answer) != 1 { + return fmt.Errorf("DNS server %s returned wrong answer", input) + } + if t, ok := reply.Answer[0].(*dns.A); ok { + if !net.IPv4(8, 8, 8, 8).Equal(t.A) { + return fmt.Errorf("DNS server %s returned wrong answer: %v", input, t.A) + } + } + + log.Debug("DNS %s works OK", input) + return nil +} + +func (s *Server) registerHandlers() { + s.conf.HTTPRegister("POST", "/control/enable_protection", s.handleProtectionEnable) + s.conf.HTTPRegister("POST", "/control/disable_protection", s.handleProtectionDisable) + s.conf.HTTPRegister("POST", "/control/set_upstreams_config", s.handleSetUpstreamConfig) + s.conf.HTTPRegister("POST", "/control/test_upstream_dns", s.handleTestUpstreamDNS) +} diff --git a/home/blocked_services.go b/home/blocked_services.go index 7795867a..7cb09f11 100644 --- a/home/blocked_services.go +++ b/home/blocked_services.go @@ -186,8 +186,6 @@ func handleBlockedServicesSet(w http.ResponseWriter, r *http.Request) { httpError(w, http.StatusBadRequest, "%s", err) return } - - returnOK(w) } // RegisterBlockedServicesHandlers - register HTTP handlers diff --git a/home/clients_http.go b/home/clients_http.go index 7c39b68d..5a6bd332 100644 --- a/home/clients_http.go +++ b/home/clients_http.go @@ -168,8 +168,7 @@ func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http. return } - _ = writeAllConfigsAndReloadDNS() - returnOK(w) + onConfigModified() } // Remove client @@ -192,8 +191,7 @@ func (clients *clientsContainer) handleDelClient(w http.ResponseWriter, r *http. return } - _ = writeAllConfigsAndReloadDNS() - returnOK(w) + onConfigModified() } type updateJSON struct { @@ -232,8 +230,7 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht return } - _ = writeAllConfigsAndReloadDNS() - returnOK(w) + onConfigModified() } // Get the list of clients by IP address list diff --git a/home/config.go b/home/config.go index 61053ebe..c63db20f 100644 --- a/home/config.go +++ b/home/config.go @@ -105,9 +105,14 @@ type dnsConfig struct { // time interval for statistics (in days) StatsInterval uint32 `yaml:"statistics_interval"` + QueryLogEnabled bool `yaml:"querylog_enabled"` // if true, query log is enabled + QueryLogInterval uint32 `yaml:"querylog_interval"` // time interval for query log (in days) + dnsforward.FilteringConfig `yaml:",inline"` - UpstreamDNS []string `yaml:"upstream_dns"` + FilteringEnabled bool `yaml:"filtering_enabled"` // whether or not use filter lists + FiltersUpdateIntervalHours uint32 `yaml:"filters_update_interval"` // time period to update filters (in hours) + DnsfilterConf dnsfilter.Config `yaml:",inline"` } var defaultDNS = []string{ @@ -159,23 +164,21 @@ var config = configuration{ BindPort: 3000, BindHost: "0.0.0.0", DNS: dnsConfig{ - BindHost: "0.0.0.0", - Port: 53, - StatsInterval: 1, + BindHost: "0.0.0.0", + Port: 53, + StatsInterval: 1, + QueryLogEnabled: true, + QueryLogInterval: 1, FilteringConfig: dnsforward.FilteringConfig{ - ProtectionEnabled: true, // whether or not use any of dnsfilter features - FilteringEnabled: true, // whether or not use filter lists - FiltersUpdateIntervalHours: 24, - BlockingMode: "nxdomain", // mode how to answer filtered requests - BlockedResponseTTL: 10, // in seconds - QueryLogEnabled: true, - QueryLogInterval: 1, - Ratelimit: 20, - RefuseAny: true, - BootstrapDNS: defaultBootstrap, - AllServers: false, + ProtectionEnabled: true, // whether or not use any of dnsfilter features + BlockingMode: "nxdomain", // mode how to answer filtered requests + BlockedResponseTTL: 10, // in seconds + Ratelimit: 20, + RefuseAny: true, + AllServers: false, }, - UpstreamDNS: defaultDNS, + FilteringEnabled: true, // whether or not use filter lists + FiltersUpdateIntervalHours: 24, }, TLS: tlsConfig{ tlsConfigSettings: tlsConfigSettings{ @@ -202,13 +205,13 @@ func initConfig() { config.WebSessionTTLHours = 30 * 24 + config.DNS.UpstreamDNS = defaultDNS if runtime.GOARCH == "mips" || runtime.GOARCH == "mipsle" { // Use plain DNS on MIPS, encryption is too slow - defaultDNS = []string{"1.1.1.1", "1.0.0.1"} - // also change the default config - config.DNS.UpstreamDNS = defaultDNS + config.DNS.UpstreamDNS = []string{"1.1.1.1", "1.0.0.1"} } + config.DNS.BootstrapDNS = defaultBootstrap config.DNS.CacheSize = 4 * 1024 * 1024 config.DNS.DnsfilterConf.SafeBrowsingCacheSize = 1 * 1024 * 1024 config.DNS.DnsfilterConf.SafeSearchCacheSize = 1 * 1024 * 1024 @@ -325,6 +328,12 @@ func (c *configuration) write() error { config.DNS.DnsfilterConf = c } + if config.dnsServer != nil { + c := dnsforward.FilteringConfig{} + config.dnsServer.WriteDiskConfig(&c) + config.DNS.FilteringConfig = c + } + if config.dhcpServer != nil { c := dhcpd.ServerConfig{} config.dhcpServer.WriteDiskConfig(&c) diff --git a/home/control.go b/home/control.go index 73508f79..97a8d232 100644 --- a/home/control.go +++ b/home/control.go @@ -3,21 +3,13 @@ package home import ( "encoding/json" "fmt" - "net" "net/http" - "strconv" - "strings" "github.com/AdguardTeam/AdGuardHome/dnsforward" - "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/log" - "github.com/AdguardTeam/golibs/utils" "github.com/NYTimes/gziphandler" - "github.com/miekg/dns" ) -var protocols = []string{"tls://", "https://", "tcp://", "sdns://"} - // ---------------- // helper functions // ---------------- @@ -47,15 +39,6 @@ func writeAllConfigsAndReloadDNS() error { return reconfigureDNSServer() } -func httpUpdateConfigReloadDNSReturnOK(w http.ResponseWriter, r *http.Request) { - err := writeAllConfigsAndReloadDNS() - if err != nil { - httpError(w, http.StatusInternalServerError, "Couldn't write config file: %s", err) - return - } - returnOK(w) -} - func addDNSAddress(dnsAddresses *[]string, addr string) { if config.DNS.Port != 53 { addr = fmt.Sprintf("%s:%d", addr, config.DNS.Port) @@ -106,17 +89,22 @@ func getDNSAddresses() []string { } func handleStatus(w http.ResponseWriter, r *http.Request) { + c := dnsforward.FilteringConfig{} + if config.dnsServer != nil { + config.dnsServer.WriteDiskConfig(&c) + } data := map[string]interface{}{ - "dns_addresses": getDNSAddresses(), - "http_port": config.BindPort, - "dns_port": config.DNS.Port, - "protection_enabled": config.DNS.ProtectionEnabled, - "running": isRunning(), - "bootstrap_dns": config.DNS.BootstrapDNS, - "upstream_dns": config.DNS.UpstreamDNS, - "all_servers": config.DNS.AllServers, - "version": versionString, - "language": config.Language, + "dns_addresses": getDNSAddresses(), + "http_port": config.BindPort, + "dns_port": config.DNS.Port, + "running": isRunning(), + "version": versionString, + "language": config.Language, + + "protection_enabled": c.ProtectionEnabled, + "bootstrap_dns": c.BootstrapDNS, + "upstream_dns": c.UpstreamDNS, + "all_servers": c.AllServers, } jsonVal, err := json.Marshal(data) @@ -132,251 +120,6 @@ func handleStatus(w http.ResponseWriter, r *http.Request) { } } -func handleProtectionEnable(w http.ResponseWriter, r *http.Request) { - config.DNS.ProtectionEnabled = true - httpUpdateConfigReloadDNSReturnOK(w, r) -} - -func handleProtectionDisable(w http.ResponseWriter, r *http.Request) { - config.DNS.ProtectionEnabled = false - httpUpdateConfigReloadDNSReturnOK(w, r) -} - -// ----------------------- -// upstreams configuration -// ----------------------- - -// TODO this struct will become unnecessary after config file rework -type upstreamConfig struct { - Upstreams []string `json:"upstream_dns"` // Upstreams - BootstrapDNS []string `json:"bootstrap_dns"` // Bootstrap DNS - AllServers bool `json:"all_servers"` // --all-servers param for dnsproxy -} - -func handleSetUpstreamConfig(w http.ResponseWriter, r *http.Request) { - newconfig := upstreamConfig{} - err := json.NewDecoder(r.Body).Decode(&newconfig) - if err != nil { - httpError(w, http.StatusBadRequest, "Failed to parse new upstreams config json: %s", err) - return - } - - err = validateUpstreams(newconfig.Upstreams) - if err != nil { - httpError(w, http.StatusBadRequest, "wrong upstreams specification: %s", err) - return - } - - config.DNS.UpstreamDNS = defaultDNS - if len(newconfig.Upstreams) > 0 { - config.DNS.UpstreamDNS = newconfig.Upstreams - } - - // bootstrap servers are plain DNS only. - for _, host := range newconfig.BootstrapDNS { - if err := checkPlainDNS(host); err != nil { - httpError(w, http.StatusBadRequest, "%s can not be used as bootstrap dns cause: %s", host, err) - return - } - } - - config.DNS.BootstrapDNS = defaultBootstrap - if len(newconfig.BootstrapDNS) > 0 { - config.DNS.BootstrapDNS = newconfig.BootstrapDNS - } - - config.DNS.AllServers = newconfig.AllServers - httpUpdateConfigReloadDNSReturnOK(w, r) -} - -// validateUpstreams validates each upstream and returns an error if any upstream is invalid or if there are no default upstreams specified -func validateUpstreams(upstreams []string) error { - var defaultUpstreamFound bool - for _, u := range upstreams { - d, err := validateUpstream(u) - if err != nil { - return err - } - - // Check this flag until default upstream will not be found - if !defaultUpstreamFound { - defaultUpstreamFound = d - } - } - - // Return error if there are no default upstreams - if !defaultUpstreamFound { - return fmt.Errorf("no default upstreams specified") - } - - return nil -} - -func validateUpstream(u string) (bool, error) { - // Check if user tries to specify upstream for domain - u, defaultUpstream, err := separateUpstream(u) - if err != nil { - return defaultUpstream, err - } - - // The special server address '#' means "use the default servers" - if u == "#" && !defaultUpstream { - return defaultUpstream, nil - } - - // Check if the upstream has a valid protocol prefix - for _, proto := range protocols { - if strings.HasPrefix(u, proto) { - return defaultUpstream, nil - } - } - - // Return error if the upstream contains '://' without any valid protocol - if strings.Contains(u, "://") { - return defaultUpstream, fmt.Errorf("wrong protocol") - } - - // Check if upstream is valid plain DNS - return defaultUpstream, checkPlainDNS(u) -} - -// separateUpstream returns upstream without specified domains and a bool flag that indicates if no domains were specified -// error will be returned if upstream per domain specification is invalid -func separateUpstream(upstream string) (string, bool, error) { - defaultUpstream := true - if strings.HasPrefix(upstream, "[/") { - defaultUpstream = false - // split domains and upstream string - domainsAndUpstream := strings.Split(strings.TrimPrefix(upstream, "[/"), "/]") - if len(domainsAndUpstream) != 2 { - return "", defaultUpstream, fmt.Errorf("wrong DNS upstream per domain specification: %s", upstream) - } - - // split domains list and validate each one - for _, host := range strings.Split(domainsAndUpstream[0], "/") { - if host != "" { - if err := utils.IsValidHostname(host); err != nil { - return "", defaultUpstream, err - } - } - } - upstream = domainsAndUpstream[1] - } - return upstream, defaultUpstream, nil -} - -// checkPlainDNS checks if host is plain DNS -func checkPlainDNS(upstream string) error { - // Check if host is ip without port - if net.ParseIP(upstream) != nil { - return nil - } - - // Check if host is ip with port - ip, port, err := net.SplitHostPort(upstream) - if err != nil { - return err - } - - if net.ParseIP(ip) == nil { - return fmt.Errorf("%s is not a valid IP", ip) - } - - _, err = strconv.ParseInt(port, 0, 64) - if err != nil { - return fmt.Errorf("%s is not a valid port: %s", port, err) - } - - return nil -} - -func handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) { - upstreamConfig := upstreamConfig{} - err := json.NewDecoder(r.Body).Decode(&upstreamConfig) - if err != nil { - httpError(w, http.StatusBadRequest, "Failed to read request body: %s", err) - return - } - - if len(upstreamConfig.Upstreams) == 0 { - httpError(w, http.StatusBadRequest, "No servers specified") - return - } - - result := map[string]string{} - - for _, host := range upstreamConfig.Upstreams { - err = checkDNS(host, upstreamConfig.BootstrapDNS) - if err != nil { - log.Info("%v", err) - result[host] = err.Error() - } else { - result[host] = "OK" - } - } - - jsonVal, err := json.Marshal(result) - if err != nil { - httpError(w, http.StatusInternalServerError, "Unable to marshal status json: %s", err) - return - } - - w.Header().Set("Content-Type", "application/json") - _, err = w.Write(jsonVal) - if err != nil { - httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err) - } -} - -func checkDNS(input string, bootstrap []string) error { - // separate upstream from domains list - input, defaultUpstream, err := separateUpstream(input) - if err != nil { - return fmt.Errorf("wrong upstream format: %s", err) - } - - // No need to check this entrance - if input == "#" && !defaultUpstream { - return nil - } - - if _, err := validateUpstream(input); err != nil { - return fmt.Errorf("wrong upstream format: %s", err) - } - - if len(bootstrap) == 0 { - bootstrap = defaultBootstrap - } - - log.Debug("Checking if DNS %s works...", input) - u, err := upstream.AddressToUpstream(input, upstream.Options{Bootstrap: bootstrap, Timeout: dnsforward.DefaultTimeout}) - if err != nil { - return fmt.Errorf("failed to choose upstream for %s: %s", input, err) - } - - req := dns.Msg{} - req.Id = dns.Id() - req.RecursionDesired = true - req.Question = []dns.Question{ - {Name: "google-public-dns-a.google.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, - } - reply, err := u.Exchange(&req) - if err != nil { - return fmt.Errorf("couldn't communicate with DNS server %s: %s", input, err) - } - if len(reply.Answer) != 1 { - return fmt.Errorf("DNS server %s returned wrong answer", input) - } - if t, ok := reply.Answer[0].(*dns.A); ok { - if !net.IPv4(8, 8, 8, 8).Equal(t.A) { - return fmt.Errorf("DNS server %s returned wrong answer: %v", input, t.A) - } - } - - log.Debug("DNS %s works OK", input) - return nil -} - type profileJSON struct { Name string `json:"name"` } @@ -416,10 +159,6 @@ func handleDOH(w http.ResponseWriter, r *http.Request) { // ------------------------ func registerControlHandlers() { httpRegister(http.MethodGet, "/control/status", handleStatus) - httpRegister(http.MethodPost, "/control/enable_protection", handleProtectionEnable) - httpRegister(http.MethodPost, "/control/disable_protection", handleProtectionDisable) - httpRegister(http.MethodPost, "/control/set_upstreams_config", handleSetUpstreamConfig) - httpRegister(http.MethodPost, "/control/test_upstream_dns", handleTestUpstreamDNS) httpRegister(http.MethodPost, "/control/i18n/change_language", handleI18nChangeLanguage) httpRegister(http.MethodGet, "/control/i18n/current_language", handleI18nCurrentLanguage) http.HandleFunc("/control/version.json", postInstall(optionalAuth(handleGetVersionJSON))) diff --git a/home/control_filtering.go b/home/control_filtering.go index 92a0a01b..5b4b963e 100644 --- a/home/control_filtering.go +++ b/home/control_filtering.go @@ -172,7 +172,12 @@ func handleFilteringSetRules(w http.ResponseWriter, r *http.Request) { } config.UserRules = strings.Split(string(body), "\n") - _ = writeAllConfigs() + onConfigModified() + userFilter := userFilter() + err = userFilter.save() + if err != nil { + log.Error("Couldn't save the user filter: %s", err) + } enableFilters(true) } @@ -218,7 +223,7 @@ func handleFilteringStatus(w http.ResponseWriter, r *http.Request) { RulesCount: uint32(f.RulesCount), } - if f.LastUpdated.Second() != 0 { + if !f.LastUpdated.IsZero() { fj.LastUpdated = f.LastUpdated.Format(time.RFC3339) } diff --git a/home/dns.go b/home/dns.go index 7b81a89e..af39ee12 100644 --- a/home/dns.go +++ b/home/dns.go @@ -156,6 +156,9 @@ func generateServerConfig() (dnsforward.ServerConfig, error) { UDPListenAddr: &net.UDPAddr{IP: net.ParseIP(config.DNS.BindHost), Port: config.DNS.Port}, TCPListenAddr: &net.TCPAddr{IP: net.ParseIP(config.DNS.BindHost), Port: config.DNS.Port}, FilteringConfig: config.DNS.FilteringConfig, + ConfigModified: onConfigModified, + HTTPRegister: httpRegister, + OnDNSRequest: onDNSRequest, } if config.TLS.Enabled { @@ -165,15 +168,7 @@ func generateServerConfig() (dnsforward.ServerConfig, error) { } } - upstreamConfig, err := proxy.ParseUpstreamsConfig(config.DNS.UpstreamDNS, config.DNS.BootstrapDNS, dnsforward.DefaultTimeout) - if err != nil { - return newconfig, fmt.Errorf("Couldn't get upstreams configuration cause: %s", err) - } - newconfig.Upstreams = upstreamConfig.Upstreams - newconfig.DomainsReservedUpstreams = upstreamConfig.DomainReservedUpstreams - newconfig.AllServers = config.DNS.AllServers newconfig.FilterHandler = applyAdditionalFiltering - newconfig.OnDNSRequest = onDNSRequest return newconfig, nil } diff --git a/home/filter.go b/home/filter.go index d5ce7ffc..04a32d87 100644 --- a/home/filter.go +++ b/home/filter.go @@ -499,7 +499,7 @@ func (filter *filter) LastTimeUpdated() time.Time { func enableFilters(async bool) { var filters map[int]string - if config.DNS.FilteringConfig.FilteringEnabled { + if config.DNS.FilteringEnabled { // convert array of filters filters = make(map[int]string) diff --git a/home/i18n.go b/home/i18n.go index 610a44c2..4b20f606 100644 --- a/home/i18n.go +++ b/home/i18n.go @@ -79,6 +79,6 @@ func handleI18nChangeLanguage(w http.ResponseWriter, r *http.Request) { } config.Language = language - - httpUpdateConfigReloadDNSReturnOK(w, r) + onConfigModified() + returnOK(w) }