diff --git a/integrationtests/online_dataset/test_online_dataset.py b/integrationtests/online_dataset/test_online_dataset.py index 7b83480f9..9b11339e4 100644 --- a/integrationtests/online_dataset/test_online_dataset.py +++ b/integrationtests/online_dataset/test_online_dataset.py @@ -1,4 +1,3 @@ -import io import json import math import os @@ -12,11 +11,7 @@ import modyn.storage.internal.grpc.generated.storage_pb2 as storage_pb2 import torch import yaml -from modyn.selector.internal.grpc.generated.selector_pb2 import ( - DataInformRequest, - JsonString, - RegisterPipelineRequest, -) +from modyn.selector.internal.grpc.generated.selector_pb2 import DataInformRequest, JsonString, RegisterPipelineRequest from modyn.selector.internal.grpc.generated.selector_pb2_grpc import SelectorStub from modyn.storage.internal.grpc.generated.storage_pb2 import ( DatasetAvailableRequest, @@ -258,10 +253,10 @@ def get_data_keys() -> list[int]: def get_bytes_parser() -> str: return """ - from PIL import Image - import io - def bytes_parser_function(data: bytes) -> Image: - return Image.open(io.BytesIO(data)).convert("RGB")""" +from PIL import Image +import io +def bytes_parser_function(data: bytes) -> Image: + return Image.open(io.BytesIO(data)).convert("RGB")""" def tensor_in_list(tensor: torch.Tensor, tensor_list: list[torch.Tensor]) -> bool: @@ -292,7 +287,10 @@ def test_dataset_impl( None, ) - expected_batches = math.ceil(len(items) / batch_size) + expected_min_batches = math.floor(len(items) / batch_size) + # max one excess batch per worker + expected_max_batches = expected_min_batches if num_dataworkers <= 1 else expected_min_batches + num_dataworkers + all_samples = [] all_data = [] all_labels = [] @@ -308,27 +306,32 @@ def test_dataset_impl( assert isinstance(batch[1], torch.Tensor) and isinstance(batch[2], torch.Tensor) all_samples.extend(sample_ids) - all_data.extend(batch[1].tolist()) + for sample in batch[1]: + all_data.append(sample) # iterate over batch dimension to extract samples all_labels.extend(batch[2].tolist()) assert len(all_samples) == len(items) + assert len(all_labels) == len(items) assert len(all_data) == len(items) - assert len(all_data) == len(items) - assert batch_number + 1 == expected_batches, ( + + assert expected_min_batches <= batch_number + 1 <= expected_max_batches, ( f"[{num_dataworkers}][{batch_size}][{prefetched_partitions}]" - + f"Wrong number of batches: {batch_number + 1}. num_items = {len(items)}" + + f"Wrong number of batches: {batch_number + 1}. num_items = {len(items)}." + + f"expected_min = {expected_min_batches}, expected_max = {expected_max_batches}" ) assert set(all_samples) == set(items) assert set(all_labels) == set(range(len(items))) - trans = transforms.Compose([transforms.ToTensor()]) + trans = transforms.Compose([transforms.ToPILImage()]) + + assert len(FIRST_ADDED_IMAGES) == len(all_data) - for idx, image in enumerate(FIRST_ADDED_IMAGES): - parsed_image = trans(Image.open(io.BytesIO(image))) - assert tensor_in_list( - parsed_image, all_data - ), f"Could not find image {idx} in all_data, all_samples = {all_samples}" + for idx, image_tensor in enumerate(all_data): + pil_image = trans(image_tensor).convert("RGB") + image_bytes = pil_image.tobytes() + if image_bytes not in FIRST_ADDED_IMAGES: + raise ValueError(f"Could not find image {idx} in created images, all_samples = {all_samples}") def test_dataset() -> None: