Merge pull request #102 in DNS/adguard-dns from fix/improve_ratelimit to master

* commit '166bc72ff34f839d205166b3d560e97e8d6cf208':
  Fix tests
  Fix binary search in the whitelist
  Added ratelimit whitelist and tests
This commit is contained in:
Andrey Meshkov 2018-11-21 00:24:34 +03:00
commit 4e27ad0c8e
2 changed files with 122 additions and 10 deletions

View File

@ -3,6 +3,7 @@ package ratelimit
import ( import (
"errors" "errors"
"log" "log"
"sort"
"strconv" "strconv"
"time" "time"
@ -22,7 +23,6 @@ import (
) )
const defaultRatelimit = 100 const defaultRatelimit = 100
const defaultMaxRateLimitedIPs = 1024 * 1024
var ( var (
tokenBuckets = cache.New(time.Hour, time.Hour) tokenBuckets = cache.New(time.Hour, time.Hour)
@ -44,6 +44,15 @@ func (p *plug) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (
} }
func (p *plug) allowRequest(ip string) (bool, error) { func (p *plug) allowRequest(ip string) (bool, error) {
if len(p.whitelist) > 0 {
i := sort.SearchStrings(p.whitelist, ip)
if i < len(p.whitelist) && p.whitelist[i] == ip {
return true, nil
}
}
if _, found := tokenBuckets.Get(ip); !found { if _, found := tokenBuckets.Get(ip); !found {
tokenBuckets.Set(ip, rate.New(p.ratelimit, time.Second), time.Hour) tokenBuckets.Set(ip, rate.New(p.ratelimit, time.Second), time.Hour)
} }
@ -83,25 +92,46 @@ type plug struct {
Next plugin.Handler Next plugin.Handler
// configuration for creating above // configuration for creating above
ratelimit int // in requests per second per IP ratelimit int // in requests per second per IP
whitelist []string // a list of whitelisted IP addresses
} }
func setup(c *caddy.Controller) error { func setupPlugin(c *caddy.Controller) (*plug, error) {
p := &plug{ratelimit: defaultRatelimit} p := &plug{ratelimit: defaultRatelimit}
config := dnsserver.GetConfig(c)
for c.Next() { for c.Next() {
args := c.RemainingArgs() args := c.RemainingArgs()
if len(args) <= 0 { if len(args) > 0 {
continue ratelimit, err := strconv.Atoi(args[0])
if err != nil {
return nil, c.ArgErr()
}
p.ratelimit = ratelimit
} }
ratelimit, err := strconv.Atoi(args[0]) for c.NextBlock() {
if err != nil { switch c.Val() {
return c.ArgErr() case "whitelist":
p.whitelist = c.RemainingArgs()
if len(p.whitelist) > 0 {
sort.Strings(p.whitelist)
}
}
} }
p.ratelimit = ratelimit
} }
return p, nil
}
func setup(c *caddy.Controller) error {
p, err := setupPlugin(c)
if err != nil {
return err
}
config := dnsserver.GetConfig(c)
config.AddPlugin(func(next plugin.Handler) plugin.Handler { config.AddPlugin(func(next plugin.Handler) plugin.Handler {
p.Next = next p.Next = next
return p return p

View File

@ -0,0 +1,82 @@
package ratelimit
import (
"testing"
"github.com/mholt/caddy"
)
func TestSetup(t *testing.T) {
for i, testcase := range []struct {
config string
failing bool
}{
{`ratelimit`, false},
{`ratelimit 100`, false},
{`ratelimit {
whitelist 127.0.0.1
}`, false},
{`ratelimit 50 {
whitelist 127.0.0.1 176.103.130.130
}`, false},
{`ratelimit test`, true},
} {
c := caddy.NewTestController("dns", testcase.config)
err := setup(c)
if err != nil {
if !testcase.failing {
t.Fatalf("Test #%d expected no errors, but got: %v", i, err)
}
continue
}
if testcase.failing {
t.Fatalf("Test #%d expected to fail but it didn't", i)
}
}
}
func TestRatelimiting(t *testing.T) {
// rate limit is 1 per sec
c := caddy.NewTestController("dns", `ratelimit 1`)
p, err := setupPlugin(c)
if err != nil {
t.Fatal("Failed to initialize the plugin")
}
allowed, err := p.allowRequest("127.0.0.1")
if err != nil || !allowed {
t.Fatal("First request must have been allowed")
}
allowed, err = p.allowRequest("127.0.0.1")
if err != nil || allowed {
t.Fatal("Second request must have been ratelimited")
}
}
func TestWhitelist(t *testing.T) {
// rate limit is 1 per sec
c := caddy.NewTestController("dns", `ratelimit 1 { whitelist 127.0.0.2 127.0.0.1 127.0.0.125 }`)
p, err := setupPlugin(c)
if err != nil {
t.Fatal("Failed to initialize the plugin")
}
allowed, err := p.allowRequest("127.0.0.1")
if err != nil || !allowed {
t.Fatal("First request must have been allowed")
}
allowed, err = p.allowRequest("127.0.0.1")
if err != nil || !allowed {
t.Fatal("Second request must have been allowed due to whitelist")
}
}