Skip to content

Commit

Permalink
Merge branch 'dev' into d/cookbookfix
Browse files Browse the repository at this point in the history
  • Loading branch information
Vedantsahai18 committed Oct 18, 2024
2 parents 8e9e93f + 3669ca5 commit 676ff97
Show file tree
Hide file tree
Showing 22 changed files with 400 additions and 238 deletions.
4 changes: 3 additions & 1 deletion agents-api/agents_api/activities/execute_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,16 @@
from ..routers.docs.create_doc import create_agent_doc, create_user_doc
from ..routers.docs.search_docs import search_agent_docs, search_user_docs

# FIXME: This is a total mess. Should be refactored.


@auto_blob_store
@beartype
async def execute_system(
context: StepContext,
system: SystemDef,
) -> Any:
arguments = system.arguments
arguments: dict[str, Any] = system.arguments or {}
arguments["developer_id"] = context.execution_input.developer_id

# Unbox all the arguments
Expand Down
3 changes: 3 additions & 0 deletions agents-api/agents_api/common/interceptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from typing import Optional, Type

from temporalio.activity import _CompleteAsyncError as CompleteAsyncError
from temporalio.exceptions import ApplicationError, FailureError, TemporalError
from temporalio.service import RPCError
from temporalio.worker import (
Expand Down Expand Up @@ -42,6 +43,7 @@ async def execute_activity(self, input: ExecuteActivityInput):
ReadOnlyContextError,
NondeterminismError,
RPCError,
CompleteAsyncError,
TemporalError,
FailureError,
):
Expand Down Expand Up @@ -73,6 +75,7 @@ async def execute_workflow(self, input: ExecuteWorkflowInput):
ReadOnlyContextError,
NondeterminismError,
RPCError,
CompleteAsyncError,
TemporalError,
FailureError,
):
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/common/nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def extract_keywords(text: str, top_n: int = 10, clean: bool = True) -> list[str
combined = entities + nouns

# Normalize and count frequency
normalized = [re.sub(r"\s+", " ", kw).strip().lower() for kw in combined]
normalized = [re.sub(r"\s+", " ", kw).strip() for kw in combined]
freq = Counter(normalized)

# Get top_n keywords
Expand Down
10 changes: 7 additions & 3 deletions agents-api/agents_api/common/protocol/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,9 @@ def __save_item(self, item: Any) -> Any:

return store_in_blob_store_if_large(item)

def __getitem__(self, index: int | slice) -> Any:
def __getitem__(
self, index: int | slice
) -> Any: # pytype: disable=signature-mismatch
if isinstance(index, slice):
# Obtain the slice without triggering __getitem__ recursively
sliced_items = super().__getitem__(
Expand Down Expand Up @@ -162,7 +164,9 @@ def _extend_without_processing(self, items: list[Any]) -> None:
"""
super().extend(items)

def __setitem__(self, index: int | slice, value: Any) -> None:
def __setitem__(
self, index: int | slice, value: Any
) -> None: # pytype: disable=signature-mismatch
if isinstance(index, slice):
# Handle slice assignment without processing existing RemoteObjects
processed_values = [self.__save_item(v) for v in value]
Expand Down Expand Up @@ -231,7 +235,7 @@ def extend(self, iterable: list[Any]) -> None:
for item in iterable:
self.append(item)

def __iter__(self) -> Iterator[Any]:
def __iter__(self) -> Iterator[Any]: # pytype: disable=signature-mismatch
for index in range(len(self)):
yield self.__getitem__(index)

Expand Down
12 changes: 9 additions & 3 deletions agents-api/agents_api/common/storage_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@


def store_in_blob_store_if_large(x: Any) -> RemoteObject | Any:
if not use_blob_store_for_temporal:
return x

s3.setup()

serialized = serialize(x)
Expand All @@ -28,6 +31,9 @@ def store_in_blob_store_if_large(x: Any) -> RemoteObject | Any:


def load_from_blob_store_if_remote(x: Any | RemoteObject) -> Any:
if not use_blob_store_for_temporal:
return x

s3.setup()

if isinstance(x, RemoteObject):
Expand All @@ -45,8 +51,8 @@ def load_from_blob_store_if_remote(x: Any | RemoteObject) -> Any:
def auto_blob_store(f: Callable | None = None, *, deep: bool = False) -> Callable:
def auto_blob_store_decorator(f: Callable) -> Callable:
def load_args(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
args: list | tuple, kwargs: dict[str, Any]
) -> tuple[list | tuple, dict[str, Any]]:
new_args = [load_from_blob_store_if_remote(arg) for arg in args]
new_kwargs = {
k: load_from_blob_store_if_remote(v) for k, v in kwargs.items()
Expand Down Expand Up @@ -143,4 +149,4 @@ async def wrapper(*args, **kwargs) -> Any:

return result

return wrapper
return wrapper if use_blob_store_for_temporal else f
6 changes: 6 additions & 0 deletions agents-api/agents_api/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,12 @@
temporal_worker_url=temporal_worker_url,
temporal_namespace=temporal_namespace,
embedding_model_id=embedding_model_id,
use_blob_store_for_temporal=use_blob_store_for_temporal,
blob_store_bucket=blob_store_bucket,
blob_store_cutoff_kb=blob_store_cutoff_kb,
s3_endpoint=s3_endpoint,
s3_access_key=s3_access_key,
s3_secret_key=s3_secret_key,
testing=testing,
)

Expand Down
6 changes: 5 additions & 1 deletion agents-api/agents_api/models/docs/get_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,11 @@
one=True,
transform=lambda d: {
"content": [s[1] for s in sorted(d["snippet_data"], key=lambda x: x[0])],
"embeddings": [s[2] for s in sorted(d["snippet_data"], key=lambda x: x[0])],
"embeddings": [
s[2]
for s in sorted(d["snippet_data"], key=lambda x: x[0])
if s[2] is not None
],
**d,
},
)
Expand Down
6 changes: 5 additions & 1 deletion agents-api/agents_api/models/docs/list_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@
Doc,
transform=lambda d: {
"content": [s[1] for s in sorted(d["snippet_data"], key=lambda x: x[0])],
"embeddings": [s[2] for s in sorted(d["snippet_data"], key=lambda x: x[0])],
"embeddings": [
s[2]
for s in sorted(d["snippet_data"], key=lambda x: x[0])
if s[2] is not None
],
**d,
},
)
Expand Down
6 changes: 4 additions & 2 deletions agents-api/agents_api/models/docs/search_docs_by_text.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""This module contains functions for searching documents in the CozoDB based on embedding queries."""

import re
from typing import Any, Literal, TypeVar
from uuid import UUID

Expand Down Expand Up @@ -62,9 +63,10 @@ def search_docs_by_text(
[owner_type, str(owner_id)] for owner_type, owner_id in owners
]

# Need to use NEAR/3($query) to search for arbitrary text within 3 words of each other
# See: https://docs.cozodb.org/en/latest/vector.html#full-text-search-fts
fts_queries = paragraph_to_custom_queries(query)
fts_queries = paragraph_to_custom_queries(query) or [
re.sub(r"[^\w\s\-_]+", "", query)
]

# Construct the datalog query for searching document snippets
search_query = f"""
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/models/execution/get_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
transform=lambda d: {
**d,
"output": d["output"][OUTPUT_UNNEST_KEY]
if OUTPUT_UNNEST_KEY in d["output"]
if isinstance(d["output"], dict) and OUTPUT_UNNEST_KEY in d["output"]
else d["output"],
},
)
Expand Down
6 changes: 4 additions & 2 deletions agents-api/agents_api/models/execution/list_executions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
transform=lambda d: {
**d,
"output": d["output"][OUTPUT_UNNEST_KEY]
if OUTPUT_UNNEST_KEY in d["output"]
else d["output"],
if isinstance(d.get("output"), dict) and OUTPUT_UNNEST_KEY in d["output"]
else d.get("output"),
},
)
@cozo_query
Expand All @@ -58,6 +58,7 @@ def list_executions(
task_id,
status,
input,
output,
session_id,
metadata,
created_at,
Expand All @@ -68,6 +69,7 @@ def list_executions(
execution_id: id,
status,
input,
output,
session_id,
metadata,
created_at,
Expand Down
6 changes: 3 additions & 3 deletions agents-api/agents_api/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pydantic import BaseModel

from ..common.utils.cozo import uuid_int_list_to_uuid4
from ..env import debug, do_verify_developer, do_verify_developer_owns_resource
from ..env import do_verify_developer, do_verify_developer_owns_resource

P = ParamSpec("P")
T = TypeVar("T")
Expand Down Expand Up @@ -185,8 +185,8 @@ def make_cozo_json_query(fields):

def cozo_query(
func: Callable[P, tuple[str | list[str | None], dict]] | None = None,
debug: bool | None = debug,
only_on_error: bool = True,
debug: bool | None = None,
only_on_error: bool = False,
):
def cozo_query_dec(func: Callable[P, tuple[str | list[Any], dict]]):
"""
Expand Down
4 changes: 3 additions & 1 deletion agents-api/agents_api/routers/sessions/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,9 @@ async def chat(
)

total_tokens_per_user.labels(str(developer.id)).inc(
amount=chat_response.usage.total_tokens or 0
amount=chat_response.usage.total_tokens
if chat_response.usage is not None
else 0
)

return chat_response
4 changes: 0 additions & 4 deletions agents-api/agents_api/routers/tasks/create_or_update_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,6 @@ async def create_or_update_task(
# TODO: Do thorough validation of the task spec
# SCRUM-10

# FIXME: There is also some subtle bug here that prevents us from
# starting executions from tasks created via this endpoint
# SCRUM-9

# Validate the input schema
try:
if data.input_schema is not None:
Expand Down
1 change: 0 additions & 1 deletion agents-api/agents_api/routers/tasks/create_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ async def create_task(
) -> ResourceCreatedResponse:
# TODO: Do thorough validation of the task spec
# SCRUM-10
# TODO: Validate the jinja templates

# Validate the input schema
try:
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/routers/tasks/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ async def body(self) -> bytes:
"application/yaml",
"text/yaml",
]:
body = yaml.load(body, yaml.CSafeLoader)
body = yaml.load(body)

self._body = body

Expand Down
Loading

0 comments on commit 676ff97

Please sign in to comment.