Merge pull request #137 in DNS/adguard-dns from fix/557 to master

* commit 'ec6b1f7c42c8d2fc413d29fba55430b89fcbce2d':
  Added golangci-lint configuration and prepared for the integrattion
  Added codecov, goreport
  Use EnableAll in gometalinter config
  gometalinter
This commit is contained in:
Eugene Bujak 2019-01-25 20:14:03 +03:00
commit 0161509b5f
28 changed files with 577 additions and 464 deletions

8
.codecov.yml Normal file
View File

@ -0,0 +1,8 @@
coverage:
status:
project:
default:
target: 40%
threshold: null
patch: false
changes: false

4
.gitignore vendored
View File

@ -12,5 +12,5 @@
/scripts/translations/oneskyapp.json
# Test output
dnsfilter/dnsfilter.TestLotsOfRules*.pprof
tests/top-1m.csv
dnsfilter/tests/top-1m.csv
dnsfilter/tests/dnsfilter.TestLotsOfRules*.pprof

55
.golangci.yml Normal file
View File

@ -0,0 +1,55 @@
# options for analysis running
run:
# default concurrency is a available CPU number
concurrency: 4
# timeout for analysis, e.g. 30s, 5m, default is 1m
deadline: 2m
# which files to skip: they will be analyzed, but issues from them
# won't be reported. Default value is empty list, but there is
# no need to include all autogenerated files, we confidently recognize
# autogenerated files. If it's not please let us know.
skip-files:
- ".*generated.*"
- dnsfilter/rule_to_regexp.go
# all available settings of specific linters
linters-settings:
errcheck:
# [deprecated] comma-separated list of pairs of the form pkg:regex
# the regex is used to ignore names within pkg. (default "fmt:.*").
# see https://github.com/kisielk/errcheck#the-deprecated-method for details
ignore: fmt:.*,net:SetReadDeadline,net/http:^Write
gocyclo:
min-complexity: 20
lll:
line-length: 200
linters:
enable-all: true
disable:
- interfacer
- gocritic
- scopelint
- gochecknoglobals
- gochecknoinits
- prealloc
- maligned
- goconst # disabled until it's possible to configure
fast: true
issues:
# List of regexps of issue texts to exclude, empty list by default.
# But independently from this option we use default exclude patterns,
# it can be disabled by `exclude-use-default: false`. To list all
# excluded by default patterns execute `golangci-lint run --help`
exclude:
# structcheck cannot detect usages while they're there
- .parentalServer. is unused
- .safeBrowsingServer. is unused
# errcheck
- Error return value of .s.closeConn. is not checked
# goconst
- string .forcesafesearch.google.com. has 3 occurrences

23
.gometalinter.json Normal file
View File

@ -0,0 +1,23 @@
{
"Vendor": true,
"Test": true,
"Deadline": "2m",
"Sort": ["linter", "severity", "path", "line"],
"Exclude": [
".*generated.*",
"dnsfilter/rule_to_regexp.go"
],
"EnableGC": true,
"Linters": {
"nakedret": {
"Command": "nakedret",
"Pattern": "^(?P<path>.*?\\.go):(?P<line>\\d+)\\s*(?P<message>.*)$"
}
},
"WarnUnmatchedDirective": true,
"EnableAll": true,
"Cyclo": 20,
"LineLength": 200
}

View File

