* move getDNSAddresses() and "/dns-query" handler to DNS module
This commit is contained in:
parent
db30f27c8f
commit
dd7d9dc334
@ -179,6 +179,7 @@ type ServerConfig struct {
|
||||
|
||||
FilteringConfig
|
||||
TLSConfig
|
||||
TLSAllowUnencryptedDOH bool
|
||||
|
||||
TLSv12Roots *x509.CertPool // list of root CAs for TLSv1.2
|
||||
|
||||
|
@ -376,6 +376,20 @@ func checkDNS(input string, bootstrap []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) handleDOH(w http.ResponseWriter, r *http.Request) {
|
||||
if !s.conf.TLSAllowUnencryptedDOH && r.TLS == nil {
|
||||
httpError(r, w, http.StatusNotFound, "Not Found")
|
||||
return
|
||||
}
|
||||
|
||||
if !s.IsRunning() {
|
||||
httpError(r, w, http.StatusInternalServerError, "DNS server is not running")
|
||||
return
|
||||
}
|
||||
|
||||
s.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
func (s *Server) registerHandlers() {
|
||||
s.conf.HTTPRegister("GET", "/control/dns_info", s.handleGetConfig)
|
||||
s.conf.HTTPRegister("POST", "/control/dns_config", s.handleSetConfig)
|
||||
@ -384,4 +398,6 @@ func (s *Server) registerHandlers() {
|
||||
|
||||
s.conf.HTTPRegister("GET", "/control/access/list", s.handleAccessList)
|
||||
s.conf.HTTPRegister("POST", "/control/access/set", s.handleAccessSet)
|
||||
|
||||
s.conf.HTTPRegister("", "/dns-query", s.handleDOH)
|
||||
}
|
||||
|
@ -10,7 +10,6 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/dnsforward"
|
||||
"github.com/AdguardTeam/AdGuardHome/util"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/NYTimes/gziphandler"
|
||||
)
|
||||
@ -42,46 +41,6 @@ func addDNSAddress(dnsAddresses *[]string, addr string) {
|
||||
*dnsAddresses = append(*dnsAddresses, addr)
|
||||
}
|
||||
|
||||
// Get the list of DNS addresses the server is listening on
|
||||
func getDNSAddresses() []string {
|
||||
dnsAddresses := []string{}
|
||||
|
||||
if config.DNS.BindHost == "0.0.0.0" {
|
||||
ifaces, e := util.GetValidNetInterfacesForWeb()
|
||||
if e != nil {
|
||||
log.Error("Couldn't get network interfaces: %v", e)
|
||||
return []string{}
|
||||
}
|
||||
|
||||
for _, iface := range ifaces {
|
||||
for _, addr := range iface.Addresses {
|
||||
addDNSAddress(&dnsAddresses, addr)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
addDNSAddress(&dnsAddresses, config.DNS.BindHost)
|
||||
}
|
||||
|
||||
if config.TLS.Enabled && len(config.TLS.ServerName) != 0 {
|
||||
|
||||
if config.TLS.PortHTTPS != 0 {
|
||||
addr := config.TLS.ServerName
|
||||
if config.TLS.PortHTTPS != 443 {
|
||||
addr = fmt.Sprintf("%s:%d", addr, config.TLS.PortHTTPS)
|
||||
}
|
||||
addr = fmt.Sprintf("https://%s/dns-query", addr)
|
||||
dnsAddresses = append(dnsAddresses, addr)
|
||||
}
|
||||
|
||||
if config.TLS.PortDNSOverTLS != 0 {
|
||||
addr := fmt.Sprintf("tls://%s:%d", config.TLS.ServerName, config.TLS.PortDNSOverTLS)
|
||||
dnsAddresses = append(dnsAddresses, addr)
|
||||
}
|
||||
}
|
||||
|
||||
return dnsAddresses
|
||||
}
|
||||
|
||||
func handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||
c := dnsforward.FilteringConfig{}
|
||||
if Context.dnsServer != nil {
|
||||
@ -143,11 +102,15 @@ func registerControlHandlers() {
|
||||
|
||||
httpRegister("GET", "/control/profile", handleGetProfile)
|
||||
RegisterAuthHandlers()
|
||||
|
||||
http.HandleFunc("/dns-query", postInstall(handleDOH))
|
||||
}
|
||||
|
||||
func httpRegister(method string, url string, handler func(http.ResponseWriter, *http.Request)) {
|
||||
if len(method) == 0 {
|
||||
// "/dns-query" handler doesn't need auth, gzip and isn't restricted by 1 HTTP method
|
||||
http.HandleFunc(url, postInstall(handler))
|
||||
return
|
||||
}
|
||||
|
||||
http.Handle(url, postInstallHandler(optionalAuthHandler(gziphandler.GzipHandler(ensureHandler(method, handler)))))
|
||||
}
|
||||
|
||||
|
45
home/dns.go
45
home/dns.go
@ -9,6 +9,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/dnsforward"
|
||||
"github.com/AdguardTeam/AdGuardHome/querylog"
|
||||
"github.com/AdguardTeam/AdGuardHome/stats"
|
||||
"github.com/AdguardTeam/AdGuardHome/util"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
@ -166,15 +167,57 @@ func generateServerConfig() dnsforward.ServerConfig {
|
||||
Port: tlsConf.PortDNSOverTLS,
|
||||
}
|
||||
}
|
||||
newconfig.TLSAllowUnencryptedDOH = tlsConf.AllowUnencryptedDOH
|
||||
}
|
||||
newconfig.TLSv12Roots = Context.tlsRoots
|
||||
newconfig.TLSAllowUnencryptedDOH = tlsConf.AllowUnencryptedDOH
|
||||
|
||||
newconfig.FilterHandler = applyAdditionalFiltering
|
||||
newconfig.GetUpstreamsByClient = getUpstreamsByClient
|
||||
return newconfig
|
||||
}
|
||||
|
||||
// Get the list of DNS addresses the server is listening on
|
||||
func getDNSAddresses() []string {
|
||||
dnsAddresses := []string{}
|
||||
|
||||
if config.DNS.BindHost == "0.0.0.0" {
|
||||
ifaces, e := util.GetValidNetInterfacesForWeb()
|
||||
if e != nil {
|
||||
log.Error("Couldn't get network interfaces: %v", e)
|
||||
return []string{}
|
||||
}
|
||||
|
||||
for _, iface := range ifaces {
|
||||
for _, addr := range iface.Addresses {
|
||||
addDNSAddress(&dnsAddresses, addr)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
addDNSAddress(&dnsAddresses, config.DNS.BindHost)
|
||||
}
|
||||
|
||||
tlsConf := tlsConfigSettings{}
|
||||
Context.tls.WriteDiskConfig(&tlsConf)
|
||||
if tlsConf.Enabled && len(tlsConf.ServerName) != 0 {
|
||||
|
||||
if tlsConf.PortHTTPS != 0 {
|
||||
addr := tlsConf.ServerName
|
||||
if tlsConf.PortHTTPS != 443 {
|
||||
addr = fmt.Sprintf("%s:%d", addr, tlsConf.PortHTTPS)
|
||||
}
|
||||
addr = fmt.Sprintf("https://%s/dns-query", addr)
|
||||
dnsAddresses = append(dnsAddresses, addr)
|
||||
}
|
||||
|
||||
if tlsConf.PortDNSOverTLS != 0 {
|
||||
addr := fmt.Sprintf("tls://%s:%d", tlsConf.ServerName, tlsConf.PortDNSOverTLS)
|
||||
dnsAddresses = append(dnsAddresses, addr)
|
||||
}
|
||||
}
|
||||
|
||||
return dnsAddresses
|
||||
}
|
||||
|
||||
func getUpstreamsByClient(clientAddr string) []upstream.Upstream {
|
||||
return Context.clients.FindUpstreams(clientAddr)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user