Merge pull request #137 in DNS/adguard-dns from fix/557 to master

* commit 'ec6b1f7c42c8d2fc413d29fba55430b89fcbce2d':
  Added golangci-lint configuration and prepared for the integrattion
  Added codecov, goreport
  Use EnableAll in gometalinter config
  gometalinter
This commit is contained in:
Eugene Bujak 2019-01-25 20:14:03 +03:00
commit 0161509b5f
28 changed files with 577 additions and 464 deletions

8
.codecov.yml Normal file
View File

@ -0,0 +1,8 @@
coverage:
status:
project:
default:
target: 40%
threshold: null
patch: false
changes: false

4
.gitignore vendored
View File

@ -12,5 +12,5 @@
/scripts/translations/oneskyapp.json /scripts/translations/oneskyapp.json
# Test output # Test output
dnsfilter/dnsfilter.TestLotsOfRules*.pprof dnsfilter/tests/top-1m.csv
tests/top-1m.csv dnsfilter/tests/dnsfilter.TestLotsOfRules*.pprof

55
.golangci.yml Normal file
View File

@ -0,0 +1,55 @@
# options for analysis running
run:
# default concurrency is a available CPU number
concurrency: 4
# timeout for analysis, e.g. 30s, 5m, default is 1m
deadline: 2m
# which files to skip: they will be analyzed, but issues from them
# won't be reported. Default value is empty list, but there is
# no need to include all autogenerated files, we confidently recognize
# autogenerated files. If it's not please let us know.
skip-files:
- ".*generated.*"
- dnsfilter/rule_to_regexp.go
# all available settings of specific linters
linters-settings:
errcheck:
# [deprecated] comma-separated list of pairs of the form pkg:regex
# the regex is used to ignore names within pkg. (default "fmt:.*").
# see https://github.com/kisielk/errcheck#the-deprecated-method for details
ignore: fmt:.*,net:SetReadDeadline,net/http:^Write
gocyclo:
min-complexity: 20
lll:
line-length: 200
linters:
enable-all: true
disable:
- interfacer
- gocritic
- scopelint
- gochecknoglobals
- gochecknoinits
- prealloc
- maligned
- goconst # disabled until it's possible to configure
fast: true
issues:
# List of regexps of issue texts to exclude, empty list by default.
# But independently from this option we use default exclude patterns,
# it can be disabled by `exclude-use-default: false`. To list all
# excluded by default patterns execute `golangci-lint run --help`
exclude:
# structcheck cannot detect usages while they're there
- .parentalServer. is unused
- .safeBrowsingServer. is unused
# errcheck
- Error return value of .s.closeConn. is not checked
# goconst
- string .forcesafesearch.google.com. has 3 occurrences

23
.gometalinter.json Normal file
View File

@ -0,0 +1,23 @@
{
"Vendor": true,
"Test": true,
"Deadline": "2m",
"Sort": ["linter", "severity", "path", "line"],
"Exclude": [
".*generated.*",
"dnsfilter/rule_to_regexp.go"
],
"EnableGC": true,
"Linters": {
"nakedret": {
"Command": "nakedret",
"Pattern": "^(?P<path>.*?\\.go):(?P<line>\\d+)\\s*(?P<message>.*)$"
}
},
"WarnUnmatchedDirective": true,
"EnableAll": true,
"Cyclo": 20,
"LineLength": 200
}

View File

@ -24,7 +24,11 @@ install:
script: script:
- node -v - node -v
- npm -v - npm -v
- go test ./... # Run tests
- go test -race -v -bench=. -coverprofile=coverage.txt -covermode=atomic ./...
# Make
- make build/static/index.html - make build/static/index.html
- make - make
after_success:
- bash <(curl -s https://codecov.io/bash)

View File

@ -11,11 +11,18 @@
<a href="https://adguard.com/">AdGuard.com</a> | <a href="https://adguard.com/">AdGuard.com</a> |
<a href="https://github.com/AdguardTeam/AdGuardHome/wiki">Wiki</a> | <a href="https://github.com/AdguardTeam/AdGuardHome/wiki">Wiki</a> |
<a href="https://reddit.com/r/Adguard">Reddit</a> | <a href="https://reddit.com/r/Adguard">Reddit</a> |
<a href="https://twitter.com/AdGuard">Twitter</a> <a href="https://twitter.com/AdGuard">Twitter</a> |
<a href="https://t.me/adguard_en">Telegram</a>
<br /><br /> <br /><br />
<a href="https://travis-ci.org/AdguardTeam/AdGuardHome"> <a href="https://travis-ci.org/AdguardTeam/AdGuardHome">
<img src="https://travis-ci.org/AdguardTeam/AdGuardHome.svg" alt="Build status" /> <img src="https://travis-ci.org/AdguardTeam/AdGuardHome.svg" alt="Build status" />
</a> </a>
<a href="https://codecov.io/github/AdguardTeam/AdGuardHome?branch=master">
<img src="https://img.shields.io/codecov/c/github/AdguardTeam/AdGuardHome/master.svg" alt="Code Coverage" />
</a>
<a href="https://goreportcard.com/report/AdguardTeam/AdGuardHome">
<img src="https://goreportcard.com/badge/github.com/AdguardTeam/AdGuardHome" alt="Go Report Card" />
</a>
<a href="https://github.com/AdguardTeam/AdGuardHome/releases"> <a href="https://github.com/AdguardTeam/AdGuardHome/releases">
<img src="https://img.shields.io/github/release/AdguardTeam/AdGuardHome/all.svg" alt="Latest release" /> <img src="https://img.shields.io/github/release/AdguardTeam/AdGuardHome/all.svg" alt="Latest release" />
</a> </a>

163
app.go
View File

@ -43,7 +43,89 @@ func main() {
// config can be specified, which reads options from there, but other command line flags have to override config values // config can be specified, which reads options from there, but other command line flags have to override config values
// therefore, we must do it manually instead of using a lib // therefore, we must do it manually instead of using a lib
{ loadOptions()
// Load filters from the disk
// And if any filter has zero ID, assign a new one
for i := range config.Filters {
filter := &config.Filters[i] // otherwise we're operating on a copy
if filter.ID == 0 {
filter.ID = assignUniqueFilterID()
}
err := filter.load()
if err != nil {
// This is okay for the first start, the filter will be loaded later
log.Printf("Couldn't load filter %d contents due to %s", filter.ID, err)
// clear LastUpdated so it gets fetched right away
}
if len(filter.Rules) == 0 {
filter.LastUpdated = time.Time{}
}
}
// Update filters we've just loaded right away, don't wait for periodic update timer
go func() {
refreshFiltersIfNecessary(false)
// Save the updated config
err := config.write()
if err != nil {
log.Fatal(err)
}
}()
signalChannel := make(chan os.Signal)
signal.Notify(signalChannel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT)
go func() {
<-signalChannel
cleanup()
os.Exit(0)
}()
// Save the updated config
err := config.write()
if err != nil {
log.Fatal(err)
}
address := net.JoinHostPort(config.BindHost, strconv.Itoa(config.BindPort))
go periodicallyRefreshFilters()
http.Handle("/", optionalAuthHandler(http.FileServer(box)))
registerControlHandlers()
err = startDNSServer()
if err != nil {
log.Fatal(err)
}
err = startDHCPServer()
if err != nil {
log.Fatal(err)
}
URL := fmt.Sprintf("http://%s", address)
log.Println("Go to " + URL)
log.Fatal(http.ListenAndServe(address, nil))
}
func cleanup() {
err := stopDNSServer()
if err != nil {
log.Printf("Couldn't stop DNS server: %s", err)
}
}
func getInput() (string, error) {
scanner := bufio.NewScanner(os.Stdin)
scanner.Scan()
text := scanner.Text()
err := scanner.Err()
return text, err
}
// loadOptions reads command line arguments and initializes configuration
func loadOptions() {
var printHelp func() var printHelp func()
var configFilename *string var configFilename *string
var bindHost *string var bindHost *string
@ -130,85 +212,6 @@ func main() {
} }
} }
// Load filters from the disk
// And if any filter has zero ID, assign a new one
for i := range config.Filters {
filter := &config.Filters[i] // otherwise we're operating on a copy
if filter.ID == 0 {
filter.ID = assignUniqueFilterID()
}
err := filter.load()
if err != nil {
// This is okay for the first start, the filter will be loaded later
log.Printf("Couldn't load filter %d contents due to %s", filter.ID, err)
// clear LastUpdated so it gets fetched right away
}
if len(filter.Rules) == 0 {
filter.LastUpdated = time.Time{}
}
}
// Update filters we've just loaded right away, don't wait for periodic update timer
go func() {
refreshFiltersIfNeccessary(false)
// Save the updated config
err := config.write()
if err != nil {
log.Fatal(err)
}
}()
signalChannel := make(chan os.Signal)
signal.Notify(signalChannel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT)
go func() {
<-signalChannel
cleanup()
os.Exit(0)
}()
// Save the updated config
err := config.write()
if err != nil {
log.Fatal(err)
}
address := net.JoinHostPort(config.BindHost, strconv.Itoa(config.BindPort))
go periodicallyRefreshFilters()
http.Handle("/", optionalAuthHandler(http.FileServer(box)))
registerControlHandlers()
err = startDNSServer()
if err != nil {
log.Fatal(err)
}
err = startDHCPServer()
if err != nil {
log.Fatal(err)
}
URL := fmt.Sprintf("http://%s", address)
log.Println("Go to " + URL)
log.Fatal(http.ListenAndServe(address, nil))
}
func cleanup() {
err := stopDNSServer()
if err != nil {
log.Printf("Couldn't stop DNS server: %s", err)
}
}
func getInput() (string, error) {
scanner := bufio.NewScanner(os.Stdin)
scanner.Scan()
text := scanner.Text()
err := scanner.Err()
return text, err
}
func promptAndGet(prompt string) (string, error) { func promptAndGet(prompt string) (string, error) {
for { for {
fmt.Print(prompt) fmt.Print(prompt)

View File

@ -10,7 +10,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/dnsfilter" "github.com/AdguardTeam/AdGuardHome/dnsfilter"
"github.com/AdguardTeam/AdGuardHome/dnsforward" "github.com/AdguardTeam/AdGuardHome/dnsforward"
"github.com/hmage/golibs/log" "github.com/hmage/golibs/log"
"gopkg.in/yaml.v2" yaml "gopkg.in/yaml.v2"
) )
const ( const (
@ -21,7 +21,7 @@ const (
// configuration is loaded from YAML // configuration is loaded from YAML
// field ordering is important -- yaml fields will mirror ordering from here // field ordering is important -- yaml fields will mirror ordering from here
type configuration struct { type configuration struct {
ourConfigFilename string // Config filename (can be overriden via the command line arguments) ourConfigFilename string // Config filename (can be overridden via the command line arguments)
ourBinaryDir string // Location of our directory, used to protect against CWD being somewhere else ourBinaryDir string // Location of our directory, used to protect against CWD being somewhere else
BindHost string `yaml:"bind_host"` BindHost string `yaml:"bind_host"`

View File

@ -15,7 +15,7 @@ import (
"github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/dnsproxy/upstream"
"github.com/hmage/golibs/log" "github.com/hmage/golibs/log"
"github.com/miekg/dns" "github.com/miekg/dns"
"gopkg.in/asaskevich/govalidator.v4" govalidator "gopkg.in/asaskevich/govalidator.v4"
) )
const updatePeriod = time.Minute * 30 const updatePeriod = time.Minute * 30
@ -40,27 +40,26 @@ func writeAllConfigsAndReloadDNS() error {
log.Printf("Couldn't write all configs: %s", err) log.Printf("Couldn't write all configs: %s", err)
return err return err
} }
reconfigureDNSServer() return reconfigureDNSServer()
return nil
} }
func httpUpdateConfigReloadDNSReturnOK(w http.ResponseWriter, r *http.Request) { func httpUpdateConfigReloadDNSReturnOK(w http.ResponseWriter, r *http.Request) {
err := writeAllConfigsAndReloadDNS() err := writeAllConfigsAndReloadDNS()
if err != nil { if err != nil {
errortext := fmt.Sprintf("Couldn't write config file: %s", err) errorText := fmt.Sprintf("Couldn't write config file: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, http.StatusInternalServerError) http.Error(w, errorText, http.StatusInternalServerError)
return return
} }
returnOK(w, r) returnOK(w)
} }
func returnOK(w http.ResponseWriter, r *http.Request) { func returnOK(w http.ResponseWriter) {
_, err := fmt.Fprintf(w, "OK\n") _, err := fmt.Fprintf(w, "OK\n")
if err != nil { if err != nil {
errortext := fmt.Sprintf("Couldn't write body: %s", err) errorText := fmt.Sprintf("Couldn't write body: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, http.StatusInternalServerError) http.Error(w, errorText, http.StatusInternalServerError)
} }
} }
@ -79,17 +78,17 @@ func handleStatus(w http.ResponseWriter, r *http.Request) {
jsonVal, err := json.Marshal(data) jsonVal, err := json.Marshal(data)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Unable to marshal status json: %s", err) errorText := fmt.Sprintf("Unable to marshal status json: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, 500) http.Error(w, errorText, 500)
return return
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal) _, err = w.Write(jsonVal)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Unable to write response json: %s", err) errorText := fmt.Sprintf("Unable to write response json: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, 500) http.Error(w, errorText, 500)
return return
} }
} }
@ -147,7 +146,13 @@ func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) {
http.Error(w, errorText, http.StatusInternalServerError) http.Error(w, errorText, http.StatusInternalServerError)
return return
} }
reconfigureDNSServer() err = reconfigureDNSServer()
if err != nil {
errorText := fmt.Sprintf("Couldn't reconfigure the DNS server: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
return
}
_, err = fmt.Fprintf(w, "OK %d servers\n", len(hosts)) _, err = fmt.Fprintf(w, "OK %d servers\n", len(hosts))
if err != nil { if err != nil {
errorText := fmt.Sprintf("Couldn't write body: %s", err) errorText := fmt.Sprintf("Couldn't write body: %s", err)
@ -206,7 +211,7 @@ func checkDNS(input string) error {
log.Printf("Checking if DNS %s works...", input) log.Printf("Checking if DNS %s works...", input)
u, err := upstream.AddressToUpstream(input, "", dnsforward.DefaultTimeout) u, err := upstream.AddressToUpstream(input, "", dnsforward.DefaultTimeout)
if err != nil { if err != nil {
return fmt.Errorf("Failed to choose upstream for %s: %s", input, err) return fmt.Errorf("failed to choose upstream for %s: %s", input, err)
} }
req := dns.Msg{} req := dns.Msg{}
@ -243,9 +248,9 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
resp, err := client.Get(versionCheckURL) resp, err := client.Get(versionCheckURL)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Couldn't get version check json from %s: %T %s\n", versionCheckURL, err, err) errorText := fmt.Sprintf("Couldn't get version check json from %s: %T %s\n", versionCheckURL, err, err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, http.StatusBadGateway) http.Error(w, errorText, http.StatusBadGateway)
return return
} }
if resp != nil && resp.Body != nil { if resp != nil && resp.Body != nil {
@ -255,18 +260,18 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
// read the body entirely // read the body entirely
body, err := ioutil.ReadAll(resp.Body) body, err := ioutil.ReadAll(resp.Body)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Couldn't read response body from %s: %s", versionCheckURL, err) errorText := fmt.Sprintf("Couldn't read response body from %s: %s", versionCheckURL, err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, http.StatusBadGateway) http.Error(w, errorText, http.StatusBadGateway)
return return
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
_, err = w.Write(body) _, err = w.Write(body)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Couldn't write body: %s", err) errorText := fmt.Sprintf("Couldn't write body: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, http.StatusInternalServerError) http.Error(w, errorText, http.StatusInternalServerError)
} }
versionCheckLastTime = now versionCheckLastTime = now
@ -299,44 +304,44 @@ func handleFilteringStatus(w http.ResponseWriter, r *http.Request) {
config.RUnlock() config.RUnlock()
if err != nil { if err != nil {
errortext := fmt.Sprintf("Unable to marshal status json: %s", err) errorText := fmt.Sprintf("Unable to marshal status json: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, 500) http.Error(w, errorText, 500)
return return
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal) _, err = w.Write(jsonVal)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Unable to write response json: %s", err) errorText := fmt.Sprintf("Unable to write response json: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, 500) http.Error(w, errorText, 500)
return return
} }
} }
func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) { func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) {
filter := filter{} f := filter{}
err := json.NewDecoder(r.Body).Decode(&filter) err := json.NewDecoder(r.Body).Decode(&f)
if err != nil { if err != nil {
httpError(w, http.StatusBadRequest, "Failed to parse request body json: %s", err) httpError(w, http.StatusBadRequest, "Failed to parse request body json: %s", err)
return return
} }
if len(filter.URL) == 0 { if len(f.URL) == 0 {
http.Error(w, "URL parameter was not specified", 400) http.Error(w, "URL parameter was not specified", 400)
return return
} }
if valid := govalidator.IsRequestURL(filter.URL); !valid { if valid := govalidator.IsRequestURL(f.URL); !valid {
http.Error(w, "URL parameter is not valid request URL", 400) http.Error(w, "URL parameter is not valid request URL", 400)
return return
} }
// Check for duplicates // Check for duplicates
for i := range config.Filters { for i := range config.Filters {
if config.Filters[i].URL == filter.URL { if config.Filters[i].URL == f.URL {
errorText := fmt.Sprintf("Filter URL already added -- %s", filter.URL) errorText := fmt.Sprintf("Filter URL already added -- %s", f.URL)
log.Println(errorText) log.Println(errorText)
http.Error(w, errorText, http.StatusBadRequest) http.Error(w, errorText, http.StatusBadRequest)
return return
@ -344,42 +349,42 @@ func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) {
} }
// Set necessary properties // Set necessary properties
filter.ID = assignUniqueFilterID() f.ID = assignUniqueFilterID()
filter.Enabled = true f.Enabled = true
// Download the filter contents // Download the filter contents
ok, err := filter.update(true) ok, err := f.update(true)
if err != nil { if err != nil {
errorText := fmt.Sprintf("Couldn't fetch filter from url %s: %s", filter.URL, err) errorText := fmt.Sprintf("Couldn't fetch filter from url %s: %s", f.URL, err)
log.Println(errorText) log.Println(errorText)
http.Error(w, errorText, http.StatusBadRequest) http.Error(w, errorText, http.StatusBadRequest)
return return
} }
if filter.RulesCount == 0 { if f.RulesCount == 0 {
errorText := fmt.Sprintf("Filter at the url %s has no rules (maybe it points to blank page?)", filter.URL) errorText := fmt.Sprintf("Filter at the url %s has no rules (maybe it points to blank page?)", f.URL)
log.Println(errorText) log.Println(errorText)
http.Error(w, errorText, http.StatusBadRequest) http.Error(w, errorText, http.StatusBadRequest)
return return
} }
if !ok { if !ok {
errorText := fmt.Sprintf("Filter at the url %s is invalid (maybe it points to blank page?)", filter.URL) errorText := fmt.Sprintf("Filter at the url %s is invalid (maybe it points to blank page?)", f.URL)
log.Println(errorText) log.Println(errorText)
http.Error(w, errorText, http.StatusBadRequest) http.Error(w, errorText, http.StatusBadRequest)
return return
} }
// Save the filter contents // Save the filter contents
err = filter.save() err = f.save()
if err != nil { if err != nil {
errorText := fmt.Sprintf("Failed to save filter %d due to %s", filter.ID, err) errorText := fmt.Sprintf("Failed to save filter %d due to %s", f.ID, err)
log.Println(errorText) log.Println(errorText)
http.Error(w, errorText, http.StatusBadRequest) http.Error(w, errorText, http.StatusBadRequest)
return return
} }
// URL is deemed valid, append it to filters, update config, write new filter file and tell dns to reload it // URL is deemed valid, append it to filters, update config, write new filter file and tell dns to reload it
// TODO: since we directly feed filters in-memory, revisit if writing configs is always neccessary // TODO: since we directly feed filters in-memory, revisit if writing configs is always necessary
config.Filters = append(config.Filters, filter) config.Filters = append(config.Filters, f)
err = writeAllConfigs() err = writeAllConfigs()
if err != nil { if err != nil {
errorText := fmt.Sprintf("Couldn't write config file: %s", err) errorText := fmt.Sprintf("Couldn't write config file: %s", err)
@ -388,9 +393,14 @@ func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) {
return return
} }
reconfigureDNSServer() err = reconfigureDNSServer()
if err != nil {
errorText := fmt.Sprintf("Couldn't reconfigure the DNS server: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
}
_, err = fmt.Fprintf(w, "OK %d rules\n", filter.RulesCount) _, err = fmt.Fprintf(w, "OK %d rules\n", f.RulesCount)
if err != nil { if err != nil {
errorText := fmt.Sprintf("Couldn't write body: %s", err) errorText := fmt.Sprintf("Couldn't write body: %s", err)
log.Println(errorText) log.Println(errorText)
@ -473,7 +483,7 @@ func handleFilteringEnableURL(w http.ResponseWriter, r *http.Request) {
} }
// kick off refresh of rules from new URLs // kick off refresh of rules from new URLs
refreshFiltersIfNeccessary(false) refreshFiltersIfNecessary(false)
httpUpdateConfigReloadDNSReturnOK(w, r) httpUpdateConfigReloadDNSReturnOK(w, r)
} }
@ -529,7 +539,7 @@ func handleFilteringSetRules(w http.ResponseWriter, r *http.Request) {
func handleFilteringRefresh(w http.ResponseWriter, r *http.Request) { func handleFilteringRefresh(w http.ResponseWriter, r *http.Request) {
force := r.URL.Query().Get("force") force := r.URL.Query().Get("force")
updated := refreshFiltersIfNeccessary(force != "") updated := refreshFiltersIfNecessary(force != "")
fmt.Fprintf(w, "OK %d filters updated\n", updated) fmt.Fprintf(w, "OK %d filters updated\n", updated)
} }
@ -553,17 +563,17 @@ func handleSafeBrowsingStatus(w http.ResponseWriter, r *http.Request) {
} }
jsonVal, err := json.Marshal(data) jsonVal, err := json.Marshal(data)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Unable to marshal status json: %s", err) errorText := fmt.Sprintf("Unable to marshal status json: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, 500) http.Error(w, errorText, 500)
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal) _, err = w.Write(jsonVal)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Unable to write response json: %s", err) errorText := fmt.Sprintf("Unable to write response json: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, 500) http.Error(w, errorText, 500)
return return
} }
} }
@ -574,9 +584,9 @@ func handleSafeBrowsingStatus(w http.ResponseWriter, r *http.Request) {
func handleParentalEnable(w http.ResponseWriter, r *http.Request) { func handleParentalEnable(w http.ResponseWriter, r *http.Request) {
parameters, err := parseParametersFromBody(r.Body) parameters, err := parseParametersFromBody(r.Body)
if err != nil { if err != nil {
errortext := fmt.Sprintf("failed to parse parameters from body: %s", err) errorText := fmt.Sprintf("failed to parse parameters from body: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, 400) http.Error(w, errorText, 400)
return return
} }
@ -631,18 +641,18 @@ func handleParentalStatus(w http.ResponseWriter, r *http.Request) {
} }
jsonVal, err := json.Marshal(data) jsonVal, err := json.Marshal(data)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Unable to marshal status json: %s", err) errorText := fmt.Sprintf("Unable to marshal status json: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, 500) http.Error(w, errorText, 500)
return return
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal) _, err = w.Write(jsonVal)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Unable to write response json: %s", err) errorText := fmt.Sprintf("Unable to write response json: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, 500) http.Error(w, errorText, 500)
return return
} }
} }
@ -667,18 +677,18 @@ func handleSafeSearchStatus(w http.ResponseWriter, r *http.Request) {
} }
jsonVal, err := json.Marshal(data) jsonVal, err := json.Marshal(data)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Unable to marshal status json: %s", err) errorText := fmt.Sprintf("Unable to marshal status json: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, 500) http.Error(w, errorText, 500)
return return
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal) _, err = w.Write(jsonVal)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Unable to write response json: %s", err) errorText := fmt.Sprintf("Unable to write response json: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, 500) http.Error(w, errorText, 500)
return return
} }
} }

