+ DNS: TLS handshake: terminate handshake on bad SNI
This commit is contained in:
parent
ac156b9612
commit
bfd1f3b650
|
@ -2,10 +2,12 @@ package dnsforward
|
|||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
@ -16,6 +18,7 @@ import (
|
|||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/utils"
|
||||
"github.com/joomcode/errorx"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
@ -151,6 +154,7 @@ type FilteringConfig struct {
|
|||
// TLSConfig is the TLS configuration for HTTPS, DNS-over-HTTPS, and DNS-over-TLS
|
||||
type TLSConfig struct {
|
||||
TLSListenAddr *net.TCPAddr `yaml:"-" json:"-"`
|
||||
StrictSNICheck bool `yaml:"strict_sni_check" json:"-"` // Reject connection if the client uses server name (in SNI) that doesn't match the certificate
|
||||
CertificateChain string `yaml:"certificate_chain" json:"certificate_chain"` // PEM-encoded certificates chain
|
||||
PrivateKey string `yaml:"private_key" json:"private_key"` // PEM-encoded private key
|
||||
|
||||
|
@ -159,6 +163,9 @@ type TLSConfig struct {
|
|||
|
||||
CertificateChainData []byte `yaml:"-" json:"-"`
|
||||
PrivateKeyData []byte `yaml:"-" json:"-"`
|
||||
|
||||
cert tls.Certificate // nolint(structcheck) - linter thinks that this field is unused, while TLSConfig is directly included into ServerConfig
|
||||
dnsNames []string // nolint(structcheck) // DNS names from certificate (SAN) or CN value from Subject
|
||||
}
|
||||
|
||||
// ServerConfig represents server configuration.
|
||||
|
@ -234,6 +241,7 @@ func (s *Server) startInternal() error {
|
|||
}
|
||||
|
||||
// Prepare the object
|
||||
// nolint(gocyclo)
|
||||
func (s *Server) Prepare(config *ServerConfig) error {
|
||||
if config != nil {
|
||||
s.conf = *config
|
||||
|
@ -305,12 +313,28 @@ func (s *Server) Prepare(config *ServerConfig) error {
|
|||
|
||||
if s.conf.TLSListenAddr != nil && len(s.conf.CertificateChainData) != 0 && len(s.conf.PrivateKeyData) != 0 {
|
||||
proxyConfig.TLSListenAddr = s.conf.TLSListenAddr
|
||||
keypair, err := tls.X509KeyPair(s.conf.CertificateChainData, s.conf.PrivateKeyData)
|
||||
s.conf.cert, err = tls.X509KeyPair(s.conf.CertificateChainData, s.conf.PrivateKeyData)
|
||||
if err != nil {
|
||||
return errorx.Decorate(err, "Failed to parse TLS keypair")
|
||||
}
|
||||
|
||||
if s.conf.StrictSNICheck {
|
||||
x, err := x509.ParseCertificate(s.conf.cert.Certificate[0])
|
||||
if err != nil {
|
||||
return errorx.Decorate(err, "x509.ParseCertificate(): %s", err)
|
||||
}
|
||||
if len(x.DNSNames) != 0 {
|
||||
s.conf.dnsNames = x.DNSNames
|
||||
log.Debug("DNS: using DNS names from certificate's SAN: %v", x.DNSNames)
|
||||
sort.Strings(s.conf.dnsNames)
|
||||
} else {
|
||||
s.conf.dnsNames = append(s.conf.dnsNames, x.Subject.CommonName)
|
||||
log.Debug("DNS: using DNS name from certificate's CN: %s", x.Subject.CommonName)
|
||||
}
|
||||
}
|
||||
|
||||
proxyConfig.TLSConfig = &tls.Config{
|
||||
Certificates: []tls.Certificate{keypair},
|
||||
GetCertificate: s.onGetCertificate,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
}
|
||||
|
@ -329,6 +353,53 @@ func (s *Server) Prepare(config *ServerConfig) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// Find value in a sorted array
|
||||
func findSorted(ar []string, val string) int {
|
||||
i := sort.SearchStrings(ar, val)
|
||||
if i == len(ar) || ar[i] != val {
|
||||
return -1
|
||||
}
|
||||
return i
|
||||
}
|
||||
|
||||
func isWildcard(host string) bool {
|
||||
return len(host) >= 2 &&
|
||||
host[0] == '*' && host[1] == '.'
|
||||
}
|
||||
|
||||
// Return TRUE if host name matches a wildcard pattern
|
||||
func matchDomainWildcard(host, wildcard string) bool {
|
||||
return isWildcard(wildcard) &&
|
||||
strings.HasSuffix(host, wildcard[1:])
|
||||
}
|
||||
|
||||
// Return TRUE if client's SNI value matches DNS names from certificate
|
||||
func matchDNSName(dnsNames []string, sni string) bool {
|
||||
if utils.IsValidHostname(sni) != nil {
|
||||
return false
|
||||
}
|
||||
if findSorted(dnsNames, sni) != -1 {
|
||||
return true
|
||||
}
|
||||
|
||||
for _, dn := range dnsNames {
|
||||
if matchDomainWildcard(sni, dn) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Called by 'tls' package when Client Hello is received
|
||||
// If the server name (from SNI) supplied by client is incorrect - we terminate the ongoing TLS handshake.
|
||||
func (s *Server) onGetCertificate(ch *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
if s.conf.StrictSNICheck && !matchDNSName(s.conf.dnsNames, ch.ServerName) {
|
||||
log.Info("DNS: TLS: unknown SNI in Client Hello: %s", ch.ServerName)
|
||||
return nil, fmt.Errorf("Invalid SNI")
|
||||
}
|
||||
return &s.conf.cert, nil
|
||||
}
|
||||
|
||||
// Stop stops the DNS server
|
||||
func (s *Server) Stop() error {
|
||||
s.Lock()
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
"encoding/pem"
|
||||
"math/big"
|
||||
"net"
|
||||
"sort"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -887,3 +888,15 @@ func TestIpFromAddr(t *testing.T) {
|
|||
a = ipFromAddr(nil)
|
||||
assert.True(t, a == "")
|
||||
}
|
||||
|
||||
func TestMatchDNSName(t *testing.T) {
|
||||
dnsNames := []string{"host1", "*.host2", "1.2.3.4"}
|
||||
sort.Strings(dnsNames)
|
||||
assert.True(t, matchDNSName(dnsNames, "host1"))
|
||||
assert.True(t, matchDNSName(dnsNames, "a.host2"))
|
||||
assert.True(t, matchDNSName(dnsNames, "b.a.host2"))
|
||||
assert.True(t, matchDNSName(dnsNames, "1.2.3.4"))
|
||||
assert.True(t, !matchDNSName(dnsNames, "host2"))
|
||||
assert.True(t, !matchDNSName(dnsNames, ""))
|
||||
assert.True(t, !matchDNSName(dnsNames, "*.host2"))
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue