Skip to content

Commit

Permalink
Add recently implemented performance improvement
Browse files Browse the repository at this point in the history
  • Loading branch information
bouweandela committed Jan 9, 2024
1 parent d21e97d commit c7fce55
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 4 deletions.
23 changes: 22 additions & 1 deletion dask/array/routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -2107,12 +2107,33 @@ def where(condition, x=None, y=None):
return elemwise(np.where, condition, x, y)


_no_nan_types = {
type(np.dtype(t))
for t in (
np.bool_,
np.int8,
np.int16,
np.int32,
np.int64,
np.uint8,
np.uint16,
np.uint32,
np.uint64,
)
}


@derived_from(np)
def array_equal(a1, a2, equal_nan=False, split_every=None):
if a1.shape != a2.shape:
return array([np.False_])[0]
cannot_have_nan = (
type(a1.dtype) in _no_nan_types and type(a2.dtype) in _no_nan_types
)
if (equal_nan or cannot_have_nan) and (a1 is a2):
return array([np.True_])[0]
equal = a1 == a2
if equal_nan:
if equal_nan and not cannot_have_nan:
equal = where(isnan(a1) & isnan(a2), True, equal)
return equal.all(split_every=split_every)

Expand Down
9 changes: 6 additions & 3 deletions dask/array/tests/test_routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -1908,7 +1908,7 @@ def _test_array_equal_parametrizations():
# where (e0 == e0).all() would raise
e0 = np.array(0, dtype="int")
e1 = np.array(1, dtype="float")
# x,y, nan_equal, expected_result
# a1, a2, equal_nan
yield (e0, e0.copy(), None)
yield (e0, e0.copy(), False)
yield (e0, e0.copy(), True)
Expand All @@ -1918,7 +1918,7 @@ def _test_array_equal_parametrizations():
yield (e1, e1.copy(), False)
yield (e1, e1.copy(), True)

# Non-nanable those cannot hold nans
# Non-nanable - those cannot hold nans
a12 = np.array([1, 2])
a12b = a12.copy()
a123 = np.array([1, 2, 3])
Expand Down Expand Up @@ -2026,7 +2026,10 @@ def _test_array_equal_parametrizations():
@pytest.mark.parametrize("a1,a2,equal_nan", _test_array_equal_parametrizations())
def test_array_equal(a1, a2, equal_nan):
d1 = da.asarray(a1, chunks=2)
d2 = da.asarray(a2, chunks=1)
if a1 is a2:
d2 = d1
else:
d2 = da.asarray(a2, chunks=1)
if equal_nan is None:
np_eq = np.array_equal(a1, a2)
da_eq = da.array_equal(d1, d2)
Expand Down

0 comments on commit c7fce55

Please sign in to comment.