+ DNS: TLS handshake: terminate handshake on bad SNI
This commit is contained in:
parent
ac156b9612
commit
bfd1f3b650
|
@ -2,10 +2,12 @@ package dnsforward
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
@ -16,6 +18,7 @@ import (
|
||||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
|
"github.com/AdguardTeam/golibs/utils"
|
||||||
"github.com/joomcode/errorx"
|
"github.com/joomcode/errorx"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
)
|
)
|
||||||
|
@ -151,6 +154,7 @@ type FilteringConfig struct {
|
||||||
// TLSConfig is the TLS configuration for HTTPS, DNS-over-HTTPS, and DNS-over-TLS
|
// TLSConfig is the TLS configuration for HTTPS, DNS-over-HTTPS, and DNS-over-TLS
|
||||||
type TLSConfig struct {
|
type TLSConfig struct {
|
||||||
TLSListenAddr *net.TCPAddr `yaml:"-" json:"-"`
|
TLSListenAddr *net.TCPAddr `yaml:"-" json:"-"`
|
||||||
|
StrictSNICheck bool `yaml:"strict_sni_check" json:"-"` // Reject connection if the client uses server name (in SNI) that doesn't match the certificate
|
||||||
CertificateChain string `yaml:"certificate_chain" json:"certificate_chain"` // PEM-encoded certificates chain
|
CertificateChain string `yaml:"certificate_chain" json:"certificate_chain"` // PEM-encoded certificates chain
|
||||||
PrivateKey string `yaml:"private_key" json:"private_key"` // PEM-encoded private key
|
PrivateKey string `yaml:"private_key" json:"private_key"` // PEM-encoded private key
|
||||||
|
|
||||||
|
@ -159,6 +163,9 @@ type TLSConfig struct {
|
||||||
|
|
||||||
CertificateChainData []byte `yaml:"-" json:"-"`
|
CertificateChainData []byte `yaml:"-" json:"-"`
|
||||||
PrivateKeyData []byte `yaml:"-" json:"-"`
|
PrivateKeyData []byte `yaml:"-" json:"-"`
|
||||||
|
|
||||||
|
cert tls.Certificate // nolint(structcheck) - linter thinks that this field is unused, while TLSConfig is directly included into ServerConfig
|
||||||
|
dnsNames []string // nolint(structcheck) // DNS names from certificate (SAN) or CN value from Subject
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServerConfig represents server configuration.
|
// ServerConfig represents server configuration.
|
||||||
|
@ -234,6 +241,7 @@ func (s *Server) startInternal() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prepare the object
|
// Prepare the object
|
||||||
|
// nolint(gocyclo)
|
||||||
func (s *Server) Prepare(config *ServerConfig) error {
|
func (s *Server) Prepare(config *ServerConfig) error {
|
||||||
if config != nil {
|
if config != nil {
|
||||||
s.conf = *config
|
s.conf = *config
|
||||||
|
@ -305,13 +313,29 @@ func (s *Server) Prepare(config *ServerConfig) error {
|
||||||
|
|
||||||
if s.conf.TLSListenAddr != nil && len(s.conf.CertificateChainData) != 0 && len(s.conf.PrivateKeyData) != 0 {
|
if s.conf.TLSListenAddr != nil && len(s.conf.CertificateChainData) != 0 && len(s.conf.PrivateKeyData) != 0 {
|
||||||
proxyConfig.TLSListenAddr = s.conf.TLSListenAddr
|
proxyConfig.TLSListenAddr = s.conf.TLSListenAddr
|
||||||
keypair, err := tls.X509KeyPair(s.conf.CertificateChainData, s.conf.PrivateKeyData)
|
s.conf.cert, err = tls.X509KeyPair(s.conf.CertificateChainData, s.conf.PrivateKeyData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorx.Decorate(err, "Failed to parse TLS keypair")
|
return errorx.Decorate(err, "Failed to parse TLS keypair")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if s.conf.StrictSNICheck {
|
||||||
|
x, err := x509.ParseCertificate(s.conf.cert.Certificate[0])
|
||||||
|
if err != nil {
|
||||||
|
return errorx.Decorate(err, "x509.ParseCertificate(): %s", err)
|
||||||
|
}
|
||||||
|
if len(x.DNSNames) != 0 {
|
||||||
|
s.conf.dnsNames = x.DNSNames
|
||||||
|
log.Debug("DNS: using DNS names from certificate's SAN: %v", x.DNSNames)
|
||||||
|
sort.Strings(s.conf.dnsNames)
|
||||||
|
} else {
|
||||||
|
s.conf.dnsNames = append(s.conf.dnsNames, x.Subject.CommonName)
|
||||||
|
log.Debug("DNS: using DNS name from certificate's CN: %s", x.Subject.CommonName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
proxyConfig.TLSConfig = &tls.Config{
|
proxyConfig.TLSConfig = &tls.Config{
|
||||||
Certificates: []tls.Certificate{keypair},
|
GetCertificate: s.onGetCertificate,
|
||||||
MinVersion: tls.VersionTLS12,
|
MinVersion: tls.VersionTLS12,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -329,6 +353,53 @@ func (s *Server) Prepare(config *ServerConfig) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Find value in a sorted array
|
||||||
|
func findSorted(ar []string, val string) int {
|
||||||
|
i := sort.SearchStrings(ar, val)
|
||||||
|
if i == len(ar) || ar[i] != val {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
|
||||||
|
func isWildcard(host string) bool {
|
||||||
|
return len(host) >= 2 &&
|
||||||
|
host[0] == '*' && host[1] == '.'
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return TRUE if host name matches a wildcard pattern
|
||||||
|
func matchDomainWildcard(host, wildcard string) bool {
|
||||||
|
return isWildcard(wildcard) &&
|
||||||
|
strings.HasSuffix(host, wildcard[1:])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return TRUE if client's SNI value matches DNS names from certificate
|
||||||
|
func matchDNSName(dnsNames []string, sni string) bool {
|
||||||
|
if utils.IsValidHostname(sni) != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if findSorted(dnsNames, sni) != -1 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, dn := range dnsNames {
|
||||||
|
if matchDomainWildcard(sni, dn) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Called by 'tls' package when Client Hello is received
|
||||||
|
// If the server name (from SNI) supplied by client is incorrect - we terminate the ongoing TLS handshake.
|
||||||
|
func (s *Server) onGetCertificate(ch *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||||
|
if s.conf.StrictSNICheck && !matchDNSName(s.conf.dnsNames, ch.ServerName) {
|
||||||
|
log.Info("DNS: TLS: unknown SNI in Client Hello: %s", ch.ServerName)
|
||||||
|
return nil, fmt.Errorf("Invalid SNI")
|
||||||
|
}
|
||||||
|
return &s.conf.cert, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Stop stops the DNS server
|
// Stop stops the DNS server
|
||||||
func (s *Server) Stop() error {
|
func (s *Server) Stop() error {
|
||||||
s.Lock()
|
s.Lock()
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"math/big"
|
"math/big"
|
||||||
"net"
|
"net"
|
||||||
|
"sort"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
@ -887,3 +888,15 @@ func TestIpFromAddr(t *testing.T) {
|
||||||
a = ipFromAddr(nil)
|
a = ipFromAddr(nil)
|
||||||
assert.True(t, a == "")
|
assert.True(t, a == "")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMatchDNSName(t *testing.T) {
|
||||||
|
dnsNames := []string{"host1", "*.host2", "1.2.3.4"}
|
||||||
|
sort.Strings(dnsNames)
|
||||||
|
assert.True(t, matchDNSName(dnsNames, "host1"))
|
||||||
|
assert.True(t, matchDNSName(dnsNames, "a.host2"))
|
||||||
|
assert.True(t, matchDNSName(dnsNames, "b.a.host2"))
|
||||||
|
assert.True(t, matchDNSName(dnsNames, "1.2.3.4"))
|
||||||
|
assert.True(t, !matchDNSName(dnsNames, "host2"))
|
||||||
|
assert.True(t, !matchDNSName(dnsNames, ""))
|
||||||
|
assert.True(t, !matchDNSName(dnsNames, "*.host2"))
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue