Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes in how parquet is read #766

Merged
merged 3 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 69 additions & 17 deletions src/levanter/data/sharded_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,32 +224,59 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[str]:
compression = "infer"
if url.endswith(".zstd"): # hacky way to detect zstd
compression = "zstd"
with fsspec.open(url, "r", compression=compression) as f:
format = _sniff_format_for_dataset(url)
match format:
case ".jsonl":

format = _sniff_format_for_dataset(url)
match format:
case ".jsonl":
with fsspec.open(url, "r", compression=compression) as f:
# TODO: would be nice if we could seek faster than this. Right now, all we do is skip json parsing
# which is not nothing, but not ideal.
for line in f:
if i >= row:
yield json.loads(line)[self.text_key]
i += 1
case ".txt":
case ".txt":
with fsspec.open(url, "r", compression=compression) as f:
for line in f:
if i >= row:
yield line
i += 1
case ".json":
case ".json":
with fsspec.open(url, "r", compression=compression) as f:
data = json.load(f)
for doc in data[row:]:
yield doc[self.text_key]
case ".parquet":
table = pq.read_table(f)
sliced_table = table.slice(row)
for record in sliced_table.to_pylist():
yield record[self.text_key] # assumes text_key is in record
case _:
raise ValueError(f"Unknown format {format}")
case ".parquet":
with fsspec.open(url, "rb", compression=compression) as f:
parquet_file = pq.ParquetFile(f)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the logic in this block is complex enough and duplicated enough i'd prefer if you extracted a method.

total_rows = parquet_file.metadata.num_rows
if row >= total_rows:
return iter([])

num_row_groups = parquet_file.metadata.num_row_groups

# Compute cumulative row counts
row_counts = [parquet_file.metadata.row_group(i).num_rows for i in range(num_row_groups)]
cumulative_rows = [0]
for count in row_counts:
cumulative_rows.append(cumulative_rows[-1] + count)

# Find the starting row group and row within it
for idx, cum_row in enumerate(cumulative_rows):
if cum_row > row:
row_group_index = idx - 1
start_row_in_group = row - cumulative_rows[row_group_index]
break

# Read from the starting row group onwards
for rg_idx in range(row_group_index, parquet_file.num_row_groups):
table = parquet_file.read_row_group(rg_idx, columns=[self.text_key])
if rg_idx == row_group_index:
table = table.slice(start_row_in_group)
for record in table.to_pylist():
yield record[self.text_key]
case _:
raise ValueError(f"Unknown format {format}")


class AudioTextUrlDataSource(ShardedDataSource[Tuple[np.ndarray, int, str]]):
Expand Down Expand Up @@ -439,10 +466,35 @@ def shard_names(self) -> Sequence[str]:
def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]:
url = self._shard_name_to_url_mapping[shard_name]
with fsspec.open(url, "rb", compression="infer") as f:
table = pq.read_table(f)
sliced_table = table.slice(row) # zero-copy slicing
for record in sliced_table.to_pylist():
yield record
parquet_file = pq.ParquetFile(f)
total_rows = parquet_file.metadata.num_rows
if row >= total_rows:
return iter([])

num_row_groups = parquet_file.metadata.num_row_groups

# Compute cumulative row counts
row_counts = [parquet_file.metadata.row_group(i).num_rows for i in range(num_row_groups)]
cumulative_rows = [0]
for count in row_counts:
cumulative_rows.append(cumulative_rows[-1] + count)

# find starting row group and also find the row within it
for idx, cum_row in enumerate(cumulative_rows):
if cum_row > row:
row_group_index = idx - 1
start_row_in_group = row - cumulative_rows[row_group_index]
break

# read from the starting row group onwards
for rg_idx in range(row_group_index, parquet_file.num_row_groups):
table = parquet_file.read_row_group(rg_idx)
Copy link
Member

@dlwh dlwh Oct 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now I’m concerned this is a disk seek but probably not worth worrying about


# if we're in the row group we want, slice the table at/from the row we want
if rg_idx == row_group_index:
table = table.slice(start_row_in_group)

yield from table.to_pylist()


def _mk_shard_name_mapping(urls):
Expand Down
33 changes: 32 additions & 1 deletion tests/test_sharded_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import os
import tempfile

from levanter.data.sharded_datasource import AudioTextUrlDataSource, ParquetDataSource, _sniff_format_for_dataset
from levanter.data.sharded_datasource import (
AudioTextUrlDataSource,
ParquetDataSource,
TextUrlDataSource,
_sniff_format_for_dataset,
)
from test_utils import skip_if_no_soundlibs


Expand Down Expand Up @@ -68,3 +73,29 @@ def test_basic_parquet_datasource_read_row():
assert row_data[0]["column2"] == 20
assert row_data[1]["column1"] == "value3"
assert row_data[1]["column2"] == 30


def test_text_url_data_source_parquet():
import pyarrow as pa
import pyarrow.parquet as pq

with tempfile.NamedTemporaryFile(suffix=".parquet", delete=True) as f:
data = {
"text": ["line1", "line2", "line3", "line4", "line5", "line6"],
"column2": [10, 20, 30, 40, 50, 60],
}
table = pa.Table.from_pydict(data)
pq.write_table(table, f.name)

datasource = TextUrlDataSource([os.path.abspath(f.name)], text_key="text")

assert len(datasource.shard_names) == 1, "Expected only one shard"
shard_name = datasource.shard_names[0]

# Read data starting from row 2
row_data = list(datasource.open_shard_at_row(shard_name=shard_name, row=2))

# Verify the output
expected_texts = ["line3", "line4", "line5", "line6"]
assert len(row_data) == len(expected_texts), f"Expected {len(expected_texts)} rows starting from index 2"
assert row_data == expected_texts, f"Expected texts {expected_texts}, got {row_data}"
Loading