Skip to content

Commit

Permalink
Merge pull request #12 from deeppavlov/feat/improved_triplet-check
Browse files Browse the repository at this point in the history
Feat/improved triplet check
  • Loading branch information
NotBioWaste905 authored Dec 10, 2024
2 parents 8115acb + c250bd3 commit 17a7169
Show file tree
Hide file tree
Showing 19 changed files with 10,597 additions and 1,572 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
2. The output must be a list of dictionaries, where each dictionary has:
- 'text': string
- 'participant': either 'user' or 'assistant'
3. Ensure all utterance variations:
- Are appropriate for the theme
- Maintain consistency in tone and style
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Optional
from chatsky_llm_autoconfig.algorithms.base import TopicGraphGenerator
from chatsky_llm_autoconfig.autometrics.registry import AlgorithmRegistry
from chatsky_llm_autoconfig.schemas import DialogueGraph
Expand All @@ -19,50 +20,54 @@ class CycleGraphGenerator(TopicGraphGenerator):
prompt: str = ""
cycle_graph_generation_prompt: str = ""

def __init__(self):
def __init__(self, prompt: Optional[PromptTemplate] = None):
super().__init__()
self.cycle_graph_generation_prompt = PromptTemplate.from_template(
"""
Create a cyclic dialogue graph where the conversation MUST return to an existing node.
**CRITICAL: Response Specificity**
Responses must acknowledge and build upon what the user has already specified:
INCORRECT flow:
- User: "I'd like to order a coffee"
- Staff: "What would you like to order?" (TOO GENERAL - ignores that they specified coffee)
CORRECT flow:
- User: "I'd like to order a coffee"
- Staff: "What kind of coffee would you like?" (GOOD - acknowledges they want coffee)
Example of a CORRECT cyclic graph for a coffee shop:
"edges": [
{{ "source": 1, "target": 2, "utterances": ["Hi, I'd like to order a coffee"] }},
{{ "source": 2, "target": 3, "utterances": ["A large latte please"] }},
{{ "source": 3, "target": 4, "utterances": ["Yes, that's correct"] }},
{{ "source": 4, "target": 5, "utterances": ["Here's my payment"] }},
{{ "source": 5, "target": 2, "utterances": ["I'd like to order another coffee"] }}
],
"nodes": [
{{ "id": 1, "label": "welcome", "is_start": true, "utterances": ["Welcome! How can I help you today?"] }},
{{ "id": 2, "label": "ask_coffee_type", "is_start": false, "utterances": ["What kind of coffee would you like?"] }},
{{ "id": 3, "label": "confirm", "is_start": false, "utterances": ["That's a large latte. Is this correct?"] }},
{{ "id": 4, "label": "payment", "is_start": false, "utterances": ["Great! That'll be $5. Please proceed with payment."] }},
{{ "id": 5, "label": "completed", "is_start": false, "utterances": ["Thank you! Would you like another coffee?"] }}
]
**Rules:**
1) Responses must acknowledge what the user has already specified
2) The final node MUST connect back to an existing node
3) Each node must have clear purpose
4) Return ONLY the JSON without commentary
5) Graph must be cyclic - no dead ends
6) All edges must connect to existing nodes
7) The cycle point should make logical sense
**Your task is to create a cyclic dialogue graph about the following topic:** {topic}.
"""
self.cycle_graph_generation_prompt = (
prompt
if prompt
else PromptTemplate.from_template(
"""
Create a cyclic dialogue graph where the conversation MUST return to an existing node.
**CRITICAL: Response Specificity**
Responses must acknowledge and build upon what the user has already specified:
INCORRECT flow:
- User: "I'd like to order a coffee"
- Staff: "What would you like to order?" (TOO GENERAL - ignores that they specified coffee)
CORRECT flow:
- User: "I'd like to order a coffee"
- Staff: "What kind of coffee would you like?" (GOOD - acknowledges they want coffee)
Example of a CORRECT cyclic graph for a coffee shop:
"edges": [
{{ "source": 1, "target": 2, "utterances": ["Hi, I'd like to order a coffee"] }},
{{ "source": 2, "target": 3, "utterances": ["A large latte please"] }},
{{ "source": 3, "target": 4, "utterances": ["Yes, that's correct"] }},
{{ "source": 4, "target": 5, "utterances": ["Here's my payment"] }},
{{ "source": 5, "target": 2, "utterances": ["I'd like to order another coffee"] }}
],
"nodes": [
{{ "id": 1, "label": "welcome", "is_start": true, "utterances": ["Welcome! How can I help you today?"] }},
{{ "id": 2, "label": "ask_coffee_type", "is_start": false, "utterances": ["What kind of coffee would you like?"] }},
{{ "id": 3, "label": "confirm", "is_start": false, "utterances": ["That's a large latte. Is this correct?"] }},
{{ "id": 4, "label": "payment", "is_start": false, "utterances": ["Great! That'll be $5. Please proceed with payment."] }},
{{ "id": 5, "label": "completed", "is_start": false, "utterances": ["Thank you! Would you like another coffee?"] }}
]
**Rules:**
1) Responses must acknowledge what the user has already specified
2) The final node MUST connect back to an existing node
3) Each node must have clear purpose
4) Return ONLY the JSON without commentary
5) Graph must be cyclic - no dead ends
6) All edges must connect to existing nodes
7) The cycle point should make logical sense
**Your task is to create a cyclic dialogue graph about the following topic:** {topic}.
"""
)
)

def invoke(self, topic: str) -> BaseGraph:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,5 +156,40 @@
"all_paths_sampled_avg": 1.0,
"all_utterances_present_avg": 0.6666666666666666
}
},
"2024-12-10 03:41:54.009332": {
"DialogAugmentator": {
"all_roles_correct": [
true,
true,
true
],
"is_correct_lenght": [
false,
false,
false
],
"all_roles_correct_avg": 1.0,
"is_correct_lenght_avg": 0.0
},
"CycleGraphGenerator": {
"is_theme_valid": [
true,
true,
true,
true
],
"are_triplets_valid": 1.0,
"is_theme_valid_avg": 1.0
},
"DialogueSampler": {
"all_paths_sampled": [],
"all_utterances_present": [
true,
true,
true
],
"all_utterances_present_avg": 1.0
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def run_all_algorithms():
test_topic = case["topic"]
result = class_instance.invoke(test_topic)

metrics["are_triplets_valid"].append(are_triplets_valid(result, model, topic=test_topic)["value"])
metrics["are_triplets_valid"].append(are_triplets_valid(result, model)["value"])
metrics["is_theme_valid"].append(is_theme_valid(result, model, topic=test_topic)["value"])

metrics["is_theme_valid_avg"] = sum(metrics["is_theme_valid"]) / len(metrics["is_theme_valid"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,121 +5,110 @@
This module contains functions that checks Graphs and Dialogues for various metrics using LLM calls.
"""

from chatsky_llm_autoconfig.graph import BaseGraph
from typing import List, Tuple
from chatsky_llm_autoconfig.graph import BaseGraph, Graph
from langchain_core.language_models.chat_models import BaseChatModel
from langchain.prompts import PromptTemplate
from pydantic import BaseModel, Field
from typing import List
from langchain_core.output_parsers import PydanticOutputParser
import logging
import json

# Set up logging
logging.basicConfig(level=logging.INFO)


def are_triplets_valid(G: BaseGraph, model: BaseChatModel, topic: str) -> dict[str]:
def are_triplets_valid(G: Graph, model: BaseChatModel) -> dict[str]:
"""
Validates the dialog graph structure and logical transitions between nodes.
Validates dialogue graph structure and logical transitions between nodes.
Parameters:
G (BaseGraph): The dialog graph to validate
G (BaseGraph): The dialogue graph to validate
model (BaseChatModel): The LLM model to use for validation
topic (str): The topic of the dialog
Returns:
dict: {'value': bool, 'description': str}
"""
# Define prompt template and parser inside the function since they're only used here
triplet_validate_prompt_template = """
You are given a dialog between assistant and a user.
source_utterances, edge_utterances, target_utterances are dialog parts and each contains an array with exactly one utterance.
They should be read left to right.
- source_utterances are assistant phrases
- edge_utterances are user phrases
- target_utterances are assistant phrases
TASK. Evaluate if the transition makes a logical connection when reading from Source utterances to Target utterances through Edge utterances
this is an invalid transition:
{{
'source_utterances': ['Welcome to our online bookstore. How can I assist you today?'],
'edge_utterances': ['Hello! Are you looking for any book recommendations?'],
'target_utterances': ['We have a wide selection of genres. Which do you prefer?'],
'topic': 'Dialog about purchasing books between assistant and customer'
}}
Provide your answer in the following JSON format:
{{"isValid": true or false, "description": "Explanation of why it's valid or invalid."}}
Dialog topic: {topic}

(source_utterances) {source_utterances} -> (edge_utterances) {edge_utterances} -> (target_utterances) {target_utterances}
# Define validation result model
class TransitionValidationResult(BaseModel):
isValid: bool = Field(description="Whether the transition is valid or not")
description: str = Field(description="Explanation of why it's valid or invalid")

Your answer:"""
# Create prompt template
triplet_validate_prompt_template = """
You are evaluating if dialog transitions make logical sense.
Given this conversation graph in JSON:
{json_graph}
For the current transition:
Source (Assistant): {source_utterances}
User Response: {edge_utterances}
Target (Assistant): {target_utterances}
EVALUATE: Do these three messages form a logical sequence in the conversation?
Consider:
1. Does the assistant's first response naturally lead to the user's response?
2. Does the user's response logically connect to the assistant's next message?
3. Is the overall flow natural and coherent?
Reply in JSON format:
{{"isValid": true/false, "description": "Brief explanation of why it's valid or invalid"}}
"""

triplet_validate_prompt = PromptTemplate(
input_variables=["source_utterances", "edge_utterances", "target_utterances", "topic"],
template=triplet_validate_prompt_template,
input_variables=["json_graph", "source_utterances", "edge_utterances", "target_utterances"], template=triplet_validate_prompt_template
)

class TransitionValidationResult(BaseModel):
isValid: bool = Field(description="Whether the transition is valid or not.")
description: str = Field(description="Explanation of why it's valid or invalid.")

parser = PydanticOutputParser(pydantic_object=TransitionValidationResult)

graph = G.graph_dict
# Create a mapping from node IDs to node data for quick access
node_map = {node["id"]: node for node in graph["nodes"]}
# Convert graph to JSON string
graph_json = json.dumps(G.graph_dict)

# Create node mapping
node_map = {node["id"]: node for node in G.graph_dict["nodes"]}

overall_valid = True
descriptions = []

for edge in graph["edges"]:
for edge in G.graph_dict["edges"]:
source_id = edge["source"]
target_id = edge["target"]
edge_utterances = edge["utterances"]

# Check if source and target nodes exist
if source_id not in node_map:
description = f"Invalid edge: source node {source_id} does not exist."
logging.info(description)
if source_id not in node_map or target_id not in node_map:
description = f"Invalid edge: missing node reference {source_id} -> {target_id}"
overall_valid = False
descriptions.append(description)
continue
if target_id not in node_map:
description = f"Invalid edge: target node {target_id} does not exist."
logging.info(description)
overall_valid = False
descriptions.append(description)
continue

source_node = node_map[source_id]
target_node = node_map[target_id]

# Get utterances from nodes
source_utterances = source_node.get("utterances", [])
target_utterances = target_node.get("utterances", [])
# Get utterances
source_utterances = node_map[source_id]["utterances"]
target_utterances = node_map[target_id]["utterances"]
edge_utterances = edge["utterances"]

# Prepare input data for the chain
# Prepare input for validation
input_data = {
"json_graph": graph_json,
"source_utterances": source_utterances,
"edge_utterances": edge_utterances,
"target_utterances": target_utterances,
"topic": topic,
}

# print(triplet_validate_prompt.format(**input_data))

# Run validation
triplet_check_chain = triplet_validate_prompt | model | parser
response = triplet_check_chain.invoke(input_data)

if not response.isValid:
overall_valid = False
description = f"Invalid transition from {source_utterances} to {target_utterances} via edge '{edge_utterances}': {response.description}"
description = f"Invalid transition: {response.description}"
logging.info(description)
descriptions.append(description)

result = {"value": overall_valid, "description": " ".join(descriptions) if descriptions else "All transitions are valid."}

return result


Expand Down
Loading

0 comments on commit 17a7169

Please sign in to comment.