Pull request: 2704 local addresses vol.4

Updates #2704.

Squashed commit of the following:

commit bbc292a1c9dc7382e1116cfd85fea22ab8de6949
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Fri Apr 9 19:22:46 2021 +0300

    all: replace exchanger with proxy
This commit is contained in:
Ainar Garipov 2021-04-09 21:01:21 +03:00
parent 0247a2da44
commit 8b8319fca0
7 changed files with 70 additions and 205 deletions

View File

@ -1,86 +0,0 @@
package aghnet
import (
"time"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/miekg/dns"
)
// This package is not the best place for this functionality, but we put it here
// since we need to use it in both rDNS (home) and dnsServer (dnsforward).
// NoUpstreamsErr should be returned when there are no upstreams inside
// Exchanger implementation.
const NoUpstreamsErr agherr.Error = "no upstreams specified"
// Exchanger represents an object able to resolve DNS messages.
//
// TODO(e.burkov): Maybe expand with method like ExchangeParallel to be able to
// use user's upstream mode settings. Also, think about Update method to
// refresh the internal state.
type Exchanger interface {
Exchange(req *dns.Msg) (resp *dns.Msg, err error)
}
// multiAddrExchanger is the default implementation of Exchanger interface.
type multiAddrExchanger struct {
ups []upstream.Upstream
}
// NewMultiAddrExchanger creates an Exchanger instance from passed addresses.
// It returns an error if any of addrs failed to become an upstream.
func NewMultiAddrExchanger(
addrs []string,
bootstraps []string,
timeout time.Duration,
) (e Exchanger, err error) {
defer agherr.Annotate("exchanger: %w", &err)
if len(addrs) == 0 {
return &multiAddrExchanger{}, nil
}
var ups []upstream.Upstream = make([]upstream.Upstream, 0, len(addrs))
for _, addr := range addrs {
var u upstream.Upstream
u, err = upstream.AddressToUpstream(addr, upstream.Options{
Bootstrap: bootstraps,
Timeout: timeout,
})
if err != nil {
return nil, err
}
ups = append(ups, u)
}
return &multiAddrExchanger{ups: ups}, nil
}
// Exсhange performs a query to each resolver until first response.
func (e *multiAddrExchanger) Exchange(req *dns.Msg) (resp *dns.Msg, err error) {
defer agherr.Annotate("exchanger: %w", &err)
// TODO(e.burkov): Maybe prohibit the initialization without upstreams.
if len(e.ups) == 0 {
return nil, NoUpstreamsErr
}
var errs []error
for _, u := range e.ups {
resp, err = u.Exchange(req)
if err != nil {
errs = append(errs, err)
continue
}
if resp != nil {
return resp, nil
}
}
return nil, agherr.Many("can't exchange", errs...)
}

View File

@ -1,64 +0,0 @@
package aghnet
import (
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewMultiAddrExchanger(t *testing.T) {
var e Exchanger
var err error
t.Run("empty", func(t *testing.T) {
e, err = NewMultiAddrExchanger([]string{}, nil, 0)
require.NoError(t, err)
assert.NotNil(t, e)
})
t.Run("successful", func(t *testing.T) {
e, err = NewMultiAddrExchanger([]string{"www.example.com"}, nil, 0)
require.NoError(t, err)
assert.NotNil(t, e)
})
t.Run("unsuccessful", func(t *testing.T) {
e, err = NewMultiAddrExchanger([]string{"invalid-proto://www.example.com"}, nil, 0)
require.Error(t, err)
assert.Nil(t, e)
})
}
func TestMultiAddrExchanger_Exchange(t *testing.T) {
e := &multiAddrExchanger{}
t.Run("error", func(t *testing.T) {
e.ups = []upstream.Upstream{&aghtest.TestErrUpstream{}}
resp, err := e.Exchange(nil)
require.Error(t, err)
assert.Nil(t, resp)
})
t.Run("success", func(t *testing.T) {
e.ups = []upstream.Upstream{&aghtest.TestUpstream{
Reverse: map[string][]string{
"abc": {"cba"},
},
}}
resp, err := e.Exchange(&dns.Msg{
Question: []dns.Question{{
Name: "abc",
Qtype: dns.TypePTR,
}},
})
require.NoError(t, err)
require.Len(t, resp.Answer, 1)
assert.Equal(t, "cba", resp.Answer[0].Header().Name)
})
}

View File

@ -1,7 +1,6 @@
package dnsforward package dnsforward
import ( import (
"errors"
"net" "net"
"strings" "strings"
"time" "time"
@ -403,21 +402,18 @@ func (s *Server) processLocalPTR(ctx *dnsContext) (rc resultCode) {
return resultCodeSuccess return resultCodeSuccess
} }
req := d.Req err := s.localResolvers.Resolve(d)
resp, err := s.localResolvers.Exchange(req)
if err != nil { if err != nil {
if errors.Is(err, aghnet.NoUpstreamsErr) {
d.Res = s.genNXDomain(req)
return resultCodeFinish
}
ctx.err = err ctx.err = err
return resultCodeError return resultCodeError
} }
d.Res = resp if d.Res == nil {
d.Res = s.genNXDomain(d.Req)
return resultCodeFinish
}
return resultCodeSuccess return resultCodeSuccess
} }

