Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add timezone name validation #193

Merged
merged 4 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 189 additions & 0 deletions pydantic_extra_types/timezone_name.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
"""Time zone name validation and serialization module."""

from __future__ import annotations

import importlib
import sys
import warnings
from typing import Any, Callable, List, Set, Type, cast

from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler
from pydantic_core import PydanticCustomError, core_schema


def _is_available(name: str) -> bool:
"""Check if a module is available for import."""
try:
importlib.import_module(name)
return True
except ModuleNotFoundError: # pragma: no cover
return False


def _tz_provider_from_zone_info() -> Set[str]: # pragma: no cover
"""Get timezones from the zoneinfo module."""
from zoneinfo import available_timezones

return set(available_timezones())


def _tz_provider_from_pytz() -> Set[str]: # pragma: no cover
"""Get timezones from the pytz module."""
from pytz import all_timezones

return set(all_timezones)


def _warn_about_pytz_usage() -> None:
"""Warn about using pytz with Python 3.9 or later."""
warnings.warn( # pragma: no cover
'Projects using Python 3.9 or later should be using the support now included as part of the standard library. '
'Please consider switching to the standard library (zoneinfo) module.'
)


def get_timezones() -> Set[str]:
"""Determine the timezone provider and return available timezones."""
if _is_available('zoneinfo') and _is_available('tzdata'): # pragma: no cover
return _tz_provider_from_zone_info()
elif _is_available('pytz'): # pragma: no cover
if sys.version_info[:2] > (3, 8):
_warn_about_pytz_usage()
return _tz_provider_from_pytz()
else: # pragma: no cover
if sys.version_info[:2] == (3, 8):
raise ImportError('No pytz module found. Please install it with "pip install pytz"')
raise ImportError('No timezone provider found. Please install tzdata with "pip install tzdata"')


class TimeZoneNameSettings(type):
def __new__(cls, name: str, bases: tuple[type, ...], dct: dict[str, Any], **kwargs: Any) -> Type[TimeZoneName]:
dct['strict'] = kwargs.pop('strict', True)
return cast(Type[TimeZoneName], super().__new__(cls, name, bases, dct))

def __init__(cls, name: str, bases: tuple[type, ...], dct: dict[str, Any], **kwargs: Any) -> None:
super().__init__(name, bases, dct)
cls.strict = kwargs.get('strict', True)


def timezone_name_settings(**kwargs: Any) -> Callable[[Type[TimeZoneName]], Type[TimeZoneName]]:
def wrapper(cls: Type[TimeZoneName]) -> Type[TimeZoneName]:
cls.strict = kwargs.get('strict', True)
return cls

return wrapper


@timezone_name_settings(strict=True)
class TimeZoneName(str):
"""
TimeZoneName is a custom string subclass for validating and serializing timezone names.

The TimeZoneName class uses the IANA Time Zone Database for validation.
It supports both strict and non-strict modes for timezone name validation.


## Examples:

Some examples of using the TimeZoneName class:

### Normal usage:

```python
from pydantic_extra_types.timezone_name import TimeZoneName
from pydantic import BaseModel
class Location(BaseModel):
city: str
timezone: TimeZoneName

loc = Location(city="New York", timezone="America/New_York")
print(loc.timezone)

>> America/New_York

```

### Non-strict mode:

```python

from pydantic_extra_types.timezone_name import TimeZoneName, timezone_name_settings

@timezone_name_settings(strict=False)
class TZNonStrict(TimeZoneName):
pass

tz = TZNonStrict("america/new_york")

print(tz)

>> america/new_york

```
"""

__slots__: List[str] = []
allowed_values: Set[str] = set(get_timezones())
allowed_values_list: List[str] = sorted(allowed_values)
allowed_values_upper_to_correct: dict[str, str] = {val.upper(): val for val in allowed_values}
strict: bool

