Merge pull request #117 in DNS/adguard-dns from no_coredns to master
* commit '253d8a4016d66863ecee426b8f7d74841c4ed4de': (58 commits) Pointer for dnsfilter.Result in querylog didn't make things simpler, revert that change and all related changes. Fixup of previous commit -- remove unused import. Remove unused code. Use filter deduplication function. Small code review update -- use CamelCase readme -- Update config field descriptions and clarify about coredns. dnsforward -- fix panic on ANY request dnsfilter -- fix broken tests config -- Avoid deleting existing dns section if someone removes schema_version from yaml file. Rename coredns.go to dns.go Add support for bootstrapping upstream DNS servers by hostname. dnsforward -- support tcp:// schema dnsforward -- add upstream tests. Don't omit empty user rules in configfile -- otherwise users might not be able to find that it's customizable in configfile. Get rid of mentions of CoreDNS in code except for upgrading and in readme. Add config upgrade. dnsforward -- add a simple test that launches a server and queries well-known value through it Remove old entries from .gitignore Remove unused code. Goodbye CoreDNS. Use dnsforward for checking if upstream DNS server is working. dnsforward -- implement ratelimit and refuseany ...
This commit is contained in:
commit
b5121c5754
8
.gitignore
vendored
8
.gitignore
vendored
@ -1,15 +1,11 @@
|
||||
.DS_Store
|
||||
.vscode
|
||||
.idea
|
||||
debug
|
||||
/.vscode
|
||||
/.idea
|
||||
/AdGuardHome
|
||||
/AdGuardHome.yaml
|
||||
/data/
|
||||
/build/
|
||||
/client/node_modules/
|
||||
/coredns
|
||||
/Corefile
|
||||
/dnsfilter.txt
|
||||
/querylog.json
|
||||
/querylog.json.1
|
||||
/scripts/translations/node_modules
|
||||
|
9
Makefile
9
Makefile
@ -19,9 +19,12 @@ client/node_modules: client/package.json client/package-lock.json
|
||||
$(STATIC): $(JSFILES) client/node_modules
|
||||
npm --prefix client run build-prod
|
||||
|
||||
$(TARGET): $(STATIC) *.go coredns_plugin/*.go dnsfilter/*.go
|
||||
GOPATH=$(GOPATH) GOOS=$(NATIVE_GOOS) GOARCH=$(NATIVE_GOARCH) GO111MODULE=off go get -v github.com/gobuffalo/packr/...
|
||||
GOPATH=$(GOPATH) PATH=$(GOPATH)/bin:$(PATH) packr build -ldflags="-X main.VersionString=$(GIT_VERSION)" -asmflags="-trimpath=$(PWD)" -gcflags="-trimpath=$(PWD)" -o $(TARGET)
|
||||
$(TARGET): $(STATIC) *.go dnsfilter/*.go dnsforward/*.go
|
||||
go get -d .
|
||||
GOOS=$(NATIVE_GOOS) GOARCH=$(NATIVE_GOARCH) GO111MODULE=off go get -v github.com/gobuffalo/packr/...
|
||||
PATH=$(GOPATH)/bin:$(PATH) packr -z
|
||||
CGO_ENABLED=0 go build -ldflags="-s -w -X main.VersionString=$(GIT_VERSION)" -asmflags="-trimpath=$(PWD)" -gcflags="-trimpath=$(PWD)"
|
||||
PATH=$(GOPATH)/bin:$(PATH) packr clean
|
||||
|
||||
clean:
|
||||
$(MAKE) cleanfast
|
||||
|
66
README.md
66
README.md
@ -90,7 +90,7 @@ Now open the browser and navigate to http://localhost:3000/ to control your AdGu
|
||||
You can run AdGuard Home without superuser privileges, but you need to instruct it to use a different port rather than 53. You can do that by editing `AdGuardHome.yaml` and finding these two lines:
|
||||
|
||||
```yaml
|
||||
coredns:
|
||||
dns:
|
||||
port: 53
|
||||
```
|
||||
|
||||
@ -104,25 +104,32 @@ Upon the first execution, a file named `AdGuardHome.yaml` will be created, with
|
||||
|
||||
Settings are stored in [YAML format](https://en.wikipedia.org/wiki/YAML), possible parameters that you can configure are listed below:
|
||||
|
||||
* `bind_host` — Web interface IP address to listen on
|
||||
* `bind_port` — Web interface IP port to listen on
|
||||
* `auth_name` — Web interface optional authorization username
|
||||
* `auth_pass` — Web interface optional authorization password
|
||||
* `coredns` — CoreDNS configuration section
|
||||
* `port` — DNS server port to listen on
|
||||
* `filtering_enabled` — Filtering of DNS requests based on filter lists
|
||||
* `safebrowsing_enabled` — Filtering of DNS requests based on safebrowsing
|
||||
* `safesearch_enabled` — Enforcing "Safe search" option for search engines, when possible
|
||||
* `parental_enabled` — Parental control-based DNS requests filtering
|
||||
* `parental_sensitivity` — Age group for parental control-based filtering, must be either 3, 10, 13 or 17
|
||||
* `querylog_enabled` — Query logging (also used to calculate top 50 clients, blocked domains and requested domains for statistic purposes)
|
||||
* `bootstrap_dns` — DNS server used for initial hostnames resolution in case if upstream is DoH or DoT with a hostname
|
||||
* `upstream_dns` — List of upstream DNS servers
|
||||
* `bind_host` — Web interface IP address to listen on.
|
||||
* `bind_port` — Web interface IP port to listen on.
|
||||
* `auth_name` — Web interface optional authorization username.
|
||||
* `auth_pass` — Web interface optional authorization password.
|
||||
* `dns` — DNS configuration section.
|
||||
* `port` — DNS server port to listen on.
|
||||
* `protection_enabled` — Whether any kind of filtering and protection should be done, when off it works as a plain dns forwarder.
|
||||
* `filtering_enabled` — Filtering of DNS requests based on filter lists.
|
||||
* `blocked_response_ttl` — For how many seconds the clients should cache a filtered response. Low values are useful on LAN if you change filters very often, high values are useful to increase performance and save traffic.
|
||||
* `querylog_enabled` — Query logging (also used to calculate top 50 clients, blocked domains and requested domains for statistical purposes).
|
||||
* `ratelimit` — DDoS protection, specifies in how many packets per second a client should receive. Anything above that is silently dropped. To disable set 0, default is 20. Safe to disable if DNS server is not available from internet.
|
||||
* `ratelimit_whitelist` — If you want exclude some IP addresses from ratelimiting but keep ratelimiting on for others, put them here.
|
||||
* `refuse_any` — Another DDoS protection mechanism. Requests of type ANY are rarely needed, so refusing to serve them mitigates against attackers trying to use your DNS as a reflection. Safe to disable if DNS server is not available from internet.
|
||||
* `bootstrap_dns` — DNS server used for initial hostname resolution in case if upstream server name is a hostname.
|
||||
* `parental_sensitivity` — Age group for parental control-based filtering, must be either 3, 10, 13 or 17 if enabled.
|
||||
* `parental_enabled` — Parental control-based DNS requests filtering.
|
||||
* `safesearch_enabled` — Enforcing "Safe search" option for search engines, when possible.
|
||||
* `safebrowsing_enabled` — Filtering of DNS requests based on safebrowsing.
|
||||
* `upstream_dns` — List of upstream DNS servers.
|
||||
* `filters` — List of filters, each filter has the following values:
|
||||
* `ID` - filter ID (must be unique)
|
||||
* `url` — URL pointing to the filter contents (filtering rules)
|
||||
* `enabled` — Current filter's status (enabled/disabled)
|
||||
* `user_rules` — User-specified filtering rules
|
||||
* `enabled` — Current filter's status (enabled/disabled).
|
||||
* `url` — URL pointing to the filter contents (filtering rules).
|
||||
* `name` — Name of the filter. If it's an adguard syntax filter it will get updated automatically, otherwise it stays unchanged.
|
||||
* `last_updated` — Time when the filter was last updated from server.
|
||||
* `ID` - filter ID (must be unique).
|
||||
* `user_rules` — User-specified filtering rules.
|
||||
|
||||
Removing an entry from settings file will reset it to the default value. Deleting the file will reset all settings to the default values.
|
||||
|
||||
@ -151,7 +158,15 @@ cd AdGuardHome
|
||||
make
|
||||
```
|
||||
|
||||
## How to update translations
|
||||
## Contributing
|
||||
|
||||
You are welcome to fork this repository, make your changes and submit a pull request — https://github.com/AdguardTeam/AdGuardHome/pulls
|
||||
|
||||
### How to update translations
|
||||
|
||||
If you want to help with AdGuard Home translations, please learn more about translating AdGuard products here: https://kb.adguard.com/en/general/adguard-translations
|
||||
|
||||
Here is a direct link to AdGuard Home project: http://translate.adguard.com/collaboration/project?id=153384
|
||||
|
||||
Before updating translations you need to install dependencies:
|
||||
```
|
||||
@ -181,14 +196,6 @@ node upload.js
|
||||
node download.js
|
||||
```
|
||||
|
||||
## Contributing
|
||||
|
||||
You are welcome to fork this repository, make your changes and submit a pull request — https://github.com/AdguardTeam/AdGuardHome/pulls
|
||||
|
||||
If you want to help with AdGuard Home translations, please learn more about translating AdGuard products here: https://kb.adguard.com/en/general/adguard-translations
|
||||
|
||||
Here is a direct link to AdGuard Home project: http://translate.adguard.com/collaboration/project?id=153384
|
||||
|
||||
## Reporting issues
|
||||
|
||||
If you run into any problem or have a suggestion, head to [this page](https://github.com/AdguardTeam/AdGuardHome/issues) and click on the `New issue` button.
|
||||
@ -198,7 +205,6 @@ If you run into any problem or have a suggestion, head to [this page](https://gi
|
||||
This software wouldn't have been possible without:
|
||||
|
||||
* [Go](https://golang.org/dl/) and it's libraries:
|
||||
* [CoreDNS](https://coredns.io)
|
||||
* [packr](https://github.com/gobuffalo/packr)
|
||||
* [gcache](https://github.com/bluele/gcache)
|
||||
* [miekg's dns](https://github.com/miekg/dns)
|
||||
@ -209,4 +215,6 @@ This software wouldn't have been possible without:
|
||||
* And many more node.js packages.
|
||||
* [whotracks.me data](https://github.com/cliqz-oss/whotracks.me)
|
||||
|
||||
You might have seen that [CoreDNS](https://coredns.io) was mentioned here before — we've stopped using it in AdGuardHome. While we still use it on our servers for [AdGuard DNS](https://adguard.com/adguard-dns/overview.html) service, it seemed like an overkill for Home as it impeded with Home features that we plan to implement.
|
||||
|
||||
For a full list of all node.js packages in use, please take a look at [client/package.json](https://github.com/AdguardTeam/AdGuardHome/blob/master/client/package.json) file.
|
||||
|
22
app.go
22
app.go
@ -7,8 +7,10 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/gobuffalo/packr"
|
||||
@ -149,7 +151,7 @@ func main() {
|
||||
log.Printf("Couldn't load filter %d contents due to %s", filter.ID, err)
|
||||
// clear LastUpdated so it gets fetched right away
|
||||
}
|
||||
if len(filter.Contents) == 0 {
|
||||
if len(filter.Rules) == 0 {
|
||||
filter.LastUpdated = time.Time{}
|
||||
}
|
||||
}
|
||||
@ -164,10 +166,13 @@ func main() {
|
||||
}
|
||||
}()
|
||||
|
||||
// Eat all args so that coredns can start happily
|
||||
if len(os.Args) > 1 {
|
||||
os.Args = os.Args[:1]
|
||||
}
|
||||
signalChannel := make(chan os.Signal)
|
||||
signal.Notify(signalChannel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT)
|
||||
go func() {
|
||||
<-signalChannel
|
||||
cleanup()
|
||||
os.Exit(0)
|
||||
}()
|
||||
|
||||
// Save the updated config
|
||||
err := config.write()
|
||||
@ -192,6 +197,13 @@ func main() {
|
||||
log.Fatal(http.ListenAndServe(address, nil))
|
||||
}
|
||||
|
||||
func cleanup() {
|
||||
err := stopDNSServer()
|
||||
if err != nil {
|
||||
log.Printf("Couldn't stop DNS server: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func getInput() (string, error) {
|
||||
scanner := bufio.NewScanner(os.Stdin)
|
||||
scanner.Scan()
|
||||
|
254
config.go
254
config.go
@ -1,43 +1,36 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sync"
|
||||
"text/template"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
|
||||
"github.com/AdguardTeam/AdGuardHome/dnsforward"
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
const (
|
||||
currentSchemaVersion = 1 // used for upgrading from old configs to new config
|
||||
dataDir = "data" // data storage
|
||||
filterDir = "filters" // cache location for downloaded filters, it's under DataDir
|
||||
userFilterID = 0 // special filter ID, always 0
|
||||
dataDir = "data" // data storage
|
||||
filterDir = "filters" // cache location for downloaded filters, it's under DataDir
|
||||
)
|
||||
|
||||
// Just a counter that we use for incrementing the filter ID
|
||||
var nextFilterID int64 = time.Now().Unix()
|
||||
|
||||
// configuration is loaded from YAML
|
||||
// field ordering is important -- yaml fields will mirror ordering from here
|
||||
type configuration struct {
|
||||
ourConfigFilename string // Config filename (can be overriden via the command line arguments)
|
||||
ourBinaryDir string // Location of our directory, used to protect against CWD being somewhere else
|
||||
|
||||
BindHost string `yaml:"bind_host"`
|
||||
BindPort int `yaml:"bind_port"`
|
||||
AuthName string `yaml:"auth_name"`
|
||||
AuthPass string `yaml:"auth_pass"`
|
||||
Language string `yaml:"language"` // two-letter ISO 639-1 language code
|
||||
CoreDNS coreDNSConfig `yaml:"coredns"`
|
||||
Filters []filter `yaml:"filters"`
|
||||
UserRules []string `yaml:"user_rules,omitempty"`
|
||||
BindHost string `yaml:"bind_host"`
|
||||
BindPort int `yaml:"bind_port"`
|
||||
AuthName string `yaml:"auth_name"`
|
||||
AuthPass string `yaml:"auth_pass"`
|
||||
Language string `yaml:"language"` // two-letter ISO 639-1 language code
|
||||
DNS dnsConfig `yaml:"dns"`
|
||||
Filters []filter `yaml:"filters"`
|
||||
UserRules []string `yaml:"user_rules"`
|
||||
|
||||
sync.RWMutex `yaml:"-"`
|
||||
|
||||
@ -45,38 +38,12 @@ type configuration struct {
|
||||
}
|
||||
|
||||
// field ordering is important -- yaml fields will mirror ordering from here
|
||||
type coreDNSConfig struct {
|
||||
binaryFile string
|
||||
coreFile string
|
||||
Filters []filter `yaml:"-"`
|
||||
Port int `yaml:"port"`
|
||||
ProtectionEnabled bool `yaml:"protection_enabled"`
|
||||
FilteringEnabled bool `yaml:"filtering_enabled"`
|
||||
SafeBrowsingEnabled bool `yaml:"safebrowsing_enabled"`
|
||||
SafeSearchEnabled bool `yaml:"safesearch_enabled"`
|
||||
ParentalEnabled bool `yaml:"parental_enabled"`
|
||||
ParentalSensitivity int `yaml:"parental_sensitivity"`
|
||||
BlockedResponseTTL int `yaml:"blocked_response_ttl"`
|
||||
QueryLogEnabled bool `yaml:"querylog_enabled"`
|
||||
Ratelimit int `yaml:"ratelimit"`
|
||||
RefuseAny bool `yaml:"refuse_any"`
|
||||
Pprof string `yaml:"-"`
|
||||
Cache string `yaml:"-"`
|
||||
Prometheus string `yaml:"-"`
|
||||
BootstrapDNS string `yaml:"bootstrap_dns"`
|
||||
UpstreamDNS []string `yaml:"upstream_dns"`
|
||||
}
|
||||
type dnsConfig struct {
|
||||
Port int `yaml:"port"`
|
||||
|
||||
// field ordering is important -- yaml fields will mirror ordering from here
|
||||
type filter struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
URL string `json:"url"`
|
||||
Name string `json:"name" yaml:"name"`
|
||||
RulesCount int `json:"rulesCount" yaml:"-"`
|
||||
LastUpdated time.Time `json:"lastUpdated,omitempty" yaml:"last_updated,omitempty"`
|
||||
ID int64 `json:"id"` // auto-assigned when filter is added (see nextFilterID), json by default keeps ID uppercase but we need lowercase
|
||||
dnsforward.FilteringConfig `yaml:",inline"`
|
||||
|
||||
Contents []byte `json:"-" yaml:"-"` // not in yaml or json
|
||||
UpstreamDNS []string `yaml:"upstream_dns"`
|
||||
}
|
||||
|
||||
var defaultDNS = []string{"tls://1.1.1.1", "tls://1.0.0.1"}
|
||||
@ -86,47 +53,26 @@ var config = configuration{
|
||||
ourConfigFilename: "AdGuardHome.yaml",
|
||||
BindPort: 3000,
|
||||
BindHost: "127.0.0.1",
|
||||
CoreDNS: coreDNSConfig{
|
||||
Port: 53,
|
||||
binaryFile: "coredns", // only filename, no path
|
||||
coreFile: "Corefile", // only filename, no path
|
||||
ProtectionEnabled: true,
|
||||
FilteringEnabled: true,
|
||||
SafeBrowsingEnabled: false,
|
||||
BlockedResponseTTL: 10, // in seconds
|
||||
QueryLogEnabled: true,
|
||||
Ratelimit: 20,
|
||||
RefuseAny: true,
|
||||
BootstrapDNS: "8.8.8.8:53",
|
||||
UpstreamDNS: defaultDNS,
|
||||
Cache: "cache",
|
||||
Prometheus: "prometheus :9153",
|
||||
DNS: dnsConfig{
|
||||
Port: 53,
|
||||
FilteringConfig: dnsforward.FilteringConfig{
|
||||
ProtectionEnabled: true, // whether or not use any of dnsfilter features
|
||||
FilteringEnabled: true, // whether or not use filter lists
|
||||
BlockedResponseTTL: 10, // in seconds
|
||||
QueryLogEnabled: true,
|
||||
Ratelimit: 20,
|
||||
RefuseAny: true,
|
||||
BootstrapDNS: "8.8.8.8:53",
|
||||
},
|
||||
UpstreamDNS: defaultDNS,
|
||||
},
|
||||
Filters: []filter{
|
||||
{ID: 1, Enabled: true, URL: "https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt", Name: "AdGuard Simplified Domain Names filter"},
|
||||
{ID: 2, Enabled: false, URL: "https://adaway.org/hosts.txt", Name: "AdAway"},
|
||||
{ID: 3, Enabled: false, URL: "https://hosts-file.net/ad_servers.txt", Name: "hpHosts - Ad and Tracking servers only"},
|
||||
{ID: 4, Enabled: false, URL: "http://www.malwaredomainlist.com/hostslist/hosts.txt", Name: "MalwareDomainList.com Hosts List"},
|
||||
{Filter: dnsfilter.Filter{ID: 1}, Enabled: true, URL: "https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt", Name: "AdGuard Simplified Domain Names filter"},
|
||||
{Filter: dnsfilter.Filter{ID: 2}, Enabled: false, URL: "https://adaway.org/hosts.txt", Name: "AdAway"},
|
||||
{Filter: dnsfilter.Filter{ID: 3}, Enabled: false, URL: "https://hosts-file.net/ad_servers.txt", Name: "hpHosts - Ad and Tracking servers only"},
|
||||
{Filter: dnsfilter.Filter{ID: 4}, Enabled: false, URL: "http://www.malwaredomainlist.com/hostslist/hosts.txt", Name: "MalwareDomainList.com Hosts List"},
|
||||
},
|
||||
}
|
||||
|
||||
// Creates a helper object for working with the user rules
|
||||
func userFilter() filter {
|
||||
// TODO: This should be calculated when UserRules are set
|
||||
var contents []byte
|
||||
for _, rule := range config.UserRules {
|
||||
contents = append(contents, []byte(rule)...)
|
||||
contents = append(contents, '\n')
|
||||
}
|
||||
|
||||
userFilter := filter{
|
||||
// User filter always has constant ID=0
|
||||
ID: userFilterID,
|
||||
Contents: contents,
|
||||
Enabled: true,
|
||||
}
|
||||
|
||||
return userFilter
|
||||
SchemaVersion: currentSchemaVersion,
|
||||
}
|
||||
|
||||
// Loads configuration from the YAML file
|
||||
@ -150,20 +96,7 @@ func parseConfig() error {
|
||||
}
|
||||
|
||||
// Deduplicate filters
|
||||
{
|
||||
i := 0 // output index, used for deletion later
|
||||
urls := map[string]bool{}
|
||||
for _, filter := range config.Filters {
|
||||
if _, ok := urls[filter.URL]; !ok {
|
||||
// we didn't see it before, keep it
|
||||
urls[filter.URL] = true // remember the URL
|
||||
config.Filters[i] = filter
|
||||
i++
|
||||
}
|
||||
}
|
||||
// all entries we want to keep are at front, delete the rest
|
||||
config.Filters = config.Filters[:i]
|
||||
}
|
||||
deduplicateFilters()
|
||||
|
||||
updateUniqueFilterID(config.Filters)
|
||||
|
||||
@ -187,6 +120,16 @@ func (c *configuration) write() error {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeAllConfigs() error {
|
||||
err := config.write()
|
||||
if err != nil {
|
||||
log.Printf("Couldn't write config: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
userFilter := userFilter()
|
||||
err = userFilter.save()
|
||||
if err != nil {
|
||||
@ -196,112 +139,3 @@ func (c *configuration) write() error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// --------------
|
||||
// coredns config
|
||||
// --------------
|
||||
func writeCoreDNSConfig() error {
|
||||
coreFile := filepath.Join(config.ourBinaryDir, config.CoreDNS.coreFile)
|
||||
log.Printf("Writing DNS config: %s", coreFile)
|
||||
configText, err := generateCoreDNSConfigText()
|
||||
if err != nil {
|
||||
log.Printf("Couldn't generate DNS config: %s", err)
|
||||
return err
|
||||
}
|
||||
err = safeWriteFile(coreFile, []byte(configText))
|
||||
if err != nil {
|
||||
log.Printf("Couldn't save DNS config: %s", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeAllConfigs() error {
|
||||
err := config.write()
|
||||
if err != nil {
|
||||
log.Printf("Couldn't write our config: %s", err)
|
||||
return err
|
||||
}
|
||||
err = writeCoreDNSConfig()
|
||||
if err != nil {
|
||||
log.Printf("Couldn't write DNS config: %s", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
const coreDNSConfigTemplate = `.:{{.Port}} {
|
||||
{{if .ProtectionEnabled}}dnsfilter {
|
||||
{{if .SafeBrowsingEnabled}}safebrowsing{{end}}
|
||||
{{if .ParentalEnabled}}parental {{.ParentalSensitivity}}{{end}}
|
||||
{{if .SafeSearchEnabled}}safesearch{{end}}
|
||||
{{if .QueryLogEnabled}}querylog{{end}}
|
||||
blocked_ttl {{.BlockedResponseTTL}}
|
||||
{{if .FilteringEnabled}}{{range .Filters}}{{if and .Enabled .Contents}}
|
||||
filter {{.ID}} "{{.Path}}"
|
||||
{{end}}{{end}}{{end}}
|
||||
}{{end}}
|
||||
{{.Pprof}}
|
||||
{{if .RefuseAny}}refuseany{{end}}
|
||||
{{if gt .Ratelimit 0}}ratelimit {{.Ratelimit}}{{end}}
|
||||
hosts {
|
||||
fallthrough
|
||||
}
|
||||
{{if .UpstreamDNS}}upstream {{range .UpstreamDNS}}{{.}} {{end}} { bootstrap {{.BootstrapDNS}} }{{end}}
|
||||
{{.Cache}}
|
||||
{{.Prometheus}}
|
||||
}
|
||||
`
|
||||
|
||||
var removeEmptyLines = regexp.MustCompile("([\t ]*\n)+")
|
||||
|
||||
// generate CoreDNS config text
|
||||
func generateCoreDNSConfigText() (string, error) {
|
||||
t, err := template.New("config").Parse(coreDNSConfigTemplate)
|
||||
if err != nil {
|
||||
log.Printf("Couldn't generate DNS config: %s", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
var configBytes bytes.Buffer
|
||||
temporaryConfig := config.CoreDNS
|
||||
|
||||
// generate temporary filter list, needed to put userfilter in coredns config
|
||||
filters := []filter{}
|
||||
|
||||
// first of all, append the user filter
|
||||
userFilter := userFilter()
|
||||
|
||||
filters = append(filters, userFilter)
|
||||
|
||||
// then go through other filters
|
||||
filters = append(filters, config.Filters...)
|
||||
temporaryConfig.Filters = filters
|
||||
|
||||
// run the template
|
||||
err = t.Execute(&configBytes, &temporaryConfig)
|
||||
if err != nil {
|
||||
log.Printf("Couldn't generate DNS config: %s", err)
|
||||
return "", err
|
||||
}
|
||||
configText := configBytes.String()
|
||||
|
||||
// remove empty lines from generated config
|
||||
configText = removeEmptyLines.ReplaceAllString(configText, "\n")
|
||||
return configText, nil
|
||||
}
|
||||
|
||||
// Set the next filter ID to max(filter.ID) + 1
|
||||
func updateUniqueFilterID(filters []filter) {
|
||||
for _, filter := range filters {
|
||||
if nextFilterID < filter.ID {
|
||||
nextFilterID = filter.ID + 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func assignUniqueFilterID() int64 {
|
||||
value := nextFilterID
|
||||
nextFilterID += 1
|
||||
return value
|
||||
}
|
||||
|
298
control.go
298
control.go
@ -1,29 +1,25 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/upstream"
|
||||
"github.com/AdguardTeam/AdGuardHome/dnsforward"
|
||||
"github.com/miekg/dns"
|
||||
|
||||
corednsplugin "github.com/AdguardTeam/AdGuardHome/coredns_plugin"
|
||||
"gopkg.in/asaskevich/govalidator.v4"
|
||||
)
|
||||
|
||||
const updatePeriod = time.Minute * 30
|
||||
|
||||
var filterTitleRegexp = regexp.MustCompile(`^! Title: +(.*)$`)
|
||||
|
||||
// cached version.json to avoid hammering github.io for each page reload
|
||||
var versionCheckJSON []byte
|
||||
var versionCheckLastTime time.Time
|
||||
@ -36,24 +32,20 @@ var client = &http.Client{
|
||||
}
|
||||
|
||||
// -------------------
|
||||
// coredns run control
|
||||
// dns run control
|
||||
// -------------------
|
||||
func tellCoreDNSToReload() {
|
||||
corednsplugin.Reload <- true
|
||||
}
|
||||
|
||||
func writeAllConfigsAndReloadCoreDNS() error {
|
||||
func writeAllConfigsAndReloadDNS() error {
|
||||
err := writeAllConfigs()
|
||||
if err != nil {
|
||||
log.Printf("Couldn't write all configs: %s", err)
|
||||
return err
|
||||
}
|
||||
tellCoreDNSToReload()
|
||||
reconfigureDNSServer()
|
||||
return nil
|
||||
}
|
||||
|
||||
func httpUpdateConfigReloadDNSReturnOK(w http.ResponseWriter, r *http.Request) {
|
||||
err := writeAllConfigsAndReloadCoreDNS()
|
||||
err := writeAllConfigsAndReloadDNS()
|
||||
if err != nil {
|
||||
errortext := fmt.Sprintf("Couldn't write config file: %s", err)
|
||||
log.Println(errortext)
|
||||
@ -75,12 +67,12 @@ func returnOK(w http.ResponseWriter, r *http.Request) {
|
||||
func handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||
data := map[string]interface{}{
|
||||
"dns_address": config.BindHost,
|
||||
"dns_port": config.CoreDNS.Port,
|
||||
"protection_enabled": config.CoreDNS.ProtectionEnabled,
|
||||
"querylog_enabled": config.CoreDNS.QueryLogEnabled,
|
||||
"dns_port": config.DNS.Port,
|
||||
"protection_enabled": config.DNS.ProtectionEnabled,
|
||||
"querylog_enabled": config.DNS.QueryLogEnabled,
|
||||
"running": isRunning(),
|
||||
"bootstrap_dns": config.CoreDNS.BootstrapDNS,
|
||||
"upstream_dns": config.CoreDNS.UpstreamDNS,
|
||||
"bootstrap_dns": config.DNS.BootstrapDNS,
|
||||
"upstream_dns": config.DNS.UpstreamDNS,
|
||||
"version": VersionString,
|
||||
"language": config.Language,
|
||||
}
|
||||
@ -103,12 +95,12 @@ func handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func handleProtectionEnable(w http.ResponseWriter, r *http.Request) {
|
||||
config.CoreDNS.ProtectionEnabled = true
|
||||
config.DNS.ProtectionEnabled = true
|
||||
httpUpdateConfigReloadDNSReturnOK(w, r)
|
||||
}
|
||||
|
||||
func handleProtectionDisable(w http.ResponseWriter, r *http.Request) {
|
||||
config.CoreDNS.ProtectionEnabled = false
|
||||
config.DNS.ProtectionEnabled = false
|
||||
httpUpdateConfigReloadDNSReturnOK(w, r)
|
||||
}
|
||||
|
||||
@ -116,12 +108,12 @@ func handleProtectionDisable(w http.ResponseWriter, r *http.Request) {
|
||||
// stats
|
||||
// -----
|
||||
func handleQueryLogEnable(w http.ResponseWriter, r *http.Request) {
|
||||
config.CoreDNS.QueryLogEnabled = true
|
||||
config.DNS.QueryLogEnabled = true
|
||||
httpUpdateConfigReloadDNSReturnOK(w, r)
|
||||
}
|
||||
|
||||
func handleQueryLogDisable(w http.ResponseWriter, r *http.Request) {
|
||||
config.CoreDNS.QueryLogEnabled = false
|
||||
config.DNS.QueryLogEnabled = false
|
||||
httpUpdateConfigReloadDNSReturnOK(w, r)
|
||||
}
|
||||
|
||||
@ -143,9 +135,9 @@ func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||
hosts := strings.Fields(string(body))
|
||||
|
||||
if len(hosts) == 0 {
|
||||
config.CoreDNS.UpstreamDNS = defaultDNS
|
||||
config.DNS.UpstreamDNS = defaultDNS
|
||||
} else {
|
||||
config.CoreDNS.UpstreamDNS = hosts
|
||||
config.DNS.UpstreamDNS = hosts
|
||||
}
|
||||
|
||||
err = writeAllConfigs()
|
||||
@ -155,7 +147,7 @@ func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, errorText, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
tellCoreDNSToReload()
|
||||
reconfigureDNSServer()
|
||||
_, err = fmt.Fprintf(w, "OK %d servers\n", len(hosts))
|
||||
if err != nil {
|
||||
errorText := fmt.Sprintf("Couldn't write body: %s", err)
|
||||
@ -211,23 +203,32 @@ func handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func checkDNS(input string) error {
|
||||
u, err := upstream.NewUpstream(input, config.CoreDNS.BootstrapDNS)
|
||||
|
||||
log.Printf("Checking if DNS %s works...", input)
|
||||
u, err := dnsforward.AddressToUpstream(input, "")
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("Failed to choose upstream for %s: %s", input, err)
|
||||
}
|
||||
defer u.Close()
|
||||
|
||||
alive, err := upstream.IsAlive(u)
|
||||
|
||||
req := dns.Msg{}
|
||||
req.Id = dns.Id()
|
||||
req.RecursionDesired = true
|
||||
req.Question = []dns.Question{
|
||||
{Name: "google-public-dns-a.google.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
|
||||
}
|
||||
reply, err := u.Exchange(&req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't communicate with DNS server %s: %s", input, err)
|
||||
}
|
||||
|
||||
if !alive {
|
||||
return fmt.Errorf("DNS server has not passed the healthcheck: %s", input)
|
||||
if len(reply.Answer) != 1 {
|
||||
return fmt.Errorf("DNS server %s returned wrong answer", input)
|
||||
}
|
||||
if t, ok := reply.Answer[0].(*dns.A); ok {
|
||||
if !net.IPv4(8, 8, 8, 8).Equal(t.A) {
|
||||
return fmt.Errorf("DNS server %s returned wrong answer: %v", input, t.A)
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("DNS %s works OK", input)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -242,7 +243,7 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
resp, err := client.Get(versionCheckURL)
|
||||
if err != nil {
|
||||
errortext := fmt.Sprintf("Couldn't get querylog from coredns: %T %s\n", err, err)
|
||||
errortext := fmt.Sprintf("Couldn't get version check json from %s: %T %s\n", versionCheckURL, err, err)
|
||||
log.Println(errortext)
|
||||
http.Error(w, errortext, http.StatusBadGateway)
|
||||
return
|
||||
@ -254,7 +255,7 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
|
||||
// read the body entirely
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
errortext := fmt.Sprintf("Couldn't read response body: %s", err)
|
||||
errortext := fmt.Sprintf("Couldn't read response body from %s: %s", versionCheckURL, err)
|
||||
log.Println(errortext)
|
||||
http.Error(w, errortext, http.StatusBadGateway)
|
||||
return
|
||||
@ -277,18 +278,18 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
|
||||
// ---------
|
||||
|
||||
func handleFilteringEnable(w http.ResponseWriter, r *http.Request) {
|
||||
config.CoreDNS.FilteringEnabled = true
|
||||
config.DNS.FilteringEnabled = true
|
||||
httpUpdateConfigReloadDNSReturnOK(w, r)
|
||||
}
|
||||
|
||||
func handleFilteringDisable(w http.ResponseWriter, r *http.Request) {
|
||||
config.CoreDNS.FilteringEnabled = false
|
||||
config.DNS.FilteringEnabled = false
|
||||
httpUpdateConfigReloadDNSReturnOK(w, r)
|
||||
}
|
||||
|
||||
func handleFilteringStatus(w http.ResponseWriter, r *http.Request) {
|
||||
data := map[string]interface{}{
|
||||
"enabled": config.CoreDNS.FilteringEnabled,
|
||||
"enabled": config.DNS.FilteringEnabled,
|
||||
}
|
||||
|
||||
config.RLock()
|
||||
@ -376,7 +377,8 @@ func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// URL is deemed valid, append it to filters, update config, write new filter file and tell coredns to reload it
|
||||
// URL is deemed valid, append it to filters, update config, write new filter file and tell dns to reload it
|
||||
// TODO: since we directly feed filters in-memory, revisit if writing configs is always neccessary
|
||||
config.Filters = append(config.Filters, filter)
|
||||
err = writeAllConfigs()
|
||||
if err != nil {
|
||||
@ -386,7 +388,7 @@ func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
tellCoreDNSToReload()
|
||||
reconfigureDNSServer()
|
||||
|
||||
_, err = fmt.Fprintf(w, "OK %d rules\n", filter.RulesCount)
|
||||
if err != nil {
|
||||
@ -531,199 +533,23 @@ func handleFilteringRefresh(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintf(w, "OK %d filters updated\n", updated)
|
||||
}
|
||||
|
||||
// Sets up a timer that will be checking for filters updates periodically
|
||||
func periodicallyRefreshFilters() {
|
||||
for range time.Tick(time.Minute) {
|
||||
refreshFiltersIfNeccessary(false)
|
||||
}
|
||||
}
|
||||
|
||||
// Checks filters updates if necessary
|
||||
// If force is true, it ignores the filter.LastUpdated field value
|
||||
func refreshFiltersIfNeccessary(force bool) int {
|
||||
config.Lock()
|
||||
|
||||
// fetch URLs
|
||||
updateCount := 0
|
||||
for i := range config.Filters {
|
||||
filter := &config.Filters[i] // otherwise we will be operating on a copy
|
||||
|
||||
if filter.ID == 0 { // protect against users modifying the yaml and removing the ID
|
||||
filter.ID = assignUniqueFilterID()
|
||||
}
|
||||
|
||||
updated, err := filter.update(force)
|
||||
if err != nil {
|
||||
log.Printf("Failed to update filter %s: %s\n", filter.URL, err)
|
||||
continue
|
||||
}
|
||||
if updated {
|
||||
// Saving it to the filters dir now
|
||||
err = filter.save()
|
||||
if err != nil {
|
||||
log.Printf("Failed to save the updated filter %d: %s", filter.ID, err)
|
||||
continue
|
||||
}
|
||||
|
||||
updateCount++
|
||||
}
|
||||
}
|
||||
config.Unlock()
|
||||
|
||||
if updateCount > 0 {
|
||||
tellCoreDNSToReload()
|
||||
}
|
||||
return updateCount
|
||||
}
|
||||
|
||||
// A helper function that parses filter contents and returns a number of rules and a filter name (if there's any)
|
||||
func parseFilterContents(contents []byte) (int, string) {
|
||||
lines := strings.Split(string(contents), "\n")
|
||||
rulesCount := 0
|
||||
name := ""
|
||||
seenTitle := false
|
||||
|
||||
// Count lines in the filter
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if len(line) > 0 && line[0] == '!' {
|
||||
if m := filterTitleRegexp.FindAllStringSubmatch(line, -1); len(m) > 0 && len(m[0]) >= 2 && !seenTitle {
|
||||
name = m[0][1]
|
||||
seenTitle = true
|
||||
}
|
||||
} else if len(line) != 0 {
|
||||
rulesCount++
|
||||
}
|
||||
}
|
||||
|
||||
return rulesCount, name
|
||||
}
|
||||
|
||||
// Checks for filters updates
|
||||
// If "force" is true -- does not check the filter's LastUpdated field
|
||||
// Call "save" to persist the filter contents
|
||||
func (filter *filter) update(force bool) (bool, error) {
|
||||
if filter.ID == 0 { // protect against users deleting the ID
|
||||
filter.ID = assignUniqueFilterID()
|
||||
}
|
||||
if !filter.Enabled {
|
||||
return false, nil
|
||||
}
|
||||
if !force && time.Since(filter.LastUpdated) <= updatePeriod {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
log.Printf("Downloading update for filter %d from %s", filter.ID, filter.URL)
|
||||
|
||||
// use the same update period for failed filter downloads to avoid flooding with requests
|
||||
filter.LastUpdated = time.Now()
|
||||
|
||||
resp, err := client.Get(filter.URL)
|
||||
if resp != nil && resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
if err != nil {
|
||||
log.Printf("Couldn't request filter from URL %s, skipping: %s", filter.URL, err)
|
||||
return false, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
log.Printf("Got status code %d from URL %s, skipping", resp.StatusCode, filter.URL)
|
||||
return false, fmt.Errorf("got status code != 200: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
contentType := strings.ToLower(resp.Header.Get("content-type"))
|
||||
if !strings.HasPrefix(contentType, "text/plain") {
|
||||
log.Printf("Non-text response %s from %s, skipping", contentType, filter.URL)
|
||||
return false, fmt.Errorf("non-text response %s", contentType)
|
||||
}
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
log.Printf("Couldn't fetch filter contents from URL %s, skipping: %s", filter.URL, err)
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Extract filter name and count number of rules
|
||||
rulesCount, filterName := parseFilterContents(body)
|
||||
|
||||
if filterName != "" {
|
||||
filter.Name = filterName
|
||||
}
|
||||
|
||||
// Check if the filter has been really changed
|
||||
if bytes.Equal(filter.Contents, body) {
|
||||
log.Printf("The filter %d text has not changed", filter.ID)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
log.Printf("Filter %d has been updated: %d bytes, %d rules", filter.ID, len(body), rulesCount)
|
||||
filter.RulesCount = rulesCount
|
||||
filter.Contents = body
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// saves filter contents to the file in dataDir
|
||||
func (filter *filter) save() error {
|
||||
filterFilePath := filter.Path()
|
||||
log.Printf("Saving filter %d contents to: %s", filter.ID, filterFilePath)
|
||||
|
||||
return safeWriteFile(filterFilePath, filter.Contents)
|
||||
}
|
||||
|
||||
// loads filter contents from the file in dataDir
|
||||
func (filter *filter) load() error {
|
||||
if !filter.Enabled {
|
||||
// No need to load a filter that is not enabled
|
||||
return nil
|
||||
}
|
||||
|
||||
filterFilePath := filter.Path()
|
||||
log.Printf("Loading filter %d contents to: %s", filter.ID, filterFilePath)
|
||||
|
||||
if _, err := os.Stat(filterFilePath); os.IsNotExist(err) {
|
||||
// do nothing, file doesn't exist
|
||||
return err
|
||||
}
|
||||
|
||||
filterFileContents, err := ioutil.ReadFile(filterFilePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("Filter %d length is %d", filter.ID, len(filterFileContents))
|
||||
filter.Contents = filterFileContents
|
||||
|
||||
// Now extract the rules count
|
||||
rulesCount, _ := parseFilterContents(filter.Contents)
|
||||
filter.RulesCount = rulesCount
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Path to the filter contents
|
||||
func (filter *filter) Path() string {
|
||||
return filepath.Join(config.ourBinaryDir, dataDir, filterDir, strconv.FormatInt(filter.ID, 10)+".txt")
|
||||
}
|
||||
|
||||
// ------------
|
||||
// safebrowsing
|
||||
// ------------
|
||||
|
||||
func handleSafeBrowsingEnable(w http.ResponseWriter, r *http.Request) {
|
||||
config.CoreDNS.SafeBrowsingEnabled = true
|
||||
config.DNS.SafeBrowsingEnabled = true
|
||||
httpUpdateConfigReloadDNSReturnOK(w, r)
|
||||
}
|
||||
|
||||
func handleSafeBrowsingDisable(w http.ResponseWriter, r *http.Request) {
|
||||
config.CoreDNS.SafeBrowsingEnabled = false
|
||||
config.DNS.SafeBrowsingEnabled = false
|
||||
httpUpdateConfigReloadDNSReturnOK(w, r)
|
||||
}
|
||||
|
||||
func handleSafeBrowsingStatus(w http.ResponseWriter, r *http.Request) {
|
||||
data := map[string]interface{}{
|
||||
"enabled": config.CoreDNS.SafeBrowsingEnabled,
|
||||
"enabled": config.DNS.SafeBrowsingEnabled,
|
||||
}
|
||||
jsonVal, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
@ -786,22 +612,22 @@ func handleParentalEnable(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "Sensitivity must be set to valid value", 400)
|
||||
return
|
||||
}
|
||||
config.CoreDNS.ParentalSensitivity = i
|
||||
config.CoreDNS.ParentalEnabled = true
|
||||
config.DNS.ParentalSensitivity = i
|
||||
config.DNS.ParentalEnabled = true
|
||||
httpUpdateConfigReloadDNSReturnOK(w, r)
|
||||
}
|
||||
|
||||
func handleParentalDisable(w http.ResponseWriter, r *http.Request) {
|
||||
config.CoreDNS.ParentalEnabled = false
|
||||
config.DNS.ParentalEnabled = false
|
||||
httpUpdateConfigReloadDNSReturnOK(w, r)
|
||||
}
|
||||
|
||||
func handleParentalStatus(w http.ResponseWriter, r *http.Request) {
|
||||
data := map[string]interface{}{
|
||||
"enabled": config.CoreDNS.ParentalEnabled,
|
||||
"enabled": config.DNS.ParentalEnabled,
|
||||
}
|
||||
if config.CoreDNS.ParentalEnabled {
|
||||
data["sensitivity"] = config.CoreDNS.ParentalSensitivity
|
||||
if config.DNS.ParentalEnabled {
|
||||
data["sensitivity"] = config.DNS.ParentalSensitivity
|
||||
}
|
||||
jsonVal, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
@ -826,18 +652,18 @@ func handleParentalStatus(w http.ResponseWriter, r *http.Request) {
|
||||
// ------------
|
||||
|
||||
func handleSafeSearchEnable(w http.ResponseWriter, r *http.Request) {
|
||||
config.CoreDNS.SafeSearchEnabled = true
|
||||
config.DNS.SafeSearchEnabled = true
|
||||
httpUpdateConfigReloadDNSReturnOK(w, r)
|
||||
}
|
||||
|
||||
func handleSafeSearchDisable(w http.ResponseWriter, r *http.Request) {
|
||||
config.CoreDNS.SafeSearchEnabled = false
|
||||
config.DNS.SafeSearchEnabled = false
|
||||
httpUpdateConfigReloadDNSReturnOK(w, r)
|
||||
}
|
||||
|
||||
func handleSafeSearchStatus(w http.ResponseWriter, r *http.Request) {
|
||||
data := map[string]interface{}{
|
||||
"enabled": config.CoreDNS.SafeSearchEnabled,
|
||||
"enabled": config.DNS.SafeSearchEnabled,
|
||||
}
|
||||
jsonVal, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
@ -861,17 +687,17 @@ func registerControlHandlers() {
|
||||
http.HandleFunc("/control/status", optionalAuth(ensureGET(handleStatus)))
|
||||
http.HandleFunc("/control/enable_protection", optionalAuth(ensurePOST(handleProtectionEnable)))
|
||||
http.HandleFunc("/control/disable_protection", optionalAuth(ensurePOST(handleProtectionDisable)))
|
||||
http.HandleFunc("/control/querylog", optionalAuth(ensureGET(corednsplugin.HandleQueryLog)))
|
||||
http.HandleFunc("/control/querylog", optionalAuth(ensureGET(dnsforward.HandleQueryLog)))
|
||||
http.HandleFunc("/control/querylog_enable", optionalAuth(ensurePOST(handleQueryLogEnable)))
|
||||
http.HandleFunc("/control/querylog_disable", optionalAuth(ensurePOST(handleQueryLogDisable)))
|
||||
http.HandleFunc("/control/set_upstream_dns", optionalAuth(ensurePOST(handleSetUpstreamDNS)))
|
||||
http.HandleFunc("/control/test_upstream_dns", optionalAuth(ensurePOST(handleTestUpstreamDNS)))
|
||||
http.HandleFunc("/control/i18n/change_language", optionalAuth(ensurePOST(handleI18nChangeLanguage)))
|
||||
http.HandleFunc("/control/i18n/current_language", optionalAuth(ensureGET(handleI18nCurrentLanguage)))
|
||||
http.HandleFunc("/control/stats_top", optionalAuth(ensureGET(corednsplugin.HandleStatsTop)))
|
||||
http.HandleFunc("/control/stats", optionalAuth(ensureGET(corednsplugin.HandleStats)))
|
||||
http.HandleFunc("/control/stats_history", optionalAuth(ensureGET(corednsplugin.HandleStatsHistory)))
|
||||
http.HandleFunc("/control/stats_reset", optionalAuth(ensurePOST(corednsplugin.HandleStatsReset)))
|
||||
http.HandleFunc("/control/stats_top", optionalAuth(ensureGET(dnsforward.HandleStatsTop)))
|
||||
http.HandleFunc("/control/stats", optionalAuth(ensureGET(dnsforward.HandleStats)))
|
||||
http.HandleFunc("/control/stats_history", optionalAuth(ensureGET(dnsforward.HandleStatsHistory)))
|
||||
http.HandleFunc("/control/stats_reset", optionalAuth(ensurePOST(dnsforward.HandleStatsReset)))
|
||||
http.HandleFunc("/control/version.json", optionalAuth(handleGetVersionJSON))
|
||||
http.HandleFunc("/control/filtering/enable", optionalAuth(ensurePOST(handleFilteringEnable)))
|
||||
http.HandleFunc("/control/filtering/disable", optionalAuth(ensurePOST(handleFilteringDisable)))
|
||||
|
132
coredns.go
132
coredns.go
@ -1,132 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync" // Include all plugins.
|
||||
|
||||
_ "github.com/AdguardTeam/AdGuardHome/coredns_plugin"
|
||||
_ "github.com/AdguardTeam/AdGuardHome/coredns_plugin/ratelimit"
|
||||
_ "github.com/AdguardTeam/AdGuardHome/coredns_plugin/refuseany"
|
||||
_ "github.com/AdguardTeam/AdGuardHome/upstream"
|
||||
"github.com/coredns/coredns/core/dnsserver"
|
||||
"github.com/coredns/coredns/coremain"
|
||||
_ "github.com/coredns/coredns/plugin/auto"
|
||||
_ "github.com/coredns/coredns/plugin/autopath"
|
||||
_ "github.com/coredns/coredns/plugin/bind"
|
||||
_ "github.com/coredns/coredns/plugin/cache"
|
||||
_ "github.com/coredns/coredns/plugin/chaos"
|
||||
_ "github.com/coredns/coredns/plugin/debug"
|
||||
_ "github.com/coredns/coredns/plugin/dnssec"
|
||||
_ "github.com/coredns/coredns/plugin/dnstap"
|
||||
_ "github.com/coredns/coredns/plugin/erratic"
|
||||
_ "github.com/coredns/coredns/plugin/errors"
|
||||
_ "github.com/coredns/coredns/plugin/file"
|
||||
_ "github.com/coredns/coredns/plugin/forward"
|
||||
_ "github.com/coredns/coredns/plugin/health"
|
||||
_ "github.com/coredns/coredns/plugin/hosts"
|
||||
_ "github.com/coredns/coredns/plugin/loadbalance"
|
||||
_ "github.com/coredns/coredns/plugin/log"
|
||||
_ "github.com/coredns/coredns/plugin/loop"
|
||||
_ "github.com/coredns/coredns/plugin/metadata"
|
||||
_ "github.com/coredns/coredns/plugin/metrics"
|
||||
_ "github.com/coredns/coredns/plugin/nsid"
|
||||
_ "github.com/coredns/coredns/plugin/pprof"
|
||||
_ "github.com/coredns/coredns/plugin/proxy"
|
||||
_ "github.com/coredns/coredns/plugin/reload"
|
||||
_ "github.com/coredns/coredns/plugin/rewrite"
|
||||
_ "github.com/coredns/coredns/plugin/root"
|
||||
_ "github.com/coredns/coredns/plugin/secondary"
|
||||
_ "github.com/coredns/coredns/plugin/template"
|
||||
_ "github.com/coredns/coredns/plugin/tls"
|
||||
_ "github.com/coredns/coredns/plugin/whoami"
|
||||
_ "github.com/mholt/caddy/onevent"
|
||||
)
|
||||
|
||||
// Directives are registered in the order they should be
|
||||
// executed.
|
||||
//
|
||||
// Ordering is VERY important. Every plugin will
|
||||
// feel the effects of all other plugin below
|
||||
// (after) them during a request, but they must not
|
||||
// care what plugin above them are doing.
|
||||
|
||||
var directives = []string{
|
||||
"metadata",
|
||||
"tls",
|
||||
"reload",
|
||||
"nsid",
|
||||
"root",
|
||||
"bind",
|
||||
"debug",
|
||||
"health",
|
||||
"pprof",
|
||||
"prometheus",
|
||||
"errors",
|
||||
"log",
|
||||
"refuseany",
|
||||
"ratelimit",
|
||||
"dnsfilter",
|
||||
"dnstap",
|
||||
"chaos",
|
||||
"loadbalance",
|
||||
"cache",
|
||||
"rewrite",
|
||||
"dnssec",
|
||||
"autopath",
|
||||
"template",
|
||||
"hosts",
|
||||
"file",
|
||||
"auto",
|
||||
"secondary",
|
||||
"loop",
|
||||
"forward",
|
||||
"proxy",
|
||||
"upstream",
|
||||
"erratic",
|
||||
"whoami",
|
||||
"on",
|
||||
}
|
||||
|
||||
func init() {
|
||||
dnsserver.Directives = directives
|
||||
}
|
||||
|
||||
var (
|
||||
isCoreDNSRunningLock sync.Mutex
|
||||
isCoreDNSRunning = false
|
||||
)
|
||||
|
||||
func isRunning() bool {
|
||||
isCoreDNSRunningLock.Lock()
|
||||
value := isCoreDNSRunning
|
||||
isCoreDNSRunningLock.Unlock()
|
||||
return value
|
||||
}
|
||||
|
||||
func startDNSServer() error {
|
||||
isCoreDNSRunningLock.Lock()
|
||||
if isCoreDNSRunning {
|
||||
isCoreDNSRunningLock.Unlock()
|
||||
return fmt.Errorf("Unable to start coreDNS: Already running")
|
||||
}
|
||||
isCoreDNSRunning = true
|
||||
isCoreDNSRunningLock.Unlock()
|
||||
|
||||
configpath := filepath.Join(config.ourBinaryDir, config.CoreDNS.coreFile)
|
||||
os.Args = os.Args[:1]
|
||||
os.Args = append(os.Args, "-conf")
|
||||
os.Args = append(os.Args, configpath)
|
||||
|
||||
err := writeCoreDNSConfig()
|
||||
if err != nil {
|
||||
errortext := fmt.Errorf("Unable to write coredns config: %s", err)
|
||||
log.Println(errortext)
|
||||
return errortext
|
||||
}
|
||||
|
||||
go coremain.Run()
|
||||
return nil
|
||||
}
|
@ -1,557 +0,0 @@
|
||||
package dnsfilter
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
|
||||
"github.com/coredns/coredns/core/dnsserver"
|
||||
"github.com/coredns/coredns/plugin"
|
||||
"github.com/coredns/coredns/plugin/metrics"
|
||||
"github.com/coredns/coredns/plugin/pkg/dnstest"
|
||||
"github.com/coredns/coredns/plugin/pkg/upstream"
|
||||
"github.com/coredns/coredns/request"
|
||||
"github.com/mholt/caddy"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
var defaultSOA = &dns.SOA{
|
||||
// values copied from verisign's nonexistent .com domain
|
||||
// their exact values are not important in our use case because they are used for domain transfers between primary/secondary DNS servers
|
||||
Refresh: 1800,
|
||||
Retry: 900,
|
||||
Expire: 604800,
|
||||
Minttl: 86400,
|
||||
}
|
||||
|
||||
func init() {
|
||||
caddy.RegisterPlugin("dnsfilter", caddy.Plugin{
|
||||
ServerType: "dns",
|
||||
Action: setup,
|
||||
})
|
||||
}
|
||||
|
||||
type plugFilter struct {
|
||||
ID int64
|
||||
Path string
|
||||
}
|
||||
|
||||
type plugSettings struct {
|
||||
SafeBrowsingBlockHost string
|
||||
ParentalBlockHost string
|
||||
QueryLogEnabled bool
|
||||
BlockedTTL uint32 // in seconds, default 3600
|
||||
Filters []plugFilter
|
||||
}
|
||||
|
||||
type plug struct {
|
||||
d *dnsfilter.Dnsfilter
|
||||
Next plugin.Handler
|
||||
upstream upstream.Upstream
|
||||
settings plugSettings
|
||||
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
var defaultPluginSettings = plugSettings{
|
||||
SafeBrowsingBlockHost: "safebrowsing.block.dns.adguard.com",
|
||||
ParentalBlockHost: "family.block.dns.adguard.com",
|
||||
BlockedTTL: 3600, // in seconds
|
||||
Filters: make([]plugFilter, 0),
|
||||
}
|
||||
|
||||
//
|
||||
// coredns handling functions
|
||||
//
|
||||
func setupPlugin(c *caddy.Controller) (*plug, error) {
|
||||
// create new Plugin and copy default values
|
||||
p := &plug{
|
||||
settings: defaultPluginSettings,
|
||||
d: dnsfilter.New(),
|
||||
}
|
||||
|
||||
log.Println("Initializing the CoreDNS plugin")
|
||||
|
||||
for c.Next() {
|
||||
for c.NextBlock() {
|
||||
blockValue := c.Val()
|
||||
switch blockValue {
|
||||
case "safebrowsing":
|
||||
log.Println("Browsing security service is enabled")
|
||||
p.d.EnableSafeBrowsing()
|
||||
if c.NextArg() {
|
||||
if len(c.Val()) == 0 {
|
||||
return nil, c.ArgErr()
|
||||
}
|
||||
p.d.SetSafeBrowsingServer(c.Val())
|
||||
}
|
||||
case "safesearch":
|
||||
log.Println("Safe search is enabled")
|
||||
p.d.EnableSafeSearch()
|
||||
case "parental":
|
||||
if !c.NextArg() {
|
||||
return nil, c.ArgErr()
|
||||
}
|
||||
sensitivity, err := strconv.Atoi(c.Val())
|
||||
if err != nil {
|
||||
return nil, c.ArgErr()
|
||||
}
|
||||
|
||||
log.Println("Parental control is enabled")
|
||||
err = p.d.EnableParental(sensitivity)
|
||||
if err != nil {
|
||||
return nil, c.ArgErr()
|
||||
}
|
||||
if c.NextArg() {
|
||||
if len(c.Val()) == 0 {
|
||||
return nil, c.ArgErr()
|
||||
}
|
||||
p.settings.ParentalBlockHost = c.Val()
|
||||
}
|
||||
case "blocked_ttl":
|
||||
if !c.NextArg() {
|
||||
return nil, c.ArgErr()
|
||||
}
|
||||
blockedTtl, err := strconv.ParseUint(c.Val(), 10, 32)
|
||||
if err != nil {
|
||||
return nil, c.ArgErr()
|
||||
}
|
||||
log.Printf("Blocked request TTL is %d", blockedTtl)
|
||||
p.settings.BlockedTTL = uint32(blockedTtl)
|
||||
case "querylog":
|
||||
log.Println("Query log is enabled")
|
||||
p.settings.QueryLogEnabled = true
|
||||
case "filter":
|
||||
if !c.NextArg() {
|
||||
return nil, c.ArgErr()
|
||||
}
|
||||
|
||||
filterId, err := strconv.ParseInt(c.Val(), 10, 64)
|
||||
if err != nil {
|
||||
return nil, c.ArgErr()
|
||||
}
|
||||
if !c.NextArg() {
|
||||
return nil, c.ArgErr()
|
||||
}
|
||||
filterPath := c.Val()
|
||||
|
||||
// Initialize filter and add it to the list
|
||||
p.settings.Filters = append(p.settings.Filters, plugFilter{
|
||||
ID: filterId,
|
||||
Path: filterPath,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, filter := range p.settings.Filters {
|
||||
log.Printf("Loading rules from %s", filter.Path)
|
||||
|
||||
file, err := os.Open(filter.Path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
count := 0
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
text := scanner.Text()
|
||||
|
||||
err = p.d.AddRule(text, filter.ID)
|
||||
if err == dnsfilter.ErrAlreadyExists || err == dnsfilter.ErrInvalidSyntax {
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
log.Printf("Cannot add rule %s: %s", text, err)
|
||||
// Just ignore invalid rules
|
||||
continue
|
||||
}
|
||||
count++
|
||||
}
|
||||
log.Printf("Added %d rules from filter ID=%d", count, filter.ID)
|
||||
|
||||
if err = scanner.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("Loading stats from querylog")
|
||||
err := fillStatsFromQueryLog()
|
||||
if err != nil {
|
||||
log.Printf("Failed to load stats from querylog: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if p.settings.QueryLogEnabled {
|
||||
onceQueryLog.Do(func() {
|
||||
go periodicQueryLogRotate()
|
||||
go periodicHourlyTopRotate()
|
||||
go statsRotator()
|
||||
})
|
||||
}
|
||||
|
||||
onceHook.Do(func() {
|
||||
caddy.RegisterEventHook("dnsfilter-reload", hook)
|
||||
})
|
||||
|
||||
p.upstream, err = upstream.New(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func setup(c *caddy.Controller) error {
|
||||
p, err := setupPlugin(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
config := dnsserver.GetConfig(c)
|
||||
config.AddPlugin(func(next plugin.Handler) plugin.Handler {
|
||||
p.Next = next
|
||||
return p
|
||||
})
|
||||
|
||||
c.OnStartup(func() error {
|
||||
m := dnsserver.GetConfig(c).Handler("prometheus")
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
if x, ok := m.(*metrics.Metrics); ok {
|
||||
x.MustRegister(requests)
|
||||
x.MustRegister(filtered)
|
||||
x.MustRegister(filteredLists)
|
||||
x.MustRegister(filteredSafebrowsing)
|
||||
x.MustRegister(filteredParental)
|
||||
x.MustRegister(whitelisted)
|
||||
x.MustRegister(safesearch)
|
||||
x.MustRegister(errorsTotal)
|
||||
x.MustRegister(elapsedTime)
|
||||
x.MustRegister(p)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
c.OnShutdown(p.onShutdown)
|
||||
c.OnFinalShutdown(p.onFinalShutdown)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *plug) onShutdown() error {
|
||||
p.Lock()
|
||||
p.d.Destroy()
|
||||
p.d = nil
|
||||
p.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *plug) onFinalShutdown() error {
|
||||
logBufferLock.Lock()
|
||||
err := flushToFile(logBuffer)
|
||||
if err != nil {
|
||||
log.Printf("failed to flush to file: %s", err)
|
||||
return err
|
||||
}
|
||||
logBufferLock.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
type statsFunc func(ch interface{}, name string, text string, value float64, valueType prometheus.ValueType)
|
||||
|
||||
func doDesc(ch interface{}, name string, text string, value float64, valueType prometheus.ValueType) {
|
||||
realch, ok := ch.(chan<- *prometheus.Desc)
|
||||
if !ok {
|
||||
log.Printf("Couldn't convert ch to chan<- *prometheus.Desc\n")
|
||||
return
|
||||
}
|
||||
realch <- prometheus.NewDesc(name, text, nil, nil)
|
||||
}
|
||||
|
||||
func doMetric(ch interface{}, name string, text string, value float64, valueType prometheus.ValueType) {
|
||||
realch, ok := ch.(chan<- prometheus.Metric)
|
||||
if !ok {
|
||||
log.Printf("Couldn't convert ch to chan<- prometheus.Metric\n")
|
||||
return
|
||||
}
|
||||
desc := prometheus.NewDesc(name, text, nil, nil)
|
||||
realch <- prometheus.MustNewConstMetric(desc, valueType, value)
|
||||
}
|
||||
|
||||
func gen(ch interface{}, doFunc statsFunc, name string, text string, value float64, valueType prometheus.ValueType) {
|
||||
doFunc(ch, name, text, value, valueType)
|
||||
}
|
||||
|
||||
func doStatsLookup(ch interface{}, doFunc statsFunc, name string, lookupstats *dnsfilter.LookupStats) {
|
||||
gen(ch, doFunc, fmt.Sprintf("coredns_dnsfilter_%s_requests", name), fmt.Sprintf("Number of %s HTTP requests that were sent", name), float64(lookupstats.Requests), prometheus.CounterValue)
|
||||
gen(ch, doFunc, fmt.Sprintf("coredns_dnsfilter_%s_cachehits", name), fmt.Sprintf("Number of %s lookups that didn't need HTTP requests", name), float64(lookupstats.CacheHits), prometheus.CounterValue)
|
||||
gen(ch, doFunc, fmt.Sprintf("coredns_dnsfilter_%s_pending", name), fmt.Sprintf("Number of currently pending %s HTTP requests", name), float64(lookupstats.Pending), prometheus.GaugeValue)
|
||||
gen(ch, doFunc, fmt.Sprintf("coredns_dnsfilter_%s_pending_max", name), fmt.Sprintf("Maximum number of pending %s HTTP requests", name), float64(lookupstats.PendingMax), prometheus.GaugeValue)
|
||||
}
|
||||
|
||||
func (p *plug) doStats(ch interface{}, doFunc statsFunc) {
|
||||
p.RLock()
|
||||
stats := p.d.GetStats()
|
||||
doStatsLookup(ch, doFunc, "safebrowsing", &stats.Safebrowsing)
|
||||
doStatsLookup(ch, doFunc, "parental", &stats.Parental)
|
||||
p.RUnlock()
|
||||
}
|
||||
|
||||
// Describe is called by prometheus handler to know stat types
|
||||
func (p *plug) Describe(ch chan<- *prometheus.Desc) {
|
||||
p.doStats(ch, doDesc)
|
||||
}
|
||||
|
||||
// Collect is called by prometheus handler to collect stats
|
||||
func (p *plug) Collect(ch chan<- prometheus.Metric) {
|
||||
p.doStats(ch, doMetric)
|
||||
}
|
||||
|
||||
func (p *plug) replaceHostWithValAndReply(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, host string, val string, question dns.Question) (int, error) {
|
||||
// check if it's a domain name or IP address
|
||||
addr := net.ParseIP(val)
|
||||
var records []dns.RR
|
||||
// log.Println("Will give", val, "instead of", host) // debug logging
|
||||
if addr != nil {
|
||||
// this is an IP address, return it
|
||||
result, err := dns.NewRR(fmt.Sprintf("%s %d A %s", host, p.settings.BlockedTTL, val))
|
||||
if err != nil {
|
||||
log.Printf("Got error %s\n", err)
|
||||
return dns.RcodeServerFailure, fmt.Errorf("plugin/dnsfilter: %s", err)
|
||||
}
|
||||
records = append(records, result)
|
||||
} else {
|
||||
// this is a domain name, need to look it up
|
||||
req := new(dns.Msg)
|
||||
req.SetQuestion(dns.Fqdn(val), question.Qtype)
|
||||
req.RecursionDesired = true
|
||||
reqstate := request.Request{W: w, Req: req, Context: ctx}
|
||||
result, err := p.upstream.Lookup(reqstate, dns.Fqdn(val), reqstate.QType())
|
||||
if err != nil {
|
||||
log.Printf("Got error %s\n", err)
|
||||
return dns.RcodeServerFailure, fmt.Errorf("plugin/dnsfilter: %s", err)
|
||||
}
|
||||
if result != nil {
|
||||
for _, answer := range result.Answer {
|
||||
answer.Header().Name = question.Name
|
||||
}
|
||||
records = result.Answer
|
||||
}
|
||||
}
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Authoritative, m.RecursionAvailable, m.Compress = true, true, true
|
||||
m.Answer = append(m.Answer, records...)
|
||||
state := request.Request{W: w, Req: r, Context: ctx}
|
||||
state.SizeAndDo(m)
|
||||
err := state.W.WriteMsg(m)
|
||||
if err != nil {
|
||||
log.Printf("Got error %s\n", err)
|
||||
return dns.RcodeServerFailure, fmt.Errorf("plugin/dnsfilter: %s", err)
|
||||
}
|
||||
return dns.RcodeSuccess, nil
|
||||
}
|
||||
|
||||
// generate SOA record that makes DNS clients cache NXdomain results
|
||||
// the only value that is important is TTL in header, other values like refresh, retry, expire and minttl are irrelevant
|
||||
func (p *plug) genSOA(r *dns.Msg) []dns.RR {
|
||||
zone := r.Question[0].Name
|
||||
header := dns.RR_Header{Name: zone, Rrtype: dns.TypeSOA, Ttl: p.settings.BlockedTTL, Class: dns.ClassINET}
|
||||
|
||||
Mbox := "hostmaster."
|
||||
if zone[0] != '.' {
|
||||
Mbox += zone
|
||||
}
|
||||
Ns := "fake-for-negative-caching.adguard.com."
|
||||
|
||||
soa := *defaultSOA
|
||||
soa.Hdr = header
|
||||
soa.Mbox = Mbox
|
||||
soa.Ns = Ns
|
||||
soa.Serial = 100500 // faster than uint32(time.Now().Unix())
|
||||
return []dns.RR{&soa}
|
||||
}
|
||||
|
||||
func (p *plug) writeNXdomain(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||
state := request.Request{W: w, Req: r, Context: ctx}
|
||||
m := new(dns.Msg)
|
||||
m.SetRcode(state.Req, dns.RcodeNameError)
|
||||
m.Authoritative, m.RecursionAvailable, m.Compress = true, true, true
|
||||
m.Ns = p.genSOA(r)
|
||||
|
||||
state.SizeAndDo(m)
|
||||
err := state.W.WriteMsg(m)
|
||||
if err != nil {
|
||||
log.Printf("Got error %s\n", err)
|
||||
return dns.RcodeServerFailure, err
|
||||
}
|
||||
return dns.RcodeNameError, nil
|
||||
}
|
||||
|
||||
func (p *plug) serveDNSInternal(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, dnsfilter.Result, error) {
|
||||
if len(r.Question) != 1 {
|
||||
// google DNS, bind and others do the same
|
||||
return dns.RcodeFormatError, dnsfilter.Result{}, fmt.Errorf("got a DNS request with more than one Question")
|
||||
}
|
||||
for _, question := range r.Question {
|
||||
host := strings.ToLower(strings.TrimSuffix(question.Name, "."))
|
||||
// is it a safesearch domain?
|
||||
p.RLock()
|
||||
if val, ok := p.d.SafeSearchDomain(host); ok {
|
||||
rcode, err := p.replaceHostWithValAndReply(ctx, w, r, host, val, question)
|
||||
if err != nil {
|
||||
p.RUnlock()
|
||||
return rcode, dnsfilter.Result{}, err
|
||||
}
|
||||
p.RUnlock()
|
||||
return rcode, dnsfilter.Result{Reason: dnsfilter.FilteredSafeSearch}, err
|
||||
}
|
||||
p.RUnlock()
|
||||
|
||||
// needs to be filtered instead
|
||||
p.RLock()
|
||||
result, err := p.d.CheckHost(host)
|
||||
if err != nil {
|
||||
log.Printf("plugin/dnsfilter: %s\n", err)
|
||||
p.RUnlock()
|
||||
return dns.RcodeServerFailure, dnsfilter.Result{}, fmt.Errorf("plugin/dnsfilter: %s", err)
|
||||
}
|
||||
p.RUnlock()
|
||||
|
||||
if result.IsFiltered {
|
||||
switch result.Reason {
|
||||
case dnsfilter.FilteredSafeBrowsing:
|
||||
// return cname safebrowsing.block.dns.adguard.com
|
||||
val := p.settings.SafeBrowsingBlockHost
|
||||
rcode, err := p.replaceHostWithValAndReply(ctx, w, r, host, val, question)
|
||||
if err != nil {
|
||||
return rcode, dnsfilter.Result{}, err
|
||||
}
|
||||
return rcode, result, err
|
||||
case dnsfilter.FilteredParental:
|
||||
// return cname family.block.dns.adguard.com
|
||||
val := p.settings.ParentalBlockHost
|
||||
rcode, err := p.replaceHostWithValAndReply(ctx, w, r, host, val, question)
|
||||
if err != nil {
|
||||
return rcode, dnsfilter.Result{}, err
|
||||
}
|
||||
return rcode, result, err
|
||||
case dnsfilter.FilteredBlackList:
|
||||
|
||||
if result.Ip == nil {
|
||||
// return NXDomain
|
||||
rcode, err := p.writeNXdomain(ctx, w, r)
|
||||
if err != nil {
|
||||
return rcode, dnsfilter.Result{}, err
|
||||
}
|
||||
return rcode, result, err
|
||||
} else {
|
||||
// This is a hosts-syntax rule
|
||||
rcode, err := p.replaceHostWithValAndReply(ctx, w, r, host, result.Ip.String(), question)
|
||||
if err != nil {
|
||||
return rcode, dnsfilter.Result{}, err
|
||||
}
|
||||
return rcode, result, err
|
||||
}
|
||||
case dnsfilter.FilteredInvalid:
|
||||
// return NXdomain
|
||||
rcode, err := p.writeNXdomain(ctx, w, r)
|
||||
if err != nil {
|
||||
return rcode, dnsfilter.Result{}, err
|
||||
}
|
||||
return rcode, result, err
|
||||
default:
|
||||
log.Printf("SHOULD NOT HAPPEN -- got unknown reason for filtering host \"%s\": %v, %+v", host, result.Reason, result)
|
||||
}
|
||||
} else {
|
||||
switch result.Reason {
|
||||
case dnsfilter.NotFilteredWhiteList:
|
||||
rcode, err := plugin.NextOrFailure(p.Name(), p.Next, ctx, w, r)
|
||||
return rcode, result, err
|
||||
case dnsfilter.NotFilteredNotFound:
|
||||
// do nothing, pass through to lower code
|
||||
default:
|
||||
log.Printf("SHOULD NOT HAPPEN -- got unknown reason for not filtering host \"%s\": %v, %+v", host, result.Reason, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
rcode, err := plugin.NextOrFailure(p.Name(), p.Next, ctx, w, r)
|
||||
return rcode, dnsfilter.Result{}, err
|
||||
}
|
||||
|
||||
// ServeDNS handles the DNS request and refuses if it's in filterlists
|
||||
func (p *plug) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||
start := time.Now()
|
||||
requests.Inc()
|
||||
state := request.Request{W: w, Req: r}
|
||||
ip := state.IP()
|
||||
|
||||
// capture the written answer
|
||||
rrw := dnstest.NewRecorder(w)
|
||||
rcode, result, err := p.serveDNSInternal(ctx, rrw, r)
|
||||
if rcode > 0 {
|
||||
// actually send the answer if we have one
|
||||
answer := new(dns.Msg)
|
||||
answer.SetRcode(r, rcode)
|
||||
state.SizeAndDo(answer)
|
||||
err = w.WriteMsg(answer)
|
||||
if err != nil {
|
||||
return dns.RcodeServerFailure, err
|
||||
}
|
||||
}
|
||||
|
||||
// increment counters
|
||||
switch {
|
||||
case err != nil:
|
||||
errorsTotal.Inc()
|
||||
case result.Reason == dnsfilter.FilteredBlackList:
|
||||
filtered.Inc()
|
||||
filteredLists.Inc()
|
||||
case result.Reason == dnsfilter.FilteredSafeBrowsing:
|
||||
filtered.Inc()
|
||||
filteredSafebrowsing.Inc()
|
||||
case result.Reason == dnsfilter.FilteredParental:
|
||||
filtered.Inc()
|
||||
filteredParental.Inc()
|
||||
case result.Reason == dnsfilter.FilteredInvalid:
|
||||
filtered.Inc()
|
||||
filteredInvalid.Inc()
|
||||
case result.Reason == dnsfilter.FilteredSafeSearch:
|
||||
// the request was passsed through but not filtered, don't increment filtered
|
||||
safesearch.Inc()
|
||||
case result.Reason == dnsfilter.NotFilteredWhiteList:
|
||||
whitelisted.Inc()
|
||||
case result.Reason == dnsfilter.NotFilteredNotFound:
|
||||
// do nothing
|
||||
case result.Reason == dnsfilter.NotFilteredError:
|
||||
text := "SHOULD NOT HAPPEN: got DNSFILTER_NOTFILTERED_ERROR without err != nil!"
|
||||
log.Println(text)
|
||||
err = errors.New(text)
|
||||
rcode = dns.RcodeServerFailure
|
||||
}
|
||||
|
||||
// log
|
||||
elapsed := time.Since(start)
|
||||
elapsedTime.Observe(elapsed.Seconds())
|
||||
if p.settings.QueryLogEnabled {
|
||||
logRequest(r, rrw.Msg, result, time.Since(start), ip)
|
||||
}
|
||||
return rcode, err
|
||||
}
|
||||
|
||||
// Name returns name of the plugin as seen in Corefile and plugin.cfg
|
||||
func (p *plug) Name() string { return "dnsfilter" }
|
||||
|
||||
var onceHook sync.Once
|
||||
var onceQueryLog sync.Once
|
@ -1,131 +0,0 @@
|
||||
package dnsfilter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/coredns/coredns/plugin"
|
||||
"github.com/coredns/coredns/plugin/pkg/dnstest"
|
||||
"github.com/coredns/coredns/plugin/test"
|
||||
"github.com/mholt/caddy"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func TestSetup(t *testing.T) {
|
||||
for i, testcase := range []struct {
|
||||
config string
|
||||
failing bool
|
||||
}{
|
||||
{`dnsfilter`, false},
|
||||
{`dnsfilter {
|
||||
filter 0 /dev/nonexistent/abcdef
|
||||
}`, true},
|
||||
{`dnsfilter {
|
||||
filter 0 ../tests/dns.txt
|
||||
}`, false},
|
||||
{`dnsfilter {
|
||||
safebrowsing
|
||||
filter 0 ../tests/dns.txt
|
||||
}`, false},
|
||||
{`dnsfilter {
|
||||
parental
|
||||
filter 0 ../tests/dns.txt
|
||||
}`, true},
|
||||
} {
|
||||
c := caddy.NewTestController("dns", testcase.config)
|
||||
err := setup(c)
|
||||
if err != nil {
|
||||
if !testcase.failing {
|
||||
t.Fatalf("Test #%d expected no errors, but got: %v", i, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if testcase.failing {
|
||||
t.Fatalf("Test #%d expected to fail but it didn't", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEtcHostsFilter(t *testing.T) {
|
||||
text := []byte("127.0.0.1 doubleclick.net\n" + "127.0.0.1 example.org example.net www.example.org www.example.net")
|
||||
tmpfile, err := ioutil.TempFile("", "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err = tmpfile.Write(text); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err = tmpfile.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
defer os.Remove(tmpfile.Name())
|
||||
|
||||
configText := fmt.Sprintf("dnsfilter {\nfilter 0 %s\n}", tmpfile.Name())
|
||||
c := caddy.NewTestController("dns", configText)
|
||||
p, err := setupPlugin(c)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
p.Next = zeroTTLBackend()
|
||||
|
||||
ctx := context.TODO()
|
||||
|
||||
for _, testcase := range []struct {
|
||||
host string
|
||||
filtered bool
|
||||
}{
|
||||
{"www.doubleclick.net", false},
|
||||
{"doubleclick.net", true},
|
||||
{"www2.example.org", false},
|
||||
{"www2.example.net", false},
|
||||
{"test.www.example.org", false},
|
||||
{"test.www.example.net", false},
|
||||
{"example.org", true},
|
||||
{"example.net", true},
|
||||
{"www.example.org", true},
|
||||
{"www.example.net", true},
|
||||
} {
|
||||
req := new(dns.Msg)
|
||||
req.SetQuestion(testcase.host+".", dns.TypeA)
|
||||
|
||||
resp := test.ResponseWriter{}
|
||||
rrw := dnstest.NewRecorder(&resp)
|
||||
rcode, err := p.ServeDNS(ctx, rrw, req)
|
||||
if err != nil {
|
||||
t.Fatalf("ServeDNS returned error: %s", err)
|
||||
}
|
||||
if rcode != rrw.Rcode {
|
||||
t.Fatalf("ServeDNS return value for host %s has rcode %d that does not match captured rcode %d", testcase.host, rcode, rrw.Rcode)
|
||||
}
|
||||
A, ok := rrw.Msg.Answer[0].(*dns.A)
|
||||
if !ok {
|
||||
t.Fatalf("Host %s expected to have result A", testcase.host)
|
||||
}
|
||||
ip := net.IPv4(127, 0, 0, 1)
|
||||
filtered := ip.Equal(A.A)
|
||||
if testcase.filtered && testcase.filtered != filtered {
|
||||
t.Fatalf("Host %s expected to be filtered, instead it is not filtered", testcase.host)
|
||||
}
|
||||
if !testcase.filtered && testcase.filtered != filtered {
|
||||
t.Fatalf("Host %s expected to be not filtered, instead it is filtered", testcase.host)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func zeroTTLBackend() plugin.Handler {
|
||||
return plugin.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Response, m.RecursionAvailable = true, true
|
||||
|
||||
m.Answer = []dns.RR{test.A("example.org. 0 IN A 127.0.0.53")}
|
||||
w.WriteMsg(m)
|
||||
return dns.RcodeSuccess, nil
|
||||
})
|
||||
}
|
@ -1,182 +0,0 @@
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log"
|
||||
"sort"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
// ratelimiting and per-ip buckets
|
||||
"github.com/beefsack/go-rate"
|
||||
"github.com/patrickmn/go-cache"
|
||||
|
||||
// coredns plugin
|
||||
"github.com/coredns/coredns/core/dnsserver"
|
||||
"github.com/coredns/coredns/plugin"
|
||||
"github.com/coredns/coredns/plugin/metrics"
|
||||
"github.com/coredns/coredns/plugin/pkg/dnstest"
|
||||
"github.com/coredns/coredns/request"
|
||||
"github.com/mholt/caddy"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
const defaultRatelimit = 30
|
||||
const defaultResponseSize = 1000
|
||||
|
||||
var (
|
||||
tokenBuckets = cache.New(time.Hour, time.Hour)
|
||||
)
|
||||
|
||||
// ServeDNS handles the DNS request and refuses if it's an beyind specified ratelimit
|
||||
func (p *plug) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||
state := request.Request{W: w, Req: r}
|
||||
ip := state.IP()
|
||||
allow, err := p.allowRequest(ip)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if !allow {
|
||||
ratelimited.Inc()
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Record response to get status code and size of the reply.
|
||||
rw := dnstest.NewRecorder(w)
|
||||
status, err := plugin.NextOrFailure(p.Name(), p.Next, ctx, rw, r)
|
||||
|
||||
size := rw.Len
|
||||
|
||||
if size > defaultResponseSize && state.Proto() == "udp" {
|
||||
// For large UDP responses we call allowRequest more times
|
||||
// The exact number of times depends on the response size
|
||||
for i := 0; i < size/defaultResponseSize; i++ {
|
||||
p.allowRequest(ip)
|
||||
}
|
||||
}
|
||||
|
||||
return status, err
|
||||
}
|
||||
|
||||
func (p *plug) allowRequest(ip string) (bool, error) {
|
||||
if len(p.whitelist) > 0 {
|
||||
i := sort.SearchStrings(p.whitelist, ip)
|
||||
|
||||
if i < len(p.whitelist) && p.whitelist[i] == ip {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
if _, found := tokenBuckets.Get(ip); !found {
|
||||
tokenBuckets.Set(ip, rate.New(p.ratelimit, time.Second), time.Hour)
|
||||
}
|
||||
|
||||
value, found := tokenBuckets.Get(ip)
|
||||
if !found {
|
||||
// should not happen since we've just inserted it
|
||||
text := "SHOULD NOT HAPPEN: just-inserted ratelimiter disappeared"
|
||||
log.Println(text)
|
||||
err := errors.New(text)
|
||||
return true, err
|
||||
}
|
||||
|
||||
rl, ok := value.(*rate.RateLimiter)
|
||||
if !ok {
|
||||
text := "SHOULD NOT HAPPEN: non-bool entry found in safebrowsing lookup cache"
|
||||
log.Println(text)
|
||||
err := errors.New(text)
|
||||
return true, err
|
||||
}
|
||||
|
||||
allow, _ := rl.Try()
|
||||
return allow, nil
|
||||
}
|
||||
|
||||
//
|
||||
// helper functions
|
||||
//
|
||||
func init() {
|
||||
caddy.RegisterPlugin("ratelimit", caddy.Plugin{
|
||||
ServerType: "dns",
|
||||
Action: setup,
|
||||
})
|
||||
}
|
||||
|
||||
type plug struct {
|
||||
Next plugin.Handler
|
||||
|
||||
// configuration for creating above
|
||||
ratelimit int // in requests per second per IP
|
||||
whitelist []string // a list of whitelisted IP addresses
|
||||
}
|
||||
|
||||
func setupPlugin(c *caddy.Controller) (*plug, error) {
|
||||
p := &plug{ratelimit: defaultRatelimit}
|
||||
|
||||
for c.Next() {
|
||||
args := c.RemainingArgs()
|
||||
if len(args) > 0 {
|
||||
ratelimit, err := strconv.Atoi(args[0])
|
||||
if err != nil {
|
||||
return nil, c.ArgErr()
|
||||
}
|
||||
p.ratelimit = ratelimit
|
||||
}
|
||||
for c.NextBlock() {
|
||||
switch c.Val() {
|
||||
case "whitelist":
|
||||
p.whitelist = c.RemainingArgs()
|
||||
|
||||
if len(p.whitelist) > 0 {
|
||||
sort.Strings(p.whitelist)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func setup(c *caddy.Controller) error {
|
||||
p, err := setupPlugin(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
config := dnsserver.GetConfig(c)
|
||||
config.AddPlugin(func(next plugin.Handler) plugin.Handler {
|
||||
p.Next = next
|
||||
return p
|
||||
})
|
||||
|
||||
c.OnStartup(func() error {
|
||||
m := dnsserver.GetConfig(c).Handler("prometheus")
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
if x, ok := m.(*metrics.Metrics); ok {
|
||||
x.MustRegister(ratelimited)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func newDNSCounter(name string, help string) prometheus.Counter {
|
||||
return prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Namespace: plugin.Namespace,
|
||||
Subsystem: "ratelimit",
|
||||
Name: name,
|
||||
Help: help,
|
||||
})
|
||||
}
|
||||
|
||||
var (
|
||||
ratelimited = newDNSCounter("dropped_total", "Count of requests that have been dropped because of rate limit")
|
||||
)
|
||||
|
||||
// Name returns name of the plugin as seen in Corefile and plugin.cfg
|
||||
func (p *plug) Name() string { return "ratelimit" }
|
@ -1,80 +0,0 @@
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
)
|
||||
|
||||
func TestSetup(t *testing.T) {
|
||||
for i, testcase := range []struct {
|
||||
config string
|
||||
failing bool
|
||||
}{
|
||||
{`ratelimit`, false},
|
||||
{`ratelimit 100`, false},
|
||||
{`ratelimit {
|
||||
whitelist 127.0.0.1
|
||||
}`, false},
|
||||
{`ratelimit 50 {
|
||||
whitelist 127.0.0.1 176.103.130.130
|
||||
}`, false},
|
||||
{`ratelimit test`, true},
|
||||
} {
|
||||
c := caddy.NewTestController("dns", testcase.config)
|
||||
err := setup(c)
|
||||
if err != nil {
|
||||
if !testcase.failing {
|
||||
t.Fatalf("Test #%d expected no errors, but got: %v", i, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if testcase.failing {
|
||||
t.Fatalf("Test #%d expected to fail but it didn't", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRatelimiting(t *testing.T) {
|
||||
// rate limit is 1 per sec
|
||||
c := caddy.NewTestController("dns", `ratelimit 1`)
|
||||
p, err := setupPlugin(c)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal("Failed to initialize the plugin")
|
||||
}
|
||||
|
||||
allowed, err := p.allowRequest("127.0.0.1")
|
||||
|
||||
if err != nil || !allowed {
|
||||
t.Fatal("First request must have been allowed")
|
||||
}
|
||||
|
||||
allowed, err = p.allowRequest("127.0.0.1")
|
||||
|
||||
if err != nil || allowed {
|
||||
t.Fatal("Second request must have been ratelimited")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWhitelist(t *testing.T) {
|
||||
// rate limit is 1 per sec
|
||||
c := caddy.NewTestController("dns", `ratelimit 1 { whitelist 127.0.0.2 127.0.0.1 127.0.0.125 }`)
|
||||
p, err := setupPlugin(c)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal("Failed to initialize the plugin")
|
||||
}
|
||||
|
||||
allowed, err := p.allowRequest("127.0.0.1")
|
||||
|
||||
if err != nil || !allowed {
|
||||
t.Fatal("First request must have been allowed")
|
||||
}
|
||||
|
||||
allowed, err = p.allowRequest("127.0.0.1")
|
||||
|
||||
if err != nil || !allowed {
|
||||
t.Fatal("Second request must have been allowed due to whitelist")
|
||||
}
|
||||
}
|
@ -1,91 +0,0 @@
|
||||
package refuseany
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/coredns/coredns/core/dnsserver"
|
||||
"github.com/coredns/coredns/plugin"
|
||||
"github.com/coredns/coredns/plugin/metrics"
|
||||
"github.com/coredns/coredns/request"
|
||||
"github.com/mholt/caddy"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
type plug struct {
|
||||
Next plugin.Handler
|
||||
}
|
||||
|
||||
// ServeDNS handles the DNS request and refuses if it's an ANY request
|
||||
func (p *plug) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||
if len(r.Question) != 1 {
|
||||
// google DNS, bind and others do the same
|
||||
return dns.RcodeFormatError, fmt.Errorf("Got DNS request with != 1 questions")
|
||||
}
|
||||
|
||||
q := r.Question[0]
|
||||
if q.Qtype == dns.TypeANY {
|
||||
state := request.Request{W: w, Req: r, Context: ctx}
|
||||
rcode := dns.RcodeNotImplemented
|
||||
|
||||
m := new(dns.Msg)
|
||||
m.SetRcode(r, rcode)
|
||||
state.SizeAndDo(m)
|
||||
err := state.W.WriteMsg(m)
|
||||
if err != nil {
|
||||
log.Printf("Got error %s\n", err)
|
||||
return dns.RcodeServerFailure, err
|
||||
}
|
||||
return rcode, nil
|
||||
}
|
||||
|
||||
return plugin.NextOrFailure(p.Name(), p.Next, ctx, w, r)
|
||||
}
|
||||
|
||||
func init() {
|
||||
caddy.RegisterPlugin("refuseany", caddy.Plugin{
|
||||
ServerType: "dns",
|
||||
Action: setup,
|
||||
})
|
||||
}
|
||||
|
||||
func setup(c *caddy.Controller) error {
|
||||
p := &plug{}
|
||||
config := dnsserver.GetConfig(c)
|
||||
|
||||
config.AddPlugin(func(next plugin.Handler) plugin.Handler {
|
||||
p.Next = next
|
||||
return p
|
||||
})
|
||||
|
||||
c.OnStartup(func() error {
|
||||
m := dnsserver.GetConfig(c).Handler("prometheus")
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
if x, ok := m.(*metrics.Metrics); ok {
|
||||
x.MustRegister(ratelimited)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func newDNSCounter(name string, help string) prometheus.Counter {
|
||||
return prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Namespace: plugin.Namespace,
|
||||
Subsystem: "refuseany",
|
||||
Name: name,
|
||||
Help: help,
|
||||
})
|
||||
}
|
||||
|
||||
var (
|
||||
ratelimited = newDNSCounter("refusedany_total", "Count of ANY requests that have been dropped")
|
||||
)
|
||||
|
||||
// Name returns name of the plugin as seen in Corefile and plugin.cfg
|
||||
func (p *plug) Name() string { return "refuseany" }
|
@ -1,36 +0,0 @@
|
||||
package dnsfilter
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
)
|
||||
|
||||
var Reload = make(chan bool)
|
||||
|
||||
func hook(event caddy.EventName, info interface{}) error {
|
||||
if event != caddy.InstanceStartupEvent {
|
||||
return nil
|
||||
}
|
||||
|
||||
// this should be an instance. ok to panic if not
|
||||
instance := info.(*caddy.Instance)
|
||||
|
||||
go func() {
|
||||
for range Reload {
|
||||
corefile, err := caddy.LoadCaddyfile(instance.Caddyfile().ServerType())
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
_, err = instance.Restart(corefile)
|
||||
if err != nil {
|
||||
log.Printf("Corefile changed but reload failed: %s", err)
|
||||
continue
|
||||
}
|
||||
// hook will be called again from new instance
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
89
dns.go
Normal file
89
dns.go
Normal file
@ -0,0 +1,89 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
|
||||
"github.com/AdguardTeam/AdGuardHome/dnsforward"
|
||||
"github.com/joomcode/errorx"
|
||||
)
|
||||
|
||||
var dnsServer = dnsforward.Server{}
|
||||
|
||||
func isRunning() bool {
|
||||
return dnsServer.IsRunning()
|
||||
}
|
||||
|
||||
func generateServerConfig() dnsforward.ServerConfig {
|
||||
filters := []dnsfilter.Filter{}
|
||||
userFilter := userFilter()
|
||||
filters = append(filters, dnsfilter.Filter{
|
||||
ID: userFilter.ID,
|
||||
Rules: userFilter.Rules,
|
||||
})
|
||||
for _, filter := range config.Filters {
|
||||
filters = append(filters, dnsfilter.Filter{
|
||||
ID: filter.ID,
|
||||
Rules: filter.Rules,
|
||||
})
|
||||
}
|
||||
|
||||
newconfig := dnsforward.ServerConfig{
|
||||
UDPListenAddr: &net.UDPAddr{Port: config.DNS.Port},
|
||||
FilteringConfig: config.DNS.FilteringConfig,
|
||||
Filters: filters,
|
||||
}
|
||||
|
||||
for _, u := range config.DNS.UpstreamDNS {
|
||||
upstream, err := dnsforward.AddressToUpstream(u, config.DNS.BootstrapDNS)
|
||||
if err != nil {
|
||||
log.Printf("Couldn't get upstream: %s", err)
|
||||
// continue, just ignore the upstream
|
||||
continue
|
||||
}
|
||||
newconfig.Upstreams = append(newconfig.Upstreams, upstream)
|
||||
}
|
||||
return newconfig
|
||||
}
|
||||
|
||||
func startDNSServer() error {
|
||||
if isRunning() {
|
||||
return fmt.Errorf("Unable to start forwarding DNS server: Already running")
|
||||
}
|
||||
|
||||
newconfig := generateServerConfig()
|
||||
err := dnsServer.Start(&newconfig)
|
||||
if err != nil {
|
||||
return errorx.Decorate(err, "Couldn't start forwarding DNS server")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func reconfigureDNSServer() error {
|
||||
if !isRunning() {
|
||||
return fmt.Errorf("Refusing to reconfigure forwarding DNS server: not running")
|
||||
}
|
||||
|
||||
err := dnsServer.Reconfigure(generateServerConfig())
|
||||
if err != nil {
|
||||
return errorx.Decorate(err, "Couldn't start forwarding DNS server")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func stopDNSServer() error {
|
||||
if !isRunning() {
|
||||
return fmt.Errorf("Refusing to stop forwarding DNS server: not running")
|
||||
}
|
||||
|
||||
err := dnsServer.Stop()
|
||||
if err != nil {
|
||||
return errorx.Decorate(err, "Couldn't stop forwarding DNS server")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
@ -38,21 +38,22 @@ var ErrInvalidSyntax = errors.New("dnsfilter: invalid rule syntax")
|
||||
// ErrInvalidSyntax is returned by AddRule when the rule was already added to the filter
|
||||
var ErrAlreadyExists = errors.New("dnsfilter: rule was already added")
|
||||
|
||||
// ErrInvalidParental is returned by EnableParental when sensitivity is not a valid value
|
||||
var ErrInvalidParental = errors.New("dnsfilter: invalid parental sensitivity, must be either 3, 10, 13 or 17")
|
||||
|
||||
const shortcutLength = 6 // used for rule search optimization, 6 hits the sweet spot
|
||||
|
||||
const enableFastLookup = true // flag for debugging, must be true in production for faster performance
|
||||
const enableDelayedCompilation = true // flag for debugging, must be true in production for faster performance
|
||||
|
||||
type config struct {
|
||||
parentalServer string
|
||||
parentalSensitivity int // must be either 3, 10, 13 or 17
|
||||
parentalEnabled bool
|
||||
safeSearchEnabled bool
|
||||
safeBrowsingEnabled bool
|
||||
safeBrowsingServer string
|
||||
// Config allows you to configure DNS filtering with New() or just change variables directly.
|
||||
type Config struct {
|
||||
ParentalSensitivity int `yaml:"parental_sensitivity"` // must be either 3, 10, 13 or 17
|
||||
ParentalEnabled bool `yaml:"parental_enabled"`
|
||||
SafeSearchEnabled bool `yaml:"safesearch_enabled"`
|
||||
SafeBrowsingEnabled bool `yaml:"safebrowsing_enabled"`
|
||||
}
|
||||
|
||||
type privateConfig struct {
|
||||
parentalServer string // access via methods
|
||||
safeBrowsingServer string // access via methods
|
||||
}
|
||||
|
||||
type rule struct {
|
||||
@ -110,7 +111,13 @@ type Dnsfilter struct {
|
||||
client http.Client // handle for http client -- single instance as recommended by docs
|
||||
transport *http.Transport // handle for http transport used by http client
|
||||
|
||||
config config
|
||||
Config // for direct access by library users, even a = assignment
|
||||
privateConfig
|
||||
}
|
||||
|
||||
type Filter struct {
|
||||
ID int64 `json:"id"` // auto-assigned when filter is added (see nextFilterID), json by default keeps ID uppercase but we need lowercase
|
||||
Rules []string `json:"-" yaml:"-"` // not in yaml or json
|
||||
}
|
||||
|
||||
//go:generate stringer -type=Reason
|
||||
@ -171,7 +178,7 @@ func (d *Dnsfilter) CheckHost(host string) (Result, error) {
|
||||
}
|
||||
|
||||
// check safebrowsing if no match
|
||||
if d.config.safeBrowsingEnabled {
|
||||
if d.SafeBrowsingEnabled {
|
||||
result, err = d.checkSafeBrowsing(host)
|
||||
if err != nil {
|
||||
// failed to do HTTP lookup -- treat it as if we got empty response, but don't save cache
|
||||
@ -184,7 +191,7 @@ func (d *Dnsfilter) CheckHost(host string) (Result, error) {
|
||||
}
|
||||
|
||||
// check parental if no match
|
||||
if d.config.parentalEnabled {
|
||||
if d.ParentalEnabled {
|
||||
result, err = d.checkParental(host)
|
||||
if err != nil {
|
||||
// failed to do HTTP lookup -- treat it as if we got empty response, but don't save cache
|
||||
@ -569,11 +576,11 @@ func hostnameToHashParam(host string, addslash bool) (string, map[string]bool) {
|
||||
|
||||
func (d *Dnsfilter) checkSafeBrowsing(host string) (Result, error) {
|
||||
// prevent recursion -- checking the host of safebrowsing server makes no sense
|
||||
if host == d.config.safeBrowsingServer {
|
||||
if host == d.safeBrowsingServer {
|
||||
return Result{}, nil
|
||||
}
|
||||
format := func(hashparam string) string {
|
||||
url := fmt.Sprintf(defaultSafebrowsingURL, d.config.safeBrowsingServer, hashparam)
|
||||
url := fmt.Sprintf(defaultSafebrowsingURL, d.safeBrowsingServer, hashparam)
|
||||
return url
|
||||
}
|
||||
handleBody := func(body []byte, hashes map[string]bool) (Result, error) {
|
||||
@ -610,11 +617,11 @@ func (d *Dnsfilter) checkSafeBrowsing(host string) (Result, error) {
|
||||
|
||||
func (d *Dnsfilter) checkParental(host string) (Result, error) {
|
||||
// prevent recursion -- checking the host of parental safety server makes no sense
|
||||
if host == d.config.parentalServer {
|
||||
if host == d.parentalServer {
|
||||
return Result{}, nil
|
||||
}
|
||||
format := func(hashparam string) string {
|
||||
url := fmt.Sprintf(defaultParentalURL, d.config.parentalServer, hashparam, d.config.parentalSensitivity)
|
||||
url := fmt.Sprintf(defaultParentalURL, d.parentalServer, hashparam, d.ParentalSensitivity)
|
||||
return url
|
||||
}
|
||||
handleBody := func(body []byte, hashes map[string]bool) (Result, error) {
|
||||
@ -727,6 +734,24 @@ func (d *Dnsfilter) lookupCommon(host string, lookupstats *LookupStats, cache gc
|
||||
// Adding rule and matching against the rules
|
||||
//
|
||||
|
||||
// AddRules is a convinience function to add an array of filters in one call
|
||||
func (d *Dnsfilter) AddRules(filters []Filter) error {
|
||||
for _, f := range filters {
|
||||
for _, rule := range f.Rules {
|
||||
err := d.AddRule(rule, f.ID)
|
||||
if err == ErrAlreadyExists || err == ErrInvalidSyntax {
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
log.Printf("Cannot add rule %s: %s", rule, err)
|
||||
// Just ignore invalid rules
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddRule adds a rule, checking if it is a valid rule first and if it wasn't added already
|
||||
func (d *Dnsfilter) AddRule(input string, filterListID int64) error {
|
||||
input = strings.TrimSpace(input)
|
||||
@ -846,7 +871,7 @@ func (d *Dnsfilter) matchHost(host string) (Result, error) {
|
||||
//
|
||||
|
||||
// New creates properly initialized DNS Filter that is ready to be used
|
||||
func New() *Dnsfilter {
|
||||
func New(c *Config) *Dnsfilter {
|
||||
d := new(Dnsfilter)
|
||||
|
||||
d.storage = make(map[string]bool)
|
||||
@ -867,8 +892,11 @@ func New() *Dnsfilter {
|
||||
Transport: d.transport,
|
||||
Timeout: defaultHTTPTimeout,
|
||||
}
|
||||
d.config.safeBrowsingServer = defaultSafebrowsingServer
|
||||
d.config.parentalServer = defaultParentalServer
|
||||
d.safeBrowsingServer = defaultSafebrowsingServer
|
||||
d.parentalServer = defaultParentalServer
|
||||
if c != nil {
|
||||
d.Config = *c
|
||||
}
|
||||
|
||||
return d
|
||||
}
|
||||
@ -885,35 +913,21 @@ func (d *Dnsfilter) Destroy() {
|
||||
// config manipulation helpers
|
||||
//
|
||||
|
||||
// EnableSafeBrowsing turns on checking hostnames in malware/phishing database
|
||||
func (d *Dnsfilter) EnableSafeBrowsing() {
|
||||
d.config.safeBrowsingEnabled = true
|
||||
}
|
||||
|
||||
// EnableParental turns on checking hostnames for containing adult content
|
||||
func (d *Dnsfilter) EnableParental(sensitivity int) error {
|
||||
// IsParentalSensitivityValid checks if sensitivity is valid value
|
||||
func IsParentalSensitivityValid(sensitivity int) bool {
|
||||
switch sensitivity {
|
||||
case 3, 10, 13, 17:
|
||||
d.config.parentalSensitivity = sensitivity
|
||||
d.config.parentalEnabled = true
|
||||
return nil
|
||||
default:
|
||||
return ErrInvalidParental
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// EnableSafeSearch turns on enforcing safesearch in search engines
|
||||
// only used in coredns plugin and requires caller to use SafeSearchDomain()
|
||||
func (d *Dnsfilter) EnableSafeSearch() {
|
||||
d.config.safeSearchEnabled = true
|
||||
return false
|
||||
}
|
||||
|
||||
// SetSafeBrowsingServer lets you optionally change hostname of safesearch lookup
|
||||
func (d *Dnsfilter) SetSafeBrowsingServer(host string) {
|
||||
if len(host) == 0 {
|
||||
d.config.safeBrowsingServer = defaultSafebrowsingServer
|
||||
d.safeBrowsingServer = defaultSafebrowsingServer
|
||||
} else {
|
||||
d.config.safeBrowsingServer = host
|
||||
d.safeBrowsingServer = host
|
||||
}
|
||||
}
|
||||
|
||||
@ -929,7 +943,7 @@ func (d *Dnsfilter) ResetHTTPTimeout() {
|
||||
|
||||
// SafeSearchDomain returns replacement address for search engine
|
||||
func (d *Dnsfilter) SafeSearchDomain(host string) (string, bool) {
|
||||
if d.config.safeSearchEnabled {
|
||||
if d.SafeSearchEnabled {
|
||||
val, ok := safeSearchDomains[host]
|
||||
return val, ok
|
||||
}
|
||||
|
@ -338,7 +338,7 @@ func mustLoadTestRules(d *Dnsfilter) {
|
||||
}
|
||||
|
||||
func NewForTest() *Dnsfilter {
|
||||
d := New()
|
||||
d := New(nil)
|
||||
purgeCaches()
|
||||
return d
|
||||
}
|
||||
@ -542,7 +542,7 @@ func TestSafeBrowsing(t *testing.T) {
|
||||
t.Run(fmt.Sprintf("%s in %s", tc, _Func()), func(t *testing.T) {
|
||||
d := NewForTest()
|
||||
defer d.Destroy()
|
||||
d.EnableSafeBrowsing()
|
||||
d.SafeBrowsingEnabled = true
|
||||
stats.Safebrowsing.Requests = 0
|
||||
d.checkMatch(t, "wmconvirus.narod.ru")
|
||||
d.checkMatch(t, "wmconvirus.narod.ru")
|
||||
@ -570,7 +570,7 @@ func TestSafeBrowsing(t *testing.T) {
|
||||
func TestParallelSB(t *testing.T) {
|
||||
d := NewForTest()
|
||||
defer d.Destroy()
|
||||
d.EnableSafeBrowsing()
|
||||
d.SafeBrowsingEnabled = true
|
||||
t.Run("group", func(t *testing.T) {
|
||||
for i := 0; i < 100; i++ {
|
||||
t.Run(fmt.Sprintf("aaa%d", i), func(t *testing.T) {
|
||||
@ -597,7 +597,7 @@ func TestSafeBrowsingCustomServerFail(t *testing.T) {
|
||||
defer ts.Close()
|
||||
address := ts.Listener.Addr().String()
|
||||
|
||||
d.EnableSafeBrowsing()
|
||||
d.SafeBrowsingEnabled = true
|
||||
d.SetHTTPTimeout(time.Second * 5)
|
||||
d.SetSafeBrowsingServer(address) // this will ensure that test fails
|
||||
d.checkMatchEmpty(t, "wmconvirus.narod.ru")
|
||||
@ -606,7 +606,8 @@ func TestSafeBrowsingCustomServerFail(t *testing.T) {
|
||||
func TestParentalControl(t *testing.T) {
|
||||
d := NewForTest()
|
||||
defer d.Destroy()
|
||||
d.EnableParental(3)
|
||||
d.ParentalEnabled = true
|
||||
d.ParentalSensitivity = 3
|
||||
d.checkMatch(t, "pornhub.com")
|
||||
d.checkMatch(t, "pornhub.com")
|
||||
if stats.Parental.Requests != 1 {
|
||||
@ -637,7 +638,7 @@ func TestSafeSearch(t *testing.T) {
|
||||
if ok {
|
||||
t.Errorf("Expected safesearch to error when disabled")
|
||||
}
|
||||
d.EnableSafeSearch()
|
||||
d.SafeSearchEnabled = true
|
||||
val, ok := d.SafeSearchDomain("www.google.com")
|
||||
if !ok {
|
||||
t.Errorf("Expected safesearch to find result for www.google.com")
|
||||
@ -924,7 +925,7 @@ func BenchmarkLotsOfRulesLotsOfHostsParallel(b *testing.B) {
|
||||
func BenchmarkSafeBrowsing(b *testing.B) {
|
||||
d := NewForTest()
|
||||
defer d.Destroy()
|
||||
d.EnableSafeBrowsing()
|
||||
d.SafeBrowsingEnabled = true
|
||||
for n := 0; n < b.N; n++ {
|
||||
hostname := "wmconvirus.narod.ru"
|
||||
ret, err := d.CheckHost(hostname)
|
||||
@ -940,7 +941,7 @@ func BenchmarkSafeBrowsing(b *testing.B) {
|
||||
func BenchmarkSafeBrowsingParallel(b *testing.B) {
|
||||
d := NewForTest()
|
||||
defer d.Destroy()
|
||||
d.EnableSafeBrowsing()
|
||||
d.SafeBrowsingEnabled = true
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
hostname := "wmconvirus.narod.ru"
|
||||
@ -958,7 +959,7 @@ func BenchmarkSafeBrowsingParallel(b *testing.B) {
|
||||
func BenchmarkSafeSearch(b *testing.B) {
|
||||
d := NewForTest()
|
||||
defer d.Destroy()
|
||||
d.EnableSafeSearch()
|
||||
d.SafeSearchEnabled = true
|
||||
for n := 0; n < b.N; n++ {
|
||||
val, ok := d.SafeSearchDomain("www.google.com")
|
||||
if !ok {
|
||||
@ -973,7 +974,7 @@ func BenchmarkSafeSearch(b *testing.B) {
|
||||
func BenchmarkSafeSearchParallel(b *testing.B) {
|
||||
d := NewForTest()
|
||||
defer d.Destroy()
|
||||
d.EnableSafeSearch()
|
||||
d.SafeSearchEnabled = true
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
val, ok := d.SafeSearchDomain("www.google.com")
|
||||
@ -1009,17 +1010,3 @@ func _Func() string {
|
||||
f := runtime.FuncForPC(pc[0])
|
||||
return path.Base(f.Name())
|
||||
}
|
||||
|
||||
func trace(format string, args ...interface{}) {
|
||||
pc := make([]uintptr, 10) // at least 1 entry needed
|
||||
runtime.Callers(2, pc)
|
||||
f := runtime.FuncForPC(pc[0])
|
||||
var buf strings.Builder
|
||||
buf.WriteString(fmt.Sprintf("%s(): ", path.Base(f.Name())))
|
||||
text := fmt.Sprintf(format, args...)
|
||||
buf.WriteString(text)
|
||||
if len(text) == 0 || text[len(text)-1] != '\n' {
|
||||
buf.WriteRune('\n')
|
||||
}
|
||||
fmt.Print(buf.String())
|
||||
}
|
||||
|
@ -1,6 +1,10 @@
|
||||
package dnsfilter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
)
|
||||
@ -58,3 +62,17 @@ func updateMax(valuePtr *int64, maxPtr *int64) {
|
||||
// swapping failed because value has changed after reading, try again
|
||||
}
|
||||
}
|
||||
|
||||
func trace(format string, args ...interface{}) {
|
||||
pc := make([]uintptr, 10) // at least 1 entry needed
|
||||
runtime.Callers(2, pc)
|
||||
f := runtime.FuncForPC(pc[0])
|
||||
var buf strings.Builder
|
||||
buf.WriteString(fmt.Sprintf("%s(): ", path.Base(f.Name())))
|
||||
text := fmt.Sprintf(format, args...)
|
||||
buf.WriteString(text)
|
||||
if len(text) == 0 || text[len(text)-1] != '\n' {
|
||||
buf.WriteRune('\n')
|
||||
}
|
||||
fmt.Fprint(os.Stderr, buf.String())
|
||||
}
|
||||
|
107
dnsforward/bootstrap.go
Normal file
107
dnsforward/bootstrap.go
Normal file
@ -0,0 +1,107 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/joomcode/errorx"
|
||||
)
|
||||
|
||||
type bootstrapper struct {
|
||||
address string // in form of "tls://one.one.one.one:853"
|
||||
resolver *net.Resolver // resolver to use to resolve hostname, if neccessary
|
||||
resolved string // in form "IP:port"
|
||||
resolvedConfig *tls.Config
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
func toBoot(address, bootstrapAddr string) bootstrapper {
|
||||
var resolver *net.Resolver
|
||||
if bootstrapAddr != "" {
|
||||
resolver = &net.Resolver{
|
||||
PreferGo: true,
|
||||
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
d := net.Dialer{}
|
||||
return d.DialContext(ctx, network, bootstrapAddr)
|
||||
},
|
||||
}
|
||||
}
|
||||
return bootstrapper{
|
||||
address: address,
|
||||
resolver: resolver,
|
||||
}
|
||||
}
|
||||
|
||||
// will get usable IP address from Address field, and caches the result
|
||||
func (n *bootstrapper) get() (string, *tls.Config, error) {
|
||||
// TODO: RLock() here but atomically upgrade to Lock() if fast path doesn't work
|
||||
n.Lock()
|
||||
if n.resolved != "" { // fast path
|
||||
retval, tlsconfig := n.resolved, n.resolvedConfig
|
||||
n.Unlock()
|
||||
return retval, tlsconfig, nil
|
||||
}
|
||||
|
||||
//
|
||||
// slow path
|
||||
//
|
||||
|
||||
defer n.Unlock()
|
||||
|
||||
justHostPort := n.address
|
||||
if strings.Contains(n.address, "://") {
|
||||
url, err := url.Parse(n.address)
|
||||
if err != nil {
|
||||
return "", nil, errorx.Decorate(err, "Failed to parse %s", n.address)
|
||||
}
|
||||
|
||||
justHostPort = url.Host
|
||||
}
|
||||
|
||||
// convert host to IP if neccessary, we know that it's scheme://hostname:port/
|
||||
|
||||
// get a host without port
|
||||
host, port, err := net.SplitHostPort(justHostPort)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("bootstrapper requires port in address %s", n.address)
|
||||
}
|
||||
|
||||
// if it's an IP
|
||||
ip := net.ParseIP(host)
|
||||
if ip != nil {
|
||||
n.resolved = justHostPort
|
||||
return n.resolved, nil, nil
|
||||
}
|
||||
|
||||
//
|
||||
// if it's a hostname
|
||||
//
|
||||
|
||||
resolver := n.resolver // no need to check for nil resolver -- documented that nil is default resolver
|
||||
addrs, err := resolver.LookupIPAddr(context.TODO(), host)
|
||||
if err != nil {
|
||||
return "", nil, errorx.Decorate(err, "Failed to lookup %s", host)
|
||||
}
|
||||
for _, addr := range addrs {
|
||||
// TODO: support ipv6, support multiple ipv4
|
||||
if addr.IP.To4() == nil {
|
||||
continue
|
||||
}
|
||||
ip = addr.IP
|
||||
break
|
||||
}
|
||||
|
||||
if ip == nil {
|
||||
// couldn't find any suitable IP address
|
||||
return "", nil, fmt.Errorf("Couldn't find any suitable IP address for host %s", host)
|
||||
}
|
||||
|
||||
n.resolved = net.JoinHostPort(ip.String(), port)
|
||||
n.resolvedConfig = &tls.Config{ServerName: host}
|
||||
return n.resolved, n.resolvedConfig, nil
|
||||
}
|
225
dnsforward/cache.go
Normal file
225
dnsforward/cache.go
Normal file
@ -0,0 +1,225 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"log"
|
||||
"math"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type item struct {
|
||||
m *dns.Msg
|
||||
when time.Time
|
||||
}
|
||||
|
||||
type cache struct {
|
||||
items map[string]item
|
||||
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
func (c *cache) Get(request *dns.Msg) (*dns.Msg, bool) {
|
||||
if request == nil {
|
||||
return nil, false
|
||||
}
|
||||
ok, key := key(request)
|
||||
if !ok {
|
||||
log.Printf("Get(): key returned !ok")
|
||||
return nil, false
|
||||
}
|
||||
|
||||
c.RLock()
|
||||
item, ok := c.items[key]
|
||||
c.RUnlock()
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
// get item's TTL
|
||||
ttl := findLowestTTL(item.m)
|
||||
// zero TTL? delete and don't serve it
|
||||
if ttl == 0 {
|
||||
c.Lock()
|
||||
delete(c.items, key)
|
||||
c.Unlock()
|
||||
return nil, false
|
||||
}
|
||||
// too much time has passed? delete and don't serve it
|
||||
if time.Since(item.when) >= time.Duration(ttl)*time.Second {
|
||||
c.Lock()
|
||||
delete(c.items, key)
|
||||
c.Unlock()
|
||||
return nil, false
|
||||
}
|
||||
response := item.fromItem(request)
|
||||
return response, true
|
||||
}
|
||||
|
||||
func (c *cache) Set(m *dns.Msg) {
|
||||
if m == nil {
|
||||
return // no-op
|
||||
}
|
||||
if !isRequestCacheable(m) {
|
||||
return
|
||||
}
|
||||
if !isResponseCacheable(m) {
|
||||
return
|
||||
}
|
||||
ok, key := key(m)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
i := toItem(m)
|
||||
|
||||
c.Lock()
|
||||
if c.items == nil {
|
||||
c.items = map[string]item{}
|
||||
}
|
||||
c.items[key] = i
|
||||
c.Unlock()
|
||||
}
|
||||
|
||||
// check only request fields
|
||||
func isRequestCacheable(m *dns.Msg) bool {
|
||||
// truncated messages aren't valid
|
||||
if m.Truncated {
|
||||
log.Printf("Refusing to cache truncated message")
|
||||
return false
|
||||
}
|
||||
|
||||
// if has wrong number of questions, also don't cache
|
||||
if len(m.Question) != 1 {
|
||||
log.Printf("Refusing to cache message with wrong number of questions")
|
||||
return false
|
||||
}
|
||||
|
||||
// only OK or NXdomain replies are cached
|
||||
switch m.Rcode {
|
||||
case dns.RcodeSuccess:
|
||||
case dns.RcodeNameError: // that's an NXDomain
|
||||
case dns.RcodeServerFailure:
|
||||
return false // quietly refuse, don't log
|
||||
default:
|
||||
log.Printf("%s: Refusing to cache message with rcode: %s", m.Question[0].Name, dns.RcodeToString[m.Rcode])
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func isResponseCacheable(m *dns.Msg) bool {
|
||||
ttl := findLowestTTL(m)
|
||||
if ttl == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func findLowestTTL(m *dns.Msg) uint32 {
|
||||
var ttl uint32 = math.MaxUint32
|
||||
found := false
|
||||
|
||||
if m.Answer != nil {
|
||||
for _, r := range m.Answer {
|
||||
if r.Header().Ttl < ttl {
|
||||
ttl = r.Header().Ttl
|
||||
found = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if m.Ns != nil {
|
||||
for _, r := range m.Ns {
|
||||
if r.Header().Ttl < ttl {
|
||||
ttl = r.Header().Ttl
|
||||
found = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if m.Extra != nil {
|
||||
for _, r := range m.Extra {
|
||||
if r.Header().Rrtype == dns.TypeOPT {
|
||||
continue // OPT records use TTL for other purposes
|
||||
}
|
||||
if r.Header().Ttl < ttl {
|
||||
ttl = r.Header().Ttl
|
||||
found = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if found == false {
|
||||
return 0
|
||||
}
|
||||
|
||||
return ttl
|
||||
}
|
||||
|
||||
// key is binary little endian in sequence:
|
||||
// uint16(qtype) then uint16(qclass) then name
|
||||
func key(m *dns.Msg) (bool, string) {
|
||||
if len(m.Question) != 1 {
|
||||
log.Printf("got msg with len(m.Question) != 1: %d", len(m.Question))
|
||||
return false, ""
|
||||
}
|
||||
|
||||
bb := strings.Builder{}
|
||||
b := make([]byte, 2)
|
||||
binary.LittleEndian.PutUint16(b, m.Question[0].Qtype)
|
||||
bb.Write(b)
|
||||
binary.LittleEndian.PutUint16(b, m.Question[0].Qclass)
|
||||
bb.Write(b)
|
||||
name := strings.ToLower(m.Question[0].Name)
|
||||
bb.WriteString(name)
|
||||
return true, bb.String()
|
||||
}
|
||||
|
||||
func toItem(m *dns.Msg) item {
|
||||
return item{
|
||||
m: m,
|
||||
when: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func (i *item) fromItem(request *dns.Msg) *dns.Msg {
|
||||
response := &dns.Msg{}
|
||||
response.SetReply(request)
|
||||
|
||||
response.Authoritative = false
|
||||
response.AuthenticatedData = i.m.AuthenticatedData
|
||||
response.RecursionAvailable = i.m.RecursionAvailable
|
||||
response.Rcode = i.m.Rcode
|
||||
|
||||
ttl := findLowestTTL(i.m)
|
||||
timeleft := math.Round(float64(ttl) - time.Since(i.when).Seconds())
|
||||
var newttl uint32
|
||||
if timeleft > 0 {
|
||||
newttl = uint32(timeleft)
|
||||
}
|
||||
for _, r := range i.m.Answer {
|
||||
answer := dns.Copy(r)
|
||||
answer.Header().Ttl = newttl
|
||||
response.Answer = append(response.Answer, answer)
|
||||
}
|
||||
for _, r := range i.m.Ns {
|
||||
ns := dns.Copy(r)
|
||||
ns.Header().Ttl = newttl
|
||||
response.Ns = append(response.Ns, ns)
|
||||
}
|
||||
for _, r := range i.m.Extra {
|
||||
// don't return OPT records as these are hop-by-hop
|
||||
if r.Header().Rrtype == dns.TypeOPT {
|
||||
continue
|
||||
}
|
||||
extra := dns.Copy(r)
|
||||
extra.Header().Ttl = newttl
|
||||
response.Extra = append(response.Extra, extra)
|
||||
}
|
||||
return response
|
||||
}
|
144
dnsforward/cache_test.go
Normal file
144
dnsforward/cache_test.go
Normal file
@ -0,0 +1,144 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/go-test/deep"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func RR(rr string) dns.RR {
|
||||
r, err := dns.NewRR(rr)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// deepEqual is same as deep.Equal, except:
|
||||
// * ignores Id when comparing
|
||||
// * question names are not case sensetive
|
||||
func deepEqualMsg(left *dns.Msg, right *dns.Msg) []string {
|
||||
temp := *left
|
||||
temp.Id = right.Id
|
||||
for i := range left.Question {
|
||||
left.Question[i].Name = strings.ToLower(left.Question[i].Name)
|
||||
}
|
||||
for i := range right.Question {
|
||||
right.Question[i].Name = strings.ToLower(right.Question[i].Name)
|
||||
}
|
||||
return deep.Equal(&temp, right)
|
||||
}
|
||||
|
||||
func TestCacheSanity(t *testing.T) {
|
||||
cache := cache{}
|
||||
request := dns.Msg{}
|
||||
request.SetQuestion("google.com.", dns.TypeA)
|
||||
_, ok := cache.Get(&request)
|
||||
if ok {
|
||||
t.Fatal("empty cache replied with positive response")
|
||||
}
|
||||
}
|
||||
|
||||
type tests struct {
|
||||
cache []testEntry
|
||||
cases []testCase
|
||||
}
|
||||
|
||||
type testEntry struct {
|
||||
q string
|
||||
t uint16
|
||||
a []dns.RR
|
||||
}
|
||||
|
||||
type testCase struct {
|
||||
q string
|
||||
t uint16
|
||||
a []dns.RR
|
||||
ok bool
|
||||
}
|
||||
|
||||
func TestCache(t *testing.T) {
|
||||
tests := tests{
|
||||
cache: []testEntry{
|
||||
{q: "google.com.", t: dns.TypeA, a: []dns.RR{RR("google.com. 3600 IN A 8.8.8.8")}},
|
||||
},
|
||||
cases: []testCase{
|
||||
{q: "google.com.", t: dns.TypeA, a: []dns.RR{RR("google.com. 3600 IN A 8.8.8.8")}, ok: true},
|
||||
{q: "google.com.", t: dns.TypeMX, ok: false},
|
||||
},
|
||||
}
|
||||
runTests(t, tests)
|
||||
}
|
||||
|
||||
func TestCacheMixedCase(t *testing.T) {
|
||||
tests := tests{
|
||||
cache: []testEntry{
|
||||
{q: "gOOgle.com.", t: dns.TypeA, a: []dns.RR{RR("google.com. 3600 IN A 8.8.8.8")}},
|
||||
},
|
||||
cases: []testCase{
|
||||
{q: "gOOgle.com.", t: dns.TypeA, a: []dns.RR{RR("google.com. 3600 IN A 8.8.8.8")}, ok: true},
|
||||
{q: "google.com.", t: dns.TypeA, a: []dns.RR{RR("google.com. 3600 IN A 8.8.8.8")}, ok: true},
|
||||
{q: "GOOGLE.COM.", t: dns.TypeA, a: []dns.RR{RR("google.com. 3600 IN A 8.8.8.8")}, ok: true},
|
||||
{q: "gOOgle.com.", t: dns.TypeMX, ok: false},
|
||||
{q: "google.com.", t: dns.TypeMX, ok: false},
|
||||
{q: "GOOGLE.COM.", t: dns.TypeMX, ok: false},
|
||||
},
|
||||
}
|
||||
runTests(t, tests)
|
||||
}
|
||||
|
||||
func TestZeroTTL(t *testing.T) {
|
||||
tests := tests{
|
||||
cache: []testEntry{
|
||||
{q: "gOOgle.com.", t: dns.TypeA, a: []dns.RR{RR("google.com. 0 IN A 8.8.8.8")}},
|
||||
},
|
||||
cases: []testCase{
|
||||
{q: "google.com.", t: dns.TypeA, ok: false},
|
||||
{q: "google.com.", t: dns.TypeA, ok: false},
|
||||
{q: "google.com.", t: dns.TypeA, ok: false},
|
||||
{q: "google.com.", t: dns.TypeMX, ok: false},
|
||||
{q: "google.com.", t: dns.TypeMX, ok: false},
|
||||
{q: "google.com.", t: dns.TypeMX, ok: false},
|
||||
},
|
||||
}
|
||||
runTests(t, tests)
|
||||
}
|
||||
|
||||
func runTests(t *testing.T, tests tests) {
|
||||
t.Helper()
|
||||
cache := cache{}
|
||||
for _, tc := range tests.cache {
|
||||
reply := dns.Msg{}
|
||||
reply.SetQuestion(tc.q, tc.t)
|
||||
reply.Response = true
|
||||
reply.Answer = tc.a
|
||||
cache.Set(&reply)
|
||||
}
|
||||
for _, tc := range tests.cases {
|
||||
request := dns.Msg{}
|
||||
request.SetQuestion(tc.q, tc.t)
|
||||
val, ok := cache.Get(&request)
|
||||
if diff := deep.Equal(ok, tc.ok); diff != nil {
|
||||
t.Error(diff)
|
||||
}
|
||||
if tc.a != nil {
|
||||
if ok == false {
|
||||
continue
|
||||
}
|
||||
reply := dns.Msg{}
|
||||
reply.SetQuestion(tc.q, tc.t)
|
||||
reply.Response = true
|
||||
reply.Answer = tc.a
|
||||
cache.Set(&reply)
|
||||
if diff := deepEqualMsg(val, &reply); diff != nil {
|
||||
t.Error(diff)
|
||||
} else {
|
||||
if diff := deep.Equal(val, reply); diff == nil {
|
||||
t.Error("different message ID were not caught")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
594
dnsforward/dnsforward.go
Normal file
594
dnsforward/dnsforward.go
Normal file
@ -0,0 +1,594 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
|
||||
"github.com/joomcode/errorx"
|
||||
"github.com/miekg/dns"
|
||||
gocache "github.com/patrickmn/go-cache"
|
||||
)
|
||||
|
||||
// Server is the main way to start a DNS server.
|
||||
//
|
||||
// Example:
|
||||
// s := dnsforward.Server{}
|
||||
// err := s.Start(nil) // will start a DNS server listening on default port 53, in a goroutine
|
||||
// err := s.Reconfigure(ServerConfig{UDPListenAddr: &net.UDPAddr{Port: 53535}}) // will reconfigure running DNS server to listen on UDP port 53535
|
||||
// err := s.Stop() // will stop listening on port 53535 and cancel all goroutines
|
||||
// err := s.Start(nil) // will start listening again, on port 53535, in a goroutine
|
||||
//
|
||||
// The zero Server is empty and ready for use.
|
||||
type Server struct {
|
||||
udpListen *net.UDPConn
|
||||
|
||||
dnsFilter *dnsfilter.Dnsfilter
|
||||
|
||||
cache cache
|
||||
|
||||
ratelimitBuckets *gocache.Cache // where the ratelimiters are stored, per IP
|
||||
|
||||
sync.RWMutex
|
||||
ServerConfig
|
||||
}
|
||||
|
||||
// uncomment this block to have tracing of locks
|
||||
/*
|
||||
func (s *Server) Lock() {
|
||||
pc := make([]uintptr, 10) // at least 1 entry needed
|
||||
runtime.Callers(2, pc)
|
||||
f := runtime.FuncForPC(pc[0])
|
||||
file, line := f.FileLine(pc[0])
|
||||
fmt.Fprintf(os.Stderr, "%s:%d %s() -> Lock() -> in progress\n", path.Base(file), line, path.Base(f.Name()))
|
||||
s.RWMutex.Lock()
|
||||
fmt.Fprintf(os.Stderr, "%s:%d %s() -> Lock() -> done\n", path.Base(file), line, path.Base(f.Name()))
|
||||
}
|
||||
func (s *Server) RLock() {
|
||||
pc := make([]uintptr, 10) // at least 1 entry needed
|
||||
runtime.Callers(2, pc)
|
||||
f := runtime.FuncForPC(pc[0])
|
||||
file, line := f.FileLine(pc[0])
|
||||
fmt.Fprintf(os.Stderr, "%s:%d %s() -> RLock() -> in progress\n", path.Base(file), line, path.Base(f.Name()))
|
||||
s.RWMutex.RLock()
|
||||
fmt.Fprintf(os.Stderr, "%s:%d %s() -> RLock() -> done\n", path.Base(file), line, path.Base(f.Name()))
|
||||
}
|
||||
func (s *Server) Unlock() {
|
||||
pc := make([]uintptr, 10) // at least 1 entry needed
|
||||
runtime.Callers(2, pc)
|
||||
f := runtime.FuncForPC(pc[0])
|
||||
file, line := f.FileLine(pc[0])
|
||||
fmt.Fprintf(os.Stderr, "%s:%d %s() -> Unlock() -> in progress\n", path.Base(file), line, path.Base(f.Name()))
|
||||
s.RWMutex.Unlock()
|
||||
fmt.Fprintf(os.Stderr, "%s:%d %s() -> Unlock() -> done\n", path.Base(file), line, path.Base(f.Name()))
|
||||
}
|
||||
func (s *Server) RUnlock() {
|
||||
pc := make([]uintptr, 10) // at least 1 entry needed
|
||||
runtime.Callers(2, pc)
|
||||
f := runtime.FuncForPC(pc[0])
|
||||
file, line := f.FileLine(pc[0])
|
||||
fmt.Fprintf(os.Stderr, "%s:%d %s() -> RUnlock() -> in progress\n", path.Base(file), line, path.Base(f.Name()))
|
||||
s.RWMutex.RUnlock()
|
||||
fmt.Fprintf(os.Stderr, "%s:%d %s() -> RUnlock() -> done\n", path.Base(file), line, path.Base(f.Name()))
|
||||
}
|
||||
*/
|
||||
|
||||
type FilteringConfig struct {
|
||||
ProtectionEnabled bool `yaml:"protection_enabled"`
|
||||
FilteringEnabled bool `yaml:"filtering_enabled"`
|
||||
BlockedResponseTTL uint32 `yaml:"blocked_response_ttl"` // if 0, then default is used (3600)
|
||||
QueryLogEnabled bool `yaml:"querylog_enabled"`
|
||||
Ratelimit int `yaml:"ratelimit"`
|
||||
RatelimitWhitelist []string `yaml:"ratelimit_whitelist"`
|
||||
RefuseAny bool `yaml:"refuse_any"`
|
||||
BootstrapDNS string `yaml:"bootstrap_dns"`
|
||||
|
||||
dnsfilter.Config `yaml:",inline"`
|
||||
}
|
||||
|
||||
// The zero ServerConfig is empty and ready for use.
|
||||
type ServerConfig struct {
|
||||
UDPListenAddr *net.UDPAddr // if nil, then default is is used (port 53 on *)
|
||||
Upstreams []Upstream
|
||||
Filters []dnsfilter.Filter
|
||||
|
||||
FilteringConfig
|
||||
}
|
||||
|
||||
// if any of ServerConfig values are zero, then default values from below are used
|
||||
var defaultValues = ServerConfig{
|
||||
UDPListenAddr: &net.UDPAddr{Port: 53},
|
||||
FilteringConfig: FilteringConfig{BlockedResponseTTL: 3600},
|
||||
Upstreams: []Upstream{
|
||||
//// dns over HTTPS
|
||||
// &dnsOverHTTPS{boot: toBoot("https://1.1.1.1/dns-query", "")},
|
||||
// &dnsOverHTTPS{boot: toBoot("https://dns.google.com/experimental", "")},
|
||||
// &dnsOverHTTPS{boot: toBoot("https://doh.cleanbrowsing.org/doh/security-filter/", "")},
|
||||
// &dnsOverHTTPS{boot: toBoot("https://dns10.quad9.net/dns-query", "")},
|
||||
// &dnsOverHTTPS{boot: toBoot("https://doh.powerdns.org", "")},
|
||||
// &dnsOverHTTPS{boot: toBoot("https://doh.securedns.eu/dns-query", "")},
|
||||
|
||||
//// dns over TLS
|
||||
// &dnsOverTLS{boot: toBoot("tls://8.8.8.8:853", "")},
|
||||
// &dnsOverTLS{boot: toBoot("tls://8.8.4.4:853", "")},
|
||||
// &dnsOverTLS{boot: toBoot("tls://1.1.1.1:853", "")},
|
||||
// &dnsOverTLS{boot: toBoot("tls://1.0.0.1:853", "")},
|
||||
|
||||
//// plainDNS
|
||||
&plainDNS{boot: toBoot("8.8.8.8:53", "")},
|
||||
&plainDNS{boot: toBoot("8.8.4.4:53", "")},
|
||||
&plainDNS{boot: toBoot("1.1.1.1:53", "")},
|
||||
&plainDNS{boot: toBoot("1.0.0.1:53", "")},
|
||||
},
|
||||
}
|
||||
|
||||
//
|
||||
// packet loop
|
||||
//
|
||||
func (s *Server) packetLoop() {
|
||||
log.Printf("Entering packet handle loop")
|
||||
b := make([]byte, dns.MaxMsgSize)
|
||||
for {
|
||||
s.RLock()
|
||||
conn := s.udpListen
|
||||
s.RUnlock()
|
||||
if conn == nil {
|
||||
log.Printf("udp socket has disappeared, exiting loop")
|
||||
break
|
||||
}
|
||||
n, addr, err := conn.ReadFrom(b)
|
||||
// documentation says to handle the packet even if err occurs, so do that first
|
||||
if n > 0 {
|
||||
// make a copy of all bytes because ReadFrom() will overwrite contents of b on next call
|
||||
// we need the contents to survive the call because we're handling them in goroutine
|
||||
p := make([]byte, n)
|
||||
copy(p, b)
|
||||
go s.handlePacket(p, addr, conn) // ignore errors
|
||||
}
|
||||
if err != nil {
|
||||
if isConnClosed(err) {
|
||||
log.Printf("ReadFrom() returned because we're reading from a closed connection, exiting loop")
|
||||
// don't try to nullify s.udpListen here, because s.udpListen could be already re-bound to listen
|
||||
break
|
||||
}
|
||||
log.Printf("Got error when reading from udp listen: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Control functions
|
||||
//
|
||||
|
||||
func (s *Server) Start(config *ServerConfig) error {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
if config != nil {
|
||||
s.ServerConfig = *config
|
||||
}
|
||||
// TODO: handle being called Start() second time after Stop()
|
||||
if s.udpListen == nil {
|
||||
log.Printf("Creating UDP socket")
|
||||
var err error
|
||||
addr := s.UDPListenAddr
|
||||
if addr == nil {
|
||||
addr = defaultValues.UDPListenAddr
|
||||
}
|
||||
s.udpListen, err = net.ListenUDP("udp", addr)
|
||||
if err != nil {
|
||||
s.udpListen = nil
|
||||
return errorx.Decorate(err, "Couldn't listen to UDP socket")
|
||||
}
|
||||
log.Println(s.udpListen.LocalAddr(), s.UDPListenAddr)
|
||||
}
|
||||
|
||||
if s.dnsFilter == nil {
|
||||
log.Printf("Creating dnsfilter")
|
||||
s.dnsFilter = dnsfilter.New(&s.Config)
|
||||
// add rules only if they are enabled
|
||||
if s.FilteringEnabled {
|
||||
s.dnsFilter.AddRules(s.Filters)
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("Loading stats from querylog")
|
||||
err := fillStatsFromQueryLog()
|
||||
if err != nil {
|
||||
log.Printf("Failed to load stats from querylog: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
once.Do(func() {
|
||||
go periodicQueryLogRotate()
|
||||
go periodicHourlyTopRotate()
|
||||
go statsRotator()
|
||||
})
|
||||
|
||||
go s.packetLoop()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) Stop() error {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
if s.udpListen != nil {
|
||||
err := s.udpListen.Close()
|
||||
s.udpListen = nil
|
||||
if err != nil {
|
||||
return errorx.Decorate(err, "Couldn't close UDP listening socket")
|
||||
}
|
||||
}
|
||||
|
||||
// flush remainder to file
|
||||
logBufferLock.Lock()
|
||||
flushBuffer := logBuffer
|
||||
logBuffer = nil
|
||||
logBufferLock.Unlock()
|
||||
err := flushToFile(flushBuffer)
|
||||
if err != nil {
|
||||
log.Printf("Saving querylog to file failed: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) IsRunning() bool {
|
||||
s.RLock()
|
||||
isRunning := true
|
||||
if s.udpListen == nil {
|
||||
isRunning = false
|
||||
}
|
||||
s.RUnlock()
|
||||
return isRunning
|
||||
}
|
||||
|
||||
//
|
||||
// Server reconfigure
|
||||
//
|
||||
|
||||
func (s *Server) reconfigureListenAddr(new ServerConfig) error {
|
||||
oldAddr := s.UDPListenAddr
|
||||
if oldAddr == nil {
|
||||
oldAddr = defaultValues.UDPListenAddr
|
||||
}
|
||||
newAddr := new.UDPListenAddr
|
||||
if newAddr == nil {
|
||||
newAddr = defaultValues.UDPListenAddr
|
||||
}
|
||||
if newAddr.Port == 0 {
|
||||
return errorx.IllegalArgument.New("new port cannot be 0")
|
||||
}
|
||||
if reflect.DeepEqual(oldAddr, newAddr) {
|
||||
// do nothing, the addresses are exactly the same
|
||||
log.Printf("Not going to rebind because addresses are same: %v -> %v", oldAddr, newAddr)
|
||||
return nil
|
||||
}
|
||||
|
||||
// rebind, using a strategy:
|
||||
// * if ports are different, bind new first, then close old
|
||||
// * if ports are same, close old first, then bind new
|
||||
var newListen *net.UDPConn
|
||||
var err error
|
||||
if oldAddr.Port != newAddr.Port {
|
||||
log.Printf("Rebinding -- ports are different so bind first then close")
|
||||
newListen, err = net.ListenUDP("udp", newAddr)
|
||||
if err != nil {
|
||||
return errorx.Decorate(err, "Couldn't bind to %v", newAddr)
|
||||
}
|
||||
s.Lock()
|
||||
if s.udpListen != nil {
|
||||
err = s.udpListen.Close()
|
||||
s.udpListen = nil
|
||||
}
|
||||
s.Unlock()
|
||||
if err != nil {
|
||||
return errorx.Decorate(err, "Couldn't close UDP listening socket")
|
||||
}
|
||||
} else {
|
||||
log.Printf("Rebinding -- ports are same so close first then bind")
|
||||
s.Lock()
|
||||
if s.udpListen != nil {
|
||||
err = s.udpListen.Close()
|
||||
s.udpListen = nil
|
||||
}
|
||||
s.Unlock()
|
||||
if err != nil {
|
||||
return errorx.Decorate(err, "Couldn't close UDP listening socket")
|
||||
}
|
||||
newListen, err = net.ListenUDP("udp", newAddr)
|
||||
if err != nil {
|
||||
return errorx.Decorate(err, "Couldn't bind to %v", newAddr)
|
||||
}
|
||||
}
|
||||
s.Lock()
|
||||
s.udpListen = newListen
|
||||
s.UDPListenAddr = new.UDPListenAddr
|
||||
s.Unlock()
|
||||
log.Println(s.udpListen.LocalAddr(), s.UDPListenAddr)
|
||||
|
||||
go s.packetLoop() // the old one has quit, use new one
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) reconfigureBlockedResponseTTL(new ServerConfig) {
|
||||
newVal := new.BlockedResponseTTL
|
||||
if newVal == 0 {
|
||||
newVal = defaultValues.BlockedResponseTTL
|
||||
}
|
||||
oldVal := s.BlockedResponseTTL
|
||||
if oldVal == 0 {
|
||||
oldVal = defaultValues.BlockedResponseTTL
|
||||
}
|
||||
if newVal != oldVal {
|
||||
s.BlockedResponseTTL = new.BlockedResponseTTL
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) reconfigureUpstreams(new ServerConfig) {
|
||||
newVal := new.Upstreams
|
||||
if len(newVal) == 0 {
|
||||
newVal = defaultValues.Upstreams
|
||||
}
|
||||
oldVal := s.Upstreams
|
||||
if len(oldVal) == 0 {
|
||||
oldVal = defaultValues.Upstreams
|
||||
}
|
||||
if reflect.DeepEqual(newVal, oldVal) {
|
||||
// they're exactly the same, do nothing
|
||||
return
|
||||
}
|
||||
s.Upstreams = new.Upstreams
|
||||
}
|
||||
|
||||
func (s *Server) reconfigureFiltering(new ServerConfig) {
|
||||
newFilters := new.Filters
|
||||
if len(newFilters) == 0 {
|
||||
newFilters = defaultValues.Filters
|
||||
}
|
||||
oldFilters := s.Filters
|
||||
if len(oldFilters) == 0 {
|
||||
oldFilters = defaultValues.Filters
|
||||
}
|
||||
|
||||
needUpdate := false
|
||||
if !reflect.DeepEqual(newFilters, oldFilters) {
|
||||
needUpdate = true
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(new.FilteringConfig, s.FilteringConfig) {
|
||||
needUpdate = true
|
||||
}
|
||||
|
||||
if !needUpdate {
|
||||
// nothing to do, everything is same
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: instead of creating new dnsfilter, change existing one's settings and filters
|
||||
dnsFilter := dnsfilter.New(&new.Config) // sets safebrowsing, safesearch and parental
|
||||
|
||||
// add rules only if they are enabled
|
||||
if new.FilteringEnabled {
|
||||
dnsFilter.AddRules(newFilters)
|
||||
}
|
||||
|
||||
s.Lock()
|
||||
oldDNSFilter := s.dnsFilter
|
||||
s.dnsFilter = dnsFilter
|
||||
s.FilteringConfig = new.FilteringConfig
|
||||
s.Unlock()
|
||||
|
||||
oldDNSFilter.Destroy()
|
||||
}
|
||||
|
||||
func (s *Server) Reconfigure(new ServerConfig) error {
|
||||
s.reconfigureBlockedResponseTTL(new)
|
||||
s.reconfigureUpstreams(new)
|
||||
s.reconfigureFiltering(new)
|
||||
|
||||
err := s.reconfigureListenAddr(new)
|
||||
if err != nil {
|
||||
return errorx.Decorate(err, "Couldn't reconfigure to new listening address %+v", new.UDPListenAddr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
//
|
||||
// packet handling functions
|
||||
//
|
||||
|
||||
// handlePacketInternal processes the incoming packet bytes and returns with an optional response packet.
|
||||
//
|
||||
// If an empty dns.Msg is returned, do not try to send anything back to client, otherwise send contents of dns.Msg.
|
||||
//
|
||||
// If an error is returned, log it, don't try to generate data based on that error.
|
||||
func (s *Server) handlePacketInternal(msg *dns.Msg, addr net.Addr, conn *net.UDPConn) (*dns.Msg, *dnsfilter.Result, Upstream, error) {
|
||||
// log.Printf("Got packet %d bytes from %s: %v", len(p), addr, p)
|
||||
//
|
||||
// DNS packet byte format is valid
|
||||
//
|
||||
// any errors below here require a response to client
|
||||
// log.Printf("Unpacked: %v", msg.String())
|
||||
if len(msg.Question) != 1 {
|
||||
log.Printf("Got invalid number of questions: %v", len(msg.Question))
|
||||
return s.genServerFailure(msg), nil, nil, nil
|
||||
}
|
||||
|
||||
if msg.Question[0].Qtype == dns.TypeANY && s.RefuseAny {
|
||||
return s.genNotImpl(msg), nil, nil, nil
|
||||
}
|
||||
|
||||
// use dnsfilter before cache -- changed settings or filters would require cache invalidation otherwise
|
||||
host := strings.TrimSuffix(msg.Question[0].Name, ".")
|
||||
res, err := s.dnsFilter.CheckHost(host)
|
||||
if err != nil {
|
||||
log.Printf("dnsfilter failed to check host '%s': %s", host, err)
|
||||
return s.genServerFailure(msg), &res, nil, err
|
||||
} else if res.IsFiltered {
|
||||
log.Printf("Host %s is filtered, reason - '%s', matched rule: '%s'", host, res.Reason, res.Rule)
|
||||
return s.genNXDomain(msg), &res, nil, nil
|
||||
}
|
||||
|
||||
{
|
||||
val, ok := s.cache.Get(msg)
|
||||
if ok && val != nil {
|
||||
return val, &res, nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: replace with single-socket implementation
|
||||
upstream := s.chooseUpstream()
|
||||
reply, err := upstream.Exchange(msg)
|
||||
if err != nil {
|
||||
log.Printf("talking to upstream failed for host '%s': %s", host, err)
|
||||
return s.genServerFailure(msg), &res, upstream, err
|
||||
}
|
||||
if reply == nil {
|
||||
log.Printf("SHOULD NOT HAPPEN upstream returned empty message for host '%s'. Request is %v", host, msg.String())
|
||||
return s.genServerFailure(msg), &res, upstream, nil
|
||||
}
|
||||
|
||||
s.cache.Set(reply)
|
||||
|
||||
return reply, &res, upstream, nil
|
||||
}
|
||||
|
||||
func (s *Server) handlePacket(p []byte, addr net.Addr, conn *net.UDPConn) {
|
||||
start := time.Now()
|
||||
ip, _, err := net.SplitHostPort(addr.String())
|
||||
if err != nil {
|
||||
log.Printf("Failed to split %v into host/port: %s", addr, err)
|
||||
// not a fatal error, move on
|
||||
}
|
||||
|
||||
// ratelimit based on IP only, protects CPU cycles and outbound connections
|
||||
if s.isRatelimited(ip) {
|
||||
// log.Printf("Ratelimiting %s based on IP only", ip)
|
||||
return // do nothing, don't reply, we got ratelimited
|
||||
}
|
||||
|
||||
msg := &dns.Msg{}
|
||||
err = msg.Unpack(p)
|
||||
if err != nil {
|
||||
log.Printf("got invalid DNS packet: %s", err)
|
||||
return // do nothing
|
||||
}
|
||||
|
||||
reply, result, upstream, err := s.handlePacketInternal(msg, addr, conn)
|
||||
|
||||
if reply != nil {
|
||||
// ratelimit based on reply size now
|
||||
replysize := reply.Len()
|
||||
if s.isRatelimitedForReply(ip, replysize) {
|
||||
log.Printf("Ratelimiting %s based on IP and size %d", ip, replysize)
|
||||
return // do nothing, don't reply, we got ratelimited
|
||||
}
|
||||
|
||||
// we're good to respond
|
||||
rerr := s.respond(reply, addr, conn)
|
||||
if rerr != nil {
|
||||
log.Printf("Couldn't respond to UDP packet: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// query logging and stats counters
|
||||
if s.QueryLogEnabled {
|
||||
elapsed := time.Since(start)
|
||||
upstreamAddr := ""
|
||||
if upstream != nil {
|
||||
upstreamAddr = upstream.Address()
|
||||
}
|
||||
logRequest(msg, reply, result, elapsed, ip, upstreamAddr)
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// packet sending functions
|
||||
//
|
||||
|
||||
func (s *Server) respond(resp *dns.Msg, addr net.Addr, conn *net.UDPConn) error {
|
||||
// log.Printf("Replying to %s with %s", addr, resp)
|
||||
resp.Compress = true
|
||||
bytes, err := resp.Pack()
|
||||
if err != nil {
|
||||
return errorx.Decorate(err, "Couldn't convert message into wire format")
|
||||
}
|
||||
n, err := conn.WriteTo(bytes, addr)
|
||||
if n == 0 && isConnClosed(err) {
|
||||
return err
|
||||
}
|
||||
if n != len(bytes) {
|
||||
return fmt.Errorf("WriteTo() returned with %d != %d", n, len(bytes))
|
||||
}
|
||||
if err != nil {
|
||||
return errorx.Decorate(err, "WriteTo() returned error")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) genServerFailure(request *dns.Msg) *dns.Msg {
|
||||
resp := dns.Msg{}
|
||||
resp.SetRcode(request, dns.RcodeServerFailure)
|
||||
resp.RecursionAvailable = true
|
||||
return &resp
|
||||
}
|
||||
|
||||
func (s *Server) genNotImpl(request *dns.Msg) *dns.Msg {
|
||||
resp := dns.Msg{}
|
||||
resp.SetRcode(request, dns.RcodeNotImplemented)
|
||||
resp.RecursionAvailable = true
|
||||
resp.SetEdns0(1452, false) // NOTIMPL without EDNS is treated as 'we don't support EDNS', so explicitly set it
|
||||
return &resp
|
||||
}
|
||||
|
||||
func (s *Server) genNXDomain(request *dns.Msg) *dns.Msg {
|
||||
resp := dns.Msg{}
|
||||
resp.SetRcode(request, dns.RcodeNameError)
|
||||
resp.RecursionAvailable = true
|
||||
resp.Ns = s.genSOA(request)
|
||||
return &resp
|
||||
}
|
||||
|
||||
func (s *Server) genSOA(request *dns.Msg) []dns.RR {
|
||||
zone := ""
|
||||
if len(request.Question) > 0 {
|
||||
zone = request.Question[0].Name
|
||||
}
|
||||
|
||||
soa := dns.SOA{
|
||||
// values copied from verisign's nonexistent .com domain
|
||||
// their exact values are not important in our use case because they are used for domain transfers between primary/secondary DNS servers
|
||||
Refresh: 1800,
|
||||
Retry: 900,
|
||||
Expire: 604800,
|
||||
Minttl: 86400,
|
||||
// copied from AdGuard DNS
|
||||
Ns: "fake-for-negative-caching.adguard.com.",
|
||||
Serial: 100500,
|
||||
// rest is request-specific
|
||||
Hdr: dns.RR_Header{
|
||||
Name: zone,
|
||||
Rrtype: dns.TypeSOA,
|
||||
Ttl: s.BlockedResponseTTL,
|
||||
Class: dns.ClassINET,
|
||||
},
|
||||
Mbox: "hostmaster.", // zone will be appended later if it's not empty or "."
|
||||
}
|
||||
if soa.Hdr.Ttl == 0 {
|
||||
soa.Hdr.Ttl = defaultValues.BlockedResponseTTL
|
||||
}
|
||||
if len(zone) > 0 && zone[0] != '.' {
|
||||
soa.Mbox += zone
|
||||
}
|
||||
return []dns.RR{&soa}
|
||||
}
|
||||
|
||||
var once sync.Once
|
49
dnsforward/dnsforward_test.go
Normal file
49
dnsforward/dnsforward_test.go
Normal file
@ -0,0 +1,49 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func TestServer(t *testing.T) {
|
||||
s := Server{}
|
||||
s.UDPListenAddr = &net.UDPAddr{Port: 0}
|
||||
err := s.Start(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to start server: %s", err)
|
||||
}
|
||||
if s.udpListen == nil {
|
||||
t.Fatal("Started server has nil udpListen")
|
||||
}
|
||||
|
||||
// server is running, send a message
|
||||
addr := s.udpListen.LocalAddr()
|
||||
req := dns.Msg{}
|
||||
req.Id = dns.Id()
|
||||
req.RecursionDesired = true
|
||||
req.Question = []dns.Question{
|
||||
{Name: "google-public-dns-a.google.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
|
||||
}
|
||||
|
||||
reply, err := dns.Exchange(&req, addr.String())
|
||||
if err != nil {
|
||||
t.Fatalf("Couldn't talk to server %s: %s", addr, err)
|
||||
}
|
||||
if len(reply.Answer) != 1 {
|
||||
t.Fatalf("DNS server %s returned reply with wrong number of answers - %d", addr, len(reply.Answer))
|
||||
}
|
||||
if a, ok := reply.Answer[0].(*dns.A); ok {
|
||||
if !net.IPv4(8, 8, 8, 8).Equal(a.A) {
|
||||
t.Fatalf("DNS server %s returned wrong answer instead of 8.8.8.8: %v", addr, a.A)
|
||||
}
|
||||
} else {
|
||||
t.Fatalf("DNS server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0])
|
||||
}
|
||||
|
||||
err = s.Stop()
|
||||
if err != nil {
|
||||
t.Fatalf("DNS server %s failed to stop: %s", addr, err)
|
||||
}
|
||||
}
|
50
dnsforward/helpers.go
Normal file
50
dnsforward/helpers.go
Normal file
@ -0,0 +1,50 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path"
|
||||
"runtime"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func isConnClosed(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
nerr, ok := err.(*net.OpError)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
if strings.Contains(nerr.Err.Error(), "use of closed network connection") {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// ---------------------
|
||||
// debug logging helpers
|
||||
// ---------------------
|
||||
func _Func() string {
|
||||
pc := make([]uintptr, 10) // at least 1 entry needed
|
||||
runtime.Callers(2, pc)
|
||||
f := runtime.FuncForPC(pc[0])
|
||||
return path.Base(f.Name())
|
||||
}
|
||||
|
||||
func trace(format string, args ...interface{}) {
|
||||
pc := make([]uintptr, 10) // at least 1 entry needed
|
||||
runtime.Callers(2, pc)
|
||||
f := runtime.FuncForPC(pc[0])
|
||||
var buf strings.Builder
|
||||
buf.WriteString(fmt.Sprintf("%s(): ", path.Base(f.Name())))
|
||||
text := fmt.Sprintf(format, args...)
|
||||
buf.WriteString(text)
|
||||
if len(text) == 0 || text[len(text)-1] != '\n' {
|
||||
buf.WriteRune('\n')
|
||||
}
|
||||
fmt.Fprint(os.Stderr, buf.String())
|
||||
}
|
@ -1,20 +1,16 @@
|
||||
package dnsfilter
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
|
||||
"github.com/coredns/coredns/plugin/pkg/response"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
@ -42,9 +38,10 @@ type logEntry struct {
|
||||
Time time.Time
|
||||
Elapsed time.Duration
|
||||
IP string
|
||||
Upstream string `json:",omitempty"` // if empty, means it was cached
|
||||
}
|
||||
|
||||
func logRequest(question *dns.Msg, answer *dns.Msg, result dnsfilter.Result, elapsed time.Duration, ip string) {
|
||||
func logRequest(question *dns.Msg, answer *dns.Msg, result *dnsfilter.Result, elapsed time.Duration, ip string, upstream string) {
|
||||
var q []byte
|
||||
var a []byte
|
||||
var err error
|
||||
@ -64,14 +61,19 @@ func logRequest(question *dns.Msg, answer *dns.Msg, result dnsfilter.Result, ela
|
||||
}
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
result = &dnsfilter.Result{}
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
entry := logEntry{
|
||||
Question: q,
|
||||
Answer: a,
|
||||
Result: result,
|
||||
Result: *result,
|
||||
Time: now,
|
||||
Elapsed: elapsed,
|
||||
IP: ip,
|
||||
Upstream: upstream,
|
||||
}
|
||||
var flushBuffer []*logEntry
|
||||
|
||||
@ -97,6 +99,8 @@ func logRequest(question *dns.Msg, answer *dns.Msg, result dnsfilter.Result, ela
|
||||
// don't do failure, just log
|
||||
}
|
||||
|
||||
incrementCounters(&entry)
|
||||
|
||||
// if buffer needs to be flushed to disk, do it now
|
||||
if len(flushBuffer) > 0 {
|
||||
// write to file
|
||||
@ -153,8 +157,7 @@ func HandleQueryLog(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
if a != nil {
|
||||
status, _ := response.Typify(a, time.Now().UTC())
|
||||
jsonEntry["status"] = status.String()
|
||||
jsonEntry["status"] = dns.RcodeToString[a.Rcode]
|
||||
}
|
||||
if len(entry.Result.Rule) > 0 {
|
||||
jsonEntry["rule"] = entry.Result.Rule
|
||||
@ -223,17 +226,3 @@ func HandleQueryLog(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, errorText, http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
func trace(format string, args ...interface{}) {
|
||||
pc := make([]uintptr, 10) // at least 1 entry needed
|
||||
runtime.Callers(2, pc)
|
||||
f := runtime.FuncForPC(pc[0])
|
||||
var buf strings.Builder
|
||||
buf.WriteString(fmt.Sprintf("%s(): ", path.Base(f.Name())))
|
||||
text := fmt.Sprintf(format, args...)
|
||||
buf.WriteString(text)
|
||||
if len(text) == 0 || text[len(text)-1] != '\n' {
|
||||
buf.WriteRune('\n')
|
||||
}
|
||||
fmt.Fprint(os.Stderr, buf.String())
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package dnsfilter
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@ -251,41 +251,3 @@ func genericLoader(onEntry func(entry *logEntry) error, needMore func() bool, ti
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func appendFromLogFile(values []*logEntry, maxLen int, timeWindow time.Duration) []*logEntry {
|
||||
a := []*logEntry{}
|
||||
|
||||
onEntry := func(entry *logEntry) error {
|
||||
a = append(a, entry)
|
||||
if len(a) > maxLen {
|
||||
toskip := len(a) - maxLen
|
||||
a = a[toskip:]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
needMore := func() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
err := genericLoader(onEntry, needMore, timeWindow)
|
||||
if err != nil {
|
||||
log.Printf("Failed to load entries from querylog: %s", err)
|
||||
return values
|
||||
}
|
||||
|
||||
// now that we've read all eligible entries, reverse the slice to make it go from newest->oldest
|
||||
for left, right := 0, len(a)-1; left < right; left, right = left+1, right-1 {
|
||||
a[left], a[right] = a[right], a[left]
|
||||
}
|
||||
|
||||
// append it to values
|
||||
values = append(values, a...)
|
||||
|
||||
// then cut off of it is bigger than maxLen
|
||||
if len(values) > maxLen {
|
||||
values = values[:maxLen]
|
||||
}
|
||||
|
||||
return values
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package dnsfilter
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@ -14,7 +14,6 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
|
||||
"github.com/bluele/gcache"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
@ -231,27 +230,7 @@ func fillStatsFromQueryLog() error {
|
||||
}
|
||||
queryLogLock.Unlock()
|
||||
|
||||
requests.IncWithTime(entry.Time)
|
||||
if entry.Result.IsFiltered {
|
||||
filtered.IncWithTime(entry.Time)
|
||||
}
|
||||
switch entry.Result.Reason {
|
||||
case dnsfilter.NotFilteredWhiteList:
|
||||
whitelisted.IncWithTime(entry.Time)
|
||||
case dnsfilter.NotFilteredError:
|
||||
errorsTotal.IncWithTime(entry.Time)
|
||||
case dnsfilter.FilteredBlackList:
|
||||
filteredLists.IncWithTime(entry.Time)
|
||||
case dnsfilter.FilteredSafeBrowsing:
|
||||
filteredSafebrowsing.IncWithTime(entry.Time)
|
||||
case dnsfilter.FilteredParental:
|
||||
filteredParental.IncWithTime(entry.Time)
|
||||
case dnsfilter.FilteredInvalid:
|
||||
// do nothing
|
||||
case dnsfilter.FilteredSafeSearch:
|
||||
safesearch.IncWithTime(entry.Time)
|
||||
}
|
||||
elapsedTime.ObserveWithTime(entry.Elapsed.Seconds(), entry.Time)
|
||||
incrementCounters(entry)
|
||||
|
||||
return nil
|
||||
}
|
80
dnsforward/ratelimit.go
Normal file
80
dnsforward/ratelimit.go
Normal file
@ -0,0 +1,80 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"log"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/beefsack/go-rate"
|
||||
gocache "github.com/patrickmn/go-cache"
|
||||
)
|
||||
|
||||
func (s *Server) limiterForIP(ip string) interface{} {
|
||||
if s.ratelimitBuckets == nil {
|
||||
s.ratelimitBuckets = gocache.New(time.Hour, time.Hour)
|
||||
}
|
||||
|
||||
// check if ratelimiter for that IP already exists, if not, create
|
||||
value, found := s.ratelimitBuckets.Get(ip)
|
||||
if !found {
|
||||
value = rate.New(s.Ratelimit, time.Second)
|
||||
s.ratelimitBuckets.Set(ip, value, time.Hour)
|
||||
}
|
||||
|
||||
return value
|
||||
}
|
||||
|
||||
func (s *Server) isRatelimited(ip string) bool {
|
||||
if s.Ratelimit == 0 { // 0 -- disabled
|
||||
return false
|
||||
}
|
||||
if len(s.RatelimitWhitelist) > 0 {
|
||||
i := sort.SearchStrings(s.RatelimitWhitelist, ip)
|
||||
|
||||
if i < len(s.RatelimitWhitelist) && s.RatelimitWhitelist[i] == ip {
|
||||
// found, don't ratelimit
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
value := s.limiterForIP(ip)
|
||||
rl, ok := value.(*rate.RateLimiter)
|
||||
if !ok {
|
||||
log.Println("SHOULD NOT HAPPEN: non-bool entry found in safebrowsing lookup cache")
|
||||
return false
|
||||
}
|
||||
|
||||
allow, _ := rl.Try()
|
||||
return !allow
|
||||
}
|
||||
|
||||
func (s *Server) isRatelimitedForReply(ip string, size int) bool {
|
||||
if s.Ratelimit == 0 { // 0 -- disabled
|
||||
return false
|
||||
}
|
||||
if len(s.RatelimitWhitelist) > 0 {
|
||||
i := sort.SearchStrings(s.RatelimitWhitelist, ip)
|
||||
|
||||
if i < len(s.RatelimitWhitelist) && s.RatelimitWhitelist[i] == ip {
|
||||
// found, don't ratelimit
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
value := s.limiterForIP(ip)
|
||||
rl, ok := value.(*rate.RateLimiter)
|
||||
if !ok {
|
||||
log.Println("SHOULD NOT HAPPEN: non-bool entry found in safebrowsing lookup cache")
|
||||
return false
|
||||
}
|
||||
|
||||
// For large UDP responses we try more times, effectively limiting per bandwidth
|
||||
// The exact number of times depends on the response size
|
||||
for i := 0; i < size/1000; i++ {
|
||||
allow, _ := rl.Try()
|
||||
if !allow { // not allowed -> ratelimited
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
42
dnsforward/ratelimit_test.go
Normal file
42
dnsforward/ratelimit_test.go
Normal file
@ -0,0 +1,42 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRatelimiting(t *testing.T) {
|
||||
// rate limit is 1 per sec
|
||||
p := Server{}
|
||||
p.Ratelimit = 1
|
||||
|
||||
limited := p.isRatelimited("127.0.0.1")
|
||||
|
||||
if limited {
|
||||
t.Fatal("First request must have been allowed")
|
||||
}
|
||||
|
||||
limited = p.isRatelimited("127.0.0.1")
|
||||
|
||||
if !limited {
|
||||
t.Fatal("Second request must have been ratelimited")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWhitelist(t *testing.T) {
|
||||
// rate limit is 1 per sec with whitelist
|
||||
p := Server{}
|
||||
p.Ratelimit = 1
|
||||
p.RatelimitWhitelist = []string{"127.0.0.1", "127.0.0.2", "127.0.0.125"}
|
||||
|
||||
limited := p.isRatelimited("127.0.0.1")
|
||||
|
||||
if limited {
|
||||
t.Fatal("First request must have been allowed")
|
||||
}
|
||||
|
||||
limited = p.isRatelimited("127.0.0.1")
|
||||
|
||||
if limited {
|
||||
t.Fatal("Second request must have been allowed due to whitelist")
|
||||
}
|
||||
}
|
1
dnsforward/standalone/.gitignore
vendored
Normal file
1
dnsforward/standalone/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
/standalone
|
51
dnsforward/standalone/standalone.go
Normal file
51
dnsforward/standalone/standalone.go
Normal file
@ -0,0 +1,51 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
"os"
|
||||
"os/signal"
|
||||
"runtime"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/dnsforward"
|
||||
)
|
||||
|
||||
//
|
||||
// main function
|
||||
//
|
||||
func main() {
|
||||
go func() {
|
||||
log.Println(http.ListenAndServe("localhost:6060", nil))
|
||||
}()
|
||||
go func() {
|
||||
for range time.Tick(time.Second) {
|
||||
log.Printf("goroutines = %d", runtime.NumGoroutine())
|
||||
}
|
||||
}()
|
||||
s := dnsforward.Server{}
|
||||
err := s.Start(nil)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
err = s.Stop()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = s.Start(&dnsforward.ServerConfig{UDPListenAddr: &net.UDPAddr{Port: 53535}})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = s.Reconfigure(dnsforward.ServerConfig{UDPListenAddr: &net.UDPAddr{Port: 53, IP: net.ParseIP("0.0.0.0")}})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
log.Printf("Now serving DNS")
|
||||
signal_channel := make(chan os.Signal)
|
||||
signal.Notify(signal_channel, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-signal_channel
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package dnsfilter
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
@ -8,21 +8,20 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/coredns/coredns/plugin"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
|
||||
)
|
||||
|
||||
var (
|
||||
requests = newDNSCounter("requests_total", "Count of requests seen by dnsfilter.")
|
||||
filtered = newDNSCounter("filtered_total", "Count of requests filtered by dnsfilter.")
|
||||
filteredLists = newDNSCounter("filtered_lists_total", "Count of requests filtered by dnsfilter using lists.")
|
||||
filteredSafebrowsing = newDNSCounter("filtered_safebrowsing_total", "Count of requests filtered by dnsfilter using safebrowsing.")
|
||||
filteredParental = newDNSCounter("filtered_parental_total", "Count of requests filtered by dnsfilter using parental.")
|
||||
filteredInvalid = newDNSCounter("filtered_invalid_total", "Count of requests filtered by dnsfilter because they were invalid.")
|
||||
whitelisted = newDNSCounter("whitelisted_total", "Count of requests not filtered by dnsfilter because they are whitelisted.")
|
||||
safesearch = newDNSCounter("safesearch_total", "Count of requests replaced by dnsfilter safesearch.")
|
||||
errorsTotal = newDNSCounter("errors_total", "Count of requests that dnsfilter couldn't process because of transitive errors.")
|
||||
elapsedTime = newDNSHistogram("request_duration", "Histogram of the time (in seconds) each request took.")
|
||||
requests = newDNSCounter("requests_total")
|
||||
filtered = newDNSCounter("filtered_total")
|
||||
filteredLists = newDNSCounter("filtered_lists_total")
|
||||
filteredSafebrowsing = newDNSCounter("filtered_safebrowsing_total")
|
||||
filteredParental = newDNSCounter("filtered_parental_total")
|
||||
filteredInvalid = newDNSCounter("filtered_invalid_total")
|
||||
whitelisted = newDNSCounter("whitelisted_total")
|
||||
safesearch = newDNSCounter("safesearch_total")
|
||||
errorsTotal = newDNSCounter("errors_total")
|
||||
elapsedTime = newDNSHistogram("request_duration")
|
||||
)
|
||||
|
||||
// entries for single time period (for example all per-second entries)
|
||||
@ -143,21 +142,13 @@ func statsRotator() {
|
||||
type counter struct {
|
||||
name string // used as key in periodic stats
|
||||
value int64
|
||||
prom prometheus.Counter
|
||||
}
|
||||
|
||||
func newDNSCounter(name string, help string) *counter {
|
||||
func newDNSCounter(name string) *counter {
|
||||
// trace("called")
|
||||
c := &counter{}
|
||||
c.prom = prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Namespace: plugin.Namespace,
|
||||
Subsystem: "dnsfilter",
|
||||
Name: name,
|
||||
Help: help,
|
||||
})
|
||||
c.name = name
|
||||
|
||||
return c
|
||||
return &counter{
|
||||
name: name,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *counter) IncWithTime(when time.Time) {
|
||||
@ -166,40 +157,22 @@ func (c *counter) IncWithTime(when time.Time) {
|
||||
statistics.PerHour.Inc(c.name, when)
|
||||
statistics.PerDay.Inc(c.name, when)
|
||||
c.value++
|
||||
c.prom.Inc()
|
||||
}
|
||||
|
||||
func (c *counter) Inc() {
|
||||
c.IncWithTime(time.Now())
|
||||
}
|
||||
|
||||
func (c *counter) Describe(ch chan<- *prometheus.Desc) {
|
||||
c.prom.Describe(ch)
|
||||
}
|
||||
|
||||
func (c *counter) Collect(ch chan<- prometheus.Metric) {
|
||||
c.prom.Collect(ch)
|
||||
}
|
||||
|
||||
type histogram struct {
|
||||
name string // used as key in periodic stats
|
||||
count int64
|
||||
total float64
|
||||
prom prometheus.Histogram
|
||||
}
|
||||
|
||||
func newDNSHistogram(name string, help string) *histogram {
|
||||
// trace("called")
|
||||
h := &histogram{}
|
||||
h.prom = prometheus.NewHistogram(prometheus.HistogramOpts{
|
||||
Namespace: plugin.Namespace,
|
||||
Subsystem: "dnsfilter",
|
||||
Name: name,
|
||||
Help: help,
|
||||
})
|
||||
h.name = name
|
||||
|
||||
return h
|
||||
func newDNSHistogram(name string) *histogram {
|
||||
return &histogram{
|
||||
name: name,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *histogram) ObserveWithTime(value float64, when time.Time) {
|
||||
@ -209,24 +182,40 @@ func (h *histogram) ObserveWithTime(value float64, when time.Time) {
|
||||
statistics.PerDay.Observe(h.name, when, value)
|
||||
h.count++
|
||||
h.total += value
|
||||
h.prom.Observe(value)
|
||||
}
|
||||
|
||||
func (h *histogram) Observe(value float64) {
|
||||
h.ObserveWithTime(value, time.Now())
|
||||
}
|
||||
|
||||
func (h *histogram) Describe(ch chan<- *prometheus.Desc) {
|
||||
h.prom.Describe(ch)
|
||||
}
|
||||
|
||||
func (h *histogram) Collect(ch chan<- prometheus.Metric) {
|
||||
h.prom.Collect(ch)
|
||||
}
|
||||
|
||||
// -----
|
||||
// stats
|
||||
// -----
|
||||
func incrementCounters(entry *logEntry) {
|
||||
requests.IncWithTime(entry.Time)
|
||||
if entry.Result.IsFiltered {
|
||||
filtered.IncWithTime(entry.Time)
|
||||
}
|
||||
|
||||
switch entry.Result.Reason {
|
||||
case dnsfilter.NotFilteredWhiteList:
|
||||
whitelisted.IncWithTime(entry.Time)
|
||||
case dnsfilter.NotFilteredError:
|
||||
errorsTotal.IncWithTime(entry.Time)
|
||||
case dnsfilter.FilteredBlackList:
|
||||
filteredLists.IncWithTime(entry.Time)
|
||||
case dnsfilter.FilteredSafeBrowsing:
|
||||
filteredSafebrowsing.IncWithTime(entry.Time)
|
||||
case dnsfilter.FilteredParental:
|
||||
filteredParental.IncWithTime(entry.Time)
|
||||
case dnsfilter.FilteredInvalid:
|
||||
// do nothing
|
||||
case dnsfilter.FilteredSafeSearch:
|
||||
safesearch.IncWithTime(entry.Time)
|
||||
}
|
||||
elapsedTime.ObserveWithTime(entry.Elapsed.Seconds(), entry.Time)
|
||||
}
|
||||
|
||||
func HandleStats(w http.ResponseWriter, r *http.Request) {
|
||||
const numHours = 24
|
||||
histrical := generateMapFromStats(&statistics.PerHour, 0, numHours)
|
239
dnsforward/upstream.go
Normal file
239
dnsforward/upstream.go
Normal file
@ -0,0 +1,239 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/joomcode/errorx"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
const defaultTimeout = time.Second * 10
|
||||
|
||||
type Upstream interface {
|
||||
Exchange(m *dns.Msg) (*dns.Msg, error)
|
||||
Address() string
|
||||
}
|
||||
|
||||
//
|
||||
// plain DNS
|
||||
//
|
||||
type plainDNS struct {
|
||||
boot bootstrapper
|
||||
preferTCP bool
|
||||
}
|
||||
|
||||
var defaultUDPClient = dns.Client{
|
||||
Timeout: defaultTimeout,
|
||||
UDPSize: dns.MaxMsgSize,
|
||||
}
|
||||
|
||||
var defaultTCPClient = dns.Client{
|
||||
Net: "tcp",
|
||||
UDPSize: dns.MaxMsgSize,
|
||||
Timeout: defaultTimeout,
|
||||
}
|
||||
|
||||
// Address returns the original address that we've put in initially, not resolved one
|
||||
func (p *plainDNS) Address() string { return p.boot.address }
|
||||
|
||||
func (p *plainDNS) Exchange(m *dns.Msg) (*dns.Msg, error) {
|
||||
addr, _, err := p.boot.get()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if p.preferTCP {
|
||||
reply, _, err := defaultTCPClient.Exchange(m, addr)
|
||||
return reply, err
|
||||
}
|
||||
|
||||
reply, _, err := defaultUDPClient.Exchange(m, addr)
|
||||
if err != nil && reply != nil && reply.Truncated {
|
||||
log.Printf("Truncated message was received, retrying over TCP, question: %s", m.Question[0].String())
|
||||
reply, _, err = defaultTCPClient.Exchange(m, addr)
|
||||
}
|
||||
|
||||
return reply, err
|
||||
}
|
||||
|
||||
//
|
||||
// DNS-over-TLS
|
||||
//
|
||||
type dnsOverTLS struct {
|
||||
boot bootstrapper
|
||||
pool *TLSPool
|
||||
|
||||
sync.RWMutex // protects pool
|
||||
}
|
||||
|
||||
func (p *dnsOverTLS) Address() string { return p.boot.address }
|
||||
|
||||
func (p *dnsOverTLS) Exchange(m *dns.Msg) (*dns.Msg, error) {
|
||||
var pool *TLSPool
|
||||
p.RLock()
|
||||
pool = p.pool
|
||||
p.RUnlock()
|
||||
if pool == nil {
|
||||
p.Lock()
|
||||
// lazy initialize it
|
||||
p.pool = &TLSPool{boot: &p.boot}
|
||||
p.Unlock()
|
||||
}
|
||||
|
||||
p.RLock()
|
||||
poolConn, err := p.pool.Get()
|
||||
p.RUnlock()
|
||||
if err != nil {
|
||||
return nil, errorx.Decorate(err, "Failed to get a connection from TLSPool to %s", p.Address())
|
||||
}
|
||||
c := dns.Conn{Conn: poolConn}
|
||||
err = c.WriteMsg(m)
|
||||
if err != nil {
|
||||
poolConn.Close()
|
||||
return nil, errorx.Decorate(err, "Failed to send a request to %s", p.Address())
|
||||
}
|
||||
|
||||
reply, err := c.ReadMsg()
|
||||
if err != nil {
|
||||
poolConn.Close()
|
||||
return nil, errorx.Decorate(err, "Failed to read a request from %s", p.Address())
|
||||
}
|
||||
p.RLock()
|
||||
p.pool.Put(poolConn)
|
||||
p.RUnlock()
|
||||
return reply, nil
|
||||
}
|
||||
|
||||
//
|
||||
// DNS-over-https
|
||||
//
|
||||
type dnsOverHTTPS struct {
|
||||
boot bootstrapper
|
||||
}
|
||||
|
||||
func (p *dnsOverHTTPS) Address() string { return p.boot.address }
|
||||
|
||||
func (p *dnsOverHTTPS) Exchange(m *dns.Msg) (*dns.Msg, error) {
|
||||
addr, tlsConfig, err := p.boot.get()
|
||||
if err != nil {
|
||||
return nil, errorx.Decorate(err, "Couldn't bootstrap %s", p.boot.address)
|
||||
}
|
||||
|
||||
buf, err := m.Pack()
|
||||
if err != nil {
|
||||
return nil, errorx.Decorate(err, "Couldn't pack request msg")
|
||||
}
|
||||
bb := bytes.NewBuffer(buf)
|
||||
|
||||
// set up a custom request with custom URL
|
||||
url, err := url.Parse(p.boot.address)
|
||||
if err != nil {
|
||||
return nil, errorx.Decorate(err, "Couldn't parse URL %s", p.boot.address)
|
||||
}
|
||||
req := http.Request{
|
||||
Method: "POST",
|
||||
URL: url,
|
||||
Body: ioutil.NopCloser(bb),
|
||||
Header: make(http.Header),
|
||||
Host: url.Host,
|
||||
}
|
||||
url.Host = addr
|
||||
req.Header.Set("Content-Type", "application/dns-message")
|
||||
client := http.Client{
|
||||
Transport: &http.Transport{TLSClientConfig: tlsConfig},
|
||||
}
|
||||
resp, err := client.Do(&req)
|
||||
if resp != nil && resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
if err != nil {
|
||||
return nil, errorx.Decorate(err, "Couldn't do a POST request to '%s'", addr)
|
||||
}
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, errorx.Decorate(err, "Couldn't read body contents for '%s'", addr)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("Got an unexpected HTTP status code %d from '%s'", resp.StatusCode, addr)
|
||||
}
|
||||
if len(body) == 0 {
|
||||
return nil, fmt.Errorf("Got an unexpected empty body from '%s'", addr)
|
||||
}
|
||||
response := dns.Msg{}
|
||||
err = response.Unpack(body)
|
||||
if err != nil {
|
||||
return nil, errorx.Decorate(err, "Couldn't unpack DNS response from '%s': body is %s", addr, string(body))
|
||||
}
|
||||
return &response, nil
|
||||
}
|
||||
|
||||
func (s *Server) chooseUpstream() Upstream {
|
||||
upstreams := s.Upstreams
|
||||
if upstreams == nil {
|
||||
upstreams = defaultValues.Upstreams
|
||||
}
|
||||
if len(upstreams) == 0 {
|
||||
panic("SHOULD NOT HAPPEN: no default upstreams specified")
|
||||
}
|
||||
if len(upstreams) == 1 {
|
||||
return upstreams[0]
|
||||
}
|
||||
n := rand.Intn(len(upstreams))
|
||||
upstream := upstreams[n]
|
||||
return upstream
|
||||
}
|
||||
|
||||
func AddressToUpstream(address string, bootstrap string) (Upstream, error) {
|
||||
if strings.Contains(address, "://") {
|
||||
url, err := url.Parse(address)
|
||||
if err != nil {
|
||||
return nil, errorx.Decorate(err, "Failed to parse %s", address)
|
||||
}
|
||||
switch url.Scheme {
|
||||
case "dns":
|
||||
if url.Port() == "" {
|
||||
url.Host += ":53"
|
||||
}
|
||||
return &plainDNS{boot: toBoot(url.Host, bootstrap)}, nil
|
||||
case "tcp":
|
||||
if url.Port() == "" {
|
||||
url.Host += ":53"
|
||||
}
|
||||
return &plainDNS{boot: toBoot(url.Host, bootstrap), preferTCP: true}, nil
|
||||
case "tls":
|
||||
if url.Port() == "" {
|
||||
url.Host += ":853"
|
||||
}
|
||||
return &dnsOverTLS{boot: toBoot(url.String(), bootstrap)}, nil
|
||||
case "https":
|
||||
if url.Port() == "" {
|
||||
url.Host += ":443"
|
||||
}
|
||||
return &dnsOverHTTPS{boot: toBoot(url.String(), bootstrap)}, nil
|
||||
default:
|
||||
// assume it's plain DNS
|
||||
if url.Port() == "" {
|
||||
url.Host += ":53"
|
||||
}
|
||||
return &plainDNS{boot: toBoot(url.String(), bootstrap)}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// we don't have scheme in the url, so it's just a plain DNS host:port
|
||||
_, _, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
// doesn't have port, default to 53
|
||||
address = net.JoinHostPort(address, "53")
|
||||
}
|
||||
return &plainDNS{boot: toBoot(address, bootstrap)}, nil
|
||||
}
|
74
dnsforward/upstream_pool.go
Normal file
74
dnsforward/upstream_pool.go
Normal file
@ -0,0 +1,74 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/joomcode/errorx"
|
||||
)
|
||||
|
||||
// Upstream TLS pool.
|
||||
//
|
||||
// Example:
|
||||
// pool := TLSPool{Address: "tls://1.1.1.1:853"}
|
||||
// netConn, err := pool.Get()
|
||||
// if err != nil {panic(err)}
|
||||
// c := dns.Conn{Conn: netConn}
|
||||
// q := dns.Msg{}
|
||||
// q.SetQuestion("google.com.", dns.TypeA)
|
||||
// log.Println(q)
|
||||
// err = c.WriteMsg(&q)
|
||||
// if err != nil {panic(err)}
|
||||
// r, err := c.ReadMsg()
|
||||
// if err != nil {panic(err)}
|
||||
// log.Println(r)
|
||||
// pool.Put(c.Conn)
|
||||
type TLSPool struct {
|
||||
boot *bootstrapper
|
||||
|
||||
// connections
|
||||
conns []net.Conn
|
||||
connsMutex sync.Mutex // protects conns
|
||||
}
|
||||
|
||||
func (n *TLSPool) Get() (net.Conn, error) {
|
||||
address, tlsConfig, err := n.boot.get()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// get the connection from the slice inside the lock
|
||||
var c net.Conn
|
||||
n.connsMutex.Lock()
|
||||
num := len(n.conns)
|
||||
if num > 0 {
|
||||
last := num - 1
|
||||
c = n.conns[last]
|
||||
n.conns = n.conns[:last]
|
||||
}
|
||||
n.connsMutex.Unlock()
|
||||
|
||||
// if we got connection from the slice, return it
|
||||
if c != nil {
|
||||
// log.Printf("Returning existing connection to %s", host)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// we'll need a new connection, dial now
|
||||
// log.Printf("Dialing to %s", address)
|
||||
conn, err := tls.Dial("tcp", address, tlsConfig)
|
||||
if err != nil {
|
||||
return nil, errorx.Decorate(err, "Failed to connect to %s", address)
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (n *TLSPool) Put(c net.Conn) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
n.connsMutex.Lock()
|
||||
n.conns = append(n.conns, c)
|
||||
n.connsMutex.Unlock()
|
||||
}
|
96
dnsforward/upstream_test.go
Normal file
96
dnsforward/upstream_test.go
Normal file
@ -0,0 +1,96 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func TestUpstreams(t *testing.T) {
|
||||
upstreams := []struct {
|
||||
address string
|
||||
bootstrap string
|
||||
}{
|
||||
{
|
||||
address: "8.8.8.8:53",
|
||||
bootstrap: "8.8.8.8:53",
|
||||
},
|
||||
{
|
||||
address: "1.1.1.1",
|
||||
bootstrap: "",
|
||||
},
|
||||
{
|
||||
address: "tcp://1.1.1.1:53",
|
||||
bootstrap: "",
|
||||
},
|
||||
{
|
||||
address: "176.103.130.130:5353",
|
||||
bootstrap: "",
|
||||
},
|
||||
{
|
||||
address: "tls://1.1.1.1",
|
||||
bootstrap: "",
|
||||
},
|
||||
{
|
||||
address: "tls://9.9.9.9:853",
|
||||
bootstrap: "",
|
||||
},
|
||||
{
|
||||
address: "tls://security-filter-dns.cleanbrowsing.org",
|
||||
bootstrap: "8.8.8.8:53",
|
||||
},
|
||||
{
|
||||
address: "tls://adult-filter-dns.cleanbrowsing.org:853",
|
||||
bootstrap: "8.8.8.8:53",
|
||||
},
|
||||
{
|
||||
address: "https://cloudflare-dns.com/dns-query",
|
||||
bootstrap: "8.8.8.8:53",
|
||||
},
|
||||
{
|
||||
address: "https://dns.google.com/experimental",
|
||||
bootstrap: "8.8.8.8:53",
|
||||
},
|
||||
{
|
||||
address: "https://doh.cleanbrowsing.org/doh/security-filter/",
|
||||
bootstrap: "",
|
||||
},
|
||||
}
|
||||
for _, test := range upstreams {
|
||||
t.Run(test.address, func(t *testing.T) {
|
||||
u, err := AddressToUpstream(test.address, test.bootstrap)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate upstream from address %s: %s", test.address, err)
|
||||
}
|
||||
|
||||
checkUpstream(t, u, test.address)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func checkUpstream(t *testing.T, u Upstream, addr string) {
|
||||
t.Helper()
|
||||
|
||||
req := dns.Msg{}
|
||||
req.Id = dns.Id()
|
||||
req.RecursionDesired = true
|
||||
req.Question = []dns.Question{
|
||||
{Name: "google-public-dns-a.google.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
|
||||
}
|
||||
|
||||
reply, err := u.Exchange(&req)
|
||||
if err != nil {
|
||||
t.Fatalf("Couldn't talk to upstream %s: %s", addr, err)
|
||||
}
|
||||
if len(reply.Answer) != 1 {
|
||||
t.Fatalf("DNS upstream %s returned reply with wrong number of answers - %d", addr, len(reply.Answer))
|
||||
}
|
||||
if a, ok := reply.Answer[0].(*dns.A); ok {
|
||||
if !net.IPv4(8, 8, 8, 8).Equal(a.A) {
|
||||
t.Fatalf("DNS upstream %s returned wrong answer instead of 8.8.8.8: %v", addr, a.A)
|
||||
}
|
||||
} else {
|
||||
t.Fatalf("DNS upstream %s returned wrong answer type instead of A: %v", addr, reply.Answer[0])
|
||||
}
|
||||
}
|
251
filter.go
Normal file
251
filter.go
Normal file
@ -0,0 +1,251 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
|
||||
)
|
||||
|
||||
var (
|
||||
nextFilterID = time.Now().Unix() // semi-stable way to generate an unique ID
|
||||
filterTitleRegexp = regexp.MustCompile(`^! Title: +(.*)$`)
|
||||
)
|
||||
|
||||
// field ordering is important -- yaml fields will mirror ordering from here
|
||||
type filter struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
URL string `json:"url"`
|
||||
Name string `json:"name" yaml:"name"`
|
||||
RulesCount int `json:"rulesCount" yaml:"-"`
|
||||
LastUpdated time.Time `json:"lastUpdated,omitempty" yaml:"last_updated,omitempty"`
|
||||
|
||||
dnsfilter.Filter `yaml:",inline"`
|
||||
}
|
||||
|
||||
// Creates a helper object for working with the user rules
|
||||
func userFilter() filter {
|
||||
return filter{
|
||||
// User filter always has constant ID=0
|
||||
Enabled: true,
|
||||
Filter: dnsfilter.Filter{
|
||||
Rules: config.UserRules,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func deduplicateFilters() {
|
||||
// Deduplicate filters
|
||||
i := 0 // output index, used for deletion later
|
||||
urls := map[string]bool{}
|
||||
for _, filter := range config.Filters {
|
||||
if _, ok := urls[filter.URL]; !ok {
|
||||
// we didn't see it before, keep it
|
||||
urls[filter.URL] = true // remember the URL
|
||||
config.Filters[i] = filter
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
// all entries we want to keep are at front, delete the rest
|
||||
config.Filters = config.Filters[:i]
|
||||
}
|
||||
|
||||
// Set the next filter ID to max(filter.ID) + 1
|
||||
func updateUniqueFilterID(filters []filter) {
|
||||
for _, filter := range filters {
|
||||
if nextFilterID < filter.ID {
|
||||
nextFilterID = filter.ID + 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func assignUniqueFilterID() int64 {
|
||||
value := nextFilterID
|
||||
nextFilterID += 1
|
||||
return value
|
||||
}
|
||||
|
||||
// Sets up a timer that will be checking for filters updates periodically
|
||||
func periodicallyRefreshFilters() {
|
||||
for range time.Tick(time.Minute) {
|
||||
refreshFiltersIfNeccessary(false)
|
||||
}
|
||||
}
|
||||
|
||||
// Checks filters updates if necessary
|
||||
// If force is true, it ignores the filter.LastUpdated field value
|
||||
func refreshFiltersIfNeccessary(force bool) int {
|
||||
config.Lock()
|
||||
|
||||
// fetch URLs
|
||||
updateCount := 0
|
||||
for i := range config.Filters {
|
||||
filter := &config.Filters[i] // otherwise we will be operating on a copy
|
||||
|
||||
if filter.ID == 0 { // protect against users modifying the yaml and removing the ID
|
||||
filter.ID = assignUniqueFilterID()
|
||||
}
|
||||
|
||||
updated, err := filter.update(force)
|
||||
if err != nil {
|
||||
log.Printf("Failed to update filter %s: %s\n", filter.URL, err)
|
||||
continue
|
||||
}
|
||||
if updated {
|
||||
// Saving it to the filters dir now
|
||||
err = filter.save()
|
||||
if err != nil {
|
||||
log.Printf("Failed to save the updated filter %d: %s", filter.ID, err)
|
||||
continue
|
||||
}
|
||||
|
||||
updateCount++
|
||||
}
|
||||
}
|
||||
config.Unlock()
|
||||
|
||||
if updateCount > 0 {
|
||||
reconfigureDNSServer()
|
||||
}
|
||||
return updateCount
|
||||
}
|
||||
|
||||
// A helper function that parses filter contents and returns a number of rules and a filter name (if there's any)
|
||||
func parseFilterContents(contents []byte) (int, string, []string) {
|
||||
lines := strings.Split(string(contents), "\n")
|
||||
rulesCount := 0
|
||||
name := ""
|
||||
seenTitle := false
|
||||
|
||||
// Count lines in the filter
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if len(line) > 0 && line[0] == '!' {
|
||||
if m := filterTitleRegexp.FindAllStringSubmatch(line, -1); len(m) > 0 && len(m[0]) >= 2 && !seenTitle {
|
||||
name = m[0][1]
|
||||
seenTitle = true
|
||||
}
|
||||
} else if len(line) != 0 {
|
||||
rulesCount++
|
||||
}
|
||||
}
|
||||
|
||||
return rulesCount, name, lines
|
||||
}
|
||||
|
||||
// Checks for filters updates
|
||||
// If "force" is true -- does not check the filter's LastUpdated field
|
||||
// Call "save" to persist the filter contents
|
||||
func (filter *filter) update(force bool) (bool, error) {
|
||||
if filter.ID == 0 { // protect against users deleting the ID
|
||||
filter.ID = assignUniqueFilterID()
|
||||
}
|
||||
if !filter.Enabled {
|
||||
return false, nil
|
||||
}
|
||||
if !force && time.Since(filter.LastUpdated) <= updatePeriod {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
log.Printf("Downloading update for filter %d from %s", filter.ID, filter.URL)
|
||||
|
||||
// use the same update period for failed filter downloads to avoid flooding with requests
|
||||
filter.LastUpdated = time.Now()
|
||||
|
||||
resp, err := client.Get(filter.URL)
|
||||
if resp != nil && resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
if err != nil {
|
||||
log.Printf("Couldn't request filter from URL %s, skipping: %s", filter.URL, err)
|
||||
return false, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
log.Printf("Got status code %d from URL %s, skipping", resp.StatusCode, filter.URL)
|
||||
return false, fmt.Errorf("got status code != 200: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
contentType := strings.ToLower(resp.Header.Get("content-type"))
|
||||
if !strings.HasPrefix(contentType, "text/plain") {
|
||||
log.Printf("Non-text response %s from %s, skipping", contentType, filter.URL)
|
||||
return false, fmt.Errorf("non-text response %s", contentType)
|
||||
}
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
log.Printf("Couldn't fetch filter contents from URL %s, skipping: %s", filter.URL, err)
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Extract filter name and count number of rules
|
||||
rulesCount, filterName, rules := parseFilterContents(body)
|
||||
|
||||
if filterName != "" {
|
||||
filter.Name = filterName
|
||||
}
|
||||
|
||||
// Check if the filter has been really changed
|
||||
if reflect.DeepEqual(filter.Rules, rules) {
|
||||
log.Printf("Filter #%d at URL %s hasn't changed, not updating it", filter.ID, filter.URL)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
log.Printf("Filter %d has been updated: %d bytes, %d rules", filter.ID, len(body), rulesCount)
|
||||
filter.RulesCount = rulesCount
|
||||
filter.Rules = rules
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// saves filter contents to the file in dataDir
|
||||
func (filter *filter) save() error {
|
||||
filterFilePath := filter.Path()
|
||||
log.Printf("Saving filter %d contents to: %s", filter.ID, filterFilePath)
|
||||
body := []byte(strings.Join(filter.Rules, "\n"))
|
||||
|
||||
return safeWriteFile(filterFilePath, body)
|
||||
}
|
||||
|
||||
// loads filter contents from the file in dataDir
|
||||
func (filter *filter) load() error {
|
||||
if !filter.Enabled {
|
||||
// No need to load a filter that is not enabled
|
||||
return nil
|
||||
}
|
||||
|
||||
filterFilePath := filter.Path()
|
||||
log.Printf("Loading filter %d contents to: %s", filter.ID, filterFilePath)
|
||||
|
||||
if _, err := os.Stat(filterFilePath); os.IsNotExist(err) {
|
||||
// do nothing, file doesn't exist
|
||||
return err
|
||||
}
|
||||
|
||||
filterFileContents, err := ioutil.ReadFile(filterFilePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("File %s, id %d, length %d", filterFilePath, filter.ID, len(filterFileContents))
|
||||
rulesCount, _, rules := parseFilterContents(filterFileContents)
|
||||
|
||||
filter.RulesCount = rulesCount
|
||||
filter.Rules = rules
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Path to the filter contents
|
||||
func (filter *filter) Path() string {
|
||||
return filepath.Join(config.ourBinaryDir, dataDir, filterDir, strconv.FormatInt(filter.ID, 10)+".txt")
|
||||
}
|
17
go.mod
17
go.mod
@ -3,34 +3,19 @@ module github.com/AdguardTeam/AdGuardHome
|
||||
require (
|
||||
github.com/StackExchange/wmi v0.0.0-20180725035823-b12b22c5341f // indirect
|
||||
github.com/beefsack/go-rate v0.0.0-20180408011153-efa7637bb9b6
|
||||
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973 // indirect
|
||||
github.com/bluele/gcache v0.0.0-20171010155617-472614239ac7
|
||||
github.com/coredns/coredns v1.2.6
|
||||
github.com/dnstap/golang-dnstap v0.0.0-20170829151710-2cf77a2b5e11 // indirect
|
||||
github.com/farsightsec/golang-framestream v0.0.0-20181102145529-8a0cb8ba8710 // indirect
|
||||
github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568 // indirect
|
||||
github.com/go-ole/go-ole v1.2.1 // indirect
|
||||
github.com/go-test/deep v1.0.1
|
||||
github.com/gobuffalo/packr v1.19.0
|
||||
github.com/google/uuid v1.0.0 // indirect
|
||||
github.com/grpc-ecosystem/grpc-opentracing v0.0.0-20180507213350-8e809c8a8645 // indirect
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect
|
||||
github.com/mholt/caddy v0.11.0
|
||||
github.com/joomcode/errorx v0.1.0
|
||||
github.com/miekg/dns v1.0.15
|
||||
github.com/opentracing/opentracing-go v1.0.2 // indirect
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible
|
||||
github.com/pkg/errors v0.8.0
|
||||
github.com/prometheus/client_golang v0.9.0-pre1
|
||||
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910 // indirect
|
||||
github.com/prometheus/common v0.0.0-20181109100915-0b1957f9d949 // indirect
|
||||
github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d // indirect
|
||||
github.com/shirou/gopsutil v2.18.10+incompatible
|
||||
github.com/shirou/w32 v0.0.0-20160930032740-bb4de0191aa4 // indirect
|
||||
go.uber.org/goleak v0.10.0
|
||||
golang.org/x/crypto v0.0.0-20181106171534-e4dc69e5b2fd
|
||||
golang.org/x/net v0.0.0-20181108082009-03003ca0c849
|
||||
golang.org/x/sys v0.0.0-20181107165924-66b7b1311ac8 // indirect
|
||||
google.golang.org/grpc v1.16.0 // indirect
|
||||
gopkg.in/asaskevich/govalidator.v4 v4.0.0-20160518190739-766470278477
|
||||
gopkg.in/yaml.v2 v2.2.1
|
||||
)
|
||||
|
52
go.sum
52
go.sum
@ -1,23 +1,11 @@
|
||||
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
|
||||
github.com/StackExchange/wmi v0.0.0-20180725035823-b12b22c5341f h1:5ZfJxyXo8KyX8DgGXC5B7ILL8y51fci/qYz2B4j8iLY=
|
||||
github.com/StackExchange/wmi v0.0.0-20180725035823-b12b22c5341f/go.mod h1:3eOhrUMpNV+6aFIbp5/iudMxNCF27Vw2OZgy4xEx0Fg=
|
||||
github.com/beefsack/go-rate v0.0.0-20180408011153-efa7637bb9b6 h1:KXlsf+qt/X5ttPGEjR0tPH1xaWWoKBEg9Q1THAj2h3I=
|
||||
github.com/beefsack/go-rate v0.0.0-20180408011153-efa7637bb9b6/go.mod h1:6YNgTHLutezwnBvyneBbwvB8C82y3dcoOj5EQJIdGXA=
|
||||
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973 h1:xJ4a3vCFaGF/jqvzLMYoU8P317H5OQ+Via4RmuPwCS0=
|
||||
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=
|
||||
github.com/bluele/gcache v0.0.0-20171010155617-472614239ac7 h1:NpQ+gkFOH27AyDypSCJ/LdsIi/b4rdnEb1N5+IpFfYs=
|
||||
github.com/bluele/gcache v0.0.0-20171010155617-472614239ac7/go.mod h1:8c4/i2VlovMO2gBnHGQPN5EJw+H0lx1u/5p+cgsXtCk=
|
||||
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
|
||||
github.com/coredns/coredns v1.2.6 h1:QIAOkBqVE44Zx0ttrFqgE5YhCEn64XPIngU60JyuTGM=
|
||||
github.com/coredns/coredns v1.2.6/go.mod h1:zASH/MVDgR6XZTbxvOnsZfffS+31vg6Ackf/wo1+AM0=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dnstap/golang-dnstap v0.0.0-20170829151710-2cf77a2b5e11 h1:m8nX8hsUghn853BJ5qB0lX+VvS6LTJPksWyILFZRYN4=
|
||||
github.com/dnstap/golang-dnstap v0.0.0-20170829151710-2cf77a2b5e11/go.mod h1:s1PfVYYVmTMgCSPtho4LKBDecEHJWtiVDPNv78Z985U=
|
||||
github.com/farsightsec/golang-framestream v0.0.0-20181102145529-8a0cb8ba8710 h1:QdyRyGZWLEvJG5Kw3VcVJvhXJ5tZ1MkRgqpJOEZSySM=
|
||||
github.com/farsightsec/golang-framestream v0.0.0-20181102145529-8a0cb8ba8710/go.mod h1:eNde4IQyEiA5br02AouhEHCu3p3UzrCdFR4LuQHklMI=
|
||||
github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568 h1:BHsljHzVlRcyQhjrss6TZTdY2VfCqZPbv5k3iBFa2ZQ=
|
||||
github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc=
|
||||
github.com/go-ole/go-ole v1.2.1 h1:2lOsA72HgjxAuMlKpFiCbHTvu44PIVkZ5hqm3RSdI/E=
|
||||
github.com/go-ole/go-ole v1.2.1/go.mod h1:7FAglXiTm7HKlQRDeOQ6ZNUHidzCWXuZWq/1dTyBNF8=
|
||||
github.com/go-test/deep v1.0.1 h1:UQhStjbkDClarlmv0am7OXXO4/GaPdCGiUiMTvi28sg=
|
||||
@ -28,44 +16,21 @@ github.com/gobuffalo/packd v0.0.0-20181031195726-c82734870264 h1:roWyi0eEdiFreSq
|
||||
github.com/gobuffalo/packd v0.0.0-20181031195726-c82734870264/go.mod h1:Yf2toFaISlyQrr5TfO3h6DB9pl9mZRmyvBGQb/aQ/pI=
|
||||
github.com/gobuffalo/packr v1.19.0 h1:3UDmBDxesCOPF8iZdMDBBWKfkBoYujIMIZePnobqIUI=
|
||||
github.com/gobuffalo/packr v1.19.0/go.mod h1:MstrNkfCQhd5o+Ct4IJ0skWlxN8emOq8DsoT1G98VIU=
|
||||
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58=
|
||||
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
|
||||
github.com/golang/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E=
|
||||
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
|
||||
github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM=
|
||||
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/google/uuid v1.0.0 h1:b4Gk+7WdP/d3HZH8EJsZpvV7EtDOgaZLtnaNGIu1adA=
|
||||
github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/grpc-ecosystem/grpc-opentracing v0.0.0-20180507213350-8e809c8a8645 h1:MJG/KsmcqMwFAkh8mTnAwhyKoB+sTAnY4CACC110tbU=
|
||||
github.com/grpc-ecosystem/grpc-opentracing v0.0.0-20180507213350-8e809c8a8645/go.mod h1:6iZfnjpejD4L/4DwD7NryNaJyCQdzwWwH2MWhCA90Kw=
|
||||
github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8=
|
||||
github.com/joho/godotenv v1.3.0 h1:Zjp+RcGpHhGlrMbJzXTrZZPrWj+1vfm90La1wgB6Bhc=
|
||||
github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg=
|
||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||
github.com/joomcode/errorx v0.1.0 h1:QmJMiI1DE1UFje2aI1ZWO/VMT5a32qBoXUclGOt8vsc=
|
||||
github.com/joomcode/errorx v0.1.0/go.mod h1:kgco15ekB6cs+4Xjzo7SPeXzx38PbJzBwbnu9qfVNHQ=
|
||||
github.com/markbates/oncer v0.0.0-20181014194634-05fccaae8fc4 h1:Mlji5gkcpzkqTROyE4ZxZ8hN7osunMb2RuGVrbvMvCc=
|
||||
github.com/markbates/oncer v0.0.0-20181014194634-05fccaae8fc4/go.mod h1:Ld9puTsIW75CHf65OeIOkyKbteujpZVXDpWK6YGZbxE=
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU=
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
|
||||
github.com/mholt/caddy v0.11.0 h1:cuhEyR7So/SBBRiAaiRBe9BoccDu6uveIPuM9FMMavg=
|
||||
github.com/mholt/caddy v0.11.0/go.mod h1:Wb1PlT4DAYSqOEd03MsqkdkXnTxA8v9pKjdpxbqM1kY=
|
||||
github.com/miekg/dns v1.0.15 h1:9+UupePBQCG6zf1q/bGmTO1vumoG13jsrbWOSX1W6Tw=
|
||||
github.com/miekg/dns v1.0.15/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg=
|
||||
github.com/opentracing/opentracing-go v1.0.2 h1:3jA2P6O1F9UOrWVpwrIo17pu01KWvNWg4X946/Y5Zwg=
|
||||
github.com/opentracing/opentracing-go v1.0.2/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o=
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
|
||||
github.com/pkg/errors v0.8.0 h1:WdK/asTD0HN+q6hsWO3/vpuAkAr+tw6aNJNDFFf0+qw=
|
||||
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/prometheus/client_golang v0.9.0-pre1 h1:AWTOhsOI9qxeirTuA0A4By/1Es1+y9EcCGY6bBZ2fhM=
|
||||
github.com/prometheus/client_golang v0.9.0-pre1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
|
||||
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910 h1:idejC8f05m9MGOsuEi1ATq9shN03HrxNkD/luQvxCv8=
|
||||
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
|
||||
github.com/prometheus/common v0.0.0-20181109100915-0b1957f9d949 h1:MVbUQq1a49hMEISI29UcAUjywT3FyvDwx5up90OvVa4=
|
||||
github.com/prometheus/common v0.0.0-20181109100915-0b1957f9d949/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro=
|
||||
github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d h1:GoAlyOgbOEIFdaDqxJVlbOQ1DtGmZWs/Qau0hIlk+WQ=
|
||||
github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
|
||||
github.com/shirou/gopsutil v2.18.10+incompatible h1:cy84jW6EVRPa5g9HAHrlbxMSIjBhDSX0OFYyMYminYs=
|
||||
github.com/shirou/gopsutil v2.18.10+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
|
||||
github.com/shirou/w32 v0.0.0-20160930032740-bb4de0191aa4 h1:udFKJ0aHUL60LboW/A+DfgoHVedieIzIXE8uylPue0U=
|
||||
@ -80,29 +45,16 @@ go.uber.org/goleak v0.10.0 h1:G3eWbSNIskeRqtsN/1uI5B+eP73y3JUuBsv9AZjehb4=
|
||||
go.uber.org/goleak v0.10.0/go.mod h1:VCZuO8V8mFPlL0F5J5GK1rtHV3DrFcQ1R8ryq7FK0aI=
|
||||
golang.org/x/crypto v0.0.0-20181106171534-e4dc69e5b2fd h1:VtIkGDhk0ph3t+THbvXHfMZ8QHgsBO39Nh52+74pq7w=
|
||||
golang.org/x/crypto v0.0.0-20181106171534-e4dc69e5b2fd/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||
golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
|
||||
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20181102091132-c10e9556a7bc/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20181108082009-03003ca0c849 h1:FSqE2GGG7wzsYUsWiQ8MZrvEd1EOyU3NCF0AW3Wtltg=
|
||||
golang.org/x/net v0.0.0-20181108082009-03003ca0c849/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f h1:wMNYb4v58l5UBM7MYRLPG6ZhfOqbKu7X5eyFl8ZhKvA=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20181107165924-66b7b1311ac8 h1:YoY1wS6JYVRpIfFngRf2HHo9R9dAne3xbkGOQ5rJXjU=
|
||||
golang.org/x/sys v0.0.0-20181107165924-66b7b1311ac8/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
|
||||
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8 h1:Nw54tB0rB7hY/N0NQvRW8DG4Yk3Q6T9cu9RcFQDu1tc=
|
||||
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
|
||||
google.golang.org/grpc v1.16.0 h1:dz5IJGuC2BB7qXR5AyHNwAUBhZscK2xVez7mznh72sY=
|
||||
google.golang.org/grpc v1.16.0/go.mod h1:0JHn/cJsOMiMfNA9+DeHDlAU7KAAB5GDlYFpa9MZMio=
|
||||
gopkg.in/asaskevich/govalidator.v4 v4.0.0-20160518190739-766470278477 h1:5xUJw+lg4zao9W4HIDzlFbMYgSgtvNVHh00MEHvbGpQ=
|
||||
gopkg.in/asaskevich/govalidator.v4 v4.0.0-20160518190739-766470278477/go.mod h1:QDV1vrFSrowdoOba0UM8VJPUZONT7dnfdLsM+GG53Z8=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v2 v2.2.1 h1:mUhvW9EsL+naU5Q3cakzfE91YhliOondGd6ZrsDBHQE=
|
||||
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
|
51
upgrade.go
51
upgrade.go
@ -10,6 +10,8 @@ import (
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
const currentSchemaVersion = 2 // used for upgrading from old configs to new config
|
||||
|
||||
// Performs necessary upgrade operations if needed
|
||||
func upgradeConfig() error {
|
||||
// read a config file into an interface map, so we can manipulate values without losing any
|
||||
@ -57,7 +59,12 @@ func upgradeConfig() error {
|
||||
func upgradeConfigSchema(oldVersion int, diskConfig *map[string]interface{}) error {
|
||||
switch oldVersion {
|
||||
case 0:
|
||||
err := upgradeSchema0to1(diskConfig)
|
||||
err := upgradeSchema0to2(diskConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case 1:
|
||||
err := upgradeSchema1to2(diskConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -83,14 +90,13 @@ func upgradeConfigSchema(oldVersion int, diskConfig *map[string]interface{}) err
|
||||
return nil
|
||||
}
|
||||
|
||||
// The first schema upgrade:
|
||||
// No more "dnsfilter.txt", filters are now kept in data/filters/
|
||||
func upgradeSchema0to1(diskConfig *map[string]interface{}) error {
|
||||
log.Printf("%s(): called", _Func())
|
||||
|
||||
// The first schema upgrade:
|
||||
// No more "dnsfilter.txt", filters are now kept in data/filters/
|
||||
dnsFilterPath := filepath.Join(config.ourBinaryDir, "dnsfilter.txt")
|
||||
_, err := os.Stat(dnsFilterPath)
|
||||
if !os.IsNotExist(err) {
|
||||
if _, err := os.Stat(dnsFilterPath); !os.IsNotExist(err) {
|
||||
log.Printf("Deleting %s as we don't need it anymore", dnsFilterPath)
|
||||
err = os.Remove(dnsFilterPath)
|
||||
if err != nil {
|
||||
@ -103,3 +109,38 @@ func upgradeSchema0to1(diskConfig *map[string]interface{}) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Second schema upgrade:
|
||||
// coredns is now dns in config
|
||||
// delete 'Corefile', since we don't use that anymore
|
||||
func upgradeSchema1to2(diskConfig *map[string]interface{}) error {
|
||||
log.Printf("%s(): called", _Func())
|
||||
|
||||
coreFilePath := filepath.Join(config.ourBinaryDir, "Corefile")
|
||||
if _, err := os.Stat(coreFilePath); !os.IsNotExist(err) {
|
||||
log.Printf("Deleting %s as we don't need it anymore", coreFilePath)
|
||||
err = os.Remove(coreFilePath)
|
||||
if err != nil {
|
||||
log.Printf("Cannot remove %s due to %s", coreFilePath, err)
|
||||
// not fatal, move on
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := (*diskConfig)["dns"]; !ok {
|
||||
(*diskConfig)["dns"] = (*diskConfig)["coredns"]
|
||||
delete((*diskConfig), "coredns")
|
||||
}
|
||||
(*diskConfig)["schema_version"] = 2
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// jump two schemas at once -- this time we just do it sequentially
|
||||
func upgradeSchema0to2(diskConfig *map[string]interface{}) error {
|
||||
err := upgradeSchema0to1(diskConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return upgradeSchema1to2(diskConfig)
|
||||
}
|
||||
|
@ -1,105 +0,0 @@
|
||||
package upstream
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
// DnsUpstream is a very simple upstream implementation for plain DNS
|
||||
type DnsUpstream struct {
|
||||
endpoint string // IP:port
|
||||
timeout time.Duration // Max read and write timeout
|
||||
proto string // Protocol (tcp, tcp-tls, or udp)
|
||||
transport *Transport // Persistent connections cache
|
||||
}
|
||||
|
||||
// NewDnsUpstream creates a new DNS upstream
|
||||
func NewDnsUpstream(endpoint string, proto string, tlsServerName string) (Upstream, error) {
|
||||
u := &DnsUpstream{
|
||||
endpoint: endpoint,
|
||||
timeout: defaultTimeout,
|
||||
proto: proto,
|
||||
}
|
||||
|
||||
var tlsConfig *tls.Config
|
||||
|
||||
if proto == "tcp-tls" {
|
||||
tlsConfig = new(tls.Config)
|
||||
tlsConfig.ServerName = tlsServerName
|
||||
}
|
||||
|
||||
// Initialize the connections cache
|
||||
u.transport = NewTransport(endpoint)
|
||||
u.transport.tlsConfig = tlsConfig
|
||||
u.transport.Start()
|
||||
|
||||
return u, nil
|
||||
}
|
||||
|
||||
// Exchange provides an implementation for the Upstream interface
|
||||
func (u *DnsUpstream) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) {
|
||||
resp, err := u.exchange(u.proto, query)
|
||||
|
||||
// Retry over TCP if response is truncated
|
||||
if err == dns.ErrTruncated && u.proto == "udp" {
|
||||
resp, err = u.exchange("tcp", query)
|
||||
} else if err == dns.ErrTruncated && resp != nil {
|
||||
// Reassemble something to be sent to client
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(query)
|
||||
m.Truncated = true
|
||||
m.Authoritative = true
|
||||
m.Rcode = dns.RcodeSuccess
|
||||
return m, nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
resp = &dns.Msg{}
|
||||
resp.SetRcode(resp, dns.RcodeServerFailure)
|
||||
}
|
||||
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// Clear resources
|
||||
func (u *DnsUpstream) Close() error {
|
||||
// Close active connections
|
||||
u.transport.Stop()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Performs a synchronous query. It sends the message m via the conn
|
||||
// c and waits for a reply. The conn c is not closed.
|
||||
func (u *DnsUpstream) exchange(proto string, query *dns.Msg) (r *dns.Msg, err error) {
|
||||
// Establish a connection if needed (or reuse cached)
|
||||
conn, err := u.transport.Dial(proto)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Write the request with a timeout
|
||||
conn.SetWriteDeadline(time.Now().Add(u.timeout))
|
||||
if err = conn.WriteMsg(query); err != nil {
|
||||
conn.Close() // Not giving it back
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Write response with a timeout
|
||||
conn.SetReadDeadline(time.Now().Add(u.timeout))
|
||||
r, err = conn.ReadMsg()
|
||||
if err != nil {
|
||||
conn.Close() // Not giving it back
|
||||
} else if err == nil && r.Id != query.Id {
|
||||
err = dns.ErrId
|
||||
conn.Close() // Not giving it back
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
// Return it back to the connections cache if there were no errors
|
||||
u.transport.Yield(conn)
|
||||
}
|
||||
return r, err
|
||||
}
|
@ -1,98 +0,0 @@
|
||||
package upstream
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
// Detects the upstream type from the specified url and creates a proper Upstream object
|
||||
func NewUpstream(url string, bootstrap string) (Upstream, error) {
|
||||
proto := "udp"
|
||||
prefix := ""
|
||||
|
||||
switch {
|
||||
case strings.HasPrefix(url, "tcp://"):
|
||||
proto = "tcp"
|
||||
prefix = "tcp://"
|
||||
case strings.HasPrefix(url, "tls://"):
|
||||
proto = "tcp-tls"
|
||||
prefix = "tls://"
|
||||
case strings.HasPrefix(url, "https://"):
|
||||
return NewHttpsUpstream(url, bootstrap)
|
||||
}
|
||||
|
||||
hostname := strings.TrimPrefix(url, prefix)
|
||||
|
||||
host, port, err := net.SplitHostPort(hostname)
|
||||
if err != nil {
|
||||
// Set port depending on the protocol
|
||||
switch proto {
|
||||
case "udp":
|
||||
port = "53"
|
||||
case "tcp":
|
||||
port = "53"
|
||||
case "tcp-tls":
|
||||
port = "853"
|
||||
}
|
||||
|
||||
// Set host = hostname
|
||||
host = hostname
|
||||
}
|
||||
|
||||
// Try to resolve the host address (or check if it's an IP address)
|
||||
bootstrapResolver := CreateResolver(bootstrap)
|
||||
ips, err := bootstrapResolver.LookupIPAddr(context.Background(), host)
|
||||
|
||||
if err != nil || len(ips) == 0 {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
addr := ips[0].String()
|
||||
endpoint := net.JoinHostPort(addr, port)
|
||||
tlsServerName := ""
|
||||
|
||||
if proto == "tcp-tls" && host != addr {
|
||||
// Check if we need to specify TLS server name
|
||||
tlsServerName = host
|
||||
}
|
||||
|
||||
return NewDnsUpstream(endpoint, proto, tlsServerName)
|
||||
}
|
||||
|
||||
func CreateResolver(bootstrap string) *net.Resolver {
|
||||
bootstrapResolver := net.DefaultResolver
|
||||
|
||||
if bootstrap != "" {
|
||||
bootstrapResolver = &net.Resolver{
|
||||
PreferGo: true,
|
||||
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
var d net.Dialer
|
||||
return d.DialContext(ctx, network, bootstrap)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return bootstrapResolver
|
||||
}
|
||||
|
||||
// Performs a simple health-check of the specified upstream
|
||||
func IsAlive(u Upstream) (bool, error) {
|
||||
// Using ipv4only.arpa. domain as it is a part of DNS64 RFC and it should exist everywhere
|
||||
ping := new(dns.Msg)
|
||||
ping.SetQuestion("ipv4only.arpa.", dns.TypeA)
|
||||
|
||||
resp, err := u.Exchange(context.Background(), ping)
|
||||
|
||||
// If we got a header, we're alright, basically only care about I/O errors 'n stuff.
|
||||
if err != nil && resp != nil {
|
||||
// Silly check, something sane came back.
|
||||
if resp.Rcode != dns.RcodeServerFailure {
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
|
||||
return err == nil, err
|
||||
}
|
@ -1,128 +0,0 @@
|
||||
package upstream
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/net/context"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
const (
|
||||
dnsMessageContentType = "application/dns-message"
|
||||
defaultKeepAlive = 30 * time.Second
|
||||
)
|
||||
|
||||
// HttpsUpstream is the upstream implementation for DNS-over-HTTPS
|
||||
type HttpsUpstream struct {
|
||||
client *http.Client
|
||||
endpoint *url.URL
|
||||
}
|
||||
|
||||
// NewHttpsUpstream creates a new DNS-over-HTTPS upstream from the specified url
|
||||
func NewHttpsUpstream(endpoint string, bootstrap string) (Upstream, error) {
|
||||
u, err := url.Parse(endpoint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Initialize bootstrap resolver
|
||||
bootstrapResolver := CreateResolver(bootstrap)
|
||||
dialer := &net.Dialer{
|
||||
Timeout: defaultTimeout,
|
||||
KeepAlive: defaultKeepAlive,
|
||||
DualStack: true,
|
||||
Resolver: bootstrapResolver,
|
||||
}
|
||||
|
||||
// Update TLS and HTTP client configuration
|
||||
tlsConfig := &tls.Config{ServerName: u.Hostname()}
|
||||
transport := &http.Transport{
|
||||
TLSClientConfig: tlsConfig,
|
||||
DisableCompression: true,
|
||||
MaxIdleConns: 1,
|
||||
DialContext: dialer.DialContext,
|
||||
}
|
||||
http2.ConfigureTransport(transport)
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: defaultTimeout,
|
||||
Transport: transport,
|
||||
}
|
||||
|
||||
return &HttpsUpstream{client: client, endpoint: u}, nil
|
||||
}
|
||||
|
||||
// Exchange provides an implementation for the Upstream interface
|
||||
func (u *HttpsUpstream) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) {
|
||||
queryBuf, err := query.Pack()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to pack DNS query")
|
||||
}
|
||||
|
||||
// No content negotiation for now, use DNS wire format
|
||||
buf, backendErr := u.exchangeWireformat(queryBuf)
|
||||
if backendErr == nil {
|
||||
response := &dns.Msg{}
|
||||
if err := response.Unpack(buf); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to unpack DNS response from body")
|
||||
}
|
||||
|
||||
response.Id = query.Id
|
||||
return response, nil
|
||||
}
|
||||
|
||||
log.Printf("failed to connect to an HTTPS backend %q due to %s", u.endpoint, backendErr)
|
||||
return nil, backendErr
|
||||
}
|
||||
|
||||
// Perform message exchange with the default UDP wireformat defined in current draft
|
||||
// https://tools.ietf.org/html/draft-ietf-doh-dns-over-https-10
|
||||
func (u *HttpsUpstream) exchangeWireformat(msg []byte) ([]byte, error) {
|
||||
req, err := http.NewRequest("POST", u.endpoint.String(), bytes.NewBuffer(msg))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to create an HTTPS request")
|
||||
}
|
||||
|
||||
req.Header.Add("Content-Type", dnsMessageContentType)
|
||||
req.Header.Add("Accept", dnsMessageContentType)
|
||||
req.Host = u.endpoint.Hostname()
|
||||
|
||||
resp, err := u.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to perform an HTTPS request")
|
||||
}
|
||||
|
||||
// Check response status code
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("returned status code %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if contentType != dnsMessageContentType {
|
||||
return nil, fmt.Errorf("return wrong content type %s", contentType)
|
||||
}
|
||||
|
||||
// Read application/dns-message response from the body
|
||||
buf, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to read the response body")
|
||||
}
|
||||
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
// Clear resources
|
||||
func (u *HttpsUpstream) Close() error {
|
||||
return nil
|
||||
}
|
@ -1,210 +0,0 @@
|
||||
package upstream
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"sort"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Persistent connections cache -- almost similar to the same used in the CoreDNS forward plugin
|
||||
|
||||
const (
|
||||
defaultExpire = 10 * time.Second
|
||||
minDialTimeout = 100 * time.Millisecond
|
||||
maxDialTimeout = 30 * time.Second
|
||||
defaultDialTimeout = 30 * time.Second
|
||||
cumulativeAvgWeight = 4
|
||||
)
|
||||
|
||||
// a persistConn hold the dns.Conn and the last used time.
|
||||
type persistConn struct {
|
||||
c *dns.Conn
|
||||
used time.Time
|
||||
}
|
||||
|
||||
// Transport hold the persistent cache.
|
||||
type Transport struct {
|
||||
avgDialTime int64 // kind of average time of dial time
|
||||
conns map[string][]*persistConn // Buckets for udp, tcp and tcp-tls.
|
||||
expire time.Duration // After this duration a connection is expired.
|
||||
addr string
|
||||
tlsConfig *tls.Config
|
||||
|
||||
dial chan string
|
||||
yield chan *dns.Conn
|
||||
ret chan *dns.Conn
|
||||
stop chan bool
|
||||
}
|
||||
|
||||
// Dial dials the address configured in transport, potentially reusing a connection or creating a new one.
|
||||
func (t *Transport) Dial(proto string) (*dns.Conn, error) {
|
||||
// If tls has been configured; use it.
|
||||
if t.tlsConfig != nil {
|
||||
proto = "tcp-tls"
|
||||
}
|
||||
|
||||
t.dial <- proto
|
||||
c := <-t.ret
|
||||
|
||||
if c != nil {
|
||||
return c, nil
|
||||
}
|
||||
|
||||
reqTime := time.Now()
|
||||
timeout := t.dialTimeout()
|
||||
if proto == "tcp-tls" {
|
||||
conn, err := dns.DialTimeoutWithTLS(proto, t.addr, t.tlsConfig, timeout)
|
||||
t.updateDialTimeout(time.Since(reqTime))
|
||||
return conn, err
|
||||
}
|
||||
conn, err := dns.DialTimeout(proto, t.addr, timeout)
|
||||
t.updateDialTimeout(time.Since(reqTime))
|
||||
return conn, err
|
||||
}
|
||||
|
||||
// Yield return the connection to transport for reuse.
|
||||
func (t *Transport) Yield(c *dns.Conn) { t.yield <- c }
|
||||
|
||||
// Start starts the transport's connection manager.
|
||||
func (t *Transport) Start() { go t.connManager() }
|
||||
|
||||
// Stop stops the transport's connection manager.
|
||||
func (t *Transport) Stop() { close(t.stop) }
|
||||
|
||||
// SetExpire sets the connection expire time in transport.
|
||||
func (t *Transport) SetExpire(expire time.Duration) { t.expire = expire }
|
||||
|
||||
// SetTLSConfig sets the TLS config in transport.
|
||||
func (t *Transport) SetTLSConfig(cfg *tls.Config) { t.tlsConfig = cfg }
|
||||
|
||||
func NewTransport(addr string) *Transport {
|
||||
t := &Transport{
|
||||
avgDialTime: int64(defaultDialTimeout / 2),
|
||||
conns: make(map[string][]*persistConn),
|
||||
expire: defaultExpire,
|
||||
addr: addr,
|
||||
dial: make(chan string),
|
||||
yield: make(chan *dns.Conn),
|
||||
ret: make(chan *dns.Conn),
|
||||
stop: make(chan bool),
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
func averageTimeout(currentAvg *int64, observedDuration time.Duration, weight int64) {
|
||||
dt := time.Duration(atomic.LoadInt64(currentAvg))
|
||||
atomic.AddInt64(currentAvg, int64(observedDuration-dt)/weight)
|
||||
}
|
||||
|
||||
func (t *Transport) dialTimeout() time.Duration {
|
||||
return limitTimeout(&t.avgDialTime, minDialTimeout, maxDialTimeout)
|
||||
}
|
||||
|
||||
func (t *Transport) updateDialTimeout(newDialTime time.Duration) {
|
||||
averageTimeout(&t.avgDialTime, newDialTime, cumulativeAvgWeight)
|
||||
}
|
||||
|
||||
// limitTimeout is a utility function to auto-tune timeout values
|
||||
// average observed time is moved towards the last observed delay moderated by a weight
|
||||
// next timeout to use will be the double of the computed average, limited by min and max frame.
|
||||
func limitTimeout(currentAvg *int64, minValue time.Duration, maxValue time.Duration) time.Duration {
|
||||
rt := time.Duration(atomic.LoadInt64(currentAvg))
|
||||
if rt < minValue {
|
||||
return minValue
|
||||
}
|
||||
if rt < maxValue/2 {
|
||||
return 2 * rt
|
||||
}
|
||||
return maxValue
|
||||
}
|
||||
|
||||
// connManagers manages the persistent connection cache for UDP and TCP.
|
||||
func (t *Transport) connManager() {
|
||||
ticker := time.NewTicker(t.expire)
|
||||
Wait:
|
||||
for {
|
||||
select {
|
||||
case proto := <-t.dial:
|
||||
// take the last used conn - complexity O(1)
|
||||
if stack := t.conns[proto]; len(stack) > 0 {
|
||||
pc := stack[len(stack)-1]
|
||||
if time.Since(pc.used) < t.expire {
|
||||
// Found one, remove from pool and return this conn.
|
||||
t.conns[proto] = stack[:len(stack)-1]
|
||||
t.ret <- pc.c
|
||||
continue Wait
|
||||
}
|
||||
// clear entire cache if the last conn is expired
|
||||
t.conns[proto] = nil
|
||||
// now, the connections being passed to closeConns() are not reachable from
|
||||
// transport methods anymore. So, it's safe to close them in a separate goroutine
|
||||
go closeConns(stack)
|
||||
}
|
||||
|
||||
t.ret <- nil
|
||||
|
||||
case conn := <-t.yield:
|
||||
|
||||
// no proto here, infer from config and conn
|
||||
if _, ok := conn.Conn.(*net.UDPConn); ok {
|
||||
t.conns["udp"] = append(t.conns["udp"], &persistConn{conn, time.Now()})
|
||||
continue Wait
|
||||
}
|
||||
|
||||
if t.tlsConfig == nil {
|
||||
t.conns["tcp"] = append(t.conns["tcp"], &persistConn{conn, time.Now()})
|
||||
continue Wait
|
||||
}
|
||||
|
||||
t.conns["tcp-tls"] = append(t.conns["tcp-tls"], &persistConn{conn, time.Now()})
|
||||
|
||||
case <-ticker.C:
|
||||
t.cleanup(false)
|
||||
|
||||
case <-t.stop:
|
||||
t.cleanup(true)
|
||||
close(t.ret)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// closeConns closes connections.
|
||||
func closeConns(conns []*persistConn) {
|
||||
for _, pc := range conns {
|
||||
pc.c.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// cleanup removes connections from cache.
|
||||
func (t *Transport) cleanup(all bool) {
|
||||
staleTime := time.Now().Add(-t.expire)
|
||||
for proto, stack := range t.conns {
|
||||
if len(stack) == 0 {
|
||||
continue
|
||||
}
|
||||
if all {
|
||||
t.conns[proto] = nil
|
||||
// now, the connections being passed to closeConns() are not reachable from
|
||||
// transport methods anymore. So, it's safe to close them in a separate goroutine
|
||||
go closeConns(stack)
|
||||
continue
|
||||
}
|
||||
if stack[0].used.After(staleTime) {
|
||||
continue
|
||||
}
|
||||
|
||||
// connections in stack are sorted by "used"
|
||||
good := sort.Search(len(stack), func(i int) bool {
|
||||
return stack[i].used.After(staleTime)
|
||||
})
|
||||
t.conns[proto] = stack[good:]
|
||||
// now, the connections being passed to closeConns() are not reachable from
|
||||
// transport methods anymore. So, it's safe to close them in a separate goroutine
|
||||
go closeConns(stack[:good])
|
||||
}
|
||||
}
|
@ -1,81 +0,0 @@
|
||||
package upstream
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
"github.com/coredns/coredns/core/dnsserver"
|
||||
"github.com/coredns/coredns/plugin"
|
||||
"github.com/mholt/caddy"
|
||||
)
|
||||
|
||||
func init() {
|
||||
caddy.RegisterPlugin("upstream", caddy.Plugin{
|
||||
ServerType: "dns",
|
||||
Action: setup,
|
||||
})
|
||||
}
|
||||
|
||||
// Read the configuration and initialize upstreams
|
||||
func setup(c *caddy.Controller) error {
|
||||
p, err := setupPlugin(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
config := dnsserver.GetConfig(c)
|
||||
config.AddPlugin(func(next plugin.Handler) plugin.Handler {
|
||||
p.Next = next
|
||||
return p
|
||||
})
|
||||
|
||||
c.OnShutdown(p.onShutdown)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Read the configuration
|
||||
func setupPlugin(c *caddy.Controller) (*UpstreamPlugin, error) {
|
||||
p := New()
|
||||
|
||||
log.Println("Initializing the Upstream plugin")
|
||||
|
||||
bootstrap := ""
|
||||
upstreamUrls := []string{}
|
||||
for c.Next() {
|
||||
args := c.RemainingArgs()
|
||||
if len(args) > 0 {
|
||||
upstreamUrls = append(upstreamUrls, args...)
|
||||
}
|
||||
for c.NextBlock() {
|
||||
switch c.Val() {
|
||||
case "bootstrap":
|
||||
if !c.NextArg() {
|
||||
return nil, c.ArgErr()
|
||||
}
|
||||
bootstrap = c.Val()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, url := range upstreamUrls {
|
||||
u, err := NewUpstream(url, bootstrap)
|
||||
if err != nil {
|
||||
log.Printf("Cannot initialize upstream %s", url)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p.Upstreams = append(p.Upstreams, u)
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (p *UpstreamPlugin) onShutdown() error {
|
||||
for i := range p.Upstreams {
|
||||
u := p.Upstreams[i]
|
||||
err := u.Close()
|
||||
if err != nil {
|
||||
log.Printf("Error while closing the upstream: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
@ -1,29 +0,0 @@
|
||||
package upstream
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
)
|
||||
|
||||
func TestSetup(t *testing.T) {
|
||||
var tests = []struct {
|
||||
config string
|
||||
}{
|
||||
{`upstream 8.8.8.8`},
|
||||
{`upstream 8.8.8.8 {
|
||||
bootstrap 8.8.8.8:53
|
||||
}`},
|
||||
{`upstream tls://1.1.1.1 8.8.8.8 {
|
||||
bootstrap 1.1.1.1
|
||||
}`},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
c := caddy.NewTestController("dns", test.config)
|
||||
err := setup(c)
|
||||
if err != nil {
|
||||
t.Fatalf("Test failed")
|
||||
}
|
||||
}
|
||||
}
|
@ -1,57 +0,0 @@
|
||||
package upstream
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/coredns/coredns/plugin"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
// Upstream is a simplified interface for proxy destination
|
||||
type Upstream interface {
|
||||
Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
// UpstreamPlugin is a simplified DNS proxy using a generic upstream interface
|
||||
type UpstreamPlugin struct {
|
||||
Upstreams []Upstream
|
||||
Next plugin.Handler
|
||||
}
|
||||
|
||||
// Initialize the upstream plugin
|
||||
func New() *UpstreamPlugin {
|
||||
p := &UpstreamPlugin{
|
||||
Upstreams: []Upstream{},
|
||||
}
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
// ServeDNS implements interface for CoreDNS plugin
|
||||
func (p *UpstreamPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||
var reply *dns.Msg
|
||||
var backendErr error
|
||||
|
||||
for i := range p.Upstreams {
|
||||
upstream := p.Upstreams[i]
|
||||
reply, backendErr = upstream.Exchange(ctx, r)
|
||||
if backendErr == nil {
|
||||
w.WriteMsg(reply)
|
||||
return 0, nil
|
||||
}
|
||||
}
|
||||
|
||||
return dns.RcodeServerFailure, errors.Wrap(backendErr, "failed to contact any of the upstreams")
|
||||
}
|
||||
|
||||
// Name implements interface for CoreDNS plugin
|
||||
func (p *UpstreamPlugin) Name() string {
|
||||
return "upstream"
|
||||
}
|
@ -1,187 +0,0 @@
|
||||
package upstream
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
func TestDnsUpstreamIsAlive(t *testing.T) {
|
||||
var tests = []struct {
|
||||
url string
|
||||
bootstrap string
|
||||
}{
|
||||
{"8.8.8.8:53", "8.8.8.8:53"},
|
||||
{"1.1.1.1", ""},
|
||||
{"tcp://1.1.1.1:53", ""},
|
||||
{"176.103.130.130:5353", ""},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
u, err := NewUpstream(test.url, test.bootstrap)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("cannot create a DNS upstream")
|
||||
}
|
||||
|
||||
testUpstreamIsAlive(t, u)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHttpsUpstreamIsAlive(t *testing.T) {
|
||||
var tests = []struct {
|
||||
url string
|
||||
bootstrap string
|
||||
}{
|
||||
{"https://cloudflare-dns.com/dns-query", "8.8.8.8:53"},
|
||||
{"https://dns.google.com/experimental", "8.8.8.8:53"},
|
||||
{"https://doh.cleanbrowsing.org/doh/security-filter/", ""},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
u, err := NewUpstream(test.url, test.bootstrap)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("cannot create a DNS-over-HTTPS upstream")
|
||||
}
|
||||
|
||||
testUpstreamIsAlive(t, u)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDnsOverTlsIsAlive(t *testing.T) {
|
||||
var tests = []struct {
|
||||
url string
|
||||
bootstrap string
|
||||
}{
|
||||
{"tls://1.1.1.1", ""},
|
||||
{"tls://9.9.9.9:853", ""},
|
||||
{"tls://security-filter-dns.cleanbrowsing.org", "8.8.8.8:53"},
|
||||
{"tls://adult-filter-dns.cleanbrowsing.org:853", "8.8.8.8:53"},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
u, err := NewUpstream(test.url, test.bootstrap)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("cannot create a DNS-over-TLS upstream")
|
||||
}
|
||||
|
||||
testUpstreamIsAlive(t, u)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDnsUpstream(t *testing.T) {
|
||||
var tests = []struct {
|
||||
url string
|
||||
bootstrap string
|
||||
}{
|
||||
{"8.8.8.8:53", "8.8.8.8:53"},
|
||||
{"1.1.1.1", ""},
|
||||
{"tcp://1.1.1.1:53", ""},
|
||||
{"176.103.130.130:5353", ""},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
u, err := NewUpstream(test.url, test.bootstrap)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("cannot create a DNS upstream")
|
||||
}
|
||||
|
||||
testUpstream(t, u)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHttpsUpstream(t *testing.T) {
|
||||
var tests = []struct {
|
||||
url string
|
||||
bootstrap string
|
||||
}{
|
||||
{"https://cloudflare-dns.com/dns-query", "8.8.8.8:53"},
|
||||
{"https://dns.google.com/experimental", "8.8.8.8:53"},
|
||||
{"https://doh.cleanbrowsing.org/doh/security-filter/", ""},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
u, err := NewUpstream(test.url, test.bootstrap)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("cannot create a DNS-over-HTTPS upstream")
|
||||
}
|
||||
|
||||
testUpstream(t, u)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDnsOverTlsUpstream(t *testing.T) {
|
||||
var tests = []struct {
|
||||
url string
|
||||
bootstrap string
|
||||
}{
|
||||
{"tls://1.1.1.1", ""},
|
||||
{"tls://9.9.9.9:853", ""},
|
||||
{"tls://security-filter-dns.cleanbrowsing.org", "8.8.8.8:53"},
|
||||
{"tls://adult-filter-dns.cleanbrowsing.org:853", "8.8.8.8:53"},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
u, err := NewUpstream(test.url, test.bootstrap)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("cannot create a DNS-over-TLS upstream")
|
||||
}
|
||||
|
||||
testUpstream(t, u)
|
||||
}
|
||||
}
|
||||
|
||||
func testUpstreamIsAlive(t *testing.T, u Upstream) {
|
||||
alive, err := IsAlive(u)
|
||||
if !alive || err != nil {
|
||||
t.Errorf("Upstream is not alive")
|
||||
}
|
||||
|
||||
u.Close()
|
||||
}
|
||||
|
||||
func testUpstream(t *testing.T, u Upstream) {
|
||||
var tests = []struct {
|
||||
name string
|
||||
expected net.IP
|
||||
}{
|
||||
{"google-public-dns-a.google.com.", net.IPv4(8, 8, 8, 8)},
|
||||
{"google-public-dns-b.google.com.", net.IPv4(8, 8, 4, 4)},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
req := dns.Msg{}
|
||||
req.Id = dns.Id()
|
||||
req.RecursionDesired = true
|
||||
req.Question = []dns.Question{
|
||||
{Name: test.name, Qtype: dns.TypeA, Qclass: dns.ClassINET},
|
||||
}
|
||||
|
||||
resp, err := u.Exchange(context.Background(), &req)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("error while making an upstream request: %s", err)
|
||||
}
|
||||
|
||||
if len(resp.Answer) != 1 {
|
||||
t.Fatalf("no answer section in the response")
|
||||
}
|
||||
if answer, ok := resp.Answer[0].(*dns.A); ok {
|
||||
if !test.expected.Equal(answer.A) {
|
||||
t.Errorf("wrong IP in the response: %v", answer.A)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err := u.Close()
|
||||
if err != nil {
|
||||
t.Errorf("Error while closing the upstream: %s", err)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user