Skip to content

Commit

Permalink
chore: misc fixes and refactors
Browse files Browse the repository at this point in the history
  • Loading branch information
Vedantsahai18 committed Jan 3, 2025
1 parent a38dca1 commit 5608d8e
Show file tree
Hide file tree
Showing 16 changed files with 235 additions and 207 deletions.
2 changes: 1 addition & 1 deletion agents-api/agents_api/activities/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,8 +397,8 @@ def get_handler(system: SystemDef) -> Callable:
from ..queries.agents.update_agent import update_agent as update_agent_query
from ..queries.docs.delete_doc import delete_doc as delete_doc_query
from ..queries.docs.list_docs import list_docs as list_docs_query
from ..queries.entries.get_history import get_history as get_history_query
from ..queries.sessions.create_session import create_session as create_session_query
from ..queries.sessions.delete_session import delete_session as delete_session_query
from ..queries.sessions.get_session import get_session as get_session_query
from ..queries.sessions.list_sessions import list_sessions as list_sessions_query
from ..queries.sessions.update_session import update_session as update_session_query
Expand Down
8 changes: 4 additions & 4 deletions agents-api/agents_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ async def lifespan(*containers: list[FastAPI | ObjectWithState]):
pg_dsn = os.environ.get("PG_DSN")

for container in containers:
if not getattr(container.state, "postgres_pool", None):
if hasattr(container, "state") and not getattr(container.state, "postgres_pool", None):
container.state.postgres_pool = await create_db_pool(pg_dsn)

# INIT S3 #
Expand All @@ -35,7 +35,7 @@ async def lifespan(*containers: list[FastAPI | ObjectWithState]):
s3_endpoint = os.environ.get("S3_ENDPOINT")

