diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index 8c95eff4..fde411a9 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -88,7 +88,7 @@ type FilteringConfig struct { // Requests won't be filtered until dnsfilter is created. // If "restart" command is received while we're creating an old dnsfilter object, // we delay creation of the new object until the old one is created. - AsyncStartup bool + AsyncStartup bool `yaml:"-"` ProtectionEnabled bool `yaml:"protection_enabled"` // whether or not use any of dnsfilter features FilteringEnabled bool `yaml:"filtering_enabled"` // whether or not use filter lists diff --git a/home/config.go b/home/config.go index fd8bbed4..e2c52302 100644 --- a/home/config.go +++ b/home/config.go @@ -269,9 +269,6 @@ func parseConfig() error { return err } - if !checkStatsInterval(config.DNS.StatsInterval) { - config.DNS.StatsInterval = 1 - } if !checkFiltersUpdateIntervalHours(config.DNS.FiltersUpdateIntervalHours) { config.DNS.FiltersUpdateIntervalHours = 24 } @@ -356,6 +353,12 @@ func (c *configuration) write() error { config.Users = config.auth.GetUsers() } + if config.stats != nil { + sdc := stats.DiskConfig{} + config.stats.WriteDiskConfig(&sdc) + config.DNS.StatsInterval = sdc.Interval + } + configFile := config.getConfigFilename() log.Debug("Writing YAML file: %s", configFile) yamlText, err := yaml.Marshal(&config) diff --git a/home/control.go b/home/control.go index f6cded37..7d3d3f21 100644 --- a/home/control.go +++ b/home/control.go @@ -569,14 +569,11 @@ func registerControlHandlers() { registerRewritesHandlers() RegisterBlockedServicesHandlers() RegisterQueryLogHandlers() - RegisterStatsHandlers() RegisterAuthHandlers() http.HandleFunc("/dns-query", postInstall(handleDOH)) } -type httpHandlerType func(http.ResponseWriter, *http.Request) - -func httpRegister(method string, url string, handler httpHandlerType) { +func httpRegister(method string, url string, handler func(http.ResponseWriter, *http.Request)) { http.Handle(url, postInstallHandler(optionalAuthHandler(gziphandler.GzipHandler(ensureHandler(method, handler))))) } diff --git a/home/control_stats.go b/home/control_stats.go deleted file mode 100644 index 67b26f6b..00000000 --- a/home/control_stats.go +++ /dev/null @@ -1,94 +0,0 @@ -package home - -import ( - "encoding/json" - "net/http" - - "github.com/AdguardTeam/AdGuardHome/stats" - "github.com/AdguardTeam/golibs/log" -) - -type statsConfig struct { - Interval uint32 `json:"interval"` -} - -// Get stats configuration -func handleStatsInfo(w http.ResponseWriter, r *http.Request) { - resp := statsConfig{} - resp.Interval = config.DNS.StatsInterval - - jsonVal, err := json.Marshal(resp) - if err != nil { - httpError(w, http.StatusInternalServerError, "json encode: %s", err) - return - } - w.Header().Set("Content-Type", "application/json") - _, err = w.Write(jsonVal) - if err != nil { - httpError(w, http.StatusInternalServerError, "http write: %s", err) - } -} - -// Set stats configuration -func handleStatsConfig(w http.ResponseWriter, r *http.Request) { - reqData := statsConfig{} - err := json.NewDecoder(r.Body).Decode(&reqData) - if err != nil { - httpError(w, http.StatusBadRequest, "json decode: %s", err) - return - } - - if !checkStatsInterval(reqData.Interval) { - httpError(w, http.StatusBadRequest, "Unsupported interval") - return - } - - config.DNS.StatsInterval = reqData.Interval - _ = config.write() - - config.stats.Configure(int(config.DNS.StatsInterval)) - - returnOK(w) -} - -// handleStats returns aggregated stats data -func handleStats(w http.ResponseWriter, r *http.Request) { - units := stats.Hours - if config.DNS.StatsInterval > 7 { - units = stats.Days - } - counter := log.StartTimer() - d := config.stats.GetData(units) - counter.LogElapsed("Stats: prepared data") - - if d == nil { - httpError(w, http.StatusInternalServerError, "Couldn't get statistics data") - return - } - - data, err := json.Marshal(d) - if err != nil { - httpError(w, http.StatusInternalServerError, "json encode: %s", err) - return - } - - w.Write(data) -} - -// handleStatsReset resets the stats -func handleStatsReset(w http.ResponseWriter, r *http.Request) { - config.stats.Clear() - returnOK(w) -} - -// RegisterStatsHandlers - register handlers -func RegisterStatsHandlers() { - httpRegister(http.MethodGet, "/control/stats", handleStats) - httpRegister(http.MethodPost, "/control/stats_reset", handleStatsReset) - httpRegister(http.MethodPost, "/control/stats_config", handleStatsConfig) - httpRegister(http.MethodGet, "/control/stats_info", handleStatsInfo) -} - -func checkStatsInterval(i uint32) bool { - return i == 1 || i == 7 || i == 30 || i == 90 -} diff --git a/home/dns.go b/home/dns.go index e51cd9a2..c9c8c003 100644 --- a/home/dns.go +++ b/home/dns.go @@ -21,6 +21,11 @@ type dnsContext struct { whois *Whois } +// Called by other modules when configuration is changed +func onConfigModified() { + _ = config.write() +} + // initDNSServer creates an instance of the dnsforward.Server // Please note that we must do it even if we don't start it // so that we had access to the query log and the stats @@ -31,8 +36,10 @@ func initDNSServer(baseDir string) { } statsConf := stats.Config{ - Filename: filepath.Join(baseDir, "stats.db"), - LimitDays: config.DNS.StatsInterval, + Filename: filepath.Join(baseDir, "stats.db"), + LimitDays: config.DNS.StatsInterval, + ConfigModified: onConfigModified, + HTTPRegister: httpRegister, } config.stats, err = stats.New(statsConf) if err != nil { diff --git a/stats/stats.go b/stats/stats.go index d2b59859..85d0e149 100644 --- a/stats/stats.go +++ b/stats/stats.go @@ -4,15 +4,27 @@ package stats import ( "net" + "net/http" ) type unitIDCallback func() uint32 +// DiskConfig - configuration settings that are stored on disk +type DiskConfig struct { + Interval uint32 `yaml:"statistics_interval"` // time interval for statistics (in days) +} + // Config - module configuration type Config struct { Filename string // database file name LimitDays uint32 // time limit (in days) UnitID unitIDCallback // user function to get the current unit ID. If nil, the current time hour is used. + + // Called when the configuration is changed by HTTP request + ConfigModified func() + + // Register an HTTP handler + HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request)) } // New - create object @@ -27,18 +39,11 @@ type Stats interface { // (can't be called in parallel with any other function of this interface). Close() - // Set new configuration at runtime. - // limit: time limit (in days) - Configure(limit int) - - // Reset counters and clear database - Clear() - // Update counters Update(e Entry) - // Get data - GetData(timeUnit TimeUnit) map[string]interface{} + // WriteDiskConfig - write configuration + WriteDiskConfig(dc *DiskConfig) } // TimeUnit - time unit diff --git a/stats/stats_http.go b/stats/stats_http.go new file mode 100644 index 00000000..d5f08ddc --- /dev/null +++ b/stats/stats_http.go @@ -0,0 +1,107 @@ +// HTTP request handlers for accessing statistics data and configuration settings + +package stats + +import ( + "encoding/json" + "fmt" + "net/http" + + "github.com/AdguardTeam/golibs/log" +) + +func httpError(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) { + text := fmt.Sprintf(format, args...) + + log.Info("Stats: %s %s: %s", r.Method, r.URL, text) + + http.Error(w, text, code) +} + +// Return data +func (s *statsCtx) handleStats(w http.ResponseWriter, r *http.Request) { + units := Hours + if s.limit/24 > 7 { + units = Days + } + counter := log.StartTimer() + d := s.getData(units) + counter.LogElapsed("Stats: prepared data") + + if d == nil { + httpError(r, w, http.StatusInternalServerError, "Couldn't get statistics data") + return + } + + data, err := json.Marshal(d) + if err != nil { + httpError(r, w, http.StatusInternalServerError, "json encode: %s", err) + return + } + + w.Write(data) +} + +type config struct { + IntervalDays uint32 `json:"interval"` +} + +// Get configuration +func (s *statsCtx) handleStatsInfo(w http.ResponseWriter, r *http.Request) { + resp := config{} + resp.IntervalDays = s.limit / 24 + + data, err := json.Marshal(resp) + if err != nil { + httpError(r, w, http.StatusInternalServerError, "json encode: %s", err) + return + } + w.Header().Set("Content-Type", "application/json") + _, err = w.Write(data) + if err != nil { + httpError(r, w, http.StatusInternalServerError, "http write: %s", err) + } +} + +// Set configuration +func (s *statsCtx) handleStatsConfig(w http.ResponseWriter, r *http.Request) { + reqData := config{} + err := json.NewDecoder(r.Body).Decode(&reqData) + if err != nil { + httpError(r, w, http.StatusBadRequest, "json decode: %s", err) + return + } + + if !checkInterval(reqData.IntervalDays) { + httpError(r, w, http.StatusBadRequest, "Unsupported interval") + return + } + + s.setLimit(int(reqData.IntervalDays)) + s.conf.ConfigModified() +} + +// Reset data +func (s *statsCtx) handleStatsReset(w http.ResponseWriter, r *http.Request) { + s.clear() +} + +// Register web handlers +func (s *statsCtx) initWeb() { + if s.conf.HTTPRegister == nil { + return + } + + s.conf.HTTPRegister("GET", "/control/stats", func(w http.ResponseWriter, r *http.Request) { + s.handleStats(w, r) + }) + s.conf.HTTPRegister("POST", "/control/stats_reset", func(w http.ResponseWriter, r *http.Request) { + s.handleStatsReset(w, r) + }) + s.conf.HTTPRegister("POST", "/control/stats_config", func(w http.ResponseWriter, r *http.Request) { + s.handleStatsConfig(w, r) + }) + s.conf.HTTPRegister("GET", "/control/stats_info", func(w http.ResponseWriter, r *http.Request) { + s.handleStatsInfo(w, r) + }) +} diff --git a/stats/stats_test.go b/stats/stats_test.go index 0590b443..3fa4ecd3 100644 --- a/stats/stats_test.go +++ b/stats/stats_test.go @@ -30,7 +30,7 @@ func TestStats(t *testing.T) { Filename: "./stats.db", LimitDays: 1, } - s, _ := New(conf) + s, _ := createObject(conf) e := Entry{} @@ -46,7 +46,7 @@ func TestStats(t *testing.T) { e.Time = 123456 s.Update(e) - d := s.GetData(Hours) + d := s.getData(Hours) a := []uint64{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2} assert.True(t, UIntArrayEquals(d["dns_queries"].([]uint64), a)) @@ -75,7 +75,7 @@ func TestStats(t *testing.T) { assert.True(t, d["num_replaced_parental"].(uint64) == 0) assert.True(t, d["avg_processing_time"].(float64) == 0.123456) - s.Clear() + s.clear() s.Close() os.Remove(conf.Filename) } @@ -95,7 +95,7 @@ func TestLargeNumbers(t *testing.T) { UnitID: newID, } os.Remove(conf.Filename) - s, _ := New(conf) + s, _ := createObject(conf) e := Entry{} n := 1000 // number of distinct clients and domains every hour @@ -115,7 +115,7 @@ func TestLargeNumbers(t *testing.T) { } } - d := s.GetData(Hours) + d := s.getData(Hours) assert.True(t, d["num_dns_queries"].(uint64) == uint64(int(hour)*n)) s.Close() diff --git a/stats/stats_unit.go b/stats/stats_unit.go index eb6d6e85..3f7e49a0 100644 --- a/stats/stats_unit.go +++ b/stats/stats_unit.go @@ -64,6 +64,9 @@ type unitDB struct { func createObject(conf Config) (*statsCtx, error) { s := statsCtx{} + if !checkInterval(conf.LimitDays) { + conf.LimitDays = 1 + } s.limit = conf.LimitDays * 24 s.conf = conf if conf.UnitID == nil { @@ -112,12 +115,18 @@ func createObject(conf Config) (*statsCtx, error) { } s.unit = &u + s.initWeb() + go s.periodicFlush() log.Debug("Stats: initialized") return &s, nil } +func checkInterval(days uint32) bool { + return days == 1 || days == 7 || days == 30 || days == 90 +} + func (s *statsCtx) dbOpen() bool { var err error log.Tracef("db.Open...") @@ -362,12 +371,13 @@ func convertTopArray(a []countPair) []map[string]uint64 { return m } -func (s *statsCtx) Configure(limit int) { - if limit < 0 { - return - } - s.limit = uint32(limit) * 24 - log.Debug("Stats: set limit: %d", limit) +func (s *statsCtx) setLimit(limitDays int) { + s.limit = uint32(limitDays) * 24 + log.Debug("Stats: set limit: %d", limitDays) +} + +func (s *statsCtx) WriteDiskConfig(dc *DiskConfig) { + dc.Interval = s.limit / 24 } func (s *statsCtx) Close() { @@ -391,7 +401,8 @@ func (s *statsCtx) Close() { log.Debug("Stats: closed") } -func (s *statsCtx) Clear() { +// Reset counters and clear database +func (s *statsCtx) clear() { tx := s.beginTxn(true) if tx != nil { db := s.db @@ -472,7 +483,7 @@ func (s *statsCtx) Update(e Entry) { These values are just the sum of data for all units. */ // nolint (gocyclo) -func (s *statsCtx) GetData(timeUnit TimeUnit) map[string]interface{} { +func (s *statsCtx) getData(timeUnit TimeUnit) map[string]interface{} { d := map[string]interface{}{} tx := s.beginTxn(false)