Skip to content
This repository has been archived by the owner on Nov 27, 2024. It is now read-only.

Commit

Permalink
Merge pull request #562 from OP2/fix/access-init
Browse files Browse the repository at this point in the history
Fix/access init
  • Loading branch information
dham authored May 10, 2019
2 parents 51d0509 + 13baae0 commit 9a07a33
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
8 changes: 4 additions & 4 deletions pyop2/codegen/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,11 +185,11 @@ def pack(self, loop_indices=None):
if self.view_index is None:
shape = shape + self.outer.shape[1:]

if self.access in {INC, WRITE, MIN, MAX}:
if self.access in {INC, WRITE}:
val = Zero((), self.outer.dtype)
multiindex = MultiIndex(*(Index(e) for e in shape))
self._pack = Materialise(PackInst(), val, multiindex)
elif self.access in {READ, RW}:
elif self.access in {READ, RW, MIN, MAX}:
multiindex = MultiIndex(*(Index(e) for e in shape))
expr, mask = self._rvalue(multiindex, loop_indices=loop_indices)
if mask is not None:
Expand Down Expand Up @@ -262,11 +262,11 @@ def pack(self, loop_indices=None):
else:
_shape = (1,)

if self.access in {INC, WRITE, MIN, MAX}:
if self.access in {INC, WRITE}:
val = Zero((), self.dtype)
multiindex = MultiIndex(Index(flat_shape))
self._pack = Materialise(PackInst(), val, multiindex)
elif self.access in {READ, RW}:
elif self.access in {READ, RW, MIN, MAX}:
multiindex = MultiIndex(Index(flat_shape))
val = Zero((), self.dtype)
expressions = []
Expand Down
20 changes: 20 additions & 0 deletions test/unit/test_indirect_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,26 @@ def test_indirect_inc(self, iterset, unitset, iterset2unitset):
iterset, u(op2.INC, iterset2unitset))
assert u.data[0] == nelems

def test_indirect_max(self, iterset, indset, iterset2indset):
a = op2.Dat(indset, dtype=np.int32)
b = op2.Dat(indset, dtype=np.int32)
a.data[:] = -10
b.data[:] = -5
kernel = "static void maxify(int *a, int *b) {*a = *a < *b ? *b : *a;}\n"
op2.par_loop(op2.Kernel(kernel, "maxify"),
iterset, a(op2.MAX, iterset2indset), b(op2.READ, iterset2indset))
assert np.allclose(a.data_ro, -5)

def test_indirect_min(self, iterset, indset, iterset2indset):
a = op2.Dat(indset, dtype=np.int32)
b = op2.Dat(indset, dtype=np.int32)
a.data[:] = 10
b.data[:] = 5
kernel = "static void minify(int *a, int *b) {*a = *a > *b ? *b : *a;}\n"
op2.par_loop(op2.Kernel(kernel, "minify"),
iterset, a(op2.MIN, iterset2indset), b(op2.READ, iterset2indset))
assert np.allclose(a.data_ro, 5)

def test_global_read(self, iterset, x, iterset2indset):
"""Divide a Dat by a Global."""
g = op2.Global(1, 2, np.uint32, "g")
Expand Down

0 comments on commit 9a07a33

Please sign in to comment.