Skip to content

Commit

Permalink
add onlinedataset integrationtest
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxiBoether committed Sep 26, 2023
1 parent 6700421 commit da4f60f
Show file tree
Hide file tree
Showing 4 changed files with 379 additions and 5 deletions.
366 changes: 366 additions & 0 deletions integrationtests/online_dataset/test_online_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,366 @@
import io
import json
import math
import os
import pathlib
import random
import shutil
import time
from typing import Iterable, Tuple

import grpc
import modyn.storage.internal.grpc.generated.storage_pb2 as storage_pb2
import torch
import yaml
from modyn.selector.internal.grpc.generated.selector_pb2 import (
DataInformRequest,
JsonString,
RegisterPipelineRequest,
)
from modyn.selector.internal.grpc.generated.selector_pb2_grpc import SelectorStub
from modyn.storage.internal.grpc.generated.storage_pb2 import (
DatasetAvailableRequest,
GetDatasetSizeRequest,
GetDatasetSizeResponse,
GetNewDataSinceRequest,
GetNewDataSinceResponse,
RegisterNewDatasetRequest,
)
from modyn.storage.internal.grpc.generated.storage_pb2_grpc import StorageStub
from modyn.trainer_server.internal.dataset.data_utils import prepare_dataloaders
from modyn.utils import grpc_connection_established
from PIL import Image
from torchvision import transforms

SCRIPT_PATH = pathlib.Path(os.path.realpath(__file__))

TIMEOUT = 120 # seconds
CONFIG_FILE = SCRIPT_PATH.parent.parent.parent / "modyn" / "config" / "examples" / "modyn_config.yaml"
# The following path leads to a directory that is mounted into the docker container and shared with the
# storage container.
DATASET_PATH = pathlib.Path("/app") / "storage" / "datasets" / "test_dataset"

# Because we have no mapping of file to key (happens in the storage service), we have to keep
# track of the images we added to the dataset ourselves and compare them to the images we get
# from the storage service.
FIRST_ADDED_IMAGES = []
SECOND_ADDED_IMAGES = []
IMAGE_UPDATED_TIME_STAMPS = []


def get_modyn_config() -> dict:
with open(CONFIG_FILE, "r", encoding="utf-8") as config_file:
config = yaml.safe_load(config_file)

return config


def connect_to_selector_servicer() -> grpc.Channel:
selector_address = get_selector_address()
selector_channel = grpc.insecure_channel(selector_address)

if not grpc_connection_established(selector_channel):
raise ConnectionError(f"Could not establish gRPC connection to selector at {selector_address}.")

return selector_channel


def get_storage_address() -> str:
config = get_modyn_config()
return f"{config['storage']['hostname']}:{config['storage']['port']}"


def get_selector_address() -> str:
config = get_modyn_config()
return f"{config['selector']['hostname']}:{config['selector']['port']}"


def connect_to_storage() -> grpc.Channel:
storage_address = get_storage_address()
storage_channel = grpc.insecure_channel(storage_address)

if not grpc_connection_established(storage_channel) or storage_channel is None:
raise ConnectionError(f"Could not establish gRPC connection to storage at {storage_address}.")

return storage_channel


def register_new_dataset() -> None:
storage_channel = connect_to_storage()

storage = StorageStub(storage_channel)

request = RegisterNewDatasetRequest(
base_path=str(DATASET_PATH),
dataset_id="test_dataset",
description="Test dataset for integration tests.",
file_wrapper_config=json.dumps({"file_extension": ".png", "label_file_extension": ".txt"}),
file_wrapper_type="SingleSampleFileWrapper",
filesystem_wrapper_type="LocalFilesystemWrapper",
version="0.1.0",
)

response = storage.RegisterNewDataset(request)

assert response.success, "Could not register new dataset."


def check_dataset_availability() -> None:
storage_channel = connect_to_storage()

storage = StorageStub(storage_channel)

request = DatasetAvailableRequest(dataset_id="test_dataset")
response = storage.CheckAvailability(request)

assert response.available, "Dataset is not available."


def check_dataset_size(expected_size: int) -> None:
storage_channel = connect_to_storage()

storage = StorageStub(storage_channel)
request = GetDatasetSizeRequest(dataset_id="test_dataset")
response: GetDatasetSizeResponse = storage.GetDatasetSize(request)

assert response.success, "Dataset is not available."
assert response.num_keys == expected_size


def check_dataset_size_invalid() -> None:
storage_channel = connect_to_storage()

storage = StorageStub(storage_channel)
request = GetDatasetSizeRequest(dataset_id="unknown_dataset")
response: GetDatasetSizeResponse = storage.GetDatasetSize(request)

assert not response.success, "Dataset is available (even though it should not be)."


def check_get_current_timestamp() -> None:
storage_channel = connect_to_storage()
storage = StorageStub(storage_channel)
empty = storage_pb2.google_dot_protobuf_dot_empty__pb2.Empty()
response = storage.GetCurrentTimestamp(empty)

assert response.timestamp > 0, "Timestamp is not valid."


def create_dataset_dir() -> None:
pathlib.Path(DATASET_PATH).mkdir(parents=True, exist_ok=True)


def cleanup_dataset_dir() -> None:
shutil.rmtree(DATASET_PATH)


def cleanup_storage_database() -> None:
storage_channel = connect_to_storage()
storage = StorageStub(storage_channel)
request = DatasetAvailableRequest(dataset_id="test_dataset")
response = storage.DeleteDataset(request)

assert response.success, "Could not cleanup storage database."


def add_image_to_dataset(image: Image, name: str) -> None:
image.save(DATASET_PATH / name)
IMAGE_UPDATED_TIME_STAMPS.append(int(round(os.path.getmtime(DATASET_PATH / name) * 1000)))


def create_random_image() -> Image:
image = Image.new("RGB", (100, 100))
random_x = random.randint(0, 99)
random_y = random.randint(0, 99)

random_r = random.randint(0, 254)
random_g = random.randint(0, 254)
random_b = random.randint(0, 254)

image.putpixel((random_x, random_y), (random_r, random_g, random_b))

return image


def add_images_to_dataset(start_number: int, end_number: int, images_added: list[bytes]) -> None:
create_dataset_dir()

for i in range(start_number, end_number):
image = create_random_image()
add_image_to_dataset(image, f"image_{i}.png")
images_added.append(image.tobytes())
with open(DATASET_PATH / f"image_{i}.txt", "w") as label_file:
label_file.write(f"{i}")


def prepare_selector(num_dataworkers: int, keys: list[int]) -> Tuple[int, int]:
selector_channel = connect_to_selector_servicer()
selector = SelectorStub(selector_channel)
# We test the NewData strategy for finetuning on the new data, i.e., we reset without limit
# We also enforce high partitioning (maximum_keys_in_memory == 2) to ensure that works

strategy_config = {
"name": "NewDataStrategy",
"maximum_keys_in_memory": 2,
"config": {"limit": -1, "reset_after_trigger": True},
}

pipeline_id = selector.register_pipeline(
RegisterPipelineRequest(
num_workers=max(num_dataworkers, 1), selection_strategy=JsonString(value=json.dumps(strategy_config))
)
).pipeline_id

trigger_id = selector.inform_data_and_trigger(
DataInformRequest(
pipeline_id=pipeline_id,
keys=keys,
timestamps=[2 for _ in range(len(keys))],
labels=[3 for _ in range(len(keys))],
)
).trigger_id

return pipeline_id, trigger_id


def get_new_data_since(timestamp: int) -> Iterable[GetNewDataSinceResponse]:
storage_channel = connect_to_storage()

storage = StorageStub(storage_channel)

request = GetNewDataSinceRequest(
dataset_id="test_dataset",
timestamp=timestamp,
)

responses = storage.GetNewDataSince(request)
return responses


def get_data_keys() -> list[int]:
response = None
keys = []
for i in range(60):
responses = list(get_new_data_since(0))
assert len(responses) < 2, f"Received batched response, shouldn't happen: {responses}"
if len(responses) == 1:
response = responses[0]
keys = list(response.keys)
if len(keys) == 10:
break
time.sleep(1)

assert response is not None, "Did not get any response from Storage"
assert len(keys) == 10, f"Not all images were returned. Images returned: {response.keys}"

return keys


def get_bytes_parser() -> str:
return """
from PIL import Image
import io
def bytes_parser_function(data: bytes) -> Image:
return Image.open(io.BytesIO(data)).convert("RGB")"""


