From cc954d0cef476187ab478522a4d091a2fb3ec804 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 11 Oct 2024 16:39:11 -0500 Subject: [PATCH] Add MPINumpyArrayContext (#312) * Add MPINumpyArrayContext * add support to wave-op-mpi, add to tests * doc fix * minor update --- .github/workflows/ci.yml | 2 ++ examples/wave/wave-op-mpi.py | 22 +++++++++++------ grudge/array_context.py | 48 +++++++++++++++++++++++++++++++++--- test/test_dt_utils.py | 4 ++- test/test_metrics.py | 4 ++- 5 files changed, 67 insertions(+), 13 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c7c54d2ee..9892070df 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -69,6 +69,8 @@ jobs: # have a sufficient number of cores. mpiexec -np 2 --oversubscribe python wave/wave-op-mpi.py --lazy + mpiexec -np 2 --oversubscribe python wave/wave-op-mpi.py --numpy + docs: name: Documentation runs-on: ubuntu-latest diff --git a/examples/wave/wave-op-mpi.py b/examples/wave/wave-op-mpi.py index f72985608..ea47e202e 100644 --- a/examples/wave/wave-op-mpi.py +++ b/examples/wave/wave-op-mpi.py @@ -175,19 +175,24 @@ def bump(actx, dcoll, t=0): def main(ctx_factory, dim=2, order=3, - visualize=False, lazy=False, use_quad=False, use_nonaffine_mesh=False, - no_diagnostics=False): - cl_ctx = ctx_factory() - queue = cl.CommandQueue(cl_ctx) - + visualize=False, lazy=False, numpy=False, use_quad=False, + use_nonaffine_mesh=False, no_diagnostics=False): comm = MPI.COMM_WORLD num_parts = comm.size from grudge.array_context import get_reasonable_array_context_class - actx_class = get_reasonable_array_context_class(lazy=lazy, distributed=True) - if lazy: + actx_class = get_reasonable_array_context_class(lazy=lazy, + distributed=True, numpy=numpy) + + if numpy: + actx = actx_class(comm) + elif lazy: + cl_ctx = ctx_factory() + queue = cl.CommandQueue(cl_ctx) actx = actx_class(comm, queue, mpi_base_tag=15000) else: + cl_ctx = ctx_factory() + queue = cl.CommandQueue(cl_ctx) actx = actx_class(comm, queue, allocator=cl_tools.MemoryPool(cl_tools.ImmediateAllocator(queue)), force_device_scalars=True) @@ -323,6 +328,8 @@ def rhs(t, w): parser.add_argument("--visualize", action="store_true") parser.add_argument("--lazy", action="store_true", help="switch to a lazy computation mode") + parser.add_argument("--numpy", action="store_true", + help="switch to numpy-based array context") parser.add_argument("--quad", action="store_true") parser.add_argument("--nonaffine", action="store_true") parser.add_argument("--no-diagnostics", action="store_true") @@ -335,6 +342,7 @@ def rhs(t, w): order=args.order, visualize=args.visualize, lazy=args.lazy, + numpy=args.numpy, use_quad=args.quad, use_nonaffine_mesh=args.nonaffine, no_diagnostics=args.no_diagnostics) diff --git a/grudge/array_context.py b/grudge/array_context.py index 674dac8de..45f7d8ef4 100644 --- a/grudge/array_context.py +++ b/grudge/array_context.py @@ -3,6 +3,7 @@ .. autoclass:: PytatoPyOpenCLArrayContext .. autoclass:: MPIBasedArrayContext .. autoclass:: MPIPyOpenCLArrayContext +.. autoclass:: MPINumpyArrayContext .. class:: MPIPytatoArrayContext .. autofunction:: get_reasonable_array_context_class """ @@ -100,10 +101,11 @@ _HAVE_FUSION_ACTX = False -from arraycontext import ArrayContext +from arraycontext import ArrayContext, NumpyArrayContext from arraycontext.container import ArrayContainer from arraycontext.impl.pytato.compile import LazilyPyOpenCLCompilingFunctionCaller from arraycontext.pytest import ( + _PytestNumpyArrayContextFactory, _PytestPyOpenCLArrayContextFactoryWithClass, _PytestPytatoPyOpenCLArrayContextFactory, register_pytest_array_context_factory, @@ -468,6 +470,26 @@ def clone(self): # }}} +# {{{ distributed + numpy + +class MPINumpyArrayContext(NumpyArrayContext, MPIBasedArrayContext): + """An array context for using distributed computation with :mod:`numpy` + eager evaluation. + + .. autofunction:: __init__ + """ + + def __init__(self, mpi_communicator) -> None: + super().__init__() + + self.mpi_communicator = mpi_communicator + + def clone(self): + return type(self)(self.mpi_communicator) + +# }}} + + # {{{ distributed + pytato array context subclasses class MPIBasePytatoPyOpenCLArrayContext( @@ -535,10 +557,19 @@ def __call__(self): return self.actx_class(queue, allocator=alloc) +class PytestNumpyArrayContextFactory(_PytestNumpyArrayContextFactory): + actx_class = NumpyArrayContext + + def __call__(self): + return self.actx_class() + + register_pytest_array_context_factory("grudge.pyopencl", PytestPyOpenCLArrayContextFactory) register_pytest_array_context_factory("grudge.pytato-pyopencl", PytestPytatoPyOpenCLArrayContextFactory) +register_pytest_array_context_factory("grudge.numpy", + PytestNumpyArrayContextFactory) # }}} @@ -570,13 +601,22 @@ def _get_single_grid_pytato_actx_class(distributed: bool) -> Type[ArrayContext]: def get_reasonable_array_context_class( lazy: bool = True, distributed: bool = True, - fusion: Optional[bool] = None, + fusion: Optional[bool] = None, numpy: bool = False, ) -> Type[ArrayContext]: - """Returns a reasonable :class:`PyOpenCLArrayContext` currently - supported given the constraints of *lazy* and *distributed*.""" + """Returns a reasonable :class:`~arraycontext.ArrayContext` currently + supported given the constraints of *lazy*, *distributed*, and *numpy*.""" if fusion is None: fusion = lazy + if numpy: + assert not (lazy or fusion) + if distributed: + actx_class = MPINumpyArrayContext + else: + actx_class = NumpyArrayContext + + return actx_class + if lazy: if fusion: if not _HAVE_FUSION_ACTX: diff --git a/test/test_dt_utils.py b/test/test_dt_utils.py index 729980917..cf3ac2021 100644 --- a/test/test_dt_utils.py +++ b/test/test_dt_utils.py @@ -27,6 +27,7 @@ from arraycontext import pytest_generate_tests_for_array_contexts from grudge.array_context import ( + PytestNumpyArrayContextFactory, PytestPyOpenCLArrayContextFactory, PytestPytatoPyOpenCLArrayContextFactory, ) @@ -34,7 +35,8 @@ pytest_generate_tests = pytest_generate_tests_for_array_contexts( [PytestPyOpenCLArrayContextFactory, - PytestPytatoPyOpenCLArrayContextFactory]) + PytestPytatoPyOpenCLArrayContextFactory, + PytestNumpyArrayContextFactory]) import logging diff --git a/test/test_metrics.py b/test/test_metrics.py index e6863c7f6..1ee043b8a 100644 --- a/test/test_metrics.py +++ b/test/test_metrics.py @@ -33,6 +33,7 @@ from meshmode.dof_array import flat_norm from grudge.array_context import ( + PytestNumpyArrayContextFactory, PytestPyOpenCLArrayContextFactory, PytestPytatoPyOpenCLArrayContextFactory, ) @@ -42,7 +43,8 @@ logger = logging.getLogger(__name__) pytest_generate_tests = pytest_generate_tests_for_array_contexts( [PytestPyOpenCLArrayContextFactory, - PytestPytatoPyOpenCLArrayContextFactory]) + PytestPytatoPyOpenCLArrayContextFactory, + PytestNumpyArrayContextFactory]) # {{{ inverse metric