From 6b4852e9fb7751994a1b1877b72e523d04c623c0 Mon Sep 17 00:00:00 2001 From: "longbin.lailb" Date: Sun, 17 Nov 2024 23:36:11 +0800 Subject: [PATCH] refactor workflow --- python/graphy/graph/nodes/inspector_node.py | 13 -- .../graphy/graph/nodes/paper_reading_nodes.py | 183 +++++++++++++++++- .../graphy/workflows/survey_paper_reading.py | 104 +--------- 3 files changed, 187 insertions(+), 113 deletions(-) delete mode 100644 python/graphy/graph/nodes/inspector_node.py diff --git a/python/graphy/graph/nodes/inspector_node.py b/python/graphy/graph/nodes/inspector_node.py deleted file mode 100644 index 824a72d66..000000000 --- a/python/graphy/graph/nodes/inspector_node.py +++ /dev/null @@ -1,13 +0,0 @@ -from .base_node import BaseNode, NodeType - - -class InspectorNode(BaseNode): - """ - Node representing an inspector in the graph. - """ - - def __init__( - self, - name: str, - ): - super().__init__(name, NodeType.INSPECTOR) diff --git a/python/graphy/graph/nodes/paper_reading_nodes.py b/python/graphy/graph/nodes/paper_reading_nodes.py index 4520b37fb..6c19dc0cc 100644 --- a/python/graphy/graph/nodes/paper_reading_nodes.py +++ b/python/graphy/graph/nodes/paper_reading_nodes.py @@ -1,9 +1,15 @@ +from .base_node import BaseNode, NodeType, NodeCache from .chain_node import BaseChainNode, DataGenerator +from .pdf_extract_node import PDFExtractNode +from graph.base_graph import BaseGraph +from graph.base_edge import BaseEdge +from memory.llm_memory import VectorDBHierarchy from prompts import TEMPLATE_ACADEMIC_RESPONSE from config import WF_STATE_MEMORY_KEY -from langchain_core.pydantic_v1 import BaseModel, Field +from langchain_core.pydantic_v1 import BaseModel, Field, create_model from langchain_core.language_models.llms import BaseLLM +from langchain_core.embeddings import Embeddings from typing import Any, Dict, List, Generator import logging @@ -26,6 +32,54 @@ class NameDescListFormat(BaseModel): ) +def create_inspector_graph( + graph_dict: Dict[str, Any], + llm_model: BaseLLM, + parser_model: BaseLLM, + embeddings_model: Embeddings, + max_token_size: int = 8192, + enable_streaming: bool = False, +) -> BaseGraph: + nodes_dict = {} + nodes = [] + edges = [] + start_node = "Paper" + + for node in graph_dict["nodes"]: + if node["name"] == start_node: # node_0 = pdf_extract + nodes_dict[node["name"]] = PDFExtractNode( + embeddings_model, + start_node, + ) + else: + extract_node = ExtractNode.from_dict( + node, + llm_model, + parser_model, + max_token_size, + enable_streaming, + ) + nodes_dict[node["name"]] = extract_node + + for _, value in nodes_dict.items(): + nodes.append(value) + for edge in graph_dict["edges"]: + edges.append(BaseEdge(edge["source"], edge["target"])) + if edge["source"] != start_node: + nodes_dict[edge["target"]].add_dependent_node(edge["source"]) + + graph = BaseGraph() + # Add all nodes + for node in nodes: + graph.add_node(node) + + # Add all edges + for edge in edges: + graph.add_edge(edge) + + return graph + + class ExtractNode(BaseChainNode): def __init__( self, @@ -54,6 +108,127 @@ def __init__( self.query_dependency = "" self.dependent_nodes = [] + @classmethod + def from_dict( + cls, + node_dict: Dict[str, Any], + llm_model: BaseLLM, + parser_model: BaseLLM, + max_token_size: int = 8192, + enable_streaming: bool = False, + ) -> "ExtractNode": + """ + Creates an ExtractNode instance from a dictionary. + + Args: + node_dict: Dictionary containing node configuration. + llm_model: The LLM model to be used. + parser_model: The parser model to be used. + max_token_size: Maximum token size for the node. + enable_streaming: Flag to enable or disable streaming. + + Returns: + ExtractNode: An initialized ExtractNode instance. + """ + # Build output schema + items = {} + for item in node_dict["output_schema"]["item"]: + item_type = item["type"] + if item_type == "string": + items[item["name"]] = (str, Field(description=item["description"])) + elif item_type == "int": + items[item["name"]] = (int, Field(description=item["description"])) + else: + raise ValueError(f"Unsupported type: {item_type}") + + ItemClass = create_model(node_dict["name"] + "ItemClass", **items) + + # Determine the type of data (array or single) + if node_dict["output_schema"]["type"] == "array": + item_type = { + "data": ( + List[ItemClass], + Field(description=node_dict["output_schema"]["description"]), + ) + } + elif node_dict["output_schema"]["type"] == "single": + item_type = { + "data": ( + ItemClass, + Field(description=node_dict["output_schema"]["description"]), + ) + } + else: + raise ValueError( + f"Unsupported output schema type: {node_dict['output_schema']['type']}" + ) + + NodeClass = create_model(node_dict["name"] + "NodeClass", **item_type) + + # Build `where` conditions + extract_from = node_dict.get("extract_from") + where = None + if isinstance(extract_from, str): + where_conditions = extract_from.split("|") + condition_dict = ( + { + "sec_name": { + "$in": { + "conditions": {"type": VectorDBHierarchy.FirstLayer.value}, + "return": "documents", + "subquery": where_conditions[0], + "result_num": 1, + } + } + } + if len(where_conditions) == 1 + else { + "$or": [ + { + "sec_name": { + "$in": { + "conditions": { + "type": VectorDBHierarchy.FirstLayer.value + }, + "return": "documents", + "subquery": condition, + "result_num": 1, + } + } + } + for condition in where_conditions + ] + } + ) + where = { + "conditions": condition_dict, + "return": "all", + "result_num": -1, + "subquery": "{slot}", + } + elif isinstance(extract_from, list): + where = { + "conditions": { + "section": {"$in": ["paper_meta", "abstract"] + extract_from} + }, + "return": "all", + "result_num": -1, + "subquery": "{slot}", + } + + # Create and return the ExtractNode instance + return cls( + node_name=node_dict["name"], + llm=llm_model, + parser_llm=parser_model, + output_format=NodeClass, + input_query=node_dict["query"], + max_token_size=max_token_size, + enable_streaming=enable_streaming, + block_config=None, + where=where, + ) + def add_dependent_node(self, dependent_node): self.dependent_nodes.append(dependent_node) self.query_dependency = ( @@ -85,3 +260,9 @@ def execute( # Format query for paper reading self.query = TEMPLATE_ACADEMIC_RESPONSE.format(user_query=self.query) yield from super().execute(state, _input) + + +class PaperInspector(BaseNode): + def __init__(self, name: str, graph: BaseGraph): + super().__init__(name, NodeType.INSPECTOR) + self.graph = graph diff --git a/python/graphy/workflows/survey_paper_reading.py b/python/graphy/workflows/survey_paper_reading.py index b2d6dd9de..beec42d41 100644 --- a/python/graphy/workflows/survey_paper_reading.py +++ b/python/graphy/workflows/survey_paper_reading.py @@ -112,108 +112,14 @@ def _create_graph(self, workflow): } ) else: - items = {} - for item in node["output_schema"]["item"]: - if item["type"] == "string": - items[item["name"]] = ( - str, - Field(description=item["description"]), - ) - elif item["type"] == "int": - items[item["name"]] = ( - int, - Field(description=item["description"]), - ) - else: - pass - # ERROR - ItemClass = create_model(node["name"] + "ItemClass", **items) - if node["output_schema"]["type"] == "array": - item_type = { - "data": ( - List[ItemClass], - Field(description=node["output_schema"]["description"]), - ), - } - elif node["output_schema"]["type"] == "single": - item_type = { - "data": ( - ItemClass, - Field(description=node["output_schema"]["description"]), - ), - } - else: - pass - # ERROR - NodeClass = create_model(node["name"] + "NodeClass", **item_type) - - if not node["extract_from"]: - where = None - elif type(node["extract_from"]) is str: - where_conditions = node["extract_from"].split("|") - condition_dict = {} - if len(where_conditions) == 1: - condition_dict = { - "sec_name": { - "$in": { - "conditions": { - "type": VectorDBHierarchy.FirstLayer.value - }, - "return": "documents", - "subquery": where_conditions[0], - "result_num": 1, - } - } - } - else: - condition_dict["$or"] = [] - for condition in where_conditions: - condition_dict["$or"].append( - { - "sec_name": { - "$in": { - "conditions": { - "type": VectorDBHierarchy.FirstLayer.value - }, - "return": "documents", - "subquery": condition, - "result_num": 1, - } - } - } - ) - - where = { - "conditions": condition_dict, - "return": "all", - "result_num": -1, - "subquery": "{slot}", - } - elif type(node["extract_from"]) is list: - where = { - "conditions": { - "section": { - "$in": ["paper_meta", "abstract"] + node["extract_from"] - } - }, - "return": "all", - "result_num": -1, - "subquery": "{slot}", - } - else: - pass - - nodes_dict[node["name"]] = ExtractNode( - node["name"], + extract_node = ExtractNode.from_dict( + node, self.llm_model, self.parser_model, - NodeClass, - node["query"], self.max_token_size, self.enable_streaming, - None, - where, ) + nodes_dict[node["name"]] = extract_node output_dict["nodes"].append( { "node_name": "ExtractNode", @@ -221,12 +127,12 @@ def _create_graph(self, workflow): node["name"], "[llm_model]", "[parser_model]", - NodeClass, + extract_node.json_format, node["query"], "[max_token_size]", "[enable_streaming]", None, - where, + extract_node.where, ), } )