diff --git a/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/algorithms/base.py b/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/algorithms/base.py index 26da80c..08c7639 100644 --- a/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/algorithms/base.py +++ b/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/algorithms/base.py @@ -2,8 +2,9 @@ from pydantic import BaseModel import abc from chatsky_llm_autoconfig.graph import BaseGraph -from chatsky_llm_autoconfig.dialogue import Dialogue +from chatsky_llm_autoconfig.schemas import Dialogue from langchain_core.language_models.chat_models import BaseChatModel +from langchain.prompts import PromptTemplate class BaseAlgorithm(BaseModel, abc.ABC): @@ -72,7 +73,7 @@ async def ainvoke(self, topic: str, graph: BaseGraph) -> BaseGraph: class TopicGraphGenerator(BaseAlgorithm): """Graph generator that works only with topics.""" - def invoke(self, topic: str, model: BaseChatModel) -> BaseGraph: + def invoke(self, model: BaseChatModel, prompt: PromptTemplate) -> BaseGraph: raise NotImplementedError async def ainvoke(self, topic: str) -> BaseGraph: diff --git a/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/algorithms/cycle_graph_generation_pipeline.py b/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/algorithms/cycle_graph_generation_pipeline.py new file mode 100644 index 0000000..dec4092 --- /dev/null +++ b/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/algorithms/cycle_graph_generation_pipeline.py @@ -0,0 +1,238 @@ +from dataclasses import dataclass +from typing import Optional, Dict, Any +import networkx as nx +from langchain_core.language_models.chat_models import BaseChatModel +from langchain.prompts import PromptTemplate +from chatsky_llm_autoconfig.algorithms.topic_graph_generation import CycleGraphGenerator +from chatsky_llm_autoconfig.algorithms.dialogue_generation import RecursiveDialogueSampler +from chatsky_llm_autoconfig.metrics.automatic_metrics import all_utterances_present +from chatsky_llm_autoconfig.metrics.llm_metrics import graph_validation, is_theme_valid +from chatsky_llm_autoconfig.graph import BaseGraph +from chatsky_llm_autoconfig.prompts import cycle_graph_generation_prompt_enhanced, cycle_graph_repair_prompt +from openai import BaseModel + +from enum import Enum +from typing import Union + +from chatsky_llm_autoconfig.schemas import GraphGenerationResult + + +class ErrorType(str, Enum): + """Types of errors that can occur during generation""" + INVALID_GRAPH_STRUCTURE = "invalid_graph_structure" + TOO_MANY_CYCLES = "too_many_cycles" + SAMPLING_FAILED = "sampling_failed" + INVALID_THEME = "invalid_theme" + GENERATION_FAILED = "generation_failed" + + +class GenerationError(BaseModel): + """Base error with essential fields""" + error_type: ErrorType + message: str + + +PipelineResult = Union[GraphGenerationResult, GenerationError] + + +@dataclass +class GraphGenerationPipeline: + generation_model: BaseChatModel + validation_model: BaseChatModel + graph_generator: CycleGraphGenerator + generation_prompt: PromptTemplate + repair_prompt: PromptTemplate + min_cycles: int = 2 + max_fix_attempts: int = 3 + + def __init__( + self, + generation_model: BaseChatModel, + validation_model: BaseChatModel, + generation_prompt: Optional[PromptTemplate] = None, + repair_prompt: Optional[PromptTemplate] = None, + min_cycles: int = 2, + max_fix_attempts: int = 3 + ): + self.generation_model = generation_model + self.validation_model = validation_model + self.graph_generator = CycleGraphGenerator() + self.dialogue_sampler = RecursiveDialogueSampler() + + self.generation_prompt = generation_prompt or cycle_graph_generation_prompt_enhanced + self.repair_prompt = repair_prompt or cycle_graph_repair_prompt + + self.min_cycles = min_cycles + self.max_fix_attempts = max_fix_attempts + + def validate_graph_cycle_requirement( + self, + graph: BaseGraph, + min_cycles: int = 2 + ) -> Dict[str, Any]: + """ + Проверяет граф на соответствие требованиям по количеству циклов + """ + print("\n🔍 Checking graph requirements...") + + try: + cycles = list(nx.simple_cycles(graph.graph)) + cycles_count = len(cycles) + + print(f"🔄 Found {cycles_count} cycles in the graph:") + for i, cycle in enumerate(cycles, 1): + print(f"Cycle {i}: {' -> '.join(map(str, cycle + [cycle[0]]))}") + + meets_requirements = cycles_count >= min_cycles + + if not meets_requirements: + print(f"❌ Graph doesn't meet cycle requirements (minimum {min_cycles} cycles needed)") + else: + print("✅ Graph meets cycle requirements") + + return { + "meets_requirements": meets_requirements, + "cycles": cycles, + "cycles_count": cycles_count + } + + except Exception as e: + print(f"❌ Validation error: {str(e)}") + raise + + def check_and_fix_transitions( + self, + graph: BaseGraph, + max_attempts: int = 3 + ) -> Dict[str, Any]: + """ + Проверяет переходы в графе и пытается исправить невалидные через LLM + """ + print("Validating initial graph") + + initial_validation = graph_validation(graph, self.validation_model) + if initial_validation["is_valid"]: + return { + "is_valid": True, + "graph": graph, + "validation_details": { + "invalid_transitions": [], + "attempts_made": 0, + "fixed_count": 0 + } + } + + initial_invalid_count = len(initial_validation["invalid_transitions"]) + current_graph = graph + current_attempt = 0 + + while current_attempt < max_attempts: + print(f"\n🔄 Fix attempt {current_attempt + 1}/{max_attempts}") + + try: + # Используем generation_model для исправления графа + current_graph = self.graph_generator.invoke( + model=self.generation_model, + prompt=self.repair_prompt, + invalid_transitions=initial_validation["invalid_transitions"], + graph_json=current_graph.graph_dict + ) + + # Проверяем исправленный граф используя validation_model + validation = graph_validation(current_graph, self.validation_model) + if validation["is_valid"]: + return { + "is_valid": True, + "graph": current_graph, + "validation_details": { + "invalid_transitions": [], + "attempts_made": current_attempt + 1, + "fixed_count": initial_invalid_count + } + } + + except Exception as e: + print(f"⚠️ Error during fix attempt: {str(e)}") + break + + current_attempt += 1 + + remaining_invalid = len(validation["invalid_transitions"]) + + return { + "is_valid": False, + "graph": current_graph, + "validation_details": { + "invalid_transitions": validation["invalid_transitions"], + "attempts_made": current_attempt, + "fixed_count": initial_invalid_count - remaining_invalid + } + } + + def generate_and_validate(self, topic: str) -> PipelineResult: + """ + Generates and validates a dialogue graph for given topic + """ + try: + # 1. Generate initial graph + print("Generating Graph ...") + graph = self.graph_generator.invoke( + model=self.generation_model, + prompt=self.generation_prompt, + topic=topic + ) + + # 2. Validate cycles + cycle_validation = self.validate_graph_cycle_requirement(graph, self.min_cycles) + if not cycle_validation["meets_requirements"]: + return GenerationError( + error_type=ErrorType.TOO_MANY_CYCLES, + message=f"Graph requires minimum {self.min_cycles} cycles, found {cycle_validation['cycles_count']}" + ) + + # 3. Generate and validate dialogues + print("Sampling dialogues...") + sampled_dialogues = self.dialogue_sampler.invoke(graph, 1, -1) + if not all_utterances_present(graph, sampled_dialogues): + return GenerationError( + error_type=ErrorType.SAMPLING_FAILED, + message="Failed to sample valid dialogues - not all utterances are present" + ) + + # 4. Validate theme + theme_validation = is_theme_valid(graph, self.validation_model, topic) + if not theme_validation["value"]: + return GenerationError( + error_type=ErrorType.INVALID_THEME, + message=f"Theme validation failed: {theme_validation['description']}" + ) + + # 5. Validate and fix transitions + transition_validation = self.check_and_fix_transitions( + graph=graph, + max_attempts=self.max_fix_attempts + ) + + if not transition_validation["is_valid"]: + invalid_transitions = transition_validation["validation_details"]["invalid_transitions"] + return GenerationError( + error_type=ErrorType.INVALID_GRAPH_STRUCTURE, + message=f"Found {len(invalid_transitions)} invalid transitions after {transition_validation['validation_details']['attempts_made']} fix attempts" + ) + + # All validations passed - return successful result + return GraphGenerationResult( + graph=transition_validation["graph"].graph_dict, + topic=topic, + dialogues=sampled_dialogues + ) + + except Exception as e: + return GenerationError( + error_type=ErrorType.GENERATION_FAILED, + message=f"Unexpected error during generation: {str(e)}" + ) + + def __call__(self, topic: str) -> PipelineResult: + """Shorthand for generate_and_validate""" + return self.generate_and_validate(topic) diff --git a/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/algorithms/dialogue_augmentation.py b/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/algorithms/dialogue_augmentation.py index 37261e7..76757ad 100644 --- a/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/algorithms/dialogue_augmentation.py +++ b/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/algorithms/dialogue_augmentation.py @@ -1,4 +1,4 @@ -from chatsky_llm_autoconfig.dialogue import Dialogue +from chatsky_llm_autoconfig.schemas import Dialogue from chatsky_llm_autoconfig.schemas import DialogueMessage from chatsky_llm_autoconfig.autometrics.registry import AlgorithmRegistry diff --git a/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/algorithms/dialogue_generation.py b/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/algorithms/dialogue_generation.py index 43526b1..44a02ca 100644 --- a/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/algorithms/dialogue_generation.py +++ b/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/algorithms/dialogue_generation.py @@ -3,7 +3,7 @@ import networkx as nx from chatsky_llm_autoconfig.graph import BaseGraph from chatsky_llm_autoconfig.algorithms.base import DialogueGenerator -from chatsky_llm_autoconfig.dialogue import Dialogue +from chatsky_llm_autoconfig.schemas import Dialogue from chatsky_llm_autoconfig.autometrics.registry import AlgorithmRegistry diff --git a/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/algorithms/topic_graph_generation.py b/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/algorithms/topic_graph_generation.py index 054fa93..a337d70 100644 --- a/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/algorithms/topic_graph_generation.py +++ b/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/algorithms/topic_graph_generation.py @@ -1,4 +1,3 @@ -from typing import Optional from chatsky_llm_autoconfig.algorithms.base import TopicGraphGenerator from chatsky_llm_autoconfig.autometrics.registry import AlgorithmRegistry from chatsky_llm_autoconfig.schemas import DialogueGraph @@ -7,82 +6,32 @@ from chatsky_llm_autoconfig.graph import BaseGraph, Graph from langchain_core.language_models.chat_models import BaseChatModel -from pydantic import Field -from typing import ClassVar - @AlgorithmRegistry.register(input_type=str, output_type=BaseGraph) class CycleGraphGenerator(TopicGraphGenerator): """Generator specifically for topic-based cyclic graphs""" - DEFAULT_TEMPLATE: ClassVar[str] = """ - Create a complex dialogue graph where the conversation MUST return to an existing node. - - **CRITICAL: Response Specificity** - Responses must acknowledge and build upon what the user has already specified: - - INCORRECT flow: - - User: "I'd like to order a coffee" - - Staff: "What would you like to order?" (TOO GENERAL - ignores that they specified coffee) - - CORRECT flow: - - User: "I'd like to order a coffee" - - Staff: "What kind of coffee would you like?" (GOOD - acknowledges they want coffee) - - Example of a CORRECT cyclic graph for a coffee shop: - "edges": [ - {{ "source": 1, "target": 2, "utterances": ["Hi, I'd like to order a coffee"] }}, - {{ "source": 2, "target": 3, "utterances": ["A large latte please"] }}, - {{ "source": 3, "target": 4, "utterances": ["Yes, that's correct"] }}, - {{ "source": 4, "target": 5, "utterances": ["Here's my payment"] }}, - {{ "source": 5, "target": 2, "utterances": ["I'd like to order another coffee"] }} - ], - "nodes": [ - {{ "id": 1, "label": "welcome", "is_start": true, "utterances": ["Welcome! How can I help you today?"] }}, - {{ "id": 2, "label": "ask_coffee_type", "is_start": false, "utterances": ["What kind of coffee would you like?"] }}, - {{ "id": 3, "label": "confirm", "is_start": false, "utterances": ["That's a large latte. Is this correct?"] }}, - {{ "id": 4, "label": "payment", "is_start": false, "utterances": ["Great! That'll be $5. Please proceed with payment."] }}, - {{ "id": 5, "label": "completed", "is_start": false, "utterances": ["Thank you! Would you like another coffee?"] }} - ] - - **Rules:** - 1) Responses must acknowledge what the user has already specified - 2) The final node MUST connect back to an existing node - 3) Each node must have clear purpose - 4) Return ONLY the JSON without commentary - 5) Graph must be cyclic - no dead ends - 6) All edges must connect to existing nodes - 7) The cycle point should make logical sense - - **Your task is to create a cyclic dialogue graph about the following topic:** {topic}. - """ - - cycle_graph_generation_prompt: PromptTemplate = Field( - default_factory=lambda: PromptTemplate.from_template(CycleGraphGenerator.DEFAULT_TEMPLATE) - ) - - def __init__(self, prompt: Optional[PromptTemplate] = None): + def __init__(self): super().__init__() - if prompt is not None: - self.cycle_graph_generation_prompt = prompt - def invoke(self, topic: str, model: BaseChatModel) -> BaseGraph: + def invoke(self, model: BaseChatModel, prompt: PromptTemplate, **kwargs) -> BaseGraph: """ Generate a cyclic dialogue graph based on the topic input. Args: - topic (str): The topic for the dialogue graph - model_name (str): The name of the model to use + model (BaseChatModel): The model to use for generation + prompt (PromptTemplate): Prepared prompt template + **kwargs: Additional arguments for formatting the prompt Returns: BaseGraph: Generated Graph object with cyclic structure """ + # Создаем цепочку: промпт -> модель -> парсер parser = JsonOutputParser(pydantic_object=DialogueGraph) + chain = prompt | model | parser - chain = self.cycle_graph_generation_prompt | model | parser - - generated_graph = chain.invoke({"topic": topic}) - return Graph(generated_graph) + # Передаем kwargs как входные данные для цепочки + return Graph(chain.invoke(kwargs)) async def ainvoke(self, *args, **kwargs): """ diff --git a/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/autometrics/run_autometrics.py b/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/autometrics/run_autometrics.py index 4440b65..53c5d1f 100644 --- a/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/autometrics/run_autometrics.py +++ b/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/autometrics/run_autometrics.py @@ -4,7 +4,7 @@ from chatsky_llm_autoconfig.algorithms.dialogue_generation import DialogueSampler import json from chatsky_llm_autoconfig.graph import Graph, BaseGraph -from chatsky_llm_autoconfig.dialogue import Dialogue +from chatsky_llm_autoconfig.schemas import Dialogue from chatsky_llm_autoconfig.metrics.automatic_metrics import * from chatsky_llm_autoconfig.metrics.llm_metrics import are_triplets_valid, is_theme_valid import datetime diff --git a/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/dialogue.py b/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/dialogue.py deleted file mode 100644 index b193aef..0000000 --- a/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/dialogue.py +++ /dev/null @@ -1,127 +0,0 @@ -import networkx as nx -from typing import List, Union, Dict -from chatsky_llm_autoconfig.schemas import DialogueMessage -from pydantic import BaseModel, Field, ConfigDict - - -class Dialogue(BaseModel): - """Represents a complete dialogue consisting of multiple messages. - - The class provides methods for creating dialogues from different formats - and converting dialogues to various representations. - """ - - messages: List[DialogueMessage] = Field(default_factory=list) - topic: str = "" - validate: bool = Field(default=True, description="Whether to validate messages upon initialization") - - model_config = ConfigDict( - arbitrary_types_allowed=True, - frozen=False, # Dialogue needs to be mutable to append messages - ) - - def __init__(self, **data): - super().__init__(**data) - if self.validate: - self.__validate(self.messages) - - @classmethod - def from_string(cls, string: str) -> "Dialogue": - """Creates a Dialogue from a tab-separated string format. - - Args: - string: Tab-separated string with format: "participant\ttext\n" - - Returns: - Dialogue object with parsed messages - """ - messages: List[DialogueMessage] = [ - DialogueMessage(participant=line.split("\t")[0], text=line.split("\t")[1]) for line in string.strip().split("\n") - ] - return cls(messages=messages) - - @classmethod - def from_list(cls, messages: List[Dict[str, str]], validate: bool = True) -> "Dialogue": - """Create a Dialogue from a list of dictionaries.""" - dialogue_messages = [DialogueMessage(**m) for m in messages] - return cls(messages=dialogue_messages, validate=validate) - - @classmethod - def from_nodes_ids(cls, graph, node_list, validate: bool = True) -> "Dialogue": - utts = [] - nodes_attributes = nx.get_node_attributes(graph.graph, "utterances") - edges_attributes = nx.get_edge_attributes(graph.graph, "utterances") - for node in range(len(node_list)): - utts.append({"participant": "assistant", "text": nodes_attributes[node_list[node]][0]}) - if node == len(node_list) - 1: - if graph.graph.has_edge(node_list[node], node_list[0]): - utts.append({"participant": "user", "text": edges_attributes[(node_list[node], node_list[0])][0]}) - else: - if graph.graph.has_edge(node_list[node], node_list[node + 1]): - utts.append({"participant": "user", "text": edges_attributes[(node_list[node], node_list[node + 1])][0]}) - - return cls(messages=utts, validate=validate) - - def to_list(self) -> List[Dict[str, str]]: - """Converts Dialogue to a list of message dictionaries.""" - return [msg.model_dump() for msg in self.messages] - - def __str__(self) -> str: - """Returns a readable string representation of the dialogue.""" - return "\n".join(f"{msg.participant}: {msg.text}" for msg in self.messages).strip() - - def append(self, text: str, participant: str) -> None: - """Adds a new message to the dialogue. - - Args: - text: Content of the message - participant: Sender of the message - """ - self.messages.append(DialogueMessage(text=text, participant=participant)) - - def extend(self, messages: List[Union[DialogueMessage, Dict[str, str]]]) -> None: - """Adds multiple messages to the dialogue. - - Args: - messages: List of DialogueMessage objects or dicts to add - """ - new_messages = [msg if isinstance(msg, DialogueMessage) else DialogueMessage(**msg) for msg in messages] - self.__validate(new_messages) - self.messages.extend(new_messages) - - def __validate(self, messages): - """Ensure that messages meets expectations.""" - if not messages: - return - - # Check if first message is from assistant - if messages[0].participant != "assistant": - raise ValueError(f"First message must be from assistant, got: {messages[0]}") - - # Check for consecutive messages from same participant - for i in range(len(messages) - 1): - if messages[i].participant == messages[i + 1].participant: - raise ValueError(f"Cannot have consecutive messages from the same participant. Messages: {messages[i]}, {messages[i + 1]}") - - -# Type-safe usage examples -if __name__ == "__main__": - # Create from list of dicts - dialogue1 = Dialogue( - messages=[DialogueMessage(text="How can I help?", participant="assistant"), DialogueMessage(text="I need coffee", participant="user")] - ) - - # Create using from_list - dialogue2 = Dialogue.from_list([{"text": "How can I help?", "participant": "assistant"}, {"text": "I need coffee", "participant": "user"}]) - - # Create from string - dialogue3 = Dialogue.from_string( - """ - assistant\tHow can I help? - user\tI need coffee -""".strip() - ) - - # Append and extend - dialogue1.append("What kind of coffee?", "assistant") - dialogue1.extend([{"text": "Espresso please", "participant": "user"}, DialogueMessage(text="Coming right up!", participant="assistant")]) diff --git a/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/metrics/automatic_metrics.py b/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/metrics/automatic_metrics.py index 4e9ef05..dcd0398 100644 --- a/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/metrics/automatic_metrics.py +++ b/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/metrics/automatic_metrics.py @@ -9,7 +9,7 @@ import networkx as nx from chatsky_llm_autoconfig.metrics.jaccard import jaccard_edges, jaccard_nodes, collapse_multiedges from chatsky_llm_autoconfig.graph import BaseGraph -from chatsky_llm_autoconfig.dialogue import Dialogue +from chatsky_llm_autoconfig.schemas import Dialogue def edge_match_for_multigraph(x, y): diff --git a/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/metrics/llm_metrics.py b/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/metrics/llm_metrics.py index d5b0f4f..3455530 100644 --- a/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/metrics/llm_metrics.py +++ b/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/metrics/llm_metrics.py @@ -7,7 +7,9 @@ import logging import json +from typing import List, TypedDict from chatsky_llm_autoconfig.graph import BaseGraph, Graph +from chatsky_llm_autoconfig.graph import BaseGraph from langchain_core.language_models.chat_models import BaseChatModel from langchain.prompts import PromptTemplate from pydantic import BaseModel, Field @@ -172,3 +174,112 @@ class ThemeValidationResult(BaseModel): response = theme_check_chain.invoke(input_data) return {"value": response.isValid, "description": response.description} + + +class InvalidTransition(TypedDict): + from_: List[str] # Using from_ because 'from' is reserved + user: List[str] + to: List[str] + reason: str + + +class GraphValidationResult(TypedDict): + is_valid: bool + invalid_transitions: List[InvalidTransition] + + +def graph_validation(G: BaseGraph, model: BaseChatModel) -> GraphValidationResult: + """ + Проверяет валидность графа диалога + Возвращает: + { + "is_valid": bool, # валиден ли граф в целом + "invalid_transitions": [ # список невалидных переходов + { + "from": ["source utterance"], + "user": ["user utterance"], + "to": ["target utterance"], + "reason": "причина невалидности" + }, + ... + ] + } + """ + # Define validation result model + class TransitionValidationResult(BaseModel): + isValid: bool = Field(description="Whether the transition is valid or not") + description: str = Field(description="Explanation of why it's valid or invalid") + + # Create prompt template + triplet_validate_prompt = PromptTemplate( + input_variables=["json_graph", "source_utterances", "edge_utterances", "target_utterances"], + template=""" + You are evaluating if dialog transitions make logical sense. + + Given this conversation graph in JSON: + {json_graph} + + For the current transition: + Source (Assistant): {source_utterances} + User Response: {edge_utterances} + Target (Assistant): {target_utterances} + + EVALUATE: Do these three messages form a logical sequence in the conversation? + Consider: + 1. Does the assistant's first response naturally lead to the user's response? + 2. Does the user's response logically connect to the assistant's next message? + 3. Is the overall flow natural and coherent? + + Reply in JSON format: + {{"isValid": true/false, "description": "Brief explanation of why it's valid or invalid"}} + """ + ) + + parser = PydanticOutputParser(pydantic_object=TransitionValidationResult) + + # Convert graph to JSON string + graph_json = json.dumps(G.graph_dict) + + # Create node mapping + node_map = {node["id"]: node for node in G.graph_dict["nodes"]} + invalid_transitions = [] + is_valid = True + + for edge in G.graph_dict["edges"]: + source_id = edge["source"] + target_id = edge["target"] + + # Проверяем существование узлов + if source_id not in node_map or target_id not in node_map: + is_valid = False + continue + + # Get utterances + source_node = node_map[source_id] + target_node = node_map[target_id] + + # Prepare input for validation + input_data = { + "json_graph": graph_json, + "source_utterances": source_node["utterances"], + "edge_utterances": edge["utterances"], + "target_utterances": target_node["utterances"] + } + + # Run validation + triplet_check_chain = triplet_validate_prompt | model | parser + result = triplet_check_chain.invoke(input_data) + + if not result.isValid: + is_valid = False + invalid_transitions.append({ + "from": source_node["utterances"], + "user": edge["utterances"], + "to": target_node["utterances"], + "reason": result.description + }) + + return { + "is_valid": is_valid, + "invalid_transitions": invalid_transitions + } diff --git a/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/prompts.py b/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/prompts.py index d45e571..e5f41fb 100644 --- a/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/prompts.py +++ b/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/prompts.py @@ -95,191 +95,122 @@ "could'nt happen because it contradicts the rules print NO.\nDialogue: {dialog}.\nSet of rules: {rules}" ) -cycle_graph_generation_prompt = PromptTemplate.from_template( - "You have an example of dialogue from customer chatbot system. You also have an " - "example of set of rules how chatbot system works should be looking - it is " - "a set of nodes when chatbot system respons and a set of transitions that are " - "triggered by user requests. " - "Here is the example of set of rules: " - "'edges': [ [ 'source': 1, 'target': 2, 'utterances': 'I need to make an order' ], " - "[ 'source': 1, 'target': 2, 'utterances': 'I want to order from you' ], " - "[ 'source': 2, 'target': 3, 'utterances': 'I would like to purchase 'Pale Fire' and 'Anna Karenina', please' ], " - "'nodes': [ [ 'id': 1, 'label': 'start', 'is_start': true, 'utterances': [ 'How can I help?', 'Hello' ], " - "[ 'id': 2, 'label': 'ask_books', 'is_start': false, 'utterances': [ 'What books do you like?'] ] " - "I will give a dialogue, your task is to build a graph for this dialogue in the format above. We allow several edges with equal " - "source and target and also multiple responses on one node so try not to add new nodes if it is logical just to extend an " - "exsiting one. utterances in one node or on multiedge should close between each other and correspond to different answers " - "to one question or different ways to say something. " - "If two nodes has the same responses they " - "should be united in one node. Do not make up utterances that aren’t present in the dialogue. Please do not combine " - "utterances for multiedges in one list, write them separately like in example above. Every utterance from the dialogue, " - "whether it is from user or assistanst, should contain in one of the nodes. Edges must be utterances from the user. Do not forget ending nodes with goodbyes. " - "Sometimes dialogue can correspond to several iterations of loop, for example: " - "['text': 'Do you have apples?', 'participant': 'user'], " - "['text': 'Yes, add it to your cart?', 'participant': 'assistant'], " - "['text': 'No', 'participant': 'user'], " - "['text': 'Okay. Anything else?', 'participant': 'assistant'], " - "['text': 'I need a pack of chips', 'participant': 'user'], " - "['text': 'Yes, add it to your cart?', 'participant': 'assistant'], " - "['text': 'Yes', 'participant': 'user'], " - "['text': 'Done. Anything else?', 'participant': 'assistant'], " - "['text': 'No, that’s all', 'participant': 'user'], " - "it corresponds to following graph: " - "[ nodes: " - "'id': 1, " - "'label': 'confirm_availability_and_ask_to_add', " - "'is_start': false, " - "'utterances': 'Yes, add it to your cart?' " - "], " - "[ " - "'id': 2, " - "'label': 'reply_to_yes', " - "'is_start': false, " - "'utterances': ['Done. Anything else?', 'Okay. Anything else?'] " - "], " - "[ " - "'id': 3, " - "'label': 'finish_filling_cart', " - "'is_start': false, " - "'utterances': 'Okay, everything is done, you can go to cart and finish the order.' " - "], " - "edges: " - "[ " - "'source': 1, " - "'target': 2, " - "'utterances': 'Yes' " - "], " - "[ " - "'source': 1, " - "'target': 2, " - "'utterances': 'No' " - "], " - "[ " - "'source': 2, " - "'target': 1, " - "'utterances': 'I need a pack of chips' " - "], " - "[ " - "'source': 2, " - "'target': 2, " - "'utterances': 'No, that’s all' " - "]. " - "Another example:" - """ - [ - [ - "text": "How can I help?", - "participant": "assistant" - ], - [ - "text": "I need to make an order", - "participant": "user" - ], - [ - "text": "Which books would you like to order?", - "participant": "assistant" - ], - [ - "text": "One War and Piece in hard cover and one Pride and Prejudice", - "participant": "user" - ], - [ - "text": "Please, enter the payment method you would like to use: cash or credit card.", - "participant": "assistant" - ], - [ - "text": "With credit card, please", - "participant": "user" - ], - [ - "text": "Something is wrong, can you please use other payment method or start order again", - "participant": "assistant" - ], - [ - "text": "I will enter new payment method", - "participant": "user" - ] - ] - """ - "Should result in graph like this (note that even in the case of negative result 'something is wrong' it must be cycled):" - """ +cycle_graph_generation_prompt_basic = PromptTemplate.from_template(""" + Create a complex dialogue graph where the conversation MUST return to an existing node. + + **CRITICAL: Response Specificity** + Responses must acknowledge and build upon what the user has already specified: + + INCORRECT flow: + - User: "I'd like to order a coffee" + - Staff: "What would you like to order?" (TOO GENERAL - ignores that they specified coffee) + + CORRECT flow: + - User: "I'd like to order a coffee" + - Staff: "What kind of coffee would you like?" (GOOD - acknowledges they want coffee) + + Example of a CORRECT cyclic graph for a coffee shop: "edges": [ - [ - "utterances": [ - "I need to make an order", - "I want to order from you" - ], - "source": 1, - "target": 2 - ], - [ - "utterances": [ - "I would like to purchase 'Pale Fire' and 'Anna Karenina', please", - "One War and Piece in hard cover and one Pride and Prejudice" - ], - "source": 2, - "target": 3 - ], - [ - "utterances": [ - "Cash", - "With credit card, please" - ], - "source": 3, - "target": 4 - ], - [ - "utterances": [ - "I will enter new payment method" - ], - "source": 4, - "target": 3 - ], - [ - "utterances": [ - "Start new order" - ], - "source": 4, - "target": 1 - ] - ], - "nodes": [ - [ - "id": 1, - "label": "start", - "is_start": true, - "utterances": [ - "How can I help?", - "Hello" - ] - ], - [ - "id": 2, - "label": "ask_item", - "is_start": false, - "utterances": [ - "Which books would you like to order?" - ] - ], - [ - "id": 3, - "label": "ask_payment_method", - "is_start": false, - "utterances": [ - "Please, enter the payment method you would like to use: cash or credit card." - ] - ], - [ - "id": 4, - "label": "ask_to_redo", - "is_start": false, - "utterances": [ - "Something is wrong, can you please use other payment method or start order again" - ] - ] - ] + {{ "source": 1, "target": 2, "utterances": ["Hi, I'd like to order a coffee"] }}, + {{ "source": 2, "target": 3, "utterances": ["A large latte please"] }}, + {{ "source": 3, "target": 4, "utterances": ["Yes, that's correct"] }}, + {{ "source": 4, "target": 5, "utterances": ["Here's my payment"] }}, + {{ "source": 5, "target": 2, "utterances": ["I'd like to order another coffee"] }} + ], + "nodes": [ + {{ "id": 1, "label": "welcome", "is_start": true, "utterances": ["Welcome! How can I help you today?"] }}, + {{ "id": 2, "label": "ask_coffee_type", "is_start": false, "utterances": ["What kind of coffee would you like?"] }}, + {{ "id": 3, "label": "confirm", "is_start": false, "utterances": ["That's a large latte. Is this correct?"] }}, + {{ "id": 4, "label": "payment", "is_start": false, "utterances": ["Great! That'll be $5. Please proceed with payment."] }}, + {{ "id": 5, "label": "completed", "is_start": false, "utterances": ["Thank you! Would you like another coffee?"] }} + ] + + **Rules:** + 1) Responses must acknowledge what the user has already specified + 2) The final node MUST connect back to an existing node + 3) Each node must have clear purpose + 4) Return ONLY the JSON without commentary + 5) Graph must be cyclic - no dead ends + 6) All edges must connect to existing nodes + 7) The cycle point should make logical sense + + **Your task is to create a cyclic dialogue graph about the following topic:** {topic}. + """) + +cycle_graph_generation_prompt_enhanced = PromptTemplate.from_template( """ - "This is the end of the example." - "IMPORTANT: all the dialogues you've prompted are cyclic. Before answering you must check where the dialog can loop or cycle and make the first node of a cycle a target node for the last node of the cycle. Brackets must be changed back into curly braces to create a valid JSON string. Return ONLY JSON string in plain text (no code blocks) without any additional commentaries." - "Dialogue: {dialog}" +Create a dialogue graph for a {topic} conversation that will be used for training data generation. The graph must follow these requirements: + +1. Dialogue Flow Requirements: + - Each assistant message (node) must be a precise question or statement that expects a specific type of response + - Each user message (edge) must logically and directly respond to the previous assistant message + - All paths must maintain clear context and natural conversation flow + - Avoid any ambiguous or overly generic responses + +2. Graph Structure Requirements: + - Must contain at least 2 distinct cycles (return paths) + - Each cycle should allow users to: + * Return to previous choices for modification + * Restart specific parts of the conversation + * Change their mind about earlier decisions + - Include clear exit points from each major decision path + +3. Core Path Types: + - Main success path (completing the intended task) + - Multiple modification paths (returning to change choices) + - Early exit paths (user decides to stop) + - Alternative success paths (achieving goal differently) + +Example of a good cycle structure: +Assistant: "What size coffee would you like?" +User: "Medium please" +Assistant: "Would you like that hot or iced?" +User: "Actually, can I change my size?" +Assistant: "Of course! What size would you like instead?" + +Format: +{{ + "edges": [ + {{ + "source": "node_id", + "target": "node_id", + "utterances": ["User response text"] + }} + ], + "nodes": [ + {{ + "id": "node_id", + "label": "semantic_label", + "is_start": boolean, + "utterances": ["Assistant message text"] + }} + ] +}} + +Requirements for node IDs: +- Must be unique integers +- Start node should have ID 1 +- IDs should increment sequentially + +Return ONLY the valid JSON without any additional text or explanations. +""" ) + +cycle_graph_repair_prompt = PromptTemplate.from_template(""" +Fix the invalid transitions in this dialogue graph while keeping its structure. + +Current invalid transitions that need to be fixed: +{invalid_transitions} + +Original graph structure: +{graph_json} + +Requirements for the fix: +1. Keep all node IDs and structure the same +2. Fix ONLY the invalid transitions +3. Make sure the fixed transitions are logical and natural +4. Each user response must logically follow from the assistant's previous message +5. Each assistant response must properly address the user's input + +Return ONLY the complete fixed graph JSON with the same structure. +""") diff --git a/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/schemas.py b/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/schemas.py index 0fe1437..c12e4c5 100644 --- a/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/schemas.py +++ b/dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/schemas.py @@ -1,5 +1,118 @@ -from typing import List -from pydantic import BaseModel, Field +import networkx as nx +from typing import List, Union, Dict +from pydantic import BaseModel, Field, ConfigDict + + +class DialogueMessage(BaseModel): + """Represents a single message in a dialogue. + + Attributes: + text: The content of the message + participant: The sender of the message (e.g. "user" or "assistant") + """ + + text: str + participant: str + + +class Dialogue(BaseModel): + """Represents a complete dialogue consisting of multiple messages. + + The class provides methods for creating dialogues from different formats + and converting dialogues to various representations. + """ + + messages: List[DialogueMessage] = Field(default_factory=list) + topic: str = "" + validate: bool = Field(default=True, description="Whether to validate messages upon initialization") + + model_config = ConfigDict( + arbitrary_types_allowed=True, + frozen=False, # Dialogue needs to be mutable to append messages + ) + + def __init__(self, **data): + super().__init__(**data) + if self.validate: + self.__validate(self.messages) + + @classmethod + def from_string(cls, string: str) -> "Dialogue": + """Creates a Dialogue from a tab-separated string format. + + Args: + string: Tab-separated string with format: "participant\ttext\n" + + Returns: + Dialogue object with parsed messages + """ + messages: List[DialogueMessage] = [ + DialogueMessage(participant=line.split("\t")[0], text=line.split("\t")[1]) for line in string.strip().split("\n") + ] + return cls(messages=messages) + + @classmethod + def from_list(cls, messages: List[Dict[str, str]], validate: bool = True) -> "Dialogue": + """Create a Dialogue from a list of dictionaries.""" + dialogue_messages = [DialogueMessage(**m) for m in messages] + return cls(messages=dialogue_messages, validate=validate) + + @classmethod + def from_nodes_ids(cls, graph, node_list, validate: bool = True) -> "Dialogue": + utts = [] + nodes_attributes = nx.get_node_attributes(graph.graph, "utterances") + edges_attributes = nx.get_edge_attributes(graph.graph, "utterances") + for node in range(len(node_list)): + utts.append({"participant": "assistant", "text": nodes_attributes[node_list[node]][0]}) + if node == len(node_list) - 1: + if graph.graph.has_edge(node_list[node], node_list[0]): + utts.append({"participant": "user", "text": edges_attributes[(node_list[node], node_list[0])][0]}) + else: + if graph.graph.has_edge(node_list[node], node_list[node + 1]): + utts.append({"participant": "user", "text": edges_attributes[(node_list[node], node_list[node + 1])][0]}) + + return cls(messages=utts, validate=validate) + + def to_list(self) -> List[Dict[str, str]]: + """Converts Dialogue to a list of message dictionaries.""" + return [msg.model_dump() for msg in self.messages] + + def __str__(self) -> str: + """Returns a readable string representation of the dialogue.""" + return "\n".join(f"{msg.participant}: {msg.text}" for msg in self.messages).strip() + + def append(self, text: str, participant: str) -> None: + """Adds a new message to the dialogue. + + Args: + text: Content of the message + participant: Sender of the message + """ + self.messages.append(DialogueMessage(text=text, participant=participant)) + + def extend(self, messages: List[Union[DialogueMessage, Dict[str, str]]]) -> None: + """Adds multiple messages to the dialogue. + + Args: + messages: List of DialogueMessage objects or dicts to add + """ + new_messages = [msg if isinstance(msg, DialogueMessage) else DialogueMessage(**msg) for msg in messages] + self.__validate(new_messages) + self.messages.extend(new_messages) + + def __validate(self, messages): + """Ensure that messages meets expectations.""" + if not messages: + return + + # Check if first message is from assistant + if messages[0].participant != "assistant": + raise ValueError(f"First message must be from assistant, got: {messages[0]}") + + # Check for consecutive messages from same participant + for i in range(len(messages) - 1): + if messages[i].participant == messages[i + 1].participant: + raise ValueError(f"Cannot have consecutive messages from the same participant. Messages: {messages[i]}, {messages[i + 1]}") class Edge(BaseModel): @@ -20,13 +133,8 @@ class DialogueGraph(BaseModel): nodes: List[Node] = Field(description="List of nodes representing assistant states") -class DialogueMessage(BaseModel): - """Represents a single message in a dialogue. - - Attributes: - text: The content of the message - participant: The sender of the message (e.g. "user" or "assistant") - """ - - text: str - participant: str +class GraphGenerationResult(BaseModel): + """Complete result with graph and dialogues""" + graph: DialogueGraph + topic: str + dialogues: List[Dialogue] diff --git a/experiments/2025.01.13_graph_generation_autofix/grpah_gen_autofix.ipynb b/experiments/2025.01.13_graph_generation_autofix/grpah_gen_autofix.ipynb new file mode 100644 index 0000000..ca939be --- /dev/null +++ b/experiments/2025.01.13_graph_generation_autofix/grpah_gen_autofix.ipynb @@ -0,0 +1,496 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from langchain_openai import ChatOpenAI\n", + "from dotenv import load_dotenv\n", + "\n", + "\n", + "load_dotenv() " + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# enhanced_graph_prompt = PromptTemplate.from_template(\n", + "# \"\"\"\n", + "# Create a dialogue graph for a {topic} conversation that will be used for training data generation. The graph must follow these requirements:\n", + "\n", + "# 1. Dialogue Flow Requirements:\n", + "# - Each assistant message (node) must be a precise question or statement that expects a specific type of response\n", + "# - Each user message (edge) must logically and directly respond to the previous assistant message\n", + "# - All paths must maintain clear context and natural conversation flow\n", + "# - Avoid any ambiguous or overly generic responses\n", + "\n", + "# 2. Graph Structure Requirements:\n", + "# - Must contain at least 2 distinct cycles (return paths)\n", + "# - Each cycle should allow users to:\n", + "# * Return to previous choices for modification\n", + "# * Restart specific parts of the conversation\n", + "# * Change their mind about earlier decisions\n", + "# - Include clear exit points from each major decision path\n", + " \n", + "# 3. Core Path Types:\n", + "# - Main success path (completing the intended task)\n", + "# - Multiple modification paths (returning to change choices)\n", + "# - Early exit paths (user decides to stop)\n", + "# - Alternative success paths (achieving goal differently)\n", + "\n", + "# Example of a good cycle structure:\n", + "# Assistant: \"What size coffee would you like?\"\n", + "# User: \"Medium please\"\n", + "# Assistant: \"Would you like that hot or iced?\"\n", + "# User: \"Actually, can I change my size?\"\n", + "# Assistant: \"Of course! What size would you like instead?\"\n", + "\n", + "# Format:\n", + "# {{\n", + "# \"edges\": [\n", + "# {{\n", + "# \"source\": \"node_id\",\n", + "# \"target\": \"node_id\",\n", + "# \"utterances\": [\"User response text\"]\n", + "# }}\n", + "# ],\n", + "# \"nodes\": [\n", + "# {{\n", + "# \"id\": \"node_id\",\n", + "# \"label\": \"semantic_label\",\n", + "# \"is_start\": boolean,\n", + "# \"utterances\": [\"Assistant message text\"]\n", + "# }}\n", + "# ]\n", + "# }}\n", + "\n", + "# Requirements for node IDs:\n", + "# - Must be unique integers\n", + "# - Start node should have ID 1\n", + "# - IDs should increment sequentially\n", + "\n", + "# Return ONLY the valid JSON without any additional text or explanations.\n", + "# \"\"\"\n", + "# )\n", + "\n", + "# graph_generator = CycleGraphGenerator()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# def validate_graph_cycle_requirement(\n", + "# graph: BaseGraph,\n", + "# min_cycles: int = 2\n", + "# ) -> Dict[str, Any]:\n", + "# \"\"\"\n", + "# Проверяет граф на соответствие техническим требованиям\n", + " \n", + "# Args:\n", + "# graph: BaseGraph для проверки\n", + "# min_cycles: минимальное требуемое количество циклов\n", + " \n", + "# Returns:\n", + "# Dict с результатами проверки:\n", + "# {\n", + "# \"meets_requirements\": bool,\n", + "# \"cycles\": List[List[int]],\n", + "# \"cycles_count\": int\n", + "# }\n", + "# \"\"\"\n", + "# print(\"\\n🔍 Checking graph requirements...\")\n", + " \n", + "# try:\n", + "# cycles = list(nx.simple_cycles(graph.graph))\n", + "# cycles_count = len(cycles)\n", + " \n", + "# print(f\"🔄 Found {cycles_count} cycles in the graph:\")\n", + "# for i, cycle in enumerate(cycles, 1):\n", + "# print(f\"Cycle {i}: {' -> '.join(map(str, cycle + [cycle[0]]))}\")\n", + " \n", + "# meets_requirements = cycles_count >= min_cycles\n", + " \n", + "# if not meets_requirements:\n", + "# print(f\"❌ Graph doesn't meet cycle requirements (minimum {min_cycles} cycles needed)\")\n", + "# else:\n", + "# print(\"✅ Graph meets cycle requirements\")\n", + " \n", + "# return {\n", + "# \"meets_requirements\": meets_requirements,\n", + "# \"cycles\": cycles,\n", + "# \"cycles_count\": cycles_count\n", + "# }\n", + " \n", + "# except Exception as e:\n", + "# print(f\"❌ Validation error: {str(e)}\")\n", + "# raise" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# from chatsky_llm_autoconfig.metrics.llm_metrics import graph_validation\n", + "# import json\n", + "# repair_template = PromptTemplate.from_template(\"\"\"\n", + "# Fix the invalid transitions in this dialogue graph while keeping its structure.\n", + "\n", + "# Current invalid transitions that need to be fixed:\n", + "# {invalid_transitions}\n", + "\n", + "# Original graph structure:\n", + "# {graph_json}\n", + "\n", + "# Requirements for the fix:\n", + "# 1. Keep all node IDs and structure the same\n", + "# 2. Fix ONLY the invalid transitions\n", + "# 3. Make sure the fixed transitions are logical and natural\n", + "# 4. Each user response must logically follow from the assistant's previous message\n", + "# 5. Each assistant response must properly address the user's input\n", + "\n", + "# Return ONLY the complete fixed graph JSON with the same structure.\n", + "# \"\"\")\n", + "\n", + "# def check_and_fix_transitions(graph: BaseGraph, graph_generator: CycleGraphGenerator, model: BaseChatModel, max_attempts: int = 3) -> Dict[str, Any]:\n", + "# \"\"\"\n", + "# Проверяет переходы в графе и пытается исправить невалидные через LLM\n", + " \n", + "# Args:\n", + "# graph: Исходный граф для проверки и исправления\n", + "# graph_generator: Генератор графов для исправления\n", + "# model: Модель для валидации\n", + "# max_attempts: Максимальное количество попыток исправления\n", + " \n", + "# Returns:\n", + "# Dict: {\n", + "# \"is_valid\": bool, # Удалось ли получить валидный граф\n", + "# \"graph\": BaseGraph, # Последняя версия графа (исправленная или нет)\n", + "# \"validation_details\": { # Детали последней валидации\n", + "# \"invalid_transitions\": [...], # Список оставшихся невалидных переходов\n", + "# \"attempts_made\": int, # Сколько попыток исправления было сделано\n", + "# \"fixed_count\": int, # Сколько переходов удалось исправить\n", + "# }\n", + "# }\n", + "# \"\"\"\n", + "# # Проверяем изначальный граф\n", + "# initial_validation = graph_validation(graph, model)\n", + "# if initial_validation[\"is_valid\"]:\n", + "# return {\n", + "# \"is_valid\": True,\n", + "# \"graph\": graph,\n", + "# \"validation_details\": {\n", + "# \"invalid_transitions\": [],\n", + "# \"attempts_made\": 0,\n", + "# \"fixed_count\": 0\n", + "# }\n", + "# }\n", + " \n", + "# initial_invalid_count = len(initial_validation[\"invalid_transitions\"])\n", + "# current_graph = graph\n", + "# current_attempt = 0\n", + " \n", + "# while current_attempt < max_attempts:\n", + "# print(f\"\\n🔄 Fix attempt {current_attempt + 1}/{max_attempts}\")\n", + " \n", + " \n", + "# try:\n", + "# # Используем graph_generator для генерации исправленного графа\n", + "# current_graph = graph_generator.invoke(model=model, prompt=repair_template, invalid_transitions=initial_validation[\"invalid_transitions\"], graph_json=current_graph.graph_dict)\n", + " \n", + "# # Проверяем исправленный граф\n", + "# validation = graph_validation(current_graph, model)\n", + "# if validation[\"is_valid\"]:\n", + "# return {\n", + "# \"is_valid\": True,\n", + "# \"graph\": current_graph,\n", + "# \"validation_details\": {\n", + "# \"invalid_transitions\": [],\n", + "# \"attempts_made\": current_attempt + 1,\n", + "# \"fixed_count\": initial_invalid_count\n", + "# }\n", + "# }\n", + " \n", + "# except Exception as e:\n", + "# print(f\"⚠️ Error during fix attempt: {str(e)}\")\n", + "# break\n", + " \n", + "# current_attempt += 1\n", + " \n", + "# remaining_invalid = len(validation[\"invalid_transitions\"])\n", + " \n", + "# return {\n", + "# \"is_valid\": False,\n", + "# \"graph\": current_graph,\n", + "# \"validation_details\": {\n", + "# \"invalid_transitions\": validation[\"invalid_transitions\"],\n", + "# \"attempts_made\": current_attempt,\n", + "# \"fixed_count\": initial_invalid_count - remaining_invalid\n", + "# }\n", + "# }" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "# from chatsky_llm_autoconfig.algorithms.dialogue_generation import RecursiveDialogueSampler\n", + "# from chatsky_llm_autoconfig.metrics.automatic_metrics import all_utterances_present\n", + "# from chatsky_llm_autoconfig.metrics.llm_metrics import graph_validation, is_theme_valid\n", + "\n", + "# CYCLE_REQUIREMENT = 2\n", + "\n", + "# # Генерация\n", + "# gen_model = ChatOpenAI(\n", + "# model='o1-mini',\n", + "# api_key=os.getenv(\"OPENAI_API_KEY\"),\n", + "# base_url=os.getenv(\"OPENAI_BASE_URL\"),\n", + "# temperature=1\n", + "# )\n", + "\n", + "# try:\n", + "# topic = \"ordering pizza\"\n", + " \n", + "# graph = graph_generator.invoke(model=gen_model, prompt=enhanced_graph_prompt, topic=topic)\n", + "\n", + "# # Проверка требований\n", + "# validation_result = validate_graph_cycle_requirement(graph, min_cycles=CYCLE_REQUIREMENT)\n", + "\n", + "# # Семплинг диалогов\n", + "# dial_sampler = RecursiveDialogueSampler()\n", + "# sampled_dialogues = dial_sampler.invoke(graph, 1, -1)\n", + "\n", + "# # Проверка семплинга\n", + "# sampling_result = all_utterances_present(graph, sampled_dialogues)\n", + " \n", + "# if sampling_result is False:\n", + "# raise ValueError(\"Failed to sample valid dialogues from the graph or sampling error occurred\")\n", + " \n", + "# # Сначала проверяем валидность темы\n", + "# theme_validation = is_theme_valid(graph, gen_model, topic)\n", + "# if not theme_validation['value']:\n", + "# raise ValueError(f\"Theme validation failed: {theme_validation['description']}\")\n", + " \n", + "# # Если тема валидна, проверяем триплеты в цикле\n", + "# transition_validation = check_and_fix_transitions(graph, graph_generator, gen_model, max_attempts=3)\n", + " \n", + "# print(\"\\nValidation results:\")\n", + "# print(f\"Theme valid: {theme_validation['value']}\")\n", + "# print(f\"Transitions valid: {transition_validation['is_valid']}\")\n", + " \n", + "# if not transition_validation['is_valid']:\n", + "# print(\"\\nInvalid transitions:\")\n", + "# for t in transition_validation['invalid_transitions']:\n", + "# print(f\"- {t['reason']}\")\n", + "\n", + "# except Exception as e:\n", + "# print(f\"❌ Error during graph generation or validation: {str(e)}\")\n", + "# raise\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/dmitriimartynov/Documents/Projects/chatsky-llm-autoconfig/.venv/lib/python3.12/site-packages/pydantic/_internal/_fields.py:172: UserWarning: Field name \"validate\" in \"Dialogue\" shadows an attribute in parent \"BaseModel\"\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "==================================================\n", + "Generating graph for topic: technical support conversation\n", + "==================================================\n", + "Generating Graph ...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:httpx:HTTP Request: POST http://193.187.173.33:8002/api/providers/openai/v1/chat/completions \"HTTP/1.1 200 OK\"\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "🔍 Checking graph requirements...\n", + "🔄 Found 30 cycles in the graph:\n", + "Cycle 1: 1 -> 2 -> 5 -> 7 -> 11 -> 1\n", + "Cycle 2: 1 -> 2 -> 5 -> 7 -> 12 -> 1\n", + "Cycle 3: 1 -> 2 -> 5 -> 8 -> 11 -> 1\n", + "Cycle 4: 1 -> 2 -> 5 -> 8 -> 12 -> 1\n", + "Cycle 5: 1 -> 2 -> 6 -> 9 -> 11 -> 1\n", + "Cycle 6: 1 -> 2 -> 6 -> 9 -> 12 -> 1\n", + "Cycle 7: 1 -> 2 -> 6 -> 10 -> 11 -> 1\n", + "Cycle 8: 1 -> 2 -> 6 -> 10 -> 12 -> 1\n", + "Cycle 9: 1 -> 2 -> 1\n", + "Cycle 10: 1 -> 3 -> 14 -> 16 -> 1\n", + "Cycle 11: 1 -> 3 -> 14 -> 17 -> 11 -> 1\n", + "Cycle 12: 1 -> 3 -> 14 -> 17 -> 16 -> 1\n", + "Cycle 13: 1 -> 3 -> 15 -> 18 -> 1\n", + "Cycle 14: 1 -> 3 -> 15 -> 19 -> 11 -> 1\n", + "Cycle 15: 1 -> 3 -> 15 -> 19 -> 18 -> 1\n", + "Cycle 16: 1 -> 3 -> 1\n", + "Cycle 17: 1 -> 4 -> 20 -> 1\n", + "Cycle 18: 1 -> 4 -> 1\n", + "Cycle 19: 1 -> 4 -> 21 -> 20 -> 1\n", + "Cycle 20: 1 -> 4 -> 21 -> 1\n", + "Cycle 21: 3 -> 14 -> 17 -> 3\n", + "Cycle 22: 3 -> 14 -> 3\n", + "Cycle 23: 3 -> 15 -> 19 -> 3\n", + "Cycle 24: 3 -> 15 -> 3\n", + "Cycle 25: 2 -> 5 -> 7 -> 2\n", + "Cycle 26: 2 -> 5 -> 8 -> 2\n", + "Cycle 27: 2 -> 5 -> 2\n", + "Cycle 28: 2 -> 6 -> 9 -> 2\n", + "Cycle 29: 2 -> 6 -> 10 -> 2\n", + "Cycle 30: 2 -> 6 -> 2\n", + "✅ Graph meets cycle requirements\n", + "Sampling dialogues...\n" + ] + } + ], + "source": [ + "from chatsky_llm_autoconfig.algorithms.cycle_graph_generation_pipeline import GraphGenerationPipeline\n", + "from langchain_openai import ChatOpenAI\n", + "from dotenv import load_dotenv\n", + "from chatsky_llm_autoconfig.schemas import GraphGenerationResult\n", + "from datetime import datetime\n", + "from pathlib import Path\n", + "import os\n", + "import json\n", + "\n", + "\n", + "def generate_graphs():\n", + " output_dir = Path(\"generated_graphs\")\n", + " output_dir.mkdir(exist_ok=True)\n", + " \n", + " generation_model = ChatOpenAI(\n", + " model='o1-mini',\n", + " api_key=os.getenv(\"OPENAI_API_KEY\"),\n", + " base_url=os.getenv(\"OPENAI_BASE_URL\"),\n", + " temperature=1\n", + " )\n", + " \n", + " validation_model = ChatOpenAI(\n", + " model='gpt-4o',\n", + " api_key=os.getenv(\"OPENAI_API_KEY\"),\n", + " base_url=os.getenv(\"OPENAI_BASE_URL\"),\n", + " temperature=0\n", + " )\n", + " \n", + " pipeline = GraphGenerationPipeline(\n", + " generation_model=generation_model,\n", + " validation_model=validation_model\n", + " )\n", + " \n", + " topics = [\n", + " \"technical support conversation\",\n", + " # \"restaurant reservation\",\n", + " # \"online shopping checkout\",\n", + " # \"job interview\",\n", + " # \"travel booking\"\n", + " ]\n", + " \n", + " successful_generations = []\n", + " \n", + " for topic in topics:\n", + " print(f\"\\n{'='*50}\")\n", + " print(f\"Generating graph for topic: {topic}\")\n", + " print(f\"{'='*50}\")\n", + " \n", + " try:\n", + " result = pipeline(topic)\n", + " \n", + " # Проверяем тип результата\n", + " if isinstance(result, GraphGenerationResult):\n", + " print(f\"✅ Successfully generated graph for {topic}\")\n", + " # Сохраняем полный результат с графом и диалогами\n", + " successful_generations.append({\n", + " \"graph\": result.graph,\n", + " \"topic\": result.topic,\n", + " \"dialogues\": result.dialogues\n", + " })\n", + " else: # isinstance(result, GenerationError)\n", + " print(f\"❌ Failed to generate graph for {topic}\")\n", + " print(f\"Error type: {result.error_type}\")\n", + " print(f\"Error message: {result.message}\")\n", + " \n", + " except Exception as e:\n", + " print(f\"❌ Unexpected error processing topic '{topic}': {str(e)}\")\n", + " continue\n", + " \n", + " if successful_generations:\n", + " timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')\n", + " filename = f\"generated_graphs_{timestamp}.json\"\n", + " with open(output_dir / filename, \"w\", encoding=\"utf-8\") as f:\n", + " # Используем model_dump() для корректной сериализации Pydantic моделей\n", + " json_data = [result for result in successful_generations]\n", + " json.dump(json_data, f, indent=2, ensure_ascii=False)\n", + " print(f\"\\nAll graphs saved to: {filename}\")\n", + " \n", + " print(f\"\\nSuccessfully generated {len(successful_generations)} graphs out of {len(topics)} topics\")\n", + " else:\n", + " print(\"\\nNo graphs were successfully generated\")\n", + "\n", + "\n", + "if __name__ == \"__main__\":\n", + " generate_graphs()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/experiments/2025.01.13_graph_generation_autofix/report.md b/experiments/2025.01.13_graph_generation_autofix/report.md new file mode 100644 index 0000000..e69de29 diff --git a/experiments/2025.01.13_graph_generation_autofix/task.md b/experiments/2025.01.13_graph_generation_autofix/task.md new file mode 100644 index 0000000..a2a2ed9 --- /dev/null +++ b/experiments/2025.01.13_graph_generation_autofix/task.md @@ -0,0 +1 @@ +write autifixing graph gen flow \ No newline at end of file diff --git a/test_classes.py b/test_classes.py index b2769f3..79c336e 100644 --- a/test_classes.py +++ b/test_classes.py @@ -1,5 +1,5 @@ from chatsky_llm_autoconfig.algorithms.dialogue_generation import DialogueSampler -from chatsky_llm_autoconfig.dialogue import Dialogue +from chatsky_llm_autoconfig.schemas import Dialogue from chatsky_llm_autoconfig.graph import Graph from chatsky_llm_autoconfig.metrics.automatic_metrics import all_paths_sampled import json diff --git a/tests/conftest.py b/tests/conftest.py index 8ceaa16..cc3398d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,4 @@ -from chatsky_llm_autoconfig.dialogue import Dialogue +from chatsky_llm_autoconfig.schemas import Dialogue from chatsky_llm_autoconfig.algorithms.dialogue_augmentation import * from chatsky_llm_autoconfig.algorithms.dialogue_generation import * from chatsky_llm_autoconfig.algorithms.topic_graph_generation import * diff --git a/tests/test_models.py b/tests/test_models.py index 2d9b0b2..1c1e882 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,4 +1,4 @@ import pytest -from chatsky_llm_autoconfig.dialogue import Dialogue +from chatsky_llm_autoconfig.schemas import Dialogue from chatsky_llm_autoconfig.graph import Graph