diff --git a/piccolo/columns/column_types.py b/piccolo/columns/column_types.py index 33c0f261f..75634bdf3 100644 --- a/piccolo/columns/column_types.py +++ b/piccolo/columns/column_types.py @@ -1005,6 +1005,11 @@ class Concert(Table): """ value_type = datetime + + # Currently just used by ModelBuilder, to know that we want a timezone + # aware datetime. + tz_aware = True + timedelta_delegate = TimedeltaDelegate() def __init__( diff --git a/piccolo/testing/model_builder.py b/piccolo/testing/model_builder.py index 15b50416c..2994b3ec3 100644 --- a/piccolo/testing/model_builder.py +++ b/piccolo/testing/model_builder.py @@ -1,8 +1,8 @@ from __future__ import annotations +import datetime import json import typing as t -from datetime import date, datetime, time, timedelta from decimal import Decimal from uuid import UUID @@ -16,13 +16,13 @@ class ModelBuilder: __DEFAULT_MAPPER: t.Dict[t.Type, t.Callable] = { bool: RandomBuilder.next_bool, bytes: RandomBuilder.next_bytes, - date: RandomBuilder.next_date, - datetime: RandomBuilder.next_datetime, + datetime.date: RandomBuilder.next_date, + datetime.datetime: RandomBuilder.next_datetime, float: RandomBuilder.next_float, int: RandomBuilder.next_int, str: RandomBuilder.next_str, - time: RandomBuilder.next_time, - timedelta: RandomBuilder.next_timedelta, + datetime.time: RandomBuilder.next_time, + datetime.timedelta: RandomBuilder.next_timedelta, UUID: RandomBuilder.next_uuid, } @@ -155,6 +155,9 @@ def _randomize_attribute(cls, column: Column) -> t.Any: random_value = RandomBuilder.next_float( maximum=10 ** (precision - scale), scale=scale ) + elif column.value_type == datetime.datetime: + tz_aware = getattr(column, "tz_aware", False) + random_value = RandomBuilder.next_datetime(tz_aware=tz_aware) elif column.value_type == list: length = RandomBuilder.next_int(maximum=10) base_type = t.cast(Array, column).base_column.value_type diff --git a/piccolo/testing/random_builder.py b/piccolo/testing/random_builder.py index dfc46a9f2..bca29a7f2 100644 --- a/piccolo/testing/random_builder.py +++ b/piccolo/testing/random_builder.py @@ -1,9 +1,9 @@ +import datetime import enum import random import string import typing as t import uuid -from datetime import date, datetime, time, timedelta class RandomBuilder: @@ -16,22 +16,23 @@ def next_bytes(cls, length=8) -> bytes: return random.getrandbits(length * 8).to_bytes(length, "little") @classmethod - def next_date(cls) -> date: - return date( + def next_date(cls) -> datetime.date: + return datetime.date( year=random.randint(2000, 2050), month=random.randint(1, 12), day=random.randint(1, 28), ) @classmethod - def next_datetime(cls) -> datetime: - return datetime( + def next_datetime(cls, tz_aware: bool = False) -> datetime.datetime: + return datetime.datetime( year=random.randint(2000, 2050), month=random.randint(1, 12), day=random.randint(1, 28), hour=random.randint(0, 23), minute=random.randint(0, 59), second=random.randint(0, 59), + tzinfo=datetime.timezone.utc if tz_aware else None, ) @classmethod @@ -53,16 +54,16 @@ def next_str(cls, length=16) -> str: ) @classmethod - def next_time(cls) -> time: - return time( + def next_time(cls) -> datetime.time: + return datetime.time( hour=random.randint(0, 23), minute=random.randint(0, 59), second=random.randint(0, 59), ) @classmethod - def next_timedelta(cls) -> timedelta: - return timedelta( + def next_timedelta(cls) -> datetime.timedelta: + return datetime.timedelta( days=random.randint(1, 7), hours=random.randint(1, 23), minutes=random.randint(0, 59), diff --git a/tests/testing/test_model_builder.py b/tests/testing/test_model_builder.py index 4eead76bf..f56fcf956 100644 --- a/tests/testing/test_model_builder.py +++ b/tests/testing/test_model_builder.py @@ -11,6 +11,8 @@ LazyTableReference, Numeric, Real, + Timestamp, + Timestamptz, Varchar, ) from piccolo.table import Table, create_db_tables_sync, drop_db_tables_sync @@ -97,6 +99,25 @@ def test_choices(self): ["s", "l", "m"], ) + def test_datetime(self): + """ + Make sure that ``ModelBuilder`` generates timezone aware datetime + objects for ``Timestamptz`` columns, and timezone naive datetime + objects for ``Timestamp`` columns. + """ + + class Table1(Table): + starts = Timestamptz() + + class Table2(Table): + starts = Timestamp() + + model_1 = ModelBuilder.build_sync(Table1, persist=False) + assert model_1.starts.tzinfo is not None + + model_2 = ModelBuilder.build_sync(Table2, persist=False) + assert model_2.starts.tzinfo is None + def test_foreign_key(self): model = ModelBuilder.build_sync(Band, persist=True)