diff --git a/modyn/tests/trainer_server/internal/data/test_online_dataset.py b/modyn/tests/trainer_server/internal/data/test_online_dataset.py index 1fe9ed74e..3231e3a0a 100644 --- a/modyn/tests/trainer_server/internal/data/test_online_dataset.py +++ b/modyn/tests/trainer_server/internal/data/test_online_dataset.py @@ -447,7 +447,11 @@ def test_dataloader_dataset_weighted( ) @patch("modyn.trainer_server.internal.dataset.online_dataset.grpc_connection_established", return_value=True) @patch.object(grpc, "insecure_channel", return_value=None) -@patch.object(OnlineDataset, "_get_data_from_storage", return_value=([x.to_bytes(2, "big") for x in range(4)], [1] * 4)) +@patch.object( + OnlineDataset, + "_get_data_from_storage", + return_value=[(list(range(4)), [x.to_bytes(2, "big") for x in range(4)], [1] * 4, 0)], +) @patch.object(SelectorKeySource, "get_keys_and_weights", return_value=(list(range(4)), None)) @patch.object(SelectorKeySource, "get_num_data_partitions", return_value=1) def test_dataloader_dataset_multi_worker( diff --git a/modyn/tests/trainer_server/internal/data/test_per_class_online_dataset.py b/modyn/tests/trainer_server/internal/data/test_per_class_online_dataset.py index 5c9cf10d2..26b3dd3fc 100644 --- a/modyn/tests/trainer_server/internal/data/test_per_class_online_dataset.py +++ b/modyn/tests/trainer_server/internal/data/test_per_class_online_dataset.py @@ -47,7 +47,7 @@ def Get(self, request): # pylint: disable=invalid-name @patch.object( PerClassOnlineDataset, "_get_data_from_storage", - return_value=([x.to_bytes(2, "big") for x in range(16)], [0, 1, 2, 3, 0, 0, 0, 1] * 2), + return_value=[(list(range(16)), [x.to_bytes(2, "big") for x in range(16)], [0, 1, 2, 3, 0, 0, 0, 1] * 2, 0)], ) @patch.object(SelectorKeySource, "get_keys_and_weights", return_value=(list(range(16)), None)) @patch.object(SelectorKeySource, "get_num_data_partitions", return_value=1) diff --git a/modyn/trainer_server/internal/dataset/online_dataset.py b/modyn/trainer_server/internal/dataset/online_dataset.py index 442debaac..0d6f2957c 100644 --- a/modyn/trainer_server/internal/dataset/online_dataset.py +++ b/modyn/trainer_server/internal/dataset/online_dataset.py @@ -377,7 +377,7 @@ def __iter__(self) -> Generator: self._prefetched_partitions = min(self._prefetched_partitions, self._num_partitions) for data_tuple in self.all_partition_generator(worker_id): - if data_tuple is not None: # Can happen in subclasses overwriting generator - yield self._get_transformed_data_tuple(*data_tuple) + if (transformed_tuple := self._get_transformed_data_tuple(*data_tuple)) is not None: + yield transformed_tuple self._persist_log(worker_id) diff --git a/modyn/trainer_server/internal/dataset/per_class_online_dataset.py b/modyn/trainer_server/internal/dataset/per_class_online_dataset.py index 40404c1a5..5a3c6fbec 100644 --- a/modyn/trainer_server/internal/dataset/per_class_online_dataset.py +++ b/modyn/trainer_server/internal/dataset/per_class_online_dataset.py @@ -39,7 +39,9 @@ def __init__( assert initial_filtered_label is not None self.filtered_label = initial_filtered_label - def _get_data_tuple(self, key: int, sample: bytes, label: int, weight: Optional[float]) -> Optional[Tuple]: + def _get_transformed_data_tuple( + self, key: int, sample: bytes, label: int, weight: Optional[float] + ) -> Optional[Tuple]: assert self.filtered_label is not None if self.filtered_label != label: