Fixed a bug where nil sockets weren't getting skipped.

Added some additional sanity checks.
This commit is contained in:
Azareal 2018-09-12 19:50:57 +10:00
parent 646c1f2545
commit 1854010840
2 changed files with 34 additions and 3 deletions

View File

@ -67,6 +67,9 @@ func RouteWebsockets(w http.ResponseWriter, r *http.Request, user User) RouteErr
} }
break break
} }
if conn == nil {
panic("conn must not be nil")
}
messages := bytes.Split(message, []byte("\r")) messages := bytes.Split(message, []byte("\r"))
for _, msg := range messages { for _, msg := range messages {
@ -89,6 +92,7 @@ func RouteWebsockets(w http.ResponseWriter, r *http.Request, user User) RouteErr
} }
} }
conn.Close() conn.Close()
DebugLog("Closing connection for user " + strconv.Itoa(user.ID))
return nil return nil
} }
@ -98,6 +102,7 @@ func wsPageResponses(wsUser *WSUser, conn *websocket.Conn, page string) {
page = Config.DefaultPath page = Config.DefaultPath
} }
DebugLog("Entering page " + page)
switch page { switch page {
// Live Topic List is an experimental feature // Live Topic List is an experimental feature
// TODO: Optimise this to reduce the amount of contention // TODO: Optimise this to reduce the amount of contention
@ -120,7 +125,10 @@ func wsPageResponses(wsUser *WSUser, conn *websocket.Conn, page string) {
default: default:
return return
} }
wsUser.SetPageForSocket(conn, page) err := wsUser.SetPageForSocket(conn, page)
if err != nil {
LogError(err)
}
} }
// TODO: Use a map instead of a switch to make this more modular? // TODO: Use a map instead of a switch to make this more modular?
@ -129,6 +137,7 @@ func wsLeavePage(wsUser *WSUser, conn *websocket.Conn, page string) {
page = Config.DefaultPath page = Config.DefaultPath
} }
DebugLog("Leaving page " + page)
switch page { switch page {
case "/topics/": case "/topics/":
wsUser.FinalizePage("/topics/", func() { wsUser.FinalizePage("/topics/", func() {
@ -141,7 +150,10 @@ func wsLeavePage(wsUser *WSUser, conn *websocket.Conn, page string) {
delete(adminStatsWatchers, conn) delete(adminStatsWatchers, conn)
adminStatsMutex.Unlock() adminStatsMutex.Unlock()
} }
wsUser.SetPageForSocket(conn, "") err := wsUser.SetPageForSocket(conn, "")
if err != nil {
LogError(err)
}
} }
// TODO: Abstract this // TODO: Abstract this

View File

@ -8,6 +8,7 @@ import (
) )
var ErrNoneOnPage = errors.New("This user isn't on that page") var ErrNoneOnPage = errors.New("This user isn't on that page")
var ErrInvalidSocket = errors.New("That's not a valid WebSocket Connection")
type WSUser struct { type WSUser struct {
User *User User *User
@ -84,6 +85,9 @@ func (wsUser *WSUser) RemoveSocket(conn *websocket.Conn) {
wsUser.Lock() wsUser.Lock()
if len(wsUser.Sockets) < 6 { if len(wsUser.Sockets) < 6 {
for i, socket := range wsUser.Sockets { for i, socket := range wsUser.Sockets {
if socket == nil {
continue
}
if socket.conn == conn { if socket.conn == conn {
wsUser.Sockets[i] = nil wsUser.Sockets[i] = nil
wsUser.Unlock() wsUser.Unlock()
@ -104,20 +108,32 @@ func (wsUser *WSUser) RemoveSocket(conn *websocket.Conn) {
wsUser.Unlock() wsUser.Unlock()
} }
func (wsUser *WSUser) SetPageForSocket(conn *websocket.Conn, page string) { func (wsUser *WSUser) SetPageForSocket(conn *websocket.Conn, page string) error {
if conn == nil {
return ErrInvalidSocket
}
wsUser.Lock() wsUser.Lock()
for _, socket := range wsUser.Sockets { for _, socket := range wsUser.Sockets {
if socket == nil {
continue
}
if socket.conn == conn { if socket.conn == conn {
socket.Page = page socket.Page = page
} }
} }
wsUser.Unlock() wsUser.Unlock()
return nil
} }
func (wsUser *WSUser) InPage(page string) bool { func (wsUser *WSUser) InPage(page string) bool {
wsUser.Lock() wsUser.Lock()
defer wsUser.Unlock() defer wsUser.Unlock()
for _, socket := range wsUser.Sockets { for _, socket := range wsUser.Sockets {
if socket == nil {
continue
}
if socket.Page == page { if socket.Page == page {
return true return true
} }
@ -129,6 +145,9 @@ func (wsUser *WSUser) FinalizePage(page string, handle func()) {
wsUser.Lock() wsUser.Lock()
defer wsUser.Unlock() defer wsUser.Unlock()
for _, socket := range wsUser.Sockets { for _, socket := range wsUser.Sockets {
if socket == nil {
continue
}
if socket.Page == page { if socket.Page == page {
return return
} }