Skip to content

Commit

Permalink
Changes in how parquet is read (#766)
Browse files Browse the repository at this point in the history
Closes #763 and addresses David's comments in #764
  • Loading branch information
nikil-ravi authored Oct 15, 2024
1 parent 02f34ac commit 877ca7e
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 18 deletions.
86 changes: 69 additions & 17 deletions src/levanter/data/sharded_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,32 +233,59 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[str]:
compression = "infer"
if url.endswith(".zstd"): # hacky way to detect zstd
compression = "zstd"
with fsspec.open(url, "r", compression=compression) as f:
format = _sniff_format_for_dataset(url)
match format:
case ".jsonl":

format = _sniff_format_for_dataset(url)
match format:
case ".jsonl":
with fsspec.open(url, "r", compression=compression) as f:
# TODO: would be nice if we could seek faster than this. Right now, all we do is skip json parsing
# which is not nothing, but not ideal.
for line in f:
if i >= row:
yield json.loads(line)[self.text_key]
i += 1
case ".txt":
case ".txt":
with fsspec.open(url, "r", compression=compression) as f:
for line in f:
if i >= row:
yield line
i += 1
case ".json":
case ".json":
with fsspec.open(url, "r", compression=compression) as f:
data = json.load(f)
for doc in data[row:]:
yield doc[self.text_key]
case ".parquet":
table = pq.read_table(f)
sliced_table = table.slice(row)
for record in sliced_table.to_pylist():
yield record[self.text_key] # assumes text_key is in record
case _:
raise ValueError(f"Unknown format {format}")
case ".parquet":
with fsspec.open(url, "rb", compression=compression) as f:
parquet_file = pq.ParquetFile(f)
total_rows = parquet_file.metadata.num_rows
if row >= total_rows:
return iter([])

num_row_groups = parquet_file.metadata.num_row_groups

# Compute cumulative row counts
row_counts = [parquet_file.metadata.row_group(i).num_rows for i in range(num_row_groups)]
cumulative_rows = [0]
for count in row_counts:
cumulative_rows.append(cumulative_rows[-1] + count)

# Find the starting row group and row within it
for idx, cum_row in enumerate(cumulative_rows):
if cum_row > row:
row_group_index = idx - 1
start_row_in_group = row - cumulative_rows[row_group_index]
break

# Read from the starting row group onwards
for rg_idx in range(row_group_index, parquet_file.num_row_groups):
table = parquet_file.read_row_group(rg_idx, columns=[self.text_key])
if rg_idx == row_group_index:
table = table.slice(start_row_in_group)
for record in table.to_pylist():
yield record[self.text_key]
case _:
raise ValueError(f"Unknown format {format}")


class AudioTextUrlDataSource(ShardedDataSource[Tuple[np.ndarray, int, str]]):
Expand Down Expand Up @@ -448,10 +475,35 @@ def shard_names(self) -> Sequence[str]:
def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]:
url = self._shard_name_to_url_mapping[shard_name]
with fsspec.open(url, "rb", compression="infer") as f:
table = pq.read_table(f)
sliced_table = table.slice(row) # zero-copy slicing
for record in sliced_table.to_pylist():
yield record
parquet_file = pq.ParquetFile(f)
total_rows = parquet_file.metadata.num_rows
if row >= total_rows:
return iter([])

num_row_groups = parquet_file.metadata.num_row_groups

# Compute cumulative row counts
row_counts = [parquet_file.metadata.row_group(i).num_rows for i in range(num_row_groups)]
cumulative_rows = [0]
for count in row_counts:
cumulative_rows.append(cumulative_rows[-1] + count)

# find starting row group and also find the row within it
for idx, cum_row in enumerate(cumulative_rows):
if cum_row > row:
row_group_index = idx - 1
start_row_in_group = row - cumulative_rows[row_group_index]
break

# read from the starting row group onwards
for rg_idx in range(row_group_index, parquet_file.num_row_groups):
table = parquet_file.read_row_group(rg_idx)

# if we're in the row group we want, slice the table at/from the row we want
if rg_idx == row_group_index:
table = table.slice(start_row_in_group)

yield from table.to_pylist()


def _mk_shard_name_mapping(urls):
Expand Down
33 changes: 32 additions & 1 deletion tests/test_sharded_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import os
import tempfile

from levanter.data.sharded_datasource import AudioTextUrlDataSource, ParquetDataSource, _sniff_format_for_dataset
from levanter.data.sharded_datasource import (
AudioTextUrlDataSource,
ParquetDataSource,
TextUrlDataSource,
_sniff_format_for_dataset,
)
from test_utils import skip_if_no_soundlibs


Expand Down Expand Up @@ -68,3 +73,29 @@ def test_basic_parquet_datasource_read_row():
assert row_data[0]["column2"] == 20
assert row_data[1]["column1"] == "value3"
assert row_data[1]["column2"] == 30


def test_text_url_data_source_parquet():
import pyarrow as pa
import pyarrow.parquet as pq

with tempfile.NamedTemporaryFile(suffix=".parquet", delete=True) as f:
data = {
"text": ["line1", "line2", "line3", "line4", "line5", "line6"],
"column2": [10, 20, 30, 40, 50, 60],
}
table = pa.Table.from_pydict(data)
pq.write_table(table, f.name)

datasource = TextUrlDataSource([os.path.abspath(f.name)], text_key="text")

assert len(datasource.shard_names) == 1, "Expected only one shard"
shard_name = datasource.shard_names[0]

# Read data starting from row 2
row_data = list(datasource.open_shard_at_row(shard_name=shard_name, row=2))

# Verify the output
expected_texts = ["line3", "line4", "line5", "line6"]
assert len(row_data) == len(expected_texts), f"Expected {len(expected_texts)} rows starting from index 2"
assert row_data == expected_texts, f"Expected texts {expected_texts}, got {row_data}"

0 comments on commit 877ca7e

Please sign in to comment.