diff --git a/python/pyarrow/pandas_compat.py b/python/pyarrow/pandas_compat.py index 7fbde36bc23e9..5a930a41f0300 100644 --- a/python/pyarrow/pandas_compat.py +++ b/python/pyarrow/pandas_compat.py @@ -848,6 +848,25 @@ def _get_extension_dtypes(table, columns_metadata, types_mapper=None): if _pandas_api.extension_dtype is None: return ext_columns + # use the specified mapping of built-in arrow types to pandas dtypes + if types_mapper: + for field in table.schema: + typ = field.type + pandas_dtype = types_mapper(typ) + if pandas_dtype is not None: + ext_columns[field.name] = pandas_dtype + + # infer from extension type in the schema + for field in table.schema: + typ = field.type + if field.name not in ext_columns and isinstance(typ, pa.BaseExtensionType): + try: + pandas_dtype = typ.to_pandas_dtype() + except NotImplementedError: + pass + else: + ext_columns[field.name] = pandas_dtype + # infer the extension columns from the pandas metadata for col_meta in columns_metadata: try: @@ -856,7 +875,7 @@ def _get_extension_dtypes(table, columns_metadata, types_mapper=None): name = col_meta['name'] dtype = col_meta['numpy_type'] - if dtype not in _pandas_supported_numpy_types: + if name not in ext_columns and dtype not in _pandas_supported_numpy_types: # pandas_dtype is expensive, so avoid doing this for types # that are certainly numpy dtypes pandas_dtype = _pandas_api.pandas_dtype(dtype) @@ -864,25 +883,6 @@ def _get_extension_dtypes(table, columns_metadata, types_mapper=None): if hasattr(pandas_dtype, "__from_arrow__"): ext_columns[name] = pandas_dtype - # infer from extension type in the schema - for field in table.schema: - typ = field.type - if isinstance(typ, pa.BaseExtensionType): - try: - pandas_dtype = typ.to_pandas_dtype() - except NotImplementedError: - pass - else: - ext_columns[field.name] = pandas_dtype - - # use the specified mapping of built-in arrow types to pandas dtypes - if types_mapper: - for field in table.schema: - typ = field.type - pandas_dtype = types_mapper(typ) - if pandas_dtype is not None: - ext_columns[field.name] = pandas_dtype - return ext_columns diff --git a/python/pyarrow/tests/test_pandas.py b/python/pyarrow/tests/test_pandas.py index 178a073ed59dc..1186f87b0322a 100644 --- a/python/pyarrow/tests/test_pandas.py +++ b/python/pyarrow/tests/test_pandas.py @@ -4411,6 +4411,31 @@ def test_to_pandas_extension_dtypes_mapping(): assert isinstance(result['a'].dtype, pd.PeriodDtype) +def test_to_pandas_extension_dtypes_mapping_complex_type(): + # https://github.com/apache/arrow/pull/44720 + if Version(pd.__version__) < Version("1.5.2"): + pytest.skip("Test relies on pd.ArrowDtype") + pa_type = pa.struct( + [ + pa.field("bar", pa.bool_(), nullable=False), + pa.field("baz", pa.float32(), nullable=True), + ], + ) + pd_type = pd.ArrowDtype(pa_type) + schema = pa.schema([pa.field("foo", pa_type)]) + df0 = pd.DataFrame( + [ + {"foo": {"bar": True, "baz": np.float32(1)}}, + {"foo": {"bar": True, "baz": None}}, + ], + ).astype({"foo": pd_type}) + + # Round trip df0 into df1 + table = pa.Table.from_pandas(df0, schema=schema) + df1 = table.to_pandas(types_mapper=pd.ArrowDtype) + pd.testing.assert_frame_equal(df0, df1) + + def test_array_to_pandas(): if Version(pd.__version__) < Version("1.1"): pytest.skip("ExtensionDtype to_pandas method missing")