Skip to content

Commit

Permalink
Implement new improved retry logic (#1282)
Browse files Browse the repository at this point in the history
* Implement improved retry logic

* Pretty format retry

* Fix test_process_runs

* Support new retry for instances

* Update docs on retry
  • Loading branch information
r4victor authored May 29, 2024
1 parent 3fb54eb commit 042f12c
Show file tree
Hide file tree
Showing 21 changed files with 392 additions and 191 deletions.
4 changes: 2 additions & 2 deletions docs/docs/reference/profiles.yml.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ The profile configuration supports many properties. See below.
max_price:
type: 'Optional[float]'
### `retry_policy`
### `retry`

#SCHEMA# dstack._internal.core.models.profiles.ProfileRetryPolicy
#SCHEMA# dstack._internal.core.models.profiles.ProfileRetry
overrides:
show_root_heading: false
19 changes: 10 additions & 9 deletions src/dstack/_internal/cli/utils/run.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing import List

from rich.markup import escape
from rich.table import Table

from dstack._internal.cli.utils.common import add_row_from_dict, console
from dstack._internal.core.models.instances import InstanceAvailability
from dstack._internal.core.models.profiles import TerminationPolicy
from dstack._internal.core.models.runs import RunPlan
from dstack._internal.utils.common import pretty_date
from dstack._internal.utils.common import format_pretty_duration, pretty_date
from dstack.api import Run


Expand All @@ -23,18 +24,18 @@ def print_run_plan(run_plan: RunPlan, offers_limit: int = 3):
max_duration = (
f"{job_plan.job_spec.max_duration / 3600:g}h" if job_plan.job_spec.max_duration else "-"
)
retry_policy = job_plan.job_spec.retry_policy
retry_policy = (
(f"{retry_policy.duration / 3600:g}h" if retry_policy.duration else "yes")
if retry_policy.retry
else "no"
)
if job_plan.job_spec.retry is None:
retry = "no"
else:
retry = escape(job_plan.job_spec.retry.pretty_format())

profile = run_plan.run_spec.merged_profile
creation_policy = profile.creation_policy
termination_policy = profile.termination_policy
termination_idle_time = f"{profile.termination_idle_time}s"
if termination_policy == TerminationPolicy.DONT_DESTROY:
termination_idle_time = "-"
else:
termination_idle_time = format_pretty_duration(profile.termination_idle_time)

if req.spot is None:
spot_policy = "auto"
Expand All @@ -54,7 +55,7 @@ def th(s: str) -> str:
props.add_row(th("Max price"), max_price)
props.add_row(th("Max duration"), max_duration)
props.add_row(th("Spot policy"), spot_policy)
props.add_row(th("Retry policy"), retry_policy)
props.add_row(th("Retry policy"), retry)
props.add_row(th("Creation policy"), creation_policy)
props.add_row(th("Termination policy"), termination_policy)
props.add_row(th("Termination idle time"), termination_idle_time)
Expand Down
6 changes: 4 additions & 2 deletions src/dstack/_internal/core/backends/gcp/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
)
from dstack._internal.core.backends.base.offers import get_catalog_offers
from dstack._internal.core.backends.gcp.config import GCPConfig
from dstack._internal.core.errors import ComputeResourceNotFoundError, NoCapacityError
from dstack._internal.core.errors import (
ComputeResourceNotFoundError,
NoCapacityError,
)
from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.gateways import GatewayComputeConfiguration
from dstack._internal.core.models.instances import (
Expand Down Expand Up @@ -96,7 +99,6 @@ def create_instance(
instance_config: InstanceConfiguration,
) -> JobProvisioningData:
instance_name = instance_config.instance_name

if not gcp_resources.is_valid_resource_name(instance_name):
# In a rare case the instance name is invalid in GCP,
# we better use a random instance name than fail provisioning.
Expand Down
52 changes: 43 additions & 9 deletions src/dstack/_internal/core/models/profiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,43 @@ class ProfileRetryPolicy(CoreModel):

_validate_duration = validator("duration", pre=True, allow_reuse=True)(parse_duration)

@root_validator()
@classmethod
def _validate_fields(cls, field_values):
if field_values["retry"] and "duration" not in field_values:
field_values["duration"] = DEFAULT_RETRY_DURATION
if field_values.get("duration") is not None:
field_values["retry"] = True
return field_values
@root_validator
def _validate_fields(cls, values):
if values["retry"] and "duration" not in values:
values["duration"] = DEFAULT_RETRY_DURATION
if values.get("duration") is not None:
values["retry"] = True
return values


class RetryEvent(str, Enum):
NO_CAPACITY = "no-capacity"
INTERRUPTION = "interruption"
ERROR = "error"


class ProfileRetry(CoreModel):
on_events: Annotated[
List[RetryEvent],
Field(
description=(
"The list of events that should be handled with retry."
" Supported events are `no-capacity`, `interruption`, and `error`"
)
),
]
duration: Annotated[
Optional[Union[int, str]],
Field(description="The maximum period of retrying the run, e.g., `4h` or `1d`"),
] = None

_validate_duration = validator("duration", pre=True, allow_reuse=True)(parse_duration)

@root_validator
def _validate_fields(cls, values):
if len(values["on_events"]) == 0:
raise ValueError("`on_events` cannot be empty")
return values


class ProfileParams(CoreModel):
Expand All @@ -86,8 +115,13 @@ class ProfileParams(CoreModel):
description="The policy for provisioning spot or on-demand instances: `spot`, `on-demand`, or `auto`"
),
]
retry: Annotated[
Optional[Union[ProfileRetry, bool]],
Field(description="The policy for resubmitting the run. Defaults to `false`"),
]
retry_policy: Annotated[
Optional[ProfileRetryPolicy], Field(description="The policy for re-submitting the run")
Optional[ProfileRetryPolicy],
Field(description="The policy for resubmitting the run. Deprecated in favor of `retry`"),
]
max_duration: Annotated[
Optional[Union[Literal["off"], str, int]],
Expand Down
18 changes: 13 additions & 5 deletions src/dstack/_internal/core/models/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@
CreationPolicy,
Profile,
ProfileParams,
RetryEvent,
SpotPolicy,
TerminationPolicy,
)
from dstack._internal.core.models.repos import AnyRunRepoData
from dstack._internal.core.models.resources import ResourcesSpec
from dstack._internal.utils import common as common_utils
from dstack._internal.utils.common import pretty_resources
from dstack._internal.utils.common import format_pretty_duration, pretty_resources


class AppSpec(CoreModel):
Expand Down Expand Up @@ -58,9 +59,14 @@ def is_finished(self):
return self in self.finished_statuses()


class RetryPolicy(CoreModel):
retry: bool
duration: Optional[int]
class Retry(CoreModel):
on_events: List[RetryEvent]
duration: int

def pretty_format(self) -> str:
pretty_duration = format_pretty_duration(self.duration)
events = ", ".join(event.value for event in self.on_events)
return f"{pretty_duration}[{events}]"


class RunTerminationReason(str, Enum):
Expand Down Expand Up @@ -187,7 +193,7 @@ class JobSpec(CoreModel):
max_duration: Optional[int]
registry_auth: Optional[RegistryAuth]
requirements: Requirements
retry_policy: RetryPolicy
retry: Optional[Retry]
working_dir: Optional[str]


Expand Down Expand Up @@ -225,6 +231,7 @@ class JobSubmission(CoreModel):
id: UUID4
submission_num: int
submitted_at: datetime
last_processed_at: datetime
finished_at: Optional[datetime]
status: JobStatus
termination_reason: Optional[JobTerminationReason]
Expand Down Expand Up @@ -323,6 +330,7 @@ class Run(CoreModel):
project_name: str
user: str
submitted_at: datetime
last_processed_at: datetime
status: RunStatus
termination_reason: Optional[RunTerminationReason]
run_spec: RunSpec
Expand Down
32 changes: 32 additions & 0 deletions src/dstack/_internal/core/services/profiles.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import Optional

from dstack._internal.core.models.profiles import DEFAULT_RETRY_DURATION, Profile, RetryEvent
from dstack._internal.core.models.runs import Retry


