diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index 862bf162..8c95eff4 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -44,6 +44,12 @@ type Server struct { queryLog querylog.QueryLog // Query log instance stats stats.Stats + // How many times the server was started + // While creating a dnsfilter object, + // we use this value to set s.dnsFilter property only with the most recent settings. + startCounter uint32 + dnsfilterCreatorChan chan dnsfilterCreatorParams + AllowedClients map[string]bool // IP addresses of whitelist clients DisallowedClients map[string]bool // IP addresses of clients that should be blocked AllowedClientsIPNet []net.IPNet // CIDRs of whitelist clients @@ -54,6 +60,11 @@ type Server struct { conf ServerConfig } +type dnsfilterCreatorParams struct { + conf dnsfilter.Config + filters map[int]string +} + // NewServer creates a new instance of the dnsforward.Server // Note: this function must be called only once func NewServer(stats stats.Stats, queryLog querylog.QueryLog) *Server { @@ -73,6 +84,12 @@ func (s *Server) Close() { // FilteringConfig represents the DNS filtering configuration of AdGuard Home // The zero FilteringConfig is empty and ready for use. type FilteringConfig struct { + // Create dnsfilter asynchronously. + // Requests won't be filtered until dnsfilter is created. + // If "restart" command is received while we're creating an old dnsfilter object, + // we delay creation of the new object until the old one is created. + AsyncStartup bool + ProtectionEnabled bool `yaml:"protection_enabled"` // whether or not use any of dnsfilter features FilteringEnabled bool `yaml:"filtering_enabled"` // whether or not use filter lists FiltersUpdateIntervalHours uint32 `yaml:"filters_update_interval"` // time period to update filters (in hours) @@ -254,8 +271,6 @@ func (s *Server) startInternal(config *ServerConfig) error { // Initializes the DNS filter func (s *Server) initDNSFilter(config *ServerConfig) error { - log.Tracef("Creating dnsfilter") - if config != nil { s.conf = *config } @@ -280,13 +295,71 @@ func (s *Server) initDNSFilter(config *ServerConfig) error { s.conf.SafeBrowsingBlockHost = safeBrowsingBlockHost } - s.dnsFilter = dnsfilter.New(&s.conf.Config, filters) - if s.dnsFilter == nil { - return fmt.Errorf("could not initialize dnsfilter") + if s.conf.AsyncStartup { + params := dnsfilterCreatorParams{ + conf: s.conf.Config, + filters: filters, + } + s.startCounter++ + if s.startCounter == 1 { + s.dnsfilterCreatorChan = make(chan dnsfilterCreatorParams, 1) + go s.dnsfilterCreator() + } + + // remove all pending tasks + stop := false + for !stop { + select { + case <-s.dnsfilterCreatorChan: + // + default: + stop = true + } + } + + s.dnsfilterCreatorChan <- params + } else { + log.Debug("creating dnsfilter...") + f := dnsfilter.New(&s.conf.Config, filters) + if f == nil { + return fmt.Errorf("could not initialize dnsfilter") + } + log.Debug("created dnsfilter") + s.dnsFilter = f } return nil } +func (s *Server) dnsfilterCreator() { + for { + params := <-s.dnsfilterCreatorChan + + s.Lock() + counter := s.startCounter + s.Unlock() + + log.Debug("creating dnsfilter...") + f := dnsfilter.New(¶ms.conf, params.filters) + if f == nil { + log.Error("could not initialize dnsfilter") + continue + } + + set := false + s.Lock() + if counter == s.startCounter { + s.dnsFilter = f + set = true + } + s.Unlock() + if set { + log.Debug("created and activated dnsfilter") + } else { + log.Debug("created dnsfilter") + } + } +} + // Stop stops the DNS server func (s *Server) Stop() error { s.Lock() diff --git a/home/dns.go b/home/dns.go index c9f199e1..bf5a2daa 100644 --- a/home/dns.go +++ b/home/dns.go @@ -104,6 +104,7 @@ func generateServerConfig() (dnsforward.ServerConfig, error) { FilteringConfig: config.DNS.FilteringConfig, Filters: filters, } + newconfig.AsyncStartup = true bindhost := config.DNS.BindHost if config.DNS.BindHost == "0.0.0.0" { bindhost = "127.0.0.1"