diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index 85be3ac9..6ecc7442 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -2,10 +2,12 @@ package dnsforward import ( "crypto/tls" + "crypto/x509" "fmt" "net" "net/http" "runtime" + "sort" "strings" "sync" "time" @@ -16,6 +18,7 @@ import ( "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/utils" "github.com/joomcode/errorx" "github.com/miekg/dns" ) @@ -151,6 +154,7 @@ type FilteringConfig struct { // TLSConfig is the TLS configuration for HTTPS, DNS-over-HTTPS, and DNS-over-TLS type TLSConfig struct { TLSListenAddr *net.TCPAddr `yaml:"-" json:"-"` + StrictSNICheck bool `yaml:"strict_sni_check" json:"-"` // Reject connection if the client uses server name (in SNI) that doesn't match the certificate CertificateChain string `yaml:"certificate_chain" json:"certificate_chain"` // PEM-encoded certificates chain PrivateKey string `yaml:"private_key" json:"private_key"` // PEM-encoded private key @@ -159,6 +163,9 @@ type TLSConfig struct { CertificateChainData []byte `yaml:"-" json:"-"` PrivateKeyData []byte `yaml:"-" json:"-"` + + cert tls.Certificate // nolint(structcheck) - linter thinks that this field is unused, while TLSConfig is directly included into ServerConfig + dnsNames []string // nolint(structcheck) // DNS names from certificate (SAN) or CN value from Subject } // ServerConfig represents server configuration. @@ -234,6 +241,7 @@ func (s *Server) startInternal() error { } // Prepare the object +// nolint(gocyclo) func (s *Server) Prepare(config *ServerConfig) error { if config != nil { s.conf = *config @@ -305,13 +313,29 @@ func (s *Server) Prepare(config *ServerConfig) error { if s.conf.TLSListenAddr != nil && len(s.conf.CertificateChainData) != 0 && len(s.conf.PrivateKeyData) != 0 { proxyConfig.TLSListenAddr = s.conf.TLSListenAddr - keypair, err := tls.X509KeyPair(s.conf.CertificateChainData, s.conf.PrivateKeyData) + s.conf.cert, err = tls.X509KeyPair(s.conf.CertificateChainData, s.conf.PrivateKeyData) if err != nil { return errorx.Decorate(err, "Failed to parse TLS keypair") } + + if s.conf.StrictSNICheck { + x, err := x509.ParseCertificate(s.conf.cert.Certificate[0]) + if err != nil { + return errorx.Decorate(err, "x509.ParseCertificate(): %s", err) + } + if len(x.DNSNames) != 0 { + s.conf.dnsNames = x.DNSNames + log.Debug("DNS: using DNS names from certificate's SAN: %v", x.DNSNames) + sort.Strings(s.conf.dnsNames) + } else { + s.conf.dnsNames = append(s.conf.dnsNames, x.Subject.CommonName) + log.Debug("DNS: using DNS name from certificate's CN: %s", x.Subject.CommonName) + } + } + proxyConfig.TLSConfig = &tls.Config{ - Certificates: []tls.Certificate{keypair}, - MinVersion: tls.VersionTLS12, + GetCertificate: s.onGetCertificate, + MinVersion: tls.VersionTLS12, } } @@ -329,6 +353,53 @@ func (s *Server) Prepare(config *ServerConfig) error { return nil } +// Find value in a sorted array +func findSorted(ar []string, val string) int { + i := sort.SearchStrings(ar, val) + if i == len(ar) || ar[i] != val { + return -1 + } + return i +} + +func isWildcard(host string) bool { + return len(host) >= 2 && + host[0] == '*' && host[1] == '.' +} + +// Return TRUE if host name matches a wildcard pattern +func matchDomainWildcard(host, wildcard string) bool { + return isWildcard(wildcard) && + strings.HasSuffix(host, wildcard[1:]) +} + +// Return TRUE if client's SNI value matches DNS names from certificate +func matchDNSName(dnsNames []string, sni string) bool { + if utils.IsValidHostname(sni) != nil { + return false + } + if findSorted(dnsNames, sni) != -1 { + return true + } + + for _, dn := range dnsNames { + if matchDomainWildcard(sni, dn) { + return true + } + } + return false +} + +// Called by 'tls' package when Client Hello is received +// If the server name (from SNI) supplied by client is incorrect - we terminate the ongoing TLS handshake. +func (s *Server) onGetCertificate(ch *tls.ClientHelloInfo) (*tls.Certificate, error) { + if s.conf.StrictSNICheck && !matchDNSName(s.conf.dnsNames, ch.ServerName) { + log.Info("DNS: TLS: unknown SNI in Client Hello: %s", ch.ServerName) + return nil, fmt.Errorf("Invalid SNI") + } + return &s.conf.cert, nil +} + // Stop stops the DNS server func (s *Server) Stop() error { s.Lock() diff --git a/dnsforward/dnsforward_test.go b/dnsforward/dnsforward_test.go index 53803c78..35140eaa 100644 --- a/dnsforward/dnsforward_test.go +++ b/dnsforward/dnsforward_test.go @@ -10,6 +10,7 @@ import ( "encoding/pem" "math/big" "net" + "sort" "sync" "testing" "time" @@ -887,3 +888,15 @@ func TestIpFromAddr(t *testing.T) { a = ipFromAddr(nil) assert.True(t, a == "") } + +func TestMatchDNSName(t *testing.T) { + dnsNames := []string{"host1", "*.host2", "1.2.3.4"} + sort.Strings(dnsNames) + assert.True(t, matchDNSName(dnsNames, "host1")) + assert.True(t, matchDNSName(dnsNames, "a.host2")) + assert.True(t, matchDNSName(dnsNames, "b.a.host2")) + assert.True(t, matchDNSName(dnsNames, "1.2.3.4")) + assert.True(t, !matchDNSName(dnsNames, "host2")) + assert.True(t, !matchDNSName(dnsNames, "")) + assert.True(t, !matchDNSName(dnsNames, "*.host2")) +}