diff --git a/control.go b/control.go index a1b88900..d00d5e07 100644 --- a/control.go +++ b/control.go @@ -222,115 +222,68 @@ func handleStatus(w http.ResponseWriter, r *http.Request) { // stats // ----- func handleStats(w http.ResponseWriter, r *http.Request) { - histrical := generateMapFromStats(&statistics.PerHour, 0, 24) - // sum them up - summed := map[string]interface{}{} - for key, values := range histrical { - summedValue := 0.0 - floats, ok := values.([]float64) - if !ok { - continue - } - for _, v := range floats { - summedValue += v - } - summed[key] = summedValue + resp, err := client.Get("http://127.0.0.1:8618/stats") + if err != nil { + errortext := fmt.Sprintf("Couldn't get stats_top from coredns: %T %s\n", err, err) + log.Println(errortext) + http.Error(w, errortext, http.StatusBadGateway) + return + } + if resp != nil && resp.Body != nil { + defer resp.Body.Close() } - summed["stats_period"] = "24 hours" - json, err := json.Marshal(summed) + // read the body entirely + body, err := ioutil.ReadAll(resp.Body) if err != nil { - errortext := fmt.Sprintf("Unable to marshal status json: %s", err) + errortext := fmt.Sprintf("Couldn't read response body: %s", err) log.Println(errortext) - http.Error(w, errortext, 500) + http.Error(w, errortext, http.StatusBadGateway) return } + + // forward body entirely with status code w.Header().Set("Content-Type", "application/json") - _, err = w.Write(json) + w.Header().Set("Content-Length", strconv.Itoa(len(body))) + w.WriteHeader(resp.StatusCode) + _, err = w.Write(body) if err != nil { - errortext := fmt.Sprintf("Unable to write response json: %s", err) + errortext := fmt.Sprintf("Couldn't write body: %s", err) log.Println(errortext) - http.Error(w, errortext, 500) - return + http.Error(w, errortext, http.StatusInternalServerError) } } func handleStatsHistory(w http.ResponseWriter, r *http.Request) { - // handle time unit and prepare our time window size - now := time.Now() - timeUnitString := r.URL.Query().Get("time_unit") - var stats *periodicStats - var timeUnit time.Duration - switch timeUnitString { - case "seconds": - timeUnit = time.Second - stats = &statistics.PerSecond - case "minutes": - timeUnit = time.Minute - stats = &statistics.PerMinute - case "hours": - timeUnit = time.Hour - stats = &statistics.PerHour - case "days": - timeUnit = time.Hour * 24 - stats = &statistics.PerDay - default: - http.Error(w, "Must specify valid time_unit parameter", 400) - return - } - - // parse start and end time - startTime, err := time.Parse(time.RFC3339, r.URL.Query().Get("start_time")) + resp, err := client.Get("http://127.0.0.1:8618/stats_history?" + r.URL.RawQuery) if err != nil { - errortext := fmt.Sprintf("Must specify valid start_time parameter: %s", err) + errortext := fmt.Sprintf("Couldn't get stats_top from coredns: %T %s\n", err, err) log.Println(errortext) - http.Error(w, errortext, 400) + http.Error(w, errortext, http.StatusBadGateway) return } - endTime, err := time.Parse(time.RFC3339, r.URL.Query().Get("end_time")) + if resp != nil && resp.Body != nil { + defer resp.Body.Close() + } + + // read the body entirely + body, err := ioutil.ReadAll(resp.Body) if err != nil { - errortext := fmt.Sprintf("Must specify valid end_time parameter: %s", err) + errortext := fmt.Sprintf("Couldn't read response body: %s", err) log.Println(errortext) - http.Error(w, errortext, 400) + http.Error(w, errortext, http.StatusBadGateway) return } - // check if start and time times are within supported time range - timeRange := timeUnit * statsHistoryElements - if startTime.Add(timeRange).Before(now) { - http.Error(w, "start_time parameter is outside of supported range", 501) - return - } - if endTime.Add(timeRange).Before(now) { - http.Error(w, "end_time parameter is outside of supported range", 501) - return - } - - // calculate start and end of our array - // basically it's how many hours/minutes/etc have passed since now - start := int(now.Sub(endTime) / timeUnit) - end := int(now.Sub(startTime) / timeUnit) - - // swap them around if they're inverted - if start > end { - start, end = end, start - } - - data := generateMapFromStats(stats, start, end) - json, err := json.Marshal(data) - if err != nil { - errortext := fmt.Sprintf("Unable to marshal status json: %s", err) - log.Println(errortext) - http.Error(w, errortext, 500) - return - } + // forward body entirely with status code w.Header().Set("Content-Type", "application/json") - _, err = w.Write(json) + w.Header().Set("Content-Length", strconv.Itoa(len(body))) + w.WriteHeader(resp.StatusCode) + _, err = w.Write(body) if err != nil { - errortext := fmt.Sprintf("Unable to write response json: %s", err) + errortext := fmt.Sprintf("Couldn't write body: %s", err) log.Println(errortext) - http.Error(w, errortext, 500) - return + http.Error(w, errortext, http.StatusInternalServerError) } } diff --git a/coredns_plugin/coredns_plugin.go b/coredns_plugin/coredns_plugin.go index 474c3f22..195a9218 100644 --- a/coredns_plugin/coredns_plugin.go +++ b/coredns_plugin/coredns_plugin.go @@ -68,27 +68,6 @@ var defaultPluginSettings = plugSettings{ BlockedTTL: 3600, // in seconds } -func newDNSCounter(name string, help string) prometheus.Counter { - return prometheus.NewCounter(prometheus.CounterOpts{ - Namespace: plugin.Namespace, - Subsystem: "dnsfilter", - Name: name, - Help: help, - }) -} - -var ( - requests = newDNSCounter("requests_total", "Count of requests seen by dnsfilter.") - filtered = newDNSCounter("filtered_total", "Count of requests filtered by dnsfilter.") - filteredLists = newDNSCounter("filtered_lists_total", "Count of requests filtered by dnsfilter using lists.") - filteredSafebrowsing = newDNSCounter("filtered_safebrowsing_total", "Count of requests filtered by dnsfilter using safebrowsing.") - filteredParental = newDNSCounter("filtered_parental_total", "Count of requests filtered by dnsfilter using parental.") - filteredInvalid = newDNSCounter("filtered_invalid_total", "Count of requests filtered by dnsfilter because they were invalid.") - whitelisted = newDNSCounter("whitelisted_total", "Count of requests not filtered by dnsfilter because they are whitelisted.") - safesearch = newDNSCounter("safesearch_total", "Count of requests replaced by dnsfilter safesearch.") - errorsTotal = newDNSCounter("errors_total", "Count of requests that dnsfilter couldn't process because of transitive errors.") -) - // // coredns handling functions // @@ -183,10 +162,10 @@ func setupPlugin(c *caddy.Controller) (*plug, error) { } } - log.Printf("Loading top from querylog") - err := loadTopFromFiles() + log.Printf("Loading stats from querylog") + err := fillStatsFromQueryLog() if err != nil { - log.Printf("Failed to load top from querylog: %s", err) + log.Printf("Failed to load stats from querylog: %s", err) return nil, err } @@ -229,6 +208,7 @@ func setup(c *caddy.Controller) error { x.MustRegister(whitelisted) x.MustRegister(safesearch) x.MustRegister(errorsTotal) + x.MustRegister(elapsedTime) x.MustRegister(p) } return nil @@ -562,6 +542,8 @@ func (p *plug) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) ( } // log + elapsed := time.Since(start) + elapsedTime.Observe(elapsed.Seconds()) if p.settings.QueryLogEnabled { logRequest(r, rrw.Msg, result, time.Since(start), ip) } diff --git a/coredns_plugin/coredns_stats.go b/coredns_plugin/coredns_stats.go new file mode 100644 index 00000000..d57ba397 --- /dev/null +++ b/coredns_plugin/coredns_stats.go @@ -0,0 +1,391 @@ +package dnsfilter + +import ( + "encoding/json" + "fmt" + "log" + "net/http" + "sync" + "time" + + "github.com/coredns/coredns/plugin" + "github.com/prometheus/client_golang/prometheus" +) + +var ( + requests = newDNSCounter("requests_total", "Count of requests seen by dnsfilter.") + filtered = newDNSCounter("filtered_total", "Count of requests filtered by dnsfilter.") + filteredLists = newDNSCounter("filtered_lists_total", "Count of requests filtered by dnsfilter using lists.") + filteredSafebrowsing = newDNSCounter("filtered_safebrowsing_total", "Count of requests filtered by dnsfilter using safebrowsing.") + filteredParental = newDNSCounter("filtered_parental_total", "Count of requests filtered by dnsfilter using parental.") + filteredInvalid = newDNSCounter("filtered_invalid_total", "Count of requests filtered by dnsfilter because they were invalid.") + whitelisted = newDNSCounter("whitelisted_total", "Count of requests not filtered by dnsfilter because they are whitelisted.") + safesearch = newDNSCounter("safesearch_total", "Count of requests replaced by dnsfilter safesearch.") + errorsTotal = newDNSCounter("errors_total", "Count of requests that dnsfilter couldn't process because of transitive errors.") + elapsedTime = newDNSHistogram("request_duration", "Histogram of the time (in seconds) each request took.") +) + +// entries for single time period (for example all per-second entries) +type statsEntries map[string][statsHistoryElements]float64 + +// how far back to keep the stats +const statsHistoryElements = 60 + 1 // +1 for calculating delta + +// each periodic stat is a map of arrays +type periodicStats struct { + Entries statsEntries + period time.Duration // how long one entry lasts + LastRotate time.Time // last time this data was rotated + + sync.RWMutex +} + +type stats struct { + PerSecond periodicStats + PerMinute periodicStats + PerHour periodicStats + PerDay periodicStats +} + +// per-second/per-minute/per-hour/per-day stats +var statistics stats + +func initPeriodicStats(periodic *periodicStats, period time.Duration) { + periodic.Entries = statsEntries{} + periodic.LastRotate = time.Now() + periodic.period = period +} + +func init() { + purgeStats() +} + +func purgeStats() { + initPeriodicStats(&statistics.PerSecond, time.Second) + initPeriodicStats(&statistics.PerMinute, time.Minute) + initPeriodicStats(&statistics.PerHour, time.Hour) + initPeriodicStats(&statistics.PerDay, time.Hour*24) +} + +func (p *periodicStats) Inc(name string, when time.Time) { + // calculate how many periods ago this happened + elapsed := int64(time.Since(when) / p.period) + // trace("%s: %v as %v -> [%v]", name, time.Since(when), p.period, elapsed) + if elapsed >= statsHistoryElements { + return // outside of our timeframe + } + p.Lock() + currentValues := p.Entries[name] + currentValues[elapsed]++ + p.Entries[name] = currentValues + p.Unlock() +} + +func (p *periodicStats) Observe(name string, when time.Time, value float64) { + // calculate how many periods ago this happened + elapsed := int64(time.Since(when) / p.period) + // trace("%s: %v as %v -> [%v]", name, time.Since(when), p.period, elapsed) + if elapsed >= statsHistoryElements { + return // outside of our timeframe + } + p.Lock() + { + countname := name + "_count" + currentValues := p.Entries[countname] + value := currentValues[elapsed] + // trace("Will change p.Entries[%s][%d] from %v to %v", countname, elapsed, value, value+1) + value += 1 + currentValues[elapsed] = value + p.Entries[countname] = currentValues + } + { + totalname := name + "_sum" + currentValues := p.Entries[totalname] + currentValues[elapsed] += value + p.Entries[totalname] = currentValues + } + p.Unlock() +} + +func (p *periodicStats) statsRotate(now time.Time) { + p.Lock() + rotations := int64(now.Sub(p.LastRotate) / p.period) + if rotations > statsHistoryElements { + rotations = statsHistoryElements + } + // calculate how many times we should rotate + for r := int64(0); r < rotations; r++ { + for key, values := range p.Entries { + newValues := [statsHistoryElements]float64{} + for i := 1; i < len(values); i++ { + newValues[i] = values[i-1] + } + p.Entries[key] = newValues + } + } + if rotations > 0 { + p.LastRotate = now + } + p.Unlock() +} + +func statsRotator() { + for range time.Tick(time.Second) { + now := time.Now() + statistics.PerSecond.statsRotate(now) + statistics.PerMinute.statsRotate(now) + statistics.PerHour.statsRotate(now) + statistics.PerDay.statsRotate(now) + } +} + +// counter that wraps around prometheus Counter but also adds to periodic stats +type counter struct { + name string // used as key in periodic stats + value int64 + prom prometheus.Counter +} + +func newDNSCounter(name string, help string) *counter { + // trace("called") + c := &counter{} + c.prom = prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "dnsfilter", + Name: name, + Help: help, + }) + c.name = name + + return c +} + +func (c *counter) IncWithTime(when time.Time) { + statistics.PerSecond.Inc(c.name, when) + statistics.PerMinute.Inc(c.name, when) + statistics.PerHour.Inc(c.name, when) + statistics.PerDay.Inc(c.name, when) + c.value++ + c.prom.Inc() +} + +func (c *counter) Inc() { + c.IncWithTime(time.Now()) +} + +func (c *counter) Describe(ch chan<- *prometheus.Desc) { + c.prom.Describe(ch) +} + +func (c *counter) Collect(ch chan<- prometheus.Metric) { + c.prom.Collect(ch) +} + +type histogram struct { + name string // used as key in periodic stats + count int64 + total float64 + prom prometheus.Histogram +} + +func newDNSHistogram(name string, help string) *histogram { + // trace("called") + h := &histogram{} + h.prom = prometheus.NewHistogram(prometheus.HistogramOpts{ + Namespace: plugin.Namespace, + Subsystem: "dnsfilter", + Name: name, + Help: help, + }) + h.name = name + + return h +} + +func (h *histogram) ObserveWithTime(value float64, when time.Time) { + statistics.PerSecond.Observe(h.name, when, value) + statistics.PerMinute.Observe(h.name, when, value) + statistics.PerHour.Observe(h.name, when, value) + statistics.PerDay.Observe(h.name, when, value) + h.count++ + h.total += value + h.prom.Observe(value) +} + +func (h *histogram) Observe(value float64) { + h.ObserveWithTime(value, time.Now()) +} + +func (h *histogram) Describe(ch chan<- *prometheus.Desc) { + h.prom.Describe(ch) +} + +func (h *histogram) Collect(ch chan<- prometheus.Metric) { + h.prom.Collect(ch) +} + +// ----- +// stats +// ----- +func handleStats(w http.ResponseWriter, r *http.Request) { + histrical := generateMapFromStats(&statistics.PerHour, 0, 24) + // sum them up + summed := map[string]interface{}{} + for key, values := range histrical { + summedValue := 0.0 + floats, ok := values.([]float64) + if !ok { + continue + } + for _, v := range floats { + summedValue += v + } + summed[key] = summedValue + } + summed["stats_period"] = "24 hours" + + json, err := json.Marshal(summed) + if err != nil { + errortext := fmt.Sprintf("Unable to marshal status json: %s", err) + log.Println(errortext) + http.Error(w, errortext, 500) + return + } + w.Header().Set("Content-Type", "application/json") + _, err = w.Write(json) + if err != nil { + errortext := fmt.Sprintf("Unable to write response json: %s", err) + log.Println(errortext) + http.Error(w, errortext, 500) + return + } +} + +func generateMapFromStats(stats *periodicStats, start int, end int) map[string]interface{} { + // clamp + start = clamp(start, 0, statsHistoryElements) + end = clamp(end, 0, statsHistoryElements) + + avgProcessingTime := make([]float64, 0) + + count := getReversedSlice(stats.Entries[elapsedTime.name+"_count"], start, end) + sum := getReversedSlice(stats.Entries[elapsedTime.name+"_sum"], start, end) + for i := 0; i < len(count); i++ { + var avg float64 + if count[i] != 0 { + avg = sum[i] / count[i] + avg *= 1000 + } + avgProcessingTime = append(avgProcessingTime, avg) + } + + result := map[string]interface{}{ + "dns_queries": getReversedSlice(stats.Entries[requests.name], start, end), + "blocked_filtering": getReversedSlice(stats.Entries[filtered.name], start, end), + "replaced_safebrowsing": getReversedSlice(stats.Entries[filteredSafebrowsing.name], start, end), + "replaced_safesearch": getReversedSlice(stats.Entries[safesearch.name], start, end), + "replaced_parental": getReversedSlice(stats.Entries[filteredParental.name], start, end), + "avg_processing_time": avgProcessingTime, + } + return result +} + +func handleStatsHistory(w http.ResponseWriter, r *http.Request) { + // handle time unit and prepare our time window size + now := time.Now() + timeUnitString := r.URL.Query().Get("time_unit") + var stats *periodicStats + var timeUnit time.Duration + switch timeUnitString { + case "seconds": + timeUnit = time.Second + stats = &statistics.PerSecond + case "minutes": + timeUnit = time.Minute + stats = &statistics.PerMinute + case "hours": + timeUnit = time.Hour + stats = &statistics.PerHour + case "days": + timeUnit = time.Hour * 24 + stats = &statistics.PerDay + default: + http.Error(w, "Must specify valid time_unit parameter", 400) + return + } + + // parse start and end time + startTime, err := time.Parse(time.RFC3339, r.URL.Query().Get("start_time")) + if err != nil { + errortext := fmt.Sprintf("Must specify valid start_time parameter: %s", err) + log.Println(errortext) + http.Error(w, errortext, 400) + return + } + endTime, err := time.Parse(time.RFC3339, r.URL.Query().Get("end_time")) + if err != nil { + errortext := fmt.Sprintf("Must specify valid end_time parameter: %s", err) + log.Println(errortext) + http.Error(w, errortext, 400) + return + } + + // check if start and time times are within supported time range + timeRange := timeUnit * statsHistoryElements + if startTime.Add(timeRange).Before(now) { + http.Error(w, "start_time parameter is outside of supported range", 501) + return + } + if endTime.Add(timeRange).Before(now) { + http.Error(w, "end_time parameter is outside of supported range", 501) + return + } + + // calculate start and end of our array + // basically it's how many hours/minutes/etc have passed since now + start := int(now.Sub(endTime) / timeUnit) + end := int(now.Sub(startTime) / timeUnit) + + // swap them around if they're inverted + if start > end { + start, end = end, start + } + + data := generateMapFromStats(stats, start, end) + json, err := json.Marshal(data) + if err != nil { + errortext := fmt.Sprintf("Unable to marshal status json: %s", err) + log.Println(errortext) + http.Error(w, errortext, 500) + return + } + w.Header().Set("Content-Type", "application/json") + _, err = w.Write(json) + if err != nil { + errortext := fmt.Sprintf("Unable to write response json: %s", err) + log.Println(errortext) + http.Error(w, errortext, 500) + return + } +} + +func clamp(value, low, high int) int { + if value < low { + return low + } + if value > high { + return high + } + return value +} + +// -------------------------- +// helper functions for stats +// -------------------------- +func getReversedSlice(input [statsHistoryElements]float64, start int, end int) []float64 { + output := make([]float64, 0) + for i := start; i <= end; i++ { + output = append([]float64{input[i]}, output...) + } + return output +} diff --git a/coredns_plugin/querylog.go b/coredns_plugin/querylog.go index 9e45315f..ecf9185b 100644 --- a/coredns_plugin/querylog.go +++ b/coredns_plugin/querylog.go @@ -88,7 +88,7 @@ func logRequest(question *dns.Msg, answer *dns.Msg, result dnsfilter.Result, ela logBufferLock.Unlock() // add it to running top - err = runningTop.addEntry(&entry, now) + err = runningTop.addEntry(&entry, question, now) if err != nil { log.Printf("Failed to add entry to running top: %s", err) // don't do failure, just log @@ -100,7 +100,6 @@ func logRequest(question *dns.Msg, answer *dns.Msg, result dnsfilter.Result, ela // do it in separate goroutine -- we are stalling DNS response this whole time go flushToFile(flushBuffer) } - return } func handleQueryLog(w http.ResponseWriter, r *http.Request) { @@ -114,7 +113,7 @@ func handleQueryLog(w http.ResponseWriter, r *http.Request) { if needRefresh { // need to get fresh data logBufferLock.RLock() - values := logBuffer + values = logBuffer logBufferLock.RUnlock() if len(values) < queryLogCacheSize { @@ -238,9 +237,12 @@ func startQueryLogServer() { go periodicQueryLogRotate() go periodicHourlyTopRotate() + go statsRotator() http.HandleFunc("/querylog", handleQueryLog) + http.HandleFunc("/stats", handleStats) http.HandleFunc("/stats_top", handleStatsTop) + http.HandleFunc("/stats_history", handleStatsHistory) if err := http.ListenAndServe(listenAddr, nil); err != nil { log.Fatalf("error in ListenAndServe: %s", err) } diff --git a/coredns_plugin/querylog_file.go b/coredns_plugin/querylog_file.go index 7025fcd3..72cd4d32 100644 --- a/coredns_plugin/querylog_file.go +++ b/coredns_plugin/querylog_file.go @@ -17,6 +17,8 @@ var ( fileWriteLock sync.Mutex ) +const enableGzip = false + func flushToFile(buffer []logEntry) error { if len(buffer) == 0 { return nil @@ -42,31 +44,37 @@ func flushToFile(buffer []logEntry) error { return err } - filenamegz := queryLogFileName + ".gz" - var zb bytes.Buffer + filename := queryLogFileName - zw := gzip.NewWriter(&zb) - zw.Name = queryLogFileName - zw.ModTime = time.Now() + // gzip enabled? + if enableGzip { + filename += ".gz" - _, err = zw.Write(b.Bytes()) - if err != nil { - log.Printf("Couldn't compress to gzip: %s", err) - zw.Close() - return err - } + zw := gzip.NewWriter(&zb) + zw.Name = queryLogFileName + zw.ModTime = time.Now() - if err = zw.Close(); err != nil { - log.Printf("Couldn't close gzip writer: %s", err) - return err + _, err = zw.Write(b.Bytes()) + if err != nil { + log.Printf("Couldn't compress to gzip: %s", err) + zw.Close() + return err + } + + if err = zw.Close(); err != nil { + log.Printf("Couldn't close gzip writer: %s", err) + return err + } + } else { + zb = b } fileWriteLock.Lock() defer fileWriteLock.Unlock() - f, err := os.OpenFile(filenamegz, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644) + f, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644) if err != nil { - log.Printf("failed to create file \"%s\": %s", filenamegz, err) + log.Printf("failed to create file \"%s\": %s", filename, err) return err } defer f.Close() @@ -77,7 +85,7 @@ func flushToFile(buffer []logEntry) error { return err } - log.Printf("ok \"%s\": %v bytes written", filenamegz, n) + log.Printf("ok \"%s\": %v bytes written", filename, n) return nil } @@ -111,8 +119,13 @@ func checkBuffer(buffer []logEntry, b bytes.Buffer) error { } func rotateQueryLog() error { - from := queryLogFileName + ".gz" - to := queryLogFileName + ".gz.1" + from := queryLogFileName + to := queryLogFileName + ".1" + + if enableGzip { + from = queryLogFileName + ".gz" + to = queryLogFileName + ".gz.1" + } if _, err := os.Stat(from); os.IsNotExist(err) { // do nothing, file doesn't exist @@ -143,9 +156,18 @@ func periodicQueryLogRotate() { func genericLoader(onEntry func(entry *logEntry) error, needMore func() bool, timeWindow time.Duration) error { now := time.Now() // read from querylog files, try newest file first - files := []string{ - queryLogFileName + ".gz", - queryLogFileName + ".gz.1", + files := []string{} + + if enableGzip { + files = []string{ + queryLogFileName + ".gz", + queryLogFileName + ".gz.1", + } + } else { + files = []string{ + queryLogFileName, + queryLogFileName + ".1", + } } // read from all files @@ -158,7 +180,6 @@ func genericLoader(onEntry func(entry *logEntry) error, needMore func() bool, ti continue } - trace("Opening file %s", file) f, err := os.Open(file) if err != nil { log.Printf("Failed to open file \"%s\": %s", file, err) @@ -167,16 +188,22 @@ func genericLoader(onEntry func(entry *logEntry) error, needMore func() bool, ti } defer f.Close() - trace("Creating gzip reader") - zr, err := gzip.NewReader(f) - if err != nil { - log.Printf("Failed to create gzip reader: %s", err) - continue - } - defer zr.Close() + var d *json.Decoder - trace("Creating json decoder") - d := json.NewDecoder(zr) + if enableGzip { + trace("Creating gzip reader") + zr, err := gzip.NewReader(f) + if err != nil { + log.Printf("Failed to create gzip reader: %s", err) + continue + } + defer zr.Close() + + trace("Creating json decoder") + d = json.NewDecoder(zr) + } else { + d = json.NewDecoder(f) + } i := 0 // entries on file are in oldest->newest order @@ -204,7 +231,8 @@ func genericLoader(onEntry func(entry *logEntry) error, needMore func() bool, ti return err } } - log.Printf("file \"%s\": read %d entries", file, i) + elapsed := time.Since(now) + log.Printf("file \"%s\": read %d entries in %v, %v/entry", file, i, elapsed, elapsed/time.Duration(i)) } return nil } diff --git a/coredns_plugin/querylog_top.go b/coredns_plugin/querylog_top.go index 7e92a7f5..6ff531f2 100644 --- a/coredns_plugin/querylog_top.go +++ b/coredns_plugin/querylog_top.go @@ -14,6 +14,7 @@ import ( "sync" "time" + "github.com/AdguardTeam/AdguardDNS/dnsfilter" "github.com/bluele/gcache" "github.com/miekg/dns" ) @@ -155,16 +156,7 @@ func (top *hourTop) lockedGetClients(key string) (int, error) { return top.lockedGetValue(key, top.clients) } -func (r *dayTop) addEntry(entry *logEntry, now time.Time) error { - if len(entry.Question) == 0 { - log.Printf("entry question is absent, skipping") - return nil - } - - if entry.Time.After(now) { - log.Printf("t %v vs %v is in the future, ignoring", entry.Time, now) - return nil - } +func (r *dayTop) addEntry(entry *logEntry, q *dns.Msg, now time.Time) error { // figure out which hour bucket it belongs to hour := int(now.Sub(entry.Time).Hours()) if hour >= 24 { @@ -172,17 +164,6 @@ func (r *dayTop) addEntry(entry *logEntry, now time.Time) error { return nil } - q := new(dns.Msg) - if err := q.Unpack(entry.Question); err != nil { - log.Printf("failed to unpack dns message question: %s", err) - return err - } - - if len(q.Question) != 1 { - log.Printf("malformed dns message, has no questions, skipping") - return nil - } - hostname := strings.ToLower(strings.TrimSuffix(q.Question[0].Name, ".")) // get value, if not set, crate one @@ -213,7 +194,7 @@ func (r *dayTop) addEntry(entry *logEntry, now time.Time) error { return nil } -func loadTopFromFiles() error { +func fillStatsFromQueryLog() error { now := time.Now() runningTop.loadedWriteLock() defer runningTop.loadedWriteUnlock() @@ -221,11 +202,55 @@ func loadTopFromFiles() error { return nil } onEntry := func(entry *logEntry) error { - err := runningTop.addEntry(entry, now) + if len(entry.Question) == 0 { + log.Printf("entry question is absent, skipping") + return nil + } + + if entry.Time.After(now) { + log.Printf("t %v vs %v is in the future, ignoring", entry.Time, now) + return nil + } + + q := new(dns.Msg) + if err := q.Unpack(entry.Question); err != nil { + log.Printf("failed to unpack dns message question: %s", err) + return err + } + + if len(q.Question) != 1 { + log.Printf("malformed dns message, has no questions, skipping") + return nil + } + + err := runningTop.addEntry(entry, q, now) if err != nil { log.Printf("Failed to add entry to running top: %s", err) return err } + + requests.IncWithTime(entry.Time) + if entry.Result.IsFiltered { + filtered.IncWithTime(entry.Time) + } + switch entry.Result.Reason { + case dnsfilter.NotFilteredWhiteList: + whitelisted.IncWithTime(entry.Time) + case dnsfilter.NotFilteredError: + errorsTotal.IncWithTime(entry.Time) + case dnsfilter.FilteredBlackList: + filteredLists.IncWithTime(entry.Time) + case dnsfilter.FilteredSafeBrowsing: + filteredSafebrowsing.IncWithTime(entry.Time) + case dnsfilter.FilteredParental: + filteredParental.IncWithTime(entry.Time) + case dnsfilter.FilteredInvalid: + // do nothing + case dnsfilter.FilteredSafeSearch: + safesearch.IncWithTime(entry.Time) + } + elapsedTime.ObserveWithTime(entry.Elapsed.Seconds(), entry.Time) + return nil }