From 636f7057a2df19a426c5d069475a5dadeee58778 Mon Sep 17 00:00:00 2001 From: 07pepa Date: Sun, 30 Jun 2024 18:13:53 +0200 Subject: [PATCH] add timezone name validation --- Makefile | 2 +- pydantic_extra_types/timezone_name.py | 125 ++++++++++++++++++++++++++ pyproject.toml | 5 +- tests/test_json_schema.py | 19 ++++ tests/test_timezone_names.py | 78 ++++++++++++++++ 5 files changed, 227 insertions(+), 2 deletions(-) create mode 100644 pydantic_extra_types/timezone_name.py create mode 100644 tests/test_timezone_names.py diff --git a/Makefile b/Makefile index f8d968f7..af890d57 100644 --- a/Makefile +++ b/Makefile @@ -28,7 +28,7 @@ lint: .PHONY: mypy mypy: - mypy pydantic_extra_types + @mypy pydantic_extra_types .PHONY: test test: diff --git a/pydantic_extra_types/timezone_name.py b/pydantic_extra_types/timezone_name.py new file mode 100644 index 00000000..27a50825 --- /dev/null +++ b/pydantic_extra_types/timezone_name.py @@ -0,0 +1,125 @@ +"""Time zone name validation and serialization module.""" + +from __future__ import annotations + +import importlib +import sys +import warnings +from typing import Any, List + +from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler +from pydantic_core import PydanticCustomError, core_schema + + +def _is_available(name: str) -> bool: + try: + importlib.import_module(name=name) + return True + except ModuleNotFoundError: # pragma: no cover + return False + + +if _is_available('zoneinfo') and _is_available('tzdata'): # pragma: no cover + from zoneinfo import available_timezones + + def _tz_provider() -> set[str]: + return set(available_timezones()) + +elif _is_available('pytz'): # pragma: no cover + if sys.version_info[:2] > (3, 8): + warnings.warn( + 'Projects using Python 3.9 or later' + ' should be using the support now included as part of the standard library zone-info. ' + 'Please consider switching to the standard library module.' + ) + from pytz import all_timezones + + def _tz_provider() -> set[str]: + return set(all_timezones) +else: # pragma: no cover + if sys.version_info[:2] == (3, 8): + raise ImportError('No pytz module not found. Please install it with "pip install pytz') + raise ImportError('No timezone provider found. Please install tzdata' 'Please install it with "pip install tzdata"') + + +class TimeZoneNameSettings(type): + def __new__(cls, name, bases, dct, **kwargs): # type: ignore[no-untyped-def] + dct['strict'] = kwargs.pop('strict', True) + return super().__new__(cls, name, bases, dct) + + def __init__(cls, name, bases, dct, **kwargs): # type: ignore[no-untyped-def] + super().__init__(name, bases, dct) + cls.strict = kwargs.get('strict', True) + + +class TimeZoneName(str, metaclass=TimeZoneNameSettings): # type: ignore[misc] + """If the mode is not strict matching, it is case-insensitive with whitespace stripped. + Value is then coerced to the correct case.""" + + __slots__: List[str] = [] + allowed_values = set(_tz_provider()) + allowed_values_list = list(allowed_values) + allowed_values_list.sort() + allowed_values_upper_to_correct = {val.upper(): val for val in allowed_values} + + @classmethod + def _validate(cls, __input_value: str, _: core_schema.ValidationInfo) -> TimeZoneName: + """ + Validate a time zone name from the provided str value. + + Args: + __input_value: The str value to be validated. + _: The Pydantic ValidationInfo. + + Returns: + The validated time zone name. + + Raises: + PydanticCustomError: If the timezone name is not valid. + """ + if __input_value not in cls.allowed_values: # be fast for the most common case + if not cls.strict: + upper_value = __input_value.strip().upper() + if upper_value in cls.allowed_values_upper_to_correct: + return cls(cls.allowed_values_upper_to_correct[upper_value]) + raise PydanticCustomError('TimeZoneName', 'Invalid timezone name.') + return cls(__input_value) + + @classmethod + def __get_pydantic_core_schema__( + cls, _: type[Any], __: GetCoreSchemaHandler + ) -> core_schema.AfterValidatorFunctionSchema: + """ + Return a Pydantic CoreSchema with the ISO 639-3 language code validation. + + Args: + _: The source type. + __: The handler to get the CoreSchema. + + Returns: + A Pydantic CoreSchema with the ISO 639-3 language code validation. + + """ + return core_schema.with_info_after_validator_function( + cls._validate, + core_schema.str_schema(min_length=1), + ) + + @classmethod + def __get_pydantic_json_schema__( + cls, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler + ) -> dict[str, Any]: + """ + Return a Pydantic JSON Schema with the ISO 639-3 language code validation. + + Args: + schema: The Pydantic CoreSchema. + handler: The handler to get the JSON Schema. + + Returns: + A Pydantic JSON Schema with the ISO 639-3 language code validation. + + """ + json_schema = handler(schema) + json_schema.update({'enum': cls.allowed_values_list}) + return json_schema diff --git a/pyproject.toml b/pyproject.toml index 26f15d0e..b4b6a1a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,7 +49,10 @@ all = [ 'pycountry>=23', 'python-ulid>=1,<2; python_version<"3.9"', 'python-ulid>=1,<3; python_version>="3.9"', - 'pendulum>=3.0.0,<4.0.0' + 'pendulum>=3.0.0,<4.0.0', + 'pytz>=2024.1', + 'tzdata>=2024.1', + 'types-pytz>=2024.4.0.2024' ] phonenumbers = ['phonenumbers>=8,<9'] pycountry = ['pycountry>=23'] diff --git a/tests/test_json_schema.py b/tests/test_json_schema.py index 43ad9326..c37e4357 100644 --- a/tests/test_json_schema.py +++ b/tests/test_json_schema.py @@ -18,6 +18,7 @@ from pydantic_extra_types.payment import PaymentCardNumber from pydantic_extra_types.pendulum_dt import DateTime from pydantic_extra_types.script_code import ISO_15924 +from pydantic_extra_types.timezone_name import TimeZoneName from pydantic_extra_types.ulid import ULID languages = [lang.alpha_3 for lang in pycountry.languages] @@ -35,6 +36,8 @@ scripts = [script.alpha_4 for script in pycountry.scripts] +timezone_names = TimeZoneName.allowed_values_list + everyday_currencies.sort() @@ -325,6 +328,22 @@ 'type': 'object', }, ), + ( + TimeZoneName, + { + 'properties': { + 'x': { + 'title': 'X', + 'type': 'string', + 'enum': timezone_names, + 'minLength': 1, + } + }, + 'required': ['x'], + 'title': 'Model', + 'type': 'object', + }, + ), ], ) def test_json_schema(cls, expected): diff --git a/tests/test_timezone_names.py b/tests/test_timezone_names.py new file mode 100644 index 00000000..471ca27b --- /dev/null +++ b/tests/test_timezone_names.py @@ -0,0 +1,78 @@ +import re + +import pytest +import pytz +from pydantic import BaseModel, ValidationError + +from pydantic_extra_types.timezone_name import TimeZoneName + +has_zone_info = True +try: + from zoneinfo import available_timezones +except ImportError: + has_zone_info = False + +pytz_zones_bad = [(zone.lower(), zone) for zone in pytz.all_timezones] +pytz_zones_bad.extend([(f' {zone}', zone) for zone in pytz.all_timezones_set]) + + +class TZNameCheck(BaseModel): + timezone_name: TimeZoneName + + +class TZNonStrict(TimeZoneName, strict=False): + pass + + +class NonStrictTzName(BaseModel): + timezone_name: TZNonStrict + + +@pytest.mark.parametrize('zone', pytz.all_timezones) +def test_all_timezones_non_strict_pytz(zone): + assert TZNameCheck(timezone_name=zone).timezone_name == zone + assert NonStrictTzName(timezone_name=zone).timezone_name == zone + + +@pytest.mark.parametrize('zone', pytz_zones_bad) +def test_all_timezones_pytz_lower(zone): + assert NonStrictTzName(timezone_name=zone[0]).timezone_name == zone[1] + + +def test_fail_non_existing_timezone(): + with pytest.raises( + ValidationError, + match=re.escape( + '1 validation error for TZNameCheck\n' + 'timezone_name\n ' + 'Invalid timezone name. ' + "[type=TimeZoneName, input_value='mars', input_type=str]" + ), + ): + TZNameCheck(timezone_name='mars') + + with pytest.raises( + ValidationError, + match=re.escape( + '1 validation error for NonStrictTzName\n' + 'timezone_name\n ' + 'Invalid timezone name. ' + "[type=TimeZoneName, input_value='mars', input_type=str]" + ), + ): + NonStrictTzName(timezone_name='mars') + + +if has_zone_info: + zones = list(available_timezones()) + zones.sort() + zones_bad = [(zone.lower(), zone) for zone in zones] + + @pytest.mark.parametrize('zone', zones) + def test_all_timezones_zone_info(zone): + assert TZNameCheck(timezone_name=zone).timezone_name == zone + assert NonStrictTzName(timezone_name=zone).timezone_name == zone + + @pytest.mark.parametrize('zone', zones_bad) + def test_all_timezones_zone_info_NonStrict(zone): + assert NonStrictTzName(timezone_name=zone[0]).timezone_name == zone[1]