Skip to content

Commit

Permalink
Add DialoguePathSampler class for generating dialogue paths from a graph
Browse files Browse the repository at this point in the history
  • Loading branch information
NotBioWaste905 committed Dec 13, 2024
1 parent c84e334 commit 939ffb4
Showing 1 changed file with 47 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import random
import networkx as nx
from chatsky_llm_autoconfig.graph import BaseGraph
from chatsky_llm_autoconfig.algorithms.base import DialogueGenerator
from chatsky_llm_autoconfig.dialogue import Dialogue
Expand Down Expand Up @@ -71,3 +72,49 @@ def invoke(self, graph: BaseGraph, start_node: int = 1, end_node: int = -1, topi

async def ainvoke(self, *args, **kwargs):
return self.invoke(*args, **kwargs)


@AlgorithmRegistry.register(input_type=BaseGraph, output_type=Dialogue)
class DialoguePathSampler(DialogueGenerator):
def invoke(self, graph: BaseGraph, start_node: int = 1, end_node: int = -1, topic="") -> list[Dialogue]:
nx_graph = graph.graph

# Find all nodes with no outgoing edges (end nodes)
end_nodes = [node for node in nx_graph.nodes() if nx_graph.out_degree(node) == 0]
dialogues = []
# If no end nodes found, return empty list
if not end_nodes:
return []

all_paths = []
# Get paths from start node to each end node
for end in end_nodes:
paths = list(nx.all_simple_paths(nx_graph, source=start_node, target=end))
all_paths.extend(paths)

for path in all_paths:
dialogue_turns = []
# Process each node and edge in the path
for i in range(len(path)):
# Add assistant utterance from current node
current_node = path[i]
assistant_utterance = random.choice(nx_graph.nodes[current_node]["utterances"])
dialogue_turns.append({"text": assistant_utterance, "participant": "assistant"})

# Add user utterance from edge (if not at last node)
if i < len(path) - 1:
next_node = path[i + 1]
edge_data = nx_graph.edges[current_node, next_node]
user_utterance = (
random.choice(edge_data["utterances"])
if isinstance(edge_data["utterances"], list)
else edge_data["utterances"]
)
dialogue_turns.append({"text": user_utterance, "participant": "user"})

dialogues.append(Dialogue().from_list(dialogue_turns))

return dialogues

async def ainvoke(self, *args, **kwargs):
return self.invoke(*args, **kwargs)

0 comments on commit 939ffb4

Please sign in to comment.