From 21df7c9a13ba864bbdb27483b58fd5c148700f7b Mon Sep 17 00:00:00 2001 From: Connor Ward Date: Tue, 5 Jan 2021 16:43:44 +0000 Subject: [PATCH 1/2] Add zero kernel cache Co-authored-by: David A. Ham --- pyop2/base.py | 32 ++++++++++++++++++++------------ 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/pyop2/base.py b/pyop2/base.py index a81a3b52b..b700e4d0a 100644 --- a/pyop2/base.py +++ b/pyop2/base.py @@ -1338,6 +1338,10 @@ class Dat(DataCarrier, _EmptyDataMixin): multiplication / division by a scalar. """ + _zero_kernels = {} + """Class-level cache for zero kernels.""" + + @cached_property def pack(self): from pyop2.codegen.builder import DatPack @@ -1546,18 +1550,22 @@ def zero(self, subset=None): loop = loops.get(iterset, None) if loop is None: - import islpy as isl - import pymbolic.primitives as p - - inames = isl.make_zero_and_vars(["i"]) - domain = (inames[0].le_set(inames["i"])) & (inames["i"].lt_set(inames[0] + self.cdim)) - x = p.Variable("dat") - i = p.Variable("i") - insn = loopy.Assignment(x.index(i), 0, within_inames=frozenset(["i"])) - data = loopy.GlobalArg("dat", dtype=self.dtype, shape=(self.cdim,)) - knl = loopy.make_function([domain], [insn], [data], name="zero") - - knl = _make_object('Kernel', knl, 'zero') + try: + knl = self._zero_kernels[(self.dtype, self.cdim)] + except KeyError: + import islpy as isl + import pymbolic.primitives as p + + inames = isl.make_zero_and_vars(["i"]) + domain = (inames[0].le_set(inames["i"])) & (inames["i"].lt_set(inames[0] + self.cdim)) + x = p.Variable("dat") + i = p.Variable("i") + insn = loopy.Assignment(x.index(i), 0, within_inames=frozenset(["i"])) + data = loopy.GlobalArg("dat", dtype=self.dtype, shape=(self.cdim,)) + knl = loopy.make_function([domain], [insn], [data], name="zero") + + knl = _make_object('Kernel', knl, 'zero') + self._zero_kernels[(self.dtype, self.cdim)] = knl loop = _make_object('ParLoop', knl, iterset, self(WRITE)) From be843f28ff2915972f653d780f442ef14a113046 Mon Sep 17 00:00:00 2001 From: Connor Ward Date: Wed, 6 Jan 2021 13:48:34 +0000 Subject: [PATCH 2/2] Fix linting error --- pyop2/base.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pyop2/base.py b/pyop2/base.py index b700e4d0a..26c01d4cd 100644 --- a/pyop2/base.py +++ b/pyop2/base.py @@ -1341,14 +1341,13 @@ class Dat(DataCarrier, _EmptyDataMixin): _zero_kernels = {} """Class-level cache for zero kernels.""" + _modes = [READ, WRITE, RW, INC, MIN, MAX] @cached_property def pack(self): from pyop2.codegen.builder import DatPack return DatPack - _modes = [READ, WRITE, RW, INC, MIN, MAX] - @validate_type(('dataset', (DataCarrier, DataSet, Set), DataSetTypeError), ('name', str, NameTypeError)) @validate_dtype(('dtype', None, DataTypeError))