Add GetForRenderRoute to DefaultAttachmentStore.

Use ErrCorruptAttachPath in DefaultAttachmentStore.
Consider the possibility that the requested attachment doesn't have a file extension.
Rename variables to reduce boilerplate.
Add TestThaw.
Avoid an allocation in hookgen.

Add route_attach_start hook.
Add route_attach_post_get hook.
This commit is contained in:
Azareal 2021-03-24 21:45:18 +10:00
parent adfed477a0
commit 26e8bf32a7
8 changed files with 160 additions and 137 deletions

View File

@ -1,60 +1,62 @@
package hookgen
import (
"bytes"
"log"
"os"
"bytes"
"text/template"
)
type HookVars struct {
Imports []string
Hooks []Hook
Hooks []Hook
}
type Hook struct {
Name string
Params string
Params2 string
Ret string
Type string
Any bool
MultiHook bool
Skip bool
Name string
Params string
Params2 string
Ret string
Type string
Any bool
MultiHook bool
Skip bool
DefaultRet string
Pure string
Pure string
}
func AddHooks(add func(name, params, ret, htype string, multiHook, skip bool, defaultRet, pure string)) {
vhookskip := func(name, params string) {
add(name,params,"(bool,RouteError)","VhookSkippable_",false,true,"false,nil","")
add(name, params, "(bool,RouteError)", "VhookSkippable_", false, true, "false,nil", "")
}
vhookskip("forum_check_pre_perms","w http.ResponseWriter,r *http.Request,u *User,fid *int,h *Header")
vhookskip("router_after_filters","w http.ResponseWriter,r *http.Request,prefix string")
vhookskip("router_pre_route","w http.ResponseWriter,r *http.Request,u *User,prefix string")
vhookskip("route_forum_list_start","w http.ResponseWriter,r *http.Request,u *User,h *Header")
vhookskip("route_topic_list_start","w http.ResponseWriter,r *http.Request,u *User,h *Header")
vhookskip("forum_check_pre_perms", "w http.ResponseWriter,r *http.Request,u *User,fid *int,h *Header")
vhookskip("router_after_filters", "w http.ResponseWriter,r *http.Request,prefix string")
vhookskip("router_pre_route", "w http.ResponseWriter,r *http.Request,u *User,prefix string")
vhookskip("route_forum_list_start", "w http.ResponseWriter,r *http.Request,u *User,h *Header")
vhookskip("route_topic_list_start", "w http.ResponseWriter,r *http.Request,u *User,h *Header")
vhookskip("route_attach_start", "w http.ResponseWriter,r *http.Request,u *User,fname string")
vhookskip("route_attach_post_get", "w http.ResponseWriter,r *http.Request,u *User,a *Attachment")
vhooknoret := func(name, params string) {
add(name,params,"","Vhooks",false,false,"false,nil","")
add(name, params, "", "Vhooks", false, false, "false,nil", "")
}
vhooknoret("router_end","w http.ResponseWriter,r *http.Request,u *User,prefix string, extraData string")
vhooknoret("topic_reply_row_assign","r *ReplyUser")
vhooknoret("router_end", "w http.ResponseWriter,r *http.Request,u *User,prefix string,extraData string")
vhooknoret("topic_reply_row_assign", "r *ReplyUser")
//forums_frow_assign
//Hook(name string, data interface{}) interface{}
/*hook := func(name, params, ret, pure string) {
add(name,params,ret,"Hooks",true,false,ret,pure)
}*/
hooknoret := func(name, params string) {
add(name,params,"","HooksNoRet",true,false,"","")
add(name, params, "", "HooksNoRet", true, false, "", "")
}
hooknoret("forums_frow_assign","f *Forum")
hooknoret("forums_frow_assign", "f *Forum")
hookskip := func(name, params string) {
add(name,params,"(skip bool)","HooksSkip",true,true,"","")
add(name, params, "(skip bool)", "HooksSkip", true, true, "", "")
}
//hookskip("forums_frow_assign","f *Forum")
hookskip("topic_create_frow_assign","f *Forum")
hookskip("topic_create_frow_assign", "f *Forum")
hookss := func(name string) {
add(name,"d string","string","Sshooks",true,false,"","d")
add(name, "d string", "string", "Sshooks", true, false, "", "d")
}
hookss("topic_ogdesc_assign")
}
@ -85,8 +87,8 @@ func H_{{.Name}}_hook(t *HookTable,{{.Params}}) {{.Ret}} { {{if .Any}}
if e := tmpl.Execute(&b, hookVars); e != nil {
log.Fatal(e)
}
err := writeFile("./common/gen_extend.go", string(b.Bytes()))
err := writeFile("./common/gen_extend.go", b.String())
if err != nil {
log.Fatal(err)
}
@ -104,4 +106,4 @@ func writeFile(name, body string) error {
return e
}
return f.Close()
}
}

