Skip to content

Commit

Permalink
refactor workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
longbinlai committed Nov 17, 2024
1 parent 9b131cc commit 6b4852e
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 113 deletions.
13 changes: 0 additions & 13 deletions python/graphy/graph/nodes/inspector_node.py

This file was deleted.

183 changes: 182 additions & 1 deletion python/graphy/graph/nodes/paper_reading_nodes.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from .base_node import BaseNode, NodeType, NodeCache
from .chain_node import BaseChainNode, DataGenerator
from .pdf_extract_node import PDFExtractNode
from graph.base_graph import BaseGraph
from graph.base_edge import BaseEdge
from memory.llm_memory import VectorDBHierarchy
from prompts import TEMPLATE_ACADEMIC_RESPONSE
from config import WF_STATE_MEMORY_KEY

from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.pydantic_v1 import BaseModel, Field, create_model
from langchain_core.language_models.llms import BaseLLM
from langchain_core.embeddings import Embeddings

from typing import Any, Dict, List, Generator
import logging
Expand All @@ -26,6 +32,54 @@ class NameDescListFormat(BaseModel):
)


def create_inspector_graph(
graph_dict: Dict[str, Any],
llm_model: BaseLLM,
parser_model: BaseLLM,
embeddings_model: Embeddings,
max_token_size: int = 8192,
enable_streaming: bool = False,
) -> BaseGraph:
nodes_dict = {}
nodes = []
edges = []
start_node = "Paper"

for node in graph_dict["nodes"]:
if node["name"] == start_node: # node_0 = pdf_extract
nodes_dict[node["name"]] = PDFExtractNode(
embeddings_model,
start_node,
)
else:
extract_node = ExtractNode.from_dict(
node,
llm_model,
parser_model,
max_token_size,
enable_streaming,
)
nodes_dict[node["name"]] = extract_node

for _, value in nodes_dict.items():
nodes.append(value)
for edge in graph_dict["edges"]:
edges.append(BaseEdge(edge["source"], edge["target"]))
if edge["source"] != start_node:
nodes_dict[edge["target"]].add_dependent_node(edge["source"])

graph = BaseGraph()
# Add all nodes
for node in nodes:
graph.add_node(node)

# Add all edges
for edge in edges:
graph.add_edge(edge)

return graph


