Merge pull request #40 in DNS/adguard-dns from feature/347 to master
* commit 'ff86d6b7dc31e463651c11f02330630e35676e05': Set default servers to tls://1.1.1.1 and tls://1.0.0.1
This commit is contained in:
commit
de7b2d5e6b
@ -56,7 +56,7 @@ type filter struct {
|
||||
LastUpdated time.Time `json:"last_updated" yaml:"-"`
|
||||
}
|
||||
|
||||
var defaultDNS = []string{"1.1.1.1", "1.0.0.1"}
|
||||
var defaultDNS = []string{"tls://1.1.1.1", "tls://1.0.0.1"}
|
||||
|
||||
// initialize to default values, will be changed later when reading config or parsing command line
|
||||
var config = configuration{
|
||||
|
122
control.go
122
control.go
@ -506,16 +506,25 @@ func handleStatsTop(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
func httpError(w http.ResponseWriter, code int, format string, args ...interface{}) {
|
||||
text := fmt.Sprintf(format, args...)
|
||||
http.Error(w, text, code)
|
||||
}
|
||||
|
||||
func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
errortext := fmt.Sprintf("Failed to read request body: %s", err)
|
||||
log.Println(errortext)
|
||||
http.Error(w, errortext, 400)
|
||||
http.Error(w, errortext, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
// if empty body -- user is asking for default servers
|
||||
hosts := parseIPsOptionalPort(string(body))
|
||||
hosts, err := sanitiseDNSServers(string(body))
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "Invalid DNS servers were given: %s", err)
|
||||
return
|
||||
}
|
||||
if len(hosts) == 0 {
|
||||
config.CoreDNS.UpstreamDNS = defaultDNS
|
||||
} else {
|
||||
@ -584,17 +593,10 @@ func handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
func checkDNS(host string) error {
|
||||
host = appendPortIfMissing(host)
|
||||
{
|
||||
h, _, err := net.SplitHostPort(host)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ip := net.ParseIP(h)
|
||||
if ip == nil {
|
||||
return fmt.Errorf("Invalid DNS server field: %s", h)
|
||||
}
|
||||
func checkDNS(input string) error {
|
||||
input, err := sanitizeDNSServer(input)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req := dns.Msg{}
|
||||
@ -603,45 +605,91 @@ func checkDNS(host string) error {
|
||||
req.Question = []dns.Question{
|
||||
{"google-public-dns-a.google.com.", dns.TypeA, dns.ClassINET},
|
||||
}
|
||||
resp, err := dns.Exchange(&req, host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Couldn't communicate with DNS server %s: %s", host, err)
|
||||
|
||||
prefix, host := splitDNSServerPrefixServer(input)
|
||||
|
||||
c := dns.Client{
|
||||
Timeout: time.Minute,
|
||||
}
|
||||
switch prefix {
|
||||
case "tls://":
|
||||
c.Net = "tcp-tls"
|
||||
}
|
||||
|
||||
resp, rtt, err := c.Exchange(&req, host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Couldn't communicate with DNS server %s: %s", input, err)
|
||||
}
|
||||
trace("exchange with %s took %v", input, rtt)
|
||||
if len(resp.Answer) != 1 {
|
||||
return fmt.Errorf("DNS server %s returned wrong answer", host)
|
||||
return fmt.Errorf("DNS server %s returned wrong answer", input)
|
||||
}
|
||||
if t, ok := resp.Answer[0].(*dns.A); ok {
|
||||
if !net.IPv4(8, 8, 8, 8).Equal(t.A) {
|
||||
return fmt.Errorf("DNS server %s returned wrong answer: %v", host, t.A)
|
||||
return fmt.Errorf("DNS server %s returned wrong answer: %v", input, t.A)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func appendPortIfMissing(input string) string {
|
||||
func sanitiseDNSServers(input string) ([]string, error) {
|
||||
fields := strings.Fields(input)
|
||||
hosts := []string{}
|
||||
for _, field := range fields {
|
||||
sanitized, err := sanitizeDNSServer(field)
|
||||
if err != nil {
|
||||
return hosts, err
|
||||
}
|
||||
hosts = append(hosts, sanitized)
|
||||
}
|
||||
return hosts, nil
|
||||
}
|
||||
|
||||
func getDNSServerPrefix(input string) string {
|
||||
prefix := ""
|
||||
switch {
|
||||
case strings.HasPrefix(input, "dns://"):
|
||||
prefix = "dns://"
|
||||
case strings.HasPrefix(input, "tls://"):
|
||||
prefix = "tls://"
|
||||
}
|
||||
return prefix
|
||||
}
|
||||
|
||||
func splitDNSServerPrefixServer(input string) (string, string) {
|
||||
prefix := getDNSServerPrefix(input)
|
||||
host := strings.TrimPrefix(input, prefix)
|
||||
return prefix, host
|
||||
}
|
||||
|
||||
func sanitizeDNSServer(input string) (string, error) {
|
||||
prefix, host := splitDNSServerPrefixServer(input)
|
||||
host = appendPortIfMissing(prefix, host)
|
||||
{
|
||||
h, _, err := net.SplitHostPort(host)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
ip := net.ParseIP(h)
|
||||
if ip == nil {
|
||||
return "", fmt.Errorf("Invalid DNS server field: %s", h)
|
||||
}
|
||||
}
|
||||
return prefix + host, nil
|
||||
}
|
||||
|
||||
func appendPortIfMissing(prefix, input string) string {
|
||||
port := "53"
|
||||
switch prefix {
|
||||
case "tls://":
|
||||
port = "853"
|
||||
}
|
||||
_, _, err := net.SplitHostPort(input)
|
||||
if err == nil {
|
||||
return input
|
||||
}
|
||||
return net.JoinHostPort(input, "53")
|
||||
}
|
||||
|
||||
func parseIPsOptionalPort(input string) []string {
|
||||
fields := strings.Fields(input)
|
||||
hosts := []string{}
|
||||
for _, field := range fields {
|
||||
_, _, err := net.SplitHostPort(field)
|
||||
if err != nil {
|
||||
ip := net.ParseIP(field)
|
||||
if ip == nil {
|
||||
log.Printf("Invalid DNS server field: %s\n", field)
|
||||
continue
|
||||
}
|
||||
}
|
||||
hosts = append(hosts, field)
|
||||
}
|
||||
return hosts
|
||||
return net.JoinHostPort(input, port)
|
||||
}
|
||||
|
||||
func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"runtime"
|
||||
"sort"
|
||||
@ -259,5 +260,5 @@ func trace(format string, args ...interface{}) {
|
||||
if len(text) == 0 || text[len(text)-1] != '\n' {
|
||||
buf.WriteRune('\n')
|
||||
}
|
||||
fmt.Print(buf.String())
|
||||
fmt.Fprint(os.Stderr, buf.String())
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user