From 696a6df6ba9992bdc0693bcbd38baceeaca5634d Mon Sep 17 00:00:00 2001 From: Andrey Meshkov Date: Fri, 21 Feb 2020 01:07:30 +0300 Subject: [PATCH] *(dnsforward): replaced old search with the new one --- querylog/qlog.go | 217 +++++-------- querylog/qlog_file.go | 6 +- querylog/qlog_file_test.go | 4 +- querylog/qlog_reader.go | 2 +- querylog/qlog_reader_test.go | 2 +- querylog/querylog_file.go | 597 +---------------------------------- querylog/querylog_search.go | 366 +++++++++++++++++++++ querylog/querylog_test.go | 60 ++-- 8 files changed, 481 insertions(+), 773 deletions(-) create mode 100644 querylog/querylog_search.go diff --git a/querylog/qlog.go b/querylog/qlog.go index 8a34130e..cafcdb5c 100644 --- a/querylog/qlog.go +++ b/querylog/qlog.go @@ -171,82 +171,6 @@ func (l *queryLog) Add(params AddParams) { } } -// Return TRUE if this entry is needed -func isNeeded(entry *logEntry, params getDataParams) bool { - if params.ResponseStatus == responseStatusFiltered && !entry.Result.IsFiltered { - return false - } - - if len(params.QuestionType) != 0 { - if entry.QType != params.QuestionType { - return false - } - } - - if len(params.Domain) != 0 { - if (params.StrictMatchDomain && entry.QHost != params.Domain) || - (!params.StrictMatchDomain && strings.Index(entry.QHost, params.Domain) == -1) { - return false - } - } - - if len(params.Client) != 0 { - if (params.StrictMatchClient && entry.IP != params.Client) || - (!params.StrictMatchClient && strings.Index(entry.IP, params.Client) == -1) { - return false - } - } - - return true -} - -func (l *queryLog) readFromFile(params getDataParams) ([]*logEntry, time.Time, int) { - entries := []*logEntry{} - oldest := time.Time{} - - r := l.OpenReader() - if r == nil { - return entries, time.Time{}, 0 - } - r.BeginRead(params.OlderThan, getDataLimit, ¶ms) - total := uint64(0) - for total <= maxSearchEntries { - newEntries := []*logEntry{} - for { - entry := r.Next() - if entry == nil { - break - } - - if !isNeeded(entry, params) { - continue - } - if len(newEntries) == getDataLimit { - newEntries = newEntries[1:] - } - newEntries = append(newEntries, entry) - } - - log.Debug("entries: +%d (%d) [%d]", len(newEntries), len(entries), r.Total()) - - entries = append(newEntries, entries...) - if len(entries) > getDataLimit { - toremove := len(entries) - getDataLimit - entries = entries[toremove:] - break - } - if r.Total() == 0 || len(entries) == getDataLimit { - break - } - total += r.Total() - oldest = r.Oldest() - r.BeginReadPrev(getDataLimit) - } - - r.Close() - return entries, oldest, int(total) -} - // Parameters for getData() type getDataParams struct { OlderThan time.Time // return entries that are older than this value @@ -267,17 +191,12 @@ const ( responseStatusFiltered ) -// Get log entries +// Gets log entries func (l *queryLog) getData(params getDataParams) map[string]interface{} { - var data = []map[string]interface{}{} - - var oldest time.Time now := time.Now() - entries := []*logEntry{} - total := 0 // add from file - entries, oldest, total = l.readFromFile(params) + fileEntries, oldest, total := l.searchFiles(params) if params.OlderThan.IsZero() { params.OlderThan = now @@ -286,9 +205,9 @@ func (l *queryLog) getData(params getDataParams) map[string]interface{} { // add from memory buffer l.bufferLock.Lock() total += len(l.buffer) + memoryEntries := make([]*logEntry, 0) for _, entry := range l.buffer { - - if !isNeeded(entry, params) { + if !matchesGetDataParams(entry, params) { continue } @@ -296,68 +215,24 @@ func (l *queryLog) getData(params getDataParams) map[string]interface{} { break } - if len(entries) == getDataLimit { - entries = entries[1:] - } - entries = append(entries, entry) + memoryEntries = append(memoryEntries, entry) } l.bufferLock.Unlock() - // process the elements from latest to oldest - for i := len(entries) - 1; i >= 0; i-- { + // now let's get a unified collection + entries := append(memoryEntries, fileEntries...) + if len(entries) > getDataLimit { + // remove extra records + entries = entries[(len(entries) - getDataLimit):] + } + + // init the response object + var data = []map[string]interface{}{} + + // the elements order is already reversed (from newer to older) + for i := 0; i < len(entries); i++ { entry := entries[i] - var a *dns.Msg - - if len(entry.Answer) > 0 { - a = new(dns.Msg) - if err := a.Unpack(entry.Answer); err != nil { - log.Debug("Failed to unpack dns message answer: %s: %s", err, string(entry.Answer)) - a = nil - } - } - - jsonEntry := map[string]interface{}{ - "reason": entry.Result.Reason.String(), - "elapsedMs": strconv.FormatFloat(entry.Elapsed.Seconds()*1000, 'f', -1, 64), - "time": entry.Time.Format(time.RFC3339Nano), - "client": entry.IP, - } - jsonEntry["question"] = map[string]interface{}{ - "host": entry.QHost, - "type": entry.QType, - "class": entry.QClass, - } - - if a != nil { - jsonEntry["status"] = dns.RcodeToString[a.Rcode] - } - if len(entry.Result.Rule) > 0 { - jsonEntry["rule"] = entry.Result.Rule - jsonEntry["filterId"] = entry.Result.FilterID - } - - if len(entry.Result.ServiceName) != 0 { - jsonEntry["service_name"] = entry.Result.ServiceName - } - - answers := answerToMap(a) - if answers != nil { - jsonEntry["answer"] = answers - } - - if len(entry.OrigAnswer) != 0 { - a := new(dns.Msg) - err := a.Unpack(entry.OrigAnswer) - if err == nil { - answers = answerToMap(a) - if answers != nil { - jsonEntry["original_answer"] = answers - } - } else { - log.Debug("Querylog: a.Unpack(entry.OrigAnswer): %s: %s", err, string(entry.OrigAnswer)) - } - } - + jsonEntry := logEntryToJSONEntry(entry) data = append(data, jsonEntry) } @@ -376,6 +251,62 @@ func (l *queryLog) getData(params getDataParams) map[string]interface{} { return result } +func logEntryToJSONEntry(entry *logEntry) map[string]interface{} { + var msg *dns.Msg + + if len(entry.Answer) > 0 { + msg = new(dns.Msg) + if err := msg.Unpack(entry.Answer); err != nil { + log.Debug("Failed to unpack dns message answer: %s: %s", err, string(entry.Answer)) + msg = nil + } + } + + jsonEntry := map[string]interface{}{ + "reason": entry.Result.Reason.String(), + "elapsedMs": strconv.FormatFloat(entry.Elapsed.Seconds()*1000, 'f', -1, 64), + "time": entry.Time.Format(time.RFC3339Nano), + "client": entry.IP, + } + jsonEntry["question"] = map[string]interface{}{ + "host": entry.QHost, + "type": entry.QType, + "class": entry.QClass, + } + + if msg != nil { + jsonEntry["status"] = dns.RcodeToString[msg.Rcode] + } + if len(entry.Result.Rule) > 0 { + jsonEntry["rule"] = entry.Result.Rule + jsonEntry["filterId"] = entry.Result.FilterID + } + + if len(entry.Result.ServiceName) != 0 { + jsonEntry["service_name"] = entry.Result.ServiceName + } + + answers := answerToMap(msg) + if answers != nil { + jsonEntry["answer"] = answers + } + + if len(entry.OrigAnswer) != 0 { + a := new(dns.Msg) + err := a.Unpack(entry.OrigAnswer) + if err == nil { + answers = answerToMap(a) + if answers != nil { + jsonEntry["original_answer"] = answers + } + } else { + log.Debug("Querylog: msg.Unpack(entry.OrigAnswer): %s: %s", err, string(entry.OrigAnswer)) + } + } + + return jsonEntry +} + func answerToMap(a *dns.Msg) []map[string]interface{} { if a == nil || len(a.Answer) == 0 { return nil diff --git a/querylog/qlog_file.go b/querylog/qlog_file.go index 14ae2b32..24c5a4e2 100644 --- a/querylog/qlog_file.go +++ b/querylog/qlog_file.go @@ -62,7 +62,7 @@ func NewQLogFile(path string) (*QLogFile, error) { // It returns the position of the the line with the timestamp we were looking for // so that when we call "ReadNext" this line was returned. // If we could not find it, it returns 0 and ErrSeekNotFound -func (q *QLogFile) Seek(timestamp uint64) (int64, error) { +func (q *QLogFile) Seek(timestamp int64) (int64, error) { q.lock.Lock() defer q.lock.Unlock() @@ -305,7 +305,7 @@ func (q *QLogFile) readProbeLine(position int64) (string, int64, error) { } // readQLogTimestamp reads the timestamp field from the query log line -func readQLogTimestamp(str string) uint64 { +func readQLogTimestamp(str string) int64 { val := readJSONValue(str, "T") if len(val) == 0 { val = readJSONValue(str, "Time") @@ -320,5 +320,5 @@ func readQLogTimestamp(str string) uint64 { log.Error("Couldn't parse timestamp: %s", val) return 0 } - return uint64(tm.UnixNano()) + return tm.UnixNano() } diff --git a/querylog/qlog_file_test.go b/querylog/qlog_file_test.go index 2ba6f38d..c5871475 100644 --- a/querylog/qlog_file_test.go +++ b/querylog/qlog_file_test.go @@ -100,7 +100,7 @@ func TestQLogFileSeekLargeFile(t *testing.T) { // CASE 6: Seek non-existent (too high) ts, _ := time.Parse(time.RFC3339, "2100-01-02T15:04:05Z07:00") - _, err = q.Seek(uint64(ts.UnixNano())) + _, err = q.Seek(ts.UnixNano()) assert.NotNil(t, err) } @@ -136,7 +136,7 @@ func TestQLogFileSeekSmallFile(t *testing.T) { // CASE 6: Seek non-existent (too high) ts, _ := time.Parse(time.RFC3339, "2100-01-02T15:04:05Z07:00") - _, err = q.Seek(uint64(ts.UnixNano())) + _, err = q.Seek(ts.UnixNano()) assert.NotNil(t, err) } diff --git a/querylog/qlog_reader.go b/querylog/qlog_reader.go index 6b081cc2..5f4f4dfa 100644 --- a/querylog/qlog_reader.go +++ b/querylog/qlog_reader.go @@ -48,7 +48,7 @@ func NewQLogReader(files []string) (*QLogReader, error) { // // Returns nil if the record is successfully found. // Returns an error if for some reason we could not find a record with the specified timestamp. -func (r *QLogReader) Seek(timestamp uint64) error { +func (r *QLogReader) Seek(timestamp int64) error { for i := len(r.qFiles) - 1; i >= 0; i-- { q := r.qFiles[i] _, err := q.Seek(timestamp) diff --git a/querylog/qlog_reader_test.go b/querylog/qlog_reader_test.go index 592ab4b9..357b4f9d 100644 --- a/querylog/qlog_reader_test.go +++ b/querylog/qlog_reader_test.go @@ -122,7 +122,7 @@ func TestQLogReaderSeek(t *testing.T) { // CASE 6: Seek non-existent (too high) ts, _ := time.Parse(time.RFC3339, "2100-01-02T15:04:05Z07:00") - err = r.Seek(uint64(ts.UnixNano())) + err = r.Seek(ts.UnixNano()) assert.NotNil(t, err) } diff --git a/querylog/querylog_file.go b/querylog/querylog_file.go index e5544fc4..3b32b7a2 100644 --- a/querylog/querylog_file.go +++ b/querylog/querylog_file.go @@ -1,23 +1,16 @@ package querylog import ( - "bufio" "bytes" - "encoding/base64" "encoding/json" - "io" "os" - "strconv" - "strings" "time" - "github.com/AdguardTeam/AdGuardHome/dnsfilter" "github.com/AdguardTeam/golibs/log" - "github.com/miekg/dns" ) // TODO: Check this when we append a new line -- we don't want to have a line longer than this -const maxEntrySize = 1000 +const maxEntrySize = 1024 // flushLogBuffer flushes the current buffer to file and resets the current buffer func (l *queryLog) flushLogBuffer(fullFlush bool) error { @@ -117,591 +110,3 @@ func (l *queryLog) periodicRotate() { } } } - -// Reader is the DB reader context -type Reader struct { - ql *queryLog - search *getDataParams - - f *os.File - reader *bufio.Reader // reads file line by line - now time.Time - validFrom int64 // UNIX time (ns) - olderThan int64 // UNIX time (ns) - oldest time.Time - - files []string - ifile int - - limit uint64 - count uint64 // counter for returned elements - latest bool // return the latest entries - filePrepared bool - - seeking bool // we're seaching for an entry with exact time stamp - fseeker fileSeeker // file seeker object - fpos uint64 // current file offset - nSeekRequests uint32 // number of Seek() requests made (finding a new line doesn't count) -} - -type fileSeeker struct { - target uint64 // target value - - pos uint64 // current offset, may be adjusted by user for increased accuracy - lastpos uint64 // the last offset returned - lo uint64 // low boundary offset - hi uint64 // high boundary offset -} - -// OpenReader - return reader object -func (l *queryLog) OpenReader() *Reader { - r := Reader{} - r.ql = l - r.now = time.Now() - r.validFrom = r.now.Unix() - int64(l.conf.Interval*24*60*60) - r.validFrom *= 1000000000 - r.files = []string{ - r.ql.logFile, - r.ql.logFile + ".1", - } - return &r -} - -// Close - close the reader -func (r *Reader) Close() { - elapsed := time.Since(r.now) - var perunit time.Duration - if r.count > 0 { - perunit = elapsed / time.Duration(r.count) - } - log.Debug("querylog: read %d entries in %v, %v/entry, seek-reqs:%d", - r.count, elapsed, perunit, r.nSeekRequests) - - if r.f != nil { - r.f.Close() - } -} - -// BeginRead - start reading -// olderThan: stop returning entries when an entry with this time is reached -// count: minimum number of entries to return -func (r *Reader) BeginRead(olderThan time.Time, count uint64, search *getDataParams) { - r.olderThan = olderThan.UnixNano() - r.latest = olderThan.IsZero() - r.oldest = time.Time{} - r.search = search - r.limit = count - if r.latest { - r.olderThan = r.now.UnixNano() - } - r.filePrepared = false - r.seeking = false -} - -// BeginReadPrev - start reading the previous data chunk -func (r *Reader) BeginReadPrev(count uint64) { - r.olderThan = r.oldest.UnixNano() - r.oldest = time.Time{} - r.latest = false - r.limit = count - r.count = 0 - - off := r.fpos - maxEntrySize*(r.limit+1) - if int64(off) < maxEntrySize { - off = 0 - } - r.fpos = off - log.Debug("QueryLog: seek: %x", off) - _, err := r.f.Seek(int64(off), io.SeekStart) - if err != nil { - log.Error("file.Seek: %s: %s", r.files[r.ifile], err) - return - } - r.nSeekRequests++ - - r.seekToNewLine() - r.fseeker.pos = r.fpos - - r.filePrepared = true - r.seeking = false -} - -// Perform binary seek -// Return 0: success; 1: seek reqiured; -1: error -func (fs *fileSeeker) seekBinary(cur uint64) int32 { - log.Debug("QueryLog: seek: tgt=%x cur=%x, %x: [%x..%x]", fs.target, cur, fs.pos, fs.lo, fs.hi) - - off := uint64(0) - if fs.pos >= fs.lo && fs.pos < fs.hi { - if cur == fs.target { - return 0 - } else if cur < fs.target { - fs.lo = fs.pos + 1 - } else { - fs.hi = fs.pos - } - off = fs.lo + (fs.hi-fs.lo)/2 - } else { - // we didn't find another entry from the last file offset: now return the boundary beginning - off = fs.lo - } - - if off == fs.lastpos { - return -1 - } - - fs.lastpos = off - fs.pos = off - return 1 -} - -// Seek to a new line -func (r *Reader) seekToNewLine() bool { - r.reader = bufio.NewReader(r.f) - b, err := r.reader.ReadBytes('\n') - if err != nil { - r.reader = nil - log.Error("QueryLog: file.Read: %s: %s", r.files[r.ifile], err) - return false - } - - off := len(b) - r.fpos += uint64(off) - log.Debug("QueryLog: seek: %x (+%d)", r.fpos, off) - return true -} - -// Open a file -func (r *Reader) openFile() bool { - var err error - fn := r.files[r.ifile] - - r.f, err = os.Open(fn) - if err != nil { - if !os.IsNotExist(err) { - log.Error("QueryLog: Failed to open file \"%s\": %s", fn, err) - } - return false - } - return true -} - -// Seek to the needed position -func (r *Reader) prepareRead() bool { - fn := r.files[r.ifile] - - fi, err := r.f.Stat() - if err != nil { - log.Error("QueryLog: file.Stat: %s: %s", fn, err) - return false - } - fsize := uint64(fi.Size()) - - off := uint64(0) - if r.latest { - // read data from the end of file - off = fsize - maxEntrySize*(r.limit+1) - if int64(off) < maxEntrySize { - off = 0 - } - r.fpos = off - log.Debug("QueryLog: seek: %x", off) - _, err = r.f.Seek(int64(off), io.SeekStart) - if err != nil { - log.Error("QueryLog: file.Seek: %s: %s", fn, err) - return false - } - } else { - // start searching in file: we'll read the first chunk of data from the middle of file - r.seeking = true - r.fseeker = fileSeeker{} - r.fseeker.target = uint64(r.olderThan) - r.fseeker.hi = fsize - rc := r.fseeker.seekBinary(0) - r.fpos = r.fseeker.pos - if rc == 1 { - _, err = r.f.Seek(int64(r.fpos), io.SeekStart) - if err != nil { - log.Error("QueryLog: file.Seek: %s: %s", fn, err) - return false - } - } - } - r.nSeekRequests++ - - if !r.seekToNewLine() { - return false - } - r.fseeker.pos = r.fpos - return true -} - -// Get bool value from "key":bool -func readJSONBool(s, name string) (bool, bool) { - i := strings.Index(s, "\""+name+"\":") - if i == -1 { - return false, false - } - start := i + 1 + len(name) + 2 - b := false - if strings.HasPrefix(s[start:], "true") { - b = true - } else if !strings.HasPrefix(s[start:], "false") { - return false, false - } - return b, true -} - -// Get value from "key":"value" -func readJSONValue(s, name string) string { - i := strings.Index(s, "\""+name+"\":\"") - if i == -1 { - return "" - } - start := i + 1 + len(name) + 3 - i = strings.IndexByte(s[start:], '"') - if i == -1 { - return "" - } - end := start + i - return s[start:end] -} - -// nolint (gocyclo) -func (r *Reader) applySearch(str string) bool { - if r.search.ResponseStatus == responseStatusFiltered { - boolVal, ok := readJSONBool(str, "IsFiltered") - if !ok || !boolVal { - return false - } - } - - mq := dns.Msg{} - - if len(r.search.Domain) != 0 { - val := readJSONValue(str, "QH") - if len(val) == 0 { - // pre-v0.99.3 compatibility - val = readJSONValue(str, "Question") - if len(val) == 0 { - return false - } - bval, err := base64.StdEncoding.DecodeString(val) - if err != nil { - return false - } - err = mq.Unpack(bval) - if err != nil { - return false - } - val = strings.TrimSuffix(mq.Question[0].Name, ".") - } - if len(val) == 0 { - return false - } - - if (r.search.StrictMatchDomain && val != r.search.Domain) || - (!r.search.StrictMatchDomain && strings.Index(val, r.search.Domain) == -1) { - return false - } - } - - if len(r.search.QuestionType) != 0 { - val := readJSONValue(str, "QT") - if len(val) == 0 { - // pre-v0.99.3 compatibility - if len(mq.Question) == 0 { - val = readJSONValue(str, "Question") - if len(val) == 0 { - return false - } - bval, err := base64.StdEncoding.DecodeString(val) - if err != nil { - return false - } - err = mq.Unpack(bval) - if err != nil { - return false - } - } - ok := false - val, ok = dns.TypeToString[mq.Question[0].Qtype] - if !ok { - return false - } - } - if val != r.search.QuestionType { - return false - } - } - - if len(r.search.Client) != 0 { - val := readJSONValue(str, "IP") - if len(val) == 0 { - log.Debug("QueryLog: failed to decode") - return false - } - - if (r.search.StrictMatchClient && val != r.search.Client) || - (!r.search.StrictMatchClient && strings.Index(val, r.search.Client) == -1) { - return false - } - } - - return true -} - -const ( - jsonTErr = iota - jsonTObj - jsonTStr - jsonTNum - jsonTBool -) - -// Parse JSON key-value pair -// e.g.: "key":VALUE where VALUE is "string", true|false (boolean), or 123.456 (number) -// Note the limitations: -// . doesn't support whitespace -// . doesn't support "null" -// . doesn't validate boolean or number -// . no proper handling of {} braces -// . no handling of [] brackets -// Return (key, value, type) -func readJSON(ps *string) (string, string, int32) { - s := *ps - k := "" - v := "" - t := int32(jsonTErr) - - q1 := strings.IndexByte(s, '"') - if q1 == -1 { - return k, v, t - } - q2 := strings.IndexByte(s[q1+1:], '"') - if q2 == -1 { - return k, v, t - } - k = s[q1+1 : q1+1+q2] - s = s[q1+1+q2+1:] - - if len(s) < 2 || s[0] != ':' { - return k, v, t - } - - if s[1] == '"' { - q2 = strings.IndexByte(s[2:], '"') - if q2 == -1 { - return k, v, t - } - v = s[2 : 2+q2] - t = jsonTStr - s = s[2+q2+1:] - - } else if s[1] == '{' { - t = jsonTObj - s = s[1+1:] - - } else { - sep := strings.IndexAny(s[1:], ",}") - if sep == -1 { - return k, v, t - } - v = s[1 : 1+sep] - if s[1] == 't' || s[1] == 'f' { - t = jsonTBool - } else if s[1] == '.' || (s[1] >= '0' && s[1] <= '9') { - t = jsonTNum - } - s = s[1+sep+1:] - } - - *ps = s - return k, v, t -} - -// nolint (gocyclo) -func decode(ent *logEntry, str string) { - var b bool - var i int - var err error - for { - k, v, t := readJSON(&str) - if t == jsonTErr { - break - } - switch k { - case "IP": - if len(ent.IP) == 0 { - ent.IP = v - } - case "T": - ent.Time, err = time.Parse(time.RFC3339, v) - - case "QH": - ent.QHost = v - case "QT": - ent.QType = v - case "QC": - ent.QClass = v - - case "Answer": - ent.Answer, err = base64.StdEncoding.DecodeString(v) - case "OrigAnswer": - ent.OrigAnswer, err = base64.StdEncoding.DecodeString(v) - - case "IsFiltered": - b, err = strconv.ParseBool(v) - ent.Result.IsFiltered = b - case "Rule": - ent.Result.Rule = v - case "FilterID": - i, err = strconv.Atoi(v) - ent.Result.FilterID = int64(i) - case "Reason": - i, err = strconv.Atoi(v) - ent.Result.Reason = dnsfilter.Reason(i) - - case "Upstream": - ent.Upstream = v - case "Elapsed": - i, err = strconv.Atoi(v) - ent.Elapsed = time.Duration(i) - - // pre-v0.99.3 compatibility: - case "Question": - var qstr []byte - qstr, err = base64.StdEncoding.DecodeString(v) - if err != nil { - break - } - q := new(dns.Msg) - err = q.Unpack(qstr) - if err != nil { - break - } - ent.QHost = q.Question[0].Name - if len(ent.QHost) == 0 { - break - } - ent.QHost = ent.QHost[:len(ent.QHost)-1] - ent.QType = dns.TypeToString[q.Question[0].Qtype] - ent.QClass = dns.ClassToString[q.Question[0].Qclass] - case "Time": - ent.Time, err = time.Parse(time.RFC3339, v) - } - - if err != nil { - log.Debug("decode err: %s", err) - break - } - } -} - -// Next - return the next entry or nil if reading is finished -func (r *Reader) Next() *logEntry { // nolint - for { - // open file if needed - if r.f == nil { - if r.ifile == len(r.files) { - return nil - } - if !r.openFile() { - r.ifile++ - continue - } - } - - if !r.filePrepared { - if !r.prepareRead() { - return nil - } - r.filePrepared = true - } - - b, err := r.reader.ReadBytes('\n') - if err != nil { - return nil - } - str := string(b) - - val := readJSONValue(str, "T") - if len(val) == 0 { - val = readJSONValue(str, "Time") - } - if len(val) == 0 { - log.Debug("QueryLog: failed to decode") - continue - } - tm, err := time.Parse(time.RFC3339, val) - if err != nil { - log.Debug("QueryLog: failed to decode") - continue - } - t := tm.UnixNano() - - if r.seeking { - - r.reader = nil - rr := r.fseeker.seekBinary(uint64(t)) - r.fpos = r.fseeker.pos - if rr < 0 { - log.Error("QueryLog: File seek error: can't find the target entry: %s", r.files[r.ifile]) - return nil - } else if rr == 0 { - // We found the target entry. - // We'll start reading the previous chunk of data. - r.seeking = false - - off := r.fpos - (maxEntrySize * (r.limit + 1)) - if int64(off) < maxEntrySize { - off = 0 - } - r.fpos = off - } - - _, err = r.f.Seek(int64(r.fpos), io.SeekStart) - if err != nil { - log.Error("QueryLog: file.Seek: %s: %s", r.files[r.ifile], err) - return nil - } - r.nSeekRequests++ - - if !r.seekToNewLine() { - return nil - } - r.fseeker.pos = r.fpos - continue - } - - if r.oldest.IsZero() { - r.oldest = tm - } - - if t < r.validFrom { - continue - } - if t >= r.olderThan { - return nil - } - r.count++ - - if !r.applySearch(str) { - continue - } - - var ent logEntry - decode(&ent, str) - return &ent - } -} - -// Total returns the total number of processed items -func (r *Reader) Total() uint64 { - return r.count -} - -// Oldest returns the time of the oldest processed entry -func (r *Reader) Oldest() time.Time { - return r.oldest -} diff --git a/querylog/querylog_search.go b/querylog/querylog_search.go new file mode 100644 index 00000000..236c1940 --- /dev/null +++ b/querylog/querylog_search.go @@ -0,0 +1,366 @@ +package querylog + +import ( + "encoding/base64" + "io" + "strconv" + "strings" + "time" + + "github.com/AdguardTeam/AdGuardHome/dnsfilter" + "github.com/AdguardTeam/AdGuardHome/util" + "github.com/AdguardTeam/golibs/log" + "github.com/miekg/dns" +) + +// searchFiles reads log entries from all log files and applies the specified search criteria. +// IMPORTANT: this method does not scan more than "maxSearchEntries" so you +// may need to call it many times. +// +// it returns: +// * an array of log entries that we have read +// * time of the oldest processed entry (even if it was discarded) +// * total number of processed entries (including discarded). +func (l *queryLog) searchFiles(params getDataParams) ([]*logEntry, time.Time, int) { + entries := make([]*logEntry, 0) + oldest := time.Time{} + + r, err := l.openReader() + if err != nil { + log.Error("Failed to open qlog reader: %v", err) + return entries, oldest, 0 + } + defer r.Close() + + if params.OlderThan.IsZero() { + err = r.SeekStart() + } else { + err = r.Seek(params.OlderThan.UnixNano()) + } + + if err != nil { + log.Error("Failed to Seek(): %v", err) + return entries, oldest, 0 + } + + total := 0 + oldestNano := int64(0) + // Do not scan more than 50k at once + for total <= maxSearchEntries { + entry, ts, err := l.readNextEntry(r, params) + + if err == io.EOF { + // there's nothing to read anymore + break + } + + if entry != nil { + entries = append(entries, entry) + } + + oldestNano = ts + total++ + } + + oldest = time.Unix(0, oldestNano) + return entries, oldest, total +} + +// readNextEntry - reads the next log entry and checks if it matches the search criteria (getDataParams) +// +// returns: +// * log entry that matches search criteria or null if it was discarded (or if there's nothing to read) +// * timestamp of the processed log entry +// * error if we can't read anymore +func (l *queryLog) readNextEntry(r *QLogReader, params getDataParams) (*logEntry, int64, error) { + line, err := r.ReadNext() + if err != nil { + return nil, 0, err + } + + // Read the log record timestamp right away + timestamp := readQLogTimestamp(line) + + // Quick check without deserializing log entry + if !quickMatchesGetDataParams(line, params) { + return nil, timestamp, nil + } + + entry := logEntry{} + decodeLogEntry(&entry, line) + + // Full check of the deserialized log entry + if !matchesGetDataParams(&entry, params) { + return nil, timestamp, nil + } + + return &entry, timestamp, nil +} + +// openReader - opens QLogReader instance +func (l *queryLog) openReader() (*QLogReader, error) { + files := make([]string, 0) + + if util.FileExists(l.logFile + ".1") { + files = append(files, l.logFile+".1") + } + if util.FileExists(l.logFile) { + files = append(files, l.logFile) + } + + return NewQLogReader(files) +} + +// quickMatchesGetDataParams - quickly checks if the line matches getDataParams +// this method does not guarantee anything and the reason is to do a quick check +// without deserializing anything +func quickMatchesGetDataParams(line string, params getDataParams) bool { + if params.ResponseStatus == responseStatusFiltered { + boolVal, ok := readJSONBool(line, "IsFiltered") + if !ok || !boolVal { + return false + } + } + + if len(params.Domain) != 0 { + val := readJSONValue(line, "QH") + if len(val) == 0 { + return false + } + + if (params.StrictMatchDomain && val != params.Domain) || + (!params.StrictMatchDomain && strings.Index(val, params.Domain) == -1) { + return false + } + } + + if len(params.QuestionType) != 0 { + val := readJSONValue(line, "QT") + if val != params.QuestionType { + return false + } + } + + if len(params.Client) != 0 { + val := readJSONValue(line, "IP") + if len(val) == 0 { + log.Debug("QueryLog: failed to decodeLogEntry") + return false + } + + if (params.StrictMatchClient && val != params.Client) || + (!params.StrictMatchClient && strings.Index(val, params.Client) == -1) { + return false + } + } + + return true +} + +// matchesGetDataParams - returns true if the entry matches the search parameters +func matchesGetDataParams(entry *logEntry, params getDataParams) bool { + if params.ResponseStatus == responseStatusFiltered && !entry.Result.IsFiltered { + return false + } + + if len(params.QuestionType) != 0 { + if entry.QType != params.QuestionType { + return false + } + } + + if len(params.Domain) != 0 { + if (params.StrictMatchDomain && entry.QHost != params.Domain) || + (!params.StrictMatchDomain && strings.Index(entry.QHost, params.Domain) == -1) { + return false + } + } + + if len(params.Client) != 0 { + if (params.StrictMatchClient && entry.IP != params.Client) || + (!params.StrictMatchClient && strings.Index(entry.IP, params.Client) == -1) { + return false + } + } + + return true +} + +// decodeLogEntry - decodes query log entry from a line +// nolint (gocyclo) +func decodeLogEntry(ent *logEntry, str string) { + var b bool + var i int + var err error + for { + k, v, t := readJSON(&str) + if t == jsonTErr { + break + } + switch k { + case "IP": + if len(ent.IP) == 0 { + ent.IP = v + } + case "T": + ent.Time, err = time.Parse(time.RFC3339, v) + + case "QH": + ent.QHost = v + case "QT": + ent.QType = v + case "QC": + ent.QClass = v + + case "Answer": + ent.Answer, err = base64.StdEncoding.DecodeString(v) + case "OrigAnswer": + ent.OrigAnswer, err = base64.StdEncoding.DecodeString(v) + + case "IsFiltered": + b, err = strconv.ParseBool(v) + ent.Result.IsFiltered = b + case "Rule": + ent.Result.Rule = v + case "FilterID": + i, err = strconv.Atoi(v) + ent.Result.FilterID = int64(i) + case "Reason": + i, err = strconv.Atoi(v) + ent.Result.Reason = dnsfilter.Reason(i) + + case "Upstream": + ent.Upstream = v + case "Elapsed": + i, err = strconv.Atoi(v) + ent.Elapsed = time.Duration(i) + + // pre-v0.99.3 compatibility: + case "Question": + var qstr []byte + qstr, err = base64.StdEncoding.DecodeString(v) + if err != nil { + break + } + q := new(dns.Msg) + err = q.Unpack(qstr) + if err != nil { + break + } + ent.QHost = q.Question[0].Name + if len(ent.QHost) == 0 { + break + } + ent.QHost = ent.QHost[:len(ent.QHost)-1] + ent.QType = dns.TypeToString[q.Question[0].Qtype] + ent.QClass = dns.ClassToString[q.Question[0].Qclass] + case "Time": + ent.Time, err = time.Parse(time.RFC3339, v) + } + + if err != nil { + log.Debug("decodeLogEntry err: %s", err) + break + } + } +} + +// Get bool value from "key":bool +func readJSONBool(s, name string) (bool, bool) { + i := strings.Index(s, "\""+name+"\":") + if i == -1 { + return false, false + } + start := i + 1 + len(name) + 2 + b := false + if strings.HasPrefix(s[start:], "true") { + b = true + } else if !strings.HasPrefix(s[start:], "false") { + return false, false + } + return b, true +} + +// Get value from "key":"value" +func readJSONValue(s, name string) string { + i := strings.Index(s, "\""+name+"\":\"") + if i == -1 { + return "" + } + start := i + 1 + len(name) + 3 + i = strings.IndexByte(s[start:], '"') + if i == -1 { + return "" + } + end := start + i + return s[start:end] +} + +const ( + jsonTErr = iota + jsonTObj + jsonTStr + jsonTNum + jsonTBool +) + +// Parse JSON key-value pair +// e.g.: "key":VALUE where VALUE is "string", true|false (boolean), or 123.456 (number) +// Note the limitations: +// . doesn't support whitespace +// . doesn't support "null" +// . doesn't validate boolean or number +// . no proper handling of {} braces +// . no handling of [] brackets +// Return (key, value, type) +func readJSON(ps *string) (string, string, int32) { + s := *ps + k := "" + v := "" + t := int32(jsonTErr) + + q1 := strings.IndexByte(s, '"') + if q1 == -1 { + return k, v, t + } + q2 := strings.IndexByte(s[q1+1:], '"') + if q2 == -1 { + return k, v, t + } + k = s[q1+1 : q1+1+q2] + s = s[q1+1+q2+1:] + + if len(s) < 2 || s[0] != ':' { + return k, v, t + } + + if s[1] == '"' { + q2 = strings.IndexByte(s[2:], '"') + if q2 == -1 { + return k, v, t + } + v = s[2 : 2+q2] + t = jsonTStr + s = s[2+q2+1:] + + } else if s[1] == '{' { + t = jsonTObj + s = s[1+1:] + + } else { + sep := strings.IndexAny(s[1:], ",}") + if sep == -1 { + return k, v, t + } + v = s[1 : 1+sep] + if s[1] == 't' || s[1] == 'f' { + t = jsonTBool + } else if s[1] == '.' || (s[1] >= '0' && s[1] <= '9') { + t = jsonTNum + } + s = s[1+sep+1:] + } + + *ps = s + return k, v, t +} diff --git a/querylog/querylog_test.go b/querylog/querylog_test.go index 8c8b9bb4..8d587b7c 100644 --- a/querylog/querylog_test.go +++ b/querylog/querylog_test.go @@ -30,14 +30,17 @@ func TestQueryLog(t *testing.T) { l := newQueryLog(conf) // add disk entries - addEntry(l, "example.org", "1.2.3.4", "0.1.2.3") - addEntry(l, "example.org", "1.2.3.4", "0.1.2.3") - + addEntry(l, "example.org", "1.1.1.1", "2.2.2.1") + // write to disk (first file) + _ = l.flushLogBuffer(true) + // start writing to the second file + _ = l.rotate() + // add disk entries + addEntry(l, "example.org", "1.1.1.2", "2.2.2.2") // write to disk - l.flushLogBuffer(true) - + _ = l.flushLogBuffer(true) // add memory entries - addEntry(l, "test.example.org", "2.2.3.4", "0.1.2.4") + addEntry(l, "test.example.org", "1.1.1.3", "2.2.2.3") // get all entries params := getDataParams{ @@ -45,9 +48,10 @@ func TestQueryLog(t *testing.T) { } d := l.getData(params) mdata := d["data"].([]map[string]interface{}) - assert.True(t, len(mdata) == 2) - assert.True(t, checkEntry(t, mdata[0], "test.example.org", "2.2.3.4", "0.1.2.4")) - assert.True(t, checkEntry(t, mdata[1], "example.org", "1.2.3.4", "0.1.2.3")) + assert.Equal(t, 3, len(mdata)) + assert.True(t, checkEntry(t, mdata[0], "test.example.org", "1.1.1.3", "2.2.2.3")) + assert.True(t, checkEntry(t, mdata[1], "example.org", "1.1.1.2", "2.2.2.2")) + assert.True(t, checkEntry(t, mdata[2], "example.org", "1.1.1.1", "2.2.2.1")) // search by domain (strict) params = getDataParams{ @@ -58,9 +62,9 @@ func TestQueryLog(t *testing.T) { d = l.getData(params) mdata = d["data"].([]map[string]interface{}) assert.True(t, len(mdata) == 1) - assert.True(t, checkEntry(t, mdata[0], "test.example.org", "2.2.3.4", "0.1.2.4")) + assert.True(t, checkEntry(t, mdata[0], "test.example.org", "1.1.1.3", "2.2.2.3")) - // search by domain + // search by domain (not strict) params = getDataParams{ OlderThan: time.Time{}, Domain: "example.org", @@ -68,32 +72,34 @@ func TestQueryLog(t *testing.T) { } d = l.getData(params) mdata = d["data"].([]map[string]interface{}) - assert.True(t, len(mdata) == 2) - assert.True(t, checkEntry(t, mdata[0], "test.example.org", "2.2.3.4", "0.1.2.4")) - assert.True(t, checkEntry(t, mdata[1], "example.org", "1.2.3.4", "0.1.2.3")) + assert.Equal(t, 3, len(mdata)) + assert.True(t, checkEntry(t, mdata[0], "test.example.org", "1.1.1.3", "2.2.2.3")) + assert.True(t, checkEntry(t, mdata[1], "example.org", "1.1.1.2", "2.2.2.2")) + assert.True(t, checkEntry(t, mdata[2], "example.org", "1.1.1.1", "2.2.2.1")) // search by client IP (strict) params = getDataParams{ OlderThan: time.Time{}, - Client: "0.1.2.3", + Client: "2.2.2.2", StrictMatchClient: true, } d = l.getData(params) mdata = d["data"].([]map[string]interface{}) - assert.True(t, len(mdata) == 1) - assert.True(t, checkEntry(t, mdata[0], "example.org", "1.2.3.4", "0.1.2.3")) + assert.Equal(t, 1, len(mdata)) + assert.True(t, checkEntry(t, mdata[0], "example.org", "1.1.1.2", "2.2.2.2")) - // search by client IP + // search by client IP (part of) params = getDataParams{ OlderThan: time.Time{}, - Client: "0.1.2", + Client: "2.2.2", StrictMatchClient: false, } d = l.getData(params) mdata = d["data"].([]map[string]interface{}) - assert.True(t, len(mdata) == 2) - assert.True(t, checkEntry(t, mdata[0], "test.example.org", "2.2.3.4", "0.1.2.4")) - assert.True(t, checkEntry(t, mdata[1], "example.org", "1.2.3.4", "0.1.2.3")) + assert.Equal(t, 3, len(mdata)) + assert.True(t, checkEntry(t, mdata[0], "test.example.org", "1.1.1.3", "2.2.2.3")) + assert.True(t, checkEntry(t, mdata[1], "example.org", "1.1.1.2", "2.2.2.2")) + assert.True(t, checkEntry(t, mdata[2], "example.org", "1.1.1.1", "2.2.2.1")) } func addEntry(l *queryLog, host, answerStr, client string) { @@ -129,11 +135,11 @@ func checkEntry(t *testing.T, m map[string]interface{}, host, answer, client str mq := m["question"].(map[string]interface{}) ma := m["answer"].([]map[string]interface{}) ma0 := ma[0] - if !assert.True(t, mq["host"].(string) == host) || - !assert.True(t, mq["class"].(string) == "IN") || - !assert.True(t, mq["type"].(string) == "A") || - !assert.True(t, ma0["value"].(string) == answer) || - !assert.True(t, m["client"].(string) == client) { + if !assert.Equal(t, host, mq["host"].(string)) || + !assert.Equal(t, "IN", mq["class"].(string)) || + !assert.Equal(t, "A", mq["type"].(string)) || + !assert.Equal(t, answer, ma0["value"].(string)) || + !assert.Equal(t, client, m["client"].(string)) { return false } return true