diff --git a/CHANGELOG.md b/CHANGELOG.md index 9d07fc3c..0f888d72 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -64,6 +64,7 @@ and this project adheres to ### Fixed +- Client ID checking ([#3437]). - Discovering other DHCP servers on `darwin` and `freebsd` ([#3417]). - Switching listening address to unspecified one when bound to a single specified IPv4 address on Darwin (macOS) ([#2807]). @@ -122,6 +123,7 @@ and this project adheres to [#3351]: https://github.com/AdguardTeam/AdGuardHome/issues/3351 [#3372]: https://github.com/AdguardTeam/AdGuardHome/issues/3372 [#3417]: https://github.com/AdguardTeam/AdGuardHome/issues/3417 +[#3437]: https://github.com/AdguardTeam/AdGuardHome/issues/3437 diff --git a/internal/dnsforward/clientid.go b/internal/dnsforward/clientid.go index 00433358..3e92e8de 100644 --- a/internal/dnsforward/clientid.go +++ b/internal/dnsforward/clientid.go @@ -24,21 +24,39 @@ func ValidateClientID(clientID string) (err error) { return nil } +// hasLabelSuffix returns true if s ends with suffix preceded by a dot. It's +// a helper function to prevent unnecessary allocations in code like: +// +// if strings.HasSuffix(s, "." + suffix) { /* … */ } +// +// s must be longer than suffix. +func hasLabelSuffix(s, suffix string) (ok bool) { + return strings.HasSuffix(s, suffix) && s[len(s)-len(suffix)-1] == '.' +} + // clientIDFromClientServerName extracts and validates a client ID. hostSrvName // is the server name of the host. cliSrvName is the server name as sent by the // client. When strict is true, and client and host server name don't match, // clientIDFromClientServerName will return an error. -func clientIDFromClientServerName(hostSrvName, cliSrvName string, strict bool) (clientID string, err error) { +func clientIDFromClientServerName( + hostSrvName string, + cliSrvName string, + strict bool, +) (clientID string, err error) { if hostSrvName == cliSrvName { return "", nil } - if !strings.HasSuffix(cliSrvName, hostSrvName) { + if !hasLabelSuffix(cliSrvName, hostSrvName) { if !strict { return "", nil } - return "", fmt.Errorf("client server name %q doesn't match host server name %q", cliSrvName, hostSrvName) + return "", fmt.Errorf( + "client server name %q doesn't match host server name %q", + cliSrvName, + hostSrvName, + ) } clientID = cliSrvName[:len(cliSrvName)-len(hostSrvName)-1] diff --git a/internal/dnsforward/clientid_test.go b/internal/dnsforward/clientid_test.go index 7b14a836..f91eaeff 100644 --- a/internal/dnsforward/clientid_test.go +++ b/internal/dnsforward/clientid_test.go @@ -134,6 +134,15 @@ func TestServer_clientIDFromDNSContext(t *testing.T) { wantClientID: "cli", wantErrMsg: "", strictSNI: true, + }, { + name: "tls_client_id_issue3437", + proto: proxy.ProtoTLS, + hostSrvName: "example.com", + cliSrvName: "cli.myexample.com", + wantClientID: "", + wantErrMsg: `client id check: client server name "cli.myexample.com" ` + + `doesn't match host server name "example.com"`, + strictSNI: true, }} for _, tc := range testCases {