def get_retry(profile: Profile) -> Optional[Retry]:
profile_retry = profile.retry
if profile_retry is None:
# Handle retry_policy before retry was introduced
# TODO: Remove once retry_policy no longer supported
profile_retry_policy = profile.retry_policy
if profile_retry_policy is None:
return None
if not profile_retry_policy.retry:
return None
duration = profile_retry_policy.duration or DEFAULT_RETRY_DURATION
return Retry(
on_events=[RetryEvent.NO_CAPACITY, RetryEvent.INTERRUPTION, RetryEvent.ERROR],
duration=duration,
)
if isinstance(profile_retry, bool):
if profile_retry:
return Retry(
on_events=[RetryEvent.NO_CAPACITY, RetryEvent.INTERRUPTION, RetryEvent.ERROR],
duration=DEFAULT_RETRY_DURATION,
)
return None
profile_retry = profile_retry.copy()
if profile_retry.duration is None:
profile_retry.duration = DEFAULT_RETRY_DURATION
return Retry.parse_obj(profile_retry)
69 changes: 41 additions & 28 deletions src/dstack/_internal/server/background/tasks/process_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,18 @@
InstanceRuntime,
RemoteConnectionInfo,
)
from dstack._internal.core.models.profiles import Profile, TerminationPolicy
from dstack._internal.core.models.runs import InstanceStatus, JobProvisioningData, Requirements
from dstack._internal.core.models.profiles import (
Profile,
RetryEvent,
TerminationPolicy,
)
from dstack._internal.core.models.runs import (
InstanceStatus,
JobProvisioningData,
Requirements,
Retry,
)
from dstack._internal.core.services.profiles import get_retry
from dstack._internal.server.db import get_session_ctx
from dstack._internal.server.models import InstanceModel, ProjectModel
from dstack._internal.server.schemas.runner import HealthcheckResponse
Expand Down Expand Up @@ -341,24 +351,6 @@ async def create_instance(instance_id: UUID) -> None:
)
).one()

if instance.retry_policy and instance.retry_policy_duration is not None:
retry_duration_deadline = _get_retry_duration_deadline(instance)
if get_current_datetime() > retry_duration_deadline:
instance.status = InstanceStatus.TERMINATED
instance.deleted = True
instance.deleted_at = get_current_datetime()
instance.termination_reason = "Retry duration expired"
await session.commit()
logger.warning(
"Retry duration expired. Terminate instance %s",
instance.name,
extra={
"instance_name": instance.name,
"instance_status": InstanceStatus.TERMINATED.value,
},
)
return

if instance.last_retry_at is not None:
last_retry = instance.last_retry_at.replace(tzinfo=datetime.timezone.utc)
if get_current_datetime() < last_retry + timedelta(minutes=1):
Expand Down Expand Up @@ -386,10 +378,10 @@ async def create_instance(instance_id: UUID) -> None:
return

try:
profile = Profile.__response__.parse_raw(instance.profile)
requirements = Requirements.__response__.parse_raw(instance.requirements)
instance_configuration = InstanceConfiguration.__response__.parse_raw(
instance.instance_configuration
profile: Profile = Profile.__response__.parse_raw(instance.profile)
requirements: Requirements = Requirements.__response__.parse_raw(instance.requirements)
instance_configuration: InstanceConfiguration = (
InstanceConfiguration.__response__.parse_raw(instance.instance_configuration)
)
except ValidationError as e:
instance.status = InstanceStatus.TERMINATED
Expand All @@ -410,14 +402,35 @@ async def create_instance(instance_id: UUID) -> None:
await session.commit()
return

retry = get_retry(profile)
should_retry = retry is not None and RetryEvent.NO_CAPACITY in retry.on_events

if retry is not None:
retry_duration_deadline = _get_retry_duration_deadline(instance, retry)
if get_current_datetime() > retry_duration_deadline:
instance.status = InstanceStatus.TERMINATED
instance.deleted = True
instance.deleted_at = get_current_datetime()
instance.termination_reason = "Retry duration expired"
await session.commit()
logger.warning(
"Retry duration expired. Terminate instance %s",
instance.name,
extra={
"instance_name": instance.name,
"instance_status": InstanceStatus.TERMINATED.value,
},
)
return

offers = await get_create_instance_offers(
project=instance.project,
profile=profile,
requirements=requirements,
exclude_not_available=True,
)

if not offers and instance.retry_policy:
if not offers and should_retry:
instance.last_retry_at = get_current_datetime()
await session.commit()
logger.debug(
Expand Down Expand Up @@ -479,7 +492,7 @@ async def create_instance(instance_id: UUID) -> None:

instance.last_retry_at = get_current_datetime()

if not instance.retry_policy:
if not should_retry:
instance.status = InstanceStatus.TERMINATED
instance.deleted = True
instance.deleted_at = get_current_datetime()
Expand Down Expand Up @@ -749,9 +762,9 @@ def _get_instance_idle_duration(instance: InstanceModel) -> datetime.timedelta:
return get_current_datetime() - last_time


def _get_retry_duration_deadline(instance: InstanceModel) -> datetime.datetime:
def _get_retry_duration_deadline(instance: InstanceModel, retry: Retry) -> datetime.datetime:
return instance.created_at.replace(tzinfo=datetime.timezone.utc) + timedelta(
seconds=instance.retry_policy_duration
seconds=retry.duration
)


Expand Down
Loading

0 comments on commit 042f12c

Please sign in to comment.