Skip to content

Commit

Permalink
add workflow executor
Browse files Browse the repository at this point in the history
  • Loading branch information
longbinlai committed Nov 18, 2024
1 parent e634301 commit 7805b5a
Show file tree
Hide file tree
Showing 10 changed files with 397 additions and 794 deletions.
27 changes: 23 additions & 4 deletions python/graphy/graph/base_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,35 @@ class BaseGraph:
"""

def __init__(self):
self.node_names = []
self.nodes: Dict[str, BaseNode] = {}
self.edges: Dict[str, BaseEdge] = {}
self.adjacency_list: Dict[str, List[str]] = {}

def get_node_names(self) -> List[str]:
"""Returns a list of all node names in the graph."""
return list(self.nodes.keys())
return self.node_names

def get_first_node_name(self) -> None | str:
"""Returns the name of the first node in the graph."""
if not self.node_names:
return None
else:
return self.node_names[0]

def get_first_node(self) -> None | BaseNode:
"""Returns the first node in the graph."""
first_node_name = self.get_first_node_name()
if first_node_name:
return self.get_node(first_node_name)
else:
return None

def add_node(self, node: BaseNode):
"""Adds a node to the graph."""
if node.name in self.nodes:
raise ValueError(f"Node {node.name} already exists.")
self.node_names.append(node.name)
self.nodes[node.name] = node
self.adjacency_list[node.name] = []

Expand All @@ -37,6 +54,7 @@ def remove_node_by_name(self, node_name: str):
self.edges.pop(edge_name, None)
del self.adjacency_list[node_name]
del self.nodes[node_name]
self.node_names.remove(node_name)

# Remove any edges pointing to this node
for source, edges in self.adjacency_list.items():
Expand Down Expand Up @@ -65,14 +83,15 @@ def get_adjacent_edges(self, node_name: str) -> List[BaseEdge]:
"""Returns a list of edges adjacent to a given node."""
return [self.edges[edge] for edge in self.adjacency_list.get(node_name, [])]

def get_adjacent_nodes(self, node_name: str) -> List[str]:
def get_adjacent_nodes(self, node_name: str) -> List[BaseNode]:
"""Returns a list of edges adjacent to a given node."""
return [
self.edges[edge].target for edge in self.adjacency_list.get(node_name, [])
self.get_node(self.edges[edge].target)
for edge in self.adjacency_list.get(node_name, [])
]

def nodes_count(self) -> int:
return len(self.nodes)
return len(self.node_names)

def edges_count(self) -> int:
return len(self.edges)
Expand Down
6 changes: 0 additions & 6 deletions python/graphy/graph/nodes/base_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,6 @@ def __init__(self, name: str, node_type: NodeType = NodeType.BASE):
self.name = name
self.node_type = node_type

def set_name(self, new_name):
self.name = new_name

def get_node_key(self) -> str:
return self.name

def pre_execute(self, state: Dict[str, Any] = None):
"""define pre-execution logic"""
logger.info(f"Executing node: {self.get_node_key()}")
Expand Down
122 changes: 74 additions & 48 deletions python/graphy/graph/nodes/paper_reading_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,10 @@ def backpedal(self):
self.completed -= 1

def __str__(self):
return f"Number: {self.number}, Completed: {self.completed}"
return f"{self.number}, {self.completed}"

def __repr__(self) -> str:
return f"ProgressInfo [ Number: {self.number}, Completed: {self.completed} ]"


def create_inspector_graph(
Expand Down Expand Up @@ -330,6 +333,7 @@ def execute(

# Format query for paper reading
self.query = TEMPLATE_ACADEMIC_RESPONSE.format(user_query=self.query)

yield from super().execute(state, _input)


Expand All @@ -347,18 +351,21 @@ def __init__(
self.llm_model = llm_model
self.embeddings_model = embeddings_model
self.persist_store = persist_store
self.progress = {"total": ProgressInfo(self.graph.nodes_count(), 0)}
self.progress = {"total": ProgressInfo(0, 0)}
for node in self.graph.get_node_names():
self.progress[node] = ProgressInfo(1, 0)
self.progress[node] = ProgressInfo(0, 0)

def get_progress(self) -> ProgressInfo:
return self.progress["total"]

def run_through(
self,
input_data: DataType,
data_id,
state,
continue_on_error: bool = True,
is_persist: bool = True,
skipped_nodes: List[str] = [],
) -> DataGenerator:
):
"""
Runs through the workflow and executes all nodes.
Expand All @@ -371,33 +378,7 @@ def run_through(
DataGenerator: Outputs generated by the workflow nodes.
"""

logger.info(f"Executing {self.name} for input data: {input_data}")

paper_file_path = input_data.get("paper_file_path", None)
if not paper_file_path:
logger.error("No 'paper_file_path' provided in input data.")
yield {}
return

# Initialize the paper extractor and other components
pdf_extractor = PaperExtractor(paper_file_path)
base_name = pdf_extractor.get_meta_data().get("title", "").lower()
if not base_name: # If no title, fallback to filename
base_name = os.path.basename(paper_file_path).split(".")[0]
data_id = process_id(base_name)
pdf_extractor.set_img_path(f"{WF_IMAGE_DIR}/{data_id}")
state[data_id] = {
WF_STATE_CACHE_KEY: {},
WF_STATE_EXTRACTOR_KEY: pdf_extractor,
WF_STATE_MEMORY_KEY: PaperReadingMemoryManager(
self.llm_model.model,
self.embeddings_model,
data_id,
self.llm_model.context_size,
),
}

current_node_name = "Paper"
current_node_name = self.graph.get_first_node_name()
next_nodes = [current_node_name]

while next_nodes:
Expand All @@ -410,22 +391,20 @@ def run_through(
curr_node = self.graph.get_node(current_node_name)
if not curr_node:
logger.error(f"Node '{current_node_name}' not found in the graph.")
yield {}
continue

last_output = None
try:
curr_node.pre_execute(state)
# Execute the current node
output_generator = curr_node.execute(state[data_id])
output_generator = curr_node.execute(state)
for output in output_generator:
last_output = output
yield last_output # Yield each output from the node execution
except Exception as e:
logger.error(f"Error executing node '{current_node_name}': {e}")
if continue_on_error:
yield {} # Yield empty result in case of error if continue_on_error is True
continue
else:
print("raise error")
raise ValueError(f"Error executing node '{current_node_name}': {e}")
finally:
# Complete progress tracking
Expand All @@ -435,29 +414,33 @@ def run_through(
# Persist the output and queries if applicable
if last_output and is_persist and self.persist_store:
self.persist_store.save_state(
data_id, curr_node.get_node_key(), last_output
data_id, current_node_name, last_output
)
if curr_node.get_query():
input_query = f"**************QUERY***************:\n{curr_node.get_query()}\n**************MEMORY**************:\n{curr_node.get_memory()}"
input_query = f"""
**************QUERY***************:
{curr_node.get_query()}
**************MEMORY**************:
{curr_node.get_memory()}
"""
self.persist_store.save_query(
data_id, curr_node.get_node_key(), input_query
data_id, current_node_name, input_query
)

# Cache the output
if last_output:
node_caches: dict = state[data_id].get(WF_STATE_CACHE_KEY, {})
node_key = curr_node.get_node_key()
node_caches: dict = state.get(WF_STATE_CACHE_KEY, {})
node_cache: NodeCache = node_caches.setdefault(
node_key, NodeCache(node_key)
current_node_name, NodeCache(current_node_name)
)
if node_cache:
node_cache.add_chat_cache("", last_output)

curr_node.post_execute(last_output)

# Add adjacent nodes to the processing queue
for next_node_name in reversed(
self.graph.get_adjacent_nodes(current_node_name)
):
next_nodes.append(next_node_name)
for next_node in reversed(self.graph.get_adjacent_nodes(current_node_name)):
next_nodes.append(next_node.name)

@profiler.profile
def execute(
Expand All @@ -475,4 +458,47 @@ def execute(
"""

for input_data in input:
yield self.run_through(input_data, state)

paper_file_path = input_data.get("paper_file_path", None)
logger.info(f"Executing {self.name} for paper: {paper_file_path}")

if not paper_file_path:
logger.error("No 'paper_file_path' provided in input data.")
continue

try:
# Initialize the paper extractor and other components
pdf_extractor = PaperExtractor(paper_file_path)
base_name = pdf_extractor.get_meta_data().get("title", "").lower()
if not base_name: # If no title, fallback to filename
base_name = os.path.basename(paper_file_path).split(".")[0]
data_id = process_id(base_name)
pdf_extractor.set_img_path(f"{WF_IMAGE_DIR}/{data_id}")
state[data_id] = {
WF_STATE_CACHE_KEY: {},
WF_STATE_EXTRACTOR_KEY: pdf_extractor,
WF_STATE_MEMORY_KEY: PaperReadingMemoryManager(
self.llm_model.model,
self.embeddings_model,
data_id,
self.llm_model.context_size,
),
}
except Exception as e:
logger.error(f"Error initializing PaperExtractor: {e}")
continue

self.progress["total"].add(ProgressInfo(self.graph.nodes_count(), 0))
for node in self.graph.get_node_names():
self.progress[node].add(ProgressInfo(1, 0))

self.run_through(data_id, state[data_id])
first_node_name = self.graph.get_first_node_name()

# Debugging
response = state[data_id][WF_STATE_CACHE_KEY][
first_node_name
].get_response()

# Ensure the correct response is yielded
yield response
64 changes: 19 additions & 45 deletions python/graphy/tests/workflow/paper_inspector_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pytest
from unittest.mock import MagicMock, create_autospec
from graph import BaseGraph
from graph import BaseGraph, BaseEdge
from graph.nodes import PaperInspector, BaseNode
from graph.nodes.paper_reading_nodes import ProgressInfo, ExtractNode
from graph.types import DataGenerator
Expand All @@ -22,9 +22,16 @@ def mock_graph():
mock_node.name = "Paper"
mock_node.get_query.return_value = ""
mock_node.get_memory.return_value = ""
mock_node.get_node_key.return_value = "Paper"
mock_node.execute.return_value = iter([{"result": "node_output"}])

mock_node2 = create_autospec(ExtractNode, instance=True)
mock_node2.name = "Extract"
mock_node2.get_query.return_value = ""
mock_node2.get_memory.return_value = ""
mock_node2.execute.return_value = iter([{"result": "node_output2"}])
graph.add_node(mock_node)
graph.add_node(mock_node2)
graph.add_edge(BaseEdge(source="Paper", target="Extract"))
return graph


Expand Down Expand Up @@ -56,60 +63,27 @@ def test_initialization(mock_paper_inspector, mock_graph):
assert mock_paper_inspector.persist_store is not None


def test_run_through(mock_paper_inspector):
def test_execute(mock_paper_inspector):
"""
Test the run_through method for processing a simple workflow.
Test the execute method for processing input generator.
"""
input_data = {"paper_file_path": "inputs/samples/graphrag.pdf"}
input_data = [{"paper_file_path": "inputs/samples/graphrag.pdf"}]
state = {}
output = list(mock_paper_inspector.run_through(input_data, state))
input_gen = (item for item in input_data) # Create a generator from input_data

# Ensure the expected results are returned
assert len(output) == 1
assert output[0] == {"result": "node_output"}
output_gen = mock_paper_inspector.execute(state, input_gen)
# Fully consume the generator
outputs = list(output_gen)

# Verify state and progress updates
data_id = list(state.keys())[0]
assert WF_STATE_CACHE_KEY in state[data_id]
assert WF_STATE_EXTRACTOR_KEY in state[data_id]
assert WF_STATE_MEMORY_KEY in state[data_id]
assert mock_paper_inspector.progress["Paper"].completed == 1
assert mock_paper_inspector.progress["total"].completed == 1
assert mock_paper_inspector.progress["Extract"].completed == 1
assert mock_paper_inspector.progress["total"].completed == 2
cached_response = state[data_id][WF_STATE_CACHE_KEY]["Paper"].get_response()
assert cached_response == {"result": "node_output"}


def test_execute(mock_paper_inspector):
"""
Test the execute method for processing input generator.
"""
input_data = [{"paper_file_path": "inputs/samples/graphrag.pdf"}]
state = {}
input_gen = (item for item in input_data) # Create a generator from input_data

output_gen = mock_paper_inspector.execute(state, input_gen)
# Fully consume the generator
outputs = [list(inner_gen) for inner_gen in output_gen]

assert len(outputs) == 1
assert len(outputs[0]) == 1 # `run_through` yields single outputs
assert outputs[0][0] == {"result": "node_output"}


def test_run_through_with_error(mock_paper_inspector):
"""
Test the run_through method for error handling.
"""
mock_paper_inspector.graph.get_node("Paper").execute.side_effect = Exception(
"Test error"
)

input_data = {"paper_file_path": "inputs/samples/graphrag.pdf"}
state = {}

# Test with continue_on_error=False
with pytest.raises(ValueError, match="Error executing node 'Paper': Test error"):
output_gen = mock_paper_inspector.run_through(
input_data, state, continue_on_error=False
)
[list(inner_gen) for inner_gen in output_gen]
assert outputs[0] == {"result": "node_output"}
6 changes: 3 additions & 3 deletions python/graphy/workflow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""
__init__.py for the graphs package
__init__.py for the workflow package
"""

from .abstract_workflow import AbstractWorkflow
from .base_workflow import BaseWorkflow
from .survey_paper_reading import SurveyPaperReading

__all__ = ["AbstractWorkflow", "SurveyPaperReading"]
__all__ = ["BaseWorkflow", "SurveyPaperReading"]
Loading

0 comments on commit 7805b5a

Please sign in to comment.