Added persistent connections cache
This commit is contained in:
parent
484c0ceaff
commit
d6f560ecaf
|
@ -1,6 +1,7 @@
|
||||||
package upstream
|
package upstream
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/tls"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"golang.org/x/net/context"
|
"golang.org/x/net/context"
|
||||||
"time"
|
"time"
|
||||||
|
@ -8,24 +9,40 @@ import (
|
||||||
|
|
||||||
// DnsUpstream is a very simple upstream implementation for plain DNS
|
// DnsUpstream is a very simple upstream implementation for plain DNS
|
||||||
type DnsUpstream struct {
|
type DnsUpstream struct {
|
||||||
nameServer string // IP:port
|
endpoint string // IP:port
|
||||||
timeout time.Duration // Max read and write timeout
|
timeout time.Duration // Max read and write timeout
|
||||||
|
proto string // Protocol (tcp, tcp-tls, or udp)
|
||||||
|
transport *Transport // Persistent connections cache
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDnsUpstream creates a new plain-DNS upstream
|
// NewDnsUpstream creates a new DNS upstream
|
||||||
func NewDnsUpstream(nameServer string) (Upstream, error) {
|
func NewDnsUpstream(endpoint string, proto string, tlsServerName string) (Upstream, error) {
|
||||||
return &DnsUpstream{nameServer: nameServer, timeout: defaultTimeout}, nil
|
|
||||||
|
u := &DnsUpstream{
|
||||||
|
endpoint: endpoint,
|
||||||
|
timeout: defaultTimeout,
|
||||||
|
proto: proto,
|
||||||
|
}
|
||||||
|
|
||||||
|
var tlsConfig *tls.Config
|
||||||
|
|
||||||
|
if tlsServerName != "" {
|
||||||
|
tlsConfig = new(tls.Config)
|
||||||
|
tlsConfig.ServerName = tlsServerName
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize the connections cache
|
||||||
|
u.transport = NewTransport(endpoint)
|
||||||
|
u.transport.tlsConfig = tlsConfig
|
||||||
|
u.transport.Start()
|
||||||
|
|
||||||
|
return u, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Exchange provides an implementation for the Upstream interface
|
// Exchange provides an implementation for the Upstream interface
|
||||||
func (u *DnsUpstream) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) {
|
func (u *DnsUpstream) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) {
|
||||||
|
|
||||||
dnsClient := &dns.Client{
|
resp, err := u.exchange(query)
|
||||||
ReadTimeout: u.timeout,
|
|
||||||
WriteTimeout: u.timeout,
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, _, err := dnsClient.Exchange(query, u.nameServer)
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp = &dns.Msg{}
|
resp = &dns.Msg{}
|
||||||
|
@ -34,3 +51,42 @@ func (u *DnsUpstream) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, e
|
||||||
|
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Clear resources
|
||||||
|
func (u *DnsUpstream) Close() error {
|
||||||
|
|
||||||
|
// Close active connections
|
||||||
|
u.transport.Stop()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Performs a synchronous query. It sends the message m via the conn
|
||||||
|
// c and waits for a reply. The conn c is not closed.
|
||||||
|
func (u *DnsUpstream) exchange(query *dns.Msg) (r *dns.Msg, err error) {
|
||||||
|
|
||||||
|
// Establish a connection if needed (or reuse cached)
|
||||||
|
conn, err := u.transport.Dial(u.proto)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write the request with a timeout
|
||||||
|
conn.SetWriteDeadline(time.Now().Add(u.timeout))
|
||||||
|
if err = conn.WriteMsg(query); err != nil {
|
||||||
|
conn.Close() // Not giving it back
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write response with a timeout
|
||||||
|
conn.SetReadDeadline(time.Now().Add(u.timeout))
|
||||||
|
r, err = conn.ReadMsg()
|
||||||
|
if err != nil {
|
||||||
|
conn.Close() // Not giving it back
|
||||||
|
} else if err == nil && r.Id != query.Id {
|
||||||
|
err = dns.ErrId
|
||||||
|
conn.Close() // Not giving it back
|
||||||
|
}
|
||||||
|
|
||||||
|
u.transport.Yield(conn)
|
||||||
|
return r, err
|
||||||
|
}
|
||||||
|
|
|
@ -18,6 +18,8 @@ const (
|
||||||
dnsMessageContentType = "application/dns-message"
|
dnsMessageContentType = "application/dns-message"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// TODO: Add bootstrap DNS resolver field
|
||||||
|
|
||||||
// HttpsUpstream is the upstream implementation for DNS-over-HTTPS
|
// HttpsUpstream is the upstream implementation for DNS-over-HTTPS
|
||||||
type HttpsUpstream struct {
|
type HttpsUpstream struct {
|
||||||
client *http.Client
|
client *http.Client
|
||||||
|
@ -107,3 +109,8 @@ func (u *HttpsUpstream) exchangeWireformat(msg []byte) ([]byte, error) {
|
||||||
|
|
||||||
return buf, nil
|
return buf, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Clear resources
|
||||||
|
func (u *HttpsUpstream) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,208 @@
|
||||||
|
package upstream
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"net"
|
||||||
|
"sort"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultExpire = 10 * time.Second
|
||||||
|
minDialTimeout = 100 * time.Millisecond
|
||||||
|
maxDialTimeout = 30 * time.Second
|
||||||
|
defaultDialTimeout = 30 * time.Second
|
||||||
|
cumulativeAvgWeight = 4
|
||||||
|
)
|
||||||
|
|
||||||
|
// a persistConn hold the dns.Conn and the last used time.
|
||||||
|
type persistConn struct {
|
||||||
|
c *dns.Conn
|
||||||
|
used time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// Transport hold the persistent cache.
|
||||||
|
type Transport struct {
|
||||||
|
avgDialTime int64 // kind of average time of dial time
|
||||||
|
conns map[string][]*persistConn // Buckets for udp, tcp and tcp-tls.
|
||||||
|
expire time.Duration // After this duration a connection is expired.
|
||||||
|
addr string
|
||||||
|
tlsConfig *tls.Config
|
||||||
|
|
||||||
|
dial chan string
|
||||||
|
yield chan *dns.Conn
|
||||||
|
ret chan *dns.Conn
|
||||||
|
stop chan bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dial dials the address configured in transport, potentially reusing a connection or creating a new one.
|
||||||
|
func (t *Transport) Dial(proto string) (*dns.Conn, error) {
|
||||||
|
// If tls has been configured; use it.
|
||||||
|
if t.tlsConfig != nil {
|
||||||
|
proto = "tcp-tls"
|
||||||
|
}
|
||||||
|
|
||||||
|
t.dial <- proto
|
||||||
|
c := <-t.ret
|
||||||
|
|
||||||
|
if c != nil {
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
reqTime := time.Now()
|
||||||
|
timeout := t.dialTimeout()
|
||||||
|
if proto == "tcp-tls" {
|
||||||
|
conn, err := dns.DialTimeoutWithTLS(proto, t.addr, t.tlsConfig, timeout)
|
||||||
|
t.updateDialTimeout(time.Since(reqTime))
|
||||||
|
return conn, err
|
||||||
|
}
|
||||||
|
conn, err := dns.DialTimeout(proto, t.addr, timeout)
|
||||||
|
t.updateDialTimeout(time.Since(reqTime))
|
||||||
|
return conn, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Yield return the connection to transport for reuse.
|
||||||
|
func (t *Transport) Yield(c *dns.Conn) { t.yield <- c }
|
||||||
|
|
||||||
|
// Start starts the transport's connection manager.
|
||||||
|
func (t *Transport) Start() { go t.connManager() }
|
||||||
|
|
||||||
|
// Stop stops the transport's connection manager.
|
||||||
|
func (t *Transport) Stop() { close(t.stop) }
|
||||||
|
|
||||||
|
// SetExpire sets the connection expire time in transport.
|
||||||
|
func (t *Transport) SetExpire(expire time.Duration) { t.expire = expire }
|
||||||
|
|
||||||
|
// SetTLSConfig sets the TLS config in transport.
|
||||||
|
func (t *Transport) SetTLSConfig(cfg *tls.Config) { t.tlsConfig = cfg }
|
||||||
|
|
||||||
|
func NewTransport(addr string) *Transport {
|
||||||
|
t := &Transport{
|
||||||
|
avgDialTime: int64(defaultDialTimeout / 2),
|
||||||
|
conns: make(map[string][]*persistConn),
|
||||||
|
expire: defaultExpire,
|
||||||
|
addr: addr,
|
||||||
|
dial: make(chan string),
|
||||||
|
yield: make(chan *dns.Conn),
|
||||||
|
ret: make(chan *dns.Conn),
|
||||||
|
stop: make(chan bool),
|
||||||
|
}
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
|
||||||
|
func averageTimeout(currentAvg *int64, observedDuration time.Duration, weight int64) {
|
||||||
|
dt := time.Duration(atomic.LoadInt64(currentAvg))
|
||||||
|
atomic.AddInt64(currentAvg, int64(observedDuration-dt)/weight)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Transport) dialTimeout() time.Duration {
|
||||||
|
return limitTimeout(&t.avgDialTime, minDialTimeout, maxDialTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Transport) updateDialTimeout(newDialTime time.Duration) {
|
||||||
|
averageTimeout(&t.avgDialTime, newDialTime, cumulativeAvgWeight)
|
||||||
|
}
|
||||||
|
|
||||||
|
// limitTimeout is a utility function to auto-tune timeout values
|
||||||
|
// average observed time is moved towards the last observed delay moderated by a weight
|
||||||
|
// next timeout to use will be the double of the computed average, limited by min and max frame.
|
||||||
|
func limitTimeout(currentAvg *int64, minValue time.Duration, maxValue time.Duration) time.Duration {
|
||||||
|
rt := time.Duration(atomic.LoadInt64(currentAvg))
|
||||||
|
if rt < minValue {
|
||||||
|
return minValue
|
||||||
|
}
|
||||||
|
if rt < maxValue/2 {
|
||||||
|
return 2 * rt
|
||||||
|
}
|
||||||
|
return maxValue
|
||||||
|
}
|
||||||
|
|
||||||
|
// connManagers manages the persistent connection cache for UDP and TCP.
|
||||||
|
func (t *Transport) connManager() {
|
||||||
|
ticker := time.NewTicker(t.expire)
|
||||||
|
Wait:
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case proto := <-t.dial:
|
||||||
|
// take the last used conn - complexity O(1)
|
||||||
|
if stack := t.conns[proto]; len(stack) > 0 {
|
||||||
|
pc := stack[len(stack)-1]
|
||||||
|
if time.Since(pc.used) < t.expire {
|
||||||
|
// Found one, remove from pool and return this conn.
|
||||||
|
t.conns[proto] = stack[:len(stack)-1]
|
||||||
|
t.ret <- pc.c
|
||||||
|
continue Wait
|
||||||
|
}
|
||||||
|
// clear entire cache if the last conn is expired
|
||||||
|
t.conns[proto] = nil
|
||||||
|
// now, the connections being passed to closeConns() are not reachable from
|
||||||
|
// transport methods anymore. So, it's safe to close them in a separate goroutine
|
||||||
|
go closeConns(stack)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.ret <- nil
|
||||||
|
|
||||||
|
case conn := <-t.yield:
|
||||||
|
|
||||||
|
// no proto here, infer from config and conn
|
||||||
|
if _, ok := conn.Conn.(*net.UDPConn); ok {
|
||||||
|
t.conns["udp"] = append(t.conns["udp"], &persistConn{conn, time.Now()})
|
||||||
|
continue Wait
|
||||||
|
}
|
||||||
|
|
||||||
|
if t.tlsConfig == nil {
|
||||||
|
t.conns["tcp"] = append(t.conns["tcp"], &persistConn{conn, time.Now()})
|
||||||
|
continue Wait
|
||||||
|
}
|
||||||
|
|
||||||
|
t.conns["tcp-tls"] = append(t.conns["tcp-tls"], &persistConn{conn, time.Now()})
|
||||||
|
|
||||||
|
case <-ticker.C:
|
||||||
|
t.cleanup(false)
|
||||||
|
|
||||||
|
case <-t.stop:
|
||||||
|
t.cleanup(true)
|
||||||
|
close(t.ret)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// closeConns closes connections.
|
||||||
|
func closeConns(conns []*persistConn) {
|
||||||
|
for _, pc := range conns {
|
||||||
|
pc.c.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanup removes connections from cache.
|
||||||
|
func (t *Transport) cleanup(all bool) {
|
||||||
|
staleTime := time.Now().Add(-t.expire)
|
||||||
|
for proto, stack := range t.conns {
|
||||||
|
if len(stack) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if all {
|
||||||
|
t.conns[proto] = nil
|
||||||
|
// now, the connections being passed to closeConns() are not reachable from
|
||||||
|
// transport methods anymore. So, it's safe to close them in a separate goroutine
|
||||||
|
go closeConns(stack)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if stack[0].used.After(staleTime) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// connections in stack are sorted by "used"
|
||||||
|
good := sort.Search(len(stack), func(i int) bool {
|
||||||
|
return stack[i].used.After(staleTime)
|
||||||
|
})
|
||||||
|
t.conns[proto] = stack[good:]
|
||||||
|
// now, the connections being passed to closeConns() are not reachable from
|
||||||
|
// transport methods anymore. So, it's safe to close them in a separate goroutine
|
||||||
|
go closeConns(stack[:good])
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,47 +0,0 @@
|
||||||
package upstream
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/tls"
|
|
||||||
"github.com/miekg/dns"
|
|
||||||
"golang.org/x/net/context"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TODO: Use persistent connection here
|
|
||||||
|
|
||||||
// DnsOverTlsUpstream is the upstream implementation for plain DNS-over-TLS
|
|
||||||
type DnsOverTlsUpstream struct {
|
|
||||||
endpoint string
|
|
||||||
tlsServerName string
|
|
||||||
timeout time.Duration
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewHttpsUpstream creates a new DNS-over-TLS upstream from the endpoint address and TLS server name
|
|
||||||
func NewDnsOverTlsUpstream(endpoint string, tlsServerName string) (Upstream, error) {
|
|
||||||
return &DnsOverTlsUpstream{
|
|
||||||
endpoint: endpoint,
|
|
||||||
tlsServerName: tlsServerName,
|
|
||||||
timeout: defaultTimeout,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Exchange provides an implementation for the Upstream interface
|
|
||||||
func (u *DnsOverTlsUpstream) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) {
|
|
||||||
|
|
||||||
dnsClient := &dns.Client{
|
|
||||||
Net: "tcp-tls",
|
|
||||||
ReadTimeout: u.timeout,
|
|
||||||
WriteTimeout: u.timeout,
|
|
||||||
TLSConfig: new(tls.Config),
|
|
||||||
}
|
|
||||||
dnsClient.TLSConfig.ServerName = u.tlsServerName
|
|
||||||
|
|
||||||
resp, _, err := dnsClient.Exchange(query, u.endpoint)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
resp = &dns.Msg{}
|
|
||||||
resp.SetRcode(resp, dns.RcodeServerFailure)
|
|
||||||
}
|
|
||||||
|
|
||||||
return resp, err
|
|
||||||
}
|
|
|
@ -5,6 +5,8 @@ import (
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"golang.org/x/net/context"
|
"golang.org/x/net/context"
|
||||||
|
"log"
|
||||||
|
"runtime"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -12,9 +14,12 @@ const (
|
||||||
defaultTimeout = 5 * time.Second
|
defaultTimeout = 5 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// TODO: Add a helper method for health-checking an upstream (see health.go in coredns)
|
||||||
|
|
||||||
// Upstream is a simplified interface for proxy destination
|
// Upstream is a simplified interface for proxy destination
|
||||||
type Upstream interface {
|
type Upstream interface {
|
||||||
Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error)
|
Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error)
|
||||||
|
Close() error
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpstreamPlugin is a simplified DNS proxy using a generic upstream interface
|
// UpstreamPlugin is a simplified DNS proxy using a generic upstream interface
|
||||||
|
@ -23,11 +28,21 @@ type UpstreamPlugin struct {
|
||||||
Next plugin.Handler
|
Next plugin.Handler
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Initialize the upstream plugin
|
||||||
|
func New() *UpstreamPlugin {
|
||||||
|
p := &UpstreamPlugin{}
|
||||||
|
|
||||||
|
// Make sure all resources are cleaned up
|
||||||
|
runtime.SetFinalizer(p, (*UpstreamPlugin).finalizer)
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
// ServeDNS implements interface for CoreDNS plugin
|
// ServeDNS implements interface for CoreDNS plugin
|
||||||
func (p UpstreamPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
func (p *UpstreamPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||||
var reply *dns.Msg
|
var reply *dns.Msg
|
||||||
var backendErr error
|
var backendErr error
|
||||||
|
|
||||||
|
// TODO: Change the way we call upstreams
|
||||||
for _, upstream := range p.Upstreams {
|
for _, upstream := range p.Upstreams {
|
||||||
reply, backendErr = upstream.Exchange(ctx, r)
|
reply, backendErr = upstream.Exchange(ctx, r)
|
||||||
if backendErr == nil {
|
if backendErr == nil {
|
||||||
|
@ -40,4 +55,16 @@ func (p UpstreamPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *d
|
||||||
}
|
}
|
||||||
|
|
||||||
// Name implements interface for CoreDNS plugin
|
// Name implements interface for CoreDNS plugin
|
||||||
func (p UpstreamPlugin) Name() string { return "upstream" }
|
func (p *UpstreamPlugin) Name() string { return "upstream" }
|
||||||
|
|
||||||
|
func (p *UpstreamPlugin) finalizer() {
|
||||||
|
|
||||||
|
for i := range p.Upstreams {
|
||||||
|
|
||||||
|
u := p.Upstreams[i]
|
||||||
|
err := u.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Error while closing the upstream: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -2,14 +2,13 @@ package upstream
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"log"
|
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDnsUpstream(t *testing.T) {
|
func TestDnsUpstream(t *testing.T) {
|
||||||
|
|
||||||
u, err := NewDnsUpstream("8.8.8.8:53")
|
u, err := NewDnsUpstream("8.8.8.8:53", "udp", "")
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("cannot create a DNS upstream")
|
t.Errorf("cannot create a DNS upstream")
|
||||||
|
@ -44,12 +43,12 @@ func TestDnsOverTlsUpstream(t *testing.T) {
|
||||||
tlsServerName string
|
tlsServerName string
|
||||||
}{
|
}{
|
||||||
{"1.1.1.1:853", ""},
|
{"1.1.1.1:853", ""},
|
||||||
{"8.8.8.8:853", ""},
|
{"9.9.9.9:853", ""},
|
||||||
{"185.228.168.10:853", "security-filter-dns.cleanbrowsing.org"},
|
{"185.228.168.10:853", "security-filter-dns.cleanbrowsing.org"},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
u, err := NewDnsOverTlsUpstream(test.endpoint, test.tlsServerName)
|
u, err := NewDnsUpstream(test.endpoint, "tcp-tls", test.tlsServerName)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("cannot create a DNS-over-TLS upstream")
|
t.Errorf("cannot create a DNS-over-TLS upstream")
|
||||||
|
@ -60,11 +59,21 @@ func TestDnsOverTlsUpstream(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func testUpstream(t *testing.T, u Upstream) {
|
func testUpstream(t *testing.T, u Upstream) {
|
||||||
|
|
||||||
|
var tests = []struct {
|
||||||
|
name string
|
||||||
|
expected net.IP
|
||||||
|
}{
|
||||||
|
{"google-public-dns-a.google.com.", net.IPv4(8, 8, 8, 8)},
|
||||||
|
{"google-public-dns-b.google.com.", net.IPv4(8, 8, 4, 4)},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
req := dns.Msg{}
|
req := dns.Msg{}
|
||||||
req.Id = dns.Id()
|
req.Id = dns.Id()
|
||||||
req.RecursionDesired = true
|
req.RecursionDesired = true
|
||||||
req.Question = []dns.Question{
|
req.Question = []dns.Question{
|
||||||
{Name: "google-public-dns-a.google.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
|
{Name: test.name, Qtype: dns.TypeA, Qclass: dns.ClassINET},
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := u.Exchange(nil, &req)
|
resp, err := u.Exchange(nil, &req)
|
||||||
|
@ -77,10 +86,14 @@ func testUpstream(t *testing.T, u Upstream) {
|
||||||
t.Errorf("no answer section in the response")
|
t.Errorf("no answer section in the response")
|
||||||
}
|
}
|
||||||
if answer, ok := resp.Answer[0].(*dns.A); ok {
|
if answer, ok := resp.Answer[0].(*dns.A); ok {
|
||||||
if !net.IPv4(8, 8, 8, 8).Equal(answer.A) {
|
if !test.expected.Equal(answer.A) {
|
||||||
t.Errorf("wrong IP in the response: %v", answer.A)
|
t.Errorf("wrong IP in the response: %v", answer.A)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
log.Printf("response: %v", resp)
|
|
||||||
|
err := u.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Error while closing the upstream: %s", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue