Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Martynov gen autofix #16

Open
wants to merge 6 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
from pydantic import BaseModel
import abc
from chatsky_llm_autoconfig.graph import BaseGraph
from chatsky_llm_autoconfig.dialogue import Dialogue
from chatsky_llm_autoconfig.schemas import Dialogue
from langchain_core.language_models.chat_models import BaseChatModel
from langchain.prompts import PromptTemplate


class BaseAlgorithm(BaseModel, abc.ABC):
Expand Down Expand Up @@ -72,7 +73,7 @@ async def ainvoke(self, topic: str, graph: BaseGraph) -> BaseGraph:
class TopicGraphGenerator(BaseAlgorithm):
"""Graph generator that works only with topics."""

def invoke(self, topic: str, model: BaseChatModel) -> BaseGraph:
def invoke(self, model: BaseChatModel, prompt: PromptTemplate) -> BaseGraph:
raise NotImplementedError

async def ainvoke(self, topic: str) -> BaseGraph:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
from dataclasses import dataclass
from typing import Optional, Dict, Any
import networkx as nx
from langchain_core.language_models.chat_models import BaseChatModel
from langchain.prompts import PromptTemplate
from chatsky_llm_autoconfig.algorithms.topic_graph_generation import CycleGraphGenerator
from chatsky_llm_autoconfig.algorithms.dialogue_generation import RecursiveDialogueSampler
from chatsky_llm_autoconfig.metrics.automatic_metrics import all_utterances_present
from chatsky_llm_autoconfig.metrics.llm_metrics import graph_validation, is_theme_valid
from chatsky_llm_autoconfig.graph import BaseGraph
from chatsky_llm_autoconfig.prompts import cycle_graph_generation_prompt_enhanced, cycle_graph_repair_prompt
from openai import BaseModel

from enum import Enum
from typing import Union

from chatsky_llm_autoconfig.schemas import GraphGenerationResult


class ErrorType(str, Enum):
"""Types of errors that can occur during generation"""
INVALID_GRAPH_STRUCTURE = "invalid_graph_structure"
TOO_MANY_CYCLES = "too_many_cycles"
SAMPLING_FAILED = "sampling_failed"
INVALID_THEME = "invalid_theme"
GENERATION_FAILED = "generation_failed"


class GenerationError(BaseModel):
"""Base error with essential fields"""
error_type: ErrorType
message: str


PipelineResult = Union[GraphGenerationResult, GenerationError]


@dataclass
class GraphGenerationPipeline:
generation_model: BaseChatModel
validation_model: BaseChatModel
graph_generator: CycleGraphGenerator
generation_prompt: PromptTemplate
repair_prompt: PromptTemplate
min_cycles: int = 2
max_fix_attempts: int = 3

def __init__(
self,
generation_model: BaseChatModel,
validation_model: BaseChatModel,
generation_prompt: Optional[PromptTemplate] = None,
repair_prompt: Optional[PromptTemplate] = None,
min_cycles: int = 2,
max_fix_attempts: int = 3
):
self.generation_model = generation_model
self.validation_model = validation_model
self.graph_generator = CycleGraphGenerator()
self.dialogue_sampler = RecursiveDialogueSampler()

self.generation_prompt = generation_prompt or cycle_graph_generation_prompt_enhanced
self.repair_prompt = repair_prompt or cycle_graph_repair_prompt

self.min_cycles = min_cycles
self.max_fix_attempts = max_fix_attempts

def validate_graph_cycle_requirement(
self,
graph: BaseGraph,
min_cycles: int = 2
) -> Dict[str, Any]:
"""
Проверяет граф на соответствие требованиям по количеству циклов
"""
print("\n🔍 Checking graph requirements...")

try:
cycles = list(nx.simple_cycles(graph.graph))
cycles_count = len(cycles)

print(f"🔄 Found {cycles_count} cycles in the graph:")
for i, cycle in enumerate(cycles, 1):
print(f"Cycle {i}: {' -> '.join(map(str, cycle + [cycle[0]]))}")

meets_requirements = cycles_count >= min_cycles

if not meets_requirements:
print(f"❌ Graph doesn't meet cycle requirements (minimum {min_cycles} cycles needed)")
else:
print("✅ Graph meets cycle requirements")

return {
"meets_requirements": meets_requirements,
"cycles": cycles,
"cycles_count": cycles_count
}

except Exception as e:
print(f"❌ Validation error: {str(e)}")
raise

def check_and_fix_transitions(
self,
graph: BaseGraph,
max_attempts: int = 3
) -> Dict[str, Any]:
"""
Проверяет переходы в графе и пытается исправить невалидные через LLM
"""
print("Validating initial graph")

