Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(agents-api): Add jinja templates support #300

Merged
merged 4 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions agents-api/agents_api/common/protocol/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,7 @@ class SessionData(BaseModel):
created_at: float
model: str
default_settings: SessionSettings
render_templates: bool = False
metadata: dict = {}
user_metadata: dict = {}
agent_metadata: dict = {}
38 changes: 38 additions & 0 deletions agents-api/agents_api/common/utils/template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import arrow
from jinja2.sandbox import ImmutableSandboxedEnvironment
from jinja2schema import infer, to_json_schema
from jsonschema import validate

__all__ = [
"render_template",
]

# jinja environment
jinja_env = ImmutableSandboxedEnvironment(
autoescape=False,
trim_blocks=True,
lstrip_blocks=True,
auto_reload=False,
enable_async=True,
loader=None,
)

# Add arrow to jinja
jinja_env.globals["arrow"] = arrow


# Funcs
async def render_template(
template_string: str, variables: dict, check: bool = False
) -> str:
# Parse template
template = jinja_env.from_string(template_string)

# If check is required, get required vars from template and validate variables
if check:
schema = to_json_schema(infer(template_string))
validate(instance=variables, schema=schema)

# Render
rendered = await template.render_async(**variables)
return rendered
7 changes: 6 additions & 1 deletion agents-api/agents_api/models/session/create_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def create_session_query(
user_id: UUID | None,
situation: str | None,
metadata: dict = {},
render_templates: bool = False,
) -> tuple[str, dict]:
"""
Constructs and executes a datalog query to create a new session in the database.
Expand All @@ -28,6 +29,7 @@ def create_session_query(
- user_id (UUID | None): The unique identifier for the user, if applicable.
- situation (str | None): The situation/context of the session.
- metadata (dict): Additional metadata for the session.
- render_templates (bool): Specifies whether to render templates.

Returns:
- pd.DataFrame: The result of the query execution.
Expand All @@ -52,18 +54,20 @@ def create_session_query(
}
} {
# Insert the new session data into the 'session' table with the specified columns.
?[session_id, developer_id, situation, metadata] <- [[
?[session_id, developer_id, situation, metadata, render_templates] <- [[
$session_id,
$developer_id,
$situation,
$metadata,
$render_templates,
]]

:insert sessions {
developer_id,
session_id,
situation,
metadata,
render_templates,
}
# Specify the data to return after the query execution, typically the newly created session's ID.
:returning
Expand All @@ -79,5 +83,6 @@ def create_session_query(
"developer_id": str(developer_id),
"situation": situation,
"metadata": metadata,
"render_templates": render_templates,
},
)
2 changes: 2 additions & 0 deletions agents-api/agents_api/models/session/get_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def get_session_query(
updated_at,
created_at,
metadata,
render_templates,
] := input[developer_id, id],
*sessions{
developer_id,
Expand All @@ -49,6 +50,7 @@ def get_session_query(
created_at,
updated_at: validity,
metadata,
render_templates,
@ "NOW"
},
*session_lookup{
Expand Down
4 changes: 4 additions & 0 deletions agents-api/agents_api/models/session/patch_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
]


# TODO: Add support for updating `render_templates` field


@cozo_query
def patch_session_query(
session_id: UUID,
Expand All @@ -41,6 +44,7 @@ def patch_session_query(
},
session_id = to_uuid($session_id),
developer_id = to_uuid($developer_id),

# Assertion to ensure the session exists before updating.
:assert some
"""
Expand Down
14 changes: 8 additions & 6 deletions agents-api/agents_api/models/session/session_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,10 @@ def session_data_query(
agent_about,
model,
default_settings,
session_metadata,
users_metadata,
agents_metadata,
metadata,
render_templates,
user_metadata,
agent_metadata,
] := input[developer_id, session_id],
*sessions{
developer_id,
Expand All @@ -53,7 +54,8 @@ def session_data_query(
summary,
created_at,
updated_at: validity,
metadata: session_metadata,
metadata,
render_templates,
@ "NOW"
},
*session_lookup{
Expand All @@ -65,14 +67,14 @@ def session_data_query(
user_id,
name: user_name,
about: user_about,
metadata: users_metadata,
metadata: user_metadata,
},
*agents{
agent_id,
name: agent_name,
about: agent_about,
model,
metadata: agents_metadata,
metadata: agent_metadata,
},
*agent_default_settings {
agent_id,
Expand Down
2 changes: 2 additions & 0 deletions agents-api/agents_api/models/session/update_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
"developer_id",
]

# TODO: Add support for updating `render_templates` field


@cozo_query
def update_session_query(
Expand Down
1 change: 1 addition & 0 deletions agents-api/agents_api/routers/sessions/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ async def create_session(
user_id=request.user_id,
situation=request.situation,
metadata=request.metadata or {},
render_templates=request.render_templates or False,
)

return ResourceCreatedResponse(
Expand Down
67 changes: 52 additions & 15 deletions agents-api/agents_api/routers/sessions/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,26 @@
from openai.types.chat.chat_completion import ChatCompletion
from pydantic import UUID4

from agents_api.clients.embed import embed
from agents_api.env import summarization_tokens_threshold
from agents_api.clients.temporal import run_summarization_task
from agents_api.models.entry.add_entries import add_entries_query
from agents_api.common.protocol.entries import Entry
from agents_api.common.exceptions.sessions import SessionNotFoundError
from agents_api.clients.worker.types import ChatML
from agents_api.models.session.session_data import get_session_data
from agents_api.models.entry.proc_mem_context import proc_mem_context_query
from agents_api.autogen.openapi_model import InputChatMLMessage, Tool
from agents_api.model_registry import (
from ...autogen.openapi_model import InputChatMLMessage, Tool
from ...clients.embed import embed
from ...clients.temporal import run_summarization_task
from ...clients.worker.types import ChatML
from ...common.exceptions.sessions import SessionNotFoundError
from ...common.protocol.entries import Entry
from ...common.protocol.sessions import SessionData
from ...common.utils.template import render_template
from ...env import summarization_tokens_threshold
from ...model_registry import (
get_extra_settings,
get_model_client,
load_context,
)
from ...common.protocol.sessions import SessionData
from .protocol import Settings
from ...models.entry.add_entries import add_entries_query
from ...models.entry.proc_mem_context import proc_mem_context_query
from ...models.session.session_data import get_session_data

from .exceptions import InputTooBigError
from .protocol import Settings


THOUGHTS_STRIP_LEN = 2
Expand Down Expand Up @@ -118,18 +120,22 @@ async def run(
self, new_input, settings: Settings
) -> tuple[ChatCompletion, Entry, Callable | None]:
# TODO: implement locking at some point

# Get session data
session_data = get_session_data(self.developer_id, self.session_id)
if session_data is None:
raise SessionNotFoundError(self.developer_id, self.session_id)

# Assemble context
init_context, final_settings = await self.forward(
session_data, new_input, settings
)

# Generate response
response = await self.generate(
self.truncate(init_context, summarization_tokens_threshold), final_settings
)

# Save response to session
# if final_settings.get("remember"):
# await self.add_to_session(new_input, response)
Expand Down Expand Up @@ -195,10 +201,11 @@ async def forward(
)

entries: list[Entry] = []
instructions = "IMPORTANT INSTRUCTIONS:\n\n"
instructions = "Instructions:\n\n"
first_instruction_idx = -1
first_instruction_created_at = 0
tools = []

for idx, row in proc_mem_context_query(
session_id=self.session_id,
tool_query_embedding=tool_query_embedding,
Expand All @@ -224,7 +231,7 @@ async def forward(
first_instruction_idx = idx
first_instruction_created_at = row["created_at"]

instructions += f"- {row['content']}\n"
instructions += f"{row['content']}\n\n"

continue

Expand Down Expand Up @@ -266,6 +273,36 @@ async def forward(
if e.content
]

# If render_templates=True, render the templates
if session_data is not None and session_data.render_templates:

template_data = {
"session": {
"id": session_data.session_id,
"situation": session_data.situation,
"metadata": session_data.metadata,
},
"user": {
"id": session_data.user_id,
"name": session_data.user_name,
"about": session_data.user_about,
"metadata": session_data.user_metadata,
},
"agent": {
"id": session_data.agent_id,
"name": session_data.agent_name,
"about": session_data.agent_about,
"metadata": session_data.agent_metadata,
},
}

for i, msg in enumerate(messages):
# Only render templates for system/assistant messages
if msg.role not in ["system", "assistant"]:
continue

messages[i].content = await render_template(msg.content, template_data)

# FIXME: This sometimes returns "The model `` does not exist."
if session_data is not None:
settings.model = session_data.model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,13 +206,7 @@ def run(client, *queries):
query = joiner.join(queries)
query = f"{{\n{query}\n}}"

try:
client.run(query)
except Exception as error:
print(error)
import pdb

pdb.set_trace()
client.run(query)


def up(client):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# /usr/bin/env python3

MIGRATION_ID = "session_render_templates"
CREATED_AT = 1714119679.493182

extend_sessions = {
"up": """
?[render_templates, developer_id, session_id, updated_at, situation, summary, created_at, developer_id, metadata] := *sessions{
session_id,
updated_at,
situation,
summary,
created_at,
developer_id
},
metadata = {},
render_templates = false

:replace sessions {
developer_id: Uuid,
session_id: Uuid,
updated_at: Validity default [floor(now()), true],
=>
situation: String,
summary: String? default null,
created_at: Float default now(),
metadata: Json default {},
render_templates: Bool default false,
}
""",
"down": """
?[developer_id, session_id, updated_at, situation, summary, created_at, developer_id, metadata] := *sessions{
session_id,
updated_at,
situation,
summary,
created_at,
developer_id
}, metadata = {}

:replace sessions {
developer_id: Uuid,
session_id: Uuid,
updated_at: Validity default [floor(now()), true],
=>
situation: String,
summary: String? default null,
created_at: Float default now(),
metadata: Json default {},
}
""",
}


queries_to_run = [
extend_sessions,
]


def up(client):
for q in queries_to_run:
client.run(q["up"])


def down(client):
for q in reversed(queries_to_run):
client.run(q["down"])
Loading
Loading