diff --git a/common/likes.go b/common/likes.go index 347be566..91c3eaed 100644 --- a/common/likes.go +++ b/common/likes.go @@ -15,20 +15,30 @@ type LikeStore interface { } type DefaultLikeStore struct { - count *sql.Stmt - delete *sql.Stmt + count *sql.Stmt + delete *sql.Stmt + singleExists *sql.Stmt } func NewDefaultLikeStore(acc *qgen.Accumulator) (*DefaultLikeStore, error) { return &DefaultLikeStore{ - count: acc.Count("likes").Prepare(), - delete: acc.Delete("likes").Where("targetItem=? AND targetType=?").Prepare(), + count: acc.Count("likes").Prepare(), + delete: acc.Delete("likes").Where("targetItem=? AND targetType=?").Prepare(), + singleExists: acc.Select("likes").Columns("targetItem").Where("sentBy=? AND targetType=? AND targetItem=?").Prepare(), }, acc.FirstError() } // TODO: Write a test for this func (s *DefaultLikeStore) BulkExists(ids []int, sentBy int, targetType string) (eids []int, err error) { - rows, err := qgen.NewAcc().Select("likes").Columns("targetItem").Where("sentBy=? AND targetType=?").In("targetItem", ids).Query(sentBy, targetType) + if len(ids) == 0 { + return nil, nil + } + var rows *sql.Rows + if len(ids) == 1 { + rows, err = s.singleExists.Query(sentBy, targetType, ids[0]) + } else { + rows, err = qgen.NewAcc().Select("likes").Columns("targetItem").Where("sentBy=? AND targetType=?").In("targetItem", ids).Query(sentBy, targetType) + } if err == sql.ErrNoRows { return nil, nil } else if err != nil { diff --git a/misc_test.go b/misc_test.go index 03b48540..f20ddf57 100644 --- a/misc_test.go +++ b/misc_test.go @@ -13,8 +13,8 @@ import ( "time" c "github.com/Azareal/Gosora/common" - "github.com/Azareal/Gosora/common/phrases" "github.com/Azareal/Gosora/common/gauth" + "github.com/Azareal/Gosora/common/phrases" ) func miscinit(t *testing.T) { @@ -1058,6 +1058,54 @@ func testReplyStore(t *testing.T, newID, newPostCount int, ip string) { // TODO: Add tests for ReplyCache } +func TestLikes(t *testing.T) { + miscinit(t) + if !c.PluginsInited { + c.InitPlugins() + } + + uid := 1 + ids, err := c.Likes.BulkExists([]int{}, uid, "replies") + //recordMustNotExist(t, err, "no likes should be found") + expectNilErr(t, err) + expect(t, len(ids) == 0, "len ids should be 0") + + topic, err := c.Topics.Get(1) + expectNilErr(t, err) + rid, err := c.Rstore.Create(topic, "hiii", "", uid) + expectNilErr(t, err) + r, err := c.Rstore.Get(rid) + expectNilErr(t, err) + expectNilErr(t, r.Like(uid)) + ids, err = c.Likes.BulkExists([]int{rid}, uid, "replies") + expectNilErr(t, err) + expect(t, len(ids) == 1, fmt.Sprintf("ids should be %d not %d", 1, len(ids))) + + rid2, err := c.Rstore.Create(topic, "hi 2 u 2", "", uid) + expectNilErr(t, err) + r2, err := c.Rstore.Get(rid2) + expectNilErr(t, err) + expectNilErr(t, r2.Like(uid)) + ids, err = c.Likes.BulkExists([]int{rid, rid2}, uid, "replies") + expectNilErr(t, err) + expect(t, len(ids) == 2, fmt.Sprintf("ids should be %d not %d", 2, len(ids))) + + expectNilErr(t, r.Unlike(uid)) + ids, err = c.Likes.BulkExists([]int{rid2}, uid, "replies") + expectNilErr(t, err) + expect(t, len(ids) == 1, fmt.Sprintf("ids should be %d not %d", 1, len(ids))) + expectNilErr(t, r2.Unlike(uid)) + ids, err = c.Likes.BulkExists([]int{}, uid, "replies") + //recordMustNotExist(t, err, "no likes should be found") + expectNilErr(t, err) + expect(t, len(ids) == 0, "len ids should be 0") + + //BulkExists(ids []int, sentBy int, targetType string) (eids []int, err error) + + expectNilErr(t, topic.Like(1, uid)) + expectNilErr(t, topic.Unlike(uid)) +} + func TestAttachments(t *testing.T) { miscinit(t) if !c.PluginsInited { @@ -1813,7 +1861,7 @@ func TestMFAStore(t *testing.T) { recordMustNotExist(t, err, "mfa uid 1 should not exist") secret, err := c.GenerateGAuthSecret() - expectNilErr(t,err) + expectNilErr(t, err) expectNilErr(t, c.MFAstore.Create(secret, 1)) _, err = c.MFAstore.Get(0) recordMustNotExist(t, err, "mfa uid 0 should not exist") @@ -1844,7 +1892,7 @@ func TestMFAStore(t *testing.T) { test(i) } token, err := gauth.GetTOTPToken(secret) - expectNilErr(t,err) + expectNilErr(t, err) expectNilErr(t, c.Auth.ValidateMFAToken(token, 1)) expectNilErr(t, it.Delete()) _, err = c.MFAstore.Get(-1)