Fix more race conditions found by race detector

This commit is contained in:
Eugene Bujak 2018-10-07 21:24:22 +03:00
parent dc1042c3e9
commit 3b1faa1365
2 changed files with 23 additions and 19 deletions

View File

@ -45,21 +45,24 @@ func init() {
}) })
} }
type plugSettings struct {
SafeBrowsingBlockHost string
ParentalBlockHost string
QueryLogEnabled bool
BlockedTTL uint32 // in seconds, default 3600
}
type plug struct { type plug struct {
d *dnsfilter.Dnsfilter d *dnsfilter.Dnsfilter
Next plugin.Handler Next plugin.Handler
upstream upstream.Upstream upstream upstream.Upstream
hosts map[string]net.IP hosts map[string]net.IP
settings plugSettings
SafeBrowsingBlockHost string
ParentalBlockHost string
QueryLogEnabled bool
BlockedTTL uint32 // in seconds, default 3600
sync.RWMutex sync.RWMutex
} }
var defaultPlugin = plug{ var defaultPluginSettings = plugSettings{
SafeBrowsingBlockHost: "safebrowsing.block.dns.adguard.com", SafeBrowsingBlockHost: "safebrowsing.block.dns.adguard.com",
ParentalBlockHost: "family.block.dns.adguard.com", ParentalBlockHost: "family.block.dns.adguard.com",
BlockedTTL: 3600, // in seconds BlockedTTL: 3600, // in seconds
@ -91,10 +94,11 @@ var (
// //
func setupPlugin(c *caddy.Controller) (*plug, error) { func setupPlugin(c *caddy.Controller) (*plug, error) {
// create new Plugin and copy default values // create new Plugin and copy default values
var p = new(plug) p := &plug{
*p = defaultPlugin settings: defaultPluginSettings,
p.d = dnsfilter.New() d: dnsfilter.New(),
p.hosts = make(map[string]net.IP) hosts: make(map[string]net.IP),
}
filterFileNames := []string{} filterFileNames := []string{}
for c.Next() { for c.Next() {
@ -130,7 +134,7 @@ func setupPlugin(c *caddy.Controller) (*plug, error) {
if len(c.Val()) == 0 { if len(c.Val()) == 0 {
return nil, c.ArgErr() return nil, c.ArgErr()
} }
p.ParentalBlockHost = c.Val() p.settings.ParentalBlockHost = c.Val()
} }
case "blocked_ttl": case "blocked_ttl":
if !c.NextArg() { if !c.NextArg() {
@ -140,9 +144,9 @@ func setupPlugin(c *caddy.Controller) (*plug, error) {
if err != nil { if err != nil {
return nil, c.ArgErr() return nil, c.ArgErr()
} }
p.BlockedTTL = uint32(blockttl) p.settings.BlockedTTL = uint32(blockttl)
case "querylog": case "querylog":
p.QueryLogEnabled = true p.settings.QueryLogEnabled = true
onceQueryLog.Do(func() { onceQueryLog.Do(func() {
go startQueryLogServer() // TODO: how to handle errors? go startQueryLogServer() // TODO: how to handle errors?
}) })
@ -323,7 +327,7 @@ func (p *plug) replaceHostWithValAndReply(ctx context.Context, w dns.ResponseWri
log.Println("Will give", val, "instead of", host) log.Println("Will give", val, "instead of", host)
if addr != nil { if addr != nil {
// this is an IP address, return it // this is an IP address, return it
result, err := dns.NewRR(fmt.Sprintf("%s %d A %s", host, p.BlockedTTL, val)) result, err := dns.NewRR(fmt.Sprintf("%s %d A %s", host, p.settings.BlockedTTL, val))
if err != nil { if err != nil {
log.Printf("Got error %s\n", err) log.Printf("Got error %s\n", err)
return dns.RcodeServerFailure, fmt.Errorf("plugin/dnsfilter: %s", err) return dns.RcodeServerFailure, fmt.Errorf("plugin/dnsfilter: %s", err)
@ -365,7 +369,7 @@ func (p *plug) replaceHostWithValAndReply(ctx context.Context, w dns.ResponseWri
// the only value that is important is TTL in header, other values like refresh, retry, expire and minttl are irrelevant // the only value that is important is TTL in header, other values like refresh, retry, expire and minttl are irrelevant
func (p *plug) genSOA(r *dns.Msg) []dns.RR { func (p *plug) genSOA(r *dns.Msg) []dns.RR {
zone := r.Question[0].Name zone := r.Question[0].Name
header := dns.RR_Header{Name: zone, Rrtype: dns.TypeSOA, Ttl: p.BlockedTTL, Class: dns.ClassINET} header := dns.RR_Header{Name: zone, Rrtype: dns.TypeSOA, Ttl: p.settings.BlockedTTL, Class: dns.ClassINET}
Mbox := "hostmaster." Mbox := "hostmaster."
if zone[0] != '.' { if zone[0] != '.' {
@ -450,7 +454,7 @@ func (p *plug) serveDNSInternal(ctx context.Context, w dns.ResponseWriter, r *dn
switch result.Reason { switch result.Reason {
case dnsfilter.FilteredSafeBrowsing: case dnsfilter.FilteredSafeBrowsing:
// return cname safebrowsing.block.dns.adguard.com // return cname safebrowsing.block.dns.adguard.com
val := p.SafeBrowsingBlockHost val := p.settings.SafeBrowsingBlockHost
rcode, err := p.replaceHostWithValAndReply(ctx, w, r, host, val, question) rcode, err := p.replaceHostWithValAndReply(ctx, w, r, host, val, question)
if err != nil { if err != nil {
return rcode, dnsfilter.Result{}, err return rcode, dnsfilter.Result{}, err
@ -458,7 +462,7 @@ func (p *plug) serveDNSInternal(ctx context.Context, w dns.ResponseWriter, r *dn
return rcode, result, err return rcode, result, err
case dnsfilter.FilteredParental: case dnsfilter.FilteredParental:
// return cname family.block.dns.adguard.com // return cname family.block.dns.adguard.com
val := p.ParentalBlockHost val := p.settings.ParentalBlockHost
rcode, err := p.replaceHostWithValAndReply(ctx, w, r, host, val, question) rcode, err := p.replaceHostWithValAndReply(ctx, w, r, host, val, question)
if err != nil { if err != nil {
return rcode, dnsfilter.Result{}, err return rcode, dnsfilter.Result{}, err
@ -549,7 +553,7 @@ func (p *plug) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (
} }
// log // log
if p.QueryLogEnabled { if p.settings.QueryLogEnabled {
logRequest(r, rrw.Msg, result, time.Since(start), ip) logRequest(r, rrw.Msg, result, time.Since(start), ip)
} }
return rcode, err return rcode, err

View File

@ -254,7 +254,7 @@ func writeStats() error {
statsFile := filepath.Join(config.ourBinaryDir, "stats.json") statsFile := filepath.Join(config.ourBinaryDir, "stats.json")
log.Printf("Writing JSON file: %s", statsFile) log.Printf("Writing JSON file: %s", statsFile)
statistics.RLock() statistics.RLock()
json, err := json.MarshalIndent(statistics, "", " ") json, err := json.MarshalIndent(&statistics, "", " ")
statistics.RUnlock() statistics.RUnlock()
if err != nil { if err != nil {
log.Printf("Couldn't generate JSON: %s", err) log.Printf("Couldn't generate JSON: %s", err)