From 52bff4f9980dfdb7873a1bef2995fb7e74f797ee Mon Sep 17 00:00:00 2001 From: Nikil Ravi Date: Sun, 13 Oct 2024 15:01:40 -0700 Subject: [PATCH 1/5] add parquet support --- src/levanter/data/sharded_datasource.py | 31 +++++++++++++++++- tests/test_sharded_dataset.py | 43 ++++++++++++++++++++++++- 2 files changed, 72 insertions(+), 2 deletions(-) diff --git a/src/levanter/data/sharded_datasource.py b/src/levanter/data/sharded_datasource.py index 38682616d..494bd5f05 100644 --- a/src/levanter/data/sharded_datasource.py +++ b/src/levanter/data/sharded_datasource.py @@ -20,6 +20,8 @@ import datasets import fsspec import numpy as np +import pyarrow.parquet as pq +import pandas as pd from levanter.utils import fsspec_utils @@ -149,6 +151,10 @@ def datasource_from_json(urls_or_paths: Sequence[str]) -> ShardedDataSource[dict return JsonDataSource(urls_or_paths) +def datasource_from_parquet(urls_or_paths: Sequence[str]) -> ShardedDataSource[dict]: + return ParquetDataSource(urls_or_paths) + + class WrappedHFDataSource(ShardedDataSource[dict]): """ This class is responsible for loading a dataset from HuggingFace Datasets and returning the shards. @@ -238,6 +244,11 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[str]: data = json.load(f) for doc in data[row:]: yield doc[self.text_key] + case ".parquet": + table = pq.read_table(f) + sliced_table = table.slice(row) + for record in sliced_table.to_pylist(): + yield record[self.text_key] # assumes text_key is in record case _: raise ValueError(f"Unknown format {format}") @@ -313,7 +324,7 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[Tuple[np.ndar def _sniff_format_for_dataset(url): - good_formats = [".jsonl", ".txt", ".json"] + good_formats = [".jsonl", ".txt", ".json", ".parquet"] format_from_url = None # try both with and without compression (could be gz, bz2, etc, so look at the "first" extension) extensions = [os.path.splitext(url)[1], os.path.splitext(os.path.splitext(url)[0])[1]] @@ -417,6 +428,24 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]: return iter(data[row:]) +class ParquetDataSource(ShardedDataSource[dict]): + def __init__(self, urls): + self.urls = urls + self._shard_name_to_url_mapping = _mk_shard_name_mapping(urls) + + @property + def shard_names(self) -> Sequence[str]: + return list(self._shard_name_to_url_mapping.keys()) + + def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]: + url = self._shard_name_to_url_mapping[shard_name] + with fsspec.open(url, "r", compression="infer") as f: + table = pq.read_table(f) + sliced_table = table.slice(row) # zero-copy slicing + for record in sliced_table.to_pylist(): + yield record + + def _mk_shard_name_mapping(urls): _shard_name_to_url_mapping = {} # remove common prefix diff --git a/tests/test_sharded_dataset.py b/tests/test_sharded_dataset.py index b3c8bcc8d..3cf0d78e6 100644 --- a/tests/test_sharded_dataset.py +++ b/tests/test_sharded_dataset.py @@ -1,6 +1,6 @@ import tempfile -from levanter.data.sharded_datasource import AudioTextUrlDataSource, _sniff_format_for_dataset +from levanter.data.sharded_datasource import AudioTextUrlDataSource, _sniff_format_for_dataset, ParquetDataSource from test_utils import skip_if_no_soundlibs @@ -24,6 +24,47 @@ def test_sniff_format_for_json(): assert _sniff_format_for_dataset(f.name) == ".json" +def test_sniff_format_for_parquet(): + + import pyarrow as pa + import pyarrow.parquet as pq + + with tempfile.NamedTemporaryFile(suffix=".parquet") as f: + table = pa.table({'col1': [1, 2, 3], 'col2': ['a', 'b', 'c']}) + pq.write_table(table, f.name) + f.flush() + + assert _sniff_format_for_dataset(f.name) == ".parquet" + + @skip_if_no_soundlibs def test_resolve_audio_pointer(): AudioTextUrlDataSource.resolve_audio_pointer("https://ccrma.stanford.edu/~jos/mp3/trumpet.mp3", 16_000) + + +def test_basic_parquet_datasource_read_row(): + + import pyarrow as pa + import pyarrow.parquet as pq + + with tempfile.NamedTemporaryFile(suffix=".parquet") as f: + # Create a simple dataset + data = { + "column1": ["value1", "value2", "value3"], + "column2": [10, 20, 30] + } + table = pa.Table.from_pydict(data) + pq.write_table(table, f.name) + + # Instantiate the ParquetDataSource + datasource = ParquetDataSource([f.name]) + + # sanity check: Read data starting from row 1 + row_data = list(datasource.open_shard_at_row(shard_name=f.name.replace(".", "_"), row=1)) + + # Verify the output + assert len(row_data) == 2 # We expect 2 rows starting from index 1 + assert row_data[0]["column1"] == "value2" + assert row_data[0]["column2"] == 20 + assert row_data[1]["column1"] == "value3" + assert row_data[1]["column2"] == 30 \ No newline at end of file From af78281e9dbf47163c980d54981ed736000f7280 Mon Sep 17 00:00:00 2001 From: Nikil Ravi Date: Sun, 13 Oct 2024 15:15:06 -0700 Subject: [PATCH 2/5] lint, shard name fix --- tests/test_sharded_dataset.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/test_sharded_dataset.py b/tests/test_sharded_dataset.py index 3cf0d78e6..2fb9cf4d8 100644 --- a/tests/test_sharded_dataset.py +++ b/tests/test_sharded_dataset.py @@ -35,7 +35,7 @@ def test_sniff_format_for_parquet(): f.flush() assert _sniff_format_for_dataset(f.name) == ".parquet" - + @skip_if_no_soundlibs def test_resolve_audio_pointer(): @@ -56,15 +56,17 @@ def test_basic_parquet_datasource_read_row(): table = pa.Table.from_pydict(data) pq.write_table(table, f.name) - # Instantiate the ParquetDataSource datasource = ParquetDataSource([f.name]) + assert len(datasource.shard_names) == 1, "Expected only one shard" + shard_name = datasource.shard_names[0] + # sanity check: Read data starting from row 1 - row_data = list(datasource.open_shard_at_row(shard_name=f.name.replace(".", "_"), row=1)) + row_data = list(datasource.open_shard_at_row(shard_name=shard_name, row=1)) # Verify the output assert len(row_data) == 2 # We expect 2 rows starting from index 1 assert row_data[0]["column1"] == "value2" assert row_data[0]["column2"] == 20 assert row_data[1]["column1"] == "value3" - assert row_data[1]["column2"] == 30 \ No newline at end of file + assert row_data[1]["column2"] == 30 From 8d09cfd1c216a7bc6f302ae529ae8d6c4412b005 Mon Sep 17 00:00:00 2001 From: Nikil Ravi Date: Sun, 13 Oct 2024 15:38:47 -0700 Subject: [PATCH 3/5] pre-commit --- src/levanter/data/sharded_datasource.py | 5 ++--- tests/test_sharded_dataset.py | 23 +++++++++++++++-------- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/src/levanter/data/sharded_datasource.py b/src/levanter/data/sharded_datasource.py index 494bd5f05..10eb42b1b 100644 --- a/src/levanter/data/sharded_datasource.py +++ b/src/levanter/data/sharded_datasource.py @@ -21,7 +21,6 @@ import fsspec import numpy as np import pyarrow.parquet as pq -import pandas as pd from levanter.utils import fsspec_utils @@ -248,7 +247,7 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[str]: table = pq.read_table(f) sliced_table = table.slice(row) for record in sliced_table.to_pylist(): - yield record[self.text_key] # assumes text_key is in record + yield record[self.text_key] # assumes text_key is in record case _: raise ValueError(f"Unknown format {format}") @@ -441,7 +440,7 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]: url = self._shard_name_to_url_mapping[shard_name] with fsspec.open(url, "r", compression="infer") as f: table = pq.read_table(f) - sliced_table = table.slice(row) # zero-copy slicing + sliced_table = table.slice(row) # zero-copy slicing for record in sliced_table.to_pylist(): yield record diff --git a/tests/test_sharded_dataset.py b/tests/test_sharded_dataset.py index 2fb9cf4d8..b732596e5 100644 --- a/tests/test_sharded_dataset.py +++ b/tests/test_sharded_dataset.py @@ -1,6 +1,7 @@ +import os import tempfile -from levanter.data.sharded_datasource import AudioTextUrlDataSource, _sniff_format_for_dataset, ParquetDataSource +from levanter.data.sharded_datasource import AudioTextUrlDataSource, ParquetDataSource, _sniff_format_for_dataset from test_utils import skip_if_no_soundlibs @@ -30,7 +31,7 @@ def test_sniff_format_for_parquet(): import pyarrow.parquet as pq with tempfile.NamedTemporaryFile(suffix=".parquet") as f: - table = pa.table({'col1': [1, 2, 3], 'col2': ['a', 'b', 'c']}) + table = pa.table({"col1": [1, 2, 3], "col2": ["a", "b", "c"]}) pq.write_table(table, f.name) f.flush() @@ -47,20 +48,23 @@ def test_basic_parquet_datasource_read_row(): import pyarrow as pa import pyarrow.parquet as pq - with tempfile.NamedTemporaryFile(suffix=".parquet") as f: + with tempfile.NamedTemporaryFile(suffix=".parquet", delete=False) as f: # Create a simple dataset - data = { - "column1": ["value1", "value2", "value3"], - "column2": [10, 20, 30] - } + data = {"column1": ["value1", "value2", "value3"], "column2": [10, 20, 30]} table = pa.Table.from_pydict(data) pq.write_table(table, f.name) - datasource = ParquetDataSource([f.name]) + try: + + datasource = ParquetDataSource([os.path.abspath(f.name)]) assert len(datasource.shard_names) == 1, "Expected only one shard" shard_name = datasource.shard_names[0] + print(f"Shard name: {shard_name}") + print("File name: ", f.name) + print("File path: ", os.path.abspath(f.name)) + # sanity check: Read data starting from row 1 row_data = list(datasource.open_shard_at_row(shard_name=shard_name, row=1)) @@ -70,3 +74,6 @@ def test_basic_parquet_datasource_read_row(): assert row_data[0]["column2"] == 20 assert row_data[1]["column1"] == "value3" assert row_data[1]["column2"] == 30 + + finally: + os.unlink(f.name) From 50715e9bc64d05dfb655087862c826d59a377ad7 Mon Sep 17 00:00:00 2001 From: Nikil Ravi Date: Sun, 13 Oct 2024 15:49:49 -0700 Subject: [PATCH 4/5] read as binary file --- src/levanter/data/sharded_datasource.py | 2 +- tests/test_sharded_dataset.py | 4 ---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/src/levanter/data/sharded_datasource.py b/src/levanter/data/sharded_datasource.py index 10eb42b1b..208116ca6 100644 --- a/src/levanter/data/sharded_datasource.py +++ b/src/levanter/data/sharded_datasource.py @@ -438,7 +438,7 @@ def shard_names(self) -> Sequence[str]: def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]: url = self._shard_name_to_url_mapping[shard_name] - with fsspec.open(url, "r", compression="infer") as f: + with fsspec.open(url, "rb", compression="infer") as f: table = pq.read_table(f) sliced_table = table.slice(row) # zero-copy slicing for record in sliced_table.to_pylist(): diff --git a/tests/test_sharded_dataset.py b/tests/test_sharded_dataset.py index b732596e5..265a70867 100644 --- a/tests/test_sharded_dataset.py +++ b/tests/test_sharded_dataset.py @@ -61,10 +61,6 @@ def test_basic_parquet_datasource_read_row(): assert len(datasource.shard_names) == 1, "Expected only one shard" shard_name = datasource.shard_names[0] - print(f"Shard name: {shard_name}") - print("File name: ", f.name) - print("File path: ", os.path.abspath(f.name)) - # sanity check: Read data starting from row 1 row_data = list(datasource.open_shard_at_row(shard_name=shard_name, row=1)) From 3fe89957b55f799c1eb42200d8152dbdecd50c21 Mon Sep 17 00:00:00 2001 From: Nikil Ravi Date: Sun, 13 Oct 2024 18:59:37 -0700 Subject: [PATCH 5/5] simplify test --- tests/test_sharded_dataset.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/tests/test_sharded_dataset.py b/tests/test_sharded_dataset.py index 265a70867..90ab6c34b 100644 --- a/tests/test_sharded_dataset.py +++ b/tests/test_sharded_dataset.py @@ -48,14 +48,12 @@ def test_basic_parquet_datasource_read_row(): import pyarrow as pa import pyarrow.parquet as pq - with tempfile.NamedTemporaryFile(suffix=".parquet", delete=False) as f: + with tempfile.NamedTemporaryFile(suffix=".parquet", delete=True) as f: # Create a simple dataset data = {"column1": ["value1", "value2", "value3"], "column2": [10, 20, 30]} table = pa.Table.from_pydict(data) pq.write_table(table, f.name) - try: - datasource = ParquetDataSource([os.path.abspath(f.name)]) assert len(datasource.shard_names) == 1, "Expected only one shard" @@ -70,6 +68,3 @@ def test_basic_parquet_datasource_read_row(): assert row_data[0]["column2"] == 20 assert row_data[1]["column1"] == "value3" assert row_data[1]["column2"] == 30 - - finally: - os.unlink(f.name)