Pull request: 3185 detecting recursion
Merge in DNS/adguard-home from 3185-recursion to master
Closes #3185.
Squashed commit of the following:
commit 2fa44223f533c471f2b8c0e17d8550bf4ff73c7b
Merge: 7975957c 7a48e92e
Author: Eugene Burkov <e.burkov@adguard.com>
Date: Thu May 27 19:04:44 2021 +0300
Merge branch 'master' into 3185-recursion
commit 7975957cceb840f76eef0e2e434f4163a122ac34
Author: Eugene Burkov <e.burkov@adguard.com>
Date: Thu May 27 17:36:22 2021 +0300
dnsforward: imp docs
commit 1af7131a5b7c1fefed2d1eb8ee24ebfd3602dc77
Author: Eugene Burkov <e.burkov@adguard.com>
Date: Thu May 27 17:15:00 2021 +0300
dnsforward: imp code, tests, docs
commit f3f9145fb5e1174fab87ca6890da9df722cfebf0
Author: Eugene Burkov <e.burkov@adguard.com>
Date: Thu May 27 15:45:44 2021 +0300
dnsforward: add recursion detector
This commit is contained in:
parent
7a48e92e4d
commit
48b8579703
|
@ -15,6 +15,7 @@ and this project adheres to
|
||||||
|
|
||||||
### Added
|
### Added
|
||||||
|
|
||||||
|
- Detection and handling of recurrent requests ([#3185]).
|
||||||
- The ability to completely disable reverse DNS resolving of IPs from
|
- The ability to completely disable reverse DNS resolving of IPs from
|
||||||
locally-served networks ([#3184]).
|
locally-served networks ([#3184]).
|
||||||
- New flag `--local-frontend` to serve dinamically changeable frontend files
|
- New flag `--local-frontend` to serve dinamically changeable frontend files
|
||||||
|
@ -38,6 +39,7 @@ released by then.
|
||||||
- Go 1.15 support.
|
- Go 1.15 support.
|
||||||
|
|
||||||
[#3184]: https://github.com/AdguardTeam/AdGuardHome/issues/3184
|
[#3184]: https://github.com/AdguardTeam/AdGuardHome/issues/3184
|
||||||
|
[#3185]: https://github.com/AdguardTeam/AdGuardHome/issues/3185
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -42,11 +42,11 @@ func ValidateHardwareAddress(hwa net.HardwareAddr) (err error) {
|
||||||
// according to RFC 1035.
|
// according to RFC 1035.
|
||||||
const maxDomainLabelLen = 63
|
const maxDomainLabelLen = 63
|
||||||
|
|
||||||
// maxDomainNameLen is the maximum allowed length of a full domain name
|
// MaxDomainNameLen is the maximum allowed length of a full domain name
|
||||||
// according to RFC 1035.
|
// according to RFC 1035.
|
||||||
//
|
//
|
||||||
// See https://stackoverflow.com/a/32294443/1892060.
|
// See https://stackoverflow.com/a/32294443/1892060.
|
||||||
const maxDomainNameLen = 253
|
const MaxDomainNameLen = 253
|
||||||
|
|
||||||
// ValidateDomainNameLabel returns an error if label is not a valid label of
|
// ValidateDomainNameLabel returns an error if label is not a valid label of
|
||||||
// a domain name.
|
// a domain name.
|
||||||
|
@ -97,8 +97,8 @@ func ValidateDomainName(name string) (err error) {
|
||||||
l := len(name)
|
l := len(name)
|
||||||
if l == 0 {
|
if l == 0 {
|
||||||
return errors.Error("domain name is empty")
|
return errors.Error("domain name is empty")
|
||||||
} else if l > maxDomainNameLen {
|
} else if l > MaxDomainNameLen {
|
||||||
return fmt.Errorf("too long, max: %d", maxDomainNameLen)
|
return fmt.Errorf("too long, max: %d", MaxDomainNameLen)
|
||||||
}
|
}
|
||||||
|
|
||||||
labels := strings.Split(name, ".")
|
labels := strings.Split(name, ".")
|
||||||
|
|
|
@ -82,6 +82,7 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
|
||||||
// (*proxy.Proxy).handleDNSRequest method performs it before calling the
|
// (*proxy.Proxy).handleDNSRequest method performs it before calling the
|
||||||
// appropriate handler.
|
// appropriate handler.
|
||||||
mods := []modProcessFunc{
|
mods := []modProcessFunc{
|
||||||
|
s.processRecursion,
|
||||||
processInitial,
|
processInitial,
|
||||||
s.processDetermineLocal,
|
s.processDetermineLocal,
|
||||||
s.processInternalHosts,
|
s.processInternalHosts,
|
||||||
|
@ -90,7 +91,7 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
|
||||||
processClientID,
|
processClientID,
|
||||||
processFilteringBeforeRequest,
|
processFilteringBeforeRequest,
|
||||||
s.processLocalPTR,
|
s.processLocalPTR,
|
||||||
processUpstream,
|
s.processUpstream,
|
||||||
processDNSSECAfterResponse,
|
processDNSSECAfterResponse,
|
||||||
processFilteringAfterResponse,
|
processFilteringAfterResponse,
|
||||||
s.ipset.process,
|
s.ipset.process,
|
||||||
|
@ -116,6 +117,22 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// processRecursion checks the incoming request and halts it's handling if s
|
||||||
|
// have tried to resolve it recently.
|
||||||
|
func (s *Server) processRecursion(dctx *dnsContext) (rc resultCode) {
|
||||||
|
pctx := dctx.proxyCtx
|
||||||
|
|
||||||
|
if msg := pctx.Req; msg != nil && s.recDetector.check(*msg) {
|
||||||
|
log.Debug("recursion detected resolving %q", msg.Question[0].Name)
|
||||||
|
pctx.Res = s.genNXDomain(pctx.Req)
|
||||||
|
|
||||||
|
return resultCodeFinish
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return resultCodeSuccess
|
||||||
|
}
|
||||||
|
|
||||||
// Perform initial checks; process WHOIS & rDNS
|
// Perform initial checks; process WHOIS & rDNS
|
||||||
func processInitial(ctx *dnsContext) (rc resultCode) {
|
func processInitial(ctx *dnsContext) (rc resultCode) {
|
||||||
s := ctx.srv
|
s := ctx.srv
|
||||||
|
@ -422,6 +439,7 @@ func (s *Server) processLocalPTR(ctx *dnsContext) (rc resultCode) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.conf.UsePrivateRDNS {
|
if s.conf.UsePrivateRDNS {
|
||||||
|
s.recDetector.add(*d.Req)
|
||||||
if err := s.localResolvers.Resolve(d); err != nil {
|
if err := s.localResolvers.Resolve(d); err != nil {
|
||||||
ctx.err = err
|
ctx.err = err
|
||||||
|
|
||||||
|
@ -472,8 +490,7 @@ func processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// processUpstream passes request to upstream servers and handles the response.
|
// processUpstream passes request to upstream servers and handles the response.
|
||||||
func processUpstream(ctx *dnsContext) (rc resultCode) {
|
func (s *Server) processUpstream(ctx *dnsContext) (rc resultCode) {
|
||||||
s := ctx.srv
|
|
||||||
d := ctx.proxyCtx
|
d := ctx.proxyCtx
|
||||||
if d.Res != nil {
|
if d.Res != nil {
|
||||||
return resultCodeSuccess // response is already set - nothing to do
|
return resultCodeSuccess // response is already set - nothing to do
|
||||||
|
@ -481,18 +498,18 @@ func processUpstream(ctx *dnsContext) (rc resultCode) {
|
||||||
|
|
||||||
if d.Addr != nil && s.conf.GetCustomUpstreamByClient != nil {
|
if d.Addr != nil && s.conf.GetCustomUpstreamByClient != nil {
|
||||||
clientIP := IPStringFromAddr(d.Addr)
|
clientIP := IPStringFromAddr(d.Addr)
|
||||||
upstreamsConf := s.conf.GetCustomUpstreamByClient(clientIP)
|
if upsConf := s.conf.GetCustomUpstreamByClient(clientIP); upsConf != nil {
|
||||||
if upstreamsConf != nil {
|
log.Debug("dns: using custom upstreams for client %s", clientIP)
|
||||||
log.Debug("Using custom upstreams for %s", clientIP)
|
d.CustomUpstreamConfig = upsConf
|
||||||
d.CustomUpstreamConfig = upstreamsConf
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
req := d.Req
|
||||||
if s.conf.EnableDNSSEC {
|
if s.conf.EnableDNSSEC {
|
||||||
opt := d.Req.IsEdns0()
|
opt := req.IsEdns0()
|
||||||
if opt == nil {
|
if opt == nil {
|
||||||
log.Debug("dns: Adding OPT record with DNSSEC flag")
|
log.Debug("dns: adding OPT record with DNSSEC flag")
|
||||||
d.Req.SetEdns0(4096, true)
|
req.SetEdns0(4096, true)
|
||||||
} else if !opt.Do() {
|
} else if !opt.Do() {
|
||||||
opt.SetDo(true)
|
opt.SetDo(true)
|
||||||
} else {
|
} else {
|
||||||
|
@ -501,13 +518,13 @@ func processUpstream(ctx *dnsContext) (rc resultCode) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// request was not filtered so let it be processed further
|
// request was not filtered so let it be processed further
|
||||||
err := s.dnsProxy.Resolve(d)
|
s.recDetector.add(*req)
|
||||||
if err != nil {
|
if ctx.err = s.dnsProxy.Resolve(d); ctx.err != nil {
|
||||||
ctx.err = err
|
|
||||||
return resultCodeError
|
return resultCodeError
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx.responseFromUpstream = true
|
ctx.responseFromUpstream = true
|
||||||
|
|
||||||
return resultCodeSuccess
|
return resultCodeSuccess
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -76,6 +76,7 @@ type Server struct {
|
||||||
ipset ipsetCtx
|
ipset ipsetCtx
|
||||||
subnetDetector *aghnet.SubnetDetector
|
subnetDetector *aghnet.SubnetDetector
|
||||||
localResolvers *proxy.Proxy
|
localResolvers *proxy.Proxy
|
||||||
|
recDetector *recursionDetector
|
||||||
|
|
||||||
tableHostToIP hostToIPTable
|
tableHostToIP hostToIPTable
|
||||||
tableHostToIPLock sync.Mutex
|
tableHostToIPLock sync.Mutex
|
||||||
|
@ -121,6 +122,14 @@ func domainNameToSuffix(tld string) (suffix string) {
|
||||||
return string(b)
|
return string(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
// recursionTTL is the time recursive request is cached for.
|
||||||
|
recursionTTL = 5 * time.Second
|
||||||
|
// cachedRecurrentReqNum is the maximum number of cached recurrent
|
||||||
|
// requests.
|
||||||
|
cachedRecurrentReqNum = 1000
|
||||||
|
)
|
||||||
|
|
||||||
// NewServer creates a new instance of the dnsforward.Server
|
// NewServer creates a new instance of the dnsforward.Server
|
||||||
// Note: this function must be called only once
|
// Note: this function must be called only once
|
||||||
func NewServer(p DNSCreateParams) (s *Server, err error) {
|
func NewServer(p DNSCreateParams) (s *Server, err error) {
|
||||||
|
@ -142,6 +151,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
|
||||||
queryLog: p.QueryLog,
|
queryLog: p.QueryLog,
|
||||||
subnetDetector: p.SubnetDetector,
|
subnetDetector: p.SubnetDetector,
|
||||||
localDomainSuffix: localDomainSuffix,
|
localDomainSuffix: localDomainSuffix,
|
||||||
|
recDetector: newRecursionDetector(recursionTTL, cachedRecurrentReqNum),
|
||||||
}
|
}
|
||||||
|
|
||||||
if p.DHCPServer != nil {
|
if p.DHCPServer != nil {
|
||||||
|
@ -160,7 +170,9 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
|
||||||
|
|
||||||
// NewCustomServer creates a new instance of *Server with custom internal proxy.
|
// NewCustomServer creates a new instance of *Server with custom internal proxy.
|
||||||
func NewCustomServer(internalProxy *proxy.Proxy) *Server {
|
func NewCustomServer(internalProxy *proxy.Proxy) *Server {
|
||||||
s := &Server{}
|
s := &Server{
|
||||||
|
recDetector: newRecursionDetector(0, 1),
|
||||||
|
}
|
||||||
if internalProxy != nil {
|
if internalProxy != nil {
|
||||||
s.internalProxy = internalProxy
|
s.internalProxy = internalProxy
|
||||||
}
|
}
|
||||||
|
@ -278,14 +290,13 @@ func (s *Server) Exchange(ip net.IP) (host string, err error) {
|
||||||
Req: req,
|
Req: req,
|
||||||
StartTime: time.Now(),
|
StartTime: time.Now(),
|
||||||
}
|
}
|
||||||
var resp *dns.Msg
|
|
||||||
err = resolver.Resolve(ctx)
|
s.recDetector.add(*req)
|
||||||
if err != nil {
|
if err = resolver.Resolve(ctx); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
resp = ctx.Res
|
resp := ctx.Res
|
||||||
|
|
||||||
if len(resp.Answer) == 0 {
|
if len(resp.Answer) == 0 {
|
||||||
return "", fmt.Errorf("lookup for %q: %w", arpa, rDNSEmptyAnswerErr)
|
return "", fmt.Errorf("lookup for %q: %w", arpa, rDNSEmptyAnswerErr)
|
||||||
}
|
}
|
||||||
|
@ -490,6 +501,8 @@ func (s *Server) Prepare(config *ServerConfig) error {
|
||||||
return fmt.Errorf("setting up resolvers: %w", err)
|
return fmt.Errorf("setting up resolvers: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.recDetector.clear()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,115 @@
|
||||||
|
package dnsforward
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||||
|
"github.com/AdguardTeam/golibs/cache"
|
||||||
|
"github.com/AdguardTeam/golibs/log"
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
// uint* sizes in bytes to improve readability.
|
||||||
|
//
|
||||||
|
// TODO(e.burkov): Remove when there will be a more regardful way to define
|
||||||
|
// those. See https://github.com/golang/go/issues/29982.
|
||||||
|
const (
|
||||||
|
uint16sz = 2
|
||||||
|
uint64sz = 8
|
||||||
|
)
|
||||||
|
|
||||||
|
// recursionDetector detects recursion in DNS forwarding.
|
||||||
|
type recursionDetector struct {
|
||||||
|
recentRequests cache.Cache
|
||||||
|
ttl time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// check checks if the passed req was already sent by s.
|
||||||
|
func (rd *recursionDetector) check(msg dns.Msg) (ok bool) {
|
||||||
|
if len(msg.Question) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
key := msgToSignature(msg)
|
||||||
|
expireData := rd.recentRequests.Get(key)
|
||||||
|
if expireData == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
expire := time.Unix(0, int64(binary.BigEndian.Uint64(expireData)))
|
||||||
|
|
||||||
|
return time.Now().Before(expire)
|
||||||
|
}
|
||||||
|
|
||||||
|
// add caches the msg if it has anything in the questions section.
|
||||||
|
func (rd *recursionDetector) add(msg dns.Msg) {
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
if len(msg.Question) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
key := msgToSignature(msg)
|
||||||
|
expire64 := uint64(now.Add(rd.ttl).UnixNano())
|
||||||
|
expire := make([]byte, uint64sz)
|
||||||
|
binary.BigEndian.PutUint64(expire, expire64)
|
||||||
|
|
||||||
|
rd.recentRequests.Set(key, expire)
|
||||||
|
}
|
||||||
|
|
||||||
|
// clear clears the recent requests cache.
|
||||||
|
func (rd *recursionDetector) clear() {
|
||||||
|
rd.recentRequests.Clear()
|
||||||
|
}
|
||||||
|
|
||||||
|
// newRecursionDetector returns the initialized *recursionDetector.
|
||||||
|
func newRecursionDetector(ttl time.Duration, suspectsNum uint) (rd *recursionDetector) {
|
||||||
|
return &recursionDetector{
|
||||||
|
recentRequests: cache.New(cache.Config{
|
||||||
|
EnableLRU: true,
|
||||||
|
MaxCount: suspectsNum,
|
||||||
|
}),
|
||||||
|
ttl: ttl,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// msgToSignature converts msg into it's signature represented in bytes.
|
||||||
|
func msgToSignature(msg dns.Msg) (sig []byte) {
|
||||||
|
sig = make([]byte, uint16sz*2+aghnet.MaxDomainNameLen)
|
||||||
|
// The binary.BigEndian byte order is used everywhere except when the
|
||||||
|
// real machine's endianess is needed.
|
||||||
|
byteOrder := binary.BigEndian
|
||||||
|
byteOrder.PutUint16(sig[0:], msg.Id)
|
||||||
|
q := msg.Question[0]
|
||||||
|
byteOrder.PutUint16(sig[uint16sz:], q.Qtype)
|
||||||
|
copy(sig[2*uint16sz:], []byte(q.Name))
|
||||||
|
|
||||||
|
return sig
|
||||||
|
}
|
||||||
|
|
||||||
|
// msgToSignatureSlow converts msg into it's signature represented in bytes in
|
||||||
|
// the less efficient way.
|
||||||
|
//
|
||||||
|
// See BenchmarkMsgToSignature.
|
||||||
|
func msgToSignatureSlow(msg dns.Msg) (sig []byte) {
|
||||||
|
type msgSignature struct {
|
||||||
|
name [aghnet.MaxDomainNameLen]byte
|
||||||
|
id uint16
|
||||||
|
qtype uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
b := bytes.NewBuffer(sig)
|
||||||
|
q := msg.Question[0]
|
||||||
|
signature := msgSignature{
|
||||||
|
id: msg.Id,
|
||||||
|
qtype: q.Qtype,
|
||||||
|
}
|
||||||
|
copy(signature.name[:], q.Name)
|
||||||
|
if err := binary.Write(b, binary.BigEndian, signature); err != nil {
|
||||||
|
log.Debug("writing message signature: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return b.Bytes()
|
||||||
|
}
|
|
@ -0,0 +1,154 @@
|
||||||
|
package dnsforward
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRecursionDetector_Check(t *testing.T) {
|
||||||
|
rd := newRecursionDetector(0, 2)
|
||||||
|
|
||||||
|
const (
|
||||||
|
recID = 1234
|
||||||
|
recTTL = time.Hour * 100
|
||||||
|
)
|
||||||
|
|
||||||
|
const nonRecID = recID * 2
|
||||||
|
|
||||||
|
sampleQuestion := dns.Question{
|
||||||
|
Name: "some.domain",
|
||||||
|
Qtype: dns.TypeAAAA,
|
||||||
|
}
|
||||||
|
sampleMsg := dns.Msg{
|
||||||
|
MsgHdr: dns.MsgHdr{
|
||||||
|
Id: recID,
|
||||||
|
},
|
||||||
|
Question: []dns.Question{sampleQuestion},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Manually add the message with big ttl.
|
||||||
|
key := msgToSignature(sampleMsg)
|
||||||
|
expire := make([]byte, uint64sz)
|
||||||
|
binary.BigEndian.PutUint64(expire, uint64(time.Now().Add(recTTL).UnixNano()))
|
||||||
|
rd.recentRequests.Set(key, expire)
|
||||||
|
|
||||||
|
// Add an expired message.
|
||||||
|
sampleMsg.Id = nonRecID
|
||||||
|
rd.add(sampleMsg)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
questions []dns.Question
|
||||||
|
id uint16
|
||||||
|
want bool
|
||||||
|
}{{
|
||||||
|
name: "recurrent",
|
||||||
|
questions: []dns.Question{sampleQuestion},
|
||||||
|
id: recID,
|
||||||
|
want: true,
|
||||||
|
}, {
|
||||||
|
name: "not_suspected",
|
||||||
|
questions: []dns.Question{sampleQuestion},
|
||||||
|
id: recID + 1,
|
||||||
|
want: false,
|
||||||
|
}, {
|
||||||
|
name: "expired",
|
||||||
|
questions: []dns.Question{sampleQuestion},
|
||||||
|
id: nonRecID,
|
||||||
|
want: false,
|
||||||
|
}, {
|
||||||
|
name: "empty",
|
||||||
|
questions: []dns.Question{},
|
||||||
|
id: nonRecID,
|
||||||
|
want: false,
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
sampleMsg.Id = tc.id
|
||||||
|
sampleMsg.Question = tc.questions
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
detected := rd.check(sampleMsg)
|
||||||
|
assert.Equal(t, tc.want, detected)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecursionDetector_Suspect(t *testing.T) {
|
||||||
|
rd := newRecursionDetector(0, 1)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
msg dns.Msg
|
||||||
|
want bool
|
||||||
|
}{{
|
||||||
|
name: "simple",
|
||||||
|
msg: dns.Msg{
|
||||||
|
MsgHdr: dns.MsgHdr{
|
||||||
|
Id: 1234,
|
||||||
|
},
|
||||||
|
Question: []dns.Question{{
|
||||||
|
Name: "some.domain",
|
||||||
|
Qtype: dns.TypeA,
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
want: true,
|
||||||
|
}, {
|
||||||
|
name: "unencumbered",
|
||||||
|
msg: dns.Msg{},
|
||||||
|
want: false,
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
t.Cleanup(rd.clear)
|
||||||
|
|
||||||
|
rd.add(tc.msg)
|
||||||
|
|
||||||
|
if tc.want {
|
||||||
|
assert.Equal(t, 1, rd.recentRequests.Stats().Count)
|
||||||
|
} else {
|
||||||
|
assert.Zero(t, rd.recentRequests.Stats().Count)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var sink []byte
|
||||||
|
|
||||||
|
func BenchmarkMsgToSignature(b *testing.B) {
|
||||||
|
const name = "some.not.very.long.host.name"
|
||||||
|
|
||||||
|
msg := dns.Msg{
|
||||||
|
MsgHdr: dns.MsgHdr{
|
||||||
|
Id: 1234,
|
||||||
|
},
|
||||||
|
Question: []dns.Question{{
|
||||||
|
Name: name,
|
||||||
|
Qtype: dns.TypeAAAA,
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
|
||||||
|
b.Run("efficient", func(b *testing.B) {
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
sink = msgToSignature(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.NotEmpty(b, sink)
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("inefficient", func(b *testing.B) {
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
sink = msgToSignatureSlow(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.NotEmpty(b, sink)
|
||||||
|
})
|
||||||
|
}
|
Loading…
Reference in New Issue