Skip to content

Commit

Permalink
chore(multiprocessing): set spawn only on linux for now + disable par…
Browse files Browse the repository at this point in the history
…allel processing for ComputeDirectEffectsMultivariateRegression as it is not optimized for it
  • Loading branch information
LilithWittmann committed Apr 2, 2024
1 parent 5b781e9 commit 0bc44f2
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 7 deletions.
3 changes: 3 additions & 0 deletions causy/causal_effect_estimation/multivariate_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
class ComputeDirectEffectsMultivariateRegression(PipelineStepInterface):
generator = PairsWithEdgesInBetweenGenerator()

chunk_size_parallel_processing = 1
parallel = False

def test(self, nodes: Tuple[str], graph: BaseGraphInterface) -> TestResult:
"""
Calculate the direct effect of each edge in the graph using multivariate regression.
Expand Down
17 changes: 10 additions & 7 deletions causy/graph_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import platform
from abc import ABC
from copy import deepcopy
from typing import Optional, List, Dict, Callable
Expand Down Expand Up @@ -52,13 +53,15 @@ def __initialize_pool(self) -> mp.Pool:
Initialize the multiprocessing pool
:return: the multiprocessing pool
"""
# we need to set the start method to spawn because the default fork method does not work well with torch
try:
mp.set_start_method("spawn")
except RuntimeError:
logger.warning(
"Could not set multiprocessing start method to spawn. Using default method."
)
# we need to set the start method to spawn because the default fork method does not work well with torch on linux
# see https://pytorch.org/docs/stable/notes/multiprocessing.html
if platform.system() == "Linux":
try:
mp.set_start_method("spawn")
except RuntimeError:
logger.warning(
"Could not set multiprocessing start method to spawn. Using default method. This might cause issues on Linux."
)
return mp.Pool(mp.cpu_count() * 2)

def __multiprocessing_required(self, pipeline_steps):
Expand Down

0 comments on commit 0bc44f2

Please sign in to comment.