Skip to content
This repository has been archived by the owner on Nov 27, 2024. It is now read-only.

Commit

Permalink
Merge pull request #680 from OP2/connorjward/freeze-halos
Browse files Browse the repository at this point in the history
Add halo freezing
  • Loading branch information
dham authored Nov 23, 2022
2 parents 3a1d876 + 43c14a6 commit c1158ed
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 7 deletions.
16 changes: 15 additions & 1 deletion pyop2/parloop.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from pyop2.global_kernel import (GlobalKernelArg, DatKernelArg, MixedDatKernelArg,
MatKernelArg, MixedMatKernelArg, GlobalKernel)
from pyop2.local_kernel import LocalKernel, CStringLocalKernel, CoffeeLocalKernel, LoopyLocalKernel
from pyop2.types import (Access, Global, Dat, DatView, MixedDat, Mat, Set,
from pyop2.types import (Access, Global, AbstractDat, Dat, DatView, MixedDat, Mat, Set,
MixedSet, ExtrudedSet, Subset, Map, ComposedMap, MixedMap)
from pyop2.utils import cached_property

Expand Down Expand Up @@ -146,6 +146,7 @@ def __init__(self, global_knl, iterset, arguments):
raise ValueError("The argument dtypes do not match those for the local kernel")

self.check_iterset(iterset, global_knl, arguments)
self._check_frozen_access_modes(global_knl.local_kernel, arguments)

self.global_kernel = global_knl
self.iterset = iterset
Expand Down Expand Up @@ -440,6 +441,19 @@ def check_iterset(cls, iterset, global_knl, arguments):
if m.iterset != iterset and m.iterset not in iterset:
raise MapValueError(f"Iterset of arg {i} map {j} does not match parloop iterset")

@classmethod
def _check_frozen_access_modes(cls, local_knl, arguments):
"""Check that any frozen :class:`Dat` are getting accessed with the right access mode."""
for lknl_arg, pl_arg in zip(local_knl.arguments, arguments):
if isinstance(pl_arg.data, AbstractDat):
if any(
d._halo_frozen and d._frozen_access_mode != lknl_arg.access
for d in pl_arg.data
):
raise RuntimeError(
"Dats with frozen halos must always be accessed with the same access mode"
)

@classmethod
def prepare_reduced_globals(cls, arguments, global_knl):
"""Swap any :class:`GlobalParloopArg` instances that are INC'd into
Expand Down
80 changes: 76 additions & 4 deletions pyop2/types/dat.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ def __init__(self, dataset, data=None, dtype=None, name=None):
self.halo_valid = True
self._name = name or "dat_#x%x" % id(self)

self._halo_frozen = False
self._frozen_access_mode = None

@utils.cached_property
def _kernel_args_(self):
return (self._data.ctypes.data, )
Expand Down Expand Up @@ -546,7 +549,7 @@ def global_to_local_begin(self, access_mode):
:kwarg access_mode: Mode with which the data will subsequently
be accessed."""
halo = self.dataset.halo
if halo is None:
if halo is None or self._halo_frozen:
return
if not self.halo_valid and access_mode in {Access.READ, Access.RW}:
halo.global_to_local_begin(self, Access.WRITE)
Expand All @@ -565,7 +568,7 @@ def global_to_local_end(self, access_mode):
:kwarg access_mode: Mode with which the data will subsequently
be accessed."""
halo = self.dataset.halo
if halo is None:
if halo is None or self._halo_frozen:
return
if not self.halo_valid and access_mode in {Access.READ, Access.RW}:
halo.global_to_local_end(self, Access.WRITE)
Expand All @@ -582,7 +585,7 @@ def local_to_global_begin(self, insert_mode):
:kwarg insert_mode: insertion mode (an access descriptor)"""
halo = self.dataset.halo
if halo is None:
if halo is None or self._halo_frozen:
return
halo.local_to_global_begin(self, insert_mode)

Expand All @@ -592,11 +595,44 @@ def local_to_global_end(self, insert_mode):
:kwarg insert_mode: insertion mode (an access descriptor)"""
halo = self.dataset.halo
if halo is None:
if halo is None or self._halo_frozen:
return
halo.local_to_global_end(self, insert_mode)
self.halo_valid = False

@mpi.collective
def frozen_halo(self, access_mode):
"""Temporarily disable halo exchanges inside a context manager.
:arg access_mode: Mode with which the data will subsequently be accessed.
This is useful in cases where one is repeatedly writing to a :class:`Dat` with
the same access descriptor since the intermediate updates can be skipped.
"""
return frozen_halo(self, access_mode)

@mpi.collective
def freeze_halo(self, access_mode):
"""Disable halo exchanges.
:arg access_mode: Mode with which the data will subsequently be accessed.
Note that some bookkeeping is needed when freezing halos. Prefer to use the
:meth:`Dat.frozen_halo` context manager.
"""
if self._halo_frozen:
raise RuntimeError("Expected an unfrozen halo")
self._halo_frozen = True
self._frozen_access_mode = access_mode

@mpi.collective
def unfreeze_halo(self):
"""Re-enable halo exchanges."""
if not self._halo_frozen:
raise RuntimeError("Expected a frozen halo")
self._halo_frozen = False
self._frozen_access_mode = None


class DatView(AbstractDat):
"""An indexed view into a :class:`Dat`.
Expand Down Expand Up @@ -834,6 +870,18 @@ def local_to_global_end(self, insert_mode):
for s in self:
s.local_to_global_end(insert_mode)

@mpi.collective
def freeze_halo(self, access_mode):
"""Disable halo exchanges."""
for d in self:
d.freeze_halo(access_mode)

@mpi.collective
def unfreeze_halo(self):
"""Re-enable halo exchanges."""
for d in self:
d.unfreeze_halo()

@mpi.collective
def zero(self, subset=None):
"""Zero the data associated with this :class:`MixedDat`.
Expand Down Expand Up @@ -1033,3 +1081,27 @@ def vec_context(self, access):
v.array[:] = array[offset:offset+size]
offset += size
self.halo_valid = False


class frozen_halo:
"""Context manager handling the freezing and unfreezing of halos.
:param dat: The :class:`Dat` whose halo is to be frozen.
:param access_mode: Mode with which the :class:`Dat` will be accessed whilst
its halo is frozen.
"""
def __init__(self, dat, access_mode):
self._dat = dat
self._access_mode = access_mode

def __enter__(self):
# Initialise the halo values (e.g. set to zero if INC'ing)
self._dat.global_to_local_begin(self._access_mode)
self._dat.global_to_local_end(self._access_mode)
self._dat.freeze_halo(self._access_mode)

def __exit__(self, *args):
# Finally do the halo exchanges
self._dat.unfreeze_halo()
self._dat.local_to_global_begin(self._access_mode)
self._dat.local_to_global_end(self._access_mode)
22 changes: 20 additions & 2 deletions pyop2/types/glob.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from contextlib import contextmanager
import contextlib
import ctypes
import operator

Expand Down Expand Up @@ -185,6 +185,24 @@ def local_to_global_end(self, insert_mode):
part of a :class:`MixedDat`."""
pass

@mpi.collective
def frozen_halo(self, access_mode):
"""Dummy halo operation for the case in which a :class:`Global` forms
part of a :class:`MixedDat`."""
return contextlib.nullcontext()

@mpi.collective
def freeze_halo(self, access_mode):
"""Dummy halo operation for the case in which a :class:`Global` forms
part of a :class:`MixedDat`."""
pass

@mpi.collective
def unfreeze_halo(self):
"""Dummy halo operation for the case in which a :class:`Global` forms
part of a :class:`MixedDat`."""
pass

def _op(self, other, op):
ret = type(self)(self.dim, dtype=self.dtype, name=self.name, comm=self.comm)
if isinstance(other, Global):
Expand Down Expand Up @@ -283,7 +301,7 @@ def _vec(self):
bsize=self.cdim,
comm=self.comm)

@contextmanager
@contextlib.contextmanager
def vec_context(self, access):
"""A context manager for a :class:`PETSc.Vec` from a :class:`Global`.
Expand Down
13 changes: 13 additions & 0 deletions test/unit/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1624,6 +1624,19 @@ def test_empty_map_and_iterset(self):
k = op2.Kernel("static void k(int *x) {}", "k")
op2.par_loop(k, s1, d(op2.READ, m))

def test_frozen_dats_cannot_use_different_access_mode(self):
s1 = op2.Set(2)
s2 = op2.Set(3)
m = op2.Map(s1, s2, 3, [0]*6)
d = op2.Dat(s2**1, [0]*3, dtype=int)
k = op2.Kernel("static void k(int *x) {}", "k")

with d.frozen_halo(op2.INC):
op2.par_loop(k, s1, d(op2.INC, m))

with pytest.raises(RuntimeError):
op2.par_loop(k, s1, d(op2.WRITE, m))


if __name__ == '__main__':
import os
Expand Down

0 comments on commit c1158ed

Please sign in to comment.