diff --git a/CHANGELOG.md b/CHANGELOG.md index 7560957b..bfae45f1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,17 +19,20 @@ and this project adheres to ### Changed -- Increase the HTTP API request body size limit for the `/control/access/set` - API ([#2666]). +- The server name sent by clients of TLS APIs is not only checked when + `strict_sni_check` is enabled ([#2664]). +- HTTP API request body size limit for the `/control/access/set` API is + increased ([#2666]). ### Fixed -- Set the request body size limit for HTTPS reqeusts as well. +- The request body size limit is now set for HTTPS requests as well. - Incorrect version tag in the Docker release ([#2663]). - DNSCrypt queries weren't marked as such in logs ([#2662]). [#2662]: https://github.com/AdguardTeam/AdGuardHome/issues/2662 [#2663]: https://github.com/AdguardTeam/AdGuardHome/issues/2663 +[#2664]: https://github.com/AdguardTeam/AdGuardHome/issues/2664 [#2666]: https://github.com/AdguardTeam/AdGuardHome/issues/2666 diff --git a/internal/dnsforward/clientid.go b/internal/dnsforward/clientid.go new file mode 100644 index 00000000..c497c7b7 --- /dev/null +++ b/internal/dnsforward/clientid.go @@ -0,0 +1,165 @@ +package dnsforward + +import ( + "crypto/tls" + "fmt" + "path" + "strings" + + "github.com/AdguardTeam/dnsproxy/proxy" + "github.com/lucas-clemente/quic-go" +) + +const maxDomainPartLen = 64 + +// ValidateClientID returns an error if clientID is not a valid client ID. +func ValidateClientID(clientID string) (err error) { + if len(clientID) > maxDomainPartLen { + return fmt.Errorf("client id %q is too long, max: %d", clientID, maxDomainPartLen) + } + + for i, r := range clientID { + if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '-' { + continue + } + + return fmt.Errorf("invalid char %q at index %d in client id %q", r, i, clientID) + } + + return nil +} + +// clientIDFromClientServerName extracts and validates a client ID. hostSrvName +// is the server name of the host. cliSrvName is the server name as sent by the +// client. When strict is true, and client and host server name don't match, +// clientIDFromClientServerName will return an error. +func clientIDFromClientServerName(hostSrvName, cliSrvName string, strict bool) (clientID string, err error) { + if hostSrvName == cliSrvName { + return "", nil + } + + if !strings.HasSuffix(cliSrvName, hostSrvName) { + if !strict { + return "", nil + } + + return "", fmt.Errorf("client server name %q doesn't match host server name %q", cliSrvName, hostSrvName) + } + + clientID = cliSrvName[:len(cliSrvName)-len(hostSrvName)-1] + err = ValidateClientID(clientID) + if err != nil { + return "", fmt.Errorf("invalid client id: %w", err) + } + + return clientID, nil +} + +// processClientIDHTTPS extracts the client's ID from the path of the +// client's DNS-over-HTTPS request. +func processClientIDHTTPS(ctx *dnsContext) (rc resultCode) { + pctx := ctx.proxyCtx + r := pctx.HTTPRequest + if r == nil { + ctx.err = fmt.Errorf("proxy ctx http request of proto %s is nil", pctx.Proto) + + return resultCodeError + } + + origPath := r.URL.Path + parts := strings.Split(path.Clean(origPath), "/") + if parts[0] == "" { + parts = parts[1:] + } + + if len(parts) == 0 || parts[0] != "dns-query" { + ctx.err = fmt.Errorf("client id check: invalid path %q", origPath) + + return resultCodeError + } + + clientID := "" + switch len(parts) { + case 1: + // Just /dns-query, no client ID. + return resultCodeSuccess + case 2: + clientID = parts[1] + default: + ctx.err = fmt.Errorf("client id check: invalid path %q: extra parts", origPath) + + return resultCodeError + } + + err := ValidateClientID(clientID) + if err != nil { + ctx.err = fmt.Errorf("client id check: invalid client id: %w", err) + + return resultCodeError + } + + ctx.clientID = clientID + + return resultCodeSuccess +} + +// tlsConn is a narrow interface for *tls.Conn to simplify testing. +type tlsConn interface { + ConnectionState() (cs tls.ConnectionState) +} + +// quicSession is a narrow interface for quic.Session to simplify testing. +type quicSession interface { + ConnectionState() (cs quic.ConnectionState) +} + +// processClientID extracts the client's ID from the server name of the client's +// DOT or DOQ request or the path of the client's DOH. +func processClientID(dctx *dnsContext) (rc resultCode) { + pctx := dctx.proxyCtx + proto := pctx.Proto + if proto == proxy.ProtoHTTPS { + return processClientIDHTTPS(dctx) + } else if proto != proxy.ProtoTLS && proto != proxy.ProtoQUIC { + return resultCodeSuccess + } + + srvConf := dctx.srv.conf + hostSrvName := srvConf.TLSConfig.ServerName + if hostSrvName == "" { + return resultCodeSuccess + } + + cliSrvName := "" + if proto == proxy.ProtoTLS { + conn := pctx.Conn + tc, ok := conn.(tlsConn) + if !ok { + dctx.err = fmt.Errorf("proxy ctx conn of proto %s is %T, want *tls.Conn", proto, conn) + + return resultCodeError + } + + cliSrvName = tc.ConnectionState().ServerName + } else if proto == proxy.ProtoQUIC { + qs, ok := pctx.QUICSession.(quicSession) + if !ok { + dctx.err = fmt.Errorf("proxy ctx quic session of proto %s is %T, want quic.Session", proto, pctx.QUICSession) + + return resultCodeError + } + + cliSrvName = qs.ConnectionState().ServerName + } + + clientID, err := clientIDFromClientServerName(hostSrvName, cliSrvName, srvConf.StrictSNICheck) + if err != nil { + dctx.err = fmt.Errorf("client id check: %w", err) + + return resultCodeError + } + + dctx.clientID = clientID + + return resultCodeSuccess +} diff --git a/internal/dnsforward/dns_test.go b/internal/dnsforward/clientid_test.go similarity index 73% rename from internal/dnsforward/dns_test.go rename to internal/dnsforward/clientid_test.go index bd0ef4ab..503203f9 100644 --- a/internal/dnsforward/dns_test.go +++ b/internal/dnsforward/clientid_test.go @@ -10,6 +10,7 @@ import ( "github.com/AdguardTeam/dnsproxy/proxy" "github.com/lucas-clemente/quic-go" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) // testTLSConn is a tlsConn for tests. @@ -53,6 +54,7 @@ func TestProcessClientID(t *testing.T) { wantClientID string wantErrMsg string wantRes resultCode + strictSNI bool }{{ name: "udp", proto: proxy.ProtoUDP, @@ -61,6 +63,7 @@ func TestProcessClientID(t *testing.T) { wantClientID: "", wantErrMsg: "", wantRes: resultCodeSuccess, + strictSNI: false, }, { name: "tls_no_client_id", proto: proxy.ProtoTLS, @@ -69,6 +72,26 @@ func TestProcessClientID(t *testing.T) { wantClientID: "", wantErrMsg: "", wantRes: resultCodeSuccess, + strictSNI: true, + }, { + name: "tls_no_client_server_name", + proto: proxy.ProtoTLS, + hostSrvName: "example.com", + cliSrvName: "", + wantClientID: "", + wantErrMsg: `client id check: client server name "" ` + + `doesn't match host server name "example.com"`, + wantRes: resultCodeError, + strictSNI: true, + }, { + name: "tls_no_client_server_name_no_strict", + proto: proxy.ProtoTLS, + hostSrvName: "example.com", + cliSrvName: "", + wantClientID: "", + wantErrMsg: "", + wantRes: resultCodeSuccess, + strictSNI: false, }, { name: "tls_client_id", proto: proxy.ProtoTLS, @@ -77,30 +100,39 @@ func TestProcessClientID(t *testing.T) { wantClientID: "cli", wantErrMsg: "", wantRes: resultCodeSuccess, + strictSNI: true, }, { name: "tls_client_id_hostname_error", proto: proxy.ProtoTLS, hostSrvName: "example.com", cliSrvName: "cli.example.net", wantClientID: "", - wantErrMsg: `client id check: client server name "cli.example.net" doesn't match host server name "example.com"`, - wantRes: resultCodeError, + wantErrMsg: `client id check: client server name "cli.example.net" ` + + `doesn't match host server name "example.com"`, + wantRes: resultCodeError, + strictSNI: true, }, { name: "tls_invalid_client_id", proto: proxy.ProtoTLS, hostSrvName: "example.com", cliSrvName: "!!!.example.com", wantClientID: "", - wantErrMsg: `client id check: invalid client id: invalid char '!' at index 0 in client id "!!!"`, - wantRes: resultCodeError, + wantErrMsg: `client id check: invalid client id: invalid char '!' ` + + `at index 0 in client id "!!!"`, + wantRes: resultCodeError, + strictSNI: true, }, { - name: "tls_client_id_too_long", - proto: proxy.ProtoTLS, - hostSrvName: "example.com", - cliSrvName: "abcdefghijklmnopqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz0123456789.example.com", + name: "tls_client_id_too_long", + proto: proxy.ProtoTLS, + hostSrvName: "example.com", + cliSrvName: `abcdefghijklmnopqrstuvwxyz0123456789abcdefghijklmno` + + `pqrstuvwxyz0123456789.example.com`, wantClientID: "", - wantErrMsg: `client id check: invalid client id: client id "abcdefghijklmnopqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz0123456789" is too long, max: 64`, - wantRes: resultCodeError, + wantErrMsg: `client id check: invalid client id: client id "abcdefghijklmno` + + `pqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz0123456789" ` + + `is too long, max: 64`, + wantRes: resultCodeError, + strictSNI: true, }, { name: "quic_client_id", proto: proxy.ProtoQUIC, @@ -109,14 +141,17 @@ func TestProcessClientID(t *testing.T) { wantClientID: "cli", wantErrMsg: "", wantRes: resultCodeSuccess, + strictSNI: true, }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { + tlsConf := TLSConfig{ + ServerName: tc.hostSrvName, + StrictSNICheck: tc.strictSNI, + } srv := &Server{ - conf: ServerConfig{ - TLSConfig: TLSConfig{ServerName: tc.hostSrvName}, - }, + conf: ServerConfig{TLSConfig: tlsConf}, } var conn net.Conn @@ -146,10 +181,11 @@ func TestProcessClientID(t *testing.T) { assert.Equal(t, tc.wantRes, res) assert.Equal(t, tc.wantClientID, dctx.clientID) - if tc.wantErrMsg != "" && assert.NotNil(t, dctx.err) { - assert.Equal(t, tc.wantErrMsg, dctx.err.Error()) - } else { + if tc.wantErrMsg == "" { assert.Nil(t, dctx.err) + } else { + require.NotNil(t, dctx.err) + assert.Equal(t, tc.wantErrMsg, dctx.err.Error()) } }) } @@ -202,8 +238,9 @@ func TestProcessClientID_https(t *testing.T) { name: "invalid_client_id", path: "/dns-query/!!!", wantClientID: "", - wantErrMsg: `client id check: invalid client id: invalid char '!' at index 0 in client id "!!!"`, - wantRes: resultCodeError, + wantErrMsg: `client id check: invalid client id: invalid char '!'` + + ` at index 0 in client id "!!!"`, + wantRes: resultCodeError, }} for _, tc := range testCases { @@ -225,10 +262,11 @@ func TestProcessClientID_https(t *testing.T) { assert.Equal(t, tc.wantRes, res) assert.Equal(t, tc.wantClientID, dctx.clientID) - if tc.wantErrMsg != "" && assert.NotNil(t, dctx.err) { - assert.Equal(t, tc.wantErrMsg, dctx.err.Error()) - } else { + if tc.wantErrMsg == "" { assert.Nil(t, dctx.err) + } else { + require.NotNil(t, dctx.err) + assert.Equal(t, tc.wantErrMsg, dctx.err.Error()) } }) } diff --git a/internal/dnsforward/dns.go b/internal/dnsforward/dns.go index f8e7bff0..acc6aa86 100644 --- a/internal/dnsforward/dns.go +++ b/internal/dnsforward/dns.go @@ -1,10 +1,7 @@ package dnsforward import ( - "crypto/tls" - "fmt" "net" - "path" "strings" "time" @@ -13,7 +10,6 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/util" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/golibs/log" - "github.com/lucas-clemente/quic-go" "github.com/miekg/dns" ) @@ -234,154 +230,6 @@ func processInternalHosts(ctx *dnsContext) (rc resultCode) { return resultCodeSuccess } -const maxDomainPartLen = 64 - -// ValidateClientID returns an error if clientID is not a valid client ID. -func ValidateClientID(clientID string) (err error) { - if len(clientID) > maxDomainPartLen { - return fmt.Errorf("client id %q is too long, max: %d", clientID, maxDomainPartLen) - } - - for i, r := range clientID { - if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '-' { - continue - } - - return fmt.Errorf("invalid char %q at index %d in client id %q", r, i, clientID) - } - - return nil -} - -// clientIDFromClientServerName extracts and validates a client ID. hostSrvName -// is the server name of the host. cliSrvName is the server name as sent by the -// client. -func clientIDFromClientServerName(hostSrvName, cliSrvName string) (clientID string, err error) { - if hostSrvName == cliSrvName { - return "", nil - } - - if !strings.HasSuffix(cliSrvName, hostSrvName) { - return "", fmt.Errorf("client server name %q doesn't match host server name %q", cliSrvName, hostSrvName) - } - - clientID = cliSrvName[:len(cliSrvName)-len(hostSrvName)-1] - err = ValidateClientID(clientID) - if err != nil { - return "", fmt.Errorf("invalid client id: %w", err) - } - - return clientID, nil -} - -// processClientIDHTTPS extracts the client's ID from the path of the -// client's DNS-over-HTTPS request. -func processClientIDHTTPS(ctx *dnsContext) (rc resultCode) { - pctx := ctx.proxyCtx - r := pctx.HTTPRequest - if r == nil { - ctx.err = fmt.Errorf("proxy ctx http request of proto %s is nil", pctx.Proto) - - return resultCodeError - } - - origPath := r.URL.Path - parts := strings.Split(path.Clean(origPath), "/") - if parts[0] == "" { - parts = parts[1:] - } - - if len(parts) == 0 || parts[0] != "dns-query" { - ctx.err = fmt.Errorf("client id check: invalid path %q", origPath) - - return resultCodeError - } - - clientID := "" - switch len(parts) { - case 1: - // Just /dns-query, no client ID. - return resultCodeSuccess - case 2: - clientID = parts[1] - default: - ctx.err = fmt.Errorf("client id check: invalid path %q: extra parts", origPath) - - return resultCodeError - } - - err := ValidateClientID(clientID) - if err != nil { - ctx.err = fmt.Errorf("client id check: invalid client id: %w", err) - - return resultCodeError - } - - ctx.clientID = clientID - - return resultCodeSuccess -} - -// tlsConn is a narrow interface for *tls.Conn to simplify testing. -type tlsConn interface { - ConnectionState() (cs tls.ConnectionState) -} - -// quicSession is a narrow interface for quic.Session to simplify testing. -type quicSession interface { - ConnectionState() (cs quic.ConnectionState) -} - -// processClientID extracts the client's ID from the server name of the client's -// DOT or DOQ request or the path of the client's DOH. -func processClientID(ctx *dnsContext) (rc resultCode) { - pctx := ctx.proxyCtx - proto := pctx.Proto - if proto == proxy.ProtoHTTPS { - return processClientIDHTTPS(ctx) - } else if proto != proxy.ProtoTLS && proto != proxy.ProtoQUIC { - return resultCodeSuccess - } - - hostSrvName := ctx.srv.conf.TLSConfig.ServerName - if hostSrvName == "" { - return resultCodeSuccess - } - - cliSrvName := "" - if proto == proxy.ProtoTLS { - conn := pctx.Conn - tc, ok := conn.(tlsConn) - if !ok { - ctx.err = fmt.Errorf("proxy ctx conn of proto %s is %T, want *tls.Conn", proto, conn) - - return resultCodeError - } - - cliSrvName = tc.ConnectionState().ServerName - } else if proto == proxy.ProtoQUIC { - qs, ok := pctx.QUICSession.(quicSession) - if !ok { - ctx.err = fmt.Errorf("proxy ctx quic session of proto %s is %T, want quic.Session", proto, pctx.QUICSession) - - return resultCodeError - } - - cliSrvName = qs.ConnectionState().ServerName - } - - clientID, err := clientIDFromClientServerName(hostSrvName, cliSrvName) - if err != nil { - ctx.err = fmt.Errorf("client id check: %w", err) - - return resultCodeError - } - - ctx.clientID = clientID - - return resultCodeSuccess -} - // Respond to PTR requests if the target IP address is leased by our DHCP server func processInternalIPAddrs(ctx *dnsContext) (rc resultCode) { s := ctx.srv