18
dhcp.go
View File

@ -58,7 +58,10 @@ func handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
} }
} }
if !newconfig.Enabled { if !newconfig.Enabled {
dhcpServer.Stop() err := dhcpServer.Stop()
if err != nil {
log.Printf("failed to stop the DHCP server: %s", err)
}
} }
config.DHCP = newconfig config.DHCP = newconfig
httpUpdateConfigReloadDNSReturnOK(w, r) httpUpdateConfigReloadDNSReturnOK(w, r)
@ -73,11 +76,6 @@ func handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
return return
} }
type address struct {
IP string
Netmask string
}
type responseInterface struct { type responseInterface struct {
Name string `json:"name"` Name string `json:"name"`
MTU int `json:"mtu"` MTU int `json:"mtu"`
@ -103,9 +101,9 @@ func handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
MTU: ifaces[i].MTU, MTU: ifaces[i].MTU,
HardwareAddr: ifaces[i].HardwareAddr.String(), HardwareAddr: ifaces[i].HardwareAddr.String(),
} }
addrs, err := ifaces[i].Addrs() addrs, errAddrs := ifaces[i].Addrs()
if err != nil { if errAddrs != nil {
httpError(w, http.StatusInternalServerError, "Failed to get addresses for interface %v: %s", ifaces[i].Name, err) httpError(w, http.StatusInternalServerError, "Failed to get addresses for interface %v: %s", ifaces[i].Name, errAddrs)
return return
} }
for _, addr := range addrs { for _, addr := range addrs {
@ -157,7 +155,7 @@ func handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Request) {
} }
func startDHCPServer() error { func startDHCPServer() error {
if config.DHCP.Enabled == false { if !config.DHCP.Enabled {
// not enabled, don't do anything // not enabled, don't do anything
return nil return nil
} }

View File

@ -13,6 +13,8 @@ import (
"github.com/krolaw/dhcp4" "github.com/krolaw/dhcp4"
) )
// CheckIfOtherDHCPServersPresent sends a DHCP request to the specified network interface,
// and waits for a response for a period defined by defaultDiscoverTime
func CheckIfOtherDHCPServersPresent(ifaceName string) (bool, error) { func CheckIfOtherDHCPServersPresent(ifaceName string) (bool, error) {
iface, err := net.InterfaceByName(ifaceName) iface, err := net.InterfaceByName(ifaceName)
if err != nil { if err != nil {
@ -30,8 +32,8 @@ func CheckIfOtherDHCPServersPresent(ifaceName string) (bool, error) {
dst := "255.255.255.255:67" dst := "255.255.255.255:67"
// form a DHCP request packet, try to emulate existing client as much as possible // form a DHCP request packet, try to emulate existing client as much as possible
xId := make([]byte, 8) xID := make([]byte, 8)
n, err := rand.Read(xId) n, err := rand.Read(xID)
if n != 8 && err == nil { if n != 8 && err == nil {
err = fmt.Errorf("Generated less than 8 bytes") err = fmt.Errorf("Generated less than 8 bytes")
} }
@ -57,16 +59,16 @@ func CheckIfOtherDHCPServersPresent(ifaceName string) (bool, error) {
maxUDPsizeRaw := make([]byte, 2) maxUDPsizeRaw := make([]byte, 2)
binary.BigEndian.PutUint16(maxUDPsizeRaw, 1500) binary.BigEndian.PutUint16(maxUDPsizeRaw, 1500)
leaseTimeRaw := make([]byte, 4) leaseTimeRaw := make([]byte, 4)
leaseTime := uint32(math.RoundToEven(time.Duration(time.Hour * 24 * 90).Seconds())) leaseTime := uint32(math.RoundToEven((time.Hour * 24 * 90).Seconds()))
binary.BigEndian.PutUint32(leaseTimeRaw, leaseTime) binary.BigEndian.PutUint32(leaseTimeRaw, leaseTime)
options := []dhcp4.Option{ options := []dhcp4.Option{
{dhcp4.OptionParameterRequestList, requestList}, {Code: dhcp4.OptionParameterRequestList, Value: requestList},
{dhcp4.OptionMaximumDHCPMessageSize, maxUDPsizeRaw}, {Code: dhcp4.OptionMaximumDHCPMessageSize, Value: maxUDPsizeRaw},
{dhcp4.OptionClientIdentifier, append([]byte{0x01}, iface.HardwareAddr...)}, {Code: dhcp4.OptionClientIdentifier, Value: append([]byte{0x01}, iface.HardwareAddr...)},
{dhcp4.OptionIPAddressLeaseTime, leaseTimeRaw}, {Code: dhcp4.OptionIPAddressLeaseTime, Value: leaseTimeRaw},
{dhcp4.OptionHostName, []byte(hostname)}, {Code: dhcp4.OptionHostName, Value: []byte(hostname)},
} }
packet := dhcp4.RequestPacket(dhcp4.Discover, iface.HardwareAddr, nil, xId, false, options) packet := dhcp4.RequestPacket(dhcp4.Discover, iface.HardwareAddr, nil, xID, false, options)
// resolve 0.0.0.0:68 // resolve 0.0.0.0:68
udpAddr, err := net.ResolveUDPAddr("udp4", src) udpAddr, err := net.ResolveUDPAddr("udp4", src)
@ -98,7 +100,7 @@ func CheckIfOtherDHCPServersPresent(ifaceName string) (bool, error) {
} }
// send to 255.255.255.255:67 // send to 255.255.255.255:67
n, err = c.WriteTo(packet, dstAddr) _, err = c.WriteTo(packet, dstAddr)
// spew.Dump(n, err) // spew.Dump(n, err)
if err != nil { if err != nil {
return false, wrapErrPrint(err, "Couldn't send a packet to %s", dst) return false, wrapErrPrint(err, "Couldn't send a packet to %s", dst)

View File

@ -13,6 +13,7 @@ import (
const defaultDiscoverTime = time.Second * 3 const defaultDiscoverTime = time.Second * 3
// Lease contains the necessary information about a DHCP lease
// field ordering is important -- yaml fields will mirror ordering from here // field ordering is important -- yaml fields will mirror ordering from here
type Lease struct { type Lease struct {
HWAddr net.HardwareAddr `json:"mac" yaml:"hwaddr"` HWAddr net.HardwareAddr `json:"mac" yaml:"hwaddr"`
@ -21,6 +22,7 @@ type Lease struct {
Expiry time.Time `json:"expires"` Expiry time.Time `json:"expires"`
} }
// ServerConfig - DHCP server configuration
// field ordering is important -- yaml fields will mirror ordering from here // field ordering is important -- yaml fields will mirror ordering from here
type ServerConfig struct { type ServerConfig struct {
Enabled bool `json:"enabled" yaml:"enabled"` Enabled bool `json:"enabled" yaml:"enabled"`
@ -32,6 +34,7 @@ type ServerConfig struct {
LeaseDuration uint `json:"lease_duration" yaml:"lease_duration"` // in seconds LeaseDuration uint `json:"lease_duration" yaml:"lease_duration"` // in seconds
} }
// Server - the current state of the DHCP server
type Server struct { type Server struct {
conn *filterConn // listening UDP socket conn *filterConn // listening UDP socket
@ -80,6 +83,7 @@ func (s *Server) Start(config *ServerConfig) error {
s.leaseStart, err = parseIPv4(s.RangeStart) s.leaseStart, err = parseIPv4(s.RangeStart)
if err != nil { if err != nil {
s.closeConn() // in case it was already started s.closeConn() // in case it was already started
return wrapErrPrint(err, "Failed to parse range start address %s", s.RangeStart) return wrapErrPrint(err, "Failed to parse range start address %s", s.RangeStart)
} }
@ -137,6 +141,7 @@ func (s *Server) Start(config *ServerConfig) error {
return nil return nil
} }
// Stop closes the listening UDP socket
func (s *Server) Stop() error { func (s *Server) Stop() error {
if s.conn == nil { if s.conn == nil {
// nothing to do, return silently // nothing to do, return silently
@ -174,7 +179,7 @@ func (s *Server) reserveLease(p dhcp4.Packet) (*Lease, error) {
} }
// not assigned a lease, create new one, find IP from LRU // not assigned a lease, create new one, find IP from LRU
log.Tracef("Lease not found for %s: creating new one", hwaddr) log.Tracef("Lease not found for %s: creating new one", hwaddr)
ip, err := s.findFreeIP(p, hwaddr) ip, err := s.findFreeIP(hwaddr)
if err != nil { if err != nil {
return nil, wrapErrPrint(err, "Couldn't find free IP for the lease %s", hwaddr.String()) return nil, wrapErrPrint(err, "Couldn't find free IP for the lease %s", hwaddr.String())
} }
@ -198,7 +203,7 @@ func (s *Server) locateLease(p dhcp4.Packet) *Lease {
return nil return nil
} }
func (s *Server) findFreeIP(p dhcp4.Packet, hwaddr net.HardwareAddr) (net.IP, error) { func (s *Server) findFreeIP(hwaddr net.HardwareAddr) (net.IP, error) {
// if IP pool is nil, lazy initialize it // if IP pool is nil, lazy initialize it
if s.IPpool == nil { if s.IPpool == nil {
s.IPpool = make(map[[4]byte]net.HardwareAddr) s.IPpool = make(map[[4]byte]net.HardwareAddr)
@ -223,7 +228,7 @@ func (s *Server) findFreeIP(p dhcp4.Packet, hwaddr net.HardwareAddr) (net.IP, er
if foundIP == nil { if foundIP == nil {
// TODO: LRU // TODO: LRU
return nil, fmt.Errorf("Couldn't find free entry in IP pool") return nil, fmt.Errorf("couldn't find free entry in IP pool")
} }
s.reserveIP(foundIP, hwaddr) s.reserveIP(foundIP, hwaddr)
@ -249,6 +254,7 @@ func (s *Server) unreserveIP(ip net.IP) {
delete(s.IPpool, IP4) delete(s.IPpool, IP4)
} }
// ServeDHCP handles an incoming DHCP request
func (s *Server) ServeDHCP(p dhcp4.Packet, msgType dhcp4.MessageType, options dhcp4.Options) dhcp4.Packet { func (s *Server) ServeDHCP(p dhcp4.Packet, msgType dhcp4.MessageType, options dhcp4.Options) dhcp4.Packet {
log.Tracef("Got %v message", msgType) log.Tracef("Got %v message", msgType)
log.Tracef("Leases:") log.Tracef("Leases:")
@ -259,27 +265,6 @@ func (s *Server) ServeDHCP(p dhcp4.Packet, msgType dhcp4.MessageType, options dh
for ip, hwaddr := range s.IPpool { for ip, hwaddr := range s.IPpool {
log.Tracef("IP pool entry %s -> %s", net.IPv4(ip[0], ip[1], ip[2], ip[3]), hwaddr) log.Tracef("IP pool entry %s -> %s", net.IPv4(ip[0], ip[1], ip[2], ip[3]), hwaddr)
} }
// spew.Dump(s.leases, s.IPpool)
// log.Printf("Called with msgType = %v, options = %+v", msgType, options)
// spew.Dump(p)
// log.Printf("%14s %v", "p.Broadcast", p.Broadcast()) // false
// log.Printf("%14s %v", "p.CHAddr", p.CHAddr()) // 2c:f0:a2:f2:31:00
// log.Printf("%14s %v", "p.CIAddr", p.CIAddr()) // 0.0.0.0
// log.Printf("%14s %v", "p.Cookie", p.Cookie()) // [99 130 83 99]
// log.Printf("%14s %v", "p.File", p.File()) // []
// log.Printf("%14s %v", "p.Flags", p.Flags()) // [0 0]
// log.Printf("%14s %v", "p.GIAddr", p.GIAddr()) // 0.0.0.0
// log.Printf("%14s %v", "p.HLen", p.HLen()) // 6
// log.Printf("%14s %v", "p.HType", p.HType()) // 1
// log.Printf("%14s %v", "p.Hops", p.Hops()) // 0
// log.Printf("%14s %v", "p.OpCode", p.OpCode()) // BootRequest
// log.Printf("%14s %v", "p.Options", p.Options()) // [53 1 1 55 10 1 121 3 6 15 119 252 95 44 46 57 2 5 220 61 7 1 44 240 162 242 49 0 51 4 0 118 167 0 12 4 119 104 109 100 255 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
// log.Printf("%14s %v", "p.ParseOptions", p.ParseOptions()) // map[OptionParameterRequestList:[1 121 3 6 15 119 252 95 44 46] OptionDHCPMessageType:[1] OptionMaximumDHCPMessageSize:[5 220] OptionClientIdentifier:[1 44 240 162 242 49 0] OptionIPAddressLeaseTime:[0 118 167 0] OptionHostName:[119 104 109 100]]
// log.Printf("%14s %v", "p.SIAddr", p.SIAddr()) // 0.0.0.0
// log.Printf("%14s %v", "p.SName", p.SName()) // []
// log.Printf("%14s %v", "p.Secs", p.Secs()) // [0 8]
// log.Printf("%14s %v", "p.XId", p.XId()) // [211 184 20 44]
// log.Printf("%14s %v", "p.YIAddr", p.YIAddr()) // 0.0.0.0
switch msgType { switch msgType {
case dhcp4.Discover: // Broadcast Packet From Client - Can I have an IP? case dhcp4.Discover: // Broadcast Packet From Client - Can I have an IP?
@ -297,6 +282,32 @@ func (s *Server) ServeDHCP(p dhcp4.Packet, msgType dhcp4.MessageType, options dh
case dhcp4.Request: // Broadcast From Client - I'll take that IP (Also start for renewals) case dhcp4.Request: // Broadcast From Client - I'll take that IP (Also start for renewals)
// start/renew a lease -- update lease time // start/renew a lease -- update lease time
// some clients (OSX) just go right ahead and do Request first from previously known IP, if they get NAK, they restart full cycle with Discover then Request // some clients (OSX) just go right ahead and do Request first from previously known IP, if they get NAK, they restart full cycle with Discover then Request
return s.handleDHCP4Request(p, options)
case dhcp4.Decline: // Broadcast From Client - Sorry I can't use that IP
log.Tracef("Got from client: Decline")
case dhcp4.Release: // From Client, I don't need that IP anymore
log.Tracef("Got from client: Release")
case dhcp4.Inform: // From Client, I have this IP and there's nothing you can do about it
log.Tracef("Got from client: Inform")
// do nothing
// from server -- ignore those but enumerate just in case
case dhcp4.Offer: // Broadcast From Server - Here's an IP
log.Printf("SHOULD NOT HAPPEN -- FROM ANOTHER DHCP SERVER: Offer")
case dhcp4.ACK: // From Server, Yes you can have that IP
log.Printf("SHOULD NOT HAPPEN -- FROM ANOTHER DHCP SERVER: ACK")
case dhcp4.NAK: // From Server, No you cannot have that IP
log.Printf("SHOULD NOT HAPPEN -- FROM ANOTHER DHCP SERVER: NAK")
default:
log.Printf("Unknown DHCP packet detected, ignoring: %v", msgType)
return nil
}
return nil
}
func (s *Server) handleDHCP4Request(p dhcp4.Packet, options dhcp4.Options) dhcp4.Packet {
log.Tracef("Got from client: Request") log.Tracef("Got from client: Request")
if server, ok := options[dhcp4.OptionServerIdentifier]; ok && !net.IP(server).Equal(s.ipnet.IP) { if server, ok := options[dhcp4.OptionServerIdentifier]; ok && !net.IP(server).Equal(s.ipnet.IP) {
log.Tracef("Request message not for this DHCP server (%v vs %v)", server, s.ipnet.IP) log.Tracef("Request message not for this DHCP server (%v vs %v)", server, s.ipnet.IP)
@ -305,7 +316,7 @@ func (s *Server) ServeDHCP(p dhcp4.Packet, msgType dhcp4.MessageType, options dh
reqIP := net.IP(options[dhcp4.OptionRequestedIPAddress]) reqIP := net.IP(options[dhcp4.OptionRequestedIPAddress])
if reqIP == nil { if reqIP == nil {
reqIP = net.IP(p.CIAddr()) reqIP = p.CIAddr()
} }
if reqIP.To4() == nil { if reqIP.To4() == nil {
@ -366,30 +377,9 @@ func (s *Server) ServeDHCP(p dhcp4.Packet, msgType dhcp4.MessageType, options dh
// requested IP is outside of DHCP range // requested IP is outside of DHCP range
log.Tracef("Replying with NAK: request IP %s is outside of DHCP range [%s, %s], asked by %v", reqIP, s.leaseStart, s.leaseStop, p.CHAddr()) log.Tracef("Replying with NAK: request IP %s is outside of DHCP range [%s, %s], asked by %v", reqIP, s.leaseStart, s.leaseStop, p.CHAddr())
return dhcp4.ReplyPacket(p, dhcp4.NAK, s.ipnet.IP, nil, 0, nil) return dhcp4.ReplyPacket(p, dhcp4.NAK, s.ipnet.IP, nil, 0, nil)
case dhcp4.Decline: // Broadcast From Client - Sorry I can't use that IP
log.Tracef("Got from client: Decline")
case dhcp4.Release: // From Client, I don't need that IP anymore
log.Tracef("Got from client: Release")
case dhcp4.Inform: // From Client, I have this IP and there's nothing you can do about it
log.Tracef("Got from client: Inform")
// do nothing
// from server -- ignore those but enumerate just in case
case dhcp4.Offer: // Broadcast From Server - Here's an IP
log.Printf("SHOULD NOT HAPPEN -- FROM ANOTHER DHCP SERVER: Offer")
case dhcp4.ACK: // From Server, Yes you can have that IP
log.Printf("SHOULD NOT HAPPEN -- FROM ANOTHER DHCP SERVER: ACK")
case dhcp4.NAK: // From Server, No you cannot have that IP
log.Printf("SHOULD NOT HAPPEN -- FROM ANOTHER DHCP SERVER: NAK")
default:
log.Printf("Unknown DHCP packet detected, ignoring: %v", msgType)
return nil
}
return nil
} }
// Leases returns the list of current DHCP leases
func (s *Server) Leases() []*Lease { func (s *Server) Leases() []*Lease {
s.RLock() s.RLock()
result := s.leases result := s.leases

View File

@ -8,7 +8,7 @@ import (
) )
// filterConn listens to 0.0.0.0:67, but accepts packets only from specific interface // filterConn listens to 0.0.0.0:67, but accepts packets only from specific interface
// This is neccessary for DHCP daemon to work, since binding to IP address doesn't // This is necessary for DHCP daemon to work, since binding to IP address doesn't
// us access to see Discover/Request packets from clients. // us access to see Discover/Request packets from clients.
// //
// TODO: on windows, controlmessage does not work, try to find out another way // TODO: on windows, controlmessage does not work, try to find out another way
@ -49,7 +49,6 @@ func (f *filterConn) ReadFrom(b []byte) (int, net.Addr, error) {
} }
// packet doesn't match criteria, drop it // packet doesn't match criteria, drop it
} }
return 0, nil, nil
} }
func (f *filterConn) WriteTo(b []byte, addr net.Addr) (int, error) { func (f *filterConn) WriteTo(b []byte, addr net.Addr) (int, error) {

View File

@ -3,7 +3,6 @@ package dhcpd
import ( import (
"fmt" "fmt"
"net" "net"
"strings"
"github.com/hmage/golibs/log" "github.com/hmage/golibs/log"
"github.com/joomcode/errorx" "github.com/joomcode/errorx"
@ -45,22 +44,6 @@ func getIfaceIPv4(iface *net.Interface) *net.IPNet {
return nil return nil
} }
func isConnClosed(err error) bool {
if err == nil {
return false
}
nerr, ok := err.(*net.OpError)
if !ok {
return false
}
if strings.Contains(nerr.Err.Error(), "use of closed network connection") {
return true
}
return false
}
func wrapErrPrint(err error, message string, args ...interface{}) error { func wrapErrPrint(err error, message string, args ...interface{}) error {
var errx error var errx error
if err == nil { if err == nil {

View File

@ -76,9 +76,9 @@ func main() {
panic(err) panic(err)
} }
log.Printf("Now serving DHCP") log.Printf("Now serving DHCP")
signal_channel := make(chan os.Signal) signalChannel := make(chan os.Signal)
signal.Notify(signal_channel, syscall.SIGINT, syscall.SIGTERM) signal.Notify(signalChannel, syscall.SIGINT, syscall.SIGTERM)
<-signal_channel <-signalChannel
} }

6
dns.go
View File

@ -39,20 +39,20 @@ func generateServerConfig() dnsforward.ServerConfig {
} }
for _, u := range config.DNS.UpstreamDNS { for _, u := range config.DNS.UpstreamDNS {
upstream, err := upstream.AddressToUpstream(u, config.DNS.BootstrapDNS, dnsforward.DefaultTimeout) dnsUpstream, err := upstream.AddressToUpstream(u, config.DNS.BootstrapDNS, dnsforward.DefaultTimeout)
if err != nil { if err != nil {
log.Printf("Couldn't get upstream: %s", err) log.Printf("Couldn't get upstream: %s", err)
// continue, just ignore the upstream // continue, just ignore the upstream
continue continue
} }
newconfig.Upstreams = append(newconfig.Upstreams, upstream) newconfig.Upstreams = append(newconfig.Upstreams, dnsUpstream)
} }
return newconfig return newconfig
} }
func startDNSServer() error { func startDNSServer() error {
if isRunning() { if isRunning() {
return fmt.Errorf("Unable to start forwarding DNS server: Already running") return fmt.Errorf("unable to start forwarding DNS server: Already running")
} }
newconfig := generateServerConfig() newconfig := generateServerConfig()

View File

@ -35,7 +35,7 @@ const defaultParentalURL = "http://%s/check-parental-control-hash?prefixes=%s&se
// ErrInvalidSyntax is returned by AddRule when the rule is invalid // ErrInvalidSyntax is returned by AddRule when the rule is invalid
var ErrInvalidSyntax = errors.New("dnsfilter: invalid rule syntax") var ErrInvalidSyntax = errors.New("dnsfilter: invalid rule syntax")
// ErrInvalidSyntax is returned by AddRule when the rule was already added to the filter // ErrAlreadyExists is returned by AddRule when the rule was already added to the filter
var ErrAlreadyExists = errors.New("dnsfilter: rule was already added") var ErrAlreadyExists = errors.New("dnsfilter: rule was already added")
const shortcutLength = 6 // used for rule search optimization, 6 hits the sweet spot const shortcutLength = 6 // used for rule search optimization, 6 hits the sweet spot
@ -115,6 +115,7 @@ type Dnsfilter struct {
privateConfig privateConfig
} }
// Filter represents a filter list
type Filter struct { type Filter struct {
ID int64 `json:"id"` // auto-assigned when filter is added (see nextFilterID), json by default keeps ID uppercase but we need lowercase ID int64 `json:"id"` // auto-assigned when filter is added (see nextFilterID), json by default keeps ID uppercase but we need lowercase
Rules []string `json:"-" yaml:"-"` // not in yaml or json Rules []string `json:"-" yaml:"-"` // not in yaml or json
@ -127,16 +128,26 @@ type Reason int
const ( const (
// reasons for not filtering // reasons for not filtering
NotFilteredNotFound Reason = iota // host was not find in any checks, default value for result
NotFilteredWhiteList // the host is explicitly whitelisted // NotFilteredNotFound - host was not find in any checks, default value for result
NotFilteredError // there was a transitive error during check NotFilteredNotFound Reason = iota
// NotFilteredWhiteList - the host is explicitly whitelisted
NotFilteredWhiteList
// NotFilteredError - there was a transitive error during check
NotFilteredError
// reasons for filtering // reasons for filtering
FilteredBlackList // the host was matched to be advertising host
FilteredSafeBrowsing // the host was matched to be malicious/phishing // FilteredBlackList - the host was matched to be advertising host
FilteredParental // the host was matched to be outside of parental control settings FilteredBlackList
FilteredInvalid // the request was invalid and was not processed // FilteredSafeBrowsing - the host was matched to be malicious/phishing
FilteredSafeSearch // the host was replaced with safesearch variant FilteredSafeBrowsing
// FilteredParental - the host was matched to be outside of parental control settings
FilteredParental
// FilteredInvalid - the request was invalid and was not processed
FilteredInvalid
// FilteredSafeSearch - the host was replaced with safesearch variant
FilteredSafeSearch
) )
// these variables need to survive coredns reload // these variables need to survive coredns reload
@ -151,7 +162,7 @@ type Result struct {
IsFiltered bool `json:",omitempty"` // True if the host name is filtered IsFiltered bool `json:",omitempty"` // True if the host name is filtered
Reason Reason `json:",omitempty"` // Reason for blocking / unblocking Reason Reason `json:",omitempty"` // Reason for blocking / unblocking
Rule string `json:",omitempty"` // Original rule text Rule string `json:",omitempty"` // Original rule text
Ip net.IP `json:",omitempty"` // Not nil only in the case of a hosts file syntax IP net.IP `json:",omitempty"` // Not nil only in the case of a hosts file syntax
FilterID int64 `json:",omitempty"` // Filter ID the rule belongs to FilterID int64 `json:",omitempty"` // Filter ID the rule belongs to
} }
@ -228,7 +239,6 @@ func newRulesTable() *rulesTable {
func (r *rulesTable) Add(rule *rule) { func (r *rulesTable) Add(rule *rule) {
r.Lock() r.Lock()
if rule.ip != nil { if rule.ip != nil {
// Hosts syntax // Hosts syntax
r.rulesByHost[rule.text] = rule r.rulesByHost[rule.text] = rule
@ -476,7 +486,7 @@ func (rule *rule) match(host string) (Result, error) {
IsFiltered: true, IsFiltered: true,
Reason: FilteredBlackList, Reason: FilteredBlackList,
Rule: rule.originalText, Rule: rule.originalText,
Ip: rule.ip, IP: rule.ip,
FilterID: rule.listID, FilterID: rule.listID,
}, nil }, nil
} }
@ -661,8 +671,11 @@ func (d *Dnsfilter) checkParental(host string) (Result, error) {
return result, err return result, err
} }
type formatHandler func(hashparam string) string
type bodyHandler func(body []byte, hashes map[string]bool) (Result, error)
// real implementation of lookup/check // real implementation of lookup/check
func (d *Dnsfilter) lookupCommon(host string, lookupstats *LookupStats, cache gcache.Cache, hashparamNeedSlash bool, format func(hashparam string) string, handleBody func(body []byte, hashes map[string]bool) (Result, error)) (Result, error) { func (d *Dnsfilter) lookupCommon(host string, lookupstats *LookupStats, cache gcache.Cache, hashparamNeedSlash bool, format formatHandler, handleBody bodyHandler) (Result, error) {
// if host ends with a dot, trim it // if host ends with a dot, trim it
host = strings.ToLower(strings.Trim(host, ".")) host = strings.ToLower(strings.Trim(host, "."))
@ -774,43 +787,43 @@ func (d *Dnsfilter) AddRule(input string, filterListID int64) error {
} }
// Start parsing the rule // Start parsing the rule
rule := rule{ r := rule{
text: input, // will be modified text: input, // will be modified
originalText: input, originalText: input,
listID: filterListID, listID: filterListID,
} }
// Mark rule as whitelist if it starts with @@ // Mark rule as whitelist if it starts with @@
if strings.HasPrefix(rule.text, "@@") { if strings.HasPrefix(r.text, "@@") {
rule.isWhitelist = true r.isWhitelist = true
rule.text = rule.text[2:] r.text = r.text[2:]
} }
err := rule.parseOptions() err := r.parseOptions()
if err != nil { if err != nil {
return err return err
} }
rule.extractShortcut() r.extractShortcut()
if !enableDelayedCompilation { if !enableDelayedCompilation {
err := rule.compile() err := r.compile()
if err != nil { if err != nil {
return err return err
} }
} }
destination := d.blackList destination := d.blackList
if rule.isImportant { if r.isImportant {
destination = d.important destination = d.important
} else if rule.isWhitelist { } else if r.isWhitelist {
destination = d.whiteList destination = d.whiteList
} }
d.storageMutex.Lock() d.storageMutex.Lock()
d.storage[input] = true d.storage[input] = true
d.storageMutex.Unlock() d.storageMutex.Unlock()
destination.Add(&rule) destination.Add(&r)
return nil return nil
} }
@ -835,13 +848,13 @@ func (d *Dnsfilter) parseEtcHosts(input string, filterListID int64) bool {
d.storageMutex.Unlock() d.storageMutex.Unlock()
for _, host := range fields[1:] { for _, host := range fields[1:] {
rule := rule{ r := rule{
text: host, text: host,
originalText: input, originalText: input,
listID: filterListID, listID: filterListID,
ip: addr, ip: addr,
} }
d.blackList.Add(&rule) d.blackList.Add(&r)
} }
return true return true
} }
@ -913,15 +926,6 @@ func (d *Dnsfilter) Destroy() {
// config manipulation helpers // config manipulation helpers
// //
// IsParentalSensitivityValid checks if sensitivity is valid value
func IsParentalSensitivityValid(sensitivity int) bool {
switch sensitivity {
case 3, 10, 13, 17:
return true
}
return false
}
// SetSafeBrowsingServer lets you optionally change hostname of safesearch lookup // SetSafeBrowsingServer lets you optionally change hostname of safesearch lookup
func (d *Dnsfilter) SetSafeBrowsingServer(host string) { func (d *Dnsfilter) SetSafeBrowsingServer(host string) {
if len(host) == 0 { if len(host) == 0 {

View File

@ -26,7 +26,7 @@ import (
func TestLotsOfRulesMemoryUsage(t *testing.T) { func TestLotsOfRulesMemoryUsage(t *testing.T) {
start := getRSS() start := getRSS()
log.Tracef("RSS before loading rules - %d kB\n", start/1024) log.Tracef("RSS before loading rules - %d kB\n", start/1024)
dumpMemProfile(_Func() + "1.pprof") dumpMemProfile("tests/" + _Func() + "1.pprof")
d := NewForTest() d := NewForTest()
defer d.Destroy() defer d.Destroy()
@ -37,7 +37,7 @@ func TestLotsOfRulesMemoryUsage(t *testing.T) {
afterLoad := getRSS() afterLoad := getRSS()
log.Tracef("RSS after loading rules - %d kB (%d kB diff)\n", afterLoad/1024, (afterLoad-start)/1024) log.Tracef("RSS after loading rules - %d kB (%d kB diff)\n", afterLoad/1024, (afterLoad-start)/1024)
dumpMemProfile(_Func() + "2.pprof") dumpMemProfile("tests/" + _Func() + "2.pprof")
tests := []struct { tests := []struct {
host string host string
@ -60,7 +60,7 @@ func TestLotsOfRulesMemoryUsage(t *testing.T) {
} }
afterMatch := getRSS() afterMatch := getRSS()
log.Tracef("RSS after matching - %d kB (%d kB diff)\n", afterMatch/1024, (afterMatch-afterLoad)/1024) log.Tracef("RSS after matching - %d kB (%d kB diff)\n", afterMatch/1024, (afterMatch-afterLoad)/1024)
dumpMemProfile(_Func() + "3.pprof") dumpMemProfile("tests/" + _Func() + "3.pprof")
} }
func getRSS() uint64 { func getRSS() uint64 {
@ -69,6 +69,9 @@ func getRSS() uint64 {
panic(err) panic(err)
} }
minfo, err := proc.MemoryInfo() minfo, err := proc.MemoryInfo()
if err != nil {
panic(err)
}
return minfo.RSS return minfo.RSS
} }
@ -86,7 +89,7 @@ func dumpMemProfile(name string) {
} }
} }
const topHostsFilename = "../tests/top-1m.csv" const topHostsFilename = "tests/top-1m.csv"
func fetchTopHostsFromNet() { func fetchTopHostsFromNet() {
log.Tracef("Fetching top hosts from network") log.Tracef("Fetching top hosts from network")
@ -146,7 +149,7 @@ func getTopHosts() {
func TestLotsOfRulesLotsOfHostsMemoryUsage(t *testing.T) { func TestLotsOfRulesLotsOfHostsMemoryUsage(t *testing.T) {
start := getRSS() start := getRSS()
log.Tracef("RSS before loading rules - %d kB\n", start/1024) log.Tracef("RSS before loading rules - %d kB\n", start/1024)
dumpMemProfile(_Func() + "1.pprof") dumpMemProfile("tests/" + _Func() + "1.pprof")
d := NewForTest() d := NewForTest()
defer d.Destroy() defer d.Destroy()
@ -155,7 +158,7 @@ func TestLotsOfRulesLotsOfHostsMemoryUsage(t *testing.T) {
afterLoad := getRSS() afterLoad := getRSS()
log.Tracef("RSS after loading rules - %d kB (%d kB diff)\n", afterLoad/1024, (afterLoad-start)/1024) log.Tracef("RSS after loading rules - %d kB (%d kB diff)\n", afterLoad/1024, (afterLoad-start)/1024)
dumpMemProfile(_Func() + "2.pprof") dumpMemProfile("tests/" + _Func() + "2.pprof")
getTopHosts() getTopHosts()
hostnames, err := os.Open(topHostsFilename) hostnames, err := os.Open(topHostsFilename)
@ -165,7 +168,7 @@ func TestLotsOfRulesLotsOfHostsMemoryUsage(t *testing.T) {
defer hostnames.Close() defer hostnames.Close()
afterHosts := getRSS() afterHosts := getRSS()
log.Tracef("RSS after loading hosts - %d kB (%d kB diff)\n", afterHosts/1024, (afterHosts-afterLoad)/1024) log.Tracef("RSS after loading hosts - %d kB (%d kB diff)\n", afterHosts/1024, (afterHosts-afterLoad)/1024)
dumpMemProfile(_Func() + "2.pprof") dumpMemProfile("tests/" + _Func() + "2.pprof")
{ {
scanner := bufio.NewScanner(hostnames) scanner := bufio.NewScanner(hostnames)
@ -184,7 +187,7 @@ func TestLotsOfRulesLotsOfHostsMemoryUsage(t *testing.T) {
afterMatch := getRSS() afterMatch := getRSS()
log.Tracef("RSS after matching - %d kB (%d kB diff)\n", afterMatch/1024, (afterMatch-afterLoad)/1024) log.Tracef("RSS after matching - %d kB (%d kB diff)\n", afterMatch/1024, (afterMatch-afterLoad)/1024)
dumpMemProfile(_Func() + "3.pprof") dumpMemProfile("tests/" + _Func() + "3.pprof")
} }
func TestRuleToRegexp(t *testing.T) { func TestRuleToRegexp(t *testing.T) {
@ -282,7 +285,7 @@ func (d *Dnsfilter) checkMatch(t *testing.T, hostname string) {
} }
} }
func (d *Dnsfilter) checkMatchIp(t *testing.T, hostname string, ip string) { func (d *Dnsfilter) checkMatchIP(t *testing.T, hostname string, ip string) {
t.Helper() t.Helper()
ret, err := d.CheckHost(hostname) ret, err := d.CheckHost(hostname)
if err != nil { if err != nil {
@ -291,8 +294,8 @@ func (d *Dnsfilter) checkMatchIp(t *testing.T, hostname string, ip string) {
if !ret.IsFiltered { if !ret.IsFiltered {
t.Errorf("Expected hostname %s to match", hostname) t.Errorf("Expected hostname %s to match", hostname)
} }
if ret.Ip == nil || ret.Ip.String() != ip { if ret.IP == nil || ret.IP.String() != ip {
t.Errorf("Expected ip %s to match, actual: %v", ip, ret.Ip) t.Errorf("Expected ip %s to match, actual: %v", ip, ret.IP)
} }
} }
@ -308,7 +311,7 @@ func (d *Dnsfilter) checkMatchEmpty(t *testing.T, hostname string) {
} }
func loadTestRules(d *Dnsfilter) error { func loadTestRules(d *Dnsfilter) error {
filterFileName := "../tests/dns.txt" filterFileName := "tests/dns.txt"
file, err := os.Open(filterFileName) file, err := os.Open(filterFileName)
if err != nil { if err != nil {
return err return err
@ -368,8 +371,8 @@ func TestEtcHostsMatching(t *testing.T) {
text := fmt.Sprintf(" %s google.com www.google.com # enforce google's safesearch ", addr) text := fmt.Sprintf(" %s google.com www.google.com # enforce google's safesearch ", addr)
d.checkAddRule(t, text) d.checkAddRule(t, text)
d.checkMatchIp(t, "google.com", addr) d.checkMatchIP(t, "google.com", addr)
d.checkMatchIp(t, "www.google.com", addr) d.checkMatchIP(t, "www.google.com", addr)
d.checkMatchEmpty(t, "subdomain.google.com") d.checkMatchEmpty(t, "subdomain.google.com")
d.checkMatchEmpty(t, "example.org") d.checkMatchEmpty(t, "example.org")
} }
@ -866,7 +869,7 @@ func BenchmarkLotsOfRulesLotsOfHosts(b *testing.B) {
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
havedata := scanner.Scan() havedata := scanner.Scan()
if !havedata { if !havedata {
hostnames.Seek(0, 0) _, _ = hostnames.Seek(0, 0)
scanner = bufio.NewScanner(hostnames) scanner = bufio.NewScanner(hostnames)
havedata = scanner.Scan() havedata = scanner.Scan()
} }
@ -903,7 +906,7 @@ func BenchmarkLotsOfRulesLotsOfHostsParallel(b *testing.B) {
for pb.Next() { for pb.Next() {
havedata := scanner.Scan() havedata := scanner.Scan()
if !havedata { if !havedata {
hostnames.Seek(0, 0) _, _ = hostnames.Seek(0, 0)
scanner = bufio.NewScanner(hostnames) scanner = bufio.NewScanner(hostnames)
havedata = scanner.Scan() havedata = scanner.Scan()
} }

View File

@ -310,8 +310,8 @@ func (s *Server) genDNSFilterMessage(d *proxy.DNSContext, result *dnsfilter.Resu
case dnsfilter.FilteredParental: case dnsfilter.FilteredParental:
return s.genBlockedHost(m, parentalBlockHost, d.Upstream) return s.genBlockedHost(m, parentalBlockHost, d.Upstream)
default: default:
if result.Ip != nil { if result.IP != nil {
return s.genARecord(m, result.Ip) return s.genARecord(m, result.IP)
} }
return s.genNXDomain(m) return s.genNXDomain(m)

View File

@ -108,10 +108,16 @@ func logRequest(question *dns.Msg, answer *dns.Msg, result *dnsfilter.Result, el
if len(flushBuffer) > 0 { if len(flushBuffer) > 0 {
// write to file // write to file
// do it in separate goroutine -- we are stalling DNS response this whole time // do it in separate goroutine -- we are stalling DNS response this whole time
go flushToFile(flushBuffer) go func() {
err := flushToFile(flushBuffer)
if err != nil {
log.Printf("Failed to flush the query log: %s", err)
}
}()
} }
} }
// HandleQueryLog handles query log web request
func HandleQueryLog(w http.ResponseWriter, r *http.Request) { func HandleQueryLog(w http.ResponseWriter, r *http.Request) {
queryLogLock.RLock() queryLogLock.RLock()
values := make([]*logEntry, len(queryLogCache)) values := make([]*logEntry, len(queryLogCache))
@ -123,6 +129,7 @@ func HandleQueryLog(w http.ResponseWriter, r *http.Request) {
values[left], values[right] = values[right], values[left] values[left], values[right] = values[right], values[left]
} }
// iterate
var data = []map[string]interface{}{} var data = []map[string]interface{}{}
for _, entry := range values { for _, entry := range values {
var q *dns.Msg var q *dns.Msg
@ -167,7 +174,36 @@ func HandleQueryLog(w http.ResponseWriter, r *http.Request) {
jsonEntry["filterId"] = entry.Result.FilterID jsonEntry["filterId"] = entry.Result.FilterID
} }
if a != nil && len(a.Answer) > 0 { answers := answerToMap(a)
if answers != nil {
jsonEntry["answer"] = answers
}
data = append(data, jsonEntry)
}
jsonVal, err := json.Marshal(data)
if err != nil {
errorText := fmt.Sprintf("Couldn't marshal data into json: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal)
if err != nil {
errorText := fmt.Sprintf("Unable to write response json: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
}
}
func answerToMap(a *dns.Msg) []map[string]interface{} {
if a == nil || len(a.Answer) == 0 {
return nil
}
var answers = []map[string]interface{}{} var answers = []map[string]interface{}{}
for _, k := range a.Answer { for _, k := range a.Answer {
header := k.Header() header := k.Header()
@ -207,27 +243,8 @@ func HandleQueryLog(w http.ResponseWriter, r *http.Request) {
} }
answers = append(answers, answer) answers = append(answers, answer)
} }
jsonEntry["answer"] = answers
}
data = append(data, jsonEntry) return answers
}
jsonVal, err := json.Marshal(data)
if err != nil {
errorText := fmt.Sprintf("Couldn't marshal data into json: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal)
if err != nil {
errorText := fmt.Sprintf("Unable to write response json: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
}
} }
// getIPString is a helper function that extracts IP address from net.Addr // getIPString is a helper function that extracts IP address from net.Addr

View File

@ -156,7 +156,7 @@ func periodicQueryLogRotate() {
func genericLoader(onEntry func(entry *logEntry) error, needMore func() bool, timeWindow time.Duration) error { func genericLoader(onEntry func(entry *logEntry) error, needMore func() bool, timeWindow time.Duration) error {
now := time.Now() now := time.Now()
// read from querylog files, try newest file first // read from querylog files, try newest file first
files := []string{} var files []string
if enableGzip { if enableGzip {
files = []string{ files = []string{

View File

@ -26,10 +26,10 @@ type hourTop struct {
mutex sync.RWMutex mutex sync.RWMutex
} }
func (top *hourTop) init() { func (h *hourTop) init() {
top.domains = gcache.New(queryLogTopSize).LRU().Build() h.domains = gcache.New(queryLogTopSize).LRU().Build()
top.blocked = gcache.New(queryLogTopSize).LRU().Build() h.blocked = gcache.New(queryLogTopSize).LRU().Build()
top.clients = gcache.New(queryLogTopSize).LRU().Build() h.clients = gcache.New(queryLogTopSize).LRU().Build()
} }
type dayTop struct { type dayTop struct {
@ -69,9 +69,9 @@ func periodicHourlyTopRotate() {
} }
} }
func (top *hourTop) incrementValue(key string, cache gcache.Cache) error { func (h *hourTop) incrementValue(key string, cache gcache.Cache) error {
top.Lock() h.Lock()
defer top.Unlock() defer h.Unlock()
ivalue, err := cache.Get(key) ivalue, err := cache.Get(key)
if err == gcache.KeyNotFoundError { if err == gcache.KeyNotFoundError {
// we just set it and we're done // we just set it and we're done
@ -103,20 +103,20 @@ func (top *hourTop) incrementValue(key string, cache gcache.Cache) error {
return nil return nil
} }
func (top *hourTop) incrementDomains(key string) error { func (h *hourTop) incrementDomains(key string) error {
return top.incrementValue(key, top.domains) return h.incrementValue(key, h.domains)
} }
func (top *hourTop) incrementBlocked(key string) error { func (h *hourTop) incrementBlocked(key string) error {
return top.incrementValue(key, top.blocked) return h.incrementValue(key, h.blocked)
} }
func (top *hourTop) incrementClients(key string) error { func (h *hourTop) incrementClients(key string) error {
return top.incrementValue(key, top.clients) return h.incrementValue(key, h.clients)
} }
// if does not exist -- return 0 // if does not exist -- return 0
func (top *hourTop) lockedGetValue(key string, cache gcache.Cache) (int, error) { func (h *hourTop) lockedGetValue(key string, cache gcache.Cache) (int, error) {
ivalue, err := cache.Get(key) ivalue, err := cache.Get(key)
if err == gcache.KeyNotFoundError { if err == gcache.KeyNotFoundError {
return 0, nil return 0, nil
@ -137,19 +137,19 @@ func (top *hourTop) lockedGetValue(key string, cache gcache.Cache) (int, error)
return value, nil return value, nil
} }
func (top *hourTop) lockedGetDomains(key string) (int, error) { func (h *hourTop) lockedGetDomains(key string) (int, error) {
return top.lockedGetValue(key, top.domains) return h.lockedGetValue(key, h.domains)
} }
func (top *hourTop) lockedGetBlocked(key string) (int, error) { func (h *hourTop) lockedGetBlocked(key string) (int, error) {
return top.lockedGetValue(key, top.blocked) return h.lockedGetValue(key, h.blocked)
} }
func (top *hourTop) lockedGetClients(key string) (int, error) { func (h *hourTop) lockedGetClients(key string) (int, error) {
return top.lockedGetValue(key, top.clients) return h.lockedGetValue(key, h.clients)
} }
func (r *dayTop) addEntry(entry *logEntry, q *dns.Msg, now time.Time) error { func (d *dayTop) addEntry(entry *logEntry, q *dns.Msg, now time.Time) error {
// figure out which hour bucket it belongs to // figure out which hour bucket it belongs to
hour := int(now.Sub(entry.Time).Hours()) hour := int(now.Sub(entry.Time).Hours())
if hour >= 24 { if hour >= 24 {
@ -252,6 +252,7 @@ func fillStatsFromQueryLog() error {
return nil return nil
} }
// HandleStatsTop returns the current top stats
func HandleStatsTop(w http.ResponseWriter, r *http.Request) { func HandleStatsTop(w http.ResponseWriter, r *http.Request) {
domains := map[string]int{} domains := map[string]int{}
blocked := map[string]int{} blocked := map[string]int{}
@ -320,9 +321,9 @@ func HandleStatsTop(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
_, err := w.Write(json.Bytes()) _, err := w.Write(json.Bytes())
if err != nil { if err != nil {
errortext := fmt.Sprintf("Couldn't write body: %s", err) errorText := fmt.Sprintf("Couldn't write body: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, http.StatusInternalServerError) http.Error(w, errorText, http.StatusInternalServerError)
} }
} }

View File

@ -17,7 +17,6 @@ var (
filteredLists = newDNSCounter("filtered_lists_total") filteredLists = newDNSCounter("filtered_lists_total")
filteredSafebrowsing = newDNSCounter("filtered_safebrowsing_total") filteredSafebrowsing = newDNSCounter("filtered_safebrowsing_total")
filteredParental = newDNSCounter("filtered_parental_total") filteredParental = newDNSCounter("filtered_parental_total")
filteredInvalid = newDNSCounter("filtered_invalid_total")
whitelisted = newDNSCounter("whitelisted_total") whitelisted = newDNSCounter("whitelisted_total")
safesearch = newDNSCounter("safesearch_total") safesearch = newDNSCounter("safesearch_total")
errorsTotal = newDNSCounter("errors_total") errorsTotal = newDNSCounter("errors_total")
@ -91,10 +90,10 @@ func (p *periodicStats) Observe(name string, when time.Time, value float64) {
{ {
countname := name + "_count" countname := name + "_count"
currentValues := p.Entries[countname] currentValues := p.Entries[countname]
value := currentValues[elapsed] v := currentValues[elapsed]
// log.Tracef("Will change p.Entries[%s][%d] from %v to %v", countname, elapsed, value, value+1) // log.Tracef("Will change p.Entries[%s][%d] from %v to %v", countname, elapsed, value, value+1)
value += 1 v++
currentValues[elapsed] = value currentValues[elapsed] = v
p.Entries[countname] = currentValues p.Entries[countname] = currentValues
} }
{ {
@ -224,6 +223,7 @@ func incrementCounters(entry *logEntry) {
elapsedTime.ObserveWithTime(entry.Elapsed.Seconds(), entry.Time) elapsedTime.ObserveWithTime(entry.Elapsed.Seconds(), entry.Time)
} }
// HandleStats returns aggregated stats data for the 24 hours
func HandleStats(w http.ResponseWriter, r *http.Request) { func HandleStats(w http.ResponseWriter, r *http.Request) {
const numHours = 24 const numHours = 24
histrical := generateMapFromStats(&statistics.PerHour, 0, numHours) histrical := generateMapFromStats(&statistics.PerHour, 0, numHours)
@ -252,17 +252,17 @@ func HandleStats(w http.ResponseWriter, r *http.Request) {
json, err := json.Marshal(summed) json, err := json.Marshal(summed)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Unable to marshal status json: %s", err) errorText := fmt.Sprintf("Unable to marshal status json: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, 500) http.Error(w, errorText, 500)
return return
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
_, err = w.Write(json) _, err = w.Write(json)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Unable to write response json: %s", err) errorText := fmt.Sprintf("Unable to write response json: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, 500) http.Error(w, errorText, 500)
return return
} }
} }
@ -296,6 +296,7 @@ func generateMapFromStats(stats *periodicStats, start int, end int) map[string]i
return result return result
} }
// HandleStatsHistory returns historical stats data for the 24 hours
func HandleStatsHistory(w http.ResponseWriter, r *http.Request) { func HandleStatsHistory(w http.ResponseWriter, r *http.Request) {
// handle time unit and prepare our time window size // handle time unit and prepare our time window size
now := time.Now() now := time.Now()
@ -323,27 +324,27 @@ func HandleStatsHistory(w http.ResponseWriter, r *http.Request) {
// parse start and end time // parse start and end time
startTime, err := time.Parse(time.RFC3339, r.URL.Query().Get("start_time")) startTime, err := time.Parse(time.RFC3339, r.URL.Query().Get("start_time"))
if err != nil { if err != nil {
errortext := fmt.Sprintf("Must specify valid start_time parameter: %s", err) errorText := fmt.Sprintf("Must specify valid start_time parameter: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, 400) http.Error(w, errorText, 400)
return return
} }
endTime, err := time.Parse(time.RFC3339, r.URL.Query().Get("end_time")) endTime, err := time.Parse(time.RFC3339, r.URL.Query().Get("end_time"))
if err != nil { if err != nil {
errortext := fmt.Sprintf("Must specify valid end_time parameter: %s", err) errorText := fmt.Sprintf("Must specify valid end_time parameter: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, 400) http.Error(w, errorText, 400)
return return
} }
// check if start and time times are within supported time range // check if start and time times are within supported time range
timeRange := timeUnit * statsHistoryElements timeRange := timeUnit * statsHistoryElements
if startTime.Add(timeRange).Before(now) { if startTime.Add(timeRange).Before(now) {
http.Error(w, "start_time parameter is outside of supported range", 501) http.Error(w, "start_time parameter is outside of supported range", http.StatusBadRequest)
return return
} }
if endTime.Add(timeRange).Before(now) { if endTime.Add(timeRange).Before(now) {
http.Error(w, "end_time parameter is outside of supported range", 501) http.Error(w, "end_time parameter is outside of supported range", http.StatusBadRequest)
return return
} }
@ -360,28 +361,29 @@ func HandleStatsHistory(w http.ResponseWriter, r *http.Request) {
data := generateMapFromStats(stats, start, end) data := generateMapFromStats(stats, start, end)
json, err := json.Marshal(data) json, err := json.Marshal(data)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Unable to marshal status json: %s", err) errorText := fmt.Sprintf("Unable to marshal status json: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, 500) http.Error(w, errorText, 500)
return return
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
_, err = w.Write(json) _, err = w.Write(json)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Unable to write response json: %s", err) errorText := fmt.Sprintf("Unable to write response json: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, 500) http.Error(w, errorText, 500)
return return
} }
} }
// HandleStatsReset resets the stats caches
func HandleStatsReset(w http.ResponseWriter, r *http.Request) { func HandleStatsReset(w http.ResponseWriter, r *http.Request) {
purgeStats() purgeStats()
_, err := fmt.Fprintf(w, "OK\n") _, err := fmt.Fprintf(w, "OK\n")
if err != nil { if err != nil {
errortext := fmt.Sprintf("Couldn't write body: %s", err) errorText := fmt.Sprintf("Couldn't write body: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, http.StatusInternalServerError) http.Error(w, errorText, http.StatusInternalServerError)
} }
} }

View File

@ -70,20 +70,20 @@ func updateUniqueFilterID(filters []filter) {
func assignUniqueFilterID() int64 { func assignUniqueFilterID() int64 {
value := nextFilterID value := nextFilterID
nextFilterID += 1 nextFilterID++
return value return value
} }
// Sets up a timer that will be checking for filters updates periodically // Sets up a timer that will be checking for filters updates periodically
func periodicallyRefreshFilters() { func periodicallyRefreshFilters() {
for range time.Tick(time.Minute) { for range time.Tick(time.Minute) {
refreshFiltersIfNeccessary(false) refreshFiltersIfNecessary(false)
} }
} }
// Checks filters updates if necessary // Checks filters updates if necessary
// If force is true, it ignores the filter.LastUpdated field value // If force is true, it ignores the filter.LastUpdated field value
func refreshFiltersIfNeccessary(force bool) int { func refreshFiltersIfNecessary(force bool) int {
config.Lock() config.Lock()
// fetch URLs // fetch URLs
@ -113,8 +113,12 @@ func refreshFiltersIfNeccessary(force bool) int {
} }
config.Unlock() config.Unlock()
if updateCount > 0 { if updateCount > 0 && isRunning() {
reconfigureDNSServer() err := reconfigureDNSServer()
if err != nil {
msg := fmt.Sprintf("SHOULD NOT HAPPEN: cannot reconfigure DNS server with the new filters: %s", err)
panic(msg)
}
} }
return updateCount return updateCount
} }

View File

@ -39,7 +39,7 @@ func safeWriteFile(path string, data []byte) error {
func ensure(method string, handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { func ensure(method string, handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
if r.Method != method { if r.Method != method {
http.Error(w, "This request must be "+method, 405) http.Error(w, "This request must be "+method, http.StatusMethodNotAllowed)
return return
} }
handler(w, r) handler(w, r)

View File

@ -34,9 +34,9 @@ func handleI18nCurrentLanguage(w http.ResponseWriter, r *http.Request) {
log.Printf("config.Language is %s", config.Language) log.Printf("config.Language is %s", config.Language)
_, err := fmt.Fprintf(w, "%s\n", config.Language) _, err := fmt.Fprintf(w, "%s\n", config.Language)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Unable to write response json: %s", err) errorText := fmt.Sprintf("Unable to write response json: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, http.StatusInternalServerError) http.Error(w, errorText, http.StatusInternalServerError)
return return
} }
} }

View File

@ -7,7 +7,7 @@ import (
"path/filepath" "path/filepath"
"github.com/hmage/golibs/log" "github.com/hmage/golibs/log"
"gopkg.in/yaml.v2" yaml "gopkg.in/yaml.v2"
) )
const currentSchemaVersion = 2 // used for upgrading from old configs to new config const currentSchemaVersion = 2 // used for upgrading from old configs to new config