From 3bc51bda7e34facf1b6d2aa9156f27c627d9a314 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 6 Aug 2024 20:31:02 -0600 Subject: [PATCH 01/37] Add GroupBy.shuffle() --- xarray/core/duck_array_ops.py | 15 +++++++++++ xarray/core/groupby.py | 48 +++++++++++++++++++++++++++++++++++ xarray/tests/__init__.py | 1 + xarray/tests/test_groupby.py | 18 ++++++++++++- 4 files changed, 81 insertions(+), 1 deletion(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 8993c136ba6..25bd86177df 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -831,3 +831,18 @@ def chunked_nanfirst(darray, axis): def chunked_nanlast(darray, axis): return _chunked_first_or_last(darray, axis, op=nputils.nanlast) + + +def shuffle_array(array, indices: list[list[int]], axis: int): + # TODO: do chunk manager dance here. + if is_duck_dask_array(array): + if not module_available("dask", minversion="2024.08.0"): + raise ValueError( + "This method is very inefficient on dask<2024.08.0. Please upgrade." + ) + # TODO: handle dimensions + return array.shuffle(indexer=indices, axis=axis) + else: + indexer = np.concatenate(indices) + # TODO: Do the array API thing here. + return np.take(array, indices=indexer, axis=axis) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 9b0758d030b..9fbf6778aea 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -517,6 +517,54 @@ def sizes(self) -> Mapping[Hashable, int]: self._sizes = self._obj.isel({self._group_dim: index}).sizes return self._sizes + def shuffle(self) -> None: + """ + Shuffle the underlying object so that all members in a group occur sequentially. + + The order of appearance is not guaranteed. This method modifies the underlying Xarray + object in place. + + Use this method first if you need to map a function that requires all members of a group + be in a single chunk. + """ + from xarray.core.dataarray import DataArray + from xarray.core.dataset import Dataset + from xarray.core.duck_array_ops import shuffle_array + + (grouper,) = self.groupers + dim = self._group_dim + + # Slices mean this is already sorted. E.g. resampling ops, _DummyGroup + if all(isinstance(idx, slice) for idx in self._group_indices): + return + + was_array = isinstance(self._obj, DataArray) + as_dataset = self._obj._to_temp_dataset() if was_array else self._obj + + shuffled = Dataset() + for name, var in as_dataset._variables.items(): + if dim not in var.dims: + shuffled[name] = var + continue + shuffled_data = shuffle_array( + var._data, list(self._group_indices), axis=var.get_axis_num(dim) + ) + shuffled[name] = var._replace(data=shuffled_data) + + # Replace self._group_indices with slices + slices = [] + start = 0 + for idxr in self._group_indices: + slices.append(slice(start, start + len(idxr))) + start += len(idxr) + # TODO: we have now broken the invariant + # self._group_indices ≠ self.groupers[0].group_indices + self._group_indices = tuple(slices) + if was_array: + self._obj = self._obj._from_temp_dataset(shuffled) + else: + self._obj = shuffled + def map( self, func: Callable, diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 0caab6e8247..31d8e88dde1 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -106,6 +106,7 @@ def _importorskip( has_h5netcdf, requires_h5netcdf = _importorskip("h5netcdf") has_cftime, requires_cftime = _importorskip("cftime") has_dask, requires_dask = _importorskip("dask") +has_dask_ge_2024_08_0, _ = _importorskip("dask", minversion="2024.08.0") with warnings.catch_warnings(): warnings.filterwarnings( "ignore", diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 6c9254966d9..c41086cdf97 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -21,6 +21,7 @@ assert_identical, create_test_data, has_cftime, + has_dask_ge_2024_08_0, has_flox, requires_cftime, requires_dask, @@ -1293,11 +1294,26 @@ def test_groupby_sum(self) -> None: assert_allclose(expected_sum_axis1, grouped.reduce(np.sum, "y")) assert_allclose(expected_sum_axis1, grouped.sum("y")) + @pytest.mark.parametrize( + "shuffle", + [ + pytest.param( + True, + marks=pytest.mark.skipif( + not has_dask_ge_2024_08_0, reason="dask too old" + ), + ), + False, + ], + ) @pytest.mark.parametrize("method", ["sum", "mean", "median"]) - def test_groupby_reductions(self, method) -> None: + def test_groupby_reductions(self, method: str, shuffle: bool) -> None: array = self.da grouped = array.groupby("abc") + if shuffle: + grouped.shuffle() + reduction = getattr(np, method) expected = Dataset( { From 60d76197388180945434beae2e3cda4287e1254f Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 6 Aug 2024 20:36:50 -0600 Subject: [PATCH 02/37] Cleanup --- xarray/core/duck_array_ops.py | 15 --------------- xarray/core/groupby.py | 12 +++++++----- xarray/core/types.py | 2 +- xarray/core/variable.py | 18 +++++++++++++++++- xarray/namedarray/daskmanager.py | 9 +++++++++ xarray/namedarray/parallelcompat.py | 5 +++++ 6 files changed, 39 insertions(+), 22 deletions(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 25bd86177df..8993c136ba6 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -831,18 +831,3 @@ def chunked_nanfirst(darray, axis): def chunked_nanlast(darray, axis): return _chunked_first_or_last(darray, axis, op=nputils.nanlast) - - -def shuffle_array(array, indices: list[list[int]], axis: int): - # TODO: do chunk manager dance here. - if is_duck_dask_array(array): - if not module_available("dask", minversion="2024.08.0"): - raise ValueError( - "This method is very inefficient on dask<2024.08.0. Please upgrade." - ) - # TODO: handle dimensions - return array.shuffle(indexer=indices, axis=axis) - else: - indexer = np.concatenate(indices) - # TODO: Do the array API thing here. - return np.take(array, indices=indexer, axis=axis) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 9fbf6778aea..95a1680e6f0 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -529,7 +529,6 @@ def shuffle(self) -> None: """ from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset - from xarray.core.duck_array_ops import shuffle_array (grouper,) = self.groupers dim = self._group_dim @@ -538,6 +537,8 @@ def shuffle(self) -> None: if all(isinstance(idx, slice) for idx in self._group_indices): return + indices: tuple[list[int]] = self._group_indices # type: ignore[assignment] + was_array = isinstance(self._obj, DataArray) as_dataset = self._obj._to_temp_dataset() if was_array else self._obj @@ -546,21 +547,22 @@ def shuffle(self) -> None: if dim not in var.dims: shuffled[name] = var continue - shuffled_data = shuffle_array( - var._data, list(self._group_indices), axis=var.get_axis_num(dim) - ) - shuffled[name] = var._replace(data=shuffled_data) + shuffled[name] = var._shuffle(indices=list(indices), dim=dim) # Replace self._group_indices with slices slices = [] start = 0 for idxr in self._group_indices: + if TYPE_CHECKING: + assert not isinstance(idxr, slice) slices.append(slice(start, start + len(idxr))) start += len(idxr) # TODO: we have now broken the invariant # self._group_indices ≠ self.groupers[0].group_indices self._group_indices = tuple(slices) if was_array: + if TYPE_CHECKING: + assert isinstance(self._obj, DataArray) self._obj = self._obj._from_temp_dataset(shuffled) else: self._obj = shuffled diff --git a/xarray/core/types.py b/xarray/core/types.py index 591320d26da..96e75e18b51 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -297,7 +297,7 @@ def copy( ZarrWriteModes = Literal["w", "w-", "a", "a-", "r+", "r"] GroupKey = Any -GroupIndex = Union[int, slice, list[int]] +GroupIndex = Union[slice, list[int]] GroupIndices = tuple[GroupIndex, ...] Bins = Union[ int, Sequence[int], Sequence[float], Sequence[pd.Timestamp], np.ndarray, pd.Index diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 828c53e6187..b37959f2a38 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -44,7 +44,13 @@ maybe_coerce_to_str, ) from xarray.namedarray.core import NamedArray, _raise_if_any_duplicate_dimensions -from xarray.namedarray.pycompat import integer_types, is_0d_dask_array, to_duck_array +from xarray.namedarray.parallelcompat import get_chunked_array_type +from xarray.namedarray.pycompat import ( + integer_types, + is_0d_dask_array, + is_chunked_array, + to_duck_array, +) from xarray.util.deprecation_helpers import deprecate_dims NON_NUMPY_SUPPORTED_ARRAY_TYPES = ( @@ -998,6 +1004,16 @@ def compute(self, **kwargs): new = self.copy(deep=False) return new.load(**kwargs) + def _shuffle(self, indices: list[list[int]], dim: Hashable) -> Self: + array = self._data + if is_chunked_array(array): + chunkmanager = get_chunked_array_type(array) + return chunkmanager.shuffle( + array, indexer=indices, axis=self.get_axis_num(dim) + ) + else: + return self.isel({dim: np.concatenate(indices)}) + def isel( self, indexers: Mapping[Any, Any] | None = None, diff --git a/xarray/namedarray/daskmanager.py b/xarray/namedarray/daskmanager.py index 963d12fd865..aa4ced9f37a 100644 --- a/xarray/namedarray/daskmanager.py +++ b/xarray/namedarray/daskmanager.py @@ -251,3 +251,12 @@ def store( targets=targets, **kwargs, ) + + def shuffle(self, x: DaskArray, indexer: list[list[int]], axis: int) -> DaskArray: + import dask.array + + if not module_available("dask", minversion="2024.08.0"): + raise ValueError( + "This method is very inefficient on dask<2024.08.0. Please upgrade." + ) + return dask.array.shuffle(x, indexer, axis) diff --git a/xarray/namedarray/parallelcompat.py b/xarray/namedarray/parallelcompat.py index dd555fe200a..f3c73027a8a 100644 --- a/xarray/namedarray/parallelcompat.py +++ b/xarray/namedarray/parallelcompat.py @@ -364,6 +364,11 @@ def compute( """ raise NotImplementedError() + def shuffle( + self, x: T_ChunkedArray, indexer: list[list[int]], axis: int + ) -> T_ChunkedArray: + raise NotImplementedError() + @property def array_api(self) -> Any: """ From d1429cd06373d16111cda5f8942a977f324c78a8 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 7 Aug 2024 09:23:58 -0600 Subject: [PATCH 03/37] Cleanup --- xarray/core/groupby.py | 10 ++-------- xarray/core/variable.py | 7 ++++++- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 95a1680e6f0..b62fb399023 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -532,13 +532,6 @@ def shuffle(self) -> None: (grouper,) = self.groupers dim = self._group_dim - - # Slices mean this is already sorted. E.g. resampling ops, _DummyGroup - if all(isinstance(idx, slice) for idx in self._group_indices): - return - - indices: tuple[list[int]] = self._group_indices # type: ignore[assignment] - was_array = isinstance(self._obj, DataArray) as_dataset = self._obj._to_temp_dataset() if was_array else self._obj @@ -547,7 +540,7 @@ def shuffle(self) -> None: if dim not in var.dims: shuffled[name] = var continue - shuffled[name] = var._shuffle(indices=list(indices), dim=dim) + shuffled[name] = var._shuffle(indices=list(self._group_indices), dim=dim) # Replace self._group_indices with slices slices = [] @@ -557,6 +550,7 @@ def shuffle(self) -> None: assert not isinstance(idxr, slice) slices.append(slice(start, start + len(idxr))) start += len(idxr) + # TODO: we have now broken the invariant # self._group_indices ≠ self.groupers[0].group_indices self._group_indices = tuple(slices) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index b37959f2a38..9272303061a 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1004,7 +1004,12 @@ def compute(self, **kwargs): new = self.copy(deep=False) return new.load(**kwargs) - def _shuffle(self, indices: list[list[int]], dim: Hashable) -> Self: + def _shuffle(self, indices: list[slice | list[int]], dim: Hashable) -> Self: + size = self.sizes[dim] + indices: list[list[int]] = [ + list(range(*idx.indices(size))) if isinstance(idx, slice) else idx + for idx in self._group_indices + ] array = self._data if is_chunked_array(array): chunkmanager = get_chunked_array_type(array) From 31fc00e5382470611d32e0aa14b6dd55ca99e05c Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 7 Aug 2024 11:18:59 -0600 Subject: [PATCH 04/37] fix --- xarray/core/variable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 9272303061a..b78e316b258 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1008,7 +1008,7 @@ def _shuffle(self, indices: list[slice | list[int]], dim: Hashable) -> Self: size = self.sizes[dim] indices: list[list[int]] = [ list(range(*idx.indices(size))) if isinstance(idx, slice) else idx - for idx in self._group_indices + for idx in indices ] array = self._data if is_chunked_array(array): From 458385364687d774d9d7d34bbfe7ed3a260af70c Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 13 Aug 2024 15:39:25 -0600 Subject: [PATCH 05/37] return groupby instance from shuffle --- xarray/core/groupby.py | 34 +++++++++------------------- xarray/core/variable.py | 10 ++++---- xarray/groupers.py | 31 ++++++++++++++++++++++++- xarray/tests/test_groupby.py | 44 ++++++++++++++++-------------------- 4 files changed, 66 insertions(+), 53 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index b62fb399023..99ce647c3fc 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -4,7 +4,7 @@ import warnings from collections.abc import Hashable, Iterator, Mapping, Sequence from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union +from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Self, Union import numpy as np import pandas as pd @@ -517,12 +517,11 @@ def sizes(self) -> Mapping[Hashable, int]: self._sizes = self._obj.isel({self._group_dim: index}).sizes return self._sizes - def shuffle(self) -> None: + def shuffle(self) -> Self: """ Shuffle the underlying object so that all members in a group occur sequentially. - The order of appearance is not guaranteed. This method modifies the underlying Xarray - object in place. + The order of appearance is not guaranteed. Use this method first if you need to map a function that requires all members of a group be in a single chunk. @@ -536,30 +535,19 @@ def shuffle(self) -> None: as_dataset = self._obj._to_temp_dataset() if was_array else self._obj shuffled = Dataset() + if grouper.name not in as_dataset._variables: + as_dataset.coords[grouper.name] = grouper.group1d for name, var in as_dataset._variables.items(): if dim not in var.dims: shuffled[name] = var continue shuffled[name] = var._shuffle(indices=list(self._group_indices), dim=dim) - - # Replace self._group_indices with slices - slices = [] - start = 0 - for idxr in self._group_indices: - if TYPE_CHECKING: - assert not isinstance(idxr, slice) - slices.append(slice(start, start + len(idxr))) - start += len(idxr) - - # TODO: we have now broken the invariant - # self._group_indices ≠ self.groupers[0].group_indices - self._group_indices = tuple(slices) - if was_array: - if TYPE_CHECKING: - assert isinstance(self._obj, DataArray) - self._obj = self._obj._from_temp_dataset(shuffled) - else: - self._obj = shuffled + shuffled = self._maybe_unstack(shuffled) + new_obj = self._obj._from_temp_dataset(shuffled) if was_array else shuffled + return new_obj.groupby( + {grouper.name: grouper.grouper.reset()}, + restore_coord_dims=self._restore_coord_dims, + ) def map( self, diff --git a/xarray/core/variable.py b/xarray/core/variable.py index b78e316b258..b32e1cbfafe 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1006,18 +1006,20 @@ def compute(self, **kwargs): def _shuffle(self, indices: list[slice | list[int]], dim: Hashable) -> Self: size = self.sizes[dim] - indices: list[list[int]] = [ + no_slices: list[list[int]] = [ list(range(*idx.indices(size))) if isinstance(idx, slice) else idx for idx in indices ] array = self._data if is_chunked_array(array): chunkmanager = get_chunked_array_type(array) - return chunkmanager.shuffle( - array, indexer=indices, axis=self.get_axis_num(dim) + return self._replace( + data=chunkmanager.shuffle( + array, indexer=no_slices, axis=self.get_axis_num(dim) + ) ) else: - return self.isel({dim: np.concatenate(indices)}) + return self.isel({dim: np.concatenate(no_slices)}) def isel( self, diff --git a/xarray/groupers.py b/xarray/groupers.py index becb005b66c..0c6619f89bd 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -9,7 +9,7 @@ import datetime from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import Any, Literal, cast +from typing import Any, Literal, Self, cast import numpy as np import pandas as pd @@ -90,6 +90,13 @@ def factorize(self, group: T_Group) -> EncodedGroups: """ pass + @abstractmethod + def reset(self) -> Self: + """ + Creates a new version of this Grouper clearing any caches. + """ + pass + class Resampler(Grouper): """ @@ -114,6 +121,9 @@ def group_as_index(self) -> pd.Index: self._group_as_index = self.group.to_index() return self._group_as_index + def reset(self) -> Self: + return type(self)() + def factorize(self, group1d: T_Group) -> EncodedGroups: self.group = group1d @@ -221,6 +231,16 @@ class BinGrouper(Grouper): include_lowest: bool = False duplicates: Literal["raise", "drop"] = "raise" + def reset(self) -> Self: + return type(self)( + bins=self.bins, + right=self.right, + labels=self.labels, + precision=self.precision, + include_lowest=self.include_lowest, + duplicates=self.duplicates, + ) + def __post_init__(self) -> None: if duck_array_ops.isnull(self.bins).all(): raise ValueError("All bin edges are NaN.") @@ -302,6 +322,15 @@ class TimeResampler(Resampler): index_grouper: CFTimeGrouper | pd.Grouper = field(init=False, repr=False) group_as_index: pd.Index = field(init=False, repr=False) + def reset(self) -> Self: + return type(self)( + freq=self.freq, + closed=self.closed, + label=self.label, + origin=self.origin, + offset=self.offset, + ) + def _init_properties(self, group: T_Group) -> None: from xarray import CFTimeIndex diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index c41086cdf97..11c57ef465e 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -23,6 +23,7 @@ has_cftime, has_dask_ge_2024_08_0, has_flox, + raise_if_dask_computes, requires_cftime, requires_dask, requires_flox, @@ -1294,26 +1295,19 @@ def test_groupby_sum(self) -> None: assert_allclose(expected_sum_axis1, grouped.reduce(np.sum, "y")) assert_allclose(expected_sum_axis1, grouped.sum("y")) - @pytest.mark.parametrize( - "shuffle", - [ - pytest.param( - True, - marks=pytest.mark.skipif( - not has_dask_ge_2024_08_0, reason="dask too old" - ), - ), - False, - ], - ) + @pytest.mark.parametrize("use_flox", [True, False]) + @pytest.mark.parametrize("shuffle", [True, False]) + @pytest.mark.parametrize("chunk", [True, False]) @pytest.mark.parametrize("method", ["sum", "mean", "median"]) - def test_groupby_reductions(self, method: str, shuffle: bool) -> None: - array = self.da - grouped = array.groupby("abc") - - if shuffle: - grouped.shuffle() + def test_groupby_reductions( + self, use_flox: bool, method: str, shuffle: bool, chunk: bool + ) -> None: + if shuffle and chunk and not has_dask_ge_2024_08_0: + pytest.skip() + array = self.da + if chunk: + array.data = array.chunk({"y": 5}).data reduction = getattr(np, method) expected = Dataset( { @@ -1331,14 +1325,14 @@ def test_groupby_reductions(self, method: str, shuffle: bool) -> None: } )["foo"] - with xr.set_options(use_flox=False): - actual_legacy = getattr(grouped, method)(dim="y") - - with xr.set_options(use_flox=True): - actual_npg = getattr(grouped, method)(dim="y") + with raise_if_dask_computes(): + grouped = array.groupby("abc") + if shuffle: + grouped = grouped.shuffle() - assert_allclose(expected, actual_legacy) - assert_allclose(expected, actual_npg) + with xr.set_options(use_flox=use_flox): + actual = getattr(grouped, method)(dim="y") + assert_allclose(expected, actual) def test_groupby_count(self) -> None: array = DataArray( From abd9dd27c27d6613629a9f6549a64737361d2f91 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 13 Aug 2024 17:01:59 -0600 Subject: [PATCH 06/37] Fix nD by --- xarray/core/groupby.py | 26 ++++++++++++++++++++------ xarray/core/variable.py | 12 ++++-------- xarray/tests/test_groupby.py | 18 ++++++++++++++---- 3 files changed, 38 insertions(+), 18 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 99ce647c3fc..8f9db486f8e 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -45,6 +45,7 @@ peek_at, ) from xarray.core.variable import IndexVariable, Variable +from xarray.namedarray.pycompat import is_chunked_array from xarray.util.deprecation_helpers import _deprecate_positional_args if TYPE_CHECKING: @@ -527,20 +528,33 @@ def shuffle(self) -> Self: be in a single chunk. """ from xarray.core.dataarray import DataArray - from xarray.core.dataset import Dataset (grouper,) = self.groupers dim = self._group_dim + size = self._obj.sizes[dim] was_array = isinstance(self._obj, DataArray) as_dataset = self._obj._to_temp_dataset() if was_array else self._obj + no_slices: list[list[int]] = [ + list(range(*idx.indices(size))) if isinstance(idx, slice) else idx + for idx in self._group_indices + ] - shuffled = Dataset() if grouper.name not in as_dataset._variables: as_dataset.coords[grouper.name] = grouper.group1d - for name, var in as_dataset._variables.items(): - if dim not in var.dims: - shuffled[name] = var - continue + + # Shuffling is only different from `isel` for chunked arrays. + # Extract them out, and treat them specially. The rest, we route through isel. + # This makes it easy to ensure correct handling of indexes. + is_chunked = { + name: var + for name, var in as_dataset._variables.items() + if is_chunked_array(var._data) + } + subset = as_dataset[ + [name for name in as_dataset._variables if name not in is_chunked] + ] + shuffled = subset.isel({dim: np.concatenate(no_slices)}) + for name, var in is_chunked.items(): shuffled[name] = var._shuffle(indices=list(self._group_indices), dim=dim) shuffled = self._maybe_unstack(shuffled) new_obj = self._obj._from_temp_dataset(shuffled) if was_array else shuffled diff --git a/xarray/core/variable.py b/xarray/core/variable.py index b32e1cbfafe..9f73bc5c71b 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1004,22 +1004,18 @@ def compute(self, **kwargs): new = self.copy(deep=False) return new.load(**kwargs) - def _shuffle(self, indices: list[slice | list[int]], dim: Hashable) -> Self: - size = self.sizes[dim] - no_slices: list[list[int]] = [ - list(range(*idx.indices(size))) if isinstance(idx, slice) else idx - for idx in indices - ] + def _shuffle(self, indices: list[list[int]], dim: Hashable) -> Self: array = self._data if is_chunked_array(array): chunkmanager = get_chunked_array_type(array) return self._replace( data=chunkmanager.shuffle( - array, indexer=no_slices, axis=self.get_axis_num(dim) + array, indexer=indices, axis=self.get_axis_num(dim) ) ) else: - return self.isel({dim: np.concatenate(no_slices)}) + assert False, "this should be unreachable" + return self.isel({dim: np.concatenate(indices)}) def isel( self, diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 11c57ef465e..e3a52fc299f 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -2,6 +2,7 @@ import operator import warnings +from typing import Literal from unittest import mock import numpy as np @@ -584,7 +585,12 @@ def test_groupby_repr_datetime(obj) -> None: @pytest.mark.filterwarnings("ignore:Converting non-nanosecond") @pytest.mark.filterwarnings("ignore:invalid value encountered in divide:RuntimeWarning") -def test_groupby_drops_nans() -> None: +@pytest.mark.parametrize("shuffle", [True, False]) +@pytest.mark.parametrize("chunk", [dict(lat=1), dict(lat=2, lon=2), False]) +def test_groupby_drops_nans(shuffle: bool, chunk: Literal[False] | dict) -> None: + xr.set_options(use_flox=False) # TODO: remove + if shuffle and chunk and not has_dask_ge_2024_08_0: + pytest.skip() # GH2383 # nan in 2D data variable (requires stacking) ds = xr.Dataset( @@ -599,13 +605,17 @@ def test_groupby_drops_nans() -> None: ds["id"].values[3, 0] = np.nan ds["id"].values[-1, -1] = np.nan + if chunk: + ds = ds.chunk(chunk) grouped = ds.groupby(ds.id) + if shuffle: + grouped = grouped.shuffle() # non reduction operation expected1 = ds.copy() - expected1.variable.values[0, 0, :] = np.nan - expected1.variable.values[-1, -1, :] = np.nan - expected1.variable.values[3, 0, :] = np.nan + expected1.variable.data[0, 0, :] = np.nan + expected1.variable.data[-1, -1, :] = np.nan + expected1.variable.data[3, 0, :] = np.nan actual1 = grouped.map(lambda x: x).transpose(*ds.variable.dims) assert_identical(actual1, expected1) From 0d70656a03644bd9b5b0b0059bf5d262e3c895b0 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 13 Aug 2024 21:19:41 -0600 Subject: [PATCH 07/37] Skip if no dask --- xarray/tests/test_groupby.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index e3a52fc299f..fb93866d8d1 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -22,6 +22,7 @@ assert_identical, create_test_data, has_cftime, + has_dask, has_dask_ge_2024_08_0, has_flox, raise_if_dask_computes, @@ -586,7 +587,18 @@ def test_groupby_repr_datetime(obj) -> None: @pytest.mark.filterwarnings("ignore:Converting non-nanosecond") @pytest.mark.filterwarnings("ignore:invalid value encountered in divide:RuntimeWarning") @pytest.mark.parametrize("shuffle", [True, False]) -@pytest.mark.parametrize("chunk", [dict(lat=1), dict(lat=2, lon=2), False]) +@pytest.mark.parametrize( + "chunk", + [ + pytest.param( + dict(lat=1), marks=pytest.mark.skipif(not has_dask, reason="no dask") + ), + pytest.param( + dict(lat=2, lon=2), marks=pytest.mark.skipif(not has_dask, reason="no dask") + ), + False, + ], +) def test_groupby_drops_nans(shuffle: bool, chunk: Literal[False] | dict) -> None: xr.set_options(use_flox=False) # TODO: remove if shuffle and chunk and not has_dask_ge_2024_08_0: @@ -744,7 +756,6 @@ def test_groupby_none_group_name() -> None: def test_groupby_getitem(dataset) -> None: - assert_identical(dataset.sel(x=["a"]), dataset.groupby("x")["a"]) assert_identical(dataset.sel(z=[1]), dataset.groupby("z")[1]) assert_identical(dataset.foo.sel(x=["a"]), dataset.foo.groupby("x")["a"]) @@ -1307,7 +1318,15 @@ def test_groupby_sum(self) -> None: @pytest.mark.parametrize("use_flox", [True, False]) @pytest.mark.parametrize("shuffle", [True, False]) - @pytest.mark.parametrize("chunk", [True, False]) + @pytest.mark.parametrize( + "chunk", + [ + pytest.param( + True, marks=pytest.mark.skipif(not has_dask, reason="no dask") + ), + False, + ], + ) @pytest.mark.parametrize("method", ["sum", "mean", "median"]) def test_groupby_reductions( self, use_flox: bool, method: str, shuffle: bool, chunk: bool From fafb937778285ea38bc2cbc1407ba4c072cc2944 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 13 Aug 2024 21:21:58 -0600 Subject: [PATCH 08/37] fix tests --- xarray/tests/test_groupby.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index fb93866d8d1..368ec29684d 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -600,7 +600,6 @@ def test_groupby_repr_datetime(obj) -> None: ], ) def test_groupby_drops_nans(shuffle: bool, chunk: Literal[False] | dict) -> None: - xr.set_options(use_flox=False) # TODO: remove if shuffle and chunk and not has_dask_ge_2024_08_0: pytest.skip() # GH2383 @@ -632,7 +631,8 @@ def test_groupby_drops_nans(shuffle: bool, chunk: Literal[False] | dict) -> None assert_identical(actual1, expected1) # reduction along grouped dimension - actual2 = grouped.mean() + with xr.set_options(use_flox=False): # TODO: remove + actual2 = grouped.mean() stacked = ds.stack({"xy": ["lat", "lon"]}) expected2 = ( stacked.variable.where(stacked.id.notnull()) @@ -2573,6 +2573,9 @@ def factorize(self, group) -> EncodedGroups: codes = group.copy(data=codes_).rename("year") return EncodedGroups(codes=codes, full_index=pd.Index(uniques)) + def reset(self): + return type(self)() + da = xr.DataArray( dims="time", data=np.arange(20), From a08450efa2ebf796631b4b08509043f246944186 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 14 Aug 2024 13:44:08 -0600 Subject: [PATCH 09/37] Add `chunks` to signature --- xarray/core/groupby.py | 8 +++++--- xarray/namedarray/daskmanager.py | 10 ++++++---- xarray/namedarray/parallelcompat.py | 3 ++- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 4a6093ff21b..d45e8d93104 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -53,7 +53,7 @@ from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset - from xarray.core.types import GroupIndex, GroupIndices, GroupKey + from xarray.core.types import GroupIndex, GroupIndices, GroupKey, T_Chunks from xarray.core.utils import Frozen from xarray.groupers import Grouper @@ -518,7 +518,7 @@ def sizes(self) -> Mapping[Hashable, int]: self._sizes = self._obj.isel({self._group_dim: index}).sizes return self._sizes - def shuffle(self) -> Self: + def shuffle(self, chunks: T_Chunks = "auto") -> Self: """ Shuffle the underlying object so that all members in a group occur sequentially. @@ -555,7 +555,9 @@ def shuffle(self) -> Self: ] shuffled = subset.isel({dim: np.concatenate(no_slices)}) for name, var in is_chunked.items(): - shuffled[name] = var._shuffle(indices=list(self._group_indices), dim=dim) + shuffled[name] = var._shuffle( + indices=list(self._group_indices), dim=dim, chunks=chunks + ) shuffled = self._maybe_unstack(shuffled) new_obj = self._obj._from_temp_dataset(shuffled) if was_array else shuffled return new_obj.groupby( diff --git a/xarray/namedarray/daskmanager.py b/xarray/namedarray/daskmanager.py index e3089472051..16206264f9e 100644 --- a/xarray/namedarray/daskmanager.py +++ b/xarray/namedarray/daskmanager.py @@ -252,11 +252,13 @@ def store( **kwargs, ) - def shuffle(self, x: DaskArray, indexer: list[list[int]], axis: int) -> DaskArray: + def shuffle( + self, x: DaskArray, indexer: list[list[int]], axis: int, chunks: T_Chunks + ) -> DaskArray: import dask.array - if not module_available("dask", minversion="2024.08.0"): + if not module_available("dask", minversion="2024.08.1"): raise ValueError( - "This method is very inefficient on dask<2024.08.0. Please upgrade." + "This method is very inefficient on dask<2024.08.1. Please upgrade." ) - return dask.array.shuffle(x, indexer, axis) + return dask.array.shuffle(x, indexer, axis, chunks=chunks) diff --git a/xarray/namedarray/parallelcompat.py b/xarray/namedarray/parallelcompat.py index 5271712e691..dbbc9df0d95 100644 --- a/xarray/namedarray/parallelcompat.py +++ b/xarray/namedarray/parallelcompat.py @@ -19,6 +19,7 @@ if TYPE_CHECKING: from xarray.namedarray._typing import ( + T_Chunks, _Chunks, _DType, _DType_co, @@ -357,7 +358,7 @@ def compute( raise NotImplementedError() def shuffle( - self, x: T_ChunkedArray, indexer: list[list[int]], axis: int + self, x: T_ChunkedArray, indexer: list[list[int]], axis: int, chunks: T_Chunks ) -> T_ChunkedArray: raise NotImplementedError() From d0cd218b0ea6cd7ad29481d5fc63fd4a5db03d99 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 14 Aug 2024 13:53:40 -0600 Subject: [PATCH 10/37] FIx self --- xarray/core/groupby.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index d45e8d93104..9329e637fd8 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -1,10 +1,11 @@ from __future__ import annotations import copy +import sys import warnings from collections.abc import Callable, Hashable, Iterator, Mapping, Sequence from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Generic, Literal, Self, Union +from typing import TYPE_CHECKING, Any, Generic, Literal, Union import numpy as np import pandas as pd @@ -57,6 +58,11 @@ from xarray.core.utils import Frozen from xarray.groupers import Grouper + if sys.version_info >= (3, 11): + from typing import Self + else: + from typing_extensions import Self + def check_reduce_dims(reduce_dims, dimensions): if reduce_dims is not ...: From 4edc976a9c26dc592484cc168ae9e6ac85e33fb0 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 14 Aug 2024 13:57:07 -0600 Subject: [PATCH 11/37] Another Self fix --- xarray/groupers.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/xarray/groupers.py b/xarray/groupers.py index dd3456ac533..b8dfa14c522 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -7,13 +7,19 @@ from __future__ import annotations import datetime +import sys from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import Any, Literal, Self, cast +from typing import Any, Literal, cast import numpy as np import pandas as pd +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self + from xarray.coding.cftime_offsets import _new_to_legacy_freq from xarray.core import duck_array_ops from xarray.core.dataarray import DataArray From 0b42be48798211daa060d4fb3a061d4b2bfb33a1 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 14 Aug 2024 14:12:19 -0600 Subject: [PATCH 12/37] Forward chunks too --- xarray/core/variable.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 7a949ba8d8c..a68139cb5a7 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1004,13 +1004,18 @@ def compute(self, **kwargs): new = self.copy(deep=False) return new.load(**kwargs) - def _shuffle(self, indices: list[list[int]], dim: Hashable) -> Self: + def _shuffle( + self, indices: list[list[int]], dim: Hashable, chunks: T_Chunks + ) -> Self: array = self._data if is_chunked_array(array): chunkmanager = get_chunked_array_type(array) return self._replace( data=chunkmanager.shuffle( - array, indexer=indices, axis=self.get_axis_num(dim) + array, + indexer=indices, + axis=self.get_axis_num(dim), + chunks=chunks, ) ) else: From c52734dd8bf1569a43987f9c8fcb90dcd6cad244 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 14 Aug 2024 14:13:08 -0600 Subject: [PATCH 13/37] [revert] --- xarray/namedarray/daskmanager.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/xarray/namedarray/daskmanager.py b/xarray/namedarray/daskmanager.py index 16206264f9e..c1b6f92ad0a 100644 --- a/xarray/namedarray/daskmanager.py +++ b/xarray/namedarray/daskmanager.py @@ -261,4 +261,6 @@ def shuffle( raise ValueError( "This method is very inefficient on dask<2024.08.1. Please upgrade." ) - return dask.array.shuffle(x, indexer, axis, chunks=chunks) + if chunks is not None: + raise NotImplementedError + return dask.array.shuffle(x, indexer, axis, chunks=None) From 81806253ebba10e9a8212962b83c434e2808c502 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 14 Aug 2024 14:13:51 -0600 Subject: [PATCH 14/37] undo flox limit --- xarray/tests/test_groupby.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 368ec29684d..e8b718d2deb 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -631,8 +631,7 @@ def test_groupby_drops_nans(shuffle: bool, chunk: Literal[False] | dict) -> None assert_identical(actual1, expected1) # reduction along grouped dimension - with xr.set_options(use_flox=False): # TODO: remove - actual2 = grouped.mean() + actual2 = grouped.mean() stacked = ds.stack({"xy": ["lat", "lon"]}) expected2 = ( stacked.variable.where(stacked.id.notnull()) From 7897c919eaa3952274888d11db5e805eb0d9f46a Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 14 Aug 2024 14:20:29 -0600 Subject: [PATCH 15/37] [revert] --- xarray/core/groupby.py | 2 +- xarray/namedarray/daskmanager.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 9329e637fd8..c671dfe0906 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -524,7 +524,7 @@ def sizes(self) -> Mapping[Hashable, int]: self._sizes = self._obj.isel({self._group_dim: index}).sizes return self._sizes - def shuffle(self, chunks: T_Chunks = "auto") -> Self: + def shuffle(self, chunks: T_Chunks = None) -> Self: """ Shuffle the underlying object so that all members in a group occur sequentially. diff --git a/xarray/namedarray/daskmanager.py b/xarray/namedarray/daskmanager.py index c1b6f92ad0a..ddc5ab88c6e 100644 --- a/xarray/namedarray/daskmanager.py +++ b/xarray/namedarray/daskmanager.py @@ -257,10 +257,10 @@ def shuffle( ) -> DaskArray: import dask.array - if not module_available("dask", minversion="2024.08.1"): + if not module_available("dask", minversion="2024.08.0"): raise ValueError( - "This method is very inefficient on dask<2024.08.1. Please upgrade." + "This method is very inefficient on dask<2024.08.0. Please upgrade." ) if chunks is not None: raise NotImplementedError - return dask.array.shuffle(x, indexer, axis, chunks=None) + return dask.array.shuffle(x, indexer, axis) From 7773548171e2c368d42283c870f71ec30d663f42 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 14 Aug 2024 15:03:24 -0600 Subject: [PATCH 16/37] fix types --- xarray/core/groupby.py | 8 +------- xarray/groupers.py | 8 +------- 2 files changed, 2 insertions(+), 14 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index c671dfe0906..4dde2318772 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -1,7 +1,6 @@ from __future__ import annotations import copy -import sys import warnings from collections.abc import Callable, Hashable, Iterator, Mapping, Sequence from dataclasses import dataclass, field @@ -54,15 +53,10 @@ from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset - from xarray.core.types import GroupIndex, GroupIndices, GroupKey, T_Chunks + from xarray.core.types import GroupIndex, GroupIndices, GroupKey, Self, T_Chunks from xarray.core.utils import Frozen from xarray.groupers import Grouper - if sys.version_info >= (3, 11): - from typing import Self - else: - from typing_extensions import Self - def check_reduce_dims(reduce_dims, dimensions): if reduce_dims is not ...: diff --git a/xarray/groupers.py b/xarray/groupers.py index b8dfa14c522..5402e36c8a6 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -7,7 +7,6 @@ from __future__ import annotations import datetime -import sys from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import Any, Literal, cast @@ -15,18 +14,13 @@ import numpy as np import pandas as pd -if sys.version_info >= (3, 11): - from typing import Self -else: - from typing_extensions import Self - from xarray.coding.cftime_offsets import _new_to_legacy_freq from xarray.core import duck_array_ops from xarray.core.dataarray import DataArray from xarray.core.groupby import T_Group, _DummyGroup from xarray.core.indexes import safe_cast_to_index from xarray.core.resample_cftime import CFTimeGrouper -from xarray.core.types import Bins, DatetimeLike, GroupIndices, SideOptions +from xarray.core.types import Bins, DatetimeLike, GroupIndices, Self, SideOptions from xarray.core.variable import Variable __all__ = [ From 51a7723605f3b977a4cbc3ffb33984b870a1d300 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 14 Aug 2024 20:59:17 -0600 Subject: [PATCH 17/37] Add DataArray.shuffle_by, Dataset.shuffle_by --- doc/api.rst | 4 ++++ xarray/core/common.py | 47 +++++++++++++++++++++++++++++++++++++++++- xarray/core/groupby.py | 33 +++++++++++++++++++++++++---- 3 files changed, 79 insertions(+), 5 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index 6ed8d513934..336ec16f5ea 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -250,6 +250,7 @@ Reshaping and reorganizing Dataset.roll Dataset.pad Dataset.sortby + Dataset.shuffle_by Dataset.broadcast_like DataArray @@ -588,6 +589,7 @@ Reshaping and reorganizing DataArray.roll DataArray.pad DataArray.sortby + DataArray.shuffle_by DataArray.broadcast_like IO / Conversion @@ -771,6 +773,7 @@ Dataset DatasetGroupBy.var DatasetGroupBy.dims DatasetGroupBy.groups + DatasetGroupBy.shuffle DataArray --------- @@ -802,6 +805,7 @@ DataArray DataArrayGroupBy.var DataArrayGroupBy.dims DataArrayGroupBy.groups + DataArrayGroupBy.shuffle Grouper Objects --------------- diff --git a/xarray/core/common.py b/xarray/core/common.py index 664de7146d7..4afae9d4e68 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -52,7 +52,7 @@ T_Variable, ) from xarray.core.variable import Variable - from xarray.groupers import Resampler + from xarray.groupers import Grouper, Resampler DTypeMaybeMapping = Union[DTypeLikeSave, Mapping[Any, DTypeLikeSave]] @@ -874,6 +874,51 @@ def rolling_exp( return rolling_exp.RollingExp(self, window, window_type) + def shuffle_by(self, **groupers: Grouper) -> Self: + """ + Shuffle this object by a Grouper. + + Parameters + ---------- + **groupers : Grouper + Grouper objects using which to shuffle the data. + + Examples + -------- + >>> import dask + >>> from xarray.groupers import UniqueGrouper + >>> da = xr.DataArray( + ... dims="x", + ... data=dask.array.arange(10, chunks=1), + ... coords={"x": [1, 2, 3, 1, 2, 3, 1, 2, 3, 0]}, + ... name="a", + ... ) + >>> da + Size: 80B + dask.array + Coordinates: + * x (x) int64 80B 1 2 3 1 2 3 1 2 3 0 + + >>> da.shuffle_by(x=UniqueGrouper()) + Size: 80B + dask.array + Coordinates: + * x (x) int64 80B 0 1 1 1 2 2 2 3 3 3 + + Returns + ------- + DataArray or Dataset + The same type as this object + + See Also + -------- + DataArrayGroupBy.shuffle + DatasetGroupBy.shuffle + dask.dataframe.DataFrame.shuffle + dask.array.shuffle + """ + return self.groupby(**groupers).shuffle()._obj + def _resample( self, resample_cls: type[T_Resample], diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 4dde2318772..e70c8fc92ad 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -520,12 +520,37 @@ def sizes(self) -> Mapping[Hashable, int]: def shuffle(self, chunks: T_Chunks = None) -> Self: """ - Shuffle the underlying object so that all members in a group occur sequentially. + Sort or "shuffle" the underlying object so that all members in a group occur sequentially. - The order of appearance is not guaranteed. + This method is particularly useful for chunked arrays (e.g. dask, cubed). + particularly when you need to map a function that requires all members of a group + to be present in a single chunk. + For chunked array types, the order of appearance is not guaranteed, but will depend on + the input chunking. - Use this method first if you need to map a function that requires all members of a group - be in a single chunk. + Parameters + ---------- + chunks : int, tuple of int, "auto" or mapping of hashable to int or a TimeResampler, optional + How to adjust chunks along dimensions not present in the array being grouped by. + + Returns + ------- + DataArrayGroupBy or DatasetGroupBy + + Examples + -------- + >>> da = xr.DataArray( + ... dims="x", + ... data=dask.array.arange(10, chunks=3), + ... coords={"x": [1, 2, 3, 1, 2, 3, 1, 2, 3, 0]}, + ... ) + >>> shuffled = da.groupby("x").shuffle() + >>> shuffled.quantile().compute() + + See Also + -------- + dask.dataframe.shuffle + dask.array.shuffle """ from xarray.core.dataarray import DataArray From cc9551336b0603d66dabdfbb9e4de614b68d85d6 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 14 Aug 2024 21:06:01 -0600 Subject: [PATCH 18/37] Add doctest --- xarray/core/groupby.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index e70c8fc92ad..088b61c5eae 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -539,13 +539,20 @@ def shuffle(self, chunks: T_Chunks = None) -> Self: Examples -------- + >>> import dask >>> da = xr.DataArray( ... dims="x", ... data=dask.array.arange(10, chunks=3), ... coords={"x": [1, 2, 3, 1, 2, 3, 1, 2, 3, 0]}, + ... name="a", ... ) >>> shuffled = da.groupby("x").shuffle() - >>> shuffled.quantile().compute() + >>> shuffled.quantile(q=0.5).compute() + Size: 32B + array([9., 3., 4., 5.]) + Coordinates: + quantile float64 8B 0.5 + * x (x) int64 32B 0 1 2 3 See Also -------- From 18f4a40790cb9330ca300830e47e956aff4c4693 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 14 Aug 2024 21:08:05 -0600 Subject: [PATCH 19/37] Refactor --- xarray/core/common.py | 2 +- xarray/core/groupby.py | 12 ++++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/xarray/core/common.py b/xarray/core/common.py index 4afae9d4e68..edf55d02652 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -917,7 +917,7 @@ def shuffle_by(self, **groupers: Grouper) -> Self: dask.dataframe.DataFrame.shuffle dask.array.shuffle """ - return self.groupby(**groupers).shuffle()._obj + return self.groupby(**groupers)._shuffle_obj() def _resample( self, diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 088b61c5eae..1abf4dd5d40 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -559,6 +559,13 @@ def shuffle(self, chunks: T_Chunks = None) -> Self: dask.dataframe.shuffle dask.array.shuffle """ + (grouper,) = self.groupers + return self._shuffle_obj(chunks).groupby( + {grouper.name: grouper.grouper.reset()}, + restore_coord_dims=self._restore_coord_dims, + ) + + def _shuffle_obj(self, chunks: T_Chunks) -> T_DataWithCoords: from xarray.core.dataarray import DataArray (grouper,) = self.groupers @@ -592,10 +599,7 @@ def shuffle(self, chunks: T_Chunks = None) -> Self: ) shuffled = self._maybe_unstack(shuffled) new_obj = self._obj._from_temp_dataset(shuffled) if was_array else shuffled - return new_obj.groupby( - {grouper.name: grouper.grouper.reset()}, - restore_coord_dims=self._restore_coord_dims, - ) + return new_obj def map( self, From f489bcf433382cb42f832edadd0c0894d7d4db7f Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 14 Aug 2024 21:10:39 -0600 Subject: [PATCH 20/37] tweak docstrings --- xarray/core/common.py | 8 +++++++- xarray/core/groupby.py | 9 +++++---- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/xarray/core/common.py b/xarray/core/common.py index edf55d02652..3304ebd7061 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -876,7 +876,13 @@ def rolling_exp( def shuffle_by(self, **groupers: Grouper) -> Self: """ - Shuffle this object by a Grouper. + Sort or "shuffle" this object by a Grouper. + + "Shuffle" means the object is sorted so that all group members occur sequentially, + in the same chunk. Multiple groups may occur in the same chunk. + This method is particularly useful for chunked arrays (e.g. dask, cubed). + For chunked array types, the order of appearance is not guaranteed, but will depend on + the input chunking. Parameters ---------- diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 1abf4dd5d40..698cdf7b1ed 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -520,13 +520,14 @@ def sizes(self) -> Mapping[Hashable, int]: def shuffle(self, chunks: T_Chunks = None) -> Self: """ - Sort or "shuffle" the underlying object so that all members in a group occur sequentially. + Sort or "shuffle" the underlying object. + "Shuffle" means the object is sorted so that all group members occur sequentially, + in the same chunk. Multiple groups may occur in the same chunk. This method is particularly useful for chunked arrays (e.g. dask, cubed). particularly when you need to map a function that requires all members of a group - to be present in a single chunk. - For chunked array types, the order of appearance is not guaranteed, but will depend on - the input chunking. + to be present in a single chunk. For chunked array types, the order of appearance + is not guaranteed, but will depend on the input chunking. Parameters ---------- From ead1bb4dd889ebcce40de15b8ddf114f576c0e97 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 14 Aug 2024 21:14:07 -0600 Subject: [PATCH 21/37] fix typing --- xarray/namedarray/_typing.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py index 57b17385558..5d96f0a4b92 100644 --- a/xarray/namedarray/_typing.py +++ b/xarray/namedarray/_typing.py @@ -78,7 +78,8 @@ def dtype(self) -> _DType_co: ... _Chunks = tuple[_Shape, ...] _NormalizedChunks = tuple[tuple[int, ...], ...] # FYI in some cases we don't allow `None`, which this doesn't take account of. -T_ChunkDim: TypeAlias = int | Literal["auto"] | None | tuple[int, ...] +# # FYI the `str` is for a size string, e.g. "16MB", supported by dask. +T_ChunkDim: TypeAlias = str | int | Literal["auto"] | None | tuple[int, ...] # We allow the tuple form of this (though arguably we could transition to named dims only) T_Chunks: TypeAlias = T_ChunkDim | Mapping[Any, T_ChunkDim] From 75115d053473ab0f9115a1b89fd823967f8b0433 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 14 Aug 2024 21:22:15 -0600 Subject: [PATCH 22/37] Fix --- xarray/core/common.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/xarray/core/common.py b/xarray/core/common.py index 3304ebd7061..a6d719ca595 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -874,7 +874,12 @@ def rolling_exp( return rolling_exp.RollingExp(self, window, window_type) - def shuffle_by(self, **groupers: Grouper) -> Self: + def shuffle_by( + self, + group: Hashable | DataArray | Mapping[Any, Grouper] | None = None, + chunks: T_Chunks = None, + **groupers: Grouper, + ) -> Self: """ Sort or "shuffle" this object by a Grouper. @@ -886,6 +891,12 @@ def shuffle_by(self, **groupers: Grouper) -> Self: Parameters ---------- + group : Hashable or DataArray or IndexVariable or mapping of Hashable to Grouper + Array whose unique values should be used to group this array. If a + Hashable, must be the name of a coordinate contained in this dataarray. If a dictionary, + must map an existing variable name to a :py:class:`Grouper` instance. + chunks : int, tuple of int, "auto" or mapping of hashable to int or a TimeResampler, optional + How to adjust chunks along dimensions not present in the array being grouped by. **groupers : Grouper Grouper objects using which to shuffle the data. @@ -923,7 +934,7 @@ def shuffle_by(self, **groupers: Grouper) -> Self: dask.dataframe.DataFrame.shuffle dask.array.shuffle """ - return self.groupby(**groupers)._shuffle_obj() + return self.groupby(group=group, **groupers)._shuffle_obj(chunks) def _resample( self, From 390863a0269b7081cbfd714bad878b3a602ffec4 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 14 Aug 2024 22:52:19 -0600 Subject: [PATCH 23/37] fix docstring --- xarray/core/common.py | 2 +- xarray/core/groupby.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/xarray/core/common.py b/xarray/core/common.py index a6d719ca595..f0ed81e851b 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -895,7 +895,7 @@ def shuffle_by( Array whose unique values should be used to group this array. If a Hashable, must be the name of a coordinate contained in this dataarray. If a dictionary, must map an existing variable name to a :py:class:`Grouper` instance. - chunks : int, tuple of int, "auto" or mapping of hashable to int or a TimeResampler, optional + chunks : int, tuple of int, "auto" or mapping of hashable to int or tuple of int, optional How to adjust chunks along dimensions not present in the array being grouped by. **groupers : Grouper Grouper objects using which to shuffle the data. diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 698cdf7b1ed..19ea185628a 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -531,7 +531,7 @@ def shuffle(self, chunks: T_Chunks = None) -> Self: Parameters ---------- - chunks : int, tuple of int, "auto" or mapping of hashable to int or a TimeResampler, optional + chunks : int, tuple of int, "auto" or mapping of hashable to int or tuple of int, optional How to adjust chunks along dimensions not present in the array being grouped by. Returns @@ -557,7 +557,7 @@ def shuffle(self, chunks: T_Chunks = None) -> Self: See Also -------- - dask.dataframe.shuffle + dask.dataframe.DataFrame.shuffle dask.array.shuffle """ (grouper,) = self.groupers From a408cb0b2994a8db2ea083f132718c3bcac5246f Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Fri, 16 Aug 2024 20:10:58 -0600 Subject: [PATCH 24/37] bump min version to dask>=2024.08.1 --- xarray/namedarray/daskmanager.py | 11 ++++++----- xarray/tests/__init__.py | 2 +- xarray/tests/test_groupby.py | 6 +++--- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/xarray/namedarray/daskmanager.py b/xarray/namedarray/daskmanager.py index ddc5ab88c6e..e36781c3f47 100644 --- a/xarray/namedarray/daskmanager.py +++ b/xarray/namedarray/daskmanager.py @@ -257,10 +257,11 @@ def shuffle( ) -> DaskArray: import dask.array - if not module_available("dask", minversion="2024.08.0"): + if not module_available("dask", minversion="2024.08.1"): raise ValueError( - "This method is very inefficient on dask<2024.08.0. Please upgrade." + "This method is very inefficient on dask<2024.08.1. Please upgrade." ) - if chunks is not None: - raise NotImplementedError - return dask.array.shuffle(x, indexer, axis) + chunks = chunks or "auto" + if chunks != "auto": + raise NotImplementedError("Only chunks='auto' is supported at present.") + return dask.array.shuffle(x, indexer, axis, chunks=chunks) diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 31d8e88dde1..b950f3520b5 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -106,7 +106,7 @@ def _importorskip( has_h5netcdf, requires_h5netcdf = _importorskip("h5netcdf") has_cftime, requires_cftime = _importorskip("cftime") has_dask, requires_dask = _importorskip("dask") -has_dask_ge_2024_08_0, _ = _importorskip("dask", minversion="2024.08.0") +has_dask_ge_2024_08_1, _ = _importorskip("dask", minversion="2024.08.1") with warnings.catch_warnings(): warnings.filterwarnings( "ignore", diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index e8b718d2deb..6e060ea2358 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -23,7 +23,7 @@ create_test_data, has_cftime, has_dask, - has_dask_ge_2024_08_0, + has_dask_ge_2024_08_1, has_flox, raise_if_dask_computes, requires_cftime, @@ -600,7 +600,7 @@ def test_groupby_repr_datetime(obj) -> None: ], ) def test_groupby_drops_nans(shuffle: bool, chunk: Literal[False] | dict) -> None: - if shuffle and chunk and not has_dask_ge_2024_08_0: + if shuffle and chunk and not has_dask_ge_2024_08_1: pytest.skip() # GH2383 # nan in 2D data variable (requires stacking) @@ -1330,7 +1330,7 @@ def test_groupby_sum(self) -> None: def test_groupby_reductions( self, use_flox: bool, method: str, shuffle: bool, chunk: bool ) -> None: - if shuffle and chunk and not has_dask_ge_2024_08_0: + if shuffle and chunk and not has_dask_ge_2024_08_1: pytest.skip() array = self.da From 05a0fb480283eed2e9fecfa8b2c60a8f28d544be Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Fri, 16 Aug 2024 20:21:51 -0600 Subject: [PATCH 25/37] Fix typing --- xarray/namedarray/daskmanager.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/xarray/namedarray/daskmanager.py b/xarray/namedarray/daskmanager.py index e36781c3f47..82ceadf548b 100644 --- a/xarray/namedarray/daskmanager.py +++ b/xarray/namedarray/daskmanager.py @@ -261,7 +261,8 @@ def shuffle( raise ValueError( "This method is very inefficient on dask<2024.08.1. Please upgrade." ) - chunks = chunks or "auto" + if chunks is None: + chunks = "auto" if chunks != "auto": raise NotImplementedError("Only chunks='auto' is supported at present.") - return dask.array.shuffle(x, indexer, axis, chunks=chunks) + return dask.array.shuffle(x, indexer, axis, chunks="auto") From b8e7f62deba1f6ef40e51fb8283f9d41fb30486b Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Fri, 16 Aug 2024 20:24:06 -0600 Subject: [PATCH 26/37] Fix types --- xarray/core/groupby.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 19ea185628a..2b7d9033399 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -53,7 +53,7 @@ from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset - from xarray.core.types import GroupIndex, GroupIndices, GroupKey, Self, T_Chunks + from xarray.core.types import GroupIndex, GroupIndices, GroupKey, T_Chunks from xarray.core.utils import Frozen from xarray.groupers import Grouper @@ -518,7 +518,7 @@ def sizes(self) -> Mapping[Hashable, int]: self._sizes = self._obj.isel({self._group_dim: index}).sizes return self._sizes - def shuffle(self, chunks: T_Chunks = None) -> Self: + def shuffle(self, chunks: T_Chunks = None) -> DataArrayGroupBy | DatasetGroupBy: """ Sort or "shuffle" the underlying object. @@ -566,7 +566,7 @@ def shuffle(self, chunks: T_Chunks = None) -> Self: restore_coord_dims=self._restore_coord_dims, ) - def _shuffle_obj(self, chunks: T_Chunks) -> T_DataWithCoords: + def _shuffle_obj(self, chunks: T_Chunks) -> T_Xarray: from xarray.core.dataarray import DataArray (grouper,) = self.groupers From 7a99c8fe53926f0d47dea503cf4fa305be6fa73c Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Fri, 30 Aug 2024 11:35:30 -0600 Subject: [PATCH 27/37] remove shuffle_by for now. --- xarray/core/common.py | 64 +------------------------------------------ 1 file changed, 1 insertion(+), 63 deletions(-) diff --git a/xarray/core/common.py b/xarray/core/common.py index 66f094da72a..74c03f9baf5 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -52,7 +52,7 @@ T_Variable, ) from xarray.core.variable import Variable - from xarray.groupers import Grouper, Resampler + from xarray.groupers import Resampler DTypeMaybeMapping = Union[DTypeLikeSave, Mapping[Any, DTypeLikeSave]] @@ -888,68 +888,6 @@ def rolling_exp( return rolling_exp.RollingExp(self, window, window_type) - def shuffle_by( - self, - group: Hashable | DataArray | Mapping[Any, Grouper] | None = None, - chunks: T_Chunks = None, - **groupers: Grouper, - ) -> Self: - """ - Sort or "shuffle" this object by a Grouper. - - "Shuffle" means the object is sorted so that all group members occur sequentially, - in the same chunk. Multiple groups may occur in the same chunk. - This method is particularly useful for chunked arrays (e.g. dask, cubed). - For chunked array types, the order of appearance is not guaranteed, but will depend on - the input chunking. - - Parameters - ---------- - group : Hashable or DataArray or IndexVariable or mapping of Hashable to Grouper - Array whose unique values should be used to group this array. If a - Hashable, must be the name of a coordinate contained in this dataarray. If a dictionary, - must map an existing variable name to a :py:class:`Grouper` instance. - chunks : int, tuple of int, "auto" or mapping of hashable to int or tuple of int, optional - How to adjust chunks along dimensions not present in the array being grouped by. - **groupers : Grouper - Grouper objects using which to shuffle the data. - - Examples - -------- - >>> import dask - >>> from xarray.groupers import UniqueGrouper - >>> da = xr.DataArray( - ... dims="x", - ... data=dask.array.arange(10, chunks=1), - ... coords={"x": [1, 2, 3, 1, 2, 3, 1, 2, 3, 0]}, - ... name="a", - ... ) - >>> da - Size: 80B - dask.array - Coordinates: - * x (x) int64 80B 1 2 3 1 2 3 1 2 3 0 - - >>> da.shuffle_by(x=UniqueGrouper()) - Size: 80B - dask.array - Coordinates: - * x (x) int64 80B 0 1 1 1 2 2 2 3 3 3 - - Returns - ------- - DataArray or Dataset - The same type as this object - - See Also - -------- - DataArrayGroupBy.shuffle - DatasetGroupBy.shuffle - dask.dataframe.DataFrame.shuffle - dask.array.shuffle - """ - return self.groupby(group=group, **groupers)._shuffle_obj(chunks) - def _resample( self, resample_cls: type[T_Resample], From 5e2fdfb77802e9ebcf9802b2397094f5bbac11a7 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Fri, 30 Aug 2024 13:14:12 -0600 Subject: [PATCH 28/37] Add tests --- xarray/core/groupby.py | 13 ++++---- xarray/core/resample.py | 58 +++++++++++++++++++++++++++++++++++- xarray/tests/test_groupby.py | 47 +++++++++++++++++++++++------ 3 files changed, 103 insertions(+), 15 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 3c387fde072..cf150b1966a 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -566,7 +566,7 @@ def sizes(self) -> Mapping[Hashable, int]: self._sizes = self._obj.isel({self._group_dim: index}).sizes return self._sizes - def shuffle(self, chunks: T_Chunks = None) -> DataArrayGroupBy | DatasetGroupBy: + def shuffle(self, chunks: T_Chunks = None): """ Sort or "shuffle" the underlying object. @@ -610,7 +610,10 @@ def shuffle(self, chunks: T_Chunks = None) -> DataArrayGroupBy | DatasetGroupBy: """ (grouper,) = self.groupers return self._shuffle_obj(chunks).groupby( - {grouper.name: grouper.grouper.reset()}, + # Using group.name handles the BinGrouper case + # It does *not* handle the TimeResampler case, + # so we just override this method in Resample + {grouper.group.name: grouper.grouper.reset()}, restore_coord_dims=self._restore_coord_dims, ) @@ -624,11 +627,11 @@ def _shuffle_obj(self, chunks: T_Chunks) -> T_Xarray: as_dataset = self._obj._to_temp_dataset() if was_array else self._obj no_slices: list[list[int]] = [ list(range(*idx.indices(size))) if isinstance(idx, slice) else idx - for idx in self._group_indices + for idx in self.encoded.group_indices ] if grouper.name not in as_dataset._variables: - as_dataset.coords[grouper.name] = grouper.group1d + as_dataset.coords[grouper.name] = grouper.group # Shuffling is only different from `isel` for chunked arrays. # Extract them out, and treat them specially. The rest, we route through isel. @@ -644,7 +647,7 @@ def _shuffle_obj(self, chunks: T_Chunks) -> T_Xarray: shuffled = subset.isel({dim: np.concatenate(no_slices)}) for name, var in is_chunked.items(): shuffled[name] = var._shuffle( - indices=list(self._group_indices), dim=dim, chunks=chunks + indices=list(self.encoded.group_indices), dim=dim, chunks=chunks ) shuffled = self._maybe_unstack(shuffled) new_obj = self._obj._from_temp_dataset(shuffled) if was_array else shuffled diff --git a/xarray/core/resample.py b/xarray/core/resample.py index 677de48f0b6..8e0c258debb 100644 --- a/xarray/core/resample.py +++ b/xarray/core/resample.py @@ -2,7 +2,7 @@ import warnings from collections.abc import Callable, Hashable, Iterable, Sequence -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from xarray.core._aggregations import ( DataArrayResampleAggregations, @@ -14,6 +14,8 @@ if TYPE_CHECKING: from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset + from xarray.core.types import T_Chunks + from xarray.groupers import Resampler from xarray.groupers import RESAMPLE_DIM @@ -58,6 +60,60 @@ def _flox_reduce( result = result.rename({RESAMPLE_DIM: self._group_dim}) return result + def shuffle(self, chunks: T_Chunks = None): + """ + Sort or "shuffle" the underlying object. + + "Shuffle" means the object is sorted so that all group members occur sequentially, + in the same chunk. Multiple groups may occur in the same chunk. + This method is particularly useful for chunked arrays (e.g. dask, cubed). + particularly when you need to map a function that requires all members of a group + to be present in a single chunk. For chunked array types, the order of appearance + is not guaranteed, but will depend on the input chunking. + + .. warning:: + + With resampling it is a lot better to use ``.chunk`` instead of ``.shuffle``, + since one can only resample a sorted time coordinate. + + Parameters + ---------- + chunks : int, tuple of int, "auto" or mapping of hashable to int or tuple of int, optional + How to adjust chunks along dimensions not present in the array being grouped by. + + Returns + ------- + DataArrayGroupBy or DatasetGroupBy + + Examples + -------- + >>> import dask + >>> da = xr.DataArray( + ... dims="x", + ... data=dask.array.arange(10, chunks=3), + ... coords={"x": [1, 2, 3, 1, 2, 3, 1, 2, 3, 0]}, + ... name="a", + ... ) + >>> shuffled = da.groupby("x").shuffle() + >>> shuffled.quantile(q=0.5).compute() + Size: 32B + array([9., 3., 4., 5.]) + Coordinates: + quantile float64 8B 0.5 + * x (x) int64 32B 0 1 2 3 + + See Also + -------- + dask.dataframe.DataFrame.shuffle + dask.array.shuffle + """ + (grouper,) = self.groupers + shuffled = self._shuffle_obj(chunks).drop_vars(RESAMPLE_DIM) + return shuffled.resample( + {self._group_dim: cast("Resampler", grouper.grouper.reset())}, + restore_coord_dims=self._restore_coord_dims, + ) + def _drop_coords(self) -> T_Xarray: """Drop non-dimension coordinates along the resampled dimension.""" obj = self._obj diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 41947d6626a..a83d840caaf 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -1659,13 +1659,14 @@ def test_groupby_bins( ) with xr.set_options(use_flox=use_flox): - actual = array.groupby_bins("dim_0", bins=bins, **cut_kwargs).sum() + gb = array.groupby_bins("dim_0", bins=bins, **cut_kwargs) + actual = gb.sum() assert_identical(expected, actual) + assert_identical(expected, gb.shuffle().sum()) - actual = array.groupby_bins("dim_0", bins=bins, **cut_kwargs).map( - lambda x: x.sum() - ) + actual = gb.map(lambda x: x.sum()) assert_identical(expected, actual) + assert_identical(expected, gb.shuffle().map(lambda x: x.sum())) # make sure original array dims are unchanged assert len(array.dim_0) == 4 @@ -1810,8 +1811,9 @@ def test_groupby_fastpath_for_monotonic(self, use_flox: bool) -> None: class TestDataArrayResample: + @pytest.mark.parametrize("shuffle", [True, False]) @pytest.mark.parametrize("use_cftime", [True, False]) - def test_resample(self, use_cftime: bool) -> None: + def test_resample(self, use_cftime: bool, shuffle: bool) -> None: if use_cftime and not has_cftime: pytest.skip() times = xr.date_range( @@ -1833,16 +1835,22 @@ def resample_as_pandas(array, *args, **kwargs): array = DataArray(np.arange(10), [("time", times)]) - actual = array.resample(time="24h").mean() + rs = array.resample(time="24h") + + actual = rs.mean() expected = resample_as_pandas(array, "24h") assert_identical(expected, actual) + assert_identical(expected, rs.shuffle().mean()) - actual = array.resample(time="24h").reduce(np.mean) - assert_identical(expected, actual) + assert_identical(expected, rs.reduce(np.mean)) + assert_identical(expected, rs.shuffle().reduce(np.mean)) - actual = array.resample(time="24h", closed="right").mean() + rs = array.resample(time="24h", closed="right") + actual = rs.mean() + shuffled = rs.shuffle().mean() expected = resample_as_pandas(array, "24h", closed="right") assert_identical(expected, actual) + assert_identical(expected, shuffled) with pytest.raises(ValueError, match=r"Index must be monotonic"): array[[2, 0, 1]].resample(time="1D") @@ -2795,6 +2803,27 @@ def test_multiple_groupers_mixed(use_flox) -> None: # ------ +@requires_dask +def test_groupby_shuffle(): + import dask + + da = DataArray( + dask.array.from_array(np.array([1, 2, 3, 0, 2, np.nan]), chunks=2), + dims="d", + coords=dict( + labels1=("d", np.array(["a", "b", "c", "c", "b", "a"])), + labels2=("d", np.array(["x", "y", "z", "z", "y", "x"])), + ), + name="foo", + ) + + gb = da.groupby("labels1") + shuffled = gb.shuffle() + shuffled_obj = shuffled._obj + with xr.set_options(use_flox=False): + xr.testing.assert_identical(gb.mean(), shuffled.mean()) + + # Possible property tests # 1. lambda x: x # 2. grouped-reduce on unique coords is identical to array From a22c7ed0166a4ccb2cfffc7181f08c159de69826 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Fri, 30 Aug 2024 13:42:20 -0600 Subject: [PATCH 29/37] Support shuffling with multiple groupers --- xarray/core/groupby.py | 25 ++++++++++++++++-------- xarray/tests/test_groupby.py | 37 ++++++++++++++---------------------- 2 files changed, 31 insertions(+), 31 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index cf150b1966a..6d80f351795 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -608,19 +608,21 @@ def shuffle(self, chunks: T_Chunks = None): dask.dataframe.DataFrame.shuffle dask.array.shuffle """ - (grouper,) = self.groupers - return self._shuffle_obj(chunks).groupby( + new_groupers = { # Using group.name handles the BinGrouper case # It does *not* handle the TimeResampler case, # so we just override this method in Resample - {grouper.group.name: grouper.grouper.reset()}, + grouper.group.name: grouper.grouper.reset() + for grouper in self.groupers + } + return self._shuffle_obj(chunks).groupby( + new_groupers, restore_coord_dims=self._restore_coord_dims, ) def _shuffle_obj(self, chunks: T_Chunks) -> T_Xarray: from xarray.core.dataarray import DataArray - (grouper,) = self.groupers dim = self._group_dim size = self._obj.sizes[dim] was_array = isinstance(self._obj, DataArray) @@ -629,9 +631,11 @@ def _shuffle_obj(self, chunks: T_Chunks) -> T_Xarray: list(range(*idx.indices(size))) if isinstance(idx, slice) else idx for idx in self.encoded.group_indices ] + no_slices = [idx for idx in no_slices if idx] - if grouper.name not in as_dataset._variables: - as_dataset.coords[grouper.name] = grouper.group + for grouper in self.groupers: + if grouper.name not in as_dataset._variables: + as_dataset.coords[grouper.name] = grouper.group # Shuffling is only different from `isel` for chunked arrays. # Extract them out, and treat them specially. The rest, we route through isel. @@ -644,10 +648,13 @@ def _shuffle_obj(self, chunks: T_Chunks) -> T_Xarray: subset = as_dataset[ [name for name in as_dataset._variables if name not in is_chunked] ] + shuffled = subset.isel({dim: np.concatenate(no_slices)}) for name, var in is_chunked.items(): shuffled[name] = var._shuffle( - indices=list(self.encoded.group_indices), dim=dim, chunks=chunks + indices=list(idx for idx in self.encoded.group_indices if idx), + dim=dim, + chunks=chunks, ) shuffled = self._maybe_unstack(shuffled) new_obj = self._obj._from_temp_dataset(shuffled) if was_array else shuffled @@ -861,7 +868,9 @@ def _maybe_unstack(self, obj): # and `inserted_dims` # if multiple groupers all share the same single dimension, then # we don't stack/unstack. Do that manually now. - obj = obj.unstack(*self.encoded.unique_coord.dims) + dims_to_unstack = self.encoded.unique_coord.dims + if all(dim in obj.dims for dim in dims_to_unstack): + obj = obj.unstack(*dims_to_unstack) to_drop = [ grouper.name for grouper in self.groupers diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index a83d840caaf..11fe8a19e8f 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -2684,8 +2684,9 @@ def test_weather_data_resample(use_flox): assert expected.location.attrs == ds.location.attrs +@pytest.mark.parametrize("shuffle", [True, False]) @pytest.mark.parametrize("use_flox", [True, False]) -def test_multiple_groupers(use_flox) -> None: +def test_multiple_groupers(use_flox: bool, shuffle: bool) -> None: da = DataArray( np.array([1, 2, 3, 0, 2, np.nan]), dims="d", @@ -2697,6 +2698,8 @@ def test_multiple_groupers(use_flox) -> None: ) gb = da.groupby(labels1=UniqueGrouper(), labels2=UniqueGrouper()) + if shuffle: + gb = gb.shuffle() repr(gb) expected = DataArray( @@ -2716,6 +2719,8 @@ def test_multiple_groupers(use_flox) -> None: coords = {"a": ("x", [0, 0, 1, 1]), "b": ("y", [0, 0, 1, 1])} square = DataArray(np.arange(16).reshape(4, 4), coords=coords, dims=["x", "y"]) gb = square.groupby(a=UniqueGrouper(), b=UniqueGrouper()) + if shuffle: + gb = gb.shuffle() repr(gb) with xr.set_options(use_flox=use_flox): actual = gb.mean() @@ -2739,11 +2744,15 @@ def test_multiple_groupers(use_flox) -> None: dims=["x", "y", "z"], ) gb = b.groupby(x=UniqueGrouper(), y=UniqueGrouper()) + if shuffle: + gb = gb.shuffle() repr(gb) with xr.set_options(use_flox=use_flox): assert_identical(gb.mean("z"), b.mean("z")) gb = b.groupby(x=UniqueGrouper(), xy=UniqueGrouper()) + if shuffle: + gb = gb.shuffle() repr(gb) with xr.set_options(use_flox=use_flox): actual = gb.mean() @@ -2758,13 +2767,16 @@ def test_multiple_groupers(use_flox) -> None: @pytest.mark.parametrize("use_flox", [True, False]) -def test_multiple_groupers_mixed(use_flox) -> None: +@pytest.mark.parametrize("shuffle", [True, False]) +def test_multiple_groupers_mixed(use_flox: bool, shuffle: bool) -> None: # This groupby has missing groups ds = xr.Dataset( {"foo": (("x", "y"), np.arange(12).reshape((4, 3)))}, coords={"x": [10, 20, 30, 40], "letters": ("x", list("abba"))}, ) gb = ds.groupby(x=BinGrouper(bins=[5, 15, 25]), letters=UniqueGrouper()) + if shuffle: + gb = gb.shuffle() expected_data = np.array( [ [[0.0, np.nan], [np.nan, 3.0]], @@ -2803,27 +2815,6 @@ def test_multiple_groupers_mixed(use_flox) -> None: # ------ -@requires_dask -def test_groupby_shuffle(): - import dask - - da = DataArray( - dask.array.from_array(np.array([1, 2, 3, 0, 2, np.nan]), chunks=2), - dims="d", - coords=dict( - labels1=("d", np.array(["a", "b", "c", "c", "b", "a"])), - labels2=("d", np.array(["x", "y", "z", "z", "y", "x"])), - ), - name="foo", - ) - - gb = da.groupby("labels1") - shuffled = gb.shuffle() - shuffled_obj = shuffled._obj - with xr.set_options(use_flox=False): - xr.testing.assert_identical(gb.mean(), shuffled.mean()) - - # Possible property tests # 1. lambda x: x # 2. grouped-reduce on unique coords is identical to array From 2d48690603c68abe1c82122f7f2030a62b5683b9 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 11 Sep 2024 15:35:39 -0500 Subject: [PATCH 30/37] Revert "remove shuffle_by for now." This reverts commit 7a99c8fe53926f0d47dea503cf4fa305be6fa73c. --- xarray/core/common.py | 64 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 63 insertions(+), 1 deletion(-) diff --git a/xarray/core/common.py b/xarray/core/common.py index 74c03f9baf5..66f094da72a 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -52,7 +52,7 @@ T_Variable, ) from xarray.core.variable import Variable - from xarray.groupers import Resampler + from xarray.groupers import Grouper, Resampler DTypeMaybeMapping = Union[DTypeLikeSave, Mapping[Any, DTypeLikeSave]] @@ -888,6 +888,68 @@ def rolling_exp( return rolling_exp.RollingExp(self, window, window_type) + def shuffle_by( + self, + group: Hashable | DataArray | Mapping[Any, Grouper] | None = None, + chunks: T_Chunks = None, + **groupers: Grouper, + ) -> Self: + """ + Sort or "shuffle" this object by a Grouper. + + "Shuffle" means the object is sorted so that all group members occur sequentially, + in the same chunk. Multiple groups may occur in the same chunk. + This method is particularly useful for chunked arrays (e.g. dask, cubed). + For chunked array types, the order of appearance is not guaranteed, but will depend on + the input chunking. + + Parameters + ---------- + group : Hashable or DataArray or IndexVariable or mapping of Hashable to Grouper + Array whose unique values should be used to group this array. If a + Hashable, must be the name of a coordinate contained in this dataarray. If a dictionary, + must map an existing variable name to a :py:class:`Grouper` instance. + chunks : int, tuple of int, "auto" or mapping of hashable to int or tuple of int, optional + How to adjust chunks along dimensions not present in the array being grouped by. + **groupers : Grouper + Grouper objects using which to shuffle the data. + + Examples + -------- + >>> import dask + >>> from xarray.groupers import UniqueGrouper + >>> da = xr.DataArray( + ... dims="x", + ... data=dask.array.arange(10, chunks=1), + ... coords={"x": [1, 2, 3, 1, 2, 3, 1, 2, 3, 0]}, + ... name="a", + ... ) + >>> da + Size: 80B + dask.array + Coordinates: + * x (x) int64 80B 1 2 3 1 2 3 1 2 3 0 + + >>> da.shuffle_by(x=UniqueGrouper()) + Size: 80B + dask.array + Coordinates: + * x (x) int64 80B 0 1 1 1 2 2 2 3 3 3 + + Returns + ------- + DataArray or Dataset + The same type as this object + + See Also + -------- + DataArrayGroupBy.shuffle + DatasetGroupBy.shuffle + dask.dataframe.DataFrame.shuffle + dask.array.shuffle + """ + return self.groupby(group=group, **groupers)._shuffle_obj(chunks) + def _resample( self, resample_cls: type[T_Resample], From 7dc5dd188ed38246ebb55b51dd42bdeb72a3cce9 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 16 Sep 2024 22:03:14 -0600 Subject: [PATCH 31/37] bad merge --- xarray/tests/test_groupby.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 4a8d03233e6..c765c718ee8 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -2759,6 +2759,36 @@ def test_weather_data_resample(use_flox): assert expected.location.attrs == ds.location.attrs +@pytest.mark.parametrize("as_dataset", [True, False]) +def test_multiple_groupers_string(as_dataset) -> None: + obj = DataArray( + np.array([1, 2, 3, 0, 2, np.nan]), + dims="d", + coords=dict( + labels1=("d", np.array(["a", "b", "c", "c", "b", "a"])), + labels2=("d", np.array(["x", "y", "z", "z", "y", "x"])), + ), + name="foo", + ) + + if as_dataset: + obj = obj.to_dataset() # type: ignore[assignment] + + expected = obj.groupby(labels1=UniqueGrouper(), labels2=UniqueGrouper()).mean() + actual = obj.groupby(("labels1", "labels2")).mean() + assert_identical(expected, actual) + + # Passes `"labels2"` to squeeze; will raise an error around kwargs rather than the + # warning & type error in the future + with pytest.warns(FutureWarning): + with pytest.raises(TypeError): + obj.groupby("labels1", "labels2") # type: ignore[arg-type, misc] + with pytest.raises(ValueError): + obj.groupby("labels1", foo="bar") # type: ignore[arg-type] + with pytest.raises(ValueError): + obj.groupby("labels1", foo=UniqueGrouper()) + + @pytest.mark.parametrize("shuffle", [True, False]) @pytest.mark.parametrize("use_flox", [True, False]) def test_multiple_groupers(use_flox: bool, shuffle: bool) -> None: From 91e4bd8057fe727a6a1574ecad93a23b5d86859a Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 17 Sep 2024 22:25:48 -0600 Subject: [PATCH 32/37] Add a test --- xarray/tests/test_dask.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 062f0525593..3c7f0321acc 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1803,3 +1803,27 @@ def test_minimize_graph_size(): # all the other dimensions. # e.g. previously for 'x', actual == numchunks['y'] * numchunks['z'] assert actual == numchunks[var], (actual, numchunks[var]) + + +@pytest.mark.parametrize( + "chunks, expected_chunks", + [ + ((1,), (1, 3, 3, 3)), + ((10,), (10,)), + ], +) +def test_shuffle_by(chunks, expected_chunks): + from xarray.groupers import UniqueGrouper + + da = xr.DataArray( + dims="x", + data=dask.array.arange(10, chunks=chunks), + coords={"x": [1, 2, 3, 1, 2, 3, 1, 2, 3, 0]}, + name="a", + ) + ds = da.to_dataset() + + for obj in [ds, da]: + actual = obj.shuffle_by(x=UniqueGrouper()) + assert_identical(actual, obj.sortby("x")) + assert actual.chunksizes["x"] == expected_chunks From 1e4f805ead5543ea1be0ae0fb0254ae4af08af0d Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Fri, 1 Nov 2024 14:20:56 -0700 Subject: [PATCH 33/37] Add docs --- doc/user-guide/groupby.rst | 38 ++++++++++++++++++++++++++++++++++++++ xarray/core/variable.py | 1 + 2 files changed, 39 insertions(+) diff --git a/doc/user-guide/groupby.rst b/doc/user-guide/groupby.rst index 98bd7b4833b..defe05e1b26 100644 --- a/doc/user-guide/groupby.rst +++ b/doc/user-guide/groupby.rst @@ -321,3 +321,41 @@ Different groupers can be combined to construct sophisticated GroupBy operations from xarray.groupers import BinGrouper ds.groupby(x=BinGrouper(bins=[5, 15, 25]), letters=UniqueGrouper()).sum() + + +Shuffling +~~~~~~~~~ + +Shuffling is a generalization of sorting a DataArray or Dataset by another DataArray, named ``label`` for example, that follows from the idea of grouping by ``label``. +Shuffling reorders the DataArray or the DataArrays in a Dataset such that all members of a group occur sequentially. For example, + +.. ipython:: python + + da = xr.DataArray( + dims="x", + data=[1, 2, 3, 4, 5, 6], + coords={"label": ("x", "a b c a b c".split(" "))}, + ) + da.shuffle_by("label") + + +:py:meth:`Dataset.shuffle_by` and :py:meth:`DataArray.shuffle_by` can also take Grouper objects: + +.. ipython:: python + + from xarray.groupers import UniqueGrouper + + da.shuffle_by(label=UniqueGrouper()) + + +Shuffling can also be performed on :py:class:`DatasetGroupBy` and :py:class:`DataArrayGroupBy` objects. +The :py:meth:`DatasetGroupBy.shuffle` and :py:meth:`DataArrayGroupBy.shuffle` methods return new :py:class:`DatasetGroupBy` and :py:class:`DataArrayGroupBy` objects that operate on the shuffled Dataset or DataArray respectively. + + +.. ipython:: python + + da.groupby(label=UniqueGrouper()).shuffle() + + +For chunked array types (e.g. dask or cubed), shuffle may result in a more optimized communication pattern when compared to direct indexing by the appropriate indexer. +Shuffling also makes GroupBy operations on chunked arrays an embarrassingly parallel problem, and may significantly improve workloads that use :py:meth:`DatasetGroupBy.map` or :py:meth:`DataArrayGroupBy.map`. diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 29571f935ba..402520c8b4b 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1022,6 +1022,7 @@ def compute(self, **kwargs): def _shuffle( self, indices: list[list[int]], dim: Hashable, chunks: T_Chunks ) -> Self: + # TODO (dcherian): consider making this public API array = self._data if is_chunked_array(array): chunkmanager = get_chunked_array_type(array) From ad502aa1ac5664dafb6fe8d2a99fbdf6162e3791 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Fri, 1 Nov 2024 14:32:42 -0700 Subject: [PATCH 34/37] bugfix --- xarray/core/groupby.py | 6 +++++- xarray/tests/test_groupby.py | 12 ++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 237c6f9b142..2ba861f6bc1 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -704,7 +704,11 @@ def _shuffle_obj(self, chunks: T_Chunks) -> T_Xarray: [name for name in as_dataset._variables if name not in is_chunked] ] - shuffled = subset.isel({dim: np.concatenate(no_slices)}) + shuffled = ( + subset + if dim not in subset.dims + else subset.isel({dim: np.concatenate(no_slices)}) + ) for name, var in is_chunked.items(): shuffled[name] = var._shuffle( indices=list(idx for idx in self.encoded.group_indices if idx), diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 0d6043f4e44..ad5b5d41ff7 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -3046,3 +3046,15 @@ def test_groupby_multiple_bin_grouper_missing_groups(): }, ) assert_identical(actual, expected) + + +@requires_dask +def test_shuffle_by_simple() -> None: + da = xr.DataArray( + dims="x", + data=[1, 2, 3, 4, 5, 6], + coords={"label": ("x", "a b c a b c".split(" "))}, + ) + actual = da.chunk(x=2).shuffle_by(label=UniqueGrouper()) + expected = da.shuffle_by(label=UniqueGrouper()) + assert_identical(actual, expected) From 4b0c1433760f35277b7964fb92627aaf147496dc Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Sat, 2 Nov 2024 12:55:31 -0600 Subject: [PATCH 35/37] Refactor out Dataset._shuffle --- xarray/core/dataset.py | 25 +++++++++++++++++++++++++ xarray/core/groupby.py | 29 ++++------------------------- 2 files changed, 29 insertions(+), 25 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index bc9360a809d..13133546b7a 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -166,6 +166,7 @@ ResampleCompatible, SideOptions, T_ChunkDimFreq, + T_Chunks, T_DatasetPadConstantValues, T_Xarray, ) @@ -3236,6 +3237,30 @@ def sel( result = self.isel(indexers=query_results.dim_indexers, drop=drop) return result._overwrite_indexes(*query_results.as_tuple()[1:]) + def _shuffle(self, dim, *, indices: list[list[int]], chunks: T_Chunks) -> Self: + # Shuffling is only different from `isel` for chunked arrays. + # Extract them out, and treat them specially. The rest, we route through isel. + # This makes it easy to ensure correct handling of indexes. + is_chunked = { + name: var + for name, var in self._variables.items() + if is_chunked_array(var._data) + } + subset = self[[name for name in self._variables if name not in is_chunked]] + + shuffled = ( + subset + if dim not in subset.dims + else subset.isel({dim: np.concatenate(indices)}) + ) + for name, var in is_chunked.items(): + shuffled[name] = var._shuffle( + indices=indices, + dim=dim, + chunks=chunks, + ) + return shuffled + def head( self, indexers: Mapping[Any, int] | int | None = None, diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 2ba861f6bc1..b13b5dca56f 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -47,7 +47,6 @@ peek_at, ) from xarray.core.variable import IndexVariable, Variable -from xarray.namedarray.pycompat import is_chunked_array from xarray.util.deprecation_helpers import _deprecate_positional_args if TYPE_CHECKING: @@ -678,10 +677,10 @@ def shuffle(self, chunks: T_Chunks = None): def _shuffle_obj(self, chunks: T_Chunks) -> T_Xarray: from xarray.core.dataarray import DataArray - dim = self._group_dim - size = self._obj.sizes[dim] was_array = isinstance(self._obj, DataArray) as_dataset = self._obj._to_temp_dataset() if was_array else self._obj + + size = self._obj.sizes[self._group_dim] no_slices: list[list[int]] = [ list(range(*idx.indices(size))) if isinstance(idx, slice) else idx for idx in self.encoded.group_indices @@ -692,29 +691,9 @@ def _shuffle_obj(self, chunks: T_Chunks) -> T_Xarray: if grouper.name not in as_dataset._variables: as_dataset.coords[grouper.name] = grouper.group - # Shuffling is only different from `isel` for chunked arrays. - # Extract them out, and treat them specially. The rest, we route through isel. - # This makes it easy to ensure correct handling of indexes. - is_chunked = { - name: var - for name, var in as_dataset._variables.items() - if is_chunked_array(var._data) - } - subset = as_dataset[ - [name for name in as_dataset._variables if name not in is_chunked] - ] - - shuffled = ( - subset - if dim not in subset.dims - else subset.isel({dim: np.concatenate(no_slices)}) + shuffled = as_dataset._shuffle( + dim=self._group_dim, indices=no_slices, chunks=chunks ) - for name, var in is_chunked.items(): - shuffled[name] = var._shuffle( - indices=list(idx for idx in self.encoded.group_indices if idx), - dim=dim, - chunks=chunks, - ) shuffled = self._maybe_unstack(shuffled) new_obj = self._obj._from_temp_dataset(shuffled) if was_array else shuffled return new_obj From f624c8fc130f2e6e9990ea6da6bf6f2245417cf6 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Sat, 2 Nov 2024 22:42:55 -0600 Subject: [PATCH 36/37] fix types --- xarray/tests/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index a6d033c3307..7b8795cc09e 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -152,7 +152,7 @@ def _importorskip( not has_numbagg_or_bottleneck, reason="requires numbagg or bottleneck" ) has_numpy_2, requires_numpy_2 = _importorskip("numpy", "2.0.0") -_, requires_flox_0_9_12 = _importorskip("flox", "0.9.12") +has_flox_0_9_12, requires_flox_0_9_12 = _importorskip("flox", "0.9.12") has_array_api_strict, requires_array_api_strict = _importorskip("array_api_strict") From fa6311a652100bb2840933185e10e88463db29f3 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Sat, 2 Nov 2024 13:06:28 -0600 Subject: [PATCH 37/37] Add GroupBy.map(..., shuffle=True) --- xarray/core/dataarray.py | 7 +++ xarray/core/groupby.py | 104 +++++++++++++++++++++++++++++++++-- xarray/core/parallel.py | 2 +- xarray/tests/test_groupby.py | 28 ++++++++-- 4 files changed, 130 insertions(+), 11 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index c6bc082f5ed..8e7c87fed72 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -117,6 +117,7 @@ Self, SideOptions, T_ChunkDimFreq, + T_Chunks, T_ChunksFreq, T_Xarray, ) @@ -661,6 +662,12 @@ def _to_dataset_whole( coord_names = set(self._coords) return Dataset._construct_direct(variables, coord_names, indexes=indexes) + def _shuffle(self, dim, *, indices: list[list[int]], chunks: T_Chunks) -> None: + shuffled = self._to_temp_dataset()._shuffle( + dim=dim, indices=indices, chunks=chunks + ) + return self._from_temp_dataset(shuffled) + def to_dataset( self, dim: Hashable = None, diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index b13b5dca56f..082638c9206 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -6,6 +6,7 @@ import warnings from collections.abc import Callable, Hashable, Iterator, Mapping, Sequence from dataclasses import dataclass, field +from functools import partial from typing import TYPE_CHECKING, Any, Generic, Literal, Union, cast import numpy as np @@ -29,6 +30,7 @@ ) from xarray.core.merge import merge_coords from xarray.core.options import OPTIONS, _get_keep_attrs +from xarray.core.parallel import map_blocks from xarray.core.types import ( Dims, QuantileMethods, @@ -86,6 +88,24 @@ def _codes_to_group_indices(codes: np.ndarray, N: int) -> GroupIndices: return groups +def _infer_map_blocks_template(shuffled: GroupBy, func: Callable, *args, **kwargs): + template = shuffled.map(func, *args, **kwargs) + name = shuffled.group1d.name + chunksizes = shuffled._obj.chunksizes[shuffled._group_dim] + output_group = template[name] + out_group_lens = output_group.groupby(name).count().data + block_ids = np.repeat(np.arange(len(chunksizes)), chunksizes) + frame = pd.DataFrame( + {"block_id": pd.Index(block_ids), "codes": shuffled.encoded.codes} + ) + groups_in_chunk = frame["codes"].groupby(block_ids).unique() + out_chunks = tuple( + itertools.chain(*[out_group_lens[group].tolist() for group in groups_in_chunk]) + ) + template = template.chunk({name: out_chunks}) + return template, name + + def _dummy_copy(xarray_obj): from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset @@ -698,6 +718,64 @@ def _shuffle_obj(self, chunks: T_Chunks) -> T_Xarray: new_obj = self._obj._from_temp_dataset(shuffled) if was_array else shuffled return new_obj + def _map_shuffled(self, func, args, kwargs) -> None: + def wrapper(x, func, groupers, renamer, *args, **kwargs): + return x.groupby(groupers).map(func, *args, **kwargs).rename(renamer) + + shuffled = self.shuffle() + obj = shuffled._obj.copy(deep=False) + try: + template, concat_dim = _infer_map_blocks_template( + shuffled, func, *args, **kwargs + ) + except Exception as e: + raise ValueError("Could not infer template automatically.") from e + group_dim = shuffled._group_dim + + groupers = {} + for grouper in shuffled.groupers: + name = grouper.group.name + if name not in obj: + obj.coords[name] = grouper.group + groupers[name] = grouper.grouper.reset() + + # map_blocks does not support adding new dimensions that are multiply-chunked + # For example, even renaming an existing dimension to a new name will not work. + # This would be needed for grouped reductions where at least one dimension is destroyed. + # So we engage in a renaming game. + result = map_blocks( + # 1. This renamer renames dimensions named after the grouping variable to the + # dimension we are grouping over. + # For example .groupby("label") where label.dims == ("x",); we rename the + # output "label" dimension back to "x" + partial( + wrapper, func=func, groupers=groupers, renamer={concat_dim: group_dim} + ), + obj, + args=args, + kwargs=kwargs, + # 2. Again do the same renaming transform on the template + template=template.rename({concat_dim: group_dim}), + ) + + if ( + group_dim == concat_dim + and self._obj.sizes[group_dim] == template.sizes[group_dim] + ): + # invert the shuffling + inverse = _inverse_permutation_indices(self.encoded.group_indices) + # output chunk sizes are the same as the input's + indices = [ + arr.tolist() + for arr in np.split( + inverse, np.cumsum(self._obj.chunksizes[self._group_dim])[:-1] + ) + ] + result = result._shuffle(dim=group_dim, indices=indices, chunks="auto") + + # 3. Now invert the renaming + return result.rename({group_dim: concat_dim}) + def map( self, func: Callable, @@ -1390,6 +1468,8 @@ def map( func: Callable[..., DataArray], args: tuple[Any, ...] = (), shortcut: bool | None = None, + *, + shuffle: bool = False, **kwargs: Any, ) -> DataArray: """Apply a function to each array in the group and concatenate them @@ -1433,9 +1513,16 @@ def map( applied : DataArray The result of splitting, applying and combining this array. """ - grouped = self._iter_grouped_shortcut() if shortcut else self._iter_grouped() - applied = (maybe_wrap_array(arr, func(arr, *args, **kwargs)) for arr in grouped) - return self._combine(applied, shortcut=shortcut) + if shuffle and self._obj.chunksizes: + return self._map_shuffled(func, args=args, kwargs=kwargs) + else: + grouped = ( + self._iter_grouped_shortcut() if shortcut else self._iter_grouped() + ) + applied = ( + maybe_wrap_array(arr, func(arr, *args, **kwargs)) for arr in grouped + ) + return self._combine(applied, shortcut=shortcut) def apply(self, func, shortcut=False, args=(), **kwargs): """ @@ -1559,6 +1646,8 @@ def map( func: Callable[..., Dataset], args: tuple[Any, ...] = (), shortcut: bool | None = None, + *, + shuffle: bool = False, **kwargs: Any, ) -> Dataset: """Apply a function to each Dataset in the group and concatenate them @@ -1590,9 +1679,12 @@ def map( applied : Dataset The result of splitting, applying and combining this dataset. """ - # ignore shortcut if set (for now) - applied = (func(ds, *args, **kwargs) for ds in self._iter_grouped()) - return self._combine(applied) + if shuffle and self._obj.chunksizes: + return self._map_shuffled(func, args=args, kwargs=kwargs) + else: + # ignore shortcut if set (for now) + applied = (func(ds, *args, **kwargs) for ds in self._iter_grouped()) + return self._combine(applied) def apply(self, func, args=(), shortcut=None, **kwargs): """ diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index a0dfe56807b..41e334e22e1 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -488,7 +488,7 @@ def _wrapper( " Please construct a template with appropriately chunked dask arrays." ) - new_indexes = set(template.xindexes) - set(merged_coordinates) + new_indexes = set(template.xindexes) - set(merged_coordinates.xindexes) modified_indexes = set( name for name, xindex in coordinates.xindexes.items() diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index ad5b5d41ff7..1f2b85825d0 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -219,6 +219,14 @@ def test_groupby_indexvariable(use_flox: bool) -> None: assert_identical(expected, actual) +@pytest.mark.parametrize("shuffle", [True, False]) +@pytest.mark.parametrize( + "chunk", + [ + pytest.param(True, marks=pytest.mark.skipif(not has_dask, reason="no dask")), + False, + ], +) @pytest.mark.parametrize( "obj", [ @@ -226,12 +234,22 @@ def test_groupby_indexvariable(use_flox: bool) -> None: xr.Dataset({"foo": ("x", [1, 2, 3, 4, 5, 6])}, {"x": [1, 1, 1, 2, 2, 2]}), ], ) -def test_groupby_map_shrink_groups(obj) -> None: +def test_groupby_map_shrink_groups(obj, chunk: bool, shuffle: bool) -> None: expected = obj.isel(x=[0, 1, 3, 4]) - actual = obj.groupby("x").map(lambda f: f.isel(x=[0, 1])) + if chunk: + obj = obj.chunk(x=2) + actual = obj.groupby("x").map(lambda f: f.isel(x=[0, 1]), shuffle=shuffle) assert_identical(expected, actual) +@pytest.mark.parametrize("shuffle", [True, False]) +@pytest.mark.parametrize( + "chunk", + [ + pytest.param(True, marks=pytest.mark.skipif(not has_dask, reason="no dask")), + False, + ], +) @pytest.mark.parametrize( "obj", [ @@ -239,7 +257,7 @@ def test_groupby_map_shrink_groups(obj) -> None: xr.Dataset({"foo": ("x", [1, 2, 3])}, {"x": [1, 2, 2]}), ], ) -def test_groupby_map_change_group_size(obj) -> None: +def test_groupby_map_change_group_size(obj, chunk: bool, shuffle: bool) -> None: def func(group): if group.sizes["x"] == 1: result = group.isel(x=[0, 0]) @@ -248,7 +266,9 @@ def func(group): return result expected = obj.isel(x=[0, 0, 1]) - actual = obj.groupby("x").map(func) + if chunk: + obj = obj.chunk(x=2) + actual = obj.groupby("x").map(func, shuffle=shuffle) assert_identical(expected, actual)