From 332cc06cdbe2b66d39d96e4ff36e142a84750717 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Wed, 20 Nov 2024 12:02:21 -0800 Subject: [PATCH] Support pivot with index or column arguments as lists (#17373) closes https://github.com/rapidsai/cudf/issues/17360 Technically I suppose this was more of an enhancement since the documentation suggested only a single label was supported, but I'll mark as a bug since the error message was not informative. Authors: - Matthew Roeschke (https://github.com/mroeschke) - Vyas Ramasubramani (https://github.com/vyasr) Approvers: - Bradley Dice (https://github.com/bdice) URL: https://github.com/rapidsai/cudf/pull/17373 --- python/cudf/cudf/core/reshape.py | 60 +++++++++++++++++++------- python/cudf/cudf/tests/test_reshape.py | 17 ++++++++ 2 files changed, 61 insertions(+), 16 deletions(-) diff --git a/python/cudf/cudf/core/reshape.py b/python/cudf/cudf/core/reshape.py index 3d132c92d54..016bd1225cd 100644 --- a/python/cudf/cudf/core/reshape.py +++ b/python/cudf/cudf/core/reshape.py @@ -961,7 +961,11 @@ def _merge_sorted( ) -def _pivot(col_accessor: ColumnAccessor, index, columns) -> cudf.DataFrame: +def _pivot( + col_accessor: ColumnAccessor, + index: cudf.Index | cudf.MultiIndex, + columns: cudf.Index | cudf.MultiIndex, +) -> cudf.DataFrame: """ Reorganize the values of the DataFrame according to the given index and columns. @@ -1012,12 +1016,12 @@ def as_tuple(x): level_names=(None,) + columns._column_names, verify=False, ) - return cudf.DataFrame._from_data( - ca, index=cudf.Index(index_labels, name=index.name) - ) + return cudf.DataFrame._from_data(ca, index=index_labels) -def pivot(data, columns=None, index=no_default, values=no_default): +def pivot( + data: cudf.DataFrame, columns=None, index=no_default, values=no_default +) -> cudf.DataFrame: """ Return reshaped DataFrame organized by the given index and column values. @@ -1027,10 +1031,10 @@ def pivot(data, columns=None, index=no_default, values=no_default): Parameters ---------- - columns : column name, optional - Column used to construct the columns of the result. - index : column name, optional - Column used to construct the index of the result. + columns : scalar or list of scalars, optional + Column label(s) used to construct the columns of the result. + index : scalar or list of scalars, optional + Column label(s) used to construct the index of the result. values : column name or list of column names, optional Column(s) whose values are rearranged to produce the result. If not specified, all remaining columns of the DataFrame @@ -1069,24 +1073,46 @@ def pivot(data, columns=None, index=no_default, values=no_default): """ values_is_list = True if values is no_default: + already_selected = set( + itertools.chain( + [index] if is_scalar(index) else index, + [columns] if is_scalar(columns) else columns, + ) + ) cols_to_select = [ - col for col in data._column_names if col not in (index, columns) + col for col in data._column_names if col not in already_selected ] elif not isinstance(values, (list, tuple)): cols_to_select = [values] values_is_list = False else: - cols_to_select = values + cols_to_select = values # type: ignore[assignment] if index is no_default: - index = data.index + index_data = data.index else: - index = cudf.Index(data.loc[:, index]) - columns = cudf.Index(data.loc[:, columns]) + index_data = data.loc[:, index] + if index_data.ndim == 2: + index_data = cudf.MultiIndex.from_frame(index_data) + if not is_scalar(index) and len(index) == 1: + # pandas converts single level MultiIndex to Index + index_data = index_data.get_level_values(0) + else: + index_data = cudf.Index(index_data) + + column_data = data.loc[:, columns] + if column_data.ndim == 2: + column_data = cudf.MultiIndex.from_frame(column_data) + else: + column_data = cudf.Index(column_data) # Create a DataFrame composed of columns from both # columns and index ca = ColumnAccessor( - dict(enumerate(itertools.chain(index._columns, columns._columns))), + dict( + enumerate( + itertools.chain(index_data._columns, column_data._columns) + ) + ), verify=False, ) columns_index = cudf.DataFrame._from_data(ca) @@ -1095,7 +1121,9 @@ def pivot(data, columns=None, index=no_default, values=no_default): if len(columns_index) != len(columns_index.drop_duplicates()): raise ValueError("Duplicate index-column pairs found. Cannot reshape.") - result = _pivot(data._data.select_by_label(cols_to_select), index, columns) + result = _pivot( + data._data.select_by_label(cols_to_select), index_data, column_data + ) # MultiIndex to Index if not values_is_list: diff --git a/python/cudf/cudf/tests/test_reshape.py b/python/cudf/cudf/tests/test_reshape.py index 26386abb05d..53fe5f7f30d 100644 --- a/python/cudf/cudf/tests/test_reshape.py +++ b/python/cudf/cudf/tests/test_reshape.py @@ -835,3 +835,20 @@ def test_crosstab_simple(): expected = pd.crosstab(a, [b, c], rownames=["a"], colnames=["b", "c"]) actual = cudf.crosstab(a, [b, c], rownames=["a"], colnames=["b", "c"]) assert_eq(expected, actual, check_dtype=False) + + +@pytest.mark.parametrize("index", [["ix"], ["ix", "foo"]]) +@pytest.mark.parametrize("columns", [["col"], ["col", "baz"]]) +def test_pivot_list_like_index_columns(index, columns): + data = { + "bar": ["x", "y", "z", "w"], + "col": ["a", "b", "a", "b"], + "foo": [1, 2, 3, 4], + "ix": [1, 1, 2, 2], + "baz": [0, 0, 0, 0], + } + pd_df = pd.DataFrame(data) + cudf_df = cudf.DataFrame(data) + result = cudf_df.pivot(columns=columns, index=index) + expected = pd_df.pivot(columns=columns, index=index) + assert_eq(result, expected)