From 2850cf43ce180ab85031e523a9f00ee633c73917 Mon Sep 17 00:00:00 2001 From: Diwank Singh Tomer Date: Fri, 26 Apr 2024 22:51:49 +0530 Subject: [PATCH] feat(agents-api): Add jinja templates support (#300) * feat(agents-api): Add 'render_templates' field to sessions relation Signed-off-by: Diwank Singh Tomer * feat(agents-api): Add jinja env Signed-off-by: Diwank Singh Tomer * feat(agents-api): Add jinja templates support Signed-off-by: Diwank Singh Tomer --------- Signed-off-by: Diwank Singh Tomer --- .../agents_api/common/protocol/sessions.py | 4 + .../agents_api/common/utils/template.py | 38 ++++ .../models/session/create_session.py | 7 +- .../agents_api/models/session/get_session.py | 2 + .../models/session/patch_session.py | 4 + .../agents_api/models/session/session_data.py | 14 +- .../models/session/update_session.py | 2 + .../agents_api/routers/sessions/routers.py | 1 + .../agents_api/routers/sessions/session.py | 67 ++++-- ...grate_1707537826_rename_additional_info.py | 8 +- ...ate_1714119679_session_render_templates.py | 67 ++++++ sdks/python/julep/managers/session.py | 5 + sdks/python/tests/fixtures.py | 21 ++ sdks/python/tests/test_sessions.py | 21 ++ sdks/ts/src/managers/session.ts | 192 ++++++++---------- sdks/ts/tests/sessions.test.ts | 67 ++++-- 16 files changed, 364 insertions(+), 156 deletions(-) create mode 100644 agents-api/agents_api/common/utils/template.py create mode 100644 agents-api/migrations/migrate_1714119679_session_render_templates.py diff --git a/agents-api/agents_api/common/protocol/sessions.py b/agents-api/agents_api/common/protocol/sessions.py index b2a6d7c28..a85ee1ae4 100644 --- a/agents-api/agents_api/common/protocol/sessions.py +++ b/agents-api/agents_api/common/protocol/sessions.py @@ -38,3 +38,7 @@ class SessionData(BaseModel): created_at: float model: str default_settings: SessionSettings + render_templates: bool = False + metadata: dict = {} + user_metadata: dict = {} + agent_metadata: dict = {} diff --git a/agents-api/agents_api/common/utils/template.py b/agents-api/agents_api/common/utils/template.py new file mode 100644 index 000000000..7aa73b5c1 --- /dev/null +++ b/agents-api/agents_api/common/utils/template.py @@ -0,0 +1,38 @@ +import arrow +from jinja2.sandbox import ImmutableSandboxedEnvironment +from jinja2schema import infer, to_json_schema +from jsonschema import validate + +__all__ = [ + "render_template", +] + +# jinja environment +jinja_env = ImmutableSandboxedEnvironment( + autoescape=False, + trim_blocks=True, + lstrip_blocks=True, + auto_reload=False, + enable_async=True, + loader=None, +) + +# Add arrow to jinja +jinja_env.globals["arrow"] = arrow + + +# Funcs +async def render_template( + template_string: str, variables: dict, check: bool = False +) -> str: + # Parse template + template = jinja_env.from_string(template_string) + + # If check is required, get required vars from template and validate variables + if check: + schema = to_json_schema(infer(template_string)) + validate(instance=variables, schema=schema) + + # Render + rendered = await template.render_async(**variables) + return rendered diff --git a/agents-api/agents_api/models/session/create_session.py b/agents-api/agents_api/models/session/create_session.py index 62dd926cc..8179bf433 100644 --- a/agents-api/agents_api/models/session/create_session.py +++ b/agents-api/agents_api/models/session/create_session.py @@ -17,6 +17,7 @@ def create_session_query( user_id: UUID | None, situation: str | None, metadata: dict = {}, + render_templates: bool = False, ) -> tuple[str, dict]: """ Constructs and executes a datalog query to create a new session in the database. @@ -28,6 +29,7 @@ def create_session_query( - user_id (UUID | None): The unique identifier for the user, if applicable. - situation (str | None): The situation/context of the session. - metadata (dict): Additional metadata for the session. + - render_templates (bool): Specifies whether to render templates. Returns: - pd.DataFrame: The result of the query execution. @@ -52,11 +54,12 @@ def create_session_query( } } { # Insert the new session data into the 'session' table with the specified columns. - ?[session_id, developer_id, situation, metadata] <- [[ + ?[session_id, developer_id, situation, metadata, render_templates] <- [[ $session_id, $developer_id, $situation, $metadata, + $render_templates, ]] :insert sessions { @@ -64,6 +67,7 @@ def create_session_query( session_id, situation, metadata, + render_templates, } # Specify the data to return after the query execution, typically the newly created session's ID. :returning @@ -79,5 +83,6 @@ def create_session_query( "developer_id": str(developer_id), "situation": situation, "metadata": metadata, + "render_templates": render_templates, }, ) diff --git a/agents-api/agents_api/models/session/get_session.py b/agents-api/agents_api/models/session/get_session.py index ef8af6ce2..2b9e5cc4d 100644 --- a/agents-api/agents_api/models/session/get_session.py +++ b/agents-api/agents_api/models/session/get_session.py @@ -40,6 +40,7 @@ def get_session_query( updated_at, created_at, metadata, + render_templates, ] := input[developer_id, id], *sessions{ developer_id, @@ -49,6 +50,7 @@ def get_session_query( created_at, updated_at: validity, metadata, + render_templates, @ "NOW" }, *session_lookup{ diff --git a/agents-api/agents_api/models/session/patch_session.py b/agents-api/agents_api/models/session/patch_session.py index f8d9b10ca..a3cbc3748 100644 --- a/agents-api/agents_api/models/session/patch_session.py +++ b/agents-api/agents_api/models/session/patch_session.py @@ -16,6 +16,9 @@ ] +# TODO: Add support for updating `render_templates` field + + @cozo_query def patch_session_query( session_id: UUID, @@ -41,6 +44,7 @@ def patch_session_query( }, session_id = to_uuid($session_id), developer_id = to_uuid($developer_id), + # Assertion to ensure the session exists before updating. :assert some """ diff --git a/agents-api/agents_api/models/session/session_data.py b/agents-api/agents_api/models/session/session_data.py index ea765371a..8d09d452b 100644 --- a/agents-api/agents_api/models/session/session_data.py +++ b/agents-api/agents_api/models/session/session_data.py @@ -42,9 +42,10 @@ def session_data_query( agent_about, model, default_settings, - session_metadata, - users_metadata, - agents_metadata, + metadata, + render_templates, + user_metadata, + agent_metadata, ] := input[developer_id, session_id], *sessions{ developer_id, @@ -53,7 +54,8 @@ def session_data_query( summary, created_at, updated_at: validity, - metadata: session_metadata, + metadata, + render_templates, @ "NOW" }, *session_lookup{ @@ -65,14 +67,14 @@ def session_data_query( user_id, name: user_name, about: user_about, - metadata: users_metadata, + metadata: user_metadata, }, *agents{ agent_id, name: agent_name, about: agent_about, model, - metadata: agents_metadata, + metadata: agent_metadata, }, *agent_default_settings { agent_id, diff --git a/agents-api/agents_api/models/session/update_session.py b/agents-api/agents_api/models/session/update_session.py index 44abce3fd..4c76965a9 100644 --- a/agents-api/agents_api/models/session/update_session.py +++ b/agents-api/agents_api/models/session/update_session.py @@ -14,6 +14,8 @@ "developer_id", ] +# TODO: Add support for updating `render_templates` field + @cozo_query def update_session_query( diff --git a/agents-api/agents_api/routers/sessions/routers.py b/agents-api/agents_api/routers/sessions/routers.py index 59c841fed..6831d96f5 100644 --- a/agents-api/agents_api/routers/sessions/routers.py +++ b/agents-api/agents_api/routers/sessions/routers.py @@ -89,6 +89,7 @@ async def create_session( user_id=request.user_id, situation=request.situation, metadata=request.metadata or {}, + render_templates=request.render_templates or False, ) return ResourceCreatedResponse( diff --git a/agents-api/agents_api/routers/sessions/session.py b/agents-api/agents_api/routers/sessions/session.py index 51fd41500..169aef92e 100644 --- a/agents-api/agents_api/routers/sessions/session.py +++ b/agents-api/agents_api/routers/sessions/session.py @@ -8,24 +8,26 @@ from openai.types.chat.chat_completion import ChatCompletion from pydantic import UUID4 -from agents_api.clients.embed import embed -from agents_api.env import summarization_tokens_threshold -from agents_api.clients.temporal import run_summarization_task -from agents_api.models.entry.add_entries import add_entries_query -from agents_api.common.protocol.entries import Entry -from agents_api.common.exceptions.sessions import SessionNotFoundError -from agents_api.clients.worker.types import ChatML -from agents_api.models.session.session_data import get_session_data -from agents_api.models.entry.proc_mem_context import proc_mem_context_query -from agents_api.autogen.openapi_model import InputChatMLMessage, Tool -from agents_api.model_registry import ( +from ...autogen.openapi_model import InputChatMLMessage, Tool +from ...clients.embed import embed +from ...clients.temporal import run_summarization_task +from ...clients.worker.types import ChatML +from ...common.exceptions.sessions import SessionNotFoundError +from ...common.protocol.entries import Entry +from ...common.protocol.sessions import SessionData +from ...common.utils.template import render_template +from ...env import summarization_tokens_threshold +from ...model_registry import ( get_extra_settings, get_model_client, load_context, ) -from ...common.protocol.sessions import SessionData -from .protocol import Settings +from ...models.entry.add_entries import add_entries_query +from ...models.entry.proc_mem_context import proc_mem_context_query +from ...models.session.session_data import get_session_data + from .exceptions import InputTooBigError +from .protocol import Settings THOUGHTS_STRIP_LEN = 2 @@ -118,18 +120,22 @@ async def run( self, new_input, settings: Settings ) -> tuple[ChatCompletion, Entry, Callable | None]: # TODO: implement locking at some point + # Get session data session_data = get_session_data(self.developer_id, self.session_id) if session_data is None: raise SessionNotFoundError(self.developer_id, self.session_id) + # Assemble context init_context, final_settings = await self.forward( session_data, new_input, settings ) + # Generate response response = await self.generate( self.truncate(init_context, summarization_tokens_threshold), final_settings ) + # Save response to session # if final_settings.get("remember"): # await self.add_to_session(new_input, response) @@ -195,10 +201,11 @@ async def forward( ) entries: list[Entry] = [] - instructions = "IMPORTANT INSTRUCTIONS:\n\n" + instructions = "Instructions:\n\n" first_instruction_idx = -1 first_instruction_created_at = 0 tools = [] + for idx, row in proc_mem_context_query( session_id=self.session_id, tool_query_embedding=tool_query_embedding, @@ -224,7 +231,7 @@ async def forward( first_instruction_idx = idx first_instruction_created_at = row["created_at"] - instructions += f"- {row['content']}\n" + instructions += f"{row['content']}\n\n" continue @@ -266,6 +273,36 @@ async def forward( if e.content ] + # If render_templates=True, render the templates + if session_data is not None and session_data.render_templates: + + template_data = { + "session": { + "id": session_data.session_id, + "situation": session_data.situation, + "metadata": session_data.metadata, + }, + "user": { + "id": session_data.user_id, + "name": session_data.user_name, + "about": session_data.user_about, + "metadata": session_data.user_metadata, + }, + "agent": { + "id": session_data.agent_id, + "name": session_data.agent_name, + "about": session_data.agent_about, + "metadata": session_data.agent_metadata, + }, + } + + for i, msg in enumerate(messages): + # Only render templates for system/assistant messages + if msg.role not in ["system", "assistant"]: + continue + + messages[i].content = await render_template(msg.content, template_data) + # FIXME: This sometimes returns "The model `` does not exist." if session_data is not None: settings.model = session_data.model diff --git a/agents-api/migrations/migrate_1707537826_rename_additional_info.py b/agents-api/migrations/migrate_1707537826_rename_additional_info.py index 7e26322df..d71576f05 100644 --- a/agents-api/migrations/migrate_1707537826_rename_additional_info.py +++ b/agents-api/migrations/migrate_1707537826_rename_additional_info.py @@ -206,13 +206,7 @@ def run(client, *queries): query = joiner.join(queries) query = f"{{\n{query}\n}}" - try: - client.run(query) - except Exception as error: - print(error) - import pdb - - pdb.set_trace() + client.run(query) def up(client): diff --git a/agents-api/migrations/migrate_1714119679_session_render_templates.py b/agents-api/migrations/migrate_1714119679_session_render_templates.py new file mode 100644 index 000000000..93d7dba14 --- /dev/null +++ b/agents-api/migrations/migrate_1714119679_session_render_templates.py @@ -0,0 +1,67 @@ +# /usr/bin/env python3 + +MIGRATION_ID = "session_render_templates" +CREATED_AT = 1714119679.493182 + +extend_sessions = { + "up": """ + ?[render_templates, developer_id, session_id, updated_at, situation, summary, created_at, developer_id, metadata] := *sessions{ + session_id, + updated_at, + situation, + summary, + created_at, + developer_id + }, + metadata = {}, + render_templates = false + + :replace sessions { + developer_id: Uuid, + session_id: Uuid, + updated_at: Validity default [floor(now()), true], + => + situation: String, + summary: String? default null, + created_at: Float default now(), + metadata: Json default {}, + render_templates: Bool default false, + } + """, + "down": """ + ?[developer_id, session_id, updated_at, situation, summary, created_at, developer_id, metadata] := *sessions{ + session_id, + updated_at, + situation, + summary, + created_at, + developer_id + }, metadata = {} + + :replace sessions { + developer_id: Uuid, + session_id: Uuid, + updated_at: Validity default [floor(now()), true], + => + situation: String, + summary: String? default null, + created_at: Float default now(), + metadata: Json default {}, + } + """, +} + + +queries_to_run = [ + extend_sessions, +] + + +def up(client): + for q in queries_to_run: + client.run(q["up"]) + + +def down(client): + for q in reversed(queries_to_run): + client.run(q["down"]) diff --git a/sdks/python/julep/managers/session.py b/sdks/python/julep/managers/session.py index ed448db31..d1ea3907d 100644 --- a/sdks/python/julep/managers/session.py +++ b/sdks/python/julep/managers/session.py @@ -39,6 +39,7 @@ class SessionCreateArgs(TypedDict): agent_id: Union[str, UUID] situation: Optional[str] = None metadata: Dict[str, Any] = {} + render_templates: bool = False class SessionUpdateArgs(TypedDict): @@ -46,6 +47,7 @@ class SessionUpdateArgs(TypedDict): situation: Optional[str] = None metadata: Optional[Dict[str, Any]] = None overwrite: bool = False + render_templates: bool = False class BaseSessionsManager(BaseManager): @@ -179,6 +181,7 @@ def _create( user_id: Optional[Union[str, UUID]] = None, situation: Optional[str] = None, metadata: Dict[str, Any] = {}, + render_templates: bool = False, ) -> Union[ResourceCreatedResponse, Awaitable[ResourceCreatedResponse]]: # Cast instructions to a list of Instruction objects """ @@ -191,6 +194,7 @@ def _create( user_id (Optional[Union[str, UUID]]): The user's identifier which could be a string or a UUID object. situation (Optional[str], optional): An optional description of the situation. metadata (Dict[str, Any]) + render_templates (bool, optional): Whether to render templates in the metadata. Defaults to False. Returns: Union[ResourceCreatedResponse, Awaitable[ResourceCreatedResponse]]: The response from the API client upon successful session creation, which can be a synchronous `ResourceCreatedResponse` or an asynchronous `Awaitable` of it. @@ -208,6 +212,7 @@ def _create( agent_id=agent_id, situation=situation, metadata=metadata, + render_templates=render_templates, ) def _list_items( diff --git a/sdks/python/tests/fixtures.py b/sdks/python/tests/fixtures.py index e1ace94ff..34c9691b2 100644 --- a/sdks/python/tests/fixtures.py +++ b/sdks/python/tests/fixtures.py @@ -49,6 +49,12 @@ "metadata": {"type": "test"}, } +mock_session_with_template = { + "situation": "Say 'hello {{ session.metadata.arg }}'", + "metadata": {"type": "test", "arg": "banana"}, + "render_templates": True, +} + mock_session_update = { "situation": "updated situation", "metadata": {"type": "test"}, @@ -189,6 +195,21 @@ def test_session(client=client, user=test_user, agent=test_agent) -> Session: client.sessions.delete(session.id) +@fixture +def test_session_with_template( + client=client, user=test_user, agent=test_agent +) -> Session: + session = client.sessions.create( + user_id=user.id, + agent_id=agent.id, + **mock_session_with_template, + ) + + yield session + + client.sessions.delete(session.id) + + @fixture def test_session_agent_user(client=client, user=test_user, agent=test_agent) -> Session: session = client.sessions.create( diff --git a/sdks/python/tests/test_sessions.py b/sdks/python/tests/test_sessions.py index a8774aa80..23bdddda9 100644 --- a/sdks/python/tests/test_sessions.py +++ b/sdks/python/tests/test_sessions.py @@ -12,10 +12,12 @@ async_client, client, test_session, + test_session_with_template, test_session_agent_user, test_session_no_user, mock_session, mock_session_update, + mock_session_with_template, TEST_API_KEY, TEST_API_URL, ) @@ -148,6 +150,25 @@ def _(client=client, session=test_session): assert len(history) > 0 +@test("sessions: sessions.chat with template") +def _(client=client, session=test_session_with_template): + response = client.sessions.chat( + session_id=session.id, + messages=[ + InputChatMlMessage( + role=InputChatMlMessageRole.USER, + content="say it please", + ) + ], + max_tokens=10, + ) + + assert isinstance(response, ChatResponse) + assert ( + mock_session_with_template["metadata"]["arg"] in response.response[0][0].content + ) + + # @test("sessions: sessions.suggestions") # def _(client=client, session=test_session): # response = client.sessions.suggestions( diff --git a/sdks/ts/src/managers/session.ts b/sdks/ts/src/managers/session.ts index 1e14c4fe3..e14ce84a5 100644 --- a/sdks/ts/src/managers/session.ts +++ b/sdks/ts/src/managers/session.ts @@ -16,6 +16,8 @@ export interface CreateSessionPayload { userId: string; agentId: string; situation?: string; + metadata?: Record; + renderTemplates?: boolean; } export class SessionsManager extends BaseManager { @@ -25,37 +27,37 @@ export class SessionsManager extends BaseManager { * @returns A promise that resolves with the session object. */ async get(sessionId: string): Promise { - try { - return this.apiClient.default.getSession({ sessionId }); - } catch (error) { - throw error; - } + return this.apiClient.default.getSession({ sessionId }); } async create({ userId, agentId, situation, + metadata = {}, + renderTemplates = false, }: CreateSessionPayload): Promise { - try { - invariant( - isValidUuid4(userId), - `userId must be a valid UUID v4. Got "${userId}"`, - ); - - invariant( - isValidUuid4(agentId), - `agentId must be a valid UUID v4. Got "${agentId}"`, - ); - - const requestBody = { user_id: userId, agent_id: agentId, situation }; - - return this.apiClient.default - .createSession({ requestBody }) - .catch((error) => Promise.reject(error)); - } catch (error) { - throw error; - } + invariant( + isValidUuid4(userId), + `userId must be a valid UUID v4. Got "${userId}"`, + ); + + invariant( + isValidUuid4(agentId), + `agentId must be a valid UUID v4. Got "${agentId}"`, + ); + + const requestBody = { + user_id: userId, + agent_id: agentId, + situation, + metadata, + render_templates: renderTemplates, + }; + + return this.apiClient.default + .createSession({ requestBody }) + .catch((error) => Promise.reject(error)); } async list({ @@ -79,13 +81,9 @@ export class SessionsManager extends BaseManager { } async delete(sessionId: string): Promise { - try { - invariant(isValidUuid4(sessionId), "sessionId must be a valid UUID v4"); + invariant(isValidUuid4(sessionId), "sessionId must be a valid UUID v4"); - await this.apiClient.default.deleteSession({ sessionId }); - } catch (error) { - throw error; - } + await this.apiClient.default.deleteSession({ sessionId }); } async update( @@ -93,17 +91,13 @@ export class SessionsManager extends BaseManager { { situation, metadata = {} }: { situation: string; metadata?: any }, overwrite = false, ): Promise { - try { - invariant(isValidUuid4(sessionId), "sessionId must be a valid UUID v4"); - const requestBody = { situation, metadata }; - - if (overwrite) { - return this.apiClient.default.updateSession({ sessionId, requestBody }); - } else { - return this.apiClient.default.patchSession({ sessionId, requestBody }); - } - } catch (error) { - throw error; + invariant(isValidUuid4(sessionId), "sessionId must be a valid UUID v4"); + const requestBody = { situation, metadata }; + + if (overwrite) { + return this.apiClient.default.updateSession({ sessionId, requestBody }); + } else { + return this.apiClient.default.patchSession({ sessionId, requestBody }); } } @@ -129,87 +123,71 @@ export class SessionsManager extends BaseManager { top_p, }: ChatInput, ): Promise { - try { - invariant(isValidUuid4(sessionId), "sessionId must be a valid UUID v4"); - - const options = omitBy( - { - tools, - tool_choice, - frequency_penalty, - length_penalty, - logit_bias, - max_tokens, - presence_penalty, - repetition_penalty, - response_format, - seed, - stop, - stream, - temperature, - top_p, - recall, - remember, - }, - isUndefined, - ); - - const requestBody = { - messages, - ...options, - }; - - return await this.apiClient.default.chat({ sessionId, requestBody }); - } catch (error) { - throw error; - } + invariant(isValidUuid4(sessionId), "sessionId must be a valid UUID v4"); + + const options = omitBy( + { + tools, + tool_choice, + frequency_penalty, + length_penalty, + logit_bias, + max_tokens, + presence_penalty, + repetition_penalty, + response_format, + seed, + stop, + stream, + temperature, + top_p, + recall, + remember, + }, + isUndefined, + ); + + const requestBody = { + messages, + ...options, + }; + + return await this.apiClient.default.chat({ sessionId, requestBody }); } async suggestions( sessionId: string, { limit = 100, offset = 0 }: { limit?: number; offset?: number } = {}, ): Promise> { - try { - invariant(isValidUuid4(sessionId), "sessionId must be a valid UUID v4"); - - const result = await this.apiClient.default.getSuggestions({ - sessionId, - limit, - offset, - }); - - return result.items || []; - } catch (error) { - throw error; - } + invariant(isValidUuid4(sessionId), "sessionId must be a valid UUID v4"); + + const result = await this.apiClient.default.getSuggestions({ + sessionId, + limit, + offset, + }); + + return result.items || []; } async history( sessionId: string, { limit = 100, offset = 0 }: { limit?: number; offset?: number } = {}, ): Promise> { - try { - invariant(isValidUuid4(sessionId), "sessionId must be a valid UUID v4"); - - const result = await this.apiClient.default.getHistory({ - sessionId, - limit, - offset, - }); - - return result.items || []; - } catch (error) { - throw error; - } + invariant(isValidUuid4(sessionId), "sessionId must be a valid UUID v4"); + + const result = await this.apiClient.default.getHistory({ + sessionId, + limit, + offset, + }); + + return result.items || []; } async deleteHistory(sessionId: string): Promise { - try { - invariant(isValidUuid4(sessionId), "sessionId must be a valid UUID v4"); + invariant(isValidUuid4(sessionId), "sessionId must be a valid UUID v4"); - await this.apiClient.default.deleteSessionHistory({ sessionId }); - } catch (error) { - throw error; - } + await this.apiClient.default.deleteSessionHistory({ sessionId }); } } diff --git a/sdks/ts/tests/sessions.test.ts b/sdks/ts/tests/sessions.test.ts index 5a7ce6d9b..d174f2ba5 100644 --- a/sdks/ts/tests/sessions.test.ts +++ b/sdks/ts/tests/sessions.test.ts @@ -26,6 +26,12 @@ const mockSession = { situation: "test situation", }; +const mockSessionWithTemplate = { + situation: "Say 'hello {{ session.metadata.arg }}'", + metadata: { arg: "banana" }, + renderTemplates: true, +}; + const mockSessionUpdate = { situation: "updated situation", }; @@ -97,26 +103,47 @@ describe("Sessions API", () => { }); it("sessions.chat", async () => { - try { - const response = await client.sessions.chat(testSessionId, { - messages: [ - { - role: "user", - content: "test content", - name: "test name", - }, - ], - max_tokens: 1000, - presence_penalty: 0.5, - repetition_penalty: 0.5, - temperature: 0.7, - top_p: 0.9, - }); - - expect(response.response).toBeDefined(); - } catch (error) { - console.error("error", error); - } + const response = await client.sessions.chat(testSessionId, { + messages: [ + { + role: "user", + content: "test content", + name: "test name", + }, + ], + max_tokens: 1000, + presence_penalty: 0.5, + repetition_penalty: 0.5, + temperature: 0.7, + top_p: 0.9, + }); + + expect(response.response).toBeDefined(); + }, 5000); + + it("sessions.chat with template", async () => { + const session = await client.sessions.create({ + userId: testUser.id, + agentId: testAgent.id, + ...mockSessionWithTemplate, + }); + + const response = await client.sessions.chat(session.id, { + messages: [ + { + role: "user", + content: "please say it", + }, + ], + max_tokens: 10, + }); + + expect(response.response).toBeDefined(); + + // Check that the template was filled in + expect(response.response[0][0].content).toContain( + mockSessionWithTemplate.metadata.arg, + ); }, 5000); // it("sessions.suggestions", async () => {