Add GetForRenderRoute to DefaultAttachmentStore.

Use ErrCorruptAttachPath in DefaultAttachmentStore.
Consider the possibility that the requested attachment doesn't have a file extension.
Rename variables to reduce boilerplate.
Add TestThaw.
Avoid an allocation in hookgen.

Add route_attach_start hook.
Add route_attach_post_get hook.
This commit is contained in:
Azareal 2021-03-24 21:45:18 +10:00
parent adfed477a0
commit 26e8bf32a7
8 changed files with 160 additions and 137 deletions

View File

@ -1,60 +1,62 @@
package hookgen package hookgen
import ( import (
"bytes"
"log" "log"
"os" "os"
"bytes"
"text/template" "text/template"
) )
type HookVars struct { type HookVars struct {
Imports []string Imports []string
Hooks []Hook Hooks []Hook
} }
type Hook struct { type Hook struct {
Name string Name string
Params string Params string
Params2 string Params2 string
Ret string Ret string
Type string Type string
Any bool Any bool
MultiHook bool MultiHook bool
Skip bool Skip bool
DefaultRet string DefaultRet string
Pure string Pure string
} }
func AddHooks(add func(name, params, ret, htype string, multiHook, skip bool, defaultRet, pure string)) { func AddHooks(add func(name, params, ret, htype string, multiHook, skip bool, defaultRet, pure string)) {
vhookskip := func(name, params string) { vhookskip := func(name, params string) {
add(name,params,"(bool,RouteError)","VhookSkippable_",false,true,"false,nil","") add(name, params, "(bool,RouteError)", "VhookSkippable_", false, true, "false,nil", "")
} }
vhookskip("forum_check_pre_perms","w http.ResponseWriter,r *http.Request,u *User,fid *int,h *Header") vhookskip("forum_check_pre_perms", "w http.ResponseWriter,r *http.Request,u *User,fid *int,h *Header")
vhookskip("router_after_filters","w http.ResponseWriter,r *http.Request,prefix string") vhookskip("router_after_filters", "w http.ResponseWriter,r *http.Request,prefix string")
vhookskip("router_pre_route","w http.ResponseWriter,r *http.Request,u *User,prefix string") vhookskip("router_pre_route", "w http.ResponseWriter,r *http.Request,u *User,prefix string")
vhookskip("route_forum_list_start","w http.ResponseWriter,r *http.Request,u *User,h *Header") vhookskip("route_forum_list_start", "w http.ResponseWriter,r *http.Request,u *User,h *Header")
vhookskip("route_topic_list_start","w http.ResponseWriter,r *http.Request,u *User,h *Header") vhookskip("route_topic_list_start", "w http.ResponseWriter,r *http.Request,u *User,h *Header")
vhookskip("route_attach_start", "w http.ResponseWriter,r *http.Request,u *User,fname string")
vhookskip("route_attach_post_get", "w http.ResponseWriter,r *http.Request,u *User,a *Attachment")
vhooknoret := func(name, params string) { vhooknoret := func(name, params string) {
add(name,params,"","Vhooks",false,false,"false,nil","") add(name, params, "", "Vhooks", false, false, "false,nil", "")
} }
vhooknoret("router_end","w http.ResponseWriter,r *http.Request,u *User,prefix string, extraData string") vhooknoret("router_end", "w http.ResponseWriter,r *http.Request,u *User,prefix string,extraData string")
vhooknoret("topic_reply_row_assign","r *ReplyUser") vhooknoret("topic_reply_row_assign", "r *ReplyUser")
//forums_frow_assign //forums_frow_assign
//Hook(name string, data interface{}) interface{} //Hook(name string, data interface{}) interface{}
/*hook := func(name, params, ret, pure string) { /*hook := func(name, params, ret, pure string) {
add(name,params,ret,"Hooks",true,false,ret,pure) add(name,params,ret,"Hooks",true,false,ret,pure)
}*/ }*/
hooknoret := func(name, params string) { hooknoret := func(name, params string) {
add(name,params,"","HooksNoRet",true,false,"","") add(name, params, "", "HooksNoRet", true, false, "", "")
} }
hooknoret("forums_frow_assign","f *Forum") hooknoret("forums_frow_assign", "f *Forum")
hookskip := func(name, params string) { hookskip := func(name, params string) {
add(name,params,"(skip bool)","HooksSkip",true,true,"","") add(name, params, "(skip bool)", "HooksSkip", true, true, "", "")
} }
//hookskip("forums_frow_assign","f *Forum") //hookskip("forums_frow_assign","f *Forum")
hookskip("topic_create_frow_assign","f *Forum") hookskip("topic_create_frow_assign", "f *Forum")
hookss := func(name string) { hookss := func(name string) {
add(name,"d string","string","Sshooks",true,false,"","d") add(name, "d string", "string", "Sshooks", true, false, "", "d")
} }
hookss("topic_ogdesc_assign") hookss("topic_ogdesc_assign")
} }
@ -85,8 +87,8 @@ func H_{{.Name}}_hook(t *HookTable,{{.Params}}) {{.Ret}} { {{if .Any}}
if e := tmpl.Execute(&b, hookVars); e != nil { if e := tmpl.Execute(&b, hookVars); e != nil {
log.Fatal(e) log.Fatal(e)
} }
err := writeFile("./common/gen_extend.go", string(b.Bytes())) err := writeFile("./common/gen_extend.go", b.String())
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
@ -104,4 +106,4 @@ func writeFile(name, body string) error {
return e return e
} }
return f.Close() return f.Close()
} }

