+ DNS: TLS handshake: terminate handshake on bad SNI

This commit is contained in:
Simon Zolin 2020-01-22 18:40:43 +03:00
parent ac156b9612
commit bfd1f3b650
2 changed files with 87 additions and 3 deletions

View File

@ -2,10 +2,12 @@ package dnsforward
import ( import (
"crypto/tls" "crypto/tls"
"crypto/x509"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
"runtime" "runtime"
"sort"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -16,6 +18,7 @@ import (
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/utils"
"github.com/joomcode/errorx" "github.com/joomcode/errorx"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@ -151,6 +154,7 @@ type FilteringConfig struct {
// TLSConfig is the TLS configuration for HTTPS, DNS-over-HTTPS, and DNS-over-TLS // TLSConfig is the TLS configuration for HTTPS, DNS-over-HTTPS, and DNS-over-TLS
type TLSConfig struct { type TLSConfig struct {
TLSListenAddr *net.TCPAddr `yaml:"-" json:"-"` TLSListenAddr *net.TCPAddr `yaml:"-" json:"-"`
StrictSNICheck bool `yaml:"strict_sni_check" json:"-"` // Reject connection if the client uses server name (in SNI) that doesn't match the certificate
CertificateChain string `yaml:"certificate_chain" json:"certificate_chain"` // PEM-encoded certificates chain CertificateChain string `yaml:"certificate_chain" json:"certificate_chain"` // PEM-encoded certificates chain
PrivateKey string `yaml:"private_key" json:"private_key"` // PEM-encoded private key PrivateKey string `yaml:"private_key" json:"private_key"` // PEM-encoded private key
@ -159,6 +163,9 @@ type TLSConfig struct {
CertificateChainData []byte `yaml:"-" json:"-"` CertificateChainData []byte `yaml:"-" json:"-"`
PrivateKeyData []byte `yaml:"-" json:"-"` PrivateKeyData []byte `yaml:"-" json:"-"`
cert tls.Certificate // nolint(structcheck) - linter thinks that this field is unused, while TLSConfig is directly included into ServerConfig
dnsNames []string // nolint(structcheck) // DNS names from certificate (SAN) or CN value from Subject
} }
// ServerConfig represents server configuration. // ServerConfig represents server configuration.
@ -234,6 +241,7 @@ func (s *Server) startInternal() error {
} }
// Prepare the object // Prepare the object
// nolint(gocyclo)
func (s *Server) Prepare(config *ServerConfig) error { func (s *Server) Prepare(config *ServerConfig) error {
if config != nil { if config != nil {
s.conf = *config s.conf = *config
@ -305,12 +313,28 @@ func (s *Server) Prepare(config *ServerConfig) error {
if s.conf.TLSListenAddr != nil && len(s.conf.CertificateChainData) != 0 && len(s.conf.PrivateKeyData) != 0 { if s.conf.TLSListenAddr != nil && len(s.conf.CertificateChainData) != 0 && len(s.conf.PrivateKeyData) != 0 {
proxyConfig.TLSListenAddr = s.conf.TLSListenAddr proxyConfig.TLSListenAddr = s.conf.TLSListenAddr
keypair, err := tls.X509KeyPair(s.conf.CertificateChainData, s.conf.PrivateKeyData) s.conf.cert, err = tls.X509KeyPair(s.conf.CertificateChainData, s.conf.PrivateKeyData)
if err != nil { if err != nil {
return errorx.Decorate(err, "Failed to parse TLS keypair") return errorx.Decorate(err, "Failed to parse TLS keypair")
} }
if s.conf.StrictSNICheck {
x, err := x509.ParseCertificate(s.conf.cert.Certificate[0])
if err != nil {
return errorx.Decorate(err, "x509.ParseCertificate(): %s", err)
}
if len(x.DNSNames) != 0 {
s.conf.dnsNames = x.DNSNames
log.Debug("DNS: using DNS names from certificate's SAN: %v", x.DNSNames)
sort.Strings(s.conf.dnsNames)
} else {
s.conf.dnsNames = append(s.conf.dnsNames, x.Subject.CommonName)
log.Debug("DNS: using DNS name from certificate's CN: %s", x.Subject.CommonName)
}
}
proxyConfig.TLSConfig = &tls.Config{ proxyConfig.TLSConfig = &tls.Config{
Certificates: []tls.Certificate{keypair}, GetCertificate: s.onGetCertificate,
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
} }
} }
@ -329,6 +353,53 @@ func (s *Server) Prepare(config *ServerConfig) error {
return nil return nil
} }
// Find value in a sorted array
func findSorted(ar []string, val string) int {
i := sort.SearchStrings(ar, val)
if i == len(ar) || ar[i] != val {
return -1
}
return i
}
func isWildcard(host string) bool {
return len(host) >= 2 &&
host[0] == '*' && host[1] == '.'
}
// Return TRUE if host name matches a wildcard pattern
func matchDomainWildcard(host, wildcard string) bool {
return isWildcard(wildcard) &&
strings.HasSuffix(host, wildcard[1:])
}
// Return TRUE if client's SNI value matches DNS names from certificate
func matchDNSName(dnsNames []string, sni string) bool {
if utils.IsValidHostname(sni) != nil {
return false
}
if findSorted(dnsNames, sni) != -1 {
return true
}
for _, dn := range dnsNames {
if matchDomainWildcard(sni, dn) {
return true
}
}
return false
}
// Called by 'tls' package when Client Hello is received
// If the server name (from SNI) supplied by client is incorrect - we terminate the ongoing TLS handshake.
func (s *Server) onGetCertificate(ch *tls.ClientHelloInfo) (*tls.Certificate, error) {
if s.conf.StrictSNICheck && !matchDNSName(s.conf.dnsNames, ch.ServerName) {
log.Info("DNS: TLS: unknown SNI in Client Hello: %s", ch.ServerName)
return nil, fmt.Errorf("Invalid SNI")
}
return &s.conf.cert, nil
}
// Stop stops the DNS server // Stop stops the DNS server
func (s *Server) Stop() error { func (s *Server) Stop() error {
s.Lock() s.Lock()

View File

@ -10,6 +10,7 @@ import (
"encoding/pem" "encoding/pem"
"math/big" "math/big"
"net" "net"
"sort"
"sync" "sync"
"testing" "testing"
"time" "time"
@ -887,3 +888,15 @@ func TestIpFromAddr(t *testing.T) {
a = ipFromAddr(nil) a = ipFromAddr(nil)
assert.True(t, a == "") assert.True(t, a == "")
} }
func TestMatchDNSName(t *testing.T) {
dnsNames := []string{"host1", "*.host2", "1.2.3.4"}
sort.Strings(dnsNames)
assert.True(t, matchDNSName(dnsNames, "host1"))
assert.True(t, matchDNSName(dnsNames, "a.host2"))
assert.True(t, matchDNSName(dnsNames, "b.a.host2"))
assert.True(t, matchDNSName(dnsNames, "1.2.3.4"))
assert.True(t, !matchDNSName(dnsNames, "host2"))
assert.True(t, !matchDNSName(dnsNames, ""))
assert.True(t, !matchDNSName(dnsNames, "*.host2"))
}