diff --git a/src/levanter/data/sharded_datasource.py b/src/levanter/data/sharded_datasource.py index 208116ca6..e43b12fde 100644 --- a/src/levanter/data/sharded_datasource.py +++ b/src/levanter/data/sharded_datasource.py @@ -244,10 +244,31 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[str]: for doc in data[row:]: yield doc[self.text_key] case ".parquet": - table = pq.read_table(f) - sliced_table = table.slice(row) - for record in sliced_table.to_pylist(): - yield record[self.text_key] # assumes text_key is in record + parquet_file = pq.ParquetFile(f) + total_rows = parquet_file.metadata.num_rows + if row >= total_rows: + return iter([]) + + # Compute cumulative row counts + row_counts = [rg.num_rows for rg in parquet_file.metadata.row_groups] + cumulative_rows = [0] + for count in row_counts: + cumulative_rows.append(cumulative_rows[-1] + count) + + # Find the starting row group and row within it + for idx, cum_row in enumerate(cumulative_rows): + if cum_row > row: + row_group_index = idx - 1 + start_row_in_group = row - cumulative_rows[row_group_index] + break + + # Read from the starting row group onwards + for rg_idx in range(row_group_index, parquet_file.num_row_groups): + table = parquet_file.read_row_group(rg_idx, columns=[self.text_key]) + if rg_idx == row_group_index: + table = table.slice(start_row_in_group) + for record in table.to_pylist(): + yield record[self.text_key] case _: raise ValueError(f"Unknown format {format}") @@ -439,10 +460,33 @@ def shard_names(self) -> Sequence[str]: def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]: url = self._shard_name_to_url_mapping[shard_name] with fsspec.open(url, "rb", compression="infer") as f: - table = pq.read_table(f) - sliced_table = table.slice(row) # zero-copy slicing - for record in sliced_table.to_pylist(): - yield record + parquet_file = pq.ParquetFile(f) + total_rows = parquet_file.metadata.num_rows + if row >= total_rows: + return iter([]) + + # compute cumulative row counts + row_counts = [rg_meta.num_rows for rg_meta in parquet_file.metadata.row_groups] + cumulative_rows = [0] + for count in row_counts: + cumulative_rows.append(cumulative_rows[-1] + count) + + # find starting row group and also find the row within it + for idx, cum_row in enumerate(cumulative_rows): + if cum_row > row: + row_group_index = idx - 1 + start_row_in_group = row - cumulative_rows[row_group_index] + break + + # read from the starting row group onwards + for rg_idx in range(row_group_index, parquet_file.num_row_groups): + table = parquet_file.read_row_group(rg_idx) + + # if we're in the row group we want, slice the table at/from the row we want + if rg_idx == row_group_index: + table = table.slice(start_row_in_group) + + yield from table.to_pylist() def _mk_shard_name_mapping(urls):