diff --git a/home/dns.go b/home/dns.go index a016d302..259ad5cb 100644 --- a/home/dns.go +++ b/home/dns.go @@ -36,8 +36,11 @@ func initDNSServer(baseDir string) { log.Fatalf("Cannot create DNS data dir at %s: %s", baseDir, err) } - statsDBFilename := filepath.Join(baseDir, "stats.db") - config.stats, err = stats.New(statsDBFilename, config.DNS.StatsInterval, nil) + statsConf := stats.Config{ + Filename: filepath.Join(baseDir, "stats.db"), + LimitDays: config.DNS.StatsInterval, + } + config.stats, err = stats.New(statsConf) if err != nil { log.Fatal("Couldn't initialize statistics module") } diff --git a/stats/stats.go b/stats/stats.go index b6883869..d2b59859 100644 --- a/stats/stats.go +++ b/stats/stats.go @@ -8,12 +8,16 @@ import ( type unitIDCallback func() uint32 +// Config - module configuration +type Config struct { + Filename string // database file name + LimitDays uint32 // time limit (in days) + UnitID unitIDCallback // user function to get the current unit ID. If nil, the current time hour is used. +} + // New - create object -// filename: DB file name -// limit: time limit (in days) -// unitID: user function to get the current unit ID. If nil, the current time hour is used. -func New(filename string, limit uint32, unitID unitIDCallback) (Stats, error) { - return createObject(filename, limit, unitID) +func New(conf Config) (Stats, error) { + return createObject(conf) } // Stats - main interface diff --git a/stats/stats_test.go b/stats/stats_test.go index 70437689..0590b443 100644 --- a/stats/stats_test.go +++ b/stats/stats_test.go @@ -26,7 +26,11 @@ func UIntArrayEquals(a []uint64, b []uint64) bool { } func TestStats(t *testing.T) { - s, _ := New("./stats.db", 1, nil) + conf := Config{ + Filename: "./stats.db", + LimitDays: 1, + } + s, _ := New(conf) e := Entry{} @@ -73,7 +77,7 @@ func TestStats(t *testing.T) { s.Clear() s.Close() - os.Remove("./stats.db") + os.Remove(conf.Filename) } func TestLargeNumbers(t *testing.T) { @@ -85,9 +89,13 @@ func TestLargeNumbers(t *testing.T) { } // log.SetLevel(log.DEBUG) - fn := "./stats.db" - os.Remove(fn) - s, _ := New(fn, 1, newID) + conf := Config{ + Filename: "./stats.db", + LimitDays: 1, + UnitID: newID, + } + os.Remove(conf.Filename) + s, _ := New(conf) e := Entry{} n := 1000 // number of distinct clients and domains every hour @@ -111,5 +119,5 @@ func TestLargeNumbers(t *testing.T) { assert.True(t, d["num_dns_queries"].(uint64) == uint64(int(hour)*n)) s.Close() - os.Remove(fn) + os.Remove(conf.Filename) } diff --git a/stats/stats_unit.go b/stats/stats_unit.go index bfcc09b8..20bfb721 100644 --- a/stats/stats_unit.go +++ b/stats/stats_unit.go @@ -21,10 +21,10 @@ const ( // statsCtx - global context type statsCtx struct { - limit uint32 // maximum time we need to keep data for (in hours) - filename string // database file name - unitID unitIDCallback // user function which returns the current unit ID - db *bolt.DB + limit uint32 // maximum time we need to keep data for (in hours) + db *bolt.DB + + conf Config unit *unit // the current unit unitLock sync.Mutex // protect 'unit' @@ -62,20 +62,19 @@ type unitDB struct { TimeAvg uint32 // usec } -func createObject(filename string, limitDays uint32, unitID unitIDCallback) (*statsCtx, error) { +func createObject(conf Config) (*statsCtx, error) { s := statsCtx{} - s.limit = limitDays * 24 - s.filename = filename - s.unitID = newUnitID - if unitID != nil { - s.unitID = unitID + s.limit = conf.LimitDays * 24 + s.conf = conf + if conf.UnitID == nil { + s.conf.UnitID = newUnitID } if !s.dbOpen() { return nil, fmt.Errorf("open database") } - id := s.unitID() + id := s.conf.UnitID() tx := s.beginTxn(true) var udb *unitDB if tx != nil { @@ -122,9 +121,9 @@ func createObject(filename string, limitDays uint32, unitID unitIDCallback) (*st func (s *statsCtx) dbOpen() bool { var err error log.Tracef("db.Open...") - s.db, err = bolt.Open(s.filename, 0644, nil) + s.db, err = bolt.Open(s.conf.Filename, 0644, nil) if err != nil { - log.Error("Stats: open DB: %s: %s", s.filename, err) + log.Error("Stats: open DB: %s: %s", s.conf.Filename, err) return false } log.Tracef("db.Open") @@ -208,7 +207,7 @@ func (s *statsCtx) periodicFlush() { break } - id := s.unitID() + id := s.conf.UnitID() if ptr.id == id { time.Sleep(time.Second) continue @@ -406,10 +405,10 @@ func (s *statsCtx) Clear() { } u := unit{} - s.initUnit(&u, s.unitID()) + s.initUnit(&u, s.conf.UnitID()) _ = s.swapUnit(&u) - err := os.Remove(s.filename) + err := os.Remove(s.conf.Filename) if err != nil { log.Error("os.Remove: %s", err) } @@ -481,7 +480,7 @@ func (s *statsCtx) GetData(timeUnit TimeUnit) map[string]interface{} { } units := []*unitDB{} //per-hour units - lastID := s.unitID() + lastID := s.conf.UnitID() firstID := lastID - s.limit + 1 for i := firstID; i != lastID; i++ { u := s.loadUnitFromDB(tx, i)