diff --git a/pyproject.toml b/pyproject.toml index 19fb077bf..0831605cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ dependencies = [ "draccus>=0.8.0", "pyarrow>=11.0.0", "zstandard>=0.20.0", - "datasets>=2.18,<4.0", + "datasets>=3.1.0,<4.0", "gcsfs>=2024.2,<2024.10", "braceexpand>=0.1.7", "jmp>=0.0.3", diff --git a/src/levanter/data/sharded_datasource.py b/src/levanter/data/sharded_datasource.py index 186a0d9dd..90803df3e 100644 --- a/src/levanter/data/sharded_datasource.py +++ b/src/levanter/data/sharded_datasource.py @@ -197,7 +197,10 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]: dataset = self._load_dataset() if isinstance(dataset, datasets.IterableDataset) and shard_name != "data": # ex_iterable has a key that gets discarded typically - shard = map(lambda t: t[1], dataset._ex_iterable.shard_data_sources(int(shard_name), dataset.n_shards)) + shard = map( + lambda t: t[1], + dataset._ex_iterable.shard_data_sources(index=int(shard_name), num_shards=dataset.n_shards), + ) else: shard = dataset