diff --git a/src/eduid/webapp/common/api/app.py b/src/eduid/webapp/common/api/app.py index 74101906f..5f6d4514d 100644 --- a/src/eduid/webapp/common/api/app.py +++ b/src/eduid/webapp/common/api/app.py @@ -34,7 +34,6 @@ from eduid.webapp.common.api.debug import init_app_debug from eduid.webapp.common.api.exceptions import init_exception_handlers, init_sentry from eduid.webapp.common.api.middleware import PrefixMiddleware -from eduid.webapp.common.api.request import Request from eduid.webapp.common.authn.utils import no_authn_views from eduid.webapp.common.session.eduid_session import SessionFactory @@ -82,7 +81,6 @@ def __init__( # App setup self.wsgi_app = ProxyFix(self.wsgi_app) # type: ignore[method-assign] - self.request_class = Request # autocorrect location header means that redirects defaults to an absolute path # werkzeug 2.1.0 changed default value to False self.response_class.autocorrect_location_header = True diff --git a/src/eduid/webapp/common/api/request.py b/src/eduid/webapp/common/api/request.py deleted file mode 100644 index bffe307cb..000000000 --- a/src/eduid/webapp/common/api/request.py +++ /dev/null @@ -1,258 +0,0 @@ -""" -This module provides a Request class that extends flask.Request -and adds sanitation to user inputs. This sanitation is performed -on the access methods of the data structures that the request uses to -hold data inputs by the user. -For more information on these structures, see werkzeug.datastructures. - -To use this request, assign it to the `request_class` attribute -of the Flask application:: - - >>> from eduid.webapp.common.api.request import Request - >>> from flask import Flask - >>> app = Flask('name') - >>> app.request_class = Request -""" - -import logging -from collections.abc import Callable, Iterator -from typing import Any, AnyStr, TypeVar - -from flask import abort -from flask.wrappers import Request as FlaskRequest -from werkzeug.datastructures import EnvironHeaders, ImmutableMultiDict, ImmutableTypeConversionDict - -from eduid.webapp.common.api.sanitation import SanitationProblem, Sanitizer - -logger = logging.getLogger(__name__) - - -class SanitationMixin(Sanitizer): - """ - Mixin for werkzeug datastructures providing methods to - sanitize user inputs. - """ - - def sanitize_input( - self, - untrusted_text: AnyStr, - content_type: str | None = None, - strip_characters: bool = False, - ) -> str: - try: - return super().sanitize_input( - untrusted_text=untrusted_text, content_type=content_type, strip_characters=strip_characters - ) - except SanitationProblem: - abort(400) - - -class SanitizedImmutableMultiDict(ImmutableMultiDict, SanitationMixin): # type: ignore[misc] - """ - See `werkzeug.datastructures.ImmutableMultiDict`. - This class is an extension that overrides all access methods to - sanitize the extracted data. - """ - - def __getitem__(self, key: str) -> str: - """ - Return the first data value for this key; - raises KeyError if not found. - - :param key: The key to be looked up. - :raise KeyError: if the key does not exist. - """ - value = super().__getitem__(key) - return self.sanitize_input(value) - - def getlist(self, key: str, type: Callable[[Any], Any] | None = None) -> list: - """ - Return the list of items for a given key. If that key is not in the - `MultiDict`, the return value will be an empty list. Just as `get` - `getlist` accepts a `type` parameter. All items will be converted - with the callable defined there. - - :param key: The key to be looked up. - :param type: A callable that is used to cast the value in the - :class:`MultiDict`. If a :exc:`ValueError` is raised - by this callable the value will be removed from the list. - :return: a :class:`list` of all the values for the key. - """ - assert type is not None - value_list = super().getlist(key, type=type) - return [self.sanitize_input(v) for v in value_list] - - def items(self, multi: bool = False) -> Iterator[tuple[Any, str]]: # type: ignore[override] - """ - Return an iterator of ``(key, value)`` pairs. - - :param multi: If set to `True` the iterator returned will have a pair - for each value of each key. Otherwise it will only - contain pairs for the first value of each key. - """ - - for key, values in dict.items(self): - values = [self.sanitize_input(v) for v in values] - if multi: - for value in values: - yield key, value - else: - yield key, values[0] - - def lists(self) -> Iterator[tuple[Any, list[str]]]: - """Return a list of ``(key, values)`` pairs, where values is the list - of all values associated with the key.""" - - for key, values in dict.items(self): - values = [self.sanitize_input(v) for v in values] - yield key, values - - def values(self) -> Iterator[str]: # type: ignore[override] - """ - Returns an iterator of the first value on every key's value list. - """ - for values in dict.values(self): - yield self.sanitize_input(values[0]) - - def listvalues(self) -> Iterator[Iterator[Any]]: # type: ignore[override] - """ - Return an iterator of all values associated with a key. Zipping - :meth:`keys` and this is the same as calling :meth:`lists`: - - >>> d = MultiDict({"foo": [1, 2, 3]}) - >>> zip(d.keys(), d.listvalues()) == d.lists() - True - """ - for values in dict.values(self): - yield (self.sanitize_input(v) for v in values) - - def to_dict(self, flat: bool = True) -> dict | dict[Any, list[str]]: - """Return the contents as regular dict. If `flat` is `True` the - returned dict will only have the first item present, if `flat` is - `False` all values will be returned as lists. - - :param flat: If set to `False` the dict returned will have lists - with all the values in it. Otherwise it will only - contain the first value for each key. - :return: a :class:`dict` - """ - if flat: - d = {} - for k, v in dict.items(self): - v = self.sanitize_input(v) - d[k] = v - return d - return dict(self.lists()) - - -T = TypeVar("T") - - -class SanitizedTypeConversionDict(ImmutableTypeConversionDict, SanitationMixin): # type: ignore[misc] - """ - See `werkzeug.datastructures.TypeConversionDict`. - This class is an extension that overrides all access methods to - sanitize the extracted data. - """ - - def __getitem__(self, key: str) -> str: - """ - Sanitized __getitem__ - """ - val = super(ImmutableTypeConversionDict, self).__getitem__(key) - return self.sanitize_input(str(val)) - - def get(self, key: str, default: str | None = None, type: Callable[[Any], T] | None = None) -> str | T | None: # type: ignore[override] - """ - Sanitized, type conversion get. - The value identified by `key` is sanitized, and if `type` - is provided, the value is cast to it. - - :param key: the key for the value - :type key: str - :para default: the default if `key` is absent - :type default: str - :param type: The type to cast the value - :type type: type - - :rtype: object - """ - try: - val: Any = self.sanitize_input(self[key]) - if type is not None: - val = type(val) - except (KeyError, ValueError): - val = default - return val - - def values(self) -> list[str]: # type: ignore[override] - """ - sanitized values - """ - return [self.sanitize_input(v) for v in super(ImmutableTypeConversionDict, self).values()] - - def items(self) -> list[tuple[str, str]]: # type: ignore[override] - """ - Sanitized items - """ - return [(v[0], self.sanitize_input(v[1])) for v in super(ImmutableTypeConversionDict, self).items()] - - -class SanitizedEnvironHeaders(EnvironHeaders, SanitationMixin): # type: ignore[misc] - """ - Sanitized and read only version of the headers from a WSGI environment. - """ - - def __init__(self, environ: dict[str, Any]) -> None: - # set content type from environ at init so we don't get in to an infinite recursion - # when sanitize_input tries to look it up later - self.content_type = environ.get("CONTENT_TYPE") - super().__init__(environ=environ) - - def __getitem__(self, key: str, _get_mode: bool = False) -> str: # type: ignore[override] - """ - Sanitized __getitem__ - - :param key: the key for the value - :param _get_mode: is a no-op for this class as there is no index but - used because get() calls it. - """ - val = super().__getitem__(key) - return self.sanitize_input(untrusted_text=val, content_type=self.content_type) - - def __iter__(self) -> Iterator[tuple[str, str]]: - """ - Sanitized __iter__ - """ - for key, value in EnvironHeaders.__iter__(self): - yield key, self.sanitize_input(untrusted_text=value, content_type=self.content_type) - - def get(self, key: str, default: str | None = None, type: Callable[[str], str | None] | None = None) -> str | None: # type: ignore[override] - """ - Sanitized get - """ - val = super().get(key=key, default=default, type=type) # type: ignore[arg-type] - if val is None: - return None - return self.sanitize_input(untrusted_text=val, content_type=self.content_type) - - -class Request(FlaskRequest, SanitationMixin): - """ - Request objects with sanitized inputs - """ - - parameter_storage_class = SanitizedImmutableMultiDict - dict_storage_class = SanitizedTypeConversionDict # type: ignore[assignment] - - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) - self.headers = SanitizedEnvironHeaders(environ=self.environ) - - def get_data(self, *args: Any, **kwargs: Any) -> str: # type: ignore[override] - text = super().get_data(*args, **kwargs) - if text: - text = self.sanitize_input(untrusted_text=text, content_type=self.mimetype) - if text is None: - text = "" - return text diff --git a/src/eduid/webapp/common/api/sanitation.py b/src/eduid/webapp/common/api/sanitation.py index da9f78e79..79575646e 100644 --- a/src/eduid/webapp/common/api/sanitation.py +++ b/src/eduid/webapp/common/api/sanitation.py @@ -1,8 +1,10 @@ import logging -from typing import AnyStr +from collections.abc import Iterable, Mapping, Sequence +from typing import Any, AnyStr from urllib.parse import quote, unquote from bleach import clean +from werkzeug.exceptions import BadRequest logger = logging.getLogger(__name__) @@ -122,7 +124,8 @@ def _sanitize_input( return cleaned_text - def _safe_clean(self, untrusted_text: str, strip_characters: bool = False) -> str: + @staticmethod + def _safe_clean(untrusted_text: str, strip_characters: bool = False) -> str: """ Wrapper for the clean function of bleach to be able to catch when illegal UTF-8 is processed. @@ -140,3 +143,37 @@ def _safe_clean(self, untrusted_text: str, strip_characters: bool = False) -> st "user input." ) raise SanitationProblem("Illegal UTF-8") + + +def sanitize_map(data: Mapping[str, Any]) -> dict[str, Any]: + return {str(sanitize_item(k)): sanitize_item(v) for k, v in data.items()} + + +def sanitize_iter(data: Iterable[str] | Iterable[Sequence[Any]]) -> list[str | dict[str, Any] | list[Any] | None]: + return [sanitize_item(item) for item in data] + + +def sanitize_item( + data: str | dict[str, Any] | Sequence[Any] | list[Sequence[Any]] | None, +) -> str | dict[str, Any] | list[Any] | None: + match data: + case None: + return None + case dict(): + return sanitize_map(data) + case list(): + return sanitize_iter(data) + case str(): + san = Sanitizer() + try: + assert isinstance(data, str) + safe_data = san.sanitize_input(data) + if safe_data != data: + logger.warning("Sanitized input from unsafe characters") + logger.debug(f"data: {data} -> safe_data: {safe_data}") + except SanitationProblem: + logger.exception("There was a problem sanitizing inputs") + raise BadRequest() + return str(safe_data) + case _: + raise SanitationProblem(f"incompatible type {type(data)}") diff --git a/src/eduid/webapp/common/api/schemas/sanitize.py b/src/eduid/webapp/common/api/schemas/sanitize.py new file mode 100644 index 000000000..2c1362530 --- /dev/null +++ b/src/eduid/webapp/common/api/schemas/sanitize.py @@ -0,0 +1,16 @@ +from collections.abc import Mapping +from typing import Any, AnyStr + +from marshmallow.fields import String + +from eduid.webapp.common.api.sanitation import Sanitizer + +__author__ = "lundberg" + + +class SanitizedString(String): + sanitizer = Sanitizer() + + def _deserialize(self, value: AnyStr, attr: str | None, data: Mapping[str, Any] | None, **kwargs: Any) -> str: + _value = self.sanitizer.sanitize_input(untrusted_text=value) + return super()._deserialize(_value, attr, data, **kwargs) diff --git a/src/eduid/webapp/common/api/tests/test_inputs.py b/src/eduid/webapp/common/api/tests/test_inputs.py index d03134886..6fed270b6 100644 --- a/src/eduid/webapp/common/api/tests/test_inputs.py +++ b/src/eduid/webapp/common/api/tests/test_inputs.py @@ -11,8 +11,10 @@ from eduid.common.config.parsers import load_config from eduid.webapp.common.api.app import EduIDBaseApp from eduid.webapp.common.api.decorators import UnmarshalWith +from eduid.webapp.common.api.sanitation import sanitize_item, sanitize_iter, sanitize_map from eduid.webapp.common.api.schemas.base import EduidSchema from eduid.webapp.common.api.schemas.csrf import CSRFRequestMixin +from eduid.webapp.common.api.schemas.sanitize import SanitizedString from eduid.webapp.common.api.testing import EduidAPITestCase from eduid.webapp.common.session.eduid_session import SessionFactory @@ -26,10 +28,11 @@ def dont_validate(value: str | bytes) -> NoReturn: class NonValidatingSchema(EduidSchema, CSRFRequestMixin): - test_data = fields.String(required=True, validate=dont_validate) + test_data = fields.String(required=False, validate=dont_validate) - class Meta: - strict = True + +class SanitizingSchema(EduidSchema, CSRFRequestMixin): + test_data = SanitizedString(required=False) test_views = Blueprint("test", __name__) @@ -45,28 +48,37 @@ def _make_response(data: str) -> Response: @test_views.route("/test-get-param", methods=["GET"]) def get_param_view() -> Response: param = request.args.get("test-param") - assert param - return _make_response(param) + safe_param = sanitize_item(param) + assert safe_param + return _make_response(str(safe_param)) @test_views.route("/test-post-param", methods=["POST"]) def post_param_view() -> Response: - param = request.form.get("test-param") - assert param - return _make_response(param) + safe_form = sanitize_map(request.form) + safe_param = safe_form.get("test-param") + assert safe_param + return _make_response(safe_param) -@test_views.route("/test-post-json", methods=["POST"]) # type: ignore[arg-type] +@test_views.route("/test-post-json", methods=["POST"]) @UnmarshalWith(NonValidatingSchema) -def post_json_view(test_data: str) -> None: - """never validates""" +def post_json_view(test_data: str) -> Response: + return _make_response(test_data) + + +@test_views.route("/test-post-json-sanitizing", methods=["POST"]) +@UnmarshalWith(SanitizingSchema) +def post_json_view_sanitizing(test_data: str) -> Response: + return _make_response(test_data) @test_views.route("/test-cookie") def cookie_view() -> Response: cookie = request.cookies.get("test-cookie") - assert cookie - return _make_response(cookie) + safe_cookie = sanitize_item(cookie) + assert safe_cookie + return _make_response(str(safe_cookie)) @test_views.route("/test-empty-session") @@ -78,16 +90,18 @@ def empty_session_view() -> Response: @test_views.route("/test-header") def header_view() -> Response: - header = request.headers.get("X-TEST") - assert header - return _make_response(header) + safe_headers = sanitize_map(dict(request.headers)) + safe_header = safe_headers.get("X-Test") + assert safe_header + return _make_response(safe_header) @test_views.route("/test-values", methods=["GET", "POST"]) def values_view() -> Response: - param = request.values.get("test-param") - assert param - return _make_response(param) + safe_values = sanitize_map(request.values) + safe_param = safe_values.get("test-param") + assert safe_param + return _make_response(safe_param) class InputsTestApp(EduIDBaseApp): @@ -118,14 +132,12 @@ def load_app(self, test_config: Mapping[str, Any]) -> InputsTestApp: return app def test_get_param(self) -> None: - """""" url = "/test-get-param?test-param=test-param" with self.app.test_request_context(url, method="GET"): response = self.app.dispatch_request() self.assertIn(b"test-param", response.data) def test_get_param_script(self) -> None: - """""" url = '/test-get-param?test-param=' with self.app.test_request_context(url, method="GET"): response = self.app.dispatch_request() @@ -161,7 +173,6 @@ def test_get_param_unicode_percent_encoded(self) -> None: self.assertIn("åäöхэжこんにちわ", response.data.decode("utf8")) def test_post_param_script(self) -> None: - """""" url = "/test-post-param" with self.app.test_request_context(url, method="POST", data={"test-param": ''}): response = self.app.dispatch_request() @@ -185,25 +196,28 @@ def test_post_param_script_percent_encoded_twice(self) -> None: self.assertNotIn(b"", "csrf_token": sess.get_csrf_token()} + response = client.post(url, json=data) + assert response.status_code == 200 + assert b"", "csrf_token": "failing-token"}', - ): - response = self.app.dispatch_request() - self.assertNotIn(b"", "csrf_token": sess.get_csrf_token()} + response = client.post(url, json=data) + self._check_error_response( + response=response, + type_="POST_TEST_TEST_POST_JSON_FAIL", + payload={"error": {"test_data": ["Problem with ''"]}}, + ) def test_cookie_script(self) -> None: - """""" url = "/test-cookie" cookie = dump_cookie("test-cookie", '') with self.app.test_request_context(url, method="GET", headers={"Cookie": cookie}): @@ -211,7 +225,6 @@ def test_cookie_script(self) -> None: self.assertNotIn(b"' with self.app.test_request_context(url, method="GET", headers={"X-TEST": script}): @@ -219,14 +232,12 @@ def test_header_script(self) -> None: self.assertNotIn(b"'}): response = self.app.dispatch_request() @@ -243,3 +254,23 @@ def test_get_using_empty_session(self) -> None: # instead of crashing. response = self.app.dispatch_request() self.assertEqual(response.data, b"") + + @staticmethod + def test_sanitize_mapping() -> None: + unsafe_d = {"test": "", "test2": ["test", ""]} + safe_d = sanitize_map(unsafe_d) + assert safe_d == { + "test": "<script>alert('ho')</script>", + "test2": ["test", "<script>alert('ho')</script>"], + } + + @staticmethod + def test_sanitize_iterable() -> None: + unsafe_l = ["test", "", "test2", ["test", ""]] + safe_l = sanitize_iter(unsafe_l) + assert safe_l == [ + "test", + "<script>alert('ho')</script>", + "test2", + ["test", "<script>alert('ho')</script>"], + ] diff --git a/src/eduid/webapp/idp/mischttp.py b/src/eduid/webapp/idp/mischttp.py index dad84935a..8f1aeecab 100644 --- a/src/eduid/webapp/idp/mischttp.py +++ b/src/eduid/webapp/idp/mischttp.py @@ -63,7 +63,7 @@ import logging import pprint -from collections.abc import Mapping, Sequence +from collections.abc import Sequence from dataclasses import dataclass from typing import Any, Self @@ -72,11 +72,11 @@ from flask import make_response, redirect, request from saml2 import BINDING_HTTP_REDIRECT from user_agents.parsers import UserAgent -from werkzeug.exceptions import BadRequest, InternalServerError +from werkzeug.exceptions import InternalServerError from werkzeug.wrappers import Response as WerkzeugResponse from eduid.common.config.base import CookieConfig -from eduid.webapp.common.api.sanitation import SanitationProblem, Sanitizer +from eduid.webapp.common.api.sanitation import sanitize_map from eduid.webapp.idp.settings.common import IdPConfig logger = logging.getLogger(__name__) @@ -168,23 +168,7 @@ def get_post() -> dict[str, Any]: :return: query string """ - return _sanitise_items(request.form) - - -def _sanitise_items(data: Mapping[str, Any]) -> dict[str, str]: - res = dict() - san = Sanitizer() - for k, v in data.items(): - try: - safe_k = san.sanitize_input(k, content_type="text/plain") - if safe_k != k: - raise BadRequest() - safe_v = san.sanitize_input(v, content_type="text/plain") - except SanitationProblem: - logger.exception("There was a problem sanitizing inputs") - raise BadRequest() - res[str(safe_k)] = str(safe_v) - return res + return sanitize_map(request.form) # ---------------------------------------------------------------------------- @@ -239,7 +223,7 @@ def parse_query_string() -> dict[str, str]: :return: parsed query string """ - args = _sanitise_items(request.args) + args = sanitize_map(request.args) res = {} for k, v in args.items(): if isinstance(v, list): diff --git a/src/eduid/webapp/support/views.py b/src/eduid/webapp/support/views.py index d62c13dfd..08d34e8d6 100644 --- a/src/eduid/webapp/support/views.py +++ b/src/eduid/webapp/support/views.py @@ -6,6 +6,7 @@ from eduid.userdb import User from eduid.userdb.exceptions import UserDoesNotExist, UserHasNotCompletedSignup from eduid.userdb.support.models import SupportSignupUserFilter, SupportUserFilter +from eduid.webapp.common.api.sanitation import sanitize_map from eduid.webapp.support.app import current_support_app as current_app from eduid.webapp.support.helpers import get_credentials_aux_data, require_support_personnel @@ -15,7 +16,8 @@ @support_views.route("/", methods=["GET", "POST"]) @require_support_personnel def index(support_user: User) -> str: - search_query = request.form.get("query") + data = sanitize_map(request.form) + search_query = data.get("query") if request.method != "POST" or not search_query: return render_template(