Add GetForRenderRoute to DefaultAttachmentStore.

Use ErrCorruptAttachPath in DefaultAttachmentStore.
Consider the possibility that the requested attachment doesn't have a file extension.
Rename variables to reduce boilerplate.
Add TestThaw.
Avoid an allocation in hookgen.

Add route_attach_start hook.
Add route_attach_post_get hook.
This commit is contained in:
Azareal 2021-03-24 21:45:18 +10:00
parent adfed477a0
commit 26e8bf32a7
8 changed files with 160 additions and 137 deletions

View File

@ -1,9 +1,9 @@
package hookgen
import (
"bytes"
"log"
"os"
"bytes"
"text/template"
)
@ -27,34 +27,36 @@ type Hook struct {
func AddHooks(add func(name, params, ret, htype string, multiHook, skip bool, defaultRet, pure string)) {
vhookskip := func(name, params string) {
add(name,params,"(bool,RouteError)","VhookSkippable_",false,true,"false,nil","")
add(name, params, "(bool,RouteError)", "VhookSkippable_", false, true, "false,nil", "")
}
vhookskip("forum_check_pre_perms","w http.ResponseWriter,r *http.Request,u *User,fid *int,h *Header")
vhookskip("router_after_filters","w http.ResponseWriter,r *http.Request,prefix string")
vhookskip("router_pre_route","w http.ResponseWriter,r *http.Request,u *User,prefix string")
vhookskip("route_forum_list_start","w http.ResponseWriter,r *http.Request,u *User,h *Header")
vhookskip("route_topic_list_start","w http.ResponseWriter,r *http.Request,u *User,h *Header")
vhookskip("forum_check_pre_perms", "w http.ResponseWriter,r *http.Request,u *User,fid *int,h *Header")
vhookskip("router_after_filters", "w http.ResponseWriter,r *http.Request,prefix string")
vhookskip("router_pre_route", "w http.ResponseWriter,r *http.Request,u *User,prefix string")
vhookskip("route_forum_list_start", "w http.ResponseWriter,r *http.Request,u *User,h *Header")
vhookskip("route_topic_list_start", "w http.ResponseWriter,r *http.Request,u *User,h *Header")
vhookskip("route_attach_start", "w http.ResponseWriter,r *http.Request,u *User,fname string")
vhookskip("route_attach_post_get", "w http.ResponseWriter,r *http.Request,u *User,a *Attachment")
vhooknoret := func(name, params string) {
add(name,params,"","Vhooks",false,false,"false,nil","")
add(name, params, "", "Vhooks", false, false, "false,nil", "")
}
vhooknoret("router_end","w http.ResponseWriter,r *http.Request,u *User,prefix string, extraData string")
vhooknoret("topic_reply_row_assign","r *ReplyUser")
vhooknoret("router_end", "w http.ResponseWriter,r *http.Request,u *User,prefix string,extraData string")
vhooknoret("topic_reply_row_assign", "r *ReplyUser")
//forums_frow_assign
//Hook(name string, data interface{}) interface{}
/*hook := func(name, params, ret, pure string) {
add(name,params,ret,"Hooks",true,false,ret,pure)
}*/
hooknoret := func(name, params string) {
add(name,params,"","HooksNoRet",true,false,"","")
add(name, params, "", "HooksNoRet", true, false, "", "")
}
hooknoret("forums_frow_assign","f *Forum")
hooknoret("forums_frow_assign", "f *Forum")
hookskip := func(name, params string) {
add(name,params,"(skip bool)","HooksSkip",true,true,"","")
add(name, params, "(skip bool)", "HooksSkip", true, true, "", "")
}
//hookskip("forums_frow_assign","f *Forum")
hookskip("topic_create_frow_assign","f *Forum")
hookskip("topic_create_frow_assign", "f *Forum")
hookss := func(name string) {
add(name,"d string","string","Sshooks",true,false,"","d")
add(name, "d string", "string", "Sshooks", true, false, "", "d")
}
hookss("topic_ogdesc_assign")
}
@ -86,7 +88,7 @@ func H_{{.Name}}_hook(t *HookTable,{{.Params}}) {{.Ret}} { {{if .Any}}
log.Fatal(e)
}
err := writeFile("./common/gen_extend.go", string(b.Bytes()))
err := writeFile("./common/gen_extend.go", b.String())
if err != nil {
log.Fatal(err)
}

View File

@ -14,6 +14,8 @@ import (
var Attachments AttachmentStore
var ErrCorruptAttachPath = errors.New("corrupt attachment path")
type MiniAttachment struct {
ID int
SectionID int
@ -41,6 +43,7 @@ type Attachment struct {
}
type AttachmentStore interface {
GetForRenderRoute(filename string, sid int, sectionTable string) (*Attachment, error)
FGet(id int) (*Attachment, error)
Get(id int) (*MiniAttachment, error)
MiniGetList(originTable string, originID int) (alist []*MiniAttachment, err error)
@ -58,6 +61,8 @@ type AttachmentStore interface {
}
type DefaultAttachmentStore struct {
getForRenderRoute *sql.Stmt
fget *sql.Stmt
get *sql.Stmt
getByObj *sql.Stmt
@ -76,6 +81,8 @@ type DefaultAttachmentStore struct {
func NewDefaultAttachmentStore(acc *qgen.Accumulator) (*DefaultAttachmentStore, error) {
a := "attachments"
return &DefaultAttachmentStore{
getForRenderRoute: acc.Select(a).Columns("sectionTable, originID, originTable, uploadedBy, path").Where("path=? AND sectionID=? AND sectionTable=?").Prepare(),
fget: acc.Select(a).Columns("originTable, originID, sectionTable, sectionID, uploadedBy, path, extra").Where("attachID=?").Prepare(),
get: acc.Select(a).Columns("originID, sectionID, uploadedBy, path, extra").Where("attachID=?").Prepare(),
getByObj: acc.Select(a).Columns("attachID, sectionID, uploadedBy, path, extra").Where("originTable=? AND originID=?").Prepare(),
@ -93,6 +100,15 @@ func NewDefaultAttachmentStore(acc *qgen.Accumulator) (*DefaultAttachmentStore,
}, acc.FirstError()
}
// TODO: Revamp this to make it less of a copy-paste from the original code in the route
// ! Lacks some attachment initialisation code
func (s *DefaultAttachmentStore) GetForRenderRoute(filename string, sid int, sectionTable string) (*Attachment, error) {
a := &Attachment{SectionID: sid}
e := s.getForRenderRoute.QueryRow(filename, sid, sectionTable).Scan(&a.SectionTable, &a.OriginID, &a.OriginTable, &a.UploadedBy, &a.Path)
// TODO: Initialise attachment struct fields?
return a, e
}
func (s *DefaultAttachmentStore) MiniGetList(originTable string, originID int) (alist []*MiniAttachment, err error) {
rows, err := s.getByObj.Query(originTable, originID)
defer rows.Close()
@ -104,7 +120,7 @@ func (s *DefaultAttachmentStore) MiniGetList(originTable string, originID int) (
}
a.Ext = strings.TrimPrefix(filepath.Ext(a.Path), ".")
if len(a.Ext) == 0 {
return nil, errors.New("corrupt attachment path")
return nil, ErrCorruptAttachPath
}
a.Image = ImageFileExts.Contains(a.Ext)
alist = append(alist, a)
@ -140,7 +156,7 @@ func (s *DefaultAttachmentStore) BulkMiniGetList(originTable string, ids []int)
}
a.Ext = strings.TrimPrefix(filepath.Ext(a.Path), ".")
if len(a.Ext) == 0 {
return nil, errors.New("corrupt attachment path")
return nil, ErrCorruptAttachPath
}
a.Image = ImageFileExts.Contains(a.Ext)
if currentID == 0 {
@ -169,7 +185,7 @@ func (s *DefaultAttachmentStore) FGet(id int) (*Attachment, error) {
}
a.Ext = strings.TrimPrefix(filepath.Ext(a.Path), ".")
if len(a.Ext) == 0 {
return nil, errors.New("corrupt attachment path")
return nil, ErrCorruptAttachPath
}
a.Image = ImageFileExts.Contains(a.Ext)
return a, nil
@ -183,7 +199,7 @@ func (s *DefaultAttachmentStore) Get(id int) (*MiniAttachment, error) {
}
a.Ext = strings.TrimPrefix(filepath.Ext(a.Path), ".")
if len(a.Ext) == 0 {
return nil, errors.New("corrupt attachment path")
return nil, ErrCorruptAttachPath
}
a.Image = ImageFileExts.Contains(a.Ext)
return a, nil
@ -209,32 +225,32 @@ func (s *DefaultAttachmentStore) MoveToByExtra(sectionID int, originTable, extra
}
func (s *DefaultAttachmentStore) Count() (count int) {
err := s.count.QueryRow().Scan(&count)
if err != nil {
LogError(err)
e := s.count.QueryRow().Scan(&count)
if e != nil {
LogError(e)
}
return count
}
func (s *DefaultAttachmentStore) CountIn(originTable string, oid int) (count int) {
err := s.countIn.QueryRow(originTable, oid).Scan(&count)
if err != nil {
LogError(err)
e := s.countIn.QueryRow(originTable, oid).Scan(&count)
if e != nil {
LogError(e)
}
return count
}
func (s *DefaultAttachmentStore) CountInPath(path string) (count int) {
err := s.countInPath.QueryRow(path).Scan(&count)
if err != nil {
LogError(err)
e := s.countInPath.QueryRow(path).Scan(&count)
if e != nil {
LogError(e)
}
return count
}
func (s *DefaultAttachmentStore) Delete(id int) error {
_, err := s.delete.Exec(id)
return err
_, e := s.delete.Exec(id)
return e
}
// TODO: Split this out of this store
@ -256,10 +272,7 @@ func (s *DefaultAttachmentStore) AddLinked(otable string, oid int) (err error) {
if err == sql.ErrNoRows {
err = nil
}
if err != nil {
return err
}
return nil
}
// TODO: Split this out of this store
@ -280,10 +293,7 @@ func (s *DefaultAttachmentStore) RemoveLinked(otable string, oid int) (err error
}
err = Rstore.GetCache().Remove(oid)
}
if err != nil {
return err
}
return nil
}
// TODO: Add a table for the files and lock the file row when performing tasks related to the file

View File

@ -1,7 +1,7 @@
/*
*
* Gosora Plugin System
* Copyright Azareal 2016 - 2020
* Copyright Azareal 2016 - 2021
*
*/
package common
@ -25,9 +25,9 @@ type PluginList map[string]*Plugin
// TODO: Have a proper store rather than a map?
var Plugins PluginList = make(map[string]*Plugin)
func (list PluginList) Add(pl *Plugin) {
func (l PluginList) Add(pl *Plugin) {
buildPlugin(pl)
list[pl.UName] = pl
l[pl.UName] = pl
}
func buildPlugin(pl *Plugin) {
@ -94,6 +94,8 @@ var hookTable = &HookTable{
"route_topic_list_start": nil,
"route_topic_list_mostviewed_start": nil,
"route_forum_list_start": nil,
"route_attach_start": nil,
"route_attach_post_get": nil,
"action_end_create_topic": nil,
"action_end_edit_topic": nil,

View File

@ -9,7 +9,7 @@ import (
var IPSearch IPSearcher
type IPSearcher interface {
Lookup(ip string) (uids []int, err error)
Lookup(ip string) (uids []int, e error)
}
type DefaultIPSearcher struct {
@ -23,27 +23,29 @@ type DefaultIPSearcher struct {
func NewDefaultIPSearcher() (*DefaultIPSearcher, error) {
acc := qgen.NewAcc()
uu := "users"
q := func(tbl string) *sql.Stmt {
return acc.Select(uu).Columns("uid").InQ("uid", acc.Select(tbl).Columns("createdBy").Where("ip=?")).Prepare()
}
return &DefaultIPSearcher{
searchUsers: acc.Select(uu).Columns("uid").Where("last_ip=? OR last_ip LIKE CONCAT('%-',?)").Prepare(),
searchTopics: acc.Select(uu).Columns("uid").InQ("uid", acc.Select("topics").Columns("createdBy").Where("ip=?")).Prepare(),
searchReplies: acc.Select(uu).Columns("uid").InQ("uid", acc.Select("replies").Columns("createdBy").Where("ip=?")).Prepare(),
searchUsersReplies: acc.Select(uu).Columns("uid").InQ("uid", acc.Select("users_replies").Columns("createdBy").Where("ip=?")).Prepare(),
searchTopics: q("topics"),
searchReplies: q("replies"),
searchUsersReplies: q("users_replies"),
}, acc.FirstError()
}
func (s *DefaultIPSearcher) Lookup(ip string) (uids []int, err error) {
func (s *DefaultIPSearcher) Lookup(ip string) (uids []int, e error) {
var uid int
reqUserList := make(map[int]bool)
runQuery2 := func(rows *sql.Rows, err error) error {
if err != nil {
return err
runQuery2 := func(rows *sql.Rows, e error) error {
if e != nil {
return e
}
defer rows.Close()
for rows.Next() {
err := rows.Scan(&uid)
if err != nil {
return err
if e := rows.Scan(&uid); e != nil {
return e
}
reqUserList[uid] = true
}
@ -53,21 +55,21 @@ func (s *DefaultIPSearcher) Lookup(ip string) (uids []int, err error) {
return runQuery2(stmt.Query(ip))
}
err = runQuery2(s.searchUsers.Query(ip, ip))
if err != nil {
return uids, err
e = runQuery2(s.searchUsers.Query(ip, ip))
if e != nil {
return uids, e
}
err = runQuery(s.searchTopics)
if err != nil {
return uids, err
e = runQuery(s.searchTopics)
if e != nil {
return uids, e
}
err = runQuery(s.searchReplies)
if err != nil {
return uids, err
e = runQuery(s.searchReplies)
if e != nil {
return uids, e
}
err = runQuery(s.searchUsersReplies)
if err != nil {
return uids, err
e = runQuery(s.searchUsersReplies)
if e != nil {
return uids, e
}
// Convert the user ID map to a slice, then bulk load the users

View File

@ -44,13 +44,13 @@ func simpleForumUserCheck(w http.ResponseWriter, r *http.Request, u *User, fid i
return h, rerr
}
fperms, err := FPStore.Get(fid, u.Group)
fp, err := FPStore.Get(fid, u.Group)
if err == ErrNoRows {
fperms = BlankForumPerms()
fp = BlankForumPerms()
} else if err != nil {
return h, InternalError(err, w, r)
}
cascadeForumPerms(fperms, u)
cascadeForumPerms(fp, u)
return h, nil
}
@ -72,13 +72,13 @@ func forumUserCheck(h *Header, w http.ResponseWriter, r *http.Request, u *User,
return rerr
}
fperms, err := FPStore.Get(fid, u.Group)
fp, err := FPStore.Get(fid, u.Group)
if err == ErrNoRows {
fperms = BlankForumPerms()
fp = BlankForumPerms()
} else if err != nil {
return InternalError(err, w, r)
}
cascadeForumPerms(fperms, u)
cascadeForumPerms(fp, u)
h.CurrentUser = u // TODO: Use a pointer instead for CurrentUser, so we don't have to do this
return rerr
}

View File

@ -44,19 +44,18 @@ func NewSQLSearcher(acc *qgen.Accumulator) (*SQLSearcher, error) {
func (s *SQLSearcher) queryAll(q string) ([]int, error) {
var ids []int
run := func(stmt *sql.Stmt, q ...interface{}) error {
rows, err := stmt.Query(q...)
if err == sql.ErrNoRows {
rows, e := stmt.Query(q...)
if e == sql.ErrNoRows {
return nil
} else if err != nil {
return err
} else if e != nil {
return e
}
defer rows.Close()
for rows.Next() {
var id int
err := rows.Scan(&id)
if err != nil {
return err
if e := rows.Scan(&id); e != nil {
return e
}
ids = append(ids, id)
}
@ -81,19 +80,18 @@ func (s *SQLSearcher) Query(q string, zones []int) (ids []int, err error) {
if len(zones) == 0 {
return nil, nil
}
run := func(rows *sql.Rows, err error) error {
/*if err == sql.ErrNoRows {
run := func(rows *sql.Rows, e error) error {
/*if e == sql.ErrNoRows {
return nil
} else */if err != nil {
return err
} else */if e != nil {
return e
}
defer rows.Close()
for rows.Next() {
var id int
err := rows.Scan(&id)
if err != nil {
return err
if e := rows.Scan(&id); e != nil {
return e
}
ids = append(ids, id)
}
@ -116,14 +114,12 @@ func (s *SQLSearcher) Query(q string, zones []int) (ids []int, err error) {
acc := qgen.NewAcc()
/*stmt := acc.RawPrepare("SELECT topics.tid FROM topics INNER JOIN replies ON topics.tid = replies.tid WHERE (MATCH(topics.title) AGAINST (? IN BOOLEAN MODE) OR MATCH(topics.content) AGAINST (? IN BOOLEAN MODE) OR MATCH(replies.content) AGAINST (? IN BOOLEAN MODE) OR topics.title=? OR topics.content=? OR replies.content=?) AND topics.parentID IN(" + zList + ")")
err = acc.FirstError()
if err != nil {
if err = acc.FirstError(); err != nil {
return nil, err
}*/
// TODO: Cache common IN counts
stmt := acc.RawPrepare("SELECT tid FROM topics WHERE (MATCH(topics.title) AGAINST (? IN BOOLEAN MODE) OR MATCH(topics.content) AGAINST (? IN BOOLEAN MODE)) AND parentID IN(" + zList + ")")
err = acc.FirstError()
if err != nil {
if err = acc.FirstError(); err != nil {
return nil, err
}
err = run(stmt.Query(q, q))
@ -131,8 +127,7 @@ func (s *SQLSearcher) Query(q string, zones []int) (ids []int, err error) {
return nil, err
}
stmt = acc.RawPrepare("SELECT tid FROM replies WHERE MATCH(replies.content) AGAINST (? IN BOOLEAN MODE) AND tid IN(" + zList + ")")
err = acc.FirstError()
if err != nil {
if err = acc.FirstError(); err != nil {
return nil, err
}
err = run(stmt.Query(q))

View File

@ -38,6 +38,21 @@ func (t *SingleServerThaw) Thaw() {
}
}
type TestThaw struct {
}
func NewTestThaw() *TestThaw {
return &TestThaw{}
}
func (t *TestThaw) Thawed() bool {
return true
}
func (t *TestThaw) Thaw() {
}
func (t *TestThaw) Tick() error {
return nil
}
type DefaultThaw struct {
thawed int64
}

View File

@ -8,50 +8,47 @@ import (
"strings"
c "github.com/Azareal/Gosora/common"
qgen "github.com/Azareal/Gosora/query_gen"
)
type AttachmentStmts struct {
get *sql.Stmt
}
var attachmentStmts AttachmentStmts
// TODO: Abstract this with an attachment store
func init() {
c.DbInits.Add(func(acc *qgen.Accumulator) error {
attachmentStmts = AttachmentStmts{
get: acc.Select("attachments").Columns("sectionID, sectionTable, originID, originTable, uploadedBy, path").Where("path=? AND sectionID=? AND sectionTable=?").Prepare(),
}
return acc.FirstError()
})
}
var maxAgeYear = "max-age=" + strconv.Itoa(int(c.Year))
func ShowAttachment(w http.ResponseWriter, r *http.Request, u *c.User, filename string) c.RouteError {
filename = c.Stripslashes(filename)
ext := filepath.Ext("./attachs/" + filename)
if !c.AllowedFileExts.Contains(strings.TrimPrefix(ext, ".")) {
return c.LocalError("Bad extension", w, r, u)
}
sid, err := strconv.Atoi(r.FormValue("sid"))
if err != nil {
return c.LocalError("The sid is not an integer", w, r, u)
}
sectionTable := r.FormValue("stype")
var originTable string
var originID, uploadedBy int
err = attachmentStmts.get.QueryRow(filename, sid, sectionTable).Scan(&sid, &sectionTable, &originID, &originTable, &uploadedBy, &filename)
filename = c.Stripslashes(filename)
if filename == "" {
return c.LocalError("Bad filename", w, r, u)
}
ext := filepath.Ext(filename)
if ext == "" || !c.AllowedFileExts.Contains(strings.TrimPrefix(ext, ".")) {
return c.LocalError("Bad extension", w, r, u)
}
// TODO: Use the same hook table as upstream
hTbl := c.GetHookTable()
skip, rerr := c.H_route_attach_start_hook(hTbl, w, r, u, filename)
if skip || rerr != nil {
return rerr
}
a, err := c.Attachments.GetForRenderRoute(filename, sid, sectionTable)
// ErrCorruptAttachPath is a possibility now
if err == sql.ErrNoRows {
return c.NotFound(w, r, nil)
} else if err != nil {
return c.InternalError(err, w, r)
}
if sectionTable == "forums" {
skip, rerr = c.H_route_attach_post_get_hook(hTbl, w, r, u, a)
if skip || rerr != nil {
return rerr
}
if a.SectionTable == "forums" {
_, ferr := c.SimpleForumUserCheck(w, r, u, sid)
if ferr != nil {
return ferr
@ -63,7 +60,7 @@ func ShowAttachment(w http.ResponseWriter, r *http.Request, u *c.User, filename
return c.LocalError("Unknown section", w, r, u)
}
if originTable != "topics" && originTable != "replies" {
if a.OriginTable != "topics" && a.OriginTable != "replies" {
return c.LocalError("Unknown origin", w, r, u)
}
@ -89,11 +86,11 @@ func ShowAttachment(w http.ResponseWriter, r *http.Request, u *c.User, filename
}
func deleteAttachment(w http.ResponseWriter, r *http.Request, u *c.User, aid int, js bool) c.RouteError {
err := c.DeleteAttachment(aid)
if err == sql.ErrNoRows {
e := c.DeleteAttachment(aid)
if e == sql.ErrNoRows {
return c.NotFoundJSQ(w, r, nil, js)
} else if err != nil {
return c.InternalErrorJSQ(err, w, r, js)
} else if e != nil {
return c.InternalErrorJSQ(e, w, r, js)
}
return nil
}