Skip to content

Commit

Permalink
chore: minor refactors
Browse files Browse the repository at this point in the history
  • Loading branch information
Vedantsahai18 committed Dec 18, 2024
1 parent db31801 commit 638fefb
Show file tree
Hide file tree
Showing 12 changed files with 171 additions and 153 deletions.
10 changes: 10 additions & 0 deletions agents-api/agents_api/queries/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,13 @@
from .list_agents import list_agents
from .patch_agent import patch_agent
from .update_agent import update_agent

__all__ = [
"create_agent",
"create_or_update_agent",
"delete_agent",
"get_agent",
"list_agents",
"patch_agent",
"update_agent",
]
15 changes: 7 additions & 8 deletions agents-api/agents_api/queries/agents/create_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,8 @@
wrap_in_class,
)

ModelT = TypeVar("ModelT", bound=Any)
T = TypeVar("T")

raw_query = """
# Define the raw SQL query
agent_query = parse_one("""
INSERT INTO agents (
developer_id,
agent_id,
Expand All @@ -46,9 +44,7 @@
$9
)
RETURNING *;
"""

query = parse_one(raw_query).sql(pretty=True)
""").sql(pretty=True)


# @rewrap_exceptions(
Expand Down Expand Up @@ -135,4 +131,7 @@ async def create_agent(
default_settings,
]

return query, params
return (
agent_query,
params,
)
15 changes: 7 additions & 8 deletions agents-api/agents_api/queries/agents/create_or_update_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,8 @@
wrap_in_class,
)

ModelT = TypeVar("ModelT", bound=Any)
T = TypeVar("T")

raw_query = """
# Define the raw SQL query
agent_query = parse_one("""
INSERT INTO agents (
developer_id,
agent_id,
Expand All @@ -45,9 +43,7 @@
$9
)
RETURNING *;
"""

query = parse_one(raw_query).sql(pretty=True)
""").sql(pretty=True)


# @rewrap_exceptions(
Expand Down Expand Up @@ -110,4 +106,7 @@ async def create_or_update_agent(
default_settings,
]

return (query, params)
return (
agent_query,
params,
)
20 changes: 8 additions & 12 deletions agents-api/agents_api/queries/agents/delete_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,14 @@

from ...autogen.openapi_model import ResourceDeletedResponse
from ...common.utils.datetime import utcnow
from ...metrics.counters import increase_counter
from ..utils import (
pg_query,
rewrap_exceptions,
wrap_in_class,
)

ModelT = TypeVar("ModelT", bound=Any)
T = TypeVar("T")

raw_query = """
# Define the raw SQL query
agent_query = parse_one("""
WITH deleted_docs AS (
DELETE FROM docs
WHERE developer_id = $1
Expand All @@ -41,13 +38,10 @@
DELETE FROM agents
WHERE agent_id = $2 AND developer_id = $1
RETURNING developer_id, agent_id;
"""


# Convert the list of queries into a single query string
query = parse_one(raw_query).sql(pretty=True)
""").sql(pretty=True)


# @rewrap_exceptions(
# @rewrap_exceptions(
# {
# psycopg_errors.ForeignKeyViolation: partialclass(
Expand All @@ -63,7 +57,6 @@
one=True,
transform=lambda d: {**d, "id": d["agent_id"], "deleted_at": utcnow()},
)
@increase_counter("delete_agent")
@pg_query
@beartype
async def delete_agent(*, agent_id: UUID, developer_id: UUID) -> tuple[str, list]:
Expand All @@ -80,4 +73,7 @@ async def delete_agent(*, agent_id: UUID, developer_id: UUID) -> tuple[str, list
# Note: We swap the parameter order because the queries use $1 for developer_id and $2 for agent_id
params = [developer_id, agent_id]

return (query, params)
return (
agent_query,
params,
)
17 changes: 7 additions & 10 deletions agents-api/agents_api/queries/agents/get_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
from sqlglot import parse_one

from ...autogen.openapi_model import Agent
from ...metrics.counters import increase_counter
from ..utils import (
pg_query,
rewrap_exceptions,
wrap_in_class,
)

raw_query = """
# Define the raw SQL query
agent_query = parse_one("""
SELECT
agent_id,
developer_id,
Expand All @@ -34,12 +34,7 @@
agents
WHERE
agent_id = $2 AND developer_id = $1;
"""

query = parse_one(raw_query).sql(pretty=True)

ModelT = TypeVar("ModelT", bound=Any)
T = TypeVar("T")
""").sql(pretty=True)


