diff --git a/superset/tasks/cache.py b/superset/tasks/cache.py index d1051c8fcb89f..1509b0d0624aa 100644 --- a/superset/tasks/cache.py +++ b/superset/tasks/cache.py @@ -29,6 +29,7 @@ from superset.models.dashboard import Dashboard from superset.models.slice import Slice from superset.tags.models import Tag, TaggedObject +from superset.tasks.utils import fetch_csrf_token from superset.utils import json from superset.utils.date_parser import parse_human_datetime from superset.utils.machine_auth import MachineAuthProvider @@ -219,7 +220,10 @@ def fetch_url(data: str, headers: dict[str, str]) -> dict[str, str]: """ result = {} try: - url = get_url_path("Superset.warm_up_cache") + # Fetch CSRF token for API request + headers.update(fetch_csrf_token(headers)) + + url = get_url_path("ChartRestApi.warm_up_cache") logger.info("Fetching %s with payload %s", url, data) req = request.Request( url, data=bytes(data, "utf-8"), headers=headers, method="PUT" diff --git a/superset/tasks/utils.py b/superset/tasks/utils.py index 5012330bbd43e..6fc799c4abc2c 100644 --- a/superset/tasks/utils.py +++ b/superset/tasks/utils.py @@ -17,12 +17,18 @@ from __future__ import annotations -from typing import TYPE_CHECKING +import logging +from http.client import HTTPResponse +from typing import Optional, TYPE_CHECKING +from urllib import request +from celery.utils.log import get_task_logger from flask import current_app, g from superset.tasks.exceptions import ExecutorNotFoundError from superset.tasks.types import ExecutorType +from superset.utils import json +from superset.utils.urls import get_url_path if TYPE_CHECKING: from superset.models.dashboard import Dashboard @@ -30,6 +36,10 @@ from superset.reports.models import ReportSchedule +logger = get_task_logger(__name__) +logger.setLevel(logging.INFO) + + # pylint: disable=too-many-branches def get_executor( executor_types: list[ExecutorType], @@ -92,3 +102,39 @@ def get_current_user() -> str | None: return user.username return None + + +def fetch_csrf_token( + headers: dict[str, str], session_cookie_name: str = "session" +) -> dict[str, str]: + """ + Fetches a CSRF token for API requests + + :param headers: A map of headers to use in the request, including the session cookie + :returns: A map of headers, including the session cookie and csrf token + """ + url = get_url_path("SecurityRestApi.csrf_token") + logger.info("Fetching %s", url) + req = request.Request(url, headers=headers, method="GET") + response: HTTPResponse + with request.urlopen(req, timeout=600) as response: + body = response.read().decode("utf-8") + session_cookie: Optional[str] = None + cookie_headers = response.headers.get_all("set-cookie") + if cookie_headers: + for cookie in cookie_headers: + cookie = cookie.split(";", 1)[0] + name, value = cookie.split("=", 1) + if name == session_cookie_name: + session_cookie = value + break + + if response.status == 200: + data = json.loads(body) + res = {"X-CSRF-Token": data["result"]} + if session_cookie is not None: + res["Cookie"] = session_cookie + return res + + logger.error("Error fetching CSRF token, status code: %s", response.status) + return {} diff --git a/tests/integration_tests/tasks/test_cache.py b/tests/integration_tests/tasks/test_cache.py index 943b444f76936..6e8d3ffe03b4d 100644 --- a/tests/integration_tests/tasks/test_cache.py +++ b/tests/integration_tests/tasks/test_cache.py @@ -29,9 +29,10 @@ ], ids=["Without trailing slash", "With trailing slash"], ) +@mock.patch("superset.tasks.cache.fetch_csrf_token") @mock.patch("superset.tasks.cache.request.Request") @mock.patch("superset.tasks.cache.request.urlopen") -def test_fetch_url(mock_urlopen, mock_request_cls, base_url): +def test_fetch_url(mock_urlopen, mock_request_cls, mock_fetch_csrf_token, base_url): from superset.tasks.cache import fetch_url mock_request = mock.MagicMock() @@ -40,18 +41,22 @@ def test_fetch_url(mock_urlopen, mock_request_cls, base_url): mock_urlopen.return_value = mock.MagicMock() mock_urlopen.return_value.code = 200 + initial_headers = {"Cookie": "cookie", "key": "value"} + csrf_headers = initial_headers | {"X-CSRF-Token": "csrf_token"} + mock_fetch_csrf_token.return_value = csrf_headers + app.config["WEBDRIVER_BASEURL"] = base_url - headers = {"key": "value"} data = "data" data_encoded = b"data" - result = fetch_url(data, headers) + result = fetch_url(data, initial_headers) assert data == result["success"] + mock_fetch_csrf_token.assert_called_once_with(initial_headers) mock_request_cls.assert_called_once_with( - "http://base-url/superset/warm_up_cache/", + "http://base-url/api/v1/chart/warm_up_cache", data=data_encoded, - headers=headers, + headers=csrf_headers, method="PUT", ) # assert the same Request object is used diff --git a/tests/integration_tests/tasks/test_utils.py b/tests/integration_tests/tasks/test_utils.py new file mode 100644 index 0000000000000..b1213b78c85a0 --- /dev/null +++ b/tests/integration_tests/tasks/test_utils.py @@ -0,0 +1,64 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from unittest import mock + +import pytest + +from tests.integration_tests.test_app import app + + +@pytest.mark.parametrize( + "base_url", + [ + "http://base-url", + "http://base-url/", + ], + ids=["Without trailing slash", "With trailing slash"], +) +@mock.patch("superset.tasks.cache.request.Request") +@mock.patch("superset.tasks.cache.request.urlopen") +def test_fetch_csrf_token(mock_urlopen, mock_request_cls, base_url, app_context): + from superset.tasks.utils import fetch_csrf_token + + mock_request = mock.MagicMock() + mock_request_cls.return_value = mock_request + + mock_response = mock.MagicMock() + mock_urlopen.return_value.__enter__.return_value = mock_response + + mock_response.status = 200 + mock_response.read.return_value = b'{"result": "csrf_token"}' + mock_response.headers.get_all.return_value = [ + "session=new_session_cookie", + "async-token=websocket_cookie", + ] + + app.config["WEBDRIVER_BASEURL"] = base_url + headers = {"Cookie": "original_session_cookie"} + + result_headers = fetch_csrf_token(headers) + + mock_request_cls.assert_called_with( + "http://base-url/api/v1/security/csrf_token/", + headers=headers, + method="GET", + ) + + assert result_headers["X-CSRF-Token"] == "csrf_token" + assert result_headers["Cookie"] == "new_session_cookie" + # assert the same Request object is used + mock_urlopen.assert_called_once_with(mock_request, timeout=mock.ANY)