diff --git a/README.md b/README.md index 1f0f19b..e33e322 100644 --- a/README.md +++ b/README.md @@ -72,8 +72,9 @@ Note: You can use any combination of the API keys above, but at least one is req * GPT-4o mini * GPT-4o * o1-preview +* o1-mini -> Note: o1-preview is used together with GPT-4o as it doesn't support structured outputs yet. While o1-preview provides enhanced performance (especially for BPMN diagram editing), it has longer response times and can incur significant costs. +> Note: Both o1-preview and o1-mini are used together with their respective GPT-4o models as they don't support structured outputs yet. While providing enhanced performance (especially for BPMN diagram editing), they have longer response times and can incur significant costs. ### Anthropic diff --git a/src/bpmn_assistant/app.py b/src/bpmn_assistant/app.py index d165de1..7131f79 100644 --- a/src/bpmn_assistant/app.py +++ b/src/bpmn_assistant/app.py @@ -59,9 +59,14 @@ def _available_providers() -> JSONResponse: def replace_reasoning_model(model: str) -> str: """ - Returns GPT-4o if o1-preview is requested. Otherwise returns the original model. + Returns GPT-4o if o1-preview is requested, or GPT-4o-mini if o1-mini is requested. + Otherwise returns the original model. """ - return OpenAIModels.GPT_4O.value if model == OpenAIModels.O1.value else model + if model == OpenAIModels.O1.value: + return OpenAIModels.GPT_4O.value + elif model == OpenAIModels.O1_MINI.value: + return OpenAIModels.GPT_4O_MINI.value + return model @app.post("/determine_intent") diff --git a/src/bpmn_assistant/core/enums/models.py b/src/bpmn_assistant/core/enums/models.py index 8a4a36a..580e765 100644 --- a/src/bpmn_assistant/core/enums/models.py +++ b/src/bpmn_assistant/core/enums/models.py @@ -5,6 +5,7 @@ class OpenAIModels(Enum): GPT_4O_MINI = "gpt-4o-mini" GPT_4O = "gpt-4o" O1 = "o1-preview" + O1_MINI = "o1-mini" class AnthropicModels(Enum): diff --git a/src/bpmn_assistant/core/provider_impl/openai_provider.py b/src/bpmn_assistant/core/provider_impl/openai_provider.py index f3e8a31..c140a01 100644 --- a/src/bpmn_assistant/core/provider_impl/openai_provider.py +++ b/src/bpmn_assistant/core/provider_impl/openai_provider.py @@ -42,8 +42,8 @@ def call( "messages": messages, # type: ignore[arg-type] } - if model == OpenAIModels.O1.value: - logger.info("Using o1 model") + if model == OpenAIModels.O1.value or model == OpenAIModels.O1_MINI.value: + logger.info("Using o-series models") # https://platform.openai.com/docs/guides/reasoning#limitations # constraining the max_completion_tokens can lead to empty responses else: diff --git a/src/bpmn_assistant/prompts/define_change_request.txt b/src/bpmn_assistant/prompts/define_change_request.txt index 5c6efac..4115de7 100644 --- a/src/bpmn_assistant/prompts/define_change_request.txt +++ b/src/bpmn_assistant/prompts/define_change_request.txt @@ -65,7 +65,7 @@ If the process description does not explicitly mention the 'else' branch or spec ### Parallel gateway Specify "branches" as an array of arrays, where each sub-array lists elements executed in parallel. -A converging element is automatically generated to synchronize parallel branches. Therefore, there's no need to explicitly specify it. +A converging (join) element is automatically generated to synchronize parallel branches. Therefore, there's no need to explicitly specify it. ```json { @@ -83,6 +83,8 @@ A converging element is automatically generated to synchronize parallel branches The order of execution in the BPMN process is determined by the sequence of elements in the top-level "process" array. Elements are executed in the order they appear in this array, from first to last. +## Example 1 + Textual description: "The student sends an email to the professor. The professor receives the email. If the professor agrees with the proposal, he replies to the student." @@ -133,7 +135,7 @@ Textual description: } ``` ---- +## Example 2 Textual description: "The manager sends the mail to the supplier and prepares the documents. At the same time, the customer searches for the goods and picks up the goods." @@ -183,7 +185,7 @@ Textual description: } ``` ---- +## Example 3 Textual description: "Someone starts a process by entering an exam room. After that, they take the test. Once the exam is finished, their score is checked. If they scored more than 50%, their grade is recorded and the process ends. But if they fail, they have to go back to the beginning and take the exam again." @@ -235,7 +237,7 @@ Textual description: "Someone starts a process by entering an exam room. After t } ``` ---- +## Example 4 Textual description: "The process starts with a decision. If Option A is selected, Task A is performed. If Option B is selected, Task B is performed. Task A is followed by another decision. If Sub-option 1 is selected, Task A1 is performed. If Sub-option 2 is selected, Task A2 is performed." @@ -311,8 +313,6 @@ Textual description: "The process starts with a decision. If Option A is selecte } ``` ---- - # Process editing functions - `delete_element(element_id)` @@ -363,14 +363,10 @@ Textual description: "The process starts with a decision. If Option A is selecte ::process ---- - # Message history ::message_history ---- - The last user message indicates that the user wants to make a modification to the process. Based on the last user message, construct a **concise** change request. diff --git a/src/bpmn_assistant/services/process_editing/bpmn_editing_service.py b/src/bpmn_assistant/services/process_editing/bpmn_editing_service.py index 9f4173a..fc57312 100644 --- a/src/bpmn_assistant/services/process_editing/bpmn_editing_service.py +++ b/src/bpmn_assistant/services/process_editing/bpmn_editing_service.py @@ -44,20 +44,13 @@ def edit_bpmn(self) -> list: return updated_process - def _apply_initial_edit(self) -> list: - response = self._get_initial_edit_proposal() - updated_process = self._attempt_process_update_with_retries( - self.process, response - ) - return updated_process - - def _get_initial_edit_proposal(self, max_retries: int = 3) -> dict: + def _apply_initial_edit(self, max_retries: int = 4) -> list: """ - Get an initial edit proposal from the LLM. + Apply the initial edit to the process. Args: max_retries: The maximum number of retries to perform if the response is invalid Returns: - The initial edit proposal (function and arguments) + The updated process """ attempts = 0 @@ -70,43 +63,82 @@ def _get_initial_edit_proposal(self, max_retries: int = 3) -> dict: while attempts < max_retries: attempts += 1 + # Get initial edit proposal try: - response = self.llm_facade.call(prompt, structured_output=EditProposal) - logger.info(f"Edit proposal: {response}") - self._validate_llm_response(response) - return response + edit_proposal: EditProposal = self.llm_facade.call( + prompt, structured_output=EditProposal + ) + logger.info(f"Edit proposal: {edit_proposal}") + self._validate_edit_proposal(edit_proposal) + + # Update process based on the edit proposal + try: + updated_process = self._update_process(self.process, edit_proposal) + return updated_process + except ProcessException as e: + logger.warning(f"Validation error (attempt {attempts}): {str(e)}") + prompt = f"Error: {str(e)}. Try again. Change request: {self.change_request}" except ValueError as e: logger.warning(f"Validation error (attempt {attempts}): {str(e)}") prompt = f"Editing error: {str(e)}. Provide a new edit proposal." raise Exception("Max number of retries reached.") - def _attempt_process_update_with_retries( - self, process: list, edit_proposal: dict, max_retries: int = 3 + def _apply_intermediate_edits( + self, + updated_process: list, + max_retries: int = 4, + max_num_of_iterations: int = 7, ) -> list: - attempts = 0 - - while attempts < max_retries: - attempts += 1 - - try: - updated_process = self._update_process(process, edit_proposal) - return updated_process - except ProcessException as e: - error_message = str(e) - logger.warning( - f"Validation error (attempt {attempts}): {error_message}" - ) - - new_prompt = f"Error: {error_message}. Try again. Change request: {self.change_request}" - - edit_proposal = self.llm_facade.call(new_prompt) - logger.info(f"New edit proposal: {edit_proposal}") + """ + Apply intermediate edits to the process. + Args: + updated_process: The updated process after the initial edit + max_retries: The maximum number of retries to perform if the response is invalid + max_num_of_iterations: The maximum number of iterations to perform + Returns: + The updated process + """ + for _ in range(max_num_of_iterations): + attempts = 0 - if "stop" in edit_proposal: - return process + prompt = prepare_prompt( + "edit_bpmn_intermediate_step.txt", + process=str(updated_process), + ) - raise Exception("Max number of retries reached. Process not fully edited.") + while attempts < max_retries: + attempts += 1 + + # Get intermediate edit proposal + try: + edit_proposal: IntermediateEditProposal = self.llm_facade.call( + prompt, structured_output=IntermediateEditProposal + ) + logger.info(f"Intermediate edit proposal: {edit_proposal}") + self._validate_edit_proposal(edit_proposal, is_first_edit=False) + + if "stop" in edit_proposal: + logger.info("Edit process stopped.") + return updated_process + + # Update process based on the edit proposal + try: + updated_process = self._update_process( + updated_process, edit_proposal + ) + except ProcessException as e: + logger.warning( + f"Validation error (attempt {attempts}): {str(e)}" + ) + prompt = ( + f"Editing error: {str(e)}. Provide a new edit proposal." + ) + except ValueError as e: + logger.warning(f"Validation error (attempt {attempts}): {str(e)}") + prompt = f"Editing error: {str(e)}. Provide a new edit proposal." + + raise Exception("Max number of editing iterations reached.") def _update_process(self, process: list, edit_proposal: dict) -> list: """ @@ -116,6 +148,8 @@ def _update_process(self, process: list, edit_proposal: dict) -> list: edit_proposal: The edit proposal from the LLM (function and args) Returns: The updated process + Raises: + ProcessException: If the edit proposal is invalid """ edit_functions = { "delete_element": delete_element, @@ -131,77 +165,32 @@ def _update_process(self, process: list, edit_proposal: dict) -> list: res = edit_functions[function_to_call](process, **args) return res["process"] - def _apply_intermediate_edits( - self, - updated_process: list, - max_num_of_iterations: int = 7, - ) -> list: - for _ in range(max_num_of_iterations): - response = self._get_intermediate_edit_proposal(updated_process) - - if "stop" in response: - logger.info("Edit process stopped.") - return updated_process - else: - # The 'response' is the edit proposal (function and arguments) - updated_process = self._attempt_process_update_with_retries( - updated_process, response - ) - - raise Exception("Max number of iterations reached. Process not fully edited.") - - def _get_intermediate_edit_proposal( - self, updated_process: list, max_retries: int = 3 - ) -> dict: + def _validate_edit_proposal( + self, edit_proposal: dict, is_first_edit: bool = True + ) -> None: """ - Get an intermediate edit proposal from the LLM. + Validate the edit proposal from the LLM. Args: - updated_process: The updated BPMN process - max_retries: The maximum number of retries to perform if the response is invalid - Returns: - The intermediate edit proposal (function and arguments, or 'stop') + edit_proposal: The edit proposal from the LLM + is_first_edit: Whether the response is for the initial edit + Raises: + ValueError: If the edit proposal is invalid """ - attempts = 0 - - prompt = prepare_prompt( - "edit_bpmn_intermediate_step.txt", - process=str(updated_process), - ) - - while attempts < max_retries: - attempts += 1 - - try: - response = self.llm_facade.call( - prompt, structured_output=IntermediateEditProposal - ) - logger.info(f"Intermediate edit proposal: {response}") - self._validate_llm_response(response, is_first_edit=False) - return response - except ValueError as e: - logger.warning(f"Validation error (attempt {attempts}): {str(e)}") - prompt = f"Editing error: {str(e)}. Provide a new edit proposal." - - raise Exception("Max number of retries reached.") - def _validate_llm_response( - self, response: dict, is_first_edit: bool = True - ) -> None: - - if not is_first_edit and "stop" in response: - if len(response) > 1: + if not is_first_edit and "stop" in edit_proposal: + if len(edit_proposal) > 1: raise ValueError( "If 'stop' key is present, no other key should be provided." ) return - if "function" not in response or "arguments" not in response: + if "function" not in edit_proposal or "arguments" not in edit_proposal: raise ValueError( "Function call should contain 'function' and 'arguments' keys." ) - function_to_call = response["function"] - args = response["arguments"] + function_to_call = edit_proposal["function"] + args = edit_proposal["arguments"] if function_to_call == "delete_element": self._validate_delete_element(args) diff --git a/src/bpmn_assistant/services/process_editing/helpers.py b/src/bpmn_assistant/services/process_editing/helpers.py index e0e5835..76c165b 100644 --- a/src/bpmn_assistant/services/process_editing/helpers.py +++ b/src/bpmn_assistant/services/process_editing/helpers.py @@ -1,5 +1,7 @@ from typing import Optional +from bpmn_assistant.core.enums import BPMNElementType +from bpmn_assistant.core.exceptions import ProcessException from bpmn_assistant.services.process_editing.position import Position @@ -14,16 +16,16 @@ def get_all_ids(process: list[dict]): ids = [] for element in process: ids.append(element["id"]) - if element["type"] == "exclusiveGateway": + if element["type"] == BPMNElementType.EXCLUSIVE_GATEWAY.value: for branch in element["branches"]: ids += get_all_ids(branch["path"]) - elif element["type"] == "parallelGateway": + elif element["type"] == BPMNElementType.PARALLEL_GATEWAY.value: for branch in element["branches"]: ids += get_all_ids(branch) return ids -def find_position_in_process( +def _find_position_in_process( process: list[dict], target_id: str, after: bool = False, @@ -44,9 +46,9 @@ def find_position_in_process( current_path = path + [index] if element["id"] == target_id: return {"index": index + 1 if after else index, "path": path} - if element["type"] == "exclusiveGateway": + if element["type"] == BPMNElementType.EXCLUSIVE_GATEWAY.value: for branch_index, branch in enumerate(element["branches"]): - result = find_position_in_process( + result = _find_position_in_process( branch["path"], target_id, after, @@ -54,9 +56,9 @@ def find_position_in_process( ) if result: return result - elif element["type"] == "parallelGateway": + elif element["type"] == BPMNElementType.PARALLEL_GATEWAY.value: for branch_index, branch in enumerate(element["branches"]): - result = find_position_in_process( + result = _find_position_in_process( branch, target_id, after, current_path + ["branches", branch_index] ) if result: @@ -82,25 +84,25 @@ def find_position( position = None if before_id is None and after_id is None: - raise Exception("Both before_id and after_id cannot be None") + raise ProcessException("Both before_id and after_id cannot be None") elif before_id is not None and after_id is not None: - raise Exception("Only one of before_id and after_id can be specified") + raise ProcessException("Only one of before_id and after_id can be specified") elif before_id is not None: if before_id not in ids: - raise Exception(f"Element with id {before_id} does not exist") - position = find_position_in_process(process, before_id) + raise ProcessException(f"Element with id {before_id} does not exist") + position = _find_position_in_process(process, before_id) elif after_id is not None: if after_id not in ids: - raise Exception(f"Element with id {after_id} does not exist") - position = find_position_in_process(process, after_id, after=True) + raise ProcessException(f"Element with id {after_id} does not exist") + position = _find_position_in_process(process, after_id, after=True) if position is None: - raise Exception("Element not found") + raise ProcessException("Element not found") return Position(position["index"], position["path"]) -def find_branch_by_condition( +def _find_branch_by_condition( process: list[dict], target_condition: str, path: Optional[list] = None ) -> dict | None: """ @@ -115,7 +117,7 @@ def find_branch_by_condition( path = path or [] for index, element in enumerate(process): current_path = path + [index] - if element["type"] == "exclusiveGateway": + if element["type"] == BPMNElementType.EXCLUSIVE_GATEWAY.value: for branch_index, branch in enumerate(element["branches"]): if branch["condition"] == target_condition: return { @@ -126,16 +128,16 @@ def find_branch_by_condition( # If not found in this gateway, search nested gateways for branch_index, branch in enumerate(element["branches"]): - result = find_branch_by_condition( + result = _find_branch_by_condition( branch["path"], target_condition, current_path + ["branches", branch_index, "path"], ) if result: return result - elif element["type"] == "parallelGateway": + elif element["type"] == BPMNElementType.PARALLEL_GATEWAY.value: for branch_index, branch in enumerate(element["branches"]): - result = find_branch_by_condition( + result = _find_branch_by_condition( branch, target_condition, current_path + ["branches", branch_index] ) if result: @@ -153,9 +155,9 @@ def find_branch_position(process: list[dict], condition: str) -> Position: Position: A class that contains the index and path of the branch. Example: Position(index=1, path=[2, "branches"]) """ - result = find_branch_by_condition(process, condition) + result = _find_branch_by_condition(process, condition) if result is None: - raise Exception(f"Branch with condition '{condition}' does not exist") + raise ProcessException(f"Branch with condition '{condition}' does not exist") return Position(result["branch_index"], result["path"] + ["branches"]) diff --git a/src/bpmn_assistant/services/validate_bpmn.py b/src/bpmn_assistant/services/validate_bpmn.py index 5164f7b..07a6d6a 100644 --- a/src/bpmn_assistant/services/validate_bpmn.py +++ b/src/bpmn_assistant/services/validate_bpmn.py @@ -59,9 +59,9 @@ def validate_element(element: dict) -> None: _validate_parallel_gateway(element) -def _validate_task(element: dict) -> Optional[str]: +def _validate_task(element: dict) -> None: if "label" not in element: - raise Exception(f"Task element is missing a label: {element}") + raise ValueError(f"Task element is missing a label: {element}") try: BPMNTask.model_validate(element) @@ -69,16 +69,16 @@ def _validate_task(element: dict) -> Optional[str]: raise ValueError(f"Invalid task element: {element}") -def _validate_exclusive_gateway(element: dict) -> Optional[str]: +def _validate_exclusive_gateway(element: dict) -> None: if "label" not in element: - raise Exception(f"Exclusive gateway is missing a label: {element}") + raise ValueError(f"Exclusive gateway is missing a label: {element}") if "branches" not in element or not isinstance(element["branches"], list): - raise Exception( + raise ValueError( f"Exclusive gateway is missing or has invalid 'branches': {element}" ) for branch in element["branches"]: if "condition" not in branch or "path" not in branch: - raise Exception(f"Invalid branch in exclusive gateway: {branch}") + raise ValueError(f"Invalid branch in exclusive gateway: {branch}") try: ExclusiveGateway.model_validate(element) @@ -86,10 +86,10 @@ def _validate_exclusive_gateway(element: dict) -> Optional[str]: raise ValueError(f"Invalid exclusive gateway element: {element}") -def _validate_parallel_gateway(element: dict) -> Optional[str]: +def _validate_parallel_gateway(element: dict) -> None: if "branches" not in element or not isinstance(element["branches"], list): - raise Exception( - f"Parallel gateway is missing or has invalid 'branches': {element}" + raise ValueError( + f"Parallel gateway has missing or invalid 'branches': {element}" ) try: diff --git a/src/bpmn_frontend/src/components/ModelPicker.vue b/src/bpmn_frontend/src/components/ModelPicker.vue index f94095f..7482d2d 100644 --- a/src/bpmn_frontend/src/components/ModelPicker.vue +++ b/src/bpmn_frontend/src/components/ModelPicker.vue @@ -18,6 +18,7 @@ const Models = Object.freeze({ GPT_4O_MINI: "gpt-4o-mini", GPT_4O: "gpt-4o", O1: "o1-preview", + O1_MINI: "o1-mini", HAIKU_3_5: "claude-3-5-haiku-20241022", SONNET_3_5: "claude-3-5-sonnet-20241022", GEMINI_1_5_PRO: "gemini-1.5-pro", @@ -47,6 +48,11 @@ export default { provider: Providers.OPENAI, }, { value: Models.GPT_4O, title: "GPT-4o", provider: Providers.OPENAI }, + { + value: Models.O1_MINI, + title: "GPT-4o mini + o1-mini", + provider: Providers.OPENAI, + }, { value: Models.O1, title: "GPT-4o + o1",