View File

@ -259,17 +259,16 @@ func TestServer_ProcessInternalHosts(t *testing.T) {
} }
func TestLocalRestriction(t *testing.T) { func TestLocalRestriction(t *testing.T) {
s := createTestServer(t, &dnsfilter.Config{}, ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
})
ups := &aghtest.TestUpstream{ ups := &aghtest.TestUpstream{
Reverse: map[string][]string{ Reverse: map[string][]string{
"251.252.253.254.in-addr.arpa.": {"host1.example.net."}, "251.252.253.254.in-addr.arpa.": {"host1.example.net."},
"1.1.168.192.in-addr.arpa.": {"some.local-client."}, "1.1.168.192.in-addr.arpa.": {"some.local-client."},
}, },
} }
s.localResolvers = &aghtest.Exchanger{Ups: ups} s := createTestServer(t, &dnsfilter.Config{}, ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
}, ups)
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ups} s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ups}
startDeferStop(t, s) startDeferStop(t, s)

View File

@ -20,6 +20,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/querylog" "github.com/AdguardTeam/AdGuardHome/internal/querylog"
"github.com/AdguardTeam/AdGuardHome/internal/stats" "github.com/AdguardTeam/AdGuardHome/internal/stats"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@ -66,7 +67,7 @@ type Server struct {
ipset ipsetCtx ipset ipsetCtx
subnetDetector *aghnet.SubnetDetector subnetDetector *aghnet.SubnetDetector
localResolvers aghnet.Exchanger localResolvers *proxy.Proxy
tableHostToIP map[string]net.IP // "hostname -> IP" table for internal addresses (DHCP) tableHostToIP map[string]net.IP // "hostname -> IP" table for internal addresses (DHCP)
tableHostToIPLock sync.Mutex tableHostToIPLock sync.Mutex
@ -243,24 +244,24 @@ func (s *Server) Exchange(ip net.IP) (host string, err error) {
Qclass: dns.ClassINET, Qclass: dns.ClassINET,
}}, }},
} }
var resp *dns.Msg
if s.subnetDetector.IsLocallyServedNetwork(ip) {
resp, err = s.localResolvers.Exchange(req)
} else {
ctx := &proxy.DNSContext{ ctx := &proxy.DNSContext{
Proto: "udp", Proto: "udp",
Req: req, Req: req,
StartTime: time.Now(), StartTime: time.Now(),
} }
err = s.internalProxy.Resolve(ctx)
resp = ctx.Res var resp *dns.Msg
if s.subnetDetector.IsLocallyServedNetwork(ip) {
err = s.localResolvers.Resolve(ctx)
} else {
err = s.internalProxy.Resolve(ctx)
} }
if err != nil { if err != nil {
return "", err return "", err
} }
resp = ctx.Res
if len(resp.Answer) == 0 { if len(resp.Answer) == 0 {
return "", fmt.Errorf("lookup for %q: %w", arpa, rDNSEmptyAnswerErr) return "", fmt.Errorf("lookup for %q: %w", arpa, rDNSEmptyAnswerErr)
} }
@ -376,18 +377,26 @@ func (s *Server) setupResolvers(localAddrs []string) (err error) {
return err return err
} }
// TODO(e.burkov): The approach of subtracting sets of strings // TODO(e.burkov): The approach of subtracting sets of strings is not
// is not really applicable here since in case of listening on // really applicable here since in case of listening on all network
// all network interfaces we should check the whole interface's // interfaces we should check the whole interface's network to cut off
// network to cut off all the loopback addresses as well. // all the loopback addresses as well.
localAddrs = stringSetSubtract(localAddrs, ourAddrs) localAddrs = stringSetSubtract(localAddrs, ourAddrs)
if s.localResolvers, err = aghnet.NewMultiAddrExchanger( var upsConfig proxy.UpstreamConfig
localAddrs, upsConfig, err = proxy.ParseUpstreamsConfig(localAddrs, upstream.Options{
bootstraps, Bootstrap: bootstraps,
defaultLocalTimeout, Timeout: defaultLocalTimeout,
); err != nil { // TODO(e.burkov): Should we verify server's ceritificates?
return err })
if err != nil {
return fmt.Errorf("parsing upstreams: %w", err)
}
s.localResolvers = &proxy.Proxy{
Config: proxy.Config{
UpstreamConfig: &upsConfig,
},
} }
return nil return nil

