From 3fb33f097b18862956ac15f940bfaea1bdd63e4c Mon Sep 17 00:00:00 2001 From: Ramanish Singh Date: Wed, 20 Nov 2024 20:41:08 -0800 Subject: [PATCH 1/2] initial commit --- .../test_multi_node_round_robin_sampler.py | 107 +++++++++++++ .../multi_node_round_robin_sampler.py | 146 ++++++++++++++++++ 2 files changed, 253 insertions(+) create mode 100644 test/nodes/test_multi_node_round_robin_sampler.py create mode 100644 torchdata/nodes/samplers/multi_node_round_robin_sampler.py diff --git a/test/nodes/test_multi_node_round_robin_sampler.py b/test/nodes/test_multi_node_round_robin_sampler.py new file mode 100644 index 000000000..908b8dc2b --- /dev/null +++ b/test/nodes/test_multi_node_round_robin_sampler.py @@ -0,0 +1,107 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import collections +import itertools +from enum import unique + +from parameterized import parameterized +from torch.testing._internal.common_utils import TestCase +from torchdata.nodes.adapters import IterableWrapper +from torchdata.nodes.batch import Batcher +from torchdata.nodes.prefetch import Prefetcher +from torchdata.nodes.samplers.multi_node_round_robin_sampler import ( + MultiNodeRoundRobinSampler, +) +from torchdata.nodes.samplers.stop_criteria import StopCriteria + +from .utils import DummyIterableDataset, run_test_save_load_state + + +class TestMultiNodeRoundRobinSampler(TestCase): + def setUp(self) -> None: + super().setUp() + self._num_samples = 1 + self._num_datasets = 3 + + def get_equal_dataset(self, num_samples, num_datasets): + datasets = { + f"ds{i}": IterableWrapper(DummyIterableDataset(num_samples, f"ds{i}")) + for i in range(num_datasets) + } + return datasets + + def get_unequal_dataset(self, num_samples, num_datasets): + datasets = { + f"ds{i}": IterableWrapper(DummyIterableDataset(num_samples + i, f"ds{i}")) + for i in range(num_datasets) + } + return datasets + + def test_multi_node_round_robin_sampler_equal_dataset(self) -> None: + datasets = self.get_equal_dataset(self._num_samples, self._num_datasets) + sampler = MultiNodeRoundRobinSampler( + datasets, StopCriteria.FIRST_DATASET_EXHAUSTED + ) + batch_size = 3 + num_epochs = 1 + # each dataset has 1 sample, so the first and only epoch must be ['ds0', 'ds1', 'ds2'] + batcher = Batcher(sampler, batch_size=batch_size) + for _ in range(num_epochs): + results = next(batcher) + self.assertGreater(len(results), 0) + datasets_in_results = [result["name"] for result in results] + dataset_counts_in_results = collections.Counter(datasets_in_results) + for key in dataset_counts_in_results: + self.assertEqual(dataset_counts_in_results[key], 1) + + def test_multi_node_round_robin_sampler_unequal_dataset(self) -> None: + datasets = self.get_unequal_dataset(self._num_samples, self._num_datasets) + sampler = MultiNodeRoundRobinSampler( + datasets, StopCriteria.ALL_DATASETS_EXHAUSTED + ) + batch_size = 3 + num_epochs = 2 + batcher = Batcher(sampler, batch_size=batch_size) + # In this case, first epoch must be ['ds0', 'ds1', 'ds2'] and second epoch must be ['ds1', 'ds2', 'ds2'] + for epoch in range(num_epochs): + results = next(batcher) + self.assertGreater(len(results), 0) + datasets_in_results = [result["name"] for result in results] + dataset_counts_in_results = collections.Counter(datasets_in_results) + if epoch == 0: + self.assertEqual(len(dataset_counts_in_results), self._num_datasets) + for key in dataset_counts_in_results: + self.assertEqual(dataset_counts_in_results[key], 1) + elif epoch == 1: + self.assertEqual(len(dataset_counts_in_results), self._num_datasets - 1) + for key in dataset_counts_in_results: + if key == "ds0": + self.assertEqual(dataset_counts_in_results[key], 0) + if key == "ds1": + self.assertEqual(dataset_counts_in_results[key], 1) + else: + self.assertEqual(dataset_counts_in_results[key], 2) + + def test_get_state(self) -> None: + datasets = self.get_equal_dataset(self._num_samples, self._num_datasets) + sampler = MultiNodeRoundRobinSampler( + datasets, StopCriteria.FIRST_DATASET_EXHAUSTED + ) + state = sampler.get_state() + self.assertIn("current_dataset_index", state) + self.assertIn("datasets_exhausted", state) + self.assertIn("dataset_node_states", state) + + def test_multi_node_round_robin_large_sample_size(self) -> None: + num_samples = 1500 + num_datasets = 3 + datasets = self.get_equal_dataset(num_samples, num_datasets) + sampler = MultiNodeRoundRobinSampler( + datasets, StopCriteria.ALL_DATASETS_EXHAUSTED + ) + prefetcher = Prefetcher(sampler, 3) + run_test_save_load_state(self, prefetcher, 400) diff --git a/torchdata/nodes/samplers/multi_node_round_robin_sampler.py b/torchdata/nodes/samplers/multi_node_round_robin_sampler.py new file mode 100644 index 000000000..32926eacd --- /dev/null +++ b/torchdata/nodes/samplers/multi_node_round_robin_sampler.py @@ -0,0 +1,146 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import copy +from typing import Any, Dict, Mapping, Optional + +from torchdata.nodes.base_node import BaseNode, T +from torchdata.nodes.samplers.stop_criteria import StopCriteria + + +class MultiNodeRoundRobinSampler(BaseNode[T]): + """A node that samples from multiple datasets in a round robin fashion. + + This node expects to take in a dictionary of source nodes. + + The node implements the state using the following keys: + - CURRENT_DATASET_INDEX_KEY: The index of the current dataset. + - DATASETS_EXHAUSTED_KEY: A dictionary of booleans indicating whether each source node is exhausted. + - DATASET_NODE_STATES_KEY: A dictionary of states for each source node. + + We support multiple stopping criteria: + - CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED: Cycle through the source nodes until all datasets + are exhausted. This is the default behavior. + - FIRST_DATASET_EXHAUSTED: Stop when the first dataset is exhausted. + - ALL_DATASETS_EXHAUSTED: Stop when all datasets are exhausted. + + #TODO: Add examples of usage and output + On complete exhaustion of the source nodes, the node will raise StopIteration. + + Parameters: + source_nodes (Mapping[str, BaseNode[T]]): A dictionary of source nodes. + stop_criteria (str): The stopping criteria. Default is CYCLE_UNTIL_ALL_DATASETS_EXHAUST. + """ + + CURRENT_DATASET_INDEX_KEY = "current_dataset_index" + DATASET_NODE_STATES_KEY = "dataset_node_states" + DATASETS_EXHAUSTED_KEY = "datasets_exhausted" + + def __init__( + self, + source_nodes: Mapping[str, BaseNode[T]], + stop_criteria: str = StopCriteria.CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED, + ) -> None: + super().__init__() + self.source_nodes = [source_nodes[k] for k in source_nodes.keys()] + self.num_datasets = len(self.source_nodes) + self.stop_criteria = stop_criteria + self.current_dataset_index = 0 + self._validate_stop_criteria() + self._datasets_exhausted = [False for _ in range(self.num_datasets)] + + def _validate_stop_criteria(self) -> None: + if self.stop_criteria not in [ + StopCriteria.CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED, + StopCriteria.ALL_DATASETS_EXHAUSTED, + StopCriteria.FIRST_DATASET_EXHAUSTED, + ]: + raise ValueError( + f"Invalid {self.stop_criteria=}. stop_criteria must be one of: CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED, , ALL_DATASETS_EXHAUSTED" + ) + + def reset(self, initial_state: Optional[Dict[str, Any]] = None): + super().reset(initial_state) + + if initial_state is not None: + self._datasets_exhausted = initial_state[self.DATASETS_EXHAUSTED_KEY] + for k in range(self.num_datasets): + self.source_nodes[k].reset( + initial_state[self.DATASET_NODE_STATES_KEY][k] + ) + else: + # Force a fresh iterator from all source nodes + self._datasets_exhausted = [False for _ in range(self.num_datasets)] + for k in range(self.num_datasets): + self.source_nodes[k].reset() + + def _check_for_stop_iteration(self) -> None: + if all(self._datasets_exhausted): + # Raise StopIteration if all datasets are exhausted, + # this covers for both ALL_DATASETS_EXHAUSTED and CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED + raise StopIteration() + + # Raise StopIteration is StopCriteria is FIRST_DATASET_EXHAUSTED and + # the first dataset is exhausted. Doing this to correctly catch StopIteration + # when trying next(it) on already exhausted iterator + if self.stop_criteria == StopCriteria.FIRST_DATASET_EXHAUSTED and any( + self._datasets_exhausted + ): + raise StopIteration() + + return + + def next(self) -> T: + + while True: + self._check_for_stop_iteration() + + current_iterator = self.source_nodes[self.current_dataset_index] + try: + if ( + self._datasets_exhausted[self.current_dataset_index] + and self.stop_criteria == StopCriteria.ALL_DATASETS_EXHAUSTED + ): + # Before fetching a new item check if the current dataset is already + # exhaused and if StopCriteria is ALL_DATASETS_EXHAUSTED, move to next dataset + self.current_dataset_index = ( + self.current_dataset_index + 1 + ) % self.num_datasets + continue + item = next(current_iterator) + except StopIteration: + # Mark the dataset as exhausted + self._datasets_exhausted[self.current_dataset_index] = True + + # Based on updated _check_for_stop_iteration, check if we should raise StopIteration + self._check_for_stop_iteration() + + # If StopCriteria is ALL_DATASETS_EXHAUSTED, move to next dataset + if self.stop_criteria == StopCriteria.ALL_DATASETS_EXHAUSTED: + continue + + # If StopCriteria is CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED, + # reset the iterator and try again + self.source_nodes[self.current_dataset_index].reset() + item = next(self.source_nodes[self.current_dataset_index]) + break + + # If we did't throw StopIteration, increment the number of items yielded and return the item + self.current_dataset_index = ( + self.current_dataset_index + 1 + ) % self.num_datasets + + return item + + def get_state(self) -> Dict[str, Any]: + state = { + self.CURRENT_DATASET_INDEX_KEY: self.current_dataset_index, + self.DATASETS_EXHAUSTED_KEY: copy.deepcopy(self._datasets_exhausted), + self.DATASET_NODE_STATES_KEY: { + k: self.source_nodes[k].state_dict() for k in range(self.num_datasets) + }, + } + return state From a760b3ad437b6cd974fba7559c1568ddee2c26b7 Mon Sep 17 00:00:00 2001 From: Ramanish Singh Date: Fri, 22 Nov 2024 05:46:38 -0800 Subject: [PATCH 2/2] run precommit --- .../test_multi_node_round_robin_sampler.py | 28 +++++-------------- .../multi_node_round_robin_sampler.py | 20 ++++--------- 2 files changed, 12 insertions(+), 36 deletions(-) diff --git a/test/nodes/test_multi_node_round_robin_sampler.py b/test/nodes/test_multi_node_round_robin_sampler.py index 908b8dc2b..8d66d1790 100644 --- a/test/nodes/test_multi_node_round_robin_sampler.py +++ b/test/nodes/test_multi_node_round_robin_sampler.py @@ -13,9 +13,7 @@ from torchdata.nodes.adapters import IterableWrapper from torchdata.nodes.batch import Batcher from torchdata.nodes.prefetch import Prefetcher -from torchdata.nodes.samplers.multi_node_round_robin_sampler import ( - MultiNodeRoundRobinSampler, -) +from torchdata.nodes.samplers.multi_node_round_robin_sampler import MultiNodeRoundRobinSampler from torchdata.nodes.samplers.stop_criteria import StopCriteria from .utils import DummyIterableDataset, run_test_save_load_state @@ -28,24 +26,18 @@ def setUp(self) -> None: self._num_datasets = 3 def get_equal_dataset(self, num_samples, num_datasets): - datasets = { - f"ds{i}": IterableWrapper(DummyIterableDataset(num_samples, f"ds{i}")) - for i in range(num_datasets) - } + datasets = {f"ds{i}": IterableWrapper(DummyIterableDataset(num_samples, f"ds{i}")) for i in range(num_datasets)} return datasets def get_unequal_dataset(self, num_samples, num_datasets): datasets = { - f"ds{i}": IterableWrapper(DummyIterableDataset(num_samples + i, f"ds{i}")) - for i in range(num_datasets) + f"ds{i}": IterableWrapper(DummyIterableDataset(num_samples + i, f"ds{i}")) for i in range(num_datasets) } return datasets def test_multi_node_round_robin_sampler_equal_dataset(self) -> None: datasets = self.get_equal_dataset(self._num_samples, self._num_datasets) - sampler = MultiNodeRoundRobinSampler( - datasets, StopCriteria.FIRST_DATASET_EXHAUSTED - ) + sampler = MultiNodeRoundRobinSampler(datasets, StopCriteria.FIRST_DATASET_EXHAUSTED) batch_size = 3 num_epochs = 1 # each dataset has 1 sample, so the first and only epoch must be ['ds0', 'ds1', 'ds2'] @@ -60,9 +52,7 @@ def test_multi_node_round_robin_sampler_equal_dataset(self) -> None: def test_multi_node_round_robin_sampler_unequal_dataset(self) -> None: datasets = self.get_unequal_dataset(self._num_samples, self._num_datasets) - sampler = MultiNodeRoundRobinSampler( - datasets, StopCriteria.ALL_DATASETS_EXHAUSTED - ) + sampler = MultiNodeRoundRobinSampler(datasets, StopCriteria.ALL_DATASETS_EXHAUSTED) batch_size = 3 num_epochs = 2 batcher = Batcher(sampler, batch_size=batch_size) @@ -88,9 +78,7 @@ def test_multi_node_round_robin_sampler_unequal_dataset(self) -> None: def test_get_state(self) -> None: datasets = self.get_equal_dataset(self._num_samples, self._num_datasets) - sampler = MultiNodeRoundRobinSampler( - datasets, StopCriteria.FIRST_DATASET_EXHAUSTED - ) + sampler = MultiNodeRoundRobinSampler(datasets, StopCriteria.FIRST_DATASET_EXHAUSTED) state = sampler.get_state() self.assertIn("current_dataset_index", state) self.assertIn("datasets_exhausted", state) @@ -100,8 +88,6 @@ def test_multi_node_round_robin_large_sample_size(self) -> None: num_samples = 1500 num_datasets = 3 datasets = self.get_equal_dataset(num_samples, num_datasets) - sampler = MultiNodeRoundRobinSampler( - datasets, StopCriteria.ALL_DATASETS_EXHAUSTED - ) + sampler = MultiNodeRoundRobinSampler(datasets, StopCriteria.ALL_DATASETS_EXHAUSTED) prefetcher = Prefetcher(sampler, 3) run_test_save_load_state(self, prefetcher, 400) diff --git a/torchdata/nodes/samplers/multi_node_round_robin_sampler.py b/torchdata/nodes/samplers/multi_node_round_robin_sampler.py index 32926eacd..9e0f58710 100644 --- a/torchdata/nodes/samplers/multi_node_round_robin_sampler.py +++ b/torchdata/nodes/samplers/multi_node_round_robin_sampler.py @@ -68,9 +68,7 @@ def reset(self, initial_state: Optional[Dict[str, Any]] = None): if initial_state is not None: self._datasets_exhausted = initial_state[self.DATASETS_EXHAUSTED_KEY] for k in range(self.num_datasets): - self.source_nodes[k].reset( - initial_state[self.DATASET_NODE_STATES_KEY][k] - ) + self.source_nodes[k].reset(initial_state[self.DATASET_NODE_STATES_KEY][k]) else: # Force a fresh iterator from all source nodes self._datasets_exhausted = [False for _ in range(self.num_datasets)] @@ -86,9 +84,7 @@ def _check_for_stop_iteration(self) -> None: # Raise StopIteration is StopCriteria is FIRST_DATASET_EXHAUSTED and # the first dataset is exhausted. Doing this to correctly catch StopIteration # when trying next(it) on already exhausted iterator - if self.stop_criteria == StopCriteria.FIRST_DATASET_EXHAUSTED and any( - self._datasets_exhausted - ): + if self.stop_criteria == StopCriteria.FIRST_DATASET_EXHAUSTED and any(self._datasets_exhausted): raise StopIteration() return @@ -106,9 +102,7 @@ def next(self) -> T: ): # Before fetching a new item check if the current dataset is already # exhaused and if StopCriteria is ALL_DATASETS_EXHAUSTED, move to next dataset - self.current_dataset_index = ( - self.current_dataset_index + 1 - ) % self.num_datasets + self.current_dataset_index = (self.current_dataset_index + 1) % self.num_datasets continue item = next(current_iterator) except StopIteration: @@ -129,9 +123,7 @@ def next(self) -> T: break # If we did't throw StopIteration, increment the number of items yielded and return the item - self.current_dataset_index = ( - self.current_dataset_index + 1 - ) % self.num_datasets + self.current_dataset_index = (self.current_dataset_index + 1) % self.num_datasets return item @@ -139,8 +131,6 @@ def get_state(self) -> Dict[str, Any]: state = { self.CURRENT_DATASET_INDEX_KEY: self.current_dataset_index, self.DATASETS_EXHAUSTED_KEY: copy.deepcopy(self._datasets_exhausted), - self.DATASET_NODE_STATES_KEY: { - k: self.source_nodes[k].state_dict() for k in range(self.num_datasets) - }, + self.DATASET_NODE_STATES_KEY: {k: self.source_nodes[k].state_dict() for k in range(self.num_datasets)}, } return state