Pull request: all: add idna handling, imp domain validation
Updates #2915. Squashed commit of the following: commit b907324426c87ee7334edbd61e43c44444ad27a9 Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Wed Apr 7 16:26:41 2021 +0300 all: imp docs, upd commit c022f75cac006e077095cad283fea0a91d3a0eea Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Wed Apr 7 15:51:30 2021 +0300 all: add idna handling, imp domain validation
This commit is contained in:
parent
00a61fdea0
commit
c133b01ef7
4
go.mod
4
go.mod
|
@ -35,9 +35,9 @@ require (
|
||||||
github.com/u-root/u-root v7.0.0+incompatible
|
github.com/u-root/u-root v7.0.0+incompatible
|
||||||
go.etcd.io/bbolt v1.3.5
|
go.etcd.io/bbolt v1.3.5
|
||||||
golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83
|
golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83
|
||||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110
|
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4
|
||||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect
|
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect
|
||||||
golang.org/x/sys v0.0.0-20210309074719-68d13333faf2
|
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44
|
||||||
golang.org/x/term v0.0.0-20210220032956-6a3ed077a48d // indirect
|
golang.org/x/term v0.0.0-20210220032956-6a3ed077a48d // indirect
|
||||||
golang.org/x/text v0.3.5 // indirect
|
golang.org/x/text v0.3.5 // indirect
|
||||||
gopkg.in/natefinch/lumberjack.v2 v2.0.0
|
gopkg.in/natefinch/lumberjack.v2 v2.0.0
|
||||||
|
|
8
go.sum
8
go.sum
|
@ -516,8 +516,8 @@ golang.org/x/net v0.0.0-20201209123823-ac852fbbde11/go.mod h1:m0MpNAwzfU5UDzcl9v
|
||||||
golang.org/x/net v0.0.0-20201216054612-986b41b23924/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
golang.org/x/net v0.0.0-20201216054612-986b41b23924/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||||
golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||||
golang.org/x/net v0.0.0-20210119194325-5f4716e94777/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
golang.org/x/net v0.0.0-20210119194325-5f4716e94777/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 h1:qWPm9rbaAMKs8Bq/9LRpbMqxWRVUAQwMI9fVrssnTfw=
|
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4 h1:4nGaVu0QrbjT/AK2PRLuQfQuh6DJve+pELhqTdAj3x0=
|
||||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
||||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||||
golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||||
golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||||
|
@ -582,8 +582,8 @@ golang.org/x/sys v0.0.0-20210110051926-789bb1bd4061/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||||
golang.org/x/sys v0.0.0-20210123111255-9b0068b26619/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20210123111255-9b0068b26619/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20210216163648-f7da38b97c65/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20210216163648-f7da38b97c65/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20210309074719-68d13333faf2 h1:46ULzRKLh1CwgRq2dC5SlBzEqqNCi8rreOZnNrbqcIY=
|
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44 h1:Bli41pIlzTzf3KEY06n+xnzK/BESIg2ze4Pgfh/aI8c=
|
||||||
golang.org/x/sys v0.0.0-20210309074719-68d13333faf2/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
|
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
golang.org/x/term v0.0.0-20210220032956-6a3ed077a48d h1:SZxvLBoTP5yHO3Frd4z4vrF+DBX9vMVanchswa69toE=
|
golang.org/x/term v0.0.0-20210220032956-6a3ed077a48d h1:SZxvLBoTP5yHO3Frd4z4vrF+DBX9vMVanchswa69toE=
|
||||||
|
|
|
@ -3,8 +3,10 @@ package aghnet
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
|
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
|
||||||
|
"golang.org/x/net/idna"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ValidateHardwareAddress returns an error if hwa is not a valid EUI-48,
|
// ValidateHardwareAddress returns an error if hwa is not a valid EUI-48,
|
||||||
|
@ -21,3 +23,79 @@ func ValidateHardwareAddress(hwa net.HardwareAddr) (err error) {
|
||||||
return fmt.Errorf("bad len: %d", l)
|
return fmt.Errorf("bad len: %d", l)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// maxDomainLabelLen is the maximum allowed length of a domain name label
|
||||||
|
// according to RFC 1035.
|
||||||
|
const maxDomainLabelLen = 63
|
||||||
|
|
||||||
|
// maxDomainNameLen is the maximum allowed length of a full domain name
|
||||||
|
// according to RFC 1035.
|
||||||
|
//
|
||||||
|
// See https://stackoverflow.com/a/32294443/1892060.
|
||||||
|
const maxDomainNameLen = 253
|
||||||
|
|
||||||
|
const invalidCharMsg = "invalid char %q at index %d in %q"
|
||||||
|
|
||||||
|
// isValidHostFirstRune returns true if r is a valid first rune for a hostname
|
||||||
|
// label.
|
||||||
|
func isValidHostFirstRune(r rune) (ok bool) {
|
||||||
|
return (r >= 'a' && r <= 'z') ||
|
||||||
|
(r >= 'A' && r <= 'Z') ||
|
||||||
|
(r >= '0' && r <= '9')
|
||||||
|
}
|
||||||
|
|
||||||
|
// isValidHostRune returns true if r is a valid rune for a hostname label.
|
||||||
|
func isValidHostRune(r rune) (ok bool) {
|
||||||
|
return r == '-' || isValidHostFirstRune(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateDomainNameLabel returns an error if label is not a valid label of
|
||||||
|
// a domain name.
|
||||||
|
func ValidateDomainNameLabel(label string) (err error) {
|
||||||
|
if len(label) > maxDomainLabelLen {
|
||||||
|
return fmt.Errorf("%q is too long, max: %d", label, maxDomainLabelLen)
|
||||||
|
} else if len(label) == 0 {
|
||||||
|
return agherr.Error("label is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
if r := label[0]; !isValidHostFirstRune(rune(r)) {
|
||||||
|
return fmt.Errorf(invalidCharMsg, r, 0, label)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, r := range label[1:] {
|
||||||
|
if !isValidHostRune(r) {
|
||||||
|
return fmt.Errorf(invalidCharMsg, r, i+1, label)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateDomainName validates the domain name in accordance to RFC 952, RFC
|
||||||
|
// 1035, and with RFC-1123's inclusion of digits at the start of the host. It
|
||||||
|
// doesn't validate against two or more hyphens to allow punycode and
|
||||||
|
// internationalized domains.
|
||||||
|
//
|
||||||
|
// TODO(a.garipov): After making sure that this works correctly, port this into
|
||||||
|
// module golibs.
|
||||||
|
func ValidateDomainName(name string) (err error) {
|
||||||
|
name, err = idna.ToASCII(name)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
l := len(name)
|
||||||
|
if l == 0 || l > maxDomainNameLen {
|
||||||
|
return fmt.Errorf("%q is too long, max: %d", name, maxDomainNameLen)
|
||||||
|
}
|
||||||
|
|
||||||
|
labels := strings.Split(name, ".")
|
||||||
|
for i, l := range labels {
|
||||||
|
err = ValidateDomainNameLabel(l)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid domain name label at index %d: %w", i, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
@ -2,6 +2,7 @@ package aghnet
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
@ -50,6 +51,81 @@ func TestValidateHardwareAddress(t *testing.T) {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
} else {
|
} else {
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, tc.wantErrMsg, err.Error())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func repeatStr(b *strings.Builder, s string, n int) {
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
_, _ = b.WriteString(s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateDomainName(t *testing.T) {
|
||||||
|
b := &strings.Builder{}
|
||||||
|
repeatStr(b, "a", 255)
|
||||||
|
longDomainName := b.String()
|
||||||
|
|
||||||
|
b.Reset()
|
||||||
|
repeatStr(b, "a", 64)
|
||||||
|
longLabel := b.String()
|
||||||
|
|
||||||
|
_, _ = b.WriteString(".com")
|
||||||
|
longLabelDomainName := b.String()
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
in string
|
||||||
|
wantErrMsg string
|
||||||
|
}{{
|
||||||
|
name: "success",
|
||||||
|
in: "example.com",
|
||||||
|
wantErrMsg: "",
|
||||||
|
}, {
|
||||||
|
name: "success_idna",
|
||||||
|
in: "пример.рф",
|
||||||
|
wantErrMsg: "",
|
||||||
|
}, {
|
||||||
|
name: "bad_symbol",
|
||||||
|
in: "!!!",
|
||||||
|
wantErrMsg: `invalid domain name label at index 0: ` +
|
||||||
|
`invalid char '!' at index 0 in "!!!"`,
|
||||||
|
}, {
|
||||||
|
name: "bad_length",
|
||||||
|
in: longDomainName,
|
||||||
|
wantErrMsg: `"` + longDomainName + `" is too long, max: 253`,
|
||||||
|
}, {
|
||||||
|
name: "bad_label_length",
|
||||||
|
in: longLabelDomainName,
|
||||||
|
wantErrMsg: `invalid domain name label at index 0: "` + longLabel +
|
||||||
|
`" is too long, max: 63`,
|
||||||
|
}, {
|
||||||
|
name: "bad_label_empty",
|
||||||
|
in: "example..com",
|
||||||
|
wantErrMsg: `invalid domain name label at index 1: label is empty`,
|
||||||
|
}, {
|
||||||
|
name: "bad_label_first_symbol",
|
||||||
|
in: "example.-aa.com",
|
||||||
|
wantErrMsg: `invalid domain name label at index 1:` +
|
||||||
|
` invalid char '-' at index 0 in "-aa"`,
|
||||||
|
}, {
|
||||||
|
name: "bad_label_symbol",
|
||||||
|
in: "example.a!!!.com",
|
||||||
|
wantErrMsg: `invalid domain name label at index 1:` +
|
||||||
|
` invalid char '!' at index 1 in "a!!!"`,
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
err := ValidateDomainName(tc.in)
|
||||||
|
if tc.wantErrMsg == "" {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
} else {
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
assert.Equal(t, tc.wantErrMsg, err.Error())
|
assert.Equal(t, tc.wantErrMsg, err.Error())
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
|
@ -6,33 +6,14 @@ import (
|
||||||
"path"
|
"path"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||||
"github.com/lucas-clemente/quic-go"
|
"github.com/lucas-clemente/quic-go"
|
||||||
)
|
)
|
||||||
|
|
||||||
// maxDomainLabelLen is the maximum allowed length of a domain name label
|
|
||||||
// according to RFC 1035.
|
|
||||||
const maxDomainLabelLen = 63
|
|
||||||
|
|
||||||
// validateDomainNameLabel returns an error if label is not a valid label of
|
|
||||||
// a domain name.
|
|
||||||
func validateDomainNameLabel(label string) (err error) {
|
|
||||||
if len(label) > maxDomainLabelLen {
|
|
||||||
return fmt.Errorf("%q is too long, max: %d", label, maxDomainLabelLen)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, r := range label {
|
|
||||||
if (r < 'a' || r > 'z') && (r < '0' || r > '9') && r != '-' {
|
|
||||||
return fmt.Errorf("invalid char %q at index %d in %q", r, i, label)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ValidateClientID returns an error if clientID is not a valid client ID.
|
// ValidateClientID returns an error if clientID is not a valid client ID.
|
||||||
func ValidateClientID(clientID string) (err error) {
|
func ValidateClientID(clientID string) (err error) {
|
||||||
err = validateDomainNameLabel(clientID)
|
err = aghnet.ValidateDomainNameLabel(clientID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("invalid client id: %w", err)
|
return fmt.Errorf("invalid client id: %w", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -114,7 +114,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
|
||||||
if p.AutohostTLD == "" {
|
if p.AutohostTLD == "" {
|
||||||
autohostSuffix = defaultAutohostSuffix
|
autohostSuffix = defaultAutohostSuffix
|
||||||
} else {
|
} else {
|
||||||
err = validateDomainNameLabel(p.AutohostTLD)
|
err = aghnet.ValidateDomainNameLabel(p.AutohostTLD)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("autohost tld: %w", err)
|
return nil, fmt.Errorf("autohost tld: %w", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -947,145 +947,6 @@ func publicKey(priv interface{}) interface{} {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestValidateUpstream(t *testing.T) {
|
|
||||||
testCases := []struct {
|
|
||||||
name string
|
|
||||||
upstream string
|
|
||||||
valid bool
|
|
||||||
wantDef bool
|
|
||||||
}{{
|
|
||||||
name: "invalid",
|
|
||||||
upstream: "1.2.3.4.5",
|
|
||||||
valid: false,
|
|
||||||
}, {
|
|
||||||
name: "invalid",
|
|
||||||
upstream: "123.3.7m",
|
|
||||||
valid: false,
|
|
||||||
}, {
|
|
||||||
name: "invalid",
|
|
||||||
upstream: "htttps://google.com/dns-query",
|
|
||||||
valid: false,
|
|
||||||
}, {
|
|
||||||
name: "invalid",
|
|
||||||
upstream: "[/host.com]tls://dns.adguard.com",
|
|
||||||
valid: false,
|
|
||||||
}, {
|
|
||||||
name: "invalid",
|
|
||||||
upstream: "[host.ru]#",
|
|
||||||
valid: false,
|
|
||||||
}, {
|
|
||||||
name: "valid_default",
|
|
||||||
upstream: "1.1.1.1",
|
|
||||||
valid: true,
|
|
||||||
wantDef: true,
|
|
||||||
}, {
|
|
||||||
name: "valid_default",
|
|
||||||
upstream: "tls://1.1.1.1",
|
|
||||||
valid: true,
|
|
||||||
wantDef: true,
|
|
||||||
}, {
|
|
||||||
name: "valid_default",
|
|
||||||
upstream: "https://dns.adguard.com/dns-query",
|
|
||||||
valid: true,
|
|
||||||
wantDef: true,
|
|
||||||
}, {
|
|
||||||
name: "valid_default",
|
|
||||||
upstream: "sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
|
|
||||||
valid: true,
|
|
||||||
wantDef: true,
|
|
||||||
}, {
|
|
||||||
name: "valid",
|
|
||||||
upstream: "[/host.com/]1.1.1.1",
|
|
||||||
valid: true,
|
|
||||||
wantDef: false,
|
|
||||||
}, {
|
|
||||||
name: "valid",
|
|
||||||
upstream: "[//]tls://1.1.1.1",
|
|
||||||
valid: true,
|
|
||||||
wantDef: false,
|
|
||||||
}, {
|
|
||||||
name: "valid",
|
|
||||||
upstream: "[/www.host.com/]#",
|
|
||||||
valid: true,
|
|
||||||
wantDef: false,
|
|
||||||
}, {
|
|
||||||
name: "valid",
|
|
||||||
upstream: "[/host.com/google.com/]8.8.8.8",
|
|
||||||
valid: true,
|
|
||||||
wantDef: false,
|
|
||||||
}, {
|
|
||||||
name: "valid",
|
|
||||||
upstream: "[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
|
|
||||||
valid: true,
|
|
||||||
wantDef: false,
|
|
||||||
}}
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
defaultUpstream, err := validateUpstream(tc.upstream)
|
|
||||||
require.Equal(t, tc.valid, err == nil)
|
|
||||||
if tc.valid {
|
|
||||||
assert.Equal(t, tc.wantDef, defaultUpstream)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidateUpstreamsSet(t *testing.T) {
|
|
||||||
testCases := []struct {
|
|
||||||
name string
|
|
||||||
msg string
|
|
||||||
set []string
|
|
||||||
wantNil bool
|
|
||||||
}{{
|
|
||||||
name: "empty",
|
|
||||||
msg: "empty upstreams array should be valid",
|
|
||||||
set: nil,
|
|
||||||
wantNil: true,
|
|
||||||
}, {
|
|
||||||
name: "comment",
|
|
||||||
msg: "comments should not be validated",
|
|
||||||
set: []string{"# comment"},
|
|
||||||
wantNil: true,
|
|
||||||
}, {
|
|
||||||
name: "valid_no_default",
|
|
||||||
msg: "there is no default upstream",
|
|
||||||
set: []string{
|
|
||||||
"[/host.com/]1.1.1.1",
|
|
||||||
"[//]tls://1.1.1.1",
|
|
||||||
"[/www.host.com/]#",
|
|
||||||
"[/host.com/google.com/]8.8.8.8",
|
|
||||||
"[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
|
|
||||||
},
|
|
||||||
wantNil: false,
|
|
||||||
}, {
|
|
||||||
name: "valid_with_default",
|
|
||||||
msg: "upstreams set is valid, but doesn't pass through validation cause: %s",
|
|
||||||
set: []string{
|
|
||||||
"[/host.com/]1.1.1.1",
|
|
||||||
"[//]tls://1.1.1.1",
|
|
||||||
"[/www.host.com/]#",
|
|
||||||
"[/host.com/google.com/]8.8.8.8",
|
|
||||||
"[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
|
|
||||||
"8.8.8.8",
|
|
||||||
},
|
|
||||||
wantNil: true,
|
|
||||||
}, {
|
|
||||||
name: "invalid",
|
|
||||||
msg: "there is an invalid upstream in set, but it pass through validation",
|
|
||||||
set: []string{"dhcp://fake.dns"},
|
|
||||||
wantNil: false,
|
|
||||||
}}
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
err := ValidateUpstreams(tc.set)
|
|
||||||
|
|
||||||
assert.Equalf(t, tc.wantNil, err == nil, tc.msg, err)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIPStringFromAddr(t *testing.T) {
|
func TestIPStringFromAddr(t *testing.T) {
|
||||||
t.Run("not_nil", func(t *testing.T) {
|
t.Run("not_nil", func(t *testing.T) {
|
||||||
addr := net.UDPAddr{
|
addr := net.UDPAddr{
|
||||||
|
|
|
@ -8,10 +8,11 @@ import (
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/AdguardTeam/golibs/utils"
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -302,7 +303,7 @@ type upstreamJSON struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateUpstreams validates each upstream and returns an error if any upstream is invalid or if there are no default upstreams specified
|
// ValidateUpstreams validates each upstream and returns an error if any upstream is invalid or if there are no default upstreams specified
|
||||||
func ValidateUpstreams(upstreams []string) error {
|
func ValidateUpstreams(upstreams []string) (err error) {
|
||||||
// No need to validate comments
|
// No need to validate comments
|
||||||
upstreams = filterOutComments(upstreams)
|
upstreams = filterOutComments(upstreams)
|
||||||
|
|
||||||
|
@ -311,7 +312,7 @@ func ValidateUpstreams(upstreams []string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := proxy.ParseUpstreamsConfig(
|
_, err = proxy.ParseUpstreamsConfig(
|
||||||
upstreams,
|
upstreams,
|
||||||
upstream.Options{
|
upstream.Options{
|
||||||
Bootstrap: []string{},
|
Bootstrap: []string{},
|
||||||
|
@ -345,56 +346,61 @@ func ValidateUpstreams(upstreams []string) error {
|
||||||
var protocols = []string{"tls://", "https://", "tcp://", "sdns://", "quic://"}
|
var protocols = []string{"tls://", "https://", "tcp://", "sdns://", "quic://"}
|
||||||
|
|
||||||
func validateUpstream(u string) (bool, error) {
|
func validateUpstream(u string) (bool, error) {
|
||||||
// Check if user tries to specify upstream for domain
|
// Check if the user tries to specify upstream for domain.
|
||||||
u, defaultUpstream, err := separateUpstream(u)
|
u, useDefault, err := separateUpstream(u)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return defaultUpstream, err
|
return useDefault, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// The special server address '#' means "use the default servers"
|
// The special server address '#' means "use the default servers"
|
||||||
if u == "#" && !defaultUpstream {
|
if u == "#" && !useDefault {
|
||||||
return defaultUpstream, nil
|
return useDefault, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if the upstream has a valid protocol prefix
|
// Check if the upstream has a valid protocol prefix
|
||||||
for _, proto := range protocols {
|
for _, proto := range protocols {
|
||||||
if strings.HasPrefix(u, proto) {
|
if strings.HasPrefix(u, proto) {
|
||||||
return defaultUpstream, nil
|
return useDefault, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return error if the upstream contains '://' without any valid protocol
|
// Return error if the upstream contains '://' without any valid protocol
|
||||||
if strings.Contains(u, "://") {
|
if strings.Contains(u, "://") {
|
||||||
return defaultUpstream, fmt.Errorf("wrong protocol")
|
return useDefault, fmt.Errorf("wrong protocol")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if upstream is valid plain DNS
|
// Check if upstream is valid plain DNS
|
||||||
return defaultUpstream, checkPlainDNS(u)
|
return useDefault, checkPlainDNS(u)
|
||||||
}
|
}
|
||||||
|
|
||||||
// separateUpstream returns upstream without specified domains and a bool flag that indicates if no domains were specified
|
// separateUpstream returns the upstream without the specified domains.
|
||||||
// error will be returned if upstream per domain specification is invalid
|
// useDefault is true when a default upstream must be used.
|
||||||
func separateUpstream(upstream string) (string, bool, error) {
|
func separateUpstream(upstreamStr string) (upstream string, useDefault bool, err error) {
|
||||||
defaultUpstream := true
|
defer agherr.Annotate("bad upstream for domain spec %q: %w", &err, upstreamStr)
|
||||||
if strings.HasPrefix(upstream, "[/") {
|
|
||||||
defaultUpstream = false
|
if !strings.HasPrefix(upstreamStr, "[/") {
|
||||||
// split domains and upstream string
|
return upstreamStr, true, nil
|
||||||
domainsAndUpstream := strings.Split(strings.TrimPrefix(upstream, "[/"), "/]")
|
|
||||||
if len(domainsAndUpstream) != 2 {
|
|
||||||
return "", defaultUpstream, fmt.Errorf("wrong dns upstream per domain specification: %s", upstream)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// split domains list and validate each one
|
parts := strings.Split(upstreamStr[2:], "/]")
|
||||||
for _, host := range strings.Split(domainsAndUpstream[0], "/") {
|
if len(parts) != 2 {
|
||||||
if host != "" {
|
return "", false, agherr.Error("duplicated separator")
|
||||||
if err := utils.IsValidHostname(host); err != nil {
|
}
|
||||||
return "", defaultUpstream, err
|
|
||||||
|
domains := parts[0]
|
||||||
|
upstream = parts[1]
|
||||||
|
for i, host := range strings.Split(domains, "/") {
|
||||||
|
if host == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
err = aghnet.ValidateDomainName(host)
|
||||||
|
if err != nil {
|
||||||
|
return "", false, fmt.Errorf("domain at index %d: %w", i, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
upstream = domainsAndUpstream[1]
|
return upstream, false, nil
|
||||||
}
|
|
||||||
return upstream, defaultUpstream, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkPlainDNS checks if host is plain DNS
|
// checkPlainDNS checks if host is plain DNS
|
||||||
|
@ -462,13 +468,13 @@ func checkDNS(input string, bootstrap []string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// separate upstream from domains list
|
// separate upstream from domains list
|
||||||
input, defaultUpstream, err := separateUpstream(input)
|
input, useDefault, err := separateUpstream(input)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("wrong upstream format: %w", err)
|
return fmt.Errorf("wrong upstream format: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// No need to check this DNS server
|
// No need to check this DNS server
|
||||||
if !defaultUpstream {
|
if !useDefault {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -213,3 +213,158 @@ func TestDNSForwardHTTTP_handleSetConfig(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(a.garipov): Rewrite to check the actual error messages.
|
||||||
|
func TestValidateUpstream(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
upstream string
|
||||||
|
valid bool
|
||||||
|
wantDef bool
|
||||||
|
}{{
|
||||||
|
name: "invalid",
|
||||||
|
upstream: "1.2.3.4.5",
|
||||||
|
valid: false,
|
||||||
|
wantDef: false,
|
||||||
|
}, {
|
||||||
|
name: "invalid",
|
||||||
|
upstream: "123.3.7m",
|
||||||
|
valid: false,
|
||||||
|
wantDef: false,
|
||||||
|
}, {
|
||||||
|
name: "invalid",
|
||||||
|
upstream: "htttps://google.com/dns-query",
|
||||||
|
valid: false,
|
||||||
|
wantDef: false,
|
||||||
|
}, {
|
||||||
|
name: "invalid",
|
||||||
|
upstream: "[/host.com]tls://dns.adguard.com",
|
||||||
|
valid: false,
|
||||||
|
wantDef: false,
|
||||||
|
}, {
|
||||||
|
name: "invalid",
|
||||||
|
upstream: "[host.ru]#",
|
||||||
|
valid: false,
|
||||||
|
wantDef: false,
|
||||||
|
}, {
|
||||||
|
name: "valid_default",
|
||||||
|
upstream: "1.1.1.1",
|
||||||
|
valid: true,
|
||||||
|
wantDef: true,
|
||||||
|
}, {
|
||||||
|
name: "valid_default",
|
||||||
|
upstream: "tls://1.1.1.1",
|
||||||
|
valid: true,
|
||||||
|
wantDef: true,
|
||||||
|
}, {
|
||||||
|
name: "valid_default",
|
||||||
|
upstream: "https://dns.adguard.com/dns-query",
|
||||||
|
valid: true,
|
||||||
|
wantDef: true,
|
||||||
|
}, {
|
||||||
|
name: "valid_default",
|
||||||
|
upstream: "sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
|
||||||
|
valid: true,
|
||||||
|
wantDef: true,
|
||||||
|
}, {
|
||||||
|
name: "valid",
|
||||||
|
upstream: "[/host.com/]1.1.1.1",
|
||||||
|
valid: true,
|
||||||
|
wantDef: false,
|
||||||
|
}, {
|
||||||
|
name: "valid",
|
||||||
|
upstream: "[//]tls://1.1.1.1",
|
||||||
|
valid: true,
|
||||||
|
wantDef: false,
|
||||||
|
}, {
|
||||||
|
name: "valid",
|
||||||
|
upstream: "[/www.host.com/]#",
|
||||||
|
valid: true,
|
||||||
|
wantDef: false,
|
||||||
|
}, {
|
||||||
|
name: "valid",
|
||||||
|
upstream: "[/host.com/google.com/]8.8.8.8",
|
||||||
|
valid: true,
|
||||||
|
wantDef: false,
|
||||||
|
}, {
|
||||||
|
name: "valid",
|
||||||
|
upstream: "[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
|
||||||
|
valid: true,
|
||||||
|
wantDef: false,
|
||||||
|
}, {
|
||||||
|
name: "idna",
|
||||||
|
upstream: "[/пример.рф/]8.8.8.8",
|
||||||
|
valid: true,
|
||||||
|
wantDef: false,
|
||||||
|
}, {
|
||||||
|
name: "bad_domain",
|
||||||
|
upstream: "[/!/]8.8.8.8",
|
||||||
|
valid: false,
|
||||||
|
wantDef: false,
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
defaultUpstream, err := validateUpstream(tc.upstream)
|
||||||
|
require.Equal(t, tc.valid, err == nil)
|
||||||
|
if tc.valid {
|
||||||
|
assert.Equal(t, tc.wantDef, defaultUpstream)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateUpstreamsSet(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
msg string
|
||||||
|
set []string
|
||||||
|
wantNil bool
|
||||||
|
}{{
|
||||||
|
name: "empty",
|
||||||
|
msg: "empty upstreams array should be valid",
|
||||||
|
set: nil,
|
||||||
|
wantNil: true,
|
||||||
|
}, {
|
||||||
|
name: "comment",
|
||||||
|
msg: "comments should not be validated",
|
||||||
|
set: []string{"# comment"},
|
||||||
|
wantNil: true,
|
||||||
|
}, {
|
||||||
|
name: "valid_no_default",
|
||||||
|
msg: "there is no default upstream",
|
||||||
|
set: []string{
|
||||||
|
"[/host.com/]1.1.1.1",
|
||||||
|
"[//]tls://1.1.1.1",
|
||||||
|
"[/www.host.com/]#",
|
||||||
|
"[/host.com/google.com/]8.8.8.8",
|
||||||
|
"[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
|
||||||
|
},
|
||||||
|
wantNil: false,
|
||||||
|
}, {
|
||||||
|
name: "valid_with_default",
|
||||||
|
msg: "upstreams set is valid, but doesn't pass through validation cause: %s",
|
||||||
|
set: []string{
|
||||||
|
"[/host.com/]1.1.1.1",
|
||||||
|
"[//]tls://1.1.1.1",
|
||||||
|
"[/www.host.com/]#",
|
||||||
|
"[/host.com/google.com/]8.8.8.8",
|
||||||
|
"[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
|
||||||
|
"8.8.8.8",
|
||||||
|
},
|
||||||
|
wantNil: true,
|
||||||
|
}, {
|
||||||
|
name: "invalid",
|
||||||
|
msg: "there is an invalid upstream in set, but it pass through validation",
|
||||||
|
set: []string{"dhcp://fake.dns"},
|
||||||
|
wantNil: false,
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
err := ValidateUpstreams(tc.set)
|
||||||
|
|
||||||
|
assert.Equalf(t, tc.wantNil, err == nil, tc.msg, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -5,7 +5,7 @@ import (
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/AdguardTeam/golibs/utils"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||||
)
|
)
|
||||||
|
|
||||||
// IPFromAddr gets IP address from addr.
|
// IPFromAddr gets IP address from addr.
|
||||||
|
@ -58,9 +58,10 @@ func matchDomainWildcard(host, wildcard string) bool {
|
||||||
|
|
||||||
// Return TRUE if client's SNI value matches DNS names from certificate
|
// Return TRUE if client's SNI value matches DNS names from certificate
|
||||||
func matchDNSName(dnsNames []string, sni string) bool {
|
func matchDNSName(dnsNames []string, sni string) bool {
|
||||||
if utils.IsValidHostname(sni) != nil {
|
if aghnet.ValidateDomainName(sni) != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if findSorted(dnsNames, sni) != -1 {
|
if findSorted(dnsNames, sni) != -1 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,6 +12,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
|
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsfilter"
|
"github.com/AdguardTeam/AdGuardHome/internal/dnsfilter"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||||
|
@ -20,7 +21,6 @@ import (
|
||||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/AdguardTeam/golibs/utils"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const clientsUpdatePeriod = 10 * time.Minute
|
const clientsUpdatePeriod = 10 * time.Minute
|
||||||
|
@ -751,7 +751,7 @@ func (clients *clientsContainer) addFromSystemARP() {
|
||||||
|
|
||||||
host := ln[:open]
|
host := ln[:open]
|
||||||
ip := ln[open+2 : close]
|
ip := ln[open+2 : close]
|
||||||
if utils.IsValidHostname(host) != nil || net.ParseIP(ip) == nil {
|
if aghnet.ValidateDomainName(host) != nil || net.ParseIP(ip) == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -123,6 +123,7 @@ func handleMobileConfig(w http.ResponseWriter, r *http.Request, dnsp string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
clientID := q.Get("client_id")
|
clientID := q.Get("client_id")
|
||||||
|
if clientID != "" {
|
||||||
err = dnsforward.ValidateClientID(clientID)
|
err = dnsforward.ValidateClientID(clientID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
w.WriteHeader(http.StatusBadRequest)
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
@ -136,6 +137,7 @@ func handleMobileConfig(w http.ResponseWriter, r *http.Request, dnsp string) {
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
d := dnsSettings{
|
d := dnsSettings{
|
||||||
DNSProtocol: dnsp,
|
DNSProtocol: dnsp,
|
||||||
|
|
Loading…
Reference in New Issue