View File

@ -14,6 +14,8 @@ import (
var Attachments AttachmentStore var Attachments AttachmentStore
var ErrCorruptAttachPath = errors.New("corrupt attachment path")
type MiniAttachment struct { type MiniAttachment struct {
ID int ID int
SectionID int SectionID int
@ -41,6 +43,7 @@ type Attachment struct {
} }
type AttachmentStore interface { type AttachmentStore interface {
GetForRenderRoute(filename string, sid int, sectionTable string) (*Attachment, error)
FGet(id int) (*Attachment, error) FGet(id int) (*Attachment, error)
Get(id int) (*MiniAttachment, error) Get(id int) (*MiniAttachment, error)
MiniGetList(originTable string, originID int) (alist []*MiniAttachment, err error) MiniGetList(originTable string, originID int) (alist []*MiniAttachment, err error)
@ -58,6 +61,8 @@ type AttachmentStore interface {
} }
type DefaultAttachmentStore struct { type DefaultAttachmentStore struct {
getForRenderRoute *sql.Stmt
fget *sql.Stmt fget *sql.Stmt
get *sql.Stmt get *sql.Stmt
getByObj *sql.Stmt getByObj *sql.Stmt
@ -76,6 +81,8 @@ type DefaultAttachmentStore struct {
func NewDefaultAttachmentStore(acc *qgen.Accumulator) (*DefaultAttachmentStore, error) { func NewDefaultAttachmentStore(acc *qgen.Accumulator) (*DefaultAttachmentStore, error) {
a := "attachments" a := "attachments"
return &DefaultAttachmentStore{ return &DefaultAttachmentStore{
getForRenderRoute: acc.Select(a).Columns("sectionTable, originID, originTable, uploadedBy, path").Where("path=? AND sectionID=? AND sectionTable=?").Prepare(),
fget: acc.Select(a).Columns("originTable, originID, sectionTable, sectionID, uploadedBy, path, extra").Where("attachID=?").Prepare(), fget: acc.Select(a).Columns("originTable, originID, sectionTable, sectionID, uploadedBy, path, extra").Where("attachID=?").Prepare(),
get: acc.Select(a).Columns("originID, sectionID, uploadedBy, path, extra").Where("attachID=?").Prepare(), get: acc.Select(a).Columns("originID, sectionID, uploadedBy, path, extra").Where("attachID=?").Prepare(),
getByObj: acc.Select(a).Columns("attachID, sectionID, uploadedBy, path, extra").Where("originTable=? AND originID=?").Prepare(), getByObj: acc.Select(a).Columns("attachID, sectionID, uploadedBy, path, extra").Where("originTable=? AND originID=?").Prepare(),
@ -93,6 +100,15 @@ func NewDefaultAttachmentStore(acc *qgen.Accumulator) (*DefaultAttachmentStore,
}, acc.FirstError() }, acc.FirstError()
} }
// TODO: Revamp this to make it less of a copy-paste from the original code in the route
// ! Lacks some attachment initialisation code
func (s *DefaultAttachmentStore) GetForRenderRoute(filename string, sid int, sectionTable string) (*Attachment, error) {
a := &Attachment{SectionID: sid}
e := s.getForRenderRoute.QueryRow(filename, sid, sectionTable).Scan(&a.SectionTable, &a.OriginID, &a.OriginTable, &a.UploadedBy, &a.Path)
// TODO: Initialise attachment struct fields?
return a, e
}
func (s *DefaultAttachmentStore) MiniGetList(originTable string, originID int) (alist []*MiniAttachment, err error) { func (s *DefaultAttachmentStore) MiniGetList(originTable string, originID int) (alist []*MiniAttachment, err error) {
rows, err := s.getByObj.Query(originTable, originID) rows, err := s.getByObj.Query(originTable, originID)
defer rows.Close() defer rows.Close()
@ -104,7 +120,7 @@ func (s *DefaultAttachmentStore) MiniGetList(originTable string, originID int) (
} }
a.Ext = strings.TrimPrefix(filepath.Ext(a.Path), ".") a.Ext = strings.TrimPrefix(filepath.Ext(a.Path), ".")
if len(a.Ext) == 0 { if len(a.Ext) == 0 {
return nil, errors.New("corrupt attachment path") return nil, ErrCorruptAttachPath
} }
a.Image = ImageFileExts.Contains(a.Ext) a.Image = ImageFileExts.Contains(a.Ext)
alist = append(alist, a) alist = append(alist, a)
@ -140,7 +156,7 @@ func (s *DefaultAttachmentStore) BulkMiniGetList(originTable string, ids []int)
} }
a.Ext = strings.TrimPrefix(filepath.Ext(a.Path), ".") a.Ext = strings.TrimPrefix(filepath.Ext(a.Path), ".")
if len(a.Ext) == 0 { if len(a.Ext) == 0 {
return nil, errors.New("corrupt attachment path") return nil, ErrCorruptAttachPath
} }
a.Image = ImageFileExts.Contains(a.Ext) a.Image = ImageFileExts.Contains(a.Ext)
if currentID == 0 { if currentID == 0 {
@ -169,7 +185,7 @@ func (s *DefaultAttachmentStore) FGet(id int) (*Attachment, error) {
} }
a.Ext = strings.TrimPrefix(filepath.Ext(a.Path), ".") a.Ext = strings.TrimPrefix(filepath.Ext(a.Path), ".")
if len(a.Ext) == 0 { if len(a.Ext) == 0 {
return nil, errors.New("corrupt attachment path") return nil, ErrCorruptAttachPath
} }
a.Image = ImageFileExts.Contains(a.Ext) a.Image = ImageFileExts.Contains(a.Ext)
return a, nil return a, nil
@ -183,7 +199,7 @@ func (s *DefaultAttachmentStore) Get(id int) (*MiniAttachment, error) {
} }
a.Ext = strings.TrimPrefix(filepath.Ext(a.Path), ".") a.Ext = strings.TrimPrefix(filepath.Ext(a.Path), ".")
if len(a.Ext) == 0 { if len(a.Ext) == 0 {
return nil, errors.New("corrupt attachment path") return nil, ErrCorruptAttachPath
} }
a.Image = ImageFileExts.Contains(a.Ext) a.Image = ImageFileExts.Contains(a.Ext)
return a, nil return a, nil
@ -209,32 +225,32 @@ func (s *DefaultAttachmentStore) MoveToByExtra(sectionID int, originTable, extra
} }
func (s *DefaultAttachmentStore) Count() (count int) { func (s *DefaultAttachmentStore) Count() (count int) {
err := s.count.QueryRow().Scan(&count) e := s.count.QueryRow().Scan(&count)
if err != nil { if e != nil {
LogError(err) LogError(e)
} }
return count return count
} }
func (s *DefaultAttachmentStore) CountIn(originTable string, oid int) (count int) { func (s *DefaultAttachmentStore) CountIn(originTable string, oid int) (count int) {
err := s.countIn.QueryRow(originTable, oid).Scan(&count) e := s.countIn.QueryRow(originTable, oid).Scan(&count)
if err != nil { if e != nil {
LogError(err) LogError(e)
} }
return count return count
} }
func (s *DefaultAttachmentStore) CountInPath(path string) (count int) { func (s *DefaultAttachmentStore) CountInPath(path string) (count int) {
err := s.countInPath.QueryRow(path).Scan(&count) e := s.countInPath.QueryRow(path).Scan(&count)
if err != nil { if e != nil {
LogError(err) LogError(e)
} }
return count return count
} }
func (s *DefaultAttachmentStore) Delete(id int) error { func (s *DefaultAttachmentStore) Delete(id int) error {
_, err := s.delete.Exec(id) _, e := s.delete.Exec(id)
return err return e
} }
// TODO: Split this out of this store // TODO: Split this out of this store
@ -256,10 +272,7 @@ func (s *DefaultAttachmentStore) AddLinked(otable string, oid int) (err error) {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
err = nil err = nil
} }
if err != nil { return err
return err
}
return nil
} }
// TODO: Split this out of this store // TODO: Split this out of this store
@ -280,10 +293,7 @@ func (s *DefaultAttachmentStore) RemoveLinked(otable string, oid int) (err error
} }
err = Rstore.GetCache().Remove(oid) err = Rstore.GetCache().Remove(oid)
} }
if err != nil { return err
return err
}
return nil
} }
// TODO: Add a table for the files and lock the file row when performing tasks related to the file // TODO: Add a table for the files and lock the file row when performing tasks related to the file

