From 939ffb478ad91e0248f1ac462c0b9028b9a7f087 Mon Sep 17 00:00:00 2001 From: NotBioWaste905 Date: Fri, 13 Dec 2024 14:31:21 +0300 Subject: [PATCH] Add DialoguePathSampler class for generating dialogue paths from a graph --- .../algorithms/dialogue_generation.py | 47 +++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/algorithms/dialogue_generation.py b/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/algorithms/dialogue_generation.py index dbbb036..7448dc2 100644 --- a/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/algorithms/dialogue_generation.py +++ b/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/algorithms/dialogue_generation.py @@ -1,4 +1,5 @@ import random +import networkx as nx from chatsky_llm_autoconfig.graph import BaseGraph from chatsky_llm_autoconfig.algorithms.base import DialogueGenerator from chatsky_llm_autoconfig.dialogue import Dialogue @@ -71,3 +72,49 @@ def invoke(self, graph: BaseGraph, start_node: int = 1, end_node: int = -1, topi async def ainvoke(self, *args, **kwargs): return self.invoke(*args, **kwargs) + + +@AlgorithmRegistry.register(input_type=BaseGraph, output_type=Dialogue) +class DialoguePathSampler(DialogueGenerator): + def invoke(self, graph: BaseGraph, start_node: int = 1, end_node: int = -1, topic="") -> list[Dialogue]: + nx_graph = graph.graph + + # Find all nodes with no outgoing edges (end nodes) + end_nodes = [node for node in nx_graph.nodes() if nx_graph.out_degree(node) == 0] + dialogues = [] + # If no end nodes found, return empty list + if not end_nodes: + return [] + + all_paths = [] + # Get paths from start node to each end node + for end in end_nodes: + paths = list(nx.all_simple_paths(nx_graph, source=start_node, target=end)) + all_paths.extend(paths) + + for path in all_paths: + dialogue_turns = [] + # Process each node and edge in the path + for i in range(len(path)): + # Add assistant utterance from current node + current_node = path[i] + assistant_utterance = random.choice(nx_graph.nodes[current_node]["utterances"]) + dialogue_turns.append({"text": assistant_utterance, "participant": "assistant"}) + + # Add user utterance from edge (if not at last node) + if i < len(path) - 1: + next_node = path[i + 1] + edge_data = nx_graph.edges[current_node, next_node] + user_utterance = ( + random.choice(edge_data["utterances"]) + if isinstance(edge_data["utterances"], list) + else edge_data["utterances"] + ) + dialogue_turns.append({"text": user_utterance, "participant": "user"}) + + dialogues.append(Dialogue().from_list(dialogue_turns)) + + return dialogues + + async def ainvoke(self, *args, **kwargs): + return self.invoke(*args, **kwargs) \ No newline at end of file