diff --git a/common/counters/agents.go b/common/counters/agents.go index f23efdad..fb9bacdb 100644 --- a/common/counters/agents.go +++ b/common/counters/agents.go @@ -14,8 +14,7 @@ type DefaultAgentViewCounter struct { insert *sql.Stmt } -func NewDefaultAgentViewCounter() (*DefaultAgentViewCounter, error) { - acc := qgen.NewAcc() +func NewDefaultAgentViewCounter(acc *qgen.Accumulator) (*DefaultAgentViewCounter, error) { var agentBuckets = make([]*RWMutexCounterBucket, len(agentMapEnum)) for bucketID, _ := range agentBuckets { agentBuckets[bucketID] = &RWMutexCounterBucket{counter: 0} diff --git a/common/ratelimit.go b/common/ratelimit.go new file mode 100644 index 00000000..dfb0080b --- /dev/null +++ b/common/ratelimit.go @@ -0,0 +1,100 @@ +package common + +import ( + "errors" + "strconv" + "sync" + "time" +) + +var ErrBadRateLimiter = errors.New("That rate limiter doesn't exist") +var ErrExceededRateLimit = errors.New("You're exceeding a rate limit. Please wait a while before trying again.") + +// TODO: Persist rate limits to disk +type RateLimiter interface { + LimitIP(limit string, ip string) error + LimitUser(limit string, user int) error +} + +type RateData struct { + value int + floorTime int +} + +type RateFence struct { + duration int + max int +} + +// TODO: Optimise this by using something other than a string when possible +type RateLimit struct { + data map[string][]RateData + fences []RateFence + + sync.RWMutex +} + +func NewRateLimit(fences []RateFence) *RateLimit { + for i, fence := range fences { + fences[i].duration = fence.duration * 1000 * 1000 * 1000 + } + return &RateLimit{data: make(map[string][]RateData), fences: fences} +} + +func (l *RateLimit) Limit(name string, ltype int) error { + l.Lock() + defer l.Unlock() + + data, ok := l.data[name] + if !ok { + data = make([]RateData, len(l.fences)) + for i, _ := range data { + data[i] = RateData{0, int(time.Now().Unix())} + } + } + + for i, field := range data { + fence := l.fences[i] + diff := int(time.Now().Unix()) - field.floorTime + + if diff >= fence.duration { + field = RateData{0, int(time.Now().Unix())} + data[i] = field + } + + if field.value > fence.max { + return ErrExceededRateLimit + } + + field.value++ + data[i] = field + } + + return nil +} + +type DefaultRateLimiter struct { + limits map[string]*RateLimit +} + +func NewDefaultRateLimiter() *DefaultRateLimiter { + return &DefaultRateLimiter{map[string]*RateLimit{ + "register": NewRateLimit([]RateFence{RateFence{int(time.Hour / 2), 1}}), + }} +} + +func (l *DefaultRateLimiter) LimitIP(limit string, ip string) error { + limiter, ok := l.limits[limit] + if !ok { + return ErrBadRateLimiter + } + return limiter.Limit(ip, 0) +} + +func (l *DefaultRateLimiter) LimitUser(limit string, user int) error { + limiter, ok := l.limits[limit] + if !ok { + return ErrBadRateLimiter + } + return limiter.Limit(strconv.Itoa(user), 1) +} diff --git a/main.go b/main.go index e68c11db..f91b2fd4 100644 --- a/main.go +++ b/main.go @@ -177,7 +177,7 @@ func afterDBInit() (err error) { if err != nil { return errors.WithStack(err) } - counters.AgentViewCounter, err = counters.NewDefaultAgentViewCounter() + counters.AgentViewCounter, err = counters.NewDefaultAgentViewCounter(acc) if err != nil { return errors.WithStack(err) } diff --git a/query_gen/acc_builders.go b/query_gen/acc_builders.go index 083676d2..71e60758 100644 --- a/query_gen/acc_builders.go +++ b/query_gen/acc_builders.go @@ -285,35 +285,49 @@ func (builder *accInsertBuilder) Run(args ...interface{}) (int, error) { } type accCountBuilder struct { - table string - where string - limit string + table string + where string + limit string + dateCutoff *dateCutoff // We might want to do this in a slightly less hacky way + inChain *AccSelectBuilder + inColumn string build *Accumulator } -func (count *accCountBuilder) Where(where string) *accCountBuilder { - if count.where != "" { - count.where += " AND " +func (b *accCountBuilder) Where(where string) *accCountBuilder { + if b.where != "" { + b.where += " AND " } - count.where += where - return count + b.where += where + return b } -func (count *accCountBuilder) Limit(limit string) *accCountBuilder { - count.limit = limit - return count +func (b *accCountBuilder) Limit(limit string) *accCountBuilder { + b.limit = limit + return b } -// TODO: Add QueryRow for this and use it in statistics.go -func (count *accCountBuilder) Prepare() *sql.Stmt { - return count.build.SimpleCount(count.table, count.where, count.limit) +func (b *accCountBuilder) DateCutoff(column string, quantity int, unit string) *accCountBuilder { + b.dateCutoff = &dateCutoff{column, quantity, unit} + return b } -func (count *accCountBuilder) Total() (total int, err error) { - stmt := count.Prepare() +// TODO: Fix this nasty hack +func (b *accCountBuilder) Prepare() *sql.Stmt { + // TODO: Phase out the procedural API and use the adapter's OO API? The OO API might need a bit more work before we do that and it needs to be rolled out to MSSQL. + if b.dateCutoff != nil || b.inChain != nil { + selBuilder := b.build.GetAdapter().Builder().Count().FromCountAcc(b) + selBuilder.columns = "COUNT(*)" + return b.build.prepare(b.build.GetAdapter().ComplexSelect(selBuilder)) + } + return b.build.SimpleCount(b.table, b.where, b.limit) +} + +func (b *accCountBuilder) Total() (total int, err error) { + stmt := b.Prepare() if stmt == nil { - return 0, count.build.FirstError() + return 0, b.build.FirstError() } err = stmt.QueryRow().Scan(&total) return total, err diff --git a/query_gen/accumulator.go b/query_gen/accumulator.go index 457c3a03..f2673deb 100644 --- a/query_gen/accumulator.go +++ b/query_gen/accumulator.go @@ -240,5 +240,5 @@ func (build *Accumulator) Insert(table string) *accInsertBuilder { } func (build *Accumulator) Count(table string) *accCountBuilder { - return &accCountBuilder{table, "", "", build} + return &accCountBuilder{table, "", "", nil, nil, "", build} } diff --git a/query_gen/micro_builders.go b/query_gen/micro_builders.go index 271e8ee1..23a73364 100644 --- a/query_gen/micro_builders.go +++ b/query_gen/micro_builders.go @@ -15,6 +15,11 @@ func (build *prebuilder) Select(nlist ...string) *selectPrebuilder { return &selectPrebuilder{name, "", "", "", "", "", nil, nil, "", build.adapter} } +func (build *prebuilder) Count(nlist ...string) *selectPrebuilder { + name := optString(nlist, "") + return &selectPrebuilder{name, "", "COUNT(*)", "", "", "", nil, nil, "", build.adapter} +} + func (build *prebuilder) Insert(nlist ...string) *insertPrebuilder { name := optString(nlist, "") return &insertPrebuilder{name, "", "", "", build.adapter} @@ -136,35 +141,50 @@ func (selectItem *selectPrebuilder) Where(where string) *selectPrebuilder { return selectItem } -func (selectItem *selectPrebuilder) InQ(subBuilder *selectPrebuilder) *selectPrebuilder { - selectItem.inChain = subBuilder - return selectItem +func (b *selectPrebuilder) InQ(subBuilder *selectPrebuilder) *selectPrebuilder { + b.inChain = subBuilder + return b } -func (selectItem *selectPrebuilder) Orderby(orderby string) *selectPrebuilder { - selectItem.orderby = orderby - return selectItem +func (b *selectPrebuilder) Orderby(orderby string) *selectPrebuilder { + b.orderby = orderby + return b } -func (selectItem *selectPrebuilder) Limit(limit string) *selectPrebuilder { - selectItem.limit = limit - return selectItem +func (b *selectPrebuilder) Limit(limit string) *selectPrebuilder { + b.limit = limit + return b } // TODO: We probably want to avoid the double allocation of two builders somehow -func (selectItem *selectPrebuilder) FromAcc(accBuilder *AccSelectBuilder) *selectPrebuilder { - selectItem.table = accBuilder.table - selectItem.columns = accBuilder.columns - selectItem.where = accBuilder.where - selectItem.orderby = accBuilder.orderby - selectItem.limit = accBuilder.limit - - selectItem.dateCutoff = accBuilder.dateCutoff - if accBuilder.inChain != nil { - selectItem.inChain = &selectPrebuilder{"", accBuilder.inChain.table, accBuilder.inChain.columns, accBuilder.inChain.where, accBuilder.inChain.orderby, accBuilder.inChain.limit, accBuilder.inChain.dateCutoff, nil, "", selectItem.build} - selectItem.inColumn = accBuilder.inColumn +func (b *selectPrebuilder) FromAcc(acc *AccSelectBuilder) *selectPrebuilder { + b.table = acc.table + if acc.columns != "" { + b.columns = acc.columns } - return selectItem + b.where = acc.where + b.orderby = acc.orderby + b.limit = acc.limit + + b.dateCutoff = acc.dateCutoff + if acc.inChain != nil { + b.inChain = &selectPrebuilder{"", acc.inChain.table, acc.inChain.columns, acc.inChain.where, acc.inChain.orderby, acc.inChain.limit, acc.inChain.dateCutoff, nil, "", b.build} + b.inColumn = acc.inColumn + } + return b +} + +func (b *selectPrebuilder) FromCountAcc(acc *accCountBuilder) *selectPrebuilder { + b.table = acc.table + b.where = acc.where + b.limit = acc.limit + + b.dateCutoff = acc.dateCutoff + if acc.inChain != nil { + b.inChain = &selectPrebuilder{"", acc.inChain.table, acc.inChain.columns, acc.inChain.where, acc.inChain.orderby, acc.inChain.limit, acc.inChain.dateCutoff, nil, "", b.build} + b.inColumn = acc.inColumn + } + return b } // TODO: Add support for dateCutoff @@ -209,7 +229,7 @@ func (insert *insertPrebuilder) Parse() { insert.build.SimpleInsert(insert.name, insert.table, insert.columns, insert.fields) } -type countPrebuilder struct { +/*type countPrebuilder struct { name string table string where string @@ -242,7 +262,7 @@ func (count *countPrebuilder) Text() (string, error) { func (count *countPrebuilder) Parse() { count.build.SimpleCount(count.name, count.table, count.where, count.limit) -} +}*/ func optString(nlist []string, defaultStr string) string { if len(nlist) == 0 { diff --git a/query_gen/mysql.go b/query_gen/mysql.go index 4bd91372..524258a2 100644 --- a/query_gen/mysql.go +++ b/query_gen/mysql.go @@ -519,13 +519,13 @@ func (adapter *MysqlAdapter) ComplexSelect(preBuilder *selectPrebuilder) (out st return "", errors.New("No columns found for ComplexSelect") } - var querystr = "SELECT " + var querystr = "SELECT " + adapter.buildJoinColumns(preBuilder.columns) // Slice up the user friendly strings into something easier to process - for _, column := range strings.Split(strings.TrimSpace(preBuilder.columns), ",") { + /*for _, column := range strings.Split(strings.TrimSpace(preBuilder.columns), ",") { querystr += "`" + strings.TrimSpace(column) + "`," } - querystr = querystr[0 : len(querystr)-1] + querystr = querystr[0 : len(querystr)-1]*/ var whereStr string // TODO: Let callers have a Where() and a InQ() diff --git a/routes/account.go b/routes/account.go index 6049cb69..a36a3732 100644 --- a/routes/account.go +++ b/routes/account.go @@ -806,7 +806,7 @@ func AccountPasswordResetSubmit(w http.ResponseWriter, r *http.Request, user com return common.InternalError(err, w, r) } - // TODO: Move this query somewhere else + // TODO: Move these queries somewhere else var disc string err = qgen.NewAcc().Select("password_resets").Columns("createdAt").DateCutoff("createdAt", 1, "hour").QueryRow().Scan(&disc) if err != nil && err != sql.ErrNoRows { @@ -816,6 +816,22 @@ func AccountPasswordResetSubmit(w http.ResponseWriter, r *http.Request, user com return common.LocalError("You can only send a password reset email for a user once an hour", w, r, user) } + count, err := qgen.NewAcc().Count("password_resets").DateCutoff("createdAt", 6, "hour").Total() + if err != nil && err != sql.ErrNoRows { + return common.InternalError(err, w, r) + } + if count >= 3 { + return common.LocalError("You can only send a password reset email for a user three times every six hours", w, r, user) + } + + count, err = qgen.NewAcc().Count("password_resets").DateCutoff("createdAt", 12, "hour").Total() + if err != nil && err != sql.ErrNoRows { + return common.InternalError(err, w, r) + } + if count >= 4 { + return common.LocalError("You can only send a password reset email for a user four times every twelve hours", w, r, user) + } + err = common.PasswordResetter.Create(tuser.Email, tuser.ID, token) if err != nil { return common.InternalError(err, w, r) diff --git a/templates/widget_search_and_filter.html b/templates/widget_search_and_filter.html index 9e9c7a7e..ac57acfc 100644 --- a/templates/widget_search_and_filter.html +++ b/templates/widget_search_and_filter.html @@ -1,5 +1,5 @@
{{range .Forums}}
{{.Name}}