From 698e98f5345e0b7d31ec17e2acf8de9006d57bb6 Mon Sep 17 00:00:00 2001 From: Nikil Ravi Date: Tue, 15 Oct 2024 01:42:18 -0700 Subject: [PATCH] fixes, and another test --- src/levanter/data/sharded_datasource.py | 10 +++++--- tests/test_sharded_dataset.py | 33 ++++++++++++++++++++++++- 2 files changed, 39 insertions(+), 4 deletions(-) diff --git a/src/levanter/data/sharded_datasource.py b/src/levanter/data/sharded_datasource.py index e43b12fde..2c7eb42d9 100644 --- a/src/levanter/data/sharded_datasource.py +++ b/src/levanter/data/sharded_datasource.py @@ -249,8 +249,10 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[str]: if row >= total_rows: return iter([]) + num_row_groups = parquet_file.metadata.num_row_groups + # Compute cumulative row counts - row_counts = [rg.num_rows for rg in parquet_file.metadata.row_groups] + row_counts = [parquet_file.metadata.row_group(i).num_rows for i in range(num_row_groups)] cumulative_rows = [0] for count in row_counts: cumulative_rows.append(cumulative_rows[-1] + count) @@ -465,8 +467,10 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]: if row >= total_rows: return iter([]) - # compute cumulative row counts - row_counts = [rg_meta.num_rows for rg_meta in parquet_file.metadata.row_groups] + num_row_groups = parquet_file.metadata.num_row_groups + + # Compute cumulative row counts + row_counts = [parquet_file.metadata.row_group(i).num_rows for i in range(num_row_groups)] cumulative_rows = [0] for count in row_counts: cumulative_rows.append(cumulative_rows[-1] + count) diff --git a/tests/test_sharded_dataset.py b/tests/test_sharded_dataset.py index 90ab6c34b..d93793bff 100644 --- a/tests/test_sharded_dataset.py +++ b/tests/test_sharded_dataset.py @@ -1,7 +1,12 @@ import os import tempfile -from levanter.data.sharded_datasource import AudioTextUrlDataSource, ParquetDataSource, _sniff_format_for_dataset +from levanter.data.sharded_datasource import ( + AudioTextUrlDataSource, + ParquetDataSource, + _sniff_format_for_dataset, + TextUrlDataSource, +) from test_utils import skip_if_no_soundlibs @@ -68,3 +73,29 @@ def test_basic_parquet_datasource_read_row(): assert row_data[0]["column2"] == 20 assert row_data[1]["column1"] == "value3" assert row_data[1]["column2"] == 30 + + +def test_text_url_data_source_parquet(): + import pyarrow as pa + import pyarrow.parquet as pq + + with tempfile.NamedTemporaryFile(suffix=".parquet", delete=True) as f: + data = { + "text": ["line1", "line2", "line3", "line4", "line5", "line6"], + "column2": [10, 20, 30, 40, 50, 60], + } + table = pa.Table.from_pydict(data) + pq.write_table(table, f.name) + + datasource = TextUrlDataSource([os.path.abspath(f.name)], text_key="text") + + assert len(datasource.shard_names) == 1, "Expected only one shard" + shard_name = datasource.shard_names[0] + + # Read data starting from row 2 + row_data = list(datasource.open_shard_at_row(shard_name=shard_name, row=2)) + + # Verify the output + expected_texts = ["line3", "line4", "line5", "line6"] + assert len(row_data) == len(expected_texts), f"Expected {len(expected_texts)} rows starting from index 2" + assert row_data == expected_texts, f"Expected texts {expected_texts}, got {row_data}"