diff --git a/app/authentication/authenticator.py b/app/authentication/authenticator.py index ccf0285b72..5e1dcc1475 100644 --- a/app/authentication/authenticator.py +++ b/app/authentication/authenticator.py @@ -1,6 +1,6 @@ from contextlib import contextmanager from datetime import datetime, timedelta, timezone -from typing import Any, Generator, Mapping, MutableMapping, Optional +from typing import Any, Generator, Mapping, MutableMapping from uuid import uuid4 from blinker import ANY @@ -30,7 +30,7 @@ @login_manager.user_loader -def user_loader(user_id: str) -> Optional[str]: +def user_loader(user_id: str) -> str | None: logger.debug("loading user", user_id=user_id) return load_user() @@ -38,7 +38,7 @@ def user_loader(user_id: str) -> Optional[str]: @login_manager.request_loader def request_load_user( request: Request, -) -> Optional[User]: +) -> User | None: logger.debug("load user") extend_session = not ( @@ -94,7 +94,7 @@ def _is_session_valid(session_store: SessionStore) -> bool: ) -def load_user(extend_session: bool = True) -> Optional[User]: +def load_user(extend_session: bool = True) -> User | None: """ Checks for the present of the JWT in the users sessions :return: A user object if a JWT token is available in the session diff --git a/app/authentication/no_questionnaire_state_exception.py b/app/authentication/no_questionnaire_state_exception.py index 2affbdb90e..aadba31c59 100644 --- a/app/authentication/no_questionnaire_state_exception.py +++ b/app/authentication/no_questionnaire_state_exception.py @@ -1,8 +1,5 @@ -from typing import Union - - class NoQuestionnaireStateException(Exception): - def __init__(self, value: Union[str, int]) -> None: + def __init__(self, value: str | int) -> None: super().__init__() self.value = value diff --git a/app/authentication/no_token_exception.py b/app/authentication/no_token_exception.py index 0790d3a9df..d79e2e82ef 100644 --- a/app/authentication/no_token_exception.py +++ b/app/authentication/no_token_exception.py @@ -1,8 +1,5 @@ -from typing import Union - - class NoTokenException(Exception): - def __init__(self, value: Union[str, int]) -> None: + def __init__(self, value: str | int) -> None: super().__init__() self.value = value diff --git a/app/data_models/answer.py b/app/data_models/answer.py index 6e16eaeb38..d199021276 100644 --- a/app/data_models/answer.py +++ b/app/data_models/answer.py @@ -2,21 +2,26 @@ from dataclasses import asdict, dataclass, field from decimal import Decimal -from typing import Optional, TypedDict, Union, overload +from typing import TypedDict, overload from markupsafe import Markup, escape -DictAnswer = dict[str, Union[int, str]] +DictAnswer = dict[str, int | str] ListAnswer = list[str] ListDictAnswer = list[DictAnswer] -DictAnswerEscaped = dict[str, Union[int, Markup]] +DictAnswerEscaped = dict[str, int | Markup] ListAnswerEscaped = list[Markup] ListDictAnswerEscaped = list[DictAnswerEscaped] -AnswerValueTypes = Union[str, int, Decimal, DictAnswer, ListAnswer, ListDictAnswer] -AnswerValueEscapedTypes = Union[ - Markup, int, Decimal, DictAnswerEscaped, ListAnswerEscaped, ListDictAnswerEscaped -] +AnswerValueTypes = str | int | Decimal | DictAnswer | ListAnswer | ListDictAnswer +AnswerValueEscapedTypes = ( + Markup + | int + | Decimal + | DictAnswerEscaped + | ListAnswerEscaped + | ListDictAnswerEscaped +) class AnswerDict(TypedDict, total=False): @@ -29,7 +34,7 @@ class AnswerDict(TypedDict, total=False): class Answer: answer_id: str value: AnswerValueTypes - list_item_id: Optional[str] = field(default=None) + list_item_id: str | None = field(default=None) @classmethod def from_dict(cls, answer_dict: AnswerDict) -> Answer: @@ -69,13 +74,13 @@ def escape_answer_value(value: str) -> Markup: ... # pragma: no cover @overload def escape_answer_value( - value: Union[None, int, Decimal] -) -> Union[None, int, Decimal]: ... # pragma: no cover + value: None | int | Decimal, +) -> None | int | Decimal: ... # pragma: no cover def escape_answer_value( - value: Optional[AnswerValueTypes], -) -> Optional[AnswerValueEscapedTypes]: + value: AnswerValueTypes | None, +) -> AnswerValueEscapedTypes | None: if isinstance(value, list): return [escape(item) for item in value] diff --git a/app/data_models/answer_store.py b/app/data_models/answer_store.py index afa291b543..d79796a92e 100644 --- a/app/data_models/answer_store.py +++ b/app/data_models/answer_store.py @@ -1,10 +1,10 @@ from __future__ import annotations -from typing import Iterable, Iterator, Optional +from typing import Iterable, Iterator from app.data_models.answer import Answer, AnswerDict -AnswerKeyType = tuple[str, Optional[str]] +AnswerKeyType = tuple[str, str | None] class AnswerStore: @@ -21,7 +21,7 @@ class AnswerStore: } """ - def __init__(self, answers: Optional[Iterable[AnswerDict]] = None): + def __init__(self, answers: Iterable[AnswerDict] | None = None): """Instantiate an answer_store. Args: @@ -81,8 +81,8 @@ def add_or_update(self, answer: Answer) -> bool: return False def get_answer( - self, answer_id: str, list_item_id: Optional[str] = None - ) -> Optional[Answer]: + self, answer_id: str, list_item_id: str | None = None + ) -> Answer | None: """Get a single answer from the store Args: @@ -95,7 +95,7 @@ def get_answer( return self.answer_map.get((answer_id, list_item_id)) def get_answers_by_answer_id( - self, answer_ids: Iterable[str], list_item_id: Optional[str] = None + self, answer_ids: Iterable[str], list_item_id: str | None = None ) -> list[Answer]: """Get multiple answers from the store using the answer_id @@ -121,9 +121,7 @@ def clear(self) -> None: """ self.answer_map.clear() - def remove_answer( - self, answer_id: str, *, list_item_id: Optional[str] = None - ) -> bool: + def remove_answer(self, answer_id: str, *, list_item_id: str | None = None) -> bool: """ Removes answer *in place* from the answer store. :return: True if answer removed else False diff --git a/app/data_models/app_models.py b/app/data_models/app_models.py index 68aeffc656..d9f0b5d7f6 100644 --- a/app/data_models/app_models.py +++ b/app/data_models/app_models.py @@ -1,5 +1,5 @@ from datetime import datetime, timezone -from typing import Any, Optional, Union +from typing import Any from marshmallow import Schema, fields, post_load, pre_dump @@ -11,8 +11,8 @@ def __init__( state_data: str, collection_exercise_sid: str, version: int, - submitted_at: Optional[datetime] = None, - expires_at: Optional[datetime] = None, + submitted_at: datetime | None = None, + expires_at: datetime | None = None, ): self.user_id = user_id self.state_data = state_data @@ -28,9 +28,9 @@ class EQSession: def __init__( self, eq_session_id: str, - user_id: Optional[str], + user_id: str | None, expires_at: datetime, - session_data: Optional[str], + session_data: str | None, ): self.eq_session_id = eq_session_id self.user_id = user_id @@ -52,9 +52,9 @@ class Timestamp(fields.Field): def _serialize( self, value: datetime, - *args: Optional[list], + *args: list | None, **kwargs: Any, - ) -> Optional[int]: + ) -> int | None: if value: # Timezone aware datetime to timestamp return int(value.replace(tzinfo=timezone.utc).timestamp()) @@ -63,9 +63,9 @@ def _serialize( def _deserialize( self, value: float, - *args: Optional[list], + *args: list | None, **kwargs: Any, - ) -> Optional[datetime]: + ) -> datetime | None: if value: # Timestamp to timezone aware datetime return datetime.fromtimestamp(value, tz=timezone.utc) @@ -79,9 +79,9 @@ class DateTimeSchemaMixin: @staticmethod @pre_dump def set_date( - data: Union[EQSession, QuestionnaireState], + data: EQSession | QuestionnaireState, **kwargs: Any, - ) -> Union[EQSession, QuestionnaireState]: + ) -> EQSession | QuestionnaireState: data.updated_at = datetime.now(tz=timezone.utc) return data diff --git a/app/data_models/list_store.py b/app/data_models/list_store.py index 0ca90df7ca..e6c3dc487a 100644 --- a/app/data_models/list_store.py +++ b/app/data_models/list_store.py @@ -3,7 +3,7 @@ import random from functools import cached_property from string import ascii_letters -from typing import Iterable, Iterator, Optional, TypedDict, overload +from typing import Iterable, Iterator, TypedDict, overload from structlog import get_logger @@ -27,9 +27,9 @@ class ListModel: def __init__( self, name: str, - items: Optional[list[str]] = None, - primary_person: Optional[str] = None, - same_name_items: Optional[list[str]] = None, + items: list[str] | None = None, + primary_person: str | None = None, + same_name_items: list[str] | None = None, ): self.name = name self.items = items or [] @@ -127,7 +127,7 @@ class ListStore: ``` """ - def __init__(self, items: Optional[Iterable[ListModelDictType]] = None): + def __init__(self, items: Iterable[ListModelDictType] | None = None): items = items or [] self._lists = self._build_map(items) diff --git a/app/data_models/progress.py b/app/data_models/progress.py index 21e641da20..99d8fdbf84 100644 --- a/app/data_models/progress.py +++ b/app/data_models/progress.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from enum import StrEnum -from typing import Mapping, Optional, TypedDict +from typing import Mapping, TypedDict class CompletionStatus(StrEnum): @@ -24,7 +24,7 @@ class Progress: section_id: str block_ids: list[str] status: CompletionStatus - list_item_id: Optional[str] = None + list_item_id: str | None = None @classmethod def from_dict(cls, progress_dict: ProgressDict) -> Progress: diff --git a/app/data_models/questionnaire_store.py b/app/data_models/questionnaire_store.py index 18a5f457b3..3e5625eeb0 100644 --- a/app/data_models/questionnaire_store.py +++ b/app/data_models/questionnaire_store.py @@ -1,7 +1,7 @@ from __future__ import annotations from datetime import datetime -from typing import TYPE_CHECKING, MutableMapping, Optional +from typing import TYPE_CHECKING, MutableMapping from app.data_models.answer_store import AnswerStore from app.data_models.data_stores import DataStores @@ -22,7 +22,7 @@ class QuestionnaireStore: LATEST_VERSION = 1 def __init__( - self, storage: EncryptedQuestionnaireStorage, version: Optional[int] = None + self, storage: EncryptedQuestionnaireStorage, version: int | None = None ): self._storage = storage if version is None: @@ -31,8 +31,8 @@ def __init__( self._metadata: MutableMapping = {} self._stores = DataStores() self.data_stores = self._stores - self.submitted_at: Optional[datetime] - self.collection_exercise_sid: Optional[str] + self.submitted_at: datetime | None + self.collection_exercise_sid: str | None ( raw_data, diff --git a/app/data_models/relationship_store.py b/app/data_models/relationship_store.py index 8beae8cd8a..1fd0d8ed59 100644 --- a/app/data_models/relationship_store.py +++ b/app/data_models/relationship_store.py @@ -1,5 +1,5 @@ from dataclasses import asdict, dataclass -from typing import Iterable, Iterator, Optional, TypedDict, cast +from typing import Iterable, Iterator, TypedDict, cast class RelationshipDict(TypedDict, total=False): @@ -27,9 +27,7 @@ class RelationshipStore: Stores and updates relationships. """ - def __init__( - self, relationships: Optional[Iterable[RelationshipDict]] = None - ) -> None: + def __init__(self, relationships: Iterable[RelationshipDict] | None = None) -> None: self._is_dirty = False self._relationships = self._build_map(relationships or []) @@ -60,7 +58,7 @@ def serialize(self) -> list[RelationshipDict]: def get_relationship( self, list_item_id: str, to_list_item_id: str - ) -> Optional[Relationship]: + ) -> Relationship | None: key = (list_item_id, to_list_item_id) return self._relationships.get(key) diff --git a/app/data_models/session_data.py b/app/data_models/session_data.py index 723eda8a85..31341c96f0 100644 --- a/app/data_models/session_data.py +++ b/app/data_models/session_data.py @@ -1,10 +1,10 @@ -from typing import Any, Optional +from typing import Any class SessionData: def __init__( self, - language_code: Optional[str] = None, + language_code: str | None = None, confirmation_email_count: int = 0, feedback_count: int = 0, **_: Any, diff --git a/app/data_models/session_store.py b/app/data_models/session_store.py index 04271b7fd2..ee73c90416 100644 --- a/app/data_models/session_store.py +++ b/app/data_models/session_store.py @@ -1,7 +1,6 @@ from __future__ import annotations from datetime import datetime -from typing import Optional from flask import current_app from jwcrypto.common import base64url_decode @@ -17,19 +16,19 @@ class SessionStore: def __init__( - self, user_ik: str, pepper: str, eq_session_id: Optional[str] = None + self, user_ik: str, pepper: str, eq_session_id: str | None = None ) -> None: self.eq_session_id = eq_session_id - self.user_id: Optional[str] = None + self.user_id: str | None = None self.user_ik = user_ik - self.session_data: Optional[SessionData] = None - self._eq_session: Optional[EQSession] = None + self.session_data: SessionData | None = None + self._eq_session: EQSession | None = None self.pepper = pepper if eq_session_id: self._load() @property - def expiration_time(self) -> Optional[datetime]: + def expiration_time(self) -> datetime | None: """ Checking if expires_at is available can be removed soon after deployment, it is only needed to cater for in-flight sessions. @@ -97,7 +96,7 @@ def _load(self) -> None: logger.debug( "finding eq_session_id in database", eq_session_id=self.eq_session_id ) - self._eq_session: Optional[EQSession] = current_app.eq["storage"].get(EQSession, self.eq_session_id) # type: ignore + self._eq_session: EQSession | None = current_app.eq["storage"].get(EQSession, self.eq_session_id) # type: ignore if self._eq_session and self._eq_session.session_data: self.user_id = self._eq_session.user_id diff --git a/app/forms/duration_form.py b/app/forms/duration_form.py index 221256a3c9..e10038fa93 100644 --- a/app/forms/duration_form.py +++ b/app/forms/duration_form.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Callable, Mapping, Optional +from typing import Callable, Mapping from wtforms import Form @@ -12,7 +12,7 @@ # pylint: disable=no-member class DurationForm(Form): def validate( - self, extra_validators: Optional[dict[str, list[Callable]]] = None + self, extra_validators: dict[str, list[Callable]] | None = None ) -> bool: super().validate(extra_validators) @@ -43,8 +43,8 @@ def _set_error(self, key: str) -> None: list(self._fields.values())[0].errors = [self.answer_errors[key]] @property - def data(self) -> Optional[dict[str, Optional[str]]]: - data: dict[str, Optional[str]] = super().data + def data(self) -> dict[str, str | None] | None: + data: dict[str, str | None] = super().data if all(value is None for value in data.values()): return None return data diff --git a/app/forms/field_handlers/__init__.py b/app/forms/field_handlers/__init__.py index 6f32ee4d20..53d5d96413 100644 --- a/app/forms/field_handlers/__init__.py +++ b/app/forms/field_handlers/__init__.py @@ -1,5 +1,3 @@ -from typing import Optional - from werkzeug.datastructures import ImmutableDict from app.forms.field_handlers.address_handler import AddressHandler @@ -49,7 +47,7 @@ def get_field_handler( rule_evaluator: RuleEvaluator, error_messages: ImmutableDict, disable_validation: bool = False, - question_title: Optional[str] = None, + question_title: str | None = None, ) -> FieldHandler: return FIELD_HANDLER_MAPPINGS[answer_schema["type"]]( answer_schema=answer_schema, diff --git a/app/forms/field_handlers/date_handlers.py b/app/forms/field_handlers/date_handlers.py index 6330329aa4..e5882c5e53 100644 --- a/app/forms/field_handlers/date_handlers.py +++ b/app/forms/field_handlers/date_handlers.py @@ -1,6 +1,6 @@ from datetime import datetime, timezone from functools import cached_property -from typing import Any, Optional +from typing import Any from dateutil.relativedelta import relativedelta from wtforms.fields.core import UnboundField @@ -56,7 +56,7 @@ def get_field(self) -> UnboundField | DateField: ) def get_min_max_validator( - self, minimum_date: Optional[datetime], maximum_date: Optional[datetime] + self, minimum_date: datetime | None, maximum_date: datetime | None ) -> SingleDatePeriodCheck: messages = self.answer_schema.get("validation", {}).get("messages") @@ -67,7 +67,7 @@ def get_min_max_validator( maximum_date=maximum_date, ) - def get_referenced_date(self, key: str) -> Optional[datetime]: + def get_referenced_date(self, key: str) -> datetime | None: """ Gets value of the referenced date type, whether it is a value, id of an answer or a meta date. @@ -83,8 +83,8 @@ def get_referenced_date(self, key: str) -> Optional[datetime]: @staticmethod def transform_date_by_offset( - date_to_offset: Optional[datetime], offset: dict[str, int] - ) -> Optional[datetime]: + date_to_offset: datetime | None, offset: dict[str, int] + ) -> datetime | None: """ Adds/subtracts offset from a date and returns the new offset value @@ -102,7 +102,7 @@ def transform_date_by_offset( return date_to_offset - def get_date_value(self, key: str) -> Optional[datetime]: + def get_date_value(self, key: str) -> datetime | None: """ Gets attributes within a minimum or maximum of a date field and validates that the entered date is valid. @@ -129,7 +129,7 @@ def get_field(self) -> UnboundField | MonthYearDateField: ) def get_min_max_validator( - self, minimum_date: Optional[datetime], maximum_date: Optional[datetime] + self, minimum_date: datetime | None, maximum_date: datetime | None ) -> SingleDatePeriodCheck: messages = self.answer_schema.get("validation", {}).get("messages") @@ -158,7 +158,7 @@ def get_field(self) -> UnboundField | YearDateField: ) def get_min_max_validator( - self, minimum_date: Optional[datetime], maximum_date: Optional[datetime] + self, minimum_date: datetime | None, maximum_date: datetime | None ) -> SingleDatePeriodCheck: messages = self.answer_schema.get("validation", {}).get("messages") diff --git a/app/forms/field_handlers/field_handler.py b/app/forms/field_handlers/field_handler.py index 38a4ee4c68..4cb0d736ce 100644 --- a/app/forms/field_handlers/field_handler.py +++ b/app/forms/field_handlers/field_handler.py @@ -1,6 +1,6 @@ from abc import ABC from functools import cached_property -from typing import Any, Mapping, Optional, Union +from typing import Any, Mapping from wtforms import Field, validators from wtforms.validators import Optional as OptionalValidator @@ -24,7 +24,7 @@ def __init__( rule_evaluator: RuleEvaluator, error_messages: Mapping[str, str], disable_validation: bool = False, - question_title: Optional[str] = None, + question_title: str | None = None, ): self.answer_schema = answer_schema self.value_source_resolver = value_source_resolver @@ -44,7 +44,7 @@ def validators(self) -> list[validators.Optional]: return [] @cached_property - def label(self) -> Optional[str]: + def label(self) -> str | None: return self.answer_schema.get("label") @cached_property @@ -57,7 +57,7 @@ def get_validation_message(self, message_key: str) -> str: or self.error_messages[message_key] ) - def get_mandatory_validator(self) -> Union[ResponseRequired, OptionalValidator]: + def get_mandatory_validator(self) -> ResponseRequired | OptionalValidator: if self.answer_schema["mandatory"] is True: mandatory_message = self.get_validation_message(self.MANDATORY_MESSAGE_KEY) @@ -70,10 +70,10 @@ def get_mandatory_validator(self) -> Union[ResponseRequired, OptionalValidator]: def get_schema_value( self, schema_element: dict - ) -> Union[ValueSourceEscapedTypes, ValueSourceTypes]: + ) -> ValueSourceEscapedTypes | ValueSourceTypes: if isinstance(schema_element["value"], dict): return self.value_source_resolver.resolve(schema_element["value"]) - schema_element_value: Union[ValueSourceEscapedTypes, ValueSourceTypes] = ( + schema_element_value: ValueSourceEscapedTypes | ValueSourceTypes = ( schema_element["value"] ) return schema_element_value diff --git a/app/forms/field_handlers/mobile_number_handler.py b/app/forms/field_handlers/mobile_number_handler.py index 2344ad5c69..fc78691a52 100644 --- a/app/forms/field_handlers/mobile_number_handler.py +++ b/app/forms/field_handlers/mobile_number_handler.py @@ -1,5 +1,4 @@ from functools import cached_property -from typing import Union from wtforms import StringField from wtforms.fields.core import UnboundField @@ -7,7 +6,7 @@ from app.forms.field_handlers.field_handler import FieldHandler from app.forms.validators import MobileNumberCheck, ResponseRequired -MobileNumberValidatorTypes = list[Union[ResponseRequired, MobileNumberCheck]] +MobileNumberValidatorTypes = list[ResponseRequired | MobileNumberCheck] class MobileNumberHandler(FieldHandler): diff --git a/app/forms/field_handlers/number_handler.py b/app/forms/field_handlers/number_handler.py index d5d712e89c..bc721d389a 100644 --- a/app/forms/field_handlers/number_handler.py +++ b/app/forms/field_handlers/number_handler.py @@ -1,5 +1,6 @@ from functools import cached_property -from typing import Any, Union + +from typing import Any from wtforms import DecimalField, IntegerField from wtforms.fields.core import UnboundField @@ -15,7 +16,7 @@ from app.settings import MAX_NUMBER NumberValidatorTypes = list[ - Union[ResponseRequired, NumberCheck, NumberRange, DecimalPlaces] + ResponseRequired | NumberCheck | NumberRange | DecimalPlaces ] @@ -58,7 +59,7 @@ def max_decimals(self) -> int: @property def _field_type( self, - ) -> type[Union[DecimalFieldWithSeparator, IntegerFieldWithSeparator]]: + ) -> type[DecimalFieldWithSeparator | IntegerFieldWithSeparator]: return ( DecimalFieldWithSeparator if self.max_decimals > 0 @@ -80,7 +81,7 @@ def get_field(self) -> UnboundField | DecimalField | IntegerField: def _get_number_field_validators( self, - ) -> list[Union[NumberCheck, NumberRange, DecimalPlaces]]: + ) -> list[NumberCheck | NumberRange | DecimalPlaces]: answer_errors = dict(self.error_messages) for error_key in self.validation_messages.keys(): diff --git a/app/forms/field_handlers/select_handlers.py b/app/forms/field_handlers/select_handlers.py index e76e7d65cf..fb56a3b172 100644 --- a/app/forms/field_handlers/select_handlers.py +++ b/app/forms/field_handlers/select_handlers.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Sequence +from typing import Any, Sequence from wtforms.fields.core import UnboundField @@ -56,7 +56,7 @@ class SelectHandler(SelectHandlerBase): MANDATORY_MESSAGE_KEY = "MANDATORY_RADIO" @staticmethod - def coerce_str_unless_none(value: Optional[str]) -> Optional[str]: + def coerce_str_unless_none(value: str | None) -> str | None: """ Coerces a value using str() unless that value is None :param value: Any value that can be coerced using str() or None diff --git a/app/forms/field_handlers/string_handler.py b/app/forms/field_handlers/string_handler.py index 161cb9e1ff..2f1d6d45cd 100644 --- a/app/forms/field_handlers/string_handler.py +++ b/app/forms/field_handlers/string_handler.py @@ -1,5 +1,4 @@ from functools import cached_property -from typing import Union from wtforms import StringField, validators from wtforms.fields.core import UnboundField @@ -7,7 +6,7 @@ from app.forms.field_handlers.field_handler import FieldHandler -StringValidatorTypes = list[Union[validators.Optional, validators.Length]] +StringValidatorTypes = list[validators.Optional | validators.Length] class StringHandler(FieldHandler): diff --git a/app/forms/field_handlers/text_area_handler.py b/app/forms/field_handlers/text_area_handler.py index 5bc484f10a..d0885b5abe 100644 --- a/app/forms/field_handlers/text_area_handler.py +++ b/app/forms/field_handlers/text_area_handler.py @@ -1,5 +1,4 @@ from functools import cached_property -from typing import Union from wtforms import validators from wtforms.fields.core import UnboundField @@ -8,7 +7,7 @@ from app.forms.field_handlers.field_handler import FieldHandler from app.forms.fields import MaxTextAreaField -TextAreaValidatorTypes = list[Union[validators.Optional, validators.Length]] +TextAreaValidatorTypes = list[validators.Optional | validators.Length] class TextAreaHandler(FieldHandler): diff --git a/app/forms/questionnaire_form.py b/app/forms/questionnaire_form.py index bea1885e14..e11aca2700 100644 --- a/app/forms/questionnaire_form.py +++ b/app/forms/questionnaire_form.py @@ -5,7 +5,7 @@ from collections.abc import Callable from datetime import datetime, timedelta, timezone from decimal import Decimal -from typing import Any, Mapping, Optional, Sequence, Union +from typing import Any, Mapping, Sequence from dateutil.relativedelta import relativedelta from flask_wtf import FlaskForm @@ -33,7 +33,7 @@ QuestionnaireExtraValidators = Mapping[str, Sequence[Callable]] Period = Mapping[str, int] PeriodLimits = Mapping[str, Any] -Error = Union[Mapping, Sequence] +Error = Mapping | Sequence Errors = Mapping[str, Error] ErrorList = Sequence[tuple[str, str]] @@ -45,8 +45,8 @@ def __init__( schema: QuestionnaireSchema, question_schema: QuestionSchemaType, data_stores: DataStores, - location: Union[None, Location, RelationshipLocation], - **kwargs: Union[MultiDict, Mapping, None], + location: None | Location | RelationshipLocation, + **kwargs: MultiDict | Mapping | None, ): self.schema = schema self.question = question_schema @@ -66,7 +66,7 @@ def __init__( super().__init__(**kwargs) def validate( - self, extra_validators: Optional[QuestionnaireExtraValidators] = None + self, extra_validators: QuestionnaireExtraValidators | None = None ) -> bool: """ Validate this form as usual and check for any form-level validation errors based on question type @@ -213,7 +213,7 @@ def _get_target_total_and_currency( def validate_date_range_with_period_limits_and_single_date_limits( self, - question_id: Union[str, Sequence[Mapping]], + question_id: str | Sequence[Mapping], period_limits: PeriodLimits, period_range: timedelta, ) -> None: @@ -244,7 +244,7 @@ def _validate_date_range_question( period_from_id: str, period_to_id: str, messages: Mapping[str, str], - period_limits: Optional[PeriodLimits], + period_limits: PeriodLimits | None, ) -> bool: period_from = getattr(self, period_from_id) period_to = getattr(self, period_to_id) @@ -268,7 +268,7 @@ def _validate_calculated_question( calculation: Calculation, question: QuestionSchemaType, target_total: Any, - currency: Optional[str], + currency: str | None, decimal_places: int | None, ) -> bool: messages = None @@ -370,8 +370,8 @@ def _get_offset_value(period_object: Mapping[str, int]) -> timedelta: @staticmethod def _get_period_limits( - limits: Optional[PeriodLimits], - ) -> tuple[Optional[dict[str, Any]], Optional[dict[str, Any]]]: + limits: PeriodLimits | None, + ) -> tuple[dict[str, Any] | None, dict[str, Any] | None]: minimum, maximum = None, None if limits: if "minimum" in limits: @@ -405,7 +405,7 @@ def _get_formatted_calculation_values( @staticmethod def _get_calculation_total( - calculation_type: Callable, values: Sequence[Union[float, int, Decimal, str]] + calculation_type: Callable, values: Sequence[float | int | Decimal | str] ) -> Decimal: result: Decimal = calculation_type(Decimal(value or 0) for value in values) return result @@ -447,7 +447,7 @@ def get_data(self, answer_id: str) -> str: def _option_value_in_data( answer: Mapping[str, str], option: Mapping[str, Any], - data: Union[MultiDict[str, Any], Mapping[str, Any]], + data: MultiDict[str, Any] | Mapping[str, Any], ) -> bool: data_to_inspect = data.to_dict(flat=False) if isinstance(data, MultiDict) else data diff --git a/app/forms/validators.py b/app/forms/validators.py index 9426038a8d..2d76c8f4a9 100644 --- a/app/forms/validators.py +++ b/app/forms/validators.py @@ -4,7 +4,8 @@ import re from datetime import datetime, timezone from decimal import Decimal, InvalidOperation -from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Sequence, Union +from typing import TYPE_CHECKING, Iterable, Mapping, Sequence + import flask_babel from babel import numbers @@ -39,19 +40,19 @@ ) email_regex = re.compile(r"^.+@([^.@][^@\s]+)$") -OptionalMessage = Optional[Mapping[str, str]] -NumType = Union[int, Decimal] +OptionalMessage = Mapping[str, str] | None +NumType = int | Decimal PeriodType = Mapping[str, int] class NumberCheck: - def __init__(self, message: Optional[str] = None): + def __init__(self, message: str | None = None): self.message = message or error_messages["INVALID_NUMBER"] def __call__( self, form: FlaskForm, - field: Union[DecimalFieldWithSeparator, IntegerFieldWithSeparator], + field: DecimalFieldWithSeparator | IntegerFieldWithSeparator, ) -> None: try: # number is sanitised to guard against inputs like `,NaN_` etc @@ -108,12 +109,12 @@ class NumberRange: def __init__( self, - minimum: Optional[NumType] = None, + minimum: NumType | None = None, minimum_exclusive: bool = False, - maximum: Optional[NumType] = None, + maximum: NumType | None = None, maximum_exclusive: bool = False, messages: OptionalMessage = None, - currency: Optional[str] = None, + currency: str | None = None, ): self.minimum = minimum self.maximum = maximum @@ -125,7 +126,7 @@ def __init__( def __call__( self, form: "QuestionnaireForm", - field: Union[DecimalFieldWithSeparator, IntegerFieldWithSeparator], + field: DecimalFieldWithSeparator | IntegerFieldWithSeparator, ) -> None: value: int | Decimal | None = field.data @@ -141,7 +142,7 @@ def __call__( def validate_minimum( self, *, value: NumType, decimal_limit: int | None - ) -> Optional[str]: + ) -> str | None: if self.minimum is None: return None @@ -161,7 +162,7 @@ def validate_minimum( def validate_maximum( self, *, value: NumType, decimal_limit: int | None - ) -> Optional[str]: + ) -> str | None: if self.maximum is None: return None @@ -240,7 +241,7 @@ def __call__(self, form: Sequence["QuestionnaireForm"], field: Field) -> None: class DateRequired: field_flags = ("required",) - def __init__(self, message: Optional[str] = None): + def __init__(self, message: str | None = None): self.message = message or error_messages["MANDATORY_DATE"] def __call__(self, form: "QuestionnaireForm", field: DateField) -> None: @@ -259,7 +260,7 @@ def __call__(self, form: "QuestionnaireForm", field: DateField) -> None: class DateCheck: - def __init__(self, message: Optional[str] = None): + def __init__(self, message: str | None = None): self.message = message or error_messages["INVALID_DATE"] def __call__(self, form: "QuestionnaireForm", field: StringField) -> None: @@ -285,8 +286,8 @@ def __init__( self, messages: OptionalMessage = None, date_format: str = "d MMMM yyyy", - minimum_date: Optional[datetime] = None, - maximum_date: Optional[datetime] = None, + minimum_date: datetime | None = None, + maximum_date: datetime | None = None, ): self.messages = {**error_messages, **(messages or {})} self.minimum_date = minimum_date @@ -326,8 +327,8 @@ class DateRangeCheck: def __init__( self, messages: OptionalMessage = None, - period_min: Optional[dict[str, int]] = None, - period_max: Optional[dict[str, int]] = None, + period_min: dict[str, int] | None = None, + period_max: dict[str, int] | None = None, ): self.messages = {**error_messages, **(messages or {})} self.period_min = period_min @@ -408,9 +409,7 @@ def _build_range_length_error(period_object: PeriodType) -> str: class SumCheck: - def __init__( - self, messages: OptionalMessage = None, currency: Optional[str] = None - ): + def __init__(self, messages: OptionalMessage = None, currency: str | None = None): self.messages = {**error_messages, **(messages or {})} self.currency = currency @@ -457,8 +456,8 @@ def __call__( @staticmethod def _is_valid( condition: str, - total: Union[Decimal, float], - target_total: Union[Decimal, float], + total: Decimal | float, + target_total: Decimal | float, ) -> tuple[bool, str]: if condition == "equals": return total == target_total, "TOTAL_SUM_NOT_EQUALS" @@ -516,7 +515,7 @@ def __call__(self, form: "QuestionnaireForm", field: StringField) -> None: class EmailTLDCheck: - def __init__(self, message: Optional[str] = None): + def __init__(self, message: str | None = None): self.message = message or error_messages["INVALID_EMAIL_FORMAT"] def __call__(self, form: "QuestionnaireForm", field: StringField) -> None: diff --git a/app/helpers/address_lookup_api_helper.py b/app/helpers/address_lookup_api_helper.py index d1f5e7991a..98f72a488d 100644 --- a/app/helpers/address_lookup_api_helper.py +++ b/app/helpers/address_lookup_api_helper.py @@ -1,5 +1,4 @@ from datetime import datetime, timezone -from typing import Optional from uuid import uuid4 from flask import current_app @@ -11,7 +10,7 @@ def get_jwk_from_secret(secret: str) -> jwk.JWK: return jwk.JWK(kty="oct", k=base64url_encode(secret.encode("utf-8"))) -def get_address_lookup_api_auth_token() -> Optional[str]: +def get_address_lookup_api_auth_token() -> str | None: if current_app.config["ADDRESS_LOOKUP_API_AUTH_ENABLED"]: secret = current_app.eq["secret_store"].get_secret_by_name( # type: ignore "ADDRESS_LOOKUP_API_AUTH_TOKEN_SECRET" diff --git a/app/helpers/header_helpers.py b/app/helpers/header_helpers.py index a4aeb8654e..1525a6b901 100644 --- a/app/helpers/header_helpers.py +++ b/app/helpers/header_helpers.py @@ -1,11 +1,9 @@ -from typing import Union - from werkzeug.datastructures import EnvironHeaders def get_span_and_trace( headers: EnvironHeaders, -) -> Union[tuple[None, None], tuple[str, str]]: +) -> tuple[None, None] | tuple[str, str]: try: trace, span = headers.get("X-Cloud-Trace-Context").split("/") # type: ignore except (ValueError, AttributeError): diff --git a/app/jinja_filters.py b/app/jinja_filters.py index c3dfd776e3..9c8754a8ee 100644 --- a/app/jinja_filters.py +++ b/app/jinja_filters.py @@ -3,7 +3,7 @@ import re from datetime import datetime from decimal import Decimal -from typing import Any, Callable, Literal, Mapping, Optional, TypeAlias, Union +from typing import Any, Callable, Literal, Mapping, TypeAlias import flask import flask_babel @@ -31,7 +31,7 @@ UnitLengthType: TypeAlias = Literal["short", "long", "narrow"] -def mark_safe(context: nodes.EvalContext, value: str) -> Union[Markup, str]: +def mark_safe(context: nodes.EvalContext, value: str) -> Markup | str: return Markup(value) if context.autoescape else value @@ -61,7 +61,7 @@ def get_currency_symbol(currency: str = "GBP") -> str: @blueprint.app_template_filter() -def format_percentage(value: Union[int, float, Decimal]) -> str: +def format_percentage(value: int | float | Decimal) -> str: return f"{value}%" @@ -151,9 +151,7 @@ def get_format_date(value: Markup) -> str: @pass_eval_context @blueprint.app_template_filter() -def format_datetime( - context: nodes.EvalContext, date_time: datetime -) -> Union[str, Markup]: +def format_datetime(context: nodes.EvalContext, date_time: datetime) -> str | Markup: # flask babel on formatting will automatically convert based on the time zone specified in setup.py formatted_date = flask_babel.format_date(date_time, format="d MMMM yyyy") formatted_time = flask_babel.format_time(date_time, format="HH:mm") @@ -266,7 +264,7 @@ def get_min_max_value_width( @blueprint.app_template_filter() -def get_width_for_number(answer: AnswerType) -> Optional[int]: +def get_width_for_number(answer: AnswerType) -> int | None: allowable_widths = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20, 30, 40, 50] min_value_width = get_min_max_value_width("minimum", answer, 0) @@ -287,7 +285,7 @@ def get_width_for_number_processor() -> dict[str, Callable]: class LabelConfig: - def __init__(self, _for: str, text: str, description: Optional[str] = None) -> None: + def __init__(self, _for: str, text: str, description: str | None = None) -> None: self._for = _for self.text = text self.description = description @@ -299,7 +297,7 @@ def __init__( option: SelectFieldBase._Option, index: int, answer: AnswerType, - form: Optional[FormType] = None, + form: FormType | None = None, ) -> None: self.id = option.id self.name = option.name @@ -395,7 +393,7 @@ def map_select_config_processor() -> dict[str, Callable]: @blueprint.app_template_filter() def map_relationships_config( - form: Mapping[str, str], answer: Mapping[str, Union[int, slice]] + form: Mapping[str, str], answer: Mapping[str, int | slice] ) -> list[RelationshipRadioConfig]: options = form["fields"][answer["id"]] @@ -459,7 +457,7 @@ def __init__( class SummaryRowItemValue: - def __init__(self, text: str, other: Optional[str] = None) -> None: + def __init__(self, text: str, other: str | None = None) -> None: self.text = text if other or other == 0: @@ -588,17 +586,17 @@ def __init__( @blueprint.app_template_filter() def map_summary_item_config( - group: dict[str, Union[list, dict]], + group: dict[str, list | dict], summary_type: str, answers_are_editable: bool, no_answer_provided: str, edit_link_text: str, edit_link_aria_label: str, - calculated_question: Optional[dict[str, list]], + calculated_question: dict[str, list] | None, remove_link_text: str | None = None, remove_link_aria_label: str | None = None, -) -> list[Union[dict[str, list], SummaryRow]]: - rows: list[Union[dict[str, list], SummaryRow]] = [] +) -> list[dict[str, list] | SummaryRow]: + rows: list[dict[str, list] | SummaryRow] = [] for block in group["blocks"]: if block.get("question"): diff --git a/app/questionnaire/dynamic_answer_options.py b/app/questionnaire/dynamic_answer_options.py index 0ee21e60b5..f267f57208 100644 --- a/app/questionnaire/dynamic_answer_options.py +++ b/app/questionnaire/dynamic_answer_options.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Mapping, Union +from typing import Mapping from app.questionnaire.rules.operator import Operator from app.questionnaire.rules.rule_evaluator import RuleEvaluator, RuleEvaluatorTypes @@ -18,9 +18,7 @@ class DynamicAnswerOptions: def evaluate(self) -> tuple[dict[str, str], ...]: values = self.dynamic_options_schema["values"] - resolved_values: Union[ - ValueSourceEscapedTypes, ValueSourceTypes, RuleEvaluatorTypes - ] + resolved_values: ValueSourceEscapedTypes | ValueSourceTypes | RuleEvaluatorTypes if "source" in values: if values["source"] != "answers": diff --git a/app/questionnaire/rules/helpers.py b/app/questionnaire/rules/helpers.py index 5db0abb155..3c6bf24ac9 100644 --- a/app/questionnaire/rules/helpers.py +++ b/app/questionnaire/rules/helpers.py @@ -1,12 +1,12 @@ from datetime import datetime from decimal import Decimal from functools import wraps -from typing import Any, Callable, Sequence, Union +from typing import Any, Callable, Sequence -ValueTypes = Union[bool, str, int, float, Decimal, None, datetime] +ValueTypes = bool | str | int | float | Decimal | None | datetime -def _casefold(value: Union[list, ValueTypes]) -> Union[list, ValueTypes]: +def _casefold(value: list | ValueTypes) -> list | ValueTypes: if isinstance(value, str): return value.casefold() diff --git a/app/questionnaire/rules/operations_helper.py b/app/questionnaire/rules/operations_helper.py index 29d0c9be8e..9bce3a555c 100644 --- a/app/questionnaire/rules/operations_helper.py +++ b/app/questionnaire/rules/operations_helper.py @@ -5,7 +5,7 @@ """ from datetime import date -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from app.questionnaire.questionnaire_schema import QuestionnaireSchema from app.questionnaire.rules.operations import DateOffset, Operations @@ -30,10 +30,10 @@ def __init__( def string_to_datetime( self, - date_string: Optional[str], - offset: Optional[DateOffset] = None, + date_string: str | None, + offset: DateOffset | None = None, offset_by_full_weeks: bool = False, - ) -> Optional[date]: + ) -> date | None: return self.ops.resolve_date_from_string( date_string, offset, offset_by_full_weeks ) diff --git a/app/questionnaire/rules/operator.py b/app/questionnaire/rules/operator.py index 0b7ec073b4..b387db11f3 100644 --- a/app/questionnaire/rules/operator.py +++ b/app/questionnaire/rules/operator.py @@ -1,5 +1,5 @@ from datetime import date -from typing import TYPE_CHECKING, Generator, Iterable, Optional, Sequence, Union +from typing import TYPE_CHECKING, Generator, Iterable, Sequence from app.questionnaire.rules.helpers import ValueTypes @@ -40,15 +40,13 @@ def __init__(self, name: str, operations: "Operations") -> None: Operator.ANY_IN, } - def evaluate( - self, operands: Union[Generator, Iterable] - ) -> Union[bool, Optional[date]]: + def evaluate(self, operands: Generator | Iterable) -> bool | date | None: if self._ensure_operands_not_none: operands = list(operands) if self._any_operands_none(*operands): return False - value: Union[bool, Optional[date]] = ( + value: bool | date | None = ( self._operation(operands) if self.name in {Operator.AND, Operator.OR} else self._operation(*operands) @@ -56,7 +54,7 @@ def evaluate( return value @staticmethod - def _any_operands_none(*operands: Union[Sequence, ValueTypes]) -> bool: + def _any_operands_none(*operands: Sequence | ValueTypes) -> bool: return any(operand is None for operand in operands) diff --git a/app/questionnaire/rules/utils.py b/app/questionnaire/rules/utils.py index 7a8ca7e836..ea7d268f3f 100644 --- a/app/questionnaire/rules/utils.py +++ b/app/questionnaire/rules/utils.py @@ -1,5 +1,5 @@ from datetime import datetime, timezone -from typing import Optional, overload +from typing import overload from dateutil import parser @@ -16,7 +16,7 @@ def parse_datetime(date_string: None) -> None: ... # pragma: no cover def parse_datetime(date_string: str) -> datetime: ... # pragma: no cover -def parse_datetime(date_string: Optional[str]) -> Optional[datetime]: +def parse_datetime(date_string: str | None) -> datetime | None: """ :param date_string: string representing a date :return: datetime of that date string diff --git a/app/secrets.py b/app/secrets.py index 840f6047cb..073f0d0b93 100644 --- a/app/secrets.py +++ b/app/secrets.py @@ -1,4 +1,4 @@ -from typing import Mapping, Optional +from typing import Mapping SecretsType = Mapping[str, Mapping[str, str]] @@ -13,7 +13,7 @@ def validate_required_secrets( - secrets: SecretsType, additional_required_secrets: Optional[list[str]] = None + secrets: SecretsType, additional_required_secrets: list[str] | None = None ) -> None: all_required_secrets = ( REQUIRED_SECRETS + additional_required_secrets @@ -29,5 +29,5 @@ class SecretStore: def __init__(self, secrets: SecretsType) -> None: self.secrets = secrets.get("secrets", {}) - def get_secret_by_name(self, secret_name: str) -> Optional[str]: + def get_secret_by_name(self, secret_name: str) -> str | None: return self.secrets.get(secret_name) diff --git a/app/storage/datastore.py b/app/storage/datastore.py index 790ff2ca40..4eac362543 100644 --- a/app/storage/datastore.py +++ b/app/storage/datastore.py @@ -1,5 +1,3 @@ -from typing import Optional - from google.api_core.retry import Retry from google.cloud import datastore from google.cloud.datastore import Entity @@ -36,7 +34,7 @@ def put(self, model: ModelTypes, overwrite: bool = True) -> bool: return True @Retry() - def get(self, model_type: type[ModelTypes], key_value: str) -> Optional[ModelTypes]: + def get(self, model_type: type[ModelTypes], key_value: str) -> ModelTypes | None: storage_model = StorageModel(model_type=model_type) key = self.client.key(storage_model.table_name, key_value) diff --git a/app/storage/dynamodb.py b/app/storage/dynamodb.py index d578999009..f3f9335d31 100644 --- a/app/storage/dynamodb.py +++ b/app/storage/dynamodb.py @@ -1,5 +1,3 @@ -from typing import Optional - import boto3 from botocore.exceptions import ClientError @@ -33,7 +31,7 @@ def put(self, model: ModelTypes, overwrite: bool = True) -> bool: raise # pragma: no cover - def get(self, model_type: type[ModelTypes], key_value: str) -> Optional[ModelTypes]: + def get(self, model_type: type[ModelTypes], key_value: str) -> ModelTypes | None: storage_model = StorageModel(model_type=model_type) table = self.client.Table(storage_model.table_name) key = {storage_model.key_field: key_value} diff --git a/app/storage/encrypted_questionnaire_storage.py b/app/storage/encrypted_questionnaire_storage.py index 49c66b8c46..c83abddcce 100644 --- a/app/storage/encrypted_questionnaire_storage.py +++ b/app/storage/encrypted_questionnaire_storage.py @@ -1,5 +1,4 @@ from datetime import datetime -from typing import Optional, Union import snappy from flask import current_app @@ -21,8 +20,8 @@ def save( self, data: str, collection_exercise_sid: str, - submitted_at: Optional[datetime] = None, - expires_at: Optional[datetime] = None, + submitted_at: datetime | None = None, + expires_at: datetime | None = None, ) -> None: compressed_data = snappy.compress(data) encrypted_data = self.encrypter.encrypt_data(compressed_data) @@ -39,7 +38,7 @@ def save( def get_user_data( self, - ) -> Union[tuple[None, None, None, None], tuple[str, str, int, Optional[datetime]]]: + ) -> tuple[None, None, None, None] | tuple[str, str, int, datetime | None]: questionnaire_state = self._find_questionnaire_state() if questionnaire_state and questionnaire_state.state_data: version = questionnaire_state.version @@ -58,7 +57,7 @@ def delete(self) -> None: if questionnaire_state: current_app.eq["storage"].delete(questionnaire_state) # type: ignore - def _find_questionnaire_state(self) -> Optional[QuestionnaireState]: + def _find_questionnaire_state(self) -> QuestionnaireState | None: logger.debug("getting questionnaire data", user_id=self._user_id) state: QuestionnaireState = current_app.eq["storage"].get(QuestionnaireState, self._user_id) # type: ignore return state diff --git a/app/storage/redis.py b/app/storage/redis.py index 9a3f5c4b27..b9aa1e2d7b 100644 --- a/app/storage/redis.py +++ b/app/storage/redis.py @@ -1,5 +1,4 @@ from datetime import datetime, timezone -from typing import Optional import redis from redis.exceptions import ConnectionError as RedisConnectionError @@ -56,7 +55,7 @@ def put(self, model: ModelTypes, overwrite: bool = True) -> bool: return True - def get(self, model_type: type[ModelTypes], key_value: str) -> Optional[ModelTypes]: + def get(self, model_type: type[ModelTypes], key_value: str) -> ModelTypes | None: storage_model = StorageModel(model_type=model_type) try: item = self.client.get(key_value) diff --git a/app/storage/storage.py b/app/storage/storage.py index 49a925204d..da0704831c 100644 --- a/app/storage/storage.py +++ b/app/storage/storage.py @@ -2,22 +2,22 @@ from abc import ABC, abstractmethod from functools import cached_property -from typing import Any, Optional, TypedDict, Union +from typing import Any, TypedDict from flask import current_app from google.cloud import datastore from app.data_models import app_models -ModelSchemaTypes = Union[ - app_models.QuestionnaireStateSchema, - app_models.EQSessionSchema, - app_models.UsedJtiClaimSchema, -] +ModelSchemaTypes = ( + app_models.QuestionnaireStateSchema + | app_models.EQSessionSchema + | app_models.UsedJtiClaimSchema +) -ModelTypes = Union[ - app_models.QuestionnaireState, app_models.EQSession, app_models.UsedJtiClaim -] +ModelTypes = ( + app_models.QuestionnaireState | app_models.EQSession | app_models.UsedJtiClaim +) class TableConfig(TypedDict, total=False): @@ -65,7 +65,7 @@ def key_field(self) -> str: return self._config["key_field"] @cached_property - def expiry_field(self) -> Optional[str]: + def expiry_field(self) -> str | None: return self._config.get("expiry_field") @cached_property @@ -81,7 +81,7 @@ def serialize(self, model_to_serialize: ModelTypes) -> dict: serialized_data: dict = self._schema.dump(model_to_serialize) return serialized_data - def deserialize(self, serialized_item: Union[datastore.Entity]) -> ModelTypes: + def deserialize(self, serialized_item: datastore.Entity) -> ModelTypes: deserialized_data: ModelTypes = self._schema.load(serialized_item) return deserialized_data @@ -95,7 +95,7 @@ def put(self, model: ModelTypes, overwrite: bool = True) -> bool: pass # pragma: no cover @abstractmethod - def get(self, model_type: type[ModelTypes], key_value: str) -> Optional[ModelTypes]: + def get(self, model_type: type[ModelTypes], key_value: str) -> ModelTypes | None: pass # pragma: no cover @abstractmethod diff --git a/app/storage/storage_encryption.py b/app/storage/storage_encryption.py index 9b9311a0e1..cd86b74c12 100644 --- a/app/storage/storage_encryption.py +++ b/app/storage/storage_encryption.py @@ -1,5 +1,4 @@ import hashlib -from typing import Optional, Union from jwcrypto import jwe, jwk from jwcrypto.common import base64url_encode @@ -13,7 +12,7 @@ class StorageEncryption: def __init__( - self, user_id: Optional[str], user_ik: Optional[str], pepper: Optional[str] + self, user_id: str | None, user_ik: str | None, pepper: str | None ) -> None: if not user_id: raise ValueError("user_id not provided") @@ -38,7 +37,7 @@ def _generate_key(user_id: str, user_ik: str, pepper: str) -> jwk.JWK: return jwk.JWK(**password) - def encrypt_data(self, data: Union[str, dict]) -> str: + def encrypt_data(self, data: str | dict) -> str: if isinstance(data, dict): data = json_dumps(data) diff --git a/app/submitter/previously_submitted_exception.py b/app/submitter/previously_submitted_exception.py index f467943703..70249a63d7 100644 --- a/app/submitter/previously_submitted_exception.py +++ b/app/submitter/previously_submitted_exception.py @@ -1,8 +1,5 @@ -from typing import Union - - class PreviouslySubmittedException(Exception): - def __init__(self, value: Union[str, int]) -> None: + def __init__(self, value: str | int) -> None: super().__init__() self.value = value diff --git a/app/submitter/submitter.py b/app/submitter/submitter.py index 958d17b812..21d6bd7698 100644 --- a/app/submitter/submitter.py +++ b/app/submitter/submitter.py @@ -1,4 +1,4 @@ -from typing import Mapping, Optional, Union +from typing import Mapping from uuid import uuid4 from google.api_core.exceptions import Forbidden @@ -19,7 +19,7 @@ def send_message( message: str, tx_id: str, case_id: str, - **kwargs: Mapping[str, Union[str, int]], + **kwargs: Mapping[str, str | int], ) -> bool: logger.info("sending message") logger.info( @@ -77,8 +77,8 @@ def __init__( secondary_host: str, port: int, queue: str, - username: Optional[str] = None, - password: Optional[str] = None, + username: str | None = None, + password: str | None = None, ) -> None: self.queue = queue if username and password: @@ -120,7 +120,7 @@ def _connect(self) -> BlockingConnection: raise err @staticmethod - def _disconnect(connection: Optional[BlockingConnection]) -> None: + def _disconnect(connection: BlockingConnection | None) -> None: try: if connection: logger.info("attempt to close connection", category="rabbitmq") diff --git a/app/survey_config/business_config.py b/app/survey_config/business_config.py index a4c3362ed1..bcd3dcb748 100644 --- a/app/survey_config/business_config.py +++ b/app/survey_config/business_config.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Iterable, Mapping, MutableMapping, Optional +from typing import Iterable, Mapping, MutableMapping from urllib.parse import urlencode from warnings import warn @@ -31,7 +31,7 @@ def __post_init__(self) -> None: self.account_service_todo_url: str = f"{self.base_url}/surveys/todo" def _get_account_service_help_url( - self, *, is_authenticated: bool, ru_ref: Optional[str] + self, *, is_authenticated: bool, ru_ref: str | None ) -> str: if self.schema and is_authenticated and ru_ref: request_data = { @@ -52,8 +52,8 @@ def get_service_links( *, is_authenticated: bool, cookie_has_theme: bool, - ru_ref: Optional[str], - ) -> Optional[list[dict]]: + ru_ref: str | None, + ) -> list[dict] | None: links = ( [ HeaderLink( @@ -102,7 +102,7 @@ def get_footer_links(self, cookie_has_theme: bool) -> list[dict]: return links - def get_footer_legal_links(self, cookie_has_theme: bool) -> Optional[list[dict]]: + def get_footer_legal_links(self, cookie_has_theme: bool) -> list[dict] | None: if cookie_has_theme: return [ Link(lazy_gettext("Cookies"), self.cookie_settings_url).as_dict(), diff --git a/app/survey_config/link.py b/app/survey_config/link.py index 155117c344..c6460db774 100644 --- a/app/survey_config/link.py +++ b/app/survey_config/link.py @@ -1,5 +1,4 @@ from dataclasses import dataclass, field -from typing import Optional from flask_babel.speaklater import LazyString @@ -8,8 +7,8 @@ class Link: text: LazyString url: str - target: Optional[str] = "_blank" - attributes: Optional[dict] = field(default_factory=dict) + target: str | None = "_blank" + attributes: dict | None = field(default_factory=dict) def as_dict(self): return {k: v for k, v in self.__dict__.items() if v} diff --git a/app/survey_config/social_survey_config.py b/app/survey_config/social_survey_config.py index c8fb8b7fbc..e3103231b0 100644 --- a/app/survey_config/social_survey_config.py +++ b/app/survey_config/social_survey_config.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Iterable, Mapping, MutableMapping, Optional +from typing import Iterable, Mapping, MutableMapping from flask_babel import lazy_gettext @@ -52,7 +52,7 @@ def get_footer_links(self, cookie_has_theme: bool) -> list[dict]: return links - def get_footer_legal_links(self, cookie_has_theme: bool) -> Optional[list[dict]]: + def get_footer_legal_links(self, cookie_has_theme: bool) -> list[dict] | None: if cookie_has_theme: return [ Link(lazy_gettext("Cookies"), self.cookie_settings_url).as_dict(), diff --git a/app/survey_config/survey_config.py b/app/survey_config/survey_config.py index 36aeb5ae58..50d7876248 100644 --- a/app/survey_config/survey_config.py +++ b/app/survey_config/survey_config.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Iterable, Mapping, MutableMapping, Optional +from typing import Iterable, Mapping, MutableMapping from flask_babel import lazy_gettext from flask_babel.speaklater import LazyString @@ -13,32 +13,32 @@ class SurveyConfig: """Valid options for defining survey-based configuration.""" - schema: Optional[QuestionnaireSchema] = None - copyright_declaration: Optional[LazyString] = lazy_gettext( + schema: QuestionnaireSchema | None = None + copyright_declaration: LazyString | None = lazy_gettext( "Crown copyright and database rights 2020 OS 100019153." ) - copyright_text: Optional[LazyString] = lazy_gettext( + copyright_text: LazyString | None = lazy_gettext( "Use of address data is subject to the terms and conditions." ) base_url: str = ACCOUNT_SERVICE_BASE_URL - account_service_my_account_url: Optional[str] = None - account_service_todo_url: Optional[str] = None - account_service_log_out_url: Optional[str] = None + account_service_my_account_url: str | None = None + account_service_todo_url: str | None = None + account_service_log_out_url: str | None = None accessibility_url: str = f"{ONS_URL}/help/accessibility/" what_we_do_url: str = f"{ONS_URL}/aboutus/whatwedo/" - masthead_logo: Optional[str] = None - masthead_logo_mobile: Optional[str] = None + masthead_logo: str | None = None + masthead_logo_mobile: str | None = None crest: bool = True - footer_links: Optional[Iterable[MutableMapping]] = None - footer_legal_links: Optional[Iterable[Mapping]] = None - survey_title: Optional[LazyString] = None - design_system_theme: Optional[str] = None + footer_links: Iterable[MutableMapping] | None = None + footer_legal_links: Iterable[Mapping] | None = None + survey_title: LazyString | None = None + design_system_theme: str | None = None sign_out_button_text: str = lazy_gettext("Save and exit survey") contact_us_url: str = field(init=False) cookie_settings_url: str = field(init=False) cookie_domain: str = field(init=False) privacy_and_data_protection_url: str = field(init=False) - language_code: Optional[str] = None + language_code: str | None = None def __post_init__(self) -> None: self.contact_us_url: str = f"{self.base_url}/contact-us/" @@ -57,16 +57,16 @@ def get_service_links( # pylint: disable=unused-argument, no-self-use *, is_authenticated: bool, cookie_has_theme: bool, - ru_ref: Optional[str], - ) -> Optional[list[dict]]: + ru_ref: str | None, + ) -> list[dict] | None: return None def get_footer_links( # pylint: disable=unused-argument, no-self-use self, cookie_has_theme: bool - ) -> Optional[list[dict]]: + ) -> list[dict] | None: return None def get_footer_legal_links( # pylint: disable=unused-argument, no-self-use self, cookie_has_theme: bool - ) -> Optional[list[dict]]: + ) -> list[dict] | None: return None diff --git a/app/utilities/strings.py b/app/utilities/strings.py index 1bfdf0fb37..f58ce64e10 100644 --- a/app/utilities/strings.py +++ b/app/utilities/strings.py @@ -1,8 +1,7 @@ import re -from typing import Union -def to_bytes(bytes_or_str: Union[bytes, str]) -> bytes: +def to_bytes(bytes_or_str: bytes | str) -> bytes: """ Converts supplied data into bytes if the data is of type str. :param bytes_or_str: Data to be converted. @@ -13,7 +12,7 @@ def to_bytes(bytes_or_str: Union[bytes, str]) -> bytes: return bytes_or_str -def to_str(bytes_or_str: Union[bytes, str]) -> str: +def to_str(bytes_or_str: bytes | str) -> str: """ Converts supplied data into a UTF-8 encoded string if the data is of type bytes. :param bytes_or_str: Data to be converted. diff --git a/app/utilities/types.py b/app/utilities/types.py index 9a0c7d8d2c..2d159b736b 100644 --- a/app/utilities/types.py +++ b/app/utilities/types.py @@ -12,15 +12,15 @@ RelationshipLocation, # pragma: no cover ) -LocationType: TypeAlias = Union["Location", "RelationshipLocation"] +LocationType: TypeAlias = Union["Location", "RelationshipLocation"] # noqa: UP007 SupplementaryDataKeyType: TypeAlias = tuple[str, str | None] SupplementaryDataValueType: TypeAlias = dict | str | list | None DateValidatorType: TypeAlias = Union[ "OptionalForm", "DateRequired", "DateCheck", "SingleDatePeriodCheck" -] +] # noqa: UP007 -ChoiceType: TypeAlias = Union["Choice", "ChoiceWithDetailAnswer"] +ChoiceType: TypeAlias = Union["Choice", "ChoiceWithDetailAnswer"] # noqa: UP007 ChoiceWidgetRenderType: TypeAlias = tuple[str, str, bool, str | None] diff --git a/app/views/contexts/email_form_context.py b/app/views/contexts/email_form_context.py index 1b248baec1..4c57f1eb49 100644 --- a/app/views/contexts/email_form_context.py +++ b/app/views/contexts/email_form_context.py @@ -1,4 +1,4 @@ -from typing import Any, Union +from typing import Any from flask import url_for @@ -7,7 +7,7 @@ def build_confirmation_email_form_context( email_confirmation_form: EmailForm, -) -> dict[str, Union[bool, str, Any]]: +) -> dict[str, bool | str | Any]: return { "hide_sign_out_button": False, "sign_out_url": url_for("session.get_sign_out"), diff --git a/app/views/contexts/feedback_form_context.py b/app/views/contexts/feedback_form_context.py index 478dc6a62a..2757d3d831 100644 --- a/app/views/contexts/feedback_form_context.py +++ b/app/views/contexts/feedback_form_context.py @@ -1,5 +1,3 @@ -from typing import Union - from flask import url_for from app.forms.questionnaire_form import QuestionnaireForm @@ -9,7 +7,7 @@ def build_feedback_context( question_schema: QuestionSchemaType, form: QuestionnaireForm -) -> dict[str, Union[str, bool, dict]]: +) -> dict[str, str | bool | dict]: block = {"question": question_schema} context = build_question_context(block, form) context["hide_sign_out_button"] = False diff --git a/app/views/contexts/hub_context.py b/app/views/contexts/hub_context.py index b2e6986f84..937f2cfb8e 100644 --- a/app/views/contexts/hub_context.py +++ b/app/views/contexts/hub_context.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import Any, Iterable, Mapping, Optional, Union +from typing import Any, Iterable, Mapping from flask import url_for from flask_babel import lazy_gettext @@ -82,11 +82,11 @@ def __call__( def get_row_context_for_section( self, - section_name: Optional[str], + section_name: str | None, section_status: CompletionStatus, section_url: str, row_id: str, - ) -> dict[str, Union[str, list]]: + ) -> dict[str, str | list]: section_content = self.SECTION_CONTENT_STATES[section_status] context: dict = { "rowItems": [ @@ -119,7 +119,7 @@ def get_row_context_for_section( @staticmethod def get_section_url( - section_id: str, list_item_id: Optional[str], section_status: CompletionStatus + section_id: str, list_item_id: str | None, section_status: CompletionStatus ) -> str: if section_status == CompletionStatus.INDIVIDUAL_RESPONSE_REQUESTED: return url_for( @@ -137,8 +137,8 @@ def get_section_url( return url_for("questionnaire.get_section", section_id=section_id) def _get_row_for_repeating_section( - self, section_id: str, list_item_id: str, list_item_index: Optional[int] - ) -> dict[str, Union[str, list]]: + self, section_id: str, list_item_id: str, list_item_index: int | None + ) -> dict[str, str | list]: # Type ignore: section id will be valid and repeat will be present at this stage repeating_title: ImmutableDict = self._schema.get_repeating_title_for_section(section_id) # type: ignore @@ -153,11 +153,11 @@ def _get_row_for_repeating_section( def _get_row_for_section( self, - section_title: Optional[str], + section_title: str | None, section_id: str, - list_item_id: Optional[str] = None, - list_item_index: Optional[int] = None, - ) -> dict[str, Union[str, list]]: + list_item_id: str | None = None, + list_item_index: int | None = None, + ) -> dict[str, str | list]: row_id = f"{section_id}-{list_item_index}" if list_item_index else section_id section_status = self._data_stores.progress_store.get_section_status( @@ -173,7 +173,7 @@ def _get_row_for_section( def _get_rows( self, enabled_section_ids: Iterable[str] - ) -> list[dict[str, Union[str, list]]]: + ) -> list[dict[str, str | list]]: rows: list[dict] = [] for section_id in enabled_section_ids: @@ -209,7 +209,7 @@ def _individual_response_enabled(self) -> bool: return True @cached_property - def _individual_response_url(self) -> Union[str, None]: + def _individual_response_url(self) -> str | None: if ( self._individual_response_enabled and self._schema.get_individual_response_show_on_hub() diff --git a/app/views/contexts/preview/preview_block.py b/app/views/contexts/preview/preview_block.py index 187a0677cd..9181879d36 100644 --- a/app/views/contexts/preview/preview_block.py +++ b/app/views/contexts/preview/preview_block.py @@ -1,4 +1,4 @@ -from typing import Any, Union +from typing import Any from werkzeug.datastructures import ImmutableDict @@ -19,12 +19,12 @@ def __init__( @staticmethod def _get_question( block: ImmutableDict, - ) -> dict[str, Union[str, dict]]: + ) -> dict[str, str | dict]: return PreviewQuestion( block=block, ).serialize() - def serialize(self) -> dict[str, Union[str, dict, Any]]: + def serialize(self) -> dict[str, str | dict | Any]: return { "question": self._question, } diff --git a/app/views/contexts/preview/preview_question.py b/app/views/contexts/preview/preview_question.py index 29364aa80d..30630f890a 100644 --- a/app/views/contexts/preview/preview_question.py +++ b/app/views/contexts/preview/preview_question.py @@ -1,4 +1,4 @@ -from typing import Any, Union +from typing import Any from flask_babel import lazy_gettext from werkzeug.datastructures import ImmutableDict @@ -50,7 +50,7 @@ def _build_answers(self) -> list[dict]: answers_list.append(answer_dict) return answers_list - def serialize(self) -> dict[str, Union[str, dict, Any]]: + def serialize(self) -> dict[str, str | dict | Any]: return { "id": self._block_id, "title": self._title, diff --git a/app/views/contexts/preview_context.py b/app/views/contexts/preview_context.py index f3c58b9e00..2abc9a9005 100644 --- a/app/views/contexts/preview_context.py +++ b/app/views/contexts/preview_context.py @@ -1,4 +1,4 @@ -from typing import Generator, Union +from typing import Generator from flask_babel import lazy_gettext @@ -26,7 +26,7 @@ def __init__( placeholder_preview_mode=True, ) - def __call__(self) -> dict[str, Union[str, list, bool]]: + def __call__(self) -> dict[str, str | list | bool]: sections = list(self.build_all_sections()) return { "sections": sections, diff --git a/app/views/contexts/section_summary_context.py b/app/views/contexts/section_summary_context.py index 335cce9b59..bd83394339 100644 --- a/app/views/contexts/section_summary_context.py +++ b/app/views/contexts/section_summary_context.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import Any, Generator, Iterable, Mapping, Union +from typing import Any, Generator, Iterable, Mapping from werkzeug.datastructures import ImmutableDict @@ -77,7 +77,7 @@ def section(self) -> ImmutableDict: section: ImmutableDict = self._schema.get_section(self.current_location.section_id) # type: ignore return section - def get_page_title(self, title_for_location: Union[Mapping, str]) -> str: + def get_page_title(self, title_for_location: Mapping | str) -> str: section_repeating_page_title = ( self._schema.get_repeating_page_title_for_section( self.current_location.section_id @@ -146,7 +146,7 @@ def build_summary( return groups - def title_for_location(self) -> Union[str, dict]: + def title_for_location(self) -> str | dict: section_id = self.current_location.section_id return ( # Type ignore: section id should exist at this point @@ -170,7 +170,7 @@ def _custom_summary_elements( ) yield list_collector_block.list_summary_element(summary_element) - def _get_safe_page_title(self, title: Union[Mapping, str]) -> str: + def _get_safe_page_title(self, title: Mapping | str) -> str: return ( safe_content(self._schema.get_single_string_value(title)) if title else "" ) diff --git a/app/views/contexts/submit_questionnaire_context.py b/app/views/contexts/submit_questionnaire_context.py index 2325f64acd..8f3aed9977 100644 --- a/app/views/contexts/submit_questionnaire_context.py +++ b/app/views/contexts/submit_questionnaire_context.py @@ -1,4 +1,4 @@ -from typing import Mapping, Union +from typing import Mapping from flask_babel import lazy_gettext @@ -7,7 +7,7 @@ class SubmitQuestionnaireContext(Context): - def __call__(self) -> dict[str, Union[str, dict]]: + def __call__(self) -> dict[str, str | dict]: submission_schema: Mapping = self._schema.get_submission() title = submission_schema.get("title") or lazy_gettext( diff --git a/app/views/contexts/summary/question.py b/app/views/contexts/summary/question.py index cbb35a69df..da54ae3790 100644 --- a/app/views/contexts/summary/question.py +++ b/app/views/contexts/summary/question.py @@ -1,4 +1,4 @@ -from typing import Any, Mapping, Optional +from typing import Any, Mapping from flask import url_for from markupsafe import Markup, escape @@ -77,7 +77,7 @@ def __init__( def get_answer( self, answer_store: AnswerStore, answer_id: str, list_item_id: str | None = None - ) -> Optional[AnswerValueEscapedTypes]: + ) -> AnswerValueEscapedTypes | None: answer = answer_store.get_answer( answer_id, list_item_id or self.list_item_id ) or self.schema.get_default_answer(answer_id) @@ -178,7 +178,7 @@ def _build_answer( answer_store: AnswerStore, question_schema: QuestionSchemaType, answer_schema: Mapping[str, Any], - answer_value: Optional[AnswerValueEscapedTypes] = None, + answer_value: AnswerValueEscapedTypes | None = None, ) -> InferredAnswerValueTypes: if answer_value is None: return None @@ -200,8 +200,8 @@ def _build_answer( return answer_value def _build_date_range_answer( - self, answer_store: AnswerStore, answer: Optional[AnswerValueEscapedTypes] - ) -> dict[str, Optional[AnswerValueEscapedTypes]]: + self, answer_store: AnswerStore, answer: AnswerValueEscapedTypes | None + ) -> dict[str, AnswerValueEscapedTypes | None]: next_answer = next(self.answer_schemas) to_date = self.get_answer(answer_store, next_answer["id"]) return {"from": answer, "to": to_date} @@ -233,7 +233,7 @@ def _build_checkbox_answers( answer: Markup, answer_schema: Mapping[str, Any], answer_store: AnswerStore, - ) -> Optional[list[RadioCheckboxTypes]]: + ) -> list[RadioCheckboxTypes] | None: multiple_answers = [] for option in self.get_answer_options(answer_schema): if escape(option["value"]) in answer: @@ -255,7 +255,7 @@ def _build_radio_answer( answer: Markup, answer_schema: Mapping[str, Any], answer_store: AnswerStore, - ) -> Optional[RadioCheckboxTypes]: + ) -> RadioCheckboxTypes | None: for option in self.get_answer_options(answer_schema): if answer == escape(option["value"]): detail_answer_value = self._get_detail_answer_value( @@ -268,15 +268,15 @@ def _build_radio_answer( def _get_detail_answer_value( self, option: dict, answer_store: AnswerStore - ) -> Optional[AnswerValueEscapedTypes]: + ) -> AnswerValueEscapedTypes | None: if "detail_answer" in option: return self.get_answer(answer_store, option["detail_answer"]["id"]) def _build_dropdown_answer( self, - answer: Optional[AnswerValueEscapedTypes], + answer: AnswerValueEscapedTypes | None, answer_schema: Mapping[str, Any], - ) -> Optional[str]: + ) -> str | None: for option in self.get_answer_options(answer_schema): if answer == option["value"]: return option["label"] diff --git a/app/views/contexts/thank_you_context.py b/app/views/contexts/thank_you_context.py index fcad74e705..925208f1e4 100644 --- a/app/views/contexts/thank_you_context.py +++ b/app/views/contexts/thank_you_context.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Any, Optional +from typing import Any from flask import url_for from flask_babel import lazy_gettext @@ -23,8 +23,8 @@ def build_thank_you_context( metadata: MetadataProxy, submitted_at: datetime, survey_type: SurveyType, - guidance_content: Optional[dict] = None, - confirmation_email_form: Optional[EmailForm] = None, + guidance_content: dict | None = None, + confirmation_email_form: EmailForm | None = None, ) -> dict[str, Any]: if (ru_name := metadata["ru_name"]) and (trad_as := metadata["trad_as"]): submission_text = lazy_gettext( diff --git a/app/views/contexts/view_submitted_response_context.py b/app/views/contexts/view_submitted_response_context.py index a4f8d0cf40..fc2a348d37 100644 --- a/app/views/contexts/view_submitted_response_context.py +++ b/app/views/contexts/view_submitted_response_context.py @@ -1,5 +1,4 @@ from datetime import datetime -from typing import Union from flask import url_for from flask_babel import lazy_gettext @@ -20,7 +19,7 @@ def build_view_submitted_response_context( schema: QuestionnaireSchema, questionnaire_store: QuestionnaireStore, survey_type: SurveyType, -) -> dict[str, Union[str, datetime, dict]]: +) -> dict[str, str | datetime | dict]: view_submitted_response_expired = has_view_submitted_response_expired( questionnaire_store.submitted_at # type: ignore ) diff --git a/app/views/handlers/block.py b/app/views/handlers/block.py index 04ccee7c72..16bab685d4 100644 --- a/app/views/handlers/block.py +++ b/app/views/handlers/block.py @@ -1,6 +1,6 @@ from datetime import datetime, timezone from functools import cached_property -from typing import Mapping, MutableMapping, Optional, Union +from typing import Mapping, MutableMapping from structlog import get_logger from werkzeug.datastructures import ImmutableDict, ImmutableMultiDict @@ -41,7 +41,7 @@ def __init__( # Type ignore: Block has to exist at this point. Block existence is checked beforehand in block_factory.py self.block: ImmutableDict = self._schema.get_block(self._current_location.block_id) # type: ignore self._routing_path = self._get_routing_path() - self.page_title: Optional[str] = None + self.page_title: str | None = None self._return_location = ReturnLocation( return_to=request_args.get("return_to"), @@ -122,7 +122,7 @@ def _get_routing_path(self) -> RoutingPath: return self.router.routing_path(self._current_location.section_key) def _update_section_completeness( - self, location: Optional[Union[Location, RelationshipLocation]] = None + self, location: Location | RelationshipLocation | None = None ) -> None: location = location or self._current_location diff --git a/app/views/handlers/feedback.py b/app/views/handlers/feedback.py index 1fe9d35d19..72a63f32fc 100644 --- a/app/views/handlers/feedback.py +++ b/app/views/handlers/feedback.py @@ -1,6 +1,6 @@ from datetime import datetime, timezone from functools import cached_property -from typing import Any, Mapping, MutableMapping, Optional, Union +from typing import Any, Mapping, MutableMapping from flask import current_app from flask_babel import gettext, lazy_gettext @@ -43,7 +43,7 @@ def __init__( questionnaire_store: QuestionnaireStore, schema: QuestionnaireSchema, session_store: SessionStore, - form_data: Optional[MultiDict[str, Any]], + form_data: MultiDict[str, Any] | None, ): if not self.is_enabled(schema): raise FeedbackNotEnabled @@ -65,7 +65,7 @@ def form(self) -> QuestionnaireForm: form_data=self._form_data, ) - def get_context(self) -> Mapping[str, Union[str, bool, dict]]: + def get_context(self) -> Mapping[str, str | bool | dict]: return build_feedback_context(self.question_schema, self.form) def get_page_title(self) -> str: @@ -108,14 +108,14 @@ def handle_post(self) -> None: tx_id=tx_id, case_id=case_id, **additional_metadata ) - submitter: Union[GCSFeedbackSubmitter, LogFeedbackSubmitter] = current_app.eq["feedback_submitter"] # type: ignore + submitter: GCSFeedbackSubmitter | LogFeedbackSubmitter = current_app.eq["feedback_submitter"] # type: ignore if not submitter.upload(feedback_metadata(), encrypted_message): raise FeedbackUploadFailed() self._session_store.save() @cached_property - def question_schema(self) -> Mapping[str, Union[str, list]]: + def question_schema(self) -> Mapping[str, str | list]: return { "type": "General", "id": "feedback", @@ -220,10 +220,10 @@ class FeedbackPayloadV2: def __init__( self, metadata: MetadataProxy, - response_metadata: MutableMapping[str, Union[str, int, list]], + response_metadata: MutableMapping[str, str | int | list], schema: QuestionnaireSchema, - case_id: Optional[str], - submission_language_code: Optional[str], + case_id: str | None, + submission_language_code: str | None, feedback_count: int, feedback_text: str, feedback_type: str, diff --git a/app/views/handlers/individual_response.py b/app/views/handlers/individual_response.py index 79c36fce56..f16741a436 100644 --- a/app/views/handlers/individual_response.py +++ b/app/views/handlers/individual_response.py @@ -1,6 +1,6 @@ from datetime import datetime, timezone from functools import cached_property -from typing import Any, Mapping, Optional +from typing import Any, Mapping from uuid import uuid4 from flask import current_app, redirect @@ -972,7 +972,7 @@ def handle_post(self) -> Response: class IndividualResponseFulfilmentRequest(FulfilmentRequest): - def __init__(self, metadata: MetadataProxy, mobile_number: Optional[str] = None): + def __init__(self, metadata: MetadataProxy, mobile_number: str | None = None): self._metadata = metadata self._mobile_number = mobile_number self._fulfilment_type = "sms" if self._mobile_number else "postal" @@ -991,7 +991,7 @@ def _get_contact_mapping(self) -> Mapping: else {} ) - def _get_fulfilment_code(self) -> Optional[str]: + def _get_fulfilment_code(self) -> str | None: fulfilment_codes = { "sms": { GB_ENG_REGION_CODE: "UACITA1", diff --git a/app/views/handlers/view_submitted_response.py b/app/views/handlers/view_submitted_response.py index d898db1422..40ee9c059f 100644 --- a/app/views/handlers/view_submitted_response.py +++ b/app/views/handlers/view_submitted_response.py @@ -1,5 +1,4 @@ from datetime import datetime -from typing import Union from flask_babel import lazy_gettext @@ -42,7 +41,7 @@ def has_expired(self) -> bool: ) return False - def get_context(self) -> dict[str, Union[str, datetime, dict]]: + def get_context(self) -> dict[str, str | datetime | dict]: return build_view_submitted_response_context( self._language, self._schema, self._questionnaire_store, get_survey_type() ) diff --git a/pyproject.toml b/pyproject.toml index 4a1f78e9bc..574a84c484 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -126,7 +126,6 @@ extend-ignore = [ "UP032", # Use f-string instead of `format` call "UP018", # Unnecessary {literal_type} call (rewrite as a literal) "UP015", # Unnecessary open mode parameters - "UP007", # Use `X | Y` for type annotations "UP009", # UTF-8 encoding declaration is unnecessary "UP017", # Use `datetime.UTC` alias "UP033", # Use @functools.cache instead of @functools.lru_cache(maxsize=None) diff --git a/tests/app/questionnaire/rules/test_rule_evaluator.py b/tests/app/questionnaire/rules/test_rule_evaluator.py index b76e0d496e..65283a1b5f 100644 --- a/tests/app/questionnaire/rules/test_rule_evaluator.py +++ b/tests/app/questionnaire/rules/test_rule_evaluator.py @@ -1,5 +1,4 @@ from datetime import datetime, timezone -from typing import Optional, Union import pytest from freezegun import freeze_time @@ -47,10 +46,10 @@ def get_rule_evaluator( language="en", schema: QuestionnaireSchema = None, data_stores: DataStores = None, - location: Union[Location, RelationshipLocation] = Location( + location: Location | RelationshipLocation = Location( section_id="test-section", block_id="test-block" ), - routing_path_block_ids: Optional[list] = None, + routing_path_block_ids: list | None = None, ): if not schema: schema = get_mock_schema() diff --git a/tests/app/questionnaire/test_value_source_resolver.py b/tests/app/questionnaire/test_value_source_resolver.py index c6220782fa..0e4e43fcd0 100644 --- a/tests/app/questionnaire/test_value_source_resolver.py +++ b/tests/app/questionnaire/test_value_source_resolver.py @@ -1,5 +1,4 @@ # pylint: disable=too-many-lines -from typing import Optional, Union import pytest from mock import MagicMock, Mock @@ -60,11 +59,11 @@ def get_calculation_block( def get_value_source_resolver( schema: QuestionnaireSchema = None, data_stores: DataStores = None, - location: Union[Location, RelationshipLocation] = Location( + location: Location | RelationshipLocation = Location( section_id="test-section", block_id="test-block" ), - list_item_id: Optional[str] = None, - routing_path_block_ids: Optional[list] = None, + list_item_id: str | None = None, + routing_path_block_ids: list | None = None, use_default_answer=False, escape_answer_values=False, ):