6bae378db0
Refactored the code to use the new builder syntax. Fixed the DbInit logic. Made sure the prepared statements are cleaned up. Added the AdminOnly middleware and added it to the routes. Added the Query method to the selectBuilder.
537 lines
14 KiB
Go
537 lines
14 KiB
Go
package common
|
|
|
|
import (
|
|
"database/sql"
|
|
"errors"
|
|
"log"
|
|
"strconv"
|
|
"sync"
|
|
"sync/atomic"
|
|
|
|
"../query_gen/lib"
|
|
"golang.org/x/crypto/bcrypt"
|
|
)
|
|
|
|
// TODO: Add the watchdog goroutine
|
|
// TODO: Add some sort of update method
|
|
var Users UserStore
|
|
var ErrAccountExists = errors.New("this username is already in use")
|
|
|
|
type UserStore interface {
|
|
DirtyGet(id int) *User
|
|
Get(id int) (*User, error)
|
|
Exists(id int) bool
|
|
//BulkGet(ids []int) ([]*User, error)
|
|
BulkGetMap(ids []int) (map[int]*User, error)
|
|
BypassGet(id int) (*User, error)
|
|
Create(username string, password string, email string, group int, active bool) (int, error)
|
|
GlobalCount() int
|
|
}
|
|
|
|
type UserCache interface {
|
|
CacheGet(id int) (*User, error)
|
|
CacheGetUnsafe(id int) (*User, error)
|
|
CacheSet(item *User) error
|
|
CacheAdd(item *User) error
|
|
CacheAddUnsafe(item *User) error
|
|
CacheRemove(id int) error
|
|
CacheRemoveUnsafe(id int) error
|
|
Flush()
|
|
Reload(id int) error
|
|
Length() int
|
|
SetCapacity(capacity int)
|
|
GetCapacity() int
|
|
}
|
|
|
|
type MemoryUserStore struct {
|
|
items map[int]*User
|
|
length int64
|
|
capacity int
|
|
get *sql.Stmt
|
|
exists *sql.Stmt
|
|
register *sql.Stmt
|
|
usernameExists *sql.Stmt
|
|
userCount *sql.Stmt
|
|
sync.RWMutex
|
|
}
|
|
|
|
// NewMemoryUserStore gives you a new instance of MemoryUserStore
|
|
func NewMemoryUserStore(capacity int) (*MemoryUserStore, error) {
|
|
acc := qgen.Builder.Accumulator()
|
|
// TODO: Add an admin version of registerStmt with more flexibility?
|
|
return &MemoryUserStore{
|
|
items: make(map[int]*User),
|
|
capacity: capacity,
|
|
get: acc.SimpleSelect("users", "name, group, is_super_admin, session, email, avatar, message, url_prefix, url_name, level, score, last_ip, temp_group", "uid = ?", "", ""),
|
|
exists: acc.SimpleSelect("users", "uid", "uid = ?", "", ""),
|
|
register: acc.SimpleInsert("users", "name, email, password, salt, group, is_super_admin, session, active, message, createdAt, lastActiveAt", "?,?,?,?,?,0,'',?,'',UTC_TIMESTAMP(),UTC_TIMESTAMP()"),
|
|
usernameExists: acc.SimpleSelect("users", "name", "name = ?", "", ""),
|
|
userCount: acc.SimpleCount("users", "", ""),
|
|
}, acc.FirstError()
|
|
}
|
|
|
|
func (mus *MemoryUserStore) CacheGet(id int) (*User, error) {
|
|
mus.RLock()
|
|
item, ok := mus.items[id]
|
|
mus.RUnlock()
|
|
if ok {
|
|
return item, nil
|
|
}
|
|
return item, ErrNoRows
|
|
}
|
|
|
|
func (mus *MemoryUserStore) CacheGetUnsafe(id int) (*User, error) {
|
|
item, ok := mus.items[id]
|
|
if ok {
|
|
return item, nil
|
|
}
|
|
return item, ErrNoRows
|
|
}
|
|
|
|
func (mus *MemoryUserStore) DirtyGet(id int) *User {
|
|
mus.RLock()
|
|
user, ok := mus.items[id]
|
|
mus.RUnlock()
|
|
if ok {
|
|
return user
|
|
}
|
|
|
|
user = &User{ID: id, Loggedin: true}
|
|
err := mus.get.QueryRow(id).Scan(&user.Name, &user.Group, &user.IsSuperAdmin, &user.Session, &user.Email, &user.Avatar, &user.Message, &user.URLPrefix, &user.URLName, &user.Level, &user.Score, &user.LastIP, &user.TempGroup)
|
|
|
|
user.Init()
|
|
if err == nil {
|
|
mus.CacheSet(user)
|
|
return user
|
|
}
|
|
return BlankUser()
|
|
}
|
|
|
|
func (mus *MemoryUserStore) Get(id int) (*User, error) {
|
|
mus.RLock()
|
|
user, ok := mus.items[id]
|
|
mus.RUnlock()
|
|
if ok {
|
|
return user, nil
|
|
}
|
|
|
|
user = &User{ID: id, Loggedin: true}
|
|
err := mus.get.QueryRow(id).Scan(&user.Name, &user.Group, &user.IsSuperAdmin, &user.Session, &user.Email, &user.Avatar, &user.Message, &user.URLPrefix, &user.URLName, &user.Level, &user.Score, &user.LastIP, &user.TempGroup)
|
|
|
|
user.Init()
|
|
if err == nil {
|
|
mus.CacheSet(user)
|
|
}
|
|
return user, err
|
|
}
|
|
|
|
// WARNING: We did a little hack to make this as thin and quick as possible to reduce lock contention, use the * Cascade* methods instead for normal use
|
|
func (mus *MemoryUserStore) bulkGet(ids []int) (list []*User) {
|
|
list = make([]*User, len(ids))
|
|
mus.RLock()
|
|
for i, id := range ids {
|
|
list[i] = mus.items[id]
|
|
}
|
|
mus.RUnlock()
|
|
return list
|
|
}
|
|
|
|
// TODO: Optimise the query to avoid preparing it on the spot? Maybe, use knowledge of the most common IN() parameter counts?
|
|
// TODO: ID of 0 should always error?
|
|
func (mus *MemoryUserStore) BulkGetMap(ids []int) (list map[int]*User, err error) {
|
|
var idCount = len(ids)
|
|
list = make(map[int]*User)
|
|
if idCount == 0 {
|
|
return list, nil
|
|
}
|
|
|
|
var stillHere []int
|
|
sliceList := mus.bulkGet(ids)
|
|
for i, sliceItem := range sliceList {
|
|
if sliceItem != nil {
|
|
list[sliceItem.ID] = sliceItem
|
|
} else {
|
|
stillHere = append(stillHere, ids[i])
|
|
}
|
|
}
|
|
ids = stillHere
|
|
|
|
// If every user is in the cache, then return immediately
|
|
if len(ids) == 0 {
|
|
return list, nil
|
|
}
|
|
|
|
// TODO: Add a function for the qlist stuff
|
|
var qlist string
|
|
var uidList []interface{}
|
|
for _, id := range ids {
|
|
uidList = append(uidList, strconv.Itoa(id))
|
|
qlist += "?,"
|
|
}
|
|
qlist = qlist[0 : len(qlist)-1]
|
|
|
|
acc := qgen.Builder.Accumulator()
|
|
rows, err := acc.Select("users").Columns("uid, name, group, is_super_admin, session, email, avatar, message, url_prefix, url_name, level, score, last_ip, temp_group").Where("uid IN(" + qlist + ")").Query(uidList...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for rows.Next() {
|
|
user := &User{Loggedin: true}
|
|
err := rows.Scan(&user.ID, &user.Name, &user.Group, &user.IsSuperAdmin, &user.Session, &user.Email, &user.Avatar, &user.Message, &user.URLPrefix, &user.URLName, &user.Level, &user.Score, &user.LastIP, &user.TempGroup)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
user.Init()
|
|
mus.CacheSet(user)
|
|
list[user.ID] = user
|
|
}
|
|
|
|
// Did we miss any users?
|
|
if idCount > len(list) {
|
|
var sidList string
|
|
for _, id := range ids {
|
|
_, ok := list[id]
|
|
if !ok {
|
|
sidList += strconv.Itoa(id) + ","
|
|
}
|
|
}
|
|
|
|
// We probably don't need this, but it might be useful in case of bugs in BulkCascadeGetMap
|
|
if sidList == "" {
|
|
if Dev.DebugMode {
|
|
log.Print("This data is sampled later in the BulkCascadeGetMap function, so it might miss the cached IDs")
|
|
log.Print("idCount", idCount)
|
|
log.Print("ids", ids)
|
|
log.Print("list", list)
|
|
}
|
|
return list, errors.New("We weren't able to find a user, but we don't know which one")
|
|
}
|
|
sidList = sidList[0 : len(sidList)-1]
|
|
|
|
err = errors.New("Unable to find the users with the following IDs: " + sidList)
|
|
}
|
|
|
|
return list, err
|
|
}
|
|
|
|
func (mus *MemoryUserStore) BypassGet(id int) (*User, error) {
|
|
user := &User{ID: id, Loggedin: true}
|
|
err := mus.get.QueryRow(id).Scan(&user.Name, &user.Group, &user.IsSuperAdmin, &user.Session, &user.Email, &user.Avatar, &user.Message, &user.URLPrefix, &user.URLName, &user.Level, &user.Score, &user.LastIP, &user.TempGroup)
|
|
|
|
user.Init()
|
|
return user, err
|
|
}
|
|
|
|
func (mus *MemoryUserStore) Reload(id int) error {
|
|
user := &User{ID: id, Loggedin: true}
|
|
err := mus.get.QueryRow(id).Scan(&user.Name, &user.Group, &user.IsSuperAdmin, &user.Session, &user.Email, &user.Avatar, &user.Message, &user.URLPrefix, &user.URLName, &user.Level, &user.Score, &user.LastIP, &user.TempGroup)
|
|
if err != nil {
|
|
mus.CacheRemove(id)
|
|
return err
|
|
}
|
|
|
|
user.Init()
|
|
_ = mus.CacheSet(user)
|
|
return nil
|
|
}
|
|
|
|
func (mus *MemoryUserStore) Exists(id int) bool {
|
|
err := mus.exists.QueryRow(id).Scan(&id)
|
|
if err != nil && err != ErrNoRows {
|
|
LogError(err)
|
|
}
|
|
return err != ErrNoRows
|
|
}
|
|
|
|
func (mus *MemoryUserStore) CacheSet(item *User) error {
|
|
mus.Lock()
|
|
user, ok := mus.items[item.ID]
|
|
if ok {
|
|
mus.Unlock()
|
|
*user = *item
|
|
} else if int(mus.length) >= mus.capacity {
|
|
mus.Unlock()
|
|
return ErrStoreCapacityOverflow
|
|
} else {
|
|
mus.items[item.ID] = item
|
|
mus.Unlock()
|
|
atomic.AddInt64(&mus.length, 1)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (mus *MemoryUserStore) CacheAdd(item *User) error {
|
|
if int(mus.length) >= mus.capacity {
|
|
return ErrStoreCapacityOverflow
|
|
}
|
|
mus.Lock()
|
|
mus.items[item.ID] = item
|
|
mus.length = int64(len(mus.items))
|
|
mus.Unlock()
|
|
return nil
|
|
}
|
|
|
|
func (mus *MemoryUserStore) CacheAddUnsafe(item *User) error {
|
|
if int(mus.length) >= mus.capacity {
|
|
return ErrStoreCapacityOverflow
|
|
}
|
|
mus.items[item.ID] = item
|
|
mus.length = int64(len(mus.items))
|
|
return nil
|
|
}
|
|
|
|
func (mus *MemoryUserStore) CacheRemove(id int) error {
|
|
mus.Lock()
|
|
_, ok := mus.items[id]
|
|
if !ok {
|
|
mus.Unlock()
|
|
return ErrNoRows
|
|
}
|
|
delete(mus.items, id)
|
|
mus.Unlock()
|
|
atomic.AddInt64(&mus.length, -1)
|
|
return nil
|
|
}
|
|
|
|
func (mus *MemoryUserStore) CacheRemoveUnsafe(id int) error {
|
|
_, ok := mus.items[id]
|
|
if !ok {
|
|
return ErrNoRows
|
|
}
|
|
delete(mus.items, id)
|
|
atomic.AddInt64(&mus.length, -1)
|
|
return nil
|
|
}
|
|
|
|
// TODO: Change active to a bool?
|
|
func (mus *MemoryUserStore) Create(username string, password string, email string, group int, active bool) (int, error) {
|
|
// Is this username already taken..?
|
|
err := mus.usernameExists.QueryRow(username).Scan(&username)
|
|
if err != ErrNoRows {
|
|
return 0, ErrAccountExists
|
|
}
|
|
|
|
salt, err := GenerateSafeString(SaltLength)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password+salt), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
res, err := mus.register.Exec(username, email, string(hashedPassword), salt, group, active)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
lastID, err := res.LastInsertId()
|
|
return int(lastID), err
|
|
}
|
|
|
|
func (mus *MemoryUserStore) Flush() {
|
|
mus.Lock()
|
|
mus.items = make(map[int]*User)
|
|
mus.length = 0
|
|
mus.Unlock()
|
|
}
|
|
|
|
// ! Is this concurrent?
|
|
// Length returns the number of users in the memory cache
|
|
func (mus *MemoryUserStore) Length() int {
|
|
return int(mus.length)
|
|
}
|
|
|
|
func (mus *MemoryUserStore) SetCapacity(capacity int) {
|
|
mus.capacity = capacity
|
|
}
|
|
|
|
func (mus *MemoryUserStore) GetCapacity() int {
|
|
return mus.capacity
|
|
}
|
|
|
|
// GlobalCount returns the total number of users registered on the forums
|
|
func (mus *MemoryUserStore) GlobalCount() (ucount int) {
|
|
err := mus.userCount.QueryRow().Scan(&ucount)
|
|
if err != nil {
|
|
LogError(err)
|
|
}
|
|
return ucount
|
|
}
|
|
|
|
type SQLUserStore struct {
|
|
get *sql.Stmt
|
|
exists *sql.Stmt
|
|
register *sql.Stmt
|
|
usernameExists *sql.Stmt
|
|
userCount *sql.Stmt
|
|
}
|
|
|
|
func NewSQLUserStore() (*SQLUserStore, error) {
|
|
acc := qgen.Builder.Accumulator()
|
|
// TODO: Add an admin version of registerStmt with more flexibility?
|
|
return &SQLUserStore{
|
|
get: acc.SimpleSelect("users", "name, group, is_super_admin, session, email, avatar, message, url_prefix, url_name, level, score, last_ip, temp_group", "uid = ?", "", ""),
|
|
exists: acc.SimpleSelect("users", "uid", "uid = ?", "", ""),
|
|
register: acc.SimpleInsert("users", "name, email, password, salt, group, is_super_admin, session, active, message, createdAt, lastActiveAt", "?,?,?,?,?,0,'',?,'',UTC_TIMESTAMP(),UTC_TIMESTAMP()"),
|
|
usernameExists: acc.SimpleSelect("users", "name", "name = ?", "", ""),
|
|
userCount: acc.SimpleCount("users", "", ""),
|
|
}, acc.FirstError()
|
|
}
|
|
|
|
func (mus *SQLUserStore) DirtyGet(id int) *User {
|
|
user := &User{ID: id, Loggedin: true}
|
|
err := mus.get.QueryRow(id).Scan(&user.Name, &user.Group, &user.IsSuperAdmin, &user.Session, &user.Email, &user.Avatar, &user.Message, &user.URLPrefix, &user.URLName, &user.Level, &user.Score, &user.LastIP, &user.TempGroup)
|
|
|
|
user.Init()
|
|
if err != nil {
|
|
return BlankUser()
|
|
}
|
|
return user
|
|
}
|
|
|
|
func (mus *SQLUserStore) Get(id int) (*User, error) {
|
|
user := &User{ID: id, Loggedin: true}
|
|
err := mus.get.QueryRow(id).Scan(&user.Name, &user.Group, &user.IsSuperAdmin, &user.Session, &user.Email, &user.Avatar, &user.Message, &user.URLPrefix, &user.URLName, &user.Level, &user.Score, &user.LastIP, &user.TempGroup)
|
|
|
|
user.Init()
|
|
return user, err
|
|
}
|
|
|
|
// TODO: Optimise the query to avoid preparing it on the spot? Maybe, use knowledge of the most common IN() parameter counts?
|
|
func (mus *SQLUserStore) BulkGetMap(ids []int) (list map[int]*User, err error) {
|
|
var qlist string
|
|
var uidList []interface{}
|
|
for _, id := range ids {
|
|
uidList = append(uidList, strconv.Itoa(id))
|
|
qlist += "?,"
|
|
}
|
|
qlist = qlist[0 : len(qlist)-1]
|
|
|
|
acc := qgen.Builder.Accumulator()
|
|
rows, err := acc.Select("users").Columns("uid, name, group, is_super_admin, session, email, avatar, message, url_prefix, url_name, level, score, last_ip, temp_group").Where("uid IN(" + qlist + ")").Query(uidList...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
list = make(map[int]*User)
|
|
for rows.Next() {
|
|
user := &User{Loggedin: true}
|
|
err := rows.Scan(&user.ID, &user.Name, &user.Group, &user.IsSuperAdmin, &user.Session, &user.Email, &user.Avatar, &user.Message, &user.URLPrefix, &user.URLName, &user.Level, &user.Score, &user.LastIP, &user.TempGroup)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
user.Init()
|
|
list[user.ID] = user
|
|
}
|
|
|
|
return list, nil
|
|
}
|
|
|
|
func (mus *SQLUserStore) BypassGet(id int) (*User, error) {
|
|
user := &User{ID: id, Loggedin: true}
|
|
err := mus.get.QueryRow(id).Scan(&user.Name, &user.Group, &user.IsSuperAdmin, &user.Session, &user.Email, &user.Avatar, &user.Message, &user.URLPrefix, &user.URLName, &user.Level, &user.Score, &user.LastIP, &user.TempGroup)
|
|
|
|
user.Init()
|
|
return user, err
|
|
}
|
|
|
|
func (mus *SQLUserStore) Exists(id int) bool {
|
|
err := mus.exists.QueryRow(id).Scan(&id)
|
|
if err != nil && err != ErrNoRows {
|
|
LogError(err)
|
|
}
|
|
return err != ErrNoRows
|
|
}
|
|
|
|
func (mus *SQLUserStore) Create(username string, password string, email string, group int, active bool) (int, error) {
|
|
// Is this username already taken..?
|
|
err := mus.usernameExists.QueryRow(username).Scan(&username)
|
|
if err != ErrNoRows {
|
|
return 0, ErrAccountExists
|
|
}
|
|
|
|
salt, err := GenerateSafeString(SaltLength)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password+salt), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
res, err := mus.register.Exec(username, email, string(hashedPassword), salt, group, active)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
lastID, err := res.LastInsertId()
|
|
return int(lastID), err
|
|
}
|
|
|
|
// GlobalCount returns the total number of users registered on the forums
|
|
func (mus *SQLUserStore) GlobalCount() (ucount int) {
|
|
err := mus.userCount.QueryRow().Scan(&ucount)
|
|
if err != nil {
|
|
LogError(err)
|
|
}
|
|
return ucount
|
|
}
|
|
|
|
// TODO: MockUserStore
|
|
|
|
// NullUserStore is here for tests because Go doesn't have short-circuiting
|
|
type NullUserStore struct {
|
|
}
|
|
|
|
func (nus *NullUserStore) CacheGet(_ int) (*User, error) {
|
|
return nil, ErrNoRows
|
|
}
|
|
|
|
func (nus *NullUserStore) CacheGetUnsafe(_ int) (*User, error) {
|
|
return nil, ErrNoRows
|
|
}
|
|
|
|
func (nus *NullUserStore) CacheSet(_ *User) error {
|
|
return ErrStoreCapacityOverflow
|
|
}
|
|
|
|
func (nus *NullUserStore) CacheAdd(_ *User) error {
|
|
return ErrStoreCapacityOverflow
|
|
}
|
|
|
|
func (nus *NullUserStore) CacheAddUnsafe(_ *User) error {
|
|
return ErrStoreCapacityOverflow
|
|
}
|
|
|
|
func (nus *NullUserStore) CacheRemove(_ int) error {
|
|
return ErrNoRows
|
|
}
|
|
|
|
func (nus *NullUserStore) CacheRemoveUnsafe(_ int) error {
|
|
return ErrNoRows
|
|
}
|
|
|
|
func (nus *NullUserStore) Flush() {
|
|
}
|
|
|
|
func (nus *NullUserStore) Reload(_ int) error {
|
|
return ErrNoRows
|
|
}
|
|
|
|
func (nus *NullUserStore) Length() int {
|
|
return 0
|
|
}
|
|
|
|
func (nus *NullUserStore) SetCapacity(_ int) {
|
|
}
|
|
|
|
func (nus *NullUserStore) GetCapacity() int {
|
|
return 0
|
|
}
|