diff --git a/app.go b/app.go index b55e9863..ce6bd2f7 100644 --- a/app.go +++ b/app.go @@ -6,6 +6,7 @@ import ( "net" "net/http" "os" + "os/signal" "path/filepath" "strconv" @@ -15,7 +16,12 @@ import ( // VersionString will be set through ldflags, contains current version var VersionString = "undefined" +func cleanup() { + writeStats() +} + func main() { + c := make(chan os.Signal, 1) log.Printf("AdGuard DNS web interface backend, version %s\n", VersionString) box := packr.NewBox("build/static") { @@ -114,6 +120,18 @@ func main() { log.Fatal(err) } + err = loadStats() + if err != nil { + log.Fatal(err) + } + + signal.Notify(c, os.Interrupt) + go func() { + <-c + cleanup() + os.Exit(1) + }() + address := net.JoinHostPort(config.BindHost, strconv.Itoa(config.BindPort)) runStatsCollectors() diff --git a/control.go b/control.go index eae9975a..b16b6e22 100644 --- a/control.go +++ b/control.go @@ -217,7 +217,7 @@ func handleStatus(w http.ResponseWriter, r *http.Request) { // stats // ----- func handleStats(w http.ResponseWriter, r *http.Request) { - histrical := generateMapFromStats(&statistics.perHour, 0, 24) + histrical := generateMapFromStats(&statistics.PerHour, 0, 24) // sum them up summed := map[string]interface{}{} for key, values := range histrical { @@ -259,16 +259,16 @@ func handleStatsHistory(w http.ResponseWriter, r *http.Request) { switch timeUnitString { case "seconds": timeUnit = time.Second - stats = &statistics.perSecond + stats = &statistics.PerSecond case "minutes": timeUnit = time.Minute - stats = &statistics.perMinute + stats = &statistics.PerMinute case "hours": timeUnit = time.Hour - stats = &statistics.perHour + stats = &statistics.PerHour case "days": timeUnit = time.Hour * 24 - stats = &statistics.perDay + stats = &statistics.PerDay default: http.Error(w, "Must specify valid time_unit parameter", 400) return diff --git a/helpers.go b/helpers.go index d6997141..a35a5f1c 100644 --- a/helpers.go +++ b/helpers.go @@ -111,8 +111,8 @@ func generateMapFromStats(stats *periodicStats, start int, end int) map[string]i avgProcessingTime := make([]float64, 0) - count := getReversedSlice(stats.entries[processingTimeCount], start, end) - sum := getReversedSlice(stats.entries[processingTimeSum], start, end) + count := getReversedSlice(stats.Entries[processingTimeCount], start, end) + sum := getReversedSlice(stats.Entries[processingTimeSum], start, end) for i := 0; i < len(count); i++ { var avg float64 if count[i] != 0 { @@ -123,11 +123,11 @@ func generateMapFromStats(stats *periodicStats, start int, end int) map[string]i } result := map[string]interface{}{ - "dns_queries": getReversedSlice(stats.entries[totalRequests], start, end), - "blocked_filtering": getReversedSlice(stats.entries[filteredTotal], start, end), - "replaced_safebrowsing": getReversedSlice(stats.entries[filteredSafebrowsing], start, end), - "replaced_safesearch": getReversedSlice(stats.entries[filteredSafesearch], start, end), - "replaced_parental": getReversedSlice(stats.entries[filteredParental], start, end), + "dns_queries": getReversedSlice(stats.Entries[totalRequests], start, end), + "blocked_filtering": getReversedSlice(stats.Entries[filteredTotal], start, end), + "replaced_safebrowsing": getReversedSlice(stats.Entries[filteredSafebrowsing], start, end), + "replaced_safesearch": getReversedSlice(stats.Entries[filteredSafesearch], start, end), + "replaced_parental": getReversedSlice(stats.Entries[filteredParental], start, end), "avg_processing_time": avgProcessingTime, } return result diff --git a/stats.go b/stats.go index f2509334..163559a8 100644 --- a/stats.go +++ b/stats.go @@ -2,12 +2,14 @@ package main import ( "bufio" + "encoding/json" "io/ioutil" "log" "net" "net/http" "net/url" "os" + "path/filepath" "strconv" "strings" "syscall" @@ -33,31 +35,41 @@ const ( processingTimeCount = `coredns_dns_request_duration_seconds_count` ) +var entryWhiteList = map[string]bool{ + totalRequests: true, + filteredTotal: true, + filteredSafebrowsing: true, + filteredSafesearch: true, + filteredParental: true, + processingTimeSum: true, + processingTimeCount: true, +} + type periodicStats struct { - entries statsEntries - lastRotate time.Time // last time this data was rotated + Entries statsEntries + LastRotate time.Time // last time this data was rotated } type stats struct { - perSecond periodicStats - perMinute periodicStats - perHour periodicStats - perDay periodicStats + PerSecond periodicStats + PerMinute periodicStats + PerHour periodicStats + PerDay periodicStats - lastSeen statsEntry + LastSeen statsEntry } var statistics stats func initPeriodicStats(periodic *periodicStats) { - periodic.entries = statsEntries{} + periodic.Entries = statsEntries{} } func init() { - initPeriodicStats(&statistics.perSecond) - initPeriodicStats(&statistics.perMinute) - initPeriodicStats(&statistics.perHour) - initPeriodicStats(&statistics.perDay) + initPeriodicStats(&statistics.PerSecond) + initPeriodicStats(&statistics.PerMinute) + initPeriodicStats(&statistics.PerHour) + initPeriodicStats(&statistics.PerDay) } func runStatsCollectors() { @@ -85,37 +97,29 @@ func isConnRefused(err error) bool { return false } -func statsRotate(periodic *periodicStats, now time.Time) { - for key, values := range periodic.entries { - newValues := [statsHistoryElements]float64{} - for i := 1; i < len(values); i++ { - newValues[i] = values[i-1] +func statsRotate(periodic *periodicStats, now time.Time, rotations int64) { + // calculate how many times we should rotate + for r := int64(0); r < rotations; r++ { + for key, values := range periodic.Entries { + newValues := [statsHistoryElements]float64{} + for i := 1; i < len(values); i++ { + newValues[i] = values[i-1] + } + periodic.Entries[key] = newValues } - periodic.entries[key] = newValues } - periodic.lastRotate = now + if rotations > 0 { + periodic.LastRotate = now + } } // called every second, accumulates stats for each second, minute, hour and day func collectStats() { now := time.Now() - // rotate each second - // NOTE: since we are called every second, always rotate perSecond, otherwise aliasing problems cause the rotation to skip - if true { - statsRotate(&statistics.perSecond, now) - } - // if minute elapsed, rotate - if now.Sub(statistics.perMinute.lastRotate).Minutes() >= 1 { - statsRotate(&statistics.perMinute, now) - } - // if hour elapsed, rotate - if now.Sub(statistics.perHour.lastRotate).Hours() >= 1 { - statsRotate(&statistics.perHour, now) - } - // if day elapsed, rotate - if now.Sub(statistics.perDay.lastRotate).Hours()/24.0 >= 1 { - statsRotate(&statistics.perDay, now) - } + statsRotate(&statistics.PerSecond, now, int64(now.Sub(statistics.PerSecond.LastRotate)/time.Second)) + statsRotate(&statistics.PerMinute, now, int64(now.Sub(statistics.PerMinute.LastRotate)/time.Minute)) + statsRotate(&statistics.PerHour, now, int64(now.Sub(statistics.PerHour.LastRotate)/time.Hour)) + statsRotate(&statistics.PerDay, now, int64(now.Sub(statistics.PerDay.LastRotate)/time.Hour/24)) // grab HTTP from prometheus resp, err := client.Get("http://127.0.0.1:9153/metrics") @@ -169,6 +173,11 @@ func collectStats() { continue } + // keys not in whitelist are not ok + if entryWhiteList[key] == false { + continue + } + got, ok := entry[key] if ok { value += got @@ -177,16 +186,16 @@ func collectStats() { } // calculate delta - delta := calcDelta(entry, statistics.lastSeen) + delta := calcDelta(entry, statistics.LastSeen) // apply delta to second/minute/hour/day - applyDelta(&statistics.perSecond, delta) - applyDelta(&statistics.perMinute, delta) - applyDelta(&statistics.perHour, delta) - applyDelta(&statistics.perDay, delta) + applyDelta(&statistics.PerSecond, delta) + applyDelta(&statistics.PerMinute, delta) + applyDelta(&statistics.PerHour, delta) + applyDelta(&statistics.PerDay, delta) // save last seen - statistics.lastSeen = entry + statistics.LastSeen = entry } func calcDelta(current, seen statsEntry) statsEntry { @@ -201,8 +210,50 @@ func calcDelta(current, seen statsEntry) statsEntry { func applyDelta(current *periodicStats, delta statsEntry) { for key, deltaValue := range delta { - currentValues := current.entries[key] + currentValues := current.Entries[key] currentValues[0] += deltaValue - current.entries[key] = currentValues + current.Entries[key] = currentValues } } + +func loadStats() error { + statsFile := filepath.Join(config.ourBinaryDir, "stats.json") + if _, err := os.Stat(statsFile); os.IsNotExist(err) { + log.Printf("Stats JSON does not exist, skipping: %s", statsFile) + return nil + } + log.Printf("Loading JSON stats: %s", statsFile) + jsonText, err := ioutil.ReadFile(statsFile) + if err != nil { + log.Printf("Couldn't read JSON stats: %s", err) + return err + } + err = json.Unmarshal(jsonText, &statistics) + if err != nil { + log.Printf("Couldn't parse JSON stats: %s", err) + return err + } + + return nil +} + +func writeStats() error { + statsFile := filepath.Join(config.ourBinaryDir, "stats.json") + log.Printf("Writing JSON file: %s", statsFile) + json, err := json.MarshalIndent(statistics, "", " ") + if err != nil { + log.Printf("Couldn't generate JSON: %s", err) + return err + } + err = ioutil.WriteFile(statsFile+".tmp", json, 0644) + if err != nil { + log.Printf("Couldn't write stats in JSON: %s", err) + return err + } + err = os.Rename(statsFile+".tmp", statsFile) + if err != nil { + log.Printf("Couldn't rename stats JSON: %s", err) + return err + } + return nil +}