Skip to content

Commit

Permalink
Ndarray strides (#1506)
Browse files Browse the repository at this point in the history
The PR adds support for custom strides to `dace.ndarray.` The support
DOES NOT extend to `numpy.ndarray` and `cupy.ndarray.` Furthermore, the
stride unit is number of elements, in contrast to NumPy/CuPy, where it
is number of bytes.

---------

Co-authored-by: Tal Ben-Nun <[email protected]>
  • Loading branch information
alexnick83 and tbennun authored Feb 21, 2024
1 parent c92ecc5 commit 5189760
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 2 deletions.
9 changes: 7 additions & 2 deletions dace/frontend/python/replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,17 @@ def _define_local_ex(pv: ProgramVisitor,
state: SDFGState,
shape: Shape,
dtype: dace.typeclass,
strides: Optional[Shape] = None,
storage: dtypes.StorageType = dtypes.StorageType.Default,
lifetime: dtypes.AllocationLifetime = dtypes.AllocationLifetime.Scope):
""" Defines a local array in a DaCe program. """
if not isinstance(shape, (list, tuple)):
shape = [shape]
name, _ = sdfg.add_temp_transient(shape, dtype, storage=storage, lifetime=lifetime)
if strides is not None:
if not isinstance(strides, (list, tuple)):
strides = [strides]
strides = [int(s) if isinstance(s, Integral) else s for s in strides]
name, _ = sdfg.add_temp_transient(shape, dtype, strides=strides, storage=storage, lifetime=lifetime)
return name


Expand Down Expand Up @@ -4691,7 +4696,7 @@ def _define_cupy_local(
sdfg: SDFG,
state: SDFGState,
shape: Shape,
dtype: typeclass,
dtype: typeclass
):
"""Defines a local array in a DaCe program."""
if not isinstance(shape, (list, tuple)):
Expand Down
60 changes: 60 additions & 0 deletions tests/numpy/array_creation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,62 @@ def test_arange_6():
return np.arange(2.5, 10, 3)


@dace.program
def program_strides_0():
A = dace.ndarray((2, 2), dtype=dace.int32, strides=(2, 1))
for i, j in dace.map[0:2, 0:2]:
A[i, j] = i * 2 + j
return A


def test_strides_0():
A = program_strides_0()
assert A.strides == (8, 4)
assert np.allclose(A, [[0, 1], [2, 3]])


@dace.program
def program_strides_1():
A = dace.ndarray((2, 2), dtype=dace.int32, strides=(4, 2))
for i, j in dace.map[0:2, 0:2]:
A[i, j] = i * 2 + j
return A


def test_strides_1():
A = program_strides_1()
assert A.strides == (16, 8)
assert np.allclose(A, [[0, 1], [2, 3]])


@dace.program
def program_strides_2():
A = dace.ndarray((2, 2), dtype=dace.int32, strides=(1, 2))
for i, j in dace.map[0:2, 0:2]:
A[i, j] = i * 2 + j
return A


def test_strides_2():
A = program_strides_2()
assert A.strides == (4, 8)
assert np.allclose(A, [[0, 1], [2, 3]])


@dace.program
def program_strides_3():
A = dace.ndarray((2, 2), dtype=dace.int32, strides=(2, 4))
for i, j in dace.map[0:2, 0:2]:
A[i, j] = i * 2 + j
return A


def test_strides_3():
A = program_strides_3()
assert A.strides == (8, 16)
assert np.allclose(A, [[0, 1], [2, 3]])


if __name__ == "__main__":
test_empty()
test_empty_like1()
Expand All @@ -173,3 +229,7 @@ def test_arange_6():
test_arange_4()
test_arange_5()
test_arange_6()
test_strides_0()
test_strides_1()
test_strides_2()
test_strides_3()

0 comments on commit 5189760

Please sign in to comment.