Skip to content

Commit

Permalink
Support pivot with index or column arguments as lists (#17373)
Browse files Browse the repository at this point in the history
closes #17360

Technically I suppose this was more of an enhancement since the documentation suggested only a single label was supported, but I'll mark as a bug since the error message was not informative.

Authors:
  - Matthew Roeschke (https://github.com/mroeschke)
  - Vyas Ramasubramani (https://github.com/vyasr)

Approvers:
  - Bradley Dice (https://github.com/bdice)

URL: #17373
  • Loading branch information
mroeschke authored Nov 20, 2024
1 parent 04502c8 commit 332cc06
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 16 deletions.
60 changes: 44 additions & 16 deletions python/cudf/cudf/core/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,7 +961,11 @@ def _merge_sorted(
)


def _pivot(col_accessor: ColumnAccessor, index, columns) -> cudf.DataFrame:
def _pivot(
col_accessor: ColumnAccessor,
index: cudf.Index | cudf.MultiIndex,
columns: cudf.Index | cudf.MultiIndex,
) -> cudf.DataFrame:
"""
Reorganize the values of the DataFrame according to the given
index and columns.
Expand Down Expand Up @@ -1012,12 +1016,12 @@ def as_tuple(x):
level_names=(None,) + columns._column_names,
verify=False,
)
return cudf.DataFrame._from_data(
ca, index=cudf.Index(index_labels, name=index.name)
)
return cudf.DataFrame._from_data(ca, index=index_labels)


def pivot(data, columns=None, index=no_default, values=no_default):
def pivot(
data: cudf.DataFrame, columns=None, index=no_default, values=no_default
) -> cudf.DataFrame:
"""
Return reshaped DataFrame organized by the given index and column values.
Expand All @@ -1027,10 +1031,10 @@ def pivot(data, columns=None, index=no_default, values=no_default):
Parameters
----------
columns : column name, optional
Column used to construct the columns of the result.
index : column name, optional
Column used to construct the index of the result.
columns : scalar or list of scalars, optional
Column label(s) used to construct the columns of the result.
index : scalar or list of scalars, optional
Column label(s) used to construct the index of the result.
values : column name or list of column names, optional
Column(s) whose values are rearranged to produce the result.
If not specified, all remaining columns of the DataFrame
Expand Down Expand Up @@ -1069,24 +1073,46 @@ def pivot(data, columns=None, index=no_default, values=no_default):
"""
values_is_list = True
if values is no_default:
already_selected = set(
itertools.chain(
[index] if is_scalar(index) else index,
[columns] if is_scalar(columns) else columns,
)
)
cols_to_select = [
col for col in data._column_names if col not in (index, columns)
col for col in data._column_names if col not in already_selected
]
elif not isinstance(values, (list, tuple)):
cols_to_select = [values]
values_is_list = False
else:
cols_to_select = values
cols_to_select = values # type: ignore[assignment]
if index is no_default:
index = data.index
index_data = data.index
else:
index = cudf.Index(data.loc[:, index])
columns = cudf.Index(data.loc[:, columns])
index_data = data.loc[:, index]
if index_data.ndim == 2:
index_data = cudf.MultiIndex.from_frame(index_data)
if not is_scalar(index) and len(index) == 1:
# pandas converts single level MultiIndex to Index
index_data = index_data.get_level_values(0)
else:
index_data = cudf.Index(index_data)

column_data = data.loc[:, columns]
if column_data.ndim == 2:
column_data = cudf.MultiIndex.from_frame(column_data)
else:
column_data = cudf.Index(column_data)

# Create a DataFrame composed of columns from both
# columns and index
ca = ColumnAccessor(
dict(enumerate(itertools.chain(index._columns, columns._columns))),
dict(
enumerate(
itertools.chain(index_data._columns, column_data._columns)
)
),
verify=False,
)
columns_index = cudf.DataFrame._from_data(ca)
Expand All @@ -1095,7 +1121,9 @@ def pivot(data, columns=None, index=no_default, values=no_default):
if len(columns_index) != len(columns_index.drop_duplicates()):
raise ValueError("Duplicate index-column pairs found. Cannot reshape.")

result = _pivot(data._data.select_by_label(cols_to_select), index, columns)
result = _pivot(
data._data.select_by_label(cols_to_select), index_data, column_data
)

# MultiIndex to Index
if not values_is_list:
Expand Down
17 changes: 17 additions & 0 deletions python/cudf/cudf/tests/test_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,3 +835,20 @@ def test_crosstab_simple():
expected = pd.crosstab(a, [b, c], rownames=["a"], colnames=["b", "c"])
actual = cudf.crosstab(a, [b, c], rownames=["a"], colnames=["b", "c"])
assert_eq(expected, actual, check_dtype=False)


@pytest.mark.parametrize("index", [["ix"], ["ix", "foo"]])
@pytest.mark.parametrize("columns", [["col"], ["col", "baz"]])
def test_pivot_list_like_index_columns(index, columns):
data = {
"bar": ["x", "y", "z", "w"],
"col": ["a", "b", "a", "b"],
"foo": [1, 2, 3, 4],
"ix": [1, 1, 2, 2],
"baz": [0, 0, 0, 0],
}
pd_df = pd.DataFrame(data)
cudf_df = cudf.DataFrame(data)
result = cudf_df.pivot(columns=columns, index=index)
expected = pd_df.pivot(columns=columns, index=index)
assert_eq(result, expected)

0 comments on commit 332cc06

Please sign in to comment.