Added ratelimit whitelist and tests
This commit is contained in:
parent
54ee16634c
commit
345e4dc89a
|
@ -3,6 +3,7 @@ package ratelimit
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"log"
|
"log"
|
||||||
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -22,7 +23,6 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const defaultRatelimit = 100
|
const defaultRatelimit = 100
|
||||||
const defaultMaxRateLimitedIPs = 1024 * 1024
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
tokenBuckets = cache.New(time.Hour, time.Hour)
|
tokenBuckets = cache.New(time.Hour, time.Hour)
|
||||||
|
@ -44,6 +44,11 @@ func (p *plug) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *plug) allowRequest(ip string) (bool, error) {
|
func (p *plug) allowRequest(ip string) (bool, error) {
|
||||||
|
|
||||||
|
if len(p.whitelist) > 0 && sort.SearchStrings(p.whitelist, ip) >= 0 {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
if _, found := tokenBuckets.Get(ip); !found {
|
if _, found := tokenBuckets.Get(ip); !found {
|
||||||
tokenBuckets.Set(ip, rate.New(p.ratelimit, time.Second), time.Hour)
|
tokenBuckets.Set(ip, rate.New(p.ratelimit, time.Second), time.Hour)
|
||||||
}
|
}
|
||||||
|
@ -83,25 +88,46 @@ type plug struct {
|
||||||
Next plugin.Handler
|
Next plugin.Handler
|
||||||
|
|
||||||
// configuration for creating above
|
// configuration for creating above
|
||||||
ratelimit int // in requests per second per IP
|
ratelimit int // in requests per second per IP
|
||||||
|
whitelist []string // a list of whitelisted IP addresses
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func setup(c *caddy.Controller) error {
|
func setupPlugin(c *caddy.Controller) (*plug, error) {
|
||||||
|
|
||||||
p := &plug{ratelimit: defaultRatelimit}
|
p := &plug{ratelimit: defaultRatelimit}
|
||||||
config := dnsserver.GetConfig(c)
|
|
||||||
|
|
||||||
for c.Next() {
|
for c.Next() {
|
||||||
args := c.RemainingArgs()
|
args := c.RemainingArgs()
|
||||||
if len(args) <= 0 {
|
if len(args) > 0 {
|
||||||
continue
|
ratelimit, err := strconv.Atoi(args[0])
|
||||||
|
if err != nil {
|
||||||
|
return nil, c.ArgErr()
|
||||||
|
}
|
||||||
|
p.ratelimit = ratelimit
|
||||||
}
|
}
|
||||||
ratelimit, err := strconv.Atoi(args[0])
|
for c.NextBlock() {
|
||||||
if err != nil {
|
switch c.Val() {
|
||||||
return c.ArgErr()
|
case "whitelist":
|
||||||
|
p.whitelist = c.RemainingArgs()
|
||||||
|
|
||||||
|
if len(p.whitelist) > 0 {
|
||||||
|
sort.Strings(p.whitelist)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
p.ratelimit = ratelimit
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setup(c *caddy.Controller) error {
|
||||||
|
p, err := setupPlugin(c)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
config := dnsserver.GetConfig(c)
|
||||||
config.AddPlugin(func(next plugin.Handler) plugin.Handler {
|
config.AddPlugin(func(next plugin.Handler) plugin.Handler {
|
||||||
p.Next = next
|
p.Next = next
|
||||||
return p
|
return p
|
||||||
|
|
|
@ -0,0 +1,82 @@
|
||||||
|
package ratelimit
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/mholt/caddy"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSetup(t *testing.T) {
|
||||||
|
for i, testcase := range []struct {
|
||||||
|
config string
|
||||||
|
failing bool
|
||||||
|
}{
|
||||||
|
{`ratelimit`, false},
|
||||||
|
{`ratelimit 100`, false},
|
||||||
|
{`ratelimit {
|
||||||
|
whitelist 127.0.0.1
|
||||||
|
}`, false},
|
||||||
|
{`ratelimit 50 {
|
||||||
|
whitelist 127.0.0.1 176.103.130.130
|
||||||
|
}`, false},
|
||||||
|
{`ratelimit test`, true},
|
||||||
|
} {
|
||||||
|
c := caddy.NewTestController("dns", testcase.config)
|
||||||
|
err := setup(c)
|
||||||
|
if err != nil {
|
||||||
|
if !testcase.failing {
|
||||||
|
t.Fatalf("Test #%d expected no errors, but got: %v", i, err)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if testcase.failing {
|
||||||
|
t.Fatalf("Test #%d expected to fail but it didn't", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRatelimiting(t *testing.T) {
|
||||||
|
|
||||||
|
// rate limit is 1 per sec
|
||||||
|
c := caddy.NewTestController("dns", `ratelimit 1`)
|
||||||
|
p, err := setupPlugin(c)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("Failed to initialize the plugin")
|
||||||
|
}
|
||||||
|
|
||||||
|
allowed, err := p.allowRequest("127.0.0.1")
|
||||||
|
|
||||||
|
if err != nil || !allowed {
|
||||||
|
t.Fatal("First request must have been allowed")
|
||||||
|
}
|
||||||
|
|
||||||
|
allowed, err = p.allowRequest("127.0.0.1")
|
||||||
|
|
||||||
|
if err != nil || allowed {
|
||||||
|
t.Fatal("Second request must have been ratelimited")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWhitelist(t *testing.T) {
|
||||||
|
|
||||||
|
// rate limit is 1 per sec
|
||||||
|
c := caddy.NewTestController("dns", `ratelimit 1 { whitelist 127.0.0.1 }`)
|
||||||
|
p, err := setupPlugin(c)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("Failed to initialize the plugin")
|
||||||
|
}
|
||||||
|
|
||||||
|
allowed, err := p.allowRequest("127.0.0.1")
|
||||||
|
|
||||||
|
if err != nil || !allowed {
|
||||||
|
t.Fatal("First request must have been allowed")
|
||||||
|
}
|
||||||
|
|
||||||
|
allowed, err = p.allowRequest("127.0.0.1")
|
||||||
|
|
||||||
|
if err != nil || !allowed {
|
||||||
|
t.Fatal("Second request must have been allowed due to whitelist")
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue