diff --git a/app.go b/app.go index 91f4fbdd..ba8aee3c 100644 --- a/app.go +++ b/app.go @@ -12,7 +12,6 @@ import ( "strconv" "sync" "syscall" - "time" "github.com/AdguardTeam/golibs/log" "github.com/gobuffalo/packr" @@ -101,24 +100,7 @@ func run(args options) { config.BindPort = args.bindPort } - // Load filters from the disk - // And if any filter has zero ID, assign a new one - for i := range config.Filters { - filter := &config.Filters[i] // otherwise we're operating on a copy - if filter.ID == 0 { - filter.ID = assignUniqueFilterID() - } - err = filter.load() - if err != nil { - // This is okay for the first start, the filter will be loaded later - log.Debug("Couldn't load filter %d contents due to %s", filter.ID, err) - // clear LastUpdated so it gets fetched right away - } - - if len(filter.Rules) == 0 { - filter.LastUpdated = time.Time{} - } - } + loadFilters() // Save the updated config err = config.write() @@ -145,11 +127,6 @@ func run(args options) { // Update filters we've just loaded right away, don't wait for periodic update timer go func() { refreshFiltersIfNecessary(false) - // Save the updated config - err := config.write() - if err != nil { - log.Fatal(err) - } }() // Schedule automatic filters updates go periodicallyRefreshFilters() diff --git a/control.go b/control.go index ad2da0ca..be0d61b2 100644 --- a/control.go +++ b/control.go @@ -576,11 +576,9 @@ func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) { } // Check for duplicates - for i := range config.Filters { - if config.Filters[i].URL == f.URL { - httpError(w, http.StatusBadRequest, "Filter URL already added -- %s", f.URL) - return - } + if filterExists(f.URL) { + httpError(w, http.StatusBadRequest, "Filter URL already added -- %s", f.URL) + return } // Set necessary properties @@ -588,7 +586,7 @@ func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) { f.Enabled = true // Download the filter contents - ok, err := f.update(true) + ok, err := f.update() if err != nil { httpError(w, http.StatusBadRequest, "Couldn't fetch filter from url %s: %s", f.URL, err) return @@ -611,7 +609,11 @@ func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) { // URL is deemed valid, append it to filters, update config, write new filter file and tell dns to reload it // TODO: since we directly feed filters in-memory, revisit if writing configs is always necessary - config.Filters = append(config.Filters, f) + if !filterAdd(f) { + httpError(w, http.StatusBadRequest, "Filter URL already added -- %s", f.URL) + return + } + err = writeAllConfigs() if err != nil { httpError(w, http.StatusInternalServerError, "Couldn't write config file: %s", err) @@ -621,6 +623,7 @@ func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) { err = reconfigureDNSServer() if err != nil { httpError(w, http.StatusInternalServerError, "Couldn't reconfigure the DNS server: %s", err) + return } _, err = fmt.Fprintf(w, "OK %d rules\n", f.RulesCount) @@ -649,6 +652,7 @@ func handleFilteringRemoveURL(w http.ResponseWriter, r *http.Request) { } // go through each element and delete if url matches + config.Lock() newFilters := config.Filters[:0] for _, filter := range config.Filters { if filter.URL != url { @@ -664,6 +668,7 @@ func handleFilteringRemoveURL(w http.ResponseWriter, r *http.Request) { } // Update the configuration after removing filter files config.Filters = newFilters + config.Unlock() httpUpdateConfigReloadDNSReturnOK(w, r) } @@ -686,22 +691,12 @@ func handleFilteringEnableURL(w http.ResponseWriter, r *http.Request) { return } - found := false - for i := range config.Filters { - filter := &config.Filters[i] // otherwise we will be operating on a copy - if filter.URL == url { - filter.Enabled = true - found = true - } - } - + found := filterEnable(url, true) if !found { http.Error(w, "URL parameter was not previously added", http.StatusBadRequest) return } - // kick off refresh of rules from new URLs - refreshFiltersIfNecessary(false) httpUpdateConfigReloadDNSReturnOK(w, r) } @@ -724,15 +719,7 @@ func handleFilteringDisableURL(w http.ResponseWriter, r *http.Request) { return } - found := false - for i := range config.Filters { - filter := &config.Filters[i] // otherwise we will be operating on a copy - if filter.URL == url { - filter.Enabled = false - found = true - } - } - + found := filterEnable(url, false) if !found { http.Error(w, "URL parameter was not previously added", http.StatusBadRequest) return diff --git a/filter.go b/filter.go index 9625b207..819e683e 100644 --- a/filter.go +++ b/filter.go @@ -2,10 +2,10 @@ package main import ( "fmt" + "hash/crc32" "io/ioutil" "os" "path/filepath" - "reflect" "regexp" "strconv" "strings" @@ -28,6 +28,7 @@ type filter struct { Name string `json:"name" yaml:"name"` RulesCount int `json:"rulesCount" yaml:"-"` LastUpdated time.Time `json:"lastUpdated,omitempty" yaml:"-"` + checksum uint32 // checksum of the file data dnsfilter.Filter `yaml:",inline"` } @@ -43,6 +44,88 @@ func userFilter() filter { } } +// Enable or disable a filter +func filterEnable(url string, enable bool) bool { + r := false + config.Lock() + for i := range config.Filters { + filter := &config.Filters[i] // otherwise we will be operating on a copy + if filter.URL == url { + filter.Enabled = enable + if enable { + e := filter.load() + if e != nil { + // This isn't a fatal error, + // because it may occur when someone removes the file from disk. + // In this case the periodic update task will try to download the file. + filter.LastUpdated = time.Time{} + log.Tracef("%s filter load: %v", url, e) + } + } else { + filter.unload() + } + r = true + break + } + } + config.Unlock() + return r +} + +// Return TRUE if a filter with this URL exists +func filterExists(url string) bool { + r := false + config.RLock() + for i := range config.Filters { + if config.Filters[i].URL == url { + r = true + break + } + } + config.RUnlock() + return r +} + +// Add a filter +// Return FALSE if a filter with this URL exists +func filterAdd(f filter) bool { + config.Lock() + + // Check for duplicates + for i := range config.Filters { + if config.Filters[i].URL == f.URL { + config.Unlock() + return false + } + } + + config.Filters = append(config.Filters, f) + config.Unlock() + return true +} + +// Load filters from the disk +// And if any filter has zero ID, assign a new one +func loadFilters() { + for i := range config.Filters { + filter := &config.Filters[i] // otherwise we're operating on a copy + if filter.ID == 0 { + filter.ID = assignUniqueFilterID() + } + + if !filter.Enabled { + // No need to load a filter that is not enabled + continue + } + + err := filter.load() + if err != nil { + // This is okay for the first start, the filter will be loaded later + log.Debug("Couldn't load filter %d contents due to %s", filter.ID, err) + } + } +} + func deduplicateFilters() { // Deduplicate filters i := 0 // output index, used for deletion later @@ -84,41 +167,83 @@ func periodicallyRefreshFilters() { // Checks filters updates if necessary // If force is true, it ignores the filter.LastUpdated field value +// +// Algorithm: +// . Get the list of filters to be updated +// . For each filter run the download and checksum check operation +// . If filter data hasn't changed, set new update time +// . If filter data has changed, parse it, save it on disk, set new update time +// . Apply changes to the current configuration +// . Restart server func refreshFiltersIfNecessary(force bool) int { - config.Lock() + var updateFilters []filter - // fetch URLs - updateCount := 0 + config.RLock() for i := range config.Filters { - filter := &config.Filters[i] // otherwise we will be operating on a copy + f := &config.Filters[i] // otherwise we will be operating on a copy - if filter.ID == 0 { // protect against users modifying the yaml and removing the ID - filter.ID = assignUniqueFilterID() + if !f.Enabled { + continue } - if len(filter.Rules) == 0 { - // Try reloading filter from the disk before updating - // This is useful for the case when we simply enable a previously downloaded filter - _ = filter.load() + if !force && time.Since(f.LastUpdated) <= updatePeriod { + continue } - updated, err := filter.update(force) + var uf filter + uf.ID = f.ID + uf.URL = f.URL + uf.checksum = f.checksum + updateFilters = append(updateFilters, uf) + } + config.RUnlock() + + updateCount := 0 + for i := range updateFilters { + uf := &updateFilters[i] + updated, err := uf.update() if err != nil { - log.Printf("Failed to update filter %s: %s\n", filter.URL, err) + log.Printf("Failed to update filter %s: %s\n", uf.URL, err) continue } if updated { // Saving it to the filters dir now - err = filter.save() + err = uf.save() if err != nil { - log.Printf("Failed to save the updated filter %d: %s", filter.ID, err) + log.Printf("Failed to save the updated filter %d: %s", uf.ID, err) continue } + } else { + mtime := time.Now() + e := os.Chtimes(uf.Path(), mtime, mtime) + if e != nil { + log.Error("os.Chtimes(): %v", e) + } + uf.LastUpdated = mtime + } + + config.Lock() + for k := range config.Filters { + f := &config.Filters[k] + if f.ID != uf.ID || f.URL != uf.URL { + continue + } + f.LastUpdated = uf.LastUpdated + if !updated { + continue + } + + log.Info("Updated filter #%d. Rules: %d -> %d", + f.ID, f.RulesCount, uf.RulesCount) + f.Name = uf.Name + f.Rules = uf.Rules + f.RulesCount = uf.RulesCount + f.checksum = uf.checksum updateCount++ } + config.Unlock() } - config.Unlock() if updateCount > 0 && isRunning() { err := reconfigureDNSServer() @@ -139,13 +264,19 @@ func parseFilterContents(contents []byte) (int, string, []string) { // Count lines in the filter for _, line := range lines { + line = strings.TrimSpace(line) - if len(line) > 0 && line[0] == '!' { - if m := filterTitleRegexp.FindAllStringSubmatch(line, -1); len(m) > 0 && len(m[0]) >= 2 && !seenTitle { + if len(line) == 0 { + continue + } + + if line[0] == '!' { + m := filterTitleRegexp.FindAllStringSubmatch(line, -1) + if len(m) > 0 && len(m[0]) >= 2 && !seenTitle { name = m[0][1] seenTitle = true } - } else if len(line) != 0 { + } else { rulesCount++ } } @@ -153,20 +284,8 @@ func parseFilterContents(contents []byte) (int, string, []string) { return rulesCount, name, lines } -// Checks for filters updates -// If "force" is true -- does not check the filter's LastUpdated field -// Call "save" to persist the filter contents -func (filter *filter) update(force bool) (bool, error) { - if filter.ID == 0 { // protect against users deleting the ID - filter.ID = assignUniqueFilterID() - } - if !filter.Enabled { - return false, nil - } - if !force && time.Since(filter.LastUpdated) <= updatePeriod { - return false, nil - } - +// Perform upgrade on a filter +func (filter *filter) update() (bool, error) { log.Tracef("Downloading update for filter %d from %s", filter.ID, filter.URL) resp, err := client.Get(filter.URL) @@ -195,22 +314,22 @@ func (filter *filter) update(force bool) (bool, error) { return false, err } - // Extract filter name and count number of rules - rulesCount, filterName, rules := parseFilterContents(body) - - if filterName != "" { - filter.Name = filterName - } - // Check if the filter has been really changed - if reflect.DeepEqual(filter.Rules, rules) { + checksum := crc32.ChecksumIEEE(body) + if filter.checksum == checksum { log.Tracef("Filter #%d at URL %s hasn't changed, not updating it", filter.ID, filter.URL) return false, nil } + // Extract filter name and count number of rules + rulesCount, filterName, rules := parseFilterContents(body) log.Printf("Filter %d has been updated: %d bytes, %d rules", filter.ID, len(body), rulesCount) + if filterName != "" { + filter.Name = filterName + } filter.RulesCount = rulesCount filter.Rules = rules + filter.checksum = checksum return true, nil } @@ -230,11 +349,6 @@ func (filter *filter) save() error { // loads filter contents from the file in dataDir func (filter *filter) load() error { - if !filter.Enabled { - // No need to load a filter that is not enabled - return nil - } - filterFilePath := filter.Path() log.Tracef("Loading filter %d contents to: %s", filter.ID, filterFilePath) @@ -253,11 +367,18 @@ func (filter *filter) load() error { filter.RulesCount = rulesCount filter.Rules = rules + filter.checksum = crc32.ChecksumIEEE(filterFileContents) filter.LastUpdated = filter.LastTimeUpdated() return nil } +// Clear filter rules +func (filter *filter) unload() { + filter.Rules = []string{} + filter.RulesCount = 0 +} + // Path to the filter contents func (filter *filter) Path() string { return filepath.Join(config.ourWorkingDir, dataDir, filterDir, strconv.FormatInt(filter.ID, 10)+".txt") @@ -266,12 +387,12 @@ func (filter *filter) Path() string { // LastTimeUpdated returns the time when the filter was last time updated func (filter *filter) LastTimeUpdated() time.Time { filterFilePath := filter.Path() - if _, err := os.Stat(filterFilePath); os.IsNotExist(err) { + s, err := os.Stat(filterFilePath) + if os.IsNotExist(err) { // if the filter file does not exist, return 0001-01-01 return time.Time{} } - s, err := os.Stat(filterFilePath) if err != nil { // if the filter file does not exist, return 0001-01-01 return time.Time{}