Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxiBoether committed Sep 26, 2023
1 parent ad42520 commit 2d53c0b
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,11 @@ def test_dataloader_dataset_weighted(
)
@patch("modyn.trainer_server.internal.dataset.online_dataset.grpc_connection_established", return_value=True)
@patch.object(grpc, "insecure_channel", return_value=None)
@patch.object(OnlineDataset, "_get_data_from_storage", return_value=([x.to_bytes(2, "big") for x in range(4)], [1] * 4))
@patch.object(
OnlineDataset,
"_get_data_from_storage",
return_value=[(list(range(4)), [x.to_bytes(2, "big") for x in range(4)], [1] * 4, 0)],
)
@patch.object(SelectorKeySource, "get_keys_and_weights", return_value=(list(range(4)), None))
@patch.object(SelectorKeySource, "get_num_data_partitions", return_value=1)
def test_dataloader_dataset_multi_worker(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def Get(self, request): # pylint: disable=invalid-name
@patch.object(
PerClassOnlineDataset,
"_get_data_from_storage",
return_value=([x.to_bytes(2, "big") for x in range(16)], [0, 1, 2, 3, 0, 0, 0, 1] * 2),
return_value=[(list(range(16)), [x.to_bytes(2, "big") for x in range(16)], [0, 1, 2, 3, 0, 0, 0, 1] * 2, 0)],
)
@patch.object(SelectorKeySource, "get_keys_and_weights", return_value=(list(range(16)), None))
@patch.object(SelectorKeySource, "get_num_data_partitions", return_value=1)
Expand Down
4 changes: 2 additions & 2 deletions modyn/trainer_server/internal/dataset/online_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def __iter__(self) -> Generator:
self._prefetched_partitions = min(self._prefetched_partitions, self._num_partitions)

for data_tuple in self.all_partition_generator(worker_id):
if data_tuple is not None: # Can happen in subclasses overwriting generator
yield self._get_transformed_data_tuple(*data_tuple)
if (transformed_tuple := self._get_transformed_data_tuple(*data_tuple)) is not None:
yield transformed_tuple

self._persist_log(worker_id)
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ def __init__(
assert initial_filtered_label is not None
self.filtered_label = initial_filtered_label

def _get_data_tuple(self, key: int, sample: bytes, label: int, weight: Optional[float]) -> Optional[Tuple]:
def _get_transformed_data_tuple(
self, key: int, sample: bytes, label: int, weight: Optional[float]
) -> Optional[Tuple]:
assert self.filtered_label is not None

if self.filtered_label != label:
Expand Down

0 comments on commit 2d53c0b

Please sign in to comment.