177 lines
4.9 KiB
Go
177 lines
4.9 KiB
Go
package querylog
|
|
|
|
import (
|
|
"net"
|
|
"os"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
|
|
"github.com/miekg/dns"
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
func prepareTestDir() string {
|
|
const dir = "./agh-test"
|
|
_ = os.RemoveAll(dir)
|
|
_ = os.MkdirAll(dir, 0755)
|
|
return dir
|
|
}
|
|
|
|
// Check adding and loading (with filtering) entries from disk and memory
|
|
func TestQueryLog(t *testing.T) {
|
|
conf := Config{
|
|
Enabled: true,
|
|
Interval: 1,
|
|
MemSize: 100,
|
|
}
|
|
conf.BaseDir = prepareTestDir()
|
|
defer func() { _ = os.RemoveAll(conf.BaseDir) }()
|
|
l := newQueryLog(conf)
|
|
|
|
// add disk entries
|
|
addEntry(l, "example.org", "1.1.1.1", "2.2.2.1")
|
|
// write to disk (first file)
|
|
_ = l.flushLogBuffer(true)
|
|
// start writing to the second file
|
|
_ = l.rotate()
|
|
// add disk entries
|
|
addEntry(l, "example.org", "1.1.1.2", "2.2.2.2")
|
|
// write to disk
|
|
_ = l.flushLogBuffer(true)
|
|
// add memory entries
|
|
addEntry(l, "test.example.org", "1.1.1.3", "2.2.2.3")
|
|
addEntry(l, "example.com", "1.1.1.4", "2.2.2.4")
|
|
|
|
// get all entries
|
|
params := getDataParams{
|
|
OlderThan: time.Time{},
|
|
}
|
|
d := l.getData(params)
|
|
mdata := d["data"].([]map[string]interface{})
|
|
assert.Equal(t, 4, len(mdata))
|
|
assert.True(t, checkEntry(t, mdata[0], "example.com", "1.1.1.4", "2.2.2.4"))
|
|
assert.True(t, checkEntry(t, mdata[1], "test.example.org", "1.1.1.3", "2.2.2.3"))
|
|
assert.True(t, checkEntry(t, mdata[2], "example.org", "1.1.1.2", "2.2.2.2"))
|
|
assert.True(t, checkEntry(t, mdata[3], "example.org", "1.1.1.1", "2.2.2.1"))
|
|
|
|
// search by domain (strict)
|
|
params = getDataParams{
|
|
OlderThan: time.Time{},
|
|
Domain: "test.example.org",
|
|
StrictMatchDomain: true,
|
|
}
|
|
d = l.getData(params)
|
|
mdata = d["data"].([]map[string]interface{})
|
|
assert.True(t, len(mdata) == 1)
|
|
assert.True(t, checkEntry(t, mdata[0], "test.example.org", "1.1.1.3", "2.2.2.3"))
|
|
|
|
// search by domain (not strict)
|
|
params = getDataParams{
|
|
OlderThan: time.Time{},
|
|
Domain: "example.org",
|
|
StrictMatchDomain: false,
|
|
}
|
|
d = l.getData(params)
|
|
mdata = d["data"].([]map[string]interface{})
|
|
assert.Equal(t, 3, len(mdata))
|
|
assert.True(t, checkEntry(t, mdata[0], "test.example.org", "1.1.1.3", "2.2.2.3"))
|
|
assert.True(t, checkEntry(t, mdata[1], "example.org", "1.1.1.2", "2.2.2.2"))
|
|
assert.True(t, checkEntry(t, mdata[2], "example.org", "1.1.1.1", "2.2.2.1"))
|
|
|
|
// search by client IP (strict)
|
|
params = getDataParams{
|
|
OlderThan: time.Time{},
|
|
Client: "2.2.2.2",
|
|
StrictMatchClient: true,
|
|
}
|
|
d = l.getData(params)
|
|
mdata = d["data"].([]map[string]interface{})
|
|
assert.Equal(t, 1, len(mdata))
|
|
assert.True(t, checkEntry(t, mdata[0], "example.org", "1.1.1.2", "2.2.2.2"))
|
|
|
|
// search by client IP (part of)
|
|
params = getDataParams{
|
|
OlderThan: time.Time{},
|
|
Client: "2.2.2",
|
|
StrictMatchClient: false,
|
|
}
|
|
d = l.getData(params)
|
|
mdata = d["data"].([]map[string]interface{})
|
|
assert.Equal(t, 4, len(mdata))
|
|
assert.True(t, checkEntry(t, mdata[0], "example.com", "1.1.1.4", "2.2.2.4"))
|
|
assert.True(t, checkEntry(t, mdata[1], "test.example.org", "1.1.1.3", "2.2.2.3"))
|
|
assert.True(t, checkEntry(t, mdata[2], "example.org", "1.1.1.2", "2.2.2.2"))
|
|
assert.True(t, checkEntry(t, mdata[3], "example.org", "1.1.1.1", "2.2.2.1"))
|
|
}
|
|
|
|
func addEntry(l *queryLog, host, answerStr, client string) {
|
|
q := dns.Msg{}
|
|
q.Question = append(q.Question, dns.Question{
|
|
Name: host + ".",
|
|
Qtype: dns.TypeA,
|
|
Qclass: dns.ClassINET,
|
|
})
|
|
|
|
a := dns.Msg{}
|
|
a.Question = append(a.Question, q.Question[0])
|
|
answer := new(dns.A)
|
|
answer.Hdr = dns.RR_Header{
|
|
Name: q.Question[0].Name,
|
|
Rrtype: dns.TypeA,
|
|
Class: dns.ClassINET,
|
|
}
|
|
answer.A = net.ParseIP(answerStr)
|
|
a.Answer = append(a.Answer, answer)
|
|
res := dnsfilter.Result{}
|
|
params := AddParams{
|
|
Question: &q,
|
|
Answer: &a,
|
|
Result: &res,
|
|
ClientIP: net.ParseIP(client),
|
|
Upstream: "upstream",
|
|
}
|
|
l.Add(params)
|
|
}
|
|
|
|
func checkEntry(t *testing.T, m map[string]interface{}, host, answer, client string) bool {
|
|
mq := m["question"].(map[string]interface{})
|
|
ma := m["answer"].([]map[string]interface{})
|
|
ma0 := ma[0]
|
|
if !assert.Equal(t, host, mq["host"].(string)) ||
|
|
!assert.Equal(t, "IN", mq["class"].(string)) ||
|
|
!assert.Equal(t, "A", mq["type"].(string)) ||
|
|
!assert.Equal(t, answer, ma0["value"].(string)) ||
|
|
!assert.Equal(t, client, m["client"].(string)) {
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
func TestJSON(t *testing.T) {
|
|
s := `
|
|
{"keystr":"val","obj":{"keybool":true,"keyint":123456}}
|
|
`
|
|
k, v, jtype := readJSON(&s)
|
|
assert.Equal(t, jtype, int32(jsonTStr))
|
|
assert.Equal(t, "keystr", k)
|
|
assert.Equal(t, "val", v)
|
|
|
|
k, v, jtype = readJSON(&s)
|
|
assert.Equal(t, jtype, int32(jsonTObj))
|
|
assert.Equal(t, "obj", k)
|
|
|
|
k, v, jtype = readJSON(&s)
|
|
assert.Equal(t, jtype, int32(jsonTBool))
|
|
assert.Equal(t, "keybool", k)
|
|
assert.Equal(t, "true", v)
|
|
|
|
k, v, jtype = readJSON(&s)
|
|
assert.Equal(t, jtype, int32(jsonTNum))
|
|
assert.Equal(t, "keyint", k)
|
|
assert.Equal(t, "123456", v)
|
|
|
|
k, v, jtype = readJSON(&s)
|
|
assert.True(t, jtype == jsonTErr)
|
|
}
|