Skip to content

Commit

Permalink
fix hf data loading for datasets>=3.1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Nov 4, 2024
1 parent 94dcda1 commit 7b2241f
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/levanter/data/sharded_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,10 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]:
dataset = self._load_dataset()
if isinstance(dataset, datasets.IterableDataset) and shard_name != "data":
# ex_iterable has a key that gets discarded typically
shard = map(lambda t: t[1], dataset._ex_iterable.shard_data_sources(int(shard_name), dataset.n_shards))
shard = map(
lambda t: t[1],
dataset._ex_iterable.shard_data_sources(index=int(shard_name), num_shards=dataset.n_shards),
)
else:
shard = dataset

Expand Down

0 comments on commit 7b2241f

Please sign in to comment.