Skip to content

Commit

Permalink
make sure ModeBuilder generates tz aware values for Timestamptz (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
dantownsend authored Jan 23, 2024
1 parent bade80f commit 44940bb
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 14 deletions.
5 changes: 5 additions & 0 deletions piccolo/columns/column_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,6 +1005,11 @@ class Concert(Table):
"""

value_type = datetime

# Currently just used by ModelBuilder, to know that we want a timezone
# aware datetime.
tz_aware = True

timedelta_delegate = TimedeltaDelegate()

def __init__(
Expand Down
13 changes: 8 additions & 5 deletions piccolo/testing/model_builder.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations

import datetime
import json
import typing as t
from datetime import date, datetime, time, timedelta
from decimal import Decimal
from uuid import UUID

Expand All @@ -16,13 +16,13 @@ class ModelBuilder:
__DEFAULT_MAPPER: t.Dict[t.Type, t.Callable] = {
bool: RandomBuilder.next_bool,
bytes: RandomBuilder.next_bytes,
date: RandomBuilder.next_date,
datetime: RandomBuilder.next_datetime,
datetime.date: RandomBuilder.next_date,
datetime.datetime: RandomBuilder.next_datetime,
float: RandomBuilder.next_float,
int: RandomBuilder.next_int,
str: RandomBuilder.next_str,
time: RandomBuilder.next_time,
timedelta: RandomBuilder.next_timedelta,
datetime.time: RandomBuilder.next_time,
datetime.timedelta: RandomBuilder.next_timedelta,
UUID: RandomBuilder.next_uuid,
}

Expand Down Expand Up @@ -155,6 +155,9 @@ def _randomize_attribute(cls, column: Column) -> t.Any:
random_value = RandomBuilder.next_float(
maximum=10 ** (precision - scale), scale=scale
)
elif column.value_type == datetime.datetime:
tz_aware = getattr(column, "tz_aware", False)
random_value = RandomBuilder.next_datetime(tz_aware=tz_aware)
elif column.value_type == list:
length = RandomBuilder.next_int(maximum=10)
base_type = t.cast(Array, column).base_column.value_type
Expand Down
19 changes: 10 additions & 9 deletions piccolo/testing/random_builder.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import datetime
import enum
import random
import string
import typing as t
import uuid
from datetime import date, datetime, time, timedelta


class RandomBuilder:
Expand All @@ -16,22 +16,23 @@ def next_bytes(cls, length=8) -> bytes:
return random.getrandbits(length * 8).to_bytes(length, "little")

@classmethod
def next_date(cls) -> date:
return date(
def next_date(cls) -> datetime.date:
return datetime.date(
year=random.randint(2000, 2050),
month=random.randint(1, 12),
day=random.randint(1, 28),
)

@classmethod
def next_datetime(cls) -> datetime:
return datetime(
def next_datetime(cls, tz_aware: bool = False) -> datetime.datetime:
return datetime.datetime(
year=random.randint(2000, 2050),
month=random.randint(1, 12),
day=random.randint(1, 28),
hour=random.randint(0, 23),
minute=random.randint(0, 59),
second=random.randint(0, 59),
tzinfo=datetime.timezone.utc if tz_aware else None,
)

@classmethod
Expand All @@ -53,16 +54,16 @@ def next_str(cls, length=16) -> str:
)

@classmethod
def next_time(cls) -> time:
return time(
def next_time(cls) -> datetime.time:
return datetime.time(
hour=random.randint(0, 23),
minute=random.randint(0, 59),
second=random.randint(0, 59),
)

@classmethod
def next_timedelta(cls) -> timedelta:
return timedelta(
def next_timedelta(cls) -> datetime.timedelta:
return datetime.timedelta(
days=random.randint(1, 7),
hours=random.randint(1, 23),
minutes=random.randint(0, 59),
Expand Down
21 changes: 21 additions & 0 deletions tests/testing/test_model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
LazyTableReference,
Numeric,
Real,
Timestamp,
Timestamptz,
Varchar,
)
from piccolo.table import Table, create_db_tables_sync, drop_db_tables_sync
Expand Down Expand Up @@ -97,6 +99,25 @@ def test_choices(self):
["s", "l", "m"],
)

def test_datetime(self):
"""
Make sure that ``ModelBuilder`` generates timezone aware datetime
objects for ``Timestamptz`` columns, and timezone naive datetime
objects for ``Timestamp`` columns.
"""

class Table1(Table):
starts = Timestamptz()

class Table2(Table):
starts = Timestamp()

model_1 = ModelBuilder.build_sync(Table1, persist=False)
assert model_1.starts.tzinfo is not None

model_2 = ModelBuilder.build_sync(Table2, persist=False)
assert model_2.starts.tzinfo is None

def test_foreign_key(self):
model = ModelBuilder.build_sync(Band, persist=True)

Expand Down

0 comments on commit 44940bb

Please sign in to comment.