Skip to content

Commit

Permalink
simplify workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
longbinlai committed Nov 20, 2024
1 parent 2b87c47 commit 04de028
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 160 deletions.
3 changes: 1 addition & 2 deletions python/graphy/graph/nodes/paper_reading_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,8 +532,7 @@ def execute(
self.persist_store.save_state(data_id, "_DONE", {"done": True})
# clean state
state[data_id][WF_STATE_CACHE_KEY].clear()
state[data_id][WF_STATE_EXTRACTOR_KEY].clear_memory()
state[data_id][WF_STATE_MEMORY_KEY].clear()
state[data_id][WF_STATE_MEMORY_KEY].clear_memory()
state.pop(data_id)

yield self.persist_store.get_state(data_id, first_node_name)
Expand Down
7 changes: 6 additions & 1 deletion python/graphy/workflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,9 @@
from .survey_paper_reading import SurveyPaperReading
from .executor import ThreadPoolWorkflowExecutor

__all__ = ["BaseWorkflow", "SurveyPaperReading", "ThreadPoolWorkflowExecutor"]
__all__ = [
"BaseWorkflow",
"SurveyPaperReading",
"ThreadPoolWorkflowExecutor",
"RayWorkflowExecutor",
]
238 changes: 81 additions & 157 deletions python/graphy/workflow/executor.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,26 @@
from abc import ABC, abstractmethod
from typing import Dict, Any, Union

from concurrent.futures import ThreadPoolExecutor, as_completed
from queue import Queue
from typing import Dict, Any, List, Generator

from concurrent.futures import ThreadPoolExecutor, Future
from queue import Queue, Empty
from typing import List, Dict, Any
from graph.types import DataType, DataGenerator
from graph.nodes import BaseNode, NodeType
from graph.edges import BaseEdge

import asyncio
import threading
import logging

logger = logging.getLogger(__name__)


class WorkflowExecutor(ABC):
"""
Abstract class for a Workflow Executor.
Abstract base class for Workflow Executors.
"""

def __init__(self, workflow):
def __init__(self, workflow, max_inspectors: int = 100):
self.workflow = workflow
self.max_inspectors = max_inspectors
self.processed_inspectors = 0
self.state = workflow.state

@abstractmethod
def execute(self, initial_inputs: List[DataType]):
Expand All @@ -36,192 +35,117 @@ def execute(self, initial_inputs: List[DataType]):
"""
pass

@abstractmethod
def _execute_node_task(self, node, input_data: DataType, state: Dict[str, Any]):
"""
Execute a node.
Args:
node: The node to execute.
input_data (DataType): Input data for the node.
state (Dict[str, Any]): Workflow state.
"""
pass

@abstractmethod
def _execute_edge_task(self, edge, input_data: DataType, state: Dict[str, Any]):
"""
Execute an edge.
Args:
edge: The edge to execute.
input_data (DataType): Input data for the edge.
state (Dict[str, Any]): Workflow state.
"""
pass


class Task:
"""
Represents a unit of work in the workflow execution.
Attributes:
input (DataGenerator): The input data for the task.
executor (Union[BaseNode, BaseEdge, str]): The executor responsible for processing the task.
Can be a BaseNode, or a BaseEdge.
executor (Union[BaseNode, BaseEdge]): The executor responsible for processing the task.
"""

def __init__(self, input: DataGenerator, executor: Union[BaseNode, BaseEdge]):
def __init__(self, input: DataGenerator, executor: Any):
self.input = input # Input data for the task
self.executor = executor # Executor (node/edge) for this task

def __repr__(self):
return f"Task(executor={self.executor}, data={list(self.data)[:5]}...)" # Limit repr for large data
return f"Task(executor={self.executor})"


class ThreadPoolWorkflowExecutor(WorkflowExecutor):
"""
WorkflowExecutor implementation using ThreadPoolExecutor with asynchronous execution.
WorkflowExecutor implementation using ThreadPoolExecutor with parallel execution.
"""

def __init__(
self,
workflow,
max_workers: int = 4,
queue_size: int = 100,
max_inspectors: int = 100,
):
super().__init__(workflow)
self.task_queue = asyncio.Queue(queue_size)
def __init__(self, workflow, max_workers: int = 4, max_inspectors: int = 100):
super().__init__(workflow, max_inspectors)
self.task_queue = Queue()
self.executor = ThreadPoolExecutor(max_workers=max_workers)
self.max_inspectors = max_inspectors
self.lock = threading.Lock()
self.active_futures = set()

async def execute(self, initial_inputs: List[DataType]):
def execute(self, initial_inputs: List[DataType]):
"""
Execute the workflow starting with the initial inputs.
Args:
initial_inputs (List[DataType]): The initial input data.
Returns:
Dict[str, Any]: Final state after workflow execution.
"""
state = self.workflow.state

# Add all initial inputs to the task queue with the first node
# Add initial tasks to the queue
first_node = self.workflow.graph.get_first_node()
if not first_node:
raise ValueError("No nodes found in the workflow graph.")
for input_data in initial_inputs:
await self.task_queue.put(Task(iter([input_data]), first_node))

# Start processing tasks asynchronously
tasks = [
asyncio.create_task(self._process_task(state))
for _ in range(self.executor._max_workers)
]

# Wait for all tasks to complete
await self.task_queue.join()
await asyncio.gather(*tasks)

# Shutdown the executor
self.executor.shutdown(wait=True)

async def _process_task(self, state: Dict[str, Any]):
processed_inspectors = 0
"""
Process a single task from the task queue.
Args:
state (Dict[str, Any]): The current workflow state.
"""
while not self.task_queue.empty():
task = await self.task_queue.get()
logger.debug(f"======= GET TASK {task.executor} ========")

if isinstance(task.executor, BaseNode): # Node
node = task.executor
logger.info(f"Executing node: {node}")
if not node:
raise ValueError(f"Node not found in the graph.")

# Execute the node in the thread pool
downstream_tasks = await asyncio.to_thread(
self._execute_node_task, node, task.input, state
)

# Add downstream tasks to the queue
for down_task in downstream_tasks:
logger.debug(f"ADD EDGE TASK {down_task}")
await self.task_queue.put(down_task)

elif isinstance(task.executor, BaseEdge): # Edge
edge = task.executor
logger.info(f"Executing edge: {edge.name}")

# Execute the edge in the thread pool
downstream_tasks = await asyncio.to_thread(
self._execute_edge_task, edge, task.input, state
)

# Add downstream tasks to the queue
for down_task in downstream_tasks:
await self.task_queue.put(down_task)
else:
raise ValueError(f"Invalid executor type: {task.executor}")

# Mark the task as done
logger.debug(f"======= FINISH TASK {task.executor} ========")
self.task_queue.task_done()

def _execute_node_task(self, node, input_gen, state):
for input_data in initial_inputs:
self.task_queue.put(Task(iter([input_data]), first_node))

try:
while not self.task_queue.empty() or self.active_futures:
# Submit tasks to executor
while not self.task_queue.empty():
task = self.task_queue.get()
future = self.executor.submit(self._process_task, task)
self.active_futures.add(future)
future.add_done_callback(self._on_task_complete)

# Wait for any task to complete
completed_futures = [f for f in self.active_futures if f.done()]
for future in completed_futures:
self.active_futures.remove(future)

finally:
self.executor.shutdown(wait=True)

def _on_task_complete(self, future: Future):
"""
Execute a node task and generate downstream tasks.
Args:
node (BaseNode): The node to execute.
input_gen (DataGenerator): The input generator.
state (Dict[str, Any]): The current workflow state.
Returns:
List[Tuple[DataGenerator, BaseEdge]]: Downstream tasks for adjacent edges.
Callback for task completion to handle downstream tasks.
"""
logger.debug(f"================== EXECUTE Node {node} =================")
results = node.execute(state, input_gen)
downstream_tasks = []

for result in results:
# Create tasks for all adjacent edges
for edge in self.workflow.graph.get_adjacent_edges(node.name):
downstream_tasks.append(Task(iter([result]), edge))

return downstream_tasks

def _execute_edge_task(self, edge, input_gen, state):
try:
downstream_tasks = future.result()
for task in downstream_tasks:
self.task_queue.put(task)
except Exception as e:
logger.error(f"Task failed with error: {e}")

def _process_task(self, task: Task) -> List[Task]:
"""
Execute an edge task and generate downstream tasks.
Process a single task.
Args:
edge (BaseEdge): The edge to execute.
input_gen (DataGenerator): The input generator.
state (Dict[str, Any]): The current workflow state.
task (Task): The task to process.
Returns:
List[Tuple[DataGenerator, BaseNode]]: Downstream tasks for target nodes.
List[Task]: Downstream tasks to enqueue.
"""
logger.debug(f"================== EXECUTE Edge {edge} =================")

results = edge.execute(state, input_gen)
downstream_tasks = []

for result in results:
# Create tasks for the target node
logger.debug(
f"#################### ADD TASK {(iter([result]), edge.target)} ###################"
)
downstream_tasks.append(
Task(iter([result]), self.workflow.graph.get_node(edge.target))
)
executor = task.executor

if isinstance(executor, BaseNode): # Node task
logger.info(f"Executing node: {executor}")
results = executor.execute(self.state, task.input)

for result in results:
if executor.node_type == NodeType.INSPECTOR:
with self.lock:
self.processed_inspectors += 1
if self.processed_inspectors >= self.max_inspectors:
logger.info(
f"Reached max inspectors limit '{self.max_inspectors}', stopping execution)"
)
return []
for edge in self.workflow.graph.get_adjacent_edges(executor.name):
downstream_tasks.append(Task(iter([result]), edge))

elif isinstance(executor, BaseEdge): # Edge task
logger.info(f"Executing edge: {executor.name}")
results = executor.execute(self.state, task.input)

for result in results:
target_node = self.workflow.graph.get_node(executor.target)
downstream_tasks.append(Task(iter([result]), target_node))

else:
raise ValueError(f"Invalid executor type: {executor}")

return downstream_tasks

0 comments on commit 04de028

Please sign in to comment.