diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index bc8ed460..6db87c7a 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -43,6 +43,12 @@ type Server struct { stats *stats // General server statistics once sync.Once + AllowedClients map[string]bool // IP addresses of whitelist clients + DisallowedClients map[string]bool // IP addresses of clients that should be blocked + AllowedClientsIPNet []net.IPNet // CIDRs of whitelist clients + DisallowedClientsIPNet []net.IPNet // CIDRs of clients that should be blocked + BlockedHosts map[string]bool // hosts that should be blocked + sync.RWMutex conf ServerConfig } @@ -70,6 +76,10 @@ type FilteringConfig struct { BootstrapDNS []string `yaml:"bootstrap_dns"` // a list of bootstrap DNS for DoH and DoT (plain DNS only) AllServers bool `yaml:"all_servers"` // if true, parallel queries to all configured upstream servers are enabled + AllowedClients []string `yaml:"allowed_clients"` // IP addresses of whitelist clients + DisallowedClients []string `yaml:"disallowed_clients"` // IP addresses of clients that should be blocked + BlockedHosts []string `yaml:"blocked_hosts"` // hosts that should be blocked + dnsfilter.Config `yaml:",inline"` } @@ -121,6 +131,34 @@ func (s *Server) Start(config *ServerConfig) error { return s.startInternal(config) } +func convertArrayToMap(dst *map[string]bool, src []string) { + *dst = make(map[string]bool) + for _, s := range src { + (*dst)[s] = true + } +} + +// Split array of IP or CIDR into 2 containers for fast search +func processIPCIDRArray(dst *map[string]bool, dstIPNet *[]net.IPNet, src []string) error { + *dst = make(map[string]bool) + + for _, s := range src { + ip := net.ParseIP(s) + if ip != nil { + (*dst)[s] = true + continue + } + + _, ipnet, err := net.ParseCIDR(s) + if err != nil { + return err + } + *dstIPNet = append(*dstIPNet, *ipnet) + } + + return nil +} + // startInternal starts without locking func (s *Server) startInternal(config *ServerConfig) error { if config != nil { @@ -167,10 +205,23 @@ func (s *Server) startInternal(config *ServerConfig) error { CacheEnabled: true, Upstreams: s.conf.Upstreams, DomainsReservedUpstreams: s.conf.DomainsReservedUpstreams, - Handler: s.handleDNSRequest, + BeforeRequestHandler: s.beforeRequestHandler, + RequestHandler: s.handleDNSRequest, AllServers: s.conf.AllServers, } + err = processIPCIDRArray(&s.AllowedClients, &s.AllowedClientsIPNet, s.conf.AllowedClients) + if err != nil { + return err + } + + err = processIPCIDRArray(&s.DisallowedClients, &s.DisallowedClientsIPNet, s.conf.DisallowedClients) + if err != nil { + return err + } + + convertArrayToMap(&s.BlockedHosts, s.conf.BlockedHosts) + if s.conf.TLSListenAddr != nil && s.conf.CertificateChain != "" && s.conf.PrivateKey != "" { proxyConfig.TLSListenAddr = s.conf.TLSListenAddr keypair, err := tls.X509KeyPair([]byte(s.conf.CertificateChain), []byte(s.conf.PrivateKey)) @@ -321,6 +372,67 @@ func (s *Server) GetStatsHistory(timeUnit time.Duration, startTime time.Time, en return s.stats.getStatsHistory(timeUnit, startTime, endTime) } +// Return TRUE if this client should be blocked +func (s *Server) isBlockedIP(ip string) bool { + if len(s.AllowedClients) != 0 || len(s.AllowedClientsIPNet) != 0 { + _, ok := s.AllowedClients[ip] + if ok { + return false + } + + if len(s.AllowedClientsIPNet) != 0 { + ipAddr := net.ParseIP(ip) + for _, ipnet := range s.AllowedClientsIPNet { + if ipnet.Contains(ipAddr) { + return false + } + } + } + + return true + } + + _, ok := s.DisallowedClients[ip] + if ok { + return true + } + + if len(s.DisallowedClientsIPNet) != 0 { + ipAddr := net.ParseIP(ip) + for _, ipnet := range s.DisallowedClientsIPNet { + if ipnet.Contains(ipAddr) { + return true + } + } + } + + return false +} + +// Return TRUE if this domain should be blocked +func (s *Server) isBlockedDomain(host string) bool { + _, ok := s.BlockedHosts[host] + return ok +} + +func (s *Server) beforeRequestHandler(p *proxy.Proxy, d *proxy.DNSContext) (bool, error) { + ip, _, _ := net.SplitHostPort(d.Addr.String()) + if s.isBlockedIP(ip) { + log.Tracef("Client IP %s is blocked by settings", ip) + return false, nil + } + + if len(d.Req.Question) == 1 { + host := strings.TrimSuffix(d.Req.Question[0].Name, ".") + if s.isBlockedDomain(host) { + log.Tracef("Domain %s is blocked by settings", host) + return false, nil + } + } + + return true, nil +} + // handleDNSRequest filters the incoming DNS requests and writes them to the query log func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error { start := time.Now() diff --git a/dnsforward/dnsforward_test.go b/dnsforward/dnsforward_test.go index 1a38e96b..d643bf30 100644 --- a/dnsforward/dnsforward_test.go +++ b/dnsforward/dnsforward_test.go @@ -15,12 +15,10 @@ import ( "testing" "time" - "github.com/AdguardTeam/dnsproxy/proxy" - - "github.com/stretchr/testify/assert" - "github.com/AdguardTeam/AdGuardHome/dnsfilter" + "github.com/AdguardTeam/dnsproxy/proxy" "github.com/miekg/dns" + "github.com/stretchr/testify/assert" ) const ( @@ -460,12 +458,8 @@ func createTestServer(t *testing.T) *Server { s.conf.FilteringConfig.SafeBrowsingEnabled = true s.conf.Filters = make([]dnsfilter.Filter, 0) - rules := []string{ - "||nxdomain.example.org^", - "||null.example.org^", - "127.0.0.1 host.example.org", - } - filter := dnsfilter.Filter{ID: 1, Rules: rules} + rules := "||nxdomain.example.org^\n||null.example.org^\n127.0.0.1 host.example.org\n" + filter := dnsfilter.Filter{ID: 1, Data: []byte(rules)} s.conf.Filters = append(s.conf.Filters, filter) return s } @@ -628,3 +622,72 @@ func publicKey(priv interface{}) interface{} { return nil } } + +func TestIsBlockedIPAllowed(t *testing.T) { + s := createTestServer(t) + s.conf.AllowedClients = []string{"1.1.1.1", "2.2.0.0/16"} + + err := s.Start(nil) + defer removeDataDir(t) + if err != nil { + t.Fatalf("Failed to start server: %s", err) + } + + if s.isBlockedIP("1.1.1.1") { + t.Fatalf("isBlockedIP") + } + if !s.isBlockedIP("1.1.1.2") { + t.Fatalf("isBlockedIP") + } + if s.isBlockedIP("2.2.1.1") { + t.Fatalf("isBlockedIP") + } + if !s.isBlockedIP("2.3.1.1") { + t.Fatalf("isBlockedIP") + } +} + +func TestIsBlockedIPDisallowed(t *testing.T) { + s := createTestServer(t) + s.conf.DisallowedClients = []string{"1.1.1.1", "2.2.0.0/16"} + + err := s.Start(nil) + defer removeDataDir(t) + if err != nil { + t.Fatalf("Failed to start server: %s", err) + } + + if !s.isBlockedIP("1.1.1.1") { + t.Fatalf("isBlockedIP") + } + if s.isBlockedIP("1.1.1.2") { + t.Fatalf("isBlockedIP") + } + if !s.isBlockedIP("2.2.1.1") { + t.Fatalf("isBlockedIP") + } + if s.isBlockedIP("2.3.1.1") { + t.Fatalf("isBlockedIP") + } +} + +func TestIsBlockedIPBlockedDomain(t *testing.T) { + s := createTestServer(t) + s.conf.BlockedHosts = []string{"host1", "host2"} + + err := s.Start(nil) + defer removeDataDir(t) + if err != nil { + t.Fatalf("Failed to start server: %s", err) + } + + if !s.isBlockedDomain("host1") { + t.Fatalf("isBlockedDomain") + } + if !s.isBlockedDomain("host2") { + t.Fatalf("isBlockedDomain") + } + if s.isBlockedDomain("host3") { + t.Fatalf("isBlockedDomain") + } +}