Pull request: home: fix dns address fallback

Closes #2868.

Squashed commit of the following:

commit 7497b0d80233fa0f0fbdc94a85007d39566eea73
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Wed Mar 24 14:23:41 2021 +0300

    home: fix specified ip collecting

commit 7b1dfa69f4edeb3e07cd1067f77ff8519bcdbe1c
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Wed Mar 24 14:01:25 2021 +0300

    home: fix dns address fallback
This commit is contained in:
Ainar Garipov 2021-03-24 14:52:37 +03:00
parent 5d0d32b926
commit e10a3fa4b3
2 changed files with 86 additions and 51 deletions

View File

@ -34,15 +34,79 @@ func httpError(w http.ResponseWriter, code int, format string, args ...interface
http.Error(w, text, code)
}
// ---------------
// dns run control
// ---------------
func addDNSAddress(dnsAddresses *[]string, addr net.IP) {
// appendDNSAddrs is a convenient helper for appending a formatted form of DNS
// addresses to a slice of strings.
func appendDNSAddrs(dst []string, addrs ...net.IP) (res []string) {
for _, addr := range addrs {
hostport := addr.String()
if config.DNS.Port != 53 {
hostport = net.JoinHostPort(hostport, strconv.Itoa(config.DNS.Port))
}
*dnsAddresses = append(*dnsAddresses, hostport)
dst = append(dst, hostport)
}
return dst
}
// appendDNSAddrsWithIfaces formats and appends all DNS addresses from src to
// dst. It also adds the IP addresses of all network interfaces if src contains
// an unspecified IP addresss.
func appendDNSAddrsWithIfaces(dst []string, src []net.IP) (res []string, err error) {
ifacesAdded := false
for _, h := range src {
if !h.IsUnspecified() {
dst = appendDNSAddrs(dst, h)
continue
} else if ifacesAdded {
continue
}
// Add addresses of all network interfaces for addresses like
// "0.0.0.0" and "::".
var ifaces []*aghnet.NetInterface
ifaces, err = aghnet.GetValidNetInterfacesForWeb()
if err != nil {
return nil, fmt.Errorf("cannot get network interfaces: %w", err)
}
for _, iface := range ifaces {
dst = appendDNSAddrs(dst, iface.Addresses...)
}
ifacesAdded = true
}
return dst, nil
}
// collectDNSAddresses returns the list of DNS addresses the server is listening
// on, including the addresses on all interfaces in cases of unspecified IPs.
func collectDNSAddresses() (addrs []string, err error) {
if hosts := config.DNS.BindHosts; len(hosts) == 0 {
addrs = appendDNSAddrs(addrs, net.IP{127, 0, 0, 1})
} else {
addrs, err = appendDNSAddrsWithIfaces(addrs, hosts)
if err != nil {
return nil, fmt.Errorf("collecting dns addresses: %w", err)
}
}
de := getDNSEncryption()
if de.https != "" {
addrs = append(addrs, de.https)
}
if de.tls != "" {
addrs = append(addrs, de.tls)
}
if de.quic != "" {
addrs = append(addrs, de.quic)
}
return addrs, nil
}
// statusResponse is a response for /control/status endpoint.
@ -60,8 +124,17 @@ type statusResponse struct {
}
func handleStatus(w http.ResponseWriter, _ *http.Request) {
dnsAddrs, err := collectDNSAddresses()
if err != nil {
// Don't add a lot of formatting, since the error is already
// wrapped by collectDNSAddresses.
httpError(w, http.StatusInternalServerError, "%s", err)
return
}
resp := statusResponse{
DNSAddrs: getDNSAddresses(),
DNSAddrs: dnsAddrs,
DNSPort: config.DNS.Port,
HTTPPort: config.BindPort,
IsRunning: isRunning(),
@ -82,9 +155,10 @@ func handleStatus(w http.ResponseWriter, _ *http.Request) {
}
w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(resp)
err = json.NewEncoder(w).Encode(resp)
if err != nil {
httpError(w, http.StatusInternalServerError, "Unable to write response json: %s", err)
return
}
}

View File

@ -9,7 +9,6 @@ import (
"strconv"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/dnsfilter"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
@ -144,10 +143,8 @@ func ipsToUDPAddrs(ips []net.IP, port int) (udpAddrs []*net.UDPAddr) {
func generateServerConfig() (newConf dnsforward.ServerConfig, err error) {
dnsConf := config.DNS
hosts := dnsConf.BindHosts
for i, h := range hosts {
if h.IsUnspecified() {
hosts[i] = net.IP{127, 0, 0, 1}
}
if len(hosts) == 0 {
hosts = []net.IP{{127, 0, 0, 1}}
}
newConf = dnsforward.ServerConfig{
@ -268,42 +265,6 @@ func getDNSEncryption() (de dnsEncryption) {
return de
}
// Get the list of DNS addresses the server is listening on
func getDNSAddresses() (dnsAddrs []string) {
if hosts := config.DNS.BindHosts; len(hosts) == 0 || hosts[0].IsUnspecified() {
ifaces, e := aghnet.GetValidNetInterfacesForWeb()
if e != nil {
log.Error("Couldn't get network interfaces: %v", e)
return []string{}
}
for _, iface := range ifaces {
for _, addr := range iface.Addresses {
addDNSAddress(&dnsAddrs, addr)
}
}
} else {
for _, h := range hosts {
addDNSAddress(&dnsAddrs, h)
}
}
de := getDNSEncryption()
if de.https != "" {
dnsAddrs = append(dnsAddrs, de.https)
}
if de.tls != "" {
dnsAddrs = append(dnsAddrs, de.tls)
}
if de.quic != "" {
dnsAddrs = append(dnsAddrs, de.quic)
}
return dnsAddrs
}
// applyAdditionalFiltering adds additional client information and settings if
// the client has them.
func applyAdditionalFiltering(clientAddr net.IP, clientID string, setts *dnsfilter.RequestFilteringSettings) {