diff --git a/common/attachments.go b/common/attachments.go index 0ce563ce..de9a8af5 100644 --- a/common/attachments.go +++ b/common/attachments.go @@ -4,8 +4,9 @@ import ( "database/sql" "errors" "strings" + "os" - "github.com/Azareal/Gosora/query_gen" + qgen "github.com/Azareal/Gosora/query_gen" ) var Attachments AttachmentStore @@ -26,9 +27,9 @@ type AttachmentStore interface { Get(id int) (*MiniAttachment, error) MiniGetList(originTable string, originID int) (alist []*MiniAttachment, err error) BulkMiniGetList(originTable string, ids []int) (amap map[int][]*MiniAttachment, err error) - Add(sectionID int, sectionTable string, originID int, originTable string, uploadedBy int, path string, extra string) (int, error) + Add(sectionID int, sectionTable string, originID int, originTable string, uploadedBy int, path, extra string) (int, error) MoveTo(sectionID int, originID int, originTable string) error - MoveToByExtra(sectionID int, originTable string, extra string) error + MoveToByExtra(sectionID int, originTable, extra string) error Count() int CountIn(originTable string, oid int) int CountInPath(path string) int @@ -50,7 +51,7 @@ type DefaultAttachmentStore struct { func NewDefaultAttachmentStore(acc *qgen.Accumulator) (*DefaultAttachmentStore, error) { a := "attachments" return &DefaultAttachmentStore{ - get: acc.Select(a).Columns("originID, sectionID, uploadedBy, path, extra").Where("attachID = ?").Prepare(), + get: acc.Select(a).Columns("originID, sectionID, uploadedBy, path, extra").Where("attachID=?").Prepare(), getByObj: acc.Select(a).Columns("attachID, sectionID, uploadedBy, path, extra").Where("originTable = ? AND originID = ?").Prepare(), add: acc.Insert(a).Columns("sectionID, sectionTable, originID, originTable, uploadedBy, path, extra").Fields("?,?,?,?,?,?,?").Prepare(), count: acc.Count(a).Prepare(), @@ -58,7 +59,7 @@ func NewDefaultAttachmentStore(acc *qgen.Accumulator) (*DefaultAttachmentStore, countInPath: acc.Count(a).Where("path = ?").Prepare(), move: acc.Update(a).Set("sectionID = ?").Where("originID = ? AND originTable = ?").Prepare(), moveByExtra: acc.Update(a).Set("sectionID = ?").Where("originTable = ? AND extra = ?").Prepare(), - delete: acc.Delete(a).Where("attachID = ?").Prepare(), + delete: acc.Delete(a).Where("attachID=?").Prepare(), }, acc.FirstError() } @@ -141,7 +142,7 @@ func (s *DefaultAttachmentStore) Get(id int) (*MiniAttachment, error) { return a, nil } -func (s *DefaultAttachmentStore) Add(sectionID int, sectionTable string, originID int, originTable string, uploadedBy int, path string, extra string) (int, error) { +func (s *DefaultAttachmentStore) Add(sectionID int, sectionTable string, originID int, originTable string, uploadedBy int, path, extra string) (int, error) { res, err := s.add.Exec(sectionID, sectionTable, originID, originTable, uploadedBy, path, extra) if err != nil { return 0, err @@ -155,7 +156,7 @@ func (s *DefaultAttachmentStore) MoveTo(sectionID int, originID int, originTable return err } -func (s *DefaultAttachmentStore) MoveToByExtra(sectionID int, originTable string, extra string) error { +func (s *DefaultAttachmentStore) MoveToByExtra(sectionID int, originTable, extra string) error { _, err := s.moveByExtra.Exec(sectionID, originTable, extra) return err } @@ -188,3 +189,25 @@ func (s *DefaultAttachmentStore) Delete(aid int) error { _, err := s.delete.Exec(aid) return err } + +// TODO: Add a table for the files and lock the file row when performing tasks related to the file +func DeleteAttachment(aid int) error { + attach, err := Attachments.Get(aid) + if err != nil { + return err + } + err = Attachments.Delete(aid) + if err != nil { + return err + } + + count := Attachments.CountInPath(attach.Path) + if count == 0 { + err := os.Remove("./attachs/" + attach.Path) + if err != nil { + return err + } + } + + return nil +} \ No newline at end of file diff --git a/common/topic.go b/common/topic.go index 64c80621..be760145 100644 --- a/common/topic.go +++ b/common/topic.go @@ -184,25 +184,25 @@ func (t *TopicsRow) Topic() *Topic { }*/ type TopicStmts struct { - getRids *sql.Stmt - getReplies *sql.Stmt - addReplies *sql.Stmt - updateLastReply *sql.Stmt - lock *sql.Stmt - unlock *sql.Stmt - moveTo *sql.Stmt - stick *sql.Stmt - unstick *sql.Stmt - hasLikedTopic *sql.Stmt - createLike *sql.Stmt - addLikesToTopic *sql.Stmt - delete *sql.Stmt + getRids *sql.Stmt + getReplies *sql.Stmt + addReplies *sql.Stmt + updateLastReply *sql.Stmt + lock *sql.Stmt + unlock *sql.Stmt + moveTo *sql.Stmt + stick *sql.Stmt + unstick *sql.Stmt + hasLikedTopic *sql.Stmt + createLike *sql.Stmt + addLikesToTopic *sql.Stmt + delete *sql.Stmt deleteLikesForTopic *sql.Stmt - deleteActivity *sql.Stmt - deleteActivitySubs *sql.Stmt - edit *sql.Stmt - setPoll *sql.Stmt - createAction *sql.Stmt + deleteActivity *sql.Stmt + deleteActivitySubs *sql.Stmt + edit *sql.Stmt + setPoll *sql.Stmt + createAction *sql.Stmt getTopicUser *sql.Stmt // TODO: Can we get rid of this? getByReplyID *sql.Stmt @@ -214,25 +214,25 @@ func init() { DbInits.Add(func(acc *qgen.Accumulator) error { t := "topics" topicStmts = TopicStmts{ - getRids: acc.Select("replies").Columns("rid").Where("tid = ?").Orderby("rid ASC").Limit("?,?").Prepare(), - getReplies: acc.SimpleLeftJoin("replies AS r", "users AS u", "r.rid, r.content, r.createdBy, r.createdAt, r.lastEdit, r.lastEditBy, u.avatar, u.name, u.group, u.level, r.ipaddress, r.likeCount, r.attachCount, r.actionType", "r.createdBy = u.uid", "r.tid = ?", "r.rid ASC", "?,?"), - addReplies: acc.Update(t).Set("postCount = postCount + ?, lastReplyBy = ?, lastReplyAt = UTC_TIMESTAMP()").Where("tid = ?").Prepare(), - updateLastReply: acc.Update(t).Set("lastReplyID=?").Where("lastReplyID > ? AND tid = ?").Prepare(), - lock: acc.Update(t).Set("is_closed=1").Where("tid=?").Prepare(), - unlock: acc.Update(t).Set("is_closed=0").Where("tid=?").Prepare(), - moveTo: acc.Update(t).Set("parentID=?").Where("tid=?").Prepare(), - stick: acc.Update(t).Set("sticky=1").Where("tid=?").Prepare(), - unstick: acc.Update(t).Set("sticky=0").Where("tid=?").Prepare(), - hasLikedTopic: acc.Select("likes").Columns("targetItem").Where("sentBy=? and targetItem=? and targetType='topics'").Prepare(), - createLike: acc.Insert("likes").Columns("weight, targetItem, targetType, sentBy, createdAt").Fields("?,?,?,?,UTC_TIMESTAMP()").Prepare(), - addLikesToTopic: acc.Update(t).Set("likeCount=likeCount+?").Where("tid = ?").Prepare(), - delete: acc.Delete(t).Where("tid=?").Prepare(), - deleteLikesForTopic: acc.Delete("likes").Where("targetItem=? AND targetType='topics'").Prepare(), - deleteActivity: acc.Delete("activity_stream").Where("elementID=? AND elementType='topic'").Prepare(), - deleteActivitySubs: acc.Delete("activity_subscriptions").Where("targetID=? AND targetType='topic'").Prepare(), - edit: acc.Update(t).Set("title=?,content=?,parsed_content=?").Where("tid=?").Prepare(), // TODO: Only run the content update bits on non-polls, does this matter? - setPoll: acc.Update(t).Set("poll=?").Where("tid=? AND poll=0").Prepare(), - createAction: acc.Insert("replies").Columns("tid, actionType, ipaddress, createdBy, createdAt, lastUpdated, content, parsed_content").Fields("?,?,?,?,UTC_TIMESTAMP(),UTC_TIMESTAMP(),'',''").Prepare(), + getRids: acc.Select("replies").Columns("rid").Where("tid = ?").Orderby("rid ASC").Limit("?,?").Prepare(), + getReplies: acc.SimpleLeftJoin("replies AS r", "users AS u", "r.rid, r.content, r.createdBy, r.createdAt, r.lastEdit, r.lastEditBy, u.avatar, u.name, u.group, u.level, r.ipaddress, r.likeCount, r.attachCount, r.actionType", "r.createdBy = u.uid", "r.tid = ?", "r.rid ASC", "?,?"), + addReplies: acc.Update(t).Set("postCount = postCount + ?, lastReplyBy = ?, lastReplyAt = UTC_TIMESTAMP()").Where("tid = ?").Prepare(), + updateLastReply: acc.Update(t).Set("lastReplyID=?").Where("lastReplyID > ? AND tid = ?").Prepare(), + lock: acc.Update(t).Set("is_closed=1").Where("tid=?").Prepare(), + unlock: acc.Update(t).Set("is_closed=0").Where("tid=?").Prepare(), + moveTo: acc.Update(t).Set("parentID=?").Where("tid=?").Prepare(), + stick: acc.Update(t).Set("sticky=1").Where("tid=?").Prepare(), + unstick: acc.Update(t).Set("sticky=0").Where("tid=?").Prepare(), + hasLikedTopic: acc.Select("likes").Columns("targetItem").Where("sentBy=? and targetItem=? and targetType='topics'").Prepare(), + createLike: acc.Insert("likes").Columns("weight, targetItem, targetType, sentBy, createdAt").Fields("?,?,?,?,UTC_TIMESTAMP()").Prepare(), + addLikesToTopic: acc.Update(t).Set("likeCount=likeCount+?").Where("tid = ?").Prepare(), + delete: acc.Delete(t).Where("tid=?").Prepare(), + deleteLikesForTopic: acc.Delete("likes").Where("targetItem=? AND targetType='topics'").Prepare(), + deleteActivity: acc.Delete("activity_stream").Where("elementID=? AND elementType='topic'").Prepare(), + deleteActivitySubs: acc.Delete("activity_subscriptions").Where("targetID=? AND targetType='topic'").Prepare(), + edit: acc.Update(t).Set("title=?,content=?,parsed_content=?").Where("tid=?").Prepare(), // TODO: Only run the content update bits on non-polls, does this matter? + setPoll: acc.Update(t).Set("poll=?").Where("tid=? AND poll=0").Prepare(), + createAction: acc.Insert("replies").Columns("tid, actionType, ipaddress, createdBy, createdAt, lastUpdated, content, parsed_content").Fields("?,?,?,?,UTC_TIMESTAMP(),UTC_TIMESTAMP(),'',''").Prepare(), getTopicUser: acc.SimpleLeftJoin("topics AS t", "users AS u", "t.title, t.content, t.createdBy, t.createdAt, t.lastReplyAt, t.lastReplyBy, t.lastReplyID, t.is_closed, t.sticky, t.parentID, t.ipaddress, t.views, t.postCount, t.likeCount, t.attachCount,t.poll, u.name, u.avatar, u.group, u.level", "t.createdBy = u.uid", "tid = ?", "", ""), getByReplyID: acc.SimpleLeftJoin("replies AS r", "topics AS t", "t.tid, t.title, t.content, t.createdBy, t.createdAt, t.is_closed, t.sticky, t.parentID, t.ipaddress, t.views, t.postCount, t.likeCount, t.poll, t.data", "r.tid = t.tid", "rid = ?", "", ""), @@ -350,6 +350,35 @@ func handleLikedTopicReplies(tid int) error { return rows.Err() } +func handleTopicAttachments(tid int) error { + f := func(stmt *sql.Stmt) error { + rows, err := stmt.Query(tid) + if err != nil { + return err + } + defer rows.Close() + + for rows.Next() { + var aid int + err := rows.Scan(&aid) + if err != nil { + return err + } + err = DeleteAttachment(aid) + if err != nil && err != sql.ErrNoRows { + return err + } + } + + return rows.Err() + } + err := f(userStmts.getAttachmentsOfTopic) + if err != nil { + return err + } + return f(userStmts.getAttachmentsOfTopic2) +} + // TODO: Use a transaction here func (t *Topic) Delete() error { creator, err := Users.Get(t.CreatedBy) @@ -380,6 +409,10 @@ func (t *Topic) Delete() error { if err != nil { return err } + err = handleTopicAttachments(t.ID) + if err != nil { + return err + } _, err = topicStmts.deleteActivitySubs.Exec(t.ID) if err != nil { return err @@ -389,7 +422,7 @@ func (t *Topic) Delete() error { return err } if t.Poll > 0 { - err = (&Poll{ID:t.Poll}).Delete() + err = (&Poll{ID: t.Poll}).Delete() if err != nil { return err } @@ -523,7 +556,7 @@ func (ru *ReplyUser) Init() error { } // TODO: Factor TopicUser into a *Topic and *User, as this starting to become overly complicated x.x -func (t *TopicUser) Replies(offset int, pFrag int, user *User) (rlist []*ReplyUser, ogdesc string, err error) { +func (t *TopicUser) Replies(offset, pFrag int, user *User) (rlist []*ReplyUser, ogdesc string, err error) { var likedMap map[int]int if user.Liked > 0 { likedMap = make(map[int]int) diff --git a/common/user.go b/common/user.go index e94e25d5..7d12fdca 100644 --- a/common/user.go +++ b/common/user.go @@ -150,6 +150,8 @@ type UserStmts struct { deleteProfilePosts *sql.Stmt deleteReplyPosts *sql.Stmt getLikedRepliesOfTopic *sql.Stmt + getAttachmentsOfTopic *sql.Stmt + getAttachmentsOfTopic2 *sql.Stmt } var userStmts UserStmts @@ -191,6 +193,8 @@ func init() { deleteProfilePosts: acc.Select("users_replies").Columns("rid").Where("createdBy=?").Prepare(), deleteReplyPosts: acc.Select("replies").Columns("rid,tid").Where("createdBy=?").Prepare(), getLikedRepliesOfTopic: acc.Select("replies").Columns("rid").Where("tid=? AND likeCount>0").Prepare(), + getAttachmentsOfTopic: acc.Select("attachments").Columns("attachID").Where("originID=? AND originTable='topics'").Prepare(), + getAttachmentsOfTopic2: acc.Select("attachments").Columns("attachID").Where("extra=? AND originTable='replies'").Prepare(), } return acc.FirstError() }) @@ -356,6 +360,10 @@ func (u *User) DeletePosts() error { if err != nil { return err } + err = handleTopicAttachments(tid) + if err != nil { + return err + } _, err = topicStmts.deleteActivitySubs.Exec(tid) if err != nil { return err @@ -365,7 +373,7 @@ func (u *User) DeletePosts() error { return err } if poll > 0 { - err = (&Poll{ID:poll}).Delete() + err = (&Poll{ID: poll}).Delete() if err != nil { return err } diff --git a/routes/attachments.go b/routes/attachments.go index f0b971e9..c28c293a 100644 --- a/routes/attachments.go +++ b/routes/attachments.go @@ -3,13 +3,12 @@ package routes import ( "database/sql" "net/http" - "os" "path/filepath" "strconv" "strings" c "github.com/Azareal/Gosora/common" - "github.com/Azareal/Gosora/query_gen" + qgen "github.com/Azareal/Gosora/query_gen" ) type AttachmentStmts struct { @@ -22,12 +21,14 @@ var attachmentStmts AttachmentStmts func init() { c.DbInits.Add(func(acc *qgen.Accumulator) error { attachmentStmts = AttachmentStmts{ - get: acc.Select("attachments").Columns("sectionID, sectionTable, originID, originTable, uploadedBy, path").Where("path = ? AND sectionID = ? AND sectionTable = ?").Prepare(), + get: acc.Select("attachments").Columns("sectionID, sectionTable, originID, originTable, uploadedBy, path").Where("path=? AND sectionID=? AND sectionTable=?").Prepare(), } return acc.FirstError() }) } +var maxAgeYear = "max-age=" + strconv.Itoa(int(c.Year)) + func ShowAttachment(w http.ResponseWriter, r *http.Request, user c.User, filename string) c.RouteError { filename = c.Stripslashes(filename) ext := filepath.Ext("./attachs/" + filename) @@ -61,13 +62,13 @@ func ShowAttachment(w http.ResponseWriter, r *http.Request, user c.User, filenam } else { return c.LocalError("Unknown section", w, r, user) } - + if originTable != "topics" && originTable != "replies" { return c.LocalError("Unknown origin", w, r, user) } if !user.Loggedin { - w.Header().Set("Cache-Control", "max-age="+strconv.Itoa(int(c.Year))) + w.Header().Set("Cache-Control", maxAgeYear) } else { guest := c.GuestUser _, ferr := c.SimpleForumUserCheck(w, r, &guest, sid) @@ -76,7 +77,7 @@ func ShowAttachment(w http.ResponseWriter, r *http.Request, user c.User, filenam } h := w.Header() if guest.Perms.ViewTopic { - h.Set("Cache-Control", "max-age="+strconv.Itoa(int(c.Year))) + h.Set("Cache-Control", maxAgeYear) } else { h.Set("Cache-Control", "private") } @@ -87,32 +88,12 @@ func ShowAttachment(w http.ResponseWriter, r *http.Request, user c.User, filenam return nil } -// TODO: Add a table for the files and lock the file row when performing tasks related to the file func deleteAttachment(w http.ResponseWriter, r *http.Request, user c.User, aid int, js bool) c.RouteError { - attach, err := c.Attachments.Get(aid) + err := c.DeleteAttachment(aid) if err == sql.ErrNoRows { return c.NotFoundJSQ(w, r, nil, js) - } else if err != nil { - return c.InternalErrorJSQ(err, w, r, js) } - - err = c.Attachments.Delete(aid) - if err != nil { - return c.InternalErrorJSQ(err, w, r, js) - } - - count := c.Attachments.CountInPath(attach.Path) - if err != nil { - return c.InternalErrorJSQ(err, w, r, js) - } - if count == 0 { - err := os.Remove("./attachs/" + attach.Path) - if err != nil { - return c.InternalErrorJSQ(err, w, r, js) - } - } - - return nil + return c.InternalErrorJSQ(err, w, r, js) } // TODO: Stop duplicating this code