Pull request: 3142 custom private subnets

Merge in DNS/adguard-home from 3142-custom-subnets to master

Updates #3142.

Squashed commit of the following:

commit 11469ade75b9dc32ee6d93e3aa35cf79dbaa28b2
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Thu Mar 17 19:56:02 2022 +0300

    all: upd golibs, use subnet set
This commit is contained in:
Eugene Burkov 2022-03-18 13:37:27 +03:00
parent 573cbafe3f
commit 778585865e
15 changed files with 105 additions and 520 deletions

View File

@ -17,6 +17,8 @@ and this project adheres to
### Added ### Added
- The ability to customize the set of networks considered private through the
new `private_networks` setting ([#3142]).
- EDNS Client-Subnet information in the request details section of a query log - EDNS Client-Subnet information in the request details section of a query log
record ([#3978]). record ([#3978]).
- Support for hostnames for plain UDP upstream servers using the `udp://` scheme - Support for hostnames for plain UDP upstream servers using the `udp://` scheme
@ -88,6 +90,7 @@ In this release, the schema version has changed from 12 to 13.
[#1730]: https://github.com/AdguardTeam/AdGuardHome/issues/1730 [#1730]: https://github.com/AdguardTeam/AdGuardHome/issues/1730
[#2993]: https://github.com/AdguardTeam/AdGuardHome/issues/2993 [#2993]: https://github.com/AdguardTeam/AdGuardHome/issues/2993
[#3057]: https://github.com/AdguardTeam/AdGuardHome/issues/3057 [#3057]: https://github.com/AdguardTeam/AdGuardHome/issues/3057
[#3142]: https://github.com/AdguardTeam/AdGuardHome/issues/3142
[#3367]: https://github.com/AdguardTeam/AdGuardHome/issues/3367 [#3367]: https://github.com/AdguardTeam/AdGuardHome/issues/3367
[#3381]: https://github.com/AdguardTeam/AdGuardHome/issues/3381 [#3381]: https://github.com/AdguardTeam/AdGuardHome/issues/3381
[#3503]: https://github.com/AdguardTeam/AdGuardHome/issues/3503 [#3503]: https://github.com/AdguardTeam/AdGuardHome/issues/3503

2
go.mod
View File

@ -4,7 +4,7 @@ go 1.17
require ( require (
github.com/AdguardTeam/dnsproxy v0.41.4 github.com/AdguardTeam/dnsproxy v0.41.4
github.com/AdguardTeam/golibs v0.10.6 github.com/AdguardTeam/golibs v0.10.8
github.com/AdguardTeam/urlfilter v0.15.2 github.com/AdguardTeam/urlfilter v0.15.2
github.com/NYTimes/gziphandler v1.1.1 github.com/NYTimes/gziphandler v1.1.1
github.com/ameshkov/dnscrypt/v2 v2.2.3 github.com/ameshkov/dnscrypt/v2 v2.2.3

3
go.sum
View File

@ -12,8 +12,9 @@ github.com/AdguardTeam/dnsproxy v0.41.4/go.mod h1:GCdEbTw683vBqksJIccPSYzBg2yIFb
github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
github.com/AdguardTeam/golibs v0.4.2/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= github.com/AdguardTeam/golibs v0.4.2/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
github.com/AdguardTeam/golibs v0.10.4/go.mod h1:rSfQRGHIdgfxriDDNgNJ7HmE5zRoURq8R+VdR81Zuzw= github.com/AdguardTeam/golibs v0.10.4/go.mod h1:rSfQRGHIdgfxriDDNgNJ7HmE5zRoURq8R+VdR81Zuzw=
github.com/AdguardTeam/golibs v0.10.6 h1:6UG6LxWFnG7TfjNzeApw+T68Kqqov0fcDYk9RjhTdhc=
github.com/AdguardTeam/golibs v0.10.6/go.mod h1:rSfQRGHIdgfxriDDNgNJ7HmE5zRoURq8R+VdR81Zuzw= github.com/AdguardTeam/golibs v0.10.6/go.mod h1:rSfQRGHIdgfxriDDNgNJ7HmE5zRoURq8R+VdR81Zuzw=
github.com/AdguardTeam/golibs v0.10.8 h1:diU9gP9qG1qeLbAkzIwfUerpHSqzR6zaBgzvRMR/m6Q=
github.com/AdguardTeam/golibs v0.10.8/go.mod h1:rSfQRGHIdgfxriDDNgNJ7HmE5zRoURq8R+VdR81Zuzw=
github.com/AdguardTeam/gomitmproxy v0.2.0/go.mod h1:Qdv0Mktnzer5zpdpi5rAwixNJzW2FN91LjKJCkVbYGU= github.com/AdguardTeam/gomitmproxy v0.2.0/go.mod h1:Qdv0Mktnzer5zpdpi5rAwixNJzW2FN91LjKJCkVbYGU=
github.com/AdguardTeam/urlfilter v0.15.2 h1:LZGgrm4l4Ys9eAqB+UUmZfiC6vHlDlYFhx0WXqo6LtQ= github.com/AdguardTeam/urlfilter v0.15.2 h1:LZGgrm4l4Ys9eAqB+UUmZfiC6vHlDlYFhx0WXqo6LtQ=
github.com/AdguardTeam/urlfilter v0.15.2/go.mod h1:46YZDOV1+qtdRDuhZKVPSSp7JWWes0KayqHrKAFBdEI= github.com/AdguardTeam/urlfilter v0.15.2/go.mod h1:46YZDOV1+qtdRDuhZKVPSSp7JWWes0KayqHrKAFBdEI=

View File

@ -1,158 +0,0 @@
package aghnet
import (
"net"
)
// SubnetDetector describes IP address properties.
type SubnetDetector struct {
// spNets is the collection of special-purpose address registries as defined
// by RFC 6890.
spNets []*net.IPNet
// locServedNets is the collection of locally-served networks as defined by
// RFC 6303.
locServedNets []*net.IPNet
}
// NewSubnetDetector returns a new IP detector.
//
// TODO(a.garipov): Decide whether an error is actually needed.
func NewSubnetDetector() (snd *SubnetDetector, err error) {
spNets := []string{
// "This" network.
"0.0.0.0/8",
// Private-Use Networks.
"10.0.0.0/8",
// Shared Address Space.
"100.64.0.0/10",
// Loopback.
"127.0.0.0/8",
// Link Local.
"169.254.0.0/16",
// Private-Use Networks.
"172.16.0.0/12",
// IETF Protocol Assignments.
"192.0.0.0/24",
// DS-Lite.
"192.0.0.0/29",
// TEST-NET-1
"192.0.2.0/24",
// 6to4 Relay Anycast.
"192.88.99.0/24",
// Private-Use Networks.
"192.168.0.0/16",
// Network Interconnect Device Benchmark Testing.
"198.18.0.0/15",
// TEST-NET-2.
"198.51.100.0/24",
// TEST-NET-3.
"203.0.113.0/24",
// Reserved for Future Use.
"240.0.0.0/4",
// Limited Broadcast.
"255.255.255.255/32",
// Loopback.
"::1/128",
// Unspecified.
"::/128",
// IPv4-IPv6 Translation Address.
"64:ff9b::/96",
// IPv4-Mapped Address. Since this network is used for mapping
// IPv4 addresses, we don't include it.
// "::ffff:0:0/96",
// Discard-Only Prefix.
"100::/64",
// IETF Protocol Assignments.
"2001::/23",
// TEREDO.
"2001::/32",
// Benchmarking.
"2001:2::/48",
// Documentation.
"2001:db8::/32",
// ORCHID.
"2001:10::/28",
// 6to4.
"2002::/16",
// Unique-Local.
"fc00::/7",
// Linked-Scoped Unicast.
"fe80::/10",
}
// TODO(e.burkov): It's a subslice of the slice above. Should be done
// smarter.
locServedNets := []string{
// IPv4.
"10.0.0.0/8",
"172.16.0.0/12",
"192.168.0.0/16",
"127.0.0.0/8",
"169.254.0.0/16",
"192.0.2.0/24",
"198.51.100.0/24",
"203.0.113.0/24",
"255.255.255.255/32",
// IPv6.
"::/128",
"::1/128",
"fe80::/10",
"2001:db8::/32",
"fd00::/8",
}
snd = &SubnetDetector{
spNets: make([]*net.IPNet, len(spNets)),
locServedNets: make([]*net.IPNet, len(locServedNets)),
}
for i, ipnetStr := range spNets {
var ipnet *net.IPNet
_, ipnet, err = net.ParseCIDR(ipnetStr)
if err != nil {
return nil, err
}
snd.spNets[i] = ipnet
}
for i, ipnetStr := range locServedNets {
var ipnet *net.IPNet
_, ipnet, err = net.ParseCIDR(ipnetStr)
if err != nil {
return nil, err
}
snd.locServedNets[i] = ipnet
}
return snd, nil
}
// anyNetContains ranges through the given ipnets slice searching for the one
// which contains the ip. For internal use only.
//
// TODO(e.burkov): Think about memoization.
func anyNetContains(ipnets *[]*net.IPNet, ip net.IP) (is bool) {
for _, ipnet := range *ipnets {
if ipnet.Contains(ip) {
return true
}
}
return false
}
// IsSpecialNetwork returns true if IP address is contained by any of
// special-purpose IP address registries. It's safe for concurrent use.
func (snd *SubnetDetector) IsSpecialNetwork(ip net.IP) (is bool) {
return anyNetContains(&snd.spNets, ip)
}
// IsLocallyServedNetwork returns true if IP address is contained by any of
// locally-served IP address registries. It's safe for concurrent use.
func (snd *SubnetDetector) IsLocallyServedNetwork(ip net.IP) (is bool) {
return anyNetContains(&snd.locServedNets, ip)
}

View File

@ -1,252 +0,0 @@
package aghnet
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSubnetDetector_DetectSpecialNetwork(t *testing.T) {
snd, err := NewSubnetDetector()
require.NoError(t, err)
testCases := []struct {
name string
ip net.IP
want bool
}{{
name: "not_specific",
ip: net.ParseIP("8.8.8.8"),
want: false,
}, {
name: "this_host_on_this_network",
ip: net.ParseIP("0.0.0.0"),
want: true,
}, {
name: "private-Use",
ip: net.ParseIP("10.0.0.0"),
want: true,
}, {
name: "shared_address_space",
ip: net.ParseIP("100.64.0.0"),
want: true,
}, {
name: "loopback",
ip: net.ParseIP("127.0.0.0"),
want: true,
}, {
name: "link_local",
ip: net.ParseIP("169.254.0.0"),
want: true,
}, {
name: "private-use",
ip: net.ParseIP("172.16.0.0"),
want: true,
}, {
name: "ietf_protocol_assignments",
ip: net.ParseIP("192.0.0.0"),
want: true,
}, {
name: "ds-lite",
ip: net.ParseIP("192.0.0.0"),
want: true,
}, {
name: "documentation_(test-net-1)",
ip: net.ParseIP("192.0.2.0"),
want: true,
}, {
name: "6to4_relay_anycast",
ip: net.ParseIP("192.88.99.0"),
want: true,
}, {
name: "private-use",
ip: net.ParseIP("192.168.0.0"),
want: true,
}, {
name: "benchmarking",
ip: net.ParseIP("198.18.0.0"),
want: true,
}, {
name: "documentation_(test-net-2)",
ip: net.ParseIP("198.51.100.0"),
want: true,
}, {
name: "documentation_(test-net-3)",
ip: net.ParseIP("203.0.113.0"),
want: true,
}, {
name: "reserved",
ip: net.ParseIP("240.0.0.0"),
want: true,
}, {
name: "limited_broadcast",
ip: net.ParseIP("255.255.255.255"),
want: true,
}, {
name: "loopback_address",
ip: net.ParseIP("::1"),
want: true,
}, {
name: "unspecified_address",
ip: net.ParseIP("::"),
want: true,
}, {
name: "ipv4-ipv6_translation",
ip: net.ParseIP("64:ff9b::"),
want: true,
}, {
name: "discard-only_address_block",
ip: net.ParseIP("100::"),
want: true,
}, {
name: "ietf_protocol_assignments",
ip: net.ParseIP("2001::"),
want: true,
}, {
name: "teredo",
ip: net.ParseIP("2001::"),
want: true,
}, {
name: "benchmarking",
ip: net.ParseIP("2001:2::"),
want: true,
}, {
name: "documentation",
ip: net.ParseIP("2001:db8::"),
want: true,
}, {
name: "orchid",
ip: net.ParseIP("2001:10::"),
want: true,
}, {
name: "6to4",
ip: net.ParseIP("2002::"),
want: true,
}, {
name: "unique-local",
ip: net.ParseIP("fc00::"),
want: true,
}, {
name: "linked-scoped_unicast",
ip: net.ParseIP("fe80::"),
want: true,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.want, snd.IsSpecialNetwork(tc.ip))
})
}
}
func TestSubnetDetector_DetectLocallyServedNetwork(t *testing.T) {
snd, err := NewSubnetDetector()
require.NoError(t, err)
testCases := []struct {
name string
ip net.IP
want bool
}{{
name: "not_specific",
ip: net.ParseIP("8.8.8.8"),
want: false,
}, {
name: "private-Use",
ip: net.ParseIP("10.0.0.0"),
want: true,
}, {
name: "loopback",
ip: net.ParseIP("127.0.0.0"),
want: true,
}, {
name: "link_local",
ip: net.ParseIP("169.254.0.0"),
want: true,
}, {
name: "private-use",
ip: net.ParseIP("172.16.0.0"),
want: true,
}, {
name: "documentation_(test-net-1)",
ip: net.ParseIP("192.0.2.0"),
want: true,
}, {
name: "private-use",
ip: net.ParseIP("192.168.0.0"),
want: true,
}, {
name: "documentation_(test-net-2)",
ip: net.ParseIP("198.51.100.0"),
want: true,
}, {
name: "documentation_(test-net-3)",
ip: net.ParseIP("203.0.113.0"),
want: true,
}, {
name: "limited_broadcast",
ip: net.ParseIP("255.255.255.255"),
want: true,
}, {
name: "loopback_address",
ip: net.ParseIP("::1"),
want: true,
}, {
name: "unspecified_address",
ip: net.ParseIP("::"),
want: true,
}, {
name: "documentation",
ip: net.ParseIP("2001:db8::"),
want: true,
}, {
name: "linked-scoped_unicast",
ip: net.ParseIP("fe80::"),
want: true,
}, {
name: "locally_assigned",
ip: net.ParseIP("fd00::1"),
want: true,
}, {
name: "not_locally_assigned",
ip: net.ParseIP("fc00::1"),
want: false,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.want, snd.IsLocallyServedNetwork(tc.ip))
})
}
}
func TestSubnetDetector_Detect_parallel(t *testing.T) {
t.Parallel()
snd, err := NewSubnetDetector()
require.NoError(t, err)
testFunc := func() {
for _, ip := range []net.IP{
net.IPv4allrouter,
net.IPv4allsys,
net.IPv4bcast,
net.IPv4zero,
net.IPv6interfacelocalallnodes,
net.IPv6linklocalallnodes,
net.IPv6linklocalallrouters,
net.IPv6loopback,
net.IPv6unspecified,
} {
_ = snd.IsSpecialNetwork(ip)
_ = snd.IsLocallyServedNetwork(ip)
}
}
const goroutinesNum = 50
for i := 0; i < goroutinesNum; i++ {
go testFunc()
}
}

View File

@ -252,7 +252,7 @@ func (s *Server) processDetermineLocal(dctx *dnsContext) (rc resultCode) {
return rc return rc
} }
dctx.isLocalClient = s.subnetDetector.IsLocallyServedNetwork(ip) dctx.isLocalClient = s.privateNets.Contains(ip)
return rc return rc
} }
@ -374,7 +374,7 @@ func (s *Server) processRestrictLocal(ctx *dnsContext) (rc resultCode) {
// Restrict an access to local addresses for external clients. We also // Restrict an access to local addresses for external clients. We also
// assume that all the DHCP leases we give are locally-served or at least // assume that all the DHCP leases we give are locally-served or at least
// don't need to be inaccessible externally. // don't need to be inaccessible externally.
if !s.subnetDetector.IsLocallyServedNetwork(ip) { if !s.privateNets.Contains(ip) {
log.Debug("dns: addr %s is not from locally-served network", ip) log.Debug("dns: addr %s is not from locally-served network", ip)
return resultCodeSuccess return resultCodeSuccess
@ -481,7 +481,7 @@ func (s *Server) processLocalPTR(ctx *dnsContext) (rc resultCode) {
s.serverLock.RLock() s.serverLock.RLock()
defer s.serverLock.RUnlock() defer s.serverLock.RUnlock()
if !s.subnetDetector.IsLocallyServedNetwork(ip) { if !s.privateNets.Contains(ip) {
return resultCodeSuccess return resultCodeSuccess
} }

View File

@ -4,35 +4,41 @@ import (
"net" "net"
"testing" "testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestServer_ProcessDetermineLocal(t *testing.T) { func TestServer_ProcessDetermineLocal(t *testing.T) {
snd, err := aghnet.NewSubnetDetector()
require.NoError(t, err)
s := &Server{ s := &Server{
subnetDetector: snd, privateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
} }
testCases := []struct { testCases := []struct {
want assert.BoolAssertionFunc
name string name string
cliIP net.IP cliIP net.IP
want bool
}{{ }{{
want: assert.True,
name: "local", name: "local",
cliIP: net.IP{192, 168, 0, 1}, cliIP: net.IP{192, 168, 0, 1},
want: true,
}, { }, {
want: assert.False,
name: "external", name: "external",
cliIP: net.IP{250, 249, 0, 1}, cliIP: net.IP{250, 249, 0, 1},
want: false, }, {
want: assert.False,
name: "invalid",
cliIP: net.IP{1, 2, 3, 4, 5},
}, {
want: assert.False,
name: "nil",
cliIP: nil,
}} }}
for _, tc := range testCases { for _, tc := range testCases {
@ -47,7 +53,7 @@ func TestServer_ProcessDetermineLocal(t *testing.T) {
} }
s.processDetermineLocal(dctx) s.processDetermineLocal(dctx)
assert.Equal(t, tc.want, dctx.isLocalClient) tc.want(t, dctx.isLocalClient)
}) })
} }
} }

View File

@ -74,7 +74,7 @@ type Server struct {
localDomainSuffix string localDomainSuffix string
ipset ipsetCtx ipset ipsetCtx
subnetDetector *aghnet.SubnetDetector privateNets netutil.SubnetSet
localResolvers *proxy.Proxy localResolvers *proxy.Proxy
sysResolvers aghnet.SystemResolvers sysResolvers aghnet.SystemResolvers
recDetector *recursionDetector recDetector *recursionDetector
@ -115,7 +115,7 @@ type DNSCreateParams struct {
Stats stats.Stats Stats stats.Stats
QueryLog querylog.QueryLog QueryLog querylog.QueryLog
DHCPServer dhcpd.ServerInterface DHCPServer dhcpd.ServerInterface
SubnetDetector *aghnet.SubnetDetector PrivateNets netutil.SubnetSet
Anonymizer *aghnet.IPMut Anonymizer *aghnet.IPMut
LocalDomain string LocalDomain string
} }
@ -161,7 +161,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
dnsFilter: p.DNSFilter, dnsFilter: p.DNSFilter,
stats: p.Stats, stats: p.Stats,
queryLog: p.QueryLog, queryLog: p.QueryLog,
subnetDetector: p.SubnetDetector, privateNets: p.PrivateNets,
localDomainSuffix: localDomainSuffix, localDomainSuffix: localDomainSuffix,
recDetector: newRecursionDetector(recursionTTL, cachedRecurrentReqNum), recDetector: newRecursionDetector(recursionTTL, cachedRecurrentReqNum),
clientIDCache: cache.New(cache.Config{ clientIDCache: cache.New(cache.Config{
@ -315,7 +315,7 @@ func (s *Server) Exchange(ip net.IP) (host string, err error) {
} }
resolver := s.internalProxy resolver := s.internalProxy
if s.subnetDetector.IsLocallyServedNetwork(ip) { if s.privateNets.Contains(ip) {
if !s.conf.UsePrivateRDNS { if !s.conf.UsePrivateRDNS {
return "", nil return "", nil
} }

View File

@ -24,6 +24,7 @@ import (
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/golibs/timeutil" "github.com/AdguardTeam/golibs/timeutil"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -69,14 +70,11 @@ func createTestServer(
f := filtering.New(filterConf, filters) f := filtering.New(filterConf, filters)
f.SetEnabled(true) f.SetEnabled(true)
snd, err := aghnet.NewSubnetDetector() var err error
require.NoError(t, err)
require.NotNil(t, snd)
s, err = NewServer(DNSCreateParams{ s, err = NewServer(DNSCreateParams{
DHCPServer: &testDHCP{}, DHCPServer: &testDHCP{},
DNSFilter: f, DNSFilter: f,
SubnetDetector: snd, PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
}) })
require.NoError(t, err) require.NoError(t, err)
@ -770,16 +768,11 @@ func TestBlockedCustomIP(t *testing.T) {
Data: []byte(rules), Data: []byte(rules),
}} }}
snd, err := aghnet.NewSubnetDetector()
require.NoError(t, err)
require.NotNil(t, snd)
f := filtering.New(&filtering.Config{}, filters) f := filtering.New(&filtering.Config{}, filters)
var s *Server s, err := NewServer(DNSCreateParams{
s, err = NewServer(DNSCreateParams{
DHCPServer: &testDHCP{}, DHCPServer: &testDHCP{},
DNSFilter: f, DNSFilter: f,
SubnetDetector: snd, PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
}) })
require.NoError(t, err) require.NoError(t, err)
@ -913,15 +906,10 @@ func TestRewrite(t *testing.T) {
f := filtering.New(c, nil) f := filtering.New(c, nil)
f.SetEnabled(true) f.SetEnabled(true)
snd, err := aghnet.NewSubnetDetector() s, err := NewServer(DNSCreateParams{
require.NoError(t, err)
require.NotNil(t, snd)
var s *Server
s, err = NewServer(DNSCreateParams{
DHCPServer: &testDHCP{}, DHCPServer: &testDHCP{},
DNSFilter: f, DNSFilter: f,
SubnetDetector: snd, PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
}) })
require.NoError(t, err) require.NoError(t, err)
@ -1028,15 +1016,10 @@ func (d *testDHCP) Leases(flags dhcpd.GetLeasesFlags) (leases []*dhcpd.Lease) {
func (d *testDHCP) SetOnLeaseChanged(onLeaseChanged dhcpd.OnLeaseChangedT) {} func (d *testDHCP) SetOnLeaseChanged(onLeaseChanged dhcpd.OnLeaseChangedT) {}
func TestPTRResponseFromDHCPLeases(t *testing.T) { func TestPTRResponseFromDHCPLeases(t *testing.T) {
snd, err := aghnet.NewSubnetDetector() s, err := NewServer(DNSCreateParams{
require.NoError(t, err)
require.NotNil(t, snd)
var s *Server
s, err = NewServer(DNSCreateParams{
DNSFilter: filtering.New(&filtering.Config{}, nil), DNSFilter: filtering.New(&filtering.Config{}, nil),
DHCPServer: &testDHCP{}, DHCPServer: &testDHCP{},
SubnetDetector: snd, PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
}) })
require.NoError(t, err) require.NoError(t, err)
@ -1105,16 +1088,11 @@ func TestPTRResponseFromHosts(t *testing.T) {
}, nil) }, nil)
flt.SetEnabled(true) flt.SetEnabled(true)
var snd *aghnet.SubnetDetector
snd, err = aghnet.NewSubnetDetector()
require.NoError(t, err)
require.NotNil(t, snd)
var s *Server var s *Server
s, err = NewServer(DNSCreateParams{ s, err = NewServer(DNSCreateParams{
DHCPServer: &testDHCP{}, DHCPServer: &testDHCP{},
DNSFilter: flt, DNSFilter: flt,
SubnetDetector: snd, PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
}) })
require.NoError(t, err) require.NoError(t, err)
@ -1227,9 +1205,7 @@ func TestServer_Exchange(t *testing.T) {
srv.conf.ResolveClients = true srv.conf.ResolveClients = true
srv.conf.UsePrivateRDNS = true srv.conf.UsePrivateRDNS = true
var err error srv.privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed)
srv.subnetDetector, err = aghnet.NewSubnetDetector()
require.NoError(t, err)
localIP := net.IP{192, 168, 1, 1} localIP := net.IP{192, 168, 1, 1}
testCases := []struct { testCases := []struct {

View File

@ -4,7 +4,6 @@ import (
"net" "net"
"testing" "testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
@ -39,14 +38,10 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
f := filtering.New(&filtering.Config{}, filters) f := filtering.New(&filtering.Config{}, filters)
f.SetEnabled(true) f.SetEnabled(true)
snd, err := aghnet.NewSubnetDetector()
require.NoError(t, err)
require.NotNil(t, snd)
s, err := NewServer(DNSCreateParams{ s, err := NewServer(DNSCreateParams{
DHCPServer: &testDHCP{}, DHCPServer: &testDHCP{},
DNSFilter: f, DNSFilter: f,
SubnetDetector: snd, PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
}) })
require.NoError(t, err) require.NoError(t, err)

View File

@ -10,7 +10,6 @@ import (
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
@ -167,7 +166,7 @@ func (req *dnsConfig) checkBootstrap() (err error) {
} }
// validate returns an error if any field of req is invalid. // validate returns an error if any field of req is invalid.
func (req *dnsConfig) validate(snd *aghnet.SubnetDetector) (err error) { func (req *dnsConfig) validate(privateNets netutil.SubnetSet) (err error) {
if req.Upstreams != nil { if req.Upstreams != nil {
err = ValidateUpstreams(*req.Upstreams) err = ValidateUpstreams(*req.Upstreams)
if err != nil { if err != nil {
@ -176,7 +175,7 @@ func (req *dnsConfig) validate(snd *aghnet.SubnetDetector) (err error) {
} }
if req.LocalPTRUpstreams != nil { if req.LocalPTRUpstreams != nil {
err = ValidateUpstreamsPrivate(*req.LocalPTRUpstreams, snd) err = ValidateUpstreamsPrivate(*req.LocalPTRUpstreams, privateNets)
if err != nil { if err != nil {
return fmt.Errorf("validating private upstream servers: %w", err) return fmt.Errorf("validating private upstream servers: %w", err)
} }
@ -224,7 +223,7 @@ func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) {
return return
} }
err = req.validate(s.subnetDetector) err = req.validate(s.privateNets)
if err != nil { if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
@ -350,17 +349,6 @@ func IsCommentOrEmpty(s string) (ok bool) {
return len(s) == 0 || s[0] == '#' return len(s) == 0 || s[0] == '#'
} }
// LocalNetChecker is used to check if the IP address belongs to a local
// network.
type LocalNetChecker interface {
// IsLocallyServedNetwork returns true if ip is contained in any of address
// registries defined by RFC 6303.
IsLocallyServedNetwork(ip net.IP) (ok bool)
}
// type check
var _ LocalNetChecker = (*aghnet.SubnetDetector)(nil)
// newUpstreamConfig validates upstreams and returns an appropriate upstream // newUpstreamConfig validates upstreams and returns an appropriate upstream
// configuration or nil if it can't be built. // configuration or nil if it can't be built.
// //
@ -422,8 +410,8 @@ func stringKeysSorted(m map[string][]upstream.Upstream) (sorted []string) {
// ValidateUpstreamsPrivate validates each upstream and returns an error if any // ValidateUpstreamsPrivate validates each upstream and returns an error if any
// upstream is invalid or if there are no default upstreams specified. It also // upstream is invalid or if there are no default upstreams specified. It also
// checks each domain of domain-specific upstreams for being ARPA pointing to // checks each domain of domain-specific upstreams for being ARPA pointing to
// a locally-served network. lnc must not be nil. // a locally-served network. privateNets must not be nil.
func ValidateUpstreamsPrivate(upstreams []string, lnc LocalNetChecker) (err error) { func ValidateUpstreamsPrivate(upstreams []string, privateNets netutil.SubnetSet) (err error) {
conf, err := newUpstreamConfig(upstreams) conf, err := newUpstreamConfig(upstreams)
if err != nil { if err != nil {
return err return err
@ -444,7 +432,7 @@ func ValidateUpstreamsPrivate(upstreams []string, lnc LocalNetChecker) (err erro
continue continue
} }
if !lnc.IsLocallyServedNetwork(subnet.IP) { if !privateNets.Contains(subnet.IP) {
errs = append( errs = append(
errs, errs,
fmt.Errorf("arpa domain %q should point to a locally-served network", domain), fmt.Errorf("arpa domain %q should point to a locally-served network", domain),

View File

@ -14,6 +14,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -410,8 +411,7 @@ func TestValidateUpstreams(t *testing.T) {
} }
func TestValidateUpstreamsPrivate(t *testing.T) { func TestValidateUpstreamsPrivate(t *testing.T) {
snd, err := aghnet.NewSubnetDetector() ss := netutil.SubnetSetFunc(netutil.IsLocallyServed)
require.NoError(t, err)
testCases := []struct { testCases := []struct {
name string name string
@ -452,7 +452,7 @@ func TestValidateUpstreamsPrivate(t *testing.T) {
set := []string{"192.168.0.1", tc.u} set := []string{"192.168.0.1", tc.u}
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
err = ValidateUpstreamsPrivate(set, snd) err := ValidateUpstreamsPrivate(set, ss)
testutil.AssertErrorMsg(t, tc.wantErr, err) testutil.AssertErrorMsg(t, tc.wantErr, err)
}) })
} }

View File

@ -126,6 +126,10 @@ type dnsConfig struct {
// ResolveClients enables and disables resolving clients with RDNS. // ResolveClients enables and disables resolving clients with RDNS.
ResolveClients bool `yaml:"resolve_clients"` ResolveClients bool `yaml:"resolve_clients"`
// PrivateNets is the set of IP networks for which the private reverse DNS
// resolver should be used.
PrivateNets []string `yaml:"private_networks"`
// UsePrivateRDNS defines if the PTR requests for unknown addresses from // UsePrivateRDNS defines if the PTR requests for unknown addresses from
// locally-served networks should be resolved via private PTR resolvers. // locally-served networks should be resolved via private PTR resolvers.
UsePrivateRDNS bool `yaml:"use_private_ptr_resolvers"` UsePrivateRDNS bool `yaml:"use_private_ptr_resolvers"`

View File

@ -77,11 +77,34 @@ func initDNSServer() (err error) {
filterConf.HTTPRegister = httpRegister filterConf.HTTPRegister = httpRegister
Context.dnsFilter = filtering.New(&filterConf, nil) Context.dnsFilter = filtering.New(&filterConf, nil)
var privateNets netutil.SubnetSet
switch len(config.DNS.PrivateNets) {
case 0:
// Use an optimized locally-served matcher.
privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed)
case 1:
var n *net.IPNet
n, err = netutil.ParseSubnet(config.DNS.PrivateNets[0])
if err != nil {
return fmt.Errorf("preparing the set of private subnets: %w", err)
}
privateNets = n
default:
var nets []*net.IPNet
nets, err = netutil.ParseSubnets(config.DNS.PrivateNets...)
if err != nil {
return fmt.Errorf("preparing the set of private subnets: %w", err)
}
privateNets = netutil.SliceSubnetSet(nets)
}
p := dnsforward.DNSCreateParams{ p := dnsforward.DNSCreateParams{
DNSFilter: Context.dnsFilter, DNSFilter: Context.dnsFilter,
Stats: Context.stats, Stats: Context.stats,
QueryLog: Context.queryLog, QueryLog: Context.queryLog,
SubnetDetector: Context.subnetDetector, PrivateNets: privateNets,
Anonymizer: anonymizer, Anonymizer: anonymizer,
LocalDomain: config.DHCP.LocalDomainName, LocalDomain: config.DHCP.LocalDomainName,
} }
@ -133,7 +156,7 @@ func onDNSRequest(pctx *proxy.DNSContext) {
if config.DNS.ResolveClients && !ip.IsLoopback() { if config.DNS.ResolveClients && !ip.IsLoopback() {
Context.rdns.Begin(ip) Context.rdns.Begin(ip)
} }
if !Context.subnetDetector.IsSpecialNetwork(ip) { if !netutil.IsSpecialPurpose(ip) {
Context.whois.Begin(ip) Context.whois.Begin(ip)
} }
} }
@ -360,10 +383,14 @@ func startDNSServer() error {
const topClientsNumber = 100 // the number of clients to get const topClientsNumber = 100 // the number of clients to get
for _, ip := range Context.stats.GetTopClientsIP(topClientsNumber) { for _, ip := range Context.stats.GetTopClientsIP(topClientsNumber) {
if ip == nil {
continue
}
if config.DNS.ResolveClients && !ip.IsLoopback() { if config.DNS.ResolveClients && !ip.IsLoopback() {
Context.rdns.Begin(ip) Context.rdns.Begin(ip)
} }
if !Context.subnetDetector.IsSpecialNetwork(ip) { if !netutil.IsSpecialPurpose(ip) {
Context.whois.Begin(ip) Context.whois.Begin(ip)
} }
} }

View File

@ -66,8 +66,6 @@ type homeContext struct {
updater *updater.Updater updater *updater.Updater
subnetDetector *aghnet.SubnetDetector
// mux is our custom http.ServeMux. // mux is our custom http.ServeMux.
mux *http.ServeMux mux *http.ServeMux
@ -477,9 +475,6 @@ func run(args options, clientBuildFS fs.FS) {
Context.web, err = initWeb(args, clientBuildFS) Context.web, err = initWeb(args, clientBuildFS)
fatalOnError(err) fatalOnError(err)
Context.subnetDetector, err = aghnet.NewSubnetDetector()
fatalOnError(err)
if !Context.firstRun { if !Context.firstRun {
err = initDNSServer() err = initDNSServer()
fatalOnError(err) fatalOnError(err)