Move websocket routes into a separate app
This is mostly so we don't have to do any wacky patching but it also makes it so we don't have to keep checking if the request is a web socket request every time we add middleware.
This commit is contained in:
parent
9e09c1f92b
commit
7b2752a62c
|
@ -4,12 +4,12 @@ import { promises as fs } from "fs"
|
||||||
import http from "http"
|
import http from "http"
|
||||||
import * as httpolyglot from "httpolyglot"
|
import * as httpolyglot from "httpolyglot"
|
||||||
import { DefaultedArgs } from "./cli"
|
import { DefaultedArgs } from "./cli"
|
||||||
import { handleUpgrade } from "./http"
|
import { handleUpgrade } from "./wsRouter"
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create an Express app and an HTTP/S server to serve it.
|
* Create an Express app and an HTTP/S server to serve it.
|
||||||
*/
|
*/
|
||||||
export const createApp = async (args: DefaultedArgs): Promise<[Express, http.Server]> => {
|
export const createApp = async (args: DefaultedArgs): Promise<[Express, Express, http.Server]> => {
|
||||||
const app = express()
|
const app = express()
|
||||||
|
|
||||||
const server = args.cert
|
const server = args.cert
|
||||||
|
@ -39,9 +39,10 @@ export const createApp = async (args: DefaultedArgs): Promise<[Express, http.Ser
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
handleUpgrade(app, server)
|
const wsApp = express()
|
||||||
|
handleUpgrade(wsApp, server)
|
||||||
|
|
||||||
return [app, server]
|
return [app, wsApp, server]
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -102,9 +102,9 @@ const main = async (args: DefaultedArgs): Promise<void> => {
|
||||||
throw new Error("Please pass in a password via the config file or $PASSWORD")
|
throw new Error("Please pass in a password via the config file or $PASSWORD")
|
||||||
}
|
}
|
||||||
|
|
||||||
const [app, server] = await createApp(args)
|
const [app, wsApp, server] = await createApp(args)
|
||||||
const serverAddress = ensureAddress(server)
|
const serverAddress = ensureAddress(server)
|
||||||
await register(app, server, args)
|
await register(app, wsApp, server, args)
|
||||||
|
|
||||||
logger.info(`Using config file ${humanPath(args.config)}`)
|
logger.info(`Using config file ${humanPath(args.config)}`)
|
||||||
logger.info(`HTTP server listening on ${serverAddress} ${args.link ? "(randomized by --link)" : ""}`)
|
logger.info(`HTTP server listening on ${serverAddress} ${args.link ? "(randomized by --link)" : ""}`)
|
||||||
|
|
110
src/node/http.ts
110
src/node/http.ts
|
@ -1,8 +1,6 @@
|
||||||
import { field, logger } from "@coder/logger"
|
import { field, logger } from "@coder/logger"
|
||||||
import * as express from "express"
|
import * as express from "express"
|
||||||
import * as expressCore from "express-serve-static-core"
|
import * as expressCore from "express-serve-static-core"
|
||||||
import * as http from "http"
|
|
||||||
import * as net from "net"
|
|
||||||
import qs from "qs"
|
import qs from "qs"
|
||||||
import safeCompare from "safe-compare"
|
import safeCompare from "safe-compare"
|
||||||
import { HttpCode, HttpError } from "../common/http"
|
import { HttpCode, HttpError } from "../common/http"
|
||||||
|
@ -135,111 +133,3 @@ export const getCookieDomain = (host: string, proxyDomains: string[]): string |
|
||||||
logger.debug("got cookie doman", field("host", host))
|
logger.debug("got cookie doman", field("host", host))
|
||||||
return host || undefined
|
return host || undefined
|
||||||
}
|
}
|
||||||
|
|
||||||
declare module "express" {
|
|
||||||
function Router(options?: express.RouterOptions): express.Router & WithWebsocketMethod
|
|
||||||
|
|
||||||
type WebSocketRequestHandler = (
|
|
||||||
req: express.Request & WithWebSocket,
|
|
||||||
res: express.Response,
|
|
||||||
next: express.NextFunction,
|
|
||||||
) => void | Promise<void>
|
|
||||||
|
|
||||||
type WebSocketMethod<T> = (route: expressCore.PathParams, ...handlers: WebSocketRequestHandler[]) => T
|
|
||||||
|
|
||||||
interface WithWebSocket {
|
|
||||||
ws: net.Socket
|
|
||||||
head: Buffer
|
|
||||||
}
|
|
||||||
|
|
||||||
interface WithWebsocketMethod {
|
|
||||||
ws: WebSocketMethod<this>
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
interface WebsocketRequest extends express.Request, express.WithWebSocket {
|
|
||||||
_ws_handled: boolean
|
|
||||||
}
|
|
||||||
|
|
||||||
function isWebSocketRequest(req: express.Request): req is WebsocketRequest {
|
|
||||||
return !!(req as WebsocketRequest).ws
|
|
||||||
}
|
|
||||||
|
|
||||||
export const handleUpgrade = (app: express.Express, server: http.Server): void => {
|
|
||||||
server.on("upgrade", (req, socket, head) => {
|
|
||||||
socket.on("error", () => socket.destroy())
|
|
||||||
|
|
||||||
req.ws = socket
|
|
||||||
req.head = head
|
|
||||||
req._ws_handled = false
|
|
||||||
|
|
||||||
const res = new http.ServerResponse(req)
|
|
||||||
res.writeHead = function writeHead(statusCode: number) {
|
|
||||||
if (statusCode > 200) {
|
|
||||||
socket.destroy(new Error(`${statusCode}`))
|
|
||||||
}
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send the request off to be handled by Express.
|
|
||||||
;(app as any).handle(req, res, () => {
|
|
||||||
if (!req._ws_handled) {
|
|
||||||
socket.destroy(new Error("Not found"))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Patch Express routers to handle web sockets.
|
|
||||||
*
|
|
||||||
* Not using express-ws since the ws-wrapped sockets don't work with the proxy.
|
|
||||||
*/
|
|
||||||
function patchRouter(): void {
|
|
||||||
// This works because Router is also the prototype assigned to the routers it
|
|
||||||
// returns.
|
|
||||||
|
|
||||||
// Store this since the original method will be overridden.
|
|
||||||
const originalGet = (express.Router as any).prototype.get
|
|
||||||
|
|
||||||
// Inject the `ws` method.
|
|
||||||
;(express.Router as any).prototype.ws = function ws(
|
|
||||||
route: expressCore.PathParams,
|
|
||||||
...handlers: express.WebSocketRequestHandler[]
|
|
||||||
) {
|
|
||||||
originalGet.apply(this, [
|
|
||||||
route,
|
|
||||||
...handlers.map((handler) => {
|
|
||||||
const wrapped: express.Handler = (req, res, next) => {
|
|
||||||
if (isWebSocketRequest(req)) {
|
|
||||||
req._ws_handled = true
|
|
||||||
return handler(req, res, next)
|
|
||||||
}
|
|
||||||
next()
|
|
||||||
}
|
|
||||||
return wrapped
|
|
||||||
}),
|
|
||||||
])
|
|
||||||
return this
|
|
||||||
}
|
|
||||||
// Overwrite `get` so we can distinguish between websocket and non-websocket
|
|
||||||
// routes.
|
|
||||||
;(express.Router as any).prototype.get = function get(route: expressCore.PathParams, ...handlers: express.Handler[]) {
|
|
||||||
originalGet.apply(this, [
|
|
||||||
route,
|
|
||||||
...handlers.map((handler) => {
|
|
||||||
const wrapped: express.Handler = (req, res, next) => {
|
|
||||||
if (!isWebSocketRequest(req)) {
|
|
||||||
return handler(req, res, next)
|
|
||||||
}
|
|
||||||
next()
|
|
||||||
}
|
|
||||||
return wrapped
|
|
||||||
}),
|
|
||||||
])
|
|
||||||
return this
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// This needs to happen before anything creates a router.
|
|
||||||
patchRouter()
|
|
||||||
|
|
|
@ -2,6 +2,7 @@ import { Request, Router } from "express"
|
||||||
import proxyServer from "http-proxy"
|
import proxyServer from "http-proxy"
|
||||||
import { HttpCode, HttpError } from "../common/http"
|
import { HttpCode, HttpError } from "../common/http"
|
||||||
import { authenticated, ensureAuthenticated } from "./http"
|
import { authenticated, ensureAuthenticated } from "./http"
|
||||||
|
import { Router as WsRouter } from "./wsRouter"
|
||||||
|
|
||||||
export const proxy = proxyServer.createProxyServer({})
|
export const proxy = proxyServer.createProxyServer({})
|
||||||
proxy.on("error", (error, _, res) => {
|
proxy.on("error", (error, _, res) => {
|
||||||
|
@ -82,7 +83,9 @@ router.all("*", (req, res, next) => {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
router.ws("*", (req, _, next) => {
|
export const wsRouter = WsRouter()
|
||||||
|
|
||||||
|
wsRouter.ws("*", (req, _, next) => {
|
||||||
const port = maybeProxy(req)
|
const port = maybeProxy(req)
|
||||||
if (!port) {
|
if (!port) {
|
||||||
return next()
|
return next()
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import { logger } from "@coder/logger"
|
import { logger } from "@coder/logger"
|
||||||
import bodyParser from "body-parser"
|
import bodyParser from "body-parser"
|
||||||
import cookieParser from "cookie-parser"
|
import cookieParser from "cookie-parser"
|
||||||
import { ErrorRequestHandler, Express } from "express"
|
import * as express from "express"
|
||||||
import { promises as fs } from "fs"
|
import { promises as fs } from "fs"
|
||||||
import http from "http"
|
import http from "http"
|
||||||
import * as path from "path"
|
import * as path from "path"
|
||||||
|
@ -15,6 +15,7 @@ import { replaceTemplates } from "../http"
|
||||||
import { loadPlugins } from "../plugin"
|
import { loadPlugins } from "../plugin"
|
||||||
import * as domainProxy from "../proxy"
|
import * as domainProxy from "../proxy"
|
||||||
import { getMediaMime, paths } from "../util"
|
import { getMediaMime, paths } from "../util"
|
||||||
|
import { WebsocketRequest } from "../wsRouter"
|
||||||
import * as health from "./health"
|
import * as health from "./health"
|
||||||
import * as login from "./login"
|
import * as login from "./login"
|
||||||
import * as proxy from "./proxy"
|
import * as proxy from "./proxy"
|
||||||
|
@ -36,7 +37,12 @@ declare global {
|
||||||
/**
|
/**
|
||||||
* Register all routes and middleware.
|
* Register all routes and middleware.
|
||||||
*/
|
*/
|
||||||
export const register = async (app: Express, server: http.Server, args: DefaultedArgs): Promise<void> => {
|
export const register = async (
|
||||||
|
app: express.Express,
|
||||||
|
wsApp: express.Express,
|
||||||
|
server: http.Server,
|
||||||
|
args: DefaultedArgs,
|
||||||
|
): Promise<void> => {
|
||||||
const heart = new Heart(path.join(paths.data, "heartbeat"), async () => {
|
const heart = new Heart(path.join(paths.data, "heartbeat"), async () => {
|
||||||
return new Promise((resolve, reject) => {
|
return new Promise((resolve, reject) => {
|
||||||
server.getConnections((error, count) => {
|
server.getConnections((error, count) => {
|
||||||
|
@ -50,14 +56,28 @@ export const register = async (app: Express, server: http.Server, args: Defaulte
|
||||||
})
|
})
|
||||||
|
|
||||||
app.disable("x-powered-by")
|
app.disable("x-powered-by")
|
||||||
|
wsApp.disable("x-powered-by")
|
||||||
|
|
||||||
app.use(cookieParser())
|
app.use(cookieParser())
|
||||||
|
wsApp.use(cookieParser())
|
||||||
|
|
||||||
app.use(bodyParser.json())
|
app.use(bodyParser.json())
|
||||||
app.use(bodyParser.urlencoded({ extended: true }))
|
app.use(bodyParser.urlencoded({ extended: true }))
|
||||||
|
|
||||||
app.use(async (req, res, next) => {
|
const common: express.RequestHandler = (req, _, next) => {
|
||||||
heart.beat()
|
heart.beat()
|
||||||
|
|
||||||
|
// Add common variables routes can use.
|
||||||
|
req.args = args
|
||||||
|
req.heart = heart
|
||||||
|
|
||||||
|
next()
|
||||||
|
}
|
||||||
|
|
||||||
|
app.use(common)
|
||||||
|
wsApp.use(common)
|
||||||
|
|
||||||
|
app.use(async (req, res, next) => {
|
||||||
// If we're handling TLS ensure all requests are redirected to HTTPS.
|
// If we're handling TLS ensure all requests are redirected to HTTPS.
|
||||||
// TODO: This does *NOT* work if you have a base path since to specify the
|
// TODO: This does *NOT* work if you have a base path since to specify the
|
||||||
// protocol we need to specify the whole path.
|
// protocol we need to specify the whole path.
|
||||||
|
@ -72,23 +92,28 @@ export const register = async (app: Express, server: http.Server, args: Defaulte
|
||||||
return res.send(await fs.readFile(resourcePath))
|
return res.send(await fs.readFile(resourcePath))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add common variables routes can use.
|
next()
|
||||||
req.args = args
|
|
||||||
req.heart = heart
|
|
||||||
|
|
||||||
return next()
|
|
||||||
})
|
})
|
||||||
|
|
||||||
app.use("/", domainProxy.router)
|
app.use("/", domainProxy.router)
|
||||||
|
wsApp.use("/", domainProxy.wsRouter.router)
|
||||||
|
|
||||||
app.use("/", vscode.router)
|
app.use("/", vscode.router)
|
||||||
|
wsApp.use("/", vscode.wsRouter.router)
|
||||||
|
app.use("/vscode", vscode.router)
|
||||||
|
wsApp.use("/vscode", vscode.wsRouter.router)
|
||||||
|
|
||||||
app.use("/healthz", health.router)
|
app.use("/healthz", health.router)
|
||||||
|
|
||||||
if (args.auth === AuthType.Password) {
|
if (args.auth === AuthType.Password) {
|
||||||
app.use("/login", login.router)
|
app.use("/login", login.router)
|
||||||
}
|
}
|
||||||
|
|
||||||
app.use("/proxy", proxy.router)
|
app.use("/proxy", proxy.router)
|
||||||
|
wsApp.use("/proxy", proxy.wsRouter.router)
|
||||||
|
|
||||||
app.use("/static", _static.router)
|
app.use("/static", _static.router)
|
||||||
app.use("/update", update.router)
|
app.use("/update", update.router)
|
||||||
app.use("/vscode", vscode.router)
|
|
||||||
|
|
||||||
await loadPlugins(app, args)
|
await loadPlugins(app, args)
|
||||||
|
|
||||||
|
@ -96,7 +121,7 @@ export const register = async (app: Express, server: http.Server, args: Defaulte
|
||||||
throw new HttpError("Not Found", HttpCode.NotFound)
|
throw new HttpError("Not Found", HttpCode.NotFound)
|
||||||
})
|
})
|
||||||
|
|
||||||
const errorHandler: ErrorRequestHandler = async (err, req, res, next) => {
|
const errorHandler: express.ErrorRequestHandler = async (err, req, res, next) => {
|
||||||
const resourcePath = path.resolve(rootPath, "src/browser/pages/error.html")
|
const resourcePath = path.resolve(rootPath, "src/browser/pages/error.html")
|
||||||
res.set("Content-Type", getMediaMime(resourcePath))
|
res.set("Content-Type", getMediaMime(resourcePath))
|
||||||
try {
|
try {
|
||||||
|
@ -117,4 +142,11 @@ export const register = async (app: Express, server: http.Server, args: Defaulte
|
||||||
}
|
}
|
||||||
|
|
||||||
app.use(errorHandler)
|
app.use(errorHandler)
|
||||||
|
|
||||||
|
const wsErrorHandler: express.ErrorRequestHandler = async (err, req) => {
|
||||||
|
logger.error(`${err.message} ${err.stack}`)
|
||||||
|
;(req as WebsocketRequest).ws.destroy(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
wsApp.use(wsErrorHandler)
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,6 +3,7 @@ import qs from "qs"
|
||||||
import { HttpCode, HttpError } from "../../common/http"
|
import { HttpCode, HttpError } from "../../common/http"
|
||||||
import { authenticated, redirect } from "../http"
|
import { authenticated, redirect } from "../http"
|
||||||
import { proxy } from "../proxy"
|
import { proxy } from "../proxy"
|
||||||
|
import { Router as WsRouter } from "../wsRouter"
|
||||||
|
|
||||||
export const router = Router()
|
export const router = Router()
|
||||||
|
|
||||||
|
@ -35,7 +36,9 @@ router.all("/(:port)(/*)?", (req, res) => {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
router.ws("/(:port)(/*)?", (req) => {
|
export const wsRouter = WsRouter()
|
||||||
|
|
||||||
|
wsRouter.ws("/(:port)(/*)?", (req) => {
|
||||||
proxy.ws(req, req.ws, req.head, {
|
proxy.ws(req, req.ws, req.head, {
|
||||||
ignorePath: true,
|
ignorePath: true,
|
||||||
target: getProxyTarget(req, true),
|
target: getProxyTarget(req, true),
|
||||||
|
|
|
@ -6,6 +6,7 @@ import { commit, rootPath, version } from "../constants"
|
||||||
import { authenticated, ensureAuthenticated, redirect, replaceTemplates } from "../http"
|
import { authenticated, ensureAuthenticated, redirect, replaceTemplates } from "../http"
|
||||||
import { getMediaMime, pathToFsPath } from "../util"
|
import { getMediaMime, pathToFsPath } from "../util"
|
||||||
import { VscodeProvider } from "../vscode"
|
import { VscodeProvider } from "../vscode"
|
||||||
|
import { Router as WsRouter } from "../wsRouter"
|
||||||
|
|
||||||
export const router = Router()
|
export const router = Router()
|
||||||
|
|
||||||
|
@ -53,23 +54,6 @@ router.get("/", async (req, res) => {
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
router.ws("/", ensureAuthenticated, async (req) => {
|
|
||||||
const magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
|
|
||||||
const reply = crypto
|
|
||||||
.createHash("sha1")
|
|
||||||
.update(req.headers["sec-websocket-key"] + magic)
|
|
||||||
.digest("base64")
|
|
||||||
req.ws.write(
|
|
||||||
[
|
|
||||||
"HTTP/1.1 101 Switching Protocols",
|
|
||||||
"Upgrade: websocket",
|
|
||||||
"Connection: Upgrade",
|
|
||||||
`Sec-WebSocket-Accept: ${reply}`,
|
|
||||||
].join("\r\n") + "\r\n\r\n",
|
|
||||||
)
|
|
||||||
await vscode.sendWebsocket(req.ws, req.query)
|
|
||||||
})
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* TODO: Might currently be unused.
|
* TODO: Might currently be unused.
|
||||||
*/
|
*/
|
||||||
|
@ -103,3 +87,22 @@ router.get("/webview/*", ensureAuthenticated, async (req, res) => {
|
||||||
await fs.readFile(path.join(vscode.vsRootPath, "out/vs/workbench/contrib/webview/browser/pre", req.params[0])),
|
await fs.readFile(path.join(vscode.vsRootPath, "out/vs/workbench/contrib/webview/browser/pre", req.params[0])),
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
export const wsRouter = WsRouter()
|
||||||
|
|
||||||
|
wsRouter.ws("/", ensureAuthenticated, async (req) => {
|
||||||
|
const magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
|
||||||
|
const reply = crypto
|
||||||
|
.createHash("sha1")
|
||||||
|
.update(req.headers["sec-websocket-key"] + magic)
|
||||||
|
.digest("base64")
|
||||||
|
req.ws.write(
|
||||||
|
[
|
||||||
|
"HTTP/1.1 101 Switching Protocols",
|
||||||
|
"Upgrade: websocket",
|
||||||
|
"Connection: Upgrade",
|
||||||
|
`Sec-WebSocket-Accept: ${reply}`,
|
||||||
|
].join("\r\n") + "\r\n\r\n",
|
||||||
|
)
|
||||||
|
await vscode.sendWebsocket(req.ws, req.query)
|
||||||
|
})
|
||||||
|
|
|
@ -0,0 +1,57 @@
|
||||||
|
import * as express from "express"
|
||||||
|
import * as expressCore from "express-serve-static-core"
|
||||||
|
import * as http from "http"
|
||||||
|
import * as net from "net"
|
||||||
|
|
||||||
|
export const handleUpgrade = (app: express.Express, server: http.Server): void => {
|
||||||
|
server.on("upgrade", (req, socket, head) => {
|
||||||
|
socket.on("error", () => socket.destroy())
|
||||||
|
|
||||||
|
req.ws = socket
|
||||||
|
req.head = head
|
||||||
|
req._ws_handled = false
|
||||||
|
|
||||||
|
// Send the request off to be handled by Express.
|
||||||
|
;(app as any).handle(req, new http.ServerResponse(req), () => {
|
||||||
|
if (!req._ws_handled) {
|
||||||
|
socket.destroy(new Error("Not found"))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface WebsocketRequest extends express.Request {
|
||||||
|
ws: net.Socket
|
||||||
|
head: Buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
interface InternalWebsocketRequest extends WebsocketRequest {
|
||||||
|
_ws_handled: boolean
|
||||||
|
}
|
||||||
|
|
||||||
|
export type WebSocketHandler = (
|
||||||
|
req: WebsocketRequest,
|
||||||
|
res: express.Response,
|
||||||
|
next: express.NextFunction,
|
||||||
|
) => void | Promise<void>
|
||||||
|
|
||||||
|
export class WebsocketRouter {
|
||||||
|
public readonly router = express.Router()
|
||||||
|
|
||||||
|
public ws(route: expressCore.PathParams, ...handlers: WebSocketHandler[]): void {
|
||||||
|
this.router.get(
|
||||||
|
route,
|
||||||
|
...handlers.map((handler) => {
|
||||||
|
const wrapped: express.Handler = (req, res, next) => {
|
||||||
|
;(req as InternalWebsocketRequest)._ws_handled = true
|
||||||
|
return handler(req as WebsocketRequest, res, next)
|
||||||
|
}
|
||||||
|
return wrapped
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export function Router(): WebsocketRouter {
|
||||||
|
return new WebsocketRouter()
|
||||||
|
}
|
Loading…
Reference in New Issue