*(querylog): added offset/limit parameters

Actually, this is a serious refactoring of the query log module. The
rest API part is refactored, it's now more clear how the search is
conducted. Split the logic into more files and added more tests.

Closes: https://github.com/AdguardTeam/AdGuardHome/issues/1559
This commit is contained in:
Andrey Meshkov 2020-05-26 15:37:37 +03:00
parent 6f3cd4e7eb
commit ceedc2d1ff
12 changed files with 1013 additions and 776 deletions

View File

@ -16,12 +16,14 @@ module.exports = {
],
"scopes": [
"",
"ui",
"global",
"dnsfilter",
"home",
"dnsforward",
"dhcpd",
"documentation"
"querylog",
"documentation",
],
"types": {
"+": {

View File

@ -1,5 +1,12 @@
# AdGuard Home API Change Log
## v0.103: API changes
### API: Get querylog: GET /control/querylog
* Added optional "offset" and "limit" parameters
We are still using "older_than" approach in AdGuard Home UI, but we realize that it's easier to use offset/limit so here is this option now.
## v0.102: API changes

View File

@ -143,13 +143,26 @@ paths:
tags:
- log
operationId: queryLog
summary: Get DNS server query log
summary: Get DNS server query log.
parameters:
- name: older_than
in: query
description: Filter by older than
schema:
type: string
- name: offset
in: query
description:
Specify the ranking number of the first item on the page.
Even though it is possible to use "offset" and "older_than",
we recommend choosing one of them and sticking to it.
schema:
type: integer
- name: limit
in: query
description: Limit the number of records to be returned
schema:
type: integer
- name: filter_domain
in: query
description: Filter by domain name

330
querylog/json.go Normal file
View File

@ -0,0 +1,330 @@
package querylog
import (
"encoding/base64"
"fmt"
"net"
"strconv"
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
)
// decodeLogEntry - decodes query log entry from a line
// nolint (gocyclo)
func decodeLogEntry(ent *logEntry, str string) {
var b bool
var i int
var err error
for {
k, v, t := readJSON(&str)
if t == jsonTErr {
break
}
switch k {
case "IP":
if len(ent.IP) == 0 {
ent.IP = v
}
case "T":
ent.Time, err = time.Parse(time.RFC3339, v)
case "QH":
ent.QHost = v
case "QT":
ent.QType = v
case "QC":
ent.QClass = v
case "Answer":
ent.Answer, err = base64.StdEncoding.DecodeString(v)
case "OrigAnswer":
ent.OrigAnswer, err = base64.StdEncoding.DecodeString(v)
case "IsFiltered":
b, err = strconv.ParseBool(v)
ent.Result.IsFiltered = b
case "Rule":
ent.Result.Rule = v
case "FilterID":
i, err = strconv.Atoi(v)
ent.Result.FilterID = int64(i)
case "Reason":
i, err = strconv.Atoi(v)
ent.Result.Reason = dnsfilter.Reason(i)
case "Upstream":
ent.Upstream = v
case "Elapsed":
i, err = strconv.Atoi(v)
ent.Elapsed = time.Duration(i)
// pre-v0.99.3 compatibility:
case "Question":
var qstr []byte
qstr, err = base64.StdEncoding.DecodeString(v)
if err != nil {
break
}
q := new(dns.Msg)
err = q.Unpack(qstr)
if err != nil {
break
}
ent.QHost = q.Question[0].Name
if len(ent.QHost) == 0 {
break
}
ent.QHost = ent.QHost[:len(ent.QHost)-1]
ent.QType = dns.TypeToString[q.Question[0].Qtype]
ent.QClass = dns.ClassToString[q.Question[0].Qclass]
case "Time":
ent.Time, err = time.Parse(time.RFC3339, v)
}
if err != nil {
log.Debug("decodeLogEntry err: %s", err)
break
}
}
}
// Get value from "key":"value"
func readJSONValue(s, name string) string {
i := strings.Index(s, "\""+name+"\":\"")
if i == -1 {
return ""
}
start := i + 1 + len(name) + 3
i = strings.IndexByte(s[start:], '"')
if i == -1 {
return ""
}
end := start + i
return s[start:end]
}
const (
jsonTErr = iota
jsonTObj
jsonTStr
jsonTNum
jsonTBool
)
// Parse JSON key-value pair
// e.g.: "key":VALUE where VALUE is "string", true|false (boolean), or 123.456 (number)
// Note the limitations:
// . doesn't support whitespace
// . doesn't support "null"
// . doesn't validate boolean or number
// . no proper handling of {} braces
// . no handling of [] brackets
// Return (key, value, type)
func readJSON(ps *string) (string, string, int32) {
s := *ps
k := ""
v := ""
t := int32(jsonTErr)
q1 := strings.IndexByte(s, '"')
if q1 == -1 {
return k, v, t
}
q2 := strings.IndexByte(s[q1+1:], '"')
if q2 == -1 {
return k, v, t
}
k = s[q1+1 : q1+1+q2]
s = s[q1+1+q2+1:]
if len(s) < 2 || s[0] != ':' {
return k, v, t
}
if s[1] == '"' {
q2 = strings.IndexByte(s[2:], '"')
if q2 == -1 {
return k, v, t
}
v = s[2 : 2+q2]
t = jsonTStr
s = s[2+q2+1:]
} else if s[1] == '{' {
t = jsonTObj
s = s[1+1:]
} else {
sep := strings.IndexAny(s[1:], ",}")
if sep == -1 {
return k, v, t
}
v = s[1 : 1+sep]
if s[1] == 't' || s[1] == 'f' {
t = jsonTBool
} else if s[1] == '.' || (s[1] >= '0' && s[1] <= '9') {
t = jsonTNum
}
s = s[1+sep+1:]
}
*ps = s
return k, v, t
}
// Get Client IP address
func (l *queryLog) getClientIP(clientIP string) string {
if l.conf.AnonymizeClientIP {
ip := net.ParseIP(clientIP)
if ip != nil {
ip4 := ip.To4()
const AnonymizeClientIP4Mask = 24
const AnonymizeClientIP6Mask = 112
if ip4 != nil {
clientIP = ip4.Mask(net.CIDRMask(AnonymizeClientIP4Mask, 32)).String()
} else {
clientIP = ip.Mask(net.CIDRMask(AnonymizeClientIP6Mask, 128)).String()
}
}
}
return clientIP
}
// entriesToJSON - converts log entries to JSON
func (l *queryLog) entriesToJSON(entries []*logEntry, oldest time.Time) map[string]interface{} {
// init the response object
var data = []map[string]interface{}{}
// the elements order is already reversed (from newer to older)
for i := 0; i < len(entries); i++ {
entry := entries[i]
jsonEntry := l.logEntryToJSONEntry(entry)
data = append(data, jsonEntry)
}
var result = map[string]interface{}{}
result["oldest"] = ""
if !oldest.IsZero() {
result["oldest"] = oldest.Format(time.RFC3339Nano)
}
result["data"] = data
return result
}
func (l *queryLog) logEntryToJSONEntry(entry *logEntry) map[string]interface{} {
var msg *dns.Msg
if len(entry.Answer) > 0 {
msg = new(dns.Msg)
if err := msg.Unpack(entry.Answer); err != nil {
log.Debug("Failed to unpack dns message answer: %s: %s", err, string(entry.Answer))
msg = nil
}
}
jsonEntry := map[string]interface{}{
"reason": entry.Result.Reason.String(),
"elapsedMs": strconv.FormatFloat(entry.Elapsed.Seconds()*1000, 'f', -1, 64),
"time": entry.Time.Format(time.RFC3339Nano),
"client": l.getClientIP(entry.IP),
}
jsonEntry["question"] = map[string]interface{}{
"host": entry.QHost,
"type": entry.QType,
"class": entry.QClass,
}
if msg != nil {
jsonEntry["status"] = dns.RcodeToString[msg.Rcode]
opt := msg.IsEdns0()
dnssecOk := false
if opt != nil {
dnssecOk = opt.Do()
}
jsonEntry["answer_dnssec"] = dnssecOk
}
if len(entry.Result.Rule) > 0 {
jsonEntry["rule"] = entry.Result.Rule
jsonEntry["filterId"] = entry.Result.FilterID
}
if len(entry.Result.ServiceName) != 0 {
jsonEntry["service_name"] = entry.Result.ServiceName
}
answers := answerToMap(msg)
if answers != nil {
jsonEntry["answer"] = answers
}
if len(entry.OrigAnswer) != 0 {
a := new(dns.Msg)
err := a.Unpack(entry.OrigAnswer)
if err == nil {
answers = answerToMap(a)
if answers != nil {
jsonEntry["original_answer"] = answers
}
} else {
log.Debug("Querylog: msg.Unpack(entry.OrigAnswer): %s: %s", err, string(entry.OrigAnswer))
}
}
return jsonEntry
}
func answerToMap(a *dns.Msg) []map[string]interface{} {
if a == nil || len(a.Answer) == 0 {
return nil
}
var answers = []map[string]interface{}{}
for _, k := range a.Answer {
header := k.Header()
answer := map[string]interface{}{
"type": dns.TypeToString[header.Rrtype],
"ttl": header.Ttl,
}
// try most common record types
switch v := k.(type) {
case *dns.A:
answer["value"] = v.A.String()
case *dns.AAAA:
answer["value"] = v.AAAA.String()
case *dns.MX:
answer["value"] = fmt.Sprintf("%v %v", v.Preference, v.Mx)
case *dns.CNAME:
answer["value"] = v.Target
case *dns.NS:
answer["value"] = v.Ns
case *dns.SPF:
answer["value"] = v.Txt
case *dns.TXT:
answer["value"] = v.Txt
case *dns.PTR:
answer["value"] = v.Ptr
case *dns.SOA:
answer["value"] = fmt.Sprintf("%v %v %v %v %v %v %v", v.Ns, v.Mbox, v.Serial, v.Refresh, v.Retry, v.Expire, v.Minttl)
case *dns.CAA:
answer["value"] = fmt.Sprintf("%v %v \"%v\"", v.Flag, v.Tag, v.Value)
case *dns.HINFO:
answer["value"] = fmt.Sprintf("\"%v\" \"%v\"", v.Cpu, v.Os)
case *dns.RRSIG:
answer["value"] = fmt.Sprintf("%v %v %v %v %v %v %v %v %v", dns.TypeToString[v.TypeCovered], v.Algorithm, v.Labels, v.OrigTtl, v.Expiration, v.Inception, v.KeyTag, v.SignerName, v.Signature)
default:
// type unknown, marshall it as-is
answer["value"] = v
}
answers = append(answers, answer)
}
return answers
}

View File

@ -1,11 +1,8 @@
package querylog
import (
"fmt"
"net"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"time"
@ -17,10 +14,6 @@ import (
const (
queryLogFileName = "querylog.json" // .gz added during compression
getDataLimit = 500 // GetData(): maximum log entries to return
// maximum entries to parse when searching
maxSearchEntries = 50000
)
// queryLog is a structure that writes and reads the DNS query log
@ -36,6 +29,23 @@ type queryLog struct {
fileWriteLock sync.Mutex
}
// logEntry - represents a single log entry
type logEntry struct {
IP string `json:"IP"` // Client IP
Time time.Time `json:"T"`
QHost string `json:"QH"`
QType string `json:"QT"`
QClass string `json:"QC"`
Answer []byte `json:",omitempty"` // sometimes empty answers happen like binerdunt.top or rev2.globalrootservers.net
OrigAnswer []byte `json:",omitempty"`
Result dnsfilter.Result
Elapsed time.Duration
Upstream string `json:",omitempty"` // if empty, means it was cached
}
// create a new instance of the query log
func newQueryLog(conf Config) *queryLog {
l := queryLog{}
@ -93,22 +103,6 @@ func (l *queryLog) clear() {
log.Debug("Query log: cleared")
}
type logEntry struct {
IP string `json:"IP"`
Time time.Time `json:"T"`
QHost string `json:"QH"`
QType string `json:"QT"`
QClass string `json:"QC"`
Answer []byte `json:",omitempty"` // sometimes empty answers happen like binerdunt.top or rev2.globalrootservers.net
OrigAnswer []byte `json:",omitempty"`
Result dnsfilter.Result
Elapsed time.Duration
Upstream string `json:",omitempty"` // if empty, means it was cached
}
func (l *queryLog) Add(params AddParams) {
if !l.conf.Enabled {
return
@ -173,230 +167,3 @@ func (l *queryLog) Add(params AddParams) {
go l.flushLogBuffer(false) // nolint
}
}
// Parameters for getData()
type getDataParams struct {
OlderThan time.Time // return entries that are older than this value
Domain string // filter by domain name in question
Client string // filter by client IP
QuestionType string // filter by question type
ResponseStatus responseStatusType // filter by response status
StrictMatchDomain bool // if Domain value must be matched strictly
StrictMatchClient bool // if Client value must be matched strictly
}
// Response status
type responseStatusType int32
// Response status constants
const (
responseStatusAll responseStatusType = iota + 1
responseStatusFiltered
)
// Gets log entries
func (l *queryLog) getData(params getDataParams) map[string]interface{} {
now := time.Now()
if len(params.Client) != 0 && l.conf.AnonymizeClientIP {
params.Client = l.getClientIP(params.Client)
}
// add from file
fileEntries, oldest, total := l.searchFiles(params)
if params.OlderThan.IsZero() {
// In case if the timer is not precise (for instance, on Windows)
// We really want to get all records including those added just before the call
params.OlderThan = now.Add(time.Millisecond)
}
// add from memory buffer
l.bufferLock.Lock()
total += len(l.buffer)
memoryEntries := make([]*logEntry, 0)
// go through the buffer in the reverse order
// from NEWER to OLDER
for i := len(l.buffer) - 1; i >= 0; i-- {
entry := l.buffer[i]
if entry.Time.UnixNano() >= params.OlderThan.UnixNano() {
// Ignore entries newer than what was requested
continue
}
if !matchesGetDataParams(entry, params) {
continue
}
memoryEntries = append(memoryEntries, entry)
}
l.bufferLock.Unlock()
// now let's get a unified collection
entries := append(memoryEntries, fileEntries...)
if len(entries) > getDataLimit {
// remove extra records
entries = entries[:getDataLimit]
}
if len(entries) == getDataLimit {
// change the "oldest" value here.
// we cannot use the "oldest" we got from "searchFiles" anymore
// because after adding in-memory records and removing extra records
// the situation has changed
oldest = entries[len(entries)-1].Time
}
// init the response object
var data = []map[string]interface{}{}
// the elements order is already reversed (from newer to older)
for i := 0; i < len(entries); i++ {
entry := entries[i]
jsonEntry := l.logEntryToJSONEntry(entry)
data = append(data, jsonEntry)
}
log.Debug("QueryLog: prepared data (%d/%d) older than %s in %s",
len(entries), total, params.OlderThan, time.Since(now))
var result = map[string]interface{}{}
result["oldest"] = ""
if !oldest.IsZero() {
result["oldest"] = oldest.Format(time.RFC3339Nano)
}
result["data"] = data
return result
}
// Get Client IP address
func (l *queryLog) getClientIP(clientIP string) string {
if l.conf.AnonymizeClientIP {
ip := net.ParseIP(clientIP)
if ip != nil {
ip4 := ip.To4()
const AnonymizeClientIP4Mask = 24
const AnonymizeClientIP6Mask = 112
if ip4 != nil {
clientIP = ip4.Mask(net.CIDRMask(AnonymizeClientIP4Mask, 32)).String()
} else {
clientIP = ip.Mask(net.CIDRMask(AnonymizeClientIP6Mask, 128)).String()
}
}
}
return clientIP
}
func (l *queryLog) logEntryToJSONEntry(entry *logEntry) map[string]interface{} {
var msg *dns.Msg
if len(entry.Answer) > 0 {
msg = new(dns.Msg)
if err := msg.Unpack(entry.Answer); err != nil {
log.Debug("Failed to unpack dns message answer: %s: %s", err, string(entry.Answer))
msg = nil
}
}
jsonEntry := map[string]interface{}{
"reason": entry.Result.Reason.String(),
"elapsedMs": strconv.FormatFloat(entry.Elapsed.Seconds()*1000, 'f', -1, 64),
"time": entry.Time.Format(time.RFC3339Nano),
"client": l.getClientIP(entry.IP),
}
jsonEntry["question"] = map[string]interface{}{
"host": entry.QHost,
"type": entry.QType,
"class": entry.QClass,
}
if msg != nil {
jsonEntry["status"] = dns.RcodeToString[msg.Rcode]
opt := msg.IsEdns0()
dnssecOk := false
if opt != nil {
dnssecOk = opt.Do()
}
jsonEntry["answer_dnssec"] = dnssecOk
}
if len(entry.Result.Rule) > 0 {
jsonEntry["rule"] = entry.Result.Rule
jsonEntry["filterId"] = entry.Result.FilterID
}
if len(entry.Result.ServiceName) != 0 {
jsonEntry["service_name"] = entry.Result.ServiceName
}
answers := answerToMap(msg)
if answers != nil {
jsonEntry["answer"] = answers
}
if len(entry.OrigAnswer) != 0 {
a := new(dns.Msg)
err := a.Unpack(entry.OrigAnswer)
if err == nil {
answers = answerToMap(a)
if answers != nil {
jsonEntry["original_answer"] = answers
}
} else {
log.Debug("Querylog: msg.Unpack(entry.OrigAnswer): %s: %s", err, string(entry.OrigAnswer))
}
}
return jsonEntry
}
func answerToMap(a *dns.Msg) []map[string]interface{} {
if a == nil || len(a.Answer) == 0 {
return nil
}
var answers = []map[string]interface{}{}
for _, k := range a.Answer {
header := k.Header()
answer := map[string]interface{}{
"type": dns.TypeToString[header.Rrtype],
"ttl": header.Ttl,
}
// try most common record types
switch v := k.(type) {
case *dns.A:
answer["value"] = v.A.String()
case *dns.AAAA:
answer["value"] = v.AAAA.String()
case *dns.MX:
answer["value"] = fmt.Sprintf("%v %v", v.Preference, v.Mx)
case *dns.CNAME:
answer["value"] = v.Target
case *dns.NS:
answer["value"] = v.Ns
case *dns.SPF:
answer["value"] = v.Txt
case *dns.TXT:
answer["value"] = v.Txt
case *dns.PTR:
answer["value"] = v.Ptr
case *dns.SOA:
answer["value"] = fmt.Sprintf("%v %v %v %v %v %v %v", v.Ns, v.Mbox, v.Serial, v.Refresh, v.Retry, v.Expire, v.Minttl)
case *dns.CAA:
answer["value"] = fmt.Sprintf("%v %v \"%v\"", v.Flag, v.Tag, v.Value)
case *dns.HINFO:
answer["value"] = fmt.Sprintf("\"%v\" \"%v\"", v.Cpu, v.Os)
case *dns.RRSIG:
answer["value"] = fmt.Sprintf("%v %v %v %v %v %v %v %v %v", dns.TypeToString[v.TypeCovered], v.Algorithm, v.Labels, v.OrigTtl, v.Expiration, v.Inception, v.KeyTag, v.SignerName, v.Signature)
default:
// type unknown, marshall it as-is
answer["value"] = v
}
answers = append(answers, answer)
}
return answers
}

View File

@ -4,13 +4,30 @@ import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"strconv"
"time"
"github.com/AdguardTeam/AdGuardHome/util"
"github.com/AdguardTeam/golibs/jsonutil"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
)
type qlogConfig struct {
Enabled bool `json:"enabled"`
Interval uint32 `json:"interval"`
AnonymizeClientIP bool `json:"anonymize_client_ip"`
}
// Register web handlers
func (l *queryLog) initWeb() {
l.conf.HTTPRegister("GET", "/control/querylog", l.handleQueryLog)
l.conf.HTTPRegister("GET", "/control/querylog_info", l.handleQueryLogInfo)
l.conf.HTTPRegister("POST", "/control/querylog_clear", l.handleQueryLogClear)
l.conf.HTTPRegister("POST", "/control/querylog_config", l.handleQueryLogConfig)
}
func httpError(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) {
text := fmt.Sprintf(format, args...)
@ -19,74 +36,18 @@ func httpError(r *http.Request, w http.ResponseWriter, code int, format string,
http.Error(w, text, code)
}
type request struct {
olderThan string
filterDomain string
filterClient string
filterQuestionType string
filterResponseStatus string
}
// "value" -> value, return TRUE
func getDoubleQuotesEnclosedValue(s *string) bool {
t := *s
if len(t) >= 2 && t[0] == '"' && t[len(t)-1] == '"' {
*s = t[1 : len(t)-1]
return true
}
return false
}
func (l *queryLog) handleQueryLog(w http.ResponseWriter, r *http.Request) {
var err error
req := request{}
q := r.URL.Query()
req.olderThan = q.Get("older_than")
req.filterDomain = q.Get("filter_domain")
req.filterClient = q.Get("filter_client")
req.filterQuestionType = q.Get("filter_question_type")
req.filterResponseStatus = q.Get("filter_response_status")
params := getDataParams{
Domain: req.filterDomain,
Client: req.filterClient,
ResponseStatus: responseStatusAll,
}
if len(req.olderThan) != 0 {
params.OlderThan, err = time.Parse(time.RFC3339Nano, req.olderThan)
params, err := l.parseSearchParams(r)
if err != nil {
httpError(r, w, http.StatusBadRequest, "invalid time stamp: %s", err)
httpError(r, w, http.StatusBadRequest, "failed to parse params: %s", err)
return
}
}
if getDoubleQuotesEnclosedValue(&params.Domain) {
params.StrictMatchDomain = true
}
if getDoubleQuotesEnclosedValue(&params.Client) {
params.StrictMatchClient = true
}
// search for the log entries
entries, oldest := l.search(params)
if len(req.filterQuestionType) != 0 {
_, ok := dns.StringToType[req.filterQuestionType]
if !ok {
httpError(r, w, http.StatusBadRequest, "invalid question_type")
return
}
params.QuestionType = req.filterQuestionType
}
if len(req.filterResponseStatus) != 0 {
switch req.filterResponseStatus {
case "filtered":
params.ResponseStatus = responseStatusFiltered
default:
httpError(r, w, http.StatusBadRequest, "invalid response_status")
return
}
}
data := l.getData(params)
// convert log entries to JSON
var data = l.entriesToJSON(entries, oldest)
jsonVal, err := json.Marshal(data)
if err != nil {
@ -101,16 +62,10 @@ func (l *queryLog) handleQueryLog(w http.ResponseWriter, r *http.Request) {
}
}
func (l *queryLog) handleQueryLogClear(w http.ResponseWriter, r *http.Request) {
func (l *queryLog) handleQueryLogClear(_ http.ResponseWriter, _ *http.Request) {
l.clear()
}
type qlogConfig struct {
Enabled bool `json:"enabled"`
Interval uint32 `json:"interval"`
AnonymizeClientIP bool `json:"anonymize_client_ip"`
}
// Get configuration
func (l *queryLog) handleQueryLogInfo(w http.ResponseWriter, r *http.Request) {
resp := qlogConfig{}
@ -162,10 +117,85 @@ func (l *queryLog) handleQueryLogConfig(w http.ResponseWriter, r *http.Request)
l.conf.ConfigModified()
}
// Register web handlers
func (l *queryLog) initWeb() {
l.conf.HTTPRegister("GET", "/control/querylog", l.handleQueryLog)
l.conf.HTTPRegister("GET", "/control/querylog_info", l.handleQueryLogInfo)
l.conf.HTTPRegister("POST", "/control/querylog_clear", l.handleQueryLogClear)
l.conf.HTTPRegister("POST", "/control/querylog_config", l.handleQueryLogConfig)
// "value" -> value, return TRUE
func getDoubleQuotesEnclosedValue(s *string) bool {
t := *s
if len(t) >= 2 && t[0] == '"' && t[len(t)-1] == '"' {
*s = t[1 : len(t)-1]
return true
}
return false
}
// parseSearchCriteria - parses "searchCriteria" from the specified query parameter
func (l *queryLog) parseSearchCriteria(q url.Values, name string, ct criteriaType) (bool, searchCriteria, error) {
val := q.Get(name)
if len(val) == 0 {
return false, searchCriteria{}, nil
}
c := searchCriteria{
criteriaType: ct,
value: val,
}
if getDoubleQuotesEnclosedValue(&c.value) {
c.strict = true
}
if ct == ctClient && l.conf.AnonymizeClientIP {
c.value = l.getClientIP(c.value)
}
if ct == ctFilteringStatus && !util.ContainsString(filteringStatusValues, c.value) {
return false, c, fmt.Errorf("invalid value %s", c.value)
}
return true, c, nil
}
// parseSearchParams - parses "searchParams" from the HTTP request's query string
func (l *queryLog) parseSearchParams(r *http.Request) (*searchParams, error) {
p := newSearchParams()
var err error
q := r.URL.Query()
olderThan := q.Get("older_than")
if len(olderThan) != 0 {
p.olderThan, err = time.Parse(time.RFC3339Nano, olderThan)
if err != nil {
return nil, err
}
}
if limit, err := strconv.ParseInt(q.Get("limit"), 10, 64); err == nil {
p.limit = int(limit)
// If limit or offset are specified explicitly, we should change the default behavior
// and scan all log records until we found enough log entries
p.maxFileScanEntries = 0
}
if offset, err := strconv.ParseInt(q.Get("offset"), 10, 64); err == nil {
p.offset = int(offset)
p.maxFileScanEntries = 0
}
paramNames := map[string]criteriaType{
"filter_domain": ctDomain,
"filter_client": ctClient,
"filter_question_type": ctQuestionType,
"filter_response_status": ctFilteringStatus,
}
for k, v := range paramNames {
ok, c, err := l.parseSearchCriteria(q, k, v)
if err != nil {
return nil, err
}
if ok {
p.searchCriteria = append(p.searchCriteria, c)
}
}
return p, nil
}

255
querylog/qlog_test.go Normal file
View File

@ -0,0 +1,255 @@
package querylog
import (
"net"
"os"
"testing"
"github.com/AdguardTeam/dnsproxy/proxyutil"
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
)
func prepareTestDir() string {
const dir = "./agh-test"
_ = os.RemoveAll(dir)
_ = os.MkdirAll(dir, 0755)
return dir
}
// Check adding and loading (with filtering) entries from disk and memory
func TestQueryLog(t *testing.T) {
conf := Config{
Enabled: true,
Interval: 1,
MemSize: 100,
}
conf.BaseDir = prepareTestDir()
defer func() { _ = os.RemoveAll(conf.BaseDir) }()
l := newQueryLog(conf)
// add disk entries
addEntry(l, "example.org", "1.1.1.1", "2.2.2.1")
// write to disk (first file)
_ = l.flushLogBuffer(true)
// start writing to the second file
_ = l.rotate()
// add disk entries
addEntry(l, "example.org", "1.1.1.2", "2.2.2.2")
// write to disk
_ = l.flushLogBuffer(true)
// add memory entries
addEntry(l, "test.example.org", "1.1.1.3", "2.2.2.3")
addEntry(l, "example.com", "1.1.1.4", "2.2.2.4")
// get all entries
params := newSearchParams()
entries, _ := l.search(params)
assert.Equal(t, 4, len(entries))
assertLogEntry(t, entries[0], "example.com", "1.1.1.4", "2.2.2.4")
assertLogEntry(t, entries[1], "test.example.org", "1.1.1.3", "2.2.2.3")
assertLogEntry(t, entries[2], "example.org", "1.1.1.2", "2.2.2.2")
assertLogEntry(t, entries[3], "example.org", "1.1.1.1", "2.2.2.1")
// search by domain (strict)
params = newSearchParams()
params.searchCriteria = append(params.searchCriteria, searchCriteria{
criteriaType: ctDomain,
strict: true,
value: "test.example.org",
})
entries, _ = l.search(params)
assert.Equal(t, 1, len(entries))
assertLogEntry(t, entries[0], "test.example.org", "1.1.1.3", "2.2.2.3")
// search by domain (not strict)
params = newSearchParams()
params.searchCriteria = append(params.searchCriteria, searchCriteria{
criteriaType: ctDomain,
strict: false,
value: "example.org",
})
entries, _ = l.search(params)
assert.Equal(t, 3, len(entries))
assertLogEntry(t, entries[0], "test.example.org", "1.1.1.3", "2.2.2.3")
assertLogEntry(t, entries[1], "example.org", "1.1.1.2", "2.2.2.2")
assertLogEntry(t, entries[2], "example.org", "1.1.1.1", "2.2.2.1")
// search by client IP (strict)
params = newSearchParams()
params.searchCriteria = append(params.searchCriteria, searchCriteria{
criteriaType: ctClient,
strict: true,
value: "2.2.2.2",
})
entries, _ = l.search(params)
assert.Equal(t, 1, len(entries))
assertLogEntry(t, entries[0], "example.org", "1.1.1.2", "2.2.2.2")
// search by client IP (part of)
params = newSearchParams()
params.searchCriteria = append(params.searchCriteria, searchCriteria{
criteriaType: ctClient,
strict: false,
value: "2.2.2",
})
entries, _ = l.search(params)
assert.Equal(t, 4, len(entries))
assertLogEntry(t, entries[0], "example.com", "1.1.1.4", "2.2.2.4")
assertLogEntry(t, entries[1], "test.example.org", "1.1.1.3", "2.2.2.3")
assertLogEntry(t, entries[2], "example.org", "1.1.1.2", "2.2.2.2")
assertLogEntry(t, entries[3], "example.org", "1.1.1.1", "2.2.2.1")
}
func TestQueryLogOffsetLimit(t *testing.T) {
conf := Config{
Enabled: true,
Interval: 1,
MemSize: 100,
}
conf.BaseDir = prepareTestDir()
defer func() { _ = os.RemoveAll(conf.BaseDir) }()
l := newQueryLog(conf)
// add 10 entries to the log
for i := 0; i < 10; i++ {
addEntry(l, "second.example.org", "1.1.1.1", "2.2.2.1")
}
// write them to disk (first file)
_ = l.flushLogBuffer(true)
// add 10 more entries to the log (memory)
for i := 0; i < 10; i++ {
addEntry(l, "first.example.org", "1.1.1.1", "2.2.2.1")
}
// First page
params := newSearchParams()
params.offset = 0
params.limit = 10
entries, _ := l.search(params)
assert.Equal(t, 10, len(entries))
assert.Equal(t, entries[0].QHost, "first.example.org")
assert.Equal(t, entries[9].QHost, "first.example.org")
// Second page
params.offset = 10
params.limit = 10
entries, _ = l.search(params)
assert.Equal(t, 10, len(entries))
assert.Equal(t, entries[0].QHost, "second.example.org")
assert.Equal(t, entries[9].QHost, "second.example.org")
// Second and a half page
params.offset = 15
params.limit = 10
entries, _ = l.search(params)
assert.Equal(t, 5, len(entries))
assert.Equal(t, entries[0].QHost, "second.example.org")
assert.Equal(t, entries[4].QHost, "second.example.org")
// Third page
params.offset = 20
params.limit = 10
entries, _ = l.search(params)
assert.Equal(t, 0, len(entries))
}
func TestQueryLogMaxFileScanEntries(t *testing.T) {
conf := Config{
Enabled: true,
Interval: 1,
MemSize: 100,
}
conf.BaseDir = prepareTestDir()
defer func() { _ = os.RemoveAll(conf.BaseDir) }()
l := newQueryLog(conf)
// add 10 entries to the log
for i := 0; i < 10; i++ {
addEntry(l, "example.org", "1.1.1.1", "2.2.2.1")
}
// write them to disk (first file)
_ = l.flushLogBuffer(true)
params := newSearchParams()
params.maxFileScanEntries = 5 // do not scan more than 5 records
entries, _ := l.search(params)
assert.Equal(t, 5, len(entries))
params.maxFileScanEntries = 0 // disable the limit
entries, _ = l.search(params)
assert.Equal(t, 10, len(entries))
}
func TestJSON(t *testing.T) {
s := `
{"keystr":"val","obj":{"keybool":true,"keyint":123456}}
`
k, v, jtype := readJSON(&s)
assert.Equal(t, jtype, int32(jsonTStr))
assert.Equal(t, "keystr", k)
assert.Equal(t, "val", v)
k, v, jtype = readJSON(&s)
assert.Equal(t, jtype, int32(jsonTObj))
assert.Equal(t, "obj", k)
k, v, jtype = readJSON(&s)
assert.Equal(t, jtype, int32(jsonTBool))
assert.Equal(t, "keybool", k)
assert.Equal(t, "true", v)
k, v, jtype = readJSON(&s)
assert.Equal(t, jtype, int32(jsonTNum))
assert.Equal(t, "keyint", k)
assert.Equal(t, "123456", v)
k, v, jtype = readJSON(&s)
assert.True(t, jtype == jsonTErr)
}
func addEntry(l *queryLog, host, answerStr, client string) {
q := dns.Msg{}
q.Question = append(q.Question, dns.Question{
Name: host + ".",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
})
a := dns.Msg{}
a.Question = append(a.Question, q.Question[0])
answer := new(dns.A)
answer.Hdr = dns.RR_Header{
Name: q.Question[0].Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
}
answer.A = net.ParseIP(answerStr)
a.Answer = append(a.Answer, answer)
res := dnsfilter.Result{}
params := AddParams{
Question: &q,
Answer: &a,
Result: &res,
ClientIP: net.ParseIP(client),
Upstream: "upstream",
}
l.Add(params)
}
func assertLogEntry(t *testing.T, entry *logEntry, host, answer, client string) bool {
assert.Equal(t, host, entry.QHost)
assert.Equal(t, client, entry.IP)
assert.Equal(t, "A", entry.QType)
assert.Equal(t, "IN", entry.QClass)
msg := new(dns.Msg)
assert.Nil(t, msg.Unpack(entry.Answer))
assert.Equal(t, 1, len(msg.Answer))
ip := proxyutil.GetIPFromDNSRecord(msg.Answer[0])
assert.NotNil(t, ip)
assert.Equal(t, answer, ip.String())
return true
}

View File

@ -1,18 +1,72 @@
package querylog
import (
"encoding/base64"
"io"
"strconv"
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
"github.com/AdguardTeam/AdGuardHome/util"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
)
// search - searches log entries in the query log using specified parameters
// returns the list of entries found + time of the oldest entry
func (l *queryLog) search(params *searchParams) ([]*logEntry, time.Time) {
now := time.Now()
if params.limit == 0 {
return []*logEntry{}, time.Time{}
}
// add from file
fileEntries, oldest, total := l.searchFiles(params)
// add from memory buffer
l.bufferLock.Lock()
total += len(l.buffer)
memoryEntries := make([]*logEntry, 0)
// go through the buffer in the reverse order
// from NEWER to OLDER
for i := len(l.buffer) - 1; i >= 0; i-- {
entry := l.buffer[i]
if !params.match(entry) {
continue
}
memoryEntries = append(memoryEntries, entry)
}
l.bufferLock.Unlock()
// limits
totalLimit := params.offset + params.limit
// now let's get a unified collection
entries := append(memoryEntries, fileEntries...)
if len(entries) > totalLimit {
// remove extra records
entries = entries[:totalLimit]
}
if params.offset > 0 {
if len(entries) > params.offset {
entries = entries[params.offset:]
} else {
entries = make([]*logEntry, 0)
oldest = time.Time{}
}
}
if len(entries) == totalLimit {
// change the "oldest" value here.
// we cannot use the "oldest" we got from "searchFiles" anymore
// because after adding in-memory records and removing extra records
// the situation has changed
oldest = entries[len(entries)-1].Time
}
log.Debug("QueryLog: prepared data (%d/%d) older than %s in %s",
len(entries), total, params.olderThan, time.Since(now))
return entries, oldest
}
// searchFiles reads log entries from all log files and applies the specified search criteria.
// IMPORTANT: this method does not scan more than "maxSearchEntries" so you
// may need to call it many times.
@ -21,7 +75,7 @@ import (
// * an array of log entries that we have read
// * time of the oldest processed entry (even if it was discarded)
// * total number of processed entries (including discarded).
func (l *queryLog) searchFiles(params getDataParams) ([]*logEntry, time.Time, int) {
func (l *queryLog) searchFiles(params *searchParams) ([]*logEntry, time.Time, int) {
entries := make([]*logEntry, 0)
oldest := time.Time{}
@ -32,10 +86,10 @@ func (l *queryLog) searchFiles(params getDataParams) ([]*logEntry, time.Time, in
}
defer r.Close()
if params.OlderThan.IsZero() {
if params.olderThan.IsZero() {
err = r.SeekStart()
} else {
err = r.Seek(params.OlderThan.UnixNano())
err = r.Seek(params.olderThan.UnixNano())
if err == nil {
// Read to the next record right away
// The one that was specified in the "oldest" param is not needed,
@ -45,14 +99,17 @@ func (l *queryLog) searchFiles(params getDataParams) ([]*logEntry, time.Time, in
}
if err != nil {
log.Debug("Cannot Seek() to %v: %v", params.OlderThan, err)
log.Debug("Cannot Seek() to %v: %v", params.olderThan, err)
return entries, oldest, 0
}
totalLimit := params.offset + params.limit
total := 0
oldestNano := int64(0)
// Do not scan more than 50k at once
for total <= maxSearchEntries {
// By default, we do not scan more than "maxFileScanEntries" at once
// The idea is to make search calls faster so that the UI could handle it and show something
// This behavior can be overridden if "maxFileScanEntries" is set to 0
for total < params.maxFileScanEntries || params.maxFileScanEntries <= 0 {
entry, ts, err := l.readNextEntry(r, params)
if err == io.EOF {
@ -65,8 +122,8 @@ func (l *queryLog) searchFiles(params getDataParams) ([]*logEntry, time.Time, in
if entry != nil {
entries = append(entries, entry)
if len(entries) == getDataLimit {
// Do not read more than "getDataLimit" records at once
if len(entries) == totalLimit {
// Do not read more than "totalLimit" records at once
break
}
}
@ -82,7 +139,7 @@ func (l *queryLog) searchFiles(params getDataParams) ([]*logEntry, time.Time, in
// * log entry that matches search criteria or null if it was discarded (or if there's nothing to read)
// * timestamp of the processed log entry
// * error if we can't read anymore
func (l *queryLog) readNextEntry(r *QLogReader, params getDataParams) (*logEntry, int64, error) {
func (l *queryLog) readNextEntry(r *QLogReader, params *searchParams) (*logEntry, int64, error) {
line, err := r.ReadNext()
if err != nil {
return nil, 0, err
@ -92,7 +149,7 @@ func (l *queryLog) readNextEntry(r *QLogReader, params getDataParams) (*logEntry
timestamp := readQLogTimestamp(line)
// Quick check without deserializing log entry
if !quickMatchesGetDataParams(line, params) {
if !params.quickMatch(line) {
return nil, timestamp, nil
}
@ -100,7 +157,7 @@ func (l *queryLog) readNextEntry(r *QLogReader, params getDataParams) (*logEntry
decodeLogEntry(&entry, line)
// Full check of the deserialized log entry
if !matchesGetDataParams(&entry, params) {
if !params.match(&entry) {
return nil, timestamp, nil
}
@ -120,257 +177,3 @@ func (l *queryLog) openReader() (*QLogReader, error) {
return NewQLogReader(files)
}
// quickMatchesGetDataParams - quickly checks if the line matches getDataParams
// this method does not guarantee anything and the reason is to do a quick check
// without deserializing anything
func quickMatchesGetDataParams(line string, params getDataParams) bool {
if params.ResponseStatus == responseStatusFiltered {
boolVal, ok := readJSONBool(line, "IsFiltered")
if !ok || !boolVal {
return false
}
}
if len(params.Domain) != 0 {
val := readJSONValue(line, "QH")
if len(val) == 0 {
return false
}
if (params.StrictMatchDomain && val != params.Domain) ||
(!params.StrictMatchDomain && strings.Index(val, params.Domain) == -1) {
return false
}
}
if len(params.QuestionType) != 0 {
val := readJSONValue(line, "QT")
if val != params.QuestionType {
return false
}
}
if len(params.Client) != 0 {
val := readJSONValue(line, "IP")
if len(val) == 0 {
log.Debug("QueryLog: failed to decodeLogEntry")
return false
}
if (params.StrictMatchClient && val != params.Client) ||
(!params.StrictMatchClient && strings.Index(val, params.Client) == -1) {
return false
}
}
return true
}
// matchesGetDataParams - returns true if the entry matches the search parameters
func matchesGetDataParams(entry *logEntry, params getDataParams) bool {
if params.ResponseStatus == responseStatusFiltered && !entry.Result.IsFiltered {
return false
}
if len(params.QuestionType) != 0 {
if entry.QType != params.QuestionType {
return false
}
}
if len(params.Domain) != 0 {
if (params.StrictMatchDomain && entry.QHost != params.Domain) ||
(!params.StrictMatchDomain && strings.Index(entry.QHost, params.Domain) == -1) {
return false
}
}
if len(params.Client) != 0 {
if (params.StrictMatchClient && entry.IP != params.Client) ||
(!params.StrictMatchClient && strings.Index(entry.IP, params.Client) == -1) {
return false
}
}
return true
}
// decodeLogEntry - decodes query log entry from a line
// nolint (gocyclo)
func decodeLogEntry(ent *logEntry, str string) {
var b bool
var i int
var err error
for {
k, v, t := readJSON(&str)
if t == jsonTErr {
break
}
switch k {
case "IP":
if len(ent.IP) == 0 {
ent.IP = v
}
case "T":
ent.Time, err = time.Parse(time.RFC3339, v)
case "QH":
ent.QHost = v
case "QT":
ent.QType = v
case "QC":
ent.QClass = v
case "Answer":
ent.Answer, err = base64.StdEncoding.DecodeString(v)
case "OrigAnswer":
ent.OrigAnswer, err = base64.StdEncoding.DecodeString(v)
case "IsFiltered":
b, err = strconv.ParseBool(v)
ent.Result.IsFiltered = b
case "Rule":
ent.Result.Rule = v
case "FilterID":
i, err = strconv.Atoi(v)
ent.Result.FilterID = int64(i)
case "Reason":
i, err = strconv.Atoi(v)
ent.Result.Reason = dnsfilter.Reason(i)
case "Upstream":
ent.Upstream = v
case "Elapsed":
i, err = strconv.Atoi(v)
ent.Elapsed = time.Duration(i)
// pre-v0.99.3 compatibility:
case "Question":
var qstr []byte
qstr, err = base64.StdEncoding.DecodeString(v)
if err != nil {
break
}
q := new(dns.Msg)
err = q.Unpack(qstr)
if err != nil {
break
}
ent.QHost = q.Question[0].Name
if len(ent.QHost) == 0 {
break
}
ent.QHost = ent.QHost[:len(ent.QHost)-1]
ent.QType = dns.TypeToString[q.Question[0].Qtype]
ent.QClass = dns.ClassToString[q.Question[0].Qclass]
case "Time":
ent.Time, err = time.Parse(time.RFC3339, v)
}
if err != nil {
log.Debug("decodeLogEntry err: %s", err)
break
}
}
}
// Get bool value from "key":bool
func readJSONBool(s, name string) (bool, bool) {
i := strings.Index(s, "\""+name+"\":")
if i == -1 {
return false, false
}
start := i + 1 + len(name) + 2
b := false
if strings.HasPrefix(s[start:], "true") {
b = true
} else if !strings.HasPrefix(s[start:], "false") {
return false, false
}
return b, true
}
// Get value from "key":"value"
func readJSONValue(s, name string) string {
i := strings.Index(s, "\""+name+"\":\"")
if i == -1 {
return ""
}
start := i + 1 + len(name) + 3
i = strings.IndexByte(s[start:], '"')
if i == -1 {
return ""
}
end := start + i
return s[start:end]
}
const (
jsonTErr = iota
jsonTObj
jsonTStr
jsonTNum
jsonTBool
)
// Parse JSON key-value pair
// e.g.: "key":VALUE where VALUE is "string", true|false (boolean), or 123.456 (number)
// Note the limitations:
// . doesn't support whitespace
// . doesn't support "null"
// . doesn't validate boolean or number
// . no proper handling of {} braces
// . no handling of [] brackets
// Return (key, value, type)
func readJSON(ps *string) (string, string, int32) {
s := *ps
k := ""
v := ""
t := int32(jsonTErr)
q1 := strings.IndexByte(s, '"')
if q1 == -1 {
return k, v, t
}
q2 := strings.IndexByte(s[q1+1:], '"')
if q2 == -1 {
return k, v, t
}
k = s[q1+1 : q1+1+q2]
s = s[q1+1+q2+1:]
if len(s) < 2 || s[0] != ':' {
return k, v, t
}
if s[1] == '"' {
q2 = strings.IndexByte(s[2:], '"')
if q2 == -1 {
return k, v, t
}
v = s[2 : 2+q2]
t = jsonTStr
s = s[2+q2+1:]
} else if s[1] == '{' {
t = jsonTObj
s = s[1+1:]
} else {
sep := strings.IndexAny(s[1:], ",}")
if sep == -1 {
return k, v, t
}
v = s[1 : 1+sep]
if s[1] == 't' || s[1] == 'f' {
t = jsonTBool
} else if s[1] == '.' || (s[1] >= '0' && s[1] <= '9') {
t = jsonTNum
}
s = s[1+sep+1:]
}
*ps = s
return k, v, t
}

View File

@ -1,176 +0,0 @@
package querylog
import (
"net"
"os"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
)
func prepareTestDir() string {
const dir = "./agh-test"
_ = os.RemoveAll(dir)
_ = os.MkdirAll(dir, 0755)
return dir
}
// Check adding and loading (with filtering) entries from disk and memory
func TestQueryLog(t *testing.T) {
conf := Config{
Enabled: true,
Interval: 1,
MemSize: 100,
}
conf.BaseDir = prepareTestDir()
defer func() { _ = os.RemoveAll(conf.BaseDir) }()
l := newQueryLog(conf)
// add disk entries
addEntry(l, "example.org", "1.1.1.1", "2.2.2.1")
// write to disk (first file)
_ = l.flushLogBuffer(true)
// start writing to the second file
_ = l.rotate()
// add disk entries
addEntry(l, "example.org", "1.1.1.2", "2.2.2.2")
// write to disk
_ = l.flushLogBuffer(true)
// add memory entries
addEntry(l, "test.example.org", "1.1.1.3", "2.2.2.3")
addEntry(l, "example.com", "1.1.1.4", "2.2.2.4")
// get all entries
params := getDataParams{
OlderThan: time.Time{},
}
d := l.getData(params)
mdata := d["data"].([]map[string]interface{})
assert.Equal(t, 4, len(mdata))
assert.True(t, checkEntry(t, mdata[0], "example.com", "1.1.1.4", "2.2.2.4"))
assert.True(t, checkEntry(t, mdata[1], "test.example.org", "1.1.1.3", "2.2.2.3"))
assert.True(t, checkEntry(t, mdata[2], "example.org", "1.1.1.2", "2.2.2.2"))
assert.True(t, checkEntry(t, mdata[3], "example.org", "1.1.1.1", "2.2.2.1"))
// search by domain (strict)
params = getDataParams{
OlderThan: time.Time{},
Domain: "test.example.org",
StrictMatchDomain: true,
}
d = l.getData(params)
mdata = d["data"].([]map[string]interface{})
assert.True(t, len(mdata) == 1)
assert.True(t, checkEntry(t, mdata[0], "test.example.org", "1.1.1.3", "2.2.2.3"))
// search by domain (not strict)
params = getDataParams{
OlderThan: time.Time{},
Domain: "example.org",
StrictMatchDomain: false,
}
d = l.getData(params)
mdata = d["data"].([]map[string]interface{})
assert.Equal(t, 3, len(mdata))
assert.True(t, checkEntry(t, mdata[0], "test.example.org", "1.1.1.3", "2.2.2.3"))
assert.True(t, checkEntry(t, mdata[1], "example.org", "1.1.1.2", "2.2.2.2"))
assert.True(t, checkEntry(t, mdata[2], "example.org", "1.1.1.1", "2.2.2.1"))
// search by client IP (strict)
params = getDataParams{
OlderThan: time.Time{},
Client: "2.2.2.2",
StrictMatchClient: true,
}
d = l.getData(params)
mdata = d["data"].([]map[string]interface{})
assert.Equal(t, 1, len(mdata))
assert.True(t, checkEntry(t, mdata[0], "example.org", "1.1.1.2", "2.2.2.2"))
// search by client IP (part of)
params = getDataParams{
OlderThan: time.Time{},
Client: "2.2.2",
StrictMatchClient: false,
}
d = l.getData(params)
mdata = d["data"].([]map[string]interface{})
assert.Equal(t, 4, len(mdata))
assert.True(t, checkEntry(t, mdata[0], "example.com", "1.1.1.4", "2.2.2.4"))
assert.True(t, checkEntry(t, mdata[1], "test.example.org", "1.1.1.3", "2.2.2.3"))
assert.True(t, checkEntry(t, mdata[2], "example.org", "1.1.1.2", "2.2.2.2"))
assert.True(t, checkEntry(t, mdata[3], "example.org", "1.1.1.1", "2.2.2.1"))
}
func addEntry(l *queryLog, host, answerStr, client string) {
q := dns.Msg{}
q.Question = append(q.Question, dns.Question{
Name: host + ".",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
})
a := dns.Msg{}
a.Question = append(a.Question, q.Question[0])
answer := new(dns.A)
answer.Hdr = dns.RR_Header{
Name: q.Question[0].Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
}
answer.A = net.ParseIP(answerStr)
a.Answer = append(a.Answer, answer)
res := dnsfilter.Result{}
params := AddParams{
Question: &q,
Answer: &a,
Result: &res,
ClientIP: net.ParseIP(client),
Upstream: "upstream",
}
l.Add(params)
}
func checkEntry(t *testing.T, m map[string]interface{}, host, answer, client string) bool {
mq := m["question"].(map[string]interface{})
ma := m["answer"].([]map[string]interface{})
ma0 := ma[0]
if !assert.Equal(t, host, mq["host"].(string)) ||
!assert.Equal(t, "IN", mq["class"].(string)) ||
!assert.Equal(t, "A", mq["type"].(string)) ||
!assert.Equal(t, answer, ma0["value"].(string)) ||
!assert.Equal(t, client, m["client"].(string)) {
return false
}
return true
}
func TestJSON(t *testing.T) {
s := `
{"keystr":"val","obj":{"keybool":true,"keyint":123456}}
`
k, v, jtype := readJSON(&s)
assert.Equal(t, jtype, int32(jsonTStr))
assert.Equal(t, "keystr", k)
assert.Equal(t, "val", v)
k, v, jtype = readJSON(&s)
assert.Equal(t, jtype, int32(jsonTObj))
assert.Equal(t, "obj", k)
k, v, jtype = readJSON(&s)
assert.Equal(t, jtype, int32(jsonTBool))
assert.Equal(t, "keybool", k)
assert.Equal(t, "true", v)
k, v, jtype = readJSON(&s)
assert.Equal(t, jtype, int32(jsonTNum))
assert.Equal(t, "keyint", k)
assert.Equal(t, "123456", v)
k, v, jtype = readJSON(&s)
assert.True(t, jtype == jsonTErr)
}

139
querylog/search_criteria.go Normal file
View File

@ -0,0 +1,139 @@
package querylog
import (
"strings"
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
)
type criteriaType int
const (
ctDomain criteriaType = iota // domain name
ctClient // client IP address
ctQuestionType // question type
ctFilteringStatus // filtering status
)
const (
filteringStatusAll = "all"
filteringStatusFiltered = "filtered" // all kinds of filtering
filteringStatusBlocked = "blocked" // blocked or blocked service
filteringStatusBlockedSafebrowsing = "blocked_safebrowsing" // blocked by safebrowsing
filteringStatusBlockedParental = "blocked_parental" // blocked by parental control
filteringStatusWhitelisted = "whitelisted" // whitelisted
filteringStatusRewritten = "rewritten" // all kinds of rewrites
filteringStatusSafeSearch = "safe_search" // enforced safe search
)
// filteringStatusValues -- array with all possible filteringStatus values
var filteringStatusValues = []string{
filteringStatusAll, filteringStatusFiltered, filteringStatusBlocked,
filteringStatusBlockedSafebrowsing, filteringStatusBlockedParental,
filteringStatusWhitelisted, filteringStatusRewritten, filteringStatusSafeSearch,
}
// searchCriteria - every search request may contain a list of different search criteria
// we use each of them to match the query
type searchCriteria struct {
criteriaType criteriaType // type of the criteria
strict bool // should we strictly match (equality) or not (indexOf)
value string // search criteria value
}
// quickMatch - quickly checks if the log entry matches this search criteria
// the reason is to do it as quickly as possible without de-serializing the entry
func (c *searchCriteria) quickMatch(line string) bool {
// note that we do this only for a limited set of criteria
switch c.criteriaType {
case ctDomain:
return c.quickMatchJSONValue(line, "QH")
case ctClient:
return c.quickMatchJSONValue(line, "IP")
case ctQuestionType:
return c.quickMatchJSONValue(line, "QT")
default:
return true
}
}
// quickMatchJSONValue - helper used by quickMatch
func (c *searchCriteria) quickMatchJSONValue(line string, propertyName string) bool {
val := readJSONValue(line, propertyName)
if len(val) == 0 {
return false
}
if c.strict && c.value == val {
return true
}
if !c.strict && strings.Contains(val, c.value) {
return true
}
return false
}
// match - checks if the log entry matches this search criteria
// nolint (gocyclo)
func (c *searchCriteria) match(entry *logEntry) bool {
switch c.criteriaType {
case ctDomain:
if c.strict && entry.QHost == c.value {
return true
}
if !c.strict && strings.Contains(entry.QHost, c.value) {
return true
}
return false
case ctClient:
if c.strict && entry.IP == c.value {
return true
}
if !c.strict && strings.Contains(entry.IP, c.value) {
return true
}
return false
case ctQuestionType:
if c.strict && entry.QType == c.value {
return true
}
if !c.strict && strings.Contains(entry.QType, c.value) {
return true
}
case ctFilteringStatus:
res := entry.Result
switch c.value {
case filteringStatusAll:
return true
case filteringStatusFiltered:
return res.IsFiltered
case filteringStatusBlocked:
return res.IsFiltered &&
(res.Reason == dnsfilter.FilteredBlackList ||
res.Reason == dnsfilter.FilteredBlockedService)
case filteringStatusBlockedParental:
return res.IsFiltered && res.Reason == dnsfilter.FilteredParental
case filteringStatusBlockedSafebrowsing:
return res.IsFiltered && res.Reason == dnsfilter.FilteredSafeBrowsing
case filteringStatusWhitelisted:
return res.IsFiltered && res.Reason == dnsfilter.NotFilteredWhiteList
case filteringStatusRewritten:
return res.IsFiltered &&
(res.Reason == dnsfilter.ReasonRewrite ||
res.Reason == dnsfilter.RewriteEtcHosts)
case filteringStatusSafeSearch:
return res.IsFiltered && res.Reason == dnsfilter.FilteredSafeSearch
default:
return false
}
default:
return false
}
return false
}

57
querylog/search_params.go Normal file
View File

@ -0,0 +1,57 @@
package querylog
import "time"
// searchParams represent the search query sent by the client
type searchParams struct {
// searchCriteria - list of search criteria that we use to get filter results
searchCriteria []searchCriteria
// olderThen - return entries that are older than this value
// if not set - disregard it and return any value
olderThan time.Time
offset int // offset for the search
limit int // limit the number of records returned
maxFileScanEntries int // maximum log entries to scan in query log files. if 0 - no limit
}
// newSearchParams - creates an empty instance of searchParams
func newSearchParams() *searchParams {
return &searchParams{
// default max log entries to return
limit: 500,
// by default, we scan up to 50k entries at once
maxFileScanEntries: 50000,
}
}
// quickMatchesGetDataParams - quickly checks if the line matches the searchParams
// this method does not guarantee anything and the reason is to do a quick check
// without deserializing anything
func (s *searchParams) quickMatch(line string) bool {
for _, c := range s.searchCriteria {
if !c.quickMatch(line) {
return false
}
}
return true
}
// match - checks if the logEntry matches the searchParams
func (s *searchParams) match(entry *logEntry) bool {
if !s.olderThan.IsZero() && entry.Time.UnixNano() >= s.olderThan.UnixNano() {
// Ignore entries newer than what was requested
return false
}
for _, c := range s.searchCriteria {
if !c.match(entry) {
return false
}
}
return true
}

View File

@ -10,6 +10,16 @@ import (
"strings"
)
// ContainsString checks if "v" is in the array "arr"
func ContainsString(arr []string, v string) bool {
for _, i := range arr {
if i == v {
return true
}
}
return false
}
// fileExists returns TRUE if file exists
func FileExists(fn string) bool {
_, err := os.Stat(fn)