diff --git a/control.go b/control.go
index 0bd7a5d7..952395a1 100644
--- a/control.go
+++ b/control.go
@@ -18,6 +18,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/dnsforward"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/log"
+ "github.com/AdguardTeam/golibs/utils"
"github.com/miekg/dns"
govalidator "gopkg.in/asaskevich/govalidator.v4"
)
@@ -317,11 +318,10 @@ func handleSetUpstreamConfig(w http.ResponseWriter, r *http.Request) {
return
}
- for _, u := range newconfig.Upstreams {
- if err = validateUpstream(u); err != nil {
- httpError(w, http.StatusBadRequest, "%s can not be used as upstream cause: %s", u, err)
- return
- }
+ err = validateUpstreams(newconfig.Upstreams)
+ if err != nil {
+ httpError(w, http.StatusBadRequest, "wrong upstreams specification: %s", err)
+ return
}
config.DNS.UpstreamDNS = defaultDNS
@@ -346,18 +346,81 @@ func handleSetUpstreamConfig(w http.ResponseWriter, r *http.Request) {
httpUpdateConfigReloadDNSReturnOK(w, r)
}
-func validateUpstream(upstream string) error {
- for _, proto := range protocols {
- if strings.HasPrefix(upstream, proto) {
- return nil
+// validateUpstreams validates each upstream and returns an error if any upstream is invalid or if there are no default upstreams specified
+func validateUpstreams(upstreams []string) error {
+ var defaultUpstreamFound bool
+ for _, u := range upstreams {
+ d, err := validateUpstream(u)
+ if err != nil {
+ return err
+ }
+
+ // Check this flag until default upstream will not be found
+ if !defaultUpstreamFound {
+ defaultUpstreamFound = d
}
}
- if strings.Contains(upstream, "://") {
- return fmt.Errorf("wrong protocol")
+ // Return error if there are no default upstreams
+ if !defaultUpstreamFound {
+ return fmt.Errorf("no default upstreams specified")
}
- return checkPlainDNS(upstream)
+ return nil
+}
+
+func validateUpstream(u string) (defaultUpstream bool, err error) {
+ // Check if user tries to specify upstream for domain
+ defaultUpstream = true
+ u, defaultUpstream, err = separateUpstream(u)
+ if err != nil {
+ return
+ }
+
+ // The special server address '#' means "use the default servers"
+ if u == "#" && !defaultUpstream {
+ return
+ }
+
+ // Check if the upstream has a valid protocol prefix
+ for _, proto := range protocols {
+ if strings.HasPrefix(u, proto) {
+ return
+ }
+ }
+
+ // Return error if the upstream contains '://' without any valid protocol
+ if strings.Contains(u, "://") {
+ return defaultUpstream, fmt.Errorf("wrong protocol")
+ }
+
+ // Check if upstream is valid plain DNS
+ return defaultUpstream, checkPlainDNS(u)
+}
+
+// separateUpstream returns upstream without specified domains and a bool flag that indicates if no domains were specified
+// error will be returned if upstream per domain specification is invalid
+func separateUpstream(upstream string) (string, bool, error) {
+ defaultUpstream := true
+ if strings.HasPrefix(upstream, "[/") {
+ defaultUpstream = false
+ // split domains and upstream string
+ domainsAndUpstream := strings.Split(strings.TrimPrefix(upstream, "[/"), "/]")
+ if len(domainsAndUpstream) != 2 {
+ return "", defaultUpstream, fmt.Errorf("wrong DNS upstream per domain specification: %s", upstream)
+ }
+
+ // split domains list and validate each one
+ for _, host := range strings.Split(domainsAndUpstream[0], "/") {
+ if host != "" {
+ if err := utils.IsValidHostname(host); err != nil {
+ return "", defaultUpstream, err
+ }
+ }
+ }
+ upstream = domainsAndUpstream[1]
+ }
+ return upstream, defaultUpstream, nil
}
// checkPlainDNS checks if host is plain DNS
@@ -425,7 +488,18 @@ func handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
}
func checkDNS(input string, bootstrap []string) error {
- if err := validateUpstream(input); err != nil {
+ // separate upstream from domains list
+ input, defaultUpstream, err := separateUpstream(input)
+ if err != nil {
+ return fmt.Errorf("wrong upstream format: %s", err)
+ }
+
+ // No need to check this entrance
+ if input == "#" && !defaultUpstream {
+ return nil
+ }
+
+ if _, err := validateUpstream(input); err != nil {
return fmt.Errorf("wrong upstream format: %s", err)
}
diff --git a/control_test.go b/control_test.go
index b823b252..c04e1a2f 100644
--- a/control_test.go
+++ b/control_test.go
@@ -75,3 +75,79 @@ kXS9jgARhhiWXJrk
t.Fatalf("valid cert & priv key: validateCertificates(): %v", data)
}
}
+
+func TestValidateUpstream(t *testing.T) {
+ invalidUpstreams := []string{"1.2.3.4.5",
+ "123.3.7m",
+ "htttps://google.com/dns-query",
+ "[/host.com]tls://dns.adguard.com",
+ "[host.ru]#",
+ }
+
+ validDefaultUpstreams := []string{"1.1.1.1",
+ "tls://1.1.1.1",
+ "https://dns.adguard.com/dns-query",
+ "sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
+ }
+
+ validUpstreams := []string{"[/host.com/]1.1.1.1",
+ "[//]tls://1.1.1.1",
+ "[/www.host.com/]#",
+ "[/host.com/google.com/]8.8.8.8",
+ "[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
+ }
+ for _, u := range invalidUpstreams {
+ _, err := validateUpstream(u)
+ if err == nil {
+ t.Fatalf("upstream %s is invalid but it pass through validation", u)
+ }
+ }
+
+ for _, u := range validDefaultUpstreams {
+ defaultUpstream, err := validateUpstream(u)
+ if err != nil {
+ t.Fatalf("upstream %s is valid but it doen't pass through validation cause: %s", u, err)
+ }
+ if !defaultUpstream {
+ t.Fatalf("upstream %s is default one!", u)
+ }
+ }
+
+ for _, u := range validUpstreams {
+ defaultUpstream, err := validateUpstream(u)
+ if err != nil {
+ t.Fatalf("upstream %s is valid but it doen't pass through validation cause: %s", u, err)
+ }
+ if defaultUpstream {
+ t.Fatalf("upstream %s is default one!", u)
+ }
+ }
+}
+
+func TestValidateUpstreamsSet(t *testing.T) {
+ // Set of valid upstreams. There is no default upstream specified
+ upstreamsSet := []string{"[/host.com/]1.1.1.1",
+ "[//]tls://1.1.1.1",
+ "[/www.host.com/]#",
+ "[/host.com/google.com/]8.8.8.8",
+ "[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
+ }
+ err := validateUpstreams(upstreamsSet)
+ if err == nil {
+ t.Fatalf("there is no default upstream")
+ }
+
+ // Let's add default upstream
+ upstreamsSet = append(upstreamsSet, "8.8.8.8")
+ err = validateUpstreams(upstreamsSet)
+ if err != nil {
+ t.Fatalf("upstreams set is valid, but doesn't pass through validation cause: %s", err)
+ }
+
+ // Let's add invalid upstream
+ upstreamsSet = append(upstreamsSet, "dhcp://fake.dns")
+ err = validateUpstreams(upstreamsSet)
+ if err == nil {
+ t.Fatalf("there is an invalid upstream in set, but it pass through validation")
+ }
+}
diff --git a/dns.go b/dns.go
index 56fae956..9abbc80e 100644
--- a/dns.go
+++ b/dns.go
@@ -7,7 +7,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
"github.com/AdguardTeam/AdGuardHome/dnsforward"
- "github.com/AdguardTeam/dnsproxy/upstream"
+ "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/log"
"github.com/joomcode/errorx"
)
@@ -58,19 +58,12 @@ func generateServerConfig() dnsforward.ServerConfig {
}
}
- for _, u := range config.DNS.UpstreamDNS {
- opts := upstream.Options{
- Timeout: dnsforward.DefaultTimeout,
- Bootstrap: config.DNS.BootstrapDNS,
- }
- dnsUpstream, err := upstream.AddressToUpstream(u, opts)
- if err != nil {
- log.Printf("Couldn't get upstream: %s", err)
- // continue, just ignore the upstream
- continue
- }
- newconfig.Upstreams = append(newconfig.Upstreams, dnsUpstream)
+ upstreamConfig, err := proxy.ParseUpstreamsConfig(config.DNS.UpstreamDNS, config.DNS.BootstrapDNS, dnsforward.DefaultTimeout)
+ if err != nil {
+ log.Error("Couldn't get upstreams configuration cause: %s", err)
}
+ newconfig.Upstreams = upstreamConfig.Upstreams
+ newconfig.DomainsReservedUpstreams = upstreamConfig.DomainReservedUpstreams
newconfig.AllServers = config.DNS.AllServers
return newconfig
}
diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go
index 7b8430ed..6404a995 100644
--- a/dnsforward/dnsforward.go
+++ b/dnsforward/dnsforward.go
@@ -82,10 +82,11 @@ type TLSConfig struct {
// ServerConfig represents server configuration.
// The zero ServerConfig is empty and ready for use.
type ServerConfig struct {
- UDPListenAddr *net.UDPAddr // UDP listen address
- TCPListenAddr *net.TCPAddr // TCP listen address
- Upstreams []upstream.Upstream // Configured upstreams
- Filters []dnsfilter.Filter // A list of filters to use
+ UDPListenAddr *net.UDPAddr // UDP listen address
+ TCPListenAddr *net.TCPAddr // TCP listen address
+ Upstreams []upstream.Upstream // Configured upstreams
+ DomainsReservedUpstreams map[string][]upstream.Upstream // Map of domains and lists of configured upstreams
+ Filters []dnsfilter.Filter // A list of filters to use
FilteringConfig
TLSConfig
@@ -156,15 +157,16 @@ func (s *Server) startInternal(config *ServerConfig) error {
})
proxyConfig := proxy.Config{
- UDPListenAddr: s.UDPListenAddr,
- TCPListenAddr: s.TCPListenAddr,
- Ratelimit: s.Ratelimit,
- RatelimitWhitelist: s.RatelimitWhitelist,
- RefuseAny: s.RefuseAny,
- CacheEnabled: true,
- Upstreams: s.Upstreams,
- Handler: s.handleDNSRequest,
- AllServers: s.AllServers,
+ UDPListenAddr: s.UDPListenAddr,
+ TCPListenAddr: s.TCPListenAddr,
+ Ratelimit: s.Ratelimit,
+ RatelimitWhitelist: s.RatelimitWhitelist,
+ RefuseAny: s.RefuseAny,
+ CacheEnabled: true,
+ Upstreams: s.Upstreams,
+ DomainsReservedUpstreams: s.DomainsReservedUpstreams,
+ Handler: s.handleDNSRequest,
+ AllServers: s.AllServers,
}
if s.TLSListenAddr != nil && s.CertificateChain != "" && s.PrivateKey != "" {
diff --git a/go.mod b/go.mod
index 01918844..dab19303 100644
--- a/go.mod
+++ b/go.mod
@@ -3,10 +3,10 @@ module github.com/AdguardTeam/AdGuardHome
go 1.12
require (
- github.com/AdguardTeam/dnsproxy v0.11.2
+ github.com/AdguardTeam/dnsproxy v0.12.0
github.com/AdguardTeam/golibs v0.1.3
github.com/StackExchange/wmi v0.0.0-20180725035823-b12b22c5341f // indirect
- github.com/bluele/gcache v0.0.0-20171010155617-472614239ac7
+ github.com/bluele/gcache v0.0.0-20190203144525-2016d595ccb0
github.com/go-ole/go-ole v1.2.1 // indirect
github.com/go-test/deep v1.0.1
github.com/gobuffalo/packr v1.19.0
diff --git a/go.sum b/go.sum
index fbd89b79..cec7fae6 100644
--- a/go.sum
+++ b/go.sum
@@ -1,6 +1,6 @@
-github.com/AdguardTeam/dnsproxy v0.11.2 h1:S/Ag2q9qoZsmW1fvMohPZP7/5amEtz8NmFCp8kxUalQ=
-github.com/AdguardTeam/dnsproxy v0.11.2/go.mod h1:EPp92b5cYR7HZpO+OQu6xC7AyhUoBaXW3sfa3exq/0I=
-github.com/AdguardTeam/golibs v0.1.0/go.mod h1:zhi6xGwK4cMpjDocybhhLgvcGkstiSIjlpKbvyxC5Yc=
+github.com/AdguardTeam/dnsproxy v0.12.0 h1:BPgv2PlH2u4xakFcaW4EqU3Visk1BNidrqGSgxe5Qzg=
+github.com/AdguardTeam/dnsproxy v0.12.0/go.mod h1:lcZM2QPwcWGEL3pz8RYy06nQdbjj4pr+94H45jnVSHg=
+github.com/AdguardTeam/golibs v0.1.2/go.mod h1:b0XkhgIcn2TxwX6C5AQMtpIFAgjPehNgxJErWkwA3ko=
github.com/AdguardTeam/golibs v0.1.3 h1:hmapdTtMtIk3T8eQDwTOLdqZLGDKNKk9325uC8z12xg=
github.com/AdguardTeam/golibs v0.1.3/go.mod h1:b0XkhgIcn2TxwX6C5AQMtpIFAgjPehNgxJErWkwA3ko=
github.com/StackExchange/wmi v0.0.0-20180725035823-b12b22c5341f h1:5ZfJxyXo8KyX8DgGXC5B7ILL8y51fci/qYz2B4j8iLY=
@@ -15,8 +15,8 @@ github.com/ameshkov/dnsstamps v1.0.1 h1:LhGvgWDzhNJh+kBQd/AfUlq1vfVe109huiXw4Jhn
github.com/ameshkov/dnsstamps v1.0.1/go.mod h1:Ii3eUu73dx4Vw5O4wjzmT5+lkCwovjzaEZZ4gKyIH5A=
github.com/beefsack/go-rate v0.0.0-20180408011153-efa7637bb9b6 h1:KXlsf+qt/X5ttPGEjR0tPH1xaWWoKBEg9Q1THAj2h3I=
github.com/beefsack/go-rate v0.0.0-20180408011153-efa7637bb9b6/go.mod h1:6YNgTHLutezwnBvyneBbwvB8C82y3dcoOj5EQJIdGXA=
-github.com/bluele/gcache v0.0.0-20171010155617-472614239ac7 h1:NpQ+gkFOH27AyDypSCJ/LdsIi/b4rdnEb1N5+IpFfYs=
-github.com/bluele/gcache v0.0.0-20171010155617-472614239ac7/go.mod h1:8c4/i2VlovMO2gBnHGQPN5EJw+H0lx1u/5p+cgsXtCk=
+github.com/bluele/gcache v0.0.0-20190203144525-2016d595ccb0 h1:vUdUwmQLnT/yuk8PsDhhMVkrfr4aMdcv/0GWzIqOjEY=
+github.com/bluele/gcache v0.0.0-20190203144525-2016d595ccb0/go.mod h1:8c4/i2VlovMO2gBnHGQPN5EJw+H0lx1u/5p+cgsXtCk=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-ole/go-ole v1.2.1 h1:2lOsA72HgjxAuMlKpFiCbHTvu44PIVkZ5hqm3RSdI/E=