Merge pull request #137 in DNS/adguard-dns from fix/557 to master

* commit 'ec6b1f7c42c8d2fc413d29fba55430b89fcbce2d':
  Added golangci-lint configuration and prepared for the integrattion
  Added codecov, goreport
  Use EnableAll in gometalinter config
  gometalinter
This commit is contained in:
Eugene Bujak 2019-01-25 20:14:03 +03:00
commit 0161509b5f
28 changed files with 577 additions and 464 deletions

8
.codecov.yml Normal file
View File

@ -0,0 +1,8 @@
coverage:
status:
project:
default:
target: 40%
threshold: null
patch: false
changes: false

4
.gitignore vendored
View File

@ -12,5 +12,5 @@
/scripts/translations/oneskyapp.json /scripts/translations/oneskyapp.json
# Test output # Test output
dnsfilter/dnsfilter.TestLotsOfRules*.pprof dnsfilter/tests/top-1m.csv
tests/top-1m.csv dnsfilter/tests/dnsfilter.TestLotsOfRules*.pprof

55
.golangci.yml Normal file
View File

@ -0,0 +1,55 @@
# options for analysis running
run:
# default concurrency is a available CPU number
concurrency: 4
# timeout for analysis, e.g. 30s, 5m, default is 1m
deadline: 2m
# which files to skip: they will be analyzed, but issues from them
# won't be reported. Default value is empty list, but there is
# no need to include all autogenerated files, we confidently recognize
# autogenerated files. If it's not please let us know.
skip-files:
- ".*generated.*"
- dnsfilter/rule_to_regexp.go
# all available settings of specific linters
linters-settings:
errcheck:
# [deprecated] comma-separated list of pairs of the form pkg:regex
# the regex is used to ignore names within pkg. (default "fmt:.*").
# see https://github.com/kisielk/errcheck#the-deprecated-method for details
ignore: fmt:.*,net:SetReadDeadline,net/http:^Write
gocyclo:
min-complexity: 20
lll:
line-length: 200
linters:
enable-all: true
disable:
- interfacer
- gocritic
- scopelint
- gochecknoglobals
- gochecknoinits
- prealloc
- maligned
- goconst # disabled until it's possible to configure
fast: true
issues:
# List of regexps of issue texts to exclude, empty list by default.
# But independently from this option we use default exclude patterns,
# it can be disabled by `exclude-use-default: false`. To list all
# excluded by default patterns execute `golangci-lint run --help`
exclude:
# structcheck cannot detect usages while they're there
- .parentalServer. is unused
- .safeBrowsingServer. is unused
# errcheck
- Error return value of .s.closeConn. is not checked
# goconst
- string .forcesafesearch.google.com. has 3 occurrences

23
.gometalinter.json Normal file
View File

@ -0,0 +1,23 @@
{
"Vendor": true,
"Test": true,
"Deadline": "2m",
"Sort": ["linter", "severity", "path", "line"],
"Exclude": [
".*generated.*",
"dnsfilter/rule_to_regexp.go"
],
"EnableGC": true,
"Linters": {
"nakedret": {
"Command": "nakedret",
"Pattern": "^(?P<path>.*?\\.go):(?P<line>\\d+)\\s*(?P<message>.*)$"
}
},
"WarnUnmatchedDirective": true,
"EnableAll": true,
"Cyclo": 20,
"LineLength": 200
}

View File

