From 04de028e1f66b0a2a931ac901311eb1e01f247a5 Mon Sep 17 00:00:00 2001 From: "longbin.lailb" Date: Wed, 20 Nov 2024 11:56:53 +0800 Subject: [PATCH] simplify workflow --- .../graphy/graph/nodes/paper_reading_nodes.py | 3 +- python/graphy/workflow/__init__.py | 7 +- python/graphy/workflow/executor.py | 238 ++++++------------ 3 files changed, 88 insertions(+), 160 deletions(-) diff --git a/python/graphy/graph/nodes/paper_reading_nodes.py b/python/graphy/graph/nodes/paper_reading_nodes.py index cec8571b6..99d230c4b 100644 --- a/python/graphy/graph/nodes/paper_reading_nodes.py +++ b/python/graphy/graph/nodes/paper_reading_nodes.py @@ -532,8 +532,7 @@ def execute( self.persist_store.save_state(data_id, "_DONE", {"done": True}) # clean state state[data_id][WF_STATE_CACHE_KEY].clear() - state[data_id][WF_STATE_EXTRACTOR_KEY].clear_memory() - state[data_id][WF_STATE_MEMORY_KEY].clear() + state[data_id][WF_STATE_MEMORY_KEY].clear_memory() state.pop(data_id) yield self.persist_store.get_state(data_id, first_node_name) diff --git a/python/graphy/workflow/__init__.py b/python/graphy/workflow/__init__.py index 745ccf92e..2f07cf4da 100644 --- a/python/graphy/workflow/__init__.py +++ b/python/graphy/workflow/__init__.py @@ -6,4 +6,9 @@ from .survey_paper_reading import SurveyPaperReading from .executor import ThreadPoolWorkflowExecutor -__all__ = ["BaseWorkflow", "SurveyPaperReading", "ThreadPoolWorkflowExecutor"] +__all__ = [ + "BaseWorkflow", + "SurveyPaperReading", + "ThreadPoolWorkflowExecutor", + "RayWorkflowExecutor", +] diff --git a/python/graphy/workflow/executor.py b/python/graphy/workflow/executor.py index 0d65712ec..53b71712f 100644 --- a/python/graphy/workflow/executor.py +++ b/python/graphy/workflow/executor.py @@ -1,15 +1,11 @@ from abc import ABC, abstractmethod -from typing import Dict, Any, Union - -from concurrent.futures import ThreadPoolExecutor, as_completed -from queue import Queue -from typing import Dict, Any, List, Generator - +from concurrent.futures import ThreadPoolExecutor, Future +from queue import Queue, Empty +from typing import List, Dict, Any from graph.types import DataType, DataGenerator from graph.nodes import BaseNode, NodeType from graph.edges import BaseEdge - -import asyncio +import threading import logging logger = logging.getLogger(__name__) @@ -17,11 +13,14 @@ class WorkflowExecutor(ABC): """ - Abstract class for a Workflow Executor. + Abstract base class for Workflow Executors. """ - def __init__(self, workflow): + def __init__(self, workflow, max_inspectors: int = 100): self.workflow = workflow + self.max_inspectors = max_inspectors + self.processed_inspectors = 0 + self.state = workflow.state @abstractmethod def execute(self, initial_inputs: List[DataType]): @@ -36,30 +35,6 @@ def execute(self, initial_inputs: List[DataType]): """ pass - @abstractmethod - def _execute_node_task(self, node, input_data: DataType, state: Dict[str, Any]): - """ - Execute a node. - - Args: - node: The node to execute. - input_data (DataType): Input data for the node. - state (Dict[str, Any]): Workflow state. - """ - pass - - @abstractmethod - def _execute_edge_task(self, edge, input_data: DataType, state: Dict[str, Any]): - """ - Execute an edge. - - Args: - edge: The edge to execute. - input_data (DataType): Input data for the edge. - state (Dict[str, Any]): Workflow state. - """ - pass - class Task: """ @@ -67,161 +42,110 @@ class Task: Attributes: input (DataGenerator): The input data for the task. - executor (Union[BaseNode, BaseEdge, str]): The executor responsible for processing the task. - Can be a BaseNode, or a BaseEdge. + executor (Union[BaseNode, BaseEdge]): The executor responsible for processing the task. """ - def __init__(self, input: DataGenerator, executor: Union[BaseNode, BaseEdge]): + def __init__(self, input: DataGenerator, executor: Any): self.input = input # Input data for the task self.executor = executor # Executor (node/edge) for this task def __repr__(self): - return f"Task(executor={self.executor}, data={list(self.data)[:5]}...)" # Limit repr for large data + return f"Task(executor={self.executor})" class ThreadPoolWorkflowExecutor(WorkflowExecutor): """ - WorkflowExecutor implementation using ThreadPoolExecutor with asynchronous execution. + WorkflowExecutor implementation using ThreadPoolExecutor with parallel execution. """ - def __init__( - self, - workflow, - max_workers: int = 4, - queue_size: int = 100, - max_inspectors: int = 100, - ): - super().__init__(workflow) - self.task_queue = asyncio.Queue(queue_size) + def __init__(self, workflow, max_workers: int = 4, max_inspectors: int = 100): + super().__init__(workflow, max_inspectors) + self.task_queue = Queue() self.executor = ThreadPoolExecutor(max_workers=max_workers) - self.max_inspectors = max_inspectors + self.lock = threading.Lock() + self.active_futures = set() - async def execute(self, initial_inputs: List[DataType]): + def execute(self, initial_inputs: List[DataType]): """ Execute the workflow starting with the initial inputs. Args: initial_inputs (List[DataType]): The initial input data. - - Returns: - Dict[str, Any]: Final state after workflow execution. """ - state = self.workflow.state - - # Add all initial inputs to the task queue with the first node + # Add initial tasks to the queue first_node = self.workflow.graph.get_first_node() if not first_node: raise ValueError("No nodes found in the workflow graph.") - for input_data in initial_inputs: - await self.task_queue.put(Task(iter([input_data]), first_node)) - - # Start processing tasks asynchronously - tasks = [ - asyncio.create_task(self._process_task(state)) - for _ in range(self.executor._max_workers) - ] - - # Wait for all tasks to complete - await self.task_queue.join() - await asyncio.gather(*tasks) - - # Shutdown the executor - self.executor.shutdown(wait=True) - async def _process_task(self, state: Dict[str, Any]): - processed_inspectors = 0 - """ - Process a single task from the task queue. - - Args: - state (Dict[str, Any]): The current workflow state. - """ - while not self.task_queue.empty(): - task = await self.task_queue.get() - logger.debug(f"======= GET TASK {task.executor} ========") - - if isinstance(task.executor, BaseNode): # Node - node = task.executor - logger.info(f"Executing node: {node}") - if not node: - raise ValueError(f"Node not found in the graph.") - - # Execute the node in the thread pool - downstream_tasks = await asyncio.to_thread( - self._execute_node_task, node, task.input, state - ) - - # Add downstream tasks to the queue - for down_task in downstream_tasks: - logger.debug(f"ADD EDGE TASK {down_task}") - await self.task_queue.put(down_task) - - elif isinstance(task.executor, BaseEdge): # Edge - edge = task.executor - logger.info(f"Executing edge: {edge.name}") - - # Execute the edge in the thread pool - downstream_tasks = await asyncio.to_thread( - self._execute_edge_task, edge, task.input, state - ) - - # Add downstream tasks to the queue - for down_task in downstream_tasks: - await self.task_queue.put(down_task) - else: - raise ValueError(f"Invalid executor type: {task.executor}") - - # Mark the task as done - logger.debug(f"======= FINISH TASK {task.executor} ========") - self.task_queue.task_done() - - def _execute_node_task(self, node, input_gen, state): + for input_data in initial_inputs: + self.task_queue.put(Task(iter([input_data]), first_node)) + + try: + while not self.task_queue.empty() or self.active_futures: + # Submit tasks to executor + while not self.task_queue.empty(): + task = self.task_queue.get() + future = self.executor.submit(self._process_task, task) + self.active_futures.add(future) + future.add_done_callback(self._on_task_complete) + + # Wait for any task to complete + completed_futures = [f for f in self.active_futures if f.done()] + for future in completed_futures: + self.active_futures.remove(future) + + finally: + self.executor.shutdown(wait=True) + + def _on_task_complete(self, future: Future): """ - Execute a node task and generate downstream tasks. - - Args: - node (BaseNode): The node to execute. - input_gen (DataGenerator): The input generator. - state (Dict[str, Any]): The current workflow state. - - Returns: - List[Tuple[DataGenerator, BaseEdge]]: Downstream tasks for adjacent edges. + Callback for task completion to handle downstream tasks. """ - logger.debug(f"================== EXECUTE Node {node} =================") - results = node.execute(state, input_gen) - downstream_tasks = [] - - for result in results: - # Create tasks for all adjacent edges - for edge in self.workflow.graph.get_adjacent_edges(node.name): - downstream_tasks.append(Task(iter([result]), edge)) - - return downstream_tasks - - def _execute_edge_task(self, edge, input_gen, state): + try: + downstream_tasks = future.result() + for task in downstream_tasks: + self.task_queue.put(task) + except Exception as e: + logger.error(f"Task failed with error: {e}") + + def _process_task(self, task: Task) -> List[Task]: """ - Execute an edge task and generate downstream tasks. + Process a single task. Args: - edge (BaseEdge): The edge to execute. - input_gen (DataGenerator): The input generator. - state (Dict[str, Any]): The current workflow state. + task (Task): The task to process. Returns: - List[Tuple[DataGenerator, BaseNode]]: Downstream tasks for target nodes. + List[Task]: Downstream tasks to enqueue. """ - logger.debug(f"================== EXECUTE Edge {edge} =================") - - results = edge.execute(state, input_gen) downstream_tasks = [] - - for result in results: - # Create tasks for the target node - logger.debug( - f"#################### ADD TASK {(iter([result]), edge.target)} ###################" - ) - downstream_tasks.append( - Task(iter([result]), self.workflow.graph.get_node(edge.target)) - ) + executor = task.executor + + if isinstance(executor, BaseNode): # Node task + logger.info(f"Executing node: {executor}") + results = executor.execute(self.state, task.input) + + for result in results: + if executor.node_type == NodeType.INSPECTOR: + with self.lock: + self.processed_inspectors += 1 + if self.processed_inspectors >= self.max_inspectors: + logger.info( + f"Reached max inspectors limit '{self.max_inspectors}', stopping execution)" + ) + return [] + for edge in self.workflow.graph.get_adjacent_edges(executor.name): + downstream_tasks.append(Task(iter([result]), edge)) + + elif isinstance(executor, BaseEdge): # Edge task + logger.info(f"Executing edge: {executor.name}") + results = executor.execute(self.state, task.input) + + for result in results: + target_node = self.workflow.graph.get_node(executor.target) + downstream_tasks.append(Task(iter([result]), target_node)) + + else: + raise ValueError(f"Invalid executor type: {executor}") return downstream_tasks