+ dnsforward: add access settings for blocking DNS requests
Block by client IP or target domain name.
This commit is contained in:
parent
147344afa3
commit
36ffcf7d22
|
@ -43,6 +43,12 @@ type Server struct {
|
||||||
stats *stats // General server statistics
|
stats *stats // General server statistics
|
||||||
once sync.Once
|
once sync.Once
|
||||||
|
|
||||||
|
AllowedClients map[string]bool // IP addresses of whitelist clients
|
||||||
|
DisallowedClients map[string]bool // IP addresses of clients that should be blocked
|
||||||
|
AllowedClientsIPNet []net.IPNet // CIDRs of whitelist clients
|
||||||
|
DisallowedClientsIPNet []net.IPNet // CIDRs of clients that should be blocked
|
||||||
|
BlockedHosts map[string]bool // hosts that should be blocked
|
||||||
|
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
conf ServerConfig
|
conf ServerConfig
|
||||||
}
|
}
|
||||||
|
@ -70,6 +76,10 @@ type FilteringConfig struct {
|
||||||
BootstrapDNS []string `yaml:"bootstrap_dns"` // a list of bootstrap DNS for DoH and DoT (plain DNS only)
|
BootstrapDNS []string `yaml:"bootstrap_dns"` // a list of bootstrap DNS for DoH and DoT (plain DNS only)
|
||||||
AllServers bool `yaml:"all_servers"` // if true, parallel queries to all configured upstream servers are enabled
|
AllServers bool `yaml:"all_servers"` // if true, parallel queries to all configured upstream servers are enabled
|
||||||
|
|
||||||
|
AllowedClients []string `yaml:"allowed_clients"` // IP addresses of whitelist clients
|
||||||
|
DisallowedClients []string `yaml:"disallowed_clients"` // IP addresses of clients that should be blocked
|
||||||
|
BlockedHosts []string `yaml:"blocked_hosts"` // hosts that should be blocked
|
||||||
|
|
||||||
dnsfilter.Config `yaml:",inline"`
|
dnsfilter.Config `yaml:",inline"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -121,6 +131,34 @@ func (s *Server) Start(config *ServerConfig) error {
|
||||||
return s.startInternal(config)
|
return s.startInternal(config)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func convertArrayToMap(dst *map[string]bool, src []string) {
|
||||||
|
*dst = make(map[string]bool)
|
||||||
|
for _, s := range src {
|
||||||
|
(*dst)[s] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Split array of IP or CIDR into 2 containers for fast search
|
||||||
|
func processIPCIDRArray(dst *map[string]bool, dstIPNet *[]net.IPNet, src []string) error {
|
||||||
|
*dst = make(map[string]bool)
|
||||||
|
|
||||||
|
for _, s := range src {
|
||||||
|
ip := net.ParseIP(s)
|
||||||
|
if ip != nil {
|
||||||
|
(*dst)[s] = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
_, ipnet, err := net.ParseCIDR(s)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*dstIPNet = append(*dstIPNet, *ipnet)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// startInternal starts without locking
|
// startInternal starts without locking
|
||||||
func (s *Server) startInternal(config *ServerConfig) error {
|
func (s *Server) startInternal(config *ServerConfig) error {
|
||||||
if config != nil {
|
if config != nil {
|
||||||
|
@ -167,10 +205,23 @@ func (s *Server) startInternal(config *ServerConfig) error {
|
||||||
CacheEnabled: true,
|
CacheEnabled: true,
|
||||||
Upstreams: s.conf.Upstreams,
|
Upstreams: s.conf.Upstreams,
|
||||||
DomainsReservedUpstreams: s.conf.DomainsReservedUpstreams,
|
DomainsReservedUpstreams: s.conf.DomainsReservedUpstreams,
|
||||||
Handler: s.handleDNSRequest,
|
BeforeRequestHandler: s.beforeRequestHandler,
|
||||||
|
RequestHandler: s.handleDNSRequest,
|
||||||
AllServers: s.conf.AllServers,
|
AllServers: s.conf.AllServers,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = processIPCIDRArray(&s.AllowedClients, &s.AllowedClientsIPNet, s.conf.AllowedClients)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = processIPCIDRArray(&s.DisallowedClients, &s.DisallowedClientsIPNet, s.conf.DisallowedClients)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
convertArrayToMap(&s.BlockedHosts, s.conf.BlockedHosts)
|
||||||
|
|
||||||
if s.conf.TLSListenAddr != nil && s.conf.CertificateChain != "" && s.conf.PrivateKey != "" {
|
if s.conf.TLSListenAddr != nil && s.conf.CertificateChain != "" && s.conf.PrivateKey != "" {
|
||||||
proxyConfig.TLSListenAddr = s.conf.TLSListenAddr
|
proxyConfig.TLSListenAddr = s.conf.TLSListenAddr
|
||||||
keypair, err := tls.X509KeyPair([]byte(s.conf.CertificateChain), []byte(s.conf.PrivateKey))
|
keypair, err := tls.X509KeyPair([]byte(s.conf.CertificateChain), []byte(s.conf.PrivateKey))
|
||||||
|
@ -321,6 +372,67 @@ func (s *Server) GetStatsHistory(timeUnit time.Duration, startTime time.Time, en
|
||||||
return s.stats.getStatsHistory(timeUnit, startTime, endTime)
|
return s.stats.getStatsHistory(timeUnit, startTime, endTime)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Return TRUE if this client should be blocked
|
||||||
|
func (s *Server) isBlockedIP(ip string) bool {
|
||||||
|
if len(s.AllowedClients) != 0 || len(s.AllowedClientsIPNet) != 0 {
|
||||||
|
_, ok := s.AllowedClients[ip]
|
||||||
|
if ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(s.AllowedClientsIPNet) != 0 {
|
||||||
|
ipAddr := net.ParseIP(ip)
|
||||||
|
for _, ipnet := range s.AllowedClientsIPNet {
|
||||||
|
if ipnet.Contains(ipAddr) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
_, ok := s.DisallowedClients[ip]
|
||||||
|
if ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(s.DisallowedClientsIPNet) != 0 {
|
||||||
|
ipAddr := net.ParseIP(ip)
|
||||||
|
for _, ipnet := range s.DisallowedClientsIPNet {
|
||||||
|
if ipnet.Contains(ipAddr) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return TRUE if this domain should be blocked
|
||||||
|
func (s *Server) isBlockedDomain(host string) bool {
|
||||||
|
_, ok := s.BlockedHosts[host]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) beforeRequestHandler(p *proxy.Proxy, d *proxy.DNSContext) (bool, error) {
|
||||||
|
ip, _, _ := net.SplitHostPort(d.Addr.String())
|
||||||
|
if s.isBlockedIP(ip) {
|
||||||
|
log.Tracef("Client IP %s is blocked by settings", ip)
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(d.Req.Question) == 1 {
|
||||||
|
host := strings.TrimSuffix(d.Req.Question[0].Name, ".")
|
||||||
|
if s.isBlockedDomain(host) {
|
||||||
|
log.Tracef("Domain %s is blocked by settings", host)
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
// handleDNSRequest filters the incoming DNS requests and writes them to the query log
|
// handleDNSRequest filters the incoming DNS requests and writes them to the query log
|
||||||
func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
|
func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
|
|
|
@ -15,12 +15,10 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
|
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
|
||||||
|
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -460,12 +458,8 @@ func createTestServer(t *testing.T) *Server {
|
||||||
s.conf.FilteringConfig.SafeBrowsingEnabled = true
|
s.conf.FilteringConfig.SafeBrowsingEnabled = true
|
||||||
s.conf.Filters = make([]dnsfilter.Filter, 0)
|
s.conf.Filters = make([]dnsfilter.Filter, 0)
|
||||||
|
|
||||||
rules := []string{
|
rules := "||nxdomain.example.org^\n||null.example.org^\n127.0.0.1 host.example.org\n"
|
||||||
"||nxdomain.example.org^",
|
filter := dnsfilter.Filter{ID: 1, Data: []byte(rules)}
|
||||||
"||null.example.org^",
|
|
||||||
"127.0.0.1 host.example.org",
|
|
||||||
}
|
|
||||||
filter := dnsfilter.Filter{ID: 1, Rules: rules}
|
|
||||||
s.conf.Filters = append(s.conf.Filters, filter)
|
s.conf.Filters = append(s.conf.Filters, filter)
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
@ -628,3 +622,72 @@ func publicKey(priv interface{}) interface{} {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIsBlockedIPAllowed(t *testing.T) {
|
||||||
|
s := createTestServer(t)
|
||||||
|
s.conf.AllowedClients = []string{"1.1.1.1", "2.2.0.0/16"}
|
||||||
|
|
||||||
|
err := s.Start(nil)
|
||||||
|
defer removeDataDir(t)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to start server: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.isBlockedIP("1.1.1.1") {
|
||||||
|
t.Fatalf("isBlockedIP")
|
||||||
|
}
|
||||||
|
if !s.isBlockedIP("1.1.1.2") {
|
||||||
|
t.Fatalf("isBlockedIP")
|
||||||
|
}
|
||||||
|
if s.isBlockedIP("2.2.1.1") {
|
||||||
|
t.Fatalf("isBlockedIP")
|
||||||
|
}
|
||||||
|
if !s.isBlockedIP("2.3.1.1") {
|
||||||
|
t.Fatalf("isBlockedIP")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsBlockedIPDisallowed(t *testing.T) {
|
||||||
|
s := createTestServer(t)
|
||||||
|
s.conf.DisallowedClients = []string{"1.1.1.1", "2.2.0.0/16"}
|
||||||
|
|
||||||
|
err := s.Start(nil)
|
||||||
|
defer removeDataDir(t)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to start server: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !s.isBlockedIP("1.1.1.1") {
|
||||||
|
t.Fatalf("isBlockedIP")
|
||||||
|
}
|
||||||
|
if s.isBlockedIP("1.1.1.2") {
|
||||||
|
t.Fatalf("isBlockedIP")
|
||||||
|
}
|
||||||
|
if !s.isBlockedIP("2.2.1.1") {
|
||||||
|
t.Fatalf("isBlockedIP")
|
||||||
|
}
|
||||||
|
if s.isBlockedIP("2.3.1.1") {
|
||||||
|
t.Fatalf("isBlockedIP")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsBlockedIPBlockedDomain(t *testing.T) {
|
||||||
|
s := createTestServer(t)
|
||||||
|
s.conf.BlockedHosts = []string{"host1", "host2"}
|
||||||
|
|
||||||
|
err := s.Start(nil)
|
||||||
|
defer removeDataDir(t)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to start server: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !s.isBlockedDomain("host1") {
|
||||||
|
t.Fatalf("isBlockedDomain")
|
||||||
|
}
|
||||||
|
if !s.isBlockedDomain("host2") {
|
||||||
|
t.Fatalf("isBlockedDomain")
|
||||||
|
}
|
||||||
|
if s.isBlockedDomain("host3") {
|
||||||
|
t.Fatalf("isBlockedDomain")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue