Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multidataset round robin sampler #1370

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions test/nodes/test_multi_node_round_robin_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import collections
import itertools
from enum import unique

from parameterized import parameterized
from torch.testing._internal.common_utils import TestCase
from torchdata.nodes.adapters import IterableWrapper
from torchdata.nodes.batch import Batcher
from torchdata.nodes.prefetch import Prefetcher
from torchdata.nodes.samplers.multi_node_round_robin_sampler import MultiNodeRoundRobinSampler
from torchdata.nodes.samplers.stop_criteria import StopCriteria

from .utils import DummyIterableDataset, run_test_save_load_state


class TestMultiNodeRoundRobinSampler(TestCase):
def setUp(self) -> None:
super().setUp()
self._num_samples = 1
self._num_datasets = 3

def get_equal_dataset(self, num_samples, num_datasets):
datasets = {f"ds{i}": IterableWrapper(DummyIterableDataset(num_samples, f"ds{i}")) for i in range(num_datasets)}
return datasets

def get_unequal_dataset(self, num_samples, num_datasets):
datasets = {
f"ds{i}": IterableWrapper(DummyIterableDataset(num_samples + i, f"ds{i}")) for i in range(num_datasets)
}
return datasets

def test_multi_node_round_robin_sampler_equal_dataset(self) -> None:
datasets = self.get_equal_dataset(self._num_samples, self._num_datasets)
sampler = MultiNodeRoundRobinSampler(datasets, StopCriteria.FIRST_DATASET_EXHAUSTED)
batch_size = 3
num_epochs = 1
# each dataset has 1 sample, so the first and only epoch must be ['ds0', 'ds1', 'ds2']
batcher = Batcher(sampler, batch_size=batch_size)
for _ in range(num_epochs):
results = next(batcher)
self.assertGreater(len(results), 0)
datasets_in_results = [result["name"] for result in results]
dataset_counts_in_results = collections.Counter(datasets_in_results)
for key in dataset_counts_in_results:
self.assertEqual(dataset_counts_in_results[key], 1)

def test_multi_node_round_robin_sampler_unequal_dataset(self) -> None:
datasets = self.get_unequal_dataset(self._num_samples, self._num_datasets)
sampler = MultiNodeRoundRobinSampler(datasets, StopCriteria.ALL_DATASETS_EXHAUSTED)
batch_size = 3
num_epochs = 2
batcher = Batcher(sampler, batch_size=batch_size)
# In this case, first epoch must be ['ds0', 'ds1', 'ds2'] and second epoch must be ['ds1', 'ds2', 'ds2']
for epoch in range(num_epochs):
results = next(batcher)
self.assertGreater(len(results), 0)
datasets_in_results = [result["name"] for result in results]
dataset_counts_in_results = collections.Counter(datasets_in_results)
if epoch == 0:
self.assertEqual(len(dataset_counts_in_results), self._num_datasets)
for key in dataset_counts_in_results:
self.assertEqual(dataset_counts_in_results[key], 1)
elif epoch == 1:
self.assertEqual(len(dataset_counts_in_results), self._num_datasets - 1)
for key in dataset_counts_in_results:
if key == "ds0":
self.assertEqual(dataset_counts_in_results[key], 0)
if key == "ds1":
self.assertEqual(dataset_counts_in_results[key], 1)
else:
self.assertEqual(dataset_counts_in_results[key], 2)

def test_get_state(self) -> None:
datasets = self.get_equal_dataset(self._num_samples, self._num_datasets)
sampler = MultiNodeRoundRobinSampler(datasets, StopCriteria.FIRST_DATASET_EXHAUSTED)
state = sampler.get_state()
self.assertIn("current_dataset_index", state)
self.assertIn("datasets_exhausted", state)
self.assertIn("dataset_node_states", state)

def test_multi_node_round_robin_large_sample_size(self) -> None:
num_samples = 1500
num_datasets = 3
datasets = self.get_equal_dataset(num_samples, num_datasets)
sampler = MultiNodeRoundRobinSampler(datasets, StopCriteria.ALL_DATASETS_EXHAUSTED)
prefetcher = Prefetcher(sampler, 3)
run_test_save_load_state(self, prefetcher, 400)
136 changes: 136 additions & 0 deletions torchdata/nodes/samplers/multi_node_round_robin_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import copy
from typing import Any, Dict, Mapping, Optional

from torchdata.nodes.base_node import BaseNode, T
from torchdata.nodes.samplers.stop_criteria import StopCriteria


