Skip to content

Commit

Permalink
Add retry policies to more activities/workflows
Browse files Browse the repository at this point in the history
  • Loading branch information
HamadaSalhab committed Oct 2, 2024
1 parent 940b6ac commit 93dfee1
Show file tree
Hide file tree
Showing 9 changed files with 22 additions and 0 deletions.
2 changes: 2 additions & 0 deletions agents-api/agents_api/routers/docs/create_doc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Annotated
from uuid import UUID, uuid4

from ...common.retry_policies import DEFAULT_RETRY_POLICY
from fastapi import BackgroundTasks, Depends
from starlette.status import HTTP_201_CREATED
from temporalio.client import Client as TemporalClient
Expand Down Expand Up @@ -41,6 +42,7 @@ async def run_embed_docs_task(
embed_payload,
task_queue=temporal_task_queue,
id=str(job_id),
retry_policy=DEFAULT_RETRY_POLICY
)

# TODO: Remove this conditional once we have a way to run workflows in
Expand Down
3 changes: 3 additions & 0 deletions agents-api/agents_api/workflows/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from temporalio import workflow

from ..common.retry_policies import DEFAULT_RETRY_POLICY

with workflow.unsafe.imports_passed_through():
from ..activities.demo import demo_activity

Expand All @@ -14,4 +16,5 @@ async def run(self, a: int, b: int) -> int:
demo_activity,
args=[a, b],
start_to_close_timeout=timedelta(seconds=30),
retry_policy=DEFAULT_RETRY_POLICY
)
2 changes: 2 additions & 0 deletions agents-api/agents_api/workflows/embed_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
with workflow.unsafe.imports_passed_through():
from ..activities.embed_docs import embed_docs
from ..activities.types import EmbedDocsPayload
from ..common.retry_policies import DEFAULT_RETRY_POLICY


@workflow.defn
Expand All @@ -18,4 +19,5 @@ async def run(self, embed_payload: EmbedDocsPayload) -> None:
embed_docs,
embed_payload,
schedule_to_close_timeout=timedelta(seconds=600),
retry_policy=DEFAULT_RETRY_POLICY
)
2 changes: 2 additions & 0 deletions agents-api/agents_api/workflows/mem_rating.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

with workflow.unsafe.imports_passed_through():
from ..activities.mem_rating import mem_rating
from ..common.retry_policies import DEFAULT_RETRY_POLICY


@workflow.defn
Expand All @@ -17,4 +18,5 @@ async def run(self, memory: str) -> None:
mem_rating,
memory,
schedule_to_close_timeout=timedelta(seconds=600),
retry_policy=DEFAULT_RETRY_POLICY
)
2 changes: 2 additions & 0 deletions agents-api/agents_api/workflows/summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

with workflow.unsafe.imports_passed_through():
from ..activities.summarization import summarization
from ..common.retry_policies import DEFAULT_RETRY_POLICY


@workflow.defn
Expand All @@ -17,4 +18,5 @@ async def run(self, session_id: str) -> None:
summarization,
session_id,
schedule_to_close_timeout=timedelta(seconds=600),
retry_policy=DEFAULT_RETRY_POLICY
)
5 changes: 5 additions & 0 deletions agents-api/agents_api/workflows/task_execution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ async def run(
task_steps.raise_complete_async,
args=[context, output],
schedule_to_close_timeout=timedelta(days=31),
retry_policy=DEFAULT_RETRY_POLICY
)

state = PartialTransition(type="resume", output=result)
Expand Down Expand Up @@ -419,6 +420,7 @@ async def run(
task_steps.raise_complete_async,
args=[context, tool_calls_input],
schedule_to_close_timeout=timedelta(days=31),
retry_policy=DEFAULT_RETRY_POLICY
)

# Feed the tool call results back to the model
Expand All @@ -430,6 +432,7 @@ async def run(
schedule_to_close_timeout=timedelta(
seconds=30 if debug or testing else 600
),
retry_policy=DEFAULT_RETRY_POLICY
)
state = PartialTransition(output=new_response.output, type="resume")

Expand Down Expand Up @@ -473,6 +476,7 @@ async def run(
task_steps.raise_complete_async,
args=[context, tool_call],
schedule_to_close_timeout=timedelta(days=31),
retry_policy=DEFAULT_RETRY_POLICY
)

state = PartialTransition(output=tool_call_response, type="resume")
Expand Down Expand Up @@ -503,6 +507,7 @@ async def run(
schedule_to_close_timeout=timedelta(
seconds=30 if debug or testing else 600
),
retry_policy=DEFAULT_RETRY_POLICY
)

state = PartialTransition(output=tool_call_response)
Expand Down
2 changes: 2 additions & 0 deletions agents-api/agents_api/workflows/task_execution/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ async def execute_map_reduce_step(
task_steps.base_evaluate,
args=[reduce, {"results": result, "_": output}],
schedule_to_close_timeout=timedelta(seconds=30),
retry_policy=DEFAULT_RETRY_POLICY
)

return result
Expand Down Expand Up @@ -245,6 +246,7 @@ async def execute_map_reduce_step_parallel(
extra_lambda_strs,
],
schedule_to_close_timeout=timedelta(seconds=30),
retry_policy=DEFAULT_RETRY_POLICY
)

except BaseException as e:
Expand Down
2 changes: 2 additions & 0 deletions agents-api/agents_api/workflows/truncation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

with workflow.unsafe.imports_passed_through():
from ..activities.truncation import truncation
from ..common.retry_policies import DEFAULT_RETRY_POLICY


@workflow.defn
Expand All @@ -17,4 +18,5 @@ async def run(self, session_id: str, token_count_threshold: int) -> None:
truncation,
args=[session_id, token_count_threshold],
schedule_to_close_timeout=timedelta(seconds=600),
retry_policy=DEFAULT_RETRY_POLICY
)
2 changes: 2 additions & 0 deletions agents-api/tests/test_activities.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from agents_api.clients import temporal
from agents_api.env import temporal_task_queue
from agents_api.workflows.demo import DemoWorkflow
from agents_api.workflows.task_execution.helpers import DEFAULT_RETRY_POLICY

from .fixtures import (
cozo_client,
Expand Down Expand Up @@ -49,6 +50,7 @@ async def _():
args=[1, 2],
id=str(uuid4()),
task_queue=temporal_task_queue,
retry_policy=DEFAULT_RETRY_POLICY
)

assert result == 3
Expand Down

0 comments on commit 93dfee1

Please sign in to comment.