diff --git a/causy/causal_effect_estimation/multivariate_regression.py b/causy/causal_effect_estimation/multivariate_regression.py index 0a4f2ef..dd512f5 100644 --- a/causy/causal_effect_estimation/multivariate_regression.py +++ b/causy/causal_effect_estimation/multivariate_regression.py @@ -14,6 +14,9 @@ class ComputeDirectEffectsMultivariateRegression(PipelineStepInterface): generator = PairsWithEdgesInBetweenGenerator() + chunk_size_parallel_processing = 1 + parallel = False + def test(self, nodes: Tuple[str], graph: BaseGraphInterface) -> TestResult: """ Calculate the direct effect of each edge in the graph using multivariate regression. diff --git a/causy/graph_model.py b/causy/graph_model.py index a98018a..6f4a65e 100644 --- a/causy/graph_model.py +++ b/causy/graph_model.py @@ -1,4 +1,5 @@ import logging +import platform from abc import ABC from copy import deepcopy from typing import Optional, List, Dict, Callable @@ -52,13 +53,15 @@ def __initialize_pool(self) -> mp.Pool: Initialize the multiprocessing pool :return: the multiprocessing pool """ - # we need to set the start method to spawn because the default fork method does not work well with torch - try: - mp.set_start_method("spawn") - except RuntimeError: - logger.warning( - "Could not set multiprocessing start method to spawn. Using default method." - ) + # we need to set the start method to spawn because the default fork method does not work well with torch on linux + # see https://pytorch.org/docs/stable/notes/multiprocessing.html + if platform.system() == "Linux": + try: + mp.set_start_method("spawn") + except RuntimeError: + logger.warning( + "Could not set multiprocessing start method to spawn. Using default method. This might cause issues on Linux." + ) return mp.Pool(mp.cpu_count() * 2) def __multiprocessing_required(self, pipeline_steps):