*(dnsforward): cache upstream instances
✅ Closes: https://github.com/AdguardTeam/AdGuardHome/issues/1296
This commit is contained in:
parent
3dd91cf179
commit
cdd55139fa
@ -113,7 +113,7 @@ type FilteringConfig struct {
|
|||||||
FilterHandler func(clientAddr string, settings *dnsfilter.RequestFilteringSettings) `yaml:"-"`
|
FilterHandler func(clientAddr string, settings *dnsfilter.RequestFilteringSettings) `yaml:"-"`
|
||||||
|
|
||||||
// This callback function returns the list of upstream servers for a client specified by IP address
|
// This callback function returns the list of upstream servers for a client specified by IP address
|
||||||
GetUpstreamsByClient func(clientAddr string) []string `yaml:"-"`
|
GetUpstreamsByClient func(clientAddr string) []upstream.Upstream `yaml:"-"`
|
||||||
|
|
||||||
ProtectionEnabled bool `yaml:"protection_enabled"` // whether or not use any of dnsfilter features
|
ProtectionEnabled bool `yaml:"protection_enabled"` // whether or not use any of dnsfilter features
|
||||||
|
|
||||||
@ -465,13 +465,9 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
|
|||||||
if d.Addr != nil && s.conf.GetUpstreamsByClient != nil {
|
if d.Addr != nil && s.conf.GetUpstreamsByClient != nil {
|
||||||
clientIP := ipFromAddr(d.Addr)
|
clientIP := ipFromAddr(d.Addr)
|
||||||
upstreams := s.conf.GetUpstreamsByClient(clientIP)
|
upstreams := s.conf.GetUpstreamsByClient(clientIP)
|
||||||
for _, us := range upstreams {
|
if len(upstreams) > 0 {
|
||||||
u, err := upstream.AddressToUpstream(us, upstream.Options{Timeout: 30 * time.Second})
|
log.Debug("Using custom upstreams for %s", clientIP)
|
||||||
if err != nil {
|
d.Upstreams = upstreams
|
||||||
log.Error("upstream.AddressToUpstream: %s: %s", us, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
d.Upstreams = append(d.Upstreams, u)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -14,6 +14,7 @@ import (
|
|||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/dhcpd"
|
"github.com/AdguardTeam/AdGuardHome/dhcpd"
|
||||||
"github.com/AdguardTeam/AdGuardHome/dnsforward"
|
"github.com/AdguardTeam/AdGuardHome/dnsforward"
|
||||||
|
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/AdguardTeam/golibs/utils"
|
"github.com/AdguardTeam/golibs/utils"
|
||||||
)
|
)
|
||||||
@ -62,8 +63,14 @@ type clientsContainer struct {
|
|||||||
list map[string]*Client // name -> client
|
list map[string]*Client // name -> client
|
||||||
idIndex map[string]*Client // IP -> client
|
idIndex map[string]*Client // IP -> client
|
||||||
ipHost map[string]*ClientHost // IP -> Hostname
|
ipHost map[string]*ClientHost // IP -> Hostname
|
||||||
|
|
||||||
|
// cache for Upstream instances that are used in the case
|
||||||
|
// when custom DNS servers are configured for a client
|
||||||
|
upstreamsCache map[string][]upstream.Upstream // name -> []Upstream
|
||||||
|
|
||||||
lock sync.Mutex
|
lock sync.Mutex
|
||||||
|
|
||||||
|
// dhcpServer is used for looking up clients IP addresses by MAC addresses
|
||||||
dhcpServer *dhcpd.Server
|
dhcpServer *dhcpd.Server
|
||||||
|
|
||||||
testing bool // if TRUE, this object is used for internal tests
|
testing bool // if TRUE, this object is used for internal tests
|
||||||
@ -78,6 +85,7 @@ func (clients *clientsContainer) Init(objects []clientObject, dhcpServer *dhcpd.
|
|||||||
clients.list = make(map[string]*Client)
|
clients.list = make(map[string]*Client)
|
||||||
clients.idIndex = make(map[string]*Client)
|
clients.idIndex = make(map[string]*Client)
|
||||||
clients.ipHost = make(map[string]*ClientHost)
|
clients.ipHost = make(map[string]*ClientHost)
|
||||||
|
clients.upstreamsCache = make(map[string][]upstream.Upstream)
|
||||||
clients.dhcpServer = dhcpServer
|
clients.dhcpServer = dhcpServer
|
||||||
clients.addFromConfig(objects)
|
clients.addFromConfig(objects)
|
||||||
|
|
||||||
@ -191,6 +199,45 @@ func (clients *clientsContainer) Find(ip string) (Client, bool) {
|
|||||||
return clients.findByIP(ip)
|
return clients.findByIP(ip)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FindUpstreams looks for upstreams configured for the client
|
||||||
|
// If no client found for this IP, or if no custom upstreams are configured,
|
||||||
|
// this method returns nil
|
||||||
|
func (clients *clientsContainer) FindUpstreams(ip string) []upstream.Upstream {
|
||||||
|
clients.lock.Lock()
|
||||||
|
defer clients.lock.Unlock()
|
||||||
|
|
||||||
|
c, ok := clients.findByIP(ip)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(c.Upstreams) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
upstreams, ok := clients.upstreamsCache[c.Name]
|
||||||
|
if ok {
|
||||||
|
return upstreams
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, us := range c.Upstreams {
|
||||||
|
u, err := upstream.AddressToUpstream(us, upstream.Options{Timeout: dnsforward.DefaultTimeout})
|
||||||
|
if err != nil {
|
||||||
|
log.Error("upstream.AddressToUpstream: %s: %s", us, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
upstreams = append(upstreams, u)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(upstreams) == 0 {
|
||||||
|
clients.upstreamsCache[c.Name] = nil
|
||||||
|
} else {
|
||||||
|
clients.upstreamsCache[c.Name] = upstreams
|
||||||
|
}
|
||||||
|
|
||||||
|
return upstreams
|
||||||
|
}
|
||||||
|
|
||||||
// Find searches for a client by IP (and does not lock anything)
|
// Find searches for a client by IP (and does not lock anything)
|
||||||
func (clients *clientsContainer) findByIP(ip string) (Client, bool) {
|
func (clients *clientsContainer) findByIP(ip string) (Client, bool) {
|
||||||
ipAddr := net.ParseIP(ip)
|
ipAddr := net.ParseIP(ip)
|
||||||
@ -355,6 +402,9 @@ func (clients *clientsContainer) Del(name string) bool {
|
|||||||
// update Name index
|
// update Name index
|
||||||
delete(clients.list, name)
|
delete(clients.list, name)
|
||||||
|
|
||||||
|
// update upstreams cache
|
||||||
|
delete(clients.upstreamsCache, name)
|
||||||
|
|
||||||
// update ID index
|
// update ID index
|
||||||
for _, id := range c.IDs {
|
for _, id := range c.IDs {
|
||||||
delete(clients.idIndex, id)
|
delete(clients.idIndex, id)
|
||||||
@ -418,10 +468,13 @@ func (clients *clientsContainer) Update(name string, c Client) error {
|
|||||||
|
|
||||||
// update Name index
|
// update Name index
|
||||||
if old.Name != c.Name {
|
if old.Name != c.Name {
|
||||||
delete(clients.list, old.Name)
|
|
||||||
clients.list[c.Name] = old
|
clients.list[c.Name] = old
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// update upstreams cache
|
||||||
|
delete(clients.upstreamsCache, name)
|
||||||
|
delete(clients.upstreamsCache, old.Name)
|
||||||
|
|
||||||
*old = c
|
*old = c
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
11
home/dns.go
11
home/dns.go
@ -11,6 +11,7 @@ import (
|
|||||||
"github.com/AdguardTeam/AdGuardHome/querylog"
|
"github.com/AdguardTeam/AdGuardHome/querylog"
|
||||||
"github.com/AdguardTeam/AdGuardHome/stats"
|
"github.com/AdguardTeam/AdGuardHome/stats"
|
||||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||||
|
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/joomcode/errorx"
|
"github.com/joomcode/errorx"
|
||||||
)
|
)
|
||||||
@ -178,18 +179,12 @@ func generateServerConfig() dnsforward.ServerConfig {
|
|||||||
return newconfig
|
return newconfig
|
||||||
}
|
}
|
||||||
|
|
||||||
func getUpstreamsByClient(clientAddr string) []string {
|
func getUpstreamsByClient(clientAddr string) []upstream.Upstream {
|
||||||
c, ok := Context.clients.Find(clientAddr)
|
return Context.clients.FindUpstreams(clientAddr)
|
||||||
if !ok {
|
|
||||||
return []string{}
|
|
||||||
}
|
|
||||||
log.Debug("Using upstreams %v for client %s (IP: %s)", c.Upstreams, c.Name, clientAddr)
|
|
||||||
return c.Upstreams
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// If a client has his own settings, apply them
|
// If a client has his own settings, apply them
|
||||||
func applyAdditionalFiltering(clientAddr string, setts *dnsfilter.RequestFilteringSettings) {
|
func applyAdditionalFiltering(clientAddr string, setts *dnsfilter.RequestFilteringSettings) {
|
||||||
|
|
||||||
ApplyBlockedServices(setts, config.DNS.BlockedServices)
|
ApplyBlockedServices(setts, config.DNS.BlockedServices)
|
||||||
|
|
||||||
if len(clientAddr) == 0 {
|
if len(clientAddr) == 0 {
|
||||||
|
Loading…
Reference in New Issue
Block a user