for container in containers:
if not getattr(container.state, "s3_client", None):
if hasattr(container, "state") and not getattr(container.state, "s3_client", None):
session = get_session()
container.state.s3_client = await session.create_client(
"s3",
Expand All @@ -49,13 +49,13 @@ async def lifespan(*containers: list[FastAPI | ObjectWithState]):
finally:
# CLOSE POSTGRES #
for container in containers:
if getattr(container.state, "postgres_pool", None):
if hasattr(container, "state") and getattr(container.state, "postgres_pool", None):
await container.state.postgres_pool.close()
container.state.postgres_pool = None

# CLOSE S3 #
for container in containers:
if getattr(container.state, "s3_client", None):
if hasattr(container, "state") and getattr(container.state, "s3_client", None):
await container.state.s3_client.close()
container.state.s3_client = None

Expand Down
4 changes: 1 addition & 3 deletions agents-api/agents_api/clients/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,8 @@ async def aembedding(
embedding_list: list[dict[Literal["embedding"], list[float]]] = response.data

# Truncate the embedding to the specified dimensions
embedding_list = [
return [
item["embedding"][:dimensions]
for item in embedding_list
if len(item["embedding"]) >= dimensions
]

return embedding_list
164 changes: 84 additions & 80 deletions agents-api/agents_api/queries/chat/gather_messages.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import TypeVar
from uuid import UUID

import numpy as np
from beartype import beartype
from fastapi import HTTPException
from pydantic import ValidationError
Expand All @@ -10,14 +11,13 @@
from ...common.protocol.developers import Developer
from ...common.protocol.sessions import ChatContext
from ...common.utils.db_exceptions import common_db_exceptions, partialclass
from ..docs.mmr import maximal_marginal_relevance
from ..docs.search_docs_by_embedding import search_docs_by_embedding
from ..docs.search_docs_by_text import search_docs_by_text
from ..docs.search_docs_hybrid import search_docs_hybrid
from ..entries.get_history import get_history
from ..sessions.get_session import get_session
from ..utils import rewrap_exceptions
from ..docs.mmr import maximal_marginal_relevance
import numpy as np

T = TypeVar("T")

Expand Down Expand Up @@ -77,86 +77,90 @@ async def gather_messages(
)
recall_options = session.recall_options

# search the last `search_threshold` messages
search_messages = [
msg
for msg in (past_messages + new_raw_messages)[-(recall_options.num_search_messages) :]
if isinstance(msg["content"], str) and msg["role"] in ["user", "assistant"]
]

if len(search_messages) == 0:
return past_messages, []
# Ensure recall_options is not None and has the necessary attributes
if recall and recall_options:
# search the last `search_threshold` messages
search_messages = [
msg
for msg in (past_messages + new_raw_messages)[
-(recall_options.num_search_messages) :
]
if isinstance(msg["content"], str) and msg["role"] in ["user", "assistant"]
]

# Search matching docs
embed_text = "\n\n".join([
f"{msg.get('name') or msg['role']}: {msg['content']}" for msg in search_messages
]).strip()

# Don't embed if search mode is text only
if recall_options.mode != "text":
[query_embedding, *_] = await litellm.aembedding(
# Truncate on the left to keep the last `search_query_chars` characters
inputs=embed_text[-(recall_options.max_query_length) :],
# TODO: Make this configurable once it's added to the ChatInput model
embed_instruction="Represent the query for retrieving supporting documents: ",
)

# Truncate on the right to take only the first `search_query_chars` characters
query_text = search_messages[-1]["content"].strip()[: recall_options.max_query_length]

# List all the applicable owners to search docs from
active_agent_id = chat_context.get_active_agent().id
user_ids = [user.id for user in chat_context.users]
owners = [("user", user_id) for user_id in user_ids] + [("agent", active_agent_id)]

# Search for doc references
doc_references: list[DocReference] = []
match recall_options.mode:
case "vector":
doc_references: list[DocReference] = await search_docs_by_embedding(
developer_id=developer.id,
owners=owners,
query_embedding=query_embedding,
connection_pool=connection_pool,
)
case "hybrid":
doc_references: list[DocReference] = await search_docs_hybrid(
developer_id=developer.id,
owners=owners,
text_query=query_text,
embedding=query_embedding,
connection_pool=connection_pool,
)
case "text":
doc_references: list[DocReference] = await search_docs_by_text(
developer_id=developer.id,
owners=owners,
query=query_text,
connection_pool=connection_pool,
if len(search_messages) == 0:
return past_messages, []

# Search matching docs
embed_text = "\n\n".join([
f"{msg.get('name') or msg['role']}: {msg['content']}" for msg in search_messages
]).strip()

# Don't embed if search mode is text only
if recall_options.mode != "text":
[query_embedding, *_] = await litellm.aembedding(
# Truncate on the left to keep the last `search_query_chars` characters
inputs=embed_text[-(recall_options.max_query_length) :],
# TODO: Make this configurable once it's added to the ChatInput model
embed_instruction="Represent the query for retrieving supporting documents: ",
)

# Apply MMR if enabled
if (
recall_options.mmr_strength > 0
and len(doc_references) > recall_options.limit
and recall_options.mode != "text"
and len([doc for doc in doc_references if doc.snippet.embedding is not None]) >= 2
):
# FIXME: This is a temporary fix to ensure that the MMR algorithm works.
# We shouldn't be having references without embeddings.
doc_references = [
doc for doc in doc_references if doc.snippet.embedding is not None
]
# Truncate on the right to take only the first `search_query_chars` characters
query_text = search_messages[-1]["content"].strip()[: recall_options.max_query_length]

# List all the applicable owners to search docs from
active_agent_id = chat_context.get_active_agent().id
user_ids = [user.id for user in chat_context.users]
owners = [("user", user_id) for user_id in user_ids] + [("agent", active_agent_id)]

# Search for doc references
doc_references: list[DocReference] = []
match recall_options.mode:
case "vector":
doc_references = await search_docs_by_embedding(
developer_id=developer.id,
owners=owners,
query_embedding=query_embedding,
connection_pool=connection_pool,
)
case "hybrid":
doc_references = await search_docs_hybrid(
developer_id=developer.id,
owners=owners,
text_query=query_text,
embedding=query_embedding,
connection_pool=connection_pool,
)
case "text":
doc_references = await search_docs_by_text(
developer_id=developer.id,
owners=owners,
query=query_text,
connection_pool=connection_pool,
)

# Apply MMR if enabled
if (
recall_options.mmr_strength > 0
and len(doc_references) > recall_options.limit
and recall_options.mode != "text"
and len([doc for doc in doc_references if doc.snippet.embedding is not None]) >= 2
):
# FIXME: This is a temporary fix to ensure that the MMR algorithm works.
# We shouldn't be having references without embeddings.
doc_references = [
doc for doc in doc_references if doc.snippet.embedding is not None
]

# Apply MMR
indices = maximal_marginal_relevance(
np.asarray(query_embedding),
[doc.snippet.embedding for doc in doc_references],
k=recall_options.limit,
)
doc_references = [doc for i, doc in enumerate(doc_references) if i in set(indices)]

# Apply MMR
indices = maximal_marginal_relevance(
np.asarray(query_embedding),
[doc.snippet.embedding for doc in doc_references],
k=recall_options.limit,
)
# Apply MMR
doc_references = [
doc for i, doc in enumerate(doc_references) if i in set(indices)
]
return past_messages, doc_references

return past_messages, doc_references
# If recall is False or recall_options is None, return past messages with no doc references
return past_messages, []
6 changes: 0 additions & 6 deletions agents-api/scripts/agents_api.py

This file was deleted.

2 changes: 1 addition & 1 deletion agents-api/tests/test_chat_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ async def _(
connection_pool=pool,
)

(embed, _) = mocks
(_embed, _) = mocks

chat_context = await prepare_chat_context(
developer_id=developer_id,
Expand Down
14 changes: 11 additions & 3 deletions cookbooks/01-website-crawler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import uuid

import yaml
from julep import Client

Expand All @@ -10,7 +11,8 @@
# Creating Julep Client with the API Key
api_key = os.getenv("JULEP_API_KEY")
if not api_key:
raise ValueError("JULEP_API_KEY not found in environment variables")
msg = "JULEP_API_KEY not found in environment variables"
raise ValueError(msg)

client = Client(api_key=api_key, environment="dev")

Expand All @@ -26,6 +28,11 @@
model="gpt-4o",
)

spider_api_key = os.getenv("SPIDER_API_KEY")
if not spider_api_key:
msg = "SPIDER_API_KEY not found in environment variables"
raise ValueError(msg)

# Defining a Task
task_def = yaml.safe_load(f"""
name: Crawling Task
Expand Down Expand Up @@ -63,7 +70,7 @@
page['content'] for page in _['result']
)
)
# Prompt step to create a summary of the results
- prompt: |
You are {{{{agent.about}}}}
Expand All @@ -90,6 +97,7 @@

# Waiting for the execution to complete
import time

time.sleep(5)

# Getting the execution details
Expand All @@ -104,4 +112,4 @@

# Stream the steps of the defined task
print("Streaming execution transitions:")
print(client.executions.transitions.stream(execution_id=execution.id))
print(client.executions.transitions.stream(execution_id=execution.id))
8 changes: 5 additions & 3 deletions cookbooks/02-sarcastic-news-headline-generator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import uuid

import yaml
from julep import Client

Expand All @@ -10,7 +11,8 @@
# Create Julep Client with the API Key
api_key = os.getenv("JULEP_API_KEY")
if not api_key:
raise ValueError("JULEP_API_KEY not found in environment variables")
msg = "JULEP_API_KEY not found in environment variables"
raise ValueError(msg)

client = Client(api_key=api_key, environment="dev")

Expand Down Expand Up @@ -76,7 +78,8 @@
)

# Waiting for the execution to complete
import time
import time

time.sleep(5)

# Getting the execution details
Expand All @@ -92,4 +95,3 @@
# Stream the steps of the defined task
print("Streaming execution transitions:")
print(client.executions.transitions.stream(execution_id=execution.id))

12 changes: 7 additions & 5 deletions cookbooks/03-trip-planning-assistant.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os
import uuid

import yaml
from julep import Client
import os

openweathermap_api_key = os.getenv("OPENWEATHERMAP_API_KEY")
brave_api_key = os.getenv("BRAVE_API_KEY")
Expand Down Expand Up @@ -139,17 +140,18 @@

# Wait for the execution to complete
import time

time.sleep(200)

# Getting the execution details
# Get execution details
execution = client.executions.get(execution.id)
# Print the output
print(execution.output)
print("-"*50)
print("-" * 50)

if 'final_plan' in execution.output:
print(execution.output['final_plan'])
if "final_plan" in execution.output:
print(execution.output["final_plan"])

# Lists all the task steps that have been executed up to this point in time
transitions = client.executions.transitions.list(execution_id=execution.id).items
Expand All @@ -158,4 +160,4 @@
for transition in reversed(transitions):
print("Transition type: ", transition.type)
print("Transition output: ", transition.output)
print("-"*50)
print("-" * 50)
Loading

0 comments on commit 5608d8e

Please sign in to comment.