Tighter ratelimiting for password resets.
Eliminated a line in agents.go Added the DateCutoff method to accCountBuilder. Function columns should now work for ComplexSelect. Added type=search to the Search and Filter Widget search box. Began cleaning some prebuilder logic up. Began work on the generic ratelimit interface.
This commit is contained in:
parent
e22ddfec40
commit
414d9c4817
@ -14,8 +14,7 @@ type DefaultAgentViewCounter struct {
|
||||
insert *sql.Stmt
|
||||
}
|
||||
|
||||
func NewDefaultAgentViewCounter() (*DefaultAgentViewCounter, error) {
|
||||
acc := qgen.NewAcc()
|
||||
func NewDefaultAgentViewCounter(acc *qgen.Accumulator) (*DefaultAgentViewCounter, error) {
|
||||
var agentBuckets = make([]*RWMutexCounterBucket, len(agentMapEnum))
|
||||
for bucketID, _ := range agentBuckets {
|
||||
agentBuckets[bucketID] = &RWMutexCounterBucket{counter: 0}
|
||||
|
100
common/ratelimit.go
Normal file
100
common/ratelimit.go
Normal file
@ -0,0 +1,100 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var ErrBadRateLimiter = errors.New("That rate limiter doesn't exist")
|
||||
var ErrExceededRateLimit = errors.New("You're exceeding a rate limit. Please wait a while before trying again.")
|
||||
|
||||
// TODO: Persist rate limits to disk
|
||||
type RateLimiter interface {
|
||||
LimitIP(limit string, ip string) error
|
||||
LimitUser(limit string, user int) error
|
||||
}
|
||||
|
||||
type RateData struct {
|
||||
value int
|
||||
floorTime int
|
||||
}
|
||||
|
||||
type RateFence struct {
|
||||
duration int
|
||||
max int
|
||||
}
|
||||
|
||||
// TODO: Optimise this by using something other than a string when possible
|
||||
type RateLimit struct {
|
||||
data map[string][]RateData
|
||||
fences []RateFence
|
||||
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
func NewRateLimit(fences []RateFence) *RateLimit {
|
||||
for i, fence := range fences {
|
||||
fences[i].duration = fence.duration * 1000 * 1000 * 1000
|
||||
}
|
||||
return &RateLimit{data: make(map[string][]RateData), fences: fences}
|
||||
}
|
||||
|
||||
func (l *RateLimit) Limit(name string, ltype int) error {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
data, ok := l.data[name]
|
||||
if !ok {
|
||||
data = make([]RateData, len(l.fences))
|
||||
for i, _ := range data {
|
||||
data[i] = RateData{0, int(time.Now().Unix())}
|
||||
}
|
||||
}
|
||||
|
||||
for i, field := range data {
|
||||
fence := l.fences[i]
|
||||
diff := int(time.Now().Unix()) - field.floorTime
|
||||
|
||||
if diff >= fence.duration {
|
||||
field = RateData{0, int(time.Now().Unix())}
|
||||
data[i] = field
|
||||
}
|
||||
|
||||
if field.value > fence.max {
|
||||
return ErrExceededRateLimit
|
||||
}
|
||||
|
||||
field.value++
|
||||
data[i] = field
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type DefaultRateLimiter struct {
|
||||
limits map[string]*RateLimit
|
||||
}
|
||||
|
||||
func NewDefaultRateLimiter() *DefaultRateLimiter {
|
||||
return &DefaultRateLimiter{map[string]*RateLimit{
|
||||
"register": NewRateLimit([]RateFence{RateFence{int(time.Hour / 2), 1}}),
|
||||
}}
|
||||
}
|
||||
|
||||
func (l *DefaultRateLimiter) LimitIP(limit string, ip string) error {
|
||||
limiter, ok := l.limits[limit]
|
||||
if !ok {
|
||||
return ErrBadRateLimiter
|
||||
}
|
||||
return limiter.Limit(ip, 0)
|
||||
}
|
||||
|
||||
func (l *DefaultRateLimiter) LimitUser(limit string, user int) error {
|
||||
limiter, ok := l.limits[limit]
|
||||
if !ok {
|
||||
return ErrBadRateLimiter
|
||||
}
|
||||
return limiter.Limit(strconv.Itoa(user), 1)
|
||||
}
|
2
main.go
2
main.go
@ -177,7 +177,7 @@ func afterDBInit() (err error) {
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
counters.AgentViewCounter, err = counters.NewDefaultAgentViewCounter()
|
||||
counters.AgentViewCounter, err = counters.NewDefaultAgentViewCounter(acc)
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
@ -288,32 +288,46 @@ type accCountBuilder struct {
|
||||
table string
|
||||
where string
|
||||
limit string
|
||||
dateCutoff *dateCutoff // We might want to do this in a slightly less hacky way
|
||||
inChain *AccSelectBuilder
|
||||
inColumn string
|
||||
|
||||
build *Accumulator
|
||||
}
|
||||
|
||||
func (count *accCountBuilder) Where(where string) *accCountBuilder {
|
||||
if count.where != "" {
|
||||
count.where += " AND "
|
||||
func (b *accCountBuilder) Where(where string) *accCountBuilder {
|
||||
if b.where != "" {
|
||||
b.where += " AND "
|
||||
}
|
||||
count.where += where
|
||||
return count
|
||||
b.where += where
|
||||
return b
|
||||
}
|
||||
|
||||
func (count *accCountBuilder) Limit(limit string) *accCountBuilder {
|
||||
count.limit = limit
|
||||
return count
|
||||
func (b *accCountBuilder) Limit(limit string) *accCountBuilder {
|
||||
b.limit = limit
|
||||
return b
|
||||
}
|
||||
|
||||
// TODO: Add QueryRow for this and use it in statistics.go
|
||||
func (count *accCountBuilder) Prepare() *sql.Stmt {
|
||||
return count.build.SimpleCount(count.table, count.where, count.limit)
|
||||
func (b *accCountBuilder) DateCutoff(column string, quantity int, unit string) *accCountBuilder {
|
||||
b.dateCutoff = &dateCutoff{column, quantity, unit}
|
||||
return b
|
||||
}
|
||||
|
||||
func (count *accCountBuilder) Total() (total int, err error) {
|
||||
stmt := count.Prepare()
|
||||
// TODO: Fix this nasty hack
|
||||
func (b *accCountBuilder) Prepare() *sql.Stmt {
|
||||
// TODO: Phase out the procedural API and use the adapter's OO API? The OO API might need a bit more work before we do that and it needs to be rolled out to MSSQL.
|
||||
if b.dateCutoff != nil || b.inChain != nil {
|
||||
selBuilder := b.build.GetAdapter().Builder().Count().FromCountAcc(b)
|
||||
selBuilder.columns = "COUNT(*)"
|
||||
return b.build.prepare(b.build.GetAdapter().ComplexSelect(selBuilder))
|
||||
}
|
||||
return b.build.SimpleCount(b.table, b.where, b.limit)
|
||||
}
|
||||
|
||||
func (b *accCountBuilder) Total() (total int, err error) {
|
||||
stmt := b.Prepare()
|
||||
if stmt == nil {
|
||||
return 0, count.build.FirstError()
|
||||
return 0, b.build.FirstError()
|
||||
}
|
||||
err = stmt.QueryRow().Scan(&total)
|
||||
return total, err
|
||||
|
@ -240,5 +240,5 @@ func (build *Accumulator) Insert(table string) *accInsertBuilder {
|
||||
}
|
||||
|
||||
func (build *Accumulator) Count(table string) *accCountBuilder {
|
||||
return &accCountBuilder{table, "", "", build}
|
||||
return &accCountBuilder{table, "", "", nil, nil, "", build}
|
||||
}
|
||||
|
@ -15,6 +15,11 @@ func (build *prebuilder) Select(nlist ...string) *selectPrebuilder {
|
||||
return &selectPrebuilder{name, "", "", "", "", "", nil, nil, "", build.adapter}
|
||||
}
|
||||
|
||||
func (build *prebuilder) Count(nlist ...string) *selectPrebuilder {
|
||||
name := optString(nlist, "")
|
||||
return &selectPrebuilder{name, "", "COUNT(*)", "", "", "", nil, nil, "", build.adapter}
|
||||
}
|
||||
|
||||
func (build *prebuilder) Insert(nlist ...string) *insertPrebuilder {
|
||||
name := optString(nlist, "")
|
||||
return &insertPrebuilder{name, "", "", "", build.adapter}
|
||||
@ -136,35 +141,50 @@ func (selectItem *selectPrebuilder) Where(where string) *selectPrebuilder {
|
||||
return selectItem
|
||||
}
|
||||
|
||||
func (selectItem *selectPrebuilder) InQ(subBuilder *selectPrebuilder) *selectPrebuilder {
|
||||
selectItem.inChain = subBuilder
|
||||
return selectItem
|
||||
func (b *selectPrebuilder) InQ(subBuilder *selectPrebuilder) *selectPrebuilder {
|
||||
b.inChain = subBuilder
|
||||
return b
|
||||
}
|
||||
|
||||
func (selectItem *selectPrebuilder) Orderby(orderby string) *selectPrebuilder {
|
||||
selectItem.orderby = orderby
|
||||
return selectItem
|
||||
func (b *selectPrebuilder) Orderby(orderby string) *selectPrebuilder {
|
||||
b.orderby = orderby
|
||||
return b
|
||||
}
|
||||
|
||||
func (selectItem *selectPrebuilder) Limit(limit string) *selectPrebuilder {
|
||||
selectItem.limit = limit
|
||||
return selectItem
|
||||
func (b *selectPrebuilder) Limit(limit string) *selectPrebuilder {
|
||||
b.limit = limit
|
||||
return b
|
||||
}
|
||||
|
||||
// TODO: We probably want to avoid the double allocation of two builders somehow
|
||||
func (selectItem *selectPrebuilder) FromAcc(accBuilder *AccSelectBuilder) *selectPrebuilder {
|
||||
selectItem.table = accBuilder.table
|
||||
selectItem.columns = accBuilder.columns
|
||||
selectItem.where = accBuilder.where
|
||||
selectItem.orderby = accBuilder.orderby
|
||||
selectItem.limit = accBuilder.limit
|
||||
|
||||
selectItem.dateCutoff = accBuilder.dateCutoff
|
||||
if accBuilder.inChain != nil {
|
||||
selectItem.inChain = &selectPrebuilder{"", accBuilder.inChain.table, accBuilder.inChain.columns, accBuilder.inChain.where, accBuilder.inChain.orderby, accBuilder.inChain.limit, accBuilder.inChain.dateCutoff, nil, "", selectItem.build}
|
||||
selectItem.inColumn = accBuilder.inColumn
|
||||
func (b *selectPrebuilder) FromAcc(acc *AccSelectBuilder) *selectPrebuilder {
|
||||
b.table = acc.table
|
||||
if acc.columns != "" {
|
||||
b.columns = acc.columns
|
||||
}
|
||||
return selectItem
|
||||
b.where = acc.where
|
||||
b.orderby = acc.orderby
|
||||
b.limit = acc.limit
|
||||
|
||||
b.dateCutoff = acc.dateCutoff
|
||||
if acc.inChain != nil {
|
||||
b.inChain = &selectPrebuilder{"", acc.inChain.table, acc.inChain.columns, acc.inChain.where, acc.inChain.orderby, acc.inChain.limit, acc.inChain.dateCutoff, nil, "", b.build}
|
||||
b.inColumn = acc.inColumn
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *selectPrebuilder) FromCountAcc(acc *accCountBuilder) *selectPrebuilder {
|
||||
b.table = acc.table
|
||||
b.where = acc.where
|
||||
b.limit = acc.limit
|
||||
|
||||
b.dateCutoff = acc.dateCutoff
|
||||
if acc.inChain != nil {
|
||||
b.inChain = &selectPrebuilder{"", acc.inChain.table, acc.inChain.columns, acc.inChain.where, acc.inChain.orderby, acc.inChain.limit, acc.inChain.dateCutoff, nil, "", b.build}
|
||||
b.inColumn = acc.inColumn
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// TODO: Add support for dateCutoff
|
||||
@ -209,7 +229,7 @@ func (insert *insertPrebuilder) Parse() {
|
||||
insert.build.SimpleInsert(insert.name, insert.table, insert.columns, insert.fields)
|
||||
}
|
||||
|
||||
type countPrebuilder struct {
|
||||
/*type countPrebuilder struct {
|
||||
name string
|
||||
table string
|
||||
where string
|
||||
@ -242,7 +262,7 @@ func (count *countPrebuilder) Text() (string, error) {
|
||||
|
||||
func (count *countPrebuilder) Parse() {
|
||||
count.build.SimpleCount(count.name, count.table, count.where, count.limit)
|
||||
}
|
||||
}*/
|
||||
|
||||
func optString(nlist []string, defaultStr string) string {
|
||||
if len(nlist) == 0 {
|
||||
|
@ -519,13 +519,13 @@ func (adapter *MysqlAdapter) ComplexSelect(preBuilder *selectPrebuilder) (out st
|
||||
return "", errors.New("No columns found for ComplexSelect")
|
||||
}
|
||||
|
||||
var querystr = "SELECT "
|
||||
var querystr = "SELECT " + adapter.buildJoinColumns(preBuilder.columns)
|
||||
|
||||
// Slice up the user friendly strings into something easier to process
|
||||
for _, column := range strings.Split(strings.TrimSpace(preBuilder.columns), ",") {
|
||||
/*for _, column := range strings.Split(strings.TrimSpace(preBuilder.columns), ",") {
|
||||
querystr += "`" + strings.TrimSpace(column) + "`,"
|
||||
}
|
||||
querystr = querystr[0 : len(querystr)-1]
|
||||
querystr = querystr[0 : len(querystr)-1]*/
|
||||
|
||||
var whereStr string
|
||||
// TODO: Let callers have a Where() and a InQ()
|
||||
|
@ -806,7 +806,7 @@ func AccountPasswordResetSubmit(w http.ResponseWriter, r *http.Request, user com
|
||||
return common.InternalError(err, w, r)
|
||||
}
|
||||
|
||||
// TODO: Move this query somewhere else
|
||||
// TODO: Move these queries somewhere else
|
||||
var disc string
|
||||
err = qgen.NewAcc().Select("password_resets").Columns("createdAt").DateCutoff("createdAt", 1, "hour").QueryRow().Scan(&disc)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
@ -816,6 +816,22 @@ func AccountPasswordResetSubmit(w http.ResponseWriter, r *http.Request, user com
|
||||
return common.LocalError("You can only send a password reset email for a user once an hour", w, r, user)
|
||||
}
|
||||
|
||||
count, err := qgen.NewAcc().Count("password_resets").DateCutoff("createdAt", 6, "hour").Total()
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return common.InternalError(err, w, r)
|
||||
}
|
||||
if count >= 3 {
|
||||
return common.LocalError("You can only send a password reset email for a user three times every six hours", w, r, user)
|
||||
}
|
||||
|
||||
count, err = qgen.NewAcc().Count("password_resets").DateCutoff("createdAt", 12, "hour").Total()
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return common.InternalError(err, w, r)
|
||||
}
|
||||
if count >= 4 {
|
||||
return common.LocalError("You can only send a password reset email for a user four times every twelve hours", w, r, user)
|
||||
}
|
||||
|
||||
err = common.PasswordResetter.Create(tuser.Email, tuser.ID, token)
|
||||
if err != nil {
|
||||
return common.InternalError(err, w, r)
|
||||
|
@ -1,5 +1,5 @@
|
||||
<div class="search widget_search">
|
||||
<input class="widget_search_input" name="widget_search" placeholder="Search" />
|
||||
<input class="widget_search_input" name="widget_search" placeholder="Search" type="search" />
|
||||
</div>
|
||||
<div class="rowblock filter_list widget_filter">
|
||||
{{range .Forums}} <div class="rowitem filter_item{{if .Selected}} filter_selected{{end}}" data-fid="{{.ID}}"><a href="/topics/?fids={{.ID}}" rel="nofollow">{{.Name}}</a></div>
|
||||
|
Loading…
Reference in New Issue
Block a user