-
Notifications
You must be signed in to change notification settings - Fork 87
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes in how parquet is read #766
Merged
Merged
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -224,32 +224,59 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[str]: | |
compression = "infer" | ||
if url.endswith(".zstd"): # hacky way to detect zstd | ||
compression = "zstd" | ||
with fsspec.open(url, "r", compression=compression) as f: | ||
format = _sniff_format_for_dataset(url) | ||
match format: | ||
case ".jsonl": | ||
|
||
format = _sniff_format_for_dataset(url) | ||
match format: | ||
case ".jsonl": | ||
with fsspec.open(url, "r", compression=compression) as f: | ||
# TODO: would be nice if we could seek faster than this. Right now, all we do is skip json parsing | ||
# which is not nothing, but not ideal. | ||
for line in f: | ||
if i >= row: | ||
yield json.loads(line)[self.text_key] | ||
i += 1 | ||
case ".txt": | ||
case ".txt": | ||
with fsspec.open(url, "r", compression=compression) as f: | ||
for line in f: | ||
if i >= row: | ||
yield line | ||
i += 1 | ||
case ".json": | ||
case ".json": | ||
with fsspec.open(url, "r", compression=compression) as f: | ||
data = json.load(f) | ||
for doc in data[row:]: | ||
yield doc[self.text_key] | ||
case ".parquet": | ||
table = pq.read_table(f) | ||
sliced_table = table.slice(row) | ||
for record in sliced_table.to_pylist(): | ||
yield record[self.text_key] # assumes text_key is in record | ||
case _: | ||
raise ValueError(f"Unknown format {format}") | ||
case ".parquet": | ||
with fsspec.open(url, "rb", compression=compression) as f: | ||
parquet_file = pq.ParquetFile(f) | ||
total_rows = parquet_file.metadata.num_rows | ||
if row >= total_rows: | ||
return iter([]) | ||
|
||
num_row_groups = parquet_file.metadata.num_row_groups | ||
|
||
# Compute cumulative row counts | ||
row_counts = [parquet_file.metadata.row_group(i).num_rows for i in range(num_row_groups)] | ||
cumulative_rows = [0] | ||
for count in row_counts: | ||
cumulative_rows.append(cumulative_rows[-1] + count) | ||
|
||
# Find the starting row group and row within it | ||
for idx, cum_row in enumerate(cumulative_rows): | ||
if cum_row > row: | ||
row_group_index = idx - 1 | ||
start_row_in_group = row - cumulative_rows[row_group_index] | ||
break | ||
|
||
# Read from the starting row group onwards | ||
for rg_idx in range(row_group_index, parquet_file.num_row_groups): | ||
table = parquet_file.read_row_group(rg_idx, columns=[self.text_key]) | ||
if rg_idx == row_group_index: | ||
table = table.slice(start_row_in_group) | ||
for record in table.to_pylist(): | ||
yield record[self.text_key] | ||
case _: | ||
raise ValueError(f"Unknown format {format}") | ||
|
||
|
||
class AudioTextUrlDataSource(ShardedDataSource[Tuple[np.ndarray, int, str]]): | ||
|
@@ -439,10 +466,35 @@ def shard_names(self) -> Sequence[str]: | |
def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]: | ||
url = self._shard_name_to_url_mapping[shard_name] | ||
with fsspec.open(url, "rb", compression="infer") as f: | ||
table = pq.read_table(f) | ||
sliced_table = table.slice(row) # zero-copy slicing | ||
for record in sliced_table.to_pylist(): | ||
yield record | ||
parquet_file = pq.ParquetFile(f) | ||
total_rows = parquet_file.metadata.num_rows | ||
if row >= total_rows: | ||
return iter([]) | ||
|
||
num_row_groups = parquet_file.metadata.num_row_groups | ||
|
||
# Compute cumulative row counts | ||
row_counts = [parquet_file.metadata.row_group(i).num_rows for i in range(num_row_groups)] | ||
cumulative_rows = [0] | ||
for count in row_counts: | ||
cumulative_rows.append(cumulative_rows[-1] + count) | ||
|
||
# find starting row group and also find the row within it | ||
for idx, cum_row in enumerate(cumulative_rows): | ||
if cum_row > row: | ||
row_group_index = idx - 1 | ||
start_row_in_group = row - cumulative_rows[row_group_index] | ||
break | ||
|
||
# read from the starting row group onwards | ||
for rg_idx in range(row_group_index, parquet_file.num_row_groups): | ||
table = parquet_file.read_row_group(rg_idx) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now I’m concerned this is a disk seek but probably not worth worrying about |
||
|
||
# if we're in the row group we want, slice the table at/from the row we want | ||
if rg_idx == row_group_index: | ||
table = table.slice(start_row_in_group) | ||
|
||
yield from table.to_pylist() | ||
|
||
|
||
def _mk_shard_name_mapping(urls): | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the logic in this block is complex enough and duplicated enough i'd prefer if you extracted a method.