dnsforward -- implement ratelimit and refuseany
This commit is contained in:
parent
15f0dee719
commit
478ce03386
27
config.go
27
config.go
@ -46,14 +46,11 @@ type coreDNSConfig struct {
|
|||||||
|
|
||||||
dnsforward.FilteringConfig `yaml:",inline"`
|
dnsforward.FilteringConfig `yaml:",inline"`
|
||||||
|
|
||||||
QueryLogEnabled bool `yaml:"querylog_enabled"`
|
Pprof string `yaml:"-"`
|
||||||
Ratelimit int `yaml:"ratelimit"`
|
Cache string `yaml:"-"`
|
||||||
RefuseAny bool `yaml:"refuse_any"`
|
Prometheus string `yaml:"-"`
|
||||||
Pprof string `yaml:"-"`
|
BootstrapDNS string `yaml:"bootstrap_dns"`
|
||||||
Cache string `yaml:"-"`
|
UpstreamDNS []string `yaml:"upstream_dns"`
|
||||||
Prometheus string `yaml:"-"`
|
|
||||||
BootstrapDNS string `yaml:"bootstrap_dns"`
|
|
||||||
UpstreamDNS []string `yaml:"upstream_dns"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var defaultDNS = []string{"tls://1.1.1.1", "tls://1.0.0.1"}
|
var defaultDNS = []string{"tls://1.1.1.1", "tls://1.0.0.1"}
|
||||||
@ -71,14 +68,14 @@ var config = configuration{
|
|||||||
ProtectionEnabled: true, // whether or not use any of dnsfilter features
|
ProtectionEnabled: true, // whether or not use any of dnsfilter features
|
||||||
FilteringEnabled: true, // whether or not use filter lists
|
FilteringEnabled: true, // whether or not use filter lists
|
||||||
BlockedResponseTTL: 10, // in seconds
|
BlockedResponseTTL: 10, // in seconds
|
||||||
|
QueryLogEnabled: true,
|
||||||
|
Ratelimit: 20,
|
||||||
|
RefuseAny: true,
|
||||||
},
|
},
|
||||||
QueryLogEnabled: true,
|
BootstrapDNS: "8.8.8.8:53",
|
||||||
Ratelimit: 20,
|
UpstreamDNS: defaultDNS,
|
||||||
RefuseAny: true,
|
Cache: "cache",
|
||||||
BootstrapDNS: "8.8.8.8:53",
|
Prometheus: "prometheus :9153",
|
||||||
UpstreamDNS: defaultDNS,
|
|
||||||
Cache: "cache",
|
|
||||||
Prometheus: "prometheus :9153",
|
|
||||||
},
|
},
|
||||||
Filters: []filter{
|
Filters: []filter{
|
||||||
{Filter: dnsfilter.Filter{ID: 1}, Enabled: true, URL: "https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt", Name: "AdGuard Simplified Domain Names filter"},
|
{Filter: dnsfilter.Filter{ID: 1}, Enabled: true, URL: "https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt", Name: "AdGuard Simplified Domain Names filter"},
|
||||||
|
@ -12,6 +12,7 @@ import (
|
|||||||
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
|
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
|
||||||
"github.com/joomcode/errorx"
|
"github.com/joomcode/errorx"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
|
gocache "github.com/patrickmn/go-cache"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Server is the main way to start a DNS server.
|
// Server is the main way to start a DNS server.
|
||||||
@ -31,6 +32,8 @@ type Server struct {
|
|||||||
|
|
||||||
cache cache
|
cache cache
|
||||||
|
|
||||||
|
ratelimitBuckets *gocache.Cache // where the ratelimiters are stored, per IP
|
||||||
|
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
ServerConfig
|
ServerConfig
|
||||||
}
|
}
|
||||||
@ -76,9 +79,13 @@ func (s *Server) RUnlock() {
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
type FilteringConfig struct {
|
type FilteringConfig struct {
|
||||||
ProtectionEnabled bool `yaml:"protection_enabled"`
|
ProtectionEnabled bool `yaml:"protection_enabled"`
|
||||||
FilteringEnabled bool `yaml:"filtering_enabled"`
|
FilteringEnabled bool `yaml:"filtering_enabled"`
|
||||||
BlockedResponseTTL uint32 `yaml:"blocked_response_ttl"` // if 0, then default is used (3600)
|
BlockedResponseTTL uint32 `yaml:"blocked_response_ttl"` // if 0, then default is used (3600)
|
||||||
|
QueryLogEnabled bool `yaml:"querylog_enabled"`
|
||||||
|
Ratelimit int `yaml:"ratelimit"`
|
||||||
|
RatelimitWhitelist []string `yaml:"ratelimit_whitelist"`
|
||||||
|
RefuseAny bool `yaml:"refuse_any"`
|
||||||
|
|
||||||
dnsfilter.Config `yaml:",inline"`
|
dnsfilter.Config `yaml:",inline"`
|
||||||
}
|
}
|
||||||
@ -92,6 +99,7 @@ type ServerConfig struct {
|
|||||||
FilteringConfig
|
FilteringConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// if any of ServerConfig values are zero, then default values from below are used
|
||||||
var defaultValues = ServerConfig{
|
var defaultValues = ServerConfig{
|
||||||
UDPListenAddr: &net.UDPAddr{Port: 53},
|
UDPListenAddr: &net.UDPAddr{Port: 53},
|
||||||
FilteringConfig: FilteringConfig{BlockedResponseTTL: 3600},
|
FilteringConfig: FilteringConfig{BlockedResponseTTL: 3600},
|
||||||
@ -413,6 +421,10 @@ func (s *Server) handlePacketInternal(msg *dns.Msg, addr net.Addr, conn *net.UDP
|
|||||||
return s.genServerFailure(msg), nil, nil, nil
|
return s.genServerFailure(msg), nil, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if msg.Question[0].Qtype == dns.TypeANY && s.RefuseAny {
|
||||||
|
return s.genNotImpl(msg), nil, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
// use dnsfilter before cache -- changed settings or filters would require cache invalidation otherwise
|
// use dnsfilter before cache -- changed settings or filters would require cache invalidation otherwise
|
||||||
host := strings.TrimSuffix(msg.Question[0].Name, ".")
|
host := strings.TrimSuffix(msg.Question[0].Name, ".")
|
||||||
res, err := s.dnsFilter.CheckHost(host)
|
res, err := s.dnsFilter.CheckHost(host)
|
||||||
@ -450,16 +462,36 @@ func (s *Server) handlePacketInternal(msg *dns.Msg, addr net.Addr, conn *net.UDP
|
|||||||
|
|
||||||
func (s *Server) handlePacket(p []byte, addr net.Addr, conn *net.UDPConn) {
|
func (s *Server) handlePacket(p []byte, addr net.Addr, conn *net.UDPConn) {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
|
ip, _, err := net.SplitHostPort(addr.String())
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to split %v into host/port: %s", addr, err)
|
||||||
|
// not a fatal error, move on
|
||||||
|
}
|
||||||
|
|
||||||
|
// ratelimit based on IP only, protects CPU cycles and outbound connections
|
||||||
|
if s.isRatelimited(ip) {
|
||||||
|
// log.Printf("Ratelimiting %s based on IP only", ip)
|
||||||
|
return // do nothing, don't reply, we got ratelimited
|
||||||
|
}
|
||||||
|
|
||||||
msg := &dns.Msg{}
|
msg := &dns.Msg{}
|
||||||
err := msg.Unpack(p)
|
err = msg.Unpack(p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("got invalid DNS packet: %s", err)
|
log.Printf("got invalid DNS packet: %s", err)
|
||||||
return // do nothing
|
return // do nothing
|
||||||
}
|
}
|
||||||
|
|
||||||
reply, result, upstream, err := s.handlePacketInternal(msg, addr, conn)
|
reply, result, upstream, err := s.handlePacketInternal(msg, addr, conn)
|
||||||
|
|
||||||
if reply != nil {
|
if reply != nil {
|
||||||
|
// ratelimit based on reply size now
|
||||||
|
replysize := reply.Len()
|
||||||
|
if s.isRatelimitedForReply(ip, replysize) {
|
||||||
|
log.Printf("Ratelimiting %s based on IP and size %d", ip, replysize)
|
||||||
|
return // do nothing, don't reply, we got ratelimited
|
||||||
|
}
|
||||||
|
|
||||||
|
// we're good to respond
|
||||||
rerr := s.respond(reply, addr, conn)
|
rerr := s.respond(reply, addr, conn)
|
||||||
if rerr != nil {
|
if rerr != nil {
|
||||||
log.Printf("Couldn't respond to UDP packet: %s", err)
|
log.Printf("Couldn't respond to UDP packet: %s", err)
|
||||||
@ -467,16 +499,14 @@ func (s *Server) handlePacket(p []byte, addr net.Addr, conn *net.UDPConn) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// query logging and stats counters
|
// query logging and stats counters
|
||||||
elapsed := time.Since(start)
|
if s.QueryLogEnabled {
|
||||||
upstreamAddr := ""
|
elapsed := time.Since(start)
|
||||||
if upstream != nil {
|
upstreamAddr := ""
|
||||||
upstreamAddr = upstream.Address()
|
if upstream != nil {
|
||||||
|
upstreamAddr = upstream.Address()
|
||||||
|
}
|
||||||
|
logRequest(msg, reply, result, elapsed, ip, upstreamAddr)
|
||||||
}
|
}
|
||||||
host, _, err := net.SplitHostPort(addr.String())
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to split %v into host/port: %s", addr, err)
|
|
||||||
}
|
|
||||||
logRequest(msg, reply, result, elapsed, host, upstreamAddr)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
@ -506,12 +536,22 @@ func (s *Server) respond(resp *dns.Msg, addr net.Addr, conn *net.UDPConn) error
|
|||||||
func (s *Server) genServerFailure(request *dns.Msg) *dns.Msg {
|
func (s *Server) genServerFailure(request *dns.Msg) *dns.Msg {
|
||||||
resp := dns.Msg{}
|
resp := dns.Msg{}
|
||||||
resp.SetRcode(request, dns.RcodeServerFailure)
|
resp.SetRcode(request, dns.RcodeServerFailure)
|
||||||
|
resp.RecursionAvailable = true
|
||||||
|
return &resp
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) genNotImpl(request *dns.Msg) *dns.Msg {
|
||||||
|
resp := dns.Msg{}
|
||||||
|
resp.SetRcode(request, dns.RcodeNotImplemented)
|
||||||
|
resp.RecursionAvailable = true
|
||||||
|
resp.SetEdns0(1452, false) // NOTIMPL without EDNS is treated as 'we don't support EDNS', so explicitly set it
|
||||||
return &resp
|
return &resp
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) genNXDomain(request *dns.Msg) *dns.Msg {
|
func (s *Server) genNXDomain(request *dns.Msg) *dns.Msg {
|
||||||
resp := dns.Msg{}
|
resp := dns.Msg{}
|
||||||
resp.SetRcode(request, dns.RcodeNameError)
|
resp.SetRcode(request, dns.RcodeNameError)
|
||||||
|
resp.RecursionAvailable = true
|
||||||
resp.Ns = s.genSOA(request)
|
resp.Ns = s.genSOA(request)
|
||||||
return &resp
|
return &resp
|
||||||
}
|
}
|
||||||
|
80
dnsforward/ratelimit.go
Normal file
80
dnsforward/ratelimit.go
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
package dnsforward
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"sort"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/beefsack/go-rate"
|
||||||
|
gocache "github.com/patrickmn/go-cache"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *Server) limiterForIP(ip string) interface{} {
|
||||||
|
if s.ratelimitBuckets == nil {
|
||||||
|
s.ratelimitBuckets = gocache.New(time.Hour, time.Hour)
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if ratelimiter for that IP already exists, if not, create
|
||||||
|
value, found := s.ratelimitBuckets.Get(ip)
|
||||||
|
if !found {
|
||||||
|
value = rate.New(s.Ratelimit, time.Second)
|
||||||
|
s.ratelimitBuckets.Set(ip, value, time.Hour)
|
||||||
|
}
|
||||||
|
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) isRatelimited(ip string) bool {
|
||||||
|
if s.Ratelimit == 0 { // 0 -- disabled
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if len(s.RatelimitWhitelist) > 0 {
|
||||||
|
i := sort.SearchStrings(s.RatelimitWhitelist, ip)
|
||||||
|
|
||||||
|
if i < len(s.RatelimitWhitelist) && s.RatelimitWhitelist[i] == ip {
|
||||||
|
// found, don't ratelimit
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
value := s.limiterForIP(ip)
|
||||||
|
rl, ok := value.(*rate.RateLimiter)
|
||||||
|
if !ok {
|
||||||
|
log.Println("SHOULD NOT HAPPEN: non-bool entry found in safebrowsing lookup cache")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
allow, _ := rl.Try()
|
||||||
|
return !allow
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) isRatelimitedForReply(ip string, size int) bool {
|
||||||
|
if s.Ratelimit == 0 { // 0 -- disabled
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if len(s.RatelimitWhitelist) > 0 {
|
||||||
|
i := sort.SearchStrings(s.RatelimitWhitelist, ip)
|
||||||
|
|
||||||
|
if i < len(s.RatelimitWhitelist) && s.RatelimitWhitelist[i] == ip {
|
||||||
|
// found, don't ratelimit
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
value := s.limiterForIP(ip)
|
||||||
|
rl, ok := value.(*rate.RateLimiter)
|
||||||
|
if !ok {
|
||||||
|
log.Println("SHOULD NOT HAPPEN: non-bool entry found in safebrowsing lookup cache")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// For large UDP responses we try more times, effectively limiting per bandwidth
|
||||||
|
// The exact number of times depends on the response size
|
||||||
|
for i := 0; i < size/1000; i++ {
|
||||||
|
allow, _ := rl.Try()
|
||||||
|
if !allow { // not allowed -> ratelimited
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
42
dnsforward/ratelimit_test.go
Normal file
42
dnsforward/ratelimit_test.go
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
package dnsforward
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRatelimiting(t *testing.T) {
|
||||||
|
// rate limit is 1 per sec
|
||||||
|
p := Server{}
|
||||||
|
p.Ratelimit = 1
|
||||||
|
|
||||||
|
limited := p.isRatelimited("127.0.0.1")
|
||||||
|
|
||||||
|
if limited {
|
||||||
|
t.Fatal("First request must have been allowed")
|
||||||
|
}
|
||||||
|
|
||||||
|
limited = p.isRatelimited("127.0.0.1")
|
||||||
|
|
||||||
|
if !limited {
|
||||||
|
t.Fatal("Second request must have been ratelimited")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWhitelist(t *testing.T) {
|
||||||
|
// rate limit is 1 per sec with whitelist
|
||||||
|
p := Server{}
|
||||||
|
p.Ratelimit = 1
|
||||||
|
p.RatelimitWhitelist = []string{"127.0.0.1", "127.0.0.2", "127.0.0.125"}
|
||||||
|
|
||||||
|
limited := p.isRatelimited("127.0.0.1")
|
||||||
|
|
||||||
|
if limited {
|
||||||
|
t.Fatal("First request must have been allowed")
|
||||||
|
}
|
||||||
|
|
||||||
|
limited = p.isRatelimited("127.0.0.1")
|
||||||
|
|
||||||
|
if limited {
|
||||||
|
t.Fatal("Second request must have been allowed due to whitelist")
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user