initial_validation = graph_validation(graph, self.validation_model)
if initial_validation["is_valid"]:
return {
"is_valid": True,
"graph": graph,
"validation_details": {
"invalid_transitions": [],
"attempts_made": 0,
"fixed_count": 0
}
}

initial_invalid_count = len(initial_validation["invalid_transitions"])
current_graph = graph
current_attempt = 0

while current_attempt < max_attempts:
print(f"\n🔄 Fix attempt {current_attempt + 1}/{max_attempts}")

try:
# Используем generation_model для исправления графа
current_graph = self.graph_generator.invoke(
model=self.generation_model,
prompt=self.repair_prompt,
invalid_transitions=initial_validation["invalid_transitions"],
graph_json=current_graph.graph_dict
)

# Проверяем исправленный граф используя validation_model
validation = graph_validation(current_graph, self.validation_model)
if validation["is_valid"]:
return {
"is_valid": True,
"graph": current_graph,
"validation_details": {
"invalid_transitions": [],
"attempts_made": current_attempt + 1,
"fixed_count": initial_invalid_count
}
}

except Exception as e:
print(f"⚠️ Error during fix attempt: {str(e)}")
break

current_attempt += 1

remaining_invalid = len(validation["invalid_transitions"])

return {
"is_valid": False,
"graph": current_graph,
"validation_details": {
"invalid_transitions": validation["invalid_transitions"],
"attempts_made": current_attempt,
"fixed_count": initial_invalid_count - remaining_invalid
}
}

def generate_and_validate(self, topic: str) -> PipelineResult:
"""
Generates and validates a dialogue graph for given topic
"""
try:
# 1. Generate initial graph
print("Generating Graph ...")
graph = self.graph_generator.invoke(
model=self.generation_model,
prompt=self.generation_prompt,
topic=topic
)

# 2. Validate cycles
cycle_validation = self.validate_graph_cycle_requirement(graph, self.min_cycles)
if not cycle_validation["meets_requirements"]:
return GenerationError(
error_type=ErrorType.TOO_MANY_CYCLES,
message=f"Graph requires minimum {self.min_cycles} cycles, found {cycle_validation['cycles_count']}"
)

# 3. Generate and validate dialogues
print("Sampling dialogues...")
sampled_dialogues = self.dialogue_sampler.invoke(graph, 1, -1)
if not all_utterances_present(graph, sampled_dialogues):
return GenerationError(
error_type=ErrorType.SAMPLING_FAILED,
message="Failed to sample valid dialogues - not all utterances are present"
)

# 4. Validate theme
theme_validation = is_theme_valid(graph, self.validation_model, topic)
if not theme_validation["value"]:
return GenerationError(
error_type=ErrorType.INVALID_THEME,
message=f"Theme validation failed: {theme_validation['description']}"
)

# 5. Validate and fix transitions
transition_validation = self.check_and_fix_transitions(
graph=graph,
max_attempts=self.max_fix_attempts
)

if not transition_validation["is_valid"]:
invalid_transitions = transition_validation["validation_details"]["invalid_transitions"]
return GenerationError(
error_type=ErrorType.INVALID_GRAPH_STRUCTURE,
message=f"Found {len(invalid_transitions)} invalid transitions after {transition_validation['validation_details']['attempts_made']} fix attempts"
)

# All validations passed - return successful result
return GraphGenerationResult(
graph=transition_validation["graph"].graph_dict,
topic=topic,
dialogues=sampled_dialogues
)

except Exception as e:
return GenerationError(
error_type=ErrorType.GENERATION_FAILED,
message=f"Unexpected error during generation: {str(e)}"
)

def __call__(self, topic: str) -> PipelineResult:
"""Shorthand for generate_and_validate"""
return self.generate_and_validate(topic)
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from chatsky_llm_autoconfig.dialogue import Dialogue
from chatsky_llm_autoconfig.schemas import Dialogue
from chatsky_llm_autoconfig.schemas import DialogueMessage
from chatsky_llm_autoconfig.autometrics.registry import AlgorithmRegistry

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import networkx as nx
from chatsky_llm_autoconfig.graph import BaseGraph
from chatsky_llm_autoconfig.algorithms.base import DialogueGenerator
from chatsky_llm_autoconfig.dialogue import Dialogue
from chatsky_llm_autoconfig.schemas import Dialogue
from chatsky_llm_autoconfig.autometrics.registry import AlgorithmRegistry


Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from typing import Optional
from chatsky_llm_autoconfig.algorithms.base import TopicGraphGenerator
from chatsky_llm_autoconfig.autometrics.registry import AlgorithmRegistry
from chatsky_llm_autoconfig.schemas import DialogueGraph
Expand All @@ -7,82 +6,32 @@
from chatsky_llm_autoconfig.graph import BaseGraph, Graph
from langchain_core.language_models.chat_models import BaseChatModel

