Skip to content

Commit

Permalink
Fix hf datasets for new version (#784)
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh authored Nov 4, 2024
1 parent 5ebf8ce commit c823b7d
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ dependencies = [
"draccus>=0.8.0",
"pyarrow>=11.0.0",
"zstandard>=0.20.0",
"datasets>=2.18,<4.0",
"datasets>=3.1.0,<4.0",
"gcsfs>=2024.2,<2024.10",
"braceexpand>=0.1.7",
"jmp>=0.0.3",
Expand Down
5 changes: 4 additions & 1 deletion src/levanter/data/sharded_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,10 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]:
dataset = self._load_dataset()
if isinstance(dataset, datasets.IterableDataset) and shard_name != "data":
# ex_iterable has a key that gets discarded typically
shard = map(lambda t: t[1], dataset._ex_iterable.shard_data_sources(int(shard_name), dataset.n_shards))
shard = map(
lambda t: t[1],
dataset._ex_iterable.shard_data_sources(index=int(shard_name), num_shards=dataset.n_shards),
)
else:
shard = dataset

Expand Down

0 comments on commit c823b7d

Please sign in to comment.