Skip to content

Commit

Permalink
chore(async): Initial Refactoring of Global Async Queries (apache#25466)
Browse files Browse the repository at this point in the history
  • Loading branch information
craig-rueda authored Oct 3, 2023
1 parent 36ed617 commit db7f5fe
Show file tree
Hide file tree
Showing 11 changed files with 176 additions and 44 deletions.
2 changes: 2 additions & 0 deletions superset-websocket/src/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ type ConfigType = {
redisStreamReadBlockMs: number;
jwtSecret: string;
jwtCookieName: string;
jwtChannelIdKey: string;
socketResponseTimeoutMs: number;
pingSocketsIntervalMs: number;
gcChannelsIntervalMs: number;
Expand All @@ -54,6 +55,7 @@ function defaultConfig(): ConfigType {
redisStreamReadBlockMs: 5000,
jwtSecret: '',
jwtCookieName: 'async-token',
jwtChannelIdKey: 'channel',
socketResponseTimeoutMs: 60 * 1000,
pingSocketsIntervalMs: 20 * 1000,
gcChannelsIntervalMs: 120 * 1000,
Expand Down
20 changes: 12 additions & 8 deletions superset-websocket/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ interface EventValue {
result_url?: string;
}
interface JwtPayload {
channel: string;
[key: string]: string;
}
interface FetchRangeFromStreamParams {
sessionId: string;
Expand Down Expand Up @@ -253,14 +253,20 @@ export const processStreamResults = (results: StreamResult[]): void => {

/**
* Verify and parse a JWT cookie from an HTTP request.
* Returns the JWT payload or throws an error on invalid token.
* Returns the channelId from the JWT payload found in the cookie
* configured via 'jwtCookieName' in the config.
*/
const getJwtPayload = (request: http.IncomingMessage): JwtPayload => {
const readChannelId = (request: http.IncomingMessage): string => {
const cookies = cookie.parse(request.headers.cookie || '');
const token = cookies[opts.jwtCookieName];

if (!token) throw new Error('JWT not present');
return jwt.verify(token, opts.jwtSecret) as JwtPayload;
const jwtPayload = jwt.verify(token, opts.jwtSecret) as JwtPayload;
const channelId = jwtPayload[opts.jwtChannelIdKey];

if (!channelId) throw new Error('Channel ID not present in JWT');

return channelId;
};

/**
Expand All @@ -286,8 +292,7 @@ export const incrementId = (id: string): string => {
* WebSocket `connection` event handler, called via wss
*/
export const wsConnection = (ws: WebSocket, request: http.IncomingMessage) => {
const jwtPayload: JwtPayload = getJwtPayload(request);
const channel: string = jwtPayload.channel;
const channel: string = readChannelId(request);
const socketInstance: SocketInstance = { ws, channel, pongTs: Date.now() };

// add this ws instance to the internal registry
Expand Down Expand Up @@ -351,8 +356,7 @@ export const httpUpgrade = (
head: Buffer,
) => {
try {
const jwtPayload: JwtPayload = getJwtPayload(request);
if (!jwtPayload.channel) throw new Error('Channel ID not present');
readChannelId(request);
} catch (err) {
// JWT invalid, do not establish a WebSocket connection
logger.error(err);
Expand Down
6 changes: 3 additions & 3 deletions superset/async_events/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ def events(self) -> Response:
$ref: '#/components/responses/500'
"""
try:
async_channel_id = async_query_manager.parse_jwt_from_request(request)[
"channel"
]
async_channel_id = async_query_manager.parse_channel_id_from_request(
request
)
last_event_id = request.args.get("last_id")
events = async_query_manager.read_events(async_channel_id, last_event_id)

Expand Down
45 changes: 43 additions & 2 deletions superset/async_events/async_query_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ def __init__(self) -> None:
self._jwt_cookie_domain: Optional[str]
self._jwt_cookie_samesite: Optional[Literal["None", "Lax", "Strict"]] = None
self._jwt_secret: str
self._load_chart_data_into_cache_job: Any = None
# pylint: disable=invalid-name
self._load_explore_json_into_cache_job: Any = None

def init_app(self, app: Flask) -> None:
config = app.config
Expand Down Expand Up @@ -115,6 +118,19 @@ def init_app(self, app: Flask) -> None:
self._jwt_cookie_domain = config["GLOBAL_ASYNC_QUERIES_JWT_COOKIE_DOMAIN"]
self._jwt_secret = config["GLOBAL_ASYNC_QUERIES_JWT_SECRET"]

if config["GLOBAL_ASYNC_QUERIES_REGISTER_REQUEST_HANDLERS"]:
self.register_request_handlers(app)

# pylint: disable=import-outside-toplevel
from superset.tasks.async_queries import (
load_chart_data_into_cache,
load_explore_json_into_cache,
)

self._load_chart_data_into_cache_job = load_chart_data_into_cache
self._load_explore_json_into_cache_job = load_explore_json_into_cache

def register_request_handlers(self, app: Flask) -> None:
@app.after_request
def validate_session(response: Response) -> Response:
user_id = get_user_id()
Expand Down Expand Up @@ -149,13 +165,13 @@ def validate_session(response: Response) -> Response:

return response

def parse_jwt_from_request(self, req: Request) -> dict[str, Any]:
def parse_channel_id_from_request(self, req: Request) -> str:
token = req.cookies.get(self._jwt_cookie_name)
if not token:
raise AsyncQueryTokenException("Token not preset")

try:
return jwt.decode(token, self._jwt_secret, algorithms=["HS256"])
return jwt.decode(token, self._jwt_secret, algorithms=["HS256"])["channel"]
except Exception as ex:
logger.warning("Parse jwt failed", exc_info=True)
raise AsyncQueryTokenException("Failed to parse token") from ex
Expand All @@ -166,6 +182,31 @@ def init_job(self, channel_id: str, user_id: Optional[int]) -> dict[str, Any]:
channel_id, job_id, user_id, status=self.STATUS_PENDING
)

# pylint: disable=too-many-arguments
def submit_explore_json_job(
self,
channel_id: str,
form_data: dict[str, Any],
response_type: str,
force: Optional[bool] = False,
user_id: Optional[int] = None,
) -> dict[str, Any]:
job_metadata = self.init_job(channel_id, user_id)
self._load_explore_json_into_cache_job.delay(
job_metadata,
form_data,
response_type,
force,
)
return job_metadata

def submit_chart_data_job(
self, channel_id: str, form_data: dict[str, Any], user_id: Optional[int]
) -> dict[str, Any]:
job_metadata = self.init_job(channel_id, user_id)
self._load_chart_data_into_cache_job.delay(job_metadata, form_data)
return job_metadata

def read_events(
self, channel: str, last_id: Optional[str]
) -> list[Optional[dict[str, Any]]]:
Expand Down
12 changes: 6 additions & 6 deletions superset/charts/data/commands/create_async_job_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from flask import Request

from superset.extensions import async_query_manager
from superset.tasks.async_queries import load_chart_data_into_cache

logger = logging.getLogger(__name__)

Expand All @@ -29,10 +28,11 @@ class CreateAsyncChartDataJobCommand:
_async_channel_id: str

def validate(self, request: Request) -> None:
jwt_data = async_query_manager.parse_jwt_from_request(request)
self._async_channel_id = jwt_data["channel"]
self._async_channel_id = async_query_manager.parse_channel_id_from_request(
request
)

def run(self, form_data: dict[str, Any], user_id: Optional[int]) -> dict[str, Any]:
job_metadata = async_query_manager.init_job(self._async_channel_id, user_id)
load_chart_data_into_cache.delay(job_metadata, form_data)
return job_metadata
return async_query_manager.submit_chart_data_job(
self._async_channel_id, form_data, user_id
)
2 changes: 1 addition & 1 deletion superset/cli/lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
feature_flags.update(config.FEATURE_FLAGS)
feature_flags_func = config.GET_FEATURE_FLAGS_FUNC
if feature_flags_func:
# pylint: disable=not-callable
try:
# pylint: disable=not-callable
feature_flags = feature_flags_func(feature_flags)
except Exception: # pylint: disable=broad-except
# bypass any feature flags that depend on context
Expand Down
1 change: 1 addition & 0 deletions superset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1524,6 +1524,7 @@ def EMAIL_HEADER_MUTATOR( # pylint: disable=invalid-name,unused-argument
GLOBAL_ASYNC_QUERIES_REDIS_STREAM_PREFIX = "async-events-"
GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT = 1000
GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT_FIREHOSE = 1000000
GLOBAL_ASYNC_QUERIES_REGISTER_REQUEST_HANDLERS = True
GLOBAL_ASYNC_QUERIES_JWT_COOKIE_NAME = "async-token"
GLOBAL_ASYNC_QUERIES_JWT_COOKIE_SECURE = False
GLOBAL_ASYNC_QUERIES_JWT_COOKIE_SAMESITE: None | (
Expand Down
12 changes: 4 additions & 8 deletions superset/views/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@
from superset.models.user_attributes import UserAttribute
from superset.sqllab.utils import bootstrap_sqllab_data
from superset.superset_typing import FlaskResponse
from superset.tasks.async_queries import load_explore_json_into_cache
from superset.utils import core as utils
from superset.utils.cache import etag_cache
from superset.utils.core import (
Expand Down Expand Up @@ -320,14 +319,11 @@ def explore_json(
# at which point they will call the /explore_json/data/<cache_key>
# endpoint to retrieve the results.
try:
async_channel_id = async_query_manager.parse_jwt_from_request(
request
)["channel"]
job_metadata = async_query_manager.init_job(
async_channel_id, get_user_id()
async_channel_id = (
async_query_manager.parse_channel_id_from_request(request)
)
load_explore_json_into_cache.delay(
job_metadata, form_data, response_type, force
job_metadata = async_query_manager.submit_explore_json_job(
async_channel_id, form_data, response_type, force, get_user_id()
)
except AsyncQueryTokenException:
return json_error_response("Not authorized", 401)
Expand Down
37 changes: 21 additions & 16 deletions tests/integration_tests/tasks/async_queries_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,11 @@

import pytest
from celery.exceptions import SoftTimeLimitExceeded
from flask import g

from superset.charts.commands.exceptions import ChartDataQueryFailedError
from superset.charts.data.commands.get_data_command import ChartDataCommand
from superset.exceptions import SupersetException
from superset.extensions import async_query_manager, security_manager
from superset.tasks import async_queries
from superset.tasks.async_queries import (
load_chart_data_into_cache,
load_explore_json_into_cache,
)
from superset.utils.core import get_user_id
from tests.integration_tests.base_tests import SupersetTestCase
from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices,
Expand All @@ -43,10 +36,14 @@


class TestAsyncQueries(SupersetTestCase):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@pytest.mark.usefixtures(
"load_birth_names_data", "load_birth_names_dashboard_with_slices"
)
@mock.patch.object(async_query_manager, "update_job")
@mock.patch.object(async_queries, "set_form_data")
@mock.patch("superset.tasks.async_queries.set_form_data")
def test_load_chart_data_into_cache(self, mock_set_form_data, mock_update_job):
from superset.tasks.async_queries import load_chart_data_into_cache

app._got_first_request = False
async_query_manager.init_app(app)
query_context = get_query_context("birth_names")
Expand All @@ -70,6 +67,8 @@ def test_load_chart_data_into_cache(self, mock_set_form_data, mock_update_job):
)
@mock.patch.object(async_query_manager, "update_job")
def test_load_chart_data_into_cache_error(self, mock_update_job, mock_run_command):
from superset.tasks.async_queries import load_chart_data_into_cache

app._got_first_request = False
async_query_manager.init_app(app)
query_context = get_query_context("birth_names")
Expand All @@ -93,6 +92,8 @@ def test_load_chart_data_into_cache_error(self, mock_update_job, mock_run_comman
def test_soft_timeout_load_chart_data_into_cache(
self, mock_update_job, mock_run_command
):
from superset.tasks.async_queries import load_chart_data_into_cache

app._got_first_request = False
async_query_manager.init_app(app)
user = security_manager.find_user("gamma")
Expand All @@ -107,9 +108,8 @@ def test_soft_timeout_load_chart_data_into_cache(
errors = ["A timeout occurred while loading chart data"]

with pytest.raises(SoftTimeLimitExceeded):
with mock.patch.object(
async_queries,
"set_form_data",
with mock.patch(
"superset.tasks.async_queries.set_form_data"
) as set_form_data:
set_form_data.side_effect = SoftTimeLimitExceeded()
load_chart_data_into_cache(job_metadata, form_data)
Expand All @@ -118,6 +118,8 @@ def test_soft_timeout_load_chart_data_into_cache(
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@mock.patch.object(async_query_manager, "update_job")
def test_load_explore_json_into_cache(self, mock_update_job):
from superset.tasks.async_queries import load_explore_json_into_cache

app._got_first_request = False
async_query_manager.init_app(app)
table = self.get_table(name="birth_names")
Expand Down Expand Up @@ -146,10 +148,12 @@ def test_load_explore_json_into_cache(self, mock_update_job):
)

@mock.patch.object(async_query_manager, "update_job")
@mock.patch.object(async_queries, "set_form_data")
@mock.patch("superset.tasks.async_queries.set_form_data")
def test_load_explore_json_into_cache_error(
self, mock_set_form_data, mock_update_job
):
from superset.tasks.async_queries import load_explore_json_into_cache

app._got_first_request = False
async_query_manager.init_app(app)
user = security_manager.find_user("gamma")
Expand All @@ -174,6 +178,8 @@ def test_load_explore_json_into_cache_error(
def test_soft_timeout_load_explore_json_into_cache(
self, mock_update_job, mock_run_command
):
from superset.tasks.async_queries import load_explore_json_into_cache

app._got_first_request = False
async_query_manager.init_app(app)
user = security_manager.find_user("gamma")
Expand All @@ -188,9 +194,8 @@ def test_soft_timeout_load_explore_json_into_cache(
errors = ["A timeout occurred while loading explore json, error"]

with pytest.raises(SoftTimeLimitExceeded):
with mock.patch.object(
async_queries,
"set_form_data",
with mock.patch(
"superset.tasks.async_queries.set_form_data"
) as set_form_data:
set_form_data.side_effect = SoftTimeLimitExceeded()
load_explore_json_into_cache(job_metadata, form_data)
Expand Down
16 changes: 16 additions & 0 deletions tests/unit_tests/async_events/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
Loading

0 comments on commit db7f5fe

Please sign in to comment.