Skip to content

Commit

Permalink
update seeds
Browse files Browse the repository at this point in the history
  • Loading branch information
ramanishsingh committed Jan 22, 2025
1 parent 1060f3c commit ee02a46
Showing 1 changed file with 40 additions and 17 deletions.
57 changes: 40 additions & 17 deletions test/nodes/test_multi_node_weighted_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def setUp(self) -> None:
self._num_datasets = 4
self._weights_fn = lambda i: 0.1 * (i + 1)
self._num_epochs = 3
self._seed = 42
self.datasets = {
f"ds{i}": IterableWrapper(DummyIterableDataset(self._num_samples, f"ds{i}"))
for i in range(self._num_datasets)
Expand All @@ -39,7 +38,7 @@ def test_torchdata_nodes_imports(self) -> None:
self.fail("MultiNodeWeightedSampler or StopCriteria failed to import")

def _setup_multi_node_weighted_sampler(
self, num_samples, num_datasets, weights_fn, stop_criteria, seed
self, num_samples, num_datasets, weights_fn, stop_criteria, seed=0
) -> Prefetcher:

datasets = {f"ds{i}": IterableWrapper(DummyIterableDataset(num_samples, f"ds{i}")) for i in range(num_datasets)}
Expand All @@ -54,7 +53,8 @@ def test_multi_node_weighted_sampler_weight_sampler_keys_mismatch(self) -> None:
"keys of source_nodes and weights must be the same",
):
MultiNodeWeightedSampler(
self.datasets, {f"dummy{i}": self._weights_fn(i) for i in range(self._num_datasets)}, seed=self._seed
self.datasets,
{f"dummy{i}": self._weights_fn(i) for i in range(self._num_datasets)},
)

def test_multi_node_weighted_batch_sampler_invalid_weights_tensor_shape(
Expand All @@ -63,7 +63,8 @@ def test_multi_node_weighted_batch_sampler_invalid_weights_tensor_shape(
"""Test validation logic for MultiNodeWeightedSampler if the shape of the weights tensor is invalid"""
with self.assertRaisesRegex(ValueError, " weights must be a 1d sequence, non-negative, and non-zero"):
MultiNodeWeightedSampler(
self.datasets, weights={f"ds{i}": [[1.0]] for i in range(self._num_datasets)}, seed=self._seed
self.datasets,
weights={f"ds{i}": [[1.0]] for i in range(self._num_datasets)},
)

def test_multi_node_weighted_batch_sampler_negative_weights(
Expand All @@ -72,7 +73,8 @@ def test_multi_node_weighted_batch_sampler_negative_weights(
"""Test validation logic for MultiNodeWeightedSampler if the value of the weights tensor is invalid"""
with self.assertRaisesRegex(ValueError, " weights must be a 1d sequence, non-negative, and non-zero"):
MultiNodeWeightedSampler(
self.datasets, weights={f"ds{i}": -1 for i in range(self._num_datasets)}, seed=self._seed
self.datasets,
weights={f"ds{i}": -1 for i in range(self._num_datasets)},
)

def test_multi_node_weighted_batch_sampler_zero_weights(
Expand All @@ -81,10 +83,11 @@ def test_multi_node_weighted_batch_sampler_zero_weights(
"""Test validation logic for MultiNodeWeightedSampler if the value of the weights tensor is invalid"""
with self.assertRaisesRegex(ValueError, " weights must be a 1d sequence, non-negative, and non-zero"):
MultiNodeWeightedSampler(
self.datasets, weights={f"ds{i}": 10 * i for i in range(self._num_datasets)}, seed=self._seed
self.datasets,
weights={f"ds{i}": 10 * i for i in range(self._num_datasets)},
)

@parameterized.expand([0, 1, 42])
@parameterized.expand(range(10))
def test_multi_node_weighted_sampler_first_exhausted(self, seed) -> None:
"""Test MultiNodeWeightedSampler with stop criteria FIRST_DATASET_EXHAUSTED"""
mixer = self._setup_multi_node_weighted_sampler(
Expand All @@ -108,7 +111,7 @@ def test_multi_node_weighted_sampler_first_exhausted(self, seed) -> None:
self.assertGreaterEqual(dataset_counts_in_results.count(self._num_samples), 1)
mixer.reset()

@parameterized.expand([0, 1, 42])
@parameterized.expand(range(10))
def test_multi_node_weighted_sampler_all_dataset_exhausted(self, seed) -> None:
"""Test MultiNodeWeightedSampler with stop criteria ALL_DATASETS_EXHAUSTED"""
mixer = self._setup_multi_node_weighted_sampler(
Expand All @@ -135,7 +138,7 @@ def test_multi_node_weighted_sampler_all_dataset_exhausted(self, seed) -> None:
self.assertEqual(sorted(set(datasets_in_results)), ["ds0", "ds1", "ds2", "ds3"])
mixer.reset()

@parameterized.expand([0, 1, 42])
@parameterized.expand(range(10))
def test_multi_node_weighted_sampler_cycle_until_all_exhausted(self, seed) -> None:
"""Test MultiNodeWeightedSampler with stop criteria CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED"""
mixer = self._setup_multi_node_weighted_sampler(
Expand All @@ -154,7 +157,7 @@ def test_multi_node_weighted_sampler_cycle_until_all_exhausted(self, seed) -> No
self.assertEqual(sorted(datasets_in_results), ["ds0", "ds1", "ds2", "ds3"])
mixer.reset()

@parameterized.expand([0, 1, 42])
@parameterized.expand(range(10))
def test_multi_node_weighted_sampler_cycle_forever(self, seed) -> None:
"""Test MultiNodeWeightedSampler with stop criteria CYCLE_FOREVER"""
mixer = MultiNodeWeightedSampler(
Expand All @@ -173,7 +176,12 @@ def test_multi_node_weighted_sampler_cycle_forever(self, seed) -> None:
@parameterized.expand([(1, 8), (8, 32)])
def test_multi_node_weighted_batch_sampler_set_rank_world_size(self, rank, world_size):
"""Test MultiNodeWeightedSampler with different rank and world size"""
mixer = MultiNodeWeightedSampler(self.datasets, self.weights, rank=rank, world_size=world_size, seed=self._seed)
mixer = MultiNodeWeightedSampler(
self.datasets,
self.weights,
rank=rank,
world_size=world_size,
)
self.assertEqual(mixer.rank, rank)
self.assertEqual(mixer.world_size, world_size)

Expand All @@ -183,7 +191,10 @@ def test_multi_node_weighted_batch_sampler_results_for_ranks(self):
global_results = []
for rank in range(world_size):
mixer = MultiNodeWeightedSampler(
self.datasets, self.weights, rank=rank, world_size=world_size, seed=self._seed
self.datasets,
self.weights,
rank=rank,
world_size=world_size,
)
results = list(mixer)
global_results.append(results)
Expand All @@ -198,7 +209,10 @@ def test_multi_node_weighted_batch_sampler_results_for_multiple_epochs(self):
"""Test MultiNodeWeightedSampler with different results in each epoch"""

# Check for the mixer node only
mixer = MultiNodeWeightedSampler(self.datasets, self.weights, seed=self._seed)
mixer = MultiNodeWeightedSampler(
self.datasets,
self.weights,
)

overall_results = []
for _ in range(self._num_epochs):
Expand All @@ -219,7 +233,6 @@ def test_multi_node_weighted_batch_sampler_results_for_multiple_epochs(self):
self._num_datasets,
self._weights_fn,
stop_criteria=StopCriteria.CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED,
seed=self._seed,
)

overall_results = []
Expand Down Expand Up @@ -247,7 +260,11 @@ def test_multi_node_weighted_batch_sampler_results_for_multiple_epochs(self):
)
def test_save_load_state_mixer_over_multiple_epochs(self, midpoint: int, stop_criteria: str):
"""Test MultiNodeWeightedSampler with saving and loading of state across multiple epochs"""
node = MultiNodeWeightedSampler(self.datasets, self.weights, stop_criteria, seed=self._seed)
node = MultiNodeWeightedSampler(
self.datasets,
self.weights,
stop_criteria,
)
run_test_save_load_state(self, node, midpoint)

@parameterized.expand(
Expand All @@ -262,7 +279,10 @@ def test_save_load_state_mixer_over_multiple_epochs(self, midpoint: int, stop_cr
)
def test_save_load_state_mixer_over_multiple_epochs_with_prefetcher(self, midpoint: int, stop_criteria: str):
node = self._setup_multi_node_weighted_sampler(
self._num_samples, self._num_datasets, self._weights_fn, stop_criteria=stop_criteria, seed=self._seed
self._num_samples,
self._num_datasets,
self._weights_fn,
stop_criteria=stop_criteria,
)
run_test_save_load_state(self, node, midpoint)

Expand All @@ -282,6 +302,9 @@ def test_multi_node_weighted_large_sample_size_with_prefetcher(self, midpoint, s
num_datasets = 5

mixer = self._setup_multi_node_weighted_sampler(
num_samples, num_datasets, self._weights_fn, stop_criteria, seed=self._seed
num_samples,
num_datasets,
self._weights_fn,
stop_criteria,
)
run_test_save_load_state(self, mixer, midpoint)

0 comments on commit ee02a46

Please sign in to comment.