From 8871fccaf40f7bec23543d3da725f9180c27b8ef Mon Sep 17 00:00:00 2001 From: Nikil Ravi <nravi@stanford.edu> Date: Tue, 15 Oct 2024 01:27:50 -0700 Subject: [PATCH 1/3] changes in how parquet is read --- src/levanter/data/sharded_datasource.py | 60 +++++++++++++++++++++---- 1 file changed, 52 insertions(+), 8 deletions(-) diff --git a/src/levanter/data/sharded_datasource.py b/src/levanter/data/sharded_datasource.py index 208116ca6..e43b12fde 100644 --- a/src/levanter/data/sharded_datasource.py +++ b/src/levanter/data/sharded_datasource.py @@ -244,10 +244,31 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[str]: for doc in data[row:]: yield doc[self.text_key] case ".parquet": - table = pq.read_table(f) - sliced_table = table.slice(row) - for record in sliced_table.to_pylist(): - yield record[self.text_key] # assumes text_key is in record + parquet_file = pq.ParquetFile(f) + total_rows = parquet_file.metadata.num_rows + if row >= total_rows: + return iter([]) + + # Compute cumulative row counts + row_counts = [rg.num_rows for rg in parquet_file.metadata.row_groups] + cumulative_rows = [0] + for count in row_counts: + cumulative_rows.append(cumulative_rows[-1] + count) + + # Find the starting row group and row within it + for idx, cum_row in enumerate(cumulative_rows): + if cum_row > row: + row_group_index = idx - 1 + start_row_in_group = row - cumulative_rows[row_group_index] + break + + # Read from the starting row group onwards + for rg_idx in range(row_group_index, parquet_file.num_row_groups): + table = parquet_file.read_row_group(rg_idx, columns=[self.text_key]) + if rg_idx == row_group_index: + table = table.slice(start_row_in_group) + for record in table.to_pylist(): + yield record[self.text_key] case _: raise ValueError(f"Unknown format {format}") @@ -439,10 +460,33 @@ def shard_names(self) -> Sequence[str]: def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]: url = self._shard_name_to_url_mapping[shard_name] with fsspec.open(url, "rb", compression="infer") as f: - table = pq.read_table(f) - sliced_table = table.slice(row) # zero-copy slicing - for record in sliced_table.to_pylist(): - yield record + parquet_file = pq.ParquetFile(f) + total_rows = parquet_file.metadata.num_rows + if row >= total_rows: + return iter([]) + + # compute cumulative row counts + row_counts = [rg_meta.num_rows for rg_meta in parquet_file.metadata.row_groups] + cumulative_rows = [0] + for count in row_counts: + cumulative_rows.append(cumulative_rows[-1] + count) + + # find starting row group and also find the row within it + for idx, cum_row in enumerate(cumulative_rows): + if cum_row > row: + row_group_index = idx - 1 + start_row_in_group = row - cumulative_rows[row_group_index] + break + + # read from the starting row group onwards + for rg_idx in range(row_group_index, parquet_file.num_row_groups): + table = parquet_file.read_row_group(rg_idx) + + # if we're in the row group we want, slice the table at/from the row we want + if rg_idx == row_group_index: + table = table.slice(start_row_in_group) + + yield from table.to_pylist() def _mk_shard_name_mapping(urls): From 698e98f5345e0b7d31ec17e2acf8de9006d57bb6 Mon Sep 17 00:00:00 2001 From: Nikil Ravi <nravi@stanford.edu> Date: Tue, 15 Oct 2024 01:42:18 -0700 Subject: [PATCH 2/3] fixes, and another test --- src/levanter/data/sharded_datasource.py | 10 +++++--- tests/test_sharded_dataset.py | 33 ++++++++++++++++++++++++- 2 files changed, 39 insertions(+), 4 deletions(-) diff --git a/src/levanter/data/sharded_datasource.py b/src/levanter/data/sharded_datasource.py index e43b12fde..2c7eb42d9 100644 --- a/src/levanter/data/sharded_datasource.py +++ b/src/levanter/data/sharded_datasource.py @@ -249,8 +249,10 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[str]: if row >= total_rows: return iter([]) + num_row_groups = parquet_file.metadata.num_row_groups + # Compute cumulative row counts - row_counts = [rg.num_rows for rg in parquet_file.metadata.row_groups] + row_counts = [parquet_file.metadata.row_group(i).num_rows for i in range(num_row_groups)] cumulative_rows = [0] for count in row_counts: cumulative_rows.append(cumulative_rows[-1] + count) @@ -465,8 +467,10 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]: if row >= total_rows: return iter([]) - # compute cumulative row counts - row_counts = [rg_meta.num_rows for rg_meta in parquet_file.metadata.row_groups] + num_row_groups = parquet_file.metadata.num_row_groups + + # Compute cumulative row counts + row_counts = [parquet_file.metadata.row_group(i).num_rows for i in range(num_row_groups)] cumulative_rows = [0] for count in row_counts: cumulative_rows.append(cumulative_rows[-1] + count) diff --git a/tests/test_sharded_dataset.py b/tests/test_sharded_dataset.py index 90ab6c34b..d93793bff 100644 --- a/tests/test_sharded_dataset.py +++ b/tests/test_sharded_dataset.py @@ -1,7 +1,12 @@ import os import tempfile -from levanter.data.sharded_datasource import AudioTextUrlDataSource, ParquetDataSource, _sniff_format_for_dataset +from levanter.data.sharded_datasource import ( + AudioTextUrlDataSource, + ParquetDataSource, + _sniff_format_for_dataset, + TextUrlDataSource, +) from test_utils import skip_if_no_soundlibs @@ -68,3 +73,29 @@ def test_basic_parquet_datasource_read_row(): assert row_data[0]["column2"] == 20 assert row_data[1]["column1"] == "value3" assert row_data[1]["column2"] == 30 + + +def test_text_url_data_source_parquet(): + import pyarrow as pa + import pyarrow.parquet as pq + + with tempfile.NamedTemporaryFile(suffix=".parquet", delete=True) as f: + data = { + "text": ["line1", "line2", "line3", "line4", "line5", "line6"], + "column2": [10, 20, 30, 40, 50, 60], + } + table = pa.Table.from_pydict(data) + pq.write_table(table, f.name) + + datasource = TextUrlDataSource([os.path.abspath(f.name)], text_key="text") + + assert len(datasource.shard_names) == 1, "Expected only one shard" + shard_name = datasource.shard_names[0] + + # Read data starting from row 2 + row_data = list(datasource.open_shard_at_row(shard_name=shard_name, row=2)) + + # Verify the output + expected_texts = ["line3", "line4", "line5", "line6"] + assert len(row_data) == len(expected_texts), f"Expected {len(expected_texts)} rows starting from index 2" + assert row_data == expected_texts, f"Expected texts {expected_texts}, got {row_data}" From 38a7e688b1f3fb61ea77a2f98060a8e99071e3b5 Mon Sep 17 00:00:00 2001 From: Nikil Ravi <nravi@stanford.edu> Date: Tue, 15 Oct 2024 01:57:18 -0700 Subject: [PATCH 3/3] open in binary mode --- src/levanter/data/sharded_datasource.py | 22 +++++++++++++--------- tests/test_sharded_dataset.py | 2 +- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/src/levanter/data/sharded_datasource.py b/src/levanter/data/sharded_datasource.py index 2c7eb42d9..80dd1b4b2 100644 --- a/src/levanter/data/sharded_datasource.py +++ b/src/levanter/data/sharded_datasource.py @@ -224,26 +224,30 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[str]: compression = "infer" if url.endswith(".zstd"): # hacky way to detect zstd compression = "zstd" - with fsspec.open(url, "r", compression=compression) as f: - format = _sniff_format_for_dataset(url) - match format: - case ".jsonl": + + format = _sniff_format_for_dataset(url) + match format: + case ".jsonl": + with fsspec.open(url, "r", compression=compression) as f: # TODO: would be nice if we could seek faster than this. Right now, all we do is skip json parsing # which is not nothing, but not ideal. for line in f: if i >= row: yield json.loads(line)[self.text_key] i += 1 - case ".txt": + case ".txt": + with fsspec.open(url, "r", compression=compression) as f: for line in f: if i >= row: yield line i += 1 - case ".json": + case ".json": + with fsspec.open(url, "r", compression=compression) as f: data = json.load(f) for doc in data[row:]: yield doc[self.text_key] - case ".parquet": + case ".parquet": + with fsspec.open(url, "rb", compression=compression) as f: parquet_file = pq.ParquetFile(f) total_rows = parquet_file.metadata.num_rows if row >= total_rows: @@ -271,8 +275,8 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[str]: table = table.slice(start_row_in_group) for record in table.to_pylist(): yield record[self.text_key] - case _: - raise ValueError(f"Unknown format {format}") + case _: + raise ValueError(f"Unknown format {format}") class AudioTextUrlDataSource(ShardedDataSource[Tuple[np.ndarray, int, str]]): diff --git a/tests/test_sharded_dataset.py b/tests/test_sharded_dataset.py index d93793bff..a375d9344 100644 --- a/tests/test_sharded_dataset.py +++ b/tests/test_sharded_dataset.py @@ -4,8 +4,8 @@ from levanter.data.sharded_datasource import ( AudioTextUrlDataSource, ParquetDataSource, - _sniff_format_for_dataset, TextUrlDataSource, + _sniff_format_for_dataset, ) from test_utils import skip_if_no_soundlibs