diff --git a/superset-websocket/src/config.ts b/superset-websocket/src/config.ts index 5d2642b4e9ac6..7d0fac323e975 100644 --- a/superset-websocket/src/config.ts +++ b/superset-websocket/src/config.ts @@ -38,6 +38,7 @@ type ConfigType = { redisStreamReadBlockMs: number; jwtSecret: string; jwtCookieName: string; + jwtChannelIdKey: string; socketResponseTimeoutMs: number; pingSocketsIntervalMs: number; gcChannelsIntervalMs: number; @@ -54,6 +55,7 @@ function defaultConfig(): ConfigType { redisStreamReadBlockMs: 5000, jwtSecret: '', jwtCookieName: 'async-token', + jwtChannelIdKey: 'channel', socketResponseTimeoutMs: 60 * 1000, pingSocketsIntervalMs: 20 * 1000, gcChannelsIntervalMs: 120 * 1000, diff --git a/superset-websocket/src/index.ts b/superset-websocket/src/index.ts index ecb20a4458c09..782275e5ca53a 100644 --- a/superset-websocket/src/index.ts +++ b/superset-websocket/src/index.ts @@ -53,7 +53,7 @@ interface EventValue { result_url?: string; } interface JwtPayload { - channel: string; + [key: string]: string; } interface FetchRangeFromStreamParams { sessionId: string; @@ -253,14 +253,20 @@ export const processStreamResults = (results: StreamResult[]): void => { /** * Verify and parse a JWT cookie from an HTTP request. - * Returns the JWT payload or throws an error on invalid token. + * Returns the channelId from the JWT payload found in the cookie + * configured via 'jwtCookieName' in the config. */ -const getJwtPayload = (request: http.IncomingMessage): JwtPayload => { +const readChannelId = (request: http.IncomingMessage): string => { const cookies = cookie.parse(request.headers.cookie || ''); const token = cookies[opts.jwtCookieName]; if (!token) throw new Error('JWT not present'); - return jwt.verify(token, opts.jwtSecret) as JwtPayload; + const jwtPayload = jwt.verify(token, opts.jwtSecret) as JwtPayload; + const channelId = jwtPayload[opts.jwtChannelIdKey]; + + if (!channelId) throw new Error('Channel ID not present in JWT'); + + return channelId; }; /** @@ -286,8 +292,7 @@ export const incrementId = (id: string): string => { * WebSocket `connection` event handler, called via wss */ export const wsConnection = (ws: WebSocket, request: http.IncomingMessage) => { - const jwtPayload: JwtPayload = getJwtPayload(request); - const channel: string = jwtPayload.channel; + const channel: string = readChannelId(request); const socketInstance: SocketInstance = { ws, channel, pongTs: Date.now() }; // add this ws instance to the internal registry @@ -351,8 +356,7 @@ export const httpUpgrade = ( head: Buffer, ) => { try { - const jwtPayload: JwtPayload = getJwtPayload(request); - if (!jwtPayload.channel) throw new Error('Channel ID not present'); + readChannelId(request); } catch (err) { // JWT invalid, do not establish a WebSocket connection logger.error(err);