Skip to content

Commit

Permalink
Merge pull request #606 from connorjward/add-zero-kernel-cache
Browse files Browse the repository at this point in the history
Add zero kernel cache
  • Loading branch information
wence- authored Jan 6, 2021
2 parents 3652924 + be843f2 commit 8de26ea
Showing 1 changed file with 21 additions and 14 deletions.
35 changes: 21 additions & 14 deletions pyop2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1338,13 +1338,16 @@ class Dat(DataCarrier, _EmptyDataMixin):
multiplication / division by a scalar.
"""

_zero_kernels = {}
"""Class-level cache for zero kernels."""

_modes = [READ, WRITE, RW, INC, MIN, MAX]

@cached_property
def pack(self):
from pyop2.codegen.builder import DatPack
return DatPack

_modes = [READ, WRITE, RW, INC, MIN, MAX]

@validate_type(('dataset', (DataCarrier, DataSet, Set), DataSetTypeError),
('name', str, NameTypeError))
@validate_dtype(('dtype', None, DataTypeError))
Expand Down Expand Up @@ -1546,18 +1549,22 @@ def zero(self, subset=None):
loop = loops.get(iterset, None)

if loop is None:
import islpy as isl
import pymbolic.primitives as p

inames = isl.make_zero_and_vars(["i"])
domain = (inames[0].le_set(inames["i"])) & (inames["i"].lt_set(inames[0] + self.cdim))
x = p.Variable("dat")
i = p.Variable("i")
insn = loopy.Assignment(x.index(i), 0, within_inames=frozenset(["i"]))
data = loopy.GlobalArg("dat", dtype=self.dtype, shape=(self.cdim,))
knl = loopy.make_function([domain], [insn], [data], name="zero")

knl = _make_object('Kernel', knl, 'zero')
try:
knl = self._zero_kernels[(self.dtype, self.cdim)]
except KeyError:
import islpy as isl
import pymbolic.primitives as p

inames = isl.make_zero_and_vars(["i"])
domain = (inames[0].le_set(inames["i"])) & (inames["i"].lt_set(inames[0] + self.cdim))
x = p.Variable("dat")
i = p.Variable("i")
insn = loopy.Assignment(x.index(i), 0, within_inames=frozenset(["i"]))
data = loopy.GlobalArg("dat", dtype=self.dtype, shape=(self.cdim,))
knl = loopy.make_function([domain], [insn], [data], name="zero")

knl = _make_object('Kernel', knl, 'zero')
self._zero_kernels[(self.dtype, self.cdim)] = knl
loop = _make_object('ParLoop', knl,
iterset,
self(WRITE))
Expand Down

0 comments on commit 8de26ea

Please sign in to comment.