Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

X/misc fixes #1016

Merged
merged 6 commits into from
Jan 5, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions agents-api/agents_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,22 @@ class ObjectWithState(Protocol):
state: State


pool = None


# TODO: This currently doesn't use env.py, we should move to using them
@asynccontextmanager
async def lifespan(*containers: FastAPI | ObjectWithState):
# INIT POSTGRES #
pg_dsn = os.environ.get("PG_DSN")

global pool
if not pool:
pool = await create_db_pool(pg_dsn)

for container in containers:
if hasattr(container, "state") and not getattr(container.state, "postgres_pool", None):
container.state.postgres_pool = await create_db_pool(pg_dsn)
container.state.postgres_pool = pool

# INIT S3 #
s3_access_key = os.environ.get("S3_ACCESS_KEY")
Expand All @@ -50,13 +57,13 @@ async def lifespan(*containers: FastAPI | ObjectWithState):
try:
yield
finally:
# CLOSE POSTGRES #
for container in containers:
if hasattr(container, "state") and getattr(container.state, "postgres_pool", None):
pool = getattr(container.state, "postgres_pool", None)
if pool:
await pool.close()
container.state.postgres_pool = None
# # CLOSE POSTGRES #
Vedantsahai18 marked this conversation as resolved.
Show resolved Hide resolved
# for container in containers:
# if hasattr(container, "state") and getattr(container.state, "postgres_pool", None):
# pool = getattr(container.state, "postgres_pool", None)
# if pool:
# await pool.close()
# container.state.postgres_pool = None

# CLOSE S3 #
for container in containers:
Expand Down
14 changes: 7 additions & 7 deletions agents-api/agents_api/queries/docs/search_docs_by_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
search_docs_by_embedding_query = """
SELECT * FROM search_by_vector(
$1, -- developer_id
$2::vector(1024), -- query_embedding
$2::vector(1024), -- embedding
$3::text[], -- owner_types
$4::uuid[], -- owner_ids
$5, -- k
Expand All @@ -33,7 +33,7 @@
async def search_docs_by_embedding(
*,
developer_id: UUID,
query_embedding: list[float],
embedding: list[float],
k: int = 10,
owners: list[tuple[Literal["user", "agent"], UUID]],
confidence: float = 0.5,
Expand All @@ -44,7 +44,7 @@ async def search_docs_by_embedding(

Parameters:
developer_id (UUID): The ID of the developer.
query_embedding (List[float]): The vector to query.
embedding (List[float]): The vector to query.
k (int): The number of results to return.
owners (list[tuple[Literal["user", "agent"], UUID]]): List of (owner_type, owner_id) tuples.
confidence (float): The confidence threshold for the search.
Expand All @@ -56,11 +56,11 @@ async def search_docs_by_embedding(
if k < 1:
raise HTTPException(status_code=400, detail="k must be >= 1")

if not query_embedding:
if not embedding:
raise HTTPException(status_code=400, detail="Empty embedding provided")

# Convert query_embedding to a string
query_embedding_str = f"[{', '.join(map(str, query_embedding))}]"
# Convert embedding to a string
embedding_str = f"[{', '.join(map(str, embedding))}]"

# Extract owner types and IDs
owner_types: list[str] = [owner[0] for owner in owners]
Expand All @@ -70,7 +70,7 @@ async def search_docs_by_embedding(
search_docs_by_embedding_query,
[
developer_id,
query_embedding_str,
embedding_str,
owner_types,
owner_ids,
k,
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/queries/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ async def wrapper(
pool = (
connection_pool
if connection_pool is not None
else cast(asyncpg.Pool, app.state.postgres_pool)
else cast(asyncpg.Pool, getattr(app.state, "postgres_pool", None))
)

try:
Expand Down
12 changes: 6 additions & 6 deletions agents-api/agents_api/routers/docs/search_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,22 +40,22 @@ def get_search_fn_and_params(
}

case VectorDocSearchRequest(
vector=query_embedding,
vector=embedding,
limit=k,
confidence=confidence,
metadata_filter=metadata_filter,
):
search_fn = search_docs_by_embedding
params = {
"query_embedding": query_embedding,
"embedding": embedding,
"k": k * 3 if search_params.mmr_strength > 0 else k,
"confidence": confidence,
"metadata_filter": metadata_filter,
}

case HybridDocSearchRequest(
text=query,
vector=query_embedding,
vector=embedding,
lang=lang,
limit=k,
confidence=confidence,
Expand All @@ -66,7 +66,7 @@ def get_search_fn_and_params(
search_fn = search_docs_hybrid
params = {
"text_query": query,
"embedding": query_embedding,
"embedding": embedding,
"k": k * 3 if search_params.mmr_strength > 0 else k,
"confidence": confidence,
"alpha": alpha,
Expand Down Expand Up @@ -111,7 +111,7 @@ async def search_user_docs(
and len(docs) > search_params.limit
):
indices = maximal_marginal_relevance(
np.asarray(params["query_embedding"]),
np.asarray(params["embedding"]),
[doc.snippet.embedding for doc in docs],
k=search_params.limit,
)
Expand Down Expand Up @@ -160,7 +160,7 @@ async def search_agent_docs(
and len(docs) > search_params.limit
):
indices = maximal_marginal_relevance(
np.asarray(params["query_embedding"]),
np.asarray(params["embedding"]),
[doc.snippet.embedding for doc in docs],
k=search_params.limit,
)
Expand Down
2 changes: 1 addition & 1 deletion memory-store/migrations/000004_agents.up.sql
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ CREATE TABLE IF NOT EXISTS agents (
),
about TEXT CONSTRAINT ct_agents_about_length CHECK (
about IS NULL
OR length(about) <= 1000
OR length(about) <= 5000
),
instructions TEXT[] DEFAULT ARRAY[]::TEXT[],
model TEXT NOT NULL,
Expand Down
2 changes: 1 addition & 1 deletion memory-store/migrations/000015_entries.up.sql
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ CREATE TABLE IF NOT EXISTS entries (
token_count INTEGER DEFAULT NULL,
tokenizer TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
timestamp DOUBLE PRECISION NOT NULL,
timestamp TIMESTAMPTZ NOT NULL,
CONSTRAINT pk_entries PRIMARY KEY (session_id, entry_id, created_at),
CONSTRAINT ct_content_is_array_of_objects CHECK (all_jsonb_elements_are_objects (content))
);
Expand Down
Loading