diff --git a/CHANGELOG.md b/CHANGELOG.md index fdb58899..78ea8efd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,7 @@ TODO(a.garipov): Remove this deprecation, if v0.108.0 is released before the Go ### Fixed +- Poor testing of domain-specific upstream servers ([#4074]). - Omitted aliases of hosts specified by another line within the OS's hosts file ([#4079]). @@ -37,6 +38,8 @@ TODO(a.garipov): Remove this deprecation, if v0.108.0 is released before the Go - Go 1.16 support. [#3057]: https://github.com/AdguardTeam/AdGuardHome/issues/3057 +[#4074]: https://github.com/AdguardTeam/AdGuardHome/issues/4074 +[#4079]: https://github.com/AdguardTeam/AdGuardHome/issues/4079 @@ -82,7 +85,6 @@ TODO(a.garipov): Remove this deprecation, if v0.108.0 is released before the Go [#4008]: https://github.com/AdguardTeam/AdGuardHome/issues/4008 [#4016]: https://github.com/AdguardTeam/AdGuardHome/issues/4016 [#4027]: https://github.com/AdguardTeam/AdGuardHome/issues/4027 -[#4079]: https://github.com/AdguardTeam/AdGuardHome/issues/4079 diff --git a/internal/dnsforward/http.go b/internal/dnsforward/http.go index 42a81946..b0de11b9 100644 --- a/internal/dnsforward/http.go +++ b/internal/dnsforward/http.go @@ -5,7 +5,6 @@ import ( "fmt" "net" "net/http" - "strconv" "strings" "time" @@ -192,22 +191,23 @@ func (req *dnsConfig) checkCacheTTL() bool { func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) { req := dnsConfig{} - dec := json.NewDecoder(r.Body) - if err := dec.Decode(&req); err != nil { + err := json.NewDecoder(r.Body).Decode(&req) + if err != nil { aghhttp.Error(r, w, http.StatusBadRequest, "json Encode: %s", err) return } if req.Upstreams != nil { - if err := ValidateUpstreams(*req.Upstreams); err != nil { + if err = ValidateUpstreams(*req.Upstreams); err != nil { aghhttp.Error(r, w, http.StatusBadRequest, "wrong upstreams specification: %s", err) return } } - if errBoot, err := req.checkBootstrap(); err != nil { + var errBoot string + if errBoot, err = req.checkBootstrap(); err != nil { aghhttp.Error( r, w, @@ -220,19 +220,16 @@ func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) { return } - if !req.checkBlockingMode() { + switch { + case !req.checkBlockingMode(): aghhttp.Error(r, w, http.StatusBadRequest, "blocking_mode: incorrect value") return - } - - if !req.checkUpstreamsMode() { + case !req.checkUpstreamsMode(): aghhttp.Error(r, w, http.StatusBadRequest, "upstream_mode: incorrect value") return - } - - if !req.checkCacheTTL() { + case !req.checkCacheTTL(): aghhttp.Error( r, w, @@ -241,13 +238,15 @@ func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) { ) return + default: + // Go on. } restart := s.setConfig(req) s.conf.ConfigModified() if restart { - if err := s.Reconfigure(nil); err != nil { + if err = s.Reconfigure(nil); err != nil { aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err) } } @@ -387,14 +386,14 @@ func ValidateUpstreams(upstreams []string) (err error) { var defaultUpstreamFound bool for _, u := range upstreams { - var ok bool - ok, err = validateUpstream(u) + var useDefault bool + useDefault, err = validateUpstream(u) if err != nil { return err } if !defaultUpstreamFound { - defaultUpstreamFound = ok + defaultUpstreamFound = useDefault } } @@ -407,50 +406,62 @@ func ValidateUpstreams(upstreams []string) (err error) { var protocols = []string{"tls://", "https://", "tcp://", "sdns://", "quic://"} -func validateUpstream(u string) (bool, error) { +func validateUpstream(u string) (useDefault bool, err error) { // Check if the user tries to specify upstream for domain. - u, useDefault, err := separateUpstream(u) + var isDomainSpec bool + u, isDomainSpec, err = separateUpstream(u) if err != nil { - return useDefault, err + return !isDomainSpec, err } - // The special server address '#' means "use the default servers" - if u == "#" && !useDefault { + // The special server address '#' means that default server must be used. + if useDefault = !isDomainSpec; u == "#" && isDomainSpec { return useDefault, nil } - // Check if the upstream has a valid protocol prefix + // Check if the upstream has a valid protocol prefix. + // + // TODO(e.burkov): Validate the domain name. for _, proto := range protocols { if strings.HasPrefix(u, proto) { return useDefault, nil } } - // Return error if the upstream contains '://' without any valid protocol if strings.Contains(u, "://") { - return useDefault, fmt.Errorf("wrong protocol") + return useDefault, errors.Error("wrong protocol") } - // Check if upstream is valid plain DNS - return useDefault, checkPlainDNS(u) + // Check if upstream is either an IP or IP with port. + if net.ParseIP(u) != nil { + return useDefault, nil + } else if _, err = netutil.ParseIPPort(u); err != nil { + return useDefault, err + } + + return useDefault, nil } // separateUpstream returns the upstream without the specified domains. -// useDefault is true when a default upstream must be used. -func separateUpstream(upstreamStr string) (upstream string, useDefault bool, err error) { - defer func() { err = errors.Annotate(err, "bad upstream for domain spec %q: %w", upstreamStr) }() - +// isDomainSpec is true when the upstream is domains-specific. +func separateUpstream(upstreamStr string) (upstream string, isDomainSpec bool, err error) { if !strings.HasPrefix(upstreamStr, "[/") { - return upstreamStr, true, nil + return upstreamStr, false, nil } + defer func() { err = errors.Annotate(err, "bad upstream for domain %q: %w", upstreamStr) }() parts := strings.Split(upstreamStr[2:], "/]") - if len(parts) != 2 { - return "", false, errors.Error("duplicated separator") + switch len(parts) { + case 2: + // Go on. + case 1: + return "", false, errors.Error("missing separator") + default: + return "", true, errors.Error("duplicated separator") } - domains := parts[0] - upstream = parts[1] + var domains string + domains, upstream = parts[0], parts[1] for i, host := range strings.Split(domains, "/") { if host == "" { continue @@ -458,36 +469,11 @@ func separateUpstream(upstreamStr string) (upstream string, useDefault bool, err err = netutil.ValidateDomainName(host) if err != nil { - return "", false, fmt.Errorf("domain at index %d: %w", i, err) + return "", true, fmt.Errorf("domain at index %d: %w", i, err) } } - return upstream, false, nil -} - -// checkPlainDNS checks if host is plain DNS -func checkPlainDNS(upstream string) error { - // Check if host is ip without port - if net.ParseIP(upstream) != nil { - return nil - } - - // Check if host is ip with port - ip, port, err := net.SplitHostPort(upstream) - if err != nil { - return err - } - - if net.ParseIP(ip) == nil { - return fmt.Errorf("%s is not a valid IP", ip) - } - - _, err = strconv.ParseInt(port, 0, 64) - if err != nil { - return fmt.Errorf("%s is not a valid port: %w", port, err) - } - - return nil + return upstream, true, nil } // excFunc is a signature of function to check if upstream exchanges correctly. @@ -515,12 +501,8 @@ func checkDNSUpstreamExc(u upstream.Upstream) (err error) { if len(reply.Answer) != 1 { return fmt.Errorf("wrong response") - } - - if t, ok := reply.Answer[0].(*dns.A); ok { - if !net.IPv4(8, 8, 8, 8).Equal(t.A) { - return fmt.Errorf("wrong response") - } + } else if a, ok := reply.Answer[0].(*dns.A); !ok || !a.A.Equal(net.IP{8, 8, 8, 8}) { + return fmt.Errorf("wrong response") } return nil @@ -555,7 +537,7 @@ func checkDNS(input string, bootstrap []string, timeout time.Duration, ef excFun // Separate upstream from domains list. var useDefault bool - if input, useDefault, err = separateUpstream(input); err != nil { + if useDefault, err = validateUpstream(input); err != nil { return fmt.Errorf("wrong upstream format: %w", err) } @@ -564,7 +546,7 @@ func checkDNS(input string, bootstrap []string, timeout time.Duration, ef excFun return nil } - if _, err = validateUpstream(input); err != nil { + if input, _, err = separateUpstream(input); err != nil { return fmt.Errorf("wrong upstream format: %w", err) } @@ -572,7 +554,8 @@ func checkDNS(input string, bootstrap []string, timeout time.Duration, ef excFun bootstrap = defaultBootstrap } - log.Debug("checking if dns server %q works...", input) + log.Debug("checking if upstream %s works", input) + var u upstream.Upstream u, err = upstream.AddressToUpstream(input, &upstream.Options{ Bootstrap: bootstrap, @@ -586,7 +569,7 @@ func checkDNS(input string, bootstrap []string, timeout time.Duration, ef excFun return fmt.Errorf("upstream %q fails to exchange: %w", input, err) } - log.Debug("dns %s works OK", input) + log.Debug("upstream %s is ok", input) return nil } @@ -620,9 +603,9 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) { err = checkDNS(host, bootstraps, timeout, checkPrivateUpstreamExc) if err != nil { log.Info("%v", err) - // TODO(e.burkov): If passed upstream have already - // written an error above, we rewriting the error for - // it. These cases should be handled properly instead. + // TODO(e.burkov): If passed upstream have already written an error + // above, we rewriting the error for it. These cases should be + // handled properly instead. result[host] = err.Error() continue diff --git a/internal/dnsforward/http_test.go b/internal/dnsforward/http_test.go index f411a1b0..876cfdcf 100644 --- a/internal/dnsforward/http_test.go +++ b/internal/dnsforward/http_test.go @@ -184,7 +184,7 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) { wantSet: "", }, { name: "upstream_dns_bad", - wantSet: `wrong upstreams specification: address !!!: ` + + wantSet: `wrong upstreams specification: bad ipport address "!!!": address !!!: ` + `missing port in address`, }, { name: "bootstraps_bad", @@ -235,107 +235,117 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) { } func TestIsCommentOrEmpty(t *testing.T) { - assert.True(t, IsCommentOrEmpty("")) - assert.True(t, IsCommentOrEmpty("# comment")) - assert.False(t, IsCommentOrEmpty("1.2.3.4")) + for _, tc := range []struct { + want assert.BoolAssertionFunc + str string + }{{ + want: assert.True, + str: "", + }, { + want: assert.True, + str: "# comment", + }, { + want: assert.False, + str: "1.2.3.4", + }} { + tc.want(t, IsCommentOrEmpty(tc.str)) + } } -// TODO(a.garipov): Rewrite to check the actual error messages. func TestValidateUpstream(t *testing.T) { testCases := []struct { + wantDef assert.BoolAssertionFunc name string upstream string - valid bool - wantDef bool + wantErr string }{{ + wantDef: assert.True, name: "invalid", upstream: "1.2.3.4.5", - valid: false, - wantDef: false, + wantErr: `bad ipport address "1.2.3.4.5": address 1.2.3.4.5: missing port in address`, }, { + wantDef: assert.True, name: "invalid", upstream: "123.3.7m", - valid: false, - wantDef: false, + wantErr: `bad ipport address "123.3.7m": address 123.3.7m: missing port in address`, }, { + wantDef: assert.True, name: "invalid", upstream: "htttps://google.com/dns-query", - valid: false, - wantDef: false, + wantErr: `wrong protocol`, }, { + wantDef: assert.True, name: "invalid", upstream: "[/host.com]tls://dns.adguard.com", - valid: false, - wantDef: false, + wantErr: `bad upstream for domain "[/host.com]tls://dns.adguard.com": missing separator`, }, { + wantDef: assert.True, name: "invalid", upstream: "[host.ru]#", - valid: false, - wantDef: false, + wantErr: `bad ipport address "[host.ru]#": address [host.ru]#: missing port in address`, }, { + wantDef: assert.True, name: "valid_default", upstream: "1.1.1.1", - valid: true, - wantDef: true, + wantErr: ``, }, { + wantDef: assert.True, name: "valid_default", upstream: "tls://1.1.1.1", - valid: true, - wantDef: true, + wantErr: ``, }, { + wantDef: assert.True, name: "valid_default", upstream: "https://dns.adguard.com/dns-query", - valid: true, - wantDef: true, + wantErr: ``, }, { + wantDef: assert.True, name: "valid_default", upstream: "sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", - valid: true, - wantDef: true, + wantErr: ``, }, { + wantDef: assert.False, name: "valid", upstream: "[/host.com/]1.1.1.1", - valid: true, - wantDef: false, + wantErr: ``, }, { + wantDef: assert.False, name: "valid", upstream: "[//]tls://1.1.1.1", - valid: true, - wantDef: false, + wantErr: ``, }, { + wantDef: assert.False, name: "valid", upstream: "[/www.host.com/]#", - valid: true, - wantDef: false, + wantErr: ``, }, { + wantDef: assert.False, name: "valid", upstream: "[/host.com/google.com/]8.8.8.8", - valid: true, - wantDef: false, + wantErr: ``, }, { + wantDef: assert.False, name: "valid", upstream: "[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", - valid: true, - wantDef: false, + wantErr: ``, }, { + wantDef: assert.False, name: "idna", upstream: "[/пример.рф/]8.8.8.8", - valid: true, - wantDef: false, + wantErr: ``, }, { + wantDef: assert.False, name: "bad_domain", upstream: "[/!/]8.8.8.8", - valid: false, - wantDef: false, + wantErr: `bad upstream for domain "[/!/]8.8.8.8": domain at index 0: ` + + `bad domain name "!": bad domain name label "!": bad domain name label rune '!'`, }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { defaultUpstream, err := validateUpstream(tc.upstream) - require.Equal(t, tc.valid, err == nil) - if tc.valid { - assert.Equal(t, tc.wantDef, defaultUpstream) - } + testutil.AssertErrorMsg(t, tc.wantErr, err) + tc.wantDef(t, defaultUpstream) }) } } @@ -343,22 +353,19 @@ func TestValidateUpstream(t *testing.T) { func TestValidateUpstreamsSet(t *testing.T) { testCases := []struct { name string - msg string + wantErr string set []string - wantNil bool }{{ name: "empty", - msg: "empty upstreams array should be valid", + wantErr: ``, set: nil, - wantNil: true, }, { name: "comment", - msg: "comments should not be validated", + wantErr: ``, set: []string{"# comment"}, - wantNil: true, }, { - name: "valid_no_default", - msg: "there is no default upstream", + name: "valid_no_default", + wantErr: `no default upstreams specified`, set: []string{ "[/host.com/]1.1.1.1", "[//]tls://1.1.1.1", @@ -366,10 +373,9 @@ func TestValidateUpstreamsSet(t *testing.T) { "[/host.com/google.com/]8.8.8.8", "[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", }, - wantNil: false, }, { - name: "valid_with_default", - msg: "upstreams set is valid, but doesn't pass through validation cause: %s", + name: "valid_with_default", + wantErr: ``, set: []string{ "[/host.com/]1.1.1.1", "[//]tls://1.1.1.1", @@ -378,19 +384,16 @@ func TestValidateUpstreamsSet(t *testing.T) { "[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", "8.8.8.8", }, - wantNil: true, }, { name: "invalid", - msg: "there is an invalid upstream in set, but it pass through validation", + wantErr: `cannot prepare the upstream dhcp://fake.dns ([]): unsupported URL scheme: dhcp`, set: []string{"dhcp://fake.dns"}, - wantNil: false, }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { err := ValidateUpstreams(tc.set) - - assert.Equalf(t, tc.wantNil, err == nil, tc.msg, err) + testutil.AssertErrorMsg(t, tc.wantErr, err) }) } } diff --git a/openapi/openapi.yaml b/openapi/openapi.yaml index ef530663..720eecdf 100644 --- a/openapi/openapi.yaml +++ b/openapi/openapi.yaml @@ -123,7 +123,9 @@ '8.8.8.8': 'OK' '8.8.4.4': 'OK' '192.168.1.104:53535': > - Couldn't communicate with DNS server + upstream "192.168.1.104:1234" fails to exchange: couldn't + communicate with upstream: read udp + 192.168.1.100:60675->8.8.8.8:1234: i/o timeout '/version.json': 'post': 'tags':