From 076c9de68e73bbddc63cf6f7212818c91f3e5c08 Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Fri, 14 Sep 2018 16:50:56 +0300 Subject: [PATCH] Fix many lint warnings found by gometalinter --- config.go | 6 +- control.go | 143 +++++++++++++++++++--- coredns_plugin/coredns_plugin.go | 169 ++++++++++++++------------ coredns_plugin/coredns_plugin_test.go | 12 +- coredns_plugin/ratelimit/ratelimit.go | 38 +++--- coredns_plugin/refuseany/refuseany.go | 37 +++--- dnsfilter/dnsfilter.go | 93 ++++++++------ dnsfilter/dnsfilter_test.go | 41 +++++-- dnsfilter/helpers.go | 29 +---- helpers.go | 26 ++-- stats.go | 6 +- 11 files changed, 359 insertions(+), 241 deletions(-) diff --git a/config.go b/config.go index 3be34373..20142b42 100644 --- a/config.go +++ b/config.go @@ -29,20 +29,20 @@ type configuration struct { } type coreDNSConfig struct { - Port int `yaml:"port"` binaryFile string coreFile string FilterFile string `yaml:"-"` + Port int `yaml:"port"` FilteringEnabled bool `yaml:"filtering_enabled"` SafeBrowsingEnabled bool `yaml:"safebrowsing_enabled"` SafeSearchEnabled bool `yaml:"safesearch_enabled"` + QueryLogEnabled bool `yaml:"querylog_enabled"` ParentalEnabled bool `yaml:"parental_enabled"` ParentalSensitivity int `yaml:"parental_sensitivity"` - QueryLogEnabled bool `yaml:"querylog_enabled"` Pprof string `yaml:"pprof"` - UpstreamDNS []string `yaml:"upstream_dns"` Cache string `yaml:"cache"` Prometheus string `yaml:"prometheus"` + UpstreamDNS []string `yaml:"upstream_dns"` } type filter struct { diff --git a/control.go b/control.go index ffdb2fd3..070a52d2 100644 --- a/control.go +++ b/control.go @@ -125,7 +125,11 @@ func handleStart(w http.ResponseWriter, r *http.Request) { return } - fmt.Fprintf(w, "OK, PID %d\n", coreDNSCommand.Process.Pid) + _, err = fmt.Fprintf(w, "OK, PID %d\n", coreDNSCommand.Process.Pid) + if err != nil { + log.Printf("Couldn't write body in %s(): %s", _Func(), err) + return + } } func childwaiter() { @@ -162,8 +166,11 @@ func handleStop(w http.ResponseWriter, r *http.Request) { http.Error(w, errortext, 500) return } - // this err is ignorable, it shows exit status of coredns - fmt.Fprintf(w, "OK\n%s\n", exitstatus) + _, err = fmt.Fprintf(w, "OK\n%s\n", exitstatus) + if err != nil { + log.Printf("Couldn't write body in %s(): %s", _Func(), err) + return + } } func handleRestart(w http.ResponseWriter, r *http.Request) { @@ -359,7 +366,12 @@ func handleQueryLogEnable(w http.ResponseWriter, r *http.Request) { http.Error(w, errortext, http.StatusInternalServerError) return } - fmt.Fprintf(w, "OK\n") + _, err = fmt.Fprintf(w, "OK\n") + if err != nil { + errortext := fmt.Sprintf("Couldn't write body: %s", err) + log.Println(errortext) + http.Error(w, errortext, http.StatusInternalServerError) + } } func handleQueryLogDisable(w http.ResponseWriter, r *http.Request) { @@ -371,7 +383,13 @@ func handleQueryLogDisable(w http.ResponseWriter, r *http.Request) { http.Error(w, errortext, http.StatusInternalServerError) return } - fmt.Fprintf(w, "OK\n") + _, err = fmt.Fprintf(w, "OK\n") + if err != nil { + errortext := fmt.Sprintf("Couldn't write body: %s", err) + log.Println(errortext) + http.Error(w, errortext, http.StatusInternalServerError) + } + } func handleStatsTop(w http.ResponseWriter, r *http.Request) { @@ -450,7 +468,10 @@ func handleStatsTop(w http.ResponseWriter, r *http.Request) { json.WriteString("\": {\n") sorted := sortByValue(top) for i, key := range sorted { - fmt.Fprintf(json, " \"%s\": %d", key, top[key]) + json.WriteString(" \"") + json.WriteString(key) + json.WriteString("\": ") + json.WriteString(strconv.Itoa(top[key])) if i+1 != len(sorted) { json.WriteByte(',') } @@ -500,7 +521,12 @@ func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) { return } tellCoreDNSToReload() - fmt.Fprintf(w, "OK %d servers\n", len(hosts)) + _, err = fmt.Fprintf(w, "OK %d servers\n", len(hosts)) + if err != nil { + errortext := fmt.Sprintf("Couldn't write body: %s", err) + log.Println(errortext) + http.Error(w, errortext, http.StatusInternalServerError) + } } func parseIPsOptionalPort(input string) []string { @@ -533,7 +559,13 @@ func handleFilteringEnable(w http.ResponseWriter, r *http.Request) { http.Error(w, errortext, http.StatusInternalServerError) return } - fmt.Fprintf(w, "OK\n") + _, err = fmt.Fprintf(w, "OK\n") + if err != nil { + errortext := fmt.Sprintf("Couldn't write body: %s", err) + log.Println(errortext) + http.Error(w, errortext, http.StatusInternalServerError) + } + } func handleFilteringDisable(w http.ResponseWriter, r *http.Request) { @@ -545,7 +577,13 @@ func handleFilteringDisable(w http.ResponseWriter, r *http.Request) { http.Error(w, errortext, http.StatusInternalServerError) return } - fmt.Fprintf(w, "OK\n") + _, err = fmt.Fprintf(w, "OK\n") + if err != nil { + errortext := fmt.Sprintf("Couldn't write body: %s", err) + log.Println(errortext) + http.Error(w, errortext, http.StatusInternalServerError) + } + } func handleFilteringStatus(w http.ResponseWriter, r *http.Request) { @@ -647,7 +685,12 @@ func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) { return } tellCoreDNSToReload() - fmt.Fprintf(w, "OK %d rules\n", filter.RulesCount) + _, err = fmt.Fprintf(w, "OK %d rules\n", filter.RulesCount) + if err != nil { + errortext := fmt.Sprintf("Couldn't write body: %s", err) + log.Println(errortext) + http.Error(w, errortext, http.StatusInternalServerError) + } } func handleFilteringRemoveURL(w http.ResponseWriter, r *http.Request) { @@ -693,7 +736,13 @@ func handleFilteringRemoveURL(w http.ResponseWriter, r *http.Request) { return } tellCoreDNSToReload() - fmt.Fprintf(w, "OK\n") + _, err = fmt.Fprintf(w, "OK\n") + if err != nil { + errortext := fmt.Sprintf("Couldn't write body: %s", err) + log.Println(errortext) + http.Error(w, errortext, http.StatusInternalServerError) + } + } func handleFilteringEnableURL(w http.ResponseWriter, r *http.Request) { @@ -748,7 +797,13 @@ func handleFilteringEnableURL(w http.ResponseWriter, r *http.Request) { return } tellCoreDNSToReload() - fmt.Fprintf(w, "OK\n") + _, err = fmt.Fprintf(w, "OK\n") + if err != nil { + errortext := fmt.Sprintf("Couldn't write body: %s", err) + log.Println(errortext) + http.Error(w, errortext, http.StatusInternalServerError) + } + } func handleFilteringDisableURL(w http.ResponseWriter, r *http.Request) { @@ -800,7 +855,13 @@ func handleFilteringDisableURL(w http.ResponseWriter, r *http.Request) { return } tellCoreDNSToReload() - fmt.Fprintf(w, "OK\n") + _, err = fmt.Fprintf(w, "OK\n") + if err != nil { + errortext := fmt.Sprintf("Couldn't write body: %s", err) + log.Println(errortext) + http.Error(w, errortext, http.StatusInternalServerError) + } + // TODO: regenerate coredns config and tell coredns to reload it if it's running } @@ -829,7 +890,13 @@ func handleFilteringSetRules(w http.ResponseWriter, r *http.Request) { return } tellCoreDNSToReload() - fmt.Fprintf(w, "OK\n") + _, err = fmt.Fprintf(w, "OK\n") + if err != nil { + errortext := fmt.Sprintf("Couldn't write body: %s", err) + log.Println(errortext) + http.Error(w, errortext, http.StatusInternalServerError) + } + } func handleFilteringRefresh(w http.ResponseWriter, r *http.Request) { @@ -1015,7 +1082,13 @@ func handleSafeBrowsingEnable(w http.ResponseWriter, r *http.Request) { http.Error(w, errortext, http.StatusInternalServerError) return } - fmt.Fprintf(w, "OK\n") + _, err = fmt.Fprintf(w, "OK\n") + if err != nil { + errortext := fmt.Sprintf("Couldn't write body: %s", err) + log.Println(errortext) + http.Error(w, errortext, http.StatusInternalServerError) + } + } func handleSafeBrowsingDisable(w http.ResponseWriter, r *http.Request) { @@ -1027,7 +1100,13 @@ func handleSafeBrowsingDisable(w http.ResponseWriter, r *http.Request) { http.Error(w, errortext, http.StatusInternalServerError) return } - fmt.Fprintf(w, "OK\n") + _, err = fmt.Fprintf(w, "OK\n") + if err != nil { + errortext := fmt.Sprintf("Couldn't write body: %s", err) + log.Println(errortext) + http.Error(w, errortext, http.StatusInternalServerError) + } + } func handleSafeBrowsingStatus(w http.ResponseWriter, r *http.Request) { @@ -1104,7 +1183,13 @@ func handleParentalEnable(w http.ResponseWriter, r *http.Request) { http.Error(w, errortext, http.StatusInternalServerError) return } - fmt.Fprintf(w, "OK\n") + _, err = fmt.Fprintf(w, "OK\n") + if err != nil { + errortext := fmt.Sprintf("Couldn't write body: %s", err) + log.Println(errortext) + http.Error(w, errortext, http.StatusInternalServerError) + } + } func handleParentalDisable(w http.ResponseWriter, r *http.Request) { @@ -1116,7 +1201,13 @@ func handleParentalDisable(w http.ResponseWriter, r *http.Request) { http.Error(w, errortext, http.StatusInternalServerError) return } - fmt.Fprintf(w, "OK\n") + _, err = fmt.Fprintf(w, "OK\n") + if err != nil { + errortext := fmt.Sprintf("Couldn't write body: %s", err) + log.Println(errortext) + http.Error(w, errortext, http.StatusInternalServerError) + } + } func handleParentalStatus(w http.ResponseWriter, r *http.Request) { @@ -1157,7 +1248,13 @@ func handleSafeSearchEnable(w http.ResponseWriter, r *http.Request) { http.Error(w, errortext, http.StatusInternalServerError) return } - fmt.Fprintf(w, "OK\n") + _, err = fmt.Fprintf(w, "OK\n") + if err != nil { + errortext := fmt.Sprintf("Couldn't write body: %s", err) + log.Println(errortext) + http.Error(w, errortext, http.StatusInternalServerError) + } + } func handleSafeSearchDisable(w http.ResponseWriter, r *http.Request) { @@ -1169,7 +1266,13 @@ func handleSafeSearchDisable(w http.ResponseWriter, r *http.Request) { http.Error(w, errortext, http.StatusInternalServerError) return } - fmt.Fprintf(w, "OK\n") + _, err = fmt.Fprintf(w, "OK\n") + if err != nil { + errortext := fmt.Sprintf("Couldn't write body: %s", err) + log.Println(errortext) + http.Error(w, errortext, http.StatusInternalServerError) + } + } func handleSafeSearchStatus(w http.ResponseWriter, r *http.Request) { diff --git a/coredns_plugin/coredns_plugin.go b/coredns_plugin/coredns_plugin.go index a61e1134..78c9bb9a 100644 --- a/coredns_plugin/coredns_plugin.go +++ b/coredns_plugin/coredns_plugin.go @@ -45,7 +45,7 @@ func init() { }) } -type Plugin struct { +type plug struct { d *dnsfilter.Dnsfilter Next plugin.Handler upstream upstream.Upstream @@ -56,12 +56,12 @@ type Plugin struct { QueryLogEnabled bool } -var defaultPlugin = Plugin{ +var defaultPlugin = plug{ SafeBrowsingBlockHost: "safebrowsing.block.dns.adguard.com", ParentalBlockHost: "family.block.dns.adguard.com", } -func newDnsCounter(name string, help string) prometheus.Counter { +func newDNSCounter(name string, help string) prometheus.Counter { return prometheus.NewCounter(prometheus.CounterOpts{ Namespace: plugin.Namespace, Subsystem: "dnsfilter", @@ -71,26 +71,26 @@ func newDnsCounter(name string, help string) prometheus.Counter { } var ( - requests = newDnsCounter("requests_total", "Count of requests seen by dnsfilter.") - filtered = newDnsCounter("filtered_total", "Count of requests filtered by dnsfilter.") - filteredLists = newDnsCounter("filtered_lists_total", "Count of requests filtered by dnsfilter using lists.") - filteredSafebrowsing = newDnsCounter("filtered_safebrowsing_total", "Count of requests filtered by dnsfilter using safebrowsing.") - filteredParental = newDnsCounter("filtered_parental_total", "Count of requests filtered by dnsfilter using parental.") - filteredInvalid = newDnsCounter("filtered_invalid_total", "Count of requests filtered by dnsfilter because they were invalid.") - whitelisted = newDnsCounter("whitelisted_total", "Count of requests not filtered by dnsfilter because they are whitelisted.") - safesearch = newDnsCounter("safesearch_total", "Count of requests replaced by dnsfilter safesearch.") - errorsTotal = newDnsCounter("errors_total", "Count of requests that dnsfilter couldn't process because of transitive errors.") + requests = newDNSCounter("requests_total", "Count of requests seen by dnsfilter.") + filtered = newDNSCounter("filtered_total", "Count of requests filtered by dnsfilter.") + filteredLists = newDNSCounter("filtered_lists_total", "Count of requests filtered by dnsfilter using lists.") + filteredSafebrowsing = newDNSCounter("filtered_safebrowsing_total", "Count of requests filtered by dnsfilter using safebrowsing.") + filteredParental = newDNSCounter("filtered_parental_total", "Count of requests filtered by dnsfilter using parental.") + filteredInvalid = newDNSCounter("filtered_invalid_total", "Count of requests filtered by dnsfilter because they were invalid.") + whitelisted = newDNSCounter("whitelisted_total", "Count of requests not filtered by dnsfilter because they are whitelisted.") + safesearch = newDNSCounter("safesearch_total", "Count of requests replaced by dnsfilter safesearch.") + errorsTotal = newDNSCounter("errors_total", "Count of requests that dnsfilter couldn't process because of transitive errors.") ) // // coredns handling functions // -func setupPlugin(c *caddy.Controller) (*Plugin, error) { +func setupPlugin(c *caddy.Controller) (*plug, error) { // create new Plugin and copy default values - var d = new(Plugin) - *d = defaultPlugin - d.d = dnsfilter.New() - d.hosts = make(map[string]net.IP) + var p = new(plug) + *p = defaultPlugin + p.d = dnsfilter.New() + p.hosts = make(map[string]net.IP) var filterFileName string for c.Next() { @@ -103,15 +103,15 @@ func setupPlugin(c *caddy.Controller) (*Plugin, error) { for c.NextBlock() { switch c.Val() { case "safebrowsing": - d.d.EnableSafeBrowsing() + p.d.EnableSafeBrowsing() if c.NextArg() { if len(c.Val()) == 0 { return nil, c.ArgErr() } - d.d.SetSafeBrowsingServer(c.Val()) + p.d.SetSafeBrowsingServer(c.Val()) } case "safesearch": - d.d.EnableSafeSearch() + p.d.EnableSafeSearch() case "parental": if !c.NextArg() { return nil, c.ArgErr() @@ -120,7 +120,7 @@ func setupPlugin(c *caddy.Controller) (*Plugin, error) { if err != nil { return nil, c.ArgErr() } - err = d.d.EnableParental(sensitivity) + err = p.d.EnableParental(sensitivity) if err != nil { return nil, c.ArgErr() } @@ -128,10 +128,10 @@ func setupPlugin(c *caddy.Controller) (*Plugin, error) { if len(c.Val()) == 0 { return nil, c.ArgErr() } - d.ParentalBlockHost = c.Val() + p.ParentalBlockHost = c.Val() } case "querylog": - d.QueryLogEnabled = true + p.QueryLogEnabled = true onceQueryLog.Do(func() { go startQueryLogServer() // TODO: how to handle errors? }) @@ -149,10 +149,10 @@ func setupPlugin(c *caddy.Controller) (*Plugin, error) { scanner := bufio.NewScanner(file) for scanner.Scan() { text := scanner.Text() - if d.parseEtcHosts(text) { + if p.parseEtcHosts(text) { continue } - err = d.d.AddRule(text, 0) + err = p.d.AddRule(text, 0) if err == dnsfilter.ErrInvalidSyntax { continue } @@ -167,23 +167,23 @@ func setupPlugin(c *caddy.Controller) (*Plugin, error) { return nil, err } - d.upstream, err = upstream.New(nil) + p.upstream, err = upstream.New(nil) if err != nil { return nil, err } - return d, nil + return p, nil } func setup(c *caddy.Controller) error { - d, err := setupPlugin(c) + p, err := setupPlugin(c) if err != nil { return err } config := dnsserver.GetConfig(c) config.AddPlugin(func(next plugin.Handler) plugin.Handler { - d.Next = next - return d + p.Next = next + return p }) c.OnStartup(func() error { @@ -200,16 +200,16 @@ func setup(c *caddy.Controller) error { x.MustRegister(whitelisted) x.MustRegister(safesearch) x.MustRegister(errorsTotal) - x.MustRegister(d) + x.MustRegister(p) } return nil }) - c.OnShutdown(d.OnShutdown) + c.OnShutdown(p.onShutdown) return nil } -func (d *Plugin) parseEtcHosts(text string) bool { +func (p *plug) parseEtcHosts(text string) bool { if pos := strings.IndexByte(text, '#'); pos != -1 { text = text[0:pos] } @@ -222,17 +222,17 @@ func (d *Plugin) parseEtcHosts(text string) bool { return false } for _, host := range fields[1:] { - if val, ok := d.hosts[host]; ok { + if val, ok := p.hosts[host]; ok { log.Printf("warning: host %s already has value %s, will overwrite it with %s", host, val, addr) } - d.hosts[host] = addr + p.hosts[host] = addr } return true } -func (d *Plugin) OnShutdown() error { - d.d.Destroy() - d.d = nil +func (p *plug) onShutdown() error { + p.d.Destroy() + p.d = nil return nil } @@ -240,7 +240,7 @@ type statsFunc func(ch interface{}, name string, text string, value float64, val func doDesc(ch interface{}, name string, text string, value float64, valueType prometheus.ValueType) { realch, ok := ch.(chan<- *prometheus.Desc) - if ok == false { + if !ok { log.Printf("Couldn't convert ch to chan<- *prometheus.Desc\n") return } @@ -249,7 +249,7 @@ func doDesc(ch interface{}, name string, text string, value float64, valueType p func doMetric(ch interface{}, name string, text string, value float64, valueType prometheus.ValueType) { realch, ok := ch.(chan<- prometheus.Metric) - if ok == false { + if !ok { log.Printf("Couldn't convert ch to chan<- prometheus.Metric\n") return } @@ -268,21 +268,23 @@ func doStatsLookup(ch interface{}, doFunc statsFunc, name string, lookupstats *d gen(ch, doFunc, fmt.Sprintf("coredns_dnsfilter_%s_pending_max", name), fmt.Sprintf("Maximum number of pending %s HTTP requests", name), float64(lookupstats.PendingMax), prometheus.GaugeValue) } -func (d *Plugin) doStats(ch interface{}, doFunc statsFunc) { - stats := d.d.GetStats() +func (p *plug) doStats(ch interface{}, doFunc statsFunc) { + stats := p.d.GetStats() doStatsLookup(ch, doFunc, "safebrowsing", &stats.Safebrowsing) doStatsLookup(ch, doFunc, "parental", &stats.Parental) } -func (d *Plugin) Describe(ch chan<- *prometheus.Desc) { - d.doStats(ch, doDesc) +// Describe is called by prometheus handler to know stat types +func (p *plug) Describe(ch chan<- *prometheus.Desc) { + p.doStats(ch, doDesc) } -func (d *Plugin) Collect(ch chan<- prometheus.Metric) { - d.doStats(ch, doMetric) +// Collect is called by prometheus handler to collect stats +func (p *plug) Collect(ch chan<- prometheus.Metric) { + p.doStats(ch, doMetric) } -func (d *Plugin) replaceHostWithValAndReply(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, host string, val string, question dns.Question) (int, error) { +func (p *plug) replaceHostWithValAndReply(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, host string, val string, question dns.Question) (int, error) { // check if it's a domain name or IP address addr := net.ParseIP(val) var records []dns.RR @@ -301,7 +303,7 @@ func (d *Plugin) replaceHostWithValAndReply(ctx context.Context, w dns.ResponseW req.SetQuestion(dns.Fqdn(val), question.Qtype) req.RecursionDesired = true reqstate := request.Request{W: w, Req: req, Context: ctx} - result, err := d.upstream.Lookup(reqstate, dns.Fqdn(val), reqstate.QType()) + result, err := p.upstream.Lookup(reqstate, dns.Fqdn(val), reqstate.QType()) if err != nil { log.Printf("Got error %s\n", err) return dns.RcodeServerFailure, fmt.Errorf("plugin/dnsfilter: %s", err) @@ -363,80 +365,80 @@ func writeNXdomain(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, return dns.RcodeNameError, nil } -func (d *Plugin) serveDNSInternal(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error, dnsfilter.Result) { +func (p *plug) serveDNSInternal(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, dnsfilter.Result, error) { if len(r.Question) != 1 { // google DNS, bind and others do the same - return dns.RcodeFormatError, fmt.Errorf("Got DNS request with != 1 questions"), dnsfilter.Result{} + return dns.RcodeFormatError, dnsfilter.Result{}, fmt.Errorf("Got DNS request with != 1 questions") } for _, question := range r.Question { host := strings.ToLower(strings.TrimSuffix(question.Name, ".")) // is it a safesearch domain? - if val, ok := d.d.SafeSearchDomain(host); ok { - rcode, err := d.replaceHostWithValAndReply(ctx, w, r, host, val, question) + if val, ok := p.d.SafeSearchDomain(host); ok { + rcode, err := p.replaceHostWithValAndReply(ctx, w, r, host, val, question) if err != nil { - return rcode, err, dnsfilter.Result{} + return rcode, dnsfilter.Result{}, err } - return rcode, err, dnsfilter.Result{Reason: dnsfilter.FilteredSafeSearch} + return rcode, dnsfilter.Result{Reason: dnsfilter.FilteredSafeSearch}, err } // is it in hosts? - if val, ok := d.hosts[host]; ok { + if val, ok := p.hosts[host]; ok { // it is, if it's a loopback host, reply with NXDOMAIN if val.IsLoopback() { rcode, err := writeNXdomain(ctx, w, r) if err != nil { - return rcode, err, dnsfilter.Result{} + return rcode, dnsfilter.Result{}, err } - return rcode, err, dnsfilter.Result{Reason: dnsfilter.FilteredInvalid} + return rcode, dnsfilter.Result{Reason: dnsfilter.FilteredInvalid}, err } // it's not a loopback host, replace it with value specified - rcode, err := d.replaceHostWithValAndReply(ctx, w, r, host, val.String(), question) + rcode, err := p.replaceHostWithValAndReply(ctx, w, r, host, val.String(), question) if err != nil { - return rcode, err, dnsfilter.Result{} + return rcode, dnsfilter.Result{}, err } - return rcode, err, dnsfilter.Result{Reason: dnsfilter.FilteredSafeSearch} + return rcode, dnsfilter.Result{Reason: dnsfilter.FilteredSafeSearch}, err } // needs to be filtered instead - result, err := d.d.CheckHost(host) + result, err := p.d.CheckHost(host) if err != nil { log.Printf("plugin/dnsfilter: %s\n", err) - return dns.RcodeServerFailure, fmt.Errorf("plugin/dnsfilter: %s", err), dnsfilter.Result{} + return dns.RcodeServerFailure, dnsfilter.Result{}, fmt.Errorf("plugin/dnsfilter: %s", err) } if result.IsFiltered { switch result.Reason { case dnsfilter.FilteredSafeBrowsing: // return cname safebrowsing.block.dns.adguard.com - val := d.SafeBrowsingBlockHost - rcode, err := d.replaceHostWithValAndReply(ctx, w, r, host, val, question) + val := p.SafeBrowsingBlockHost + rcode, err := p.replaceHostWithValAndReply(ctx, w, r, host, val, question) if err != nil { - return rcode, err, dnsfilter.Result{} + return rcode, dnsfilter.Result{}, err } - return rcode, err, result + return rcode, result, err case dnsfilter.FilteredParental: // return cname family.block.dns.adguard.com - val := d.ParentalBlockHost - rcode, err := d.replaceHostWithValAndReply(ctx, w, r, host, val, question) + val := p.ParentalBlockHost + rcode, err := p.replaceHostWithValAndReply(ctx, w, r, host, val, question) if err != nil { - return rcode, err, dnsfilter.Result{} + return rcode, dnsfilter.Result{}, err } - return rcode, err, result + return rcode, result, err case dnsfilter.FilteredBlackList: // return NXdomain rcode, err := writeNXdomain(ctx, w, r) if err != nil { - return rcode, err, dnsfilter.Result{} + return rcode, dnsfilter.Result{}, err } - return rcode, err, result + return rcode, result, err default: log.Printf("SHOULD NOT HAPPEN -- got unknown reason for filtering: %T %v %s", result.Reason, result.Reason, result.Reason.String()) } } else { switch result.Reason { case dnsfilter.NotFilteredWhiteList: - rcode, err := plugin.NextOrFailure(d.Name(), d.Next, ctx, w, r) - return rcode, err, result + rcode, err := plugin.NextOrFailure(p.Name(), p.Next, ctx, w, r) + return rcode, result, err case dnsfilter.NotFilteredNotFound: // do nothing, pass through to lower code default: @@ -444,11 +446,12 @@ func (d *Plugin) serveDNSInternal(ctx context.Context, w dns.ResponseWriter, r * } } } - rcode, err := plugin.NextOrFailure(d.Name(), d.Next, ctx, w, r) - return rcode, err, dnsfilter.Result{} + rcode, err := plugin.NextOrFailure(p.Name(), p.Next, ctx, w, r) + return rcode, dnsfilter.Result{}, err } -func (d *Plugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { +// ServeDNS handles the DNS request and refuses if it's in filterlists +func (p *plug) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { start := time.Now() requests.Inc() state := request.Request{W: w, Req: r} @@ -456,13 +459,16 @@ func (d *Plugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) // capture the written answer rrw := dnstest.NewRecorder(w) - rcode, err, result := d.serveDNSInternal(ctx, rrw, r) + rcode, result, err := p.serveDNSInternal(ctx, rrw, r) if rcode > 0 { // actually send the answer if we have one answer := new(dns.Msg) answer.SetRcode(r, rcode) state.SizeAndDo(answer) - w.WriteMsg(answer) + err = w.WriteMsg(answer) + if err != nil { + return dns.RcodeServerFailure, err + } } // increment counters @@ -496,12 +502,13 @@ func (d *Plugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) } // log - if d.QueryLogEnabled { + if p.QueryLogEnabled { logRequest(r, rrw.Msg, result, time.Since(start), ip) } return rcode, err } -func (d *Plugin) Name() string { return "dnsfilter" } +// Name returns name of the plugin as seen in Corefile and plugin.cfg +func (p *plug) Name() string { return "dnsfilter" } var onceQueryLog sync.Once diff --git a/coredns_plugin/coredns_plugin_test.go b/coredns_plugin/coredns_plugin_test.go index 5fceb997..da4f5cd4 100644 --- a/coredns_plugin/coredns_plugin_test.go +++ b/coredns_plugin/coredns_plugin_test.go @@ -46,10 +46,10 @@ func TestEtcHostsParse(t *testing.T) { if err != nil { t.Fatal(err) } - if _, err := tmpfile.Write(text); err != nil { + if _, err = tmpfile.Write(text); err != nil { t.Fatal(err) } - if err := tmpfile.Close(); err != nil { + if err = tmpfile.Close(); err != nil { t.Fatal(err) } @@ -80,10 +80,10 @@ func TestEtcHostsFilter(t *testing.T) { if err != nil { t.Fatal(err) } - if _, err := tmpfile.Write(text); err != nil { + if _, err = tmpfile.Write(text); err != nil { t.Fatal(err) } - if err := tmpfile.Close(); err != nil { + if err = tmpfile.Close(); err != nil { t.Fatal(err) } @@ -127,10 +127,10 @@ func TestEtcHostsFilter(t *testing.T) { t.Fatalf("ServeDNS return value for host %s has rcode %d that does not match captured rcode %d", testcase.host, rcode, rrw.Rcode) } filtered := rcode == dns.RcodeNameError - if testcase.filtered == true && testcase.filtered != filtered { + if testcase.filtered && testcase.filtered != filtered { t.Fatalf("Host %s expected to be filtered, instead it is not filtered", testcase.host) } - if testcase.filtered == false && testcase.filtered != filtered { + if !testcase.filtered && testcase.filtered != filtered { t.Fatalf("Host %s expected to be not filtered, instead it is filtered", testcase.host) } } diff --git a/coredns_plugin/ratelimit/ratelimit.go b/coredns_plugin/ratelimit/ratelimit.go index 29d335ba..2a43644c 100644 --- a/coredns_plugin/ratelimit/ratelimit.go +++ b/coredns_plugin/ratelimit/ratelimit.go @@ -4,7 +4,6 @@ import ( "errors" "log" "strconv" - "sync" "time" // ratelimiting and per-ip buckets @@ -29,8 +28,8 @@ var ( tokenBuckets = cache.New(time.Hour, time.Hour) ) -// main function -func (p *Plugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { +// ServeDNS handles the DNS request and refuses if it's an beyind specified ratelimit +func (p *plug) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { state := request.Request{W: w, Req: r} ip := state.IP() allow, err := p.allowRequest(ip) @@ -44,7 +43,7 @@ func (p *Plugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) return plugin.NextOrFailure(p.Name(), p.Next, ctx, w, r) } -func (p *Plugin) allowRequest(ip string) (bool, error) { +func (p *plug) allowRequest(ip string) (bool, error) { if _, found := tokenBuckets.Get(ip); !found { tokenBuckets.Set(ip, rate.New(p.ratelimit, time.Second), time.Hour) } @@ -59,7 +58,7 @@ func (p *Plugin) allowRequest(ip string) (bool, error) { } rl, ok := value.(*rate.RateLimiter) - if ok == false { + if !ok { text := "SHOULD NOT HAPPEN: non-bool entry found in safebrowsing lookup cache" log.Println(text) err := errors.New(text) @@ -80,7 +79,7 @@ func init() { }) } -type Plugin struct { +type plug struct { Next plugin.Handler // configuration for creating above @@ -88,7 +87,7 @@ type Plugin struct { } func setup(c *caddy.Controller) error { - p := &Plugin{ratelimit: defaultRatelimit} + p := &plug{ratelimit: defaultRatelimit} config := dnsserver.GetConfig(c) for c.Next() { @@ -109,22 +108,20 @@ func setup(c *caddy.Controller) error { }) c.OnStartup(func() error { - once.Do(func() { - m := dnsserver.GetConfig(c).Handler("prometheus") - if m == nil { - return - } - if x, ok := m.(*metrics.Metrics); ok { - x.MustRegister(ratelimited) - } - }) + m := dnsserver.GetConfig(c).Handler("prometheus") + if m == nil { + return nil + } + if x, ok := m.(*metrics.Metrics); ok { + x.MustRegister(ratelimited) + } return nil }) return nil } -func newDnsCounter(name string, help string) prometheus.Counter { +func newDNSCounter(name string, help string) prometheus.Counter { return prometheus.NewCounter(prometheus.CounterOpts{ Namespace: plugin.Namespace, Subsystem: "ratelimit", @@ -134,9 +131,8 @@ func newDnsCounter(name string, help string) prometheus.Counter { } var ( - ratelimited = newDnsCounter("dropped_total", "Count of requests that have been dropped because of rate limit") + ratelimited = newDNSCounter("dropped_total", "Count of requests that have been dropped because of rate limit") ) -func (d *Plugin) Name() string { return "ratelimit" } - -var once sync.Once +// Name returns name of the plugin as seen in Corefile and plugin.cfg +func (p *plug) Name() string { return "ratelimit" } diff --git a/coredns_plugin/refuseany/refuseany.go b/coredns_plugin/refuseany/refuseany.go index 4517b173..02721cc1 100644 --- a/coredns_plugin/refuseany/refuseany.go +++ b/coredns_plugin/refuseany/refuseany.go @@ -3,7 +3,6 @@ package refuseany import ( "fmt" "log" - "sync" "github.com/coredns/coredns/core/dnsserver" "github.com/coredns/coredns/plugin" @@ -15,11 +14,12 @@ import ( "golang.org/x/net/context" ) -type Plugin struct { +type plug struct { Next plugin.Handler } -func (p *Plugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { +// ServeDNS handles the DNS request and refuses if it's an ANY request +func (p *plug) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { if len(r.Question) != 1 { // google DNS, bind and others do the same return dns.RcodeFormatError, fmt.Errorf("Got DNS request with != 1 questions") @@ -41,9 +41,9 @@ func (p *Plugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) return dns.RcodeServerFailure, err } return rcode, nil - } else { - return plugin.NextOrFailure(p.Name(), p.Next, ctx, w, r) } + + return plugin.NextOrFailure(p.Name(), p.Next, ctx, w, r) } func init() { @@ -54,7 +54,7 @@ func init() { } func setup(c *caddy.Controller) error { - p := &Plugin{} + p := &plug{} config := dnsserver.GetConfig(c) config.AddPlugin(func(next plugin.Handler) plugin.Handler { @@ -63,22 +63,20 @@ func setup(c *caddy.Controller) error { }) c.OnStartup(func() error { - once.Do(func() { - m := dnsserver.GetConfig(c).Handler("prometheus") - if m == nil { - return - } - if x, ok := m.(*metrics.Metrics); ok { - x.MustRegister(ratelimited) - } - }) + m := dnsserver.GetConfig(c).Handler("prometheus") + if m == nil { + return nil + } + if x, ok := m.(*metrics.Metrics); ok { + x.MustRegister(ratelimited) + } return nil }) return nil } -func newDnsCounter(name string, help string) prometheus.Counter { +func newDNSCounter(name string, help string) prometheus.Counter { return prometheus.NewCounter(prometheus.CounterOpts{ Namespace: plugin.Namespace, Subsystem: "refuseany", @@ -88,9 +86,8 @@ func newDnsCounter(name string, help string) prometheus.Counter { } var ( - ratelimited = newDnsCounter("refusedany_total", "Count of ANY requests that have been dropped") + ratelimited = newDNSCounter("refusedany_total", "Count of ANY requests that have been dropped") ) -func (d *Plugin) Name() string { return "refuseany" } - -var once sync.Once +// Name returns name of the plugin as seen in Corefile and plugin.cfg +func (p *plug) Name() string { return "refuseany" } diff --git a/dnsfilter/dnsfilter.go b/dnsfilter/dnsfilter.go index 90db03b7..9a031219 100644 --- a/dnsfilter/dnsfilter.go +++ b/dnsfilter/dnsfilter.go @@ -31,7 +31,10 @@ const defaultSafebrowsingURL = "http://%s/safebrowsing-lookup-hash.html?prefixes const defaultParentalServer = "pctrl.adguard.com" const defaultParentalURL = "http://%s/check-parental-control-hash?prefixes=%s&sensitivity=%d" +// ErrInvalidSyntax is returned by AddRule when rule is invalid var ErrInvalidSyntax = errors.New("dnsfilter: invalid rule syntax") + +// ErrInvalidParental is returned by EnableParental when sensitivity is not a valid value var ErrInvalidParental = errors.New("dnsfilter: invalid parental sensitivity, must be either 3, 10, 13 or 17") const shortcutLength = 6 // used for rule search optimization, 6 hits the sweet spot @@ -39,16 +42,16 @@ const shortcutLength = 6 // used for rule search optimization, 6 hits the sweet const enableFastLookup = true // flag for debugging, must be true in production for faster performance const enableDelayedCompilation = true // flag for debugging, must be true in production for faster performance -type Config struct { +type config struct { + parentalServer string + parentalSensitivity int // must be either 3, 10, 13 or 17 + parentalEnabled bool safeSearchEnabled bool safeBrowsingEnabled bool safeBrowsingServer string - parentalEnabled bool - parentalServer string - parentalSensitivity int // must be either 3, 10, 13 or 17 } -type Rule struct { +type rule struct { text string // text without @@ decorators or $ options shortcut string // for speeding up lookup originalText string // original text for reporting back to applications @@ -57,9 +60,9 @@ type Rule struct { options []string // optional options after $ // parsed options + apps []string isWhitelist bool isImportant bool - apps []string // user-supplied data listID uint32 @@ -70,6 +73,7 @@ type Rule struct { sync.RWMutex } +// LookupStats store stats collected during safebrowsing or parental checks type LookupStats struct { Requests uint64 // number of HTTP requests that were sent CacheHits uint64 // number of lookups that didn't need HTTP requests @@ -77,6 +81,7 @@ type LookupStats struct { PendingMax int64 // maximum number of pending HTTP requests } +// Stats store LookupStats for both safebrowsing and parental type Stats struct { Safebrowsing LookupStats Parental LookupStats @@ -84,7 +89,7 @@ type Stats struct { // Dnsfilter holds added rules and performs hostname matches against the rules type Dnsfilter struct { - storage map[string]*Rule // rule storage, not used for matching, needs to be key->value + storage map[string]*rule // rule storage, not used for matching, needs to be key->value storageMutex sync.RWMutex // rules are checked against these lists in the order defined here @@ -96,12 +101,12 @@ type Dnsfilter struct { client http.Client // handle for http client -- single instance as recommended by docs transport *http.Transport // handle for http transport used by http client - config Config + config config } //go:generate stringer -type=Reason -// filtered/notfiltered reason +// Reason holds an enum detailing why it was filtered or not filtered type Reason int const ( @@ -125,13 +130,14 @@ var ( parentalCache = gcache.New(defaultCacheSize).LRU().Expiration(defaultCacheTime).Build() ) -// search result +// Result holds state of hostname check type Result struct { IsFiltered bool Reason Reason Rule string } +// Matched can be used to see if any match at all was found, no matter filtered or not func (r Reason) Matched() bool { return r != NotFilteredNotFound } @@ -188,19 +194,19 @@ func (d *Dnsfilter) CheckHost(host string) (Result, error) { // type rulesTable struct { - rulesByShortcut map[string][]*Rule - rulesLeftovers []*Rule + rulesByShortcut map[string][]*rule + rulesLeftovers []*rule sync.RWMutex } func newRulesTable() *rulesTable { return &rulesTable{ - rulesByShortcut: make(map[string][]*Rule), - rulesLeftovers: make([]*Rule, 0), + rulesByShortcut: make(map[string][]*rule), + rulesLeftovers: make([]*rule, 0), } } -func (r *rulesTable) Add(rule *Rule) { +func (r *rulesTable) Add(rule *rule) { r.Lock() if len(rule.shortcut) == shortcutLength && enableFastLookup { r.rulesByShortcut[rule.shortcut] = append(r.rulesByShortcut[rule.shortcut], rule) @@ -295,7 +301,7 @@ func findOptionIndex(text string) int { return -1 } -func (rule *Rule) extractOptions() error { +func (rule *rule) extractOptions() error { optIndex := findOptionIndex(rule.text) if optIndex == 0 { // starts with $ return ErrInvalidSyntax @@ -333,7 +339,7 @@ func (rule *Rule) extractOptions() error { return nil } -func (rule *Rule) parseOptions() error { +func (rule *rule) parseOptions() error { err := rule.extractOptions() if err != nil { return err @@ -354,7 +360,7 @@ func (rule *Rule) parseOptions() error { return nil } -func (rule *Rule) extractShortcut() { +func (rule *rule) extractShortcut() { // regex rules have no shortcuts if rule.text[0] == '/' && rule.text[len(rule.text)-1] == '/' { return @@ -379,7 +385,7 @@ func (rule *Rule) extractShortcut() { rule.shortcut = strings.ToLower(longestField) } -func (rule *Rule) compile() error { +func (rule *rule) compile() error { rule.RLock() isCompiled := rule.compiled != nil rule.RUnlock() @@ -404,7 +410,7 @@ func (rule *Rule) compile() error { return nil } -func (rule *Rule) match(host string) (Result, error) { +func (rule *rule) match(host string) (Result, error) { res := Result{} err := rule.compile() if err != nil { @@ -442,7 +448,7 @@ func getCachedReason(cache gcache.Cache, host string) (result Result, isFound bo // since it can be something else, validate that it belongs to proper type cachedValue, ok := rawValue.(Result) - if ok == false { + if !ok { // this is not our type -- error text := "SHOULD NOT HAPPEN: entry with invalid type was found in lookup cache" log.Println(text) @@ -458,7 +464,7 @@ func hostnameToHashParam(host string, addslash bool) (string, map[string]bool) { var hashparam bytes.Buffer hashes := map[string]bool{} tld, icann := publicsuffix.PublicSuffix(host) - if icann == false { + if !icann { // private suffixes like cloudfront.net tld = "" } @@ -612,7 +618,10 @@ func (d *Dnsfilter) lookupCommon(host string, lookupstats *LookupStats, cache gc switch { case resp.StatusCode == 204: // empty result, save cache - cache.Set(host, Result{}) + err = cache.Set(host, Result{}) + if err != nil { + return Result{}, err + } return Result{}, nil case resp.StatusCode != 200: // error, don't save cache @@ -625,7 +634,10 @@ func (d *Dnsfilter) lookupCommon(host string, lookupstats *LookupStats, cache gc return Result{}, err } - cache.Set(host, result) + err = cache.Set(host, result) + if err != nil { + return Result{}, err + } return result, nil } @@ -648,7 +660,7 @@ func (d *Dnsfilter) AddRule(input string, filterListID uint32) error { return ErrInvalidSyntax } - rule := Rule{ + rule := rule{ text: input, // will be modified originalText: input, listID: filterListID, @@ -712,10 +724,11 @@ func (d *Dnsfilter) matchHost(host string) (Result, error) { // lifecycle helper functions // +// New creates properly initialized DNS Filter that is ready to be used func New() *Dnsfilter { d := new(Dnsfilter) - d.storage = make(map[string]*Rule) + d.storage = make(map[string]*rule) d.important = newRulesTable() d.whiteList = newRulesTable() d.blackList = newRulesTable() @@ -739,6 +752,8 @@ func New() *Dnsfilter { return d } +// Destroy is optional if you want to tidy up goroutines without waiting for them to die off +// right now it closes idle HTTP connections if there are any func (d *Dnsfilter) Destroy() { d.transport.CloseIdleConnections() } @@ -747,10 +762,12 @@ func (d *Dnsfilter) Destroy() { // config manipulation helpers // +// EnableSafeBrowsing turns on checking hostnames in malware/phishing database func (d *Dnsfilter) EnableSafeBrowsing() { d.config.safeBrowsingEnabled = true } +// EnableParental turns on checking hostnames for containing adult content func (d *Dnsfilter) EnableParental(sensitivity int) error { switch sensitivity { case 3, 10, 13, 17: @@ -762,10 +779,13 @@ func (d *Dnsfilter) EnableParental(sensitivity int) error { } } +// EnableSafeSearch turns on enforcing safesearch in search engines +// only used in coredns plugin and requires caller to use SafeSearchDomain() func (d *Dnsfilter) EnableSafeSearch() { d.config.safeSearchEnabled = true } +// SetSafeBrowsingServer lets you optionally change hostname of safesearch lookup func (d *Dnsfilter) SetSafeBrowsingServer(host string) { if len(host) == 0 { d.config.safeBrowsingServer = defaultSafebrowsingServer @@ -774,38 +794,35 @@ func (d *Dnsfilter) SetSafeBrowsingServer(host string) { } } +// SetHTTPTimeout lets you optionally change timeout during lookups func (d *Dnsfilter) SetHTTPTimeout(t time.Duration) { d.client.Timeout = t } +// ResetHTTPTimeout resets lookup timeouts func (d *Dnsfilter) ResetHTTPTimeout() { d.client.Timeout = defaultHTTPTimeout } +// SafeSearchDomain returns replacement address for search engine func (d *Dnsfilter) SafeSearchDomain(host string) (string, bool) { - if d.config.safeSearchEnabled == false { - return "", false + if d.config.safeSearchEnabled { + val, ok := safeSearchDomains[host] + return val, ok } - val, ok := safeSearchDomains[host] - return val, ok + return "", false } // // stats // +// GetStats return dns filtering stats since startup func (d *Dnsfilter) GetStats() Stats { return stats } +// Count returns number of rules added to filter func (d *Dnsfilter) Count() int { return len(d.storage) } - -// -// cache control, right now needed only for tests -// -func purgeCaches() { - safebrowsingCache.Purge() - parentalCache.Purge() -} diff --git a/dnsfilter/dnsfilter_test.go b/dnsfilter/dnsfilter_test.go index 387a58a1..82ca472c 100644 --- a/dnsfilter/dnsfilter_test.go +++ b/dnsfilter/dnsfilter_test.go @@ -3,6 +3,8 @@ package dnsfilter import ( "net/http" "net/http/httptest" + "path" + "strings" "testing" "time" @@ -217,12 +219,6 @@ func TestAddRuleFail(t *testing.T) { d.checkAddRuleFail(t, "lkfaojewhoawehfwacoefawr$@#$@3413841384") } -func printMemStats(r runtime.MemStats) { - fmt.Printf("Alloc: %.2f, HeapAlloc: %.2f Mb, Sys: %.2f Mb, HeapSys: %.2f Mb\n", - float64(r.Alloc)/1024.0/1024.0, float64(r.HeapAlloc)/1024.0/1024.0, - float64(r.Sys)/1024.0/1024.0, float64(r.HeapSys)/1024.0/1024.0) -} - func TestLotsOfRulesMemoryUsage(t *testing.T) { var start, afterLoad, end runtime.MemStats runtime.GC() @@ -251,10 +247,10 @@ func TestLotsOfRulesMemoryUsage(t *testing.T) { if err != nil { t.Errorf("Error while matching host %s: %s", testcase.host, err) } - if ret.IsFiltered == false && ret.IsFiltered != testcase.match { + if !ret.IsFiltered && ret.IsFiltered != testcase.match { t.Errorf("Expected hostname %s to not match", testcase.host) } - if ret.IsFiltered == true && ret.IsFiltered != testcase.match { + if ret.IsFiltered && ret.IsFiltered != testcase.match { t.Errorf("Expected hostname %s to match", testcase.host) } } @@ -642,3 +638,32 @@ func BenchmarkSafeSearchParallel(b *testing.B) { func TestMain(m *testing.M) { goleak.VerifyTestMain(m) } + +// +// helper functions for debugging and testing +// +func purgeCaches() { + safebrowsingCache.Purge() + parentalCache.Purge() +} + +func _Func() string { + pc := make([]uintptr, 10) // at least 1 entry needed + runtime.Callers(2, pc) + f := runtime.FuncForPC(pc[0]) + return path.Base(f.Name()) +} + +func trace(format string, args ...interface{}) { + pc := make([]uintptr, 10) // at least 1 entry needed + runtime.Callers(2, pc) + f := runtime.FuncForPC(pc[0]) + var buf strings.Builder + buf.WriteString(fmt.Sprintf("%s(): ", path.Base(f.Name()))) + text := fmt.Sprintf(format, args...) + buf.WriteString(text) + if len(text) == 0 || text[len(text)-1] != '\n' { + buf.WriteRune('\n') + } + fmt.Print(buf.String()) +} diff --git a/dnsfilter/helpers.go b/dnsfilter/helpers.go index a08b453e..fb34cff9 100644 --- a/dnsfilter/helpers.go +++ b/dnsfilter/helpers.go @@ -1,9 +1,6 @@ package dnsfilter import ( - "fmt" - "path" - "runtime" "strings" "sync/atomic" ) @@ -49,33 +46,9 @@ func updateMax(valuePtr *int64, maxPtr *int64) { break } swapped := atomic.CompareAndSwapInt64(maxPtr, max, current) - if swapped == true { + if swapped { break } // swapping failed because value has changed after reading, try again } } - -// -// helper functions for debugging and testing -// -func _Func() string { - pc := make([]uintptr, 10) // at least 1 entry needed - runtime.Callers(2, pc) - f := runtime.FuncForPC(pc[0]) - return path.Base(f.Name()) -} - -func trace(format string, args ...interface{}) { - pc := make([]uintptr, 10) // at least 1 entry needed - runtime.Callers(2, pc) - f := runtime.FuncForPC(pc[0]) - var buf strings.Builder - buf.WriteString(fmt.Sprintf("%s(): ", path.Base(f.Name()))) - text := fmt.Sprintf(format, args...) - buf.WriteString(text) - if len(text) == 0 || text[len(text)-1] != '\n' { - buf.WriteRune('\n') - } - fmt.Print(buf.String()) -} diff --git a/helpers.go b/helpers.go index d39c5e93..62e260e6 100644 --- a/helpers.go +++ b/helpers.go @@ -5,6 +5,8 @@ import ( "errors" "io" "net/http" + "path" + "runtime" "sort" "strings" "time" @@ -80,7 +82,7 @@ func generateMapFromStats(stats *periodicStats, start int, end int) map[string]i result := map[string]interface{}{ "dns_queries": getReversedSlice(stats.entries[totalRequests], start, end), - "blocked_filtering": getReversedSlice(stats.entries[filteredLists], start, end), + "blocked_filtering": getReversedSlice(stats.entries[filteredTotal], start, end), "replaced_safebrowsing": getReversedSlice(stats.entries[filteredSafebrowsing], start, end), "replaced_safesearch": getReversedSlice(stats.entries[filteredSafesearch], start, end), "replaced_parental": getReversedSlice(stats.entries[filteredParental], start, end), @@ -89,18 +91,6 @@ func generateMapFromStats(stats *periodicStats, start int, end int) map[string]i return result } -func produceTop(m map[string]int, top int) map[string]int { - toMarshal := map[string]int{} - topKeys := sortByValue(m) - for i, k := range topKeys { - if i == top { - break - } - toMarshal[k] = m[k] - } - return toMarshal -} - // ------------------------------------- // helper functions for querylog parsing // ------------------------------------- @@ -206,3 +196,13 @@ func parseParametersFromBody(r io.Reader) (map[string]string, error) { return parameters, nil } + +// --------------------- +// debug logging helpers +// --------------------- +func _Func() string { + pc := make([]uintptr, 10) // at least 1 entry needed + runtime.Callers(2, pc) + f := runtime.FuncForPC(pc[0]) + return path.Base(f.Name()) +} diff --git a/stats.go b/stats.go index 4829cd0c..f2509334 100644 --- a/stats.go +++ b/stats.go @@ -26,7 +26,6 @@ const ( statsHistoryElements = 60 + 1 // +1 for calculating delta totalRequests = `coredns_dns_request_count_total` filteredTotal = `coredns_dnsfilter_filtered_total` - filteredLists = `coredns_dnsfilter_filtered_lists_total` filteredSafebrowsing = `coredns_dnsfilter_filtered_safebrowsing_total` filteredSafesearch = `coredns_dnsfilter_safesearch_total` filteredParental = `coredns_dnsfilter_filtered_parental_total` @@ -124,9 +123,10 @@ func collectStats() { defer resp.Body.Close() } if err != nil { - if isConnRefused(err) == false { - log.Printf("Couldn't get coredns metrics: %T %s\n", err, err) + if isConnRefused(err) { + return } + log.Printf("Couldn't get coredns metrics: %T %s\n", err, err) return }