Merge pull request #102 in DNS/adguard-dns from fix/improve_ratelimit to master
* commit '166bc72ff34f839d205166b3d560e97e8d6cf208': Fix tests Fix binary search in the whitelist Added ratelimit whitelist and tests
This commit is contained in:
commit
4e27ad0c8e
|
@ -3,6 +3,7 @@ package ratelimit
|
|||
import (
|
||||
"errors"
|
||||
"log"
|
||||
"sort"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
|
@ -22,7 +23,6 @@ import (
|
|||
)
|
||||
|
||||
const defaultRatelimit = 100
|
||||
const defaultMaxRateLimitedIPs = 1024 * 1024
|
||||
|
||||
var (
|
||||
tokenBuckets = cache.New(time.Hour, time.Hour)
|
||||
|
@ -44,6 +44,15 @@ func (p *plug) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (
|
|||
}
|
||||
|
||||
func (p *plug) allowRequest(ip string) (bool, error) {
|
||||
|
||||
if len(p.whitelist) > 0 {
|
||||
i := sort.SearchStrings(p.whitelist, ip)
|
||||
|
||||
if i < len(p.whitelist) && p.whitelist[i] == ip {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
if _, found := tokenBuckets.Get(ip); !found {
|
||||
tokenBuckets.Set(ip, rate.New(p.ratelimit, time.Second), time.Hour)
|
||||
}
|
||||
|
@ -84,24 +93,45 @@ type plug struct {
|
|||
|
||||
// configuration for creating above
|
||||
ratelimit int // in requests per second per IP
|
||||
whitelist []string // a list of whitelisted IP addresses
|
||||
|
||||
}
|
||||
|
||||
func setup(c *caddy.Controller) error {
|
||||
func setupPlugin(c *caddy.Controller) (*plug, error) {
|
||||
|
||||
p := &plug{ratelimit: defaultRatelimit}
|
||||
config := dnsserver.GetConfig(c)
|
||||
|
||||
for c.Next() {
|
||||
args := c.RemainingArgs()
|
||||
if len(args) <= 0 {
|
||||
continue
|
||||
}
|
||||
if len(args) > 0 {
|
||||
ratelimit, err := strconv.Atoi(args[0])
|
||||
if err != nil {
|
||||
return c.ArgErr()
|
||||
return nil, c.ArgErr()
|
||||
}
|
||||
p.ratelimit = ratelimit
|
||||
}
|
||||
for c.NextBlock() {
|
||||
switch c.Val() {
|
||||
case "whitelist":
|
||||
p.whitelist = c.RemainingArgs()
|
||||
|
||||
if len(p.whitelist) > 0 {
|
||||
sort.Strings(p.whitelist)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func setup(c *caddy.Controller) error {
|
||||
p, err := setupPlugin(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
config := dnsserver.GetConfig(c)
|
||||
config.AddPlugin(func(next plugin.Handler) plugin.Handler {
|
||||
p.Next = next
|
||||
return p
|
||||
|
|
|
@ -0,0 +1,82 @@
|
|||
package ratelimit
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
)
|
||||
|
||||
func TestSetup(t *testing.T) {
|
||||
for i, testcase := range []struct {
|
||||
config string
|
||||
failing bool
|
||||
}{
|
||||
{`ratelimit`, false},
|
||||
{`ratelimit 100`, false},
|
||||
{`ratelimit {
|
||||
whitelist 127.0.0.1
|
||||
}`, false},
|
||||
{`ratelimit 50 {
|
||||
whitelist 127.0.0.1 176.103.130.130
|
||||
}`, false},
|
||||
{`ratelimit test`, true},
|
||||
} {
|
||||
c := caddy.NewTestController("dns", testcase.config)
|
||||
err := setup(c)
|
||||
if err != nil {
|
||||
if !testcase.failing {
|
||||
t.Fatalf("Test #%d expected no errors, but got: %v", i, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if testcase.failing {
|
||||
t.Fatalf("Test #%d expected to fail but it didn't", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRatelimiting(t *testing.T) {
|
||||
|
||||
// rate limit is 1 per sec
|
||||
c := caddy.NewTestController("dns", `ratelimit 1`)
|
||||
p, err := setupPlugin(c)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal("Failed to initialize the plugin")
|
||||
}
|
||||
|
||||
allowed, err := p.allowRequest("127.0.0.1")
|
||||
|
||||
if err != nil || !allowed {
|
||||
t.Fatal("First request must have been allowed")
|
||||
}
|
||||
|
||||
allowed, err = p.allowRequest("127.0.0.1")
|
||||
|
||||
if err != nil || allowed {
|
||||
t.Fatal("Second request must have been ratelimited")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWhitelist(t *testing.T) {
|
||||
|
||||
// rate limit is 1 per sec
|
||||
c := caddy.NewTestController("dns", `ratelimit 1 { whitelist 127.0.0.2 127.0.0.1 127.0.0.125 }`)
|
||||
p, err := setupPlugin(c)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal("Failed to initialize the plugin")
|
||||
}
|
||||
|
||||
allowed, err := p.allowRequest("127.0.0.1")
|
||||
|
||||
if err != nil || !allowed {
|
||||
t.Fatal("First request must have been allowed")
|
||||
}
|
||||
|
||||
allowed, err = p.allowRequest("127.0.0.1")
|
||||
|
||||
if err != nil || !allowed {
|
||||
t.Fatal("Second request must have been allowed due to whitelist")
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue