-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
6700421
commit da4f60f
Showing
4 changed files
with
379 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,366 @@ | ||
import io | ||
import json | ||
import math | ||
import os | ||
import pathlib | ||
import random | ||
import shutil | ||
import time | ||
from typing import Iterable, Tuple | ||
|
||
import grpc | ||
import modyn.storage.internal.grpc.generated.storage_pb2 as storage_pb2 | ||
import torch | ||
import yaml | ||
from modyn.selector.internal.grpc.generated.selector_pb2 import ( | ||
DataInformRequest, | ||
JsonString, | ||
RegisterPipelineRequest, | ||
) | ||
from modyn.selector.internal.grpc.generated.selector_pb2_grpc import SelectorStub | ||
from modyn.storage.internal.grpc.generated.storage_pb2 import ( | ||
DatasetAvailableRequest, | ||
GetDatasetSizeRequest, | ||
GetDatasetSizeResponse, | ||
GetNewDataSinceRequest, | ||
GetNewDataSinceResponse, | ||
RegisterNewDatasetRequest, | ||
) | ||
from modyn.storage.internal.grpc.generated.storage_pb2_grpc import StorageStub | ||
from modyn.trainer_server.internal.dataset.data_utils import prepare_dataloaders | ||
from modyn.utils import grpc_connection_established | ||
from PIL import Image | ||
from torchvision import transforms | ||
|
||
SCRIPT_PATH = pathlib.Path(os.path.realpath(__file__)) | ||
|
||
TIMEOUT = 120 # seconds | ||
CONFIG_FILE = SCRIPT_PATH.parent.parent.parent / "modyn" / "config" / "examples" / "modyn_config.yaml" | ||
# The following path leads to a directory that is mounted into the docker container and shared with the | ||
# storage container. | ||
DATASET_PATH = pathlib.Path("/app") / "storage" / "datasets" / "test_dataset" | ||
|
||
# Because we have no mapping of file to key (happens in the storage service), we have to keep | ||
# track of the images we added to the dataset ourselves and compare them to the images we get | ||
# from the storage service. | ||
FIRST_ADDED_IMAGES = [] | ||
SECOND_ADDED_IMAGES = [] | ||
IMAGE_UPDATED_TIME_STAMPS = [] | ||
|
||
|
||
def get_modyn_config() -> dict: | ||
with open(CONFIG_FILE, "r", encoding="utf-8") as config_file: | ||
config = yaml.safe_load(config_file) | ||
|
||
return config | ||
|
||
|
||
def connect_to_selector_servicer() -> grpc.Channel: | ||
selector_address = get_selector_address() | ||
selector_channel = grpc.insecure_channel(selector_address) | ||
|
||
if not grpc_connection_established(selector_channel): | ||
raise ConnectionError(f"Could not establish gRPC connection to selector at {selector_address}.") | ||
|
||
return selector_channel | ||
|
||
|
||
def get_storage_address() -> str: | ||
config = get_modyn_config() | ||
return f"{config['storage']['hostname']}:{config['storage']['port']}" | ||
|
||
|
||
def get_selector_address() -> str: | ||
config = get_modyn_config() | ||
return f"{config['selector']['hostname']}:{config['selector']['port']}" | ||
|
||
|
||
def connect_to_storage() -> grpc.Channel: | ||
storage_address = get_storage_address() | ||
storage_channel = grpc.insecure_channel(storage_address) | ||
|
||
if not grpc_connection_established(storage_channel) or storage_channel is None: | ||
raise ConnectionError(f"Could not establish gRPC connection to storage at {storage_address}.") | ||
|
||
return storage_channel | ||
|
||
|
||
def register_new_dataset() -> None: | ||
storage_channel = connect_to_storage() | ||
|
||
storage = StorageStub(storage_channel) | ||
|
||
request = RegisterNewDatasetRequest( | ||
base_path=str(DATASET_PATH), | ||
dataset_id="test_dataset", | ||
description="Test dataset for integration tests.", | ||
file_wrapper_config=json.dumps({"file_extension": ".png", "label_file_extension": ".txt"}), | ||
file_wrapper_type="SingleSampleFileWrapper", | ||
filesystem_wrapper_type="LocalFilesystemWrapper", | ||
version="0.1.0", | ||
) | ||
|
||
response = storage.RegisterNewDataset(request) | ||
|
||
assert response.success, "Could not register new dataset." | ||
|
||
|
||
def check_dataset_availability() -> None: | ||
storage_channel = connect_to_storage() | ||
|
||
storage = StorageStub(storage_channel) | ||
|
||
request = DatasetAvailableRequest(dataset_id="test_dataset") | ||
response = storage.CheckAvailability(request) | ||
|
||
assert response.available, "Dataset is not available." | ||
|
||
|
||
def check_dataset_size(expected_size: int) -> None: | ||
storage_channel = connect_to_storage() | ||
|
||
storage = StorageStub(storage_channel) | ||
request = GetDatasetSizeRequest(dataset_id="test_dataset") | ||
response: GetDatasetSizeResponse = storage.GetDatasetSize(request) | ||
|
||
assert response.success, "Dataset is not available." | ||
assert response.num_keys == expected_size | ||
|
||
|
||
def check_dataset_size_invalid() -> None: | ||
storage_channel = connect_to_storage() | ||
|
||
storage = StorageStub(storage_channel) | ||
request = GetDatasetSizeRequest(dataset_id="unknown_dataset") | ||
response: GetDatasetSizeResponse = storage.GetDatasetSize(request) | ||
|
||
assert not response.success, "Dataset is available (even though it should not be)." | ||
|
||
|
||
def check_get_current_timestamp() -> None: | ||
storage_channel = connect_to_storage() | ||
storage = StorageStub(storage_channel) | ||
empty = storage_pb2.google_dot_protobuf_dot_empty__pb2.Empty() | ||
response = storage.GetCurrentTimestamp(empty) | ||
|
||
assert response.timestamp > 0, "Timestamp is not valid." | ||
|
||
|
||
def create_dataset_dir() -> None: | ||
pathlib.Path(DATASET_PATH).mkdir(parents=True, exist_ok=True) | ||
|
||
|
||
def cleanup_dataset_dir() -> None: | ||
shutil.rmtree(DATASET_PATH) | ||
|
||
|
||
def cleanup_storage_database() -> None: | ||
storage_channel = connect_to_storage() | ||
storage = StorageStub(storage_channel) | ||
request = DatasetAvailableRequest(dataset_id="test_dataset") | ||
response = storage.DeleteDataset(request) | ||
|
||
assert response.success, "Could not cleanup storage database." | ||
|
||
|
||
def add_image_to_dataset(image: Image, name: str) -> None: | ||
image.save(DATASET_PATH / name) | ||
IMAGE_UPDATED_TIME_STAMPS.append(int(round(os.path.getmtime(DATASET_PATH / name) * 1000))) | ||
|
||
|
||
def create_random_image() -> Image: | ||
image = Image.new("RGB", (100, 100)) | ||
random_x = random.randint(0, 99) | ||
random_y = random.randint(0, 99) | ||
|
||
random_r = random.randint(0, 254) | ||
random_g = random.randint(0, 254) | ||
random_b = random.randint(0, 254) | ||
|
||
image.putpixel((random_x, random_y), (random_r, random_g, random_b)) | ||
|
||
return image | ||
|
||
|
||
def add_images_to_dataset(start_number: int, end_number: int, images_added: list[bytes]) -> None: | ||
create_dataset_dir() | ||
|
||
for i in range(start_number, end_number): | ||
image = create_random_image() | ||
add_image_to_dataset(image, f"image_{i}.png") | ||
images_added.append(image.tobytes()) | ||
with open(DATASET_PATH / f"image_{i}.txt", "w") as label_file: | ||
label_file.write(f"{i}") | ||
|
||
|
||
def prepare_selector(num_dataworkers: int, keys: list[int]) -> Tuple[int, int]: | ||
selector_channel = connect_to_selector_servicer() | ||
selector = SelectorStub(selector_channel) | ||
# We test the NewData strategy for finetuning on the new data, i.e., we reset without limit | ||
# We also enforce high partitioning (maximum_keys_in_memory == 2) to ensure that works | ||
|
||
strategy_config = { | ||
"name": "NewDataStrategy", | ||
"maximum_keys_in_memory": 2, | ||
"config": {"limit": -1, "reset_after_trigger": True}, | ||
} | ||
|
||
pipeline_id = selector.register_pipeline( | ||
RegisterPipelineRequest( | ||
num_workers=max(num_dataworkers, 1), selection_strategy=JsonString(value=json.dumps(strategy_config)) | ||
) | ||
).pipeline_id | ||
|
||
trigger_id = selector.inform_data_and_trigger( | ||
DataInformRequest( | ||
pipeline_id=pipeline_id, | ||
keys=keys, | ||
timestamps=[2 for _ in range(len(keys))], | ||
labels=[3 for _ in range(len(keys))], | ||
) | ||
).trigger_id | ||
|
||
return pipeline_id, trigger_id | ||
|
||
|
||
def get_new_data_since(timestamp: int) -> Iterable[GetNewDataSinceResponse]: | ||
storage_channel = connect_to_storage() | ||
|
||
storage = StorageStub(storage_channel) | ||
|
||
request = GetNewDataSinceRequest( | ||
dataset_id="test_dataset", | ||
timestamp=timestamp, | ||
) | ||
|
||
responses = storage.GetNewDataSince(request) | ||
return responses | ||
|
||
|
||
def get_data_keys() -> list[int]: | ||
response = None | ||
keys = [] | ||
for i in range(60): | ||
responses = list(get_new_data_since(0)) | ||
assert len(responses) < 2, f"Received batched response, shouldn't happen: {responses}" | ||
if len(responses) == 1: | ||
response = responses[0] | ||
keys = list(response.keys) | ||
if len(keys) == 10: | ||
break | ||
time.sleep(1) | ||
|
||
assert response is not None, "Did not get any response from Storage" | ||
assert len(keys) == 10, f"Not all images were returned. Images returned: {response.keys}" | ||
|
||
return keys | ||
|
||
|
||
def get_bytes_parser() -> str: | ||
return """ | ||
from PIL import Image | ||
import io | ||
def bytes_parser_function(data: bytes) -> Image: | ||
return Image.open(io.BytesIO(data)).convert("RGB")""" | ||
|
||
|
||
def tensor_in_list(tensor: torch.Tensor, tensor_list: list[torch.Tensor]) -> bool: | ||
return any([(tensor == c_).all() for c_ in tensor_list]) | ||
|
||
|
||
def test_dataset_impl( | ||
num_dataworkers: int, | ||
batch_size: int, | ||
prefetched_partitions: int, | ||
pipeline_id: int, | ||
trigger_id: int, | ||
items: list[int], | ||
) -> None: | ||
dataloader, _ = prepare_dataloaders( | ||
pipeline_id, | ||
trigger_id, | ||
"test_dataset", | ||
num_dataworkers, | ||
batch_size, | ||
get_bytes_parser(), | ||
["transforms.ToTensor()"], | ||
get_storage_address(), | ||
get_selector_address(), | ||
42, | ||
prefetched_partitions, | ||
None, | ||
None, | ||
) | ||
|
||
expected_batches = math.ceil(len(items) / batch_size) | ||
all_samples = [] | ||
all_data = [] | ||
all_labels = [] | ||
|
||
for batch_number, batch in enumerate(dataloader): | ||
sample_ids = batch[0] | ||
if isinstance(sample_ids, torch.Tensor): | ||
sample_ids = sample_ids.tolist() | ||
elif isinstance(sample_ids, tuple): | ||
sample_ids = list(sample_ids) | ||
|
||
assert isinstance(sample_ids, list), "Cannot parse result from DataLoader" | ||
assert isinstance(batch[1], torch.Tensor) and isinstance(batch[2], torch.Tensor) | ||
|
||
all_samples.extend(sample_ids) | ||
all_data.extend(batch[1].tolist()) | ||
all_labels.extend(batch[2].tolist()) | ||
|
||
assert len(all_samples) == len(items) | ||
assert len(all_data) == len(items) | ||
assert len(all_data) == len(items) | ||
assert batch_number + 1 == expected_batches, ( | ||
f"[{num_dataworkers}][{batch_size}][{prefetched_partitions}]" | ||
+ f"Wrong number of batches: {batch_number + 1}. num_items = {len(items)}" | ||
) | ||
|
||
assert set(all_samples) == set(items) | ||
assert set(all_labels) == set(range(len(items))) | ||
|
||
trans = transforms.Compose([transforms.ToTensor()]) | ||
|
||
for idx, image in enumerate(FIRST_ADDED_IMAGES): | ||
parsed_image = trans(Image.open(io.BytesIO(image))) | ||
assert tensor_in_list( | ||
parsed_image, all_data | ||
), f"Could not find image {idx} in all_data, all_samples = {all_samples}" | ||
|
||
|
||
def test_dataset() -> None: | ||
NUM_IMAGES = 10 | ||
|
||
check_get_current_timestamp() # Check if the storage service is available. | ||
create_dataset_dir() | ||
add_images_to_dataset(0, NUM_IMAGES, FIRST_ADDED_IMAGES) # Add images to the dataset. | ||
register_new_dataset() | ||
check_dataset_availability() # Check if the dataset is available. | ||
check_dataset_size_invalid() | ||
|
||
keys = get_data_keys() | ||
|
||
for num_dataworkers in [0, 1, 2, 4, 8, 16]: | ||
pipeline_id, trigger_id = prepare_selector(num_dataworkers, keys) | ||
for prefetched_partitions in [0, 1, 2, 3, 4, 5, 999]: | ||
for batch_size in [1, 2, 10]: | ||
print( | ||
f"Testing num_workers = {num_dataworkers}, partitions = {prefetched_partitions}," | ||
+ f"batch_size = {batch_size}" | ||
) | ||
test_dataset_impl(num_dataworkers, batch_size, prefetched_partitions, pipeline_id, trigger_id, keys) | ||
|
||
|
||
def main() -> None: | ||
try: | ||
test_dataset() | ||
finally: | ||
cleanup_dataset_dir() | ||
cleanup_storage_database() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.