badguardhome/internal/dnsforward/recursiondetector_test.go

155 lines
2.8 KiB
Go
Raw Normal View History

package dnsforward
import (
"encoding/binary"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
)
func TestRecursionDetector_Check(t *testing.T) {
rd := newRecursionDetector(0, 2)
const (
recID = 1234
recTTL = time.Hour * 100
)
const nonRecID = recID * 2
sampleQuestion := dns.Question{
Name: "some.domain",
Qtype: dns.TypeAAAA,
}
sampleMsg := dns.Msg{
MsgHdr: dns.MsgHdr{
Id: recID,
},
Question: []dns.Question{sampleQuestion},
}
// Manually add the message with big ttl.
key := msgToSignature(sampleMsg)
expire := make([]byte, uint64sz)
binary.BigEndian.PutUint64(expire, uint64(time.Now().Add(recTTL).UnixNano()))
rd.recentRequests.Set(key, expire)
// Add an expired message.
sampleMsg.Id = nonRecID
rd.add(sampleMsg)
testCases := []struct {
name string
questions []dns.Question
id uint16
want bool
}{{
name: "recurrent",
questions: []dns.Question{sampleQuestion},
id: recID,
want: true,
}, {
name: "not_suspected",
questions: []dns.Question{sampleQuestion},
id: recID + 1,
want: false,
}, {
name: "expired",
questions: []dns.Question{sampleQuestion},
id: nonRecID,
want: false,
}, {
name: "empty",
questions: []dns.Question{},
id: nonRecID,
want: false,
}}
for _, tc := range testCases {
sampleMsg.Id = tc.id
sampleMsg.Question = tc.questions
t.Run(tc.name, func(t *testing.T) {
detected := rd.check(sampleMsg)
assert.Equal(t, tc.want, detected)
})
}
}
func TestRecursionDetector_Suspect(t *testing.T) {
rd := newRecursionDetector(0, 1)
testCases := []struct {
name string
msg dns.Msg
want bool
}{{
name: "simple",
msg: dns.Msg{
MsgHdr: dns.MsgHdr{
Id: 1234,
},
Question: []dns.Question{{
Name: "some.domain",
Qtype: dns.TypeA,
}},
},
want: true,
}, {
name: "unencumbered",
msg: dns.Msg{},
want: false,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Cleanup(rd.clear)
rd.add(tc.msg)
if tc.want {
assert.Equal(t, 1, rd.recentRequests.Stats().Count)
} else {
assert.Zero(t, rd.recentRequests.Stats().Count)
}
})
}
}
var sink []byte
func BenchmarkMsgToSignature(b *testing.B) {
const name = "some.not.very.long.host.name"
msg := dns.Msg{
MsgHdr: dns.MsgHdr{
Id: 1234,
},
Question: []dns.Question{{
Name: name,
Qtype: dns.TypeAAAA,
}},
}
b.Run("efficient", func(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
sink = msgToSignature(msg)
}
assert.NotEmpty(b, sink)
})
b.Run("inefficient", func(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
sink = msgToSignatureSlow(msg)
}
assert.NotEmpty(b, sink)
})
}