Skip to content

Commit

Permalink
add support to wave-op-mpi, add to tests
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasdiener committed Sep 6, 2024
1 parent 66319b7 commit 67e7993
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 14 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ jobs:
# have a sufficient number of cores.
mpiexec -np 2 --oversubscribe python wave/wave-op-mpi.py --lazy
mpiexec -np 2 --oversubscribe python wave/wave-op-mpi.py --numpy
docs:
name: Documentation
runs-on: ubuntu-latest
Expand Down
22 changes: 15 additions & 7 deletions examples/wave/wave-op-mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,19 +175,24 @@ def bump(actx, dcoll, t=0):


def main(ctx_factory, dim=2, order=3,
visualize=False, lazy=False, use_quad=False, use_nonaffine_mesh=False,
no_diagnostics=False):
cl_ctx = ctx_factory()
queue = cl.CommandQueue(cl_ctx)

visualize=False, lazy=False, numpy=False, use_quad=False,
use_nonaffine_mesh=False, no_diagnostics=False):
comm = MPI.COMM_WORLD
num_parts = comm.size

from grudge.array_context import get_reasonable_array_context_class
actx_class = get_reasonable_array_context_class(lazy=lazy, distributed=True)
if lazy:
actx_class = get_reasonable_array_context_class(lazy=lazy,
distributed=True, numpy=numpy)

if numpy:
actx = actx_class(comm)
elif lazy:
cl_ctx = ctx_factory()
queue = cl.CommandQueue(cl_ctx)
actx = actx_class(comm, queue, mpi_base_tag=15000)
else:
cl_ctx = ctx_factory()
queue = cl.CommandQueue(cl_ctx)
actx = actx_class(comm, queue,
allocator=cl_tools.MemoryPool(cl_tools.ImmediateAllocator(queue)),
force_device_scalars=True)
Expand Down Expand Up @@ -323,6 +328,8 @@ def rhs(t, w):
parser.add_argument("--visualize", action="store_true")
parser.add_argument("--lazy", action="store_true",
help="switch to a lazy computation mode")
parser.add_argument("--numpy", action="store_true",
help="switch to numpy-based array context")
parser.add_argument("--quad", action="store_true")
parser.add_argument("--nonaffine", action="store_true")
parser.add_argument("--no-diagnostics", action="store_true")
Expand All @@ -335,6 +342,7 @@ def rhs(t, w):
order=args.order,
visualize=args.visualize,
lazy=args.lazy,
numpy=args.numpy,
use_quad=args.quad,
use_nonaffine_mesh=args.nonaffine,
no_diagnostics=args.no_diagnostics)
Expand Down
27 changes: 22 additions & 5 deletions grudge/array_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,18 +101,17 @@
_HAVE_FUSION_ACTX = False


from arraycontext import ArrayContext
from arraycontext import ArrayContext, NumpyArrayContext
from arraycontext.container import ArrayContainer
from arraycontext.impl.pytato.compile import LazilyPyOpenCLCompilingFunctionCaller
from arraycontext.pytest import (
_PytestNumpyArrayContextFactory,
_PytestPyOpenCLArrayContextFactoryWithClass,
_PytestPytatoPyOpenCLArrayContextFactory,
register_pytest_array_context_factory,
)


from arraycontext import NumpyArrayContext

if TYPE_CHECKING:
import pytato as pt
from mpi4py import MPI
Expand Down Expand Up @@ -558,10 +557,19 @@ def __call__(self):
return self.actx_class(queue, allocator=alloc)


class PytestNumpyArrayContextFactory(_PytestNumpyArrayContextFactory):
actx_class = NumpyArrayContext

def __call__(self):
return self.actx_class()


register_pytest_array_context_factory("grudge.pyopencl",
PytestPyOpenCLArrayContextFactory)
register_pytest_array_context_factory("grudge.pytato-pyopencl",
PytestPytatoPyOpenCLArrayContextFactory)
register_pytest_array_context_factory("grudge.numpy",
PytestNumpyArrayContextFactory)

# }}}

Expand Down Expand Up @@ -593,13 +601,22 @@ def _get_single_grid_pytato_actx_class(distributed: bool) -> Type[ArrayContext]:

def get_reasonable_array_context_class(
lazy: bool = True, distributed: bool = True,
fusion: Optional[bool] = None,
fusion: Optional[bool] = None, numpy: bool = False,
) -> Type[ArrayContext]:
"""Returns a reasonable :class:`PyOpenCLArrayContext` currently
"""Returns a reasonable :class:`ArrayContext` currently
supported given the constraints of *lazy* and *distributed*."""
if fusion is None:
fusion = lazy

if numpy:
assert not lazy
if distributed:
actx_class = MPINumpyArrayContext
else:
actx_class = NumpyArrayContext

return actx_class

if lazy:
if fusion:
if not _HAVE_FUSION_ACTX:
Expand Down
4 changes: 3 additions & 1 deletion test/test_dt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,16 @@
from arraycontext import pytest_generate_tests_for_array_contexts

from grudge.array_context import (
PytestNumpyArrayContextFactory,
PytestPyOpenCLArrayContextFactory,
PytestPytatoPyOpenCLArrayContextFactory,
)


pytest_generate_tests = pytest_generate_tests_for_array_contexts(
[PytestPyOpenCLArrayContextFactory,
PytestPytatoPyOpenCLArrayContextFactory])
PytestPytatoPyOpenCLArrayContextFactory,
PytestNumpyArrayContextFactory])

import logging

Expand Down
4 changes: 3 additions & 1 deletion test/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from meshmode.dof_array import flat_norm

from grudge.array_context import (
PytestNumpyArrayContextFactory,
PytestPyOpenCLArrayContextFactory,
PytestPytatoPyOpenCLArrayContextFactory,
)
Expand All @@ -42,7 +43,8 @@
logger = logging.getLogger(__name__)
pytest_generate_tests = pytest_generate_tests_for_array_contexts(
[PytestPyOpenCLArrayContextFactory,
PytestPytatoPyOpenCLArrayContextFactory])
PytestPytatoPyOpenCLArrayContextFactory,
PytestNumpyArrayContextFactory])


# {{{ inverse metric
Expand Down

0 comments on commit 67e7993

Please sign in to comment.