diff --git a/pyop2/base.py b/pyop2/base.py index a81a3b52b..26c01d4cd 100644 --- a/pyop2/base.py +++ b/pyop2/base.py @@ -1338,13 +1338,16 @@ class Dat(DataCarrier, _EmptyDataMixin): multiplication / division by a scalar. """ + _zero_kernels = {} + """Class-level cache for zero kernels.""" + + _modes = [READ, WRITE, RW, INC, MIN, MAX] + @cached_property def pack(self): from pyop2.codegen.builder import DatPack return DatPack - _modes = [READ, WRITE, RW, INC, MIN, MAX] - @validate_type(('dataset', (DataCarrier, DataSet, Set), DataSetTypeError), ('name', str, NameTypeError)) @validate_dtype(('dtype', None, DataTypeError)) @@ -1546,18 +1549,22 @@ def zero(self, subset=None): loop = loops.get(iterset, None) if loop is None: - import islpy as isl - import pymbolic.primitives as p - - inames = isl.make_zero_and_vars(["i"]) - domain = (inames[0].le_set(inames["i"])) & (inames["i"].lt_set(inames[0] + self.cdim)) - x = p.Variable("dat") - i = p.Variable("i") - insn = loopy.Assignment(x.index(i), 0, within_inames=frozenset(["i"])) - data = loopy.GlobalArg("dat", dtype=self.dtype, shape=(self.cdim,)) - knl = loopy.make_function([domain], [insn], [data], name="zero") - - knl = _make_object('Kernel', knl, 'zero') + try: + knl = self._zero_kernels[(self.dtype, self.cdim)] + except KeyError: + import islpy as isl + import pymbolic.primitives as p + + inames = isl.make_zero_and_vars(["i"]) + domain = (inames[0].le_set(inames["i"])) & (inames["i"].lt_set(inames[0] + self.cdim)) + x = p.Variable("dat") + i = p.Variable("i") + insn = loopy.Assignment(x.index(i), 0, within_inames=frozenset(["i"])) + data = loopy.GlobalArg("dat", dtype=self.dtype, shape=(self.cdim,)) + knl = loopy.make_function([domain], [insn], [data], name="zero") + + knl = _make_object('Kernel', knl, 'zero') + self._zero_kernels[(self.dtype, self.cdim)] = knl loop = _make_object('ParLoop', knl, iterset, self(WRITE))