@classmethod
def _validate(cls, __input_value: str, _: core_schema.ValidationInfo) -> TimeZoneName:
"""
Validate a time zone name from the provided str value.

Args:
__input_value: The str value to be validated.
_: The Pydantic ValidationInfo.

Returns:
The validated time zone name.

Raises:
PydanticCustomError: If the timezone name is not valid.
"""
if __input_value not in cls.allowed_values: # be fast for the most common case
if not cls.strict:
upper_value = __input_value.strip().upper()
if upper_value in cls.allowed_values_upper_to_correct:
return cls(cls.allowed_values_upper_to_correct[upper_value])
raise PydanticCustomError('TimeZoneName', 'Invalid timezone name.')
return cls(__input_value)

@classmethod
def __get_pydantic_core_schema__(
cls, _: Type[Any], __: GetCoreSchemaHandler
) -> core_schema.AfterValidatorFunctionSchema:
"""
Return a Pydantic CoreSchema with the timezone name validation.

Args:
_: The source type.
__: The handler to get the CoreSchema.

Returns:
A Pydantic CoreSchema with the timezone name validation.
"""
return core_schema.with_info_after_validator_function(
cls._validate,
core_schema.str_schema(min_length=1),
)

@classmethod
def __get_pydantic_json_schema__(
cls, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> dict[str, Any]:
"""
Return a Pydantic JSON Schema with the timezone name validation.

Args:
schema: The Pydantic CoreSchema.
handler: The handler to get the JSON Schema.

Returns:
A Pydantic JSON Schema with the timezone name validation.
"""
json_schema = handler(schema)
json_schema.update({'enum': cls.allowed_values_list})
return json_schema
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ all = [
'semver>=3.0.2',
'python-ulid>=1,<2; python_version<"3.9"',
'python-ulid>=1,<3; python_version>="3.9"',
'pendulum>=3.0.0,<4.0.0'
'pendulum>=3.0.0,<4.0.0',
'pytz>=2024.1',
'tzdata>=2024.1',
]
phonenumbers = ['phonenumbers>=8,<9']
pycountry = ['pycountry>=23']
Expand Down
1 change: 1 addition & 0 deletions requirements/linting.in
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ pre-commit
mypy
annotated-types
ruff
types-pytz
4 changes: 3 additions & 1 deletion requirements/linting.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# This file is autogenerated by pip-compile with Python 3.11
# This file is autogenerated by pip-compile with Python 3.12
# by the following command:
#
# pip-compile --no-emit-index-url --output-file=requirements/linting.txt requirements/linting.in
Expand Down Expand Up @@ -28,6 +28,8 @@ pyyaml==6.0.1
# via pre-commit
ruff==0.5.0
# via -r requirements/linting.in
types-pytz==2024.1.0.20240417
# via -r requirements/linting.in
typing-extensions==4.10.0
# via mypy
virtualenv==20.25.1
Expand Down
19 changes: 19 additions & 0 deletions tests/test_json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from pydantic_extra_types.pendulum_dt import DateTime
from pydantic_extra_types.script_code import ISO_15924
from pydantic_extra_types.semantic_version import SemanticVersion
from pydantic_extra_types.timezone_name import TimeZoneName
from pydantic_extra_types.ulid import ULID

languages = [lang.alpha_3 for lang in pycountry.languages]
Expand All @@ -36,6 +37,8 @@

scripts = [script.alpha_4 for script in pycountry.scripts]

timezone_names = TimeZoneName.allowed_values_list

everyday_currencies.sort()


Expand Down Expand Up @@ -335,6 +338,22 @@
'type': 'object',
},
),
(
TimeZoneName,
{
'properties': {
'x': {
'title': 'X',
'type': 'string',
'enum': timezone_names,
'minLength': 1,
}
},
'required': ['x'],
'title': 'Model',
'type': 'object',
},
),
],
)
def test_json_schema(cls, expected):
Expand Down
Loading
Loading