From d8d5793f1bf064eabeba2303ea37a7159b09641d Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Sat, 18 Nov 2023 23:25:07 +0800 Subject: [PATCH] [running] fix multiprocessing bugs (#547) * [running] fix multiprocessing bugs * fix tests --- .../_src/running/pathos_multiprocessing.py | 7 ++++ .../tests/test_pathos_multiprocessing.py | 41 +++++++++++++++++++ requirements-dev.txt | 3 +- requirements-doc.txt | 4 +- 4 files changed, 52 insertions(+), 3 deletions(-) create mode 100644 brainpy/_src/running/tests/test_pathos_multiprocessing.py diff --git a/brainpy/_src/running/pathos_multiprocessing.py b/brainpy/_src/running/pathos_multiprocessing.py index 1573a541c..f652217d9 100644 --- a/brainpy/_src/running/pathos_multiprocessing.py +++ b/brainpy/_src/running/pathos_multiprocessing.py @@ -9,6 +9,7 @@ - ``cpu_unordered_parallel``: Performs a parallel unordered map. """ +import sys from collections.abc import Sized from typing import (Any, Callable, Generator, Iterable, List, Union, Optional, Sequence, Dict) @@ -20,6 +21,8 @@ try: from pathos.helpers import cpu_count # noqa from pathos.multiprocessing import ProcessPool # noqa + import multiprocess.context as ctx # noqa + ctx._force_start_method('spawn') except ModuleNotFoundError: cpu_count = None ProcessPool = None @@ -63,6 +66,10 @@ def _parallel( A generator which will apply the function to each element of the given Iterables in parallel in order with a progress bar. """ + if sys.platform == 'win32' and sys.version_info.minor >= 11: + raise NotImplementedError('Multiprocessing is not available in Python >=3.11 on Windows. ' + 'Please use Linux or MacOS, or Windows with Python <= 3.10.') + if ProcessPool is None or cpu_count is None: raise PackageMissingError( ''' diff --git a/brainpy/_src/running/tests/test_pathos_multiprocessing.py b/brainpy/_src/running/tests/test_pathos_multiprocessing.py new file mode 100644 index 000000000..6f92bda7e --- /dev/null +++ b/brainpy/_src/running/tests/test_pathos_multiprocessing.py @@ -0,0 +1,41 @@ +import sys + +import jax +import pytest +from absl.testing import parameterized + +import brainpy as bp +import brainpy.math as bm + +if sys.platform == 'win32' and sys.version_info.minor >= 11: + pytest.skip('python 3.11 does not support.', allow_module_level=True) +else: + pytest.skip('Cannot pass tests.', allow_module_level=True) + + +class TestParallel(parameterized.TestCase): + @parameterized.product( + duration=[1e2, 1e3, 1e4, 1e5] + ) + def test_cpu_unordered_parallel_v1(self, duration): + @jax.jit + def body(inp): + return bm.for_loop(lambda x: x + 1e-9, inp) + + input_long = bm.random.randn(1, int(duration / bm.dt), 3) / 100 + + r = bp.running.cpu_ordered_parallel(body, {'inp': [input_long, input_long]}, num_process=2) + assert bm.allclose(r[0], r[1]) + + @parameterized.product( + duration=[1e2, 1e3, 1e4, 1e5] + ) + def test_cpu_unordered_parallel_v2(self, duration): + @jax.jit + def body(inp): + return bm.for_loop(lambda x: x + 1e-9, inp) + + input_long = bm.random.randn(1, int(duration / bm.dt), 3) / 100 + + r = bp.running.cpu_unordered_parallel(body, {'inp': [input_long, input_long]}, num_process=2) + assert bm.allclose(r[0], r[1]) diff --git a/requirements-dev.txt b/requirements-dev.txt index 93fa26af3..068c38546 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -3,9 +3,10 @@ numba brainpylib jax jaxlib -matplotlib>=3.4 +matplotlib msgpack tqdm +pathos # test requirements pytest diff --git a/requirements-doc.txt b/requirements-doc.txt index d4fe3f43e..c399c03b0 100644 --- a/requirements-doc.txt +++ b/requirements-doc.txt @@ -4,8 +4,8 @@ msgpack numba jax jaxlib -matplotlib>=3.4 -scipy>=1.1.0 +matplotlib +scipy numba # document requirements