diff --git a/common/poll.go b/common/poll.go index c607f405..eab83053 100644 --- a/common/poll.go +++ b/common/poll.go @@ -28,30 +28,49 @@ func (p *Poll) CastVote(optionIndex, uid int, ip string) error { if Config.DisablePollIP || !p.AntiCheat { ip = "" } - _, err := pollStmts.addVote.Exec(p.ID, uid, optionIndex, ip) - if err != nil { - return err + _, e := pollStmts.addVote.Exec(p.ID, uid, optionIndex, ip) + if e != nil { + return e } - _, err = pollStmts.incVoteCount.Exec(p.ID) - if err != nil { - return err + _, e = pollStmts.incVoteCount.Exec(p.ID) + if e != nil { + return e } - _, err = pollStmts.incVoteCountForOption.Exec(optionIndex, p.ID) - return err + _, e = pollStmts.incVoteCountForOption.Exec(optionIndex, p.ID) + return e } func (p *Poll) Delete() error { - _, err := pollStmts.deletePollVotes.Exec(p.ID) - if err != nil { - return err + _, e := pollStmts.deletePollVotes.Exec(p.ID) + if e != nil { + return e } - _, err = pollStmts.deletePollOptions.Exec(p.ID) - if err != nil { - return err + _, e = pollStmts.deletePollOptions.Exec(p.ID) + if e != nil { + return e } - _, err = pollStmts.deletePoll.Exec(p.ID) + _, e = pollStmts.deletePoll.Exec(p.ID) _ = Polls.GetCache().Remove(p.ID) - return err + return e +} + +func (p *Poll) Resultsf(f func(votes int) error) error { + rows, e := pollStmts.getResults.Query(p.ID) + if e != nil { + return e + } + defer rows.Close() + + var votes int + for rows.Next() { + if e := rows.Scan(&votes); e != nil { + return e + } + if e := f(votes); e != nil { + return e + } + } + return rows.Err() } func (p *Poll) Copy() Poll { @@ -59,24 +78,31 @@ func (p *Poll) Copy() Poll { } type PollStmts struct { + getResults *sql.Stmt + addVote *sql.Stmt incVoteCount *sql.Stmt incVoteCountForOption *sql.Stmt - deletePoll *sql.Stmt - deletePollOptions *sql.Stmt - deletePollVotes *sql.Stmt + + deletePoll *sql.Stmt + deletePollOptions *sql.Stmt + deletePollVotes *sql.Stmt } func init() { DbInits.Add(func(acc *qgen.Accumulator) error { p := "polls" + wh := "pollID=?" pollStmts = PollStmts{ + getResults: acc.Select("polls_options").Columns("votes").Where("pollID=?").Orderby("option ASC").Prepare(), + addVote: acc.Insert("polls_votes").Columns("pollID,uid,option,castAt,ip").Fields("?,?,?,UTC_TIMESTAMP(),?").Prepare(), - incVoteCount: acc.Update(p).Set("votes = votes + 1").Where("pollID=?").Prepare(), - incVoteCountForOption: acc.Update("polls_options").Set("votes = votes + 1").Where("option=? AND pollID=?").Prepare(), - deletePoll: acc.Delete(p).Where("pollID=?").Prepare(), - deletePollOptions: acc.Delete("polls_options").Where("pollID=?").Prepare(), - deletePollVotes: acc.Delete("polls_votes").Where("pollID=?").Prepare(), + incVoteCount: acc.Update(p).Set("votes=votes+1").Where(wh).Prepare(), + incVoteCountForOption: acc.Update("polls_options").Set("votes=votes+1").Where("option=? AND pollID=?").Prepare(), + + deletePoll: acc.Delete(p).Where(wh).Prepare(), + deletePollOptions: acc.Delete("polls_options").Where(wh).Prepare(), + deletePollVotes: acc.Delete("polls_votes").Where(wh).Prepare(), } return acc.FirstError() }) diff --git a/misc_test.go b/misc_test.go index 7c40bb24..4eeeb718 100644 --- a/misc_test.go +++ b/misc_test.go @@ -1546,6 +1546,32 @@ func TestPolls(t *testing.T) { expectNilErr(t, e) testPoll(p, 1, tid, "topics", 0, false, 1) + var vslice []int + expectNilErr(t, p.Resultsf(func(votes int) error { + vslice = append(vslice, votes) + return nil + })) + //fmt.Printf("vslice: %+v\n", vslice) + exf(vslice[0] == 1, "vslice[0] should be %d not %d", 0, vslice[0]) + exf(vslice[1] == 0, "vslice[1] should be %d not %d", 1, vslice[1]) + exf(vslice[2] == 0, "vslice[2] should be %d not %d", 0, vslice[2]) + + expectNilErr(t, p.CastVote(2, 1, "")) + expectNilErr(t, c.Polls.Reload(p.ID)) + p, e = c.Polls.Get(1) + expectNilErr(t, e) + testPoll(p, 1, tid, "topics", 0, false, 2) + + vslice = nil + expectNilErr(t, p.Resultsf(func(votes int) error { + vslice = append(vslice, votes) + return nil + })) + //fmt.Printf("vslice: %+v\n", vslice) + exf(vslice[0] == 1, "vslice[0] should be %d not %d", 1, vslice[0]) + exf(vslice[1] == 0, "vslice[1] should be %d not %d", 0, vslice[1]) + exf(vslice[2] == 1, "vslice[2] should be %d not %d", 1, vslice[2]) + expectNilErr(t, c.Polls.ClearIPs()) // TODO: Test to see if it worked diff --git a/routes/poll.go b/routes/poll.go index c38a3cd0..a897c067 100644 --- a/routes/poll.go +++ b/routes/poll.go @@ -1,13 +1,13 @@ package routes import ( + "bytes" "database/sql" "errors" "net/http" "strconv" c "github.com/Azareal/Gosora/common" - qgen "github.com/Azareal/Gosora/query_gen" ) func PollVote(w http.ResponseWriter, r *http.Request, u *c.User, sPollID string) c.RouteError { @@ -78,29 +78,23 @@ func PollResults(w http.ResponseWriter, r *http.Request, u *c.User, sPollID stri return c.InternalError(err, w, r) } - // TODO: Abstract this - rows, err := qgen.NewAcc().Select("polls_options").Columns("votes").Where("pollID=?").Orderby("option ASC").Query(poll.ID) - if err != nil { - return c.InternalError(err, w, r) - } - defer rows.Close() - - optList := "" - var votes int - for rows.Next() { - if e := rows.Scan(&votes); e != nil { - return c.InternalError(e, w, r) - } - optList += strconv.Itoa(votes) + "," - } - if err = rows.Err(); err != nil { - return c.InternalError(err, w, r) - } - // TODO: Implement a version of this which doesn't rely so much on sequential order - if len(optList) > 0 { - optList = optList[:len(optList)-1] + var ob bytes.Buffer + ob.WriteRune('[') + var i int + e := poll.Resultsf(func(votes int) error { + if i != 0 { + ob.WriteRune(',') + } + ob.WriteString(strconv.Itoa(votes)) + i++ + return nil + }) + if e != nil && e != sql.ErrNoRows { + return c.InternalError(e, w, r) } - w.Write([]byte("[" + optList + "]")) + ob.WriteRune(']') + w.Write(ob.Bytes()) + return nil }