@ -24,7 +24,11 @@ install:
script: script:
- node -v - node -v
- npm -v - npm -v
- go test ./... # Run tests
- go test -race -v -bench=. -coverprofile=coverage.txt -covermode=atomic ./...
# Make
- make build/static/index.html - make build/static/index.html
- make - make
after_success:
- bash <(curl -s https://codecov.io/bash)

View File

@ -11,11 +11,18 @@
<a href="https://adguard.com/">AdGuard.com</a> | <a href="https://adguard.com/">AdGuard.com</a> |
<a href="https://github.com/AdguardTeam/AdGuardHome/wiki">Wiki</a> | <a href="https://github.com/AdguardTeam/AdGuardHome/wiki">Wiki</a> |
<a href="https://reddit.com/r/Adguard">Reddit</a> | <a href="https://reddit.com/r/Adguard">Reddit</a> |
<a href="https://twitter.com/AdGuard">Twitter</a> <a href="https://twitter.com/AdGuard">Twitter</a> |
<a href="https://t.me/adguard_en">Telegram</a>
<br /><br /> <br /><br />
<a href="https://travis-ci.org/AdguardTeam/AdGuardHome"> <a href="https://travis-ci.org/AdguardTeam/AdGuardHome">
<img src="https://travis-ci.org/AdguardTeam/AdGuardHome.svg" alt="Build status" /> <img src="https://travis-ci.org/AdguardTeam/AdGuardHome.svg" alt="Build status" />
</a> </a>
<a href="https://codecov.io/github/AdguardTeam/AdGuardHome?branch=master">
<img src="https://img.shields.io/codecov/c/github/AdguardTeam/AdGuardHome/master.svg" alt="Code Coverage" />
</a>
<a href="https://goreportcard.com/report/AdguardTeam/AdGuardHome">
<img src="https://goreportcard.com/badge/github.com/AdguardTeam/AdGuardHome" alt="Go Report Card" />
</a>
<a href="https://github.com/AdguardTeam/AdGuardHome/releases"> <a href="https://github.com/AdguardTeam/AdGuardHome/releases">
<img src="https://img.shields.io/github/release/AdguardTeam/AdGuardHome/all.svg" alt="Latest release" /> <img src="https://img.shields.io/github/release/AdguardTeam/AdGuardHome/all.svg" alt="Latest release" />
</a> </a>

177
app.go
View File

@ -43,92 +43,7 @@ func main() {
// config can be specified, which reads options from there, but other command line flags have to override config values // config can be specified, which reads options from there, but other command line flags have to override config values
// therefore, we must do it manually instead of using a lib // therefore, we must do it manually instead of using a lib
{ loadOptions()
var printHelp func()
var configFilename *string
var bindHost *string
var bindPort *int
var opts = []struct {
longName string
shortName string
description string
callbackWithValue func(value string)
callbackNoValue func()
}{
{"config", "c", "path to config file", func(value string) { configFilename = &value }, nil},
{"host", "h", "host address to bind HTTP server on", func(value string) { bindHost = &value }, nil},
{"port", "p", "port to serve HTTP pages on", func(value string) {
v, err := strconv.Atoi(value)
if err != nil {
panic("Got port that is not a number")
}
bindPort = &v
}, nil},
{"verbose", "v", "enable verbose output", nil, func() { log.Verbose = true }},
{"help", "h", "print this help", nil, func() { printHelp(); os.Exit(64) }},
}
printHelp = func() {
fmt.Printf("Usage:\n\n")
fmt.Printf("%s [options]\n\n", os.Args[0])
fmt.Printf("Options:\n")
for _, opt := range opts {
fmt.Printf(" -%s, %-30s %s\n", opt.shortName, "--"+opt.longName, opt.description)
}
}
for i := 1; i < len(os.Args); i++ {
v := os.Args[i]
knownParam := false
for _, opt := range opts {
if v == "--"+opt.longName || v == "-"+opt.shortName {
if opt.callbackWithValue != nil {
if i+1 > len(os.Args) {
log.Printf("ERROR: Got %s without argument\n", v)
os.Exit(64)
}
i++
opt.callbackWithValue(os.Args[i])
} else if opt.callbackNoValue != nil {
opt.callbackNoValue()
}
knownParam = true
break
}
}
if !knownParam {
log.Printf("ERROR: unknown option %v\n", v)
printHelp()
os.Exit(64)
}
}
if configFilename != nil {
config.ourConfigFilename = *configFilename
}
err := askUsernamePasswordIfPossible()
if err != nil {
log.Fatal(err)
}
// Do the upgrade if necessary
err = upgradeConfig()
if err != nil {
log.Fatal(err)
}
// parse from config file
err = parseConfig()
if err != nil {
log.Fatal(err)
}
// override bind host/port from the console
if bindHost != nil {
config.BindHost = *bindHost
}
if bindPort != nil {
config.BindPort = *bindPort
}
}
// Load filters from the disk // Load filters from the disk
// And if any filter has zero ID, assign a new one // And if any filter has zero ID, assign a new one
@ -150,7 +65,7 @@ func main() {
// Update filters we've just loaded right away, don't wait for periodic update timer // Update filters we've just loaded right away, don't wait for periodic update timer
go func() { go func() {
refreshFiltersIfNeccessary(false) refreshFiltersIfNecessary(false)
// Save the updated config // Save the updated config
err := config.write() err := config.write()
if err != nil { if err != nil {
@ -209,6 +124,94 @@ func getInput() (string, error) {
return text, err return text, err
} }
// loadOptions reads command line arguments and initializes configuration
func loadOptions() {
var printHelp func()
var configFilename *string
var bindHost *string
var bindPort *int
var opts = []struct {
longName string
shortName string
description string
callbackWithValue func(value string)
callbackNoValue func()
}{
{"config", "c", "path to config file", func(value string) { configFilename = &value }, nil},
{"host", "h", "host address to bind HTTP server on", func(value string) { bindHost = &value }, nil},
{"port", "p", "port to serve HTTP pages on", func(value string) {
v, err := strconv.Atoi(value)
if err != nil {
panic("Got port that is not a number")
}
bindPort = &v
}, nil},
{"verbose", "v", "enable verbose output", nil, func() { log.Verbose = true }},
{"help", "h", "print this help", nil, func() { printHelp(); os.Exit(64) }},
}
printHelp = func() {
fmt.Printf("Usage:\n\n")
fmt.Printf("%s [options]\n\n", os.Args[0])
fmt.Printf("Options:\n")
for _, opt := range opts {
fmt.Printf(" -%s, %-30s %s\n", opt.shortName, "--"+opt.longName, opt.description)
}
}
for i := 1; i < len(os.Args); i++ {
v := os.Args[i]
knownParam := false
for _, opt := range opts {
if v == "--"+opt.longName || v == "-"+opt.shortName {
if opt.callbackWithValue != nil {
if i+1 > len(os.Args) {
log.Printf("ERROR: Got %s without argument\n", v)
os.Exit(64)
}
i++
opt.callbackWithValue(os.Args[i])
} else if opt.callbackNoValue != nil {
opt.callbackNoValue()
}
knownParam = true
break
}
}
if !knownParam {
log.Printf("ERROR: unknown option %v\n", v)
printHelp()
os.Exit(64)
}
}
if configFilename != nil {
config.ourConfigFilename = *configFilename
}
err := askUsernamePasswordIfPossible()
if err != nil {
log.Fatal(err)
}
// Do the upgrade if necessary
err = upgradeConfig()
if err != nil {
log.Fatal(err)
}
// parse from config file
err = parseConfig()
if err != nil {
log.Fatal(err)
}
// override bind host/port from the console
if bindHost != nil {
config.BindHost = *bindHost
}
if bindPort != nil {
config.BindPort = *bindPort
}
}
func promptAndGet(prompt string) (string, error) { func promptAndGet(prompt string) (string, error) {
for { for {
fmt.Print(prompt) fmt.Print(prompt)

View File

@ -10,7 +10,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/dnsfilter" "github.com/AdguardTeam/AdGuardHome/dnsfilter"
"github.com/AdguardTeam/AdGuardHome/dnsforward" "github.com/AdguardTeam/AdGuardHome/dnsforward"
"github.com/hmage/golibs/log" "github.com/hmage/golibs/log"
"gopkg.in/yaml.v2" yaml "gopkg.in/yaml.v2"
) )
const ( const (
@ -21,7 +21,7 @@ const (
// configuration is loaded from YAML // configuration is loaded from YAML
// field ordering is important -- yaml fields will mirror ordering from here // field ordering is important -- yaml fields will mirror ordering from here
type configuration struct { type configuration struct {
ourConfigFilename string // Config filename (can be overriden via the command line arguments) ourConfigFilename string // Config filename (can be overridden via the command line arguments)
ourBinaryDir string // Location of our directory, used to protect against CWD being somewhere else ourBinaryDir string // Location of our directory, used to protect against CWD being somewhere else
BindHost string `yaml:"bind_host"` BindHost string `yaml:"bind_host"`
@ -41,8 +41,8 @@ type configuration struct {
// field ordering is important -- yaml fields will mirror ordering from here // field ordering is important -- yaml fields will mirror ordering from here
type dnsConfig struct { type dnsConfig struct {
BindHost string `yaml:"bind_host"` BindHost string `yaml:"bind_host"`
Port int `yaml:"port"` Port int `yaml:"port"`
dnsforward.FilteringConfig `yaml:",inline"` dnsforward.FilteringConfig `yaml:",inline"`

View File

@ -15,7 +15,7 @@ import (
"github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/dnsproxy/upstream"
"github.com/hmage/golibs/log" "github.com/hmage/golibs/log"
"github.com/miekg/dns" "github.com/miekg/dns"
"gopkg.in/asaskevich/govalidator.v4" govalidator "gopkg.in/asaskevich/govalidator.v4"
) )
const updatePeriod = time.Minute * 30 const updatePeriod = time.Minute * 30
@ -40,27 +40,26 @@ func writeAllConfigsAndReloadDNS() error {
log.Printf("Couldn't write all configs: %s", err) log.Printf("Couldn't write all configs: %s", err)
return err return err
} }
reconfigureDNSServer() return reconfigureDNSServer()
return nil
} }
func httpUpdateConfigReloadDNSReturnOK(w http.ResponseWriter, r *http.Request) { func httpUpdateConfigReloadDNSReturnOK(w http.ResponseWriter, r *http.Request) {
err := writeAllConfigsAndReloadDNS() err := writeAllConfigsAndReloadDNS()
if err != nil { if err != nil {
errortext := fmt.Sprintf("Couldn't write config file: %s", err) errorText := fmt.Sprintf("Couldn't write config file: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, http.StatusInternalServerError) http.Error(w, errorText, http.StatusInternalServerError)
return return
} }
returnOK(w, r) returnOK(w)
} }
func returnOK(w http.ResponseWriter, r *http.Request) { func returnOK(w http.ResponseWriter) {
_, err := fmt.Fprintf(w, "OK\n") _, err := fmt.Fprintf(w, "OK\n")
if err != nil { if err != nil {
errortext := fmt.Sprintf("Couldn't write body: %s", err) errorText := fmt.Sprintf("Couldn't write body: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, http.StatusInternalServerError) http.Error(w, errorText, http.StatusInternalServerError)
} }
} }
@ -79,17 +78,17 @@ func handleStatus(w http.ResponseWriter, r *http.Request) {
jsonVal, err := json.Marshal(data) jsonVal, err := json.Marshal(data)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Unable to marshal status json: %s", err) errorText := fmt.Sprintf("Unable to marshal status json: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, 500) http.Error(w, errorText, 500)
return return
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal) _, err = w.Write(jsonVal)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Unable to write response json: %s", err) errorText := fmt.Sprintf("Unable to write response json: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, 500) http.Error(w, errorText, 500)
return return
} }
} }
@ -147,7 +146,13 @@ func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) {
http.Error(w, errorText, http.StatusInternalServerError) http.Error(w, errorText, http.StatusInternalServerError)
return return
} }
reconfigureDNSServer() err = reconfigureDNSServer()
if err != nil {
errorText := fmt.Sprintf("Couldn't reconfigure the DNS server: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
return
}
_, err = fmt.Fprintf(w, "OK %d servers\n", len(hosts)) _, err = fmt.Fprintf(w, "OK %d servers\n", len(hosts))
if err != nil { if err != nil {
errorText := fmt.Sprintf("Couldn't write body: %s", err) errorText := fmt.Sprintf("Couldn't write body: %s", err)
@ -206,7 +211,7 @@ func checkDNS(input string) error {
log.Printf("Checking if DNS %s works...", input) log.Printf("Checking if DNS %s works...", input)
u, err := upstream.AddressToUpstream(input, "", dnsforward.DefaultTimeout) u, err := upstream.AddressToUpstream(input, "", dnsforward.DefaultTimeout)
if err != nil { if err != nil {
return fmt.Errorf("Failed to choose upstream for %s: %s", input, err) return fmt.Errorf("failed to choose upstream for %s: %s", input, err)
} }
req := dns.Msg{} req := dns.Msg{}
@ -243,9 +248,9 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
resp, err := client.Get(versionCheckURL) resp, err := client.Get(versionCheckURL)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Couldn't get version check json from %s: %T %s\n", versionCheckURL, err, err) errorText := fmt.Sprintf("Couldn't get version check json from %s: %T %s\n", versionCheckURL, err, err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, http.StatusBadGateway) http.Error(w, errorText, http.StatusBadGateway)
return return
} }
if resp != nil && resp.Body != nil { if resp != nil && resp.Body != nil {
@ -255,18 +260,18 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
// read the body entirely // read the body entirely
body, err := ioutil.ReadAll(resp.Body) body, err := ioutil.ReadAll(resp.Body)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Couldn't read response body from %s: %s", versionCheckURL, err) errorText := fmt.Sprintf("Couldn't read response body from %s: %s", versionCheckURL, err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, http.StatusBadGateway) http.Error(w, errorText, http.StatusBadGateway)
return return
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
_, err = w.Write(body) _, err = w.Write(body)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Couldn't write body: %s", err) errorText := fmt.Sprintf("Couldn't write body: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, http.StatusInternalServerError) http.Error(w, errorText, http.StatusInternalServerError)
} }
versionCheckLastTime = now versionCheckLastTime = now
@ -299,44 +304,44 @@ func handleFilteringStatus(w http.ResponseWriter, r *http.Request) {
config.RUnlock() config.RUnlock()
if err != nil { if err != nil {
errortext := fmt.Sprintf("Unable to marshal status json: %s", err) errorText := fmt.Sprintf("Unable to marshal status json: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, 500) http.Error(w, errorText, 500)
return return
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal) _, err = w.Write(jsonVal)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Unable to write response json: %s", err) errorText := fmt.Sprintf("Unable to write response json: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, 500) http.Error(w, errorText, 500)
return return
} }
} }
func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) { func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) {
filter := filter{} f := filter{}
err := json.NewDecoder(r.Body).Decode(&filter) err := json.NewDecoder(r.Body).Decode(&f)
if err != nil { if err != nil {
httpError(w, http.StatusBadRequest, "Failed to parse request body json: %s", err) httpError(w, http.StatusBadRequest, "Failed to parse request body json: %s", err)
return return
} }
if len(filter.URL) == 0 { if len(f.URL) == 0 {
http.Error(w, "URL parameter was not specified", 400) http.Error(w, "URL parameter was not specified", 400)
return return
} }
if valid := govalidator.IsRequestURL(filter.URL); !valid { if valid := govalidator.IsRequestURL(f.URL); !valid {
http.Error(w, "URL parameter is not valid request URL", 400) http.Error(w, "URL parameter is not valid request URL", 400)
return return
} }
// Check for duplicates // Check for duplicates
for i := range config.Filters { for i := range config.Filters {
if config.Filters[i].URL == filter.URL { if config.Filters[i].URL == f.URL {
errorText := fmt.Sprintf("Filter URL already added -- %s", filter.URL) errorText := fmt.Sprintf("Filter URL already added -- %s", f.URL)
log.Println(errorText) log.Println(errorText)
http.Error(w, errorText, http.StatusBadRequest) http.Error(w, errorText, http.StatusBadRequest)
return return
@ -344,42 +349,42 @@ func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) {
} }
// Set necessary properties // Set necessary properties
filter.ID = assignUniqueFilterID() f.ID = assignUniqueFilterID()
filter.Enabled = true f.Enabled = true
// Download the filter contents // Download the filter contents
ok, err := filter.update(true) ok, err := f.update(true)
if err != nil { if err != nil {
errorText := fmt.Sprintf("Couldn't fetch filter from url %s: %s", filter.URL, err) errorText := fmt.Sprintf("Couldn't fetch filter from url %s: %s", f.URL, err)
log.Println(errorText) log.Println(errorText)
http.Error(w, errorText, http.StatusBadRequest) http.Error(w, errorText, http.StatusBadRequest)
return return
} }
if filter.RulesCount == 0 { if f.RulesCount == 0 {
errorText := fmt.Sprintf("Filter at the url %s has no rules (maybe it points to blank page?)", filter.URL) errorText := fmt.Sprintf("Filter at the url %s has no rules (maybe it points to blank page?)", f.URL)
log.Println(errorText) log.Println(errorText)
http.Error(w, errorText, http.StatusBadRequest) http.Error(w, errorText, http.StatusBadRequest)
return return
} }
if !ok { if !ok {
errorText := fmt.Sprintf("Filter at the url %s is invalid (maybe it points to blank page?)", filter.URL) errorText := fmt.Sprintf("Filter at the url %s is invalid (maybe it points to blank page?)", f.URL)
log.Println(errorText) log.Println(errorText)
http.Error(w, errorText, http.StatusBadRequest) http.Error(w, errorText, http.StatusBadRequest)
return return
} }
// Save the filter contents // Save the filter contents
err = filter.save() err = f.save()
if err != nil { if err != nil {
errorText := fmt.Sprintf("Failed to save filter %d due to %s", filter.ID, err) errorText := fmt.Sprintf("Failed to save filter %d due to %s", f.ID, err)
log.Println(errorText) log.Println(errorText)
http.Error(w, errorText, http.StatusBadRequest) http.Error(w, errorText, http.StatusBadRequest)
return return
} }
// URL is deemed valid, append it to filters, update config, write new filter file and tell dns to reload it // URL is deemed valid, append it to filters, update config, write new filter file and tell dns to reload it
// TODO: since we directly feed filters in-memory, revisit if writing configs is always neccessary // TODO: since we directly feed filters in-memory, revisit if writing configs is always necessary
config.Filters = append(config.Filters, filter) config.Filters = append(config.Filters, f)
err = writeAllConfigs() err = writeAllConfigs()
if err != nil { if err != nil {
errorText := fmt.Sprintf("Couldn't write config file: %s", err) errorText := fmt.Sprintf("Couldn't write config file: %s", err)
@ -388,9 +393,14 @@ func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) {
return return
} }
reconfigureDNSServer() err = reconfigureDNSServer()
if err != nil {
errorText := fmt.Sprintf("Couldn't reconfigure the DNS server: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
}
_, err = fmt.Fprintf(w, "OK %d rules\n", filter.RulesCount) _, err = fmt.Fprintf(w, "OK %d rules\n", f.RulesCount)
if err != nil { if err != nil {
errorText := fmt.Sprintf("Couldn't write body: %s", err) errorText := fmt.Sprintf("Couldn't write body: %s", err)
log.Println(errorText) log.Println(errorText)
@ -473,7 +483,7 @@ func handleFilteringEnableURL(w http.ResponseWriter, r *http.Request) {
} }
// kick off refresh of rules from new URLs // kick off refresh of rules from new URLs
refreshFiltersIfNeccessary(false) refreshFiltersIfNecessary(false)
httpUpdateConfigReloadDNSReturnOK(w, r) httpUpdateConfigReloadDNSReturnOK(w, r)
} }
@ -529,7 +539,7 @@ func handleFilteringSetRules(w http.ResponseWriter, r *http.Request) {
func handleFilteringRefresh(w http.ResponseWriter, r *http.Request) { func handleFilteringRefresh(w http.ResponseWriter, r *http.Request) {
force := r.URL.Query().Get("force") force := r.URL.Query().Get("force")
updated := refreshFiltersIfNeccessary(force != "") updated := refreshFiltersIfNecessary(force != "")
fmt.Fprintf(w, "OK %d filters updated\n", updated) fmt.Fprintf(w, "OK %d filters updated\n", updated)
} }
@ -553,17 +563,17 @@ func handleSafeBrowsingStatus(w http.ResponseWriter, r *http.Request) {
} }
jsonVal, err := json.Marshal(data) jsonVal, err := json.Marshal(data)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Unable to marshal status json: %s", err) errorText := fmt.Sprintf("Unable to marshal status json: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, 500) http.Error(w, errorText, 500)
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal) _, err = w.Write(jsonVal)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Unable to write response json: %s", err) errorText := fmt.Sprintf("Unable to write response json: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, 500) http.Error(w, errorText, 500)
return return
} }
} }
@ -574,9 +584,9 @@ func handleSafeBrowsingStatus(w http.ResponseWriter, r *http.Request) {
func handleParentalEnable(w http.ResponseWriter, r *http.Request) { func handleParentalEnable(w http.ResponseWriter, r *http.Request) {
parameters, err := parseParametersFromBody(r.Body) parameters, err := parseParametersFromBody(r.Body)
if err != nil { if err != nil {
errortext := fmt.Sprintf("failed to parse parameters from body: %s", err) errorText := fmt.Sprintf("failed to parse parameters from body: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, 400) http.Error(w, errorText, 400)
return return
} }
@ -631,18 +641,18 @@ func handleParentalStatus(w http.ResponseWriter, r *http.Request) {
} }
jsonVal, err := json.Marshal(data) jsonVal, err := json.Marshal(data)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Unable to marshal status json: %s", err) errorText := fmt.Sprintf("Unable to marshal status json: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, 500) http.Error(w, errorText, 500)
return return
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal) _, err = w.Write(jsonVal)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Unable to write response json: %s", err) errorText := fmt.Sprintf("Unable to write response json: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, 500) http.Error(w, errorText, 500)
return return
} }
} }
@ -667,18 +677,18 @@ func handleSafeSearchStatus(w http.ResponseWriter, r *http.Request) {
} }
jsonVal, err := json.Marshal(data) jsonVal, err := json.Marshal(data)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Unable to marshal status json: %s", err) errorText := fmt.Sprintf("Unable to marshal status json: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, 500) http.Error(w, errorText, 500)
return return
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal) _, err = w.Write(jsonVal)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Unable to write response json: %s", err) errorText := fmt.Sprintf("Unable to write response json: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, 500) http.Error(w, errorText, 500)
return return
} }
} }

18
dhcp.go
View File

@ -58,7 +58,10 @@ func handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
} }
} }
if !newconfig.Enabled { if !newconfig.Enabled {
dhcpServer.Stop() err := dhcpServer.Stop()
if err != nil {
log.Printf("failed to stop the DHCP server: %s", err)
}
} }
config.DHCP = newconfig config.DHCP = newconfig
httpUpdateConfigReloadDNSReturnOK(w, r) httpUpdateConfigReloadDNSReturnOK(w, r)
@ -73,11 +76,6 @@ func handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
return return
} }
type address struct {
IP string
Netmask string
}
type responseInterface struct { type responseInterface struct {
Name string `json:"name"` Name string `json:"name"`
MTU int `json:"mtu"` MTU int `json:"mtu"`
@ -103,9 +101,9 @@ func handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
MTU: ifaces[i].MTU, MTU: ifaces[i].MTU,
HardwareAddr: ifaces[i].HardwareAddr.String(), HardwareAddr: ifaces[i].HardwareAddr.String(),
} }
addrs, err := ifaces[i].Addrs() addrs, errAddrs := ifaces[i].Addrs()
if err != nil { if errAddrs != nil {
httpError(w, http.StatusInternalServerError, "Failed to get addresses for interface %v: %s", ifaces[i].Name, err) httpError(w, http.StatusInternalServerError, "Failed to get addresses for interface %v: %s", ifaces[i].Name, errAddrs)
return return
} }
for _, addr := range addrs { for _, addr := range addrs {
@ -157,7 +155,7 @@ func handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Request) {
} }
func startDHCPServer() error { func startDHCPServer() error {
if config.DHCP.Enabled == false { if !config.DHCP.Enabled {
// not enabled, don't do anything // not enabled, don't do anything
return nil return nil
} }

View File

@ -13,6 +13,8 @@ import (
"github.com/krolaw/dhcp4" "github.com/krolaw/dhcp4"
) )
// CheckIfOtherDHCPServersPresent sends a DHCP request to the specified network interface,
// and waits for a response for a period defined by defaultDiscoverTime
func CheckIfOtherDHCPServersPresent(ifaceName string) (bool, error) { func CheckIfOtherDHCPServersPresent(ifaceName string) (bool, error) {
iface, err := net.InterfaceByName(ifaceName) iface, err := net.InterfaceByName(ifaceName)
if err != nil { if err != nil {
@ -30,8 +32,8 @@ func CheckIfOtherDHCPServersPresent(ifaceName string) (bool, error) {
dst := "255.255.255.255:67" dst := "255.255.255.255:67"
// form a DHCP request packet, try to emulate existing client as much as possible // form a DHCP request packet, try to emulate existing client as much as possible
xId := make([]byte, 8) xID := make([]byte, 8)
n, err := rand.Read(xId) n, err := rand.Read(xID)
if n != 8 && err == nil { if n != 8 && err == nil {
err = fmt.Errorf("Generated less than 8 bytes") err = fmt.Errorf("Generated less than 8 bytes")
} }
@ -57,16 +59,16 @@ func CheckIfOtherDHCPServersPresent(ifaceName string) (bool, error) {
maxUDPsizeRaw := make([]byte, 2) maxUDPsizeRaw := make([]byte, 2)
binary.BigEndian.PutUint16(maxUDPsizeRaw, 1500) binary.BigEndian.PutUint16(maxUDPsizeRaw, 1500)
leaseTimeRaw := make([]byte, 4) leaseTimeRaw := make([]byte, 4)
leaseTime := uint32(math.RoundToEven(time.Duration(time.Hour * 24 * 90).Seconds())) leaseTime := uint32(math.RoundToEven((time.Hour * 24 * 90).Seconds()))
binary.BigEndian.PutUint32(leaseTimeRaw, leaseTime) binary.BigEndian.PutUint32(leaseTimeRaw, leaseTime)
options := []dhcp4.Option{ options := []dhcp4.Option{
{dhcp4.OptionParameterRequestList, requestList}, {Code: dhcp4.OptionParameterRequestList, Value: requestList},
{dhcp4.OptionMaximumDHCPMessageSize, maxUDPsizeRaw}, {Code: dhcp4.OptionMaximumDHCPMessageSize, Value: maxUDPsizeRaw},
{dhcp4.OptionClientIdentifier, append([]byte{0x01}, iface.HardwareAddr...)}, {Code: dhcp4.OptionClientIdentifier, Value: append([]byte{0x01}, iface.HardwareAddr...)},
{dhcp4.OptionIPAddressLeaseTime, leaseTimeRaw}, {Code: dhcp4.OptionIPAddressLeaseTime, Value: leaseTimeRaw},
{dhcp4.OptionHostName, []byte(hostname)}, {Code: dhcp4.OptionHostName, Value: []byte(hostname)},
} }
packet := dhcp4.RequestPacket(dhcp4.Discover, iface.HardwareAddr, nil, xId, false, options) packet := dhcp4.RequestPacket(dhcp4.Discover, iface.HardwareAddr, nil, xID, false, options)
// resolve 0.0.0.0:68 // resolve 0.0.0.0:68
udpAddr, err := net.ResolveUDPAddr("udp4", src) udpAddr, err := net.ResolveUDPAddr("udp4", src)
@ -98,7 +100,7 @@ func CheckIfOtherDHCPServersPresent(ifaceName string) (bool, error) {
} }
// send to 255.255.255.255:67 // send to 255.255.255.255:67
n, err = c.WriteTo(packet, dstAddr) _, err = c.WriteTo(packet, dstAddr)
// spew.Dump(n, err) // spew.Dump(n, err)
if err != nil { if err != nil {
return false, wrapErrPrint(err, "Couldn't send a packet to %s", dst) return false, wrapErrPrint(err, "Couldn't send a packet to %s", dst)

View File

@ -13,6 +13,7 @@ import (
const defaultDiscoverTime = time.Second * 3 const defaultDiscoverTime = time.Second * 3
// Lease contains the necessary information about a DHCP lease
// field ordering is important -- yaml fields will mirror ordering from here // field ordering is important -- yaml fields will mirror ordering from here
type Lease struct { type Lease struct {
HWAddr net.HardwareAddr `json:"mac" yaml:"hwaddr"` HWAddr net.HardwareAddr `json:"mac" yaml:"hwaddr"`
@ -21,6 +22,7 @@ type Lease struct {
Expiry time.Time `json:"expires"` Expiry time.Time `json:"expires"`
} }
// ServerConfig - DHCP server configuration
// field ordering is important -- yaml fields will mirror ordering from here // field ordering is important -- yaml fields will mirror ordering from here
type ServerConfig struct { type ServerConfig struct {
Enabled bool `json:"enabled" yaml:"enabled"` Enabled bool `json:"enabled" yaml:"enabled"`
@ -32,6 +34,7 @@ type ServerConfig struct {
LeaseDuration uint `json:"lease_duration" yaml:"lease_duration"` // in seconds LeaseDuration uint `json:"lease_duration" yaml:"lease_duration"` // in seconds
} }
// Server - the current state of the DHCP server
type Server struct { type Server struct {
conn *filterConn // listening UDP socket conn *filterConn // listening UDP socket
@ -80,6 +83,7 @@ func (s *Server) Start(config *ServerConfig) error {
s.leaseStart, err = parseIPv4(s.RangeStart) s.leaseStart, err = parseIPv4(s.RangeStart)
if err != nil { if err != nil {
s.closeConn() // in case it was already started s.closeConn() // in case it was already started
return wrapErrPrint(err, "Failed to parse range start address %s", s.RangeStart) return wrapErrPrint(err, "Failed to parse range start address %s", s.RangeStart)
} }
@ -137,6 +141,7 @@ func (s *Server) Start(config *ServerConfig) error {
return nil return nil
} }
// Stop closes the listening UDP socket
func (s *Server) Stop() error { func (s *Server) Stop() error {
if s.conn == nil { if s.conn == nil {
// nothing to do, return silently // nothing to do, return silently
@ -174,7 +179,7 @@ func (s *Server) reserveLease(p dhcp4.Packet) (*Lease, error) {
} }
// not assigned a lease, create new one, find IP from LRU // not assigned a lease, create new one, find IP from LRU
log.Tracef("Lease not found for %s: creating new one", hwaddr) log.Tracef("Lease not found for %s: creating new one", hwaddr)
ip, err := s.findFreeIP(p, hwaddr) ip, err := s.findFreeIP(hwaddr)
if err != nil { if err != nil {
return nil, wrapErrPrint(err, "Couldn't find free IP for the lease %s", hwaddr.String()) return nil, wrapErrPrint(err, "Couldn't find free IP for the lease %s", hwaddr.String())
} }
@ -198,7 +203,7 @@ func (s *Server) locateLease(p dhcp4.Packet) *Lease {
return nil return nil
} }
func (s *Server) findFreeIP(p dhcp4.Packet, hwaddr net.HardwareAddr) (net.IP, error) { func (s *Server) findFreeIP(hwaddr net.HardwareAddr) (net.IP, error) {
// if IP pool is nil, lazy initialize it // if IP pool is nil, lazy initialize it
if s.IPpool == nil { if s.IPpool == nil {
s.IPpool = make(map[[4]byte]net.HardwareAddr) s.IPpool = make(map[[4]byte]net.HardwareAddr)
@ -223,7 +228,7 @@ func (s *Server) findFreeIP(p dhcp4.Packet, hwaddr net.HardwareAddr) (net.IP, er
if foundIP == nil { if foundIP == nil {
// TODO: LRU // TODO: LRU
return nil, fmt.Errorf("Couldn't find free entry in IP pool") return nil, fmt.Errorf("couldn't find free entry in IP pool")
} }
s.reserveIP(foundIP, hwaddr) s.reserveIP(foundIP, hwaddr)
@ -249,6 +254,7 @@ func (s *Server) unreserveIP(ip net.IP) {
delete(s.IPpool, IP4) delete(s.IPpool, IP4)
} }
// ServeDHCP handles an incoming DHCP request
func (s *Server) ServeDHCP(p dhcp4.Packet, msgType dhcp4.MessageType, options dhcp4.Options) dhcp4.Packet { func (s *Server) ServeDHCP(p dhcp4.Packet, msgType dhcp4.MessageType, options dhcp4.Options) dhcp4.Packet {
log.Tracef("Got %v message", msgType) log.Tracef("Got %v message", msgType)
log.Tracef("Leases:") log.Tracef("Leases:")
@ -259,27 +265,6 @@ func (s *Server) ServeDHCP(p dhcp4.Packet, msgType dhcp4.MessageType, options dh
for ip, hwaddr := range s.IPpool { for ip, hwaddr := range s.IPpool {
log.Tracef("IP pool entry %s -> %s", net.IPv4(ip[0], ip[1], ip[2], ip[3]), hwaddr) log.Tracef("IP pool entry %s -> %s", net.IPv4(ip[0], ip[1], ip[2], ip[3]), hwaddr)
} }
// spew.Dump(s.leases, s.IPpool)
// log.Printf("Called with msgType = %v, options = %+v", msgType, options)
// spew.Dump(p)
// log.Printf("%14s %v", "p.Broadcast", p.Broadcast()) // false
// log.Printf("%14s %v", "p.CHAddr", p.CHAddr()) // 2c:f0:a2:f2:31:00
// log.Printf("%14s %v", "p.CIAddr", p.CIAddr()) // 0.0.0.0
// log.Printf("%14s %v", "p.Cookie", p.Cookie()) // [99 130 83 99]
// log.Printf("%14s %v", "p.File", p.File()) // []
// log.Printf("%14s %v", "p.Flags", p.Flags()) // [0 0]
// log.Printf("%14s %v", "p.GIAddr", p.GIAddr()) // 0.0.0.0
// log.Printf("%14s %v", "p.HLen", p.HLen()) // 6
// log.Printf("%14s %v", "p.HType", p.HType()) // 1
// log.Printf("%14s %v", "p.Hops", p.Hops()) // 0
// log.Printf("%14s %v", "p.OpCode", p.OpCode()) // BootRequest
// log.Printf("%14s %v", "p.Options", p.Options()) // [53 1 1 55 10 1 121 3 6 15 119 252 95 44 46 57 2 5 220 61 7 1 44 240 162 242 49 0 51 4 0 118 167 0 12 4 119 104 109 100 255 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
// log.Printf("%14s %v", "p.ParseOptions", p.ParseOptions()) // map[OptionParameterRequestList:[1 121 3 6 15 119 252 95 44 46] OptionDHCPMessageType:[1] OptionMaximumDHCPMessageSize:[5 220] OptionClientIdentifier:[1 44 240 162 242 49 0] OptionIPAddressLeaseTime:[0 118 167 0] OptionHostName:[119 104 109 100]]
// log.Printf("%14s %v", "p.SIAddr", p.SIAddr()) // 0.0.0.0
// log.Printf("%14s %v", "p.SName", p.SName()) // []
// log.Printf("%14s %v", "p.Secs", p.Secs()) // [0 8]
// log.Printf("%14s %v", "p.XId", p.XId()) // [211 184 20 44]
// log.Printf("%14s %v", "p.YIAddr", p.YIAddr()) // 0.0.0.0
switch msgType { switch msgType {
case dhcp4.Discover: // Broadcast Packet From Client - Can I have an IP? case dhcp4.Discover: // Broadcast Packet From Client - Can I have an IP?
@ -297,75 +282,7 @@ func (s *Server) ServeDHCP(p dhcp4.Packet, msgType dhcp4.MessageType, options dh
case dhcp4.Request: // Broadcast From Client - I'll take that IP (Also start for renewals) case dhcp4.Request: // Broadcast From Client - I'll take that IP (Also start for renewals)
// start/renew a lease -- update lease time // start/renew a lease -- update lease time
// some clients (OSX) just go right ahead and do Request first from previously known IP, if they get NAK, they restart full cycle with Discover then Request // some clients (OSX) just go right ahead and do Request first from previously known IP, if they get NAK, they restart full cycle with Discover then Request
log.Tracef("Got from client: Request") return s.handleDHCP4Request(p, options)
if server, ok := options[dhcp4.OptionServerIdentifier]; ok && !net.IP(server).Equal(s.ipnet.IP) {
log.Tracef("Request message not for this DHCP server (%v vs %v)", server, s.ipnet.IP)
return nil // Message not for this dhcp server
}
reqIP := net.IP(options[dhcp4.OptionRequestedIPAddress])
if reqIP == nil {
reqIP = net.IP(p.CIAddr())
}
if reqIP.To4() == nil {
log.Tracef("Replying with NAK: request IP isn't valid IPv4: %s", reqIP)
return dhcp4.ReplyPacket(p, dhcp4.NAK, s.ipnet.IP, nil, 0, nil)
}
if reqIP.Equal(net.IPv4zero) {
log.Tracef("Replying with NAK: request IP is 0.0.0.0")
return dhcp4.ReplyPacket(p, dhcp4.NAK, s.ipnet.IP, nil, 0, nil)
}
log.Tracef("requested IP is %s", reqIP)
lease, err := s.reserveLease(p)
if err != nil {
log.Tracef("Couldn't find free lease: %s", err)
// couldn't find lease, don't respond
return nil
}
if lease.IP.Equal(reqIP) {
// IP matches lease IP, nothing else to do
lease.Expiry = time.Now().Add(s.leaseTime)
log.Tracef("Replying with ACK: request IP matches lease IP, nothing else to do. IP %v for %v", lease.IP, p.CHAddr())
return dhcp4.ReplyPacket(p, dhcp4.ACK, s.ipnet.IP, lease.IP, s.leaseTime, s.leaseOptions.SelectOrderOrAll(options[dhcp4.OptionParameterRequestList]))
}
//
// requested IP different from lease
//
log.Tracef("lease IP is different from requested IP: %s vs %s", lease.IP, reqIP)
hwaddr := s.getIPpool(reqIP)
if hwaddr == nil {
// not in pool, check if it's in DHCP range
if dhcp4.IPInRange(s.leaseStart, s.leaseStop, reqIP) {
// okay, we can give it to our client -- it's in our DHCP range and not taken, so let them use their IP
log.Tracef("Replying with ACK: request IP %v is not taken, so assigning lease IP %v to it, for %v", reqIP, lease.IP, p.CHAddr())
s.unreserveIP(lease.IP)
lease.IP = reqIP
s.reserveIP(reqIP, p.CHAddr())
lease.Expiry = time.Now().Add(s.leaseTime)
return dhcp4.ReplyPacket(p, dhcp4.ACK, s.ipnet.IP, lease.IP, s.leaseTime, s.leaseOptions.SelectOrderOrAll(options[dhcp4.OptionParameterRequestList]))
}
}
if hwaddr != nil && !bytes.Equal(hwaddr, lease.HWAddr) {
log.Printf("SHOULD NOT HAPPEN: IP pool hwaddr does not match lease hwaddr: %s vs %s", hwaddr, lease.HWAddr)
}
// requsted IP is not sufficient, reply with NAK
if hwaddr != nil {
log.Tracef("Replying with NAK: request IP %s is taken, asked by %v", reqIP, p.CHAddr())
return dhcp4.ReplyPacket(p, dhcp4.NAK, s.ipnet.IP, nil, 0, nil)
}
// requested IP is outside of DHCP range
log.Tracef("Replying with NAK: request IP %s is outside of DHCP range [%s, %s], asked by %v", reqIP, s.leaseStart, s.leaseStop, p.CHAddr())
return dhcp4.ReplyPacket(p, dhcp4.NAK, s.ipnet.IP, nil, 0, nil)
case dhcp4.Decline: // Broadcast From Client - Sorry I can't use that IP case dhcp4.Decline: // Broadcast From Client - Sorry I can't use that IP
log.Tracef("Got from client: Decline") log.Tracef("Got from client: Decline")
@ -390,6 +307,79 @@ func (s *Server) ServeDHCP(p dhcp4.Packet, msgType dhcp4.MessageType, options dh
return nil return nil
} }
func (s *Server) handleDHCP4Request(p dhcp4.Packet, options dhcp4.Options) dhcp4.Packet {
log.Tracef("Got from client: Request")
if server, ok := options[dhcp4.OptionServerIdentifier]; ok && !net.IP(server).Equal(s.ipnet.IP) {
log.Tracef("Request message not for this DHCP server (%v vs %v)", server, s.ipnet.IP)
return nil // Message not for this dhcp server
}
reqIP := net.IP(options[dhcp4.OptionRequestedIPAddress])
if reqIP == nil {
reqIP = p.CIAddr()
}
if reqIP.To4() == nil {
log.Tracef("Replying with NAK: request IP isn't valid IPv4: %s", reqIP)
return dhcp4.ReplyPacket(p, dhcp4.NAK, s.ipnet.IP, nil, 0, nil)
}
if reqIP.Equal(net.IPv4zero) {
log.Tracef("Replying with NAK: request IP is 0.0.0.0")
return dhcp4.ReplyPacket(p, dhcp4.NAK, s.ipnet.IP, nil, 0, nil)
}
log.Tracef("requested IP is %s", reqIP)
lease, err := s.reserveLease(p)
if err != nil {
log.Tracef("Couldn't find free lease: %s", err)
// couldn't find lease, don't respond
return nil
}
if lease.IP.Equal(reqIP) {
// IP matches lease IP, nothing else to do
lease.Expiry = time.Now().Add(s.leaseTime)
log.Tracef("Replying with ACK: request IP matches lease IP, nothing else to do. IP %v for %v", lease.IP, p.CHAddr())
return dhcp4.ReplyPacket(p, dhcp4.ACK, s.ipnet.IP, lease.IP, s.leaseTime, s.leaseOptions.SelectOrderOrAll(options[dhcp4.OptionParameterRequestList]))
}
//
// requested IP different from lease
//
log.Tracef("lease IP is different from requested IP: %s vs %s", lease.IP, reqIP)
hwaddr := s.getIPpool(reqIP)
if hwaddr == nil {
// not in pool, check if it's in DHCP range
if dhcp4.IPInRange(s.leaseStart, s.leaseStop, reqIP) {
// okay, we can give it to our client -- it's in our DHCP range and not taken, so let them use their IP
log.Tracef("Replying with ACK: request IP %v is not taken, so assigning lease IP %v to it, for %v", reqIP, lease.IP, p.CHAddr())
s.unreserveIP(lease.IP)
lease.IP = reqIP
s.reserveIP(reqIP, p.CHAddr())
lease.Expiry = time.Now().Add(s.leaseTime)
return dhcp4.ReplyPacket(p, dhcp4.ACK, s.ipnet.IP, lease.IP, s.leaseTime, s.leaseOptions.SelectOrderOrAll(options[dhcp4.OptionParameterRequestList]))
}
}
if hwaddr != nil && !bytes.Equal(hwaddr, lease.HWAddr) {
log.Printf("SHOULD NOT HAPPEN: IP pool hwaddr does not match lease hwaddr: %s vs %s", hwaddr, lease.HWAddr)
}
// requsted IP is not sufficient, reply with NAK
if hwaddr != nil {
log.Tracef("Replying with NAK: request IP %s is taken, asked by %v", reqIP, p.CHAddr())
return dhcp4.ReplyPacket(p, dhcp4.NAK, s.ipnet.IP, nil, 0, nil)
}
// requested IP is outside of DHCP range
log.Tracef("Replying with NAK: request IP %s is outside of DHCP range [%s, %s], asked by %v", reqIP, s.leaseStart, s.leaseStop, p.CHAddr())
return dhcp4.ReplyPacket(p, dhcp4.NAK, s.ipnet.IP, nil, 0, nil)
}
// Leases returns the list of current DHCP leases
func (s *Server) Leases() []*Lease { func (s *Server) Leases() []*Lease {
s.RLock() s.RLock()
result := s.leases result := s.leases

View File

@ -8,7 +8,7 @@ import (
) )
// filterConn listens to 0.0.0.0:67, but accepts packets only from specific interface // filterConn listens to 0.0.0.0:67, but accepts packets only from specific interface
// This is neccessary for DHCP daemon to work, since binding to IP address doesn't // This is necessary for DHCP daemon to work, since binding to IP address doesn't
// us access to see Discover/Request packets from clients. // us access to see Discover/Request packets from clients.
// //
// TODO: on windows, controlmessage does not work, try to find out another way // TODO: on windows, controlmessage does not work, try to find out another way
@ -49,7 +49,6 @@ func (f *filterConn) ReadFrom(b []byte) (int, net.Addr, error) {
} }
// packet doesn't match criteria, drop it // packet doesn't match criteria, drop it
} }
return 0, nil, nil
} }
func (f *filterConn) WriteTo(b []byte, addr net.Addr) (int, error) { func (f *filterConn) WriteTo(b []byte, addr net.Addr) (int, error) {

View File

@ -3,7 +3,6 @@ package dhcpd
import ( import (
"fmt" "fmt"
"net" "net"
"strings"
"github.com/hmage/golibs/log" "github.com/hmage/golibs/log"
"github.com/joomcode/errorx" "github.com/joomcode/errorx"
@ -45,22 +44,6 @@ func getIfaceIPv4(iface *net.Interface) *net.IPNet {
return nil return nil
} }
func isConnClosed(err error) bool {
if err == nil {
return false
}
nerr, ok := err.(*net.OpError)
if !ok {
return false
}
if strings.Contains(nerr.Err.Error(), "use of closed network connection") {
return true
}
return false
}
func wrapErrPrint(err error, message string, args ...interface{}) error { func wrapErrPrint(err error, message string, args ...interface{}) error {
var errx error var errx error
if err == nil { if err == nil {

View File

@ -76,9 +76,9 @@ func main() {
panic(err) panic(err)
} }
log.Printf("Now serving DHCP") log.Printf("Now serving DHCP")
signal_channel := make(chan os.Signal) signalChannel := make(chan os.Signal)
signal.Notify(signal_channel, syscall.SIGINT, syscall.SIGTERM) signal.Notify(signalChannel, syscall.SIGINT, syscall.SIGTERM)
<-signal_channel <-signalChannel
} }

6
dns.go
View File

@ -39,20 +39,20 @@ func generateServerConfig() dnsforward.ServerConfig {
} }
for _, u := range config.DNS.UpstreamDNS { for _, u := range config.DNS.UpstreamDNS {
upstream, err := upstream.AddressToUpstream(u, config.DNS.BootstrapDNS, dnsforward.DefaultTimeout) dnsUpstream, err := upstream.AddressToUpstream(u, config.DNS.BootstrapDNS, dnsforward.DefaultTimeout)
if err != nil { if err != nil {
log.Printf("Couldn't get upstream: %s", err) log.Printf("Couldn't get upstream: %s", err)
// continue, just ignore the upstream // continue, just ignore the upstream
continue continue
} }
newconfig.Upstreams = append(newconfig.Upstreams, upstream) newconfig.Upstreams = append(newconfig.Upstreams, dnsUpstream)
} }
return newconfig return newconfig
} }
func startDNSServer() error { func startDNSServer() error {
if isRunning() { if isRunning() {
return fmt.Errorf("Unable to start forwarding DNS server: Already running") return fmt.Errorf("unable to start forwarding DNS server: Already running")
} }
newconfig := generateServerConfig() newconfig := generateServerConfig()

View File

@ -35,7 +35,7 @@ const defaultParentalURL = "http://%s/check-parental-control-hash?prefixes=%s&se
// ErrInvalidSyntax is returned by AddRule when the rule is invalid // ErrInvalidSyntax is returned by AddRule when the rule is invalid
var ErrInvalidSyntax = errors.New("dnsfilter: invalid rule syntax") var ErrInvalidSyntax = errors.New("dnsfilter: invalid rule syntax")
// ErrInvalidSyntax is returned by AddRule when the rule was already added to the filter // ErrAlreadyExists is returned by AddRule when the rule was already added to the filter
var ErrAlreadyExists = errors.New("dnsfilter: rule was already added") var ErrAlreadyExists = errors.New("dnsfilter: rule was already added")
const shortcutLength = 6 // used for rule search optimization, 6 hits the sweet spot const shortcutLength = 6 // used for rule search optimization, 6 hits the sweet spot
@ -115,6 +115,7 @@ type Dnsfilter struct {
privateConfig privateConfig
} }
// Filter represents a filter list
type Filter struct { type Filter struct {
ID int64 `json:"id"` // auto-assigned when filter is added (see nextFilterID), json by default keeps ID uppercase but we need lowercase ID int64 `json:"id"` // auto-assigned when filter is added (see nextFilterID), json by default keeps ID uppercase but we need lowercase
Rules []string `json:"-" yaml:"-"` // not in yaml or json Rules []string `json:"-" yaml:"-"` // not in yaml or json
@ -127,16 +128,26 @@ type Reason int
const ( const (
// reasons for not filtering // reasons for not filtering
NotFilteredNotFound Reason = iota // host was not find in any checks, default value for result
NotFilteredWhiteList // the host is explicitly whitelisted // NotFilteredNotFound - host was not find in any checks, default value for result
NotFilteredError // there was a transitive error during check NotFilteredNotFound Reason = iota
// NotFilteredWhiteList - the host is explicitly whitelisted
NotFilteredWhiteList
// NotFilteredError - there was a transitive error during check
NotFilteredError
// reasons for filtering // reasons for filtering
FilteredBlackList // the host was matched to be advertising host
FilteredSafeBrowsing // the host was matched to be malicious/phishing // FilteredBlackList - the host was matched to be advertising host
FilteredParental // the host was matched to be outside of parental control settings FilteredBlackList
FilteredInvalid // the request was invalid and was not processed // FilteredSafeBrowsing - the host was matched to be malicious/phishing
FilteredSafeSearch // the host was replaced with safesearch variant FilteredSafeBrowsing
// FilteredParental - the host was matched to be outside of parental control settings
FilteredParental
// FilteredInvalid - the request was invalid and was not processed
FilteredInvalid
// FilteredSafeSearch - the host was replaced with safesearch variant
FilteredSafeSearch
) )
// these variables need to survive coredns reload // these variables need to survive coredns reload
@ -151,7 +162,7 @@ type Result struct {
IsFiltered bool `json:",omitempty"` // True if the host name is filtered IsFiltered bool `json:",omitempty"` // True if the host name is filtered
Reason Reason `json:",omitempty"` // Reason for blocking / unblocking Reason Reason `json:",omitempty"` // Reason for blocking / unblocking
Rule string `json:",omitempty"` // Original rule text Rule string `json:",omitempty"` // Original rule text
Ip net.IP `json:",omitempty"` // Not nil only in the case of a hosts file syntax IP net.IP `json:",omitempty"` // Not nil only in the case of a hosts file syntax
FilterID int64 `json:",omitempty"` // Filter ID the rule belongs to FilterID int64 `json:",omitempty"` // Filter ID the rule belongs to
} }
@ -228,7 +239,6 @@ func newRulesTable() *rulesTable {
func (r *rulesTable) Add(rule *rule) { func (r *rulesTable) Add(rule *rule) {
r.Lock() r.Lock()
if rule.ip != nil { if rule.ip != nil {
// Hosts syntax // Hosts syntax
r.rulesByHost[rule.text] = rule r.rulesByHost[rule.text] = rule
@ -476,7 +486,7 @@ func (rule *rule) match(host string) (Result, error) {
IsFiltered: true, IsFiltered: true,
Reason: FilteredBlackList, Reason: FilteredBlackList,
Rule: rule.originalText, Rule: rule.originalText,
Ip: rule.ip, IP: rule.ip,
FilterID: rule.listID, FilterID: rule.listID,
}, nil }, nil
} }
@ -661,8 +671,11 @@ func (d *Dnsfilter) checkParental(host string) (Result, error) {
return result, err return result, err
} }
type formatHandler func(hashparam string) string
type bodyHandler func(body []byte, hashes map[string]bool) (Result, error)
// real implementation of lookup/check // real implementation of lookup/check
func (d *Dnsfilter) lookupCommon(host string, lookupstats *LookupStats, cache gcache.Cache, hashparamNeedSlash bool, format func(hashparam string) string, handleBody func(body []byte, hashes map[string]bool) (Result, error)) (Result, error) { func (d *Dnsfilter) lookupCommon(host string, lookupstats *LookupStats, cache gcache.Cache, hashparamNeedSlash bool, format formatHandler, handleBody bodyHandler) (Result, error) {
// if host ends with a dot, trim it // if host ends with a dot, trim it
host = strings.ToLower(strings.Trim(host, ".")) host = strings.ToLower(strings.Trim(host, "."))
@ -774,43 +787,43 @@ func (d *Dnsfilter) AddRule(input string, filterListID int64) error {
} }
// Start parsing the rule // Start parsing the rule
rule := rule{ r := rule{
text: input, // will be modified text: input, // will be modified
originalText: input, originalText: input,
listID: filterListID, listID: filterListID,
} }
// Mark rule as whitelist if it starts with @@ // Mark rule as whitelist if it starts with @@
if strings.HasPrefix(rule.text, "@@") { if strings.HasPrefix(r.text, "@@") {
rule.isWhitelist = true r.isWhitelist = true
rule.text = rule.text[2:] r.text = r.text[2:]
} }
err := rule.parseOptions() err := r.parseOptions()
if err != nil { if err != nil {
return err return err
} }
rule.extractShortcut() r.extractShortcut()
if !enableDelayedCompilation { if !enableDelayedCompilation {
err := rule.compile() err := r.compile()
if err != nil { if err != nil {
return err return err
} }
} }
destination := d.blackList destination := d.blackList
if rule.isImportant { if r.isImportant {
destination = d.important destination = d.important
} else if rule.isWhitelist { } else if r.isWhitelist {
destination = d.whiteList destination = d.whiteList
} }
d.storageMutex.Lock() d.storageMutex.Lock()
d.storage[input] = true d.storage[input] = true
d.storageMutex.Unlock() d.storageMutex.Unlock()
destination.Add(&rule) destination.Add(&r)
return nil return nil
} }
@ -835,13 +848,13 @@ func (d *Dnsfilter) parseEtcHosts(input string, filterListID int64) bool {
d.storageMutex.Unlock() d.storageMutex.Unlock()
for _, host := range fields[1:] { for _, host := range fields[1:] {
rule := rule{ r := rule{
text: host, text: host,
originalText: input, originalText: input,
listID: filterListID, listID: filterListID,
ip: addr, ip: addr,
} }
d.blackList.Add(&rule) d.blackList.Add(&r)
} }
return true return true
} }
@ -913,15 +926,6 @@ func (d *Dnsfilter) Destroy() {
// config manipulation helpers // config manipulation helpers
// //
// IsParentalSensitivityValid checks if sensitivity is valid value
func IsParentalSensitivityValid(sensitivity int) bool {
switch sensitivity {
case 3, 10, 13, 17:
return true
}
return false
}
// SetSafeBrowsingServer lets you optionally change hostname of safesearch lookup // SetSafeBrowsingServer lets you optionally change hostname of safesearch lookup
func (d *Dnsfilter) SetSafeBrowsingServer(host string) { func (d *Dnsfilter) SetSafeBrowsingServer(host string) {
if len(host) == 0 { if len(host) == 0 {

View File

@ -26,7 +26,7 @@ import (
func TestLotsOfRulesMemoryUsage(t *testing.T) { func TestLotsOfRulesMemoryUsage(t *testing.T) {
start := getRSS() start := getRSS()
log.Tracef("RSS before loading rules - %d kB\n", start/1024) log.Tracef("RSS before loading rules - %d kB\n", start/1024)
dumpMemProfile(_Func() + "1.pprof") dumpMemProfile("tests/" + _Func() + "1.pprof")
d := NewForTest() d := NewForTest()
defer d.Destroy() defer d.Destroy()
@ -37,7 +37,7 @@ func TestLotsOfRulesMemoryUsage(t *testing.T) {
afterLoad := getRSS() afterLoad := getRSS()
log.Tracef("RSS after loading rules - %d kB (%d kB diff)\n", afterLoad/1024, (afterLoad-start)/1024) log.Tracef("RSS after loading rules - %d kB (%d kB diff)\n", afterLoad/1024, (afterLoad-start)/1024)
dumpMemProfile(_Func() + "2.pprof") dumpMemProfile("tests/" + _Func() + "2.pprof")
tests := []struct { tests := []struct {
host string host string
@ -60,7 +60,7 @@ func TestLotsOfRulesMemoryUsage(t *testing.T) {
} }
afterMatch := getRSS() afterMatch := getRSS()
log.Tracef("RSS after matching - %d kB (%d kB diff)\n", afterMatch/1024, (afterMatch-afterLoad)/1024) log.Tracef("RSS after matching - %d kB (%d kB diff)\n", afterMatch/1024, (afterMatch-afterLoad)/1024)
dumpMemProfile(_Func() + "3.pprof") dumpMemProfile("tests/" + _Func() + "3.pprof")
} }
func getRSS() uint64 { func getRSS() uint64 {
@ -69,6 +69,9 @@ func getRSS() uint64 {
panic(err) panic(err)
} }
minfo, err := proc.MemoryInfo() minfo, err := proc.MemoryInfo()
if err != nil {
panic(err)
}
return minfo.RSS return minfo.RSS
} }
@ -86,7 +89,7 @@ func dumpMemProfile(name string) {
} }
} }
const topHostsFilename = "../tests/top-1m.csv" const topHostsFilename = "tests/top-1m.csv"
func fetchTopHostsFromNet() { func fetchTopHostsFromNet() {
log.Tracef("Fetching top hosts from network") log.Tracef("Fetching top hosts from network")
@ -146,7 +149,7 @@ func getTopHosts() {
func TestLotsOfRulesLotsOfHostsMemoryUsage(t *testing.T) { func TestLotsOfRulesLotsOfHostsMemoryUsage(t *testing.T) {
start := getRSS() start := getRSS()
log.Tracef("RSS before loading rules - %d kB\n", start/1024) log.Tracef("RSS before loading rules - %d kB\n", start/1024)
dumpMemProfile(_Func() + "1.pprof") dumpMemProfile("tests/" + _Func() + "1.pprof")
d := NewForTest() d := NewForTest()
defer d.Destroy() defer d.Destroy()
@ -155,7 +158,7 @@ func TestLotsOfRulesLotsOfHostsMemoryUsage(t *testing.T) {
afterLoad := getRSS() afterLoad := getRSS()
log.Tracef("RSS after loading rules - %d kB (%d kB diff)\n", afterLoad/1024, (afterLoad-start)/1024) log.Tracef("RSS after loading rules - %d kB (%d kB diff)\n", afterLoad/1024, (afterLoad-start)/1024)
dumpMemProfile(_Func() + "2.pprof") dumpMemProfile("tests/" + _Func() + "2.pprof")
getTopHosts() getTopHosts()
hostnames, err := os.Open(topHostsFilename) hostnames, err := os.Open(topHostsFilename)
@ -165,7 +168,7 @@ func TestLotsOfRulesLotsOfHostsMemoryUsage(t *testing.T) {
defer hostnames.Close() defer hostnames.Close()
afterHosts := getRSS() afterHosts := getRSS()
log.Tracef("RSS after loading hosts - %d kB (%d kB diff)\n", afterHosts/1024, (afterHosts-afterLoad)/1024) log.Tracef("RSS after loading hosts - %d kB (%d kB diff)\n", afterHosts/1024, (afterHosts-afterLoad)/1024)
dumpMemProfile(_Func() + "2.pprof") dumpMemProfile("tests/" + _Func() + "2.pprof")
{ {
scanner := bufio.NewScanner(hostnames) scanner := bufio.NewScanner(hostnames)
@ -184,7 +187,7 @@ func TestLotsOfRulesLotsOfHostsMemoryUsage(t *testing.T) {
afterMatch := getRSS() afterMatch := getRSS()
log.Tracef("RSS after matching - %d kB (%d kB diff)\n", afterMatch/1024, (afterMatch-afterLoad)/1024) log.Tracef("RSS after matching - %d kB (%d kB diff)\n", afterMatch/1024, (afterMatch-afterLoad)/1024)
dumpMemProfile(_Func() + "3.pprof") dumpMemProfile("tests/" + _Func() + "3.pprof")
} }
func TestRuleToRegexp(t *testing.T) { func TestRuleToRegexp(t *testing.T) {
@ -282,7 +285,7 @@ func (d *Dnsfilter) checkMatch(t *testing.T, hostname string) {
} }
} }
func (d *Dnsfilter) checkMatchIp(t *testing.T, hostname string, ip string) { func (d *Dnsfilter) checkMatchIP(t *testing.T, hostname string, ip string) {
t.Helper() t.Helper()
ret, err := d.CheckHost(hostname) ret, err := d.CheckHost(hostname)
if err != nil { if err != nil {
@ -291,8 +294,8 @@ func (d *Dnsfilter) checkMatchIp(t *testing.T, hostname string, ip string) {
if !ret.IsFiltered { if !ret.IsFiltered {
t.Errorf("Expected hostname %s to match", hostname) t.Errorf("Expected hostname %s to match", hostname)
} }
if ret.Ip == nil || ret.Ip.String() != ip { if ret.IP == nil || ret.IP.String() != ip {
t.Errorf("Expected ip %s to match, actual: %v", ip, ret.Ip) t.Errorf("Expected ip %s to match, actual: %v", ip, ret.IP)
} }
} }
@ -308,7 +311,7 @@ func (d *Dnsfilter) checkMatchEmpty(t *testing.T, hostname string) {
} }
func loadTestRules(d *Dnsfilter) error { func loadTestRules(d *Dnsfilter) error {
filterFileName := "../tests/dns.txt" filterFileName := "tests/dns.txt"
file, err := os.Open(filterFileName) file, err := os.Open(filterFileName)
if err != nil { if err != nil {
return err return err
@ -368,8 +371,8 @@ func TestEtcHostsMatching(t *testing.T) {
text := fmt.Sprintf(" %s google.com www.google.com # enforce google's safesearch ", addr) text := fmt.Sprintf(" %s google.com www.google.com # enforce google's safesearch ", addr)
d.checkAddRule(t, text) d.checkAddRule(t, text)
d.checkMatchIp(t, "google.com", addr) d.checkMatchIP(t, "google.com", addr)
d.checkMatchIp(t, "www.google.com", addr) d.checkMatchIP(t, "www.google.com", addr)
d.checkMatchEmpty(t, "subdomain.google.com") d.checkMatchEmpty(t, "subdomain.google.com")
d.checkMatchEmpty(t, "example.org") d.checkMatchEmpty(t, "example.org")
} }
@ -866,7 +869,7 @@ func BenchmarkLotsOfRulesLotsOfHosts(b *testing.B) {
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
havedata := scanner.Scan() havedata := scanner.Scan()
if !havedata { if !havedata {
hostnames.Seek(0, 0) _, _ = hostnames.Seek(0, 0)
scanner = bufio.NewScanner(hostnames) scanner = bufio.NewScanner(hostnames)
havedata = scanner.Scan() havedata = scanner.Scan()
} }
@ -903,7 +906,7 @@ func BenchmarkLotsOfRulesLotsOfHostsParallel(b *testing.B) {
for pb.Next() { for pb.Next() {
havedata := scanner.Scan() havedata := scanner.Scan()
if !havedata { if !havedata {
hostnames.Seek(0, 0) _, _ = hostnames.Seek(0, 0)
scanner = bufio.NewScanner(hostnames) scanner = bufio.NewScanner(hostnames)
havedata = scanner.Scan() havedata = scanner.Scan()
} }

View File

@ -310,8 +310,8 @@ func (s *Server) genDNSFilterMessage(d *proxy.DNSContext, result *dnsfilter.Resu
case dnsfilter.FilteredParental: case dnsfilter.FilteredParental:
return s.genBlockedHost(m, parentalBlockHost, d.Upstream) return s.genBlockedHost(m, parentalBlockHost, d.Upstream)
default: default:
if result.Ip != nil { if result.IP != nil {
return s.genARecord(m, result.Ip) return s.genARecord(m, result.IP)
} }
return s.genNXDomain(m) return s.genNXDomain(m)

View File

@ -108,10 +108,16 @@ func logRequest(question *dns.Msg, answer *dns.Msg, result *dnsfilter.Result, el
if len(flushBuffer) > 0 { if len(flushBuffer) > 0 {
// write to file // write to file
// do it in separate goroutine -- we are stalling DNS response this whole time // do it in separate goroutine -- we are stalling DNS response this whole time
go flushToFile(flushBuffer) go func() {
err := flushToFile(flushBuffer)
if err != nil {
log.Printf("Failed to flush the query log: %s", err)
}
}()
} }
} }
// HandleQueryLog handles query log web request
func HandleQueryLog(w http.ResponseWriter, r *http.Request) { func HandleQueryLog(w http.ResponseWriter, r *http.Request) {
queryLogLock.RLock() queryLogLock.RLock()
values := make([]*logEntry, len(queryLogCache)) values := make([]*logEntry, len(queryLogCache))
@ -123,6 +129,7 @@ func HandleQueryLog(w http.ResponseWriter, r *http.Request) {
values[left], values[right] = values[right], values[left] values[left], values[right] = values[right], values[left]
} }
// iterate
var data = []map[string]interface{}{} var data = []map[string]interface{}{}
for _, entry := range values { for _, entry := range values {
var q *dns.Msg var q *dns.Msg
@ -167,46 +174,8 @@ func HandleQueryLog(w http.ResponseWriter, r *http.Request) {
jsonEntry["filterId"] = entry.Result.FilterID jsonEntry["filterId"] = entry.Result.FilterID
} }
if a != nil && len(a.Answer) > 0 { answers := answerToMap(a)
var answers = []map[string]interface{}{} if answers != nil {
for _, k := range a.Answer {
header := k.Header()
answer := map[string]interface{}{
"type": dns.TypeToString[header.Rrtype],
"ttl": header.Ttl,
}
// try most common record types
switch v := k.(type) {
case *dns.A:
answer["value"] = v.A
case *dns.AAAA:
answer["value"] = v.AAAA
case *dns.MX:
answer["value"] = fmt.Sprintf("%v %v", v.Preference, v.Mx)
case *dns.CNAME:
answer["value"] = v.Target
case *dns.NS:
answer["value"] = v.Ns
case *dns.SPF:
answer["value"] = v.Txt
case *dns.TXT:
answer["value"] = v.Txt
case *dns.PTR:
answer["value"] = v.Ptr
case *dns.SOA:
answer["value"] = fmt.Sprintf("%v %v %v %v %v %v %v", v.Ns, v.Mbox, v.Serial, v.Refresh, v.Retry, v.Expire, v.Minttl)
case *dns.CAA:
answer["value"] = fmt.Sprintf("%v %v \"%v\"", v.Flag, v.Tag, v.Value)
case *dns.HINFO:
answer["value"] = fmt.Sprintf("\"%v\" \"%v\"", v.Cpu, v.Os)
case *dns.RRSIG:
answer["value"] = fmt.Sprintf("%v %v %v %v %v %v %v %v %v", dns.TypeToString[v.TypeCovered], v.Algorithm, v.Labels, v.OrigTtl, v.Expiration, v.Inception, v.KeyTag, v.SignerName, v.Signature)
default:
// type unknown, marshall it as-is
answer["value"] = v
}
answers = append(answers, answer)
}
jsonEntry["answer"] = answers jsonEntry["answer"] = answers
} }
@ -230,6 +199,54 @@ func HandleQueryLog(w http.ResponseWriter, r *http.Request) {
} }
} }
func answerToMap(a *dns.Msg) []map[string]interface{} {
if a == nil || len(a.Answer) == 0 {
return nil
}
var answers = []map[string]interface{}{}
for _, k := range a.Answer {
header := k.Header()
answer := map[string]interface{}{
"type": dns.TypeToString[header.Rrtype],
"ttl": header.Ttl,
}
// try most common record types
switch v := k.(type) {
case *dns.A:
answer["value"] = v.A
case *dns.AAAA:
answer["value"] = v.AAAA
case *dns.MX:
answer["value"] = fmt.Sprintf("%v %v", v.Preference, v.Mx)
case *dns.CNAME:
answer["value"] = v.Target
case *dns.NS:
answer["value"] = v.Ns
case *dns.SPF:
answer["value"] = v.Txt
case *dns.TXT:
answer["value"] = v.Txt
case *dns.PTR:
answer["value"] = v.Ptr
case *dns.SOA:
answer["value"] = fmt.Sprintf("%v %v %v %v %v %v %v", v.Ns, v.Mbox, v.Serial, v.Refresh, v.Retry, v.Expire, v.Minttl)
case *dns.CAA:
answer["value"] = fmt.Sprintf("%v %v \"%v\"", v.Flag, v.Tag, v.Value)
case *dns.HINFO:
answer["value"] = fmt.Sprintf("\"%v\" \"%v\"", v.Cpu, v.Os)
case *dns.RRSIG:
answer["value"] = fmt.Sprintf("%v %v %v %v %v %v %v %v %v", dns.TypeToString[v.TypeCovered], v.Algorithm, v.Labels, v.OrigTtl, v.Expiration, v.Inception, v.KeyTag, v.SignerName, v.Signature)
default:
// type unknown, marshall it as-is
answer["value"] = v
}
answers = append(answers, answer)
}
return answers
}
// getIPString is a helper function that extracts IP address from net.Addr // getIPString is a helper function that extracts IP address from net.Addr
func getIPString(addr net.Addr) string { func getIPString(addr net.Addr) string {
switch addr := addr.(type) { switch addr := addr.(type) {

View File

@ -156,7 +156,7 @@ func periodicQueryLogRotate() {
func genericLoader(onEntry func(entry *logEntry) error, needMore func() bool, timeWindow time.Duration) error { func genericLoader(onEntry func(entry *logEntry) error, needMore func() bool, timeWindow time.Duration) error {
now := time.Now() now := time.Now()
// read from querylog files, try newest file first // read from querylog files, try newest file first
files := []string{} var files []string
if enableGzip { if enableGzip {
files = []string{ files = []string{

View File

@ -26,10 +26,10 @@ type hourTop struct {
mutex sync.RWMutex mutex sync.RWMutex
} }
func (top *hourTop) init() { func (h *hourTop) init() {
top.domains = gcache.New(queryLogTopSize).LRU().Build() h.domains = gcache.New(queryLogTopSize).LRU().Build()
top.blocked = gcache.New(queryLogTopSize).LRU().Build() h.blocked = gcache.New(queryLogTopSize).LRU().Build()
top.clients = gcache.New(queryLogTopSize).LRU().Build() h.clients = gcache.New(queryLogTopSize).LRU().Build()
} }
type dayTop struct { type dayTop struct {
@ -69,9 +69,9 @@ func periodicHourlyTopRotate() {
} }
} }
func (top *hourTop) incrementValue(key string, cache gcache.Cache) error { func (h *hourTop) incrementValue(key string, cache gcache.Cache) error {
top.Lock() h.Lock()
defer top.Unlock() defer h.Unlock()
ivalue, err := cache.Get(key) ivalue, err := cache.Get(key)
if err == gcache.KeyNotFoundError { if err == gcache.KeyNotFoundError {
// we just set it and we're done // we just set it and we're done
@ -103,20 +103,20 @@ func (top *hourTop) incrementValue(key string, cache gcache.Cache) error {
return nil return nil
} }
func (top *hourTop) incrementDomains(key string) error { func (h *hourTop) incrementDomains(key string) error {
return top.incrementValue(key, top.domains) return h.incrementValue(key, h.domains)
} }
func (top *hourTop) incrementBlocked(key string) error { func (h *hourTop) incrementBlocked(key string) error {
return top.incrementValue(key, top.blocked) return h.incrementValue(key, h.blocked)
} }
func (top *hourTop) incrementClients(key string) error { func (h *hourTop) incrementClients(key string) error {
return top.incrementValue(key, top.clients) return h.incrementValue(key, h.clients)
} }
// if does not exist -- return 0 // if does not exist -- return 0
func (top *hourTop) lockedGetValue(key string, cache gcache.Cache) (int, error) { func (h *hourTop) lockedGetValue(key string, cache gcache.Cache) (int, error) {
ivalue, err := cache.Get(key) ivalue, err := cache.Get(key)
if err == gcache.KeyNotFoundError { if err == gcache.KeyNotFoundError {
return 0, nil return 0, nil
@ -137,19 +137,19 @@ func (top *hourTop) lockedGetValue(key string, cache gcache.Cache) (int, error)
return value, nil return value, nil
} }
func (top *hourTop) lockedGetDomains(key string) (int, error) { func (h *hourTop) lockedGetDomains(key string) (int, error) {
return top.lockedGetValue(key, top.domains) return h.lockedGetValue(key, h.domains)
} }
func (top *hourTop) lockedGetBlocked(key string) (int, error) { func (h *hourTop) lockedGetBlocked(key string) (int, error) {
return top.lockedGetValue(key, top.blocked) return h.lockedGetValue(key, h.blocked)
} }
func (top *hourTop) lockedGetClients(key string) (int, error) { func (h *hourTop) lockedGetClients(key string) (int, error) {
return top.lockedGetValue(key, top.clients) return h.lockedGetValue(key, h.clients)
} }
func (r *dayTop) addEntry(entry *logEntry, q *dns.Msg, now time.Time) error { func (d *dayTop) addEntry(entry *logEntry, q *dns.Msg, now time.Time) error {
// figure out which hour bucket it belongs to // figure out which hour bucket it belongs to
hour := int(now.Sub(entry.Time).Hours()) hour := int(now.Sub(entry.Time).Hours())
if hour >= 24 { if hour >= 24 {
@ -252,6 +252,7 @@ func fillStatsFromQueryLog() error {
return nil return nil
} }
// HandleStatsTop returns the current top stats
func HandleStatsTop(w http.ResponseWriter, r *http.Request) { func HandleStatsTop(w http.ResponseWriter, r *http.Request) {
domains := map[string]int{} domains := map[string]int{}
blocked := map[string]int{} blocked := map[string]int{}
@ -320,9 +321,9 @@ func HandleStatsTop(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
_, err := w.Write(json.Bytes()) _, err := w.Write(json.Bytes())
if err != nil { if err != nil {
errortext := fmt.Sprintf("Couldn't write body: %s", err) errorText := fmt.Sprintf("Couldn't write body: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, http.StatusInternalServerError) http.Error(w, errorText, http.StatusInternalServerError)
} }
} }

View File

@ -17,7 +17,6 @@ var (
filteredLists = newDNSCounter("filtered_lists_total") filteredLists = newDNSCounter("filtered_lists_total")
filteredSafebrowsing = newDNSCounter("filtered_safebrowsing_total") filteredSafebrowsing = newDNSCounter("filtered_safebrowsing_total")
filteredParental = newDNSCounter("filtered_parental_total") filteredParental = newDNSCounter("filtered_parental_total")
filteredInvalid = newDNSCounter("filtered_invalid_total")
whitelisted = newDNSCounter("whitelisted_total") whitelisted = newDNSCounter("whitelisted_total")
safesearch = newDNSCounter("safesearch_total") safesearch = newDNSCounter("safesearch_total")
errorsTotal = newDNSCounter("errors_total") errorsTotal = newDNSCounter("errors_total")
@ -91,10 +90,10 @@ func (p *periodicStats) Observe(name string, when time.Time, value float64) {
{ {
countname := name + "_count" countname := name + "_count"
currentValues := p.Entries[countname] currentValues := p.Entries[countname]
value := currentValues[elapsed] v := currentValues[elapsed]
// log.Tracef("Will change p.Entries[%s][%d] from %v to %v", countname, elapsed, value, value+1) // log.Tracef("Will change p.Entries[%s][%d] from %v to %v", countname, elapsed, value, value+1)
value += 1 v++
currentValues[elapsed] = value currentValues[elapsed] = v
p.Entries[countname] = currentValues p.Entries[countname] = currentValues
} }
{ {
@ -224,6 +223,7 @@ func incrementCounters(entry *logEntry) {
elapsedTime.ObserveWithTime(entry.Elapsed.Seconds(), entry.Time) elapsedTime.ObserveWithTime(entry.Elapsed.Seconds(), entry.Time)
} }
// HandleStats returns aggregated stats data for the 24 hours
func HandleStats(w http.ResponseWriter, r *http.Request) { func HandleStats(w http.ResponseWriter, r *http.Request) {
const numHours = 24 const numHours = 24
histrical := generateMapFromStats(&statistics.PerHour, 0, numHours) histrical := generateMapFromStats(&statistics.PerHour, 0, numHours)
@ -252,17 +252,17 @@ func HandleStats(w http.ResponseWriter, r *http.Request) {
json, err := json.Marshal(summed) json, err := json.Marshal(summed)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Unable to marshal status json: %s", err) errorText := fmt.Sprintf("Unable to marshal status json: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, 500) http.Error(w, errorText, 500)
return return
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
_, err = w.Write(json) _, err = w.Write(json)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Unable to write response json: %s", err) errorText := fmt.Sprintf("Unable to write response json: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, 500) http.Error(w, errorText, 500)
return return
} }
} }
@ -296,6 +296,7 @@ func generateMapFromStats(stats *periodicStats, start int, end int) map[string]i
return result return result
} }
// HandleStatsHistory returns historical stats data for the 24 hours
func HandleStatsHistory(w http.ResponseWriter, r *http.Request) { func HandleStatsHistory(w http.ResponseWriter, r *http.Request) {
// handle time unit and prepare our time window size // handle time unit and prepare our time window size
now := time.Now() now := time.Now()
@ -323,27 +324,27 @@ func HandleStatsHistory(w http.ResponseWriter, r *http.Request) {
// parse start and end time // parse start and end time
startTime, err := time.Parse(time.RFC3339, r.URL.Query().Get("start_time")) startTime, err := time.Parse(time.RFC3339, r.URL.Query().Get("start_time"))
if err != nil { if err != nil {
errortext := fmt.Sprintf("Must specify valid start_time parameter: %s", err) errorText := fmt.Sprintf("Must specify valid start_time parameter: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, 400) http.Error(w, errorText, 400)
return return
} }
endTime, err := time.Parse(time.RFC3339, r.URL.Query().Get("end_time")) endTime, err := time.Parse(time.RFC3339, r.URL.Query().Get("end_time"))
if err != nil { if err != nil {
errortext := fmt.Sprintf("Must specify valid end_time parameter: %s", err) errorText := fmt.Sprintf("Must specify valid end_time parameter: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, 400) http.Error(w, errorText, 400)
return return
} }
// check if start and time times are within supported time range // check if start and time times are within supported time range
timeRange := timeUnit * statsHistoryElements timeRange := timeUnit * statsHistoryElements
if startTime.Add(timeRange).Before(now) { if startTime.Add(timeRange).Before(now) {
http.Error(w, "start_time parameter is outside of supported range", 501) http.Error(w, "start_time parameter is outside of supported range", http.StatusBadRequest)
return return
} }
if endTime.Add(timeRange).Before(now) { if endTime.Add(timeRange).Before(now) {
http.Error(w, "end_time parameter is outside of supported range", 501) http.Error(w, "end_time parameter is outside of supported range", http.StatusBadRequest)
return return
} }
@ -360,28 +361,29 @@ func HandleStatsHistory(w http.ResponseWriter, r *http.Request) {
data := generateMapFromStats(stats, start, end) data := generateMapFromStats(stats, start, end)
json, err := json.Marshal(data) json, err := json.Marshal(data)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Unable to marshal status json: %s", err) errorText := fmt.Sprintf("Unable to marshal status json: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, 500) http.Error(w, errorText, 500)
return return
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
_, err = w.Write(json) _, err = w.Write(json)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Unable to write response json: %s", err) errorText := fmt.Sprintf("Unable to write response json: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, 500) http.Error(w, errorText, 500)
return return
} }
} }
// HandleStatsReset resets the stats caches
func HandleStatsReset(w http.ResponseWriter, r *http.Request) { func HandleStatsReset(w http.ResponseWriter, r *http.Request) {
purgeStats() purgeStats()
_, err := fmt.Fprintf(w, "OK\n") _, err := fmt.Fprintf(w, "OK\n")
if err != nil { if err != nil {
errortext := fmt.Sprintf("Couldn't write body: %s", err) errorText := fmt.Sprintf("Couldn't write body: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, http.StatusInternalServerError) http.Error(w, errorText, http.StatusInternalServerError)
} }
} }

View File

@ -70,20 +70,20 @@ func updateUniqueFilterID(filters []filter) {
func assignUniqueFilterID() int64 { func assignUniqueFilterID() int64 {
value := nextFilterID value := nextFilterID
nextFilterID += 1 nextFilterID++
return value return value
} }
// Sets up a timer that will be checking for filters updates periodically // Sets up a timer that will be checking for filters updates periodically
func periodicallyRefreshFilters() { func periodicallyRefreshFilters() {
for range time.Tick(time.Minute) { for range time.Tick(time.Minute) {
refreshFiltersIfNeccessary(false) refreshFiltersIfNecessary(false)
} }
} }
// Checks filters updates if necessary // Checks filters updates if necessary
// If force is true, it ignores the filter.LastUpdated field value // If force is true, it ignores the filter.LastUpdated field value
func refreshFiltersIfNeccessary(force bool) int { func refreshFiltersIfNecessary(force bool) int {
config.Lock() config.Lock()
// fetch URLs // fetch URLs
@ -113,8 +113,12 @@ func refreshFiltersIfNeccessary(force bool) int {
} }
config.Unlock() config.Unlock()
if updateCount > 0 { if updateCount > 0 && isRunning() {
reconfigureDNSServer() err := reconfigureDNSServer()
if err != nil {
msg := fmt.Sprintf("SHOULD NOT HAPPEN: cannot reconfigure DNS server with the new filters: %s", err)
panic(msg)
}
} }
return updateCount return updateCount
} }

View File

@ -39,7 +39,7 @@ func safeWriteFile(path string, data []byte) error {
func ensure(method string, handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { func ensure(method string, handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
if r.Method != method { if r.Method != method {
http.Error(w, "This request must be "+method, 405) http.Error(w, "This request must be "+method, http.StatusMethodNotAllowed)
return return
} }
handler(w, r) handler(w, r)

View File

@ -34,9 +34,9 @@ func handleI18nCurrentLanguage(w http.ResponseWriter, r *http.Request) {
log.Printf("config.Language is %s", config.Language) log.Printf("config.Language is %s", config.Language)
_, err := fmt.Fprintf(w, "%s\n", config.Language) _, err := fmt.Fprintf(w, "%s\n", config.Language)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Unable to write response json: %s", err) errorText := fmt.Sprintf("Unable to write response json: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, http.StatusInternalServerError) http.Error(w, errorText, http.StatusInternalServerError)
return return
} }
} }

View File

@ -7,7 +7,7 @@ import (
"path/filepath" "path/filepath"
"github.com/hmage/golibs/log" "github.com/hmage/golibs/log"
"gopkg.in/yaml.v2" yaml "gopkg.in/yaml.v2"
) )
const currentSchemaVersion = 2 // used for upgrading from old configs to new config const currentSchemaVersion = 2 // used for upgrading from old configs to new config