Protect against users deleting the filter ID's in the config file.
Incidentally, it also simplifies upgrade schema from 0 to 1.
This commit is contained in:
parent
6cb991fe7f
commit
701fd10c1c
39
app.go
39
app.go
|
@ -9,6 +9,7 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gobuffalo/packr"
|
"github.com/gobuffalo/packr"
|
||||||
"golang.org/x/crypto/ssh/terminal"
|
"golang.org/x/crypto/ssh/terminal"
|
||||||
|
@ -135,6 +136,34 @@ func main() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Load filters from the disk
|
||||||
|
// And if any filter has zero ID, assign a new one
|
||||||
|
for i := range config.Filters {
|
||||||
|
filter := &config.Filters[i] // otherwise we're operating on a copy
|
||||||
|
if filter.ID == 0 {
|
||||||
|
filter.ID = assignUniqueFilterID()
|
||||||
|
}
|
||||||
|
err := filter.load()
|
||||||
|
if err != nil {
|
||||||
|
// This is okay for the first start, the filter will be loaded later
|
||||||
|
log.Printf("Couldn't load filter %d contents due to %s", filter.ID, err)
|
||||||
|
// clear LastUpdated so it gets fetched right away
|
||||||
|
}
|
||||||
|
if len(filter.Contents) == 0 {
|
||||||
|
filter.LastUpdated = time.Time{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update filters we've just loaded right away, don't wait for periodic update timer
|
||||||
|
go func() {
|
||||||
|
checkFiltersUpdates(false)
|
||||||
|
// Save the updated config
|
||||||
|
err := writeConfig()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
// Eat all args so that coredns can start happily
|
// Eat all args so that coredns can start happily
|
||||||
if len(os.Args) > 1 {
|
if len(os.Args) > 1 {
|
||||||
os.Args = os.Args[:1]
|
os.Args = os.Args[:1]
|
||||||
|
@ -146,16 +175,6 @@ func main() {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load filters from the disk
|
|
||||||
for i := range config.Filters {
|
|
||||||
filter := &config.Filters[i]
|
|
||||||
err = filter.load()
|
|
||||||
if err != nil {
|
|
||||||
// This is okay for the first start, the filter will be loaded later
|
|
||||||
log.Printf("Couldn't load filter %d contents due to %s", filter.ID, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
address := net.JoinHostPort(config.BindHost, strconv.Itoa(config.BindPort))
|
address := net.JoinHostPort(config.BindHost, strconv.Itoa(config.BindPort))
|
||||||
|
|
||||||
runFiltersUpdatesTimer()
|
runFiltersUpdatesTimer()
|
||||||
|
|
26
config.go
26
config.go
|
@ -22,7 +22,7 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
// Just a counter that we use for incrementing the filter ID
|
// Just a counter that we use for incrementing the filter ID
|
||||||
var NextFilterId = time.Now().Unix()
|
var nextFilterID int64 = time.Now().Unix()
|
||||||
|
|
||||||
// configuration is loaded from YAML
|
// configuration is loaded from YAML
|
||||||
// field ordering is important -- yaml fields will mirror ordering from here
|
// field ordering is important -- yaml fields will mirror ordering from here
|
||||||
|
@ -74,7 +74,7 @@ type filter struct {
|
||||||
Name string `json:"name" yaml:"name"`
|
Name string `json:"name" yaml:"name"`
|
||||||
RulesCount int `json:"rulesCount" yaml:"-"`
|
RulesCount int `json:"rulesCount" yaml:"-"`
|
||||||
LastUpdated time.Time `json:"lastUpdated,omitempty" yaml:"last_updated,omitempty"`
|
LastUpdated time.Time `json:"lastUpdated,omitempty" yaml:"last_updated,omitempty"`
|
||||||
ID int64 // auto-assigned when filter is added (see NextFilterId)
|
ID int64 // auto-assigned when filter is added (see nextFilterID)
|
||||||
|
|
||||||
Contents []byte `json:"-" yaml:"-"` // not in yaml or json
|
Contents []byte `json:"-" yaml:"-"` // not in yaml or json
|
||||||
}
|
}
|
||||||
|
@ -165,12 +165,7 @@ func parseConfig() error {
|
||||||
config.Filters = config.Filters[:i]
|
config.Filters = config.Filters[:i]
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set the next filter ID to max(filter.ID) + 1
|
updateUniqueFilterID(config.Filters)
|
||||||
for i := range config.Filters {
|
|
||||||
if NextFilterId < config.Filters[i].ID {
|
|
||||||
NextFilterId = config.Filters[i].ID + 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -293,3 +288,18 @@ func generateCoreDNSConfigText() (string, error) {
|
||||||
configText = removeEmptyLines.ReplaceAllString(configText, "\n")
|
configText = removeEmptyLines.ReplaceAllString(configText, "\n")
|
||||||
return configText, nil
|
return configText, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set the next filter ID to max(filter.ID) + 1
|
||||||
|
func updateUniqueFilterID(filters []filter) {
|
||||||
|
for _, filter := range filters {
|
||||||
|
if nextFilterID < filter.ID {
|
||||||
|
nextFilterID = filter.ID + 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func assignUniqueFilterID() int64 {
|
||||||
|
value := nextFilterID
|
||||||
|
nextFilterID += 1
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
11
control.go
11
control.go
|
@ -343,9 +343,8 @@ func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set necessary properties
|
// Set necessary properties
|
||||||
filter.ID = NextFilterId
|
filter.ID = assignUniqueFilterID()
|
||||||
filter.Enabled = true
|
filter.Enabled = true
|
||||||
NextFilterId++
|
|
||||||
|
|
||||||
// Download the filter contents
|
// Download the filter contents
|
||||||
ok, err := filter.update(true)
|
ok, err := filter.update(true)
|
||||||
|
@ -550,6 +549,11 @@ func checkFiltersUpdates(force bool) int {
|
||||||
updateCount := 0
|
updateCount := 0
|
||||||
for i := range config.Filters {
|
for i := range config.Filters {
|
||||||
filter := &config.Filters[i] // otherwise we will be operating on a copy
|
filter := &config.Filters[i] // otherwise we will be operating on a copy
|
||||||
|
|
||||||
|
if filter.ID == 0 { // protect against users modifying the yaml and removing the ID
|
||||||
|
filter.ID = assignUniqueFilterID()
|
||||||
|
}
|
||||||
|
|
||||||
updated, err := filter.update(force)
|
updated, err := filter.update(force)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Failed to update filter %s: %s\n", filter.URL, err)
|
log.Printf("Failed to update filter %s: %s\n", filter.URL, err)
|
||||||
|
@ -601,6 +605,9 @@ func parseFilterContents(contents []byte) (int, string) {
|
||||||
// If "force" is true -- does not check the filter's LastUpdated field
|
// If "force" is true -- does not check the filter's LastUpdated field
|
||||||
// Call "save" to persist the filter contents
|
// Call "save" to persist the filter contents
|
||||||
func (filter *filter) update(force bool) (bool, error) {
|
func (filter *filter) update(force bool) (bool, error) {
|
||||||
|
if filter.ID == 0 { // protect against users deleting the ID
|
||||||
|
filter.ID = assignUniqueFilterID()
|
||||||
|
}
|
||||||
if !filter.Enabled {
|
if !filter.Enabled {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
25
upgrade.go
25
upgrade.go
|
@ -87,30 +87,7 @@ func upgradeSchema0to1(diskConfig *map[string]interface{}) error {
|
||||||
trace("Called")
|
trace("Called")
|
||||||
|
|
||||||
// The first schema upgrade:
|
// The first schema upgrade:
|
||||||
// Added "ID" field to "filter" -- we need to populate this field now
|
// No more "dnsfilter.txt", filters are now kept in data/filters/
|
||||||
// Added "config.ourDataDir" -- where we will now store filters contents
|
|
||||||
for i := range config.Filters {
|
|
||||||
filter := &config.Filters[i] // otherwise we will be operating on a copy
|
|
||||||
|
|
||||||
// Set the filter ID
|
|
||||||
log.Printf("Seting ID=%d for filter %s", NextFilterId, filter.URL)
|
|
||||||
filter.ID = NextFilterId
|
|
||||||
NextFilterId++
|
|
||||||
|
|
||||||
// Forcibly update the filter
|
|
||||||
_, err := filter.update(true)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Saving it to the filters dir now
|
|
||||||
err = filter.save()
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// No more "dnsfilter.txt", filters are now loaded from config.ourDataDir/filters/
|
|
||||||
dnsFilterPath := filepath.Join(config.ourBinaryDir, "dnsfilter.txt")
|
dnsFilterPath := filepath.Join(config.ourBinaryDir, "dnsfilter.txt")
|
||||||
_, err := os.Stat(dnsFilterPath)
|
_, err := os.Stat(dnsFilterPath)
|
||||||
if !os.IsNotExist(err) {
|
if !os.IsNotExist(err) {
|
||||||
|
|
Loading…
Reference in New Issue