Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(agents-api): added mmr to chat #1013

Merged
merged 8 commits into from
Jan 4, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion agents-api/agents_api/activities/container.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
from typing import Any

from aiobotocore.client import AioBaseClient
from asyncpg.pool import Pool


class State:
pass
postgres_pool: Pool | None
s3_client: AioBaseClient | None

def __init__(self):
self.postgres_pool = None
self.s3_client = None

def __setattr__(self, name: str, value: Any) -> None:
super().__setattr__(name, value)


class Container:
Expand Down
11 changes: 8 additions & 3 deletions agents-api/agents_api/activities/execute_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ..common.exceptions.tools import IntegrationExecutionException
from ..common.protocol.tasks import ExecutionInput, StepContext
from ..env import testing
from ..queries.tools import get_tool_args_from_metadata
from ..queries import tools
Vedantsahai18 marked this conversation as resolved.
Show resolved Hide resolved
from .container import container


Expand All @@ -28,17 +28,22 @@ async def execute_integration(

developer_id = context.execution_input.developer_id
agent_id = context.execution_input.agent.id

if context.execution_input.task is None:
msg = "Task cannot be None in execution_input"
raise ValueError(msg)

task_id = context.execution_input.task.id

merged_tool_args = await get_tool_args_from_metadata(
merged_tool_args = await tools.get_tool_args_from_metadata(
developer_id=developer_id,
agent_id=agent_id,
task_id=task_id,
arg_type="args",
connection_pool=container.state.postgres_pool,
)

merged_tool_setup = await get_tool_args_from_metadata(
merged_tool_setup = await tools.get_tool_args_from_metadata(
developer_id=developer_id,
agent_id=agent_id,
task_id=task_id,
Expand Down
4 changes: 2 additions & 2 deletions agents-api/agents_api/activities/execute_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
)
from ..common.protocol.tasks import ExecutionInput, StepContext
from ..env import testing
from ..queries.developers import get_developer
from ..queries import developers
from .container import container
from .utils import get_handler

Expand Down Expand Up @@ -95,7 +95,7 @@ async def execute_system(

# Handle chat operations
if system.operation == "chat" and system.resource == "session":
developer = await get_developer(
developer = await developers.get_developer(
developer_id=arguments["developer_id"],
connection_pool=container.state.postgres_pool,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ async def transition_step(
msg = "Expected ExecutionInput type for context.execution_input"
raise TypeError(msg)

if not context.execution_input.execution:
msg = "Execution is required in execution_input"
raise ValueError(msg)

# Create transition
try:
transition = await create_execution_transition(
Expand Down
10 changes: 10 additions & 0 deletions agents-api/agents_api/activities/task_steps/yield_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,20 @@ async def yield_step(context: StepContext) -> StepOutcome:
msg = "Expected ExecutionInput type for context.execution_input"
raise TypeError(msg)

# Add validation for task
if not context.execution_input.task:
msg = "Task is required in execution_input"
raise ValueError(msg)

all_workflows = context.execution_input.task.workflows
workflow = context.current_step.workflow
exprs = context.current_step.arguments

# Validate workflows exists
if not all_workflows:
msg = "No workflows found in task"
raise ValueError(msg)

assert workflow in [wf.name for wf in all_workflows], (
f"Workflow {workflow} not found in task"
)
Expand Down
104 changes: 99 additions & 5 deletions agents-api/agents_api/activities/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,21 +41,60 @@ def safe_range(*args):
return result


def safe_json_loads(s: str):
@beartype
def safe_json_loads(s: str) -> Any:
"""
Safely load a JSON string with size limits.

Args:
s: JSON string to parse

Returns:
Parsed JSON data

Raises:
ValueError: If string exceeds size limit
"""
if len(s) > MAX_STRING_LENGTH:
msg = f"String exceeds maximum length of {MAX_STRING_LENGTH}"
raise ValueError(msg)
return json.loads(s)


def safe_yaml_load(s: str):
@beartype
def safe_yaml_load(s: str) -> Any:
"""
Safely load a YAML string with size limits.

Args:
s: YAML string to parse

Returns:
Parsed YAML data

Raises:
ValueError: If string exceeds size limit
"""
if len(s) > MAX_STRING_LENGTH:
msg = f"String exceeds maximum length of {MAX_STRING_LENGTH}"
raise ValueError(msg)
return yaml.load(s)


@beartype
def safe_base64_decode(s: str) -> str:
"""
Safely decode a base64 string with size limits.

Args:
s: Base64 string to decode

Returns:
Decoded UTF-8 string

Raises:
ValueError: If string exceeds size limit or is invalid base64
"""
if len(s) > MAX_STRING_LENGTH:
msg = f"String exceeds maximum length of {MAX_STRING_LENGTH}"
raise ValueError(msg)
Expand All @@ -66,21 +105,66 @@ def safe_base64_decode(s: str) -> str:
raise ValueError(msg)


@beartype
def safe_base64_encode(s: str) -> str:
"""
Safely encode a string to base64 with size limits.

Args:
s: String to encode

Returns:
Base64 encoded string

Raises:
ValueError: If string exceeds size limit
"""
if len(s) > MAX_STRING_LENGTH:
msg = f"String exceeds maximum length of {MAX_STRING_LENGTH}"
raise ValueError(msg)
return base64.b64encode(s.encode("utf-8")).decode("utf-8")


def safe_random_choice(seq):
@beartype
def safe_random_choice(seq: list[Any] | tuple[Any, ...] | str) -> Any:
"""
Safely choose a random element from a sequence with size limits.

Args:
seq: A sequence (list, tuple, or string) to choose from

Returns:
A randomly selected element

Raises:
ValueError: If sequence exceeds size limit
TypeError: If input is not a valid sequence type
"""
if len(seq) > MAX_COLLECTION_SIZE:
msg = f"Sequence exceeds maximum size of {MAX_COLLECTION_SIZE}"
raise ValueError(msg)
return random.choice(seq)


def safe_random_sample(population, k):
@beartype
def safe_random_sample(population: list[T] | tuple[T, ...] | str, k: int) -> list[T]:
"""
Safely sample k elements from a population with size limits.

Args:
population: A sequence to sample from
k: Number of elements to sample

Returns:
A list containing k randomly selected elements

Raises:
ValueError: If population/sample size exceeds limits
TypeError: If input is not a valid sequence type
"""
if not isinstance(population, list | tuple | str):
msg = "Expected a sequence (list, tuple, or string)"
raise TypeError(msg)
if len(population) > MAX_COLLECTION_SIZE:
msg = f"Population exceeds maximum size of {MAX_COLLECTION_SIZE}"
raise ValueError(msg)
Expand All @@ -93,9 +177,19 @@ def safe_random_sample(population, k):
return random.sample(population, k)


@beartype
def chunk_doc(string: str) -> list[str]:
"""
Chunk a string into sentences.

Args:
string: The text to chunk into sentences

Returns:
A list of sentence chunks

Raises:
ValueError: If string exceeds size limit
"""
if len(string) > MAX_STRING_LENGTH:
msg = f"String exceeds maximum length of {MAX_STRING_LENGTH}"
Expand Down Expand Up @@ -397,8 +491,8 @@ def get_handler(system: SystemDef) -> Callable:
from ..queries.agents.update_agent import update_agent as update_agent_query
from ..queries.docs.delete_doc import delete_doc as delete_doc_query
from ..queries.docs.list_docs import list_docs as list_docs_query
from ..queries.entries.get_history import get_history as get_history_query
from ..queries.sessions.create_session import create_session as create_session_query
from ..queries.sessions.delete_session import delete_session as delete_session_query
from ..queries.sessions.get_session import get_session as get_session_query
from ..queries.sessions.list_sessions import list_sessions as list_sessions_query
from ..queries.sessions.update_session import update_session as update_session_query
Expand Down
8 changes: 4 additions & 4 deletions agents-api/agents_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ async def lifespan(*containers: list[FastAPI | ObjectWithState]):
pg_dsn = os.environ.get("PG_DSN")

for container in containers:
if not getattr(container.state, "postgres_pool", None):
if hasattr(container, "state") and not getattr(container.state, "postgres_pool", None):
container.state.postgres_pool = await create_db_pool(pg_dsn)

# INIT S3 #
Expand All @@ -35,7 +35,7 @@ async def lifespan(*containers: list[FastAPI | ObjectWithState]):
s3_endpoint = os.environ.get("S3_ENDPOINT")

for container in containers:
if not getattr(container.state, "s3_client", None):
if hasattr(container, "state") and not getattr(container.state, "s3_client", None):
session = get_session()
container.state.s3_client = await session.create_client(
"s3",
Expand All @@ -49,13 +49,13 @@ async def lifespan(*containers: list[FastAPI | ObjectWithState]):
finally:
# CLOSE POSTGRES #
for container in containers:
if getattr(container.state, "postgres_pool", None):
if hasattr(container, "state") and getattr(container.state, "postgres_pool", None):
await container.state.postgres_pool.close()
container.state.postgres_pool = None

# CLOSE S3 #
for container in containers:
if getattr(container.state, "s3_client", None):
if hasattr(container, "state") and getattr(container.state, "s3_client", None):
await container.state.s3_client.close()
container.state.s3_client = None

Expand Down
4 changes: 1 addition & 3 deletions agents-api/agents_api/clients/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,8 @@ async def aembedding(
embedding_list: list[dict[Literal["embedding"], list[float]]] = response.data

# Truncate the embedding to the specified dimensions
embedding_list = [
return [
item["embedding"][:dimensions]
for item in embedding_list
if len(item["embedding"]) >= dimensions
]

return embedding_list
4 changes: 4 additions & 0 deletions agents-api/agents_api/clients/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ async def run_task_execution_workflow(
):
from ..workflows.task_execution import TaskExecutionWorkflow

if execution_input.execution is None:
msg = "execution_input.execution cannot be None"
raise ValueError(msg)

start: TransitionTarget = start or TransitionTarget(workflow="main", step=0)

client = client or (await get_client())
Expand Down
Loading
Loading