def tensor_in_list(tensor: torch.Tensor, tensor_list: list[torch.Tensor]) -> bool:
return any([(tensor == c_).all() for c_ in tensor_list])


def test_dataset_impl(
num_dataworkers: int,
batch_size: int,
prefetched_partitions: int,
pipeline_id: int,
trigger_id: int,
items: list[int],
) -> None:
dataloader, _ = prepare_dataloaders(
pipeline_id,
trigger_id,
"test_dataset",
num_dataworkers,
batch_size,
get_bytes_parser(),
["transforms.ToTensor()"],
get_storage_address(),
get_selector_address(),
42,
prefetched_partitions,
None,
None,
)

expected_batches = math.ceil(len(items) / batch_size)
all_samples = []
all_data = []
all_labels = []

for batch_number, batch in enumerate(dataloader):
sample_ids = batch[0]
if isinstance(sample_ids, torch.Tensor):
sample_ids = sample_ids.tolist()
elif isinstance(sample_ids, tuple):
sample_ids = list(sample_ids)

assert isinstance(sample_ids, list), "Cannot parse result from DataLoader"
assert isinstance(batch[1], torch.Tensor) and isinstance(batch[2], torch.Tensor)

all_samples.extend(sample_ids)
all_data.extend(batch[1].tolist())
all_labels.extend(batch[2].tolist())

assert len(all_samples) == len(items)
assert len(all_data) == len(items)
assert len(all_data) == len(items)
assert batch_number + 1 == expected_batches, (
f"[{num_dataworkers}][{batch_size}][{prefetched_partitions}]"
+ f"Wrong number of batches: {batch_number + 1}. num_items = {len(items)}"
)

assert set(all_samples) == set(items)
assert set(all_labels) == set(range(len(items)))

trans = transforms.Compose([transforms.ToTensor()])

for idx, image in enumerate(FIRST_ADDED_IMAGES):
parsed_image = trans(Image.open(io.BytesIO(image)))
assert tensor_in_list(
parsed_image, all_data
), f"Could not find image {idx} in all_data, all_samples = {all_samples}"


def test_dataset() -> None:
NUM_IMAGES = 10

check_get_current_timestamp() # Check if the storage service is available.
create_dataset_dir()
add_images_to_dataset(0, NUM_IMAGES, FIRST_ADDED_IMAGES) # Add images to the dataset.
register_new_dataset()
check_dataset_availability() # Check if the dataset is available.
check_dataset_size_invalid()

keys = get_data_keys()

for num_dataworkers in [0, 1, 2, 4, 8, 16]:
pipeline_id, trigger_id = prepare_selector(num_dataworkers, keys)
for prefetched_partitions in [0, 1, 2, 3, 4, 5, 999]:
for batch_size in [1, 2, 10]:
print(
f"Testing num_workers = {num_dataworkers}, partitions = {prefetched_partitions},"
+ f"batch_size = {batch_size}"
)
test_dataset_impl(num_dataworkers, batch_size, prefetched_partitions, pipeline_id, trigger_id, keys)


def main() -> None:
try:
test_dataset()
finally:
cleanup_dataset_dir()
cleanup_storage_database()


if __name__ == "__main__":
main()
2 changes: 2 additions & 0 deletions integrationtests/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ python $SCRIPT_DIR/storage/integrationtest_storage.py
python $SCRIPT_DIR/storage/integrationtest_storage_csv.py
echo "Running selector integration tests"
python $SCRIPT_DIR/selector/integrationtest_selector.py
echo "Running online datasets integration tests"
python $SCRIPT_DIR/online_dataset/test_online_dataset.py
echo "Running model storage integration tests"
python $SCRIPT_DIR/model_storage/integrationtest_model_storage.py
echo "Successfuly ran all integration tests."
2 changes: 1 addition & 1 deletion integrationtests/storage/integrationtest_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def test_storage() -> None:

add_images_to_dataset(10, 20, SECOND_ADDED_IMAGES) # Add more images to the dataset.

for i in range(20):
for i in range(60):
responses = list(get_new_data_since(IMAGE_UPDATED_TIME_STAMPS[9] + 1))
assert len(responses) < 2, f"Received batched response, shouldn't happen: {responses}"
if len(responses) == 1:
Expand Down
Loading

0 comments on commit da4f60f

Please sign in to comment.