diff --git a/python/cudf/cudf/core/dataframe.py b/python/cudf/cudf/core/dataframe.py index 99e4588d608..7b7fc87a6dc 100644 --- a/python/cudf/cudf/core/dataframe.py +++ b/python/cudf/cudf/core/dataframe.py @@ -5450,9 +5450,11 @@ def from_arrow(cls, table): """ index_col = None col_index_names = None + physical_column_md = [] if isinstance(table, pa.Table) and isinstance( table.schema.pandas_metadata, dict ): + physical_column_md = table.schema.pandas_metadata["columns"] index_col = table.schema.pandas_metadata["index_columns"] if "column_indexes" in table.schema.pandas_metadata: col_index_names = [] @@ -5480,7 +5482,18 @@ def from_arrow(cls, table): # https://github.com/apache/arrow/issues/15178 out = out.set_index(idx) else: - out = out.set_index(index_col[0]) + out = out.set_index(index_col) + + if ( + "__index_level_0__" in out.index.names + and len(out.index.names) == 1 + ): + real_index_name = None + for md in physical_column_md: + if md["field_name"] == "__index_level_0__": + real_index_name = md["name"] + break + out.index.name = real_index_name return out @@ -5530,42 +5543,43 @@ def to_arrow(self, preserve_index=None): write_index = preserve_index is not False keep_range_index = write_index and preserve_index is None index = self.index + index_levels = [self.index] if write_index: if isinstance(index, cudf.RangeIndex) and keep_range_index: - descr = { - "kind": "range", - "name": index.name, - "start": index._start, - "stop": index._stop, - "step": 1, - } + index_descr = [ + { + "kind": "range", + "name": index.name, + "start": index._start, + "stop": index._stop, + "step": 1, + } + ] else: if isinstance(index, cudf.RangeIndex): index = index._as_int_index() index.name = "__index_level_0__" if isinstance(index, MultiIndex): - gen_names = tuple( - f"level_{i}" for i, _ in enumerate(index._data.names) - ) + index_descr = list(index._data.names) + index_levels = index.levels else: - gen_names = ( + index_descr = ( index.names if index.name is not None else ("index",) ) - for gen_name, col_name in zip(gen_names, index._data.names): + for gen_name, col_name in zip(index_descr, index._data.names): data._insert( data.shape[1], gen_name, index._data[col_name], ) - descr = gen_names[0] - index_descr.append(descr) out = super(DataFrame, data).to_arrow() + # import pdb; pdb.set_trace() metadata = pa.pandas_compat.construct_metadata( columns_to_convert=[self[col] for col in self._data.names], df=self, column_names=out.schema.names, - index_levels=[index], + index_levels=index_levels, index_descriptors=index_descr, preserve_index=preserve_index, types=out.schema.types, diff --git a/python/cudf/cudf/tests/test_dataframe.py b/python/cudf/cudf/tests/test_dataframe.py index ead1ab2da6c..df0e22c5e43 100644 --- a/python/cudf/cudf/tests/test_dataframe.py +++ b/python/cudf/cudf/tests/test_dataframe.py @@ -2769,6 +2769,28 @@ def test_arrow_pandas_compat(pdf, gdf, preserve_index): assert_eq(pdf2, gdf2) +@pytest.mark.parametrize( + "index", [None, cudf.RangeIndex(3, name="a"), "a", "b", ["a", "b"]] +) +@pytest.mark.parametrize("preserve_index", [True, False, None]) +def test_arrow_round_trip(preserve_index, index): + data = {"a": [4, 5, 6], "b": ["cat", "dog", "bird"]} + if isinstance(index, (list, str)): + gdf = cudf.DataFrame(data).set_index(index) + else: + gdf = cudf.DataFrame(data, index=index) + + table = gdf.to_arrow(preserve_index=preserve_index) + table_pd = pa.Table.from_pandas( + gdf.to_pandas(), preserve_index=preserve_index + ) + + gdf_out = cudf.DataFrame.from_arrow(table) + pdf_out = table_pd.to_pandas() + + assert_eq(gdf_out, pdf_out) + + @pytest.mark.parametrize("dtype", NUMERIC_TYPES + ["bool"]) def test_cuda_array_interface(dtype): np_data = np.arange(10).astype(dtype) diff --git a/python/dask_cudf/dask_cudf/backends.py b/python/dask_cudf/dask_cudf/backends.py index 5401bcd3767..94528325aea 100644 --- a/python/dask_cudf/dask_cudf/backends.py +++ b/python/dask_cudf/dask_cudf/backends.py @@ -384,18 +384,6 @@ def _cudf_to_table(obj, preserve_index=None, **kwargs): "Ignoring the following arguments to " f"`to_pyarrow_table_dispatch`: {list(kwargs)}" ) - - # TODO: Remove this logic when cudf#14159 is resolved - # (see: https://github.com/rapidsai/cudf/issues/14159) - if preserve_index and isinstance(obj.index, cudf.RangeIndex): - obj = obj.copy() - obj.index.name = ( - obj.index.name - if obj.index.name is not None - else "__index_level_0__" - ) - obj.index = obj.index._as_int_index() - return obj.to_arrow(preserve_index=preserve_index) @@ -408,15 +396,7 @@ def _table_to_cudf(obj, table, self_destruct=None, **kwargs): f"Ignoring the following arguments to " f"`from_pyarrow_table_dispatch`: {list(kwargs)}" ) - result = obj.from_arrow(table) - - # TODO: Remove this logic when cudf#14159 is resolved - # (see: https://github.com/rapidsai/cudf/issues/14159) - if "__index_level_0__" in result.index.names: - assert len(result.index.names) == 1 - result.index.name = None - - return result + return obj.from_arrow(table) @union_categoricals_dispatch.register((cudf.Series, cudf.BaseIndex)) diff --git a/python/dask_cudf/dask_cudf/tests/test_dispatch.py b/python/dask_cudf/dask_cudf/tests/test_dispatch.py index 76703206726..a12481a7bb4 100644 --- a/python/dask_cudf/dask_cudf/tests/test_dispatch.py +++ b/python/dask_cudf/dask_cudf/tests/test_dispatch.py @@ -25,17 +25,24 @@ def test_is_categorical_dispatch(): @pytest.mark.parametrize("preserve_index", [True, False]) -def test_pyarrow_conversion_dispatch(preserve_index): +@pytest.mark.parametrize("index", [None, cudf.RangeIndex(10, name="foo")]) +def test_pyarrow_conversion_dispatch(preserve_index, index): from dask.dataframe.dispatch import ( from_pyarrow_table_dispatch, to_pyarrow_table_dispatch, ) - df1 = cudf.DataFrame(np.random.randn(10, 3), columns=list("abc")) + df1 = cudf.DataFrame( + np.random.randn(10, 3), columns=list("abc"), index=index + ) df2 = from_pyarrow_table_dispatch( df1, to_pyarrow_table_dispatch(df1, preserve_index=preserve_index) ) + # preserve_index=False doesn't retain index metadata + if not preserve_index and index is not None: + df1.index.name = None + assert type(df1) == type(df2) assert_eq(df1, df2)