class ExtractNode(BaseChainNode):
def __init__(
self,
Expand Down Expand Up @@ -54,6 +108,127 @@ def __init__(
self.query_dependency = ""
self.dependent_nodes = []

@classmethod
def from_dict(
cls,
node_dict: Dict[str, Any],
llm_model: BaseLLM,
parser_model: BaseLLM,
max_token_size: int = 8192,
enable_streaming: bool = False,
) -> "ExtractNode":
"""
Creates an ExtractNode instance from a dictionary.
Args:
node_dict: Dictionary containing node configuration.
llm_model: The LLM model to be used.
parser_model: The parser model to be used.
max_token_size: Maximum token size for the node.
enable_streaming: Flag to enable or disable streaming.
Returns:
ExtractNode: An initialized ExtractNode instance.
"""
# Build output schema
items = {}
for item in node_dict["output_schema"]["item"]:
item_type = item["type"]
if item_type == "string":
items[item["name"]] = (str, Field(description=item["description"]))
elif item_type == "int":
items[item["name"]] = (int, Field(description=item["description"]))
else:
raise ValueError(f"Unsupported type: {item_type}")

ItemClass = create_model(node_dict["name"] + "ItemClass", **items)

# Determine the type of data (array or single)
if node_dict["output_schema"]["type"] == "array":
item_type = {
"data": (
List[ItemClass],
Field(description=node_dict["output_schema"]["description"]),
)
}
elif node_dict["output_schema"]["type"] == "single":
item_type = {
"data": (
ItemClass,
Field(description=node_dict["output_schema"]["description"]),
)
}
else:
raise ValueError(
f"Unsupported output schema type: {node_dict['output_schema']['type']}"
)

NodeClass = create_model(node_dict["name"] + "NodeClass", **item_type)

# Build `where` conditions
extract_from = node_dict.get("extract_from")
where = None
if isinstance(extract_from, str):
where_conditions = extract_from.split("|")
condition_dict = (
{
"sec_name": {
"$in": {
"conditions": {"type": VectorDBHierarchy.FirstLayer.value},
"return": "documents",
"subquery": where_conditions[0],
"result_num": 1,
}
}
}
if len(where_conditions) == 1
else {
"$or": [
{
"sec_name": {
"$in": {
"conditions": {
"type": VectorDBHierarchy.FirstLayer.value
},
"return": "documents",
"subquery": condition,
"result_num": 1,
}
}
}
for condition in where_conditions
]
}
)
where = {
"conditions": condition_dict,
"return": "all",
"result_num": -1,
"subquery": "{slot}",
}
elif isinstance(extract_from, list):
where = {
"conditions": {
"section": {"$in": ["paper_meta", "abstract"] + extract_from}
},
"return": "all",
"result_num": -1,
"subquery": "{slot}",
}

# Create and return the ExtractNode instance
return cls(
node_name=node_dict["name"],
llm=llm_model,
parser_llm=parser_model,
output_format=NodeClass,
input_query=node_dict["query"],
max_token_size=max_token_size,
enable_streaming=enable_streaming,
block_config=None,
where=where,
)

def add_dependent_node(self, dependent_node):
self.dependent_nodes.append(dependent_node)
self.query_dependency = (
Expand Down Expand Up @@ -85,3 +260,9 @@ def execute(
# Format query for paper reading
self.query = TEMPLATE_ACADEMIC_RESPONSE.format(user_query=self.query)
yield from super().execute(state, _input)


class PaperInspector(BaseNode):
def __init__(self, name: str, graph: BaseGraph):
super().__init__(name, NodeType.INSPECTOR)
self.graph = graph
104 changes: 5 additions & 99 deletions python/graphy/workflows/survey_paper_reading.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,121 +112,27 @@ def _create_graph(self, workflow):
}
)
else:
items = {}
for item in node["output_schema"]["item"]:
if item["type"] == "string":
items[item["name"]] = (
str,
Field(description=item["description"]),
)
elif item["type"] == "int":
items[item["name"]] = (
int,
Field(description=item["description"]),
)
else:
pass
# ERROR
ItemClass = create_model(node["name"] + "ItemClass", **items)
if node["output_schema"]["type"] == "array":
item_type = {
"data": (
List[ItemClass],
Field(description=node["output_schema"]["description"]),
),
}
elif node["output_schema"]["type"] == "single":
item_type = {
"data": (
ItemClass,
Field(description=node["output_schema"]["description"]),
),
}
else:
pass
# ERROR
NodeClass = create_model(node["name"] + "NodeClass", **item_type)

if not node["extract_from"]:
where = None
elif type(node["extract_from"]) is str:
where_conditions = node["extract_from"].split("|")
condition_dict = {}
if len(where_conditions) == 1:
condition_dict = {
"sec_name": {
"$in": {
"conditions": {
"type": VectorDBHierarchy.FirstLayer.value
},
"return": "documents",
"subquery": where_conditions[0],
"result_num": 1,
}
}
}
else:
condition_dict["$or"] = []
for condition in where_conditions:
condition_dict["$or"].append(
{
"sec_name": {
"$in": {
"conditions": {
"type": VectorDBHierarchy.FirstLayer.value
},
"return": "documents",
"subquery": condition,
"result_num": 1,
}
}
}
)

where = {
"conditions": condition_dict,
"return": "all",
"result_num": -1,
"subquery": "{slot}",
}
elif type(node["extract_from"]) is list:
where = {
"conditions": {
"section": {
"$in": ["paper_meta", "abstract"] + node["extract_from"]
}
},
"return": "all",
"result_num": -1,
"subquery": "{slot}",
}
else:
pass

nodes_dict[node["name"]] = ExtractNode(
node["name"],
extract_node = ExtractNode.from_dict(
node,
self.llm_model,
self.parser_model,
NodeClass,
node["query"],
self.max_token_size,
self.enable_streaming,
None,
where,
)
nodes_dict[node["name"]] = extract_node
output_dict["nodes"].append(
{
"node_name": "ExtractNode",
"input": (
node["name"],
"[llm_model]",
"[parser_model]",
NodeClass,
extract_node.json_format,
node["query"],
"[max_token_size]",
"[enable_streaming]",
None,
where,
extract_node.where,
),
}
)
Expand Down

0 comments on commit 6b4852e

Please sign in to comment.