class MultiNodeRoundRobinSampler(BaseNode[T]):
"""A node that samples from multiple datasets in a round robin fashion.

This node expects to take in a dictionary of source nodes.

The node implements the state using the following keys:
- CURRENT_DATASET_INDEX_KEY: The index of the current dataset.
- DATASETS_EXHAUSTED_KEY: A dictionary of booleans indicating whether each source node is exhausted.
- DATASET_NODE_STATES_KEY: A dictionary of states for each source node.

We support multiple stopping criteria:
- CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED: Cycle through the source nodes until all datasets
are exhausted. This is the default behavior.
- FIRST_DATASET_EXHAUSTED: Stop when the first dataset is exhausted.
- ALL_DATASETS_EXHAUSTED: Stop when all datasets are exhausted.

#TODO: Add examples of usage and output
On complete exhaustion of the source nodes, the node will raise StopIteration.

Parameters:
source_nodes (Mapping[str, BaseNode[T]]): A dictionary of source nodes.
stop_criteria (str): The stopping criteria. Default is CYCLE_UNTIL_ALL_DATASETS_EXHAUST.
"""

CURRENT_DATASET_INDEX_KEY = "current_dataset_index"
DATASET_NODE_STATES_KEY = "dataset_node_states"
DATASETS_EXHAUSTED_KEY = "datasets_exhausted"

def __init__(
self,
source_nodes: Mapping[str, BaseNode[T]],
stop_criteria: str = StopCriteria.CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED,
) -> None:
super().__init__()
self.source_nodes = [source_nodes[k] for k in source_nodes.keys()]
self.num_datasets = len(self.source_nodes)
self.stop_criteria = stop_criteria
self.current_dataset_index = 0
self._validate_stop_criteria()
self._datasets_exhausted = [False for _ in range(self.num_datasets)]

def _validate_stop_criteria(self) -> None:
if self.stop_criteria not in [
StopCriteria.CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED,
StopCriteria.ALL_DATASETS_EXHAUSTED,
StopCriteria.FIRST_DATASET_EXHAUSTED,
]:
raise ValueError(
f"Invalid {self.stop_criteria=}. stop_criteria must be one of: CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED, , ALL_DATASETS_EXHAUSTED"
)

def reset(self, initial_state: Optional[Dict[str, Any]] = None):
super().reset(initial_state)

if initial_state is not None:
self._datasets_exhausted = initial_state[self.DATASETS_EXHAUSTED_KEY]
for k in range(self.num_datasets):
self.source_nodes[k].reset(initial_state[self.DATASET_NODE_STATES_KEY][k])
else:
# Force a fresh iterator from all source nodes
self._datasets_exhausted = [False for _ in range(self.num_datasets)]
for k in range(self.num_datasets):
self.source_nodes[k].reset()

def _check_for_stop_iteration(self) -> None:
if all(self._datasets_exhausted):
# Raise StopIteration if all datasets are exhausted,
# this covers for both ALL_DATASETS_EXHAUSTED and CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED
raise StopIteration()

# Raise StopIteration is StopCriteria is FIRST_DATASET_EXHAUSTED and
# the first dataset is exhausted. Doing this to correctly catch StopIteration
# when trying next(it) on already exhausted iterator
if self.stop_criteria == StopCriteria.FIRST_DATASET_EXHAUSTED and any(self._datasets_exhausted):
raise StopIteration()

return

def next(self) -> T:

while True:
self._check_for_stop_iteration()

current_iterator = self.source_nodes[self.current_dataset_index]
try:
if (
self._datasets_exhausted[self.current_dataset_index]
and self.stop_criteria == StopCriteria.ALL_DATASETS_EXHAUSTED
):
# Before fetching a new item check if the current dataset is already
# exhaused and if StopCriteria is ALL_DATASETS_EXHAUSTED, move to next dataset
self.current_dataset_index = (self.current_dataset_index + 1) % self.num_datasets
continue
item = next(current_iterator)
except StopIteration:
# Mark the dataset as exhausted
self._datasets_exhausted[self.current_dataset_index] = True

# Based on updated _check_for_stop_iteration, check if we should raise StopIteration
self._check_for_stop_iteration()

# If StopCriteria is ALL_DATASETS_EXHAUSTED, move to next dataset
if self.stop_criteria == StopCriteria.ALL_DATASETS_EXHAUSTED:
continue

# If StopCriteria is CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED,
# reset the iterator and try again
self.source_nodes[self.current_dataset_index].reset()
item = next(self.source_nodes[self.current_dataset_index])
break

# If we did't throw StopIteration, increment the number of items yielded and return the item
self.current_dataset_index = (self.current_dataset_index + 1) % self.num_datasets

return item

def get_state(self) -> Dict[str, Any]:
state = {
self.CURRENT_DATASET_INDEX_KEY: self.current_dataset_index,
self.DATASETS_EXHAUSTED_KEY: copy.deepcopy(self._datasets_exhausted),
self.DATASET_NODE_STATES_KEY: {k: self.source_nodes[k].state_dict() for k in range(self.num_datasets)},
}
return state
Loading