From a2434d4574bf5dc3ab892748776086ba565e4356 Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Sun, 7 Oct 2018 23:24:04 +0300 Subject: [PATCH] coredns plugin -- Calculate top for domains, clients and blocked both from querylog and running requests. This moves the functionality from frontend to coredns plugin. --- control.go | 87 +-------- coredns_plugin/coredns_plugin.go | 17 +- coredns_plugin/querylog.go | 14 +- coredns_plugin/querylog_file.go | 49 +++-- coredns_plugin/querylog_top.go | 325 +++++++++++++++++++++++++++++++ helpers.go | 85 -------- 6 files changed, 394 insertions(+), 183 deletions(-) create mode 100644 coredns_plugin/querylog_top.go diff --git a/control.go b/control.go index 06adb498..dea887b2 100644 --- a/control.go +++ b/control.go @@ -417,9 +417,9 @@ func handleStatsReset(w http.ResponseWriter, r *http.Request) { } func handleStatsTop(w http.ResponseWriter, r *http.Request) { - resp, err := client.Get("http://127.0.0.1:8618/querylog") + resp, err := client.Get("http://127.0.0.1:8618/stats_top") if err != nil { - errortext := fmt.Sprintf("Couldn't get querylog from coredns: %T %s\n", err, err) + errortext := fmt.Sprintf("Couldn't get stats_top from coredns: %T %s\n", err, err) log.Println(errortext) http.Error(w, errortext, http.StatusBadGateway) return @@ -428,7 +428,7 @@ func handleStatsTop(w http.ResponseWriter, r *http.Request) { defer resp.Body.Close() } - // read body + // read the body entirely body, err := ioutil.ReadAll(resp.Body) if err != nil { errortext := fmt.Sprintf("Couldn't read response body: %s", err) @@ -436,85 +436,12 @@ func handleStatsTop(w http.ResponseWriter, r *http.Request) { http.Error(w, errortext, http.StatusBadGateway) return } - // empty body - if len(body) == 0 { - return - } - - values := []interface{}{} - err = json.Unmarshal(body, &values) - if err != nil { - errortext := fmt.Sprintf("Couldn't parse response body: %s", err) - log.Println(errortext) - http.Error(w, errortext, http.StatusBadGateway) - return - } - - domains := map[string]int{} - blocked := map[string]int{} - clients := map[string]int{} - now := time.Now() - timeWindow := time.Hour * 24 - notBefore := now.Add(timeWindow * -1) - - for _, value := range values { - entry, ok := value.(map[string]interface{}) - if !ok { - // ignore anything else - continue - } - host := getHost(entry) - reason := getReason(entry) - client := getClient(entry) - time := getTime(entry) - if time.Before(notBefore) { - // skip if the entry is before specified cutoff - continue - } - if len(host) > 0 { - domains[host]++ - } - if len(host) > 0 && strings.HasPrefix(reason, "Filtered") { - blocked[host]++ - } - if len(client) > 0 { - clients[client]++ - } - } - - // use manual json marshalling because we want maps to be sorted by value - json := bytes.Buffer{} - json.WriteString("{\n") - - gen := func(json *bytes.Buffer, name string, top map[string]int, addComma bool) { - json.WriteString(" \"") - json.WriteString(name) - json.WriteString("\": {\n") - sorted := sortByValue(top) - for i, key := range sorted { - json.WriteString(" \"") - json.WriteString(key) - json.WriteString("\": ") - json.WriteString(strconv.Itoa(top[key])) - if i+1 != len(sorted) { - json.WriteByte(',') - } - json.WriteByte('\n') - } - json.WriteString(" }") - if addComma { - json.WriteByte(',') - } - json.WriteByte('\n') - } - gen(&json, "top_queried_domains", domains, true) - gen(&json, "top_blocked_domains", blocked, true) - gen(&json, "top_clients", clients, true) - json.WriteString(" \"stats_period\": \"24 hours\"\n") - json.WriteString("}\n") + // forward body entirely with status code w.Header().Set("Content-Type", "application/json") - _, err = w.Write(json.Bytes()) + w.Header().Set("Content-Length", strconv.Itoa(len(body))) + w.WriteHeader(resp.StatusCode) + _, err = w.Write(body) if err != nil { errortext := fmt.Sprintf("Couldn't write body: %s", err) log.Println(errortext) diff --git a/coredns_plugin/coredns_plugin.go b/coredns_plugin/coredns_plugin.go index 2db776ac..474c3f22 100644 --- a/coredns_plugin/coredns_plugin.go +++ b/coredns_plugin/coredns_plugin.go @@ -147,9 +147,6 @@ func setupPlugin(c *caddy.Controller) (*plug, error) { p.settings.BlockedTTL = uint32(blockttl) case "querylog": p.settings.QueryLogEnabled = true - onceQueryLog.Do(func() { - go startQueryLogServer() // TODO: how to handle errors? - }) } } } @@ -186,7 +183,19 @@ func setupPlugin(c *caddy.Controller) (*plug, error) { } } - var err error + log.Printf("Loading top from querylog") + err := loadTopFromFiles() + if err != nil { + log.Printf("Failed to load top from querylog: %s", err) + return nil, err + } + + if p.settings.QueryLogEnabled { + onceQueryLog.Do(func() { + go startQueryLogServer() // TODO: how to handle errors? + }) + } + p.upstream, err = upstream.New(nil) if err != nil { return nil, err diff --git a/coredns_plugin/querylog.go b/coredns_plugin/querylog.go index d064ea08..c0df50e2 100644 --- a/coredns_plugin/querylog.go +++ b/coredns_plugin/querylog.go @@ -57,11 +57,12 @@ func logRequest(question *dns.Msg, answer *dns.Msg, result dnsfilter.Result, ela } } + now := time.Now() entry := logEntry{ Question: q, Answer: a, Result: result, - Time: time.Now(), + Time: now, Elapsed: elapsed, IP: ip, } @@ -74,6 +75,15 @@ func logRequest(question *dns.Msg, answer *dns.Msg, result dnsfilter.Result, ela logBuffer = nil } logBufferLock.Unlock() + + // add it to running top + err = runningTop.addEntry(&entry, now) + if err != nil { + log.Printf("Failed to add entry to running top: %s", err) + // don't do failure, just log + } + + // if buffer needs to be flushed to disk, do it now if len(flushBuffer) > 0 { // write to file // do it in separate goroutine -- we are stalling DNS response this whole time @@ -204,8 +214,10 @@ func startQueryLogServer() { listenAddr := "127.0.0.1:8618" // 8618 is sha512sum of "querylog" then each byte summed go periodicQueryLogRotate(queryLogRotationPeriod) + go periodicHourlyTopRotate() http.HandleFunc("/querylog", handleQueryLog) + http.HandleFunc("/stats_top", handleStatsTop) if err := http.ListenAndServe(listenAddr, nil); err != nil { log.Fatalf("error in ListenAndServe: %s", err) } diff --git a/coredns_plugin/querylog_file.go b/coredns_plugin/querylog_file.go index 412817ed..fffd050a 100644 --- a/coredns_plugin/querylog_file.go +++ b/coredns_plugin/querylog_file.go @@ -145,7 +145,7 @@ func periodicQueryLogRotate(t time.Duration) { } } -func appendFromLogFile(values []logEntry, maxLen int, timeWindow time.Duration) []logEntry { +func genericLoader(onEntry func(entry *logEntry) error, needMore func() bool, timeWindow time.Duration) error { now := time.Now() // read from querylog files, try newest file first files := []string{ @@ -153,12 +153,9 @@ func appendFromLogFile(values []logEntry, maxLen int, timeWindow time.Duration) queryLogFileName + ".gz.1", } - a := []logEntry{} - // read from all files for _, file := range files { - if len(a) >= maxLen { - // previous file filled us with enough fresh entries + if !needMore() { break } if _, err := os.Stat(file); os.IsNotExist(err) { @@ -181,6 +178,7 @@ func appendFromLogFile(values []logEntry, maxLen int, timeWindow time.Duration) log.Printf("Failed to create gzip reader: %s", err) continue } + defer zr.Close() trace("Creating json decoder") d := json.NewDecoder(zr) @@ -189,6 +187,9 @@ func appendFromLogFile(values []logEntry, maxLen int, timeWindow time.Duration) // entries on file are in oldest->newest order // we want maxLen newest for d.More() { + if !needMore() { + break + } var entry logEntry err := d.Decode(&entry) if err != nil { @@ -203,18 +204,40 @@ func appendFromLogFile(values []logEntry, maxLen int, timeWindow time.Duration) } i++ - a = append(a, entry) - if len(a) > maxLen { - toskip := len(a) - maxLen - a = a[toskip:] + err = onEntry(&entry) + if err != nil { + return err } } - err = zr.Close() - if err != nil { - log.Printf("Encountered error while closing gzip reader: %s", err) - } log.Printf("file \"%s\": read %d entries", file, i) } + return nil +} + +func appendFromLogFile(values []logEntry, maxLen int, timeWindow time.Duration) []logEntry { + a := []logEntry{} + + onEntry := func(entry *logEntry) error { + a = append(a, *entry) + if len(a) > maxLen { + toskip := len(a) - maxLen + a = a[toskip:] + } + return nil + } + + needMore := func() bool { + if len(a) >= maxLen { + return false + } + return true + } + + err := genericLoader(onEntry, needMore, timeWindow) + if err != nil { + log.Printf("Failed to load entries from querylog: %s", err) + return values + } // now that we've read all eligible entries, reverse the slice to make it go from newest->oldest for left, right := 0, len(a)-1; left < right; left, right = left+1, right-1 { diff --git a/coredns_plugin/querylog_top.go b/coredns_plugin/querylog_top.go new file mode 100644 index 00000000..2ba7bcb4 --- /dev/null +++ b/coredns_plugin/querylog_top.go @@ -0,0 +1,325 @@ +package dnsfilter + +import ( + "bytes" + "fmt" + "log" + "net/http" + "sort" + "strconv" + "strings" + "sync" + "time" + + "github.com/bluele/gcache" + "github.com/miekg/dns" +) + +// top domains/clients/blocked stats in the last 24 hours + +// on start we read the saved stats from the last 24 hours and add them to the stats + +// stats are counted using hourly LRU, rotating hourly and keeping last 24 hours + +type hourTop struct { + domains gcache.Cache + blocked gcache.Cache + clients gcache.Cache + sync.RWMutex +} + +func (top *hourTop) init() { + top.domains = gcache.New(500).LRU().Build() + top.blocked = gcache.New(500).LRU().Build() + top.clients = gcache.New(500).LRU().Build() +} + +type dayTop struct { + hours []*hourTop + sync.RWMutex // write -- rotating hourTop, read -- anything else +} + +var runningTop dayTop + +func init() { + runningTop.Lock() + for i := 0; i < 24; i++ { + hour := hourTop{} + hour.init() + runningTop.hours = append(runningTop.hours, &hour) + } + runningTop.Unlock() +} + +func rotateHourlyTop() { + log.Printf("Rotating hourly top") + hour := &hourTop{} + hour.init() + runningTop.Lock() + runningTop.hours = append([]*hourTop{hour}, runningTop.hours...) + runningTop.hours = runningTop.hours[:24] + runningTop.Unlock() +} + +func periodicHourlyTopRotate() { + t := time.Hour + for range time.Tick(t) { + rotateHourlyTop() + } +} + +func (top *hourTop) incrementValue(key string, cache gcache.Cache) error { + top.Lock() + defer top.Unlock() + ivalue, err := cache.Get(key) + if err == gcache.KeyNotFoundError { + // we just set it and we're done + err = cache.Set(key, 1) + if err != nil { + log.Printf("Failed to set hourly top value: %s", err) + return err + } + return nil + } + + if err != nil { + log.Printf("gcache encountered an error during get: %s", err) + return err + } + + cachedValue, ok := ivalue.(int) + if !ok { + err = fmt.Errorf("SHOULD NOT HAPPEN: gcache has non-int as value: %v", ivalue) + log.Println(err) + return err + } + + err = cache.Set(key, cachedValue+1) + if err != nil { + log.Printf("Failed to set hourly top value: %s", err) + return err + } + return nil +} + +func (top *hourTop) incrementDomains(key string) error { + return top.incrementValue(key, top.domains) +} + +func (top *hourTop) incrementBlocked(key string) error { + return top.incrementValue(key, top.blocked) +} + +func (top *hourTop) incrementClients(key string) error { + return top.incrementValue(key, top.clients) +} + +// if does not exist -- return 0 +func (top *hourTop) lockedGetValue(key string, cache gcache.Cache) (int, error) { + ivalue, err := cache.Get(key) + if err == gcache.KeyNotFoundError { + return 0, nil + } + + if err != nil { + log.Printf("gcache encountered an error during get: %s", err) + return 0, err + } + + value, ok := ivalue.(int) + if !ok { + err := fmt.Errorf("SHOULD NOT HAPPEN: gcache has non-int as value: %v", ivalue) + log.Println(err) + return 0, err + } + + return value, nil +} + +func (top *hourTop) lockedGetDomains(key string) (int, error) { + return top.lockedGetValue(key, top.domains) +} + +func (top *hourTop) lockedGetBlocked(key string) (int, error) { + return top.lockedGetValue(key, top.blocked) +} + +func (top *hourTop) lockedGetClients(key string) (int, error) { + return top.lockedGetValue(key, top.clients) +} + +func (r *dayTop) addEntry(entry *logEntry, now time.Time) error { + if len(entry.Question) == 0 { + log.Printf("entry question is absent, skipping") + return nil + } + + if entry.Time.After(now) { + log.Printf("t %v vs %v is in the future, ignoring", entry.Time, now) + return nil + } + // figure out which hour bucket it belongs to + hour := int(now.Sub(entry.Time).Hours()) + if hour >= 24 { + log.Printf("t %v is >24 hours ago, ignoring", entry.Time) + return nil + } + + q := new(dns.Msg) + if err := q.Unpack(entry.Question); err != nil { + log.Printf("failed to unpack dns message question: %s", err) + return err + } + + if len(q.Question) != 1 { + log.Printf("malformed dns message, has no questions, skipping") + return nil + } + + hostname := strings.ToLower(strings.TrimSuffix(q.Question[0].Name, ".")) + + // get value, if not set, crate one + runningTop.RLock() + defer runningTop.RUnlock() + err := runningTop.hours[hour].incrementDomains(hostname) + if err != nil { + log.Printf("Failed to increment value: %s", err) + return err + } + + if entry.Result.IsFiltered { + err := runningTop.hours[hour].incrementBlocked(hostname) + if err != nil { + log.Printf("Failed to increment value: %s", err) + return err + } + } + + if len(entry.IP) > 0 { + err := runningTop.hours[hour].incrementClients(entry.IP) + if err != nil { + log.Printf("Failed to increment value: %s", err) + return err + } + } + + return nil +} + +func loadTopFromFiles() error { + now := time.Now() + runningTop.RLock() + defer runningTop.RUnlock() + onEntry := func(entry *logEntry) error { + err := runningTop.addEntry(entry, now) + if err != nil { + log.Printf("Failed to add entry to running top: %s", err) + return err + } + return nil + } + + needMore := func() bool { return true } + err := genericLoader(onEntry, needMore, time.Hour*24) + if err != nil { + log.Printf("Failed to load entries from querylog: %s", err) + return err + } + + return nil +} + +func handleStatsTop(w http.ResponseWriter, r *http.Request) { + domains := map[string]int{} + blocked := map[string]int{} + clients := map[string]int{} + + do := func(keys []interface{}, getter func(key string) (int, error), result map[string]int) { + for _, ikey := range keys { + key, ok := ikey.(string) + if !ok { + continue + } + value, err := getter(key) + if err != nil { + log.Printf("Failed to get top domains value for %v: %s", key, err) + return + } + result[key] += value + } + } + + runningTop.RLock() + for hour := 0; hour < 24; hour++ { + runningTop.hours[hour].RLock() + do(runningTop.hours[hour].domains.Keys(), runningTop.hours[hour].lockedGetDomains, domains) + do(runningTop.hours[hour].blocked.Keys(), runningTop.hours[hour].lockedGetBlocked, blocked) + do(runningTop.hours[hour].clients.Keys(), runningTop.hours[hour].lockedGetClients, clients) + runningTop.hours[hour].RUnlock() + } + runningTop.RUnlock() + + // use manual json marshalling because we want maps to be sorted by value + json := bytes.Buffer{} + json.WriteString("{\n") + + gen := func(json *bytes.Buffer, name string, top map[string]int, addComma bool) { + json.WriteString(" \"") + json.WriteString(name) + json.WriteString("\": {\n") + sorted := sortByValue(top) + for i, key := range sorted { + // no more than 50 entries + if i >= 50 { + break + } + json.WriteString(" \"") + json.WriteString(key) + json.WriteString("\": ") + json.WriteString(strconv.Itoa(top[key])) + if i+1 != len(sorted) { + json.WriteByte(',') + } + json.WriteByte('\n') + } + json.WriteString(" }") + if addComma { + json.WriteByte(',') + } + json.WriteByte('\n') + } + gen(&json, "top_queried_domains", domains, true) + gen(&json, "top_blocked_domains", blocked, true) + gen(&json, "top_clients", clients, true) + json.WriteString(" \"stats_period\": \"24 hours\"\n") + json.WriteString("}\n") + + w.Header().Set("Content-Type", "application/json") + _, err := w.Write(json.Bytes()) + if err != nil { + errortext := fmt.Sprintf("Couldn't write body: %s", err) + log.Println(errortext) + http.Error(w, errortext, http.StatusInternalServerError) + } +} + +// helper function for querylog API +func sortByValue(m map[string]int) []string { + type kv struct { + k string + v int + } + var ss []kv + for k, v := range m { + ss = append(ss, kv{k, v}) + } + sort.Slice(ss, func(l, r int) bool { + return ss[l].v > ss[r].v + }) + + sorted := []string{} + for _, v := range ss { + sorted = append(sorted, v.k) + } + return sorted +} diff --git a/helpers.go b/helpers.go index a35a5f1c..1bbca87c 100644 --- a/helpers.go +++ b/helpers.go @@ -9,9 +9,7 @@ import ( "os" "path" "runtime" - "sort" "strings" - "time" ) func clamp(value, low, high int) int { @@ -133,89 +131,6 @@ func generateMapFromStats(stats *periodicStats, start int, end int) map[string]i return result } -// ------------------------------------- -// helper functions for querylog parsing -// ------------------------------------- -func sortByValue(m map[string]int) []string { - type kv struct { - k string - v int - } - var ss []kv - for k, v := range m { - ss = append(ss, kv{k, v}) - } - sort.Slice(ss, func(l, r int) bool { - return ss[l].v > ss[r].v - }) - - sorted := []string{} - for _, v := range ss { - sorted = append(sorted, v.k) - } - return sorted -} - -func getHost(entry map[string]interface{}) string { - q, ok := entry["question"] - if !ok { - return "" - } - question, ok := q.(map[string]interface{}) - if !ok { - return "" - } - h, ok := question["host"] - if !ok { - return "" - } - host, ok := h.(string) - if !ok { - return "" - } - return host -} - -func getReason(entry map[string]interface{}) string { - r, ok := entry["reason"] - if !ok { - return "" - } - reason, ok := r.(string) - if !ok { - return "" - } - return reason -} - -func getClient(entry map[string]interface{}) string { - c, ok := entry["client"] - if !ok { - return "" - } - client, ok := c.(string) - if !ok { - return "" - } - return client -} - -func getTime(entry map[string]interface{}) time.Time { - t, ok := entry["time"] - if !ok { - return time.Time{} - } - tstr, ok := t.(string) - if !ok { - return time.Time{} - } - value, err := time.Parse(time.RFC3339, tstr) - if err != nil { - return time.Time{} - } - return value -} - // ------------------------------------------------- // helper functions for parsing parameters from body // -------------------------------------------------