from pydantic import Field
from typing import ClassVar


@AlgorithmRegistry.register(input_type=str, output_type=BaseGraph)
class CycleGraphGenerator(TopicGraphGenerator):
"""Generator specifically for topic-based cyclic graphs"""

DEFAULT_TEMPLATE: ClassVar[str] = """
Create a complex dialogue graph where the conversation MUST return to an existing node.

**CRITICAL: Response Specificity**
Responses must acknowledge and build upon what the user has already specified:

INCORRECT flow:
- User: "I'd like to order a coffee"
- Staff: "What would you like to order?" (TOO GENERAL - ignores that they specified coffee)

CORRECT flow:
- User: "I'd like to order a coffee"
- Staff: "What kind of coffee would you like?" (GOOD - acknowledges they want coffee)

Example of a CORRECT cyclic graph for a coffee shop:
"edges": [
{{ "source": 1, "target": 2, "utterances": ["Hi, I'd like to order a coffee"] }},
{{ "source": 2, "target": 3, "utterances": ["A large latte please"] }},
{{ "source": 3, "target": 4, "utterances": ["Yes, that's correct"] }},
{{ "source": 4, "target": 5, "utterances": ["Here's my payment"] }},
{{ "source": 5, "target": 2, "utterances": ["I'd like to order another coffee"] }}
],
"nodes": [
{{ "id": 1, "label": "welcome", "is_start": true, "utterances": ["Welcome! How can I help you today?"] }},
{{ "id": 2, "label": "ask_coffee_type", "is_start": false, "utterances": ["What kind of coffee would you like?"] }},
{{ "id": 3, "label": "confirm", "is_start": false, "utterances": ["That's a large latte. Is this correct?"] }},
{{ "id": 4, "label": "payment", "is_start": false, "utterances": ["Great! That'll be $5. Please proceed with payment."] }},
{{ "id": 5, "label": "completed", "is_start": false, "utterances": ["Thank you! Would you like another coffee?"] }}
]

**Rules:**
1) Responses must acknowledge what the user has already specified
2) The final node MUST connect back to an existing node
3) Each node must have clear purpose
4) Return ONLY the JSON without commentary
5) Graph must be cyclic - no dead ends
6) All edges must connect to existing nodes
7) The cycle point should make logical sense

**Your task is to create a cyclic dialogue graph about the following topic:** {topic}.
"""

cycle_graph_generation_prompt: PromptTemplate = Field(
default_factory=lambda: PromptTemplate.from_template(CycleGraphGenerator.DEFAULT_TEMPLATE)
)

def __init__(self, prompt: Optional[PromptTemplate] = None):
def __init__(self):
super().__init__()
if prompt is not None:
self.cycle_graph_generation_prompt = prompt

def invoke(self, topic: str, model: BaseChatModel) -> BaseGraph:
def invoke(self, model: BaseChatModel, prompt: PromptTemplate, **kwargs) -> BaseGraph:
"""
Generate a cyclic dialogue graph based on the topic input.

Args:
topic (str): The topic for the dialogue graph
model_name (str): The name of the model to use
model (BaseChatModel): The model to use for generation
prompt (PromptTemplate): Prepared prompt template
**kwargs: Additional arguments for formatting the prompt

Returns:
BaseGraph: Generated Graph object with cyclic structure
"""
# Создаем цепочку: промпт -> модель -> парсер
parser = JsonOutputParser(pydantic_object=DialogueGraph)
chain = prompt | model | parser

chain = self.cycle_graph_generation_prompt | model | parser

generated_graph = chain.invoke({"topic": topic})
return Graph(generated_graph)
# Передаем kwargs как входные данные для цепочки
return Graph(chain.invoke(kwargs))

async def ainvoke(self, *args, **kwargs):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from chatsky_llm_autoconfig.algorithms.dialogue_generation import DialogueSampler
import json
from chatsky_llm_autoconfig.graph import Graph, BaseGraph
from chatsky_llm_autoconfig.dialogue import Dialogue
from chatsky_llm_autoconfig.schemas import Dialogue
from chatsky_llm_autoconfig.metrics.automatic_metrics import *
from chatsky_llm_autoconfig.metrics.llm_metrics import are_triplets_valid, is_theme_valid
import datetime
Expand Down
Loading
Loading