Merge: * DNS filters: optimize filter update

Close #1463

Squashed commit of the following:

commit d5bdc939a2ae9f6d1ae879e4225b1dce09657b92
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Mon Mar 16 16:39:17 2020 +0300

    minor

commit e15b56a0d9db182f9d30b434584018cb1bf038d5
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Thu Mar 12 14:39:07 2020 +0300

    minor

commit 77bf59ca6e556b75af48c5987866af6d5025dae8
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Thu Mar 12 14:30:04 2020 +0300

    minor

commit e19c13f82dd408ed638bd4b68d21cdfebbdf782f
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Thu Mar 12 14:24:50 2020 +0300

    minor

commit 9113c6dae6263aa7ee6e4295c2b60dd3083e2bf0
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Thu Mar 12 14:02:06 2020 +0300

    minor

commit 70283e329e32def3375e893f806a2a02d8ca9f57
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Thu Mar 12 13:35:11 2020 +0300

    logical module Filtering

commit 837a255c6a04941e9fc007a56d71faf4c4213257
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Thu Mar 12 13:11:37 2020 +0300

    minor

commit 1853ed2b57a86dd49508023f47218219399b4fe5
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Thu Mar 12 12:59:28 2020 +0300

    refactor

commit 1ba3cc53c76255439fe54693b40ee9665fdc15e4
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Wed Mar 11 20:12:53 2020 +0300

    * filters: optimize update procedure
This commit is contained in:
Simon Zolin 2020-03-17 15:00:40 +03:00
parent a93652b1c0
commit 646725efb7
8 changed files with 227 additions and 201 deletions

View File

@ -339,12 +339,5 @@ func writeAllConfigs() error {
return err return err
} }
userFilter := userFilter()
err = userFilter.save()
if err != nil {
log.Error("Couldn't save the user filter: %s", err)
return err
}
return nil return nil
} }

View File

@ -173,7 +173,6 @@ func registerControlHandlers() {
httpRegister("GET", "/control/profile", handleGetProfile) httpRegister("GET", "/control/profile", handleGetProfile)
RegisterFilteringHandlers()
RegisterTLSHandlers() RegisterTLSHandlers()
RegisterBlockedServicesHandlers() RegisterBlockedServicesHandlers()
RegisterAuthHandlers() RegisterAuthHandlers()

View File

@ -33,7 +33,7 @@ type filterAddJSON struct {
Whitelist bool `json:"whitelist"` Whitelist bool `json:"whitelist"`
} }
func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) { func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request) {
fj := filterAddJSON{} fj := filterAddJSON{}
err := json.NewDecoder(r.Body).Decode(&fj) err := json.NewDecoder(r.Body).Decode(&fj)
if err != nil { if err != nil {
@ -53,52 +53,41 @@ func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) {
} }
// Set necessary properties // Set necessary properties
f := filter{ filt := filter{
Enabled: true, Enabled: true,
URL: fj.URL, URL: fj.URL,
Name: fj.Name, Name: fj.Name,
white: fj.Whitelist, white: fj.Whitelist,
} }
f.ID = assignUniqueFilterID() filt.ID = assignUniqueFilterID()
// Download the filter contents // Download the filter contents
ok, err := f.update() ok, err := f.update(&filt)
if err != nil { if err != nil {
httpError(w, http.StatusBadRequest, "Couldn't fetch filter from url %s: %s", f.URL, err) httpError(w, http.StatusBadRequest, "Couldn't fetch filter from url %s: %s", filt.URL, err)
return
}
if f.RulesCount == 0 {
httpError(w, http.StatusBadRequest, "Filter at the url %s has no rules (maybe it points to blank page?)", f.URL)
return return
} }
if !ok { if !ok {
httpError(w, http.StatusBadRequest, "Filter at the url %s is invalid (maybe it points to blank page?)", f.URL) httpError(w, http.StatusBadRequest, "Filter at the url %s is invalid (maybe it points to blank page?)", filt.URL)
return
}
// Save the filter contents
err = f.save()
if err != nil {
httpError(w, http.StatusBadRequest, "Failed to save filter %d due to %s", f.ID, err)
return return
} }
// URL is deemed valid, append it to filters, update config, write new filter file and tell dns to reload it // URL is deemed valid, append it to filters, update config, write new filter file and tell dns to reload it
if !filterAdd(f) { if !filterAdd(filt) {
httpError(w, http.StatusBadRequest, "Filter URL already added -- %s", f.URL) httpError(w, http.StatusBadRequest, "Filter URL already added -- %s", filt.URL)
return return
} }
onConfigModified() onConfigModified()
enableFilters(true) enableFilters(true)
_, err = fmt.Fprintf(w, "OK %d rules\n", f.RulesCount) _, err = fmt.Fprintf(w, "OK %d rules\n", filt.RulesCount)
if err != nil { if err != nil {
httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err) httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err)
} }
} }
func handleFilteringRemoveURL(w http.ResponseWriter, r *http.Request) { func (f *Filtering) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Request) {
type request struct { type request struct {
URL string `json:"url"` URL string `json:"url"`
@ -156,7 +145,7 @@ type filterURLReq struct {
Data filterURLJSON `json:"data"` Data filterURLJSON `json:"data"`
} }
func handleFilteringSetURL(w http.ResponseWriter, r *http.Request) { func (f *Filtering) handleFilteringSetURL(w http.ResponseWriter, r *http.Request) {
fj := filterURLReq{} fj := filterURLReq{}
err := json.NewDecoder(r.Body).Decode(&fj) err := json.NewDecoder(r.Body).Decode(&fj)
if err != nil { if err != nil {
@ -169,12 +158,12 @@ func handleFilteringSetURL(w http.ResponseWriter, r *http.Request) {
return return
} }
f := filter{ filt := filter{
Enabled: fj.Data.Enabled, Enabled: fj.Data.Enabled,
Name: fj.Data.Name, Name: fj.Data.Name,
URL: fj.Data.URL, URL: fj.Data.URL,
} }
status := filterSetProperties(fj.URL, f, fj.Whitelist) status := f.filterSetProperties(fj.URL, filt, fj.Whitelist)
if (status & statusFound) == 0 { if (status & statusFound) == 0 {
http.Error(w, "URL doesn't exist", http.StatusBadRequest) http.Error(w, "URL doesn't exist", http.StatusBadRequest)
return return
@ -196,7 +185,7 @@ func handleFilteringSetURL(w http.ResponseWriter, r *http.Request) {
if fj.Whitelist { if fj.Whitelist {
flags = FilterRefreshAllowlists flags = FilterRefreshAllowlists
} }
nUpdated, _ := refreshFilters(flags, true) nUpdated, _ := f.refreshFilters(flags, true)
// if at least 1 filter has been updated, refreshFilters() restarts the filtering automatically // if at least 1 filter has been updated, refreshFilters() restarts the filtering automatically
// if not - we restart the filtering ourselves // if not - we restart the filtering ourselves
restart = false restart = false
@ -209,7 +198,7 @@ func handleFilteringSetURL(w http.ResponseWriter, r *http.Request) {
} }
} }
func handleFilteringSetRules(w http.ResponseWriter, r *http.Request) { func (f *Filtering) handleFilteringSetRules(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(r.Body) body, err := ioutil.ReadAll(r.Body)
if err != nil { if err != nil {
httpError(w, http.StatusBadRequest, "Failed to read request body: %s", err) httpError(w, http.StatusBadRequest, "Failed to read request body: %s", err)
@ -218,15 +207,10 @@ func handleFilteringSetRules(w http.ResponseWriter, r *http.Request) {
config.UserRules = strings.Split(string(body), "\n") config.UserRules = strings.Split(string(body), "\n")
onConfigModified() onConfigModified()
userFilter := userFilter()
err = userFilter.save()
if err != nil {
log.Error("Couldn't save the user filter: %s", err)
}
enableFilters(true) enableFilters(true)
} }
func handleFilteringRefresh(w http.ResponseWriter, r *http.Request) { func (f *Filtering) handleFilteringRefresh(w http.ResponseWriter, r *http.Request) {
type Req struct { type Req struct {
White bool `json:"whitelist"` White bool `json:"whitelist"`
} }
@ -248,7 +232,7 @@ func handleFilteringRefresh(w http.ResponseWriter, r *http.Request) {
if req.White { if req.White {
flags = FilterRefreshAllowlists flags = FilterRefreshAllowlists
} }
resp.Updated, err = refreshFilters(flags|FilterRefreshForce, false) resp.Updated, err = f.refreshFilters(flags|FilterRefreshForce, false)
Context.controlLock.Lock() Context.controlLock.Lock()
if err != nil { if err != nil {
httpError(w, http.StatusInternalServerError, "%s", err) httpError(w, http.StatusInternalServerError, "%s", err)
@ -298,7 +282,7 @@ func filterToJSON(f filter) filterJSON {
} }
// Get filtering configuration // Get filtering configuration
func handleFilteringStatus(w http.ResponseWriter, r *http.Request) { func (f *Filtering) handleFilteringStatus(w http.ResponseWriter, r *http.Request) {
resp := filteringConfig{} resp := filteringConfig{}
config.RLock() config.RLock()
resp.Enabled = config.DNS.FilteringEnabled resp.Enabled = config.DNS.FilteringEnabled
@ -327,7 +311,7 @@ func handleFilteringStatus(w http.ResponseWriter, r *http.Request) {
} }
// Set filtering configuration // Set filtering configuration
func handleFilteringConfig(w http.ResponseWriter, r *http.Request) { func (f *Filtering) handleFilteringConfig(w http.ResponseWriter, r *http.Request) {
req := filteringConfig{} req := filteringConfig{}
err := json.NewDecoder(r.Body).Decode(&req) err := json.NewDecoder(r.Body).Decode(&req)
if err != nil { if err != nil {
@ -359,7 +343,7 @@ type checkHostResp struct {
IPList []net.IP `json:"ip_addrs"` // list of IP addresses IPList []net.IP `json:"ip_addrs"` // list of IP addresses
} }
func handleCheckHost(w http.ResponseWriter, r *http.Request) { func (f *Filtering) handleCheckHost(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query() q := r.URL.Query()
host := q.Get("name") host := q.Get("name")
@ -389,15 +373,15 @@ func handleCheckHost(w http.ResponseWriter, r *http.Request) {
} }
// RegisterFilteringHandlers - register handlers // RegisterFilteringHandlers - register handlers
func RegisterFilteringHandlers() { func (f *Filtering) RegisterFilteringHandlers() {
httpRegister("GET", "/control/filtering/status", handleFilteringStatus) httpRegister("GET", "/control/filtering/status", f.handleFilteringStatus)
httpRegister("POST", "/control/filtering/config", handleFilteringConfig) httpRegister("POST", "/control/filtering/config", f.handleFilteringConfig)
httpRegister("POST", "/control/filtering/add_url", handleFilteringAddURL) httpRegister("POST", "/control/filtering/add_url", f.handleFilteringAddURL)
httpRegister("POST", "/control/filtering/remove_url", handleFilteringRemoveURL) httpRegister("POST", "/control/filtering/remove_url", f.handleFilteringRemoveURL)
httpRegister("POST", "/control/filtering/set_url", handleFilteringSetURL) httpRegister("POST", "/control/filtering/set_url", f.handleFilteringSetURL)
httpRegister("POST", "/control/filtering/refresh", handleFilteringRefresh) httpRegister("POST", "/control/filtering/refresh", f.handleFilteringRefresh)
httpRegister("POST", "/control/filtering/set_rules", handleFilteringSetRules) httpRegister("POST", "/control/filtering/set_rules", f.handleFilteringSetRules)
httpRegister("GET", "/control/filtering/check_host", handleCheckHost) httpRegister("GET", "/control/filtering/check_host", f.handleCheckHost)
} }
func checkFiltersUpdateIntervalHours(i uint32) bool { func checkFiltersUpdateIntervalHours(i uint32) bool {

View File

@ -81,7 +81,7 @@ func initDNSServer() error {
Context.rdns = InitRDNS(Context.dnsServer, &Context.clients) Context.rdns = InitRDNS(Context.dnsServer, &Context.clients)
Context.whois = initWhois(&Context.clients) Context.whois = initWhois(&Context.clients)
initFiltering() Context.filters.Init()
return nil return nil
} }
@ -230,7 +230,7 @@ func startDNSServer() error {
} }
Context.dnsFilter.Start() Context.dnsFilter.Start()
startFiltering() Context.filters.Start()
Context.stats.Start() Context.stats.Start()
Context.queryLog.Start() Context.queryLog.Start()
@ -300,5 +300,7 @@ func closeDNSServer() {
Context.auth = nil Context.auth = nil
} }
Context.filters.Close()
log.Debug("Closed all DNS modules") log.Debug("Closed all DNS modules")
} }

View File

@ -1,8 +1,10 @@
package home package home
import ( import (
"bufio"
"fmt" "fmt"
"hash/crc32" "hash/crc32"
"io"
"io/ioutil" "io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
@ -15,30 +17,50 @@ import (
"github.com/AdguardTeam/AdGuardHome/dnsfilter" "github.com/AdguardTeam/AdGuardHome/dnsfilter"
"github.com/AdguardTeam/AdGuardHome/util" "github.com/AdguardTeam/AdGuardHome/util"
"github.com/AdguardTeam/golibs/file"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
) )
var ( var (
nextFilterID = time.Now().Unix() // semi-stable way to generate an unique ID nextFilterID = time.Now().Unix() // semi-stable way to generate an unique ID
filterTitleRegexp = regexp.MustCompile(`^! Title: +(.*)$`)
refreshStatus uint32 // 0:none; 1:in progress
refreshLock sync.Mutex
) )
func initFiltering() { // type FilteringConf struct {
loadFilters(config.Filters) // BlockLists []filter
loadFilters(config.WhitelistFilters) // AllowLists []filter
// UserRules []string
// }
// Filtering - module object
type Filtering struct {
// conf FilteringConf
refreshStatus uint32 // 0:none; 1:in progress
refreshLock sync.Mutex
filterTitleRegexp *regexp.Regexp
}
// Init - initialize the module
func (f *Filtering) Init() {
f.filterTitleRegexp = regexp.MustCompile(`^! Title: +(.*)$`)
_ = os.MkdirAll(filepath.Join(Context.getDataDir(), filterDir), 0755)
f.loadFilters(config.Filters)
f.loadFilters(config.WhitelistFilters)
deduplicateFilters() deduplicateFilters()
updateUniqueFilterID(config.Filters) updateUniqueFilterID(config.Filters)
updateUniqueFilterID(config.WhitelistFilters) updateUniqueFilterID(config.WhitelistFilters)
} }
func startFiltering() { // Start - start the module
func (f *Filtering) Start() {
f.RegisterFilteringHandlers()
// Here we should start updating filters, // Here we should start updating filters,
// but currently we can't wake up the periodic task to do so. // but currently we can't wake up the periodic task to do so.
// So for now we just start this periodic task from here. // So for now we just start this periodic task from here.
go periodicallyRefreshFilters() go f.periodicallyRefreshFilters()
}
// Close - close the module
func (f *Filtering) Close() {
} }
func defaultFilters() []filter { func defaultFilters() []filter {
@ -83,7 +105,7 @@ const (
// Update properties for a filter specified by its URL // Update properties for a filter specified by its URL
// Return status* flags. // Return status* flags.
func filterSetProperties(url string, newf filter, whitelist bool) int { func (f *Filtering) filterSetProperties(url string, newf filter, whitelist bool) int {
r := 0 r := 0
config.Lock() config.Lock()
defer config.Unlock() defer config.Unlock()
@ -94,44 +116,44 @@ func filterSetProperties(url string, newf filter, whitelist bool) int {
} }
for i := range *filters { for i := range *filters {
f := &(*filters)[i] filt := &(*filters)[i]
if f.URL != url { if filt.URL != url {
continue continue
} }
log.Debug("filter: set properties: %s: {%s %s %v}", log.Debug("filter: set properties: %s: {%s %s %v}",
f.URL, newf.Name, newf.URL, newf.Enabled) filt.URL, newf.Name, newf.URL, newf.Enabled)
f.Name = newf.Name filt.Name = newf.Name
if f.URL != newf.URL { if filt.URL != newf.URL {
r |= statusURLChanged | statusUpdateRequired r |= statusURLChanged | statusUpdateRequired
if filterExistsNoLock(newf.URL) { if filterExistsNoLock(newf.URL) {
return statusURLExists return statusURLExists
} }
f.URL = newf.URL filt.URL = newf.URL
f.unload() filt.unload()
f.LastUpdated = time.Time{} filt.LastUpdated = time.Time{}
f.checksum = 0 filt.checksum = 0
f.RulesCount = 0 filt.RulesCount = 0
} }
if f.Enabled != newf.Enabled { if filt.Enabled != newf.Enabled {
r |= statusEnabledChanged r |= statusEnabledChanged
f.Enabled = newf.Enabled filt.Enabled = newf.Enabled
if f.Enabled { if filt.Enabled {
if (r & statusURLChanged) == 0 { if (r & statusURLChanged) == 0 {
e := f.load() e := f.load(filt)
if e != nil { if e != nil {
// This isn't a fatal error, // This isn't a fatal error,
// because it may occur when someone removes the file from disk. // because it may occur when someone removes the file from disk.
f.LastUpdated = time.Time{} filt.LastUpdated = time.Time{}
f.checksum = 0 filt.checksum = 0
f.RulesCount = 0 filt.RulesCount = 0
r |= statusUpdateRequired r |= statusUpdateRequired
} }
} }
} else { } else {
f.unload() filt.unload()
} }
} }
@ -183,7 +205,7 @@ func filterAdd(f filter) bool {
// Load filters from the disk // Load filters from the disk
// And if any filter has zero ID, assign a new one // And if any filter has zero ID, assign a new one
func loadFilters(array []filter) { func (f *Filtering) loadFilters(array []filter) {
for i := range array { for i := range array {
filter := &array[i] // otherwise we're operating on a copy filter := &array[i] // otherwise we're operating on a copy
if filter.ID == 0 { if filter.ID == 0 {
@ -195,7 +217,7 @@ func loadFilters(array []filter) {
continue continue
} }
err := filter.load() err := f.load(filter)
if err != nil { if err != nil {
log.Error("Couldn't load filter %d contents due to %s", filter.ID, err) log.Error("Couldn't load filter %d contents due to %s", filter.ID, err)
} }
@ -235,16 +257,16 @@ func assignUniqueFilterID() int64 {
} }
// Sets up a timer that will be checking for filters updates periodically // Sets up a timer that will be checking for filters updates periodically
func periodicallyRefreshFilters() { func (f *Filtering) periodicallyRefreshFilters() {
const maxInterval = 1 * 60 * 60 const maxInterval = 1 * 60 * 60
intval := 5 // use a dynamically increasing time interval intval := 5 // use a dynamically increasing time interval
for { for {
isNetworkErr := false isNetworkErr := false
if config.DNS.FiltersUpdateIntervalHours != 0 && atomic.CompareAndSwapUint32(&refreshStatus, 0, 1) { if config.DNS.FiltersUpdateIntervalHours != 0 && atomic.CompareAndSwapUint32(&f.refreshStatus, 0, 1) {
refreshLock.Lock() f.refreshLock.Lock()
_, isNetworkErr = refreshFiltersIfNecessary(FilterRefreshBlocklists | FilterRefreshAllowlists) _, isNetworkErr = f.refreshFiltersIfNecessary(FilterRefreshBlocklists | FilterRefreshAllowlists)
refreshLock.Unlock() f.refreshLock.Unlock()
refreshStatus = 0 f.refreshStatus = 0
if !isNetworkErr { if !isNetworkErr {
intval = maxInterval intval = maxInterval
} }
@ -265,20 +287,20 @@ func periodicallyRefreshFilters() {
// flags: FilterRefresh* // flags: FilterRefresh*
// important: // important:
// TRUE: ignore the fact that we're currently updating the filters // TRUE: ignore the fact that we're currently updating the filters
func refreshFilters(flags int, important bool) (int, error) { func (f *Filtering) refreshFilters(flags int, important bool) (int, error) {
set := atomic.CompareAndSwapUint32(&refreshStatus, 0, 1) set := atomic.CompareAndSwapUint32(&f.refreshStatus, 0, 1)
if !important && !set { if !important && !set {
return 0, fmt.Errorf("Filters update procedure is already running") return 0, fmt.Errorf("Filters update procedure is already running")
} }
refreshLock.Lock() f.refreshLock.Lock()
nUpdated, _ := refreshFiltersIfNecessary(flags) nUpdated, _ := f.refreshFiltersIfNecessary(flags)
refreshLock.Unlock() f.refreshLock.Unlock()
refreshStatus = 0 f.refreshStatus = 0
return nUpdated, nil return nUpdated, nil
} }
func refreshFiltersArray(filters *[]filter, force bool) (int, []filter, []bool, bool) { func (f *Filtering) refreshFiltersArray(filters *[]filter, force bool) (int, []filter, []bool, bool) {
var updateFilters []filter var updateFilters []filter
var updateFlags []bool // 'true' if filter data has changed var updateFlags []bool // 'true' if filter data has changed
@ -312,14 +334,13 @@ func refreshFiltersArray(filters *[]filter, force bool) (int, []filter, []bool,
nfail := 0 nfail := 0
for i := range updateFilters { for i := range updateFilters {
uf := &updateFilters[i] uf := &updateFilters[i]
updated, err := uf.update() updated, err := f.update(uf)
updateFlags = append(updateFlags, updated) updateFlags = append(updateFlags, updated)
if err != nil { if err != nil {
nfail++ nfail++
log.Printf("Failed to update filter %s: %s\n", uf.URL, err) log.Printf("Failed to update filter %s: %s\n", uf.URL, err)
continue continue
} }
uf.LastUpdated = now
} }
if nfail == len(updateFilters) { if nfail == len(updateFilters) {
@ -330,18 +351,6 @@ func refreshFiltersArray(filters *[]filter, force bool) (int, []filter, []bool,
for i := range updateFilters { for i := range updateFilters {
uf := &updateFilters[i] uf := &updateFilters[i]
updated := updateFlags[i] updated := updateFlags[i]
if updated {
err := uf.saveAndBackupOld()
if err != nil {
log.Printf("Failed to save the updated filter %d: %s", uf.ID, err)
continue
}
} else {
e := os.Chtimes(uf.Path(), uf.LastUpdated, uf.LastUpdated)
if e != nil {
log.Error("os.Chtimes(): %v", e)
}
}
config.Lock() config.Lock()
for k := range *filters { for k := range *filters {
@ -357,7 +366,6 @@ func refreshFiltersArray(filters *[]filter, force bool) (int, []filter, []bool,
log.Info("Updated filter #%d. Rules: %d -> %d", log.Info("Updated filter #%d. Rules: %d -> %d",
f.ID, f.RulesCount, uf.RulesCount) f.ID, f.RulesCount, uf.RulesCount)
f.Name = uf.Name f.Name = uf.Name
f.Data = nil
f.RulesCount = uf.RulesCount f.RulesCount = uf.RulesCount
f.checksum = uf.checksum f.checksum = uf.checksum
updateCount++ updateCount++
@ -381,18 +389,19 @@ const (
// Algorithm: // Algorithm:
// . Get the list of filters to be updated // . Get the list of filters to be updated
// . For each filter run the download and checksum check operation // . For each filter run the download and checksum check operation
// . Store downloaded data in a temporary file inside data/filters directory
// . For each filter: // . For each filter:
// . If filter data hasn't changed, just set new update time on file // . If filter data hasn't changed, just set new update time on file
// . If filter data has changed: // . If filter data has changed:
// . rename the old file (1.txt -> 1.txt.old) // . rename the temporary file (<temp> -> 1.txt)
// . store the new data on disk (1.txt) // Note that this method works only on UNIX.
// On Windows we don't pass files to dnsfilter - we pass the whole data.
// . Pass new filters to dnsfilter object - it analyzes new data while the old filters are still active // . Pass new filters to dnsfilter object - it analyzes new data while the old filters are still active
// . dnsfilter activates new filters // . dnsfilter activates new filters
// . Remove the old filter files (1.txt.old)
// //
// Return the number of updated filters // Return the number of updated filters
// Return TRUE - there was a network error and nothing could be updated // Return TRUE - there was a network error and nothing could be updated
func refreshFiltersIfNecessary(flags int) (int, bool) { func (f *Filtering) refreshFiltersIfNecessary(flags int) (int, bool) {
log.Debug("Filters: updating...") log.Debug("Filters: updating...")
updateCount := 0 updateCount := 0
@ -405,13 +414,13 @@ func refreshFiltersIfNecessary(flags int) (int, bool) {
force = true force = true
} }
if (flags & FilterRefreshBlocklists) != 0 { if (flags & FilterRefreshBlocklists) != 0 {
updateCount, updateFilters, updateFlags, netError = refreshFiltersArray(&config.Filters, force) updateCount, updateFilters, updateFlags, netError = f.refreshFiltersArray(&config.Filters, force)
} }
if (flags & FilterRefreshAllowlists) != 0 { if (flags & FilterRefreshAllowlists) != 0 {
updateCountW := 0 updateCountW := 0
var updateFiltersW []filter var updateFiltersW []filter
var updateFlagsW []bool var updateFlagsW []bool
updateCountW, updateFiltersW, updateFlagsW, netErrorW = refreshFiltersArray(&config.WhitelistFilters, force) updateCountW, updateFiltersW, updateFlagsW, netErrorW = f.refreshFiltersArray(&config.WhitelistFilters, force)
updateCount += updateCountW updateCount += updateCountW
updateFilters = append(updateFilters, updateFiltersW...) updateFilters = append(updateFilters, updateFiltersW...)
updateFlags = append(updateFlags, updateFlagsW...) updateFlags = append(updateFlags, updateFlagsW...)
@ -449,21 +458,28 @@ func isPrintableText(data []byte) bool {
} }
// A helper function that parses filter contents and returns a number of rules and a filter name (if there's any) // A helper function that parses filter contents and returns a number of rules and a filter name (if there's any)
func parseFilterContents(contents []byte) (int, string) { func (f *Filtering) parseFilterContents(file io.Reader) (int, uint32, string) {
data := string(contents)
rulesCount := 0 rulesCount := 0
name := "" name := ""
seenTitle := false seenTitle := false
r := bufio.NewReader(file)
checksum := uint32(0)
// Count lines in the filter for {
for len(data) != 0 { line, err := r.ReadString('\n')
line := util.SplitNext(&data, '\n') if err != nil {
break
}
checksum = crc32.Update(checksum, crc32.IEEETable, []byte(line))
line = strings.TrimSpace(line)
if len(line) == 0 { if len(line) == 0 {
continue continue
} }
if line[0] == '!' { if line[0] == '!' {
m := filterTitleRegexp.FindAllStringSubmatch(line, -1) m := f.filterTitleRegexp.FindAllStringSubmatch(line, -1)
if len(m) > 0 && len(m[0]) >= 2 && !seenTitle { if len(m) > 0 && len(m[0]) >= 2 && !seenTitle {
name = m[0][1] name = m[0][1]
seenTitle = true seenTitle = true
@ -473,13 +489,36 @@ func parseFilterContents(contents []byte) (int, string) {
} }
} }
return rulesCount, name return rulesCount, checksum, name
} }
// Perform upgrade on a filter // Perform upgrade on a filter and update LastUpdated value
func (filter *filter) update() (bool, error) { func (f *Filtering) update(filter *filter) (bool, error) {
b, err := f.updateIntl(filter)
filter.LastUpdated = time.Now()
if !b {
e := os.Chtimes(filter.Path(), filter.LastUpdated, filter.LastUpdated)
if e != nil {
log.Error("os.Chtimes(): %v", e)
}
}
return b, err
}
func (f *Filtering) updateIntl(filter *filter) (bool, error) {
log.Tracef("Downloading update for filter %d from %s", filter.ID, filter.URL) log.Tracef("Downloading update for filter %d from %s", filter.ID, filter.URL)
tmpfile, err := ioutil.TempFile(filepath.Join(Context.getDataDir(), filterDir), "")
if err != nil {
return false, err
}
defer func() {
if tmpfile != nil {
_ = tmpfile.Close()
_ = os.Remove(tmpfile.Name())
}
}()
resp, err := Context.client.Get(filter.URL) resp, err := Context.client.Get(filter.URL)
if resp != nil && resp.Body != nil { if resp != nil && resp.Body != nil {
defer resp.Body.Close() defer resp.Body.Close()
@ -494,74 +533,81 @@ func (filter *filter) update() (bool, error) {
return false, fmt.Errorf("got status code != 200: %d", resp.StatusCode) return false, fmt.Errorf("got status code != 200: %d", resp.StatusCode)
} }
body, err := ioutil.ReadAll(resp.Body) htmlTest := true
if err != nil { firstChunk := make([]byte, 4*1024)
log.Printf("Couldn't fetch filter contents from URL %s, skipping: %s", filter.URL, err) firstChunkLen := 0
return false, err buf := make([]byte, 64*1024)
total := 0
for {
n, err := resp.Body.Read(buf)
total += n
if htmlTest {
// gather full buffer firstChunk and perform its data tests
num := util.MinInt(n, len(firstChunk)-firstChunkLen)
copied := copy(firstChunk[firstChunkLen:], buf[:num])
firstChunkLen += copied
if firstChunkLen == len(firstChunk) || err == io.EOF {
if !isPrintableText(firstChunk) {
return false, fmt.Errorf("Data contains non-printable characters")
}
s := strings.ToLower(string(firstChunk))
if strings.Index(s, "<html") >= 0 ||
strings.Index(s, "<!doctype") >= 0 {
return false, fmt.Errorf("Data is HTML, not plain text")
}
htmlTest = false
firstChunk = nil
}
}
_, err2 := tmpfile.Write(buf[:n])
if err2 != nil {
return false, err2
}
if err == io.EOF {
break
}
if err != nil {
log.Printf("Couldn't fetch filter contents from URL %s, skipping: %s", filter.URL, err)
return false, err
}
} }
// Extract filter name and count number of rules
_, _ = tmpfile.Seek(0, io.SeekStart)
rulesCount, checksum, filterName := f.parseFilterContents(tmpfile)
// Check if the filter has been really changed // Check if the filter has been really changed
checksum := crc32.ChecksumIEEE(body)
if filter.checksum == checksum { if filter.checksum == checksum {
log.Tracef("Filter #%d at URL %s hasn't changed, not updating it", filter.ID, filter.URL) log.Tracef("Filter #%d at URL %s hasn't changed, not updating it", filter.ID, filter.URL)
return false, nil return false, nil
} }
var firstChunk []byte log.Printf("Filter %d has been updated: %d bytes, %d rules",
if len(body) <= 4096 { filter.ID, total, rulesCount)
firstChunk = body
} else {
firstChunk = body[:4096]
}
if !isPrintableText(firstChunk) {
return false, fmt.Errorf("Data contains non-printable characters")
}
s := strings.ToLower(string(firstChunk))
if strings.Index(s, "<html") >= 0 ||
strings.Index(s, "<!doctype") >= 0 {
return false, fmt.Errorf("Data is HTML, not plain text")
}
// Extract filter name and count number of rules
rulesCount, filterName := parseFilterContents(body)
log.Printf("Filter %d has been updated: %d bytes, %d rules", filter.ID, len(body), rulesCount)
if filterName != "" { if filterName != "" {
filter.Name = filterName filter.Name = filterName
} }
filter.RulesCount = rulesCount filter.RulesCount = rulesCount
filter.Data = body
filter.checksum = checksum filter.checksum = checksum
filterFilePath := filter.Path()
log.Printf("Saving filter %d contents to: %s", filter.ID, filterFilePath)
err = os.Rename(tmpfile.Name(), filterFilePath)
if err != nil {
return false, err
}
tmpfile.Close()
tmpfile = nil
return true, nil return true, nil
} }
// saves filter contents to the file in dataDir
// This method is safe to call during filters update,
// because it creates a new file and then renames it,
// so the currently opened file descriptors to the old filter file remain valid.
func (filter *filter) save() error {
filterFilePath := filter.Path()
log.Printf("Saving filter %d contents to: %s", filter.ID, filterFilePath)
err := file.SafeWrite(filterFilePath, filter.Data)
// update LastUpdated field after saving the file
filter.LastUpdated = filter.LastTimeUpdated()
return err
}
func (filter *filter) saveAndBackupOld() error {
filterFilePath := filter.Path()
err := os.Rename(filterFilePath, filterFilePath+".old")
if err != nil && !os.IsNotExist(err) {
return err
}
return filter.save()
}
// loads filter contents from the file in dataDir // loads filter contents from the file in dataDir
func (filter *filter) load() error { func (f *Filtering) load(filter *filter) error {
filterFilePath := filter.Path() filterFilePath := filter.Path()
log.Tracef("Loading filter %d contents to: %s", filter.ID, filterFilePath) log.Tracef("Loading filter %d contents to: %s", filter.ID, filterFilePath)
@ -570,17 +616,19 @@ func (filter *filter) load() error {
return err return err
} }
filterFileContents, err := ioutil.ReadFile(filterFilePath) file, err := os.Open(filterFilePath)
if err != nil { if err != nil {
return err return err
} }
defer file.Close()
st, _ := file.Stat()
log.Tracef("File %s, id %d, length %d", filterFilePath, filter.ID, len(filterFileContents)) log.Tracef("File %s, id %d, length %d",
rulesCount, _ := parseFilterContents(filterFileContents) filterFilePath, filter.ID, st.Size())
rulesCount, checksum, _ := f.parseFilterContents(file)
filter.RulesCount = rulesCount filter.RulesCount = rulesCount
filter.Data = nil filter.checksum = checksum
filter.checksum = crc32.ChecksumIEEE(filterFileContents)
filter.LastUpdated = filter.LastTimeUpdated() filter.LastUpdated = filter.LastTimeUpdated()
return nil return nil
@ -588,8 +636,8 @@ func (filter *filter) load() error {
// Clear filter rules // Clear filter rules
func (filter *filter) unload() { func (filter *filter) unload() {
filter.Data = nil
filter.RulesCount = 0 filter.RulesCount = 0
filter.checksum = 0
} }
// Path to the filter contents // Path to the filter contents

View File

@ -12,29 +12,27 @@ import (
func TestFilters(t *testing.T) { func TestFilters(t *testing.T) {
dir := prepareTestDir() dir := prepareTestDir()
defer func() { _ = os.RemoveAll(dir) }() defer func() { _ = os.RemoveAll(dir) }()
Context = homeContext{} Context = homeContext{}
Context.workDir = dir Context.workDir = dir
Context.client = &http.Client{ Context.client = &http.Client{
Timeout: time.Minute * 5, Timeout: 5 * time.Second,
} }
Context.filters.Init()
f := filter{ f := filter{
URL: "https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt", URL: "https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt",
} }
// download // download
ok, err := f.update() ok, err := Context.filters.update(&f)
assert.True(t, ok && err == nil) assert.Equal(t, nil, err)
assert.True(t, ok)
// refresh // refresh
ok, err = f.update() ok, err = Context.filters.update(&f)
assert.True(t, !ok && err == nil) assert.True(t, !ok && err == nil)
err = f.save() err = Context.filters.load(&f)
assert.True(t, err == nil)
err = f.load()
assert.True(t, err == nil) assert.True(t, err == nil)
f.unload() f.unload()

View File

@ -69,6 +69,7 @@ type homeContext struct {
auth *Auth // HTTP authentication module auth *Auth // HTTP authentication module
httpServer *http.Server // HTTP module httpServer *http.Server // HTTP module
httpsServer HTTPSServer // HTTPS module httpsServer HTTPSServer // HTTPS module
filters Filtering
// Runtime properties // Runtime properties
// -- // --

View File

@ -9,10 +9,6 @@ import (
"strings" "strings"
) )
// ---------------------
// general helpers
// ---------------------
// fileExists returns TRUE if file exists // fileExists returns TRUE if file exists
func FileExists(fn string) bool { func FileExists(fn string) bool {
_, err := os.Stat(fn) _, err := os.Stat(fn)
@ -33,9 +29,6 @@ func RunCommand(command string, arguments ...string) (int, string, error) {
return cmd.ProcessState.ExitCode(), string(out), nil return cmd.ProcessState.ExitCode(), string(out), nil
} }
// ---------------------
// debug logging helpers
// ---------------------
func FuncName() string { func FuncName() string {
pc := make([]uintptr, 10) // at least 1 entry needed pc := make([]uintptr, 10) // at least 1 entry needed
runtime.Callers(2, pc) runtime.Callers(2, pc)
@ -57,3 +50,11 @@ func SplitNext(str *string, splitBy byte) string {
} }
return strings.TrimSpace(s) return strings.TrimSpace(s)
} }
// MinInt - return the minimum value
func MinInt(a, b int) int {
if a < b {
return a
}
return b
}