View File

@ -52,7 +52,12 @@ func startDeferStop(t *testing.T, s *Server) {
}) })
} }
func createTestServer(t *testing.T, filterConf *dnsfilter.Config, forwardConf ServerConfig) *Server { func createTestServer(
t *testing.T,
filterConf *dnsfilter.Config,
forwardConf ServerConfig,
localUps upstream.Upstream,
) (s *Server) {
t.Helper() t.Helper()
rules := `||nxdomain.example.org rules := `||nxdomain.example.org
@ -70,7 +75,6 @@ func createTestServer(t *testing.T, filterConf *dnsfilter.Config, forwardConf Se
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, snd) require.NotNil(t, snd)
var s *Server
s, err = NewServer(DNSCreateParams{ s, err = NewServer(DNSCreateParams{
DNSFilter: f, DNSFilter: f,
SubnetDetector: snd, SubnetDetector: snd,
@ -85,7 +89,9 @@ func createTestServer(t *testing.T, filterConf *dnsfilter.Config, forwardConf Se
s.Lock() s.Lock()
defer s.Unlock() defer s.Unlock()
s.localResolvers = &aghtest.Exchanger{} if localUps != nil {
s.localResolvers.Config.UpstreamConfig.Upstreams = []upstream.Upstream{localUps}
}
return s return s
} }
@ -143,7 +149,7 @@ func createTestTLS(t *testing.T, tlsConf TLSConfig) (s *Server, certPem []byte)
s = createTestServer(t, &dnsfilter.Config{}, ServerConfig{ s = createTestServer(t, &dnsfilter.Config{}, ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}}, UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}}, TCPListenAddrs: []*net.TCPAddr{{}},
}) }, nil)
tlsConf.CertificateChainData, tlsConf.PrivateKeyData = certPem, keyPem tlsConf.CertificateChainData, tlsConf.PrivateKeyData = certPem, keyPem
s.conf.TLSConfig = tlsConf s.conf.TLSConfig = tlsConf
@ -239,7 +245,7 @@ func TestServer(t *testing.T) {
s := createTestServer(t, &dnsfilter.Config{}, ServerConfig{ s := createTestServer(t, &dnsfilter.Config{}, ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}}, UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}}, TCPListenAddrs: []*net.TCPAddr{{}},
}) }, nil)
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&aghtest.TestUpstream{ &aghtest.TestUpstream{
IPv4: map[string][]net.IP{ IPv4: map[string][]net.IP{
@ -277,7 +283,7 @@ func TestServerWithProtectionDisabled(t *testing.T) {
s := createTestServer(t, &dnsfilter.Config{}, ServerConfig{ s := createTestServer(t, &dnsfilter.Config{}, ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}}, UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}}, TCPListenAddrs: []*net.TCPAddr{{}},
}) }, nil)
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&aghtest.TestUpstream{ &aghtest.TestUpstream{
IPv4: map[string][]net.IP{ IPv4: map[string][]net.IP{
@ -374,7 +380,7 @@ func TestServerRace(t *testing.T) {
}, },
ConfigModified: func() {}, ConfigModified: func() {},
} }
s := createTestServer(t, filterConf, forwardConf) s := createTestServer(t, filterConf, forwardConf, nil)
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&aghtest.TestUpstream{ &aghtest.TestUpstream{
IPv4: map[string][]net.IP{ IPv4: map[string][]net.IP{
@ -407,7 +413,7 @@ func TestSafeSearch(t *testing.T) {
ProtectionEnabled: true, ProtectionEnabled: true,
}, },
} }
s := createTestServer(t, filterConf, forwardConf) s := createTestServer(t, filterConf, forwardConf, nil)
startDeferStop(t, s) startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP).String() addr := s.dnsProxy.Addr(proxy.ProtoUDP).String()
@ -460,7 +466,7 @@ func TestInvalidRequest(t *testing.T) {
s := createTestServer(t, &dnsfilter.Config{}, ServerConfig{ s := createTestServer(t, &dnsfilter.Config{}, ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}}, UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}}, TCPListenAddrs: []*net.TCPAddr{{}},
}) }, nil)
startDeferStop(t, s) startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP).String() addr := s.dnsProxy.Addr(proxy.ProtoUDP).String()
@ -488,7 +494,7 @@ func TestBlockedRequest(t *testing.T) {
ProtectionEnabled: true, ProtectionEnabled: true,
}, },
} }
s := createTestServer(t, &dnsfilter.Config{}, forwardConf) s := createTestServer(t, &dnsfilter.Config{}, forwardConf, nil)
startDeferStop(t, s) startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP) addr := s.dnsProxy.Addr(proxy.ProtoUDP)
@ -513,7 +519,7 @@ func TestServerCustomClientUpstream(t *testing.T) {
ProtectionEnabled: true, ProtectionEnabled: true,
}, },
} }
s := createTestServer(t, &dnsfilter.Config{}, forwardConf) s := createTestServer(t, &dnsfilter.Config{}, forwardConf, nil)
s.conf.GetCustomUpstreamByClient = func(_ string) *proxy.UpstreamConfig { s.conf.GetCustomUpstreamByClient = func(_ string) *proxy.UpstreamConfig {
return &proxy.UpstreamConfig{ return &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{ Upstreams: []upstream.Upstream{
@ -558,7 +564,7 @@ func TestBlockCNAMEProtectionEnabled(t *testing.T) {
s := createTestServer(t, &dnsfilter.Config{}, ServerConfig{ s := createTestServer(t, &dnsfilter.Config{}, ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}}, UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}}, TCPListenAddrs: []*net.TCPAddr{{}},
}) }, nil)
testUpstm := &aghtest.TestUpstream{ testUpstm := &aghtest.TestUpstream{
CName: testCNAMEs, CName: testCNAMEs,
IPv4: testIPv4, IPv4: testIPv4,
@ -590,7 +596,7 @@ func TestBlockCNAME(t *testing.T) {
ProtectionEnabled: true, ProtectionEnabled: true,
}, },
} }
s := createTestServer(t, &dnsfilter.Config{}, forwardConf) s := createTestServer(t, &dnsfilter.Config{}, forwardConf, nil)
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&aghtest.TestUpstream{ &aghtest.TestUpstream{
CName: testCNAMEs, CName: testCNAMEs,
@ -652,7 +658,7 @@ func TestClientRulesForCNAMEMatching(t *testing.T) {
}, },
}, },
} }
s := createTestServer(t, &dnsfilter.Config{}, forwardConf) s := createTestServer(t, &dnsfilter.Config{}, forwardConf, nil)
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&aghtest.TestUpstream{ &aghtest.TestUpstream{
CName: testCNAMEs, CName: testCNAMEs,
@ -693,7 +699,7 @@ func TestNullBlockedRequest(t *testing.T) {
BlockingMode: "null_ip", BlockingMode: "null_ip",
}, },
} }
s := createTestServer(t, &dnsfilter.Config{}, forwardConf) s := createTestServer(t, &dnsfilter.Config{}, forwardConf, nil)
startDeferStop(t, s) startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP) addr := s.dnsProxy.Addr(proxy.ProtoUDP)
@ -792,7 +798,7 @@ func TestBlockedByHosts(t *testing.T) {
ProtectionEnabled: true, ProtectionEnabled: true,
}, },
} }
s := createTestServer(t, &dnsfilter.Config{}, forwardConf) s := createTestServer(t, &dnsfilter.Config{}, forwardConf, nil)
startDeferStop(t, s) startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP) addr := s.dnsProxy.Addr(proxy.ProtoUDP)
@ -827,7 +833,7 @@ func TestBlockedBySafeBrowsing(t *testing.T) {
ProtectionEnabled: true, ProtectionEnabled: true,
}, },
} }
s := createTestServer(t, filterConf, forwardConf) s := createTestServer(t, filterConf, forwardConf, nil)
s.dnsFilter.SetSafeBrowsingUpstream(sbUps) s.dnsFilter.SetSafeBrowsingUpstream(sbUps)
startDeferStop(t, s) startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP) addr := s.dnsProxy.Addr(proxy.ProtoUDP)
@ -1240,8 +1246,13 @@ func TestServer_Exchange(t *testing.T) {
}} }}
for _, tc := range testCases { for _, tc := range testCases {
dns.localResolvers = &aghtest.Exchanger{ pcfg := proxy.Config{
Ups: tc.locUpstream, UpstreamConfig: &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{tc.locUpstream},
},
}
dns.localResolvers = &proxy.Proxy{
Config: pcfg,
} }
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {

View File

@ -50,7 +50,7 @@ func TestDNSForwardHTTTP_handleGetConfig(t *testing.T) {
}, },
ConfigModified: func() {}, ConfigModified: func() {},
} }
s := createTestServer(t, filterConf, forwardConf) s := createTestServer(t, filterConf, forwardConf, nil)
require.Nil(t, s.Start()) require.Nil(t, s.Start())
t.Cleanup(func() { t.Cleanup(func() {
require.Nil(t, s.Stop()) require.Nil(t, s.Stop())
@ -123,7 +123,7 @@ func TestDNSForwardHTTTP_handleSetConfig(t *testing.T) {
}, },
ConfigModified: func() {}, ConfigModified: func() {},
} }
s := createTestServer(t, filterConf, forwardConf) s := createTestServer(t, filterConf, forwardConf, nil)
defaultConf := s.conf defaultConf := s.conf