Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(agents-api): Remove auto_blob_store in favor of interceptor based system #977

Merged
merged 3 commits into from
Dec 21, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions agents-api/agents_api/activities/embed_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,11 @@
from temporalio import activity

from ..clients import cozo, litellm
from ..common.storage_handler import auto_blob_store
from ..env import testing
from ..models.docs.embed_snippets import embed_snippets as embed_snippets_query
from .types import EmbedDocsPayload


@auto_blob_store(deep=True)
@beartype
async def embed_docs(
payload: EmbedDocsPayload, cozo_client=None, max_batch_size: int = 100
Expand Down
2 changes: 0 additions & 2 deletions agents-api/agents_api/activities/excecute_api_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from temporalio import activity

from ..autogen.openapi_model import ApiCallDef
from ..common.storage_handler import auto_blob_store
from ..env import testing


Expand All @@ -20,7 +19,6 @@ class RequestArgs(TypedDict):
headers: Optional[dict[str, str]]


@auto_blob_store(deep=True)
@beartype
async def execute_api_call(
api_call: ApiCallDef,
Expand Down
2 changes: 0 additions & 2 deletions agents-api/agents_api/activities/execute_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,10 @@
from ..clients import integrations
from ..common.exceptions.tools import IntegrationExecutionException
from ..common.protocol.tasks import ExecutionInput, StepContext
from ..common.storage_handler import auto_blob_store
from ..env import testing
from ..models.tools import get_tool_args_from_metadata


@auto_blob_store(deep=True)
@beartype
async def execute_integration(
context: StepContext,
Expand Down
7 changes: 1 addition & 6 deletions agents-api/agents_api/activities/execute_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,14 @@
VectorDocSearchRequest,
)
from ..common.protocol.tasks import ExecutionInput, StepContext
from ..common.storage_handler import auto_blob_store, load_from_blob_store_if_remote
from ..env import testing
from ..queries.developer import get_developer
from ..queries.developers import get_developer
from .utils import get_handler

# For running synchronous code in the background
process_pool_executor = ProcessPoolExecutor()


@auto_blob_store(deep=True)
@beartype
async def execute_system(
context: StepContext,
Expand All @@ -37,9 +35,6 @@ async def execute_system(
"""Execute a system call with the appropriate handler and transformed arguments."""
arguments: dict[str, Any] = system.arguments or {}

if set(arguments.keys()) == {"bucket", "key"}:
arguments = await load_from_blob_store_if_remote(arguments)

if not isinstance(context.execution_input, ExecutionInput):
raise TypeError("Expected ExecutionInput type for context.execution_input")

Expand Down
12 changes: 4 additions & 8 deletions agents-api/agents_api/activities/sync_items_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,16 @@

@beartype
async def save_inputs_remote_fn(inputs: list[Any]) -> list[Any | RemoteObject]:
from ..common.storage_handler import store_in_blob_store_if_large
from ..common.interceptors import offload_if_large

return await asyncio.gather(
*[store_in_blob_store_if_large(input) for input in inputs]
)
return await asyncio.gather(*[offload_if_large(input) for input in inputs])


@beartype
async def load_inputs_remote_fn(inputs: list[Any | RemoteObject]) -> list[Any]:
from ..common.storage_handler import load_from_blob_store_if_remote
from ..common.interceptors import load_if_remote

return await asyncio.gather(
*[load_from_blob_store_if_remote(input) for input in inputs]
)
return await asyncio.gather(*[load_if_remote(input) for input in inputs])


save_inputs_remote = activity.defn(name="save_inputs_remote")(save_inputs_remote_fn)
Expand Down
2 changes: 0 additions & 2 deletions agents-api/agents_api/activities/task_steps/base_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from temporalio import activity # noqa: E402
from thefuzz import fuzz # noqa: E402

from ...common.storage_handler import auto_blob_store # noqa: E402
from ...env import testing # noqa: E402
from ..utils import get_evaluator # noqa: E402

Expand Down Expand Up @@ -63,7 +62,6 @@ def _recursive_evaluate(expr, evaluator: SimpleEval):
raise ValueError(f"Invalid expression: {expr}")


@auto_blob_store(deep=True)
@beartype
async def base_evaluate(
exprs: Any,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@
from temporalio import activity

from ... import models
from ...common.storage_handler import auto_blob_store
from ...env import testing


@auto_blob_store(deep=True)
@beartype
async def cozo_query_step(
query_name: str,
Expand Down
2 changes: 0 additions & 2 deletions agents-api/agents_api/activities/task_steps/evaluate_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,9 @@

from ...activities.utils import simple_eval_dict
from ...common.protocol.tasks import StepContext, StepOutcome
from ...common.storage_handler import auto_blob_store
from ...env import testing


@auto_blob_store(deep=True)
@beartype
async def evaluate_step(
context: StepContext,
Expand Down
2 changes: 0 additions & 2 deletions agents-api/agents_api/activities/task_steps/for_each_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@
StepContext,
StepOutcome,
)
from ...common.storage_handler import auto_blob_store
from ...env import testing
from .base_evaluate import base_evaluate


@auto_blob_store(deep=True)
@beartype
async def for_each_step(context: StepContext) -> StepOutcome:
try:
Expand Down
5 changes: 2 additions & 3 deletions agents-api/agents_api/activities/task_steps/get_value_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@
from temporalio import activity

from ...common.protocol.tasks import StepContext, StepOutcome
from ...common.storage_handler import auto_blob_store
from ...env import testing


# TODO: We should use this step to query the parent workflow and get the value from the workflow context
# SCRUM-1
@auto_blob_store(deep=True)


@beartype
async def get_value_step(
context: StepContext,
Expand Down
2 changes: 0 additions & 2 deletions agents-api/agents_api/activities/task_steps/if_else_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@
StepContext,
StepOutcome,
)
from ...common.storage_handler import auto_blob_store
from ...env import testing
from .base_evaluate import base_evaluate


@auto_blob_store(deep=True)
@beartype
async def if_else_step(context: StepContext) -> StepOutcome:
# NOTE: This activity is only for logging, so we just evaluate the expression
Expand Down
2 changes: 0 additions & 2 deletions agents-api/agents_api/activities/task_steps/log_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@
StepContext,
StepOutcome,
)
from ...common.storage_handler import auto_blob_store
from ...common.utils.template import render_template
from ...env import testing


@auto_blob_store(deep=True)
@beartype
async def log_step(context: StepContext) -> StepOutcome:
# NOTE: This activity is only for logging, so we just evaluate the expression
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,10 @@
StepContext,
StepOutcome,
)
from ...common.storage_handler import auto_blob_store
from ...env import testing
from .base_evaluate import base_evaluate


@auto_blob_store(deep=True)
@beartype
async def map_reduce_step(context: StepContext) -> StepOutcome:
try:
Expand Down
2 changes: 0 additions & 2 deletions agents-api/agents_api/activities/task_steps/prompt_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
litellm, # We dont directly import `acompletion` so we can mock it
)
from ...common.protocol.tasks import ExecutionInput, StepContext, StepOutcome
from ...common.storage_handler import auto_blob_store
from ...common.utils.template import render_template
from ...env import debug
from .base_evaluate import base_evaluate
Expand Down Expand Up @@ -62,7 +61,6 @@ def format_tool(tool: Tool) -> dict:


@activity.defn
@auto_blob_store(deep=True)
@beartype
async def prompt_step(context: StepContext) -> StepOutcome:
# Get context data
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@

from ...autogen.openapi_model import CreateTransitionRequest
from ...common.protocol.tasks import StepContext
from ...common.storage_handler import auto_blob_store
from .transition_step import original_transition_step


@activity.defn
@auto_blob_store(deep=True)
@beartype
async def raise_complete_async(context: StepContext, output: Any) -> None:
activity_info = activity.info()
Expand Down
2 changes: 0 additions & 2 deletions agents-api/agents_api/activities/task_steps/return_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@
StepContext,
StepOutcome,
)
from ...common.storage_handler import auto_blob_store
from ...env import testing
from .base_evaluate import base_evaluate


@auto_blob_store(deep=True)
@beartype
async def return_step(context: StepContext) -> StepOutcome:
try:
Expand Down
5 changes: 2 additions & 3 deletions agents-api/agents_api/activities/task_steps/set_value_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@

from ...activities.utils import simple_eval_dict
from ...common.protocol.tasks import StepContext, StepOutcome
from ...common.storage_handler import auto_blob_store
from ...env import testing


# TODO: We should use this step to signal to the parent workflow and set the value on the workflow context
# SCRUM-2
@auto_blob_store(deep=True)


@beartype
async def set_value_step(
context: StepContext,
Expand Down
2 changes: 0 additions & 2 deletions agents-api/agents_api/activities/task_steps/switch_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@
StepContext,
StepOutcome,
)
from ...common.storage_handler import auto_blob_store
from ...env import testing
from ..utils import get_evaluator


@auto_blob_store(deep=True)
@beartype
async def switch_step(context: StepContext) -> StepOutcome:
try:
Expand Down
2 changes: 0 additions & 2 deletions agents-api/agents_api/activities/task_steps/tool_call_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
StepContext,
StepOutcome,
)
from ...common.storage_handler import auto_blob_store


# FIXME: This shouldn't be here.
Expand Down Expand Up @@ -47,7 +46,6 @@ def construct_tool_call(


@activity.defn
@auto_blob_store(deep=True)
@beartype
async def tool_call_step(context: StepContext) -> StepOutcome:
assert isinstance(context.current_step, ToolCallStep)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from ...autogen.openapi_model import CreateTransitionRequest, Transition
from ...clients.temporal import get_workflow_handle
from ...common.protocol.tasks import ExecutionInput, StepContext
from ...common.storage_handler import load_from_blob_store_if_remote
from ...env import (
temporal_activity_after_retry_timeout,
testing,
Expand Down Expand Up @@ -48,11 +47,6 @@ async def transition_step(
TaskExecutionWorkflow.set_last_error, LastErrorInput(last_error=None)
)

# Load output from blob store if it is a remote object
transition_info.output = await load_from_blob_store_if_remote(
transition_info.output
)

if not isinstance(context.execution_input, ExecutionInput):
raise TypeError("Expected ExecutionInput type for context.execution_input")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@

from ...autogen.openapi_model import WaitForInputStep
from ...common.protocol.tasks import StepContext, StepOutcome
from ...common.storage_handler import auto_blob_store
from ...env import testing
from .base_evaluate import base_evaluate


@auto_blob_store(deep=True)
@beartype
async def wait_for_input_step(context: StepContext) -> StepOutcome:
try:
Expand Down
2 changes: 0 additions & 2 deletions agents-api/agents_api/activities/task_steps/yield_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@

from ...autogen.openapi_model import TransitionTarget, YieldStep
from ...common.protocol.tasks import ExecutionInput, StepContext, StepOutcome
from ...common.storage_handler import auto_blob_store
from ...env import testing
from .base_evaluate import base_evaluate


@auto_blob_store(deep=True)
@beartype
async def yield_step(context: StepContext) -> StepOutcome:
try:
Expand Down
3 changes: 1 addition & 2 deletions agents-api/agents_api/autogen/openapi_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
model_validator,
)

from ..common.storage_handler import RemoteObject
from ..common.utils.datetime import utcnow
from .Agents import *
from .Chat import *
Expand Down Expand Up @@ -358,7 +357,7 @@ def validate_subworkflows(self):


class SystemDef(SystemDef):
arguments: dict[str, Any] | None | RemoteObject = None
arguments: dict[str, Any] | None = None


class CreateTransitionRequest(Transition):
Expand Down
1 change: 1 addition & 0 deletions agents-api/agents_api/clients/async_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)


@alru_cache(maxsize=1024)
async def list_buckets() -> list[str]:
session = get_session()

Expand Down
9 changes: 6 additions & 3 deletions agents-api/agents_api/clients/temporal.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from datetime import timedelta
from uuid import UUID

Expand All @@ -12,9 +13,9 @@
from temporalio.runtime import PrometheusConfig, Runtime, TelemetryConfig

from ..autogen.openapi_model import TransitionTarget
from ..common.interceptors import offload_if_large
from ..common.protocol.tasks import ExecutionInput
from ..common.retry_policies import DEFAULT_RETRY_POLICY
from ..common.storage_handler import store_in_blob_store_if_large
from ..env import (
temporal_client_cert,
temporal_metrics_bind_host,
Expand Down Expand Up @@ -96,8 +97,10 @@ async def run_task_execution_workflow(
client = client or (await get_client())
execution_id = execution_input.execution.id
execution_id_key = SearchAttributeKey.for_keyword("CustomStringField")
execution_input.arguments = await store_in_blob_store_if_large(
execution_input.arguments

old_args = execution_input.arguments
execution_input.arguments = await asyncio.gather(
*[offload_if_large(arg) for arg in old_args]
)

return await client.start_workflow(
Expand Down
Loading
Loading