diff --git a/src/levanter/data/sharded_datasource.py b/src/levanter/data/sharded_datasource.py index 5d9a9f224..186a0d9dd 100644 --- a/src/levanter/data/sharded_datasource.py +++ b/src/levanter/data/sharded_datasource.py @@ -233,32 +233,59 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[str]: compression = "infer" if url.endswith(".zstd"): # hacky way to detect zstd compression = "zstd" - with fsspec.open(url, "r", compression=compression) as f: - format = _sniff_format_for_dataset(url) - match format: - case ".jsonl": + + format = _sniff_format_for_dataset(url) + match format: + case ".jsonl": + with fsspec.open(url, "r", compression=compression) as f: # TODO: would be nice if we could seek faster than this. Right now, all we do is skip json parsing # which is not nothing, but not ideal. for line in f: if i >= row: yield json.loads(line)[self.text_key] i += 1 - case ".txt": + case ".txt": + with fsspec.open(url, "r", compression=compression) as f: for line in f: if i >= row: yield line i += 1 - case ".json": + case ".json": + with fsspec.open(url, "r", compression=compression) as f: data = json.load(f) for doc in data[row:]: yield doc[self.text_key] - case ".parquet": - table = pq.read_table(f) - sliced_table = table.slice(row) - for record in sliced_table.to_pylist(): - yield record[self.text_key] # assumes text_key is in record - case _: - raise ValueError(f"Unknown format {format}") + case ".parquet": + with fsspec.open(url, "rb", compression=compression) as f: + parquet_file = pq.ParquetFile(f) + total_rows = parquet_file.metadata.num_rows + if row >= total_rows: + return iter([]) + + num_row_groups = parquet_file.metadata.num_row_groups + + # Compute cumulative row counts + row_counts = [parquet_file.metadata.row_group(i).num_rows for i in range(num_row_groups)] + cumulative_rows = [0] + for count in row_counts: + cumulative_rows.append(cumulative_rows[-1] + count) + + # Find the starting row group and row within it + for idx, cum_row in enumerate(cumulative_rows): + if cum_row > row: + row_group_index = idx - 1 + start_row_in_group = row - cumulative_rows[row_group_index] + break + + # Read from the starting row group onwards + for rg_idx in range(row_group_index, parquet_file.num_row_groups): + table = parquet_file.read_row_group(rg_idx, columns=[self.text_key]) + if rg_idx == row_group_index: + table = table.slice(start_row_in_group) + for record in table.to_pylist(): + yield record[self.text_key] + case _: + raise ValueError(f"Unknown format {format}") class AudioTextUrlDataSource(ShardedDataSource[Tuple[np.ndarray, int, str]]): @@ -448,10 +475,35 @@ def shard_names(self) -> Sequence[str]: def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]: url = self._shard_name_to_url_mapping[shard_name] with fsspec.open(url, "rb", compression="infer") as f: - table = pq.read_table(f) - sliced_table = table.slice(row) # zero-copy slicing - for record in sliced_table.to_pylist(): - yield record + parquet_file = pq.ParquetFile(f) + total_rows = parquet_file.metadata.num_rows + if row >= total_rows: + return iter([]) + + num_row_groups = parquet_file.metadata.num_row_groups + + # Compute cumulative row counts + row_counts = [parquet_file.metadata.row_group(i).num_rows for i in range(num_row_groups)] + cumulative_rows = [0] + for count in row_counts: + cumulative_rows.append(cumulative_rows[-1] + count) + + # find starting row group and also find the row within it + for idx, cum_row in enumerate(cumulative_rows): + if cum_row > row: + row_group_index = idx - 1 + start_row_in_group = row - cumulative_rows[row_group_index] + break + + # read from the starting row group onwards + for rg_idx in range(row_group_index, parquet_file.num_row_groups): + table = parquet_file.read_row_group(rg_idx) + + # if we're in the row group we want, slice the table at/from the row we want + if rg_idx == row_group_index: + table = table.slice(start_row_in_group) + + yield from table.to_pylist() def _mk_shard_name_mapping(urls): diff --git a/tests/test_sharded_dataset.py b/tests/test_sharded_dataset.py index 90ab6c34b..a375d9344 100644 --- a/tests/test_sharded_dataset.py +++ b/tests/test_sharded_dataset.py @@ -1,7 +1,12 @@ import os import tempfile -from levanter.data.sharded_datasource import AudioTextUrlDataSource, ParquetDataSource, _sniff_format_for_dataset +from levanter.data.sharded_datasource import ( + AudioTextUrlDataSource, + ParquetDataSource, + TextUrlDataSource, + _sniff_format_for_dataset, +) from test_utils import skip_if_no_soundlibs @@ -68,3 +73,29 @@ def test_basic_parquet_datasource_read_row(): assert row_data[0]["column2"] == 20 assert row_data[1]["column1"] == "value3" assert row_data[1]["column2"] == 30 + + +def test_text_url_data_source_parquet(): + import pyarrow as pa + import pyarrow.parquet as pq + + with tempfile.NamedTemporaryFile(suffix=".parquet", delete=True) as f: + data = { + "text": ["line1", "line2", "line3", "line4", "line5", "line6"], + "column2": [10, 20, 30, 40, 50, 60], + } + table = pa.Table.from_pydict(data) + pq.write_table(table, f.name) + + datasource = TextUrlDataSource([os.path.abspath(f.name)], text_key="text") + + assert len(datasource.shard_names) == 1, "Expected only one shard" + shard_name = datasource.shard_names[0] + + # Read data starting from row 2 + row_data = list(datasource.open_shard_at_row(shard_name=shard_name, row=2)) + + # Verify the output + expected_texts = ["line3", "line4", "line5", "line6"] + assert len(row_data) == len(expected_texts), f"Expected {len(expected_texts)} rows starting from index 2" + assert row_data == expected_texts, f"Expected texts {expected_texts}, got {row_data}"