diff --git a/database.go b/database.go index 27b9ff3b..323b9080 100644 --- a/database.go +++ b/database.go @@ -68,7 +68,10 @@ func initDatabase() (err error) { } log.Print("Loading the forum permissions.") - fpstore = NewForumPermsStore() + fpstore, err = NewMemoryForumPermsStore() + if err != nil { + return err + } err = fpstore.Init() if err != nil { return err @@ -87,5 +90,5 @@ func initDatabase() (err error) { } log.Print("Loading the themes.") - return LoadThemes() + return LoadThemeActiveStatus() } diff --git a/forum_perms_store.go b/forum_perms_store.go index 6098a541..da6db07c 100644 --- a/forum_perms_store.go +++ b/forum_perms_store.go @@ -1,20 +1,45 @@ package main import ( + "database/sql" "encoding/json" "log" + + "./query_gen/lib" ) -var fpstore *ForumPermsStore +var fpstore ForumPermsStore -type ForumPermsStore struct { +type ForumPermsStore interface { + Init() error + Get(fid int, gid int) (fperms ForumPerms, err error) + Reload(id int) error } -func NewForumPermsStore() *ForumPermsStore { - return &ForumPermsStore{} +type ForumPermsCache interface { } -func (fps *ForumPermsStore) Init() error { +type MemoryForumPermsStore struct { + get *sql.Stmt + getByForum *sql.Stmt +} + +func NewMemoryForumPermsStore() (*MemoryForumPermsStore, error) { + getPermsStmt, err := qgen.Builder.SimpleSelect("forums_permissions", "gid, fid, permissions", "", "gid ASC, fid ASC", "") + if err != nil { + return nil, err + } + getPermsByForumStmt, err := qgen.Builder.SimpleSelect("forums_permissions", "gid, permissions", "fid = ?", "gid ASC", "") + if err != nil { + return nil, err + } + return &MemoryForumPermsStore{ + get: getPermsStmt, + getByForum: getPermsByForumStmt, + }, nil +} + +func (fps *MemoryForumPermsStore) Init() error { fids, err := fstore.GetAllIDs() if err != nil { return err @@ -23,7 +48,7 @@ func (fps *ForumPermsStore) Init() error { log.Print("fids: ", fids) } - rows, err := stmts.getForumsPermissions.Query() + rows, err := fps.get.Query() if err != nil { return err } @@ -73,7 +98,7 @@ func (fps *ForumPermsStore) Init() error { } // TODO: Need a more thread-safe way of doing this. Possibly with sync.Map? -func (fps *ForumPermsStore) Reload(fid int) error { +func (fps *MemoryForumPermsStore) Reload(fid int) error { if dev.DebugMode { log.Printf("Reloading the forum permissions for forum #%d", fid) } @@ -82,7 +107,7 @@ func (fps *ForumPermsStore) Reload(fid int) error { return err } - rows, err := db.Query("select gid, permissions from forums_permissions where fid = ? order by gid asc", fid) + rows, err := fps.getByForum.Query(fid) if err != nil { return err } @@ -112,7 +137,7 @@ func (fps *ForumPermsStore) Reload(fid int) error { return fps.cascadePermSetToGroups(forumPerms, fids) } -func (fps *ForumPermsStore) cascadePermSetToGroups(forumPerms map[int]map[int]ForumPerms, fids []int) error { +func (fps *MemoryForumPermsStore) cascadePermSetToGroups(forumPerms map[int]map[int]ForumPerms, fids []int) error { groups, err := gstore.GetAll() if err != nil { return err @@ -134,7 +159,7 @@ func (fps *ForumPermsStore) cascadePermSetToGroups(forumPerms map[int]map[int]Fo return nil } -func (fps *ForumPermsStore) cascadePermSetToGroup(forumPerms map[int]map[int]ForumPerms, group *Group, fids []int) { +func (fps *MemoryForumPermsStore) cascadePermSetToGroup(forumPerms map[int]map[int]ForumPerms, group *Group, fids []int) { for _, fid := range fids { if dev.SuperDebug { log.Printf("Forum #%+v\n", fid) @@ -164,7 +189,7 @@ func (fps *ForumPermsStore) cascadePermSetToGroup(forumPerms map[int]map[int]For } } -func (fps *ForumPermsStore) Get(fid int, gid int) (fperms ForumPerms, err error) { +func (fps *MemoryForumPermsStore) Get(fid int, gid int) (fperms ForumPerms, err error) { // TODO: Add a hook here and have plugin_guilds use it group, err := gstore.Get(gid) if err != nil { diff --git a/forum_store.go b/forum_store.go index 28577ff5..55f779a5 100644 --- a/forum_store.go +++ b/forum_store.go @@ -59,10 +59,14 @@ type MemoryForumStore struct { forumView atomic.Value // []*Forum //fids []int - get *sql.Stmt - getAll *sql.Stmt - delete *sql.Stmt - getForumCount *sql.Stmt + get *sql.Stmt + getAll *sql.Stmt + delete *sql.Stmt + create *sql.Stmt + count *sql.Stmt + updateCache *sql.Stmt + addTopicsToForum *sql.Stmt + removeTopicsFromForum *sql.Stmt } // NewMemoryForumStore gives you a new instance of MemoryForumStore @@ -80,15 +84,35 @@ func NewMemoryForumStore() (*MemoryForumStore, error) { if err != nil { return nil, err } + createStmt, err := qgen.Builder.SimpleInsert("forums", "name, desc, active, preset", "?,?,?,?") + if err != nil { + return nil, err + } forumCountStmt, err := qgen.Builder.SimpleCount("forums", "name != ''", "") if err != nil { return nil, err } + updateCacheStmt, err := qgen.Builder.SimpleUpdate("forums", "lastTopicID = ?, lastReplyerID = ?", "fid = ?") + if err != nil { + return nil, err + } + addTopicsToForumStmt, err := qgen.Builder.SimpleUpdate("forums", "topicCount = topicCount + ?", "fid = ?") + if err != nil { + return nil, err + } + removeTopicsFromForumStmt, err := qgen.Builder.SimpleUpdate("forums", "topicCount = topicCount - ?", "fid = ?") + if err != nil { + return nil, err + } return &MemoryForumStore{ - get: getStmt, - getAll: getAllStmt, - delete: deleteStmt, - getForumCount: forumCountStmt, + get: getStmt, + getAll: getAllStmt, + delete: deleteStmt, + create: createStmt, + count: forumCountStmt, + updateCache: updateCacheStmt, + addTopicsToForum: addTopicsToForumStmt, + removeTopicsFromForum: removeTopicsFromForumStmt, }, nil } @@ -102,8 +126,7 @@ func (mfs *MemoryForumStore) LoadForums() error { } } - // TODO: Move this statement into the store - rows, err := stmts.getForums.Query() + rows, err := mfs.getAll.Query() if err != nil { return err } @@ -122,7 +145,7 @@ func (mfs *MemoryForumStore) LoadForums() error { log.Print("Adding a placeholder forum") } } else { - log.Print("Adding the " + forum.Name + " forum") + log.Printf("Adding the %s forum", forum.Name) } forum.Link = buildForumURL(nameToSlug(forum.Name), forum.ID) @@ -328,11 +351,11 @@ func (mfs *MemoryForumStore) Delete(id int) error { } func (mfs *MemoryForumStore) AddTopic(tid int, uid int, fid int) error { - _, err := stmts.updateForumCache.Exec(tid, uid, fid) + _, err := mfs.updateCache.Exec(tid, uid, fid) if err != nil { return err } - _, err = stmts.addTopicsToForum.Exec(1, fid) + _, err = mfs.addTopicsToForum.Exec(1, fid) if err != nil { return err } @@ -342,7 +365,7 @@ func (mfs *MemoryForumStore) AddTopic(tid int, uid int, fid int) error { // TODO: Update the forum cache with the latest topic func (mfs *MemoryForumStore) RemoveTopic(fid int) error { - _, err := stmts.removeTopicsFromForum.Exec(1, fid) + _, err := mfs.removeTopicsFromForum.Exec(1, fid) if err != nil { return err } @@ -354,7 +377,7 @@ func (mfs *MemoryForumStore) RemoveTopic(fid int) error { // DEPRECATED. forum.Update() will be the way to do this in the future, once it's completed // TODO: Have a pointer to the last topic rather than storing it on the forum itself func (mfs *MemoryForumStore) UpdateLastTopic(tid int, uid int, fid int) error { - _, err := stmts.updateForumCache.Exec(tid, uid, fid) + _, err := mfs.updateCache.Exec(tid, uid, fid) if err != nil { return err } @@ -364,8 +387,7 @@ func (mfs *MemoryForumStore) UpdateLastTopic(tid int, uid int, fid int) error { func (mfs *MemoryForumStore) Create(forumName string, forumDesc string, active bool, preset string) (int, error) { forumCreateMutex.Lock() - // TODO: Move this query into the store - res, err := stmts.createForum.Exec(forumName, forumDesc, active, preset) + res, err := mfs.create.Exec(forumName, forumDesc, active, preset) if err != nil { return 0, err } @@ -399,7 +421,7 @@ func (mfs *MemoryForumStore) Length() (length int) { // TODO: Get the total count of forums in the forum store minus the blanked forums rather than doing a heavy query for this? // GlobalCount returns the total number of forums func (mfs *MemoryForumStore) GlobalCount() (fcount int) { - err := mfs.getForumCount.QueryRow().Scan(&fcount) + err := mfs.count.QueryRow().Scan(&fcount) if err != nil { LogError(err) } diff --git a/gen_mssql.go b/gen_mssql.go index c3b15682..930e1e5d 100644 --- a/gen_mssql.go +++ b/gen_mssql.go @@ -8,16 +8,11 @@ import "database/sql" // nolint type Stmts struct { - getReply *sql.Stmt - getUserReply *sql.Stmt getPassword *sql.Stmt getSettings *sql.Stmt getSetting *sql.Stmt getFullSetting *sql.Stmt getFullSettings *sql.Stmt - getGroups *sql.Stmt - getForums *sql.Stmt - getForumsPermissions *sql.Stmt getPlugins *sql.Stmt getThemes *sql.Stmt getWidgets *sql.Stmt @@ -52,7 +47,6 @@ type Stmts struct { getWatchers *sql.Stmt createTopic *sql.Stmt createReport *sql.Stmt - createReply *sql.Stmt createActionReply *sql.Stmt createLike *sql.Stmt addActivity *sql.Stmt @@ -60,7 +54,6 @@ type Stmts struct { addEmail *sql.Stmt createProfileReply *sql.Stmt addSubscription *sql.Stmt - createForum *sql.Stmt addForumPermsToForum *sql.Stmt addPlugin *sql.Stmt addTheme *sql.Stmt @@ -70,9 +63,6 @@ type Stmts struct { createWordFilter *sql.Stmt addRepliesToTopic *sql.Stmt removeRepliesFromTopic *sql.Stmt - addTopicsToForum *sql.Stmt - removeTopicsFromForum *sql.Stmt - updateForumCache *sql.Stmt addLikesToTopic *sql.Stmt addLikesToReply *sql.Stmt editTopic *sql.Stmt @@ -140,20 +130,6 @@ func _gen_mssql() (err error) { log.Print("Building the generated statements") } - log.Print("Preparing getReply statement.") - stmts.getReply, err = db.Prepare("SELECT [tid],[content],[createdBy],[createdAt],[lastEdit],[lastEditBy],[ipaddress],[likeCount] FROM [replies] WHERE [rid] = ?1") - if err != nil { - log.Print("Bad Query: ","SELECT [tid],[content],[createdBy],[createdAt],[lastEdit],[lastEditBy],[ipaddress],[likeCount] FROM [replies] WHERE [rid] = ?1") - return err - } - - log.Print("Preparing getUserReply statement.") - stmts.getUserReply, err = db.Prepare("SELECT [uid],[content],[createdBy],[createdAt],[lastEdit],[lastEditBy],[ipaddress] FROM [users_replies] WHERE [rid] = ?1") - if err != nil { - log.Print("Bad Query: ","SELECT [uid],[content],[createdBy],[createdAt],[lastEdit],[lastEditBy],[ipaddress] FROM [users_replies] WHERE [rid] = ?1") - return err - } - log.Print("Preparing getPassword statement.") stmts.getPassword, err = db.Prepare("SELECT [password],[salt] FROM [users] WHERE [uid] = ?1") if err != nil { @@ -189,27 +165,6 @@ func _gen_mssql() (err error) { return err } - log.Print("Preparing getGroups statement.") - stmts.getGroups, err = db.Prepare("SELECT [gid],[name],[permissions],[plugin_perms],[is_mod],[is_admin],[is_banned],[tag] FROM [users_groups]") - if err != nil { - log.Print("Bad Query: ","SELECT [gid],[name],[permissions],[plugin_perms],[is_mod],[is_admin],[is_banned],[tag] FROM [users_groups]") - return err - } - - log.Print("Preparing getForums statement.") - stmts.getForums, err = db.Prepare("SELECT [fid],[name],[desc],[active],[preset],[parentID],[parentType],[topicCount],[lastTopicID],[lastReplyerID] FROM [forums] ORDER BY fid ASC") - if err != nil { - log.Print("Bad Query: ","SELECT [fid],[name],[desc],[active],[preset],[parentID],[parentType],[topicCount],[lastTopicID],[lastReplyerID] FROM [forums] ORDER BY fid ASC") - return err - } - - log.Print("Preparing getForumsPermissions statement.") - stmts.getForumsPermissions, err = db.Prepare("SELECT [gid],[fid],[permissions] FROM [forums_permissions] ORDER BY gid ASC,fid ASC") - if err != nil { - log.Print("Bad Query: ","SELECT [gid],[fid],[permissions] FROM [forums_permissions] ORDER BY gid ASC,fid ASC") - return err - } - log.Print("Preparing getPlugins statement.") stmts.getPlugins, err = db.Prepare("SELECT [uname],[active],[installed] FROM [plugins]") if err != nil { @@ -448,13 +403,6 @@ func _gen_mssql() (err error) { return err } - log.Print("Preparing createReply statement.") - stmts.createReply, err = db.Prepare("INSERT INTO [replies] ([tid],[content],[parsed_content],[createdAt],[lastUpdated],[ipaddress],[words],[createdBy]) VALUES (?,?,?,GETUTCDATE(),GETUTCDATE(),?,?,?)") - if err != nil { - log.Print("Bad Query: ","INSERT INTO [replies] ([tid],[content],[parsed_content],[createdAt],[lastUpdated],[ipaddress],[words],[createdBy]) VALUES (?,?,?,GETUTCDATE(),GETUTCDATE(),?,?,?)") - return err - } - log.Print("Preparing createActionReply statement.") stmts.createActionReply, err = db.Prepare("INSERT INTO [replies] ([tid],[actionType],[ipaddress],[createdBy],[createdAt],[lastUpdated],[content],[parsed_content]) VALUES (?,?,?,?,GETUTCDATE(),GETUTCDATE(),'','')") if err != nil { @@ -504,13 +452,6 @@ func _gen_mssql() (err error) { return err } - log.Print("Preparing createForum statement.") - stmts.createForum, err = db.Prepare("INSERT INTO [forums] ([name],[desc],[active],[preset]) VALUES (?,?,?,?)") - if err != nil { - log.Print("Bad Query: ","INSERT INTO [forums] ([name],[desc],[active],[preset]) VALUES (?,?,?,?)") - return err - } - log.Print("Preparing addForumPermsToForum statement.") stmts.addForumPermsToForum, err = db.Prepare("INSERT INTO [forums_permissions] ([gid],[fid],[preset],[permissions]) VALUES (?,?,?,?)") if err != nil { @@ -574,27 +515,6 @@ func _gen_mssql() (err error) { return err } - log.Print("Preparing addTopicsToForum statement.") - stmts.addTopicsToForum, err = db.Prepare("UPDATE [forums] SET [topicCount] = [topicCount] + ? WHERE [fid] = ?") - if err != nil { - log.Print("Bad Query: ","UPDATE [forums] SET [topicCount] = [topicCount] + ? WHERE [fid] = ?") - return err - } - - log.Print("Preparing removeTopicsFromForum statement.") - stmts.removeTopicsFromForum, err = db.Prepare("UPDATE [forums] SET [topicCount] = [topicCount] - ? WHERE [fid] = ?") - if err != nil { - log.Print("Bad Query: ","UPDATE [forums] SET [topicCount] = [topicCount] - ? WHERE [fid] = ?") - return err - } - - log.Print("Preparing updateForumCache statement.") - stmts.updateForumCache, err = db.Prepare("UPDATE [forums] SET [lastTopicID] = ?,[lastReplyerID] = ? WHERE [fid] = ?") - if err != nil { - log.Print("Bad Query: ","UPDATE [forums] SET [lastTopicID] = ?,[lastReplyerID] = ? WHERE [fid] = ?") - return err - } - log.Print("Preparing addLikesToTopic statement.") stmts.addLikesToTopic, err = db.Prepare("UPDATE [topics] SET [likeCount] = [likeCount] + ? WHERE [tid] = ?") if err != nil { diff --git a/gen_mysql.go b/gen_mysql.go index 51e304d6..025cb718 100644 --- a/gen_mysql.go +++ b/gen_mysql.go @@ -10,16 +10,11 @@ import "database/sql" // nolint type Stmts struct { - getReply *sql.Stmt - getUserReply *sql.Stmt getPassword *sql.Stmt getSettings *sql.Stmt getSetting *sql.Stmt getFullSetting *sql.Stmt getFullSettings *sql.Stmt - getGroups *sql.Stmt - getForums *sql.Stmt - getForumsPermissions *sql.Stmt getPlugins *sql.Stmt getThemes *sql.Stmt getWidgets *sql.Stmt @@ -54,7 +49,6 @@ type Stmts struct { getWatchers *sql.Stmt createTopic *sql.Stmt createReport *sql.Stmt - createReply *sql.Stmt createActionReply *sql.Stmt createLike *sql.Stmt addActivity *sql.Stmt @@ -62,7 +56,6 @@ type Stmts struct { addEmail *sql.Stmt createProfileReply *sql.Stmt addSubscription *sql.Stmt - createForum *sql.Stmt addForumPermsToForum *sql.Stmt addPlugin *sql.Stmt addTheme *sql.Stmt @@ -72,9 +65,6 @@ type Stmts struct { createWordFilter *sql.Stmt addRepliesToTopic *sql.Stmt removeRepliesFromTopic *sql.Stmt - addTopicsToForum *sql.Stmt - removeTopicsFromForum *sql.Stmt - updateForumCache *sql.Stmt addLikesToTopic *sql.Stmt addLikesToReply *sql.Stmt editTopic *sql.Stmt @@ -142,18 +132,6 @@ func _gen_mysql() (err error) { log.Print("Building the generated statements") } - log.Print("Preparing getReply statement.") - stmts.getReply, err = db.Prepare("SELECT `tid`,`content`,`createdBy`,`createdAt`,`lastEdit`,`lastEditBy`,`ipaddress`,`likeCount` FROM `replies` WHERE `rid` = ?") - if err != nil { - return err - } - - log.Print("Preparing getUserReply statement.") - stmts.getUserReply, err = db.Prepare("SELECT `uid`,`content`,`createdBy`,`createdAt`,`lastEdit`,`lastEditBy`,`ipaddress` FROM `users_replies` WHERE `rid` = ?") - if err != nil { - return err - } - log.Print("Preparing getPassword statement.") stmts.getPassword, err = db.Prepare("SELECT `password`,`salt` FROM `users` WHERE `uid` = ?") if err != nil { @@ -184,24 +162,6 @@ func _gen_mysql() (err error) { return err } - log.Print("Preparing getGroups statement.") - stmts.getGroups, err = db.Prepare("SELECT `gid`,`name`,`permissions`,`plugin_perms`,`is_mod`,`is_admin`,`is_banned`,`tag` FROM `users_groups`") - if err != nil { - return err - } - - log.Print("Preparing getForums statement.") - stmts.getForums, err = db.Prepare("SELECT `fid`,`name`,`desc`,`active`,`preset`,`parentID`,`parentType`,`topicCount`,`lastTopicID`,`lastReplyerID` FROM `forums` ORDER BY fid ASC") - if err != nil { - return err - } - - log.Print("Preparing getForumsPermissions statement.") - stmts.getForumsPermissions, err = db.Prepare("SELECT `gid`,`fid`,`permissions` FROM `forums_permissions` ORDER BY gid ASC,fid ASC") - if err != nil { - return err - } - log.Print("Preparing getPlugins statement.") stmts.getPlugins, err = db.Prepare("SELECT `uname`,`active`,`installed` FROM `plugins`") if err != nil { @@ -406,12 +366,6 @@ func _gen_mysql() (err error) { return err } - log.Print("Preparing createReply statement.") - stmts.createReply, err = db.Prepare("INSERT INTO `replies`(`tid`,`content`,`parsed_content`,`createdAt`,`lastUpdated`,`ipaddress`,`words`,`createdBy`) VALUES (?,?,?,UTC_TIMESTAMP(),UTC_TIMESTAMP(),?,?,?)") - if err != nil { - return err - } - log.Print("Preparing createActionReply statement.") stmts.createActionReply, err = db.Prepare("INSERT INTO `replies`(`tid`,`actionType`,`ipaddress`,`createdBy`,`createdAt`,`lastUpdated`,`content`,`parsed_content`) VALUES (?,?,?,?,UTC_TIMESTAMP(),UTC_TIMESTAMP(),'','')") if err != nil { @@ -454,12 +408,6 @@ func _gen_mysql() (err error) { return err } - log.Print("Preparing createForum statement.") - stmts.createForum, err = db.Prepare("INSERT INTO `forums`(`name`,`desc`,`active`,`preset`) VALUES (?,?,?,?)") - if err != nil { - return err - } - log.Print("Preparing addForumPermsToForum statement.") stmts.addForumPermsToForum, err = db.Prepare("INSERT INTO `forums_permissions`(`gid`,`fid`,`preset`,`permissions`) VALUES (?,?,?,?)") if err != nil { @@ -514,24 +462,6 @@ func _gen_mysql() (err error) { return err } - log.Print("Preparing addTopicsToForum statement.") - stmts.addTopicsToForum, err = db.Prepare("UPDATE `forums` SET `topicCount` = `topicCount` + ? WHERE `fid` = ?") - if err != nil { - return err - } - - log.Print("Preparing removeTopicsFromForum statement.") - stmts.removeTopicsFromForum, err = db.Prepare("UPDATE `forums` SET `topicCount` = `topicCount` - ? WHERE `fid` = ?") - if err != nil { - return err - } - - log.Print("Preparing updateForumCache statement.") - stmts.updateForumCache, err = db.Prepare("UPDATE `forums` SET `lastTopicID` = ?,`lastReplyerID` = ? WHERE `fid` = ?") - if err != nil { - return err - } - log.Print("Preparing addLikesToTopic statement.") stmts.addLikesToTopic, err = db.Prepare("UPDATE `topics` SET `likeCount` = `likeCount` + ? WHERE `tid` = ?") if err != nil { diff --git a/gen_pgsql.go b/gen_pgsql.go index 74913f9c..6aafb582 100644 --- a/gen_pgsql.go +++ b/gen_pgsql.go @@ -10,9 +10,6 @@ import "database/sql" type Stmts struct { addRepliesToTopic *sql.Stmt removeRepliesFromTopic *sql.Stmt - addTopicsToForum *sql.Stmt - removeTopicsFromForum *sql.Stmt - updateForumCache *sql.Stmt addLikesToTopic *sql.Stmt addLikesToReply *sql.Stmt editTopic *sql.Stmt @@ -82,24 +79,6 @@ func _gen_pgsql() (err error) { return err } - log.Print("Preparing addTopicsToForum statement.") - stmts.addTopicsToForum, err = db.Prepare("UPDATE `forums` SET `topicCount` = `topicCount` + ? WHERE `fid` = ?") - if err != nil { - return err - } - - log.Print("Preparing removeTopicsFromForum statement.") - stmts.removeTopicsFromForum, err = db.Prepare("UPDATE `forums` SET `topicCount` = `topicCount` - ? WHERE `fid` = ?") - if err != nil { - return err - } - - log.Print("Preparing updateForumCache statement.") - stmts.updateForumCache, err = db.Prepare("UPDATE `forums` SET `lastTopicID` = ?,`lastReplyerID` = ? WHERE `fid` = ?") - if err != nil { - return err - } - log.Print("Preparing addLikesToTopic statement.") stmts.addLikesToTopic, err = db.Prepare("UPDATE `topics` SET `likeCount` = `likeCount` + ? WHERE `tid` = ?") if err != nil { diff --git a/general_test.go b/general_test.go index bdc96284..8fb4aef6 100644 --- a/general_test.go +++ b/general_test.go @@ -79,8 +79,14 @@ func gloinit() (err error) { return err } - rstore = NewSQLReplyStore() - prstore = NewSQLProfileReplyStore() + rstore, err = NewSQLReplyStore() + if err != nil { + return err + } + prstore, err = NewSQLProfileReplyStore() + if err != nil { + return err + } dbProd = db //db_test, err = sql.Open("testdb","") @@ -538,7 +544,7 @@ func BenchmarkQueryPreparedTopicParallel(b *testing.B) { b.RunParallel(func(pb *testing.PB) { var tu TopicUser for pb.Next() { - err := getTopicUserStmt.QueryRow(1).Scan(&tu.Title, &tu.Content, &tu.CreatedBy, &tu.CreatedAt, &tu.IsClosed, &tu.Sticky, &tu.ParentID, &tu.IPAddress, &tu.PostCount, &tu.LikeCount, &tu.CreatedByName, &tu.Avatar, &tu.Group, &tu.URLPrefix, &tu.URLName, &tu.Level) + err := stmts.getTopicUser.QueryRow(1).Scan(&tu.Title, &tu.Content, &tu.CreatedBy, &tu.CreatedAt, &tu.IsClosed, &tu.Sticky, &tu.ParentID, &tu.IPAddress, &tu.PostCount, &tu.LikeCount, &tu.CreatedByName, &tu.Avatar, &tu.Group, &tu.URLPrefix, &tu.URLName, &tu.Level) if err == ErrNoRows { b.Fatal("No rows found!") return diff --git a/group_store.go b/group_store.go index 0efa3344..0f8d5dcf 100644 --- a/group_store.go +++ b/group_store.go @@ -36,12 +36,17 @@ type GroupCache interface { type MemoryGroupStore struct { groups map[int]*Group // TODO: Use a sync.Map instead of a map? groupCount int + getAll *sql.Stmt get *sql.Stmt sync.RWMutex } func NewMemoryGroupStore() (*MemoryGroupStore, error) { + getAllStmt, err := qgen.Builder.SimpleSelect("users_groups", "gid, name, permissions, plugin_perms, is_mod, is_admin, is_banned, tag", "", "", "") + if err != nil { + return nil, err + } getGroupStmt, err := qgen.Builder.SimpleSelect("users_groups", "name, permissions, plugin_perms, is_mod, is_admin, is_banned, tag", "gid = ?", "", "") if err != nil { return nil, err @@ -50,6 +55,7 @@ func NewMemoryGroupStore() (*MemoryGroupStore, error) { return &MemoryGroupStore{ groups: make(map[int]*Group), groupCount: 0, + getAll: getAllStmt, get: getGroupStmt, }, nil } @@ -60,7 +66,7 @@ func (mgs *MemoryGroupStore) LoadGroups() error { defer mgs.Unlock() mgs.groups[0] = &Group{ID: 0, Name: "Unknown"} - rows, err := stmts.getGroups.Query() + rows, err := mgs.getAll.Query() if err != nil { return err } diff --git a/main.go b/main.go index 5175ab34..f0b835b8 100644 --- a/main.go +++ b/main.go @@ -82,6 +82,7 @@ type Globs struct { stmts *Stmts } +// TODO: Split this function up func main() { // TODO: Recover from panics /*defer func() { @@ -130,8 +131,14 @@ func main() { log.Fatal(err) } - rstore = NewSQLReplyStore() - prstore = NewSQLProfileReplyStore() + rstore, err = NewSQLReplyStore() + if err != nil { + log.Fatal(err) + } + prstore, err = NewSQLProfileReplyStore() + if err != nil { + log.Fatal(err) + } initTemplates() diff --git a/member_routes.go b/member_routes.go index b3278858..9503c4f1 100644 --- a/member_routes.go +++ b/member_routes.go @@ -680,11 +680,7 @@ func routeReportSubmit(w http.ResponseWriter, r *http.Request, user User, sitemI return InternalError(err, w, r) } - _, err = stmts.addTopicsToForum.Exec(1, fid) - if err != nil { - return InternalError(err, w, r) - } - err = fstore.UpdateLastTopic(int(lastID), user.ID, fid) + err = fstore.AddTopic(int(lastID), user.ID, fid) if err != nil && err != ErrNoRows { return InternalError(err, w, r) } diff --git a/plugin_markdown.go b/plugin_markdown.go index ec632b12..90227281 100644 --- a/plugin_markdown.go +++ b/plugin_markdown.go @@ -45,7 +45,10 @@ func deactivateMarkdown() { // An adapter for the parser, so that the parser can call itself recursively. // This is less for the simple Markdown elements like bold and italics and more for the really complicated ones I plan on adding at some point. func markdownParse(msg string) string { - msg = strings.TrimSuffix(_markdownParse(msg+" ", 0), " ") + msg = _markdownParse(msg+" ", 0) + if msg[len(msg)-1] == ' ' { + msg = msg[:len(msg)-1] + } log.Print("final msg: ", msg) return msg } @@ -292,9 +295,13 @@ func _markdownParse(msg string, n int) string { if len(outbytes) == 0 { return msg + //return msg[:len(msg)-1] } else if lastElement < (len(msg) - 1) { - return string(outbytes) + msg[lastElement:] + msg = string(outbytes) + msg[lastElement:] + return msg + //return msg[:len(msg)-1] } + //return string(outbytes[:len(outbytes)-1]) return string(outbytes) } diff --git a/plugin_test.go b/plugin_test.go index bb2093bd..63179491 100644 --- a/plugin_test.go +++ b/plugin_test.go @@ -23,6 +23,12 @@ func TestBBCodeRender(t *testing.T) { var res string var msgList []MEPair + msgList = addMEPair(msgList, "", "") + msgList = addMEPair(msgList, " ", " ") + msgList = addMEPair(msgList, " ", " ") + msgList = addMEPair(msgList, " ", " ") + msgList = addMEPair(msgList, "[b]", "") + msgList = addMEPair(msgList, "[b][/b]", "") msgList = addMEPair(msgList, "hi", "hi") msgList = addMEPair(msgList, "😀", "😀") msgList = addMEPair(msgList, "[b]😀[/b]", "😀") @@ -57,7 +63,7 @@ func TestBBCodeRender(t *testing.T) { if res != item.Expects { t.Error("Testing string '" + item.Msg + "'") t.Error("Bad output:", "'"+res+"'") - t.Error("Expected:", item.Expects) + t.Error("Expected:", "'"+item.Expects+"'") } } @@ -70,7 +76,7 @@ func TestBBCodeRender(t *testing.T) { res = bbcodeFullParse(msg) if res != expects { t.Error("Bad output:", "'"+res+"'") - t.Error("Expected:", expects) + t.Error("Expected:", "'"+expects+"'") } msg = "[rand]-1[/rand]" @@ -79,7 +85,7 @@ func TestBBCodeRender(t *testing.T) { res = bbcodeFullParse(msg) if res != expects { t.Error("Bad output:", "'"+res+"'") - t.Error("Expected:", expects) + t.Error("Expected:", "'"+expects+"'") } msg = "[rand]-01[/rand]" @@ -88,7 +94,7 @@ func TestBBCodeRender(t *testing.T) { res = bbcodeFullParse(msg) if res != expects { t.Error("Bad output:", "'"+res+"'") - t.Error("Expected:", expects) + t.Error("Expected:", "'"+expects+"'") } msg = "[rand]NaN[/rand]" @@ -97,7 +103,7 @@ func TestBBCodeRender(t *testing.T) { res = bbcodeFullParse(msg) if res != expects { t.Error("Bad output:", "'"+res+"'") - t.Error("Expected:", expects) + t.Error("Expected:", "'"+expects+"'") } msg = "[rand]Inf[/rand]" @@ -106,7 +112,7 @@ func TestBBCodeRender(t *testing.T) { res = bbcodeFullParse(msg) if res != expects { t.Error("Bad output:", "'"+res+"'") - t.Error("Expected:", expects) + t.Error("Expected:", "'"+expects+"'") } msg = "[rand]+[/rand]" @@ -115,7 +121,7 @@ func TestBBCodeRender(t *testing.T) { res = bbcodeFullParse(msg) if res != expects { t.Error("Bad output:", "'"+res+"'") - t.Error("Expected:", expects) + t.Error("Expected:", "'"+expects+"'") } msg = "[rand]1+1[/rand]" @@ -124,7 +130,7 @@ func TestBBCodeRender(t *testing.T) { res = bbcodeFullParse(msg) if res != expects { t.Error("Bad output:", "'"+res+"'") - t.Error("Expected:", expects) + t.Error("Expected:", "'"+expects+"'") } var conv int @@ -202,6 +208,15 @@ func TestMarkdownRender(t *testing.T) { var res string var msgList []MEPair + msgList = addMEPair(msgList, "", "") + msgList = addMEPair(msgList, " ", " ") + msgList = addMEPair(msgList, " ", " ") + msgList = addMEPair(msgList, " ", " ") + msgList = addMEPair(msgList, "\t", "\t") + msgList = addMEPair(msgList, "\n", "\n") + msgList = addMEPair(msgList, "*", "*") + msgList = addMEPair(msgList, "**", "") + msgList = addMEPair(msgList, "h", "h") msgList = addMEPair(msgList, "hi", "hi") msgList = addMEPair(msgList, "**h**", "h") msgList = addMEPair(msgList, "**hi**", "hi") @@ -260,17 +275,17 @@ func TestMarkdownRender(t *testing.T) { t.Error("Testing string '" + item.Msg + "'") t.Error("Bad output:", "'"+res+"'") //t.Error("Ouput in bytes:", []byte(res)) - t.Error("Expected:", item.Expects) + t.Error("Expected:", "'"+item.Expects+"'") } } - for _, item := range msgList { + /*for _, item := range msgList { res = markdownParse("\n" + item.Msg) if res != "\n"+item.Expects { t.Error("Testing string '\n" + item.Msg + "'") t.Error("Bad output:", "'"+res+"'") //t.Error("Ouput in bytes:", []byte(res)) - t.Error("Expected:", "\n"+item.Expects) + t.Error("Expected:", "'\n"+item.Expects+"'") } } @@ -280,9 +295,9 @@ func TestMarkdownRender(t *testing.T) { t.Error("Testing string '\t" + item.Msg + "'") t.Error("Bad output:", "'"+res+"'") //t.Error("Ouput in bytes:", []byte(res)) - t.Error("Expected:", "\t"+item.Expects) + t.Error("Expected:", "'\t"+item.Expects+"'") } - } + }*/ for _, item := range msgList { res = markdownParse("d" + item.Msg) @@ -290,7 +305,7 @@ func TestMarkdownRender(t *testing.T) { t.Error("Testing string 'd" + item.Msg + "'") t.Error("Bad output:", "'"+res+"'") //t.Error("Ouput in bytes:", []byte(res)) - t.Error("Expected:", "d"+item.Expects) + t.Error("Expected:", "'d"+item.Expects+"'") } } diff --git a/query_gen/lib/accumulator.go b/query_gen/lib/accumulator.go new file mode 100644 index 00000000..d1df3508 --- /dev/null +++ b/query_gen/lib/accumulator.go @@ -0,0 +1,248 @@ +/* WIP: A version of the builder which accumulates errors, we'll see if we can't unify the implementations at some point */ +package qgen + +import "database/sql" + +type accBuilder struct { + conn *sql.DB + adapter DB_Adapter + firstErr error +} + +func (build *accBuilder) SetConn(conn *sql.DB) { + build.conn = conn +} + +func (build *accBuilder) SetAdapter(name string) error { + adap, err := GetAdapter(name) + if err != nil { + return err + } + build.adapter = adap + return nil +} + +func (build *accBuilder) GetAdapter() DB_Adapter { + return build.adapter +} + +func (build *accBuilder) FirstError() error { + return build.firstErr +} + +func (build *accBuilder) recordError(err error) { + if err == nil { + return + } + if build.firstErr != nil { + build.firstErr = err + } +} + +func (build *accBuilder) prepare(res string, err error) *sql.Stmt { + if err != nil { + build.recordError(err) + return nil + } + stmt, err := build.conn.Prepare(res) + build.recordError(err) + return stmt +} + +func (build *accBuilder) Tx(handler func(*TransactionBuilder) error) { + tx, err := build.conn.Begin() + if err != nil { + build.recordError(err) + return + } + err = handler(&TransactionBuilder{tx, build.adapter, nil}) + if err != nil { + tx.Rollback() + build.recordError(err) + return + } + build.recordError(tx.Commit()) +} + +func (build *accBuilder) SimpleSelect(table string, columns string, where string, orderby string, limit string) *sql.Stmt { + return build.prepare(build.adapter.SimpleSelect("_builder", table, columns, where, orderby, limit)) +} + +func (build *accBuilder) SimpleCount(table string, where string, limit string) *sql.Stmt { + return build.prepare(build.adapter.SimpleCount("_builder", table, where, limit)) +} + +func (build *accBuilder) SimpleLeftJoin(table1 string, table2 string, columns string, joiners string, where string, orderby string, limit string) *sql.Stmt { + return build.prepare(build.adapter.SimpleLeftJoin("_builder", table1, table2, columns, joiners, where, orderby, limit)) +} + +func (build *accBuilder) SimpleInnerJoin(table1 string, table2 string, columns string, joiners string, where string, orderby string, limit string) *sql.Stmt { + return build.prepare(build.adapter.SimpleInnerJoin("_builder", table1, table2, columns, joiners, where, orderby, limit)) +} + +func (build *accBuilder) CreateTable(table string, charset string, collation string, columns []DB_Table_Column, keys []DB_Table_Key) *sql.Stmt { + return build.prepare(build.adapter.CreateTable("_builder", table, charset, collation, columns, keys)) +} + +func (build *accBuilder) SimpleInsert(table string, columns string, fields string) *sql.Stmt { + return build.prepare(build.adapter.SimpleInsert("_builder", table, columns, fields)) +} + +func (build *accBuilder) SimpleInsertSelect(ins DB_Insert, sel DB_Select) *sql.Stmt { + return build.prepare(build.adapter.SimpleInsertSelect("_builder", ins, sel)) +} + +func (build *accBuilder) SimpleInsertLeftJoin(ins DB_Insert, sel DB_Join) *sql.Stmt { + return build.prepare(build.adapter.SimpleInsertLeftJoin("_builder", ins, sel)) +} + +func (build *accBuilder) SimpleInsertInnerJoin(ins DB_Insert, sel DB_Join) *sql.Stmt { + return build.prepare(build.adapter.SimpleInsertInnerJoin("_builder", ins, sel)) +} + +func (build *accBuilder) SimpleUpdate(table string, set string, where string) *sql.Stmt { + return build.prepare(build.adapter.SimpleUpdate("_builder", table, set, where)) +} + +func (build *accBuilder) SimpleDelete(table string, where string) *sql.Stmt { + return build.prepare(build.adapter.SimpleDelete("_builder", table, where)) +} + +// I don't know why you need this, but here it is x.x +func (build *accBuilder) Purge(table string) *sql.Stmt { + return build.prepare(build.adapter.Purge("_builder", table)) +} + +// These ones support transactions +func (build *accBuilder) SimpleSelectTx(tx *sql.Tx, table string, columns string, where string, orderby string, limit string) (stmt *sql.Stmt) { + res, err := build.adapter.SimpleSelect("_builder", table, columns, where, orderby, limit) + if err != nil { + build.recordError(err) + return nil + } + stmt, err = tx.Prepare(res) + build.recordError(err) + return stmt +} + +func (build *accBuilder) SimpleCountTx(tx *sql.Tx, table string, where string, limit string) (stmt *sql.Stmt) { + res, err := build.adapter.SimpleCount("_builder", table, where, limit) + if err != nil { + build.recordError(err) + return nil + } + stmt, err = tx.Prepare(res) + build.recordError(err) + return stmt +} + +func (build *accBuilder) SimpleLeftJoinTx(tx *sql.Tx, table1 string, table2 string, columns string, joiners string, where string, orderby string, limit string) (stmt *sql.Stmt) { + res, err := build.adapter.SimpleLeftJoin("_builder", table1, table2, columns, joiners, where, orderby, limit) + if err != nil { + build.recordError(err) + return nil + } + stmt, err = tx.Prepare(res) + build.recordError(err) + return stmt +} + +func (build *accBuilder) SimpleInnerJoinTx(tx *sql.Tx, table1 string, table2 string, columns string, joiners string, where string, orderby string, limit string) (stmt *sql.Stmt) { + res, err := build.adapter.SimpleInnerJoin("_builder", table1, table2, columns, joiners, where, orderby, limit) + if err != nil { + build.recordError(err) + return nil + } + stmt, err = tx.Prepare(res) + build.recordError(err) + return stmt +} + +func (build *accBuilder) CreateTableTx(tx *sql.Tx, table string, charset string, collation string, columns []DB_Table_Column, keys []DB_Table_Key) (stmt *sql.Stmt) { + res, err := build.adapter.CreateTable("_builder", table, charset, collation, columns, keys) + if err != nil { + build.recordError(err) + return nil + } + stmt, err = tx.Prepare(res) + build.recordError(err) + return stmt +} + +func (build *accBuilder) SimpleInsertTx(tx *sql.Tx, table string, columns string, fields string) (stmt *sql.Stmt) { + res, err := build.adapter.SimpleInsert("_builder", table, columns, fields) + if err != nil { + build.recordError(err) + return nil + } + stmt, err = tx.Prepare(res) + build.recordError(err) + return stmt +} + +func (build *accBuilder) SimpleInsertSelectTx(tx *sql.Tx, ins DB_Insert, sel DB_Select) (stmt *sql.Stmt) { + res, err := build.adapter.SimpleInsertSelect("_builder", ins, sel) + if err != nil { + build.recordError(err) + return nil + } + stmt, err = tx.Prepare(res) + build.recordError(err) + return stmt +} + +func (build *accBuilder) SimpleInsertLeftJoinTx(tx *sql.Tx, ins DB_Insert, sel DB_Join) (stmt *sql.Stmt) { + res, err := build.adapter.SimpleInsertLeftJoin("_builder", ins, sel) + if err != nil { + build.recordError(err) + return nil + } + stmt, err = tx.Prepare(res) + build.recordError(err) + return stmt +} + +func (build *accBuilder) SimpleInsertInnerJoinTx(tx *sql.Tx, ins DB_Insert, sel DB_Join) (stmt *sql.Stmt) { + res, err := build.adapter.SimpleInsertInnerJoin("_builder", ins, sel) + if err != nil { + build.recordError(err) + return nil + } + stmt, err = tx.Prepare(res) + build.recordError(err) + return stmt +} + +func (build *accBuilder) SimpleUpdateTx(tx *sql.Tx, table string, set string, where string) (stmt *sql.Stmt) { + res, err := build.adapter.SimpleUpdate("_builder", table, set, where) + if err != nil { + build.recordError(err) + return nil + } + stmt, err = tx.Prepare(res) + build.recordError(err) + return stmt +} + +func (build *accBuilder) SimpleDeleteTx(tx *sql.Tx, table string, where string) (stmt *sql.Stmt) { + res, err := build.adapter.SimpleDelete("_builder", table, where) + if err != nil { + build.recordError(err) + return nil + } + stmt, err = tx.Prepare(res) + build.recordError(err) + return stmt +} + +// I don't know why you need this, but here it is x.x +func (build *accBuilder) PurgeTx(tx *sql.Tx, table string) (stmt *sql.Stmt) { + res, err := build.adapter.Purge("_builder", table) + if err != nil { + build.recordError(err) + return nil + } + stmt, err = tx.Prepare(res) + build.recordError(err) + return stmt +} diff --git a/query_gen/lib/builder.go b/query_gen/lib/builder.go index f9e22580..842cc107 100644 --- a/query_gen/lib/builder.go +++ b/query_gen/lib/builder.go @@ -16,6 +16,10 @@ type builder struct { adapter DB_Adapter } +func (build *builder) Accumulator() *accBuilder { + return &accBuilder{build.conn, build.adapter, nil} +} + func (build *builder) SetConn(conn *sql.DB) { build.conn = conn } diff --git a/query_gen/main.go b/query_gen/main.go index 3f3557cd..2a894fd6 100644 --- a/query_gen/main.go +++ b/query_gen/main.go @@ -222,10 +222,6 @@ func seedTables(adapter qgen.DB_Adapter) error { func writeSelects(adapter qgen.DB_Adapter) error { // Looking for getTopic? Your statement is in another castle - adapter.SimpleSelect("getReply", "replies", "tid, content, createdBy, createdAt, lastEdit, lastEditBy, ipaddress, likeCount", "rid = ?", "", "") - - adapter.SimpleSelect("getUserReply", "users_replies", "uid, content, createdBy, createdAt, lastEdit, lastEditBy, ipaddress", "rid = ?", "", "") - adapter.SimpleSelect("getPassword", "users", "password, salt", "uid = ?", "", "") adapter.SimpleSelect("getSettings", "settings", "name, content, type", "", "", "") @@ -236,12 +232,6 @@ func writeSelects(adapter qgen.DB_Adapter) error { adapter.SimpleSelect("getFullSettings", "settings", "name, content, type, constraints", "", "", "") - adapter.SimpleSelect("getGroups", "users_groups", "gid, name, permissions, plugin_perms, is_mod, is_admin, is_banned, tag", "", "", "") - - adapter.SimpleSelect("getForums", "forums", "fid, name, desc, active, preset, parentID, parentType, topicCount, lastTopicID, lastReplyerID", "", "fid ASC", "") - - adapter.SimpleSelect("getForumsPermissions", "forums_permissions", "gid, fid, permissions", "", "gid ASC, fid ASC", "") - adapter.SimpleSelect("getPlugins", "plugins", "uname, active, installed", "", "", "") adapter.SimpleSelect("getThemes", "themes", "uname, default", "", "", "") @@ -328,8 +318,6 @@ func writeInserts(adapter qgen.DB_Adapter) error { adapter.SimpleInsert("createReport", "topics", "title, content, parsed_content, createdAt, lastReplyAt, createdBy, lastReplyBy, data, parentID, css_class", "?,?,?,UTC_TIMESTAMP(),UTC_TIMESTAMP(),?,?,?,1,'report'") - adapter.SimpleInsert("createReply", "replies", "tid, content, parsed_content, createdAt, lastUpdated, ipaddress, words, createdBy", "?,?,?,UTC_TIMESTAMP(),UTC_TIMESTAMP(),?,?,?") - adapter.SimpleInsert("createActionReply", "replies", "tid, actionType, ipaddress, createdBy, createdAt, lastUpdated, content, parsed_content", "?,?,?,?,UTC_TIMESTAMP(),UTC_TIMESTAMP(),'',''") adapter.SimpleInsert("createLike", "likes", "weight, targetItem, targetType, sentBy", "?,?,?,?") @@ -344,8 +332,6 @@ func writeInserts(adapter qgen.DB_Adapter) error { adapter.SimpleInsert("addSubscription", "activity_subscriptions", "user, targetID, targetType, level", "?,?,?,2") - adapter.SimpleInsert("createForum", "forums", "name, desc, active, preset", "?,?,?,?") - adapter.SimpleInsert("addForumPermsToForum", "forums_permissions", "gid,fid,preset,permissions", "?,?,?,?") adapter.SimpleInsert("addPlugin", "plugins", "uname, active, installed", "?,?,?") @@ -387,12 +373,6 @@ func writeUpdates(adapter qgen.DB_Adapter) error { adapter.SimpleUpdate("removeRepliesFromTopic", "topics", "postCount = postCount - ?", "tid = ?") - adapter.SimpleUpdate("addTopicsToForum", "forums", "topicCount = topicCount + ?", "fid = ?") - - adapter.SimpleUpdate("removeTopicsFromForum", "forums", "topicCount = topicCount - ?", "fid = ?") - - adapter.SimpleUpdate("updateForumCache", "forums", "lastTopicID = ?, lastReplyerID = ?", "fid = ?") - adapter.SimpleUpdate("addLikesToTopic", "topics", "likeCount = likeCount + ?", "tid = ?") adapter.SimpleUpdate("addLikesToReply", "replies", "likeCount = likeCount + ?", "rid = ?") diff --git a/reply.go b/reply.go index aad06070..105b6d7a 100644 --- a/reply.go +++ b/reply.go @@ -7,8 +7,11 @@ package main import ( + "database/sql" "errors" "time" + + "./query_gen/lib" ) // ? - Should we add a reply store to centralise all the reply logic? Would this cover profile replies too or would that be separate? @@ -107,22 +110,35 @@ type ReplyStore interface { } type SQLReplyStore struct { + get *sql.Stmt + create *sql.Stmt } -func NewSQLReplyStore() *SQLReplyStore { - return &SQLReplyStore{} +func NewSQLReplyStore() (*SQLReplyStore, error) { + getReplyStmt, err := qgen.Builder.SimpleSelect("replies", "tid, content, createdBy, createdAt, lastEdit, lastEditBy, ipaddress, likeCount", "rid = ?", "", "") + if err != nil { + return nil, err + } + createReplyStmt, err := qgen.Builder.SimpleInsert("replies", "tid, content, parsed_content, createdAt, lastUpdated, ipaddress, words, createdBy", "?,?,?,UTC_TIMESTAMP(),UTC_TIMESTAMP(),?,?,?") + if err != nil { + return nil, err + } + return &SQLReplyStore{ + get: getReplyStmt, + create: createReplyStmt, + }, nil } func (store *SQLReplyStore) Get(id int) (*Reply, error) { reply := Reply{ID: id} - err := stmts.getReply.QueryRow(id).Scan(&reply.ParentID, &reply.Content, &reply.CreatedBy, &reply.CreatedAt, &reply.LastEdit, &reply.LastEditBy, &reply.IPAddress, &reply.LikeCount) + err := store.get.QueryRow(id).Scan(&reply.ParentID, &reply.Content, &reply.CreatedBy, &reply.CreatedAt, &reply.LastEdit, &reply.LastEditBy, &reply.IPAddress, &reply.LikeCount) return &reply, err } // TODO: Write a test for this func (store *SQLReplyStore) Create(tid int, content string, ipaddress string, fid int, uid int) (id int, err error) { wcount := wordCount(content) - res, err := stmts.createReply.Exec(tid, content, parseMessage(content, fid, "forums"), ipaddress, wcount, uid) + res, err := store.create.Exec(tid, content, parseMessage(content, fid, "forums"), ipaddress, wcount, uid) if err != nil { return 0, err } @@ -147,15 +163,23 @@ type ProfileReplyStore interface { } // TODO: Refactor this to stop using the global stmt store +// TODO: Add more methods to this like Create() type SQLProfileReplyStore struct { + get *sql.Stmt } -func NewSQLProfileReplyStore() *SQLProfileReplyStore { - return &SQLProfileReplyStore{} +func NewSQLProfileReplyStore() (*SQLProfileReplyStore, error) { + getUserReplyStmt, err := qgen.Builder.SimpleSelect("users_replies", "uid, content, createdBy, createdAt, lastEdit, lastEditBy, ipaddress", "rid = ?", "", "") + if err != nil { + return nil, err + } + return &SQLProfileReplyStore{ + get: getUserReplyStmt, + }, nil } func (store *SQLProfileReplyStore) Get(id int) (*Reply, error) { reply := Reply{ID: id} - err := stmts.getUserReply.QueryRow(id).Scan(&reply.ParentID, &reply.Content, &reply.CreatedBy, &reply.CreatedAt, &reply.LastEdit, &reply.LastEditBy, &reply.IPAddress) + err := store.get.QueryRow(id).Scan(&reply.ParentID, &reply.Content, &reply.CreatedBy, &reply.CreatedAt, &reply.LastEdit, &reply.LastEditBy, &reply.IPAddress) return &reply, err } diff --git a/themes.go b/themes.go index 2b543f2d..343ace3a 100644 --- a/themes.go +++ b/themes.go @@ -19,7 +19,10 @@ import ( "text/template" ) -var themes = make(map[string]Theme) +//var themes = make(map[string]Theme) +type ThemeList map[string]Theme + +var themes ThemeList = make(map[string]Theme) var defaultThemeBox atomic.Value var changeDefaultThemeMutex sync.Mutex @@ -74,8 +77,9 @@ func init() { defaultThemeBox.Store(fallbackTheme) } +// TODO: Make the initThemes and LoadThemes functions less confusing // ? - Delete themes which no longer exist in the themes folder from the database? -func LoadThemes() error { +func LoadThemeActiveStatus() error { changeDefaultThemeMutex.Lock() rows, err := stmts.getThemes.Query() if err != nil {