diff --git a/packages/vaex-core/vaex/__init__.py b/packages/vaex-core/vaex/__init__.py index 50bbb0b921..673e179d18 100644 --- a/packages/vaex-core/vaex/__init__.py +++ b/packages/vaex-core/vaex/__init__.py @@ -379,7 +379,7 @@ def from_arrow_table(table) -> vaex.dataframe.DataFrame: def from_arrow_dataset(arrow_dataset) -> vaex.dataframe.DataFrame: - '''Create a DataFrame from an Apache Arrow dataset''' + '''Create a DataFrame from an Apache Arrow dataset.''' import vaex.arrow.dataset return from_dataset(vaex.arrow.dataset.DatasetArrow(arrow_dataset)) diff --git a/packages/vaex-core/vaex/arrow/dataset.py b/packages/vaex-core/vaex/arrow/dataset.py index 1c08c130f6..57a6ee482c 100644 --- a/packages/vaex-core/vaex/arrow/dataset.py +++ b/packages/vaex-core/vaex/arrow/dataset.py @@ -250,6 +250,39 @@ def __getstate__(self): return state + +class DatasetArrow(DatasetArrowBase): + snake_name = "arrow-dataset" + def __init__(self, ds, max_rows_read=1024**2*10): + self._arrow_ds = ds + super().__init__(max_rows_read=max_rows_read) + + @property + def _fingerprint(self): + return self._id + + def hashed(self): + raise NotImplementedError + + def _create_columns(self): + super()._create_columns() + # self._ids = frozendict({name: vaex.cache.fingerprint(self._fingerprint, name) for name in self._columns}) + self._ids = frozendict() + + def _create_dataset(self): + self._partitions = defaultdict(list) # path -> list (which will be an arrow array later on) + self._partition_keys = defaultdict(dict) # path -> key -> int/index + + for fragment in self._arrow_ds.get_fragments(): + keys = pa.dataset._get_partition_keys(fragment.partition_expression) + for name, value in keys.items(): + if value not in self._partitions[name]: + self._partitions[name].append(value) + self._partition_keys[fragment.path][name] = self._partitions[name].index(value) + self._partitions = {name: pa.array(values) for name, values in self._partitions.items()} + + + class DatasetArrowFileBase(vaex.dataset.Dataset): def __init__(self, path, fs_options, fs=None): super().__init__() diff --git a/tests/from_test.py b/tests/from_test.py index 6319f8dfd1..2562b0004e 100644 --- a/tests/from_test.py +++ b/tests/from_test.py @@ -1,4 +1,9 @@ +import pytest import vaex +from pathlib import Path + +HERE = Path(__file__).parent + def test_from_records(): df = vaex.from_records([ @@ -26,3 +31,12 @@ def test_from_records(): ], array_type="numpy") assert df.a.tolist() == [[1, 1], [11, 12], [13, 14]] assert df.a.shape == (3, 2) + + +def test_from_arrow_dataset(): + import pyarrow.dataset + path = HERE / 'data' / 'sample_arrow_dict.parquet' + ds = pyarrow.dataset.dataset(path) + df = vaex.from_arrow_dataset(ds) + assert df.col1.sum() == 45 + assert df.fingerprint() == df.fingerprint()