Skip to content

Commit

Permalink
chore(multiprocessing): try to enforce multiprocessing method as a wo…
Browse files Browse the repository at this point in the history
…rkaround for our linux issues
  • Loading branch information
LilithWittmann committed Apr 2, 2024
1 parent cdfb070 commit 5b781e9
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions causy/graph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ class AbstractGraphModel(GraphModelInterface, ABC):
pipeline_steps: List[PipelineStepInterface]
graph: BaseGraphInterface
pool: mp.Pool
initialize_pool_fn: Callable = lambda self: mp.Pool(mp.cpu_count() * 2)

def __init__(
self,
Expand All @@ -44,10 +43,24 @@ def __init__(
self.graph = graph
self.pipeline_steps = pipeline_steps or []
if self.__multiprocessing_required(self.pipeline_steps):
self.pool = self.initialize_pool_fn()
self.pool = self.__initialize_pool()
else:
self.pool = None

def __initialize_pool(self) -> mp.Pool:
"""
Initialize the multiprocessing pool
:return: the multiprocessing pool
"""
# we need to set the start method to spawn because the default fork method does not work well with torch
try:
mp.set_start_method("spawn")
except RuntimeError:
logger.warning(
"Could not set multiprocessing start method to spawn. Using default method."
)
return mp.Pool(mp.cpu_count() * 2)

def __multiprocessing_required(self, pipeline_steps):
"""
Check if multiprocessing is required
Expand Down Expand Up @@ -244,7 +257,7 @@ def execute_pipeline_step(self, test_fn: PipelineStepInterface):
logger.warning(
"Parallel processing is enabled but no pool is initialized. Initializing pool."
)
self.pool = self.initialize_pool_fn()
self.pool = self.__initialize_pool()

for result in self.pool.imap_unordered(
unpack_run,
Expand Down

0 comments on commit 5b781e9

Please sign in to comment.