View File

@ -14,6 +14,8 @@ import (
var Attachments AttachmentStore
var ErrCorruptAttachPath = errors.New("corrupt attachment path")
type MiniAttachment struct {
ID int
SectionID int
@ -41,6 +43,7 @@ type Attachment struct {
}
type AttachmentStore interface {
GetForRenderRoute(filename string, sid int, sectionTable string) (*Attachment, error)
FGet(id int) (*Attachment, error)
Get(id int) (*MiniAttachment, error)
MiniGetList(originTable string, originID int) (alist []*MiniAttachment, err error)
@ -58,6 +61,8 @@ type AttachmentStore interface {
}
type DefaultAttachmentStore struct {
getForRenderRoute *sql.Stmt
fget *sql.Stmt
get *sql.Stmt
getByObj *sql.Stmt
@ -76,6 +81,8 @@ type DefaultAttachmentStore struct {
func NewDefaultAttachmentStore(acc *qgen.Accumulator) (*DefaultAttachmentStore, error) {
a := "attachments"
return &DefaultAttachmentStore{
getForRenderRoute: acc.Select(a).Columns("sectionTable, originID, originTable, uploadedBy, path").Where("path=? AND sectionID=? AND sectionTable=?").Prepare(),
fget: acc.Select(a).Columns("originTable, originID, sectionTable, sectionID, uploadedBy, path, extra").Where("attachID=?").Prepare(),
get: acc.Select(a).Columns("originID, sectionID, uploadedBy, path, extra").Where("attachID=?").Prepare(),
getByObj: acc.Select(a).Columns("attachID, sectionID, uploadedBy, path, extra").Where("originTable=? AND originID=?").Prepare(),
@ -93,6 +100,15 @@ func NewDefaultAttachmentStore(acc *qgen.Accumulator) (*DefaultAttachmentStore,
}, acc.FirstError()
}
// TODO: Revamp this to make it less of a copy-paste from the original code in the route
// ! Lacks some attachment initialisation code
func (s *DefaultAttachmentStore) GetForRenderRoute(filename string, sid int, sectionTable string) (*Attachment, error) {
a := &Attachment{SectionID: sid}
e := s.getForRenderRoute.QueryRow(filename, sid, sectionTable).Scan(&a.SectionTable, &a.OriginID, &a.OriginTable, &a.UploadedBy, &a.Path)
// TODO: Initialise attachment struct fields?
return a, e
}
func (s *DefaultAttachmentStore) MiniGetList(originTable string, originID int) (alist []*MiniAttachment, err error) {
rows, err := s.getByObj.Query(originTable, originID)
defer rows.Close()
@ -104,7 +120,7 @@ func (s *DefaultAttachmentStore) MiniGetList(originTable string, originID int) (
}
a.Ext = strings.TrimPrefix(filepath.Ext(a.Path), ".")
if len(a.Ext) == 0 {
return nil, errors.New("corrupt attachment path")
return nil, ErrCorruptAttachPath
}
a.Image = ImageFileExts.Contains(a.Ext)
alist = append(alist, a)
@ -140,7 +156,7 @@ func (s *DefaultAttachmentStore) BulkMiniGetList(originTable string, ids []int)
}
a.Ext = strings.TrimPrefix(filepath.Ext(a.Path), ".")
if len(a.Ext) == 0 {
return nil, errors.New("corrupt attachment path")
return nil, ErrCorruptAttachPath
}
a.Image = ImageFileExts.Contains(a.Ext)
if currentID == 0 {
@ -169,7 +185,7 @@ func (s *DefaultAttachmentStore) FGet(id int) (*Attachment, error) {
}
a.Ext = strings.TrimPrefix(filepath.Ext(a.Path), ".")
if len(a.Ext) == 0 {
return nil, errors.New("corrupt attachment path")
return nil, ErrCorruptAttachPath
}
a.Image = ImageFileExts.Contains(a.Ext)
return a, nil
@ -183,7 +199,7 @@ func (s *DefaultAttachmentStore) Get(id int) (*MiniAttachment, error) {
}
a.Ext = strings.TrimPrefix(filepath.Ext(a.Path), ".")
if len(a.Ext) == 0 {
return nil, errors.New("corrupt attachment path")
return nil, ErrCorruptAttachPath
}
a.Image = ImageFileExts.Contains(a.Ext)
return a, nil
@ -209,32 +225,32 @@ func (s *DefaultAttachmentStore) MoveToByExtra(sectionID int, originTable, extra
}
func (s *DefaultAttachmentStore) Count() (count int) {
err := s.count.QueryRow().Scan(&count)
if err != nil {
LogError(err)
e := s.count.QueryRow().Scan(&count)
if e != nil {
LogError(e)
}
return count
}
func (s *DefaultAttachmentStore) CountIn(originTable string, oid int) (count int) {
err := s.countIn.QueryRow(originTable, oid).Scan(&count)
if err != nil {
LogError(err)
e := s.countIn.QueryRow(originTable, oid).Scan(&count)
if e != nil {
LogError(e)
}
return count
}
func (s *DefaultAttachmentStore) CountInPath(path string) (count int) {
err := s.countInPath.QueryRow(path).Scan(&count)
if err != nil {
LogError(err)
e := s.countInPath.QueryRow(path).Scan(&count)
if e != nil {
LogError(e)
}
return count
}
func (s *DefaultAttachmentStore) Delete(id int) error {
_, err := s.delete.Exec(id)
return err
_, e := s.delete.Exec(id)
return e
}
// TODO: Split this out of this store
@ -256,10 +272,7 @@ func (s *DefaultAttachmentStore) AddLinked(otable string, oid int) (err error) {
if err == sql.ErrNoRows {
err = nil
}
if err != nil {
return err
}
return nil
return err
}
// TODO: Split this out of this store
@ -280,10 +293,7 @@ func (s *DefaultAttachmentStore) RemoveLinked(otable string, oid int) (err error
}
err = Rstore.GetCache().Remove(oid)
}
if err != nil {
return err
}
return nil
return err
}
// TODO: Add a table for the files and lock the file row when performing tasks related to the file

View File

@ -1,7 +1,7 @@
/*
*
* Gosora Plugin System
* Copyright Azareal 2016 - 2020
* Copyright Azareal 2016 - 2021
*
*/
package common
@ -25,9 +25,9 @@ type PluginList map[string]*Plugin
// TODO: Have a proper store rather than a map?
var Plugins PluginList = make(map[string]*Plugin)
func (list PluginList) Add(pl *Plugin) {
func (l PluginList) Add(pl *Plugin) {
buildPlugin(pl)
list[pl.UName] = pl
l[pl.UName] = pl
}
func buildPlugin(pl *Plugin) {
@ -94,6 +94,8 @@ var hookTable = &HookTable{
"route_topic_list_start": nil,
"route_topic_list_mostviewed_start": nil,
"route_forum_list_start": nil,
"route_attach_start": nil,
"route_attach_post_get": nil,
"action_end_create_topic": nil,
"action_end_edit_topic": nil,
@ -126,8 +128,8 @@ var hookTable = &HookTable{
"tasks_tick_widget_wol": nil,
},
map[string][]func(string) string{
"preparse_preassign": nil,
"parse_assign": nil,
"preparse_preassign": nil,
"parse_assign": nil,
"topic_ogdesc_assign": nil,
},
nil,

View File

@ -9,7 +9,7 @@ import (
var IPSearch IPSearcher
type IPSearcher interface {
Lookup(ip string) (uids []int, err error)
Lookup(ip string) (uids []int, e error)
}
type DefaultIPSearcher struct {
@ -23,27 +23,29 @@ type DefaultIPSearcher struct {
func NewDefaultIPSearcher() (*DefaultIPSearcher, error) {
acc := qgen.NewAcc()
uu := "users"
q := func(tbl string) *sql.Stmt {
return acc.Select(uu).Columns("uid").InQ("uid", acc.Select(tbl).Columns("createdBy").Where("ip=?")).Prepare()
}
return &DefaultIPSearcher{
searchUsers: acc.Select(uu).Columns("uid").Where("last_ip=? OR last_ip LIKE CONCAT('%-',?)").Prepare(),
searchTopics: acc.Select(uu).Columns("uid").InQ("uid", acc.Select("topics").Columns("createdBy").Where("ip=?")).Prepare(),
searchReplies: acc.Select(uu).Columns("uid").InQ("uid", acc.Select("replies").Columns("createdBy").Where("ip=?")).Prepare(),
searchUsersReplies: acc.Select(uu).Columns("uid").InQ("uid", acc.Select("users_replies").Columns("createdBy").Where("ip=?")).Prepare(),
searchTopics: q("topics"),
searchReplies: q("replies"),
searchUsersReplies: q("users_replies"),
}, acc.FirstError()
}
func (s *DefaultIPSearcher) Lookup(ip string) (uids []int, err error) {
func (s *DefaultIPSearcher) Lookup(ip string) (uids []int, e error) {
var uid int
reqUserList := make(map[int]bool)
runQuery2 := func(rows *sql.Rows, err error) error {
if err != nil {
return err
runQuery2 := func(rows *sql.Rows, e error) error {
if e != nil {
return e
}
defer rows.Close()
for rows.Next() {
err := rows.Scan(&uid)
if err != nil {
return err
if e := rows.Scan(&uid); e != nil {
return e
}
reqUserList[uid] = true
}
@ -53,21 +55,21 @@ func (s *DefaultIPSearcher) Lookup(ip string) (uids []int, err error) {
return runQuery2(stmt.Query(ip))
}
err = runQuery2(s.searchUsers.Query(ip, ip))
if err != nil {
return uids, err
e = runQuery2(s.searchUsers.Query(ip, ip))
if e != nil {
return uids, e
}
err = runQuery(s.searchTopics)
if err != nil {
return uids, err
e = runQuery(s.searchTopics)
if e != nil {
return uids, e
}
err = runQuery(s.searchReplies)
if err != nil {
return uids, err
e = runQuery(s.searchReplies)
if e != nil {
return uids, e
}
err = runQuery(s.searchUsersReplies)
if err != nil {
return uids, err
e = runQuery(s.searchUsersReplies)
if e != nil {
return uids, e
}
// Convert the user ID map to a slice, then bulk load the users

View File

@ -44,13 +44,13 @@ func simpleForumUserCheck(w http.ResponseWriter, r *http.Request, u *User, fid i
return h, rerr
}
fperms, err := FPStore.Get(fid, u.Group)
fp, err := FPStore.Get(fid, u.Group)
if err == ErrNoRows {
fperms = BlankForumPerms()
fp = BlankForumPerms()
} else if err != nil {
return h, InternalError(err, w, r)
}
cascadeForumPerms(fperms, u)
cascadeForumPerms(fp, u)
return h, nil
}
@ -72,13 +72,13 @@ func forumUserCheck(h *Header, w http.ResponseWriter, r *http.Request, u *User,
return rerr
}
fperms, err := FPStore.Get(fid, u.Group)
fp, err := FPStore.Get(fid, u.Group)
if err == ErrNoRows {
fperms = BlankForumPerms()
fp = BlankForumPerms()
} else if err != nil {
return InternalError(err, w, r)
}
cascadeForumPerms(fperms, u)
cascadeForumPerms(fp, u)
h.CurrentUser = u // TODO: Use a pointer instead for CurrentUser, so we don't have to do this
return rerr
}

View File

@ -44,19 +44,18 @@ func NewSQLSearcher(acc *qgen.Accumulator) (*SQLSearcher, error) {
func (s *SQLSearcher) queryAll(q string) ([]int, error) {
var ids []int
run := func(stmt *sql.Stmt, q ...interface{}) error {
rows, err := stmt.Query(q...)
if err == sql.ErrNoRows {
rows, e := stmt.Query(q...)
if e == sql.ErrNoRows {
return nil
} else if err != nil {
return err
} else if e != nil {
return e
}
defer rows.Close()
for rows.Next() {
var id int
err := rows.Scan(&id)
if err != nil {
return err
if e := rows.Scan(&id); e != nil {
return e
}
ids = append(ids, id)
}
@ -81,19 +80,18 @@ func (s *SQLSearcher) Query(q string, zones []int) (ids []int, err error) {
if len(zones) == 0 {
return nil, nil
}
run := func(rows *sql.Rows, err error) error {
/*if err == sql.ErrNoRows {
run := func(rows *sql.Rows, e error) error {
/*if e == sql.ErrNoRows {
return nil
} else */if err != nil {
return err
} else */if e != nil {
return e
}
defer rows.Close()
for rows.Next() {
var id int
err := rows.Scan(&id)
if err != nil {
return err
if e := rows.Scan(&id); e != nil {
return e
}
ids = append(ids, id)
}
@ -116,14 +114,12 @@ func (s *SQLSearcher) Query(q string, zones []int) (ids []int, err error) {
acc := qgen.NewAcc()
/*stmt := acc.RawPrepare("SELECT topics.tid FROM topics INNER JOIN replies ON topics.tid = replies.tid WHERE (MATCH(topics.title) AGAINST (? IN BOOLEAN MODE) OR MATCH(topics.content) AGAINST (? IN BOOLEAN MODE) OR MATCH(replies.content) AGAINST (? IN BOOLEAN MODE) OR topics.title=? OR topics.content=? OR replies.content=?) AND topics.parentID IN(" + zList + ")")
err = acc.FirstError()
if err != nil {
if err = acc.FirstError(); err != nil {
return nil, err
}*/
// TODO: Cache common IN counts
stmt := acc.RawPrepare("SELECT tid FROM topics WHERE (MATCH(topics.title) AGAINST (? IN BOOLEAN MODE) OR MATCH(topics.content) AGAINST (? IN BOOLEAN MODE)) AND parentID IN(" + zList + ")")
err = acc.FirstError()
if err != nil {
if err = acc.FirstError(); err != nil {
return nil, err
}
err = run(stmt.Query(q, q))
@ -131,8 +127,7 @@ func (s *SQLSearcher) Query(q string, zones []int) (ids []int, err error) {
return nil, err
}
stmt = acc.RawPrepare("SELECT tid FROM replies WHERE MATCH(replies.content) AGAINST (? IN BOOLEAN MODE) AND tid IN(" + zList + ")")
err = acc.FirstError()
if err != nil {
if err = acc.FirstError(); err != nil {
return nil, err
}
err = run(stmt.Query(q))

View File

@ -38,6 +38,21 @@ func (t *SingleServerThaw) Thaw() {
}
}
type TestThaw struct {
}
func NewTestThaw() *TestThaw {
return &TestThaw{}
}
func (t *TestThaw) Thawed() bool {
return true
}
func (t *TestThaw) Thaw() {
}
func (t *TestThaw) Tick() error {
return nil
}
type DefaultThaw struct {
thawed int64
}

View File

@ -8,50 +8,47 @@ import (
"strings"
c "github.com/Azareal/Gosora/common"
qgen "github.com/Azareal/Gosora/query_gen"
)
type AttachmentStmts struct {
get *sql.Stmt
}
var attachmentStmts AttachmentStmts
// TODO: Abstract this with an attachment store
func init() {
c.DbInits.Add(func(acc *qgen.Accumulator) error {
attachmentStmts = AttachmentStmts{
get: acc.Select("attachments").Columns("sectionID, sectionTable, originID, originTable, uploadedBy, path").Where("path=? AND sectionID=? AND sectionTable=?").Prepare(),
}
return acc.FirstError()
})
}
var maxAgeYear = "max-age=" + strconv.Itoa(int(c.Year))
func ShowAttachment(w http.ResponseWriter, r *http.Request, u *c.User, filename string) c.RouteError {
filename = c.Stripslashes(filename)
ext := filepath.Ext("./attachs/" + filename)
if !c.AllowedFileExts.Contains(strings.TrimPrefix(ext, ".")) {
return c.LocalError("Bad extension", w, r, u)
}
sid, err := strconv.Atoi(r.FormValue("sid"))
if err != nil {
return c.LocalError("The sid is not an integer", w, r, u)
}
sectionTable := r.FormValue("stype")
var originTable string
var originID, uploadedBy int
err = attachmentStmts.get.QueryRow(filename, sid, sectionTable).Scan(&sid, &sectionTable, &originID, &originTable, &uploadedBy, &filename)
filename = c.Stripslashes(filename)
if filename == "" {
return c.LocalError("Bad filename", w, r, u)
}
ext := filepath.Ext(filename)
if ext == "" || !c.AllowedFileExts.Contains(strings.TrimPrefix(ext, ".")) {
return c.LocalError("Bad extension", w, r, u)
}
// TODO: Use the same hook table as upstream
hTbl := c.GetHookTable()
skip, rerr := c.H_route_attach_start_hook(hTbl, w, r, u, filename)
if skip || rerr != nil {
return rerr
}
a, err := c.Attachments.GetForRenderRoute(filename, sid, sectionTable)
// ErrCorruptAttachPath is a possibility now
if err == sql.ErrNoRows {
return c.NotFound(w, r, nil)
} else if err != nil {
return c.InternalError(err, w, r)
}
if sectionTable == "forums" {
skip, rerr = c.H_route_attach_post_get_hook(hTbl, w, r, u, a)
if skip || rerr != nil {
return rerr
}
if a.SectionTable == "forums" {
_, ferr := c.SimpleForumUserCheck(w, r, u, sid)
if ferr != nil {
return ferr
@ -63,7 +60,7 @@ func ShowAttachment(w http.ResponseWriter, r *http.Request, u *c.User, filename
return c.LocalError("Unknown section", w, r, u)
}
if originTable != "topics" && originTable != "replies" {
if a.OriginTable != "topics" && a.OriginTable != "replies" {
return c.LocalError("Unknown origin", w, r, u)
}
@ -89,11 +86,11 @@ func ShowAttachment(w http.ResponseWriter, r *http.Request, u *c.User, filename
}
func deleteAttachment(w http.ResponseWriter, r *http.Request, u *c.User, aid int, js bool) c.RouteError {
err := c.DeleteAttachment(aid)
if err == sql.ErrNoRows {
e := c.DeleteAttachment(aid)
if e == sql.ErrNoRows {
return c.NotFoundJSQ(w, r, nil, js)
} else if err != nil {
return c.InternalErrorJSQ(err, w, r, js)
} else if e != nil {
return c.InternalErrorJSQ(e, w, r, js)
}
return nil
}