Use ensureAuthenticated as middleware

This commit is contained in:
Asher 2020-11-03 16:45:03 -06:00
parent 476379a77e
commit 34225e2bdf
No known key found for this signature in database
GPG Key ID: D63C1EF81242354A
5 changed files with 40 additions and 34 deletions

View File

@ -34,12 +34,15 @@ export const replaceTemplates = <T extends object>(
} }
/** /**
* Throw an error if not authorized. * Throw an error if not authorized. Call `next` if provided.
*/ */
export const ensureAuthenticated = (req: express.Request): void => { export const ensureAuthenticated = (req: express.Request, _?: express.Response, next?: express.NextFunction): void => {
if (!authenticated(req)) { if (!authenticated(req)) {
throw new HttpError("Unauthorized", HttpCode.Unauthorized) throw new HttpError("Unauthorized", HttpCode.Unauthorized)
} }
if (next) {
next()
}
} }
/** /**
@ -136,20 +139,32 @@ export const getCookieDomain = (host: string, proxyDomains: string[]): string |
declare module "express" { declare module "express" {
function Router(options?: express.RouterOptions): express.Router & WithWebsocketMethod function Router(options?: express.RouterOptions): express.Router & WithWebsocketMethod
type WebsocketRequestHandler = ( type WebSocketRequestHandler = (
socket: net.Socket, req: express.Request & WithWebSocket,
head: Buffer, res: express.Response,
req: express.Request,
next: express.NextFunction, next: express.NextFunction,
) => void | Promise<void> ) => void | Promise<void>
type WebsocketMethod<T> = (route: expressCore.PathParams, ...handlers: WebsocketRequestHandler[]) => T type WebSocketMethod<T> = (route: expressCore.PathParams, ...handlers: WebSocketRequestHandler[]) => T
interface WithWebSocket {
ws: net.Socket
head: Buffer
}
interface WithWebsocketMethod { interface WithWebsocketMethod {
ws: WebsocketMethod<this> ws: WebSocketMethod<this>
} }
} }
interface WebsocketRequest extends express.Request, express.WithWebSocket {
_ws_handled: boolean
}
function isWebSocketRequest(req: express.Request): req is WebsocketRequest {
return !!(req as WebsocketRequest).ws
}
export const handleUpgrade = (app: express.Express, server: http.Server): void => { export const handleUpgrade = (app: express.Express, server: http.Server): void => {
server.on("upgrade", (req, socket, head) => { server.on("upgrade", (req, socket, head) => {
socket.on("error", () => socket.destroy()) socket.on("error", () => socket.destroy())
@ -193,15 +208,15 @@ function patchRouter(): void {
// Inject the `ws` method. // Inject the `ws` method.
;(express.Router as any).ws = function ws( ;(express.Router as any).ws = function ws(
route: expressCore.PathParams, route: expressCore.PathParams,
...handlers: express.WebsocketRequestHandler[] ...handlers: express.WebSocketRequestHandler[]
) { ) {
originalGet.apply(this, [ originalGet.apply(this, [
route, route,
...handlers.map((handler) => { ...handlers.map((handler) => {
const wrapped: express.Handler = (req, _, next) => { const wrapped: express.Handler = (req, res, next) => {
if ((req as any).ws) { if (isWebSocketRequest(req)) {
;(req as any)._ws_handled = true req._ws_handled = true
Promise.resolve(handler((req as any).ws, (req as any).head, req, next)).catch(next) Promise.resolve(handler(req, res, next)).catch(next)
} else { } else {
next() next()
} }
@ -218,7 +233,7 @@ function patchRouter(): void {
route, route,
...handlers.map((handler) => { ...handlers.map((handler) => {
const wrapped: express.Handler = (req, res, next) => { const wrapped: express.Handler = (req, res, next) => {
if (!(req as any).ws) { if (!isWebSocketRequest(req)) {
Promise.resolve(handler(req, res, next)).catch(next) Promise.resolve(handler(req, res, next)).catch(next)
} else { } else {
next() next()

View File

@ -82,7 +82,7 @@ router.all("*", (req, res, next) => {
}) })
}) })
router.ws("*", (socket, head, req, next) => { router.ws("*", (req, _, next) => {
const port = maybeProxy(req) const port = maybeProxy(req)
if (!port) { if (!port) {
return next() return next()
@ -91,7 +91,7 @@ router.ws("*", (socket, head, req, next) => {
// Must be authenticated to use the proxy. // Must be authenticated to use the proxy.
ensureAuthenticated(req) ensureAuthenticated(req)
proxy.ws(req, socket, head, { proxy.ws(req, req.ws, req.head, {
ignorePath: true, ignorePath: true,
target: `http://0.0.0.0:${port}${req.originalUrl}`, target: `http://0.0.0.0:${port}${req.originalUrl}`,
}) })

View File

@ -35,8 +35,8 @@ router.all("/(:port)(/*)?", (req, res) => {
}) })
}) })
router.ws("/(:port)(/*)?", (socket, head, req) => { router.ws("/(:port)(/*)?", (req) => {
proxy.ws(req, socket, head, { proxy.ws(req, req.ws, req.head, {
ignorePath: true, ignorePath: true,
target: getProxyTarget(req, true), target: getProxyTarget(req, true),
}) })

View File

@ -7,12 +7,7 @@ export const router = Router()
const provider = new UpdateProvider() const provider = new UpdateProvider()
router.use((req, _, next) => { router.get("/", ensureAuthenticated, async (req, res) => {
ensureAuthenticated(req)
next()
})
router.get("/", async (req, res) => {
const update = await provider.getUpdate(req.query.force === "true") const update = await provider.getUpdate(req.query.force === "true")
res.json({ res.json({
checked: update.checked, checked: update.checked,

View File

@ -53,14 +53,13 @@ router.get("/", async (req, res) => {
) )
}) })
router.ws("/", async (socket, _, req) => { router.ws("/", ensureAuthenticated, async (req) => {
ensureAuthenticated(req)
const magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" const magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
const reply = crypto const reply = crypto
.createHash("sha1") .createHash("sha1")
.update(req.headers["sec-websocket-key"] + magic) .update(req.headers["sec-websocket-key"] + magic)
.digest("base64") .digest("base64")
socket.write( req.ws.write(
[ [
"HTTP/1.1 101 Switching Protocols", "HTTP/1.1 101 Switching Protocols",
"Upgrade: websocket", "Upgrade: websocket",
@ -68,14 +67,13 @@ router.ws("/", async (socket, _, req) => {
`Sec-WebSocket-Accept: ${reply}`, `Sec-WebSocket-Accept: ${reply}`,
].join("\r\n") + "\r\n\r\n", ].join("\r\n") + "\r\n\r\n",
) )
await vscode.sendWebsocket(socket, req.query) await vscode.sendWebsocket(req.ws, req.query)
}) })
/** /**
* TODO: Might currently be unused. * TODO: Might currently be unused.
*/ */
router.get("/resource(/*)?", async (req, res) => { router.get("/resource(/*)?", ensureAuthenticated, async (req, res) => {
ensureAuthenticated(req)
if (typeof req.query.path === "string") { if (typeof req.query.path === "string") {
res.set("Content-Type", getMediaMime(req.query.path)) res.set("Content-Type", getMediaMime(req.query.path))
res.send(await fs.readFile(pathToFsPath(req.query.path))) res.send(await fs.readFile(pathToFsPath(req.query.path)))
@ -85,8 +83,7 @@ router.get("/resource(/*)?", async (req, res) => {
/** /**
* Used by VS Code to load files. * Used by VS Code to load files.
*/ */
router.get("/vscode-remote-resource(/*)?", async (req, res) => { router.get("/vscode-remote-resource(/*)?", ensureAuthenticated, async (req, res) => {
ensureAuthenticated(req)
if (typeof req.query.path === "string") { if (typeof req.query.path === "string") {
res.set("Content-Type", getMediaMime(req.query.path)) res.set("Content-Type", getMediaMime(req.query.path))
res.send(await fs.readFile(pathToFsPath(req.query.path))) res.send(await fs.readFile(pathToFsPath(req.query.path)))
@ -97,8 +94,7 @@ router.get("/vscode-remote-resource(/*)?", async (req, res) => {
* VS Code webviews use these paths to load files and to load webview assets * VS Code webviews use these paths to load files and to load webview assets
* like HTML and JavaScript. * like HTML and JavaScript.
*/ */
router.get("/webview/*", async (req, res) => { router.get("/webview/*", ensureAuthenticated, async (req, res) => {
ensureAuthenticated(req)
res.set("Content-Type", getMediaMime(req.path)) res.set("Content-Type", getMediaMime(req.path))
if (/^vscode-resource/.test(req.params[0])) { if (/^vscode-resource/.test(req.params[0])) {
return res.send(await fs.readFile(req.params[0].replace(/^vscode-resource(\/file)?/, ""))) return res.send(await fs.readFile(req.params[0].replace(/^vscode-resource(\/file)?/, "")))