# @rewrap_exceptions(
Expand All @@ -53,7 +48,6 @@
# # TODO: Add more exceptions
# )
@wrap_in_class(Agent, one=True, transform=lambda d: {"id": d["agent_id"], **d})
@increase_counter("get_agent")
@pg_query
@beartype
async def get_agent(*, agent_id: UUID, developer_id: UUID) -> tuple[str, list]:
Expand All @@ -68,4 +62,7 @@ async def get_agent(*, agent_id: UUID, developer_id: UUID) -> tuple[str, list]:
tuple[list[str], dict]: A tuple containing the SQL query and its parameters.
"""

return (query, [developer_id, agent_id])
return (
agent_query,
[developer_id, agent_id],
)
13 changes: 6 additions & 7 deletions agents-api/agents_api/queries/agents/list_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,13 @@
from fastapi import HTTPException

from ...autogen.openapi_model import Agent
from ...metrics.counters import increase_counter
from ..utils import (
pg_query,
rewrap_exceptions,
wrap_in_class,
)

ModelT = TypeVar("ModelT", bound=Any)
T = TypeVar("T")

# Define the raw SQL query
raw_query = """
SELECT
agent_id,
Expand Down Expand Up @@ -55,7 +52,6 @@
# # TODO: Add more exceptions
# )
@wrap_in_class(Agent, transform=lambda d: {"id": d["agent_id"], **d})
@increase_counter("list_agents")
@pg_query
@beartype
async def list_agents(
Expand Down Expand Up @@ -87,7 +83,7 @@ async def list_agents(

# Build metadata filter clause if needed

final_query = raw_query.format(
agent_query = raw_query.format(
metadata_filter_query="AND metadata @> $6::jsonb" if metadata_filter else ""
)

Expand All @@ -102,4 +98,7 @@ async def list_agents(
if metadata_filter:
params.append(metadata_filter)

return final_query, params
return (
agent_query,
params,
)
14 changes: 7 additions & 7 deletions agents-api/agents_api/queries/agents/patch_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@
wrap_in_class,
)

ModelT = TypeVar("ModelT", bound=Any)
T = TypeVar("T")

raw_query = """
# Define the raw SQL query
agent_query = parse_one("""
UPDATE agents
SET
name = CASE
Expand All @@ -45,9 +44,7 @@
END
WHERE agent_id = $2 AND developer_id = $1
RETURNING *;
"""

query = parse_one(raw_query).sql(pretty=True)
""").sql(pretty=True)


# @rewrap_exceptions(
Expand Down Expand Up @@ -92,4 +89,7 @@ async def patch_agent(
data.default_settings.model_dump() if data.default_settings else None,
]

return query, params
return (
agent_query,
params,
)
15 changes: 7 additions & 8 deletions agents-api/agents_api/queries/agents/update_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@
wrap_in_class,
)

ModelT = TypeVar("ModelT", bound=Any)
T = TypeVar("T")

raw_query = """
# Define the raw SQL query
agent_query = parse_one("""
UPDATE agents
SET
metadata = $3,
Expand All @@ -30,9 +28,7 @@
default_settings = $7::jsonb
WHERE agent_id = $2 AND developer_id = $1
RETURNING *;
"""

query = parse_one(raw_query).sql(pretty=True)
""").sql(pretty=True)


# @rewrap_exceptions(
Expand Down Expand Up @@ -77,4 +73,7 @@ async def update_agent(
data.default_settings.model_dump() if data.default_settings else {},
]

return (query, params)
return (
agent_query,
params,
)
72 changes: 39 additions & 33 deletions agents-api/agents_api/queries/entries/create_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ...common.utils.datetime import utcnow
from ...common.utils.messages import content_to_json
from ...metrics.counters import increase_counter
from ..utils import pg_query, rewrap_exceptions, wrap_in_class
from ..utils import pg_query, rewrap_exceptions, wrap_in_class, partialclass

# Query for checking if the session exists
session_exists_query = """
Expand Down Expand Up @@ -53,26 +53,30 @@
"""


@rewrap_exceptions(
{
asyncpg.ForeignKeyViolationError: lambda exc: HTTPException(
status_code=404,
detail=str(exc),
),
asyncpg.UniqueViolationError: lambda exc: HTTPException(
status_code=409,
detail=str(exc),
),
asyncpg.NotNullViolationError: lambda exc: HTTPException(
status_code=400,
detail=str(exc),
),
asyncpg.NoDataFoundError: lambda exc: HTTPException(
status_code=404,
detail="Session not found",
),
}
)
# @rewrap_exceptions(
# {
# asyncpg.ForeignKeyViolationError: partialclass(
# HTTPException,
# status_code=404,
# detail="Session not found",
# ),
# asyncpg.UniqueViolationError: partialclass(
# HTTPException,
# status_code=409,
# detail="Entry already exists",
# ),
# asyncpg.NotNullViolationError: partialclass(
# HTTPException,
# status_code=400,
# detail="Not null violation",
# ),
# asyncpg.NoDataFoundError: partialclass(
# HTTPException,
# status_code=404,
# detail="Session not found",
# ),
# }
# )
@wrap_in_class(
Entry,
transform=lambda d: {
Expand Down Expand Up @@ -128,18 +132,20 @@ async def create_entries(
]


@rewrap_exceptions(
{
asyncpg.ForeignKeyViolationError: lambda exc: HTTPException(
status_code=404,
detail=str(exc),
),
asyncpg.UniqueViolationError: lambda exc: HTTPException(
status_code=409,
detail=str(exc),
),
}
)
# @rewrap_exceptions(
# {
# asyncpg.ForeignKeyViolationError: partialclass(
# HTTPException,
# status_code=404,
# detail="Session not found",
# ),
# asyncpg.UniqueViolationError: partialclass(
# HTTPException,
# status_code=409,
# detail="Entry already exists",
# ),
# }
# )
@wrap_in_class(Relation)
@increase_counter("add_entry_relations")
@pg_query
Expand Down
Loading

0 comments on commit 638fefb

Please sign in to comment.