Pull request: all: fix client upstreams, imp code
Updates #3186. Squashed commit of the following: commit a8dd0e2cda3039839d069fe71a5bd0f9635ec064 Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Fri May 28 12:54:07 2021 +0300 all: imp code, names commit 98f86c21ae23b665095075feb4a59dcfcc622bc7 Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Thu May 27 21:11:37 2021 +0300 all: fix client upstreams, imp code
This commit is contained in:
parent
48b8579703
commit
3be783bd34
|
@ -32,6 +32,8 @@ released by then.
|
||||||
|
|
||||||
### Fixed
|
### Fixed
|
||||||
|
|
||||||
|
- Custom upstreams selection for clients with client IDs in DNS-over-TLS and
|
||||||
|
DNS-over-HTTP ([#3186]).
|
||||||
- Incorrect client-based filtering applying logic ([#2875]).
|
- Incorrect client-based filtering applying logic ([#2875]).
|
||||||
|
|
||||||
### Removed
|
### Removed
|
||||||
|
@ -40,6 +42,7 @@ released by then.
|
||||||
|
|
||||||
[#3184]: https://github.com/AdguardTeam/AdGuardHome/issues/3184
|
[#3184]: https://github.com/AdguardTeam/AdGuardHome/issues/3184
|
||||||
[#3185]: https://github.com/AdguardTeam/AdGuardHome/issues/3185
|
[#3185]: https://github.com/AdguardTeam/AdGuardHome/issues/3185
|
||||||
|
[#3186]: https://github.com/AdguardTeam/AdGuardHome/issues/3186
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -10,6 +10,19 @@ import (
|
||||||
"golang.org/x/net/idna"
|
"golang.org/x/net/idna"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// IPFromAddr returns an IP address from addr. If addr is neither
|
||||||
|
// a *net.TCPAddr nor a *net.UDPAddr, it returns nil.
|
||||||
|
func IPFromAddr(addr net.Addr) (ip net.IP) {
|
||||||
|
switch addr := addr.(type) {
|
||||||
|
case *net.TCPAddr:
|
||||||
|
return addr.IP
|
||||||
|
case *net.UDPAddr:
|
||||||
|
return addr.IP
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// IsValidHostOuterRune returns true if r is a valid initial or final rune for
|
// IsValidHostOuterRune returns true if r is a valid initial or final rune for
|
||||||
// a hostname label.
|
// a hostname label.
|
||||||
func IsValidHostOuterRune(r rune) (ok bool) {
|
func IsValidHostOuterRune(r rune) (ok bool) {
|
||||||
|
|
|
@ -9,6 +9,14 @@ import (
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestIPFromAddr(t *testing.T) {
|
||||||
|
ip := net.IP{1, 2, 3, 4}
|
||||||
|
assert.Equal(t, net.IP(nil), IPFromAddr(nil))
|
||||||
|
assert.Equal(t, net.IP(nil), IPFromAddr(struct{ net.Addr }{}))
|
||||||
|
assert.Equal(t, ip, IPFromAddr(&net.TCPAddr{IP: ip}))
|
||||||
|
assert.Equal(t, ip, IPFromAddr(&net.UDPAddr{IP: ip}))
|
||||||
|
}
|
||||||
|
|
||||||
func TestValidateHardwareAddress(t *testing.T) {
|
func TestValidateHardwareAddress(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
|
|
|
@ -19,6 +19,19 @@ func CloneSlice(a []string) (b []string) {
|
||||||
return CloneSliceOrEmpty(a)
|
return CloneSliceOrEmpty(a)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Coalesce returns the first non-empty string. It is named after the function
|
||||||
|
// COALESCE in SQL except that since strings in Go are non-nullable, it uses an
|
||||||
|
// empty string as a NULL value. If strs is empty, it returns an empty string.
|
||||||
|
func Coalesce(strs ...string) (res string) {
|
||||||
|
for _, s := range strs {
|
||||||
|
if s != "" {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
// FilterOut returns a copy of strs with all strings for which f returned true
|
// FilterOut returns a copy of strs with all strings for which f returned true
|
||||||
// removed.
|
// removed.
|
||||||
func FilterOut(strs []string, f func(s string) (ok bool)) (filtered []string) {
|
func FilterOut(strs []string, f func(s string) (ok bool)) (filtered []string) {
|
||||||
|
|
|
@ -36,6 +36,14 @@ func TestCloneSlice_family(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCoalesce(t *testing.T) {
|
||||||
|
assert.Equal(t, "", Coalesce())
|
||||||
|
assert.Equal(t, "a", Coalesce("a"))
|
||||||
|
assert.Equal(t, "a", Coalesce("", "a"))
|
||||||
|
assert.Equal(t, "a", Coalesce("a", ""))
|
||||||
|
assert.Equal(t, "a", Coalesce("a", "b"))
|
||||||
|
}
|
||||||
|
|
||||||
func TestFilterOut(t *testing.T) {
|
func TestFilterOut(t *testing.T) {
|
||||||
strs := []string{
|
strs := []string{
|
||||||
"1.2.3.4",
|
"1.2.3.4",
|
||||||
|
|
|
@ -8,7 +8,9 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"sort"
|
"sort"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghstrings"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghstrings"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||||
|
@ -27,11 +29,10 @@ type FilteringConfig struct {
|
||||||
// FilterHandler is an optional additional filtering callback.
|
// FilterHandler is an optional additional filtering callback.
|
||||||
FilterHandler func(clientAddr net.IP, clientID string, settings *filtering.Settings) `yaml:"-"`
|
FilterHandler func(clientAddr net.IP, clientID string, settings *filtering.Settings) `yaml:"-"`
|
||||||
|
|
||||||
// GetCustomUpstreamByClient - a callback function that returns upstreams configuration
|
// GetCustomUpstreamByClient is a callback that returns upstreams
|
||||||
// based on the client IP address. Returns nil if there are no custom upstreams for the client
|
// configuration based on the client IP address or ClientID. It returns
|
||||||
//
|
// nil if there are no custom upstreams for the client.
|
||||||
// TODO(e.burkov): Replace argument type with net.IP.
|
GetCustomUpstreamByClient func(id string) (conf *proxy.UpstreamConfig, err error) `yaml:"-"`
|
||||||
GetCustomUpstreamByClient func(clientAddr string) *proxy.UpstreamConfig `yaml:"-"`
|
|
||||||
|
|
||||||
// Protection configuration
|
// Protection configuration
|
||||||
// --
|
// --
|
||||||
|
@ -384,10 +385,51 @@ func (s *Server) prepareTLS(proxyConfig *proxy.Config) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isInSorted returns true if s is in the sorted slice strs.
|
||||||
|
func isInSorted(strs []string, s string) (ok bool) {
|
||||||
|
i := sort.SearchStrings(strs, s)
|
||||||
|
if i == len(strs) || strs[i] != s {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// isWildcard returns true if host is a wildcard hostname.
|
||||||
|
func isWildcard(host string) (ok bool) {
|
||||||
|
return len(host) >= 2 && host[0] == '*' && host[1] == '.'
|
||||||
|
}
|
||||||
|
|
||||||
|
// matchesDomainWildcard returns true if host matches the domain wildcard
|
||||||
|
// pattern pat.
|
||||||
|
func matchesDomainWildcard(host, pat string) (ok bool) {
|
||||||
|
return isWildcard(pat) && strings.HasSuffix(host, pat[1:])
|
||||||
|
}
|
||||||
|
|
||||||
|
// anyNameMatches returns true if sni, the client's SNI value, matches any of
|
||||||
|
// the DNS names and patterns from certificate. dnsNames must be sorted.
|
||||||
|
func anyNameMatches(dnsNames []string, sni string) (ok bool) {
|
||||||
|
if aghnet.ValidateDomainName(sni) != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if isInSorted(dnsNames, sni) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, dn := range dnsNames {
|
||||||
|
if matchesDomainWildcard(sni, dn) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// Called by 'tls' package when Client Hello is received
|
// Called by 'tls' package when Client Hello is received
|
||||||
// If the server name (from SNI) supplied by client is incorrect - we terminate the ongoing TLS handshake.
|
// If the server name (from SNI) supplied by client is incorrect - we terminate the ongoing TLS handshake.
|
||||||
func (s *Server) onGetCertificate(ch *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
func (s *Server) onGetCertificate(ch *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||||
if s.conf.StrictSNICheck && !matchDNSName(s.conf.dnsNames, ch.ServerName) {
|
if s.conf.StrictSNICheck && !anyNameMatches(s.conf.dnsNames, ch.ServerName) {
|
||||||
log.Info("dns: tls: unknown SNI in Client Hello: %s", ch.ServerName)
|
log.Info("dns: tls: unknown SNI in Client Hello: %s", ch.ServerName)
|
||||||
return nil, fmt.Errorf("invalid SNI")
|
return nil, fmt.Errorf("invalid SNI")
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,53 @@
|
||||||
|
package dnsforward
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sort"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAnyNameMatches(t *testing.T) {
|
||||||
|
dnsNames := []string{"host1", "*.host2", "1.2.3.4"}
|
||||||
|
sort.Strings(dnsNames)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
dnsName string
|
||||||
|
want bool
|
||||||
|
}{{
|
||||||
|
name: "match",
|
||||||
|
dnsName: "host1",
|
||||||
|
want: true,
|
||||||
|
}, {
|
||||||
|
name: "match",
|
||||||
|
dnsName: "a.host2",
|
||||||
|
want: true,
|
||||||
|
}, {
|
||||||
|
name: "match",
|
||||||
|
dnsName: "b.a.host2",
|
||||||
|
want: true,
|
||||||
|
}, {
|
||||||
|
name: "match",
|
||||||
|
dnsName: "1.2.3.4",
|
||||||
|
want: true,
|
||||||
|
}, {
|
||||||
|
name: "mismatch",
|
||||||
|
dnsName: "host2",
|
||||||
|
want: false,
|
||||||
|
}, {
|
||||||
|
name: "mismatch",
|
||||||
|
dnsName: "",
|
||||||
|
want: false,
|
||||||
|
}, {
|
||||||
|
name: "mismatch",
|
||||||
|
dnsName: "*.host2",
|
||||||
|
want: false,
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
assert.Equal(t, tc.want, anyNameMatches(dnsNames, tc.dnsName))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/aghstrings"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||||
|
@ -229,7 +230,7 @@ func (s *Server) processDetermineLocal(dctx *dnsContext) (rc resultCode) {
|
||||||
rc = resultCodeSuccess
|
rc = resultCodeSuccess
|
||||||
|
|
||||||
var ip net.IP
|
var ip net.IP
|
||||||
if ip = IPFromAddr(dctx.proxyCtx.Addr); ip == nil {
|
if ip = aghnet.IPFromAddr(dctx.proxyCtx.Addr); ip == nil {
|
||||||
return rc
|
return rc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -489,6 +490,15 @@ func processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode) {
|
||||||
return resultCodeSuccess
|
return resultCodeSuccess
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ipStringFromAddr extracts an IP address string from net.Addr.
|
||||||
|
func ipStringFromAddr(addr net.Addr) (ipStr string) {
|
||||||
|
if ip := aghnet.IPFromAddr(addr); ip != nil {
|
||||||
|
return ip.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
// processUpstream passes request to upstream servers and handles the response.
|
// processUpstream passes request to upstream servers and handles the response.
|
||||||
func (s *Server) processUpstream(ctx *dnsContext) (rc resultCode) {
|
func (s *Server) processUpstream(ctx *dnsContext) (rc resultCode) {
|
||||||
d := ctx.proxyCtx
|
d := ctx.proxyCtx
|
||||||
|
@ -497,9 +507,13 @@ func (s *Server) processUpstream(ctx *dnsContext) (rc resultCode) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if d.Addr != nil && s.conf.GetCustomUpstreamByClient != nil {
|
if d.Addr != nil && s.conf.GetCustomUpstreamByClient != nil {
|
||||||
clientIP := IPStringFromAddr(d.Addr)
|
// Use the clientID first, since it has a higher priority.
|
||||||
if upsConf := s.conf.GetCustomUpstreamByClient(clientIP); upsConf != nil {
|
id := aghstrings.Coalesce(ctx.clientID, ipStringFromAddr(d.Addr))
|
||||||
log.Debug("dns: using custom upstreams for client %s", clientIP)
|
upsConf, err := s.conf.GetCustomUpstreamByClient(id)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("dns: getting custom upstreams for client %s: %s", id, err)
|
||||||
|
} else if upsConf != nil {
|
||||||
|
log.Debug("dns: using custom upstreams for client %s", id)
|
||||||
d.CustomUpstreamConfig = upsConf
|
d.CustomUpstreamConfig = upsConf
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -379,3 +379,18 @@ func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) {
|
||||||
require.Empty(t, proxyCtx.Res.Answer)
|
require.Empty(t, proxyCtx.Res.Answer)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIPStringFromAddr(t *testing.T) {
|
||||||
|
t.Run("not_nil", func(t *testing.T) {
|
||||||
|
addr := net.UDPAddr{
|
||||||
|
IP: net.ParseIP("1:2:3::4"),
|
||||||
|
Port: 12345,
|
||||||
|
Zone: "eth0",
|
||||||
|
}
|
||||||
|
assert.Equal(t, ipStringFromAddr(&addr), addr.IP.String())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("nil", func(t *testing.T) {
|
||||||
|
assert.Empty(t, ipStringFromAddr(nil))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
@ -12,7 +12,6 @@ import (
|
||||||
"math/big"
|
"math/big"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"sort"
|
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
@ -521,16 +520,16 @@ func TestServerCustomClientUpstream(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
s := createTestServer(t, &filtering.Config{}, forwardConf, nil)
|
s := createTestServer(t, &filtering.Config{}, forwardConf, nil)
|
||||||
s.conf.GetCustomUpstreamByClient = func(_ string) *proxy.UpstreamConfig {
|
s.conf.GetCustomUpstreamByClient = func(_ string) (conf *proxy.UpstreamConfig, err error) {
|
||||||
return &proxy.UpstreamConfig{
|
ups := &aghtest.TestUpstream{
|
||||||
Upstreams: []upstream.Upstream{
|
IPv4: map[string][]net.IP{
|
||||||
&aghtest.TestUpstream{
|
"host.": {{192, 168, 0, 1}},
|
||||||
IPv4: map[string][]net.IP{
|
|
||||||
"host.": {{192, 168, 0, 1}},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return &proxy.UpstreamConfig{
|
||||||
|
Upstreams: []upstream.Upstream{ups},
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
startDeferStop(t, s)
|
startDeferStop(t, s)
|
||||||
|
|
||||||
|
@ -962,65 +961,6 @@ func publicKey(priv interface{}) interface{} {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIPStringFromAddr(t *testing.T) {
|
|
||||||
t.Run("not_nil", func(t *testing.T) {
|
|
||||||
addr := net.UDPAddr{
|
|
||||||
IP: net.ParseIP("1:2:3::4"),
|
|
||||||
Port: 12345,
|
|
||||||
Zone: "eth0",
|
|
||||||
}
|
|
||||||
assert.Equal(t, IPStringFromAddr(&addr), addr.IP.String())
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("nil", func(t *testing.T) {
|
|
||||||
assert.Empty(t, IPStringFromAddr(nil))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMatchDNSName(t *testing.T) {
|
|
||||||
dnsNames := []string{"host1", "*.host2", "1.2.3.4"}
|
|
||||||
sort.Strings(dnsNames)
|
|
||||||
|
|
||||||
testCases := []struct {
|
|
||||||
name string
|
|
||||||
dnsName string
|
|
||||||
want bool
|
|
||||||
}{{
|
|
||||||
name: "match",
|
|
||||||
dnsName: "host1",
|
|
||||||
want: true,
|
|
||||||
}, {
|
|
||||||
name: "match",
|
|
||||||
dnsName: "a.host2",
|
|
||||||
want: true,
|
|
||||||
}, {
|
|
||||||
name: "match",
|
|
||||||
dnsName: "b.a.host2",
|
|
||||||
want: true,
|
|
||||||
}, {
|
|
||||||
name: "match",
|
|
||||||
dnsName: "1.2.3.4",
|
|
||||||
want: true,
|
|
||||||
}, {
|
|
||||||
name: "mismatch",
|
|
||||||
dnsName: "host2",
|
|
||||||
want: false,
|
|
||||||
}, {
|
|
||||||
name: "mismatch",
|
|
||||||
dnsName: "",
|
|
||||||
want: false,
|
|
||||||
}, {
|
|
||||||
name: "mismatch",
|
|
||||||
dnsName: "*.host2",
|
|
||||||
want: false,
|
|
||||||
}}
|
|
||||||
for _, tc := range testCases {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
assert.Equal(t, tc.want, matchDNSName(dnsNames, tc.dnsName))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type testDHCP struct{}
|
type testDHCP struct{}
|
||||||
|
|
||||||
func (d *testDHCP) Enabled() (ok bool) { return true }
|
func (d *testDHCP) Enabled() (ok bool) { return true }
|
||||||
|
|
|
@ -4,15 +4,15 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Server) beforeRequestHandler(_ *proxy.Proxy, d *proxy.DNSContext) (bool, error) {
|
func (s *Server) beforeRequestHandler(_ *proxy.Proxy, d *proxy.DNSContext) (bool, error) {
|
||||||
ip := IPFromAddr(d.Addr)
|
ip := aghnet.IPFromAddr(d.Addr)
|
||||||
disallowed, _ := s.access.IsBlockedIP(ip)
|
disallowed, _ := s.access.IsBlockedIP(ip)
|
||||||
if disallowed {
|
if disallowed {
|
||||||
log.Tracef("Client IP %s is blocked by settings", ip)
|
log.Tracef("Client IP %s is blocked by settings", ip)
|
||||||
|
@ -39,7 +39,7 @@ func (s *Server) beforeRequestHandler(_ *proxy.Proxy, d *proxy.DNSContext) (bool
|
||||||
func (s *Server) getClientRequestFilteringSettings(ctx *dnsContext) *filtering.Settings {
|
func (s *Server) getClientRequestFilteringSettings(ctx *dnsContext) *filtering.Settings {
|
||||||
setts := s.dnsFilter.GetConfig()
|
setts := s.dnsFilter.GetConfig()
|
||||||
if s.conf.FilterHandler != nil {
|
if s.conf.FilterHandler != nil {
|
||||||
s.conf.FilterHandler(IPFromAddr(ctx.proxyCtx.Addr), ctx.clientID, &setts)
|
s.conf.FilterHandler(aghnet.IPFromAddr(ctx.proxyCtx.Addr), ctx.clientID, &setts)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &setts
|
return &setts
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
||||||
|
@ -37,7 +38,7 @@ func processQueryLogsAndStats(ctx *dnsContext) (rc resultCode) {
|
||||||
OrigAnswer: ctx.origResp,
|
OrigAnswer: ctx.origResp,
|
||||||
Result: ctx.result,
|
Result: ctx.result,
|
||||||
Elapsed: elapsed,
|
Elapsed: elapsed,
|
||||||
ClientIP: IPFromAddr(pctx.Addr),
|
ClientIP: aghnet.IPFromAddr(pctx.Addr),
|
||||||
ClientID: ctx.clientID,
|
ClientID: ctx.clientID,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -79,7 +80,7 @@ func (s *Server) updateStats(ctx *dnsContext, elapsed time.Duration, res filteri
|
||||||
|
|
||||||
if clientID := ctx.clientID; clientID != "" {
|
if clientID := ctx.clientID; clientID != "" {
|
||||||
e.Client = clientID
|
e.Client = clientID
|
||||||
} else if ip := IPFromAddr(pctx.Addr); ip != nil {
|
} else if ip := aghnet.IPFromAddr(pctx.Addr); ip != nil {
|
||||||
e.Client = ip.String()
|
e.Client = ip.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,69 +0,0 @@
|
||||||
package dnsforward
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
"sort"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
|
||||||
)
|
|
||||||
|
|
||||||
// IPFromAddr gets IP address from addr.
|
|
||||||
func IPFromAddr(addr net.Addr) (ip net.IP) {
|
|
||||||
switch addr := addr.(type) {
|
|
||||||
case *net.UDPAddr:
|
|
||||||
return addr.IP
|
|
||||||
case *net.TCPAddr:
|
|
||||||
return addr.IP
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// IPStringFromAddr extracts IP address from net.Addr.
|
|
||||||
// Note: we can't use net.SplitHostPort(a.String()) because of IPv6 zone:
|
|
||||||
// https://github.com/AdguardTeam/AdGuardHome/internal/issues/1261
|
|
||||||
func IPStringFromAddr(addr net.Addr) (ipStr string) {
|
|
||||||
if ip := IPFromAddr(addr); ip != nil {
|
|
||||||
return ip.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// Find value in a sorted array
|
|
||||||
func findSorted(ar []string, val string) int {
|
|
||||||
i := sort.SearchStrings(ar, val)
|
|
||||||
if i == len(ar) || ar[i] != val {
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
return i
|
|
||||||
}
|
|
||||||
|
|
||||||
func isWildcard(host string) bool {
|
|
||||||
return len(host) >= 2 &&
|
|
||||||
host[0] == '*' && host[1] == '.'
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return TRUE if host name matches a wildcard pattern
|
|
||||||
func matchDomainWildcard(host, wildcard string) bool {
|
|
||||||
return isWildcard(wildcard) &&
|
|
||||||
strings.HasSuffix(host, wildcard[1:])
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return TRUE if client's SNI value matches DNS names from certificate
|
|
||||||
func matchDNSName(dnsNames []string, sni string) bool {
|
|
||||||
if aghnet.ValidateDomainName(sni) != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if findSorted(dnsNames, sni) != -1 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, dn := range dnsNames {
|
|
||||||
if matchDomainWildcard(sni, dn) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
|
@ -1,60 +0,0 @@
|
||||||
package dnsforward
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
// fakeAddr is a mock implementation of net.Addr interface to simplify testing.
|
|
||||||
type fakeAddr struct {
|
|
||||||
// Addr is embedded here simply to make fakeAddr a net.Addr without
|
|
||||||
// actually implementing all methods.
|
|
||||||
net.Addr
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIPFromAddr(t *testing.T) {
|
|
||||||
supIPv4 := net.IP{1, 2, 3, 4}
|
|
||||||
supIPv6 := net.ParseIP("2a00:1450:400c:c06::93")
|
|
||||||
|
|
||||||
testCases := []struct {
|
|
||||||
name string
|
|
||||||
addr net.Addr
|
|
||||||
want net.IP
|
|
||||||
}{{
|
|
||||||
name: "ipv4_tcp",
|
|
||||||
addr: &net.TCPAddr{
|
|
||||||
IP: supIPv4,
|
|
||||||
},
|
|
||||||
want: supIPv4,
|
|
||||||
}, {
|
|
||||||
name: "ipv6_tcp",
|
|
||||||
addr: &net.TCPAddr{
|
|
||||||
IP: supIPv6,
|
|
||||||
},
|
|
||||||
want: supIPv6,
|
|
||||||
}, {
|
|
||||||
name: "ipv4_udp",
|
|
||||||
addr: &net.UDPAddr{
|
|
||||||
IP: supIPv4,
|
|
||||||
},
|
|
||||||
want: supIPv4,
|
|
||||||
}, {
|
|
||||||
name: "ipv6_udp",
|
|
||||||
addr: &net.UDPAddr{
|
|
||||||
IP: supIPv6,
|
|
||||||
},
|
|
||||||
want: supIPv6,
|
|
||||||
}, {
|
|
||||||
name: "non-ip_addr",
|
|
||||||
addr: &fakeAddr{},
|
|
||||||
want: nil,
|
|
||||||
}}
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
assert.Equal(t, tc.want, IPFromAddr(tc.addr))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -335,37 +335,44 @@ func (clients *clientsContainer) Find(id string) (c *Client, ok bool) {
|
||||||
return c, true
|
return c, true
|
||||||
}
|
}
|
||||||
|
|
||||||
// FindUpstreams looks for upstreams configured for the client
|
// findUpstreams returns upstreams configured for the client, identified either
|
||||||
// If no client found for this IP, or if no custom upstreams are configured,
|
// by its IP address or its ClientID. upsConf is nil if the client isn't found
|
||||||
// this method returns nil
|
// or if the client has no custom upstreams.
|
||||||
func (clients *clientsContainer) FindUpstreams(ip string) *proxy.UpstreamConfig {
|
func (clients *clientsContainer) findUpstreams(
|
||||||
|
id string,
|
||||||
|
) (upsConf *proxy.UpstreamConfig, err error) {
|
||||||
clients.lock.Lock()
|
clients.lock.Lock()
|
||||||
defer clients.lock.Unlock()
|
defer clients.lock.Unlock()
|
||||||
|
|
||||||
c, ok := clients.findLocked(ip)
|
c, ok := clients.findLocked(id)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
upstreams := aghstrings.FilterOut(c.Upstreams, aghstrings.IsCommentOrEmpty)
|
upstreams := aghstrings.FilterOut(c.Upstreams, aghstrings.IsCommentOrEmpty)
|
||||||
if len(upstreams) == 0 {
|
if len(upstreams) == 0 {
|
||||||
return nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.upstreamConfig == nil {
|
if c.upstreamConfig != nil {
|
||||||
conf, err := proxy.ParseUpstreamsConfig(
|
return c.upstreamConfig, nil
|
||||||
upstreams,
|
|
||||||
upstream.Options{
|
|
||||||
Bootstrap: config.DNS.BootstrapDNS,
|
|
||||||
Timeout: dnsforward.DefaultTimeout,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
if err == nil {
|
|
||||||
c.upstreamConfig = &conf
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.upstreamConfig
|
var conf proxy.UpstreamConfig
|
||||||
|
conf, err = proxy.ParseUpstreamsConfig(
|
||||||
|
upstreams,
|
||||||
|
upstream.Options{
|
||||||
|
Bootstrap: config.DNS.BootstrapDNS,
|
||||||
|
Timeout: dnsforward.DefaultTimeout,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.upstreamConfig = &conf
|
||||||
|
|
||||||
|
return &conf, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// findLocked searches for a client by its ID. For internal use only.
|
// findLocked searches for a client by its ID. For internal use only.
|
||||||
|
|
|
@ -25,7 +25,7 @@ func TestClients(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
ok, err := clients.Add(c)
|
ok, err := clients.Add(c)
|
||||||
require.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
c = &Client{
|
c = &Client{
|
||||||
|
@ -34,7 +34,7 @@ func TestClients(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
ok, err = clients.Add(c)
|
ok, err = clients.Add(c)
|
||||||
require.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
c, ok = clients.Find("1.1.1.1")
|
c, ok = clients.Find("1.1.1.1")
|
||||||
|
@ -59,7 +59,7 @@ func TestClients(t *testing.T) {
|
||||||
IDs: []string{"1.2.3.5"},
|
IDs: []string{"1.2.3.5"},
|
||||||
Name: "client1",
|
Name: "client1",
|
||||||
})
|
})
|
||||||
require.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.False(t, ok)
|
assert.False(t, ok)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -68,7 +68,7 @@ func TestClients(t *testing.T) {
|
||||||
IDs: []string{"2.2.2.2"},
|
IDs: []string{"2.2.2.2"},
|
||||||
Name: "client3",
|
Name: "client3",
|
||||||
})
|
})
|
||||||
require.NotNil(t, err)
|
require.Error(t, err)
|
||||||
assert.False(t, ok)
|
assert.False(t, ok)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -77,13 +77,13 @@ func TestClients(t *testing.T) {
|
||||||
IDs: []string{"1.2.3.0"},
|
IDs: []string{"1.2.3.0"},
|
||||||
Name: "client3",
|
Name: "client3",
|
||||||
})
|
})
|
||||||
require.NotNil(t, err)
|
require.Error(t, err)
|
||||||
|
|
||||||
err = clients.Update("client3", &Client{
|
err = clients.Update("client3", &Client{
|
||||||
IDs: []string{"1.2.3.0"},
|
IDs: []string{"1.2.3.0"},
|
||||||
Name: "client2",
|
Name: "client2",
|
||||||
})
|
})
|
||||||
assert.NotNil(t, err)
|
assert.Error(t, err)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("update_fail_ip", func(t *testing.T) {
|
t.Run("update_fail_ip", func(t *testing.T) {
|
||||||
|
@ -91,7 +91,7 @@ func TestClients(t *testing.T) {
|
||||||
IDs: []string{"2.2.2.2"},
|
IDs: []string{"2.2.2.2"},
|
||||||
Name: "client1",
|
Name: "client1",
|
||||||
})
|
})
|
||||||
assert.NotNil(t, err)
|
assert.Error(t, err)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("update_success", func(t *testing.T) {
|
t.Run("update_success", func(t *testing.T) {
|
||||||
|
@ -99,7 +99,7 @@ func TestClients(t *testing.T) {
|
||||||
IDs: []string{"1.1.1.2"},
|
IDs: []string{"1.1.1.2"},
|
||||||
Name: "client1",
|
Name: "client1",
|
||||||
})
|
})
|
||||||
require.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.False(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
|
assert.False(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
|
||||||
assert.True(t, clients.Exists("1.1.1.2", ClientSourceHostsFile))
|
assert.True(t, clients.Exists("1.1.1.2", ClientSourceHostsFile))
|
||||||
|
@ -109,7 +109,7 @@ func TestClients(t *testing.T) {
|
||||||
Name: "client1-renamed",
|
Name: "client1-renamed",
|
||||||
UseOwnSettings: true,
|
UseOwnSettings: true,
|
||||||
})
|
})
|
||||||
require.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
c, ok := clients.Find("1.1.1.2")
|
c, ok := clients.Find("1.1.1.2")
|
||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
|
@ -137,15 +137,15 @@ func TestClients(t *testing.T) {
|
||||||
|
|
||||||
t.Run("addhost_success", func(t *testing.T) {
|
t.Run("addhost_success", func(t *testing.T) {
|
||||||
ok, err := clients.AddHost("1.1.1.1", "host", ClientSourceARP)
|
ok, err := clients.AddHost("1.1.1.1", "host", ClientSourceARP)
|
||||||
require.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
ok, err = clients.AddHost("1.1.1.1", "host2", ClientSourceARP)
|
ok, err = clients.AddHost("1.1.1.1", "host2", ClientSourceARP)
|
||||||
require.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
ok, err = clients.AddHost("1.1.1.1", "host3", ClientSourceHostsFile)
|
ok, err = clients.AddHost("1.1.1.1", "host3", ClientSourceHostsFile)
|
||||||
require.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
|
assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
|
||||||
|
@ -153,7 +153,7 @@ func TestClients(t *testing.T) {
|
||||||
|
|
||||||
t.Run("addhost_fail", func(t *testing.T) {
|
t.Run("addhost_fail", func(t *testing.T) {
|
||||||
ok, err := clients.AddHost("1.1.1.1", "host1", ClientSourceRDNS)
|
ok, err := clients.AddHost("1.1.1.1", "host1", ClientSourceRDNS)
|
||||||
require.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.False(t, ok)
|
assert.False(t, ok)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -181,7 +181,7 @@ func TestClientsWhois(t *testing.T) {
|
||||||
|
|
||||||
t.Run("existing_auto-client", func(t *testing.T) {
|
t.Run("existing_auto-client", func(t *testing.T) {
|
||||||
ok, err := clients.AddHost("1.1.1.1", "host", ClientSourceRDNS)
|
ok, err := clients.AddHost("1.1.1.1", "host", ClientSourceRDNS)
|
||||||
require.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
clients.SetWhoisInfo("1.1.1.1", whois)
|
clients.SetWhoisInfo("1.1.1.1", whois)
|
||||||
|
@ -198,7 +198,7 @@ func TestClientsWhois(t *testing.T) {
|
||||||
IDs: []string{"1.1.1.2"},
|
IDs: []string{"1.1.1.2"},
|
||||||
Name: "client1",
|
Name: "client1",
|
||||||
})
|
})
|
||||||
require.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
clients.SetWhoisInfo("1.1.1.2", whois)
|
clients.SetWhoisInfo("1.1.1.2", whois)
|
||||||
|
@ -219,12 +219,12 @@ func TestClientsAddExisting(t *testing.T) {
|
||||||
IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa", "2.2.2.0/24"},
|
IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa", "2.2.2.0/24"},
|
||||||
Name: "client1",
|
Name: "client1",
|
||||||
})
|
})
|
||||||
require.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
// Now add an auto-client with the same IP.
|
// Now add an auto-client with the same IP.
|
||||||
ok, err = clients.AddHost("1.1.1.1", "test", ClientSourceRDNS)
|
ok, err = clients.AddHost("1.1.1.1", "test", ClientSourceRDNS)
|
||||||
require.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -253,14 +253,14 @@ func TestClientsAddExisting(t *testing.T) {
|
||||||
Hostname: "testhost",
|
Hostname: "testhost",
|
||||||
Expiry: time.Now().Add(time.Hour),
|
Expiry: time.Now().Add(time.Hour),
|
||||||
})
|
})
|
||||||
require.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Add a new client with the same IP as for a client with MAC.
|
// Add a new client with the same IP as for a client with MAC.
|
||||||
ok, err := clients.Add(&Client{
|
ok, err := clients.Add(&Client{
|
||||||
IDs: []string{testIP.String()},
|
IDs: []string{testIP.String()},
|
||||||
Name: "client2",
|
Name: "client2",
|
||||||
})
|
})
|
||||||
require.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
// Add a new client with the IP from the first client's IP
|
// Add a new client with the IP from the first client's IP
|
||||||
|
@ -269,7 +269,7 @@ func TestClientsAddExisting(t *testing.T) {
|
||||||
IDs: []string{"2.2.2.2"},
|
IDs: []string{"2.2.2.2"},
|
||||||
Name: "client3",
|
Name: "client3",
|
||||||
})
|
})
|
||||||
require.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -289,14 +289,16 @@ func TestClientsCustomUpstream(t *testing.T) {
|
||||||
"[/example.org/]8.8.8.8",
|
"[/example.org/]8.8.8.8",
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
require.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
config := clients.FindUpstreams("1.2.3.4")
|
config, err := clients.findUpstreams("1.2.3.4")
|
||||||
assert.Nil(t, config)
|
assert.Nil(t, config)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
config = clients.FindUpstreams("1.1.1.1")
|
config, err = clients.findUpstreams("1.1.1.1")
|
||||||
require.NotNil(t, config)
|
require.NotNil(t, config)
|
||||||
|
assert.NoError(t, err)
|
||||||
assert.Len(t, config.Upstreams, 1)
|
assert.Len(t, config.Upstreams, 1)
|
||||||
assert.Len(t, config.DomainReservedUpstreams, 1)
|
assert.Len(t, config.DomainReservedUpstreams, 1)
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
||||||
|
@ -106,7 +107,7 @@ func isRunning() bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
func onDNSRequest(d *proxy.DNSContext) {
|
func onDNSRequest(d *proxy.DNSContext) {
|
||||||
ip := dnsforward.IPFromAddr(d.Addr)
|
ip := aghnet.IPFromAddr(d.Addr)
|
||||||
if ip == nil {
|
if ip == nil {
|
||||||
// This would be quite weird if we get here.
|
// This would be quite weird if we get here.
|
||||||
return
|
return
|
||||||
|
@ -197,7 +198,7 @@ func generateServerConfig() (newConf dnsforward.ServerConfig, err error) {
|
||||||
newConf.TLSAllowUnencryptedDOH = tlsConf.AllowUnencryptedDOH
|
newConf.TLSAllowUnencryptedDOH = tlsConf.AllowUnencryptedDOH
|
||||||
|
|
||||||
newConf.FilterHandler = applyAdditionalFiltering
|
newConf.FilterHandler = applyAdditionalFiltering
|
||||||
newConf.GetCustomUpstreamByClient = Context.clients.FindUpstreams
|
newConf.GetCustomUpstreamByClient = Context.clients.findUpstreams
|
||||||
|
|
||||||
newConf.ResolveClients = dnsConf.ResolveClients
|
newConf.ResolveClients = dnsConf.ResolveClients
|
||||||
newConf.UsePrivateRDNS = dnsConf.UsePrivateRDNS
|
newConf.UsePrivateRDNS = dnsConf.UsePrivateRDNS
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/aghstrings"
|
||||||
"github.com/AdguardTeam/golibs/cache"
|
"github.com/AdguardTeam/golibs/cache"
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
|
@ -66,19 +67,6 @@ func trimValue(s string) string {
|
||||||
return s[:maxValueLength-3] + "..."
|
return s[:maxValueLength-3] + "..."
|
||||||
}
|
}
|
||||||
|
|
||||||
// coalesceStr returns the first non-empty string.
|
|
||||||
//
|
|
||||||
// TODO(a.garipov): Move to aghstrings?
|
|
||||||
func coalesceStr(strs ...string) (res string) {
|
|
||||||
for _, s := range strs {
|
|
||||||
if s != "" {
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// isWhoisComment returns true if the string is empty or is a WHOIS comment.
|
// isWhoisComment returns true if the string is empty or is a WHOIS comment.
|
||||||
func isWhoisComment(s string) (ok bool) {
|
func isWhoisComment(s string) (ok bool) {
|
||||||
return len(s) == 0 || s[0] == '#' || s[0] == '%'
|
return len(s) == 0 || s[0] == '#' || s[0] == '%'
|
||||||
|
@ -119,7 +107,7 @@ func whoisParse(data string) (m strmap) {
|
||||||
v = trimValue(v)
|
v = trimValue(v)
|
||||||
case "descr", "netname":
|
case "descr", "netname":
|
||||||
k = "orgname"
|
k = "orgname"
|
||||||
v = coalesceStr(orgname, v)
|
v = aghstrings.Coalesce(orgname, v)
|
||||||
orgname = v
|
orgname = v
|
||||||
case "whois":
|
case "whois":
|
||||||
k = "whois"
|
k = "whois"
|
||||||
|
|
Loading…
Reference in New Issue