From df3a5dc265333d966b5f12128fd7534054307a50 Mon Sep 17 00:00:00 2001 From: alex-hh Date: Fri, 4 Oct 2024 10:08:12 +0100 Subject: [PATCH 01/10] implement repeat method for iterable dataset --- src/datasets/iterable_dataset.py | 74 ++++++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index 5f5c49f1556..352a2486a0d 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -1483,6 +1483,57 @@ def n_shards(self) -> int: return self.ex_iterable.n_shards +class RepeatExamplesIterable(_BaseExamplesIterable): + """ + Iterable that repeats the underlying iterable a given number of times. + It does not duplicate shards, so that duplicate shards are not seen in + the same iteration. + """ + + def __init__( + self, + ex_iterable: _BaseExamplesIterable, + num_times: int, + ): + super().__init__() + self.ex_iterable = ex_iterable + self.num_times = num_times + + def _init_state_dict(self) -> dict: + self._state_dict = { + "repeat_index": 0, + "ex_iterable": self.ex_iterable._init_state_dict(), + } + return self._state_dict + + def __iter__(self): + if self.num_times is None: + while True: + yield from self.ex_iterable + if self._state_dict: + self._state_dict["repeat_index"] += 1 + else: + for _ in range(self.num_times): + yield from self.ex_iterable + if self._state_dict: + self._state_dict["repeat_index"] += 1 + + def shuffle_data_sources(self, generator: np.random.Generator) -> "RepeatExamplesIterable": + """Shuffle the underlying iterable, then repeat.""" + return RepeatExamplesIterable(self.ex_iterable.shuffle_data_sources(generator), num_times=self.num_times) + + def shard_data_sources(self, worker_id: int, num_workers: int) -> "RepeatExamplesIterable": + """Shard, then repeat shards.""" + return RepeatExamplesIterable( + self.ex_iterable.shard_data_sources(worker_id, num_workers), + num_times=self.num_times, + ) + + @property + def n_shards(self) -> int: + return self.ex_iterable.n_shards + + class TakeExamplesIterable(_BaseExamplesIterable): def __init__( self, @@ -2513,6 +2564,29 @@ def skip(self, n: int) -> "IterableDataset": token_per_repo_id=self._token_per_repo_id, ) + def repeat(self, num_times: Optional[int] = None) -> "IterableDataset": + """ + Create a new [`IterableDataset`] that repeats the underlying dataset `num_times` times. + + N.B. Duplicate data is never seen in the same iteration, even after shuffling: + ds.repeat(n).shuffle(seed=42) is equivalent to ds.shuffle(seed=42).repeat(n) + + Args: + num_times (`int`) or (`None`): + Number of times to repeat the dataset. If `None`, the dataset will be repeated indefinitely. + + Example: + """ + return IterableDataset( + ex_iterable=RepeatExamplesIterable(self._ex_iterable, num_times=num_times), + info=self._info, + split=self._split, + formatting=self._formatting, + shuffling=copy.deepcopy(self._shuffling), + distributed=copy.deepcopy(self._distributed), + token_per_repo_id=self._token_per_repo_id, + ) + def take(self, n: int) -> "IterableDataset": """ Create a new [`IterableDataset`] with only the first `n` elements. From ab3e9ab987e616631a02dcc67628317c40bea735 Mon Sep 17 00:00:00 2001 From: alex-hh Date: Fri, 4 Oct 2024 10:30:03 +0100 Subject: [PATCH 02/10] implement repeat method for map-style dataset --- src/datasets/arrow_dataset.py | 34 ++++++++++++++++++++++++++++++++ src/datasets/iterable_dataset.py | 2 +- 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index e9e074e0e97..4135f331f72 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -4040,6 +4040,40 @@ def skip(self, n: int) -> "Dataset": """ return self.select(range(n, len(self))) + def repeat(self, num_times: int) -> "Dataset": + """ + Create a new [`Dataset`] that repeats the underlying dataset `num_times` times. + + Like itertools.repeat, repeating once just returns the full dataset. + + Args: + num_times (`int`): + Number of times to repeat the dataset. + + Example: + ```py + >>> from datasets import load_dataset + >>> ds = load_dataset("rotten_tomatoes", split="train") + >>> ds = ds.take(2).repeat(2) + >>> list(ds) + [{'label': 1, + 'text': 'the rock is destined to be the 21st century\'s new " conan " and that he\'s going to make a splash even greater than arnold schwarzenegger , jean-claud van damme or steven segal .'}, + {'label': 1, + 'text': 'the gorgeously elaborate continuation of " the lord of the rings " trilogy is so huge that a column of words cannot adequately describe co-writer/director peter jackson\'s expanded vision of j . r . r . tolkien\'s middle-earth .'}, + {'label': 1, 'text': 'effective but too-tepid biopic'}, + {'label': 1, + 'text': 'the rock is destined to be the 21st century\'s new " conan " and that he\'s going to make a splash even greater than arnold schwarzenegger , jean-claud van damme or steven segal .'}, + {'label': 1, + 'text': 'the gorgeously elaborate continuation of " the lord of the rings " trilogy is so huge that a column of words cannot adequately describe co-writer/director peter jackson\'s expanded vision of j . r . r . tolkien\'s middle-earth .'}, + {'label': 1, 'text': 'effective but too-tepid biopic'}] + ``` + """ + if num_times is None: + raise ValueError("Map style datasets do not support indefinite repetition.") + num_times = max(num_times, 0) + indices = list(range(len(self))) * num_times + return self.select(indices) + def take(self, n: int) -> "Dataset": """ Create a new [`Dataset`] with only the first `n` elements. diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index 352a2486a0d..5604f8dfba2 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -2564,7 +2564,7 @@ def skip(self, n: int) -> "IterableDataset": token_per_repo_id=self._token_per_repo_id, ) - def repeat(self, num_times: Optional[int] = None) -> "IterableDataset": + def repeat(self, num_times: Optional[int]) -> "IterableDataset": """ Create a new [`IterableDataset`] that repeats the underlying dataset `num_times` times. From 85ee92f892cc61f790cd537898b87a2a29147892 Mon Sep 17 00:00:00 2001 From: alex-hh Date: Fri, 4 Oct 2024 11:36:00 +0100 Subject: [PATCH 03/10] fix iterable dataset repeat --- src/datasets/iterable_dataset.py | 45 ++++++++++++++++++++++---------- 1 file changed, 31 insertions(+), 14 deletions(-) diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index 5604f8dfba2..6fcae604be1 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -1486,8 +1486,6 @@ def n_shards(self) -> int: class RepeatExamplesIterable(_BaseExamplesIterable): """ Iterable that repeats the underlying iterable a given number of times. - It does not duplicate shards, so that duplicate shards are not seen in - the same iteration. """ def __init__( @@ -1507,16 +1505,15 @@ def _init_state_dict(self) -> dict: return self._state_dict def __iter__(self): - if self.num_times is None: - while True: - yield from self.ex_iterable - if self._state_dict: - self._state_dict["repeat_index"] += 1 - else: - for _ in range(self.num_times): - yield from self.ex_iterable - if self._state_dict: - self._state_dict["repeat_index"] += 1 + repeat_index = self._state_dict["repeat_index"] if self._state_dict else 0 + while True: + if self.num_times and repeat_index >= max(self.num_times, 0): + break + yield from self.ex_iterable + repeat_index += 1 + if self._state_dict: + self._state_dict["repeat_index"] = repeat_index + self._state_dict["ex_iterable"] = self.ex_iterable._init_state_dict() def shuffle_data_sources(self, generator: np.random.Generator) -> "RepeatExamplesIterable": """Shuffle the underlying iterable, then repeat.""" @@ -2568,14 +2565,34 @@ def repeat(self, num_times: Optional[int]) -> "IterableDataset": """ Create a new [`IterableDataset`] that repeats the underlying dataset `num_times` times. - N.B. Duplicate data is never seen in the same iteration, even after shuffling: - ds.repeat(n).shuffle(seed=42) is equivalent to ds.shuffle(seed=42).repeat(n) + N.B. The effect of calling shuffle after repeat depends significantly on buffer size. + With buffer_size 1, duplicate data is never seen in the same iteration, even after shuffling: + ds.repeat(n).shuffle(seed=42, buffer_size=1) is equivalent to ds.shuffle(seed=42, buffer_size=1).repeat(n), + and only shuffles shard orders within each iteration. + With buffer size >= (num samples in the dataset * num_times), we get full shuffling of the repeated data, i.e. we can observe duplicates in + the same iteration. Args: num_times (`int`) or (`None`): Number of times to repeat the dataset. If `None`, the dataset will be repeated indefinitely. Example: + ```py + >>> from datasets import load_dataset + >>> ds = load_dataset("rotten_tomatoes", split="train") + >>> ds = ds.take(2).repeat(2) + >>> list(ds) + [{'label': 1, + 'text': 'the rock is destined to be the 21st century\'s new " conan " and that he\'s going to make a splash even greater than arnold schwarzenegger , jean-claud van damme or steven segal .'}, + {'label': 1, + 'text': 'the gorgeously elaborate continuation of " the lord of the rings " trilogy is so huge that a column of words cannot adequately describe co-writer/director peter jackson\'s expanded vision of j . r . r . tolkien\'s middle-earth .'}, + {'label': 1, 'text': 'effective but too-tepid biopic'}, + {'label': 1, + 'text': 'the rock is destined to be the 21st century\'s new " conan " and that he\'s going to make a splash even greater than arnold schwarzenegger , jean-claud van damme or steven segal .'}, + {'label': 1, + 'text': 'the gorgeously elaborate continuation of " the lord of the rings " trilogy is so huge that a column of words cannot adequately describe co-writer/director peter jackson\'s expanded vision of j . r . r . tolkien\'s middle-earth .'}, + {'label': 1, 'text': 'effective but too-tepid biopic'}] + ``` """ return IterableDataset( ex_iterable=RepeatExamplesIterable(self._ex_iterable, num_times=num_times), From 273c0fe6cd22228e11c5293faa300c7068b2a98c Mon Sep 17 00:00:00 2001 From: alex-hh Date: Wed, 29 Jan 2025 21:14:17 +0100 Subject: [PATCH 04/10] address pr comments --- src/datasets/arrow_dataset.py | 4 +--- src/datasets/iterable_dataset.py | 4 ++-- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 4135f331f72..528ad040307 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -4070,9 +4070,7 @@ def repeat(self, num_times: int) -> "Dataset": """ if num_times is None: raise ValueError("Map style datasets do not support indefinite repetition.") - num_times = max(num_times, 0) - indices = list(range(len(self))) * num_times - return self.select(indices) + return _concatenate_map_style_datasets([self] * num_times) if num_times > 0 else self.select([]) def take(self, n: int) -> "Dataset": """ diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index 6fcae604be1..4e9fd704d81 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -1491,7 +1491,7 @@ class RepeatExamplesIterable(_BaseExamplesIterable): def __init__( self, ex_iterable: _BaseExamplesIterable, - num_times: int, + num_times: Optional[int], ): super().__init__() self.ex_iterable = ex_iterable @@ -1507,7 +1507,7 @@ def _init_state_dict(self) -> dict: def __iter__(self): repeat_index = self._state_dict["repeat_index"] if self._state_dict else 0 while True: - if self.num_times and repeat_index >= max(self.num_times, 0): + if self.num_times is not None and repeat_index >= max(self.num_times, 0): break yield from self.ex_iterable repeat_index += 1 From c3faf4c2eafd95eaabffb3e5237877e0d6bd3117 Mon Sep 17 00:00:00 2001 From: alex-hh Date: Wed, 29 Jan 2025 21:31:20 +0100 Subject: [PATCH 05/10] add test case for map-style dataset --- tests/test_arrow_dataset.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index ffa048644e2..e67f8bdc7b9 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -869,6 +869,29 @@ def test_concatenate_pickle(self, in_memory): self.assertEqual(dset_concat.info.description, "Dataset2\n\nDataset1") del dset1, dset2, dset3 + def test_repeat(self, in_memory): + with tempfile.TemporaryDirectory() as tmp_dir: + with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset: + repeated_dset = dset.repeat(3) + column_values_dict = {col: repeated_dset[col] for col in repeated_dset.column_names} + for col, single_values in column_values_dict.items(): + self.assertListEqual(repeated_dset[col], single_values * 3) + + with tempfile.TemporaryDirectory() as tmp_dir: + with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset: + with pytest.raises(ValueError): + dset.repeat(None) + + with tempfile.TemporaryDirectory() as tmp_dir: + with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset: + repeated_dset = dset.repeat(0) + self.assertEqual(len(repeated_dset), 0) + + with tempfile.TemporaryDirectory() as tmp_dir: + with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset: + repeated_dset = dset.repeat(-1) + self.assertEqual(len(repeated_dset), 0) + def test_flatten(self, in_memory): with tempfile.TemporaryDirectory() as tmp_dir: with Dataset.from_dict( From b5047d7591bdc53b87a7e5fd2033ac93008457db Mon Sep 17 00:00:00 2001 From: alex-hh Date: Wed, 29 Jan 2025 22:58:10 +0100 Subject: [PATCH 06/10] add test cases for iterable datasets --- tests/test_iterable_dataset.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index 232652f1fa3..3747bb4815f 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -31,6 +31,7 @@ MappedExamplesIterable, RandomlyCyclingMultiSourcesExamplesIterable, RebatchedArrowExamplesIterable, + RepeatExamplesIterable, SelectColumnsIterable, ShuffledDataSourcesArrowExamplesIterable, ShuffledDataSourcesExamplesIterable, @@ -1111,6 +1112,28 @@ def test_take_examples_iterable(): assert_load_state_dict_resumes_iteration(take_ex_iterable) +@pytest.mark.parametrize( + "n, num_times", + [ + (3, None), + (3, 3), + (3, 0), + ], +) +def test_repeat_examples_iterable(n, num_times): + base_ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n}) + ex_iterable = RepeatExamplesIterable(base_ex_iterable, num_times=num_times) + all_examples = [x for _, x in generate_examples_fn(n=n)] + if num_times is not None: + expected = all_examples * max(num_times, 0) + assert [x for _, x in ex_iterable] == expected + else: + max_iters = 135 + iterator = iter(ex_iterable) + for i in range(max_iters): + assert next(iterator)[1] == all_examples[i % len(all_examples)], f"iteration {i} failed," + + def test_vertically_concatenated_examples_iterable(): ex_iterable1 = ExamplesIterable(generate_examples_fn, {"label": 10}) ex_iterable2 = ExamplesIterable(generate_examples_fn, {"label": 5}) @@ -1681,6 +1704,14 @@ def test_iterable_dataset_take(dataset: IterableDataset, n): assert list(take_dataset) == list(dataset)[:n] +@pytest.mark.parametrize("n", [0, 2]) +def test_iterable_dataset_repeat(dataset: IterableDataset, n): + repeat_dataset = dataset.repeat(n) + assert isinstance(repeat_dataset._ex_iterable, RepeatExamplesIterable) + assert repeat_dataset._ex_iterable.num_times == n + assert list(repeat_dataset) == list(dataset) * n + + @pytest.mark.parametrize("method", ["skip", "take"]) @pytest.mark.parametrize("after_shuffle", [False, True]) @pytest.mark.parametrize("count", [2, 5, 11]) From 79de1989a9514b2c7a86e0008ec690e0d141b4ae Mon Sep 17 00:00:00 2001 From: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> Date: Thu, 30 Jan 2025 11:52:35 +0100 Subject: [PATCH 07/10] fix code formatting --- tests/test_iterable_dataset.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index c7a0002312d..8a972ec9cd3 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -1758,7 +1758,6 @@ def test_iterable_dataset_take(dataset: IterableDataset, n): assert list(take_dataset) == list(dataset)[:n] - @pytest.mark.parametrize("n", [0, 2]) def test_iterable_dataset_repeat(dataset: IterableDataset, n): repeat_dataset = dataset.repeat(n) From 6b51fd5deff31f4c22814bc9d4cd1aba419b286f Mon Sep 17 00:00:00 2001 From: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> Date: Thu, 30 Jan 2025 12:04:02 +0100 Subject: [PATCH 08/10] Update test_arrow_dataset.py --- tests/test_arrow_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 1654234b51b..93d62beb143 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -873,7 +873,7 @@ def test_repeat(self, in_memory): with tempfile.TemporaryDirectory() as tmp_dir: with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset: repeated_dset = dset.repeat(3) - column_values_dict = {col: repeated_dset[col] for col in repeated_dset.column_names} + column_values_dict = {col: dset[col] for col in dset.column_names} for col, single_values in column_values_dict.items(): self.assertListEqual(repeated_dset[col], single_values * 3) From bee285a9453783bc91eaa69dc65505a84a6a388a Mon Sep 17 00:00:00 2001 From: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> Date: Fri, 31 Jan 2025 12:01:26 +0100 Subject: [PATCH 09/10] Update test_arrow_dataset.py --- tests/test_arrow_dataset.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 93d62beb143..766174b15a1 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -876,6 +876,7 @@ def test_repeat(self, in_memory): column_values_dict = {col: dset[col] for col in dset.column_names} for col, single_values in column_values_dict.items(): self.assertListEqual(repeated_dset[col], single_values * 3) + del repeated_dset with tempfile.TemporaryDirectory() as tmp_dir: with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset: @@ -886,11 +887,13 @@ def test_repeat(self, in_memory): with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset: repeated_dset = dset.repeat(0) self.assertEqual(len(repeated_dset), 0) + del repeated_dset with tempfile.TemporaryDirectory() as tmp_dir: with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset: repeated_dset = dset.repeat(-1) self.assertEqual(len(repeated_dset), 0) + del repeated_dset def test_flatten(self, in_memory): with tempfile.TemporaryDirectory() as tmp_dir: From e3f8f30773ae192163a960fa9471096bb586af02 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> Date: Wed, 5 Feb 2025 15:54:24 +0100 Subject: [PATCH 10/10] Update main_classes.mdx --- docs/source/package_reference/main_classes.mdx | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/package_reference/main_classes.mdx b/docs/source/package_reference/main_classes.mdx index 185bde10d72..62dc9127d4b 100644 --- a/docs/source/package_reference/main_classes.mdx +++ b/docs/source/package_reference/main_classes.mdx @@ -52,6 +52,7 @@ The base class [`Dataset`] implements a Dataset backed by an Apache Arrow table. - take - train_test_split - shard + - repeat - to_tf_dataset - push_to_hub - save_to_disk @@ -172,6 +173,7 @@ The base class [`IterableDataset`] implements an iterable Dataset backed by pyth - skip - take - shard + - repeat - load_state_dict - state_dict - info