Skip to content

Commit

Permalink
Implement sort_remaining for sort_index (#14033)
Browse files Browse the repository at this point in the history
Previously, the `sort_remaining` argument to `sort_index` was ignored. Passing `sort_remaining=False` would raise a `NotImplementedError`. Moreover, for a multiindex, `sort_remaining=True` was not handled correctly: if not all levels were requested as sorted, `sort_index` would behave as if `sort_remaining=False` had been passed.

To fix this case, construct the sort order based on first the provided levels and, if `sort_remaining=True`, the left-over levels (in index order).

To facilitate this, refactor the internal `_get_columns_by_label` function to always return a `Frame`-like object (previously, if we had a `Frame` we would get back a `ColumnAccessor`, and it was only for `IndexedFrame` and above that we'd get something of `Self`-like type back). This meant that calling `_get_sorted_inds` with `by != None` was not possible on an `Index` or `MultiIndex` (the code assumed we'd get a `Frame` back).

- Closes #14011

Authors:
  - Lawrence Mitchell (https://github.com/wence-)

Approvers:
  - GALI PREM SAGAR (https://github.com/galipremsagar)
  - Matthew Roeschke (https://github.com/mroeschke)
  - Vyas Ramasubramani (https://github.com/vyasr)

URL: #14033
  • Loading branch information
wence- authored Sep 5, 2023
1 parent 3e5f019 commit 0b01fe4
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 25 deletions.
12 changes: 7 additions & 5 deletions python/cudf/cudf/core/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from pandas.core.dtypes.common import is_float, is_integer
from pandas.io.formats import console
from pandas.io.formats.printing import pprint_thing
from typing_extensions import assert_never
from typing_extensions import Self, assert_never

import cudf
import cudf.core.common
Expand Down Expand Up @@ -1830,25 +1830,27 @@ def _repr_latex_(self):
return self._get_renderable_dataframe().to_pandas()._repr_latex_()

@_cudf_nvtx_annotate
def _get_columns_by_label(self, labels, downcast=False):
def _get_columns_by_label(
self, labels, *, downcast=False
) -> Self | Series:
"""
Return columns of dataframe by `labels`
If downcast is True, try and downcast from a DataFrame to a Series
"""
new_data = super()._get_columns_by_label(labels, downcast)
ca = self._data.select_by_label(labels)
if downcast:
if is_scalar(labels):
nlevels = 1
elif isinstance(labels, tuple):
nlevels = len(labels)
if self._data.multiindex is False or nlevels == self._data.nlevels:
out = self._constructor_sliced._from_data(
new_data, index=self.index, name=labels
ca, index=self.index, name=labels
)
return out
out = self.__class__._from_data(
new_data, index=self.index, columns=new_data.to_pandas_index()
ca, index=self.index, columns=ca.to_pandas_index()
)
return out

Expand Down
4 changes: 2 additions & 2 deletions python/cudf/cudf/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,12 +362,12 @@ def equals(self, other):
)

@_cudf_nvtx_annotate
def _get_columns_by_label(self, labels, downcast=False):
def _get_columns_by_label(self, labels, *, downcast=False) -> Self:
"""
Returns columns of the Frame specified by `labels`
"""
return self._data.select_by_label(labels)
return self.__class__._from_data(self._data.select_by_label(labels))

@property
@_cudf_nvtx_annotate
Expand Down
31 changes: 17 additions & 14 deletions python/cudf/cudf/core/indexed_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1526,7 +1526,9 @@ def sort_index(
na_position : {'first', 'last'}, default 'last'
Puts NaNs at the beginning if first; last puts NaNs at the end.
sort_remaining : bool, default True
Not yet supported
When sorting a multiindex on a subset of its levels,
should entries be lexsorted by the remaining
(non-specified) levels as well?
ignore_index : bool, default False
if True, index will be replaced with RangeIndex.
key : callable, optional
Expand Down Expand Up @@ -1592,11 +1594,6 @@ def sort_index(
if kind is not None:
raise NotImplementedError("kind is not yet supported")

if not sort_remaining:
raise NotImplementedError(
"sort_remaining == False is not yet supported"
)

if key is not None:
raise NotImplementedError("key is not yet supported.")

Expand All @@ -1609,16 +1606,22 @@ def sort_index(
if level is not None:
# Pandas doesn't handle na_position in case of MultiIndex.
na_position = "first" if ascending is True else "last"
labels = [
idx._get_level_label(lvl)
for lvl in (level if is_list_like(level) else (level,))
]
# Explicitly construct a Frame rather than using type(self)
# to avoid constructing a SingleColumnFrame (e.g. Series).
idx = Frame._from_data(idx._data.select_by_label(labels))
if not is_list_like(level):
level = [level]
by = list(map(idx._get_level_label, level))
if sort_remaining:
handled = set(by)
by.extend(
filter(
lambda n: n not in handled,
self.index._data.names,
)
)
else:
by = list(idx._data.names)

inds = idx._get_sorted_inds(
ascending=ascending, na_position=na_position
by=by, ascending=ascending, na_position=na_position
)
out = self._gather(
GatherMap.from_column_unchecked(
Expand Down
8 changes: 4 additions & 4 deletions python/cudf/cudf/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,17 +797,17 @@ def deserialize(cls, header, frames):

return obj

def _get_columns_by_label(self, labels, downcast=False):
def _get_columns_by_label(self, labels, *, downcast=False) -> Self:
"""Return the column specified by `labels`
For cudf.Series, either the column, or an empty series is returned.
Parameter `downcast` does not have effects.
"""
new_data = super()._get_columns_by_label(labels, downcast)
ca = self._data.select_by_label(labels)

return (
self.__class__._from_data(data=new_data, index=self.index)
if len(new_data) > 0
self.__class__._from_data(data=ca, index=self.index)
if len(ca) > 0
else self.__class__(dtype=self.dtype, name=self.name)
)

Expand Down
23 changes: 23 additions & 0 deletions python/cudf/cudf/tests/test_multiindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -1897,3 +1897,26 @@ def test_multiindex_empty_slice_pandas_compatibility():
with cudf.option_context("mode.pandas_compatible", True):
actual = cudf.from_pandas(expected)
assert_eq(expected, actual, exact=False)


@pytest.mark.parametrize(
"levels",
itertools.chain.from_iterable(
itertools.permutations(range(3), n) for n in range(1, 4)
),
ids=str,
)
def test_multiindex_sort_index_partial(levels):
df = pd.DataFrame(
{
"a": [3, 3, 3, 1, 1, 1, 2, 2],
"b": [4, 2, 7, -1, 11, -2, 7, 7],
"c": [4, 4, 2, 3, 3, 3, 1, 1],
"val": [1, 2, 3, 4, 5, 6, 7, 8],
}
).set_index(["a", "b", "c"])
cdf = cudf.from_pandas(df)

expect = df.sort_index(level=levels, sort_remaining=True)
got = cdf.sort_index(level=levels, sort_remaining=True)
assert_eq(expect, got)

0 comments on commit 0b01fe4

Please sign in to comment.