diff --git a/common/ws_user.go b/common/ws_user.go index de83ed58..d61441c0 100644 --- a/common/ws_user.go +++ b/common/ws_user.go @@ -23,16 +23,36 @@ type WSUserSocket struct { } func (u *WSUser) Ping() error { - for _, socket := range u.Sockets { - if socket == nil { - continue + var sockets []*WSUserSocket + var del int + func() { + u.Lock() + defer u.Unlock() + for i, s := range u.Sockets { + if s == nil || s.conn == nil { + del++ + u.Sockets[i] = u.Sockets[len(u.Sockets)-del] + continue + } + sockets = append(sockets, s) } - socket.conn.SetWriteDeadline(time.Now().Add(time.Minute)) - e := socket.conn.WriteMessage(websocket.PingMessage, nil) + }() + if del > 0 { + // TODO: Resize the capacity to release memory more eagerly? + u.Sockets = u.Sockets[:len(u.Sockets)-del] + } + + for _, s := range sockets { + _ = s.conn.SetWriteDeadline(time.Now().Add(time.Minute)) + e := s.conn.WriteMessage(websocket.PingMessage, nil) if e != nil { - socket.conn.Close() + s.conn.Close() + u.Lock() + s.conn = nil + u.Unlock() } } + return nil } @@ -106,10 +126,16 @@ func (u *WSUser) WriteToPageBytesMulti(msgs [][]byte, page string) error { return nil } +func (u *WSUser) CountSockets() int { + u.Lock() + defer u.Unlock() + return len(u.Sockets) +} + func (u *WSUser) AddSocket(conn *websocket.Conn, page string) { u.Lock() // If the number of the sockets is small, then we can keep the size of the slice mostly static and just walk through it looking for empty slots - if len(u.Sockets) < 6 { + /*if len(u.Sockets) < 6 { for i, socket := range u.Sockets { if socket == nil { u.Sockets[i] = &WSUserSocket{conn, page} @@ -118,15 +144,35 @@ func (u *WSUser) AddSocket(conn *websocket.Conn, page string) { return } } - } + }*/ u.Sockets = append(u.Sockets, &WSUserSocket{conn, page}) //fmt.Printf("%+v\n", u.Sockets) u.Unlock() } func (u *WSUser) RemoveSocket(conn *websocket.Conn) { + var del int u.Lock() defer u.Unlock() + for i, socket := range u.Sockets { + if socket == nil || socket.conn == nil { + del++ + u.Sockets[i] = u.Sockets[len(u.Sockets)-del] + } else if socket.conn == conn { + del++ + u.Sockets[i] = u.Sockets[len(u.Sockets)-del] + //break + } + } + //Logf("%+v\n", u.Sockets) + //Log("del: ", del) + if del > 0 { + // TODO: Resize the capacity to release memory more eagerly? + u.Sockets = u.Sockets[:len(u.Sockets)-del] + } + //Logf("%+v\n", u.Sockets) + return + if len(u.Sockets) < 6 { for i, socket := range u.Sockets { if socket == nil { diff --git a/misc_test.go b/misc_test.go index cf475cfd..003d22b2 100644 --- a/misc_test.go +++ b/misc_test.go @@ -3201,3 +3201,76 @@ func TestTick(t *testing.T) { expectNilErr(t, c.CTickLoop.Hourf()) expectNilErr(t, c.CTickLoop.Dayf()) } + +func TestWSHub(t *testing.T) { + ex, exf, h := exp(t), expf(t), &c.WsHub + exf(h.GuestCount() == 0, "GuestCount should be %d not %d", 0, h.GuestCount()) + exf(h.UserCount() == 0, "UserCount should be %d not %d", 0, h.UserCount()) + ex(!h.HasUser(-1), "HasUser(-1) should be false") + ex(!h.HasUser(0), "HasUser(0) should be false") + ex(!h.HasUser(1), "HasUser(1) should be false") + + uid, e := c.Users.Create("WsHub Test", "WsHub Test", "", 1, true) + expectNilErr(t, e) + exf(!h.HasUser(uid), "HasUser(%d) should be false", uid) + exf(len(h.AllUsers()) == 0, "len(AllUsers()) should be %d not %d", 0, len(h.AllUsers())) + + f := func(uid, guestCount, userCount, allUserListLen int, hasUser bool) { + exf(h.GuestCount() == guestCount, "GuestCount should be %d not %d", guestCount, h.GuestCount()) + exf(h.UserCount() == userCount, "UserCount should be %d not %d", userCount, h.UserCount()) + exf(len(h.AllUsers()) == allUserListLen, "len(AllUsers()) should be %d not %d", allUserListLen, len(h.AllUsers())) + if hasUser { + exf(h.HasUser(uid), "HasUser(%d) should be true", uid) + } else { + exf(!h.HasUser(uid), "HasUser(%d) should be false", uid) + } + } + + u, e := c.Users.Get(uid) + expectNilErr(t, e) + wsUser, e := h.AddConn(u, nil) + expectNilErr(t, e) + f(uid, 0, 1, 1, true) + + uid, e = c.Users.Create("WsHub Test 2", "WsHub Test 2", "", 1, true) + expectNilErr(t, e) + u2, e := c.Users.Get(uid) + expectNilErr(t, e) + wsUser2, e := h.AddConn(u2, nil) + expectNilErr(t, e) + f(uid, 0, 2, 2, true) + + h.RemoveConn(wsUser2, nil) + f(uid, 0, 1, 1, false) + h.RemoveConn(wsUser2, nil) + f(uid, 0, 1, 1, false) + h.RemoveConn(wsUser, nil) + f(uid, 0, 0, 0, false) + + countSockets := func(wsUser *c.WSUser, expect int) { + exf(wsUser.CountSockets() == expect, "CountSockets() should be %d not %d", expect, wsUser.CountSockets()) + } + wsUser2, e = h.AddConn(u2, nil) + expectNilErr(t, e) + f(uid, 0, 1, 1, true) + countSockets(wsUser2, 1) + wsUser2.RemoveSocket(nil) + f(uid, 0, 1, 1, true) + countSockets(wsUser2, 0) + h.RemoveConn(wsUser2, nil) + f(uid, 0, 0, 0, false) + countSockets(wsUser2, 0) + + wsUser2, e = h.AddConn(u2, nil) + expectNilErr(t, e) + f(uid, 0, 1, 1, true) + countSockets(wsUser2, 1) + expectNilErr(t, wsUser2.Ping()) + f(uid, 0, 1, 1, true) + countSockets(wsUser2, 0) + h.RemoveConn(wsUser2, nil) + f(uid, 0, 0, 0, false) + countSockets(wsUser2, 0) + + // TODO: Add more tests +}