From 34b3eca1687bb1806da7ba9a58a9ff24d1189f19 Mon Sep 17 00:00:00 2001 From: Bar Karov Date: Wed, 25 Dec 2024 09:42:39 +0000 Subject: [PATCH] fixed guideline proposer tests, reverted gpt version for guideline proposer --- src/parlant/adapters/nlp/openai.py | 3 ++- .../engines/alpha/test_guideline_proposer.py | 22 +++++++++++++------ 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/src/parlant/adapters/nlp/openai.py b/src/parlant/adapters/nlp/openai.py index 6d38062aa..990d628b6 100644 --- a/src/parlant/adapters/nlp/openai.py +++ b/src/parlant/adapters/nlp/openai.py @@ -34,6 +34,7 @@ import tiktoken from parlant.adapters.nlp.common import normalize_json_output +from parlant.core.engines.alpha.guideline_proposer import GuidelinePropositionsSchema from parlant.core.logging import Logger from parlant.core.nlp.policies import policy, retry from parlant.core.nlp.tokenization import EstimatingTokenizer @@ -380,7 +381,7 @@ def __init__( @override async def get_schematic_generator(self, t: type[T]) -> OpenAISchematicGenerator[T]: - if t == GuidelineConnectionPropositionsSchema: + if t == GuidelineConnectionPropositionsSchema or t == GuidelinePropositionsSchema: return GPT_4o_24_08_06[t](self._logger) # type: ignore return GPT_4o[t](self._logger) # type: ignore diff --git a/tests/core/unstable/engines/alpha/test_guideline_proposer.py b/tests/core/unstable/engines/alpha/test_guideline_proposer.py index 184056676..25a02fe35 100644 --- a/tests/core/unstable/engines/alpha/test_guideline_proposer.py +++ b/tests/core/unstable/engines/alpha/test_guideline_proposer.py @@ -182,6 +182,10 @@ "condition": "the customer asked a question about birds", "action": "answer their question enthusiastically, while not using punctuation. Also say that the kingfisher is your favorite bird", }, + "second_thanks": { + "condition": "the customer is thanking you for the second time in the interaction", + "action": "compliment the customer for their manners", + }, } @@ -392,7 +396,7 @@ def test_that_many_guidelines_are_classified_correctly( # a stress test "ai_agent", "Got it! Your blue 'City Cruiser' skateboard and black medium helmet are ready for checkout. How would you like to pay?", ), - ("customer", "I'll pay with a credit card, thanks."), + ("customer", "I'll pay with a credit card, thank you very much!"), ( "ai_agent", "Thank you for your order! Your skateboard and helmet will be shipped shortly. Enjoy your ride!", @@ -400,17 +404,20 @@ def test_that_many_guidelines_are_classified_correctly( # a stress test ("customer", "That's great! Thanks!"), ] - exceptions = ["credit_payment1", "credit_payment2", "cow_response"] + exceptions = [ + "credit_payment1", + "credit_payment2", + "cow_response", + "thankful_customer", + "payment_process", + ] conversation_guideline_names: list[str] = [ guideline_name for guideline_name in GUIDELINES_DICT.keys() if guideline_name not in exceptions ] - relevant_guideline_names = [ - "announce_shipment", - "thankful_customer", - ] + relevant_guideline_names = ["announce_shipment", "second_thanks"] base_test_that_correct_guidelines_are_proposed( context, agent, @@ -510,13 +517,14 @@ def test_that_guideline_that_needs_to_be_reapplied_is_proposed( ] conversation_guideline_names: list[str] = ["large_pizza_crust"] + relevant_guideline_names = conversation_guideline_names base_test_that_correct_guidelines_are_proposed( context, agent, customer, conversation_context, conversation_guideline_names, - [], + relevant_guideline_names, context_variables=[], )