diff --git a/superset-websocket/src/config.ts b/superset-websocket/src/config.ts index 5d2642b4e9ac6..7d0fac323e975 100644 --- a/superset-websocket/src/config.ts +++ b/superset-websocket/src/config.ts @@ -38,6 +38,7 @@ type ConfigType = { redisStreamReadBlockMs: number; jwtSecret: string; jwtCookieName: string; + jwtChannelIdKey: string; socketResponseTimeoutMs: number; pingSocketsIntervalMs: number; gcChannelsIntervalMs: number; @@ -54,6 +55,7 @@ function defaultConfig(): ConfigType { redisStreamReadBlockMs: 5000, jwtSecret: '', jwtCookieName: 'async-token', + jwtChannelIdKey: 'channel', socketResponseTimeoutMs: 60 * 1000, pingSocketsIntervalMs: 20 * 1000, gcChannelsIntervalMs: 120 * 1000, diff --git a/superset-websocket/src/index.ts b/superset-websocket/src/index.ts index ecb20a4458c09..782275e5ca53a 100644 --- a/superset-websocket/src/index.ts +++ b/superset-websocket/src/index.ts @@ -53,7 +53,7 @@ interface EventValue { result_url?: string; } interface JwtPayload { - channel: string; + [key: string]: string; } interface FetchRangeFromStreamParams { sessionId: string; @@ -253,14 +253,20 @@ export const processStreamResults = (results: StreamResult[]): void => { /** * Verify and parse a JWT cookie from an HTTP request. - * Returns the JWT payload or throws an error on invalid token. + * Returns the channelId from the JWT payload found in the cookie + * configured via 'jwtCookieName' in the config. */ -const getJwtPayload = (request: http.IncomingMessage): JwtPayload => { +const readChannelId = (request: http.IncomingMessage): string => { const cookies = cookie.parse(request.headers.cookie || ''); const token = cookies[opts.jwtCookieName]; if (!token) throw new Error('JWT not present'); - return jwt.verify(token, opts.jwtSecret) as JwtPayload; + const jwtPayload = jwt.verify(token, opts.jwtSecret) as JwtPayload; + const channelId = jwtPayload[opts.jwtChannelIdKey]; + + if (!channelId) throw new Error('Channel ID not present in JWT'); + + return channelId; }; /** @@ -286,8 +292,7 @@ export const incrementId = (id: string): string => { * WebSocket `connection` event handler, called via wss */ export const wsConnection = (ws: WebSocket, request: http.IncomingMessage) => { - const jwtPayload: JwtPayload = getJwtPayload(request); - const channel: string = jwtPayload.channel; + const channel: string = readChannelId(request); const socketInstance: SocketInstance = { ws, channel, pongTs: Date.now() }; // add this ws instance to the internal registry @@ -351,8 +356,7 @@ export const httpUpgrade = ( head: Buffer, ) => { try { - const jwtPayload: JwtPayload = getJwtPayload(request); - if (!jwtPayload.channel) throw new Error('Channel ID not present'); + readChannelId(request); } catch (err) { // JWT invalid, do not establish a WebSocket connection logger.error(err); diff --git a/superset/async_events/api.py b/superset/async_events/api.py index 0a6ceb9c5f4b1..376671cf364a7 100644 --- a/superset/async_events/api.py +++ b/superset/async_events/api.py @@ -88,9 +88,9 @@ def events(self) -> Response: $ref: '#/components/responses/500' """ try: - async_channel_id = async_query_manager.parse_jwt_from_request(request)[ - "channel" - ] + async_channel_id = async_query_manager.parse_channel_id_from_request( + request + ) last_event_id = request.args.get("last_id") events = async_query_manager.read_events(async_channel_id, last_event_id) diff --git a/superset/async_events/async_query_manager.py b/superset/async_events/async_query_manager.py index d67d9ca0817ec..94941541fb4f9 100644 --- a/superset/async_events/async_query_manager.py +++ b/superset/async_events/async_query_manager.py @@ -82,6 +82,9 @@ def __init__(self) -> None: self._jwt_cookie_domain: Optional[str] self._jwt_cookie_samesite: Optional[Literal["None", "Lax", "Strict"]] = None self._jwt_secret: str + self._load_chart_data_into_cache_job: Any = None + # pylint: disable=invalid-name + self._load_explore_json_into_cache_job: Any = None def init_app(self, app: Flask) -> None: config = app.config @@ -115,6 +118,19 @@ def init_app(self, app: Flask) -> None: self._jwt_cookie_domain = config["GLOBAL_ASYNC_QUERIES_JWT_COOKIE_DOMAIN"] self._jwt_secret = config["GLOBAL_ASYNC_QUERIES_JWT_SECRET"] + if config["GLOBAL_ASYNC_QUERIES_REGISTER_REQUEST_HANDLERS"]: + self.register_request_handlers(app) + + # pylint: disable=import-outside-toplevel + from superset.tasks.async_queries import ( + load_chart_data_into_cache, + load_explore_json_into_cache, + ) + + self._load_chart_data_into_cache_job = load_chart_data_into_cache + self._load_explore_json_into_cache_job = load_explore_json_into_cache + + def register_request_handlers(self, app: Flask) -> None: @app.after_request def validate_session(response: Response) -> Response: user_id = get_user_id() @@ -149,13 +165,13 @@ def validate_session(response: Response) -> Response: return response - def parse_jwt_from_request(self, req: Request) -> dict[str, Any]: + def parse_channel_id_from_request(self, req: Request) -> str: token = req.cookies.get(self._jwt_cookie_name) if not token: raise AsyncQueryTokenException("Token not preset") try: - return jwt.decode(token, self._jwt_secret, algorithms=["HS256"]) + return jwt.decode(token, self._jwt_secret, algorithms=["HS256"])["channel"] except Exception as ex: logger.warning("Parse jwt failed", exc_info=True) raise AsyncQueryTokenException("Failed to parse token") from ex @@ -166,6 +182,31 @@ def init_job(self, channel_id: str, user_id: Optional[int]) -> dict[str, Any]: channel_id, job_id, user_id, status=self.STATUS_PENDING ) + # pylint: disable=too-many-arguments + def submit_explore_json_job( + self, + channel_id: str, + form_data: dict[str, Any], + response_type: str, + force: Optional[bool] = False, + user_id: Optional[int] = None, + ) -> dict[str, Any]: + job_metadata = self.init_job(channel_id, user_id) + self._load_explore_json_into_cache_job.delay( + job_metadata, + form_data, + response_type, + force, + ) + return job_metadata + + def submit_chart_data_job( + self, channel_id: str, form_data: dict[str, Any], user_id: Optional[int] + ) -> dict[str, Any]: + job_metadata = self.init_job(channel_id, user_id) + self._load_chart_data_into_cache_job.delay(job_metadata, form_data) + return job_metadata + def read_events( self, channel: str, last_id: Optional[str] ) -> list[Optional[dict[str, Any]]]: diff --git a/superset/charts/data/commands/create_async_job_command.py b/superset/charts/data/commands/create_async_job_command.py index fb6e3f3dbff34..da126277ee7da 100644 --- a/superset/charts/data/commands/create_async_job_command.py +++ b/superset/charts/data/commands/create_async_job_command.py @@ -20,7 +20,6 @@ from flask import Request from superset.extensions import async_query_manager -from superset.tasks.async_queries import load_chart_data_into_cache logger = logging.getLogger(__name__) @@ -29,10 +28,11 @@ class CreateAsyncChartDataJobCommand: _async_channel_id: str def validate(self, request: Request) -> None: - jwt_data = async_query_manager.parse_jwt_from_request(request) - self._async_channel_id = jwt_data["channel"] + self._async_channel_id = async_query_manager.parse_channel_id_from_request( + request + ) def run(self, form_data: dict[str, Any], user_id: Optional[int]) -> dict[str, Any]: - job_metadata = async_query_manager.init_job(self._async_channel_id, user_id) - load_chart_data_into_cache.delay(job_metadata, form_data) - return job_metadata + return async_query_manager.submit_chart_data_job( + self._async_channel_id, form_data, user_id + ) diff --git a/superset/cli/lib.py b/superset/cli/lib.py index 9e14ab6aae025..68f6f0383188a 100755 --- a/superset/cli/lib.py +++ b/superset/cli/lib.py @@ -26,8 +26,8 @@ feature_flags.update(config.FEATURE_FLAGS) feature_flags_func = config.GET_FEATURE_FLAGS_FUNC if feature_flags_func: - # pylint: disable=not-callable try: + # pylint: disable=not-callable feature_flags = feature_flags_func(feature_flags) except Exception: # pylint: disable=broad-except # bypass any feature flags that depend on context diff --git a/superset/config.py b/superset/config.py index afe41fe5949ef..f2daaf5dea577 100644 --- a/superset/config.py +++ b/superset/config.py @@ -1524,6 +1524,7 @@ def EMAIL_HEADER_MUTATOR( # pylint: disable=invalid-name,unused-argument GLOBAL_ASYNC_QUERIES_REDIS_STREAM_PREFIX = "async-events-" GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT = 1000 GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT_FIREHOSE = 1000000 +GLOBAL_ASYNC_QUERIES_REGISTER_REQUEST_HANDLERS = True GLOBAL_ASYNC_QUERIES_JWT_COOKIE_NAME = "async-token" GLOBAL_ASYNC_QUERIES_JWT_COOKIE_SECURE = False GLOBAL_ASYNC_QUERIES_JWT_COOKIE_SAMESITE: None | ( diff --git a/superset/views/core.py b/superset/views/core.py index 268c6fe333d74..e67a255da2850 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -74,7 +74,6 @@ from superset.models.user_attributes import UserAttribute from superset.sqllab.utils import bootstrap_sqllab_data from superset.superset_typing import FlaskResponse -from superset.tasks.async_queries import load_explore_json_into_cache from superset.utils import core as utils from superset.utils.cache import etag_cache from superset.utils.core import ( @@ -320,14 +319,11 @@ def explore_json( # at which point they will call the /explore_json/data/ # endpoint to retrieve the results. try: - async_channel_id = async_query_manager.parse_jwt_from_request( - request - )["channel"] - job_metadata = async_query_manager.init_job( - async_channel_id, get_user_id() + async_channel_id = ( + async_query_manager.parse_channel_id_from_request(request) ) - load_explore_json_into_cache.delay( - job_metadata, form_data, response_type, force + job_metadata = async_query_manager.submit_explore_json_job( + async_channel_id, form_data, response_type, force, get_user_id() ) except AsyncQueryTokenException: return json_error_response("Not authorized", 401) diff --git a/tests/integration_tests/tasks/async_queries_tests.py b/tests/integration_tests/tasks/async_queries_tests.py index 50806ee677394..8e6e595757c4f 100644 --- a/tests/integration_tests/tasks/async_queries_tests.py +++ b/tests/integration_tests/tasks/async_queries_tests.py @@ -20,18 +20,11 @@ import pytest from celery.exceptions import SoftTimeLimitExceeded -from flask import g from superset.charts.commands.exceptions import ChartDataQueryFailedError from superset.charts.data.commands.get_data_command import ChartDataCommand from superset.exceptions import SupersetException from superset.extensions import async_query_manager, security_manager -from superset.tasks import async_queries -from superset.tasks.async_queries import ( - load_chart_data_into_cache, - load_explore_json_into_cache, -) -from superset.utils.core import get_user_id from tests.integration_tests.base_tests import SupersetTestCase from tests.integration_tests.fixtures.birth_names_dashboard import ( load_birth_names_dashboard_with_slices, @@ -43,10 +36,14 @@ class TestAsyncQueries(SupersetTestCase): - @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + @pytest.mark.usefixtures( + "load_birth_names_data", "load_birth_names_dashboard_with_slices" + ) @mock.patch.object(async_query_manager, "update_job") - @mock.patch.object(async_queries, "set_form_data") + @mock.patch("superset.tasks.async_queries.set_form_data") def test_load_chart_data_into_cache(self, mock_set_form_data, mock_update_job): + from superset.tasks.async_queries import load_chart_data_into_cache + app._got_first_request = False async_query_manager.init_app(app) query_context = get_query_context("birth_names") @@ -70,6 +67,8 @@ def test_load_chart_data_into_cache(self, mock_set_form_data, mock_update_job): ) @mock.patch.object(async_query_manager, "update_job") def test_load_chart_data_into_cache_error(self, mock_update_job, mock_run_command): + from superset.tasks.async_queries import load_chart_data_into_cache + app._got_first_request = False async_query_manager.init_app(app) query_context = get_query_context("birth_names") @@ -93,6 +92,8 @@ def test_load_chart_data_into_cache_error(self, mock_update_job, mock_run_comman def test_soft_timeout_load_chart_data_into_cache( self, mock_update_job, mock_run_command ): + from superset.tasks.async_queries import load_chart_data_into_cache + app._got_first_request = False async_query_manager.init_app(app) user = security_manager.find_user("gamma") @@ -107,9 +108,8 @@ def test_soft_timeout_load_chart_data_into_cache( errors = ["A timeout occurred while loading chart data"] with pytest.raises(SoftTimeLimitExceeded): - with mock.patch.object( - async_queries, - "set_form_data", + with mock.patch( + "superset.tasks.async_queries.set_form_data" ) as set_form_data: set_form_data.side_effect = SoftTimeLimitExceeded() load_chart_data_into_cache(job_metadata, form_data) @@ -118,6 +118,8 @@ def test_soft_timeout_load_chart_data_into_cache( @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @mock.patch.object(async_query_manager, "update_job") def test_load_explore_json_into_cache(self, mock_update_job): + from superset.tasks.async_queries import load_explore_json_into_cache + app._got_first_request = False async_query_manager.init_app(app) table = self.get_table(name="birth_names") @@ -146,10 +148,12 @@ def test_load_explore_json_into_cache(self, mock_update_job): ) @mock.patch.object(async_query_manager, "update_job") - @mock.patch.object(async_queries, "set_form_data") + @mock.patch("superset.tasks.async_queries.set_form_data") def test_load_explore_json_into_cache_error( self, mock_set_form_data, mock_update_job ): + from superset.tasks.async_queries import load_explore_json_into_cache + app._got_first_request = False async_query_manager.init_app(app) user = security_manager.find_user("gamma") @@ -174,6 +178,8 @@ def test_load_explore_json_into_cache_error( def test_soft_timeout_load_explore_json_into_cache( self, mock_update_job, mock_run_command ): + from superset.tasks.async_queries import load_explore_json_into_cache + app._got_first_request = False async_query_manager.init_app(app) user = security_manager.find_user("gamma") @@ -188,9 +194,8 @@ def test_soft_timeout_load_explore_json_into_cache( errors = ["A timeout occurred while loading explore json, error"] with pytest.raises(SoftTimeLimitExceeded): - with mock.patch.object( - async_queries, - "set_form_data", + with mock.patch( + "superset.tasks.async_queries.set_form_data" ) as set_form_data: set_form_data.side_effect = SoftTimeLimitExceeded() load_explore_json_into_cache(job_metadata, form_data) diff --git a/tests/unit_tests/async_events/__init__.py b/tests/unit_tests/async_events/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/unit_tests/async_events/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/async_events/async_query_manager_tests.py b/tests/unit_tests/async_events/async_query_manager_tests.py new file mode 100644 index 0000000000000..b4ae06dfc3f6f --- /dev/null +++ b/tests/unit_tests/async_events/async_query_manager_tests.py @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from unittest.mock import Mock + +from jwt import encode +from pytest import fixture, raises + +from superset.async_events.async_query_manager import ( + AsyncQueryManager, + AsyncQueryTokenException, +) + +JWT_TOKEN_SECRET = "some_secret" +JWT_TOKEN_COOKIE_NAME = "superset_async_jwt" + + +@fixture +def async_query_manager(): + query_manager = AsyncQueryManager() + query_manager._jwt_secret = JWT_TOKEN_SECRET + query_manager._jwt_cookie_name = JWT_TOKEN_COOKIE_NAME + + return query_manager + + +def test_parse_channel_id_from_request(async_query_manager): + encoded_token = encode( + {"channel": "test_channel_id"}, JWT_TOKEN_SECRET, algorithm="HS256" + ) + + request = Mock() + request.cookies = {"superset_async_jwt": encoded_token} + + assert ( + async_query_manager.parse_channel_id_from_request(request) == "test_channel_id" + ) + + +def test_parse_channel_id_from_request_no_cookie(async_query_manager): + request = Mock() + request.cookies = {} + + with raises(AsyncQueryTokenException): + async_query_manager.parse_channel_id_from_request(request) + + +def test_parse_channel_id_from_request_bad_jwt(async_query_manager): + request = Mock() + request.cookies = {"superset_async_jwt": "bad_jwt"} + + with raises(AsyncQueryTokenException): + async_query_manager.parse_channel_id_from_request(request)