From 1854010840cab722232660115e764015ebbe41a3 Mon Sep 17 00:00:00 2001 From: Azareal Date: Wed, 12 Sep 2018 19:50:57 +1000 Subject: [PATCH] Fixed a bug where nil sockets weren't getting skipped. Added some additional sanity checks. --- common/websockets.go | 16 ++++++++++++++-- common/ws_user.go | 21 ++++++++++++++++++++- 2 files changed, 34 insertions(+), 3 deletions(-) diff --git a/common/websockets.go b/common/websockets.go index 1d1728cb..b9f336ec 100644 --- a/common/websockets.go +++ b/common/websockets.go @@ -67,6 +67,9 @@ func RouteWebsockets(w http.ResponseWriter, r *http.Request, user User) RouteErr } break } + if conn == nil { + panic("conn must not be nil") + } messages := bytes.Split(message, []byte("\r")) for _, msg := range messages { @@ -89,6 +92,7 @@ func RouteWebsockets(w http.ResponseWriter, r *http.Request, user User) RouteErr } } conn.Close() + DebugLog("Closing connection for user " + strconv.Itoa(user.ID)) return nil } @@ -98,6 +102,7 @@ func wsPageResponses(wsUser *WSUser, conn *websocket.Conn, page string) { page = Config.DefaultPath } + DebugLog("Entering page " + page) switch page { // Live Topic List is an experimental feature // TODO: Optimise this to reduce the amount of contention @@ -120,7 +125,10 @@ func wsPageResponses(wsUser *WSUser, conn *websocket.Conn, page string) { default: return } - wsUser.SetPageForSocket(conn, page) + err := wsUser.SetPageForSocket(conn, page) + if err != nil { + LogError(err) + } } // TODO: Use a map instead of a switch to make this more modular? @@ -129,6 +137,7 @@ func wsLeavePage(wsUser *WSUser, conn *websocket.Conn, page string) { page = Config.DefaultPath } + DebugLog("Leaving page " + page) switch page { case "/topics/": wsUser.FinalizePage("/topics/", func() { @@ -141,7 +150,10 @@ func wsLeavePage(wsUser *WSUser, conn *websocket.Conn, page string) { delete(adminStatsWatchers, conn) adminStatsMutex.Unlock() } - wsUser.SetPageForSocket(conn, "") + err := wsUser.SetPageForSocket(conn, "") + if err != nil { + LogError(err) + } } // TODO: Abstract this diff --git a/common/ws_user.go b/common/ws_user.go index 50f2c325..d71f2687 100644 --- a/common/ws_user.go +++ b/common/ws_user.go @@ -8,6 +8,7 @@ import ( ) var ErrNoneOnPage = errors.New("This user isn't on that page") +var ErrInvalidSocket = errors.New("That's not a valid WebSocket Connection") type WSUser struct { User *User @@ -84,6 +85,9 @@ func (wsUser *WSUser) RemoveSocket(conn *websocket.Conn) { wsUser.Lock() if len(wsUser.Sockets) < 6 { for i, socket := range wsUser.Sockets { + if socket == nil { + continue + } if socket.conn == conn { wsUser.Sockets[i] = nil wsUser.Unlock() @@ -104,20 +108,32 @@ func (wsUser *WSUser) RemoveSocket(conn *websocket.Conn) { wsUser.Unlock() } -func (wsUser *WSUser) SetPageForSocket(conn *websocket.Conn, page string) { +func (wsUser *WSUser) SetPageForSocket(conn *websocket.Conn, page string) error { + if conn == nil { + return ErrInvalidSocket + } + wsUser.Lock() for _, socket := range wsUser.Sockets { + if socket == nil { + continue + } if socket.conn == conn { socket.Page = page } } wsUser.Unlock() + + return nil } func (wsUser *WSUser) InPage(page string) bool { wsUser.Lock() defer wsUser.Unlock() for _, socket := range wsUser.Sockets { + if socket == nil { + continue + } if socket.Page == page { return true } @@ -129,6 +145,9 @@ func (wsUser *WSUser) FinalizePage(page string, handle func()) { wsUser.Lock() defer wsUser.Unlock() for _, socket := range wsUser.Sockets { + if socket == nil { + continue + } if socket.Page == page { return }