diff --git a/common/websockets.go b/common/websockets.go index 1c007b3e..2cd89467 100644 --- a/common/websockets.go +++ b/common/websockets.go @@ -31,11 +31,15 @@ type WSUser struct { // TODO: Make this an interface? type WsHubImpl struct { - // TODO: Shard this map - OnlineUsers map[int]*WSUser + // TODO: Implement some form of generics so we don't write as much odd-even sharding code + evenOnlineUsers map[int]*WSUser + oddOnlineUsers map[int]*WSUser + evenUserLock sync.RWMutex + oddUserLock sync.RWMutex + + // TODO: Add sharding for this too? OnlineGuests map[*WSUser]bool GuestLock sync.RWMutex - UserLock sync.RWMutex lastTick time.Time lastTopicList []*TopicsRow @@ -54,8 +58,9 @@ func init() { topicListWatchers = make(map[*WSUser]bool) // TODO: Do we really want to initialise this here instead of in main.go / general_test.go like the other things? WsHub = WsHubImpl{ - OnlineUsers: make(map[int]*WSUser), - OnlineGuests: make(map[*WSUser]bool), + evenOnlineUsers: make(map[int]*WSUser), + oddOnlineUsers: make(map[int]*WSUser), + OnlineGuests: make(map[*WSUser]bool), } } @@ -222,32 +227,107 @@ func (hub *WsHubImpl) GuestCount() int { return len(hub.OnlineGuests) } -func (hub *WsHubImpl) UserCount() int { - defer hub.UserLock.RUnlock() - hub.UserLock.RLock() - return len(hub.OnlineUsers) +func (hub *WsHubImpl) UserCount() (count int) { + hub.evenUserLock.RLock() + count += len(hub.evenOnlineUsers) + hub.evenUserLock.RUnlock() + hub.oddUserLock.RLock() + count += len(hub.oddOnlineUsers) + hub.oddUserLock.RUnlock() + return count } func (hub *WsHubImpl) broadcastMessage(msg string) error { - hub.UserLock.RLock() - defer hub.UserLock.RUnlock() - for _, wsUser := range hub.OnlineUsers { - w, err := wsUser.conn.NextWriter(websocket.TextMessage) - if err != nil { - return err + var userLoop = func(users map[int]*WSUser, mutex *sync.RWMutex) error { + defer mutex.RUnlock() + for _, wsUser := range users { + w, err := wsUser.conn.NextWriter(websocket.TextMessage) + if err != nil { + return err + } + _, _ = w.Write([]byte(msg)) + w.Close() } - _, _ = w.Write([]byte(msg)) - w.Close() + return nil + } + // TODO: Can we move this RLock inside the closure safely? + hub.evenUserLock.RLock() + err := userLoop(hub.evenOnlineUsers, &hub.evenUserLock) + if err != nil { + return err + } + hub.oddUserLock.RLock() + return userLoop(hub.oddOnlineUsers, &hub.oddUserLock) +} + +func (hub *WsHubImpl) getUser(uid int) (wsUser *WSUser, err error) { + var ok bool + if uid%2 == 0 { + hub.evenUserLock.RLock() + wsUser, ok = hub.evenOnlineUsers[uid] + hub.evenUserLock.RUnlock() + } else { + hub.oddUserLock.RLock() + wsUser, ok = hub.oddOnlineUsers[uid] + hub.oddUserLock.RUnlock() + } + if !ok { + return nil, errWsNouser + } + return wsUser, nil +} + +// Warning: For efficiency, some of the *WSUsers may be nil pointers, DO NOT EXPORT +func (hub *WsHubImpl) getUsers(uids []int) (wsUsers []*WSUser, err error) { + if len(uids) == 0 { + return nil, errWsNouser + } + hub.evenUserLock.RLock() + // We don't want to keep a lock on this for too long, so we'll accept some nil pointers + for _, uid := range uids { + wsUsers = append(wsUsers, hub.evenOnlineUsers[uid]) + } + hub.evenUserLock.RUnlock() + hub.oddUserLock.RLock() + // We don't want to keep a lock on this for too long, so we'll accept some nil pointers + for _, uid := range uids { + wsUsers = append(wsUsers, hub.oddOnlineUsers[uid]) + } + hub.oddUserLock.RUnlock() + if len(wsUsers) == 0 { + return nil, errWsNouser + } + return wsUsers, nil +} + +func (hub *WsHubImpl) SetUser(uid int, wsUser *WSUser) { + if uid%2 == 0 { + hub.evenUserLock.Lock() + hub.evenOnlineUsers[uid] = wsUser + hub.evenUserLock.Unlock() + } else { + hub.oddUserLock.Lock() + hub.oddOnlineUsers[uid] = wsUser + hub.oddUserLock.Unlock() + } +} + +func (hub *WsHubImpl) RemoveUser(uid int) { + if uid%2 == 0 { + hub.evenUserLock.Lock() + delete(hub.evenOnlineUsers, uid) + hub.evenUserLock.Unlock() + } else { + hub.oddUserLock.Lock() + delete(hub.oddOnlineUsers, uid) + hub.oddUserLock.Unlock() } - return nil } func (hub *WsHubImpl) pushMessage(targetUser int, msg string) error { - hub.UserLock.RLock() - wsUser, ok := hub.OnlineUsers[targetUser] - hub.UserLock.RUnlock() - if !ok { - return errWsNouser + wsUser, err := hub.getUser(targetUser) + if err != nil { + return err } w, err := wsUser.conn.NextWriter(websocket.TextMessage) @@ -261,21 +341,16 @@ func (hub *WsHubImpl) pushMessage(targetUser int, msg string) error { } func (hub *WsHubImpl) pushAlert(targetUser int, asid int, event string, elementType string, actorID int, targetUserID int, elementID int) error { - //log.Print("In pushAlert") - hub.UserLock.RLock() - wsUser, ok := hub.OnlineUsers[targetUser] - hub.UserLock.RUnlock() - if !ok { - return errWsNouser + wsUser, err := hub.getUser(targetUser) + if err != nil { + return err } - //log.Print("Building alert") alert, err := BuildAlert(asid, event, elementType, actorID, targetUserID, elementID, *wsUser.User) if err != nil { return err } - //log.Print("Getting WS Writer") w, err := wsUser.conn.NextWriter(websocket.TextMessage) if err != nil { return err @@ -287,15 +362,9 @@ func (hub *WsHubImpl) pushAlert(targetUser int, asid int, event string, elementT } func (hub *WsHubImpl) pushAlerts(users []int, asid int, event string, elementType string, actorID int, targetUserID int, elementID int) error { - var wsUsers []*WSUser - hub.UserLock.RLock() - // We don't want to keep a lock on this for too long, so we'll accept some nil pointers - for _, uid := range users { - wsUsers = append(wsUsers, hub.OnlineUsers[uid]) - } - hub.UserLock.RUnlock() - if len(wsUsers) == 0 { - return errWsNouser + wsUsers, err := hub.getUsers(users) + if err != nil { + return err } var errs []error @@ -344,9 +413,7 @@ func RouteWebsockets(w http.ResponseWriter, r *http.Request, user User) RouteErr WsHub.OnlineGuests[wsUser] = true WsHub.GuestLock.Unlock() } else { - WsHub.UserLock.Lock() - WsHub.OnlineUsers[user.ID] = wsUser - WsHub.UserLock.Unlock() + WsHub.SetUser(user.ID, wsUser) } //conn.SetReadLimit(/* put the max request size from earlier here? */) @@ -361,9 +428,7 @@ func RouteWebsockets(w http.ResponseWriter, r *http.Request, user User) RouteErr WsHub.GuestLock.Unlock() } else { // TODO: Make sure the admin is removed from the admin stats list in the case that an error happens - WsHub.UserLock.Lock() - delete(WsHub.OnlineUsers, user.ID) - WsHub.UserLock.Unlock() + WsHub.RemoveUser(user.ID) } break }