diff --git a/pyop2/base.py b/pyop2/base.py index f814799da..ee60c2a33 100644 --- a/pyop2/base.py +++ b/pyop2/base.py @@ -47,7 +47,7 @@ import types from hashlib import md5 -from pyop2.datatypes import IntType, as_cstr, _EntityMask, _MapMask +from pyop2.datatypes import IntType, as_cstr, _EntityMask, _MapMask, dtype_limits from pyop2.configuration import configuration from pyop2.caching import Cached, ObjectCached from pyop2.exceptions import * @@ -473,7 +473,7 @@ def global_to_local_begin(self): assert self._is_dat, "Doing halo exchanges only makes sense for Dats" assert not self._in_flight, \ "Halo exchange already in flight for Arg %s" % self - if self.access in [READ, RW, INC]: + if self.access in [READ, RW, INC, MIN, MAX]: self._in_flight = True self.data.global_to_local_begin(self.access) @@ -483,7 +483,7 @@ def global_to_local_end(self): Doing halo exchanges only makes sense for :class:`Dat` objects. """ assert self._is_dat, "Doing halo exchanges only makes sense for Dats" - if self.access in [READ, RW, INC] and self._in_flight: + if self.access in [READ, RW, INC, MIN, MAX] and self._in_flight: self._in_flight = False self.data.global_to_local_end(self.access) @@ -1647,7 +1647,7 @@ class Dat(DataCarrier, _EmptyDataMixin): """ _globalcount = 0 - _modes = [READ, WRITE, RW, INC] + _modes = [READ, WRITE, RW, INC, MIN, MAX] @validate_type(('dataset', (DataCarrier, DataSet, Set), DataSetTypeError), ('name', str, NameTypeError)) @@ -2121,6 +2121,9 @@ def global_to_local_begin(self, access_mode): halo.global_to_local_begin(self, WRITE) elif access_mode is INC: self._data[self.dataset.size:] = 0 + elif access_mode in [MIN, MAX]: + min_, max_ = dtype_limits(self.dtype) + self._data[self.dataset.size:] = {MAX: min_, MIN: max_}[access_mode] @collective def global_to_local_end(self, access_mode): @@ -2134,7 +2137,7 @@ def global_to_local_end(self, access_mode): if access_mode in [READ, RW] and not self.halo_valid: halo.global_to_local_end(self, WRITE) self.halo_valid = True - elif access_mode is INC: + elif access_mode in [MIN, MAX, INC]: self.halo_valid = False @collective diff --git a/pyop2/datatypes.py b/pyop2/datatypes.py index b8115b1c6..7fcf14088 100644 --- a/pyop2/datatypes.py +++ b/pyop2/datatypes.py @@ -50,3 +50,21 @@ class _EntityMask(ctypes.Structure): _fields_ = [("section", ctypes.c_voidp), ("bottom", ctypes.c_voidp), ("top", ctypes.c_voidp)] + + +def dtype_limits(dtype): + """Attempt to determine the min and max values of a datatype. + + :arg dtype: A numpy datatype. + :returns: a 2-tuple of min, max + :raises ValueError: If numeric limits could not be determined. + """ + try: + info = numpy.finfo(dtype) + except ValueError: + # maybe an int? + try: + info = numpy.iinfo(dtype) + except ValueError as e: + raise ValueError("Unable to determine numeric limits from %s" % dtype) from e + return info.min, info.max diff --git a/test/unit/test_api.py b/test/unit/test_api.py index eb0e05018..f14a4905c 100644 --- a/test/unit/test_api.py +++ b/test/unit/test_api.py @@ -790,12 +790,6 @@ def test_dat_initialise_data_type(self, dset): d = op2.Dat(dset, dtype=np.int32) assert d.data.dtype == np.int32 - @pytest.mark.parametrize("mode", [op2.MAX, op2.MIN]) - def test_dat_arg_illegal_mode(self, dat, mode): - """Dat __call__ should not allow access modes not allowed for a Dat.""" - with pytest.raises(exceptions.ModeValueError): - dat(mode) - def test_dat_subscript(self, dat): """Extracting component 0 of a Dat should yield self.""" assert dat[0] is dat