From 26e8bf32a7701022c6495ca7227647b09871187f Mon Sep 17 00:00:00 2001 From: Azareal Date: Wed, 24 Mar 2021 21:45:18 +1000 Subject: [PATCH] Add GetForRenderRoute to DefaultAttachmentStore. Use ErrCorruptAttachPath in DefaultAttachmentStore. Consider the possibility that the requested attachment doesn't have a file extension. Rename variables to reduce boilerplate. Add TestThaw. Avoid an allocation in hookgen. Add route_attach_start hook. Add route_attach_post_get hook. --- cmd/common_hook_gen/hookgen.go | 58 ++++++++++++++++---------------- common/attachments.go | 56 ++++++++++++++++++------------- common/extend.go | 12 ++++--- common/ip_search.go | 48 +++++++++++++------------- common/routes_common.go | 12 +++---- common/search.go | 35 +++++++++---------- common/thaw.go | 15 +++++++++ routes/attachments.go | 61 ++++++++++++++++------------------ 8 files changed, 160 insertions(+), 137 deletions(-) diff --git a/cmd/common_hook_gen/hookgen.go b/cmd/common_hook_gen/hookgen.go index c167ab35..21bf6214 100644 --- a/cmd/common_hook_gen/hookgen.go +++ b/cmd/common_hook_gen/hookgen.go @@ -1,60 +1,62 @@ package hookgen import ( + "bytes" "log" "os" - "bytes" "text/template" ) type HookVars struct { Imports []string - Hooks []Hook + Hooks []Hook } type Hook struct { - Name string - Params string - Params2 string - Ret string - Type string - Any bool - MultiHook bool - Skip bool + Name string + Params string + Params2 string + Ret string + Type string + Any bool + MultiHook bool + Skip bool DefaultRet string - Pure string + Pure string } func AddHooks(add func(name, params, ret, htype string, multiHook, skip bool, defaultRet, pure string)) { vhookskip := func(name, params string) { - add(name,params,"(bool,RouteError)","VhookSkippable_",false,true,"false,nil","") + add(name, params, "(bool,RouteError)", "VhookSkippable_", false, true, "false,nil", "") } - vhookskip("forum_check_pre_perms","w http.ResponseWriter,r *http.Request,u *User,fid *int,h *Header") - vhookskip("router_after_filters","w http.ResponseWriter,r *http.Request,prefix string") - vhookskip("router_pre_route","w http.ResponseWriter,r *http.Request,u *User,prefix string") - vhookskip("route_forum_list_start","w http.ResponseWriter,r *http.Request,u *User,h *Header") - vhookskip("route_topic_list_start","w http.ResponseWriter,r *http.Request,u *User,h *Header") + vhookskip("forum_check_pre_perms", "w http.ResponseWriter,r *http.Request,u *User,fid *int,h *Header") + vhookskip("router_after_filters", "w http.ResponseWriter,r *http.Request,prefix string") + vhookskip("router_pre_route", "w http.ResponseWriter,r *http.Request,u *User,prefix string") + vhookskip("route_forum_list_start", "w http.ResponseWriter,r *http.Request,u *User,h *Header") + vhookskip("route_topic_list_start", "w http.ResponseWriter,r *http.Request,u *User,h *Header") + vhookskip("route_attach_start", "w http.ResponseWriter,r *http.Request,u *User,fname string") + vhookskip("route_attach_post_get", "w http.ResponseWriter,r *http.Request,u *User,a *Attachment") vhooknoret := func(name, params string) { - add(name,params,"","Vhooks",false,false,"false,nil","") + add(name, params, "", "Vhooks", false, false, "false,nil", "") } - vhooknoret("router_end","w http.ResponseWriter,r *http.Request,u *User,prefix string, extraData string") - vhooknoret("topic_reply_row_assign","r *ReplyUser") + vhooknoret("router_end", "w http.ResponseWriter,r *http.Request,u *User,prefix string,extraData string") + vhooknoret("topic_reply_row_assign", "r *ReplyUser") //forums_frow_assign //Hook(name string, data interface{}) interface{} /*hook := func(name, params, ret, pure string) { add(name,params,ret,"Hooks",true,false,ret,pure) }*/ hooknoret := func(name, params string) { - add(name,params,"","HooksNoRet",true,false,"","") + add(name, params, "", "HooksNoRet", true, false, "", "") } - hooknoret("forums_frow_assign","f *Forum") + hooknoret("forums_frow_assign", "f *Forum") hookskip := func(name, params string) { - add(name,params,"(skip bool)","HooksSkip",true,true,"","") + add(name, params, "(skip bool)", "HooksSkip", true, true, "", "") } //hookskip("forums_frow_assign","f *Forum") - hookskip("topic_create_frow_assign","f *Forum") + hookskip("topic_create_frow_assign", "f *Forum") hookss := func(name string) { - add(name,"d string","string","Sshooks",true,false,"","d") + add(name, "d string", "string", "Sshooks", true, false, "", "d") } hookss("topic_ogdesc_assign") } @@ -85,8 +87,8 @@ func H_{{.Name}}_hook(t *HookTable,{{.Params}}) {{.Ret}} { {{if .Any}} if e := tmpl.Execute(&b, hookVars); e != nil { log.Fatal(e) } - - err := writeFile("./common/gen_extend.go", string(b.Bytes())) + + err := writeFile("./common/gen_extend.go", b.String()) if err != nil { log.Fatal(err) } @@ -104,4 +106,4 @@ func writeFile(name, body string) error { return e } return f.Close() -} \ No newline at end of file +} diff --git a/common/attachments.go b/common/attachments.go index ba78b554..0f448e7b 100644 --- a/common/attachments.go +++ b/common/attachments.go @@ -14,6 +14,8 @@ import ( var Attachments AttachmentStore +var ErrCorruptAttachPath = errors.New("corrupt attachment path") + type MiniAttachment struct { ID int SectionID int @@ -41,6 +43,7 @@ type Attachment struct { } type AttachmentStore interface { + GetForRenderRoute(filename string, sid int, sectionTable string) (*Attachment, error) FGet(id int) (*Attachment, error) Get(id int) (*MiniAttachment, error) MiniGetList(originTable string, originID int) (alist []*MiniAttachment, err error) @@ -58,6 +61,8 @@ type AttachmentStore interface { } type DefaultAttachmentStore struct { + getForRenderRoute *sql.Stmt + fget *sql.Stmt get *sql.Stmt getByObj *sql.Stmt @@ -76,6 +81,8 @@ type DefaultAttachmentStore struct { func NewDefaultAttachmentStore(acc *qgen.Accumulator) (*DefaultAttachmentStore, error) { a := "attachments" return &DefaultAttachmentStore{ + getForRenderRoute: acc.Select(a).Columns("sectionTable, originID, originTable, uploadedBy, path").Where("path=? AND sectionID=? AND sectionTable=?").Prepare(), + fget: acc.Select(a).Columns("originTable, originID, sectionTable, sectionID, uploadedBy, path, extra").Where("attachID=?").Prepare(), get: acc.Select(a).Columns("originID, sectionID, uploadedBy, path, extra").Where("attachID=?").Prepare(), getByObj: acc.Select(a).Columns("attachID, sectionID, uploadedBy, path, extra").Where("originTable=? AND originID=?").Prepare(), @@ -93,6 +100,15 @@ func NewDefaultAttachmentStore(acc *qgen.Accumulator) (*DefaultAttachmentStore, }, acc.FirstError() } +// TODO: Revamp this to make it less of a copy-paste from the original code in the route +// ! Lacks some attachment initialisation code +func (s *DefaultAttachmentStore) GetForRenderRoute(filename string, sid int, sectionTable string) (*Attachment, error) { + a := &Attachment{SectionID: sid} + e := s.getForRenderRoute.QueryRow(filename, sid, sectionTable).Scan(&a.SectionTable, &a.OriginID, &a.OriginTable, &a.UploadedBy, &a.Path) + // TODO: Initialise attachment struct fields? + return a, e +} + func (s *DefaultAttachmentStore) MiniGetList(originTable string, originID int) (alist []*MiniAttachment, err error) { rows, err := s.getByObj.Query(originTable, originID) defer rows.Close() @@ -104,7 +120,7 @@ func (s *DefaultAttachmentStore) MiniGetList(originTable string, originID int) ( } a.Ext = strings.TrimPrefix(filepath.Ext(a.Path), ".") if len(a.Ext) == 0 { - return nil, errors.New("corrupt attachment path") + return nil, ErrCorruptAttachPath } a.Image = ImageFileExts.Contains(a.Ext) alist = append(alist, a) @@ -140,7 +156,7 @@ func (s *DefaultAttachmentStore) BulkMiniGetList(originTable string, ids []int) } a.Ext = strings.TrimPrefix(filepath.Ext(a.Path), ".") if len(a.Ext) == 0 { - return nil, errors.New("corrupt attachment path") + return nil, ErrCorruptAttachPath } a.Image = ImageFileExts.Contains(a.Ext) if currentID == 0 { @@ -169,7 +185,7 @@ func (s *DefaultAttachmentStore) FGet(id int) (*Attachment, error) { } a.Ext = strings.TrimPrefix(filepath.Ext(a.Path), ".") if len(a.Ext) == 0 { - return nil, errors.New("corrupt attachment path") + return nil, ErrCorruptAttachPath } a.Image = ImageFileExts.Contains(a.Ext) return a, nil @@ -183,7 +199,7 @@ func (s *DefaultAttachmentStore) Get(id int) (*MiniAttachment, error) { } a.Ext = strings.TrimPrefix(filepath.Ext(a.Path), ".") if len(a.Ext) == 0 { - return nil, errors.New("corrupt attachment path") + return nil, ErrCorruptAttachPath } a.Image = ImageFileExts.Contains(a.Ext) return a, nil @@ -209,32 +225,32 @@ func (s *DefaultAttachmentStore) MoveToByExtra(sectionID int, originTable, extra } func (s *DefaultAttachmentStore) Count() (count int) { - err := s.count.QueryRow().Scan(&count) - if err != nil { - LogError(err) + e := s.count.QueryRow().Scan(&count) + if e != nil { + LogError(e) } return count } func (s *DefaultAttachmentStore) CountIn(originTable string, oid int) (count int) { - err := s.countIn.QueryRow(originTable, oid).Scan(&count) - if err != nil { - LogError(err) + e := s.countIn.QueryRow(originTable, oid).Scan(&count) + if e != nil { + LogError(e) } return count } func (s *DefaultAttachmentStore) CountInPath(path string) (count int) { - err := s.countInPath.QueryRow(path).Scan(&count) - if err != nil { - LogError(err) + e := s.countInPath.QueryRow(path).Scan(&count) + if e != nil { + LogError(e) } return count } func (s *DefaultAttachmentStore) Delete(id int) error { - _, err := s.delete.Exec(id) - return err + _, e := s.delete.Exec(id) + return e } // TODO: Split this out of this store @@ -256,10 +272,7 @@ func (s *DefaultAttachmentStore) AddLinked(otable string, oid int) (err error) { if err == sql.ErrNoRows { err = nil } - if err != nil { - return err - } - return nil + return err } // TODO: Split this out of this store @@ -280,10 +293,7 @@ func (s *DefaultAttachmentStore) RemoveLinked(otable string, oid int) (err error } err = Rstore.GetCache().Remove(oid) } - if err != nil { - return err - } - return nil + return err } // TODO: Add a table for the files and lock the file row when performing tasks related to the file diff --git a/common/extend.go b/common/extend.go index 7169e2f0..58dc1bcc 100644 --- a/common/extend.go +++ b/common/extend.go @@ -1,7 +1,7 @@ /* * * Gosora Plugin System -* Copyright Azareal 2016 - 2020 +* Copyright Azareal 2016 - 2021 * */ package common @@ -25,9 +25,9 @@ type PluginList map[string]*Plugin // TODO: Have a proper store rather than a map? var Plugins PluginList = make(map[string]*Plugin) -func (list PluginList) Add(pl *Plugin) { +func (l PluginList) Add(pl *Plugin) { buildPlugin(pl) - list[pl.UName] = pl + l[pl.UName] = pl } func buildPlugin(pl *Plugin) { @@ -94,6 +94,8 @@ var hookTable = &HookTable{ "route_topic_list_start": nil, "route_topic_list_mostviewed_start": nil, "route_forum_list_start": nil, + "route_attach_start": nil, + "route_attach_post_get": nil, "action_end_create_topic": nil, "action_end_edit_topic": nil, @@ -126,8 +128,8 @@ var hookTable = &HookTable{ "tasks_tick_widget_wol": nil, }, map[string][]func(string) string{ - "preparse_preassign": nil, - "parse_assign": nil, + "preparse_preassign": nil, + "parse_assign": nil, "topic_ogdesc_assign": nil, }, nil, diff --git a/common/ip_search.go b/common/ip_search.go index d88f7574..1a38a936 100644 --- a/common/ip_search.go +++ b/common/ip_search.go @@ -9,7 +9,7 @@ import ( var IPSearch IPSearcher type IPSearcher interface { - Lookup(ip string) (uids []int, err error) + Lookup(ip string) (uids []int, e error) } type DefaultIPSearcher struct { @@ -23,27 +23,29 @@ type DefaultIPSearcher struct { func NewDefaultIPSearcher() (*DefaultIPSearcher, error) { acc := qgen.NewAcc() uu := "users" + q := func(tbl string) *sql.Stmt { + return acc.Select(uu).Columns("uid").InQ("uid", acc.Select(tbl).Columns("createdBy").Where("ip=?")).Prepare() + } return &DefaultIPSearcher{ searchUsers: acc.Select(uu).Columns("uid").Where("last_ip=? OR last_ip LIKE CONCAT('%-',?)").Prepare(), - searchTopics: acc.Select(uu).Columns("uid").InQ("uid", acc.Select("topics").Columns("createdBy").Where("ip=?")).Prepare(), - searchReplies: acc.Select(uu).Columns("uid").InQ("uid", acc.Select("replies").Columns("createdBy").Where("ip=?")).Prepare(), - searchUsersReplies: acc.Select(uu).Columns("uid").InQ("uid", acc.Select("users_replies").Columns("createdBy").Where("ip=?")).Prepare(), + searchTopics: q("topics"), + searchReplies: q("replies"), + searchUsersReplies: q("users_replies"), }, acc.FirstError() } -func (s *DefaultIPSearcher) Lookup(ip string) (uids []int, err error) { +func (s *DefaultIPSearcher) Lookup(ip string) (uids []int, e error) { var uid int reqUserList := make(map[int]bool) - runQuery2 := func(rows *sql.Rows, err error) error { - if err != nil { - return err + runQuery2 := func(rows *sql.Rows, e error) error { + if e != nil { + return e } defer rows.Close() for rows.Next() { - err := rows.Scan(&uid) - if err != nil { - return err + if e := rows.Scan(&uid); e != nil { + return e } reqUserList[uid] = true } @@ -53,21 +55,21 @@ func (s *DefaultIPSearcher) Lookup(ip string) (uids []int, err error) { return runQuery2(stmt.Query(ip)) } - err = runQuery2(s.searchUsers.Query(ip, ip)) - if err != nil { - return uids, err + e = runQuery2(s.searchUsers.Query(ip, ip)) + if e != nil { + return uids, e } - err = runQuery(s.searchTopics) - if err != nil { - return uids, err + e = runQuery(s.searchTopics) + if e != nil { + return uids, e } - err = runQuery(s.searchReplies) - if err != nil { - return uids, err + e = runQuery(s.searchReplies) + if e != nil { + return uids, e } - err = runQuery(s.searchUsersReplies) - if err != nil { - return uids, err + e = runQuery(s.searchUsersReplies) + if e != nil { + return uids, e } // Convert the user ID map to a slice, then bulk load the users diff --git a/common/routes_common.go b/common/routes_common.go index d39a4780..167b9f32 100644 --- a/common/routes_common.go +++ b/common/routes_common.go @@ -44,13 +44,13 @@ func simpleForumUserCheck(w http.ResponseWriter, r *http.Request, u *User, fid i return h, rerr } - fperms, err := FPStore.Get(fid, u.Group) + fp, err := FPStore.Get(fid, u.Group) if err == ErrNoRows { - fperms = BlankForumPerms() + fp = BlankForumPerms() } else if err != nil { return h, InternalError(err, w, r) } - cascadeForumPerms(fperms, u) + cascadeForumPerms(fp, u) return h, nil } @@ -72,13 +72,13 @@ func forumUserCheck(h *Header, w http.ResponseWriter, r *http.Request, u *User, return rerr } - fperms, err := FPStore.Get(fid, u.Group) + fp, err := FPStore.Get(fid, u.Group) if err == ErrNoRows { - fperms = BlankForumPerms() + fp = BlankForumPerms() } else if err != nil { return InternalError(err, w, r) } - cascadeForumPerms(fperms, u) + cascadeForumPerms(fp, u) h.CurrentUser = u // TODO: Use a pointer instead for CurrentUser, so we don't have to do this return rerr } diff --git a/common/search.go b/common/search.go index 457f412f..3d196126 100644 --- a/common/search.go +++ b/common/search.go @@ -44,19 +44,18 @@ func NewSQLSearcher(acc *qgen.Accumulator) (*SQLSearcher, error) { func (s *SQLSearcher) queryAll(q string) ([]int, error) { var ids []int run := func(stmt *sql.Stmt, q ...interface{}) error { - rows, err := stmt.Query(q...) - if err == sql.ErrNoRows { + rows, e := stmt.Query(q...) + if e == sql.ErrNoRows { return nil - } else if err != nil { - return err + } else if e != nil { + return e } defer rows.Close() for rows.Next() { var id int - err := rows.Scan(&id) - if err != nil { - return err + if e := rows.Scan(&id); e != nil { + return e } ids = append(ids, id) } @@ -81,19 +80,18 @@ func (s *SQLSearcher) Query(q string, zones []int) (ids []int, err error) { if len(zones) == 0 { return nil, nil } - run := func(rows *sql.Rows, err error) error { - /*if err == sql.ErrNoRows { + run := func(rows *sql.Rows, e error) error { + /*if e == sql.ErrNoRows { return nil - } else */if err != nil { - return err + } else */if e != nil { + return e } defer rows.Close() for rows.Next() { var id int - err := rows.Scan(&id) - if err != nil { - return err + if e := rows.Scan(&id); e != nil { + return e } ids = append(ids, id) } @@ -116,14 +114,12 @@ func (s *SQLSearcher) Query(q string, zones []int) (ids []int, err error) { acc := qgen.NewAcc() /*stmt := acc.RawPrepare("SELECT topics.tid FROM topics INNER JOIN replies ON topics.tid = replies.tid WHERE (MATCH(topics.title) AGAINST (? IN BOOLEAN MODE) OR MATCH(topics.content) AGAINST (? IN BOOLEAN MODE) OR MATCH(replies.content) AGAINST (? IN BOOLEAN MODE) OR topics.title=? OR topics.content=? OR replies.content=?) AND topics.parentID IN(" + zList + ")") - err = acc.FirstError() - if err != nil { + if err = acc.FirstError(); err != nil { return nil, err }*/ // TODO: Cache common IN counts stmt := acc.RawPrepare("SELECT tid FROM topics WHERE (MATCH(topics.title) AGAINST (? IN BOOLEAN MODE) OR MATCH(topics.content) AGAINST (? IN BOOLEAN MODE)) AND parentID IN(" + zList + ")") - err = acc.FirstError() - if err != nil { + if err = acc.FirstError(); err != nil { return nil, err } err = run(stmt.Query(q, q)) @@ -131,8 +127,7 @@ func (s *SQLSearcher) Query(q string, zones []int) (ids []int, err error) { return nil, err } stmt = acc.RawPrepare("SELECT tid FROM replies WHERE MATCH(replies.content) AGAINST (? IN BOOLEAN MODE) AND tid IN(" + zList + ")") - err = acc.FirstError() - if err != nil { + if err = acc.FirstError(); err != nil { return nil, err } err = run(stmt.Query(q)) diff --git a/common/thaw.go b/common/thaw.go index 8f39d894..0bbdaf15 100644 --- a/common/thaw.go +++ b/common/thaw.go @@ -38,6 +38,21 @@ func (t *SingleServerThaw) Thaw() { } } +type TestThaw struct { +} + +func NewTestThaw() *TestThaw { + return &TestThaw{} +} +func (t *TestThaw) Thawed() bool { + return true +} +func (t *TestThaw) Thaw() { +} +func (t *TestThaw) Tick() error { + return nil +} + type DefaultThaw struct { thawed int64 } diff --git a/routes/attachments.go b/routes/attachments.go index 0f4bcecd..5706f0ee 100644 --- a/routes/attachments.go +++ b/routes/attachments.go @@ -8,50 +8,47 @@ import ( "strings" c "github.com/Azareal/Gosora/common" - qgen "github.com/Azareal/Gosora/query_gen" ) -type AttachmentStmts struct { - get *sql.Stmt -} - -var attachmentStmts AttachmentStmts - -// TODO: Abstract this with an attachment store -func init() { - c.DbInits.Add(func(acc *qgen.Accumulator) error { - attachmentStmts = AttachmentStmts{ - get: acc.Select("attachments").Columns("sectionID, sectionTable, originID, originTable, uploadedBy, path").Where("path=? AND sectionID=? AND sectionTable=?").Prepare(), - } - return acc.FirstError() - }) -} - var maxAgeYear = "max-age=" + strconv.Itoa(int(c.Year)) func ShowAttachment(w http.ResponseWriter, r *http.Request, u *c.User, filename string) c.RouteError { - filename = c.Stripslashes(filename) - ext := filepath.Ext("./attachs/" + filename) - if !c.AllowedFileExts.Contains(strings.TrimPrefix(ext, ".")) { - return c.LocalError("Bad extension", w, r, u) - } - sid, err := strconv.Atoi(r.FormValue("sid")) if err != nil { return c.LocalError("The sid is not an integer", w, r, u) } sectionTable := r.FormValue("stype") - var originTable string - var originID, uploadedBy int - err = attachmentStmts.get.QueryRow(filename, sid, sectionTable).Scan(&sid, §ionTable, &originID, &originTable, &uploadedBy, &filename) + filename = c.Stripslashes(filename) + if filename == "" { + return c.LocalError("Bad filename", w, r, u) + } + ext := filepath.Ext(filename) + if ext == "" || !c.AllowedFileExts.Contains(strings.TrimPrefix(ext, ".")) { + return c.LocalError("Bad extension", w, r, u) + } + + // TODO: Use the same hook table as upstream + hTbl := c.GetHookTable() + skip, rerr := c.H_route_attach_start_hook(hTbl, w, r, u, filename) + if skip || rerr != nil { + return rerr + } + + a, err := c.Attachments.GetForRenderRoute(filename, sid, sectionTable) + // ErrCorruptAttachPath is a possibility now if err == sql.ErrNoRows { return c.NotFound(w, r, nil) } else if err != nil { return c.InternalError(err, w, r) } - if sectionTable == "forums" { + skip, rerr = c.H_route_attach_post_get_hook(hTbl, w, r, u, a) + if skip || rerr != nil { + return rerr + } + + if a.SectionTable == "forums" { _, ferr := c.SimpleForumUserCheck(w, r, u, sid) if ferr != nil { return ferr @@ -63,7 +60,7 @@ func ShowAttachment(w http.ResponseWriter, r *http.Request, u *c.User, filename return c.LocalError("Unknown section", w, r, u) } - if originTable != "topics" && originTable != "replies" { + if a.OriginTable != "topics" && a.OriginTable != "replies" { return c.LocalError("Unknown origin", w, r, u) } @@ -89,11 +86,11 @@ func ShowAttachment(w http.ResponseWriter, r *http.Request, u *c.User, filename } func deleteAttachment(w http.ResponseWriter, r *http.Request, u *c.User, aid int, js bool) c.RouteError { - err := c.DeleteAttachment(aid) - if err == sql.ErrNoRows { + e := c.DeleteAttachment(aid) + if e == sql.ErrNoRows { return c.NotFoundJSQ(w, r, nil, js) - } else if err != nil { - return c.InternalErrorJSQ(err, w, r, js) + } else if e != nil { + return c.InternalErrorJSQ(e, w, r, js) } return nil }