diff --git a/modyn/tests/trainer_server/internal/data/test_online_dataset.py b/modyn/tests/trainer_server/internal/data/test_online_dataset.py index 3231e3a0a..fe88e7327 100644 --- a/modyn/tests/trainer_server/internal/data/test_online_dataset.py +++ b/modyn/tests/trainer_server/internal/data/test_online_dataset.py @@ -794,7 +794,9 @@ def test_iter_multi_partition_multi_workers( assert torch.equal(batch[0], torch.Tensor([0, 1, 2, 3])) assert torch.equal(batch[1], torch.Tensor([0, 1, 2, 3])) assert torch.equal(batch[2], torch.ones(4, dtype=int)) - assert idx == 7 + + # each worker gets 8 items from get_keys_and_weights; batch size 4; minus one for zero indexing + assert idx == ((min(num_workers, 1) * 32) / 4) - 1 @pytest.mark.parametrize("prefetched_partitions", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 100, 999999])