Skip to content

Commit

Permalink
Consolidate pendulum usage in prefect.server.services (#17100)
Browse files Browse the repository at this point in the history
  • Loading branch information
desertaxle authored Feb 11, 2025
1 parent b87f0e8 commit 949e865
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 68 deletions.
79 changes: 38 additions & 41 deletions src/prefect/server/database/orm_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union

import pendulum
import sqlalchemy as sa
from sqlalchemy import FetchedValue
from sqlalchemy.dialects import postgresql
Expand Down Expand Up @@ -46,11 +45,9 @@
Timestamp,
)
from prefect.server.utilities.encryption import decrypt_fernet, encrypt_fernet
from prefect.types._datetime import DateTime, now
from prefect.utilities.names import generate_slug

if TYPE_CHECKING:
DateTime = pendulum.DateTime

# for 'plain JSON' columns, use the postgresql variant (which comes with an
# extra operator) and fall back to the generic JSON variant for SQLite
sa_JSON: postgresql.JSON = postgresql.JSON().with_variant(sa.JSON(), "sqlite")
Expand Down Expand Up @@ -89,7 +86,7 @@ class Base(DeclarativeBase):
),
type_annotation_map={
uuid.UUID: UUID,
pendulum.DateTime: Timestamp,
DateTime: Timestamp,
},
)

Expand Down Expand Up @@ -121,17 +118,17 @@ def __tablename__(cls) -> str:
default=uuid.uuid4,
)

created: Mapped[pendulum.DateTime] = mapped_column(
server_default=sa.func.now(), default=lambda: pendulum.now("UTC")
created: Mapped[DateTime] = mapped_column(
server_default=sa.func.now(), default=lambda: now("UTC")
)

# onupdate is only called when statements are actually issued
# against the database. until COMMIT is issued, this column
# will not be updated
updated: Mapped[pendulum.DateTime] = mapped_column(
updated: Mapped[DateTime] = mapped_column(
index=True,
server_default=sa.func.now(),
default=lambda: pendulum.now("UTC"),
default=lambda: now("UTC"),
onupdate=sa.func.now(),
server_onupdate=FetchedValue(),
)
Expand Down Expand Up @@ -170,8 +167,8 @@ class FlowRunState(Base):
type: Mapped[schemas.states.StateType] = mapped_column(
sa.Enum(schemas.states.StateType, name="state_type"), index=True
)
timestamp: Mapped[pendulum.DateTime] = mapped_column(
server_default=sa.func.now(), default=lambda: pendulum.now("UTC")
timestamp: Mapped[DateTime] = mapped_column(
server_default=sa.func.now(), default=lambda: now("UTC")
)
name: Mapped[str] = mapped_column(index=True)
message: Mapped[Optional[str]]
Expand Down Expand Up @@ -235,8 +232,8 @@ class TaskRunState(Base):
type: Mapped[schemas.states.StateType] = mapped_column(
sa.Enum(schemas.states.StateType, name="state_type"), index=True
)
timestamp: Mapped[pendulum.DateTime] = mapped_column(
server_default=sa.func.now(), default=lambda: pendulum.now("UTC")
timestamp: Mapped[DateTime] = mapped_column(
server_default=sa.func.now(), default=lambda: now("UTC")
)
name: Mapped[str] = mapped_column(index=True)
message: Mapped[Optional[str]]
Expand Down Expand Up @@ -358,7 +355,7 @@ class TaskRunStateCache(Base):
"""

cache_key: Mapped[str] = mapped_column()
cache_expiration: Mapped[Optional[pendulum.DateTime]]
cache_expiration: Mapped[Optional[DateTime]]
task_run_state_id: Mapped[uuid.UUID]

@declared_attr.directive
Expand All @@ -385,12 +382,12 @@ class Run(Base):
sa.Enum(schemas.states.StateType, name="state_type")
)
state_name: Mapped[Optional[str]]
state_timestamp: Mapped[Optional[pendulum.DateTime]]
state_timestamp: Mapped[Optional[DateTime]]
run_count: Mapped[int] = mapped_column(server_default="0", default=0)
expected_start_time: Mapped[Optional[pendulum.DateTime]]
next_scheduled_start_time: Mapped[Optional[pendulum.DateTime]]
start_time: Mapped[Optional[pendulum.DateTime]]
end_time: Mapped[Optional[pendulum.DateTime]]
expected_start_time: Mapped[Optional[DateTime]]
next_scheduled_start_time: Mapped[Optional[DateTime]]
start_time: Mapped[Optional[DateTime]]
end_time: Mapped[Optional[DateTime]]
total_run_time: Mapped[datetime.timedelta] = mapped_column(
server_default="0", default=datetime.timedelta(0)
)
Expand All @@ -403,7 +400,7 @@ def estimated_run_time(self) -> datetime.timedelta:
if self.state_type and self.state_type == schemas.states.StateType.RUNNING:
if TYPE_CHECKING:
assert self.state_timestamp is not None
return self.total_run_time + (pendulum.now("UTC") - self.state_timestamp)
return self.total_run_time + (now("UTC") - self.state_timestamp)
else:
return self.total_run_time

Expand Down Expand Up @@ -445,10 +442,10 @@ def estimated_start_time_delta(self) -> datetime.timedelta:
elif (
self.start_time is None
and self.expected_start_time
and self.expected_start_time < pendulum.now("UTC")
and self.expected_start_time < now("UTC")
and self.state_type not in schemas.states.TERMINAL_STATES
):
return pendulum.now("UTC") - self.expected_start_time
return now("UTC") - self.expected_start_time
else:
return datetime.timedelta(0)

Expand Down Expand Up @@ -660,7 +657,7 @@ class TaskRun(Run):
task_key: Mapped[str] = mapped_column()
dynamic_key: Mapped[str] = mapped_column()
cache_key: Mapped[Optional[str]]
cache_expiration: Mapped[Optional[pendulum.DateTime]]
cache_expiration: Mapped[Optional[DateTime]]
task_version: Mapped[Optional[str]]
flow_run_run_count: Mapped[int] = mapped_column(server_default="0", default=0)
empirical_policy: Mapped[schemas.core.TaskRunPolicy] = mapped_column(
Expand Down Expand Up @@ -803,7 +800,7 @@ class Deployment(Base):
path: Mapped[Optional[str]]
entrypoint: Mapped[Optional[str]]

last_polled: Mapped[Optional[pendulum.DateTime]]
last_polled: Mapped[Optional[DateTime]]
status: Mapped[DeploymentStatus] = mapped_column(
sa.Enum(DeploymentStatus, name="deployment_status"),
default=DeploymentStatus.NOT_READY,
Expand Down Expand Up @@ -913,7 +910,7 @@ class Log(Base):
message: Mapped[str] = mapped_column(sa.Text)

# The client-side timestamp of this logged statement.
timestamp: Mapped[pendulum.DateTime] = mapped_column(index=True)
timestamp: Mapped[DateTime] = mapped_column(index=True)

__table_args__: Any = (
sa.Index(
Expand Down Expand Up @@ -1104,7 +1101,7 @@ class WorkQueue(Base):
concurrency_limit: Mapped[Optional[int]]
priority: Mapped[int]

last_polled: Mapped[Optional[pendulum.DateTime]]
last_polled: Mapped[Optional[DateTime]]
status: Mapped[WorkQueueStatus] = mapped_column(
sa.Enum(WorkQueueStatus, name="work_queue_status"),
default=WorkQueueStatus.NOT_READY,
Expand Down Expand Up @@ -1150,7 +1147,7 @@ class WorkPool(Base):
default=WorkPoolStatus.NOT_READY,
server_default=WorkPoolStatus.NOT_READY,
)
last_transitioned_status_at: Mapped[Optional[pendulum.DateTime]]
last_transitioned_status_at: Mapped[Optional[DateTime]]
last_status_event_id: Mapped[Optional[uuid.UUID]]

__table_args__: Any = (sa.UniqueConstraint("name"),)
Expand All @@ -1164,8 +1161,8 @@ class Worker(Base):
)

name: Mapped[str]
last_heartbeat_time: Mapped[pendulum.DateTime] = mapped_column(
server_default=sa.func.now(), default=lambda: pendulum.now("UTC")
last_heartbeat_time: Mapped[DateTime] = mapped_column(
server_default=sa.func.now(), default=lambda: now("UTC")
)
heartbeat_interval_seconds: Mapped[Optional[int]]

Expand Down Expand Up @@ -1194,8 +1191,8 @@ class Agent(Base):
sa.ForeignKey("work_queue.id"), index=True
)

last_activity_time: Mapped[pendulum.DateTime] = mapped_column(
server_default=sa.func.now(), default=lambda: pendulum.now("UTC")
last_activity_time: Mapped[DateTime] = mapped_column(
server_default=sa.func.now(), default=lambda: now("UTC")
)

__table_args__: Any = (sa.UniqueConstraint("name"),)
Expand Down Expand Up @@ -1249,7 +1246,7 @@ class FlowRunInput(Base):
class CsrfToken(Base):
token: Mapped[str]
client: Mapped[str] = mapped_column(unique=True)
expiration: Mapped[pendulum.DateTime]
expiration: Mapped[DateTime]


class Automation(Base):
Expand Down Expand Up @@ -1314,14 +1311,14 @@ class AutomationBucket(Base):

last_event: Mapped[Optional[ReceivedEvent]] = mapped_column(Pydantic(ReceivedEvent))

start: Mapped[pendulum.DateTime]
end: Mapped[pendulum.DateTime]
start: Mapped[DateTime]
end: Mapped[DateTime]

count: Mapped[int]

last_operation: Mapped[Optional[str]]

triggered_at: Mapped[Optional[pendulum.DateTime]]
triggered_at: Mapped[Optional[DateTime]]


class AutomationRelatedResource(Base):
Expand Down Expand Up @@ -1367,7 +1364,7 @@ class CompositeTriggerChildFiring(Base):

child_trigger_id: Mapped[uuid.UUID]
child_firing_id: Mapped[uuid.UUID]
child_fired_at: Mapped[Optional[pendulum.DateTime]]
child_fired_at: Mapped[Optional[DateTime]]
child_firing: Mapped[Firing] = mapped_column(Pydantic(Firing))


Expand All @@ -1383,7 +1380,7 @@ class AutomationEventFollower(Base):
scope: Mapped[str] = mapped_column(default="", index=True)
leader_event_id: Mapped[uuid.UUID] = mapped_column(index=True)
follower_event_id: Mapped[uuid.UUID]
received: Mapped[pendulum.DateTime] = mapped_column(index=True)
received: Mapped[DateTime] = mapped_column(index=True)
follower: Mapped[ReceivedEvent] = mapped_column(Pydantic(ReceivedEvent))


Expand All @@ -1407,7 +1404,7 @@ def __tablename__(cls) -> str:
sa.Index("ix_events__event_related_occurred", "event", "related", "occurred"),
)

occurred: Mapped[pendulum.DateTime]
occurred: Mapped[DateTime]
event: Mapped[str] = mapped_column(sa.Text())
resource_id: Mapped[str] = mapped_column(sa.Text())
resource: Mapped[dict[str, Any]] = mapped_column(JSON())
Expand All @@ -1418,8 +1415,8 @@ def __tablename__(cls) -> str:
JSON(), server_default="[]", default=list
)
payload: Mapped[dict[str, Any]] = mapped_column(JSON())
received: Mapped[pendulum.DateTime]
recorded: Mapped[pendulum.DateTime]
received: Mapped[DateTime]
recorded: Mapped[DateTime]
follows: Mapped[Optional[uuid.UUID]]


Expand All @@ -1436,7 +1433,7 @@ def __tablename__(cls) -> str:
),
)

occurred: Mapped[pendulum.DateTime]
occurred: Mapped[DateTime]
resource_id: Mapped[str] = mapped_column(sa.Text())
resource_role: Mapped[str] = mapped_column(sa.Text())
resource: Mapped[dict[str, Any]] = mapped_column(sa_JSON)
Expand Down
5 changes: 3 additions & 2 deletions src/prefect/server/services/cancellation_cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
"""

import asyncio
import datetime
from typing import Any, Optional
from uuid import UUID

import pendulum
import sqlalchemy as sa
from sqlalchemy.sql.expression import or_

Expand All @@ -18,6 +18,7 @@
from prefect.settings import PREFECT_API_SERVICES_CANCELLATION_CLEANUP_LOOP_SECONDS
from prefect.settings.context import get_current_settings
from prefect.settings.models.server.services import ServicesBaseSetting
from prefect.types._datetime import now

NON_TERMINAL_STATES = list(set(states.StateType) - states.TERMINAL_STATES)

Expand Down Expand Up @@ -64,7 +65,7 @@ async def clean_up_cancelled_flow_run_task_runs(
.where(
db.FlowRun.state_type == states.StateType.CANCELLED,
db.FlowRun.end_time.is_not(None),
db.FlowRun.end_time >= (pendulum.now("UTC").subtract(days=1)),
db.FlowRun.end_time >= (now("UTC") - datetime.timedelta(days=1)),
)
.limit(self.batch_size)
)
Expand Down
6 changes: 3 additions & 3 deletions src/prefect/server/services/foreman.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from datetime import timedelta
from typing import Any, Optional

import pendulum
import sqlalchemy as sa

from prefect.server import models
Expand All @@ -29,6 +28,7 @@
get_current_settings,
)
from prefect.settings.models.server.services import ServicesBaseSetting
from prefect.types._datetime import now


class Foreman(LoopService):
Expand Down Expand Up @@ -179,7 +179,7 @@ async def _mark_deployments_as_not_ready(self, db: PrefectDBInterface) -> None:
session (AsyncSession): The session to use for the database operation.
"""
async with db.session_context(begin_transaction=True) as session:
status_timeout_threshold = pendulum.now("UTC") - timedelta(
status_timeout_threshold = now("UTC") - timedelta(
seconds=self._deployment_last_polled_timeout_seconds
)
deployment_id_select_stmt = (
Expand Down Expand Up @@ -222,7 +222,7 @@ async def _mark_work_queues_as_not_ready(self, db: PrefectDBInterface):
session (AsyncSession): The session to use for the database operation.
"""
async with db.session_context(begin_transaction=True) as session:
status_timeout_threshold = pendulum.now("UTC") - timedelta(
status_timeout_threshold = now("UTC") - timedelta(
seconds=self._work_queue_last_polled_timeout_seconds
)
id_select_stmt = (
Expand Down
8 changes: 4 additions & 4 deletions src/prefect/server/services/late_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import datetime
from typing import TYPE_CHECKING, Any

import pendulum
import sqlalchemy as sa
from sqlalchemy.ext.asyncio import AsyncSession

Expand All @@ -26,6 +25,7 @@
)
from prefect.settings.context import get_current_settings
from prefect.settings.models.server.services import ServicesBaseSetting
from prefect.types._datetime import DateTime, now

if TYPE_CHECKING:
from uuid import UUID
Expand Down Expand Up @@ -67,7 +67,7 @@ async def run_once(self, db: PrefectDBInterface) -> None:
- Querying for flow runs in a scheduled state that are Scheduled to start in the past
- For any runs past the "late" threshold, setting the flow run state to a new `Late` state
"""
scheduled_to_start_before = pendulum.now("UTC").subtract(
scheduled_to_start_before = now("UTC") - datetime.timedelta(
seconds=self.mark_late_after.total_seconds()
)

Expand All @@ -93,7 +93,7 @@ async def run_once(self, db: PrefectDBInterface) -> None:
@inject_db
def _get_select_late_flow_runs_query(
self, scheduled_to_start_before: datetime.datetime, db: PrefectDBInterface
) -> sa.Select[tuple["UUID", pendulum.DateTime | None]]:
) -> sa.Select[tuple["UUID", DateTime | None]]:
"""
Returns a sqlalchemy query for late flow runs.
Expand All @@ -120,7 +120,7 @@ def _get_select_late_flow_runs_query(
async def _mark_flow_run_as_late(
self,
session: AsyncSession,
flow_run: sa.Row[tuple["UUID", pendulum.DateTime | None]],
flow_run: sa.Row[tuple["UUID", DateTime | None]],
) -> None:
"""
Mark a flow run as late.
Expand Down
4 changes: 2 additions & 2 deletions src/prefect/server/services/pause_expirations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import asyncio
from typing import Any, Optional

import pendulum
import sqlalchemy as sa
from sqlalchemy.ext.asyncio import AsyncSession

Expand All @@ -18,6 +17,7 @@
from prefect.settings import PREFECT_API_SERVICES_PAUSE_EXPIRATIONS_LOOP_SECONDS
from prefect.settings.context import get_current_settings
from prefect.settings.models.server.services import ServicesBaseSetting
from prefect.types._datetime import now


class FailExpiredPauses(LoopService):
Expand Down Expand Up @@ -82,7 +82,7 @@ async def _mark_flow_run_as_failed(
if (
flow_run.state is not None
and flow_run.state.state_details.pause_timeout is not None
and flow_run.state.state_details.pause_timeout < pendulum.now("UTC")
and flow_run.state.state_details.pause_timeout < now("UTC")
):
await models.flow_runs.set_flow_run_state(
session=session,
Expand Down
Loading

0 comments on commit 949e865

Please sign in to comment.