diff --git a/AGHTechDoc.md b/AGHTechDoc.md index b4b7bd61..4b2833d7 100644 --- a/AGHTechDoc.md +++ b/AGHTechDoc.md @@ -37,6 +37,11 @@ Contents: * Services Filter * API: Get blocked services list * API: Set blocked services list +* Statistics + * API: Get statistics data + * API: Clear statistics data + * API: Set statistics parameters + * API: Get statistics parameters ## First startup @@ -863,3 +868,111 @@ Request: Response: 200 OK + + +## Statistics + +Load (main thread): +. Load data from the last bucket from DB for the current hour + +Runtime (DNS worker threads): +. Update current unit + +Runtime (goroutine): +. Periodically check that current unit should be flushed to file (when the current hour changes) + . If so, flush it, allocate a new empty unit + +Runtime (HTTP worker threads): +. To respond to "Get statistics" API request we: + . load all units from file + . load current unit + . process data from all loaded units: + . sum up data for "total counters" output values + . add value into "per time unit counters" output arrays + . aggregate data for "top_" output arrays; sort in descending order + +Unload (main thread): +. Flush current unit to file + + +### API: Get statistics data + +Request: + + GET /control/stats + +Response: + + 200 OK + + { + time_units: hours | days + + // total counters: + num_dns_queries: 123 + num_blocked_filtering: 123 + num_replaced_safebrowsing: 123 + num_replaced_safesearch: 123 + num_replaced_parental: 123 + avg_processing_time: 123.123 + + // per time unit counters + dns_queries: [123, ...] + blocked_filtering: [123, ...] + replaced_parental: [123, ...] + replaced_safebrowsing: [123, ...] + + top_queried_domains: [ + {host: 123}, + ... + ] + top_blocked_domains: [ + {host: 123}, + ... + ] + top_clients: [ + {IP: 123}, + ... + ] + } + + +### API: Clear statistics data + +Request: + + POST /control/stats_reset + +Response: + + 200 OK + + +### API: Set statistics parameters + +Request: + + POST /control/stats_config + + { + "interval": 1 | 7 | 30 | 90 + } + +Response: + + 200 OK + + +### API: Get statistics parameters + +Request: + + GET /control/stats_info + +Response: + + 200 OK + + { + "interval": 1 | 7 | 30 | 90 + } diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index fe943b91..29146843 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -11,6 +11,7 @@ import ( "time" "github.com/AdguardTeam/AdGuardHome/dnsfilter" + "github.com/AdguardTeam/AdGuardHome/stats" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/log" @@ -40,6 +41,7 @@ type Server struct { dnsProxy *proxy.Proxy // DNS proxy instance dnsFilter *dnsfilter.Dnsfilter // DNS filter instance queryLog *queryLog // Query log instance + stats stats.Stats AllowedClients map[string]bool // IP addresses of whitelist clients DisallowedClients map[string]bool // IP addresses of clients that should be blocked @@ -54,10 +56,11 @@ type Server struct { // NewServer creates a new instance of the dnsforward.Server // baseDir is the base directory for query logs // Note: this function must be called only once -func NewServer(baseDir string) *Server { +func NewServer(baseDir string, stats stats.Stats) *Server { s := &Server{ queryLog: newQueryLog(baseDir), } + s.stats = stats log.Printf("Start DNS server periodic jobs") go s.queryLog.periodicQueryLogRotate() @@ -465,8 +468,8 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error { shouldLog = false } + elapsed := time.Since(start) if s.conf.QueryLogEnabled && shouldLog { - elapsed := time.Since(start) upstreamAddr := "" if d.Upstream != nil { upstreamAddr = d.Upstream.Address() @@ -474,9 +477,52 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error { _ = s.queryLog.logRequest(msg, d.Res, res, elapsed, d.Addr, upstreamAddr) } + s.updateStats(d, elapsed, *res) + return nil } +func (s *Server) updateStats(d *proxy.DNSContext, elapsed time.Duration, res dnsfilter.Result) { + if s.stats == nil { + return + } + + e := stats.Entry{} + e.Domain = strings.ToLower(d.Req.Question[0].Name) + e.Domain = e.Domain[:len(e.Domain)-1] // remove last "." + switch addr := d.Addr.(type) { + case *net.UDPAddr: + e.Client = addr.IP + case *net.TCPAddr: + e.Client = addr.IP + } + e.Time = uint(elapsed / 1000) + switch res.Reason { + + case dnsfilter.NotFilteredNotFound: + fallthrough + case dnsfilter.NotFilteredWhiteList: + fallthrough + case dnsfilter.NotFilteredError: + e.Result = stats.RNotFiltered + + case dnsfilter.FilteredSafeBrowsing: + e.Result = stats.RSafeBrowsing + case dnsfilter.FilteredParental: + e.Result = stats.RParental + case dnsfilter.FilteredSafeSearch: + e.Result = stats.RSafeSearch + + case dnsfilter.FilteredBlackList: + fallthrough + case dnsfilter.FilteredInvalid: + fallthrough + case dnsfilter.FilteredBlockedService: + e.Result = stats.RFiltered + } + s.stats.Update(e) +} + // filterDNSRequest applies the dnsFilter and sets d.Res if the request was filtered func (s *Server) filterDNSRequest(d *proxy.DNSContext) (*dnsfilter.Result, error) { var res dnsfilter.Result diff --git a/dnsforward/dnsforward_test.go b/dnsforward/dnsforward_test.go index eb69e3d0..740a43b8 100644 --- a/dnsforward/dnsforward_test.go +++ b/dnsforward/dnsforward_test.go @@ -422,7 +422,7 @@ func TestBlockedBySafeBrowsing(t *testing.T) { } func createTestServer(t *testing.T) *Server { - s := NewServer(createDataDir(t)) + s := NewServer(createDataDir(t), nil) s.conf.UDPListenAddr = &net.UDPAddr{Port: 0} s.conf.TCPListenAddr = &net.TCPAddr{Port: 0} diff --git a/go.mod b/go.mod index 783486bd..8949d9b8 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,9 @@ require ( github.com/AdguardTeam/golibs v0.2.1 github.com/AdguardTeam/urlfilter v0.5.0 github.com/NYTimes/gziphandler v1.1.1 + github.com/asaskevich/govalidator v0.0.0-20180720115003-f9ffefc3facf github.com/bluele/gcache v0.0.0-20190518031135-bc40bd653833 + github.com/etcd-io/bbolt v1.3.3 github.com/go-test/deep v1.0.1 github.com/gobuffalo/packr v1.19.0 github.com/joomcode/errorx v0.8.0 @@ -17,8 +19,8 @@ require ( github.com/miekg/dns v1.1.8 github.com/sparrc/go-ping v0.0.0-20181106165434-ef3ab45e41b0 github.com/stretchr/testify v1.4.0 + go.etcd.io/bbolt v1.3.3 // indirect golang.org/x/net v0.0.0-20190620200207-3b0461eec859 golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0 - gopkg.in/asaskevich/govalidator.v4 v4.0.0-20160518190739-766470278477 gopkg.in/yaml.v2 v2.2.2 ) diff --git a/go.sum b/go.sum index 433dac2d..c26688a1 100644 --- a/go.sum +++ b/go.sum @@ -28,6 +28,7 @@ github.com/bluele/gcache v0.0.0-20190518031135-bc40bd653833/go.mod h1:8c4/i2Vlov github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/etcd-io/bbolt v1.3.3/go.mod h1:ZF2nL25h33cCyBtcyWeZ2/I3HQOfTP+0PIEvHjkjCrw= github.com/go-ole/go-ole v1.2.4 h1:nNBDSCOigTSiarFpYE9J/KtEA1IOW4CNeqT9TQDqCxI= github.com/go-ole/go-ole v1.2.4/go.mod h1:XCwSNxSkXRo4vlyPy93sltvi/qJq0jqQhjqQNIwKuxM= github.com/go-test/deep v1.0.1 h1:UQhStjbkDClarlmv0am7OXXO4/GaPdCGiUiMTvi28sg= @@ -80,6 +81,7 @@ github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0 github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +go.etcd.io/bbolt v1.3.3/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9 h1:mKdxBk7AujPs8kU4m80U72y/zjbZ3UcXC7dClwKbUI0= golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90PveolxSbWFaJdECFbxSq0Mqo2M= @@ -112,8 +114,6 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -gopkg.in/asaskevich/govalidator.v4 v4.0.0-20160518190739-766470278477 h1:5xUJw+lg4zao9W4HIDzlFbMYgSgtvNVHh00MEHvbGpQ= -gopkg.in/asaskevich/govalidator.v4 v4.0.0-20160518190739-766470278477/go.mod h1:QDV1vrFSrowdoOba0UM8VJPUZONT7dnfdLsM+GG53Z8= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/home/config.go b/home/config.go index 600cad41..b3a62efe 100644 --- a/home/config.go +++ b/home/config.go @@ -12,6 +12,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/dhcpd" "github.com/AdguardTeam/AdGuardHome/dnsfilter" "github.com/AdguardTeam/AdGuardHome/dnsforward" + "github.com/AdguardTeam/AdGuardHome/stats" "github.com/AdguardTeam/golibs/file" "github.com/AdguardTeam/golibs/log" yaml "gopkg.in/yaml.v2" @@ -68,6 +69,7 @@ type configuration struct { controlLock sync.Mutex transport *http.Transport client *http.Client + stats stats.Stats // cached version.json to avoid hammering github.io for each page reload versionCheckJSON []byte diff --git a/home/control_stats.go b/home/control_stats.go index 0dce5510..710eb577 100644 --- a/home/control_stats.go +++ b/home/control_stats.go @@ -4,6 +4,7 @@ import ( "encoding/json" "net/http" + "github.com/AdguardTeam/AdGuardHome/stats" "github.com/AdguardTeam/golibs/log" ) @@ -11,9 +12,8 @@ type statsConfig struct { Interval uint `json:"interval"` } +// Get stats configuration func handleStatsInfo(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) - resp := statsConfig{} resp.Interval = config.DNS.StatsInterval @@ -29,9 +29,8 @@ func handleStatsInfo(w http.ResponseWriter, r *http.Request) { } } +// Set stats configuration func handleStatsConfig(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) - reqData := statsConfig{} err := json.NewDecoder(r.Body).Decode(&reqData) if err != nil { @@ -45,8 +44,47 @@ func handleStatsConfig(w http.ResponseWriter, r *http.Request) { } config.DNS.StatsInterval = reqData.Interval + config.stats.Configure(int(config.DNS.StatsInterval)) - httpUpdateConfigReloadDNSReturnOK(w, r) + returnOK(w) +} + +// handleStats returns aggregated stats data +func handleStats(w http.ResponseWriter, r *http.Request) { + units := stats.Hours + if config.DNS.StatsInterval > 7 { + units = stats.Days + } + counter := log.StartTimer() + d := config.stats.GetData(units) + counter.LogElapsed("Stats: prepared data") + + if d == nil { + httpError(w, http.StatusInternalServerError, "Couldn't get statistics data") + return + } + + data, err := json.Marshal(d) + if err != nil { + httpError(w, http.StatusInternalServerError, "json encode: %s", err) + return + } + + w.Write(data) +} + +// handleStatsReset resets the stats +func handleStatsReset(w http.ResponseWriter, r *http.Request) { + config.stats.Clear() + returnOK(w) +} + +// RegisterStatsHandlers - register handlers +func RegisterStatsHandlers() { + httpRegister(http.MethodGet, "/control/stats", handleStats) + httpRegister(http.MethodPost, "/control/stats_reset", handleStatsReset) + httpRegister(http.MethodPost, "/control/stats_config", handleStatsConfig) + httpRegister(http.MethodGet, "/control/stats_info", handleStatsInfo) } func checkStatsInterval(i uint) bool { diff --git a/home/dns.go b/home/dns.go index cace3288..c00ddf51 100644 --- a/home/dns.go +++ b/home/dns.go @@ -8,6 +8,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/dnsfilter" "github.com/AdguardTeam/AdGuardHome/dnsforward" + "github.com/AdguardTeam/AdGuardHome/stats" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/log" @@ -33,7 +34,11 @@ func initDNSServer(baseDir string) { log.Fatalf("Cannot create DNS data dir at %s: %s", baseDir, err) } - config.dnsServer = dnsforward.NewServer(baseDir) + config.stats = stats.New("./data/stats.db", int(config.DNS.StatsInterval), nil) + if config.stats == nil { + log.Fatal("config.stats == nil") + } + config.dnsServer = dnsforward.NewServer(baseDir, config.stats) initRDNS() } @@ -178,5 +183,7 @@ func stopDNSServer() error { return errorx.Decorate(err, "Couldn't stop forwarding DNS server") } + config.stats.Close() + return nil } diff --git a/stats/stats.go b/stats/stats.go new file mode 100644 index 00000000..2542b16d --- /dev/null +++ b/stats/stats.go @@ -0,0 +1,68 @@ +// Module for managing statistics for DNS filtering server + +package stats + +import ( + "net" +) + +type unitIDCallback func() int + +// New - create object +// filename: DB file name +// limit: time limit (in days) +// unitID: user function to get the current unit ID. If nil, the current time hour is used. +func New(filename string, limit int, unitID unitIDCallback) Stats { + return createObject(filename, limit, unitID) +} + +// Stats - main interface +type Stats interface { + // Close object. + // This function is not thread safe + // (can't be called in parallel with any other function of this interface). + Close() + + // Set new configuration at runtime. + // limit: time limit (in days) + Configure(limit int) + + // Reset counters and clear database + Clear() + + // Update counters + Update(e Entry) + + // Get data + GetData(timeUnit TimeUnit) map[string]interface{} +} + +// TimeUnit - time unit +type TimeUnit int + +// Supported time units +const ( + Hours TimeUnit = iota + Days +) + +// Result of DNS request processing +type Result int + +// Supported result values +const ( + RNotFiltered Result = iota + 1 + RFiltered + RSafeBrowsing + RSafeSearch + RParental + rLast +) + +// Entry - data to add +type Entry struct { + Domain string + Client net.IP + Result Result + Time uint // processing time (msec) +} diff --git a/stats/stats_test.go b/stats/stats_test.go new file mode 100644 index 00000000..45b06520 --- /dev/null +++ b/stats/stats_test.go @@ -0,0 +1,115 @@ +package stats + +import ( + "fmt" + "net" + "os" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func UIntArrayEquals(a []uint, b []uint) bool { + if len(a) != len(b) { + return false + } + + for i := range a { + if a[i] != b[i] { + return false + } + } + + return true +} + +func TestStats(t *testing.T) { + s := New("./stats.db", 1, nil) + + e := Entry{} + + e.Domain = "domain" + e.Client = net.ParseIP("127.0.0.1") + e.Result = RFiltered + e.Time = 123456 + s.Update(e) + + e.Domain = "domain" + e.Client = net.ParseIP("127.0.0.1") + e.Result = RNotFiltered + e.Time = 123456 + s.Update(e) + + d := s.GetData(Hours) + a := []uint{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2} + assert.True(t, UIntArrayEquals(d["dns_queries"].([]uint), a)) + + a = []uint{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} + assert.True(t, UIntArrayEquals(d["blocked_filtering"].([]uint), a)) + + a = []uint{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} + assert.True(t, UIntArrayEquals(d["replaced_safebrowsing"].([]uint), a)) + + a = []uint{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} + assert.True(t, UIntArrayEquals(d["replaced_parental"].([]uint), a)) + + m := d["top_queried_domains"].([]map[string]uint) + assert.True(t, m[0]["domain"] == 1) + + m = d["top_blocked_domains"].([]map[string]uint) + assert.True(t, m[0]["domain"] == 1) + + m = d["top_clients"].([]map[string]uint) + assert.True(t, m[0]["127.0.0.1"] == 2) + + assert.True(t, d["num_dns_queries"].(uint) == 2) + assert.True(t, d["num_blocked_filtering"].(uint) == 1) + assert.True(t, d["num_replaced_safebrowsing"].(uint) == 0) + assert.True(t, d["num_replaced_safesearch"].(uint) == 0) + assert.True(t, d["num_replaced_parental"].(uint) == 0) + assert.True(t, d["avg_processing_time"].(float64) == 0.123456) + + s.Clear() + s.Close() + os.Remove("./stats.db") +} + +func TestLargeNumbers(t *testing.T) { + var hour int32 + hour = 1 + newID := func() int { + // use "atomic" to make Go race detector happy + return int(atomic.LoadInt32(&hour)) + } + + // log.SetLevel(log.DEBUG) + fn := "./stats.db" + os.Remove(fn) + s := New(fn, 1, newID) + e := Entry{} + + n := 1000 // number of distinct clients and domains every hour + for h := 0; h != 12; h++ { + if h != 0 { + atomic.AddInt32(&hour, 1) + time.Sleep(1500 * time.Millisecond) + } + for i := 0; i != n; i++ { + e.Domain = fmt.Sprintf("domain%d", i) + e.Client = net.ParseIP("127.0.0.1") + e.Client[2] = byte((i & 0xff00) >> 8) + e.Client[3] = byte(i & 0xff) + e.Result = RNotFiltered + e.Time = 123456 + s.Update(e) + } + } + + d := s.GetData(Hours) + assert.True(t, d["num_dns_queries"].(uint) == uint(int(hour)*n)) + + s.Close() + os.Remove(fn) +} diff --git a/stats/stats_unit.go b/stats/stats_unit.go new file mode 100644 index 00000000..1cc69324 --- /dev/null +++ b/stats/stats_unit.go @@ -0,0 +1,672 @@ +package stats + +import ( + "bytes" + "encoding/binary" + "encoding/gob" + "fmt" + "os" + "sort" + "sync" + "time" + + "github.com/AdguardTeam/golibs/log" + bolt "github.com/etcd-io/bbolt" +) + +const ( + maxDomains = 100 // max number of top domains to store in file or return via Get() + maxClients = 100 // max number of top clients to store in file or return via Get() +) + +// statsCtx - global context +type statsCtx struct { + limit int // maximum time we need to keep data for (in hours) + filename string // database file name + unitID unitIDCallback // user function which returns the current unit ID + db *bolt.DB + + unit *unit // the current unit + unitLock sync.Mutex // protect 'unit' +} + +// data for 1 time unit +type unit struct { + id int // unit ID. Default: absolute hour since Jan 1, 1970 + + nTotal int // total requests + nResult []int // number of requests per one result + timeSum int // sum of processing time of all requests (usec) + + // top: + domains map[string]int // number of requests per domain + blockedDomains map[string]int // number of blocked requests per domain + clients map[string]int // number of requests per client +} + +// name-count pair +type countPair struct { + Name string + Count uint +} + +// structure for storing data in file +type unitDB struct { + NTotal uint + NResult []uint + + Domains []countPair + BlockedDomains []countPair + Clients []countPair + + TimeAvg uint // usec +} + +func createObject(filename string, limitDays int, unitID unitIDCallback) *statsCtx { + s := statsCtx{} + s.limit = limitDays * 24 + s.filename = filename + s.unitID = newUnitID + if unitID != nil { + s.unitID = unitID + } + + if !s.dbOpen() { + return nil + } + + id := s.unitID() + tx := s.beginTxn(true) + var udb *unitDB + if tx != nil { + log.Tracef("Deleting old units...") + firstID := id - s.limit - 1 + unitDel := 0 + forEachBkt := func(name []byte, b *bolt.Bucket) error { + id := btoi(name) + if id < firstID { + err := tx.DeleteBucket(name) + if err != nil { + log.Debug("tx.DeleteBucket: %s", err) + } + log.Debug("Stats: deleted unit %d", id) + unitDel++ + return nil + } + return fmt.Errorf("") + } + _ = tx.ForEach(forEachBkt) + + udb = s.loadUnitFromDB(tx, id) + + if unitDel != 0 { + s.commitTxn(tx) + } else { + _ = tx.Rollback() + } + } + + u := unit{} + s.initUnit(&u, id) + if udb != nil { + deserialize(&u, udb) + } + s.unit = &u + + go s.periodicFlush() + + log.Debug("Stats: initialized") + return &s +} + +func (s *statsCtx) dbOpen() bool { + var err error + log.Tracef("db.Open...") + s.db, err = bolt.Open(s.filename, 0644, nil) + if err != nil { + log.Error("Stats: open DB: %s: %s", s.filename, err) + return false + } + log.Tracef("db.Open") + return true +} + +// Atomically swap the currently active unit with a new value +// Return old value +func (s *statsCtx) swapUnit(new *unit) *unit { + s.unitLock.Lock() + u := s.unit + s.unit = new + s.unitLock.Unlock() + return u +} + +// Get unit ID for the current hour +func newUnitID() int { + return int(time.Now().Unix() / (60 * 60)) +} + +// Initialize a unit +func (s *statsCtx) initUnit(u *unit, id int) { + u.id = id + u.nResult = make([]int, rLast) + u.domains = make(map[string]int) + u.blockedDomains = make(map[string]int) + u.clients = make(map[string]int) +} + +// Open a DB transaction +func (s *statsCtx) beginTxn(wr bool) *bolt.Tx { + db := s.db + if db == nil { + return nil + } + + log.Tracef("db.Begin...") + tx, err := db.Begin(wr) + if err != nil { + log.Error("db.Begin: %s", err) + return nil + } + log.Tracef("db.Begin") + return tx +} + +func (s *statsCtx) commitTxn(tx *bolt.Tx) { + err := tx.Commit() + if err != nil { + log.Debug("tx.Commit: %s", err) + return + } + log.Tracef("tx.Commit") +} + +// Get unit name +func unitName(id int) []byte { + return itob(id) +} + +// Convert integer to 8-byte array (big endian) +func itob(v int) []byte { + b := make([]byte, 8) + binary.BigEndian.PutUint64(b, uint64(v)) + return b +} + +// Convert 8-byte array (big endian) to integer +func btoi(b []byte) int { + return int(binary.BigEndian.Uint64(b)) +} + +// Flush the current unit to DB and delete an old unit when a new hour is started +func (s *statsCtx) periodicFlush() { + for { + s.unitLock.Lock() + ptr := s.unit + s.unitLock.Unlock() + if ptr == nil { + break + } + + id := s.unitID() + if ptr.id == id { + time.Sleep(time.Second) + continue + } + + nu := unit{} + s.initUnit(&nu, id) + u := s.swapUnit(&nu) + udb := serialize(u) + + tx := s.beginTxn(true) + if tx == nil { + continue + } + ok1 := s.flushUnitToDB(tx, u.id, udb) + ok2 := s.deleteUnit(tx, id-s.limit) + if ok1 || ok2 { + s.commitTxn(tx) + } else { + _ = tx.Rollback() + } + } + log.Tracef("periodicFlush() exited") +} + +// Delete unit's data from file +func (s *statsCtx) deleteUnit(tx *bolt.Tx, id int) bool { + err := tx.DeleteBucket(unitName(id)) + if err != nil { + log.Tracef("bolt DeleteBucket: %s", err) + return false + } + log.Debug("Stats: deleted unit %d", id) + return true +} + +func convertMapToArray(m map[string]int, max int) []countPair { + a := []countPair{} + for k, v := range m { + pair := countPair{} + pair.Name = k + pair.Count = uint(v) + a = append(a, pair) + } + less := func(i, j int) bool { + if a[i].Count >= a[j].Count { + return true + } + return false + } + sort.Slice(a, less) + if max > len(a) { + max = len(a) + } + return a[:max] +} + +func convertArrayToMap(a []countPair) map[string]int { + m := map[string]int{} + for _, it := range a { + m[it.Name] = int(it.Count) + } + return m +} + +func serialize(u *unit) *unitDB { + udb := unitDB{} + udb.NTotal = uint(u.nTotal) + for _, it := range u.nResult { + udb.NResult = append(udb.NResult, uint(it)) + } + if u.nTotal != 0 { + udb.TimeAvg = uint(u.timeSum / u.nTotal) + } + udb.Domains = convertMapToArray(u.domains, maxDomains) + udb.BlockedDomains = convertMapToArray(u.blockedDomains, maxDomains) + udb.Clients = convertMapToArray(u.clients, maxClients) + return &udb +} + +func deserialize(u *unit, udb *unitDB) { + u.nTotal = int(udb.NTotal) + for _, it := range udb.NResult { + u.nResult = append(u.nResult, int(it)) + } + u.domains = convertArrayToMap(udb.Domains) + u.blockedDomains = convertArrayToMap(udb.BlockedDomains) + u.clients = convertArrayToMap(udb.Clients) + u.timeSum = int(udb.TimeAvg) * u.nTotal +} + +func (s *statsCtx) flushUnitToDB(tx *bolt.Tx, id int, udb *unitDB) bool { + log.Tracef("Flushing unit %d", id) + + bkt, err := tx.CreateBucketIfNotExists(unitName(id)) + if err != nil { + log.Error("tx.CreateBucketIfNotExists: %s", err) + return false + } + + var buf bytes.Buffer + enc := gob.NewEncoder(&buf) + err = enc.Encode(udb) + if err != nil { + log.Error("gob.Encode: %s", err) + return false + } + + err = bkt.Put([]byte{0}, buf.Bytes()) + if err != nil { + log.Error("bkt.Put: %s", err) + return false + } + + return true +} + +func (s *statsCtx) loadUnitFromDB(tx *bolt.Tx, id int) *unitDB { + bkt := tx.Bucket(unitName(id)) + if bkt == nil { + return nil + } + + log.Tracef("Loading unit %d", id) + + var buf bytes.Buffer + buf.Write(bkt.Get([]byte{0})) + dec := gob.NewDecoder(&buf) + udb := unitDB{} + err := dec.Decode(&udb) + if err != nil { + log.Error("gob Decode: %s", err) + return nil + } + + return &udb +} + +func convertTopArray(a []countPair) []map[string]uint { + m := []map[string]uint{} + for _, it := range a { + ent := map[string]uint{} + ent[it.Name] = it.Count + m = append(m, ent) + } + return m +} + +func (s *statsCtx) Configure(limit int) { + if limit < 0 { + return + } + s.limit = limit * 24 + log.Debug("Stats: set limit: %d", limit) +} + +func (s *statsCtx) Close() { + u := s.swapUnit(nil) + udb := serialize(u) + tx := s.beginTxn(true) + if tx != nil { + if s.flushUnitToDB(tx, u.id, udb) { + s.commitTxn(tx) + } else { + _ = tx.Rollback() + } + } + + if s.db != nil { + log.Tracef("db.Close...") + _ = s.db.Close() + log.Tracef("db.Close") + } + + log.Debug("Stats: closed") +} + +func (s *statsCtx) Clear() { + tx := s.beginTxn(true) + if tx != nil { + db := s.db + s.db = nil + _ = tx.Rollback() + // the active transactions can continue using database, + // but no new transactions will be opened + _ = db.Close() + log.Tracef("db.Close") + // all active transactions are now closed + } + + u := unit{} + s.initUnit(&u, s.unitID()) + _ = s.swapUnit(&u) + + err := os.Remove(s.filename) + if err != nil { + log.Error("os.Remove: %s", err) + } + + _ = s.dbOpen() + + log.Debug("Stats: cleared") +} + +func (s *statsCtx) Update(e Entry) { + if e.Result == 0 || + len(e.Domain) == 0 || + !(len(e.Client) == 4 || len(e.Client) == 16) { + return + } + client := e.Client.String() + + s.unitLock.Lock() + u := s.unit + + u.nResult[e.Result]++ + + if e.Result == RNotFiltered { + u.domains[e.Domain]++ + } else { + u.blockedDomains[e.Domain]++ + } + + u.clients[client]++ + u.timeSum += int(e.Time) + u.nTotal++ + s.unitLock.Unlock() +} + +/* Algorithm: +. Prepare array of N units, where N is the value of "limit" configuration setting + . Load data for the most recent units from file + If a unit with required ID doesn't exist, just add an empty unit + . Get data for the current unit +. Process data from the units and prepare an output map object: + * per time unit counters: + * DNS-queries/time-unit + * blocked/time-unit + * safebrowsing-blocked/time-unit + * parental-blocked/time-unit + If time-unit is an hour, just add values from each unit to an array. + If time-unit is a day, aggregate per-hour data into days. + * top counters: + * queries/domain + * queries/blocked-domain + * queries/client + To get these values we first sum up data for all units into a single map. + Then we get the pairs with the highest numbers (the values are sorted in descending order) + * total counters: + * DNS-queries + * blocked + * safebrowsing-blocked + * safesearch-blocked + * parental-blocked + These values are just the sum of data for all units. +*/ +// nolint (gocyclo) +func (s *statsCtx) GetData(timeUnit TimeUnit) map[string]interface{} { + d := map[string]interface{}{} + + tx := s.beginTxn(false) + if tx == nil { + return nil + } + + units := []*unitDB{} //per-hour units + lastID := s.unitID() + firstID := lastID - s.limit + 1 + for i := firstID; i != lastID; i++ { + u := s.loadUnitFromDB(tx, i) + if u == nil { + u = &unitDB{} + u.NResult = make([]uint, rLast) + } + units = append(units, u) + } + + _ = tx.Rollback() + + s.unitLock.Lock() + cu := serialize(s.unit) + cuID := s.unit.id + s.unitLock.Unlock() + if cuID != lastID { + units = units[1:] + } + units = append(units, cu) + + if len(units) != s.limit { + log.Fatalf("len(units) != s.limit: %d %d", len(units), s.limit) + } + + // per time unit counters: + + // 720 hours may span 31 days, so we skip data for the first day in this case + firstDayID := (firstID + 24 - 1) / 24 * 24 // align_ceil(24) + + a := []uint{} + if timeUnit == Hours { + for _, u := range units { + a = append(a, u.NTotal) + } + } else { + var sum uint + id := firstDayID + nextDayID := firstDayID + 24 + for i := firstDayID - firstID; i != len(units); i++ { + sum += units[i].NTotal + if id == nextDayID { + a = append(a, sum) + sum = 0 + nextDayID += 24 + } + id++ + } + if id < nextDayID { + a = append(a, sum) + } + if len(a) != s.limit/24 { + log.Fatalf("len(a) != s.limit: %d %d", len(a), s.limit) + } + } + d["dns_queries"] = a + + a = []uint{} + if timeUnit == Hours { + for _, u := range units { + a = append(a, u.NResult[RFiltered]) + } + } else { + var sum uint + id := firstDayID + nextDayID := firstDayID + 24 + for i := firstDayID - firstID; i != len(units); i++ { + sum += units[i].NResult[RFiltered] + if id == nextDayID { + a = append(a, sum) + sum = 0 + nextDayID += 24 + } + id++ + } + if id < nextDayID { + a = append(a, sum) + } + } + d["blocked_filtering"] = a + + a = []uint{} + if timeUnit == Hours { + for _, u := range units { + a = append(a, u.NResult[RSafeBrowsing]) + } + } else { + var sum uint + id := firstDayID + nextDayID := firstDayID + 24 + for i := firstDayID - firstID; i != len(units); i++ { + sum += units[i].NResult[RSafeBrowsing] + if id == nextDayID { + a = append(a, sum) + sum = 0 + nextDayID += 24 + } + id++ + } + if id < nextDayID { + a = append(a, sum) + } + } + d["replaced_safebrowsing"] = a + + a = []uint{} + if timeUnit == Hours { + for _, u := range units { + a = append(a, u.NResult[RParental]) + } + } else { + var sum uint + id := firstDayID + nextDayID := firstDayID + 24 + for i := firstDayID - firstID; i != len(units); i++ { + sum += units[i].NResult[RParental] + if id == nextDayID { + a = append(a, sum) + sum = 0 + nextDayID += 24 + } + id++ + } + if id < nextDayID { + a = append(a, sum) + } + } + d["replaced_parental"] = a + + // top counters: + + m := map[string]int{} + for _, u := range units { + for _, it := range u.Domains { + m[it.Name] += int(it.Count) + } + } + a2 := convertMapToArray(m, maxDomains) + d["top_queried_domains"] = convertTopArray(a2) + + m = map[string]int{} + for _, u := range units { + for _, it := range u.BlockedDomains { + m[it.Name] += int(it.Count) + } + } + a2 = convertMapToArray(m, maxDomains) + d["top_blocked_domains"] = convertTopArray(a2) + + m = map[string]int{} + for _, u := range units { + for _, it := range u.Clients { + m[it.Name] += int(it.Count) + } + } + a2 = convertMapToArray(m, maxClients) + d["top_clients"] = convertTopArray(a2) + + // total counters: + + sum := unitDB{} + timeN := 0 + sum.NResult = make([]uint, rLast) + for _, u := range units { + sum.NTotal += u.NTotal + sum.TimeAvg += u.TimeAvg + if u.TimeAvg != 0 { + timeN++ + } + sum.NResult[RFiltered] += u.NResult[RFiltered] + sum.NResult[RSafeBrowsing] += u.NResult[RSafeBrowsing] + sum.NResult[RSafeSearch] += u.NResult[RSafeSearch] + sum.NResult[RParental] += u.NResult[RParental] + } + + d["num_dns_queries"] = sum.NTotal + d["num_blocked_filtering"] = sum.NResult[RFiltered] + d["num_replaced_safebrowsing"] = sum.NResult[RSafeBrowsing] + d["num_replaced_safesearch"] = sum.NResult[RSafeSearch] + d["num_replaced_parental"] = sum.NResult[RParental] + + avgTime := float64(0) + if timeN != 0 { + avgTime = float64(sum.TimeAvg/uint(timeN)) / 1000000 + } + d["avg_processing_time"] = avgTime + + d["time_units"] = "hours" + if timeUnit == Days { + d["time_units"] = "days" + } + + return d +}