* move getDNSAddresses() and "/dns-query" handler to DNS module
This commit is contained in:
parent
db30f27c8f
commit
dd7d9dc334
|
@ -179,6 +179,7 @@ type ServerConfig struct {
|
||||||
|
|
||||||
FilteringConfig
|
FilteringConfig
|
||||||
TLSConfig
|
TLSConfig
|
||||||
|
TLSAllowUnencryptedDOH bool
|
||||||
|
|
||||||
TLSv12Roots *x509.CertPool // list of root CAs for TLSv1.2
|
TLSv12Roots *x509.CertPool // list of root CAs for TLSv1.2
|
||||||
|
|
||||||
|
|
|
@ -376,6 +376,20 @@ func checkDNS(input string, bootstrap []string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleDOH(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if !s.conf.TLSAllowUnencryptedDOH && r.TLS == nil {
|
||||||
|
httpError(r, w, http.StatusNotFound, "Not Found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !s.IsRunning() {
|
||||||
|
httpError(r, w, http.StatusInternalServerError, "DNS server is not running")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.ServeHTTP(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) registerHandlers() {
|
func (s *Server) registerHandlers() {
|
||||||
s.conf.HTTPRegister("GET", "/control/dns_info", s.handleGetConfig)
|
s.conf.HTTPRegister("GET", "/control/dns_info", s.handleGetConfig)
|
||||||
s.conf.HTTPRegister("POST", "/control/dns_config", s.handleSetConfig)
|
s.conf.HTTPRegister("POST", "/control/dns_config", s.handleSetConfig)
|
||||||
|
@ -384,4 +398,6 @@ func (s *Server) registerHandlers() {
|
||||||
|
|
||||||
s.conf.HTTPRegister("GET", "/control/access/list", s.handleAccessList)
|
s.conf.HTTPRegister("GET", "/control/access/list", s.handleAccessList)
|
||||||
s.conf.HTTPRegister("POST", "/control/access/set", s.handleAccessSet)
|
s.conf.HTTPRegister("POST", "/control/access/set", s.handleAccessSet)
|
||||||
|
|
||||||
|
s.conf.HTTPRegister("", "/dns-query", s.handleDOH)
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,7 +10,6 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/dnsforward"
|
"github.com/AdguardTeam/AdGuardHome/dnsforward"
|
||||||
"github.com/AdguardTeam/AdGuardHome/util"
|
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/NYTimes/gziphandler"
|
"github.com/NYTimes/gziphandler"
|
||||||
)
|
)
|
||||||
|
@ -42,46 +41,6 @@ func addDNSAddress(dnsAddresses *[]string, addr string) {
|
||||||
*dnsAddresses = append(*dnsAddresses, addr)
|
*dnsAddresses = append(*dnsAddresses, addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the list of DNS addresses the server is listening on
|
|
||||||
func getDNSAddresses() []string {
|
|
||||||
dnsAddresses := []string{}
|
|
||||||
|
|
||||||
if config.DNS.BindHost == "0.0.0.0" {
|
|
||||||
ifaces, e := util.GetValidNetInterfacesForWeb()
|
|
||||||
if e != nil {
|
|
||||||
log.Error("Couldn't get network interfaces: %v", e)
|
|
||||||
return []string{}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, iface := range ifaces {
|
|
||||||
for _, addr := range iface.Addresses {
|
|
||||||
addDNSAddress(&dnsAddresses, addr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
addDNSAddress(&dnsAddresses, config.DNS.BindHost)
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.TLS.Enabled && len(config.TLS.ServerName) != 0 {
|
|
||||||
|
|
||||||
if config.TLS.PortHTTPS != 0 {
|
|
||||||
addr := config.TLS.ServerName
|
|
||||||
if config.TLS.PortHTTPS != 443 {
|
|
||||||
addr = fmt.Sprintf("%s:%d", addr, config.TLS.PortHTTPS)
|
|
||||||
}
|
|
||||||
addr = fmt.Sprintf("https://%s/dns-query", addr)
|
|
||||||
dnsAddresses = append(dnsAddresses, addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.TLS.PortDNSOverTLS != 0 {
|
|
||||||
addr := fmt.Sprintf("tls://%s:%d", config.TLS.ServerName, config.TLS.PortDNSOverTLS)
|
|
||||||
dnsAddresses = append(dnsAddresses, addr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return dnsAddresses
|
|
||||||
}
|
|
||||||
|
|
||||||
func handleStatus(w http.ResponseWriter, r *http.Request) {
|
func handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||||
c := dnsforward.FilteringConfig{}
|
c := dnsforward.FilteringConfig{}
|
||||||
if Context.dnsServer != nil {
|
if Context.dnsServer != nil {
|
||||||
|
@ -143,11 +102,15 @@ func registerControlHandlers() {
|
||||||
|
|
||||||
httpRegister("GET", "/control/profile", handleGetProfile)
|
httpRegister("GET", "/control/profile", handleGetProfile)
|
||||||
RegisterAuthHandlers()
|
RegisterAuthHandlers()
|
||||||
|
|
||||||
http.HandleFunc("/dns-query", postInstall(handleDOH))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func httpRegister(method string, url string, handler func(http.ResponseWriter, *http.Request)) {
|
func httpRegister(method string, url string, handler func(http.ResponseWriter, *http.Request)) {
|
||||||
|
if len(method) == 0 {
|
||||||
|
// "/dns-query" handler doesn't need auth, gzip and isn't restricted by 1 HTTP method
|
||||||
|
http.HandleFunc(url, postInstall(handler))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
http.Handle(url, postInstallHandler(optionalAuthHandler(gziphandler.GzipHandler(ensureHandler(method, handler)))))
|
http.Handle(url, postInstallHandler(optionalAuthHandler(gziphandler.GzipHandler(ensureHandler(method, handler)))))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
45
home/dns.go
45
home/dns.go
|
@ -9,6 +9,7 @@ import (
|
||||||
"github.com/AdguardTeam/AdGuardHome/dnsforward"
|
"github.com/AdguardTeam/AdGuardHome/dnsforward"
|
||||||
"github.com/AdguardTeam/AdGuardHome/querylog"
|
"github.com/AdguardTeam/AdGuardHome/querylog"
|
||||||
"github.com/AdguardTeam/AdGuardHome/stats"
|
"github.com/AdguardTeam/AdGuardHome/stats"
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/util"
|
||||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
|
@ -166,15 +167,57 @@ func generateServerConfig() dnsforward.ServerConfig {
|
||||||
Port: tlsConf.PortDNSOverTLS,
|
Port: tlsConf.PortDNSOverTLS,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
newconfig.TLSAllowUnencryptedDOH = tlsConf.AllowUnencryptedDOH
|
|
||||||
}
|
}
|
||||||
newconfig.TLSv12Roots = Context.tlsRoots
|
newconfig.TLSv12Roots = Context.tlsRoots
|
||||||
|
newconfig.TLSAllowUnencryptedDOH = tlsConf.AllowUnencryptedDOH
|
||||||
|
|
||||||
newconfig.FilterHandler = applyAdditionalFiltering
|
newconfig.FilterHandler = applyAdditionalFiltering
|
||||||
newconfig.GetUpstreamsByClient = getUpstreamsByClient
|
newconfig.GetUpstreamsByClient = getUpstreamsByClient
|
||||||
return newconfig
|
return newconfig
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get the list of DNS addresses the server is listening on
|
||||||
|
func getDNSAddresses() []string {
|
||||||
|
dnsAddresses := []string{}
|
||||||
|
|
||||||
|
if config.DNS.BindHost == "0.0.0.0" {
|
||||||
|
ifaces, e := util.GetValidNetInterfacesForWeb()
|
||||||
|
if e != nil {
|
||||||
|
log.Error("Couldn't get network interfaces: %v", e)
|
||||||
|
return []string{}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, iface := range ifaces {
|
||||||
|
for _, addr := range iface.Addresses {
|
||||||
|
addDNSAddress(&dnsAddresses, addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
addDNSAddress(&dnsAddresses, config.DNS.BindHost)
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConf := tlsConfigSettings{}
|
||||||
|
Context.tls.WriteDiskConfig(&tlsConf)
|
||||||
|
if tlsConf.Enabled && len(tlsConf.ServerName) != 0 {
|
||||||
|
|
||||||
|
if tlsConf.PortHTTPS != 0 {
|
||||||
|
addr := tlsConf.ServerName
|
||||||
|
if tlsConf.PortHTTPS != 443 {
|
||||||
|
addr = fmt.Sprintf("%s:%d", addr, tlsConf.PortHTTPS)
|
||||||
|
}
|
||||||
|
addr = fmt.Sprintf("https://%s/dns-query", addr)
|
||||||
|
dnsAddresses = append(dnsAddresses, addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tlsConf.PortDNSOverTLS != 0 {
|
||||||
|
addr := fmt.Sprintf("tls://%s:%d", tlsConf.ServerName, tlsConf.PortDNSOverTLS)
|
||||||
|
dnsAddresses = append(dnsAddresses, addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return dnsAddresses
|
||||||
|
}
|
||||||
|
|
||||||
func getUpstreamsByClient(clientAddr string) []upstream.Upstream {
|
func getUpstreamsByClient(clientAddr string) []upstream.Upstream {
|
||||||
return Context.clients.FindUpstreams(clientAddr)
|
return Context.clients.FindUpstreams(clientAddr)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue