diff --git a/causy/graph_model.py b/causy/graph_model.py index e7a6228..a98018a 100644 --- a/causy/graph_model.py +++ b/causy/graph_model.py @@ -34,7 +34,6 @@ class AbstractGraphModel(GraphModelInterface, ABC): pipeline_steps: List[PipelineStepInterface] graph: BaseGraphInterface pool: mp.Pool - initialize_pool_fn: Callable = lambda self: mp.Pool(mp.cpu_count() * 2) def __init__( self, @@ -44,10 +43,24 @@ def __init__( self.graph = graph self.pipeline_steps = pipeline_steps or [] if self.__multiprocessing_required(self.pipeline_steps): - self.pool = self.initialize_pool_fn() + self.pool = self.__initialize_pool() else: self.pool = None + def __initialize_pool(self) -> mp.Pool: + """ + Initialize the multiprocessing pool + :return: the multiprocessing pool + """ + # we need to set the start method to spawn because the default fork method does not work well with torch + try: + mp.set_start_method("spawn") + except RuntimeError: + logger.warning( + "Could not set multiprocessing start method to spawn. Using default method." + ) + return mp.Pool(mp.cpu_count() * 2) + def __multiprocessing_required(self, pipeline_steps): """ Check if multiprocessing is required @@ -244,7 +257,7 @@ def execute_pipeline_step(self, test_fn: PipelineStepInterface): logger.warning( "Parallel processing is enabled but no pool is initialized. Initializing pool." ) - self.pool = self.initialize_pool_fn() + self.pool = self.__initialize_pool() for result in self.pool.imap_unordered( unpack_run,