2018-11-05 17:40:10 +00:00
|
|
|
package upstream
|
|
|
|
|
|
|
|
import (
|
|
|
|
"crypto/tls"
|
|
|
|
"net"
|
|
|
|
"sort"
|
|
|
|
"sync/atomic"
|
|
|
|
"time"
|
|
|
|
|
|
|
|
"github.com/miekg/dns"
|
|
|
|
)
|
|
|
|
|
2018-11-05 21:47:59 +00:00
|
|
|
// Persistent connections cache -- almost similar to the same used in the CoreDNS forward plugin
|
|
|
|
|
2018-11-05 17:40:10 +00:00
|
|
|
const (
|
|
|
|
defaultExpire = 10 * time.Second
|
|
|
|
minDialTimeout = 100 * time.Millisecond
|
|
|
|
maxDialTimeout = 30 * time.Second
|
|
|
|
defaultDialTimeout = 30 * time.Second
|
|
|
|
cumulativeAvgWeight = 4
|
|
|
|
)
|
|
|
|
|
|
|
|
// a persistConn hold the dns.Conn and the last used time.
|
|
|
|
type persistConn struct {
|
|
|
|
c *dns.Conn
|
|
|
|
used time.Time
|
|
|
|
}
|
|
|
|
|
|
|
|
// Transport hold the persistent cache.
|
|
|
|
type Transport struct {
|
|
|
|
avgDialTime int64 // kind of average time of dial time
|
|
|
|
conns map[string][]*persistConn // Buckets for udp, tcp and tcp-tls.
|
|
|
|
expire time.Duration // After this duration a connection is expired.
|
|
|
|
addr string
|
|
|
|
tlsConfig *tls.Config
|
|
|
|
|
|
|
|
dial chan string
|
|
|
|
yield chan *dns.Conn
|
|
|
|
ret chan *dns.Conn
|
|
|
|
stop chan bool
|
|
|
|
}
|
|
|
|
|
|
|
|
// Dial dials the address configured in transport, potentially reusing a connection or creating a new one.
|
|
|
|
func (t *Transport) Dial(proto string) (*dns.Conn, error) {
|
|
|
|
// If tls has been configured; use it.
|
|
|
|
if t.tlsConfig != nil {
|
|
|
|
proto = "tcp-tls"
|
|
|
|
}
|
|
|
|
|
|
|
|
t.dial <- proto
|
|
|
|
c := <-t.ret
|
|
|
|
|
|
|
|
if c != nil {
|
|
|
|
return c, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
reqTime := time.Now()
|
|
|
|
timeout := t.dialTimeout()
|
|
|
|
if proto == "tcp-tls" {
|
|
|
|
conn, err := dns.DialTimeoutWithTLS(proto, t.addr, t.tlsConfig, timeout)
|
|
|
|
t.updateDialTimeout(time.Since(reqTime))
|
|
|
|
return conn, err
|
|
|
|
}
|
|
|
|
conn, err := dns.DialTimeout(proto, t.addr, timeout)
|
|
|
|
t.updateDialTimeout(time.Since(reqTime))
|
|
|
|
return conn, err
|
|
|
|
}
|
|
|
|
|
|
|
|
// Yield return the connection to transport for reuse.
|
|
|
|
func (t *Transport) Yield(c *dns.Conn) { t.yield <- c }
|
|
|
|
|
|
|
|
// Start starts the transport's connection manager.
|
|
|
|
func (t *Transport) Start() { go t.connManager() }
|
|
|
|
|
|
|
|
// Stop stops the transport's connection manager.
|
|
|
|
func (t *Transport) Stop() { close(t.stop) }
|
|
|
|
|
|
|
|
// SetExpire sets the connection expire time in transport.
|
|
|
|
func (t *Transport) SetExpire(expire time.Duration) { t.expire = expire }
|
|
|
|
|
|
|
|
// SetTLSConfig sets the TLS config in transport.
|
|
|
|
func (t *Transport) SetTLSConfig(cfg *tls.Config) { t.tlsConfig = cfg }
|
|
|
|
|
|
|
|
func NewTransport(addr string) *Transport {
|
|
|
|
t := &Transport{
|
|
|
|
avgDialTime: int64(defaultDialTimeout / 2),
|
|
|
|
conns: make(map[string][]*persistConn),
|
|
|
|
expire: defaultExpire,
|
|
|
|
addr: addr,
|
|
|
|
dial: make(chan string),
|
|
|
|
yield: make(chan *dns.Conn),
|
|
|
|
ret: make(chan *dns.Conn),
|
|
|
|
stop: make(chan bool),
|
|
|
|
}
|
|
|
|
return t
|
|
|
|
}
|
|
|
|
|
|
|
|
func averageTimeout(currentAvg *int64, observedDuration time.Duration, weight int64) {
|
|
|
|
dt := time.Duration(atomic.LoadInt64(currentAvg))
|
|
|
|
atomic.AddInt64(currentAvg, int64(observedDuration-dt)/weight)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (t *Transport) dialTimeout() time.Duration {
|
|
|
|
return limitTimeout(&t.avgDialTime, minDialTimeout, maxDialTimeout)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (t *Transport) updateDialTimeout(newDialTime time.Duration) {
|
|
|
|
averageTimeout(&t.avgDialTime, newDialTime, cumulativeAvgWeight)
|
|
|
|
}
|
|
|
|
|
|
|
|
// limitTimeout is a utility function to auto-tune timeout values
|
|
|
|
// average observed time is moved towards the last observed delay moderated by a weight
|
|
|
|
// next timeout to use will be the double of the computed average, limited by min and max frame.
|
|
|
|
func limitTimeout(currentAvg *int64, minValue time.Duration, maxValue time.Duration) time.Duration {
|
|
|
|
rt := time.Duration(atomic.LoadInt64(currentAvg))
|
|
|
|
if rt < minValue {
|
|
|
|
return minValue
|
|
|
|
}
|
|
|
|
if rt < maxValue/2 {
|
|
|
|
return 2 * rt
|
|
|
|
}
|
|
|
|
return maxValue
|
|
|
|
}
|
|
|
|
|
|
|
|
// connManagers manages the persistent connection cache for UDP and TCP.
|
|
|
|
func (t *Transport) connManager() {
|
|
|
|
ticker := time.NewTicker(t.expire)
|
|
|
|
Wait:
|
|
|
|
for {
|
|
|
|
select {
|
|
|
|
case proto := <-t.dial:
|
|
|
|
// take the last used conn - complexity O(1)
|
|
|
|
if stack := t.conns[proto]; len(stack) > 0 {
|
|
|
|
pc := stack[len(stack)-1]
|
|
|
|
if time.Since(pc.used) < t.expire {
|
|
|
|
// Found one, remove from pool and return this conn.
|
|
|
|
t.conns[proto] = stack[:len(stack)-1]
|
|
|
|
t.ret <- pc.c
|
|
|
|
continue Wait
|
|
|
|
}
|
|
|
|
// clear entire cache if the last conn is expired
|
|
|
|
t.conns[proto] = nil
|
|
|
|
// now, the connections being passed to closeConns() are not reachable from
|
|
|
|
// transport methods anymore. So, it's safe to close them in a separate goroutine
|
|
|
|
go closeConns(stack)
|
|
|
|
}
|
|
|
|
|
|
|
|
t.ret <- nil
|
|
|
|
|
|
|
|
case conn := <-t.yield:
|
|
|
|
|
|
|
|
// no proto here, infer from config and conn
|
|
|
|
if _, ok := conn.Conn.(*net.UDPConn); ok {
|
|
|
|
t.conns["udp"] = append(t.conns["udp"], &persistConn{conn, time.Now()})
|
|
|
|
continue Wait
|
|
|
|
}
|
|
|
|
|
|
|
|
if t.tlsConfig == nil {
|
|
|
|
t.conns["tcp"] = append(t.conns["tcp"], &persistConn{conn, time.Now()})
|
|
|
|
continue Wait
|
|
|
|
}
|
|
|
|
|
|
|
|
t.conns["tcp-tls"] = append(t.conns["tcp-tls"], &persistConn{conn, time.Now()})
|
|
|
|
|
|
|
|
case <-ticker.C:
|
|
|
|
t.cleanup(false)
|
|
|
|
|
|
|
|
case <-t.stop:
|
|
|
|
t.cleanup(true)
|
|
|
|
close(t.ret)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// closeConns closes connections.
|
|
|
|
func closeConns(conns []*persistConn) {
|
|
|
|
for _, pc := range conns {
|
|
|
|
pc.c.Close()
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// cleanup removes connections from cache.
|
|
|
|
func (t *Transport) cleanup(all bool) {
|
|
|
|
staleTime := time.Now().Add(-t.expire)
|
|
|
|
for proto, stack := range t.conns {
|
|
|
|
if len(stack) == 0 {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
if all {
|
|
|
|
t.conns[proto] = nil
|
|
|
|
// now, the connections being passed to closeConns() are not reachable from
|
|
|
|
// transport methods anymore. So, it's safe to close them in a separate goroutine
|
|
|
|
go closeConns(stack)
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
if stack[0].used.After(staleTime) {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
|
|
|
// connections in stack are sorted by "used"
|
|
|
|
good := sort.Search(len(stack), func(i int) bool {
|
|
|
|
return stack[i].used.After(staleTime)
|
|
|
|
})
|
|
|
|
t.conns[proto] = stack[good:]
|
|
|
|
// now, the connections being passed to closeConns() are not reachable from
|
|
|
|
// transport methods anymore. So, it's safe to close them in a separate goroutine
|
|
|
|
go closeConns(stack[:good])
|
|
|
|
}
|
|
|
|
}
|