- (dnsforward): fix ipset race

https://github.com/AdguardTeam/AdGuardHome/pull/2182#issuecomment-707807737
This commit is contained in:
Andrey Meshkov 2020-10-22 23:31:51 +03:00
parent 0b43bf4cd9
commit e37ccdbb70
1 changed files with 10 additions and 1 deletions

View File

@ -3,6 +3,7 @@ package dnsforward
import ( import (
"net" "net"
"strings" "strings"
"sync"
"github.com/AdguardTeam/AdGuardHome/util" "github.com/AdguardTeam/AdGuardHome/util"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
@ -12,7 +13,9 @@ import (
type ipsetCtx struct { type ipsetCtx struct {
ipsetList map[string][]string // domain -> []ipset_name ipsetList map[string][]string // domain -> []ipset_name
ipsetCache map[[4]byte]bool // cache for IP[] to prevent duplicate calls to ipset program ipsetCache map[[4]byte]bool // cache for IP[] to prevent duplicate calls to ipset program
ipset6Cache map[[16]byte]bool // cache for IP[] to prevent duplicate calls to ipset program ipsetMutex *sync.Mutex
ipset6Cache map[[16]byte]bool // cache for IP[] to prevent duplicate calls to ipset program
ipset6Mutex *sync.Mutex
} }
// Convert configuration settings to an internal map // Convert configuration settings to an internal map
@ -20,7 +23,9 @@ type ipsetCtx struct {
func (c *ipsetCtx) init(ipsetConfig []string) { func (c *ipsetCtx) init(ipsetConfig []string) {
c.ipsetList = make(map[string][]string) c.ipsetList = make(map[string][]string)
c.ipsetCache = make(map[[4]byte]bool) c.ipsetCache = make(map[[4]byte]bool)
c.ipsetMutex = &sync.Mutex{}
c.ipset6Cache = make(map[[16]byte]bool) c.ipset6Cache = make(map[[16]byte]bool)
c.ipset6Mutex = &sync.Mutex{}
for _, it := range ipsetConfig { for _, it := range ipsetConfig {
it = strings.TrimSpace(it) it = strings.TrimSpace(it)
@ -67,6 +72,8 @@ func (c *ipsetCtx) getIP(rr dns.RR) net.IP {
case *dns.A: case *dns.A:
var ip4 [4]byte var ip4 [4]byte
copy(ip4[:], a.A.To4()) copy(ip4[:], a.A.To4())
c.ipsetMutex.Lock()
defer c.ipsetMutex.Unlock()
_, found := c.ipsetCache[ip4] _, found := c.ipsetCache[ip4]
if found { if found {
return nil // this IP was added before return nil // this IP was added before
@ -77,6 +84,8 @@ func (c *ipsetCtx) getIP(rr dns.RR) net.IP {
case *dns.AAAA: case *dns.AAAA:
var ip6 [16]byte var ip6 [16]byte
copy(ip6[:], a.AAAA) copy(ip6[:], a.AAAA)
c.ipset6Mutex.Lock()
defer c.ipset6Mutex.Unlock()
_, found := c.ipset6Cache[ip6] _, found := c.ipset6Cache[ip6]
if found { if found {
return nil // this IP was added before return nil // this IP was added before