Tighter ratelimiting for password resets.

Eliminated a line in agents.go
Added the DateCutoff method to accCountBuilder.
Function columns should now work for ComplexSelect.
Added type=search to the Search and Filter Widget search box.

Began cleaning some prebuilder logic up.
Began work on the generic ratelimit interface.
This commit is contained in:
Azareal 2019-03-12 19:13:57 +10:00
parent e22ddfec40
commit 414d9c4817
9 changed files with 198 additions and 49 deletions

View File

@ -14,8 +14,7 @@ type DefaultAgentViewCounter struct {
insert *sql.Stmt insert *sql.Stmt
} }
func NewDefaultAgentViewCounter() (*DefaultAgentViewCounter, error) { func NewDefaultAgentViewCounter(acc *qgen.Accumulator) (*DefaultAgentViewCounter, error) {
acc := qgen.NewAcc()
var agentBuckets = make([]*RWMutexCounterBucket, len(agentMapEnum)) var agentBuckets = make([]*RWMutexCounterBucket, len(agentMapEnum))
for bucketID, _ := range agentBuckets { for bucketID, _ := range agentBuckets {
agentBuckets[bucketID] = &RWMutexCounterBucket{counter: 0} agentBuckets[bucketID] = &RWMutexCounterBucket{counter: 0}

100
common/ratelimit.go Normal file
View File

@ -0,0 +1,100 @@
package common
import (
"errors"
"strconv"
"sync"
"time"
)
var ErrBadRateLimiter = errors.New("That rate limiter doesn't exist")
var ErrExceededRateLimit = errors.New("You're exceeding a rate limit. Please wait a while before trying again.")
// TODO: Persist rate limits to disk
type RateLimiter interface {
LimitIP(limit string, ip string) error
LimitUser(limit string, user int) error
}
type RateData struct {
value int
floorTime int
}
type RateFence struct {
duration int
max int
}
// TODO: Optimise this by using something other than a string when possible
type RateLimit struct {
data map[string][]RateData
fences []RateFence
sync.RWMutex
}
func NewRateLimit(fences []RateFence) *RateLimit {
for i, fence := range fences {
fences[i].duration = fence.duration * 1000 * 1000 * 1000
}
return &RateLimit{data: make(map[string][]RateData), fences: fences}
}
func (l *RateLimit) Limit(name string, ltype int) error {
l.Lock()
defer l.Unlock()
data, ok := l.data[name]
if !ok {
data = make([]RateData, len(l.fences))
for i, _ := range data {
data[i] = RateData{0, int(time.Now().Unix())}
}
}
for i, field := range data {
fence := l.fences[i]
diff := int(time.Now().Unix()) - field.floorTime
if diff >= fence.duration {
field = RateData{0, int(time.Now().Unix())}
data[i] = field
}
if field.value > fence.max {
return ErrExceededRateLimit
}
field.value++
data[i] = field
}
return nil
}
type DefaultRateLimiter struct {
limits map[string]*RateLimit
}
func NewDefaultRateLimiter() *DefaultRateLimiter {
return &DefaultRateLimiter{map[string]*RateLimit{
"register": NewRateLimit([]RateFence{RateFence{int(time.Hour / 2), 1}}),
}}
}
func (l *DefaultRateLimiter) LimitIP(limit string, ip string) error {
limiter, ok := l.limits[limit]
if !ok {
return ErrBadRateLimiter
}
return limiter.Limit(ip, 0)
}
func (l *DefaultRateLimiter) LimitUser(limit string, user int) error {
limiter, ok := l.limits[limit]
if !ok {
return ErrBadRateLimiter
}
return limiter.Limit(strconv.Itoa(user), 1)
}

View File

@ -177,7 +177,7 @@ func afterDBInit() (err error) {
if err != nil { if err != nil {
return errors.WithStack(err) return errors.WithStack(err)
} }
counters.AgentViewCounter, err = counters.NewDefaultAgentViewCounter() counters.AgentViewCounter, err = counters.NewDefaultAgentViewCounter(acc)
if err != nil { if err != nil {
return errors.WithStack(err) return errors.WithStack(err)
} }

View File

@ -288,32 +288,46 @@ type accCountBuilder struct {
table string table string
where string where string
limit string limit string
dateCutoff *dateCutoff // We might want to do this in a slightly less hacky way
inChain *AccSelectBuilder
inColumn string
build *Accumulator build *Accumulator
} }
func (count *accCountBuilder) Where(where string) *accCountBuilder { func (b *accCountBuilder) Where(where string) *accCountBuilder {
if count.where != "" { if b.where != "" {
count.where += " AND " b.where += " AND "
} }
count.where += where b.where += where
return count return b
} }
func (count *accCountBuilder) Limit(limit string) *accCountBuilder { func (b *accCountBuilder) Limit(limit string) *accCountBuilder {
count.limit = limit b.limit = limit
return count return b
} }
// TODO: Add QueryRow for this and use it in statistics.go func (b *accCountBuilder) DateCutoff(column string, quantity int, unit string) *accCountBuilder {
func (count *accCountBuilder) Prepare() *sql.Stmt { b.dateCutoff = &dateCutoff{column, quantity, unit}
return count.build.SimpleCount(count.table, count.where, count.limit) return b
} }
func (count *accCountBuilder) Total() (total int, err error) { // TODO: Fix this nasty hack
stmt := count.Prepare() func (b *accCountBuilder) Prepare() *sql.Stmt {
// TODO: Phase out the procedural API and use the adapter's OO API? The OO API might need a bit more work before we do that and it needs to be rolled out to MSSQL.
if b.dateCutoff != nil || b.inChain != nil {
selBuilder := b.build.GetAdapter().Builder().Count().FromCountAcc(b)
selBuilder.columns = "COUNT(*)"
return b.build.prepare(b.build.GetAdapter().ComplexSelect(selBuilder))
}
return b.build.SimpleCount(b.table, b.where, b.limit)
}
func (b *accCountBuilder) Total() (total int, err error) {
stmt := b.Prepare()
if stmt == nil { if stmt == nil {
return 0, count.build.FirstError() return 0, b.build.FirstError()
} }
err = stmt.QueryRow().Scan(&total) err = stmt.QueryRow().Scan(&total)
return total, err return total, err

View File

@ -240,5 +240,5 @@ func (build *Accumulator) Insert(table string) *accInsertBuilder {
} }
func (build *Accumulator) Count(table string) *accCountBuilder { func (build *Accumulator) Count(table string) *accCountBuilder {
return &accCountBuilder{table, "", "", build} return &accCountBuilder{table, "", "", nil, nil, "", build}
} }

View File

@ -15,6 +15,11 @@ func (build *prebuilder) Select(nlist ...string) *selectPrebuilder {
return &selectPrebuilder{name, "", "", "", "", "", nil, nil, "", build.adapter} return &selectPrebuilder{name, "", "", "", "", "", nil, nil, "", build.adapter}
} }
func (build *prebuilder) Count(nlist ...string) *selectPrebuilder {
name := optString(nlist, "")
return &selectPrebuilder{name, "", "COUNT(*)", "", "", "", nil, nil, "", build.adapter}
}
func (build *prebuilder) Insert(nlist ...string) *insertPrebuilder { func (build *prebuilder) Insert(nlist ...string) *insertPrebuilder {
name := optString(nlist, "") name := optString(nlist, "")
return &insertPrebuilder{name, "", "", "", build.adapter} return &insertPrebuilder{name, "", "", "", build.adapter}
@ -136,35 +141,50 @@ func (selectItem *selectPrebuilder) Where(where string) *selectPrebuilder {
return selectItem return selectItem
} }
func (selectItem *selectPrebuilder) InQ(subBuilder *selectPrebuilder) *selectPrebuilder { func (b *selectPrebuilder) InQ(subBuilder *selectPrebuilder) *selectPrebuilder {
selectItem.inChain = subBuilder b.inChain = subBuilder
return selectItem return b
} }
func (selectItem *selectPrebuilder) Orderby(orderby string) *selectPrebuilder { func (b *selectPrebuilder) Orderby(orderby string) *selectPrebuilder {
selectItem.orderby = orderby b.orderby = orderby
return selectItem return b
} }
func (selectItem *selectPrebuilder) Limit(limit string) *selectPrebuilder { func (b *selectPrebuilder) Limit(limit string) *selectPrebuilder {
selectItem.limit = limit b.limit = limit
return selectItem return b
} }
// TODO: We probably want to avoid the double allocation of two builders somehow // TODO: We probably want to avoid the double allocation of two builders somehow
func (selectItem *selectPrebuilder) FromAcc(accBuilder *AccSelectBuilder) *selectPrebuilder { func (b *selectPrebuilder) FromAcc(acc *AccSelectBuilder) *selectPrebuilder {
selectItem.table = accBuilder.table b.table = acc.table
selectItem.columns = accBuilder.columns if acc.columns != "" {
selectItem.where = accBuilder.where b.columns = acc.columns
selectItem.orderby = accBuilder.orderby
selectItem.limit = accBuilder.limit
selectItem.dateCutoff = accBuilder.dateCutoff
if accBuilder.inChain != nil {
selectItem.inChain = &selectPrebuilder{"", accBuilder.inChain.table, accBuilder.inChain.columns, accBuilder.inChain.where, accBuilder.inChain.orderby, accBuilder.inChain.limit, accBuilder.inChain.dateCutoff, nil, "", selectItem.build}
selectItem.inColumn = accBuilder.inColumn
} }
return selectItem b.where = acc.where
b.orderby = acc.orderby
b.limit = acc.limit
b.dateCutoff = acc.dateCutoff
if acc.inChain != nil {
b.inChain = &selectPrebuilder{"", acc.inChain.table, acc.inChain.columns, acc.inChain.where, acc.inChain.orderby, acc.inChain.limit, acc.inChain.dateCutoff, nil, "", b.build}
b.inColumn = acc.inColumn
}
return b
}
func (b *selectPrebuilder) FromCountAcc(acc *accCountBuilder) *selectPrebuilder {
b.table = acc.table
b.where = acc.where
b.limit = acc.limit
b.dateCutoff = acc.dateCutoff
if acc.inChain != nil {
b.inChain = &selectPrebuilder{"", acc.inChain.table, acc.inChain.columns, acc.inChain.where, acc.inChain.orderby, acc.inChain.limit, acc.inChain.dateCutoff, nil, "", b.build}
b.inColumn = acc.inColumn
}
return b
} }
// TODO: Add support for dateCutoff // TODO: Add support for dateCutoff
@ -209,7 +229,7 @@ func (insert *insertPrebuilder) Parse() {
insert.build.SimpleInsert(insert.name, insert.table, insert.columns, insert.fields) insert.build.SimpleInsert(insert.name, insert.table, insert.columns, insert.fields)
} }
type countPrebuilder struct { /*type countPrebuilder struct {
name string name string
table string table string
where string where string
@ -242,7 +262,7 @@ func (count *countPrebuilder) Text() (string, error) {
func (count *countPrebuilder) Parse() { func (count *countPrebuilder) Parse() {
count.build.SimpleCount(count.name, count.table, count.where, count.limit) count.build.SimpleCount(count.name, count.table, count.where, count.limit)
} }*/
func optString(nlist []string, defaultStr string) string { func optString(nlist []string, defaultStr string) string {
if len(nlist) == 0 { if len(nlist) == 0 {

View File

@ -519,13 +519,13 @@ func (adapter *MysqlAdapter) ComplexSelect(preBuilder *selectPrebuilder) (out st
return "", errors.New("No columns found for ComplexSelect") return "", errors.New("No columns found for ComplexSelect")
} }
var querystr = "SELECT " var querystr = "SELECT " + adapter.buildJoinColumns(preBuilder.columns)
// Slice up the user friendly strings into something easier to process // Slice up the user friendly strings into something easier to process
for _, column := range strings.Split(strings.TrimSpace(preBuilder.columns), ",") { /*for _, column := range strings.Split(strings.TrimSpace(preBuilder.columns), ",") {
querystr += "`" + strings.TrimSpace(column) + "`," querystr += "`" + strings.TrimSpace(column) + "`,"
} }
querystr = querystr[0 : len(querystr)-1] querystr = querystr[0 : len(querystr)-1]*/
var whereStr string var whereStr string
// TODO: Let callers have a Where() and a InQ() // TODO: Let callers have a Where() and a InQ()

View File

@ -806,7 +806,7 @@ func AccountPasswordResetSubmit(w http.ResponseWriter, r *http.Request, user com
return common.InternalError(err, w, r) return common.InternalError(err, w, r)
} }
// TODO: Move this query somewhere else // TODO: Move these queries somewhere else
var disc string var disc string
err = qgen.NewAcc().Select("password_resets").Columns("createdAt").DateCutoff("createdAt", 1, "hour").QueryRow().Scan(&disc) err = qgen.NewAcc().Select("password_resets").Columns("createdAt").DateCutoff("createdAt", 1, "hour").QueryRow().Scan(&disc)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
@ -816,6 +816,22 @@ func AccountPasswordResetSubmit(w http.ResponseWriter, r *http.Request, user com
return common.LocalError("You can only send a password reset email for a user once an hour", w, r, user) return common.LocalError("You can only send a password reset email for a user once an hour", w, r, user)
} }
count, err := qgen.NewAcc().Count("password_resets").DateCutoff("createdAt", 6, "hour").Total()
if err != nil && err != sql.ErrNoRows {
return common.InternalError(err, w, r)
}
if count >= 3 {
return common.LocalError("You can only send a password reset email for a user three times every six hours", w, r, user)
}
count, err = qgen.NewAcc().Count("password_resets").DateCutoff("createdAt", 12, "hour").Total()
if err != nil && err != sql.ErrNoRows {
return common.InternalError(err, w, r)
}
if count >= 4 {
return common.LocalError("You can only send a password reset email for a user four times every twelve hours", w, r, user)
}
err = common.PasswordResetter.Create(tuser.Email, tuser.ID, token) err = common.PasswordResetter.Create(tuser.Email, tuser.ID, token)
if err != nil { if err != nil {
return common.InternalError(err, w, r) return common.InternalError(err, w, r)

View File

@ -1,5 +1,5 @@
<div class="search widget_search"> <div class="search widget_search">
<input class="widget_search_input" name="widget_search" placeholder="Search" /> <input class="widget_search_input" name="widget_search" placeholder="Search" type="search" />
</div> </div>
<div class="rowblock filter_list widget_filter"> <div class="rowblock filter_list widget_filter">
{{range .Forums}} <div class="rowitem filter_item{{if .Selected}} filter_selected{{end}}" data-fid="{{.ID}}"><a href="/topics/?fids={{.ID}}" rel="nofollow">{{.Name}}</a></div> {{range .Forums}} <div class="rowitem filter_item{{if .Selected}} filter_selected{{end}}" data-fid="{{.ID}}"><a href="/topics/?fids={{.ID}}" rel="nofollow">{{.Name}}</a></div>