summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorEmelia Smith <ThisIsMissEm@users.noreply.github.com>2024-01-15 11:36:30 +0100
committerGitHub <noreply@github.com>2024-01-15 10:36:30 +0000
commit58830be94329ce1a541f1f79c9e4e70aef430b19 (patch)
tree6075fa38e564f87cc5d7909bec82c281daa5de93
parente72676e83a3c638ddacbb89cfe15fe9b1b83b92f (diff)
Streaming: Rework websocket server initialisation & authentication code (#28631)
-rw-r--r--streaming/index.js128
1 files changed, 95 insertions, 33 deletions
diff --git a/streaming/index.js b/streaming/index.js
index 42d0afc7c5b..c8124fcc0f1 100644
--- a/streaming/index.js
+++ b/streaming/index.js
@@ -182,14 +182,74 @@ const CHANNEL_NAMES = [
];
const startServer = async () => {
+ const pgPool = new pg.Pool(pgConfigFromEnv(process.env));
+ const server = http.createServer();
+ const wss = new WebSocket.Server({ noServer: true });
+
+ // Set the X-Request-Id header on WebSockets:
+ wss.on("headers", function onHeaders(headers, req) {
+ headers.push(`X-Request-Id: ${req.id}`);
+ });
+
const app = express();
app.set('trust proxy', process.env.TRUSTED_PROXY_IP ? process.env.TRUSTED_PROXY_IP.split(/(?:\s*,\s*|\s+)/) : 'loopback,uniquelocal');
- const pgPool = new pg.Pool(pgConfigFromEnv(process.env));
- const server = http.createServer(app);
app.use(cors());
+ // Handle eventsource & other http requests:
+ server.on('request', app);
+
+ // Handle upgrade requests:
+ server.on('upgrade', async function handleUpgrade(request, socket, head) {
+ /** @param {Error} err */
+ const onSocketError = (err) => {
+ log.error(`Error with websocket upgrade: ${err}`);
+ };
+
+ socket.on('error', onSocketError);
+
+ // Authenticate:
+ try {
+ await accountFromRequest(request);
+ } catch (err) {
+ log.error(`Error authenticating request: ${err}`);
+
+ // Unfortunately for using the on('upgrade') setup, we need to manually
+ // write a HTTP Response to the Socket to close the connection upgrade
+ // attempt, so the following code is to handle all of that.
+ const statusCode = err.status ?? 401;
+
+ /** @type {Record<string, string | number>} */
+ const headers = {
+ 'Connection': 'close',
+ 'Content-Type': 'text/plain',
+ 'Content-Length': 0,
+ 'X-Request-Id': request.id,
+ // TODO: Send the error message via header so it can be debugged in
+ // developer tools
+ };
+
+ // Ensure the socket is closed once we've finished writing to it:
+ socket.once('finish', () => {
+ socket.destroy();
+ });
+
+ // Write the HTTP response manually:
+ socket.end(`HTTP/1.1 ${statusCode} ${http.STATUS_CODES[statusCode]}\r\n${Object.keys(headers).map((key) => `${key}: ${headers[key]}`).join('\r\n')}\r\n\r\n`);
+
+ return;
+ }
+
+ wss.handleUpgrade(request, socket, head, function done(ws) {
+ // Remove the error handler:
+ socket.removeListener('error', onSocketError);
+
+ // Start the connection:
+ wss.emit('connection', ws, request);
+ });
+ });
+
/**
* @type {Object.<string, Array.<function(Object<string, any>): void>>}
*/
@@ -361,9 +421,18 @@ const startServer = async () => {
req.scopes.some(scope => necessaryScopes.includes(scope));
/**
+ * @typedef ResolvedAccount
+ * @property {string} accessTokenId
+ * @property {string[]} scopes
+ * @property {string} accountId
+ * @property {string[]} chosenLanguages
+ * @property {string} deviceId
+ */
+
+ /**
* @param {string} token
* @param {any} req
- * @returns {Promise.<void>}
+ * @returns {Promise<ResolvedAccount>}
*/
const accountFromToken = (token, req) => new Promise((resolve, reject) => {
pgPool.connect((err, client, done) => {
@@ -394,14 +463,20 @@ const startServer = async () => {
req.chosenLanguages = result.rows[0].chosen_languages;
req.deviceId = result.rows[0].device_id;
- resolve();
+ resolve({
+ accessTokenId: result.rows[0].id,
+ scopes: result.rows[0].scopes.split(' '),
+ accountId: result.rows[0].account_id,
+ chosenLanguages: result.rows[0].chosen_languages,
+ deviceId: result.rows[0].device_id
+ });
});
});
});
/**
* @param {any} req
- * @returns {Promise.<void>}
+ * @returns {Promise<ResolvedAccount>}
*/
const accountFromRequest = (req) => new Promise((resolve, reject) => {
const authorization = req.headers.authorization;
@@ -495,25 +570,6 @@ const startServer = async () => {
});
/**
- * @param {any} info
- * @param {function(boolean, number, string): void} callback
- */
- const wsVerifyClient = (info, callback) => {
- // When verifying the websockets connection, we no longer pre-emptively
- // check OAuth scopes and drop the connection if they're missing. We only
- // drop the connection if access without token is not allowed by environment
- // variables. OAuth scope checks are moved to the point of subscription
- // to a specific stream.
-
- accountFromRequest(info.req).then(() => {
- callback(true, undefined, undefined);
- }).catch(err => {
- log.error(info.req.requestId, err.toString());
- callback(false, 401, 'Unauthorized');
- });
- };
-
- /**
* @typedef SystemMessageHandlers
* @property {function(): void} onKill
*/
@@ -944,8 +1000,8 @@ const startServer = async () => {
};
/**
- * @param {any} req
- * @param {any} ws
+ * @param {http.IncomingMessage} req
+ * @param {WebSocket} ws
* @param {string[]} streamName
* @returns {function(string, string): void}
*/
@@ -955,7 +1011,9 @@ const startServer = async () => {
return;
}
- ws.send(JSON.stringify({ stream: streamName, event, payload }), (err) => {
+ const message = JSON.stringify({ stream: streamName, event, payload });
+
+ ws.send(message, (/** @type {Error} */ err) => {
if (err) {
log.error(req.requestId, `Failed to send to websocket: ${err}`);
}
@@ -992,8 +1050,6 @@ const startServer = async () => {
});
});
- const wss = new WebSocket.Server({ server, verifyClient: wsVerifyClient });
-
/**
* @typedef StreamParams
* @property {string} [tag]
@@ -1173,8 +1229,8 @@ const startServer = async () => {
/**
* @typedef WebSocketSession
- * @property {any} socket
- * @property {any} request
+ * @property {WebSocket} websocket
+ * @property {http.IncomingMessage} request
* @property {Object.<string, { channelName: string, listener: SubscriptionListener, stopHeartbeat: function(): void }>} subscriptions
*/
@@ -1297,7 +1353,11 @@ const startServer = async () => {
}
};
- wss.on('connection', (ws, req) => {
+ /**
+ * @param {WebSocket & { isAlive: boolean }} ws
+ * @param {http.IncomingMessage} req
+ */
+ function onConnection(ws, req) {
// Note: url.parse could throw, which would terminate the connection, so we
// increment the connected clients metric straight away when we establish
// the connection, without waiting:
@@ -1375,7 +1435,9 @@ const startServer = async () => {
if (location && location.query.stream) {
subscribeWebsocketToChannel(session, firstParam(location.query.stream), location.query);
}
- });
+ }
+
+ wss.on('connection', onConnection);
setInterval(() => {
wss.clients.forEach(ws => {