Merge: *(querylog): added offset/limit parameters

* commit '61a24ffc71eda26e10a11e3a9c8506909b6c4b52':
  *(querylog): decode methods to a new file
  *(querylog): added offset/limit parameters
This commit is contained in:
Andrey Meshkov 2020-05-27 11:40:15 +03:00
commit 16a6aada6f
14 changed files with 1029 additions and 776 deletions

View File

@ -16,12 +16,14 @@ module.exports = {
], ],
"scopes": [ "scopes": [
"", "",
"ui",
"global", "global",
"dnsfilter", "dnsfilter",
"home", "home",
"dnsforward", "dnsforward",
"dhcpd", "dhcpd",
"documentation" "querylog",
"documentation",
], ],
"types": { "types": {
"+": { "+": {

View File

@ -1,5 +1,12 @@
# AdGuard Home API Change Log # AdGuard Home API Change Log
## v0.103: API changes
### API: Get querylog: GET /control/querylog
* Added optional "offset" and "limit" parameters
We are still using "older_than" approach in AdGuard Home UI, but we realize that it's easier to use offset/limit so here is this option now.
## v0.102: API changes ## v0.102: API changes

View File

@ -143,13 +143,26 @@ paths:
tags: tags:
- log - log
operationId: queryLog operationId: queryLog
summary: Get DNS server query log summary: Get DNS server query log.
parameters: parameters:
- name: older_than - name: older_than
in: query in: query
description: Filter by older than description: Filter by older than
schema: schema:
type: string type: string
- name: offset
in: query
description:
Specify the ranking number of the first item on the page.
Even though it is possible to use "offset" and "older_than",
we recommend choosing one of them and sticking to it.
schema:
type: integer
- name: limit
in: query
description: Limit the number of records to be returned
schema:
type: integer
- name: filter_domain - name: filter_domain
in: query in: query
description: Filter by domain name description: Filter by domain name

175
querylog/decode.go Normal file
View File

@ -0,0 +1,175 @@
package querylog
import (
"encoding/base64"
"strconv"
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
)
// decodeLogEntry - decodes query log entry from a line
// nolint (gocyclo)
func decodeLogEntry(ent *logEntry, str string) {
var b bool
var i int
var err error
for {
k, v, t := readJSON(&str)
if t == jsonTErr {
break
}
switch k {
case "IP":
if len(ent.IP) == 0 {
ent.IP = v
}
case "T":
ent.Time, err = time.Parse(time.RFC3339, v)
case "QH":
ent.QHost = v
case "QT":
ent.QType = v
case "QC":
ent.QClass = v
case "Answer":
ent.Answer, err = base64.StdEncoding.DecodeString(v)
case "OrigAnswer":
ent.OrigAnswer, err = base64.StdEncoding.DecodeString(v)
case "IsFiltered":
b, err = strconv.ParseBool(v)
ent.Result.IsFiltered = b
case "Rule":
ent.Result.Rule = v
case "FilterID":
i, err = strconv.Atoi(v)
ent.Result.FilterID = int64(i)
case "Reason":
i, err = strconv.Atoi(v)
ent.Result.Reason = dnsfilter.Reason(i)
case "Upstream":
ent.Upstream = v
case "Elapsed":
i, err = strconv.Atoi(v)
ent.Elapsed = time.Duration(i)
// pre-v0.99.3 compatibility:
case "Question":
var qstr []byte
qstr, err = base64.StdEncoding.DecodeString(v)
if err != nil {
break
}
q := new(dns.Msg)
err = q.Unpack(qstr)
if err != nil {
break
}
ent.QHost = q.Question[0].Name
if len(ent.QHost) == 0 {
break
}
ent.QHost = ent.QHost[:len(ent.QHost)-1]
ent.QType = dns.TypeToString[q.Question[0].Qtype]
ent.QClass = dns.ClassToString[q.Question[0].Qclass]
case "Time":
ent.Time, err = time.Parse(time.RFC3339, v)
}
if err != nil {
log.Debug("decodeLogEntry err: %s", err)
break
}
}
}
// Get value from "key":"value"
func readJSONValue(s, name string) string {
i := strings.Index(s, "\""+name+"\":\"")
if i == -1 {
return ""
}
start := i + 1 + len(name) + 3
i = strings.IndexByte(s[start:], '"')
if i == -1 {
return ""
}
end := start + i
return s[start:end]
}
const (
jsonTErr = iota
jsonTObj
jsonTStr
jsonTNum
jsonTBool
)
// Parse JSON key-value pair
// e.g.: "key":VALUE where VALUE is "string", true|false (boolean), or 123.456 (number)
// Note the limitations:
// . doesn't support whitespace
// . doesn't support "null"
// . doesn't validate boolean or number
// . no proper handling of {} braces
// . no handling of [] brackets
// Return (key, value, type)
func readJSON(ps *string) (string, string, int32) {
s := *ps
k := ""
v := ""
t := int32(jsonTErr)
q1 := strings.IndexByte(s, '"')
if q1 == -1 {
return k, v, t
}
q2 := strings.IndexByte(s[q1+1:], '"')
if q2 == -1 {
return k, v, t
}
k = s[q1+1 : q1+1+q2]
s = s[q1+1+q2+1:]
if len(s) < 2 || s[0] != ':' {
return k, v, t
}
if s[1] == '"' {
q2 = strings.IndexByte(s[2:], '"')
if q2 == -1 {
return k, v, t
}
v = s[2 : 2+q2]
t = jsonTStr
s = s[2+q2+1:]
} else if s[1] == '{' {
t = jsonTObj
s = s[1+1:]
} else {
sep := strings.IndexAny(s[1:], ",}")
if sep == -1 {
return k, v, t
}
v = s[1 : 1+sep]
if s[1] == 't' || s[1] == 'f' {
t = jsonTBool
} else if s[1] == '.' || (s[1] >= '0' && s[1] <= '9') {
t = jsonTNum
}
s = s[1+sep+1:]
}
*ps = s
return k, v, t
}

34
querylog/decode_test.go Normal file
View File

@ -0,0 +1,34 @@
package querylog
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestJSON(t *testing.T) {
s := `
{"keystr":"val","obj":{"keybool":true,"keyint":123456}}
`
k, v, jtype := readJSON(&s)
assert.Equal(t, jtype, int32(jsonTStr))
assert.Equal(t, "keystr", k)
assert.Equal(t, "val", v)
k, v, jtype = readJSON(&s)
assert.Equal(t, jtype, int32(jsonTObj))
assert.Equal(t, "obj", k)
k, v, jtype = readJSON(&s)
assert.Equal(t, jtype, int32(jsonTBool))
assert.Equal(t, "keybool", k)
assert.Equal(t, "true", v)
k, v, jtype = readJSON(&s)
assert.Equal(t, jtype, int32(jsonTNum))
assert.Equal(t, "keyint", k)
assert.Equal(t, "123456", v)
k, v, jtype = readJSON(&s)
assert.True(t, jtype == jsonTErr)
}

164
querylog/json.go Normal file
View File

@ -0,0 +1,164 @@
package querylog
import (
"fmt"
"net"
"strconv"
"time"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
)
// Get Client IP address
func (l *queryLog) getClientIP(clientIP string) string {
if l.conf.AnonymizeClientIP {
ip := net.ParseIP(clientIP)
if ip != nil {
ip4 := ip.To4()
const AnonymizeClientIP4Mask = 24
const AnonymizeClientIP6Mask = 112
if ip4 != nil {
clientIP = ip4.Mask(net.CIDRMask(AnonymizeClientIP4Mask, 32)).String()
} else {
clientIP = ip.Mask(net.CIDRMask(AnonymizeClientIP6Mask, 128)).String()
}
}
}
return clientIP
}
// entriesToJSON - converts log entries to JSON
func (l *queryLog) entriesToJSON(entries []*logEntry, oldest time.Time) map[string]interface{} {
// init the response object
var data = []map[string]interface{}{}
// the elements order is already reversed (from newer to older)
for i := 0; i < len(entries); i++ {
entry := entries[i]
jsonEntry := l.logEntryToJSONEntry(entry)
data = append(data, jsonEntry)
}
var result = map[string]interface{}{}
result["oldest"] = ""
if !oldest.IsZero() {
result["oldest"] = oldest.Format(time.RFC3339Nano)
}
result["data"] = data
return result
}
func (l *queryLog) logEntryToJSONEntry(entry *logEntry) map[string]interface{} {
var msg *dns.Msg
if len(entry.Answer) > 0 {
msg = new(dns.Msg)
if err := msg.Unpack(entry.Answer); err != nil {
log.Debug("Failed to unpack dns message answer: %s: %s", err, string(entry.Answer))
msg = nil
}
}
jsonEntry := map[string]interface{}{
"reason": entry.Result.Reason.String(),
"elapsedMs": strconv.FormatFloat(entry.Elapsed.Seconds()*1000, 'f', -1, 64),
"time": entry.Time.Format(time.RFC3339Nano),
"client": l.getClientIP(entry.IP),
}
jsonEntry["question"] = map[string]interface{}{
"host": entry.QHost,
"type": entry.QType,
"class": entry.QClass,
}
if msg != nil {
jsonEntry["status"] = dns.RcodeToString[msg.Rcode]
opt := msg.IsEdns0()
dnssecOk := false
if opt != nil {
dnssecOk = opt.Do()
}
jsonEntry["answer_dnssec"] = dnssecOk
}
if len(entry.Result.Rule) > 0 {
jsonEntry["rule"] = entry.Result.Rule
jsonEntry["filterId"] = entry.Result.FilterID
}
if len(entry.Result.ServiceName) != 0 {
jsonEntry["service_name"] = entry.Result.ServiceName
}
answers := answerToMap(msg)
if answers != nil {
jsonEntry["answer"] = answers
}
if len(entry.OrigAnswer) != 0 {
a := new(dns.Msg)
err := a.Unpack(entry.OrigAnswer)
if err == nil {
answers = answerToMap(a)
if answers != nil {
jsonEntry["original_answer"] = answers
}
} else {
log.Debug("Querylog: msg.Unpack(entry.OrigAnswer): %s: %s", err, string(entry.OrigAnswer))
}
}
return jsonEntry
}
func answerToMap(a *dns.Msg) []map[string]interface{} {
if a == nil || len(a.Answer) == 0 {
return nil
}
var answers = []map[string]interface{}{}
for _, k := range a.Answer {
header := k.Header()
answer := map[string]interface{}{
"type": dns.TypeToString[header.Rrtype],
"ttl": header.Ttl,
}
// try most common record types
switch v := k.(type) {
case *dns.A:
answer["value"] = v.A.String()
case *dns.AAAA:
answer["value"] = v.AAAA.String()
case *dns.MX:
answer["value"] = fmt.Sprintf("%v %v", v.Preference, v.Mx)
case *dns.CNAME:
answer["value"] = v.Target
case *dns.NS:
answer["value"] = v.Ns
case *dns.SPF:
answer["value"] = v.Txt
case *dns.TXT:
answer["value"] = v.Txt
case *dns.PTR:
answer["value"] = v.Ptr
case *dns.SOA:
answer["value"] = fmt.Sprintf("%v %v %v %v %v %v %v", v.Ns, v.Mbox, v.Serial, v.Refresh, v.Retry, v.Expire, v.Minttl)
case *dns.CAA:
answer["value"] = fmt.Sprintf("%v %v \"%v\"", v.Flag, v.Tag, v.Value)
case *dns.HINFO:
answer["value"] = fmt.Sprintf("\"%v\" \"%v\"", v.Cpu, v.Os)
case *dns.RRSIG:
answer["value"] = fmt.Sprintf("%v %v %v %v %v %v %v %v %v", dns.TypeToString[v.TypeCovered], v.Algorithm, v.Labels, v.OrigTtl, v.Expiration, v.Inception, v.KeyTag, v.SignerName, v.Signature)
default:
// type unknown, marshall it as-is
answer["value"] = v
}
answers = append(answers, answer)
}
return answers
}

View File

@ -1,11 +1,8 @@
package querylog package querylog
import ( import (
"fmt"
"net"
"os" "os"
"path/filepath" "path/filepath"
"strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -17,10 +14,6 @@ import (
const ( const (
queryLogFileName = "querylog.json" // .gz added during compression queryLogFileName = "querylog.json" // .gz added during compression
getDataLimit = 500 // GetData(): maximum log entries to return
// maximum entries to parse when searching
maxSearchEntries = 50000
) )
// queryLog is a structure that writes and reads the DNS query log // queryLog is a structure that writes and reads the DNS query log
@ -36,6 +29,23 @@ type queryLog struct {
fileWriteLock sync.Mutex fileWriteLock sync.Mutex
} }
// logEntry - represents a single log entry
type logEntry struct {
IP string `json:"IP"` // Client IP
Time time.Time `json:"T"`
QHost string `json:"QH"`
QType string `json:"QT"`
QClass string `json:"QC"`
Answer []byte `json:",omitempty"` // sometimes empty answers happen like binerdunt.top or rev2.globalrootservers.net
OrigAnswer []byte `json:",omitempty"`
Result dnsfilter.Result
Elapsed time.Duration
Upstream string `json:",omitempty"` // if empty, means it was cached
}
// create a new instance of the query log // create a new instance of the query log
func newQueryLog(conf Config) *queryLog { func newQueryLog(conf Config) *queryLog {
l := queryLog{} l := queryLog{}
@ -93,22 +103,6 @@ func (l *queryLog) clear() {
log.Debug("Query log: cleared") log.Debug("Query log: cleared")
} }
type logEntry struct {
IP string `json:"IP"`
Time time.Time `json:"T"`
QHost string `json:"QH"`
QType string `json:"QT"`
QClass string `json:"QC"`
Answer []byte `json:",omitempty"` // sometimes empty answers happen like binerdunt.top or rev2.globalrootservers.net
OrigAnswer []byte `json:",omitempty"`
Result dnsfilter.Result
Elapsed time.Duration
Upstream string `json:",omitempty"` // if empty, means it was cached
}
func (l *queryLog) Add(params AddParams) { func (l *queryLog) Add(params AddParams) {
if !l.conf.Enabled { if !l.conf.Enabled {
return return
@ -173,230 +167,3 @@ func (l *queryLog) Add(params AddParams) {
go l.flushLogBuffer(false) // nolint go l.flushLogBuffer(false) // nolint
} }
} }
// Parameters for getData()
type getDataParams struct {
OlderThan time.Time // return entries that are older than this value
Domain string // filter by domain name in question
Client string // filter by client IP
QuestionType string // filter by question type
ResponseStatus responseStatusType // filter by response status
StrictMatchDomain bool // if Domain value must be matched strictly
StrictMatchClient bool // if Client value must be matched strictly
}
// Response status
type responseStatusType int32
// Response status constants
const (
responseStatusAll responseStatusType = iota + 1
responseStatusFiltered
)
// Gets log entries
func (l *queryLog) getData(params getDataParams) map[string]interface{} {
now := time.Now()
if len(params.Client) != 0 && l.conf.AnonymizeClientIP {
params.Client = l.getClientIP(params.Client)
}
// add from file
fileEntries, oldest, total := l.searchFiles(params)
if params.OlderThan.IsZero() {
// In case if the timer is not precise (for instance, on Windows)
// We really want to get all records including those added just before the call
params.OlderThan = now.Add(time.Millisecond)
}
// add from memory buffer
l.bufferLock.Lock()
total += len(l.buffer)
memoryEntries := make([]*logEntry, 0)
// go through the buffer in the reverse order
// from NEWER to OLDER
for i := len(l.buffer) - 1; i >= 0; i-- {
entry := l.buffer[i]
if entry.Time.UnixNano() >= params.OlderThan.UnixNano() {
// Ignore entries newer than what was requested
continue
}
if !matchesGetDataParams(entry, params) {
continue
}
memoryEntries = append(memoryEntries, entry)
}
l.bufferLock.Unlock()
// now let's get a unified collection
entries := append(memoryEntries, fileEntries...)
if len(entries) > getDataLimit {
// remove extra records
entries = entries[:getDataLimit]
}
if len(entries) == getDataLimit {
// change the "oldest" value here.
// we cannot use the "oldest" we got from "searchFiles" anymore
// because after adding in-memory records and removing extra records
// the situation has changed
oldest = entries[len(entries)-1].Time
}
// init the response object
var data = []map[string]interface{}{}
// the elements order is already reversed (from newer to older)
for i := 0; i < len(entries); i++ {
entry := entries[i]
jsonEntry := l.logEntryToJSONEntry(entry)
data = append(data, jsonEntry)
}
log.Debug("QueryLog: prepared data (%d/%d) older than %s in %s",
len(entries), total, params.OlderThan, time.Since(now))
var result = map[string]interface{}{}
result["oldest"] = ""
if !oldest.IsZero() {
result["oldest"] = oldest.Format(time.RFC3339Nano)
}
result["data"] = data
return result
}
// Get Client IP address
func (l *queryLog) getClientIP(clientIP string) string {
if l.conf.AnonymizeClientIP {
ip := net.ParseIP(clientIP)
if ip != nil {
ip4 := ip.To4()
const AnonymizeClientIP4Mask = 24
const AnonymizeClientIP6Mask = 112
if ip4 != nil {
clientIP = ip4.Mask(net.CIDRMask(AnonymizeClientIP4Mask, 32)).String()
} else {
clientIP = ip.Mask(net.CIDRMask(AnonymizeClientIP6Mask, 128)).String()
}
}
}
return clientIP
}
func (l *queryLog) logEntryToJSONEntry(entry *logEntry) map[string]interface{} {
var msg *dns.Msg
if len(entry.Answer) > 0 {
msg = new(dns.Msg)
if err := msg.Unpack(entry.Answer); err != nil {
log.Debug("Failed to unpack dns message answer: %s: %s", err, string(entry.Answer))
msg = nil
}
}
jsonEntry := map[string]interface{}{
"reason": entry.Result.Reason.String(),
"elapsedMs": strconv.FormatFloat(entry.Elapsed.Seconds()*1000, 'f', -1, 64),
"time": entry.Time.Format(time.RFC3339Nano),
"client": l.getClientIP(entry.IP),
}
jsonEntry["question"] = map[string]interface{}{
"host": entry.QHost,
"type": entry.QType,
"class": entry.QClass,
}
if msg != nil {
jsonEntry["status"] = dns.RcodeToString[msg.Rcode]
opt := msg.IsEdns0()
dnssecOk := false
if opt != nil {
dnssecOk = opt.Do()
}
jsonEntry["answer_dnssec"] = dnssecOk
}
if len(entry.Result.Rule) > 0 {
jsonEntry["rule"] = entry.Result.Rule
jsonEntry["filterId"] = entry.Result.FilterID
}
if len(entry.Result.ServiceName) != 0 {
jsonEntry["service_name"] = entry.Result.ServiceName
}
answers := answerToMap(msg)
if answers != nil {
jsonEntry["answer"] = answers
}
if len(entry.OrigAnswer) != 0 {
a := new(dns.Msg)
err := a.Unpack(entry.OrigAnswer)
if err == nil {
answers = answerToMap(a)
if answers != nil {
jsonEntry["original_answer"] = answers
}
} else {
log.Debug("Querylog: msg.Unpack(entry.OrigAnswer): %s: %s", err, string(entry.OrigAnswer))
}
}
return jsonEntry
}
func answerToMap(a *dns.Msg) []map[string]interface{} {
if a == nil || len(a.Answer) == 0 {
return nil
}
var answers = []map[string]interface{}{}
for _, k := range a.Answer {
header := k.Header()
answer := map[string]interface{}{
"type": dns.TypeToString[header.Rrtype],
"ttl": header.Ttl,
}
// try most common record types
switch v := k.(type) {
case *dns.A:
answer["value"] = v.A.String()
case *dns.AAAA:
answer["value"] = v.AAAA.String()
case *dns.MX:
answer["value"] = fmt.Sprintf("%v %v", v.Preference, v.Mx)
case *dns.CNAME:
answer["value"] = v.Target
case *dns.NS:
answer["value"] = v.Ns
case *dns.SPF:
answer["value"] = v.Txt
case *dns.TXT:
answer["value"] = v.Txt
case *dns.PTR:
answer["value"] = v.Ptr
case *dns.SOA:
answer["value"] = fmt.Sprintf("%v %v %v %v %v %v %v", v.Ns, v.Mbox, v.Serial, v.Refresh, v.Retry, v.Expire, v.Minttl)
case *dns.CAA:
answer["value"] = fmt.Sprintf("%v %v \"%v\"", v.Flag, v.Tag, v.Value)
case *dns.HINFO:
answer["value"] = fmt.Sprintf("\"%v\" \"%v\"", v.Cpu, v.Os)
case *dns.RRSIG:
answer["value"] = fmt.Sprintf("%v %v %v %v %v %v %v %v %v", dns.TypeToString[v.TypeCovered], v.Algorithm, v.Labels, v.OrigTtl, v.Expiration, v.Inception, v.KeyTag, v.SignerName, v.Signature)
default:
// type unknown, marshall it as-is
answer["value"] = v
}
answers = append(answers, answer)
}
return answers
}

View File

@ -4,13 +4,30 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"net/url"
"strconv"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/util"
"github.com/AdguardTeam/golibs/jsonutil" "github.com/AdguardTeam/golibs/jsonutil"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
) )
type qlogConfig struct {
Enabled bool `json:"enabled"`
Interval uint32 `json:"interval"`
AnonymizeClientIP bool `json:"anonymize_client_ip"`
}
// Register web handlers
func (l *queryLog) initWeb() {
l.conf.HTTPRegister("GET", "/control/querylog", l.handleQueryLog)
l.conf.HTTPRegister("GET", "/control/querylog_info", l.handleQueryLogInfo)
l.conf.HTTPRegister("POST", "/control/querylog_clear", l.handleQueryLogClear)
l.conf.HTTPRegister("POST", "/control/querylog_config", l.handleQueryLogConfig)
}
func httpError(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) { func httpError(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) {
text := fmt.Sprintf(format, args...) text := fmt.Sprintf(format, args...)
@ -19,74 +36,18 @@ func httpError(r *http.Request, w http.ResponseWriter, code int, format string,
http.Error(w, text, code) http.Error(w, text, code)
} }
type request struct {
olderThan string
filterDomain string
filterClient string
filterQuestionType string
filterResponseStatus string
}
// "value" -> value, return TRUE
func getDoubleQuotesEnclosedValue(s *string) bool {
t := *s
if len(t) >= 2 && t[0] == '"' && t[len(t)-1] == '"' {
*s = t[1 : len(t)-1]
return true
}
return false
}
func (l *queryLog) handleQueryLog(w http.ResponseWriter, r *http.Request) { func (l *queryLog) handleQueryLog(w http.ResponseWriter, r *http.Request) {
var err error params, err := l.parseSearchParams(r)
req := request{}
q := r.URL.Query()
req.olderThan = q.Get("older_than")
req.filterDomain = q.Get("filter_domain")
req.filterClient = q.Get("filter_client")
req.filterQuestionType = q.Get("filter_question_type")
req.filterResponseStatus = q.Get("filter_response_status")
params := getDataParams{
Domain: req.filterDomain,
Client: req.filterClient,
ResponseStatus: responseStatusAll,
}
if len(req.olderThan) != 0 {
params.OlderThan, err = time.Parse(time.RFC3339Nano, req.olderThan)
if err != nil { if err != nil {
httpError(r, w, http.StatusBadRequest, "invalid time stamp: %s", err) httpError(r, w, http.StatusBadRequest, "failed to parse params: %s", err)
return return
} }
}
if getDoubleQuotesEnclosedValue(&params.Domain) { // search for the log entries
params.StrictMatchDomain = true entries, oldest := l.search(params)
}
if getDoubleQuotesEnclosedValue(&params.Client) {
params.StrictMatchClient = true
}
if len(req.filterQuestionType) != 0 { // convert log entries to JSON
_, ok := dns.StringToType[req.filterQuestionType] var data = l.entriesToJSON(entries, oldest)
if !ok {
httpError(r, w, http.StatusBadRequest, "invalid question_type")
return
}
params.QuestionType = req.filterQuestionType
}
if len(req.filterResponseStatus) != 0 {
switch req.filterResponseStatus {
case "filtered":
params.ResponseStatus = responseStatusFiltered
default:
httpError(r, w, http.StatusBadRequest, "invalid response_status")
return
}
}
data := l.getData(params)
jsonVal, err := json.Marshal(data) jsonVal, err := json.Marshal(data)
if err != nil { if err != nil {
@ -101,16 +62,10 @@ func (l *queryLog) handleQueryLog(w http.ResponseWriter, r *http.Request) {
} }
} }
func (l *queryLog) handleQueryLogClear(w http.ResponseWriter, r *http.Request) { func (l *queryLog) handleQueryLogClear(_ http.ResponseWriter, _ *http.Request) {
l.clear() l.clear()
} }
type qlogConfig struct {
Enabled bool `json:"enabled"`
Interval uint32 `json:"interval"`
AnonymizeClientIP bool `json:"anonymize_client_ip"`
}
// Get configuration // Get configuration
func (l *queryLog) handleQueryLogInfo(w http.ResponseWriter, r *http.Request) { func (l *queryLog) handleQueryLogInfo(w http.ResponseWriter, r *http.Request) {
resp := qlogConfig{} resp := qlogConfig{}
@ -162,10 +117,85 @@ func (l *queryLog) handleQueryLogConfig(w http.ResponseWriter, r *http.Request)
l.conf.ConfigModified() l.conf.ConfigModified()
} }
// Register web handlers // "value" -> value, return TRUE
func (l *queryLog) initWeb() { func getDoubleQuotesEnclosedValue(s *string) bool {
l.conf.HTTPRegister("GET", "/control/querylog", l.handleQueryLog) t := *s
l.conf.HTTPRegister("GET", "/control/querylog_info", l.handleQueryLogInfo) if len(t) >= 2 && t[0] == '"' && t[len(t)-1] == '"' {
l.conf.HTTPRegister("POST", "/control/querylog_clear", l.handleQueryLogClear) *s = t[1 : len(t)-1]
l.conf.HTTPRegister("POST", "/control/querylog_config", l.handleQueryLogConfig) return true
}
return false
}
// parseSearchCriteria - parses "searchCriteria" from the specified query parameter
func (l *queryLog) parseSearchCriteria(q url.Values, name string, ct criteriaType) (bool, searchCriteria, error) {
val := q.Get(name)
if len(val) == 0 {
return false, searchCriteria{}, nil
}
c := searchCriteria{
criteriaType: ct,
value: val,
}
if getDoubleQuotesEnclosedValue(&c.value) {
c.strict = true
}
if ct == ctClient && l.conf.AnonymizeClientIP {
c.value = l.getClientIP(c.value)
}
if ct == ctFilteringStatus && !util.ContainsString(filteringStatusValues, c.value) {
return false, c, fmt.Errorf("invalid value %s", c.value)
}
return true, c, nil
}
// parseSearchParams - parses "searchParams" from the HTTP request's query string
func (l *queryLog) parseSearchParams(r *http.Request) (*searchParams, error) {
p := newSearchParams()
var err error
q := r.URL.Query()
olderThan := q.Get("older_than")
if len(olderThan) != 0 {
p.olderThan, err = time.Parse(time.RFC3339Nano, olderThan)
if err != nil {
return nil, err
}
}
if limit, err := strconv.ParseInt(q.Get("limit"), 10, 64); err == nil {
p.limit = int(limit)
// If limit or offset are specified explicitly, we should change the default behavior
// and scan all log records until we found enough log entries
p.maxFileScanEntries = 0
}
if offset, err := strconv.ParseInt(q.Get("offset"), 10, 64); err == nil {
p.offset = int(offset)
p.maxFileScanEntries = 0
}
paramNames := map[string]criteriaType{
"filter_domain": ctDomain,
"filter_client": ctClient,
"filter_question_type": ctQuestionType,
"filter_response_status": ctFilteringStatus,
}
for k, v := range paramNames {
ok, c, err := l.parseSearchCriteria(q, k, v)
if err != nil {
return nil, err
}
if ok {
p.searchCriteria = append(p.searchCriteria, c)
}
}
return p, nil
} }

228
querylog/qlog_test.go Normal file
View File

@ -0,0 +1,228 @@
package querylog
import (
"net"
"os"
"testing"
"github.com/AdguardTeam/dnsproxy/proxyutil"
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
)
func prepareTestDir() string {
const dir = "./agh-test"
_ = os.RemoveAll(dir)
_ = os.MkdirAll(dir, 0755)
return dir
}
// Check adding and loading (with filtering) entries from disk and memory
func TestQueryLog(t *testing.T) {
conf := Config{
Enabled: true,
Interval: 1,
MemSize: 100,
}
conf.BaseDir = prepareTestDir()
defer func() { _ = os.RemoveAll(conf.BaseDir) }()
l := newQueryLog(conf)
// add disk entries
addEntry(l, "example.org", "1.1.1.1", "2.2.2.1")
// write to disk (first file)
_ = l.flushLogBuffer(true)
// start writing to the second file
_ = l.rotate()
// add disk entries
addEntry(l, "example.org", "1.1.1.2", "2.2.2.2")
// write to disk
_ = l.flushLogBuffer(true)
// add memory entries
addEntry(l, "test.example.org", "1.1.1.3", "2.2.2.3")
addEntry(l, "example.com", "1.1.1.4", "2.2.2.4")
// get all entries
params := newSearchParams()
entries, _ := l.search(params)
assert.Equal(t, 4, len(entries))
assertLogEntry(t, entries[0], "example.com", "1.1.1.4", "2.2.2.4")
assertLogEntry(t, entries[1], "test.example.org", "1.1.1.3", "2.2.2.3")
assertLogEntry(t, entries[2], "example.org", "1.1.1.2", "2.2.2.2")
assertLogEntry(t, entries[3], "example.org", "1.1.1.1", "2.2.2.1")
// search by domain (strict)
params = newSearchParams()
params.searchCriteria = append(params.searchCriteria, searchCriteria{
criteriaType: ctDomain,
strict: true,
value: "test.example.org",
})
entries, _ = l.search(params)
assert.Equal(t, 1, len(entries))
assertLogEntry(t, entries[0], "test.example.org", "1.1.1.3", "2.2.2.3")
// search by domain (not strict)
params = newSearchParams()
params.searchCriteria = append(params.searchCriteria, searchCriteria{
criteriaType: ctDomain,
strict: false,
value: "example.org",
})
entries, _ = l.search(params)
assert.Equal(t, 3, len(entries))
assertLogEntry(t, entries[0], "test.example.org", "1.1.1.3", "2.2.2.3")
assertLogEntry(t, entries[1], "example.org", "1.1.1.2", "2.2.2.2")
assertLogEntry(t, entries[2], "example.org", "1.1.1.1", "2.2.2.1")
// search by client IP (strict)
params = newSearchParams()
params.searchCriteria = append(params.searchCriteria, searchCriteria{
criteriaType: ctClient,
strict: true,
value: "2.2.2.2",
})
entries, _ = l.search(params)
assert.Equal(t, 1, len(entries))
assertLogEntry(t, entries[0], "example.org", "1.1.1.2", "2.2.2.2")
// search by client IP (part of)
params = newSearchParams()
params.searchCriteria = append(params.searchCriteria, searchCriteria{
criteriaType: ctClient,
strict: false,
value: "2.2.2",
})
entries, _ = l.search(params)
assert.Equal(t, 4, len(entries))
assertLogEntry(t, entries[0], "example.com", "1.1.1.4", "2.2.2.4")
assertLogEntry(t, entries[1], "test.example.org", "1.1.1.3", "2.2.2.3")
assertLogEntry(t, entries[2], "example.org", "1.1.1.2", "2.2.2.2")
assertLogEntry(t, entries[3], "example.org", "1.1.1.1", "2.2.2.1")
}
func TestQueryLogOffsetLimit(t *testing.T) {
conf := Config{
Enabled: true,
Interval: 1,
MemSize: 100,
}
conf.BaseDir = prepareTestDir()
defer func() { _ = os.RemoveAll(conf.BaseDir) }()
l := newQueryLog(conf)
// add 10 entries to the log
for i := 0; i < 10; i++ {
addEntry(l, "second.example.org", "1.1.1.1", "2.2.2.1")
}
// write them to disk (first file)
_ = l.flushLogBuffer(true)
// add 10 more entries to the log (memory)
for i := 0; i < 10; i++ {
addEntry(l, "first.example.org", "1.1.1.1", "2.2.2.1")
}
// First page
params := newSearchParams()
params.offset = 0
params.limit = 10
entries, _ := l.search(params)
assert.Equal(t, 10, len(entries))
assert.Equal(t, entries[0].QHost, "first.example.org")
assert.Equal(t, entries[9].QHost, "first.example.org")
// Second page
params.offset = 10
params.limit = 10
entries, _ = l.search(params)
assert.Equal(t, 10, len(entries))
assert.Equal(t, entries[0].QHost, "second.example.org")
assert.Equal(t, entries[9].QHost, "second.example.org")
// Second and a half page
params.offset = 15
params.limit = 10
entries, _ = l.search(params)
assert.Equal(t, 5, len(entries))
assert.Equal(t, entries[0].QHost, "second.example.org")
assert.Equal(t, entries[4].QHost, "second.example.org")
// Third page
params.offset = 20
params.limit = 10
entries, _ = l.search(params)
assert.Equal(t, 0, len(entries))
}
func TestQueryLogMaxFileScanEntries(t *testing.T) {
conf := Config{
Enabled: true,
Interval: 1,
MemSize: 100,
}
conf.BaseDir = prepareTestDir()
defer func() { _ = os.RemoveAll(conf.BaseDir) }()
l := newQueryLog(conf)
// add 10 entries to the log
for i := 0; i < 10; i++ {
addEntry(l, "example.org", "1.1.1.1", "2.2.2.1")
}
// write them to disk (first file)
_ = l.flushLogBuffer(true)
params := newSearchParams()
params.maxFileScanEntries = 5 // do not scan more than 5 records
entries, _ := l.search(params)
assert.Equal(t, 5, len(entries))
params.maxFileScanEntries = 0 // disable the limit
entries, _ = l.search(params)
assert.Equal(t, 10, len(entries))
}
func addEntry(l *queryLog, host, answerStr, client string) {
q := dns.Msg{}
q.Question = append(q.Question, dns.Question{
Name: host + ".",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
})
a := dns.Msg{}
a.Question = append(a.Question, q.Question[0])
answer := new(dns.A)
answer.Hdr = dns.RR_Header{
Name: q.Question[0].Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
}
answer.A = net.ParseIP(answerStr)
a.Answer = append(a.Answer, answer)
res := dnsfilter.Result{}
params := AddParams{
Question: &q,
Answer: &a,
Result: &res,
ClientIP: net.ParseIP(client),
Upstream: "upstream",
}
l.Add(params)
}
func assertLogEntry(t *testing.T, entry *logEntry, host, answer, client string) bool {
assert.Equal(t, host, entry.QHost)
assert.Equal(t, client, entry.IP)
assert.Equal(t, "A", entry.QType)
assert.Equal(t, "IN", entry.QClass)
msg := new(dns.Msg)
assert.Nil(t, msg.Unpack(entry.Answer))
assert.Equal(t, 1, len(msg.Answer))
ip := proxyutil.GetIPFromDNSRecord(msg.Answer[0])
assert.NotNil(t, ip)
assert.Equal(t, answer, ip.String())
return true
}

View File

@ -1,18 +1,72 @@
package querylog package querylog
import ( import (
"encoding/base64"
"io" "io"
"strconv"
"strings"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
"github.com/AdguardTeam/AdGuardHome/util" "github.com/AdguardTeam/AdGuardHome/util"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
) )
// search - searches log entries in the query log using specified parameters
// returns the list of entries found + time of the oldest entry
func (l *queryLog) search(params *searchParams) ([]*logEntry, time.Time) {
now := time.Now()
if params.limit == 0 {
return []*logEntry{}, time.Time{}
}
// add from file
fileEntries, oldest, total := l.searchFiles(params)
// add from memory buffer
l.bufferLock.Lock()
total += len(l.buffer)
memoryEntries := make([]*logEntry, 0)
// go through the buffer in the reverse order
// from NEWER to OLDER
for i := len(l.buffer) - 1; i >= 0; i-- {
entry := l.buffer[i]
if !params.match(entry) {
continue
}
memoryEntries = append(memoryEntries, entry)
}
l.bufferLock.Unlock()
// limits
totalLimit := params.offset + params.limit
// now let's get a unified collection
entries := append(memoryEntries, fileEntries...)
if len(entries) > totalLimit {
// remove extra records
entries = entries[:totalLimit]
}
if params.offset > 0 {
if len(entries) > params.offset {
entries = entries[params.offset:]
} else {
entries = make([]*logEntry, 0)
oldest = time.Time{}
}
}
if len(entries) == totalLimit {
// change the "oldest" value here.
// we cannot use the "oldest" we got from "searchFiles" anymore
// because after adding in-memory records and removing extra records
// the situation has changed
oldest = entries[len(entries)-1].Time
}
log.Debug("QueryLog: prepared data (%d/%d) older than %s in %s",
len(entries), total, params.olderThan, time.Since(now))
return entries, oldest
}
// searchFiles reads log entries from all log files and applies the specified search criteria. // searchFiles reads log entries from all log files and applies the specified search criteria.
// IMPORTANT: this method does not scan more than "maxSearchEntries" so you // IMPORTANT: this method does not scan more than "maxSearchEntries" so you
// may need to call it many times. // may need to call it many times.
@ -21,7 +75,7 @@ import (
// * an array of log entries that we have read // * an array of log entries that we have read
// * time of the oldest processed entry (even if it was discarded) // * time of the oldest processed entry (even if it was discarded)
// * total number of processed entries (including discarded). // * total number of processed entries (including discarded).
func (l *queryLog) searchFiles(params getDataParams) ([]*logEntry, time.Time, int) { func (l *queryLog) searchFiles(params *searchParams) ([]*logEntry, time.Time, int) {
entries := make([]*logEntry, 0) entries := make([]*logEntry, 0)
oldest := time.Time{} oldest := time.Time{}
@ -32,10 +86,10 @@ func (l *queryLog) searchFiles(params getDataParams) ([]*logEntry, time.Time, in
} }
defer r.Close() defer r.Close()
if params.OlderThan.IsZero() { if params.olderThan.IsZero() {
err = r.SeekStart() err = r.SeekStart()
} else { } else {
err = r.Seek(params.OlderThan.UnixNano()) err = r.Seek(params.olderThan.UnixNano())
if err == nil { if err == nil {
// Read to the next record right away // Read to the next record right away
// The one that was specified in the "oldest" param is not needed, // The one that was specified in the "oldest" param is not needed,
@ -45,14 +99,17 @@ func (l *queryLog) searchFiles(params getDataParams) ([]*logEntry, time.Time, in
} }
if err != nil { if err != nil {
log.Debug("Cannot Seek() to %v: %v", params.OlderThan, err) log.Debug("Cannot Seek() to %v: %v", params.olderThan, err)
return entries, oldest, 0 return entries, oldest, 0
} }
totalLimit := params.offset + params.limit
total := 0 total := 0
oldestNano := int64(0) oldestNano := int64(0)
// Do not scan more than 50k at once // By default, we do not scan more than "maxFileScanEntries" at once
for total <= maxSearchEntries { // The idea is to make search calls faster so that the UI could handle it and show something
// This behavior can be overridden if "maxFileScanEntries" is set to 0
for total < params.maxFileScanEntries || params.maxFileScanEntries <= 0 {
entry, ts, err := l.readNextEntry(r, params) entry, ts, err := l.readNextEntry(r, params)
if err == io.EOF { if err == io.EOF {
@ -65,8 +122,8 @@ func (l *queryLog) searchFiles(params getDataParams) ([]*logEntry, time.Time, in
if entry != nil { if entry != nil {
entries = append(entries, entry) entries = append(entries, entry)
if len(entries) == getDataLimit { if len(entries) == totalLimit {
// Do not read more than "getDataLimit" records at once // Do not read more than "totalLimit" records at once
break break
} }
} }
@ -82,7 +139,7 @@ func (l *queryLog) searchFiles(params getDataParams) ([]*logEntry, time.Time, in
// * log entry that matches search criteria or null if it was discarded (or if there's nothing to read) // * log entry that matches search criteria or null if it was discarded (or if there's nothing to read)
// * timestamp of the processed log entry // * timestamp of the processed log entry
// * error if we can't read anymore // * error if we can't read anymore
func (l *queryLog) readNextEntry(r *QLogReader, params getDataParams) (*logEntry, int64, error) { func (l *queryLog) readNextEntry(r *QLogReader, params *searchParams) (*logEntry, int64, error) {
line, err := r.ReadNext() line, err := r.ReadNext()
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
@ -92,7 +149,7 @@ func (l *queryLog) readNextEntry(r *QLogReader, params getDataParams) (*logEntry
timestamp := readQLogTimestamp(line) timestamp := readQLogTimestamp(line)
// Quick check without deserializing log entry // Quick check without deserializing log entry
if !quickMatchesGetDataParams(line, params) { if !params.quickMatch(line) {
return nil, timestamp, nil return nil, timestamp, nil
} }
@ -100,7 +157,7 @@ func (l *queryLog) readNextEntry(r *QLogReader, params getDataParams) (*logEntry
decodeLogEntry(&entry, line) decodeLogEntry(&entry, line)
// Full check of the deserialized log entry // Full check of the deserialized log entry
if !matchesGetDataParams(&entry, params) { if !params.match(&entry) {
return nil, timestamp, nil return nil, timestamp, nil
} }
@ -120,257 +177,3 @@ func (l *queryLog) openReader() (*QLogReader, error) {
return NewQLogReader(files) return NewQLogReader(files)
} }
// quickMatchesGetDataParams - quickly checks if the line matches getDataParams
// this method does not guarantee anything and the reason is to do a quick check
// without deserializing anything
func quickMatchesGetDataParams(line string, params getDataParams) bool {
if params.ResponseStatus == responseStatusFiltered {
boolVal, ok := readJSONBool(line, "IsFiltered")
if !ok || !boolVal {
return false
}
}
if len(params.Domain) != 0 {
val := readJSONValue(line, "QH")
if len(val) == 0 {
return false
}
if (params.StrictMatchDomain && val != params.Domain) ||
(!params.StrictMatchDomain && strings.Index(val, params.Domain) == -1) {
return false
}
}
if len(params.QuestionType) != 0 {
val := readJSONValue(line, "QT")
if val != params.QuestionType {
return false
}
}
if len(params.Client) != 0 {
val := readJSONValue(line, "IP")
if len(val) == 0 {
log.Debug("QueryLog: failed to decodeLogEntry")
return false
}
if (params.StrictMatchClient && val != params.Client) ||
(!params.StrictMatchClient && strings.Index(val, params.Client) == -1) {
return false
}
}
return true
}
// matchesGetDataParams - returns true if the entry matches the search parameters
func matchesGetDataParams(entry *logEntry, params getDataParams) bool {
if params.ResponseStatus == responseStatusFiltered && !entry.Result.IsFiltered {
return false
}
if len(params.QuestionType) != 0 {
if entry.QType != params.QuestionType {
return false
}
}
if len(params.Domain) != 0 {
if (params.StrictMatchDomain && entry.QHost != params.Domain) ||
(!params.StrictMatchDomain && strings.Index(entry.QHost, params.Domain) == -1) {
return false
}
}
if len(params.Client) != 0 {
if (params.StrictMatchClient && entry.IP != params.Client) ||
(!params.StrictMatchClient && strings.Index(entry.IP, params.Client) == -1) {
return false
}
}
return true
}
// decodeLogEntry - decodes query log entry from a line
// nolint (gocyclo)
func decodeLogEntry(ent *logEntry, str string) {
var b bool
var i int
var err error
for {
k, v, t := readJSON(&str)
if t == jsonTErr {
break
}
switch k {
case "IP":
if len(ent.IP) == 0 {
ent.IP = v
}
case "T":
ent.Time, err = time.Parse(time.RFC3339, v)
case "QH":
ent.QHost = v
case "QT":
ent.QType = v
case "QC":
ent.QClass = v
case "Answer":
ent.Answer, err = base64.StdEncoding.DecodeString(v)
case "OrigAnswer":
ent.OrigAnswer, err = base64.StdEncoding.DecodeString(v)
case "IsFiltered":
b, err = strconv.ParseBool(v)
ent.Result.IsFiltered = b
case "Rule":
ent.Result.Rule = v
case "FilterID":
i, err = strconv.Atoi(v)
ent.Result.FilterID = int64(i)
case "Reason":
i, err = strconv.Atoi(v)
ent.Result.Reason = dnsfilter.Reason(i)
case "Upstream":
ent.Upstream = v
case "Elapsed":
i, err = strconv.Atoi(v)
ent.Elapsed = time.Duration(i)
// pre-v0.99.3 compatibility:
case "Question":
var qstr []byte
qstr, err = base64.StdEncoding.DecodeString(v)
if err != nil {
break
}
q := new(dns.Msg)
err = q.Unpack(qstr)
if err != nil {
break
}
ent.QHost = q.Question[0].Name
if len(ent.QHost) == 0 {
break
}
ent.QHost = ent.QHost[:len(ent.QHost)-1]
ent.QType = dns.TypeToString[q.Question[0].Qtype]
ent.QClass = dns.ClassToString[q.Question[0].Qclass]
case "Time":
ent.Time, err = time.Parse(time.RFC3339, v)
}
if err != nil {
log.Debug("decodeLogEntry err: %s", err)
break
}
}
}
// Get bool value from "key":bool
func readJSONBool(s, name string) (bool, bool) {
i := strings.Index(s, "\""+name+"\":")
if i == -1 {
return false, false
}
start := i + 1 + len(name) + 2
b := false
if strings.HasPrefix(s[start:], "true") {
b = true
} else if !strings.HasPrefix(s[start:], "false") {
return false, false
}
return b, true
}
// Get value from "key":"value"
func readJSONValue(s, name string) string {
i := strings.Index(s, "\""+name+"\":\"")
if i == -1 {
return ""
}
start := i + 1 + len(name) + 3
i = strings.IndexByte(s[start:], '"')
if i == -1 {
return ""
}
end := start + i
return s[start:end]
}
const (
jsonTErr = iota
jsonTObj
jsonTStr
jsonTNum
jsonTBool
)
// Parse JSON key-value pair
// e.g.: "key":VALUE where VALUE is "string", true|false (boolean), or 123.456 (number)
// Note the limitations:
// . doesn't support whitespace
// . doesn't support "null"
// . doesn't validate boolean or number
// . no proper handling of {} braces
// . no handling of [] brackets
// Return (key, value, type)
func readJSON(ps *string) (string, string, int32) {
s := *ps
k := ""
v := ""
t := int32(jsonTErr)
q1 := strings.IndexByte(s, '"')
if q1 == -1 {
return k, v, t
}
q2 := strings.IndexByte(s[q1+1:], '"')
if q2 == -1 {
return k, v, t
}
k = s[q1+1 : q1+1+q2]
s = s[q1+1+q2+1:]
if len(s) < 2 || s[0] != ':' {
return k, v, t
}
if s[1] == '"' {
q2 = strings.IndexByte(s[2:], '"')
if q2 == -1 {
return k, v, t
}
v = s[2 : 2+q2]
t = jsonTStr
s = s[2+q2+1:]
} else if s[1] == '{' {
t = jsonTObj
s = s[1+1:]
} else {
sep := strings.IndexAny(s[1:], ",}")
if sep == -1 {
return k, v, t
}
v = s[1 : 1+sep]
if s[1] == 't' || s[1] == 'f' {
t = jsonTBool
} else if s[1] == '.' || (s[1] >= '0' && s[1] <= '9') {
t = jsonTNum
}
s = s[1+sep+1:]
}
*ps = s
return k, v, t
}

View File

@ -1,176 +0,0 @@
package querylog
import (
"net"
"os"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
)
func prepareTestDir() string {
const dir = "./agh-test"
_ = os.RemoveAll(dir)
_ = os.MkdirAll(dir, 0755)
return dir
}
// Check adding and loading (with filtering) entries from disk and memory
func TestQueryLog(t *testing.T) {
conf := Config{
Enabled: true,
Interval: 1,
MemSize: 100,
}
conf.BaseDir = prepareTestDir()
defer func() { _ = os.RemoveAll(conf.BaseDir) }()
l := newQueryLog(conf)
// add disk entries
addEntry(l, "example.org", "1.1.1.1", "2.2.2.1")
// write to disk (first file)
_ = l.flushLogBuffer(true)
// start writing to the second file
_ = l.rotate()
// add disk entries
addEntry(l, "example.org", "1.1.1.2", "2.2.2.2")
// write to disk
_ = l.flushLogBuffer(true)
// add memory entries
addEntry(l, "test.example.org", "1.1.1.3", "2.2.2.3")
addEntry(l, "example.com", "1.1.1.4", "2.2.2.4")
// get all entries
params := getDataParams{
OlderThan: time.Time{},
}
d := l.getData(params)
mdata := d["data"].([]map[string]interface{})
assert.Equal(t, 4, len(mdata))
assert.True(t, checkEntry(t, mdata[0], "example.com", "1.1.1.4", "2.2.2.4"))
assert.True(t, checkEntry(t, mdata[1], "test.example.org", "1.1.1.3", "2.2.2.3"))
assert.True(t, checkEntry(t, mdata[2], "example.org", "1.1.1.2", "2.2.2.2"))
assert.True(t, checkEntry(t, mdata[3], "example.org", "1.1.1.1", "2.2.2.1"))
// search by domain (strict)
params = getDataParams{
OlderThan: time.Time{},
Domain: "test.example.org",
StrictMatchDomain: true,
}
d = l.getData(params)
mdata = d["data"].([]map[string]interface{})
assert.True(t, len(mdata) == 1)
assert.True(t, checkEntry(t, mdata[0], "test.example.org", "1.1.1.3", "2.2.2.3"))
// search by domain (not strict)
params = getDataParams{
OlderThan: time.Time{},
Domain: "example.org",
StrictMatchDomain: false,
}
d = l.getData(params)
mdata = d["data"].([]map[string]interface{})
assert.Equal(t, 3, len(mdata))
assert.True(t, checkEntry(t, mdata[0], "test.example.org", "1.1.1.3", "2.2.2.3"))
assert.True(t, checkEntry(t, mdata[1], "example.org", "1.1.1.2", "2.2.2.2"))
assert.True(t, checkEntry(t, mdata[2], "example.org", "1.1.1.1", "2.2.2.1"))
// search by client IP (strict)
params = getDataParams{
OlderThan: time.Time{},
Client: "2.2.2.2",
StrictMatchClient: true,
}
d = l.getData(params)
mdata = d["data"].([]map[string]interface{})
assert.Equal(t, 1, len(mdata))
assert.True(t, checkEntry(t, mdata[0], "example.org", "1.1.1.2", "2.2.2.2"))
// search by client IP (part of)
params = getDataParams{
OlderThan: time.Time{},
Client: "2.2.2",
StrictMatchClient: false,
}
d = l.getData(params)
mdata = d["data"].([]map[string]interface{})
assert.Equal(t, 4, len(mdata))
assert.True(t, checkEntry(t, mdata[0], "example.com", "1.1.1.4", "2.2.2.4"))
assert.True(t, checkEntry(t, mdata[1], "test.example.org", "1.1.1.3", "2.2.2.3"))
assert.True(t, checkEntry(t, mdata[2], "example.org", "1.1.1.2", "2.2.2.2"))
assert.True(t, checkEntry(t, mdata[3], "example.org", "1.1.1.1", "2.2.2.1"))
}
func addEntry(l *queryLog, host, answerStr, client string) {
q := dns.Msg{}
q.Question = append(q.Question, dns.Question{
Name: host + ".",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
})
a := dns.Msg{}
a.Question = append(a.Question, q.Question[0])
answer := new(dns.A)
answer.Hdr = dns.RR_Header{
Name: q.Question[0].Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
}
answer.A = net.ParseIP(answerStr)
a.Answer = append(a.Answer, answer)
res := dnsfilter.Result{}
params := AddParams{
Question: &q,
Answer: &a,
Result: &res,
ClientIP: net.ParseIP(client),
Upstream: "upstream",
}
l.Add(params)
}
func checkEntry(t *testing.T, m map[string]interface{}, host, answer, client string) bool {
mq := m["question"].(map[string]interface{})
ma := m["answer"].([]map[string]interface{})
ma0 := ma[0]
if !assert.Equal(t, host, mq["host"].(string)) ||
!assert.Equal(t, "IN", mq["class"].(string)) ||
!assert.Equal(t, "A", mq["type"].(string)) ||
!assert.Equal(t, answer, ma0["value"].(string)) ||
!assert.Equal(t, client, m["client"].(string)) {
return false
}
return true
}
func TestJSON(t *testing.T) {
s := `
{"keystr":"val","obj":{"keybool":true,"keyint":123456}}
`
k, v, jtype := readJSON(&s)
assert.Equal(t, jtype, int32(jsonTStr))
assert.Equal(t, "keystr", k)
assert.Equal(t, "val", v)
k, v, jtype = readJSON(&s)
assert.Equal(t, jtype, int32(jsonTObj))
assert.Equal(t, "obj", k)
k, v, jtype = readJSON(&s)
assert.Equal(t, jtype, int32(jsonTBool))
assert.Equal(t, "keybool", k)
assert.Equal(t, "true", v)
k, v, jtype = readJSON(&s)
assert.Equal(t, jtype, int32(jsonTNum))
assert.Equal(t, "keyint", k)
assert.Equal(t, "123456", v)
k, v, jtype = readJSON(&s)
assert.True(t, jtype == jsonTErr)
}

139
querylog/search_criteria.go Normal file
View File

@ -0,0 +1,139 @@
package querylog
import (
"strings"
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
)
type criteriaType int
const (
ctDomain criteriaType = iota // domain name
ctClient // client IP address
ctQuestionType // question type
ctFilteringStatus // filtering status
)
const (
filteringStatusAll = "all"
filteringStatusFiltered = "filtered" // all kinds of filtering
filteringStatusBlocked = "blocked" // blocked or blocked service
filteringStatusBlockedSafebrowsing = "blocked_safebrowsing" // blocked by safebrowsing
filteringStatusBlockedParental = "blocked_parental" // blocked by parental control
filteringStatusWhitelisted = "whitelisted" // whitelisted
filteringStatusRewritten = "rewritten" // all kinds of rewrites
filteringStatusSafeSearch = "safe_search" // enforced safe search
)
// filteringStatusValues -- array with all possible filteringStatus values
var filteringStatusValues = []string{
filteringStatusAll, filteringStatusFiltered, filteringStatusBlocked,
filteringStatusBlockedSafebrowsing, filteringStatusBlockedParental,
filteringStatusWhitelisted, filteringStatusRewritten, filteringStatusSafeSearch,
}
// searchCriteria - every search request may contain a list of different search criteria
// we use each of them to match the query
type searchCriteria struct {
criteriaType criteriaType // type of the criteria
strict bool // should we strictly match (equality) or not (indexOf)
value string // search criteria value
}
// quickMatch - quickly checks if the log entry matches this search criteria
// the reason is to do it as quickly as possible without de-serializing the entry
func (c *searchCriteria) quickMatch(line string) bool {
// note that we do this only for a limited set of criteria
switch c.criteriaType {
case ctDomain:
return c.quickMatchJSONValue(line, "QH")
case ctClient:
return c.quickMatchJSONValue(line, "IP")
case ctQuestionType:
return c.quickMatchJSONValue(line, "QT")
default:
return true
}
}
// quickMatchJSONValue - helper used by quickMatch
func (c *searchCriteria) quickMatchJSONValue(line string, propertyName string) bool {
val := readJSONValue(line, propertyName)
if len(val) == 0 {
return false
}
if c.strict && c.value == val {
return true
}
if !c.strict && strings.Contains(val, c.value) {
return true
}
return false
}
// match - checks if the log entry matches this search criteria
// nolint (gocyclo)
func (c *searchCriteria) match(entry *logEntry) bool {
switch c.criteriaType {
case ctDomain:
if c.strict && entry.QHost == c.value {
return true
}
if !c.strict && strings.Contains(entry.QHost, c.value) {
return true
}
return false
case ctClient:
if c.strict && entry.IP == c.value {
return true
}
if !c.strict && strings.Contains(entry.IP, c.value) {
return true
}
return false
case ctQuestionType:
if c.strict && entry.QType == c.value {
return true
}
if !c.strict && strings.Contains(entry.QType, c.value) {
return true
}
case ctFilteringStatus:
res := entry.Result
switch c.value {
case filteringStatusAll:
return true
case filteringStatusFiltered:
return res.IsFiltered
case filteringStatusBlocked:
return res.IsFiltered &&
(res.Reason == dnsfilter.FilteredBlackList ||
res.Reason == dnsfilter.FilteredBlockedService)
case filteringStatusBlockedParental:
return res.IsFiltered && res.Reason == dnsfilter.FilteredParental
case filteringStatusBlockedSafebrowsing:
return res.IsFiltered && res.Reason == dnsfilter.FilteredSafeBrowsing
case filteringStatusWhitelisted:
return res.IsFiltered && res.Reason == dnsfilter.NotFilteredWhiteList
case filteringStatusRewritten:
return res.IsFiltered &&
(res.Reason == dnsfilter.ReasonRewrite ||
res.Reason == dnsfilter.RewriteEtcHosts)
case filteringStatusSafeSearch:
return res.IsFiltered && res.Reason == dnsfilter.FilteredSafeSearch
default:
return false
}
default:
return false
}
return false
}

57
querylog/search_params.go Normal file
View File

@ -0,0 +1,57 @@
package querylog
import "time"
// searchParams represent the search query sent by the client
type searchParams struct {
// searchCriteria - list of search criteria that we use to get filter results
searchCriteria []searchCriteria
// olderThen - return entries that are older than this value
// if not set - disregard it and return any value
olderThan time.Time
offset int // offset for the search
limit int // limit the number of records returned
maxFileScanEntries int // maximum log entries to scan in query log files. if 0 - no limit
}
// newSearchParams - creates an empty instance of searchParams
func newSearchParams() *searchParams {
return &searchParams{
// default max log entries to return
limit: 500,
// by default, we scan up to 50k entries at once
maxFileScanEntries: 50000,
}
}
// quickMatchesGetDataParams - quickly checks if the line matches the searchParams
// this method does not guarantee anything and the reason is to do a quick check
// without deserializing anything
func (s *searchParams) quickMatch(line string) bool {
for _, c := range s.searchCriteria {
if !c.quickMatch(line) {
return false
}
}
return true
}
// match - checks if the logEntry matches the searchParams
func (s *searchParams) match(entry *logEntry) bool {
if !s.olderThan.IsZero() && entry.Time.UnixNano() >= s.olderThan.UnixNano() {
// Ignore entries newer than what was requested
return false
}
for _, c := range s.searchCriteria {
if !c.match(entry) {
return false
}
}
return true
}

View File

@ -10,6 +10,16 @@ import (
"strings" "strings"
) )
// ContainsString checks if "v" is in the array "arr"
func ContainsString(arr []string, v string) bool {
for _, i := range arr {
if i == v {
return true
}
}
return false
}
// fileExists returns TRUE if file exists // fileExists returns TRUE if file exists
func FileExists(fn string) bool { func FileExists(fn string) bool {
_, err := os.Stat(fn) _, err := os.Stat(fn)