Skip to content

Commit

Permalink
Refactor prompt handling
Browse files Browse the repository at this point in the history
  • Loading branch information
jtlicardo committed Dec 18, 2024
1 parent 86829e5 commit 1ffca51
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 59 deletions.
8 changes: 2 additions & 6 deletions src/bpmn_assistant/services/bpmn_modeling_service.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import json
import traceback
from importlib import resources

from bpmn_assistant.config import logger
from bpmn_assistant.core import LLMFacade, MessageItem
Expand Down Expand Up @@ -32,12 +31,9 @@ def create_bpmn(
Returns:
list: The BPMN process.
"""
prompt_template = resources.read_text(
"bpmn_assistant.prompts", "create_bpmn.txt"
)

prompt = prepare_prompt(
prompt_template, message_history=message_history_to_string(message_history)
"create_bpmn.txt",
message_history=message_history_to_string(message_history),
)

attempts = 0
Expand Down
24 changes: 5 additions & 19 deletions src/bpmn_assistant/services/conversational_service.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from importlib import resources
from typing import Optional, Generator
from typing import Generator, Optional

from bpmn_assistant.core import MessageItem
from bpmn_assistant.core.enums import OutputMode
from bpmn_assistant.utils import (
prepare_prompt,
get_llm_facade,
message_history_to_string,
prepare_prompt,
)


Expand All @@ -28,21 +27,13 @@ def respond_to_query(
Generator: A generator that yields the response
"""
if not process:
prompt_template = resources.read_text(
"bpmn_assistant.prompts", "respond_to_query_no_process.txt"
)

prompt = prepare_prompt(
prompt_template,
"respond_to_query_no_process.txt",
message_history=message_history_to_string(message_history),
)
else:
prompt_template = resources.read_text(
"bpmn_assistant.prompts", "respond_to_query.txt"
)

prompt = prepare_prompt(
prompt_template,
"respond_to_query.txt",
message_history=message_history_to_string(message_history),
process=str(process),
)
Expand All @@ -60,13 +51,8 @@ def make_final_comment(
Returns:
Generator: A generator that yields the final comment
"""
# TODO: prepare_prompt should take care of reading the template
prompt_template = resources.read_text(
"bpmn_assistant.prompts", "make_final_comment.txt"
)

prompt = prepare_prompt(
prompt_template,
"make_final_comment.txt",
message_history=message_history_to_string(message_history),
process=str(process),
)
Expand Down
7 changes: 1 addition & 6 deletions src/bpmn_assistant/services/determine_intent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import traceback
from importlib import resources

from pydantic import BaseModel

Expand Down Expand Up @@ -42,12 +41,8 @@ def determine_intent(
Returns:
dict: The response containing the intent
"""
prompt_template = resources.read_text(
"bpmn_assistant.prompts", "determine_intent.txt"
)

prompt = prepare_prompt(
prompt_template,
"determine_intent.txt",
message_history=message_history_to_string(message_history),
)

Expand Down
16 changes: 4 additions & 12 deletions src/bpmn_assistant/services/process_editing/bpmn_editor_service.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from importlib import resources

from bpmn_assistant.config import logger
from bpmn_assistant.core import LLMFacade
from bpmn_assistant.core.exceptions import ProcessException
from bpmn_assistant.services.process_editing import (
delete_element,
redirect_branch,
add_element,
delete_element,
move_element,
redirect_branch,
update_element,
)
from bpmn_assistant.utils import prepare_prompt
Expand Down Expand Up @@ -107,10 +105,8 @@ def _update_process(self, process: list, edit_proposal: dict) -> list:
return res["process"]

def _get_initial_edit_proposal(self, max_retries: int = 3) -> dict:
prompt_template = resources.read_text("bpmn_assistant.prompts", "edit_bpmn.txt")

prompt = prepare_prompt(
prompt_template,
"edit_bpmn.txt",
process=str(self.process),
change_request=self.change_request,
)
Expand Down Expand Up @@ -150,12 +146,8 @@ def _get_intermediate_edit_proposal(
Returns:
The intermediate edit proposal (function and arguments)
"""
prompt_template = resources.read_text(
"bpmn_assistant.prompts", "edit_bpmn_intermediate_step.txt"
)

prompt = prepare_prompt(
prompt_template,
"edit_bpmn_intermediate_step.txt",
process=str(updated_process),
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,37 +1,69 @@
from importlib import resources
import traceback

from pydantic import BaseModel

from bpmn_assistant.config import logger
from bpmn_assistant.core import LLMFacade, MessageItem
from bpmn_assistant.utils import prepare_prompt, message_history_to_string
from bpmn_assistant.utils import message_history_to_string, prepare_prompt


class DefineChangeRequestResponse(BaseModel):
change_request: str


def _validate_define_change_request(response: dict) -> None:
"""
Validate the response from the define_change_request function.
Args:
response: The response to validate
Raises:
ValueError: If the response is invalid
"""
if "change_request" not in response:
raise ValueError("Invalid response: 'change_request' key not found")


def define_change_request(
llm_facade: LLMFacade, process: list[dict], message_history: list[MessageItem]
llm_facade: LLMFacade,
process: list[dict],
message_history: list[MessageItem],
max_retries: int = 3,
) -> str:
"""
Defines the change to be made in the BPMN process based on the message history.
Args:
llm_facade: The LLM facade object
process: The BPMN process
message_history: The message history
max_retries: The maximum number of retries in case of failure
Returns:
str: The change request
"""

prompt_template = resources.read_text(
"bpmn_assistant.prompts", "define_change_request.txt"
)

prompt = prepare_prompt(
prompt_template,
"define_change_request.txt",
process=str(process),
message_history=message_history_to_string(message_history),
)

json_object = llm_facade.call(prompt, max_tokens=100, temperature=0.4)
attempts = 0

# TODO: validate the response, retry until it's valid
while attempts < max_retries:

logger.info(f"Change request: {json_object['change_request']}")
attempts += 1

return json_object["change_request"]
try:
json_object = llm_facade.call(prompt, max_tokens=100, temperature=0.4)
_validate_define_change_request(json_object)
logger.info(f"Change request: {json_object['change_request']}")
return json_object["change_request"]
except Exception as e:
logger.warning(
f"Validation error (attempt {attempts}): {str(e)}\n"
f"Traceback: {traceback.format_exc()}"
)

prompt = f"Error: {str(e)}. Try again."

raise Exception(
"Maximum number of retries reached. Could not define change request."
)
9 changes: 6 additions & 3 deletions src/bpmn_assistant/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from importlib import resources

from dotenv import load_dotenv

Expand All @@ -13,16 +14,18 @@
)


def prepare_prompt(prompt_template, **kwargs):
def prepare_prompt(template_file: str, **kwargs) -> str:
"""
Replace the placeholders in the prompt template with the given values.
Read the prompt template from the given resource and replace the placeholders with the given values.
Args:
prompt_template (str): The prompt template.
template_file (str): The template file name.
**kwargs: Keyword arguments where keys are variable names (without '::')
and values are the replacement strings.
Returns:
str: The prompt
"""
prompt_template = resources.read_text("bpmn_assistant.prompts", template_file)

prompt = prompt_template

# Extract variables from the template
Expand Down

0 comments on commit 1ffca51

Please sign in to comment.