diff --git a/coredns_plugin/ratelimit/ratelimit.go b/coredns_plugin/ratelimit/ratelimit.go index 2a43644c..b14a907d 100644 --- a/coredns_plugin/ratelimit/ratelimit.go +++ b/coredns_plugin/ratelimit/ratelimit.go @@ -3,6 +3,7 @@ package ratelimit import ( "errors" "log" + "sort" "strconv" "time" @@ -22,7 +23,6 @@ import ( ) const defaultRatelimit = 100 -const defaultMaxRateLimitedIPs = 1024 * 1024 var ( tokenBuckets = cache.New(time.Hour, time.Hour) @@ -44,6 +44,11 @@ func (p *plug) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) ( } func (p *plug) allowRequest(ip string) (bool, error) { + + if len(p.whitelist) > 0 && sort.SearchStrings(p.whitelist, ip) >= 0 { + return true, nil + } + if _, found := tokenBuckets.Get(ip); !found { tokenBuckets.Set(ip, rate.New(p.ratelimit, time.Second), time.Hour) } @@ -83,25 +88,46 @@ type plug struct { Next plugin.Handler // configuration for creating above - ratelimit int // in requests per second per IP + ratelimit int // in requests per second per IP + whitelist []string // a list of whitelisted IP addresses + } -func setup(c *caddy.Controller) error { +func setupPlugin(c *caddy.Controller) (*plug, error) { + p := &plug{ratelimit: defaultRatelimit} - config := dnsserver.GetConfig(c) for c.Next() { args := c.RemainingArgs() - if len(args) <= 0 { - continue + if len(args) > 0 { + ratelimit, err := strconv.Atoi(args[0]) + if err != nil { + return nil, c.ArgErr() + } + p.ratelimit = ratelimit } - ratelimit, err := strconv.Atoi(args[0]) - if err != nil { - return c.ArgErr() + for c.NextBlock() { + switch c.Val() { + case "whitelist": + p.whitelist = c.RemainingArgs() + + if len(p.whitelist) > 0 { + sort.Strings(p.whitelist) + } + } } - p.ratelimit = ratelimit } + return p, nil +} + +func setup(c *caddy.Controller) error { + p, err := setupPlugin(c) + if err != nil { + return err + } + + config := dnsserver.GetConfig(c) config.AddPlugin(func(next plugin.Handler) plugin.Handler { p.Next = next return p diff --git a/coredns_plugin/ratelimit/ratelimit_test.go b/coredns_plugin/ratelimit/ratelimit_test.go new file mode 100644 index 00000000..9fb05cb7 --- /dev/null +++ b/coredns_plugin/ratelimit/ratelimit_test.go @@ -0,0 +1,82 @@ +package ratelimit + +import ( + "testing" + + "github.com/mholt/caddy" +) + +func TestSetup(t *testing.T) { + for i, testcase := range []struct { + config string + failing bool + }{ + {`ratelimit`, false}, + {`ratelimit 100`, false}, + {`ratelimit { + whitelist 127.0.0.1 + }`, false}, + {`ratelimit 50 { + whitelist 127.0.0.1 176.103.130.130 + }`, false}, + {`ratelimit test`, true}, + } { + c := caddy.NewTestController("dns", testcase.config) + err := setup(c) + if err != nil { + if !testcase.failing { + t.Fatalf("Test #%d expected no errors, but got: %v", i, err) + } + continue + } + if testcase.failing { + t.Fatalf("Test #%d expected to fail but it didn't", i) + } + } +} + +func TestRatelimiting(t *testing.T) { + + // rate limit is 1 per sec + c := caddy.NewTestController("dns", `ratelimit 1`) + p, err := setupPlugin(c) + + if err != nil { + t.Fatal("Failed to initialize the plugin") + } + + allowed, err := p.allowRequest("127.0.0.1") + + if err != nil || !allowed { + t.Fatal("First request must have been allowed") + } + + allowed, err = p.allowRequest("127.0.0.1") + + if err != nil || allowed { + t.Fatal("Second request must have been ratelimited") + } +} + +func TestWhitelist(t *testing.T) { + + // rate limit is 1 per sec + c := caddy.NewTestController("dns", `ratelimit 1 { whitelist 127.0.0.1 }`) + p, err := setupPlugin(c) + + if err != nil { + t.Fatal("Failed to initialize the plugin") + } + + allowed, err := p.allowRequest("127.0.0.1") + + if err != nil || !allowed { + t.Fatal("First request must have been allowed") + } + + allowed, err = p.allowRequest("127.0.0.1") + + if err != nil || !allowed { + t.Fatal("Second request must have been allowed due to whitelist") + } +}