@ -24,7 +24,11 @@ install:
script:
- node -v
- npm -v
- go test ./...
# Run tests
- go test -race -v -bench=. -coverprofile=coverage.txt -covermode=atomic ./...
# Make
- make build/static/index.html
- make
after_success:
- bash <(curl -s https://codecov.io/bash)

View File

@ -11,11 +11,18 @@
<a href="https://adguard.com/">AdGuard.com</a> |
<a href="https://github.com/AdguardTeam/AdGuardHome/wiki">Wiki</a> |
<a href="https://reddit.com/r/Adguard">Reddit</a> |
<a href="https://twitter.com/AdGuard">Twitter</a>
<a href="https://twitter.com/AdGuard">Twitter</a> |
<a href="https://t.me/adguard_en">Telegram</a>
<br /><br />
<a href="https://travis-ci.org/AdguardTeam/AdGuardHome">
<img src="https://travis-ci.org/AdguardTeam/AdGuardHome.svg" alt="Build status" />
</a>
<a href="https://codecov.io/github/AdguardTeam/AdGuardHome?branch=master">
<img src="https://img.shields.io/codecov/c/github/AdguardTeam/AdGuardHome/master.svg" alt="Code Coverage" />
</a>
<a href="https://goreportcard.com/report/AdguardTeam/AdGuardHome">
<img src="https://goreportcard.com/badge/github.com/AdguardTeam/AdGuardHome" alt="Go Report Card" />
</a>
<a href="https://github.com/AdguardTeam/AdGuardHome/releases">
<img src="https://img.shields.io/github/release/AdguardTeam/AdGuardHome/all.svg" alt="Latest release" />
</a>

163
app.go
View File

@ -43,7 +43,89 @@ func main() {
// config can be specified, which reads options from there, but other command line flags have to override config values
// therefore, we must do it manually instead of using a lib
{
loadOptions()
// Load filters from the disk
// And if any filter has zero ID, assign a new one
for i := range config.Filters {
filter := &config.Filters[i] // otherwise we're operating on a copy
if filter.ID == 0 {
filter.ID = assignUniqueFilterID()
}
err := filter.load()
if err != nil {
// This is okay for the first start, the filter will be loaded later
log.Printf("Couldn't load filter %d contents due to %s", filter.ID, err)
// clear LastUpdated so it gets fetched right away
}
if len(filter.Rules) == 0 {
filter.LastUpdated = time.Time{}
}
}
// Update filters we've just loaded right away, don't wait for periodic update timer
go func() {
refreshFiltersIfNecessary(false)
// Save the updated config
err := config.write()
if err != nil {
log.Fatal(err)
}
}()
signalChannel := make(chan os.Signal)
signal.Notify(signalChannel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT)
go func() {
<-signalChannel
cleanup()
os.Exit(0)
}()
// Save the updated config
err := config.write()
if err != nil {
log.Fatal(err)
}
address := net.JoinHostPort(config.BindHost, strconv.Itoa(config.BindPort))
go periodicallyRefreshFilters()
http.Handle("/", optionalAuthHandler(http.FileServer(box)))
registerControlHandlers()
err = startDNSServer()
if err != nil {
log.Fatal(err)
}
err = startDHCPServer()
if err != nil {
log.Fatal(err)
}
URL := fmt.Sprintf("http://%s", address)
log.Println("Go to " + URL)
log.Fatal(http.ListenAndServe(address, nil))
}
func cleanup() {
err := stopDNSServer()
if err != nil {
log.Printf("Couldn't stop DNS server: %s", err)
}
}
func getInput() (string, error) {
scanner := bufio.NewScanner(os.Stdin)
scanner.Scan()
text := scanner.Text()
err := scanner.Err()
return text, err
}
// loadOptions reads command line arguments and initializes configuration
func loadOptions() {
var printHelp func()
var configFilename *string
var bindHost *string
@ -130,85 +212,6 @@ func main() {
}
}
// Load filters from the disk
// And if any filter has zero ID, assign a new one
for i := range config.Filters {
filter := &config.Filters[i] // otherwise we're operating on a copy
if filter.ID == 0 {
filter.ID = assignUniqueFilterID()
}
err := filter.load()
if err != nil {
// This is okay for the first start, the filter will be loaded later
log.Printf("Couldn't load filter %d contents due to %s", filter.ID, err)
// clear LastUpdated so it gets fetched right away
}
if len(filter.Rules) == 0 {
filter.LastUpdated = time.Time{}
}
}
// Update filters we've just loaded right away, don't wait for periodic update timer
go func() {
refreshFiltersIfNeccessary(false)
// Save the updated config
err := config.write()
if err != nil {
log.Fatal(err)
}
}()
signalChannel := make(chan os.Signal)
signal.Notify(signalChannel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT)
go func() {
<-signalChannel
cleanup()
os.Exit(0)
}()
// Save the updated config
err := config.write()
if err != nil {
log.Fatal(err)
}
address := net.JoinHostPort(config.BindHost, strconv.Itoa(config.BindPort))
go periodicallyRefreshFilters()
http.Handle("/", optionalAuthHandler(http.FileServer(box)))
registerControlHandlers()
err = startDNSServer()
if err != nil {
log.Fatal(err)
}
err = startDHCPServer()
if err != nil {
log.Fatal(err)
}
URL := fmt.Sprintf("http://%s", address)
log.Println("Go to " + URL)
log.Fatal(http.ListenAndServe(address, nil))
}
func cleanup() {
err := stopDNSServer()
if err != nil {
log.Printf("Couldn't stop DNS server: %s", err)
}
}
func getInput() (string, error) {
scanner := bufio.NewScanner(os.Stdin)
scanner.Scan()
text := scanner.Text()
err := scanner.Err()
return text, err
}
func promptAndGet(prompt string) (string, error) {
for {
fmt.Print(prompt)

View File

@ -10,7 +10,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
"github.com/AdguardTeam/AdGuardHome/dnsforward"
"github.com/hmage/golibs/log"
"gopkg.in/yaml.v2"
yaml "gopkg.in/yaml.v2"
)
const (
@ -21,7 +21,7 @@ const (
// configuration is loaded from YAML
// field ordering is important -- yaml fields will mirror ordering from here
type configuration struct {
ourConfigFilename string // Config filename (can be overriden via the command line arguments)
ourConfigFilename string // Config filename (can be overridden via the command line arguments)
ourBinaryDir string // Location of our directory, used to protect against CWD being somewhere else
BindHost string `yaml:"bind_host"`

View File

@ -15,7 +15,7 @@ import (
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/hmage/golibs/log"
"github.com/miekg/dns"
"gopkg.in/asaskevich/govalidator.v4"
govalidator "gopkg.in/asaskevich/govalidator.v4"
)
const updatePeriod = time.Minute * 30
@ -40,27 +40,26 @@ func writeAllConfigsAndReloadDNS() error {
log.Printf("Couldn't write all configs: %s", err)
return err
}
reconfigureDNSServer()
return nil
return reconfigureDNSServer()
}
func httpUpdateConfigReloadDNSReturnOK(w http.ResponseWriter, r *http.Request) {
err := writeAllConfigsAndReloadDNS()
if err != nil {
errortext := fmt.Sprintf("Couldn't write config file: %s", err)
log.Println(errortext)
http.Error(w, errortext, http.StatusInternalServerError)
errorText := fmt.Sprintf("Couldn't write config file: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
return
}
returnOK(w, r)
returnOK(w)
}
func returnOK(w http.ResponseWriter, r *http.Request) {
func returnOK(w http.ResponseWriter) {
_, err := fmt.Fprintf(w, "OK\n")
if err != nil {
errortext := fmt.Sprintf("Couldn't write body: %s", err)
log.Println(errortext)
http.Error(w, errortext, http.StatusInternalServerError)
errorText := fmt.Sprintf("Couldn't write body: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
}
}
@ -79,17 +78,17 @@ func handleStatus(w http.ResponseWriter, r *http.Request) {
jsonVal, err := json.Marshal(data)
if err != nil {
errortext := fmt.Sprintf("Unable to marshal status json: %s", err)
log.Println(errortext)
http.Error(w, errortext, 500)
errorText := fmt.Sprintf("Unable to marshal status json: %s", err)
log.Println(errorText)
http.Error(w, errorText, 500)
return
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal)
if err != nil {
errortext := fmt.Sprintf("Unable to write response json: %s", err)
log.Println(errortext)
http.Error(w, errortext, 500)
errorText := fmt.Sprintf("Unable to write response json: %s", err)
log.Println(errorText)
http.Error(w, errorText, 500)
return
}
}
@ -147,7 +146,13 @@ func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) {
http.Error(w, errorText, http.StatusInternalServerError)
return
}
reconfigureDNSServer()
err = reconfigureDNSServer()
if err != nil {
errorText := fmt.Sprintf("Couldn't reconfigure the DNS server: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
return
}
_, err = fmt.Fprintf(w, "OK %d servers\n", len(hosts))
if err != nil {
errorText := fmt.Sprintf("Couldn't write body: %s", err)
@ -206,7 +211,7 @@ func checkDNS(input string) error {
log.Printf("Checking if DNS %s works...", input)
u, err := upstream.AddressToUpstream(input, "", dnsforward.DefaultTimeout)
if err != nil {
return fmt.Errorf("Failed to choose upstream for %s: %s", input, err)
return fmt.Errorf("failed to choose upstream for %s: %s", input, err)
}
req := dns.Msg{}
@ -243,9 +248,9 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
resp, err := client.Get(versionCheckURL)
if err != nil {
errortext := fmt.Sprintf("Couldn't get version check json from %s: %T %s\n", versionCheckURL, err, err)
log.Println(errortext)
http.Error(w, errortext, http.StatusBadGateway)
errorText := fmt.Sprintf("Couldn't get version check json from %s: %T %s\n", versionCheckURL, err, err)
log.Println(errorText)
http.Error(w, errorText, http.StatusBadGateway)
return
}
if resp != nil && resp.Body != nil {
@ -255,18 +260,18 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
// read the body entirely
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
errortext := fmt.Sprintf("Couldn't read response body from %s: %s", versionCheckURL, err)
log.Println(errortext)
http.Error(w, errortext, http.StatusBadGateway)
errorText := fmt.Sprintf("Couldn't read response body from %s: %s", versionCheckURL, err)
log.Println(errorText)
http.Error(w, errorText, http.StatusBadGateway)
return
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(body)
if err != nil {
errortext := fmt.Sprintf("Couldn't write body: %s", err)
log.Println(errortext)
http.Error(w, errortext, http.StatusInternalServerError)
errorText := fmt.Sprintf("Couldn't write body: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
}
versionCheckLastTime = now
@ -299,44 +304,44 @@ func handleFilteringStatus(w http.ResponseWriter, r *http.Request) {
config.RUnlock()
if err != nil {
errortext := fmt.Sprintf("Unable to marshal status json: %s", err)
log.Println(errortext)
http.Error(w, errortext, 500)
errorText := fmt.Sprintf("Unable to marshal status json: %s", err)
log.Println(errorText)
http.Error(w, errorText, 500)
return
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal)
if err != nil {
errortext := fmt.Sprintf("Unable to write response json: %s", err)
log.Println(errortext)
http.Error(w, errortext, 500)
errorText := fmt.Sprintf("Unable to write response json: %s", err)
log.Println(errorText)
http.Error(w, errorText, 500)
return
}
}
func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) {
filter := filter{}
err := json.NewDecoder(r.Body).Decode(&filter)
f := filter{}
err := json.NewDecoder(r.Body).Decode(&f)
if err != nil {
httpError(w, http.StatusBadRequest, "Failed to parse request body json: %s", err)
return
}
if len(filter.URL) == 0 {
if len(f.URL) == 0 {
http.Error(w, "URL parameter was not specified", 400)
return
}
if valid := govalidator.IsRequestURL(filter.URL); !valid {
if valid := govalidator.IsRequestURL(f.URL); !valid {
http.Error(w, "URL parameter is not valid request URL", 400)
return
}
// Check for duplicates
for i := range config.Filters {
if config.Filters[i].URL == filter.URL {
errorText := fmt.Sprintf("Filter URL already added -- %s", filter.URL)
if config.Filters[i].URL == f.URL {
errorText := fmt.Sprintf("Filter URL already added -- %s", f.URL)
log.Println(errorText)
http.Error(w, errorText, http.StatusBadRequest)
return
@ -344,42 +349,42 @@ func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) {
}
// Set necessary properties
filter.ID = assignUniqueFilterID()
filter.Enabled = true
f.ID = assignUniqueFilterID()
f.Enabled = true
// Download the filter contents
ok, err := filter.update(true)
ok, err := f.update(true)
if err != nil {
errorText := fmt.Sprintf("Couldn't fetch filter from url %s: %s", filter.URL, err)
errorText := fmt.Sprintf("Couldn't fetch filter from url %s: %s", f.URL, err)
log.Println(errorText)
http.Error(w, errorText, http.StatusBadRequest)
return
}
if filter.RulesCount == 0 {
errorText := fmt.Sprintf("Filter at the url %s has no rules (maybe it points to blank page?)", filter.URL)
if f.RulesCount == 0 {
errorText := fmt.Sprintf("Filter at the url %s has no rules (maybe it points to blank page?)", f.URL)
log.Println(errorText)
http.Error(w, errorText, http.StatusBadRequest)
return
}
if !ok {
errorText := fmt.Sprintf("Filter at the url %s is invalid (maybe it points to blank page?)", filter.URL)
errorText := fmt.Sprintf("Filter at the url %s is invalid (maybe it points to blank page?)", f.URL)
log.Println(errorText)
http.Error(w, errorText, http.StatusBadRequest)
return
}
// Save the filter contents
err = filter.save()
err = f.save()
if err != nil {
errorText := fmt.Sprintf("Failed to save filter %d due to %s", filter.ID, err)
errorText := fmt.Sprintf("Failed to save filter %d due to %s", f.ID, err)
log.Println(errorText)
http.Error(w, errorText, http.StatusBadRequest)
return
}
// URL is deemed valid, append it to filters, update config, write new filter file and tell dns to reload it
// TODO: since we directly feed filters in-memory, revisit if writing configs is always neccessary
config.Filters = append(config.Filters, filter)
// TODO: since we directly feed filters in-memory, revisit if writing configs is always necessary
config.Filters = append(config.Filters, f)
err = writeAllConfigs()
if err != nil {
errorText := fmt.Sprintf("Couldn't write config file: %s", err)
@ -388,9 +393,14 @@ func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) {
return
}
reconfigureDNSServer()
err = reconfigureDNSServer()
if err != nil {
errorText := fmt.Sprintf("Couldn't reconfigure the DNS server: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
}
_, err = fmt.Fprintf(w, "OK %d rules\n", filter.RulesCount)
_, err = fmt.Fprintf(w, "OK %d rules\n", f.RulesCount)
if err != nil {
errorText := fmt.Sprintf("Couldn't write body: %s", err)
log.Println(errorText)
@ -473,7 +483,7 @@ func handleFilteringEnableURL(w http.ResponseWriter, r *http.Request) {
}
// kick off refresh of rules from new URLs
refreshFiltersIfNeccessary(false)
refreshFiltersIfNecessary(false)
httpUpdateConfigReloadDNSReturnOK(w, r)
}
@ -529,7 +539,7 @@ func handleFilteringSetRules(w http.ResponseWriter, r *http.Request) {
func handleFilteringRefresh(w http.ResponseWriter, r *http.Request) {
force := r.URL.Query().Get("force")
updated := refreshFiltersIfNeccessary(force != "")
updated := refreshFiltersIfNecessary(force != "")
fmt.Fprintf(w, "OK %d filters updated\n", updated)
}
@ -553,17 +563,17 @@ func handleSafeBrowsingStatus(w http.ResponseWriter, r *http.Request) {
}
jsonVal, err := json.Marshal(data)
if err != nil {
errortext := fmt.Sprintf("Unable to marshal status json: %s", err)
log.Println(errortext)
http.Error(w, errortext, 500)
errorText := fmt.Sprintf("Unable to marshal status json: %s", err)
log.Println(errorText)
http.Error(w, errorText, 500)
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal)
if err != nil {
errortext := fmt.Sprintf("Unable to write response json: %s", err)
log.Println(errortext)
http.Error(w, errortext, 500)
errorText := fmt.Sprintf("Unable to write response json: %s", err)
log.Println(errorText)
http.Error(w, errorText, 500)
return
}
}
@ -574,9 +584,9 @@ func handleSafeBrowsingStatus(w http.ResponseWriter, r *http.Request) {
func handleParentalEnable(w http.ResponseWriter, r *http.Request) {
parameters, err := parseParametersFromBody(r.Body)
if err != nil {
errortext := fmt.Sprintf("failed to parse parameters from body: %s", err)
log.Println(errortext)
http.Error(w, errortext, 400)
errorText := fmt.Sprintf("failed to parse parameters from body: %s", err)
log.Println(errorText)
http.Error(w, errorText, 400)
return
}
@ -631,18 +641,18 @@ func handleParentalStatus(w http.ResponseWriter, r *http.Request) {
}
jsonVal, err := json.Marshal(data)
if err != nil {
errortext := fmt.Sprintf("Unable to marshal status json: %s", err)
log.Println(errortext)
http.Error(w, errortext, 500)
errorText := fmt.Sprintf("Unable to marshal status json: %s", err)
log.Println(errorText)
http.Error(w, errorText, 500)
return
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal)
if err != nil {
errortext := fmt.Sprintf("Unable to write response json: %s", err)
log.Println(errortext)
http.Error(w, errortext, 500)
errorText := fmt.Sprintf("Unable to write response json: %s", err)
log.Println(errorText)
http.Error(w, errorText, 500)
return
}
}
@ -667,18 +677,18 @@ func handleSafeSearchStatus(w http.ResponseWriter, r *http.Request) {
}
jsonVal, err := json.Marshal(data)
if err != nil {
errortext := fmt.Sprintf("Unable to marshal status json: %s", err)
log.Println(errortext)
http.Error(w, errortext, 500)
errorText := fmt.Sprintf("Unable to marshal status json: %s", err)
log.Println(errorText)
http.Error(w, errorText, 500)
return
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal)
if err != nil {
errortext := fmt.Sprintf("Unable to write response json: %s", err)
log.Println(errortext)
http.Error(w, errortext, 500)
errorText := fmt.Sprintf("Unable to write response json: %s", err)
log.Println(errorText)
http.Error(w, errorText, 500)
return
}
}

18
dhcp.go
View File

@ -58,7 +58,10 @@ func handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
}
}
if !newconfig.Enabled {
dhcpServer.Stop()
err := dhcpServer.Stop()
if err != nil {
log.Printf("failed to stop the DHCP server: %s", err)
}
}
config.DHCP = newconfig
httpUpdateConfigReloadDNSReturnOK(w, r)
@ -73,11 +76,6 @@ func handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
return
}
type address struct {
IP string
Netmask string
}
type responseInterface struct {
Name string `json:"name"`
MTU int `json:"mtu"`
@ -103,9 +101,9 @@ func handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
MTU: ifaces[i].MTU,
HardwareAddr: ifaces[i].HardwareAddr.String(),
}
addrs, err := ifaces[i].Addrs()
if err != nil {
httpError(w, http.StatusInternalServerError, "Failed to get addresses for interface %v: %s", ifaces[i].Name, err)
addrs, errAddrs := ifaces[i].Addrs()
if errAddrs != nil {
httpError(w, http.StatusInternalServerError, "Failed to get addresses for interface %v: %s", ifaces[i].Name, errAddrs)
return
}
for _, addr := range addrs {
@ -157,7 +155,7 @@ func handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Request) {
}
func startDHCPServer() error {
if config.DHCP.Enabled == false {
if !config.DHCP.Enabled {
// not enabled, don't do anything
return nil
}

View File

@ -13,6 +13,8 @@ import (
"github.com/krolaw/dhcp4"
)
// CheckIfOtherDHCPServersPresent sends a DHCP request to the specified network interface,
// and waits for a response for a period defined by defaultDiscoverTime
func CheckIfOtherDHCPServersPresent(ifaceName string) (bool, error) {
iface, err := net.InterfaceByName(ifaceName)
if err != nil {
@ -30,8 +32,8 @@ func CheckIfOtherDHCPServersPresent(ifaceName string) (bool, error) {
dst := "255.255.255.255:67"
// form a DHCP request packet, try to emulate existing client as much as possible
xId := make([]byte, 8)
n, err := rand.Read(xId)
xID := make([]byte, 8)
n, err := rand.Read(xID)
if n != 8 && err == nil {
err = fmt.Errorf("Generated less than 8 bytes")
}
@ -57,16 +59,16 @@ func CheckIfOtherDHCPServersPresent(ifaceName string) (bool, error) {
maxUDPsizeRaw := make([]byte, 2)
binary.BigEndian.PutUint16(maxUDPsizeRaw, 1500)
leaseTimeRaw := make([]byte, 4)
leaseTime := uint32(math.RoundToEven(time.Duration(time.Hour * 24 * 90).Seconds()))
leaseTime := uint32(math.RoundToEven((time.Hour * 24 * 90).Seconds()))
binary.BigEndian.PutUint32(leaseTimeRaw, leaseTime)
options := []dhcp4.Option{
{dhcp4.OptionParameterRequestList, requestList},
{dhcp4.OptionMaximumDHCPMessageSize, maxUDPsizeRaw},
{dhcp4.OptionClientIdentifier, append([]byte{0x01}, iface.HardwareAddr...)},
{dhcp4.OptionIPAddressLeaseTime, leaseTimeRaw},
{dhcp4.OptionHostName, []byte(hostname)},
{Code: dhcp4.OptionParameterRequestList, Value: requestList},
{Code: dhcp4.OptionMaximumDHCPMessageSize, Value: maxUDPsizeRaw},
{Code: dhcp4.OptionClientIdentifier, Value: append([]byte{0x01}, iface.HardwareAddr...)},
{Code: dhcp4.OptionIPAddressLeaseTime, Value: leaseTimeRaw},
{Code: dhcp4.OptionHostName, Value: []byte(hostname)},
}
packet := dhcp4.RequestPacket(dhcp4.Discover, iface.HardwareAddr, nil, xId, false, options)
packet := dhcp4.RequestPacket(dhcp4.Discover, iface.HardwareAddr, nil, xID, false, options)
// resolve 0.0.0.0:68
udpAddr, err := net.ResolveUDPAddr("udp4", src)
@ -98,7 +100,7 @@ func CheckIfOtherDHCPServersPresent(ifaceName string) (bool, error) {
}
// send to 255.255.255.255:67
n, err = c.WriteTo(packet, dstAddr)
_, err = c.WriteTo(packet, dstAddr)
// spew.Dump(n, err)
if err != nil {
return false, wrapErrPrint(err, "Couldn't send a packet to %s", dst)

View File

@ -13,6 +13,7 @@ import (
const defaultDiscoverTime = time.Second * 3
// Lease contains the necessary information about a DHCP lease
// field ordering is important -- yaml fields will mirror ordering from here
type Lease struct {
HWAddr net.HardwareAddr `json:"mac" yaml:"hwaddr"`
@ -21,6 +22,7 @@ type Lease struct {
Expiry time.Time `json:"expires"`
}
// ServerConfig - DHCP server configuration
// field ordering is important -- yaml fields will mirror ordering from here
type ServerConfig struct {
Enabled bool `json:"enabled" yaml:"enabled"`
@ -32,6 +34,7 @@ type ServerConfig struct {
LeaseDuration uint `json:"lease_duration" yaml:"lease_duration"` // in seconds
}
// Server - the current state of the DHCP server
type Server struct {
conn *filterConn // listening UDP socket
@ -80,6 +83,7 @@ func (s *Server) Start(config *ServerConfig) error {
s.leaseStart, err = parseIPv4(s.RangeStart)
if err != nil {
s.closeConn() // in case it was already started
return wrapErrPrint(err, "Failed to parse range start address %s", s.RangeStart)
}
@ -137,6 +141,7 @@ func (s *Server) Start(config *ServerConfig) error {
return nil
}
// Stop closes the listening UDP socket
func (s *Server) Stop() error {
if s.conn == nil {
// nothing to do, return silently
@ -174,7 +179,7 @@ func (s *Server) reserveLease(p dhcp4.Packet) (*Lease, error) {
}
// not assigned a lease, create new one, find IP from LRU
log.Tracef("Lease not found for %s: creating new one", hwaddr)
ip, err := s.findFreeIP(p, hwaddr)
ip, err := s.findFreeIP(hwaddr)
if err != nil {
return nil, wrapErrPrint(err, "Couldn't find free IP for the lease %s", hwaddr.String())
}
@ -198,7 +203,7 @@ func (s *Server) locateLease(p dhcp4.Packet) *Lease {
return nil
}
func (s *Server) findFreeIP(p dhcp4.Packet, hwaddr net.HardwareAddr) (net.IP, error) {
func (s *Server) findFreeIP(hwaddr net.HardwareAddr) (net.IP, error) {
// if IP pool is nil, lazy initialize it
if s.IPpool == nil {
s.IPpool = make(map[[4]byte]net.HardwareAddr)
@ -223,7 +228,7 @@ func (s *Server) findFreeIP(p dhcp4.Packet, hwaddr net.HardwareAddr) (net.IP, er
if foundIP == nil {
// TODO: LRU
return nil, fmt.Errorf("Couldn't find free entry in IP pool")
return nil, fmt.Errorf("couldn't find free entry in IP pool")
}
s.reserveIP(foundIP, hwaddr)
@ -249,6 +254,7 @@ func (s *Server) unreserveIP(ip net.IP) {
delete(s.IPpool, IP4)
}
// ServeDHCP handles an incoming DHCP request
func (s *Server) ServeDHCP(p dhcp4.Packet, msgType dhcp4.MessageType, options dhcp4.Options) dhcp4.Packet {
log.Tracef("Got %v message", msgType)
log.Tracef("Leases:")
@ -259,27 +265,6 @@ func (s *Server) ServeDHCP(p dhcp4.Packet, msgType dhcp4.MessageType, options dh
for ip, hwaddr := range s.IPpool {
log.Tracef("IP pool entry %s -> %s", net.IPv4(ip[0], ip[1], ip[2], ip[3]), hwaddr)
}
// spew.Dump(s.leases, s.IPpool)
// log.Printf("Called with msgType = %v, options = %+v", msgType, options)
// spew.Dump(p)
// log.Printf("%14s %v", "p.Broadcast", p.Broadcast()) // false
// log.Printf("%14s %v", "p.CHAddr", p.CHAddr()) // 2c:f0:a2:f2:31:00
// log.Printf("%14s %v", "p.CIAddr", p.CIAddr()) // 0.0.0.0
// log.Printf("%14s %v", "p.Cookie", p.Cookie()) // [99 130 83 99]
// log.Printf("%14s %v", "p.File", p.File()) // []
// log.Printf("%14s %v", "p.Flags", p.Flags()) // [0 0]
// log.Printf("%14s %v", "p.GIAddr", p.GIAddr()) // 0.0.0.0
// log.Printf("%14s %v", "p.HLen", p.HLen()) // 6
// log.Printf("%14s %v", "p.HType", p.HType()) // 1
// log.Printf("%14s %v", "p.Hops", p.Hops()) // 0
// log.Printf("%14s %v", "p.OpCode", p.OpCode()) // BootRequest
// log.Printf("%14s %v", "p.Options", p.Options()) // [53 1 1 55 10 1 121 3 6 15 119 252 95 44 46 57 2 5 220 61 7 1 44 240 162 242 49 0 51 4 0 118 167 0 12 4 119 104 109 100 255 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
// log.Printf("%14s %v", "p.ParseOptions", p.ParseOptions()) // map[OptionParameterRequestList:[1 121 3 6 15 119 252 95 44 46] OptionDHCPMessageType:[1] OptionMaximumDHCPMessageSize:[5 220] OptionClientIdentifier:[1 44 240 162 242 49 0] OptionIPAddressLeaseTime:[0 118 167 0] OptionHostName:[119 104 109 100]]
// log.Printf("%14s %v", "p.SIAddr", p.SIAddr()) // 0.0.0.0
// log.Printf("%14s %v", "p.SName", p.SName()) // []
// log.Printf("%14s %v", "p.Secs", p.Secs()) // [0 8]
// log.Printf("%14s %v", "p.XId", p.XId()) // [211 184 20 44]
// log.Printf("%14s %v", "p.YIAddr", p.YIAddr()) // 0.0.0.0
switch msgType {
case dhcp4.Discover: // Broadcast Packet From Client - Can I have an IP?
@ -297,6 +282,32 @@ func (s *Server) ServeDHCP(p dhcp4.Packet, msgType dhcp4.MessageType, options dh
case dhcp4.Request: // Broadcast From Client - I'll take that IP (Also start for renewals)
// start/renew a lease -- update lease time
// some clients (OSX) just go right ahead and do Request first from previously known IP, if they get NAK, they restart full cycle with Discover then Request
return s.handleDHCP4Request(p, options)
case dhcp4.Decline: // Broadcast From Client - Sorry I can't use that IP
log.Tracef("Got from client: Decline")
case dhcp4.Release: // From Client, I don't need that IP anymore
log.Tracef("Got from client: Release")
case dhcp4.Inform: // From Client, I have this IP and there's nothing you can do about it
log.Tracef("Got from client: Inform")
// do nothing
// from server -- ignore those but enumerate just in case
case dhcp4.Offer: // Broadcast From Server - Here's an IP
log.Printf("SHOULD NOT HAPPEN -- FROM ANOTHER DHCP SERVER: Offer")
case dhcp4.ACK: // From Server, Yes you can have that IP
log.Printf("SHOULD NOT HAPPEN -- FROM ANOTHER DHCP SERVER: ACK")
case dhcp4.NAK: // From Server, No you cannot have that IP
log.Printf("SHOULD NOT HAPPEN -- FROM ANOTHER DHCP SERVER: NAK")
default:
log.Printf("Unknown DHCP packet detected, ignoring: %v", msgType)
return nil
}
return nil
}
func (s *Server) handleDHCP4Request(p dhcp4.Packet, options dhcp4.Options) dhcp4.Packet {
log.Tracef("Got from client: Request")
if server, ok := options[dhcp4.OptionServerIdentifier]; ok && !net.IP(server).Equal(s.ipnet.IP) {
log.Tracef("Request message not for this DHCP server (%v vs %v)", server, s.ipnet.IP)
@ -305,7 +316,7 @@ func (s *Server) ServeDHCP(p dhcp4.Packet, msgType dhcp4.MessageType, options dh
reqIP := net.IP(options[dhcp4.OptionRequestedIPAddress])
if reqIP == nil {
reqIP = net.IP(p.CIAddr())
reqIP = p.CIAddr()
}
if reqIP.To4() == nil {
@ -366,30 +377,9 @@ func (s *Server) ServeDHCP(p dhcp4.Packet, msgType dhcp4.MessageType, options dh
// requested IP is outside of DHCP range
log.Tracef("Replying with NAK: request IP %s is outside of DHCP range [%s, %s], asked by %v", reqIP, s.leaseStart, s.leaseStop, p.CHAddr())
return dhcp4.ReplyPacket(p, dhcp4.NAK, s.ipnet.IP, nil, 0, nil)
case dhcp4.Decline: // Broadcast From Client - Sorry I can't use that IP
log.Tracef("Got from client: Decline")
case dhcp4.Release: // From Client, I don't need that IP anymore
log.Tracef("Got from client: Release")
case dhcp4.Inform: // From Client, I have this IP and there's nothing you can do about it
log.Tracef("Got from client: Inform")
// do nothing
// from server -- ignore those but enumerate just in case
case dhcp4.Offer: // Broadcast From Server - Here's an IP
log.Printf("SHOULD NOT HAPPEN -- FROM ANOTHER DHCP SERVER: Offer")
case dhcp4.ACK: // From Server, Yes you can have that IP
log.Printf("SHOULD NOT HAPPEN -- FROM ANOTHER DHCP SERVER: ACK")
case dhcp4.NAK: // From Server, No you cannot have that IP
log.Printf("SHOULD NOT HAPPEN -- FROM ANOTHER DHCP SERVER: NAK")
default:
log.Printf("Unknown DHCP packet detected, ignoring: %v", msgType)
return nil
}
return nil
}
// Leases returns the list of current DHCP leases
func (s *Server) Leases() []*Lease {
s.RLock()
result := s.leases

View File

@ -8,7 +8,7 @@ import (
)
// filterConn listens to 0.0.0.0:67, but accepts packets only from specific interface
// This is neccessary for DHCP daemon to work, since binding to IP address doesn't
// This is necessary for DHCP daemon to work, since binding to IP address doesn't
// us access to see Discover/Request packets from clients.
//
// TODO: on windows, controlmessage does not work, try to find out another way
@ -49,7 +49,6 @@ func (f *filterConn) ReadFrom(b []byte) (int, net.Addr, error) {
}
// packet doesn't match criteria, drop it
}
return 0, nil, nil
}
func (f *filterConn) WriteTo(b []byte, addr net.Addr) (int, error) {

View File

@ -3,7 +3,6 @@ package dhcpd
import (
"fmt"
"net"
"strings"
"github.com/hmage/golibs/log"
"github.com/joomcode/errorx"
@ -45,22 +44,6 @@ func getIfaceIPv4(iface *net.Interface) *net.IPNet {
return nil
}
func isConnClosed(err error) bool {
if err == nil {
return false
}
nerr, ok := err.(*net.OpError)
if !ok {
return false
}
if strings.Contains(nerr.Err.Error(), "use of closed network connection") {
return true
}
return false
}
func wrapErrPrint(err error, message string, args ...interface{}) error {
var errx error
if err == nil {

View File

@ -76,9 +76,9 @@ func main() {
panic(err)
}
log.Printf("Now serving DHCP")
signal_channel := make(chan os.Signal)
signal.Notify(signal_channel, syscall.SIGINT, syscall.SIGTERM)
<-signal_channel
signalChannel := make(chan os.Signal)
signal.Notify(signalChannel, syscall.SIGINT, syscall.SIGTERM)
<-signalChannel
}

6
dns.go
View File

@ -39,20 +39,20 @@ func generateServerConfig() dnsforward.ServerConfig {
}
for _, u := range config.DNS.UpstreamDNS {
upstream, err := upstream.AddressToUpstream(u, config.DNS.BootstrapDNS, dnsforward.DefaultTimeout)
dnsUpstream, err := upstream.AddressToUpstream(u, config.DNS.BootstrapDNS, dnsforward.DefaultTimeout)
if err != nil {
log.Printf("Couldn't get upstream: %s", err)
// continue, just ignore the upstream
continue
}
newconfig.Upstreams = append(newconfig.Upstreams, upstream)
newconfig.Upstreams = append(newconfig.Upstreams, dnsUpstream)
}
return newconfig
}
func startDNSServer() error {
if isRunning() {
return fmt.Errorf("Unable to start forwarding DNS server: Already running")
return fmt.Errorf("unable to start forwarding DNS server: Already running")
}
newconfig := generateServerConfig()

View File

@ -35,7 +35,7 @@ const defaultParentalURL = "http://%s/check-parental-control-hash?prefixes=%s&se
// ErrInvalidSyntax is returned by AddRule when the rule is invalid
var ErrInvalidSyntax = errors.New("dnsfilter: invalid rule syntax")
// ErrInvalidSyntax is returned by AddRule when the rule was already added to the filter
// ErrAlreadyExists is returned by AddRule when the rule was already added to the filter
var ErrAlreadyExists = errors.New("dnsfilter: rule was already added")
const shortcutLength = 6 // used for rule search optimization, 6 hits the sweet spot
@ -115,6 +115,7 @@ type Dnsfilter struct {
privateConfig
}
// Filter represents a filter list
type Filter struct {
ID int64 `json:"id"` // auto-assigned when filter is added (see nextFilterID), json by default keeps ID uppercase but we need lowercase
Rules []string `json:"-" yaml:"-"` // not in yaml or json
@ -127,16 +128,26 @@ type Reason int
const (
// reasons for not filtering
NotFilteredNotFound Reason = iota // host was not find in any checks, default value for result
NotFilteredWhiteList // the host is explicitly whitelisted
NotFilteredError // there was a transitive error during check
// NotFilteredNotFound - host was not find in any checks, default value for result
NotFilteredNotFound Reason = iota
// NotFilteredWhiteList - the host is explicitly whitelisted
NotFilteredWhiteList
// NotFilteredError - there was a transitive error during check
NotFilteredError
// reasons for filtering
FilteredBlackList // the host was matched to be advertising host
FilteredSafeBrowsing // the host was matched to be malicious/phishing
FilteredParental // the host was matched to be outside of parental control settings
FilteredInvalid // the request was invalid and was not processed
FilteredSafeSearch // the host was replaced with safesearch variant
// FilteredBlackList - the host was matched to be advertising host
FilteredBlackList
// FilteredSafeBrowsing - the host was matched to be malicious/phishing
FilteredSafeBrowsing
// FilteredParental - the host was matched to be outside of parental control settings
FilteredParental
// FilteredInvalid - the request was invalid and was not processed
FilteredInvalid
// FilteredSafeSearch - the host was replaced with safesearch variant
FilteredSafeSearch
)
// these variables need to survive coredns reload
@ -151,7 +162,7 @@ type Result struct {
IsFiltered bool `json:",omitempty"` // True if the host name is filtered
Reason Reason `json:",omitempty"` // Reason for blocking / unblocking
Rule string `json:",omitempty"` // Original rule text
Ip net.IP `json:",omitempty"` // Not nil only in the case of a hosts file syntax
IP net.IP `json:",omitempty"` // Not nil only in the case of a hosts file syntax
FilterID int64 `json:",omitempty"` // Filter ID the rule belongs to
}
@ -228,7 +239,6 @@ func newRulesTable() *rulesTable {
func (r *rulesTable) Add(rule *rule) {
r.Lock()
if rule.ip != nil {
// Hosts syntax
r.rulesByHost[rule.text] = rule
@ -476,7 +486,7 @@ func (rule *rule) match(host string) (Result, error) {
IsFiltered: true,
Reason: FilteredBlackList,
Rule: rule.originalText,
Ip: rule.ip,
IP: rule.ip,
FilterID: rule.listID,
}, nil
}
@ -661,8 +671,11 @@ func (d *Dnsfilter) checkParental(host string) (Result, error) {
return result, err
}
type formatHandler func(hashparam string) string
type bodyHandler func(body []byte, hashes map[string]bool) (Result, error)
// real implementation of lookup/check
func (d *Dnsfilter) lookupCommon(host string, lookupstats *LookupStats, cache gcache.Cache, hashparamNeedSlash bool, format func(hashparam string) string, handleBody func(body []byte, hashes map[string]bool) (Result, error)) (Result, error) {
func (d *Dnsfilter) lookupCommon(host string, lookupstats *LookupStats, cache gcache.Cache, hashparamNeedSlash bool, format formatHandler, handleBody bodyHandler) (Result, error) {
// if host ends with a dot, trim it
host = strings.ToLower(strings.Trim(host, "."))
@ -774,43 +787,43 @@ func (d *Dnsfilter) AddRule(input string, filterListID int64) error {
}
// Start parsing the rule
rule := rule{
r := rule{
text: input, // will be modified
originalText: input,
listID: filterListID,
}
// Mark rule as whitelist if it starts with @@
if strings.HasPrefix(rule.text, "@@") {
rule.isWhitelist = true
rule.text = rule.text[2:]
if strings.HasPrefix(r.text, "@@") {
r.isWhitelist = true
r.text = r.text[2:]
}
err := rule.parseOptions()
err := r.parseOptions()
if err != nil {
return err
}
rule.extractShortcut()
r.extractShortcut()
if !enableDelayedCompilation {
err := rule.compile()
err := r.compile()
if err != nil {
return err
}
}
destination := d.blackList
if rule.isImportant {
if r.isImportant {
destination = d.important
} else if rule.isWhitelist {
} else if r.isWhitelist {
destination = d.whiteList
}
d.storageMutex.Lock()
d.storage[input] = true
d.storageMutex.Unlock()
destination.Add(&rule)
destination.Add(&r)
return nil
}
@ -835,13 +848,13 @@ func (d *Dnsfilter) parseEtcHosts(input string, filterListID int64) bool {
d.storageMutex.Unlock()
for _, host := range fields[1:] {
rule := rule{
r := rule{
text: host,
originalText: input,
listID: filterListID,
ip: addr,
}
d.blackList.Add(&rule)
d.blackList.Add(&r)
}
return true
}
@ -913,15 +926,6 @@ func (d *Dnsfilter) Destroy() {
// config manipulation helpers
//
// IsParentalSensitivityValid checks if sensitivity is valid value
func IsParentalSensitivityValid(sensitivity int) bool {
switch sensitivity {
case 3, 10, 13, 17:
return true
}
return false
}
// SetSafeBrowsingServer lets you optionally change hostname of safesearch lookup
func (d *Dnsfilter) SetSafeBrowsingServer(host string) {
if len(host) == 0 {

View File

@ -26,7 +26,7 @@ import (
func TestLotsOfRulesMemoryUsage(t *testing.T) {
start := getRSS()
log.Tracef("RSS before loading rules - %d kB\n", start/1024)
dumpMemProfile(_Func() + "1.pprof")
dumpMemProfile("tests/" + _Func() + "1.pprof")
d := NewForTest()
defer d.Destroy()
@ -37,7 +37,7 @@ func TestLotsOfRulesMemoryUsage(t *testing.T) {
afterLoad := getRSS()
log.Tracef("RSS after loading rules - %d kB (%d kB diff)\n", afterLoad/1024, (afterLoad-start)/1024)
dumpMemProfile(_Func() + "2.pprof")
dumpMemProfile("tests/" + _Func() + "2.pprof")
tests := []struct {
host string
@ -60,7 +60,7 @@ func TestLotsOfRulesMemoryUsage(t *testing.T) {
}
afterMatch := getRSS()
log.Tracef("RSS after matching - %d kB (%d kB diff)\n", afterMatch/1024, (afterMatch-afterLoad)/1024)
dumpMemProfile(_Func() + "3.pprof")
dumpMemProfile("tests/" + _Func() + "3.pprof")
}
func getRSS() uint64 {
@ -69,6 +69,9 @@ func getRSS() uint64 {
panic(err)
}
minfo, err := proc.MemoryInfo()
if err != nil {
panic(err)
}
return minfo.RSS
}
@ -86,7 +89,7 @@ func dumpMemProfile(name string) {
}
}
const topHostsFilename = "../tests/top-1m.csv"
const topHostsFilename = "tests/top-1m.csv"
func fetchTopHostsFromNet() {
log.Tracef("Fetching top hosts from network")
@ -146,7 +149,7 @@ func getTopHosts() {
func TestLotsOfRulesLotsOfHostsMemoryUsage(t *testing.T) {
start := getRSS()
log.Tracef("RSS before loading rules - %d kB\n", start/1024)
dumpMemProfile(_Func() + "1.pprof")
dumpMemProfile("tests/" + _Func() + "1.pprof")
d := NewForTest()
defer d.Destroy()
@ -155,7 +158,7 @@ func TestLotsOfRulesLotsOfHostsMemoryUsage(t *testing.T) {
afterLoad := getRSS()
log.Tracef("RSS after loading rules - %d kB (%d kB diff)\n", afterLoad/1024, (afterLoad-start)/1024)
dumpMemProfile(_Func() + "2.pprof")
dumpMemProfile("tests/" + _Func() + "2.pprof")
getTopHosts()
hostnames, err := os.Open(topHostsFilename)
@ -165,7 +168,7 @@ func TestLotsOfRulesLotsOfHostsMemoryUsage(t *testing.T) {
defer hostnames.Close()
afterHosts := getRSS()
log.Tracef("RSS after loading hosts - %d kB (%d kB diff)\n", afterHosts/1024, (afterHosts-afterLoad)/1024)
dumpMemProfile(_Func() + "2.pprof")
dumpMemProfile("tests/" + _Func() + "2.pprof")
{
scanner := bufio.NewScanner(hostnames)
@ -184,7 +187,7 @@ func TestLotsOfRulesLotsOfHostsMemoryUsage(t *testing.T) {
afterMatch := getRSS()
log.Tracef("RSS after matching - %d kB (%d kB diff)\n", afterMatch/1024, (afterMatch-afterLoad)/1024)
dumpMemProfile(_Func() + "3.pprof")
dumpMemProfile("tests/" + _Func() + "3.pprof")
}
func TestRuleToRegexp(t *testing.T) {
@ -282,7 +285,7 @@ func (d *Dnsfilter) checkMatch(t *testing.T, hostname string) {
}
}
func (d *Dnsfilter) checkMatchIp(t *testing.T, hostname string, ip string) {
func (d *Dnsfilter) checkMatchIP(t *testing.T, hostname string, ip string) {
t.Helper()
ret, err := d.CheckHost(hostname)
if err != nil {
@ -291,8 +294,8 @@ func (d *Dnsfilter) checkMatchIp(t *testing.T, hostname string, ip string) {
if !ret.IsFiltered {
t.Errorf("Expected hostname %s to match", hostname)
}
if ret.Ip == nil || ret.Ip.String() != ip {
t.Errorf("Expected ip %s to match, actual: %v", ip, ret.Ip)
if ret.IP == nil || ret.IP.String() != ip {
t.Errorf("Expected ip %s to match, actual: %v", ip, ret.IP)
}
}
@ -308,7 +311,7 @@ func (d *Dnsfilter) checkMatchEmpty(t *testing.T, hostname string) {
}
func loadTestRules(d *Dnsfilter) error {
filterFileName := "../tests/dns.txt"
filterFileName := "tests/dns.txt"
file, err := os.Open(filterFileName)
if err != nil {
return err
@ -368,8 +371,8 @@ func TestEtcHostsMatching(t *testing.T) {
text := fmt.Sprintf(" %s google.com www.google.com # enforce google's safesearch ", addr)
d.checkAddRule(t, text)
d.checkMatchIp(t, "google.com", addr)
d.checkMatchIp(t, "www.google.com", addr)
d.checkMatchIP(t, "google.com", addr)
d.checkMatchIP(t, "www.google.com", addr)
d.checkMatchEmpty(t, "subdomain.google.com")
d.checkMatchEmpty(t, "example.org")
}
@ -866,7 +869,7 @@ func BenchmarkLotsOfRulesLotsOfHosts(b *testing.B) {
for n := 0; n < b.N; n++ {
havedata := scanner.Scan()
if !havedata {
hostnames.Seek(0, 0)
_, _ = hostnames.Seek(0, 0)
scanner = bufio.NewScanner(hostnames)
havedata = scanner.Scan()
}
@ -903,7 +906,7 @@ func BenchmarkLotsOfRulesLotsOfHostsParallel(b *testing.B) {
for pb.Next() {
havedata := scanner.Scan()
if !havedata {
hostnames.Seek(0, 0)
_, _ = hostnames.Seek(0, 0)
scanner = bufio.NewScanner(hostnames)
havedata = scanner.Scan()
}

View File

@ -310,8 +310,8 @@ func (s *Server) genDNSFilterMessage(d *proxy.DNSContext, result *dnsfilter.Resu
case dnsfilter.FilteredParental:
return s.genBlockedHost(m, parentalBlockHost, d.Upstream)
default:
if result.Ip != nil {
return s.genARecord(m, result.Ip)
if result.IP != nil {
return s.genARecord(m, result.IP)
}
return s.genNXDomain(m)

View File

@ -108,10 +108,16 @@ func logRequest(question *dns.Msg, answer *dns.Msg, result *dnsfilter.Result, el
if len(flushBuffer) > 0 {
// write to file
// do it in separate goroutine -- we are stalling DNS response this whole time
go flushToFile(flushBuffer)
go func() {
err := flushToFile(flushBuffer)
if err != nil {
log.Printf("Failed to flush the query log: %s", err)
}
}()
}
}
// HandleQueryLog handles query log web request
func HandleQueryLog(w http.ResponseWriter, r *http.Request) {
queryLogLock.RLock()
values := make([]*logEntry, len(queryLogCache))
@ -123,6 +129,7 @@ func HandleQueryLog(w http.ResponseWriter, r *http.Request) {
values[left], values[right] = values[right], values[left]
}
// iterate
var data = []map[string]interface{}{}
for _, entry := range values {
var q *dns.Msg
@ -167,7 +174,36 @@ func HandleQueryLog(w http.ResponseWriter, r *http.Request) {
jsonEntry["filterId"] = entry.Result.FilterID
}
if a != nil && len(a.Answer) > 0 {
answers := answerToMap(a)
if answers != nil {
jsonEntry["answer"] = answers
}
data = append(data, jsonEntry)
}
jsonVal, err := json.Marshal(data)
if err != nil {
errorText := fmt.Sprintf("Couldn't marshal data into json: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal)
if err != nil {
errorText := fmt.Sprintf("Unable to write response json: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
}
}
func answerToMap(a *dns.Msg) []map[string]interface{} {
if a == nil || len(a.Answer) == 0 {
return nil
}
var answers = []map[string]interface{}{}
for _, k := range a.Answer {
header := k.Header()
@ -207,27 +243,8 @@ func HandleQueryLog(w http.ResponseWriter, r *http.Request) {
}
answers = append(answers, answer)
}
jsonEntry["answer"] = answers
}
data = append(data, jsonEntry)
}
jsonVal, err := json.Marshal(data)
if err != nil {
errorText := fmt.Sprintf("Couldn't marshal data into json: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal)
if err != nil {
errorText := fmt.Sprintf("Unable to write response json: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
}
return answers
}
// getIPString is a helper function that extracts IP address from net.Addr

View File

@ -156,7 +156,7 @@ func periodicQueryLogRotate() {
func genericLoader(onEntry func(entry *logEntry) error, needMore func() bool, timeWindow time.Duration) error {
now := time.Now()
// read from querylog files, try newest file first
files := []string{}
var files []string
if enableGzip {
files = []string{

View File

@ -26,10 +26,10 @@ type hourTop struct {
mutex sync.RWMutex
}
func (top *hourTop) init() {
top.domains = gcache.New(queryLogTopSize).LRU().Build()
top.blocked = gcache.New(queryLogTopSize).LRU().Build()
top.clients = gcache.New(queryLogTopSize).LRU().Build()
func (h *hourTop) init() {
h.domains = gcache.New(queryLogTopSize).LRU().Build()
h.blocked = gcache.New(queryLogTopSize).LRU().Build()
h.clients = gcache.New(queryLogTopSize).LRU().Build()
}
type dayTop struct {
@ -69,9 +69,9 @@ func periodicHourlyTopRotate() {
}
}
func (top *hourTop) incrementValue(key string, cache gcache.Cache) error {
top.Lock()
defer top.Unlock()
func (h *hourTop) incrementValue(key string, cache gcache.Cache) error {
h.Lock()
defer h.Unlock()
ivalue, err := cache.Get(key)
if err == gcache.KeyNotFoundError {
// we just set it and we're done
@ -103,20 +103,20 @@ func (top *hourTop) incrementValue(key string, cache gcache.Cache) error {
return nil
}
func (top *hourTop) incrementDomains(key string) error {
return top.incrementValue(key, top.domains)
func (h *hourTop) incrementDomains(key string) error {
return h.incrementValue(key, h.domains)
}
func (top *hourTop) incrementBlocked(key string) error {
return top.incrementValue(key, top.blocked)
func (h *hourTop) incrementBlocked(key string) error {
return h.incrementValue(key, h.blocked)
}
func (top *hourTop) incrementClients(key string) error {
return top.incrementValue(key, top.clients)
func (h *hourTop) incrementClients(key string) error {
return h.incrementValue(key, h.clients)
}
// if does not exist -- return 0
func (top *hourTop) lockedGetValue(key string, cache gcache.Cache) (int, error) {
func (h *hourTop) lockedGetValue(key string, cache gcache.Cache) (int, error) {
ivalue, err := cache.Get(key)
if err == gcache.KeyNotFoundError {
return 0, nil
@ -137,19 +137,19 @@ func (top *hourTop) lockedGetValue(key string, cache gcache.Cache) (int, error)
return value, nil
}
func (top *hourTop) lockedGetDomains(key string) (int, error) {
return top.lockedGetValue(key, top.domains)
func (h *hourTop) lockedGetDomains(key string) (int, error) {
return h.lockedGetValue(key, h.domains)
}
func (top *hourTop) lockedGetBlocked(key string) (int, error) {
return top.lockedGetValue(key, top.blocked)
func (h *hourTop) lockedGetBlocked(key string) (int, error) {
return h.lockedGetValue(key, h.blocked)
}
func (top *hourTop) lockedGetClients(key string) (int, error) {
return top.lockedGetValue(key, top.clients)
func (h *hourTop) lockedGetClients(key string) (int, error) {
return h.lockedGetValue(key, h.clients)
}
func (r *dayTop) addEntry(entry *logEntry, q *dns.Msg, now time.Time) error {
func (d *dayTop) addEntry(entry *logEntry, q *dns.Msg, now time.Time) error {
// figure out which hour bucket it belongs to
hour := int(now.Sub(entry.Time).Hours())
if hour >= 24 {
@ -252,6 +252,7 @@ func fillStatsFromQueryLog() error {
return nil
}
// HandleStatsTop returns the current top stats
func HandleStatsTop(w http.ResponseWriter, r *http.Request) {
domains := map[string]int{}
blocked := map[string]int{}
@ -320,9 +321,9 @@ func HandleStatsTop(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, err := w.Write(json.Bytes())
if err != nil {
errortext := fmt.Sprintf("Couldn't write body: %s", err)
log.Println(errortext)
http.Error(w, errortext, http.StatusInternalServerError)
errorText := fmt.Sprintf("Couldn't write body: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
}
}

View File

@ -17,7 +17,6 @@ var (
filteredLists = newDNSCounter("filtered_lists_total")
filteredSafebrowsing = newDNSCounter("filtered_safebrowsing_total")
filteredParental = newDNSCounter("filtered_parental_total")
filteredInvalid = newDNSCounter("filtered_invalid_total")
whitelisted = newDNSCounter("whitelisted_total")
safesearch = newDNSCounter("safesearch_total")
errorsTotal = newDNSCounter("errors_total")
@ -91,10 +90,10 @@ func (p *periodicStats) Observe(name string, when time.Time, value float64) {
{
countname := name + "_count"
currentValues := p.Entries[countname]
value := currentValues[elapsed]
v := currentValues[elapsed]
// log.Tracef("Will change p.Entries[%s][%d] from %v to %v", countname, elapsed, value, value+1)
value += 1
currentValues[elapsed] = value
v++
currentValues[elapsed] = v
p.Entries[countname] = currentValues
}
{
@ -224,6 +223,7 @@ func incrementCounters(entry *logEntry) {
elapsedTime.ObserveWithTime(entry.Elapsed.Seconds(), entry.Time)
}
// HandleStats returns aggregated stats data for the 24 hours
func HandleStats(w http.ResponseWriter, r *http.Request) {
const numHours = 24
histrical := generateMapFromStats(&statistics.PerHour, 0, numHours)
@ -252,17 +252,17 @@ func HandleStats(w http.ResponseWriter, r *http.Request) {
json, err := json.Marshal(summed)
if err != nil {
errortext := fmt.Sprintf("Unable to marshal status json: %s", err)
log.Println(errortext)
http.Error(w, errortext, 500)
errorText := fmt.Sprintf("Unable to marshal status json: %s", err)
log.Println(errorText)
http.Error(w, errorText, 500)
return
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(json)
if err != nil {
errortext := fmt.Sprintf("Unable to write response json: %s", err)
log.Println(errortext)
http.Error(w, errortext, 500)
errorText := fmt.Sprintf("Unable to write response json: %s", err)
log.Println(errorText)
http.Error(w, errorText, 500)
return
}
}
@ -296,6 +296,7 @@ func generateMapFromStats(stats *periodicStats, start int, end int) map[string]i
return result
}
// HandleStatsHistory returns historical stats data for the 24 hours
func HandleStatsHistory(w http.ResponseWriter, r *http.Request) {
// handle time unit and prepare our time window size
now := time.Now()
@ -323,27 +324,27 @@ func HandleStatsHistory(w http.ResponseWriter, r *http.Request) {
// parse start and end time
startTime, err := time.Parse(time.RFC3339, r.URL.Query().Get("start_time"))
if err != nil {
errortext := fmt.Sprintf("Must specify valid start_time parameter: %s", err)
log.Println(errortext)
http.Error(w, errortext, 400)
errorText := fmt.Sprintf("Must specify valid start_time parameter: %s", err)
log.Println(errorText)
http.Error(w, errorText, 400)
return
}
endTime, err := time.Parse(time.RFC3339, r.URL.Query().Get("end_time"))
if err != nil {
errortext := fmt.Sprintf("Must specify valid end_time parameter: %s", err)
log.Println(errortext)
http.Error(w, errortext, 400)
errorText := fmt.Sprintf("Must specify valid end_time parameter: %s", err)
log.Println(errorText)
http.Error(w, errorText, 400)
return
}
// check if start and time times are within supported time range
timeRange := timeUnit * statsHistoryElements
if startTime.Add(timeRange).Before(now) {
http.Error(w, "start_time parameter is outside of supported range", 501)
http.Error(w, "start_time parameter is outside of supported range", http.StatusBadRequest)
return
}
if endTime.Add(timeRange).Before(now) {
http.Error(w, "end_time parameter is outside of supported range", 501)
http.Error(w, "end_time parameter is outside of supported range", http.StatusBadRequest)
return
}
@ -360,28 +361,29 @@ func HandleStatsHistory(w http.ResponseWriter, r *http.Request) {
data := generateMapFromStats(stats, start, end)
json, err := json.Marshal(data)
if err != nil {
errortext := fmt.Sprintf("Unable to marshal status json: %s", err)
log.Println(errortext)
http.Error(w, errortext, 500)
errorText := fmt.Sprintf("Unable to marshal status json: %s", err)
log.Println(errorText)
http.Error(w, errorText, 500)
return
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(json)
if err != nil {
errortext := fmt.Sprintf("Unable to write response json: %s", err)
log.Println(errortext)
http.Error(w, errortext, 500)
errorText := fmt.Sprintf("Unable to write response json: %s", err)
log.Println(errorText)
http.Error(w, errorText, 500)
return
}
}
// HandleStatsReset resets the stats caches
func HandleStatsReset(w http.ResponseWriter, r *http.Request) {
purgeStats()
_, err := fmt.Fprintf(w, "OK\n")
if err != nil {
errortext := fmt.Sprintf("Couldn't write body: %s", err)
log.Println(errortext)
http.Error(w, errortext, http.StatusInternalServerError)
errorText := fmt.Sprintf("Couldn't write body: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
}
}

View File

@ -70,20 +70,20 @@ func updateUniqueFilterID(filters []filter) {
func assignUniqueFilterID() int64 {
value := nextFilterID
nextFilterID += 1
nextFilterID++
return value
}
// Sets up a timer that will be checking for filters updates periodically
func periodicallyRefreshFilters() {
for range time.Tick(time.Minute) {
refreshFiltersIfNeccessary(false)
refreshFiltersIfNecessary(false)
}
}
// Checks filters updates if necessary
// If force is true, it ignores the filter.LastUpdated field value
func refreshFiltersIfNeccessary(force bool) int {
func refreshFiltersIfNecessary(force bool) int {
config.Lock()
// fetch URLs
@ -113,8 +113,12 @@ func refreshFiltersIfNeccessary(force bool) int {
}
config.Unlock()
if updateCount > 0 {
reconfigureDNSServer()
if updateCount > 0 && isRunning() {
err := reconfigureDNSServer()
if err != nil {
msg := fmt.Sprintf("SHOULD NOT HAPPEN: cannot reconfigure DNS server with the new filters: %s", err)
panic(msg)
}
}
return updateCount
}

View File

@ -39,7 +39,7 @@ func safeWriteFile(path string, data []byte) error {
func ensure(method string, handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
if r.Method != method {
http.Error(w, "This request must be "+method, 405)
http.Error(w, "This request must be "+method, http.StatusMethodNotAllowed)
return
}
handler(w, r)

View File

@ -34,9 +34,9 @@ func handleI18nCurrentLanguage(w http.ResponseWriter, r *http.Request) {
log.Printf("config.Language is %s", config.Language)
_, err := fmt.Fprintf(w, "%s\n", config.Language)
if err != nil {
errortext := fmt.Sprintf("Unable to write response json: %s", err)
log.Println(errortext)
http.Error(w, errortext, http.StatusInternalServerError)
errorText := fmt.Sprintf("Unable to write response json: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
return
}
}

View File

@ -7,7 +7,7 @@ import (
"path/filepath"
"github.com/hmage/golibs/log"
"gopkg.in/yaml.v2"
yaml "gopkg.in/yaml.v2"
)
const currentSchemaVersion = 2 // used for upgrading from old configs to new config