View File

@ -1,7 +1,7 @@
/* /*
* *
* Gosora Plugin System * Gosora Plugin System
* Copyright Azareal 2016 - 2020 * Copyright Azareal 2016 - 2021
* *
*/ */
package common package common
@ -25,9 +25,9 @@ type PluginList map[string]*Plugin
// TODO: Have a proper store rather than a map? // TODO: Have a proper store rather than a map?
var Plugins PluginList = make(map[string]*Plugin) var Plugins PluginList = make(map[string]*Plugin)
func (list PluginList) Add(pl *Plugin) { func (l PluginList) Add(pl *Plugin) {
buildPlugin(pl) buildPlugin(pl)
list[pl.UName] = pl l[pl.UName] = pl
} }
func buildPlugin(pl *Plugin) { func buildPlugin(pl *Plugin) {
@ -94,6 +94,8 @@ var hookTable = &HookTable{
"route_topic_list_start": nil, "route_topic_list_start": nil,
"route_topic_list_mostviewed_start": nil, "route_topic_list_mostviewed_start": nil,
"route_forum_list_start": nil, "route_forum_list_start": nil,
"route_attach_start": nil,
"route_attach_post_get": nil,
"action_end_create_topic": nil, "action_end_create_topic": nil,
"action_end_edit_topic": nil, "action_end_edit_topic": nil,
@ -126,8 +128,8 @@ var hookTable = &HookTable{
"tasks_tick_widget_wol": nil, "tasks_tick_widget_wol": nil,
}, },
map[string][]func(string) string{ map[string][]func(string) string{
"preparse_preassign": nil, "preparse_preassign": nil,
"parse_assign": nil, "parse_assign": nil,
"topic_ogdesc_assign": nil, "topic_ogdesc_assign": nil,
}, },
nil, nil,

View File

@ -9,7 +9,7 @@ import (
var IPSearch IPSearcher var IPSearch IPSearcher
type IPSearcher interface { type IPSearcher interface {
Lookup(ip string) (uids []int, err error) Lookup(ip string) (uids []int, e error)
} }
type DefaultIPSearcher struct { type DefaultIPSearcher struct {
@ -23,27 +23,29 @@ type DefaultIPSearcher struct {
func NewDefaultIPSearcher() (*DefaultIPSearcher, error) { func NewDefaultIPSearcher() (*DefaultIPSearcher, error) {
acc := qgen.NewAcc() acc := qgen.NewAcc()
uu := "users" uu := "users"
q := func(tbl string) *sql.Stmt {
return acc.Select(uu).Columns("uid").InQ("uid", acc.Select(tbl).Columns("createdBy").Where("ip=?")).Prepare()
}
return &DefaultIPSearcher{ return &DefaultIPSearcher{
searchUsers: acc.Select(uu).Columns("uid").Where("last_ip=? OR last_ip LIKE CONCAT('%-',?)").Prepare(), searchUsers: acc.Select(uu).Columns("uid").Where("last_ip=? OR last_ip LIKE CONCAT('%-',?)").Prepare(),
searchTopics: acc.Select(uu).Columns("uid").InQ("uid", acc.Select("topics").Columns("createdBy").Where("ip=?")).Prepare(), searchTopics: q("topics"),
searchReplies: acc.Select(uu).Columns("uid").InQ("uid", acc.Select("replies").Columns("createdBy").Where("ip=?")).Prepare(), searchReplies: q("replies"),
searchUsersReplies: acc.Select(uu).Columns("uid").InQ("uid", acc.Select("users_replies").Columns("createdBy").Where("ip=?")).Prepare(), searchUsersReplies: q("users_replies"),
}, acc.FirstError() }, acc.FirstError()
} }
func (s *DefaultIPSearcher) Lookup(ip string) (uids []int, err error) { func (s *DefaultIPSearcher) Lookup(ip string) (uids []int, e error) {
var uid int var uid int
reqUserList := make(map[int]bool) reqUserList := make(map[int]bool)
runQuery2 := func(rows *sql.Rows, err error) error { runQuery2 := func(rows *sql.Rows, e error) error {
if err != nil { if e != nil {
return err return e
} }
defer rows.Close() defer rows.Close()
for rows.Next() { for rows.Next() {
err := rows.Scan(&uid) if e := rows.Scan(&uid); e != nil {
if err != nil { return e
return err
} }
reqUserList[uid] = true reqUserList[uid] = true
} }
@ -53,21 +55,21 @@ func (s *DefaultIPSearcher) Lookup(ip string) (uids []int, err error) {
return runQuery2(stmt.Query(ip)) return runQuery2(stmt.Query(ip))
} }
err = runQuery2(s.searchUsers.Query(ip, ip)) e = runQuery2(s.searchUsers.Query(ip, ip))
if err != nil { if e != nil {
return uids, err return uids, e
} }
err = runQuery(s.searchTopics) e = runQuery(s.searchTopics)
if err != nil { if e != nil {
return uids, err return uids, e
} }
err = runQuery(s.searchReplies) e = runQuery(s.searchReplies)
if err != nil { if e != nil {
return uids, err return uids, e
} }
err = runQuery(s.searchUsersReplies) e = runQuery(s.searchUsersReplies)
if err != nil { if e != nil {
return uids, err return uids, e
} }
// Convert the user ID map to a slice, then bulk load the users // Convert the user ID map to a slice, then bulk load the users

View File

@ -44,13 +44,13 @@ func simpleForumUserCheck(w http.ResponseWriter, r *http.Request, u *User, fid i
return h, rerr return h, rerr
} }
fperms, err := FPStore.Get(fid, u.Group) fp, err := FPStore.Get(fid, u.Group)
if err == ErrNoRows { if err == ErrNoRows {
fperms = BlankForumPerms() fp = BlankForumPerms()
} else if err != nil { } else if err != nil {
return h, InternalError(err, w, r) return h, InternalError(err, w, r)
} }
cascadeForumPerms(fperms, u) cascadeForumPerms(fp, u)
return h, nil return h, nil
} }
@ -72,13 +72,13 @@ func forumUserCheck(h *Header, w http.ResponseWriter, r *http.Request, u *User,
return rerr return rerr
} }
fperms, err := FPStore.Get(fid, u.Group) fp, err := FPStore.Get(fid, u.Group)
if err == ErrNoRows { if err == ErrNoRows {
fperms = BlankForumPerms() fp = BlankForumPerms()
} else if err != nil { } else if err != nil {
return InternalError(err, w, r) return InternalError(err, w, r)
} }
cascadeForumPerms(fperms, u) cascadeForumPerms(fp, u)
h.CurrentUser = u // TODO: Use a pointer instead for CurrentUser, so we don't have to do this h.CurrentUser = u // TODO: Use a pointer instead for CurrentUser, so we don't have to do this
return rerr return rerr
} }

View File

@ -44,19 +44,18 @@ func NewSQLSearcher(acc *qgen.Accumulator) (*SQLSearcher, error) {
func (s *SQLSearcher) queryAll(q string) ([]int, error) { func (s *SQLSearcher) queryAll(q string) ([]int, error) {
var ids []int var ids []int
run := func(stmt *sql.Stmt, q ...interface{}) error { run := func(stmt *sql.Stmt, q ...interface{}) error {
rows, err := stmt.Query(q...) rows, e := stmt.Query(q...)
if err == sql.ErrNoRows { if e == sql.ErrNoRows {
return nil return nil
} else if err != nil { } else if e != nil {
return err return e
} }
defer rows.Close() defer rows.Close()
for rows.Next() { for rows.Next() {
var id int var id int
err := rows.Scan(&id) if e := rows.Scan(&id); e != nil {
if err != nil { return e
return err
} }
ids = append(ids, id) ids = append(ids, id)
} }
@ -81,19 +80,18 @@ func (s *SQLSearcher) Query(q string, zones []int) (ids []int, err error) {
if len(zones) == 0 { if len(zones) == 0 {
return nil, nil return nil, nil
} }
run := func(rows *sql.Rows, err error) error { run := func(rows *sql.Rows, e error) error {
/*if err == sql.ErrNoRows { /*if e == sql.ErrNoRows {
return nil return nil
} else */if err != nil { } else */if e != nil {
return err return e
} }
defer rows.Close() defer rows.Close()
for rows.Next() { for rows.Next() {
var id int var id int
err := rows.Scan(&id) if e := rows.Scan(&id); e != nil {
if err != nil { return e
return err
} }
ids = append(ids, id) ids = append(ids, id)
} }
@ -116,14 +114,12 @@ func (s *SQLSearcher) Query(q string, zones []int) (ids []int, err error) {
acc := qgen.NewAcc() acc := qgen.NewAcc()
/*stmt := acc.RawPrepare("SELECT topics.tid FROM topics INNER JOIN replies ON topics.tid = replies.tid WHERE (MATCH(topics.title) AGAINST (? IN BOOLEAN MODE) OR MATCH(topics.content) AGAINST (? IN BOOLEAN MODE) OR MATCH(replies.content) AGAINST (? IN BOOLEAN MODE) OR topics.title=? OR topics.content=? OR replies.content=?) AND topics.parentID IN(" + zList + ")") /*stmt := acc.RawPrepare("SELECT topics.tid FROM topics INNER JOIN replies ON topics.tid = replies.tid WHERE (MATCH(topics.title) AGAINST (? IN BOOLEAN MODE) OR MATCH(topics.content) AGAINST (? IN BOOLEAN MODE) OR MATCH(replies.content) AGAINST (? IN BOOLEAN MODE) OR topics.title=? OR topics.content=? OR replies.content=?) AND topics.parentID IN(" + zList + ")")
err = acc.FirstError() if err = acc.FirstError(); err != nil {
if err != nil {
return nil, err return nil, err
}*/ }*/
// TODO: Cache common IN counts // TODO: Cache common IN counts
stmt := acc.RawPrepare("SELECT tid FROM topics WHERE (MATCH(topics.title) AGAINST (? IN BOOLEAN MODE) OR MATCH(topics.content) AGAINST (? IN BOOLEAN MODE)) AND parentID IN(" + zList + ")") stmt := acc.RawPrepare("SELECT tid FROM topics WHERE (MATCH(topics.title) AGAINST (? IN BOOLEAN MODE) OR MATCH(topics.content) AGAINST (? IN BOOLEAN MODE)) AND parentID IN(" + zList + ")")
err = acc.FirstError() if err = acc.FirstError(); err != nil {
if err != nil {
return nil, err return nil, err
} }
err = run(stmt.Query(q, q)) err = run(stmt.Query(q, q))
@ -131,8 +127,7 @@ func (s *SQLSearcher) Query(q string, zones []int) (ids []int, err error) {
return nil, err return nil, err
} }
stmt = acc.RawPrepare("SELECT tid FROM replies WHERE MATCH(replies.content) AGAINST (? IN BOOLEAN MODE) AND tid IN(" + zList + ")") stmt = acc.RawPrepare("SELECT tid FROM replies WHERE MATCH(replies.content) AGAINST (? IN BOOLEAN MODE) AND tid IN(" + zList + ")")
err = acc.FirstError() if err = acc.FirstError(); err != nil {
if err != nil {
return nil, err return nil, err
} }
err = run(stmt.Query(q)) err = run(stmt.Query(q))

View File

@ -38,6 +38,21 @@ func (t *SingleServerThaw) Thaw() {
} }
} }
type TestThaw struct {
}
func NewTestThaw() *TestThaw {
return &TestThaw{}
}
func (t *TestThaw) Thawed() bool {
return true
}
func (t *TestThaw) Thaw() {
}
func (t *TestThaw) Tick() error {
return nil
}
type DefaultThaw struct { type DefaultThaw struct {
thawed int64 thawed int64
} }

View File

@ -8,50 +8,47 @@ import (
"strings" "strings"
c "github.com/Azareal/Gosora/common" c "github.com/Azareal/Gosora/common"
qgen "github.com/Azareal/Gosora/query_gen"
) )
type AttachmentStmts struct {
get *sql.Stmt
}
var attachmentStmts AttachmentStmts
// TODO: Abstract this with an attachment store
func init() {
c.DbInits.Add(func(acc *qgen.Accumulator) error {
attachmentStmts = AttachmentStmts{
get: acc.Select("attachments").Columns("sectionID, sectionTable, originID, originTable, uploadedBy, path").Where("path=? AND sectionID=? AND sectionTable=?").Prepare(),
}
return acc.FirstError()
})
}
var maxAgeYear = "max-age=" + strconv.Itoa(int(c.Year)) var maxAgeYear = "max-age=" + strconv.Itoa(int(c.Year))
func ShowAttachment(w http.ResponseWriter, r *http.Request, u *c.User, filename string) c.RouteError { func ShowAttachment(w http.ResponseWriter, r *http.Request, u *c.User, filename string) c.RouteError {
filename = c.Stripslashes(filename)
ext := filepath.Ext("./attachs/" + filename)
if !c.AllowedFileExts.Contains(strings.TrimPrefix(ext, ".")) {
return c.LocalError("Bad extension", w, r, u)
}
sid, err := strconv.Atoi(r.FormValue("sid")) sid, err := strconv.Atoi(r.FormValue("sid"))
if err != nil { if err != nil {
return c.LocalError("The sid is not an integer", w, r, u) return c.LocalError("The sid is not an integer", w, r, u)
} }
sectionTable := r.FormValue("stype") sectionTable := r.FormValue("stype")
var originTable string filename = c.Stripslashes(filename)
var originID, uploadedBy int if filename == "" {
err = attachmentStmts.get.QueryRow(filename, sid, sectionTable).Scan(&sid, &sectionTable, &originID, &originTable, &uploadedBy, &filename) return c.LocalError("Bad filename", w, r, u)
}
ext := filepath.Ext(filename)
if ext == "" || !c.AllowedFileExts.Contains(strings.TrimPrefix(ext, ".")) {
return c.LocalError("Bad extension", w, r, u)
}
// TODO: Use the same hook table as upstream
hTbl := c.GetHookTable()
skip, rerr := c.H_route_attach_start_hook(hTbl, w, r, u, filename)
if skip || rerr != nil {
return rerr
}
a, err := c.Attachments.GetForRenderRoute(filename, sid, sectionTable)
// ErrCorruptAttachPath is a possibility now
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return c.NotFound(w, r, nil) return c.NotFound(w, r, nil)
} else if err != nil { } else if err != nil {
return c.InternalError(err, w, r) return c.InternalError(err, w, r)
} }
if sectionTable == "forums" { skip, rerr = c.H_route_attach_post_get_hook(hTbl, w, r, u, a)
if skip || rerr != nil {
return rerr
}
if a.SectionTable == "forums" {
_, ferr := c.SimpleForumUserCheck(w, r, u, sid) _, ferr := c.SimpleForumUserCheck(w, r, u, sid)
if ferr != nil { if ferr != nil {
return ferr return ferr
@ -63,7 +60,7 @@ func ShowAttachment(w http.ResponseWriter, r *http.Request, u *c.User, filename
return c.LocalError("Unknown section", w, r, u) return c.LocalError("Unknown section", w, r, u)
} }
if originTable != "topics" && originTable != "replies" { if a.OriginTable != "topics" && a.OriginTable != "replies" {
return c.LocalError("Unknown origin", w, r, u) return c.LocalError("Unknown origin", w, r, u)
} }
@ -89,11 +86,11 @@ func ShowAttachment(w http.ResponseWriter, r *http.Request, u *c.User, filename
} }
func deleteAttachment(w http.ResponseWriter, r *http.Request, u *c.User, aid int, js bool) c.RouteError { func deleteAttachment(w http.ResponseWriter, r *http.Request, u *c.User, aid int, js bool) c.RouteError {
err := c.DeleteAttachment(aid) e := c.DeleteAttachment(aid)
if err == sql.ErrNoRows { if e == sql.ErrNoRows {
return c.NotFoundJSQ(w, r, nil, js) return c.NotFoundJSQ(w, r, nil, js)
} else if err != nil { } else if e != nil {
return c.InternalErrorJSQ(err, w, r, js) return c.InternalErrorJSQ(e, w, r, js)
} }
return nil return nil
} }