Pull request: dnsforward: do not check client srv name unless asked
Merge in DNS/adguard-home from 2664-non-strict-sni to master Updates #2664. Squashed commit of the following: commit e8d625fe3b1f06f97328809a3330b37e5bd578d7 Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Thu Feb 11 14:46:52 2021 +0300 all: imp doc commit 10537b8bdf126eca9608353e57d92edba632232a Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Thu Feb 11 14:30:25 2021 +0300 dnsforward: do not check client srv name unless asked
This commit is contained in:
parent
f016ae172c
commit
841bb9bc35
@ -19,17 +19,20 @@ and this project adheres to
|
|||||||
|
|
||||||
### Changed
|
### Changed
|
||||||
|
|
||||||
- Increase the HTTP API request body size limit for the `/control/access/set`
|
- The server name sent by clients of TLS APIs is not only checked when
|
||||||
API ([#2666]).
|
`strict_sni_check` is enabled ([#2664]).
|
||||||
|
- HTTP API request body size limit for the `/control/access/set` API is
|
||||||
|
increased ([#2666]).
|
||||||
|
|
||||||
### Fixed
|
### Fixed
|
||||||
|
|
||||||
- Set the request body size limit for HTTPS reqeusts as well.
|
- The request body size limit is now set for HTTPS requests as well.
|
||||||
- Incorrect version tag in the Docker release ([#2663]).
|
- Incorrect version tag in the Docker release ([#2663]).
|
||||||
- DNSCrypt queries weren't marked as such in logs ([#2662]).
|
- DNSCrypt queries weren't marked as such in logs ([#2662]).
|
||||||
|
|
||||||
[#2662]: https://github.com/AdguardTeam/AdGuardHome/issues/2662
|
[#2662]: https://github.com/AdguardTeam/AdGuardHome/issues/2662
|
||||||
[#2663]: https://github.com/AdguardTeam/AdGuardHome/issues/2663
|
[#2663]: https://github.com/AdguardTeam/AdGuardHome/issues/2663
|
||||||
|
[#2664]: https://github.com/AdguardTeam/AdGuardHome/issues/2664
|
||||||
[#2666]: https://github.com/AdguardTeam/AdGuardHome/issues/2666
|
[#2666]: https://github.com/AdguardTeam/AdGuardHome/issues/2666
|
||||||
|
|
||||||
|
|
||||||
|
165
internal/dnsforward/clientid.go
Normal file
165
internal/dnsforward/clientid.go
Normal file
@ -0,0 +1,165 @@
|
|||||||
|
package dnsforward
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"path"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||||
|
"github.com/lucas-clemente/quic-go"
|
||||||
|
)
|
||||||
|
|
||||||
|
const maxDomainPartLen = 64
|
||||||
|
|
||||||
|
// ValidateClientID returns an error if clientID is not a valid client ID.
|
||||||
|
func ValidateClientID(clientID string) (err error) {
|
||||||
|
if len(clientID) > maxDomainPartLen {
|
||||||
|
return fmt.Errorf("client id %q is too long, max: %d", clientID, maxDomainPartLen)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, r := range clientID {
|
||||||
|
if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '-' {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("invalid char %q at index %d in client id %q", r, i, clientID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// clientIDFromClientServerName extracts and validates a client ID. hostSrvName
|
||||||
|
// is the server name of the host. cliSrvName is the server name as sent by the
|
||||||
|
// client. When strict is true, and client and host server name don't match,
|
||||||
|
// clientIDFromClientServerName will return an error.
|
||||||
|
func clientIDFromClientServerName(hostSrvName, cliSrvName string, strict bool) (clientID string, err error) {
|
||||||
|
if hostSrvName == cliSrvName {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.HasSuffix(cliSrvName, hostSrvName) {
|
||||||
|
if !strict {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", fmt.Errorf("client server name %q doesn't match host server name %q", cliSrvName, hostSrvName)
|
||||||
|
}
|
||||||
|
|
||||||
|
clientID = cliSrvName[:len(cliSrvName)-len(hostSrvName)-1]
|
||||||
|
err = ValidateClientID(clientID)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("invalid client id: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return clientID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// processClientIDHTTPS extracts the client's ID from the path of the
|
||||||
|
// client's DNS-over-HTTPS request.
|
||||||
|
func processClientIDHTTPS(ctx *dnsContext) (rc resultCode) {
|
||||||
|
pctx := ctx.proxyCtx
|
||||||
|
r := pctx.HTTPRequest
|
||||||
|
if r == nil {
|
||||||
|
ctx.err = fmt.Errorf("proxy ctx http request of proto %s is nil", pctx.Proto)
|
||||||
|
|
||||||
|
return resultCodeError
|
||||||
|
}
|
||||||
|
|
||||||
|
origPath := r.URL.Path
|
||||||
|
parts := strings.Split(path.Clean(origPath), "/")
|
||||||
|
if parts[0] == "" {
|
||||||
|
parts = parts[1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(parts) == 0 || parts[0] != "dns-query" {
|
||||||
|
ctx.err = fmt.Errorf("client id check: invalid path %q", origPath)
|
||||||
|
|
||||||
|
return resultCodeError
|
||||||
|
}
|
||||||
|
|
||||||
|
clientID := ""
|
||||||
|
switch len(parts) {
|
||||||
|
case 1:
|
||||||
|
// Just /dns-query, no client ID.
|
||||||
|
return resultCodeSuccess
|
||||||
|
case 2:
|
||||||
|
clientID = parts[1]
|
||||||
|
default:
|
||||||
|
ctx.err = fmt.Errorf("client id check: invalid path %q: extra parts", origPath)
|
||||||
|
|
||||||
|
return resultCodeError
|
||||||
|
}
|
||||||
|
|
||||||
|
err := ValidateClientID(clientID)
|
||||||
|
if err != nil {
|
||||||
|
ctx.err = fmt.Errorf("client id check: invalid client id: %w", err)
|
||||||
|
|
||||||
|
return resultCodeError
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.clientID = clientID
|
||||||
|
|
||||||
|
return resultCodeSuccess
|
||||||
|
}
|
||||||
|
|
||||||
|
// tlsConn is a narrow interface for *tls.Conn to simplify testing.
|
||||||
|
type tlsConn interface {
|
||||||
|
ConnectionState() (cs tls.ConnectionState)
|
||||||
|
}
|
||||||
|
|
||||||
|
// quicSession is a narrow interface for quic.Session to simplify testing.
|
||||||
|
type quicSession interface {
|
||||||
|
ConnectionState() (cs quic.ConnectionState)
|
||||||
|
}
|
||||||
|
|
||||||
|
// processClientID extracts the client's ID from the server name of the client's
|
||||||
|
// DOT or DOQ request or the path of the client's DOH.
|
||||||
|
func processClientID(dctx *dnsContext) (rc resultCode) {
|
||||||
|
pctx := dctx.proxyCtx
|
||||||
|
proto := pctx.Proto
|
||||||
|
if proto == proxy.ProtoHTTPS {
|
||||||
|
return processClientIDHTTPS(dctx)
|
||||||
|
} else if proto != proxy.ProtoTLS && proto != proxy.ProtoQUIC {
|
||||||
|
return resultCodeSuccess
|
||||||
|
}
|
||||||
|
|
||||||
|
srvConf := dctx.srv.conf
|
||||||
|
hostSrvName := srvConf.TLSConfig.ServerName
|
||||||
|
if hostSrvName == "" {
|
||||||
|
return resultCodeSuccess
|
||||||
|
}
|
||||||
|
|
||||||
|
cliSrvName := ""
|
||||||
|
if proto == proxy.ProtoTLS {
|
||||||
|
conn := pctx.Conn
|
||||||
|
tc, ok := conn.(tlsConn)
|
||||||
|
if !ok {
|
||||||
|
dctx.err = fmt.Errorf("proxy ctx conn of proto %s is %T, want *tls.Conn", proto, conn)
|
||||||
|
|
||||||
|
return resultCodeError
|
||||||
|
}
|
||||||
|
|
||||||
|
cliSrvName = tc.ConnectionState().ServerName
|
||||||
|
} else if proto == proxy.ProtoQUIC {
|
||||||
|
qs, ok := pctx.QUICSession.(quicSession)
|
||||||
|
if !ok {
|
||||||
|
dctx.err = fmt.Errorf("proxy ctx quic session of proto %s is %T, want quic.Session", proto, pctx.QUICSession)
|
||||||
|
|
||||||
|
return resultCodeError
|
||||||
|
}
|
||||||
|
|
||||||
|
cliSrvName = qs.ConnectionState().ServerName
|
||||||
|
}
|
||||||
|
|
||||||
|
clientID, err := clientIDFromClientServerName(hostSrvName, cliSrvName, srvConf.StrictSNICheck)
|
||||||
|
if err != nil {
|
||||||
|
dctx.err = fmt.Errorf("client id check: %w", err)
|
||||||
|
|
||||||
|
return resultCodeError
|
||||||
|
}
|
||||||
|
|
||||||
|
dctx.clientID = clientID
|
||||||
|
|
||||||
|
return resultCodeSuccess
|
||||||
|
}
|
@ -10,6 +10,7 @@ import (
|
|||||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||||
"github.com/lucas-clemente/quic-go"
|
"github.com/lucas-clemente/quic-go"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
// testTLSConn is a tlsConn for tests.
|
// testTLSConn is a tlsConn for tests.
|
||||||
@ -53,6 +54,7 @@ func TestProcessClientID(t *testing.T) {
|
|||||||
wantClientID string
|
wantClientID string
|
||||||
wantErrMsg string
|
wantErrMsg string
|
||||||
wantRes resultCode
|
wantRes resultCode
|
||||||
|
strictSNI bool
|
||||||
}{{
|
}{{
|
||||||
name: "udp",
|
name: "udp",
|
||||||
proto: proxy.ProtoUDP,
|
proto: proxy.ProtoUDP,
|
||||||
@ -61,6 +63,7 @@ func TestProcessClientID(t *testing.T) {
|
|||||||
wantClientID: "",
|
wantClientID: "",
|
||||||
wantErrMsg: "",
|
wantErrMsg: "",
|
||||||
wantRes: resultCodeSuccess,
|
wantRes: resultCodeSuccess,
|
||||||
|
strictSNI: false,
|
||||||
}, {
|
}, {
|
||||||
name: "tls_no_client_id",
|
name: "tls_no_client_id",
|
||||||
proto: proxy.ProtoTLS,
|
proto: proxy.ProtoTLS,
|
||||||
@ -69,6 +72,26 @@ func TestProcessClientID(t *testing.T) {
|
|||||||
wantClientID: "",
|
wantClientID: "",
|
||||||
wantErrMsg: "",
|
wantErrMsg: "",
|
||||||
wantRes: resultCodeSuccess,
|
wantRes: resultCodeSuccess,
|
||||||
|
strictSNI: true,
|
||||||
|
}, {
|
||||||
|
name: "tls_no_client_server_name",
|
||||||
|
proto: proxy.ProtoTLS,
|
||||||
|
hostSrvName: "example.com",
|
||||||
|
cliSrvName: "",
|
||||||
|
wantClientID: "",
|
||||||
|
wantErrMsg: `client id check: client server name "" ` +
|
||||||
|
`doesn't match host server name "example.com"`,
|
||||||
|
wantRes: resultCodeError,
|
||||||
|
strictSNI: true,
|
||||||
|
}, {
|
||||||
|
name: "tls_no_client_server_name_no_strict",
|
||||||
|
proto: proxy.ProtoTLS,
|
||||||
|
hostSrvName: "example.com",
|
||||||
|
cliSrvName: "",
|
||||||
|
wantClientID: "",
|
||||||
|
wantErrMsg: "",
|
||||||
|
wantRes: resultCodeSuccess,
|
||||||
|
strictSNI: false,
|
||||||
}, {
|
}, {
|
||||||
name: "tls_client_id",
|
name: "tls_client_id",
|
||||||
proto: proxy.ProtoTLS,
|
proto: proxy.ProtoTLS,
|
||||||
@ -77,30 +100,39 @@ func TestProcessClientID(t *testing.T) {
|
|||||||
wantClientID: "cli",
|
wantClientID: "cli",
|
||||||
wantErrMsg: "",
|
wantErrMsg: "",
|
||||||
wantRes: resultCodeSuccess,
|
wantRes: resultCodeSuccess,
|
||||||
|
strictSNI: true,
|
||||||
}, {
|
}, {
|
||||||
name: "tls_client_id_hostname_error",
|
name: "tls_client_id_hostname_error",
|
||||||
proto: proxy.ProtoTLS,
|
proto: proxy.ProtoTLS,
|
||||||
hostSrvName: "example.com",
|
hostSrvName: "example.com",
|
||||||
cliSrvName: "cli.example.net",
|
cliSrvName: "cli.example.net",
|
||||||
wantClientID: "",
|
wantClientID: "",
|
||||||
wantErrMsg: `client id check: client server name "cli.example.net" doesn't match host server name "example.com"`,
|
wantErrMsg: `client id check: client server name "cli.example.net" ` +
|
||||||
wantRes: resultCodeError,
|
`doesn't match host server name "example.com"`,
|
||||||
|
wantRes: resultCodeError,
|
||||||
|
strictSNI: true,
|
||||||
}, {
|
}, {
|
||||||
name: "tls_invalid_client_id",
|
name: "tls_invalid_client_id",
|
||||||
proto: proxy.ProtoTLS,
|
proto: proxy.ProtoTLS,
|
||||||
hostSrvName: "example.com",
|
hostSrvName: "example.com",
|
||||||
cliSrvName: "!!!.example.com",
|
cliSrvName: "!!!.example.com",
|
||||||
wantClientID: "",
|
wantClientID: "",
|
||||||
wantErrMsg: `client id check: invalid client id: invalid char '!' at index 0 in client id "!!!"`,
|
wantErrMsg: `client id check: invalid client id: invalid char '!' ` +
|
||||||
wantRes: resultCodeError,
|
`at index 0 in client id "!!!"`,
|
||||||
|
wantRes: resultCodeError,
|
||||||
|
strictSNI: true,
|
||||||
}, {
|
}, {
|
||||||
name: "tls_client_id_too_long",
|
name: "tls_client_id_too_long",
|
||||||
proto: proxy.ProtoTLS,
|
proto: proxy.ProtoTLS,
|
||||||
hostSrvName: "example.com",
|
hostSrvName: "example.com",
|
||||||
cliSrvName: "abcdefghijklmnopqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz0123456789.example.com",
|
cliSrvName: `abcdefghijklmnopqrstuvwxyz0123456789abcdefghijklmno` +
|
||||||
|
`pqrstuvwxyz0123456789.example.com`,
|
||||||
wantClientID: "",
|
wantClientID: "",
|
||||||
wantErrMsg: `client id check: invalid client id: client id "abcdefghijklmnopqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz0123456789" is too long, max: 64`,
|
wantErrMsg: `client id check: invalid client id: client id "abcdefghijklmno` +
|
||||||
wantRes: resultCodeError,
|
`pqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz0123456789" ` +
|
||||||
|
`is too long, max: 64`,
|
||||||
|
wantRes: resultCodeError,
|
||||||
|
strictSNI: true,
|
||||||
}, {
|
}, {
|
||||||
name: "quic_client_id",
|
name: "quic_client_id",
|
||||||
proto: proxy.ProtoQUIC,
|
proto: proxy.ProtoQUIC,
|
||||||
@ -109,14 +141,17 @@ func TestProcessClientID(t *testing.T) {
|
|||||||
wantClientID: "cli",
|
wantClientID: "cli",
|
||||||
wantErrMsg: "",
|
wantErrMsg: "",
|
||||||
wantRes: resultCodeSuccess,
|
wantRes: resultCodeSuccess,
|
||||||
|
strictSNI: true,
|
||||||
}}
|
}}
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
tlsConf := TLSConfig{
|
||||||
|
ServerName: tc.hostSrvName,
|
||||||
|
StrictSNICheck: tc.strictSNI,
|
||||||
|
}
|
||||||
srv := &Server{
|
srv := &Server{
|
||||||
conf: ServerConfig{
|
conf: ServerConfig{TLSConfig: tlsConf},
|
||||||
TLSConfig: TLSConfig{ServerName: tc.hostSrvName},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var conn net.Conn
|
var conn net.Conn
|
||||||
@ -146,10 +181,11 @@ func TestProcessClientID(t *testing.T) {
|
|||||||
assert.Equal(t, tc.wantRes, res)
|
assert.Equal(t, tc.wantRes, res)
|
||||||
assert.Equal(t, tc.wantClientID, dctx.clientID)
|
assert.Equal(t, tc.wantClientID, dctx.clientID)
|
||||||
|
|
||||||
if tc.wantErrMsg != "" && assert.NotNil(t, dctx.err) {
|
if tc.wantErrMsg == "" {
|
||||||
assert.Equal(t, tc.wantErrMsg, dctx.err.Error())
|
|
||||||
} else {
|
|
||||||
assert.Nil(t, dctx.err)
|
assert.Nil(t, dctx.err)
|
||||||
|
} else {
|
||||||
|
require.NotNil(t, dctx.err)
|
||||||
|
assert.Equal(t, tc.wantErrMsg, dctx.err.Error())
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -202,8 +238,9 @@ func TestProcessClientID_https(t *testing.T) {
|
|||||||
name: "invalid_client_id",
|
name: "invalid_client_id",
|
||||||
path: "/dns-query/!!!",
|
path: "/dns-query/!!!",
|
||||||
wantClientID: "",
|
wantClientID: "",
|
||||||
wantErrMsg: `client id check: invalid client id: invalid char '!' at index 0 in client id "!!!"`,
|
wantErrMsg: `client id check: invalid client id: invalid char '!'` +
|
||||||
wantRes: resultCodeError,
|
` at index 0 in client id "!!!"`,
|
||||||
|
wantRes: resultCodeError,
|
||||||
}}
|
}}
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
@ -225,10 +262,11 @@ func TestProcessClientID_https(t *testing.T) {
|
|||||||
assert.Equal(t, tc.wantRes, res)
|
assert.Equal(t, tc.wantRes, res)
|
||||||
assert.Equal(t, tc.wantClientID, dctx.clientID)
|
assert.Equal(t, tc.wantClientID, dctx.clientID)
|
||||||
|
|
||||||
if tc.wantErrMsg != "" && assert.NotNil(t, dctx.err) {
|
if tc.wantErrMsg == "" {
|
||||||
assert.Equal(t, tc.wantErrMsg, dctx.err.Error())
|
|
||||||
} else {
|
|
||||||
assert.Nil(t, dctx.err)
|
assert.Nil(t, dctx.err)
|
||||||
|
} else {
|
||||||
|
require.NotNil(t, dctx.err)
|
||||||
|
assert.Equal(t, tc.wantErrMsg, dctx.err.Error())
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
@ -1,10 +1,7 @@
|
|||||||
package dnsforward
|
package dnsforward
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
"net"
|
||||||
"path"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -13,7 +10,6 @@ import (
|
|||||||
"github.com/AdguardTeam/AdGuardHome/internal/util"
|
"github.com/AdguardTeam/AdGuardHome/internal/util"
|
||||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/lucas-clemente/quic-go"
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -234,154 +230,6 @@ func processInternalHosts(ctx *dnsContext) (rc resultCode) {
|
|||||||
return resultCodeSuccess
|
return resultCodeSuccess
|
||||||
}
|
}
|
||||||
|
|
||||||
const maxDomainPartLen = 64
|
|
||||||
|
|
||||||
// ValidateClientID returns an error if clientID is not a valid client ID.
|
|
||||||
func ValidateClientID(clientID string) (err error) {
|
|
||||||
if len(clientID) > maxDomainPartLen {
|
|
||||||
return fmt.Errorf("client id %q is too long, max: %d", clientID, maxDomainPartLen)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, r := range clientID {
|
|
||||||
if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '-' {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Errorf("invalid char %q at index %d in client id %q", r, i, clientID)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// clientIDFromClientServerName extracts and validates a client ID. hostSrvName
|
|
||||||
// is the server name of the host. cliSrvName is the server name as sent by the
|
|
||||||
// client.
|
|
||||||
func clientIDFromClientServerName(hostSrvName, cliSrvName string) (clientID string, err error) {
|
|
||||||
if hostSrvName == cliSrvName {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if !strings.HasSuffix(cliSrvName, hostSrvName) {
|
|
||||||
return "", fmt.Errorf("client server name %q doesn't match host server name %q", cliSrvName, hostSrvName)
|
|
||||||
}
|
|
||||||
|
|
||||||
clientID = cliSrvName[:len(cliSrvName)-len(hostSrvName)-1]
|
|
||||||
err = ValidateClientID(clientID)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("invalid client id: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return clientID, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// processClientIDHTTPS extracts the client's ID from the path of the
|
|
||||||
// client's DNS-over-HTTPS request.
|
|
||||||
func processClientIDHTTPS(ctx *dnsContext) (rc resultCode) {
|
|
||||||
pctx := ctx.proxyCtx
|
|
||||||
r := pctx.HTTPRequest
|
|
||||||
if r == nil {
|
|
||||||
ctx.err = fmt.Errorf("proxy ctx http request of proto %s is nil", pctx.Proto)
|
|
||||||
|
|
||||||
return resultCodeError
|
|
||||||
}
|
|
||||||
|
|
||||||
origPath := r.URL.Path
|
|
||||||
parts := strings.Split(path.Clean(origPath), "/")
|
|
||||||
if parts[0] == "" {
|
|
||||||
parts = parts[1:]
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(parts) == 0 || parts[0] != "dns-query" {
|
|
||||||
ctx.err = fmt.Errorf("client id check: invalid path %q", origPath)
|
|
||||||
|
|
||||||
return resultCodeError
|
|
||||||
}
|
|
||||||
|
|
||||||
clientID := ""
|
|
||||||
switch len(parts) {
|
|
||||||
case 1:
|
|
||||||
// Just /dns-query, no client ID.
|
|
||||||
return resultCodeSuccess
|
|
||||||
case 2:
|
|
||||||
clientID = parts[1]
|
|
||||||
default:
|
|
||||||
ctx.err = fmt.Errorf("client id check: invalid path %q: extra parts", origPath)
|
|
||||||
|
|
||||||
return resultCodeError
|
|
||||||
}
|
|
||||||
|
|
||||||
err := ValidateClientID(clientID)
|
|
||||||
if err != nil {
|
|
||||||
ctx.err = fmt.Errorf("client id check: invalid client id: %w", err)
|
|
||||||
|
|
||||||
return resultCodeError
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx.clientID = clientID
|
|
||||||
|
|
||||||
return resultCodeSuccess
|
|
||||||
}
|
|
||||||
|
|
||||||
// tlsConn is a narrow interface for *tls.Conn to simplify testing.
|
|
||||||
type tlsConn interface {
|
|
||||||
ConnectionState() (cs tls.ConnectionState)
|
|
||||||
}
|
|
||||||
|
|
||||||
// quicSession is a narrow interface for quic.Session to simplify testing.
|
|
||||||
type quicSession interface {
|
|
||||||
ConnectionState() (cs quic.ConnectionState)
|
|
||||||
}
|
|
||||||
|
|
||||||
// processClientID extracts the client's ID from the server name of the client's
|
|
||||||
// DOT or DOQ request or the path of the client's DOH.
|
|
||||||
func processClientID(ctx *dnsContext) (rc resultCode) {
|
|
||||||
pctx := ctx.proxyCtx
|
|
||||||
proto := pctx.Proto
|
|
||||||
if proto == proxy.ProtoHTTPS {
|
|
||||||
return processClientIDHTTPS(ctx)
|
|
||||||
} else if proto != proxy.ProtoTLS && proto != proxy.ProtoQUIC {
|
|
||||||
return resultCodeSuccess
|
|
||||||
}
|
|
||||||
|
|
||||||
hostSrvName := ctx.srv.conf.TLSConfig.ServerName
|
|
||||||
if hostSrvName == "" {
|
|
||||||
return resultCodeSuccess
|
|
||||||
}
|
|
||||||
|
|
||||||
cliSrvName := ""
|
|
||||||
if proto == proxy.ProtoTLS {
|
|
||||||
conn := pctx.Conn
|
|
||||||
tc, ok := conn.(tlsConn)
|
|
||||||
if !ok {
|
|
||||||
ctx.err = fmt.Errorf("proxy ctx conn of proto %s is %T, want *tls.Conn", proto, conn)
|
|
||||||
|
|
||||||
return resultCodeError
|
|
||||||
}
|
|
||||||
|
|
||||||
cliSrvName = tc.ConnectionState().ServerName
|
|
||||||
} else if proto == proxy.ProtoQUIC {
|
|
||||||
qs, ok := pctx.QUICSession.(quicSession)
|
|
||||||
if !ok {
|
|
||||||
ctx.err = fmt.Errorf("proxy ctx quic session of proto %s is %T, want quic.Session", proto, pctx.QUICSession)
|
|
||||||
|
|
||||||
return resultCodeError
|
|
||||||
}
|
|
||||||
|
|
||||||
cliSrvName = qs.ConnectionState().ServerName
|
|
||||||
}
|
|
||||||
|
|
||||||
clientID, err := clientIDFromClientServerName(hostSrvName, cliSrvName)
|
|
||||||
if err != nil {
|
|
||||||
ctx.err = fmt.Errorf("client id check: %w", err)
|
|
||||||
|
|
||||||
return resultCodeError
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx.clientID = clientID
|
|
||||||
|
|
||||||
return resultCodeSuccess
|
|
||||||
}
|
|
||||||
|
|
||||||
// Respond to PTR requests if the target IP address is leased by our DHCP server
|
// Respond to PTR requests if the target IP address is leased by our DHCP server
|
||||||
func processInternalIPAddrs(ctx *dnsContext) (rc resultCode) {
|
func processInternalIPAddrs(ctx *dnsContext) (rc resultCode) {
|
||||||
s := ctx.srv
|
s := ctx.srv
|
||||||
|
Loading…
Reference in New Issue
Block a user