From 646725efb7ea5693fc537ae88ec259ddfa21dd80 Mon Sep 17 00:00:00 2001 From: Simon Zolin Date: Tue, 17 Mar 2020 15:00:40 +0300 Subject: [PATCH] Merge: * DNS filters: optimize filter update Close #1463 Squashed commit of the following: commit d5bdc939a2ae9f6d1ae879e4225b1dce09657b92 Author: Simon Zolin Date: Mon Mar 16 16:39:17 2020 +0300 minor commit e15b56a0d9db182f9d30b434584018cb1bf038d5 Author: Simon Zolin Date: Thu Mar 12 14:39:07 2020 +0300 minor commit 77bf59ca6e556b75af48c5987866af6d5025dae8 Author: Simon Zolin Date: Thu Mar 12 14:30:04 2020 +0300 minor commit e19c13f82dd408ed638bd4b68d21cdfebbdf782f Author: Simon Zolin Date: Thu Mar 12 14:24:50 2020 +0300 minor commit 9113c6dae6263aa7ee6e4295c2b60dd3083e2bf0 Author: Simon Zolin Date: Thu Mar 12 14:02:06 2020 +0300 minor commit 70283e329e32def3375e893f806a2a02d8ca9f57 Author: Simon Zolin Date: Thu Mar 12 13:35:11 2020 +0300 logical module Filtering commit 837a255c6a04941e9fc007a56d71faf4c4213257 Author: Simon Zolin Date: Thu Mar 12 13:11:37 2020 +0300 minor commit 1853ed2b57a86dd49508023f47218219399b4fe5 Author: Simon Zolin Date: Thu Mar 12 12:59:28 2020 +0300 refactor commit 1ba3cc53c76255439fe54693b40ee9665fdc15e4 Author: Simon Zolin Date: Wed Mar 11 20:12:53 2020 +0300 * filters: optimize update procedure --- home/config.go | 7 - home/control.go | 1 - home/control_filtering.go | 74 ++++----- home/dns.go | 6 +- home/filter.go | 308 ++++++++++++++++++++++---------------- home/filter_test.go | 16 +- home/home.go | 1 + util/helpers.go | 15 +- 8 files changed, 227 insertions(+), 201 deletions(-) diff --git a/home/config.go b/home/config.go index 09bdbd9e..657325e2 100644 --- a/home/config.go +++ b/home/config.go @@ -339,12 +339,5 @@ func writeAllConfigs() error { return err } - userFilter := userFilter() - err = userFilter.save() - if err != nil { - log.Error("Couldn't save the user filter: %s", err) - return err - } - return nil } diff --git a/home/control.go b/home/control.go index 031f7a36..ca8939ca 100644 --- a/home/control.go +++ b/home/control.go @@ -173,7 +173,6 @@ func registerControlHandlers() { httpRegister("GET", "/control/profile", handleGetProfile) - RegisterFilteringHandlers() RegisterTLSHandlers() RegisterBlockedServicesHandlers() RegisterAuthHandlers() diff --git a/home/control_filtering.go b/home/control_filtering.go index d78badf6..242bbf3d 100644 --- a/home/control_filtering.go +++ b/home/control_filtering.go @@ -33,7 +33,7 @@ type filterAddJSON struct { Whitelist bool `json:"whitelist"` } -func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) { +func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request) { fj := filterAddJSON{} err := json.NewDecoder(r.Body).Decode(&fj) if err != nil { @@ -53,52 +53,41 @@ func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) { } // Set necessary properties - f := filter{ + filt := filter{ Enabled: true, URL: fj.URL, Name: fj.Name, white: fj.Whitelist, } - f.ID = assignUniqueFilterID() + filt.ID = assignUniqueFilterID() // Download the filter contents - ok, err := f.update() + ok, err := f.update(&filt) if err != nil { - httpError(w, http.StatusBadRequest, "Couldn't fetch filter from url %s: %s", f.URL, err) - return - } - if f.RulesCount == 0 { - httpError(w, http.StatusBadRequest, "Filter at the url %s has no rules (maybe it points to blank page?)", f.URL) + httpError(w, http.StatusBadRequest, "Couldn't fetch filter from url %s: %s", filt.URL, err) return } if !ok { - httpError(w, http.StatusBadRequest, "Filter at the url %s is invalid (maybe it points to blank page?)", f.URL) - return - } - - // Save the filter contents - err = f.save() - if err != nil { - httpError(w, http.StatusBadRequest, "Failed to save filter %d due to %s", f.ID, err) + httpError(w, http.StatusBadRequest, "Filter at the url %s is invalid (maybe it points to blank page?)", filt.URL) return } // URL is deemed valid, append it to filters, update config, write new filter file and tell dns to reload it - if !filterAdd(f) { - httpError(w, http.StatusBadRequest, "Filter URL already added -- %s", f.URL) + if !filterAdd(filt) { + httpError(w, http.StatusBadRequest, "Filter URL already added -- %s", filt.URL) return } onConfigModified() enableFilters(true) - _, err = fmt.Fprintf(w, "OK %d rules\n", f.RulesCount) + _, err = fmt.Fprintf(w, "OK %d rules\n", filt.RulesCount) if err != nil { httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err) } } -func handleFilteringRemoveURL(w http.ResponseWriter, r *http.Request) { +func (f *Filtering) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Request) { type request struct { URL string `json:"url"` @@ -156,7 +145,7 @@ type filterURLReq struct { Data filterURLJSON `json:"data"` } -func handleFilteringSetURL(w http.ResponseWriter, r *http.Request) { +func (f *Filtering) handleFilteringSetURL(w http.ResponseWriter, r *http.Request) { fj := filterURLReq{} err := json.NewDecoder(r.Body).Decode(&fj) if err != nil { @@ -169,12 +158,12 @@ func handleFilteringSetURL(w http.ResponseWriter, r *http.Request) { return } - f := filter{ + filt := filter{ Enabled: fj.Data.Enabled, Name: fj.Data.Name, URL: fj.Data.URL, } - status := filterSetProperties(fj.URL, f, fj.Whitelist) + status := f.filterSetProperties(fj.URL, filt, fj.Whitelist) if (status & statusFound) == 0 { http.Error(w, "URL doesn't exist", http.StatusBadRequest) return @@ -196,7 +185,7 @@ func handleFilteringSetURL(w http.ResponseWriter, r *http.Request) { if fj.Whitelist { flags = FilterRefreshAllowlists } - nUpdated, _ := refreshFilters(flags, true) + nUpdated, _ := f.refreshFilters(flags, true) // if at least 1 filter has been updated, refreshFilters() restarts the filtering automatically // if not - we restart the filtering ourselves restart = false @@ -209,7 +198,7 @@ func handleFilteringSetURL(w http.ResponseWriter, r *http.Request) { } } -func handleFilteringSetRules(w http.ResponseWriter, r *http.Request) { +func (f *Filtering) handleFilteringSetRules(w http.ResponseWriter, r *http.Request) { body, err := ioutil.ReadAll(r.Body) if err != nil { httpError(w, http.StatusBadRequest, "Failed to read request body: %s", err) @@ -218,15 +207,10 @@ func handleFilteringSetRules(w http.ResponseWriter, r *http.Request) { config.UserRules = strings.Split(string(body), "\n") onConfigModified() - userFilter := userFilter() - err = userFilter.save() - if err != nil { - log.Error("Couldn't save the user filter: %s", err) - } enableFilters(true) } -func handleFilteringRefresh(w http.ResponseWriter, r *http.Request) { +func (f *Filtering) handleFilteringRefresh(w http.ResponseWriter, r *http.Request) { type Req struct { White bool `json:"whitelist"` } @@ -248,7 +232,7 @@ func handleFilteringRefresh(w http.ResponseWriter, r *http.Request) { if req.White { flags = FilterRefreshAllowlists } - resp.Updated, err = refreshFilters(flags|FilterRefreshForce, false) + resp.Updated, err = f.refreshFilters(flags|FilterRefreshForce, false) Context.controlLock.Lock() if err != nil { httpError(w, http.StatusInternalServerError, "%s", err) @@ -298,7 +282,7 @@ func filterToJSON(f filter) filterJSON { } // Get filtering configuration -func handleFilteringStatus(w http.ResponseWriter, r *http.Request) { +func (f *Filtering) handleFilteringStatus(w http.ResponseWriter, r *http.Request) { resp := filteringConfig{} config.RLock() resp.Enabled = config.DNS.FilteringEnabled @@ -327,7 +311,7 @@ func handleFilteringStatus(w http.ResponseWriter, r *http.Request) { } // Set filtering configuration -func handleFilteringConfig(w http.ResponseWriter, r *http.Request) { +func (f *Filtering) handleFilteringConfig(w http.ResponseWriter, r *http.Request) { req := filteringConfig{} err := json.NewDecoder(r.Body).Decode(&req) if err != nil { @@ -359,7 +343,7 @@ type checkHostResp struct { IPList []net.IP `json:"ip_addrs"` // list of IP addresses } -func handleCheckHost(w http.ResponseWriter, r *http.Request) { +func (f *Filtering) handleCheckHost(w http.ResponseWriter, r *http.Request) { q := r.URL.Query() host := q.Get("name") @@ -389,15 +373,15 @@ func handleCheckHost(w http.ResponseWriter, r *http.Request) { } // RegisterFilteringHandlers - register handlers -func RegisterFilteringHandlers() { - httpRegister("GET", "/control/filtering/status", handleFilteringStatus) - httpRegister("POST", "/control/filtering/config", handleFilteringConfig) - httpRegister("POST", "/control/filtering/add_url", handleFilteringAddURL) - httpRegister("POST", "/control/filtering/remove_url", handleFilteringRemoveURL) - httpRegister("POST", "/control/filtering/set_url", handleFilteringSetURL) - httpRegister("POST", "/control/filtering/refresh", handleFilteringRefresh) - httpRegister("POST", "/control/filtering/set_rules", handleFilteringSetRules) - httpRegister("GET", "/control/filtering/check_host", handleCheckHost) +func (f *Filtering) RegisterFilteringHandlers() { + httpRegister("GET", "/control/filtering/status", f.handleFilteringStatus) + httpRegister("POST", "/control/filtering/config", f.handleFilteringConfig) + httpRegister("POST", "/control/filtering/add_url", f.handleFilteringAddURL) + httpRegister("POST", "/control/filtering/remove_url", f.handleFilteringRemoveURL) + httpRegister("POST", "/control/filtering/set_url", f.handleFilteringSetURL) + httpRegister("POST", "/control/filtering/refresh", f.handleFilteringRefresh) + httpRegister("POST", "/control/filtering/set_rules", f.handleFilteringSetRules) + httpRegister("GET", "/control/filtering/check_host", f.handleCheckHost) } func checkFiltersUpdateIntervalHours(i uint32) bool { diff --git a/home/dns.go b/home/dns.go index 90461930..6e53dc79 100644 --- a/home/dns.go +++ b/home/dns.go @@ -81,7 +81,7 @@ func initDNSServer() error { Context.rdns = InitRDNS(Context.dnsServer, &Context.clients) Context.whois = initWhois(&Context.clients) - initFiltering() + Context.filters.Init() return nil } @@ -230,7 +230,7 @@ func startDNSServer() error { } Context.dnsFilter.Start() - startFiltering() + Context.filters.Start() Context.stats.Start() Context.queryLog.Start() @@ -300,5 +300,7 @@ func closeDNSServer() { Context.auth = nil } + Context.filters.Close() + log.Debug("Closed all DNS modules") } diff --git a/home/filter.go b/home/filter.go index f9770ca3..af03e230 100644 --- a/home/filter.go +++ b/home/filter.go @@ -1,8 +1,10 @@ package home import ( + "bufio" "fmt" "hash/crc32" + "io" "io/ioutil" "os" "path/filepath" @@ -15,30 +17,50 @@ import ( "github.com/AdguardTeam/AdGuardHome/dnsfilter" "github.com/AdguardTeam/AdGuardHome/util" - "github.com/AdguardTeam/golibs/file" "github.com/AdguardTeam/golibs/log" ) var ( - nextFilterID = time.Now().Unix() // semi-stable way to generate an unique ID - filterTitleRegexp = regexp.MustCompile(`^! Title: +(.*)$`) - refreshStatus uint32 // 0:none; 1:in progress - refreshLock sync.Mutex + nextFilterID = time.Now().Unix() // semi-stable way to generate an unique ID ) -func initFiltering() { - loadFilters(config.Filters) - loadFilters(config.WhitelistFilters) +// type FilteringConf struct { +// BlockLists []filter +// AllowLists []filter +// UserRules []string +// } + +// Filtering - module object +type Filtering struct { + // conf FilteringConf + refreshStatus uint32 // 0:none; 1:in progress + refreshLock sync.Mutex + filterTitleRegexp *regexp.Regexp +} + +// Init - initialize the module +func (f *Filtering) Init() { + f.filterTitleRegexp = regexp.MustCompile(`^! Title: +(.*)$`) + _ = os.MkdirAll(filepath.Join(Context.getDataDir(), filterDir), 0755) + f.loadFilters(config.Filters) + f.loadFilters(config.WhitelistFilters) deduplicateFilters() updateUniqueFilterID(config.Filters) updateUniqueFilterID(config.WhitelistFilters) } -func startFiltering() { +// Start - start the module +func (f *Filtering) Start() { + f.RegisterFilteringHandlers() + // Here we should start updating filters, // but currently we can't wake up the periodic task to do so. // So for now we just start this periodic task from here. - go periodicallyRefreshFilters() + go f.periodicallyRefreshFilters() +} + +// Close - close the module +func (f *Filtering) Close() { } func defaultFilters() []filter { @@ -83,7 +105,7 @@ const ( // Update properties for a filter specified by its URL // Return status* flags. -func filterSetProperties(url string, newf filter, whitelist bool) int { +func (f *Filtering) filterSetProperties(url string, newf filter, whitelist bool) int { r := 0 config.Lock() defer config.Unlock() @@ -94,44 +116,44 @@ func filterSetProperties(url string, newf filter, whitelist bool) int { } for i := range *filters { - f := &(*filters)[i] - if f.URL != url { + filt := &(*filters)[i] + if filt.URL != url { continue } log.Debug("filter: set properties: %s: {%s %s %v}", - f.URL, newf.Name, newf.URL, newf.Enabled) - f.Name = newf.Name + filt.URL, newf.Name, newf.URL, newf.Enabled) + filt.Name = newf.Name - if f.URL != newf.URL { + if filt.URL != newf.URL { r |= statusURLChanged | statusUpdateRequired if filterExistsNoLock(newf.URL) { return statusURLExists } - f.URL = newf.URL - f.unload() - f.LastUpdated = time.Time{} - f.checksum = 0 - f.RulesCount = 0 + filt.URL = newf.URL + filt.unload() + filt.LastUpdated = time.Time{} + filt.checksum = 0 + filt.RulesCount = 0 } - if f.Enabled != newf.Enabled { + if filt.Enabled != newf.Enabled { r |= statusEnabledChanged - f.Enabled = newf.Enabled - if f.Enabled { + filt.Enabled = newf.Enabled + if filt.Enabled { if (r & statusURLChanged) == 0 { - e := f.load() + e := f.load(filt) if e != nil { // This isn't a fatal error, // because it may occur when someone removes the file from disk. - f.LastUpdated = time.Time{} - f.checksum = 0 - f.RulesCount = 0 + filt.LastUpdated = time.Time{} + filt.checksum = 0 + filt.RulesCount = 0 r |= statusUpdateRequired } } } else { - f.unload() + filt.unload() } } @@ -183,7 +205,7 @@ func filterAdd(f filter) bool { // Load filters from the disk // And if any filter has zero ID, assign a new one -func loadFilters(array []filter) { +func (f *Filtering) loadFilters(array []filter) { for i := range array { filter := &array[i] // otherwise we're operating on a copy if filter.ID == 0 { @@ -195,7 +217,7 @@ func loadFilters(array []filter) { continue } - err := filter.load() + err := f.load(filter) if err != nil { log.Error("Couldn't load filter %d contents due to %s", filter.ID, err) } @@ -235,16 +257,16 @@ func assignUniqueFilterID() int64 { } // Sets up a timer that will be checking for filters updates periodically -func periodicallyRefreshFilters() { +func (f *Filtering) periodicallyRefreshFilters() { const maxInterval = 1 * 60 * 60 intval := 5 // use a dynamically increasing time interval for { isNetworkErr := false - if config.DNS.FiltersUpdateIntervalHours != 0 && atomic.CompareAndSwapUint32(&refreshStatus, 0, 1) { - refreshLock.Lock() - _, isNetworkErr = refreshFiltersIfNecessary(FilterRefreshBlocklists | FilterRefreshAllowlists) - refreshLock.Unlock() - refreshStatus = 0 + if config.DNS.FiltersUpdateIntervalHours != 0 && atomic.CompareAndSwapUint32(&f.refreshStatus, 0, 1) { + f.refreshLock.Lock() + _, isNetworkErr = f.refreshFiltersIfNecessary(FilterRefreshBlocklists | FilterRefreshAllowlists) + f.refreshLock.Unlock() + f.refreshStatus = 0 if !isNetworkErr { intval = maxInterval } @@ -265,20 +287,20 @@ func periodicallyRefreshFilters() { // flags: FilterRefresh* // important: // TRUE: ignore the fact that we're currently updating the filters -func refreshFilters(flags int, important bool) (int, error) { - set := atomic.CompareAndSwapUint32(&refreshStatus, 0, 1) +func (f *Filtering) refreshFilters(flags int, important bool) (int, error) { + set := atomic.CompareAndSwapUint32(&f.refreshStatus, 0, 1) if !important && !set { return 0, fmt.Errorf("Filters update procedure is already running") } - refreshLock.Lock() - nUpdated, _ := refreshFiltersIfNecessary(flags) - refreshLock.Unlock() - refreshStatus = 0 + f.refreshLock.Lock() + nUpdated, _ := f.refreshFiltersIfNecessary(flags) + f.refreshLock.Unlock() + f.refreshStatus = 0 return nUpdated, nil } -func refreshFiltersArray(filters *[]filter, force bool) (int, []filter, []bool, bool) { +func (f *Filtering) refreshFiltersArray(filters *[]filter, force bool) (int, []filter, []bool, bool) { var updateFilters []filter var updateFlags []bool // 'true' if filter data has changed @@ -312,14 +334,13 @@ func refreshFiltersArray(filters *[]filter, force bool) (int, []filter, []bool, nfail := 0 for i := range updateFilters { uf := &updateFilters[i] - updated, err := uf.update() + updated, err := f.update(uf) updateFlags = append(updateFlags, updated) if err != nil { nfail++ log.Printf("Failed to update filter %s: %s\n", uf.URL, err) continue } - uf.LastUpdated = now } if nfail == len(updateFilters) { @@ -330,18 +351,6 @@ func refreshFiltersArray(filters *[]filter, force bool) (int, []filter, []bool, for i := range updateFilters { uf := &updateFilters[i] updated := updateFlags[i] - if updated { - err := uf.saveAndBackupOld() - if err != nil { - log.Printf("Failed to save the updated filter %d: %s", uf.ID, err) - continue - } - } else { - e := os.Chtimes(uf.Path(), uf.LastUpdated, uf.LastUpdated) - if e != nil { - log.Error("os.Chtimes(): %v", e) - } - } config.Lock() for k := range *filters { @@ -357,7 +366,6 @@ func refreshFiltersArray(filters *[]filter, force bool) (int, []filter, []bool, log.Info("Updated filter #%d. Rules: %d -> %d", f.ID, f.RulesCount, uf.RulesCount) f.Name = uf.Name - f.Data = nil f.RulesCount = uf.RulesCount f.checksum = uf.checksum updateCount++ @@ -381,18 +389,19 @@ const ( // Algorithm: // . Get the list of filters to be updated // . For each filter run the download and checksum check operation +// . Store downloaded data in a temporary file inside data/filters directory // . For each filter: // . If filter data hasn't changed, just set new update time on file // . If filter data has changed: -// . rename the old file (1.txt -> 1.txt.old) -// . store the new data on disk (1.txt) +// . rename the temporary file ( -> 1.txt) +// Note that this method works only on UNIX. +// On Windows we don't pass files to dnsfilter - we pass the whole data. // . Pass new filters to dnsfilter object - it analyzes new data while the old filters are still active // . dnsfilter activates new filters -// . Remove the old filter files (1.txt.old) // // Return the number of updated filters // Return TRUE - there was a network error and nothing could be updated -func refreshFiltersIfNecessary(flags int) (int, bool) { +func (f *Filtering) refreshFiltersIfNecessary(flags int) (int, bool) { log.Debug("Filters: updating...") updateCount := 0 @@ -405,13 +414,13 @@ func refreshFiltersIfNecessary(flags int) (int, bool) { force = true } if (flags & FilterRefreshBlocklists) != 0 { - updateCount, updateFilters, updateFlags, netError = refreshFiltersArray(&config.Filters, force) + updateCount, updateFilters, updateFlags, netError = f.refreshFiltersArray(&config.Filters, force) } if (flags & FilterRefreshAllowlists) != 0 { updateCountW := 0 var updateFiltersW []filter var updateFlagsW []bool - updateCountW, updateFiltersW, updateFlagsW, netErrorW = refreshFiltersArray(&config.WhitelistFilters, force) + updateCountW, updateFiltersW, updateFlagsW, netErrorW = f.refreshFiltersArray(&config.WhitelistFilters, force) updateCount += updateCountW updateFilters = append(updateFilters, updateFiltersW...) updateFlags = append(updateFlags, updateFlagsW...) @@ -449,21 +458,28 @@ func isPrintableText(data []byte) bool { } // A helper function that parses filter contents and returns a number of rules and a filter name (if there's any) -func parseFilterContents(contents []byte) (int, string) { - data := string(contents) +func (f *Filtering) parseFilterContents(file io.Reader) (int, uint32, string) { rulesCount := 0 name := "" seenTitle := false + r := bufio.NewReader(file) + checksum := uint32(0) - // Count lines in the filter - for len(data) != 0 { - line := util.SplitNext(&data, '\n') + for { + line, err := r.ReadString('\n') + if err != nil { + break + } + + checksum = crc32.Update(checksum, crc32.IEEETable, []byte(line)) + + line = strings.TrimSpace(line) if len(line) == 0 { continue } if line[0] == '!' { - m := filterTitleRegexp.FindAllStringSubmatch(line, -1) + m := f.filterTitleRegexp.FindAllStringSubmatch(line, -1) if len(m) > 0 && len(m[0]) >= 2 && !seenTitle { name = m[0][1] seenTitle = true @@ -473,13 +489,36 @@ func parseFilterContents(contents []byte) (int, string) { } } - return rulesCount, name + return rulesCount, checksum, name } -// Perform upgrade on a filter -func (filter *filter) update() (bool, error) { +// Perform upgrade on a filter and update LastUpdated value +func (f *Filtering) update(filter *filter) (bool, error) { + b, err := f.updateIntl(filter) + filter.LastUpdated = time.Now() + if !b { + e := os.Chtimes(filter.Path(), filter.LastUpdated, filter.LastUpdated) + if e != nil { + log.Error("os.Chtimes(): %v", e) + } + } + return b, err +} + +func (f *Filtering) updateIntl(filter *filter) (bool, error) { log.Tracef("Downloading update for filter %d from %s", filter.ID, filter.URL) + tmpfile, err := ioutil.TempFile(filepath.Join(Context.getDataDir(), filterDir), "") + if err != nil { + return false, err + } + defer func() { + if tmpfile != nil { + _ = tmpfile.Close() + _ = os.Remove(tmpfile.Name()) + } + }() + resp, err := Context.client.Get(filter.URL) if resp != nil && resp.Body != nil { defer resp.Body.Close() @@ -494,74 +533,81 @@ func (filter *filter) update() (bool, error) { return false, fmt.Errorf("got status code != 200: %d", resp.StatusCode) } - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - log.Printf("Couldn't fetch filter contents from URL %s, skipping: %s", filter.URL, err) - return false, err + htmlTest := true + firstChunk := make([]byte, 4*1024) + firstChunkLen := 0 + buf := make([]byte, 64*1024) + total := 0 + for { + n, err := resp.Body.Read(buf) + total += n + + if htmlTest { + // gather full buffer firstChunk and perform its data tests + num := util.MinInt(n, len(firstChunk)-firstChunkLen) + copied := copy(firstChunk[firstChunkLen:], buf[:num]) + firstChunkLen += copied + + if firstChunkLen == len(firstChunk) || err == io.EOF { + if !isPrintableText(firstChunk) { + return false, fmt.Errorf("Data contains non-printable characters") + } + + s := strings.ToLower(string(firstChunk)) + if strings.Index(s, "= 0 || + strings.Index(s, "= 0 { + return false, fmt.Errorf("Data is HTML, not plain text") + } + + htmlTest = false + firstChunk = nil + } + } + + _, err2 := tmpfile.Write(buf[:n]) + if err2 != nil { + return false, err2 + } + + if err == io.EOF { + break + } + if err != nil { + log.Printf("Couldn't fetch filter contents from URL %s, skipping: %s", filter.URL, err) + return false, err + } } + // Extract filter name and count number of rules + _, _ = tmpfile.Seek(0, io.SeekStart) + rulesCount, checksum, filterName := f.parseFilterContents(tmpfile) // Check if the filter has been really changed - checksum := crc32.ChecksumIEEE(body) if filter.checksum == checksum { log.Tracef("Filter #%d at URL %s hasn't changed, not updating it", filter.ID, filter.URL) return false, nil } - var firstChunk []byte - if len(body) <= 4096 { - firstChunk = body - } else { - firstChunk = body[:4096] - } - if !isPrintableText(firstChunk) { - return false, fmt.Errorf("Data contains non-printable characters") - } - - s := strings.ToLower(string(firstChunk)) - if strings.Index(s, "= 0 || - strings.Index(s, "= 0 { - return false, fmt.Errorf("Data is HTML, not plain text") - } - - // Extract filter name and count number of rules - rulesCount, filterName := parseFilterContents(body) - log.Printf("Filter %d has been updated: %d bytes, %d rules", filter.ID, len(body), rulesCount) + log.Printf("Filter %d has been updated: %d bytes, %d rules", + filter.ID, total, rulesCount) if filterName != "" { filter.Name = filterName } filter.RulesCount = rulesCount - filter.Data = body filter.checksum = checksum + filterFilePath := filter.Path() + log.Printf("Saving filter %d contents to: %s", filter.ID, filterFilePath) + err = os.Rename(tmpfile.Name(), filterFilePath) + if err != nil { + return false, err + } + tmpfile.Close() + tmpfile = nil return true, nil } -// saves filter contents to the file in dataDir -// This method is safe to call during filters update, -// because it creates a new file and then renames it, -// so the currently opened file descriptors to the old filter file remain valid. -func (filter *filter) save() error { - filterFilePath := filter.Path() - log.Printf("Saving filter %d contents to: %s", filter.ID, filterFilePath) - - err := file.SafeWrite(filterFilePath, filter.Data) - - // update LastUpdated field after saving the file - filter.LastUpdated = filter.LastTimeUpdated() - return err -} - -func (filter *filter) saveAndBackupOld() error { - filterFilePath := filter.Path() - err := os.Rename(filterFilePath, filterFilePath+".old") - if err != nil && !os.IsNotExist(err) { - return err - } - return filter.save() -} - // loads filter contents from the file in dataDir -func (filter *filter) load() error { +func (f *Filtering) load(filter *filter) error { filterFilePath := filter.Path() log.Tracef("Loading filter %d contents to: %s", filter.ID, filterFilePath) @@ -570,17 +616,19 @@ func (filter *filter) load() error { return err } - filterFileContents, err := ioutil.ReadFile(filterFilePath) + file, err := os.Open(filterFilePath) if err != nil { return err } + defer file.Close() + st, _ := file.Stat() - log.Tracef("File %s, id %d, length %d", filterFilePath, filter.ID, len(filterFileContents)) - rulesCount, _ := parseFilterContents(filterFileContents) + log.Tracef("File %s, id %d, length %d", + filterFilePath, filter.ID, st.Size()) + rulesCount, checksum, _ := f.parseFilterContents(file) filter.RulesCount = rulesCount - filter.Data = nil - filter.checksum = crc32.ChecksumIEEE(filterFileContents) + filter.checksum = checksum filter.LastUpdated = filter.LastTimeUpdated() return nil @@ -588,8 +636,8 @@ func (filter *filter) load() error { // Clear filter rules func (filter *filter) unload() { - filter.Data = nil filter.RulesCount = 0 + filter.checksum = 0 } // Path to the filter contents diff --git a/home/filter_test.go b/home/filter_test.go index edda556a..7449f037 100644 --- a/home/filter_test.go +++ b/home/filter_test.go @@ -12,29 +12,27 @@ import ( func TestFilters(t *testing.T) { dir := prepareTestDir() defer func() { _ = os.RemoveAll(dir) }() - Context = homeContext{} Context.workDir = dir Context.client = &http.Client{ - Timeout: time.Minute * 5, + Timeout: 5 * time.Second, } + Context.filters.Init() f := filter{ URL: "https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt", } // download - ok, err := f.update() - assert.True(t, ok && err == nil) + ok, err := Context.filters.update(&f) + assert.Equal(t, nil, err) + assert.True(t, ok) // refresh - ok, err = f.update() + ok, err = Context.filters.update(&f) assert.True(t, !ok && err == nil) - err = f.save() - assert.True(t, err == nil) - - err = f.load() + err = Context.filters.load(&f) assert.True(t, err == nil) f.unload() diff --git a/home/home.go b/home/home.go index e30d4f7a..760d48ed 100644 --- a/home/home.go +++ b/home/home.go @@ -69,6 +69,7 @@ type homeContext struct { auth *Auth // HTTP authentication module httpServer *http.Server // HTTP module httpsServer HTTPSServer // HTTPS module + filters Filtering // Runtime properties // -- diff --git a/util/helpers.go b/util/helpers.go index c50c940d..45b3311e 100644 --- a/util/helpers.go +++ b/util/helpers.go @@ -9,10 +9,6 @@ import ( "strings" ) -// --------------------- -// general helpers -// --------------------- - // fileExists returns TRUE if file exists func FileExists(fn string) bool { _, err := os.Stat(fn) @@ -33,9 +29,6 @@ func RunCommand(command string, arguments ...string) (int, string, error) { return cmd.ProcessState.ExitCode(), string(out), nil } -// --------------------- -// debug logging helpers -// --------------------- func FuncName() string { pc := make([]uintptr, 10) // at least 1 entry needed runtime.Callers(2, pc) @@ -57,3 +50,11 @@ func SplitNext(str *string, splitBy byte) string { } return strings.TrimSpace(s) } + +// MinInt - return the minimum value +func MinInt(a, b int) int { + if a < b { + return a + } + return b +}