diff --git a/common/topic.go b/common/topic.go index a9f2c02e..a942f8a1 100644 --- a/common/topic.go +++ b/common/topic.go @@ -199,6 +199,7 @@ type TopicStmts struct { deleteActivity *sql.Stmt edit *sql.Stmt setPoll *sql.Stmt + removePoll *sql.Stmt testSetCreatedAt *sql.Stmt createAction *sql.Stmt @@ -235,6 +236,7 @@ func init() { deleteActivity: acc.Delete("activity_stream").Where("elementID=? AND elementType='topic'").Prepare(), edit: set("title=?,content=?,parsed_content=?"), // TODO: Only run the content update bits on non-polls, does this matter? setPoll: acc.Update(t).Set("poll=?").Where("tid=? AND poll=0").Prepare(), + removePoll: acc.Update(t).Set("poll=0").Where("tid=?").Prepare(), testSetCreatedAt: set("createdAt=?"), createAction: acc.Insert("replies").Columns("tid,actionType,ip,createdBy,createdAt,lastUpdated,content,parsed_content").Fields("?,?,?,?,UTC_TIMESTAMP(),UTC_TIMESTAMP(),'',''").Prepare(), @@ -519,6 +521,12 @@ func (t *Topic) SetPoll(pollID int) error { return e } +func (t *Topic) RemovePoll() error { + _, e := topicStmts.removePoll.Exec(t.ID) // TODO: Sniff if this changed anything to see if we hit an existing poll + t.cacheRemove() + return e +} + // TODO: Have this go through the ReplyStore? // TODO: Return the rid? func (t *Topic) CreateActionReply(action, ip string, uid int) (err error) { diff --git a/misc_test.go b/misc_test.go index 003d22b2..7c40bb24 100644 --- a/misc_test.go +++ b/misc_test.go @@ -1505,6 +1505,8 @@ func TestPolls(t *testing.T) { tid, e := c.Topics.Create(2, "Poll Test", "Filler Body", 1, "") expectNilErr(t, e) topic, e := c.Topics.Get(tid) + expectNilErr(t, e) + exf(topic.Poll == 0, "t.Poll should be %d not %d", 0, topic.Poll) /*Options map[int]string Results map[int]int // map[optionIndex]points QuickOptions []PollOption // TODO: Fix up the template transpiler so we don't need to use this hack anymore @@ -1515,6 +1517,9 @@ func TestPolls(t *testing.T) { exf(pid == 1, "poll id should be 1 not %d", pid) ex(c.Polls.Exists(1), "poll 1 should exist") exf(c.Polls.Count() == 1, "count should be %d not %d", 1, c.Polls.Count()) + topic, e = c.Topics.BypassGet(tid) + expectNilErr(t, e) + exf(topic.Poll == pid, "t.Poll should be %d not %d", pid, topic.Poll) testPoll := func(p *c.Poll, id, parentID int, parentTable string, ptype int, antiCheat bool, voteCount int) { ef := exf @@ -1549,6 +1554,24 @@ func TestPolls(t *testing.T) { _, e = c.Polls.Get(1) recordMustNotExist(t, e, "poll 1 should no longer exist") exf(c.Polls.Count() == 0, "count should be %d not %d", 0, c.Polls.Count()) + topic, e = c.Topics.BypassGet(tid) + expectNilErr(t, e) + exf(topic.Poll == pid, "t.Poll should be %d not %d", pid, topic.Poll) + + expectNilErr(t, topic.SetPoll(999)) + topic, e = c.Topics.BypassGet(tid) + expectNilErr(t, e) + exf(topic.Poll == pid, "t.Poll should be %d not %d", pid, topic.Poll) + + expectNilErr(t, topic.SetPoll(0)) + topic, e = c.Topics.BypassGet(tid) + expectNilErr(t, e) + exf(topic.Poll == pid, "t.Poll should be %d not %d", pid, topic.Poll) + + expectNilErr(t, topic.RemovePoll()) + topic, e = c.Topics.BypassGet(tid) + expectNilErr(t, e) + exf(topic.Poll == 0, "t.Poll should be %d not %d", 0, topic.Poll) } func TestSearch(t *testing.T) {