diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index 6095af08..64949ce1 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -179,6 +179,7 @@ type ServerConfig struct { FilteringConfig TLSConfig + TLSAllowUnencryptedDOH bool TLSv12Roots *x509.CertPool // list of root CAs for TLSv1.2 diff --git a/dnsforward/dnsforward_http.go b/dnsforward/dnsforward_http.go index 088add56..111fefbd 100644 --- a/dnsforward/dnsforward_http.go +++ b/dnsforward/dnsforward_http.go @@ -376,6 +376,20 @@ func checkDNS(input string, bootstrap []string) error { return nil } +func (s *Server) handleDOH(w http.ResponseWriter, r *http.Request) { + if !s.conf.TLSAllowUnencryptedDOH && r.TLS == nil { + httpError(r, w, http.StatusNotFound, "Not Found") + return + } + + if !s.IsRunning() { + httpError(r, w, http.StatusInternalServerError, "DNS server is not running") + return + } + + s.ServeHTTP(w, r) +} + func (s *Server) registerHandlers() { s.conf.HTTPRegister("GET", "/control/dns_info", s.handleGetConfig) s.conf.HTTPRegister("POST", "/control/dns_config", s.handleSetConfig) @@ -384,4 +398,6 @@ func (s *Server) registerHandlers() { s.conf.HTTPRegister("GET", "/control/access/list", s.handleAccessList) s.conf.HTTPRegister("POST", "/control/access/set", s.handleAccessSet) + + s.conf.HTTPRegister("", "/dns-query", s.handleDOH) } diff --git a/home/control.go b/home/control.go index 4da42ca9..4d1dbe0c 100644 --- a/home/control.go +++ b/home/control.go @@ -10,7 +10,6 @@ import ( "strings" "github.com/AdguardTeam/AdGuardHome/dnsforward" - "github.com/AdguardTeam/AdGuardHome/util" "github.com/AdguardTeam/golibs/log" "github.com/NYTimes/gziphandler" ) @@ -42,46 +41,6 @@ func addDNSAddress(dnsAddresses *[]string, addr string) { *dnsAddresses = append(*dnsAddresses, addr) } -// Get the list of DNS addresses the server is listening on -func getDNSAddresses() []string { - dnsAddresses := []string{} - - if config.DNS.BindHost == "0.0.0.0" { - ifaces, e := util.GetValidNetInterfacesForWeb() - if e != nil { - log.Error("Couldn't get network interfaces: %v", e) - return []string{} - } - - for _, iface := range ifaces { - for _, addr := range iface.Addresses { - addDNSAddress(&dnsAddresses, addr) - } - } - } else { - addDNSAddress(&dnsAddresses, config.DNS.BindHost) - } - - if config.TLS.Enabled && len(config.TLS.ServerName) != 0 { - - if config.TLS.PortHTTPS != 0 { - addr := config.TLS.ServerName - if config.TLS.PortHTTPS != 443 { - addr = fmt.Sprintf("%s:%d", addr, config.TLS.PortHTTPS) - } - addr = fmt.Sprintf("https://%s/dns-query", addr) - dnsAddresses = append(dnsAddresses, addr) - } - - if config.TLS.PortDNSOverTLS != 0 { - addr := fmt.Sprintf("tls://%s:%d", config.TLS.ServerName, config.TLS.PortDNSOverTLS) - dnsAddresses = append(dnsAddresses, addr) - } - } - - return dnsAddresses -} - func handleStatus(w http.ResponseWriter, r *http.Request) { c := dnsforward.FilteringConfig{} if Context.dnsServer != nil { @@ -143,11 +102,15 @@ func registerControlHandlers() { httpRegister("GET", "/control/profile", handleGetProfile) RegisterAuthHandlers() - - http.HandleFunc("/dns-query", postInstall(handleDOH)) } func httpRegister(method string, url string, handler func(http.ResponseWriter, *http.Request)) { + if len(method) == 0 { + // "/dns-query" handler doesn't need auth, gzip and isn't restricted by 1 HTTP method + http.HandleFunc(url, postInstall(handler)) + return + } + http.Handle(url, postInstallHandler(optionalAuthHandler(gziphandler.GzipHandler(ensureHandler(method, handler))))) } diff --git a/home/dns.go b/home/dns.go index 3b415a8e..29cec636 100644 --- a/home/dns.go +++ b/home/dns.go @@ -9,6 +9,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/dnsforward" "github.com/AdguardTeam/AdGuardHome/querylog" "github.com/AdguardTeam/AdGuardHome/stats" + "github.com/AdguardTeam/AdGuardHome/util" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/log" @@ -166,15 +167,57 @@ func generateServerConfig() dnsforward.ServerConfig { Port: tlsConf.PortDNSOverTLS, } } - newconfig.TLSAllowUnencryptedDOH = tlsConf.AllowUnencryptedDOH } newconfig.TLSv12Roots = Context.tlsRoots + newconfig.TLSAllowUnencryptedDOH = tlsConf.AllowUnencryptedDOH newconfig.FilterHandler = applyAdditionalFiltering newconfig.GetUpstreamsByClient = getUpstreamsByClient return newconfig } +// Get the list of DNS addresses the server is listening on +func getDNSAddresses() []string { + dnsAddresses := []string{} + + if config.DNS.BindHost == "0.0.0.0" { + ifaces, e := util.GetValidNetInterfacesForWeb() + if e != nil { + log.Error("Couldn't get network interfaces: %v", e) + return []string{} + } + + for _, iface := range ifaces { + for _, addr := range iface.Addresses { + addDNSAddress(&dnsAddresses, addr) + } + } + } else { + addDNSAddress(&dnsAddresses, config.DNS.BindHost) + } + + tlsConf := tlsConfigSettings{} + Context.tls.WriteDiskConfig(&tlsConf) + if tlsConf.Enabled && len(tlsConf.ServerName) != 0 { + + if tlsConf.PortHTTPS != 0 { + addr := tlsConf.ServerName + if tlsConf.PortHTTPS != 443 { + addr = fmt.Sprintf("%s:%d", addr, tlsConf.PortHTTPS) + } + addr = fmt.Sprintf("https://%s/dns-query", addr) + dnsAddresses = append(dnsAddresses, addr) + } + + if tlsConf.PortDNSOverTLS != 0 { + addr := fmt.Sprintf("tls://%s:%d", tlsConf.ServerName, tlsConf.PortDNSOverTLS) + dnsAddresses = append(dnsAddresses, addr) + } + } + + return dnsAddresses +} + func getUpstreamsByClient(clientAddr string) []upstream.Upstream { return Context.clients.FindUpstreams(clientAddr) }