Skip to content

Commit

Permalink
changes in how parquet is read
Browse files Browse the repository at this point in the history
  • Loading branch information
nikil-ravi committed Oct 15, 2024
1 parent 3fe8995 commit 8871fcc
Showing 1 changed file with 52 additions and 8 deletions.
60 changes: 52 additions & 8 deletions src/levanter/data/sharded_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,10 +244,31 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[str]:
for doc in data[row:]:
yield doc[self.text_key]
case ".parquet":
table = pq.read_table(f)
sliced_table = table.slice(row)
for record in sliced_table.to_pylist():
yield record[self.text_key] # assumes text_key is in record
parquet_file = pq.ParquetFile(f)
total_rows = parquet_file.metadata.num_rows
if row >= total_rows:
return iter([])

# Compute cumulative row counts
row_counts = [rg.num_rows for rg in parquet_file.metadata.row_groups]
cumulative_rows = [0]
for count in row_counts:
cumulative_rows.append(cumulative_rows[-1] + count)

# Find the starting row group and row within it
for idx, cum_row in enumerate(cumulative_rows):
if cum_row > row:
row_group_index = idx - 1
start_row_in_group = row - cumulative_rows[row_group_index]
break

# Read from the starting row group onwards
for rg_idx in range(row_group_index, parquet_file.num_row_groups):
table = parquet_file.read_row_group(rg_idx, columns=[self.text_key])
if rg_idx == row_group_index:
table = table.slice(start_row_in_group)
for record in table.to_pylist():
yield record[self.text_key]
case _:
raise ValueError(f"Unknown format {format}")

Expand Down Expand Up @@ -439,10 +460,33 @@ def shard_names(self) -> Sequence[str]:
def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]:
url = self._shard_name_to_url_mapping[shard_name]
with fsspec.open(url, "rb", compression="infer") as f:
table = pq.read_table(f)
sliced_table = table.slice(row) # zero-copy slicing
for record in sliced_table.to_pylist():
yield record
parquet_file = pq.ParquetFile(f)
total_rows = parquet_file.metadata.num_rows
if row >= total_rows:
return iter([])

# compute cumulative row counts
row_counts = [rg_meta.num_rows for rg_meta in parquet_file.metadata.row_groups]
cumulative_rows = [0]
for count in row_counts:
cumulative_rows.append(cumulative_rows[-1] + count)

# find starting row group and also find the row within it
for idx, cum_row in enumerate(cumulative_rows):
if cum_row > row:
row_group_index = idx - 1
start_row_in_group = row - cumulative_rows[row_group_index]
break

# read from the starting row group onwards
for rg_idx in range(row_group_index, parquet_file.num_row_groups):
table = parquet_file.read_row_group(rg_idx)

# if we're in the row group we want, slice the table at/from the row we want
if rg_idx == row_group_index:
table = table.slice(start_row_in_group)

yield from table.to_pylist()


def _mk_shard_name_mapping(urls):
Expand Down

0 comments on commit 8871fcc

Please sign in to comment.