diff --git a/dnsfilter/dnsfilter.go b/dnsfilter/dnsfilter.go index 61493f05..0f667850 100644 --- a/dnsfilter/dnsfilter.go +++ b/dnsfilter/dnsfilter.go @@ -577,6 +577,10 @@ func New(c *Config, filters map[int]string) *Dnsfilter { } } + return d +} + +func (d *Dnsfilter) Start() { d.filtersInitializerChan = make(chan filtersInitializerParams, 1) go d.filtersInitializer() @@ -584,7 +588,6 @@ func New(c *Config, filters map[int]string) *Dnsfilter { d.registerSecurityHandlers() d.registerRewritesHandlers() } - return d } // diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index c51a25b3..32875db1 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -33,6 +33,8 @@ var defaultDNS = []string{ } var defaultBootstrap = []string{"9.9.9.9", "149.112.112.112"} +var webRegistered bool + // Server is the main way to start a DNS server. // // Example: @@ -54,8 +56,7 @@ type Server struct { // We don't Start() it and so no listen port is required. internalProxy *proxy.Proxy - webRegistered bool - isRunning bool + isRunning bool sync.RWMutex conf ServerConfig @@ -318,8 +319,8 @@ func (s *Server) Prepare(config *ServerConfig) error { log.Fatal("len(proxyConfig.Upstreams) == 0") } - if !s.webRegistered && s.conf.HTTPRegister != nil { - s.webRegistered = true + if !webRegistered && s.conf.HTTPRegister != nil { + webRegistered = true s.registerHandlers() } diff --git a/home/control_install.go b/home/control_install.go index 9d13ec74..fe811ba0 100644 --- a/home/control_install.go +++ b/home/control_install.go @@ -244,6 +244,9 @@ func handleInstallConfigure(w http.ResponseWriter, r *http.Request) { var err2 error if err == nil { err2 = startDNSServer() + if err2 != nil { + closeDNSServer() + } } if err != nil || err2 != nil { config.firstRun = true diff --git a/home/dns.go b/home/dns.go index f760ff77..6a8be1fd 100644 --- a/home/dns.go +++ b/home/dns.go @@ -66,12 +66,14 @@ func initDNSServer() error { dnsConfig := generateServerConfig() err = Context.dnsServer.Prepare(&dnsConfig) if err != nil { + closeDNSServer() return fmt.Errorf("dnsServer.Prepare: %s", err) } sessFilename := filepath.Join(baseDir, "sessions.db") config.auth = InitAuth(sessFilename, config.Users, config.WebSessionTTLHours*60*60) if config.auth == nil { + closeDNSServer() return fmt.Errorf("Couldn't initialize Auth module") } config.Users = nil @@ -224,7 +226,10 @@ func startDNSServer() error { return errorx.Decorate(err, "Couldn't start forwarding DNS server") } + Context.dnsFilter.Start() startFiltering() + Context.stats.Start() + Context.queryLog.Start() const topClientsNumber = 100 // the number of clients to get topClients := Context.stats.GetTopClientsIP(topClientsNumber) @@ -261,19 +266,36 @@ func stopDNSServer() error { return errorx.Decorate(err, "Couldn't stop forwarding DNS server") } - // DNS forward module must be closed BEFORE stats or queryLog because it depends on them - Context.dnsServer.Close() - - Context.dnsFilter.Close() - Context.dnsFilter = nil - - Context.stats.Close() - Context.stats = nil - - Context.queryLog.Close() - Context.queryLog = nil - - config.auth.Close() - config.auth = nil + closeDNSServer() return nil } + +func closeDNSServer() { + // DNS forward module must be closed BEFORE stats or queryLog because it depends on them + if Context.dnsServer != nil { + Context.dnsServer.Close() + Context.dnsServer = nil + } + + if Context.dnsFilter != nil { + Context.dnsFilter.Close() + Context.dnsFilter = nil + } + + if Context.stats != nil { + Context.stats.Close() + Context.stats = nil + } + + if Context.queryLog != nil { + Context.queryLog.Close() + Context.queryLog = nil + } + + if config.auth != nil { + config.auth.Close() + config.auth = nil + } + + log.Debug("Closed all DNS modules") +} diff --git a/querylog/qlog.go b/querylog/qlog.go index 8c48c969..8a34130e 100644 --- a/querylog/qlog.go +++ b/querylog/qlog.go @@ -44,11 +44,14 @@ func newQueryLog(conf Config) *queryLog { if !checkInterval(l.conf.Interval) { l.conf.Interval = 1 } + return &l +} + +func (l *queryLog) Start() { if l.conf.HTTPRegister != nil { l.initWeb() } go l.periodicRotate() - return &l } func (l *queryLog) Close() { diff --git a/querylog/querylog.go b/querylog/querylog.go index 7d479d92..dcca14dd 100644 --- a/querylog/querylog.go +++ b/querylog/querylog.go @@ -18,6 +18,8 @@ type DiskConfig struct { // QueryLog - main interface type QueryLog interface { + Start() + // Close query log object Close() diff --git a/stats/stats.go b/stats/stats.go index f4c05b94..91b6b25f 100644 --- a/stats/stats.go +++ b/stats/stats.go @@ -36,6 +36,8 @@ func New(conf Config) (Stats, error) { // Stats - main interface type Stats interface { + Start() + // Close object. // This function is not thread safe // (can't be called in parallel with any other function of this interface). diff --git a/stats/stats_unit.go b/stats/stats_unit.go index b52b31f0..44aa66d5 100644 --- a/stats/stats_unit.go +++ b/stats/stats_unit.go @@ -114,14 +114,15 @@ func createObject(conf Config) (*statsCtx, error) { } s.unit = &u - s.initWeb() - - go s.periodicFlush() - log.Debug("Stats: initialized") return &s, nil } +func (s *statsCtx) Start() { + s.initWeb() + go s.periodicFlush() +} + func checkInterval(days uint32) bool { return days == 1 || days == 7 || days == 30 || days == 90 }