From 86444eacc2d38666b949b37163e9e53b499119eb Mon Sep 17 00:00:00 2001 From: Eugene Burkov Date: Wed, 31 Mar 2021 15:00:47 +0300 Subject: [PATCH] Pull request: 2704 local addresses vol.2 Merge in DNS/adguard-home from 2704-local-addresses-vol.2 to master Updates #2704. Updates #2829. Squashed commit of the following: commit 507d038c2709de59246fc0b65c3c4ab8e38d1990 Author: Eugene Burkov Date: Wed Mar 31 14:33:05 2021 +0300 aghtest: fix file name commit 8e19f99337bee1d88ad6595adb96f9bb23fa3c41 Author: Eugene Burkov Date: Wed Mar 31 14:06:43 2021 +0300 aghnet: rm redundant mutexes commit 361fa418b33ed160ca20862be1c455ab9378c03f Author: Eugene Burkov Date: Wed Mar 31 13:45:30 2021 +0300 all: fix names, docs commit 14034f4f0230d7aaa3645054946ae5c278089a99 Merge: 35e265cc a72ce1cf Author: Eugene Burkov Date: Wed Mar 31 13:38:15 2021 +0300 Merge branch 'master' into 2704-local-addresses-vol.2 commit 35e265cc8cd308ef1fda414b58c0217cb5f258e4 Author: Eugene Burkov Date: Wed Mar 31 13:33:35 2021 +0300 aghnet: imp naming commit 7a7edac7208a40697d7bc50682b923a144e28e2b Author: Eugene Burkov Date: Tue Mar 30 20:59:54 2021 +0300 changelog: oops, nope yet commit d26a5d2513daf662ac92053b5e235189a64cc022 Author: Eugene Burkov Date: Tue Mar 30 20:55:53 2021 +0300 all: some renaming for the glory of semantics commit 9937fa619452b0742616217b975e3ff048d58acb Author: Eugene Burkov Date: Mon Mar 29 15:34:42 2021 +0300 all: log changes commit d8d9e6dfeea8474466ee25f27021efdd3ddb1592 Author: Eugene Burkov Date: Fri Mar 26 18:32:23 2021 +0300 all: imp localresolver, imp cutting off own addresses commit 344140df449b85925f19b460fd7dc7c08e29c35a Author: Eugene Burkov Date: Fri Mar 26 14:53:33 2021 +0300 all: imp code quality commit 1c5c0babec73b125044e23dd3aa75d8eefc19b28 Author: Eugene Burkov Date: Thu Mar 25 20:44:08 2021 +0300 all: fix go.mod commit 0b9fb3c2369a752e893af8ddc45a86bb9fb27ce5 Author: Eugene Burkov Date: Thu Mar 25 20:38:51 2021 +0300 all: add error handling commit a7a2e51f57fc6f8f74b95a264ad345cd2a9e026e Merge: c13be634 27f4f052 Author: Eugene Burkov Date: Thu Mar 25 19:48:36 2021 +0300 Merge branch 'master' into 2704-local-addresses-vol.2 commit c13be634f47bcaed9320a732a51c0e4752d0dad0 Author: Eugene Burkov Date: Thu Mar 25 18:52:28 2021 +0300 all: cover rdns with tests, imp aghnet functionality commit 48bed9025944530c613ee53e7961d6d5fbabf8be Author: Eugene Burkov Date: Wed Mar 24 20:18:07 2021 +0300 home: make rdns great again commit 1dbacfc8d5b6895807797998317fe3cc814617c1 Author: Eugene Burkov Date: Wed Mar 24 16:07:52 2021 +0300 all: imp external client restriction commit 1208a319a7f4ffe7b7fa8956f245d7a19437c0a4 Author: Eugene Burkov Date: Mon Mar 22 15:26:45 2021 +0300 all: finish local ptr processor commit c8827fc3db289e1a5d7a11d057743bab39957b02 Author: Eugene Burkov Date: Tue Mar 2 13:41:22 2021 +0300 all: imp ipdetector, add local ptr processor --- CHANGELOG.md | 3 + HACKING.md | 4 +- internal/aghnet/exchanger.go | 79 ++++++ internal/aghnet/exchanger_test.go | 64 +++++ internal/aghnet/ipdetector.go | 73 ------ internal/aghnet/net.go | 152 +++++++++-- internal/aghnet/net_darwin.go | 2 +- internal/aghnet/net_test.go | 91 +++++++ internal/aghnet/subnetdetector.go | 155 +++++++++++ ...etector_test.go => subnetdetector_test.go} | 111 +++++++- internal/aghnet/systemresolvers.go | 8 +- internal/aghnet/systemresolvers_others.go | 4 +- .../aghnet/systemresolvers_others_test.go | 2 +- internal/aghnet/systemresolvers_windows.go | 2 +- internal/aghtest/exchanger.go | 20 ++ internal/aghtest/upstream.go | 10 +- internal/dhcpd/http.go | 2 +- internal/dnsforward/dns.go | 184 ++++++++++--- internal/dnsforward/dns_test.go | 74 ++++++ internal/dnsforward/dnsforward.go | 19 +- internal/dnsforward/dnsforward_test.go | 64 ++++- internal/dnsforward/filter.go | 2 + internal/home/clients.go | 3 +- internal/home/dns.go | 20 +- internal/home/home.go | 112 +++++++- internal/home/rdns.go | 156 ++++++----- internal/home/rdns_test.go | 247 +++++++++++++++++- internal/util/autohosts.go | 3 +- internal/util/autohosts_test.go | 42 --- internal/util/dns.go | 70 ----- 30 files changed, 1418 insertions(+), 360 deletions(-) create mode 100644 internal/aghnet/exchanger.go create mode 100644 internal/aghnet/exchanger_test.go delete mode 100644 internal/aghnet/ipdetector.go create mode 100644 internal/aghnet/subnetdetector.go rename internal/aghnet/{ipdetector_test.go => subnetdetector_test.go} (54%) create mode 100644 internal/aghtest/exchanger.go delete mode 100644 internal/util/dns.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 870842c8..007b41fc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,8 @@ and this project adheres to ### Changed +- The reverse lookup for local addresses is now performed via local resolvers + ([#2704]). - Stricter validation of the IP addresses of static leases in the DHCP server with regards to the netmask ([#2838]). - Stricter validation of `$dnsrewrite` filter modifier parameters ([#2498]). @@ -50,6 +52,7 @@ and this project adheres to [#2498]: https://github.com/AdguardTeam/AdGuardHome/issues/2498 [#2533]: https://github.com/AdguardTeam/AdGuardHome/issues/2533 [#2541]: https://github.com/AdguardTeam/AdGuardHome/issues/2541 +[#2704]: https://github.com/AdguardTeam/AdGuardHome/issues/2704 [#2828]: https://github.com/AdguardTeam/AdGuardHome/issues/2828 [#2835]: https://github.com/AdguardTeam/AdGuardHome/issues/2835 [#2838]: https://github.com/AdguardTeam/AdGuardHome/issues/2838 diff --git a/HACKING.md b/HACKING.md index 462c3f9c..07308745 100644 --- a/HACKING.md +++ b/HACKING.md @@ -188,8 +188,8 @@ on GitHub and most other Markdown renderers. --> ### Formatting - * Add an empty line before `break`, `continue`, `fallthrough`, and `return`, - unless it's the only statement in that block. + * Decorate `break`, `continue`, `fallthrough`, `return`, and other function + exit points with empty lines unless it's the only statement in that block. * Use `gofumpt --extra -s`. diff --git a/internal/aghnet/exchanger.go b/internal/aghnet/exchanger.go new file mode 100644 index 00000000..2ddeb7ad --- /dev/null +++ b/internal/aghnet/exchanger.go @@ -0,0 +1,79 @@ +package aghnet + +import ( + "time" + + "github.com/AdguardTeam/AdGuardHome/internal/agherr" + "github.com/AdguardTeam/dnsproxy/upstream" + "github.com/miekg/dns" +) + +// This package is not the best place for this functionality, but we put it here +// since we need to use it in both rDNS (home) and dnsServer (dnsforward). + +// NoUpstreamsErr should be returned when there are no upstreams inside +// Exchanger implementation. +const NoUpstreamsErr agherr.Error = "no upstreams specified" + +// Exchanger represents an object able to resolve DNS messages. +// +// TODO(e.burkov): Maybe expand with method like ExchangeParallel to be able to +// use user's upstream mode settings. Also, think about Update method to +// refresh the internal state. +type Exchanger interface { + Exchange(req *dns.Msg) (resp *dns.Msg, err error) +} + +// multiAddrExchanger is the default implementation of Exchanger interface. +type multiAddrExchanger struct { + ups []upstream.Upstream +} + +// NewMultiAddrExchanger creates an Exchanger instance from passed addresses. +// It returns an error if any of addrs failed to become an upstream. +func NewMultiAddrExchanger(addrs []string, timeout time.Duration) (e Exchanger, err error) { + defer agherr.Annotate("exchanger: %w", &err) + + if len(addrs) == 0 { + return &multiAddrExchanger{}, nil + } + + var ups []upstream.Upstream = make([]upstream.Upstream, 0, len(addrs)) + for _, addr := range addrs { + var u upstream.Upstream + u, err = upstream.AddressToUpstream(addr, upstream.Options{Timeout: timeout}) + if err != nil { + return nil, err + } + + ups = append(ups, u) + } + + return &multiAddrExchanger{ups: ups}, nil +} + +// Exсhange performs a query to each resolver until first response. +func (e *multiAddrExchanger) Exchange(req *dns.Msg) (resp *dns.Msg, err error) { + defer agherr.Annotate("exchanger: %w", &err) + + // TODO(e.burkov): Maybe prohibit the initialization without upstreams. + if len(e.ups) == 0 { + return nil, NoUpstreamsErr + } + + var errs []error + for _, u := range e.ups { + resp, err = u.Exchange(req) + if err != nil { + errs = append(errs, err) + + continue + } + + if resp != nil { + return resp, nil + } + } + + return nil, agherr.Many("can't exchange", errs...) +} diff --git a/internal/aghnet/exchanger_test.go b/internal/aghnet/exchanger_test.go new file mode 100644 index 00000000..774bec86 --- /dev/null +++ b/internal/aghnet/exchanger_test.go @@ -0,0 +1,64 @@ +package aghnet + +import ( + "testing" + + "github.com/AdguardTeam/AdGuardHome/internal/aghtest" + "github.com/AdguardTeam/dnsproxy/upstream" + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewMultiAddrExchanger(t *testing.T) { + var e Exchanger + var err error + + t.Run("empty", func(t *testing.T) { + e, err = NewMultiAddrExchanger([]string{}, 0) + require.NoError(t, err) + assert.NotNil(t, e) + }) + + t.Run("successful", func(t *testing.T) { + e, err = NewMultiAddrExchanger([]string{"www.example.com"}, 0) + require.NoError(t, err) + assert.NotNil(t, e) + }) + + t.Run("unsuccessful", func(t *testing.T) { + e, err = NewMultiAddrExchanger([]string{"invalid-proto://www.example.com"}, 0) + require.Error(t, err) + assert.Nil(t, e) + }) +} + +func TestMultiAddrExchanger_Exchange(t *testing.T) { + e := &multiAddrExchanger{} + + t.Run("error", func(t *testing.T) { + e.ups = []upstream.Upstream{&aghtest.TestErrUpstream{}} + + resp, err := e.Exchange(nil) + require.Error(t, err) + assert.Nil(t, resp) + }) + + t.Run("success", func(t *testing.T) { + e.ups = []upstream.Upstream{&aghtest.TestUpstream{ + Reverse: map[string][]string{ + "abc": {"cba"}, + }, + }} + + resp, err := e.Exchange(&dns.Msg{ + Question: []dns.Question{{ + Name: "abc", + Qtype: dns.TypePTR, + }}, + }) + require.NoError(t, err) + require.Len(t, resp.Answer, 1) + assert.Equal(t, "cba", resp.Answer[0].Header().Name) + }) +} diff --git a/internal/aghnet/ipdetector.go b/internal/aghnet/ipdetector.go deleted file mode 100644 index 7fa9414c..00000000 --- a/internal/aghnet/ipdetector.go +++ /dev/null @@ -1,73 +0,0 @@ -package aghnet - -import "net" - -// IPDetector describes IP address properties. -type IPDetector struct { - nets []*net.IPNet -} - -// NewIPDetector returns a new IP detector. -func NewIPDetector() (ipd *IPDetector, err error) { - specialNetworks := []string{ - "0.0.0.0/8", - "10.0.0.0/8", - "100.64.0.0/10", - "127.0.0.0/8", - "169.254.0.0/16", - "172.16.0.0/12", - "192.0.0.0/24", - "192.0.0.0/29", - "192.0.2.0/24", - "192.88.99.0/24", - "192.168.0.0/16", - "198.18.0.0/15", - "198.51.100.0/24", - "203.0.113.0/24", - "240.0.0.0/4", - "255.255.255.255/32", - "::1/128", - "::/128", - "64:ff9b::/96", - // Since this network is used for mapping IPv4 addresses, we - // don't include it. - // "::ffff:0:0/96", - "100::/64", - "2001::/23", - "2001::/32", - "2001:2::/48", - "2001:db8::/32", - "2001:10::/28", - "2002::/16", - "fc00::/7", - "fe80::/10", - } - - ipd = &IPDetector{ - nets: make([]*net.IPNet, len(specialNetworks)), - } - for i, ipnetStr := range specialNetworks { - var ipnet *net.IPNet - _, ipnet, err = net.ParseCIDR(ipnetStr) - if err != nil { - return nil, err - } - - ipd.nets[i] = ipnet - } - - return ipd, nil -} - -// DetectSpecialNetwork returns true if IP address is contained by any of -// special-purpose IP address registries according to RFC-6890 -// (https://tools.ietf.org/html/rfc6890). -func (ipd *IPDetector) DetectSpecialNetwork(ip net.IP) bool { - for _, ipnet := range ipd.nets { - if ipnet.Contains(ip) { - return true - } - } - - return false -} diff --git a/internal/aghnet/net.go b/internal/aghnet/net.go index d49e5573..fd36fe24 100644 --- a/internal/aghnet/net.go +++ b/internal/aghnet/net.go @@ -97,25 +97,10 @@ func (iface *NetInterface) MarshalJSON() ([]byte, error) { }) } -// GetValidNetInterfaces returns interfaces that are eligible for DNS and/or DHCP -// invalid interface is a ppp interface or the one that doesn't allow broadcasts -func GetValidNetInterfaces() ([]net.Interface, error) { - ifaces, err := net.Interfaces() - if err != nil { - return nil, fmt.Errorf("couldn't get list of interfaces: %w", err) - } - - netIfaces := []net.Interface{} - - netIfaces = append(netIfaces, ifaces...) - - return netIfaces, nil -} - // GetValidNetInterfacesForWeb returns interfaces that are eligible for DNS and WEB only // we do not return link-local addresses here func GetValidNetInterfacesForWeb() ([]*NetInterface, error) { - ifaces, err := GetValidNetInterfaces() + ifaces, err := net.Interfaces() if err != nil { return nil, fmt.Errorf("couldn't get interfaces: %w", err) } @@ -273,3 +258,138 @@ func SplitHost(hostport string) (host string, err error) { return host, nil } + +// TODO(e.burkov): Inspect the charToHex, ipParseARPA6, ipReverse and +// UnreverseAddr and maybe refactor it. + +// charToHex converts character to a hexadecimal. +func charToHex(n byte) int8 { + if n >= '0' && n <= '9' { + return int8(n) - '0' + } else if (n|0x20) >= 'a' && (n|0x20) <= 'f' { + return (int8(n) | 0x20) - 'a' + 10 + } + return -1 +} + +// ipParseARPA6 parse IPv6 reverse address +func ipParseARPA6(s string) (ip6 net.IP) { + if len(s) != 63 { + return nil + } + + ip6 = make(net.IP, 16) + + for i := 0; i != 64; i += 4 { + // parse "0.1." + n := charToHex(s[i]) + n2 := charToHex(s[i+2]) + if s[i+1] != '.' || (i != 60 && s[i+3] != '.') || + n < 0 || n2 < 0 { + return nil + } + + ip6[16-i/4-1] = byte(n2<<4) | byte(n&0x0f) + } + return ip6 +} + +// ipReverse inverts byte order of ip. +func ipReverse(ip net.IP) (rev net.IP) { + ipLen := len(ip) + rev = make(net.IP, ipLen) + for i, b := range ip { + rev[ipLen-i-1] = b + } + + return rev +} + +// ARPA addresses' suffixes. +const ( + arpaV4Suffix = ".in-addr.arpa" + arpaV6Suffix = ".ip6.arpa" +) + +// UnreverseAddr tries to convert reversed ARPA to a normal IP address. +func UnreverseAddr(arpa string) (unreversed net.IP) { + // Unify the input data. + arpa = strings.TrimSuffix(arpa, ".") + arpa = strings.ToLower(arpa) + + if strings.HasSuffix(arpa, arpaV4Suffix) { + ip := strings.TrimSuffix(arpa, arpaV4Suffix) + ip4 := net.ParseIP(ip).To4() + if ip4 == nil { + return nil + } + + return ipReverse(ip4) + + } else if strings.HasSuffix(arpa, arpaV6Suffix) { + ip := strings.TrimSuffix(arpa, arpaV6Suffix) + return ipParseARPA6(ip) + } + + // The suffix unrecognizable. + return nil +} + +// The length of extreme cases of arpa formatted addresses. +// +// The example of IPv4 with maximum length: +// +// 49.91.20.104.in-addr.arpa +// +// The example of IPv6 with maximum length: +// +// 1.3.b.5.4.1.8.6.0.0.0.0.0.0.0.0.0.0.0.0.0.1.0.0.0.0.7.4.6.0.6.2.ip6.arpa +// +const ( + arpaV4MaxLen = len("000.000.000.000") + len(arpaV4Suffix) + arpaV6MaxLen = len("0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0") + + len(arpaV6Suffix) +) + +// ReverseAddr returns the ARPA hostname of the ip suitable for reverse DNS +// (PTR) record lookups. This is the modified version of ReverseAddr from +// github.com/miekg/dns package with no error among returned values. +func ReverseAddr(ip net.IP) (arpa string) { + var strLen int + var suffix string + // Don't handle errors in implementations since strings.WriteString + // never returns non-nil errors. + var writeByte func(val byte) + b := &strings.Builder{} + if ip4 := ip.To4(); ip4 != nil { + strLen, suffix = arpaV4MaxLen, arpaV4Suffix[1:] + ip = ip4 + writeByte = func(val byte) { + _, _ = b.WriteString(strconv.Itoa(int(val))) + _, _ = b.WriteRune('.') + } + + } else if ip6 := ip.To16(); ip6 != nil { + strLen, suffix = arpaV6MaxLen, arpaV6Suffix[1:] + ip = ip6 + writeByte = func(val byte) { + lByte, rByte := val&0xF, val>>4 + + _, _ = b.WriteString(strconv.FormatUint(uint64(lByte), 16)) + _, _ = b.WriteRune('.') + _, _ = b.WriteString(strconv.FormatUint(uint64(rByte), 16)) + _, _ = b.WriteRune('.') + } + + } else { + return "" + } + + b.Grow(strLen) + for i := len(ip) - 1; i >= 0; i-- { + writeByte(ip[i]) + } + _, _ = b.WriteString(suffix) + + return b.String() +} diff --git a/internal/aghnet/net_darwin.go b/internal/aghnet/net_darwin.go index 926e87e5..bd715569 100644 --- a/internal/aghnet/net_darwin.go +++ b/internal/aghnet/net_darwin.go @@ -31,7 +31,7 @@ func ifaceHasStaticIP(ifaceName string) (bool, error) { return portInfo.static, nil } -// getCurrentHardwarePortInfo gets information the specified network interface. +// getCurrentHardwarePortInfo gets information for the specified network interface. func getCurrentHardwarePortInfo(ifaceName string) (hardwarePortInfo, error) { // First of all we should find hardware port name m := getNetworkSetupHardwareReports() diff --git a/internal/aghnet/net_test.go b/internal/aghnet/net_test.go index 3cd2fd6a..9c5afd07 100644 --- a/internal/aghnet/net_test.go +++ b/internal/aghnet/net_test.go @@ -1,8 +1,10 @@ package aghnet import ( + "net" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -14,3 +16,92 @@ func TestGetValidNetInterfacesForWeb(t *testing.T) { require.NotEmptyf(t, iface.Addresses, "no addresses found for %s", iface.Name) } } + +func TestUnreverseAddr(t *testing.T) { + testCases := []struct { + name string + have string + want net.IP + }{{ + name: "good_ipv4", + have: "1.0.0.127.in-addr.arpa", + want: net.IP{127, 0, 0, 1}, + }, { + name: "good_ipv6", + have: "4.3.2.1.d.c.b.a.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa", + want: net.ParseIP("::abcd:1234"), + }, { + name: "good_ipv6_case", + have: "4.3.2.1.d.c.B.A.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.iP6.ArPa", + want: net.ParseIP("::abcd:1234"), + }, { + name: "good_ipv4_dot", + have: "1.0.0.127.in-addr.arpa.", + want: net.IP{127, 0, 0, 1}, + }, { + name: "good_ipv4_case", + have: "1.0.0.127.In-Addr.Arpa", + want: net.IP{127, 0, 0, 1}, + }, { + name: "wrong_ipv4", + have: ".0.0.127.in-addr.arpa", + want: nil, + }, { + name: "wrong_ipv6", + have: ".3.2.1.d.c.b.a.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa", + want: nil, + }, { + name: "bad_ipv6_dot", + have: "4.3.2.1.d.c.b.a.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0..ip6.arpa", + want: nil, + }, { + name: "bad_ipv6_space", + have: "4.3.2.1.d.c.b. .0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa", + want: nil, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ip := UnreverseAddr(tc.have) + assert.True(t, tc.want.Equal(ip)) + }) + } +} + +func TestReverseAddr(t *testing.T) { + testCases := []struct { + name string + want string + ip net.IP + }{{ + name: "valid_ipv4", + want: "4.3.2.1.in-addr.arpa", + ip: net.IP{1, 2, 3, 4}, + }, { + name: "valid_ipv6", + want: "1.3.b.5.4.1.8.6.0.0.0.0.0.0.0.0.0.0.0.0.0.1.0.0.0.0.7.4.6.0.6.2.ip6.arpa", + ip: net.ParseIP("2606:4700:10::6814:5b31"), + }, { + name: "nil_ip", + want: "", + ip: nil, + }, { + name: "unspecified_ipv6", + want: "0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa", + ip: net.IPv6unspecified, + }, { + name: "unspecified_ipv4", + want: "0.0.0.0.in-addr.arpa", + ip: net.IPv4zero, + }, { + name: "wrong_length_ip", + want: "", + ip: net.IP{1, 2, 3, 4, 5}, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.want, ReverseAddr(tc.ip)) + }) + } +} diff --git a/internal/aghnet/subnetdetector.go b/internal/aghnet/subnetdetector.go new file mode 100644 index 00000000..e610929c --- /dev/null +++ b/internal/aghnet/subnetdetector.go @@ -0,0 +1,155 @@ +package aghnet + +import ( + "net" +) + +// SubnetDetector describes IP address properties. +type SubnetDetector struct { + // spNets is the slice of special-purpose address registries as defined + // by RFC-6890 (https://tools.ietf.org/html/rfc6890). + spNets []*net.IPNet + + // locServedNets is the slice of locally-served networks as defined by + // RFC-6303 (https://tools.ietf.org/html/rfc6303). + locServedNets []*net.IPNet +} + +// NewSubnetDetector returns a new IP detector. +func NewSubnetDetector() (snd *SubnetDetector, err error) { + spNets := []string{ + // "This" network. + "0.0.0.0/8", + // Private-Use Networks. + "10.0.0.0/8", + // Shared Address Space. + "100.64.0.0/10", + // Loopback. + "127.0.0.0/8", + // Link Local. + "169.254.0.0/16", + // Private-Use Networks. + "172.16.0.0/12", + // IETF Protocol Assignments. + "192.0.0.0/24", + // DS-Lite. + "192.0.0.0/29", + // TEST-NET-1 + "192.0.2.0/24", + // 6to4 Relay Anycast. + "192.88.99.0/24", + // Private-Use Networks. + "192.168.0.0/16", + // Network Interconnect Device Benchmark Testing. + "198.18.0.0/15", + // TEST-NET-2. + "198.51.100.0/24", + // TEST-NET-3. + "203.0.113.0/24", + // Reserved for Future Use. + "240.0.0.0/4", + // Limited Broadcast. + "255.255.255.255/32", + + // Loopback. + "::1/128", + // Unspecified. + "::/128", + // IPv4-IPv6 Translation Address. + "64:ff9b::/96", + + // IPv4-Mapped Address. Since this network is used for mapping + // IPv4 addresses, we don't include it. + // "::ffff:0:0/96", + + // Discard-Only Prefix. + "100::/64", + // IETF Protocol Assignments. + "2001::/23", + // TEREDO. + "2001::/32", + // Benchmarking. + "2001:2::/48", + // Documentation. + "2001:db8::/32", + // ORCHID. + "2001:10::/28", + // 6to4. + "2002::/16", + // Unique-Local. + "fc00::/7", + // Linked-Scoped Unicast. + "fe80::/10", + } + + // TODO(e.burkov): It's a subslice of the slice above. Should be done + // smarter. + locServedNets := []string{ + // IPv4. + "10.0.0.0/8", + "172.16.0.0/12", + "192.168.0.0/16", + "127.0.0.0/8", + "169.254.0.0/16", + "192.0.2.0/24", + "198.51.100.0/24", + "203.0.113.0/24", + "255.255.255.255/32", + // IPv6. + "::/128", + "::1/128", + "fe80::/10", + "2001:db8::/32", + } + + snd = &SubnetDetector{ + spNets: make([]*net.IPNet, len(spNets)), + locServedNets: make([]*net.IPNet, len(locServedNets)), + } + for i, ipnetStr := range spNets { + var ipnet *net.IPNet + _, ipnet, err = net.ParseCIDR(ipnetStr) + if err != nil { + return nil, err + } + + snd.spNets[i] = ipnet + } + for i, ipnetStr := range locServedNets { + var ipnet *net.IPNet + _, ipnet, err = net.ParseCIDR(ipnetStr) + if err != nil { + return nil, err + } + + snd.locServedNets[i] = ipnet + } + + return snd, nil +} + +// anyNetContains ranges through the given ipnets slice searching for the one +// which contains the ip. For internal use only. +// +// TODO(e.burkov): Think about memoization. +func anyNetContains(ipnets *[]*net.IPNet, ip net.IP) (is bool) { + for _, ipnet := range *ipnets { + if ipnet.Contains(ip) { + return true + } + } + + return false +} + +// IsSpecialNetwork returns true if IP address is contained by any of +// special-purpose IP address registries. It's safe for concurrent use. +func (snd *SubnetDetector) IsSpecialNetwork(ip net.IP) (is bool) { + return anyNetContains(&snd.spNets, ip) +} + +// IsLocallyServedNetwork returns true if IP address is contained by any of +// locally-served IP address registries. It's safe for concurrent use. +func (snd *SubnetDetector) IsLocallyServedNetwork(ip net.IP) (is bool) { + return anyNetContains(&snd.locServedNets, ip) +} diff --git a/internal/aghnet/ipdetector_test.go b/internal/aghnet/subnetdetector_test.go similarity index 54% rename from internal/aghnet/ipdetector_test.go rename to internal/aghnet/subnetdetector_test.go index 07c89c9e..8f2fa4b9 100644 --- a/internal/aghnet/ipdetector_test.go +++ b/internal/aghnet/subnetdetector_test.go @@ -8,11 +8,8 @@ import ( "github.com/stretchr/testify/require" ) -func TestIPDetector_detectSpecialNetwork(t *testing.T) { - var ipd *IPDetector - var err error - - ipd, err = NewIPDetector() +func TestSubnetDetector_DetectSpecialNetwork(t *testing.T) { + snd, err := NewSubnetDetector() require.NoError(t, err) testCases := []struct { @@ -139,7 +136,109 @@ func TestIPDetector_detectSpecialNetwork(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - assert.Equal(t, tc.want, ipd.DetectSpecialNetwork(tc.ip)) + assert.Equal(t, tc.want, snd.IsSpecialNetwork(tc.ip)) }) } } + +func TestSubnetDetector_DetectLocallyServedNetwork(t *testing.T) { + snd, err := NewSubnetDetector() + require.NoError(t, err) + + testCases := []struct { + name string + ip net.IP + want bool + }{{ + name: "not_specific", + ip: net.ParseIP("8.8.8.8"), + want: false, + }, { + name: "private-Use", + ip: net.ParseIP("10.0.0.0"), + want: true, + }, { + name: "loopback", + ip: net.ParseIP("127.0.0.0"), + want: true, + }, { + name: "link_local", + ip: net.ParseIP("169.254.0.0"), + want: true, + }, { + name: "private-use", + ip: net.ParseIP("172.16.0.0"), + want: true, + }, { + name: "documentation_(test-net-1)", + ip: net.ParseIP("192.0.2.0"), + want: true, + }, { + name: "private-use", + ip: net.ParseIP("192.168.0.0"), + want: true, + }, { + name: "documentation_(test-net-2)", + ip: net.ParseIP("198.51.100.0"), + want: true, + }, { + name: "documentation_(test-net-3)", + ip: net.ParseIP("203.0.113.0"), + want: true, + }, { + name: "limited_broadcast", + ip: net.ParseIP("255.255.255.255"), + want: true, + }, { + name: "loopback_address", + ip: net.ParseIP("::1"), + want: true, + }, { + name: "unspecified_address", + ip: net.ParseIP("::"), + want: true, + }, { + name: "documentation", + ip: net.ParseIP("2001:db8::"), + want: true, + }, { + name: "linked-scoped_unicast", + ip: net.ParseIP("fe80::"), + want: true, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.want, snd.IsLocallyServedNetwork(tc.ip)) + }) + } +} + +func TestSubnetDetector_Detect_parallel(t *testing.T) { + t.Parallel() + + snd, err := NewSubnetDetector() + require.NoError(t, err) + + testFunc := func() { + for _, ip := range []net.IP{ + net.IPv4allrouter, + net.IPv4allsys, + net.IPv4bcast, + net.IPv4zero, + net.IPv6interfacelocalallnodes, + net.IPv6linklocalallnodes, + net.IPv6linklocalallrouters, + net.IPv6loopback, + net.IPv6unspecified, + } { + _ = snd.IsSpecialNetwork(ip) + _ = snd.IsLocallyServedNetwork(ip) + } + } + + const goroutinesNum = 50 + for i := 0; i < goroutinesNum; i++ { + go testFunc() + } +} diff --git a/internal/aghnet/systemresolvers.go b/internal/aghnet/systemresolvers.go index 4a0ae6ca..4cba0f92 100644 --- a/internal/aghnet/systemresolvers.go +++ b/internal/aghnet/systemresolvers.go @@ -23,9 +23,9 @@ type SystemResolvers interface { // Get returns the slice of local resolvers' addresses. // It should be safe for concurrent use. Get() (rs []string) - // Refresh refreshes the local resolvers' addresses cache. It should be + // refresh refreshes the local resolvers' addresses cache. It should be // safe for concurrent use. - Refresh() (err error) + refresh() (err error) } const ( @@ -42,7 +42,7 @@ func refreshWithTicker(sr SystemResolvers, tickCh <-chan time.Time) { // TODO(e.burkov): Implement a functionality to stop ticker. for range tickCh { - err := sr.Refresh() + err := sr.refresh() if err != nil { log.Error("systemResolvers: error in refreshing goroutine: %s", err) @@ -63,7 +63,7 @@ func NewSystemResolvers( sr = newSystemResolvers(refreshIvl, hostGenFunc) // Fill cache. - err = sr.Refresh() + err = sr.refresh() if err != nil { return nil, err } diff --git a/internal/aghnet/systemresolvers_others.go b/internal/aghnet/systemresolvers_others.go index ad67cfdb..975ff744 100644 --- a/internal/aghnet/systemresolvers_others.go +++ b/internal/aghnet/systemresolvers_others.go @@ -29,7 +29,7 @@ type systemResolvers struct { addrsLock sync.RWMutex } -func (sr *systemResolvers) Refresh() (err error) { +func (sr *systemResolvers) refresh() (err error) { defer agherr.Annotate("systemResolvers: %w", &err) _, err = sr.resolver.LookupHost(context.Background(), sr.hostGenFunc()) @@ -75,7 +75,7 @@ func (sr *systemResolvers) dialFunc(_ context.Context, _, address string) (_ net sr.addrsLock.Lock() defer sr.addrsLock.Unlock() - sr.addrs[address] = unit{} + sr.addrs[host] = unit{} return nil, fakeDialErr } diff --git a/internal/aghnet/systemresolvers_others_test.go b/internal/aghnet/systemresolvers_others_test.go index 972247b4..f86cdabf 100644 --- a/internal/aghnet/systemresolvers_others_test.go +++ b/internal/aghnet/systemresolvers_others_test.go @@ -31,7 +31,7 @@ func TestSystemResolvers_Refresh(t *testing.T) { t.Run("expected_error", func(t *testing.T) { sr := createTestSystemResolvers(t, 0, nil) - assert.NoError(t, sr.Refresh()) + assert.NoError(t, sr.refresh()) }) t.Run("unexpected_error", func(t *testing.T) { diff --git a/internal/aghnet/systemresolvers_windows.go b/internal/aghnet/systemresolvers_windows.go index c918b44a..75e0a758 100644 --- a/internal/aghnet/systemresolvers_windows.go +++ b/internal/aghnet/systemresolvers_windows.go @@ -138,7 +138,7 @@ func (sr *systemResolvers) getAddrs() (addrs []string, err error) { return addrs, nil } -func (sr *systemResolvers) Refresh() (err error) { +func (sr *systemResolvers) refresh() (err error) { defer agherr.Annotate("systemResolvers: %w", &err) got, err := sr.getAddrs() diff --git a/internal/aghtest/exchanger.go b/internal/aghtest/exchanger.go new file mode 100644 index 00000000..d68a3566 --- /dev/null +++ b/internal/aghtest/exchanger.go @@ -0,0 +1,20 @@ +package aghtest + +import ( + "github.com/AdguardTeam/dnsproxy/upstream" + "github.com/miekg/dns" +) + +// Exchanger is a mock aghnet.Exchanger implementation for tests. +type Exchanger struct { + Ups upstream.Upstream +} + +// Exchange implements aghnet.Exchanger interface for *Exchanger. +func (lr *Exchanger) Exchange(req *dns.Msg) (resp *dns.Msg, err error) { + if lr.Ups == nil { + lr.Ups = &TestErrUpstream{} + } + + return lr.Ups.Exchange(req) +} diff --git a/internal/aghtest/upstream.go b/internal/aghtest/upstream.go index 5cf4925d..44c6a6ce 100644 --- a/internal/aghtest/upstream.go +++ b/internal/aghtest/upstream.go @@ -3,7 +3,6 @@ package aghtest import ( "crypto/sha256" "encoding/hex" - "errors" "fmt" "net" "strings" @@ -71,7 +70,7 @@ func (u *TestUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) { for _, n := range names { resp.Answer = append(resp.Answer, &dns.PTR{ Hdr: dns.RR_Header{ - Name: name, + Name: n, Rrtype: rrType, }, Ptr: n, @@ -162,14 +161,17 @@ func (u *TestBlockUpstream) RequestsCount() int { // TestErrUpstream implements upstream.Upstream interface for replacing real // upstream in tests. -type TestErrUpstream struct{} +type TestErrUpstream struct { + // The error returned by Exchange may be unwraped to the Err. + Err error +} // Exchange always returns nil Msg and non-nil error. func (u *TestErrUpstream) Exchange(*dns.Msg) (*dns.Msg, error) { // We don't use an agherr.Error to avoid the import cycle since aghtests // used to provide the utilities for testing which agherr (and any other // testable package) should be able to use. - return nil, errors.New("bad") + return nil, fmt.Errorf("errupstream: %w", u.Err) } // Address always returns an empty string. diff --git a/internal/dhcpd/http.go b/internal/dhcpd/http.go index 24a73b77..2fbf76c0 100644 --- a/internal/dhcpd/http.go +++ b/internal/dhcpd/http.go @@ -266,7 +266,7 @@ type netInterfaceJSON struct { func (s *Server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) { response := map[string]netInterfaceJSON{} - ifaces, err := aghnet.GetValidNetInterfaces() + ifaces, err := net.Interfaces() if err != nil { httpError(r, w, http.StatusInternalServerError, "Couldn't get interfaces: %s", err) return diff --git a/internal/dnsforward/dns.go b/internal/dnsforward/dns.go index 9b1ce6ef..4fffcc21 100644 --- a/internal/dnsforward/dns.go +++ b/internal/dnsforward/dns.go @@ -1,13 +1,14 @@ package dnsforward import ( + "errors" "net" "strings" "time" + "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/dhcpd" "github.com/AdguardTeam/AdGuardHome/internal/dnsfilter" - "github.com/AdguardTeam/AdGuardHome/internal/util" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/golibs/log" "github.com/miekg/dns" @@ -26,6 +27,9 @@ type dnsContext struct { // origResp is the response received from upstream. It is set when the // response is modified by filters. origResp *dns.Msg + // unreversedReqIP stores an IP address obtained from PTR request if it + // was successfully parsed. + unreversedReqIP net.IP // err is the error returned from a processing function. err error // clientID is the clientID from DOH, DOQ, or DOT, if provided. @@ -78,9 +82,11 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error { mods := []modProcessFunc{ processInitial, s.processInternalHosts, - processInternalIPAddrs, + s.processRestrictLocal, + s.processInternalIPAddrs, processClientID, processFilteringBeforeRequest, + s.processLocalPTR, processUpstream, processDNSSECAfterResponse, processFilteringAfterResponse, @@ -185,6 +191,29 @@ func (s *Server) onDHCPLeaseChanged(flags int) { s.tablePTRLock.Unlock() } +// hostToIP tries to get an IP leased by DHCP and returns the copy of address +// since the data inside the internal table may be changed while request +// processing. It's safe for concurrent use. +func (s *Server) hostToIP(host string) (ip net.IP, ok bool) { + s.tableHostToIPLock.Lock() + defer s.tableHostToIPLock.Unlock() + + if s.tableHostToIP == nil { + return nil, false + } + + var ipFromTable net.IP + ipFromTable, ok = s.tableHostToIP[host] + if !ok { + return nil, false + } + + ip = make(net.IP, len(ipFromTable)) + copy(ip, ipFromTable) + + return ip, true +} + // processInternalHosts respond to A requests if the target hostname is known to // the server. // @@ -206,13 +235,9 @@ func (s *Server) processInternalHosts(dctx *dnsContext) (rc resultCode) { return resultCodeSuccess } - s.tableHostToIPLock.Lock() - if s.tableHostToIP == nil { - s.tableHostToIPLock.Unlock() - return resultCodeSuccess - } - ip, ok := s.tableHostToIP[host] - s.tableHostToIPLock.Unlock() + // TODO(e.burkov): Restrict the access for external clients. + + ip, ok := s.hostToIP(host) if !ok { return resultCodeSuccess } @@ -220,62 +245,143 @@ func (s *Server) processInternalHosts(dctx *dnsContext) (rc resultCode) { log.Debug("dns: internal record: %s -> %s", q.Name, ip) resp := s.makeResponse(req) - if q.Qtype == dns.TypeA { a := &dns.A{ Hdr: s.hdr(req, dns.TypeA), - A: make([]byte, len(ip)), + A: ip, } - - copy(a.A, ip) resp.Answer = append(resp.Answer, a) } - dctx.proxyCtx.Res = resp return resultCodeSuccess } -// Respond to PTR requests if the target IP address is leased by our DHCP server -func processInternalIPAddrs(ctx *dnsContext) (rc resultCode) { - s := ctx.srv - req := ctx.proxyCtx.Req - if req.Question[0].Qtype != dns.TypePTR { +// processRestrictLocal responds with empty answers to PTR requests for IP +// addresses in locally-served network from external clients. +func (s *Server) processRestrictLocal(ctx *dnsContext) (rc resultCode) { + d := ctx.proxyCtx + req := d.Req + q := req.Question[0] + if q.Qtype != dns.TypePTR { + // No need for restriction. return resultCodeSuccess } - arpa := req.Question[0].Name - arpa = strings.TrimSuffix(arpa, ".") - arpa = strings.ToLower(arpa) - ip := util.DNSUnreverseAddr(arpa) + ip := aghnet.UnreverseAddr(q.Name) + if ip == nil { + // That's weird. + // + // TODO(e.burkov): Research the cases when it could happen. + return resultCodeSuccess + } + + // Restrict an access to local addresses for external clients. We also + // assume that all the DHCP leases we give are locally-served or at + // least don't need to be unaccessable externally. + if s.subnetDetector.IsLocallyServedNetwork(ip) { + clientIP := IPFromAddr(d.Addr) + if !s.subnetDetector.IsLocallyServedNetwork(clientIP) { + log.Debug("dns: %q requests for internal ip", clientIP) + d.Res = s.makeResponse(req) + + // Do not even put into query log. + return resultCodeFinish + } + } + + // Do not perform unreversing ever again. + ctx.unreversedReqIP = ip + + // Nothing to restrict. + return resultCodeSuccess +} + +// ipToHost tries to get a hostname leased by DHCP. It's safe for concurrent +// use. +func (s *Server) ipToHost(ip net.IP) (host string, ok bool) { + s.tablePTRLock.Lock() + defer s.tablePTRLock.Unlock() + + if s.tablePTR == nil { + return "", false + } + + host, ok = s.tablePTR[ip.String()] + + return host, ok +} + +// Respond to PTR requests if the target IP is leased by our DHCP server and the +// requestor is inside the local network. +func (s *Server) processInternalIPAddrs(ctx *dnsContext) (rc resultCode) { + d := ctx.proxyCtx + if d.Res != nil { + return resultCodeSuccess + } + + ip := ctx.unreversedReqIP if ip == nil { return resultCodeSuccess } - s.tablePTRLock.Lock() - if s.tablePTR == nil { - s.tablePTRLock.Unlock() - return resultCodeSuccess - } - host, ok := s.tablePTR[ip.String()] - s.tablePTRLock.Unlock() + host, ok := s.ipToHost(ip) if !ok { return resultCodeSuccess } - log.Debug("dns: reverse-lookup: %s -> %s", arpa, host) + log.Debug("dns: reverse-lookup: %s -> %s", ip, host) + req := d.Req resp := s.makeResponse(req) - ptr := &dns.PTR{} - ptr.Hdr = dns.RR_Header{ - Name: req.Question[0].Name, - Rrtype: dns.TypePTR, - Ttl: s.conf.BlockedResponseTTL, - Class: dns.ClassINET, + ptr := &dns.PTR{ + Hdr: dns.RR_Header{ + Name: req.Question[0].Name, + Rrtype: dns.TypePTR, + Ttl: s.conf.BlockedResponseTTL, + Class: dns.ClassINET, + }, + Ptr: dns.Fqdn(host), } - ptr.Ptr = host + "." resp.Answer = append(resp.Answer, ptr) - ctx.proxyCtx.Res = resp + d.Res = resp + + return resultCodeSuccess +} + +// processLocalPTR responds to PTR requests if the target IP is detected to be +// inside the local network and the query was not answered from DHCP. +func (s *Server) processLocalPTR(ctx *dnsContext) (rc resultCode) { + d := ctx.proxyCtx + if d.Res != nil { + return resultCodeSuccess + } + + ip := ctx.unreversedReqIP + if ip == nil { + return resultCodeSuccess + } + + if !s.subnetDetector.IsLocallyServedNetwork(ip) { + return resultCodeSuccess + } + + req := d.Req + resp, err := s.localResolvers.Exchange(req) + if err != nil { + if errors.Is(err, aghnet.NoUpstreamsErr) { + d.Res = s.genNXDomain(req) + + return resultCodeFinish + } + + ctx.err = err + + return resultCodeError + } + + d.Res = resp + return resultCodeSuccess } diff --git a/internal/dnsforward/dns_test.go b/internal/dnsforward/dns_test.go index 188d0705..d3b91466 100644 --- a/internal/dnsforward/dns_test.go +++ b/internal/dnsforward/dns_test.go @@ -4,7 +4,10 @@ import ( "net" "testing" + "github.com/AdguardTeam/AdGuardHome/internal/aghtest" + "github.com/AdguardTeam/AdGuardHome/internal/dnsfilter" "github.com/AdguardTeam/dnsproxy/proxy" + "github.com/AdguardTeam/dnsproxy/upstream" "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -120,3 +123,74 @@ func TestServer_ProcessInternalHosts(t *testing.T) { }) } } + +func TestLocalRestriction(t *testing.T) { + s := createTestServer(t, &dnsfilter.Config{}, ServerConfig{ + UDPListenAddrs: []*net.UDPAddr{{}}, + TCPListenAddrs: []*net.TCPAddr{{}}, + }) + ups := &aghtest.TestUpstream{ + Reverse: map[string][]string{ + "251.252.253.254.in-addr.arpa.": {"host1.example.net."}, + "1.1.168.192.in-addr.arpa.": {"some.local-client."}, + }, + } + s.localResolvers = &aghtest.Exchanger{Ups: ups} + s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ups} + startDeferStop(t, s) + + testCases := []struct { + name string + want string + question net.IP + cliIP net.IP + wantLen int + }{{ + name: "from_local_to_external", + want: "host1.example.net.", + question: net.IP{254, 253, 252, 251}, + cliIP: net.IP{192, 168, 10, 10}, + wantLen: 1, + }, { + name: "from_external_for_local", + want: "", + question: net.IP{192, 168, 1, 1}, + cliIP: net.IP{254, 253, 252, 251}, + wantLen: 0, + }, { + name: "from_local_for_local", + want: "some.local-client.", + question: net.IP{192, 168, 1, 1}, + cliIP: net.IP{192, 168, 1, 2}, + wantLen: 1, + }, { + name: "from_external_for_external", + want: "host1.example.net.", + question: net.IP{254, 253, 252, 251}, + cliIP: net.IP{254, 253, 252, 255}, + wantLen: 1, + }} + + for _, tc := range testCases { + reqAddr, err := dns.ReverseAddr(tc.question.String()) + require.NoError(t, err) + req := createTestMessageWithType(reqAddr, dns.TypePTR) + + pctx := &proxy.DNSContext{ + Proto: proxy.ProtoTCP, + Req: req, + Addr: &net.TCPAddr{ + IP: tc.cliIP, + }, + } + t.Run(tc.name, func(t *testing.T) { + err = s.handleDNSRequest(nil, pctx) + require.Nil(t, err) + require.NotNil(t, pctx.Res) + require.Len(t, pctx.Res.Answer, tc.wantLen) + if tc.wantLen > 0 { + assert.Equal(t, tc.want, pctx.Res.Answer[0].Header().Name) + } + }) + } +} diff --git a/internal/dnsforward/dnsforward.go b/internal/dnsforward/dnsforward.go index 85acf201..adab01a3 100644 --- a/internal/dnsforward/dnsforward.go +++ b/internal/dnsforward/dnsforward.go @@ -11,6 +11,7 @@ import ( "sync" "time" + "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/dhcpd" "github.com/AdguardTeam/AdGuardHome/internal/dnsfilter" "github.com/AdguardTeam/AdGuardHome/internal/querylog" @@ -60,7 +61,9 @@ type Server struct { // be a valid top-level domain plus dots on each side. autohostSuffix string - ipset ipsetCtx + ipset ipsetCtx + subnetDetector *aghnet.SubnetDetector + localResolvers aghnet.Exchanger tableHostToIP map[string]net.IP // "hostname -> IP" table for internal addresses (DHCP) tableHostToIPLock sync.Mutex @@ -84,11 +87,13 @@ const defaultAutohostSuffix = ".lan." // DNSCreateParams are parameters to create a new server. type DNSCreateParams struct { - DNSFilter *dnsfilter.DNSFilter - Stats stats.Stats - QueryLog querylog.QueryLog - DHCPServer dhcpd.ServerInterface - AutohostTLD string + DNSFilter *dnsfilter.DNSFilter + Stats stats.Stats + QueryLog querylog.QueryLog + DHCPServer dhcpd.ServerInterface + SubnetDetector *aghnet.SubnetDetector + LocalResolvers aghnet.Exchanger + AutohostTLD string } // tldToSuffix converts a top-level domain into an autohost suffix. @@ -121,6 +126,8 @@ func NewServer(p DNSCreateParams) (s *Server, err error) { dnsFilter: p.DNSFilter, stats: p.Stats, queryLog: p.QueryLog, + subnetDetector: p.SubnetDetector, + localResolvers: p.LocalResolvers, autohostSuffix: autohostSuffix, } diff --git a/internal/dnsforward/dnsforward_test.go b/internal/dnsforward/dnsforward_test.go index 5758f2d5..177e5a2a 100644 --- a/internal/dnsforward/dnsforward_test.go +++ b/internal/dnsforward/dnsforward_test.go @@ -18,6 +18,7 @@ import ( "testing" "time" + "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/dhcpd" "github.com/AdguardTeam/AdGuardHome/internal/dnsfilter" @@ -64,7 +65,16 @@ func createTestServer(t *testing.T, filterConf *dnsfilter.Config, forwardConf Se f := dnsfilter.New(filterConf, filters) - s, err := NewServer(DNSCreateParams{DNSFilter: f}) + snd, err := aghnet.NewSubnetDetector() + require.NoError(t, err) + require.NotNil(t, snd) + + var s *Server + s, err = NewServer(DNSCreateParams{ + DNSFilter: f, + SubnetDetector: snd, + LocalResolvers: &aghtest.Exchanger{}, + }) require.NoError(t, err) s.conf = forwardConf @@ -710,8 +720,15 @@ func TestBlockedCustomIP(t *testing.T) { Data: []byte(rules), }} - s, err := NewServer(DNSCreateParams{ - DNSFilter: dnsfilter.New(&dnsfilter.Config{}, filters), + snd, err := aghnet.NewSubnetDetector() + require.NoError(t, err) + require.NotNil(t, snd) + + var s *Server + s, err = NewServer(DNSCreateParams{ + DNSFilter: dnsfilter.New(&dnsfilter.Config{}, filters), + SubnetDetector: snd, + LocalResolvers: &aghtest.Exchanger{}, }) require.NoError(t, err) @@ -841,18 +858,26 @@ func TestRewrite(t *testing.T) { } f := dnsfilter.New(c, nil) - s, err := NewServer(DNSCreateParams{DNSFilter: f}) + snd, err := aghnet.NewSubnetDetector() + require.NoError(t, err) + require.NotNil(t, snd) + + var s *Server + s, err = NewServer(DNSCreateParams{ + DNSFilter: f, + SubnetDetector: snd, + LocalResolvers: &aghtest.Exchanger{}, + }) require.NoError(t, err) - err = s.Prepare(&ServerConfig{ + assert.NoError(t, s.Prepare(&ServerConfig{ UDPListenAddrs: []*net.UDPAddr{{}}, TCPListenAddrs: []*net.TCPAddr{{}}, FilteringConfig: FilteringConfig{ ProtectionEnabled: true, UpstreamDNS: []string{"8.8.8.8:53"}, }, - }) - assert.NoError(t, err) + })) s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ &aghtest.TestUpstream{ @@ -1134,9 +1159,16 @@ func (d *testDHCP) Leases(flags int) []dhcpd.Lease { func (d *testDHCP) SetOnLeaseChanged(onLeaseChanged dhcpd.OnLeaseChangedT) {} func TestPTRResponseFromDHCPLeases(t *testing.T) { - s, err := NewServer(DNSCreateParams{ - DNSFilter: dnsfilter.New(&dnsfilter.Config{}, nil), - DHCPServer: &testDHCP{}, + snd, err := aghnet.NewSubnetDetector() + require.NoError(t, err) + require.NotNil(t, snd) + + var s *Server + s, err = NewServer(DNSCreateParams{ + DNSFilter: dnsfilter.New(&dnsfilter.Config{}, nil), + DHCPServer: &testDHCP{}, + SubnetDetector: snd, + LocalResolvers: &aghtest.Exchanger{}, }) require.NoError(t, err) @@ -1192,7 +1224,17 @@ func TestPTRResponseFromHosts(t *testing.T) { c.AutoHosts.Init(hf.Name()) t.Cleanup(c.AutoHosts.Close) - s, err := NewServer(DNSCreateParams{DNSFilter: dnsfilter.New(&c, nil)}) + var snd *aghnet.SubnetDetector + snd, err = aghnet.NewSubnetDetector() + require.NoError(t, err) + require.NotNil(t, snd) + + var s *Server + s, err = NewServer(DNSCreateParams{ + DNSFilter: dnsfilter.New(&c, nil), + SubnetDetector: snd, + LocalResolvers: &aghtest.Exchanger{}, + }) require.NoError(t, err) s.conf.UDPListenAddrs = []*net.UDPAddr{{}} diff --git a/internal/dnsforward/filter.go b/internal/dnsforward/filter.go index 3e2cebd2..8b1c3283 100644 --- a/internal/dnsforward/filter.go +++ b/internal/dnsforward/filter.go @@ -46,6 +46,8 @@ func (s *Server) getClientRequestFilteringSettings(ctx *dnsContext) *dnsfilter.F // was filtered. func (s *Server) filterDNSRequest(ctx *dnsContext) (*dnsfilter.Result, error) { d := ctx.proxyCtx + // TODO(e.burkov): Consistently use req instead of d.Req since it is + // declared. req := d.Req host := strings.TrimSuffix(req.Question[0].Name, ".") res, err := s.dnsFilter.CheckHost(host, d.Req.Question[0].Qtype, ctx.setts) diff --git a/internal/home/clients.go b/internal/home/clients.go index 49ea1a71..9e0d0fd3 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -591,8 +591,9 @@ func (clients *clientsContainer) SetWhoisInfo(ip string, info [][]string) { // taken into account. ok is true if the pairing was added. func (clients *clientsContainer) AddHost(ip, host string, src clientSource) (ok bool, err error) { clients.lock.Lock() + defer clients.lock.Unlock() + ok = clients.addHostLocked(ip, host, src) - clients.lock.Unlock() return ok, nil } diff --git a/internal/home/dns.go b/internal/home/dns.go index d9795560..5e629d4d 100644 --- a/internal/home/dns.go +++ b/internal/home/dns.go @@ -61,10 +61,12 @@ func initDNSServer() error { Context.dnsFilter = dnsfilter.New(&filterConf, nil) p := dnsforward.DNSCreateParams{ - DNSFilter: Context.dnsFilter, - Stats: Context.stats, - QueryLog: Context.queryLog, - AutohostTLD: config.DNS.AutohostTLD, + DNSFilter: Context.dnsFilter, + Stats: Context.stats, + QueryLog: Context.queryLog, + SubnetDetector: Context.subnetDetector, + LocalResolvers: Context.localResolvers, + AutohostTLD: config.DNS.AutohostTLD, } if Context.dhcpServer != nil { p.DHCPServer = Context.dhcpServer @@ -91,7 +93,7 @@ func initDNSServer() error { return fmt.Errorf("dnsServer.Prepare: %w", err) } - Context.rdns = InitRDNS(Context.dnsServer, &Context.clients) + Context.rdns = NewRDNS(Context.dnsServer, &Context.clients, Context.subnetDetector, Context.localResolvers) Context.whois = initWhois(&Context.clients) Context.filters.Init() @@ -105,14 +107,14 @@ func isRunning() bool { func onDNSRequest(d *proxy.DNSContext) { ip := dnsforward.IPFromAddr(d.Addr) if ip == nil { - // This would be quite weird if we get here + // This would be quite weird if we get here. return } if !ip.IsLoopback() { Context.rdns.Begin(ip) } - if !Context.ipDetector.DetectSpecialNetwork(ip) { + if !Context.subnetDetector.IsSpecialNetwork(ip) { Context.whois.Begin(ip) } } @@ -333,10 +335,10 @@ func startDNSServer() error { const topClientsNumber = 100 // the number of clients to get for _, ip := range Context.stats.GetTopClientsIP(topClientsNumber) { - if !ip.IsLoopback() { + if !Context.subnetDetector.IsLocallyServedNetwork(ip) { Context.rdns.Begin(ip) } - if !Context.ipDetector.DetectSpecialNetwork(ip) { + if !Context.subnetDetector.IsSpecialNetwork(ip) { Context.whois.Begin(ip) } } diff --git a/internal/home/home.go b/internal/home/home.go index 5c2b2d3f..efbdf0ba 100644 --- a/internal/home/home.go +++ b/internal/home/home.go @@ -61,7 +61,9 @@ type homeContext struct { autoHosts util.AutoHosts // IP-hostname pairs taken from system configuration (e.g. /etc/hosts) files updater *updater.Updater - ipDetector *aghnet.IPDetector + subnetDetector *aghnet.SubnetDetector + systemResolvers aghnet.SystemResolvers + localResolvers aghnet.Exchanger // mux is our custom http.ServeMux. mux *http.ServeMux @@ -220,6 +222,110 @@ func setupConfig(args options) { } } +const defaultLocalTimeout = 5 * time.Second + +// stringsSetSubtract subtracts b from a interpreted as sets. +// +// TODO(e.burkov): Move into our internal package for working with strings. +func stringsSetSubtract(a, b []string) (c []string) { + // unit is an object to be used as value in set. + type unit = struct{} + + cSet := make(map[string]unit) + for _, k := range a { + cSet[k] = unit{} + } + + for _, k := range b { + delete(cSet, k) + } + + c = make([]string, len(cSet)) + i := 0 + for k := range cSet { + c[i] = k + i++ + } + + return c +} + +// collectAllIfacesAddrs returns the slice of all network interfaces IP +// addresses without port number. +func collectAllIfacesAddrs() (addrs []string, err error) { + var ifaces []net.Interface + ifaces, err = net.Interfaces() + if err != nil { + return nil, fmt.Errorf("getting network interfaces: %w", err) + } + + for _, iface := range ifaces { + var ifaceAddrs []net.Addr + ifaceAddrs, err = iface.Addrs() + if err != nil { + return nil, fmt.Errorf("getting addresses for %q: %w", iface.Name, err) + } + + for _, addr := range ifaceAddrs { + cidr := addr.String() + var ip net.IP + ip, _, err = net.ParseCIDR(cidr) + if err != nil { + return nil, fmt.Errorf("parsing %q as cidr: %w", cidr, err) + } + + addrs = append(addrs, ip.String()) + } + } + + return addrs, nil +} + +// collectDNSIPAddrs returns the slice of IP addresses without port number which +// we are listening on. +func collectDNSIPaddrs() (addrs []string, err error) { + addrs = make([]string, len(config.DNS.BindHosts)) + + for i, bh := range config.DNS.BindHosts { + if bh.IsUnspecified() { + return collectAllIfacesAddrs() + } + + addrs[i] = bh.String() + } + + return addrs, nil +} + +func setupResolvers() { + // TODO(e.burkov): Enhance when the config will contain local resolvers + // addresses. + + sysRes, err := aghnet.NewSystemResolvers(0, nil) + if err != nil { + log.Fatal(err) + } + + Context.systemResolvers = sysRes + + var ourAddrs []string + ourAddrs, err = collectDNSIPaddrs() + if err != nil { + log.Fatal(err) + } + + // TODO(e.burkov): The approach of subtracting sets of strings is not + // really applicable here since in case of listening on all network + // interfaces we should check the whole interface's network to cut off + // all the loopback addresses as well. + addrs := stringsSetSubtract(sysRes.Get(), ourAddrs) + + Context.localResolvers, err = aghnet.NewMultiAddrExchanger(addrs, defaultLocalTimeout) + if err != nil { + log.Fatal(err) + } +} + // run performs configurating and starts AdGuard Home. func run(args options) { // configure config filename @@ -305,11 +411,13 @@ func run(args options) { log.Fatalf("Can't initialize Web module") } - Context.ipDetector, err = aghnet.NewIPDetector() + Context.subnetDetector, err = aghnet.NewSubnetDetector() if err != nil { log.Fatal(err) } + setupResolvers() + if !Context.firstRun { err = initDNSServer() if err != nil { diff --git a/internal/home/rdns.go b/internal/home/rdns.go index c21a6f6e..55df779c 100644 --- a/internal/home/rdns.go +++ b/internal/home/rdns.go @@ -2,129 +2,163 @@ package home import ( "encoding/binary" + "fmt" "net" "strings" "time" + "github.com/AdguardTeam/AdGuardHome/internal/agherr" + "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/golibs/cache" "github.com/AdguardTeam/golibs/log" "github.com/miekg/dns" ) -// RDNS - module context +// RDNS resolves clients' addresses to enrich their metadata. type RDNS struct { - dnsServer *dnsforward.Server - clients *clientsContainer - ipChannel chan net.IP // pass data from DNS request handling thread to rDNS thread + dnsServer *dnsforward.Server + clients *clientsContainer + subnetDetector *aghnet.SubnetDetector + localResolvers aghnet.Exchanger - // Contains IP addresses of clients to be resolved by rDNS - // If IP address is resolved, it stays here while it's inside Clients. - // If it's removed from Clients, this IP address will be resolved once again. - // If IP address couldn't be resolved, it stays here for some time to prevent further attempts to resolve the same IP. - ipAddrs cache.Cache + // ipCh used to pass client's IP to rDNS workerLoop. + ipCh chan net.IP + + // ipCache caches the IP addresses to be resolved by rDNS. The resolved + // address stays here while it's inside clients. After leaving clients + // the address will be resolved once again. If the address couldn't be + // resolved, cache prevents further attempts to resolve it for some + // time. + ipCache cache.Cache } -// InitRDNS - create module context -func InitRDNS(dnsServer *dnsforward.Server, clients *clientsContainer) *RDNS { - r := &RDNS{ - dnsServer: dnsServer, - clients: clients, - ipAddrs: cache.New(cache.Config{ +// Default rDNS values. +const ( + defaultRDNSCacheSize = 10000 + defaultRDNSCacheTTL = 1 * 60 * 60 + defaultRDNSIPChSize = 256 +) + +// NewRDNS creates and returns initialized RDNS. +func NewRDNS( + dnsServer *dnsforward.Server, + clients *clientsContainer, + snd *aghnet.SubnetDetector, + lr aghnet.Exchanger, +) (rDNS *RDNS) { + rDNS = &RDNS{ + dnsServer: dnsServer, + clients: clients, + subnetDetector: snd, + localResolvers: lr, + ipCache: cache.New(cache.Config{ EnableLRU: true, - MaxCount: 10000, + MaxCount: defaultRDNSCacheSize, }), - ipChannel: make(chan net.IP, 256), + ipCh: make(chan net.IP, defaultRDNSIPChSize), } - go r.workerLoop() - return r + go rDNS.workerLoop() + + return rDNS } -// Begin - add IP address to rDNS queue +// Begin adds the ip to the resolving queue if it is not cached or already +// resolved. func (r *RDNS) Begin(ip net.IP) { now := uint64(time.Now().Unix()) - expire := r.ipAddrs.Get(ip) - if len(expire) != 0 { - exp := binary.BigEndian.Uint64(expire) - if exp > now { + if expire := r.ipCache.Get(ip); len(expire) != 0 { + if binary.BigEndian.Uint64(expire) > now { return } - // TTL expired } - expire = make([]byte, 8) - const ttl = 1 * 60 * 60 - binary.BigEndian.PutUint64(expire, now+ttl) - _ = r.ipAddrs.Set(ip, expire) + + // The cache entry either expired or doesn't exist. + ttl := make([]byte, 8) + binary.BigEndian.PutUint64(ttl, now+defaultRDNSCacheTTL) + r.ipCache.Set(ip, ttl) id := ip.String() if r.clients.Exists(id, ClientSourceRDNS) { return } - log.Tracef("rDNS: adding %s", ip) select { - case r.ipChannel <- ip: - // + case r.ipCh <- ip: + log.Tracef("rdns: %q added to queue", ip) default: - log.Tracef("rDNS: queue is full") + log.Tracef("rdns: queue is full") } } -// Use rDNS to get hostname by IP address -func (r *RDNS) resolve(ip net.IP) string { - log.Tracef("Resolving host for %s", ip) +const ( + // rDNSEmptyAnswerErr is returned by RDNS resolve method when the answer + // section of respond is empty. + rDNSEmptyAnswerErr agherr.Error = "the answer section is empty" - name, err := dns.ReverseAddr(ip.String()) - if err != nil { - log.Debug("Error while calling dns.ReverseAddr(%s): %s", ip, err) - return "" - } + // rDNSNotPTRErr is returned by RDNS resolve method when the response is + // not of PTR type. + rDNSNotPTRErr agherr.Error = "the response is not a ptr" +) - resp, err := r.dnsServer.Exchange(&dns.Msg{ +// resolve tries to resolve the ip in a suitable way. +func (r *RDNS) resolve(ip net.IP) (host string, err error) { + log.Tracef("rdns: resolving host for %q", ip) + + arpa := dns.Fqdn(aghnet.ReverseAddr(ip)) + msg := &dns.Msg{ MsgHdr: dns.MsgHdr{ Id: dns.Id(), RecursionDesired: true, }, + Compress: true, Question: []dns.Question{{ - Name: name, + Name: arpa, Qtype: dns.TypePTR, Qclass: dns.ClassINET, }}, - }) + } + + var resp *dns.Msg + if r.subnetDetector.IsLocallyServedNetwork(ip) { + resp, err = r.localResolvers.Exchange(msg) + } else { + resp, err = r.dnsServer.Exchange(msg) + } if err != nil { - log.Debug("Error while making an rDNS lookup for %s: %s", ip, err) - return "" + return "", fmt.Errorf("performing lookup for %q: %w", arpa, err) } + if len(resp.Answer) == 0 { - log.Debug("No answer for rDNS lookup of %s", ip) - return "" + return "", fmt.Errorf("lookup for %q: %w", arpa, rDNSEmptyAnswerErr) } + ptr, ok := resp.Answer[0].(*dns.PTR) if !ok { - log.Debug("not a PTR response for %s", ip) - return "" + return "", fmt.Errorf("type checking: %w", rDNSNotPTRErr) } - log.Tracef("PTR response for %s: %s", ip, ptr.String()) - if strings.HasSuffix(ptr.Ptr, ".") { - ptr.Ptr = ptr.Ptr[:len(ptr.Ptr)-1] - } + log.Tracef("rdns: ptr response for %q: %s", ip, ptr.String()) - return ptr.Ptr + return strings.TrimSuffix(ptr.Ptr, "."), nil } -// Wait for a signal and then synchronously resolve hostname by IP address -// Add the hostname:IP pair to "Clients" array +// workerLoop handles incoming IP addresses from ipChan and adds it into +// clients. func (r *RDNS) workerLoop() { - for { - ip := <-r.ipChannel + defer agherr.LogPanic("rdns") + + for ip := range r.ipCh { + host, err := r.resolve(ip) + if err != nil { + log.Error("rdns: resolving %q: %s", ip, err) - host := r.resolve(ip) - if len(host) == 0 { continue } + // Don't handle any errors since AddHost doesn't return non-nil + // errors for now. _, _ = r.clients.AddHost(ip.String(), host, ClientSourceRDNS) } } diff --git a/internal/home/rdns_test.go b/internal/home/rdns_test.go index b17efdd8..0e313ef6 100644 --- a/internal/home/rdns_test.go +++ b/internal/home/rdns_test.go @@ -1,32 +1,265 @@ package home import ( + "bytes" + "encoding/binary" + "errors" "net" + "sync" "testing" + "time" + "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" + "github.com/AdguardTeam/golibs/cache" + "github.com/AdguardTeam/golibs/log" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -func TestResolveRDNS(t *testing.T) { - ups := &aghtest.TestUpstream{ +func TestRDNS_Begin(t *testing.T) { + aghtest.ReplaceLogLevel(t, log.DEBUG) + w := &bytes.Buffer{} + aghtest.ReplaceLogWriter(t, w) + + ip1234, ip1235 := net.IP{1, 2, 3, 4}, net.IP{1, 2, 3, 5} + + testCases := []struct { + cliIDIndex map[string]*Client + customChan chan net.IP + name string + wantLog string + req net.IP + wantCacheHit int + wantCacheMiss int + }{{ + cliIDIndex: map[string]*Client{}, + customChan: nil, + name: "cached", + wantLog: "", + req: ip1234, + wantCacheHit: 1, + wantCacheMiss: 0, + }, { + cliIDIndex: map[string]*Client{}, + customChan: nil, + name: "not_cached", + wantLog: "rdns: queue is full", + req: ip1235, + wantCacheHit: 0, + wantCacheMiss: 1, + }, { + cliIDIndex: map[string]*Client{"1.2.3.5": {}}, + customChan: nil, + name: "already_in_clients", + wantLog: "", + req: ip1235, + wantCacheHit: 0, + wantCacheMiss: 1, + }, { + cliIDIndex: map[string]*Client{}, + customChan: make(chan net.IP, 1), + name: "add_to_queue", + wantLog: `rdns: "1.2.3.5" added to queue`, + req: ip1235, + wantCacheHit: 0, + wantCacheMiss: 1, + }} + + for _, tc := range testCases { + w.Reset() + + ipCache := cache.New(cache.Config{ + EnableLRU: true, + MaxCount: defaultRDNSCacheSize, + }) + ttl := make([]byte, binary.Size(uint64(0))) + binary.BigEndian.PutUint64(ttl, uint64(time.Now().Add(100*time.Hour).Unix())) + + rdns := &RDNS{ + ipCache: ipCache, + clients: &clientsContainer{ + list: map[string]*Client{}, + idIndex: tc.cliIDIndex, + ipHost: map[string]*ClientHost{}, + allTags: map[string]bool{}, + }, + } + ipCache.Clear() + ipCache.Set(net.IP{1, 2, 3, 4}, ttl) + + if tc.customChan != nil { + rdns.ipCh = tc.customChan + defer close(tc.customChan) + } + + t.Run(tc.name, func(t *testing.T) { + rdns.Begin(tc.req) + assert.Equal(t, tc.wantCacheHit, ipCache.Stats().Hit) + assert.Equal(t, tc.wantCacheMiss, ipCache.Stats().Miss) + assert.Contains(t, w.String(), tc.wantLog) + }) + } +} + +func TestRDNS_Resolve(t *testing.T) { + extUpstream := &aghtest.TestUpstream{ Reverse: map[string][]string{ "1.1.1.1.in-addr.arpa.": {"one.one.one.one"}, }, } + locUpstream := &aghtest.TestUpstream{ + Reverse: map[string][]string{ + "1.1.168.192.in-addr.arpa.": {"local.domain"}, + "2.1.168.192.in-addr.arpa.": {}, + }, + } + upstreamErr := errors.New("upstream error") + errUpstream := &aghtest.TestErrUpstream{ + Err: upstreamErr, + } + nonPtrUpstream := &aghtest.TestBlockUpstream{ + Hostname: "some-host", + Block: true, + } + dns := dnsforward.NewCustomServer(&proxy.Proxy{ Config: proxy.Config{ UpstreamConfig: &proxy.UpstreamConfig{ - Upstreams: []upstream.Upstream{ups}, + Upstreams: []upstream.Upstream{extUpstream}, }, }, }) - clients := &clientsContainer{} - rdns := InitRDNS(dns, clients) - r := rdns.resolve(net.IP{1, 1, 1, 1}) - assert.Equal(t, "one.one.one.one", r, r) + cc := &clientsContainer{} + + snd, err := aghnet.NewSubnetDetector() + require.NoError(t, err) + + localIP := net.IP{192, 168, 1, 1} + testCases := []struct { + name string + want string + wantErr error + locUpstream upstream.Upstream + req net.IP + }{{ + name: "external_good", + want: "one.one.one.one", + wantErr: nil, + locUpstream: nil, + req: net.IP{1, 1, 1, 1}, + }, { + name: "local_good", + want: "local.domain", + wantErr: nil, + locUpstream: locUpstream, + req: localIP, + }, { + name: "upstream_error", + want: "", + wantErr: upstreamErr, + locUpstream: errUpstream, + req: localIP, + }, { + name: "empty_answer_error", + want: "", + wantErr: rDNSEmptyAnswerErr, + locUpstream: locUpstream, + req: net.IP{192, 168, 1, 2}, + }, { + name: "not_ptr_error", + want: "", + wantErr: rDNSNotPTRErr, + locUpstream: nonPtrUpstream, + req: localIP, + }} + + for _, tc := range testCases { + rdns := NewRDNS(dns, cc, snd, &aghtest.Exchanger{ + Ups: tc.locUpstream, + }) + + t.Run(tc.name, func(t *testing.T) { + r, rerr := rdns.resolve(tc.req) + require.ErrorIs(t, rerr, tc.wantErr) + assert.Equal(t, tc.want, r) + }) + } +} + +func TestRDNS_WorkerLoop(t *testing.T) { + aghtest.ReplaceLogLevel(t, log.DEBUG) + w := &bytes.Buffer{} + aghtest.ReplaceLogWriter(t, w) + + locUpstream := &aghtest.TestUpstream{ + Reverse: map[string][]string{ + "1.1.168.192.in-addr.arpa.": {"local.domain"}, + }, + } + + snd, err := aghnet.NewSubnetDetector() + require.NoError(t, err) + + testCases := []struct { + wantLog string + name string + cliIP net.IP + }{{ + wantLog: "", + name: "all_good", + cliIP: net.IP{192, 168, 1, 1}, + }, { + wantLog: `rdns: resolving "192.168.1.2": lookup for "2.1.168.192.in-addr.arpa.": ` + + string(rDNSEmptyAnswerErr), + name: "resolve_error", + cliIP: net.IP{192, 168, 1, 2}, + }} + + for _, tc := range testCases { + w.Reset() + + lr := &aghtest.Exchanger{ + Ups: locUpstream, + } + cc := &clientsContainer{ + list: map[string]*Client{}, + idIndex: map[string]*Client{}, + ipHost: map[string]*ClientHost{}, + allTags: map[string]bool{}, + } + ch := make(chan net.IP) + rdns := &RDNS{ + dnsServer: nil, + clients: cc, + subnetDetector: snd, + localResolvers: lr, + ipCh: ch, + } + + t.Run(tc.name, func(t *testing.T) { + var wg sync.WaitGroup + wg.Add(1) + go func() { + rdns.workerLoop() + wg.Done() + }() + + ch <- tc.cliIP + close(ch) + wg.Wait() + + if tc.wantLog != "" { + assert.Contains(t, w.String(), tc.wantLog) + + return + } + + assert.True(t, cc.Exists(tc.cliIP.String(), ClientSourceRDNS)) + }) + } } diff --git a/internal/util/autohosts.go b/internal/util/autohosts.go index c3156920..22602ac4 100644 --- a/internal/util/autohosts.go +++ b/internal/util/autohosts.go @@ -12,6 +12,7 @@ import ( "strings" "sync" + "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/golibs/log" "github.com/fsnotify/fsnotify" "github.com/miekg/dns" @@ -139,7 +140,7 @@ func (a *AutoHosts) ProcessReverse(addr string, qtype uint16) (hosts []string) { return nil } - ipReal := DNSUnreverseAddr(addr) + ipReal := aghnet.UnreverseAddr(addr) if ipReal == nil { return nil } diff --git a/internal/util/autohosts_test.go b/internal/util/autohosts_test.go index 367ba50a..60ff4622 100644 --- a/internal/util/autohosts_test.go +++ b/internal/util/autohosts_test.go @@ -128,45 +128,3 @@ func TestAutoHostsFSNotify(t *testing.T) { assert.True(t, net.IP{127, 0, 0, 2}.Equal(ips[0])) }) } - -func TestDNSReverseAddr(t *testing.T) { - testCases := []struct { - name string - have string - want net.IP - }{{ - name: "good_ipv4", - have: "1.0.0.127.in-addr.arpa", - want: net.IP{127, 0, 0, 1}, - }, { - name: "good_ipv6", - have: "4.3.2.1.d.c.b.a.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa", - want: net.ParseIP("::abcd:1234"), - }, { - name: "good_ipv6_case", - have: "4.3.2.1.d.c.B.A.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa", - want: net.ParseIP("::abcd:1234"), - }, { - name: "bad_ipv4_dot", - have: "1.0.0.127.in-addr.arpa.", - }, { - name: "wrong_ipv4", - have: ".0.0.127.in-addr.arpa", - }, { - name: "wrong_ipv6", - have: ".3.2.1.d.c.b.a.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa", - }, { - name: "bad_ipv6_dot", - have: "4.3.2.1.d.c.b.a.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0..ip6.arpa", - }, { - name: "bad_ipv6_space", - have: "4.3.2.1.d.c.b. .0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa", - }} - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - ip := DNSUnreverseAddr(tc.have) - assert.True(t, tc.want.Equal(ip)) - }) - } -} diff --git a/internal/util/dns.go b/internal/util/dns.go deleted file mode 100644 index aaf51d4d..00000000 --- a/internal/util/dns.go +++ /dev/null @@ -1,70 +0,0 @@ -package util - -import ( - "net" - "strings" -) - -// convert character to hex number -func charToHex(n byte) int8 { - if n >= '0' && n <= '9' { - return int8(n) - '0' - } else if (n|0x20) >= 'a' && (n|0x20) <= 'f' { - return (int8(n) | 0x20) - 'a' + 10 - } - return -1 -} - -// parse IPv6 reverse address -func ipParseArpa6(s string) net.IP { - if len(s) != 63 { - return nil - } - ip6 := make(net.IP, 16) - - for i := 0; i != 64; i += 4 { - - // parse "0.1." - n := charToHex(s[i]) - n2 := charToHex(s[i+2]) - if s[i+1] != '.' || (i != 60 && s[i+3] != '.') || - n < 0 || n2 < 0 { - return nil - } - - ip6[16-i/4-1] = byte(n2<<4) | byte(n&0x0f) - } - return ip6 -} - -// ipReverse - reverse IP address: 1.0.0.127 -> 127.0.0.1 -func ipReverse(ip net.IP) net.IP { - n := len(ip) - r := make(net.IP, n) - for i := 0; i != n; i++ { - r[i] = ip[n-i-1] - } - return r -} - -// DNSUnreverseAddr - convert reversed ARPA address to a normal IP address -func DNSUnreverseAddr(s string) net.IP { - const arpaV4 = ".in-addr.arpa" - const arpaV6 = ".ip6.arpa" - - if strings.HasSuffix(s, arpaV4) { - ip := strings.TrimSuffix(s, arpaV4) - ip4 := net.ParseIP(ip).To4() - if ip4 == nil { - return nil - } - - return ipReverse(ip4) - - } else if strings.HasSuffix(s, arpaV6) { - ip := strings.TrimSuffix(s, arpaV6) - return ipParseArpa6(ip) - } - - return nil // unknown suffix -}